[CI/Build] Fix VLM test failures when using transformers v4.46 (#9666)

This commit is contained in:
Cyrus Leung
2024-10-25 01:40:40 +08:00
committed by GitHub
parent d27cfbf791
commit c866e0079d
4 changed files with 28 additions and 12 deletions

View File

@ -232,20 +232,22 @@ def video_assets() -> _VideoAssets:
return VIDEO_ASSETS
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature)
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature, dict)
class HfRunner:
def wrap_device(self, input: _T, device: Optional[str] = None) -> _T:
def wrap_device(self, x: _T, device: Optional[str] = None) -> _T:
if device is None:
return self.wrap_device(
input, "cpu" if current_platform.is_cpu() else "cuda")
device = "cpu" if current_platform.is_cpu() else "cuda"
if hasattr(input, "device") and input.device.type == device:
return input
if isinstance(x, dict):
return {k: self.wrap_device(v, device) for k, v in x.items()}
return input.to(device)
if hasattr(x, "device") and x.device.type == device:
return x
return x.to(device)
def __init__(
self,