CUTLASS 2.6 (#298)

CUTLASS 2.6
This commit is contained in:
Manish Gupta
2021-07-22 21:40:53 -07:00
committed by GitHub
parent 6c29fe20ba
commit e5d51840e8
308 changed files with 32408 additions and 4722 deletions

View File

@ -292,10 +292,11 @@ public:
int64_t batch_stride_C = int64_t(problem_size.m()) * problem_size.n() * 2;
int64_t batch_stride_D = int64_t(problem_size.m()) * problem_size.n() * 2;
int lda = LayoutA::packed({problem_size.m(), problem_size.k()}).stride(0);
int ldb = LayoutB::packed({problem_size.k(), problem_size.n()}).stride(0);
int ldc = LayoutC::packed({problem_size.m(), problem_size.n()}).stride(0);
int ldd = LayoutC::packed({problem_size.m(), problem_size.n()}).stride(0);
typename LayoutA::Stride::Index lda = LayoutA::packed({problem_size.m(), problem_size.k()}).stride(0);
typename LayoutB::Stride::Index ldb = LayoutB::packed({problem_size.k(), problem_size.n()}).stride(0);
typename LayoutC::Stride::Index ldc = LayoutC::packed({problem_size.m(), problem_size.n()}).stride(0);
typename LayoutC::Stride::Index ldd = LayoutC::packed({problem_size.m(), problem_size.n()}).stride(0);
int64_t imag_stride_A = int64_t(problem_size.m()) * problem_size.k();
int64_t imag_stride_B = int64_t(problem_size.k()) * problem_size.n();