Compare commits
1 Commits
use-uv-pyt
...
memory-lea
| Author | SHA1 | Date | |
|---|---|---|---|
| 2686925630 |
@ -418,35 +418,39 @@ class MultiHeadAttention(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def wait_for_kv_layer_from_connector(layer_name: str):
|
def wait_for_kv_layer_from_connector(layer_name: str):
|
||||||
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
|
print("hi --- wait_for_kv_layer_from_connector")
|
||||||
return
|
pass
|
||||||
|
# if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
|
||||||
|
# return
|
||||||
|
|
||||||
connector = get_kv_transfer_group()
|
# connector = get_kv_transfer_group()
|
||||||
|
|
||||||
forward_context: ForwardContext = get_forward_context()
|
# forward_context: ForwardContext = get_forward_context()
|
||||||
attn_metadata = forward_context.attn_metadata
|
# attn_metadata = forward_context.attn_metadata
|
||||||
if attn_metadata is None:
|
# if attn_metadata is None:
|
||||||
return
|
# return
|
||||||
assert isinstance(attn_metadata, dict)
|
# assert isinstance(attn_metadata, dict)
|
||||||
connector.wait_for_layer_load(layer_name)
|
# connector.wait_for_layer_load(layer_name)
|
||||||
|
|
||||||
|
|
||||||
def maybe_save_kv_layer_to_connector(
|
def maybe_save_kv_layer_to_connector(
|
||||||
layer_name: str,
|
layer_name: str,
|
||||||
kv_cache_layer: List[torch.Tensor],
|
kv_cache_layer: List[torch.Tensor],
|
||||||
):
|
):
|
||||||
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
|
print("hi --- maybe_save_kv_layer_to_connector")
|
||||||
return
|
pass
|
||||||
|
# if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
|
||||||
|
# return
|
||||||
|
|
||||||
connector = get_kv_transfer_group()
|
# connector = get_kv_transfer_group()
|
||||||
|
|
||||||
forward_context: ForwardContext = get_forward_context()
|
# forward_context: ForwardContext = get_forward_context()
|
||||||
attn_metadata = forward_context.attn_metadata
|
# attn_metadata = forward_context.attn_metadata
|
||||||
if attn_metadata is None:
|
# if attn_metadata is None:
|
||||||
return
|
# return
|
||||||
assert isinstance(attn_metadata, dict)
|
# assert isinstance(attn_metadata, dict)
|
||||||
connector.save_kv_layer(layer_name, kv_cache_layer,
|
# connector.save_kv_layer(layer_name, kv_cache_layer,
|
||||||
attn_metadata[layer_name])
|
# attn_metadata[layer_name])
|
||||||
|
|
||||||
|
|
||||||
def unified_attention(
|
def unified_attention(
|
||||||
@ -497,7 +501,7 @@ def unified_attention_with_output(
|
|||||||
output_scale: Optional[torch.Tensor] = None,
|
output_scale: Optional[torch.Tensor] = None,
|
||||||
output_block_scale: Optional[torch.Tensor] = None,
|
output_block_scale: Optional[torch.Tensor] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
wait_for_kv_layer_from_connector(layer_name)
|
# wait_for_kv_layer_from_connector(layer_name)
|
||||||
forward_context: ForwardContext = get_forward_context()
|
forward_context: ForwardContext = get_forward_context()
|
||||||
attn_metadata = forward_context.attn_metadata
|
attn_metadata = forward_context.attn_metadata
|
||||||
if isinstance(attn_metadata, dict):
|
if isinstance(attn_metadata, dict):
|
||||||
@ -514,7 +518,7 @@ def unified_attention_with_output(
|
|||||||
output_scale=output_scale,
|
output_scale=output_scale,
|
||||||
output_block_scale=output_block_scale)
|
output_block_scale=output_block_scale)
|
||||||
|
|
||||||
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
|
# maybe_save_kv_layer_to_connector(layer_name, kv_cache)
|
||||||
|
|
||||||
|
|
||||||
def unified_attention_with_output_fake(
|
def unified_attention_with_output_fake(
|
||||||
|
|||||||
Reference in New Issue
Block a user