################################################################################ # Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ################################################################################ import itertools from dataclasses import dataclass from typing import TYPE_CHECKING, Any, ClassVar, Optional, Tuple, Union import torch import torch_br from vllm.attention.backends.abstract import (AttentionLayer, AttentionType, is_quantized_kv_cache) from vllm.config import VllmConfig from vllm.distributed import (get_tensor_model_parallel_world_size, get_tp_group, tensor_model_parallel_all_reduce) from vllm.logger import init_logger from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearBase, ReplicatedLinear, RowParallelLinear, UnquantizedLinearMethod) from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.v1.attention.backends.flash_attn import _get_sliding_window_configs from vllm.v1.attention.backends.mla.common import (MLACommonImpl, MLACommonMetadataBuilder) from vllm.v1.attention.backends.mla.flashmla import (FlashMLABackend, FlashMLAMetadata) from vllm.v1.attention.backends.utils import (AttentionCGSupport, split_decodes_and_prefills) from vllm.v1.kv_cache_interface import AttentionSpec from vllm_br import envs from vllm_br.model_executor.layers.br_utils import _convert_to_numa_tensor from vllm_br.utils import get_grandparent_pid from vllm_br.v1.attention.backends.utils import SUPACommonAttentionMetadata if TYPE_CHECKING: pass logger = init_logger(__name__) class SupaFlashMLABackend(FlashMLABackend): # NOTE: When piecewise cudagraph is enabled, this # makes sure the output tensor is allocated inside the cudagraph. # NOTE: currently, we do not support accept_output_buffer=True accept_output_buffer: bool = False @staticmethod def get_supported_head_sizes() -> list[int]: return [32, 64, 96, 128, 160, 192, 224, 256] @staticmethod def get_name() -> str: return "SUPAFLASHMLA" @staticmethod def get_metadata_cls() -> type["SupaFlashMLAMetadata"]: return SupaFlashMLAMetadata @staticmethod def get_builder_cls() -> type["SupaFlashMLAMetadataBuilder"]: return SupaFlashMLAMetadataBuilder @staticmethod def get_impl_cls() -> type["SupaFlashMLAImpl"]: return SupaFlashMLAImpl @staticmethod def get_kv_cache_shape( num_blocks: int, block_size: int, num_kv_heads: int, head_size: int, ) -> tuple[int, ...]: if block_size % 16 != 0: raise ValueError("Block size must be a multiple of 16.") return (2, num_blocks, block_size, num_kv_heads, head_size) @staticmethod def get_kv_cache_usharp_shape( num_blocks: int, block_size: int, num_kv_heads: int, head_size: int, ) -> Tuple[int, ...]: th_gran = SupaFlashMLABackend.get_kv_cache_usharp_alignment(block_size) n_block = max(1, (num_blocks + th_gran - 1) // th_gran) # return (2, n_block, th_gran * block_size, num_kv_heads * head_size) logger.debug( f'Origin kv cache shape is [1, {num_blocks}, {block_size}, {num_kv_heads}, {head_size}, For SUPA Speed up, use [1, {n_block}, {th_gran * block_size}, {num_kv_heads * head_size}]' # noqa: G004 ) # TODO, shared kv only used in deepseek return (1, n_block, th_gran * block_size, num_kv_heads * head_size) @staticmethod def get_kv_cache_usharp_alignment(block_size: int) -> int: max_h_limit = 2048 return max_h_limit // block_size @dataclass class SupaFlashMLAMetadata: # class SupaFlashMLAMetadata(FlashMLAMetadata): # NOTE(sang): Definition of context_len, query_len, and seq_len. # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| # |- tokenA -|......................|-- newTokens ---| # |---------- context_len ----------| # |-------------------- seq_len ---------------------| # |-- query_len ---| num_actual_tokens: int # Number of tokens excluding padding. max_query_len: int query_start_loc: torch.Tensor max_seq_len: int seq_lens: torch.Tensor block_table: torch.Tensor slot_mapping: torch.Tensor # BIREN Attention Params seq_start_loc: torch.Tensor context_lens: torch.Tensor max_decode_seq_len: int do_cache: bool # when use attentionsplit, do cache = False # For handling prefill decode split num_decodes: int num_decode_tokens: int num_prefills: int num_prefill_tokens: int num_actual_reqs: torch.Tensor # For cascade attention. use_cascade: bool common_prefix_len: int cu_prefix_query_lens: Optional[torch.Tensor] prefix_kv_lens: Optional[torch.Tensor] suffix_kv_lens: Optional[torch.Tensor] # Optional aot scheduling scheduler_metadata: Optional[torch.Tensor] = None prefix_scheduler_metadata: Optional[torch.Tensor] = None class SupaFlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): cudagraph_support: ClassVar[AttentionCGSupport] = \ AttentionCGSupport.UNIFORM_BATCH reorder_batch_threshold: int = 1 def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): super().__init__(kv_cache_spec, layer_names, vllm_config, device, FlashMLAMetadata) self.vllm_config = vllm_config self.num_q_heads = vllm_config.model_config.get_num_attention_heads( vllm_config.parallel_config) self.cg_buf_tile_scheduler_metadata = None self.cg_buf_num_splits = None device_properties = torch.cuda.get_device_properties(self.device) num_sms = device_properties.multi_processor_count if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): self.cg_buf_tile_scheduler_metadata = torch.zeros( # Upper bound on size (<= #SMs, TileSchedulerMetaDataSize) # TileSchedulerMetaDataSize = 8 (num_sms, 8), device=self.device, dtype=torch.int32, ) self.cg_buf_num_splits = torch.empty( (vllm_config.scheduler_config.max_num_seqs + 1), device=self.device, dtype=torch.int32) self.aot_schedule = False logger.warning( "AOT Schedule is disabled when using SUPAFlashAttention.") # Sliding window size to be used with the AOT scheduler will be # populated on first build() call. self.aot_sliding_window: Optional[tuple[int, int]] = None supports_spec_as_decode = True self._init_reorder_batch_threshold(1, supports_spec_as_decode) def build(self, common_prefix_len: int, common_attn_metadata: SUPACommonAttentionMetadata, fast_build: bool = False): num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len max_seq_len = int(common_attn_metadata.seq_lens_cpu[:num_reqs].max()) query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens seq_lens_cpu = common_attn_metadata.seq_lens_cpu block_table_tensor = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping num_actual_reqs = common_attn_metadata.num_actual_reqs aot_schedule = self.aot_schedule and not fast_build num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\ split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.reorder_batch_threshold, require_uniform=True) if self.aot_sliding_window is None: self.aot_sliding_window = (-1, -1) # For the AOT scheduler we need the sliding window value to be # constant for all layers to. We have to populate this on the first # build() call so the layers are constructed (cannot populate) # in __init__. if aot_schedule: sliding_window_configs = _get_sliding_window_configs( self.vllm_config) if len(sliding_window_configs) == 1: sliding_window_config = sliding_window_configs.pop() if sliding_window_config is not None: self.aot_sliding_window = sliding_window_config elif len(sliding_window_configs) > 1: self.aot_schedule = False aot_schedule = False def schedule(batch_size, cu_query_lens, max_query_len, seqlens, max_seq_len, causal): if self.aot_schedule: raise NotImplementedError( 'aot schedule not support in SUPA attention') return None use_cascade = common_prefix_len > 0 if use_cascade: cu_prefix_query_lens = torch.tensor([0, num_actual_tokens], dtype=torch.int32, device=self.runner.device) prefix_kv_lens = torch.tensor([common_prefix_len], dtype=torch.int32, device=self.runner.device) suffix_kv_lens = (self.runner.seq_lens_np[:num_reqs] - common_prefix_len) suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to( self.runner.device) prefix_scheduler_metadata = schedule( batch_size=1, cu_query_lens=cu_prefix_query_lens, max_query_len=num_actual_tokens, seqlens=prefix_kv_lens, max_seq_len=common_prefix_len, causal=False) scheduler_metadata = schedule(batch_size=num_reqs, cu_query_lens=query_start_loc, max_query_len=max_query_len, seqlens=suffix_kv_lens, max_seq_len=max_seq_len - common_prefix_len, causal=True) else: cu_prefix_query_lens = None prefix_kv_lens = None suffix_kv_lens = None prefix_scheduler_metadata = None scheduler_metadata = schedule(batch_size=num_reqs, cu_query_lens=query_start_loc, max_query_len=max_query_len, seqlens=seq_lens, max_seq_len=max_seq_len, causal=True) if common_attn_metadata.seq_start_loc is None: if len(seq_lens) > 8: seq_lens_cpu = seq_lens.cpu() seq_start_loc = torch.tensor( [0] + list(itertools.accumulate(seq_lens_cpu)), device=query_start_loc.device, dtype=torch.int32) else: seq_start_loc = torch.tensor( [0] + list(itertools.accumulate(seq_lens)), device=query_start_loc.device, dtype=torch.int32) else: seq_start_loc = common_attn_metadata.seq_start_loc if common_attn_metadata.context_lens is None: context_lens = seq_lens - (query_start_loc[1:] - query_start_loc[:-1]) else: context_lens = common_attn_metadata.context_lens if common_attn_metadata.max_decode_seq_len is None: max_decode_seq_len = max_decode_seq_len = int( seq_lens.max().item()) else: max_decode_seq_len = common_attn_metadata.max_decode_seq_len attn_metadata = SupaFlashMLAMetadata( num_actual_tokens=num_actual_tokens, max_query_len=max_query_len, query_start_loc=query_start_loc, max_seq_len=max_seq_len, seq_lens=seq_lens, block_table=block_table_tensor, slot_mapping=slot_mapping, use_cascade=use_cascade, common_prefix_len=common_prefix_len, scheduler_metadata=scheduler_metadata, cu_prefix_query_lens=cu_prefix_query_lens, prefix_kv_lens=prefix_kv_lens, suffix_kv_lens=suffix_kv_lens, prefix_scheduler_metadata=prefix_scheduler_metadata, # Biren Attention Params seq_start_loc=seq_start_loc, context_lens=context_lens, max_decode_seq_len=max_decode_seq_len, num_decodes=num_decodes, num_decode_tokens=num_decode_tokens, num_prefills=num_prefills, num_prefill_tokens=num_prefill_tokens, do_cache=True, num_actual_reqs=num_actual_reqs) return attn_metadata def can_run_in_cudagraph( self, common_attn_metadata: SUPACommonAttentionMetadata) -> bool: # Full CUDA Graph always supported (FA2 support checked separately) return False def use_cascade_attention(self, *args, **kwargs) -> bool: return False # class SupaFlashMLAImpl(FlashMLAImpl): class SupaFlashMLAImpl(MLACommonImpl[SupaFlashMLAMetadata]): can_return_lse_for_decode: bool = True def __init__( self, num_heads: int, head_size: int, scale: float, num_kv_heads: int, alibi_slopes: Optional[list[float]], sliding_window: Optional[int], kv_cache_dtype: str, logits_soft_cap: Optional[float], attn_type: str, kv_sharing_target_layer_name: Optional[str], # MLA Specific Arguments q_lora_rank: Optional[int], kv_lora_rank: int, qk_nope_head_dim: int, qk_rope_head_dim: int, qk_head_dim: int, v_head_dim: int, kv_b_proj: ColumnParallelLinear, rotary_emb: RotaryEmbedding, # # q_proj should be q_b_proj if q_lora_rank is not None, but from an # # attention backend perspective we rely on the layer to pass in the # # correct matrix q_proj: ColumnParallelLinear, # q_b_proj # kv_b_proj: ColumnParallelLinear, o_proj: RowParallelLinear, kv_a_proj_with_mqa: ReplicatedLinear, kv_a_layernorm: Any, q_a_proj: ReplicatedLinear, q_a_layernorm: Any, # MLA Specific Arguments **mla_args) -> None: super().__init__(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, logits_soft_cap, attn_type, kv_sharing_target_layer_name, q_lora_rank, kv_lora_rank, qk_nope_head_dim, qk_rope_head_dim, qk_head_dim, v_head_dim, kv_b_proj, **mla_args) self.rotary_emb = rotary_emb self.q_proj = q_proj self.kv_b_proj = kv_b_proj self.o_proj = o_proj self.kv_a_proj_with_mqa = kv_a_proj_with_mqa self.kv_a_layernorm = kv_a_layernorm self.q_a_layernorm = q_a_layernorm self.q_a_proj = q_a_proj self.tp_size = get_tensor_model_parallel_world_size() cur_device = torch.supa.current_device() self.spc_num = torch_br.supa.get_device_properties( cur_device).max_compute_units if envs.VLLM_BR_USE_FUSED_ALLREDUCE and self.tp_size == 8 and self.spc_num == 16: # Initialize the p2p info torch.supa.init_p2p_remote_id(cur_device) assert self.q_lora_rank is not None unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] if any(unsupported_features): raise NotImplementedError( "SUPAFlashMLAImpl does not support one of the following: " "alibi_slopes, sliding_window, logits_soft_cap") if attn_type != AttentionType.DECODER: raise NotImplementedError("Encoder self-attention and " "encoder/decoder cross-attention " "are not implemented for " "SUPAFlashMLAImpl") if is_quantized_kv_cache(self.kv_cache_dtype): raise NotImplementedError( "SUPAFlashMLA V1 with FP8 KV cache not yet supported") def process_weights_after_loading(self, act_dtype: torch.dtype): def get_layer_weight(layer): WEIGHT_NAMES = ("weight", "qweight", "weight_packed") for attr in WEIGHT_NAMES: if hasattr(layer, attr): return getattr(layer, attr) raise AttributeError( f"Layer '{layer}' has no recognized weight attribute:" f" {WEIGHT_NAMES}.") def get_and_maybe_dequant_weights(layer: LinearBase): if not isinstance(layer.quant_method, UnquantizedLinearMethod): # NOTE: This should only be used offline, since it's O(N^3) eye = torch.eye(layer.input_size_per_partition, dtype=act_dtype, device=get_layer_weight(layer).device) dequant_weights = layer.quant_method.apply(layer, eye, bias=None) del eye # standardize to (output, input) return dequant_weights.T return layer.weight if self.q_lora_rank is not None: # handle deepseek_v3 weight w_q_a = get_and_maybe_dequant_weights(self.q_a_proj).T w_kv_a = get_and_maybe_dequant_weights(self.kv_a_proj_with_mqa).T w_qkv_a = torch.cat([w_q_a, w_kv_a], dim=-1) # w_qkv_a must make two copies in br166 align_size = 32 die_spc_num = envs.VLLM_BR_DEVICE_SPC_NUM if die_spc_num > 16: w_qkv_a = torch.cat([w_qkv_a, w_qkv_a], dim=-1) self.w_qkv_a = _convert_to_numa_tensor(w_qkv_a, align_size, "colmajor", w_qkv_a.dtype) w_kv_b = get_and_maybe_dequant_weights(self.kv_b_proj).T w_k_b, w_v_b = w_kv_b.reshape( self.kv_lora_rank, -1, self.qk_nope_head_dim + self.v_head_dim).split( [self.qk_nope_head_dim, self.v_head_dim], dim=-1) w_k_b = w_k_b.permute(1, 2, 0).contiguous() w_v_b = w_v_b.permute(1, 0, 2).contiguous() w_o = get_and_maybe_dequant_weights(self.o_proj.to(w_v_b.device)).T hidden_dim = w_o.shape[-1] w_o = w_o.reshape(-1, self.v_head_dim, hidden_dim) w_vo = torch.bmm(w_v_b, w_o).reshape(-1, hidden_dim) self.w_vo = _convert_to_numa_tensor(w_vo, align_size, "colmajor", w_qkv_a.dtype, parallel_type="row_parallel") # replace q_b_proj as q_proj w_q_b = get_and_maybe_dequant_weights(self.q_proj).T w_q_b_nope, w_q_b_rope = w_q_b.reshape( self.q_lora_rank, -1, self.qk_head_dim).split( [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) w_q_b_nope = w_q_b_nope.permute(1, 0, 2).contiguous() w_q_b_rope = w_q_b_rope.reshape(self.q_lora_rank, -1) w_qk_b_nope = torch.bmm(w_q_b_nope, w_k_b).permute( 1, 0, 2).contiguous().reshape(self.q_lora_rank, -1) # w_qk_b_nope w_q_b_rope is independent head, separate like QKVParallelLinear if die_spc_num > 16: qk_b_nope0, qk_b_nope1 = torch.chunk(w_qk_b_nope, 2, dim=-1) qk_b_rope0, qk_b_rope1 = torch.chunk(w_q_b_rope, 2, dim=-1) w_qk_b = torch.cat( [qk_b_nope0, qk_b_rope0, qk_b_nope1, qk_b_rope1], dim=-1) else: w_qk_b = torch.cat([w_qk_b_nope, w_q_b_rope], dim=-1) self.w_qk_b = _convert_to_numa_tensor(w_qk_b, align_size, "colmajor", w_qkv_a.dtype) self.q_a_proj.weight = None self.kv_a_proj_with_mqa.weight = None self.q_proj.weight = None self.kv_b_proj.weight = None self.o_proj.weight = None if self.kv_a_layernorm.weight.dtype != torch.float32: self.kv_a_layernorm.weight.data = self.kv_a_layernorm.weight.to( torch.float32) if self.q_a_layernorm.weight.dtype != torch.float32: self.q_a_layernorm.weight.data = self.q_a_layernorm.weight.to( torch.float32) else: raise NotImplementedError torch.supa.empty_cache() def _forward_decode( self, q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: FlashMLAMetadata, layer: AttentionLayer, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: raise NotImplementedError def forward( self, layer: AttentionLayer, hidden_states: torch.Tensor, # query in unified attn positions: torch.Tensor, # reuse k_c_normed as position k_pe: torch.Tensor, # value in unified attn kv_cache: torch.Tensor, attn_metadata: SupaFlashMLAMetadata, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with torch SPDA and PagedAttention. Args: hidden_states: shape = [num_tokens, num_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size] kv_cache = [1, num_blocks, block_size * num_kv_heads * head_size] attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] """ assert output is None, "Output tensor should not provided." if envs.VLLM_BR_USE_CPU_ALL_REDUCE != 0 and not hasattr( self, "grandparent_pid"): self.grandparent_pid = get_grandparent_pid() # profile and warm up mla attention kernel if attn_metadata is None: return hidden_states # handle deepseek_v3 mla if hidden_states.shape[1] <= 512: query, key = torch_br.supa_mla_prefix_infer_v2( hidden_states, self.w_qkv_a, self.w_qk_b, self.q_a_layernorm.weight, self.kv_a_layernorm.weight, self.rotary_emb.sin_cache, self.rotary_emb.cos_cache, positions, kv_cache, attn_metadata.slot_mapping, self.num_heads, self.qk_head_dim, self.qk_nope_head_dim, self.qk_rope_head_dim, self.kv_lora_rank, self.v_head_dim, self.q_lora_rank, self.kv_a_layernorm.variance_epsilon) else: query, key = torch_br.supa_mla_prefix_infer_v3( hidden_states, self.w_qkv_a, self.w_qk_b, self.q_a_layernorm.weight, self.kv_a_layernorm.weight, self.rotary_emb.sin_cache, self.rotary_emb.cos_cache, positions, kv_cache, attn_metadata.slot_mapping, self.num_heads, self.qk_head_dim, self.qk_nope_head_dim, self.qk_rope_head_dim, self.kv_lora_rank, self.v_head_dim, self.q_lora_rank, self.kv_a_layernorm.variance_epsilon) if query.shape[0] == 1: output = torch.empty_like(query) else: output = torch_br._empty_ut_only( [1, query.shape[1], query.shape[0] * self.kv_lora_rank], device=query.device, dtype=query.dtype, tensor_type="colmajor", axis=2, sbp="SB" if envs.VLLM_BR_DEVICE_SPC_NUM > 16 else None) num_prefill_tokens = attn_metadata.num_prefill_tokens #decoder_qloc = attn_metadata.query_start_loc[:attn_metadata.num_decodes + 1].cpu() #if decoder_qloc.shape[0] > 1: # assert torch.all(torch.diff(decoder_qloc) == 1), f"Must ensure that it is an increasing queue with a step of 1 !\nq_loc:{attn_metadata.query_start_loc}" #print("num_prefill_tokens:", num_prefill_tokens) if num_prefill_tokens > 0: assert len(query.shape) == 3 output = torch_br.br_flash_attn_with_kvcache_infer( # type: ignore query, kv_cache, attn_metadata.query_start_loc, attn_metadata.seq_start_loc, attn_metadata.block_table, self.head_size, alibi_slopes=None, softmax_scale=self.scale, v_head_size=self.kv_lora_rank, num_reqs=attn_metadata.num_actual_reqs, ) else: assert len(query.shape) == 3 and attn_metadata.num_prefills == 0 output = torch_br.supa_attention_decoder_infer_v2( # type: ignore query, # type: ignore kv_cache, attn_metadata.block_table, attn_metadata.seq_lens, attn_metadata.max_decode_seq_len, self.head_size, attn_metadata.num_prefills, alibi_slopes=None, softmax_scale=self.scale, v_head_size=self.kv_lora_rank, ) # now linear+allreduce only support M <= 512 and tp_size == 4 | 8 and spc_num == 16 seq_len = hidden_states.shape[-2] tp_size = get_tensor_model_parallel_world_size() support_types = ((16, 4), (16, 8), (32, 2), (32, 4)) fused_comm = (envs.VLLM_BR_USE_FUSED_ALLREDUCE and seq_len <= envs.VLLM_BR_STATIC_MOE_DECODER_MAX_LEN and (envs.VLLM_BR_DEVICE_SPC_NUM, tp_size) in support_types) if fused_comm: tp_rank = get_tp_group().rank_in_group global_rank = get_tp_group().rank rank_i = global_rank % tp_size assert rank_i == tp_rank o_proj_out = torch_br.supa_fused_linear_allreduce_opt( output, self.w_vo, hidden_states.shape[-1], tp_rank, tp_size, global_rank, 0) else: # do o_proj output_parallel = torch_br.br_fused_mlp_infer( output, [self.w_vo], output_w=hidden_states.shape[-1]) if self.tp_size > 1: o_proj_out = tensor_model_parallel_all_reduce(output_parallel) else: o_proj_out = output_parallel return o_proj_out