Refactor attention kernels (#53)
This commit is contained in:
2
setup.py
2
setup.py
@ -18,7 +18,7 @@ ext_modules.append(cache_extension)
|
||||
# Attention kernels.
|
||||
attention_extension = cpp_extension.CUDAExtension(
|
||||
name='cacheflow.attention_ops',
|
||||
sources=['csrc/attention.cpp', 'csrc/attention_kernels.cu'],
|
||||
sources=['csrc/attention.cpp', 'csrc/attention/attention_kernels.cu'],
|
||||
extra_compile_args={'cxx': CXX_FLAGS, 'nvcc': NVCC_FLAGS},
|
||||
)
|
||||
ext_modules.append(attention_extension)
|
||||
|
||||
Reference in New Issue
Block a user