@ -324,12 +324,14 @@ class FlashAttentionMetadataBuilder:
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
return False
|
||||
|
||||
def build_slice(self, req_slice: slice,
|
||||
token_slice: slice,
|
||||
max_query_len: int,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
) -> FlashAttentionMetadata:
|
||||
def build_slice(
|
||||
self,
|
||||
req_slice: slice,
|
||||
token_slice: slice,
|
||||
max_query_len: int,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
) -> FlashAttentionMetadata:
|
||||
num_reqs = req_slice.stop - req_slice.start
|
||||
num_tokens = token_slice.stop - token_slice.start
|
||||
|
||||
@ -482,7 +484,7 @@ class FlashAttentionMetadataBuilder:
|
||||
)
|
||||
|
||||
def use_cascade_attention(self, *args, **kwargs) -> bool:
|
||||
return False #use_cascade_attention(*args, **kwargs)
|
||||
return False #use_cascade_attention(*args, **kwargs)
|
||||
|
||||
|
||||
class FlashAttentionImpl(AttentionImpl):
|
||||
|
||||
Reference in New Issue
Block a user