From 2080b05099b4bf74940a43319babf85c0e0dde57 Mon Sep 17 00:00:00 2001 From: Fadi Arafeh <115173828+fadara01@users.noreply.github.com> Date: Fri, 24 Oct 2025 16:57:48 +0100 Subject: [PATCH] [cpu][fix] Fix onednn_mm crash on consecutive matmuls with same M,K,N and different dtype (#27472) Signed-off-by: Fadi Arafeh --- csrc/cpu/dnnl_helper.cpp | 12 ++++++++---- csrc/cpu/dnnl_helper.h | 1 + 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/csrc/cpu/dnnl_helper.cpp b/csrc/cpu/dnnl_helper.cpp index 0f0cc34602..bb43aeee2e 100644 --- a/csrc/cpu/dnnl_helper.cpp +++ b/csrc/cpu/dnnl_helper.cpp @@ -187,7 +187,8 @@ template <> struct hash { size_t operator()( const MatMulPrimitiveHandler::ClassMatmulCacheKey& val) const { - return hash()(val.b_n_size) ^ hash()(val.b_k_size); + return hash()(val.b_n_size) ^ hash()(val.b_k_size) ^ + hash()(static_cast(val.b_type)); } }; @@ -216,7 +217,8 @@ bool operator==(const W8A8MatMulPrimitiveHandler::MSizeCacheKey& l, bool operator==(const MatMulPrimitiveHandler::ClassMatmulCacheKey& l, const MatMulPrimitiveHandler::ClassMatmulCacheKey& r) { - return l.b_n_size == r.b_n_size && l.b_k_size == r.b_k_size; + return l.b_n_size == r.b_n_size && l.b_k_size == r.b_k_size && + l.b_type == r.b_type; } bool operator==(const MatMulPrimitiveHandler::MSizeCacheKey& l, @@ -493,8 +495,10 @@ void MatMulPrimitiveHandler::execute(ExecArgs& args) { dnnl::matmul MatMulPrimitiveHandler::get_matmul_cache( const MSizeCacheKey& key) { if (m_size_cache_.get() == nullptr) { - ClassMatmulCacheKey key = {.b_n_size = b_n_size_, .b_k_size = b_k_size_}; - m_size_cache_ = get_matul_class_primitive_cache(key, primitive_cache_size_); + ClassMatmulCacheKey class_key = { + .b_n_size = b_n_size_, .b_k_size = b_k_size_, .b_type = b_type_}; + m_size_cache_ = + get_matul_class_primitive_cache(class_key, primitive_cache_size_); } return m_size_cache_->get_or_create(key, [&]() { dnnl::matmul::primitive_desc desc = this->create_primitive_desc(key, false); diff --git a/csrc/cpu/dnnl_helper.h b/csrc/cpu/dnnl_helper.h index f0cb197d81..58ffe7a19b 100644 --- a/csrc/cpu/dnnl_helper.h +++ b/csrc/cpu/dnnl_helper.h @@ -199,6 +199,7 @@ class MatMulPrimitiveHandler : public DNNLMatMulPrimitiveHandler { struct ClassMatmulCacheKey { dnnl_dim_t b_n_size; dnnl_dim_t b_k_size; + dnnl::memory::data_type b_type; friend bool operator==(const ClassMatmulCacheKey& l, const ClassMatmulCacheKey& r);