[Feature] Support DSA-CP for Hybrid scenario (#5702)

Signed-off-by: zzhx1 <zzh_201018@outlook.com>

### What this PR does / why we need it?
> Extracted from PR #5513
Based on the Sharded-CP feature PR:#4702;
RFC:https://github.com/vllm-project/vllm/issues/30055

### Support FULL_DECODE_ONLY Mode under PD-Mixed Scenario:
Extends DSA-CP to handle the FULL_DECODE_ONLY execution mode when
running in a prefill-decode mixed (PD-mixed) serving environment,
improving throughput and resource utilization for decode-intensive
workloads.
**In pure prefill nodes:**
- Both q_proj and o_proj are sharded across world ranks, using
**broadcast** for weights distribution.

**In PD-mixed nodes (supporting both prefill and decode):**

- q_proj is fully replicated (not sharded) to avoid communication
overhead during decoding.
- o_proj Using the original TP `RowParallelLinear` method to store
weights

**During prefill execution:**
- o_proj forwards through all_gather to collect weights, reconstructing
the complete o_proj weights on each card.

**During decode (graph replay phase):**
- Additional all_to_all (before o_proj) and reduce_scatter (after
o_proj) are introduced to enable sequence-parallel output aggregation
while maintaining correctness under SFA CP.

### benchmark:
- TTFT increased by **527%**
- TPOT increased by **180%**

<img width="1550" height="938" alt="image"
src="https://github.com/user-attachments/assets/9b7a03d8-a3db-4a99-8923-6e5bfcfecf72"
/>


### Does this PR introduce _any_ user-facing change?
None
### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef

---------

Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Signed-off-by: zzhxx <zhangzihang23@mails.ucas.ac.cn>
Co-authored-by: clrs97 <524936896@qq.com>
This commit is contained in:
zzhxxx
2026-01-22 10:12:09 +08:00
committed by GitHub
parent 69740039b7
commit dd8571860d
4 changed files with 207 additions and 68 deletions

View File

@@ -6,6 +6,7 @@ import torch_npu
import vllm.envs as envs_vllm import vllm.envs as envs_vllm
from torch import nn from torch import nn
from vllm.config import CUDAGraphMode, VllmConfig, get_current_vllm_config from vllm.config import CUDAGraphMode, VllmConfig, get_current_vllm_config
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
from vllm.logger import logger from vllm.logger import logger
@@ -34,7 +35,7 @@ from vllm_ascend.ops.triton.rope import rope_forward_triton
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, _round_up, dispose_layer, from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, _round_up, dispose_layer,
enable_dsa_cp, maybe_trans_nz, vllm_version_is) enable_dsa_cp, enable_dsa_cp_with_layer_shard, maybe_trans_nz, vllm_version_is)
from vllm_ascend.worker.npu_input_batch import NPUInputBatch from vllm_ascend.worker.npu_input_batch import NPUInputBatch
# isort: off # isort: off
@@ -79,7 +80,7 @@ class AscendSFABackend(AttentionBackend):
@dataclass @dataclass
class SfaCpContext: class DSACPContext:
num_tokens: int num_tokens: int
num_tokens_pad: int num_tokens_pad: int
local_start: int local_start: int
@@ -119,7 +120,7 @@ class AscendSFAMetadata:
attn_mask: torch.Tensor = None attn_mask: torch.Tensor = None
# chunked prefill by default if no attn_states passed # chunked prefill by default if no attn_states passed
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
sfa_cp_context: Optional[SfaCpContext] = None dsa_cp_context: Optional[DSACPContext] = None
reshape_cache_event: torch.npu.Event = None reshape_cache_event: torch.npu.Event = None
@@ -159,15 +160,16 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
npu_fused_infer_attention_score TND layout's limit of 16, \ npu_fused_infer_attention_score TND layout's limit of 16, \
got {self.decode_threshold}" got {self.decode_threshold}"
self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
self.enable_sfa_cp = enable_dsa_cp()
assert not (
self.enable_sfa_cp
and self.vllm_config.compilation_config.cudagraph_mode
== CUDAGraphMode.FULL_DECODE_ONLY
), "FlashComm1 is not compatible with FULL_DECODE_ONLY. Please set graph_mode to 'piecewise' or disable FlashComm1."
self.attn_mask_builder = AttentionMaskBuilder(self.device) self.attn_mask_builder = AttentionMaskBuilder(self.device)
self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
self.enable_dsa_cp = enable_dsa_cp()
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
self.actual_seq_lengths_query = torch.zeros(max_num_reqs + 1,
dtype=torch.int32,
device=device)
self.actual_seq_lengths_key = torch.empty_like(
self.actual_seq_lengths_query)
@staticmethod @staticmethod
def determine_chunked_prefill_workspace_size( def determine_chunked_prefill_workspace_size(
@@ -210,8 +212,8 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
cos, sin = get_cos_and_sin_mla(input_positions, True) cos, sin = get_cos_and_sin_mla(input_positions, True)
sfa_cp_context = None dsa_cp_context = None
if self.enable_sfa_cp: if self.enable_dsa_cp:
global_tp_size = get_tp_group().world_size global_tp_size = get_tp_group().world_size
num_tokens = num_input_tokens num_tokens = num_input_tokens
num_tokens_pad = _round_up(num_tokens, global_tp_size) num_tokens_pad = _round_up(num_tokens, global_tp_size)
@@ -235,13 +237,11 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
value=-1) value=-1)
else: else:
slot_mapping = slot_mapping[:num_tokens_pad] slot_mapping = slot_mapping[:num_tokens_pad]
slot_mapping_cp = slot_mapping[local_start:local_end_with_pad]
cos = cos[local_start:local_end_with_pad] cos = cos[local_start:local_end_with_pad]
sin = sin[local_start:local_end_with_pad] sin = sin[local_start:local_end_with_pad]
slot_mapping_cp = torch.full(size=(num_tokens_per_device, ),
fill_value=-1,
dtype=slot_mapping.dtype,
device=slot_mapping.device)
assert cos.shape[0] == num_tokens_per_device, \ assert cos.shape[0] == num_tokens_per_device, \
f"cos.shape[0] must be equal to num_tokens_per_device, \ f"cos.shape[0] must be equal to num_tokens_per_device, \
got {cos.shape[0]} and {num_tokens_per_device}" got {cos.shape[0]} and {num_tokens_per_device}"
@@ -252,8 +252,9 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
f"slot_mapping.shape[0] must be equal to num_tokens_pad, \ f"slot_mapping.shape[0] must be equal to num_tokens_pad, \
got {slot_mapping.shape[0]} and {num_tokens_pad}" got {slot_mapping.shape[0]} and {num_tokens_pad}"
actual_seq_lengths_query = torch.empty_like(cum_query_lens) actual_seq_lengths_query = self.actual_seq_lengths_query
actual_seq_lengths_key = torch.empty_like(seq_lens) actual_seq_lengths_key = self.actual_seq_lengths_key
num_segs = cum_query_lens.shape[0] num_segs = cum_query_lens.shape[0]
last_token = 0 last_token = 0
cum = 0 cum = 0
@@ -262,21 +263,24 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
global_end = cum_query_lens[i].item() global_end = cum_query_lens[i].item()
last_token = global_end last_token = global_end
local_start = max(global_start, local_start) req_local_start = max(global_start, local_start)
local_end = min(global_end, local_end_with_pad) req_local_end = min(global_end, local_end_with_pad)
num_local_tokens = local_end - local_start num_local_tokens = req_local_end - req_local_start
if num_local_tokens > 0: if num_local_tokens > 0:
cum += num_local_tokens cum += num_local_tokens
actual_seq_lengths_query[i] = cum actual_seq_lengths_query[i] = cum
offset = global_end - local_end offset = global_end - req_local_end
actual_seq_lengths_key[i] = seq_lens[i].item() - offset actual_seq_lengths_key[i] = seq_lens[i].item() - offset
else: else:
actual_seq_lengths_query[i] = cum actual_seq_lengths_query[i] = cum
actual_seq_lengths_key[i] = 0 actual_seq_lengths_key[i] = 0
sfa_cp_context = SfaCpContext( actual_seq_lengths_query = actual_seq_lengths_query[:num_reqs]
actual_seq_lengths_key = actual_seq_lengths_key[:num_reqs]
dsa_cp_context = DSACPContext(
num_tokens=num_tokens, num_tokens=num_tokens,
num_tokens_pad=num_tokens_pad, num_tokens_pad=num_tokens_pad,
local_start=local_start, local_start=local_start,
@@ -300,7 +304,7 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
block_tables=block_table, block_tables=block_table,
sin=sin[:num_input_tokens], sin=sin[:num_input_tokens],
cos=cos[:num_input_tokens], cos=cos[:num_input_tokens],
sfa_cp_context=sfa_cp_context) dsa_cp_context=dsa_cp_context)
def build_for_graph_capture( def build_for_graph_capture(
self, self,
@@ -329,6 +333,8 @@ class AscendSFAImpl(MLAAttentionImpl):
NOTE: Please read the comment at the top of the file before trying to NOTE: Please read the comment at the top of the file before trying to
understand this class understand this class
""" """
# Supports forward using the all-gather o_proj weight for decode requests when Sharded CP is enabled.
o_proj_full_pool: Optional[torch.Tensor] = None
def __init__( def __init__(
self, self,
@@ -382,22 +388,9 @@ class AscendSFAImpl(MLAAttentionImpl):
assert self.indexer is not None, "Indexer is required for DSA." assert self.indexer is not None, "Indexer is required for DSA."
self.enable_sfa_cp = enable_dsa_cp()
self.local_num_heads = self.num_heads self.local_num_heads = self.num_heads
self.vllm_config = get_current_vllm_config() self.vllm_config = get_current_vllm_config()
self.is_kv_producer = self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer self.is_kv_producer = self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer
if self.enable_sfa_cp:
self.local_num_heads = self.num_heads * self.tp_size
self.layer_sharding_kwargs = []
for layer_name in (get_ascend_config().layer_sharding or []):
if layer_name in kwargs:
self.layer_sharding_kwargs.append(kwargs[layer_name])
else:
logger.warning_once(
f"Layer '{layer_name}' not found in kwargs for layer sharding, skipping sharding configuration"
)
register_all_layers_to_shard_weight_series(
self.layer_sharding_kwargs)
# indexer param # indexer param
self.n_head: int = self.indexer.n_head # 64 self.n_head: int = self.indexer.n_head # 64
@@ -406,9 +399,24 @@ class AscendSFAImpl(MLAAttentionImpl):
self.wk = self.indexer.wk self.wk = self.indexer.wk
self.weights_proj = self.indexer.weights_proj self.weights_proj = self.indexer.weights_proj
self.k_norm = self.indexer.k_norm self.k_norm = self.indexer.k_norm
self.cp_size = 1 self.cp_size = 1
self.enable_dsa_cp = enable_dsa_cp()
self.enable_dsa_cp_prefill_only = enable_dsa_cp_with_layer_shard()
if self.enable_dsa_cp:
self.local_num_heads = self.num_heads * self.tp_size
if self.enable_dsa_cp_prefill_only:
self.layer_sharding_kwargs = []
for layer_name in (get_ascend_config().layer_sharding or []):
if layer_name in kwargs:
self.layer_sharding_kwargs.append(kwargs[layer_name])
else:
logger.warning_once(
f"[SFAImpl init] Layer '{layer_name}' not found in kwargs for layer sharding, skipping sharding configuration"
)
register_all_layers_to_shard_weight_series(
self.layer_sharding_kwargs)
def process_weights_after_loading(self, act_dtype: torch.dtype): def process_weights_after_loading(self, act_dtype: torch.dtype):
# NOTE: We currently do not support quant kv_b_proj. # NOTE: We currently do not support quant kv_b_proj.
assert isinstance(self.kv_b_proj.quant_method, UnquantizedLinearMethod) assert isinstance(self.kv_b_proj.quant_method, UnquantizedLinearMethod)
@@ -442,10 +450,14 @@ class AscendSFAImpl(MLAAttentionImpl):
# Dispose kv_b_proj since it is replaced by W_UV and W_UK_T to save memory # Dispose kv_b_proj since it is replaced by W_UV and W_UK_T to save memory
dispose_layer(self.kv_b_proj) dispose_layer(self.kv_b_proj)
if self.enable_sfa_cp: if self.enable_dsa_cp:
for layer in (self.layer_sharding_kwargs or []): if self.enable_dsa_cp_prefill_only:
if is_hidden_layer(layer): for layer in (self.layer_sharding_kwargs or []):
post_process_after_loading_for_shard_weight_series(layer) if is_hidden_layer(layer):
post_process_after_loading_for_shard_weight_series(
layer)
else:
self._init_o_proj_tp_full_params()
if self.enable_mlapo: if self.enable_mlapo:
quant_method = getattr( quant_method = getattr(
@@ -460,7 +472,7 @@ class AscendSFAImpl(MLAAttentionImpl):
"Currently mlapo only supports W8A8 quantization in SFA scenario." "Currently mlapo only supports W8A8 quantization in SFA scenario."
"Some layers in your model are not quantized with W8A8," "Some layers in your model are not quantized with W8A8,"
"thus mlapo is disabled for these layers.") "thus mlapo is disabled for these layers.")
if self.enable_sfa_cp: if self.enable_dsa_cp:
reasons.append("Currently mlapo does not support SFA with CP," reasons.append("Currently mlapo does not support SFA with CP,"
"thus mlapo is disabled for these layers.") "thus mlapo is disabled for these layers.")
if reasons: if reasons:
@@ -525,7 +537,7 @@ class AscendSFAImpl(MLAAttentionImpl):
B, N, S, self.kv_lora_rank + self.qk_rope_head_dim) B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
cache_mode = "PA" cache_mode = "PA"
if self.enable_sfa_cp: if self.enable_dsa_cp:
_, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache( _, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache(
kv_no_split, kv_no_split,
self.kv_a_layernorm.weight, self.kv_a_layernorm.weight,
@@ -738,7 +750,7 @@ class AscendSFAImpl(MLAAttentionImpl):
forward_context = get_forward_context() forward_context = get_forward_context()
if attn_metadata is None: if attn_metadata is None:
# Profiling run. # Profiling run.
if self.enable_sfa_cp and not forward_context.in_profile_run: if self.enable_dsa_cp_prefill_only and not forward_context.in_profile_run:
for layer in (self.layer_sharding_kwargs or []): for layer in (self.layer_sharding_kwargs or []):
if is_hidden_layer(layer): if is_hidden_layer(layer):
reach_layer_for_shard_weight_series(layer) reach_layer_for_shard_weight_series(layer)
@@ -748,12 +760,20 @@ class AscendSFAImpl(MLAAttentionImpl):
sin = attn_metadata.sin sin = attn_metadata.sin
actual_seq_lengths_query = attn_metadata.cum_query_lens actual_seq_lengths_query = attn_metadata.cum_query_lens
actual_seq_lengths_key = attn_metadata.seq_lens actual_seq_lengths_key = attn_metadata.seq_lens
if self.enable_sfa_cp: if self.enable_dsa_cp:
need_gather_q_kv = False need_gather_q_kv = False
# Inputs and outputs may be padded for CUDA graphs # Inputs and outputs may be padded for CUDA graphs
num_input_tokens = attn_metadata.num_input_tokens num_input_tokens = attn_metadata.num_input_tokens
output_padded = output output_padded = output
# all-gather o_proj weight for prefill stage of PD mix node
o_proj_full_handle = None
# if is PD mix stage, using original TP o_proj weight, and also need to full gather for o_proj weight for prefill stage.
should_shard_weight = self.enable_dsa_cp_prefill_only or attn_metadata.attn_state not in {
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
}
if self.enable_mlapo and num_input_tokens <= MLAPO_MAX_SUPPORTED_TOKENS: if self.enable_mlapo and num_input_tokens <= MLAPO_MAX_SUPPORTED_TOKENS:
hidden_states, ql_nope, q_pe, q_c = self._sfa_preprocess_decode( hidden_states, ql_nope, q_pe, q_c = self._sfa_preprocess_decode(
hidden_states=hidden_states, hidden_states=hidden_states,
@@ -796,16 +816,16 @@ class AscendSFAImpl(MLAAttentionImpl):
wait_for_kv_layer_from_connector(layer_name) wait_for_kv_layer_from_connector(layer_name)
slot_mapping = attn_metadata.slot_mapping slot_mapping = attn_metadata.slot_mapping
if self.enable_sfa_cp: if self.enable_dsa_cp:
assert attn_metadata.sfa_cp_context is not None assert attn_metadata.dsa_cp_context is not None
slot_mapping = attn_metadata.sfa_cp_context.slot_mapping_cp slot_mapping = attn_metadata.dsa_cp_context.slot_mapping_cp
actual_seq_lengths_query = attn_metadata.sfa_cp_context.actual_seq_lengths_query actual_seq_lengths_query = attn_metadata.dsa_cp_context.actual_seq_lengths_query
actual_seq_lengths_key = attn_metadata.sfa_cp_context.actual_seq_lengths_key actual_seq_lengths_key = attn_metadata.dsa_cp_context.actual_seq_lengths_key
k_pe, k_nope = self.exec_kv(kv_no_split, cos, sin, kv_cache, k_pe, k_nope = self.exec_kv(kv_no_split, cos, sin, kv_cache,
slot_mapping) slot_mapping)
if self.enable_sfa_cp: if self.enable_dsa_cp:
assert k_pe is not None assert k_pe is not None
assert k_nope is not None assert k_nope is not None
# support all_gather kv async for communication calculation overlap # support all_gather kv async for communication calculation overlap
@@ -815,17 +835,26 @@ class AscendSFAImpl(MLAAttentionImpl):
k_nope.view(-1, k_nope.shape[-1]), k_nope.view(-1, k_nope.shape[-1]),
k.view(-1, k.shape[-1]) k.view(-1, k.shape[-1])
], ],
dim=1), get_tp_group()) dim=1),
get_tp_group(),
async_op=should_shard_weight)
ql_nope, q_pe = self._q_proj_and_k_up_proj(q_c) ql_nope, q_pe = self._q_proj_and_k_up_proj(q_c)
q_pe = self.rope_single(q_pe, cos, sin) q_pe = self.rope_single(q_pe, cos, sin)
if self.enable_sfa_cp: if self.enable_dsa_cp:
if kv_ag_handle is not None: if kv_ag_handle is not None:
kv_ag_handle.wait() kv_ag_handle.wait()
for layer in (self.layer_sharding_kwargs or []):
if is_hidden_layer(layer): if self.enable_dsa_cp_prefill_only:
reach_layer_for_shard_weight_series(layer) for layer in (self.layer_sharding_kwargs or []):
if is_hidden_layer(layer):
reach_layer_for_shard_weight_series(layer)
elif should_shard_weight:
_, o_proj_full_handle = all_gather_async(
self.o_proj_tp_weight,
get_tp_group(),
output=AscendSFAImpl.o_proj_full_pool)
if kv_cache is not None: if kv_cache is not None:
assert fused_kv_no_split is not None assert fused_kv_no_split is not None
@@ -841,6 +870,12 @@ class AscendSFAImpl(MLAAttentionImpl):
kv_cache[1].view(-1, k_pe.shape[-1]), slot_mapping, kv_cache[1].view(-1, k_pe.shape[-1]), slot_mapping,
k_pe) k_pe)
if kv_cache is not None:
torch_npu.npu_scatter_nd_update_(
kv_cache[2].view(-1, k.shape[-1]),
attn_metadata.slot_mapping.view(-1, 1),
k.view(-1, k.shape[-1])) # b, s, n, d
topk_indices = self.indexer_select_post_process( topk_indices = self.indexer_select_post_process(
x=hidden_states, x=hidden_states,
qr=q_c, qr=q_c,
@@ -876,6 +911,20 @@ class AscendSFAImpl(MLAAttentionImpl):
dependency=attn_output, dependency=attn_output,
max_size=MAX_O_PROJ_PREFETCH_SIZE, max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=self.enable_prefetch) enabled=self.enable_prefetch)
if self.enable_dsa_cp and not self.enable_dsa_cp_prefill_only:
# When using SFA-CP with pd mixed, o_proj has two cases:
# 1. prefill: o_proj is a TP weight, we need to all-gather o_proj weight to switch TP=1.
# 2. decode: all-to-all the hidden_state before the o_proj forward.
result, require_o_proj_forward = self._handle_o_proj_weight_switch_and_forward(
attn_output=attn_output,
output=output,
o_proj_full_handle=o_proj_full_handle,
should_shard_weight=should_shard_weight)
if not require_o_proj_forward:
return result
attn_output = result
output[...] = self.o_proj(attn_output)[0] output[...] = self.o_proj(attn_output)[0]
maybe_save_kv_layer_to_connector(layer_name, list(kv_cache)) maybe_save_kv_layer_to_connector(layer_name, list(kv_cache))
@@ -912,7 +961,10 @@ class AscendSFAImpl(MLAAttentionImpl):
k_pe, k_nope = torch.split( k_pe, k_nope = torch.split(
k, k,
[self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim],
dim=-1) # [b,s,64+64] dim=-1)
cos = cos.view(-1, 1, 1, self.qk_rope_head_dim)
sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)
k_pe = k_pe.unsqueeze(2) k_pe = k_pe.unsqueeze(2)
k_pe = torch_npu.npu_interleave_rope(k_pe, cos, sin) k_pe = torch_npu.npu_interleave_rope(k_pe, cos, sin)
@@ -940,10 +992,7 @@ class AscendSFAImpl(MLAAttentionImpl):
if q is None: if q is None:
q, _ = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128] q, _ = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128]
q = q.view(-1, self.n_head, self.head_dim) # [n_toks,64,128] q = q.view(-1, self.n_head, self.head_dim) # [n_toks,64,128]
cos_q, sin_q = cos, sin cos_q, sin_q = cos, sin
cos = cos.view(-1, 1, 1, self.qk_rope_head_dim)
sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)
q_pe, q_nope = torch.split( q_pe, q_nope = torch.split(
q, q,
@@ -984,3 +1033,92 @@ class AscendSFAImpl(MLAAttentionImpl):
sparse_count=2048, sparse_count=2048,
sparse_mode=3) sparse_mode=3)
return topk_indices return topk_indices
def _init_o_proj_tp_full_params(self):
"""
Initialize TP-mode and Full-mode parameters for o_proj weight,
preparing for weight switching in PD mix stage.
For PD mix stage:
- Use original TP o_proj weight for decode phase
- Need full-gather o_proj weight from all TP ranks for prefill phase
"""
if AscendSFAImpl.o_proj_full_pool is None:
sample = self.o_proj.weight
AscendSFAImpl.o_proj_full_pool = torch.empty(
(sample.shape[0] * self.tp_size, sample.shape[1]),
dtype=sample.dtype,
device=sample.device)
# Save TP-mode parameters (original sharded weights)
self.o_proj_tp_weight = self.o_proj.weight.clone().detach()
self.o_proj_tp_aclnn_input_scale = self.o_proj.aclnn_input_scale.clone(
).detach()
self.o_proj_tp_aclnn_input_scale_reciprocal = self.o_proj.aclnn_input_scale_reciprocal.clone(
).detach()
self.o_proj_tp_aclnn_input_offset = self.o_proj.aclnn_input_offset.clone(
).detach()
# Initially switch to TP mode for graph capture
self.o_proj.weight.set_(self.o_proj_tp_weight)
self.o_proj.aclnn_input_scale.set_(self.o_proj_tp_aclnn_input_scale)
self.o_proj.aclnn_input_scale_reciprocal.set_(
self.o_proj_tp_aclnn_input_scale_reciprocal)
self.o_proj.aclnn_input_offset.set_(self.o_proj_tp_aclnn_input_offset)
# Precompute Full-mode quantization parameters by repeating TP parameters across all TP ranks
self.o_proj_full_aclnn_input_scale = self.o_proj.aclnn_input_scale.repeat(
self.tp_size)
self.o_proj_full_aclnn_input_scale_reciprocal = self.o_proj.aclnn_input_scale_reciprocal.repeat(
self.tp_size)
self.o_proj_full_aclnn_input_offset = self.o_proj.aclnn_input_offset.repeat(
self.tp_size)
def _handle_o_proj_weight_switch_and_forward(
self, attn_output: torch.Tensor, output: torch.Tensor,
o_proj_full_handle: Optional[torch.distributed.Work],
should_shard_weight: bool) -> Tuple[torch.Tensor, bool]:
"""
Handle o_proj weight switching between TP-mode and Full-mode, and execute forward computation.
"""
# Gather o_proj weight from all TP ranks for Full-mode computation
if should_shard_weight:
# Wait for the completion of o_proj weight all-gather operation
if o_proj_full_handle is not None:
o_proj_full_handle.wait()
# Switch o_proj to Full-mode (gathered weight from all TP ranks)
self.o_proj.weight.set_(AscendSFAImpl.o_proj_full_pool)
self.o_proj.aclnn_input_scale.set_(
self.o_proj_full_aclnn_input_scale)
self.o_proj.aclnn_input_scale_reciprocal.set_(
self.o_proj_full_aclnn_input_scale_reciprocal)
self.o_proj.aclnn_input_offset.set_(
self.o_proj_full_aclnn_input_offset)
# Apply quantization method and execute forward computation
output[...] = self.o_proj.quant_method.quant_method.apply(
self.o_proj, attn_output)
# Switch o_proj back to TP-mode for subsequent decode operations
self.o_proj.weight.set_(self.o_proj_tp_weight)
self.o_proj.aclnn_input_scale.set_(
self.o_proj_tp_aclnn_input_scale)
self.o_proj.aclnn_input_scale_reciprocal.set_(
self.o_proj_tp_aclnn_input_scale_reciprocal)
self.o_proj.aclnn_input_offset.set_(
self.o_proj_tp_aclnn_input_offset)
return output, False
else:
# For decode scenario: perform all-to-all communication on o_proj input activations
# Reshape for all-to-all: [batch * seq, tp_size, head_dim] -> [tp_size, batch * seq, head_dim]
send = attn_output.view(-1, self.tp_size, self.num_heads *
self.v_head_dim).permute(1, 0, 2).reshape(
-1, self.num_heads * self.v_head_dim)
attn_output = torch.empty_like(send)
torch.distributed.all_to_all_single(
attn_output, send, group=get_tp_group().device_group)
return attn_output, True

