Update deprecated Python 3.8 typing (#13971)
This commit is contained in:
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import random
|
||||
from typing import Tuple
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@ -33,7 +32,7 @@ class MockLogitsProcessor(LogitsProcessor):
|
||||
|
||||
def _prepare_test(
|
||||
batch_size: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsProcessor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor, MockLogitsProcessor]:
|
||||
vocab_size = 32000
|
||||
input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16)
|
||||
fake_logits = torch.full((batch_size, vocab_size),
|
||||
|
||||
Reference in New Issue
Block a user