Overlap qk norm with two streams (#5977)
This commit is contained in:
@@ -421,6 +421,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
reduce_results: bool = True,
|
reduce_results: bool = True,
|
||||||
layer_id: int = None,
|
layer_id: int = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
|
alt_stream: Optional[torch.cuda.Stream] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layer_id = layer_id
|
self.layer_id = layer_id
|
||||||
@@ -543,6 +544,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
prefix=add_prefix("attn_mha", prefix),
|
prefix=add_prefix("attn_mha", prefix),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.alt_stream = alt_stream
|
||||||
|
|
||||||
self.w_kc = None
|
self.w_kc = None
|
||||||
self.w_vc = None
|
self.w_vc = None
|
||||||
self.w_scale = None
|
self.w_scale = None
|
||||||
@@ -706,14 +709,32 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
|
q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
|
||||||
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
|
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
|
||||||
)
|
)
|
||||||
q = self.q_a_layernorm(q)
|
k_nope = latent_cache[..., : self.kv_lora_rank]
|
||||||
|
|
||||||
|
# overlap qk norm
|
||||||
|
if self.alt_stream is not None and torch.cuda.is_current_stream_capturing():
|
||||||
|
current_stream = torch.cuda.current_stream()
|
||||||
|
self.alt_stream.wait_stream(current_stream)
|
||||||
|
q = self.q_a_layernorm(q)
|
||||||
|
with torch.cuda.stream(self.alt_stream):
|
||||||
|
k_nope = self.kv_a_layernorm(k_nope)
|
||||||
|
current_stream.wait_stream(self.alt_stream)
|
||||||
|
else:
|
||||||
|
q = self.q_a_layernorm(q)
|
||||||
|
k_nope = self.kv_a_layernorm(k_nope)
|
||||||
|
|
||||||
|
k_nope = k_nope.unsqueeze(1)
|
||||||
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
||||||
else:
|
else:
|
||||||
q = self.q_proj(hidden_states)[0].view(
|
q = self.q_proj(hidden_states)[0].view(
|
||||||
-1, self.num_local_heads, self.qk_head_dim
|
-1, self.num_local_heads, self.qk_head_dim
|
||||||
)
|
)
|
||||||
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
|
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
|
||||||
|
k_nope = latent_cache[..., : self.kv_lora_rank]
|
||||||
|
k_nope = self.kv_a_layernorm(k_nope).unsqueeze(1)
|
||||||
|
|
||||||
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||||
|
k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1)
|
||||||
|
|
||||||
if self.use_deep_gemm_bmm:
|
if self.use_deep_gemm_bmm:
|
||||||
q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = (
|
q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = (
|
||||||
@@ -750,11 +771,6 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
|
q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
|
||||||
|
|
||||||
q_nope_out = q_nope_out.transpose(0, 1)
|
q_nope_out = q_nope_out.transpose(0, 1)
|
||||||
|
|
||||||
k_nope = latent_cache[..., : self.kv_lora_rank]
|
|
||||||
k_nope = self.kv_a_layernorm(k_nope).unsqueeze(1)
|
|
||||||
k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1)
|
|
||||||
|
|
||||||
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
||||||
|
|
||||||
if self.attention_backend == "fa3":
|
if self.attention_backend == "fa3":
|
||||||
@@ -1104,6 +1120,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
is_nextn: bool = False,
|
is_nextn: bool = False,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
|
alt_stream: Optional[torch.cuda.Stream] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
@@ -1133,6 +1150,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
reduce_results=False,
|
reduce_results=False,
|
||||||
prefix=add_prefix("self_attn", prefix),
|
prefix=add_prefix("self_attn", prefix),
|
||||||
|
alt_stream=alt_stream,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.info = self._compute_info(config, layer_id=layer_id, is_nextn=is_nextn)
|
self.info = self._compute_info(config, layer_id=layer_id, is_nextn=is_nextn)
|
||||||
@@ -1376,6 +1394,7 @@ class DeepseekV2Model(nn.Module):
|
|||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
enable_tp=not global_server_args_dict["enable_dp_attention"],
|
enable_tp=not global_server_args_dict["enable_dp_attention"],
|
||||||
)
|
)
|
||||||
|
self.alt_stream = torch.cuda.Stream()
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
DeepseekV2DecoderLayer(
|
DeepseekV2DecoderLayer(
|
||||||
@@ -1383,6 +1402,7 @@ class DeepseekV2Model(nn.Module):
|
|||||||
layer_id,
|
layer_id,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=add_prefix(f"layers.{layer_id}", prefix),
|
prefix=add_prefix(f"layers.{layer_id}", prefix),
|
||||||
|
alt_stream=self.alt_stream,
|
||||||
)
|
)
|
||||||
for layer_id in range(config.num_hidden_layers)
|
for layer_id in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
|
|||||||
Reference in New Issue
Block a user