658 lines
28 KiB
Python
658 lines
28 KiB
Python
################################################################################
|
|
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
################################################################################
|
|
import itertools
|
|
from dataclasses import dataclass
|
|
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Tuple, Union
|
|
|
|
import torch
|
|
import torch_br
|
|
|
|
from vllm.attention.backends.abstract import (AttentionLayer, AttentionType,
|
|
is_quantized_kv_cache)
|
|
from vllm.config import VllmConfig
|
|
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
|
get_tp_group, tensor_model_parallel_all_reduce)
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|
LinearBase, ReplicatedLinear,
|
|
RowParallelLinear,
|
|
UnquantizedLinearMethod)
|
|
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
|
from vllm.v1.attention.backends.flash_attn import _get_sliding_window_configs
|
|
from vllm.v1.attention.backends.mla.common import (MLACommonImpl,
|
|
MLACommonMetadataBuilder)
|
|
from vllm.v1.attention.backends.mla.flashmla import (FlashMLABackend,
|
|
FlashMLAMetadata)
|
|
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
|
|
split_decodes_and_prefills)
|
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
|
from vllm_br import envs
|
|
from vllm_br.model_executor.layers.br_utils import _convert_to_numa_tensor
|
|
from vllm_br.utils import get_grandparent_pid
|
|
from vllm_br.v1.attention.backends.utils import SUPACommonAttentionMetadata
|
|
|
|
if TYPE_CHECKING:
|
|
pass
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
class SupaFlashMLABackend(FlashMLABackend):
|
|
|
|
# NOTE: When piecewise cudagraph is enabled, this
|
|
# makes sure the output tensor is allocated inside the cudagraph.
|
|
# NOTE: currently, we do not support accept_output_buffer=True
|
|
accept_output_buffer: bool = False
|
|
|
|
@staticmethod
|
|
def get_supported_head_sizes() -> list[int]:
|
|
return [32, 64, 96, 128, 160, 192, 224, 256]
|
|
|
|
@staticmethod
|
|
def get_name() -> str:
|
|
return "SUPAFLASHMLA"
|
|
|
|
@staticmethod
|
|
def get_metadata_cls() -> type["SupaFlashMLAMetadata"]:
|
|
return SupaFlashMLAMetadata
|
|
|
|
@staticmethod
|
|
def get_builder_cls() -> type["SupaFlashMLAMetadataBuilder"]:
|
|
return SupaFlashMLAMetadataBuilder
|
|
|
|
@staticmethod
|
|
def get_impl_cls() -> type["SupaFlashMLAImpl"]:
|
|
return SupaFlashMLAImpl
|
|
|
|
@staticmethod
|
|
def get_kv_cache_shape(
|
|
num_blocks: int,
|
|
block_size: int,
|
|
num_kv_heads: int,
|
|
head_size: int,
|
|
) -> tuple[int, ...]:
|
|
if block_size % 16 != 0:
|
|
raise ValueError("Block size must be a multiple of 16.")
|
|
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
|
|
|
@staticmethod
|
|
def get_kv_cache_usharp_shape(
|
|
num_blocks: int,
|
|
block_size: int,
|
|
num_kv_heads: int,
|
|
head_size: int,
|
|
) -> Tuple[int, ...]:
|
|
th_gran = SupaFlashMLABackend.get_kv_cache_usharp_alignment(block_size)
|
|
n_block = max(1, (num_blocks + th_gran - 1) // th_gran)
|
|
# return (2, n_block, th_gran * block_size, num_kv_heads * head_size)
|
|
logger.debug(
|
|
f'Origin kv cache shape is [1, {num_blocks}, {block_size}, {num_kv_heads}, {head_size}, For SUPA Speed up, use [1, {n_block}, {th_gran * block_size}, {num_kv_heads * head_size}]' # noqa: G004
|
|
)
|
|
# TODO, shared kv only used in deepseek
|
|
return (1, n_block, th_gran * block_size, num_kv_heads * head_size)
|
|
|
|
@staticmethod
|
|
def get_kv_cache_usharp_alignment(block_size: int) -> int:
|
|
max_h_limit = 2048
|
|
return max_h_limit // block_size
|
|
|
|
|
|
@dataclass
|
|
class SupaFlashMLAMetadata:
|
|
# class SupaFlashMLAMetadata(FlashMLAMetadata):
|
|
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
|
# |---------- N-1 iteration --------|
|
|
# |---------------- N iteration ---------------------|
|
|
# |- tokenA -|......................|-- newTokens ---|
|
|
# |---------- context_len ----------|
|
|
# |-------------------- seq_len ---------------------|
|
|
# |-- query_len ---|
|
|
|
|
num_actual_tokens: int # Number of tokens excluding padding.
|
|
max_query_len: int
|
|
query_start_loc: torch.Tensor
|
|
max_seq_len: int
|
|
seq_lens: torch.Tensor
|
|
block_table: torch.Tensor
|
|
slot_mapping: torch.Tensor
|
|
|
|
# BIREN Attention Params
|
|
seq_start_loc: torch.Tensor
|
|
context_lens: torch.Tensor
|
|
max_decode_seq_len: int
|
|
do_cache: bool # when use attentionsplit, do cache = False
|
|
|
|
# For handling prefill decode split
|
|
num_decodes: int
|
|
num_decode_tokens: int
|
|
num_prefills: int
|
|
num_prefill_tokens: int
|
|
num_actual_reqs: torch.Tensor
|
|
|
|
# For cascade attention.
|
|
use_cascade: bool
|
|
common_prefix_len: int
|
|
cu_prefix_query_lens: Optional[torch.Tensor]
|
|
prefix_kv_lens: Optional[torch.Tensor]
|
|
suffix_kv_lens: Optional[torch.Tensor]
|
|
|
|
# Optional aot scheduling
|
|
scheduler_metadata: Optional[torch.Tensor] = None
|
|
prefix_scheduler_metadata: Optional[torch.Tensor] = None
|
|
|
|
|
|
class SupaFlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
|
cudagraph_support: ClassVar[AttentionCGSupport] = \
|
|
AttentionCGSupport.UNIFORM_BATCH
|
|
|
|
reorder_batch_threshold: int = 1
|
|
|
|
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
|
vllm_config: VllmConfig, device: torch.device):
|
|
super().__init__(kv_cache_spec, layer_names, vllm_config, device,
|
|
FlashMLAMetadata)
|
|
|
|
self.vllm_config = vllm_config
|
|
self.num_q_heads = vllm_config.model_config.get_num_attention_heads(
|
|
vllm_config.parallel_config)
|
|
|
|
self.cg_buf_tile_scheduler_metadata = None
|
|
self.cg_buf_num_splits = None
|
|
|
|
device_properties = torch.cuda.get_device_properties(self.device)
|
|
num_sms = device_properties.multi_processor_count
|
|
|
|
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
|
self.cg_buf_tile_scheduler_metadata = torch.zeros(
|
|
# Upper bound on size (<= #SMs, TileSchedulerMetaDataSize)
|
|
# TileSchedulerMetaDataSize = 8
|
|
(num_sms, 8),
|
|
device=self.device,
|
|
dtype=torch.int32,
|
|
)
|
|
self.cg_buf_num_splits = torch.empty(
|
|
(vllm_config.scheduler_config.max_num_seqs + 1),
|
|
device=self.device,
|
|
dtype=torch.int32)
|
|
|
|
self.aot_schedule = False
|
|
logger.warning(
|
|
"AOT Schedule is disabled when using SUPAFlashAttention.")
|
|
|
|
# Sliding window size to be used with the AOT scheduler will be
|
|
# populated on first build() call.
|
|
self.aot_sliding_window: Optional[tuple[int, int]] = None
|
|
|
|
supports_spec_as_decode = True
|
|
self._init_reorder_batch_threshold(1, supports_spec_as_decode)
|
|
|
|
def build(self,
|
|
common_prefix_len: int,
|
|
common_attn_metadata: SUPACommonAttentionMetadata,
|
|
fast_build: bool = False):
|
|
|
|
num_reqs = common_attn_metadata.num_reqs
|
|
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
|
max_query_len = common_attn_metadata.max_query_len
|
|
max_seq_len = int(common_attn_metadata.seq_lens_cpu[:num_reqs].max())
|
|
query_start_loc = common_attn_metadata.query_start_loc
|
|
seq_lens = common_attn_metadata.seq_lens
|
|
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
|
|
block_table_tensor = common_attn_metadata.block_table_tensor
|
|
slot_mapping = common_attn_metadata.slot_mapping
|
|
num_actual_reqs = common_attn_metadata.num_actual_reqs
|
|
|
|
aot_schedule = self.aot_schedule and not fast_build
|
|
|
|
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\
|
|
split_decodes_and_prefills(common_attn_metadata,
|
|
decode_threshold=self.reorder_batch_threshold,
|
|
require_uniform=True)
|
|
|
|
if self.aot_sliding_window is None:
|
|
self.aot_sliding_window = (-1, -1)
|
|
# For the AOT scheduler we need the sliding window value to be
|
|
# constant for all layers to. We have to populate this on the first
|
|
# build() call so the layers are constructed (cannot populate)
|
|
# in __init__.
|
|
if aot_schedule:
|
|
sliding_window_configs = _get_sliding_window_configs(
|
|
self.vllm_config)
|
|
if len(sliding_window_configs) == 1:
|
|
sliding_window_config = sliding_window_configs.pop()
|
|
if sliding_window_config is not None:
|
|
self.aot_sliding_window = sliding_window_config
|
|
elif len(sliding_window_configs) > 1:
|
|
self.aot_schedule = False
|
|
aot_schedule = False
|
|
|
|
def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
|
|
max_seq_len, causal):
|
|
if self.aot_schedule:
|
|
raise NotImplementedError(
|
|
'aot schedule not support in SUPA attention')
|
|
return None
|
|
|
|
use_cascade = common_prefix_len > 0
|
|
|
|
if use_cascade:
|
|
cu_prefix_query_lens = torch.tensor([0, num_actual_tokens],
|
|
dtype=torch.int32,
|
|
device=self.runner.device)
|
|
prefix_kv_lens = torch.tensor([common_prefix_len],
|
|
dtype=torch.int32,
|
|
device=self.runner.device)
|
|
suffix_kv_lens = (self.runner.seq_lens_np[:num_reqs] -
|
|
common_prefix_len)
|
|
suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to(
|
|
self.runner.device)
|
|
prefix_scheduler_metadata = schedule(
|
|
batch_size=1,
|
|
cu_query_lens=cu_prefix_query_lens,
|
|
max_query_len=num_actual_tokens,
|
|
seqlens=prefix_kv_lens,
|
|
max_seq_len=common_prefix_len,
|
|
causal=False)
|
|
scheduler_metadata = schedule(batch_size=num_reqs,
|
|
cu_query_lens=query_start_loc,
|
|
max_query_len=max_query_len,
|
|
seqlens=suffix_kv_lens,
|
|
max_seq_len=max_seq_len -
|
|
common_prefix_len,
|
|
causal=True)
|
|
else:
|
|
cu_prefix_query_lens = None
|
|
prefix_kv_lens = None
|
|
suffix_kv_lens = None
|
|
prefix_scheduler_metadata = None
|
|
scheduler_metadata = schedule(batch_size=num_reqs,
|
|
cu_query_lens=query_start_loc,
|
|
max_query_len=max_query_len,
|
|
seqlens=seq_lens,
|
|
max_seq_len=max_seq_len,
|
|
causal=True)
|
|
|
|
if common_attn_metadata.seq_start_loc is None:
|
|
if len(seq_lens) > 8:
|
|
seq_lens_cpu = seq_lens.cpu()
|
|
seq_start_loc = torch.tensor(
|
|
[0] + list(itertools.accumulate(seq_lens_cpu)),
|
|
device=query_start_loc.device,
|
|
dtype=torch.int32)
|
|
else:
|
|
seq_start_loc = torch.tensor(
|
|
[0] + list(itertools.accumulate(seq_lens)),
|
|
device=query_start_loc.device,
|
|
dtype=torch.int32)
|
|
else:
|
|
seq_start_loc = common_attn_metadata.seq_start_loc
|
|
|
|
if common_attn_metadata.context_lens is None:
|
|
context_lens = seq_lens - (query_start_loc[1:] -
|
|
query_start_loc[:-1])
|
|
else:
|
|
context_lens = common_attn_metadata.context_lens
|
|
|
|
if common_attn_metadata.max_decode_seq_len is None:
|
|
max_decode_seq_len = max_decode_seq_len = int(
|
|
seq_lens.max().item())
|
|
else:
|
|
max_decode_seq_len = common_attn_metadata.max_decode_seq_len
|
|
|
|
attn_metadata = SupaFlashMLAMetadata(
|
|
num_actual_tokens=num_actual_tokens,
|
|
max_query_len=max_query_len,
|
|
query_start_loc=query_start_loc,
|
|
max_seq_len=max_seq_len,
|
|
seq_lens=seq_lens,
|
|
block_table=block_table_tensor,
|
|
slot_mapping=slot_mapping,
|
|
use_cascade=use_cascade,
|
|
common_prefix_len=common_prefix_len,
|
|
scheduler_metadata=scheduler_metadata,
|
|
cu_prefix_query_lens=cu_prefix_query_lens,
|
|
prefix_kv_lens=prefix_kv_lens,
|
|
suffix_kv_lens=suffix_kv_lens,
|
|
prefix_scheduler_metadata=prefix_scheduler_metadata,
|
|
# Biren Attention Params
|
|
seq_start_loc=seq_start_loc,
|
|
context_lens=context_lens,
|
|
max_decode_seq_len=max_decode_seq_len,
|
|
num_decodes=num_decodes,
|
|
num_decode_tokens=num_decode_tokens,
|
|
num_prefills=num_prefills,
|
|
num_prefill_tokens=num_prefill_tokens,
|
|
do_cache=True,
|
|
num_actual_reqs=num_actual_reqs)
|
|
|
|
return attn_metadata
|
|
|
|
def can_run_in_cudagraph(
|
|
self, common_attn_metadata: SUPACommonAttentionMetadata) -> bool:
|
|
# Full CUDA Graph always supported (FA2 support checked separately)
|
|
return False
|
|
|
|
def use_cascade_attention(self, *args, **kwargs) -> bool:
|
|
return False
|
|
|
|
|
|
# class SupaFlashMLAImpl(FlashMLAImpl):
|
|
class SupaFlashMLAImpl(MLACommonImpl[SupaFlashMLAMetadata]):
|
|
can_return_lse_for_decode: bool = True
|
|
|
|
def __init__(
|
|
self,
|
|
num_heads: int,
|
|
head_size: int,
|
|
scale: float,
|
|
num_kv_heads: int,
|
|
alibi_slopes: Optional[list[float]],
|
|
sliding_window: Optional[int],
|
|
kv_cache_dtype: str,
|
|
logits_soft_cap: Optional[float],
|
|
attn_type: str,
|
|
kv_sharing_target_layer_name: Optional[str],
|
|
# MLA Specific Arguments
|
|
q_lora_rank: Optional[int],
|
|
kv_lora_rank: int,
|
|
qk_nope_head_dim: int,
|
|
qk_rope_head_dim: int,
|
|
qk_head_dim: int,
|
|
v_head_dim: int,
|
|
kv_b_proj: ColumnParallelLinear,
|
|
rotary_emb: RotaryEmbedding,
|
|
# # q_proj should be q_b_proj if q_lora_rank is not None, but from an
|
|
# # attention backend perspective we rely on the layer to pass in the
|
|
# # correct matrix
|
|
q_proj: ColumnParallelLinear, # q_b_proj
|
|
# kv_b_proj: ColumnParallelLinear,
|
|
o_proj: RowParallelLinear,
|
|
kv_a_proj_with_mqa: ReplicatedLinear,
|
|
kv_a_layernorm: Any,
|
|
q_a_proj: ReplicatedLinear,
|
|
q_a_layernorm: Any,
|
|
|
|
# MLA Specific Arguments
|
|
**mla_args) -> None:
|
|
|
|
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
|
alibi_slopes, sliding_window, kv_cache_dtype,
|
|
logits_soft_cap, attn_type,
|
|
kv_sharing_target_layer_name, q_lora_rank,
|
|
kv_lora_rank, qk_nope_head_dim, qk_rope_head_dim,
|
|
qk_head_dim, v_head_dim, kv_b_proj, **mla_args)
|
|
|
|
self.rotary_emb = rotary_emb
|
|
|
|
self.q_proj = q_proj
|
|
self.kv_b_proj = kv_b_proj
|
|
self.o_proj = o_proj
|
|
self.kv_a_proj_with_mqa = kv_a_proj_with_mqa
|
|
self.kv_a_layernorm = kv_a_layernorm
|
|
self.q_a_layernorm = q_a_layernorm
|
|
self.q_a_proj = q_a_proj
|
|
|
|
self.tp_size = get_tensor_model_parallel_world_size()
|
|
cur_device = torch.supa.current_device()
|
|
self.spc_num = torch_br.supa.get_device_properties(
|
|
cur_device).max_compute_units
|
|
|
|
if envs.VLLM_BR_USE_FUSED_ALLREDUCE and self.tp_size == 8 and self.spc_num == 16:
|
|
# Initialize the p2p info
|
|
torch.supa.init_p2p_remote_id(cur_device)
|
|
|
|
assert self.q_lora_rank is not None
|
|
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
|
if any(unsupported_features):
|
|
raise NotImplementedError(
|
|
"SUPAFlashMLAImpl does not support one of the following: "
|
|
"alibi_slopes, sliding_window, logits_soft_cap")
|
|
|
|
if attn_type != AttentionType.DECODER:
|
|
raise NotImplementedError("Encoder self-attention and "
|
|
"encoder/decoder cross-attention "
|
|
"are not implemented for "
|
|
"SUPAFlashMLAImpl")
|
|
|
|
if is_quantized_kv_cache(self.kv_cache_dtype):
|
|
raise NotImplementedError(
|
|
"SUPAFlashMLA V1 with FP8 KV cache not yet supported")
|
|
|
|
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
|
|
|
def get_layer_weight(layer):
|
|
WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
|
|
for attr in WEIGHT_NAMES:
|
|
if hasattr(layer, attr):
|
|
return getattr(layer, attr)
|
|
raise AttributeError(
|
|
f"Layer '{layer}' has no recognized weight attribute:"
|
|
f" {WEIGHT_NAMES}.")
|
|
|
|
def get_and_maybe_dequant_weights(layer: LinearBase):
|
|
if not isinstance(layer.quant_method, UnquantizedLinearMethod):
|
|
# NOTE: This should only be used offline, since it's O(N^3)
|
|
eye = torch.eye(layer.input_size_per_partition,
|
|
dtype=act_dtype,
|
|
device=get_layer_weight(layer).device)
|
|
dequant_weights = layer.quant_method.apply(layer,
|
|
eye,
|
|
bias=None)
|
|
del eye
|
|
# standardize to (output, input)
|
|
return dequant_weights.T
|
|
return layer.weight
|
|
|
|
if self.q_lora_rank is not None:
|
|
# handle deepseek_v3 weight
|
|
w_q_a = get_and_maybe_dequant_weights(self.q_a_proj).T
|
|
w_kv_a = get_and_maybe_dequant_weights(self.kv_a_proj_with_mqa).T
|
|
w_qkv_a = torch.cat([w_q_a, w_kv_a], dim=-1)
|
|
# w_qkv_a must make two copies in br166
|
|
align_size = 32
|
|
die_spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
|
|
if die_spc_num > 16:
|
|
w_qkv_a = torch.cat([w_qkv_a, w_qkv_a], dim=-1)
|
|
self.w_qkv_a = _convert_to_numa_tensor(w_qkv_a, align_size,
|
|
"colmajor", w_qkv_a.dtype)
|
|
|
|
w_kv_b = get_and_maybe_dequant_weights(self.kv_b_proj).T
|
|
w_k_b, w_v_b = w_kv_b.reshape(
|
|
self.kv_lora_rank, -1,
|
|
self.qk_nope_head_dim + self.v_head_dim).split(
|
|
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
|
w_k_b = w_k_b.permute(1, 2, 0).contiguous()
|
|
w_v_b = w_v_b.permute(1, 0, 2).contiguous()
|
|
|
|
w_o = get_and_maybe_dequant_weights(self.o_proj.to(w_v_b.device)).T
|
|
hidden_dim = w_o.shape[-1]
|
|
w_o = w_o.reshape(-1, self.v_head_dim, hidden_dim)
|
|
w_vo = torch.bmm(w_v_b, w_o).reshape(-1, hidden_dim)
|
|
self.w_vo = _convert_to_numa_tensor(w_vo,
|
|
align_size,
|
|
"colmajor",
|
|
w_qkv_a.dtype,
|
|
parallel_type="row_parallel")
|
|
|
|
# replace q_b_proj as q_proj
|
|
w_q_b = get_and_maybe_dequant_weights(self.q_proj).T
|
|
w_q_b_nope, w_q_b_rope = w_q_b.reshape(
|
|
self.q_lora_rank, -1, self.qk_head_dim).split(
|
|
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
|
w_q_b_nope = w_q_b_nope.permute(1, 0, 2).contiguous()
|
|
w_q_b_rope = w_q_b_rope.reshape(self.q_lora_rank, -1)
|
|
|
|
w_qk_b_nope = torch.bmm(w_q_b_nope, w_k_b).permute(
|
|
1, 0, 2).contiguous().reshape(self.q_lora_rank, -1)
|
|
# w_qk_b_nope w_q_b_rope is independent head, separate like QKVParallelLinear
|
|
if die_spc_num > 16:
|
|
qk_b_nope0, qk_b_nope1 = torch.chunk(w_qk_b_nope, 2, dim=-1)
|
|
qk_b_rope0, qk_b_rope1 = torch.chunk(w_q_b_rope, 2, dim=-1)
|
|
w_qk_b = torch.cat(
|
|
[qk_b_nope0, qk_b_rope0, qk_b_nope1, qk_b_rope1], dim=-1)
|
|
else:
|
|
w_qk_b = torch.cat([w_qk_b_nope, w_q_b_rope], dim=-1)
|
|
self.w_qk_b = _convert_to_numa_tensor(w_qk_b, align_size,
|
|
"colmajor", w_qkv_a.dtype)
|
|
|
|
self.q_a_proj.weight = None
|
|
self.kv_a_proj_with_mqa.weight = None
|
|
self.q_proj.weight = None
|
|
self.kv_b_proj.weight = None
|
|
self.o_proj.weight = None
|
|
|
|
if self.kv_a_layernorm.weight.dtype != torch.float32:
|
|
self.kv_a_layernorm.weight.data = self.kv_a_layernorm.weight.to(
|
|
torch.float32)
|
|
if self.q_a_layernorm.weight.dtype != torch.float32:
|
|
self.q_a_layernorm.weight.data = self.q_a_layernorm.weight.to(
|
|
torch.float32)
|
|
else:
|
|
raise NotImplementedError
|
|
torch.supa.empty_cache()
|
|
|
|
def _forward_decode(
|
|
self,
|
|
q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
|
|
kv_c_and_k_pe_cache: torch.Tensor,
|
|
attn_metadata: FlashMLAMetadata,
|
|
layer: AttentionLayer,
|
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
raise NotImplementedError
|
|
|
|
def forward(
|
|
self,
|
|
layer: AttentionLayer,
|
|
hidden_states: torch.Tensor, # query in unified attn
|
|
positions: torch.Tensor, # reuse k_c_normed as position
|
|
k_pe: torch.Tensor, # value in unified attn
|
|
kv_cache: torch.Tensor,
|
|
attn_metadata: SupaFlashMLAMetadata,
|
|
output: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
"""Forward pass with torch SPDA and PagedAttention.
|
|
|
|
Args:
|
|
hidden_states: shape = [num_tokens, num_heads * head_size]
|
|
key: shape = [num_tokens, num_kv_heads * head_size]
|
|
value: shape = [num_tokens, num_kv_heads * head_size]
|
|
kv_cache = [1, num_blocks, block_size * num_kv_heads * head_size]
|
|
attn_metadata: Metadata for attention.
|
|
Returns:
|
|
shape = [num_tokens, num_heads * head_size]
|
|
"""
|
|
assert output is None, "Output tensor should not provided."
|
|
if envs.VLLM_BR_USE_CPU_ALL_REDUCE != 0 and not hasattr(
|
|
self, "grandparent_pid"):
|
|
self.grandparent_pid = get_grandparent_pid()
|
|
|
|
# profile and warm up mla attention kernel
|
|
if attn_metadata is None:
|
|
return hidden_states
|
|
|
|
# handle deepseek_v3 mla
|
|
if hidden_states.shape[1] <= 512:
|
|
query, key = torch_br.supa_mla_prefix_infer_v2(
|
|
hidden_states, self.w_qkv_a, self.w_qk_b,
|
|
self.q_a_layernorm.weight, self.kv_a_layernorm.weight,
|
|
self.rotary_emb.sin_cache, self.rotary_emb.cos_cache,
|
|
positions, kv_cache, attn_metadata.slot_mapping,
|
|
self.num_heads, self.qk_head_dim, self.qk_nope_head_dim,
|
|
self.qk_rope_head_dim, self.kv_lora_rank, self.v_head_dim,
|
|
self.q_lora_rank, self.kv_a_layernorm.variance_epsilon)
|
|
else:
|
|
query, key = torch_br.supa_mla_prefix_infer_v3(
|
|
hidden_states, self.w_qkv_a, self.w_qk_b,
|
|
self.q_a_layernorm.weight, self.kv_a_layernorm.weight,
|
|
self.rotary_emb.sin_cache, self.rotary_emb.cos_cache,
|
|
positions, kv_cache, attn_metadata.slot_mapping,
|
|
self.num_heads, self.qk_head_dim, self.qk_nope_head_dim,
|
|
self.qk_rope_head_dim, self.kv_lora_rank, self.v_head_dim,
|
|
self.q_lora_rank, self.kv_a_layernorm.variance_epsilon)
|
|
|
|
if query.shape[0] == 1:
|
|
output = torch.empty_like(query)
|
|
else:
|
|
output = torch_br._empty_ut_only(
|
|
[1, query.shape[1], query.shape[0] * self.kv_lora_rank],
|
|
device=query.device,
|
|
dtype=query.dtype,
|
|
tensor_type="colmajor",
|
|
axis=2,
|
|
sbp="SB" if envs.VLLM_BR_DEVICE_SPC_NUM > 16 else None)
|
|
|
|
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
|
#decoder_qloc = attn_metadata.query_start_loc[:attn_metadata.num_decodes + 1].cpu()
|
|
#if decoder_qloc.shape[0] > 1:
|
|
# assert torch.all(torch.diff(decoder_qloc) == 1), f"Must ensure that it is an increasing queue with a step of 1 !\nq_loc:{attn_metadata.query_start_loc}"
|
|
#print("num_prefill_tokens:", num_prefill_tokens)
|
|
if num_prefill_tokens > 0:
|
|
assert len(query.shape) == 3
|
|
output = torch_br.br_flash_attn_with_kvcache_infer( # type: ignore
|
|
query,
|
|
kv_cache,
|
|
attn_metadata.query_start_loc,
|
|
attn_metadata.seq_start_loc,
|
|
attn_metadata.block_table,
|
|
self.head_size,
|
|
alibi_slopes=None,
|
|
softmax_scale=self.scale,
|
|
v_head_size=self.kv_lora_rank,
|
|
num_reqs=attn_metadata.num_actual_reqs,
|
|
)
|
|
else:
|
|
assert len(query.shape) == 3 and attn_metadata.num_prefills == 0
|
|
output = torch_br.supa_attention_decoder_infer_v2( # type: ignore
|
|
query, # type: ignore
|
|
kv_cache,
|
|
attn_metadata.block_table,
|
|
attn_metadata.seq_lens,
|
|
attn_metadata.max_decode_seq_len,
|
|
self.head_size,
|
|
attn_metadata.num_prefills,
|
|
alibi_slopes=None,
|
|
softmax_scale=self.scale,
|
|
v_head_size=self.kv_lora_rank,
|
|
)
|
|
|
|
# now linear+allreduce only support M <= 512 and tp_size == 4 | 8 and spc_num == 16
|
|
seq_len = hidden_states.shape[-2]
|
|
tp_size = get_tensor_model_parallel_world_size()
|
|
support_types = ((16, 4), (16, 8), (32, 2), (32, 4))
|
|
fused_comm = (envs.VLLM_BR_USE_FUSED_ALLREDUCE
|
|
and seq_len <= envs.VLLM_BR_STATIC_MOE_DECODER_MAX_LEN
|
|
and
|
|
(envs.VLLM_BR_DEVICE_SPC_NUM, tp_size) in support_types)
|
|
|
|
if fused_comm:
|
|
tp_rank = get_tp_group().rank_in_group
|
|
global_rank = get_tp_group().rank
|
|
rank_i = global_rank % tp_size
|
|
assert rank_i == tp_rank
|
|
o_proj_out = torch_br.supa_fused_linear_allreduce_opt(
|
|
output, self.w_vo, hidden_states.shape[-1], tp_rank, tp_size,
|
|
global_rank, 0)
|
|
else:
|
|
# do o_proj
|
|
output_parallel = torch_br.br_fused_mlp_infer(
|
|
output, [self.w_vo], output_w=hidden_states.shape[-1])
|
|
if self.tp_size > 1:
|
|
o_proj_out = tensor_model_parallel_all_reduce(output_parallel)
|
|
else:
|
|
o_proj_out = output_parallel
|
|
|
|
return o_proj_out
|