Updates for 3.4 release. (#1305)
This commit is contained in:
@ -278,6 +278,7 @@ execute_process(
|
||||
--architectures "${CUTLASS_NVCC_ARCHS_ENABLED}"
|
||||
--kernels "${CUTLASS_LIBRARY_KERNELS}"
|
||||
--ignore-kernels "${CUTLASS_LIBRARY_IGNORE_KERNELS}"
|
||||
--kernel-filter-file "${KERNEL_FILTER_FILE}"
|
||||
--selected-kernel-list "${CUTLASS_LIBRARY_GENERATED_KERNEL_LIST_FILE}"
|
||||
--cuda-version "${CUTLASS_GENERATOR_CUDA_COMPILER_VERSION}"
|
||||
--log-level DEBUG
|
||||
|
||||
@ -207,7 +207,10 @@ public:
|
||||
void initialize_random_sparsemeta_host(int seed, int MetaSizeInBits);
|
||||
|
||||
/// Uniformly fills a tensor with a value when provided o.w. zero
|
||||
void fill(double value);
|
||||
void fill_device(double value);
|
||||
|
||||
/// Uniformly fills a host allocation with a value when provided o.w. zero
|
||||
void fill_host(double value);
|
||||
|
||||
/// Copies from an equivalent-sized tensor in device memory
|
||||
void copy_from_device(void const *ptr);
|
||||
|
||||
@ -2160,7 +2160,7 @@ static void tensor_fill(DeviceAllocation &allocation, Element val = Element()) {
|
||||
}
|
||||
|
||||
/// Fills a tensor uniformly with a value (most frequently used to clear the tensor)
|
||||
void DeviceAllocation::fill(double val = 0.0) {
|
||||
void DeviceAllocation::fill_device(double val = 0.0) {
|
||||
|
||||
switch (this->type()) {
|
||||
case library::NumericTypeID::kFE4M3:
|
||||
@ -2259,6 +2259,180 @@ void DeviceAllocation::fill(double val = 0.0) {
|
||||
}
|
||||
}
|
||||
|
||||
/// Fills a tensor uniformly with a value (most frequently used to clear the tensor)
|
||||
void DeviceAllocation::fill_host(double val = 0.0) {
|
||||
|
||||
std::vector<uint8_t> host_data(bytes());
|
||||
|
||||
switch (this->type()) {
|
||||
case library::NumericTypeID::kFE4M3:
|
||||
cutlass::reference::host::BlockFill<float_e4m3_t>(
|
||||
reinterpret_cast<float_e4m3_t *>(host_data.data()),
|
||||
capacity_,
|
||||
static_cast<float_e4m3_t>(val)
|
||||
);
|
||||
break;
|
||||
|
||||
case library::NumericTypeID::kFE5M2:
|
||||
cutlass::reference::host::BlockFill<float_e5m2_t>(
|
||||
reinterpret_cast<float_e5m2_t *>(host_data.data()),
|
||||
capacity_,
|
||||
static_cast<float_e5m2_t>(val)
|
||||
);
|
||||
break;
|
||||
|
||||
case library::NumericTypeID::kF16:
|
||||
cutlass::reference::host::BlockFill<half_t>(
|
||||
reinterpret_cast<half_t *>(host_data.data()),
|
||||
capacity_,
|
||||
static_cast<half_t>(val)
|
||||
);
|
||||
break;
|
||||
|
||||
case library::NumericTypeID::kBF16:
|
||||
cutlass::reference::host::BlockFill<bfloat16_t>(
|
||||
reinterpret_cast<bfloat16_t *>(host_data.data()),
|
||||
capacity_,
|
||||
static_cast<bfloat16_t>(val)
|
||||
);
|
||||
break;
|
||||
|
||||
case library::NumericTypeID::kTF32:
|
||||
cutlass::reference::host::BlockFill<tfloat32_t>(
|
||||
reinterpret_cast<tfloat32_t *>(host_data.data()),
|
||||
capacity_,
|
||||
static_cast<tfloat32_t>(val)
|
||||
);
|
||||
break;
|
||||
|
||||
case library::NumericTypeID::kF32:
|
||||
cutlass::reference::host::BlockFill<float>(
|
||||
reinterpret_cast<float *>(host_data.data()),
|
||||
capacity_,
|
||||
static_cast<float>(val)
|
||||
);
|
||||
break;
|
||||
|
||||
case library::NumericTypeID::kF64:
|
||||
cutlass::reference::host::BlockFill<double>(
|
||||
reinterpret_cast<double *>(host_data.data()),
|
||||
capacity_,
|
||||
static_cast<double>(val)
|
||||
);
|
||||
break;
|
||||
|
||||
case library::NumericTypeID::kS2:
|
||||
cutlass::reference::host::BlockFill<int2b_t>(
|
||||
reinterpret_cast<int2b_t *>(host_data.data()),
|
||||
capacity_,
|
||||
static_cast<int2b_t>(val)
|
||||
);
|
||||
break;
|
||||
|
||||
case library::NumericTypeID::kS4:
|
||||
cutlass::reference::host::BlockFill<int4b_t>(
|
||||
reinterpret_cast<int4b_t *>(host_data.data()),
|
||||
capacity_,
|
||||
static_cast<int4b_t>(val)
|
||||
);
|
||||
break;
|
||||
|
||||
case library::NumericTypeID::kS8:
|
||||
cutlass::reference::host::BlockFill<int8_t>(
|
||||
reinterpret_cast<int8_t *>(host_data.data()),
|
||||
capacity_,
|
||||
static_cast<int8_t>(val)
|
||||
);
|
||||
break;
|
||||
|
||||
case library::NumericTypeID::kS16:
|
||||
cutlass::reference::host::BlockFill<int16_t>(
|
||||
reinterpret_cast<int16_t *>(host_data.data()),
|
||||
capacity_,
|
||||
static_cast<int16_t>(val)
|
||||
);
|
||||
break;
|
||||
|
||||
case library::NumericTypeID::kS32:
|
||||
cutlass::reference::host::BlockFill<int32_t>(
|
||||
reinterpret_cast<int32_t *>(host_data.data()),
|
||||
capacity_,
|
||||
static_cast<int32_t>(val)
|
||||
);
|
||||
break;
|
||||
|
||||
case library::NumericTypeID::kS64:
|
||||
cutlass::reference::host::BlockFill<int64_t>(
|
||||
reinterpret_cast<int64_t *>(host_data.data()),
|
||||
capacity_,
|
||||
static_cast<int64_t>(val)
|
||||
);
|
||||
break;
|
||||
|
||||
case library::NumericTypeID::kB1:
|
||||
cutlass::reference::host::BlockFill<uint1b_t>(
|
||||
reinterpret_cast<uint1b_t *>(host_data.data()),
|
||||
capacity_,
|
||||
static_cast<uint1b_t>(val)
|
||||
);
|
||||
break;
|
||||
|
||||
case library::NumericTypeID::kU2:
|
||||
cutlass::reference::host::BlockFill<uint2b_t>(
|
||||
reinterpret_cast<uint2b_t *>(host_data.data()),
|
||||
capacity_,
|
||||
static_cast<uint2b_t>(val)
|
||||
);
|
||||
break;
|
||||
|
||||
case library::NumericTypeID::kU4:
|
||||
cutlass::reference::host::BlockFill<uint4b_t>(
|
||||
reinterpret_cast<uint4b_t *>(host_data.data()),
|
||||
capacity_,
|
||||
static_cast<uint4b_t>(val)
|
||||
);
|
||||
break;
|
||||
|
||||
case library::NumericTypeID::kU8:
|
||||
cutlass::reference::host::BlockFill<uint8_t>(
|
||||
reinterpret_cast<uint8_t *>(host_data.data()),
|
||||
capacity_,
|
||||
static_cast<uint8_t>(val)
|
||||
);
|
||||
break;
|
||||
|
||||
case library::NumericTypeID::kU16:
|
||||
cutlass::reference::host::BlockFill<uint16_t>(
|
||||
reinterpret_cast<uint16_t *>(host_data.data()),
|
||||
capacity_,
|
||||
static_cast<uint16_t>(val)
|
||||
);
|
||||
break;
|
||||
|
||||
case library::NumericTypeID::kU32:
|
||||
cutlass::reference::host::BlockFill<uint32_t>(
|
||||
reinterpret_cast<uint32_t *>(host_data.data()),
|
||||
capacity_,
|
||||
static_cast<uint32_t>(val)
|
||||
);
|
||||
break;
|
||||
|
||||
case library::NumericTypeID::kU64:
|
||||
cutlass::reference::host::BlockFill<uint64_t>(
|
||||
reinterpret_cast<uint64_t *>(host_data.data()),
|
||||
capacity_,
|
||||
static_cast<uint64_t>(val)
|
||||
);
|
||||
break;
|
||||
|
||||
default:
|
||||
throw std::runtime_error(std::string("Unsupported numeric type: ") + to_string(this->type()));
|
||||
}
|
||||
|
||||
copy_from_host(host_data.data());
|
||||
}
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace profiler
|
||||
|
||||
@ -77,6 +77,7 @@ struct GettMainloopParams {
|
||||
|
||||
ComplexTransform transform_A = ComplexTransform::kNone;
|
||||
ComplexTransform transform_B = ComplexTransform::kNone;
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -126,6 +127,7 @@ struct GettEpilogueParams {
|
||||
TensorAux Aux{};
|
||||
VectorAlpha Valpha{};
|
||||
VectorBeta Vbeta{};
|
||||
ElementCompute st = ElementCompute(1);
|
||||
|
||||
ElementAccumulator* abs_max_D = nullptr;
|
||||
ElementAccumulator* abs_max_Aux = nullptr;
|
||||
@ -204,6 +206,7 @@ void gett_mainloop(
|
||||
if (m + m_b < cute::size<0>(mainloop_params.A.layout())) {
|
||||
// Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type.
|
||||
a_frag[m_b] = static_cast<ElementAccumulator>(ElementA(mainloop_params.A(m + m_b, k, l)));
|
||||
|
||||
if (mainloop_params.transform_A == ComplexTransform::kConjugate) {
|
||||
a_frag[m_b] = conj(a_frag[m_b]);
|
||||
}
|
||||
@ -218,6 +221,7 @@ void gett_mainloop(
|
||||
if (n + n_b < cute::size<0>(mainloop_params.B.layout())) {
|
||||
// Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type.
|
||||
b_frag[n_b] = static_cast<ElementAccumulator>(ElementB(mainloop_params.B(n + n_b, k, l)));
|
||||
|
||||
if (mainloop_params.transform_B == ComplexTransform::kConjugate) {
|
||||
b_frag[n_b] = conj(b_frag[n_b]);
|
||||
}
|
||||
@ -325,6 +329,8 @@ void gett_epilogue(
|
||||
converted_alpha = mul(converted_alpha, mul(converted_scale_a, converted_scale_b));
|
||||
converted_beta = mul(converted_beta, converted_scale_c);
|
||||
|
||||
ElementCompute inter_accum[kBlockM][kBlockN];
|
||||
|
||||
for (int m_b = 0; m_b < kBlockM; ++m_b) {
|
||||
ElementCompute local_dBias = ElementCompute(0);
|
||||
|
||||
@ -391,7 +397,7 @@ void gett_epilogue(
|
||||
output = epilogue_fma(converted_scale_d, output, ElementCompute(0));
|
||||
}
|
||||
|
||||
epilogue_params.D(m + m_b, n + n_b, l) = destination_converter(output);
|
||||
inter_accum[m_b][n_b] = ElementCompute(output);
|
||||
}
|
||||
} // n_b
|
||||
|
||||
@ -403,6 +409,13 @@ void gett_epilogue(
|
||||
}
|
||||
}
|
||||
} // m_b
|
||||
for (int m_b = 0; m_b < kBlockM; ++m_b) {
|
||||
for (int n_b = 0; n_b < kBlockN; ++n_b) {
|
||||
if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) {
|
||||
epilogue_params.D(m + m_b, n + n_b, l) = destination_converter(inter_accum[m_b][n_b]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#if defined(_OPENMP)
|
||||
#pragma omp critical(Abs_Max_Data_Update)
|
||||
#endif
|
||||
|
||||
@ -947,6 +947,20 @@ void TensorFillPadDiagonalRandomUniform(
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Fills a tensor with a uniform value
|
||||
template <
|
||||
typename Element ///< Element type
|
||||
>
|
||||
void BlockFill(
|
||||
Element *ptr,
|
||||
size_t capacity,
|
||||
Element val
|
||||
) {
|
||||
for (size_t i = 0; i < capacity; ++i) {
|
||||
ReferenceFactory<Element>::get(ptr, i) = val;
|
||||
}
|
||||
}
|
||||
|
||||
/// Fills a tensor with random values with a uniform random distribution.
|
||||
template <
|
||||
typename Element ///< Element type
|
||||
|
||||
Reference in New Issue
Block a user