MLA prefill w/o weight absorption (#2349)
This commit is contained in:
@@ -453,7 +453,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
||||
self.scaling = self.scaling * mscale * mscale
|
||||
|
||||
self.attn = RadixAttention(
|
||||
self.attn_mqa = RadixAttention(
|
||||
self.num_local_heads,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
self.scaling,
|
||||
@@ -462,6 +462,15 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
v_head_dim=self.kv_lora_rank,
|
||||
)
|
||||
|
||||
self.attn_mha = RadixAttention(
|
||||
self.num_local_heads,
|
||||
self.qk_nope_head_dim + self.qk_rope_head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_local_heads,
|
||||
layer_id=layer_id,
|
||||
v_head_dim=self.v_head_dim,
|
||||
)
|
||||
|
||||
self.w_kc = None
|
||||
self.w_vc = None
|
||||
self.w_scale = None
|
||||
@@ -471,6 +480,63 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
# Use normal computation for prefill and use weight absorption for extend/decode
|
||||
if (
|
||||
forward_batch.forward_mode.is_extend()
|
||||
and forward_batch.extend_prefix_lens.sum() == 0
|
||||
):
|
||||
return self.forward_normal(positions, hidden_states, forward_batch)
|
||||
else:
|
||||
return self.forward_absorb(positions, hidden_states, forward_batch)
|
||||
|
||||
def forward_normal(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
if self.q_lora_rank is not None:
|
||||
q = self.q_a_proj(hidden_states)[0]
|
||||
q = self.q_a_layernorm(q)
|
||||
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
||||
else:
|
||||
q = self.q_proj(hidden_states)[0].view(
|
||||
-1, self.num_local_heads, self.qk_head_dim
|
||||
)
|
||||
_, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
|
||||
kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
latent_cache = latent_cache.unsqueeze(1)
|
||||
kv_a = self.kv_a_layernorm(kv_a.contiguous())
|
||||
kv = self.kv_b_proj(kv_a)[0]
|
||||
kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
|
||||
k_nope = kv[..., : self.qk_nope_head_dim]
|
||||
v = kv[..., self.qk_nope_head_dim :]
|
||||
k_pe = latent_cache[:, :, self.kv_lora_rank :]
|
||||
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
||||
q[..., self.qk_nope_head_dim :] = q_pe
|
||||
k = torch.empty_like(q)
|
||||
k[..., : self.qk_nope_head_dim] = k_nope
|
||||
k[..., self.qk_nope_head_dim :] = k_pe
|
||||
|
||||
latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
|
||||
latent_cache[:, :, self.kv_lora_rank :] = k_pe
|
||||
|
||||
# Save latent cache
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
|
||||
)
|
||||
attn_output = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
|
||||
attn_output = attn_output.reshape(-1, self.num_local_heads * self.v_head_dim)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
def forward_absorb(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
q_len = hidden_states.shape[0]
|
||||
q_input = hidden_states.new_empty(
|
||||
@@ -508,7 +574,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
q_input[..., self.kv_lora_rank :] = q_pe
|
||||
k_input[..., self.kv_lora_rank :] = k_pe
|
||||
|
||||
attn_output = self.attn(q_input, k_input, v_input, forward_batch)
|
||||
attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
|
||||
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
||||
|
||||
if self.w_vc.dtype == torch.float8_e4m3fn:
|
||||
@@ -835,7 +901,6 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
|
||||
if hasattr(self_attn.kv_b_proj, "weight_scale"):
|
||||
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
|
||||
del self_attn.kv_b_proj
|
||||
|
||||
|
||||
EntryClass = DeepseekV2ForCausalLM
|
||||
|
||||
Reference in New Issue
Block a user