[Feature] support deepseek v3/r1/v3.2 (#78)

* [Feature] support deepseek v3/r1/v3.2

* fix gpt_oss

* update readme

* update readme

---------

Co-authored-by: hanhaowen <hanhaowen@baidu.com>
This commit is contained in:
baoqian426
2026-01-05 22:55:35 +08:00
committed by GitHub
parent 07bc24a555
commit ee0f50e68f
27 changed files with 5760 additions and 621 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,202 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import ClassVar, Optional, Union
import torch
from vllm.attention.backends.abstract import AttentionLayer, AttentionType
from vllm_kunlun.ops.attention.flashmla import (flash_mla_with_kvcache,
get_mla_metadata,
is_flashmla_supported)
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm_kunlun.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonDecodeMetadata,
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder)
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import AttentionSpec
logger = init_logger(__name__)
class FlashMLABackend(MLACommonBackend):
@staticmethod
def get_name() -> str:
return "FLASHMLA"
@staticmethod
def get_metadata_cls() -> type["FlashMLAMetadata"]:
return FlashMLAMetadata
@staticmethod
def get_builder_cls() -> type["FlashMLAMetadataBuilder"]:
return FlashMLAMetadataBuilder
@staticmethod
def get_impl_cls() -> type["FlashMLAImpl"]:
return FlashMLAImpl
@dataclass
class FlashMLADecodeMetadata(MLACommonDecodeMetadata):
tile_scheduler_metadata: torch.Tensor
num_splits: torch.Tensor
@dataclass
class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
pass
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_BATCH
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
super().__init__(kv_cache_spec, layer_names, vllm_config, device,
FlashMLAMetadata)
self.num_q_heads = vllm_config.model_config.get_num_attention_heads(
vllm_config.parallel_config)
self.cg_buf_tile_scheduler_metadata = None
self.cg_buf_num_splits = None
device_properties = torch.cuda.get_device_properties(self.device)
num_sms = device_properties.multi_processor_count
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
self.cg_buf_tile_scheduler_metadata = torch.zeros(
# Upper bound on size (<= #SMs, TileSchedulerMetaDataSize)
# TileSchedulerMetaDataSize = 8
(num_sms, 8),
device=self.device,
dtype=torch.int32,
)
self.cg_buf_num_splits = torch.empty(
(vllm_config.scheduler_config.max_num_seqs + 1),
device=self.device,
dtype=torch.int32)
def _build_decode(self, block_table_tensor: torch.Tensor,
seq_lens_cpu: torch.Tensor,
seq_lens_device: torch.Tensor,
query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor,
num_decode_tokens: int) -> FlashMLADecodeMetadata:
tile_scheduler_metadata, num_splits = \
get_mla_metadata(
seq_lens_device,
self.num_q_heads,
1, # MQA for the decode path
)
# TODO: we can disambiguate between decode and mixed-prefill decode here
# so we can only use the persistent buffer if a cudagraph is actually
# being used.
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
assert self.cg_buf_tile_scheduler_metadata is not None
assert self.cg_buf_num_splits is not None
sm_parts = tile_scheduler_metadata.size(0)
# Metadata per-SM, upper bound on size (<= #SMs, TileMetadataSize)
assert sm_parts <= self.cg_buf_tile_scheduler_metadata.size(0)
tile_scheduler_metadata_view = \
self.cg_buf_tile_scheduler_metadata[:sm_parts]
tile_scheduler_metadata_view.copy_(tile_scheduler_metadata)
tile_scheduler_metadata = tile_scheduler_metadata_view
# Num splits is per-batch, varying size (batch_size,)
n = num_splits.size(0)
# make sure static buffer is large enough
assert n <= self.cg_buf_num_splits.size(0)
num_splits_view = self.cg_buf_num_splits[:n]
num_splits_view.copy_(num_splits)
# Num splits needs to monotonically increasing
# (with: https://github.com/vllm-project/FlashMLA/pull/3, otherwise
# it needs to monotonically increasing by 1)
self.cg_buf_num_splits[n:].fill_(num_splits[-1])
num_splits = num_splits_view
return FlashMLADecodeMetadata(
block_table=block_table_tensor,
seq_lens=seq_lens_device,
tile_scheduler_metadata=tile_scheduler_metadata,
num_splits=num_splits,
)
class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
can_return_lse_for_decode: bool = True
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
# MLA Specific Arguments
**mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
logits_soft_cap, attn_type,
kv_sharing_target_layer_name, **mla_args)
is_supported, reason = is_flashmla_supported()
assert is_supported, reason
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
if any(unsupported_features):
raise NotImplementedError(
"FlashMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, logits_soft_cap")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashMLAImpl")
def _forward_decode(
self,
q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: FlashMLAMetadata,
layer: AttentionLayer,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
# TODO: (zyongye) decode function for mla here
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
if type(q) is tuple:
q = torch.cat(q, dim=-1)
assert isinstance(q, torch.Tensor)
o, lse = flash_mla_with_kvcache(
q=q.unsqueeze(1), # Add seqlen dim of 1 (decode)
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens,
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=attn_metadata.decode.
tile_scheduler_metadata,
num_splits=attn_metadata.decode.num_splits,
softmax_scale=self.scale,
causal=True,
descale_q=layer._q_scale.reshape(1),
descale_k=layer._k_scale.reshape(1),
)
return o, lse

View File

@@ -0,0 +1,752 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from dataclasses import dataclass
from typing import TYPE_CHECKING, ClassVar, Optional
import numpy as np
import torch
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
AttentionMetadata)
from vllm.attention.backends.utils import get_mla_dims
from vllm_kunlun.ops.attention.flashmla import (flash_mla_sparse_prefill,
flash_mla_with_kvcache,
get_mla_metadata,
kunlun_flash_mla_with_kvcache)
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import cdiv
from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
reshape_attn_output_for_spec_decode,
reshape_query_for_spec_decode,
split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.distributed import get_tp_group
if TYPE_CHECKING:
from vllm.model_executor.models.deepseek_v2 import Indexer
logger = init_logger(__name__)
"""
NOTE: FlashMLA Sparse uses an fp8 cache with the following format
In the "FP8 with scale" format, each token's KV cache is 656 Bytes,
structured as:
- **First 512 bytes:** The "quantized NoPE" part, containing 512
`float8_e4m3` values.
- **Next 16 bytes:** Scale factors, containing 4 `float32` values.
The first `float32` is the scale for the first 128 `float8_e4m3` values,
the second for the next 128, and so on.
- **Last 128 bytes:** The "RoPE" part, containing 64 `bfloat16` values. This
part is not quantized for accuracy.
"""
def _lse2_to_lse(lse_base2: torch.Tensor) -> torch.Tensor:
# Convert base-2 LSE to natural-log LSE
# Keep FP32 for numerical stability during the merge.
return (lse_base2.to(torch.float32) * math.log(2.0))
class FlashMLASparseBackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_name() -> str:
return "FLASHMLA_SPARSE_VLLM_V1"
@staticmethod
def get_metadata_cls() -> type[AttentionMetadata]:
return FlashMLASparseMetadata
@staticmethod
def get_builder_cls() -> type["FlashMLASparseMetadataBuilder"]:
return FlashMLASparseMetadataBuilder
@staticmethod
def get_impl_cls() -> type["FlashMLASparseImpl"]:
return FlashMLASparseImpl
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int, # assumed to be 1 for MLA
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
if cache_dtype_str == "fp8_ds_mla":
# custom storage fromat is 656 bytes
# see FlashMLA readme.md for details
return (num_blocks, block_size, 656)
else:
return (num_blocks, block_size, head_size)
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.bfloat16]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [576]
@dataclass
class MLASparsePrefillMetadata:
# NOTE(Chen): not call it "FlashMLASparsePrefillMetadata" because
# the kernel is not from flashmla
block_table: torch.Tensor = None
has_context: bool = False
context_lens: Optional[torch.Tensor] = None
# Sequence lengths (context + query) for prefill requests
# Shape: [num_prefill_reqs]
seq_lens: torch.Tensor = None
# Request ID for each token: -1 for decode tokens, request index
# (0, 1, 2, ...) for prefill tokens.
# Shape: [num_actual_tokens]
request_ids: torch.Tensor = None
query_start_loc: torch.Tensor = None
query_start_loc_cpu: torch.Tensor = None
@dataclass
class FlashMLASparseDecodeAndContextMetadata:
scheduler_metadata: torch.Tensor = None
num_splits: torch.Tensor = None
cache_lens: torch.Tensor = None
prefill_context_lengths: Optional[torch.Tensor] = None
prefill_new_k_start_locs: Optional[torch.Tensor] = None
dummy_block_table: torch.Tensor = None
seq_lens: torch.Tensor = None
seq_lens_cpu: torch.Tensor = None
max_seq_len: int = -1 # needed for reshape in spec decode
def filter_prefill_indices(
self, indices: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
assert self.prefill_context_lengths is not None
prefill_context_lengths = self.prefill_context_lengths.unsqueeze(-1)
context_indices = torch.where(indices < prefill_context_lengths,
indices, -1)
new_token_indices = torch.where(indices >= prefill_context_lengths,
indices - prefill_context_lengths, -1)
return context_indices, new_token_indices
@dataclass
class FlashMLASparseMetadata:
num_reqs: int
max_query_len: int
max_seq_len: int
num_actual_tokens: int # Number of tokens excluding padding.
query_start_loc: torch.Tensor
slot_mapping: torch.Tensor
block_table: torch.Tensor
req_id_per_token: torch.Tensor
block_size: int = 64
topk_tokens: int = 2048
num_prefills: int = 0
num_decodes: int = 0
num_prefill_tokens: int = 0
num_decode_tokens: int = 0
decode_metadata: Optional[FlashMLASparseDecodeAndContextMetadata] = None
prefill_metadata: Optional[MLASparsePrefillMetadata] = None
@dataclass
class FP8KernelMetadata:
scheduler_metadata: Optional[torch.Tensor]
num_splits: torch.Tensor
dummy_block_table: torch.Tensor
cache_lens: torch.Tensor
fp8_extra_metadata: Optional[FP8KernelMetadata] = None
@triton.jit
def _convert_req_index_to_global_index_kernel(
req_id_ptr, # int32 [num_tokens]
block_table_ptr, # int32 [num_requests, max_num_blocks_per_req]
token_indices_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS]
out_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS]
# shapes (compile-time where possible)
max_num_blocks_per_req: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
BLOCK_N: tl.constexpr, # tile width along columns
# strides (in elements)
bt_stride0,
bt_stride1,
ti_stride0,
ti_stride1,
out_stride0,
out_stride1,
):
# program_id(0) -> token_id (row)
# program_id(1) -> tile index along columns
token_id = tl.program_id(0)
tile_id = tl.program_id(1)
# Each program covers BLOCK_N consecutive columns
indice_id = tile_id * BLOCK_N + tl.arange(0, BLOCK_N)
# Load request id for this token (no mask: grid is exact)
req = tl.load(req_id_ptr + token_id)
# Load token indices for this tile
ti_ptr = token_indices_ptr + token_id * ti_stride0 + indice_id * ti_stride1
tok = tl.load(ti_ptr) # int32
# Only token == -1 should propagate as -1
is_invalid_tok = tok < 0
# Compute block id and in-block offset
block_id = tok // BLOCK_SIZE
inblock_off = tok % BLOCK_SIZE
# Guard block_table access
valid_block = block_id < max_num_blocks_per_req
bt_ptr = block_table_ptr + req * bt_stride0 + block_id * bt_stride1
base = tl.load(bt_ptr, mask=valid_block, other=0)
# If token == -1 OR block_id OOB, output -1; else base * BLOCK_SIZE + offset
out_val = tl.where(is_invalid_tok | (~valid_block), -1,
base * BLOCK_SIZE + inblock_off)
# Store results
out_ptr_ij = out_ptr + token_id * out_stride0 + indice_id * out_stride1
tl.store(out_ptr_ij, out_val)
def triton_convert_req_index_to_global_index(
req_id: torch.Tensor, # int32 [num_tokens]
block_table: torch.
Tensor, # int32 [num_requests, max_num_blocks_per_req]
token_indices: torch.Tensor, # int32 [num_tokens, NUM_TOPK_TOKENS]
BLOCK_SIZE: int = 64,
NUM_TOPK_TOKENS: int = 2048,
BLOCK_N: int = 128, # tile width along columns
):
"""
out[token_id, indice_id] =
block_table[req_id[token_id],
token_indices[token_id, indice_id] // BLOCK_SIZE] * BLOCK_SIZE
+ token_indices[token_id, indice_id] % BLOCK_SIZE
Only when token_indices[token_id, indice_id] == -1 do we output -1.
For safety, we also output -1 if the derived block_id would be
out-of-bounds.
"""
assert req_id.dtype == torch.int32
assert block_table.dtype == torch.int32
assert token_indices.dtype == torch.int32
assert token_indices.shape[1] == NUM_TOPK_TOKENS
assert NUM_TOPK_TOKENS % BLOCK_N == 0, \
f"NUM_TOPK_TOKENS ({NUM_TOPK_TOKENS}) must be divisible by" \
f"BLOCK_N ({BLOCK_N})"
num_tokens = req_id.shape[0]
num_requests, max_num_blocks_per_req = block_table.shape
tiles_per_row = NUM_TOPK_TOKENS // BLOCK_N
# Ensure contiguous tensors on the same device
req_id_c = req_id.contiguous()
block_table_c = block_table.contiguous()
token_indices_c = token_indices.contiguous()
out = torch.empty_like(token_indices_c)
# Strides in elements
bt_stride0, bt_stride1 = block_table_c.stride()
ti_stride0, ti_stride1 = token_indices_c.stride()
out_stride0, out_stride1 = out.stride()
# Exact 2D grid: tokens × column tiles
grid = (num_tokens, tiles_per_row)
_convert_req_index_to_global_index_kernel[grid](
req_id_c,
block_table_c,
token_indices_c,
out,
# shapes / constexprs
max_num_blocks_per_req,
BLOCK_SIZE,
BLOCK_N,
# strides
bt_stride0,
bt_stride1,
ti_stride0,
ti_stride1,
out_stride0,
out_stride1,
)
return out
def kunlun_convert_req_index_to_global_index(
req_id: torch.Tensor, # int32 [num_tokens]
block_table: torch.Tensor, # int32 [num_requests, max_num_blocks_per_req]
token_indices: torch.Tensor, # int32 [num_tokens, NUM_TOPK_TOKENS]
BLOCK_SIZE: int = 64,
NUM_TOPK_TOKENS: int = 2048,
):
assert req_id.dtype == torch.int32
assert block_table.dtype == torch.int32
assert token_indices.dtype == torch.int32
assert token_indices.shape[1] == NUM_TOPK_TOKENS
num_tokens = req_id.shape[0]
num_requests, max_num_blocks_per_req = block_table.shape
out = torch.zeros_like(token_indices)
# Compute block_id and inblock_off for all tokens at once
block_id = token_indices // BLOCK_SIZE
inblock_off = token_indices % BLOCK_SIZE
# Create mask for invalid tokens (tok < 0)
invalid_tok_mask = token_indices < 0
# Create mask for out-of-bounds block_id
oob_block_mask = block_id >= max_num_blocks_per_req
# Combine masks - output -1 for either condition
invalid_mask = invalid_tok_mask | oob_block_mask
# Get request IDs expanded to match token_indices shape
req_ids_expanded = req_id.unsqueeze(1).expand(-1, NUM_TOPK_TOKENS)
# Gather base addresses from block_table
# Clamp block_id to avoid index errors (we'll mask these out anyway)
block_id_clamped = torch.clamp(block_id, 0, max_num_blocks_per_req - 1)
# Use advanced indexing to get base addresses
base_addrs = block_table[req_ids_expanded, block_id_clamped]
# Compute the global indices
global_indices = base_addrs * BLOCK_SIZE + inblock_off
# Apply mask: set invalid positions to -1
out = torch.where(invalid_mask, torch.tensor(-1, dtype=torch.int32, device=token_indices.device), global_indices)
return out
def kunlun_concat_and_cache_mla(
kv_c: torch.Tensor, #[num_tokens, kv_lora_rank]
k_pe: torch.Tensor, #[num_tokens, pe_dim]
kv_cache: torch.Tensor, #[num_blocks, block_size, (kv_lora_rank + pe_dim)]
slot_mapping: torch.Tensor, #[num_tokens] or [num_actual_tokens]
kv_cache_dtype: str,
scale: torch.Tensor
):
num_tokens = slot_mapping.shape[0]
kv_lora_rank = kv_c.shape[1]
pe_dim = k_pe.shape[1]
block_size = kv_cache.shape[1]
def kunlun_fp8_ds_mla():
for token_idx in range(num_tokens):
slot = slot_mapping[token_idx].item()
if slot < 0: continue
block_idx = slot // block_size
block_offset = slot % block_size
kv_c_i = kv_c[token_idx].view(4,kv_lora_rank//4).contiguous()
kv_c_i_int8 = torch.zeros(
kv_c_i.shape,
device=kv_c.device,
dtype=torch.int8,
)
kv_c_i_scale = torch.zeros(
[kv_c_i.shape[0], 1],
device=kv_c.device,
dtype=torch.float32,
)
torch.ops._C.quant2d(kv_c_i, kv_c_i_int8, kv_c_i_scale, force_sdnn=True)
kv_c_i_scale /= 127
kv_cache[block_idx, block_offset, :kv_lora_rank] = kv_c_i_int8.view(-1).view(torch.uint8).contiguous()
kv_cache[block_idx, block_offset, kv_lora_rank:kv_lora_rank + 16] = kv_c_i_scale.view(-1).view(torch.uint8).contiguous()
kv_cache[block_idx, block_offset, kv_lora_rank+16:] = k_pe[token_idx, :].view(torch.uint8).contiguous()
def kunlun_mla():
for token_idx in range(num_tokens):
slot = slot_mapping[token_idx].item()
if slot < 0: continue
block_idx = slot // block_size
block_offset = slot % block_size
kv_cache[block_idx, block_offset, :kv_lora_rank] = kv_c[token_idx, :].contiguous()
kv_cache[block_idx, block_offset, kv_lora_rank:] = k_pe[token_idx, :].contiguous()
if (kv_cache_dtype == "fp8_ds_mla"):
assert kv_lora_rank == 512, "kv_lora_rank must be 512 for fp8_ds_mla"
assert pe_dim == 64, "pe_dim must be 64 for fp8_ds_mla"
assert kv_cache.shape[2] == 656 // kv_cache.element_size(), "kv_cache.shape[2] must be 656 bytes for fp8_ds_mla"
assert kv_c.element_size() == 2, "kv_c.element_size() must be 2 for fp8_ds_mla"
assert k_pe.element_size() == 2, "k_pe.element_size() must be 2 for fp8_ds_mla"
kunlun_fp8_ds_mla()
else:
assert kv_cache.shape[2] == kv_lora_rank + pe_dim
kunlun_mla()
@dataclass
class FlashMLASparseMetadataBuilder(
AttentionMetadataBuilder[FlashMLASparseMetadata]):
cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_BATCH
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
self.vllm_config = vllm_config
self.layer_names = layer_names
cache_config = vllm_config.cache_config
self.kv_cache_spec = kv_cache_spec
self.model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config
self.device = device
# Treat requests with query length <= 1 as decodes to match the
# DeepGEMM indexer constraint (fp8_paged_mqa_logits only supports next_n <= 2)
# 从最新版本vllm中引入的
self._init_reorder_batch_threshold(1, supports_spec_as_decode=True)
props = torch.cuda.get_device_properties(device)
sm_count = props.multi_processor_count
self.num_heads = self.model_config.get_num_attention_heads(
parallel_config)
self.mla_dims = get_mla_dims(self.model_config)
self.topk_tokens = vllm_config.model_config.hf_config.index_topk
self.use_fp8_kv_cache = cache_config.cache_dtype == "fp8_ds_mla"
self.topk_tokens_tensor = torch.tensor([self.topk_tokens],
device=device,
dtype=torch.int32)
# self.max_model_len_tensor = torch.tensor(
# [self.model_config.max_model_len],
# device=device,
# dtype=torch.int32)
# this is ignored by `flash_mla_with_kvcache` if indices not None
self.dummy_block_table = torch.empty((1, 1),
dtype=torch.int32,
device=self.device)
# Equation taken from FlashMLA/csrc/pybind.cpp
h_q, h_k = self.num_heads, 1
s_q = 1 # inversely proportional to s_q, so s_q = 1 is the largest
max_num_sm_parts = int(
max((sm_count // 2) / h_k // (cdiv(h_q // h_k, 2 * 64) * s_q), 1))
if current_platform.is_device_capability(100):
max_num_sm_parts *= 2
self.tile_scheduler_metadata_buffer = torch.zeros(
# TileSchedulerMetaDataSize = 8
# see: FlashMLA/csrc/params.h
(max_num_sm_parts, 8),
dtype=torch.int32,
device=device)
self.num_splits_buffer = torch.zeros(
# We pack all the tokens into one batch for sparse attention.
# Otherwise, we can exceed the sm of `get_mla_metadata`.
(
2, ),
dtype=torch.int32,
device=device)
self.req_id_per_token_buffer = torch.zeros(
(vllm_config.scheduler_config.max_num_batched_tokens, ),
dtype=torch.int32,
device=device)
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> FlashMLASparseMetadata:
num_tokens = common_attn_metadata.num_actual_tokens
starts = np.asarray(common_attn_metadata.query_start_loc_cpu,
dtype=np.int32)
seg_lengths = np.diff(starts)
req_id_per_token = np.repeat(
np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths)
# Zero-fill for cudagraphs
self.req_id_per_token_buffer.fill_(0)
self.req_id_per_token_buffer[:req_id_per_token.shape[0]]\
.copy_(torch.from_numpy(req_id_per_token), non_blocking=True)
req_id_per_token = self.req_id_per_token_buffer[:num_tokens]
fp8_extra_metadata = None
if self.use_fp8_kv_cache:
cache_seqlens_cpu, cache_seqlens = get_mla_metadata(
cache_seqlens=self.topk_tokens_tensor,
)
fp8_extra_metadata = FlashMLASparseMetadata.FP8KernelMetadata(
scheduler_metadata=None,
num_splits=None,
# cache_lens and block_table are basically unused in sparse case
# but the decode kernel will treat -1 and indices >= cache_lens
# as invalid so we make sure cache_lens is large enough to not
# accidentally mark indices invalid, we will use -1 exclusively
# to mark invalid indices
cache_lens=cache_seqlens_cpu,
dummy_block_table=self.dummy_block_table)
(num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) = (
split_decodes_and_prefills(
common_attn_metadata,
decode_threshold=self.reorder_batch_threshold or 1,
require_uniform=True,
)
)
# For pure decode batches, prefill_request_id will be None
# For mixed batches, it will have -1 for decode and request_id for prefill
prefill_metadata = None
if num_prefills > 0:
prefill_metadata = MLASparsePrefillMetadata(
query_start_loc = common_attn_metadata.query_start_loc[num_decodes:] - common_attn_metadata.query_start_loc[num_decodes], #因为prefiil、decode请求是分离所以需要对q进行切分故需调整该值
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[num_decodes:] - common_attn_metadata.query_start_loc_cpu[num_decodes],
)
decode_metadata = None
if num_decodes > 0:
max_seq_len = int(common_attn_metadata.seq_lens_cpu[:num_decodes].max())
decode_metadata = FlashMLASparseDecodeAndContextMetadata(
max_seq_len=max_seq_len,
seq_lens=common_attn_metadata.seq_lens[:num_decodes],
seq_lens_cpu=common_attn_metadata.seq_lens_cpu[:num_decodes],
)
metadata = FlashMLASparseMetadata(
num_reqs=common_attn_metadata.num_reqs,
max_query_len=common_attn_metadata.max_query_len,
max_seq_len=common_attn_metadata.max_seq_len,
num_actual_tokens=common_attn_metadata.num_actual_tokens,
query_start_loc=common_attn_metadata.query_start_loc,
slot_mapping=common_attn_metadata.slot_mapping,
block_table=common_attn_metadata.block_table_tensor,
req_id_per_token=req_id_per_token,
block_size=self.kv_cache_spec.block_size,
topk_tokens=self.topk_tokens,
fp8_extra_metadata=fp8_extra_metadata,
num_prefills=num_prefills,
num_decodes=num_decodes,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
decode_metadata=decode_metadata,
prefill_metadata=prefill_metadata
)
return metadata
class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
# MLA Specific Arguments
topk_indice_buffer: Optional[torch.Tensor] = None,
indexer: Optional["Indexer"] = None,
**mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
logits_soft_cap, attn_type,
kv_sharing_target_layer_name, **mla_args)
self.softmax_scale = scale
assert indexer is not None
self.topk_indices_buffer = indexer.topk_indices_buffer
self.padding = 128 if current_platform.is_device_capability(
100) else 64
def _forward_bf16_kv(
self, q: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
topk_indices: torch.Tensor,
attn_metadata: FlashMLASparseMetadata) -> torch.Tensor:
num_tokens = q.shape[0]
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.contiguous().view(
-1, kv_c_and_k_pe_cache.shape[-1])
# num_decode_tokens = attn_metadata.num_decode_tokens
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decodes = attn_metadata.num_decodes
has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens
def _bf16_decode(q: torch.Tensor, topk_indices: torch.Tensor) -> torch.Tensor:
# Reshape q: (num_decode_tokens, num_heads, head_dim)
# -> (num_decodes, seq_len, num_heads, head_dim)
q = reshape_query_for_spec_decode(q, num_decodes)
seq_len = q.shape[1]
# Reshape topk_indices: (num_decode_tokens, topk)
# -> (num_decodes, seq_len, topk)
topk_indices = topk_indices.view(num_decodes, seq_len, -1)
decode_metadata = attn_metadata.decode_metadata
_attn_out, _, _ = kunlun_flash_mla_with_kvcache(
q=q,
k_cache=kv_c_and_k_pe_cache,
head_dim_v=512,
cache_seqlens=decode_metadata.seq_lens,
cache_seqlens_cpu=decode_metadata.seq_lens_cpu,
is_fp8_kvcache=False,
indices=topk_indices,
softmax_scale=self.softmax_scale,
max_seq_kv=decode_metadata.max_seq_len
)
# Reshape output: (num_decodes, seq_len, num_heads, head_dim_v)
# -> (num_decode_tokens, num_heads, head_dim_v)
return reshape_attn_output_for_spec_decode(_attn_out)
def _bf16_prefill(q: torch.Tensor, topk_indices: torch.Tensor) -> torch.Tensor:
prefill_metadata = attn_metadata.prefill_metadata
topk_indices = topk_indices.view(num_prefill_tokens, 1, -1)
# NOTE: 只有prefill阶段attn_metadata.query_start_loc是符合klx算子需求的
_attn_out = flash_mla_sparse_prefill(
q=q,
kv=kv_c_and_k_pe_cache,
indices=topk_indices,
sm_scale=self.softmax_scale,
q_lod_xpu=prefill_metadata.query_start_loc,
q_lod_cpu=prefill_metadata.query_start_loc_cpu
)[0]
return _attn_out
topk_indices_global = torch.ops.xspeedgate_ops.convert_req_index_to_global_index(
req_id=attn_metadata.req_id_per_token,
block_table=attn_metadata.block_table,
token_indices=topk_indices,
block_size=attn_metadata.block_size,
num_topk_tokens=attn_metadata.topk_tokens,
)
attn_out = torch.empty(
(num_tokens, self.num_heads, self.kv_lora_rank),
dtype=q.dtype,
device=q.device,
)
if has_prefill:
prefill_q = q[num_decode_tokens:]
prefill_topk_indices_global = topk_indices_global[num_decode_tokens:]
attn_out[num_decode_tokens:] = _bf16_prefill(prefill_q, prefill_topk_indices_global)
# 处理decode部分 - 需要正确的block table映射print
if has_decode:
decode_q = q[:num_decode_tokens]
decode_topk_indices_global = topk_indices_global[:num_decode_tokens]
attn_out[:num_decode_tokens] = _bf16_decode(decode_q, decode_topk_indices_global)
return attn_out
def _forward_fp8_kv(self, q: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
topk_indices: torch.Tensor,
attn_metadata: FlashMLASparseMetadata) -> torch.Tensor:
# TODO: When fwd_kvcache_mla supports uint8 kv cache, execute this function.
assert attn_metadata.fp8_extra_metadata is not None
extra_metadata = attn_metadata.fp8_extra_metadata
_attn_out, _ = flash_mla_with_kvcache(
q=q.unsqueeze(0), # unsqueeze to add batch_dim
k_cache=kv_c_and_k_pe_cache,
block_table=extra_metadata.dummy_block_table,
head_dim_v=512,
cache_seqlens=extra_metadata.cache_lens,
tile_scheduler_metadata=extra_metadata.scheduler_metadata, # None
num_splits=extra_metadata.num_splits, # None
is_fp8_kvcache=True,
indices=topk_indices.unsqueeze(0), # unsqueeze to add batch_dim
softmax_scale=self.softmax_scale,
max_seq_kv=attn_metadata.max_seq_len
)
return _attn_out
def forward(
self,
layer: AttentionLayer,
q: torch.Tensor,
k_c_normed: torch.Tensor, # key in unified attn
k_pe: torch.Tensor, # value in unified attn
kv_cache: torch.Tensor,
attn_metadata: FlashMLASparseMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use
# MQA 576/512 approach for both prefill and decode
assert output is not None, "Output tensor must be provided."
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for MLACommonImpl")
if attn_metadata is None:
# The zero fill is required when used with DP + EP
# to ensure all ranks within a DP group compute the
# same expert outputs.
return output.fill_(0)
num_actual_toks = attn_metadata.num_actual_tokens
# Inputs and outputs may be padded for CUDA graphs
q = q[:num_actual_toks, ...]
k_c_normed = k_c_normed[:num_actual_toks, ...]
k_pe = k_pe[:num_actual_toks, ...]
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim],
dim=-1)
# Convert from (B, N, P) to (N, B, P)
q_nope = q_nope.transpose(0, 1)
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
ql_nope = torch.bmm(q_nope, self.W_UK_T)
# Convert from (N, B, L) to (B, N, L)
ql_nope = ql_nope.transpose(0, 1)
topk_indices = self.topk_indices_buffer[:num_actual_toks]
q = torch.cat([ql_nope, q_pe], dim=-1)
# write the latent and rope to kv cache
if kv_cache.numel() > 0:
torch.ops._C.concat_and_cache_mla(
kv_c=k_c_normed,
k_pe=k_pe.squeeze(1),
kv_cache=kv_cache,
slot_mapping=attn_metadata.slot_mapping.flatten(),
)
if self.kv_cache_dtype != "fp8_ds_mla":
attn_out = self._forward_bf16_kv(q, kv_cache, topk_indices,
attn_metadata)
else:
# attn_out = self._forward_fp8_kv(q, kv_cache, topk_indices_global,
# attn_metadata)
raise NotImplementedError
self._v_up_proj(attn_out, out=output[:num_actual_toks])
return output

View File

@@ -0,0 +1,133 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from dataclasses import dataclass
from typing import ClassVar, Optional
from vllm.logger import init_logger
from vllm.v1.attention.backends.utils import ( CommonAttentionMetadata,
split_decodes_and_prefills)
from vllm.v1.attention.backends.mla.indexer import (split_prefill_chunks,
DeepseekV32IndexerMetadataBuilder,
DeepseekV32IndexerPrefillMetadata)
logger = init_logger(__name__)
@dataclass
class DeepSeekV32IndexerDecodeMetadata:
block_table: torch.Tensor
seq_lens: torch.Tensor
seq_lens_cpu: torch.Tensor
decode_lens: torch.Tensor
requires_padding: bool
schedule_metadata: torch.Tensor
@dataclass
class DeepseekV32IndexerMetadata:
# FIXME (zyongye)
# hacky way to access the data now, need to be in chunked meta
seq_lens: torch.Tensor
seq_lens_cpu: torch.Tensor
num_reqs: int
max_query_len: int
max_seq_len: int
num_actual_tokens: int # Number of tokens excluding padding.
query_start_loc: torch.Tensor
slot_mapping: torch.Tensor
# The dimension of the attention heads
head_dim: int
# New for MLA (compared to FlashAttention)
# For handling prefill decode split
num_decodes: int
num_decode_tokens: int
num_prefills: int
num_prefill_tokens: int
decode: Optional[DeepSeekV32IndexerDecodeMetadata] = None
prefill: Optional[DeepseekV32IndexerPrefillMetadata] = None
def kunlun_build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> DeepseekV32IndexerMetadata:
num_reqs = common_attn_metadata.num_reqs
num_tokens = common_attn_metadata.num_actual_tokens
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
split_decodes_and_prefills(
common_attn_metadata,
decode_threshold=self.reorder_batch_threshold)
assert num_decodes + num_prefills == num_reqs
assert num_decode_tokens + num_prefill_tokens == num_tokens
prefill_metadata = None
if num_prefills > 0:
chunk_seq_ids = split_prefill_chunks(
common_attn_metadata.seq_lens_cpu,
self.max_prefill_buffer_size,
num_decodes,
)
chunks = [
self.build_one_prefill_chunk(
reqs_start, reqs_end, query_start_loc_cpu,
common_attn_metadata.seq_lens_cpu,
common_attn_metadata.block_table_tensor)
for reqs_start, reqs_end in chunk_seq_ids
]
prefill_metadata = DeepseekV32IndexerPrefillMetadata(
chunks=chunks, )
decode_metadata = None
if num_decodes > 0:
torch.diff(common_attn_metadata.query_start_loc[:num_decodes + 1],
out=self.decode_lens_buffer[:num_decodes])
decode_lens = self.decode_lens_buffer[:num_decodes]
decode_lens_cpu = torch.diff(
common_attn_metadata.query_start_loc_cpu[:num_decodes + 1])
# Use CPU to avoid GPU sync; breaking async scheduling
requires_padding = (decode_lens_cpu.max()
> decode_lens_cpu.min()).item()
seq_lens = common_attn_metadata.seq_lens[:num_decodes]
decode_metadata = DeepSeekV32IndexerDecodeMetadata(
block_table=common_attn_metadata.
block_table_tensor[:num_decodes, ...],
seq_lens=common_attn_metadata.seq_lens[:num_decodes],
seq_lens_cpu=common_attn_metadata.seq_lens[:num_decodes].cpu(),
decode_lens=decode_lens,
requires_padding=requires_padding,
schedule_metadata=self.scheduler_metadata_buffer,
)
attn_metadata = DeepseekV32IndexerMetadata(
seq_lens=common_attn_metadata.seq_lens,
seq_lens_cpu=common_attn_metadata.seq_lens.cpu(),
num_reqs=common_attn_metadata.num_reqs,
max_query_len=common_attn_metadata.max_query_len,
max_seq_len=common_attn_metadata.max_seq_len,
num_actual_tokens=common_attn_metadata.num_actual_tokens,
query_start_loc=common_attn_metadata.query_start_loc,
slot_mapping=common_attn_metadata.slot_mapping,
head_dim=128,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
prefill=prefill_metadata,
decode=decode_metadata,
)
# if get_tensor_model_parallel_rank() == 0:
# logger.info(f"attn_metadata: {attn_metadata}")
return attn_metadata
DeepseekV32IndexerMetadataBuilder.build = kunlun_build

View File

@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import os
import torch
import torch.nn as nn
from packaging import version
@@ -24,6 +24,7 @@ class TopKTopPSampler(nn.Module):
def __init__(self, logprobs_mode):
super().__init__()
self.logprobs_mode = logprobs_mode
logger.info_once(
"Using FlashInfer for top-p & top-k sampling.")
self.forward = self.forward_kunlun
@@ -40,9 +41,14 @@ class TopKTopPSampler(nn.Module):
The logits tensor may be updated in-place.
"""
logits = apply_top_k_top_p(logits, k, p)
logits = self.apply_top_k_top_p(logits, k, p)
logits_to_return = None
if self.logprobs_mode == "processed_logits":
logits_to_return = logits
elif self.logprobs_mode == "processed_logprobs":
logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32)
probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators), None
return random_sample(probs, generators), logits_to_return
def forward_kunlun(
self,
@@ -52,16 +58,13 @@ class TopKTopPSampler(nn.Module):
p: Optional[torch.Tensor],
) -> torch.Tensor:
"""More optimized implementation for top-k and top-p sampling."""
if k is None and p is None:
# We prefer `random_sample` over `flashinfer_sample` when sorting is
# not needed. This is because `random_sample` does not require
# CPU-GPU synchronization while `flashinfer_sample` does.
probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators), None
if generators:
logger.warning_once("FlashInfer 0.2.3+ does not support "
"per-request generators. Falling back to "
"PyTorch-native implementation.")
if (k is None and p is None) or generators:
if generators:
logger.debug_once(
"FlashInfer 0.2.3+ does not support "
"per-request generators. Falling back to "
"PyTorch-native implementation."
)
return self.forward_native(logits, generators, k, p)
# flashinfer sampling functions expect contiguous logits.
# In flex_attn/triton_attn fp32 inference, logits can be non-contiguous
@@ -196,6 +199,7 @@ def flashinfer_sample(
probs, top_k=k, deterministic=True)
else:
# Both top-k and top-p.
k = k.to(torch.int32)
next_token_ids = xtorch_ops.top_k_top_p_sampling_from_probs(
probs, top_k=k, top_p=p, deterministic=True)

View File

@@ -0,0 +1,344 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import defaultdict
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.model_executor.models.utils import extract_layer_index
from vllm.multimodal.cache import processor_only_cache_from_config
from vllm.multimodal.registry import MultiModalRegistry
from vllm.platforms import current_platform
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec
if TYPE_CHECKING:
from vllm.attention.layer import Attention
class MultiModalBudget:
"""Helper class to calculate budget information for multi-modal models."""
def __init__(
self,
model_config: ModelConfig,
scheduler_config: SchedulerConfig,
mm_registry: MultiModalRegistry,
) -> None:
super().__init__()
self.model_config = model_config
self.scheduler_config = scheduler_config
self.mm_registry = mm_registry
self.cache = cache = processor_only_cache_from_config(
model_config, mm_registry)
self.max_model_len = model_config.max_model_len
self.max_num_reqs = scheduler_config.max_num_seqs
self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config,
cache=cache)
max_tokens_by_modality = mm_registry \
.get_max_tokens_per_item_by_nonzero_modality(model_config,
cache=cache)
encoder_compute_budget, encoder_cache_size = compute_mm_encoder_budget(
scheduler_config,
max_tokens_by_modality,
)
self.encoder_compute_budget = encoder_compute_budget
self.encoder_cache_size = encoder_cache_size
max_items_per_prompt_by_modality = dict[str, int]()
max_items_per_batch_by_modality = dict[str, int]()
for modality, max_tokens in max_tokens_by_modality.items():
(
max_items_per_prompt,
max_items_per_batch,
) = self.get_max_items(modality, max_tokens)
max_items_per_prompt_by_modality[modality] = max_items_per_prompt
max_items_per_batch_by_modality[modality] = max_items_per_batch
self.max_tokens_by_modality = max_tokens_by_modality
self.max_items_per_prompt_by_modality = max_items_per_prompt_by_modality
self.max_items_per_batch_by_modality = max_items_per_batch_by_modality
def get_modality_with_max_tokens(self) -> str:
max_tokens_by_modality = self.max_tokens_by_modality
modality, _ = max(max_tokens_by_modality.items(), key=lambda x: x[1])
return modality
def get_encoder_budget(self) -> int:
return min(self.encoder_compute_budget, self.encoder_cache_size)
def get_max_items(
self,
modality: str,
max_tokens_per_item: int,
) -> tuple[int, int]:
if max_tokens_per_item == 0:
return 0, 0
# Check how many items of this modality can be supported by
# the encoder budget.
encoder_budget = self.get_encoder_budget()
# TODO: handle encoder-decoder models once we support them.
if encoder_budget == 0:
return 0, 0
max_encoder_items_per_batch = encoder_budget // max_tokens_per_item
# Check how many items of this modality can be supported by
# the decoder budget.
mm_limit = self.mm_limits[modality]
max_items_per_prompt = max(
1,
min(mm_limit, self.max_model_len // max_tokens_per_item),
)
scheduler_config = self.scheduler_config
max_num_reqs = self.max_num_reqs
if not scheduler_config.enable_chunked_prefill:
max_num_reqs = min(
max_num_reqs,
scheduler_config.max_num_batched_tokens // max_tokens_per_item,
)
max_decoder_items_per_batch = max_num_reqs * max_items_per_prompt
max_items_per_batch = max(
1,
min(max_encoder_items_per_batch, max_decoder_items_per_batch),
)
return max_items_per_prompt, max_items_per_batch
@dataclass
class AttentionGroup:
backend: type[AttentionBackend]
# When ubatching is enabled we will have a metadata builder for each ubatch
# so that if they use internal persistant buffers for cudagraphs, and they
# won't have to worry about conflicting with the other ubatches.
metadata_builders: list[AttentionMetadataBuilder]
layer_names: list[str]
kv_cache_spec: KVCacheSpec
@staticmethod
def create_with_metadata_builders(
backend: type[AttentionBackend],
layer_names: list[str],
kv_cache_spec: KVCacheSpec,
vllm_config: VllmConfig,
device: torch.device,
num_metadata_builders: int = 1,
) -> 'AttentionGroup':
metadata_builders = [
backend.get_builder_cls()(kv_cache_spec, layer_names, vllm_config,
device)
for _ in range(num_metadata_builders)
]
return AttentionGroup(backend, metadata_builders, layer_names,
kv_cache_spec)
def get_metadata_builder(self,
ubatch_id: int = 0) -> AttentionMetadataBuilder:
assert len(self.metadata_builders) > ubatch_id
return self.metadata_builders[ubatch_id]
def sanity_check_mm_encoder_outputs(
mm_embeddings: MultiModalEmbeddings,
expected_num_items: int,
) -> None:
"""
Perform sanity checks for the result of
[`vllm.model_executor.models.SupportsMultiModal.get_multimodal_embeddings`][].
"""
assert isinstance(mm_embeddings, (list, tuple, torch.Tensor)), (
"Expected multimodal embeddings to be a list/tuple of 2D tensors, "
f"or a single 3D tensor, but got {type(mm_embeddings)} "
"instead. This is most likely due to incorrect implementation "
"of the model's `get_multimodal_embeddings` method.")
assert len(mm_embeddings) == expected_num_items, (
"Expected number of multimodal embeddings to match number of "
f"input items: {expected_num_items}, but got {len(mm_embeddings)=} "
"instead. This is most likely due to incorrect implementation "
"of the model's `get_multimodal_embeddings` method.")
assert all(e.ndim == 2 for e in mm_embeddings), (
"Expected multimodal embeddings to be a sequence of 2D tensors, "
f"but got tensors with shapes {[e.shape for e in mm_embeddings]} "
"instead. This is most likely due to incorrect implementation "
"of the model's `get_multimodal_embeddings` method.")
def scatter_mm_placeholders(
embeds: torch.Tensor,
is_embed: Optional[torch.Tensor],
) -> torch.Tensor:
"""
Scatter the multimodal embeddings into a contiguous tensor that represents
the placeholder tokens.
[`vllm.multimodal.processing.PromptUpdateDetails.is_embed`][].
Args:
embeds: The multimodal embeddings.
Shape: `(num_embeds, embed_dim)`
is_embed: A boolean mask indicating which positions in the placeholder
tokens need to be filled with multimodal embeddings.
Shape: `(num_placeholders, num_embeds)`
"""
if is_embed is None:
return embeds
placeholders = embeds.new_full(
(is_embed.shape[0], embeds.shape[-1]),
fill_value=torch.nan,
)
placeholders[is_embed] = embeds
return placeholders
def gather_mm_placeholders(
placeholders: torch.Tensor,
is_embed: Optional[torch.Tensor],
) -> torch.Tensor:
"""
Reconstructs the embeddings from the placeholder tokens.
This is the operation of [`scatter_mm_placeholders`]
[vllm.v1.worker.utils.scatter_mm_placeholders].
"""
if is_embed is None:
return placeholders
return placeholders[is_embed]
def add_kv_sharing_layers_to_kv_cache_groups(
shared_kv_cache_layers: dict[str, str],
kv_cache_groups: list[KVCacheGroupSpec],
runner_only_attn_layers: Optional[set[str]] = None,
) -> None:
"""
Sets up KV cache sharing by reusing the allocated KV caches in `kv_caches`
for layers that do not allocate its own KV cache, based on the mapping in
`shared_kv_cache_layers`. Adds these layers to the corresponding KV cache
group, which is needed to ensure that attention metadata is assigned later.
Args:
shared_kv_cache_layers: Layer pairings for cross-layer KV sharing.
If an Attention layer `layer_name` is in the keys of this dict, it
means this layer will perform attention using the keys and values
from the KV cache of `shared_kv_cache_layers[layer_name]`.
kv_cache_groups: The KV cache groups of the model.
"""
layer_to_kv_cache_group: dict[str, KVCacheGroupSpec] = {}
for kv_cache_group in kv_cache_groups:
for layer_name in kv_cache_group.layer_names:
layer_to_kv_cache_group[layer_name] = kv_cache_group
for layer_name, target_layer_name in shared_kv_cache_layers.items():
tgt_kv_cache_group = layer_to_kv_cache_group[target_layer_name]
tgt_kv_cache_group.layer_names.append(layer_name)
if runner_only_attn_layers is not None:
runner_only_attn_layers.add(layer_name)
def bind_kv_cache(
kv_caches: dict[str, torch.Tensor],
forward_context: dict[str, "Attention"],
runner_kv_caches: list[torch.Tensor],
num_attn_module: Optional[int] = 1,
) -> None:
"""
Bind the allocated KV cache to both ModelRunner and forward context so
that the KV cache can be used in the forward pass.
This function:
1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with
kv_caches.
2) Associates each attention layer in the `forward_context` with its
corresponding KV cache in kv_caches.
Args:
kv_caches: The allocated kv_caches with layer names as keys.
forward_context: The global forward context containing all Attention
layers with layer names as keys.
runner_kv_caches: The kv_cache declared by ModelRunner.
"""
# Bind kv_caches to ModelRunner
assert len(runner_kv_caches) == 0
# Convert kv_caches dict to a list of tensors in the order of layer_index.
index2name = defaultdict(list)
for layer_name in kv_caches:
index2name[extract_layer_index(layer_name,
num_attn_module)].append(layer_name)
for layer_index in sorted(index2name.keys()):
layer_names = index2name[layer_index]
if len(layer_names) > 1:
# One typical case is encoder-decoder model, e.g., bart.
# The cross attention and self attention in the same decoder layer
# has different layer_name but the same layer_index.
# TODO - analyze where runner_kv_caches is used and the right
# way to ensure it properly reflects multiple attention layers
# in the same decoder block.
if current_platform.is_kunlun() or current_platform.is_cuda() or current_platform.is_xpu():
# We know that the GPU runner is not impacted by this
# case. Some test code depends on runner_kv_caches, but
# not in a way that's impacted by ignoring this.
pass
else:
raise NotImplementedError
layer_name = layer_names[0]
runner_kv_caches.append(kv_caches[layer_name])
# Bind kv_caches to forward context
for layer_name, kv_cache in kv_caches.items():
# NOTE: Use list because of v0 PP virtual engine.
forward_context[layer_name].kv_cache = [kv_cache]
def is_residual_scattered_for_sp(vllm_config: VllmConfig,
num_input_tokens: int) -> bool:
"""Check if the residual tensor is scattered for sequence parallelism.
The residual tensor is scattered across tensor parallel ranks when sequence
parallelism and tensor parallelism is enabled, and the number of
input tokens is one of the compilation sizes.
"""
if not vllm_config.compilation_config.pass_config.\
enable_sequence_parallelism:
return False
tp = vllm_config.parallel_config.tensor_parallel_size
if tp == 1:
return False
# When sequence parallelism is enabled, we always pad num_input_tokens
# to be a multiple of tensor_parallel_size (tp) earlier.
assert num_input_tokens % tp == 0
# Currently, SP is only enabled for static size fx graphs.
return (num_input_tokens in vllm_config.compilation_config.compile_sizes)