Fuse q_a_proj and kv_a_proj (#5619)
This commit is contained in:
@@ -443,12 +443,12 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
|
||||
# For tensor parallel attention
|
||||
if self.q_lora_rank is not None:
|
||||
self.q_a_proj = ReplicatedLinear(
|
||||
self.fused_qkv_a_proj_with_mqa = ReplicatedLinear(
|
||||
self.hidden_size,
|
||||
self.q_lora_rank,
|
||||
self.q_lora_rank + self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("q_a_proj", prefix),
|
||||
prefix=add_prefix("fused_qkv_a_proj_with_mqa", prefix),
|
||||
)
|
||||
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
|
||||
self.q_b_proj = ColumnParallelLinear(
|
||||
@@ -470,6 +470,14 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
tp_rank=attn_tp_rank,
|
||||
tp_size=attn_tp_size,
|
||||
)
|
||||
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
||||
self.hidden_size,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("kv_a_proj_with_mqa", prefix),
|
||||
)
|
||||
|
||||
self.kv_b_proj = ColumnParallelLinear(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
||||
@@ -490,14 +498,6 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
tp_rank=attn_tp_rank,
|
||||
tp_size=attn_tp_size,
|
||||
)
|
||||
|
||||
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
||||
self.hidden_size,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("kv_a_proj_with_mqa", prefix),
|
||||
)
|
||||
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
|
||||
|
||||
if rope_scaling:
|
||||
@@ -656,15 +656,18 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
if self.q_lora_rank is not None:
|
||||
q = self.q_a_proj(hidden_states)[0]
|
||||
q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
|
||||
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
q = self.q_a_layernorm(q)
|
||||
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
||||
else:
|
||||
q = self.q_proj(hidden_states)[0].view(
|
||||
-1, self.num_local_heads, self.qk_head_dim
|
||||
)
|
||||
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
|
||||
|
||||
_, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
|
||||
kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
latent_cache = latent_cache.unsqueeze(1)
|
||||
kv_a = self.kv_a_layernorm(kv_a.contiguous())
|
||||
@@ -699,13 +702,16 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
zero_allocator: BumpAllocator,
|
||||
) -> torch.Tensor:
|
||||
if self.q_lora_rank is not None:
|
||||
q = self.q_a_proj(hidden_states)[0]
|
||||
q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
|
||||
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
q = self.q_a_layernorm(q)
|
||||
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
||||
else:
|
||||
q = self.q_proj(hidden_states)[0].view(
|
||||
-1, self.num_local_heads, self.qk_head_dim
|
||||
)
|
||||
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
|
||||
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
|
||||
if self.use_deep_gemm_bmm:
|
||||
@@ -744,7 +750,6 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
|
||||
q_nope_out = q_nope_out.transpose(0, 1)
|
||||
|
||||
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
|
||||
k_nope = latent_cache[..., : self.kv_lora_rank]
|
||||
k_nope = self.kv_a_layernorm(k_nope).unsqueeze(1)
|
||||
k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1)
|
||||
@@ -819,13 +824,16 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
q_len, self.num_local_heads, self.kv_lora_rank + self.qk_rope_head_dim
|
||||
)
|
||||
if self.q_lora_rank is not None:
|
||||
q = self.q_a_proj(hidden_states)[0]
|
||||
q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
|
||||
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
q = self.q_a_layernorm(q)
|
||||
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
||||
else:
|
||||
q = self.q_proj(hidden_states)[0].view(
|
||||
-1, self.num_local_heads, self.qk_head_dim
|
||||
)
|
||||
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
|
||||
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
|
||||
if self.w_kc.dtype == torch.float8_e4m3fnuz:
|
||||
@@ -846,8 +854,6 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
else:
|
||||
q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
|
||||
q_input[..., : self.kv_lora_rank] = q_nope_out.transpose(0, 1)
|
||||
|
||||
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
|
||||
v_input = latent_cache[..., : self.kv_lora_rank]
|
||||
v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1)
|
||||
k_input = latent_cache.unsqueeze(1)
|
||||
@@ -1018,15 +1024,17 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
|
||||
# First do normal mha forward to get output for extended part
|
||||
if self.q_lora_rank is not None:
|
||||
q = self.q_a_proj(hidden_states)[0]
|
||||
q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
|
||||
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
q = self.q_a_layernorm(q)
|
||||
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
||||
else:
|
||||
q = self.q_proj(hidden_states)[0].view(
|
||||
-1, self.num_local_heads, self.qk_head_dim
|
||||
)
|
||||
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
|
||||
_, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
|
||||
kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
latent_cache = latent_cache.unsqueeze(1)
|
||||
kv_a = self.kv_a_layernorm(kv_a.contiguous())
|
||||
@@ -1668,6 +1676,12 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
num_experts=self.config.n_routed_experts + self.n_share_experts_fusion,
|
||||
)
|
||||
|
||||
# Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
|
||||
fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
|
||||
self.config.q_lora_rank is not None
|
||||
)
|
||||
cached_a_proj = {} if fuse_qkv_a_proj else None
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in weights:
|
||||
# TODO(HandH1998): Modify it when nextn is supported.
|
||||
@@ -1723,11 +1737,50 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
if fuse_qkv_a_proj and (
|
||||
"q_a_proj" in name or "kv_a_proj_with_mqa" in name
|
||||
):
|
||||
cached_a_proj[name] = loaded_weight
|
||||
q_a_proj_name = (
|
||||
name
|
||||
if "q_a_proj" in name
|
||||
else name.replace("kv_a_proj_with_mqa", "q_a_proj")
|
||||
)
|
||||
kv_a_proj_name = (
|
||||
name
|
||||
if "kv_a_proj_with_mqa" in name
|
||||
else name.replace("q_a_proj", "kv_a_proj_with_mqa")
|
||||
)
|
||||
|
||||
# When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
|
||||
if (
|
||||
q_a_proj_name in cached_a_proj
|
||||
and kv_a_proj_name in cached_a_proj
|
||||
):
|
||||
|
||||
q_a_proj_weight = cached_a_proj[q_a_proj_name]
|
||||
kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
|
||||
fused_weight = torch.cat(
|
||||
[q_a_proj_weight, kv_a_proj_weight], dim=0
|
||||
)
|
||||
|
||||
param_name = name.replace(
|
||||
"q_a_proj", "fused_qkv_a_proj_with_mqa"
|
||||
)
|
||||
param = params_dict[param_name]
|
||||
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, fused_weight)
|
||||
cached_a_proj.pop(q_a_proj_name)
|
||||
cached_a_proj.pop(kv_a_proj_name)
|
||||
else:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
self.post_load_weights()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user