Files
enginex-biren-vllm/vllm_br/v1/attention/backends/mla/indexer.py

141 lines
5.8 KiB
Python
Raw Normal View History

2026-03-10 13:31:25 +08:00
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Tuple
import torch
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.indexer import (
DeepseekV32IndexerBackend, DeepSeekV32IndexerDecodeMetadata,
DeepseekV32IndexerMetadata, DeepseekV32IndexerMetadataBuilder,
DeepseekV32IndexerPrefillMetadata, split_prefill_chunks)
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
split_decodes_and_prefills)
logger = init_logger(__name__)
class SupaDeepseekV32IndexerBackend(DeepseekV32IndexerBackend):
@staticmethod
def get_builder_cls() -> type["SupaDeepseekV32IndexerMetadataBuilder"]:
return SupaDeepseekV32IndexerMetadataBuilder
@staticmethod
def get_kv_cache_usharp_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
th_gran = SupaDeepseekV32IndexerBackend.get_kv_cache_usharp_alignment(
block_size)
n_block = max(1, (num_blocks + th_gran - 1) // th_gran)
logger.debug(
f'Origin kv cache shape is [1, {num_blocks}, {block_size}, {num_kv_heads}, {head_size}, For SUPA Speed up, use [1, {n_block}, {th_gran * block_size}, {num_kv_heads * head_size}]' # noqa: G004
)
return (1, n_block, th_gran * block_size, num_kv_heads * head_size)
@staticmethod
def get_kv_cache_usharp_alignment(block_size: int) -> int:
max_h_limit = 2048
return max_h_limit // block_size
class SupaDeepseekV32IndexerMetadataBuilder(DeepseekV32IndexerMetadataBuilder):
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> DeepseekV32IndexerMetadata:
num_reqs = common_attn_metadata.num_reqs
num_tokens = common_attn_metadata.num_actual_tokens
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
split_decodes_and_prefills(
common_attn_metadata,
decode_threshold=self.reorder_batch_threshold)
assert num_decodes + num_prefills == num_reqs
assert num_decode_tokens + num_prefill_tokens == num_tokens
prefill_metadata = None
if num_prefills > 0:
chunk_seq_ids = split_prefill_chunks(
common_attn_metadata.seq_lens_cpu,
self.max_prefill_buffer_size,
num_decodes,
)
chunks = [
self.build_one_prefill_chunk(
reqs_start, reqs_end, query_start_loc_cpu,
common_attn_metadata.seq_lens_cpu,
common_attn_metadata.block_table_tensor)
for reqs_start, reqs_end in chunk_seq_ids
]
prefill_metadata = DeepseekV32IndexerPrefillMetadata(
chunks=chunks, )
decode_metadata = None
if num_decodes > 0:
torch.diff(common_attn_metadata.query_start_loc[:num_decodes + 1],
out=self.decode_lens_buffer[:num_decodes])
decode_lens = self.decode_lens_buffer[:num_decodes]
decode_lens_cpu = torch.diff(
common_attn_metadata.query_start_loc_cpu[:num_decodes + 1])
# Use CPU to avoid GPU sync; breaking async scheduling
requires_padding = (decode_lens_cpu.max()
> decode_lens_cpu.min()).item()
# self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
# seq_lens, self.kv_cache_spec.block_size, self.num_sms)
self.scheduler_metadata_buffer = None
decode_metadata = DeepSeekV32IndexerDecodeMetadata(
block_table=common_attn_metadata.
block_table_tensor[:num_decodes, ...],
seq_lens=common_attn_metadata.seq_lens[:num_decodes],
decode_lens=decode_lens,
requires_padding=requires_padding,
schedule_metadata=self.scheduler_metadata_buffer,
)
attn_metadata = DeepseekV32IndexerMetadata(
seq_lens=common_attn_metadata.seq_lens,
num_reqs=common_attn_metadata.num_reqs,
max_query_len=common_attn_metadata.max_query_len,
max_seq_len=common_attn_metadata.max_seq_len,
num_actual_tokens=common_attn_metadata.num_actual_tokens,
query_start_loc=common_attn_metadata.query_start_loc,
slot_mapping=common_attn_metadata.slot_mapping,
head_dim=128,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
prefill=prefill_metadata,
decode=decode_metadata,
)
# if get_tensor_model_parallel_rank() == 0:
# logger.info(f"attn_metadata: {attn_metadata}")
return attn_metadata