Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -88,7 +88,7 @@ class MiniMaxM2MoE(nn.Module):
|
||||
self.use_routing_bias = getattr(config, "use_routing_bias", False)
|
||||
if self.use_routing_bias:
|
||||
self.e_score_correction_bias = nn.Parameter(
|
||||
torch.empty(config.num_local_experts, dtype=torch.float32)
|
||||
torch.empty(config.num_local_experts, dtype=torch.get_default_dtype())
|
||||
)
|
||||
self.e_score_correction_bias.weight_loader = (
|
||||
MiniMaxM2MoE.ebias_weight_loader
|
||||
@@ -107,13 +107,14 @@ class MiniMaxM2MoE(nn.Module):
|
||||
renormalize=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.experts",
|
||||
router_logits_dtype=torch.float32,
|
||||
)
|
||||
|
||||
self.gate = ReplicatedLinear(
|
||||
config.hidden_size,
|
||||
config.num_local_experts,
|
||||
bias=False,
|
||||
# params_dtype=torch.float32,
|
||||
params_dtype=torch.float32,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.gate",
|
||||
)
|
||||
@@ -121,7 +122,6 @@ class MiniMaxM2MoE(nn.Module):
|
||||
@staticmethod
|
||||
def ebias_weight_loader(param: nn.Parameter, loaded_weight: torch.Tensor) -> None:
|
||||
assert param.size() == loaded_weight.size()
|
||||
# param.data.copy_(loaded_weight.to(torch.float32))
|
||||
param.data.copy_(loaded_weight)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
@@ -129,10 +129,9 @@ class MiniMaxM2MoE(nn.Module):
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
# router_logits, _ = self.gate(hidden_states.to(torch.float32))
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
router_logits, _ = self.gate(hidden_states.to(torch.float32))
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states, router_logits=router_logits
|
||||
hidden_states=hidden_states, router_logits=router_logits.to(hidden_states.dtype)
|
||||
)
|
||||
final_hidden_states = final_hidden_states
|
||||
if self.tp_size > 1:
|
||||
|
||||
Reference in New Issue
Block a user