Upgrade to vllm 0.17.0 corex v4.1 overlay

This commit is contained in:
2026-04-29 19:38:22 +08:00
parent 8fac6062e4
commit 938d0854a5
430 changed files with 35969 additions and 14511 deletions

View File

@@ -110,7 +110,12 @@ class Glm4MoeMLP(nn.Module):
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
if self.down_proj.quant_method.__class__.__name__ != "UnquantizedLinearMethod" and x.shape[-1] != self.down_proj.weight.shape[0]:
padding = self.down_proj.weight.shape[0] - x.shape[-1]
x_align = torch.nn.functional.pad(x, (0, padding), mode='constant', value=0)
else:
x_align = x
x, _ = self.down_proj(x_align)
return x
@@ -144,11 +149,10 @@ class Glm4MoE(nn.Module):
config.hidden_size,
config.n_routed_experts,
bias=False,
# dtype=torch.float32,
dtype=torch.bfloat16,
)
self.gate.e_score_correction_bias = nn.Parameter(
torch.empty(config.n_routed_experts)
)
torch.empty(config.n_routed_experts, dtype=torch.bfloat16))
# Load balancing settings.
vllm_config = get_current_vllm_config()
@@ -205,8 +209,7 @@ class Glm4MoE(nn.Module):
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (num_tokens, n_experts)
# router_logits = self.gate(hidden_states.to(dtype=torch.float32))
router_logits = self.gate(hidden_states)
router_logits = self.gate(hidden_states.to(dtype=torch.bfloat16))
fused_moe_out = self.experts(
hidden_states=hidden_states, router_logits=router_logits
@@ -312,6 +315,9 @@ class Glm4MoeAttention(nn.Module):
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
if self.use_qk_norm:
q = self.q_norm(q.reshape(-1, self.num_heads, self.head_dim)).reshape(
q.shape