@ -131,6 +131,7 @@ class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
self.bmm2_scale = layer._v_scale_float
|
||||
|
||||
if vllm_is_batch_invariant():
|
||||
# TODO(wentao): optimize this when it is supported by Flashinfer upstream.
|
||||
# execute per-request to eliminate batch-shape-dependent kernel paths.
|
||||
num = q.shape[0]
|
||||
outs = []
|
||||
|
||||
Reference in New Issue
Block a user