[CI/Build] Fix VLM test failures when using transformers v4.46 (#9666)
This commit is contained in:
@ -232,20 +232,22 @@ def video_assets() -> _VideoAssets:
|
||||
return VIDEO_ASSETS
|
||||
|
||||
|
||||
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature)
|
||||
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature, dict)
|
||||
|
||||
|
||||
class HfRunner:
|
||||
|
||||
def wrap_device(self, input: _T, device: Optional[str] = None) -> _T:
|
||||
def wrap_device(self, x: _T, device: Optional[str] = None) -> _T:
|
||||
if device is None:
|
||||
return self.wrap_device(
|
||||
input, "cpu" if current_platform.is_cpu() else "cuda")
|
||||
device = "cpu" if current_platform.is_cpu() else "cuda"
|
||||
|
||||
if hasattr(input, "device") and input.device.type == device:
|
||||
return input
|
||||
if isinstance(x, dict):
|
||||
return {k: self.wrap_device(v, device) for k, v in x.items()}
|
||||
|
||||
return input.to(device)
|
||||
if hasattr(x, "device") and x.device.type == device:
|
||||
return x
|
||||
|
||||
return x.to(device)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user