### What this PR does / why we need it?
This PR supports W8A8C8 in dsv3.2/glm5 with lightning_indexer_quant ops
in pd-mix stage mainly.
Because the code for the current PD-disaggregated scenario is still
under refactoring and cleanup, this PR prioritizes ensuring the C8
functionality in the pd-mix scenario.
The next steps are planned in two parts:
① Once the optimized scatter operator is updated, we will replace the
original operator to improve the performance of storing k_scale.
② Once the code logic for the PD-disaggregated scenario becomes stable,
we will carry out more comprehensive validation and make appropriate
adaptations.
③ Because enabling C8 currently introduces several new operators whose
performance still needs improvement, performance may regress in some
scenarios. Therefore, only after all the operators are fully ready can
we ensure that this feature does not cause any performance degradation.
At that point, we will enable this feature by default and remove the
switch in `additional_config`.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
CI passed with new added/existing test.
- vLLM version: v0.16.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: rjg-lyh <1318825571@qq.com>
1231 lines
52 KiB
Python
1231 lines
52 KiB
Python
from dataclasses import dataclass
|
|
from typing import TYPE_CHECKING, TypeVar
|
|
|
|
import scipy # type: ignore
|
|
import torch
|
|
import torch_npu
|
|
import vllm.envs as envs_vllm
|
|
from torch import nn
|
|
from vllm.config import VllmConfig, get_current_vllm_config
|
|
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
|
|
from vllm.logger import logger
|
|
from vllm.model_executor.layers.attention.mla_attention import MLACommonMetadataBuilder
|
|
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
|
from vllm.triton_utils import HAS_TRITON
|
|
from vllm.v1.attention.backend import (
|
|
AttentionBackend, # type: ignore
|
|
AttentionCGSupport,
|
|
MLAAttentionImpl,
|
|
)
|
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
|
|
|
from vllm_ascend import envs
|
|
from vllm_ascend.ascend_config import get_ascend_config
|
|
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
|
|
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
|
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
|
from vllm_ascend.attention.context_parallel.common_cp import AscendPCPMetadata
|
|
from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE, MLAPO_MAX_SUPPORTED_TOKENS
|
|
from vllm_ascend.attention.utils import (
|
|
AscendCommonAttentionMetadata,
|
|
ascend_chunked_prefill_workspace_size,
|
|
enable_cp,
|
|
maybe_save_kv_layer_to_connector,
|
|
trans_rope_weight,
|
|
transdata,
|
|
wait_for_kv_layer_from_connector,
|
|
)
|
|
from vllm_ascend.device.device_op import DeviceOperator
|
|
from vllm_ascend.distributed.utils import all_gather_async
|
|
from vllm_ascend.ops.layer_shard_linear import (
|
|
is_hidden_layer,
|
|
post_process_after_loading_for_shard_weight_series,
|
|
reach_layer_for_shard_weight_series,
|
|
register_all_layers_to_shard_weight_series,
|
|
)
|
|
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
|
|
from vllm_ascend.ops.triton.rope import rope_forward_triton_siso
|
|
from vllm_ascend.quantization.methods import AscendW8A8LinearMethod
|
|
from vllm_ascend.utils import (
|
|
ACL_FORMAT_FRACTAL_ND,
|
|
_round_up,
|
|
dispose_layer,
|
|
enable_dsa_cp,
|
|
enable_dsa_cp_with_layer_shard,
|
|
enable_dsa_cp_with_o_proj_tp,
|
|
get_weight_prefetch_method,
|
|
maybe_trans_nz,
|
|
)
|
|
from vllm_ascend.worker.npu_input_batch import NPUInputBatch
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm.v1.core.sched.output import SchedulerOutput
|
|
|
|
# token count limits within bmm_transpose operator
|
|
BMM_TRANS_MAX_SUPPORTED_TOKENS = 1024
|
|
|
|
|
|
class AscendSFABackend(AttentionBackend):
|
|
accept_output_buffer: bool = True
|
|
|
|
@staticmethod
|
|
def get_name() -> str:
|
|
# HACK(Ronald1995): vllm `initialize_kv_cache` method in model runner v2 make
|
|
# attention name assertion, we just set name to FLASH_ATTN to avoid assertion error.
|
|
# rectify this when vllm disable the assertion.
|
|
return "ASCEND_SFA" if not envs_vllm.VLLM_USE_V2_MODEL_RUNNER else "FLASH_ATTN"
|
|
|
|
@staticmethod
|
|
def get_builder_cls():
|
|
if enable_cp():
|
|
from vllm_ascend.attention.context_parallel.sfa_cp import AscendSFACPMetadataBuilder
|
|
|
|
return AscendSFACPMetadataBuilder
|
|
return AscendSFAMetadataBuilder
|
|
|
|
@staticmethod
|
|
def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int, head_size: int) -> tuple[int, ...]:
|
|
return (num_blocks, block_size, num_kv_heads, head_size)
|
|
|
|
@staticmethod
|
|
def get_impl_cls() -> type["AscendSFAImpl"]:
|
|
if enable_cp():
|
|
from vllm_ascend.attention.context_parallel.sfa_cp import AscendSFACPImpl
|
|
|
|
return AscendSFACPImpl
|
|
return AscendSFAImpl
|
|
|
|
@staticmethod
|
|
def get_supported_kernel_block_sizes() -> list[int]:
|
|
return [128]
|
|
|
|
|
|
@dataclass
|
|
class DSACPContext:
|
|
num_tokens: int
|
|
num_tokens_pad: int
|
|
local_start: int
|
|
local_end: int
|
|
local_end_with_pad: int
|
|
slot_mapping_cp: torch.Tensor
|
|
actual_seq_lengths_query: torch.Tensor
|
|
actual_seq_lengths_key: torch.Tensor
|
|
|
|
|
|
@dataclass
|
|
class AscendSFAMetadata:
|
|
"""Metadata for MLACommon.
|
|
|
|
NOTE: Please read the comment at the top of the file before trying to
|
|
understand this class
|
|
"""
|
|
|
|
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
|
# |---------- N-1 iteration --------|
|
|
# |---------------- N iteration ---------------------|
|
|
# |- tokenA -|......................|-- newTokens ---|
|
|
# |---------- context_len ----------|
|
|
# |-------------------- seq_len ---------------------|
|
|
# |-- query_len ---|
|
|
num_actual_tokens: int # Number of tokens excluding padding.
|
|
slot_mapping: torch.Tensor
|
|
seq_lens: torch.Tensor
|
|
cum_query_lens: torch.Tensor
|
|
block_table: torch.Tensor
|
|
sin: torch.Tensor
|
|
cos: torch.Tensor
|
|
|
|
# For logging.
|
|
num_input_tokens: int = 0 # Number of tokens including padding.
|
|
# The dimension of the attention heads
|
|
head_dim: int | None = None
|
|
attn_mask: torch.Tensor = None
|
|
# chunked prefill by default if no attn_states passed
|
|
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
|
|
dsa_cp_context: DSACPContext | None = None
|
|
reshape_cache_event: torch.npu.Event = None
|
|
sfa_cp_metadata: AscendPCPMetadata | None = None
|
|
num_decodes: int = 0
|
|
num_decode_tokens: int = 0
|
|
num_prefills: int = 0
|
|
|
|
|
|
M = TypeVar("M", bound=AscendSFAMetadata)
|
|
|
|
|
|
class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
|
|
"""
|
|
NOTE: Please read the comment at the top of the file before trying to
|
|
understand this class
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
kv_cache_spec,
|
|
layer_names: list[str],
|
|
vllm_config: VllmConfig,
|
|
device: torch.device,
|
|
metadata_cls: type[AscendSFAMetadata] | None = None,
|
|
supports_dcp_with_varlen: bool = False,
|
|
):
|
|
super().__init__(
|
|
kv_cache_spec,
|
|
layer_names,
|
|
vllm_config,
|
|
device,
|
|
metadata_cls if metadata_cls is not None else AscendSFAMetadata,
|
|
supports_dcp_with_varlen,
|
|
)
|
|
|
|
self.block_size = vllm_config.cache_config.block_size
|
|
self.max_blocks = (vllm_config.model_config.max_model_len + self.block_size - 1) // self.block_size
|
|
|
|
self.speculative_config = vllm_config.speculative_config
|
|
self.decode_threshold = 1
|
|
if self.speculative_config:
|
|
spec_token_num = self.speculative_config.num_speculative_tokens
|
|
self.decode_threshold += spec_token_num
|
|
assert self.decode_threshold <= 16, (
|
|
f"decode_threshold exceeded \
|
|
npu_fused_infer_attention_score TND layout's limit of 16, \
|
|
got {self.decode_threshold}"
|
|
)
|
|
self.reorder_batch_threshold = self.decode_threshold
|
|
self.attn_mask_builder = AttentionMaskBuilder(self.device)
|
|
self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
|
|
self.enable_dsa_cp = enable_dsa_cp()
|
|
|
|
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
|
|
self.actual_seq_lengths_query = torch.zeros(max_num_reqs + 1, dtype=torch.int32, device=device)
|
|
self.actual_seq_lengths_key = torch.empty_like(self.actual_seq_lengths_query)
|
|
|
|
@staticmethod
|
|
def determine_chunked_prefill_workspace_size(vllm_config: VllmConfig) -> int:
|
|
return ascend_chunked_prefill_workspace_size(vllm_config)
|
|
|
|
@classmethod
|
|
def get_cudagraph_support(
|
|
cls: type["AscendSFAMetadataBuilder"],
|
|
vllm_config: VllmConfig,
|
|
kv_cache_spec: AttentionSpec,
|
|
) -> AttentionCGSupport:
|
|
# Explicit override in case the underlying builder specialized this getter.
|
|
# @override omitted only because of mypy limitation due to type variable.
|
|
return AttentionCGSupport.UNIFORM_BATCH
|
|
|
|
def reorder_batch(self, input_batch: "NPUInputBatch", scheduler_output: "SchedulerOutput") -> bool:
|
|
# No need to reorder for Ascend SFA
|
|
return False
|
|
|
|
def build(
|
|
self,
|
|
common_prefix_len: int,
|
|
common_attn_metadata: AscendCommonAttentionMetadata,
|
|
fast_build: bool = False,
|
|
) -> AscendSFAMetadata:
|
|
num_reqs = common_attn_metadata.num_reqs
|
|
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
|
num_input_tokens = common_attn_metadata.num_input_tokens
|
|
|
|
block_table = common_attn_metadata.block_table_tensor[:num_reqs]
|
|
slot_mapping = common_attn_metadata.slot_mapping[:num_input_tokens]
|
|
input_positions = common_attn_metadata.positions[:num_input_tokens].long()
|
|
|
|
cum_query_lens = common_attn_metadata.query_start_loc[1 : num_reqs + 1]
|
|
seq_lens = common_attn_metadata.seq_lens[:num_reqs]
|
|
|
|
cos, sin = get_cos_and_sin_mla(input_positions, True)
|
|
|
|
dsa_cp_context = None
|
|
if self.enable_dsa_cp:
|
|
global_tp_size = get_tp_group().world_size
|
|
num_tokens = num_input_tokens
|
|
num_tokens_pad = _round_up(num_tokens, global_tp_size)
|
|
num_tokens_per_device = num_tokens_pad // global_tp_size
|
|
local_start = get_tp_group().rank_in_group * num_tokens_per_device
|
|
local_end_with_pad = local_start + num_tokens_per_device
|
|
local_end = min(local_end_with_pad, num_actual_tokens)
|
|
|
|
pad_size = num_tokens_pad - cos.shape[0]
|
|
assert cos.shape == sin.shape, f"cos.shape must be equal to sin.shape, got {cos.shape} and {sin.shape}"
|
|
|
|
if pad_size > 0:
|
|
cos = nn.functional.pad(cos, (0, 0, 0, 0, 0, 0, 0, pad_size))
|
|
sin = nn.functional.pad(sin, (0, 0, 0, 0, 0, 0, 0, pad_size))
|
|
|
|
pad_size_slot = num_tokens_pad - slot_mapping.shape[0]
|
|
if pad_size_slot > 0:
|
|
slot_mapping = nn.functional.pad(slot_mapping, (0, pad_size_slot), value=-1)
|
|
else:
|
|
slot_mapping = slot_mapping[:num_tokens_pad]
|
|
slot_mapping_cp = slot_mapping[local_start:local_end_with_pad]
|
|
|
|
cos = cos[local_start:local_end_with_pad]
|
|
sin = sin[local_start:local_end_with_pad]
|
|
|
|
assert cos.shape[0] == num_tokens_per_device, (
|
|
f"cos.shape[0] must be equal to num_tokens_per_device, \
|
|
got {cos.shape[0]} and {num_tokens_per_device}"
|
|
)
|
|
assert slot_mapping_cp.shape[0] == num_tokens_per_device, (
|
|
f"slot_mapping_cp.shape[0] must be equal to num_tokens_per_device, \
|
|
got {slot_mapping_cp.shape[0]} and {num_tokens_per_device}"
|
|
)
|
|
assert slot_mapping.shape[0] == num_tokens_pad, (
|
|
f"slot_mapping.shape[0] must be equal to num_tokens_pad, \
|
|
got {slot_mapping.shape[0]} and {num_tokens_pad}"
|
|
)
|
|
|
|
actual_seq_lengths_query = self.actual_seq_lengths_query
|
|
actual_seq_lengths_key = self.actual_seq_lengths_key
|
|
|
|
num_segs = cum_query_lens.shape[0]
|
|
last_token = 0
|
|
cum = 0
|
|
for i in range(0, num_segs):
|
|
global_start = last_token
|
|
global_end = cum_query_lens[i].item()
|
|
last_token = global_end
|
|
|
|
req_local_start = max(global_start, local_start)
|
|
req_local_end = min(global_end, local_end_with_pad)
|
|
num_local_tokens = req_local_end - req_local_start
|
|
|
|
if num_local_tokens > 0:
|
|
cum += num_local_tokens
|
|
actual_seq_lengths_query[i] = cum
|
|
|
|
offset = global_end - req_local_end
|
|
actual_seq_lengths_key[i] = seq_lens[i].item() - offset
|
|
else:
|
|
actual_seq_lengths_query[i] = cum
|
|
actual_seq_lengths_key[i] = 0
|
|
|
|
actual_seq_lengths_query = actual_seq_lengths_query[:num_reqs]
|
|
actual_seq_lengths_key = actual_seq_lengths_key[:num_reqs]
|
|
|
|
dsa_cp_context = DSACPContext(
|
|
num_tokens=num_tokens,
|
|
num_tokens_pad=num_tokens_pad,
|
|
local_start=local_start,
|
|
local_end=local_end,
|
|
local_end_with_pad=local_end_with_pad,
|
|
slot_mapping_cp=slot_mapping_cp,
|
|
actual_seq_lengths_query=actual_seq_lengths_query,
|
|
actual_seq_lengths_key=actual_seq_lengths_key,
|
|
)
|
|
|
|
return self.metadata_cls( # type: ignore
|
|
num_input_tokens=common_attn_metadata.num_input_tokens,
|
|
num_actual_tokens=num_actual_tokens,
|
|
cum_query_lens=cum_query_lens,
|
|
seq_lens=seq_lens,
|
|
slot_mapping=slot_mapping,
|
|
head_dim=self.model_config.get_head_size(),
|
|
attn_mask=self.attn_mask_builder.get_attention_mask(self.model_config),
|
|
attn_state=common_attn_metadata.attn_state,
|
|
block_table=block_table,
|
|
sin=sin[:num_input_tokens],
|
|
cos=cos[:num_input_tokens],
|
|
dsa_cp_context=dsa_cp_context,
|
|
)
|
|
|
|
def build_for_graph_capture(
|
|
self,
|
|
common_attn_metadata: AscendCommonAttentionMetadata,
|
|
attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly,
|
|
):
|
|
if attn_state in {AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding}:
|
|
attn_metadata = self.build(
|
|
common_prefix_len=0,
|
|
common_attn_metadata=common_attn_metadata,
|
|
)
|
|
else:
|
|
raise NotImplementedError("Currently we only support building dummy metadata for DecodeOnly state")
|
|
|
|
attn_metadata.attn_state = attn_state
|
|
return attn_metadata
|
|
|
|
|
|
class AscendSFAImpl(MLAAttentionImpl):
|
|
"""
|
|
NOTE: Please read the comment at the top of the file before trying to
|
|
understand this class
|
|
"""
|
|
|
|
# Supports forward using the all-gather o_proj weight for decode requests when Sharded CP is enabled.
|
|
o_proj_full_pool: torch.Tensor | None = None
|
|
|
|
# qk_hadamard tensor shared when dsa c8 enabled
|
|
qk_hadamard: torch.Tensor | None = None
|
|
|
|
def __init__(
|
|
self,
|
|
num_heads: int,
|
|
head_size: int,
|
|
scale: float,
|
|
num_kv_heads: int,
|
|
alibi_slopes: list[float] | None,
|
|
sliding_window: int | None,
|
|
kv_cache_dtype: str,
|
|
logits_soft_cap: float | None,
|
|
attn_type: str,
|
|
kv_sharing_target_layer_name: str | None,
|
|
**kwargs,
|
|
) -> None:
|
|
self.num_heads = num_heads
|
|
self.head_size = head_size
|
|
self.scale = float(scale)
|
|
self.num_kv_heads = num_kv_heads
|
|
self.kv_cache_dtype = kv_cache_dtype
|
|
|
|
# MLA Args
|
|
self.q_lora_rank = kwargs["q_lora_rank"]
|
|
self.kv_lora_rank = kwargs["kv_lora_rank"]
|
|
self.qk_nope_head_dim = kwargs["qk_nope_head_dim"]
|
|
self.qk_rope_head_dim = kwargs["qk_rope_head_dim"]
|
|
self.qk_head_dim = kwargs["qk_head_dim"]
|
|
self.v_head_dim = kwargs["v_head_dim"]
|
|
self.rotary_emb = kwargs["rotary_emb"]
|
|
self.q_proj = kwargs["q_proj"] if self.q_lora_rank is None else kwargs["q_b_proj"]
|
|
self.fused_qkv_a_proj = kwargs.get("fused_qkv_a_proj")
|
|
self.kv_b_proj = kwargs["kv_b_proj"]
|
|
self.o_proj = kwargs["o_proj"]
|
|
self.indexer = kwargs["indexer"]
|
|
self.kv_a_proj_with_mqa = kwargs.get("kv_a_proj_with_mqa")
|
|
self.kv_a_layernorm = kwargs.get("kv_a_layernorm")
|
|
self.q_a_layernorm = kwargs.get("q_a_layernorm")
|
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
|
self.tp_size = get_tensor_model_parallel_world_size()
|
|
self.tp_rank = get_tp_group().rank_in_group
|
|
self.q_b_proj = kwargs["q_b_proj"]
|
|
|
|
ascend_config = get_ascend_config()
|
|
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
|
|
|
# The MLAPO operator fuses the pre-processing steps on Q/K/V in MLA into a single operator
|
|
# NOTE: it imposes a limit on the number of input tokens and conflicts with FlashComm
|
|
self.enable_mlapo = envs.VLLM_ASCEND_ENABLE_MLAPO
|
|
|
|
assert self.indexer is not None, "Indexer is required for DSA."
|
|
|
|
self.local_num_heads = self.num_heads
|
|
self.vllm_config = get_current_vllm_config()
|
|
self.is_kv_producer = (
|
|
self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer
|
|
)
|
|
|
|
# indexer param
|
|
self.n_head: int = self.indexer.n_head # 64
|
|
self.head_dim: int = self.indexer.head_dim # 128
|
|
self.wq_b = self.indexer.wq_b
|
|
self.wk = self.indexer.wk
|
|
self.weights_proj = self.indexer.weights_proj
|
|
self.k_norm = self.indexer.k_norm
|
|
self.cp_size = 1
|
|
self.is_rope_neox_style = True
|
|
self.use_torch_npu_lightning_indexer = False
|
|
if self.vllm_config.model_config.hf_config.model_type in ["glm_moe_dsa"]:
|
|
self.is_rope_neox_style = False
|
|
self.use_torch_npu_lightning_indexer = True
|
|
|
|
# dsa c8
|
|
self.use_sparse_c8_indexer = ascend_config.enable_sparse_c8
|
|
if self.use_sparse_c8_indexer:
|
|
self.c8_k_cache_dtype = torch.int8
|
|
self.c8_k_scale_cache_dtype = torch.float16
|
|
|
|
# Effective in SFA when FlashComm is enabled.
|
|
self.enable_dsa_cp = enable_dsa_cp()
|
|
|
|
# Enable layer sharding via DSA-CP on the P node in the PD-disaggregated setup.
|
|
self.enable_dsa_cp_with_layer_shard = enable_dsa_cp_with_layer_shard()
|
|
|
|
# use original TP o_proj weight in PD mix stage, and full gather
|
|
# for o_proj weight for prefill stage.
|
|
self.enable_dsa_cp_with_o_proj_tp = enable_dsa_cp_with_o_proj_tp()
|
|
|
|
if self.enable_dsa_cp:
|
|
self.local_num_heads = self.num_heads * self.tp_size
|
|
if self.enable_dsa_cp_with_layer_shard:
|
|
self.layer_sharding_kwargs = []
|
|
for layer_name in get_ascend_config().layer_sharding or []:
|
|
if layer_name in kwargs:
|
|
self.layer_sharding_kwargs.append(kwargs[layer_name])
|
|
else:
|
|
logger.warning_once(
|
|
f"[SFAImpl init] Layer '{layer_name}' not found in kwargs for layer sharding, "
|
|
"skipping sharding configuration"
|
|
)
|
|
register_all_layers_to_shard_weight_series(self.layer_sharding_kwargs)
|
|
|
|
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
|
# NOTE: We currently do not support quant kv_b_proj.
|
|
assert isinstance(self.kv_b_proj.quant_method, UnquantizedLinearMethod)
|
|
# NOTE: Weight will be reshaped next, we need to revert and transpose it.
|
|
kv_b_proj_weight = torch_npu.npu_format_cast(self.kv_b_proj.weight.data, ACL_FORMAT_FRACTAL_ND).T
|
|
assert kv_b_proj_weight.shape == (
|
|
self.kv_lora_rank,
|
|
self.local_num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
|
), (
|
|
f"{kv_b_proj_weight.shape=}, "
|
|
f"{self.kv_lora_rank=}, "
|
|
f"{self.local_num_heads=}, "
|
|
f"{self.qk_nope_head_dim=}, "
|
|
f"{self.v_head_dim=}"
|
|
)
|
|
kv_b_proj_weight = kv_b_proj_weight.view(
|
|
self.kv_lora_rank,
|
|
self.local_num_heads,
|
|
self.qk_nope_head_dim + self.v_head_dim,
|
|
)
|
|
|
|
W_UK, W_UV = kv_b_proj_weight.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
|
|
|
# Convert from (L, N, V) to (N, L, V)
|
|
self.W_UV = W_UV.transpose(0, 1).contiguous()
|
|
# Convert from (L, N, P) to (N, P, L)
|
|
self.W_UK_T = W_UK.permute(1, 2, 0).contiguous()
|
|
|
|
# TODO(zzzzwwjj): Currently, torch.ops._C_ascend.batch_matmul_transpose cannot support weight nz
|
|
# self.W_UV = maybe_trans_nz(self.W_UV)
|
|
|
|
# Dispose kv_b_proj since it is replaced by W_UV and W_UK_T to save memory
|
|
dispose_layer(self.kv_b_proj)
|
|
if self.enable_dsa_cp:
|
|
if self.enable_dsa_cp_with_layer_shard:
|
|
for layer in self.layer_sharding_kwargs or []:
|
|
if is_hidden_layer(layer):
|
|
post_process_after_loading_for_shard_weight_series(layer)
|
|
else:
|
|
self._init_o_proj_tp_full_params()
|
|
|
|
if self.enable_mlapo:
|
|
quant_method = getattr(
|
|
getattr(self.fused_qkv_a_proj, "quant_method", None),
|
|
"quant_method",
|
|
None,
|
|
)
|
|
reasons = []
|
|
if self.fused_qkv_a_proj is None or not isinstance(quant_method, AscendW8A8LinearMethod):
|
|
reasons.append(
|
|
"Currently mlapo only supports W8A8 quantization in SFA scenario."
|
|
"Some layers in your model are not quantized with W8A8,"
|
|
"thus mlapo is disabled for these layers."
|
|
)
|
|
if self.enable_dsa_cp:
|
|
reasons.append("Currently mlapo does not support SFA with CP,thus mlapo is disabled for these layers.")
|
|
if reasons:
|
|
self.enable_mlapo = False
|
|
for msg in reasons:
|
|
logger.warning_once(msg)
|
|
else:
|
|
self._process_weights_for_fused_mlapo(act_dtype)
|
|
if not self.enable_mlapo:
|
|
# if mlapo, W_UK_T can't trans nz
|
|
self.W_UK_T = maybe_trans_nz(self.W_UK_T)
|
|
|
|
if self.use_sparse_c8_indexer and AscendSFAImpl.qk_hadamard is None:
|
|
AscendSFAImpl.qk_hadamard = torch.tensor(scipy.linalg.hadamard(128), dtype=torch.bfloat16, device="npu") / (
|
|
128**0.5
|
|
)
|
|
|
|
# Processing the input parameters for MLAPO by reordering and transposing
|
|
# QKV(and part of Q) weight, applying RoPE-related dimension transformations,
|
|
# and handling quantization parameters.
|
|
def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype):
|
|
assert self.kv_a_proj_with_mqa is None
|
|
assert self.fused_qkv_a_proj is not None
|
|
|
|
kv_a_proj_wt = self.fused_qkv_a_proj.weight.data[..., self.q_lora_rank :].contiguous()
|
|
q_a_proj_wt = self.fused_qkv_a_proj.weight.data[..., : self.q_lora_rank].contiguous()
|
|
|
|
kv_a_proj_wt = kv_a_proj_wt.t().contiguous()
|
|
kv_a_proj_wt = trans_rope_weight(kv_a_proj_wt, self.qk_rope_head_dim)
|
|
kv_a_proj_wt = kv_a_proj_wt.t().contiguous()
|
|
wd_qkv = torch.cat((kv_a_proj_wt, q_a_proj_wt), dim=-1)
|
|
wd_qkv = wd_qkv.t().contiguous()
|
|
wd_qkv = transdata(wd_qkv, block_size=(16, 32)).unsqueeze(0).contiguous()
|
|
self.wd_qkv = torch_npu.npu_format_cast(wd_qkv, 29)
|
|
|
|
kv_a_proj_deq_scl = self.fused_qkv_a_proj.deq_scale[self.q_lora_rank :].contiguous()
|
|
q_a_proj_deq_scl = self.fused_qkv_a_proj.deq_scale[: self.q_lora_rank].contiguous()
|
|
kv_a_proj_deq_scl = kv_a_proj_deq_scl.reshape(self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous()
|
|
kv_a_proj_deq_scl = trans_rope_weight(kv_a_proj_deq_scl, self.qk_rope_head_dim)
|
|
kv_a_proj_deq_scl = kv_a_proj_deq_scl.view(self.kv_lora_rank + self.qk_rope_head_dim).contiguous()
|
|
self.deq_scale_qkv = torch.cat((kv_a_proj_deq_scl, q_a_proj_deq_scl), dim=-1).contiguous()
|
|
|
|
kv_a_proj_qt_bias = self.fused_qkv_a_proj.quant_bias[self.q_lora_rank :].contiguous()
|
|
q_a_proj_qt_bias = self.fused_qkv_a_proj.quant_bias[: self.q_lora_rank].contiguous()
|
|
|
|
kv_a_proj_qt_bias = kv_a_proj_qt_bias.reshape(self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous()
|
|
kv_a_proj_qt_bias = trans_rope_weight(kv_a_proj_qt_bias, self.qk_rope_head_dim)
|
|
kv_a_proj_qt_bias = kv_a_proj_qt_bias.view(self.kv_lora_rank + self.qk_rope_head_dim).contiguous()
|
|
self.quant_bias_qkv = torch.cat((kv_a_proj_qt_bias, q_a_proj_qt_bias), dim=-1).contiguous()
|
|
|
|
wu_q = self.q_proj.weight.data
|
|
wu_q = wu_q.t().reshape(self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1)
|
|
wu_q = trans_rope_weight(wu_q, self.qk_rope_head_dim)
|
|
wu_q = wu_q.reshape(self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim), -1)
|
|
wu_q = transdata(wu_q, block_size=(16, 32)).unsqueeze(0).contiguous()
|
|
self.wu_q = torch_npu.npu_format_cast(wu_q, 29)
|
|
|
|
qb_deq_scl = self.q_proj.deq_scale.data
|
|
qb_deq_scl = qb_deq_scl.reshape(self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1)
|
|
qb_deq_scl = trans_rope_weight(qb_deq_scl, self.qk_rope_head_dim)
|
|
self.qb_deq_scl = qb_deq_scl.reshape(self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim))
|
|
|
|
qb_qt_bias = self.q_proj.quant_bias.data
|
|
qb_qt_bias = qb_qt_bias.reshape(self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1)
|
|
qb_qt_bias = trans_rope_weight(qb_qt_bias, self.qk_rope_head_dim)
|
|
self.qb_qt_bias = qb_qt_bias.reshape(self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim))
|
|
|
|
device = self.q_proj.weight.device
|
|
self.gamma1 = self.q_a_layernorm.weight.data # type: ignore[union-attr]
|
|
self.beta1 = self.q_a_layernorm.bias.data # type: ignore[union-attr]
|
|
self.gamma2 = self.kv_a_layernorm.weight.data # type: ignore[union-attr]
|
|
self.quant_scale0 = self.fused_qkv_a_proj.input_scale.data
|
|
self.quant_offset0 = self.fused_qkv_a_proj.input_offset.data
|
|
self.quant_scale1 = self.q_proj.input_scale.data
|
|
self.quant_offset1 = self.q_proj.input_offset.data
|
|
self.ctkv_scale = torch.tensor([1], dtype=act_dtype, device=device)
|
|
self.q_nope_scale = torch.tensor([1], dtype=act_dtype, device=device)
|
|
|
|
# On KV consumers (decode-only) MLAPO uses the transformed weights built above;
|
|
# the original fused_qkv_a_proj/q_proj weights and quant params are no longer
|
|
# referenced, so drop them to save memory.
|
|
if (
|
|
self.vllm_config.kv_transfer_config is not None
|
|
and self.vllm_config.kv_transfer_config.is_kv_consumer
|
|
and self.vllm_config.scheduler_config.max_num_batched_tokens <= MLAPO_MAX_SUPPORTED_TOKENS
|
|
):
|
|
self.fused_qkv_a_proj.weight = None
|
|
self.fused_qkv_a_proj.deq_scale = None
|
|
self.fused_qkv_a_proj.quant_bias = None
|
|
self.q_proj.weight = None
|
|
self.q_proj.deq_scale = None
|
|
self.q_proj.quant_bias = None
|
|
torch.npu.empty_cache()
|
|
|
|
def forward_mha(
|
|
self,
|
|
q: torch.Tensor,
|
|
kv_c_normed: torch.Tensor,
|
|
k_pe: torch.Tensor,
|
|
kv_c_and_k_pe_cache: torch.Tensor,
|
|
attn_metadata: M,
|
|
k_scale: torch.Tensor,
|
|
output: torch.Tensor,
|
|
) -> None:
|
|
raise NotImplementedError("forward_mha is not supported for SFA attention. Use forward() instead.")
|
|
|
|
def forward_mqa(
|
|
self,
|
|
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
|
kv_c_and_k_pe_cache: torch.Tensor,
|
|
attn_metadata: M,
|
|
layer,
|
|
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
|
raise NotImplementedError("forward_mqa is not supported for SFA attention. Use forward() instead.")
|
|
|
|
def rope_single(
|
|
self,
|
|
x: torch.Tensor,
|
|
cos: torch.Tensor,
|
|
sin: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
B, N, D = x.shape
|
|
S = 1
|
|
x = x.view(B, N, S, D)
|
|
x = torch_npu.npu_interleave_rope(x, cos, sin)
|
|
return x.view(B, N, D)
|
|
|
|
def _init_o_proj_tp_full_params(self):
|
|
"""
|
|
Initialize TP-mode and Full-mode parameters for o_proj weight,
|
|
preparing for weight switching in PD mix stage.
|
|
|
|
For PD mix stage:
|
|
- Use original TP o_proj weight for decode phase
|
|
- Need full-gather o_proj weight from all TP ranks for prefill phase
|
|
"""
|
|
if AscendSFAImpl.o_proj_full_pool is None:
|
|
sample = self.o_proj.weight
|
|
AscendSFAImpl.o_proj_full_pool = torch.empty(
|
|
(sample.shape[0] * self.tp_size, sample.shape[1]), dtype=sample.dtype, device=sample.device
|
|
)
|
|
|
|
# Save TP-mode parameters (original sharded weights)
|
|
self.o_proj_tp_weight = self.o_proj.weight.clone().detach()
|
|
self.o_proj_tp_aclnn_input_scale = self.o_proj.aclnn_input_scale.clone().detach()
|
|
self.o_proj_tp_aclnn_input_scale_reciprocal = self.o_proj.aclnn_input_scale_reciprocal.clone().detach()
|
|
self.o_proj_tp_aclnn_input_offset = self.o_proj.aclnn_input_offset.clone().detach()
|
|
|
|
# Initially switch to TP mode for graph capture
|
|
self.o_proj.weight.set_(self.o_proj_tp_weight)
|
|
self.o_proj.aclnn_input_scale.set_(self.o_proj_tp_aclnn_input_scale)
|
|
self.o_proj.aclnn_input_scale_reciprocal.set_(self.o_proj_tp_aclnn_input_scale_reciprocal)
|
|
self.o_proj.aclnn_input_offset.set_(self.o_proj_tp_aclnn_input_offset)
|
|
|
|
# Precompute Full-mode quantization parameters by repeating TP parameters across all TP ranks
|
|
self.o_proj_full_aclnn_input_scale = self.o_proj.aclnn_input_scale.repeat(self.tp_size)
|
|
self.o_proj_full_aclnn_input_scale_reciprocal = self.o_proj.aclnn_input_scale_reciprocal.repeat(self.tp_size)
|
|
self.o_proj_full_aclnn_input_offset = self.o_proj.aclnn_input_offset.repeat(self.tp_size)
|
|
|
|
def _handle_o_proj_weight_switch_and_forward(
|
|
self,
|
|
attn_output: torch.Tensor,
|
|
output: torch.Tensor,
|
|
o_proj_full_handle: torch.distributed.Work | None,
|
|
should_shard_weight: bool,
|
|
) -> tuple[torch.Tensor, bool]:
|
|
"""
|
|
Handle o_proj weight switching between TP-mode and Full-mode, and execute forward computation.
|
|
"""
|
|
# Gather o_proj weight from all TP ranks for Full-mode computation
|
|
if should_shard_weight:
|
|
# Wait for the completion of o_proj weight all-gather operation
|
|
if o_proj_full_handle is not None:
|
|
o_proj_full_handle.wait()
|
|
|
|
# Switch o_proj to Full-mode (gathered weight from all TP ranks)
|
|
self.o_proj.weight.set_(AscendSFAImpl.o_proj_full_pool)
|
|
self.o_proj.aclnn_input_scale.set_(self.o_proj_full_aclnn_input_scale)
|
|
self.o_proj.aclnn_input_scale_reciprocal.set_(self.o_proj_full_aclnn_input_scale_reciprocal)
|
|
self.o_proj.aclnn_input_offset.set_(self.o_proj_full_aclnn_input_offset)
|
|
|
|
# Apply quantization method and execute forward computation
|
|
output[...] = self.o_proj.quant_method.quant_method.apply(self.o_proj, attn_output)
|
|
|
|
# Switch o_proj back to TP-mode for subsequent decode operations
|
|
self.o_proj.weight.set_(self.o_proj_tp_weight)
|
|
self.o_proj.aclnn_input_scale.set_(self.o_proj_tp_aclnn_input_scale)
|
|
self.o_proj.aclnn_input_scale_reciprocal.set_(self.o_proj_tp_aclnn_input_scale_reciprocal)
|
|
self.o_proj.aclnn_input_offset.set_(self.o_proj_tp_aclnn_input_offset)
|
|
|
|
return output, False
|
|
else:
|
|
# For decode scenario: perform all-to-all communication on o_proj input activations
|
|
# Reshape for all-to-all: [batch * seq, tp_size, head_dim] -> [tp_size, batch * seq, head_dim]
|
|
send = (
|
|
attn_output.view(-1, self.tp_size, self.num_heads * self.v_head_dim)
|
|
.permute(1, 0, 2)
|
|
.reshape(-1, self.num_heads * self.v_head_dim)
|
|
)
|
|
|
|
attn_output = torch.empty_like(send)
|
|
torch.distributed.all_to_all_single(attn_output, send, group=get_tp_group().device_group)
|
|
|
|
return attn_output, True
|
|
|
|
def _get_full_kv(self, k, attn_metadata):
|
|
return k
|
|
|
|
def exec_kv(
|
|
self,
|
|
kv_no_split: torch.Tensor,
|
|
cos: torch.Tensor,
|
|
sin: torch.Tensor,
|
|
kv_cache: tuple,
|
|
slots: torch.Tensor,
|
|
attn_metadata: M,
|
|
):
|
|
B = kv_no_split.shape[0]
|
|
N = self.num_kv_heads
|
|
S = 1
|
|
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
|
|
kv_no_split = kv_no_split.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
|
|
cache_mode = "PA"
|
|
|
|
if self.enable_dsa_cp:
|
|
_, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache(
|
|
kv_no_split,
|
|
self.kv_a_layernorm.weight, # type: ignore[union-attr]
|
|
cos,
|
|
sin,
|
|
slots.to(torch.int64),
|
|
kv_cache[1],
|
|
kv_cache[0],
|
|
epsilon=self.kv_a_layernorm.variance_epsilon, # type: ignore[union-attr]
|
|
cache_mode=cache_mode,
|
|
is_output_kv=True,
|
|
)
|
|
return k_pe, k_nope
|
|
else:
|
|
torch_npu.npu_kv_rmsnorm_rope_cache(
|
|
kv_no_split,
|
|
self.kv_a_layernorm.weight, # type: ignore[union-attr]
|
|
cos,
|
|
sin,
|
|
slots.to(torch.int64),
|
|
kv_cache[1],
|
|
kv_cache[0],
|
|
epsilon=self.kv_a_layernorm.variance_epsilon, # type: ignore[union-attr]
|
|
cache_mode=cache_mode,
|
|
)
|
|
return None, None
|
|
|
|
# Return `ql_nope`, `q_pe`
|
|
def _q_proj_and_k_up_proj(self, x):
|
|
q_nope, q_pe = (
|
|
self.q_proj(x)[0]
|
|
.view(-1, self.local_num_heads, self.qk_head_dim)
|
|
.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
|
)
|
|
|
|
# Convert from (B, N, P) to (N, B, P)
|
|
q_nope = q_nope.transpose(0, 1)
|
|
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
|
|
ql_nope = torch.bmm(q_nope, self.W_UK_T)
|
|
# Convert from (N, B, L) to (B, N, L)
|
|
return ql_nope.transpose(0, 1), q_pe
|
|
|
|
def _v_up_proj(self, x):
|
|
num_input_tokens, _, _ = x.shape
|
|
if (
|
|
x.dtype in [torch.float16, torch.bfloat16]
|
|
and hasattr(torch.ops._C_ascend, "batch_matmul_transpose")
|
|
and num_input_tokens <= BMM_TRANS_MAX_SUPPORTED_TOKENS
|
|
):
|
|
x = x.view(-1, self.local_num_heads, self.kv_lora_rank)
|
|
res = torch.empty((num_input_tokens, self.local_num_heads, self.v_head_dim), dtype=x.dtype, device=x.device)
|
|
torch.ops._C_ascend.batch_matmul_transpose(x, self.W_UV, res)
|
|
x = res.reshape(-1, self.local_num_heads * self.v_head_dim)
|
|
else:
|
|
# Convert from (B, N, L) to (N, B, L)
|
|
x = x.view(-1, self.local_num_heads, self.kv_lora_rank).transpose(0, 1)
|
|
# # Multiply (N, B, L) x (N, L, V) -> (N, B, V)
|
|
x = torch.bmm(x, self.W_UV)
|
|
# # Convert from (N, B, V) to (B, N * V)
|
|
x = x.transpose(0, 1).reshape(-1, self.local_num_heads * self.v_head_dim)
|
|
return x
|
|
|
|
def _sfa_preprocess_with_mlapo(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
kv_cache: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
|
cos: torch.Tensor,
|
|
sin: torch.Tensor,
|
|
slot_mapping: torch.Tensor,
|
|
num_input_tokens: int,
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
k_nope, k_pe = kv_cache[0], kv_cache[1]
|
|
ql_nope = torch.empty(
|
|
(num_input_tokens, self.W_UK_T.shape[0], k_nope.shape[-1]),
|
|
dtype=hidden_states.dtype,
|
|
device=hidden_states.device,
|
|
)
|
|
q_pe = torch.empty(
|
|
(num_input_tokens, self.W_UK_T.shape[0], k_pe.shape[-1]),
|
|
dtype=hidden_states.dtype,
|
|
device=hidden_states.device,
|
|
)
|
|
q_c = torch.empty(
|
|
(num_input_tokens, self.q_lora_rank),
|
|
dtype=hidden_states.dtype,
|
|
device=hidden_states.device,
|
|
)
|
|
torch.ops._C_ascend.mla_preprocess(
|
|
hidden_states,
|
|
self.wd_qkv,
|
|
self.deq_scale_qkv,
|
|
self.gamma1,
|
|
self.beta1,
|
|
self.wu_q,
|
|
self.qb_deq_scl,
|
|
self.gamma2,
|
|
cos,
|
|
sin,
|
|
self.W_UK_T,
|
|
k_nope,
|
|
k_pe,
|
|
slot_mapping,
|
|
quant_scale0=self.quant_scale0,
|
|
quant_offset0=self.quant_offset0,
|
|
bias0=self.quant_bias_qkv,
|
|
quant_scale1=self.quant_scale1,
|
|
quant_offset1=self.quant_offset1,
|
|
bias1=self.qb_qt_bias,
|
|
ctkv_scale=self.ctkv_scale,
|
|
q_nope_scale=self.q_nope_scale,
|
|
cache_mode="krope_ctkv",
|
|
quant_mode="per_tensor_quant_asymm",
|
|
enable_inner_out=True,
|
|
q_out0=ql_nope,
|
|
kv_cache_out0=k_nope,
|
|
q_out1=q_pe,
|
|
kv_cache_out1=k_pe,
|
|
inner_out=q_c,
|
|
)
|
|
return hidden_states, ql_nope, q_pe, q_c
|
|
|
|
def indexer_select_pre_process(
|
|
self,
|
|
x: torch.Tensor,
|
|
cos: torch.Tensor,
|
|
sin: torch.Tensor,
|
|
):
|
|
k_li, _ = self.wk(x) # [b,s,7168] @ [7168,128] = [b,s,128]
|
|
k_li = self.k_norm(k_li).unsqueeze(1)
|
|
k_li = k_li.view(-1, 1, self.head_dim)
|
|
|
|
if HAS_TRITON:
|
|
cos = cos.view(-1, self.qk_rope_head_dim)
|
|
sin = sin.view(-1, self.qk_rope_head_dim)
|
|
k_li = rope_forward_triton_siso(
|
|
k_li, cos, sin, rope_dim=self.qk_rope_head_dim, is_neox_style=self.is_rope_neox_style
|
|
)
|
|
else:
|
|
k_li_pe, k_li_nope = torch.split(
|
|
k_li, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], dim=-1
|
|
)
|
|
|
|
cos = cos.view(-1, 1, 1, self.qk_rope_head_dim)
|
|
sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)
|
|
|
|
k_li_pe = k_li_pe.unsqueeze(2)
|
|
k_li_pe = torch_npu.npu_interleave_rope(k_li_pe, cos, sin)
|
|
k_li_pe = k_li_pe.squeeze(2)
|
|
|
|
k_li = torch.cat([k_li_pe, k_li_nope], dim=-1) # [b*s,128]
|
|
|
|
if self.use_sparse_c8_indexer:
|
|
k_li = k_li @ AscendSFAImpl.qk_hadamard
|
|
k_li, k_li_scale = torch_npu.npu_dynamic_quant(k_li.view(-1, self.head_dim), dst_type=self.c8_k_cache_dtype)
|
|
k_li_scale = k_li_scale.to(self.c8_k_scale_cache_dtype) # [b*s,]
|
|
k_li_scale = k_li_scale.unsqueeze(-1) # [b*s,1]
|
|
else:
|
|
k_li_scale = None
|
|
|
|
return k_li, k_li_scale
|
|
|
|
def indexer_select_post_process(
|
|
self,
|
|
x: torch.Tensor,
|
|
q_c: torch.Tensor,
|
|
kv_cache: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
|
attn_metadata: M,
|
|
cos: torch.Tensor,
|
|
sin: torch.Tensor,
|
|
actual_seq_lengths_query: torch.Tensor,
|
|
actual_seq_lengths_key: torch.Tensor,
|
|
):
|
|
weights, _ = self.weights_proj(x)
|
|
|
|
q_li, _ = self.wq_b(q_c) # [b,s,1536] @ [1536,64*128] = [b,s,64*128]
|
|
q_li = q_li.view(-1, self.n_head, self.head_dim) # [n_toks,64,128]
|
|
if HAS_TRITON:
|
|
q_li = rope_forward_triton_siso(
|
|
q_li, cos, sin, rope_dim=self.qk_rope_head_dim, is_neox_style=self.is_rope_neox_style
|
|
)
|
|
else:
|
|
q_li_pe, q_li_nope = torch.split(
|
|
q_li, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], dim=-1
|
|
) # [b,s,64,64+64]
|
|
|
|
q_li_pe = q_li_pe.unsqueeze(2)
|
|
q_li_pe = torch_npu.npu_rotary_mul(q_li_pe, cos, sin)
|
|
q_li_pe = q_li_pe.squeeze(2)
|
|
q_li = torch.cat([q_li_pe, q_li_nope], dim=-1) # [b*s,64,128]
|
|
|
|
if self.use_sparse_c8_indexer:
|
|
q_li_shape_ori = q_li.shape
|
|
q_li = q_li @ AscendSFAImpl.qk_hadamard
|
|
q_li, q_li_scale = torch_npu.npu_dynamic_quant(q_li.view(-1, self.head_dim), dst_type=self.c8_k_cache_dtype)
|
|
q_li_scale = q_li_scale.to(self.c8_k_scale_cache_dtype)
|
|
|
|
# DSV3.2 currently has graph compilation issues when using torch_npu.npu.lightning_indexer.
|
|
# So two branches are maintained temporarily.
|
|
# TODO: torch.ops._C_ascend.npu_lightning_indexer needs to be removed.
|
|
if self.use_sparse_c8_indexer:
|
|
assert len(kv_cache) == 4
|
|
weights = weights.to(torch.float16)
|
|
topk_indices = torch.ops._C_ascend.npu_lightning_indexer_quant(
|
|
query=q_li.view(q_li_shape_ori),
|
|
key=kv_cache[2],
|
|
weights=weights,
|
|
query_dequant_scale=q_li_scale.view(q_li_shape_ori[:-1]),
|
|
key_dequant_scale=kv_cache[3].squeeze(2), # B S N D -> B S D
|
|
actual_seq_lengths_query=actual_seq_lengths_query,
|
|
actual_seq_lengths_key=actual_seq_lengths_key,
|
|
block_table=attn_metadata.block_table,
|
|
query_quant_mode=0,
|
|
key_quant_mode=0,
|
|
layout_query="TND",
|
|
layout_key="PA_BSND",
|
|
sparse_count=2048,
|
|
sparse_mode=3,
|
|
)
|
|
elif self.use_torch_npu_lightning_indexer:
|
|
topk_indices, _ = torch_npu.npu_lightning_indexer(
|
|
query=q_li,
|
|
key=kv_cache[2],
|
|
weights=weights,
|
|
actual_seq_lengths_query=actual_seq_lengths_query,
|
|
actual_seq_lengths_key=actual_seq_lengths_key,
|
|
block_table=attn_metadata.block_table,
|
|
layout_query="TND",
|
|
layout_key="PA_BSND",
|
|
sparse_count=2048,
|
|
sparse_mode=3,
|
|
)
|
|
else:
|
|
topk_indices = torch.ops._C_ascend.npu_lightning_indexer(
|
|
query=q_li,
|
|
key=kv_cache[2],
|
|
weights=weights,
|
|
actual_seq_lengths_query=actual_seq_lengths_query,
|
|
actual_seq_lengths_key=actual_seq_lengths_key,
|
|
block_table=attn_metadata.block_table,
|
|
layout_query="TND",
|
|
layout_key="PA_BSND",
|
|
sparse_count=2048,
|
|
sparse_mode=3,
|
|
)
|
|
return topk_indices
|
|
|
|
def _execute_sparse_flash_attention_process(
|
|
self, ql_nope, q_pe, kv_cache, topk_indices, attn_metadata, actual_seq_lengths_query, actual_seq_lengths_key
|
|
):
|
|
block_table = attn_metadata.block_table
|
|
kv = kv_cache[0]
|
|
key_rope = kv_cache[1]
|
|
|
|
attn_output = torch.ops._C_ascend.npu_sparse_flash_attention(
|
|
query=ql_nope,
|
|
key=kv,
|
|
value=kv,
|
|
sparse_indices=topk_indices,
|
|
scale_value=self.scale,
|
|
sparse_block_size=1,
|
|
block_table=block_table,
|
|
actual_seq_lengths_query=actual_seq_lengths_query,
|
|
actual_seq_lengths_kv=actual_seq_lengths_key,
|
|
query_rope=q_pe,
|
|
key_rope=key_rope,
|
|
layout_query="TND",
|
|
layout_kv="PA_BSND",
|
|
sparse_mode=3,
|
|
)
|
|
return attn_output
|
|
|
|
def forward(
|
|
self,
|
|
layer_name,
|
|
hidden_states: torch.Tensor, # query in unified attn
|
|
kv_cache: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
|
attn_metadata: M,
|
|
need_gather_q_kv: bool = False,
|
|
output: torch.Tensor | None = None,
|
|
) -> torch.Tensor:
|
|
assert output is not None, "Output tensor must be provided."
|
|
if attn_metadata is None:
|
|
# Profiling run.
|
|
if self.enable_dsa_cp_with_layer_shard and not _EXTRA_CTX.in_profile_run:
|
|
for layer in self.layer_sharding_kwargs or []:
|
|
if is_hidden_layer(layer):
|
|
reach_layer_for_shard_weight_series(layer)
|
|
return output.fill_(0)
|
|
|
|
cos = attn_metadata.cos
|
|
sin = attn_metadata.sin
|
|
slot_mapping = attn_metadata.slot_mapping
|
|
slot_mapping_cp = None
|
|
if self.enable_dsa_cp:
|
|
assert attn_metadata.dsa_cp_context is not None
|
|
slot_mapping_cp = attn_metadata.dsa_cp_context.slot_mapping_cp
|
|
actual_seq_lengths_query = attn_metadata.dsa_cp_context.actual_seq_lengths_query
|
|
actual_seq_lengths_key = attn_metadata.dsa_cp_context.actual_seq_lengths_key
|
|
else:
|
|
actual_seq_lengths_query = attn_metadata.cum_query_lens
|
|
actual_seq_lengths_key = attn_metadata.seq_lens
|
|
|
|
# Inputs and outputs may be padded for CUDA graphs
|
|
num_input_tokens = attn_metadata.num_input_tokens
|
|
output_padded = output
|
|
|
|
# all-gather o_proj weight for prefill stage of PD mix node
|
|
o_proj_full_handle = None
|
|
# if is PD mix stage, using original TP o_proj weight, and also need to full gather for o_proj
|
|
# weight for prefill stage.
|
|
full_gather_o_proj_enabled = self.enable_dsa_cp_with_o_proj_tp and attn_metadata.attn_state not in {
|
|
AscendAttentionState.DecodeOnly,
|
|
AscendAttentionState.SpecDecoding,
|
|
}
|
|
|
|
# run mlapo ops when dsa-cp is disabled, and ensure that num_tokens satisfies the count limitation
|
|
if self.enable_mlapo and num_input_tokens <= MLAPO_MAX_SUPPORTED_TOKENS:
|
|
hidden_states, ql_nope, q_pe, q_c = self._sfa_preprocess_with_mlapo(
|
|
hidden_states=hidden_states,
|
|
kv_cache=kv_cache,
|
|
cos=cos,
|
|
sin=sin,
|
|
slot_mapping=slot_mapping,
|
|
num_input_tokens=num_input_tokens,
|
|
)
|
|
k_li = self.indexer_select_pre_process(x=hidden_states, cos=cos, sin=sin)
|
|
# native
|
|
else:
|
|
assert self.fused_qkv_a_proj is not None, "q lora is required for DSA."
|
|
weight_prefetch_method = get_weight_prefetch_method()
|
|
weight_prefetch_method.maybe_prefetch_mla_or_sla_weight_in_current_stream(
|
|
inputs=self.fused_qkv_a_proj.weight, dependency=hidden_states
|
|
)
|
|
qkv_lora = self.fused_qkv_a_proj(hidden_states)[0]
|
|
q_c, kv_no_split = qkv_lora.split(
|
|
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
|
|
dim=-1,
|
|
)
|
|
assert self.q_a_layernorm is not None, "q_a_layernorm must be initialized"
|
|
q_c = self.q_a_layernorm(q_c)
|
|
|
|
k_li, k_li_scale = self.indexer_select_pre_process(x=hidden_states, cos=cos, sin=sin)
|
|
|
|
wait_for_kv_layer_from_connector(layer_name)
|
|
|
|
if self.enable_dsa_cp:
|
|
assert slot_mapping_cp is not None
|
|
k_pe, k_nope = self.exec_kv(kv_no_split, cos, sin, kv_cache, slot_mapping_cp, attn_metadata)
|
|
else:
|
|
k_pe, k_nope = self.exec_kv(kv_no_split, cos, sin, kv_cache, slot_mapping, attn_metadata)
|
|
|
|
if self.enable_dsa_cp:
|
|
assert k_pe is not None
|
|
assert k_nope is not None
|
|
assert k_li is not None
|
|
async_op = self.enable_dsa_cp_with_layer_shard or full_gather_o_proj_enabled
|
|
# support all_gather kv async for communication calculation overlap
|
|
if not self.use_sparse_c8_indexer:
|
|
fused_kv_no_split, kv_ag_handle = all_gather_async(
|
|
torch.cat(
|
|
[
|
|
k_pe.view(-1, k_pe.shape[-1]),
|
|
k_nope.view(-1, k_nope.shape[-1]),
|
|
k_li.view(-1, k_li.shape[-1]),
|
|
],
|
|
dim=1,
|
|
),
|
|
get_tp_group(),
|
|
async_op=async_op,
|
|
)
|
|
else:
|
|
# due to different dtypes, we have to split commu pass
|
|
assert k_li_scale is not None
|
|
fused_kv_no_split, _ = all_gather_async(
|
|
torch.cat(
|
|
[
|
|
k_pe.view(-1, k_pe.shape[-1]),
|
|
k_nope.view(-1, k_nope.shape[-1]),
|
|
],
|
|
dim=1,
|
|
),
|
|
get_tp_group(),
|
|
async_op=async_op,
|
|
)
|
|
k_li, _ = all_gather_async(
|
|
k_li,
|
|
get_tp_group(),
|
|
async_op=async_op,
|
|
)
|
|
k_li_scale, kv_ag_handle = all_gather_async(
|
|
k_li_scale,
|
|
get_tp_group(),
|
|
async_op=async_op,
|
|
)
|
|
|
|
ql_nope, q_pe = self._q_proj_and_k_up_proj(q_c)
|
|
q_pe = self.rope_single(q_pe, cos, sin)
|
|
|
|
if self.enable_dsa_cp:
|
|
if kv_ag_handle is not None:
|
|
kv_ag_handle.wait()
|
|
|
|
if self.enable_dsa_cp_with_layer_shard:
|
|
for layer in self.layer_sharding_kwargs or []:
|
|
if is_hidden_layer(layer):
|
|
reach_layer_for_shard_weight_series(layer)
|
|
elif full_gather_o_proj_enabled:
|
|
_, o_proj_full_handle = all_gather_async(
|
|
self.o_proj_tp_weight, get_tp_group(), output=AscendSFAImpl.o_proj_full_pool
|
|
)
|
|
|
|
if kv_cache is not None:
|
|
assert fused_kv_no_split is not None
|
|
if not self.use_sparse_c8_indexer:
|
|
k_pe, k_nope, k_li = fused_kv_no_split.split(
|
|
[self.qk_rope_head_dim, self.kv_lora_rank, self.head_dim], dim=-1
|
|
)
|
|
else:
|
|
k_pe, k_nope = fused_kv_no_split.split([self.qk_rope_head_dim, self.kv_lora_rank], dim=-1)
|
|
k_nope = k_nope.view(k_nope.shape[0], 1, -1)
|
|
k_pe = k_pe.view(k_pe.shape[0], 1, -1)
|
|
DeviceOperator.reshape_and_cache(
|
|
key=k_nope[: attn_metadata.num_actual_tokens],
|
|
value=k_pe[: attn_metadata.num_actual_tokens],
|
|
key_cache=kv_cache[0],
|
|
value_cache=kv_cache[1],
|
|
slot_mapping=slot_mapping[: attn_metadata.num_actual_tokens],
|
|
)
|
|
|
|
k_li = self._get_full_kv(k_li, attn_metadata)
|
|
|
|
if kv_cache is not None:
|
|
if self.is_kv_producer:
|
|
attn_metadata.reshape_cache_event = torch.npu.Event()
|
|
torch_npu.npu_scatter_nd_update_(
|
|
kv_cache[2].view(-1, k_li.shape[-1]), slot_mapping.view(-1, 1), k_li.view(-1, k_li.shape[-1])
|
|
) # b, s, n, d
|
|
if self.use_sparse_c8_indexer:
|
|
assert len(kv_cache) == 4
|
|
torch_npu.npu_scatter_nd_update_(
|
|
kv_cache[3].view(-1, k_li_scale.shape[-1]),
|
|
slot_mapping.view(-1, 1),
|
|
k_li_scale.view(-1, k_li_scale.shape[-1]),
|
|
)
|
|
if self.is_kv_producer:
|
|
attn_metadata.reshape_cache_event.record()
|
|
|
|
topk_indices = self.indexer_select_post_process(
|
|
x=hidden_states,
|
|
q_c=q_c,
|
|
kv_cache=kv_cache,
|
|
attn_metadata=attn_metadata,
|
|
cos=cos,
|
|
sin=sin,
|
|
actual_seq_lengths_query=actual_seq_lengths_query,
|
|
actual_seq_lengths_key=actual_seq_lengths_key,
|
|
)
|
|
|
|
attn_output = self._execute_sparse_flash_attention_process(
|
|
ql_nope, q_pe, kv_cache, topk_indices, attn_metadata, actual_seq_lengths_query, actual_seq_lengths_key
|
|
)
|
|
|
|
attn_output = self._v_up_proj(attn_output)
|
|
weight_prefetch_method = get_weight_prefetch_method()
|
|
weight_prefetch_method.maybe_prefetch_mla_or_sla_weight_in_current_stream(
|
|
inputs=self.o_proj.weight,
|
|
dependency=attn_output,
|
|
max_size=MAX_O_PROJ_PREFETCH_SIZE,
|
|
linear_layer=self.o_proj,
|
|
)
|
|
|
|
if self.enable_dsa_cp_with_o_proj_tp:
|
|
# When using SFA-CP with pd mixed, o_proj has two cases:
|
|
# 1. prefill: o_proj is a TP weight, we need to all-gather o_proj weight to switch TP=1.
|
|
# 2. decode: all-to-all the hidden_state before the o_proj forward.
|
|
result, require_o_proj_forward = self._handle_o_proj_weight_switch_and_forward(
|
|
attn_output=attn_output,
|
|
output=output,
|
|
o_proj_full_handle=o_proj_full_handle,
|
|
should_shard_weight=full_gather_o_proj_enabled,
|
|
)
|
|
if not require_o_proj_forward:
|
|
return result
|
|
attn_output = result
|
|
|
|
output[...] = self.o_proj(attn_output)[0]
|
|
|
|
maybe_save_kv_layer_to_connector(layer_name, list(kv_cache))
|
|
|
|
return output_padded
|