451 lines
19 KiB
Python
451 lines
19 KiB
Python
|
|
################################################################################
|
|||
|
|
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
|||
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|||
|
|
# you may not use this file except in compliance with the License.
|
|||
|
|
# You may obtain a copy of the License at
|
|||
|
|
#
|
|||
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|||
|
|
#
|
|||
|
|
# Unless required by applicable law or agreed to in writing, software
|
|||
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|||
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||
|
|
# See the License for the specific language governing permissions and
|
|||
|
|
# limitations under the License.
|
|||
|
|
#
|
|||
|
|
################################################################################
|
|||
|
|
import itertools
|
|||
|
|
from dataclasses import dataclass
|
|||
|
|
from typing import TYPE_CHECKING, ClassVar, Optional, Tuple
|
|||
|
|
|
|||
|
|
import numpy as np
|
|||
|
|
import torch
|
|||
|
|
import torch_br
|
|||
|
|
|
|||
|
|
from vllm.attention.backends.abstract import AttentionLayer, AttentionMetadata
|
|||
|
|
from vllm.attention.ops.flashmla import get_mla_metadata
|
|||
|
|
from vllm.config import VllmConfig
|
|||
|
|
from vllm.logger import init_logger
|
|||
|
|
from vllm.v1.attention.backends.mla.flashmla_sparse import (
|
|||
|
|
FlashMLASparseBackend, FlashMLASparseImpl, FlashMLASparseMetadata,
|
|||
|
|
FlashMLASparseMetadataBuilder)
|
|||
|
|
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
|
|||
|
|
CommonAttentionMetadata,
|
|||
|
|
split_decodes_and_prefills)
|
|||
|
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
|||
|
|
|
|||
|
|
if TYPE_CHECKING:
|
|||
|
|
from vllm.model_executor.models.deepseek_v2 import Indexer
|
|||
|
|
from vllm.v1.core.sched.output import SchedulerOutput
|
|||
|
|
from vllm.v1.worker.gpu_input_batch import InputBatch
|
|||
|
|
|
|||
|
|
logger = init_logger(__name__)
|
|||
|
|
|
|||
|
|
_NO_DEFAULT = object()
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class SupaFlashMLASparseMetadata(FlashMLASparseMetadata):
|
|||
|
|
# BIREN Attention Params
|
|||
|
|
seq_start_loc: torch.Tensor = _NO_DEFAULT
|
|||
|
|
context_lens: torch.Tensor = _NO_DEFAULT
|
|||
|
|
max_decode_seq_len: int = -1
|
|||
|
|
num_prefills: int = -1
|
|||
|
|
num_decodes: int = -1
|
|||
|
|
num_prefill_tokens: int = -1
|
|||
|
|
num_decode_tokens: int = -1
|
|||
|
|
|
|||
|
|
def __post_init__(self):
|
|||
|
|
if self.seq_start_loc is _NO_DEFAULT or self.context_lens is _NO_DEFAULT or \
|
|||
|
|
self.max_decode_seq_len == -1 or self.num_prefills == -1 or \
|
|||
|
|
self.num_decodes == -1 or self.num_prefill_tokens == -1 or \
|
|||
|
|
self.num_decode_tokens == -1:
|
|||
|
|
raise TypeError("__init__ missing required argument")
|
|||
|
|
|
|||
|
|
|
|||
|
|
class SupaFlashMLASparseMetadataBuilder(FlashMLASparseMetadataBuilder):
|
|||
|
|
|
|||
|
|
reorder_batch_threshold: int = 1
|
|||
|
|
cudagraph_support: ClassVar[AttentionCGSupport] = \
|
|||
|
|
AttentionCGSupport.UNIFORM_BATCH
|
|||
|
|
|
|||
|
|
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
|||
|
|
vllm_config: VllmConfig, device: torch.device):
|
|||
|
|
super().__init__(
|
|||
|
|
kv_cache_spec=kv_cache_spec,
|
|||
|
|
layer_names=layer_names,
|
|||
|
|
vllm_config=vllm_config,
|
|||
|
|
device=device,
|
|||
|
|
)
|
|||
|
|
self.vllm_config = vllm_config
|
|||
|
|
self.num_speculative_tokens = (
|
|||
|
|
self.vllm_config.speculative_config.num_speculative_tokens
|
|||
|
|
if self.vllm_config.speculative_config else 0)
|
|||
|
|
# Now deepgemm fp8_paged_mqa_logits does not support next_n > 2
|
|||
|
|
self.reorder_batch_threshold += min(self.num_speculative_tokens, 1)
|
|||
|
|
|
|||
|
|
def reorder_batch(self, input_batch: "InputBatch",
|
|||
|
|
scheduler_output: "SchedulerOutput") -> bool:
|
|||
|
|
"""On SUPA, we want prefill at front and decode at back.
|
|||
|
|
"""
|
|||
|
|
# TODO update doc
|
|||
|
|
# We now want to reorder the batch so that the "decode" requests are and
|
|||
|
|
# the front and the "prefill" requests are at the using the least amount
|
|||
|
|
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
|
|||
|
|
# where attention is likely memory-bound and "prefill" to mean requests
|
|||
|
|
# where attention is likely compute-bound, TODO(lucas): figure out a
|
|||
|
|
# better naming here)
|
|||
|
|
decodes = []
|
|||
|
|
prefills = []
|
|||
|
|
num_decode_tokens = 0
|
|||
|
|
num_prefill_tokens = 0
|
|||
|
|
|
|||
|
|
for i, req_id in enumerate(input_batch.req_ids):
|
|||
|
|
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
|||
|
|
num_spec_tokens = len(
|
|||
|
|
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
|
|||
|
|
# for now treat 1 scheduled token as "decode" even if its not,
|
|||
|
|
# we should update this to something like < 8 in the future but
|
|||
|
|
# currently the TritonMLA._forward_decode only supports
|
|||
|
|
# num_tokens = 1
|
|||
|
|
if num_tokens - num_spec_tokens == 1:
|
|||
|
|
decodes.append(i)
|
|||
|
|
num_decode_tokens += num_tokens
|
|||
|
|
else:
|
|||
|
|
prefills.append(i)
|
|||
|
|
num_prefill_tokens += num_tokens
|
|||
|
|
# TODO update doc
|
|||
|
|
# We hope that this is fairly minimal since decodes
|
|||
|
|
# should be around for a number of iterations so hopefully they are
|
|||
|
|
# relatively stationary (and new request are generally appended to the
|
|||
|
|
# persistent batch so already should be at the back)
|
|||
|
|
# To achieve this we loop over the decodes in descending order and
|
|||
|
|
# the prefills in ascending order. We swap decodes from the "back"
|
|||
|
|
# i.e. past where the last decode should be in the reodorered with
|
|||
|
|
# prefills from the front of the batch.
|
|||
|
|
# `decodes` and `prefills` are already in ascending order just based on
|
|||
|
|
# the above loop
|
|||
|
|
num_decodes = len(decodes)
|
|||
|
|
num_prefills = len(prefills)
|
|||
|
|
modified_batch = False
|
|||
|
|
|
|||
|
|
# for i in range(1, min(num_decodes, num_prefills) + 1):
|
|||
|
|
# # If the decode is at the "back" of the batch, i, we can swap it
|
|||
|
|
# # with the prefill closest to the front of the batch
|
|||
|
|
# decode_idx = decodes[num_decodes - i]
|
|||
|
|
# if decode_idx < num_decodes:
|
|||
|
|
# break
|
|||
|
|
|
|||
|
|
# input_batch.swap_states(prefills[i - 1], decode_idx)
|
|||
|
|
# modified_batch = True
|
|||
|
|
for i in range(1, min(num_decodes, num_prefills) + 1):
|
|||
|
|
# If the decode is at the "back" of the batch, i, we can swap it
|
|||
|
|
# with the prefill closest to the front of the batch
|
|||
|
|
prefills_idx = prefills[num_prefills - i]
|
|||
|
|
if prefills_idx < num_prefills:
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
input_batch.swap_states(decodes[i - 1], prefills_idx)
|
|||
|
|
modified_batch = True
|
|||
|
|
|
|||
|
|
# Save for next `build` call
|
|||
|
|
# TODO(lucas): this is a bit of a hack, we should probably have a
|
|||
|
|
# better way of doing this
|
|||
|
|
self._num_decodes = num_decodes
|
|||
|
|
self._num_prefills = num_prefills
|
|||
|
|
self._num_decode_tokens = num_decode_tokens
|
|||
|
|
self._num_prefill_tokens = num_prefill_tokens
|
|||
|
|
|
|||
|
|
return modified_batch
|
|||
|
|
|
|||
|
|
def build(self,
|
|||
|
|
common_prefix_len: int,
|
|||
|
|
common_attn_metadata: CommonAttentionMetadata,
|
|||
|
|
fast_build: bool = False) -> SupaFlashMLASparseMetadata:
|
|||
|
|
|
|||
|
|
num_tokens = common_attn_metadata.num_actual_tokens
|
|||
|
|
starts = np.asarray(common_attn_metadata.query_start_loc_cpu,
|
|||
|
|
dtype=np.int32)
|
|||
|
|
seg_lengths = np.diff(starts)
|
|||
|
|
req_id_per_token = np.repeat(
|
|||
|
|
np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths)
|
|||
|
|
# Zero-fill for cudagraphs
|
|||
|
|
self.req_id_per_token_buffer.fill_(0)
|
|||
|
|
self.req_id_per_token_buffer[:req_id_per_token.shape[0]]\
|
|||
|
|
.copy_(torch.from_numpy(req_id_per_token), non_blocking=True)
|
|||
|
|
req_id_per_token = self.req_id_per_token_buffer[:num_tokens]
|
|||
|
|
|
|||
|
|
fp8_extra_metadata = None
|
|||
|
|
if self.use_fp8_kv_cache:
|
|||
|
|
tile_scheduler_metadata, num_splits = get_mla_metadata(
|
|||
|
|
cache_seqlens=self.topk_tokens_tensor,
|
|||
|
|
num_q_tokens_per_head_k=num_tokens * self.num_heads,
|
|||
|
|
topk=self.topk_tokens,
|
|||
|
|
num_heads_q=self.num_heads,
|
|||
|
|
num_heads_k=1,
|
|||
|
|
is_fp8_kvcache=True,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
num_sm_parts = tile_scheduler_metadata.size(0)
|
|||
|
|
# Copy to persistent buffer for full-CG support
|
|||
|
|
tile_scheduler_metadata_buffer = \
|
|||
|
|
self.tile_scheduler_metadata_buffer[:num_sm_parts]
|
|||
|
|
tile_scheduler_metadata_buffer.copy_(tile_scheduler_metadata)
|
|||
|
|
self.num_splits_buffer.copy_(num_splits)
|
|||
|
|
|
|||
|
|
fp8_extra_metadata = FlashMLASparseMetadata.FP8KernelMetadata(
|
|||
|
|
scheduler_metadata=tile_scheduler_metadata_buffer,
|
|||
|
|
num_splits=self.num_splits_buffer,
|
|||
|
|
# cache_lens and block_table are basically unused in sparse case
|
|||
|
|
# but the decode kernel will treat -1 and indices >= cache_lens
|
|||
|
|
# as invalid so we make sure cache_lens is large enough to not
|
|||
|
|
# accidentally mark indices invalid, we will use -1 exclusively
|
|||
|
|
# to mark invalid indices
|
|||
|
|
cache_lens=self.max_model_len_tensor,
|
|||
|
|
dummy_block_table=self.dummy_block_table)
|
|||
|
|
|
|||
|
|
# Add biren attention params
|
|||
|
|
query_start_loc = common_attn_metadata.query_start_loc
|
|||
|
|
seq_lens = common_attn_metadata.seq_lens
|
|||
|
|
num_reqs = common_attn_metadata.num_reqs
|
|||
|
|
num_tokens = common_attn_metadata.num_actual_tokens
|
|||
|
|
|
|||
|
|
if common_attn_metadata.seq_start_loc is None:
|
|||
|
|
if len(seq_lens) > 8:
|
|||
|
|
seq_lens_cpu = seq_lens.cpu()
|
|||
|
|
seq_start_loc = torch.tensor(
|
|||
|
|
[0] + list(itertools.accumulate(seq_lens_cpu)),
|
|||
|
|
device=query_start_loc.device,
|
|||
|
|
dtype=torch.int32)
|
|||
|
|
else:
|
|||
|
|
seq_start_loc = torch.tensor(
|
|||
|
|
[0] + list(itertools.accumulate(seq_lens)),
|
|||
|
|
device=query_start_loc.device,
|
|||
|
|
dtype=torch.int32)
|
|||
|
|
else:
|
|||
|
|
seq_start_loc = common_attn_metadata.seq_start_loc
|
|||
|
|
|
|||
|
|
if common_attn_metadata.context_lens is None:
|
|||
|
|
context_lens = seq_lens - (query_start_loc[1:] -
|
|||
|
|
query_start_loc[:-1])
|
|||
|
|
else:
|
|||
|
|
context_lens = common_attn_metadata.context_lens
|
|||
|
|
|
|||
|
|
if common_attn_metadata.max_decode_seq_len is None:
|
|||
|
|
max_decode_seq_len = max_decode_seq_len = int(
|
|||
|
|
seq_lens.max().item())
|
|||
|
|
else:
|
|||
|
|
max_decode_seq_len = common_attn_metadata.max_decode_seq_len
|
|||
|
|
|
|||
|
|
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
|
|||
|
|
split_decodes_and_prefills(
|
|||
|
|
common_attn_metadata,
|
|||
|
|
decode_threshold=self.reorder_batch_threshold)
|
|||
|
|
assert num_decodes + num_prefills == num_reqs
|
|||
|
|
assert num_decode_tokens + num_prefill_tokens == num_tokens
|
|||
|
|
|
|||
|
|
metadata = SupaFlashMLASparseMetadata(
|
|||
|
|
num_reqs=common_attn_metadata.num_reqs,
|
|||
|
|
max_query_len=common_attn_metadata.max_query_len,
|
|||
|
|
max_seq_len=common_attn_metadata.max_seq_len,
|
|||
|
|
num_actual_tokens=common_attn_metadata.num_actual_tokens,
|
|||
|
|
query_start_loc=common_attn_metadata.query_start_loc,
|
|||
|
|
slot_mapping=common_attn_metadata.slot_mapping,
|
|||
|
|
block_table=common_attn_metadata.block_table_tensor,
|
|||
|
|
req_id_per_token=req_id_per_token,
|
|||
|
|
block_size=self.kv_cache_spec.block_size,
|
|||
|
|
topk_tokens=self.topk_tokens,
|
|||
|
|
fp8_extra_metadata=fp8_extra_metadata,
|
|||
|
|
seq_start_loc=seq_start_loc,
|
|||
|
|
context_lens=context_lens,
|
|||
|
|
max_decode_seq_len=max_decode_seq_len,
|
|||
|
|
num_prefills=num_prefills,
|
|||
|
|
num_decodes=num_decodes,
|
|||
|
|
num_prefill_tokens=num_prefill_tokens,
|
|||
|
|
num_decode_tokens=num_decode_tokens,
|
|||
|
|
)
|
|||
|
|
return metadata
|
|||
|
|
|
|||
|
|
|
|||
|
|
class SupaFlashMLASparseBackend(FlashMLASparseBackend):
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def get_name() -> str:
|
|||
|
|
return "SUPA_FLASHMLA_SPARSE_VLLM_V1"
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def get_metadata_cls() -> type[AttentionMetadata]:
|
|||
|
|
return SupaFlashMLASparseMetadata
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def get_builder_cls() -> type["SupaFlashMLASparseMetadataBuilder"]:
|
|||
|
|
return SupaFlashMLASparseMetadataBuilder
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def get_impl_cls() -> type["SupaFlashMLASparseImpl"]:
|
|||
|
|
return SupaFlashMLASparseImpl
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def get_kv_cache_usharp_shape(
|
|||
|
|
num_blocks: int,
|
|||
|
|
block_size: int,
|
|||
|
|
num_kv_heads: int,
|
|||
|
|
head_size: int,
|
|||
|
|
) -> Tuple[int, ...]:
|
|||
|
|
th_gran = SupaFlashMLASparseBackend.get_kv_cache_usharp_alignment(
|
|||
|
|
block_size)
|
|||
|
|
n_block = max(1, (num_blocks + th_gran - 1) // th_gran)
|
|||
|
|
logger.debug(
|
|||
|
|
f'Origin kv cache shape is [2, {num_blocks}, {block_size}, {num_kv_heads}, {head_size}, For SUPA Speed up, use [2, {n_block}, {th_gran * block_size}, {num_kv_heads * head_size}]' # noqa: G004
|
|||
|
|
)
|
|||
|
|
return (2, n_block, th_gran * block_size, num_kv_heads * head_size)
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def get_kv_cache_usharp_alignment(block_size: int) -> int:
|
|||
|
|
max_h_limit = 2048
|
|||
|
|
return max_h_limit // block_size
|
|||
|
|
|
|||
|
|
|
|||
|
|
class SupaFlashMLASparseImpl(FlashMLASparseImpl):
|
|||
|
|
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
num_heads: int,
|
|||
|
|
head_size: int,
|
|||
|
|
scale: float,
|
|||
|
|
num_kv_heads: int,
|
|||
|
|
alibi_slopes: Optional[list[float]],
|
|||
|
|
sliding_window: Optional[int],
|
|||
|
|
kv_cache_dtype: str,
|
|||
|
|
logits_soft_cap: Optional[float],
|
|||
|
|
attn_type: str,
|
|||
|
|
kv_sharing_target_layer_name: Optional[str],
|
|||
|
|
# MLA Specific Arguments
|
|||
|
|
topk_indice_buffer: Optional[torch.Tensor] = None,
|
|||
|
|
indexer: Optional["Indexer"] = None,
|
|||
|
|
**mla_args) -> None:
|
|||
|
|
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
|||
|
|
alibi_slopes, sliding_window, kv_cache_dtype,
|
|||
|
|
logits_soft_cap, attn_type,
|
|||
|
|
kv_sharing_target_layer_name, topk_indice_buffer,
|
|||
|
|
indexer, **mla_args)
|
|||
|
|
|
|||
|
|
def _forward_bf16_kv(
|
|||
|
|
self, q: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
|
|||
|
|
topk_indices: torch.Tensor,
|
|||
|
|
attn_metadata: SupaFlashMLASparseMetadata) -> torch.Tensor:
|
|||
|
|
bsz = 1
|
|||
|
|
seq_len_q, num_heads, _ = q.shape
|
|||
|
|
|
|||
|
|
# topk_indices = topk_indices.unsqueeze(0)
|
|||
|
|
index_mask = torch.full((bsz, seq_len_q, seq_len_q),
|
|||
|
|
1,
|
|||
|
|
dtype=torch.int32,
|
|||
|
|
device=q.device)
|
|||
|
|
# .scatter_(-1, valid_mask.to(torch.int64), 0).to(torch.int32).supa()
|
|||
|
|
|
|||
|
|
for idx_bsz in range(bsz):
|
|||
|
|
for idx_q in range(seq_len_q):
|
|||
|
|
for idx_k in range(topk_indices.shape[-1]):
|
|||
|
|
target_idx = topk_indices[idx_q][idx_k]
|
|||
|
|
if target_idx >= 0 and target_idx < seq_len_q:
|
|||
|
|
index_mask[idx_bsz][idx_q][topk_indices[idx_q]
|
|||
|
|
[idx_k]] = 0
|
|||
|
|
|
|||
|
|
query = q.transpose(0,
|
|||
|
|
1).contiguous() # [num_heads, seq_len, head_dim]
|
|||
|
|
# output is always [1, seq_len, num_heads * head_dim] however query;s shape is
|
|||
|
|
output = torch_br.supa_flash_attn_cache_infer(
|
|||
|
|
query,
|
|||
|
|
kv_c_and_k_pe_cache[:
|
|||
|
|
1], # [1, num_blocks, block_szie,self.head_size]
|
|||
|
|
attn_metadata.query_start_loc,
|
|||
|
|
attn_metadata.seq_start_loc,
|
|||
|
|
attn_metadata.block_table,
|
|||
|
|
attn_metadata.context_lens,
|
|||
|
|
attn_metadata.slot_mapping,
|
|||
|
|
attn_metadata.max_seq_len,
|
|||
|
|
self.head_size,
|
|||
|
|
softmax_scale=self.softmax_scale,
|
|||
|
|
v_head_size=self.kv_lora_rank,
|
|||
|
|
mask=index_mask)
|
|||
|
|
|
|||
|
|
output = output.reshape(seq_len_q, num_heads,
|
|||
|
|
self.kv_lora_rank).contiguous()
|
|||
|
|
return output
|
|||
|
|
|
|||
|
|
def forward(
|
|||
|
|
self,
|
|||
|
|
layer: AttentionLayer,
|
|||
|
|
q: torch.Tensor,
|
|||
|
|
k_c_normed: torch.Tensor, # key in unified attn
|
|||
|
|
k_pe: torch.Tensor, # value in unified attn
|
|||
|
|
kv_cache: torch.Tensor,
|
|||
|
|
attn_metadata: SupaFlashMLASparseMetadata,
|
|||
|
|
output: Optional[torch.Tensor] = None,
|
|||
|
|
output_scale: Optional[torch.Tensor] = None,
|
|||
|
|
output_block_scale: Optional[torch.Tensor] = None,
|
|||
|
|
) -> torch.Tensor:
|
|||
|
|
# NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use
|
|||
|
|
# MQA 576/512 approach for both prefill and decode
|
|||
|
|
|
|||
|
|
assert output is not None, "Output tensor must be provided."
|
|||
|
|
|
|||
|
|
if output_scale is not None or output_block_scale is not None:
|
|||
|
|
raise NotImplementedError(
|
|||
|
|
"fused output quantization is not yet supported"
|
|||
|
|
" for MLACommonImpl")
|
|||
|
|
|
|||
|
|
if attn_metadata is None:
|
|||
|
|
# The zero fill is required when used with DP + EP
|
|||
|
|
# to ensure all ranks within a DP group compute the
|
|||
|
|
# same expert outputs.
|
|||
|
|
return output.fill_(0)
|
|||
|
|
|
|||
|
|
num_actual_toks = attn_metadata.num_actual_tokens
|
|||
|
|
|
|||
|
|
# Inputs and outputs may be padded for CUDA graphs
|
|||
|
|
|
|||
|
|
q = q[:num_actual_toks, ...]
|
|||
|
|
k_c_normed = k_c_normed[:num_actual_toks, ...]
|
|||
|
|
k_pe = k_pe[:num_actual_toks, ...]
|
|||
|
|
|
|||
|
|
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim],
|
|||
|
|
dim=-1)
|
|||
|
|
# Convert from (B, N, P) to (N, B, P)
|
|||
|
|
q_nope = q_nope.transpose(0, 1)
|
|||
|
|
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
|
|||
|
|
ql_nope = torch.bmm(q_nope, self.W_UK_T)
|
|||
|
|
# Convert from (N, B, L) to (B, N, L)
|
|||
|
|
ql_nope = ql_nope.transpose(0, 1)
|
|||
|
|
|
|||
|
|
topk_indices = self.topk_indices_buffer[:num_actual_toks]
|
|||
|
|
|
|||
|
|
# TODO: handle index / kv_cache correctly
|
|||
|
|
# topk_indices_global = triton_convert_req_index_to_global_index(
|
|||
|
|
# attn_metadata.req_id_per_token,
|
|||
|
|
# attn_metadata.block_table,
|
|||
|
|
# topk_indices,
|
|||
|
|
# BLOCK_SIZE=attn_metadata.block_size,
|
|||
|
|
# NUM_TOPK_TOKENS=attn_metadata.topk_tokens,
|
|||
|
|
# )
|
|||
|
|
|
|||
|
|
q = torch.cat([ql_nope, q_pe], dim=-1)
|
|||
|
|
|
|||
|
|
# write the latent and rope to kv cache
|
|||
|
|
if kv_cache.numel() > 0:
|
|||
|
|
_, num_blocks, block_size, head_size = kv_cache.shape
|
|||
|
|
k_pe_tmp = k_pe.squeeze(1).unsqueeze(0)
|
|||
|
|
key_supa = torch.cat([k_c_normed, k_pe_tmp], dim=2)
|
|||
|
|
torch_br.supa_kvcache_store_infer_v2(kv_cache, key_supa, key_supa,
|
|||
|
|
attn_metadata.slot_mapping,
|
|||
|
|
head_size)
|
|||
|
|
|
|||
|
|
if self.kv_cache_dtype != "fp8_ds_mla":
|
|||
|
|
attn_out = self._forward_bf16_kv(q, kv_cache, topk_indices,
|
|||
|
|
attn_metadata)
|
|||
|
|
else:
|
|||
|
|
raise RuntimeError("Not support fp8 on br.")
|
|||
|
|
|
|||
|
|
self._v_up_proj(attn_out, out=output[:num_actual_toks])
|
|||
|
|
return output
|