[Misc] Add TPU usage report when using tpu_inference. (#27423)

Signed-off-by: Hongmin Fan <fanhongmin@google.com>
This commit is contained in:
hfan
2025-10-23 23:29:37 -04:00
committed by GitHub
parent 5cc6bddb6e
commit 8dbe0c527f

View File

@ -176,6 +176,32 @@ class UsageMessage:
self._report_usage_once(model_architecture, usage_context, extra_kvs)
self._report_continuous_usage()
def _report_tpu_inference_usage(self) -> bool:
try:
from tpu_inference import tpu_info, utils
self.gpu_count = tpu_info.get_num_chips()
self.gpu_type = tpu_info.get_tpu_type()
self.gpu_memory_per_device = utils.get_device_hbm_limit()
self.cuda_runtime = "tpu_inference"
return True
except Exception:
return False
def _report_torch_xla_usage(self) -> bool:
try:
import torch_xla
self.gpu_count = torch_xla.runtime.world_size()
self.gpu_type = torch_xla.tpu.get_tpu_type()
self.gpu_memory_per_device = torch_xla.core.xla_model.get_memory_info()[
"bytes_limit"
]
self.cuda_runtime = "torch_xla"
return True
except Exception:
return False
def _report_usage_once(
self,
model_architecture: str,
@ -192,16 +218,10 @@ class UsageMessage:
)
if current_platform.is_cuda():
self.cuda_runtime = torch.version.cuda
if current_platform.is_tpu():
try:
import torch_xla
self.gpu_count = torch_xla.runtime.world_size()
self.gpu_type = torch_xla.tpu.get_tpu_type()
self.gpu_memory_per_device = torch_xla.core.xla_model.get_memory_info()[
"bytes_limit"
]
except Exception:
if current_platform.is_tpu(): # noqa: SIM102
if (not self._report_tpu_inference_usage()) and (
not self._report_torch_xla_usage()
):
logger.exception("Failed to collect TPU information")
self.provider = _detect_cloud_provider()
self.architecture = platform.machine()