[Feat] Flash comm allgher ep (#3334)

Support flash comm v1(Sequence Parallelism) for Allgather EP.

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: realliujiaxu <realliujiaxu@163.com>
Co-authored-by: zhaozx-cn <zhaozx2116@163.com>
This commit is contained in:
realliujiaxu
2025-10-15 19:36:32 +08:00
committed by GitHub
parent 8abe517870
commit f69a83b7ba
15 changed files with 283 additions and 78 deletions

View File

@@ -9,7 +9,7 @@ from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata,
MLAAttentionImpl)
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
@@ -1128,10 +1128,11 @@ class AscendMLAImpl(MLAAttentionImpl):
q_c = hidden_states
kv_no_split = self.kv_a_proj_with_mqa(hidden_states)[0]
# Process for shared_expert_dp
if need_gather_q_kv:
q_c = get_tp_group().all_gather(q_c, 0)
kv_no_split = get_tp_group().all_gather(kv_no_split, 0)
# Process for Flash Comm V1
q_c = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
q_c, need_gather_q_kv)
kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
kv_no_split, need_gather_q_kv)
decode_preprocess_res = None
prefill_preprocess_res = None
if has_prefill:
@@ -1200,8 +1201,7 @@ class AscendMLAImpl(MLAAttentionImpl):
num_decode_tokens = attn_metadata.num_decode_tokens
# Inputs and outputs may be padded for CUDA graphs
output_padded = output
output = output[:num_actual_tokens, ...]
o_proj_input_shape = (num_actual_tokens,
o_proj_input_shape = (get_forward_context().num_tokens,
self.num_heads * self.v_head_dim)
o_proj_input = torch.empty(o_proj_input_shape,
dtype=hidden_states.dtype,
@@ -1248,7 +1248,8 @@ class AscendMLAImpl(MLAAttentionImpl):
o_proj_input[num_decode_tokens:] = output_prefill
current_ms_metadata.after_comm_event.record()
else:
o_proj_input[num_decode_tokens:] = output_prefill
o_proj_input[
num_decode_tokens:num_actual_tokens] = output_prefill
# O proj
current_ms_metadata = get_multistream_comm_context()
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024
@@ -1258,20 +1259,14 @@ class AscendMLAImpl(MLAAttentionImpl):
max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=self.enable_prefetch)
output[...] = self.o_proj(
o_proj_input,
is_prefill=prefill_preprocess_res is not None,
is_force_scatter=self.enable_shared_expert_dp)[0]
output[...] = self.o_proj(o_proj_input)[0]
else:
with torch.npu.stream(current_ms_metadata.comm_stream):
maybe_npu_prefetch(inputs=self.o_proj.weight,
dependency=o_proj_input,
max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=self.enable_prefetch)
output[...] = self.o_proj(
o_proj_input,
is_prefill=prefill_preprocess_res is not None,
is_force_scatter=self.enable_shared_expert_dp)[0]
output[...] = self.o_proj(o_proj_input)[0]
current_ms_metadata.after_comm_event.record()
del o_proj_input