[fix]: remove data type hardcoding from gptoss model implementation (#23807)

Signed-off-by: Nikhil Gupta <nikhil.gupta2@arm.com>
This commit is contained in:
Nikhil Gupta
2025-09-18 19:15:23 +01:00
committed by GitHub
parent e19bce40a1
commit 064cac7bb7

View File

@ -76,7 +76,6 @@ class OAIAttention(nn.Module):
self.sinks = torch.nn.Parameter(
torch.empty(config.num_attention_heads // tp_size,
dtype=torch.bfloat16,
requires_grad=False))
self.q_size = self.num_attention_heads * self.head_dim // tp_size
@ -145,8 +144,7 @@ class MLPBlock(torch.nn.Module):
self.experts_per_token = config.num_experts_per_tok
self.world_size = dist.get_world_size() if dist.is_initialized() else 1
self.router = torch.nn.Linear(config.hidden_size,
config.num_local_experts,
dtype=torch.bfloat16)
config.num_local_experts)
assert config.intermediate_size % self.world_size == 0
self.experts = FusedMoE(num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,