Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -47,7 +47,7 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
|
||||
from vllm.model_executor.layers.fused_moe import GateLinear, SharedFusedMoE
|
||||
from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
@@ -75,7 +75,9 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
from vllm.model_executor.models.utils import sequence_parallel_chunk
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backend import AttentionBackend
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.attention.backends.mla.indexer import (
|
||||
DeepseekV32IndexerBackend,
|
||||
)
|
||||
@@ -89,6 +91,7 @@ from .utils import (
|
||||
make_layers,
|
||||
maybe_prefix,
|
||||
)
|
||||
import ixformer.inference.functions as ixfops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -221,73 +224,6 @@ class DeepseekV2MLP(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class DeepSeekV2Gate(ReplicatedLinear):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
n_experts: int,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
assert quant_config is None
|
||||
super().__init__(
|
||||
hidden_size,
|
||||
n_experts,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate",
|
||||
)
|
||||
|
||||
# Unquantized only, will be called "weight".
|
||||
assert hasattr(self, "weight")
|
||||
is_hopper_or_blackwell = current_platform.is_device_capability(
|
||||
(9, 0)
|
||||
) or current_platform.is_device_capability_family(100)
|
||||
SUPPORTED_NUM_EXPERTS = [256, 384]
|
||||
SUPPORTED_HIDDEN_SIZES = [7168]
|
||||
|
||||
self.allow_dsv3_router_gemm = (
|
||||
current_platform.is_cuda()
|
||||
and is_hopper_or_blackwell
|
||||
and n_experts in SUPPORTED_NUM_EXPERTS
|
||||
and hidden_size in SUPPORTED_HIDDEN_SIZES
|
||||
)
|
||||
|
||||
self._out_dtype: torch.dtype | None = None
|
||||
|
||||
def set_out_dtype(self, out_dtype: torch.dtype) -> None:
|
||||
"""
|
||||
Set out dtype for the router logits. This is needed after
|
||||
__init__, b/c we need to check if the trtllm kernel is
|
||||
selected before we decide between bf16 and fp32.
|
||||
"""
|
||||
|
||||
if self._out_dtype is not None:
|
||||
raise ValueError("out_dtype has already been set")
|
||||
else:
|
||||
self._out_dtype = out_dtype
|
||||
|
||||
@property
|
||||
def out_dtype(self) -> torch.dtype:
|
||||
if self._out_dtype is None:
|
||||
raise ValueError("out_dtype has not been set yet")
|
||||
return self._out_dtype
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, None]:
|
||||
"""
|
||||
Use specialized GEMM for low batch size for DSV3 and KIMI.
|
||||
"""
|
||||
if self.allow_dsv3_router_gemm and x.shape[0] <= 16:
|
||||
return ops.dsv3_router_gemm(
|
||||
hidden_states=x, router_weight=self.weight, output_dtype=self.out_dtype
|
||||
), None
|
||||
else:
|
||||
return super().forward(x)
|
||||
|
||||
|
||||
class DeepseekV2MoE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -316,23 +252,12 @@ class DeepseekV2MoE(nn.Module):
|
||||
"Only silu is supported for now."
|
||||
)
|
||||
|
||||
# self.gate = DeepSeekV2Gate(
|
||||
# config.hidden_size,
|
||||
# config.n_routed_experts,
|
||||
# quant_config=None,
|
||||
# prefix=f"{prefix}.gate",
|
||||
# )
|
||||
self.gate = ReplicatedLinear(
|
||||
self.gate = GateLinear(
|
||||
config.hidden_size,
|
||||
config.n_routed_experts,
|
||||
bias=False,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.gate",
|
||||
)
|
||||
if getattr(config, "topk_method", None) == "noaux_tc":
|
||||
# self.gate.e_score_correction_bias = nn.Parameter(
|
||||
# torch.empty(config.n_routed_experts, dtype=torch.float32)
|
||||
# )
|
||||
self.gate.e_score_correction_bias = nn.Parameter(
|
||||
torch.empty(config.n_routed_experts)
|
||||
)
|
||||
@@ -401,12 +326,12 @@ class DeepseekV2MoE(nn.Module):
|
||||
else None,
|
||||
)
|
||||
|
||||
# # NOTE(rob): this is a hack until we finish off the PR for
|
||||
# # merging TRTLLM kernels into the MK framework. Then we can
|
||||
# # query the MonolithicMK for the expected router logits.
|
||||
# self.gate.set_out_dtype(
|
||||
# torch.float32 if self.experts.quant_method.is_monolithic else torch.bfloat16
|
||||
# )
|
||||
# NOTE(rob): this is a hack until we finish off the PR for
|
||||
# merging TRTLLM kernels into the MK framework. Then we can
|
||||
# query the MonolithicMK for the expected router logits.
|
||||
self.gate.set_out_dtype(
|
||||
torch.float32 if self.experts.quant_method.is_monolithic else torch.bfloat16
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
@@ -443,11 +368,12 @@ class DeepseekV2MoE(nn.Module):
|
||||
elif self.shared_experts is not None:
|
||||
assert shared_output is not None
|
||||
shared_output *= 1.0 / self.routed_scaling_factor
|
||||
|
||||
|
||||
if self.shared_experts is not None:
|
||||
assert shared_output is not None
|
||||
final_hidden_states += shared_output
|
||||
|
||||
|
||||
if self.is_sequence_parallel:
|
||||
final_hidden_states = tensor_model_parallel_all_gather(
|
||||
final_hidden_states, 0
|
||||
@@ -596,7 +522,7 @@ class DeepseekV2Attention(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
@@ -605,23 +531,20 @@ class DeepseekV2Attention(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
if self.q_lora_rank is not None:
|
||||
q = self.q_a_proj(hidden_states)[0]
|
||||
kv_a, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split([self.kv_lora_rank, self.qk_rope_head_dim], dim=1)
|
||||
q = self.q_a_layernorm(q)
|
||||
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
||||
else:
|
||||
q = self.q_proj(hidden_states)[0].view(
|
||||
-1, self.num_local_heads, self.qk_head_dim
|
||||
)
|
||||
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
|
||||
kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
latent_cache = latent_cache.unsqueeze(1)
|
||||
q = self.q_proj(hidden_states)[0]
|
||||
kv_a, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split([self.kv_lora_rank, self.qk_rope_head_dim], dim=1)
|
||||
q = q.view(-1, self.num_local_heads, self.qk_head_dim)
|
||||
|
||||
_, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
|
||||
kv_a = self.kv_a_layernorm(kv_a)
|
||||
kv = self.kv_b_proj(kv_a)[0]
|
||||
kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
|
||||
k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
k_pe = latent_cache[:, :, self.kv_lora_rank :]
|
||||
|
||||
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
||||
k_nope, v_nope = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
|
||||
q[..., self.qk_nope_head_dim :] = q_pe
|
||||
k = torch.empty_like(q)
|
||||
@@ -671,7 +594,7 @@ class DeepseekV32IndexerCache(torch.nn.Module, AttentionLayerBase):
|
||||
|
||||
def get_attn_backend(self) -> AttentionBackend:
|
||||
return DeepseekV32IndexerBackend
|
||||
|
||||
|
||||
|
||||
class Indexer(nn.Module):
|
||||
def __init__(
|
||||
@@ -727,8 +650,8 @@ class Indexer(nn.Module):
|
||||
# where we store value in fp8 and scale in fp32
|
||||
# per self.quant_block_size element
|
||||
self.k_cache = DeepseekV32IndexerCache(
|
||||
head_dim=self.head_dim + self.head_dim // self.quant_block_size * 4,
|
||||
dtype=torch.uint8,
|
||||
head_dim=self.head_dim,
|
||||
dtype=torch.bfloat16,
|
||||
prefix=f"{prefix}.k_cache",
|
||||
cache_config=cache_config,
|
||||
)
|
||||
@@ -776,23 +699,61 @@ class Indexer(nn.Module):
|
||||
k = torch.cat([k_pe.squeeze(-2), k_nope], dim=-1)
|
||||
|
||||
# we only quant q here since k quant is fused with cache insertion
|
||||
q = q.view(-1, self.head_dim)
|
||||
q_fp8, q_scale = per_token_group_quant_fp8(
|
||||
q,
|
||||
self.quant_block_size,
|
||||
column_major_scales=False,
|
||||
use_ue8m0=self.scale_fmt is not None,
|
||||
)
|
||||
q_fp8 = q_fp8.view(-1, self.n_head, self.head_dim)
|
||||
q_scale = q_scale.view(-1, self.n_head, 1)
|
||||
# q = q.view(-1, self.head_dim)
|
||||
# q_fp8, q_scale = per_token_group_quant_fp8(
|
||||
# q,
|
||||
# self.quant_block_size,
|
||||
# column_major_scales=False,
|
||||
# use_ue8m0=self.scale_fmt is not None,
|
||||
# )
|
||||
# q_fp8 = q_fp8.view(-1, self.n_head, self.head_dim)
|
||||
# q_scale = q_scale.view(-1, self.n_head, 1)
|
||||
|
||||
weights, _ = self.weights_proj(hidden_states)
|
||||
weights = (
|
||||
weights.unsqueeze(-1) * q_scale * self.softmax_scale * self.n_head**-0.5
|
||||
weights.unsqueeze(-1) * self.softmax_scale * self.n_head**-0.5
|
||||
)
|
||||
weights = weights.squeeze(-1)
|
||||
|
||||
return self.indexer_op(hidden_states, q_fp8, k, weights)
|
||||
return self.indexer_op(hidden_states, q, k, weights)
|
||||
|
||||
|
||||
def _min_latency_fused_qkv_a_proj_impl(
|
||||
input_: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Dynamically run min-latency gemm if num_tokens <= 16.
|
||||
This must be wrapped in a custom op because our torch.compile integration
|
||||
does not support runtime dispatching on num_tokens.
|
||||
"""
|
||||
num_tokens = input_.shape[0]
|
||||
if 0 < num_tokens <= 16:
|
||||
output = torch.empty(
|
||||
num_tokens,
|
||||
weight.shape[0],
|
||||
dtype=torch.bfloat16,
|
||||
device=input_.device,
|
||||
)
|
||||
ops.dsv3_fused_a_gemm(output, input_, weight.T)
|
||||
return output
|
||||
else:
|
||||
return torch.nn.functional.linear(input_, weight)
|
||||
|
||||
|
||||
def _min_latency_fused_qkv_a_proj_fake(
|
||||
input_: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return input_.new_empty(input_.shape[0], weight.shape[0])
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="min_latency_fused_qkv_a_proj",
|
||||
op_func=_min_latency_fused_qkv_a_proj_impl,
|
||||
mutates_args=[],
|
||||
fake_impl=_min_latency_fused_qkv_a_proj_fake,
|
||||
)
|
||||
|
||||
|
||||
class DeepSeekV2FusedQkvAProj(MergedColumnParallelLinear):
|
||||
@@ -830,19 +791,8 @@ class DeepSeekV2FusedQkvAProj(MergedColumnParallelLinear):
|
||||
self,
|
||||
input_,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.nn.Parameter | None]:
|
||||
num_tokens = input_.shape[0]
|
||||
if self._use_min_latency_gemm and (0 < num_tokens <= 16):
|
||||
output = torch.empty(
|
||||
num_tokens,
|
||||
2112,
|
||||
dtype=torch.bfloat16,
|
||||
device=input_.device,
|
||||
)
|
||||
ops.dsv3_fused_a_gemm(
|
||||
output,
|
||||
input_,
|
||||
self.weight.T,
|
||||
)
|
||||
if self._use_min_latency_gemm:
|
||||
output = torch.ops.vllm.min_latency_fused_qkv_a_proj(input_, self.weight)
|
||||
if not self.return_bias:
|
||||
return output
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
@@ -898,47 +848,35 @@ class DeepseekV2MLAAttention(nn.Module):
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
if self.q_lora_rank is not None:
|
||||
# self.fused_qkv_a_proj = DeepSeekV2FusedQkvAProj(
|
||||
# self.hidden_size,
|
||||
# [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
|
||||
# quant_config=quant_config,
|
||||
# prefix=f"{prefix}.fused_qkv_a_proj",
|
||||
# )
|
||||
self.fused_qkv_a_proj = MergedColumnParallelLinear(
|
||||
self.hidden_size,
|
||||
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fused_qkv_a_proj",
|
||||
disable_tp=True,
|
||||
)
|
||||
self.q_a_proj = ReplicatedLinear(self.hidden_size,
|
||||
self.q_lora_rank,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.q_a_proj")
|
||||
self.q_a_layernorm = RMSNorm(self.q_lora_rank,
|
||||
eps=config.rms_norm_eps)
|
||||
self.q_b_proj = ColumnParallelLinear(self.q_lora_rank,
|
||||
self.num_heads *
|
||||
self.qk_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.q_b_proj")
|
||||
else:
|
||||
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
||||
self.hidden_size,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.kv_a_proj_with_mqa",
|
||||
)
|
||||
self.q_proj = ColumnParallelLinear(self.hidden_size,
|
||||
self.num_heads *
|
||||
self.qk_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.q_proj")
|
||||
|
||||
if self.q_lora_rank is not None:
|
||||
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
|
||||
self.q_b_proj = ColumnParallelLinear(
|
||||
self.q_lora_rank,
|
||||
self.num_heads * self.qk_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.q_b_proj",
|
||||
)
|
||||
else:
|
||||
self.q_proj = ColumnParallelLinear(
|
||||
self.hidden_size,
|
||||
self.num_heads * self.qk_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.q_proj",
|
||||
)
|
||||
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
|
||||
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
||||
self.hidden_size,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.kv_a_proj_with_mqa")
|
||||
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
|
||||
eps=config.rms_norm_eps)
|
||||
self.kv_b_proj = ColumnParallelLinear(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
||||
@@ -1005,9 +943,7 @@ class DeepseekV2MLAAttention(nn.Module):
|
||||
kv_b_proj=self.kv_b_proj,
|
||||
rotary_emb=self.rotary_emb,
|
||||
o_proj=self.o_proj,
|
||||
fused_qkv_a_proj=self.fused_qkv_a_proj
|
||||
if self.q_lora_rank is not None
|
||||
else None,
|
||||
q_a_proj=self.q_a_proj if self.q_lora_rank is not None else None,
|
||||
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa
|
||||
if self.q_lora_rank is None
|
||||
else None,
|
||||
@@ -1346,14 +1282,14 @@ class DeepseekV2ForCausalLM(
|
||||
# initializing DeepseekV2Model, as it is passed inplace to
|
||||
# quantization config init and may be used to select the
|
||||
# quant_method for relevant layers during initialization.
|
||||
self.fuse_qkv_a_proj = (
|
||||
hasattr(config, "q_lora_rank") and config.q_lora_rank is not None
|
||||
)
|
||||
if self.fuse_qkv_a_proj:
|
||||
self.packed_modules_mapping["fused_qkv_a_proj"] = [
|
||||
"q_a_proj",
|
||||
"kv_a_proj_with_mqa",
|
||||
]
|
||||
# self.fuse_qkv_a_proj = (
|
||||
# hasattr(config, "q_lora_rank") and config.q_lora_rank is not None
|
||||
# )
|
||||
# if self.fuse_qkv_a_proj:
|
||||
# self.packed_modules_mapping["fused_qkv_a_proj"] = [
|
||||
# "q_a_proj",
|
||||
# "kv_a_proj_with_mqa",
|
||||
# ]
|
||||
|
||||
self.model = self.model_cls(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
|
||||
@@ -1385,19 +1321,19 @@ class DeepseekV2ForCausalLM(
|
||||
self.moe_layers = []
|
||||
self.moe_mlp_layers = []
|
||||
example_moe = None
|
||||
|
||||
for layer in self.model.layers:
|
||||
if isinstance(layer, PPMissingLayer):
|
||||
continue
|
||||
|
||||
assert isinstance(layer, DeepseekV2DecoderLayer)
|
||||
if isinstance(layer.mlp, DeepseekV2MoE):
|
||||
# Pick last one layer since the first ones may be dense layers.
|
||||
example_moe = layer.mlp
|
||||
self.moe_mlp_layers.append(layer.mlp)
|
||||
self.moe_layers.append(layer.mlp.experts)
|
||||
|
||||
self.extract_moe_parameters(example_moe)
|
||||
|
||||
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
@@ -1441,10 +1377,10 @@ class DeepseekV2ForCausalLM(
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
mla_params_mapping = [
|
||||
("fused_qkv_a_proj", "q_a_proj", 0),
|
||||
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
|
||||
]
|
||||
# mla_params_mapping = [
|
||||
# ("fused_qkv_a_proj", "q_a_proj", 0),
|
||||
# ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
|
||||
# ]
|
||||
mha_params_mapping = [
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
@@ -1452,8 +1388,8 @@ class DeepseekV2ForCausalLM(
|
||||
]
|
||||
if self.use_mha:
|
||||
stacked_params_mapping.extend(mha_params_mapping)
|
||||
else:
|
||||
stacked_params_mapping.extend(mla_params_mapping)
|
||||
# else:
|
||||
# stacked_params_mapping.extend(mla_params_mapping)
|
||||
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
@@ -1474,168 +1410,232 @@ class DeepseekV2ForCausalLM(
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
|
||||
if spec_layer is not None:
|
||||
continue # skip spec decode layers for main model
|
||||
|
||||
is_fusion_moe_shared_experts_layer = (
|
||||
rocm_aiter_moe_shared_expert_enabled and ("mlp.shared_experts" in name)
|
||||
)
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
# Skip non-stacked layers and experts (experts handled below).
|
||||
if weight_name not in name:
|
||||
continue
|
||||
# We have mlp.experts[0].gate_proj in the checkpoint.
|
||||
# Since we handle the experts below in expert_params_mapping,
|
||||
# we need to skip here BEFORE we update the name, otherwise
|
||||
# name will be updated to mlp.experts[0].gate_up_proj, which
|
||||
# will then be updated below in expert_params_mapping
|
||||
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
||||
if ("mlp.experts." in name) and name not in params_dict:
|
||||
continue
|
||||
if is_fusion_moe_shared_experts_layer:
|
||||
continue
|
||||
name_mapped = name.replace(weight_name, param_name)
|
||||
|
||||
# QKV fusion is optional, fall back to normal
|
||||
# weight loading if it's not enabled
|
||||
# if go with fusion option, then update name
|
||||
if (
|
||||
param_name == "fused_qkv_a_proj"
|
||||
) and name_mapped not in params_dict:
|
||||
continue
|
||||
else:
|
||||
name = name_mapped
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
try:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
|
||||
if spec_layer is not None:
|
||||
continue # skip spec decode layers for main model
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
is_expert_weight = False
|
||||
|
||||
# Special handling: when AITER fusion_shared_experts is enabled,
|
||||
# checkpoints may provide a single widened shared_experts tensor
|
||||
# without explicit expert indices
|
||||
# (e.g. ...mlp.shared_experts.gate_proj.weight).
|
||||
# For models with multiple shared experts, split that tensor
|
||||
# evenly into per-shared-expert slices and load them into
|
||||
# appended expert slots mlp.experts.{n_routed_experts + j}.*
|
||||
# accordingly.
|
||||
num_chunks = 1
|
||||
if is_fusion_moe_shared_experts_layer:
|
||||
num_chunks = getattr(self.config, "n_shared_experts", 1) or 1
|
||||
# Determine split axis based on op type
|
||||
# gate/up: ColumnParallel → split along dim 0
|
||||
# down: RowParallel → split along dim 1
|
||||
split_dim = (
|
||||
1
|
||||
if ("down_proj.weight" in name and loaded_weight.ndim > 1)
|
||||
else 0
|
||||
)
|
||||
total = loaded_weight.shape[split_dim]
|
||||
assert total % num_chunks == 0, (
|
||||
f"Shared expert weight dim {total} "
|
||||
f"not divisible by num_chunks {num_chunks}"
|
||||
)
|
||||
chunk_size = total // num_chunks
|
||||
|
||||
for j in range(num_chunks):
|
||||
chunk_name = name
|
||||
weight_to_load = loaded_weight
|
||||
is_fusion_moe_shared_experts_layer = (
|
||||
rocm_aiter_moe_shared_expert_enabled and ("mlp.shared_experts" in name)
|
||||
)
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
# Skip non-stacked layers and experts (experts handled below).
|
||||
if weight_name not in name:
|
||||
continue
|
||||
# We have mlp.experts[0].gate_proj in the checkpoint.
|
||||
# Since we handle the experts below in expert_params_mapping,
|
||||
# we need to skip here BEFORE we update the name, otherwise
|
||||
# name will be updated to mlp.experts[0].gate_up_proj, which
|
||||
# will then be updated below in expert_params_mapping
|
||||
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
||||
if ("mlp.experts." in name) and name not in params_dict:
|
||||
continue
|
||||
if is_fusion_moe_shared_experts_layer:
|
||||
chunk_slice = slice(j * chunk_size, (j + 1) * chunk_size)
|
||||
if loaded_weight.ndim == 1:
|
||||
weight_to_load = loaded_weight[chunk_slice]
|
||||
elif split_dim == 0:
|
||||
weight_to_load = loaded_weight[chunk_slice, :]
|
||||
else:
|
||||
weight_to_load = loaded_weight[:, chunk_slice]
|
||||
# Synthesize an expert-style name so expert mapping
|
||||
# can route it
|
||||
chunk_name = name.replace(
|
||||
"mlp.shared_experts",
|
||||
f"mlp.experts.{self.config.n_routed_experts + j}",
|
||||
)
|
||||
continue
|
||||
name_mapped = name.replace(weight_name, param_name)
|
||||
|
||||
# Use expert_params_mapping to locate the destination
|
||||
# param and delegate to its expert-aware weight_loader
|
||||
# with expert_id.
|
||||
for mapping in expert_params_mapping:
|
||||
param_name, weight_name, expert_id, shard_id = mapping
|
||||
if weight_name not in chunk_name:
|
||||
continue
|
||||
|
||||
# Anyway, this is an expert weight and should not be
|
||||
# attempted to load as other weights later
|
||||
is_expert_weight = True
|
||||
|
||||
# Do not modify `name` since the loop may continue here
|
||||
# Instead, create a new variable
|
||||
name_mapped = chunk_name.replace(weight_name, param_name)
|
||||
|
||||
if is_pp_missing_parameter(name_mapped, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name_mapped]
|
||||
# We should ask the weight loader to return success or
|
||||
# not here since otherwise we may skip experts with
|
||||
# other available replicas.
|
||||
weight_loader = typing.cast(
|
||||
Callable[..., bool], param.weight_loader
|
||||
)
|
||||
success = weight_loader(
|
||||
param,
|
||||
weight_to_load,
|
||||
name_mapped,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id,
|
||||
return_success=True,
|
||||
)
|
||||
if success:
|
||||
if not is_fusion_moe_shared_experts_layer:
|
||||
name = name_mapped
|
||||
else:
|
||||
loaded_params.add(name_mapped)
|
||||
break
|
||||
# QKV fusion is optional, fall back to normal
|
||||
# weight loading if it's not enabled
|
||||
# if go with fusion option, then update name
|
||||
if (
|
||||
param_name == "fused_qkv_a_proj"
|
||||
) and name_mapped not in params_dict:
|
||||
continue
|
||||
else:
|
||||
if is_expert_weight:
|
||||
# We've checked that this is an expert weight
|
||||
# However it's not mapped locally to this rank
|
||||
# So we simply skip it
|
||||
continue
|
||||
name = name_mapped
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
is_expert_weight = False
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
# Special handling: when AITER fusion_shared_experts is enabled,
|
||||
# checkpoints may provide a single widened shared_experts tensor
|
||||
# without explicit expert indices
|
||||
# (e.g. ...mlp.shared_experts.gate_proj.weight).
|
||||
# For models with multiple shared experts, split that tensor
|
||||
# evenly into per-shared-expert slices and load them into
|
||||
# appended expert slots mlp.experts.{n_routed_experts + j}.*
|
||||
# accordingly.
|
||||
num_chunks = 1
|
||||
if is_fusion_moe_shared_experts_layer:
|
||||
num_chunks = getattr(self.config, "n_shared_experts", 1) or 1
|
||||
# Determine split axis based on op type
|
||||
# gate/up: ColumnParallel → split along dim 0
|
||||
# down: RowParallel → split along dim 1
|
||||
split_dim = (
|
||||
1
|
||||
if ("down_proj.weight" in name and loaded_weight.ndim > 1)
|
||||
else 0
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
if name is not None and not is_fusion_moe_shared_experts_layer:
|
||||
loaded_params.add(name)
|
||||
total = loaded_weight.shape[split_dim]
|
||||
assert total % num_chunks == 0, (
|
||||
f"Shared expert weight dim {total} "
|
||||
f"not divisible by num_chunks {num_chunks}"
|
||||
)
|
||||
chunk_size = total // num_chunks
|
||||
|
||||
for j in range(num_chunks):
|
||||
chunk_name = name
|
||||
weight_to_load = loaded_weight
|
||||
|
||||
if is_fusion_moe_shared_experts_layer:
|
||||
chunk_slice = slice(j * chunk_size, (j + 1) * chunk_size)
|
||||
if loaded_weight.ndim == 1:
|
||||
weight_to_load = loaded_weight[chunk_slice]
|
||||
elif split_dim == 0:
|
||||
weight_to_load = loaded_weight[chunk_slice, :]
|
||||
else:
|
||||
weight_to_load = loaded_weight[:, chunk_slice]
|
||||
# Synthesize an expert-style name so expert mapping
|
||||
# can route it
|
||||
chunk_name = name.replace(
|
||||
"mlp.shared_experts",
|
||||
f"mlp.experts.{self.config.n_routed_experts + j}",
|
||||
)
|
||||
|
||||
# Use expert_params_mapping to locate the destination
|
||||
# param and delegate to its expert-aware weight_loader
|
||||
# with expert_id.
|
||||
for mapping in expert_params_mapping:
|
||||
param_name, weight_name, expert_id, shard_id = mapping
|
||||
if weight_name not in chunk_name:
|
||||
continue
|
||||
# Anyway, this is an expert weight and should not be
|
||||
# attempted to load as other weights later
|
||||
is_expert_weight = True
|
||||
|
||||
# Do not modify `name` since the loop may continue here
|
||||
# Instead, create a new variable
|
||||
name_mapped = chunk_name.replace(weight_name, param_name)
|
||||
|
||||
if is_pp_missing_parameter(name_mapped, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name_mapped]
|
||||
# We should ask the weight loader to return success or
|
||||
# not here since otherwise we may skip experts with
|
||||
# other available replicas.
|
||||
weight_loader = typing.cast(
|
||||
Callable[..., bool], param.weight_loader
|
||||
)
|
||||
success = weight_loader(
|
||||
param,
|
||||
weight_to_load,
|
||||
name_mapped,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id,
|
||||
return_success=True,
|
||||
)
|
||||
if success:
|
||||
if not is_fusion_moe_shared_experts_layer:
|
||||
name = name_mapped
|
||||
else:
|
||||
loaded_params.add(name_mapped)
|
||||
break
|
||||
else:
|
||||
if is_expert_weight:
|
||||
# We've checked that this is an expert weight
|
||||
# However it's not mapped locally to this rank
|
||||
# So we simply skip it
|
||||
continue
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
if name is not None and not is_fusion_moe_shared_experts_layer:
|
||||
loaded_params.add(name)
|
||||
except:
|
||||
pass
|
||||
|
||||
opt_support_quant_method = ["GGUFLinearMethod", "UnquantizedLinearMethod", "CompressedTensorsW8A8Int8", "AWQMarlinLinearMethod"]
|
||||
# add your opt here..
|
||||
def inject_layer(layer, quant_method, is_mla):
|
||||
q_lora_rank = getattr(layer, "q_lora_rank", None)
|
||||
if quant_method in ["UnquantizedLinearMethod", "CompressedTensorsW8A8Int8"]:
|
||||
if q_lora_rank is not None:
|
||||
layer.q_a_proj.weight.data = torch.cat([layer.q_a_proj.weight, layer.kv_a_proj_with_mqa.weight], dim=0)
|
||||
if hasattr(layer.q_a_proj, "weight_scale"):
|
||||
layer.q_a_proj.weight_scale.data = torch.cat([layer.q_a_proj.weight_scale, layer.kv_a_proj_with_mqa.weight_scale], dim=0)
|
||||
del layer.kv_a_proj_with_mqa.weight_scale
|
||||
elif not is_mla:
|
||||
layer.q_proj.weight.data = torch.cat([layer.q_proj.weight, layer.kv_a_proj_with_mqa.weight], dim=0)
|
||||
if hasattr(layer.q_proj, "weight_scale"):
|
||||
layer.q_proj.weight_scale.data = torch.cat([layer.q_proj.weight_scale, layer.kv_a_proj_with_mqa.weight_scale], dim=0)
|
||||
del layer.kv_a_proj_with_mqa.weight_scale
|
||||
else:
|
||||
return
|
||||
del layer.kv_a_proj_with_mqa.weight
|
||||
del layer.kv_a_proj_with_mqa
|
||||
if is_mla:
|
||||
layer.mla_attn.forward = layer.mla_attn.forward_opt
|
||||
else:
|
||||
layer.forward = layer.forward_opt
|
||||
elif quant_method == "GGUFLinearMethod":
|
||||
pass
|
||||
elif quant_method == "AWQMarlinLinearMethod":
|
||||
dtype = layer.kv_a_proj_with_mqa.qweight.dtype
|
||||
assert dtype == torch.int32
|
||||
if layer.q_lora_rank is not None:
|
||||
layer.q_a_proj.qweight.data = torch.cat([layer.q_a_proj.qweight, layer.kv_a_proj_with_mqa.qweight], dim=1)
|
||||
layer.q_a_proj.scales.data = torch.cat([layer.q_a_proj.scales, layer.kv_a_proj_with_mqa.scales], dim=1)
|
||||
del layer.kv_a_proj_with_mqa.scales
|
||||
layer.q_a_proj.qzeros.data = torch.cat([layer.q_a_proj.qzeros, layer.kv_a_proj_with_mqa.qzeros], dim=1)
|
||||
del layer.kv_a_proj_with_mqa.qzeros
|
||||
elif not is_mla:
|
||||
layer.q_proj.weight.data = torch.cat([layer.q_proj.weight, layer.kv_a_proj_with_mqa.weight], dim=1)
|
||||
layer.q_proj.scales.data = torch.cat([layer.q_proj.scales, layer.kv_a_proj_with_mqa.scales], dim=1)
|
||||
del layer.kv_a_proj_with_mqa.scales
|
||||
layer.q_proj.qzeros.data = torch.cat([layer.q_proj.qzeros, layer.kv_a_proj_with_mqa.qzeros], dim=1)
|
||||
del layer.kv_a_proj_with_mqa.qzeros
|
||||
else:
|
||||
return
|
||||
|
||||
del layer.kv_a_proj_with_mqa.qweight
|
||||
del layer.kv_a_proj_with_mqa
|
||||
if is_mla:
|
||||
layer.mla_attn.forward = layer.mla_attn.forward_opt
|
||||
else:
|
||||
layer.forward = layer.forward_opt
|
||||
else:
|
||||
pass
|
||||
|
||||
for _, layer in self.model.named_modules():
|
||||
if layer.__class__.__name__ in ["DeepseekV2Attention","DeepseekV2MLAAttention"]:
|
||||
if hasattr(layer.kv_a_proj_with_mqa, "scheme"):
|
||||
quant_method = layer.kv_a_proj_with_mqa.scheme.__class__.__name__
|
||||
else:
|
||||
quant_method = layer.kv_a_proj_with_mqa.quant_method.__class__.__name__
|
||||
if quant_method not in opt_support_quant_method:
|
||||
break
|
||||
|
||||
inject_layer(layer, quant_method, is_mla = layer.__class__.__name__ == "DeepseekV2MLAAttention")
|
||||
|
||||
return loaded_params
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user