[Misc] Respect no_use_tqdm_on_load flag while capturing CUDA graph (#20834)
Signed-off-by: Linkun <github@lkchen.net>
This commit is contained in:
@ -2270,8 +2270,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Only rank 0 should print progress bar during capture
|
||||
compilation_cases = reversed(self.cudagraph_batch_sizes)
|
||||
if is_global_first_rank():
|
||||
compilation_cases = tqdm(list(compilation_cases),
|
||||
desc="Capturing CUDA graph shapes")
|
||||
compilation_cases = tqdm(
|
||||
list(compilation_cases),
|
||||
disable=not self.load_config.use_tqdm_on_load,
|
||||
desc="Capturing CUDA graph shapes")
|
||||
for num_tokens in compilation_cases:
|
||||
# We skip EPLB here since we don't want to record dummy metrics
|
||||
for _ in range(
|
||||
|
||||
@ -1587,6 +1587,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
if get_tensor_model_parallel_rank() == 0:
|
||||
compilation_cases = tqdm(
|
||||
list(compilation_cases),
|
||||
disable=not self.load_config.use_tqdm_on_load,
|
||||
desc="Capturing CUDA graph shapes")
|
||||
for batch_size, use_inputs_embeds in compilation_cases:
|
||||
attn_metadata = (
|
||||
|
||||
Reference in New Issue
Block a user