################################################################################ # Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ################################################################################ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from typing import Tuple import torch from vllm.logger import init_logger from vllm.v1.attention.backends.mla.indexer import ( DeepseekV32IndexerBackend, DeepSeekV32IndexerDecodeMetadata, DeepseekV32IndexerMetadata, DeepseekV32IndexerMetadataBuilder, DeepseekV32IndexerPrefillMetadata, split_prefill_chunks) from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, split_decodes_and_prefills) logger = init_logger(__name__) class SupaDeepseekV32IndexerBackend(DeepseekV32IndexerBackend): @staticmethod def get_builder_cls() -> type["SupaDeepseekV32IndexerMetadataBuilder"]: return SupaDeepseekV32IndexerMetadataBuilder @staticmethod def get_kv_cache_usharp_shape( num_blocks: int, block_size: int, num_kv_heads: int, head_size: int, ) -> Tuple[int, ...]: th_gran = SupaDeepseekV32IndexerBackend.get_kv_cache_usharp_alignment( block_size) n_block = max(1, (num_blocks + th_gran - 1) // th_gran) logger.debug( f'Origin kv cache shape is [1, {num_blocks}, {block_size}, {num_kv_heads}, {head_size}, For SUPA Speed up, use [1, {n_block}, {th_gran * block_size}, {num_kv_heads * head_size}]' # noqa: G004 ) return (1, n_block, th_gran * block_size, num_kv_heads * head_size) @staticmethod def get_kv_cache_usharp_alignment(block_size: int) -> int: max_h_limit = 2048 return max_h_limit // block_size class SupaDeepseekV32IndexerMetadataBuilder(DeepseekV32IndexerMetadataBuilder): def build(self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, fast_build: bool = False) -> DeepseekV32IndexerMetadata: num_reqs = common_attn_metadata.num_reqs num_tokens = common_attn_metadata.num_actual_tokens query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ split_decodes_and_prefills( common_attn_metadata, decode_threshold=self.reorder_batch_threshold) assert num_decodes + num_prefills == num_reqs assert num_decode_tokens + num_prefill_tokens == num_tokens prefill_metadata = None if num_prefills > 0: chunk_seq_ids = split_prefill_chunks( common_attn_metadata.seq_lens_cpu, self.max_prefill_buffer_size, num_decodes, ) chunks = [ self.build_one_prefill_chunk( reqs_start, reqs_end, query_start_loc_cpu, common_attn_metadata.seq_lens_cpu, common_attn_metadata.block_table_tensor) for reqs_start, reqs_end in chunk_seq_ids ] prefill_metadata = DeepseekV32IndexerPrefillMetadata( chunks=chunks, ) decode_metadata = None if num_decodes > 0: torch.diff(common_attn_metadata.query_start_loc[:num_decodes + 1], out=self.decode_lens_buffer[:num_decodes]) decode_lens = self.decode_lens_buffer[:num_decodes] decode_lens_cpu = torch.diff( common_attn_metadata.query_start_loc_cpu[:num_decodes + 1]) # Use CPU to avoid GPU sync; breaking async scheduling requires_padding = (decode_lens_cpu.max() > decode_lens_cpu.min()).item() # self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata( # seq_lens, self.kv_cache_spec.block_size, self.num_sms) self.scheduler_metadata_buffer = None decode_metadata = DeepSeekV32IndexerDecodeMetadata( block_table=common_attn_metadata. block_table_tensor[:num_decodes, ...], seq_lens=common_attn_metadata.seq_lens[:num_decodes], decode_lens=decode_lens, requires_padding=requires_padding, schedule_metadata=self.scheduler_metadata_buffer, ) attn_metadata = DeepseekV32IndexerMetadata( seq_lens=common_attn_metadata.seq_lens, num_reqs=common_attn_metadata.num_reqs, max_query_len=common_attn_metadata.max_query_len, max_seq_len=common_attn_metadata.max_seq_len, num_actual_tokens=common_attn_metadata.num_actual_tokens, query_start_loc=common_attn_metadata.query_start_loc, slot_mapping=common_attn_metadata.slot_mapping, head_dim=128, num_decodes=num_decodes, num_decode_tokens=num_decode_tokens, num_prefills=num_prefills, num_prefill_tokens=num_prefill_tokens, prefill=prefill_metadata, decode=decode_metadata, ) # if get_tensor_model_parallel_rank() == 0: # logger.info(f"attn_metadata: {attn_metadata}") return attn_metadata