[Feat] Flash comm allgher ep (#3334)
Support flash comm v1(Sequence Parallelism) for Allgather EP. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: realliujiaxu <realliujiaxu@163.com> Co-authored-by: zhaozx-cn <zhaozx2116@163.com>
This commit is contained in:
@@ -9,7 +9,7 @@ from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
AttentionMetadata,
|
||||
MLAAttentionImpl)
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
@@ -1128,10 +1128,11 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
q_c = hidden_states
|
||||
|
||||
kv_no_split = self.kv_a_proj_with_mqa(hidden_states)[0]
|
||||
# Process for shared_expert_dp
|
||||
if need_gather_q_kv:
|
||||
q_c = get_tp_group().all_gather(q_c, 0)
|
||||
kv_no_split = get_tp_group().all_gather(kv_no_split, 0)
|
||||
# Process for Flash Comm V1
|
||||
q_c = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
q_c, need_gather_q_kv)
|
||||
kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
kv_no_split, need_gather_q_kv)
|
||||
decode_preprocess_res = None
|
||||
prefill_preprocess_res = None
|
||||
if has_prefill:
|
||||
@@ -1200,8 +1201,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
# Inputs and outputs may be padded for CUDA graphs
|
||||
output_padded = output
|
||||
output = output[:num_actual_tokens, ...]
|
||||
o_proj_input_shape = (num_actual_tokens,
|
||||
o_proj_input_shape = (get_forward_context().num_tokens,
|
||||
self.num_heads * self.v_head_dim)
|
||||
o_proj_input = torch.empty(o_proj_input_shape,
|
||||
dtype=hidden_states.dtype,
|
||||
@@ -1248,7 +1248,8 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
o_proj_input[num_decode_tokens:] = output_prefill
|
||||
current_ms_metadata.after_comm_event.record()
|
||||
else:
|
||||
o_proj_input[num_decode_tokens:] = output_prefill
|
||||
o_proj_input[
|
||||
num_decode_tokens:num_actual_tokens] = output_prefill
|
||||
# O proj
|
||||
current_ms_metadata = get_multistream_comm_context()
|
||||
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024
|
||||
@@ -1258,20 +1259,14 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
max_size=MAX_O_PROJ_PREFETCH_SIZE,
|
||||
enabled=self.enable_prefetch)
|
||||
|
||||
output[...] = self.o_proj(
|
||||
o_proj_input,
|
||||
is_prefill=prefill_preprocess_res is not None,
|
||||
is_force_scatter=self.enable_shared_expert_dp)[0]
|
||||
output[...] = self.o_proj(o_proj_input)[0]
|
||||
else:
|
||||
with torch.npu.stream(current_ms_metadata.comm_stream):
|
||||
maybe_npu_prefetch(inputs=self.o_proj.weight,
|
||||
dependency=o_proj_input,
|
||||
max_size=MAX_O_PROJ_PREFETCH_SIZE,
|
||||
enabled=self.enable_prefetch)
|
||||
output[...] = self.o_proj(
|
||||
o_proj_input,
|
||||
is_prefill=prefill_preprocess_res is not None,
|
||||
is_force_scatter=self.enable_shared_expert_dp)[0]
|
||||
output[...] = self.o_proj(o_proj_input)[0]
|
||||
current_ms_metadata.after_comm_event.record()
|
||||
del o_proj_input
|
||||
|
||||
|
||||
Reference in New Issue
Block a user