[BugFix] FA2 MLA Accuracy Issue (#18807)
Signed-off-by: LucasWilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
@ -143,6 +143,14 @@ void merge_attn_states_launcher(torch::Tensor& output,
|
||||
const uint pack_size = 16 / sizeof(scalar_t);
|
||||
TORCH_CHECK(head_size % pack_size == 0,
|
||||
"headsize must be multiple of pack_size:", pack_size);
|
||||
TORCH_CHECK(output.stride(-2) == head_size && output.stride(-1) == 1,
|
||||
"output heads must be contiguous in memory");
|
||||
TORCH_CHECK(
|
||||
prefix_output.stride(-2) == head_size && prefix_output.stride(-1) == 1,
|
||||
"prefix_output heads must be contiguous in memory");
|
||||
TORCH_CHECK(
|
||||
suffix_output.stride(-2) == head_size && suffix_output.stride(-1) == 1,
|
||||
"suffix_output heads must be contiguous in memory");
|
||||
float* output_lse_ptr = nullptr;
|
||||
if (output_lse.has_value()) {
|
||||
output_lse_ptr = output_lse.value().data_ptr<float>();
|
||||
|
||||
Reference in New Issue
Block a user