141 lines
5.8 KiB
Python
141 lines
5.8 KiB
Python
|
|
################################################################################
|
||
|
|
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
|
# you may not use this file except in compliance with the License.
|
||
|
|
# You may obtain a copy of the License at
|
||
|
|
#
|
||
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
|
#
|
||
|
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
|
# See the License for the specific language governing permissions and
|
||
|
|
# limitations under the License.
|
||
|
|
#
|
||
|
|
################################################################################
|
||
|
|
|
||
|
|
# SPDX-License-Identifier: Apache-2.0
|
||
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||
|
|
from typing import Tuple
|
||
|
|
|
||
|
|
import torch
|
||
|
|
|
||
|
|
from vllm.logger import init_logger
|
||
|
|
from vllm.v1.attention.backends.mla.indexer import (
|
||
|
|
DeepseekV32IndexerBackend, DeepSeekV32IndexerDecodeMetadata,
|
||
|
|
DeepseekV32IndexerMetadata, DeepseekV32IndexerMetadataBuilder,
|
||
|
|
DeepseekV32IndexerPrefillMetadata, split_prefill_chunks)
|
||
|
|
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
|
||
|
|
split_decodes_and_prefills)
|
||
|
|
|
||
|
|
logger = init_logger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
class SupaDeepseekV32IndexerBackend(DeepseekV32IndexerBackend):
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def get_builder_cls() -> type["SupaDeepseekV32IndexerMetadataBuilder"]:
|
||
|
|
return SupaDeepseekV32IndexerMetadataBuilder
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def get_kv_cache_usharp_shape(
|
||
|
|
num_blocks: int,
|
||
|
|
block_size: int,
|
||
|
|
num_kv_heads: int,
|
||
|
|
head_size: int,
|
||
|
|
) -> Tuple[int, ...]:
|
||
|
|
th_gran = SupaDeepseekV32IndexerBackend.get_kv_cache_usharp_alignment(
|
||
|
|
block_size)
|
||
|
|
n_block = max(1, (num_blocks + th_gran - 1) // th_gran)
|
||
|
|
logger.debug(
|
||
|
|
f'Origin kv cache shape is [1, {num_blocks}, {block_size}, {num_kv_heads}, {head_size}, For SUPA Speed up, use [1, {n_block}, {th_gran * block_size}, {num_kv_heads * head_size}]' # noqa: G004
|
||
|
|
)
|
||
|
|
return (1, n_block, th_gran * block_size, num_kv_heads * head_size)
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def get_kv_cache_usharp_alignment(block_size: int) -> int:
|
||
|
|
max_h_limit = 2048
|
||
|
|
return max_h_limit // block_size
|
||
|
|
|
||
|
|
|
||
|
|
class SupaDeepseekV32IndexerMetadataBuilder(DeepseekV32IndexerMetadataBuilder):
|
||
|
|
|
||
|
|
def build(self,
|
||
|
|
common_prefix_len: int,
|
||
|
|
common_attn_metadata: CommonAttentionMetadata,
|
||
|
|
fast_build: bool = False) -> DeepseekV32IndexerMetadata:
|
||
|
|
|
||
|
|
num_reqs = common_attn_metadata.num_reqs
|
||
|
|
num_tokens = common_attn_metadata.num_actual_tokens
|
||
|
|
|
||
|
|
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||
|
|
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
|
||
|
|
split_decodes_and_prefills(
|
||
|
|
common_attn_metadata,
|
||
|
|
decode_threshold=self.reorder_batch_threshold)
|
||
|
|
|
||
|
|
assert num_decodes + num_prefills == num_reqs
|
||
|
|
assert num_decode_tokens + num_prefill_tokens == num_tokens
|
||
|
|
|
||
|
|
prefill_metadata = None
|
||
|
|
if num_prefills > 0:
|
||
|
|
chunk_seq_ids = split_prefill_chunks(
|
||
|
|
common_attn_metadata.seq_lens_cpu,
|
||
|
|
self.max_prefill_buffer_size,
|
||
|
|
num_decodes,
|
||
|
|
)
|
||
|
|
chunks = [
|
||
|
|
self.build_one_prefill_chunk(
|
||
|
|
reqs_start, reqs_end, query_start_loc_cpu,
|
||
|
|
common_attn_metadata.seq_lens_cpu,
|
||
|
|
common_attn_metadata.block_table_tensor)
|
||
|
|
for reqs_start, reqs_end in chunk_seq_ids
|
||
|
|
]
|
||
|
|
prefill_metadata = DeepseekV32IndexerPrefillMetadata(
|
||
|
|
chunks=chunks, )
|
||
|
|
|
||
|
|
decode_metadata = None
|
||
|
|
if num_decodes > 0:
|
||
|
|
torch.diff(common_attn_metadata.query_start_loc[:num_decodes + 1],
|
||
|
|
out=self.decode_lens_buffer[:num_decodes])
|
||
|
|
decode_lens = self.decode_lens_buffer[:num_decodes]
|
||
|
|
decode_lens_cpu = torch.diff(
|
||
|
|
common_attn_metadata.query_start_loc_cpu[:num_decodes + 1])
|
||
|
|
|
||
|
|
# Use CPU to avoid GPU sync; breaking async scheduling
|
||
|
|
requires_padding = (decode_lens_cpu.max()
|
||
|
|
> decode_lens_cpu.min()).item()
|
||
|
|
|
||
|
|
# self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
|
||
|
|
# seq_lens, self.kv_cache_spec.block_size, self.num_sms)
|
||
|
|
self.scheduler_metadata_buffer = None
|
||
|
|
decode_metadata = DeepSeekV32IndexerDecodeMetadata(
|
||
|
|
block_table=common_attn_metadata.
|
||
|
|
block_table_tensor[:num_decodes, ...],
|
||
|
|
seq_lens=common_attn_metadata.seq_lens[:num_decodes],
|
||
|
|
decode_lens=decode_lens,
|
||
|
|
requires_padding=requires_padding,
|
||
|
|
schedule_metadata=self.scheduler_metadata_buffer,
|
||
|
|
)
|
||
|
|
|
||
|
|
attn_metadata = DeepseekV32IndexerMetadata(
|
||
|
|
seq_lens=common_attn_metadata.seq_lens,
|
||
|
|
num_reqs=common_attn_metadata.num_reqs,
|
||
|
|
max_query_len=common_attn_metadata.max_query_len,
|
||
|
|
max_seq_len=common_attn_metadata.max_seq_len,
|
||
|
|
num_actual_tokens=common_attn_metadata.num_actual_tokens,
|
||
|
|
query_start_loc=common_attn_metadata.query_start_loc,
|
||
|
|
slot_mapping=common_attn_metadata.slot_mapping,
|
||
|
|
head_dim=128,
|
||
|
|
num_decodes=num_decodes,
|
||
|
|
num_decode_tokens=num_decode_tokens,
|
||
|
|
num_prefills=num_prefills,
|
||
|
|
num_prefill_tokens=num_prefill_tokens,
|
||
|
|
prefill=prefill_metadata,
|
||
|
|
decode=decode_metadata,
|
||
|
|
)
|
||
|
|
|
||
|
|
# if get_tensor_model_parallel_rank() == 0:
|
||
|
|
# logger.info(f"attn_metadata: {attn_metadata}")
|
||
|
|
return attn_metadata
|