diff --git a/vllm/model_executor/models/jax/ops/write_to_cache.py b/vllm/model_executor/models/jax/ops/write_to_cache.py index 8124bd6fd2..cdd562afcd 100644 --- a/vllm/model_executor/models/jax/ops/write_to_cache.py +++ b/vllm/model_executor/models/jax/ops/write_to_cache.py @@ -1,3 +1,5 @@ +from typing import Tuple + import chex import jax import jax.numpy as jnp