This commit is contained in:
Woosuk Kwon
2023-02-23 20:23:47 +00:00
parent 7f985166f7
commit fdd0f2f472
3 changed files with 4 additions and 3 deletions

View File

@ -13,7 +13,7 @@ class Sampler(nn.Module):
embedding: torch.Tensor,
) -> None:
super().__init__()
self.embedding = embedding.t() # [hidden_size, vocab_size]
self.embedding = embedding # [vocab_size, hidden_size]
def forward(
self,
@ -31,7 +31,7 @@ class Sampler(nn.Module):
hidden_states = hidden_states[last_token_indicies]
# Get the logits for the next tokens.
logits = torch.matmul(hidden_states, self.embedding)
logits = torch.matmul(hidden_states, self.embedding.t())
# Sample the next tokens.
# TODO(woosuk): Implement other sampling methods.