@ -17,10 +17,10 @@ class Sampler(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
logprobs_mode: LogprobsMode = LogprobsMode.PROCESSED_LOGPROBS,
|
||||
logprobs_mode: LogprobsMode = "processed_logprobs",
|
||||
):
|
||||
super().__init__()
|
||||
assert logprobs_mode == LogprobsMode.PROCESSED_LOGPROBS
|
||||
assert logprobs_mode == "processed_logprobs"
|
||||
self.logprobs_mode = logprobs_mode
|
||||
|
||||
def forward(
|
||||
|
||||
Reference in New Issue
Block a user