diff --git a/csrc/cpu/dnnl_helper.cpp b/csrc/cpu/dnnl_helper.cpp index 0f0cc34602..bb43aeee2e 100644 --- a/csrc/cpu/dnnl_helper.cpp +++ b/csrc/cpu/dnnl_helper.cpp @@ -187,7 +187,8 @@ template <> struct hash { size_t operator()( const MatMulPrimitiveHandler::ClassMatmulCacheKey& val) const { - return hash()(val.b_n_size) ^ hash()(val.b_k_size); + return hash()(val.b_n_size) ^ hash()(val.b_k_size) ^ + hash()(static_cast(val.b_type)); } }; @@ -216,7 +217,8 @@ bool operator==(const W8A8MatMulPrimitiveHandler::MSizeCacheKey& l, bool operator==(const MatMulPrimitiveHandler::ClassMatmulCacheKey& l, const MatMulPrimitiveHandler::ClassMatmulCacheKey& r) { - return l.b_n_size == r.b_n_size && l.b_k_size == r.b_k_size; + return l.b_n_size == r.b_n_size && l.b_k_size == r.b_k_size && + l.b_type == r.b_type; } bool operator==(const MatMulPrimitiveHandler::MSizeCacheKey& l, @@ -493,8 +495,10 @@ void MatMulPrimitiveHandler::execute(ExecArgs& args) { dnnl::matmul MatMulPrimitiveHandler::get_matmul_cache( const MSizeCacheKey& key) { if (m_size_cache_.get() == nullptr) { - ClassMatmulCacheKey key = {.b_n_size = b_n_size_, .b_k_size = b_k_size_}; - m_size_cache_ = get_matul_class_primitive_cache(key, primitive_cache_size_); + ClassMatmulCacheKey class_key = { + .b_n_size = b_n_size_, .b_k_size = b_k_size_, .b_type = b_type_}; + m_size_cache_ = + get_matul_class_primitive_cache(class_key, primitive_cache_size_); } return m_size_cache_->get_or_create(key, [&]() { dnnl::matmul::primitive_desc desc = this->create_primitive_desc(key, false); diff --git a/csrc/cpu/dnnl_helper.h b/csrc/cpu/dnnl_helper.h index f0cb197d81..58ffe7a19b 100644 --- a/csrc/cpu/dnnl_helper.h +++ b/csrc/cpu/dnnl_helper.h @@ -199,6 +199,7 @@ class MatMulPrimitiveHandler : public DNNLMatMulPrimitiveHandler { struct ClassMatmulCacheKey { dnnl_dim_t b_n_size; dnnl_dim_t b_k_size; + dnnl::memory::data_type b_type; friend bool operator==(const ClassMatmulCacheKey& l, const ClassMatmulCacheKey& r);