View File

@@ -7,7 +7,7 @@ from vllm.distributed.parallel_state import (GroupCoordinator, get_tp_group,
init_model_parallel_group) init_model_parallel_group)
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.utils import enable_dsa_cp, flashcomm2_enable from vllm_ascend.utils import enable_dsa_cp_with_layer_shard, flashcomm2_enable
# Currently, mc2 op need their own group coordinator. # Currently, mc2 op need their own group coordinator.
_MC2: Optional[GroupCoordinator] = None _MC2: Optional[GroupCoordinator] = None
@@ -238,7 +238,7 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
FC2_group_ranks = torch.tensor( FC2_group_ranks = torch.tensor(
flashcomm2_otp_group_ranks).squeeze(0) flashcomm2_otp_group_ranks).squeeze(0)
_SHARD_WEIGHT = create_shard_weight_group(FC2_group_ranks) _SHARD_WEIGHT = create_shard_weight_group(FC2_group_ranks)
elif enable_dsa_cp(): elif enable_dsa_cp_with_layer_shard():
# For dsa_cp, all shard layers are replicated. # For dsa_cp, all shard layers are replicated.
_SHARD_WEIGHT = create_shard_weight_group(None) _SHARD_WEIGHT = create_shard_weight_group(None)
else: else:

View File

