Overlap qk norm with two streams (#5977)
This commit is contained in:
@@ -421,6 +421,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
reduce_results: bool = True,
|
||||
layer_id: int = None,
|
||||
prefix: str = "",
|
||||
alt_stream: Optional[torch.cuda.Stream] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.layer_id = layer_id
|
||||
@@ -543,6 +544,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
prefix=add_prefix("attn_mha", prefix),
|
||||
)
|
||||
|
||||
self.alt_stream = alt_stream
|
||||
|
||||
self.w_kc = None
|
||||
self.w_vc = None
|
||||
self.w_scale = None
|
||||
@@ -706,14 +709,32 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
|
||||
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
q = self.q_a_layernorm(q)
|
||||
k_nope = latent_cache[..., : self.kv_lora_rank]
|
||||
|
||||
# overlap qk norm
|
||||
if self.alt_stream is not None and torch.cuda.is_current_stream_capturing():
|
||||
current_stream = torch.cuda.current_stream()
|
||||
self.alt_stream.wait_stream(current_stream)
|
||||
q = self.q_a_layernorm(q)
|
||||
with torch.cuda.stream(self.alt_stream):
|
||||
k_nope = self.kv_a_layernorm(k_nope)
|
||||
current_stream.wait_stream(self.alt_stream)
|
||||
else:
|
||||
q = self.q_a_layernorm(q)
|
||||
k_nope = self.kv_a_layernorm(k_nope)
|
||||
|
||||
k_nope = k_nope.unsqueeze(1)
|
||||
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
||||
else:
|
||||
q = self.q_proj(hidden_states)[0].view(
|
||||
-1, self.num_local_heads, self.qk_head_dim
|
||||
)
|
||||
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
|
||||
k_nope = latent_cache[..., : self.kv_lora_rank]
|
||||
k_nope = self.kv_a_layernorm(k_nope).unsqueeze(1)
|
||||
|
||||
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1)
|
||||
|
||||
if self.use_deep_gemm_bmm:
|
||||
q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = (
|
||||
@@ -750,11 +771,6 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
|
||||
|
||||
q_nope_out = q_nope_out.transpose(0, 1)
|
||||
|
||||
k_nope = latent_cache[..., : self.kv_lora_rank]
|
||||
k_nope = self.kv_a_layernorm(k_nope).unsqueeze(1)
|
||||
k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1)
|
||||
|
||||
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
||||
|
||||
if self.attention_backend == "fa3":
|
||||
@@ -1104,6 +1120,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
is_nextn: bool = False,
|
||||
prefix: str = "",
|
||||
alt_stream: Optional[torch.cuda.Stream] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@@ -1133,6 +1150,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
layer_id=layer_id,
|
||||
reduce_results=False,
|
||||
prefix=add_prefix("self_attn", prefix),
|
||||
alt_stream=alt_stream,
|
||||
)
|
||||
|
||||
self.info = self._compute_info(config, layer_id=layer_id, is_nextn=is_nextn)
|
||||
@@ -1376,6 +1394,7 @@ class DeepseekV2Model(nn.Module):
|
||||
config.hidden_size,
|
||||
enable_tp=not global_server_args_dict["enable_dp_attention"],
|
||||
)
|
||||
self.alt_stream = torch.cuda.Stream()
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
DeepseekV2DecoderLayer(
|
||||
@@ -1383,6 +1402,7 @@ class DeepseekV2Model(nn.Module):
|
||||
layer_id,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix(f"layers.{layer_id}", prefix),
|
||||
alt_stream=self.alt_stream,
|
||||
)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user