Fix per file ruff ignores related to simplification (#26259)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@ -101,40 +101,6 @@ include = ["vllm*"]
|
||||
"vllm/v1/engine/utils.py" = ["E501"]
|
||||
"vllm/v1/utils.py" = ["E501"]
|
||||
"vllm/v1/worker/gpu_model_runner.py" = ["E501"]
|
||||
## Simplification rules
|
||||
"tests/distributed/test_expert_placement.py" = ["SIM108"]
|
||||
"tests/kernels/attention/test_cutlass_mla_decode.py" = ["SIM108"]
|
||||
"tests/kernels/attention/test_flashmla.py" = ["SIM108"]
|
||||
"tests/kernels/attention/test_lightning_attn.py" = ["SIM108"]
|
||||
"tests/kernels/moe/test_pplx_moe.py" = ["SIM108"]
|
||||
"tests/kernels/quantization/test_cutlass_scaled_mm.py" = ["SIM108"]
|
||||
"tests/kernels/test_onednn.py" = ["SIM108"]
|
||||
"tests/kernels/utils.py" = ["SIM108"]
|
||||
"tests/multimodal/test_processing.py" = ["SIM108"]
|
||||
"vllm/attention/ops/triton_reshape_and_cache_flash.py" = ["SIM108"]
|
||||
"vllm/distributed/parallel_state.py" = ["SIM108"]
|
||||
"vllm/entrypoints/chat_utils.py" = ["SIM108"]
|
||||
"vllm/entrypoints/llm.py" = ["SIM108"]
|
||||
"vllm/executor/ray_distributed_executor.py" = ["SIM108", "SIM112"]
|
||||
"vllm/model_executor/layers/batch_invariant.py" = ["SIM108"]
|
||||
"vllm/model_executor/layers/fla/ops/chunk_o.py" = ["SIM108"]
|
||||
"vllm/model_executor/layers/fused_moe/fused_moe.py" = ["SIM108"]
|
||||
"vllm/model_executor/layers/fused_moe/layer.py" = ["SIM108"]
|
||||
"vllm/model_executor/layers/fused_moe/modular_kernel.py" = ["SIM108"]
|
||||
"vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py" = ["SIM108"]
|
||||
"vllm/model_executor/layers/layernorm.py" = ["SIM108"]
|
||||
"vllm/model_executor/layers/lightning_attn.py" = ["SIM108"]
|
||||
"vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py" = ["SIM103"]
|
||||
"vllm/model_executor/layers/quantization/compressed_tensors/utils.py" = ["SIM110"]
|
||||
"vllm/model_executor/layers/quantization/quark/utils.py" = ["SIM110"]
|
||||
"vllm/utils/__init__.py" = ["SIM108"]
|
||||
"vllm/v1/sample/ops/bad_words.py" = ["SIM108"]
|
||||
"vllm/v1/sample/rejection_sampler.py" = ["SIM108"]
|
||||
"vllm/v1/worker/tpu_model_runner.py" = ["SIM108"]
|
||||
"vllm/_custom_ops.py" = ["SIM108"]
|
||||
"tools/profiler/print_layerwise_table.py" = ["SIM118"]
|
||||
## Loop variable binding issues
|
||||
"tests/kernels/mamba/test_mamba_ssm_ssd.py" = ["B023"]
|
||||
# End of temporary ignores
|
||||
|
||||
[tool.ruff.lint]
|
||||
|
||||
@ -12,10 +12,7 @@ def verify_round_robin_pattern(expert_map, ep_rank, ep_size, global_num_experts)
|
||||
base_experts = global_num_experts // ep_size
|
||||
remainder = global_num_experts % ep_size
|
||||
|
||||
if ep_rank < remainder:
|
||||
local_num_experts = base_experts + 1
|
||||
else:
|
||||
local_num_experts = base_experts
|
||||
local_num_experts = base_experts + 1 if ep_rank < remainder else base_experts
|
||||
|
||||
# Expected expert IDs for this rank in round_robin pattern
|
||||
# For non-divisible cases, ranks with extra experts start earlier
|
||||
|
||||
@ -66,10 +66,7 @@ def test_cutlass_mla_decode(
|
||||
b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, varlen, torch_dtype
|
||||
):
|
||||
device = torch.device("cuda:0")
|
||||
if torch_dtype == torch.float8_e4m3fn:
|
||||
init_dtype = torch.bfloat16
|
||||
else:
|
||||
init_dtype = torch_dtype
|
||||
init_dtype = torch.bfloat16 if torch_dtype == torch.float8_e4m3fn else torch_dtype
|
||||
torch.set_default_dtype(init_dtype)
|
||||
torch.set_default_device(device)
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
@ -52,10 +52,7 @@ def test_flash_mla(
|
||||
b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, varlen, torch_dtype
|
||||
):
|
||||
device = torch.device("cuda:0")
|
||||
if torch_dtype == torch.float8_e4m3fn:
|
||||
init_dtype = torch.bfloat16
|
||||
else:
|
||||
init_dtype = torch_dtype
|
||||
init_dtype = torch.bfloat16 if torch_dtype == torch.float8_e4m3fn else torch_dtype
|
||||
torch.set_default_dtype(init_dtype)
|
||||
torch.set_default_device(device)
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
@ -33,10 +33,7 @@ def reference_lightning_attention(q, k, v, ed, block_size, kv_history):
|
||||
|
||||
# More efficient implementation
|
||||
# Convert decay factors to matrix form
|
||||
if ed.dim() == 1:
|
||||
decay = torch.exp(-ed).view(1, -1, 1, 1)
|
||||
else:
|
||||
decay = torch.exp(-ed)
|
||||
decay = torch.exp(-ed).view(1, -1, 1, 1) if ed.dim() == 1 else torch.exp(-ed)
|
||||
|
||||
for b in range(B):
|
||||
for step in range(S):
|
||||
|
||||
@ -705,10 +705,7 @@ def _pplx_moe(
|
||||
with set_current_vllm_config(vllm_config), override_config(moe_config):
|
||||
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||
|
||||
if shared_experts is not None:
|
||||
shared_output = shared_experts(a)
|
||||
else:
|
||||
shared_output = None
|
||||
shared_output = shared_experts(a) if shared_experts is not None else None
|
||||
|
||||
torch_output = torch_experts(
|
||||
a,
|
||||
|
||||
@ -88,10 +88,7 @@ def cutlass_fp8_gemm_helper(
|
||||
# make scales K-major for blockwise quant, doesn't affect 1D scales
|
||||
scale_b = scale_b.t().contiguous().t()
|
||||
|
||||
if use_bias:
|
||||
bias = torch.rand((n,), device=device, dtype=out_dtype) * 10
|
||||
else:
|
||||
bias = None
|
||||
bias = torch.rand((n,), device=device, dtype=out_dtype) * 10 if use_bias else None
|
||||
|
||||
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||
@ -122,10 +119,7 @@ def cutlass_int8_gemm_helper(
|
||||
scale_a = torch.randn(a_scales_shape, device=device, dtype=torch.float32)
|
||||
scale_b = torch.randn(b_scales_shape, device=device, dtype=torch.float32)
|
||||
|
||||
if use_bias:
|
||||
bias = torch.rand((n,), device=device, dtype=out_dtype) * 10
|
||||
else:
|
||||
bias = None
|
||||
bias = torch.rand((n,), device=device, dtype=out_dtype) * 10 if use_bias else None
|
||||
|
||||
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||
|
||||
@ -84,10 +84,7 @@ def onednn_int8_gemm_test_helper(
|
||||
azp = None
|
||||
azp_adj = None
|
||||
|
||||
if use_bias:
|
||||
bias = torch.rand((n,), device=device, dtype=out_dtype) * 10
|
||||
else:
|
||||
bias = None
|
||||
bias = torch.rand((n,), device=device, dtype=out_dtype) * 10 if use_bias else None
|
||||
|
||||
handler = ops.create_onednn_scaled_mm(
|
||||
b,
|
||||
|
||||
@ -963,13 +963,9 @@ def make_test_metadata(
|
||||
None if encoder_seq_lens is None else (sum(encoder_seq_lens))
|
||||
)
|
||||
|
||||
if cross_test_params is None:
|
||||
cross_kv_mmap = None
|
||||
else:
|
||||
# Encoder/decoder or encoder-only models only:
|
||||
# * Extract *cross-attention* slot_mapping and block table
|
||||
# (kv_mmap)
|
||||
cross_kv_mmap = cross_test_params.kv_mmap
|
||||
# For encoder/decoder or encoder-only models only, extract *cross-attention*
|
||||
# slot_mapping and block table (kv_mmap)
|
||||
cross_kv_mmap = None if cross_test_params is None else cross_test_params.kv_mmap
|
||||
|
||||
attn_backend_obj = make_backend(attn_backend.name)
|
||||
|
||||
|
||||
@ -941,10 +941,7 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
|
||||
|
||||
profiler = MultiModalProfiler(processor)
|
||||
|
||||
if is_valid:
|
||||
exc_ctx = nullcontext()
|
||||
else:
|
||||
exc_ctx = pytest.raises(ValueError, match="At most")
|
||||
exc_ctx = nullcontext() if is_valid else pytest.raises(ValueError, match="At most")
|
||||
|
||||
with exc_ctx:
|
||||
profiler.get_decoder_dummy_data(
|
||||
@ -985,10 +982,7 @@ def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid):
|
||||
else:
|
||||
mm_data = {"image": [image] * num_images}
|
||||
|
||||
if is_valid:
|
||||
exc_ctx = nullcontext()
|
||||
else:
|
||||
exc_ctx = pytest.raises(ValueError, match="At most")
|
||||
exc_ctx = nullcontext() if is_valid else pytest.raises(ValueError, match="At most")
|
||||
|
||||
with exc_ctx:
|
||||
processor.apply(
|
||||
|
||||
@ -58,7 +58,7 @@ if __name__ == "__main__":
|
||||
|
||||
assert args.phase in profile_data, (
|
||||
f"Cannot find phase {args.phase} in profile data. Choose one among"
|
||||
f"{[x for x in profile_data.keys() if 'prefill' in x or 'decode' in x]}"
|
||||
f"{[x for x in profile_data if 'prefill' in x or 'decode' in x]}"
|
||||
) # noqa
|
||||
|
||||
if args.table == "summary":
|
||||
|
||||
@ -2370,10 +2370,7 @@ class CPUDNNLGEMMHandler:
|
||||
torch.ops._C.release_dnnl_matmul_handler(self.handler)
|
||||
|
||||
|
||||
if hasattr(torch.ops._C, "create_onednn_mm_handler"):
|
||||
_supports_onednn = True
|
||||
else:
|
||||
_supports_onednn = False
|
||||
_supports_onednn = bool(hasattr(torch.ops._C, "create_onednn_mm_handler"))
|
||||
|
||||
|
||||
def is_onednn_acl_supported():
|
||||
|
||||
@ -52,12 +52,9 @@ def reshape_and_cache_kernel_flash(
|
||||
key_ptr + src_key_idx + tile_pos, mask=tile_pos < (num_heads * head_size)
|
||||
)
|
||||
if FP8_KV_CACHE:
|
||||
if key_load.dtype.is_fp8():
|
||||
key_tile = key_load
|
||||
else:
|
||||
# tl.store will do the correct implicit cast to fp8,
|
||||
# based on the key_cache_ptr.dtype.element_ty
|
||||
key_tile = key_load / tl.load(k_scale)
|
||||
# tl.store will do the correct implicit cast to fp8,
|
||||
# based on the key_cache_ptr.dtype.element_ty
|
||||
key_tile = key_load if key_load.dtype.is_fp8() else key_load / tl.load(k_scale)
|
||||
else:
|
||||
key_tile = key_load
|
||||
|
||||
|
||||
@ -1097,10 +1097,7 @@ def init_distributed_environment(
|
||||
if local_rank == -1:
|
||||
# local rank not set, this usually happens in single-node
|
||||
# setting, where we can use rank as local rank
|
||||
if distributed_init_method == "env://":
|
||||
local_rank = envs.LOCAL_RANK
|
||||
else:
|
||||
local_rank = rank
|
||||
local_rank = envs.LOCAL_RANK if distributed_init_method == "env://" else rank
|
||||
global _WORLD, _NODE_COUNT
|
||||
if _WORLD is None:
|
||||
ranks = list(range(torch.distributed.get_world_size()))
|
||||
|
||||
@ -1310,10 +1310,7 @@ def _parse_chat_message_content_part(
|
||||
|
||||
modality = None
|
||||
if part_type == "image_pil":
|
||||
if content is not None:
|
||||
image_content = cast(Image.Image, content)
|
||||
else:
|
||||
image_content = None
|
||||
image_content = cast(Image.Image, content) if content is not None else None
|
||||
mm_parser.parse_image_pil(image_content, uuid)
|
||||
modality = "image"
|
||||
elif part_type in ("image_url", "input_image"):
|
||||
|
||||
@ -1018,10 +1018,7 @@ class LLM:
|
||||
pooling_task = "encode"
|
||||
|
||||
if pooling_task is None:
|
||||
if "embed" in self.supported_tasks:
|
||||
pooling_task = "embed"
|
||||
else:
|
||||
pooling_task = "encode"
|
||||
pooling_task = "embed" if "embed" in self.supported_tasks else "encode"
|
||||
|
||||
logger.warning_once(
|
||||
"`LLM.encode` is currently using `pooling_task = %s`.\n"
|
||||
|
||||
@ -458,10 +458,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
|
||||
else:
|
||||
serialized_data = self.input_encoder.encode(execute_model_req)
|
||||
outputs = ray.get(self.forward_dag.execute(serialized_data))
|
||||
if self.use_v1:
|
||||
output = outputs[0]
|
||||
else:
|
||||
output = self.output_decoder.decode(outputs[0])
|
||||
output = outputs[0] if self.use_v1 else self.output_decoder.decode(outputs[0])
|
||||
return output
|
||||
|
||||
def _run_workers(
|
||||
@ -482,10 +479,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
|
||||
rather than blocking on the results.
|
||||
- args/kwargs: All workers share the same args/kwargs
|
||||
"""
|
||||
if isinstance(method, str):
|
||||
sent_method = method
|
||||
else:
|
||||
sent_method = cloudpickle.dumps(method)
|
||||
sent_method = method if isinstance(method, str) else cloudpickle.dumps(method)
|
||||
del method
|
||||
if self.use_ray_spmd_worker:
|
||||
assert not async_run_tensor_parallel_workers_only, (
|
||||
@ -573,8 +567,9 @@ class RayDistributedExecutor(DistributedExecutorBase):
|
||||
from ray.dag import InputNode, MultiOutputNode
|
||||
|
||||
logger.info(
|
||||
"RAY_CGRAPH_get_timeout is set to %s", os.environ["RAY_CGRAPH_get_timeout"]
|
||||
) # noqa: SIM112
|
||||
"RAY_CGRAPH_get_timeout is set to %s",
|
||||
os.environ["RAY_CGRAPH_get_timeout"], # noqa: SIM112
|
||||
)
|
||||
logger.info(
|
||||
"VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE = %s",
|
||||
envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE,
|
||||
|
||||
@ -439,10 +439,7 @@ def mean_dim(
|
||||
output = torch.empty(output_shape, dtype=dtype, device=input.device)
|
||||
|
||||
# Reshape output for kernel
|
||||
if keepdim:
|
||||
output_2d = output.reshape(M, 1, K).squeeze(1)
|
||||
else:
|
||||
output_2d = output.reshape(M, K)
|
||||
output_2d = output.reshape(M, 1, K).squeeze(1) if keepdim else output.reshape(M, K)
|
||||
|
||||
# Launch kernel
|
||||
grid = (M * K,)
|
||||
|
||||
@ -151,10 +151,7 @@ def chunk_fwd_o(
|
||||
) -> torch.Tensor:
|
||||
B, T, Hg, K, V = *q.shape, v.shape[-1]
|
||||
H = v.shape[-2]
|
||||
if FLA_GDN_FIX_BT:
|
||||
BT = 64
|
||||
else:
|
||||
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
|
||||
BT = 64 if FLA_GDN_FIX_BT else min(chunk_size, max(16, triton.next_power_of_2(T)))
|
||||
chunk_indices = (
|
||||
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
||||
)
|
||||
|
||||
@ -1746,10 +1746,7 @@ def fused_experts_impl(
|
||||
else:
|
||||
raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")
|
||||
|
||||
if inplace:
|
||||
out_hidden_states = hidden_states
|
||||
else:
|
||||
out_hidden_states = torch.empty_like(hidden_states)
|
||||
out_hidden_states = hidden_states if inplace else torch.empty_like(hidden_states)
|
||||
|
||||
if use_mxfp4_w4a4:
|
||||
# Weight has to be dequantized for mxfp4 emulation.
|
||||
|
||||
@ -886,10 +886,7 @@ def determine_expert_map(
|
||||
# Distribute experts as evenly as possible to each rank.
|
||||
base_experts = global_num_experts // ep_size
|
||||
remainder = global_num_experts % ep_size
|
||||
if ep_rank < remainder:
|
||||
local_num_experts = base_experts + 1
|
||||
else:
|
||||
local_num_experts = base_experts
|
||||
local_num_experts = base_experts + 1 if ep_rank < remainder else base_experts
|
||||
|
||||
# Create a tensor of size num_experts filled with -1
|
||||
expert_map = torch.full((global_num_experts,), -1, dtype=torch.int32)
|
||||
|
||||
@ -948,10 +948,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
"""
|
||||
|
||||
a1 = hidden_states
|
||||
if inplace and self.shared_experts is None:
|
||||
output = a1
|
||||
else:
|
||||
output = torch.zeros_like(a1)
|
||||
output = a1 if inplace and self.shared_experts is None else torch.zeros_like(a1)
|
||||
|
||||
local_num_experts = w1.size(0)
|
||||
if global_num_experts == -1:
|
||||
|
||||
@ -355,10 +355,7 @@ def rocm_aiter_fused_experts(
|
||||
topk_weights = topk_weights.to(torch.float32)
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
|
||||
if expert_map is not None:
|
||||
expert_mask = (expert_map > -1).to(torch.int32)
|
||||
else:
|
||||
expert_mask = None
|
||||
expert_mask = (expert_map > -1).to(torch.int32) if expert_map is not None else None
|
||||
|
||||
# w8a8 per-channel quantization
|
||||
if (
|
||||
|
||||
@ -318,10 +318,7 @@ class GemmaRMSNorm(CustomOp):
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
orig_dtype = x.dtype
|
||||
if residual is not None:
|
||||
if orig_dtype == torch.float16:
|
||||
x = x + residual.float()
|
||||
else:
|
||||
x = x + residual
|
||||
x = x + residual.float() if orig_dtype == torch.float16 else x + residual
|
||||
residual = x
|
||||
|
||||
x = x.float()
|
||||
|
||||
@ -207,10 +207,7 @@ def _fwd_kv_parallel(
|
||||
kv = tl.zeros([D_FBLOCK, E_FBLOCK], dtype=tl.float32)
|
||||
|
||||
# Handle the last block which might be smaller than BLOCK
|
||||
if off_block == NUM_BLOCK - 1:
|
||||
split_n = n - (NUM_BLOCK - 1) * BLOCK
|
||||
else:
|
||||
split_n = BLOCK
|
||||
split_n = n - (NUM_BLOCK - 1) * BLOCK if off_block == NUM_BLOCK - 1 else BLOCK
|
||||
left_shift = tl.cdiv(split_n, CBLOCK) * CBLOCK - split_n
|
||||
num_blocks = min(tl.cdiv(split_n, CBLOCK), NUM_CBLOCK)
|
||||
k_decay_ptr += (NUM_CBLOCK - num_blocks) * CBLOCK
|
||||
|
||||
@ -502,15 +502,11 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
QuantizationStrategy.CHANNEL,
|
||||
QuantizationStrategy.BLOCK,
|
||||
]
|
||||
if not (
|
||||
return (
|
||||
is_symmetric_weight
|
||||
and is_static_weight # noqa: SIM103
|
||||
and is_static_weight
|
||||
and is_tensor_or_channel_or_block_weight
|
||||
):
|
||||
return False
|
||||
|
||||
# All conditions satisfied.
|
||||
return True
|
||||
)
|
||||
|
||||
def _is_wNa16_group_channel(
|
||||
self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs
|
||||
|
||||
@ -80,10 +80,7 @@ def check_equal_or_regex_match(layer_name: str, targets: Iterable[str]) -> bool:
|
||||
Checks whether a layer_name is exactly equal or a regex match for
|
||||
if target starts with 're:' to any target in list.
|
||||
"""
|
||||
for target in targets:
|
||||
if _is_equal_or_regex_match(layer_name, target):
|
||||
return True
|
||||
return False
|
||||
return any(_is_equal_or_regex_match(layer_name, target) for target in targets)
|
||||
|
||||
|
||||
def find_matched_target(
|
||||
|
||||
@ -81,10 +81,7 @@ def check_equal_or_regex_match(layer_name: str, targets: Iterable[str]) -> bool:
|
||||
Checks whether a layer_name is exactly equal or a regex match for
|
||||
if target starts with 're:' to any target in list.
|
||||
"""
|
||||
for target in targets:
|
||||
if _is_equal_or_regex_match(layer_name, target):
|
||||
return True
|
||||
return False
|
||||
return any(_is_equal_or_regex_match(layer_name, target) for target in targets)
|
||||
|
||||
|
||||
def _is_equal_or_regex_match(
|
||||
|
||||
@ -3052,10 +3052,7 @@ def make_zmq_socket(
|
||||
# - Set a large 0.5GB buffer to improve throughput
|
||||
# For systems with less memory:
|
||||
# - Use system default (-1) to avoid excessive memory consumption
|
||||
if total_mem > 32 and available_mem > 16:
|
||||
buf_size = int(0.5 * 1024**3) # 0.5GB in bytes
|
||||
else:
|
||||
buf_size = -1 # Use system default buffer size
|
||||
buf_size = int(0.5 * 1024**3) if total_mem > 32 and available_mem > 16 else -1
|
||||
|
||||
if bind is None:
|
||||
bind = socket_type not in (zmq.PUSH, zmq.SUB, zmq.XSUB)
|
||||
|
||||
@ -17,10 +17,7 @@ def _apply_bad_words_single_batch(
|
||||
|
||||
prefix_length = len(bad_word_ids) - 1
|
||||
last_token_id = bad_word_ids[-1]
|
||||
if prefix_length > 0:
|
||||
actual_prefix = past_tokens_ids[-prefix_length:]
|
||||
else:
|
||||
actual_prefix = []
|
||||
actual_prefix = past_tokens_ids[-prefix_length:] if prefix_length > 0 else []
|
||||
expected_prefix = bad_word_ids[:prefix_length]
|
||||
|
||||
assert len(actual_prefix) == len(expected_prefix)
|
||||
|
||||
@ -444,18 +444,12 @@ def rejection_greedy_sample_kernel(
|
||||
req_idx = tl.program_id(0)
|
||||
# FIXME(woosuk): Because is_greedy_ptr is not None at profiling run,
|
||||
# re-compilation may happen during runtime when is_greedy_ptr is None.
|
||||
if is_greedy_ptr is None:
|
||||
is_greedy = True
|
||||
else:
|
||||
is_greedy = tl.load(is_greedy_ptr + req_idx)
|
||||
is_greedy = True if is_greedy_ptr is None else tl.load(is_greedy_ptr + req_idx)
|
||||
if not is_greedy:
|
||||
# Early exit for non-greedy sampling requests.
|
||||
return
|
||||
|
||||
if req_idx == 0:
|
||||
start_idx = 0
|
||||
else:
|
||||
start_idx = tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
|
||||
start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
|
||||
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
|
||||
num_draft_tokens = end_idx - start_idx
|
||||
|
||||
@ -503,10 +497,7 @@ def rejection_random_sample_kernel(
|
||||
# Early exit for greedy sampling requests.
|
||||
return
|
||||
|
||||
if req_idx == 0:
|
||||
start_idx = 0
|
||||
else:
|
||||
start_idx = tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
|
||||
start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
|
||||
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
|
||||
num_draft_tokens = end_idx - start_idx
|
||||
|
||||
@ -583,10 +574,7 @@ def sample_recovered_tokens_kernel(
|
||||
NO_DRAFT_PROBS: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
if req_idx == 0:
|
||||
start_idx = 0
|
||||
else:
|
||||
start_idx = tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
|
||||
start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
|
||||
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
|
||||
num_draft_tokens = end_idx - start_idx
|
||||
|
||||
|
||||
@ -507,12 +507,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
removed_req_indices = sorted(removed_req_indices, reverse=True)
|
||||
for req_id in req_ids_to_add:
|
||||
req_state = self.requests[req_id]
|
||||
if removed_req_indices:
|
||||
# Fill the empty index.
|
||||
req_index = removed_req_indices.pop()
|
||||
else:
|
||||
# Append to the end.
|
||||
req_index = None
|
||||
# Fill the empty index or append to the end
|
||||
req_index = removed_req_indices.pop() if removed_req_indices else None
|
||||
self.input_batch.add_request(req_state, req_index)
|
||||
|
||||
# Condense the batched states if there are empty indices.
|
||||
|
||||
Reference in New Issue
Block a user