[Refactor][WIP] Refactor mla_v1 by moving all MLA preprocessing ops into mla_v1 attention impl (#2465)
### What this PR does / why we need it?
In order to support fused kernels, multi-stream, communication
optimization etc, it's better to aggregate all opreations in Attention
layer togather. This PR tries to refactor mla_v1 by moving all MLA
preprocessing ops into mla_v1 attention impl.
Note that new mla_v1 doesn't take torchair into consideration. So this
PR can only be merged after torchair related mla_v1 is isolated into a
new file.
### Does this PR introduce _any_ user-facing change?
NO
### How was this patch tested?
### Features Test
<img width="506" height="141" alt="image"
src="https://github.com/user-attachments/assets/f1ab2906-a1ac-4450-8433-94811cd89466"
/>
### Performance After Refact
<img width="648" height="486" alt="image"
src="https://github.com/user-attachments/assets/e33e038c-c5d9-4ba7-a8e9-1ac22f9833eb"
/>
### Performance Before Refact
<img width="618" height="494" alt="image"
src="https://github.com/user-attachments/assets/83861dc2-dc51-4af3-9310-90ab10c43bb1"
/>
- vLLM version: v0.10.1.1
- vLLM main:
e03940762b
---------
Signed-off-by: lwq <liwenquan5@huawei.com>
Signed-off-by: whx-sjtu <2952154980@qq.com>
Signed-off-by: SunnyLee219 <3294305115@qq.com>
Co-authored-by: lwq <liwenquan5@huawei.com>
Co-authored-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
@@ -37,7 +37,6 @@ from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tp_group, split_tensor_along_last_dim,
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce,
|
||||
tensor_model_parallel_reduce_scatter)
|
||||
from vllm.distributed.parallel_state import get_dp_group, get_ep_group
|
||||
@@ -73,7 +72,7 @@ from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.ops.fused_moe import AscendFusedMoE
|
||||
from vllm_ascend.quantization.quant_config import AscendLinearMethod
|
||||
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
|
||||
from vllm_ascend.utils import dispose_tensor, npu_prefetch
|
||||
from vllm_ascend.utils import dispose_tensor
|
||||
|
||||
|
||||
class CustomDeepseekV2SiluAndMul(SiluAndMul):
|
||||
@@ -471,9 +470,6 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
|
||||
self.debug_layer_idx = int(self.prefix.split(".")[-2])
|
||||
|
||||
ascend_config = get_ascend_config()
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||
self.enable_multistream_mla = \
|
||||
ascend_config.torchair_graph_config.enable_multistream_mla
|
||||
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
||||
|
||||
if self.q_lora_rank is not None:
|
||||
@@ -515,8 +511,7 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
|
||||
if (config.n_routed_experts is not None
|
||||
and self.debug_layer_idx >= config.first_k_dense_replace
|
||||
and self.debug_layer_idx % config.moe_layer_freq == 0
|
||||
and (ascend_config.torchair_graph_config.enable_multistream_moe
|
||||
or self.enable_shared_expert_dp)):
|
||||
and self.enable_shared_expert_dp):
|
||||
self.o_proj = CustomDeepseekV2RowParallelLinearReplaceAllreduce(
|
||||
self.num_heads * self.v_head_dim,
|
||||
self.hidden_size,
|
||||
@@ -568,6 +563,9 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
|
||||
qk_head_dim=self.qk_head_dim,
|
||||
v_head_dim=self.v_head_dim,
|
||||
rotary_emb=self.rotary_emb,
|
||||
q_a_proj=self.q_a_proj if self.q_lora_rank is not None else None,
|
||||
q_a_layernorm=self.q_a_layernorm
|
||||
if self.q_lora_rank is not None else None,
|
||||
q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj,
|
||||
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
|
||||
kv_a_layernorm=self.kv_a_layernorm,
|
||||
@@ -582,55 +580,29 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
|
||||
kv_cache: Optional[torch.Tensor] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
|
||||
forward_context = get_forward_context()
|
||||
enable_multistream_mla = (self.enable_multistream_mla
|
||||
and attn_metadata is not None
|
||||
and not forward_context.with_prefill
|
||||
and attn_metadata.num_decodes > 0)
|
||||
forward_kwargs = {"enable_multistream_mla": enable_multistream_mla}
|
||||
if self.q_lora_rank is not None:
|
||||
npu_prefetch(self.q_a_proj.weight,
|
||||
hidden_states,
|
||||
enabled=enable_multistream_mla)
|
||||
ckq = self.q_a_proj(hidden_states)[0]
|
||||
hidden_states_or_q_c = self.q_a_layernorm(ckq)
|
||||
forward_kwargs['ckq'] = ckq
|
||||
else:
|
||||
hidden_states_or_q_c = hidden_states
|
||||
if self.torchair_graph_enabled:
|
||||
if kv_cache is None:
|
||||
kv_cache = self.mla_attn.kv_cache[forward_context.virtual_engine]
|
||||
num_tokens = hidden_states.shape[0]
|
||||
need_gather_q_kv = False
|
||||
if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers:
|
||||
# Simulate all gather to calculate output shape
|
||||
num_tokens = num_tokens * self.tp_size
|
||||
need_gather_q_kv = True
|
||||
if not self.enable_shared_expert_dp or self.debug_layer_idx < self.first_k_dense_replace:
|
||||
output_shape = hidden_states.shape
|
||||
output = torch.empty(output_shape,
|
||||
dtype=hidden_states_or_q_c.dtype,
|
||||
device=hidden_states_or_q_c.device)
|
||||
forward_kwargs['output'] = output
|
||||
output = self.mla_attn.impl.forward(self.mla_attn,
|
||||
hidden_states_or_q_c,
|
||||
hidden_states, None, kv_cache,
|
||||
attn_metadata,
|
||||
**forward_kwargs)
|
||||
output = output.view(-1, output_shape[-1])
|
||||
return output
|
||||
else:
|
||||
kv_no_split = self.kv_a_proj_with_mqa(hidden_states)[0]
|
||||
if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers:
|
||||
hidden_states_or_q_c = get_tp_group().all_gather(
|
||||
hidden_states_or_q_c, 0)
|
||||
kv_no_split = get_tp_group().all_gather(kv_no_split, 0)
|
||||
|
||||
kv_c, k_pe = kv_no_split.split(
|
||||
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
|
||||
if not self.enable_shared_expert_dp or self.debug_layer_idx < self.first_k_dense_replace:
|
||||
output_shape = hidden_states.shape
|
||||
else:
|
||||
num_tokens = hidden_states_or_q_c.shape[0]
|
||||
rows = num_tokens // self.tp_size
|
||||
if num_tokens % self.tp_size:
|
||||
rows += 1
|
||||
output_shape = (rows, hidden_states.shape[1])
|
||||
return self.mla_attn(hidden_states_or_q_c,
|
||||
kv_c_normed,
|
||||
k_pe,
|
||||
output_shape=output_shape)
|
||||
rows = num_tokens // self.tp_size
|
||||
if num_tokens % self.tp_size:
|
||||
rows += 1
|
||||
output_shape = (rows, hidden_states.shape[1])
|
||||
output = torch.empty(output_shape,
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
output = self.mla_attn.impl.forward(hidden_states, kv_cache,
|
||||
forward_context.attn_metadata,
|
||||
need_gather_q_kv, output)
|
||||
output = output.view(-1, output_shape[-1])
|
||||
return output
|
||||
|
||||
|
||||
class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
||||
@@ -688,8 +660,6 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
self.mla_moe_communication = ascend_config.torchair_graph_config.enable_multistream_moe \
|
||||
and model_config.use_mla and self.tp_size > 1
|
||||
else:
|
||||
self.mlp = CustomDeepseekV2MLP(
|
||||
hidden_size=config.hidden_size,
|
||||
@@ -698,7 +668,6 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
self.mla_moe_communication = False
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
@@ -718,10 +687,6 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
||||
replace_allreduce: bool = False,
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
if attn_metadata is not None and attn_metadata.num_decodes > 0:
|
||||
mla_moe_communication = self.mla_moe_communication and replace_allreduce
|
||||
else:
|
||||
mla_moe_communication = False
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
@@ -733,9 +698,6 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
||||
# to save npu memory because they're no longer used.
|
||||
dispose_tensor(previous_hidden_states)
|
||||
dispose_tensor(previous_residual)
|
||||
if mla_moe_communication and self.layer_idx > self.first_k_dense_replace:
|
||||
hidden_states = tensor_model_parallel_all_gather(hidden_states,
|
||||
dim=0)
|
||||
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
@@ -744,13 +706,6 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
if mla_moe_communication and residual.shape[0] != hidden_states.shape[
|
||||
0]:
|
||||
chunk_hidden_states = torch.tensor_split(residual,
|
||||
self.tp_size,
|
||||
dim=0)
|
||||
residual = chunk_hidden_states[self.tp_rank]
|
||||
|
||||
if hidden_states.dtype == torch.float16:
|
||||
# Fix FP16 overflow
|
||||
# We scale both hidden_states and residual before
|
||||
@@ -778,9 +733,7 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
||||
hidden_states, residual)
|
||||
|
||||
if isinstance(self.mlp, CustomDeepseekV2MoE):
|
||||
hidden_states = self.mlp(hidden_states,
|
||||
attn_metadata,
|
||||
replace_allreduce=mla_moe_communication)
|
||||
hidden_states = self.mlp(hidden_states, attn_metadata)
|
||||
else:
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
|
||||
@@ -793,10 +746,6 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
||||
# The scaling of DeepseekV2MOE output would be done in the forward
|
||||
# of DeepseekV2MOE
|
||||
hidden_states *= 1. / self.routed_scaling_factor
|
||||
if mla_moe_communication and self.layer_idx == self.layers - 1:
|
||||
hidden_states = tensor_model_parallel_all_gather(hidden_states,
|
||||
dim=0)
|
||||
residual = tensor_model_parallel_all_gather(residual, dim=0)
|
||||
|
||||
# for last layer of main model and mtp layer.
|
||||
if self.enable_shared_expert_dp and self.layer_idx >= (
|
||||
|
||||
Reference in New Issue
Block a user