Updates for CUTLASS 3.5.0 (#1468)

This commit is contained in:
Vijay Thakkar
2024-04-11 21:33:40 -04:00
committed by GitHub
parent a40e08e9d5
commit 7d49e6c7e2
171 changed files with 7526 additions and 1888 deletions

View File

@ -108,8 +108,8 @@ __global__ void convert_with_scale_factor(
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Destination, typename Source, typename ScaleFactor, int Count, int Range = 4>
void run_test_with_scalefactor(const char dest_name[], const char source_name[], const char scale_factor_name[]) {
template <typename Destination, typename Source, typename ScaleFactor, int Count>
void run_test_with_scalefactor(const char dest_name[], const char source_name[], const char scale_factor_name[], const int range = 4, const int offset = 0) {
const int kN = Count;
dim3 grid(1, 1);
@ -124,7 +124,7 @@ void run_test_with_scalefactor(const char dest_name[], const char source_name[],
for (int i = 0; i < kN; ++i) {
source_ref.at({0, i}) = Source(i % Range);
source_ref.at({0, i}) = Source(i % range + offset);
}
for (int i = 0; i < kN; ++i) {
@ -144,10 +144,12 @@ void run_test_with_scalefactor(const char dest_name[], const char source_name[],
for (int i = 0; i < kN; ++i) {
float ref = float(source_ref.at({0, i})) / float(scale_factor_ref.at({0, i}));
EXPECT_TRUE(float(destination_ref.at({0, i})) == ref)
<< "Destination type: " << dest_name << " "<< float(destination_ref.at({0, i}))
<< ", Source type: " << source_name << " " << float(source_ref.at({0, i}))
<< ", Count: " << Count;
bool pass = float(destination_ref.at({0, i})) == ref;
EXPECT_TRUE(pass)
<< "Destination type: " << dest_name << " "<< float(destination_ref.at({0, i})) << std::endl
<< ", Source type: " << source_name << " " << float(source_ref.at({0, i})) << std::endl
<< ", Scalefactor type: " << source_name << " " << float(scale_factor_ref.at({0, i})) << std::endl
<< ", idx: " << i << std::endl;
}
}