Upgrade to vllm 0.17.0 corex v4.1 overlay

This commit is contained in:
2026-04-29 19:38:22 +08:00
parent 8fac6062e4
commit 938d0854a5
430 changed files with 35969 additions and 14511 deletions

View File

@@ -69,7 +69,7 @@ class FP32ReplicatedLinear(ReplicatedLinear):
x: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
assert self.params_dtype == torch.float32
return super().forward(x.to(torch.float32))
return super().forward(x)
class Step3p5MLP(nn.Module):
@@ -148,6 +148,7 @@ class Step3p5Attention(nn.Module):
yarn_only_types: list = None,
swa_num_attention_heads: int | None = None,
partial_rotary_factor: float = 1.0,
total_layer_num: int = None,
):
super().__init__()
self.hidden_size = hidden_size
@@ -250,7 +251,12 @@ class Step3p5Attention(nn.Module):
prefix=f"{prefix}.attn",
per_layer_sliding_window=sliding_window,
attn_type=attn_type,
)
extra_cache_para={
"total_num_kv_heads":self.total_num_kv_heads,
"total_layer_num":total_layer_num,
"layer_id":self.layer_idx,
},
)
self.max_position_embeddings = max_position
assert self.partial_rotary_factor == 1 or self.partial_rotary_factor == 0.5
@@ -390,6 +396,7 @@ class FusedMoEBlock(nn.Module):
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
router_logits_dtype=torch.float32,
fused_shared_output=True,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -413,8 +420,8 @@ class FusedMoEBlock(nn.Module):
assert shared_output is None
if self.share_expert is not None:
assert shared_output is not None
final_hidden_states += shared_output
if shared_output is not None:
final_hidden_states += shared_output
if self.tp_size > 1:
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(
@@ -484,6 +491,7 @@ class Step3p5DecoderLayer(nn.Module):
if partial_rotary_factors
else 1.0,
prefix=f"{prefix}.self_attn",
total_layer_num=vllm_config.model_config.hf_config.num_hidden_layers,
)
else:
raise ValueError(
@@ -535,25 +543,28 @@ class Step3p5DecoderLayer(nn.Module):
return self.tp_group.all_reduce(in1 + in2)
def forward(
self, positions: torch.Tensor, hidden_states: torch.Tensor
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: torch.Tensor | None,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
if residual is None:
residual = hidden_states.clone()
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)
hidden_states += residual
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
if self.use_moe:
ffn_output = self.moe(hidden_states)
else:
ffn_output = self.mlp(hidden_states)
hidden_states = ffn_output + residual
return hidden_states
return ffn_output, residual
@support_torch_compile
@@ -592,7 +603,7 @@ class Step3p5Model(nn.Module):
self.norm = PPMissingLayer()
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states"], config.hidden_size
["hidden_states", "residual"], config.hidden_size
)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
@@ -610,20 +621,26 @@ class Step3p5Model(nn.Module):
hidden_states = inputs_embeds
else:
hidden_states = self.embed_input_ids(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states = layer(positions, hidden_states)
hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors(
{
"hidden_states": hidden_states,
"residual": residual
}
)
hidden_states += residual
return hidden_states
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
@@ -891,7 +908,6 @@ class Step3p5ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.model.norm(hidden_states)
logits = self.logits_processor(self.lm_head, hidden_states)
return logits
@@ -934,7 +950,26 @@ class Step3p5ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
loaded = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
all_params = set(dict(self.named_parameters()).keys())
missing = all_params - loaded
kv_cache_params = {p for p in missing
if p.endswith((".k_scale", ".k_zero_point",
".q_scale", ".q_zero_point",
".v_scale", ".v_zero_point"))}
real_missing = missing - kv_cache_params
if real_missing:
logger.warning(
"[DIAG-WEIGHTS] Missing %d params: %s",
len(real_missing), sorted(real_missing),
)
else:
logger.info(
"[DIAG-WEIGHTS] All params loaded. (%d kv-cache params skipped)",
len(kv_cache_params),
)
return loaded
def get_spec_layer_idx_from_weight_name(