2025-04-19 17:38:18 +08:00
from dataclasses import dataclass
2025-07-26 15:43:29 +08:00
from typing import TYPE_CHECKING , Optional , Tuple , Type , TypeVar
2025-04-19 17:38:18 +08:00
2025-05-12 19:14:07 +08:00
import numpy as np
2025-04-19 17:38:18 +08:00
import torch
import torch_npu
from vllm . attention . backends . abstract import ( AttentionBackend , AttentionLayer ,
2025-07-26 15:43:29 +08:00
AttentionMetadata ,
2025-04-19 17:38:18 +08:00
MLAAttentionImpl )
2025-05-12 19:14:07 +08:00
from vllm . attention . backends . utils import PAD_SLOT_ID
2025-06-09 22:21:42 +08:00
from vllm . config import get_current_vllm_config
2025-06-25 19:56:49 +08:00
from vllm . distributed import get_tensor_model_parallel_world_size
2025-06-04 20:26:44 +08:00
from vllm . model_executor . layers . linear import ( LinearBase ,
2025-04-19 17:38:18 +08:00
UnquantizedLinearMethod )
2025-06-14 22:31:16 +08:00
from vllm . utils import cdiv , round_down
2025-04-19 17:38:18 +08:00
2025-07-26 17:15:47 +08:00
from vllm_ascend import envs
2025-06-05 16:28:01 +08:00
from vllm_ascend . ascend_config import get_ascend_config
2025-04-19 17:38:18 +08:00
from vllm_ascend . attention . attention_v1 import AscendAttentionState
2025-06-07 16:46:58 +08:00
from vllm_ascend . multistream . base import MSAttentionMetadataSplitConfig
from vllm_ascend . multistream . context import get_multistream_comm_context
from vllm_ascend . multistream . ms_split import model_input_split_v1_mla_attn
2025-04-19 17:38:18 +08:00
from vllm_ascend . ops . attention import vanilla_chunked_prefill_mla
2025-07-21 19:43:30 +08:00
from vllm_ascend . torchair . utils import npu_stream_switch , npu_wait_tensor
2025-07-26 15:43:29 +08:00
from vllm_ascend . utils import npu_prefetch
2025-07-02 17:42:53 +08:00
from vllm_ascend . worker . npu_input_batch import InputBatch
2025-04-19 17:38:18 +08:00
if TYPE_CHECKING :
from vllm . v1 . core . sched . output import SchedulerOutput
class AscendMLABackend ( AttentionBackend ) :
accept_output_buffer : bool = True
@staticmethod
def get_name ( ) - > str :
2025-07-24 10:23:34 +08:00
return " ASCEND_MLA "
2025-04-19 17:38:18 +08:00
@staticmethod
def get_metadata_cls ( ) - > type [ " AttentionMetadata " ] :
return AscendMLAMetadata
@staticmethod
def get_builder_cls ( ) :
return AscendMLAMetadataBuilder
@staticmethod
def get_kv_cache_shape ( num_blocks : int , block_size : int , num_kv_heads : int ,
head_size : int ) - > tuple [ int , . . . ] :
return ( num_blocks , block_size , num_kv_heads , head_size )
@staticmethod
def get_impl_cls ( ) - > Type [ " MLAAttentionImpl " ] :
return AscendMLAImpl
@dataclass
class AscendMLAPrefillMetadata :
""" Prefill Specific Metadata for Ascend """
2025-06-14 22:31:16 +08:00
@dataclass
class ChunkedContextMetadata :
# New for MLA (compared to FlashAttention)
# For handling chunked prefill
cu_seq_lens : torch . Tensor
starts : torch . Tensor
seq_tot : list [ int ]
max_seq_lens : list [ int ]
workspace : torch . Tensor
chunk_seq_lens : torch . Tensor
2025-04-19 17:38:18 +08:00
attn_mask : torch . Tensor
query_lens : list [ int ]
2025-05-12 19:14:07 +08:00
seq_lens : list [ int ]
2025-04-19 17:38:18 +08:00
context_lens : torch . Tensor
input_positions : torch . Tensor
2025-05-30 08:59:58 +08:00
query_start_loc : torch . Tensor
2025-04-19 17:38:18 +08:00
block_table : torch . Tensor
max_query_len : int
2025-04-29 17:12:03 +08:00
max_seq_lens : int
2025-06-14 22:31:16 +08:00
chunked_context : Optional [ ChunkedContextMetadata ] = None
2025-07-29 18:06:45 +08:00
sin : torch . Tensor = None
cos : torch . Tensor = None
2025-04-19 17:38:18 +08:00
@dataclass
class AscendMLADecodeMetadata :
# Input positions for rotrary embeddings since for MLA the rotary
# position embeddings are applied inside the attention backend
input_positions : torch . Tensor
block_table : torch . Tensor
seq_lens : torch . Tensor
2025-04-29 17:12:03 +08:00
max_seq_lens : int
2025-05-12 19:14:07 +08:00
seq_lens_list : list [ int ]
2025-06-09 22:21:42 +08:00
attn_mask : Optional [ torch . Tensor ] = None
2025-07-29 18:06:45 +08:00
sin : torch . Tensor = None
cos : torch . Tensor = None
2025-04-19 17:38:18 +08:00
@dataclass
class AscendMLAMetadata :
""" Metadata for MLACommon.
NOTE : Please read the comment at the top of the file before trying to
understand this class
"""
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
num_actual_tokens : int # Number of tokens excluding padding.
slot_mapping : torch . Tensor
2025-05-30 08:59:58 +08:00
query_start_loc : torch . Tensor
seq_lens : torch . Tensor
block_tables : torch . Tensor
2025-04-19 17:38:18 +08:00
# New for MLA (compared to FlashAttention)
# For handling prefill decode split
num_decodes : int
num_decode_tokens : int
num_prefills : int
# For logging.
num_input_tokens : int = 0 # Number of tokens including padding.
2025-06-07 16:46:58 +08:00
query_lens : Optional [ list [ int ] ] = None
2025-04-19 17:38:18 +08:00
# The dimension of the attention heads
head_dim : Optional [ int ] = None
attn_mask : torch . Tensor = None
# chunked prefill by default if no attn_states passed
attn_state : AscendAttentionState = AscendAttentionState . ChunkedPrefill
decode : Optional [ AscendMLADecodeMetadata ] = None
prefill : Optional [ AscendMLAPrefillMetadata ] = None
2025-08-01 09:08:45 +08:00
enable_dbo_across_dp : bool = False
2025-04-19 17:38:18 +08:00
def __post_init__ ( self ) :
pass
# supported_head_sizes = AscendMLABackend.get_supported_head_sizes()
# if self.head_dim is not None and self.head_dim \
# not in supported_head_sizes:
# raise ValueError(
# f"Only {supported_head_sizes} are supported for head_dim,",
# f"received {self.head_dim}.")
2025-06-07 16:46:58 +08:00
def split_metadata_for_multistream (
self ,
ms_split_config : MSAttentionMetadataSplitConfig ,
) - > list [ " AscendMLAMetadata " ] :
""" Split metadata for multi-stream with AscendMLAMetadata """
return model_input_split_v1_mla_attn (
ms_split_config = ms_split_config ,
attn_metadata = self ,
_metadata_cls = AscendMLAMetadata ,
)
2025-04-19 17:38:18 +08:00
M = TypeVar ( " M " , bound = AscendMLAMetadata )
class AscendMLAMetadataBuilder :
"""
NOTE : Please read the comment at the top of the file before trying to
understand this class
"""
# _attn_mask_builder = None
def __init__ ( self ,
2025-05-30 08:59:58 +08:00
runner ,
2025-04-19 17:38:18 +08:00
metadata_cls : Optional [ AscendMLAMetadata ] = None ) :
self . metadata_cls : Optional [ AscendMLAMetadata ] = metadata_cls \
if metadata_cls is not None else AscendMLAMetadata # type: ignore
self . runner = runner
scheduler_config = runner . scheduler_config
2025-06-14 22:31:16 +08:00
model_config = runner . model_config
self . block_size = runner . block_size
self . chunked_prefill_enabled = runner . chunked_prefill_enabled
if self . chunked_prefill_enabled :
self . chunked_prefill_workspace_size = min (
# Max sure there is enough for 8 full length request or at least
# 4 pages of cache per request
max ( 8 * model_config . max_model_len ,
4 * scheduler_config . max_num_seqs * self . block_size ) ,
# For long-context models try not to over-allocate limiting
# kv-cache space, limiting it to 64k tokens,
# which would result in the workspace being:
# 2*(576)*(64*1024) = 144mb
# (assuming 576 MLA head dim, and fp16)
# which would result in up-projected context being
# 2*(192*128)*(64*1024) = 3gb
# (assuming 192 QK head dim, 128 heads, and fp16)
128 * 1024 )
assert self . chunked_prefill_workspace_size > = \
scheduler_config . max_num_seqs * self . block_size
self . chunked_prefill_workspace = torch . empty (
( self . chunked_prefill_workspace_size ,
model_config . get_head_size ( ) ) ,
dtype = model_config . dtype ,
device = runner . device ,
)
2025-06-09 22:21:42 +08:00
ascend_config = get_ascend_config ( )
self . torchair_graph_enabled = ascend_config . torchair_graph_config . enabled
2025-07-29 18:06:45 +08:00
self . rope_dim = self . runner . model_config . hf_text_config . qk_rope_head_dim
self . cos_cache = None
self . sin_cache = None
2025-04-19 17:38:18 +08:00
def reorder_batch ( self , input_batch : " InputBatch " ,
scheduler_output : " SchedulerOutput " ) - > bool :
# We now want to reorder the batch so that the "decode" requests are at
# the front and the "prefill" requests are at the using the least amount
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
# where attention is likely memory-bound and "prefill" to mean requests
# where attention is likely compute-bound, TODO(lucas): figure out a
# better naming here)
decodes = [ ]
prefills = [ ]
num_decode_tokens = 0
num_prefill_tokens = 0
for i , req_id in enumerate ( input_batch . req_ids ) :
num_tokens = scheduler_output . num_scheduled_tokens [ req_id ]
2025-06-09 22:21:42 +08:00
num_spec_tokens = len (
scheduler_output . scheduled_spec_decode_tokens . get ( req_id , [ ] ) )
# For torch air graph mode we treat spec decoding as decode.
if self . torchair_graph_enabled :
if num_tokens - num_spec_tokens == 1 :
decodes . append ( i )
num_decode_tokens + = num_tokens
else :
prefills . append ( i )
num_prefill_tokens + = num_tokens
# For eager mode we treat spec decoding as chunked prefill.
2025-04-19 17:38:18 +08:00
else :
2025-06-09 22:21:42 +08:00
if num_tokens == 1 :
decodes . append ( i )
num_decode_tokens + = num_tokens
else :
prefills . append ( i )
num_prefill_tokens + = num_tokens
2025-04-19 17:38:18 +08:00
# We hope that this is fairly minimal since decodes
# should be around for a number of iterations so hopefully they are
# relatively stationary (and new request are generally appended to the
# persistent batch so already should be at the back)
# To achieve this we loop over the decodes in descending order and
# the prefills in ascending order. We swap decodes from the "back"
# i.e. past where the last decode should be in the reodorered with
# prefills from the front of the batch.
# `decodes` and `prefills` are already in ascending order just based on
# the above loop
num_decodes = len ( decodes )
num_prefills = len ( prefills )
first_prefill = 0
modified_batch = False
for i in range ( 1 , min ( num_decodes , num_prefills ) + 1 ) :
# If the decode is at the "back" of the batch, i, we can swap it
# with the prefill closest to the front of the batch
if decodes [ num_decodes - i ] > = num_decodes :
input_batch . swap_states ( prefills [ first_prefill ] ,
decodes [ num_decodes - i ] )
first_prefill + = 1
modified_batch = True
else :
break
# Save for next `build` call
# TODO(lucas): this is a bit of a hack, we should probably have a
# better way of doing this
self . _num_decodes = num_decodes
self . _num_prefills = num_prefills
self . _num_decode_tokens = num_decode_tokens
self . _num_prefill_tokens = num_prefill_tokens
return modified_batch
2025-05-12 19:14:07 +08:00
def _get_graph_runner_block_tables (
self , num_seqs : int , block_tables : torch . Tensor ) - > torch . Tensor :
max_batch_size , max_blocks = self . runner . graph_block_tables . shape
assert max_batch_size > = num_seqs
if isinstance ( self . runner . graph_block_tables , np . ndarray ) :
graph_block_tables = torch . zeros ( ( max_batch_size , max_blocks ) ,
dtype = block_tables . dtype ,
device = block_tables . device )
else :
graph_block_tables = self . runner . graph_block_tables . to (
device = block_tables . device , dtype = block_tables . dtype )
num_blocks = block_tables . size ( 1 )
if num_blocks < = max_blocks :
graph_block_tables [ : num_seqs , :
num_blocks ] = block_tables [ : num_seqs , :
num_blocks ]
else :
graph_block_tables [ : num_seqs , :
max_blocks ] = block_tables [ : num_seqs , :
max_blocks ]
2025-05-31 06:03:03 +08:00
return graph_block_tables [ : num_seqs , : max_blocks ]
2025-07-28 14:06:20 +08:00
def build_torchair_graph_dummy (
self , num_reqs : int , num_actual_tokens : int ) - > AscendMLAMetadata :
2025-05-31 06:03:03 +08:00
device = self . runner . device
_ , max_blocks = self . runner . graph_block_tables . shape
block_table = torch . zeros ( ( num_reqs , max_blocks ) ,
dtype = torch . int32 ,
device = device )
block_table = self . _get_graph_runner_block_tables (
num_reqs , block_table )
seq_lens = torch . ones ( num_reqs , dtype = torch . int32 , device = device )
input_positions = torch . zeros ( num_reqs ,
dtype = torch . int32 ,
device = device ) . long ( )
slot_mapping = torch . full ( ( num_reqs , ) ,
PAD_SLOT_ID ,
dtype = torch . int32 ,
device = device )
2025-06-04 18:31:41 +08:00
query_start_loc = torch . full ( ( num_reqs , ) ,
- 1 ,
dtype = torch . int32 ,
device = device )
2025-07-29 18:06:45 +08:00
sin = torch . ones ( num_reqs ,
1 ,
1 ,
self . rope_dim ,
dtype = self . runner . dtype ,
device = device )
cos = torch . ones ( num_reqs ,
1 ,
1 ,
self . rope_dim ,
dtype = self . runner . dtype ,
device = device )
2025-05-31 06:03:03 +08:00
decode_metadata = AscendMLADecodeMetadata (
input_positions = input_positions ,
block_table = block_table ,
seq_lens = seq_lens ,
seq_lens_list = seq_lens . tolist ( ) ,
2025-06-09 22:21:42 +08:00
max_seq_lens = 1 ,
2025-07-29 18:06:45 +08:00
attn_mask = self . runner . spec_attn_mask ,
sin = sin ,
cos = cos )
2025-05-31 06:03:03 +08:00
return self . metadata_cls ( # type: ignore
num_input_tokens = num_actual_tokens ,
num_actual_tokens = num_actual_tokens ,
slot_mapping = slot_mapping ,
head_dim = self . runner . model_config . get_head_size ( ) ,
num_decodes = 1 ,
num_decode_tokens = 1 ,
num_prefills = 0 ,
attn_mask = self . runner . attn_mask ,
attn_state = AscendAttentionState . DecodeOnly ,
prefill = None ,
decode = decode_metadata ,
2025-06-04 18:31:41 +08:00
query_start_loc = query_start_loc ,
seq_lens = seq_lens ,
block_tables = block_table ,
2025-05-31 06:03:03 +08:00
)
2025-05-12 19:14:07 +08:00
2025-06-04 18:31:41 +08:00
def build (
self ,
num_reqs : int ,
num_actual_tokens : int ,
max_query_len : int ,
graph_pad_size : int = - 1 ,
2025-07-24 10:23:34 +08:00
query_start_loc : torch . Tensor = None ,
2025-08-01 09:08:45 +08:00
enable_dbo_across_dp : bool = False ,
2025-06-04 18:31:41 +08:00
) - > AscendMLAMetadata :
2025-04-19 17:38:18 +08:00
assert self . _num_decodes + self . _num_prefills == num_reqs
# Note(simon): be careful about the CPU <> GPU memory movement in this
# function. We should avoid GPU -> CPU sync as much as possible because
# it blocks on all previous kernels.
device = self . runner . device
2025-05-28 21:18:41 +08:00
2025-05-29 11:57:43 +08:00
block_table = ( self . runner . input_batch . block_table [ 0 ] .
get_device_tensor ( ) [ : num_reqs ] )
2025-04-19 17:38:18 +08:00
slot_mapping = self . runner . slot_mapping_cpu [ : num_actual_tokens ] . to (
2025-04-28 21:54:42 +08:00
device , non_blocking = True )
2025-04-19 17:38:18 +08:00
input_positions = self . runner . positions_cpu [ : num_actual_tokens ] . to (
device , non_blocking = True ) . long ( )
seq_lens_cpu = self . runner . seq_lens_cpu [ : num_reqs ]
query_lens = seq_lens_cpu - self . runner . input_batch . num_computed_tokens_cpu_tensor [ :
num_reqs ]
seq_lens = seq_lens_cpu
max_query_len = query_lens . max ( ) . item ( )
2025-04-29 17:12:03 +08:00
max_seq_lens = seq_lens . max ( ) . item ( )
2025-07-29 18:06:45 +08:00
if self . cos_cache is None :
self . cos_cache = self . runner . get_model (
) . model . layers [ 0 ] . self_attn . rotary_emb . cos_cached
self . sin_cache = self . runner . get_model (
) . model . layers [ 0 ] . self_attn . rotary_emb . sin_cached
if self . cos_cache . dtype != self . runner . dtype : # type: ignore
self . cos_cache = self . cos_cache . to ( # type: ignore
self . runner . dtype ) # type: ignore
self . sin_cache = self . sin_cache . to ( # type: ignore
self . runner . dtype ) # type: ignore
2025-04-19 17:38:18 +08:00
prefill_metadata = None
2025-06-14 22:31:16 +08:00
chunked_context_metadata = None
2025-04-19 17:38:18 +08:00
if self . _num_prefills > 0 :
reqs_start = self . _num_decodes # prefill_start
tokens_start = self . _num_decode_tokens
2025-04-29 17:12:03 +08:00
max_query_len = query_lens [ tokens_start : ] . max ( ) . item ( )
max_seq_lens = seq_lens [ tokens_start : ] . max ( ) . item ( )
2025-05-30 08:59:58 +08:00
prefill_query_start_loc = query_start_loc [
reqs_start : ] - query_start_loc [ reqs_start ]
2025-04-19 17:38:18 +08:00
2025-06-14 22:31:16 +08:00
context_lens_cpu = self . runner . input_batch . num_computed_tokens_cpu_tensor [
reqs_start : num_reqs ]
max_context_len_cpu = context_lens_cpu . max ( ) . item ( )
num_prefills_with_context_cpu = ( context_lens_cpu > 0 ) . sum ( ) . item ( )
if self . chunked_prefill_enabled and max_context_len_cpu > 0 :
max_context_chunk = ( self . chunked_prefill_workspace_size / /
num_prefills_with_context_cpu )
max_context_chunk = round_down ( max_context_chunk ,
self . block_size )
assert max_context_chunk > 0
num_chunks = cdiv ( max_context_len_cpu , max_context_chunk )
chunk_starts = torch . arange ( num_chunks , dtype = torch . int32 ) \
. unsqueeze ( 1 ) . expand ( - 1 , self . _num_prefills ) * max_context_chunk
chunk_ends = torch . min ( context_lens_cpu . unsqueeze ( 0 ) ,
chunk_starts + max_context_chunk )
chunk_seq_lens = ( chunk_ends - chunk_starts ) . clamp ( min = 0 )
cu_seq_lens_cpu = torch . zeros ( num_chunks ,
self . _num_prefills + 1 ,
dtype = torch . int32 ,
pin_memory = True )
torch . cumsum ( chunk_seq_lens ,
dim = 1 ,
out = cu_seq_lens_cpu [ : , 1 : ] ,
dtype = torch . int32 )
chunked_context_metadata = \
AscendMLAPrefillMetadata . ChunkedContextMetadata (
cu_seq_lens = cu_seq_lens_cpu . to ( device , non_blocking = True ) ,
starts = chunk_starts . to ( device , non_blocking = True ) ,
seq_tot = chunk_seq_lens . sum ( dim = 1 ) . tolist ( ) ,
max_seq_lens = chunk_seq_lens . max ( dim = 1 ) . values . tolist ( ) ,
chunk_seq_lens = chunk_seq_lens ,
workspace = self . chunked_prefill_workspace ,
)
2025-07-29 18:06:45 +08:00
prefill_input_positions = input_positions [ tokens_start : ]
cos = self . cos_cache [
prefill_input_positions ] . unsqueeze ( # type: ignore
1 ) . unsqueeze ( 2 )
sin = self . sin_cache [
prefill_input_positions ] . unsqueeze ( # type: ignore
1 ) . unsqueeze ( 2 )
2025-04-19 17:38:18 +08:00
prefill_metadata = AscendMLAPrefillMetadata (
attn_mask = self . runner . attn_mask ,
query_lens = query_lens [ tokens_start : ] ,
2025-05-12 19:14:07 +08:00
seq_lens = seq_lens ,
2025-04-19 17:38:18 +08:00
context_lens = seq_lens [ tokens_start : ] ,
2025-07-29 18:06:45 +08:00
input_positions = prefill_input_positions ,
2025-04-19 17:38:18 +08:00
block_table = block_table [ reqs_start : , . . . ] ,
max_query_len = max_query_len ,
2025-04-29 17:12:03 +08:00
max_seq_lens = max_seq_lens ,
2025-05-30 08:59:58 +08:00
query_start_loc = prefill_query_start_loc ,
2025-06-14 22:31:16 +08:00
chunked_context = chunked_context_metadata ,
2025-07-29 18:06:45 +08:00
sin = sin ,
cos = cos ,
2025-04-19 17:38:18 +08:00
)
decode_metadata = None
2025-05-12 19:14:07 +08:00
use_torchair_graph = graph_pad_size != - 1
2025-04-19 17:38:18 +08:00
if self . _num_decodes > 0 :
2025-04-29 17:12:03 +08:00
max_seq_lens = seq_lens [ : self . _num_decodes ] . max ( ) . item ( )
2025-05-12 19:14:07 +08:00
seq_lens = seq_lens [ : self . _num_decode_tokens ]
input_positions = input_positions [ : self . _num_decode_tokens ]
block_table = block_table [ : self . _num_decode_tokens , . . . ]
2025-06-10 22:20:40 +08:00
if use_torchair_graph and self . runner . attn_state in [
AscendAttentionState . DecodeOnly ,
AscendAttentionState . SpecDecoding
] :
2025-05-12 19:14:07 +08:00
num_seqs = len ( seq_lens )
if graph_pad_size != 0 :
pad_value = 1
padded_seq_lens = seq_lens . tolist ( ) + [ pad_value
] * graph_pad_size
else :
padded_seq_lens = seq_lens . tolist ( )
seq_lens = torch . from_numpy (
np . array ( padded_seq_lens ) . astype ( np . int32 ) )
padding = torch . full ( ( graph_pad_size , ) ,
PAD_SLOT_ID ,
dtype = slot_mapping . dtype ,
device = slot_mapping . device )
slot_mapping = torch . cat ( [ slot_mapping , padding ] )
block_table_padding = torch . zeros (
( graph_pad_size , ) + block_table . shape [ 1 : ] ,
dtype = block_table . dtype ,
device = block_table . device )
block_table = torch . cat ( [ block_table , block_table_padding ] ,
dim = 0 )
block_table = self . _get_graph_runner_block_tables (
2025-05-31 06:03:03 +08:00
num_seqs + graph_pad_size , block_table )
2025-05-12 19:14:07 +08:00
padding_0 = torch . zeros ( graph_pad_size ,
dtype = input_positions . dtype ,
device = input_positions . device )
input_positions = torch . cat ( [ input_positions , padding_0 ] )
2025-07-29 18:06:45 +08:00
cos = self . cos_cache [ input_positions ] . unsqueeze ( # type: ignore
1 ) . unsqueeze ( 2 )
sin = self . sin_cache [ input_positions ] . unsqueeze ( # type: ignore
1 ) . unsqueeze ( 2 )
2025-05-12 19:14:07 +08:00
2025-04-19 17:38:18 +08:00
decode_metadata = AscendMLADecodeMetadata (
2025-05-12 19:14:07 +08:00
input_positions = input_positions ,
block_table = block_table ,
seq_lens = seq_lens ,
seq_lens_list = seq_lens . tolist ( ) ,
2025-06-09 22:21:42 +08:00
max_seq_lens = max_seq_lens ,
2025-07-29 18:06:45 +08:00
attn_mask = self . runner . spec_attn_mask ,
sin = sin ,
cos = cos )
2025-04-19 17:38:18 +08:00
return self . metadata_cls ( # type: ignore
num_actual_tokens = num_actual_tokens ,
2025-06-07 16:46:58 +08:00
query_lens = query_lens . tolist ( ) ,
2025-04-19 17:38:18 +08:00
slot_mapping = slot_mapping ,
head_dim = self . runner . model_config . get_head_size ( ) ,
num_decodes = self . _num_decodes ,
num_decode_tokens = self . _num_decode_tokens ,
num_prefills = self . _num_prefills ,
attn_mask = self . runner . attn_mask ,
attn_state = self . runner . attn_state ,
prefill = prefill_metadata ,
decode = decode_metadata ,
2025-05-30 08:59:58 +08:00
query_start_loc = query_start_loc ,
block_tables = block_table ,
seq_lens = seq_lens ,
2025-08-01 09:08:45 +08:00
enable_dbo_across_dp = enable_dbo_across_dp ,
2025-04-19 17:38:18 +08:00
)
class AscendMLAImpl ( MLAAttentionImpl ) :
"""
NOTE : Please read the comment at the top of the file before trying to
understand this class
"""
def __init__ (
self ,
num_heads : int ,
head_size : int ,
scale : float ,
num_kv_heads : int ,
alibi_slopes : Optional [ list [ float ] ] ,
sliding_window : Optional [ int ] ,
kv_cache_dtype : str ,
logits_soft_cap : Optional [ float ] ,
attn_type : str ,
2025-07-24 10:23:34 +08:00
kv_sharing_target_layer_name : Optional [ str ] ,
2025-04-19 17:38:18 +08:00
* * kwargs ,
) - > None :
self . num_heads = num_heads
self . head_size = head_size
self . scale = float ( scale )
self . num_kv_heads = num_kv_heads
self . kv_cache_dtype = kv_cache_dtype
2025-06-04 20:26:44 +08:00
# MLA Args
self . q_lora_rank = kwargs [ ' q_lora_rank ' ]
self . kv_lora_rank = kwargs [ ' kv_lora_rank ' ]
self . qk_nope_head_dim = kwargs [ ' qk_nope_head_dim ' ]
self . qk_rope_head_dim = kwargs [ ' qk_rope_head_dim ' ]
self . qk_head_dim = kwargs [ ' qk_head_dim ' ]
self . v_head_dim = kwargs [ ' v_head_dim ' ]
self . rotary_emb = kwargs [ ' rotary_emb ' ]
self . q_proj = kwargs [ ' q_proj ' ]
self . kv_b_proj = kwargs [ ' kv_b_proj ' ]
self . o_proj = kwargs [ ' o_proj ' ]
2025-05-12 19:14:07 +08:00
self . kv_a_proj_with_mqa = kwargs . get ( ' kv_a_proj_with_mqa ' , None )
self . kv_a_layernorm = kwargs . get ( ' kv_a_layernorm ' , None )
2025-06-10 22:26:53 +08:00
self . num_queries_per_kv = self . num_heads / / self . num_kv_heads
2025-06-25 19:56:49 +08:00
self . tp_size = get_tensor_model_parallel_world_size ( )
2025-06-04 20:26:44 +08:00
2025-06-05 16:28:01 +08:00
ascend_config = get_ascend_config ( )
self . torchair_graph_enabled = ascend_config . torchair_graph_config . enabled
2025-06-11 14:09:28 +08:00
self . enable_kv_nz = ascend_config . torchair_graph_config . enable_kv_nz
2025-06-12 21:42:09 +08:00
2025-06-09 22:21:42 +08:00
# Adapt torch air graph mode with spec decoding.
speculative_config = get_current_vllm_config ( ) . speculative_config
if speculative_config is not None :
self . spec_token_num = speculative_config . num_speculative_tokens
assert self . spec_token_num > 0
2025-05-12 19:14:07 +08:00
2025-07-11 08:51:17 +08:00
def _v_up_proj_and_o_proj ( self , x , enable_multistream_mla : bool = False ) :
2025-04-19 17:38:18 +08:00
# Convert from (B, N, L) to (N, B, L)
x = x . view ( - 1 , self . num_heads , self . kv_lora_rank ) . transpose ( 0 , 1 )
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
x = torch . bmm ( x , self . W_UV )
# Convert from (N, B, V) to (B, N * V)
x = x . transpose ( 0 , 1 ) . reshape ( - 1 , self . num_heads * self . v_head_dim )
2025-07-11 08:51:17 +08:00
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 # 16MB
npu_prefetch ( self . o_proj . weight ,
x ,
max_size = MAX_O_PROJ_PREFETCH_SIZE ,
enabled = enable_multistream_mla )
2025-06-25 19:56:49 +08:00
return self . o_proj ( x , is_prefill = False ) [ 0 ]
2025-04-19 17:38:18 +08:00
# Return `ql_nope`, `q_pe`
def _q_proj_and_k_up_proj ( self , x ) :
q_nope , q_pe = self . q_proj ( x ) [ 0 ] \
. view ( - 1 , self . num_heads , self . qk_head_dim ) \
. split ( [ self . qk_nope_head_dim , self . qk_rope_head_dim ] , dim = - 1 )
# Convert from (B, N, P) to (N, B, P)
q_nope = q_nope . transpose ( 0 , 1 )
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
ql_nope = torch . bmm ( q_nope , self . W_UK_T )
# Convert from (N, B, L) to (B, N, L)
return ql_nope . transpose ( 0 , 1 ) , q_pe
def process_weights_after_loading ( self , act_dtype : torch . dtype ) :
def get_layer_weight ( layer ) :
WEIGHT_NAMES = ( " weight " , " qweight " , " weight_packed " )
for attr in WEIGHT_NAMES :
if hasattr ( layer , attr ) :
return getattr ( layer , attr )
raise AttributeError (
f " Layer ' { layer } ' has no recognized weight attribute: "
f " { WEIGHT_NAMES } . " )
def get_and_maybe_dequant_weights ( layer : LinearBase ) :
if not isinstance ( layer . quant_method , UnquantizedLinearMethod ) :
# NOTE: This should only be used offline, since it's O(N^3)
eye = torch . eye ( layer . input_size_per_partition ,
dtype = act_dtype ,
device = get_layer_weight ( layer ) . device )
dequant_weights = layer . quant_method . apply ( layer ,
eye ,
bias = None )
del eye
# standardize to (output, input)
return dequant_weights . T
return layer . weight
# we currently do not have quantized bmm's which are needed for
# `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
kv_b_proj_weight = get_and_maybe_dequant_weights ( self . kv_b_proj ) . T
assert kv_b_proj_weight . shape == (
self . kv_lora_rank ,
self . num_heads * ( self . qk_nope_head_dim + self . v_head_dim ) ) , (
f " { kv_b_proj_weight . shape =} , "
f " { self . kv_lora_rank =} , "
f " { self . num_heads =} , "
f " { self . qk_nope_head_dim =} , "
f " { self . v_head_dim =} " )
kv_b_proj_weight = kv_b_proj_weight . view (
self . kv_lora_rank ,
self . num_heads ,
self . qk_nope_head_dim + self . v_head_dim ,
)
W_UK , W_UV = kv_b_proj_weight . split (
[ self . qk_nope_head_dim , self . v_head_dim ] , dim = - 1 )
# Convert from (L, N, V) to (N, L, V)
2025-05-23 10:18:10 +08:00
self . W_UV = W_UV . transpose ( 0 , 1 ) . contiguous ( )
2025-04-19 17:38:18 +08:00
# Convert from (L, N, P) to (N, P, L)
2025-05-23 10:18:10 +08:00
self . W_UK_T = W_UK . permute ( 1 , 2 , 0 ) . contiguous ( )
2025-06-15 19:57:02 +08:00
# Waiting for BMM NZ support
# self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29)
# self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29)
2025-04-19 17:38:18 +08:00
2025-06-14 22:31:16 +08:00
def _compute_prefill_context (
self ,
query : torch . Tensor ,
2025-07-26 17:15:47 +08:00
kv_c_and_k_pe_cache : Tuple [ torch . Tensor ] ,
2025-06-14 22:31:16 +08:00
rope_dim : int ,
attn_metadata : AscendMLAMetadata ,
prefix_output : torch . Tensor ,
prefix_lse : torch . Tensor ,
) :
2025-07-26 17:15:47 +08:00
assert len ( kv_c_and_k_pe_cache ) > 1
2025-06-14 22:31:16 +08:00
prefill_metadata = attn_metadata . prefill
if prefill_metadata is None or prefill_metadata . chunked_context is None :
return prefix_output , prefix_lse
iters = len ( prefill_metadata . chunked_context . seq_tot )
q_pe = query [ . . . , self . qk_nope_head_dim : ]
q_nope = query [ . . . , : self . qk_nope_head_dim ]
seq_len1 = torch . tensor ( prefill_metadata . query_lens , dtype = torch . int32 )
2025-07-26 17:15:47 +08:00
cache_kv_c = kv_c_and_k_pe_cache [ 0 ]
cache_k_pe = kv_c_and_k_pe_cache [ 1 ]
num_heads = cache_k_pe . size ( 2 )
latent_kv_dim = kv_c_and_k_pe_cache [ 0 ] . size ( - 1 )
2025-06-14 22:31:16 +08:00
for i in range ( iters ) :
toks = prefill_metadata . chunked_context . seq_tot [ i ]
seq_len2 = prefill_metadata . chunked_context . chunk_seq_lens [ i ]
seq_len = torch . stack ( [ seq_len1 , seq_len2 ] )
kv_c_normed = torch . empty ( toks ,
2025-07-26 17:15:47 +08:00
num_heads ,
2025-06-14 22:31:16 +08:00
latent_kv_dim ,
dtype = query . dtype ,
device = query . device )
k_pe = torch . empty ( toks ,
2025-07-26 17:15:47 +08:00
num_heads ,
2025-06-14 22:31:16 +08:00
rope_dim ,
dtype = query . dtype ,
device = query . device )
torch_npu . atb . npu_paged_cache_load (
cache_kv_c ,
cache_k_pe ,
prefill_metadata . block_table ,
seq_len2 . to ( query . device ) ,
seq_starts = prefill_metadata . chunked_context . starts [ i ] ,
key = kv_c_normed ,
value = k_pe ,
)
kv_c_normed = kv_c_normed . squeeze ( )
kv_nope = self . kv_b_proj ( kv_c_normed ) [ 0 ] . view ( \
- 1 , self . num_heads , self . qk_nope_head_dim + self . v_head_dim )
k_nope , v = kv_nope \
. split ( [ self . qk_nope_head_dim , self . v_head_dim ] , dim = - 1 )
k_pe = k_pe . expand ( ( * k_nope . shape [ : - 1 ] , - 1 ) )
mask = torch . triu (
torch . ones ( 512 , 512 , device = query . device , dtype = query . dtype ) ,
1 )
torch_npu . atb . npu_ring_mla (
q_nope = q_nope ,
q_rope = q_pe ,
k_nope = k_nope ,
k_rope = k_pe ,
value = v ,
mask = mask ,
seqlen = seq_len ,
head_num = self . num_heads ,
kv_head_num = self . num_heads ,
pre_out = prefix_output ,
prev_lse = prefix_lse ,
qk_scale = self . scale ,
kernel_type = " kernel_type_high_precision " ,
mask_type = " no_mask " ,
input_layout = " type_bsnd " ,
calc_type = " calc_type_default " ,
output = prefix_output ,
softmax_lse = prefix_lse )
return prefix_output , prefix_lse
2025-04-19 17:38:18 +08:00
def _forward_prefill (
self ,
query : torch . Tensor ,
kv_c_normed : torch . Tensor ,
k_pe : torch . Tensor ,
2025-07-26 17:15:47 +08:00
kv_c_and_k_pe_cache : Tuple [ torch . Tensor ] ,
2025-04-19 17:38:18 +08:00
attn_metadata : AscendMLAMetadata ,
) - > torch . Tensor :
assert attn_metadata . prefill is not None
2025-07-26 17:15:47 +08:00
assert len ( kv_c_and_k_pe_cache ) > 1
2025-04-19 17:38:18 +08:00
num_tokens = query . size ( 0 )
2025-06-14 22:31:16 +08:00
attn_output = torch . empty ( num_tokens ,
self . num_heads ,
self . v_head_dim ,
dtype = query . dtype ,
device = query . device )
k_nope , value = self . kv_b_proj ( kv_c_normed ) [ 0 ] . view (
- 1 , self . num_heads , self . qk_nope_head_dim + self . v_head_dim ) . split (
[ self . qk_nope_head_dim , self . v_head_dim ] , dim = - 1 )
k_pe = k_pe . expand ( ( * k_nope . shape [ : - 1 ] , - 1 ) )
2025-05-09 16:39:28 +08:00
# Here is only 2 possibility of input, ChunkedPrefill or PrefillNoCache
2025-06-14 22:31:16 +08:00
ascend_config = get_ascend_config ( )
2025-06-09 22:21:42 +08:00
if attn_metadata . attn_state in [
AscendAttentionState . ChunkedPrefill ,
2025-06-30 16:51:20 +08:00
AscendAttentionState . SpecDecoding ,
AscendAttentionState . PrefillCacheHit
2025-06-14 22:31:16 +08:00
] and not ascend_config . chunked_prefill_for_mla :
attn_output_torch = torch . empty ( num_tokens ,
self . num_heads * self . v_head_dim ,
dtype = query . dtype ,
device = query . device )
2025-04-29 17:12:03 +08:00
# current requests is chunked in prefill, disable flash attention with chunked prefill
vanilla_chunked_prefill_mla (
2025-06-14 22:31:16 +08:00
output = attn_output_torch ,
2025-04-29 17:12:03 +08:00
query = query ,
kv_cache = kv_c_and_k_pe_cache ,
block_tables = attn_metadata . prefill . block_table ,
query_lens = attn_metadata . prefill . query_lens ,
context_lens = attn_metadata . prefill . context_lens ,
kv_b_proj = self . kv_b_proj ,
max_query_len = attn_metadata . prefill . max_query_len ,
max_context_len = attn_metadata . prefill . max_seq_lens ,
nope_dim = self . qk_nope_head_dim ,
rope_dim = self . qk_rope_head_dim ,
v_head_dim = self . v_head_dim ,
scale = self . scale ,
alibi_slopes = None ,
causal = True )
2025-06-14 22:31:16 +08:00
elif attn_metadata . attn_state in [
AscendAttentionState . ChunkedPrefill ,
2025-06-30 16:51:20 +08:00
AscendAttentionState . SpecDecoding ,
AscendAttentionState . PrefillCacheHit
2025-06-14 22:31:16 +08:00
] :
attn_lse = torch . empty ( self . num_heads ,
num_tokens ,
dtype = torch . float32 ,
device = query . device )
q_pe = query [ . . . , self . qk_nope_head_dim : ]
q_nope = query [ . . . , : self . qk_nope_head_dim ]
mask = torch . triu (
torch . ones ( 512 , 512 , device = query . device , dtype = query . dtype ) ,
1 ) # 512: mask only support 512
if attn_metadata . num_prefills > 1 :
mask = mask . unsqueeze ( 0 ) . repeat ( attn_metadata . num_prefills , 1 ,
1 )
torch_npu . atb . npu_ring_mla (
q_nope = q_nope ,
q_rope = q_pe ,
k_nope = k_nope ,
k_rope = k_pe ,
value = value ,
mask = mask ,
seqlen = torch . tensor ( attn_metadata . prefill . query_lens ,
dtype = torch . int32 ) ,
head_num = self . num_heads ,
kv_head_num = self . num_heads ,
pre_out = None ,
prev_lse = None ,
qk_scale = self . scale ,
kernel_type = " kernel_type_high_precision " ,
mask_type = " mask_type_triu " ,
input_layout = " type_bsnd " ,
calc_type = " calc_type_first_ring " ,
output = attn_output ,
softmax_lse = attn_lse )
attn_output , attn_lse = self . _compute_prefill_context ( \
query , kv_c_and_k_pe_cache , self . qk_rope_head_dim , attn_metadata , attn_output , attn_lse )
2025-05-09 16:39:28 +08:00
elif attn_metadata . attn_state == AscendAttentionState . PrefillNoCache :
2025-06-14 22:31:16 +08:00
key = torch . cat ( ( k_nope , k_pe ) , dim = - 1 )
2025-04-29 17:12:03 +08:00
torch_npu . _npu_flash_attention (
2025-05-23 14:14:06 +08:00
query = query ,
key = key ,
value = value ,
2025-04-29 17:12:03 +08:00
mask = attn_metadata . attn_mask ,
seq_len = attn_metadata . prefill . context_lens ,
scale_value = self . scale ,
num_heads = self . num_heads ,
num_kv_heads = self . num_heads ,
out = attn_output )
2025-05-23 14:14:06 +08:00
attn_output = attn_output . view ( - 1 , self . num_heads , self . v_head_dim )
2025-04-29 17:12:03 +08:00
else :
raise RuntimeError (
2025-06-30 16:51:20 +08:00
" Unexpected path reached, AscendMLAImpl should only have PrefillNoCache, PrefillCacheHit, ChunkedPrefill and SpecDecoding scenario in forward prefill, please file a bug to vllm-ascend ! "
2025-04-29 17:12:03 +08:00
)
attn_output = attn_output . reshape (
2025-04-19 17:38:18 +08:00
[ num_tokens , self . num_heads * self . v_head_dim ] )
2025-06-14 22:31:16 +08:00
if attn_metadata . attn_state in [
AscendAttentionState . ChunkedPrefill ,
2025-06-30 16:51:20 +08:00
AscendAttentionState . SpecDecoding ,
AscendAttentionState . PrefillCacheHit
2025-06-14 22:31:16 +08:00
] and not ascend_config . chunked_prefill_for_mla :
attn_output = attn_output_torch
2025-06-07 16:46:58 +08:00
current_ms_metadata = get_multistream_comm_context ( )
if current_ms_metadata is None :
2025-06-25 19:56:49 +08:00
return self . o_proj ( attn_output , is_prefill = True ) [ 0 ]
2025-06-07 16:46:58 +08:00
else :
current_ms_metadata . before_comm_event . record ( )
with torch . npu . stream ( current_ms_metadata . comm_stream ) :
current_ms_metadata . before_comm_event . wait ( )
2025-06-25 19:56:49 +08:00
return self . o_proj ( attn_output , is_prefill = True ) [ 0 ]
2025-04-19 17:38:18 +08:00
2025-05-12 19:14:07 +08:00
def exec_kv (
self ,
hidden_states : torch . Tensor ,
cos : torch . Tensor ,
sin : torch . Tensor ,
kv_cache : Tuple ,
slots : torch . Tensor ,
) :
B = hidden_states . shape [ 0 ]
N = self . num_kv_heads
S = 1
kv = self . kv_a_proj_with_mqa ( hidden_states ) [ 0 ]
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
kv = kv . view ( B , N , S , self . kv_lora_rank + self . qk_rope_head_dim )
2025-06-11 14:09:28 +08:00
cache_mode = " PA_NZ " if self . enable_kv_nz else " PA "
2025-07-11 08:51:17 +08:00
k_pe , k_nope , _ , _ = torch_npu . npu_kv_rmsnorm_rope_cache (
kv ,
self . kv_a_layernorm . weight ,
cos ,
sin ,
slots . to ( torch . int64 ) ,
kv_cache [ 1 ] ,
kv_cache [ 0 ] ,
epsilon = self . kv_a_layernorm . variance_epsilon ,
cache_mode = cache_mode ,
)
return k_pe , k_nope , kv
2025-06-11 14:09:28 +08:00
def exec_kv_prefill (
self ,
hidden_states : torch . Tensor ,
cos : torch . Tensor ,
sin : torch . Tensor ,
kv_cache : Tuple ,
slots : torch . Tensor ,
) :
B = hidden_states . shape [ 0 ]
N = self . num_kv_heads
S = 1
kv = self . kv_a_proj_with_mqa ( hidden_states ) [ 0 ]
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
kv = kv . view ( B , N , S , self . kv_lora_rank + self . qk_rope_head_dim )
cache_mode = " PA_BLK_NZ " if self . enable_kv_nz else " PA "
_ , _ , k_pe , k_nope = torch_npu . npu_kv_rmsnorm_rope_cache (
kv ,
self . kv_a_layernorm . weight ,
cos ,
sin ,
slots . to ( torch . int64 ) ,
kv_cache [ 1 ] ,
kv_cache [ 0 ] ,
epsilon = self . kv_a_layernorm . variance_epsilon ,
cache_mode = cache_mode ,
is_output_kv = True ,
2025-05-12 19:14:07 +08:00
)
return k_pe , k_nope
def rope_single (
self ,
x : torch . Tensor ,
cos : torch . Tensor ,
sin : torch . Tensor ,
) - > torch . Tensor :
B , N , D = x . shape
S = 1
x = x . view ( B , N , S , D )
2025-06-04 18:31:41 +08:00
x = torch_npu . npu_interleave_rope ( x , cos , sin )
2025-05-12 19:14:07 +08:00
return x . view ( B , N , D )
2025-04-19 17:38:18 +08:00
def _forward_decode (
self ,
q_nope : torch . Tensor ,
q_pe : torch . Tensor ,
2025-05-12 19:14:07 +08:00
k_nope : torch . Tensor ,
k_pe : torch . Tensor ,
2025-07-26 17:15:47 +08:00
kv_c_and_k_pe_cache : Tuple [ torch . Tensor ] ,
2025-04-19 17:38:18 +08:00
attn_metadata : AscendMLAMetadata ,
2025-07-11 08:51:17 +08:00
enable_multistream_mla : bool = False ,
2025-04-19 17:38:18 +08:00
) - > torch . Tensor :
decode_meta = attn_metadata . decode
assert decode_meta is not None
2025-07-26 17:15:47 +08:00
num_tokens = q_nope . size ( 0 )
2025-05-12 19:14:07 +08:00
if self . running_in_graph :
2025-06-09 22:21:42 +08:00
# TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim]
if attn_metadata . attn_state == AscendAttentionState . SpecDecoding :
assert num_tokens % self . spec_token_num == 0
2025-06-11 14:09:28 +08:00
q_nope = q_nope . view ( num_tokens / / ( self . spec_token_num + 1 ) ,
self . spec_token_num + 1 , self . num_heads ,
- 1 )
q_pe = q_pe . view ( num_tokens / / ( self . spec_token_num + 1 ) ,
self . spec_token_num + 1 , self . num_heads , - 1 )
if not self . enable_kv_nz :
q_nope = q_nope . transpose ( 1 , 2 ) . contiguous ( )
q_pe = q_pe . transpose ( 1 , 2 ) . contiguous ( )
2025-06-09 22:21:42 +08:00
sparse_mode = 3
spec_attn_mask = attn_metadata . decode . attn_mask # type:ignore
else :
2025-06-11 14:09:28 +08:00
if self . enable_kv_nz :
q_nope = q_nope . view ( num_tokens , 1 , self . num_heads , - 1 )
q_pe = q_pe . view ( num_tokens , 1 , self . num_heads , - 1 )
else :
q_nope = q_nope . view ( num_tokens , self . num_heads , 1 , - 1 )
q_pe = q_pe . view ( num_tokens , self . num_heads , 1 , - 1 )
2025-06-09 22:21:42 +08:00
sparse_mode = 0
spec_attn_mask = None
2025-05-12 19:14:07 +08:00
# shape of knope/k_pe for npu graph mode should be:
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
block_size = kv_c_and_k_pe_cache [ 0 ] . shape [ 1 ]
2025-06-11 14:09:28 +08:00
if self . enable_kv_nz :
k_nope = k_nope . view ( - 1 , self . num_kv_heads ,
self . kv_lora_rank / / 16 , block_size , 16 )
k_pe = k_pe . view ( - 1 , self . num_kv_heads ,
self . qk_rope_head_dim / / 16 , block_size , 16 )
input_layout = " BSND "
else :
k_nope = k_nope . view ( - 1 , self . num_kv_heads , block_size ,
self . kv_lora_rank )
k_pe = k_pe . view ( - 1 , self . num_kv_heads , block_size ,
self . qk_rope_head_dim )
input_layout = " BNSD "
2025-05-12 19:14:07 +08:00
2025-06-11 14:09:28 +08:00
attn_output , _ = torch_npu . npu_fused_infer_attention_score (
2025-05-12 19:14:07 +08:00
q_nope ,
k_nope ,
k_nope ,
query_rope = q_pe ,
key_rope = k_pe ,
num_heads = self . num_heads ,
num_key_value_heads = self . num_kv_heads ,
2025-06-11 14:09:28 +08:00
input_layout = input_layout ,
2025-06-09 22:21:42 +08:00
atten_mask = spec_attn_mask ,
sparse_mode = sparse_mode ,
2025-05-12 19:14:07 +08:00
scale = self . scale ,
antiquant_mode = 0 ,
antiquant_scale = None ,
block_table = decode_meta . block_table ,
block_size = block_size ,
actual_seq_lengths_kv = decode_meta . seq_lens_list ,
)
else :
2025-07-26 17:15:47 +08:00
# The MLA_PA path will be used as default path in the future, `_npu_paged_attention_mla` will
# be removed after the torch_npu contains `torch_npu.atb.npu_multi_head_latent_attention` become
# public available
assert len ( kv_c_and_k_pe_cache ) > 1
if envs . VLLM_ASCEND_MLA_PA :
attn_output = torch_npu . atb . npu_multi_head_latent_attention (
q_nope , q_pe , kv_c_and_k_pe_cache [ 0 ] ,
kv_c_and_k_pe_cache [ 1 ] , attn_metadata . decode . block_table ,
attn_metadata . decode . seq_lens , self . num_heads , self . scale ,
self . num_kv_heads )
else :
q = torch . cat ( [ q_nope , q_pe ] , dim = - 1 )
attn_output = torch . empty (
[ num_tokens , self . num_heads , self . kv_lora_rank ] ,
dtype = q . dtype ,
device = q . device )
k_cache = torch . cat (
[ kv_c_and_k_pe_cache [ 0 ] , kv_c_and_k_pe_cache [ 1 ] ] , dim = - 1 )
torch_npu . _npu_paged_attention_mla (
query = q ,
key_cache = k_cache ,
num_kv_heads = self . num_kv_heads ,
num_heads = self . num_heads ,
scale_value = self . scale ,
block_table = attn_metadata . decode .
block_table , # type:ignore
context_lens = attn_metadata . decode . seq_lens , # type:ignore
mla_vheadsize = self . kv_lora_rank ,
out = attn_output )
2025-06-07 16:46:58 +08:00
current_ms_metadata = get_multistream_comm_context ( )
if current_ms_metadata is None :
2025-07-11 08:51:17 +08:00
return self . _v_up_proj_and_o_proj ( attn_output ,
enable_multistream_mla )
2025-06-07 16:46:58 +08:00
else :
current_ms_metadata . before_comm_event . record ( )
with torch . npu . stream ( current_ms_metadata . comm_stream ) :
current_ms_metadata . before_comm_event . wait ( )
return self . _v_up_proj_and_o_proj ( attn_output )
2025-04-19 17:38:18 +08:00
def forward (
self ,
layer : AttentionLayer ,
hidden_states_or_q_c : torch . Tensor , # query in unified attn
2025-05-12 19:14:07 +08:00
hidden_states_or_kv_c_normed : torch . Tensor , # key in unified attn
2025-04-19 17:38:18 +08:00
k_pe : torch . Tensor , # value in unified attn
2025-07-26 17:15:47 +08:00
kv_cache : Tuple [ torch . Tensor ] ,
2025-04-19 17:38:18 +08:00
attn_metadata : M ,
output : Optional [ torch . Tensor ] = None ,
2025-06-26 09:32:07 +08:00
enable_multistream_mla : bool = False ,
2025-07-11 08:51:17 +08:00
ckq : Optional [ torch . Tensor ] = None ,
2025-04-19 17:38:18 +08:00
) - > torch . Tensor :
assert output is not None , " Output tensor must be provided. "
if attn_metadata is None :
# Profiling run.
return output
2025-06-09 22:21:42 +08:00
self . running_in_graph = self . torchair_graph_enabled and attn_metadata . attn_state in [
AscendAttentionState . DecodeOnly , AscendAttentionState . SpecDecoding
]
2025-04-19 17:38:18 +08:00
num_actual_toks = attn_metadata . num_actual_tokens
2025-05-12 19:14:07 +08:00
if k_pe is None and not self . running_in_graph :
2025-07-31 20:08:45 +08:00
kv_c , k_pe = self . kv_a_proj_with_mqa (
hidden_states_or_kv_c_normed ) [ 0 ] . split (
[ self . kv_lora_rank , self . qk_rope_head_dim ] , dim = - 1 )
kv_c_normed = self . kv_a_layernorm ( kv_c . contiguous ( ) )
2025-05-12 19:14:07 +08:00
else :
kv_c_normed = hidden_states_or_kv_c_normed
2025-04-19 17:38:18 +08:00
assert attn_metadata . num_decodes is not None and \
2025-05-12 19:14:07 +08:00
attn_metadata . num_prefills is not None and \
attn_metadata . num_decode_tokens is not None
2025-04-19 17:38:18 +08:00
has_decode = attn_metadata . num_decodes > 0
has_prefill = attn_metadata . num_prefills > 0
num_decode_tokens = attn_metadata . num_decode_tokens
2025-05-12 19:14:07 +08:00
if not self . running_in_graph :
# Inputs and outputs may be padded for CUDA graphs
output_padded = output
output = output [ : num_actual_toks , . . . ]
2025-06-11 14:09:28 +08:00
if not self . torchair_graph_enabled :
kv_c_normed = kv_c_normed [ : num_actual_toks , . . . ]
prefill_k_c_normed = kv_c_normed [ num_decode_tokens : ]
2025-05-12 19:14:07 +08:00
if not self . running_in_graph :
hidden_states_or_q_c = hidden_states_or_q_c [ : num_actual_toks , . . . ]
prefill_hs_or_q_c = hidden_states_or_q_c [ num_decode_tokens : ]
2025-07-31 20:08:45 +08:00
decode_hs_or_q_c = hidden_states_or_q_c [ : num_decode_tokens ]
prefill_hs = hidden_states_or_kv_c_normed [ num_decode_tokens : ]
# if not self.torchair_graph_enabled:
k_pe = k_pe [ : num_actual_toks , . . . ]
k_pe = k_pe . unsqueeze ( 1 )
decode_k_pe = k_pe [ : num_decode_tokens ]
prefill_k_pe = k_pe [ num_decode_tokens : ]
2025-05-12 19:14:07 +08:00
else :
decode_hs_or_q_c = hidden_states_or_q_c
2025-04-19 17:38:18 +08:00
if has_decode :
2025-05-12 19:14:07 +08:00
decode_k_nope = None
2025-04-19 17:38:18 +08:00
assert attn_metadata . decode is not None
2025-05-12 19:14:07 +08:00
if self . running_in_graph :
2025-07-29 18:06:45 +08:00
cos = attn_metadata . decode . cos
sin = attn_metadata . decode . sin
2025-07-11 08:51:17 +08:00
with npu_stream_switch ( " mla_secondary " ,
0 ,
enabled = enable_multistream_mla ) :
npu_wait_tensor ( hidden_states_or_kv_c_normed ,
ckq ,
enabled = enable_multistream_mla )
decode_k_pe , decode_k_nope , decode_kv = self . exec_kv (
hidden_states_or_kv_c_normed , cos , sin , kv_cache ,
attn_metadata . slot_mapping )
2025-06-12 21:42:09 +08:00
# Without explicitly controlling the order, IndexByTensor operations
# would be placed after `matmul W_KV_T` hindering the overlapping of
# KvRmsNormRopeCache and SingleRope.
npu_wait_tensor ( decode_hs_or_q_c ,
cos ,
2025-06-26 09:32:07 +08:00
enabled = enable_multistream_mla )
2025-06-12 21:42:09 +08:00
npu_wait_tensor ( decode_hs_or_q_c ,
sin ,
2025-06-26 09:32:07 +08:00
enabled = enable_multistream_mla )
2025-07-11 08:51:17 +08:00
npu_wait_tensor ( decode_hs_or_q_c ,
decode_kv ,
enabled = enable_multistream_mla )
2025-06-12 21:42:09 +08:00
decode_ql_nope , decode_q_pe = \
self . _q_proj_and_k_up_proj ( decode_hs_or_q_c )
if self . running_in_graph :
with npu_stream_switch ( " mla_secondary " ,
0 ,
2025-06-26 09:32:07 +08:00
enabled = enable_multistream_mla ) :
2025-06-12 21:42:09 +08:00
npu_wait_tensor ( decode_q_pe ,
decode_k_pe ,
2025-06-26 09:32:07 +08:00
enabled = enable_multistream_mla )
2025-06-12 21:42:09 +08:00
decode_q_pe = self . rope_single ( decode_q_pe , cos , sin )
2025-05-12 19:14:07 +08:00
else :
decode_q_pe [ . . . ] , decode_k_pe [ . . . ] = self . rotary_emb (
attn_metadata . decode . input_positions ,
decode_q_pe . contiguous ( ) ,
decode_k_pe ,
max_seq_len = attn_metadata . decode . max_seq_lens )
2025-04-19 17:38:18 +08:00
if has_prefill :
assert attn_metadata . prefill is not None
prefill_q = self . q_proj ( prefill_hs_or_q_c ) [ 0 ] \
. view ( - 1 , self . num_heads , self . qk_head_dim )
prefill_q_pe = prefill_q [ . . . , self . qk_nope_head_dim : ]
2025-05-12 19:14:07 +08:00
prefill_q_nope = prefill_q [ . . . , : self . qk_nope_head_dim ]
2025-06-05 16:28:01 +08:00
if self . torchair_graph_enabled :
2025-05-12 19:14:07 +08:00
num_tokens = prefill_hs_or_q_c . shape [ 0 ]
2025-07-29 18:06:45 +08:00
cos = attn_metadata . prefill . cos
sin = attn_metadata . prefill . sin
2025-06-11 14:09:28 +08:00
prefill_q_pe = self . rope_single ( prefill_q_pe , cos , sin )
prefill_k_pe , prefill_k_nope = self . exec_kv_prefill (
2025-07-31 20:08:45 +08:00
prefill_hs , cos , sin , kv_cache ,
attn_metadata . slot_mapping [ num_decode_tokens : ] )
2025-06-11 14:09:28 +08:00
kv_c_normed = prefill_k_nope [ : num_actual_toks , . . . ]
2025-07-31 20:08:45 +08:00
prefill_k_c_normed = prefill_k_nope
2025-05-12 19:14:07 +08:00
prefill_k_pe = prefill_k_pe . view ( num_tokens , self . num_kv_heads ,
- 1 )
prefill_q = torch . cat ( [ prefill_q_nope , prefill_q_pe ] , dim = - 1 )
else :
prefill_q_pe [ . . . ] , prefill_k_pe [ . . . ] = self . rotary_emb (
attn_metadata . prefill . input_positions ,
prefill_q_pe . contiguous ( ) ,
prefill_k_pe ,
max_seq_len = attn_metadata . prefill . max_seq_lens )
2025-07-26 17:15:47 +08:00
assert len (
kv_cache
) > 1 , " the number of kv cache should be greater than 1, namely (nope_cache and rope_cache) "
2025-06-05 16:28:01 +08:00
if self . torchair_graph_enabled :
2025-07-26 17:15:47 +08:00
if kv_cache [ 0 ] . numel (
2025-05-12 19:14:07 +08:00
) > 0 and attn_metadata . attn_state == AscendAttentionState . PrefillNoCache :
slots = attn_metadata . slot_mapping
# NOTE: Separate the kv cache in advance to avoid OOM or other issues
torch_npu . _npu_reshape_and_cache ( key = kv_c_normed . view (
num_tokens , self . num_kv_heads , - 1 ) ,
value = prefill_k_pe ,
key_cache = kv_cache [ 0 ] ,
value_cache = kv_cache [ 1 ] ,
slot_indices = slots )
2025-07-26 17:15:47 +08:00
else :
kv_c_normed = kv_c_normed . view (
[ num_actual_toks , self . num_kv_heads , - 1 ] )
torch_npu . _npu_reshape_and_cache (
key = kv_c_normed ,
value = k_pe ,
key_cache = kv_cache [ 0 ] ,
value_cache = kv_cache [ 1 ] ,
slot_indices = attn_metadata . slot_mapping )
2025-04-19 17:38:18 +08:00
if has_prefill :
2025-06-07 16:46:58 +08:00
# FIX: aicore move should be also placed on the comm stream in dbo,
# otherwise it may affect the accuracy
# TODO: use an elegant way to overlap
output_prefill = self . _forward_prefill ( prefill_q ,
prefill_k_c_normed ,
prefill_k_pe , kv_cache ,
attn_metadata )
current_ms_metadata = get_multistream_comm_context ( )
if current_ms_metadata is not None :
with torch . npu . stream ( current_ms_metadata . comm_stream ) :
output [ num_decode_tokens : ] = output_prefill
current_ms_metadata . after_comm_event . record ( )
else :
output [ num_decode_tokens : ] = output_prefill
2025-04-19 17:38:18 +08:00
if has_decode :
2025-05-12 19:14:07 +08:00
if self . running_in_graph :
return self . _forward_decode ( decode_ql_nope , decode_q_pe ,
decode_k_nope , decode_k_pe ,
2025-07-11 08:51:17 +08:00
kv_cache , attn_metadata ,
enable_multistream_mla )
2025-05-12 19:14:07 +08:00
else :
2025-06-07 16:46:58 +08:00
output_decode = self . _forward_decode ( decode_ql_nope ,
decode_q_pe ,
decode_k_nope ,
decode_k_pe , kv_cache ,
attn_metadata )
current_ms_metadata = get_multistream_comm_context ( )
if current_ms_metadata is not None :
with torch . npu . stream ( current_ms_metadata . comm_stream ) :
output [ : num_decode_tokens ] = output_decode
current_ms_metadata . after_comm_event . record ( )
else :
output [ : num_decode_tokens ] = output_decode
2025-05-16 12:14:55 +08:00
return output_padded