llama4_vision_rope: add HIP override to accept (q, k) and avoid (positions, q, k) mismatch (#26790)
Signed-off-by: Huamin Li <3ericli@gmail.com>
This commit is contained in:
@ -78,3 +78,10 @@ class Llama4VisionRotaryEmbedding(RotaryEmbedding):
|
||||
key: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
return self.forward_native(query, key)
|
||||
|
||||
def forward_hip( # type: ignore[override]
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
return self.forward_native(query, key)
|
||||
|
||||
Reference in New Issue
Block a user