|
|
|
|
@@ -6,6 +6,7 @@ import torch_npu
|
|
|
|
|
import vllm.envs as envs_vllm
|
|
|
|
|
from torch import nn
|
|
|
|
|
from vllm.config import CUDAGraphMode, VllmConfig, get_current_vllm_config
|
|
|
|
|
from vllm.config import VllmConfig, get_current_vllm_config
|
|
|
|
|
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
|
|
|
|
|
from vllm.forward_context import get_forward_context
|
|
|
|
|
from vllm.logger import logger
|
|
|
|
|
@@ -34,7 +35,7 @@ from vllm_ascend.ops.triton.rope import rope_forward_triton
|
|
|
|
|
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
|
|
|
|
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
|
|
|
|
|
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, _round_up, dispose_layer,
|
|
|
|
|
enable_dsa_cp, maybe_trans_nz, vllm_version_is)
|
|
|
|
|
enable_dsa_cp, enable_dsa_cp_with_layer_shard, maybe_trans_nz, vllm_version_is)
|
|
|
|
|
from vllm_ascend.worker.npu_input_batch import NPUInputBatch
|
|
|
|
|
|
|
|
|
|
# isort: off
|
|
|
|
|
@@ -79,7 +80,7 @@ class AscendSFABackend(AttentionBackend):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class SfaCpContext:
|
|
|
|
|
class DSACPContext:
|
|
|
|
|
num_tokens: int
|
|
|
|
|
num_tokens_pad: int
|
|
|
|
|
local_start: int
|
|
|
|
|
@@ -119,7 +120,7 @@ class AscendSFAMetadata:
|
|
|
|
|
attn_mask: torch.Tensor = None
|
|
|
|
|
# chunked prefill by default if no attn_states passed
|
|
|
|
|
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
|
|
|
|
|
sfa_cp_context: Optional[SfaCpContext] = None
|
|
|
|
|
dsa_cp_context: Optional[DSACPContext] = None
|
|
|
|
|
reshape_cache_event: torch.npu.Event = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -159,15 +160,16 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
|
|
|
|
|
npu_fused_infer_attention_score TND layout's limit of 16, \
|
|
|
|
|
got {self.decode_threshold}"
|
|
|
|
|
|
|
|
|
|
self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
|
|
|
|
|
self.enable_sfa_cp = enable_dsa_cp()
|
|
|
|
|
|
|
|
|
|
assert not (
|
|
|
|
|
self.enable_sfa_cp
|
|
|
|
|
and self.vllm_config.compilation_config.cudagraph_mode
|
|
|
|
|
== CUDAGraphMode.FULL_DECODE_ONLY
|
|
|
|
|
), "FlashComm1 is not compatible with FULL_DECODE_ONLY. Please set graph_mode to 'piecewise' or disable FlashComm1."
|
|
|
|
|
self.attn_mask_builder = AttentionMaskBuilder(self.device)
|
|
|
|
|
self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
|
|
|
|
|
self.enable_dsa_cp = enable_dsa_cp()
|
|
|
|
|
|
|
|
|
|
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
|
|
|
|
|
self.actual_seq_lengths_query = torch.zeros(max_num_reqs + 1,
|
|
|
|
|
dtype=torch.int32,
|
|
|
|
|
device=device)
|
|
|
|
|
self.actual_seq_lengths_key = torch.empty_like(
|
|
|
|
|
self.actual_seq_lengths_query)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def determine_chunked_prefill_workspace_size(
|
|
|
|
|
@@ -210,8 +212,8 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
|
|
|
|
|
|
|
|
|
|
cos, sin = get_cos_and_sin_mla(input_positions, True)
|
|
|
|
|
|
|
|
|
|
sfa_cp_context = None
|
|
|
|
|
if self.enable_sfa_cp:
|
|
|
|
|
dsa_cp_context = None
|
|
|
|
|
if self.enable_dsa_cp:
|
|
|
|
|
global_tp_size = get_tp_group().world_size
|
|
|
|
|
num_tokens = num_input_tokens
|
|
|
|
|
num_tokens_pad = _round_up(num_tokens, global_tp_size)
|
|
|
|
|
@@ -235,13 +237,11 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
|
|
|
|
|
value=-1)
|
|
|
|
|
else:
|
|
|
|
|
slot_mapping = slot_mapping[:num_tokens_pad]
|
|
|
|
|
slot_mapping_cp = slot_mapping[local_start:local_end_with_pad]
|
|
|
|
|
|
|
|
|
|
cos = cos[local_start:local_end_with_pad]
|
|
|
|
|
sin = sin[local_start:local_end_with_pad]
|
|
|
|
|
slot_mapping_cp = torch.full(size=(num_tokens_per_device, ),
|
|
|
|
|
fill_value=-1,
|
|
|
|
|
dtype=slot_mapping.dtype,
|
|
|
|
|
device=slot_mapping.device)
|
|
|
|
|
|
|
|
|
|
assert cos.shape[0] == num_tokens_per_device, \
|
|
|
|
|
f"cos.shape[0] must be equal to num_tokens_per_device, \
|
|
|
|
|
got {cos.shape[0]} and {num_tokens_per_device}"
|
|
|
|
|
@@ -252,8 +252,9 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
|
|
|
|
|
f"slot_mapping.shape[0] must be equal to num_tokens_pad, \
|
|
|
|
|
got {slot_mapping.shape[0]} and {num_tokens_pad}"
|
|
|
|
|
|
|
|
|
|
actual_seq_lengths_query = torch.empty_like(cum_query_lens)
|
|
|
|
|
actual_seq_lengths_key = torch.empty_like(seq_lens)
|
|
|
|
|
actual_seq_lengths_query = self.actual_seq_lengths_query
|
|
|
|
|
actual_seq_lengths_key = self.actual_seq_lengths_key
|
|
|
|
|
|
|
|
|
|
num_segs = cum_query_lens.shape[0]
|
|
|
|
|
last_token = 0
|
|
|
|
|
cum = 0
|
|
|
|
|
@@ -262,21 +263,24 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
|
|
|
|
|
global_end = cum_query_lens[i].item()
|
|
|
|
|
last_token = global_end
|
|
|
|
|
|
|
|
|
|
local_start = max(global_start, local_start)
|
|
|
|
|
local_end = min(global_end, local_end_with_pad)
|
|
|
|
|
num_local_tokens = local_end - local_start
|
|
|
|
|
req_local_start = max(global_start, local_start)
|
|
|
|
|
req_local_end = min(global_end, local_end_with_pad)
|
|
|
|
|
num_local_tokens = req_local_end - req_local_start
|
|
|
|
|
|
|
|
|
|
if num_local_tokens > 0:
|
|
|
|
|
cum += num_local_tokens
|
|
|
|
|
actual_seq_lengths_query[i] = cum
|
|
|
|
|
|
|
|
|
|
offset = global_end - local_end
|
|
|
|
|
offset = global_end - req_local_end
|
|
|
|
|
actual_seq_lengths_key[i] = seq_lens[i].item() - offset
|
|
|
|
|
else:
|
|
|
|
|
actual_seq_lengths_query[i] = cum
|
|
|
|
|
actual_seq_lengths_key[i] = 0
|
|
|
|
|
|
|
|
|
|
sfa_cp_context = SfaCpContext(
|
|
|
|
|
actual_seq_lengths_query = actual_seq_lengths_query[:num_reqs]
|
|
|
|
|
actual_seq_lengths_key = actual_seq_lengths_key[:num_reqs]
|
|
|
|
|
|
|
|
|
|
dsa_cp_context = DSACPContext(
|
|
|
|
|
num_tokens=num_tokens,
|
|
|
|
|
num_tokens_pad=num_tokens_pad,
|
|
|
|
|
local_start=local_start,
|
|
|
|
|
@@ -300,7 +304,7 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
|
|
|
|
|
block_tables=block_table,
|
|
|
|
|
sin=sin[:num_input_tokens],
|
|
|
|
|
cos=cos[:num_input_tokens],
|
|
|
|
|
sfa_cp_context=sfa_cp_context)
|
|
|
|
|
dsa_cp_context=dsa_cp_context)
|
|
|
|
|
|
|
|
|
|
def build_for_graph_capture(
|
|
|
|
|
self,
|
|
|
|
|
@@ -329,6 +333,8 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|
|
|
|
NOTE: Please read the comment at the top of the file before trying to
|
|
|
|
|
understand this class
|
|
|
|
|
"""
|
|
|
|
|
# Supports forward using the all-gather o_proj weight for decode requests when Sharded CP is enabled.
|
|
|
|
|
o_proj_full_pool: Optional[torch.Tensor] = None
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
@@ -382,22 +388,9 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|
|
|
|
|
|
|
|
|
assert self.indexer is not None, "Indexer is required for DSA."
|
|
|
|
|
|
|
|
|
|
self.enable_sfa_cp = enable_dsa_cp()
|
|
|
|
|
self.local_num_heads = self.num_heads
|
|
|
|
|
self.vllm_config = get_current_vllm_config()
|
|
|
|
|
self.is_kv_producer = self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer
|
|
|
|
|
if self.enable_sfa_cp:
|
|
|
|
|
self.local_num_heads = self.num_heads * self.tp_size
|
|
|
|
|
self.layer_sharding_kwargs = []
|
|
|
|
|
for layer_name in (get_ascend_config().layer_sharding or []):
|
|
|
|
|
if layer_name in kwargs:
|
|
|
|
|
self.layer_sharding_kwargs.append(kwargs[layer_name])
|
|
|
|
|
else:
|
|
|
|
|
logger.warning_once(
|
|
|
|
|
f"Layer '{layer_name}' not found in kwargs for layer sharding, skipping sharding configuration"
|
|
|
|
|
)
|
|
|
|
|
register_all_layers_to_shard_weight_series(
|
|
|
|
|
self.layer_sharding_kwargs)
|
|
|
|
|
|
|
|
|
|
# indexer param
|
|
|
|
|
self.n_head: int = self.indexer.n_head # 64
|
|
|
|
|
@@ -406,9 +399,24 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|
|
|
|
self.wk = self.indexer.wk
|
|
|
|
|
self.weights_proj = self.indexer.weights_proj
|
|
|
|
|
self.k_norm = self.indexer.k_norm
|
|
|
|
|
|
|
|
|
|
self.cp_size = 1
|
|
|
|
|
|
|
|
|
|
self.enable_dsa_cp = enable_dsa_cp()
|
|
|
|
|
self.enable_dsa_cp_prefill_only = enable_dsa_cp_with_layer_shard()
|
|
|
|
|
if self.enable_dsa_cp:
|
|
|
|
|
self.local_num_heads = self.num_heads * self.tp_size
|
|
|
|
|
if self.enable_dsa_cp_prefill_only:
|
|
|
|
|
self.layer_sharding_kwargs = []
|
|
|
|
|
for layer_name in (get_ascend_config().layer_sharding or []):
|
|
|
|
|
if layer_name in kwargs:
|
|
|
|
|
self.layer_sharding_kwargs.append(kwargs[layer_name])
|
|
|
|
|
else:
|
|
|
|
|
logger.warning_once(
|
|
|
|
|
f"[SFAImpl init] Layer '{layer_name}' not found in kwargs for layer sharding, skipping sharding configuration"
|
|
|
|
|
)
|
|
|
|
|
register_all_layers_to_shard_weight_series(
|
|
|
|
|
self.layer_sharding_kwargs)
|
|
|
|
|
|
|
|
|
|
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
|
|
|
|
# NOTE: We currently do not support quant kv_b_proj.
|
|
|
|
|
assert isinstance(self.kv_b_proj.quant_method, UnquantizedLinearMethod)
|
|
|
|
|
@@ -442,10 +450,14 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|
|
|
|
|
|
|
|
|
# Dispose kv_b_proj since it is replaced by W_UV and W_UK_T to save memory
|
|
|
|
|
dispose_layer(self.kv_b_proj)
|
|
|
|
|
if self.enable_sfa_cp:
|
|
|
|
|
for layer in (self.layer_sharding_kwargs or []):
|
|
|
|
|
if is_hidden_layer(layer):
|
|
|
|
|
post_process_after_loading_for_shard_weight_series(layer)
|
|
|
|
|
if self.enable_dsa_cp:
|
|
|
|
|
if self.enable_dsa_cp_prefill_only:
|
|
|
|
|
for layer in (self.layer_sharding_kwargs or []):
|
|
|
|
|
if is_hidden_layer(layer):
|
|
|
|
|
post_process_after_loading_for_shard_weight_series(
|
|
|
|
|
layer)
|
|
|
|
|
else:
|
|
|
|
|
self._init_o_proj_tp_full_params()
|
|
|
|
|
|
|
|
|
|
if self.enable_mlapo:
|
|
|
|
|
quant_method = getattr(
|
|
|
|
|
@@ -460,7 +472,7 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|
|
|
|
"Currently mlapo only supports W8A8 quantization in SFA scenario."
|
|
|
|
|
"Some layers in your model are not quantized with W8A8,"
|
|
|
|
|
"thus mlapo is disabled for these layers.")
|
|
|
|
|
if self.enable_sfa_cp:
|
|
|
|
|
if self.enable_dsa_cp:
|
|
|
|
|
reasons.append("Currently mlapo does not support SFA with CP,"
|
|
|
|
|
"thus mlapo is disabled for these layers.")
|
|
|
|
|
if reasons:
|
|
|
|
|
@@ -525,7 +537,7 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|
|
|
|
B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
|
|
|
|
|
cache_mode = "PA"
|
|
|
|
|
|
|
|
|
|
if self.enable_sfa_cp:
|
|
|
|
|
if self.enable_dsa_cp:
|
|
|
|
|
_, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache(
|
|
|
|
|
kv_no_split,
|
|
|
|
|
self.kv_a_layernorm.weight,
|
|
|
|
|
@@ -738,7 +750,7 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|
|
|
|
forward_context = get_forward_context()
|
|
|
|
|
if attn_metadata is None:
|
|
|
|
|
# Profiling run.
|
|
|
|
|
if self.enable_sfa_cp and not forward_context.in_profile_run:
|
|
|
|
|
if self.enable_dsa_cp_prefill_only and not forward_context.in_profile_run:
|
|
|
|
|
for layer in (self.layer_sharding_kwargs or []):
|
|
|
|
|
if is_hidden_layer(layer):
|
|
|
|
|
reach_layer_for_shard_weight_series(layer)
|
|
|
|
|
@@ -748,12 +760,20 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|
|
|
|
sin = attn_metadata.sin
|
|
|
|
|
actual_seq_lengths_query = attn_metadata.cum_query_lens
|
|
|
|
|
actual_seq_lengths_key = attn_metadata.seq_lens
|
|
|
|
|
if self.enable_sfa_cp:
|
|
|
|
|
if self.enable_dsa_cp:
|
|
|
|
|
need_gather_q_kv = False
|
|
|
|
|
# Inputs and outputs may be padded for CUDA graphs
|
|
|
|
|
num_input_tokens = attn_metadata.num_input_tokens
|
|
|
|
|
output_padded = output
|
|
|
|
|
|
|
|
|
|
# all-gather o_proj weight for prefill stage of PD mix node
|
|
|
|
|
o_proj_full_handle = None
|
|
|
|
|
# if is PD mix stage, using original TP o_proj weight, and also need to full gather for o_proj weight for prefill stage.
|
|
|
|
|
should_shard_weight = self.enable_dsa_cp_prefill_only or attn_metadata.attn_state not in {
|
|
|
|
|
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.enable_mlapo and num_input_tokens <= MLAPO_MAX_SUPPORTED_TOKENS:
|
|
|
|
|
hidden_states, ql_nope, q_pe, q_c = self._sfa_preprocess_decode(
|
|
|
|
|
hidden_states=hidden_states,
|
|
|
|
|
@@ -796,16 +816,16 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|
|
|
|
wait_for_kv_layer_from_connector(layer_name)
|
|
|
|
|
|
|
|
|
|
slot_mapping = attn_metadata.slot_mapping
|
|
|
|
|
if self.enable_sfa_cp:
|
|
|
|
|
assert attn_metadata.sfa_cp_context is not None
|
|
|
|
|
slot_mapping = attn_metadata.sfa_cp_context.slot_mapping_cp
|
|
|
|
|
actual_seq_lengths_query = attn_metadata.sfa_cp_context.actual_seq_lengths_query
|
|
|
|
|
actual_seq_lengths_key = attn_metadata.sfa_cp_context.actual_seq_lengths_key
|
|
|
|
|
if self.enable_dsa_cp:
|
|
|
|
|
assert attn_metadata.dsa_cp_context is not None
|
|
|
|
|
slot_mapping = attn_metadata.dsa_cp_context.slot_mapping_cp
|
|
|
|
|
actual_seq_lengths_query = attn_metadata.dsa_cp_context.actual_seq_lengths_query
|
|
|
|
|
actual_seq_lengths_key = attn_metadata.dsa_cp_context.actual_seq_lengths_key
|
|
|
|
|
|
|
|
|
|
k_pe, k_nope = self.exec_kv(kv_no_split, cos, sin, kv_cache,
|
|
|
|
|
slot_mapping)
|
|
|
|
|
|
|
|
|
|
if self.enable_sfa_cp:
|
|
|
|
|
if self.enable_dsa_cp:
|
|
|
|
|
assert k_pe is not None
|
|
|
|
|
assert k_nope is not None
|
|
|
|
|
# support all_gather kv async for communication calculation overlap
|
|
|
|
|
@@ -815,17 +835,26 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|
|
|
|
k_nope.view(-1, k_nope.shape[-1]),
|
|
|
|
|
k.view(-1, k.shape[-1])
|
|
|
|
|
],
|
|
|
|
|
dim=1), get_tp_group())
|
|
|
|
|
dim=1),
|
|
|
|
|
get_tp_group(),
|
|
|
|
|
async_op=should_shard_weight)
|
|
|
|
|
|
|
|
|
|
ql_nope, q_pe = self._q_proj_and_k_up_proj(q_c)
|
|
|
|
|
q_pe = self.rope_single(q_pe, cos, sin)
|
|
|
|
|
|
|
|
|
|
if self.enable_sfa_cp:
|
|
|
|
|
if self.enable_dsa_cp:
|
|
|
|
|
if kv_ag_handle is not None:
|
|
|
|
|
kv_ag_handle.wait()
|
|
|
|
|
for layer in (self.layer_sharding_kwargs or []):
|
|
|
|
|
if is_hidden_layer(layer):
|
|
|
|
|
reach_layer_for_shard_weight_series(layer)
|
|
|
|
|
|
|
|
|
|
if self.enable_dsa_cp_prefill_only:
|
|
|
|
|
for layer in (self.layer_sharding_kwargs or []):
|
|
|
|
|
if is_hidden_layer(layer):
|
|
|
|
|
reach_layer_for_shard_weight_series(layer)
|
|
|
|
|
elif should_shard_weight:
|
|
|
|
|
_, o_proj_full_handle = all_gather_async(
|
|
|
|
|
self.o_proj_tp_weight,
|
|
|
|
|
get_tp_group(),
|
|
|
|
|
output=AscendSFAImpl.o_proj_full_pool)
|
|
|
|
|
|
|
|
|
|
if kv_cache is not None:
|
|
|
|
|
assert fused_kv_no_split is not None
|
|
|
|
|
@@ -841,6 +870,12 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|
|
|
|
kv_cache[1].view(-1, k_pe.shape[-1]), slot_mapping,
|
|
|
|
|
k_pe)
|
|
|
|
|
|
|
|
|
|
if kv_cache is not None:
|
|
|
|
|
torch_npu.npu_scatter_nd_update_(
|
|
|
|
|
kv_cache[2].view(-1, k.shape[-1]),
|
|
|
|
|
attn_metadata.slot_mapping.view(-1, 1),
|
|
|
|
|
k.view(-1, k.shape[-1])) # b, s, n, d
|
|
|
|
|
|
|
|
|
|
topk_indices = self.indexer_select_post_process(
|
|
|
|
|
x=hidden_states,
|
|
|
|
|
qr=q_c,
|
|
|
|
|
@@ -876,6 +911,20 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|
|
|
|
dependency=attn_output,
|
|
|
|
|
max_size=MAX_O_PROJ_PREFETCH_SIZE,
|
|
|
|
|
enabled=self.enable_prefetch)
|
|
|
|
|
|
|
|
|
|
if self.enable_dsa_cp and not self.enable_dsa_cp_prefill_only:
|
|
|
|
|
# When using SFA-CP with pd mixed, o_proj has two cases:
|
|
|
|
|
# 1. prefill: o_proj is a TP weight, we need to all-gather o_proj weight to switch TP=1.
|
|
|
|
|
# 2. decode: all-to-all the hidden_state before the o_proj forward.
|
|
|
|
|
result, require_o_proj_forward = self._handle_o_proj_weight_switch_and_forward(
|
|
|
|
|
attn_output=attn_output,
|
|
|
|
|
output=output,
|
|
|
|
|
o_proj_full_handle=o_proj_full_handle,
|
|
|
|
|
should_shard_weight=should_shard_weight)
|
|
|
|
|
if not require_o_proj_forward:
|
|
|
|
|
return result
|
|
|
|
|
attn_output = result
|
|
|
|
|
|
|
|
|
|
output[...] = self.o_proj(attn_output)[0]
|
|
|
|
|
|
|
|
|
|
maybe_save_kv_layer_to_connector(layer_name, list(kv_cache))
|
|
|
|
|
@@ -912,7 +961,10 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|
|
|
|
k_pe, k_nope = torch.split(
|
|
|
|
|
k,
|
|
|
|
|
[self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim],
|
|
|
|
|
dim=-1) # [b,s,64+64]
|
|
|
|
|
dim=-1)
|
|
|
|
|
|
|
|
|
|
cos = cos.view(-1, 1, 1, self.qk_rope_head_dim)
|
|
|
|
|
sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)
|
|
|
|
|
|
|
|
|
|
k_pe = k_pe.unsqueeze(2)
|
|
|
|
|
k_pe = torch_npu.npu_interleave_rope(k_pe, cos, sin)
|
|
|
|
|
@@ -940,10 +992,7 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|
|
|
|
if q is None:
|
|
|
|
|
q, _ = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128]
|
|
|
|
|
q = q.view(-1, self.n_head, self.head_dim) # [n_toks,64,128]
|
|
|
|
|
|
|
|
|
|
cos_q, sin_q = cos, sin
|
|
|
|
|
cos = cos.view(-1, 1, 1, self.qk_rope_head_dim)
|
|
|
|
|
sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)
|
|
|
|
|
|
|
|
|
|
q_pe, q_nope = torch.split(
|
|
|
|
|
q,
|
|
|
|
|
@@ -984,3 +1033,92 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|
|
|
|
sparse_count=2048,
|
|
|
|
|
sparse_mode=3)
|
|
|
|
|
return topk_indices
|
|
|
|
|
|
|
|
|
|
def _init_o_proj_tp_full_params(self):
|
|
|
|
|
"""
|
|
|
|
|
Initialize TP-mode and Full-mode parameters for o_proj weight,
|
|
|
|
|
preparing for weight switching in PD mix stage.
|
|
|
|
|
|
|
|
|
|
For PD mix stage:
|
|
|
|
|
- Use original TP o_proj weight for decode phase
|
|
|
|
|
- Need full-gather o_proj weight from all TP ranks for prefill phase
|
|
|
|
|
"""
|
|
|
|
|
if AscendSFAImpl.o_proj_full_pool is None:
|
|
|
|
|
sample = self.o_proj.weight
|
|
|
|
|
AscendSFAImpl.o_proj_full_pool = torch.empty(
|
|
|
|
|
(sample.shape[0] * self.tp_size, sample.shape[1]),
|
|
|
|
|
dtype=sample.dtype,
|
|
|
|
|
device=sample.device)
|
|
|
|
|
|
|
|
|
|
# Save TP-mode parameters (original sharded weights)
|
|
|
|
|
self.o_proj_tp_weight = self.o_proj.weight.clone().detach()
|
|
|
|
|
self.o_proj_tp_aclnn_input_scale = self.o_proj.aclnn_input_scale.clone(
|
|
|
|
|
).detach()
|
|
|
|
|
self.o_proj_tp_aclnn_input_scale_reciprocal = self.o_proj.aclnn_input_scale_reciprocal.clone(
|
|
|
|
|
).detach()
|
|
|
|
|
self.o_proj_tp_aclnn_input_offset = self.o_proj.aclnn_input_offset.clone(
|
|
|
|
|
).detach()
|
|
|
|
|
|
|
|
|
|
# Initially switch to TP mode for graph capture
|
|
|
|
|
self.o_proj.weight.set_(self.o_proj_tp_weight)
|
|
|
|
|
self.o_proj.aclnn_input_scale.set_(self.o_proj_tp_aclnn_input_scale)
|
|
|
|
|
self.o_proj.aclnn_input_scale_reciprocal.set_(
|
|
|
|
|
self.o_proj_tp_aclnn_input_scale_reciprocal)
|
|
|
|
|
self.o_proj.aclnn_input_offset.set_(self.o_proj_tp_aclnn_input_offset)
|
|
|
|
|
|
|
|
|
|
# Precompute Full-mode quantization parameters by repeating TP parameters across all TP ranks
|
|
|
|
|
self.o_proj_full_aclnn_input_scale = self.o_proj.aclnn_input_scale.repeat(
|
|
|
|
|
self.tp_size)
|
|
|
|
|
self.o_proj_full_aclnn_input_scale_reciprocal = self.o_proj.aclnn_input_scale_reciprocal.repeat(
|
|
|
|
|
self.tp_size)
|
|
|
|
|
self.o_proj_full_aclnn_input_offset = self.o_proj.aclnn_input_offset.repeat(
|
|
|
|
|
self.tp_size)
|
|
|
|
|
|
|
|
|
|
def _handle_o_proj_weight_switch_and_forward(
|
|
|
|
|
self, attn_output: torch.Tensor, output: torch.Tensor,
|
|
|
|
|
o_proj_full_handle: Optional[torch.distributed.Work],
|
|
|
|
|
should_shard_weight: bool) -> Tuple[torch.Tensor, bool]:
|
|
|
|
|
"""
|
|
|
|
|
Handle o_proj weight switching between TP-mode and Full-mode, and execute forward computation.
|
|
|
|
|
"""
|
|
|
|
|
# Gather o_proj weight from all TP ranks for Full-mode computation
|
|
|
|
|
if should_shard_weight:
|
|
|
|
|
# Wait for the completion of o_proj weight all-gather operation
|
|
|
|
|
if o_proj_full_handle is not None:
|
|
|
|
|
o_proj_full_handle.wait()
|
|
|
|
|
|
|
|
|
|
# Switch o_proj to Full-mode (gathered weight from all TP ranks)
|
|
|
|
|
self.o_proj.weight.set_(AscendSFAImpl.o_proj_full_pool)
|
|
|
|
|
self.o_proj.aclnn_input_scale.set_(
|
|
|
|
|
self.o_proj_full_aclnn_input_scale)
|
|
|
|
|
self.o_proj.aclnn_input_scale_reciprocal.set_(
|
|
|
|
|
self.o_proj_full_aclnn_input_scale_reciprocal)
|
|
|
|
|
self.o_proj.aclnn_input_offset.set_(
|
|
|
|
|
self.o_proj_full_aclnn_input_offset)
|
|
|
|
|
|
|
|
|
|
# Apply quantization method and execute forward computation
|
|
|
|
|
output[...] = self.o_proj.quant_method.quant_method.apply(
|
|
|
|
|
self.o_proj, attn_output)
|
|
|
|
|
|
|
|
|
|
# Switch o_proj back to TP-mode for subsequent decode operations
|
|
|
|
|
self.o_proj.weight.set_(self.o_proj_tp_weight)
|
|
|
|
|
self.o_proj.aclnn_input_scale.set_(
|
|
|
|
|
self.o_proj_tp_aclnn_input_scale)
|
|
|
|
|
self.o_proj.aclnn_input_scale_reciprocal.set_(
|
|
|
|
|
self.o_proj_tp_aclnn_input_scale_reciprocal)
|
|
|
|
|
self.o_proj.aclnn_input_offset.set_(
|
|
|
|
|
self.o_proj_tp_aclnn_input_offset)
|
|
|
|
|
|
|
|
|
|
return output, False
|
|
|
|
|
else:
|
|
|
|
|
# For decode scenario: perform all-to-all communication on o_proj input activations
|
|
|
|
|
# Reshape for all-to-all: [batch * seq, tp_size, head_dim] -> [tp_size, batch * seq, head_dim]
|
|
|
|
|
send = attn_output.view(-1, self.tp_size, self.num_heads *
|
|
|
|
|
self.v_head_dim).permute(1, 0, 2).reshape(
|
|
|
|
|
-1, self.num_heads * self.v_head_dim)
|
|
|
|
|
|
|
|
|
|
attn_output = torch.empty_like(send)
|
|
|
|
|
torch.distributed.all_to_all_single(
|
|
|
|
|
attn_output, send, group=get_tp_group().device_group)
|
|
|
|
|
|
|
|
|
|
return attn_output, True
|
|
|
|
|
|