Use xla flag to improve the quantized model performance (#19303)
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
This commit is contained in:
@ -101,7 +101,10 @@ class TPUWorker:
|
||||
# fix this. It will be removed after the bug in XLA compiler is fixed.
|
||||
os.environ["LIBTPU_INIT_ARGS"] = (
|
||||
os.environ.get("LIBTPU_INIT_ARGS", "") +
|
||||
" --xla_tpu_force_1d_allreduce_at_chunk_count=1")
|
||||
" --xla_tpu_force_1d_allreduce_at_chunk_count=1"
|
||||
" --xla_jf_conv_input_fusion=False")
|
||||
# --xla_jf_conv_input_fusion=False is used to improve the perf of
|
||||
# quantized matmul.
|
||||
torch.set_grad_enabled(False)
|
||||
torch.set_default_dtype(self.model_config.dtype)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user