Updates for 3.4 release. (#1305)

This commit is contained in:
ANIKET SHIVAM
2024-01-16 10:42:51 -08:00
committed by GitHub
parent acba5beee5
commit 2f589ffa76
166 changed files with 5996 additions and 4702 deletions

View File

@ -278,6 +278,7 @@ execute_process(
--architectures "${CUTLASS_NVCC_ARCHS_ENABLED}"
--kernels "${CUTLASS_LIBRARY_KERNELS}"
--ignore-kernels "${CUTLASS_LIBRARY_IGNORE_KERNELS}"
--kernel-filter-file "${KERNEL_FILTER_FILE}"
--selected-kernel-list "${CUTLASS_LIBRARY_GENERATED_KERNEL_LIST_FILE}"
--cuda-version "${CUTLASS_GENERATOR_CUDA_COMPILER_VERSION}"
--log-level DEBUG

View File

@ -207,7 +207,10 @@ public:
void initialize_random_sparsemeta_host(int seed, int MetaSizeInBits);
/// Uniformly fills a tensor with a value when provided o.w. zero
void fill(double value);
void fill_device(double value);
/// Uniformly fills a host allocation with a value when provided o.w. zero
void fill_host(double value);
/// Copies from an equivalent-sized tensor in device memory
void copy_from_device(void const *ptr);

View File

@ -2160,7 +2160,7 @@ static void tensor_fill(DeviceAllocation &allocation, Element val = Element()) {
}
/// Fills a tensor uniformly with a value (most frequently used to clear the tensor)
void DeviceAllocation::fill(double val = 0.0) {
void DeviceAllocation::fill_device(double val = 0.0) {
switch (this->type()) {
case library::NumericTypeID::kFE4M3:
@ -2259,6 +2259,180 @@ void DeviceAllocation::fill(double val = 0.0) {
}
}
/// Fills a tensor uniformly with a value (most frequently used to clear the tensor)
void DeviceAllocation::fill_host(double val = 0.0) {
std::vector<uint8_t> host_data(bytes());
switch (this->type()) {
case library::NumericTypeID::kFE4M3:
cutlass::reference::host::BlockFill<float_e4m3_t>(
reinterpret_cast<float_e4m3_t *>(host_data.data()),
capacity_,
static_cast<float_e4m3_t>(val)
);
break;
case library::NumericTypeID::kFE5M2:
cutlass::reference::host::BlockFill<float_e5m2_t>(
reinterpret_cast<float_e5m2_t *>(host_data.data()),
capacity_,
static_cast<float_e5m2_t>(val)
);
break;
case library::NumericTypeID::kF16:
cutlass::reference::host::BlockFill<half_t>(
reinterpret_cast<half_t *>(host_data.data()),
capacity_,
static_cast<half_t>(val)
);
break;
case library::NumericTypeID::kBF16:
cutlass::reference::host::BlockFill<bfloat16_t>(
reinterpret_cast<bfloat16_t *>(host_data.data()),
capacity_,
static_cast<bfloat16_t>(val)
);
break;
case library::NumericTypeID::kTF32:
cutlass::reference::host::BlockFill<tfloat32_t>(
reinterpret_cast<tfloat32_t *>(host_data.data()),
capacity_,
static_cast<tfloat32_t>(val)
);
break;
case library::NumericTypeID::kF32:
cutlass::reference::host::BlockFill<float>(
reinterpret_cast<float *>(host_data.data()),
capacity_,
static_cast<float>(val)
);
break;
case library::NumericTypeID::kF64:
cutlass::reference::host::BlockFill<double>(
reinterpret_cast<double *>(host_data.data()),
capacity_,
static_cast<double>(val)
);
break;
case library::NumericTypeID::kS2:
cutlass::reference::host::BlockFill<int2b_t>(
reinterpret_cast<int2b_t *>(host_data.data()),
capacity_,
static_cast<int2b_t>(val)
);
break;
case library::NumericTypeID::kS4:
cutlass::reference::host::BlockFill<int4b_t>(
reinterpret_cast<int4b_t *>(host_data.data()),
capacity_,
static_cast<int4b_t>(val)
);
break;
case library::NumericTypeID::kS8:
cutlass::reference::host::BlockFill<int8_t>(
reinterpret_cast<int8_t *>(host_data.data()),
capacity_,
static_cast<int8_t>(val)
);
break;
case library::NumericTypeID::kS16:
cutlass::reference::host::BlockFill<int16_t>(
reinterpret_cast<int16_t *>(host_data.data()),
capacity_,
static_cast<int16_t>(val)
);
break;
case library::NumericTypeID::kS32:
cutlass::reference::host::BlockFill<int32_t>(
reinterpret_cast<int32_t *>(host_data.data()),
capacity_,
static_cast<int32_t>(val)
);
break;
case library::NumericTypeID::kS64:
cutlass::reference::host::BlockFill<int64_t>(
reinterpret_cast<int64_t *>(host_data.data()),
capacity_,
static_cast<int64_t>(val)
);
break;
case library::NumericTypeID::kB1:
cutlass::reference::host::BlockFill<uint1b_t>(
reinterpret_cast<uint1b_t *>(host_data.data()),
capacity_,
static_cast<uint1b_t>(val)
);
break;
case library::NumericTypeID::kU2:
cutlass::reference::host::BlockFill<uint2b_t>(
reinterpret_cast<uint2b_t *>(host_data.data()),
capacity_,
static_cast<uint2b_t>(val)
);
break;
case library::NumericTypeID::kU4:
cutlass::reference::host::BlockFill<uint4b_t>(
reinterpret_cast<uint4b_t *>(host_data.data()),
capacity_,
static_cast<uint4b_t>(val)
);
break;
case library::NumericTypeID::kU8:
cutlass::reference::host::BlockFill<uint8_t>(
reinterpret_cast<uint8_t *>(host_data.data()),
capacity_,
static_cast<uint8_t>(val)
);
break;
case library::NumericTypeID::kU16:
cutlass::reference::host::BlockFill<uint16_t>(
reinterpret_cast<uint16_t *>(host_data.data()),
capacity_,
static_cast<uint16_t>(val)
);
break;
case library::NumericTypeID::kU32:
cutlass::reference::host::BlockFill<uint32_t>(
reinterpret_cast<uint32_t *>(host_data.data()),
capacity_,
static_cast<uint32_t>(val)
);
break;
case library::NumericTypeID::kU64:
cutlass::reference::host::BlockFill<uint64_t>(
reinterpret_cast<uint64_t *>(host_data.data()),
capacity_,
static_cast<uint64_t>(val)
);
break;
default:
throw std::runtime_error(std::string("Unsupported numeric type: ") + to_string(this->type()));
}
copy_from_host(host_data.data());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace profiler

View File

@ -77,6 +77,7 @@ struct GettMainloopParams {
ComplexTransform transform_A = ComplexTransform::kNone;
ComplexTransform transform_B = ComplexTransform::kNone;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -126,6 +127,7 @@ struct GettEpilogueParams {
TensorAux Aux{};
VectorAlpha Valpha{};
VectorBeta Vbeta{};
ElementCompute st = ElementCompute(1);
ElementAccumulator* abs_max_D = nullptr;
ElementAccumulator* abs_max_Aux = nullptr;
@ -204,6 +206,7 @@ void gett_mainloop(
if (m + m_b < cute::size<0>(mainloop_params.A.layout())) {
// Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type.
a_frag[m_b] = static_cast<ElementAccumulator>(ElementA(mainloop_params.A(m + m_b, k, l)));
if (mainloop_params.transform_A == ComplexTransform::kConjugate) {
a_frag[m_b] = conj(a_frag[m_b]);
}
@ -218,6 +221,7 @@ void gett_mainloop(
if (n + n_b < cute::size<0>(mainloop_params.B.layout())) {
// Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type.
b_frag[n_b] = static_cast<ElementAccumulator>(ElementB(mainloop_params.B(n + n_b, k, l)));
if (mainloop_params.transform_B == ComplexTransform::kConjugate) {
b_frag[n_b] = conj(b_frag[n_b]);
}
@ -325,6 +329,8 @@ void gett_epilogue(
converted_alpha = mul(converted_alpha, mul(converted_scale_a, converted_scale_b));
converted_beta = mul(converted_beta, converted_scale_c);
ElementCompute inter_accum[kBlockM][kBlockN];
for (int m_b = 0; m_b < kBlockM; ++m_b) {
ElementCompute local_dBias = ElementCompute(0);
@ -391,7 +397,7 @@ void gett_epilogue(
output = epilogue_fma(converted_scale_d, output, ElementCompute(0));
}
epilogue_params.D(m + m_b, n + n_b, l) = destination_converter(output);
inter_accum[m_b][n_b] = ElementCompute(output);
}
} // n_b
@ -403,6 +409,13 @@ void gett_epilogue(
}
}
} // m_b
for (int m_b = 0; m_b < kBlockM; ++m_b) {
for (int n_b = 0; n_b < kBlockN; ++n_b) {
if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) {
epilogue_params.D(m + m_b, n + n_b, l) = destination_converter(inter_accum[m_b][n_b]);
}
}
}
#if defined(_OPENMP)
#pragma omp critical(Abs_Max_Data_Update)
#endif

View File

@ -947,6 +947,20 @@ void TensorFillPadDiagonalRandomUniform(
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Fills a tensor with a uniform value
template <
typename Element ///< Element type
>
void BlockFill(
Element *ptr,
size_t capacity,
Element val
) {
for (size_t i = 0; i < capacity; ++i) {
ReferenceFactory<Element>::get(ptr, i) = val;
}
}
/// Fills a tensor with random values with a uniform random distribution.
template <
typename Element ///< Element type