Files
xc-llm-ascend/vllm_ascend/torchair/torchair_sfa.py
Mengqing Cao 8abe517870 [Refactor] Adapt deepseek-v3.2 to vllm 0.11.0 (#3432)
### What this PR does / why we need it?
Adapt deepseek-v3.2 to vllm 0.11.0, removing the useless patch.

The final goal is to remove all the patches and align the code arch to
vllm, thus we need to do the following work in next prs.
TODO:
- [x] remove patch on attention spec
- [ ] refactor the kvcache creation logic

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
1. CI passed with existing test.
2. Test pass with deepseek-v3.2-exp


- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

Signed-off-by: MengqingCao <cmq0113@163.com>
2025-10-15 17:48:58 +08:00

1334 lines
60 KiB
Python

from dataclasses import dataclass
from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Type, TypeVar
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_npu
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata,
MLAAttentionImpl)
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.utils import cdiv, round_down
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
split_decodes_and_prefills)
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata
from vllm_ascend.utils import is_enable_nz
from vllm_ascend.worker.npu_input_batch import InputBatch
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
class AscendSFATorchairBackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_name() -> str:
return "ASCEND_SFA_TORCHAIR"
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
return AscendSFATorchairMetadata
@staticmethod
def get_builder_cls():
return AscendSFATorchairMetadataBuilder
#NOTE: is that ok?
@staticmethod
def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int,
head_size: int) -> tuple[int, ...]:
return (num_blocks, block_size, num_kv_heads, head_size)
@staticmethod
def get_impl_cls() -> Type["MLAAttentionImpl"]:
return AscendSFATorchairImpl
@dataclass
class AscendSFATorchairPrefillMetadata:
""" Prefill Specific Metadata for Ascend"""
@dataclass
class TorchairChunkedContextMetadata:
# New for SFA (compared to FlashAttention)
# For handling chunked prefill
cu_seq_lens: torch.Tensor
starts: torch.Tensor
seq_tot: list[int]
max_seq_lens: list[int]
workspace: torch.Tensor
chunk_seq_lens: torch.Tensor
attn_mask: torch.Tensor
query_lens: list[int] # Check!!
seq_lens: list[int] # Check!!
context_lens: torch.Tensor
input_positions: torch.Tensor
query_start_loc: torch.Tensor
block_table: torch.Tensor
max_query_len: int
max_seq_lens: int
sin: torch.Tensor
cos: torch.Tensor
chunked_context: Optional[TorchairChunkedContextMetadata] = None
@dataclass
class AscendSFATorchairDecodeMetadata:
# Input positions for rotrary embeddings since for SFA the rotary
# position embeddings are applied inside the attention backend
input_positions: torch.Tensor
block_table: torch.Tensor
seq_lens: torch.Tensor
max_seq_lens: int
seq_lens_list: list[int]
actual_seq_lengths_q: torch.Tensor
sin: torch.Tensor
cos: torch.Tensor
attn_mask: Optional[torch.Tensor] = None
@dataclass
class AscendSFATorchairMetadata:
"""Metadata for SFACommon.
NOTE: Please read the comment at the top of the file before trying to
understand this class
"""
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
num_actual_tokens: int # Number of tokens excluding padding.
slot_mapping: torch.Tensor
query_start_loc: torch.Tensor
seq_lens: torch.Tensor
block_tables: torch.Tensor
# New for SFA (compared to FlashAttention)
# For handling prefill decode split
num_decodes: int
num_decode_tokens: int
num_prefills: int
# For logging.
num_input_tokens: int = 0 # Number of tokens including padding.
query_lens: Optional[list[int]] = None
# The dimension of the attention heads
head_dim: Optional[int] = None
attn_mask: torch.Tensor = None
# chunked prefill by default if no attn_states passed
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
decode: Optional[AscendSFATorchairDecodeMetadata] = None
prefill: Optional[AscendSFATorchairPrefillMetadata] = None
enable_dbo_across_dp: bool = False
is_prefill: bool = False
is_decode: bool = False
def __post_init__(self):
pass
# supported_head_sizes = AscendSFABackend.get_supported_head_sizes()
# if self.head_dim is not None and self.head_dim \
# not in supported_head_sizes:
# raise ValueError(
# f"Only {supported_head_sizes} are supported for head_dim,",
# f"received {self.head_dim}.")
def split_metadata_for_multistream(
self,
ms_split_config: MSAttentionMetadataSplitConfig,
) -> list["AscendSFATorchairMetadata"]:
"""Split metadata for multi-stream with AscendSFATorchairMetadata"""
return model_input_split_v1_mla_attn(
ms_split_config=ms_split_config,
attn_metadata=self,
_metadata_cls=AscendSFATorchairMetadata,
)
M = TypeVar("M", bound=AscendSFATorchairMetadata)
class AscendSFATorchairMetadataBuilder:
"""
NOTE: Please read the comment at the top of the file before trying to
understand this class
"""
# _attn_mask_builder = None
def __init__(self,
kv_cache_spec,
layer_names,
vllm_config: VllmConfig,
device: torch.device,
metadata_cls: Optional[AscendSFATorchairMetadata] = None):
self.metadata_cls: Optional[AscendSFATorchairMetadata] = metadata_cls \
if metadata_cls is not None else AscendSFATorchairMetadata # type: ignore
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.device = device
scheduler_config = vllm_config.scheduler_config
self.block_size = vllm_config.cache_config.block_size
self.max_blocks = (vllm_config.model_config.max_model_len +
self.block_size - 1) // self.block_size
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
if self.chunked_prefill_enabled:
self.chunked_prefill_workspace_size = min(
# Max sure there is enough for 8 full length request or at least
# 4 pages of cache per request
max(8 * self.model_config.max_model_len,
4 * scheduler_config.max_num_seqs * self.block_size),
# For long-context models try not to over-allocate limiting
# kv-cache space, limiting it to 64k tokens,
# which would result in the workspace being:
# 2*(576)*(64*1024) = 144mb
# (assuming 576 SFA head dim, and fp16)
# which would result in up-projected context being
# 2*(192*128)*(64*1024) = 3gb
# (assuming 192 QK head dim, 128 heads, and fp16)
128 * 1024)
assert self.chunked_prefill_workspace_size >= \
scheduler_config.max_num_seqs * self.block_size
self.chunked_prefill_workspace = torch.empty(
(self.chunked_prefill_workspace_size,
self.model_config.get_head_size()),
dtype=self.model_config.dtype,
device=device,
)
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
self.cos_cache = None
self.sin_cache = None
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
# We now want to reorder the batch so that the "decode" requests are at
# the front and the "prefill" requests are at the using the least amount
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
# where attention is likely memory-bound and "prefill" to mean requests
# where attention is likely compute-bound, TODO(lucas): figure out a
# better naming here)
decodes = []
prefills = []
for i, req_id in enumerate(input_batch.req_ids):
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
num_spec_tokens = len(
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
# For torch air graph mode we treat spec decoding as decode.
if self.torchair_graph_enabled:
if num_tokens - num_spec_tokens == 1:
decodes.append(i)
else:
prefills.append(i)
# For eager mode we treat spec decoding as chunked prefill.
else:
if num_tokens == 1:
decodes.append(i)
else:
prefills.append(i)
# We hope that this is fairly minimal since decodes
# should be around for a number of iterations so hopefully they are
# relatively stationary (and new request are generally appended to the
# persistent batch so already should be at the back)
# To achieve this we loop over the decodes in descending order and
# the prefills in ascending order. We swap decodes from the "back"
# i.e. past where the last decode should be in the reodorered with
# prefills from the front of the batch.
# `decodes` and `prefills` are already in ascending order just based on
# the above loop
num_decodes = len(decodes)
num_prefills = len(prefills)
first_prefill = 0
modified_batch = False
for i in range(1, min(num_decodes, num_prefills) + 1):
# If the decode is at the "back" of the batch, i, we can swap it
# with the prefill closest to the front of the batch
if decodes[num_decodes - i] >= num_decodes:
input_batch.swap_states(prefills[first_prefill],
decodes[num_decodes - i])
first_prefill += 1
modified_batch = True
else:
break
# Save for next `build` call
# TODO(lucas): this is a bit of a hack, we should probably have a
# better way of doing this
return modified_batch
def _get_graph_runner_block_tables(
self, num_seqs: int, block_tables: torch.Tensor) -> torch.Tensor:
max_blocks = self.max_blocks
graph_block_tables = torch.zeros((num_seqs, max_blocks),
dtype=block_tables.dtype,
device=block_tables.device)
num_blocks = block_tables.size(1)
if num_blocks <= max_blocks:
graph_block_tables[:num_seqs, :
num_blocks] = block_tables[:num_seqs, :
num_blocks]
else:
graph_block_tables[:num_seqs, :
max_blocks] = block_tables[:num_seqs, :
max_blocks]
return graph_block_tables[:, :max_blocks]
def build_torchair_graph_dummy(
self,
common_attn_metadata: TorchairCommonAttentionMetadata,
) -> AscendSFATorchairMetadata:
device = self.device
num_reqs = common_attn_metadata.num_reqs
block_table = torch.zeros((num_reqs, self.max_blocks),
dtype=torch.int32,
device=device)
block_table = self._get_graph_runner_block_tables(
num_reqs, block_table)
num_tokens = num_reqs * common_attn_metadata.decode_token_per_req
seq_lens = torch.zeros(num_reqs, dtype=torch.int32, device=device)
seq_lens_list = [0] * num_reqs
input_positions = torch.zeros(num_tokens,
dtype=torch.int32,
device=device).long()
slot_mapping = torch.full((num_tokens, ),
PAD_SLOT_ID,
dtype=torch.int32,
device=device)
query_start_loc = torch.full((num_reqs, ),
-1,
dtype=torch.int32,
device=device)
sin = torch.ones(num_tokens,
1,
1,
self.rope_dim,
dtype=self.model_config.dtype,
device=device)
cos = torch.ones(num_tokens,
1,
1,
self.rope_dim,
dtype=self.model_config.dtype,
device=device)
if self.vllm_config.speculative_config is not None and\
self.vllm_config.speculative_config.method == 'deepseek_mtp':
attn_state = AscendAttentionState.SpecDecoding
num_decode_tokens = 2
else:
attn_state = AscendAttentionState.DecodeOnly
num_decode_tokens = 1
# cumsum here.
# actual_seq_lengths_q = torch.Tensor(common_attn_metadata.actual_seq_lengths_q[:num_tokens]).to(torch.int32).npu()
# actual_seq_lengths_q = torch.cumsum(actual_seq_lengths_q, dim=0).to(torch.int32).npu()
actual_seq_lengths_q = torch.arange(1, num_reqs + 1).to(
torch.int32).npu(
) * common_attn_metadata.decode_token_per_req ##############
decode_metadata = AscendSFATorchairDecodeMetadata(
input_positions=input_positions,
block_table=block_table,
seq_lens=seq_lens,
seq_lens_list=seq_lens_list,
max_seq_lens=1,
attn_mask=common_attn_metadata.spec_attn_mask,
# actual_seq_lengths_q=torch.Tensor(common_attn_metadata.actual_seq_lengths_q[:num_reqs]).to(torch.int32).npu(),
actual_seq_lengths_q=actual_seq_lengths_q,
# actual_seq_lengths_q=torch.Tensor([1]).to(torch.int32).npu(),
sin=sin,
cos=cos,
)
return self.metadata_cls( # type: ignore
num_input_tokens=common_attn_metadata.num_actual_tokens,
num_actual_tokens=common_attn_metadata.num_actual_tokens,
slot_mapping=slot_mapping,
head_dim=self.model_config.get_head_size(),
num_decodes=num_tokens,
num_decode_tokens=num_decode_tokens,
num_prefills=0,
attn_mask=common_attn_metadata.attn_mask,
attn_state=attn_state,
prefill=None,
decode=decode_metadata,
query_start_loc=query_start_loc,
seq_lens=seq_lens,
block_tables=block_table,
is_prefill=False,
is_decode=True)
def build(
self,
common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata,
model: nn.Module,
) -> AscendSFATorchairMetadata:
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
query_start_loc = common_attn_metadata.query_start_loc
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
if self.torchair_graph_enabled and common_attn_metadata.attn_state in [
AscendAttentionState.DecodeOnly,
AscendAttentionState.SpecDecoding
]:
decode_threshold = common_attn_metadata.decode_token_per_req
else:
# TODO(xyx): remove the if condition after mla supports torch mode speculative decoding
decode_threshold = 1
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
split_decodes_and_prefills(common_attn_metadata, decode_threshold=decode_threshold)
assert num_decodes + num_prefills == num_reqs
assert num_decode_tokens + num_prefill_tokens == num_actual_tokens
# Note(simon): be careful about the CPU <> GPU memory movement in this
# function. We should avoid GPU -> CPU sync as much as possible because
# it blocks on all previous kernels.
device = self.device
block_table = (common_attn_metadata.block_table_tensor[:num_reqs])
slot_mapping = common_attn_metadata.slot_mapping[:
num_actual_tokens].to(
device,
non_blocking=True)
input_positions = common_attn_metadata.positions[:
num_actual_tokens].long(
)
if self.cos_cache is None:
self.cos_cache = model.model.layers[
0].self_attn.rotary_emb.cos_cached
self.sin_cache = model.model.layers[
0].self_attn.rotary_emb.sin_cached
if self.cos_cache.dtype != self.model_config.dtype: # type: ignore
self.cos_cache = self.cos_cache.to( # type: ignore
self.model_config.dtype) # type: ignore
self.sin_cache = self.sin_cache.to( # type: ignore
self.model_config.dtype) # type: ignore
# check CPU operation here
query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
query_lens = query_seq_lens_cpu[:num_reqs]
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
num_computed_tokens_cpu = (seq_lens - query_lens)
prefill_metadata = None
chunked_context_metadata = None
is_prefill = False
is_decode = False
if num_prefills > 0:
reqs_start = num_decodes # prefill_start
tokens_start = num_decode_tokens
max_query_len = query_lens[tokens_start:].max().item()
max_seq_lens = seq_lens[tokens_start:].max().item()
prefill_query_start_loc = query_start_loc[
reqs_start:] - query_start_loc[reqs_start]
context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs]
max_context_len_cpu = context_lens_cpu.max().item()
num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
if self.chunked_prefill_enabled and max_context_len_cpu > 0:
max_context_chunk = (self.chunked_prefill_workspace_size //
num_prefills_with_context_cpu)
max_context_chunk = round_down(max_context_chunk,
self.block_size)
assert max_context_chunk > 0
num_chunks = cdiv(max_context_len_cpu, max_context_chunk)
chunk_starts = torch.arange(num_chunks, dtype=torch.int32) \
.unsqueeze(1).expand(-1, num_prefills) * max_context_chunk
chunk_ends = torch.min(context_lens_cpu.unsqueeze(0),
chunk_starts + max_context_chunk)
chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)
cu_seq_lens_cpu = torch.zeros(num_chunks,
num_prefills + 1,
dtype=torch.int32,
pin_memory=True)
torch.cumsum(chunk_seq_lens,
dim=1,
out=cu_seq_lens_cpu[:, 1:],
dtype=torch.int32)
chunked_context_metadata = \
AscendSFATorchairPrefillMetadata.TorchairChunkedContextMetadata(
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
starts=chunk_starts.to(device, non_blocking=True),
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
chunk_seq_lens=chunk_seq_lens,
workspace=self.chunked_prefill_workspace,
)
prefill_input_positions = input_positions[tokens_start:]
cos = self.cos_cache[
prefill_input_positions].unsqueeze( # type: ignore
1).unsqueeze(2)
sin = self.sin_cache[
prefill_input_positions].unsqueeze( # type: ignore
1).unsqueeze(2)
actual_query_lens = torch.tensor(
query_lens[tokens_start:],
dtype=torch.int32).npu() # int64->int32
query_lens_prefill_sfa = torch.cumsum(actual_query_lens,
dim=0).to(torch.int32).npu()
seq_lens_prefill_sfa = torch.tensor(seq_lens,
dtype=torch.int32).npu()
prefill_metadata = AscendSFATorchairPrefillMetadata(
attn_mask=common_attn_metadata.attn_mask,
query_lens=query_lens_prefill_sfa,
seq_lens=seq_lens_prefill_sfa,
context_lens=seq_lens[tokens_start:],
input_positions=prefill_input_positions,
block_table=block_table[reqs_start:, ...],
max_query_len=max_query_len,
max_seq_lens=max_seq_lens,
query_start_loc=prefill_query_start_loc,
chunked_context=chunked_context_metadata,
sin=sin,
cos=cos,
)
is_prefill = True
decode_metadata = None
graph_pad_size = common_attn_metadata.graph_pad_size
use_torchair_graph = graph_pad_size != -1
if num_decodes > 0:
# Check here!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
# Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario
actual_seq_lengths_q = query_start_loc[1:num_decodes + 1].to(
torch.int32).npu()
max_seq_lens = seq_lens[:num_decodes].max().item()
seq_lens = seq_lens[:num_decodes].to(torch.int32).npu()
# input_positions = input_positions[:num_decode_tokens]
block_table = block_table[:num_decodes, ...]
num_token_pad_size = 0
if use_torchair_graph and common_attn_metadata.attn_state in [
AscendAttentionState.DecodeOnly,
AscendAttentionState.SpecDecoding
]:
num_reqs_pad_size = 0
if graph_pad_size != 0:
pad_value = 0
num_token_pad_size = graph_pad_size - num_decode_tokens
num_reqs_pad_size = (
graph_pad_size //
common_attn_metadata.decode_token_per_req - num_reqs)
padded_seq_lens = seq_lens.tolist(
) + [pad_value] * num_reqs_pad_size
else:
padded_seq_lens = seq_lens.tolist()
seq_lens = torch.from_numpy(
np.array(padded_seq_lens).astype(np.int32)).npu()
seq_lens_list = padded_seq_lens
slot_padding = torch.full((num_token_pad_size, ),
PAD_SLOT_ID,
dtype=slot_mapping.dtype,
device=slot_mapping.device)
slot_mapping = torch.cat([slot_mapping, slot_padding])
block_table_padding = torch.zeros(
(num_reqs_pad_size, ) + block_table.shape[1:],
dtype=block_table.dtype,
device=block_table.device)
block_table = torch.cat([block_table, block_table_padding],
dim=0)
block_table = self._get_graph_runner_block_tables(
num_reqs + num_reqs_pad_size, block_table)
position_padding = torch.zeros(num_token_pad_size,
dtype=input_positions.dtype,
device=input_positions.device)
input_positions = torch.cat(
[input_positions, position_padding])
# actual_seq_lengths_q = torch.cumsum(actual_seq_lengths_q, dim=0).npu()
# actual_seq_lengths_q=torch.Tensor([1]).to(torch.int32).npu()
actual_seq_lengths_q = torch.arange(1, num_reqs + 1).to(
torch.int32).npu(
) * common_attn_metadata.decode_token_per_req
# MTP ignored
# actual_seq_lengths_q = self.pad_actual_seq_len_q(
# num_reqs_pad_size, num_reqs, actual_seq_lengths_q,
# common_attn_metadata)
else:
seq_lens_list = seq_lens.tolist()
# mtp torchair + PD scenario, last element of actual_seq_lengths_q must equal to batch_size(num_tokens)
batch_size = num_decode_tokens + num_token_pad_size
if actual_seq_lengths_q[-1] != batch_size \
and common_attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
actual_seq_lengths_q[-1] = batch_size
cos = self.cos_cache[input_positions].unsqueeze( # type: ignore
1).unsqueeze(2)
sin = self.sin_cache[input_positions].unsqueeze( # type: ignore
1).unsqueeze(2)
padded_token_num = input_positions.shape[0]
actual_seq_lengths_q = torch.arange(
1,
(padded_token_num // common_attn_metadata.decode_token_per_req)
+ 1).to(torch.int32).npu(
) * common_attn_metadata.decode_token_per_req
decode_metadata = AscendSFATorchairDecodeMetadata(
input_positions=input_positions,
block_table=block_table,
seq_lens=seq_lens,
seq_lens_list=seq_lens_list,
max_seq_lens=max_seq_lens,
attn_mask=common_attn_metadata.spec_attn_mask,
actual_seq_lengths_q=actual_seq_lengths_q,
sin=sin,
cos=cos)
is_decode = True
return self.metadata_cls( # type: ignore
num_actual_tokens=num_actual_tokens,
query_lens=query_lens.tolist(),
slot_mapping=slot_mapping,
head_dim=self.model_config.get_head_size(),
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
num_prefills=num_prefills,
attn_mask=common_attn_metadata.attn_mask,
attn_state=common_attn_metadata.attn_state,
prefill=prefill_metadata,
decode=decode_metadata,
query_start_loc=query_start_loc,
block_tables=block_table,
seq_lens=seq_lens,
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
is_prefill=is_prefill,
is_decode=is_decode)
def pad_actual_seq_len_q(self, num_reqs_pad_size, num_reqs,
actual_seq_lengths_q, common_attn_metadata):
"""
Pads actual_seq_lengths_q evenly to not exceed 16 tokens per request
in order to meet the requirement of npu_fused_infer_attention_score.
In Torchair scenario, the lengths of the queries must be padded to the same length.
And npu_fused_infer_attention_score constraint requires the last element must equal to batch_size(num_tokens).
For example:
batch_size=36, num_reqs_pad_size=2, num_reqs=16
By default, each request should have inference 2 token, which means actual_seq_lengths_q should be
[2,4,6,8,10,12,14,16,18,20,22,24,26,28,30,32,34,36].
However, mtp torchair + PD scenario, the actual_seq_lengths_q may be
[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16] before padding, since the first decode request only has 1 token.
In order to meet the requirement of npu_fused_infer_attention_score, we need to pad actual_seq_lengths_q evenly to not exceed 16 tokens per request.
after padding actual_seq_lengths_q should be similar to [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,32,36]
"""
FIA_SEQ_LEN_LIMIT = 16
need_padding = num_reqs_pad_size != 0 and \
len(common_attn_metadata.actual_seq_lengths_q) > num_reqs and \
common_attn_metadata.actual_seq_lengths_q[num_reqs] - actual_seq_lengths_q[-1] > FIA_SEQ_LEN_LIMIT
if need_padding:
padding_seq_len_q = common_attn_metadata.actual_seq_lengths_q[
num_reqs:num_reqs + num_reqs_pad_size]
start_val = actual_seq_lengths_q[-1]
end_val = padding_seq_len_q[-1]
num_step = len(padding_seq_len_q)
interpolated = np.round(
np.linspace(start_val, end_val,
num_step + 1)[1:]).astype(int).tolist()
assert interpolated[-1] == end_val
assert len(interpolated) == len(padding_seq_len_q)
actual_seq_lengths_q = actual_seq_lengths_q + interpolated
else:
actual_seq_lengths_q = actual_seq_lengths_q + common_attn_metadata.actual_seq_lengths_q[
num_reqs:num_reqs + num_reqs_pad_size]
# return actual_seq_lengths_q
return torch.Tensor(actual_seq_lengths_q).to(torch.int32).npu()
class PrefillSFAPreprocessResult(NamedTuple):
q_nope: Optional[torch.Tensor] = None
q_pe: Optional[torch.Tensor] = None
k_nope: Optional[torch.Tensor] = None
k_pe: Optional[torch.Tensor] = None
topk_indices: Optional[torch.Tensor] = None
query_states: Optional[torch.Tensor] = None
key_states: Optional[torch.Tensor] = None
class DecodeSFAPreprocessResult(NamedTuple):
q_nope: Optional[torch.Tensor] = None
q_pe: Optional[torch.Tensor] = None
# nope_cache: Optional[torch.Tensor] = None
# rope_cache: Optional[torch.Tensor] = None
topk_indices: Optional[torch.Tensor] = None
query_states: Optional[torch.Tensor] = None
key_states: Optional[torch.Tensor] = None
bsz: Optional[int] = None
class AscendSFATorchairImpl(MLAAttentionImpl):
"""
NOTE: Please read the comment at the top of the file before trying to
understand this class
"""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
**kwargs,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
self.kv_cache_dtype = kv_cache_dtype
# MLA Args
self.q_lora_rank = kwargs['q_lora_rank']
self.kv_lora_rank = kwargs['kv_lora_rank']
self.qk_nope_head_dim = kwargs['qk_nope_head_dim']
self.qk_rope_head_dim = kwargs['qk_rope_head_dim']
self.qk_head_dim = kwargs['qk_head_dim']
self.v_head_dim = kwargs['v_head_dim']
self.rotary_emb = kwargs['rotary_emb']
self.q_proj = kwargs['q_proj']
self.kv_b_proj = kwargs['kv_b_proj']
self.o_proj = kwargs['o_proj']
self.indexer = kwargs['indexer']
self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None)
self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None)
self.q_a_proj = kwargs.get('q_a_proj', None)
self.q_a_layernorm = kwargs.get('q_a_layernorm', None)
self.decoder_layer = kwargs.get('decoder_layer', None)
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_rank = self.num_heads // self.tp_size
if self.q_a_proj is not None:
self.q_b_proj = self.q_proj
else:
self.q_b_proj = None
ascend_config = get_ascend_config()
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
self.enable_prefetch = ascend_config.weight_prefetch_config.enabled
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
if ascend_config.torchair_graph_config.enabled:
self.graph_batch_size = ascend_config.torchair_graph_config.graph_batch_sizes[
0]
self.actual_seq_length = torch.arange(1, self.graph_batch_size +
1).to(torch.int32).npu()
vllm_config = get_current_vllm_config()
self.ring_mla_mask_size = 512
self.prefill_mask = None
# indexer param
self.dim = self.indexer.dim
self.n_heads: int = self.indexer.n_heads # 64
self.head_dim: int = self.indexer.head_dim # 128
self.index_topk: int = self.indexer.index_topk # 2048
self.wq_b = self.indexer.wq_b
self.wk = self.indexer.wk
self.weights_proj = self.indexer.weights_proj
self.k_norm = self.indexer.k_norm
self.softmax_scale = self.indexer.softmax_scale
# Adapt torch air graph mode with spec decoding.
speculative_config = vllm_config.speculative_config
if speculative_config is not None:
self.spec_token_num = speculative_config.num_speculative_tokens
assert self.spec_token_num > 0
self.cp_size = 1
if self.q_a_proj is not None:
self.prefix = self.q_a_proj.prefix
else:
self.prefix = 0
self.debug_layer_idx = int(self.prefix.split(".")[2])
self.layers = vllm_config.model_config.hf_config.num_hidden_layers
self.first_k_dense_replace = vllm_config.model_config.hf_config.first_k_dense_replace
def process_weights_after_loading(self, act_dtype: torch.dtype):
def get_layer_weight(layer):
WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
for attr in WEIGHT_NAMES:
if hasattr(layer, attr):
return getattr(layer, attr)
raise AttributeError(
f"Layer '{layer}' has no recognized weight attribute:"
f" {WEIGHT_NAMES}.")
def get_and_maybe_dequant_weights(layer: LinearBase):
if not isinstance(layer.quant_method, UnquantizedLinearMethod):
# NOTE: This should only be used offline, since it's O(N^3)
eye = torch.eye(layer.input_size_per_partition,
dtype=act_dtype,
device=get_layer_weight(layer).device)
dequant_weights = layer.quant_method.apply(layer,
eye,
bias=None)
del eye
# standardize to (output, input)
return dequant_weights.T
return layer.weight
# we currently do not have quantized bmm's which are needed for
# `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
assert kv_b_proj_weight.shape == (
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
f"{kv_b_proj_weight.shape=}, "
f"{self.kv_lora_rank=}, "
f"{self.num_heads=}, "
f"{self.qk_nope_head_dim=}, "
f"{self.v_head_dim=}")
kv_b_proj_weight = kv_b_proj_weight.view(
self.kv_lora_rank,
self.num_heads,
self.qk_nope_head_dim + self.v_head_dim,
)
self.kv_b_proj_w_k, self.kv_b_proj_w_v = kv_b_proj_weight.split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
# Convert from (L, N, V) to (N, L, V)
self.kv_b_proj_w_v = self.kv_b_proj_w_v.transpose(0, 1).contiguous()
# Convert from (L, N, P) to (N, P, L)
self.kv_b_proj_w_k = self.kv_b_proj_w_k.permute(1, 2, 0).contiguous()
# Waiting for BMM NZ support
# self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29)
# self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29)
if envs_ascend.VLLM_ASCEND_ENABLE_MLAPO:
self._process_weights_for_fused_mlapo(act_dtype)
def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype):
kv_a_proj_wt = self.kv_a_proj_with_mqa.weight.data.clone()
kv_a_proj_wt = kv_a_proj_wt.t().contiguous()
kv_a_proj_wt = trans_rope_weight(kv_a_proj_wt, self.qk_rope_head_dim)
kv_a_proj_wt = kv_a_proj_wt.t().contiguous()
wd_qkv = torch.cat((kv_a_proj_wt, self.q_a_proj.weight.data.clone()),
dim=-1)
wd_qkv = wd_qkv.t().contiguous()
wd_qkv = transdata(wd_qkv,
block_size=(16, 32)).unsqueeze(0).contiguous()
if is_enable_nz():
self.wd_qkv = torch_npu.npu_format_cast(wd_qkv, 29)
kv_a_proj_deq_scl = self.kv_a_proj_with_mqa.deq_scale.clone()
kv_a_proj_deq_scl = kv_a_proj_deq_scl.reshape(
self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous()
kv_a_proj_deq_scl = trans_rope_weight(kv_a_proj_deq_scl,
self.qk_rope_head_dim)
kv_a_proj_deq_scl = kv_a_proj_deq_scl.view(
self.kv_lora_rank + self.qk_rope_head_dim).contiguous()
self.deq_scale_qkv = torch.cat(
(kv_a_proj_deq_scl, self.q_a_proj.deq_scale.clone()),
dim=-1).contiguous()
kv_a_proj_qt_bias = self.kv_a_proj_with_mqa.quant_bias.clone()
kv_a_proj_qt_bias = kv_a_proj_qt_bias.reshape(
self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous()
kv_a_proj_qt_bias = trans_rope_weight(kv_a_proj_qt_bias,
self.qk_rope_head_dim)
kv_a_proj_qt_bias = kv_a_proj_qt_bias.view(
self.kv_lora_rank + self.qk_rope_head_dim).contiguous()
self.quant_bias_qkv = torch.cat(
(kv_a_proj_qt_bias, self.q_a_proj.quant_bias.clone()),
dim=-1).contiguous()
wu_q = self.q_proj.weight.data.clone()
wu_q = wu_q.t().reshape(self.num_heads,
self.qk_nope_head_dim + self.qk_rope_head_dim,
-1)
wu_q = trans_rope_weight(wu_q, self.qk_rope_head_dim)
wu_q = wu_q.reshape(
self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim),
-1)
wu_q = transdata(wu_q, block_size=(16, 32)).unsqueeze(0).contiguous()
if is_enable_nz():
self.wu_q = torch_npu.npu_format_cast(wu_q, 29)
qb_deq_scl = self.q_proj.deq_scale.data.clone()
qb_deq_scl = qb_deq_scl.reshape(
self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1)
qb_deq_scl = trans_rope_weight(qb_deq_scl, self.qk_rope_head_dim)
self.qb_deq_scl = qb_deq_scl.reshape(
self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim))
qb_qt_bias = self.q_proj.quant_bias.data.clone()
qb_qt_bias = qb_qt_bias.reshape(
self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1)
qb_qt_bias = trans_rope_weight(qb_qt_bias, self.qk_rope_head_dim)
self.qb_qt_bias = qb_qt_bias.reshape(
self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim))
self.gamma0 = self.decoder_layer.input_layernorm.weight.data
self.beta0 = self.decoder_layer.input_layernorm.bias.data
self.gamma1 = self.q_a_layernorm.weight.data
self.beta1 = self.q_a_layernorm.bias.data
self.gamma2 = self.kv_a_layernorm.weight.data
self.quant_scale0 = self.q_a_proj.input_scale.data
self.quant_offset0 = self.q_a_proj.input_offset.data
self.quant_scale1 = self.q_proj.input_scale.data
self.quant_offset1 = self.q_proj.input_offset.data
def _sfa_decode_preprocess(self, hidden_states, kv_cache, attn_metadata,
need_gather_q_kv):
bsz = hidden_states.shape[0]
cos_shape = attn_metadata.decode.cos.shape
cos = attn_metadata.decode.cos.view(cos_shape[0], cos_shape[-1])
sin = attn_metadata.decode.sin.view(cos_shape[0], cos_shape[-1])
ctkv_scale = torch.tensor([1],
dtype=hidden_states.dtype,
device=hidden_states.device)
q_nope_scale = torch.tensor([1],
dtype=hidden_states.dtype,
device=hidden_states.device)
decode_q_nope, _, decode_q_pe, _ = torch_npu.npu_mla_process(
hidden_states,
self.gamma0,
self.beta0,
self.wd_qkv,
self.deq_scale_qkv,
self.gamma1,
self.beta1,
self.wu_q,
self.qb_deq_scl,
self.gamma2,
cos,
sin,
self.kv_b_proj_w_k,
kv_cache[0],
kv_cache[1],
attn_metadata.slot_mapping.flatten(),
quant_scale0=self.quant_scale0,
quant_offset0=self.quant_offset0,
bias0=self.quant_bias_qkv,
quant_scale1=self.quant_scale1,
quant_offset1=self.quant_offset1,
bias1=self.qb_qt_bias,
ctkv_scale=ctkv_scale,
q_nope_scale=q_nope_scale,
cache_mode_opt="krope_ctkv",
quant_mode_opt="per_tensor_quant_asymm",
)
decode_k_nope = kv_cache[0]
decode_k_pe = kv_cache[1]
decode_q_nope = decode_q_nope.view(bsz, self.num_heads,
self.kv_lora_rank)
decode_q_pe = decode_q_pe.view(bsz, self.num_heads, -1)
hidden_states = self.decoder_layer.input_layernorm(hidden_states)
decode_kq = self.q_a_proj(hidden_states) # q down
decode_q_c = self.q_a_layernorm(decode_kq) # q down layernorm
topk_indices = self.indexer_select(hidden_states,
decode_q_c,
attn_metadata=attn_metadata,
kv_cache=kv_cache,
is_prefill=False)
query_states = (decode_q_nope, decode_q_pe)
key_states = (decode_k_nope, decode_k_pe)
decode_preprocess_res = DecodeSFAPreprocessResult(
q_nope=decode_q_nope,
q_pe=decode_q_pe,
topk_indices=topk_indices,
query_states=query_states,
key_states=key_states,
bsz=bsz,
)
return decode_preprocess_res
def forward(
self,
hidden_states: torch.Tensor, # query in unified attn
kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
attn_metadata: M,
need_gather_q_kv: bool = False,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert output is not None, "Output tensor must be provided."
if attn_metadata is None:
# Profiling run.
return output
if attn_metadata.prefill is not None:
assert attn_metadata.num_decodes is not None and \
attn_metadata.num_prefills is not None and \
attn_metadata.num_decode_tokens is not None
bsz = 1
hidden_states_prefill = hidden_states
prefill_slot_mapping = attn_metadata.slot_mapping
prefill_kq = self.q_a_proj(hidden_states_prefill) # q down
prefill_q_c = self.q_a_layernorm(prefill_kq) # q down layernorm
prefill_kv_no_split = self.kv_a_proj_with_mqa(
hidden_states_prefill) # c_kv
if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers:
prefill_kv_no_split = get_tp_group().all_gather(
prefill_kv_no_split,
0)[attn_metadata.num_decode_tokens:attn_metadata.
num_actual_tokens]
# prefill_q_c = q_c[
# num_decode_tokens:num_actual_tokens]
# decode_kv_no_split = decode_kv_no_split[:num_decode_tokens]
# prefill_kv_no_split = kv_no_split[
# num_decode_tokens:num_actual_tokens]
# prefill_qr = prefill_q_c[num_decode_tokens:num_actual_tokens]
prefill_qr = prefill_q_c
if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers:
prefill_qr = get_tp_group().all_gather(
prefill_qr,
0)[attn_metadata.num_decode_tokens:attn_metadata.
num_actual_tokens]
prefill_q = self.q_b_proj(prefill_qr)
prefill_q = prefill_q.view(-1, self.num_heads, self.qk_head_dim)
prefill_q_nope, prefill_q_pe = torch.split(
prefill_q, [self.qk_nope_head_dim, self.qk_rope_head_dim],
dim=-1)
prefill_q_nope = prefill_q_nope.view(
-1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1)
prefill_q_nope = (torch.matmul(prefill_q_nope,
self.kv_b_proj_w_k).transpose(
1,
0).view(-1, self.num_heads,
self.kv_lora_rank))
prefill_q_pe = prefill_q_pe.unsqueeze(2)
# stream2 kv
nope_cache = kv_cache[0]
rope_cache = kv_cache[1]
cos = attn_metadata.prefill.cos
sin = attn_metadata.prefill.sin
cos_q, sin_q = cos, sin
prefill_q_pe = torch_npu.npu_interleave_rope(
prefill_q_pe, cos_q, sin_q) # BNSD
prefill_q_pe = prefill_q_pe.squeeze(2) #BSH
# q[..., self.qk_nope_head_dim:] = prefill_q_pe # TODO:????
prefill_latent_cache = prefill_kv_no_split # (B,S,N,D)
prefill_k_pe, prefill_k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
prefill_latent_cache.view(
-1, 1, 1, self.kv_lora_rank + self.qk_rope_head_dim),
self.kv_a_layernorm.weight,
cos.view(-1, 1, 1, self.qk_rope_head_dim),
sin.view(-1, 1, 1, self.qk_rope_head_dim),
prefill_slot_mapping.to(torch.int64),
rope_cache,
nope_cache,
k_rope_scale=None,
c_kv_scale=None,
k_rope_offset=None,
c_kv_offset=None,
epsilon=self.kv_a_layernorm.variance_epsilon,
cache_mode="PA")
topk_indices = self.indexer_select(x=hidden_states_prefill,
qr=prefill_qr,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
is_prefill=True)
query_states = (prefill_q_nope, prefill_q_pe)
key_states = (prefill_k_nope, prefill_k_pe)
q_nope, q_pe = query_states
k_nope, k_rope = key_states
prefill_metadata = attn_metadata.prefill
slc_fa_fusion = torch.ops.custom.npu_sparse_flash_attention(
query=q_nope,
key=k_nope,
value=k_nope,
sparse_indices=topk_indices,
scale_value=self.scale,
sparse_block_size=1,
block_table=prefill_metadata.block_table,
actual_seq_lengths_query=prefill_metadata.query_lens,
actual_seq_lengths_kv=prefill_metadata.seq_lens,
query_rope=q_pe,
key_rope=k_rope,
layout_query="TND",
layout_kv="PA_BSND",
sparse_mode=3,
)
slc_fa_fusion = slc_fa_fusion.transpose(0, 1)
# input shape [N//attn_tp_size, T(bs*q_len), D]
# output shape [T(bs*q_len), N//attn_tp_size, D]
attn_output = torch.matmul(
slc_fa_fusion, self.kv_b_proj_w_v).transpose(1, 0).reshape(
-1, self.num_heads * self.v_head_dim)
# o_proj_input[num_decode_tokens:] = attn_output
output[...] = self.o_proj(attn_output, is_force_scatter=True)
return output
elif attn_metadata.decode is not None:
if envs_ascend.VLLM_ASCEND_ENABLE_MLAPO:
prep_res = self._sfa_decode_preprocess(hidden_states, kv_cache,
attn_metadata,
need_gather_q_kv)
q_nope, q_pe = prep_res.query_states
k_nope, k_rope = prep_res.key_states
topk_indices = prep_res.topk_indices
else:
q_len = 1
hidden_states_decode = hidden_states
decode_kq = self.q_a_proj(hidden_states_decode) # q down
decode_q_c = self.q_a_layernorm(decode_kq) # q down layernorm
decode_kv_no_split = self.kv_a_proj_with_mqa(
hidden_states_decode) # c_kv
# self.actual_seq_length = torch.arange(1,self.graph_batch_size+1).to(torch.int32).npu()
# decode_q_c = q_c[:num_decode_tokens]
decode_slot_mapping = attn_metadata.slot_mapping
decode_q = self.q_b_proj(decode_q_c)
bsz, _ = decode_q.shape
decode_q = decode_q.view(bsz, self.num_heads, 1,
self.qk_head_dim) # [16, 16, 1, 192]
decode_q_nope, decode_q_pe = torch.split(
decode_q, [self.qk_nope_head_dim, self.qk_rope_head_dim],
dim=-1) # [..., 128/64]
decode_q_nope = decode_q_nope.view(
-1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1)
decode_q_nope = (torch.matmul(
decode_q_nope, self.kv_b_proj_w_k).transpose(1, 0).view(
bsz, q_len, self.num_heads, self.kv_lora_rank))
# stream2 kv
key_cache = kv_cache[0]
value_cache = kv_cache[1]
cos = attn_metadata.decode.cos # [16, 1, 1, 64]
sin = attn_metadata.decode.sin
cos_q, sin_q = cos, sin
cos = cos.view(-1, 1, 1, self.qk_rope_head_dim)
sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)
decode_kv_no_split = decode_kv_no_split.unsqueeze(1).unsqueeze(
1)
decode_k_rope, decode_k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
decode_kv_no_split,
self.kv_a_layernorm.weight,
cos,
sin,
decode_slot_mapping.to(torch.int64),
value_cache,
key_cache,
c_kv_scale=None,
epsilon=self.kv_a_layernorm.variance_epsilon,
cache_mode='PA') # adapter NZ
# nz_block_size = 16
# KVCACHE_NZ_DIM = 16
# decode_k_nope = decode_k_nope.view(block_num, 1, self.kv_lora_rank // nz_block_size, block_size, nz_block_size)
# decode_k_rope = decode_k_rope.view(block_num, 1, self.qk_rope_head_dim // KVCACHE_NZ_DIM, block_size, KVCACHE_NZ_DIM)
decode_q_pe = torch_npu.npu_interleave_rope(
decode_q_pe, cos, sin) # BNSD
decode_q_nope = decode_q_nope.view(bsz, self.num_heads,
self.kv_lora_rank)
decode_q_pe = decode_q_pe.view(bsz, self.num_heads, -1)
topk_indices = self.indexer_select(hidden_states_decode,
decode_q_c,
attn_metadata=attn_metadata,
kv_cache=kv_cache,
is_prefill=False)
query_states = (decode_q_nope, decode_q_pe)
key_states = (decode_k_nope, decode_k_rope)
q_nope, q_pe = query_states
k_nope, k_rope = key_states
decode_metadata = attn_metadata.decode
slc_fa_fusion = torch.ops.custom.npu_sparse_flash_attention(
query=q_nope,
key=k_nope,
value=k_nope,
sparse_indices=topk_indices,
scale_value=self.scale,
sparse_block_size=1,
block_table=attn_metadata.decode.block_table,
actual_seq_lengths_query=decode_metadata.actual_seq_lengths_q,
actual_seq_lengths_kv=decode_metadata.seq_lens,
query_rope=q_pe,
key_rope=k_rope,
layout_query="TND",
layout_kv="PA_BSND",
sparse_mode=3,
)
slc_fa_fusion = slc_fa_fusion.squeeze(1)
slc_fa_fusion = slc_fa_fusion.transpose(0, 1)
# input shape [N//attn_tp_size, T(bs*q_len), D]
# output shape [T(bs*q_len), N//attn_tp_size, D]
attn_output = torch.matmul(
slc_fa_fusion, self.kv_b_proj_w_v).transpose(1, 0).reshape(
-1, self.num_heads * self.v_head_dim)
output[...] = self.o_proj(attn_output)
return output
def mla_epilog(self,
attn_output: torch.Tensor = None,
absorb: bool = False):
# TODO:
attn_output = self.o_proj(attn_output)
return attn_output
def indexer_select(
self,
x: torch.Tensor,
qr: torch.Tensor,
kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
attn_metadata: M,
is_prefill: bool = True,
):
if attn_metadata.prefill is not None:
cos = attn_metadata.prefill.cos
sin = attn_metadata.prefill.sin
elif attn_metadata.decode is not None:
cos = attn_metadata.decode.cos
sin = attn_metadata.decode.sin
cos_q, sin_q = cos, sin
cos = cos.view(-1, 1, 1, self.qk_rope_head_dim)
sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)
# q process in new stream
q = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128]
q = q.view(-1, self.n_heads, self.head_dim) # [b,s,64,128]
q_pe, q_nope = torch.split(
q, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim],
dim=-1) # [b,s,64,64+64]
q_pe = q_pe.unsqueeze(2)
q_pe = torch_npu.npu_interleave_rope(q_pe, cos_q, sin_q)
q_pe = q_pe.squeeze(2)
q = torch.cat([q_pe, q_nope], dim=-1) # [b*s,64,128]
k_proj = self.wk(x) # [b,s,7168] @ [7168,128] = [b,s,128]
if self.enable_shared_expert_dp and is_prefill and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers:
k_proj = get_tp_group().all_gather(
k_proj, 0)[attn_metadata.num_decode_tokens:attn_metadata.
num_actual_tokens]
k = self.k_norm(k_proj).unsqueeze(1)
k_pe, k_nope = torch.split(
k, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim],
dim=-1) # [b,s,64+64]
k_pe = k_pe.unsqueeze(2)
k_pe = torch_npu.npu_interleave_rope(k_pe, cos, sin)
k_pe = k_pe.squeeze(2)
k = torch.cat([k_pe, k_nope], dim=-1) # [b*s,128]
if kv_cache is not None:
torch_npu.npu_scatter_nd_update_(kv_cache[2].view(-1, k.shape[-1]),
attn_metadata.slot_mapping.view(
-1, 1),
k.view(-1,
k.shape[-1])) # b, s, n, d
weights = self.weights_proj(x)
if self.enable_shared_expert_dp and is_prefill and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers:
weights = get_tp_group().all_gather(
weights, 0)[attn_metadata.num_decode_tokens:attn_metadata.
num_actual_tokens]
actual_seq_lengths_query = None
actual_seq_lengths_key = None
block_table = None
if attn_metadata.prefill is not None:
actual_seq_lengths_query = attn_metadata.prefill.query_lens
actual_seq_lengths_key = attn_metadata.prefill.seq_lens
block_table = attn_metadata.prefill.block_table
elif attn_metadata.decode is not None:
actual_seq_lengths_query = attn_metadata.decode.actual_seq_lengths_q
actual_seq_lengths_key = attn_metadata.decode.seq_lens.to(
torch.int32)
block_table = attn_metadata.decode.block_table
topk_indices = torch.ops.custom.npu_lightning_indexer(
query=q,
key=kv_cache[2],
weights=weights,
actual_seq_lengths_query=actual_seq_lengths_query,
actual_seq_lengths_key=actual_seq_lengths_key,
block_table=block_table,
layout_query="TND",
layout_key="PA_BSND",
sparse_count=2048,
sparse_mode=3)
return topk_indices
def round_up(val: int, align: int) -> int:
if align == 0:
return 0
return -(val // -align) * align
def trans_rope_weight(weight, rope_dim):
weight_1 = weight[..., -rope_dim::2, :].contiguous()
weight_2 = weight[..., -rope_dim + 1::2, :].contiguous()
weight[..., -rope_dim:, :] = torch.cat([weight_1, weight_2], dim=-2)
return weight.contiguous()
def transdata(nd_mat, block_size: tuple = (16, 16)):
r = round_up(nd_mat.shape[0], block_size[0])
c = round_up(nd_mat.shape[1], block_size[1])
r_pad = r - nd_mat.shape[0]
c_pad = c - nd_mat.shape[1]
nd_mat = F.pad(nd_mat, ((0, r_pad, 0, c_pad)))
nz_mat = torch.permute(
torch.reshape(
nd_mat,
(r // block_size[0], block_size[0], c // block_size[1],
block_size[1]),
),
[2, 0, 1, 3],
)
nz_mat = torch.reshape(
nz_mat,
(nz_mat.shape[0], nz_mat.shape[1] * nz_mat.shape[2], nz_mat.shape[3]))
return nz_mat