@ -402,6 +402,11 @@ class GPUModelRunner:
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
|
||||
needs_prompt_logprobs = np.any(
|
||||
self.req_states.needs_prompt_logprobs[idx_mapping_np])
|
||||
if needs_prompt_logprobs:
|
||||
pass
|
||||
|
||||
if use_dp_sampler:
|
||||
# All-gather the outputs.
|
||||
sampler_output = all_gather_sampler_output(
|
||||
|
||||
Reference in New Issue
Block a user