Files
xc-llm-ascend/vllm_ascend/attention/sfa_v1.py
rjg-lyh 7ed9e9de69 [Perf][1/N] w8a8c8 support in dsv3.2/glm5 (#7029)
### What this PR does / why we need it?
This PR supports W8A8C8 in dsv3.2/glm5 with lightning_indexer_quant ops
in pd-mix stage mainly.

Because the code for the current PD-disaggregated scenario is still
under refactoring and cleanup, this PR prioritizes ensuring the C8
functionality in the pd-mix scenario.

The next steps are planned in two parts:
① Once the optimized scatter operator is updated, we will replace the
original operator to improve the performance of storing k_scale.
② Once the code logic for the PD-disaggregated scenario becomes stable,
we will carry out more comprehensive validation and make appropriate
adaptations.
③ Because enabling C8 currently introduces several new operators whose
performance still needs improvement, performance may regress in some
scenarios. Therefore, only after all the operators are fully ready can
we ensure that this feature does not cause any performance degradation.
At that point, we will enable this feature by default and remove the
switch in `additional_config`.


### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
CI passed with new added/existing test.

- vLLM version: v0.16.0
- vLLM main:
4034c3d32e

---------

Signed-off-by: rjg-lyh <1318825571@qq.com>
2026-03-13 14:47:42 +08:00

1231 lines
52 KiB
Python

from dataclasses import dataclass
from typing import TYPE_CHECKING, TypeVar
import scipy # type: ignore
import torch
import torch_npu
import vllm.envs as envs_vllm
from torch import nn
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
from vllm.logger import logger
from vllm.model_executor.layers.attention.mla_attention import MLACommonMetadataBuilder
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.triton_utils import HAS_TRITON
from vllm.v1.attention.backend import (
AttentionBackend, # type: ignore
AttentionCGSupport,
MLAAttentionImpl,
)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm_ascend import envs
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.context_parallel.common_cp import AscendPCPMetadata
from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE, MLAPO_MAX_SUPPORTED_TOKENS
from vllm_ascend.attention.utils import (
AscendCommonAttentionMetadata,
ascend_chunked_prefill_workspace_size,
enable_cp,
maybe_save_kv_layer_to_connector,
trans_rope_weight,
transdata,
wait_for_kv_layer_from_connector,
)
from vllm_ascend.device.device_op import DeviceOperator
from vllm_ascend.distributed.utils import all_gather_async
from vllm_ascend.ops.layer_shard_linear import (
is_hidden_layer,
post_process_after_loading_for_shard_weight_series,
reach_layer_for_shard_weight_series,
register_all_layers_to_shard_weight_series,
)
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
from vllm_ascend.ops.triton.rope import rope_forward_triton_siso
from vllm_ascend.quantization.methods import AscendW8A8LinearMethod
from vllm_ascend.utils import (
ACL_FORMAT_FRACTAL_ND,
_round_up,
dispose_layer,
enable_dsa_cp,
enable_dsa_cp_with_layer_shard,
enable_dsa_cp_with_o_proj_tp,
get_weight_prefetch_method,
maybe_trans_nz,
)
from vllm_ascend.worker.npu_input_batch import NPUInputBatch
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
# token count limits within bmm_transpose operator
BMM_TRANS_MAX_SUPPORTED_TOKENS = 1024
class AscendSFABackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_name() -> str:
# HACK(Ronald1995): vllm `initialize_kv_cache` method in model runner v2 make
# attention name assertion, we just set name to FLASH_ATTN to avoid assertion error.
# rectify this when vllm disable the assertion.
return "ASCEND_SFA" if not envs_vllm.VLLM_USE_V2_MODEL_RUNNER else "FLASH_ATTN"
@staticmethod
def get_builder_cls():
if enable_cp():
from vllm_ascend.attention.context_parallel.sfa_cp import AscendSFACPMetadataBuilder
return AscendSFACPMetadataBuilder
return AscendSFAMetadataBuilder
@staticmethod
def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int, head_size: int) -> tuple[int, ...]:
return (num_blocks, block_size, num_kv_heads, head_size)
@staticmethod
def get_impl_cls() -> type["AscendSFAImpl"]:
if enable_cp():
from vllm_ascend.attention.context_parallel.sfa_cp import AscendSFACPImpl
return AscendSFACPImpl
return AscendSFAImpl
@staticmethod
def get_supported_kernel_block_sizes() -> list[int]:
return [128]
@dataclass
class DSACPContext:
num_tokens: int
num_tokens_pad: int
local_start: int
local_end: int
local_end_with_pad: int
slot_mapping_cp: torch.Tensor
actual_seq_lengths_query: torch.Tensor
actual_seq_lengths_key: torch.Tensor
@dataclass
class AscendSFAMetadata:
"""Metadata for MLACommon.
NOTE: Please read the comment at the top of the file before trying to
understand this class
"""
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
num_actual_tokens: int # Number of tokens excluding padding.
slot_mapping: torch.Tensor
seq_lens: torch.Tensor
cum_query_lens: torch.Tensor
block_table: torch.Tensor
sin: torch.Tensor
cos: torch.Tensor
# For logging.
num_input_tokens: int = 0 # Number of tokens including padding.
# The dimension of the attention heads
head_dim: int | None = None
attn_mask: torch.Tensor = None
# chunked prefill by default if no attn_states passed
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
dsa_cp_context: DSACPContext | None = None
reshape_cache_event: torch.npu.Event = None
sfa_cp_metadata: AscendPCPMetadata | None = None
num_decodes: int = 0
num_decode_tokens: int = 0
num_prefills: int = 0
M = TypeVar("M", bound=AscendSFAMetadata)
class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
"""
NOTE: Please read the comment at the top of the file before trying to
understand this class
"""
def __init__(
self,
kv_cache_spec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
metadata_cls: type[AscendSFAMetadata] | None = None,
supports_dcp_with_varlen: bool = False,
):
super().__init__(
kv_cache_spec,
layer_names,
vllm_config,
device,
metadata_cls if metadata_cls is not None else AscendSFAMetadata,
supports_dcp_with_varlen,
)
self.block_size = vllm_config.cache_config.block_size
self.max_blocks = (vllm_config.model_config.max_model_len + self.block_size - 1) // self.block_size
self.speculative_config = vllm_config.speculative_config
self.decode_threshold = 1
if self.speculative_config:
spec_token_num = self.speculative_config.num_speculative_tokens
self.decode_threshold += spec_token_num
assert self.decode_threshold <= 16, (
f"decode_threshold exceeded \
npu_fused_infer_attention_score TND layout's limit of 16, \
got {self.decode_threshold}"
)
self.reorder_batch_threshold = self.decode_threshold
self.attn_mask_builder = AttentionMaskBuilder(self.device)
self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
self.enable_dsa_cp = enable_dsa_cp()
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
self.actual_seq_lengths_query = torch.zeros(max_num_reqs + 1, dtype=torch.int32, device=device)
self.actual_seq_lengths_key = torch.empty_like(self.actual_seq_lengths_query)
@staticmethod
def determine_chunked_prefill_workspace_size(vllm_config: VllmConfig) -> int:
return ascend_chunked_prefill_workspace_size(vllm_config)
@classmethod
def get_cudagraph_support(
cls: type["AscendSFAMetadataBuilder"],
vllm_config: VllmConfig,
kv_cache_spec: AttentionSpec,
) -> AttentionCGSupport:
# Explicit override in case the underlying builder specialized this getter.
# @override omitted only because of mypy limitation due to type variable.
return AttentionCGSupport.UNIFORM_BATCH
def reorder_batch(self, input_batch: "NPUInputBatch", scheduler_output: "SchedulerOutput") -> bool:
# No need to reorder for Ascend SFA
return False
def build(
self,
common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata,
fast_build: bool = False,
) -> AscendSFAMetadata:
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
num_input_tokens = common_attn_metadata.num_input_tokens
block_table = common_attn_metadata.block_table_tensor[:num_reqs]
slot_mapping = common_attn_metadata.slot_mapping[:num_input_tokens]
input_positions = common_attn_metadata.positions[:num_input_tokens].long()
cum_query_lens = common_attn_metadata.query_start_loc[1 : num_reqs + 1]
seq_lens = common_attn_metadata.seq_lens[:num_reqs]
cos, sin = get_cos_and_sin_mla(input_positions, True)
dsa_cp_context = None
if self.enable_dsa_cp:
global_tp_size = get_tp_group().world_size
num_tokens = num_input_tokens
num_tokens_pad = _round_up(num_tokens, global_tp_size)
num_tokens_per_device = num_tokens_pad // global_tp_size
local_start = get_tp_group().rank_in_group * num_tokens_per_device
local_end_with_pad = local_start + num_tokens_per_device
local_end = min(local_end_with_pad, num_actual_tokens)
pad_size = num_tokens_pad - cos.shape[0]
assert cos.shape == sin.shape, f"cos.shape must be equal to sin.shape, got {cos.shape} and {sin.shape}"
if pad_size > 0:
cos = nn.functional.pad(cos, (0, 0, 0, 0, 0, 0, 0, pad_size))
sin = nn.functional.pad(sin, (0, 0, 0, 0, 0, 0, 0, pad_size))
pad_size_slot = num_tokens_pad - slot_mapping.shape[0]
if pad_size_slot > 0:
slot_mapping = nn.functional.pad(slot_mapping, (0, pad_size_slot), value=-1)
else:
slot_mapping = slot_mapping[:num_tokens_pad]
slot_mapping_cp = slot_mapping[local_start:local_end_with_pad]
cos = cos[local_start:local_end_with_pad]
sin = sin[local_start:local_end_with_pad]
assert cos.shape[0] == num_tokens_per_device, (
f"cos.shape[0] must be equal to num_tokens_per_device, \
got {cos.shape[0]} and {num_tokens_per_device}"
)
assert slot_mapping_cp.shape[0] == num_tokens_per_device, (
f"slot_mapping_cp.shape[0] must be equal to num_tokens_per_device, \
got {slot_mapping_cp.shape[0]} and {num_tokens_per_device}"
)
assert slot_mapping.shape[0] == num_tokens_pad, (
f"slot_mapping.shape[0] must be equal to num_tokens_pad, \
got {slot_mapping.shape[0]} and {num_tokens_pad}"
)
actual_seq_lengths_query = self.actual_seq_lengths_query
actual_seq_lengths_key = self.actual_seq_lengths_key
num_segs = cum_query_lens.shape[0]
last_token = 0
cum = 0
for i in range(0, num_segs):
global_start = last_token
global_end = cum_query_lens[i].item()
last_token = global_end
req_local_start = max(global_start, local_start)
req_local_end = min(global_end, local_end_with_pad)
num_local_tokens = req_local_end - req_local_start
if num_local_tokens > 0:
cum += num_local_tokens
actual_seq_lengths_query[i] = cum
offset = global_end - req_local_end
actual_seq_lengths_key[i] = seq_lens[i].item() - offset
else:
actual_seq_lengths_query[i] = cum
actual_seq_lengths_key[i] = 0
actual_seq_lengths_query = actual_seq_lengths_query[:num_reqs]
actual_seq_lengths_key = actual_seq_lengths_key[:num_reqs]
dsa_cp_context = DSACPContext(
num_tokens=num_tokens,
num_tokens_pad=num_tokens_pad,
local_start=local_start,
local_end=local_end,
local_end_with_pad=local_end_with_pad,
slot_mapping_cp=slot_mapping_cp,
actual_seq_lengths_query=actual_seq_lengths_query,
actual_seq_lengths_key=actual_seq_lengths_key,
)
return self.metadata_cls( # type: ignore
num_input_tokens=common_attn_metadata.num_input_tokens,
num_actual_tokens=num_actual_tokens,
cum_query_lens=cum_query_lens,
seq_lens=seq_lens,
slot_mapping=slot_mapping,
head_dim=self.model_config.get_head_size(),
attn_mask=self.attn_mask_builder.get_attention_mask(self.model_config),
attn_state=common_attn_metadata.attn_state,
block_table=block_table,
sin=sin[:num_input_tokens],
cos=cos[:num_input_tokens],
dsa_cp_context=dsa_cp_context,
)
def build_for_graph_capture(
self,
common_attn_metadata: AscendCommonAttentionMetadata,
attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly,
):
if attn_state in {AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding}:
attn_metadata = self.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)
else:
raise NotImplementedError("Currently we only support building dummy metadata for DecodeOnly state")
attn_metadata.attn_state = attn_state
return attn_metadata
class AscendSFAImpl(MLAAttentionImpl):
"""
NOTE: Please read the comment at the top of the file before trying to
understand this class
"""
# Supports forward using the all-gather o_proj weight for decode requests when Sharded CP is enabled.
o_proj_full_pool: torch.Tensor | None = None
# qk_hadamard tensor shared when dsa c8 enabled
qk_hadamard: torch.Tensor | None = None
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: list[float] | None,
sliding_window: int | None,
kv_cache_dtype: str,
logits_soft_cap: float | None,
attn_type: str,
kv_sharing_target_layer_name: str | None,
**kwargs,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
self.kv_cache_dtype = kv_cache_dtype
# MLA Args
self.q_lora_rank = kwargs["q_lora_rank"]
self.kv_lora_rank = kwargs["kv_lora_rank"]
self.qk_nope_head_dim = kwargs["qk_nope_head_dim"]
self.qk_rope_head_dim = kwargs["qk_rope_head_dim"]
self.qk_head_dim = kwargs["qk_head_dim"]
self.v_head_dim = kwargs["v_head_dim"]
self.rotary_emb = kwargs["rotary_emb"]
self.q_proj = kwargs["q_proj"] if self.q_lora_rank is None else kwargs["q_b_proj"]
self.fused_qkv_a_proj = kwargs.get("fused_qkv_a_proj")
self.kv_b_proj = kwargs["kv_b_proj"]
self.o_proj = kwargs["o_proj"]
self.indexer = kwargs["indexer"]
self.kv_a_proj_with_mqa = kwargs.get("kv_a_proj_with_mqa")
self.kv_a_layernorm = kwargs.get("kv_a_layernorm")
self.q_a_layernorm = kwargs.get("q_a_layernorm")
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tp_group().rank_in_group
self.q_b_proj = kwargs["q_b_proj"]
ascend_config = get_ascend_config()
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
# The MLAPO operator fuses the pre-processing steps on Q/K/V in MLA into a single operator
# NOTE: it imposes a limit on the number of input tokens and conflicts with FlashComm
self.enable_mlapo = envs.VLLM_ASCEND_ENABLE_MLAPO
assert self.indexer is not None, "Indexer is required for DSA."
self.local_num_heads = self.num_heads
self.vllm_config = get_current_vllm_config()
self.is_kv_producer = (
self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer
)
# indexer param
self.n_head: int = self.indexer.n_head # 64
self.head_dim: int = self.indexer.head_dim # 128
self.wq_b = self.indexer.wq_b
self.wk = self.indexer.wk
self.weights_proj = self.indexer.weights_proj
self.k_norm = self.indexer.k_norm
self.cp_size = 1
self.is_rope_neox_style = True
self.use_torch_npu_lightning_indexer = False
if self.vllm_config.model_config.hf_config.model_type in ["glm_moe_dsa"]:
self.is_rope_neox_style = False
self.use_torch_npu_lightning_indexer = True
# dsa c8
self.use_sparse_c8_indexer = ascend_config.enable_sparse_c8
if self.use_sparse_c8_indexer:
self.c8_k_cache_dtype = torch.int8
self.c8_k_scale_cache_dtype = torch.float16
# Effective in SFA when FlashComm is enabled.
self.enable_dsa_cp = enable_dsa_cp()
# Enable layer sharding via DSA-CP on the P node in the PD-disaggregated setup.
self.enable_dsa_cp_with_layer_shard = enable_dsa_cp_with_layer_shard()
# use original TP o_proj weight in PD mix stage, and full gather
# for o_proj weight for prefill stage.
self.enable_dsa_cp_with_o_proj_tp = enable_dsa_cp_with_o_proj_tp()
if self.enable_dsa_cp:
self.local_num_heads = self.num_heads * self.tp_size
if self.enable_dsa_cp_with_layer_shard:
self.layer_sharding_kwargs = []
for layer_name in get_ascend_config().layer_sharding or []:
if layer_name in kwargs:
self.layer_sharding_kwargs.append(kwargs[layer_name])
else:
logger.warning_once(
f"[SFAImpl init] Layer '{layer_name}' not found in kwargs for layer sharding, "
"skipping sharding configuration"
)
register_all_layers_to_shard_weight_series(self.layer_sharding_kwargs)
def process_weights_after_loading(self, act_dtype: torch.dtype):
# NOTE: We currently do not support quant kv_b_proj.
assert isinstance(self.kv_b_proj.quant_method, UnquantizedLinearMethod)
# NOTE: Weight will be reshaped next, we need to revert and transpose it.
kv_b_proj_weight = torch_npu.npu_format_cast(self.kv_b_proj.weight.data, ACL_FORMAT_FRACTAL_ND).T
assert kv_b_proj_weight.shape == (
self.kv_lora_rank,
self.local_num_heads * (self.qk_nope_head_dim + self.v_head_dim),
), (
f"{kv_b_proj_weight.shape=}, "
f"{self.kv_lora_rank=}, "
f"{self.local_num_heads=}, "
f"{self.qk_nope_head_dim=}, "
f"{self.v_head_dim=}"
)
kv_b_proj_weight = kv_b_proj_weight.view(
self.kv_lora_rank,
self.local_num_heads,
self.qk_nope_head_dim + self.v_head_dim,
)
W_UK, W_UV = kv_b_proj_weight.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
# Convert from (L, N, V) to (N, L, V)
self.W_UV = W_UV.transpose(0, 1).contiguous()
# Convert from (L, N, P) to (N, P, L)
self.W_UK_T = W_UK.permute(1, 2, 0).contiguous()
# TODO(zzzzwwjj): Currently, torch.ops._C_ascend.batch_matmul_transpose cannot support weight nz
# self.W_UV = maybe_trans_nz(self.W_UV)
# Dispose kv_b_proj since it is replaced by W_UV and W_UK_T to save memory
dispose_layer(self.kv_b_proj)
if self.enable_dsa_cp:
if self.enable_dsa_cp_with_layer_shard:
for layer in self.layer_sharding_kwargs or []:
if is_hidden_layer(layer):
post_process_after_loading_for_shard_weight_series(layer)
else:
self._init_o_proj_tp_full_params()
if self.enable_mlapo:
quant_method = getattr(
getattr(self.fused_qkv_a_proj, "quant_method", None),
"quant_method",
None,
)
reasons = []
if self.fused_qkv_a_proj is None or not isinstance(quant_method, AscendW8A8LinearMethod):
reasons.append(
"Currently mlapo only supports W8A8 quantization in SFA scenario."
"Some layers in your model are not quantized with W8A8,"
"thus mlapo is disabled for these layers."
)
if self.enable_dsa_cp:
reasons.append("Currently mlapo does not support SFA with CP,thus mlapo is disabled for these layers.")
if reasons:
self.enable_mlapo = False
for msg in reasons:
logger.warning_once(msg)
else:
self._process_weights_for_fused_mlapo(act_dtype)
if not self.enable_mlapo:
# if mlapo, W_UK_T can't trans nz
self.W_UK_T = maybe_trans_nz(self.W_UK_T)
if self.use_sparse_c8_indexer and AscendSFAImpl.qk_hadamard is None:
AscendSFAImpl.qk_hadamard = torch.tensor(scipy.linalg.hadamard(128), dtype=torch.bfloat16, device="npu") / (
128**0.5
)
# Processing the input parameters for MLAPO by reordering and transposing
# QKV(and part of Q) weight, applying RoPE-related dimension transformations,
# and handling quantization parameters.
def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype):
assert self.kv_a_proj_with_mqa is None
assert self.fused_qkv_a_proj is not None
kv_a_proj_wt = self.fused_qkv_a_proj.weight.data[..., self.q_lora_rank :].contiguous()
q_a_proj_wt = self.fused_qkv_a_proj.weight.data[..., : self.q_lora_rank].contiguous()
kv_a_proj_wt = kv_a_proj_wt.t().contiguous()
kv_a_proj_wt = trans_rope_weight(kv_a_proj_wt, self.qk_rope_head_dim)
kv_a_proj_wt = kv_a_proj_wt.t().contiguous()
wd_qkv = torch.cat((kv_a_proj_wt, q_a_proj_wt), dim=-1)
wd_qkv = wd_qkv.t().contiguous()
wd_qkv = transdata(wd_qkv, block_size=(16, 32)).unsqueeze(0).contiguous()
self.wd_qkv = torch_npu.npu_format_cast(wd_qkv, 29)
kv_a_proj_deq_scl = self.fused_qkv_a_proj.deq_scale[self.q_lora_rank :].contiguous()
q_a_proj_deq_scl = self.fused_qkv_a_proj.deq_scale[: self.q_lora_rank].contiguous()
kv_a_proj_deq_scl = kv_a_proj_deq_scl.reshape(self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous()
kv_a_proj_deq_scl = trans_rope_weight(kv_a_proj_deq_scl, self.qk_rope_head_dim)
kv_a_proj_deq_scl = kv_a_proj_deq_scl.view(self.kv_lora_rank + self.qk_rope_head_dim).contiguous()
self.deq_scale_qkv = torch.cat((kv_a_proj_deq_scl, q_a_proj_deq_scl), dim=-1).contiguous()
kv_a_proj_qt_bias = self.fused_qkv_a_proj.quant_bias[self.q_lora_rank :].contiguous()
q_a_proj_qt_bias = self.fused_qkv_a_proj.quant_bias[: self.q_lora_rank].contiguous()
kv_a_proj_qt_bias = kv_a_proj_qt_bias.reshape(self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous()
kv_a_proj_qt_bias = trans_rope_weight(kv_a_proj_qt_bias, self.qk_rope_head_dim)
kv_a_proj_qt_bias = kv_a_proj_qt_bias.view(self.kv_lora_rank + self.qk_rope_head_dim).contiguous()
self.quant_bias_qkv = torch.cat((kv_a_proj_qt_bias, q_a_proj_qt_bias), dim=-1).contiguous()
wu_q = self.q_proj.weight.data
wu_q = wu_q.t().reshape(self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1)
wu_q = trans_rope_weight(wu_q, self.qk_rope_head_dim)
wu_q = wu_q.reshape(self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim), -1)
wu_q = transdata(wu_q, block_size=(16, 32)).unsqueeze(0).contiguous()
self.wu_q = torch_npu.npu_format_cast(wu_q, 29)
qb_deq_scl = self.q_proj.deq_scale.data
qb_deq_scl = qb_deq_scl.reshape(self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1)
qb_deq_scl = trans_rope_weight(qb_deq_scl, self.qk_rope_head_dim)
self.qb_deq_scl = qb_deq_scl.reshape(self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim))
qb_qt_bias = self.q_proj.quant_bias.data
qb_qt_bias = qb_qt_bias.reshape(self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1)
qb_qt_bias = trans_rope_weight(qb_qt_bias, self.qk_rope_head_dim)
self.qb_qt_bias = qb_qt_bias.reshape(self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim))
device = self.q_proj.weight.device
self.gamma1 = self.q_a_layernorm.weight.data # type: ignore[union-attr]
self.beta1 = self.q_a_layernorm.bias.data # type: ignore[union-attr]
self.gamma2 = self.kv_a_layernorm.weight.data # type: ignore[union-attr]
self.quant_scale0 = self.fused_qkv_a_proj.input_scale.data
self.quant_offset0 = self.fused_qkv_a_proj.input_offset.data
self.quant_scale1 = self.q_proj.input_scale.data
self.quant_offset1 = self.q_proj.input_offset.data
self.ctkv_scale = torch.tensor([1], dtype=act_dtype, device=device)
self.q_nope_scale = torch.tensor([1], dtype=act_dtype, device=device)
# On KV consumers (decode-only) MLAPO uses the transformed weights built above;
# the original fused_qkv_a_proj/q_proj weights and quant params are no longer
# referenced, so drop them to save memory.
if (
self.vllm_config.kv_transfer_config is not None
and self.vllm_config.kv_transfer_config.is_kv_consumer
and self.vllm_config.scheduler_config.max_num_batched_tokens <= MLAPO_MAX_SUPPORTED_TOKENS
):
self.fused_qkv_a_proj.weight = None
self.fused_qkv_a_proj.deq_scale = None
self.fused_qkv_a_proj.quant_bias = None
self.q_proj.weight = None
self.q_proj.deq_scale = None
self.q_proj.quant_bias = None
torch.npu.empty_cache()
def forward_mha(
self,
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: M,
k_scale: torch.Tensor,
output: torch.Tensor,
) -> None:
raise NotImplementedError("forward_mha is not supported for SFA attention. Use forward() instead.")
def forward_mqa(
self,
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: M,
layer,
) -> tuple[torch.Tensor, torch.Tensor | None]:
raise NotImplementedError("forward_mqa is not supported for SFA attention. Use forward() instead.")
def rope_single(
self,
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> torch.Tensor:
B, N, D = x.shape
S = 1
x = x.view(B, N, S, D)
x = torch_npu.npu_interleave_rope(x, cos, sin)
return x.view(B, N, D)
def _init_o_proj_tp_full_params(self):
"""
Initialize TP-mode and Full-mode parameters for o_proj weight,
preparing for weight switching in PD mix stage.
For PD mix stage:
- Use original TP o_proj weight for decode phase
- Need full-gather o_proj weight from all TP ranks for prefill phase
"""
if AscendSFAImpl.o_proj_full_pool is None:
sample = self.o_proj.weight
AscendSFAImpl.o_proj_full_pool = torch.empty(
(sample.shape[0] * self.tp_size, sample.shape[1]), dtype=sample.dtype, device=sample.device
)
# Save TP-mode parameters (original sharded weights)
self.o_proj_tp_weight = self.o_proj.weight.clone().detach()
self.o_proj_tp_aclnn_input_scale = self.o_proj.aclnn_input_scale.clone().detach()
self.o_proj_tp_aclnn_input_scale_reciprocal = self.o_proj.aclnn_input_scale_reciprocal.clone().detach()
self.o_proj_tp_aclnn_input_offset = self.o_proj.aclnn_input_offset.clone().detach()
# Initially switch to TP mode for graph capture
self.o_proj.weight.set_(self.o_proj_tp_weight)
self.o_proj.aclnn_input_scale.set_(self.o_proj_tp_aclnn_input_scale)
self.o_proj.aclnn_input_scale_reciprocal.set_(self.o_proj_tp_aclnn_input_scale_reciprocal)
self.o_proj.aclnn_input_offset.set_(self.o_proj_tp_aclnn_input_offset)
# Precompute Full-mode quantization parameters by repeating TP parameters across all TP ranks
self.o_proj_full_aclnn_input_scale = self.o_proj.aclnn_input_scale.repeat(self.tp_size)
self.o_proj_full_aclnn_input_scale_reciprocal = self.o_proj.aclnn_input_scale_reciprocal.repeat(self.tp_size)
self.o_proj_full_aclnn_input_offset = self.o_proj.aclnn_input_offset.repeat(self.tp_size)
def _handle_o_proj_weight_switch_and_forward(
self,
attn_output: torch.Tensor,
output: torch.Tensor,
o_proj_full_handle: torch.distributed.Work | None,
should_shard_weight: bool,
) -> tuple[torch.Tensor, bool]:
"""
Handle o_proj weight switching between TP-mode and Full-mode, and execute forward computation.
"""
# Gather o_proj weight from all TP ranks for Full-mode computation
if should_shard_weight:
# Wait for the completion of o_proj weight all-gather operation
if o_proj_full_handle is not None:
o_proj_full_handle.wait()
# Switch o_proj to Full-mode (gathered weight from all TP ranks)
self.o_proj.weight.set_(AscendSFAImpl.o_proj_full_pool)
self.o_proj.aclnn_input_scale.set_(self.o_proj_full_aclnn_input_scale)
self.o_proj.aclnn_input_scale_reciprocal.set_(self.o_proj_full_aclnn_input_scale_reciprocal)
self.o_proj.aclnn_input_offset.set_(self.o_proj_full_aclnn_input_offset)
# Apply quantization method and execute forward computation
output[...] = self.o_proj.quant_method.quant_method.apply(self.o_proj, attn_output)
# Switch o_proj back to TP-mode for subsequent decode operations
self.o_proj.weight.set_(self.o_proj_tp_weight)
self.o_proj.aclnn_input_scale.set_(self.o_proj_tp_aclnn_input_scale)
self.o_proj.aclnn_input_scale_reciprocal.set_(self.o_proj_tp_aclnn_input_scale_reciprocal)
self.o_proj.aclnn_input_offset.set_(self.o_proj_tp_aclnn_input_offset)
return output, False
else:
# For decode scenario: perform all-to-all communication on o_proj input activations
# Reshape for all-to-all: [batch * seq, tp_size, head_dim] -> [tp_size, batch * seq, head_dim]
send = (
attn_output.view(-1, self.tp_size, self.num_heads * self.v_head_dim)
.permute(1, 0, 2)
.reshape(-1, self.num_heads * self.v_head_dim)
)
attn_output = torch.empty_like(send)
torch.distributed.all_to_all_single(attn_output, send, group=get_tp_group().device_group)
return attn_output, True
def _get_full_kv(self, k, attn_metadata):
return k
def exec_kv(
self,
kv_no_split: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
kv_cache: tuple,
slots: torch.Tensor,
attn_metadata: M,
):
B = kv_no_split.shape[0]
N = self.num_kv_heads
S = 1
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
kv_no_split = kv_no_split.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
cache_mode = "PA"
if self.enable_dsa_cp:
_, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache(
kv_no_split,
self.kv_a_layernorm.weight, # type: ignore[union-attr]
cos,
sin,
slots.to(torch.int64),
kv_cache[1],
kv_cache[0],
epsilon=self.kv_a_layernorm.variance_epsilon, # type: ignore[union-attr]
cache_mode=cache_mode,
is_output_kv=True,
)
return k_pe, k_nope
else:
torch_npu.npu_kv_rmsnorm_rope_cache(
kv_no_split,
self.kv_a_layernorm.weight, # type: ignore[union-attr]
cos,
sin,
slots.to(torch.int64),
kv_cache[1],
kv_cache[0],
epsilon=self.kv_a_layernorm.variance_epsilon, # type: ignore[union-attr]
cache_mode=cache_mode,
)
return None, None
# Return `ql_nope`, `q_pe`
def _q_proj_and_k_up_proj(self, x):
q_nope, q_pe = (
self.q_proj(x)[0]
.view(-1, self.local_num_heads, self.qk_head_dim)
.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
)
# Convert from (B, N, P) to (N, B, P)
q_nope = q_nope.transpose(0, 1)
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
ql_nope = torch.bmm(q_nope, self.W_UK_T)
# Convert from (N, B, L) to (B, N, L)
return ql_nope.transpose(0, 1), q_pe
def _v_up_proj(self, x):
num_input_tokens, _, _ = x.shape
if (
x.dtype in [torch.float16, torch.bfloat16]
and hasattr(torch.ops._C_ascend, "batch_matmul_transpose")
and num_input_tokens <= BMM_TRANS_MAX_SUPPORTED_TOKENS
):
x = x.view(-1, self.local_num_heads, self.kv_lora_rank)
res = torch.empty((num_input_tokens, self.local_num_heads, self.v_head_dim), dtype=x.dtype, device=x.device)
torch.ops._C_ascend.batch_matmul_transpose(x, self.W_UV, res)
x = res.reshape(-1, self.local_num_heads * self.v_head_dim)
else:
# Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.local_num_heads, self.kv_lora_rank).transpose(0, 1)
# # Multiply (N, B, L) x (N, L, V) -> (N, B, V)
x = torch.bmm(x, self.W_UV)
# # Convert from (N, B, V) to (B, N * V)
x = x.transpose(0, 1).reshape(-1, self.local_num_heads * self.v_head_dim)
return x
def _sfa_preprocess_with_mlapo(
self,
hidden_states: torch.Tensor,
kv_cache: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
cos: torch.Tensor,
sin: torch.Tensor,
slot_mapping: torch.Tensor,
num_input_tokens: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
k_nope, k_pe = kv_cache[0], kv_cache[1]
ql_nope = torch.empty(
(num_input_tokens, self.W_UK_T.shape[0], k_nope.shape[-1]),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
q_pe = torch.empty(
(num_input_tokens, self.W_UK_T.shape[0], k_pe.shape[-1]),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
q_c = torch.empty(
(num_input_tokens, self.q_lora_rank),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
torch.ops._C_ascend.mla_preprocess(
hidden_states,
self.wd_qkv,
self.deq_scale_qkv,
self.gamma1,
self.beta1,
self.wu_q,
self.qb_deq_scl,
self.gamma2,
cos,
sin,
self.W_UK_T,
k_nope,
k_pe,
slot_mapping,
quant_scale0=self.quant_scale0,
quant_offset0=self.quant_offset0,
bias0=self.quant_bias_qkv,
quant_scale1=self.quant_scale1,
quant_offset1=self.quant_offset1,
bias1=self.qb_qt_bias,
ctkv_scale=self.ctkv_scale,
q_nope_scale=self.q_nope_scale,
cache_mode="krope_ctkv",
quant_mode="per_tensor_quant_asymm",
enable_inner_out=True,
q_out0=ql_nope,
kv_cache_out0=k_nope,
q_out1=q_pe,
kv_cache_out1=k_pe,
inner_out=q_c,
)
return hidden_states, ql_nope, q_pe, q_c
def indexer_select_pre_process(
self,
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
):
k_li, _ = self.wk(x) # [b,s,7168] @ [7168,128] = [b,s,128]
k_li = self.k_norm(k_li).unsqueeze(1)
k_li = k_li.view(-1, 1, self.head_dim)
if HAS_TRITON:
cos = cos.view(-1, self.qk_rope_head_dim)
sin = sin.view(-1, self.qk_rope_head_dim)
k_li = rope_forward_triton_siso(
k_li, cos, sin, rope_dim=self.qk_rope_head_dim, is_neox_style=self.is_rope_neox_style
)
else:
k_li_pe, k_li_nope = torch.split(
k_li, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], dim=-1
)
cos = cos.view(-1, 1, 1, self.qk_rope_head_dim)
sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)
k_li_pe = k_li_pe.unsqueeze(2)
k_li_pe = torch_npu.npu_interleave_rope(k_li_pe, cos, sin)
k_li_pe = k_li_pe.squeeze(2)
k_li = torch.cat([k_li_pe, k_li_nope], dim=-1) # [b*s,128]
if self.use_sparse_c8_indexer:
k_li = k_li @ AscendSFAImpl.qk_hadamard
k_li, k_li_scale = torch_npu.npu_dynamic_quant(k_li.view(-1, self.head_dim), dst_type=self.c8_k_cache_dtype)
k_li_scale = k_li_scale.to(self.c8_k_scale_cache_dtype) # [b*s,]
k_li_scale = k_li_scale.unsqueeze(-1) # [b*s,1]
else:
k_li_scale = None
return k_li, k_li_scale
def indexer_select_post_process(
self,
x: torch.Tensor,
q_c: torch.Tensor,
kv_cache: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
attn_metadata: M,
cos: torch.Tensor,
sin: torch.Tensor,
actual_seq_lengths_query: torch.Tensor,
actual_seq_lengths_key: torch.Tensor,
):
weights, _ = self.weights_proj(x)
q_li, _ = self.wq_b(q_c) # [b,s,1536] @ [1536,64*128] = [b,s,64*128]
q_li = q_li.view(-1, self.n_head, self.head_dim) # [n_toks,64,128]
if HAS_TRITON:
q_li = rope_forward_triton_siso(
q_li, cos, sin, rope_dim=self.qk_rope_head_dim, is_neox_style=self.is_rope_neox_style
)
else:
q_li_pe, q_li_nope = torch.split(
q_li, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], dim=-1
) # [b,s,64,64+64]
q_li_pe = q_li_pe.unsqueeze(2)
q_li_pe = torch_npu.npu_rotary_mul(q_li_pe, cos, sin)
q_li_pe = q_li_pe.squeeze(2)
q_li = torch.cat([q_li_pe, q_li_nope], dim=-1) # [b*s,64,128]
if self.use_sparse_c8_indexer:
q_li_shape_ori = q_li.shape
q_li = q_li @ AscendSFAImpl.qk_hadamard
q_li, q_li_scale = torch_npu.npu_dynamic_quant(q_li.view(-1, self.head_dim), dst_type=self.c8_k_cache_dtype)
q_li_scale = q_li_scale.to(self.c8_k_scale_cache_dtype)
# DSV3.2 currently has graph compilation issues when using torch_npu.npu.lightning_indexer.
# So two branches are maintained temporarily.
# TODO: torch.ops._C_ascend.npu_lightning_indexer needs to be removed.
if self.use_sparse_c8_indexer:
assert len(kv_cache) == 4
weights = weights.to(torch.float16)
topk_indices = torch.ops._C_ascend.npu_lightning_indexer_quant(
query=q_li.view(q_li_shape_ori),
key=kv_cache[2],
weights=weights,
query_dequant_scale=q_li_scale.view(q_li_shape_ori[:-1]),
key_dequant_scale=kv_cache[3].squeeze(2), # B S N D -> B S D
actual_seq_lengths_query=actual_seq_lengths_query,
actual_seq_lengths_key=actual_seq_lengths_key,
block_table=attn_metadata.block_table,
query_quant_mode=0,
key_quant_mode=0,
layout_query="TND",
layout_key="PA_BSND",
sparse_count=2048,
sparse_mode=3,
)
elif self.use_torch_npu_lightning_indexer:
topk_indices, _ = torch_npu.npu_lightning_indexer(
query=q_li,
key=kv_cache[2],
weights=weights,
actual_seq_lengths_query=actual_seq_lengths_query,
actual_seq_lengths_key=actual_seq_lengths_key,
block_table=attn_metadata.block_table,
layout_query="TND",
layout_key="PA_BSND",
sparse_count=2048,
sparse_mode=3,
)
else:
topk_indices = torch.ops._C_ascend.npu_lightning_indexer(
query=q_li,
key=kv_cache[2],
weights=weights,
actual_seq_lengths_query=actual_seq_lengths_query,
actual_seq_lengths_key=actual_seq_lengths_key,
block_table=attn_metadata.block_table,
layout_query="TND",
layout_key="PA_BSND",
sparse_count=2048,
sparse_mode=3,
)
return topk_indices
def _execute_sparse_flash_attention_process(
self, ql_nope, q_pe, kv_cache, topk_indices, attn_metadata, actual_seq_lengths_query, actual_seq_lengths_key
):
block_table = attn_metadata.block_table
kv = kv_cache[0]
key_rope = kv_cache[1]
attn_output = torch.ops._C_ascend.npu_sparse_flash_attention(
query=ql_nope,
key=kv,
value=kv,
sparse_indices=topk_indices,
scale_value=self.scale,
sparse_block_size=1,
block_table=block_table,
actual_seq_lengths_query=actual_seq_lengths_query,
actual_seq_lengths_kv=actual_seq_lengths_key,
query_rope=q_pe,
key_rope=key_rope,
layout_query="TND",
layout_kv="PA_BSND",
sparse_mode=3,
)
return attn_output
def forward(
self,
layer_name,
hidden_states: torch.Tensor, # query in unified attn
kv_cache: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
attn_metadata: M,
need_gather_q_kv: bool = False,
output: torch.Tensor | None = None,
) -> torch.Tensor:
assert output is not None, "Output tensor must be provided."
if attn_metadata is None:
# Profiling run.
if self.enable_dsa_cp_with_layer_shard and not _EXTRA_CTX.in_profile_run:
for layer in self.layer_sharding_kwargs or []:
if is_hidden_layer(layer):
reach_layer_for_shard_weight_series(layer)
return output.fill_(0)
cos = attn_metadata.cos
sin = attn_metadata.sin
slot_mapping = attn_metadata.slot_mapping
slot_mapping_cp = None
if self.enable_dsa_cp:
assert attn_metadata.dsa_cp_context is not None
slot_mapping_cp = attn_metadata.dsa_cp_context.slot_mapping_cp
actual_seq_lengths_query = attn_metadata.dsa_cp_context.actual_seq_lengths_query
actual_seq_lengths_key = attn_metadata.dsa_cp_context.actual_seq_lengths_key
else:
actual_seq_lengths_query = attn_metadata.cum_query_lens
actual_seq_lengths_key = attn_metadata.seq_lens
# Inputs and outputs may be padded for CUDA graphs
num_input_tokens = attn_metadata.num_input_tokens
output_padded = output
# all-gather o_proj weight for prefill stage of PD mix node
o_proj_full_handle = None
# if is PD mix stage, using original TP o_proj weight, and also need to full gather for o_proj
# weight for prefill stage.
full_gather_o_proj_enabled = self.enable_dsa_cp_with_o_proj_tp and attn_metadata.attn_state not in {
AscendAttentionState.DecodeOnly,
AscendAttentionState.SpecDecoding,
}
# run mlapo ops when dsa-cp is disabled, and ensure that num_tokens satisfies the count limitation
if self.enable_mlapo and num_input_tokens <= MLAPO_MAX_SUPPORTED_TOKENS:
hidden_states, ql_nope, q_pe, q_c = self._sfa_preprocess_with_mlapo(
hidden_states=hidden_states,
kv_cache=kv_cache,
cos=cos,
sin=sin,
slot_mapping=slot_mapping,
num_input_tokens=num_input_tokens,
)
k_li = self.indexer_select_pre_process(x=hidden_states, cos=cos, sin=sin)
# native
else:
assert self.fused_qkv_a_proj is not None, "q lora is required for DSA."
weight_prefetch_method = get_weight_prefetch_method()
weight_prefetch_method.maybe_prefetch_mla_or_sla_weight_in_current_stream(
inputs=self.fused_qkv_a_proj.weight, dependency=hidden_states
)
qkv_lora = self.fused_qkv_a_proj(hidden_states)[0]
q_c, kv_no_split = qkv_lora.split(
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
dim=-1,
)
assert self.q_a_layernorm is not None, "q_a_layernorm must be initialized"
q_c = self.q_a_layernorm(q_c)
k_li, k_li_scale = self.indexer_select_pre_process(x=hidden_states, cos=cos, sin=sin)
wait_for_kv_layer_from_connector(layer_name)
if self.enable_dsa_cp:
assert slot_mapping_cp is not None
k_pe, k_nope = self.exec_kv(kv_no_split, cos, sin, kv_cache, slot_mapping_cp, attn_metadata)
else:
k_pe, k_nope = self.exec_kv(kv_no_split, cos, sin, kv_cache, slot_mapping, attn_metadata)
if self.enable_dsa_cp:
assert k_pe is not None
assert k_nope is not None
assert k_li is not None
async_op = self.enable_dsa_cp_with_layer_shard or full_gather_o_proj_enabled
# support all_gather kv async for communication calculation overlap
if not self.use_sparse_c8_indexer:
fused_kv_no_split, kv_ag_handle = all_gather_async(
torch.cat(
[
k_pe.view(-1, k_pe.shape[-1]),
k_nope.view(-1, k_nope.shape[-1]),
k_li.view(-1, k_li.shape[-1]),
],
dim=1,
),
get_tp_group(),
async_op=async_op,
)
else:
# due to different dtypes, we have to split commu pass
assert k_li_scale is not None
fused_kv_no_split, _ = all_gather_async(
torch.cat(
[
k_pe.view(-1, k_pe.shape[-1]),
k_nope.view(-1, k_nope.shape[-1]),
],
dim=1,
),
get_tp_group(),
async_op=async_op,
)
k_li, _ = all_gather_async(
k_li,
get_tp_group(),
async_op=async_op,
)
k_li_scale, kv_ag_handle = all_gather_async(
k_li_scale,
get_tp_group(),
async_op=async_op,
)
ql_nope, q_pe = self._q_proj_and_k_up_proj(q_c)
q_pe = self.rope_single(q_pe, cos, sin)
if self.enable_dsa_cp:
if kv_ag_handle is not None:
kv_ag_handle.wait()
if self.enable_dsa_cp_with_layer_shard:
for layer in self.layer_sharding_kwargs or []:
if is_hidden_layer(layer):
reach_layer_for_shard_weight_series(layer)
elif full_gather_o_proj_enabled:
_, o_proj_full_handle = all_gather_async(
self.o_proj_tp_weight, get_tp_group(), output=AscendSFAImpl.o_proj_full_pool
)
if kv_cache is not None:
assert fused_kv_no_split is not None
if not self.use_sparse_c8_indexer:
k_pe, k_nope, k_li = fused_kv_no_split.split(
[self.qk_rope_head_dim, self.kv_lora_rank, self.head_dim], dim=-1
)
else:
k_pe, k_nope = fused_kv_no_split.split([self.qk_rope_head_dim, self.kv_lora_rank], dim=-1)
k_nope = k_nope.view(k_nope.shape[0], 1, -1)
k_pe = k_pe.view(k_pe.shape[0], 1, -1)
DeviceOperator.reshape_and_cache(
key=k_nope[: attn_metadata.num_actual_tokens],
value=k_pe[: attn_metadata.num_actual_tokens],
key_cache=kv_cache[0],
value_cache=kv_cache[1],
slot_mapping=slot_mapping[: attn_metadata.num_actual_tokens],
)
k_li = self._get_full_kv(k_li, attn_metadata)
if kv_cache is not None:
if self.is_kv_producer:
attn_metadata.reshape_cache_event = torch.npu.Event()
torch_npu.npu_scatter_nd_update_(
kv_cache[2].view(-1, k_li.shape[-1]), slot_mapping.view(-1, 1), k_li.view(-1, k_li.shape[-1])
) # b, s, n, d
if self.use_sparse_c8_indexer:
assert len(kv_cache) == 4
torch_npu.npu_scatter_nd_update_(
kv_cache[3].view(-1, k_li_scale.shape[-1]),
slot_mapping.view(-1, 1),
k_li_scale.view(-1, k_li_scale.shape[-1]),
)
if self.is_kv_producer:
attn_metadata.reshape_cache_event.record()
topk_indices = self.indexer_select_post_process(
x=hidden_states,
q_c=q_c,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
cos=cos,
sin=sin,
actual_seq_lengths_query=actual_seq_lengths_query,
actual_seq_lengths_key=actual_seq_lengths_key,
)
attn_output = self._execute_sparse_flash_attention_process(
ql_nope, q_pe, kv_cache, topk_indices, attn_metadata, actual_seq_lengths_query, actual_seq_lengths_key
)
attn_output = self._v_up_proj(attn_output)
weight_prefetch_method = get_weight_prefetch_method()
weight_prefetch_method.maybe_prefetch_mla_or_sla_weight_in_current_stream(
inputs=self.o_proj.weight,
dependency=attn_output,
max_size=MAX_O_PROJ_PREFETCH_SIZE,
linear_layer=self.o_proj,
)
if self.enable_dsa_cp_with_o_proj_tp:
# When using SFA-CP with pd mixed, o_proj has two cases:
# 1. prefill: o_proj is a TP weight, we need to all-gather o_proj weight to switch TP=1.
# 2. decode: all-to-all the hidden_state before the o_proj forward.
result, require_o_proj_forward = self._handle_o_proj_weight_switch_and_forward(
attn_output=attn_output,
output=output,
o_proj_full_handle=o_proj_full_handle,
should_shard_weight=full_gather_o_proj_enabled,
)
if not require_o_proj_forward:
return result
attn_output = result
output[...] = self.o_proj(attn_output)[0]
maybe_save_kv_layer_to_connector(layer_name, list(kv_cache))
return output_padded