[Feature] Support DSA-CP for Hybrid scenario (#5702)
Signed-off-by: zzhx1 <zzh_201018@outlook.com>
### What this PR does / why we need it?
> Extracted from PR #5513
Based on the Sharded-CP feature PR:#4702;
RFC:https://github.com/vllm-project/vllm/issues/30055
### Support FULL_DECODE_ONLY Mode under PD-Mixed Scenario:
Extends DSA-CP to handle the FULL_DECODE_ONLY execution mode when
running in a prefill-decode mixed (PD-mixed) serving environment,
improving throughput and resource utilization for decode-intensive
workloads.
**In pure prefill nodes:**
- Both q_proj and o_proj are sharded across world ranks, using
**broadcast** for weights distribution.
**In PD-mixed nodes (supporting both prefill and decode):**
- q_proj is fully replicated (not sharded) to avoid communication
overhead during decoding.
- o_proj Using the original TP `RowParallelLinear` method to store
weights
**During prefill execution:**
- o_proj forwards through all_gather to collect weights, reconstructing
the complete o_proj weights on each card.
**During decode (graph replay phase):**
- Additional all_to_all (before o_proj) and reduce_scatter (after
o_proj) are introduced to enable sequence-parallel output aggregation
while maintaining correctness under SFA CP.
### benchmark:
- TTFT increased by **527%**
- TPOT increased by **180%**
<img width="1550" height="938" alt="image"
src="https://github.com/user-attachments/assets/9b7a03d8-a3db-4a99-8923-6e5bfcfecf72"
/>
### Does this PR introduce _any_ user-facing change?
None
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
---------
Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Signed-off-by: zzhxx <zhangzihang23@mails.ucas.ac.cn>
Co-authored-by: clrs97 <524936896@qq.com>
This commit is contained in:
@@ -6,6 +6,7 @@ import torch_npu
|
|||||||
import vllm.envs as envs_vllm
|
import vllm.envs as envs_vllm
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from vllm.config import CUDAGraphMode, VllmConfig, get_current_vllm_config
|
from vllm.config import CUDAGraphMode, VllmConfig, get_current_vllm_config
|
||||||
|
from vllm.config import VllmConfig, get_current_vllm_config
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
|
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
@@ -34,7 +35,7 @@ from vllm_ascend.ops.triton.rope import rope_forward_triton
|
|||||||
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||||
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
|
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
|
||||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, _round_up, dispose_layer,
|
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, _round_up, dispose_layer,
|
||||||
enable_dsa_cp, maybe_trans_nz, vllm_version_is)
|
enable_dsa_cp, enable_dsa_cp_with_layer_shard, maybe_trans_nz, vllm_version_is)
|
||||||
from vllm_ascend.worker.npu_input_batch import NPUInputBatch
|
from vllm_ascend.worker.npu_input_batch import NPUInputBatch
|
||||||
|
|
||||||
# isort: off
|
# isort: off
|
||||||
@@ -79,7 +80,7 @@ class AscendSFABackend(AttentionBackend):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SfaCpContext:
|
class DSACPContext:
|
||||||
num_tokens: int
|
num_tokens: int
|
||||||
num_tokens_pad: int
|
num_tokens_pad: int
|
||||||
local_start: int
|
local_start: int
|
||||||
@@ -119,7 +120,7 @@ class AscendSFAMetadata:
|
|||||||
attn_mask: torch.Tensor = None
|
attn_mask: torch.Tensor = None
|
||||||
# chunked prefill by default if no attn_states passed
|
# chunked prefill by default if no attn_states passed
|
||||||
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
|
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
|
||||||
sfa_cp_context: Optional[SfaCpContext] = None
|
dsa_cp_context: Optional[DSACPContext] = None
|
||||||
reshape_cache_event: torch.npu.Event = None
|
reshape_cache_event: torch.npu.Event = None
|
||||||
|
|
||||||
|
|
||||||
@@ -159,15 +160,16 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
|
|||||||
npu_fused_infer_attention_score TND layout's limit of 16, \
|
npu_fused_infer_attention_score TND layout's limit of 16, \
|
||||||
got {self.decode_threshold}"
|
got {self.decode_threshold}"
|
||||||
|
|
||||||
self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
|
|
||||||
self.enable_sfa_cp = enable_dsa_cp()
|
|
||||||
|
|
||||||
assert not (
|
|
||||||
self.enable_sfa_cp
|
|
||||||
and self.vllm_config.compilation_config.cudagraph_mode
|
|
||||||
== CUDAGraphMode.FULL_DECODE_ONLY
|
|
||||||
), "FlashComm1 is not compatible with FULL_DECODE_ONLY. Please set graph_mode to 'piecewise' or disable FlashComm1."
|
|
||||||
self.attn_mask_builder = AttentionMaskBuilder(self.device)
|
self.attn_mask_builder = AttentionMaskBuilder(self.device)
|
||||||
|
self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
|
||||||
|
self.enable_dsa_cp = enable_dsa_cp()
|
||||||
|
|
||||||
|
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
|
||||||
|
self.actual_seq_lengths_query = torch.zeros(max_num_reqs + 1,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device)
|
||||||
|
self.actual_seq_lengths_key = torch.empty_like(
|
||||||
|
self.actual_seq_lengths_query)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def determine_chunked_prefill_workspace_size(
|
def determine_chunked_prefill_workspace_size(
|
||||||
@@ -210,8 +212,8 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
|
|||||||
|
|
||||||
cos, sin = get_cos_and_sin_mla(input_positions, True)
|
cos, sin = get_cos_and_sin_mla(input_positions, True)
|
||||||
|
|
||||||
sfa_cp_context = None
|
dsa_cp_context = None
|
||||||
if self.enable_sfa_cp:
|
if self.enable_dsa_cp:
|
||||||
global_tp_size = get_tp_group().world_size
|
global_tp_size = get_tp_group().world_size
|
||||||
num_tokens = num_input_tokens
|
num_tokens = num_input_tokens
|
||||||
num_tokens_pad = _round_up(num_tokens, global_tp_size)
|
num_tokens_pad = _round_up(num_tokens, global_tp_size)
|
||||||
@@ -235,13 +237,11 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
|
|||||||
value=-1)
|
value=-1)
|
||||||
else:
|
else:
|
||||||
slot_mapping = slot_mapping[:num_tokens_pad]
|
slot_mapping = slot_mapping[:num_tokens_pad]
|
||||||
|
slot_mapping_cp = slot_mapping[local_start:local_end_with_pad]
|
||||||
|
|
||||||
cos = cos[local_start:local_end_with_pad]
|
cos = cos[local_start:local_end_with_pad]
|
||||||
sin = sin[local_start:local_end_with_pad]
|
sin = sin[local_start:local_end_with_pad]
|
||||||
slot_mapping_cp = torch.full(size=(num_tokens_per_device, ),
|
|
||||||
fill_value=-1,
|
|
||||||
dtype=slot_mapping.dtype,
|
|
||||||
device=slot_mapping.device)
|
|
||||||
assert cos.shape[0] == num_tokens_per_device, \
|
assert cos.shape[0] == num_tokens_per_device, \
|
||||||
f"cos.shape[0] must be equal to num_tokens_per_device, \
|
f"cos.shape[0] must be equal to num_tokens_per_device, \
|
||||||
got {cos.shape[0]} and {num_tokens_per_device}"
|
got {cos.shape[0]} and {num_tokens_per_device}"
|
||||||
@@ -252,8 +252,9 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
|
|||||||
f"slot_mapping.shape[0] must be equal to num_tokens_pad, \
|
f"slot_mapping.shape[0] must be equal to num_tokens_pad, \
|
||||||
got {slot_mapping.shape[0]} and {num_tokens_pad}"
|
got {slot_mapping.shape[0]} and {num_tokens_pad}"
|
||||||
|
|
||||||
actual_seq_lengths_query = torch.empty_like(cum_query_lens)
|
actual_seq_lengths_query = self.actual_seq_lengths_query
|
||||||
actual_seq_lengths_key = torch.empty_like(seq_lens)
|
actual_seq_lengths_key = self.actual_seq_lengths_key
|
||||||
|
|
||||||
num_segs = cum_query_lens.shape[0]
|
num_segs = cum_query_lens.shape[0]
|
||||||
last_token = 0
|
last_token = 0
|
||||||
cum = 0
|
cum = 0
|
||||||
@@ -262,21 +263,24 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
|
|||||||
global_end = cum_query_lens[i].item()
|
global_end = cum_query_lens[i].item()
|
||||||
last_token = global_end
|
last_token = global_end
|
||||||
|
|
||||||
local_start = max(global_start, local_start)
|
req_local_start = max(global_start, local_start)
|
||||||
local_end = min(global_end, local_end_with_pad)
|
req_local_end = min(global_end, local_end_with_pad)
|
||||||
num_local_tokens = local_end - local_start
|
num_local_tokens = req_local_end - req_local_start
|
||||||
|
|
||||||
if num_local_tokens > 0:
|
if num_local_tokens > 0:
|
||||||
cum += num_local_tokens
|
cum += num_local_tokens
|
||||||
actual_seq_lengths_query[i] = cum
|
actual_seq_lengths_query[i] = cum
|
||||||
|
|
||||||
offset = global_end - local_end
|
offset = global_end - req_local_end
|
||||||
actual_seq_lengths_key[i] = seq_lens[i].item() - offset
|
actual_seq_lengths_key[i] = seq_lens[i].item() - offset
|
||||||
else:
|
else:
|
||||||
actual_seq_lengths_query[i] = cum
|
actual_seq_lengths_query[i] = cum
|
||||||
actual_seq_lengths_key[i] = 0
|
actual_seq_lengths_key[i] = 0
|
||||||
|
|
||||||
sfa_cp_context = SfaCpContext(
|
actual_seq_lengths_query = actual_seq_lengths_query[:num_reqs]
|
||||||
|
actual_seq_lengths_key = actual_seq_lengths_key[:num_reqs]
|
||||||
|
|
||||||
|
dsa_cp_context = DSACPContext(
|
||||||
num_tokens=num_tokens,
|
num_tokens=num_tokens,
|
||||||
num_tokens_pad=num_tokens_pad,
|
num_tokens_pad=num_tokens_pad,
|
||||||
local_start=local_start,
|
local_start=local_start,
|
||||||
@@ -300,7 +304,7 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
|
|||||||
block_tables=block_table,
|
block_tables=block_table,
|
||||||
sin=sin[:num_input_tokens],
|
sin=sin[:num_input_tokens],
|
||||||
cos=cos[:num_input_tokens],
|
cos=cos[:num_input_tokens],
|
||||||
sfa_cp_context=sfa_cp_context)
|
dsa_cp_context=dsa_cp_context)
|
||||||
|
|
||||||
def build_for_graph_capture(
|
def build_for_graph_capture(
|
||||||
self,
|
self,
|
||||||
@@ -329,6 +333,8 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
NOTE: Please read the comment at the top of the file before trying to
|
NOTE: Please read the comment at the top of the file before trying to
|
||||||
understand this class
|
understand this class
|
||||||
"""
|
"""
|
||||||
|
# Supports forward using the all-gather o_proj weight for decode requests when Sharded CP is enabled.
|
||||||
|
o_proj_full_pool: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -382,22 +388,9 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
|
|
||||||
assert self.indexer is not None, "Indexer is required for DSA."
|
assert self.indexer is not None, "Indexer is required for DSA."
|
||||||
|
|
||||||
self.enable_sfa_cp = enable_dsa_cp()
|
|
||||||
self.local_num_heads = self.num_heads
|
self.local_num_heads = self.num_heads
|
||||||
self.vllm_config = get_current_vllm_config()
|
self.vllm_config = get_current_vllm_config()
|
||||||
self.is_kv_producer = self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer
|
self.is_kv_producer = self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer
|
||||||
if self.enable_sfa_cp:
|
|
||||||
self.local_num_heads = self.num_heads * self.tp_size
|
|
||||||
self.layer_sharding_kwargs = []
|
|
||||||
for layer_name in (get_ascend_config().layer_sharding or []):
|
|
||||||
if layer_name in kwargs:
|
|
||||||
self.layer_sharding_kwargs.append(kwargs[layer_name])
|
|
||||||
else:
|
|
||||||
logger.warning_once(
|
|
||||||
f"Layer '{layer_name}' not found in kwargs for layer sharding, skipping sharding configuration"
|
|
||||||
)
|
|
||||||
register_all_layers_to_shard_weight_series(
|
|
||||||
self.layer_sharding_kwargs)
|
|
||||||
|
|
||||||
# indexer param
|
# indexer param
|
||||||
self.n_head: int = self.indexer.n_head # 64
|
self.n_head: int = self.indexer.n_head # 64
|
||||||
@@ -406,9 +399,24 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
self.wk = self.indexer.wk
|
self.wk = self.indexer.wk
|
||||||
self.weights_proj = self.indexer.weights_proj
|
self.weights_proj = self.indexer.weights_proj
|
||||||
self.k_norm = self.indexer.k_norm
|
self.k_norm = self.indexer.k_norm
|
||||||
|
|
||||||
self.cp_size = 1
|
self.cp_size = 1
|
||||||
|
|
||||||
|
self.enable_dsa_cp = enable_dsa_cp()
|
||||||
|
self.enable_dsa_cp_prefill_only = enable_dsa_cp_with_layer_shard()
|
||||||
|
if self.enable_dsa_cp:
|
||||||
|
self.local_num_heads = self.num_heads * self.tp_size
|
||||||
|
if self.enable_dsa_cp_prefill_only:
|
||||||
|
self.layer_sharding_kwargs = []
|
||||||
|
for layer_name in (get_ascend_config().layer_sharding or []):
|
||||||
|
if layer_name in kwargs:
|
||||||
|
self.layer_sharding_kwargs.append(kwargs[layer_name])
|
||||||
|
else:
|
||||||
|
logger.warning_once(
|
||||||
|
f"[SFAImpl init] Layer '{layer_name}' not found in kwargs for layer sharding, skipping sharding configuration"
|
||||||
|
)
|
||||||
|
register_all_layers_to_shard_weight_series(
|
||||||
|
self.layer_sharding_kwargs)
|
||||||
|
|
||||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||||
# NOTE: We currently do not support quant kv_b_proj.
|
# NOTE: We currently do not support quant kv_b_proj.
|
||||||
assert isinstance(self.kv_b_proj.quant_method, UnquantizedLinearMethod)
|
assert isinstance(self.kv_b_proj.quant_method, UnquantizedLinearMethod)
|
||||||
@@ -442,10 +450,14 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
|
|
||||||
# Dispose kv_b_proj since it is replaced by W_UV and W_UK_T to save memory
|
# Dispose kv_b_proj since it is replaced by W_UV and W_UK_T to save memory
|
||||||
dispose_layer(self.kv_b_proj)
|
dispose_layer(self.kv_b_proj)
|
||||||
if self.enable_sfa_cp:
|
if self.enable_dsa_cp:
|
||||||
for layer in (self.layer_sharding_kwargs or []):
|
if self.enable_dsa_cp_prefill_only:
|
||||||
if is_hidden_layer(layer):
|
for layer in (self.layer_sharding_kwargs or []):
|
||||||
post_process_after_loading_for_shard_weight_series(layer)
|
if is_hidden_layer(layer):
|
||||||
|
post_process_after_loading_for_shard_weight_series(
|
||||||
|
layer)
|
||||||
|
else:
|
||||||
|
self._init_o_proj_tp_full_params()
|
||||||
|
|
||||||
if self.enable_mlapo:
|
if self.enable_mlapo:
|
||||||
quant_method = getattr(
|
quant_method = getattr(
|
||||||
@@ -460,7 +472,7 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
"Currently mlapo only supports W8A8 quantization in SFA scenario."
|
"Currently mlapo only supports W8A8 quantization in SFA scenario."
|
||||||
"Some layers in your model are not quantized with W8A8,"
|
"Some layers in your model are not quantized with W8A8,"
|
||||||
"thus mlapo is disabled for these layers.")
|
"thus mlapo is disabled for these layers.")
|
||||||
if self.enable_sfa_cp:
|
if self.enable_dsa_cp:
|
||||||
reasons.append("Currently mlapo does not support SFA with CP,"
|
reasons.append("Currently mlapo does not support SFA with CP,"
|
||||||
"thus mlapo is disabled for these layers.")
|
"thus mlapo is disabled for these layers.")
|
||||||
if reasons:
|
if reasons:
|
||||||
@@ -525,7 +537,7 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
|
B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
|
||||||
cache_mode = "PA"
|
cache_mode = "PA"
|
||||||
|
|
||||||
if self.enable_sfa_cp:
|
if self.enable_dsa_cp:
|
||||||
_, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache(
|
_, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache(
|
||||||
kv_no_split,
|
kv_no_split,
|
||||||
self.kv_a_layernorm.weight,
|
self.kv_a_layernorm.weight,
|
||||||
@@ -738,7 +750,7 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
forward_context = get_forward_context()
|
forward_context = get_forward_context()
|
||||||
if attn_metadata is None:
|
if attn_metadata is None:
|
||||||
# Profiling run.
|
# Profiling run.
|
||||||
if self.enable_sfa_cp and not forward_context.in_profile_run:
|
if self.enable_dsa_cp_prefill_only and not forward_context.in_profile_run:
|
||||||
for layer in (self.layer_sharding_kwargs or []):
|
for layer in (self.layer_sharding_kwargs or []):
|
||||||
if is_hidden_layer(layer):
|
if is_hidden_layer(layer):
|
||||||
reach_layer_for_shard_weight_series(layer)
|
reach_layer_for_shard_weight_series(layer)
|
||||||
@@ -748,12 +760,20 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
sin = attn_metadata.sin
|
sin = attn_metadata.sin
|
||||||
actual_seq_lengths_query = attn_metadata.cum_query_lens
|
actual_seq_lengths_query = attn_metadata.cum_query_lens
|
||||||
actual_seq_lengths_key = attn_metadata.seq_lens
|
actual_seq_lengths_key = attn_metadata.seq_lens
|
||||||
if self.enable_sfa_cp:
|
if self.enable_dsa_cp:
|
||||||
need_gather_q_kv = False
|
need_gather_q_kv = False
|
||||||
# Inputs and outputs may be padded for CUDA graphs
|
# Inputs and outputs may be padded for CUDA graphs
|
||||||
num_input_tokens = attn_metadata.num_input_tokens
|
num_input_tokens = attn_metadata.num_input_tokens
|
||||||
output_padded = output
|
output_padded = output
|
||||||
|
|
||||||
|
# all-gather o_proj weight for prefill stage of PD mix node
|
||||||
|
o_proj_full_handle = None
|
||||||
|
# if is PD mix stage, using original TP o_proj weight, and also need to full gather for o_proj weight for prefill stage.
|
||||||
|
should_shard_weight = self.enable_dsa_cp_prefill_only or attn_metadata.attn_state not in {
|
||||||
|
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
if self.enable_mlapo and num_input_tokens <= MLAPO_MAX_SUPPORTED_TOKENS:
|
if self.enable_mlapo and num_input_tokens <= MLAPO_MAX_SUPPORTED_TOKENS:
|
||||||
hidden_states, ql_nope, q_pe, q_c = self._sfa_preprocess_decode(
|
hidden_states, ql_nope, q_pe, q_c = self._sfa_preprocess_decode(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
@@ -796,16 +816,16 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
wait_for_kv_layer_from_connector(layer_name)
|
wait_for_kv_layer_from_connector(layer_name)
|
||||||
|
|
||||||
slot_mapping = attn_metadata.slot_mapping
|
slot_mapping = attn_metadata.slot_mapping
|
||||||
if self.enable_sfa_cp:
|
if self.enable_dsa_cp:
|
||||||
assert attn_metadata.sfa_cp_context is not None
|
assert attn_metadata.dsa_cp_context is not None
|
||||||
slot_mapping = attn_metadata.sfa_cp_context.slot_mapping_cp
|
slot_mapping = attn_metadata.dsa_cp_context.slot_mapping_cp
|
||||||
actual_seq_lengths_query = attn_metadata.sfa_cp_context.actual_seq_lengths_query
|
actual_seq_lengths_query = attn_metadata.dsa_cp_context.actual_seq_lengths_query
|
||||||
actual_seq_lengths_key = attn_metadata.sfa_cp_context.actual_seq_lengths_key
|
actual_seq_lengths_key = attn_metadata.dsa_cp_context.actual_seq_lengths_key
|
||||||
|
|
||||||
k_pe, k_nope = self.exec_kv(kv_no_split, cos, sin, kv_cache,
|
k_pe, k_nope = self.exec_kv(kv_no_split, cos, sin, kv_cache,
|
||||||
slot_mapping)
|
slot_mapping)
|
||||||
|
|
||||||
if self.enable_sfa_cp:
|
if self.enable_dsa_cp:
|
||||||
assert k_pe is not None
|
assert k_pe is not None
|
||||||
assert k_nope is not None
|
assert k_nope is not None
|
||||||
# support all_gather kv async for communication calculation overlap
|
# support all_gather kv async for communication calculation overlap
|
||||||
@@ -815,17 +835,26 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
k_nope.view(-1, k_nope.shape[-1]),
|
k_nope.view(-1, k_nope.shape[-1]),
|
||||||
k.view(-1, k.shape[-1])
|
k.view(-1, k.shape[-1])
|
||||||
],
|
],
|
||||||
dim=1), get_tp_group())
|
dim=1),
|
||||||
|
get_tp_group(),
|
||||||
|
async_op=should_shard_weight)
|
||||||
|
|
||||||
ql_nope, q_pe = self._q_proj_and_k_up_proj(q_c)
|
ql_nope, q_pe = self._q_proj_and_k_up_proj(q_c)
|
||||||
q_pe = self.rope_single(q_pe, cos, sin)
|
q_pe = self.rope_single(q_pe, cos, sin)
|
||||||
|
|
||||||
if self.enable_sfa_cp:
|
if self.enable_dsa_cp:
|
||||||
if kv_ag_handle is not None:
|
if kv_ag_handle is not None:
|
||||||
kv_ag_handle.wait()
|
kv_ag_handle.wait()
|
||||||
for layer in (self.layer_sharding_kwargs or []):
|
|
||||||
if is_hidden_layer(layer):
|
if self.enable_dsa_cp_prefill_only:
|
||||||
reach_layer_for_shard_weight_series(layer)
|
for layer in (self.layer_sharding_kwargs or []):
|
||||||
|
if is_hidden_layer(layer):
|
||||||
|
reach_layer_for_shard_weight_series(layer)
|
||||||
|
elif should_shard_weight:
|
||||||
|
_, o_proj_full_handle = all_gather_async(
|
||||||
|
self.o_proj_tp_weight,
|
||||||
|
get_tp_group(),
|
||||||
|
output=AscendSFAImpl.o_proj_full_pool)
|
||||||
|
|
||||||
if kv_cache is not None:
|
if kv_cache is not None:
|
||||||
assert fused_kv_no_split is not None
|
assert fused_kv_no_split is not None
|
||||||
@@ -841,6 +870,12 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
kv_cache[1].view(-1, k_pe.shape[-1]), slot_mapping,
|
kv_cache[1].view(-1, k_pe.shape[-1]), slot_mapping,
|
||||||
k_pe)
|
k_pe)
|
||||||
|
|
||||||
|
if kv_cache is not None:
|
||||||
|
torch_npu.npu_scatter_nd_update_(
|
||||||
|
kv_cache[2].view(-1, k.shape[-1]),
|
||||||
|
attn_metadata.slot_mapping.view(-1, 1),
|
||||||
|
k.view(-1, k.shape[-1])) # b, s, n, d
|
||||||
|
|
||||||
topk_indices = self.indexer_select_post_process(
|
topk_indices = self.indexer_select_post_process(
|
||||||
x=hidden_states,
|
x=hidden_states,
|
||||||
qr=q_c,
|
qr=q_c,
|
||||||
@@ -876,6 +911,20 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
dependency=attn_output,
|
dependency=attn_output,
|
||||||
max_size=MAX_O_PROJ_PREFETCH_SIZE,
|
max_size=MAX_O_PROJ_PREFETCH_SIZE,
|
||||||
enabled=self.enable_prefetch)
|
enabled=self.enable_prefetch)
|
||||||
|
|
||||||
|
if self.enable_dsa_cp and not self.enable_dsa_cp_prefill_only:
|
||||||
|
# When using SFA-CP with pd mixed, o_proj has two cases:
|
||||||
|
# 1. prefill: o_proj is a TP weight, we need to all-gather o_proj weight to switch TP=1.
|
||||||
|
# 2. decode: all-to-all the hidden_state before the o_proj forward.
|
||||||
|
result, require_o_proj_forward = self._handle_o_proj_weight_switch_and_forward(
|
||||||
|
attn_output=attn_output,
|
||||||
|
output=output,
|
||||||
|
o_proj_full_handle=o_proj_full_handle,
|
||||||
|
should_shard_weight=should_shard_weight)
|
||||||
|
if not require_o_proj_forward:
|
||||||
|
return result
|
||||||
|
attn_output = result
|
||||||
|
|
||||||
output[...] = self.o_proj(attn_output)[0]
|
output[...] = self.o_proj(attn_output)[0]
|
||||||
|
|
||||||
maybe_save_kv_layer_to_connector(layer_name, list(kv_cache))
|
maybe_save_kv_layer_to_connector(layer_name, list(kv_cache))
|
||||||
@@ -912,7 +961,10 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
k_pe, k_nope = torch.split(
|
k_pe, k_nope = torch.split(
|
||||||
k,
|
k,
|
||||||
[self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim],
|
[self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim],
|
||||||
dim=-1) # [b,s,64+64]
|
dim=-1)
|
||||||
|
|
||||||
|
cos = cos.view(-1, 1, 1, self.qk_rope_head_dim)
|
||||||
|
sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)
|
||||||
|
|
||||||
k_pe = k_pe.unsqueeze(2)
|
k_pe = k_pe.unsqueeze(2)
|
||||||
k_pe = torch_npu.npu_interleave_rope(k_pe, cos, sin)
|
k_pe = torch_npu.npu_interleave_rope(k_pe, cos, sin)
|
||||||
@@ -940,10 +992,7 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
if q is None:
|
if q is None:
|
||||||
q, _ = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128]
|
q, _ = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128]
|
||||||
q = q.view(-1, self.n_head, self.head_dim) # [n_toks,64,128]
|
q = q.view(-1, self.n_head, self.head_dim) # [n_toks,64,128]
|
||||||
|
|
||||||
cos_q, sin_q = cos, sin
|
cos_q, sin_q = cos, sin
|
||||||
cos = cos.view(-1, 1, 1, self.qk_rope_head_dim)
|
|
||||||
sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)
|
|
||||||
|
|
||||||
q_pe, q_nope = torch.split(
|
q_pe, q_nope = torch.split(
|
||||||
q,
|
q,
|
||||||
@@ -984,3 +1033,92 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
sparse_count=2048,
|
sparse_count=2048,
|
||||||
sparse_mode=3)
|
sparse_mode=3)
|
||||||
return topk_indices
|
return topk_indices
|
||||||
|
|
||||||
|
def _init_o_proj_tp_full_params(self):
|
||||||
|
"""
|
||||||
|
Initialize TP-mode and Full-mode parameters for o_proj weight,
|
||||||
|
preparing for weight switching in PD mix stage.
|
||||||
|
|
||||||
|
For PD mix stage:
|
||||||
|
- Use original TP o_proj weight for decode phase
|
||||||
|
- Need full-gather o_proj weight from all TP ranks for prefill phase
|
||||||
|
"""
|
||||||
|
if AscendSFAImpl.o_proj_full_pool is None:
|
||||||
|
sample = self.o_proj.weight
|
||||||
|
AscendSFAImpl.o_proj_full_pool = torch.empty(
|
||||||
|
(sample.shape[0] * self.tp_size, sample.shape[1]),
|
||||||
|
dtype=sample.dtype,
|
||||||
|
device=sample.device)
|
||||||
|
|
||||||
|
# Save TP-mode parameters (original sharded weights)
|
||||||
|
self.o_proj_tp_weight = self.o_proj.weight.clone().detach()
|
||||||
|
self.o_proj_tp_aclnn_input_scale = self.o_proj.aclnn_input_scale.clone(
|
||||||
|
).detach()
|
||||||
|
self.o_proj_tp_aclnn_input_scale_reciprocal = self.o_proj.aclnn_input_scale_reciprocal.clone(
|
||||||
|
).detach()
|
||||||
|
self.o_proj_tp_aclnn_input_offset = self.o_proj.aclnn_input_offset.clone(
|
||||||
|
).detach()
|
||||||
|
|
||||||
|
# Initially switch to TP mode for graph capture
|
||||||
|
self.o_proj.weight.set_(self.o_proj_tp_weight)
|
||||||
|
self.o_proj.aclnn_input_scale.set_(self.o_proj_tp_aclnn_input_scale)
|
||||||
|
self.o_proj.aclnn_input_scale_reciprocal.set_(
|
||||||
|
self.o_proj_tp_aclnn_input_scale_reciprocal)
|
||||||
|
self.o_proj.aclnn_input_offset.set_(self.o_proj_tp_aclnn_input_offset)
|
||||||
|
|
||||||
|
# Precompute Full-mode quantization parameters by repeating TP parameters across all TP ranks
|
||||||
|
self.o_proj_full_aclnn_input_scale = self.o_proj.aclnn_input_scale.repeat(
|
||||||
|
self.tp_size)
|
||||||
|
self.o_proj_full_aclnn_input_scale_reciprocal = self.o_proj.aclnn_input_scale_reciprocal.repeat(
|
||||||
|
self.tp_size)
|
||||||
|
self.o_proj_full_aclnn_input_offset = self.o_proj.aclnn_input_offset.repeat(
|
||||||
|
self.tp_size)
|
||||||
|
|
||||||
|
def _handle_o_proj_weight_switch_and_forward(
|
||||||
|
self, attn_output: torch.Tensor, output: torch.Tensor,
|
||||||
|
o_proj_full_handle: Optional[torch.distributed.Work],
|
||||||
|
should_shard_weight: bool) -> Tuple[torch.Tensor, bool]:
|
||||||
|
"""
|
||||||
|
Handle o_proj weight switching between TP-mode and Full-mode, and execute forward computation.
|
||||||
|
"""
|
||||||
|
# Gather o_proj weight from all TP ranks for Full-mode computation
|
||||||
|
if should_shard_weight:
|
||||||
|
# Wait for the completion of o_proj weight all-gather operation
|
||||||
|
if o_proj_full_handle is not None:
|
||||||
|
o_proj_full_handle.wait()
|
||||||
|
|
||||||
|
# Switch o_proj to Full-mode (gathered weight from all TP ranks)
|
||||||
|
self.o_proj.weight.set_(AscendSFAImpl.o_proj_full_pool)
|
||||||
|
self.o_proj.aclnn_input_scale.set_(
|
||||||
|
self.o_proj_full_aclnn_input_scale)
|
||||||
|
self.o_proj.aclnn_input_scale_reciprocal.set_(
|
||||||
|
self.o_proj_full_aclnn_input_scale_reciprocal)
|
||||||
|
self.o_proj.aclnn_input_offset.set_(
|
||||||
|
self.o_proj_full_aclnn_input_offset)
|
||||||
|
|
||||||
|
# Apply quantization method and execute forward computation
|
||||||
|
output[...] = self.o_proj.quant_method.quant_method.apply(
|
||||||
|
self.o_proj, attn_output)
|
||||||
|
|
||||||
|
# Switch o_proj back to TP-mode for subsequent decode operations
|
||||||
|
self.o_proj.weight.set_(self.o_proj_tp_weight)
|
||||||
|
self.o_proj.aclnn_input_scale.set_(
|
||||||
|
self.o_proj_tp_aclnn_input_scale)
|
||||||
|
self.o_proj.aclnn_input_scale_reciprocal.set_(
|
||||||
|
self.o_proj_tp_aclnn_input_scale_reciprocal)
|
||||||
|
self.o_proj.aclnn_input_offset.set_(
|
||||||
|
self.o_proj_tp_aclnn_input_offset)
|
||||||
|
|
||||||
|
return output, False
|
||||||
|
else:
|
||||||
|
# For decode scenario: perform all-to-all communication on o_proj input activations
|
||||||
|
# Reshape for all-to-all: [batch * seq, tp_size, head_dim] -> [tp_size, batch * seq, head_dim]
|
||||||
|
send = attn_output.view(-1, self.tp_size, self.num_heads *
|
||||||
|
self.v_head_dim).permute(1, 0, 2).reshape(
|
||||||
|
-1, self.num_heads * self.v_head_dim)
|
||||||
|
|
||||||
|
attn_output = torch.empty_like(send)
|
||||||
|
torch.distributed.all_to_all_single(
|
||||||
|
attn_output, send, group=get_tp_group().device_group)
|
||||||
|
|
||||||
|
return attn_output, True
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from vllm.distributed.parallel_state import (GroupCoordinator, get_tp_group,
|
|||||||
init_model_parallel_group)
|
init_model_parallel_group)
|
||||||
|
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
from vllm_ascend.utils import enable_dsa_cp, flashcomm2_enable
|
from vllm_ascend.utils import enable_dsa_cp_with_layer_shard, flashcomm2_enable
|
||||||
|
|
||||||
# Currently, mc2 op need their own group coordinator.
|
# Currently, mc2 op need their own group coordinator.
|
||||||
_MC2: Optional[GroupCoordinator] = None
|
_MC2: Optional[GroupCoordinator] = None
|
||||||
@@ -238,7 +238,7 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
|
|||||||
FC2_group_ranks = torch.tensor(
|
FC2_group_ranks = torch.tensor(
|
||||||
flashcomm2_otp_group_ranks).squeeze(0)
|
flashcomm2_otp_group_ranks).squeeze(0)
|
||||||
_SHARD_WEIGHT = create_shard_weight_group(FC2_group_ranks)
|
_SHARD_WEIGHT = create_shard_weight_group(FC2_group_ranks)
|
||||||
elif enable_dsa_cp():
|
elif enable_dsa_cp_with_layer_shard():
|
||||||
# For dsa_cp, all shard layers are replicated.
|
# For dsa_cp, all shard layers are replicated.
|
||||||
_SHARD_WEIGHT = create_shard_weight_group(None)
|
_SHARD_WEIGHT = create_shard_weight_group(None)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -52,4 +52,4 @@ def all_gather_async(input: torch.Tensor,
|
|||||||
return output, dist.all_gather_into_tensor(output,
|
return output, dist.all_gather_into_tensor(output,
|
||||||
input,
|
input,
|
||||||
group=group.device_group,
|
group=group.device_group,
|
||||||
async_op=async_op)
|
async_op=async_op)
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ from vllm_ascend.distributed.parallel_state import (get_flashcomm2_odp_group,
|
|||||||
get_mlp_tp_group,
|
get_mlp_tp_group,
|
||||||
get_otp_group)
|
get_otp_group)
|
||||||
from vllm_ascend.ops.flashcomm2_oshard_manager import flashcomm2_oshard_manager
|
from vllm_ascend.ops.flashcomm2_oshard_manager import flashcomm2_oshard_manager
|
||||||
from vllm_ascend.utils import (enable_dsa_cp, enable_sp, flashcomm2_enable,
|
from vllm_ascend.utils import (enable_dsa_cp, enable_dsa_cp_with_layer_shard, enable_sp, flashcomm2_enable,
|
||||||
get_flashcomm2_reorgnized_batch_ids,
|
get_flashcomm2_reorgnized_batch_ids,
|
||||||
matmul_allreduce_enable, mlp_tp_enable,
|
matmul_allreduce_enable, mlp_tp_enable,
|
||||||
oproj_tp_enable, shared_expert_dp_enabled)
|
oproj_tp_enable, shared_expert_dp_enabled)
|
||||||
@@ -575,7 +575,8 @@ class SequenceRowParallelOp(CustomRowParallelOp):
|
|||||||
return tensor_model_parallel_all_reduce(output_parallel)
|
return tensor_model_parallel_all_reduce(output_parallel)
|
||||||
|
|
||||||
pad_size = forward_context.pad_size
|
pad_size = forward_context.pad_size
|
||||||
if pad_size > 0:
|
if pad_size > 0 and not (enable_dsa_cp()
|
||||||
|
and "o_proj" in self.layer.prefix):
|
||||||
x = F.pad(x, (0, 0, 0, pad_size))
|
x = F.pad(x, (0, 0, 0, pad_size))
|
||||||
|
|
||||||
world_size = self.layer.tp_size
|
world_size = self.layer.tp_size
|
||||||
@@ -728,7 +729,7 @@ def _get_row_parallel_op(
|
|||||||
) -> Optional[Union[MLPRowParallelOp, OProjRowParallelOp,
|
) -> Optional[Union[MLPRowParallelOp, OProjRowParallelOp,
|
||||||
Flashcomm2OProjRowParallelOp, MatmulAllreduceRowParallelOp,
|
Flashcomm2OProjRowParallelOp, MatmulAllreduceRowParallelOp,
|
||||||
SequenceRowParallelOp, ShardedCPRowParallelOp]]:
|
SequenceRowParallelOp, ShardedCPRowParallelOp]]:
|
||||||
if enable_dsa_cp() and "o_proj" in prefix:
|
if enable_dsa_cp_with_layer_shard() and "o_proj" in prefix:
|
||||||
return ShardedCPRowParallelOp(layer)
|
return ShardedCPRowParallelOp(layer)
|
||||||
if "down_proj" in prefix and mlp_tp_enable() and not is_moe_layer(prefix):
|
if "down_proj" in prefix and mlp_tp_enable() and not is_moe_layer(prefix):
|
||||||
return MLPRowParallelOp(layer)
|
return MLPRowParallelOp(layer)
|
||||||
|
|||||||
Reference in New Issue
Block a user