Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -69,7 +69,7 @@ class FP32ReplicatedLinear(ReplicatedLinear):
|
||||
x: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
|
||||
assert self.params_dtype == torch.float32
|
||||
return super().forward(x.to(torch.float32))
|
||||
return super().forward(x)
|
||||
|
||||
|
||||
class Step3p5MLP(nn.Module):
|
||||
@@ -148,6 +148,7 @@ class Step3p5Attention(nn.Module):
|
||||
yarn_only_types: list = None,
|
||||
swa_num_attention_heads: int | None = None,
|
||||
partial_rotary_factor: float = 1.0,
|
||||
total_layer_num: int = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@@ -250,7 +251,12 @@ class Step3p5Attention(nn.Module):
|
||||
prefix=f"{prefix}.attn",
|
||||
per_layer_sliding_window=sliding_window,
|
||||
attn_type=attn_type,
|
||||
)
|
||||
extra_cache_para={
|
||||
"total_num_kv_heads":self.total_num_kv_heads,
|
||||
"total_layer_num":total_layer_num,
|
||||
"layer_id":self.layer_idx,
|
||||
},
|
||||
)
|
||||
|
||||
self.max_position_embeddings = max_position
|
||||
assert self.partial_rotary_factor == 1 or self.partial_rotary_factor == 0.5
|
||||
@@ -390,6 +396,7 @@ class FusedMoEBlock(nn.Module):
|
||||
enable_eplb=self.enable_eplb,
|
||||
num_redundant_experts=self.n_redundant_experts,
|
||||
router_logits_dtype=torch.float32,
|
||||
fused_shared_output=True,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
@@ -413,8 +420,8 @@ class FusedMoEBlock(nn.Module):
|
||||
assert shared_output is None
|
||||
|
||||
if self.share_expert is not None:
|
||||
assert shared_output is not None
|
||||
final_hidden_states += shared_output
|
||||
if shared_output is not None:
|
||||
final_hidden_states += shared_output
|
||||
|
||||
if self.tp_size > 1:
|
||||
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(
|
||||
@@ -484,6 +491,7 @@ class Step3p5DecoderLayer(nn.Module):
|
||||
if partial_rotary_factors
|
||||
else 1.0,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
total_layer_num=vllm_config.model_config.hf_config.num_hidden_layers,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -535,25 +543,28 @@ class Step3p5DecoderLayer(nn.Module):
|
||||
return self.tp_group.all_reduce(in1 + in2)
|
||||
|
||||
def forward(
|
||||
self, positions: torch.Tensor, hidden_states: torch.Tensor
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor | None,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
if residual is None:
|
||||
residual = hidden_states.clone()
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
hidden_states += residual
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
|
||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||
if self.use_moe:
|
||||
ffn_output = self.moe(hidden_states)
|
||||
else:
|
||||
ffn_output = self.mlp(hidden_states)
|
||||
hidden_states = ffn_output + residual
|
||||
return hidden_states
|
||||
return ffn_output, residual
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
@@ -592,7 +603,7 @@ class Step3p5Model(nn.Module):
|
||||
self.norm = PPMissingLayer()
|
||||
|
||||
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
|
||||
["hidden_states"], config.hidden_size
|
||||
["hidden_states", "residual"], config.hidden_size
|
||||
)
|
||||
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
@@ -610,20 +621,26 @@ class Step3p5Model(nn.Module):
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.embed_input_ids(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states = layer(positions, hidden_states)
|
||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors(
|
||||
{
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
}
|
||||
)
|
||||
|
||||
hidden_states += residual
|
||||
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
@@ -891,7 +908,6 @@ class Step3p5ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.model.norm(hidden_states)
|
||||
logits = self.logits_processor(self.lm_head, hidden_states)
|
||||
return logits
|
||||
|
||||
@@ -934,7 +950,26 @@ class Step3p5ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
loaded = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
||||
all_params = set(dict(self.named_parameters()).keys())
|
||||
missing = all_params - loaded
|
||||
kv_cache_params = {p for p in missing
|
||||
if p.endswith((".k_scale", ".k_zero_point",
|
||||
".q_scale", ".q_zero_point",
|
||||
".v_scale", ".v_zero_point"))}
|
||||
real_missing = missing - kv_cache_params
|
||||
if real_missing:
|
||||
logger.warning(
|
||||
"[DIAG-WEIGHTS] Missing %d params: %s",
|
||||
len(real_missing), sorted(real_missing),
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"[DIAG-WEIGHTS] All params loaded. (%d kv-cache params skipped)",
|
||||
len(kv_cache_params),
|
||||
)
|
||||
return loaded
|
||||
|
||||
|
||||
def get_spec_layer_idx_from_weight_name(
|
||||
|
||||
Reference in New Issue
Block a user