Files
xc-llm-ascend/vllm_ascend/attention/sfa_v1.py
Qiu 64669c4243 [misc][FlashComm1][ACLGraph] Incompatibility between Flashcomm1 and FULL_DECODE_ONLY. (#5200)
### What this PR does / why we need it?
Currently, Flashcomm1 and FULL_DECODE_ONLY are incompatible. When both
features are enabled, graph capture errors occur without clear error
messages.

After discussion, it has been determined that enabling FULL_DECODE_ONLY
with Flashcomm1 in mixed deployment scenarios provides almost no TPOT
benefit. Additionally, a reconstruction of the decode phase for
flashcomm1 is currently underway. Therefore, related adaptation work is
temporarily postponed and will be addressed after the decode phase
reconstruction plan is finalized.

For now, an assert will be added to provide clear error messages and
correct deployment recommendations.

### Does this PR introduce _any_ user-facing change?
NO

### How was this patch tested?
NO

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
2025-12-22 14:33:32 +08:00

976 lines
40 KiB
Python

from dataclasses import dataclass
from typing import TYPE_CHECKING, ClassVar, Optional, Tuple, Type, TypeVar
import torch
import torch_npu
from torch import nn
from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl
from vllm.config import CUDAGraphMode, VllmConfig, get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
from vllm.forward_context import get_forward_context
from vllm.logger import logger
from vllm.model_executor.layers.linear import (ReplicatedLinear,
UnquantizedLinearMethod)
from vllm.triton_utils import HAS_TRITON
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm_ascend import envs
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
trans_rope_weight, transdata,
wait_for_kv_layer_from_connector)
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
from vllm_ascend.ops.shared_weight_layer import (
is_hidden_layer, post_process_after_loading_for_shared_weight_series,
reach_layer_for_shared_weight_series,
register_layer_to_shared_weight_series)
from vllm_ascend.ops.triton.rope import rope_forward_triton
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, _round_up, dispose_layer,
enable_sp, maybe_trans_nz, replace_layer)
from vllm_ascend.worker.npu_input_batch import NPUInputBatch
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
class AscendSFABackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_name() -> str:
return "ASCEND_SFA"
@staticmethod
def get_builder_cls():
return AscendSFAMetadataBuilder
@staticmethod
def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int,
head_size: int) -> tuple[int, ...]:
return (num_blocks, block_size, num_kv_heads, head_size)
@staticmethod
def get_impl_cls() -> Type["AscendSFAImpl"]:
return AscendSFAImpl
@dataclass
class SfaCpContext:
num_tokens: int
num_tokens_pad: int
local_start: int
local_end: int
local_end_with_pad: int
slot_mapping_cp: torch.Tensor
actual_seq_lengths_query: torch.Tensor
actual_seq_lengths_key: torch.Tensor
@dataclass
class AscendSFAMetadata:
"""Metadata for MLACommon.
NOTE: Please read the comment at the top of the file before trying to
understand this class
"""
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
has_prefill: bool
num_actual_tokens: int # Number of tokens excluding padding.
slot_mapping: torch.Tensor
seq_lens: torch.Tensor
cum_query_lens: torch.Tensor
block_tables: torch.Tensor
sin: torch.Tensor
cos: torch.Tensor
# For logging.
num_input_tokens: int = 0 # Number of tokens including padding.
# The dimension of the attention heads
head_dim: Optional[int] = None
attn_mask: torch.Tensor = None
# chunked prefill by default if no attn_states passed
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
sfa_cp_context: Optional[SfaCpContext] = None
M = TypeVar("M", bound=AscendSFAMetadata)
class AscendSFAMetadataBuilder:
# Does this backend/builder support ACL Graphs for attention (default: no).
aclgraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
"""
NOTE: Please read the comment at the top of the file before trying to
understand this class
"""
# _attn_mask_builder = None
def __init__(self,
kv_cache_spec,
layer_names,
vllm_config: VllmConfig,
device: torch.device,
metadata_cls: Optional[AscendSFAMetadata] = None):
self.metadata_cls: Optional[AscendSFAMetadata] = metadata_cls \
if metadata_cls is not None else AscendSFAMetadata # type: ignore
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.device = device
self.block_size = vllm_config.cache_config.block_size
self.max_blocks = (vllm_config.model_config.max_model_len +
self.block_size - 1) // self.block_size
self.speculative_config = vllm_config.speculative_config
self.decode_threshold = 1
if self.speculative_config:
spec_token_num = self.speculative_config.num_speculative_tokens
self.decode_threshold += spec_token_num
assert self.decode_threshold <= 16, f"decode_threshold exceeded \
npu_fused_infer_attention_score TND layout's limit of 16, \
got {self.decode_threshold}"
self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
self.cos_cache = None
self.sin_cache = None
self.enable_sfa_cp = enable_sp() and \
hasattr(self.model_config.hf_config, "index_topk")
assert not (
self.enable_sfa_cp
and self.vllm_config.compilation_config.cudagraph_mode
== CUDAGraphMode.FULL_DECODE_ONLY
), "FlashComm1 is not compatible with FULL_DECODE_ONLY. Please set graph_mode to 'piecewise' or disable FlashComm1."
def reorder_batch(self, input_batch: "NPUInputBatch",
scheduler_output: "SchedulerOutput") -> bool:
# No need to reorder for Ascend SFA
return False
def build(
self,
common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata,
model: nn.Module,
) -> AscendSFAMetadata:
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
num_input_tokens = common_attn_metadata.num_input_tokens
block_table = common_attn_metadata.block_table_tensor[:num_reqs]
slot_mapping = common_attn_metadata.slot_mapping[:num_input_tokens]
input_positions = common_attn_metadata.positions[:
num_input_tokens].long(
)
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
has_prefill = any(query_lens_cpu > self.decode_threshold)
if self.cos_cache is None:
self.cos_cache = model.model.layers[
model.model.start_layer].self_attn.rotary_emb.cos_cached
self.sin_cache = model.model.layers[
model.model.start_layer].self_attn.rotary_emb.sin_cached
if self.cos_cache.dtype != self.model_config.dtype: # type: ignore
self.cos_cache = self.cos_cache.to( # type: ignore
self.model_config.dtype) # type: ignore
self.sin_cache = self.sin_cache.to( # type: ignore
self.model_config.dtype) # type: ignore
cum_query_lens = common_attn_metadata.query_start_loc[1:num_reqs + 1]
seq_lens = common_attn_metadata.seq_lens[:num_reqs]
cos, sin = get_cos_and_sin_mla()
assert self.cos_cache is not None and self.sin_cache is not None
new_cos = self.cos_cache[input_positions][:, None, None]
new_sin = self.sin_cache[input_positions][:, None, None]
if (cos is not None and sin is not None
and num_input_tokens <= cos.shape[0]
and num_input_tokens <= sin.shape[0]):
cos[:num_input_tokens] = new_cos
sin[:num_input_tokens] = new_sin
else:
cos, sin = new_cos, new_sin
sfa_cp_context = None
if self.enable_sfa_cp:
global_tp_size = get_tp_group().world_size
num_tokens = num_input_tokens
num_tokens_pad = _round_up(num_tokens, global_tp_size)
num_tokens_per_device = num_tokens_pad // global_tp_size
local_start = get_tp_group().rank_in_group * num_tokens_per_device
local_end_with_pad = local_start + num_tokens_per_device
local_end = min(local_end_with_pad, num_actual_tokens)
pad_size = num_tokens_pad - cos.shape[0]
assert cos.shape == sin.shape, \
f"cos.shape must be equal to sin.shape, got {cos.shape} and {sin.shape}"
if pad_size > 0:
cos = nn.functional.pad(cos, (0, 0, 0, 0, 0, 0, 0, pad_size))
sin = nn.functional.pad(sin, (0, 0, 0, 0, 0, 0, 0, pad_size))
pad_size_slot = num_tokens_pad - slot_mapping.shape[0]
if pad_size_slot > 0:
slot_mapping = nn.functional.pad(slot_mapping,
(0, pad_size_slot),
value=-1)
else:
slot_mapping = slot_mapping[:num_tokens_pad]
cos = cos[local_start:local_end_with_pad]
sin = sin[local_start:local_end_with_pad]
slot_mapping_cp = slot_mapping[local_start:local_end_with_pad]
assert cos.shape[0] == num_tokens_per_device, \
f"cos.shape[0] must be equal to num_tokens_per_device, \
got {cos.shape[0]} and {num_tokens_per_device}"
assert slot_mapping_cp.shape[0] == num_tokens_per_device, \
f"slot_mapping_cp.shape[0] must be equal to num_tokens_per_device, \
got {slot_mapping_cp.shape[0]} and {num_tokens_per_device}"
assert slot_mapping.shape[0] == num_tokens_pad, \
f"slot_mapping.shape[0] must be equal to num_tokens_pad, \
got {slot_mapping.shape[0]} and {num_tokens_pad}"
actual_seq_lengths_query = torch.empty_like(cum_query_lens)
actual_seq_lengths_key = torch.empty_like(seq_lens)
num_segs = cum_query_lens.shape[0]
last_token = 0
cum = 0
for i in range(0, num_segs):
global_start = last_token
global_end = cum_query_lens[i].item()
last_token = global_end
local_start = max(global_start, local_start)
local_end = min(global_end, local_end_with_pad)
num_local_tokens = local_end - local_start
if num_local_tokens > 0:
cum += num_local_tokens
actual_seq_lengths_query[i] = cum
offset = global_end - local_end
actual_seq_lengths_key[i] = seq_lens[i].item() - offset
else:
actual_seq_lengths_query[i] = cum
actual_seq_lengths_key[i] = 0
sfa_cp_context = SfaCpContext(
num_tokens=num_tokens,
num_tokens_pad=num_tokens_pad,
local_start=local_start,
local_end=local_end,
local_end_with_pad=local_end_with_pad,
slot_mapping_cp=slot_mapping_cp,
actual_seq_lengths_query=actual_seq_lengths_query,
actual_seq_lengths_key=actual_seq_lengths_key,
)
return self.metadata_cls( # type: ignore
has_prefill=has_prefill,
num_input_tokens=common_attn_metadata.num_input_tokens,
num_actual_tokens=num_actual_tokens,
cum_query_lens=cum_query_lens,
seq_lens=seq_lens,
slot_mapping=slot_mapping,
head_dim=self.model_config.get_head_size(),
attn_mask=common_attn_metadata.attn_mask,
attn_state=common_attn_metadata.attn_state,
block_tables=block_table,
sin=sin[:num_input_tokens],
cos=cos[:num_input_tokens],
sfa_cp_context=sfa_cp_context)
def build_for_graph_capture(
self,
common_attn_metadata: AscendCommonAttentionMetadata,
attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly,
model: Optional[nn.Module] = None,
):
if attn_state in {
AscendAttentionState.DecodeOnly,
AscendAttentionState.SpecDecoding
}:
attn_metadata = self.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
model=model,
)
else:
raise NotImplementedError(
"Currently we only support building dummy metadata for DecodeOnly state"
)
attn_metadata.attn_state = attn_state
return attn_metadata
class AscendSFAImpl(MLAAttentionImpl):
"""
NOTE: Please read the comment at the top of the file before trying to
understand this class
"""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
**kwargs,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
self.kv_cache_dtype = kv_cache_dtype
# MLA Args
self.q_lora_rank = kwargs['q_lora_rank']
self.kv_lora_rank = kwargs['kv_lora_rank']
self.qk_nope_head_dim = kwargs['qk_nope_head_dim']
self.qk_rope_head_dim = kwargs['qk_rope_head_dim']
self.qk_head_dim = kwargs['qk_head_dim']
self.v_head_dim = kwargs['v_head_dim']
self.rotary_emb = kwargs['rotary_emb']
self.q_proj = kwargs['q_proj'] if self.q_lora_rank is None else kwargs[
'q_b_proj']
self.fused_qkv_a_proj = kwargs.get('fused_qkv_a_proj', None)
self.kv_b_proj = kwargs['kv_b_proj']
self.o_proj = kwargs['o_proj']
self.indexer = kwargs['indexer']
self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None)
self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None)
self.q_a_layernorm = kwargs.get('q_a_layernorm', None)
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tp_group().rank_in_group
self.num_heads_per_rank = self.num_heads // self.tp_size
self.q_b_proj = kwargs['q_b_proj']
ascend_config = get_ascend_config()
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
self.enable_prefetch = ascend_config.weight_prefetch_config.enabled
self.enable_mlapo = envs.VLLM_ASCEND_ENABLE_MLAPO
assert self.indexer is not None, "Indexer is required for DSA."
self.enable_sfa_cp = enable_sp()
self.local_num_heads = self.num_heads
self.vllm_config = get_current_vllm_config()
if self.enable_sfa_cp:
self.local_num_heads = self.num_heads * self.tp_size
#TODO: Temporarily adapt sfa-cp, remove after adapting near PCP. --clrs97
self._replace_linear_class_for_sfa_cp()
from vllm_ascend.distributed.parallel_state import \
get_shared_weight_group
if is_hidden_layer(self.vllm_config, self.q_proj):
register_layer_to_shared_weight_series(
series_name="q_proj",
group=get_shared_weight_group(),
layer=self.q_proj,
prefetch_step=1)
if is_hidden_layer(self.vllm_config, self.o_proj):
register_layer_to_shared_weight_series(
series_name="o_proj",
group=get_shared_weight_group(),
layer=self.o_proj,
prefetch_step=1)
# indexer param
self.n_head: int = self.indexer.n_head # 64
self.head_dim: int = self.indexer.head_dim # 128
self.wq_b = self.indexer.wq_b
self.wk = self.indexer.wk
self.weights_proj = self.indexer.weights_proj
self.k_norm = self.indexer.k_norm
self.cp_size = 1
def process_weights_after_loading(self, act_dtype: torch.dtype):
# NOTE: We currently do not support quant kv_b_proj.
assert isinstance(self.kv_b_proj.quant_method, UnquantizedLinearMethod)
# NOTE: Weight will be reshaped next, we need to revert and transpose it.
kv_b_proj_weight = torch_npu.npu_format_cast(
self.kv_b_proj.weight.data, ACL_FORMAT_FRACTAL_ND).T
assert kv_b_proj_weight.shape == (
self.kv_lora_rank, self.local_num_heads *
(self.qk_nope_head_dim + self.v_head_dim)), (
f"{kv_b_proj_weight.shape=}, "
f"{self.kv_lora_rank=}, "
f"{self.local_num_heads=}, "
f"{self.qk_nope_head_dim=}, "
f"{self.v_head_dim=}")
kv_b_proj_weight = kv_b_proj_weight.view(
self.kv_lora_rank,
self.local_num_heads,
self.qk_nope_head_dim + self.v_head_dim,
)
W_UK, W_UV = kv_b_proj_weight.split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
# Convert from (L, N, V) to (N, L, V)
self.W_UV = W_UV.transpose(0, 1).contiguous()
# Convert from (L, N, P) to (N, P, L)
self.W_UK_T = W_UK.permute(1, 2, 0).contiguous()
# TODO(zzzzwwjj): Currently, torch.ops._C_ascend.batch_matmul_transpose cannot support weight nz
# self.W_UV = maybe_trans_nz(self.W_UV)
# Dispose kv_b_proj since it is replaced by W_UV and W_UK_T to save memory
dispose_layer(self.kv_b_proj)
if self.enable_sfa_cp:
if is_hidden_layer(self.vllm_config, self.q_proj):
post_process_after_loading_for_shared_weight_series(
self.q_proj)
if is_hidden_layer(self.vllm_config, self.o_proj):
post_process_after_loading_for_shared_weight_series(
self.o_proj)
if self.enable_mlapo:
quant_method = getattr(
getattr(self.fused_qkv_a_proj, "quant_method", None),
"quant_method",
None,
)
reasons = []
if self.fused_qkv_a_proj is None or not isinstance(
quant_method, AscendW8A8LinearMethod):
reasons.append(
"Currently mlapo only supports W8A8 quantization in SFA scenario."
"Some layers in your model are not quantized with W8A8,"
"thus mlapo is disabled for these layers.")
if self.enable_sfa_cp:
reasons.append("Currently mlapo does not support SFA with CP,"
"thus mlapo is disabled for these layers.")
if reasons:
self.enable_mlapo = False
for msg in reasons:
logger.warning_once(msg)
else:
self._process_weights_for_fused_mlapo(act_dtype)
if not self.enable_mlapo:
# if mlapo, W_UK_T can't trans nz
self.W_UK_T = maybe_trans_nz(self.W_UK_T)
def _v_up_proj(self, x):
forward_context = get_forward_context()
if x.dtype in [torch.float16, torch.bfloat16] \
and hasattr(torch.ops._C_ascend, "batch_matmul_transpose") \
and not self.enable_sfa_cp \
and not forward_context.with_prefill:
x = x.view(-1, self.num_heads, self.kv_lora_rank)
b, _, _ = x.shape
res = torch.empty((b, self.num_heads, self.v_head_dim),
dtype=x.dtype,
device=x.device)
torch.ops._C_ascend.batch_matmul_transpose(x, self.W_UV, res)
x = res.reshape(-1, self.num_heads * self.v_head_dim)
else:
# Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.local_num_heads,
self.kv_lora_rank).transpose(0, 1)
# # Multiply (N, B, L) x (N, L, V) -> (N, B, V)
x = torch.bmm(x, self.W_UV)
# # Convert from (N, B, V) to (B, N * V)
x = x.transpose(0,
1).reshape(-1,
self.local_num_heads * self.v_head_dim)
return x
# Return `ql_nope`, `q_pe`
def _q_proj_and_k_up_proj(self, x):
q_nope, q_pe = self.q_proj(x)[0]\
.view(-1, self.local_num_heads, self.qk_head_dim)\
.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
# Convert from (B, N, P) to (N, B, P)
q_nope = q_nope.transpose(0, 1)
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
ql_nope = torch.bmm(q_nope, self.W_UK_T)
# Convert from (N, B, L) to (B, N, L)
return ql_nope.transpose(0, 1), q_pe
def exec_kv(
self,
kv_no_split: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
kv_cache: Tuple,
slots: torch.Tensor,
slots_cp: Optional[torch.Tensor],
):
B = kv_no_split.shape[0]
N = self.num_kv_heads
S = 1
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
kv_no_split = kv_no_split.view(
B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
cache_mode = "PA"
if self.enable_sfa_cp:
assert slots_cp is not None
_, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache(
kv_no_split,
self.kv_a_layernorm.weight,
cos,
sin,
slots_cp.to(torch.int64),
kv_cache[1],
kv_cache[0],
epsilon=self.kv_a_layernorm.variance_epsilon,
cache_mode=cache_mode,
is_output_kv=True,
)
#TODO: Temporarily adapt SFA-CP and replace it later with PCP. --clrs97
k_pe = get_tp_group().all_gather(k_pe, 0)
k_nope = get_tp_group().all_gather(k_nope, 0)
if kv_cache is not None:
torch_npu.npu_scatter_nd_update_(
kv_cache[0].view(-1, k_nope.shape[-1]), slots.view(-1, 1),
k_nope.view(-1, k_nope.shape[-1]))
torch_npu.npu_scatter_nd_update_(
kv_cache[1].view(-1, k_pe.shape[-1]), slots.view(-1, 1),
k_pe.view(-1, k_pe.shape[-1]))
else:
torch_npu.npu_kv_rmsnorm_rope_cache(
kv_no_split,
self.kv_a_layernorm.weight,
cos,
sin,
slots.to(torch.int64),
kv_cache[1],
kv_cache[0],
epsilon=self.kv_a_layernorm.variance_epsilon,
cache_mode=cache_mode,
)
def rope_single(
self,
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> torch.Tensor:
B, N, D = x.shape
S = 1
x = x.view(B, N, S, D)
x = torch_npu.npu_interleave_rope(x, cos, sin)
return x.view(B, N, D)
# Processing the input parameters for MLAPO by reordering and transposing
# QKV(and part of Q) weight, applying RoPE-related dimension transformations,
# and handling quantization parameters.
def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype):
assert self.kv_a_proj_with_mqa is None
assert self.fused_qkv_a_proj is not None
kv_a_proj_wt = self.fused_qkv_a_proj.weight.data[
..., self.q_lora_rank:].contiguous()
q_a_proj_wt = self.fused_qkv_a_proj.weight.data[
..., :self.q_lora_rank].contiguous()
kv_a_proj_wt = kv_a_proj_wt.t().contiguous()
kv_a_proj_wt = trans_rope_weight(kv_a_proj_wt, self.qk_rope_head_dim)
kv_a_proj_wt = kv_a_proj_wt.t().contiguous()
wd_qkv = torch.cat((kv_a_proj_wt, q_a_proj_wt), dim=-1)
wd_qkv = wd_qkv.t().contiguous()
wd_qkv = transdata(wd_qkv,
block_size=(16, 32)).unsqueeze(0).contiguous()
self.wd_qkv = torch_npu.npu_format_cast(wd_qkv, 29)
kv_a_proj_deq_scl = self.fused_qkv_a_proj.deq_scale[
self.q_lora_rank:].contiguous()
q_a_proj_deq_scl = self.fused_qkv_a_proj.deq_scale[:self.
q_lora_rank].contiguous(
)
kv_a_proj_deq_scl = kv_a_proj_deq_scl.reshape(
self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous()
kv_a_proj_deq_scl = trans_rope_weight(kv_a_proj_deq_scl,
self.qk_rope_head_dim)
kv_a_proj_deq_scl = kv_a_proj_deq_scl.view(
self.kv_lora_rank + self.qk_rope_head_dim).contiguous()
self.deq_scale_qkv = torch.cat((kv_a_proj_deq_scl, q_a_proj_deq_scl),
dim=-1).contiguous()
kv_a_proj_qt_bias = self.fused_qkv_a_proj.quant_bias[
self.q_lora_rank:].contiguous()
q_a_proj_qt_bias = self.fused_qkv_a_proj.quant_bias[:self.
q_lora_rank].contiguous(
)
kv_a_proj_qt_bias = kv_a_proj_qt_bias.reshape(
self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous()
kv_a_proj_qt_bias = trans_rope_weight(kv_a_proj_qt_bias,
self.qk_rope_head_dim)
kv_a_proj_qt_bias = kv_a_proj_qt_bias.view(
self.kv_lora_rank + self.qk_rope_head_dim).contiguous()
self.quant_bias_qkv = torch.cat((kv_a_proj_qt_bias, q_a_proj_qt_bias),
dim=-1).contiguous()
wu_q = self.q_proj.weight.data
wu_q = wu_q.t().reshape(self.num_heads,
self.qk_nope_head_dim + self.qk_rope_head_dim,
-1)
wu_q = trans_rope_weight(wu_q, self.qk_rope_head_dim)
wu_q = wu_q.reshape(
self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim),
-1)
wu_q = transdata(wu_q, block_size=(16, 32)).unsqueeze(0).contiguous()
self.wu_q = torch_npu.npu_format_cast(wu_q, 29)
qb_deq_scl = self.q_proj.deq_scale.data
qb_deq_scl = qb_deq_scl.reshape(
self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1)
qb_deq_scl = trans_rope_weight(qb_deq_scl, self.qk_rope_head_dim)
self.qb_deq_scl = qb_deq_scl.reshape(
self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim))
qb_qt_bias = self.q_proj.quant_bias.data
qb_qt_bias = qb_qt_bias.reshape(
self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1)
qb_qt_bias = trans_rope_weight(qb_qt_bias, self.qk_rope_head_dim)
self.qb_qt_bias = qb_qt_bias.reshape(
self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim))
device = self.q_proj.weight.device
self.gamma1 = self.q_a_layernorm.weight.data
self.beta1 = self.q_a_layernorm.bias.data
self.gamma2 = self.kv_a_layernorm.weight.data
self.quant_scale0 = self.fused_qkv_a_proj.input_scale.data
self.quant_offset0 = self.fused_qkv_a_proj.input_offset.data
self.quant_scale1 = self.q_proj.input_scale.data
self.quant_offset1 = self.q_proj.input_offset.data
self.ctkv_scale = torch.tensor([1], dtype=act_dtype, device=device)
self.q_nope_scale = torch.tensor([1], dtype=act_dtype, device=device)
if self.vllm_config.kv_transfer_config is not None and \
self.vllm_config.kv_transfer_config.is_kv_consumer:
self.fused_qkv_a_proj.weight = None
self.fused_qkv_a_proj.deq_scale = None
self.fused_qkv_a_proj.quant_bias = None
self.q_proj.weight = None
self.q_proj.deq_scale = None
self.q_proj.quant_bias = None
torch.npu.empty_cache()
def _sfa_preprocess_decode(
self,
hidden_states: torch.Tensor,
kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
attn_metadata: M,
need_gather_q_kv: bool,
num_input_tokens: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
hidden_states.contiguous(), need_gather_q_kv)
k_nope, k_pe = kv_cache[0], kv_cache[1]
ql_nope = torch.empty(
(num_input_tokens, self.W_UK_T.shape[0], k_nope.shape[-1]),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
q_pe = torch.empty(
(num_input_tokens, self.W_UK_T.shape[0], k_pe.shape[-1]),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
q_c = torch.empty(
(num_input_tokens, self.q_lora_rank),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
torch.ops._C_ascend.mla_preprocess(
hidden_states,
self.wd_qkv,
self.deq_scale_qkv,
self.gamma1,
self.beta1,
self.wu_q,
self.qb_deq_scl,
self.gamma2,
attn_metadata.cos,
attn_metadata.sin,
self.W_UK_T,
k_nope,
k_pe,
attn_metadata.slot_mapping,
quant_scale0=self.quant_scale0,
quant_offset0=self.quant_offset0,
bias0=self.quant_bias_qkv,
quant_scale1=self.quant_scale1,
quant_offset1=self.quant_offset1,
bias1=self.qb_qt_bias,
ctkv_scale=self.ctkv_scale,
q_nope_scale=self.q_nope_scale,
cache_mode="krope_ctkv",
quant_mode="per_tensor_quant_asymm",
enable_inner_out=True,
q_out0=ql_nope,
kv_cache_out0=k_nope,
q_out1=q_pe,
kv_cache_out1=k_pe,
inner_out=q_c,
)
return hidden_states, ql_nope, q_pe, q_c
def forward(
self,
layer_name,
hidden_states: torch.Tensor, # query in unified attn
kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
attn_metadata: M,
need_gather_q_kv: bool = False,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert output is not None, "Output tensor must be provided."
forward_context = get_forward_context()
if attn_metadata is None:
# Profiling run.
if self.enable_sfa_cp and not forward_context.in_profile_run:
if is_hidden_layer(self.vllm_config, self.q_proj):
reach_layer_for_shared_weight_series(self.q_proj)
if is_hidden_layer(self.vllm_config, self.o_proj):
reach_layer_for_shared_weight_series(self.o_proj)
return output.fill_(0)
has_prefill = attn_metadata.has_prefill
cos = attn_metadata.cos
sin = attn_metadata.sin
actual_seq_lengths_query = attn_metadata.cum_query_lens
actual_seq_lengths_key = attn_metadata.seq_lens
if self.enable_sfa_cp:
need_gather_q_kv = False
# Inputs and outputs may be padded for CUDA graphs
output_padded = output
if self.enable_mlapo and not forward_context.with_prefill:
hidden_states, ql_nope, q_pe, q_c = self._sfa_preprocess_decode(
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
need_gather_q_kv=need_gather_q_kv,
num_input_tokens=attn_metadata.num_input_tokens,
)
else:
assert self.fused_qkv_a_proj is not None, "q lora is required for DSA."
maybe_npu_prefetch(inputs=self.fused_qkv_a_proj.weight,
dependency=hidden_states,
enabled=self.enable_prefetch)
qkv_lora = self.fused_qkv_a_proj(hidden_states)[0]
q_c, kv_no_split = qkv_lora.split(
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
dim=-1,
)
q_c = self.q_a_layernorm(q_c)
# Process for Flash Comm V1
if need_gather_q_kv:
q_c = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
q_c.contiguous(), need_gather_q_kv)
kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
kv_no_split.contiguous(), need_gather_q_kv)
if has_prefill:
wait_for_kv_layer_from_connector(layer_name)
slot_mapping = attn_metadata.slot_mapping
slot_mapping_cp = None
if self.enable_sfa_cp:
assert attn_metadata.sfa_cp_context is not None
slot_mapping_cp = attn_metadata.sfa_cp_context.slot_mapping_cp
actual_seq_lengths_query = attn_metadata.sfa_cp_context.actual_seq_lengths_query
actual_seq_lengths_key = attn_metadata.sfa_cp_context.actual_seq_lengths_key
self.exec_kv(kv_no_split, cos, sin, kv_cache, slot_mapping,
slot_mapping_cp)
if self.enable_sfa_cp and attn_metadata.sfa_cp_context is not None:
if is_hidden_layer(self.vllm_config, self.q_proj):
reach_layer_for_shared_weight_series(self.q_proj)
if is_hidden_layer(self.vllm_config, self.o_proj):
reach_layer_for_shared_weight_series(self.o_proj)
ql_nope, q_pe = self._q_proj_and_k_up_proj(q_c)
q_pe = self.rope_single(q_pe, cos, sin)
topk_indices = self.indexer_select(
x=hidden_states,
qr=q_c,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
cos=cos,
sin=sin,
actual_seq_lengths_query=actual_seq_lengths_query,
actual_seq_lengths_key=actual_seq_lengths_key,
need_gather_q_kv=need_gather_q_kv)
attn_output = torch.ops._C_ascend.npu_sparse_flash_attention(
query=ql_nope,
key=kv_cache[0],
value=kv_cache[0],
sparse_indices=topk_indices,
scale_value=self.scale,
sparse_block_size=1,
block_table=attn_metadata.block_tables,
actual_seq_lengths_query=actual_seq_lengths_query,
actual_seq_lengths_kv=actual_seq_lengths_key,
query_rope=q_pe,
key_rope=kv_cache[1],
layout_query="TND",
layout_kv="PA_BSND",
sparse_mode=3,
)
attn_output = self._v_up_proj(attn_output)
maybe_npu_prefetch(inputs=self.o_proj.weight,
dependency=attn_output,
max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=self.enable_prefetch)
output[...] = self.o_proj(attn_output)[0]
return output_padded
def indexer_select(
self,
x: torch.Tensor,
qr: torch.Tensor,
kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
attn_metadata: M,
cos: torch.Tensor,
sin: torch.Tensor,
actual_seq_lengths_query: torch.Tensor,
actual_seq_lengths_key: torch.Tensor,
need_gather_q_kv: bool = False,
):
# q process in new stream
q, _ = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128]
q = q.view(-1, self.n_head, self.head_dim) # [n_toks,64,128]
k_proj, _ = self.wk(x) # [b,s,7168] @ [7168,128] = [b,s,128]
k_proj = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
k_proj, need_gather_q_kv)
k = self.k_norm(k_proj).unsqueeze(1)
k = k.view(-1, 1, self.head_dim)
if HAS_TRITON:
cos = cos.view(-1, self.qk_rope_head_dim)
sin = sin.view(-1, self.qk_rope_head_dim)
q, k = rope_forward_triton(q,
k,
cos,
sin,
rope_dim=self.qk_rope_head_dim,
is_neox_style=True)
else:
cos_q, sin_q = cos, sin
cos = cos.view(-1, 1, 1, self.qk_rope_head_dim)
sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)
q_pe, q_nope = torch.split(
q,
[self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim],
dim=-1) # [b,s,64,64+64]
q_pe = q_pe.unsqueeze(2)
q_pe = torch_npu.npu_interleave_rope(q_pe, cos_q, sin_q)
q_pe = q_pe.squeeze(2)
q = torch.cat([q_pe, q_nope], dim=-1) # [b*s,64,128]
k_pe, k_nope = torch.split(
k,
[self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim],
dim=-1) # [b,s,64+64]
k_pe = k_pe.unsqueeze(2)
k_pe = torch_npu.npu_interleave_rope(k_pe, cos, sin)
k_pe = k_pe.squeeze(2)
k = torch.cat([k_pe, k_nope], dim=-1) # [b*s,128]
if self.enable_sfa_cp:
k = get_tp_group().all_gather(k, 0)
if kv_cache is not None:
torch_npu.npu_scatter_nd_update_(kv_cache[2].view(-1, k.shape[-1]),
attn_metadata.slot_mapping.view(
-1, 1),
k.view(-1,
k.shape[-1])) # b, s, n, d
weights, _ = self.weights_proj(x)
weights = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
weights, need_gather_q_kv)
block_table = attn_metadata.block_tables
topk_indices = torch.ops._C_ascend.npu_lightning_indexer(
query=q,
key=kv_cache[2],
weights=weights,
actual_seq_lengths_query=actual_seq_lengths_query,
actual_seq_lengths_key=actual_seq_lengths_key,
block_table=block_table,
layout_query="TND",
layout_key="PA_BSND",
sparse_count=2048,
sparse_mode=3)
return topk_indices
def _replace_linear_class_for_sfa_cp(self):
vllm_config = get_current_vllm_config()
# Dispose tensor from the original q_proj
dispose_layer(self.q_proj)
# Construct the new q_proj using ReplicatedLinear
new_q_proj = ReplicatedLinear(self.q_lora_rank,
self.local_num_heads * self.qk_head_dim,
bias=False,
quant_config=vllm_config.quant_config,
prefix=self.q_proj.prefix)
# Replace the q_proj with the new one
replace_layer(self.q_proj, new_q_proj)
# Dispose tensor from the original kv_b_proj
dispose_layer(self.kv_b_proj)
# Construct the new kv_b_proj using ReplicatedLinear
new_kv_b_proj = ReplicatedLinear(
self.kv_lora_rank,
self.local_num_heads * (self.qk_nope_head_dim + self.v_head_dim),
bias=False,
quant_config=vllm_config.quant_config,
prefix=self.kv_b_proj.prefix)
# Replace the kv_b_proj with the new one
replace_layer(self.kv_b_proj, new_kv_b_proj)
# Dispose tensor from the original o_proj
dispose_layer(self.o_proj)
# Construct the new o_proj using ReplicatedLinear
config = vllm_config.model_config.hf_config
new_o_proj = ReplicatedLinear(config.num_attention_heads *
config.v_head_dim,
config.hidden_size,
bias=False,
quant_config=vllm_config.quant_config,
prefix=self.o_proj.prefix)
# Replace the o_proj with the new one
replace_layer(self.o_proj, new_o_proj)