Deprecate disable-mla (#5481)

This commit is contained in:
Baizhou Zhang
2025-04-17 01:43:14 -07:00
committed by GitHub
parent 81c891111f
commit 4fb05583ef
9 changed files with 188 additions and 575 deletions

View File

@@ -93,158 +93,6 @@ def input_to_float8(x, dtype=torch.float8_e4m3fn):
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
class MiniCPM3Attention(nn.Module):
def __init__(
self,
config: PretrainedConfig,
hidden_size: int,
num_heads: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
q_lora_rank: int,
kv_lora_rank: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
layer_id=None,
prefix: str = "",
) -> None:
super().__init__()
self.layer_id = layer_id
self.hidden_size = hidden_size
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
self.v_head_dim = v_head_dim
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.num_heads = num_heads
tp_size = get_tensor_model_parallel_world_size()
assert num_heads % tp_size == 0
self.num_local_heads = num_heads // tp_size
self.scaling = self.qk_head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
if self.q_lora_rank is not None:
self.q_a_proj = ReplicatedLinear(
self.hidden_size,
self.q_lora_rank,
bias=False,
quant_config=quant_config,
prefix=add_prefix("q_a_proj", prefix),
)
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
self.q_b_proj = ColumnParallelLinear(
q_lora_rank,
self.num_heads * self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=add_prefix("q_b_proj", prefix),
)
else:
self.q_proj = ColumnParallelLinear(
self.hidden_size,
self.num_heads * self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=add_prefix("q_proj", prefix),
)
self.kv_a_proj_with_mqa = ReplicatedLinear(
self.hidden_size,
self.kv_lora_rank + self.qk_rope_head_dim,
bias=False,
quant_config=quant_config,
prefix=add_prefix("kv_a_proj_with_mqa", prefix),
)
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
self.kv_b_proj = ColumnParallelLinear(
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
bias=False,
quant_config=quant_config,
prefix=add_prefix("kv_b_proj", prefix),
)
# O projection.
self.o_proj = RowParallelLinear(
self.num_heads * self.v_head_dim,
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=add_prefix("o_proj", prefix),
)
self.rotary_emb = get_rope(
qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
)
# TODO support head_size 96
self.attn = RadixAttention(
self.num_local_heads,
128,
self.scaling,
num_kv_heads=self.num_local_heads,
layer_id=layer_id,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
if self.q_lora_rank is not None:
q = self.q_a_proj(hidden_states)[0]
q = self.q_a_layernorm(q)
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
else:
q = self.q_proj(hidden_states)[0].view(
-1, self.num_local_heads, self.qk_head_dim
)
_, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
latent_cache = latent_cache.unsqueeze(1)
kv_a = self.kv_a_layernorm(kv_a.contiguous())
kv = self.kv_b_proj(kv_a)[0]
kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k_pe = latent_cache[:, :, self.kv_lora_rank :]
original_shapes = [q_pe.shape, k_pe.shape]
q_pe, k_pe = self.rotary_emb(
positions, q_pe.reshape(q_pe.shape[0], -1), k_pe.reshape(k_pe.shape[0], -1)
)
q_pe, k_pe = q_pe.view(original_shapes[0]), k_pe.view(original_shapes[1])
q[..., self.qk_nope_head_dim :] = q_pe
k = torch.empty_like(q)
k[..., : self.qk_nope_head_dim] = k_nope
k[..., self.qk_nope_head_dim :] = k_pe
q = torch.nn.functional.pad(q, [0, 128 - self.qk_head_dim], value=0).view(
-1, self.num_local_heads * 128
)
k = torch.nn.functional.pad(k, [0, 128 - self.qk_head_dim], value=0).view(
-1, self.num_local_heads * 128
)
v = torch.nn.functional.pad(v, [0, 128 - self.v_head_dim], value=0).view(
-1, self.num_local_heads * 128
)
attn_output = self.attn(q, k, v, forward_batch)
attn_output = attn_output.view(-1, self.num_local_heads, 128)[
..., : self.v_head_dim
].reshape(-1, self.num_local_heads * self.v_head_dim)
output, _ = self.o_proj(attn_output)
return output
class MiniCPM3AttentionMLA(nn.Module):
def __init__(
@@ -434,44 +282,25 @@ class MiniCPM3DecoderLayer(nn.Module):
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
if not global_server_args_dict["disable_mla"]:
self.self_attn = MiniCPM3AttentionMLA(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
qk_nope_head_dim=config.qk_nope_head_dim,
qk_rope_head_dim=config.qk_rope_head_dim,
v_head_dim=self.hidden_size // config.num_attention_heads,
q_lora_rank=(
config.q_lora_rank if hasattr(config, "q_lora_rank") else None
),
kv_lora_rank=config.kv_lora_rank,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
layer_id=layer_id,
prefix=add_prefix("self_attn", prefix),
)
else:
self.self_attn = MiniCPM3Attention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
qk_nope_head_dim=config.qk_nope_head_dim,
qk_rope_head_dim=config.qk_rope_head_dim,
v_head_dim=self.hidden_size // config.num_attention_heads,
q_lora_rank=(
config.q_lora_rank if hasattr(config, "q_lora_rank") else None
),
kv_lora_rank=config.kv_lora_rank,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
layer_id=layer_id,
prefix=add_prefix("self_attn", prefix),
)
self.self_attn = MiniCPM3AttentionMLA(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
qk_nope_head_dim=config.qk_nope_head_dim,
qk_rope_head_dim=config.qk_rope_head_dim,
v_head_dim=self.hidden_size // config.num_attention_heads,
q_lora_rank=(
config.q_lora_rank if hasattr(config, "q_lora_rank") else None
),
kv_lora_rank=config.kv_lora_rank,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
layer_id=layer_id,
prefix=add_prefix("self_attn", prefix),
)
self.mlp = MiniCPM3MLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
@@ -674,17 +503,16 @@ class MiniCPM3ForCausalLM(nn.Module):
)
weight_loader(param, loaded_weight)
if not global_server_args_dict["disable_mla"]:
for layer_id in range(self.config.num_hidden_layers):
self_attn = self.model.layers[layer_id].self_attn
w_kc, w_vc = self_attn.kv_b_proj.weight.unflatten(
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
if hasattr(self_attn.kv_b_proj, "weight_scale"):
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
del self_attn.kv_b_proj
for layer_id in range(self.config.num_hidden_layers):
self_attn = self.model.layers[layer_id].self_attn
w_kc, w_vc = self_attn.kv_b_proj.weight.unflatten(
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
if hasattr(self_attn.kv_b_proj, "weight_scale"):
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
del self_attn.kv_b_proj
EntryClass = MiniCPM3ForCausalLM