Compare commits

...

1 Commits

Author SHA1 Message Date
2686925630 updated
Signed-off-by: Robert Shaw <robshaw@redhat.com>
2025-09-08 21:06:29 +00:00

View File

@ -418,35 +418,39 @@ class MultiHeadAttention(nn.Module):
def wait_for_kv_layer_from_connector(layer_name: str): def wait_for_kv_layer_from_connector(layer_name: str):
if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): print("hi --- wait_for_kv_layer_from_connector")
return pass
# if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
# return
connector = get_kv_transfer_group() # connector = get_kv_transfer_group()
forward_context: ForwardContext = get_forward_context() # forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata # attn_metadata = forward_context.attn_metadata
if attn_metadata is None: # if attn_metadata is None:
return # return
assert isinstance(attn_metadata, dict) # assert isinstance(attn_metadata, dict)
connector.wait_for_layer_load(layer_name) # connector.wait_for_layer_load(layer_name)
def maybe_save_kv_layer_to_connector( def maybe_save_kv_layer_to_connector(
layer_name: str, layer_name: str,
kv_cache_layer: List[torch.Tensor], kv_cache_layer: List[torch.Tensor],
): ):
if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): print("hi --- maybe_save_kv_layer_to_connector")
return pass
# if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
# return
connector = get_kv_transfer_group() # connector = get_kv_transfer_group()
forward_context: ForwardContext = get_forward_context() # forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata # attn_metadata = forward_context.attn_metadata
if attn_metadata is None: # if attn_metadata is None:
return # return
assert isinstance(attn_metadata, dict) # assert isinstance(attn_metadata, dict)
connector.save_kv_layer(layer_name, kv_cache_layer, # connector.save_kv_layer(layer_name, kv_cache_layer,
attn_metadata[layer_name]) # attn_metadata[layer_name])
def unified_attention( def unified_attention(
@ -497,7 +501,7 @@ def unified_attention_with_output(
output_scale: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None, output_block_scale: Optional[torch.Tensor] = None,
) -> None: ) -> None:
wait_for_kv_layer_from_connector(layer_name) # wait_for_kv_layer_from_connector(layer_name)
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict): if isinstance(attn_metadata, dict):
@ -514,7 +518,7 @@ def unified_attention_with_output(
output_scale=output_scale, output_scale=output_scale,
output_block_scale=output_block_scale) output_block_scale=output_block_scale)
maybe_save_kv_layer_to_connector(layer_name, kv_cache) # maybe_save_kv_layer_to_connector(layer_name, kv_cache)
def unified_attention_with_output_fake( def unified_attention_with_output_fake(