[Misc] Respect no_use_tqdm_on_load flag while capturing CUDA graph (#20834)

Signed-off-by: Linkun <github@lkchen.net>
This commit is contained in:
lkchen
2025-07-11 23:04:45 -07:00
committed by GitHub
parent 147afb448b
commit f56d2996ca
2 changed files with 5 additions and 2 deletions

View File

@ -2270,8 +2270,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Only rank 0 should print progress bar during capture
compilation_cases = reversed(self.cudagraph_batch_sizes)
if is_global_first_rank():
compilation_cases = tqdm(list(compilation_cases),
desc="Capturing CUDA graph shapes")
compilation_cases = tqdm(
list(compilation_cases),
disable=not self.load_config.use_tqdm_on_load,
desc="Capturing CUDA graph shapes")
for num_tokens in compilation_cases:
# We skip EPLB here since we don't want to record dummy metrics
for _ in range(

View File

@ -1587,6 +1587,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
if get_tensor_model_parallel_rank() == 0:
compilation_cases = tqdm(
list(compilation_cases),
disable=not self.load_config.use_tqdm_on_load,
desc="Capturing CUDA graph shapes")
for batch_size, use_inputs_embeds in compilation_cases:
attn_metadata = (