[Feature] support deepseek v3/r1/v3.2 (#78)
* [Feature] support deepseek v3/r1/v3.2 * fix gpt_oss * update readme * update readme --------- Co-authored-by: hanhaowen <hanhaowen@baidu.com>
This commit is contained in:
0
vllm_kunlun/v1/attention/backends/mla/__init__.py
Normal file
0
vllm_kunlun/v1/attention/backends/mla/__init__.py
Normal file
1867
vllm_kunlun/v1/attention/backends/mla/common.py
Normal file
1867
vllm_kunlun/v1/attention/backends/mla/common.py
Normal file
File diff suppressed because it is too large
Load Diff
202
vllm_kunlun/v1/attention/backends/mla/flashmla.py
Normal file
202
vllm_kunlun/v1/attention/backends/mla/flashmla.py
Normal file
@@ -0,0 +1,202 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionLayer, AttentionType
|
||||
from vllm_kunlun.ops.attention.flashmla import (flash_mla_with_kvcache,
|
||||
get_mla_metadata,
|
||||
is_flashmla_supported)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm_kunlun.v1.attention.backends.mla.common import (MLACommonBackend,
|
||||
MLACommonDecodeMetadata,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder)
|
||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
class FlashMLABackend(MLACommonBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASHMLA"
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> type["FlashMLAMetadata"]:
|
||||
return FlashMLAMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["FlashMLAMetadataBuilder"]:
|
||||
return FlashMLAMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["FlashMLAImpl"]:
|
||||
return FlashMLAImpl
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashMLADecodeMetadata(MLACommonDecodeMetadata):
|
||||
tile_scheduler_metadata: torch.Tensor
|
||||
num_splits: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
|
||||
pass
|
||||
|
||||
|
||||
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.UNIFORM_BATCH
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device,
|
||||
FlashMLAMetadata)
|
||||
|
||||
self.num_q_heads = vllm_config.model_config.get_num_attention_heads(
|
||||
vllm_config.parallel_config)
|
||||
|
||||
self.cg_buf_tile_scheduler_metadata = None
|
||||
self.cg_buf_num_splits = None
|
||||
|
||||
device_properties = torch.cuda.get_device_properties(self.device)
|
||||
num_sms = device_properties.multi_processor_count
|
||||
|
||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
self.cg_buf_tile_scheduler_metadata = torch.zeros(
|
||||
# Upper bound on size (<= #SMs, TileSchedulerMetaDataSize)
|
||||
# TileSchedulerMetaDataSize = 8
|
||||
(num_sms, 8),
|
||||
device=self.device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
self.cg_buf_num_splits = torch.empty(
|
||||
(vllm_config.scheduler_config.max_num_seqs + 1),
|
||||
device=self.device,
|
||||
dtype=torch.int32)
|
||||
|
||||
def _build_decode(self, block_table_tensor: torch.Tensor,
|
||||
seq_lens_cpu: torch.Tensor,
|
||||
seq_lens_device: torch.Tensor,
|
||||
query_start_loc_cpu: torch.Tensor,
|
||||
query_start_loc_device: torch.Tensor,
|
||||
num_decode_tokens: int) -> FlashMLADecodeMetadata:
|
||||
tile_scheduler_metadata, num_splits = \
|
||||
get_mla_metadata(
|
||||
seq_lens_device,
|
||||
self.num_q_heads,
|
||||
1, # MQA for the decode path
|
||||
)
|
||||
|
||||
# TODO: we can disambiguate between decode and mixed-prefill decode here
|
||||
# so we can only use the persistent buffer if a cudagraph is actually
|
||||
# being used.
|
||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
assert self.cg_buf_tile_scheduler_metadata is not None
|
||||
assert self.cg_buf_num_splits is not None
|
||||
|
||||
sm_parts = tile_scheduler_metadata.size(0)
|
||||
# Metadata per-SM, upper bound on size (<= #SMs, TileMetadataSize)
|
||||
assert sm_parts <= self.cg_buf_tile_scheduler_metadata.size(0)
|
||||
tile_scheduler_metadata_view = \
|
||||
self.cg_buf_tile_scheduler_metadata[:sm_parts]
|
||||
tile_scheduler_metadata_view.copy_(tile_scheduler_metadata)
|
||||
tile_scheduler_metadata = tile_scheduler_metadata_view
|
||||
|
||||
# Num splits is per-batch, varying size (batch_size,)
|
||||
n = num_splits.size(0)
|
||||
# make sure static buffer is large enough
|
||||
assert n <= self.cg_buf_num_splits.size(0)
|
||||
num_splits_view = self.cg_buf_num_splits[:n]
|
||||
num_splits_view.copy_(num_splits)
|
||||
# Num splits needs to monotonically increasing
|
||||
# (with: https://github.com/vllm-project/FlashMLA/pull/3, otherwise
|
||||
# it needs to monotonically increasing by 1)
|
||||
self.cg_buf_num_splits[n:].fill_(num_splits[-1])
|
||||
num_splits = num_splits_view
|
||||
|
||||
return FlashMLADecodeMetadata(
|
||||
block_table=block_table_tensor,
|
||||
seq_lens=seq_lens_device,
|
||||
tile_scheduler_metadata=tile_scheduler_metadata,
|
||||
num_splits=num_splits,
|
||||
)
|
||||
|
||||
|
||||
class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
||||
|
||||
can_return_lse_for_decode: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
# MLA Specific Arguments
|
||||
**mla_args) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, **mla_args)
|
||||
|
||||
is_supported, reason = is_flashmla_supported()
|
||||
assert is_supported, reason
|
||||
|
||||
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"FlashMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, logits_soft_cap")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"FlashMLAImpl")
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: FlashMLAMetadata,
|
||||
layer: AttentionLayer,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
# TODO: (zyongye) decode function for mla here
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
if type(q) is tuple:
|
||||
q = torch.cat(q, dim=-1)
|
||||
|
||||
assert isinstance(q, torch.Tensor)
|
||||
o, lse = flash_mla_with_kvcache(
|
||||
q=q.unsqueeze(1), # Add seqlen dim of 1 (decode)
|
||||
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
|
||||
block_table=attn_metadata.decode.block_table,
|
||||
cache_seqlens=attn_metadata.decode.seq_lens,
|
||||
head_dim_v=self.kv_lora_rank,
|
||||
tile_scheduler_metadata=attn_metadata.decode.
|
||||
tile_scheduler_metadata,
|
||||
num_splits=attn_metadata.decode.num_splits,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
descale_q=layer._q_scale.reshape(1),
|
||||
descale_k=layer._k_scale.reshape(1),
|
||||
)
|
||||
|
||||
return o, lse
|
||||
752
vllm_kunlun/v1/attention/backends/mla/flashmla_sparse.py
Normal file
752
vllm_kunlun/v1/attention/backends/mla/flashmla_sparse.py
Normal file
@@ -0,0 +1,752 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, ClassVar, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
|
||||
AttentionMetadata)
|
||||
from vllm.attention.backends.utils import get_mla_dims
|
||||
from vllm_kunlun.ops.attention.flashmla import (flash_mla_sparse_prefill,
|
||||
flash_mla_with_kvcache,
|
||||
get_mla_metadata,
|
||||
kunlun_flash_mla_with_kvcache)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl
|
||||
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
reshape_attn_output_for_spec_decode,
|
||||
reshape_query_for_spec_decode,
|
||||
split_decodes_and_prefills)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.distributed import get_tp_group
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.models.deepseek_v2 import Indexer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
"""
|
||||
NOTE: FlashMLA Sparse uses an fp8 cache with the following format
|
||||
|
||||
In the "FP8 with scale" format, each token's KV cache is 656 Bytes,
|
||||
structured as:
|
||||
- **First 512 bytes:** The "quantized NoPE" part, containing 512
|
||||
`float8_e4m3` values.
|
||||
- **Next 16 bytes:** Scale factors, containing 4 `float32` values.
|
||||
The first `float32` is the scale for the first 128 `float8_e4m3` values,
|
||||
the second for the next 128, and so on.
|
||||
- **Last 128 bytes:** The "RoPE" part, containing 64 `bfloat16` values. This
|
||||
part is not quantized for accuracy.
|
||||
"""
|
||||
|
||||
|
||||
def _lse2_to_lse(lse_base2: torch.Tensor) -> torch.Tensor:
|
||||
# Convert base-2 LSE to natural-log LSE
|
||||
# Keep FP32 for numerical stability during the merge.
|
||||
return (lse_base2.to(torch.float32) * math.log(2.0))
|
||||
|
||||
|
||||
class FlashMLASparseBackend(AttentionBackend):
|
||||
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASHMLA_SPARSE_VLLM_V1"
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> type[AttentionMetadata]:
|
||||
return FlashMLASparseMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["FlashMLASparseMetadataBuilder"]:
|
||||
return FlashMLASparseMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["FlashMLASparseImpl"]:
|
||||
return FlashMLASparseImpl
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int, # assumed to be 1 for MLA
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
if cache_dtype_str == "fp8_ds_mla":
|
||||
# custom storage fromat is 656 bytes
|
||||
# see FlashMLA readme.md for details
|
||||
return (num_blocks, block_size, 656)
|
||||
else:
|
||||
return (num_blocks, block_size, head_size)
|
||||
|
||||
@classmethod
|
||||
def get_supported_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [576]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MLASparsePrefillMetadata:
|
||||
# NOTE(Chen): not call it "FlashMLASparsePrefillMetadata" because
|
||||
# the kernel is not from flashmla
|
||||
block_table: torch.Tensor = None
|
||||
has_context: bool = False
|
||||
context_lens: Optional[torch.Tensor] = None
|
||||
|
||||
# Sequence lengths (context + query) for prefill requests
|
||||
# Shape: [num_prefill_reqs]
|
||||
seq_lens: torch.Tensor = None
|
||||
|
||||
# Request ID for each token: -1 for decode tokens, request index
|
||||
# (0, 1, 2, ...) for prefill tokens.
|
||||
# Shape: [num_actual_tokens]
|
||||
request_ids: torch.Tensor = None
|
||||
query_start_loc: torch.Tensor = None
|
||||
query_start_loc_cpu: torch.Tensor = None
|
||||
|
||||
@dataclass
|
||||
class FlashMLASparseDecodeAndContextMetadata:
|
||||
scheduler_metadata: torch.Tensor = None
|
||||
num_splits: torch.Tensor = None
|
||||
cache_lens: torch.Tensor = None
|
||||
prefill_context_lengths: Optional[torch.Tensor] = None
|
||||
prefill_new_k_start_locs: Optional[torch.Tensor] = None
|
||||
dummy_block_table: torch.Tensor = None
|
||||
|
||||
seq_lens: torch.Tensor = None
|
||||
seq_lens_cpu: torch.Tensor = None
|
||||
max_seq_len: int = -1 # needed for reshape in spec decode
|
||||
|
||||
def filter_prefill_indices(
|
||||
self, indices: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.prefill_context_lengths is not None
|
||||
prefill_context_lengths = self.prefill_context_lengths.unsqueeze(-1)
|
||||
context_indices = torch.where(indices < prefill_context_lengths,
|
||||
indices, -1)
|
||||
new_token_indices = torch.where(indices >= prefill_context_lengths,
|
||||
indices - prefill_context_lengths, -1)
|
||||
return context_indices, new_token_indices
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashMLASparseMetadata:
|
||||
num_reqs: int
|
||||
max_query_len: int
|
||||
max_seq_len: int
|
||||
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
query_start_loc: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
block_table: torch.Tensor
|
||||
req_id_per_token: torch.Tensor
|
||||
block_size: int = 64
|
||||
topk_tokens: int = 2048
|
||||
|
||||
num_prefills: int = 0
|
||||
num_decodes: int = 0
|
||||
num_prefill_tokens: int = 0
|
||||
num_decode_tokens: int = 0
|
||||
|
||||
decode_metadata: Optional[FlashMLASparseDecodeAndContextMetadata] = None
|
||||
prefill_metadata: Optional[MLASparsePrefillMetadata] = None
|
||||
|
||||
@dataclass
|
||||
class FP8KernelMetadata:
|
||||
scheduler_metadata: Optional[torch.Tensor]
|
||||
num_splits: torch.Tensor
|
||||
dummy_block_table: torch.Tensor
|
||||
cache_lens: torch.Tensor
|
||||
|
||||
fp8_extra_metadata: Optional[FP8KernelMetadata] = None
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _convert_req_index_to_global_index_kernel(
|
||||
req_id_ptr, # int32 [num_tokens]
|
||||
block_table_ptr, # int32 [num_requests, max_num_blocks_per_req]
|
||||
token_indices_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS]
|
||||
out_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS]
|
||||
# shapes (compile-time where possible)
|
||||
max_num_blocks_per_req: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr, # tile width along columns
|
||||
# strides (in elements)
|
||||
bt_stride0,
|
||||
bt_stride1,
|
||||
ti_stride0,
|
||||
ti_stride1,
|
||||
out_stride0,
|
||||
out_stride1,
|
||||
):
|
||||
# program_id(0) -> token_id (row)
|
||||
# program_id(1) -> tile index along columns
|
||||
token_id = tl.program_id(0)
|
||||
tile_id = tl.program_id(1)
|
||||
|
||||
# Each program covers BLOCK_N consecutive columns
|
||||
indice_id = tile_id * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
|
||||
# Load request id for this token (no mask: grid is exact)
|
||||
req = tl.load(req_id_ptr + token_id)
|
||||
|
||||
# Load token indices for this tile
|
||||
ti_ptr = token_indices_ptr + token_id * ti_stride0 + indice_id * ti_stride1
|
||||
tok = tl.load(ti_ptr) # int32
|
||||
|
||||
# Only token == -1 should propagate as -1
|
||||
is_invalid_tok = tok < 0
|
||||
|
||||
# Compute block id and in-block offset
|
||||
block_id = tok // BLOCK_SIZE
|
||||
inblock_off = tok % BLOCK_SIZE
|
||||
|
||||
# Guard block_table access
|
||||
valid_block = block_id < max_num_blocks_per_req
|
||||
bt_ptr = block_table_ptr + req * bt_stride0 + block_id * bt_stride1
|
||||
base = tl.load(bt_ptr, mask=valid_block, other=0)
|
||||
|
||||
# If token == -1 OR block_id OOB, output -1; else base * BLOCK_SIZE + offset
|
||||
out_val = tl.where(is_invalid_tok | (~valid_block), -1,
|
||||
base * BLOCK_SIZE + inblock_off)
|
||||
|
||||
# Store results
|
||||
out_ptr_ij = out_ptr + token_id * out_stride0 + indice_id * out_stride1
|
||||
tl.store(out_ptr_ij, out_val)
|
||||
|
||||
|
||||
def triton_convert_req_index_to_global_index(
|
||||
req_id: torch.Tensor, # int32 [num_tokens]
|
||||
block_table: torch.
|
||||
Tensor, # int32 [num_requests, max_num_blocks_per_req]
|
||||
token_indices: torch.Tensor, # int32 [num_tokens, NUM_TOPK_TOKENS]
|
||||
BLOCK_SIZE: int = 64,
|
||||
NUM_TOPK_TOKENS: int = 2048,
|
||||
BLOCK_N: int = 128, # tile width along columns
|
||||
):
|
||||
"""
|
||||
out[token_id, indice_id] =
|
||||
block_table[req_id[token_id],
|
||||
token_indices[token_id, indice_id] // BLOCK_SIZE] * BLOCK_SIZE
|
||||
+ token_indices[token_id, indice_id] % BLOCK_SIZE
|
||||
|
||||
Only when token_indices[token_id, indice_id] == -1 do we output -1.
|
||||
For safety, we also output -1 if the derived block_id would be
|
||||
out-of-bounds.
|
||||
"""
|
||||
assert req_id.dtype == torch.int32
|
||||
assert block_table.dtype == torch.int32
|
||||
assert token_indices.dtype == torch.int32
|
||||
assert token_indices.shape[1] == NUM_TOPK_TOKENS
|
||||
assert NUM_TOPK_TOKENS % BLOCK_N == 0, \
|
||||
f"NUM_TOPK_TOKENS ({NUM_TOPK_TOKENS}) must be divisible by" \
|
||||
f"BLOCK_N ({BLOCK_N})"
|
||||
|
||||
num_tokens = req_id.shape[0]
|
||||
num_requests, max_num_blocks_per_req = block_table.shape
|
||||
tiles_per_row = NUM_TOPK_TOKENS // BLOCK_N
|
||||
|
||||
# Ensure contiguous tensors on the same device
|
||||
req_id_c = req_id.contiguous()
|
||||
block_table_c = block_table.contiguous()
|
||||
token_indices_c = token_indices.contiguous()
|
||||
out = torch.empty_like(token_indices_c)
|
||||
|
||||
# Strides in elements
|
||||
bt_stride0, bt_stride1 = block_table_c.stride()
|
||||
ti_stride0, ti_stride1 = token_indices_c.stride()
|
||||
out_stride0, out_stride1 = out.stride()
|
||||
|
||||
# Exact 2D grid: tokens × column tiles
|
||||
grid = (num_tokens, tiles_per_row)
|
||||
|
||||
_convert_req_index_to_global_index_kernel[grid](
|
||||
req_id_c,
|
||||
block_table_c,
|
||||
token_indices_c,
|
||||
out,
|
||||
# shapes / constexprs
|
||||
max_num_blocks_per_req,
|
||||
BLOCK_SIZE,
|
||||
BLOCK_N,
|
||||
# strides
|
||||
bt_stride0,
|
||||
bt_stride1,
|
||||
ti_stride0,
|
||||
ti_stride1,
|
||||
out_stride0,
|
||||
out_stride1,
|
||||
)
|
||||
return out
|
||||
|
||||
def kunlun_convert_req_index_to_global_index(
|
||||
req_id: torch.Tensor, # int32 [num_tokens]
|
||||
block_table: torch.Tensor, # int32 [num_requests, max_num_blocks_per_req]
|
||||
token_indices: torch.Tensor, # int32 [num_tokens, NUM_TOPK_TOKENS]
|
||||
BLOCK_SIZE: int = 64,
|
||||
NUM_TOPK_TOKENS: int = 2048,
|
||||
):
|
||||
assert req_id.dtype == torch.int32
|
||||
assert block_table.dtype == torch.int32
|
||||
assert token_indices.dtype == torch.int32
|
||||
assert token_indices.shape[1] == NUM_TOPK_TOKENS
|
||||
|
||||
num_tokens = req_id.shape[0]
|
||||
num_requests, max_num_blocks_per_req = block_table.shape
|
||||
|
||||
out = torch.zeros_like(token_indices)
|
||||
|
||||
# Compute block_id and inblock_off for all tokens at once
|
||||
block_id = token_indices // BLOCK_SIZE
|
||||
inblock_off = token_indices % BLOCK_SIZE
|
||||
|
||||
# Create mask for invalid tokens (tok < 0)
|
||||
invalid_tok_mask = token_indices < 0
|
||||
|
||||
# Create mask for out-of-bounds block_id
|
||||
oob_block_mask = block_id >= max_num_blocks_per_req
|
||||
|
||||
# Combine masks - output -1 for either condition
|
||||
invalid_mask = invalid_tok_mask | oob_block_mask
|
||||
|
||||
# Get request IDs expanded to match token_indices shape
|
||||
req_ids_expanded = req_id.unsqueeze(1).expand(-1, NUM_TOPK_TOKENS)
|
||||
|
||||
# Gather base addresses from block_table
|
||||
# Clamp block_id to avoid index errors (we'll mask these out anyway)
|
||||
block_id_clamped = torch.clamp(block_id, 0, max_num_blocks_per_req - 1)
|
||||
|
||||
# Use advanced indexing to get base addresses
|
||||
base_addrs = block_table[req_ids_expanded, block_id_clamped]
|
||||
|
||||
# Compute the global indices
|
||||
global_indices = base_addrs * BLOCK_SIZE + inblock_off
|
||||
|
||||
# Apply mask: set invalid positions to -1
|
||||
out = torch.where(invalid_mask, torch.tensor(-1, dtype=torch.int32, device=token_indices.device), global_indices)
|
||||
|
||||
return out
|
||||
|
||||
def kunlun_concat_and_cache_mla(
|
||||
kv_c: torch.Tensor, #[num_tokens, kv_lora_rank]
|
||||
k_pe: torch.Tensor, #[num_tokens, pe_dim]
|
||||
kv_cache: torch.Tensor, #[num_blocks, block_size, (kv_lora_rank + pe_dim)]
|
||||
slot_mapping: torch.Tensor, #[num_tokens] or [num_actual_tokens]
|
||||
kv_cache_dtype: str,
|
||||
scale: torch.Tensor
|
||||
):
|
||||
num_tokens = slot_mapping.shape[0]
|
||||
kv_lora_rank = kv_c.shape[1]
|
||||
pe_dim = k_pe.shape[1]
|
||||
block_size = kv_cache.shape[1]
|
||||
|
||||
def kunlun_fp8_ds_mla():
|
||||
for token_idx in range(num_tokens):
|
||||
slot = slot_mapping[token_idx].item()
|
||||
if slot < 0: continue
|
||||
block_idx = slot // block_size
|
||||
block_offset = slot % block_size
|
||||
kv_c_i = kv_c[token_idx].view(4,kv_lora_rank//4).contiguous()
|
||||
kv_c_i_int8 = torch.zeros(
|
||||
kv_c_i.shape,
|
||||
device=kv_c.device,
|
||||
dtype=torch.int8,
|
||||
)
|
||||
kv_c_i_scale = torch.zeros(
|
||||
[kv_c_i.shape[0], 1],
|
||||
device=kv_c.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
torch.ops._C.quant2d(kv_c_i, kv_c_i_int8, kv_c_i_scale, force_sdnn=True)
|
||||
kv_c_i_scale /= 127
|
||||
kv_cache[block_idx, block_offset, :kv_lora_rank] = kv_c_i_int8.view(-1).view(torch.uint8).contiguous()
|
||||
kv_cache[block_idx, block_offset, kv_lora_rank:kv_lora_rank + 16] = kv_c_i_scale.view(-1).view(torch.uint8).contiguous()
|
||||
kv_cache[block_idx, block_offset, kv_lora_rank+16:] = k_pe[token_idx, :].view(torch.uint8).contiguous()
|
||||
|
||||
def kunlun_mla():
|
||||
for token_idx in range(num_tokens):
|
||||
slot = slot_mapping[token_idx].item()
|
||||
if slot < 0: continue
|
||||
block_idx = slot // block_size
|
||||
block_offset = slot % block_size
|
||||
kv_cache[block_idx, block_offset, :kv_lora_rank] = kv_c[token_idx, :].contiguous()
|
||||
kv_cache[block_idx, block_offset, kv_lora_rank:] = k_pe[token_idx, :].contiguous()
|
||||
|
||||
if (kv_cache_dtype == "fp8_ds_mla"):
|
||||
assert kv_lora_rank == 512, "kv_lora_rank must be 512 for fp8_ds_mla"
|
||||
assert pe_dim == 64, "pe_dim must be 64 for fp8_ds_mla"
|
||||
assert kv_cache.shape[2] == 656 // kv_cache.element_size(), "kv_cache.shape[2] must be 656 bytes for fp8_ds_mla"
|
||||
assert kv_c.element_size() == 2, "kv_c.element_size() must be 2 for fp8_ds_mla"
|
||||
assert k_pe.element_size() == 2, "k_pe.element_size() must be 2 for fp8_ds_mla"
|
||||
kunlun_fp8_ds_mla()
|
||||
else:
|
||||
assert kv_cache.shape[2] == kv_lora_rank + pe_dim
|
||||
kunlun_mla()
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashMLASparseMetadataBuilder(
|
||||
AttentionMetadataBuilder[FlashMLASparseMetadata]):
|
||||
cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.UNIFORM_BATCH
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device):
|
||||
self.vllm_config = vllm_config
|
||||
self.layer_names = layer_names
|
||||
cache_config = vllm_config.cache_config
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.model_config = vllm_config.model_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
self.device = device
|
||||
|
||||
# Treat requests with query length <= 1 as decodes to match the
|
||||
# DeepGEMM indexer constraint (fp8_paged_mqa_logits only supports next_n <= 2)
|
||||
# 从最新版本vllm中引入的
|
||||
self._init_reorder_batch_threshold(1, supports_spec_as_decode=True)
|
||||
|
||||
props = torch.cuda.get_device_properties(device)
|
||||
sm_count = props.multi_processor_count
|
||||
|
||||
self.num_heads = self.model_config.get_num_attention_heads(
|
||||
parallel_config)
|
||||
self.mla_dims = get_mla_dims(self.model_config)
|
||||
self.topk_tokens = vllm_config.model_config.hf_config.index_topk
|
||||
self.use_fp8_kv_cache = cache_config.cache_dtype == "fp8_ds_mla"
|
||||
|
||||
self.topk_tokens_tensor = torch.tensor([self.topk_tokens],
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
# self.max_model_len_tensor = torch.tensor(
|
||||
# [self.model_config.max_model_len],
|
||||
# device=device,
|
||||
# dtype=torch.int32)
|
||||
|
||||
# this is ignored by `flash_mla_with_kvcache` if indices not None
|
||||
self.dummy_block_table = torch.empty((1, 1),
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
|
||||
# Equation taken from FlashMLA/csrc/pybind.cpp
|
||||
h_q, h_k = self.num_heads, 1
|
||||
s_q = 1 # inversely proportional to s_q, so s_q = 1 is the largest
|
||||
max_num_sm_parts = int(
|
||||
max((sm_count // 2) / h_k // (cdiv(h_q // h_k, 2 * 64) * s_q), 1))
|
||||
if current_platform.is_device_capability(100):
|
||||
max_num_sm_parts *= 2
|
||||
self.tile_scheduler_metadata_buffer = torch.zeros(
|
||||
# TileSchedulerMetaDataSize = 8
|
||||
# see: FlashMLA/csrc/params.h
|
||||
(max_num_sm_parts, 8),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
self.num_splits_buffer = torch.zeros(
|
||||
# We pack all the tokens into one batch for sparse attention.
|
||||
# Otherwise, we can exceed the sm of `get_mla_metadata`.
|
||||
(
|
||||
2, ),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
self.req_id_per_token_buffer = torch.zeros(
|
||||
(vllm_config.scheduler_config.max_num_batched_tokens, ),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
def build(self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False) -> FlashMLASparseMetadata:
|
||||
|
||||
num_tokens = common_attn_metadata.num_actual_tokens
|
||||
starts = np.asarray(common_attn_metadata.query_start_loc_cpu,
|
||||
dtype=np.int32)
|
||||
seg_lengths = np.diff(starts)
|
||||
req_id_per_token = np.repeat(
|
||||
np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths)
|
||||
# Zero-fill for cudagraphs
|
||||
self.req_id_per_token_buffer.fill_(0)
|
||||
self.req_id_per_token_buffer[:req_id_per_token.shape[0]]\
|
||||
.copy_(torch.from_numpy(req_id_per_token), non_blocking=True)
|
||||
req_id_per_token = self.req_id_per_token_buffer[:num_tokens]
|
||||
|
||||
fp8_extra_metadata = None
|
||||
|
||||
if self.use_fp8_kv_cache:
|
||||
cache_seqlens_cpu, cache_seqlens = get_mla_metadata(
|
||||
cache_seqlens=self.topk_tokens_tensor,
|
||||
)
|
||||
fp8_extra_metadata = FlashMLASparseMetadata.FP8KernelMetadata(
|
||||
scheduler_metadata=None,
|
||||
num_splits=None,
|
||||
# cache_lens and block_table are basically unused in sparse case
|
||||
# but the decode kernel will treat -1 and indices >= cache_lens
|
||||
# as invalid so we make sure cache_lens is large enough to not
|
||||
# accidentally mark indices invalid, we will use -1 exclusively
|
||||
# to mark invalid indices
|
||||
cache_lens=cache_seqlens_cpu,
|
||||
dummy_block_table=self.dummy_block_table)
|
||||
|
||||
(num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) = (
|
||||
split_decodes_and_prefills(
|
||||
common_attn_metadata,
|
||||
decode_threshold=self.reorder_batch_threshold or 1,
|
||||
require_uniform=True,
|
||||
)
|
||||
)
|
||||
|
||||
# For pure decode batches, prefill_request_id will be None
|
||||
# For mixed batches, it will have -1 for decode and request_id for prefill
|
||||
prefill_metadata = None
|
||||
if num_prefills > 0:
|
||||
prefill_metadata = MLASparsePrefillMetadata(
|
||||
query_start_loc = common_attn_metadata.query_start_loc[num_decodes:] - common_attn_metadata.query_start_loc[num_decodes], #因为prefiil、decode请求是分离,所以需要对q进行切分,故需调整该值
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[num_decodes:] - common_attn_metadata.query_start_loc_cpu[num_decodes],
|
||||
)
|
||||
|
||||
decode_metadata = None
|
||||
if num_decodes > 0:
|
||||
max_seq_len = int(common_attn_metadata.seq_lens_cpu[:num_decodes].max())
|
||||
|
||||
decode_metadata = FlashMLASparseDecodeAndContextMetadata(
|
||||
max_seq_len=max_seq_len,
|
||||
seq_lens=common_attn_metadata.seq_lens[:num_decodes],
|
||||
seq_lens_cpu=common_attn_metadata.seq_lens_cpu[:num_decodes],
|
||||
)
|
||||
|
||||
|
||||
metadata = FlashMLASparseMetadata(
|
||||
num_reqs=common_attn_metadata.num_reqs,
|
||||
max_query_len=common_attn_metadata.max_query_len,
|
||||
max_seq_len=common_attn_metadata.max_seq_len,
|
||||
num_actual_tokens=common_attn_metadata.num_actual_tokens,
|
||||
query_start_loc=common_attn_metadata.query_start_loc,
|
||||
slot_mapping=common_attn_metadata.slot_mapping,
|
||||
block_table=common_attn_metadata.block_table_tensor,
|
||||
req_id_per_token=req_id_per_token,
|
||||
block_size=self.kv_cache_spec.block_size,
|
||||
topk_tokens=self.topk_tokens,
|
||||
fp8_extra_metadata=fp8_extra_metadata,
|
||||
num_prefills=num_prefills,
|
||||
num_decodes=num_decodes,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
decode_metadata=decode_metadata,
|
||||
prefill_metadata=prefill_metadata
|
||||
)
|
||||
return metadata
|
||||
|
||||
|
||||
class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
# MLA Specific Arguments
|
||||
topk_indice_buffer: Optional[torch.Tensor] = None,
|
||||
indexer: Optional["Indexer"] = None,
|
||||
**mla_args) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, **mla_args)
|
||||
self.softmax_scale = scale
|
||||
assert indexer is not None
|
||||
self.topk_indices_buffer = indexer.topk_indices_buffer
|
||||
self.padding = 128 if current_platform.is_device_capability(
|
||||
100) else 64
|
||||
|
||||
def _forward_bf16_kv(
|
||||
self, q: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
attn_metadata: FlashMLASparseMetadata) -> torch.Tensor:
|
||||
|
||||
num_tokens = q.shape[0]
|
||||
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.contiguous().view(
|
||||
-1, kv_c_and_k_pe_cache.shape[-1])
|
||||
|
||||
# num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||
num_decodes = attn_metadata.num_decodes
|
||||
|
||||
has_decode = attn_metadata.num_decodes > 0
|
||||
has_prefill = attn_metadata.num_prefills > 0
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
|
||||
def _bf16_decode(q: torch.Tensor, topk_indices: torch.Tensor) -> torch.Tensor:
|
||||
# Reshape q: (num_decode_tokens, num_heads, head_dim)
|
||||
# -> (num_decodes, seq_len, num_heads, head_dim)
|
||||
q = reshape_query_for_spec_decode(q, num_decodes)
|
||||
seq_len = q.shape[1]
|
||||
# Reshape topk_indices: (num_decode_tokens, topk)
|
||||
# -> (num_decodes, seq_len, topk)
|
||||
topk_indices = topk_indices.view(num_decodes, seq_len, -1)
|
||||
decode_metadata = attn_metadata.decode_metadata
|
||||
_attn_out, _, _ = kunlun_flash_mla_with_kvcache(
|
||||
q=q,
|
||||
k_cache=kv_c_and_k_pe_cache,
|
||||
head_dim_v=512,
|
||||
cache_seqlens=decode_metadata.seq_lens,
|
||||
cache_seqlens_cpu=decode_metadata.seq_lens_cpu,
|
||||
is_fp8_kvcache=False,
|
||||
indices=topk_indices,
|
||||
softmax_scale=self.softmax_scale,
|
||||
max_seq_kv=decode_metadata.max_seq_len
|
||||
)
|
||||
# Reshape output: (num_decodes, seq_len, num_heads, head_dim_v)
|
||||
# -> (num_decode_tokens, num_heads, head_dim_v)
|
||||
return reshape_attn_output_for_spec_decode(_attn_out)
|
||||
|
||||
def _bf16_prefill(q: torch.Tensor, topk_indices: torch.Tensor) -> torch.Tensor:
|
||||
prefill_metadata = attn_metadata.prefill_metadata
|
||||
topk_indices = topk_indices.view(num_prefill_tokens, 1, -1)
|
||||
# NOTE: 只有prefill阶段attn_metadata.query_start_loc是符合klx算子需求的
|
||||
_attn_out = flash_mla_sparse_prefill(
|
||||
q=q,
|
||||
kv=kv_c_and_k_pe_cache,
|
||||
indices=topk_indices,
|
||||
sm_scale=self.softmax_scale,
|
||||
q_lod_xpu=prefill_metadata.query_start_loc,
|
||||
q_lod_cpu=prefill_metadata.query_start_loc_cpu
|
||||
)[0]
|
||||
return _attn_out
|
||||
|
||||
topk_indices_global = torch.ops.xspeedgate_ops.convert_req_index_to_global_index(
|
||||
req_id=attn_metadata.req_id_per_token,
|
||||
block_table=attn_metadata.block_table,
|
||||
token_indices=topk_indices,
|
||||
block_size=attn_metadata.block_size,
|
||||
num_topk_tokens=attn_metadata.topk_tokens,
|
||||
)
|
||||
|
||||
attn_out = torch.empty(
|
||||
(num_tokens, self.num_heads, self.kv_lora_rank),
|
||||
dtype=q.dtype,
|
||||
device=q.device,
|
||||
)
|
||||
if has_prefill:
|
||||
prefill_q = q[num_decode_tokens:]
|
||||
prefill_topk_indices_global = topk_indices_global[num_decode_tokens:]
|
||||
attn_out[num_decode_tokens:] = _bf16_prefill(prefill_q, prefill_topk_indices_global)
|
||||
|
||||
# 处理decode部分 - 需要正确的block table映射print
|
||||
if has_decode:
|
||||
decode_q = q[:num_decode_tokens]
|
||||
decode_topk_indices_global = topk_indices_global[:num_decode_tokens]
|
||||
attn_out[:num_decode_tokens] = _bf16_decode(decode_q, decode_topk_indices_global)
|
||||
|
||||
return attn_out
|
||||
|
||||
|
||||
def _forward_fp8_kv(self, q: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
attn_metadata: FlashMLASparseMetadata) -> torch.Tensor:
|
||||
# TODO: When fwd_kvcache_mla supports uint8 kv cache, execute this function.
|
||||
assert attn_metadata.fp8_extra_metadata is not None
|
||||
extra_metadata = attn_metadata.fp8_extra_metadata
|
||||
|
||||
_attn_out, _ = flash_mla_with_kvcache(
|
||||
q=q.unsqueeze(0), # unsqueeze to add batch_dim
|
||||
k_cache=kv_c_and_k_pe_cache,
|
||||
block_table=extra_metadata.dummy_block_table,
|
||||
head_dim_v=512,
|
||||
cache_seqlens=extra_metadata.cache_lens,
|
||||
tile_scheduler_metadata=extra_metadata.scheduler_metadata, # None
|
||||
num_splits=extra_metadata.num_splits, # None
|
||||
is_fp8_kvcache=True,
|
||||
indices=topk_indices.unsqueeze(0), # unsqueeze to add batch_dim
|
||||
softmax_scale=self.softmax_scale,
|
||||
max_seq_kv=attn_metadata.max_seq_len
|
||||
)
|
||||
|
||||
return _attn_out
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
q: torch.Tensor,
|
||||
k_c_normed: torch.Tensor, # key in unified attn
|
||||
k_pe: torch.Tensor, # value in unified attn
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlashMLASparseMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
output_block_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
# NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use
|
||||
# MQA 576/512 approach for both prefill and decode
|
||||
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for MLACommonImpl")
|
||||
|
||||
if attn_metadata is None:
|
||||
# The zero fill is required when used with DP + EP
|
||||
# to ensure all ranks within a DP group compute the
|
||||
# same expert outputs.
|
||||
return output.fill_(0)
|
||||
|
||||
num_actual_toks = attn_metadata.num_actual_tokens
|
||||
|
||||
# Inputs and outputs may be padded for CUDA graphs
|
||||
|
||||
q = q[:num_actual_toks, ...]
|
||||
k_c_normed = k_c_normed[:num_actual_toks, ...]
|
||||
k_pe = k_pe[:num_actual_toks, ...]
|
||||
|
||||
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim],
|
||||
dim=-1)
|
||||
# Convert from (B, N, P) to (N, B, P)
|
||||
q_nope = q_nope.transpose(0, 1)
|
||||
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
|
||||
ql_nope = torch.bmm(q_nope, self.W_UK_T)
|
||||
# Convert from (N, B, L) to (B, N, L)
|
||||
ql_nope = ql_nope.transpose(0, 1)
|
||||
|
||||
topk_indices = self.topk_indices_buffer[:num_actual_toks]
|
||||
|
||||
q = torch.cat([ql_nope, q_pe], dim=-1)
|
||||
|
||||
# write the latent and rope to kv cache
|
||||
if kv_cache.numel() > 0:
|
||||
torch.ops._C.concat_and_cache_mla(
|
||||
kv_c=k_c_normed,
|
||||
k_pe=k_pe.squeeze(1),
|
||||
kv_cache=kv_cache,
|
||||
slot_mapping=attn_metadata.slot_mapping.flatten(),
|
||||
)
|
||||
|
||||
if self.kv_cache_dtype != "fp8_ds_mla":
|
||||
attn_out = self._forward_bf16_kv(q, kv_cache, topk_indices,
|
||||
attn_metadata)
|
||||
else:
|
||||
# attn_out = self._forward_fp8_kv(q, kv_cache, topk_indices_global,
|
||||
# attn_metadata)
|
||||
raise NotImplementedError
|
||||
|
||||
self._v_up_proj(attn_out, out=output[:num_actual_toks])
|
||||
return output
|
||||
133
vllm_kunlun/v1/attention/backends/mla/indexer.py
Normal file
133
vllm_kunlun/v1/attention/backends/mla/indexer.py
Normal file
@@ -0,0 +1,133 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar, Optional
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.utils import ( CommonAttentionMetadata,
|
||||
split_decodes_and_prefills)
|
||||
from vllm.v1.attention.backends.mla.indexer import (split_prefill_chunks,
|
||||
DeepseekV32IndexerMetadataBuilder,
|
||||
DeepseekV32IndexerPrefillMetadata)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@dataclass
|
||||
class DeepSeekV32IndexerDecodeMetadata:
|
||||
block_table: torch.Tensor
|
||||
seq_lens: torch.Tensor
|
||||
seq_lens_cpu: torch.Tensor
|
||||
decode_lens: torch.Tensor
|
||||
requires_padding: bool
|
||||
schedule_metadata: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeepseekV32IndexerMetadata:
|
||||
|
||||
# FIXME (zyongye)
|
||||
# hacky way to access the data now, need to be in chunked meta
|
||||
seq_lens: torch.Tensor
|
||||
seq_lens_cpu: torch.Tensor
|
||||
|
||||
num_reqs: int
|
||||
max_query_len: int
|
||||
max_seq_len: int
|
||||
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
query_start_loc: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
# The dimension of the attention heads
|
||||
head_dim: int
|
||||
|
||||
# New for MLA (compared to FlashAttention)
|
||||
# For handling prefill decode split
|
||||
num_decodes: int
|
||||
num_decode_tokens: int
|
||||
num_prefills: int
|
||||
num_prefill_tokens: int
|
||||
|
||||
decode: Optional[DeepSeekV32IndexerDecodeMetadata] = None
|
||||
prefill: Optional[DeepseekV32IndexerPrefillMetadata] = None
|
||||
|
||||
def kunlun_build(self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False) -> DeepseekV32IndexerMetadata:
|
||||
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_tokens = common_attn_metadata.num_actual_tokens
|
||||
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
|
||||
split_decodes_and_prefills(
|
||||
common_attn_metadata,
|
||||
decode_threshold=self.reorder_batch_threshold)
|
||||
|
||||
assert num_decodes + num_prefills == num_reqs
|
||||
assert num_decode_tokens + num_prefill_tokens == num_tokens
|
||||
|
||||
prefill_metadata = None
|
||||
if num_prefills > 0:
|
||||
chunk_seq_ids = split_prefill_chunks(
|
||||
common_attn_metadata.seq_lens_cpu,
|
||||
self.max_prefill_buffer_size,
|
||||
num_decodes,
|
||||
)
|
||||
chunks = [
|
||||
self.build_one_prefill_chunk(
|
||||
reqs_start, reqs_end, query_start_loc_cpu,
|
||||
common_attn_metadata.seq_lens_cpu,
|
||||
common_attn_metadata.block_table_tensor)
|
||||
for reqs_start, reqs_end in chunk_seq_ids
|
||||
]
|
||||
prefill_metadata = DeepseekV32IndexerPrefillMetadata(
|
||||
chunks=chunks, )
|
||||
|
||||
decode_metadata = None
|
||||
if num_decodes > 0:
|
||||
torch.diff(common_attn_metadata.query_start_loc[:num_decodes + 1],
|
||||
out=self.decode_lens_buffer[:num_decodes])
|
||||
decode_lens = self.decode_lens_buffer[:num_decodes]
|
||||
decode_lens_cpu = torch.diff(
|
||||
common_attn_metadata.query_start_loc_cpu[:num_decodes + 1])
|
||||
|
||||
# Use CPU to avoid GPU sync; breaking async scheduling
|
||||
requires_padding = (decode_lens_cpu.max()
|
||||
> decode_lens_cpu.min()).item()
|
||||
|
||||
seq_lens = common_attn_metadata.seq_lens[:num_decodes]
|
||||
|
||||
decode_metadata = DeepSeekV32IndexerDecodeMetadata(
|
||||
block_table=common_attn_metadata.
|
||||
block_table_tensor[:num_decodes, ...],
|
||||
seq_lens=common_attn_metadata.seq_lens[:num_decodes],
|
||||
seq_lens_cpu=common_attn_metadata.seq_lens[:num_decodes].cpu(),
|
||||
decode_lens=decode_lens,
|
||||
requires_padding=requires_padding,
|
||||
schedule_metadata=self.scheduler_metadata_buffer,
|
||||
)
|
||||
|
||||
attn_metadata = DeepseekV32IndexerMetadata(
|
||||
seq_lens=common_attn_metadata.seq_lens,
|
||||
seq_lens_cpu=common_attn_metadata.seq_lens.cpu(),
|
||||
num_reqs=common_attn_metadata.num_reqs,
|
||||
max_query_len=common_attn_metadata.max_query_len,
|
||||
max_seq_len=common_attn_metadata.max_seq_len,
|
||||
num_actual_tokens=common_attn_metadata.num_actual_tokens,
|
||||
query_start_loc=common_attn_metadata.query_start_loc,
|
||||
slot_mapping=common_attn_metadata.slot_mapping,
|
||||
head_dim=128,
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
num_prefills=num_prefills,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
prefill=prefill_metadata,
|
||||
decode=decode_metadata,
|
||||
)
|
||||
|
||||
# if get_tensor_model_parallel_rank() == 0:
|
||||
# logger.info(f"attn_metadata: {attn_metadata}")
|
||||
return attn_metadata
|
||||
|
||||
DeepseekV32IndexerMetadataBuilder.build = kunlun_build
|
||||
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from packaging import version
|
||||
@@ -24,6 +24,7 @@ class TopKTopPSampler(nn.Module):
|
||||
|
||||
def __init__(self, logprobs_mode):
|
||||
super().__init__()
|
||||
self.logprobs_mode = logprobs_mode
|
||||
logger.info_once(
|
||||
"Using FlashInfer for top-p & top-k sampling.")
|
||||
self.forward = self.forward_kunlun
|
||||
@@ -40,9 +41,14 @@ class TopKTopPSampler(nn.Module):
|
||||
|
||||
The logits tensor may be updated in-place.
|
||||
"""
|
||||
logits = apply_top_k_top_p(logits, k, p)
|
||||
logits = self.apply_top_k_top_p(logits, k, p)
|
||||
logits_to_return = None
|
||||
if self.logprobs_mode == "processed_logits":
|
||||
logits_to_return = logits
|
||||
elif self.logprobs_mode == "processed_logprobs":
|
||||
logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32)
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
return random_sample(probs, generators), None
|
||||
return random_sample(probs, generators), logits_to_return
|
||||
|
||||
def forward_kunlun(
|
||||
self,
|
||||
@@ -52,16 +58,13 @@ class TopKTopPSampler(nn.Module):
|
||||
p: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""More optimized implementation for top-k and top-p sampling."""
|
||||
if k is None and p is None:
|
||||
# We prefer `random_sample` over `flashinfer_sample` when sorting is
|
||||
# not needed. This is because `random_sample` does not require
|
||||
# CPU-GPU synchronization while `flashinfer_sample` does.
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
return random_sample(probs, generators), None
|
||||
if generators:
|
||||
logger.warning_once("FlashInfer 0.2.3+ does not support "
|
||||
"per-request generators. Falling back to "
|
||||
"PyTorch-native implementation.")
|
||||
if (k is None and p is None) or generators:
|
||||
if generators:
|
||||
logger.debug_once(
|
||||
"FlashInfer 0.2.3+ does not support "
|
||||
"per-request generators. Falling back to "
|
||||
"PyTorch-native implementation."
|
||||
)
|
||||
return self.forward_native(logits, generators, k, p)
|
||||
# flashinfer sampling functions expect contiguous logits.
|
||||
# In flex_attn/triton_attn fp32 inference, logits can be non-contiguous
|
||||
@@ -196,6 +199,7 @@ def flashinfer_sample(
|
||||
probs, top_k=k, deterministic=True)
|
||||
else:
|
||||
# Both top-k and top-p.
|
||||
k = k.to(torch.int32)
|
||||
next_token_ids = xtorch_ops.top_k_top_p_sampling_from_probs(
|
||||
probs, top_k=k, top_p=p, deterministic=True)
|
||||
|
||||
|
||||
344
vllm_kunlun/v1/worker/utils.py
Normal file
344
vllm_kunlun/v1/worker/utils.py
Normal file
@@ -0,0 +1,344 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
|
||||
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
|
||||
from vllm.model_executor.models.utils import extract_layer_index
|
||||
from vllm.multimodal.cache import processor_only_cache_from_config
|
||||
from vllm.multimodal.registry import MultiModalRegistry
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
|
||||
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
|
||||
from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.layer import Attention
|
||||
|
||||
|
||||
class MultiModalBudget:
|
||||
"""Helper class to calculate budget information for multi-modal models."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
mm_registry: MultiModalRegistry,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.model_config = model_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.mm_registry = mm_registry
|
||||
self.cache = cache = processor_only_cache_from_config(
|
||||
model_config, mm_registry)
|
||||
|
||||
self.max_model_len = model_config.max_model_len
|
||||
self.max_num_reqs = scheduler_config.max_num_seqs
|
||||
|
||||
self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config,
|
||||
cache=cache)
|
||||
|
||||
max_tokens_by_modality = mm_registry \
|
||||
.get_max_tokens_per_item_by_nonzero_modality(model_config,
|
||||
cache=cache)
|
||||
|
||||
encoder_compute_budget, encoder_cache_size = compute_mm_encoder_budget(
|
||||
scheduler_config,
|
||||
max_tokens_by_modality,
|
||||
)
|
||||
|
||||
self.encoder_compute_budget = encoder_compute_budget
|
||||
self.encoder_cache_size = encoder_cache_size
|
||||
|
||||
max_items_per_prompt_by_modality = dict[str, int]()
|
||||
max_items_per_batch_by_modality = dict[str, int]()
|
||||
|
||||
for modality, max_tokens in max_tokens_by_modality.items():
|
||||
(
|
||||
max_items_per_prompt,
|
||||
max_items_per_batch,
|
||||
) = self.get_max_items(modality, max_tokens)
|
||||
|
||||
max_items_per_prompt_by_modality[modality] = max_items_per_prompt
|
||||
max_items_per_batch_by_modality[modality] = max_items_per_batch
|
||||
|
||||
self.max_tokens_by_modality = max_tokens_by_modality
|
||||
self.max_items_per_prompt_by_modality = max_items_per_prompt_by_modality
|
||||
self.max_items_per_batch_by_modality = max_items_per_batch_by_modality
|
||||
|
||||
def get_modality_with_max_tokens(self) -> str:
|
||||
max_tokens_by_modality = self.max_tokens_by_modality
|
||||
modality, _ = max(max_tokens_by_modality.items(), key=lambda x: x[1])
|
||||
|
||||
return modality
|
||||
|
||||
def get_encoder_budget(self) -> int:
|
||||
return min(self.encoder_compute_budget, self.encoder_cache_size)
|
||||
|
||||
def get_max_items(
|
||||
self,
|
||||
modality: str,
|
||||
max_tokens_per_item: int,
|
||||
) -> tuple[int, int]:
|
||||
if max_tokens_per_item == 0:
|
||||
return 0, 0
|
||||
|
||||
# Check how many items of this modality can be supported by
|
||||
# the encoder budget.
|
||||
encoder_budget = self.get_encoder_budget()
|
||||
|
||||
# TODO: handle encoder-decoder models once we support them.
|
||||
if encoder_budget == 0:
|
||||
return 0, 0
|
||||
|
||||
max_encoder_items_per_batch = encoder_budget // max_tokens_per_item
|
||||
|
||||
# Check how many items of this modality can be supported by
|
||||
# the decoder budget.
|
||||
mm_limit = self.mm_limits[modality]
|
||||
|
||||
max_items_per_prompt = max(
|
||||
1,
|
||||
min(mm_limit, self.max_model_len // max_tokens_per_item),
|
||||
)
|
||||
|
||||
scheduler_config = self.scheduler_config
|
||||
max_num_reqs = self.max_num_reqs
|
||||
|
||||
if not scheduler_config.enable_chunked_prefill:
|
||||
max_num_reqs = min(
|
||||
max_num_reqs,
|
||||
scheduler_config.max_num_batched_tokens // max_tokens_per_item,
|
||||
)
|
||||
|
||||
max_decoder_items_per_batch = max_num_reqs * max_items_per_prompt
|
||||
|
||||
max_items_per_batch = max(
|
||||
1,
|
||||
min(max_encoder_items_per_batch, max_decoder_items_per_batch),
|
||||
)
|
||||
|
||||
return max_items_per_prompt, max_items_per_batch
|
||||
|
||||
|
||||
@dataclass
|
||||
class AttentionGroup:
|
||||
backend: type[AttentionBackend]
|
||||
# When ubatching is enabled we will have a metadata builder for each ubatch
|
||||
# so that if they use internal persistant buffers for cudagraphs, and they
|
||||
# won't have to worry about conflicting with the other ubatches.
|
||||
metadata_builders: list[AttentionMetadataBuilder]
|
||||
layer_names: list[str]
|
||||
kv_cache_spec: KVCacheSpec
|
||||
|
||||
@staticmethod
|
||||
def create_with_metadata_builders(
|
||||
backend: type[AttentionBackend],
|
||||
layer_names: list[str],
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
num_metadata_builders: int = 1,
|
||||
) -> 'AttentionGroup':
|
||||
metadata_builders = [
|
||||
backend.get_builder_cls()(kv_cache_spec, layer_names, vllm_config,
|
||||
device)
|
||||
for _ in range(num_metadata_builders)
|
||||
]
|
||||
return AttentionGroup(backend, metadata_builders, layer_names,
|
||||
kv_cache_spec)
|
||||
|
||||
def get_metadata_builder(self,
|
||||
ubatch_id: int = 0) -> AttentionMetadataBuilder:
|
||||
assert len(self.metadata_builders) > ubatch_id
|
||||
return self.metadata_builders[ubatch_id]
|
||||
|
||||
|
||||
def sanity_check_mm_encoder_outputs(
|
||||
mm_embeddings: MultiModalEmbeddings,
|
||||
expected_num_items: int,
|
||||
) -> None:
|
||||
"""
|
||||
Perform sanity checks for the result of
|
||||
[`vllm.model_executor.models.SupportsMultiModal.get_multimodal_embeddings`][].
|
||||
"""
|
||||
assert isinstance(mm_embeddings, (list, tuple, torch.Tensor)), (
|
||||
"Expected multimodal embeddings to be a list/tuple of 2D tensors, "
|
||||
f"or a single 3D tensor, but got {type(mm_embeddings)} "
|
||||
"instead. This is most likely due to incorrect implementation "
|
||||
"of the model's `get_multimodal_embeddings` method.")
|
||||
|
||||
assert len(mm_embeddings) == expected_num_items, (
|
||||
"Expected number of multimodal embeddings to match number of "
|
||||
f"input items: {expected_num_items}, but got {len(mm_embeddings)=} "
|
||||
"instead. This is most likely due to incorrect implementation "
|
||||
"of the model's `get_multimodal_embeddings` method.")
|
||||
|
||||
assert all(e.ndim == 2 for e in mm_embeddings), (
|
||||
"Expected multimodal embeddings to be a sequence of 2D tensors, "
|
||||
f"but got tensors with shapes {[e.shape for e in mm_embeddings]} "
|
||||
"instead. This is most likely due to incorrect implementation "
|
||||
"of the model's `get_multimodal_embeddings` method.")
|
||||
|
||||
|
||||
def scatter_mm_placeholders(
|
||||
embeds: torch.Tensor,
|
||||
is_embed: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Scatter the multimodal embeddings into a contiguous tensor that represents
|
||||
the placeholder tokens.
|
||||
|
||||
[`vllm.multimodal.processing.PromptUpdateDetails.is_embed`][].
|
||||
|
||||
Args:
|
||||
embeds: The multimodal embeddings.
|
||||
Shape: `(num_embeds, embed_dim)`
|
||||
is_embed: A boolean mask indicating which positions in the placeholder
|
||||
tokens need to be filled with multimodal embeddings.
|
||||
Shape: `(num_placeholders, num_embeds)`
|
||||
"""
|
||||
if is_embed is None:
|
||||
return embeds
|
||||
|
||||
placeholders = embeds.new_full(
|
||||
(is_embed.shape[0], embeds.shape[-1]),
|
||||
fill_value=torch.nan,
|
||||
)
|
||||
placeholders[is_embed] = embeds
|
||||
return placeholders
|
||||
|
||||
|
||||
def gather_mm_placeholders(
|
||||
placeholders: torch.Tensor,
|
||||
is_embed: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Reconstructs the embeddings from the placeholder tokens.
|
||||
|
||||
This is the operation of [`scatter_mm_placeholders`]
|
||||
[vllm.v1.worker.utils.scatter_mm_placeholders].
|
||||
"""
|
||||
if is_embed is None:
|
||||
return placeholders
|
||||
|
||||
return placeholders[is_embed]
|
||||
|
||||
|
||||
def add_kv_sharing_layers_to_kv_cache_groups(
|
||||
shared_kv_cache_layers: dict[str, str],
|
||||
kv_cache_groups: list[KVCacheGroupSpec],
|
||||
runner_only_attn_layers: Optional[set[str]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Sets up KV cache sharing by reusing the allocated KV caches in `kv_caches`
|
||||
for layers that do not allocate its own KV cache, based on the mapping in
|
||||
`shared_kv_cache_layers`. Adds these layers to the corresponding KV cache
|
||||
group, which is needed to ensure that attention metadata is assigned later.
|
||||
|
||||
Args:
|
||||
shared_kv_cache_layers: Layer pairings for cross-layer KV sharing.
|
||||
If an Attention layer `layer_name` is in the keys of this dict, it
|
||||
means this layer will perform attention using the keys and values
|
||||
from the KV cache of `shared_kv_cache_layers[layer_name]`.
|
||||
kv_cache_groups: The KV cache groups of the model.
|
||||
"""
|
||||
layer_to_kv_cache_group: dict[str, KVCacheGroupSpec] = {}
|
||||
for kv_cache_group in kv_cache_groups:
|
||||
for layer_name in kv_cache_group.layer_names:
|
||||
layer_to_kv_cache_group[layer_name] = kv_cache_group
|
||||
|
||||
for layer_name, target_layer_name in shared_kv_cache_layers.items():
|
||||
tgt_kv_cache_group = layer_to_kv_cache_group[target_layer_name]
|
||||
tgt_kv_cache_group.layer_names.append(layer_name)
|
||||
|
||||
if runner_only_attn_layers is not None:
|
||||
runner_only_attn_layers.add(layer_name)
|
||||
|
||||
|
||||
def bind_kv_cache(
|
||||
kv_caches: dict[str, torch.Tensor],
|
||||
forward_context: dict[str, "Attention"],
|
||||
runner_kv_caches: list[torch.Tensor],
|
||||
num_attn_module: Optional[int] = 1,
|
||||
) -> None:
|
||||
"""
|
||||
Bind the allocated KV cache to both ModelRunner and forward context so
|
||||
that the KV cache can be used in the forward pass.
|
||||
|
||||
This function:
|
||||
1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with
|
||||
kv_caches.
|
||||
2) Associates each attention layer in the `forward_context` with its
|
||||
corresponding KV cache in kv_caches.
|
||||
|
||||
Args:
|
||||
kv_caches: The allocated kv_caches with layer names as keys.
|
||||
forward_context: The global forward context containing all Attention
|
||||
layers with layer names as keys.
|
||||
runner_kv_caches: The kv_cache declared by ModelRunner.
|
||||
"""
|
||||
# Bind kv_caches to ModelRunner
|
||||
assert len(runner_kv_caches) == 0
|
||||
|
||||
# Convert kv_caches dict to a list of tensors in the order of layer_index.
|
||||
index2name = defaultdict(list)
|
||||
for layer_name in kv_caches:
|
||||
index2name[extract_layer_index(layer_name,
|
||||
num_attn_module)].append(layer_name)
|
||||
|
||||
for layer_index in sorted(index2name.keys()):
|
||||
layer_names = index2name[layer_index]
|
||||
if len(layer_names) > 1:
|
||||
# One typical case is encoder-decoder model, e.g., bart.
|
||||
# The cross attention and self attention in the same decoder layer
|
||||
# has different layer_name but the same layer_index.
|
||||
|
||||
# TODO - analyze where runner_kv_caches is used and the right
|
||||
# way to ensure it properly reflects multiple attention layers
|
||||
# in the same decoder block.
|
||||
if current_platform.is_kunlun() or current_platform.is_cuda() or current_platform.is_xpu():
|
||||
# We know that the GPU runner is not impacted by this
|
||||
# case. Some test code depends on runner_kv_caches, but
|
||||
# not in a way that's impacted by ignoring this.
|
||||
pass
|
||||
else:
|
||||
raise NotImplementedError
|
||||
layer_name = layer_names[0]
|
||||
runner_kv_caches.append(kv_caches[layer_name])
|
||||
|
||||
# Bind kv_caches to forward context
|
||||
for layer_name, kv_cache in kv_caches.items():
|
||||
# NOTE: Use list because of v0 PP virtual engine.
|
||||
forward_context[layer_name].kv_cache = [kv_cache]
|
||||
|
||||
|
||||
def is_residual_scattered_for_sp(vllm_config: VllmConfig,
|
||||
num_input_tokens: int) -> bool:
|
||||
"""Check if the residual tensor is scattered for sequence parallelism.
|
||||
|
||||
The residual tensor is scattered across tensor parallel ranks when sequence
|
||||
parallelism and tensor parallelism is enabled, and the number of
|
||||
input tokens is one of the compilation sizes.
|
||||
"""
|
||||
if not vllm_config.compilation_config.pass_config.\
|
||||
enable_sequence_parallelism:
|
||||
return False
|
||||
|
||||
tp = vllm_config.parallel_config.tensor_parallel_size
|
||||
|
||||
if tp == 1:
|
||||
return False
|
||||
|
||||
# When sequence parallelism is enabled, we always pad num_input_tokens
|
||||
# to be a multiple of tensor_parallel_size (tp) earlier.
|
||||
assert num_input_tokens % tp == 0
|
||||
|
||||
# Currently, SP is only enabled for static size fx graphs.
|
||||
return (num_input_tokens in vllm_config.compilation_config.compile_sizes)
|
||||
Reference in New Issue
Block a user