@@ -52,4 +52,4 @@ def all_gather_async(input: torch.Tensor,
return output, dist.all_gather_into_tensor(output, return output, dist.all_gather_into_tensor(output,
input, input,
group=group.device_group, group=group.device_group,
async_op=async_op) async_op=async_op)

View File

@@ -62,7 +62,7 @@ from vllm_ascend.distributed.parallel_state import (get_flashcomm2_odp_group,
get_mlp_tp_group, get_mlp_tp_group,
get_otp_group) get_otp_group)
from vllm_ascend.ops.flashcomm2_oshard_manager import flashcomm2_oshard_manager from vllm_ascend.ops.flashcomm2_oshard_manager import flashcomm2_oshard_manager
from vllm_ascend.utils import (enable_dsa_cp, enable_sp, flashcomm2_enable, from vllm_ascend.utils import (enable_dsa_cp, enable_dsa_cp_with_layer_shard, enable_sp, flashcomm2_enable,
get_flashcomm2_reorgnized_batch_ids, get_flashcomm2_reorgnized_batch_ids,
matmul_allreduce_enable, mlp_tp_enable, matmul_allreduce_enable, mlp_tp_enable,
oproj_tp_enable, shared_expert_dp_enabled) oproj_tp_enable, shared_expert_dp_enabled)
@@ -575,7 +575,8 @@ class SequenceRowParallelOp(CustomRowParallelOp):
return tensor_model_parallel_all_reduce(output_parallel) return tensor_model_parallel_all_reduce(output_parallel)
pad_size = forward_context.pad_size pad_size = forward_context.pad_size
if pad_size > 0: if pad_size > 0 and not (enable_dsa_cp()
and "o_proj" in self.layer.prefix):
x = F.pad(x, (0, 0, 0, pad_size)) x = F.pad(x, (0, 0, 0, pad_size))
world_size = self.layer.tp_size world_size = self.layer.tp_size
@@ -728,7 +729,7 @@ def _get_row_parallel_op(
) -> Optional[Union[MLPRowParallelOp, OProjRowParallelOp, ) -> Optional[Union[MLPRowParallelOp, OProjRowParallelOp,
Flashcomm2OProjRowParallelOp, MatmulAllreduceRowParallelOp, Flashcomm2OProjRowParallelOp, MatmulAllreduceRowParallelOp,
SequenceRowParallelOp, ShardedCPRowParallelOp]]: SequenceRowParallelOp, ShardedCPRowParallelOp]]:
if enable_dsa_cp() and "o_proj" in prefix: if enable_dsa_cp_with_layer_shard() and "o_proj" in prefix:
return ShardedCPRowParallelOp(layer) return ShardedCPRowParallelOp(layer)
if "down_proj" in prefix and mlp_tp_enable() and not is_moe_layer(prefix): if "down_proj" in prefix and mlp_tp_enable() and not is_moe_layer(prefix):
return MLPRowParallelOp(layer) return MLPRowParallelOp(layer)