[Kernel] Initial Machete W4A8 support + Refactors (#9855)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
@ -1,496 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
//
|
||||
// This file is a modified excerpt of
|
||||
// include/cutlass/epilogue/fusion/visitor_load.hpp from
|
||||
// https://github.com/NVIDIA/cutlass v3.5.0
|
||||
// It has been modified to support either
|
||||
// row/column or scalar broadcasting where the tensor being loaded from is
|
||||
// always passed in via a device pointer. This lets one compiled kernel handle
|
||||
// all cases of per-tensor or per-channel/per-token quantization.
|
||||
//
|
||||
// This interface also allows the scales to be passed in as tensors that
|
||||
// consistently reside on the device, which avoids an issue with a previous
|
||||
// implementation where scalars needed to be on the CPU since they
|
||||
// were passed in via float values. This created a potential performance hazard
|
||||
// if scales were initially on the device, and caused torch.compile graph
|
||||
// breaks when moving scales to the CPU.
|
||||
//
|
||||
#pragma once
|
||||
|
||||
// Turn off clang-format for the entire file to keep it close to upstream
|
||||
// clang-format off
|
||||
|
||||
#include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp"
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
namespace cutlass::epilogue::threadblock {
|
||||
|
||||
using namespace cute;
|
||||
using namespace detail;
|
||||
|
||||
template<
|
||||
class ThreadMap,
|
||||
class Element,
|
||||
class StrideMNL
|
||||
>
|
||||
struct VisitorRowOrScalarBroadcast {
|
||||
|
||||
// This struct has been modified to have a bool indicating that ptr_row is a
|
||||
// scalar that must be broadcast.
|
||||
struct Arguments {
|
||||
Element const* ptr_row = nullptr;
|
||||
bool row_broadcast = true;
|
||||
StrideMNL dRow = {};
|
||||
};
|
||||
|
||||
using Params = Arguments;
|
||||
|
||||
template <class ProblemShape>
|
||||
static constexpr Params
|
||||
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
||||
return args;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static size_t
|
||||
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
struct SharedStorage {};
|
||||
|
||||
// Global load type
|
||||
static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits<Element>::value;
|
||||
using VecType = uint_bit_t<cute::min(128, vec_bits)>;
|
||||
static int constexpr VecLength = sizeof(VecType) / sizeof(Element);
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorRowOrScalarBroadcast() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorRowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
|
||||
: params_ptr(¶ms) { }
|
||||
|
||||
Params const* params_ptr;
|
||||
|
||||
template <class GTensor, class RTensor, class CTensor, class ProblemShape>
|
||||
struct Callbacks : EmptyCallbacks {
|
||||
CUTLASS_DEVICE
|
||||
Callbacks(
|
||||
GTensor&& tC_gRow,
|
||||
RTensor&& tC_rRow,
|
||||
CTensor&& tC_cRow,
|
||||
ProblemShape problem_shape,
|
||||
Params const* params_ptr
|
||||
):
|
||||
tC_gRow(cute::forward<GTensor>(tC_gRow)),
|
||||
tC_rRow(cute::forward<RTensor>(tC_rRow)),
|
||||
tC_cRow(cute::forward<CTensor>(tC_cRow)),
|
||||
n(get<1>(problem_shape)),
|
||||
params_ptr(params_ptr) { }
|
||||
|
||||
GTensor tC_gRow;
|
||||
RTensor tC_rRow;
|
||||
CTensor tC_cRow;
|
||||
Params const* params_ptr;
|
||||
int n;
|
||||
|
||||
// This function is modified from VisitorRowBroadcast
|
||||
CUTLASS_DEVICE void
|
||||
begin_epilogue() {
|
||||
clear(tC_rRow);
|
||||
auto src_v = filter(tC_gRow);
|
||||
auto coord_v = filter(tC_cRow);
|
||||
auto dst_v = filter(tC_rRow);
|
||||
|
||||
if (params_ptr->row_broadcast) {
|
||||
// In this case we are loading from a row vector and broadcasting
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(src_v); ++i) {
|
||||
bool guard = get<1>(coord_v(i)) < n;
|
||||
cutlass::arch::global_load<VecType, sizeof(VecType)>(
|
||||
dst_v(i), (void const*)&src_v(i), guard);
|
||||
}
|
||||
} else {
|
||||
// In this case we are loading from a scalar and broadcasting
|
||||
VecType filled_vec;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < VecLength; i++) {
|
||||
reinterpret_cast<Element*>(&filled_vec)[i] = *(params_ptr->ptr_row);
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(src_v); ++i) {
|
||||
if (get<1>(coord_v(i)) < n) {
|
||||
dst_v(i) = filled_vec;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class ElementAccumulator, int FragmentSize>
|
||||
CUTLASS_DEVICE auto // returns an Array
|
||||
visit(int iter_idx, int row_idx, int column_idx, int frg_idx,
|
||||
Array<ElementAccumulator, FragmentSize> const& frg_acc) {
|
||||
Tensor rRow_frg = recast<Array<Element, FragmentSize>>(coalesce(tC_rRow));
|
||||
return rRow_frg(column_idx);
|
||||
}
|
||||
};
|
||||
|
||||
template <class ProblemShape>
|
||||
CUTLASS_DEVICE auto
|
||||
get_callbacks(
|
||||
gemm::GemmCoord threadblock_tile_offset,
|
||||
int thread_idx,
|
||||
ProblemShape problem_shape
|
||||
) {
|
||||
Tensor mRow = make_tensor(
|
||||
make_gmem_ptr(params_ptr->ptr_row),
|
||||
problem_shape,
|
||||
params_ptr->dRow);
|
||||
|
||||
// VECTOR, FRAGMENT_COLUMN
|
||||
Tensor tC_gRow = recast<VecType>(
|
||||
ThreadMap::partition(mRow, thread_idx, threadblock_tile_offset)
|
||||
)(_,_,_0{},_0{},_0{},_0{});
|
||||
Tensor tC_rRow = make_tensor_like(tC_gRow);
|
||||
|
||||
// Generate the pred tensor
|
||||
Tensor cRow = make_identity_tensor(mRow.shape());
|
||||
Tensor tC_cRow = outer_partition(
|
||||
ThreadMap::partition(cRow, thread_idx, threadblock_tile_offset)(_,_,_0{},_0{},_0{},_0{}),
|
||||
Shape<Int<VecLength>>{},
|
||||
(_0{})
|
||||
);
|
||||
|
||||
return Callbacks<
|
||||
decltype(tC_gRow), decltype(tC_rRow),
|
||||
decltype(tC_cRow), ProblemShape>(
|
||||
cute::move(tC_gRow),
|
||||
cute::move(tC_rRow),
|
||||
cute::move(tC_cRow),
|
||||
problem_shape,
|
||||
params_ptr
|
||||
);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// This is a modified RowBroadcast that will broadcast 0 if ptr_row is null
|
||||
template<
|
||||
class ThreadMap,
|
||||
class Element,
|
||||
class StrideMNL
|
||||
>
|
||||
struct VisitorRowOrZeroBroadcast {
|
||||
|
||||
// This struct has been modified to remove null_default (because it's always 0)
|
||||
struct Arguments {
|
||||
Element const* ptr_row = nullptr;
|
||||
StrideMNL dRow = {};
|
||||
};
|
||||
|
||||
using Params = Arguments;
|
||||
|
||||
template <class ProblemShape>
|
||||
static constexpr Params
|
||||
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
||||
return args;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static size_t
|
||||
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
struct SharedStorage {};
|
||||
|
||||
// Global load type
|
||||
static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits<Element>::value;
|
||||
using VecType = uint_bit_t<cute::min(128, vec_bits)>;
|
||||
static int constexpr VecLength = sizeof(VecType) / sizeof(Element);
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorRowOrZeroBroadcast() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorRowOrZeroBroadcast(Params const& params, SharedStorage const& shared_storage)
|
||||
: params_ptr(¶ms) { }
|
||||
|
||||
Params const* params_ptr;
|
||||
|
||||
template <class GTensor, class RTensor, class CTensor, class ProblemShape>
|
||||
struct Callbacks : EmptyCallbacks {
|
||||
CUTLASS_DEVICE
|
||||
Callbacks(
|
||||
GTensor&& tC_gRow,
|
||||
RTensor&& tC_rRow,
|
||||
CTensor&& tC_cRow,
|
||||
ProblemShape problem_shape,
|
||||
Params const* params_ptr
|
||||
):
|
||||
tC_gRow(cute::forward<GTensor>(tC_gRow)),
|
||||
tC_rRow(cute::forward<RTensor>(tC_rRow)),
|
||||
tC_cRow(cute::forward<CTensor>(tC_cRow)),
|
||||
n(get<1>(problem_shape)),
|
||||
params_ptr(params_ptr) { }
|
||||
|
||||
GTensor tC_gRow;
|
||||
RTensor tC_rRow;
|
||||
CTensor tC_cRow;
|
||||
Params const* params_ptr;
|
||||
int n;
|
||||
|
||||
// This function is modified from VisitorRowBroadcast
|
||||
CUTLASS_DEVICE void
|
||||
begin_epilogue() {
|
||||
clear(tC_rRow);
|
||||
auto src_v = filter(tC_gRow);
|
||||
auto coord_v = filter(tC_cRow);
|
||||
auto dst_v = filter(tC_rRow);
|
||||
|
||||
if (params_ptr->ptr_row != nullptr) {
|
||||
// In this case we are loading from a row vector and broadcasting
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(src_v); ++i) {
|
||||
bool guard = get<1>(coord_v(i)) < n;
|
||||
cutlass::arch::global_load<VecType, sizeof(VecType)>(
|
||||
dst_v(i), (void const*)&src_v(i), guard);
|
||||
}
|
||||
} else {
|
||||
// In this case we are broadcasting 0
|
||||
VecType filled_vec;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < VecLength; i++) {
|
||||
reinterpret_cast<Element*>(&filled_vec)[i] = Element{0};
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(src_v); ++i) {
|
||||
if (get<1>(coord_v(i)) < n) {
|
||||
dst_v(i) = filled_vec;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class ElementAccumulator, int FragmentSize>
|
||||
CUTLASS_DEVICE auto // returns an Array
|
||||
visit(int iter_idx, int row_idx, int column_idx, int frg_idx,
|
||||
Array<ElementAccumulator, FragmentSize> const& frg_acc) {
|
||||
Tensor rRow_frg = recast<Array<Element, FragmentSize>>(coalesce(tC_rRow));
|
||||
return rRow_frg(column_idx);
|
||||
}
|
||||
};
|
||||
|
||||
template <class ProblemShape>
|
||||
CUTLASS_DEVICE auto
|
||||
get_callbacks(
|
||||
gemm::GemmCoord threadblock_tile_offset,
|
||||
int thread_idx,
|
||||
ProblemShape problem_shape
|
||||
) {
|
||||
Tensor mRow = make_tensor(
|
||||
make_gmem_ptr(params_ptr->ptr_row),
|
||||
problem_shape,
|
||||
params_ptr->dRow);
|
||||
|
||||
// VECTOR, FRAGMENT_COLUMN
|
||||
Tensor tC_gRow = recast<VecType>(
|
||||
ThreadMap::partition(mRow, thread_idx, threadblock_tile_offset)
|
||||
)(_,_,_0{},_0{},_0{},_0{});
|
||||
Tensor tC_rRow = make_tensor_like(tC_gRow);
|
||||
|
||||
// Generate the pred tensor
|
||||
Tensor cRow = make_identity_tensor(mRow.shape());
|
||||
Tensor tC_cRow = outer_partition(
|
||||
ThreadMap::partition(cRow, thread_idx, threadblock_tile_offset)(_,_,_0{},_0{},_0{},_0{}),
|
||||
Shape<Int<VecLength>>{},
|
||||
(_0{})
|
||||
);
|
||||
|
||||
return Callbacks<
|
||||
decltype(tC_gRow), decltype(tC_rRow),
|
||||
decltype(tC_cRow), ProblemShape>(
|
||||
cute::move(tC_gRow),
|
||||
cute::move(tC_rRow),
|
||||
cute::move(tC_cRow),
|
||||
problem_shape,
|
||||
params_ptr
|
||||
);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Column vector broadcast
|
||||
template<
|
||||
class ThreadMap,
|
||||
class Element,
|
||||
class StrideMNL = Stride<_1,_0,_0>
|
||||
>
|
||||
struct VisitorColOrScalarBroadcast {
|
||||
|
||||
// This struct has been modified to have a bool indicating that ptr_col is a
|
||||
// scalar that must be broadcast.
|
||||
struct Arguments {
|
||||
Element const* ptr_col = nullptr;
|
||||
bool col_broadcast = true;
|
||||
StrideMNL dCol = {};
|
||||
};
|
||||
|
||||
using Params = Arguments;
|
||||
|
||||
template <class ProblemShape>
|
||||
static constexpr Params
|
||||
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
||||
return args;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static size_t
|
||||
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
struct SharedStorage { };
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorColOrScalarBroadcast() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
|
||||
: params_ptr(¶ms) { }
|
||||
|
||||
Params const* params_ptr;
|
||||
|
||||
template <class GTensor, class RTensor, class CTensor, class ProblemShape>
|
||||
struct Callbacks : EmptyCallbacks {
|
||||
CUTLASS_DEVICE
|
||||
Callbacks(
|
||||
GTensor&& tC_gCol,
|
||||
RTensor&& tC_rCol,
|
||||
CTensor&& tC_cCol,
|
||||
ProblemShape problem_shape,
|
||||
Params const* params_ptr
|
||||
):
|
||||
tC_gCol(cute::forward<GTensor>(tC_gCol)),
|
||||
tC_rCol(cute::forward<RTensor>(tC_rCol)),
|
||||
tC_cCol(cute::forward<CTensor>(tC_cCol)),
|
||||
m(get<0>(problem_shape)),
|
||||
params_ptr(params_ptr) { }
|
||||
|
||||
GTensor tC_gCol;
|
||||
RTensor tC_rCol;
|
||||
CTensor tC_cCol;
|
||||
Params const* params_ptr;
|
||||
int m;
|
||||
|
||||
// This function is modified from VisitorColBroadcast
|
||||
CUTLASS_DEVICE void
|
||||
begin_epilogue() {
|
||||
clear(tC_rCol);
|
||||
|
||||
Tensor pred = make_tensor<bool>(shape(tC_gCol));
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(pred); ++i) {
|
||||
pred(i) = get<0>(tC_cCol(i)) < m;
|
||||
}
|
||||
|
||||
if (params_ptr->col_broadcast) {
|
||||
// In this case we are loading from a column vector and broadcasting
|
||||
copy_if(pred, tC_gCol, tC_rCol);
|
||||
} else {
|
||||
// In this case we are loading from a scalar and broadcasting
|
||||
auto dst_v = filter(tC_rCol);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(dst_v); ++i) {
|
||||
if (pred(i)) {
|
||||
dst_v(i) = *(params_ptr->ptr_col);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class ElementAccumulator, int FragmentSize>
|
||||
CUTLASS_DEVICE auto // returns an Array
|
||||
visit(int iter_idx, int row_idx, int column_idx, int frg_idx,
|
||||
Array<ElementAccumulator, FragmentSize> const& frg_acc) {
|
||||
Array<Element, FragmentSize> frg_col;
|
||||
frg_col.fill(tC_rCol(row_idx,iter_idx));
|
||||
return frg_col;
|
||||
}
|
||||
};
|
||||
|
||||
template <class ProblemShape>
|
||||
CUTLASS_DEVICE auto
|
||||
get_callbacks(
|
||||
gemm::GemmCoord threadblock_tile_offset,
|
||||
int thread_idx,
|
||||
ProblemShape problem_shape
|
||||
) {
|
||||
Tensor mCol = make_tensor(
|
||||
make_gmem_ptr(params_ptr->ptr_col),
|
||||
problem_shape,
|
||||
params_ptr->dCol);
|
||||
|
||||
// VECTOR, FRAGMENT_COLUMN, FRAGMENT_ROW, ITERATION_ROW, ITERATION_GROUP, ITERATION_CLUSTER
|
||||
Tensor tC_gCol = group_modes<1,4>(
|
||||
ThreadMap::partition(mCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_));
|
||||
Tensor tC_rCol = make_tensor_like(tC_gCol);
|
||||
|
||||
// Generate the pred tensor
|
||||
Tensor cCol = make_identity_tensor(mCol.shape());
|
||||
Tensor tC_cCol = group_modes<1,4>(
|
||||
ThreadMap::partition(cCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_));
|
||||
|
||||
return Callbacks<
|
||||
decltype(tC_gCol), decltype(tC_rCol),
|
||||
decltype(tC_cCol), ProblemShape>(
|
||||
cute::move(tC_gCol),
|
||||
cute::move(tC_rCol),
|
||||
cute::move(tC_cCol),
|
||||
problem_shape,
|
||||
params_ptr
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
@ -1,447 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
//
|
||||
// This file is a modified excerpt of
|
||||
// include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp
|
||||
// from https://github.com/NVIDIA/cutlass v3.5.0
|
||||
// It has been modified to support either row/column or scalar broadcasting
|
||||
// where the tensor being loaded from is always passed in via a device pointer.
|
||||
// This lets one compiled kernel handle all cases of per-tensor or
|
||||
// per-channel/per-token quantization.
|
||||
//
|
||||
// This interface also allows the scales to be passed in as tensors that
|
||||
// consistently reside on the device, which avoids an issue with a previous
|
||||
// implementation where scalars needed to be on the CPU since they
|
||||
// were passed in via float values. This created a potential performance hazard
|
||||
// if scales were initially on the device, and caused torch.compile graphs
|
||||
// breaks when moving scales to the CPU.
|
||||
//
|
||||
#pragma once
|
||||
|
||||
// Turn off clang-format for the entire file to keep it close to upstream
|
||||
// clang-format off
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/arch/barrier.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp"
|
||||
|
||||
namespace cutlass::epilogue::fusion {
|
||||
|
||||
using namespace cute;
|
||||
using namespace detail;
|
||||
|
||||
// Row vector broadcast
|
||||
template<
|
||||
int Stages,
|
||||
class CtaTileShapeMNK,
|
||||
class Element,
|
||||
class StrideMNL = Stride<_0,_1,_0>,
|
||||
int Alignment = 128 / sizeof_bits_v<Element>
|
||||
>
|
||||
struct Sm90RowOrScalarBroadcast {
|
||||
static_assert(Stages == 0, "Row broadcast doesn't support smem usage");
|
||||
static_assert(is_static_v<decltype(take<0,2>(StrideMNL{}))>); // batch stride can be dynamic or static
|
||||
static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{});
|
||||
|
||||
struct SharedStorage {
|
||||
array_aligned<Element, size<1>(CtaTileShapeMNK{})> smem;
|
||||
};
|
||||
|
||||
// This struct has been modified to have a bool indicating that ptr_row is a
|
||||
// scalar that must be broadcast, instead of containing a scalar that is
|
||||
// valid if ptr_row is null.
|
||||
struct Arguments {
|
||||
Element const* ptr_row = nullptr;
|
||||
bool row_broadcast = true;
|
||||
StrideMNL dRow = {};
|
||||
};
|
||||
|
||||
using Params = Arguments;
|
||||
|
||||
template <class ProblemShape>
|
||||
static constexpr Params
|
||||
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
||||
return args;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static bool
|
||||
can_implement(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static size_t
|
||||
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static cutlass::Status
|
||||
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
|
||||
CudaHostAdapter* cuda_adapter = nullptr) {
|
||||
return cutlass::Status::kSuccess;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90RowOrScalarBroadcast() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90RowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
|
||||
: params(params)
|
||||
, smem(const_cast<Element*>(shared_storage.smem.data())) { }
|
||||
|
||||
Params params;
|
||||
Element *smem = nullptr;
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_producer_load_needed() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_C_load_needed() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_zero() const {
|
||||
return (!params.row_broadcast && *(params.ptr_row) == Element(0));
|
||||
}
|
||||
|
||||
template <class... Args>
|
||||
CUTLASS_DEVICE auto
|
||||
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
|
||||
return EmptyProducerLoadCallbacks{};
|
||||
}
|
||||
|
||||
template <class GS_GTensor, class GS_STensor, class GS_CTensor, class Tiled_G2S, class SR_STensor, class SR_RTensor, class CTensor, class ThrResidue, class ThrNum>
|
||||
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
|
||||
CUTLASS_DEVICE
|
||||
ConsumerStoreCallbacks(
|
||||
GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_,
|
||||
GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_,
|
||||
SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_,
|
||||
CTensor tCcRow_, ThrResidue residue_tCcRow_, ThrNum thr_num_, Params const& params_)
|
||||
: tGS_gRow(tGS_gRow_)
|
||||
, tGS_sRow(tGS_sRow_)
|
||||
, tGS_cRow(tGS_cRow_)
|
||||
, tiled_G2S(tiled_g2s_)
|
||||
, tSR_sRow(tSR_sRow_)
|
||||
, tSR_rRow(tSR_rRow_)
|
||||
, tCcRow(tCcRow_)
|
||||
, residue_tCcRow(residue_tCcRow_)
|
||||
, params(params_) {}
|
||||
|
||||
GS_GTensor tGS_gRow; // (CPY,CPY_M,CPY_N)
|
||||
GS_STensor tGS_sRow; // (CPY,CPY_M,CPY_N)
|
||||
GS_CTensor tGS_cRow; // (CPY,CPY_M,CPY_N)
|
||||
Tiled_G2S tiled_G2S;
|
||||
|
||||
SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
|
||||
CTensor tCcRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
ThrResidue residue_tCcRow; // (m, n)
|
||||
ThrNum thr_num;
|
||||
Params const& params;
|
||||
|
||||
CUTLASS_DEVICE void
|
||||
begin() {
|
||||
if (!params.row_broadcast) {
|
||||
fill(tSR_rRow, *(params.ptr_row));
|
||||
return;
|
||||
}
|
||||
|
||||
auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
|
||||
Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
|
||||
Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
|
||||
Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride()));
|
||||
|
||||
for (int i = 0; i < size(tGS_gRow_flt); ++i) {
|
||||
if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) {
|
||||
continue; // OOB of SMEM,
|
||||
}
|
||||
if (elem_less(tGS_cRow_flt(i), make_coord(get<0>(residue_tCcRow), get<1>(residue_tCcRow)))) {
|
||||
tGS_sRow_flt(i) = tGS_gRow_flt(i);
|
||||
}
|
||||
else {
|
||||
tGS_sRow_flt(i) = Element(0); // Set to Zero when OOB so LDS could be issue without any preds.
|
||||
}
|
||||
}
|
||||
synchronize();
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void
|
||||
begin_loop(int epi_m, int epi_n) {
|
||||
if (epi_m == 0) { // Assumes M-major subtile loop
|
||||
if (!params.row_broadcast) return; // Do not issue LDS when row is scalar
|
||||
Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
|
||||
Tensor tSR_rRow_flt = filter_zeros(tSR_rRow);
|
||||
copy(tSR_sRow_flt, tSR_rRow_flt);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ElementAccumulator, int FragmentSize>
|
||||
CUTLASS_DEVICE Array<Element, FragmentSize>
|
||||
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
|
||||
Array<Element, FragmentSize> frg_row;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < FragmentSize; ++i) {
|
||||
frg_row[i] = tSR_rRow(epi_v * FragmentSize + i);
|
||||
}
|
||||
|
||||
return frg_row;
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
|
||||
class... Args
|
||||
>
|
||||
CUTLASS_DEVICE auto
|
||||
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
|
||||
auto [M, N, K, L] = args.problem_shape_mnkl;
|
||||
auto [m, n, k, l] = args.tile_coord_mnkl;
|
||||
using ThreadCount = decltype(size(args.tiled_copy));
|
||||
|
||||
Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow);
|
||||
Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
|
||||
Tensor sRow = make_tensor(make_smem_ptr(smem),
|
||||
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N)
|
||||
//// G2S: Gmem to Smem
|
||||
auto tiled_g2s = make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
|
||||
Layout< Shape<_1, ThreadCount>,
|
||||
Stride<_0, _1>>{},
|
||||
Layout<_1>{});
|
||||
auto thr_g2s = tiled_g2s.get_slice(args.thread_idx);
|
||||
Tensor tGS_gRow = thr_g2s.partition_S(gRow);
|
||||
Tensor tGS_sRow = thr_g2s.partition_D(sRow);
|
||||
|
||||
//// G2S: Coord
|
||||
auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})));
|
||||
Tensor tGS_cRow = thr_g2s.partition_S(cRow);
|
||||
|
||||
//// S2R: Smem to Reg
|
||||
Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||
Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
|
||||
|
||||
return ConsumerStoreCallbacks<decltype(tGS_gRow), decltype(tGS_sRow), decltype(tGS_cRow), decltype(tiled_g2s), decltype(tSR_sRow), decltype(tSR_rRow), decltype(args.tCcD), decltype(args.residue_cD), ThreadCount>(
|
||||
tGS_gRow,
|
||||
tGS_sRow,
|
||||
tGS_cRow, tiled_g2s,
|
||||
tSR_sRow,
|
||||
tSR_rRow,
|
||||
args.tCcD,
|
||||
args.residue_cD,
|
||||
ThreadCount{},
|
||||
params);
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Column vector broadcast
|
||||
template<
|
||||
int Stages,
|
||||
class CtaTileShapeMNK,
|
||||
class Element,
|
||||
class StrideMNL = Stride<_1,_0,_0>,
|
||||
int Alignment = 128 / sizeof_bits_v<Element>
|
||||
>
|
||||
struct Sm90ColOrScalarBroadcast {
|
||||
static_assert(Stages == 0, "Column broadcast doesn't support smem usage yet");
|
||||
static_assert(Alignment * sizeof_bits_v<Element> % 128 == 0, "sub-16B alignment not supported yet");
|
||||
static_assert(
|
||||
(cute::is_same_v<StrideMNL, Stride<_1,_0, _0>>) || // col vector broadcast, e.g. per-row alpha/bias
|
||||
(cute::is_same_v<StrideMNL, Stride<_1,_0,int>>)); // batched col vector broadcast, e.g. batched per-row bias
|
||||
|
||||
// Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem
|
||||
struct SharedStorage { };
|
||||
|
||||
// This struct has been modified to have a bool indicating that ptr_col is a
|
||||
// scalar that must be broadcast, instead of containing a scalar that is
|
||||
// valid if ptr_col is null.
|
||||
struct Arguments {
|
||||
Element const* ptr_col = nullptr;
|
||||
bool col_broadcast = true;
|
||||
StrideMNL dCol = {};
|
||||
};
|
||||
|
||||
using Params = Arguments;
|
||||
|
||||
template <class ProblemShape>
|
||||
static constexpr Params
|
||||
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
||||
return args;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static bool
|
||||
can_implement(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static size_t
|
||||
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static cutlass::Status
|
||||
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
|
||||
CudaHostAdapter* cuda_adapter = nullptr) {
|
||||
return cutlass::Status::kSuccess;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_producer_load_needed() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_C_load_needed() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_zero() const {
|
||||
return (!params.col_broadcast && *(params.ptr_col) == Element(0));
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90ColOrScalarBroadcast() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90ColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
|
||||
: params(params) { }
|
||||
|
||||
Params params;
|
||||
|
||||
template <class... Args>
|
||||
CUTLASS_DEVICE auto
|
||||
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
|
||||
return EmptyProducerLoadCallbacks{};
|
||||
}
|
||||
|
||||
template<class GTensor, class RTensor, class CTensor, class ProblemShape>
|
||||
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
|
||||
CUTLASS_DEVICE
|
||||
ConsumerStoreCallbacks(
|
||||
GTensor&& tCgCol,
|
||||
RTensor&& tCrCol,
|
||||
CTensor&& tCcCol,
|
||||
ProblemShape problem_shape,
|
||||
Params const& params
|
||||
):
|
||||
tCgCol(cute::forward<GTensor>(tCgCol)),
|
||||
tCrCol(cute::forward<RTensor>(tCrCol)),
|
||||
tCcCol(cute::forward<CTensor>(tCcCol)),
|
||||
m(get<0>(problem_shape)),
|
||||
params(params) {}
|
||||
|
||||
GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
RTensor tCrCol;
|
||||
CTensor tCcCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
Params const& params;
|
||||
int m;
|
||||
|
||||
CUTLASS_DEVICE void
|
||||
begin() {
|
||||
Tensor pred = make_tensor<bool>(shape(tCgCol));
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(pred); ++i) {
|
||||
pred(i) = get<0>(tCcCol(i)) < m;
|
||||
}
|
||||
|
||||
if (!params.col_broadcast) {
|
||||
fill(tCrCol, *(params.ptr_col));
|
||||
return;
|
||||
}
|
||||
|
||||
// Filter so we don't issue redundant copies over stride-0 modes
|
||||
// (only works if 0-strides are in same location, which is by construction)
|
||||
copy_if(pred, filter(tCgCol), filter(tCrCol));
|
||||
}
|
||||
|
||||
template <typename ElementAccumulator, int FragmentSize>
|
||||
CUTLASS_DEVICE Array<Element, FragmentSize>
|
||||
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
|
||||
Array<Element, FragmentSize> frg_col;
|
||||
Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < FragmentSize; ++i) {
|
||||
frg_col[i] = tCrCol_mn(epi_v * FragmentSize + i);
|
||||
}
|
||||
|
||||
return frg_col;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
template <
|
||||
bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
|
||||
class... Args
|
||||
>
|
||||
CUTLASS_DEVICE auto
|
||||
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
|
||||
|
||||
auto [M, N, K, L] = args.problem_shape_mnkl;
|
||||
Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol);
|
||||
Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||
Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
|
||||
// Generate an identity tensor matching the shape of the global tensor and
|
||||
// partition the same way, this will be used to generate the predicate
|
||||
// tensor for loading
|
||||
Tensor cCol = make_identity_tensor(mCol.shape());
|
||||
Tensor tCcCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||
|
||||
return ConsumerStoreCallbacks(
|
||||
cute::move(tCgCol),
|
||||
cute::move(tCrCol),
|
||||
cute::move(tCcCol),
|
||||
args.problem_shape_mnkl,
|
||||
params
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
@ -8,6 +8,10 @@
|
||||
#include "scaled_mm_c2x_sm89_fp8_dispatch.cuh"
|
||||
#include "scaled_mm_c2x_sm89_int8_dispatch.cuh"
|
||||
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp"
|
||||
|
||||
using namespace vllm;
|
||||
|
||||
/*
|
||||
This file defines quantized GEMM operations using the CUTLASS 2.x API, for
|
||||
NVIDIA GPUs with SM versions prior to sm90 (Hopper).
|
||||
@ -22,12 +26,11 @@ void cutlass_scaled_mm_sm75_epilogue(torch::Tensor& out, torch::Tensor const& a,
|
||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return vllm::cutlass_gemm_sm75_dispatch<int8_t, cutlass::bfloat16_t,
|
||||
Epilogue>(
|
||||
return cutlass_gemm_sm75_dispatch<int8_t, cutlass::bfloat16_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
return vllm::cutlass_gemm_sm75_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
||||
return cutlass_gemm_sm75_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
}
|
||||
@ -42,10 +45,10 @@ void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->dtype() == out.dtype(),
|
||||
"currently bias dtype must match output dtype ", out.dtype());
|
||||
return cutlass_scaled_mm_sm75_epilogue<vllm::ScaledEpilogueBias>(
|
||||
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBias>(
|
||||
out, a, b, a_scales, b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm75_epilogue<vllm::ScaledEpilogue>(
|
||||
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogue>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
@ -61,10 +64,10 @@ void cutlass_scaled_mm_azp_sm75(torch::Tensor& out, torch::Tensor const& a,
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
if (azp) {
|
||||
return cutlass_scaled_mm_sm75_epilogue<vllm::ScaledEpilogueBiasAzpToken>(
|
||||
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm75_epilogue<vllm::ScaledEpilogueBiasAzp>(
|
||||
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBiasAzp>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, bias);
|
||||
}
|
||||
}
|
||||
@ -78,12 +81,11 @@ void cutlass_scaled_mm_sm80_epilogue(torch::Tensor& out, torch::Tensor const& a,
|
||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return vllm::cutlass_gemm_sm80_dispatch<int8_t, cutlass::bfloat16_t,
|
||||
Epilogue>(
|
||||
return cutlass_gemm_sm80_dispatch<int8_t, cutlass::bfloat16_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
return vllm::cutlass_gemm_sm80_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
||||
return cutlass_gemm_sm80_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
}
|
||||
@ -98,10 +100,10 @@ void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->dtype() == out.dtype(),
|
||||
"currently bias dtype must match output dtype ", out.dtype());
|
||||
return cutlass_scaled_mm_sm80_epilogue<vllm::ScaledEpilogueBias>(
|
||||
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBias>(
|
||||
out, a, b, a_scales, b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm80_epilogue<vllm::ScaledEpilogue>(
|
||||
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogue>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
@ -117,10 +119,10 @@ void cutlass_scaled_mm_azp_sm80(torch::Tensor& out, torch::Tensor const& a,
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
if (azp) {
|
||||
return cutlass_scaled_mm_sm80_epilogue<vllm::ScaledEpilogueBiasAzpToken>(
|
||||
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm80_epilogue<vllm::ScaledEpilogueBiasAzp>(
|
||||
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBiasAzp>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, bias);
|
||||
}
|
||||
}
|
||||
@ -134,13 +136,12 @@ void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a,
|
||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return vllm::cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::bfloat16_t,
|
||||
Epilogue>(
|
||||
return cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::bfloat16_t,
|
||||
Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
assert(out.dtype() == torch::kFloat16);
|
||||
return vllm::cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::half_t,
|
||||
Epilogue>(
|
||||
return cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
} else {
|
||||
@ -148,13 +149,13 @@ void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a,
|
||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return vllm::cutlass_gemm_sm89_fp8_dispatch<
|
||||
cutlass::float_e4m3_t, cutlass::bfloat16_t, Epilogue>(
|
||||
return cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::bfloat16_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
return vllm::cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::half_t, Epilogue>(
|
||||
return cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
}
|
||||
@ -170,10 +171,10 @@ void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->dtype() == out.dtype(),
|
||||
"currently bias dtype must match output dtype ", out.dtype());
|
||||
return cutlass_scaled_mm_sm89_epilogue<vllm::ScaledEpilogueBias>(
|
||||
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBias>(
|
||||
out, a, b, a_scales, b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm89_epilogue<vllm::ScaledEpilogue>(
|
||||
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogue>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
@ -189,10 +190,10 @@ void cutlass_scaled_mm_azp_sm89(torch::Tensor& out, torch::Tensor const& a,
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
if (azp) {
|
||||
return cutlass_scaled_mm_sm89_epilogue<vllm::ScaledEpilogueBiasAzpToken>(
|
||||
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm89_epilogue<vllm::ScaledEpilogueBiasAzp>(
|
||||
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBiasAzp>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, bias);
|
||||
}
|
||||
}
|
||||
|
||||
@ -21,7 +21,6 @@
|
||||
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
|
||||
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
|
||||
|
||||
#include "broadcast_load_epilogue_c2x.hpp"
|
||||
#include "common.hpp"
|
||||
// clang-format on
|
||||
|
||||
@ -71,307 +70,6 @@ struct enable_sm89_to_sm90 : Kernel {
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* This class provides the common load descriptors for the
|
||||
* ScaledEpilogue[...] classes
|
||||
*/
|
||||
template <typename ElementD, typename OutputTileThreadMap>
|
||||
struct ScaledEpilogueBase {
|
||||
protected:
|
||||
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
|
||||
|
||||
template <typename T>
|
||||
using ColOrScalarLoad =
|
||||
cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
|
||||
OutputTileThreadMap, T, Stride<Int<1>, Int<0>, Int<0>>>;
|
||||
|
||||
template <typename T>
|
||||
using RowOrScalarLoad =
|
||||
cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast<
|
||||
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
|
||||
template <typename T>
|
||||
using ColLoad = cutlass::epilogue::threadblock::VisitorColBroadcast<
|
||||
OutputTileThreadMap, T, Stride<Int<1>, Int<0>, Int<0>>>;
|
||||
|
||||
template <typename T>
|
||||
using RowLoad = cutlass::epilogue::threadblock::VisitorRowBroadcast<
|
||||
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
|
||||
template <typename T>
|
||||
using RowOrZeroLoad =
|
||||
cutlass::epilogue::threadblock::VisitorRowOrZeroBroadcast<
|
||||
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
|
||||
// This utility function constructs the arguments for the load descriptors
|
||||
// from a tensor. It can handle both row and column, as well as row/column or
|
||||
// scalar cases.
|
||||
template <typename Descriptor, typename T>
|
||||
static auto args_from_tensor(torch::Tensor const& tensor) {
|
||||
using Arguments = typename Descriptor::Arguments;
|
||||
auto* data_ptr = static_cast<T*>(tensor.data_ptr());
|
||||
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
|
||||
std::is_same_v<Descriptor, RowOrScalarLoad<T>>) {
|
||||
return Arguments{data_ptr, tensor.numel() != 1};
|
||||
} else {
|
||||
// it would technically work but no use case as data_ptr is never nullptr
|
||||
static_assert(!std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
|
||||
return Arguments{data_ptr};
|
||||
}
|
||||
}
|
||||
|
||||
// This overload handles the case where there might not be a tensor, in which
|
||||
// case a nullptr is passed and a constant (0) is used.
|
||||
template <typename Descriptor, typename T>
|
||||
static auto args_from_tensor(c10::optional<torch::Tensor> const& tensor) {
|
||||
static_assert(std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
|
||||
using Arguments = typename Descriptor::Arguments;
|
||||
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
|
||||
return Arguments{data_ptr};
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
This epilogue function defines a quantized GEMM operation similar to
|
||||
torch._scaled_mm.
|
||||
|
||||
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
|
||||
per-row. B can be quantized per-tensor or per-column.
|
||||
Any combination of per-tensor and per-row or column is supported.
|
||||
A and B must have symmetric quantization (zero point == 0).
|
||||
|
||||
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
|
||||
scales are applied elementwise with numpy-style broadcasting.
|
||||
|
||||
ScaleA and ScaleB define the epilogue functions that apply the scales for
|
||||
the A and B operands respectively. These scales may be either per-tensor or
|
||||
per row or column.
|
||||
*/
|
||||
template <typename ElementD, typename OutputTileThreadMap>
|
||||
struct ScaledEpilogue
|
||||
: private ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
|
||||
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute0 =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiplies, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
|
||||
typename EVTCompute0::Arguments evt0_args{b_args};
|
||||
return ArgumentType{a_args, evt0_args};
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
|
||||
* This bias can also be used in the per-tensor azp case, where the activation
|
||||
* zero point (azp) is used to compute an azp correction term,
|
||||
* which is folded into the bias.
|
||||
*
|
||||
* The bias tensor must be per-output channel.
|
||||
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
|
||||
*/
|
||||
template <typename ElementD, typename OutputTileThreadMap>
|
||||
struct ScaledEpilogueBias
|
||||
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
|
||||
protected:
|
||||
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
using Bias = typename SUPER::template RowLoad<ElementD>;
|
||||
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute0 =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA,
|
||||
EVTCompute0, Bias>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
|
||||
typename EVTCompute0::Arguments evt0_args{b_args};
|
||||
return ArgumentType{a_args, evt0_args, bias_args};
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* This epilogue directly supports per-tensor azp in int32 form.
|
||||
* As opposed to the per-token epilogue below, this epilogue only has an azp_adj
|
||||
* term, which should already be multiplied with the scalar azp.
|
||||
* The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
|
||||
*
|
||||
* This epilogue also supports bias, which remains per-channel.
|
||||
*/
|
||||
template <typename ElementD, typename OutputTileThreadMap>
|
||||
struct ScaledEpilogueBiasAzp
|
||||
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
using Bias = typename SUPER::template RowOrZeroLoad<ElementD>;
|
||||
|
||||
// This is the full AZP term, azp * J @ B, shape (1,n)
|
||||
using AzpWithAdj = typename SUPER::template RowLoad<int32_t>;
|
||||
|
||||
// Compute float(accum - azp_adj), both operands are int32_t
|
||||
using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::minus, float, int32_t,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeAzp =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeAzp, Accum, AzpWithAdj>;
|
||||
|
||||
using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeScaleB =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleB, ScaleB,
|
||||
EVTComputeAzp>;
|
||||
|
||||
using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleBiasA, ScaleA,
|
||||
EVTComputeScaleB, Bias>;
|
||||
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
c10::optional<torch::Tensor> const& bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
auto azp_adj_args =
|
||||
SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
|
||||
|
||||
typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args};
|
||||
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args};
|
||||
return ArgumentType{a_args, evt_scale_b_args, bias_args};
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* This epilogue supports per-token azp by computing and applying
|
||||
* the correction term using a rank-1 update. If the term were materialized,
|
||||
* it would require O(m*n) space, and this way it only requires O(m+n) space.
|
||||
* The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
|
||||
* point for each row of A.
|
||||
* The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
|
||||
*
|
||||
* This epilogue also supports bias, which remains per-channel.
|
||||
*/
|
||||
template <typename ElementD, typename OutputTileThreadMap>
|
||||
struct ScaledEpilogueBiasAzpToken
|
||||
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
using Bias = typename SUPER::template RowOrZeroLoad<ElementD>;
|
||||
|
||||
// Per-token azp term, shape (m,1)
|
||||
using Azp = typename SUPER::template ColLoad<int32_t>;
|
||||
|
||||
// This is the AZP adjustment term, J @ B, shape (1,n)
|
||||
using AzpAdj = typename SUPER::template RowLoad<int32_t>;
|
||||
|
||||
// Compute azp * azp_adj
|
||||
using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiplies, int32_t, int32_t,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeAzp =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeAzp, Azp, AzpAdj>;
|
||||
|
||||
// Compute float(accum - azp*azp_adj), all operands are int32_t
|
||||
using ComputeAcc = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::minus, float, int32_t,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeAcc =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeAcc, Accum, EVTComputeAzp>;
|
||||
|
||||
using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeScaleB =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleB, ScaleB,
|
||||
EVTComputeAcc>;
|
||||
|
||||
using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleBiasA, ScaleA,
|
||||
EVTComputeScaleB, Bias>;
|
||||
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
torch::Tensor const& azp,
|
||||
c10::optional<torch::Tensor> const& bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
auto azp_args = SUPER::template args_from_tensor<Azp, int32_t>(azp);
|
||||
auto azp_adj_args =
|
||||
SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
|
||||
|
||||
typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args};
|
||||
typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args};
|
||||
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args};
|
||||
return ArgumentType{a_args, evt_scale_b_args, bias_args};
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Arch, template <typename> typename ArchGuard,
|
||||
typename ElementAB_, typename ElementD_,
|
||||
template <typename, typename> typename Epilogue_, typename TileShape,
|
||||
|
||||
@ -23,11 +23,12 @@
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
|
||||
#include "broadcast_load_epilogue_c3x.hpp"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
#include "common.hpp"
|
||||
// clang-format on
|
||||
|
||||
using namespace cute;
|
||||
using namespace vllm;
|
||||
|
||||
/*
|
||||
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
|
||||
@ -56,305 +57,6 @@ struct enable_sm90_or_later : Kernel {
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* This class provides the common load descriptors for the
|
||||
* ScaledEpilogue[...] classes
|
||||
*/
|
||||
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
||||
struct ScaledEpilogueBase {
|
||||
protected:
|
||||
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
|
||||
|
||||
template <typename T>
|
||||
using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<
|
||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
|
||||
Stride<Int<1>, Int<0>, Int<0>>>;
|
||||
|
||||
template <typename T>
|
||||
using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
|
||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
|
||||
Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
|
||||
// Don't want to support nullptr by default
|
||||
template <typename T, bool EnableNullPtr = false>
|
||||
using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast<
|
||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
|
||||
Stride<Int<1>, Int<0>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
|
||||
|
||||
// Don't want to support nullptr by default
|
||||
template <typename T, bool EnableNullPtr = false>
|
||||
using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast<
|
||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
|
||||
Stride<Int<0>, Int<1>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
|
||||
|
||||
// This utility function constructs the arguments for the load descriptors
|
||||
// from a tensor. It can handle both row and column, as well as row/column or
|
||||
// scalar cases.
|
||||
template <typename Descriptor, typename T>
|
||||
static auto args_from_tensor(torch::Tensor const& tensor) {
|
||||
using Arguments = typename Descriptor::Arguments;
|
||||
auto* data_ptr = static_cast<T*>(tensor.data_ptr());
|
||||
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
|
||||
std::is_same_v<Descriptor, RowOrScalarLoad<T>>) {
|
||||
return Arguments{data_ptr, tensor.numel() != 1};
|
||||
} else {
|
||||
static_assert(!std::is_same_v<Descriptor, ColLoad<T, true>> &&
|
||||
!std::is_same_v<Descriptor, RowLoad<T, true>>);
|
||||
return Arguments{data_ptr};
|
||||
}
|
||||
}
|
||||
|
||||
// This overload handles the case where there might not be a tensor, in which
|
||||
// case a nullptr is passed and a constant (0) is used.
|
||||
template <typename Descriptor, typename T>
|
||||
static auto args_from_tensor(c10::optional<torch::Tensor> const& tensor) {
|
||||
using Arguments = typename Descriptor::Arguments;
|
||||
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
|
||||
static_assert(std::is_same_v<Descriptor, ColLoad<T, true>> ||
|
||||
std::is_same_v<Descriptor, RowLoad<T, true>>);
|
||||
return Arguments{data_ptr};
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
This epilogue function defines a quantized GEMM operation similar to
|
||||
torch.scaled_mm_.
|
||||
|
||||
A and B may be both either int8 or fp8_e4m3. A can be
|
||||
quantized per-tensor or per-row. B can be quantized per-tensor or per-column.
|
||||
Any combination of per-tensor and per-row or column is supported.
|
||||
A and B must have symmetric quantization (zero point == 0).
|
||||
|
||||
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
|
||||
scales are applied elementwise with numpy-style broadcasting.
|
||||
|
||||
ScaleA and ScaleB define the epilogue functions that apply the scales for
|
||||
the A and B operands respectively. These scales may be either per-tensor or
|
||||
per row or column.
|
||||
*/
|
||||
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
||||
struct ScaledEpilogue
|
||||
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
|
||||
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute0 =
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
|
||||
typename EVTCompute0::Arguments evt0_args{b_args};
|
||||
return ArgumentType{a_args, evt0_args};
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
|
||||
* This bias can also be used in the per-tensor azp case, where the activation
|
||||
* zero point (azp) is used to compute an azp correction term,
|
||||
* which is folded into the bias.
|
||||
*
|
||||
* The bias tensor must be per-output channel.
|
||||
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
|
||||
*/
|
||||
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
||||
struct ScaledEpilogueBias
|
||||
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
using Bias = typename SUPER::template RowLoad<ElementD>;
|
||||
|
||||
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute0 =
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
|
||||
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
|
||||
typename EVTCompute0::Arguments evt0_args{b_args};
|
||||
return ArgumentType{a_args, evt0_args, bias_args};
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* This epilogue directly supports per-tensor azp in int32 form.
|
||||
* As opposed to the per-token epilogue below, this epilogue only has an azp_adj
|
||||
* term, which should already be multiplied with the scalar azp.
|
||||
* The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
|
||||
*
|
||||
* This epilogue also supports bias, which remains per-channel.
|
||||
*/
|
||||
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
||||
struct ScaledEpilogueBiasAzp
|
||||
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
using Bias = typename SUPER::template RowLoad<ElementD, true>;
|
||||
|
||||
// This is the full AZP term, azp * J @ B, shape (1,n)
|
||||
using AzpWithAdj = typename SUPER::template RowLoad<int32_t>;
|
||||
|
||||
// Compute float(accum - azp_adj), both operands are int32_t
|
||||
using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::minus, float, int32_t,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeAzp =
|
||||
cutlass::epilogue::fusion::Sm90EVT<ComputeAzp, Accum, AzpWithAdj>;
|
||||
|
||||
using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeScaleB =
|
||||
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAzp>;
|
||||
|
||||
using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleBiasA, ScaleA,
|
||||
EVTComputeScaleB, Bias>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
c10::optional<torch::Tensor> const& bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
auto azp_adj_args =
|
||||
SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
|
||||
|
||||
typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args};
|
||||
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args};
|
||||
return ArgumentType{a_args, evt_scale_b_args, bias_args};
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* This epilogue supports per-token azp by computing and applying
|
||||
* the correction term using a rank-1 update. If the term were materialized,
|
||||
* it would require O(m*n) space, and this way it only requires O(m+n) space.
|
||||
* The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
|
||||
* point for each row of A.
|
||||
* The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
|
||||
*
|
||||
* This epilogue also supports bias, which remains per-channel.
|
||||
*/
|
||||
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
||||
struct ScaledEpilogueBiasAzpToken
|
||||
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
using Bias = typename SUPER::template RowLoad<ElementD, true>;
|
||||
|
||||
// Per-token azp term, shape (m,1)
|
||||
using Azp = typename SUPER::template ColLoad<int32_t>;
|
||||
|
||||
// This is the AZP adjustment term, J @ B, shape (1,n)
|
||||
using AzpAdj = typename SUPER::template RowLoad<int32_t>;
|
||||
|
||||
// Compute azp * azp_adj
|
||||
using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, int32_t, int32_t,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeAzp =
|
||||
cutlass::epilogue::fusion::Sm90EVT<ComputeAzp, Azp, AzpAdj>;
|
||||
|
||||
// Compute float(accum - azp*azp_adj), all operands are int32_t
|
||||
using ComputeAcc = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::minus, float, int32_t,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeAcc =
|
||||
cutlass::epilogue::fusion::Sm90EVT<ComputeAcc, Accum, EVTComputeAzp>;
|
||||
|
||||
using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeScaleB =
|
||||
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAcc>;
|
||||
|
||||
using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleBiasA, ScaleA,
|
||||
EVTComputeScaleB, Bias>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
torch::Tensor const& azp,
|
||||
c10::optional<torch::Tensor> const& bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
auto azp_args = SUPER::template args_from_tensor<Azp, int32_t>(azp);
|
||||
auto azp_adj_args =
|
||||
SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
|
||||
|
||||
typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args};
|
||||
typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args};
|
||||
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args};
|
||||
return ArgumentType{a_args, evt_scale_b_args, bias_args};
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ElementAB_, typename ElementD_,
|
||||
template <typename, typename, typename> typename Epilogue_,
|
||||
typename TileShape, typename ClusterShape, typename KernelSchedule,
|
||||
@ -721,11 +423,11 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->dtype() == c.dtype(),
|
||||
"currently bias dtype must match output dtype ", c.dtype());
|
||||
return cutlass_scaled_mm_sm90_epilogue<ScaledEpilogueBias>(
|
||||
return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogueBias>(
|
||||
c, a, b, a_scales, b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm90_epilogue<ScaledEpilogue>(c, a, b, a_scales,
|
||||
b_scales);
|
||||
return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogue>(
|
||||
c, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
@ -740,10 +442,10 @@ void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
if (azp) {
|
||||
return cutlass_scaled_mm_sm90_epilogue<ScaledEpilogueBiasAzpToken>(
|
||||
return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogueBiasAzpToken>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm90_epilogue<ScaledEpilogueBiasAzp>(
|
||||
return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogueBiasAzp>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, bias);
|
||||
}
|
||||
}
|
||||
|
||||
@ -3,8 +3,10 @@ import math
|
||||
import os
|
||||
import shutil
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, fields
|
||||
from functools import reduce
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import jinja2
|
||||
# yapf conflicts with isort for this block
|
||||
@ -14,7 +16,10 @@ from vllm_cutlass_library_extension import (DataType, EpilogueScheduleTag,
|
||||
MixedInputKernelScheduleType,
|
||||
TileSchedulerTag,
|
||||
TileSchedulerType, VLLMDataType,
|
||||
VLLMDataTypeNames, VLLMDataTypeTag,
|
||||
VLLMDataTypeNames,
|
||||
VLLMDataTypeSize, VLLMDataTypeTag,
|
||||
VLLMDataTypeTorchDataTypeTag,
|
||||
VLLMDataTypeVLLMScalarTypeTag,
|
||||
VLLMKernelScheduleTag)
|
||||
|
||||
# yapf: enable
|
||||
@ -27,49 +32,125 @@ DISPATCH_TEMPLATE = """
|
||||
#include "../machete_mm_launcher.cuh"
|
||||
|
||||
namespace machete {
|
||||
using GemmDispatcher_ = GemmDispatcher<
|
||||
{{DataTypeTag[type_config.element_a]}}, // ElementA
|
||||
{{DataTypeTag[type_config.element_b]}}, // ElementB
|
||||
{{DataTypeTag[type_config.element_d]}}, // ElementD
|
||||
{{DataTypeTag[type_config.accumulator]}}, // Accumulator
|
||||
{{DataTypeTag[type_config.element_b_scale]}}, // Scales
|
||||
{{DataTypeTag[type_config.element_b_zeropoint]}}>; // Zeropoints
|
||||
|
||||
{% for s in schedules %}extern torch::Tensor
|
||||
impl_{{type_name}}_sch_{{ gen_sch_name(s) }}(PyTorchArguments args);
|
||||
{% endfor %}
|
||||
template <>
|
||||
torch::Tensor GemmDispatcher_::dispatch(PyTorchArguments args) {
|
||||
{% for impl_config in impl_configs %}
|
||||
{% set type_sig = gen_type_sig(impl_config.types) -%}
|
||||
{% for s in impl_config.schedules %}
|
||||
extern torch::Tensor impl_{{type_sig}}_sch_{{gen_sch_sig(s)}}(MMArgs);
|
||||
{%- endfor %}
|
||||
|
||||
torch::Tensor mm_dispatch_{{type_sig}}(MMArgs args) {
|
||||
[[maybe_unused]] auto M = args.A.size(0);
|
||||
[[maybe_unused]] auto N = args.B.size(1);
|
||||
[[maybe_unused]] auto K = args.A.size(1);
|
||||
|
||||
if (!args.schedule) {
|
||||
{%- for cond, s in heuristic %}
|
||||
if (!args.maybe_schedule) {
|
||||
{%- for cond, s in impl_config.heuristic %}
|
||||
{%if cond is not none%}if ({{cond}})
|
||||
{%- else %}else
|
||||
{%- endif %}
|
||||
return impl_{{ type_name }}_sch_{{ gen_sch_name(s) }}(args);{% endfor %}
|
||||
return impl_{{type_sig}}_sch_{{ gen_sch_sig(s) }}(args);{% endfor %}
|
||||
}
|
||||
|
||||
{% for s in schedules %}
|
||||
if (*args.schedule == "{{ gen_sch_name(s) }}") {
|
||||
return impl_{{ type_name }}_sch_{{ gen_sch_name(s) }}(args);
|
||||
}
|
||||
{% endfor %}
|
||||
{%- for s in impl_config.schedules %}
|
||||
if (*args.maybe_schedule == "{{ gen_sch_sig(s) }}")
|
||||
return impl_{{type_sig}}_sch_{{ gen_sch_sig(s) }}(args);
|
||||
{%- endfor %}
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "machete_gemm(..) is not implemented for "
|
||||
"schedule = ", *args.schedule);
|
||||
"schedule = ", *args.maybe_schedule);
|
||||
}
|
||||
{%- endfor %}
|
||||
|
||||
|
||||
static inline std::optional<at::ScalarType> maybe_scalartype(
|
||||
c10::optional<at::Tensor> const& t) {
|
||||
if (!t) {
|
||||
return std::nullopt;
|
||||
} else {
|
||||
return t->scalar_type();
|
||||
};
|
||||
}
|
||||
|
||||
template <>
|
||||
std::vector<std::string> GemmDispatcher_::supported_schedules() {
|
||||
return {
|
||||
{% for s in schedules -%}
|
||||
"{{ gen_sch_name(s) }}"{{ ",
|
||||
" if not loop.last }}{%- endfor %}
|
||||
};
|
||||
torch::Tensor mm_dispatch(MMArgs args) {
|
||||
auto out_type = args.maybe_out_type.value_or(args.A.scalar_type());
|
||||
auto a_type = args.A.scalar_type();
|
||||
auto maybe_g_scales_type = maybe_scalartype(args.maybe_group_scales);
|
||||
auto maybe_g_zeros_type = maybe_scalartype(args.maybe_group_zeros);
|
||||
auto maybe_ch_scales_type = maybe_scalartype(args.maybe_channel_scales);
|
||||
auto maybe_tok_scales_type = maybe_scalartype(args.maybe_token_scales);
|
||||
|
||||
{% for impl_config in impl_configs %}
|
||||
{% set t = impl_config.types -%}
|
||||
{% set type_sig = gen_type_sig(t) -%}
|
||||
if (args.b_type == {{VLLMScalarTypeTag[t.b]}}
|
||||
&& a_type == {{TorchTypeTag[t.a]}}
|
||||
&& out_type == {{TorchTypeTag[t.out]}}
|
||||
&& {%if t.b_group_scale != void -%}
|
||||
maybe_g_scales_type == {{TorchTypeTag[t.b_group_scale]}}
|
||||
{%- else %}!maybe_g_scales_type{%endif%}
|
||||
&& {%if t.b_group_zeropoint != void -%}
|
||||
maybe_g_zeros_type == {{TorchTypeTag[t.b_group_zeropoint]}}
|
||||
{%- else %}!maybe_g_zeros_type{%endif%}
|
||||
&& {%if t.b_channel_scale != void -%}
|
||||
maybe_ch_scales_type == {{TorchTypeTag[t.b_channel_scale]}}
|
||||
{%- else %}!maybe_ch_scales_type{%endif%}
|
||||
&& {%if t.a_token_scale != void -%}
|
||||
maybe_tok_scales_type == {{TorchTypeTag[t.a_token_scale]}}
|
||||
{%- else %}!maybe_tok_scales_type{%endif%}
|
||||
) {
|
||||
return mm_dispatch_{{type_sig}}(args);
|
||||
}
|
||||
{%- endfor %}
|
||||
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false, "machete_mm(..) is not implemented for "
|
||||
"a_type=", args.A.scalar_type(),
|
||||
", b_type=", args.b_type.str(),
|
||||
", out_type=", out_type,
|
||||
", with_group_scale_type=", maybe_g_scales_type
|
||||
? toString(*maybe_g_scales_type) : "None",
|
||||
", with_group_zeropoint_type=", maybe_g_zeros_type
|
||||
? toString(*maybe_g_zeros_type) : "None",
|
||||
", with_channel_scale_type=", maybe_ch_scales_type
|
||||
? toString(*maybe_ch_scales_type) : "None",
|
||||
", with_token_scale_type=", maybe_tok_scales_type
|
||||
? toString(*maybe_tok_scales_type) : "None",
|
||||
"; implemented types are: \\n",
|
||||
{%- for impl_config in impl_configs %}
|
||||
{% set t = impl_config.types -%}
|
||||
"\\t{{gen_type_option_name(t)}}\\n",
|
||||
{%- endfor %}
|
||||
"");
|
||||
}
|
||||
|
||||
std::vector<std::string> supported_schedules_dispatch(
|
||||
SupportedSchedulesArgs args) {
|
||||
auto out_type = args.maybe_out_type.value_or(args.a_type);
|
||||
|
||||
{% for impl_config in impl_configs %}
|
||||
{% set t = impl_config.types -%}
|
||||
{% set schs = impl_config.schedules -%}
|
||||
if (args.b_type == {{VLLMScalarTypeTag[t.b]}}
|
||||
&& args.a_type == {{TorchTypeTag[t.a]}}
|
||||
&& out_type == {{TorchTypeTag[t.out]}}
|
||||
&& {%if t.b_group_scale != void -%}
|
||||
args.maybe_group_scales_type == {{TorchTypeTag[t.b_group_scale]}}
|
||||
{%- else %}!args.maybe_group_scales_type{%endif%}
|
||||
&& {%if t.b_group_zeropoint != void-%}
|
||||
args.maybe_group_zeros_type == {{TorchTypeTag[t.b_group_zeropoint]}}
|
||||
{%- else %}!args.maybe_group_zeros_type{%endif%}
|
||||
) {
|
||||
return {
|
||||
{%- for s in impl_config.schedules %}
|
||||
"{{gen_sch_sig(s)}}"{% if not loop.last %},{% endif %}
|
||||
{%- endfor %}
|
||||
};
|
||||
}
|
||||
{%- endfor %}
|
||||
|
||||
return {};
|
||||
};
|
||||
|
||||
}; // namespace machete
|
||||
"""
|
||||
|
||||
@ -77,20 +158,10 @@ IMPL_TEMPLATE = """
|
||||
#include "../machete_mm_launcher.cuh"
|
||||
|
||||
namespace machete {
|
||||
template <typename Config, bool with_C, bool with_scales, bool with_zeropoints>
|
||||
using Kernel = MacheteKernelTemplate<
|
||||
{{DataTypeTag[type_config.element_a]}}, // ElementA
|
||||
{{DataTypeTag[type_config.element_b]}}, // ElementB
|
||||
{{DataTypeTag[type_config.element_d]}}, // ElementD
|
||||
{{DataTypeTag[type_config.accumulator]}}, // Accumulator
|
||||
{{DataTypeTag[type_config.element_b_scale]}}, // Scales
|
||||
{{DataTypeTag[type_config.element_b_zeropoint]}}, // Zeropoints
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput,
|
||||
Config, with_C, with_scales, with_zeropoints>;
|
||||
|
||||
{% for sch in schedules %}
|
||||
{% set schedule_name = gen_sch_name(sch) -%}
|
||||
struct sch_{{schedule_name}} {
|
||||
|
||||
{% for sch in unique_schedules(impl_configs) %}
|
||||
{% set sch_sig = gen_sch_sig(sch) -%}
|
||||
struct sch_{{sch_sig}} {
|
||||
using TileShapeNM = Shape<{{
|
||||
to_cute_constant(sch.tile_shape_mn)|join(', ')}}>;
|
||||
using ClusterShape = Shape<{{
|
||||
@ -101,27 +172,34 @@ struct sch_{{schedule_name}} {
|
||||
using TileScheduler = {{TileSchedulerTag[sch.tile_scheduler]}};
|
||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
};
|
||||
|
||||
torch::Tensor
|
||||
impl_{{type_name}}_sch_{{schedule_name}}(PyTorchArguments args) {
|
||||
bool with_C = args.C.has_value(), with_scales = args.scales.has_value(),
|
||||
with_zeropoints = args.zeros.has_value();
|
||||
|
||||
{% for s in specializations %}
|
||||
if (with_C == {{s.with_C|lower}}
|
||||
&& with_zeropoints == {{s.with_zeropoints|lower}}
|
||||
&& with_scales == {{s.with_scales|lower}}) {
|
||||
return run_impl<Kernel<sch_{{schedule_name}}, {{s.with_C|lower}},
|
||||
{{s.with_scales|lower}}, {{s.with_zeropoints|lower}}>>(args);
|
||||
}{% endfor %}
|
||||
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false, "for the sake of compile times and binary size machete_mm(..) is "
|
||||
" not implemented for with_C=", with_C, ", with_scales=", with_scales,
|
||||
", with_zeropoints=", with_zeropoints,
|
||||
" (for {{type_name}}_sch_{{schedule_name}})");
|
||||
}
|
||||
{% endfor %}
|
||||
|
||||
{% for impl_config in impl_configs %}
|
||||
{% set t = impl_config.types -%}
|
||||
{% set schs = impl_config.schedules -%}
|
||||
{% set type_sig = gen_type_sig(t) -%}
|
||||
|
||||
template<typename Sch>
|
||||
using Kernel_{{type_sig}} = MacheteKernelTemplate<
|
||||
{{DataTypeTag[t.a]}}, // ElementA
|
||||
{{DataTypeTag[t.b]}}, // ElementB
|
||||
{{DataTypeTag[t.out]}}, // ElementD
|
||||
{{DataTypeTag[t.accumulator]}}, // Accumulator
|
||||
{{DataTypeTag[t.b_group_scale]}}, // GroupScaleT
|
||||
{{DataTypeTag[t.b_group_zeropoint]}}, // GroupZeroT
|
||||
{{DataTypeTag[t.b_channel_scale]}}, // ChannelScaleT
|
||||
{{DataTypeTag[t.a_token_scale]}}, // TokenScaleT
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput,
|
||||
Sch>;
|
||||
|
||||
{% for sch in schs %}
|
||||
{% set sch_sig = gen_sch_sig(sch) -%}
|
||||
torch::Tensor
|
||||
impl_{{type_sig}}_sch_{{sch_sig}}(MMArgs args) {
|
||||
return run_impl<Kernel_{{type_sig}}<sch_{{sch_sig}}>>(args);
|
||||
}
|
||||
{%- endfor %}
|
||||
{%- endfor %}
|
||||
|
||||
}; // namespace machete
|
||||
"""
|
||||
@ -130,26 +208,34 @@ PREPACK_TEMPLATE = """
|
||||
#include "../machete_prepack_launcher.cuh"
|
||||
|
||||
namespace machete {
|
||||
using PrepackBDispatcher_ = PrepackBDispatcher<
|
||||
{{DataTypeTag[type_config.element_a]}}, // ElementA
|
||||
{{DataTypeTag[type_config.element_b]}}, // ElementB
|
||||
{{DataTypeTag[type_config.element_d]}}, // ElementD
|
||||
{{DataTypeTag[type_config.accumulator]}}, // Accumulator
|
||||
{{DataTypeTag[type_config.element_b_scale]}}, // Scales
|
||||
{{DataTypeTag[type_config.element_b_zeropoint]}}>; // Zeropoints
|
||||
|
||||
using PrepackedLayoutB = PrepackedLayoutBTemplate<
|
||||
{{DataTypeTag[type_config.element_a]}}, // ElementA
|
||||
{{DataTypeTag[type_config.element_b]}}, // ElementB
|
||||
{{DataTypeTag[type_config.element_d]}}, // ElementD
|
||||
{{DataTypeTag[type_config.accumulator]}}, // Accumulator
|
||||
cutlass::layout::ColumnMajor,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput>;
|
||||
|
||||
template <>
|
||||
torch::Tensor PrepackBDispatcher_::dispatch(torch::Tensor B) {
|
||||
return prepack_impl<PrepackedLayoutB>(B);
|
||||
torch::Tensor prepack_B_dispatch(PrepackBArgs args) {
|
||||
auto convert_type = args.maybe_group_scales_type.value_or(args.a_type);
|
||||
{%- for t in types %}
|
||||
{% set b_type = unsigned_type_with_bitwidth(t.b_num_bits) %}
|
||||
if (args.a_type == {{TorchTypeTag[t.a]}}
|
||||
&& args.b_type.size_bits() == {{t.b_num_bits}}
|
||||
&& convert_type == {{TorchTypeTag[t.convert]}}) {
|
||||
return prepack_impl<
|
||||
PrepackedLayoutBTemplate<
|
||||
{{DataTypeTag[t.a]}}, // ElementA
|
||||
{{DataTypeTag[b_type]}}, // ElementB
|
||||
{{DataTypeTag[t.convert]}}, // ElementConvert
|
||||
{{DataTypeTag[t.accumulator]}}, // Accumulator
|
||||
cutlass::layout::ColumnMajor,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput>
|
||||
>(args.B);
|
||||
}
|
||||
{%- endfor %}
|
||||
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false,
|
||||
"prepack_B_dispatch(..) is not implemented for "
|
||||
"atype = ", args.a_type,
|
||||
", b_type = ", args.b_type.str(),
|
||||
", with_group_scales_type= ", args.maybe_group_scales_type ?
|
||||
toString(*args.maybe_group_scales_type) : "None");
|
||||
}
|
||||
|
||||
}; // namespace machete
|
||||
"""
|
||||
|
||||
@ -166,32 +252,34 @@ class ScheduleConfig:
|
||||
tile_scheduler: TileSchedulerType
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(frozen=True)
|
||||
class TypeConfig:
|
||||
element_a: DataType
|
||||
element_b: Union[DataType, VLLMDataType]
|
||||
element_b_scale: DataType
|
||||
element_b_zeropoint: DataType
|
||||
element_d: DataType
|
||||
a: DataType
|
||||
b: Union[DataType, VLLMDataType]
|
||||
b_group_scale: DataType
|
||||
b_group_zeropoint: DataType
|
||||
b_channel_scale: DataType
|
||||
a_token_scale: DataType
|
||||
out: DataType
|
||||
accumulator: DataType
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PrepackTypeConfig:
|
||||
a: DataType
|
||||
b_num_bits: int
|
||||
convert: DataType
|
||||
accumulator: DataType
|
||||
|
||||
|
||||
@dataclass
|
||||
class Specialization:
|
||||
with_C: bool
|
||||
with_zeropoints: bool
|
||||
with_scales: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImplConfig:
|
||||
type_config: TypeConfig
|
||||
schedule_configs: List[ScheduleConfig]
|
||||
specializations: List[Specialization]
|
||||
types: TypeConfig
|
||||
schedules: List[ScheduleConfig]
|
||||
heuristic: List[Tuple[Optional[str], ScheduleConfig]]
|
||||
|
||||
|
||||
def generate_schedule_name(schedule_config: ScheduleConfig) -> str:
|
||||
def generate_sch_sig(schedule_config: ScheduleConfig) -> str:
|
||||
tile_shape = (
|
||||
f"{schedule_config.tile_shape_mn[0]}x{schedule_config.tile_shape_mn[1]}"
|
||||
)
|
||||
@ -209,40 +297,34 @@ def generate_schedule_name(schedule_config: ScheduleConfig) -> str:
|
||||
f"_{epilogue_schedule}_{tile_scheduler}")
|
||||
|
||||
|
||||
# mostly unique shorter schedule_name
|
||||
def generate_terse_schedule_name(schedule_config: ScheduleConfig) -> str:
|
||||
# mostly unique shorter sch_sig
|
||||
def generate_terse_sch_sig(schedule_config: ScheduleConfig) -> str:
|
||||
kernel_terse_names_replace = {
|
||||
"KernelTmaWarpSpecializedCooperativeMixedInput_": "TmaMI_",
|
||||
"TmaWarpSpecializedCooperative_": "TmaCoop_",
|
||||
"StreamKScheduler": "streamK",
|
||||
}
|
||||
|
||||
schedule_name = generate_schedule_name(schedule_config)
|
||||
sch_sig = generate_sch_sig(schedule_config)
|
||||
for orig, terse in kernel_terse_names_replace.items():
|
||||
schedule_name = schedule_name.replace(orig, terse)
|
||||
return schedule_name
|
||||
sch_sig = sch_sig.replace(orig, terse)
|
||||
return sch_sig
|
||||
|
||||
|
||||
# unique type_name
|
||||
def generate_type_signature(kernel_type_config: TypeConfig):
|
||||
element_a = VLLMDataTypeNames[kernel_type_config.element_a]
|
||||
element_b = VLLMDataTypeNames[kernel_type_config.element_b]
|
||||
element_d = VLLMDataTypeNames[kernel_type_config.element_d]
|
||||
accumulator = VLLMDataTypeNames[kernel_type_config.accumulator]
|
||||
element_scale = VLLMDataTypeNames[kernel_type_config.element_b_scale]
|
||||
element_zeropoint = VLLMDataTypeNames[
|
||||
kernel_type_config.element_b_zeropoint]
|
||||
|
||||
return (f"{element_a}{element_b}{element_d}"
|
||||
f"{accumulator}{element_scale}{element_zeropoint}")
|
||||
def generate_type_signature(kernel_types: TypeConfig):
|
||||
return str("".join([
|
||||
VLLMDataTypeNames[getattr(kernel_types, field.name)]
|
||||
for field in fields(TypeConfig)
|
||||
]))
|
||||
|
||||
|
||||
# non-unique shorter type_name
|
||||
def generate_terse_type_signature(kernel_type_config: TypeConfig):
|
||||
element_a = VLLMDataTypeNames[kernel_type_config.element_a]
|
||||
element_b = VLLMDataTypeNames[kernel_type_config.element_b]
|
||||
|
||||
return f"{element_a}{element_b}"
|
||||
def generate_type_option_name(kernel_types: TypeConfig):
|
||||
return ", ".join([
|
||||
f"{field.name.replace('b_', 'with_')+'_type'}=" +
|
||||
VLLMDataTypeNames[getattr(kernel_types, field.name)]
|
||||
for field in fields(TypeConfig)
|
||||
])
|
||||
|
||||
|
||||
def is_power_of_two(n):
|
||||
@ -263,13 +345,36 @@ def to_cute_constant(value: List[int]):
|
||||
return _to_cute_constant(value)
|
||||
|
||||
|
||||
def unique_schedules(impl_configs: List[ImplConfig]):
|
||||
return list(
|
||||
set(sch for impl_config in impl_configs
|
||||
for sch in impl_config.schedules))
|
||||
|
||||
|
||||
def unsigned_type_with_bitwidth(num_bits):
|
||||
return {
|
||||
4: DataType.u4,
|
||||
8: DataType.u8,
|
||||
16: DataType.u16,
|
||||
32: DataType.u32,
|
||||
64: DataType.u64,
|
||||
}[num_bits]
|
||||
|
||||
|
||||
template_globals = {
|
||||
"void": DataType.void,
|
||||
"DataTypeTag": VLLMDataTypeTag,
|
||||
"VLLMScalarTypeTag": VLLMDataTypeVLLMScalarTypeTag,
|
||||
"TorchTypeTag": VLLMDataTypeTorchDataTypeTag,
|
||||
"KernelScheduleTag": VLLMKernelScheduleTag,
|
||||
"EpilogueScheduleTag": EpilogueScheduleTag,
|
||||
"TileSchedulerTag": TileSchedulerTag,
|
||||
"to_cute_constant": to_cute_constant,
|
||||
"gen_sch_name": generate_terse_schedule_name,
|
||||
"gen_sch_sig": generate_terse_sch_sig,
|
||||
"gen_type_sig": generate_type_signature,
|
||||
"unique_schedules": unique_schedules,
|
||||
"unsigned_type_with_bitwidth": unsigned_type_with_bitwidth,
|
||||
"gen_type_option_name": generate_type_option_name
|
||||
}
|
||||
|
||||
|
||||
@ -284,42 +389,82 @@ mm_impl_template = create_template(IMPL_TEMPLATE)
|
||||
prepack_dispatch_template = create_template(PREPACK_TEMPLATE)
|
||||
|
||||
|
||||
def create_sources(impl_config: ImplConfig, num_impl_files=1):
|
||||
def create_sources(impl_configs: List[ImplConfig], num_impl_files=8):
|
||||
sources = []
|
||||
|
||||
type_name = generate_type_signature(impl_config.type_config)
|
||||
terse_type_name = generate_terse_type_signature(impl_config.type_config)
|
||||
|
||||
sources.append((
|
||||
f"machete_mm_{terse_type_name}",
|
||||
mm_dispatch_template.render(type_name=type_name,
|
||||
type_config=impl_config.type_config,
|
||||
schedules=impl_config.schedule_configs,
|
||||
heuristic=impl_config.heuristic),
|
||||
"machete_mm_dispatch",
|
||||
mm_dispatch_template.render(impl_configs=impl_configs),
|
||||
))
|
||||
|
||||
prepack_types = []
|
||||
for impl_config in impl_configs:
|
||||
convert_type = impl_config.types.a \
|
||||
if impl_config.types.b_group_scale == DataType.void \
|
||||
else impl_config.types.b_group_scale
|
||||
prepack_types.append(
|
||||
PrepackTypeConfig(
|
||||
a=impl_config.types.a,
|
||||
b_num_bits=VLLMDataTypeSize[impl_config.types.b],
|
||||
convert=convert_type,
|
||||
accumulator=impl_config.types.accumulator,
|
||||
))
|
||||
|
||||
def prepacked_type_key(prepack_type: PrepackTypeConfig):
|
||||
# For now we we can just use the first accumulator type seen since
|
||||
# the tensor core shapes/layouts don't vary based on accumulator
|
||||
# type so we can generate less code this way
|
||||
return (prepack_type.a, prepack_type.b_num_bits, prepack_type.convert)
|
||||
|
||||
unique_prepack_types = []
|
||||
prepack_types_seen = set()
|
||||
for prepack_type in prepack_types:
|
||||
key = prepacked_type_key(prepack_type)
|
||||
if key not in prepack_types_seen:
|
||||
unique_prepack_types.append(prepack_type)
|
||||
prepack_types_seen.add(key)
|
||||
|
||||
sources.append((
|
||||
f"machete_prepack_{terse_type_name}",
|
||||
prepack_dispatch_template.render(
|
||||
type_name=type_name,
|
||||
type_config=impl_config.type_config,
|
||||
),
|
||||
"machete_prepack",
|
||||
prepack_dispatch_template.render(types=unique_prepack_types, ),
|
||||
))
|
||||
|
||||
num_schedules = len(impl_config.schedule_configs)
|
||||
schedules_per_file = math.ceil(num_schedules / num_impl_files)
|
||||
for part, i in enumerate(range(0, num_schedules, schedules_per_file)):
|
||||
file_schedules = impl_config.schedule_configs[i:i + schedules_per_file]
|
||||
# Split up impls across files
|
||||
num_impls = reduce(lambda x, y: x + len(y.schedules), impl_configs, 0)
|
||||
num_impls_per_file = math.ceil(num_impls / num_impl_files)
|
||||
|
||||
files_impls: List[List[ImplConfig]] = [[]]
|
||||
|
||||
curr_num_impls_assigned = 0
|
||||
curr_impl_in_file = 0
|
||||
curr_impl_configs = deepcopy(list(reversed(impl_configs)))
|
||||
|
||||
while curr_num_impls_assigned < num_impls:
|
||||
room_left_in_file = num_impls_per_file - curr_impl_in_file
|
||||
if room_left_in_file == 0:
|
||||
files_impls.append([])
|
||||
room_left_in_file = num_impls_per_file
|
||||
curr_impl_in_file = 0
|
||||
|
||||
curr_ic = curr_impl_configs[-1]
|
||||
if len(curr_ic.schedules) >= room_left_in_file:
|
||||
# Break apart the current impl config
|
||||
tmp_ic = deepcopy(curr_ic)
|
||||
tmp_ic.schedules = curr_ic.schedules[:room_left_in_file]
|
||||
curr_ic.schedules = curr_ic.schedules[room_left_in_file:]
|
||||
files_impls[-1].append(tmp_ic)
|
||||
else:
|
||||
files_impls[-1].append(curr_ic)
|
||||
curr_impl_configs.pop()
|
||||
curr_num_impls_assigned += len(files_impls[-1][-1].schedules)
|
||||
curr_impl_in_file += len(files_impls[-1][-1].schedules)
|
||||
|
||||
for part, file_impls in enumerate(files_impls):
|
||||
sources.append((
|
||||
f"machete_mm_{terse_type_name}_impl_part{part}",
|
||||
mm_impl_template.render(
|
||||
type_name=type_name,
|
||||
type_config=impl_config.type_config,
|
||||
schedules=file_schedules,
|
||||
specializations=impl_config.specializations,
|
||||
),
|
||||
f"machete_mm_impl_part{part+1}",
|
||||
mm_impl_template.render(impl_configs=file_impls),
|
||||
))
|
||||
|
||||
return sources
|
||||
|
||||
|
||||
@ -328,187 +473,169 @@ def generate():
|
||||
# about how this works
|
||||
SCRIPT_DIR = os.path.dirname(__file__)
|
||||
|
||||
schedule_common_params = dict(
|
||||
sch_common_params = dict(
|
||||
kernel_schedule=TmaMI,
|
||||
epilogue_schedule=TmaCoop,
|
||||
tile_scheduler=TileSchedulerType.StreamK,
|
||||
)
|
||||
|
||||
# Stored as "condition": ((tile_shape_mn), (cluster_shape_mnk))
|
||||
default_tile_heuristic_config = {
|
||||
#### M = 257+
|
||||
"M > 256 && K <= 16384 && N <= 4096": ((128, 128), (2, 1, 1)),
|
||||
"M > 256": ((128, 256), (2, 1, 1)),
|
||||
#### M = 129-256
|
||||
"M > 128 && K <= 4096 && N <= 4096": ((128, 64), (2, 1, 1)),
|
||||
"M > 128 && K <= 8192 && N <= 8192": ((128, 128), (2, 1, 1)),
|
||||
"M > 128": ((128, 256), (2, 1, 1)),
|
||||
#### M = 65-128
|
||||
"M > 64 && K <= 4069 && N <= 4069": ((128, 32), (2, 1, 1)),
|
||||
"M > 64 && K <= 4069 && N <= 8192": ((128, 64), (2, 1, 1)),
|
||||
"M > 64 && K >= 8192 && N >= 12288": ((256, 128), (2, 1, 1)),
|
||||
"M > 64": ((128, 128), (2, 1, 1)),
|
||||
#### M = 33-64
|
||||
"M > 32 && K <= 6144 && N <= 6144": ((128, 16), (1, 1, 1)),
|
||||
"M > 32 && K >= 16384 && N >= 12288": ((256, 64), (2, 1, 1)),
|
||||
"M > 32": ((128, 64), (2, 1, 1)),
|
||||
#### M = 17-32
|
||||
"M > 16 && K <= 12288 && N <= 8192": ((128, 32), (2, 1, 1)),
|
||||
"M > 16": ((256, 32), (2, 1, 1)),
|
||||
#### M = 1-16
|
||||
"N >= 26624": ((256, 16), (1, 1, 1)),
|
||||
None: ((128, 16), (1, 1, 1)),
|
||||
}
|
||||
|
||||
# For now we use the same heuristic for all types
|
||||
# Heuristic is currently tuned for H100s
|
||||
default_heuristic = [
|
||||
#### M = 257+
|
||||
(
|
||||
"M > 256 && K <= 16384 && N <= 4096",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 128),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
(
|
||||
"M > 256",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 256),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
#### M = 129-256
|
||||
(
|
||||
"M > 128 && K <= 4096 && N <= 4096",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 64),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
(
|
||||
"M > 128 && K <= 8192 && N <= 8192",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 128),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
(
|
||||
"M > 128",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 256),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
#### M = 65-128
|
||||
(
|
||||
"M > 64 && K <= 4069 && N <= 4069",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 32),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
(
|
||||
"M > 64 && K <= 4069 && N <= 8192",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 64),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
(
|
||||
"M > 64 && K >= 8192 && N >= 12288",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(256, 128),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
(
|
||||
"M > 64",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 128),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
#### M = 33-64
|
||||
(
|
||||
"M > 32 && K <= 6144 && N <= 6144",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 16),
|
||||
cluster_shape_mnk=(1, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
(
|
||||
"M > 32 && K >= 16384 && N >= 12288",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(256, 64),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
(
|
||||
"M > 32",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 64),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
#### M = 17-32
|
||||
(
|
||||
"M > 16 && K <= 12288 && N <= 8192",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 32),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
(
|
||||
"M > 16",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(256, 32),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
#### M = 1-16
|
||||
(
|
||||
"N >= 26624",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(256, 16),
|
||||
cluster_shape_mnk=(1, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
(
|
||||
None,
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 16),
|
||||
cluster_shape_mnk=(1, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
(cond, ScheduleConfig(*tile_config,
|
||||
**sch_common_params)) # type: ignore
|
||||
for cond, tile_config in default_tile_heuristic_config.items()
|
||||
]
|
||||
|
||||
# Do not use schedules = list(set(...)) because we need to make sure
|
||||
# the output list is deterministic; otherwise the generated kernel file
|
||||
# will be non-deterministic and causes ccache miss.
|
||||
schedules = []
|
||||
for _, schedule_config in default_heuristic:
|
||||
if schedule_config not in schedules:
|
||||
schedules.append(schedule_config)
|
||||
def get_unique_schedules(heuristic: Dict[str, ScheduleConfig]):
|
||||
# Do not use schedules = list(set(...)) because we need to make sure
|
||||
# the output list is deterministic; otherwise the generated kernel file
|
||||
# will be non-deterministic and causes ccache miss.
|
||||
schedules = []
|
||||
for _, schedule_config in heuristic:
|
||||
if schedule_config not in schedules:
|
||||
schedules.append(schedule_config)
|
||||
return schedules
|
||||
|
||||
impl_configs = []
|
||||
|
||||
GPTQ_kernel_type_configs = list(
|
||||
TypeConfig(
|
||||
element_a=element_a,
|
||||
element_b=element_b,
|
||||
element_b_scale=element_a,
|
||||
element_b_zeropoint=element_a,
|
||||
element_d=element_a,
|
||||
a=a,
|
||||
b=b,
|
||||
b_group_scale=a,
|
||||
b_group_zeropoint=DataType.void,
|
||||
b_channel_scale=DataType.void,
|
||||
a_token_scale=DataType.void,
|
||||
out=a,
|
||||
accumulator=DataType.f32,
|
||||
) for element_b in (VLLMDataType.u4b8, VLLMDataType.u8b128)
|
||||
for element_a in (DataType.f16, DataType.bf16))
|
||||
|
||||
GPTQ_kernel_specializations = [
|
||||
Specialization(with_C=False, with_zeropoints=False, with_scales=True)
|
||||
]
|
||||
) for b in (VLLMDataType.u4b8, VLLMDataType.u8b128)
|
||||
for a in (DataType.f16, DataType.bf16))
|
||||
|
||||
impl_configs += [
|
||||
ImplConfig(x[0], x[1], x[2], x[3])
|
||||
for x in zip(GPTQ_kernel_type_configs, itertools.repeat(schedules),
|
||||
itertools.repeat(GPTQ_kernel_specializations),
|
||||
ImplConfig(x[0], x[1], x[2])
|
||||
for x in zip(GPTQ_kernel_type_configs,
|
||||
itertools.repeat(get_unique_schedules(default_heuristic)),
|
||||
itertools.repeat(default_heuristic))
|
||||
]
|
||||
|
||||
AWQ_kernel_type_configs = list(
|
||||
TypeConfig(
|
||||
element_a=element_a,
|
||||
element_b=element_b,
|
||||
element_b_scale=element_a,
|
||||
element_b_zeropoint=element_a,
|
||||
element_d=element_a,
|
||||
a=a,
|
||||
b=b,
|
||||
b_group_scale=a,
|
||||
b_group_zeropoint=a,
|
||||
b_channel_scale=DataType.void,
|
||||
a_token_scale=DataType.void,
|
||||
out=a,
|
||||
accumulator=DataType.f32,
|
||||
) for element_b in (DataType.u4, DataType.u8)
|
||||
for element_a in (DataType.f16, DataType.bf16))
|
||||
) for b in (DataType.u4, DataType.u8)
|
||||
for a in (DataType.f16, DataType.bf16))
|
||||
|
||||
AWQ_kernel_specializations = [
|
||||
Specialization(with_C=False, with_zeropoints=True, with_scales=True)
|
||||
impl_configs += [
|
||||
ImplConfig(x[0], x[1], x[2])
|
||||
for x in zip(AWQ_kernel_type_configs,
|
||||
itertools.repeat(get_unique_schedules(default_heuristic)),
|
||||
itertools.repeat(default_heuristic))
|
||||
]
|
||||
|
||||
# Stored as "condition": ((tile_shape_mn), (cluster_shape_mnk))
|
||||
# TODO (LucasWilkinson): Further tuning required
|
||||
qqq_tile_heuristic_config = {
|
||||
#### M = 257+
|
||||
# ((128, 256), (2, 1, 1)) Broken for QQQ types
|
||||
# TODO (LucasWilkinson): Investigate further
|
||||
# "M > 256 && K <= 16384 && N <= 4096": ((128, 128), (2, 1, 1)),
|
||||
# "M > 256": ((128, 256), (2, 1, 1)),
|
||||
"M > 256": ((128, 128), (2, 1, 1)),
|
||||
#### M = 129-256
|
||||
"M > 128 && K <= 4096 && N <= 4096": ((128, 64), (2, 1, 1)),
|
||||
"M > 128 && K <= 8192 && N <= 8192": ((128, 128), (2, 1, 1)),
|
||||
# ((128, 256), (2, 1, 1)) Broken for QQQ types
|
||||
# TODO (LucasWilkinson): Investigate further
|
||||
# "M > 128": ((128, 256), (2, 1, 1)),
|
||||
"M > 128": ((128, 128), (2, 1, 1)),
|
||||
#### M = 65-128
|
||||
"M > 64 && K <= 4069 && N <= 4069": ((128, 32), (2, 1, 1)),
|
||||
"M > 64 && K <= 4069 && N <= 8192": ((128, 64), (2, 1, 1)),
|
||||
"M > 64 && K >= 8192 && N >= 12288": ((256, 128), (2, 1, 1)),
|
||||
"M > 64": ((128, 128), (2, 1, 1)),
|
||||
#### M = 33-64
|
||||
"M > 32 && K <= 6144 && N <= 6144": ((128, 16), (1, 1, 1)),
|
||||
# Broken for QQQ types
|
||||
# TODO (LucasWilkinson): Investigate further
|
||||
#"M > 32 && K >= 16384 && N >= 12288": ((256, 64), (2, 1, 1)),
|
||||
"M > 32": ((128, 64), (2, 1, 1)),
|
||||
#### M = 17-32
|
||||
"M > 16 && K <= 12288 && N <= 8192": ((128, 32), (2, 1, 1)),
|
||||
"M > 16": ((256, 32), (2, 1, 1)),
|
||||
#### M = 1-16
|
||||
"N >= 26624": ((256, 16), (1, 1, 1)),
|
||||
None: ((128, 16), (1, 1, 1)),
|
||||
}
|
||||
|
||||
# For now we use the same heuristic for all types
|
||||
# Heuristic is currently tuned for H100s
|
||||
qqq_heuristic = [
|
||||
(cond, ScheduleConfig(*tile_config,
|
||||
**sch_common_params)) # type: ignore
|
||||
for cond, tile_config in qqq_tile_heuristic_config.items()
|
||||
]
|
||||
|
||||
QQQ_kernel_types = [
|
||||
*(TypeConfig(
|
||||
a=DataType.s8,
|
||||
b=VLLMDataType.u4b8,
|
||||
b_group_scale=b_group_scale,
|
||||
b_group_zeropoint=DataType.void,
|
||||
b_channel_scale=DataType.f32,
|
||||
a_token_scale=DataType.f32,
|
||||
out=DataType.f16,
|
||||
accumulator=DataType.s32,
|
||||
) for b_group_scale in (DataType.f16, DataType.void)),
|
||||
*(TypeConfig(
|
||||
a=DataType.e4m3,
|
||||
b=VLLMDataType.u4b8,
|
||||
b_group_scale=b_group_scale,
|
||||
b_group_zeropoint=DataType.void,
|
||||
b_channel_scale=DataType.f32,
|
||||
a_token_scale=DataType.f32,
|
||||
out=DataType.f16,
|
||||
accumulator=DataType.f32,
|
||||
) for b_group_scale in (DataType.f16, DataType.void)),
|
||||
]
|
||||
|
||||
impl_configs += [
|
||||
ImplConfig(x[0], x[1], x[2], x[3])
|
||||
for x in zip(AWQ_kernel_type_configs, itertools.repeat(schedules),
|
||||
itertools.repeat(AWQ_kernel_specializations),
|
||||
itertools.repeat(default_heuristic))
|
||||
ImplConfig(x[0], x[1], x[2])
|
||||
for x in zip(QQQ_kernel_types,
|
||||
itertools.repeat(get_unique_schedules(qqq_heuristic)),
|
||||
itertools.repeat(qqq_heuristic))
|
||||
]
|
||||
|
||||
output_dir = os.path.join(SCRIPT_DIR, "generated")
|
||||
@ -521,12 +648,11 @@ def generate():
|
||||
os.makedirs(output_dir)
|
||||
|
||||
# Render each group of configurations into separate files
|
||||
for impl_config in impl_configs:
|
||||
for filename, code in create_sources(impl_config):
|
||||
filepath = os.path.join(output_dir, f"{filename}.cu")
|
||||
with open(filepath, "w") as output_file:
|
||||
output_file.write(code)
|
||||
print(f"Rendered template to {filepath}")
|
||||
for filename, code in create_sources(impl_configs):
|
||||
filepath = os.path.join(output_dir, f"{filename}.cu")
|
||||
with open(filepath, "w") as output_file:
|
||||
output_file.write(code)
|
||||
print(f"Rendered template to {filepath}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -171,6 +171,10 @@ struct MacheteCollectiveMma {
|
||||
make_shape(size<0>(TileShape_MNK{}), size<2>(TileShape_MNK{}),
|
||||
Int<DispatchPolicy::Stages>{})));
|
||||
|
||||
using SmemLayoutACopy = decltype(GmemLayoutA::TVbNbKL_to_offset_copy(
|
||||
make_shape(size<0>(TileShape_MNK{}), size<2>(TileShape_MNK{}),
|
||||
Int<DispatchPolicy::Stages>{})));
|
||||
|
||||
using SmemLayoutAtomARowMajor =
|
||||
decltype(rs_smem_selector<GmmaMajorA, ElementA,
|
||||
decltype(cute::get<0>(TileShape_MNK{})),
|
||||
@ -288,14 +292,7 @@ struct MacheteCollectiveMma {
|
||||
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomScale{})) == 0,
|
||||
"SmemLayoutAtomScale must evenly divide tile k shape.");
|
||||
|
||||
// Tile along modes in a way that maximizes the TMA box size.
|
||||
using SmemLayoutACopy = decltype(tile_to_shape(
|
||||
SmemLayoutAtomARowMajor{},
|
||||
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}),
|
||||
Int<DispatchPolicy::Stages>{}),
|
||||
conditional_t<::cutlass::gemm::detail::is_major<0, StrideA>(),
|
||||
Step<_2, _1, _3>, Step<_1, _2, _3>>{}));
|
||||
|
||||
// Tile along modes in a way that maximizes the TMA box size
|
||||
using SmemLayoutB = decltype(tile_to_shape(
|
||||
SmemLayoutAtomB{},
|
||||
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}),
|
||||
@ -428,12 +425,12 @@ struct MacheteCollectiveMma {
|
||||
// clang-format on
|
||||
|
||||
// ((athrid, val), (BlocksM, BlockK), L) -> (storage_idx)
|
||||
using PrepackedStrideA = decltype(stride(GmemLayoutA::TVbNbKL_to_offset(
|
||||
using PrepackedStrideA = decltype(stride(GmemLayoutA::TVbNbKL_to_offset_copy(
|
||||
make_shape(int32_t(0), int32_t(0), int32_t(0)))));
|
||||
|
||||
using ATensor = decltype(make_tensor(
|
||||
get_logical_ptr(static_cast<InternalElementA const*>(nullptr)),
|
||||
shape(GmemLayoutA::TVbNbKL_to_offset(
|
||||
shape(GmemLayoutA::TVbNbKL_to_offset_copy(
|
||||
make_shape(int32_t(0), int32_t(0), int32_t(0)))),
|
||||
PrepackedStrideA{}));
|
||||
|
||||
@ -450,8 +447,8 @@ struct MacheteCollectiveMma {
|
||||
|
||||
static constexpr auto make_tma_copy_A(ATensor tensor_a = ATensor{}) {
|
||||
return make_tma_copy<TmaElementA>(
|
||||
GmemTiledCopyA{}, tensor_a, SmemLayoutA{}(_, _, cute::Int<0>{}),
|
||||
shape(SmemLayoutA{}(_, _, cute::Int<0>{})),
|
||||
GmemTiledCopyA{}, tensor_a, SmemLayoutACopy{}(_, _, cute::Int<0>{}),
|
||||
shape(SmemLayoutACopy{}(_, _, cute::Int<0>{})),
|
||||
size<1>(ClusterShape{})); // mcast along N mode for this M load, if any
|
||||
}
|
||||
|
||||
@ -584,7 +581,7 @@ struct MacheteCollectiveMma {
|
||||
typename Params::TMA_Scale tma_load_scale;
|
||||
typename Params::TMA_Zero tma_load_zero;
|
||||
|
||||
auto layout = GmemLayoutA::TVbNbKL_to_offset(make_shape(M, K, L));
|
||||
auto layout = GmemLayoutA::TVbNbKL_to_offset_copy(make_shape(M, K, L));
|
||||
tma_load_a = make_tma_copy_A(
|
||||
make_logical_tensor(ptr_A, shape(layout), stride(layout)));
|
||||
|
||||
@ -722,7 +719,7 @@ struct MacheteCollectiveMma {
|
||||
// (TILE_V,TILE_B,m,k,l)
|
||||
auto make_gA_mkl = [&]() {
|
||||
// ((athrid, val), (BlocksM, BlockK), L) -> (storage_idx)
|
||||
auto layout = GmemLayoutA::TVbNbKL_to_offset(make_shape(M, K, L));
|
||||
auto layout = GmemLayoutA::TVbNbKL_to_offset_copy(make_shape(M, K, L));
|
||||
Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(shape(layout));
|
||||
return local_tile(mA_mkl,
|
||||
make_shape(size<0>(layout), PPBlocksPerTile_MK{}),
|
||||
|
||||
@ -21,6 +21,8 @@
|
||||
|
||||
#include "cutlass_extensions/cute_utils.cuh"
|
||||
#include "cutlass_extensions/vllm_numeric_conversion.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
#include "cutlass_extensions/torch_utils.hpp"
|
||||
#include "machete_collective_builder.cuh"
|
||||
#include "machete_prepacked_layout.cuh"
|
||||
#include "machete_interleaving_utils.cuh"
|
||||
@ -37,27 +39,42 @@ using namespace cute;
|
||||
// W is quantized, in this situation or right-hand operand is quantized so
|
||||
// we compute the transpose to move it to the left-hand side.
|
||||
template <typename ElementA_, typename ElementB_, typename ElementD_,
|
||||
typename AccumulatorT, typename ScaleT, typename ZeroT,
|
||||
class KernelSchedule, typename ScheduleConfig, bool with_C,
|
||||
bool with_scales, bool with_zeropoints>
|
||||
typename AccumulatorT, typename GroupScaleT, typename GroupZeroT,
|
||||
typename ChannelScaleT, typename TokenScaleT, class KernelSchedule,
|
||||
typename ScheduleConfig>
|
||||
struct MacheteKernelTemplate {
|
||||
static constexpr bool with_C = false; // not ever used
|
||||
static constexpr bool with_group_scales = !std::is_same_v<GroupScaleT, void>;
|
||||
static constexpr bool with_group_zeropoints =
|
||||
!std::is_same_v<GroupZeroT, void>;
|
||||
static constexpr bool with_channel_scales =
|
||||
!std::is_same_v<ChannelScaleT, void>;
|
||||
static constexpr bool with_token_scales = !std::is_same_v<TokenScaleT, void>;
|
||||
|
||||
using MmaType = ElementA_;
|
||||
using ElementA = ElementA_;
|
||||
using ElementB = ElementB_;
|
||||
using ElementD = ElementD_;
|
||||
using ElementC = cute::conditional_t<with_C, ElementD, void>;
|
||||
using ElementZ = ZeroT;
|
||||
using ElementS = ScaleT;
|
||||
|
||||
using ElementAccumulator =
|
||||
AccumulatorT; // Element type for internal accumulation
|
||||
using ElementAccumulator = AccumulatorT;
|
||||
using ElementCompute = AccumulatorT; // For Epilogue
|
||||
// Use dummy values when we don't have scales or zeropoints
|
||||
using ElementZGroup =
|
||||
cute::conditional_t<with_group_zeropoints, GroupZeroT, MmaType>;
|
||||
using ElementSGroup =
|
||||
cute::conditional_t<with_group_scales, GroupScaleT, MmaType>;
|
||||
using ElementConvertGroup =
|
||||
cute::conditional_t<with_group_scales, GroupScaleT, MmaType>;
|
||||
using ElementSChannel =
|
||||
cute::conditional_t<with_channel_scales, ChannelScaleT, AccumulatorT>;
|
||||
using ElementSToken =
|
||||
cute::conditional_t<with_token_scales, TokenScaleT, AccumulatorT>;
|
||||
|
||||
using BTypeTuple = cute::conditional_t<
|
||||
with_scales,
|
||||
cute::conditional_t<with_zeropoints,
|
||||
cute::tuple<ElementB, ElementS, ElementZ>,
|
||||
cute::tuple<ElementB, ElementS>>,
|
||||
with_group_scales,
|
||||
cute::conditional_t<with_group_zeropoints,
|
||||
cute::tuple<ElementB, ElementSGroup, ElementZGroup>,
|
||||
cute::tuple<ElementB, ElementSGroup>>,
|
||||
ElementB>;
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
@ -71,8 +88,8 @@ struct MacheteKernelTemplate {
|
||||
using StrideA = cutlass::detail::TagToStrideA_t<LayoutA>;
|
||||
using StrideC = cutlass::detail::TagToStrideA_t<LayoutC>;
|
||||
using StrideD = cutlass::detail::TagToStrideA_t<LayoutD>;
|
||||
using StrideS = cutlass::detail::TagToStrideA_t<LayoutScale>;
|
||||
using StrideZ = StrideS;
|
||||
using StrideSGroup = cutlass::detail::TagToStrideA_t<LayoutScale>;
|
||||
using StrideZGroup = StrideSGroup;
|
||||
|
||||
using LayoutA_Transpose =
|
||||
typename cutlass::layout::LayoutTranspose<LayoutA>::type;
|
||||
@ -85,8 +102,8 @@ struct MacheteKernelTemplate {
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
|
||||
using PrepackedLayoutB =
|
||||
PrepackedLayoutBTemplate<ElementA_, ElementB_, ElementD_, AccumulatorT,
|
||||
LayoutA_Transpose, KernelSchedule>;
|
||||
PrepackedLayoutBTemplate<ElementA_, ElementB_, ElementConvertGroup,
|
||||
AccumulatorT, LayoutA_Transpose, KernelSchedule>;
|
||||
|
||||
static int constexpr TileShapeK =
|
||||
128 * 8 / cutlass::sizeof_bits<MmaType>::value;
|
||||
@ -103,12 +120,42 @@ struct MacheteKernelTemplate {
|
||||
using EpilogueTileType = typename ScheduleConfig::EpilogueTileType;
|
||||
using TileScheduler = typename ScheduleConfig::TileScheduler;
|
||||
|
||||
static_assert(
|
||||
(!with_channel_scales && !with_token_scales) ||
|
||||
((with_channel_scales && with_token_scales) &&
|
||||
std::is_same_v<ElementSChannel, ElementSToken>),
|
||||
"Currently token and channel scales (if present) must be the same type");
|
||||
|
||||
using EpilogueDescriptor =
|
||||
cutlass::epilogue::collective::detail::EpilogueDescriptor<
|
||||
TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD,
|
||||
ElementD, EpilogueSchedule>;
|
||||
|
||||
// Currently only supports float scales
|
||||
using ChTokScalesEpilogue =
|
||||
typename vllm::c3x::ScaledEpilogue<ElementAccumulator, ElementD,
|
||||
EpilogueDescriptor>;
|
||||
static_assert((with_channel_scales || with_token_scales) ||
|
||||
(std::is_same_v<ElementSChannel, float> &&
|
||||
std::is_same_v<ElementSToken, float>),
|
||||
"Currently token and channel scales (if present) must be float "
|
||||
"(and if one is present the other must be too)");
|
||||
|
||||
using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT<
|
||||
cutlass::epilogue::fusion::Sm90AccFetch>;
|
||||
|
||||
using EVTCompute =
|
||||
std::conditional_t<with_channel_scales || with_token_scales,
|
||||
typename ChTokScalesEpilogue::EVTCompute,
|
||||
StoreEpilogueCompute>;
|
||||
|
||||
// EVTCompute
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType,
|
||||
ElementAccumulator, ElementAccumulator, ElementC, LayoutC_Transpose,
|
||||
AlignmentC, ElementD, LayoutD_Transpose, AlignmentD,
|
||||
EpilogueSchedule>::CollectiveOp;
|
||||
ElementAccumulator, ElementSChannel, ElementC, LayoutC_Transpose,
|
||||
AlignmentC, ElementD, LayoutD_Transpose, AlignmentD, EpilogueSchedule,
|
||||
EVTCompute>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::VLLMCollectiveBuilder<
|
||||
@ -131,26 +178,44 @@ struct MacheteKernelTemplate {
|
||||
using MainloopArguments = typename GemmKernel::MainloopArguments;
|
||||
using EpilogueArguments = typename GemmKernel::EpilogueArguments;
|
||||
|
||||
template <typename ShapeA, typename ShapeC, typename ShapeD, typename ShapeS,
|
||||
typename ShapeZ>
|
||||
static Arguments create_arguments(
|
||||
cudaStream_t stream,
|
||||
ElementA const* A_ptr, // A is an MxK matrix
|
||||
Layout<ShapeA, StrideA> const& layout_A,
|
||||
ElementB const* B_ptr, // B is an KxN prepacked matrix
|
||||
ElementD* D_ptr, // D is an MxN matrix
|
||||
Layout<ShapeD, StrideD> const& layout_D,
|
||||
ElementC const* C_ptr, // C is an MxN matrix
|
||||
std::optional<Layout<ShapeC, StrideC>> const& layout_C,
|
||||
ElementS const* S_ptr, // S is an scale_KxN matrix
|
||||
std::optional<Layout<ShapeS, StrideS>> const& layout_S,
|
||||
ElementZ const* Z_ptr, // Z is an scale_KxN matrix
|
||||
std::optional<Layout<ShapeZ, StrideZ>> const& layout_Z,
|
||||
ElementCompute alpha, ElementCompute beta,
|
||||
std::optional<int> maybe_group_size) {
|
||||
static_assert(!with_zeropoints || with_scales);
|
||||
torch::Tensor const& A, // MxK matrix
|
||||
torch::Tensor const& B, // KxN prepacked matrix
|
||||
torch::Tensor& D, // MxN matrix
|
||||
c10::optional<torch::Tensor> const& maybe_g_scales, // scale_KxN matrix
|
||||
c10::optional<torch::Tensor> const& maybe_g_zeros, // scale_KxN matrix
|
||||
c10::optional<int64_t> maybe_group_size,
|
||||
c10::optional<torch::Tensor> const& maybe_ch_scales, // len N vector
|
||||
c10::optional<torch::Tensor> const& maybe_tok_scales) // len M vector
|
||||
{
|
||||
static_assert(!with_group_zeropoints || with_group_scales);
|
||||
|
||||
int M = size<0>(layout_A), N = size<1>(layout_D), K = size<1>(layout_A);
|
||||
int M = A.size(0), N = B.size(1), K = A.size(1);
|
||||
TORCH_CHECK(D.size(0) == M && D.size(1) == N);
|
||||
|
||||
auto layout_A = make_cute_layout<StrideA>(A, "A");
|
||||
auto layout_D = make_cute_layout<StrideD>(D, "D");
|
||||
auto layout_S_group =
|
||||
maybe_make_cute_layout<StrideSGroup>(maybe_g_scales, "group_scales");
|
||||
auto layout_Z_group =
|
||||
maybe_make_cute_layout<StrideZGroup>(maybe_g_zeros, "group_zeros");
|
||||
int64_t numel_S_channel = maybe_ch_scales ? maybe_ch_scales->numel() : 0;
|
||||
int64_t numel_S_token = maybe_tok_scales ? maybe_tok_scales->numel() : 0;
|
||||
|
||||
auto unwrap = [](auto const& t) {
|
||||
return t ? t->const_data_ptr() : nullptr;
|
||||
};
|
||||
auto A_ptr = static_cast<ElementA const*>(A.const_data_ptr());
|
||||
auto B_ptr = static_cast<ElementB const*>(B.const_data_ptr());
|
||||
auto D_ptr = static_cast<ElementD*>(D.mutable_data_ptr());
|
||||
auto S_group_ptr =
|
||||
static_cast<ElementSGroup const*>(unwrap(maybe_g_scales));
|
||||
auto Z_group_ptr = static_cast<ElementZGroup const*>(unwrap(maybe_g_zeros));
|
||||
auto S_channel_ptr =
|
||||
static_cast<ElementSChannel const*>(unwrap(maybe_ch_scales));
|
||||
auto S_token_ptr =
|
||||
static_cast<ElementSToken const*>(unwrap(maybe_tok_scales));
|
||||
|
||||
int const group_size =
|
||||
maybe_group_size == -1 ? K : maybe_group_size.value_or(K);
|
||||
@ -159,26 +224,28 @@ struct MacheteKernelTemplate {
|
||||
TORCH_CHECK(size<0>(layout_A) == M && size<1>(layout_A) == K);
|
||||
TORCH_CHECK(size<0>(layout_D) == M && size<1>(layout_D) == N);
|
||||
|
||||
if constexpr (with_C) {
|
||||
TORCH_CHECK(C_ptr && layout_C);
|
||||
if constexpr (with_group_scales) {
|
||||
TORCH_CHECK(S_group_ptr && layout_S_group);
|
||||
TORCH_CHECK((size<0>(*layout_S_group) == scale_k &&
|
||||
size<1>(*layout_S_group) == N));
|
||||
} else {
|
||||
TORCH_CHECK(!C_ptr, "C not supported");
|
||||
TORCH_CHECK(!S_group_ptr, "Scales not supported");
|
||||
}
|
||||
|
||||
if constexpr (with_scales) {
|
||||
TORCH_CHECK(S_ptr && layout_S);
|
||||
TORCH_CHECK((size<0>(*layout_S) == scale_k && size<1>(*layout_S) == N));
|
||||
} else {
|
||||
TORCH_CHECK(!S_ptr, "Scales not supported");
|
||||
}
|
||||
|
||||
if constexpr (with_zeropoints) {
|
||||
TORCH_CHECK(Z_ptr && layout_Z);
|
||||
TORCH_CHECK((size<0>(*layout_Z) == scale_k && size<1>(*layout_Z) == N));
|
||||
TORCH_CHECK(layout_S && *layout_Z == *layout_S,
|
||||
if constexpr (with_group_zeropoints) {
|
||||
TORCH_CHECK(Z_group_ptr && layout_Z_group);
|
||||
TORCH_CHECK((size<0>(*layout_Z_group) == scale_k &&
|
||||
size<1>(*layout_Z_group) == N));
|
||||
TORCH_CHECK(layout_S_group && *layout_Z_group == *layout_S_group,
|
||||
"Scales and zeros must have the same layout");
|
||||
} else {
|
||||
TORCH_CHECK(!Z_ptr, "Zeropoints not supported");
|
||||
TORCH_CHECK(!Z_group_ptr, "Zeropoints not supported");
|
||||
}
|
||||
|
||||
if constexpr (with_channel_scales || with_token_scales) {
|
||||
TORCH_CHECK(
|
||||
(maybe_ch_scales->numel() == N || maybe_ch_scales->numel() == 1) &&
|
||||
(maybe_tok_scales->numel() == M || maybe_tok_scales->numel() == 1));
|
||||
}
|
||||
|
||||
// Transpose A and D
|
||||
@ -186,24 +253,33 @@ struct MacheteKernelTemplate {
|
||||
// for B (which is At)
|
||||
auto stride_At = layout_A.stride();
|
||||
auto stride_Dt = permute_layout<1, 0, 2>(layout_D).stride();
|
||||
auto stride_Ct = stride_Dt;
|
||||
if (layout_C) {
|
||||
stride_Ct = permute_layout<1, 0, 2>(*layout_C).stride();
|
||||
}
|
||||
|
||||
MainloopArguments mainloop_arguments{};
|
||||
EpilogueArguments epilogue_arguments{
|
||||
{alpha, beta}, C_ptr, stride_Ct, D_ptr, stride_Dt};
|
||||
// {Accum, C, C_layout, D, D}
|
||||
EpilogueArguments epilogue_arguments{};
|
||||
|
||||
if constexpr (with_scales && with_zeropoints) {
|
||||
auto stride_S = permute_layout<1, 0, 2>(*layout_S).stride();
|
||||
mainloop_arguments =
|
||||
MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At,
|
||||
S_ptr, stride_S, group_size, Z_ptr};
|
||||
} else if constexpr (with_scales) {
|
||||
auto stride_S = permute_layout<1, 0, 2>(*layout_S).stride();
|
||||
if constexpr (with_channel_scales || with_token_scales) {
|
||||
epilogue_arguments =
|
||||
EpilogueArguments{ChTokScalesEpilogue::prepare_args(
|
||||
*maybe_ch_scales, *maybe_tok_scales),
|
||||
nullptr,
|
||||
{},
|
||||
D_ptr,
|
||||
stride_Dt};
|
||||
} else {
|
||||
epilogue_arguments = EpilogueArguments{{}, nullptr, {}, D_ptr, stride_Dt};
|
||||
}
|
||||
|
||||
if constexpr (with_group_scales && with_group_zeropoints) {
|
||||
auto stride_S_group = permute_layout<1, 0, 2>(*layout_S_group).stride();
|
||||
mainloop_arguments = MainloopArguments{
|
||||
B_ptr, _StrideB{}, A_ptr, stride_At, S_ptr, stride_S, group_size};
|
||||
B_ptr, _StrideB{}, A_ptr, stride_At,
|
||||
S_group_ptr, stride_S_group, group_size, Z_group_ptr};
|
||||
} else if constexpr (with_group_scales) {
|
||||
auto stride_S_group = permute_layout<1, 0, 2>(*layout_S_group).stride();
|
||||
mainloop_arguments =
|
||||
MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At,
|
||||
S_group_ptr, stride_S_group, group_size};
|
||||
} else {
|
||||
mainloop_arguments =
|
||||
MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At};
|
||||
|
||||
@ -5,73 +5,61 @@
|
||||
|
||||
#include "machete_mm_kernel.cuh"
|
||||
#include "cutlass_extensions/torch_utils.hpp"
|
||||
#include "core/scalar_type.hpp"
|
||||
|
||||
namespace machete {
|
||||
|
||||
struct PyTorchArguments {
|
||||
struct MMArgs {
|
||||
torch::Tensor const& A;
|
||||
torch::Tensor const& B;
|
||||
c10::optional<torch::Tensor> const& scales;
|
||||
c10::optional<torch::Tensor> const& zeros;
|
||||
c10::optional<int64_t> group_size;
|
||||
c10::optional<torch::Tensor> const& C;
|
||||
c10::optional<double> alpha;
|
||||
c10::optional<double> beta;
|
||||
c10::optional<std::string> schedule;
|
||||
vllm::ScalarType const& b_type;
|
||||
c10::optional<at::ScalarType> const& maybe_out_type;
|
||||
c10::optional<torch::Tensor> const& maybe_group_scales;
|
||||
c10::optional<torch::Tensor> const& maybe_group_zeros;
|
||||
c10::optional<int64_t> maybe_group_size;
|
||||
c10::optional<torch::Tensor> const& maybe_channel_scales;
|
||||
c10::optional<torch::Tensor> const& maybe_token_scales;
|
||||
c10::optional<std::string> maybe_schedule;
|
||||
};
|
||||
|
||||
struct SupportedSchedulesArgs {
|
||||
at::ScalarType a_type;
|
||||
vllm::ScalarType b_type;
|
||||
c10::optional<at::ScalarType> maybe_group_scales_type;
|
||||
c10::optional<at::ScalarType> maybe_group_zeros_type;
|
||||
c10::optional<at::ScalarType> maybe_channel_scales_type;
|
||||
c10::optional<at::ScalarType> maybe_token_scales_type;
|
||||
c10::optional<at::ScalarType> maybe_out_type;
|
||||
};
|
||||
|
||||
torch::Tensor mm_dispatch(MMArgs args);
|
||||
|
||||
std::vector<std::string> supported_schedules_dispatch(
|
||||
SupportedSchedulesArgs args);
|
||||
|
||||
template <typename MacheteKernel>
|
||||
torch::Tensor run_impl(PyTorchArguments args) {
|
||||
torch::Tensor run_impl(MMArgs args) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(args.A));
|
||||
|
||||
auto device = args.A.device();
|
||||
auto stream = at::cuda::getCurrentCUDAStream(device.index());
|
||||
|
||||
using EleA = typename MacheteKernel::ElementA;
|
||||
using EleB = typename MacheteKernel::ElementB;
|
||||
using EleC = typename MacheteKernel::ElementC;
|
||||
using EleD = typename MacheteKernel::ElementD;
|
||||
using EleScale = typename MacheteKernel::ElementS;
|
||||
using EleZero = typename MacheteKernel::ElementZ;
|
||||
|
||||
using StrideA = typename MacheteKernel::StrideA;
|
||||
using StrideC = typename MacheteKernel::StrideC;
|
||||
using StrideD = typename MacheteKernel::StrideD;
|
||||
using StrideS = typename MacheteKernel::StrideS;
|
||||
using StrideZ = typename MacheteKernel::StrideZ;
|
||||
|
||||
int M = args.A.size(0);
|
||||
int N = args.B.size(1);
|
||||
int K = args.A.size(1);
|
||||
|
||||
// Allocate output
|
||||
torch::Tensor D =
|
||||
torch::empty({M, N}, torch::TensorOptions()
|
||||
.dtype(equivalent_scalar_type_v<EleD>)
|
||||
.device(device));
|
||||
|
||||
auto const &A = args.A, &B = args.B;
|
||||
auto const &C = args.C, &scales = args.scales, &zeros = args.zeros;
|
||||
|
||||
auto layout_A = make_cute_layout<StrideA>(A, "A");
|
||||
auto layout_D = make_cute_layout<StrideD>(D, "D");
|
||||
auto layout_C = maybe_make_cute_layout<StrideC>(C, "C");
|
||||
auto layout_S = maybe_make_cute_layout<StrideS>(scales, "scales");
|
||||
auto layout_Z = maybe_make_cute_layout<StrideZ>(zeros, "zeros");
|
||||
|
||||
auto A_ptr = static_cast<EleA const*>(A.const_data_ptr());
|
||||
auto B_ptr = static_cast<EleB const*>(B.const_data_ptr());
|
||||
auto D_ptr = static_cast<EleD*>(D.mutable_data_ptr());
|
||||
auto C_ptr = static_cast<EleC const*>(C ? C->const_data_ptr() : nullptr);
|
||||
auto S_ptr =
|
||||
static_cast<EleScale const*>(scales ? scales->const_data_ptr() : nullptr);
|
||||
auto Z_ptr =
|
||||
static_cast<EleZero const*>(zeros ? zeros->const_data_ptr() : nullptr);
|
||||
torch::Tensor D = torch::empty(
|
||||
{M, N},
|
||||
torch::TensorOptions()
|
||||
.dtype(equivalent_scalar_type_v<typename MacheteKernel::ElementD>)
|
||||
.device(device));
|
||||
|
||||
auto arguments = MacheteKernel::create_arguments(
|
||||
stream, A_ptr, layout_A, B_ptr, D_ptr, layout_D, C_ptr, layout_C, S_ptr,
|
||||
layout_S, Z_ptr, layout_Z, args.alpha.value_or(1), args.beta.value_or(0),
|
||||
args.group_size);
|
||||
stream, //
|
||||
args.A, args.B, D, args.maybe_group_scales, args.maybe_group_zeros,
|
||||
args.maybe_group_size, args.maybe_channel_scales,
|
||||
args.maybe_token_scales);
|
||||
TORCH_CHECK(MacheteKernel::can_implement(arguments),
|
||||
"Machete kernel cannot be run with these arguments");
|
||||
|
||||
@ -84,12 +72,4 @@ torch::Tensor run_impl(PyTorchArguments args) {
|
||||
return D;
|
||||
};
|
||||
|
||||
template <typename ElementA, typename ElementB, typename ElementD = ElementA,
|
||||
typename AccumulatorT = float, typename ScaleT = ElementA,
|
||||
typename ZeroT = ElementA>
|
||||
struct GemmDispatcher {
|
||||
static torch::Tensor dispatch(PyTorchArguments args);
|
||||
static std::vector<std::string> supported_schedules();
|
||||
};
|
||||
|
||||
}; // namespace machete
|
||||
@ -6,31 +6,49 @@
|
||||
|
||||
namespace machete {
|
||||
|
||||
template <typename TileShapeNKL, typename ElementB, typename BInTensor,
|
||||
typename BTiledOutTensor>
|
||||
static __global__ void prepack_B_kernel(BInTensor B_in,
|
||||
BTiledOutTensor B_tiled_out) {
|
||||
auto tB_in = local_tile(B_in, TileShapeNKL{},
|
||||
make_coord(blockIdx.x, blockIdx.y, blockIdx.z));
|
||||
auto tB_out = B_tiled_out(make_coord(_, _),
|
||||
make_coord(blockIdx.x, blockIdx.y), blockIdx.z);
|
||||
template <int threads, typename PrepackedLayoutB, typename BInTensor,
|
||||
typename ElementB>
|
||||
static __global__ void prepack_B_kernel(BInTensor B_in, ElementB* B_out_ptr) {
|
||||
auto constexpr block_size =
|
||||
Int<size(typename PrepackedLayoutB::PPBlockShape_NK{})>{};
|
||||
auto constexpr eles_per_thread = Int<block_size / threads>{};
|
||||
static_assert(block_size % threads == 0,
|
||||
"block_size must be divisible by the number of threads");
|
||||
|
||||
auto tiled_copy = make_tiled_copy(Copy_Atom<DefaultCopy, ElementB>{},
|
||||
Layout<Shape<_4, _32>, Stride<_32, _1>>{},
|
||||
Layout<Shape<_1, _2>>{});
|
||||
// Which pre-packed are we responsible for
|
||||
auto blk_coord = make_coord(blockIdx.x, blockIdx.y, blockIdx.z);
|
||||
auto tB_in = local_tile(
|
||||
B_in, append(typename PrepackedLayoutB::PPBlockShape_NK{}, _1{}),
|
||||
blk_coord);
|
||||
|
||||
auto thr_copy = tiled_copy.get_thread_slice(threadIdx.x);
|
||||
// Find the start offset in the output for this pre-packed block
|
||||
auto bNbKL_to_offset = PrepackedLayoutB::bNbKL_to_offset(shape(B_in));
|
||||
|
||||
Tensor thr_tile_S = thr_copy.partition_S(tB_in);
|
||||
Tensor thr_tile_D = thr_copy.partition_D(tB_out);
|
||||
// Tensor representing a 1:1 mapping to the output space in 1D
|
||||
auto tB_out_linear =
|
||||
make_tensor(get_logical_ptr(B_out_ptr) + bNbKL_to_offset(blk_coord),
|
||||
make_layout(make_shape(block_size)));
|
||||
// Mapping from output space (1D) to input space
|
||||
auto tB_in_linear = make_tensor(
|
||||
tB_in.data(),
|
||||
tB_in.layout()
|
||||
.compose(right_inverse(PrepackedLayoutB::ppblock_ilvd_NK_to_offset()))
|
||||
.with_shape(make_shape(block_size)));
|
||||
|
||||
// Tile for this specific thread (could have used a TiledCopy but these work
|
||||
// best with 2d layouts, this is a simple 1d layout so local_tile is enough,
|
||||
// we are also not that concerned with performance for this kernel)
|
||||
auto thr_tB_in_linear =
|
||||
local_tile(tB_in_linear, make_shape(eles_per_thread), threadIdx.x);
|
||||
auto thr_tB_out_linear =
|
||||
local_tile(tB_out_linear, make_shape(eles_per_thread), threadIdx.x);
|
||||
|
||||
// Construct a register-backed Tensor with the same shape as each thread's
|
||||
// partition
|
||||
auto fragment = make_tensor<ElementB>(shape(thr_tile_D));
|
||||
auto fragment = make_tensor<ElementB>(shape(thr_tB_in_linear));
|
||||
|
||||
// Copy from GMEM to RMEM and from RMEM to GMEM
|
||||
copy(tiled_copy, thr_tile_S, fragment);
|
||||
copy(Copy_Atom<DefaultCopy, uint8_t>{}, fragment, thr_tile_D);
|
||||
copy(thr_tB_in_linear, fragment);
|
||||
copy(Copy_Atom<DefaultCopy, uint8_t>{}, fragment, thr_tB_out_linear);
|
||||
}
|
||||
|
||||
template <typename PrepackedLayoutB, typename InLayout>
|
||||
@ -44,18 +62,15 @@ static void prepack_B_template(
|
||||
|
||||
TORCH_CHECK(size<0>(B_layout) % size<0>(TileShapeNKL{}) == 0);
|
||||
TORCH_CHECK(size<1>(B_layout) % size<1>(TileShapeNKL{}) == 0);
|
||||
TORCH_CHECK(size<2>(B_layout) % size<2>(TileShapeNKL{}) == 0);
|
||||
|
||||
auto N_tiles = size<0>(B_layout) / size<0>(TileShapeNKL{});
|
||||
auto K_tiles = size<1>(B_layout) / size<1>(TileShapeNKL{});
|
||||
auto L_tiles = size<2>(B_layout) / size<2>(TileShapeNKL{});
|
||||
auto L_tiles = size<2>(B_layout);
|
||||
|
||||
auto B_in = make_tensor(get_logical_ptr(B_in_ptr), B_layout);
|
||||
auto B_tiled_out =
|
||||
make_tensor(get_logical_ptr(B_out_ptr), ilvd_NKbNbKL_to_offset);
|
||||
|
||||
prepack_B_kernel<TileShapeNKL, typename PrepackedLayoutB::ElementB>
|
||||
<<<dim3(N_tiles, K_tiles, L_tiles), 128, 0, stream>>>(B_in, B_tiled_out);
|
||||
prepack_B_kernel<128, PrepackedLayoutB>
|
||||
<<<dim3(N_tiles, K_tiles, L_tiles), 128, 0, stream>>>(B_in, B_out_ptr);
|
||||
}
|
||||
|
||||
}; // namespace machete
|
||||
@ -2,9 +2,17 @@
|
||||
|
||||
#include "machete_prepack_kernel.cuh"
|
||||
#include "cutlass_extensions/torch_utils.hpp"
|
||||
#include "core/scalar_type.hpp"
|
||||
|
||||
namespace machete {
|
||||
|
||||
struct PrepackBArgs {
|
||||
torch::Tensor const& B;
|
||||
at::ScalarType a_type;
|
||||
vllm::ScalarType b_type;
|
||||
c10::optional<at::ScalarType> maybe_group_scales_type;
|
||||
};
|
||||
|
||||
template <typename PrepackedLayoutB>
|
||||
torch::Tensor prepack_impl(torch::Tensor const B) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(B));
|
||||
@ -61,11 +69,6 @@ torch::Tensor prepack_impl(torch::Tensor const B) {
|
||||
return D;
|
||||
};
|
||||
|
||||
template <typename ElementA, typename ElementB, typename ElementD,
|
||||
typename AccumulatorT = float, typename ScaleT = cutlass::half_t,
|
||||
typename ZeroT = cutlass::half_t>
|
||||
struct PrepackBDispatcher {
|
||||
static torch::Tensor dispatch(torch::Tensor B);
|
||||
};
|
||||
torch::Tensor prepack_B_dispatch(PrepackBArgs args);
|
||||
|
||||
}; // namespace machete
|
||||
@ -41,7 +41,7 @@ struct IlvBlkLayoutAuto {};
|
||||
// The contract here is that the `TiledMma` determined below matches the one
|
||||
// ultimately used in the kernel. (this is also why the other element types are
|
||||
// required along with the kernel schedule)
|
||||
template <typename ElementA_, typename ElementB_, typename ElementD_,
|
||||
template <typename ElementA_, typename ElementB_, typename ElementConvert_,
|
||||
typename AccumulatorT, class LayoutB, class KernelSchedule,
|
||||
typename IlvBlkLayout_ = IlvBlkLayoutAuto>
|
||||
// clang-format on
|
||||
@ -49,20 +49,27 @@ struct PrepackedLayoutBTemplate {
|
||||
using MmaType = ElementA_;
|
||||
using ElementA = ElementA_;
|
||||
using ElementB = ElementB_;
|
||||
using ElementD = ElementD_;
|
||||
using ElementAccumulator =
|
||||
AccumulatorT; // Element type for internal accumulation
|
||||
using ElementAccumulator = AccumulatorT;
|
||||
using ElementMma = MmaType;
|
||||
|
||||
// Only use interleaved layouts for subbyte weights, prmt instructions makes
|
||||
// non-interleaved layouts for 8bit+ weights efficient enough we don't need
|
||||
// iterleaved layouts
|
||||
// Interleave for 4bit bit types when we are not upconverting to fp8 or int8,
|
||||
// in those cases case we use a LUT using prmt instructions to upconvert and
|
||||
// is more efficient if the data is not interleaved For 8bit+ prmt
|
||||
// instructions makes non-interleaved layouts efficient enough we don't need
|
||||
// iterleaved layouts (and can reuse more of the existing cutlass converts)
|
||||
static constexpr bool should_interleave =
|
||||
sizeof_bits_v<ElementB> <= 4 &&
|
||||
!std::is_same_v<ElementConvert_, cutlass::float_e4m3_t> &&
|
||||
!std::is_same_v<ElementConvert_, int8_t>;
|
||||
|
||||
// Only use interleaved layouts for subbyte weights,
|
||||
using IlvdBlkLayout = std::conditional_t<
|
||||
std::is_same_v<IlvBlkLayout_, IlvBlkLayoutAuto>,
|
||||
std::conditional_t<sizeof_bits_v<ElementB> <= 4,
|
||||
decltype(get_interleaved_blk_layout<
|
||||
ElementB, sizeof_bits_v<ElementA>, 32>()),
|
||||
void>,
|
||||
std::conditional_t<
|
||||
should_interleave,
|
||||
decltype(get_interleaved_blk_layout<
|
||||
ElementB, sizeof_bits_v<ElementConvert_>, 32>()),
|
||||
void>,
|
||||
IlvBlkLayout_>;
|
||||
|
||||
// TODO (LucasWilkinson): compare the performance for other sizes
|
||||
@ -135,7 +142,8 @@ struct PrepackedLayoutBTemplate {
|
||||
// then ((IlvBlk), FrgB) is {A, C, B, D, C, G, D, H}
|
||||
auto frgV = get<1, 0>(layout_no_interleave);
|
||||
auto ilvdBlk = IlvdBlkLayout{};
|
||||
static_assert(size(frgV) % 4 == 0, "FrgV must be divisible by 4");
|
||||
static_assert(size(frgV) % size(ilvdBlk) == 0,
|
||||
"FrgV must be divisible by size(ilvdBlk)");
|
||||
auto ilvd_FrgV = make_layout(
|
||||
make_shape(shape(ilvdBlk), Int<size(frgV) / size(ilvdBlk)>{}),
|
||||
make_stride(stride(ilvdBlk), size(ilvdBlk)));
|
||||
@ -175,6 +183,15 @@ struct PrepackedLayoutBTemplate {
|
||||
return group<1, 3>(result(_, repeat<rank<1>(result)>(_)));
|
||||
}
|
||||
|
||||
// ((athrid_val), (BlocksN, BlocksK, L)) -> (N, K, L)
|
||||
template <typename Shape_NKL>
|
||||
CUTE_HOST_DEVICE static constexpr auto TVbNbKL_to_offset_copy(
|
||||
Shape_NKL shape_mkl) {
|
||||
auto layout = TVbNbKL_to_offset(shape_mkl);
|
||||
return make_layout(coalesce(get<0>(layout)), get<1>(layout),
|
||||
get<2>(layout));
|
||||
}
|
||||
|
||||
// ((BlockN, BlockK), (BlocksN, BlocksK), L) -> (storage_idx)
|
||||
template <typename Shape_NKL>
|
||||
CUTE_HOST_DEVICE static constexpr auto ilvd_NKbNbKL_to_offset(
|
||||
@ -197,6 +214,19 @@ struct PrepackedLayoutBTemplate {
|
||||
return group<1, 3>(result(_, repeat<rank<1>(result)>(_)));
|
||||
}
|
||||
|
||||
// (BlocksN, BlocksK, L) -> (storage_idx)
|
||||
template <typename Shape_NKL>
|
||||
CUTE_HOST_DEVICE static constexpr auto bNbKL_to_offset(Shape_NKL shape_mkl) {
|
||||
// (BlocksN, BlocksK, L)
|
||||
auto blocks_shape =
|
||||
cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}),
|
||||
[](auto x, auto y) { return x / y; });
|
||||
auto stride = size(PPBlockShape_NK{});
|
||||
|
||||
// (BlocksN, BlocksK, L) -> (storage_idx)
|
||||
return make_layout(blocks_shape, compact_col_major(blocks_shape, stride));
|
||||
}
|
||||
|
||||
// ((athrid, val), (BlocksN, BlocksK, L)) -> (N, K, L)
|
||||
template <class Shape_NKL>
|
||||
CUTE_HOST_DEVICE static auto TVbNbK_to_NKL(Shape_NKL shape_mkl) {
|
||||
|
||||
@ -8,89 +8,61 @@ namespace machete {
|
||||
|
||||
using namespace vllm;
|
||||
|
||||
//
|
||||
// Utils (type dispatching)
|
||||
//
|
||||
|
||||
template <typename Fn>
|
||||
static auto scalar_type_dispatch(ScalarType const& type, Fn fn) {
|
||||
if (type == vllm::kU4) {
|
||||
return fn(cutlass::uint4b_t{});
|
||||
} else if (type == vllm::kU8) {
|
||||
return fn(cutlass::uint8_t{});
|
||||
} else if (type == vllm::kU4B8) {
|
||||
return fn(cutlass::vllm_uint4b8_t{});
|
||||
} else if (type == vllm::kU8B128) {
|
||||
return fn(cutlass::vllm_uint8b128_t{});
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported type ", type.str());
|
||||
}
|
||||
std::vector<std::string> supported_schedules(
|
||||
at::ScalarType a_type, int64_t b_type_id,
|
||||
c10::optional<at::ScalarType> maybe_group_scales_type,
|
||||
c10::optional<at::ScalarType> maybe_group_zeros_type,
|
||||
c10::optional<at::ScalarType> maybe_channel_scales_type,
|
||||
c10::optional<at::ScalarType> maybe_token_scales_type,
|
||||
c10::optional<at::ScalarType> maybe_out_type) {
|
||||
ScalarType const b_type = ScalarType::from_id(b_type_id);
|
||||
return supported_schedules_dispatch({
|
||||
.a_type = a_type,
|
||||
.b_type = b_type,
|
||||
.maybe_group_scales_type = maybe_group_scales_type,
|
||||
.maybe_group_zeros_type = maybe_group_zeros_type,
|
||||
.maybe_channel_scales_type = maybe_channel_scales_type,
|
||||
.maybe_token_scales_type = maybe_token_scales_type,
|
||||
.maybe_out_type = maybe_out_type,
|
||||
});
|
||||
}
|
||||
|
||||
#define AT_DISPATCH_CASE_SUPPORTED_COMPUTE_TYPES(...) \
|
||||
AT_DISPATCH_CASE_REDUCED_FLOATING_TYPES(__VA_ARGS__)
|
||||
|
||||
#define AT_DISPATCH_SUPPORTED_COMPUTE_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, \
|
||||
AT_DISPATCH_CASE_SUPPORTED_COMPUTE_TYPES(__VA_ARGS__))
|
||||
|
||||
//
|
||||
// Interface
|
||||
//
|
||||
|
||||
std::vector<std::string> supported_schedules(ScalarTypeId const btype_id) {
|
||||
#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12
|
||||
vllm::ScalarType b_type = ScalarType::from_id(btype_id);
|
||||
return scalar_type_dispatch(b_type, [&](auto BType) {
|
||||
return GemmDispatcher<half_t, decltype(BType)>::supported_schedules();
|
||||
});
|
||||
#else
|
||||
TORCH_CHECK(false, "Machete requires CUDA 12.0 or later");
|
||||
#endif
|
||||
torch::Tensor mm(torch::Tensor const& A, torch::Tensor const& B,
|
||||
int64_t b_type_id,
|
||||
c10::optional<at::ScalarType> const& maybe_out_type,
|
||||
c10::optional<torch::Tensor> const& maybe_group_scales,
|
||||
c10::optional<torch::Tensor> const& maybe_group_zeros,
|
||||
c10::optional<int64_t> maybe_group_size,
|
||||
c10::optional<torch::Tensor> const& maybe_channel_scales,
|
||||
c10::optional<torch::Tensor> const& maybe_token_scales,
|
||||
c10::optional<std::string> maybe_schedule) {
|
||||
ScalarType const b_type = ScalarType::from_id(b_type_id);
|
||||
return mm_dispatch({.A = A,
|
||||
.B = B,
|
||||
.b_type = b_type,
|
||||
.maybe_out_type = maybe_out_type,
|
||||
.maybe_group_scales = maybe_group_scales,
|
||||
.maybe_group_zeros = maybe_group_zeros,
|
||||
.maybe_group_size = maybe_group_size,
|
||||
.maybe_channel_scales = maybe_channel_scales,
|
||||
.maybe_token_scales = maybe_token_scales,
|
||||
.maybe_schedule = maybe_schedule});
|
||||
}
|
||||
|
||||
torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
|
||||
ScalarTypeId const btype_id,
|
||||
c10::optional<torch::Tensor> const& scales,
|
||||
c10::optional<torch::Tensor> const& zeros,
|
||||
c10::optional<int64_t> group_size,
|
||||
c10::optional<torch::Tensor> const& C,
|
||||
c10::optional<double> alpha, c10::optional<double> beta,
|
||||
c10::optional<std::string> schedule) {
|
||||
#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12
|
||||
ScalarType const btype = ScalarType::from_id(btype_id);
|
||||
auto args = PyTorchArguments{.A = A,
|
||||
.B = B,
|
||||
.scales = scales,
|
||||
.zeros = zeros,
|
||||
.group_size = group_size,
|
||||
.C = C,
|
||||
.alpha = alpha,
|
||||
.beta = beta,
|
||||
.schedule = schedule};
|
||||
|
||||
return scalar_type_dispatch(btype, [&](auto BType) {
|
||||
return AT_DISPATCH_SUPPORTED_COMPUTE_TYPES(
|
||||
A.scalar_type(), "machete_gemm", [&] {
|
||||
using ComputeType = equivalent_cutlass_type_t<scalar_t>;
|
||||
return GemmDispatcher<ComputeType, decltype(BType)>::dispatch(args);
|
||||
});
|
||||
});
|
||||
#else
|
||||
TORCH_CHECK(false, "Machete requires CUDA 12.0 or later");
|
||||
#endif
|
||||
}
|
||||
|
||||
torch::Tensor prepack_B(torch::Tensor const& B, ScalarTypeId const btype_id) {
|
||||
ScalarType const btype = ScalarType::from_id(btype_id);
|
||||
return scalar_type_dispatch(btype, [&](auto BType) {
|
||||
return PrepackBDispatcher<half_t, decltype(BType), half_t>::dispatch(B);
|
||||
});
|
||||
torch::Tensor prepack_B(
|
||||
torch::Tensor const& B, at::ScalarType const& a_type, int64_t b_type_id,
|
||||
c10::optional<at::ScalarType> const& maybe_group_scales_type) {
|
||||
ScalarType const b_type = ScalarType::from_id(b_type_id);
|
||||
return prepack_B_dispatch(
|
||||
{.B = B,
|
||||
.a_type = a_type,
|
||||
.b_type = b_type,
|
||||
.maybe_group_scales_type = maybe_group_scales_type});
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("machete_prepack_B", &prepack_B);
|
||||
m.impl("machete_gemm", &gemm);
|
||||
m.impl("machete_mm", &mm);
|
||||
}
|
||||
|
||||
// use CatchAll since supported_schedules has no tensor arguments
|
||||
|
||||
Reference in New Issue
Block a user