2025-04-19 17:38:18 +08:00
from dataclasses import dataclass
from typing import TYPE_CHECKING , Any , Optional , Type , TypeVar
import torch
import torch_npu
from vllm . attention . backends . abstract import ( AttentionBackend , AttentionLayer ,
AttentionMetadata ,
MLAAttentionImpl )
from vllm . model_executor . layers . linear import ( ColumnParallelLinear ,
LinearBase , RowParallelLinear ,
UnquantizedLinearMethod )
from vllm . model_executor . layers . rotary_embedding import RotaryEmbedding
from vllm_ascend . attention . attention_v1 import AscendAttentionState
from vllm_ascend . ops . attention import vanilla_chunked_prefill_mla
from vllm_ascend . worker . model_runner_v1 import NPUModelRunner
if TYPE_CHECKING :
from vllm . v1 . core . sched . output import SchedulerOutput
from vllm . v1 . worker . gpu_input_batch import InputBatch
class AscendMLABackend ( AttentionBackend ) :
accept_output_buffer : bool = True
@staticmethod
def get_name ( ) - > str :
return " VLLM_ASCEND_MLA "
@staticmethod
def get_metadata_cls ( ) - > type [ " AttentionMetadata " ] :
return AscendMLAMetadata
@staticmethod
def get_builder_cls ( ) :
return AscendMLAMetadataBuilder
@staticmethod
def get_kv_cache_shape ( num_blocks : int , block_size : int , num_kv_heads : int ,
head_size : int ) - > tuple [ int , . . . ] :
return ( num_blocks , block_size , num_kv_heads , head_size )
@staticmethod
def get_impl_cls ( ) - > Type [ " MLAAttentionImpl " ] :
return AscendMLAImpl
@dataclass
class AscendMLAPrefillMetadata :
""" Prefill Specific Metadata for Ascend """
attn_mask : torch . Tensor
query_lens : list [ int ]
context_lens : torch . Tensor
input_positions : torch . Tensor
block_table : torch . Tensor
max_query_len : int
2025-04-29 17:12:03 +08:00
max_seq_lens : int
2025-04-19 17:38:18 +08:00
@dataclass
class AscendMLADecodeMetadata :
# Input positions for rotrary embeddings since for MLA the rotary
# position embeddings are applied inside the attention backend
input_positions : torch . Tensor
block_table : torch . Tensor
seq_lens : torch . Tensor
2025-04-29 17:12:03 +08:00
max_seq_lens : int
2025-04-19 17:38:18 +08:00
@dataclass
class AscendMLAMetadata :
""" Metadata for MLACommon.
NOTE : Please read the comment at the top of the file before trying to
understand this class
"""
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
num_actual_tokens : int # Number of tokens excluding padding.
slot_mapping : torch . Tensor
# New for MLA (compared to FlashAttention)
# For handling prefill decode split
num_decodes : int
num_decode_tokens : int
num_prefills : int
# For logging.
num_input_tokens : int = 0 # Number of tokens including padding.
# The dimension of the attention heads
head_dim : Optional [ int ] = None
attn_mask : torch . Tensor = None
# chunked prefill by default if no attn_states passed
attn_state : AscendAttentionState = AscendAttentionState . ChunkedPrefill
decode : Optional [ AscendMLADecodeMetadata ] = None
prefill : Optional [ AscendMLAPrefillMetadata ] = None
def __post_init__ ( self ) :
pass
# supported_head_sizes = AscendMLABackend.get_supported_head_sizes()
# if self.head_dim is not None and self.head_dim \
# not in supported_head_sizes:
# raise ValueError(
# f"Only {supported_head_sizes} are supported for head_dim,",
# f"received {self.head_dim}.")
M = TypeVar ( " M " , bound = AscendMLAMetadata )
class AscendMLAMetadataBuilder :
"""
NOTE : Please read the comment at the top of the file before trying to
understand this class
"""
# _attn_mask_builder = None
def __init__ ( self ,
runner : " NPUModelRunner " ,
metadata_cls : Optional [ AscendMLAMetadata ] = None ) :
self . metadata_cls : Optional [ AscendMLAMetadata ] = metadata_cls \
if metadata_cls is not None else AscendMLAMetadata # type: ignore
self . runner = runner
scheduler_config = runner . scheduler_config
self . chunked_prefill_enabled = scheduler_config . chunked_prefill_enabled
def reorder_batch ( self , input_batch : " InputBatch " ,
scheduler_output : " SchedulerOutput " ) - > bool :
# We now want to reorder the batch so that the "decode" requests are at
# the front and the "prefill" requests are at the using the least amount
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
# where attention is likely memory-bound and "prefill" to mean requests
# where attention is likely compute-bound, TODO(lucas): figure out a
# better naming here)
decodes = [ ]
prefills = [ ]
num_decode_tokens = 0
num_prefill_tokens = 0
for i , req_id in enumerate ( input_batch . req_ids ) :
num_tokens = scheduler_output . num_scheduled_tokens [ req_id ]
# for now treat 1 scheduled token as "decode" even if its not,
# we should update this to something like < 8 in the future but
# currently the TritonMLA._forward_decode only supports
# num_tokens = 1
if num_tokens == 1 :
decodes . append ( i )
num_decode_tokens + = num_tokens
else :
prefills . append ( i )
num_prefill_tokens + = num_tokens
# We hope that this is fairly minimal since decodes
# should be around for a number of iterations so hopefully they are
# relatively stationary (and new request are generally appended to the
# persistent batch so already should be at the back)
# To achieve this we loop over the decodes in descending order and
# the prefills in ascending order. We swap decodes from the "back"
# i.e. past where the last decode should be in the reodorered with
# prefills from the front of the batch.
# `decodes` and `prefills` are already in ascending order just based on
# the above loop
num_decodes = len ( decodes )
num_prefills = len ( prefills )
first_prefill = 0
modified_batch = False
for i in range ( 1 , min ( num_decodes , num_prefills ) + 1 ) :
# If the decode is at the "back" of the batch, i, we can swap it
# with the prefill closest to the front of the batch
if decodes [ num_decodes - i ] > = num_decodes :
input_batch . swap_states ( prefills [ first_prefill ] ,
decodes [ num_decodes - i ] )
first_prefill + = 1
modified_batch = True
else :
break
# Save for next `build` call
# TODO(lucas): this is a bit of a hack, we should probably have a
# better way of doing this
self . _num_decodes = num_decodes
self . _num_prefills = num_prefills
self . _num_decode_tokens = num_decode_tokens
self . _num_prefill_tokens = num_prefill_tokens
return modified_batch
def build ( self ,
num_reqs : int ,
num_actual_tokens : int ,
max_query_len : int ,
common_prefix_len : Optional [ int ] = None ) - > AscendMLAMetadata :
assert self . _num_decodes + self . _num_prefills == num_reqs
# Note(simon): be careful about the CPU <> GPU memory movement in this
# function. We should avoid GPU -> CPU sync as much as possible because
# it blocks on all previous kernels.
device = self . runner . device
block_table = (
self . runner . input_batch . block_table . get_device_tensor ( ) [ : num_reqs ] )
slot_mapping = self . runner . slot_mapping_cpu [ : num_actual_tokens ] . to (
2025-04-28 21:54:42 +08:00
device , non_blocking = True )
2025-04-19 17:38:18 +08:00
input_positions = self . runner . positions_cpu [ : num_actual_tokens ] . to (
device , non_blocking = True ) . long ( )
seq_lens_cpu = self . runner . seq_lens_cpu [ : num_reqs ]
query_lens = seq_lens_cpu - self . runner . input_batch . num_computed_tokens_cpu_tensor [ :
num_reqs ]
seq_lens = seq_lens_cpu
max_query_len = query_lens . max ( ) . item ( )
2025-04-29 17:12:03 +08:00
max_seq_lens = seq_lens . max ( ) . item ( )
2025-04-19 17:38:18 +08:00
prefill_metadata = None
if self . _num_prefills > 0 :
reqs_start = self . _num_decodes # prefill_start
tokens_start = self . _num_decode_tokens
2025-04-29 17:12:03 +08:00
max_query_len = query_lens [ tokens_start : ] . max ( ) . item ( )
max_seq_lens = seq_lens [ tokens_start : ] . max ( ) . item ( )
2025-04-19 17:38:18 +08:00
prefill_metadata = AscendMLAPrefillMetadata (
attn_mask = self . runner . attn_mask ,
query_lens = query_lens [ tokens_start : ] ,
context_lens = seq_lens [ tokens_start : ] ,
input_positions = input_positions [ tokens_start : ] ,
block_table = block_table [ reqs_start : , . . . ] ,
max_query_len = max_query_len ,
2025-04-29 17:12:03 +08:00
max_seq_lens = max_seq_lens ,
2025-04-19 17:38:18 +08:00
)
decode_metadata = None
if self . _num_decodes > 0 :
2025-04-29 17:12:03 +08:00
max_seq_lens = seq_lens [ : self . _num_decodes ] . max ( ) . item ( )
2025-04-19 17:38:18 +08:00
decode_metadata = AscendMLADecodeMetadata (
input_positions = input_positions [ : self . _num_decode_tokens ] ,
block_table = block_table [ : self . _num_decode_tokens , . . . ] ,
2025-04-29 17:12:03 +08:00
seq_lens = seq_lens [ : self . _num_decode_tokens ] ,
max_seq_lens = max_seq_lens )
2025-04-19 17:38:18 +08:00
return self . metadata_cls ( # type: ignore
num_actual_tokens = num_actual_tokens ,
slot_mapping = slot_mapping ,
head_dim = self . runner . model_config . get_head_size ( ) ,
num_decodes = self . _num_decodes ,
num_decode_tokens = self . _num_decode_tokens ,
num_prefills = self . _num_prefills ,
attn_mask = self . runner . attn_mask ,
attn_state = self . runner . attn_state ,
prefill = prefill_metadata ,
decode = decode_metadata ,
)
class AscendMLAImpl ( MLAAttentionImpl ) :
"""
NOTE : Please read the comment at the top of the file before trying to
understand this class
"""
def __init__ (
self ,
num_heads : int ,
head_size : int ,
scale : float ,
num_kv_heads : int ,
alibi_slopes : Optional [ list [ float ] ] ,
sliding_window : Optional [ int ] ,
kv_cache_dtype : str ,
blocksparse_params : Optional [ dict [ str , Any ] ] ,
logits_soft_cap : Optional [ float ] ,
attn_type : str ,
# MLA Specific Arguments
q_lora_rank : Optional [ int ] ,
kv_lora_rank : int ,
qk_nope_head_dim : int ,
qk_rope_head_dim : int ,
qk_head_dim : int ,
v_head_dim : int ,
rotary_emb : RotaryEmbedding ,
# q_proj should be q_b_proj if q_lora_rank is not None, but from an
# attention backend perspective we rely on the layer to pass in the
# correct matrix
q_proj : ColumnParallelLinear ,
kv_b_proj : ColumnParallelLinear ,
o_proj : RowParallelLinear ,
* * kwargs ,
) - > None :
self . num_heads = num_heads
self . head_size = head_size
self . scale = float ( scale )
self . num_kv_heads = num_kv_heads
self . kv_cache_dtype = kv_cache_dtype
self . q_lora_rank = q_lora_rank
self . kv_lora_rank = kv_lora_rank
self . qk_nope_head_dim = qk_nope_head_dim
self . qk_rope_head_dim = qk_rope_head_dim
self . qk_head_dim = qk_head_dim
self . v_head_dim = v_head_dim
2025-04-29 17:12:03 +08:00
# TODO: below padding should be removed after kernel is ready
# we found npu_flash_attention can only works on 128 divisible head_dim, we pad it to target size here
# and slice the final result to guarantee its functionality.
self . padding_head_dim = (
( self . qk_nope_head_dim + self . qk_rope_head_dim - 1 ) / / 128 +
1 ) * 128
2025-04-19 17:38:18 +08:00
# Hack for V1 for now to avoid torch library overhead (since we are
# already inside an attention custom op), pull out the forward
# method from the rotary embedding and call it directly
# TODO(lucas): we should probably find a cleaner way to do this
2025-04-29 17:12:03 +08:00
self . rotary_emb = rotary_emb
2025-04-19 17:38:18 +08:00
self . q_proj = q_proj
self . kv_b_proj = kv_b_proj
self . o_proj = o_proj
# Handle the differences between the flash_attn_varlen from flash_attn
# and the one from vllm_flash_attn. The former is used on RoCM and the
# latter has an additional parameter to control FA2 vs FA3
# self.flash_attn_varlen_func = flash_attn_varlen_func
# if self.vllm_flash_attn_version is not None:
# self.flash_attn_varlen_func = \
# functools.partial(flash_attn_varlen_func,
# fa_version=self.vllm_flash_attn_version)
def _v_up_proj_and_o_proj ( self , x ) :
# Convert from (B, N, L) to (N, B, L)
x = x . view ( - 1 , self . num_heads , self . kv_lora_rank ) . transpose ( 0 , 1 )
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
x = torch . bmm ( x , self . W_UV )
# Convert from (N, B, V) to (B, N * V)
x = x . transpose ( 0 , 1 ) . reshape ( - 1 , self . num_heads * self . v_head_dim )
return self . o_proj ( x ) [ 0 ]
# Return `ql_nope`, `q_pe`
def _q_proj_and_k_up_proj ( self , x ) :
q_nope , q_pe = self . q_proj ( x ) [ 0 ] \
. view ( - 1 , self . num_heads , self . qk_head_dim ) \
. split ( [ self . qk_nope_head_dim , self . qk_rope_head_dim ] , dim = - 1 )
# Convert from (B, N, P) to (N, B, P)
q_nope = q_nope . transpose ( 0 , 1 )
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
ql_nope = torch . bmm ( q_nope , self . W_UK_T )
# Convert from (N, B, L) to (B, N, L)
return ql_nope . transpose ( 0 , 1 ) , q_pe
def process_weights_after_loading ( self , act_dtype : torch . dtype ) :
def get_layer_weight ( layer ) :
WEIGHT_NAMES = ( " weight " , " qweight " , " weight_packed " )
for attr in WEIGHT_NAMES :
if hasattr ( layer , attr ) :
return getattr ( layer , attr )
raise AttributeError (
f " Layer ' { layer } ' has no recognized weight attribute: "
f " { WEIGHT_NAMES } . " )
def get_and_maybe_dequant_weights ( layer : LinearBase ) :
if not isinstance ( layer . quant_method , UnquantizedLinearMethod ) :
# NOTE: This should only be used offline, since it's O(N^3)
eye = torch . eye ( layer . input_size_per_partition ,
dtype = act_dtype ,
device = get_layer_weight ( layer ) . device )
dequant_weights = layer . quant_method . apply ( layer ,
eye ,
bias = None )
del eye
# standardize to (output, input)
return dequant_weights . T
return layer . weight
# we currently do not have quantized bmm's which are needed for
# `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
kv_b_proj_weight = get_and_maybe_dequant_weights ( self . kv_b_proj ) . T
assert kv_b_proj_weight . shape == (
self . kv_lora_rank ,
self . num_heads * ( self . qk_nope_head_dim + self . v_head_dim ) ) , (
f " { kv_b_proj_weight . shape =} , "
f " { self . kv_lora_rank =} , "
f " { self . num_heads =} , "
f " { self . qk_nope_head_dim =} , "
f " { self . v_head_dim =} " )
kv_b_proj_weight = kv_b_proj_weight . view (
self . kv_lora_rank ,
self . num_heads ,
self . qk_nope_head_dim + self . v_head_dim ,
)
W_UK , W_UV = kv_b_proj_weight . split (
[ self . qk_nope_head_dim , self . v_head_dim ] , dim = - 1 )
# Convert from (L, N, V) to (N, L, V)
self . W_UV = W_UV . transpose ( 0 , 1 )
# Convert from (L, N, P) to (N, P, L)
self . W_UK_T = W_UK . permute ( 1 , 2 , 0 )
def _forward_prefill (
self ,
query : torch . Tensor ,
kv_c_normed : torch . Tensor ,
k_pe : torch . Tensor ,
kv_c_and_k_pe_cache : torch . Tensor ,
attn_metadata : AscendMLAMetadata ,
) - > torch . Tensor :
assert attn_metadata . prefill is not None
num_tokens = query . size ( 0 )
2025-04-29 17:12:03 +08:00
attn_output = None
# Here is only 2 possibility of input, ChunkedPrefill or PrefillOnly
if attn_metadata . attn_state == AscendAttentionState . ChunkedPrefill :
attn_output = torch . empty ( num_tokens ,
self . num_heads * self . v_head_dim ,
dtype = query . dtype ,
device = query . device )
# current requests is chunked in prefill, disable flash attention with chunked prefill
vanilla_chunked_prefill_mla (
output = attn_output ,
query = query ,
kv_cache = kv_c_and_k_pe_cache ,
block_tables = attn_metadata . prefill . block_table ,
query_lens = attn_metadata . prefill . query_lens ,
context_lens = attn_metadata . prefill . context_lens ,
kv_b_proj = self . kv_b_proj ,
max_query_len = attn_metadata . prefill . max_query_len ,
max_context_len = attn_metadata . prefill . max_seq_lens ,
nope_dim = self . qk_nope_head_dim ,
rope_dim = self . qk_rope_head_dim ,
v_head_dim = self . v_head_dim ,
scale = self . scale ,
alibi_slopes = None ,
causal = True )
elif attn_metadata . attn_state == AscendAttentionState . PrefillOnly :
attn_output = torch . empty ( num_tokens ,
self . num_heads ,
self . padding_head_dim ,
dtype = query . dtype ,
device = query . device )
k_nope , value = self . kv_b_proj ( kv_c_normed ) [ 0 ] . view (
- 1 , self . num_heads ,
self . qk_nope_head_dim + self . v_head_dim ) . split (
[ self . qk_nope_head_dim , self . v_head_dim ] , dim = - 1 )
key = torch . cat ( ( k_nope , k_pe . expand ( ( * k_nope . shape [ : - 1 ] , - 1 ) ) ) ,
dim = - 1 )
pad_query = torch . nn . functional . pad ( query , [
0 , self . padding_head_dim - self . qk_rope_head_dim -
self . qk_nope_head_dim
] ,
value = 0 )
pad_key = torch . nn . functional . pad ( key , [
0 , self . padding_head_dim - self . qk_rope_head_dim -
self . qk_nope_head_dim
] ,
value = 0 )
pad_value = torch . nn . functional . pad (
value , [ 0 , self . padding_head_dim - self . v_head_dim ] , value = 0 )
torch_npu . _npu_flash_attention (
query = pad_query ,
key = pad_key ,
value = pad_value ,
mask = attn_metadata . attn_mask ,
seq_len = attn_metadata . prefill . context_lens ,
scale_value = self . scale ,
num_heads = self . num_heads ,
num_kv_heads = self . num_heads ,
out = attn_output )
attn_output = attn_output . view (
- 1 , self . num_heads ,
self . padding_head_dim ) [ : , : , : self . v_head_dim ]
else :
raise RuntimeError (
" Unexpected path reached, AscendMLAImpl should only have PrefillOnly and ChunkedPrefill scenario in forward prefill, please file a bug to vllm-ascend ! "
)
attn_output = attn_output . reshape (
2025-04-19 17:38:18 +08:00
[ num_tokens , self . num_heads * self . v_head_dim ] )
return self . o_proj ( attn_output ) [ 0 ]
def _forward_decode (
self ,
q_nope : torch . Tensor ,
q_pe : torch . Tensor ,
kv_c_and_k_pe_cache : torch . Tensor ,
attn_metadata : AscendMLAMetadata ,
) - > torch . Tensor :
assert kv_c_and_k_pe_cache . numel ( ) > 0
decode_meta = attn_metadata . decode
assert decode_meta is not None
q = torch . cat ( [ q_nope , q_pe ] , dim = - 1 )
num_tokens = q . size ( 0 )
2025-04-29 17:12:03 +08:00
attn_output = torch . empty (
2025-04-19 17:38:18 +08:00
[ num_tokens , self . num_heads , self . kv_lora_rank ] ,
dtype = q . dtype ,
device = q . device )
torch_npu . _npu_paged_attention_mla (
query = q ,
key_cache = kv_c_and_k_pe_cache ,
num_kv_heads = self . num_kv_heads ,
num_heads = self . num_heads ,
scale_value = self . scale ,
block_table = attn_metadata . decode . block_table , # type:ignore
context_lens = attn_metadata . decode . seq_lens , # type:ignore
mla_vheadsize = self . kv_lora_rank ,
out = attn_output )
return self . _v_up_proj_and_o_proj ( attn_output )
def forward (
self ,
layer : AttentionLayer ,
hidden_states_or_q_c : torch . Tensor , # query in unified attn
k_c_normed : torch . Tensor , # key in unified attn
k_pe : torch . Tensor , # value in unified attn
kv_cache : torch . Tensor ,
attn_metadata : M ,
output : Optional [ torch . Tensor ] = None ,
) - > torch . Tensor :
assert output is not None , " Output tensor must be provided. "
if attn_metadata is None :
# Profiling run.
return output
num_actual_toks = attn_metadata . num_actual_tokens
# Inputs and outputs may be padded for CUDA graphs
output_padded = output
output = output [ : num_actual_toks , . . . ]
hidden_states_or_q_c = hidden_states_or_q_c [ : num_actual_toks , . . . ]
k_c_normed = k_c_normed [ : num_actual_toks , . . . ]
k_pe = k_pe [ : num_actual_toks , . . . ]
# Restore head dim (for rotary embedding)
k_pe = k_pe . unsqueeze ( 1 )
assert attn_metadata . num_decodes is not None and \
attn_metadata . num_prefills is not None and \
attn_metadata . num_decode_tokens is not None
has_decode = attn_metadata . num_decodes > 0
has_prefill = attn_metadata . num_prefills > 0
num_decode_tokens = attn_metadata . num_decode_tokens
decode_hs_or_q_c = hidden_states_or_q_c [ : num_decode_tokens ]
decode_k_pe = k_pe [ : num_decode_tokens ]
prefill_hs_or_q_c = hidden_states_or_q_c [ num_decode_tokens : ]
prefill_k_pe = k_pe [ num_decode_tokens : ]
prefill_k_c_normed = k_c_normed [ num_decode_tokens : ]
if has_decode :
assert attn_metadata . decode is not None
decode_ql_nope , decode_q_pe = \
self . _q_proj_and_k_up_proj ( decode_hs_or_q_c )
decode_q_pe [ . . . ] , decode_k_pe [ . . . ] = self . rotary_emb (
2025-04-29 17:12:03 +08:00
attn_metadata . decode . input_positions ,
decode_q_pe . contiguous ( ) ,
decode_k_pe ,
max_seq_len = attn_metadata . decode . max_seq_lens )
2025-04-19 17:38:18 +08:00
if has_prefill :
assert attn_metadata . prefill is not None
prefill_q = self . q_proj ( prefill_hs_or_q_c ) [ 0 ] \
. view ( - 1 , self . num_heads , self . qk_head_dim )
prefill_q_pe = prefill_q [ . . . , self . qk_nope_head_dim : ]
prefill_q_pe [ . . . ] , prefill_k_pe [ . . . ] = self . rotary_emb (
attn_metadata . prefill . input_positions ,
2025-04-29 17:12:03 +08:00
prefill_q_pe . contiguous ( ) ,
prefill_k_pe ,
max_seq_len = attn_metadata . prefill . max_seq_lens )
2025-04-19 17:38:18 +08:00
if kv_cache . numel ( ) > 0 :
2025-04-28 21:54:42 +08:00
key = torch . cat ( [
k_c_normed . view ( [ num_actual_toks , self . num_kv_heads , - 1 ] ) , k_pe
] ,
dim = 2 )
torch_npu . _npu_reshape_and_cache_siso (
key = key ,
key_cache = kv_cache ,
slot_indices = attn_metadata . slot_mapping . flatten ( ) )
2025-04-19 17:38:18 +08:00
if has_prefill :
output [ num_decode_tokens : ] = self . _forward_prefill (
prefill_q , prefill_k_c_normed , prefill_k_pe , kv_cache ,
attn_metadata )
if has_decode :
output [ : num_decode_tokens ] = self . _forward_decode (
decode_ql_nope , decode_q_pe , kv_cache , attn_metadata )
return output_padded