Upgrade to vllm 0.17.0 corex v4.1 overlay

This commit is contained in:
2026-04-29 19:38:22 +08:00
parent 8fac6062e4
commit 938d0854a5
430 changed files with 35969 additions and 14511 deletions

View File

@@ -86,6 +86,26 @@ class AttentionBackend(ABC):
) -> tuple[int, ...]:
raise NotImplementedError
@classmethod
def get_kv_cache_block_dim(
cls,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> int:
"""Discover which tensor dim is the block index, since different
backends lay out dims differently."""
_S = 1234567
shape = cls.get_kv_cache_shape(
_S,
block_size,
num_kv_heads,
head_size,
cache_dtype_str=cache_dtype_str,
)
return shape.index(_S)
@staticmethod
def get_kv_cache_stride_order(
include_num_layers_dimension: bool = False,
@@ -301,10 +321,13 @@ class CommonAttentionMetadata:
query_start_loc: torch.Tensor
query_start_loc_cpu: torch.Tensor
"""(batch_size + 1,), the start location of each request in query Tensor"""
key_start_loc: torch.Tensor
"""(batch_size + 1,), the start location of each request in key/valye Tensor(none-crossattention)"""
seq_lens: torch.Tensor
"""(batch_size,), the number of computed tokens for each request"""
seq_lens_np: np.array
num_reqs: int
"""Number of requests"""
@@ -394,7 +417,9 @@ class CommonAttentionMetadata:
return CommonAttentionMetadata(
query_start_loc=self.query_start_loc[: num_actual_reqs + 1],
query_start_loc_cpu=self.query_start_loc_cpu[: num_actual_reqs + 1],
key_start_loc=self.key_start_loc[: num_actual_reqs + 1],
seq_lens=self.seq_lens[:num_actual_reqs],
seq_lens_np=self.seq_lens_np[:num_actual_reqs],
_seq_lens_cpu=self._seq_lens_cpu[:num_actual_reqs]
if self._seq_lens_cpu is not None
else None,
@@ -811,6 +836,28 @@ class MLAAttentionImpl(AttentionImplBase[T], Generic[T]):
"""MQA-style decode forward pass."""
raise NotImplementedError
def do_kv_cache_update(
self,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: torch.Tensor,
) -> None:
if kv_cache.numel() == 0:
return
from vllm import _custom_ops as ops
ops.concat_and_cache_mla(
kv_c_normed,
k_pe.squeeze(1),
kv_cache,
slot_mapping.flatten(),
kv_cache_dtype=kv_cache_dtype,
scale=k_scale,
)
class SparseMLAAttentionImpl(AttentionImplBase[T], Generic[T]):
"""Sparse MLA attention implementation with only forward_mqa method.
@@ -856,6 +903,28 @@ class SparseMLAAttentionImpl(AttentionImplBase[T], Generic[T]):
"""MQA-style decode forward pass."""
raise NotImplementedError
def do_kv_cache_update(
self,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: torch.Tensor,
) -> None:
if kv_cache.numel() == 0:
return
from vllm import _custom_ops as ops
ops.concat_and_cache_mla(
kv_c_normed,
k_pe.squeeze(1),
kv_cache,
slot_mapping.flatten(),
kv_cache_dtype=kv_cache_dtype,
scale=k_scale,
)
def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
return kv_cache_dtype.startswith("fp8")

View File

@@ -15,13 +15,11 @@ logger = init_logger(__name__)
_ROCM_FLASH_ATTN_AVAILABLE = False
if current_platform.is_cuda():
from vllm._custom_ops import reshape_and_cache_flash
# from vllm.vllm_flash_attn import ( # type: ignore[attr-defined]
# flash_attn_varlen_func,
# get_scheduler_metadata,
# )
from ixformer.contrib.vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
from vllm import _custom_ops as ops
reshape_and_cache_flash = ops.reshape_and_cache_flash
from ixformer.contrib.vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache, flash_attn_varlen_int8_func
elif current_platform.is_xpu():
from vllm import _custom_ops as ops
from vllm._xpu_ops import xpu_ops
@@ -53,67 +51,93 @@ elif current_platform.is_rocm():
reshape_and_cache_flash = ops.reshape_and_cache_flash
def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
# import here to avoid circular dependencies
from vllm.platforms import current_platform
return 3
def get_flash_attn_version(
requires_alibi: bool = False, head_size: int | None = None
) -> int | None:
if current_platform.is_xpu():
return 2
if current_platform.is_rocm():
# ROCm doesn't use vllm_flash_attn; return None to skip fa_version arg
return None
try:
from vllm.vllm_flash_attn.flash_attn_interface import (
fa_version_unsupported_reason,
is_fa_version_supported,
)
return None
# try:
# from vllm.vllm_flash_attn.flash_attn_interface import (
# fa_version_unsupported_reason,
# is_fa_version_supported,
# )
device_capability = current_platform.get_device_capability()
# device_capability = current_platform.get_device_capability()
assert device_capability is not None
# assert device_capability is not None
# 1. default version depending on platform
fa_version = (
3 if (device_capability.major == 9 and is_fa_version_supported(3)) else 2
)
# # 1. default version depending on platform
# if device_capability.major == 9 and is_fa_version_supported(3):
# # Hopper (SM90): prefer FA3
# fa_version = 3
# elif device_capability.major == 10 and is_fa_version_supported(4):
# # Blackwell (SM100+, restrict to SM100 for now): prefer FA4
# fa_version = 4
# else:
# # Fallback to FA2
# fa_version = 2
# 2. override if passed by environment or config
from vllm.config import get_current_vllm_config_or_none
# # 2. override if passed by environment or config
# from vllm.config import get_current_vllm_config_or_none
vllm_config = get_current_vllm_config_or_none()
if (
vllm_config is not None
and vllm_config.attention_config.flash_attn_version is not None
):
fa_version = vllm_config.attention_config.flash_attn_version
# vllm_config = get_current_vllm_config_or_none()
# if (
# vllm_config is not None
# and vllm_config.attention_config.flash_attn_version is not None
# ):
# fa_version = vllm_config.attention_config.flash_attn_version
# 3. fallback for unsupported combinations
if device_capability.major == 10 and fa_version == 3:
logger.warning_once(
"Cannot use FA version 3 on Blackwell platform, "
"defaulting to FA version 2."
)
fa_version = 2
# # 3. fallback for unsupported combinations
# if device_capability.major >= 10 and fa_version == 3:
# logger.warning_once(
# "Cannot use FA version 3 on Blackwell platform, "
# "defaulting to FA version 4 if supported, otherwise FA2."
# )
# fa_version = 4 if is_fa_version_supported(4) else 2
if requires_alibi and fa_version == 3:
logger.warning_once(
"Cannot use FA version 3 with ALiBi, defaulting to FA version 2."
)
fa_version = 2
# if requires_alibi and fa_version == 3:
# logger.warning_once(
# "Cannot use FA version 3 with ALiBi, defaulting to FA version 2."
# )
# fa_version = 2
if not is_fa_version_supported(fa_version):
logger.error(
"Cannot use FA version %d is not supported due to %s",
fa_version,
fa_version_unsupported_reason(fa_version),
)
# if requires_alibi and fa_version == 4:
# logger.warning_once(
# "Cannot use FA version 4 with ALiBi, defaulting to FA version 2."
# )
# fa_version = 2
assert is_fa_version_supported(fa_version)
return fa_version
except (ImportError, AssertionError):
return None
# # FA4 on SM100 (Blackwell) has TMEM capacity limits that restrict
# # supported head dimensions.
# # See: https://github.com/Dao-AILab/flash-attention/issues/1959
# if (
# fa_version == 4
# and device_capability.major >= 10
# and head_size is not None
# and head_size > 128
# ):
# logger.warning_once(
# "FA4 on Blackwell does not support head_size=%d due to TMEM "
# "capacity limits, defaulting to FA version 2.",
# head_size,
# )
# fa_version = 2
# if not is_fa_version_supported(fa_version):
# logger.error(
# "Cannot use FA version %d is not supported due to %s",
# fa_version,
# fa_version_unsupported_reason(fa_version),
# )
# assert is_fa_version_supported(fa_version)
# return fa_version
# except (ImportError, AssertionError):
# return None
def flash_attn_supports_fp8() -> bool:
@@ -124,10 +148,7 @@ def flash_attn_supports_fp8() -> bool:
def flash_attn_supports_sinks() -> bool:
if current_platform.is_xpu():
return True
else:
return get_flash_attn_version() == 3
return True
def flash_attn_supports_mla():
@@ -142,6 +163,10 @@ def flash_attn_supports_mla():
return is_fa_version_supported(
3
) and current_platform.is_device_capability_family(90)
# NOTE(Lucas): FA4 CuteDSL does NOT currently support MLA's non-standard
# head dimensions (576 for qk, 512 for v) due to TMEM capacity limits.
except (ImportError, AssertionError):
pass
return False

View File

@@ -4,7 +4,7 @@
import copy
from dataclasses import dataclass
from typing import ClassVar
from typing import ClassVar, Optional, Union, List
import numpy as np
import torch
@@ -23,15 +23,15 @@ from vllm.v1.attention.backends.fa_utils import (
is_flash_attn_varlen_func_available,
)
from vllm.v1.attention.ops.common import cp_lse_ag_out_rs
from vllm.v1.attention.ops.merge_attn_states import merge_attn_states
from ixformer.contrib.vllm_flash_attn import merge_attn_states
if is_flash_attn_varlen_func_available():
from vllm.v1.attention.backends.fa_utils import (
flash_attn_supports_sinks,
flash_attn_varlen_func,
flash_attn_with_kvcache,
# get_scheduler_metadata,
reshape_and_cache_flash,
flash_attn_varlen_int8_func
)
from vllm.config import VllmConfig, get_current_vllm_config, get_layers_from_vllm_config
from vllm.config.cache import CacheDType
@@ -50,9 +50,12 @@ from vllm.v1.attention.backend import (
from vllm.v1.attention.backends.utils import (
get_dcp_local_seq_lens,
get_kv_cache_layout,
split_decodes_and_prefills,
split_decodes_and_prefills
)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm import _custom_ops as ops
import vllm.envs as envs
import ixformer.inference.functions as ixf_ops
logger = init_logger(__name__)
@@ -63,23 +66,7 @@ class FlashAttentionBackend(AttentionBackend):
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
vllm_config = get_current_vllm_config()
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
if (
model_config
and model_config.is_hybrid
and (
cache_config.mamba_ssm_cache_dtype == "float32"
or cache_config.mamba_cache_dtype == "float32"
)
):
# NOTE(tdoublep): while in principle, FA supports
# MultipleOf(16), these are the block sizes that do not
# suffer from the NaN propagation problem described here:
# https://github.com/Dao-AILab/flash-attention/issues/1974
return [16, 32, 64]
return [MultipleOf(16)]
return [16, 32, 64]
forward_includes_kv_cache_update: bool = False
@@ -120,7 +107,8 @@ class FlashAttentionBackend(AttentionBackend):
) -> tuple[int, ...]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
# return (2, num_blocks, block_size, num_kv_heads, head_size)
if envs.VLLM_ATTN_OPT_LEVEL == 2:
return (3, num_blocks, num_kv_heads, block_size, head_size)
return (2, num_blocks, num_kv_heads, block_size, head_size)
@staticmethod
@@ -139,7 +127,7 @@ class FlashAttentionBackend(AttentionBackend):
# (num_blocks, num_kv_heads, num_layers, 2, block_size, head_size)
return (2, 4, 0, 1, 3, 5)
elif cache_layout == "HND":
stride_order = (0, 1, 3, 2, 4)
stride_order = (0, 1, 2, 3, 4)
else:
raise ValueError(f"Unknown cache layout format {cache_layout}.")
return stride_order
@@ -188,24 +176,22 @@ class FlashAttentionBackend(AttentionBackend):
if has_sink and device_capability < DeviceCapability(9, 0):
return "sink not supported on compute capability < 9.0"
return None
@dataclass
class FlashAttentionPrefillMetadata:
"""Prefill Specific Metadata"""
""" Prefill Specific Metadata """
block_table: torch.Tensor
query_start_loc: torch.Tensor
key_start_loc: torch.Tensor
max_query_len: int
@dataclass
class FlashAttentionDecodeMetadata:
block_table: torch.Tensor
query_start_loc: torch.Tensor
seq_lens: torch.Tensor
max_query_len: int
max_decode_seq_len: int
use_graph: bool
@dataclass
class FlashAttentionMetadata:
@@ -220,11 +206,12 @@ class FlashAttentionMetadata:
num_actual_tokens: int # Number of tokens excluding padding.
max_query_len: int
query_start_loc: torch.Tensor
key_start_loc: torch.Tensor
max_seq_len: int
seq_lens: torch.Tensor
block_table: torch.Tensor
slot_mapping: torch.Tensor
num_decodes: int
num_decode_tokens: int
num_prefills: int
@@ -235,7 +222,6 @@ class FlashAttentionMetadata:
cu_prefix_query_lens: torch.Tensor | None
prefix_kv_lens: torch.Tensor | None
suffix_kv_lens: torch.Tensor | None
cu_prefix_kv_lens: torch.Tensor | None
cu_suffix_kv_lens: torch.Tensor | None
@@ -247,7 +233,7 @@ class FlashAttentionMetadata:
scheduler_metadata: torch.Tensor | None = None
prefix_scheduler_metadata: torch.Tensor | None = None
max_num_splits: int = 0
prefill: FlashAttentionPrefillMetadata | None = None
decode: FlashAttentionDecodeMetadata | None = None
@@ -291,7 +277,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
else AttentionCGSupport.UNIFORM_BATCH
)
supports_update_block_table: bool = True
reorder_batch_threshold: ClassVar[int] = 1
@classmethod
@@ -316,6 +302,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
self.compilation_config = vllm_config.compilation_config
self.attention_config = vllm_config.attention_config
self.decode_use_graph = vllm_config.compilation_config.cudagraph_mode.decode_use_graph()
self.num_heads_q = self.model_config.get_num_attention_heads(
self.parallel_config
)
@@ -325,7 +312,6 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
self.block_size = kv_cache_spec.block_size
self.max_num_splits = 0 # No upper bound on the number of splits.
# self.aot_schedule = get_flash_attn_version() == 3
self.aot_schedule = False
try:
@@ -346,6 +332,9 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
)
self.max_cudagraph_size = self.compilation_config.max_cudagraph_capture_size
# Align decode/prefill split threshold with speculative decode query length
# when backend supports treating spec requests as decode.
self._init_reorder_batch_threshold(1, supports_spec_as_decode=True)
if self.use_full_cuda_graph and self.aot_schedule:
# FA3 scheduler_metadata size: 1 + round_up(batch_size, 4) * 4
@@ -388,15 +377,17 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
max_query_len = common_attn_metadata.max_query_len
max_seq_len = common_attn_metadata.max_seq_len
query_start_loc = common_attn_metadata.query_start_loc
key_start_loc = common_attn_metadata.key_start_loc
seq_lens = common_attn_metadata.seq_lens
seq_lens_np = common_attn_metadata.seq_lens_np
block_table_tensor = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
causal = common_attn_metadata.causal
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(common_attn_metadata)
)
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
split_decodes_and_prefills(
common_attn_metadata,
decode_threshold=self.reorder_batch_threshold,
)
assert num_decodes + num_prefills == num_reqs
assert num_decode_tokens + num_prefill_tokens == num_actual_tokens
@@ -467,11 +458,11 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
dcp_context_kv_lens = None
cu_prefix_query_lens = None
cu_prefix_kv_lens = None
cu_suffix_kv_lens = None
prefix_kv_lens = None
suffix_kv_lens = None
prefix_scheduler_metadata = None
cu_prefix_kv_lens = None
cu_suffix_kv_lens = None
if self.dcp_world_size > 1:
query_kv_lens = query_start_loc[1:] - query_start_loc[:-1]
@@ -507,11 +498,11 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
prefix_kv_lens = torch.tensor(
[common_prefix_len], dtype=torch.int32, device=self.device
)
# Use GPU tensor directly - no CPU sync needed
suffix_kv_lens = seq_lens[:num_reqs] - common_prefix_len
cu_prefix_kv_lens = torch.tensor([0, common_prefix_len],
dtype=torch.int32,
device=self.device)
# Use GPU tensor directly - no CPU sync needed
suffix_kv_lens = seq_lens[:num_reqs] - common_prefix_len
cu_suffix_kv_lens = torch.tensor([0,] + suffix_kv_lens.tolist(),
dtype=torch.int32,
@@ -542,7 +533,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
causal=causal,
)
# For FA3 + full cudagraph
max_num_splits = 0
max_num_splits = 0
if self.use_full_cuda_graph and scheduler_metadata is not None:
n = scheduler_metadata.shape[0]
self.scheduler_metadata[:n] = scheduler_metadata
@@ -552,50 +543,59 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
# output buffer.
self.scheduler_metadata[n:] = 0
scheduler_metadata = self.scheduler_metadata[:n]
if num_actual_tokens <= self.max_cudagraph_size:
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
# usage, because the intermediate buffers of size [num_splits,
# num_heads, num_tokens, head_size] are allocated. Therefore,
# we only set num_splits when using cuda graphs.
max_num_splits = self.max_num_splits
prefill_metadata = None
if num_prefills > 0:
reqs_start = num_decodes
prefill_query_start_loc = (
query_start_loc[reqs_start:] - query_start_loc[reqs_start]
)
prefill_key_start_loc = (
query_start_loc[reqs_start:] - query_start_loc[reqs_start]
)
reqs_start = num_decodes # prefill_start
prefill_query_start_loc = query_start_loc[
reqs_start:] - query_start_loc[reqs_start]
prefill_key_start_loc = key_start_loc[
reqs_start:] - key_start_loc[reqs_start]
prefill_metadata = FlashAttentionPrefillMetadata(
block_table=block_table_tensor[reqs_start:, ...],
query_start_loc=prefill_query_start_loc,
key_start_loc=prefill_key_start_loc,
max_query_len=max_query_len,
)
block_table=block_table_tensor[reqs_start:, ...],
query_start_loc=prefill_query_start_loc,
key_start_loc=prefill_key_start_loc,
max_query_len=max_query_len,
)
decode_metadata = None
if num_decodes > 0:
reqs_start = num_decodes
reqs_start = num_decodes # prefill_start
decode_query_start_loc = query_start_loc[: reqs_start + 1]
decode_query_lens = (
decode_query_start_loc[1:] - decode_query_start_loc[:-1]
)
decode_metadata = FlashAttentionDecodeMetadata(
block_table=block_table_tensor[:reqs_start, ...],
query_start_loc=decode_query_start_loc,
seq_lens=seq_lens[:reqs_start],
max_decode_seq_len=torch.max(seq_lens[:reqs_start]).item(),
max_query_len=decode_query_lens.max().item(),
max_decode_seq_len=np.max(seq_lens_np[:reqs_start]).item(),
use_graph=num_prefills==0 and self.decode_use_graph
)
attn_metadata = FlashAttentionMetadata(
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,
query_start_loc=query_start_loc,
key_start_loc=key_start_loc,
max_seq_len=max_seq_len,
seq_lens=seq_lens,
block_table=block_table_tensor,
slot_mapping=slot_mapping,
max_dcp_context_kv_len=max_dcp_context_kv_len,
dcp_context_kv_lens=dcp_context_kv_lens,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
num_prefills=num_prefills,
max_dcp_context_kv_len=max_dcp_context_kv_len,
dcp_context_kv_lens=dcp_context_kv_lens,
use_cascade=use_cascade,
common_prefix_len=common_prefix_len,
scheduler_metadata=scheduler_metadata,
@@ -607,8 +607,8 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
prefix_scheduler_metadata=prefix_scheduler_metadata,
max_num_splits=max_num_splits,
causal=causal,
prefill=prefill_metadata,
decode=decode_metadata,
prefill = prefill_metadata,
decode = decode_metadata,
)
return attn_metadata
@@ -621,6 +621,19 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
new_metadata = copy.copy(metadata)
new_metadata.block_table = blk_table
new_metadata.slot_mapping = slot_mapping
# Keep nested prefill/decode block tables in sync. Decode path consumes
# `attn_metadata.decode.block_table`, so updating only the top-level
# `block_table` is insufficient when metadata is reused across groups.
if metadata.decode is not None:
new_decode = copy.copy(metadata.decode)
reqs_start = metadata.num_decodes
new_decode.block_table = blk_table[:reqs_start, ...]
new_metadata.decode = new_decode
if metadata.prefill is not None:
new_prefill = copy.copy(metadata.prefill)
reqs_start = metadata.num_decodes
new_prefill.block_table = blk_table[reqs_start:, ...]
new_metadata.prefill = new_prefill
return new_metadata
def use_cascade_attention(self, *args, **kwargs) -> bool:
@@ -667,7 +680,15 @@ class FlashAttentionImpl(AttentionImpl):
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.attn_type = attn_type
self.vllm_flash_attn_version = get_flash_attn_version()
self.vllm_flash_attn_version = get_flash_attn_version(
requires_alibi=alibi_slopes is not None,
head_size=head_size,
)
logger.info_once(
"Using FlashAttention version %s",
self.vllm_flash_attn_version,
scope="local",
)
# Cache the batch invariant result for use in forward passes
self.batch_invariant_enabled = vllm_is_batch_invariant()
@@ -677,6 +698,7 @@ class FlashAttentionImpl(AttentionImpl):
)
self.sinks = sinks
if self.sinks is not None:
assert flash_attn_supports_sinks(), (
"Sinks are only supported in FlashAttention 3"
@@ -687,6 +709,28 @@ class FlashAttentionImpl(AttentionImpl):
)
self.supports_quant_query_input = True
self.supports_per_head_quant_scales = (
self.vllm_flash_attn_version >= 3
if self.vllm_flash_attn_version is not None
else False
)
assert envs.VLLM_ATTN_OPT_LEVEL in [0, 1, 2], "VLLM_ATTN_OPT_LEVEL only support [0 for non-quant, 1 for I8Q_I8K_I8V, 2 for I8Q_I8K_F16V] now! but got {}".format(envs.VLLM_ATTN_OPT_LEVEL)
'''
quant_type = 0
attention:f16 qkv
cache:f16 kv cache
quant_type = 1
attention:int8q int8k int8v
cache:
int8 k cache && fp32 k cache scale
int8 v cache && fp32 v cache scale(load from file, dont update)
quant_type = 2
attention:int8q int8k fp16v
cache:
int8 k cache && fp32 k cache scale
fp16 v cache
'''
self.quant_type = int(envs.VLLM_ATTN_OPT_LEVEL)
def forward(
self,
@@ -698,7 +742,7 @@ class FlashAttentionImpl(AttentionImpl):
attn_metadata: FlashAttentionMetadata,
output: torch.Tensor | None = None,
sqrt_alibi: bool = False,
kv_cache_scale: torch.Tensor | None = None,
kv_cache_scale: Union[torch.Tensor, List[torch.Tensor]] | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
@@ -711,6 +755,7 @@ class FlashAttentionImpl(AttentionImpl):
kv_cache: shape =
[2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
kv_cache_scale = [num_blocks, num_kv_heads, block_size] + [num_kv_heads, head_size]
Returns:
shape = [num_tokens, num_heads * head_size]
NOTE: FP8 quantization, flash-attn expect the size of
@@ -718,9 +763,9 @@ class FlashAttentionImpl(AttentionImpl):
We use torch's .expand() to avoid duplicating values
"""
assert output is not None, "Output tensor must be provided."
assert self.vllm_flash_attn_version is not None, (
"FlashAttention version not detected."
)
# assert self.vllm_flash_attn_version is not None, (
# "FlashAttention version not detected."
# )
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
@@ -729,13 +774,12 @@ class FlashAttentionImpl(AttentionImpl):
if attn_metadata is None:
# Profiling run.
# return output.fill_(0)
return output.fill_(0).view(-1, self.num_heads * self.head_size)
return output.view(-1, self.num_heads * self.head_size)
softmax_scale: float = self.scale
window_size = self.sliding_window
alibi_slopes: torch.Tensor | None = self.alibi_slopes
logits_soft_cap: float | None = self.logits_soft_cap
alibi_slopes: torch.Tensor = self.alibi_slopes
logits_soft_cap: float = self.logits_soft_cap
attn_type = self.attn_type
@@ -761,18 +805,140 @@ class FlashAttentionImpl(AttentionImpl):
output[:num_actual_tokens],
attn_metadata,
layer,
)
).view(-1, self.num_heads * self.head_size)
# For decoder and cross-attention, use KV cache as before
key_cache, value_cache = kv_cache.unbind(0)
has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0
decode_only = has_decode and not has_prefill
num_decode_tokens = attn_metadata.num_decode_tokens
if self.quant_type == 0:
key_cache, value_cache = kv_cache.unbind(0)
elif self.quant_type == 1:
i8_key_cache, i8_value_cache = kv_cache.unbind(0)
num_blocks, num_kv_heads, block_size, head_size = i8_key_cache.shape
key_scale_cache, value_scale_cache = kv_cache_scale
assert key_scale_cache.shape == (num_blocks, num_kv_heads, block_size) and key_scale_cache.dtype == torch.float32, f"key_scale_cache.shape {key_scale_cache.shape} != (num_blocks, num_kv_heads, block_size) or key_scale_cache.dtype {key_scale_cache.dtype} != torch.float32"
assert value_scale_cache.shape == (num_kv_heads, head_size) and value_scale_cache.dtype == torch.float32, f"value_scale_cache.shape {value_scale_cache.shape} != (num_kv_heads, head_size) or value_scale_cache.dtype {value_scale_cache.dtype} != torch.float32"
value_cache_info = (i8_value_cache, value_scale_cache)
elif self.quant_type == 2:
# key_cache 是 f16value_cache 是 int8
i8_key_cache = kv_cache[0]
num_blocks, num_kv_heads, block_size, head_size = i8_key_cache.shape
value_cache = kv_cache[1:].view(query.dtype).reshape(num_blocks, num_kv_heads, block_size, head_size)
key_scale_cache = kv_cache_scale
value_cache_info = (value_cache, None)
decode_q = query[:num_decode_tokens]
prefill_q = query[num_decode_tokens:]
prefill_output = output[num_decode_tokens:]
decode_output = output[:num_decode_tokens]
if self.quant_type == 1:
if decode_only:
int8_query, query_scale = ixf_ops.scaled_int8_quant_for_attn(
query, 2, transpose_scale=False
)
i8_key, key_scale = ixf_ops.scaled_int8_quant_for_attn(
key, 2, transpose_scale=False
)
i8_value, _value_scale = ixf_ops.scaled_int8_quant_for_attn(
value, 0, transpose_scale=False, scale=value_cache_info[1]
)
else:
int8_query, query_scale = ixf_ops.scaled_int8_quant_for_attn(
query, 2, transpose_scale=True
)
i8_key, key_scale = ixf_ops.scaled_int8_quant_for_attn(
key, 2, transpose_scale=False
)
i8_value, _value_scale = ixf_ops.scaled_int8_quant_for_attn(
value, 0, transpose_scale=False, scale=value_cache_info[1]
)
elif self.quant_type == 2:
'''
origin key cache
num_blocks, num_kv_heads, block_size, head_size f16
reformat key cache
key_cache_i8 : num_blocks, num_kv_heads, block_size, head_size int8
key_scale_cache : num_blocks, num_kv_heads, block_size fp32
'''
if decode_only:
int8_query, query_scale = ixf_ops.scaled_int8_quant_for_attn(
query, 2, transpose_scale=False
)
i8_key, key_scale = ixf_ops.scaled_int8_quant_for_attn(
key, 2, transpose_scale=False
)
else:
int8_query, query_scale = ixf_ops.scaled_int8_quant_for_attn(
query, 2, transpose_scale=True
)
i8_key, key_scale = ixf_ops.scaled_int8_quant_for_attn(
key, 2, transpose_scale=False
)
else:
if layer.quant_manager is not None and layer.quant_manager.check_enable():
i8_value, value_scale = ixf_ops.scaled_int8_quant_for_attn(
value, 0, transpose_scale=False
)
layer.quant_manager.update_data(value_scale)
# key and value may be None in the case of cross attention. They are
# calculated once based on the output from the encoder and then cached
# in KV cache.
if (
self.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# not padded. However, we don't need to do key[:num_actual_tokens]
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
if self.quant_type == 1:
if has_prefill:
ixf_ops.reshape_and_cache_flash_int8(
key=i8_key,
value=i8_value,
k_scale=key_scale,
key_cache=i8_key_cache,
value_cache=value_cache_info[0],
key_scale_cache=key_scale_cache,
slot_mapping=attn_metadata.slot_mapping,
kv_cache_dtype="",
)
elif self.quant_type == 2:
if has_prefill:
ixf_ops.reshape_and_cache_flash_mix(
key=i8_key,
value=value,
k_scale=key_scale,
key_cache=i8_key_cache,
value_cache=value_cache_info[0],
key_scale_cache=key_scale_cache,
slot_mapping=attn_metadata.slot_mapping,
kv_cache_dtype="",
)
else:
ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
if self.kv_cache_dtype.startswith("fp8"):
# queries are quantized in the attention layer
@@ -783,19 +949,6 @@ class FlashAttentionImpl(AttentionImpl):
value_cache = value_cache.view(dtype)
if not attn_metadata.use_cascade:
cu_seqlens_q = attn_metadata.query_start_loc
seqused_k = attn_metadata.seq_lens
max_seqlen_q = attn_metadata.max_query_len
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table
scheduler_metadata = attn_metadata.scheduler_metadata
descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads)
q_descale = layer._q_scale.expand(descale_shape)
k_descale = layer._k_scale.expand(descale_shape)
v_descale = layer._v_scale.expand(descale_shape)
if self.dcp_world_size > 1:
self._forward_with_dcp(
query[:num_actual_tokens],
@@ -805,79 +958,140 @@ class FlashAttentionImpl(AttentionImpl):
value_cache,
output[:num_actual_tokens],
attn_metadata,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
)
return output.view(-1, self.num_heads * self.head_size)
else:
sliding_window_size = (
list(self.sliding_window)
if self.sliding_window is not None
else None
)
if has_prefill:
flash_attn_varlen_func(
q=prefill_q,
k=key_cache,
v=value_cache,
cu_seqlens_q=attn_metadata.prefill.query_start_loc,
cu_seqlens_k=attn_metadata.prefill.query_start_loc,
max_seqlen_q=attn_metadata.prefill.max_query_len,
max_seqlen_k=attn_metadata.max_query_len,
softmax_scale=softmax_scale,
causal=True,
window_size=sliding_window_size,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
sqrt_alibi=sqrt_alibi,
sinks=self.sinks,
out=prefill_output,
block_table=attn_metadata.prefill.block_table,
)
# key = key[num_decode_tokens:]
# value = value[num_decode_tokens:]
# int8 attn
if self.quant_type > 0:
flash_attn_varlen_int8_func(
q=int8_query[num_decode_tokens:],
k=i8_key_cache,
v=value_cache_info[0],
q_scale=query_scale[:, num_decode_tokens:],
k_scale=key_scale_cache,
v_scale=value_cache_info[1],
cu_seqlens_q=attn_metadata.prefill.query_start_loc,
cu_seqlens_k=attn_metadata.prefill.key_start_loc,
max_seqlen_q=attn_metadata.prefill.max_query_len,
max_seqlen_k=attn_metadata.max_query_len,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
sqrt_alibi=sqrt_alibi,
out=prefill_output,
block_table=attn_metadata.prefill.block_table,
output_dtype=query.dtype
)
else:
flash_attn_varlen_func(
q=prefill_q,
k=key_cache,
v=value_cache,
cu_seqlens_q=attn_metadata.prefill.query_start_loc,
cu_seqlens_k=attn_metadata.prefill.key_start_loc,
max_seqlen_q=attn_metadata.prefill.max_query_len,
max_seqlen_k=attn_metadata.max_query_len,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
sqrt_alibi=sqrt_alibi,
sinks=self.sinks,
out=prefill_output,
block_table=attn_metadata.prefill.block_table,
)
if has_decode:
flash_attn_with_kvcache(
q=decode_q.unsqueeze(1),
k_cache=key_cache.contiguous(),
v_cache=value_cache.contiguous(),
block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens,
softmax_scale=softmax_scale,
causal=True,
window_size=sliding_window_size,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
use_sqrt_alibi=sqrt_alibi,
out=decode_output.unsqueeze(1),
max_context_len=attn_metadata.decode.max_decode_seq_len,
# sinks=self.sinks,
)
# for mtp + cuda graph
max_q_len = attn_metadata.decode.max_query_len if attn_metadata.decode is not None else attn_metadata.max_query_len
max_ct_len = attn_metadata.decode.max_decode_seq_len if attn_metadata.decode is not None else attn_metadata.max_seq_len
if self.quant_type in [1, 2]:
para_dict = dict(
output=decode_output,
query=int8_query[:num_decode_tokens],
key_cache=i8_key_cache,
query_scale=query_scale[:num_decode_tokens] if decode_only else query_scale[:, :num_decode_tokens].t().contiguous(),
key_scale_cache=key_scale_cache,
num_kv_heads=self.num_kv_heads,
scale=softmax_scale,
block_tables=attn_metadata.decode.block_table,
context_lens=attn_metadata.decode.seq_lens,
block_size=i8_key_cache.shape[-2],
softcap=logits_soft_cap,
alibi_slopes=alibi_slopes,
causal=True,
window_left=window_size[0],
window_right=window_size[1],
use_sqrt_alibi = sqrt_alibi,
use_cuda_graph=attn_metadata.decode.use_graph if decode_only else False,
max_context_len=max_ct_len,
# mtp
cu_query_lens=attn_metadata.decode.query_start_loc,
max_query_len=max_q_len,
)
if self.quant_type == 1:
para_dict.update(
dict(
value_cache=value_cache_info[0],
value_scale_cache=value_cache_info[1],
)
)
# for kv + k_scale write fusion
if decode_only:
para_dict.update(
dict(
save_key=i8_key[:num_decode_tokens],
save_value=i8_value[:num_decode_tokens],
save_key_scale=key_scale[:num_decode_tokens],
)
)
ixf_ops.vllm_paged_attention_int8(**para_dict)
elif self.quant_type == 2:
para_dict.update(
dict(
value_cache=value_cache,
)
)
if decode_only:
para_dict.update(
dict(
save_key=i8_key[:num_decode_tokens],
save_value=value[:num_decode_tokens].contiguous(),
save_key_scale=key_scale[:num_decode_tokens],
)
)
ixf_ops.vllm_paged_attention_mix(
**para_dict
)
else:
flash_attn_with_kvcache(
q=decode_q.unsqueeze(1),
k_cache=key_cache,
v_cache=value_cache,
block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens,
max_query_len=max_q_len,
cu_query_lens=attn_metadata.decode.query_start_loc,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
use_sqrt_alibi=sqrt_alibi,
sinks=self.sinks,
out=decode_output.unsqueeze(1),
use_cuda_graph=attn_metadata.decode.use_graph,
max_context_len=max_ct_len
)
# Compute attention and update output up to `num_actual_tokens`.
return output.view(-1, self.num_heads * self.head_size)
# flash_attn_varlen_func(
# q=query[:num_actual_tokens],
# k=key_cache,
# v=value_cache,
# out=output[:num_actual_tokens],
# cu_seqlens_q=cu_seqlens_q,
# max_seqlen_q=max_seqlen_q,
# seqused_k=seqused_k,
# max_seqlen_k=max_seqlen_k,
# softmax_scale=self.scale,
# causal=attn_metadata.causal,
# alibi_slopes=self.alibi_slopes,
# window_size=sliding_window_size,
# block_table=block_table,
# softcap=self.logits_soft_cap,
# scheduler_metadata=scheduler_metadata,
# fa_version=self.vllm_flash_attn_version,
# q_descale=q_descale,
# k_descale=k_descale,
# v_descale=v_descale,
# num_splits=attn_metadata.max_num_splits,
# s_aux=self.sinks,
# )
# return output
# Cascade attention (rare case).
cascade_attention(
@@ -906,12 +1120,7 @@ class FlashAttentionImpl(AttentionImpl):
v_descale=layer._v_scale,
s_aux=self.sinks,
)
# return output
return (
output[:num_actual_tokens]
.contiguous()
.view(-1, self.num_heads * self.head_size)
)
return output.view(-1, self.num_heads * self.head_size)
def do_kv_cache_update(
self,
@@ -935,7 +1144,7 @@ class FlashAttentionImpl(AttentionImpl):
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
reshape_and_cache_flash(
ops.reshape_and_cache_flash(
key,
value,
key_cache,
@@ -959,9 +1168,9 @@ class FlashAttentionImpl(AttentionImpl):
k_descale: torch.Tensor | None = None,
v_descale: torch.Tensor | None = None,
) -> torch.Tensor:
assert self.vllm_flash_attn_version is not None, (
"FlashAttention version not detected."
)
# assert self.vllm_flash_attn_version is not None, (
# "FlashAttention version not detected."
# )
cu_seqlens_q = attn_metadata.query_start_loc
max_seqlen_q = attn_metadata.max_query_len
@@ -969,27 +1178,22 @@ class FlashAttentionImpl(AttentionImpl):
query = query.contiguous()
query_across_dcp = get_dcp_group().all_gather(query, dim=1)
cu_dcp_kv_klens = attn_metadata.dcp_context_kv_lens.cumsum(dim=0, dtype=torch.int32)
new_tensor = torch.tensor([0],
device=attn_metadata.dcp_context_kv_lens.device,
dtype=attn_metadata.dcp_context_kv_lens.dtype)
cu_seqlens_k = torch.cat([new_tensor, cu_dcp_kv_klens])
sliding_window_size = (
list(self.sliding_window) if self.sliding_window is not None else None
)
cu_seqlens_k = torch.cat(
[
torch.zeros(1, device=cu_seqlens_q.device, dtype=cu_seqlens_q.dtype),
attn_metadata.dcp_context_kv_lens.cumsum(
dim=0, dtype=cu_seqlens_q.dtype
),
],
dim=0,
)
context_attn_out, context_lse = flash_attn_varlen_func(
q=query_across_dcp,
k=key_cache,
v=value_cache,
out=None,
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
# seqused_k=attn_metadata.dcp_context_kv_lens,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=attn_metadata.max_dcp_context_kv_len,
softmax_scale=self.scale,
causal=False,
@@ -998,11 +1202,6 @@ class FlashAttentionImpl(AttentionImpl):
block_table=block_table,
softcap=self.logits_soft_cap,
return_softmax_lse=True,
# scheduler_metadata=attn_metadata.scheduler_metadata,
# fa_version=self.vllm_flash_attn_version,
# q_descale=q_descale,
# k_descale=k_descale,
# v_descale=v_descale,
)
# FA returns LSE in shape [ H, B ] but cp_lse_ag_out_rs wants [ B, H ]
context_attn_out_cor, context_lse_cor = cp_lse_ag_out_rs(
@@ -1028,10 +1227,6 @@ class FlashAttentionImpl(AttentionImpl):
window_size=sliding_window_size,
softcap=self.logits_soft_cap,
return_softmax_lse=True,
# fa_version=self.vllm_flash_attn_version,
# q_descale=q_descale,
# k_descale=k_descale,
# v_descale=v_descale,
)
assert context_attn_out_cor.shape == query_attn_out.shape
assert context_lse_cor.shape == query_lse.shape
@@ -1040,7 +1235,7 @@ class FlashAttentionImpl(AttentionImpl):
context_lse_cor,
query_attn_out,
query_lse,
output,
output
)
def _forward_encoder_attention(
@@ -1062,9 +1257,9 @@ class FlashAttentionImpl(AttentionImpl):
attn_metadata: Encoder attention metadata
layer: The attention layer
"""
assert self.vllm_flash_attn_version is not None, (
"FlashAttention version not detected."
)
# assert self.vllm_flash_attn_version is not None, (
# "FlashAttention version not detected."
# )
# For encoder attention, process FP8 quantization if needed
if self.kv_cache_dtype.startswith("fp8"):
@@ -1101,18 +1296,9 @@ class FlashAttentionImpl(AttentionImpl):
alibi_slopes=self.alibi_slopes,
window_size=sliding_window_size,
softcap=self.logits_soft_cap,
# fa_version=self.vllm_flash_attn_version,
# q_descale=layer._q_scale.expand(descale_shape),
# k_descale=layer._k_scale.expand(descale_shape),
# v_descale=layer._v_scale.expand(descale_shape),
# num_splits=1 if self.batch_invariant_enabled else 0,
)
return (
output[: attn_metadata.num_actual_tokens]
.contiguous()
.view(-1, self.num_heads * self.head_size)
)
return output
def use_cascade_attention(
@@ -1203,8 +1389,6 @@ def cascade_attention(
cu_prefix_query_lens: torch.Tensor,
cu_prefix_kv_lens: torch.Tensor,
cu_suffix_kv_lens: torch.Tensor,
# prefix_kv_lens: torch.Tensor,
# suffix_kv_lens: torch.Tensor,
max_kv_len: int,
softmax_scale: float,
alibi_slopes: torch.Tensor | None,
@@ -1228,12 +1412,13 @@ def cascade_attention(
)
num_tokens = query.shape[0]
# block_size = key_cache.shape[-3]
block_size = key_cache.shape[-2]
assert common_prefix_len % block_size == 0
num_common_kv_blocks = common_prefix_len // block_size
assert num_common_kv_blocks > 0
descale_shape = (cu_prefix_query_lens.shape[0] - 1, key_cache.shape[-2])
assert q_descale is None or q_descale==1, f"q_descale is not None, q_descale: {q_descale}"
assert k_descale is None or k_descale==1, f"k_descale is not None, k_descale: {k_descale}"
assert v_descale is None or v_descale==1, f"v_descale is not None, v_descale: {v_descale}"
# Process shared prefix.
prefix_output, prefix_lse = flash_attn_varlen_func(
@@ -1241,7 +1426,6 @@ def cascade_attention(
k=key_cache,
v=value_cache,
cu_seqlens_q=cu_prefix_query_lens,
# seqused_k=prefix_kv_lens,
cu_seqlens_k=cu_prefix_kv_lens,
max_seqlen_q=num_tokens,
max_seqlen_k=common_prefix_len,
@@ -1251,26 +1435,14 @@ def cascade_attention(
block_table=block_table[:1],
softcap=logits_soft_cap,
return_softmax_lse=True,
# scheduler_metadata=prefix_scheduler_metadata,
# fa_version=fa_version,
# q_descale=q_descale.expand(descale_shape) if q_descale is not None else None,
# k_descale=k_descale.expand(descale_shape) if k_descale is not None else None,
# v_descale=v_descale.expand(descale_shape) if v_descale is not None else None,
# s_aux is incorporated into prefix_lse inside the GPU kernel,
# enabling its effect during the final attention merge.
# s_aux=s_aux,
# num_splits=1 if vllm_is_batch_invariant() else max_num_splits,
)
descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2])
# Process suffix per query.
suffix_output, suffix_lse = flash_attn_varlen_func(
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=cu_query_lens,
# seqused_k=suffix_kv_lens,
cu_seqlens_k=cu_suffix_kv_lens,
max_seqlen_q=max_query_len,
max_seqlen_k=max_kv_len - common_prefix_len,
@@ -1280,14 +1452,6 @@ def cascade_attention(
block_table=block_table[:, num_common_kv_blocks:],
softcap=logits_soft_cap,
return_softmax_lse=True,
# scheduler_metadata=suffix_scheduler_metadata,
# fa_version=fa_version,
# q_descale=q_descale.expand(descale_shape) if q_descale is not None else None,
# k_descale=k_descale.expand(descale_shape) if k_descale is not None else None,
# v_descale=v_descale.expand(descale_shape) if v_descale is not None else None,
# num_splits=1 if vllm_is_batch_invariant() else max_num_splits,
)
# Merge prefix and suffix outputs, and store the result in output.
# merge_attn_states(output, prefix_output, prefix_lse, suffix_output, suffix_lse)
merge_attn_states(prefix_output, prefix_lse, suffix_output, suffix_lse, output)

View File

@@ -13,7 +13,7 @@ from flashinfer import (
BatchPrefillWithRaggedKVCacheWrapper,
MultiLevelCascadeAttentionWrapper,
)
from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache
from flashinfer.decode import fast_decode_plan, trtllm_batch_decode_with_kv_cache
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
from flashinfer.utils import FP4Tensor
from typing_extensions import override
@@ -199,14 +199,14 @@ class BatchDCPPrefillWrapper:
):
"""Plan the prefill operation with given parameters."""
self._context.plan(
qo_indptr_cpu,
paged_kv_indptr_cpu,
paged_kv_indices,
paged_kv_last_page_len_cpu,
num_qo_heads * dcp_world_size,
num_kv_heads,
head_dim,
page_size,
qo_indptr=qo_indptr_cpu,
paged_kv_indptr=paged_kv_indptr_cpu,
paged_kv_indices=paged_kv_indices,
paged_kv_last_page_len=paged_kv_last_page_len_cpu,
num_qo_heads=num_qo_heads * dcp_world_size,
num_kv_heads=num_kv_heads,
head_dim_qk=head_dim,
page_size=page_size,
causal=False, # This is context run
sm_scale=sm_scale,
window_left=window_left,
@@ -374,13 +374,13 @@ class FlashInferBackend(AttentionBackend):
@classmethod
def get_required_kv_cache_layout(cls) -> KVCacheLayoutType | None:
from vllm.platforms import current_platform
capability = current_platform.get_device_capability()
if capability is not None and capability.major == 10:
return "HND"
return None
forward_includes_kv_cache_update: bool = False
@dataclass
class FIPrefill:
@@ -573,20 +573,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# try to use fp8 q if kv cache is fp8, and will fall back to model dtype
# if TRTLLM attention kernel is not used when building attn metadata
can_use_trtllm = can_use_trtllm_attention(self.num_qo_heads, self.num_kv_heads)
# TRTLLM attention requires strictly contiguous KV cache tensors.
# When KV transfer (P/D disaggregation) is enabled, the KV cache may be
# permuted into non-contiguous views, which causes assertion failures.
self._kv_transfer_enabled = vllm_config.kv_transfer_config is not None
if can_use_trtllm and self._kv_transfer_enabled:
logger.info_once(
"TRTLLM attention is disabled because KV transfer "
"(P/D disaggregation) is enabled. TRTLLM attention requires "
"strictly contiguous KV cache tensors which may not be "
"guaranteed with KV transfer."
)
can_use_trtllm = False
if (
can_use_trtllm
and not vllm_config.attention_config.disable_flashinfer_q_quantization
@@ -816,6 +802,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
page_size,
paged_kv_last_page_len_np,
)
self.paged_kv_last_page_len.gpu[:num_reqs].copy_(
self.paged_kv_last_page_len.cpu[:num_reqs], non_blocking=True
)
return paged_kv_indices
def build(
@@ -860,9 +849,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
has_sinks=self.has_sinks,
has_spec=uses_spec_reorder,
)
# KV transfer requires non-contiguous KV cache views, incompatible with TRTLLM
if self._kv_transfer_enabled:
prefill_use_trtllm = False
decode_use_trtllm = (
self.use_trtllm_decode_attention and self.dcp_world_size <= 1
)
@@ -997,14 +983,17 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
attn_metadata.cascade_wrapper = self._get_cascade_wrapper()
attn_metadata.cascade_wrapper.plan(
[shared_qo_indptr_cpu, qo_indptr_cpu],
[shared_kv_page_indptr_cpu, paged_kv_indptr_cpu],
[shared_kv_page_indices_cpu, paged_kv_indices],
[shared_kv_last_page_len_cpu, paged_kv_last_page_len_cpu],
self.num_qo_heads,
self.num_kv_heads,
self.head_dim,
self.page_size,
qo_indptr_arr=[shared_qo_indptr_cpu, qo_indptr_cpu],
paged_kv_indptr_arr=[shared_kv_page_indptr_cpu, paged_kv_indptr_cpu],
paged_kv_indices_arr=[shared_kv_page_indices_cpu, paged_kv_indices],
paged_kv_last_page_len=[
shared_kv_last_page_len_cpu,
paged_kv_last_page_len_cpu,
],
num_qo_heads=self.num_qo_heads,
num_kv_heads=self.num_kv_heads,
head_dim=self.head_dim,
page_size=self.page_size,
causal=True,
sm_scale=self.sm_scale,
window_left=self.window_left,
@@ -1082,14 +1071,14 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
BatchPrefillWithPagedKVCacheWrapper,
)
prefill_wrapper.plan(
qo_indptr_prefill_cpu,
paged_kv_indptr_prefill_cpu,
paged_kv_indices,
paged_kv_last_page_len_prefill_cpu,
self.num_qo_heads,
self.num_kv_heads,
self.head_dim,
self.page_size,
qo_indptr=qo_indptr_prefill_cpu,
paged_kv_indptr=paged_kv_indptr_prefill_cpu,
paged_kv_indices=paged_kv_indices,
paged_kv_last_page_len=paged_kv_last_page_len_prefill_cpu,
num_qo_heads=self.num_qo_heads,
num_kv_heads=self.num_kv_heads,
head_dim_qk=self.head_dim,
page_size=self.page_size,
causal=True,
sm_scale=self.sm_scale,
window_left=self.window_left,
@@ -1130,14 +1119,15 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# in atten_metadata when using cudagraph.
fast_plan_decode(
decode_wrapper,
self.paged_kv_indptr.cpu[: num_input_tokens + 1],
paged_kv_indices,
self.paged_kv_last_page_len.cpu[:num_input_tokens],
seq_lens_cpu[:num_input_tokens],
self.num_qo_heads * self.dcp_world_size,
self.num_kv_heads,
self.head_dim,
self.page_size,
indptr_cpu=self.paged_kv_indptr.cpu[: num_input_tokens + 1],
indices=paged_kv_indices,
last_page_len_cpu=self.paged_kv_last_page_len.cpu[
:num_input_tokens
],
num_qo_heads=self.num_qo_heads * self.dcp_world_size,
num_kv_heads=self.num_kv_heads,
head_dim=self.head_dim,
page_size=self.page_size,
# Disable flashinfer's pos encoding and use vllm's rope.
pos_encoding_mode="NONE",
sm_scale=self.sm_scale,
@@ -1330,32 +1320,15 @@ class FlashInferImpl(AttentionImpl):
num_actual_tokens = attn_metadata.num_actual_tokens
if self.kv_sharing_target_layer_name is None:
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# not padded. However, we don't need to do key[:num_actual_tokens]
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
kv_cache[:, 0],
kv_cache[:, 1],
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
# to process the cache when the kv_cache_dtype is fp8
if self.kv_sharing_target_layer_name is None and self.kv_cache_dtype.startswith(
"fp8"
):
torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
self.kv_cache_dtype
)
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
# to process the cache when the kv_cache_dtype is fp8
if self.kv_cache_dtype.startswith("fp8"):
torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
self.kv_cache_dtype
)
kv_cache = kv_cache.view(torch_dtype)
kv_cache = kv_cache.view(torch_dtype)
# Inputs and outputs may be padded for CUDA graphs
query = query[:num_actual_tokens]
@@ -1599,13 +1572,39 @@ class FlashInferImpl(AttentionImpl):
)
return output_padded
def do_kv_cache_update(
self,
layer: torch.nn.Module,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
) -> None:
if self.kv_sharing_target_layer_name is None:
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# not padded. However, we don't need to do key[:num_actual_tokens]
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
kv_cache[:, 0],
kv_cache[:, 1],
slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
def fast_plan_decode(
self, # decode wrapper
indptr_cpu: torch.Tensor,
indices: torch.Tensor,
last_page_len_cpu: torch.Tensor,
seq_lens_cpu: torch.Tensor,
num_qo_heads: int,
num_kv_heads: int,
head_dim: int,
@@ -1642,110 +1641,56 @@ def fast_plan_decode(
# this warm up is to generate the _cached_module for the decode wrapper.
if not self.is_cuda_graph_enabled or getattr(self, "vllm_first_call", True):
self.plan(
indptr_cpu,
indices,
last_page_len_cpu,
num_qo_heads,
num_kv_heads,
head_dim,
page_size,
pos_encoding_mode,
window_left,
logits_soft_cap,
q_data_type,
kv_data_type,
o_data_type,
data_type,
sm_scale,
rope_scale,
rope_theta,
non_blocking,
None, # block_tables
None, # seq_lens
fixed_split_size,
disable_split_kv,
indptr=indptr_cpu,
indices=indices,
last_page_len=last_page_len_cpu,
num_qo_heads=num_qo_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
page_size=page_size,
pos_encoding_mode=pos_encoding_mode,
window_left=window_left,
logits_soft_cap=logits_soft_cap,
q_data_type=q_data_type,
kv_data_type=kv_data_type,
o_data_type=o_data_type,
data_type=data_type,
sm_scale=sm_scale,
rope_scale=rope_scale,
rope_theta=rope_theta,
non_blocking=non_blocking,
block_tables=None,
seq_lens=None,
fixed_split_size=fixed_split_size,
disable_split_kv=disable_split_kv,
)
self.vllm_first_call = False
return
assert self.is_cuda_graph_enabled, "Should be cudagraph only here"
batch_size = len(last_page_len_cpu)
if logits_soft_cap is None:
logits_soft_cap = 0.0
# Handle data types consistently
if data_type is not None:
if q_data_type is None:
q_data_type = data_type
if kv_data_type is None:
kv_data_type = data_type
elif q_data_type is None:
q_data_type = "float16"
if kv_data_type is None:
kv_data_type = q_data_type
q_data_type = (
getattr(torch, q_data_type) if isinstance(q_data_type, str) else q_data_type
fast_decode_plan(
self,
indptr=indptr_cpu,
indices=indices,
last_page_len=last_page_len_cpu,
num_qo_heads=num_qo_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
page_size=page_size,
pos_encoding_mode=pos_encoding_mode,
window_left=window_left,
logits_soft_cap=logits_soft_cap,
q_data_type=q_data_type,
kv_data_type=kv_data_type,
data_type=data_type,
sm_scale=sm_scale,
rope_scale=rope_scale,
rope_theta=rope_theta,
non_blocking=non_blocking,
fixed_split_size=fixed_split_size,
disable_split_kv=disable_split_kv,
)
kv_data_type = (
getattr(torch, kv_data_type) if isinstance(kv_data_type, str) else kv_data_type
)
if batch_size != self._fixed_batch_size:
raise ValueError(
"The batch size should be fixed in cudagraph mode, the runtime "
"batch size {} mismatches the batch size set during "
"initialization {}".format(batch_size, self._fixed_batch_size)
)
if len(indices) > len(self._paged_kv_indices_buf):
raise ValueError(
"The size of indices should be less than or equal to the allocated buffer"
)
# host-to-device copy for the indptr buffer
self._paged_kv_indptr_buf.copy_(indptr_cpu, non_blocking=True)
# host-to-device copy for the last_page_len buffer
self._paged_kv_last_page_len_buf.copy_(last_page_len_cpu, non_blocking=True)
qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
try:
# Make sure we pass exactly 19 arguments for tensor core version
args = [
self._float_workspace_buffer,
self._int_workspace_buffer,
self._pin_memory_int_workspace_buffer,
qo_indptr_host,
indptr_cpu,
seq_lens_cpu,
batch_size, # total_num_rows
batch_size,
num_qo_heads,
num_kv_heads,
page_size,
self.is_cuda_graph_enabled,
head_dim,
head_dim,
False, # causal
window_left,
]
if self._backend == "fa2":
args.append(fixed_split_size)
args.append(disable_split_kv)
args.append(0) # num_colocated_ctas
self._plan_info = self._cached_module.plan(
*args,
)
except Exception as e:
raise RuntimeError(f"Error in tensor core plan: {e}") from e
self._pos_encoding_mode = pos_encoding_mode
self._window_left = window_left
self._logits_soft_cap = logits_soft_cap
self._sm_scale = sm_scale
self._rope_scale = rope_scale
self._rope_theta = rope_theta
@triton.jit

View File

@@ -1,9 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from dataclasses import dataclass, replace
from typing import Any
from vllm.v1.attention.backend import AttentionBackend
from vllm.v1.attention.backend import AttentionBackend, CommonAttentionMetadata
from vllm.v1.attention.backends.mamba_attn import (
BaseMambaAttentionMetadata,
BaseMambaAttentionMetadataBuilder,
@@ -29,3 +30,31 @@ class Mamba1AttentionMetadataBuilder(
BaseMambaAttentionMetadataBuilder[Mamba1AttentionMetadata]
):
metadata_cls = Mamba1AttentionMetadata
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
**kwargs: Any,
) -> Mamba1AttentionMetadata:
common = self._compute_common_metadata(common_attn_metadata)
if (
common.num_prefills > 0
and self.vllm_config.cache_config.mamba_cache_mode == "all"
):
cu_chunk_seqlen_p, _, last_chunk_indices_p = (
self._build_chunk_metadata_tensors(
self.kv_cache_spec.block_size,
common,
common_attn_metadata,
)
)
return replace(
common,
cu_chunk_seqlen_p=cu_chunk_seqlen_p,
last_chunk_indices_p=last_chunk_indices_p,
)
return common

View File

@@ -7,7 +7,6 @@ from typing import Any
import torch
from vllm.config import VllmConfig
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backend import (
AttentionBackend,
CommonAttentionMetadata,
@@ -105,14 +104,6 @@ class Mamba2AttentionMetadata(BaseMambaAttentionMetadata):
# Chunk-related metadata (only for prefill)
seq_idx_p: torch.Tensor | None = None
# cu_chunk_seqlen_p is a tensor of shape (nchunks+1,) that contains, for
# each chunk, its offsets into the varlen sequence dimension. It is defined
# such that the i-th chunk contains tokens from cu_chunk_seqlen_p[i] to
# cu_chunk_seqlen_p[i+1].
cu_chunk_seqlen_p: torch.Tensor | None = None
# last_chunk_indices_p is a tensor of shape (batch,) that contains the
# index of the last chunk for every sequence in the (prefill) batch.
last_chunk_indices_p: torch.Tensor | None = None
class Mamba2AttentionMetadataBuilder(
@@ -134,68 +125,6 @@ class Mamba2AttentionMetadataBuilder(
)
self.chunk_size: int = chunk_size
def _compute_chunk_metadata(
self,
num_prefills: int,
num_computed_tokens_p_cpu: torch.Tensor,
query_start_loc_p_cpu: torch.Tensor,
) -> tuple[list[int], list[int], list[int]]:
"""
Compute chunk-specific metadata for Mamba2.
The code below carefully constructs the chunks such that:
1. Chunks contain tokens from a *single* sequence only.
2. For every sequence, we are guaranteed that we can
retrieve the mamba state *every* chunk_size tokens.
Constraint (1) dramatically simplifies the mamba2 kernels.
Constraint (2) dramatically simplifies the implementation
of prefix caching for mamba2 (wip). We need to take care
of the interaction with chunked prefill in order to
satisfy constraint (2).
"""
# TODO (tdoublep): This code could probably be optimized.
cu_chunk_seqlen = []
seq_idx = []
last_chunk_indices = []
seqlen_pos = 0
for req_idx in range(num_prefills):
this_num_computed = num_computed_tokens_p_cpu[req_idx].item()
this_new_tokens = (
query_start_loc_p_cpu[req_idx + 1].item()
- query_start_loc_p_cpu[req_idx].item()
)
# if computed tokens are not chunk-aligned, use the first
# chunk to finish it off
if this_num_computed % self.chunk_size != 0:
seq_idx.append(req_idx)
cu_chunk_seqlen.append(seqlen_pos)
# how many tokens to finish the chunk?
chunk_len = (
cdiv(this_num_computed, self.chunk_size) * self.chunk_size
- this_num_computed
)
# we can only use at most this_new_tokens
chunk_len = min(chunk_len, this_new_tokens)
seqlen_pos += chunk_len
this_new_tokens -= chunk_len
n_chunks = cdiv(this_new_tokens, self.chunk_size)
for chunk in range(n_chunks):
seq_idx.append(req_idx)
cu_chunk_seqlen.append(seqlen_pos)
chunk_len = min(self.chunk_size, this_new_tokens)
seqlen_pos += chunk_len
this_new_tokens -= chunk_len
assert this_new_tokens == 0
last_chunk_indices.append(len(cu_chunk_seqlen) - 1)
cu_chunk_seqlen.append(seqlen_pos)
return cu_chunk_seqlen, seq_idx, last_chunk_indices
def build(
self,
common_prefix_len: int,
@@ -220,41 +149,12 @@ class Mamba2AttentionMetadataBuilder(
else False
)
num_reqs = common.num_reqs
num_prefills = common.num_prefills
num_decode_tokens = common.num_decode_tokens
num_computed_tokens_cpu = (
common_attn_metadata.compute_num_computed_tokens().cpu()
)
num_computed_tokens_p_cpu = num_computed_tokens_cpu[
num_reqs - num_prefills : num_reqs
]
query_start_loc_p_cpu = (
common_attn_metadata.query_start_loc_cpu[-num_prefills - 1 :]
- num_decode_tokens
)
cu_chunk_seqlen, seq_idx, last_chunk_indices = self._compute_chunk_metadata(
num_prefills,
num_computed_tokens_p_cpu,
query_start_loc_p_cpu,
)
seq_idx_p = torch.as_tensor(
seq_idx,
device=common_attn_metadata.query_start_loc.device,
dtype=torch.int32,
)
cu_chunk_seqlen_p = torch.as_tensor(
cu_chunk_seqlen,
device=common_attn_metadata.query_start_loc.device,
dtype=torch.int32,
)
last_chunk_indices_p = torch.as_tensor(
last_chunk_indices,
device=common_attn_metadata.query_start_loc.device,
dtype=torch.int32,
cu_chunk_seqlen_p, seq_idx_p, last_chunk_indices_p = (
self._build_chunk_metadata_tensors(
self.chunk_size,
common,
common_attn_metadata,
)
)
return replace(

View File

@@ -59,6 +59,15 @@ class BaseMambaAttentionMetadata:
# The following tensor is only used for prefix caching in align mode
seq_lens: torch.Tensor
# cu_chunk_seqlen_p is a tensor of shape (nchunks+1,) that contains, for
# each chunk, its offsets into the varlen sequence dimension. It is defined
# such that the i-th chunk contains tokens from cu_chunk_seqlen_p[i] to
# cu_chunk_seqlen_p[i+1].
cu_chunk_seqlen_p: torch.Tensor | None = None
# last_chunk_indices_p is a tensor of shape (batch,) that contains the
# index of the last chunk for every sequence in the (prefill) batch.
last_chunk_indices_p: torch.Tensor | None = None
# The following attributes are for triton implementation of causal_conv1d
nums_dict: dict | None = None
batch_ptr: torch.Tensor | None = None
@@ -185,6 +194,118 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
common_attn_metadata, num_accepted_tokens=num_accepted_tokens
)
def _compute_chunk_metadata(
self,
chunk_size: int,
num_prefills: int,
num_computed_tokens_p_cpu: torch.Tensor,
query_start_loc_p_cpu: torch.Tensor,
) -> tuple[list[int], list[int], list[int]]:
"""
Compute chunk-specific metadata for Mamba models.
The code below carefully constructs the chunks such that:
1. Chunks contain tokens from a *single* sequence only.
2. For every sequence, we are guaranteed that we can
retrieve the mamba state *every* chunk_size tokens.
Constraint (1) dramatically simplifies the mamba kernels.
Constraint (2) dramatically simplifies the implementation
of prefix caching for mamba (wip). We need to take care
of the interaction with chunked prefill in order to
satisfy constraint (2).
"""
# TODO (tdoublep): This code could probably be optimized.
cu_chunk_seqlen = []
seq_idx = []
last_chunk_indices = []
seqlen_pos = 0
for req_idx in range(num_prefills):
this_num_computed = num_computed_tokens_p_cpu[req_idx].item()
this_new_tokens = (
query_start_loc_p_cpu[req_idx + 1].item()
- query_start_loc_p_cpu[req_idx].item()
)
# if computed tokens are not chunk-aligned, use the first
# chunk to finish it off
if this_num_computed % chunk_size != 0:
seq_idx.append(req_idx)
cu_chunk_seqlen.append(seqlen_pos)
# how many tokens to finish the chunk?
chunk_len = (
cdiv(this_num_computed, chunk_size) * chunk_size - this_num_computed
)
# we can only use at most this_new_tokens
chunk_len = min(chunk_len, this_new_tokens)
seqlen_pos += chunk_len
this_new_tokens -= chunk_len
n_chunks = cdiv(this_new_tokens, chunk_size)
for chunk in range(n_chunks):
seq_idx.append(req_idx)
cu_chunk_seqlen.append(seqlen_pos)
chunk_len = min(chunk_size, this_new_tokens)
seqlen_pos += chunk_len
this_new_tokens -= chunk_len
assert this_new_tokens == 0
last_chunk_indices.append(len(cu_chunk_seqlen) - 1)
cu_chunk_seqlen.append(seqlen_pos)
return cu_chunk_seqlen, seq_idx, last_chunk_indices
def _build_chunk_metadata_tensors(
self,
chunk_size: int,
common: M,
common_attn_metadata: CommonAttentionMetadata,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Compute chunk metadata and return as device tensors.
Returns (cu_chunk_seqlen_p, seq_idx_p, last_chunk_indices_p).
"""
num_reqs = common.num_reqs
num_prefills = common.num_prefills
num_decode_tokens = common.num_decode_tokens
num_computed_tokens_cpu = (
common_attn_metadata.compute_num_computed_tokens().cpu()
)
num_computed_tokens_p_cpu = num_computed_tokens_cpu[
num_reqs - num_prefills : num_reqs
]
query_start_loc_p_cpu = (
common_attn_metadata.query_start_loc_cpu[-num_prefills - 1 :]
- num_decode_tokens
)
cu_chunk_seqlen, seq_idx, last_chunk_indices = self._compute_chunk_metadata(
chunk_size,
num_prefills,
num_computed_tokens_p_cpu,
query_start_loc_p_cpu,
)
device = common_attn_metadata.query_start_loc.device
cu_chunk_seqlen_p = torch.as_tensor(
cu_chunk_seqlen,
device=device,
dtype=torch.int32,
)
seq_idx_p = torch.as_tensor(
seq_idx,
device=device,
dtype=torch.int32,
)
last_chunk_indices_p = torch.as_tensor(
last_chunk_indices,
device=device,
dtype=torch.int32,
)
return cu_chunk_seqlen_p, seq_idx_p, last_chunk_indices_p
def _compute_prefix_caching_block_indices(
self,
common_attn_metadata: CommonAttentionMetadata,

View File

@@ -191,6 +191,8 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
query_start_loc_device: torch.Tensor,
num_decode_tokens: int,
dcp_tot_seq_lens_device: torch.Tensor | None,
max_decode_seq_len: int = 0,
use_cuda_graph: bool = False,
) -> FlashAttnMLADecodeMetadata:
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
max_query_len = query_lens_cpu.max().item()
@@ -239,12 +241,14 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
metadata = FlashAttnMLADecodeMetadata(
block_table=block_table_tensor,
seq_lens=seq_lens_device,
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
max_decode_seq_len=max_decode_seq_len,
use_cuda_graph=use_cuda_graph,
query_start_loc=query_start_loc_device,
max_query_len=max_query_len,
max_seq_len=max_seq_len,
scheduler_metadata=scheduler_metadata,
max_num_splits=max_num_splits,
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
)
return metadata

View File

@@ -156,6 +156,8 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
query_start_loc_device: torch.Tensor,
num_decode_tokens: int,
dcp_tot_seq_lens_device: torch.Tensor | None,
max_decode_seq_len: int = 0,
use_cuda_graph: bool = False,
) -> FlashMLADecodeMetadata:
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
# we use the max but all should be the same due to uniform length requirement
@@ -179,8 +181,10 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
return FlashMLADecodeMetadata(
block_table=block_table_tensor,
seq_lens=seq_lens_device,
scheduler_metadata=scheduler_metadata,
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
max_decode_seq_len=max_decode_seq_len,
use_cuda_graph=use_cuda_graph,
scheduler_metadata=scheduler_metadata,
)

View File

@@ -13,6 +13,11 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.attention.mla_attention import (
get_mla_dims,
)
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
LinearBase,
UnquantizedLinearMethod,
)
from vllm.platforms import current_platform
from vllm.platforms.interface import DeviceCapability
from vllm.utils.platform_utils import num_compute_units
@@ -37,13 +42,17 @@ from vllm.v1.attention.backends.utils import (
)
from vllm.v1.attention.ops.flashmla import (
FlashMLASchedMeta,
flash_mla_sparse_fwd,
flash_mla_sparse_prefill,
flash_mla_with_kvcache,
get_mla_metadata,
)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.workspace import current_workspace_manager
import functools
from vllm import envs
from vllm.model_executor.layers.quantization.utils.quant_utils import scaled_dequantize
import ixformer.inference.functions as ixf_ops
import numpy as np
if TYPE_CHECKING:
from vllm.model_executor.models.deepseek_v2 import Indexer
@@ -74,7 +83,15 @@ structured as:
- **Last 128 bytes:** The "RoPE" part, containing 64 `bfloat16` values. This
part is not quantized for accuracy.
"""
def dynamic_per_batched_tensor_quant(
x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn
):
DTYPE_MAX = torch.finfo(dtype).max
min_val, max_val = x.aminmax()
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-10)
scale = DTYPE_MAX / amax
x_scl_sat = (x * scale).clamp(min=-DTYPE_MAX, max=DTYPE_MAX)
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
class FlashMLASparseBackend(AttentionBackend):
accept_output_buffer: bool = True
@@ -558,6 +575,11 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
self.num_kv_heads = num_kv_heads
self.kv_cache_dtype = kv_cache_dtype
self.kv_lora_rank: int = mla_args["kv_lora_rank"]
self.qk_nope_head_dim = mla_args["qk_nope_head_dim"]
self.qk_rope_head_dim = mla_args["qk_rope_head_dim"]
self.qk_head_dim = mla_args["qk_head_dim"]
self.v_head_dim = mla_args["v_head_dim"]
self.kv_b_proj = mla_args["kv_b_proj"]
self.softmax_scale = scale
assert indexer is not None
self.topk_indices_buffer: torch.Tensor | None = indexer.topk_indices_buffer
@@ -580,6 +602,65 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
(self.prefill_workspace_shape, torch.bfloat16)
)
)
def process_weights_after_loading(self, act_dtype: torch.dtype):
def get_layer_weight(layer):
WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
for attr in WEIGHT_NAMES:
if hasattr(layer, attr):
return getattr(layer, attr)
raise AttributeError(
f"Layer '{layer}' has no recognized weight attribute: {WEIGHT_NAMES}."
)
def get_and_maybe_dequant_weights(layer: LinearBase):
if layer.quant_method is not None and not isinstance(
layer.quant_method, UnquantizedLinearMethod
):
# NOTE: This should only be used offline, since it's O(N^3)
eye = torch.eye(
layer.input_size_per_partition,
dtype=act_dtype,
device=get_layer_weight(layer).device,
)
dequant_weights = layer.quant_method.apply(layer, eye, bias=None)
del eye
# standardize to (output, input)
return dequant_weights.T
return layer.weight
# we currently do not have quantized bmm's which are needed for
# `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
assert kv_b_proj_weight.shape == (
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
), (
f"{kv_b_proj_weight.shape=}, "
f"{self.kv_lora_rank=}, "
f"{self.num_heads=}, "
f"{self.qk_nope_head_dim=}, "
f"{self.v_head_dim=}"
)
kv_b_proj_weight = kv_b_proj_weight.view(
self.kv_lora_rank,
self.num_heads,
self.qk_nope_head_dim + self.v_head_dim,
)
W_UK, W_UV = kv_b_proj_weight.split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1
)
self.W_UV = W_UV
self.W_UK = W_UK
# self.W_UK_T = W_UK.permute(1, 2, 0)
def _v_up_proj(self, x: torch.Tensor):
return torch.einsum("bnl,lnv->bnv", x, self.W_UV)
def _k_up_proj(self, q_nope):
return torch.einsum("bnp,lnp->bnl", q_nope, self.W_UK).view(-1, self.num_heads, self.kv_lora_rank)
def _forward_bf16_kv(
self,
@@ -590,12 +671,11 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
) -> torch.Tensor:
# Convert per-request indices to global slots (decode) or workspace
# offsets (prefill).
topk_indices = triton_convert_req_index_to_global_index(
topk_indices = ops.dsa_convert_req_index_to_global_index(
attn_metadata.req_id_per_token,
attn_metadata.block_table,
topk_indices,
BLOCK_SIZE=attn_metadata.block_size,
NUM_TOPK_TOKENS=topk_indices.shape[1],
attn_metadata.block_size,
)
return self._bf16_flash_mla_kernel(q, kv_c_and_k_pe_cache, topk_indices)
@@ -790,22 +870,10 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
-1, 1, kv_c_and_k_pe_cache.shape[-1]
)
# NOTE(Chen): kernel requires num_local_head to be a multiple of
# 64 on hopper and 128 on blackwell
if self.num_heads % self.prefill_padding != 0:
assert self.prefill_padding % self.num_heads == 0
logger.warning_once(
f"Padding num_heads from {self.num_heads} to "
f"{self.prefill_padding} for BF16 sparse prefill kernel"
)
q_padded = q.new_empty((q.shape[0], self.prefill_padding, q.shape[2]))
q_padded[:, : self.num_heads, :] = q
q = q_padded
topk_indices = topk_indices.view(num_tokens, 1, -1)
output = flash_mla_sparse_fwd(
output = flash_mla_sparse_prefill(
q, kv_c_and_k_pe_cache, topk_indices, self.softmax_scale
)[0]
)
output = output[:, : self.num_heads, :]
return output
@@ -843,5 +911,5 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
attn_out = self._forward_fp8_kv_separate_prefill_decode(
q, kv_c_and_k_pe_cache, topk_indices, attn_metadata
)
return attn_out, None
return attn_out

View File

@@ -8,7 +8,11 @@ import torch
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata, has_deep_gemm
from vllm.utils.deep_gemm import (
get_paged_mqa_logits_metadata,
is_deep_gemm_supported,
)
from vllm.utils.math_utils import cdiv
from vllm.utils.platform_utils import num_compute_units
from vllm.v1.attention.backend import (
AttentionBackend,
@@ -21,6 +25,7 @@ from vllm.v1.attention.backends.utils import (
split_decodes_and_prefills,
split_prefill_chunks,
)
from vllm.v1.worker.cp_utils import get_total_cp_world_size
logger = init_logger(__name__)
@@ -68,11 +73,15 @@ class DeepseekV32IndexerPrefillChunkMetadata:
cu_seqlen_ks: torch.Tensor
cu_seqlen_ke: torch.Tensor
cu_seq_lens: torch.Tensor
cu_seqlens_q: torch.Tensor
token_to_seq: torch.Tensor
total_seq_lens: int
token_start: int
token_end: int
num_reqs: int
max_context_len: int
max_q_len: int # Maximum query length for dsa_indexer_mqa_logits_with_blocks
max_kv_len: int # Maximum key-value length for dsa_indexer_mqa_logits_with_blocks
@dataclass
@@ -86,9 +95,16 @@ class DeepSeekV32IndexerDecodeMetadata:
seq_lens: torch.Tensor
decode_lens: torch.Tensor
requires_padding: bool
schedule_metadata: torch.Tensor
# schedule_metadata: torch.Tensor
use_large_context_topk: bool
offsets: torch.Tensor | None # Precomputed offsets for speculative decoding
cu_seqlen_ks: torch.Tensor
cu_seqlen_ke: torch.Tensor
cu_seqlens_kv: torch.Tensor
cu_seqlens_q: torch.Tensor
max_context_len: int
max_q_len: int
max_kv_len: int
@dataclass
@@ -211,20 +227,39 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
if self.vllm_config.speculative_config
else 0
)
if self.num_speculative_tokens > 1:
raise ValueError(
"Sparse MLA only supports "
"num_speculative_tokens <= 1 because the DeepGEMM "
"fp8_paged_mqa_logits kernel does not support next_n > 2. "
f"Got num_speculative_tokens={self.num_speculative_tokens}."
)
self.reorder_batch_threshold += self.num_speculative_tokens
sm_count = num_compute_units(self.device.index)
self.num_sms = sm_count
self.decode_lens_buffer = torch.empty(
(scheduler_config.max_num_seqs,), dtype=torch.int32, device=self.device
(scheduler_config.max_num_batched_tokens,),
dtype=torch.int32,
device=self.device,
)
# Pre-allocated buffers for flattening (spec decode).
self.arange_buffer = torch.arange(
scheduler_config.max_num_seqs * (1 + self.num_speculative_tokens),
dtype=torch.int32,
device=self.device,
)
self.expanded_seq_lens_buffer = torch.zeros(
(scheduler_config.max_num_batched_tokens,),
dtype=torch.int32,
device=self.device,
)
max_num_blocks_per_req = cdiv(
self.vllm_config.model_config.max_model_len,
self.kv_cache_spec.block_size * get_total_cp_world_size(),
)
self.expanded_block_table_buffer = torch.zeros(
(
scheduler_config.max_num_batched_tokens,
max_num_blocks_per_req,
),
dtype=torch.int32,
device=self.device,
)
# See: DeepGMM/csrc/apis/attention.hpp
@@ -260,18 +295,88 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
.to(torch.int32)
.to(self.device)
)
cu_seqlens_q = prefill_query_start_loc.to(torch.int32).to(self.device)
max_context_len = seq_lens_cpu[reqs_start:reqs_end].max().item()
# max_q_len is the maximum query length among all batches in this chunk
# prefill_query_start_loc is cumsum of lengths with shape [batch+1]
max_q_len = (prefill_query_start_loc[1:] - prefill_query_start_loc[:-1]).max().item()
return DeepseekV32IndexerPrefillChunkMetadata(
cu_seqlen_ks=cu_seqlen_ks,
cu_seqlen_ke=cu_seqlen_ke,
cu_seq_lens=cu_seq_lens,
token_to_seq=token_to_seq,
total_seq_lens=total_seq_lens,
cu_seqlens_q=cu_seqlens_q,
block_table=block_table[reqs_start:reqs_end],
token_start=token_start,
token_end=token_end,
num_reqs=reqs_end - reqs_start,
max_context_len=max_context_len,
max_q_len=max_q_len,
max_kv_len=max_context_len
)
def build_decode_metadata(
self, common_attn_metadata, num_decodes, decode_lens, use_large_context_topk, offsets
):
decode_lens_cpu = torch.diff(
common_attn_metadata.query_start_loc_cpu[: num_decodes + 1]
)
assert (
decode_lens_cpu.max().item()
== decode_lens_cpu.min().item()
== 1
), "Only support single token decode in dsa_indexer backend"
# Calculate decode metadata parameters
seq_lens_decode = common_attn_metadata.seq_lens_cpu[:num_decodes]
max_context_len = seq_lens_decode.max().item()
max_kv_len = max_context_len
max_q_len = 1 # Single token decode
# Create cu_seqlens_q: cumulative sum of query lengths (all 1s)
cu_seqlens_q = torch.arange(
num_decodes + 1, dtype=torch.int32, device=self.device
)
# Create cu_seqlens_kv and related tensors using kv_spans_from_batches
decode_query_start_loc = torch.arange(
num_decodes + 1, dtype=torch.long
)
cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(
decode_query_start_loc, seq_lens_decode, self.device
)
cu_seqlens_kv = torch.cat(
[
torch.zeros(1, dtype=torch.int32, device=self.device),
torch.cumsum(seq_lens_decode.to(self.device), dim=0)
.to(torch.int32),
]
)
decode_metadata = DeepSeekV32IndexerDecodeMetadata(
block_table=common_attn_metadata.block_table_tensor[
:num_decodes, ...
],
seq_lens=common_attn_metadata.seq_lens[:num_decodes],
decode_lens=decode_lens,
requires_padding=(
decode_lens_cpu.max() > decode_lens_cpu.min()
).item(),
use_large_context_topk=use_large_context_topk,
offsets=offsets,
cu_seqlen_ks=cu_seqlen_ks,
cu_seqlen_ke=cu_seqlen_ke,
cu_seqlens_kv=cu_seqlens_kv,
cu_seqlens_q=cu_seqlens_q,
max_context_len=max_context_len,
max_q_len=max_q_len,
max_kv_len=max_kv_len,
# schedule_metadata=self.scheduler_metadata_buffer,
)
return decode_metadata
def build(
self,
common_prefix_len: int,
@@ -323,45 +428,103 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
common_attn_metadata.query_start_loc_cpu[: num_decodes + 1]
)
# Use CPU to avoid GPU sync; breaking async scheduling
requires_padding = (decode_lens_cpu.max() > decode_lens_cpu.min()).item()
# Decide which top-k kernel to use based on batch size and sequence length
batch_size = num_decodes
_is_large_context = common_attn_metadata.max_seq_len > 8192
# Decision logic based on micro-benchmark results:
# - large_context_topk wins for batch <= 128 and seq_len > 8K
# - top_k_per_row_decode wins for batch > 128 or seq_len <= 8K
use_large_context_topk = batch_size <= 128 and _is_large_context
next_n = 1 + self.num_speculative_tokens
if next_n > 1:
offsets = torch.arange(next_n, device=self.device, dtype=torch.int32)
else:
offsets = None
seq_lens = common_attn_metadata.seq_lens[:num_decodes]
# DeepGEMM is required for the paged MQA logits on CUDA devices
if current_platform.is_cuda() and has_deep_gemm():
self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
seq_lens, self.kv_cache_spec.block_size, self.num_sms
)
block_table = common_attn_metadata.block_table_tensor[:num_decodes, ...]
# Padded CUDA graph requests have block_table entries of -1.
# Clamp to 0 to prevent OOB access in the DeepGEMM kernel.
# This is safe because padded requests have seq_lens=0, so the
# kernel produces no meaningful output for those rows.
block_table.clamp_(min=0)
decode_metadata = DeepSeekV32IndexerDecodeMetadata(
block_table=block_table,
seq_lens=common_attn_metadata.seq_lens[:num_decodes],
decode_lens=decode_lens,
requires_padding=requires_padding,
schedule_metadata=self.scheduler_metadata_buffer,
use_large_context_topk=use_large_context_topk,
offsets=offsets,
max_decode_len = int(decode_lens_cpu.max().item())
if max_decode_len > 1:
# Flatten multi-token decode requests into single-token
# batch entries, expanding seq_lens and block tables so
# the kernel always sees next_n=1.
# Assume 4 requests with seq_lens [10, 7, 12, 0] (the final req is
# padding) and decode_lens [3, 1, 4, 0] in the below example comments.
# The context lengths are therefore
# [10-3, 7-1, 12-4, 0-0] = [7, 6, 8, 0].
# 3 + 1 + 4 + 0 = 8
actual_expanded = int(decode_lens_cpu.sum().item())
# [7, 6, 8, 0] -> [7, 7, 7, 6, 8, 8, 8, 8]
expanded_base = torch.repeat_interleave(
seq_lens - decode_lens, decode_lens
)
# [0, 3, 4, 8] -> [0, 0, 0, 3, 4, 4, 4, 4]
expanded_starts = torch.repeat_interleave(
common_attn_metadata.query_start_loc[:num_decodes], decode_lens
)
# [0, 1, 2, 0, 0, 1, 2, 3]
positions_within = (
self.arange_buffer[:actual_expanded] - expanded_starts
)
# [8, 9, 10, 7, 9, 10, 11, 12, ...] where ... is unused buffer space
self.expanded_seq_lens_buffer[:actual_expanded] = (
expanded_base + positions_within + 1
)
self.expanded_seq_lens_buffer[actual_expanded:] = 0
seq_lens = self.expanded_seq_lens_buffer[:num_decode_tokens]
# Give each of the flattened entries the same block table row as the
# original request.
self.expanded_block_table_buffer[:actual_expanded] = (
torch.repeat_interleave(block_table, decode_lens, dim=0)
)
if actual_expanded < num_decode_tokens:
self.expanded_block_table_buffer[
actual_expanded:num_decode_tokens, 0
] = 0
block_table = self.expanded_block_table_buffer[:num_decode_tokens]
# All reqs now have decode_len=1
self.decode_lens_buffer[:num_decode_tokens] = 1
decode_lens = self.decode_lens_buffer[:num_decode_tokens]
offsets = None
batch_size = num_decode_tokens
else:
next_n = 1 + self.num_speculative_tokens
if next_n > 1:
offsets = torch.arange(
next_n, device=self.device, dtype=torch.int32
)
else:
offsets = None
batch_size = num_decodes
# DeepGEMM is required for the paged MQA logits on CUDA devices
if current_platform.is_cuda() and is_deep_gemm_supported():
self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
seq_lens,
self.kv_cache_spec.block_size,
self.num_sms,
)
# Decide which top-k kernel to use based on batch size and sequence length
# Decision logic based on micro-benchmark results:
# - large_context_topk wins for batch <= 128 and seq_len > 8K
# - top_k_per_row_decode wins for batch > 128 or seq_len <= 8K
_is_large_context = common_attn_metadata.max_seq_len > 8192
use_large_context_topk = batch_size <= 128 and _is_large_context
# decode_metadata = DeepSeekV32IndexerDecodeMetadata(
# block_table=block_table,
# seq_lens=seq_lens,
# decode_lens=decode_lens,
# requires_padding=False,
# # schedule_metadata=self.scheduler_metadata_buffer,
# use_large_context_topk=use_large_context_topk,
# offsets=offsets,
# )
decode_metadata = self.build_decode_metadata(
common_attn_metadata, num_decodes, decode_lens, use_large_context_topk, offsets
)
attn_metadata = DeepseekV32IndexerMetadata(

View File

@@ -115,6 +115,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
query_start_loc_device: torch.Tensor,
num_decode_tokens: int,
dcp_tot_seq_lens_device: torch.Tensor | None,
max_decode_seq_len: int = 0,
use_cuda_graph: bool = False,
) -> AiterMLADecodeMetadata:
# kernel block size is always 1, although the kv block size is not 1.
device = self.device
@@ -170,11 +172,13 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
attn_metadata = AiterMLADecodeMetadata(
block_table=block_table_tensor,
seq_lens=seq_lens_device,
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
max_decode_seq_len=max_decode_seq_len,
use_cuda_graph=use_cuda_graph,
paged_kv_indptr=paged_kv_indptr,
paged_kv_indices=paged_kv_indices,
paged_kv_last_page_len=paged_kv_last_page_len,
qo_indptr=qo_indptr,
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
max_qo_len=max_qo_len,
attn_out_dtype=self.decode_attn_out_dtype,
)

View File

@@ -15,6 +15,7 @@ from vllm.model_executor.layers.attention.mla_attention import (
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.distributed.parallel_state import get_dcp_group
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backend import (
AttentionLayer,
@@ -22,20 +23,19 @@ from vllm.v1.attention.backend import (
is_quantized_kv_cache,
)
from vllm.v1.attention.ops.triton_decode_attention import decode_attention_fwd
import ixformer.inference.functions as ixf_ops
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.distributed.parallel_state import get_dcp_group
logger = init_logger(__name__)
class TritonMLABackend(MLACommonBackend):
# supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
# supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
# "auto",
# "bfloat16",
# ]
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"bfloat16",
]
@staticmethod
def get_name() -> str:
@@ -120,10 +120,9 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
# layer: AttentionLayer,
k_c_normed: torch.Tensor |None = None,
k_pe: torch.Tensor |None = None,
kv_c_and_k_pe_cache_scale: torch.Tensor |None = None,
k_c_normed: torch.Tensor | None,
k_pe: torch.Tensor | None,
kv_c_and_k_pe_cache_scale: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
@@ -136,7 +135,7 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
q_nope = q_nope.view(-1, self.num_heads, self.kv_lora_rank)
B = q_nope.shape[0]
if self.dcp_world_size > 1:
q = torch.cat([q_nope, q_pe], dim=-1)
q = get_dcp_group().all_gather(q, dim=1)
@@ -147,7 +146,7 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
device=q_nope.device)
if envs.VLLM_USE_INT8_MLA:
q_int8, q_scale = ops.quant_kv(q)
attn_out, softmax_lse = ixf_ops.ref_vllm_paged_attention_mla_int8(
attn_out, softmax_lse = ixf_ops.vllm_paged_attention_mla_int8(
o,
q_int8,
q_scale,
@@ -160,7 +159,7 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
return_softmax_lse=True
)
else:
attn_out, softmax_lse = ixf_ops.ref_vllm_paged_attention_mla(
attn_out, softmax_lse = ixf_ops.vllm_paged_attention_mla(
output=o,
query=q,
kv_cache=kv_c_and_k_pe_cache,
@@ -170,12 +169,12 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
max_context_len=decode_meta.max_decode_seq_len,
return_softmax_lse=True)
return attn_out, softmax_lse
o = torch.empty(B,
self.num_heads,
self.kv_lora_rank,
dtype=q_nope.dtype,
device=q_nope.device)
device=q_nope.device)
if envs.VLLM_USE_INT8_MLA:
q = torch.cat([q_nope, q_pe], dim=-1)
@@ -193,18 +192,30 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
attn_metadata.decode.use_cuda_graph
)
else:
# fused q concat & cache write
ixf_ops.vllm_paged_attention_mla_fused(
output=o,
q_nope=q_nope,
q_pe=q_pe.contiguous(),
kv_cache=kv_c_and_k_pe_cache,
scale=self.scale,
block_tables=attn_metadata.decode.block_table,
context_lens=attn_metadata.decode.seq_lens,
max_context_len=decode_meta.max_decode_seq_len,
k_c_normed=k_c_normed,
k_pe=k_pe,
use_cuda_graph=decode_meta.use_cuda_graph
)
if k_c_normed is None:
q = torch.cat([q_nope, q_pe.contiguous()], dim=-1)
ixf_ops.vllm_paged_attention_mla(
output=o,
query=q,
kv_cache=kv_c_and_k_pe_cache,
scale=self.scale,
block_tables=attn_metadata.decode.block_table,
context_lens=attn_metadata.decode.seq_lens,
max_context_len=decode_meta.max_decode_seq_len,
use_cuda_graph=decode_meta.use_cuda_graph,
)
else:
ixf_ops.vllm_paged_attention_mla_fused(
output=o,
q_nope=q_nope.contiguous(),
q_pe=q_pe.contiguous(),
kv_cache=kv_c_and_k_pe_cache,
scale=self.scale,
block_tables=attn_metadata.decode.block_table,
context_lens=attn_metadata.decode.seq_lens,
max_context_len=decode_meta.max_decode_seq_len,
k_c_normed=k_c_normed,
k_pe=k_pe,
use_cuda_graph=decode_meta.use_cuda_graph,
)
return self._v_up_proj(o), None

View File

@@ -55,6 +55,16 @@ class RocmAiterUnifiedAttentionBackend(RocmAttentionBackend):
def get_builder_cls() -> type["RocmAttentionMetadataBuilder"]:
return RocmAttentionMetadataBuilder
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""RocmAiterUnifiedAttention supports all attention types."""
return attn_type in (
AttentionType.DECODER,
AttentionType.ENCODER,
AttentionType.ENCODER_ONLY,
AttentionType.ENCODER_DECODER,
)
class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
def fused_output_quant_supported(self, quant_key: QuantKey):
@@ -143,6 +153,19 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
num_actual_tokens = attn_metadata.num_actual_tokens
# Handle encoder attention differently - no KV cache needed
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
# For encoder attention,
# we use direct Q, K, V tensors without caching
return self._forward_encoder_attention(
query[:num_actual_tokens],
key[:num_actual_tokens],
value[:num_actual_tokens],
output[:num_actual_tokens],
attn_metadata,
layer,
)
key_cache, value_cache = kv_cache.unbind(0)
if self.kv_cache_dtype.startswith("fp8"):
@@ -195,6 +218,10 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
):
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
# For encoder attention,
# we use direct Q, K, V tensors without caching
return
key_cache, value_cache = kv_cache.unbind(0)
# Reshape the input keys and values and store them in the cache.
@@ -224,6 +251,10 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
kv_cache: torch.Tensor,
layer_slot_mapping: torch.Tensor,
):
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
# For encoder attention,
# we use direct Q, K, V tensors without caching
return
key_cache, value_cache = kv_cache.unbind(0)
flash_layout = True

View File

@@ -182,7 +182,7 @@ class RocmAttentionBackend(AttentionBackend):
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
return [32, 64, 80, 96, 128, 160, 192, 224, 256]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
@@ -205,6 +205,16 @@ class RocmAttentionBackend(AttentionBackend):
def get_impl_cls() -> type["RocmAttentionImpl"]:
return RocmAttentionImpl
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""RocmAttention supports all attention types."""
return attn_type in (
AttentionType.DECODER,
AttentionType.ENCODER,
AttentionType.ENCODER_ONLY,
AttentionType.ENCODER_DECODER,
)
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
@@ -244,6 +254,7 @@ class RocmAttentionImpl(AttentionImpl):
kv_sharing_target_layer_name: int | None = None,
sinks: torch.Tensor | None = None,
) -> None:
self.attn_type = attn_type
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
@@ -266,11 +277,6 @@ class RocmAttentionImpl(AttentionImpl):
RocmAttentionBackend.validate_head_size(head_size)
if attn_type not in [AttentionType.DECODER, AttentionType.ENCODER_DECODER]:
raise NotImplementedError(
"Encoder self-attention is not implemented for RocmAttentionImpl"
)
self.fp8_dtype = current_platform.fp8_dtype()
self.sinks = sinks
@@ -281,6 +287,54 @@ class RocmAttentionImpl(AttentionImpl):
f"num_heads: {num_heads}."
)
def _forward_encoder_attention(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
layer: torch.nn.Module,
) -> torch.Tensor:
"""Forward pass for encoder attention without KV cache.
Args:
query: shape = [num_encoder_tokens, num_heads, head_size]
key: shape = [num_encoder_tokens, num_kv_heads, head_size]
value: shape = [num_encoder_tokens, num_kv_heads, head_size]
output: shape = [num_encoder_tokens, num_heads, head_size]
attn_metadata: Encoder attention metadata
layer: The attention layer
"""
# For encoder attention, process FP8 quantization if needed
if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError(
"quantization is not supported for encoder attention"
)
# Use encoder-specific metadata for sequence information
query_start_loc = attn_metadata.query_start_loc
seq_lens = attn_metadata.seq_lens
max_query_len = attn_metadata.max_query_len
# Call flash attention directly on Q, K, V tensors
from vllm.v1.attention.ops.triton_prefill_attention import context_attention_fwd
context_attention_fwd(
q=query,
k=key,
v=value,
o=output,
b_start_loc=query_start_loc,
b_seq_len=seq_lens,
max_input_len=max_query_len,
is_causal=False,
softmax_scale=self.scale,
sliding_window_q=self.sliding_window[0],
sliding_window_k=self.sliding_window[1],
)
return output
def forward(
self,
layer: torch.nn.Module,
@@ -330,6 +384,16 @@ class RocmAttentionImpl(AttentionImpl):
num_actual_tokens = attn_metadata.num_actual_tokens
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
return self._forward_encoder_attention(
query[:num_actual_tokens],
key[:num_actual_tokens],
value[:num_actual_tokens],
output[:num_actual_tokens],
attn_metadata,
layer,
)
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size
)
@@ -380,6 +444,8 @@ class RocmAttentionImpl(AttentionImpl):
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
):
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
return
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size
)
@@ -432,6 +498,8 @@ class RocmAttentionImpl(AttentionImpl):
kv_cache: torch.Tensor,
layer_slot_mapping: torch.Tensor,
):
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
return
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache,
layer.num_kv_heads, # type: ignore[attr-defined]

View File

@@ -6,6 +6,7 @@ import torch
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm import _custom_ops as ops
logger = init_logger(__name__)
@@ -151,7 +152,34 @@ def flash_mla_with_kvcache_fp8(
descale_k,
)
return out, softmax_lse
def flash_mla_sparse_prefill(
q: torch.Tensor,
kv: torch.Tensor,
indices: torch.Tensor,
sm_scale: float,
d_v: int = 512,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Sparse attention prefill kernel
Args:
- q: [s_q, h_q, d_qk], bfloat16
- kv: [s_kv, h_kv, d_qk], bfloat16
- indices: [s_q, h_kv, topk], int32.
Invalid indices should be set to -1 or numbers >= s_kv
- sm_scale: float
- d_v: The dimension of value vectors. Can only be 512
Returns:
- (output, max_logits, lse)
About the definition of output,
max_logits and lse, please refer to README.md
- output: [s_q, h_q, d_v], bfloat16
- max_logits: [s_q, h_q], float
- lse: [s_q, h_q], float, 2-based log-sum-exp
"""
results = ops.sparse_prefill_fwd(q, kv, indices,sm_scale, d_v)
return results
#
# TODO: Add fake functions

View File

@@ -37,8 +37,8 @@ def flash_attn_maxseqlen_wrapper(
else:
from vllm.v1.attention.backends.fa_utils import flash_attn_varlen_func
# if not current_platform.is_rocm() and fa_version is not None:
# kwargs["fa_version"] = fa_version
if not current_platform.is_rocm() and fa_version is not None:
kwargs["fa_version"] = fa_version
q_len = q.size(1)
if cu_seqlens is None:
@@ -268,3 +268,91 @@ def vit_torch_sdpa_wrapper(
return torch.ops.vllm.torch_sdpa_wrapper(
q, k, v, scale, cu_seqlens, enable_gqa=enable_gqa
)
def flashinfer_wrapper(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
scale: float,
workspace_buffer: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None,
sequence_lengths: torch.Tensor | None = None,
) -> torch.Tensor:
from flashinfer.prefill import cudnn_batch_prefill_with_kv_cache
is_reshaped = q.dim() == 4
if is_reshaped:
reshape_batch_size = q.shape[0]
q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
# cuDNN <= 9.10.2.21 requires q, k to be contiguous
# this comes with no cost for ViTs with RoPE because
# RoPE has already made q and k contiguous.
q, k = q.contiguous(), k.contiguous()
assert len(cu_seqlens) % 2 == 0, "cu_seqlens must be divisible by 2"
cu_seqlength = len(cu_seqlens) // 2
batch_offsets_qko = cu_seqlens[:cu_seqlength].view(-1, 1, 1, 1)
batch_offsets_v = cu_seqlens[cu_seqlength:].view(-1, 1, 1, 1)
sequence_lengths = sequence_lengths.view(-1, 1, 1, 1)
max_seqlen = max_seqlen.item()
output, _ = cudnn_batch_prefill_with_kv_cache(
q,
k,
v,
scale,
workspace_buffer,
max_token_per_sequence=max_seqlen,
max_sequence_kv=max_seqlen,
actual_seq_lens_q=sequence_lengths,
actual_seq_lens_kv=sequence_lengths,
causal=False,
return_lse=False,
batch_offsets_q=batch_offsets_qko,
batch_offsets_k=batch_offsets_qko,
batch_offsets_v=batch_offsets_v,
batch_offsets_o=batch_offsets_qko,
)
if is_reshaped:
output = einops.rearrange(output, "(b s) h d -> b s h d", b=reshape_batch_size)
return output
def vit_flashinfer_wrapper_fake(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
scale: float,
workspace_buffer: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None,
sequence_lengths: torch.Tensor | None = None,
) -> torch.Tensor:
return torch.empty_like(q)
direct_register_custom_op(
op_name="flashinfer_wrapper",
op_func=flashinfer_wrapper,
fake_impl=vit_flashinfer_wrapper_fake,
)
def vit_flashinfer_wrapper(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
scale: float,
workspace_buffer: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None,
sequence_lengths: torch.Tensor | None = None,
) -> torch.Tensor:
return torch.ops.vllm.flashinfer_wrapper(
q, k, v, scale, workspace_buffer, cu_seqlens, max_seqlen, sequence_lengths
)

View File

@@ -456,6 +456,37 @@ class KVCacheManager:
"""
return self.coordinator.get_num_common_prefix_blocks(running_request_id)
def get_num_free_blocks(self) -> int:
"""Get the number of free blocks in the pool."""
return self.block_pool.get_num_free_blocks()
def get_num_blocks_needed_for_tokens(
self,
request_id: str,
num_tokens_need_slot: int,
new_computed_blocks: KVCacheBlocks | None = None,
num_encoder_tokens: int = 0,
total_computed_tokens: int = 0,
num_tokens_main_model: int = 0,
) -> int:
"""Estimate number of blocks needed for a request (no allocation).
Used e.g. to check if enough KV cache exists for full chunked prefill
before allowing a request into running.
"""
if new_computed_blocks is not None:
new_computed_block_list = new_computed_blocks.blocks
else:
new_computed_block_list = self.empty_kv_cache_blocks.blocks
return self.coordinator.get_num_blocks_to_allocate(
request_id=request_id,
num_tokens=num_tokens_need_slot,
new_computed_blocks=new_computed_block_list,
num_encoder_tokens=num_encoder_tokens,
total_computed_tokens=total_computed_tokens,
num_tokens_main_model=num_tokens_main_model,
)
def take_events(self) -> list[KVCacheEvent]:
"""Take the KV cache events from the block pool.
@@ -489,6 +520,13 @@ class KVCacheManager:
# Only create new KVCacheBlocks for non-empty blocks
return KVCacheBlocks(blocks) if any(blocks) else self.empty_kv_cache_blocks
def take_new_block_ids(self) -> list[int]:
"""Drain and return new attention block IDs for zeroing."""
ids: list[int] = []
for mgr in self.coordinator.single_type_managers:
ids.extend(mgr.take_new_block_ids())
return ids
def new_step_starts(self) -> None:
"""Called when a new step is started."""
self.coordinator.new_step_starts()

View File

@@ -10,9 +10,8 @@ from collections.abc import Callable, Iterable, Iterator, Sequence
from dataclasses import dataclass, replace
from functools import partial
from typing import Any, NewType, TypeAlias, overload
import vllm.envs as envs
from vllm import envs
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils.hashing import sha256_cbor, xxhash_cbor
@@ -835,7 +834,7 @@ def may_override_num_blocks(vllm_config: VllmConfig, num_blocks: int) -> int:
def get_num_blocks(
vllm_config: VllmConfig, num_layers: int, available_memory: int, page_size: int
vllm_config: VllmConfig, num_layers: int, available_memory: int, page_size: int, scale_page_size: int
) -> int:
"""
Get the number of kv cache blocks.
@@ -846,7 +845,7 @@ def get_num_blocks(
available_memory: Memory available for KV cache in bytes.
page_size: The page size of the KV cache.
"""
num_blocks = int(available_memory // page_size // num_layers)
num_blocks = int(available_memory // (page_size + scale_page_size) // num_layers)
num_blocks = max(num_blocks, 0)
num_blocks = may_override_num_blocks(vllm_config, num_blocks)
return num_blocks
@@ -857,9 +856,14 @@ def get_uniform_page_size(kv_cache_specs: Iterable[KVCacheSpec]) -> int:
Get the page size of the KV cache.
"""
page_sizes = {layer.page_size_bytes for layer in kv_cache_specs}
scale_page_sizes = {layer.scale_page_size_bytes for layer in kv_cache_specs}
assert len(page_sizes) == 1
return page_sizes.pop()
if envs.VLLM_ATTN_OPT_LEVEL == 1:
v_cache_scale_sizes = set(layer.v_cache_scale_size_bytes for layer in kv_cache_specs)
assert len(v_cache_scale_sizes) == 1
return page_sizes.pop(), scale_page_sizes.pop(), v_cache_scale_sizes.pop()
else:
return page_sizes.pop(), scale_page_sizes.pop()
def _get_kv_cache_groups_uniform_spec(
kv_cache_specs: dict[str, KVCacheSpec],
@@ -955,6 +959,7 @@ def is_kv_cache_type_attention_free(kv_cache_spec: dict[str, KVCacheSpec]) -> bo
def _get_kv_cache_groups_uniform_page_size(
vllm_config: VllmConfig,
kv_cache_spec: dict[str, KVCacheSpec],
) -> list[KVCacheGroupSpec]:
"""
@@ -1015,6 +1020,7 @@ def _get_kv_cache_groups_uniform_page_size(
memory per block is the same for all groups.
Args:
vllm_config: The global VllmConfig
kv_cache_spec: The KVCacheSpec of each attention layer in the model
Returns:
The generated KVCacheGroupSpecs
@@ -1058,19 +1064,28 @@ def _get_kv_cache_groups_uniform_page_size(
num_padding_layers / len(layers) * 100,
)
num_groups = cdiv(len(layers), group_size)
# In PP case, say if we have
# - stage 0: full.0, sw.0, sw.1
# - stage 1: full.1, sw.2, sw.3
# We should have 3 groups: (full.0, full.1), (sw.0, sw.2), (sw.1, sw.3)
# It can't be (full.0, full.1), (sw.0, sw.1), (sw.2, sw.3) because
# the 3 groups in stage 0 will be (full.0), (sw.0, sw.1), (empty group)
# and it will be padded to (full.0, padding), (sw.0, sw.1),
# (padding, padding) to ensure the number of layers in each group is
# the same and will cause memory waste.
# To avoid this, we assign layers[i::num_groups] to the i-th group
# instead of layers[i * group_size: (i + 1) * group_size]
for i in range(num_groups):
grouped_layers.append(layers[i::num_groups])
# for support multi layer mtp, we need to
# make all mtp layers in the same group
if (
vllm_config.speculative_config is not None
and vllm_config.speculative_config.enable_multi_layers_mtp
):
for i in range(0, len(layers), group_size):
grouped_layers.append(layers[i : i + group_size])
else:
# In PP case, say if we have
# - stage 0: full.0, sw.0, sw.1
# - stage 1: full.1, sw.2, sw.3
# We should have 3 groups: (full.0, full.1), (sw.0, sw.2), (sw.1, sw.3)
# It can't be (full.0, full.1), (sw.0, sw.1), (sw.2, sw.3) because
# the 3 groups in stage 0 will be (full.0), (sw.0, sw.1), (empty group)
# and it will be padded to (full.0, padding), (sw.0, sw.1),
# (padding, padding) to ensure the number of layers in each group is
# the same and will cause memory waste.
# To avoid this, we assign layers[i::num_groups] to the i-th group
# instead of layers[i * group_size: (i + 1) * group_size]
for i in range(num_groups):
grouped_layers.append(layers[i::num_groups])
return create_kv_cache_group_specs(kv_cache_spec, grouped_layers)
@@ -1096,9 +1111,9 @@ def get_kv_cache_config_from_groups(
return KVCacheConfig(
num_blocks=1,
kv_cache_tensors=[],
kv_cache_scale_tensors=[],
kv_cache_groups=kv_cache_groups,
)
# Determine how model runners should initialize the KV cache tensors.
if len(kv_cache_groups) == 1 and isinstance(
kv_cache_groups[0].kv_cache_spec, UniformTypeKVCacheSpecs
@@ -1118,6 +1133,12 @@ def get_kv_cache_config_from_groups(
)
for layer_name in kv_cache_groups[0].layer_names
]
kv_cache_scale_tensors = [
KVCacheTensor(size=per_layer_specs[layer_name].scale_page_size_bytes *
num_blocks,
shared_by=[layer_name])
for layer_name in kv_cache_groups[0].layer_names
]
else:
# General case:
# We will have group_size memory pools, each is shared by one layer from
@@ -1129,55 +1150,39 @@ def get_kv_cache_config_from_groups(
# full.1, sw.2: share another Tensor with size=available_memory//2
group_size = max(len(group.layer_names) for group in kv_cache_groups)
page_size = get_uniform_page_size(
if envs.VLLM_ATTN_OPT_LEVEL == 1:
page_size, scale_page_size, v_cache_scale_size = get_uniform_page_size([group.kv_cache_spec for group in kv_cache_groups])
else:
page_size, scale_page_size = get_uniform_page_size(
[group.kv_cache_spec for group in kv_cache_groups]
)
v_cache_scale_size = 0
assert group_size > 0, "group_size must be greater than 0"
# num_blocks = get_num_blocks(
# vllm_config, group_size, available_memory, page_size
# )
num_blocks = get_num_blocks(
vllm_config,
group_size,
available_memory,
page_size,
scale_page_size,
)
kv_cache_tensors = []
# TODO: will add scale ?
kv_cache_scale_tensors = []
if envs.VLLM_KV_DISABLE_CROSS_GROUP_SHARE:
total_layers = sum(len(group.layer_names) for group in kv_cache_groups)
num_blocks = get_num_blocks(
vllm_config,
total_layers,
available_memory,
page_size,
for i in range(group_size):
shared_by = []
for j in range(len(kv_cache_groups)):
if i < len(kv_cache_groups[j].layer_names):
shared_by.append(kv_cache_groups[j].layer_names[i])
kv_cache_tensors.append(
KVCacheTensor(size=page_size * num_blocks, shared_by=shared_by)
)
for group in kv_cache_groups:
for layer_name in group.layer_names:
kv_cache_tensors.append(
KVCacheTensor(size=page_size * num_blocks, shared_by=[layer_name])
)
logger.warning(
"VLLM_KV_DISABLE_CROSS_GROUP_SHARE=1: using dedicated KV tensors per layer "
"(groups=%d, tensors=%d, num_blocks=%d)",
len(kv_cache_groups),
len(kv_cache_tensors),
num_blocks,
kv_cache_scale_tensors.append(
KVCacheTensor(size=scale_page_size * num_blocks, shared_by=shared_by, size_scale=v_cache_scale_size)
)
else:
num_blocks = get_num_blocks(
vllm_config,
group_size,
available_memory,
page_size,
)
for i in range(group_size):
shared_by = []
for j in range(len(kv_cache_groups)):
if i < len(kv_cache_groups[j].layer_names):
shared_by.append(kv_cache_groups[j].layer_names[i])
kv_cache_tensors.append(
KVCacheTensor(size=page_size * num_blocks, shared_by=shared_by)
)
return KVCacheConfig(
num_blocks=num_blocks,
kv_cache_tensors=kv_cache_tensors,
kv_cache_scale_tensors = kv_cache_scale_tensors,
kv_cache_groups=kv_cache_groups,
)
@@ -1284,7 +1289,9 @@ def get_kv_cache_groups(
# have the same physical memory per block per layer. Split the layers
# into groups with the same number of layers, and thus same total page
# size.
return _get_kv_cache_groups_uniform_page_size(kv_cache_spec)
return _get_kv_cache_groups_uniform_page_size(
vllm_config=vllm_config, kv_cache_spec=kv_cache_spec
)
def generate_scheduler_kv_cache_config(
@@ -1381,11 +1388,17 @@ def _max_memory_usage_bytes_from_groups(
# General case: group_size pools, each shared by one layer per group
# Memory = group_size * page_size * blocks_for_max_len
group_size = max(len(group.layer_names) for group in kv_cache_groups)
page_size = get_uniform_page_size(
[group.kv_cache_spec for group in kv_cache_groups]
)
if envs.VLLM_ATTN_OPT_LEVEL == 1:
page_size, scale_page_size, v_cache_scale_size = get_uniform_page_size(
[group.kv_cache_spec for group in kv_cache_groups]
)
else:
page_size, scale_page_size = get_uniform_page_size(
[group.kv_cache_spec for group in kv_cache_groups]
)
v_cache_scale_size = 0
any_spec = kv_cache_groups[0].kv_cache_spec
blocks_needed = cdiv(any_spec.max_memory_usage_bytes(vllm_config), page_size)
blocks_needed = cdiv(any_spec.max_memory_usage_bytes(vllm_config), (page_size + scale_page_size + v_cache_scale_size))
return group_size * page_size * blocks_needed
@@ -1633,6 +1646,10 @@ def get_kv_cache_configs(
for tensor in kv_cache_config.kv_cache_tensors:
assert tensor.size % num_blocks_old == 0
tensor.size = tensor.size // num_blocks_old * min_num_blocks
for tensor in kv_cache_config.kv_cache_scale_tensors:
assert tensor.size % num_blocks_old == 0
tensor.size = tensor.size // num_blocks_old * min_num_blocks
if len(kv_cache_config.kv_cache_groups) > 0:
_report_kv_cache_config(vllm_config, kv_cache_config)

View File

@@ -5,8 +5,6 @@ from dataclasses import dataclass
from functools import cached_property
from typing import TYPE_CHECKING
from vllm._bc_linter import bc_linter_include
if TYPE_CHECKING:
import numpy as np
import numpy.typing as npt
@@ -29,7 +27,6 @@ else:
Request = object
@bc_linter_include
@dataclass
class NewRequestData:
req_id: str
@@ -109,7 +106,6 @@ class NewRequestData:
)
@bc_linter_include
@dataclass
class CachedRequestData:
req_ids: list[str]
@@ -179,7 +175,6 @@ class CachedRequestData:
)
@bc_linter_include
@dataclass
class SchedulerOutput:
# list of the requests that are scheduled for the first time.
@@ -217,6 +212,9 @@ class SchedulerOutput:
# freed from the encoder cache.
free_encoder_mm_hashes: list[str]
# Request IDs that are resumed from preemption in this step.
scheduled_resumed_reqs: list[str] | None = None
# Request IDs that are preempted in this step.
# Only used for v2 model runner.
preempted_req_ids: set[str] | None = None
@@ -238,6 +236,11 @@ class SchedulerOutput:
# EC Cache Connector metadata
ec_connector_metadata: ECConnectorMetadata | None = None
# Block IDs freshly allocated from the pool during this scheduling step.
# The worker zeros the corresponding GPU memory before the blocks are used,
# preventing stale NaN/data from corrupting attention or SSM computation.
new_block_ids_to_zero: list[int] | None = None
@classmethod
def make_empty(cls) -> "SchedulerOutput":
return cls(

View File

@@ -48,7 +48,7 @@ from vllm.v1.core.sched.output import (
from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue
from vllm.v1.core.sched.utils import check_stop, remove_all
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs
from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.metrics.perf import ModelMetrics, PerfStats
from vllm.v1.metrics.stats import PrefixCacheStats, SchedulerStats
from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput
@@ -233,13 +233,8 @@ class Scheduler(SchedulerInterface):
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER
def has_mamba_layers(kv_cache_config: KVCacheConfig) -> bool:
return any(
isinstance(group_spec.kv_cache_spec, MambaSpec)
for group_spec in kv_cache_config.kv_cache_groups
)
self.has_mamba_layers = has_mamba_layers(kv_cache_config)
self.has_mamba_layers = kv_cache_config.has_mamba_layers
self.needs_kv_cache_zeroing = kv_cache_config.needs_kv_cache_zeroing
self.need_mamba_block_aligned_split = (
self.has_mamba_layers and self.cache_config.mamba_cache_mode == "align"
)
@@ -320,6 +315,9 @@ class Scheduler(SchedulerInterface):
return num_new_tokens
def schedule(self) -> SchedulerOutput:
if envs.VLLM_ENABLE_PP_MIX_ILU_SCHEDULING:
return self.schedule_opt()
# NOTE(woosuk) on the scheduling algorithm:
# There's no "decoding phase" nor "prefill phase" in the scheduler.
# Each request just has the num_computed_tokens and
@@ -413,7 +411,7 @@ class Scheduler(SchedulerInterface):
request, num_new_tokens
)
if num_new_tokens == 0:
if num_new_tokens <= 0:
# The request cannot be scheduled because one of the following
# reasons:
# 1. No new tokens to schedule. This may happen when
@@ -425,6 +423,8 @@ class Scheduler(SchedulerInterface):
# 3. The encoder cache is exhausted.
# 4. Insufficient budget for a block-aligned chunk in hybrid
# models with mamba cache mode \"align\".
# 5. num_computed_tokens > num_tokens_with_spec due to PP
# timing: schedule() runs before update_from_output().
# NOTE(woosuk): Here, by doing `continue` instead of `break`,
# we do not strictly follow the FCFS scheduling policy and
# allow the lower-priority requests to be scheduled.
@@ -670,7 +670,7 @@ class Scheduler(SchedulerInterface):
# If chunked_prefill is disabled,
# we can stop the scheduling here.
break
temp_num_new_tokens = num_new_tokens
num_new_tokens = min(num_new_tokens, token_budget)
assert num_new_tokens > 0
@@ -688,7 +688,7 @@ class Scheduler(SchedulerInterface):
encoder_compute_budget,
shift_computed_tokens=1 if self.use_eagle else 0,
)
if num_new_tokens == 0 or num_new_tokens < temp_num_new_tokens:
if num_new_tokens == 0:
# The request cannot be scheduled.
break
@@ -723,6 +723,35 @@ class Scheduler(SchedulerInterface):
for i in encoder_inputs_to_schedule
)
if not load_kv_async:
enable_chunked = self.scheduler_config.enable_chunked_prefill
tokens_still_to_compute = (
request.num_tokens - num_computed_tokens
)
is_chunked = (
enable_chunked
and tokens_still_to_compute > num_new_tokens
)
if is_chunked:
assert (
request.num_tokens <= self.max_model_len
), "request.num_tokens must not exceed max_model_len"
num_tokens_need_slot = min(
request.num_tokens + effective_lookahead_tokens,
self.max_model_len,
)
blocks_needed = (
self.kv_cache_manager.get_num_blocks_needed_for_tokens(
request.request_id,
num_tokens_need_slot,
new_computed_blocks,
num_encoder_tokens,
)
)
num_free = self.kv_cache_manager.get_num_free_blocks()
if num_free < blocks_needed:
break
new_blocks = self.kv_cache_manager.allocate_slots(
request,
num_new_tokens,
@@ -871,6 +900,12 @@ class Scheduler(SchedulerInterface):
self.prev_step_scheduled_req_ids.clear()
self.prev_step_scheduled_req_ids.update(num_scheduled_tokens.keys())
new_block_ids_to_zero = (
(self.kv_cache_manager.take_new_block_ids() or None)
if self.needs_kv_cache_zeroing
else None
)
scheduler_output = SchedulerOutput(
scheduled_new_reqs=new_reqs_data,
scheduled_cached_reqs=cached_reqs_data,
@@ -886,6 +921,7 @@ class Scheduler(SchedulerInterface):
# the previous and the current steps.
finished_req_ids=self.finished_req_ids,
free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(),
new_block_ids_to_zero=new_block_ids_to_zero,
)
# NOTE(Kuntai): this function is designed for multiple purposes:
@@ -909,6 +945,527 @@ class Scheduler(SchedulerInterface):
self._update_after_schedule(scheduler_output)
return scheduler_output
def schedule_opt(self) -> SchedulerOutput:
"""PP mix ILU scheduling variant of schedule()."""
scheduled_new_reqs: list[Request] = []
scheduled_resumed_reqs: list[Request] = []
scheduled_running_reqs: list[Request] = []
preempted_reqs: list[Request] = []
req_to_new_blocks: dict[str, KVCacheBlocks] = {}
num_scheduled_tokens: dict[str, int] = {}
token_budget = self.max_num_scheduled_tokens
if self._pause_state == PauseState.PAUSED_ALL:
token_budget = 0
# Encoder-related.
scheduled_encoder_inputs: dict[str, list[int]] = {}
encoder_compute_budget = self.max_num_encoder_input_tokens
# Spec decode-related.
scheduled_spec_decode_tokens: dict[str, list[int]] = {}
# For logging.
scheduled_timestamp = time.monotonic()
self.kv_cache_manager.new_step_starts()
# First, schedule the RUNNING requests.
req_index = 0
while req_index < len(self.running) and token_budget > 0:
request = self.running[req_index]
if (
request.num_output_placeholders > 0
and request.num_computed_tokens + 2 - request.num_output_placeholders
>= request.num_prompt_tokens + request.max_tokens
):
req_index += 1
continue
num_new_tokens = (
request.num_tokens_with_spec
+ request.num_output_placeholders
- request.num_computed_tokens
)
if 0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens:
num_new_tokens = self.scheduler_config.long_prefill_token_threshold
num_new_tokens = min(num_new_tokens, token_budget)
num_new_tokens = min(
num_new_tokens, self.max_model_len - 1 - request.num_computed_tokens
)
# Schedule encoder inputs.
encoder_inputs_to_schedule = None
external_load_encoder_input: list[int] = []
new_encoder_compute_budget = encoder_compute_budget
if request.has_encoder_inputs:
(
encoder_inputs_to_schedule,
num_new_tokens,
new_encoder_compute_budget,
external_load_encoder_input,
) = self._try_schedule_encoder_inputs(
request,
request.num_computed_tokens,
num_new_tokens,
encoder_compute_budget,
shift_computed_tokens=1 if self.use_eagle else 0,
)
if self.need_mamba_block_aligned_split:
num_new_tokens = self._mamba_block_aligned_split(
request, num_new_tokens
)
if num_new_tokens <= 0:
# The request cannot be scheduled because one of the following
# reasons:
# 1. No new tokens to schedule. This may happen when
# (1) PP>1 and we have already scheduled all prompt tokens
# but they are not finished yet.
# (2) Async scheduling and the request has reached to either
# its max_total_tokens or max_model_len.
# 2. The encoder budget is exhausted.
# 3. The encoder cache is exhausted.
# 4. num_computed_tokens > num_tokens_with_spec due to PP
# timing: schedule() runs before update_from_output().
req_index += 1
continue
# Schedule newly needed KV blocks for the request.
with record_function_or_nullcontext("schedule: allocate_slots"):
while True:
new_blocks = self.kv_cache_manager.allocate_slots(
request,
num_new_tokens,
num_lookahead_tokens=self.num_lookahead_tokens,
)
if new_blocks is not None:
break
if self.policy == SchedulingPolicy.PRIORITY:
preempted_req = max(
self.running,
key=lambda r: (r.priority, r.arrival_time),
)
self.running.remove(preempted_req)
if preempted_req in scheduled_running_reqs:
preempted_req_id = preempted_req.request_id
scheduled_running_reqs.remove(preempted_req)
token_budget += num_scheduled_tokens.pop(preempted_req_id)
req_to_new_blocks.pop(preempted_req_id)
scheduled_spec_decode_tokens.pop(preempted_req_id, None)
preempted_encoder_inputs = scheduled_encoder_inputs.pop(
preempted_req_id, None
)
if preempted_encoder_inputs:
num_embeds_to_restore = sum(
preempted_req.get_num_encoder_embeds(i)
for i in preempted_encoder_inputs
)
encoder_compute_budget += num_embeds_to_restore
req_index -= 1
else:
preempted_req = self.running.pop()
self._preempt_request(preempted_req, scheduled_timestamp)
preempted_reqs.append(preempted_req)
if preempted_req == request:
break
if new_blocks is None:
break
scheduled_running_reqs.append(request)
request_id = request.request_id
req_to_new_blocks[request_id] = new_blocks
num_scheduled_tokens[request_id] = num_new_tokens
token_budget -= num_new_tokens
req_index += 1
if request.spec_token_ids:
num_scheduled_spec_tokens = (
num_new_tokens
+ request.num_computed_tokens
- request.num_tokens
- request.num_output_placeholders
)
if num_scheduled_spec_tokens > 0:
spec_token_ids = request.spec_token_ids
if len(spec_token_ids) > num_scheduled_spec_tokens:
spec_token_ids = spec_token_ids[:num_scheduled_spec_tokens]
scheduled_spec_decode_tokens[request.request_id] = spec_token_ids
request.spec_token_ids = []
if encoder_inputs_to_schedule:
scheduled_encoder_inputs[request_id] = encoder_inputs_to_schedule
for i in encoder_inputs_to_schedule:
self.encoder_cache_manager.allocate(request, i)
encoder_compute_budget = new_encoder_compute_budget
if external_load_encoder_input:
for i in external_load_encoder_input:
self.encoder_cache_manager.allocate(request, i)
if self.ec_connector is not None:
self.ec_connector.update_state_after_alloc(request, i)
# Record the LoRAs in scheduled_running_reqs
scheduled_loras: set[int] = set()
if self.lora_config:
scheduled_loras = set(
req.lora_request.lora_int_id
for req in scheduled_running_reqs
if req.lora_request and req.lora_request.lora_int_id > 0
)
assert len(scheduled_loras) <= self.lora_config.max_loras
# Next, schedule the WAITING requests.
if not preempted_reqs and self._pause_state == PauseState.UNPAUSED:
skipped_waiting_requests = create_request_queue(self.policy)
while self.waiting and token_budget > 0:
if len(self.running) == self.max_num_running_reqs:
break
request = self.waiting.peek_request()
request_id = request.request_id
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
is_ready = self._update_waiting_for_remote_kv(request)
if is_ready:
if request.num_preemptions:
request.status = RequestStatus.PREEMPTED
else:
request.status = RequestStatus.WAITING
else:
logger.debug(
"%s is still in WAITING_FOR_REMOTE_KVS state.",
request_id,
)
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
if request.status == RequestStatus.WAITING_FOR_FSM:
structured_output_req = request.structured_output_request
if structured_output_req and structured_output_req.grammar:
request.status = RequestStatus.WAITING
else:
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
if request.status == RequestStatus.WAITING_FOR_STREAMING_REQ:
assert not request.streaming_queue
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
if (
self.lora_config
and request.lora_request
and (
len(scheduled_loras) == self.lora_config.max_loras
and request.lora_request.lora_int_id not in scheduled_loras
)
):
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
num_external_computed_tokens = 0
load_kv_async = False
connector_prefix_cache_queries, connector_prefix_cache_hits = 0, 0
if request.num_computed_tokens == 0:
new_computed_blocks, num_new_local_computed_tokens = (
self.kv_cache_manager.get_computed_blocks(request)
)
if self.connector is not None:
ext_tokens, load_kv_async = (
self.connector.get_num_new_matched_tokens(
request, num_new_local_computed_tokens
)
)
if ext_tokens is None:
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
request.num_external_computed_tokens = ext_tokens
num_external_computed_tokens = ext_tokens
connector_prefix_cache_queries = (
request.num_tokens - num_new_local_computed_tokens
)
connector_prefix_cache_hits = num_external_computed_tokens
num_computed_tokens = (
num_new_local_computed_tokens + num_external_computed_tokens
)
else:
new_computed_blocks = self.kv_cache_manager.empty_kv_cache_blocks
num_new_local_computed_tokens = 0
num_computed_tokens = request.num_computed_tokens
encoder_inputs_to_schedule = None
external_load_encoder_input = []
new_encoder_compute_budget = encoder_compute_budget
if load_kv_async:
assert num_external_computed_tokens > 0
num_new_tokens = 0
else:
num_new_tokens = request.num_tokens - num_computed_tokens
threshold = self.scheduler_config.long_prefill_token_threshold
if 0 < threshold < num_new_tokens:
num_new_tokens = threshold
if (
not self.scheduler_config.enable_chunked_prefill
and num_new_tokens > token_budget
):
break
num_new_tokens = min(num_new_tokens, token_budget)
assert num_new_tokens > 0
if request.has_encoder_inputs:
(
encoder_inputs_to_schedule,
num_new_tokens,
new_encoder_compute_budget,
external_load_encoder_input,
) = self._try_schedule_encoder_inputs(
request,
num_computed_tokens,
num_new_tokens,
encoder_compute_budget,
shift_computed_tokens=1 if self.use_eagle else 0,
)
if num_new_tokens == 0:
break
if self.need_mamba_block_aligned_split:
num_new_tokens = self._mamba_block_aligned_split(
request,
num_new_tokens,
num_new_local_computed_tokens,
num_external_computed_tokens,
)
if num_new_tokens == 0:
break
effective_lookahead_tokens = (
0 if request.num_computed_tokens == 0 else self.num_lookahead_tokens
)
num_encoder_tokens = 0
if (
self.is_encoder_decoder
and request.has_encoder_inputs
and encoder_inputs_to_schedule
):
num_encoder_tokens = sum(
request.get_num_encoder_embeds(i)
for i in encoder_inputs_to_schedule
)
if not load_kv_async:
enable_chunked = self.scheduler_config.enable_chunked_prefill
tokens_still_to_compute = (
request.num_tokens - num_computed_tokens
)
is_chunked = (
enable_chunked
and tokens_still_to_compute > num_new_tokens
)
if is_chunked:
assert (
request.num_tokens <= self.max_model_len
), "request.num_tokens must not exceed max_model_len"
num_tokens_need_slot = min(
request.num_tokens + effective_lookahead_tokens,
self.max_model_len,
)
blocks_needed = (
self.kv_cache_manager.get_num_blocks_needed_for_tokens(
request.request_id,
num_tokens_need_slot,
new_computed_blocks,
num_encoder_tokens,
)
)
num_free = self.kv_cache_manager.get_num_free_blocks()
if num_free < blocks_needed:
break
new_blocks = self.kv_cache_manager.allocate_slots(
request,
num_new_tokens,
num_new_computed_tokens=num_new_local_computed_tokens,
new_computed_blocks=new_computed_blocks,
num_lookahead_tokens=effective_lookahead_tokens,
num_external_computed_tokens=num_external_computed_tokens,
delay_cache_blocks=load_kv_async,
num_encoder_tokens=num_encoder_tokens,
)
if new_blocks is None:
if request.has_encoder_inputs:
self.encoder_cache_manager.free(request)
break
if self.connector is not None:
self.connector.update_state_after_alloc(
request,
self.kv_cache_manager.get_blocks(request_id),
num_external_computed_tokens,
)
if (
self.connector_prefix_cache_stats is not None
and connector_prefix_cache_queries != 0
):
self.connector_prefix_cache_stats.record(
num_tokens=connector_prefix_cache_queries,
num_hits=connector_prefix_cache_hits,
preempted=request.num_preemptions > 0,
)
request = self.waiting.pop_request()
if load_kv_async:
skipped_waiting_requests.prepend_request(request)
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
continue
self.running.append(request)
if self.log_stats:
request.record_event(
EngineCoreEventType.SCHEDULED, scheduled_timestamp
)
if request.status == RequestStatus.WAITING:
scheduled_new_reqs.append(request)
elif request.status == RequestStatus.PREEMPTED:
scheduled_resumed_reqs.append(request)
else:
raise RuntimeError(f"Invalid request status: {request.status}")
if self.lora_config and request.lora_request:
scheduled_loras.add(request.lora_request.lora_int_id)
req_to_new_blocks[request_id] = self.kv_cache_manager.get_blocks(
request_id
)
num_scheduled_tokens[request_id] = num_new_tokens
token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING
request.num_computed_tokens = num_computed_tokens
if request.num_cached_tokens < 0:
request.num_cached_tokens = num_computed_tokens
if encoder_inputs_to_schedule:
scheduled_encoder_inputs[request_id] = encoder_inputs_to_schedule
for i in encoder_inputs_to_schedule:
self.encoder_cache_manager.allocate(request, i)
encoder_compute_budget = new_encoder_compute_budget
if external_load_encoder_input:
for i in external_load_encoder_input:
self.encoder_cache_manager.allocate(request, i)
if self.ec_connector is not None:
self.ec_connector.update_state_after_alloc(request, i)
if skipped_waiting_requests:
self.waiting.prepend_requests(skipped_waiting_requests)
# Check if the scheduling constraints are satisfied.
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
assert token_budget >= 0
assert len(self.running) <= self.max_num_running_reqs
assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len(
scheduled_running_reqs
) <= len(self.running)
num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups)
with record_function_or_nullcontext("schedule: get_num_common_prefix_blocks"):
if self.running:
any_request_id = self.running[0].request_id
num_common_prefix_blocks = (
self.kv_cache_manager.get_num_common_prefix_blocks(any_request_id)
)
if self.use_v2_model_runner:
scheduled_new_reqs = scheduled_new_reqs + scheduled_resumed_reqs
scheduled_resumed_reqs = []
new_reqs_data = [
NewRequestData.from_request(
req,
req_to_new_blocks[req.request_id].get_block_ids(),
req._all_token_ids,
)
for req in scheduled_new_reqs
]
else:
new_reqs_data = [
NewRequestData.from_request(
req, req_to_new_blocks[req.request_id].get_block_ids()
)
for req in scheduled_new_reqs
]
with record_function_or_nullcontext("schedule: make_cached_request_data"):
cached_reqs_data = self._make_cached_request_data(
scheduled_running_reqs,
scheduled_resumed_reqs,
num_scheduled_tokens,
scheduled_spec_decode_tokens,
req_to_new_blocks,
)
self.prev_step_scheduled_req_ids.clear()
self.prev_step_scheduled_req_ids.update(num_scheduled_tokens.keys())
new_block_ids_to_zero = (
(self.kv_cache_manager.take_new_block_ids() or None)
if self.needs_kv_cache_zeroing
else None
)
scheduler_output = SchedulerOutput(
scheduled_new_reqs=new_reqs_data,
scheduled_cached_reqs=cached_reqs_data,
scheduled_resumed_reqs=[r.request_id for r in scheduled_resumed_reqs],
num_scheduled_tokens=num_scheduled_tokens,
total_num_scheduled_tokens=total_num_scheduled_tokens,
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
scheduled_encoder_inputs=scheduled_encoder_inputs,
num_common_prefix_blocks=num_common_prefix_blocks,
preempted_req_ids={req.request_id for req in preempted_reqs},
finished_req_ids=self.finished_req_ids,
free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(),
new_block_ids_to_zero=new_block_ids_to_zero,
)
if self.connector is not None:
meta: KVConnectorMetadata = self.connector.build_connector_meta(
scheduler_output
)
scheduler_output.kv_connector_metadata = meta
if self.ec_connector is not None:
ec_meta: ECConnectorMetadata = self.ec_connector.build_connector_meta(
scheduler_output
)
scheduler_output.ec_connector_metadata = ec_meta
with record_function_or_nullcontext("schedule: update_after_schedule"):
self._update_after_schedule(scheduler_output)
return scheduler_output
def _preempt_request(self, request: Request, timestamp: float) -> None:
"""Preempt a request and put it back to the waiting queue.
@@ -1193,7 +1750,6 @@ class Scheduler(SchedulerInterface):
# available. In this case, we can't schedule any token for
# the request in this step.
num_new_tokens = 0
num_new_tokens = 0
break
# Calculate the number of embeddings to schedule in the current range
@@ -1508,6 +2064,9 @@ class Scheduler(SchedulerInterface):
# outputs this step.
engine_core_outputs[0] = eco = EngineCoreOutputs()
eco.scheduler_stats = stats
if model_runner_output.draft_token_ids is not None:
self.update_draft_token_ids(model_runner_output.draft_token_ids)
return engine_core_outputs

View File

@@ -1,10 +1,64 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
from collections.abc import Sequence
from vllm.sampling_params import RepetitionDetectionParams
from vllm.v1.request import Request, RequestStatus
def _has_repeating_pattern(
token_ids: Sequence[int],
pattern_len: int,
repetition_min_count: int,
) -> bool:
"""Check if the tail of token_ids contains a repeating pattern.
Compares the last pattern_len tokens against the preceding
(repetition_min_count - 1) repetitions of the same length.
"""
for n in range(1, pattern_len + 1):
target_token = token_ids[-n]
for m in range(1, repetition_min_count):
if token_ids[-(pattern_len * m + n)] != target_token:
return False
return True
def check_sequence_repetition(
token_ids: Sequence[int],
params: RepetitionDetectionParams,
) -> bool:
"""Check if a sequence of token IDs has a repetition pattern.
Args:
token_ids: List of token IDs
params: Repetition detection parameters.
Returns:
True if a repetition pattern is found, False otherwise.
"""
max_pattern_size = params.max_pattern_size
min_pattern_size = params.min_pattern_size
min_count = params.min_count
if min_pattern_size <= 0:
min_pattern_size = 1
if max_pattern_size <= 0 or min_count < 2 or min_pattern_size > max_pattern_size:
return False
for pattern_len in range(
min_pattern_size,
max_pattern_size + 1,
):
if pattern_len * min_count > len(token_ids):
return False
if _has_repeating_pattern(token_ids, pattern_len, min_count):
return True
return False
def remove_all(lst: list, items_to_remove: set) -> list:
"""Remove all items from a list that are in the items_to_remove set.
@@ -61,4 +115,16 @@ def check_stop(request: Request, max_model_len: int) -> bool:
):
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
return True
repetition_detection = sampling_params.repetition_detection
if repetition_detection is not None and (
check_sequence_repetition(
request.output_token_ids,
repetition_detection,
)
):
request.status = RequestStatus.FINISHED_REPETITION
request.stop_reason = "repetition_detected"
return True
return False

View File

@@ -55,6 +55,7 @@ class SingleTypeKVCacheManager(ABC):
self.kv_cache_spec = kv_cache_spec
self.block_pool = block_pool
self.enable_caching = enable_caching
self.new_block_ids: list[int] = []
# Mapping from request ID to blocks to track the blocks allocated
# for each request, so that we can free the blocks when the request
@@ -208,6 +209,8 @@ class SingleTypeKVCacheManager(ABC):
cdiv(num_total_computed_tokens, self.block_size) - len(req_blocks)
)
req_blocks.extend(allocated_blocks)
if type(self.kv_cache_spec) is FullAttentionSpec:
self.new_block_ids.extend(b.block_id for b in allocated_blocks)
def allocate_new_blocks(
self, request_id: str, num_tokens: int, num_tokens_main_model: int
@@ -234,8 +237,16 @@ class SingleTypeKVCacheManager(ABC):
else:
new_blocks = self.block_pool.get_new_blocks(num_new_blocks)
req_blocks.extend(new_blocks)
if type(self.kv_cache_spec) is FullAttentionSpec:
self.new_block_ids.extend(b.block_id for b in new_blocks)
return new_blocks
def take_new_block_ids(self) -> list[int]:
"""Drain and return block IDs allocated since the last call."""
ids = self.new_block_ids
self.new_block_ids = []
return ids
def cache_blocks(self, request: Request, num_tokens: int) -> None:
"""
Cache the blocks for the request.

View File

@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Set as AbstractSet
from dataclasses import replace
from itertools import product
@@ -136,7 +137,7 @@ class CudagraphDispatcher:
num_tokens_padded = self._bs_to_padded_graph_size[num_tokens]
if uniform_decode and self.cudagraph_mode.has_mode(CUDAGraphMode.FULL):
num_reqs = num_tokens_padded // uniform_decode_query_len
num_reqs = min(num_tokens_padded // uniform_decode_query_len, max_num_seqs)
assert num_tokens_padded % uniform_decode_query_len == 0
else:
uniform_decode = False
@@ -232,8 +233,9 @@ class CudagraphDispatcher:
num_tokens: int,
uniform_decode: bool = False,
has_lora: bool = False,
disable_full: bool = False,
num_active_loras: int = 0,
valid_modes: AbstractSet[CUDAGraphMode] | None = None,
invalid_modes: AbstractSet[CUDAGraphMode] | None = None,
) -> tuple[CUDAGraphMode, BatchDescriptor]:
"""
Given conditions(e.g.,batch descriptor and if using piecewise only),
@@ -246,15 +248,29 @@ class CudagraphDispatcher:
uniform_decode: Whether the batch is uniform decode (i.e. uniform and query
length is uniform_decode_query_len).
has_lora: Whether LoRA is active.
disable_full: If True, skip FULL cudagraph checks and
return PIECEWISE or NONE only. (can be used for features like
cascade attention that are not supported by full cudagraphs)
num_active_loras: Number of distinct active LoRA adapters.
valid_modes: Set of cudagraph modes that are allowed. None means
all modes are allowed.
invalid_modes: Set of cudagraph modes to exclude. Subtracted from
valid_modes to compute allowed modes. (e.g., {FULL} for
features like cascade attention not supported by full
cudagraphs). None means no modes are excluded.
"""
allowed_modes = valid_modes or CUDAGraphMode.valid_runtime_modes()
if invalid_modes:
allowed_modes -= invalid_modes
assert len(allowed_modes) >= 1, (
f"No allowed cudagraph modes: valid_modes={valid_modes}, "
f"invalid_modes={invalid_modes}"
)
if (
not self.keys_initialized
or self.cudagraph_mode == CUDAGraphMode.NONE
or num_tokens > self.compilation_config.max_cudagraph_capture_size
or allowed_modes <= {CUDAGraphMode.NONE}
):
return CUDAGraphMode.NONE, BatchDescriptor(num_tokens)
@@ -281,24 +297,26 @@ class CudagraphDispatcher:
num_tokens, uniform_decode, has_lora, effective_num_active_loras
)
# check if key exists for full cudagraph
# For pure FULL mode, keys are registered with uniform=False.
batch_desc_to_check = batch_desc
if self.cudagraph_mode == CUDAGraphMode.FULL:
batch_desc_to_check = replace(batch_desc, uniform=False)
if (
not disable_full
and batch_desc_to_check in self.cudagraph_keys[CUDAGraphMode.FULL]
):
return CUDAGraphMode.FULL, batch_desc_to_check
if CUDAGraphMode.FULL in allowed_modes:
# check if key exists for full cudagraph
# For pure FULL mode, keys are registered with uniform=False.
batch_desc_to_check = batch_desc
if self.cudagraph_mode == CUDAGraphMode.FULL:
batch_desc_to_check = replace(batch_desc, uniform=False)
if batch_desc_to_check in self.cudagraph_keys[CUDAGraphMode.FULL]:
return CUDAGraphMode.FULL, batch_desc_to_check
# also check if the relaxed key exists for more "general"
# piecewise cudagraph
batch_desc_to_check = replace(batch_desc, num_reqs=None, uniform=False)
if batch_desc_to_check in self.cudagraph_keys[CUDAGraphMode.PIECEWISE]:
return CUDAGraphMode.PIECEWISE, batch_desc_to_check
if CUDAGraphMode.PIECEWISE in allowed_modes:
# also check if the relaxed key exists for more "general"
# piecewise cudagraph
batch_desc_to_check = replace(batch_desc, num_reqs=None, uniform=False)
if batch_desc_to_check in self.cudagraph_keys[CUDAGraphMode.PIECEWISE]:
return CUDAGraphMode.PIECEWISE, batch_desc_to_check
# finally, just return no cudagraphs and a trivial batch descriptor
assert CUDAGraphMode.NONE in allowed_modes, (
f"No matching cudagraph found and NONE is not in "
f"allowed_modes={allowed_modes}"
)
return CUDAGraphMode.NONE, BatchDescriptor(num_tokens)
def get_capture_descs(self) -> list[tuple[CUDAGraphMode, list[BatchDescriptor]]]:

View File

@@ -27,12 +27,21 @@ PauseMode = Literal["abort", "wait", "keep"]
# These are possible values of RequestOutput.finish_reason,
# so form part of the external API.
FINISH_REASON_STRINGS = ("stop", "length", "abort", "error")
FINISH_REASON_STRINGS = ("stop", "length", "abort", "error", "repetition")
EEP_NOTIFICATION_CALL_ID = -1
class EEPNotificationType(enum.Enum):
NEW_CORE_ENGINES_INIT_READY = "NEW_CORE_ENGINES_INIT_READY"
NEW_CORE_ENGINES_WEIGHTS_INIT_READY = "NEW_CORE_ENGINES_WEIGHTS_INIT_READY"
RECONFIGURE_FINISHED = "RECONFIGURE_FINISHED"
SHUTDOWN_COMPLETE = "SHUTDOWN_COMPLETE"
class FinishReason(enum.IntEnum):
"""
Reason a request finished - stop, length, abort, or error.
Reason a request finished - stop, length, abort, error, or repetition.
Int rather than Str for more compact serialization.
@@ -41,6 +50,7 @@ class FinishReason(enum.IntEnum):
abort - aborted by client
error - retryable request-level internal error (e.g., KV load failure).
Invariant: always converted to 500 Internal Server Error.
repetition - repetitive token pattern detected (hallucination)
"""
@@ -48,6 +58,7 @@ class FinishReason(enum.IntEnum):
LENGTH = 1
ABORT = 2
ERROR = 3
REPETITION = 4
def __str__(self):
return FINISH_REASON_STRINGS[self.value]
@@ -235,6 +246,11 @@ class ReconfigureDistributedRequest(msgspec.Struct):
new_data_parallel_rank_local: int
new_data_parallel_master_ip: str
new_data_parallel_master_port: int
new_data_parallel_master_port_list: list[int]
new_stateless_world_group_port_list: list[list[int]]
new_stateless_dp_group_port_list: list[list[int]]
new_stateless_ep_group_port_list: list[list[int]]
new_stateless_eplb_group_port_list: list[list[int]]
class ReconfigureRankType(enum.IntEnum):

View File

@@ -20,6 +20,7 @@ from vllm.distributed.weight_transfer.base import (
)
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import EngineClient, StreamingInput
from vllm.entrypoints.serve.elastic_ep.middleware import set_scaling_elastic_ep
from vllm.inputs import ProcessorInputs, PromptType
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
@@ -134,6 +135,7 @@ class AsyncLLM(EngineClient):
self.renderer = renderer = renderer_from_config(self.vllm_config)
self.io_processor = get_io_processor(
self.vllm_config,
self.renderer,
self.model_config.io_processor_plugin,
)
@@ -647,7 +649,11 @@ class AsyncLLM(EngineClient):
engine_core = self.engine_core
output_processor = self.output_processor
log_stats = self.log_stats
logger_manager = self.logger_manager
# We use a mutable list for logger_manager so that it can be updated
# during elastic EP scaling (see scale_elastic_ep) without creating
# a circular reference via self.
self._logger_ref = [self.logger_manager]
logger_ref = self._logger_ref
renderer = self.renderer
chunk_size = envs.VLLM_V1_OUTPUT_PROC_CHUNK_SIZE
@@ -691,8 +697,8 @@ class AsyncLLM(EngineClient):
# 4) Logging.
# TODO(rob): make into a coroutine and launch it in
# background thread once Prometheus overhead is non-trivial.
if logger_manager:
logger_manager.record(
if logger_ref[0]:
logger_ref[0].record(
engine_idx=outputs.engine_index,
scheduler_stats=outputs.scheduler_stats,
iteration_stats=iteration_stats,
@@ -976,17 +982,13 @@ class AsyncLLM(EngineClient):
new_data_parallel_size,
)
return
logger.info(
"Waiting for requests to drain before scaling up to %s engines...",
new_data_parallel_size,
)
await self.wait_for_requests_to_drain(drain_timeout)
logger.info(
"Requests have been drained, proceeding with scale to %s engines",
new_data_parallel_size,
)
await self.engine_core.scale_elastic_ep(new_data_parallel_size)
self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
if envs.VLLM_ELASTIC_EP_DRAIN_REQUESTS:
logger.info(
"VLLM_ELASTIC_EP_DRAIN_REQUESTS is set, "
"waiting for requests to drain before scaling"
)
await self.wait_for_requests_to_drain(drain_timeout)
# recreate stat loggers
if new_data_parallel_size > old_data_parallel_size and self.log_stats:
@@ -999,6 +1001,18 @@ class AsyncLLM(EngineClient):
engine_idxs=list(range(new_data_parallel_size)),
custom_stat_loggers=None,
)
# Update the mutable ref so output_handler picks up the
# new logger without creating a circular reference via self.
if hasattr(self, "_logger_ref"):
self._logger_ref[0] = self.logger_manager
self.logger_manager.log_engine_initialized()
set_scaling_elastic_ep(True)
try:
await self.engine_core.scale_elastic_ep(new_data_parallel_size)
self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
finally:
set_scaling_elastic_ep(False)
@property
def is_running(self) -> bool:

View File

@@ -71,6 +71,9 @@ class DPCoordinator:
)
local_only_eng = dp_size == parallel_config.data_parallel_size_local
# NOTE(yongji): handling scaling from intra-node to inter-node
if parallel_config.enable_elastic_ep:
local_only_eng = False
back_publish_address = get_engine_client_zmq_addr(local_only_eng, host)
back_output_address = get_engine_client_zmq_addr(local_only_eng, host)
@@ -201,6 +204,7 @@ class DPCoordinatorProc:
poller = zmq.Poller()
poller.register(publish_front, zmq.POLLIN)
poller.register(publish_back, zmq.POLLIN)
poller.register(output_back, zmq.POLLIN)
last_publish_time = 0
while True:
@@ -231,6 +235,22 @@ class DPCoordinatorProc:
events = dict(events)
wave_state_changed = False
if publish_back in events:
buffer = publish_back.recv()
if buffer == b"\x01":
# NOTE(yongji): newly started engine subscribed
# We need to send READY message here instead of receiving
# SCALE_ELASTIC_EP notification from engine core client
# as SCALE_ELASTIC_EP is only sent when
# new engines finished initialization.
# Subscription message, on the other hand, is sent
# by each engine during initialization
publish_back.send(b"READY")
else:
logger.error(
"DP Coordinator receives unexpected message from engines"
)
if publish_front in events:
buffer = publish_front.recv()
if buffer in (b"\x01", b"\x00"):
@@ -259,7 +279,6 @@ class DPCoordinatorProc:
# current_wave
# we note that 0 is the wave number for the new
# engine
engines_running = False
logger.info(
"DPCoordinator scaled up from %s to %s engines",
current_count,

View File

@@ -17,6 +17,7 @@ from typing import Any, TypeVar, cast
import msgspec
import zmq
import vllm.envs as envs
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
from vllm.envs import enable_envs_cache
@@ -44,6 +45,8 @@ from vllm.v1.core.kv_cache_utils import (
from vllm.v1.core.sched.interface import PauseState, SchedulerInterface
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.engine import (
EEP_NOTIFICATION_CALL_ID,
EEPNotificationType,
EngineCoreOutput,
EngineCoreOutputs,
EngineCoreRequest,
@@ -72,7 +75,6 @@ from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__)
POLLING_TIMEOUT_S = 2.5
HANDSHAKE_TIMEOUT_MINS = 5
_R = TypeVar("_R") # Return type for collective_rpc
@@ -111,6 +113,9 @@ class EngineCore:
self.available_gpu_memory_for_kv_cache = -1
if envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH:
self._eep_scale_up_before_kv_init()
# Setup KV Caches and update CacheConfig after profiling.
num_gpu_blocks, num_cpu_blocks, kv_cache_config = self._initialize_kv_caches(
vllm_config
@@ -180,13 +185,55 @@ class EngineCore:
# Batch queue for scheduled batches. This enables us to asynchronously
# schedule and execute batches, and is required by pipeline parallelism
# to eliminate pipeline bubbles.
self.batch_queue_size = self.model_executor.max_concurrent_batches
base_batch_queue_size = self.model_executor.max_concurrent_batches
if envs.VLLM_ENABLE_PP_ILU_OPT:
self.batch_queue_size = envs.VLLM_PP_ILU_OPT_BATCH_QUEUE_SIZE
if self.batch_queue_size <= 0:
self.batch_queue_size = base_batch_queue_size * 2
self._use_batch_queue_ilu_opt = True
logger.info(
"PP ILU opt is enabled: batch_queue_size=%d (base=%d)",
self.batch_queue_size,
base_batch_queue_size,
)
else:
self.batch_queue_size = base_batch_queue_size
self._use_batch_queue_ilu_opt = False
self.batch_queue: (
deque[tuple[Future[ModelRunnerOutput], SchedulerOutput, Future[Any]]] | None
) = None
if self.batch_queue_size > 1:
logger.debug("Batch queue is enabled with size %d", self.batch_queue_size)
logger.info(
"Batch queue is enabled with size %d (ilu_opt=%s)",
self.batch_queue_size,
self._use_batch_queue_ilu_opt,
)
self.batch_queue = deque(maxlen=self.batch_queue_size)
if self._use_batch_queue_ilu_opt:
self.engine_core_input_queue: queue.Queue[
tuple[Future[ModelRunnerOutput], SchedulerOutput]
] = queue.Queue(maxsize=self.batch_queue_size)
self.engine_core_output_queue: queue.Queue[
tuple[SchedulerOutput, ModelRunnerOutput, bool]
] = queue.Queue(maxsize=self.batch_queue_size)
self._batch_queue_loop_thread = threading.Thread(
target=self._process_batch_queue_loop,
daemon=True,
)
self._batch_queue_loop_thread.start()
# When PP mix ILU scheduling or PP ILU opt is enabled with a KV
# connector, only NixlConnector is supported.
if vllm_config.kv_transfer_config is not None and (
envs.VLLM_ENABLE_PP_MIX_ILU_SCHEDULING or envs.VLLM_ENABLE_PP_ILU_OPT
):
kv_connector_name = vllm_config.kv_transfer_config.kv_connector
if kv_connector_name != "NixlConnector":
raise ValueError(
"When VLLM_ENABLE_PP_MIX_ILU_SCHEDULING or VLLM_ENABLE_PP_ILU_OPT "
"is enabled with a KV connector, only NixlConnector is supported; "
f"current kv_connector is {kv_connector_name!r}."
)
self.is_ec_producer = (
vllm_config.ec_transfer_config is not None
@@ -209,6 +256,10 @@ class EngineCore:
self.step if self.batch_queue is None else self.step_with_batch_queue
)
self.async_scheduling = vllm_config.scheduler_config.async_scheduling
self.draft_in_model_output = (
self.batch_queue is not None and self.use_spec_decode
)
self.aborts_queue = queue.Queue[list[str]]()
@@ -234,12 +285,10 @@ class EngineCore:
has_kv_cache = any(kv_cache_spec for kv_cache_spec in kv_cache_specs)
if has_kv_cache:
if os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1":
dp_group = getattr(self, "dp_group", None)
assert dp_group is not None
self.available_gpu_memory_for_kv_cache = (
ParallelConfig.sync_kv_cache_memory_size(dp_group, -1)
)
if envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH:
# NOTE(yongji): should already be set
# during _eep_scale_up_before_kv_init
assert self.available_gpu_memory_for_kv_cache > 0
available_gpu_memory = [self.available_gpu_memory_for_kv_cache] * len(
kv_cache_specs
)
@@ -408,12 +457,52 @@ class EngineCore:
# When using async scheduling we can't get draft token ids in advance,
# so we update draft token ids in the worker process and don't
# need to update draft token ids here.
if self.draft_in_model_output:
return
if not self.async_scheduling and self.use_spec_decode and model_executed:
# Take the draft token ids.
draft_token_ids = self.model_executor.take_draft_token_ids()
if draft_token_ids is not None:
self.scheduler.update_draft_token_ids(draft_token_ids)
def _has_kv_connector_work(self, meta: Any) -> bool:
"""Return True if kv_connector_metadata has any recv/save/send work."""
if meta is None:
return False
for attr in ("reqs_to_recv", "reqs_to_save", "reqs_to_send"):
val = getattr(meta, attr, None)
if val is not None and len(val) > 0:
return True
return False
def _has_meaningful_scheduler_output(
self, scheduler_output: SchedulerOutput
) -> bool:
"""Return False if scheduler_output is effectively empty."""
return not (
len(scheduler_output.scheduled_new_reqs) == 0
and len(scheduler_output.scheduled_cached_reqs.req_ids) == 0
and len(scheduler_output.num_scheduled_tokens) == 0
and scheduler_output.total_num_scheduled_tokens == 0
and len(scheduler_output.scheduled_spec_decode_tokens) == 0
and len(scheduler_output.scheduled_encoder_inputs) == 0
and len(scheduler_output.finished_req_ids) == 0
and (scheduler_output.scheduled_resumed_reqs is None
or len(scheduler_output.scheduled_resumed_reqs) == 0)
and not self._has_kv_connector_work(
scheduler_output.kv_connector_metadata
)
)
def _process_batch_queue_loop(self) -> None:
while True:
future, scheduler_output = self.engine_core_input_queue.get()
with self.log_error_detail(scheduler_output):
model_output = future.result()
self.engine_core_output_queue.put(
(scheduler_output, model_output, False)
)
def step_with_batch_queue(
self,
) -> tuple[dict[int, EngineCoreOutputs] | None, bool]:
@@ -434,6 +523,9 @@ class EngineCore:
batch_queue = self.batch_queue
assert batch_queue is not None
if self._use_batch_queue_ilu_opt:
return self.step_with_batch_queue_ilu_opt()
# Try to schedule a new batch if the batch queue is not full, but
# the scheduler may return an empty batch if all requests are scheduled.
# Note that this is not blocking.
@@ -531,6 +623,96 @@ class EngineCore:
return engine_core_outputs, model_executed
def step_with_batch_queue_ilu_opt(
self,
) -> tuple[dict[int, EngineCoreOutputs] | None, bool]:
"""Async batch queue variant using background thread for PP ILU opt.
Uses engine_core_input_queue / engine_core_output_queue with a
background thread (_process_batch_queue_loop) that blocks on
future.result(), so the main thread never blocks on GPU compute.
"""
assert not self.is_ec_producer, (
"ec_producer is not supported in step_with_batch_queue_ilu_opt"
)
assert not self.is_pooling_model, (
"is_pooling_model is not supported in step_with_batch_queue_ilu_opt"
)
assert not self.async_scheduling, (
"async_scheduling is not supported in step_with_batch_queue_ilu_opt"
)
model_executed = False
if self.scheduler.has_requests():
scheduler_output = self.scheduler.schedule()
has_meaningful_schedule = self._has_meaningful_scheduler_output(
scheduler_output
)
if (
self.engine_core_input_queue.qsize() <= 1
and not has_meaningful_schedule
):
has_meaningful_schedule = True
if has_meaningful_schedule:
logger.debug(
"[step_with_batch_queue_ilu_opt] scheduler_output: "
"total_num_scheduled_tokens=%s num_scheduled_tokens=%s "
"scheduled_new_reqs=%s scheduled_cached_reqs.req_ids=%s "
"resumed_req_ids=%s finished_req_ids=%s "
"has_meaningful_schedule=%s",
scheduler_output.total_num_scheduled_tokens,
scheduler_output.num_scheduled_tokens,
[r.req_id for r in scheduler_output.scheduled_new_reqs],
scheduler_output.scheduled_cached_reqs.req_ids,
scheduler_output.scheduled_cached_reqs.resumed_req_ids,
scheduler_output.finished_req_ids,
has_meaningful_schedule,
)
if has_meaningful_schedule:
exec_future = self.model_executor.execute_model(
scheduler_output, non_block=True
)
model_executed = (
scheduler_output.total_num_scheduled_tokens > 0
)
if not model_executed:
future = cast(Future[ModelRunnerOutput], exec_future)
else:
grammar_output = self.scheduler.get_grammar_bitmask(
scheduler_output
)
future = self.model_executor.sample_tokens(
grammar_output, non_block=True
)
if self.engine_core_input_queue.full():
scheduler_output_out, model_output_out, model_executed_out = (
self.engine_core_output_queue.get()
)
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output_out, model_output_out
)
self.engine_core_input_queue.put(
(future, scheduler_output)
)
return engine_core_outputs, model_executed_out
self.engine_core_input_queue.put((future, scheduler_output))
try:
scheduler_output, model_output, model_executed = (
self.engine_core_output_queue.get_nowait()
)
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output
)
return engine_core_outputs, model_executed
except queue.Empty:
return None, False
def _process_aborts_queue(self):
if not self.aborts_queue.empty():
request_ids = []
@@ -753,11 +935,22 @@ class EngineCore:
self.structured_output_manager.grammar_init(req)
return req, request.current_wave
def _eep_scale_up_before_kv_init(self):
raise NotImplementedError
def _eep_send_engine_core_notification(
self,
notification_type: EEPNotificationType,
vllm_config: VllmConfig | None = None,
):
raise NotImplementedError
class EngineCoreProc(EngineCore):
"""ZMQ-wrapper for running EngineCore in background process."""
ENGINE_CORE_DEAD = b"ENGINE_CORE_DEAD"
addresses: EngineZmqAddresses
@instrument(span_name="EngineCoreProc init")
def __init__(
@@ -808,6 +1001,13 @@ class EngineCoreProc(EngineCore):
# and "hybrid" LB modes.
self.publish_dp_lb_stats = internal_dp_balancing
self.addresses = addresses
self.process_input_queue_block = True
if envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH:
self._eep_send_engine_core_notification(
EEPNotificationType.NEW_CORE_ENGINES_INIT_READY,
vllm_config=vllm_config,
)
self._init_data_parallel(vllm_config)
super().__init__(
@@ -1120,8 +1320,14 @@ class EngineCoreProc(EngineCore):
if logger.isEnabledFor(DEBUG):
logger.debug("EngineCore waiting for work.")
waited = True
req = self.input_queue.get()
self._handle_client_request(*req)
block = self.process_input_queue_block
try:
req = self.input_queue.get(block=block)
self._handle_client_request(*req)
except queue.Empty:
break
if not block:
break
if waited:
logger.debug("EngineCore loop active.")
@@ -1291,6 +1497,11 @@ class EngineCoreProc(EngineCore):
for input_socket, _ in poller.poll():
# (RequestType, RequestData)
type_frame, *data_frames = input_socket.recv_multipart(copy=False)
# NOTE(yongji): ignore READY message sent by DP coordinator
# that is used to notify newly started engines
if type_frame.buffer == b"READY":
assert input_socket == coord_socket
continue
request_type = EngineCoreRequestType(bytes(type_frame.buffer))
# Deserialize the request data.
@@ -1489,6 +1700,10 @@ class DPEngineCoreProc(EngineCoreProc):
self.current_wave = 0
self.last_counts = (0, 0)
from vllm.distributed.elastic_ep.elastic_state import ElasticEPScalingState
self.eep_scaling_state: ElasticEPScalingState | None = None
# Initialize the engine.
dp_rank = vllm_config.parallel_config.data_parallel_rank
super().__init__(
@@ -1512,7 +1727,9 @@ class DPEngineCoreProc(EngineCoreProc):
assert 0 <= local_dp_rank <= dp_rank < dp_size
self.dp_rank = dp_rank
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()
self.dp_group, self.dp_store = (
vllm_config.parallel_config.stateless_init_dp_group(return_store=True)
)
def shutdown(self):
super().shutdown()
@@ -1533,7 +1750,11 @@ class DPEngineCoreProc(EngineCoreProc):
def resume_scheduler(self):
super().resume_scheduler()
if not self.engines_running and self.scheduler.has_unfinished_requests():
if (
self.has_coordinator
and not self.engines_running
and self.scheduler.has_unfinished_requests()
):
# Wake up other DP engines.
self.output_queue.put_nowait(
(-1, EngineCoreOutputs(start_wave=self.current_wave))
@@ -1575,7 +1796,12 @@ class DPEngineCoreProc(EngineCoreProc):
# 1) Poll the input queue until there is work to do.
self._process_input_queue()
# 2) Step the engine core.
if self.eep_scaling_state is not None:
_ = self.eep_scaling_state.progress()
if self.eep_scaling_state.is_complete():
self.process_input_queue_block = True
self.eep_scaling_state = None
executed = self._process_engine_step()
self._maybe_publish_request_counts()
@@ -1625,54 +1851,129 @@ class DPEngineCoreProc(EngineCoreProc):
def reinitialize_distributed(
self, reconfig_request: ReconfigureDistributedRequest
) -> None:
stateless_destroy_torch_distributed_process_group(self.dp_group)
self.shutdown()
from copy import deepcopy
parallel_config = self.vllm_config.parallel_config
old_dp_size = parallel_config.data_parallel_size
parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
if reconfig_request.new_data_parallel_rank != -1:
parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
# local rank specifies device visibility, it should not be changed
assert (
reconfig_request.new_data_parallel_rank_local
== ReconfigureRankType.KEEP_CURRENT_RANK
)
parallel_config.data_parallel_master_ip = (
reconfig_request.new_data_parallel_master_ip
)
parallel_config.data_parallel_master_port = (
reconfig_request.new_data_parallel_master_port
)
if reconfig_request.new_data_parallel_rank != -2:
self.dp_rank = parallel_config.data_parallel_rank
self.dp_group = parallel_config.stateless_init_dp_group()
reconfig_request.new_data_parallel_master_port = (
parallel_config.data_parallel_master_port
)
from vllm.distributed.elastic_ep.elastic_state import ElasticEPScalingState
self.model_executor.reinitialize_distributed(reconfig_request)
if reconfig_request.new_data_parallel_size > old_dp_size:
assert self.available_gpu_memory_for_kv_cache > 0
# pass available_gpu_memory_for_kv_cache from existing
# engine-cores to new engine-cores so they can directly
# use it in _initialize_kv_caches() rather than profiling.
ParallelConfig.sync_kv_cache_memory_size(
self.dp_group, self.available_gpu_memory_for_kv_cache
)
# NOTE(yongji): newly joined workers require dummy_run even
# CUDA graph is not used
self.model_executor.collective_rpc("compile_or_warm_up_model")
new_parallel_config = deepcopy(self.vllm_config.parallel_config)
old_dp_size = new_parallel_config.data_parallel_size
new_parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
if (
reconfig_request.new_data_parallel_rank
== ReconfigureRankType.SHUTDOWN_CURRENT_RANK
!= ReconfigureRankType.KEEP_CURRENT_RANK
):
self.shutdown()
logger.info("DPEngineCoreProc %s shutdown", self.dp_rank)
else:
logger.info(
"Distributed environment reinitialized for DP rank %s", self.dp_rank
new_parallel_config.data_parallel_rank = (
reconfig_request.new_data_parallel_rank
)
new_parallel_config.data_parallel_master_ip = (
reconfig_request.new_data_parallel_master_ip
)
new_parallel_config.data_parallel_master_port = (
reconfig_request.new_data_parallel_master_port
)
new_parallel_config._data_parallel_master_port_list = (
reconfig_request.new_data_parallel_master_port_list
)
is_scale_down = reconfig_request.new_data_parallel_size < old_dp_size
is_shutdown = (
reconfig_request.new_data_parallel_rank
== ReconfigureRankType.SHUTDOWN_CURRENT_RANK
)
self.eep_scaling_state = ElasticEPScalingState(
model_executor=self.model_executor,
engine_core=self,
vllm_config=self.vllm_config,
new_parallel_config=new_parallel_config,
worker_type="removing" if is_shutdown else "existing",
scale_type="scale_down" if is_scale_down else "scale_up",
reconfig_request=reconfig_request,
)
self.process_input_queue_block = False
logger.info(
"[Elastic EP] Received reconfiguration request and starting scaling up/down"
)
def _eep_send_engine_core_notification(
self,
notification_type: EEPNotificationType,
vllm_config: VllmConfig | None = None,
):
"""
Send notifications to EngineCoreClient, which can then forward
the notifications to other engine core processes. It is used for:
1) In scale up: new core engines to notify exisiting core engines
that they are ready;
2) In scale down: removing core engines to notify EngineCoreClient
so EngineCoreClient can release their ray placement groups;
3) Both scale up/down: to notify EngineCoreClient that exisiting
core engines have already switched to the new parallel setup.
"""
if vllm_config is None:
dp_rank = self.vllm_config.parallel_config.data_parallel_rank
else:
dp_rank = vllm_config.parallel_config.data_parallel_rank
notification_data = (notification_type.value, dp_rank)
outputs = EngineCoreOutputs(
utility_output=UtilityOutput(
call_id=EEP_NOTIFICATION_CALL_ID,
result=UtilityResult(notification_data),
)
)
outputs.engine_index = self.engine_index
if hasattr(self, "output_thread") and self.output_thread.is_alive():
self.output_queue.put_nowait((0, outputs))
else:
encoder = MsgpackEncoder()
with (
zmq.Context() as ctx,
make_zmq_socket(
ctx, self.addresses.outputs[0], zmq.PUSH, linger=4000
) as socket,
):
socket.send_multipart(encoder.encode(outputs))
def eep_handle_engine_core_notification(
self, notification_type: str | EEPNotificationType
):
"""
Handle notification received from EngineCoreClient
(forwarded from new core engines).
"""
assert self.eep_scaling_state is not None
if isinstance(notification_type, str):
notification_type = EEPNotificationType(notification_type)
self.eep_scaling_state.handle_notification(notification_type)
def _eep_scale_up_before_kv_init(self):
from vllm.distributed.elastic_ep.elastic_state import ElasticEPScalingState
self.eep_scaling_state = ElasticEPScalingState(
model_executor=self.model_executor,
engine_core=self,
vllm_config=self.vllm_config,
new_parallel_config=self.vllm_config.parallel_config,
worker_type="new",
scale_type="scale_up",
reconfig_request=None,
)
self.model_executor.collective_rpc("init_device")
self.model_executor.collective_rpc("load_model")
self._eep_send_engine_core_notification(
EEPNotificationType.NEW_CORE_ENGINES_WEIGHTS_INIT_READY
)
self.model_executor.collective_rpc(
"elastic_ep_execute", args=("receive_weights",)
)
self.available_gpu_memory_for_kv_cache = (
ParallelConfig.sync_kv_cache_memory_size(self.dp_group, -1)
)
self.model_executor.collective_rpc(
"elastic_ep_execute", args=("prepare_new_worker",)
)
self.process_input_queue_block = False
class EngineCoreActorMixin:

View File

@@ -28,11 +28,12 @@ from vllm.tracing import instrument
from vllm.utils.async_utils import in_loop
from vllm.utils.network_utils import (
close_sockets,
get_open_port,
get_open_zmq_inproc_path,
make_zmq_socket,
)
from vllm.v1.engine import (
EEP_NOTIFICATION_CALL_ID,
EEPNotificationType,
EngineCoreOutputs,
EngineCoreRequest,
EngineCoreRequestType,
@@ -47,6 +48,7 @@ from vllm.v1.engine.exceptions import EngineDeadError
from vllm.v1.engine.utils import (
CoreEngineActorManager,
CoreEngineProcManager,
get_engine_zmq_addresses,
launch_core_engines,
)
from vllm.v1.executor import Executor
@@ -445,6 +447,63 @@ class BackgroundResources:
raise EngineDeadError()
@dataclass
class ElasticScalingCache:
existing_core_engines: list[EngineIdentity]
num_new_core_engines: int
pending_notifications: dict[EEPNotificationType, set[int]]
def allocate_stateless_group_ports(parallel_config, new_data_parallel_size: int):
"""
Allocate stateless group ports for elastic EP.
"""
from vllm.utils.network_utils import get_open_ports_list
assert parallel_config.enable_elastic_ep, "Elastic EP must be enabled"
world_size = parallel_config.world_size
new_world_size_across_dp = world_size * new_data_parallel_size
num_world_groups = 1
num_dp_groups = max(1, new_world_size_across_dp // new_data_parallel_size)
num_ep_groups = max(
1,
new_world_size_across_dp
// (new_data_parallel_size * parallel_config.tensor_parallel_size),
)
num_eplb_groups = num_ep_groups
total_ports_needed = (
num_world_groups + num_dp_groups + num_ep_groups + num_eplb_groups
) * 3 + 5
all_ports = get_open_ports_list(total_ports_needed)
new_data_parallel_master_port_list = all_ports[-5:]
all_ports = all_ports[:-5]
new_stateless_world_group_port_list = [
all_ports[i : i + 3] for i in range(0, num_world_groups * 3, 3)
]
start_idx = num_world_groups * 3
new_stateless_dp_group_port_list = [
all_ports[i : i + 3] for i in range(start_idx, start_idx + num_dp_groups * 3, 3)
]
start_idx += num_dp_groups * 3
new_stateless_ep_group_port_list = [
all_ports[i : i + 3] for i in range(start_idx, start_idx + num_ep_groups * 3, 3)
]
start_idx += num_ep_groups * 3
new_stateless_eplb_group_port_list = [
all_ports[i : i + 3]
for i in range(start_idx, start_idx + num_eplb_groups * 3, 3)
]
parallel_config._stateless_world_group_port_list = (
new_stateless_world_group_port_list
)
parallel_config._stateless_dp_group_port_list = new_stateless_dp_group_port_list
parallel_config._stateless_ep_group_port_list = new_stateless_ep_group_port_list
parallel_config._stateless_eplb_group_port_list = new_stateless_eplb_group_port_list
parallel_config.data_parallel_master_port = new_data_parallel_master_port_list.pop()
parallel_config._data_parallel_master_port_list = new_data_parallel_master_port_list
class MPClient(EngineCoreClient):
"""
MPClient: base client for multi-proc EngineCore.
@@ -491,32 +550,37 @@ class MPClient(EngineCoreClient):
input_address = client_addresses["input_address"]
output_address = client_addresses["output_address"]
self.stats_update_address = client_addresses.get("stats_update_address")
self.input_socket = self.resources.input_socket = make_zmq_socket(
self.ctx, input_address, zmq.ROUTER, bind=True
)
self.resources.output_socket = make_zmq_socket(
self.ctx, output_address, zmq.PULL
)
else:
# Engines are managed by this client.
with launch_core_engines(vllm_config, executor_class, log_stats) as (
engine_manager,
coordinator,
addresses = get_engine_zmq_addresses(vllm_config)
self.input_socket = self.resources.input_socket = make_zmq_socket(
self.ctx, addresses.inputs[0], zmq.ROUTER, bind=True
)
self.resources.output_socket = make_zmq_socket(
self.ctx, addresses.outputs[0], zmq.PULL
)
with launch_core_engines(
vllm_config,
executor_class,
log_stats,
addresses,
):
) as (engine_manager, coordinator, addresses):
self.resources.coordinator = coordinator
self.resources.engine_manager = engine_manager
(input_address,) = addresses.inputs
(output_address,) = addresses.outputs
self.stats_update_address = addresses.frontend_stats_publish_address
if coordinator is not None:
assert self.stats_update_address == (
coordinator.get_stats_publish_address()
)
# Create input and output sockets.
self.input_socket = self.resources.input_socket = make_zmq_socket(
self.ctx, input_address, zmq.ROUTER, bind=True
)
self.resources.output_socket = make_zmq_socket(
self.ctx, output_address, zmq.PULL
)
parallel_config = vllm_config.parallel_config
dp_size = parallel_config.data_parallel_size
dp_rank = parallel_config.data_parallel_index
@@ -545,8 +609,13 @@ class MPClient(EngineCoreClient):
timeout=VLLM_ENGINE_READY_TIMEOUT_S * 1000 # convert to ms
):
raise TimeoutError(
"Timed out waiting for engines to send "
"initial message on input socket."
f"Timed out waiting for engine core processes to "
f"start. This is often caused by slow weight loading "
f"for large models. Waited "
f"{VLLM_ENGINE_READY_TIMEOUT_S}s (configured by "
f"VLLM_ENGINE_READY_TIMEOUT_S). To increase the "
f"timeout, set the environment variable: "
f"VLLM_ENGINE_READY_TIMEOUT_S=<seconds>"
)
identity, _ = sync_input_socket.recv_multipart()
identities.remove(identity)
@@ -877,6 +946,10 @@ class AsyncMPClient(MPClient):
output_socket = resources.output_socket
assert output_socket is not None
notification_callback_handler: (
Callable[[AsyncMPClient, Sequence[Any]], Any] | None
) = getattr(self.__class__, "eep_process_engine_core_notification", None)
async def process_outputs_socket():
try:
while True:
@@ -884,7 +957,26 @@ class AsyncMPClient(MPClient):
resources.validate_alive(frames)
outputs: EngineCoreOutputs = decoder.decode(frames)
if outputs.utility_output:
_process_utility_output(outputs.utility_output, utility_results)
if (
outputs.utility_output.call_id == EEP_NOTIFICATION_CALL_ID
and notification_callback_handler is not None
):
assert _self_ref is not None
_self = _self_ref()
if not _self:
return
if outputs.utility_output.result is None:
continue
notification_data = outputs.utility_output.result.result
assert isinstance(notification_data, Sequence)
assert len(notification_data) == 2
asyncio.create_task(
notification_callback_handler(_self, notification_data)
)
else:
_process_utility_output(
outputs.utility_output, utility_results
)
continue
if output_handler is not None:
@@ -1081,6 +1173,8 @@ class DPAsyncMPClient(AsyncMPClient):
# Used only by DPLBAsyncMPClient subclass.
self.lb_engines: list[list[int]] = [[0, 0] for _ in self.core_engines]
self.eep_scaling_cache: ElasticScalingCache | None = None
self.first_req_sock_addr = get_open_zmq_inproc_path()
self.first_req_send_socket = self.resources.first_req_send_socket = (
make_zmq_socket(self.ctx, self.first_req_sock_addr, zmq.PAIR, bind=True)
@@ -1101,12 +1195,6 @@ class DPAsyncMPClient(AsyncMPClient):
assert self.stats_update_address is not None
stats_addr: str = self.stats_update_address
assert len(self.engine_ranks_managed) > 0
# NOTE: running and waiting counts are all global from
# the Coordinator include all global EngineCores. This
# slice includes just the cores managed by this client.
count_slice = slice(
self.engine_ranks_managed[0], self.engine_ranks_managed[-1] + 1
)
async def run_engine_stats_update_task():
with (
@@ -1145,6 +1233,29 @@ class DPAsyncMPClient(AsyncMPClient):
):
# Extract new engine count from the decoded message
new_engine_count = decoded[1]
# Update engine_ranks_managed and count_slice
parallel_config = self.vllm_config.parallel_config
dp_size = parallel_config.data_parallel_size
dp_rank = parallel_config.data_parallel_rank
assert dp_rank == 0
assert dp_size == new_engine_count
assert not (
parallel_config.data_parallel_hybrid_lb
or parallel_config.data_parallel_external_lb
)
num_ranks = dp_size
self.engine_ranks_managed = list(
range(dp_rank, dp_rank + num_ranks)
)
if len(self.lb_engines) < new_engine_count:
self.lb_engines = self.lb_engines + [
[0, 0]
for _ in range(
new_engine_count - len(self.lb_engines)
)
]
else:
self.lb_engines = self.lb_engines[:new_engine_count]
# Send scale up notification to coordinator
scale_msg = msgspec.msgpack.encode(
("SCALE_ELASTIC_EP", new_engine_count)
@@ -1178,6 +1289,11 @@ class DPAsyncMPClient(AsyncMPClient):
self.current_wave = wave
self.engines_running = running
if counts is not None:
# Running and waiting counts are global from the
# Coordinator including all EngineCores. Slice to get
# just the cores managed by this client.
ranks = self.engine_ranks_managed
count_slice = slice(ranks[0], ranks[-1] + 1)
sliced_counts = counts[count_slice]
self.lb_engines = sliced_counts
logger.debug(
@@ -1287,6 +1403,67 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
for req_id in outputs.finished_requests:
self.reqs_in_flight.pop(req_id, None)
@staticmethod
async def eep_process_engine_core_notification(
self: "DPLBAsyncMPClient", notification_data: tuple[str, int]
):
cache = self.eep_scaling_cache
notification_type_str, dp_rank = notification_data
try:
notification_type = EEPNotificationType(notification_type_str)
except ValueError as e:
raise ValueError(
f"Unknown EEP notification type: {notification_type_str}"
) from e
if notification_type == EEPNotificationType.RECONFIGURE_FINISHED:
from vllm.v1.engine import UtilityResult
# NOTE(yongji): process a dummy UtilityOutput to resolve the future
# awaited in _eep_wait_for_setup_switch_complete(), signaling that
# all engine cores have completed reconfiguration.
dummy_output = UtilityOutput(
call_id=EEP_NOTIFICATION_CALL_ID, result=UtilityResult(None)
)
_process_utility_output(dummy_output, self.utility_results)
return
assert cache is not None
if notification_type not in cache.pending_notifications:
cache.pending_notifications[notification_type] = set()
if dp_rank in cache.pending_notifications[notification_type]:
raise ValueError(
f"Duplicate notification {notification_type} from dp_rank {dp_rank}"
)
cache.pending_notifications[notification_type].add(dp_rank)
if len(cache.pending_notifications[notification_type]) >= abs(
cache.num_new_core_engines
):
if notification_type == EEPNotificationType.SHUTDOWN_COMPLETE:
assert isinstance(self.resources.engine_manager, CoreEngineActorManager)
assert cache.num_new_core_engines < 0
old_dp_size = len(cache.existing_core_engines)
new_dp_size = old_dp_size + cache.num_new_core_engines
self.resources.engine_manager.scale_down_elastic_ep(
old_dp_size, new_dp_size
)
else:
await asyncio.gather(
*[
self._call_utility_async(
"eep_handle_engine_core_notification",
notification_type,
engine=engine,
)
for engine in cache.existing_core_engines
]
)
cache.pending_notifications[notification_type] = set()
if notification_type in [
EEPNotificationType.SHUTDOWN_COMPLETE,
EEPNotificationType.NEW_CORE_ENGINES_WEIGHTS_INIT_READY,
]:
self.eep_scaling_cache = None
async def abort_requests_async(self, request_ids: list[str]) -> None:
if not request_ids or self.resources.engine_dead:
return
@@ -1333,6 +1510,20 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
cur_data_parallel_size, new_data_parallel_size
)
async def _eep_wait_for_setup_switch_complete(self) -> None:
"""
Wait for core engines to switch to the new setup.
In eep_process_engine_core_notification(), a dummy UtilityOutput with
EEP_NOTIFICATION_CALL_ID will be set when RECONFIGURE_FINISHED
notification is received from engine 0. We create a future with
that call_id and wait for it to be resolved.
"""
future = asyncio.get_running_loop().create_future()
self.utility_results[EEP_NOTIFICATION_CALL_ID] = future
self._ensure_output_queue_task()
await future
async def _scale_up_elastic_ep(
self, cur_data_parallel_size: int, new_data_parallel_size: int
) -> None:
@@ -1340,38 +1531,57 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
and reconfiguring existing ones."""
cur_data_parallel_size = len(self.core_engines)
# Phase 1: Send reconfigure messages to all existing engines and wait
# for them to be sent
self.eep_scaling_cache = ElasticScalingCache(
existing_core_engines=self.core_engines.copy(),
num_new_core_engines=new_data_parallel_size - cur_data_parallel_size,
pending_notifications=dict(),
)
parallel_config = self.vllm_config.parallel_config
allocate_stateless_group_ports(parallel_config, new_data_parallel_size)
# Phase 1: Send reconfig messages to existing engines
reconfig_futures = []
self.vllm_config.parallel_config.data_parallel_master_port = get_open_port()
for engine in self.core_engines:
reconfig_request = ReconfigureDistributedRequest(
new_data_parallel_size=new_data_parallel_size,
new_data_parallel_rank=ReconfigureRankType.KEEP_CURRENT_RANK,
new_data_parallel_rank_local=ReconfigureRankType.KEEP_CURRENT_RANK,
new_data_parallel_master_ip=self.vllm_config.parallel_config.data_parallel_master_ip,
new_data_parallel_master_port=self.vllm_config.parallel_config.data_parallel_master_port,
new_data_parallel_master_ip=parallel_config.data_parallel_master_ip,
new_data_parallel_master_port=parallel_config.data_parallel_master_port,
new_data_parallel_master_port_list=parallel_config._data_parallel_master_port_list,
new_stateless_world_group_port_list=parallel_config._stateless_world_group_port_list,
new_stateless_dp_group_port_list=parallel_config._stateless_dp_group_port_list,
new_stateless_ep_group_port_list=parallel_config._stateless_ep_group_port_list,
new_stateless_eplb_group_port_list=parallel_config._stateless_eplb_group_port_list,
)
coro = self._call_utility_async(
"reinitialize_distributed", reconfig_request, engine=engine
)
reconfig_futures.append(asyncio.create_task(coro))
logger.info("All reconfigure messages sent, starting engine creation")
# Phase 2: Create new engines now that reconfig messages have been sent
# self.resources.engine_manager is guaranteed to be
# CoreEngineActorManager for RayDPClient
# Phase 2: Create new engines
assert isinstance(self.resources.engine_manager, CoreEngineActorManager)
self.resources.engine_manager.scale_up_elastic_ep(
self.vllm_config, new_data_parallel_size
parallel_config.eplb_config.num_redundant_experts = 0
start_new_worker_future = asyncio.to_thread(
self.resources.engine_manager.scale_up_elastic_ep,
self.vllm_config,
new_data_parallel_size,
)
wait_future = self._eep_wait_for_setup_switch_complete()
# Phase 3: Wait for new engines to be created
# and reconfig messages to be received
await asyncio.gather(start_new_worker_future, *reconfig_futures)
logger.info("[Elastic EP] Successfully started new engines")
# Create new CoreEngine objects for the new engines
new_engine_identities = set()
for i in range(cur_data_parallel_size, new_data_parallel_size):
new_engine = i.to_bytes(2, "little")
self.core_engines.append(new_engine)
# NOTE(yongji): we don't update lb_engines here,
# we let run_engine_stats_update_task to update it.
new_engine_identities.add(new_engine)
# Wait for ready messages from new engines on the input socket
@@ -1381,16 +1591,21 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
timeout=VLLM_ENGINE_READY_TIMEOUT_S * 1000 # convert to ms
):
raise TimeoutError(
"Timed out waiting for new engines to send initial "
"message on input socket."
f"Timed out waiting for new engine core processes to "
f"start. Waited "
f"{VLLM_ENGINE_READY_TIMEOUT_S}s (configured by "
f"VLLM_ENGINE_READY_TIMEOUT_S). To increase the "
f"timeout, set the environment variable: "
f"VLLM_ENGINE_READY_TIMEOUT_S=<seconds>"
)
identity, _ = sync_input_socket.recv_multipart()
new_engine_identities.discard(identity)
# Phase 3: Wait for all existing engines to complete reconfiguration
logger.info("Waiting for existing engines to complete reconfiguration")
await asyncio.gather(*reconfig_futures)
# NOTE(yongji): Before we schedule any requests on the new workers,
# we should wait for them to switch to the new setup.
await wait_future
# Update the parallel config
self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
# Notify coordinator about scale up through existing
# stats_update_task connection
self._ensure_stats_update_task()
@@ -1399,8 +1614,6 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
)
await self.first_req_send_socket.send(scale_up_marker)
# Update the parallel config
self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
logger.info(
"[Elastic EP] Scale up completed, new data parallel size: %s",
new_data_parallel_size,
@@ -1413,7 +1626,14 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
reconfiguring existing engine cores."""
cur_data_parallel_size = len(self.core_engines)
self.vllm_config.parallel_config.data_parallel_master_port = get_open_port()
self.eep_scaling_cache = ElasticScalingCache(
existing_core_engines=self.core_engines.copy(),
num_new_core_engines=new_data_parallel_size - cur_data_parallel_size,
pending_notifications=dict(),
)
parallel_config = self.vllm_config.parallel_config
allocate_stateless_group_ports(parallel_config, new_data_parallel_size)
reconfig_futures = []
for cur_dp_rank, engine in enumerate(self.core_engines):
@@ -1421,8 +1641,13 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
new_data_parallel_size=new_data_parallel_size,
new_data_parallel_rank=ReconfigureRankType.KEEP_CURRENT_RANK,
new_data_parallel_rank_local=ReconfigureRankType.KEEP_CURRENT_RANK,
new_data_parallel_master_ip=self.vllm_config.parallel_config.data_parallel_master_ip,
new_data_parallel_master_port=self.vllm_config.parallel_config.data_parallel_master_port,
new_data_parallel_master_ip=parallel_config.data_parallel_master_ip,
new_data_parallel_master_port=parallel_config.data_parallel_master_port,
new_data_parallel_master_port_list=parallel_config._data_parallel_master_port_list,
new_stateless_world_group_port_list=parallel_config._stateless_world_group_port_list,
new_stateless_dp_group_port_list=parallel_config._stateless_dp_group_port_list,
new_stateless_ep_group_port_list=parallel_config._stateless_ep_group_port_list,
new_stateless_eplb_group_port_list=parallel_config._stateless_eplb_group_port_list,
)
if cur_dp_rank >= new_data_parallel_size:
reconfig_request.new_data_parallel_rank = (
@@ -1433,23 +1658,24 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
)
reconfig_futures.append(asyncio.create_task(coro))
for _ in range(new_data_parallel_size, cur_data_parallel_size):
self.core_engines.pop()
# NOTE(yongji): Immediately stop sending requests to the removing engines.
self.core_engines = self.core_engines[:new_data_parallel_size]
self.lb_engines = self.lb_engines[:new_data_parallel_size]
wait_future = self._eep_wait_for_setup_switch_complete()
await asyncio.gather(*reconfig_futures)
assert isinstance(self.resources.engine_manager, CoreEngineActorManager)
self.resources.engine_manager.scale_down_elastic_ep(
cur_data_parallel_size, new_data_parallel_size
)
self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
self._ensure_stats_update_task()
scale_down_marker = msgspec.msgpack.encode(
("SCALE_ELASTIC_EP", new_data_parallel_size)
)
await self.first_req_send_socket.send(scale_down_marker)
self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
# NOTE(yongji): Unlike scaling up,
# here we don't actually need to wait for the setup switch to complete.
# We may want to remove it in the future.
await wait_future
logger.info(
"[Elastic EP] Scale down completed, new data parallel size: %s",
new_data_parallel_size,

View File

@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
import warnings
from collections.abc import Mapping
from typing import Any, Literal
@@ -114,16 +113,6 @@ class InputProcessor:
supported_tasks: tuple[SupportedTask, ...],
) -> None:
"""Raise `ValueError` if SamplingParams or PoolingParams is not valid."""
if params.truncate_prompt_tokens is not None:
params_type = type(params).__name__
warnings.warn(
f"The `truncate_prompt_tokens` parameter in `{params_type}` "
"is deprecated and will be removed in v0.17. "
"Please pass it via `tokenization_kwargs` instead.",
DeprecationWarning,
stacklevel=2,
)
if isinstance(params, SamplingParams):
supported_generation_tasks = [
task for task in supported_tasks if task in GENERATION_TASKS

View File

@@ -92,6 +92,7 @@ class LLMEngine:
self.renderer = renderer = renderer_from_config(self.vllm_config)
self.io_processor = get_io_processor(
self.vllm_config,
self.renderer,
self.model_config.io_processor_plugin,
)

View File

@@ -277,6 +277,8 @@ class CoreEngineActorManager:
else:
ray.init()
vllm_config.parallel_config.allocate_elastic_ep_ports()
if placement_groups is not None:
assert local_dp_ranks is not None, (
"local_dp_ranks must be provided if placement_groups is provided"
@@ -584,6 +586,8 @@ class CoreEngineActorManager:
node_ip = node.node_ip
node_id = node.node_id
if device_str not in available_resources[node_id]:
continue
available_gpus = int(available_resources[node_id][device_str])
# Get total GPUs on this node from the node's resources
@@ -773,11 +777,50 @@ class CoreEngineActorManager:
ray.util.remove_placement_group(pg)
def get_engine_zmq_addresses(
vllm_config: VllmConfig,
num_api_servers: int = 1,
) -> EngineZmqAddresses:
"""Allocate ZMQ addresses for engine-client communication."""
parallel_config = vllm_config.parallel_config
local_engine_count = parallel_config.data_parallel_size_local
local_start_index = parallel_config.data_parallel_rank_local
dp_size = parallel_config.data_parallel_size
host = parallel_config.data_parallel_master_ip
local_engines_only = parallel_config.local_engines_only
# In offline mode there is an LLM instance per DP rank and
# one core engine per LLM, see
# examples/offline_inference/data_parallel.py.
offline_mode = local_start_index is not None
# client_local_only = True for cases where this front-end
# sends requests only to colocated engines.
client_local_only = (
offline_mode or local_engines_only or (local_engine_count == dp_size)
)
# NOTE(yongji): handling scaling from intra-node to inter-node
if parallel_config.enable_elastic_ep:
client_local_only = False
return EngineZmqAddresses(
inputs=[
get_engine_client_zmq_addr(client_local_only, host)
for _ in range(num_api_servers)
],
outputs=[
get_engine_client_zmq_addr(client_local_only, host)
for _ in range(num_api_servers)
],
)
@contextlib.contextmanager
def launch_core_engines(
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
addresses: EngineZmqAddresses,
num_api_servers: int = 1,
) -> Iterator[
tuple[
@@ -796,29 +839,8 @@ def launch_core_engines(
host = parallel_config.data_parallel_master_ip
local_engines_only = parallel_config.local_engines_only
# In offline mode there is an LLM instance per DP rank and
# one core engine per LLM, see
# examples/offline_inference/data_parallel.py.
offline_mode = local_start_index is not None
# client_local_only = True for cases where this front-end
# sends requests only to colocated engines.
client_local_only = (
offline_mode or local_engines_only or (local_engine_count == dp_size)
)
# Set up input and output addresses.
addresses = EngineZmqAddresses(
inputs=[
get_engine_client_zmq_addr(client_local_only, host)
for _ in range(num_api_servers)
],
outputs=[
get_engine_client_zmq_addr(client_local_only, host)
for _ in range(num_api_servers)
],
)
# Run the DP Coordinator process with rank 0 when in online DP mode.
# The coordinator is needed for:
# 1. Internal/hybrid LB: collecting and publishing queue stats for load balancing
@@ -885,6 +907,10 @@ def launch_core_engines(
# will be False.
handshake_local_only = offline_mode or local_engine_count == dp_size
# NOTE(yongji): handling scaling from intra-node to inter-node
if parallel_config.enable_elastic_ep:
handshake_local_only = False
handshake_address = get_engine_client_zmq_addr(
handshake_local_only, host, parallel_config.data_parallel_rpc_port
)

View File

@@ -115,7 +115,15 @@ class Executor(ABC):
underlying workers.
"""
self.collective_rpc("initialize_from_config", args=(kv_cache_configs,))
self.collective_rpc("compile_or_warm_up_model")
compilation_times: list[float] = self.collective_rpc("compile_or_warm_up_model")
# Propagate compilation time from workers back to the main process.
# With TP>1, compilation happens in worker processes, so the main
# process config is never updated. Use max across workers since they
# compile in parallel.
if compilation_times:
self.vllm_config.compilation_config.compilation_time = max(
compilation_times
)
def register_failure_callback(self, callback: FailureCallback): # noqa: B027
"""

View File

@@ -38,12 +38,14 @@ from vllm.distributed.parallel_state import (
get_pcp_group,
get_pp_group,
get_tp_group,
model_parallel_is_initialized,
)
from vllm.envs import enable_envs_cache
from vllm.logger import init_logger
from vllm.tracing import instrument, maybe_init_worker_tracer
from vllm.utils.network_utils import (
get_distributed_init_method,
get_ip,
get_loopback_ip,
get_open_port,
)
@@ -128,11 +130,27 @@ class MultiprocExecutor(Executor):
# For leader node within each dp rank,
# each dp will have its own leader multiproc executor.
max_chunk_bytes = envs.VLLM_MQ_MAX_CHUNK_BYTES_MB * 1024 * 1024
mq_connect_ip = get_ip()
logger.info(
"DP group leader: node_rank=%d, node_rank_within_dp=%d, "
"master_addr=%s, mq_connect_ip=%s (local), "
"world_size=%d, local_world_size=%d",
self.parallel_config.node_rank,
self.parallel_config.node_rank_within_dp,
self.parallel_config.master_addr,
mq_connect_ip,
self.world_size,
self.local_world_size,
)
mq_kwargs: dict[str, Any] = {}
if envs.VLLM_ENABLE_PP_ILU_OPT:
mq_kwargs["max_chunks"] = 32
self.rpc_broadcast_mq = MessageQueue(
self.world_size,
self.local_world_size,
max_chunk_bytes=max_chunk_bytes,
connect_ip=self.parallel_config.master_addr,
connect_ip=mq_connect_ip,
**mq_kwargs,
)
scheduler_output_handle = self.rpc_broadcast_mq.export_handle()
# Create workers
@@ -567,17 +585,22 @@ class WorkerProc:
)
self.async_output_copy_thread.start()
# Initialize device
self.worker.init_device()
# Set process title and log prefix
self.setup_proc_title_and_log_prefix(
enable_ep=vllm_config.parallel_config.enable_expert_parallel
)
# Load model
is_eep_new_worker = envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH
if not is_eep_new_worker:
self.worker.init_device()
# Update process title now that parallel groups are initialized
self.setup_proc_title_and_log_prefix(
enable_ep=vllm_config.parallel_config.enable_expert_parallel
)
self.worker.load_model()
# Initialize message queues after init_device() since multi-node setups
# (nnodes_within_dp > 1) require distributed groups to be initialized
self._init_message_queues(input_shm_handle, vllm_config)
self.worker.load_model()
# Enable environment variable cache (e.g. assume no more
# environment variable overrides after this point)
@@ -872,6 +895,13 @@ class WorkerProc:
@staticmethod
def setup_proc_title_and_log_prefix(enable_ep: bool) -> None:
# Check if parallel groups are initialized first
if not model_parallel_is_initialized():
# Parallel groups not yet initialized, use default process name
set_process_title(name="Worker")
decorate_logs("Worker")
return
dp_size = get_dp_group().world_size
dp_rank = get_dp_group().rank_in_group
pp_size = get_pp_group().world_size

View File

@@ -382,8 +382,10 @@ class RayDistributedExecutor(Executor):
all_kwargs.append(kwargs)
self.collective_rpc("init_worker", args=(all_kwargs,))
self.collective_rpc("init_device")
self.collective_rpc("load_model")
is_eep_new_worker = envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH
if not is_eep_new_worker:
self.collective_rpc("init_device")
self.collective_rpc("load_model")
for pp_rank in range(self.parallel_config.pipeline_parallel_size):
self.pp_tp_workers.append([])

View File

@@ -104,11 +104,23 @@ try:
scheduler_output, intermediate_tensors
)
if self._is_intermediate_tensors(output):
if (
self.worker.model_runner.supports_mm_inputs
and get_pp_group().is_first_rank
):
# Strip mm_features before Ray forwards it to the next PP Stage.
# PP Stage>0 only needs the intermediate tensors,
# not preprocessed multimodal data.
# scheduled_new_reqs is a required field of SchedulerOutput,
# so accessing it directly will raise AttributeError if missing.
for req in scheduler_output.scheduled_new_reqs:
req.mm_features = []
return scheduler_output, grammar_output, output
if isinstance(output, AsyncModelRunnerOutput):
output = output.get_output()
if not get_pp_group().is_last_rank:
if not self._is_last_rank():
# Case where there are no scheduled requests
# but may still be finished requests.
assert not output or not output.req_ids
@@ -128,6 +140,9 @@ try:
def _is_intermediate_tensors(self, output) -> bool:
return isinstance(output, IntermediateTensors)
def _is_last_rank(self) -> bool:
return get_pp_group().is_last_rank
ray_import_err = None
except ImportError as e:
@@ -362,7 +377,40 @@ def initialize_ray_cluster(
runtime_env=parallel_config.ray_runtime_env,
)
else:
ray.init(address=ray_address, runtime_env=parallel_config.ray_runtime_env)
import os
import torch
import vllm.envs as envs
runtime_env = {}
device_count = torch.cuda.device_count()
nccl_if_name = os.environ.get("NCCL_SOCKET_IFNAME",None)
vllm_nccl_comm = os.environ.get("VLLM_FORCE_NCCL_COMM",None)
if nccl_if_name is not None and vllm_nccl_comm is not None:
runtime_env = {"env_vars":{
"NCCL_SOCKET_IFNAME":nccl_if_name,
"VLLM_FORCE_NCCL_COMM":vllm_nccl_comm}}
elif nccl_if_name is not None:
runtime_env = {"env_vars":{
"NCCL_SOCKET_IFNAME":nccl_if_name}}
elif vllm_nccl_comm is not None:
runtime_env = {"env_vars":{
"VLLM_FORCE_NCCL_COMM":vllm_nccl_comm}}
if "env_vars" not in runtime_env:
runtime_env = {
"env_vars":{"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES":"1"}
}
else:
runtime_env["env_vars"].update({"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES":"1"})
all_envs = dict(os.environ)
all_vllm_envs = {k: v for k,v in all_envs.items() if "VLLM" in k}
runtime_env["env_vars"].update(all_vllm_envs)
# ray.init(address=ray_address, ignore_reinit_error=True, runtime_env=runtime_env)
if device_count >= parallel_config.world_size:
ray.init(address=ray_address,
ignore_reinit_error=True,
num_gpus=parallel_config.world_size,
runtime_env=runtime_env)
else:
ray.init(address=ray_address, ignore_reinit_error=True, runtime_env=runtime_env)
device_str = current_platform.ray_device_key
if not device_str:

View File

@@ -14,7 +14,6 @@ import vllm.envs as envs
from vllm.logger import init_logger
from vllm.utils.network_utils import get_distributed_init_method, get_ip, get_open_port
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.executor.abstract import Executor
from vllm.v1.outputs import AsyncModelRunnerOutput, DraftTokenIds, ModelRunnerOutput
from vllm.v1.serial_utils import run_method
@@ -43,9 +42,11 @@ class UniProcExecutor(Executor):
max_workers=1, thread_name_prefix="WorkerAsyncOutput"
)
is_eep_new_worker = envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH
self.driver_worker.init_worker(all_kwargs=[kwargs])
self.driver_worker.init_device()
self.driver_worker.load_model()
if not is_eep_new_worker:
self.driver_worker.init_device()
self.driver_worker.load_model()
def _distributed_args(self) -> tuple[str, int, int]:
"""Return (distributed_init_method, rank, local_rank)."""
@@ -122,16 +123,6 @@ class UniProcExecutor(Executor):
# it's running.
return
def reinitialize_distributed(
self, reconfig_request: ReconfigureDistributedRequest
) -> None:
self.driver_worker.reinitialize_distributed(reconfig_request)
if (
reconfig_request.new_data_parallel_rank
== ReconfigureRankType.SHUTDOWN_CURRENT_RANK
):
self.shutdown()
def shutdown(self) -> None:
if worker := self.driver_worker:
worker.shutdown()

View File

@@ -12,6 +12,7 @@ from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import get_dtype_size
import vllm.envs as envs
logger = init_logger(__name__)
@@ -34,6 +35,25 @@ class KVCacheSpec:
The page size
"""
raise NotImplementedError
@property
def scale_page_size_bytes(self) -> int:
"""
The size of a scale page with `block_size` tokens in bytes.
Returns:
The scale page size
"""
raise NotImplementedError
@property
def v_cache_scale_size_bytes(self) -> int:
"""
The size of a scale page with `block_size` tokens in bytes.
Returns:
The scale page size
"""
raise NotImplementedError
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
"""
@@ -78,13 +98,27 @@ class AttentionSpec(KVCacheSpec):
@property
def real_page_size_bytes(self) -> int:
return (
2
* self.block_size
* self.num_kv_heads
* self.head_size
* get_dtype_size(self.dtype)
)
if envs.VLLM_ATTN_OPT_LEVEL == 1:
# mla 和 i8qi8ki8v 申请的内存一样
return 2 * self.block_size * self.num_kv_heads * self.head_size \
* get_dtype_size(torch.int8)
elif envs.VLLM_ATTN_OPT_LEVEL == 2:
# i8qi8kf16v 申请的内存是f16+int8所以是3
return 3 * self.block_size * self.num_kv_heads * self.head_size \
* get_dtype_size(torch.int8)
return 2 * self.block_size * self.num_kv_heads * self.head_size \
* get_dtype_size(self.dtype)
@property
def scale_page_size_bytes(self) -> int:
# For MLA we only store a single latent vector
if envs.VLLM_ATTN_OPT_LEVEL > 0:
return self.block_size * self.num_kv_heads * get_dtype_size(torch.float32)
else:
return 0
@property
def v_cache_scale_size_bytes(self) -> int:
return self.head_size * self.num_kv_heads * get_dtype_size(torch.float32)
@dataclass(frozen=True, kw_only=True)
@@ -118,7 +152,7 @@ class FullAttentionSpec(AttentionSpec):
# (max_model_len//dcp_world_size) tokens locally.
if dcp_world_size * pcp_world_size > 1:
max_model_len = cdiv(max_model_len, dcp_world_size * pcp_world_size)
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
return cdiv(max_model_len, self.block_size) * (self.page_size_bytes + self.scale_page_size_bytes)
@classmethod
def merge_window_sizes(cls, window_sizes: set[int]) -> int | None:
@@ -179,12 +213,28 @@ class FullAttentionSpec(AttentionSpec):
@property
def real_page_size_bytes(self) -> int:
return (
self.block_size
* self.num_kv_heads
* (self.head_size + self.head_size_v)
* get_dtype_size(self.dtype)
)
if envs.VLLM_ATTN_OPT_LEVEL == 1:
return (
self.block_size
* self.num_kv_heads
* (self.head_size + self.head_size_v)
* get_dtype_size(torch.int8)
)
elif envs.VLLM_ATTN_OPT_LEVEL == 2:
return self.block_size * self.num_kv_heads * self.head_size \
* get_dtype_size(torch.int8) + self.block_size * self.num_kv_heads * self.head_size_v \
* get_dtype_size(self.dtype)
else:
return (
self.block_size
* self.num_kv_heads
* (self.head_size + self.head_size_v)
* get_dtype_size(self.dtype)
)
@property
def v_cache_scale_size_bytes(self) -> int:
return self.head_size_v * self.num_kv_heads * get_dtype_size(torch.float32)
@dataclass(frozen=True, kw_only=True)
@@ -198,12 +248,30 @@ class MLAAttentionSpec(FullAttentionSpec):
# See `vllm/v1/attention/backends/mla/flashmla_sparse.py`
# for details.
return self.block_size * 656
if envs.VLLM_USE_INT8_MLA:
return (
self.block_size
* self.num_kv_heads
* self.head_size
* get_dtype_size(torch.int8)
)
return (
self.block_size
* self.num_kv_heads
* self.head_size
* get_dtype_size(self.dtype)
)
@property
def scale_page_size_bytes(self) -> int:
# For MLA we only store a single latent vector
if envs.VLLM_USE_INT8_MLA:
return (
self.block_size
* self.num_kv_heads * 2
* get_dtype_size(torch.float32)
)
else:
return 0
@classmethod
def merge(cls, specs: list[Self]) -> Self:
@@ -267,7 +335,7 @@ class SlidingWindowSpec(AttentionSpec):
# of the block. For example, if the block size is 4 and num_token
# is 4, we need two blocks [XXCD] [EF] to store the sliding
# window [CDEF] of 6 tokens.
return (cdiv(num_tokens, self.block_size) + 1) * self.page_size_bytes
return (cdiv(num_tokens, self.block_size) + 1) * (self.page_size_bytes + self.scale_page_size_bytes)
@dataclass(frozen=True)
@@ -289,7 +357,6 @@ class MambaSpec(KVCacheSpec):
assert self.page_size_padded >= page_size
return self.page_size_padded
return page_size
@property
def scale_page_size_bytes(self) -> int:
return 0
@@ -389,6 +456,9 @@ class UniformTypeKVCacheSpecs(KVCacheSpec):
@property
def page_size_bytes(self) -> int:
return sum(spec.page_size_bytes for spec in self.kv_cache_specs.values())
@property
def scale_page_size_bytes(self) -> int:
return sum(spec.scale_page_size_bytes for spec in self.kv_cache_specs.values())
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
max_num_pages = max(
@@ -460,6 +530,7 @@ class KVCacheTensor:
size: int # size of the KV cache tensor in bytes
shared_by: list[str] # layer names that share the same KV cache tensor
size_scale: int = 0 # size of the v_cache_scale tensor in bytes, only used for VLLM_ATTN_OPT_LEVEL == 1
@dataclass
@@ -486,6 +557,7 @@ class KVCacheConfig:
kv_cache_tensors: list[KVCacheTensor]
"""How should model runner initialize the KV cache tensors for each layer"""
kv_cache_groups: list[KVCacheGroupSpec]
kv_cache_scale_tensors: list[KVCacheTensor]
"""
The kv cache groups of the model.
For models with only one type of attention, there is only one group that
@@ -493,3 +565,11 @@ class KVCacheConfig:
For models with multiple types of attention, there will be multiple groups,
see `_get_kv_cache_config_uniform_page_size` for more details.
"""
@property
def has_mamba_layers(self) -> bool:
return any(isinstance(g.kv_cache_spec, MambaSpec) for g in self.kv_cache_groups)
@property
def needs_kv_cache_zeroing(self) -> bool:
return self.has_mamba_layers

View File

@@ -259,16 +259,20 @@ class CpuGpuOffloadingHandlers:
assert gpu_shape[0] == 2
split_k_and_v = True
try:
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
include_num_layers_dimension=has_layers_dim
)
assert len(kv_cache_stride_order) == len(gpu_shape)
except (AttributeError, NotImplementedError):
kv_cache_stride_order = tuple(range(len(gpu_shape)))
if has_layers_dim:
# in the cross layers case, the registered kv cache tensor
# shape matches the physical layout, whereas test_shape
# is the logical layout.
# To match them, we need to permute test_shape
try:
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
include_num_layers_dimension=has_layers_dim
)
assert len(kv_cache_stride_order) == len(gpu_shape)
except (AttributeError, NotImplementedError):
kv_cache_stride_order = tuple(range(len(gpu_shape)))
# permute test_shape according to stride_order
test_shape = tuple(test_shape[i] for i in kv_cache_stride_order)
test_shape = tuple(test_shape[i] for i in kv_cache_stride_order)
# find block_size (16) dimension index
block_size_idx = test_shape.index(16)

View File

@@ -2,8 +2,9 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, NamedTuple, TypeAlias
from typing import TYPE_CHECKING, NamedTuple, TypeAlias, TypeVar
import numpy as np
import torch
@@ -120,6 +121,20 @@ class SamplerOutput:
logprobs_tensors: LogprobsTensors | None
T = TypeVar("T")
def _combine_non_none(f: Callable[[T, T], T], items: list[T | None]) -> T | None:
non_none = [item for item in items if item is not None]
if len(non_none) == 0:
return None
combined = non_none[0]
for item in non_none[1:]:
combined = f(combined, item)
return combined
@dataclass
class KVConnectorOutput:
# [req_ids]
@@ -146,6 +161,43 @@ class KVConnectorOutput:
and not self.invalid_block_ids
)
@classmethod
def merge(cls, *outputs: "KVConnectorOutput"):
assert len(outputs) > 0, "Cannot merge empty outputs"
finished_sending = _combine_non_none(
set.union, [output.finished_sending for output in outputs]
)
finished_recving = _combine_non_none(
set.union, [output.finished_recving for output in outputs]
)
kv_connector_stats = _combine_non_none(
lambda x, y: x.aggregate(y),
[output.kv_connector_stats for output in outputs],
)
kv_cache_events = _combine_non_none(
lambda x, y: x.merge(y),
[output.kv_cache_events for output in outputs],
)
invalid_block_ids = _combine_non_none(
set.union, [output.invalid_block_ids for output in outputs]
)
assert invalid_block_ids is not None
assert all(
output.expected_finished_count == outputs[0].expected_finished_count
for output in outputs
)
expected_finished_count = outputs[0].expected_finished_count
return cls(
finished_sending=finished_sending,
finished_recving=finished_recving,
kv_connector_stats=kv_connector_stats,
kv_cache_events=kv_cache_events,
invalid_block_ids=invalid_block_ids,
expected_finished_count=expected_finished_count,
)
@dataclass
class ECConnectorOutput:
@@ -153,7 +205,12 @@ class ECConnectorOutput:
finished_sending: set[str] | None = None
finished_recving: set[str] | None = None
@dataclass
class DraftTokenIds:
# [num_reqs]
req_ids: list[str]
# num_reqs x num_draft_tokens
draft_token_ids: list[list[int]]
# ModelRunnerOutput is serialized and sent to the scheduler process.
# This is expensive for torch.Tensor so prefer to use list instead.
@dataclass
@@ -191,6 +248,8 @@ class ModelRunnerOutput:
# req_id -> num_nans_in_logits
num_nans_in_logits: dict[str, int] | None = None
draft_token_ids: DraftTokenIds | None = None
# information related to cudagraph execution
cudagraph_stats: CUDAGraphStat | None = None
@@ -209,13 +268,6 @@ class AsyncModelRunnerOutput(ABC):
pass
@dataclass
class DraftTokenIds:
# [num_reqs]
req_ids: list[str]
# num_reqs x num_draft_tokens
draft_token_ids: list[list[int]]
def make_empty_encoder_model_runner_output(
scheduler_output: "SchedulerOutput",

View File

@@ -320,6 +320,7 @@ class RequestStatus(enum.IntEnum):
FINISHED_ABORTED = enum.auto()
FINISHED_IGNORED = enum.auto()
FINISHED_ERROR = enum.auto()
FINISHED_REPETITION = enum.auto()
def __str__(self) -> str:
return self.name
@@ -344,4 +345,5 @@ _FINISHED_REASON_MAP = {
RequestStatus.FINISHED_IGNORED: FinishReason.LENGTH,
RequestStatus.FINISHED_ERROR: FinishReason.ERROR,
RequestStatus.WAITING_FOR_STREAMING_REQ: FinishReason.STOP,
RequestStatus.FINISHED_REPETITION: FinishReason.REPETITION,
}

View File

@@ -202,10 +202,11 @@ def build_logitsprocs(
if custom_logitsprocs:
raise ValueError(STR_SPEC_DEC_REJECTS_LOGITSPROCS)
logger.warning(
"min_p, logit_bias, and min_tokens parameters won't currently work "
"with speculative decoding enabled."
"min_p and logit_bias parameters won't work with speculative decoding."
)
return LogitsProcessors(
[MinTokensLogitsProcessor(vllm_config, device, is_pin_memory)]
)
return LogitsProcessors()
custom_logitsprocs_classes = _load_custom_logitsprocs(custom_logitsprocs)
return LogitsProcessors(

View File

@@ -3,6 +3,7 @@
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, TypeVar
import numpy as np
import torch
from vllm import SamplingParams
@@ -236,6 +237,59 @@ class MinTokensLogitsProcessor(LogitsProcessor):
logits.index_put_(self.logits_slice, self.neg_inf_tensor)
return logits
def apply_with_spec_decode(
self,
logits: torch.Tensor,
num_draft_tokens: list[int],
) -> torch.Tensor:
"""Spec-decode version of apply().
Priority: ``min_tokens`` > ``stop_token_ids`` / EOS.
Example: ``num_draft_tokens = [2, 3, 1]``
→ ``logits`` shape ``[6, V]``, ``cumsum = [0, 2, 5, 6]``
→ request 0 owns rows 01, request 1 rows 24, request 2 row 5.
"""
if not self.min_toks:
return logits
num_draft_arr = np.array(num_draft_tokens, dtype=np.int64)
cumsum = np.concatenate([[0], np.cumsum(num_draft_arr)])
entries = [
(req_idx, min_tok, len(out_tok_ids), list(stop_tok_ids))
for req_idx, (min_tok, out_tok_ids, stop_tok_ids) in self.min_toks.items()
if stop_tok_ids
]
if not entries:
return logits
all_rows: list[np.ndarray] = [] # row indices to mask
all_toks: list[np.ndarray] = [] # stop-token ids at those rows
for req_idx, min_tok, current_len, stop_toks in entries:
remaining = min_tok - current_len
# How many leading draft positions still need stop-token masking.
n_mask = int(min(max(remaining, 0), num_draft_arr[req_idx]))
if n_mask > 0:
offset = cumsum[req_idx]
row_indices = np.arange(offset, offset + n_mask, dtype=np.int64)
n_stop = len(stop_toks)
all_rows.append(np.repeat(row_indices, n_stop))
all_toks.append(np.tile(stop_toks, n_mask))
if all_rows:
rows_arr = np.concatenate(all_rows)
toks_arr = np.concatenate(all_toks)
# (row_indices, token_indices) for index_put_ to set -inf.
logits_slice = (
torch.from_numpy(rows_arr).to(self.device, non_blocking=True),
torch.from_numpy(toks_arr).to(self.device, non_blocking=True),
)
logits.index_put_(logits_slice, self.neg_inf_tensor)
return logits
def process_dict_updates(
req_entries: dict[int, T],

View File

@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterator
from collections.abc import Iterable, Iterator
from itertools import chain
from typing import TYPE_CHECKING
@@ -148,7 +148,7 @@ class BatchUpdateBuilder:
class LogitsProcessors:
"""Encapsulates initialized logitsproc objects."""
def __init__(self, logitsprocs: Iterator["LogitsProcessor"] | None = None) -> None:
def __init__(self, logitsprocs: Iterable["LogitsProcessor"] | None = None) -> None:
self.argmax_invariant: list[LogitsProcessor] = []
self.non_argmax_invariant: list[LogitsProcessor] = []
if logitsprocs:

View File

@@ -10,12 +10,14 @@ import torch.nn as nn
from vllm.logger import init_logger
from vllm.triton_utils import tl, triton
from vllm.v1.outputs import LogprobsLists, LogprobsTensors, SamplerOutput
from vllm.v1.sample.logits_processor.builtin import MinTokensLogitsProcessor
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.bad_words import apply_bad_words_with_drafts
from vllm.v1.sample.ops.penalties import apply_all_penalties
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from vllm.v1.sample.sampler import Sampler
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm import _custom_ops as ops
logger = init_logger(__name__)
@@ -292,6 +294,12 @@ class RejectionSampler(nn.Module):
logits, bad_words_token_ids, output_token_ids, metadata.num_draft_tokens
)
for processor in sampling_metadata.logitsprocs.non_argmax_invariant:
if isinstance(processor, MinTokensLogitsProcessor):
logits = processor.apply_with_spec_decode(
logits, metadata.num_draft_tokens
)
return logits
@staticmethod
@@ -385,14 +393,13 @@ def rejection_sample(
if not sampling_metadata.all_random:
# Rejection sampling for greedy sampling requests.
target_argmax = target_logits.argmax(dim=-1)
rejection_greedy_sample_kernel[(batch_size,)](
ops.rejection_greedy_sample_torch(
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
target_argmax,
bonus_token_ids,
is_greedy,
max_spec_len,
)
if sampling_metadata.all_greedy:
return output_token_ids
@@ -424,7 +431,7 @@ def rejection_sample(
)
# Rejection sampling for random sampling requests.
rejection_random_sample_kernel[(batch_size,)](
ops.rejection_random_sample_torch(
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
@@ -434,8 +441,6 @@ def rejection_sample(
recovered_token_ids,
uniform_probs,
is_greedy,
max_spec_len,
vocab_size,
NO_DRAFT_PROBS=draft_probs is None,
)
return output_token_ids

View File

@@ -3,7 +3,7 @@
import ast
from dataclasses import replace
from importlib.util import find_spec
from typing import cast
from typing import Any, cast
import numpy as np
import torch
@@ -20,17 +20,13 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import supports_multimodal
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.platforms import current_platform
from vllm.triton_utils import triton
from vllm.utils.platform_utils import is_pin_memory_available
from vllm.v1.attention.backend import (
AttentionMetadataBuilder,
CommonAttentionMetadata,
)
from vllm.v1.attention.backend import CommonAttentionMetadata
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.backends.tree_attn import (
TreeAttentionMetadata,
@@ -38,14 +34,15 @@ from vllm.v1.attention.backends.tree_attn import (
)
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.kv_cache_interface import KVCacheConfig, UniformTypeKVCacheSpecs
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.sampler import _SAMPLING_EPS
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.metadata import MultiLayerEagleMetadata, SpecDecodeMetadata
from vllm.v1.spec_decode.utils import (
PADDING_SLOT_ID,
compute_new_slot_mapping,
copy_and_expand_eagle_inputs_kernel,
create_vllm_config_for_draft_model,
eagle_prepare_inputs_padded_kernel,
eagle_prepare_next_token_padded_kernel,
extend_all_queries_by_N,
@@ -53,6 +50,7 @@ from vllm.v1.spec_decode.utils import (
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.utils import AttentionGroup
logger = init_logger(__name__)
@@ -68,6 +66,7 @@ class SpecDecodeBaseProposer:
self.vllm_config = vllm_config
assert vllm_config.speculative_config is not None
self.speculative_config = vllm_config.speculative_config
self.draft_vllm_config = create_vllm_config_for_draft_model(vllm_config)
self.draft_model_config = self.speculative_config.draft_model_config
self.method = self.speculative_config.method
self.pass_hidden_states_to_model = pass_hidden_states_to_model
@@ -79,6 +78,9 @@ class SpecDecodeBaseProposer:
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
self.num_speculative_tokens = self.speculative_config.num_speculative_tokens
self.enable_multi_layers_mtp = self.speculative_config.enable_multi_layers_mtp
self.layer_num = 1
# We need to get the hidden size from the draft model config because
# the draft model's hidden size can be different from the target model's
# hidden size (e.g., Llama 3.3 70B).
@@ -113,21 +115,19 @@ class SpecDecodeBaseProposer:
vllm_config.model_config
)
self.attn_metadata_builder: AttentionMetadataBuilder | None = None
self.draft_indexer_metadata_builder: AttentionMetadataBuilder | None = None
self.attn_layer_names: list[str] = []
self.indexer_layer_names: list[str] = []
self.draft_attn_groups: list[AttentionGroup] = []
self.kv_cache_gid: int = -1
self.eagle3_use_aux_hidden_state: bool = (
self._get_eagle3_use_aux_hidden_state_from_config()
)
self.compilation_config = self.vllm_config.compilation_config
self.compilation_config = self.draft_vllm_config.compilation_config
# Cudagraph dispatcher for PIECEWISE-only dispatching in eagle.
# Keys are initialized later via initialize_cudagraph_keys() called from
# gpu_model_runner._check_and_update_cudagraph_mode after
# adjust_cudagraph_sizes_for_spec_decode is called.
self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config)
self.cudagraph_dispatcher = CudagraphDispatcher(self.draft_vllm_config)
# persistent buffers for cuda graph
self.input_ids = torch.zeros(
@@ -353,7 +353,7 @@ class SpecDecodeBaseProposer:
self._slot_mapping_buffer[num_actual:num_tokens].fill_(PADDING_SLOT_ID)
view = self._slot_mapping_buffer[:num_tokens]
return {name: view for name in self.attn_layer_names + self.indexer_layer_names}
return {name: view for name in self._draft_attn_layer_names}
def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None:
"""Initialize cudagraph dispatcher keys for eagle.
@@ -372,6 +372,23 @@ class SpecDecodeBaseProposer:
self.cudagraph_dispatcher.initialize_cudagraph_keys(eagle_cudagraph_mode)
def adjust_input(
self,
batch_size: int,
target_token_ids: torch.Tensor,
target_positions: torch.Tensor,
target_hidden_states: torch.Tensor,
token_indices_to_sample: torch.Tensor,
common_attn_metadata: CommonAttentionMetadata,
multi_layer_eagle_metadata: MultiLayerEagleMetadata | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, Any]:
return (
target_token_ids,
target_positions,
target_hidden_states,
common_attn_metadata,
)
def _greedy_sample(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""Greedy-sample draft tokens from hidden states."""
if self.use_local_argmax_reduction:
@@ -391,6 +408,7 @@ class SpecDecodeBaseProposer:
token_indices_to_sample: torch.Tensor | None,
common_attn_metadata: CommonAttentionMetadata,
sampling_metadata: SamplingMetadata,
multi_layer_eagle_metadata: MultiLayerEagleMetadata | None = None,
mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None,
num_rejected_tokens_gpu: torch.Tensor | None = None,
slot_mappings: dict[str, torch.Tensor]
@@ -406,6 +424,21 @@ class SpecDecodeBaseProposer:
)
assert target_hidden_states.shape[-1] == self.hidden_size
(
target_token_ids,
target_positions,
target_hidden_states,
common_attn_metadata,
) = self.adjust_input(
batch_size=batch_size,
target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
token_indices_to_sample=token_indices_to_sample,
common_attn_metadata=common_attn_metadata,
multi_layer_eagle_metadata=multi_layer_eagle_metadata,
)
num_tokens, token_indices_to_sample, common_attn_metadata = (
self.set_inputs_first_pass(
target_token_ids=target_token_ids,
@@ -420,109 +453,114 @@ class SpecDecodeBaseProposer:
assert self.runner is not None
if self.attn_metadata_builder is None:
attn_metadata_builder = self._get_attention_metadata_builder()
else:
attn_metadata_builder = self.attn_metadata_builder
per_layer_attn_metadata: dict[str, object] = {}
for attn_group in self.draft_attn_groups:
attn_metadata = attn_group.get_metadata_builder().build_for_drafting(
common_attn_metadata=common_attn_metadata, draft_index=0
)
for layer_name in attn_group.layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
attn_metadata = attn_metadata_builder.build_for_drafting(
common_attn_metadata=common_attn_metadata, draft_index=0
cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
self._determine_batch_execution_and_padding(num_tokens)
)
# FIXME: support hybrid kv for draft model (remove separate indexer)
if self.draft_indexer_metadata_builder:
draft_indexer_metadata = (
self.draft_indexer_metadata_builder.build_for_drafting(
common_attn_metadata=common_attn_metadata,
draft_index=0,
draft_token_ids_list = []
for spec_step_idx in range(self.layer_num):
if self.supports_mm_inputs:
mm_embeds, is_mm_embed = mm_embed_inputs or (None, None)
self.inputs_embeds[:num_tokens] = self.model.embed_input_ids(
self.input_ids[:num_tokens],
multimodal_embeddings=mm_embeds,
is_multimodal=is_mm_embed,
)
)
else:
draft_indexer_metadata = None
# At this moment, we assume all eagle layers belong to the same KV
# cache group, thus using the same attention metadata.
per_layer_attn_metadata = {}
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
for layer_name in self.indexer_layer_names:
assert draft_indexer_metadata is not None
per_layer_attn_metadata[layer_name] = draft_indexer_metadata
num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp(
num_tokens_unpadded=num_tokens, num_tokens_padded=num_tokens
)
cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
num_tokens_dp_padded
)
num_input_tokens = batch_desc.num_tokens
if num_tokens_across_dp is not None:
num_tokens_across_dp[self.dp_rank] = num_input_tokens
if self.supports_mm_inputs:
mm_embeds, is_mm_embed = mm_embed_inputs or (None, None)
self.inputs_embeds[:num_tokens] = self.model.embed_input_ids(
self.input_ids[:num_tokens],
multimodal_embeddings=mm_embeds,
is_multimodal=is_mm_embed,
)
input_ids = None
inputs_embeds = self.inputs_embeds[:num_input_tokens]
else:
input_ids = self.input_ids[:num_input_tokens]
inputs_embeds = None
model_kwargs = {
"input_ids": input_ids,
"positions": self._get_positions(num_input_tokens),
"inputs_embeds": inputs_embeds,
}
if self.pass_hidden_states_to_model:
model_kwargs["hidden_states"] = self.hidden_states[:num_input_tokens]
with set_forward_context(
per_layer_attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
slot_mapping=self._get_slot_mapping(
num_input_tokens, common_attn_metadata.slot_mapping
),
):
ret_hidden_states = self.model(**model_kwargs)
if not self.model_returns_tuple():
last_hidden_states = ret_hidden_states
hidden_states = last_hidden_states
input_ids = None
inputs_embeds = self.inputs_embeds[:num_input_tokens]
else:
last_hidden_states, hidden_states = ret_hidden_states
input_ids = self.input_ids[:num_input_tokens]
inputs_embeds = None
sample_hidden_states = last_hidden_states[token_indices_to_sample]
model_kwargs = {
"input_ids": input_ids,
"positions": self._get_positions(num_input_tokens),
"inputs_embeds": inputs_embeds,
}
if self.pass_hidden_states_to_model:
model_kwargs["hidden_states"] = self.hidden_states[:num_input_tokens]
# Early exit if there is only one draft token to be generated.
if self.num_speculative_tokens == 1 or self.parallel_drafting:
draft_token_ids = self._greedy_sample(sample_hidden_states)
return draft_token_ids.view(-1, self.num_speculative_tokens)
if self.enable_multi_layers_mtp:
model_kwargs["spec_step_idx"] = spec_step_idx
with set_forward_context(
per_layer_attn_metadata,
self.draft_vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
slot_mapping=self._get_slot_mapping(
num_input_tokens, common_attn_metadata.slot_mapping
),
):
ret_hidden_states = self.model(**model_kwargs)
if not self.model_returns_tuple():
last_hidden_states = ret_hidden_states
hidden_states = last_hidden_states
else:
last_hidden_states, hidden_states = ret_hidden_states
sample_hidden_states = last_hidden_states[token_indices_to_sample]
if self.enable_multi_layers_mtp:
logits = self.model.compute_logits(
sample_hidden_states, spec_step_idx=spec_step_idx
)
else:
logits = self.model.compute_logits(sample_hidden_states)
draft_token_ids = logits.argmax(dim=-1)
# Generate the remaining draft tokens.
draft_token_ids_list.append(draft_token_ids)
if spec_step_idx < self.layer_num - 1:
prev_token_ids = self.input_ids[:num_tokens].clone()
hidden_states = hidden_states[:num_tokens]
next_token_ids = draft_token_ids_list[-1].int()
num_tokens, token_indices_to_sample, common_attn_metadata = (
self.set_inputs_first_pass(
target_token_ids=prev_token_ids,
next_token_ids=next_token_ids,
target_positions=target_positions,
target_hidden_states=hidden_states,
token_indices_to_sample=token_indices_to_sample,
cad=common_attn_metadata,
num_rejected_tokens_gpu=num_rejected_tokens_gpu,
)
)
# Early exit if all draft tokens are generated in one pass
if self.num_speculative_tokens == self.layer_num or self.parallel_drafting:
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
return draft_token_ids
if self.uses_mrope:
positions = self.mrope_positions[:, token_indices_to_sample]
else:
positions = self.positions[token_indices_to_sample]
if self.method in (
"deepseek_mtp",
"ernie_mtp",
"longcat_flash_mtp",
"pangu_ultra_moe_mtp",
):
if self.method == "mtp":
hidden_states = self.hidden_states[token_indices_to_sample]
else:
hidden_states = hidden_states[token_indices_to_sample]
if isinstance(attn_metadata, TreeAttentionMetadata):
# Draft using tree attention - requires full logits for top-k
logits = self.model.compute_logits(sample_hidden_states)
if self.enable_multi_layers_mtp:
raise NotImplementedError(
"Speculative Decoding with multi-layer MTP and tree attention "
"is not supported yet."
)
# Draft using tree attention.
draft_token_ids_list = self.propose_tree(
batch_size=batch_size,
logits=logits,
@@ -534,32 +572,20 @@ class SpecDecodeBaseProposer:
# [batch_size, num_tree_tokens]
return torch.cat(draft_token_ids_list, dim=1)
draft_token_ids = self._greedy_sample(sample_hidden_states)
if self.allowed_attn_types is not None and not isinstance(
attn_metadata, self.allowed_attn_types
):
raise ValueError(
f"Unsupported attention metadata type for speculative "
"decoding with num_speculative_tokens > 1: "
"decoding with num_speculative_tokens > layer_num: "
f"{type(attn_metadata)}. Supported types are: "
f"{self.allowed_attn_types}"
)
# Generate the remaining draft tokens.
draft_token_ids_list = [draft_token_ids]
batch_size_dp_padded, batch_size_across_dp = self._pad_batch_across_dp(
num_tokens_unpadded=batch_size, num_tokens_padded=batch_size
cudagraph_runtime_mode, input_batch_size, batch_size_across_dp = (
self._determine_batch_execution_and_padding(batch_size)
)
cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
batch_size_dp_padded
)
input_batch_size = batch_desc.num_tokens
if batch_size_across_dp is not None:
batch_size_across_dp[self.dp_rank] = input_batch_size
common_attn_metadata.num_actual_tokens = batch_size
common_attn_metadata.max_query_len = 1
common_attn_metadata.query_start_loc = self.arange[: batch_size + 1]
@@ -577,7 +603,7 @@ class SpecDecodeBaseProposer:
common_attn_metadata._seq_lens_cpu = None
common_attn_metadata._num_computed_tokens_cpu = None
for token_index in range(self.num_speculative_tokens - 1):
for token_index in range(self.num_speculative_tokens - self.layer_num):
# Update the inputs.
# cast to int32 is crucial when eagle model is compiled.
# tensor.argmax() returns int64 by default.
@@ -627,7 +653,8 @@ class SpecDecodeBaseProposer:
common_attn_metadata._num_computed_tokens_cpu += 1
# Compute the slot mapping.
block_size = attn_metadata_builder.kv_cache_spec.block_size
# Use the first draft attention group's kv_cache_spec for block_size
block_size = self.draft_attn_groups[0].kv_cache_spec.block_size
if self.uses_mrope:
# all dimensions of positions are the same
block_numbers = clamped_positions[0] // block_size
@@ -653,11 +680,13 @@ class SpecDecodeBaseProposer:
)
# Rebuild attention metadata
attn_metadata = attn_metadata_builder.build_for_drafting( # type: ignore
common_attn_metadata=common_attn_metadata, draft_index=token_index + 1
)
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
for attn_group in self.draft_attn_groups:
attn_metadata = attn_group.get_metadata_builder().build_for_drafting(
common_attn_metadata=common_attn_metadata,
draft_index=token_index + 1,
)
for layer_name in attn_group.layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
# copy inputs to buffer for cudagraph
self.input_ids[:batch_size] = input_ids
@@ -683,7 +712,7 @@ class SpecDecodeBaseProposer:
with set_forward_context(
per_layer_attn_metadata,
self.vllm_config,
self.draft_vllm_config,
num_tokens=input_batch_size,
num_tokens_across_dp=batch_size_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
@@ -819,18 +848,17 @@ class SpecDecodeBaseProposer:
# 2.
# Recompute the slot mapping based on the new positions and
# rejection mask.
builder = (
self._get_attention_metadata_builder()
if self.attn_metadata_builder is None
else self.attn_metadata_builder
)
# Use the first draft attention group's kv_cache_spec for block_size
# (all draft layers share the same kv-cache group)
assert len(self.draft_attn_groups) > 0
block_size = self.draft_attn_groups[0].kv_cache_spec.block_size
new_slot_mapping = compute_new_slot_mapping(
cad=cad,
new_positions=self.positions[:total_num_output_tokens],
is_rejected_token_mask=self.is_rejected_token_mask[
:total_num_output_tokens
],
block_size=builder.kv_cache_spec.block_size,
block_size=block_size,
num_new_tokens=self.net_num_new_slots_per_request,
max_model_len=self.max_model_len,
)
@@ -880,6 +908,69 @@ class SpecDecodeBaseProposer:
next_token_ids, dtype=torch.int32, device=self.input_ids.device
)
return next_token_ids
def eagle_prepare_next_token_padded(
self,
sampled_token_ids, # [num_reqs, num_sampled_tokens_per_req]
discard_request_mask, # [num_reqs]
backup_next_token_ids, # [num_reqs]
vocab_size,
):
"""
PyTorch implementation of eagle_prepare_next_token_padded kernel.
Args:
sampled_token_ids: Tensor of shape [num_reqs, num_sampled_tokens_per_req]
containing sampled token ids (-1 for rejected tokens)
discard_request_mask: Boolean tensor of shape [num_reqs] indicating
which requests should be discarded
backup_next_token_ids: Tensor of shape [num_reqs] containing backup
token ids for when no valid tokens are found
vocab_size: Vocabulary size for validity checking
Returns:
next_token_ids: Tensor of shape [num_reqs] containing the next token
to sample (last accepted token or backup)
valid_sampled_tokens_count: Tensor of shape [num_reqs] containing the
number of valid (1 + accepted) tokens
"""
num_reqs = sampled_token_ids.shape[0]
num_sampled_tokens_per_req = sampled_token_ids.shape[1]
# Initialize output tensors
next_token_ids = torch.empty(num_reqs, dtype=sampled_token_ids.dtype, device=sampled_token_ids.device)
valid_sampled_tokens_count = torch.zeros(num_reqs, dtype=torch.int32, device=sampled_token_ids.device)
# Process each request
for req_idx in range(num_reqs):
if discard_request_mask[req_idx]:
# Discarded request: use backup token and valid_count=0
next_token_ids[req_idx] = backup_next_token_ids[req_idx]
valid_sampled_tokens_count[req_idx] = 0
else:
# Get sampled tokens for this request
tokens = sampled_token_ids[req_idx]
# Find valid tokens (not -1 and within vocabulary range)
is_valid = (tokens != -1) & (tokens < vocab_size)
valid_count = is_valid.sum().item()
if valid_count > 0:
# Find the last valid token index
# Get indices where is_valid is True
valid_indices = torch.where(is_valid)[0]
last_valid_idx = valid_indices[-1].item()
# Get the token at that index
last_valid_token = tokens[last_valid_idx]
next_token_ids[req_idx] = last_valid_token
else:
# No valid tokens, use backup token
next_token_ids[req_idx] = backup_next_token_ids[req_idx]
valid_sampled_tokens_count[req_idx] = valid_count
return next_token_ids, valid_sampled_tokens_count
def prepare_next_token_ids_padded(
self,
@@ -910,31 +1001,15 @@ class SpecDecodeBaseProposer:
self.backup_next_token_ids.copy_to_gpu(num_reqs)
backup_tokens_gpu = self.backup_next_token_ids.gpu
batch_size, num_tokens = sampled_token_ids.shape
device = sampled_token_ids.device
assert discard_request_mask.dtype == torch.bool
assert backup_tokens_gpu.dtype == torch.int32
next_token_ids = torch.empty(batch_size, dtype=torch.int32, device=device)
valid_sampled_tokens_count = next_token_ids.new_empty(batch_size)
# Kernel grid: one program per request (row)
grid = (batch_size,)
# Find the next power of 2 for block sizes
BLOCK_SIZE_TOKENS = triton.next_power_of_2(num_tokens)
eagle_prepare_next_token_padded_kernel[grid](
next_token_ids, valid_sampled_tokens_count = self.eagle_prepare_next_token_padded(
sampled_token_ids,
discard_request_mask,
backup_tokens_gpu,
next_token_ids,
valid_sampled_tokens_count,
gpu_input_batch.vocab_size,
num_tokens,
batch_size,
sampled_token_ids.stride(0),
BLOCK_SIZE_TOKENS=BLOCK_SIZE_TOKENS,
)
return next_token_ids, valid_sampled_tokens_count
@@ -974,6 +1049,8 @@ class SpecDecodeBaseProposer:
)
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
key_start_loc = common_attn_metadata.key_start_loc
new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
total_num_tokens = query_start_loc_cpu[-1].item()
@@ -981,7 +1058,9 @@ class SpecDecodeBaseProposer:
spec_common_attn_metadata = CommonAttentionMetadata(
query_start_loc=common_attn_metadata.query_start_loc,
seq_lens=common_attn_metadata.seq_lens,
seq_lens_np = common_attn_metadata.seq_lens_np,
query_start_loc_cpu=query_start_loc_cpu,
key_start_loc=key_start_loc,
_seq_lens_cpu=common_attn_metadata._seq_lens_cpu,
_num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu,
num_reqs=common_attn_metadata.num_reqs,
@@ -1014,9 +1093,7 @@ class SpecDecodeBaseProposer:
| list[dict[str, torch.Tensor]]
| None = None,
) -> list[torch.Tensor]:
tree_attn_metadata_builder = self.runner.attn_groups[0][
0
].get_metadata_builder()
tree_attn_metadata_builder = self.draft_attn_groups[0].get_metadata_builder()
assert isinstance(tree_attn_metadata_builder, TreeAttentionMetadataBuilder)
total_num_drafts = self.cu_drafts_per_level[0]
@@ -1092,10 +1169,11 @@ class SpecDecodeBaseProposer:
common_attn_metadata=common_attn_metadata, draft_index=level + 1
)
# Apply new attention metadata to all layers.
# Apply new attention metadata to all draft layers.
per_layer_attn_metadata = {}
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
for attn_group in self.draft_attn_groups:
for layer_name in attn_group.layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
# Consider max model length.
attn_metadata.max_seq_len = min(
@@ -1131,7 +1209,7 @@ class SpecDecodeBaseProposer:
# Run the model.
with set_forward_context(
per_layer_attn_metadata,
self.vllm_config,
self.draft_vllm_config,
num_tokens=num_input_tokens,
cudagraph_runtime_mode=cudagraph_runtime_mode,
slot_mapping=self._get_slot_mapping(
@@ -1209,6 +1287,7 @@ class SpecDecodeBaseProposer:
device = common_attn_metadata.query_start_loc.device
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
key_start_loc = common_attn_metadata.key_start_loc
new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu - num_rejected_tokens
# [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3]
@@ -1261,6 +1340,7 @@ class SpecDecodeBaseProposer:
query_start_loc=new_query_start_loc_cpu.to(device, non_blocking=True),
seq_lens=new_seq_lens_cpu.to(device, non_blocking=True),
query_start_loc_cpu=new_query_start_loc_cpu,
key_start_loc=key_start_loc,
_seq_lens_cpu=new_seq_lens_cpu,
_num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu,
num_reqs=common_attn_metadata.num_reqs,
@@ -1289,7 +1369,7 @@ class SpecDecodeBaseProposer:
with set_model_tag("eagle_head"):
model = get_model(
vllm_config=self.vllm_config,
vllm_config=self.draft_vllm_config,
model_config=self.speculative_config.draft_model_config,
load_config=self.speculative_config.draft_load_config,
)
@@ -1302,43 +1382,17 @@ class SpecDecodeBaseProposer:
AttentionLayerBase, # type: ignore[type-abstract]
).keys()
)
# FIXME: support hybrid kv for draft model
target_indexer_layer_names = set(
get_layers_from_vllm_config(
self.vllm_config, DeepseekV32IndexerCache
).keys()
)
self.model = self._get_model()
draft_attn_layer_names = (
get_layers_from_vllm_config(
self.vllm_config,
AttentionLayerBase, # type: ignore[type-abstract]
).keys()
- target_attn_layer_names
# Find draft layers (attention layers added by draft model)
all_attn_layers = get_layers_from_vllm_config(
self.draft_vllm_config,
AttentionLayerBase, # type: ignore[type-abstract]
)
indexer_layers = get_layers_from_vllm_config(
self.vllm_config, DeepseekV32IndexerCache
self._draft_attn_layer_names = (
set(all_attn_layers.keys()) - target_attn_layer_names
)
draft_indexer_layer_names = indexer_layers.keys() - target_indexer_layer_names
self.attn_layer_names = list(draft_attn_layer_names - draft_indexer_layer_names)
self.indexer_layer_names = list(draft_indexer_layer_names)
if self.indexer_layer_names:
first_layer = self.indexer_layer_names[0]
self.draft_indexer_metadata_builder = (
indexer_layers[first_layer]
.get_attn_backend()
.get_builder_cls()(
indexer_layers[first_layer].get_kv_cache_spec(self.vllm_config),
self.indexer_layer_names,
self.vllm_config,
self.device,
)
)
else:
self.draft_indexer_metadata_builder = None
if self.supports_mm_inputs:
# Even if the target model is multimodal, we can also use
@@ -1568,25 +1622,17 @@ class SpecDecodeBaseProposer:
self.num_speculative_tokens if not is_graph_capturing else 1
):
if fwd_idx <= 1:
num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp(
num_tokens_unpadded=num_tokens, num_tokens_padded=num_tokens
)
if use_cudagraphs:
cudagraph_runtime_mode, batch_desc = (
self.cudagraph_dispatcher.dispatch(num_tokens_dp_padded)
cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
self._determine_batch_execution_and_padding(
num_tokens, use_cudagraphs=use_cudagraphs
)
num_input_tokens = batch_desc.num_tokens
else:
cudagraph_runtime_mode = CUDAGraphMode.NONE
num_input_tokens = num_tokens_dp_padded
if num_tokens_across_dp is not None:
num_tokens_across_dp[self.dp_rank] = num_input_tokens
)
# Make sure to use EAGLE's own buffer during cudagraph capture.
if (
self.attn_layer_names
self._draft_attn_layer_names
and slot_mappings is not None
and self.attn_layer_names[0] in slot_mappings
and next(iter(self._draft_attn_layer_names)) in slot_mappings
):
slot_mapping_dict = self._get_slot_mapping(num_input_tokens)
else:
@@ -1594,7 +1640,7 @@ class SpecDecodeBaseProposer:
with set_forward_context(
None,
self.vllm_config,
self.draft_vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
@@ -1616,31 +1662,6 @@ class SpecDecodeBaseProposer:
kwargs["hidden_states"] = self.hidden_states[:num_input_tokens]
self.model(**kwargs)
def _get_attention_metadata_builder(self) -> AttentionMetadataBuilder:
"""Find and return the attention metadata builders for EAGLE layers.
Returns:
The metadata builders for EAGLE layers.
Raises:
AssertionError: If no metadata builders are found for EAGLE layers.
"""
builder = None
chosen_layer = self.attn_layer_names[0]
for kv_cache_group in self.runner.attn_groups:
for attn_group in kv_cache_group:
if chosen_layer in attn_group.layer_names:
builder = attn_group.get_metadata_builder()
break
if builder is not None:
break
assert builder is not None, (
"Failed to find attention metadata builder for EAGLE layers."
)
return builder
def _get_eagle3_use_aux_hidden_state_from_config(self) -> bool:
"""
Some eagle3 heads (e.g., nvidia/gpt-oss-120b-Eagle3-v2) do not use auxiliary
@@ -1673,35 +1694,114 @@ class SpecDecodeBaseProposer:
set(
[
kv_cache_groups[layer_name]
for layer_name in self.attn_layer_names
for layer_name in self._draft_attn_layer_names
]
)
)
== 1
), "All drafting layers should belong to the same kv cache group"
def _pad_batch_across_dp(
def initialize_attn_backend(
self,
num_tokens_unpadded: int,
num_tokens_padded: int,
) -> tuple[int, torch.Tensor]:
# TODO(Flechman): support DBO ubatching
should_ubatch, num_toks_across_dp, _ = coordinate_batch_across_dp(
num_tokens_unpadded=num_tokens_unpadded,
parallel_config=self.vllm_config.parallel_config,
allow_microbatching=False,
allow_dp_padding=self.cudagraph_dispatcher.cudagraph_mode
!= CUDAGraphMode.NONE,
num_tokens_padded=num_tokens_padded,
uniform_decode=None,
num_scheduled_tokens_per_request=None,
kv_cache_config: KVCacheConfig,
kernel_block_sizes: list[int] | None = None,
) -> None:
"""
Initialize AttentionGroups for draft layers using kv_cache_config.
Called from the model runner's initialize_metadata_builders.
"""
all_attn_layers = get_layers_from_vllm_config(
self.draft_vllm_config,
AttentionLayerBase, # type: ignore[type-abstract]
)
assert not should_ubatch, "DBO ubatching not implemented for EAGLE"
num_tokens_dp_padded = num_tokens_padded
if num_toks_across_dp is not None:
num_tokens_dp_padded = int(num_toks_across_dp[self.dp_rank].item())
return num_tokens_dp_padded, num_toks_across_dp
# Find which kv_cache_group the draft layers belong to
self.validate_same_kv_cache_group(kv_cache_config)
kv_cache_spec = None
for gid, group in enumerate(kv_cache_config.kv_cache_groups):
if self._draft_attn_layer_names & set(group.layer_names):
self.kv_cache_gid = gid
kv_cache_spec = group.kv_cache_spec
break
attention_groups: dict[tuple[str, str], AttentionGroup] = {}
if kv_cache_spec is not None:
for layer_name in self._draft_attn_layer_names:
attn_backend = all_attn_layers[layer_name].get_attn_backend()
backend_key = attn_backend.full_cls_name()
if backend_key not in attention_groups:
layer_kv_cache_spec = kv_cache_spec
if isinstance(layer_kv_cache_spec, UniformTypeKVCacheSpecs):
layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[
layer_name
]
kernel_block_size = (
kernel_block_sizes[self.kv_cache_gid]
if kernel_block_sizes is not None
and self.kv_cache_gid < len(kernel_block_sizes)
else None
)
attn_group = AttentionGroup(
backend=attn_backend,
layer_names=[layer_name],
kv_cache_spec=layer_kv_cache_spec,
kv_cache_group_id=self.kv_cache_gid,
)
attn_group.create_metadata_builders(
self.draft_vllm_config,
self.device,
kernel_block_size=kernel_block_size,
)
attention_groups[backend_key] = attn_group
else:
attention_groups[backend_key].layer_names.append(layer_name)
self.draft_attn_groups = list(attention_groups.values())
def _determine_batch_execution_and_padding(
self,
num_tokens: int,
use_cudagraphs: bool = True,
) -> tuple[CUDAGraphMode, int, torch.Tensor | None]:
cudagraph_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
num_tokens,
valid_modes=({CUDAGraphMode.NONE} if not use_cudagraphs else None),
)
num_tokens_padded = batch_desc.num_tokens
# Extra coordination when running data-parallel since we need to
# coordinate across ranks
# TODO(Flechman): support DBO ubatching
should_ubatch, num_tokens_across_dp = False, None
if self.draft_vllm_config.parallel_config.data_parallel_size > 1:
should_ubatch, num_tokens_across_dp, synced_cudagraph_mode = (
coordinate_batch_across_dp(
num_tokens_unpadded=num_tokens,
parallel_config=self.draft_vllm_config.parallel_config,
allow_microbatching=False,
num_tokens_padded=num_tokens_padded,
cudagraph_mode=cudagraph_mode.value,
)
)
assert not should_ubatch, "DBO ubatching not implemented for EAGLE"
# Extract DP-synced values
if num_tokens_across_dp is not None:
dp_rank = self.dp_rank
num_tokens_padded = int(num_tokens_across_dp[dp_rank].item())
# Re-dispatch with DP padding so we have the correct
# batch_descriptor
cudagraph_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
num_tokens_padded,
valid_modes={CUDAGraphMode(synced_cudagraph_mode)},
)
# Assert to make sure the agreed upon token count is correct
# otherwise num_tokens_across_dp will no-longer be valid
assert batch_desc.num_tokens == num_tokens_padded
num_tokens_across_dp[dp_rank] = num_tokens_padded
return cudagraph_mode, num_tokens_padded, num_tokens_across_dp
class EagleProposer(SpecDecodeBaseProposer):

View File

@@ -0,0 +1,395 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
from contextlib import nullcontext
from typing import TYPE_CHECKING
import torch
import torch.nn as nn
from vllm.config import CUDAGraphMode, VllmConfig, get_layers_from_vllm_config
from vllm.distributed.kv_transfer import has_kv_transfer_group
from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.model_loader import get_model
from vllm.v1.attention.backend import AttentionMetadataBuilder, CommonAttentionMetadata
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
from vllm.v1.outputs import KVConnectorOutput
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
PADDING_SLOT_ID = -1
class ExtractHiddenStatesProposer:
def __init__(self, vllm_config: VllmConfig, device):
assert vllm_config.speculative_config is not None
assert vllm_config.speculative_config.num_speculative_tokens == 1
if vllm_config.speculative_config.disable_padded_drafter_batch:
raise ValueError(
"disable_padded_drafter_batch is not supported with "
"extract_hidden_states method"
)
self.vllm_config = vllm_config
self.device = device
self.dtype = vllm_config.model_config.dtype
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
# Model and attention layer tracking (initialized in load_model)
self.model: nn.Module | None = None
self.attn_layer_names: list[str] = []
self.attn_metadata_builder: AttentionMetadataBuilder | None = None
# Maximum number of tokens for buffers
max_batch_size = vllm_config.scheduler_config.max_num_seqs
self.max_num_tokens = (
vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size
)
self.hf_config = vllm_config.speculative_config.draft_model_config.hf_config
layer_ids = getattr(self.hf_config, "eagle_aux_hidden_state_layer_ids", None)
if not layer_ids:
raise ValueError(
"eagle_aux_hidden_state_layer_ids must be set in the draft "
"model config for extract_hidden_states method"
)
self.num_hidden_states = len(layer_ids)
self.hidden_size = vllm_config.model_config.get_hidden_size()
self.hidden_states = torch.zeros(
(self.max_num_tokens, self.num_hidden_states, self.hidden_size),
dtype=self.dtype,
device=device,
)
self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config)
self._slot_mapping_buffer = torch.zeros(
self.max_num_tokens, dtype=torch.int64, device=device
)
def propose(
self,
sampled_token_ids: torch.Tensor,
target_hidden_states: list[torch.Tensor],
common_attn_metadata: CommonAttentionMetadata,
scheduler_output: SchedulerOutput,
slot_mappings: dict[str, torch.Tensor]
| list[dict[str, torch.Tensor]]
| None = None,
) -> tuple[torch.Tensor, KVConnectorOutput | None]:
"""Propose draft tokens by calling the ExtractHiddenStatesModel model.
The ExtractHiddenStatesModel caches the hidden states in the KV cache
without performing actual attention computation. This allows us to
extract and store hidden states for later use (e.g., KV transfer).
This proposer doesn't actually perform speculation - it returns the
sampled tokens as "draft" tokens, ensuring they always verify (match).
The main purpose is to cache hidden states, not to speculate.
Args:
sampled_token_ids: Sampled token IDs from the target model
target_hidden_states: List of hidden state tensors from target model
(one per aux hidden state layer)
common_attn_metadata: Attention metadata
scheduler_output: Scheduler output for KV connector
slot_mappings: Slot mappings for KV cache (unused, provided for
interface compatibility)
Returns:
Tuple of:
- Draft tokens matching sampled tokens, shape [batch_size, 1]
- KV connector output (if KV transfer is active), else None
"""
assert self.model is not None and isinstance(target_hidden_states, list)
# target_hidden_states is a list of tensors (one per layer)
# Each tensor has shape [num_tokens, hidden_size]
# Stack to shape: [num_tokens, num_hidden_states, hidden_size]
stacked_hidden_states = torch.stack(target_hidden_states, dim=1)
num_tokens = stacked_hidden_states.shape[0]
# Copy hidden states to buffer
self.hidden_states[:num_tokens] = stacked_hidden_states
assert self.attn_metadata_builder is not None
attn_metadata = self.attn_metadata_builder.build_for_drafting(
common_attn_metadata=common_attn_metadata, draft_index=0
)
# We assume all cache-only layers belong to the same KV cache group,
# thus using the same attention metadata.
per_layer_attn_metadata = {}
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
self._determine_batch_execution_and_padding(num_tokens)
)
if num_tokens_across_dp is not None:
num_tokens_across_dp[self.dp_rank] = num_input_tokens
with (
set_forward_context(
per_layer_attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
slot_mapping=self._get_slot_mapping(
num_input_tokens, common_attn_metadata.slot_mapping
),
),
(
KVConnectorModelRunnerMixin._get_kv_connector_output(scheduler_output)
if has_kv_transfer_group()
else nullcontext()
) as kv_connector_output,
):
self.model(
hidden_states=self.hidden_states[:num_input_tokens],
)
# Return the sampled tokens as "draft" tokens
# Shape: [batch_size, 1] to match num_speculative_tokens=1
return sampled_token_ids.unsqueeze(-1), kv_connector_output
def _get_slot_mapping(
self,
num_tokens: int,
slot_mapping: torch.Tensor | None = None,
) -> dict[str, torch.Tensor]:
"""Return slot_mapping dict for cache-only attention layers.
If slot_mapping is provided, copies it into the buffer first.
"""
if slot_mapping is not None:
num_actual = slot_mapping.shape[0]
self._slot_mapping_buffer[:num_actual].copy_(slot_mapping)
if num_tokens > num_actual:
self._slot_mapping_buffer[num_actual:num_tokens].fill_(PADDING_SLOT_ID)
view = self._slot_mapping_buffer[:num_tokens]
return {name: view for name in self.attn_layer_names}
def _determine_batch_execution_and_padding(
self,
num_tokens: int,
use_cudagraphs: bool = True,
) -> tuple[CUDAGraphMode, int, torch.Tensor | None]:
cudagraph_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
num_tokens,
valid_modes=({CUDAGraphMode.NONE} if not use_cudagraphs else None),
)
num_tokens_padded = batch_desc.num_tokens
# Extra coordination when running data-parallel since we need to
# coordinate across ranks
# TODO(Flechman): support DBO ubatching
should_ubatch, num_tokens_across_dp = False, None
if self.vllm_config.parallel_config.data_parallel_size > 1:
should_ubatch, num_tokens_across_dp, synced_cudagraph_mode = (
coordinate_batch_across_dp(
num_tokens_unpadded=num_tokens,
parallel_config=self.vllm_config.parallel_config,
allow_microbatching=False,
num_tokens_padded=num_tokens_padded,
cudagraph_mode=cudagraph_mode.value,
)
)
assert not should_ubatch, (
"DBO ubatching not implemented for extract_hidden_states"
)
# Extract DP-synced values
if num_tokens_across_dp is not None:
dp_rank = self.dp_rank
num_tokens_padded = int(num_tokens_across_dp[dp_rank].item())
# Re-dispatch with DP padding so we have the correct
# batch_descriptor
cudagraph_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
num_tokens_padded,
valid_modes={CUDAGraphMode(synced_cudagraph_mode)},
)
# Assert to make sure the agreed upon token count is correct
# otherwise num_tokens_across_dp will no-longer be valid
assert batch_desc.num_tokens == num_tokens_padded
num_tokens_across_dp[dp_rank] = num_tokens_padded
return cudagraph_mode, num_tokens_padded, num_tokens_across_dp
def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None:
"""Initialize cudagraph dispatcher keys.
Only supports PIECEWISE cudagraphs (via mixed_mode).
Should be called after adjust_cudagraph_sizes_for_spec_decode.
"""
assert self.vllm_config.speculative_config is not None
if (
not self.vllm_config.speculative_config.enforce_eager
and cudagraph_mode.mixed_mode()
in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL]
):
proposer_cudagraph_mode = CUDAGraphMode.PIECEWISE
else:
proposer_cudagraph_mode = CUDAGraphMode.NONE
self.cudagraph_dispatcher.initialize_cudagraph_keys(proposer_cudagraph_mode)
@torch.inference_mode()
def dummy_run(
self,
num_tokens: int,
use_cudagraphs: bool = True,
is_graph_capturing: bool = False,
slot_mappings: dict[str, torch.Tensor] | None = None,
) -> None:
assert self.model is not None, "Model must be initialized before dummy_run"
cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
self._determine_batch_execution_and_padding(
num_tokens, use_cudagraphs=use_cudagraphs
)
)
if num_tokens_across_dp is not None:
num_tokens_across_dp[self.dp_rank] = num_input_tokens
# Use our own slot mapping buffer during cudagraph capture.
if (
self.attn_layer_names
and slot_mappings is not None
and self.attn_layer_names[0] in slot_mappings
):
slot_mapping_dict = self._get_slot_mapping(num_input_tokens)
else:
slot_mapping_dict = slot_mappings or {}
with set_forward_context(
None,
self.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
slot_mapping=slot_mapping_dict,
):
self.model(
hidden_states=self.hidden_states[:num_input_tokens],
)
def _build_attn_metadata_builder(
self, draft_attn_layers: dict[str, AttentionLayerBase]
) -> AttentionMetadataBuilder:
"""Build the attention metadata builder from draft attention layers."""
if not draft_attn_layers:
raise ValueError("No attention layers found for ExtractHiddenStatesModel")
layer = next(iter(draft_attn_layers.values()))
attn_backend = layer.get_attn_backend()
return attn_backend.get_builder_cls()(
layer.get_kv_cache_spec(self.vllm_config),
self.attn_layer_names,
self.vllm_config,
self.device,
)
def prepare_next_token_ids_padded(
self,
common_attn_metadata: CommonAttentionMetadata,
sampled_token_ids: torch.Tensor,
requests: dict[str, CachedRequestState],
gpu_input_batch: InputBatch,
discard_request_mask: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Prepare next token IDs for speculative decoding.
Since num_speculative_tokens == 1, sampled_token_ids has shape
(batch_size, 1). For each request we either use the sampled token
(if valid and not discarded) or a backup token from the request state.
"""
num_reqs = gpu_input_batch.num_reqs
device = sampled_token_ids.device
# Compute backup tokens for discarded / invalid requests
backup_tokens_gpu = torch.tensor(
[
requests[gpu_input_batch.req_ids[i]].get_token_id(
common_attn_metadata.seq_lens_cpu[i].item()
)
for i in range(num_reqs)
],
dtype=torch.int32,
device=device,
)
assert discard_request_mask.dtype == torch.bool
# With num_speculative_tokens == 1, there is exactly one token
sampled = sampled_token_ids[:, 0]
is_valid = (sampled >= 0) & (sampled < gpu_input_batch.vocab_size)
valid_sampled_tokens_count = is_valid.to(torch.int32)
use_sampled = is_valid & ~discard_request_mask[:num_reqs]
next_token_ids = torch.where(
use_sampled, sampled.to(torch.int32), backup_tokens_gpu
)
return next_token_ids, valid_sampled_tokens_count
def load_model(self, target_model: nn.Module) -> None:
"""Load the ExtractHiddenStatesModel model.
This method instantiates the ExtractHiddenStatesModel model which is used
to cache hidden states during speculative decoding. The model uses
cache-only attention (no computation, just caching KV states).
Args:
target_model: The target model (passed for compatibility with
EagleProposer interface, but not used here)
"""
# Get the target model's attention layers before loading draft model
target_attn_layer_names = set(
get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys() # type: ignore[type-abstract]
)
assert self.vllm_config.speculative_config is not None
draft_model_config = self.vllm_config.speculative_config.draft_model_config
from vllm.compilation.backends import set_model_tag
with set_model_tag("extract_hidden_states"):
self.model = get_model(
vllm_config=self.vllm_config, model_config=draft_model_config
)
# Identify draft model's attention layers (difference from target)
all_attn_layers = get_layers_from_vllm_config(
self.vllm_config,
AttentionLayerBase, # type: ignore[type-abstract]
)
draft_attn_layers = {
name: layer
for name, layer in all_attn_layers.items()
if name not in target_attn_layer_names
}
self.attn_layer_names = list(draft_attn_layers.keys())
assert len(draft_attn_layers) == 1, (
"ExtractHiddenStatesModel should have exactly one "
f"attention layer, found {len(draft_attn_layers)}"
)
self.attn_metadata_builder = self._build_attn_metadata_builder(
draft_attn_layers
)
def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None:
"""Validate all drafting layers belong to the same KV cache group.
With exactly one attention layer (asserted in load_model), this is
trivially satisfied.
"""
assert len(self.attn_layer_names) == 1

View File

@@ -64,3 +64,45 @@ class SpecDecodeMetadata:
bonus_logits_indices=bonus_logits_indices,
logits_indices=logits_indices,
)
@dataclass
class MultiLayerEagleMetadata:
# [batch_size]
cached_len: torch.Tensor | None = None
# [batch_size, layer_num]
cached_token_ids: torch.Tensor | None = None
# [batch_size, layer_num, hidden_size]
cached_hidden_states: torch.Tensor | None = None
# [batch_size, layer_num]
cached_slot_mappings: torch.Tensor | None = None
# [batch_size, layer_num]
cached_positions: torch.Tensor | None = None
@classmethod
def make_dummy(
cls,
layer_num: int,
hidden_size: int,
device: torch.device,
) -> "MultiLayerEagleMetadata":
cached_len = torch.zeros((1), dtype=torch.int64, device=device)
cached_token_ids = torch.zeros(
(1, layer_num), dtype=torch.int32, device=device
)
cached_hidden_states = torch.zeros(
(1, layer_num, hidden_size), dtype=torch.float32, device=device
)
cached_slot_mappings = torch.zeros(
(1, layer_num), dtype=torch.int64, device=device
)
cached_positions = torch.zeros(
(1, layer_num), dtype=torch.int64, device=device
)
return cls(
cached_len=cached_len,
cached_token_ids=cached_token_ids,
cached_hidden_states=cached_hidden_states,
cached_slot_mappings=cached_slot_mappings,
cached_positions=cached_positions,
)

View File

@@ -0,0 +1,504 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import numpy as np
import torch
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.triton_utils import tl, triton
from vllm.v1.attention.backend import (
CommonAttentionMetadata,
)
from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.metadata import MultiLayerEagleMetadata
BLOCK_HIDDEN = 128
BLOCK_TOKENS = 128
class MultiLayerEagleProposer(EagleProposer):
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
runner=None,
):
super().__init__(vllm_config, device, runner)
self.layer_num: int = getattr(
self.speculative_config.draft_model_config.hf_text_config,
"n_predict", 0
)
self.num_speculative_tokens: int = (
self.speculative_config.num_speculative_tokens
)
def adjust_input(
self,
batch_size: int,
target_token_ids: torch.Tensor,
target_positions: torch.Tensor,
target_hidden_states: torch.Tensor,
token_indices_to_sample: torch.Tensor,
common_attn_metadata: CommonAttentionMetadata,
multi_layer_eagle_metadata: MultiLayerEagleMetadata | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, Any]:
assert multi_layer_eagle_metadata is not None
if token_indices_to_sample is None:
token_indices_to_sample = (
common_attn_metadata.query_start_loc[1:] - 1
)
MAX_SHIFT = self.layer_num
assert MAX_SHIFT > 0
prev_token_ids = target_token_ids.clone()
prev_positions = target_positions.clone()
prev_hidden_states = target_hidden_states.clone()
slot_mapping = common_attn_metadata.slot_mapping
start_token_indices = common_attn_metadata.query_start_loc[:-1]
end_token_indices = common_attn_metadata.query_start_loc[1:] - 1
pos_for_shift = (
target_positions[0]
if target_positions.dim() == 2
else target_positions
)
start_token_pos = pos_for_shift[start_token_indices]
shift = torch.minimum(
end_token_indices - token_indices_to_sample,
start_token_pos,
)
shift = torch.clamp(shift, min=0)
token_indices_to_sample.add_(shift)
common_attn_metadata.seq_lens.sub_(shift)
cached_lens = multi_layer_eagle_metadata.cached_len
shift = torch.minimum(shift, cached_lens)
_multi_layer_eagle_shift_and_cache(
batch_size=batch_size,
max_shift=MAX_SHIFT,
src_token_ids=target_token_ids,
dst_token_ids=prev_token_ids,
src_positions=target_positions,
dst_positions=prev_positions,
src_hidden_states=target_hidden_states,
dst_hidden_states=prev_hidden_states,
src_slot_mapping=slot_mapping,
dst_slot_mapping=slot_mapping,
start_token_indices=start_token_indices,
end_token_indices=end_token_indices,
token_indices_to_sample=token_indices_to_sample,
shift=shift,
cached_lens=cached_lens,
cached_prev_token_ids=(
multi_layer_eagle_metadata.cached_token_ids
),
cached_prev_positions=(
multi_layer_eagle_metadata.cached_positions
),
cached_prev_hidden_states=(
multi_layer_eagle_metadata.cached_hidden_states
),
cached_slot_mappings=(
multi_layer_eagle_metadata.cached_slot_mappings
),
common_attn_metadata=common_attn_metadata,
)
return (
prev_token_ids,
prev_positions,
prev_hidden_states,
common_attn_metadata,
)
def prepare_inputs(
self,
common_attn_metadata: CommonAttentionMetadata,
sampled_token_ids: list[list[int]],
num_draft_tokens: list[int],
) -> tuple[CommonAttentionMetadata, torch.Tensor]:
raise Exception(
"speculative_config.disable_padded_drafter_batch"
" is not supported now for MultiLayerEagleProposer."
)
@torch.inference_mode()
def dummy_run(
self,
num_tokens: int,
use_cudagraphs: bool = True,
is_graph_capturing: bool = False,
slot_mappings: dict[str, torch.Tensor] | None = None,
) -> None:
cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
self._determine_batch_execution_and_padding(
num_tokens, use_cudagraphs=use_cudagraphs
)
)
if (
self._draft_attn_layer_names
and slot_mappings is not None
and next(iter(self._draft_attn_layer_names)) in slot_mappings
):
slot_mapping_dict = self._get_slot_mapping(num_input_tokens)
else:
slot_mapping_dict = slot_mappings or {}
adjust_input_kwargs = {
"batch_size": 1,
"target_token_ids": self.input_ids[:num_input_tokens],
"target_positions": self._get_positions(num_input_tokens),
"target_hidden_states": self.hidden_states[:num_input_tokens],
"token_indices_to_sample": torch.tensor(
[num_input_tokens - 1],
dtype=torch.int32,
device=self.device,
),
"common_attn_metadata": CommonAttentionMetadata(
query_start_loc=torch.tensor(
[0, num_input_tokens],
dtype=torch.int32,
device=self.device,
),
query_start_loc_cpu=torch.tensor(
[0, num_input_tokens],
dtype=torch.int32,
device="cpu",
),
key_start_loc=torch.tensor(
[0, num_input_tokens],
dtype=torch.int32,
device=self.device,
),
seq_lens=torch.tensor(
[num_input_tokens],
dtype=torch.int32,
device=self.device,
),
seq_lens_np=np.array([num_input_tokens], dtype=np.int32),
num_reqs=1,
num_actual_tokens=num_input_tokens,
max_query_len=self.num_speculative_tokens + 1,
max_seq_len=self.max_model_len,
block_table_tensor=torch.tensor(
[], dtype=torch.int32, device=self.device
),
slot_mapping=self.arange[:num_input_tokens],
logits_indices_padded=None,
num_logits_indices=None,
causal=True,
encoder_seq_lens=None,
),
"multi_layer_eagle_metadata": MultiLayerEagleMetadata.make_dummy(
layer_num=self.layer_num,
hidden_size=self.hidden_size,
device=self.device,
),
}
self.adjust_input(**adjust_input_kwargs)
for fwd_idx in range(self.layer_num):
with set_forward_context(
None,
self.draft_vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
slot_mapping=slot_mapping_dict,
):
if self.supports_mm_inputs:
input_ids = None
inputs_embeds = self.inputs_embeds[:num_input_tokens]
else:
input_ids = self.input_ids[:num_input_tokens]
inputs_embeds = None
model_kwargs = {
"input_ids": input_ids,
"positions": self._get_positions(num_input_tokens),
"hidden_states": self.hidden_states[:num_input_tokens],
"inputs_embeds": inputs_embeds,
"spec_step_idx": fwd_idx,
}
self.model(**model_kwargs)
def _multi_layer_eagle_shift_and_cache(
*,
batch_size: int,
max_shift: int,
src_token_ids: torch.Tensor,
dst_token_ids: torch.Tensor,
src_positions: torch.Tensor,
dst_positions: torch.Tensor,
src_hidden_states: torch.Tensor,
dst_hidden_states: torch.Tensor,
src_slot_mapping: torch.Tensor,
dst_slot_mapping: torch.Tensor,
start_token_indices: torch.Tensor,
end_token_indices: torch.Tensor,
token_indices_to_sample: torch.Tensor,
shift: torch.Tensor,
cached_lens: torch.Tensor,
cached_prev_token_ids: torch.Tensor,
cached_prev_positions: torch.Tensor,
cached_prev_hidden_states: torch.Tensor,
cached_slot_mappings: torch.Tensor,
common_attn_metadata: CommonAttentionMetadata,
):
if batch_size == 0:
return
assert max_shift > 0
assert cached_prev_positions.is_contiguous()
assert cached_prev_token_ids.is_contiguous()
assert cached_prev_hidden_states.is_contiguous()
assert cached_slot_mappings.is_contiguous()
assert src_hidden_states.is_contiguous()
assert dst_hidden_states.is_contiguous()
if src_slot_mapping.data_ptr() == dst_slot_mapping.data_ptr():
src_slot_mapping = src_slot_mapping.clone()
store_start = torch.maximum(
start_token_indices,
(token_indices_to_sample + 1 - max_shift),
)
store_lens = torch.clamp(
token_indices_to_sample - store_start + 1,
min=0,
max=max_shift,
)
max_window_len = int(
(
common_attn_metadata.query_start_loc_cpu[1:]
- common_attn_metadata.query_start_loc_cpu[:-1]
)
.max()
.item()
)
num_blocks = max(1, (max_window_len + BLOCK_TOKENS - 1) // BLOCK_TOKENS)
_shift_and_gather_cache_1d_kernel[(batch_size, num_blocks)](
src_token_ids,
dst_token_ids,
cached_prev_token_ids,
start_token_indices,
end_token_indices,
shift,
cached_lens,
store_start,
store_lens,
MAX_SHIFT=max_shift,
PADDED_SHIFT=triton.next_power_of_2(max_shift),
BLOCK_TOKENS=BLOCK_TOKENS,
)
_shift_and_gather_cache_1d_kernel[(batch_size, num_blocks)](
src_slot_mapping,
dst_slot_mapping,
cached_slot_mappings,
start_token_indices,
end_token_indices,
shift,
cached_lens,
store_start,
store_lens,
MAX_SHIFT=max_shift,
PADDED_SHIFT=triton.next_power_of_2(max_shift),
BLOCK_TOKENS=BLOCK_TOKENS,
)
_shift_and_gather_cache_1d_kernel[(batch_size, num_blocks)](
src_positions,
dst_positions,
cached_prev_positions,
start_token_indices,
end_token_indices,
shift,
cached_lens,
store_start,
store_lens,
MAX_SHIFT=max_shift,
PADDED_SHIFT=triton.next_power_of_2(max_shift),
BLOCK_TOKENS=BLOCK_TOKENS,
)
hidden_size = int(dst_hidden_states.shape[1])
num_hidden_blocks = max(
1, (hidden_size + BLOCK_HIDDEN - 1) // BLOCK_HIDDEN
)
_shift_and_gather_hidden_kernel[
(batch_size, num_blocks, num_hidden_blocks)
](
src_hidden_states,
dst_hidden_states,
cached_prev_hidden_states,
start_token_indices,
end_token_indices,
shift,
cached_lens,
store_start,
store_lens,
MAX_SHIFT=max_shift,
PADDED_SHIFT=triton.next_power_of_2(max_shift),
HIDDEN_SIZE=hidden_size,
BLOCK_TOKENS=BLOCK_TOKENS,
BLOCK_HIDDEN=BLOCK_HIDDEN,
num_warps=4,
)
cached_lens.copy_(store_lens)
return
@triton.jit
def _shift_and_gather_cache_1d_kernel(
src_ptr,
dst_ptr,
cached_ptr,
start_ptr,
end_ptr,
shift_ptr,
cached_len_ptr,
store_start_ptr,
store_len_ptr,
MAX_SHIFT: tl.constexpr,
PADDED_SHIFT: tl.constexpr,
BLOCK_TOKENS: tl.constexpr,
):
# Per-sequence "shift + gather" for packed 1D arrays (token ids, positions,
# slot mappings, ...).
#
# For a single sequence (0-based index i within its window):
# - Prefix (i < shift):
# dst[start + i] = cached[cached_len - shift + i]
# - Body (i >= shift):
# dst[start + i] = src[start + i - shift]
pid_seq = tl.program_id(0)
pid_blk = tl.program_id(1)
start = tl.load(start_ptr + pid_seq).to(tl.int32)
end = tl.load(end_ptr + pid_seq).to(tl.int32)
shift = tl.load(shift_ptr + pid_seq).to(tl.int32)
cached_len = tl.load(cached_len_ptr + pid_seq).to(tl.int32)
assert cached_len >= shift
base = pid_blk * BLOCK_TOKENS
k = tl.arange(0, BLOCK_TOKENS)
offs = base + k
dst_idx = start + offs
window_len = end - start + 1
mask = offs < window_len
base_cached = cached_ptr + pid_seq * MAX_SHIFT
cached_idx = cached_len - shift + offs
cached_mask = offs < shift
val_cached = tl.load(
base_cached + cached_idx, mask=mask & cached_mask, other=0
)
src_idx = start + offs - shift
val_src = tl.load(src_ptr + src_idx, mask=mask & ~cached_mask, other=0)
val = tl.where(cached_mask, val_cached, val_src)
tl.store(dst_ptr + dst_idx, val, mask=mask)
store_start = tl.load(store_start_ptr + pid_seq).to(tl.int32)
store_len = tl.load(store_len_ptr + pid_seq).to(tl.int32)
m = tl.arange(0, PADDED_SHIFT)
store_mask = m < MAX_SHIFT
dst_idx = store_start + m
val = tl.load(
dst_ptr + dst_idx, mask=store_mask & (m < store_len), other=0
)
tl.store(base_cached + m, val, mask=store_mask)
@triton.jit
def _shift_and_gather_hidden_kernel(
src_ptr,
dst_ptr,
cached_ptr,
start_ptr,
end_ptr,
shift_ptr,
cached_len_ptr,
store_start_ptr,
store_len_ptr,
MAX_SHIFT: tl.constexpr,
PADDED_SHIFT: tl.constexpr,
HIDDEN_SIZE: tl.constexpr,
BLOCK_TOKENS: tl.constexpr,
BLOCK_HIDDEN: tl.constexpr,
):
# Per-sequence "shift + gather" for hidden states.
# Layout:
# - src_ptr / dst_ptr: [num_tokens, hidden_size]
# - cached_ptr: [batch_size, MAX_SHIFT, hidden_size]
pid_seq = tl.program_id(0)
pid_blk = tl.program_id(1)
pid_hid = tl.program_id(2)
start = tl.load(start_ptr + pid_seq).to(tl.int32)
end = tl.load(end_ptr + pid_seq).to(tl.int32)
shift = tl.load(shift_ptr + pid_seq).to(tl.int32)
cached_len = tl.load(cached_len_ptr + pid_seq).to(tl.int32)
assert cached_len >= shift
base = pid_blk * BLOCK_TOKENS
k = tl.arange(0, BLOCK_TOKENS)
tok_offs = base + k
dst_tok = start + tok_offs
n = pid_hid * BLOCK_HIDDEN + tl.arange(0, BLOCK_HIDDEN)
dst_ptrs = dst_ptr + dst_tok[:, None] * HIDDEN_SIZE + n[None, :] * 1
window_len = end - start + 1
tok_mask = tok_offs < window_len
n_mask = n < HIDDEN_SIZE
mask = tok_mask[:, None] & n_mask[None, :]
base_cached = cached_ptr + pid_seq * HIDDEN_SIZE * MAX_SHIFT
cached_tok = cached_len - shift + tok_offs
cached_ptrs = (
base_cached + cached_tok[:, None] * HIDDEN_SIZE + n[None, :] * 1
)
cached_mask = tok_offs < shift
val_cached = tl.load(
cached_ptrs, mask=mask & cached_mask[:, None], other=0
)
src_tok = start + tok_offs - shift
src_ptrs = src_ptr + src_tok[:, None] * HIDDEN_SIZE + n[None, :] * 1
val_src = tl.load(src_ptrs, mask=mask & ~cached_mask[:, None], other=0)
val = tl.where(cached_mask[:, None], val_cached, val_src)
tl.store(dst_ptrs, val, mask=mask)
store_start = tl.load(store_start_ptr + pid_seq).to(tl.int32)
store_len = tl.load(store_len_ptr + pid_seq).to(tl.int32)
m = tl.arange(0, PADDED_SHIFT)
m_mask = (m < MAX_SHIFT) & (m < store_len)
store_tok = store_start + m
dst_ptrs = dst_ptr + store_tok[:, None] * HIDDEN_SIZE + n[None, :] * 1
store_ptrs = (
base_cached + m[:, None] * HIDDEN_SIZE + n[None, :] * 1
)
mask = m_mask[:, None] & n_mask[None, :]
val = tl.load(dst_ptrs, mask=mask, other=0)
tl.store(store_ptrs, val, mask=mask)

View File

@@ -157,12 +157,23 @@ def create_vllm_config_for_draft_model(
quantized differently, and has potentially different tensor_parallel_size.
This function creates a new vllm_config configured for the drafter.
The vllm_config is useful when loading the draft model with get_model().
This helper returns the original target config for the common case and only
rewrites rank/parallel info when the drafter is configured to run locally
on the last target PP stage. This keeps runtime behavior unchanged for the
common case while still handling PP rank remapping.
"""
old = target_model_vllm_config
assert old.speculative_config is not None, "speculative_config is not set"
old_spec_config = old.speculative_config
needs_rank_remap = old_spec_config.needs_partial_pp_draft_remap(old.parallel_config)
if not needs_rank_remap:
return old
draft_rank = old_spec_config.resolve_partial_pp_draft_rank(old.parallel_config)
new_parallel_config = replace(
old_spec_config.draft_parallel_config, rank=old.parallel_config.rank
old_spec_config.draft_parallel_config, rank=draft_rank
)
new: VllmConfig = replace(
old,

View File

@@ -53,7 +53,12 @@ class CPUModelRunner(GPUModelRunner):
v.gpu = v.cpu
@instrument(span_name="Loading (CPU)")
def load_model(self, eep_scale_up: bool = False) -> None:
def load_model(self, load_dummy_weights: bool = False) -> None:
if load_dummy_weights:
raise ValueError(
"Loading dummy weights (needed for elastic EP scale-up) "
"Is not supported by the CPU Model Runner."
)
logger.info("Starting to load model %s...", self.model_config.model)
self.model = get_model(vllm_config=self.vllm_config)

View File

@@ -85,7 +85,7 @@ class CPUWorker(Worker):
self.local_omp_cpuid = omp_cpuids_list[self.rank]
if self.local_omp_cpuid != "nobind":
ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)
ret = torch.ops._C.init_cpu_threads_env(self.local_omp_cpuid)
if ret:
logger.info(ret)
@@ -118,11 +118,12 @@ class CPUWorker(Worker):
def determine_available_memory(self) -> int:
return self.cache_config.cpu_kvcache_space_bytes or 0
def compile_or_warm_up_model(self) -> None:
def compile_or_warm_up_model(self) -> float:
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
self.model_runner.warming_up_model()
return self.compilation_config.compilation_time
def _get_autobind_cpu_ids(
self, cpu_selector: Callable[[list[LogicalCPUInfo]], list[LogicalCPUInfo]]

View File

@@ -37,7 +37,6 @@ def _get_device_and_group(parallel_config: ParallelConfig):
def _run_ar(
should_ubatch: bool,
should_dp_pad: bool,
orig_num_tokens_per_ubatch: int,
padded_num_tokens_per_ubatch: int,
cudagraph_mode: int,
@@ -46,12 +45,11 @@ def _run_ar(
dp_size = parallel_config.data_parallel_size
dp_rank = parallel_config.data_parallel_rank
device, group = _get_device_and_group(parallel_config)
tensor = torch.zeros(5, dp_size, device=device, dtype=torch.int32)
tensor = torch.zeros(4, dp_size, device=device, dtype=torch.int32)
tensor[0][dp_rank] = orig_num_tokens_per_ubatch
tensor[1][dp_rank] = padded_num_tokens_per_ubatch
tensor[2][dp_rank] = 1 if should_ubatch else 0
tensor[3][dp_rank] = 1 if should_dp_pad else 0
tensor[4][dp_rank] = cudagraph_mode
tensor[3][dp_rank] = cudagraph_mode
dist.all_reduce(tensor, group=group)
return tensor
@@ -97,14 +95,13 @@ def _post_process_cudagraph_mode(tensor: torch.Tensor) -> int:
If any rank has NONE (0), all ranks use NONE.
This ensures all ranks send consistent values (all padded or all unpadded).
"""
return int(tensor[4, :].min().item())
return int(tensor[3, :].min().item())
def _synchronize_dp_ranks(
num_tokens_unpadded: int,
num_tokens_padded: int,
should_attempt_ubatching: bool,
should_attempt_dp_padding: bool,
cudagraph_mode: int,
parallel_config: ParallelConfig,
) -> tuple[bool, torch.Tensor | None, int]:
@@ -113,8 +110,8 @@ def _synchronize_dp_ranks(
run with microbatching or none of them do.
2. Determines the total number of tokens that each rank will run.
When running microbatched or if should_attempt_dp_padding is True, all
ranks will be padded out so that the run with the same number of tokens
When running microbatched or if cudagraph is enabled (synced across ranks),
all ranks will be padded out so that they run with the same number of tokens.
3. Synchronizes cudagraph_mode across ranks by taking the minimum.
@@ -133,29 +130,26 @@ def _synchronize_dp_ranks(
# will run and if we are using ubatching or not.
tensor = _run_ar(
should_ubatch=should_attempt_ubatching,
should_dp_pad=should_attempt_dp_padding,
orig_num_tokens_per_ubatch=num_tokens_unpadded,
padded_num_tokens_per_ubatch=num_tokens_padded,
cudagraph_mode=cudagraph_mode,
parallel_config=parallel_config,
)
should_dp_pad = bool(torch.all(tensor[3] == 1).item())
# DP ranks should all have the same value for should_attempt_dp_padding.
assert should_attempt_dp_padding == should_dp_pad
# Synchronize cudagraph_mode across ranks first (take min).
# This is needed before DP padding decision since we use the synced
# cudagraph mode to determine whether DP padding is needed.
synced_cudagraph_mode = _post_process_cudagraph_mode(tensor)
# Check conditions for microbatching
should_ubatch = _post_process_ubatch(tensor, parallel_config.num_ubatches)
if should_ubatch and not should_dp_pad:
logger.debug_once(
"Microbatching has been triggered and requires DP padding. "
"Enabling DP padding even though it has been explicitly "
"disabled.",
scope="global",
)
should_dp_pad = True
# DP padding is needed when cudagraph is enabled (synced across ranks)
# or when ubatching/DBO is active (ubatching requires uniform batch
# sizes across DP ranks currently).
# Use the synced runtime cudagraph mode rather than the compilation config
# so we can avoid padding when cudagraph is not enabled for this step.
should_dp_pad = synced_cudagraph_mode != 0 or should_ubatch
# Pad all DP ranks up to the maximum token count across ranks if
# should_dp_pad is True
@@ -164,16 +158,12 @@ def _synchronize_dp_ranks(
should_dp_pad,
)
# Synchronize cudagraph_mode across ranks (take min)
synced_cudagraph_mode = _post_process_cudagraph_mode(tensor)
return should_ubatch, num_tokens_after_padding, synced_cudagraph_mode
def coordinate_batch_across_dp(
num_tokens_unpadded: int,
allow_microbatching: bool,
allow_dp_padding: bool,
parallel_config: ParallelConfig,
num_tokens_padded: int | None = None,
uniform_decode: bool | None = None,
@@ -187,7 +177,6 @@ def coordinate_batch_across_dp(
Args:
num_tokens_unpadded: Number of tokens without accounting for padding
allow_microbatching: If microbatching should be attempted
allow_dp_padding: If all DP ranks should be padded up to the same value
parallel_config: The parallel config
num_tokens_padded: Number of tokens including any non-DP padding (CUDA graphs,
TP, etc)
@@ -195,15 +184,15 @@ def coordinate_batch_across_dp(
only contains single token decodes
num_scheduled_tokens_per_request: Only used if allow_microbatching is True. The
number of tokens per request.
cudagraph_mode: The cudagraph mode for this rank (0=NONE, 1=PIECEWISE, 2=FULL)
cudagraph_mode: The cudagraph mode for this rank (0=NONE, 1=PIECEWISE, 2=FULL).
DP padding is enabled when synced cudagraph mode across ranks is not NONE.
Returns: tuple[
ubatch_slices: if this is set then all DP ranks have agreed to
microbatch
num_tokens_after_padding: A tensor containing the total number of
tokens per-microbatch for each DP rank including padding. Will be
padded up to the max value across all DP ranks when allow_dp_padding
is True.
padded up to the max value across all DP ranks when cudagraph is enabled.
synced_cudagraph_mode: The synchronized cudagraph mode (min across ranks)
]
@@ -231,7 +220,6 @@ def coordinate_batch_across_dp(
num_tokens_unpadded,
num_tokens_padded,
should_attempt_ubatching,
allow_dp_padding,
cudagraph_mode,
parallel_config,
)

View File

@@ -70,6 +70,42 @@ class AsyncOutput(AsyncModelRunnerOutput):
return self.model_runner_output
class AsyncPoolingOutput(AsyncModelRunnerOutput):
def __init__(
self,
model_runner_output: ModelRunnerOutput,
pooler_output: torch.Tensor,
is_valid: torch.Tensor | None,
main_stream: torch.cuda.Stream,
copy_stream: torch.cuda.Stream,
copy_event: torch.cuda.Event,
):
self.model_runner_output = model_runner_output
self.pooler_output = pooler_output
self.is_valid = is_valid
self.copy_event = copy_event
with stream(copy_stream, main_stream):
copy_stream.wait_stream(main_stream)
self.pooler_output_cpu = self.pooler_output.to("cpu", non_blocking=True)
if self.is_valid is not None:
self.is_valid_cpu = self.is_valid.to("cpu", non_blocking=True)
else:
self.is_valid_cpu = None
self.copy_event.record(copy_stream)
def get_output(self) -> ModelRunnerOutput:
self.copy_event.synchronize()
pooler_output = self.pooler_output_cpu.unbind(dim=0)
if self.is_valid_cpu is not None:
is_valid_cpu = self.is_valid_cpu.tolist()
for i, is_valid in enumerate(is_valid_cpu):
if not is_valid:
pooler_output[i] = None
self.model_runner_output.pooler_output = pooler_output
return self.model_runner_output
def async_copy_to_np(x: torch.Tensor) -> np.ndarray:
return x.to("cpu", non_blocking=True).numpy()

View File

@@ -119,6 +119,10 @@ class BlockTables:
return tuple(block_table[:num_reqs] for block_table in self.input_block_tables)
def get_dummy_block_tables(self, num_reqs: int) -> tuple[torch.Tensor, ...]:
# NOTE(woosuk): The output may be used for CUDA graph capture.
# Therefore, this method must return the persistent tensor
# with the same memory address as that used during the model's forward pass,
# rather than allocating a new tensor.
return tuple(block_table[:num_reqs] for block_table in self.input_block_tables)
def compute_slot_mappings(
@@ -150,7 +154,14 @@ class BlockTables:
return self.slot_mappings[:, :num_tokens]
def get_dummy_slot_mappings(self, num_tokens: int) -> torch.Tensor:
# Fill the entire slot_mappings tensor, not just the first `num_tokens` entries.
# This is because the padding logic is complex and kernels may access beyond
# the requested range.
self.slot_mappings.fill_(PAD_SLOT_ID)
# NOTE(woosuk): The output may be used for CUDA graph capture.
# Therefore, this method must return the persistent tensor
# with the same memory address as that used during the model's forward pass,
# rather than allocating a new tensor.
return self.slot_mappings[:, :num_tokens]

View File

@@ -3,7 +3,6 @@
from collections.abc import Callable
from typing import Any
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
@@ -15,13 +14,12 @@ from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.model_executor.offloader.base import get_offloader
from vllm.utils.math_utils import cdiv
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import (
build_attn_metadata,
build_slot_mappings_by_layer,
)
from vllm.v1.worker.gpu.attn_utils import build_slot_mappings_by_layer
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.cp_utils import prepare_dcp_local_seq_lens
from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
from vllm.v1.worker.gpu.input_batch import InputBuffers
from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
from vllm.v1.worker.gpu.model_states.interface import ModelState
from vllm.v1.worker.utils import AttentionGroup
@@ -29,13 +27,11 @@ class CudaGraphManager:
def __init__(
self,
vllm_config: VllmConfig,
uses_mrope: bool,
use_aux_hidden_state_outputs: bool,
device: torch.device,
):
self.vllm_config = vllm_config
self.scheduler_config = vllm_config.scheduler_config
self.uses_mrope = uses_mrope
self.use_aux_hidden_state_outputs = use_aux_hidden_state_outputs
self.device = device
@@ -88,9 +84,8 @@ class CudaGraphManager:
num_tokens: int,
capture_cg_mode: CUDAGraphMode,
model: nn.Module,
model_state: ModelState,
input_buffers: InputBuffers,
mrope_positions: torch.Tensor | None,
inputs_embeds: torch.Tensor | None,
block_tables: BlockTables,
attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig,
@@ -113,24 +108,23 @@ class CudaGraphManager:
)
else:
num_reqs = min(num_tokens, self.max_num_reqs)
input_ids = input_buffers.input_ids[:num_tokens]
positions = input_buffers.positions[:num_tokens]
if self.uses_mrope:
assert mrope_positions is not None
positions = mrope_positions[:, :num_tokens]
if inputs_embeds is not None:
inputs_embeds = inputs_embeds[:num_tokens]
model_inputs = {
"input_ids": input_buffers.input_ids[:num_tokens],
"positions": input_buffers.positions[:num_tokens],
# NOTE: Values returned by `prepare_dummy_inputs` will override the
# default values above.
**model_state.prepare_dummy_inputs(num_reqs, num_tokens),
}
attn_metadata, slot_mappings = prepare_inputs_to_capture(
num_reqs,
num_tokens,
model_state,
input_buffers,
block_tables,
attn_groups,
self.max_model_len,
kv_cache_config,
uniform_decode_query_len=(
self.uniform_decode_query_len if uniform_decode else 0
),
)
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
@@ -143,11 +137,7 @@ class CudaGraphManager:
num_tokens_across_dp=num_tokens_across_dp,
slot_mapping=slot_mappings,
):
model_output = model(
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
)
model_output = model(**model_inputs)
if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = model_output
else:
@@ -164,9 +154,7 @@ class CudaGraphManager:
num_tokens=num_tokens,
num_reqs=num_reqs,
model=model,
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
model_inputs=model_inputs,
num_tokens_across_dp=num_tokens_across_dp,
attn_metadata=attn_metadata,
slot_mappings=slot_mappings,
@@ -178,9 +166,7 @@ class CudaGraphManager:
num_tokens: int,
num_reqs: int,
model: nn.Module,
input_ids: torch.Tensor,
positions: torch.Tensor,
inputs_embeds: torch.Tensor | None,
model_inputs: dict[str, torch.Tensor | None],
num_tokens_across_dp: torch.Tensor,
attn_metadata: dict[str, Any] | None,
slot_mappings: dict[str, torch.Tensor] | None,
@@ -206,11 +192,8 @@ class CudaGraphManager:
),
torch.cuda.graph(graph, self.pool),
):
model_output = model(
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
)
model_output = model(**model_inputs)
# Join offloader's copy stream after forward to avoid unjoined
# stream error. The last layer's start_prefetch forks copy_stream,
# but wait_prefetch only happens in the next forward pass.
@@ -235,9 +218,7 @@ class CudaGraphManager:
num_tokens: int,
num_reqs: int,
model: nn.Module,
input_ids: torch.Tensor,
positions: torch.Tensor,
inputs_embeds: torch.Tensor | None,
model_inputs: dict[str, torch.Tensor | None],
num_tokens_across_dp: torch.Tensor,
attn_metadata: dict[str, Any] | None,
slot_mappings: dict[str, torch.Tensor] | None,
@@ -256,19 +237,14 @@ class CudaGraphManager:
batch_descriptor=batch_descriptor,
slot_mapping=slot_mappings,
):
model(
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
)
model(**model_inputs)
@torch.inference_mode()
def capture(
self,
model: nn.Module,
model_state: ModelState,
input_buffers: InputBuffers,
mrope_positions: torch.Tensor | None,
inputs_embeds: torch.Tensor | None,
block_tables: BlockTables,
attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig,
@@ -278,9 +254,8 @@ class CudaGraphManager:
device=self.device,
capture_fn=self.capture_graph,
model=model,
model_state=model_state,
input_buffers=input_buffers,
mrope_positions=mrope_positions,
inputs_embeds=inputs_embeds,
block_tables=block_tables,
attn_groups=attn_groups,
kv_cache_config=kv_cache_config,
@@ -412,51 +387,36 @@ def capture_graphs(
def prepare_inputs_to_capture(
num_reqs: int,
num_tokens: int,
model_state: ModelState,
input_buffers: InputBuffers,
block_tables: BlockTables,
attn_groups: list[list[AttentionGroup]],
max_model_len: int,
kv_cache_config: KVCacheConfig,
uniform_decode_query_len: int = 0,
) -> tuple[dict[str, Any], dict[str, torch.Tensor]]:
if uniform_decode_query_len > 0:
num_tokens_per_req = uniform_decode_query_len
else:
num_tokens_per_req = num_tokens // num_reqs
query_start_loc_np = np.arange(num_reqs + 1, dtype=np.int32) * num_tokens_per_req
query_start_loc_np[-1] = num_tokens
query_start_loc_cpu = torch.from_numpy(query_start_loc_np)
input_buffers.query_start_loc[: num_reqs + 1] = query_start_loc_cpu
input_buffers.query_start_loc[num_reqs + 1 :] = num_tokens
query_start_loc = input_buffers.query_start_loc[: num_reqs + 1]
# HACK(woosuk): For faster warmup, we set seq_lens (GPU) to num_tokens
# rather than max_model_len.
input_buffers.seq_lens[:num_reqs] = num_tokens
input_buffers.seq_lens[num_reqs:] = 0
input_buffers.dcp_local_seq_lens[:num_reqs] = num_tokens
input_buffers.dcp_local_seq_lens[num_reqs:] = 0
input_block_tables = [x[:num_reqs] for x in block_tables.input_block_tables]
slot_mappings = block_tables.slot_mappings[:, :num_tokens]
input_batch = InputBatch.make_dummy(num_reqs, num_tokens, input_buffers)
input_block_tables = block_tables.get_dummy_block_tables(num_reqs)
slot_mappings = block_tables.get_dummy_slot_mappings(num_tokens)
slot_mappings_by_layer = build_slot_mappings_by_layer(
slot_mappings, kv_cache_config
)
attn_metadata = build_attn_metadata(
attn_groups=attn_groups,
num_reqs=num_reqs,
num_tokens=num_tokens,
query_start_loc_gpu=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
max_query_len=num_tokens_per_req,
seq_lens=input_buffers.seq_lens,
max_seq_len=max_model_len,
block_tables=input_block_tables,
slot_mappings=slot_mappings,
kv_cache_config=kv_cache_config,
dcp_local_seq_lens=input_buffers.dcp_local_seq_lens,
# HACK(woosuk): Special handling for DCP.
if block_tables.cp_size > 1:
prepare_dcp_local_seq_lens(
input_buffers.dcp_local_seq_lens,
input_batch.seq_lens,
num_reqs,
block_tables.cp_size,
block_tables.cp_rank,
block_tables.cp_interleave,
)
input_batch.dcp_local_seq_lens = input_buffers.dcp_local_seq_lens[:num_reqs]
attn_metadata = model_state.prepare_attn(
input_batch,
input_block_tables,
slot_mappings,
attn_groups,
kv_cache_config,
)
return attn_metadata, slot_mappings_by_layer

View File

@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Any
import numpy as np
import torch
@@ -60,20 +59,13 @@ class InputBatch:
query_start_loc_np: np.ndarray
# [num_reqs]
seq_lens: torch.Tensor
# [num_reqs]
dcp_local_seq_lens: torch.Tensor | None
# [num_tokens_after_padding]
input_ids: torch.Tensor
# [num_tokens_after_padding]
positions: torch.Tensor
# [3, num_tokens_after_padding]
mrope_positions: torch.Tensor | None
# [num_tokens_after_padding, hidden_size]
inputs_embeds: torch.Tensor | None
# layer_name -> Metadata
attn_metadata: dict[str, Any]
# layer_name -> slot_mapping
slot_mappings: dict[str, torch.Tensor]
# [total_num_logits]
logits_indices: torch.Tensor
@@ -90,14 +82,16 @@ class InputBatch:
num_reqs: int,
num_tokens: int,
input_buffers: InputBuffers,
device: torch.device,
) -> "InputBatch":
assert 0 < num_reqs <= num_tokens
device = input_buffers.device
req_ids = [f"req_{i}_{random_uuid()}" for i in range(num_reqs)]
idx_mapping_np = np.arange(num_reqs, dtype=np.int32)
idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=device)
expanded_idx_mapping = idx_mapping
expanded_local_pos = torch.zeros(num_reqs, dtype=torch.int32, device=device)
num_scheduled_tokens = np.full(num_reqs, num_tokens // num_reqs, dtype=np.int32)
num_scheduled_tokens[-1] += num_tokens % num_reqs
assert int(num_scheduled_tokens.sum()) == num_tokens
@@ -123,7 +117,6 @@ class InputBatch:
input_ids = input_buffers.input_ids[:num_tokens].zero_()
positions = input_buffers.positions[:num_tokens].zero_()
# attn_metadata = defaultdict(lambda: None)
logits_indices = query_start_loc[1:] - 1
cu_num_logits = torch.arange(num_reqs + 1, device=device, dtype=torch.int32)
cu_num_logits_np = np.arange(num_reqs + 1, dtype=np.int32)
@@ -141,12 +134,9 @@ class InputBatch:
query_start_loc=query_start_loc,
query_start_loc_np=query_start_loc_np,
seq_lens=seq_lens,
dcp_local_seq_lens=None,
input_ids=input_ids,
positions=positions,
mrope_positions=None,
inputs_embeds=None,
attn_metadata=None, # type: ignore
slot_mappings=None, # type: ignore
logits_indices=logits_indices,
cu_num_logits=cu_num_logits,
cu_num_logits_np=cu_num_logits_np,
@@ -507,6 +497,38 @@ def post_update(
)
@triton.jit
def _post_update_pool_kernel(
idx_mapping_ptr,
num_computed_tokens_ptr,
query_start_loc_ptr,
):
batch_id = tl.program_id(0)
query_start = tl.load(query_start_loc_ptr + batch_id)
query_end = tl.load(query_start_loc_ptr + batch_id + 1)
query_len = query_end - query_start
req_state_idx = tl.load(idx_mapping_ptr + batch_id)
num_computed = tl.load(num_computed_tokens_ptr + req_state_idx)
tl.store(num_computed_tokens_ptr + req_state_idx, num_computed + query_len)
def post_update_pool(
# [num_reqs]
idx_mapping: torch.Tensor,
# [max_num_reqs]
num_computed_tokens: torch.Tensor,
# [num_reqs + 1]
query_start_loc: torch.Tensor,
) -> None:
num_reqs = idx_mapping.shape[0]
_post_update_pool_kernel[(num_reqs,)](
idx_mapping,
num_computed_tokens,
query_start_loc,
)
@triton.jit
def _expand_idx_mapping_kernel(
idx_mapping_ptr,

View File

@@ -0,0 +1,40 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.multimodal.inputs import MultiModalFeatureSpec
class EncoderCache:
def __init__(self):
# req_id -> MM features
self.mm_features: dict[str, list[MultiModalFeatureSpec]] = {}
# MM hash -> encoder outputs
self.encoder_outputs: dict[str, torch.Tensor] = {}
def add_request(
self, req_id: str, mm_features: list[MultiModalFeatureSpec]
) -> None:
self.mm_features[req_id] = mm_features
def remove_request(self, req_id: str) -> None:
self.mm_features.pop(req_id, None)
def reset_mm_cache(self) -> None:
"""
Clear the multi-modal cache that was used during profiling,
but no longer needed during inference.
"""
# TODO: Implement MM budget for encoder dummy run
pass
def reset_encoder_cache(self) -> None:
"""Clear the GPU-side encoder cache storing vision embeddings.
This should be called when model weights are updated to ensure
stale embeddings computed with old weights are not reused.
"""
self.encoder_outputs.clear()
def free_encoder_cache(self, mm_hash: str) -> None:
self.encoder_outputs.pop(mm_hash, None)

View File

@@ -4,54 +4,32 @@ import numpy as np
import torch
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItem
from vllm.multimodal.inputs import MultiModalKwargsItem
from vllm.multimodal.utils import group_mm_kwargs_by_modality
from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache
from vllm.v1.worker.utils import sanity_check_mm_encoder_outputs
class EncoderRunner:
def __init__(
self,
model: SupportsMultiModal,
max_num_tokens: int,
hidden_size: int,
encoder_cache: EncoderCache,
dtype: torch.dtype,
device: torch.device,
):
self.model = model
self.max_num_tokens = max_num_tokens
self.hidden_size = hidden_size
self.encoder_cache = encoder_cache
self.dtype = dtype
self.device = device
self.inputs_embeds = torch.zeros(
max_num_tokens, hidden_size, dtype=dtype, device=device
)
self.req_id_to_mm_features: dict[str, list[MultiModalFeatureSpec]] = {}
self.encoder_cache: dict[str, torch.Tensor] = {}
def reset_mm_cache(self) -> None:
"""
Clear the multi-modal cache that was used during profiling,
but no longer needed during inference.
"""
# TODO: Implement MM budget for encoder dummy run
pass
def reset_encoder_cache(self) -> None:
"""Clear the GPU-side encoder cache storing vision embeddings.
This should be called when model weights are updated to ensure
stale embeddings computed with old weights are not reused.
"""
self.encoder_cache.clear()
def add_request(self, req_id: str, mm_features: list[MultiModalFeatureSpec]):
self.req_id_to_mm_features[req_id] = mm_features
def free_encoder_cache(self, mm_hash: str) -> None:
self.encoder_cache.pop(mm_hash, None)
def remove_request(self, req_id: str) -> None:
self.req_id_to_mm_features.pop(req_id, None)
def prepare_mm_inputs(
self, scheduled_encoder_inputs: dict[str, list[int]]
@@ -59,7 +37,7 @@ class EncoderRunner:
mm_hashes: list[str] = []
mm_kwargs: list[tuple[str, MultiModalKwargsItem]] = []
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
mm_features = self.req_id_to_mm_features[req_id]
mm_features = self.encoder_cache.mm_features[req_id]
for mm_input_id in encoder_input_ids:
mm_feature = mm_features[mm_input_id]
if mm_feature.data is None:
@@ -72,25 +50,17 @@ class EncoderRunner:
@torch.inference_mode()
def execute_mm_encoder(
self,
model: SupportsMultiModal,
mm_hashes: list[str],
mm_kwargs: list[tuple[str, MultiModalKwargsItem]],
) -> list[torch.Tensor]:
if not mm_hashes:
return []
encoder_outputs: list[torch.Tensor] = []
for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
mm_kwargs, device=self.device, pin_memory=False
):
curr_group_outputs = model.embed_multimodal(**mm_kwargs_group)
curr_group_outputs = self.model.embed_multimodal(**mm_kwargs_group)
sanity_check_mm_encoder_outputs(
curr_group_outputs, expected_num_items=num_items
)
encoder_outputs.extend(curr_group_outputs)
# Cache the encoder outputs by mm_hash
self.encoder_cache.update(zip(mm_hashes, encoder_outputs))
return encoder_outputs
def gather_mm_embeddings(
@@ -122,7 +92,7 @@ class EncoderRunner:
# OPTIMIZATION: Skip decode requests.
continue
mm_features = self.req_id_to_mm_features[req_id]
mm_features = self.encoder_cache.mm_features[req_id]
for mm_feature in mm_features:
pos_info = mm_feature.mm_position
start_pos = pos_info.offset
@@ -148,7 +118,7 @@ class EncoderRunner:
continue
mm_hash = mm_feature.identifier
encoder_output = self.encoder_cache.get(mm_hash, None)
encoder_output = self.encoder_cache.encoder_outputs.get(mm_hash, None)
assert encoder_output is not None, f"Encoder cache miss for {mm_hash}."
if (is_embed := pos_info.is_embed) is not None:
@@ -170,12 +140,11 @@ class EncoderRunner:
@torch.inference_mode()
def get_inputs_embeds(
self,
model: SupportsMultiModal,
input_ids: torch.Tensor,
mm_embeds: list[torch.Tensor],
is_mm_embed: torch.Tensor,
) -> torch.Tensor:
x = model.embed_input_ids(
x = self.model.embed_input_ids(
input_ids, multimodal_embeddings=mm_embeds, is_multimodal=is_mm_embed
)
# Copy to the pre-allocated buffer for CUDA graphs.

View File

@@ -38,15 +38,16 @@ from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model_loader
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors
from vllm.tasks import SupportedTask
from vllm.utils.math_utils import cdiv
from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
from vllm.v1.worker.cp_utils import check_attention_cp_compatibility
from vllm.v1.worker.gpu.async_utils import AsyncOutput
from vllm.v1.worker.gpu.async_utils import AsyncOutput, AsyncPoolingOutput
from vllm.v1.worker.gpu.attn_utils import (
build_attn_metadata,
build_slot_mappings_by_layer,
get_kv_cache_spec,
init_attn_backend,
@@ -56,10 +57,7 @@ from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.buffer_utils import async_copy_to_gpu
from vllm.v1.worker.gpu.cp_utils import prepare_dcp_local_seq_lens
from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager
from vllm.v1.worker.gpu.dp_utils import (
get_cudagraph_and_dp_padding,
make_num_tokens_across_dp,
)
from vllm.v1.worker.gpu.dp_utils import get_cudagraph_and_dp_padding
from vllm.v1.worker.gpu.input_batch import (
InputBatch,
InputBuffers,
@@ -67,6 +65,7 @@ from vllm.v1.worker.gpu.input_batch import (
expand_idx_mapping,
get_num_sampled_and_rejected,
post_update,
post_update_pool,
prepare_pos_seq_lens,
prepare_prefill_inputs,
)
@@ -76,8 +75,9 @@ from vllm.v1.worker.gpu.kv_connector import (
get_kv_connector,
)
from vllm.v1.worker.gpu.lora_utils import LoraState
from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner
from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState
from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache
from vllm.v1.worker.gpu.model_states import init_model_state
from vllm.v1.worker.gpu.pool.pooling_runner import PoolingRunner
from vllm.v1.worker.gpu.pp_utils import pp_broadcast, pp_receive
from vllm.v1.worker.gpu.sample.output import SamplerOutput
from vllm.v1.worker.gpu.sample.prompt_logprob import PromptLogprobsWorker
@@ -120,34 +120,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
self.cache_config.cache_dtype
]
self.is_pooling_model = False
self.vocab_size = self.model_config.get_vocab_size()
self.max_model_len = self.model_config.max_model_len
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
self.max_num_reqs = self.scheduler_config.max_num_seqs
self.inputs_embeds_size = self.model_config.get_inputs_embeds_size()
# Multimodal
self.mm_registry = MULTIMODAL_REGISTRY
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
self.model_config
)
if self.supports_mm_inputs:
self.encoder_runner = EncoderRunner(
max_num_tokens=self.max_num_tokens,
hidden_size=self.inputs_embeds_size,
dtype=self.dtype,
device=self.device,
)
self.uses_mrope = self.model_config.uses_mrope
if self.uses_mrope:
self.mrope_states = MRopeState(
max_num_reqs=self.max_num_reqs,
max_num_tokens=self.max_num_tokens,
max_model_len=self.max_model_len,
device=self.device,
)
self.use_async_scheduling = self.scheduler_config.async_scheduling
self.output_copy_stream = torch.cuda.Stream(self.device)
@@ -169,6 +146,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.dcp_rank = get_dcp_group().rank_in_group if self.use_dcp else 0
self.cp_interleave = self.parallel_config.cp_kv_cache_interleave_size
# Multimodal
self.mm_registry = MULTIMODAL_REGISTRY
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
self.model_config
)
self.encoder_cache = None
if self.supports_mm_inputs and self.is_first_pp_rank:
self.encoder_cache = EncoderCache()
self.speculator = None
self.num_speculative_steps = 0
self.use_aux_hidden_state_outputs = False
@@ -212,7 +198,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# CUDA graphs.
self.cudagraph_manager = CudaGraphManager(
self.vllm_config,
self.uses_mrope,
self.use_aux_hidden_state_outputs,
self.device,
)
@@ -227,13 +212,24 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# KV Connector if configured.
self.kv_connector: KVConnector = NO_OP_KV_CONNECTOR
# Pooling models.
self.is_pooling_model = self.model_config.runner_type == "pooling"
self.pooling_runner: PoolingRunner | None = None
# For transferring state from execute_model to subsequent sample_tokens call.
self.execute_model_state: tuple | None = None
def update_max_model_len(self, max_model_len: int) -> None:
self.max_model_len = max_model_len
self.req_states.max_model_len = max_model_len
@staticmethod
def get_supported_tasks() -> tuple[str]:
return ("generate",)
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
tasks: list[SupportedTask] = []
if self.model_config.runner_type == "generate":
tasks.append("generate")
if self.pooling_runner is not None:
tasks.extend(self.pooling_runner.get_supported_pooling_tasks())
return tuple(tasks)
def load_model(self, *args, **kwargs) -> None:
time_before_load = time.perf_counter()
@@ -266,7 +262,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
prepare_communication_buffer_for_model(self.model)
if self.speculator is not None:
prepare_communication_buffer_for_model(self.speculator)
prepare_communication_buffer_for_model(self.speculator.model)
# Initialize the components that require the model.
self.model_state = init_model_state(
self.vllm_config, self.model, self.encoder_cache, self.device
)
if self.is_pooling_model:
self.pooling_runner = PoolingRunner(self.model)
def get_model(self) -> nn.Module:
return self.model
@@ -305,6 +308,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if self.speculator is not None:
# HACK(woosuk)
self.speculator.set_attn(
self.model_state,
self.kv_cache_config,
self.attn_groups,
self.block_tables,
@@ -320,39 +324,27 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
self.kv_connector = get_kv_connector(self.vllm_config, kv_caches_dict)
def prepare_dummy_attn_metadata(self, input_batch: InputBatch) -> None:
block_tables = self.block_tables.get_dummy_block_tables(input_batch.num_reqs)
slot_mappings = self.block_tables.get_dummy_slot_mappings(
input_batch.num_tokens
)
slot_mappings_by_layer = build_slot_mappings_by_layer(
slot_mappings, self.kv_cache_config
)
attn_metadata = build_attn_metadata(
attn_groups=self.attn_groups,
num_reqs=input_batch.num_reqs,
num_tokens=input_batch.num_tokens,
query_start_loc_gpu=input_batch.query_start_loc,
query_start_loc_cpu=torch.from_numpy(input_batch.query_start_loc_np),
max_query_len=input_batch.num_scheduled_tokens.max().item(),
seq_lens=input_batch.seq_lens,
max_seq_len=self.max_model_len,
block_tables=block_tables,
slot_mappings=slot_mappings,
kv_cache_config=self.kv_cache_config,
dcp_local_seq_lens=self.input_buffers.dcp_local_seq_lens,
)
input_batch.attn_metadata = attn_metadata
input_batch.slot_mappings = slot_mappings_by_layer
@torch.inference_mode()
def _dummy_run(
self, num_tokens: int, *args, skip_attn: bool = True, **kwargs
self,
num_tokens: int,
*args,
skip_attn: bool = True,
uniform_decode: bool = False,
**kwargs,
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
# Create a dummy scheduler output.
num_reqs = min(num_tokens, self.max_num_reqs)
num_tokens_per_request = [num_tokens // num_reqs] * num_reqs
num_tokens_per_request[-1] += num_tokens % num_reqs
if uniform_decode:
# Align tokens to uniform_decode_query_len for cudagraph
# compatibility across DP ranks.
query_len = self.cudagraph_manager.uniform_decode_query_len
num_reqs = min(cdiv(num_tokens, query_len), self.max_num_reqs)
num_tokens = num_reqs * query_len
num_tokens_per_request = [query_len] * num_reqs
else:
num_reqs = min(num_tokens, self.max_num_reqs)
num_tokens_per_request = [num_tokens // num_reqs] * num_reqs
num_tokens_per_request[-1] += num_tokens % num_reqs
assert sum(num_tokens_per_request) == num_tokens
num_scheduled_tokens = {
f"_dummy_req_{i}": n for i, n in enumerate(num_tokens_per_request)
@@ -387,7 +379,41 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return None, None
assert self.execute_model_state is not None
hidden_states, _, input_batch, _ = self.execute_model_state
(
input_batch,
model_inputs,
attn_metadata,
slot_mappings_by_layer,
hidden_states,
aux_hidden_states,
kv_connector_output,
num_tokens_across_dp,
) = self.execute_model_state
self.execute_model_state = None
# dummy run the eagle speculator's propose to ensure DP/EP sync.
if self.speculator is not None:
self.speculator.propose(
input_batch=input_batch,
attn_metadata=attn_metadata,
slot_mappings=slot_mappings_by_layer,
last_hidden_states=hidden_states,
aux_hidden_states=aux_hidden_states,
num_sampled=torch.ones(
input_batch.num_reqs, dtype=torch.int32, device=self.device
),
num_rejected=torch.zeros(
input_batch.num_reqs, dtype=torch.int32, device=self.device
),
last_sampled=self.req_states.last_sampled_tokens,
next_prefill_tokens=self.req_states.next_prefill_tokens,
temperature=self.sampler.sampling_states.temperature.gpu,
seeds=self.sampler.sampling_states.seeds.gpu,
num_tokens_across_dp=num_tokens_across_dp,
dummy_run=True,
skip_attn_for_dummy_run=skip_attn,
)
assert hidden_states is not None # Last PP rank always has hidden_states
sample_hidden_states = hidden_states[input_batch.logits_indices]
return hidden_states, sample_hidden_states
@@ -416,39 +442,36 @@ class GPUModelRunner(LoRAModelRunnerMixin):
expanded_local_pos,
)
@torch.inference_mode()
def _dummy_pooler_run(self, hidden_states: torch.Tensor) -> None:
assert self.pooling_runner is not None
self.pooling_runner.dummy_pooler_run(hidden_states)
@torch.inference_mode()
def profile_run(self) -> None:
hidden_states, sample_hidden_states = self._dummy_run(
self.max_num_tokens, skip_attn=True
)
# Only run sampler on last PP rank (non-last ranks return None).
# Only run sampler/pooler on last PP rank (non-last ranks return None).
if self.is_last_pp_rank:
assert sample_hidden_states is not None
self._dummy_sampler_run(sample_hidden_states)
if self.speculator is not None:
num_tokens_across_dp = make_num_tokens_across_dp(
self.parallel_config.data_parallel_size, self.max_num_tokens
)
self.speculator.run_model(
self.max_num_tokens,
attn_metadata=None,
slot_mappings=None,
num_tokens_across_dp=num_tokens_across_dp,
)
if self.pooling_runner is None:
self._dummy_sampler_run(sample_hidden_states)
else:
self._dummy_pooler_run(hidden_states)
torch.cuda.synchronize()
del hidden_states, sample_hidden_states
gc.collect()
def reset_mm_cache(self) -> None:
if self.supports_mm_inputs:
self.encoder_runner.reset_mm_cache()
if self.encoder_cache is not None:
self.encoder_cache.reset_mm_cache()
def reset_encoder_cache(self) -> None:
if self.supports_mm_inputs:
self.encoder_runner.reset_encoder_cache()
if self.encoder_cache is not None:
self.encoder_cache.reset_encoder_cache()
def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int:
# SP is not supported yet.
@@ -477,17 +500,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
with self.maybe_setup_dummy_loras(self.lora_config):
mrope_positions = None
if self.uses_mrope:
mrope_positions = self.mrope_states.mrope_positions
inputs_embeds = None
if self.supports_mm_inputs:
inputs_embeds = self.encoder_runner.inputs_embeds
self.cudagraph_manager.capture(
model=self.model,
model_state=self.model_state,
input_buffers=self.input_buffers,
mrope_positions=mrope_positions,
inputs_embeds=inputs_embeds,
block_tables=self.block_tables,
attn_groups=self.attn_groups,
kv_cache_config=self.kv_cache_config,
@@ -522,21 +538,20 @@ class GPUModelRunner(LoRAModelRunnerMixin):
finished_req_ids = finished_req_ids.union(preempted_req_ids)
for req_id in finished_req_ids:
self.req_states.remove_request(req_id)
if self.supports_mm_inputs:
self.encoder_runner.remove_request(req_id)
if self.encoder_cache is not None:
self.encoder_cache.remove_request(req_id)
self.prompt_logprobs_worker.remove_request(req_id)
self.lora_state.remove_request(req_id)
def free_states(self, scheduler_output: SchedulerOutput) -> None:
if self.supports_mm_inputs:
if self.encoder_cache is not None:
for mm_hash in scheduler_output.free_encoder_mm_hashes:
self.encoder_runner.free_encoder_cache(mm_hash)
self.encoder_cache.free_encoder_cache(mm_hash)
def add_requests(self, scheduler_output: SchedulerOutput) -> None:
for new_req_data in scheduler_output.scheduled_new_reqs:
assert new_req_data.prompt_token_ids is not None
assert new_req_data.prefill_token_ids is not None
assert new_req_data.sampling_params is not None
req_id = new_req_data.req_id
prompt_len = len(new_req_data.prompt_token_ids)
self.req_states.add_request(
@@ -547,34 +562,27 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
req_index = self.req_states.req_id_to_index[req_id]
if self.supports_mm_inputs:
self.encoder_runner.add_request(req_id, new_req_data.mm_features)
# Pre-compute M-RoPE positions for prefill.
if self.uses_mrope:
self.mrope_states.init_prefill_mrope_positions(
req_index,
self.model, # type: ignore
new_req_data.prefill_token_ids,
mm_features=new_req_data.mm_features,
)
if self.encoder_cache is not None:
self.encoder_cache.add_request(req_id, new_req_data.mm_features)
self.model_state.add_request(req_index, new_req_data)
self.block_tables.append_block_ids(
req_index, new_req_data.block_ids, overwrite=True
)
self.sampler.add_request(
req_index, prompt_len, new_req_data.sampling_params
)
self.prompt_logprobs_worker.add_request(
req_id, req_index, new_req_data.sampling_params
)
self.lora_state.add_request(req_id, req_index, new_req_data.lora_request)
if new_req_data.sampling_params is not None:
self.sampler.add_request(
req_index, prompt_len, new_req_data.sampling_params
)
self.prompt_logprobs_worker.add_request(
req_id, req_index, new_req_data.sampling_params
)
if scheduler_output.scheduled_new_reqs:
self.req_states.apply_staged_writes()
self.sampler.apply_staged_writes()
if self.uses_mrope:
self.mrope_states.apply_staged_writes()
self.model_state.apply_staged_writes()
def update_requests(self, scheduler_output: SchedulerOutput) -> None:
# Add new blocks for the existing requests.
@@ -637,9 +645,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
idx_mapping, total_num_logits, cu_num_logits, max_expand_len
)
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
block_tables = self.block_tables.gather_block_tables(idx_mapping)
# Get query_start_loc.
query_start_loc_np = np.empty(self.max_num_reqs + 1, dtype=np.int32)
query_start_loc_np[0] = 0
@@ -648,11 +653,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Some attention backends like FA3 require query_start_loc to be non-decreasing.
query_start_loc_np[num_reqs + 1 :] = num_tokens
async_copy_to_gpu(query_start_loc_np, out=self.input_buffers.query_start_loc)
query_start_loc_np = query_start_loc_np[: num_reqs + 1]
query_start_loc_cpu = torch.from_numpy(query_start_loc_np)
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
max_query_len = num_scheduled_tokens.max().item()
# Get prefill tokens if any.
if self.req_states.any_prefills(idx_mapping_np):
@@ -676,6 +678,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
seq_lens = self.input_buffers.seq_lens[:num_reqs]
dcp_local_seq_lens = None
if self.use_dcp:
# Prepare dcp local seq_lens.
prepare_dcp_local_seq_lens(
@@ -686,16 +689,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.dcp_rank,
self.cp_interleave,
)
dcp_local_seq_lens = self.input_buffers.dcp_local_seq_lens[:num_reqs]
# Prepare M-RoPE positions.
if self.uses_mrope:
self.mrope_states.prepare_mrope_positions(
idx_mapping,
query_start_loc,
self.req_states.prefill_len.gpu,
self.req_states.num_computed_tokens.gpu,
)
dcp_local_seq_lens = self.input_buffers.dcp_local_seq_lens[:num_reqs]
# Some input token ids are directly read from the last sampled tokens
# and draft tokens. Also, get the logits indices to sample tokens from.
@@ -711,39 +705,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
total_num_logits,
)
# Compute slot mappings: [num_kv_cache_groups, num_tokens]
slot_mappings = self.block_tables.compute_slot_mappings(
idx_mapping,
query_start_loc,
self.input_buffers.positions[:num_tokens],
)
# Layer name -> slot mapping.
slot_mappings_by_layer = build_slot_mappings_by_layer(
slot_mappings, self.kv_cache_config
)
# Layer name -> attention metadata.
attn_metadata = build_attn_metadata(
attn_groups=self.attn_groups,
num_reqs=num_reqs,
num_tokens=num_tokens,
query_start_loc_gpu=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
max_query_len=max_query_len,
seq_lens=self.input_buffers.seq_lens,
max_seq_len=self.max_model_len,
block_tables=block_tables,
slot_mappings=slot_mappings,
kv_cache_config=self.kv_cache_config,
dcp_local_seq_lens=dcp_local_seq_lens,
)
input_ids = self.input_buffers.input_ids[:num_tokens_after_padding]
positions = self.input_buffers.positions[:num_tokens_after_padding]
mrope_positions = None
if self.uses_mrope:
mrope_positions = self.mrope_states.mrope_positions
mrope_positions = mrope_positions[:, :num_tokens_after_padding]
return InputBatch(
req_ids=req_ids,
num_reqs=num_reqs,
@@ -758,37 +719,36 @@ class GPUModelRunner(LoRAModelRunnerMixin):
query_start_loc=query_start_loc,
query_start_loc_np=query_start_loc_np,
seq_lens=seq_lens,
input_ids=input_ids,
positions=positions,
mrope_positions=mrope_positions,
inputs_embeds=None,
attn_metadata=attn_metadata,
slot_mappings=slot_mappings_by_layer,
dcp_local_seq_lens=dcp_local_seq_lens,
input_ids=self.input_buffers.input_ids[:num_tokens_after_padding],
positions=self.input_buffers.positions[:num_tokens_after_padding],
logits_indices=logits_indices,
cu_num_logits=cu_num_logits,
cu_num_logits_np=cu_num_logits_np,
has_structured_output_reqs=scheduler_output.has_structured_output_requests,
)
@torch.inference_mode()
def get_mm_embeddings(
self,
scheduled_encoder_inputs: dict[str, list[int]],
input_batch: InputBatch,
) -> tuple[list[torch.Tensor], torch.Tensor]:
mm_hashes, mm_kwargs = self.encoder_runner.prepare_mm_inputs(
scheduled_encoder_inputs
def prepare_attn(
self, input_batch: InputBatch
) -> tuple[tuple[torch.Tensor, ...], torch.Tensor]:
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
block_tables = self.block_tables.gather_block_tables(input_batch.idx_mapping)
# Compute slot mappings: [num_kv_cache_groups, num_tokens]
slot_mappings = self.block_tables.compute_slot_mappings(
input_batch.idx_mapping,
input_batch.query_start_loc,
input_batch.positions,
)
self.encoder_runner.execute_mm_encoder(self.model, mm_hashes, mm_kwargs)
mm_embeds, is_mm_embed = self.encoder_runner.gather_mm_embeddings(
input_batch.req_ids,
input_batch.num_tokens,
input_batch.num_scheduled_tokens,
input_batch.query_start_loc_np,
self.req_states.prefill_len.np[input_batch.idx_mapping_np],
self.req_states.num_computed_prefill_tokens[input_batch.idx_mapping_np],
return block_tables, slot_mappings
def prepare_dummy_attn(
self, input_batch: InputBatch
) -> tuple[tuple[torch.Tensor, ...], torch.Tensor]:
block_tables = self.block_tables.get_dummy_block_tables(input_batch.num_reqs)
slot_mappings = self.block_tables.get_dummy_slot_mappings(
input_batch.num_tokens
)
return mm_embeds, is_mm_embed
return block_tables, slot_mappings
def sample(
self,
@@ -926,6 +886,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
input_batch = self.prepare_inputs(
scheduler_output, num_tokens_after_padding
)
block_tables, slot_mappings = self.prepare_attn(input_batch)
if self.lora_config:
# Activate LoRA adapters.
lora_inputs = self.lora_state.make_lora_inputs(
@@ -934,35 +896,61 @@ class GPUModelRunner(LoRAModelRunnerMixin):
input_batch.num_scheduled_tokens,
)
self._set_active_loras(*lora_inputs)
# Only first PP rank prepares multimodal embeddings.
if self.supports_mm_inputs and self.is_first_pp_rank:
mm_embeds, is_mm_embed = self.get_mm_embeddings(
scheduler_output.scheduled_encoder_inputs, input_batch
)
inputs_embeds = self.encoder_runner.get_inputs_embeds(
self.model, input_batch.input_ids, mm_embeds, is_mm_embed
)
input_batch.inputs_embeds = inputs_embeds[
: input_batch.num_tokens_after_padding
]
else:
# No actual tokens to run. A dummy run for DP or memory profiling.
num_reqs = min(num_tokens_after_padding, self.max_num_reqs)
input_batch = InputBatch.make_dummy(
num_reqs=num_reqs,
num_tokens=num_tokens_after_padding,
input_buffers=self.input_buffers,
device=self.device,
num_reqs, num_tokens_after_padding, self.input_buffers
)
if self.uses_mrope:
input_batch.mrope_positions = self.mrope_states.mrope_positions[
:, :num_tokens_after_padding
]
if not skip_attn_for_dummy_run:
self.prepare_dummy_attn_metadata(input_batch)
block_tables, slot_mappings = self.prepare_dummy_attn(input_batch)
else:
block_tables = None
slot_mappings = None
# FIXME(woosuk): Fix warmup for LoRA.
attn_metadata = None
slot_mappings_by_layer = None
if not (dummy_run and skip_attn_for_dummy_run):
assert slot_mappings is not None
slot_mappings_by_layer = build_slot_mappings_by_layer(
slot_mappings, self.kv_cache_config
)
assert block_tables is not None
attn_metadata = self.model_state.prepare_attn(
input_batch,
block_tables,
slot_mappings,
self.attn_groups,
self.kv_cache_config,
)
inputs_embeds = None
if self.supports_mm_inputs and self.is_first_pp_rank:
# Run MM encoder (if needed) and get multimodal embeddings.
# Only first PP rank prepares multimodal embeddings.
# NOTE(woosuk): We must call get_mm_embeddings even during dummy runs
# to obtain inputs_embeds, because the compiled model expects this input.
inputs_embeds = self.model_state.get_mm_embeddings(
scheduler_output.scheduled_encoder_inputs,
input_batch,
self.req_states,
)
model_inputs = {
"input_ids": input_batch.input_ids,
"positions": input_batch.positions,
"inputs_embeds": inputs_embeds,
# NOTE: Values returned by `prepare_inputs` will override the default
# values above.
**self.model_state.prepare_inputs(input_batch, self.req_states),
}
if not self.is_first_pp_rank:
# Update for non-first PP ranks.
model_inputs["input_ids"] = None
model_inputs["inputs_embeds"] = None
model_inputs["intermediate_tensors"] = intermediate_tensors
# Run model.
if cudagraph_runtime_mode == CUDAGraphMode.FULL:
# Use explicit cudagraph replay for FULL mode.
@@ -979,41 +967,22 @@ class GPUModelRunner(LoRAModelRunnerMixin):
aux_hidden_states = None
else:
# For piecewise and eager mode, just call model().
positions = input_batch.positions
if self.uses_mrope:
assert input_batch.mrope_positions is not None
positions = input_batch.mrope_positions
if self.is_first_pp_rank:
input_ids = input_batch.input_ids
inputs_embeds = input_batch.inputs_embeds
assert intermediate_tensors is None
else:
input_ids = None
inputs_embeds = None
assert intermediate_tensors is not None
batch_descriptor = BatchDescriptor(
num_tokens=input_batch.num_tokens_after_padding,
has_lora=self.lora_config is not None,
)
with set_forward_context(
input_batch.attn_metadata,
attn_metadata,
self.vllm_config,
num_tokens=input_batch.num_tokens_after_padding,
cudagraph_runtime_mode=cudagraph_runtime_mode,
num_tokens_across_dp=num_tokens_across_dp,
batch_descriptor=batch_descriptor,
slot_mapping=input_batch.slot_mappings,
slot_mapping=slot_mappings_by_layer,
):
self.kv_connector.pre_forward(scheduler_output)
model_output = self.model(
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors,
)
model_output = self.model(**model_inputs)
if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = model_output
else:
@@ -1021,33 +990,44 @@ class GPUModelRunner(LoRAModelRunnerMixin):
aux_hidden_states = None
kv_connector_output = self.kv_connector.post_forward(scheduler_output)
self.execute_model_state = (
input_batch,
model_inputs,
attn_metadata,
slot_mappings_by_layer,
hidden_states,
aux_hidden_states,
kv_connector_output,
num_tokens_across_dp,
)
if not self.is_last_pp_rank:
# Non-last PP rank: return IntermediateTensors for sending.
assert isinstance(hidden_states, IntermediateTensors)
hidden_states.kv_connector_output = kv_connector_output
self.execute_model_state = (None, None, input_batch, kv_connector_output)
return hidden_states
# Last rank (or no PP): hidden_states is a tensor for sampling.
assert isinstance(hidden_states, torch.Tensor)
self.execute_model_state = (
hidden_states,
aux_hidden_states,
input_batch,
kv_connector_output,
) # type: ignore
return None
@torch.inference_mode()
def sample_tokens(
self, grammar_output: GrammarOutput | None
) -> AsyncOutput | ModelRunnerOutput | None:
assert self.execute_model_state is not None
hidden_states, aux_hidden_states, input_batch, kv_connector_output = (
self.execute_model_state
)
self.execute_model_state = None # type: ignore
if self.execute_model_state is None:
# The prior execute_model call must have failed.
return None
(
input_batch,
model_inputs,
attn_metadata,
slot_mappings_by_layer,
hidden_states,
aux_hidden_states,
kv_connector_output,
num_tokens_across_dp,
) = self.execute_model_state
self.execute_model_state = None
if not self.is_last_pp_rank:
# Non-last PP rank: hidden_states is None because this rank produced
@@ -1109,6 +1089,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if self.speculator is not None:
draft_tokens = self.speculator.propose(
input_batch,
attn_metadata,
slot_mappings_by_layer,
hidden_states,
aux_hidden_states,
num_sampled,
@@ -1117,6 +1099,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.req_states.next_prefill_tokens,
self.sampler.sampling_states.temperature.gpu,
self.sampler.sampling_states.seeds.gpu,
num_tokens_across_dp=num_tokens_across_dp,
)
self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens
self.draft_tokens_handler.set_draft_tokens(input_batch, draft_tokens)
@@ -1127,3 +1110,58 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def take_draft_token_ids(self) -> DraftTokenIds | None:
return self.draft_tokens_handler.get_draft_tokens()
@torch.inference_mode()
def pool(self) -> AsyncPoolingOutput | ModelRunnerOutput | None:
if self.execute_model_state is None:
# The prior execute_model call must have failed.
return None
input_batch, _, _, _, hidden_states, _, kv_connector_output = (
self.execute_model_state
)
self.execute_model_state = None
if not self.is_last_pp_rank:
self.postprocess_pool(input_batch)
return None
assert self.pooling_runner is not None
pooler_output, is_valid = self.pooling_runner.pool(
hidden_states, input_batch, self.req_states
)
self.postprocess_pool(input_batch)
# Build the model runner output.
model_runner_output = ModelRunnerOutput(
req_ids=input_batch.req_ids,
req_id_to_index={req_id: i for i, req_id in enumerate(input_batch.req_ids)},
kv_connector_output=kv_connector_output,
)
async_output = AsyncPoolingOutput(
model_runner_output=model_runner_output,
pooler_output=pooler_output,
is_valid=is_valid,
main_stream=self.main_stream,
copy_stream=self.output_copy_stream,
copy_event=self.output_copy_event,
)
if self.use_async_scheduling:
return async_output
return async_output.get_output()
def postprocess_pool(self, input_batch: InputBatch) -> None:
# Update the number of computed tokens.
post_update_pool(
input_batch.idx_mapping,
self.req_states.num_computed_tokens.gpu,
input_batch.query_start_loc,
)
# Update the number of computed prefill tokens.
idx_mapping_np = input_batch.idx_mapping_np
computed_prefill = self.req_states.num_computed_prefill_tokens
computed_prefill[idx_mapping_np] += input_batch.num_scheduled_tokens
np.minimum(
computed_prefill, self.req_states.prefill_len.np, out=computed_prefill
)

View File

@@ -0,0 +1,18 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache
def init_model_state(
vllm_config: VllmConfig,
model: nn.Module,
encoder_cache: EncoderCache | None,
device: torch.device,
):
from vllm.v1.worker.gpu.model_states.default import DefaultModelState
return DefaultModelState(vllm_config, model, encoder_cache, device)

View File

@@ -0,0 +1,161 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.v1.core.sched.output import NewRequestData
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
from vllm.v1.worker.gpu.input_batch import InputBatch
from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache
from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner
from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState
from vllm.v1.worker.gpu.model_states.interface import ModelState
from vllm.v1.worker.gpu.states import RequestState
from vllm.v1.worker.utils import AttentionGroup
class DefaultModelState(ModelState):
def __init__(
self,
vllm_config: VllmConfig,
model: nn.Module,
encoder_cache: EncoderCache | None,
device: torch.device,
):
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.scheduler_config = vllm_config.scheduler_config
self.model = model
self.device = device
self.supports_mm_inputs = encoder_cache is not None
self.max_model_len = self.model_config.max_model_len
self.max_num_reqs = self.scheduler_config.max_num_seqs
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
self.inputs_embeds_size = self.model_config.get_inputs_embeds_size()
self.dtype = self.model_config.dtype
if self.supports_mm_inputs:
assert encoder_cache is not None
self.encoder_cache = encoder_cache
self.encoder_runner = EncoderRunner(
model=self.model,
max_num_tokens=self.max_num_tokens,
hidden_size=self.inputs_embeds_size,
encoder_cache=encoder_cache,
dtype=self.dtype,
device=self.device,
)
self.uses_mrope = self.model_config.uses_mrope
if self.uses_mrope:
self.mrope_state = MRopeState(
max_num_reqs=self.max_num_reqs,
max_num_tokens=self.max_num_tokens,
max_model_len=self.max_model_len,
device=self.device,
)
def add_request(self, req_index: int, new_req_data: NewRequestData) -> None:
if self.uses_mrope:
# Pre-compute M-RoPE positions for prefill.
assert new_req_data.prefill_token_ids is not None
self.mrope_state.init_prefill_mrope_positions(
req_index,
self.model, # type: ignore
new_req_data.prefill_token_ids,
mm_features=new_req_data.mm_features,
)
def apply_staged_writes(self) -> None:
if self.uses_mrope:
self.mrope_state.apply_staged_writes()
def get_mm_embeddings(
self,
scheduled_encoder_inputs: dict[str, list[int]],
input_batch: InputBatch,
req_states: RequestState,
) -> torch.Tensor:
mm_hashes, mm_kwargs = self.encoder_runner.prepare_mm_inputs(
scheduled_encoder_inputs
)
if mm_kwargs:
# Execute the multimodal encoder.
encoder_outputs = self.encoder_runner.execute_mm_encoder(mm_kwargs)
# Cache the encoder outputs by mm_hash
self.encoder_cache.encoder_outputs.update(zip(mm_hashes, encoder_outputs))
mm_embeds, is_mm_embed = self.encoder_runner.gather_mm_embeddings(
input_batch.req_ids,
input_batch.num_tokens,
input_batch.num_scheduled_tokens,
input_batch.query_start_loc_np,
req_states.prefill_len.np[input_batch.idx_mapping_np],
req_states.num_computed_prefill_tokens[input_batch.idx_mapping_np],
)
inputs_embeds = self.encoder_runner.get_inputs_embeds(
input_batch.input_ids, mm_embeds, is_mm_embed
)
return inputs_embeds[: input_batch.num_tokens_after_padding]
def prepare_inputs(
self, input_batch: InputBatch, req_states: RequestState
) -> dict[str, torch.Tensor | None]:
if not self.uses_mrope:
# Common case (1D positions).
return {}
# Prepare M-RoPE positions.
self.mrope_state.prepare_mrope_positions(
input_batch.idx_mapping,
input_batch.query_start_loc,
req_states.prefill_len.gpu,
req_states.num_computed_tokens.gpu,
)
mrope_positions = self.mrope_state.mrope_positions[
:, : input_batch.num_tokens_after_padding
]
return {"positions": mrope_positions}
def prepare_dummy_inputs(
self, num_reqs: int, num_tokens: int
) -> dict[str, torch.Tensor | None]:
model_inputs = {}
if self.supports_mm_inputs:
inputs_embeds = self.encoder_runner.inputs_embeds[:num_tokens]
model_inputs["inputs_embeds"] = inputs_embeds
if self.uses_mrope:
mrope_positions = self.mrope_state.mrope_positions[:, :num_tokens]
model_inputs["positions"] = mrope_positions
return model_inputs
def prepare_attn(
self,
input_batch: InputBatch,
block_tables: tuple[torch.Tensor, ...],
slot_mappings: torch.Tensor,
attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig,
) -> dict[str, Any]:
query_start_loc_cpu = torch.from_numpy(input_batch.query_start_loc_np)
max_query_len = input_batch.num_scheduled_tokens.max().item()
attn_metadata = build_attn_metadata(
attn_groups=attn_groups,
num_reqs=input_batch.num_reqs,
num_tokens=input_batch.num_tokens,
query_start_loc_gpu=input_batch.query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
max_query_len=max_query_len,
seq_lens=input_batch.seq_lens,
max_seq_len=self.max_model_len,
block_tables=block_tables,
slot_mappings=slot_mappings,
kv_cache_config=kv_cache_config,
dcp_local_seq_lens=input_batch.dcp_local_seq_lens,
)
return attn_metadata

View File

@@ -0,0 +1,67 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from typing import Any
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.v1.core.sched.output import NewRequestData
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.input_batch import InputBatch
from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache
from vllm.v1.worker.gpu.states import RequestState
from vllm.v1.worker.utils import AttentionGroup
class ModelState(ABC):
@abstractmethod
def __init__(
self,
vllm_config: VllmConfig,
model: nn.Module,
encoder_cache: EncoderCache | None,
device: torch.device,
) -> None:
raise NotImplementedError
@abstractmethod
def add_request(self, req_index: int, new_req_data: NewRequestData) -> None:
raise NotImplementedError
@abstractmethod
def apply_staged_writes(self) -> None:
raise NotImplementedError
@abstractmethod
def get_mm_embeddings(
self,
scheduled_encoder_inputs: dict[str, list[int]],
input_batch: InputBatch,
req_states: RequestState,
) -> torch.Tensor:
raise NotImplementedError
@abstractmethod
def prepare_inputs(
self, input_batch: InputBatch, req_states: RequestState
) -> dict[str, torch.Tensor | None]:
raise NotImplementedError
@abstractmethod
def prepare_dummy_inputs(
self, num_reqs: int, num_tokens: int
) -> dict[str, torch.Tensor | None]:
raise NotImplementedError
@abstractmethod
def prepare_attn(
self,
input_batch: InputBatch,
block_tables: tuple[torch.Tensor, ...],
slot_mappings: torch.Tensor,
attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig,
) -> dict[str, Any]:
raise NotImplementedError

View File

View File

@@ -0,0 +1,45 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import cast
import torch
import torch.nn as nn
import torch.nn.functional as F
from vllm.model_executor.models import VllmModelForPooling, is_pooling_model
from vllm.tasks import PoolingTask
from vllm.v1.worker.gpu.input_batch import InputBatch
from vllm.v1.worker.gpu.states import RequestState
# NOTE(woosuk): Currently, this class only supports the "LAST" pooling task
# on decoder-only models. How to support other pooling tasks and models
# is to be determined.
class PoolingRunner:
def __init__(self, model: nn.Module):
self.model = cast(VllmModelForPooling, model)
def get_supported_pooling_tasks(self) -> list[PoolingTask]:
if not is_pooling_model(self.model):
return []
assert "embed" in self.model.pooler.get_supported_tasks()
return ["embed"]
def pool(
self,
hidden_states: torch.Tensor,
input_batch: InputBatch,
req_states: RequestState,
) -> tuple[torch.Tensor, torch.Tensor | None]:
# TODO(woosuk): Support different types of pooling tasks.
last_hidden_states = hidden_states[input_batch.logits_indices]
# TODO(woosuk): Make normalization optional.
last_hidden_states = F.normalize(last_hidden_states, p=2, dim=-1)
prompt_len = req_states.prompt_len.gpu[input_batch.idx_mapping]
is_valid = input_batch.seq_lens == prompt_len
return last_hidden_states, is_valid
def dummy_pooler_run(self, hidden_states: torch.Tensor) -> None:
F.normalize(hidden_states, p=2, dim=-1)
return

View File

@@ -72,7 +72,7 @@ class BadWordsState:
def apply_bad_words(
self,
logits: torch.Tensor,
idx_mapping: torch.Tensor,
expanded_idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
input_ids: torch.Tensor,
expanded_local_pos: torch.Tensor,
@@ -84,7 +84,7 @@ class BadWordsState:
apply_bad_words(
logits,
idx_mapping,
expanded_idx_mapping,
self.bad_word_token_ids.gpu,
self.bad_word_offsets.gpu,
self.num_bad_words.gpu,
@@ -114,17 +114,17 @@ def _bad_words_kernel(
input_ids_ptr,
expanded_local_pos_ptr,
):
logit_idx = tl.program_id(0)
token_idx = tl.program_id(0)
bw_idx = tl.program_id(1)
req_state_idx = tl.load(expanded_idx_mapping_ptr + logit_idx)
req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
num_bad_words = tl.load(num_bad_words_ptr + req_state_idx)
if bw_idx >= num_bad_words:
return
pos = tl.load(expanded_local_pos_ptr + logit_idx)
cur_req_first_pos = logit_idx - pos
pos = tl.load(expanded_local_pos_ptr + token_idx)
cur_req_first_pos = token_idx - pos
prompt_len = tl.load(prompt_len_ptr + req_state_idx)
total_len = tl.load(total_len_ptr + req_state_idx)
@@ -159,7 +159,7 @@ def _bad_words_kernel(
match = match & (expected == actual)
if match:
tl.store(logits_ptr + logit_idx * logits_stride + last_token, -float("inf"))
tl.store(logits_ptr + token_idx * logits_stride + last_token, -float("inf"))
def apply_bad_words(
@@ -175,8 +175,8 @@ def apply_bad_words(
expanded_local_pos: torch.Tensor,
max_num_bad_words: int,
) -> None:
total_num_tokens = logits.shape[0]
_bad_words_kernel[(total_num_tokens, max_num_bad_words)](
num_tokens = logits.shape[0]
_bad_words_kernel[(num_tokens, max_num_bad_words)](
logits,
logits.stride(0),
expanded_idx_mapping,

View File

@@ -9,13 +9,13 @@ from vllm.triton_utils import tl, triton
def _temperature_kernel(
logits_ptr,
logits_stride,
idx_mapping_ptr,
expanded_idx_mapping_ptr,
temperature_ptr,
vocab_size,
BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
token_idx = tl.program_id(0)
req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
temperature = tl.load(temperature_ptr + req_state_idx).to(tl.float32)
if temperature == 0.0 or temperature == 1.0:
# Early return to avoid loading logits.
@@ -25,24 +25,24 @@ def _temperature_kernel(
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size
logits = tl.load(logits_ptr + batch_idx * logits_stride + block, mask=mask)
logits = tl.load(logits_ptr + token_idx * logits_stride + block, mask=mask)
logits = logits.to(tl.float32)
logits = logits / temperature
tl.store(logits_ptr + batch_idx * logits_stride + block, logits, mask=mask)
tl.store(logits_ptr + token_idx * logits_stride + block, logits, mask=mask)
def apply_temperature(
logits: torch.Tensor,
idx_mapping: torch.Tensor,
expanded_idx_mapping: torch.Tensor,
temperature: torch.Tensor,
) -> None:
num_reqs, vocab_size = logits.shape
num_tokens, vocab_size = logits.shape
BLOCK_SIZE = 8192
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
_temperature_kernel[(num_reqs, num_blocks)](
_temperature_kernel[(num_tokens, num_blocks)](
logits,
logits.stride(0),
idx_mapping,
expanded_idx_mapping,
temperature,
vocab_size,
BLOCK_SIZE=BLOCK_SIZE,
@@ -57,7 +57,7 @@ def _gumbel_sample_kernel(
local_max_stride,
logits_ptr,
logits_stride,
idx_mapping_ptr,
expanded_idx_mapping_ptr,
seeds_ptr,
pos_ptr,
temp_ptr,
@@ -65,14 +65,14 @@ def _gumbel_sample_kernel(
BLOCK_SIZE: tl.constexpr,
APPLY_TEMPERATURE: tl.constexpr,
):
batch_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
token_idx = tl.program_id(0)
req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
block_idx = tl.program_id(1)
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size
logits = tl.load(
logits_ptr + batch_idx * logits_stride + block,
logits_ptr + token_idx * logits_stride + block,
mask=mask,
other=float("-inf"),
)
@@ -82,7 +82,7 @@ def _gumbel_sample_kernel(
if temp != 0.0:
# Calculate the seed for gumbel noise.
seed = tl.load(seeds_ptr + req_state_idx)
pos = tl.load(pos_ptr + batch_idx)
pos = tl.load(pos_ptr + token_idx)
gumbel_seed = tl.randint(seed, pos)
# Generate gumbel noise in FP32.
@@ -101,41 +101,41 @@ def _gumbel_sample_kernel(
value, idx = tl.max(logits, axis=0, return_indices=True)
token_id = block_idx * BLOCK_SIZE + idx
tl.store(local_argmax_ptr + batch_idx * local_argmax_stride + block_idx, token_id)
tl.store(local_max_ptr + batch_idx * local_max_stride + block_idx, value)
tl.store(local_argmax_ptr + token_idx * local_argmax_stride + block_idx, token_id)
tl.store(local_max_ptr + token_idx * local_max_stride + block_idx, value)
def gumbel_sample(
logits: torch.Tensor, # [num_reqs, vocab_size]
idx_mapping: torch.Tensor, # [max_num_reqs]
logits: torch.Tensor, # [num_tokens, vocab_size]
expanded_idx_mapping: torch.Tensor, # [num_tokens]
temperature: torch.Tensor, # [max_num_reqs]
seed: torch.Tensor, # [max_num_reqs]
pos: torch.Tensor, # [num_reqs]
pos: torch.Tensor, # [num_tokens]
apply_temperature: bool,
) -> torch.Tensor:
num_reqs, vocab_size = logits.shape
num_tokens, vocab_size = logits.shape
BLOCK_SIZE = 1024
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
local_argmax = torch.empty(
num_reqs,
num_tokens,
num_blocks,
dtype=torch.int64,
device=logits.device,
)
local_max = torch.empty(
num_reqs,
num_tokens,
num_blocks,
dtype=torch.float32,
device=logits.device,
)
_gumbel_sample_kernel[(num_reqs, num_blocks)](
_gumbel_sample_kernel[(num_tokens, num_blocks)](
local_argmax,
local_argmax.stride(0),
local_max,
local_max.stride(0),
logits,
logits.stride(0),
idx_mapping,
expanded_idx_mapping,
seed,
pos,
temperature,

View File

@@ -121,7 +121,7 @@ class LogitBiasState:
def apply_logit_bias(
self,
logits: torch.Tensor,
idx_mapping: torch.Tensor,
expanded_idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
pos: torch.Tensor,
) -> None:
@@ -131,7 +131,7 @@ class LogitBiasState:
apply_logit_bias(
logits,
idx_mapping,
expanded_idx_mapping,
pos,
self.num_allowed_token_ids.gpu,
self.allowed_token_ids.gpu,
@@ -149,7 +149,7 @@ def _bias_kernel(
logits_ptr,
logits_stride,
vocab_size,
idx_mapping_ptr,
expanded_idx_mapping_ptr,
# Allowed token IDs.
num_allowed_token_ids_ptr,
allowed_token_ids_ptr,
@@ -169,8 +169,8 @@ def _bias_kernel(
BLOCK_SIZE: tl.constexpr,
LOGITS_BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
token_idx = tl.program_id(0)
req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
block = tl.arange(0, BLOCK_SIZE)
@@ -186,21 +186,21 @@ def _bias_kernel(
mask=mask,
)
logits = tl.load(
logits_ptr + batch_idx * logits_stride + allowed_token_ids, mask=mask
logits_ptr + token_idx * logits_stride + allowed_token_ids, mask=mask
)
# Set logits to -inf for all tokens.
for i in range(0, vocab_size, LOGITS_BLOCK_SIZE):
offset = i + tl.arange(0, LOGITS_BLOCK_SIZE)
tl.store(
logits_ptr + batch_idx * logits_stride + offset,
logits_ptr + token_idx * logits_stride + offset,
-float("inf"),
mask=offset < vocab_size,
)
# Restore logits for allowed token IDs.
tl.store(
logits_ptr + batch_idx * logits_stride + allowed_token_ids,
logits_ptr + token_idx * logits_stride + allowed_token_ids,
logits,
mask=mask,
)
@@ -214,13 +214,13 @@ def _bias_kernel(
mask=mask,
)
bias = tl.load(bias_ptr + req_state_idx * bias_stride + block, mask=mask)
logits = tl.load(logits_ptr + batch_idx * logits_stride + token_ids, mask=mask)
logits = tl.load(logits_ptr + token_idx * logits_stride + token_ids, mask=mask)
logits += bias
tl.store(logits_ptr + batch_idx * logits_stride + token_ids, logits, mask=mask)
tl.store(logits_ptr + token_idx * logits_stride + token_ids, logits, mask=mask)
# Apply min tokens.
num_stop_token_ids = tl.load(num_stop_token_ids_ptr + req_state_idx)
pos = tl.load(pos_ptr + batch_idx)
pos = tl.load(pos_ptr + token_idx)
min_len = tl.load(min_lens_ptr + req_state_idx)
if num_stop_token_ids > 0 and pos < min_len:
mask = block < num_stop_token_ids
@@ -229,7 +229,7 @@ def _bias_kernel(
mask=mask,
)
tl.store(
logits_ptr + batch_idx * logits_stride + stop_token_ids,
logits_ptr + token_idx * logits_stride + stop_token_ids,
-float("inf"),
mask=mask,
)
@@ -237,7 +237,7 @@ def _bias_kernel(
def apply_logit_bias(
logits: torch.Tensor,
idx_mapping: torch.Tensor,
expanded_idx_mapping: torch.Tensor,
pos: torch.Tensor,
num_allowed_token_ids: torch.Tensor,
allowed_token_ids: torch.Tensor,
@@ -248,7 +248,7 @@ def apply_logit_bias(
num_stop_token_ids: torch.Tensor,
stop_token_ids: torch.Tensor,
) -> None:
num_reqs, vocab_size = logits.shape
num_tokens, vocab_size = logits.shape
BLOCK_SIZE = triton.next_power_of_2(
max(
allowed_token_ids.shape[-1],
@@ -257,11 +257,11 @@ def apply_logit_bias(
)
)
LOGITS_BLOCK_SIZE = 8192
_bias_kernel[(num_reqs,)](
_bias_kernel[(num_tokens,)](
logits,
logits.stride(0),
vocab_size,
idx_mapping,
expanded_idx_mapping,
num_allowed_token_ids,
allowed_token_ids,
allowed_token_ids.stride(0),

View File

@@ -9,13 +9,13 @@ from vllm.triton_utils import tl, triton
def _min_p_kernel(
logits_ptr,
logits_stride,
idx_mapping_ptr,
expanded_idx_mapping_ptr,
min_p_ptr,
vocab_size,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + req_idx)
token_idx = tl.program_id(0)
req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
min_p = tl.load(min_p_ptr + req_state_idx).to(tl.float32)
if min_p == 0.0:
return
@@ -25,7 +25,9 @@ def _min_p_kernel(
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size
logits = tl.load(
logits_ptr + req_idx * logits_stride + block, mask=mask, other=float("-inf")
logits_ptr + token_idx * logits_stride + block,
mask=mask,
other=float("-inf"),
)
max_val = tl.max(tl.maximum(logits, max_val))
max_val = max_val.to(tl.float32) # type: ignore
@@ -35,21 +37,23 @@ def _min_p_kernel(
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size
logits = tl.load(
logits_ptr + req_idx * logits_stride + block, mask=mask, other=float("-inf")
logits_ptr + token_idx * logits_stride + block,
mask=mask,
other=float("-inf"),
)
logits = tl.where(logits < threshold, float("-inf"), logits)
tl.store(logits_ptr + req_idx * logits_stride + block, logits, mask=mask)
tl.store(logits_ptr + token_idx * logits_stride + block, logits, mask=mask)
def apply_min_p(
logits: torch.Tensor, idx_mapping: torch.Tensor, min_p: torch.Tensor
logits: torch.Tensor, expanded_idx_mapping: torch.Tensor, min_p: torch.Tensor
) -> None:
num_reqs, vocab_size = logits.shape
num_tokens, vocab_size = logits.shape
BLOCK_SIZE = 1024
_min_p_kernel[(num_reqs,)](
_min_p_kernel[(num_tokens,)](
logits,
logits.stride(0),
idx_mapping,
expanded_idx_mapping,
min_p,
vocab_size,
BLOCK_SIZE=BLOCK_SIZE,

View File

@@ -82,7 +82,7 @@ class PenaltiesState:
def apply_penalties(
self,
logits: torch.Tensor,
idx_mapping: torch.Tensor,
expanded_idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
input_ids: torch.Tensor,
expanded_local_pos: torch.Tensor,
@@ -94,7 +94,7 @@ class PenaltiesState:
apply_penalties(
logits,
idx_mapping,
expanded_idx_mapping,
input_ids,
expanded_local_pos,
self.repetition_penalty.gpu,
@@ -110,7 +110,7 @@ class PenaltiesState:
def _penalties_kernel(
logits_ptr,
logits_stride,
idx_mapping_ptr,
expanded_idx_mapping_ptr,
token_ids_ptr,
expanded_local_pos_ptr,
repetition_penalty_ptr,
@@ -125,7 +125,7 @@ def _penalties_kernel(
MAX_SPEC_LEN: tl.constexpr,
):
token_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + token_idx)
req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
rep_penalty = tl.load(repetition_penalty_ptr + req_state_idx)
freq_penalty = tl.load(frequency_penalty_ptr + req_state_idx)
pres_penalty = tl.load(presence_penalty_ptr + req_state_idx)
@@ -191,7 +191,7 @@ def _penalties_kernel(
def apply_penalties(
logits: torch.Tensor,
idx_mapping: torch.Tensor,
expanded_idx_mapping: torch.Tensor,
token_ids: torch.Tensor,
expanded_local_pos: torch.Tensor,
repetition_penalty: torch.Tensor,
@@ -207,7 +207,7 @@ def apply_penalties(
_penalties_kernel[(num_tokens, num_blocks)](
logits,
logits.stride(0),
idx_mapping,
expanded_idx_mapping,
token_ids,
expanded_local_pos,
repetition_penalty,
@@ -225,7 +225,7 @@ def apply_penalties(
@triton.jit
def _bincount_kernel(
idx_mapping_ptr,
expanded_idx_mapping_ptr,
all_token_ids_ptr,
all_token_ids_stride,
prompt_len_ptr,
@@ -236,9 +236,9 @@ def _bincount_kernel(
output_bin_counts_stride,
BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
token_idx = tl.program_id(0)
block_idx = tl.program_id(1)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
prefill_len = tl.load(prefill_len_ptr + req_state_idx)
if block_idx * BLOCK_SIZE >= prefill_len:
@@ -276,7 +276,7 @@ def _bincount_kernel(
def bincount(
idx_mapping: torch.Tensor,
expanded_idx_mapping: torch.Tensor,
all_token_ids: torch.Tensor,
prompt_len: torch.Tensor,
prefill_len: torch.Tensor,
@@ -284,13 +284,13 @@ def bincount(
output_bin_counts: torch.Tensor,
max_prefill_len: int,
) -> None:
prompt_bin_mask[idx_mapping] = 0
output_bin_counts[idx_mapping] = 0
num_reqs = idx_mapping.shape[0]
prompt_bin_mask[expanded_idx_mapping] = 0
output_bin_counts[expanded_idx_mapping] = 0
num_tokens = expanded_idx_mapping.shape[0]
BLOCK_SIZE = 1024
num_blocks = triton.cdiv(max_prefill_len, BLOCK_SIZE)
_bincount_kernel[(num_reqs, num_blocks)](
idx_mapping,
_bincount_kernel[(num_tokens, num_blocks)](
expanded_idx_mapping,
all_token_ids,
all_token_ids.stride(0),
prompt_len,

View File

@@ -56,7 +56,7 @@ class Sampler:
def __call__(
self,
logits: torch.Tensor,
idx_mapping: torch.Tensor,
expanded_idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
cu_num_logits_np: np.ndarray,
pos: torch.Tensor,
@@ -68,7 +68,7 @@ class Sampler:
num_nans = get_num_nans(logits) if self.compute_nans else None
sampled, processed_logits = self.sample(
logits,
idx_mapping,
expanded_idx_mapping,
idx_mapping_np,
pos,
input_ids,
@@ -101,7 +101,7 @@ class Sampler:
def sample(
self,
logits: torch.Tensor,
idx_mapping: torch.Tensor,
expanded_idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
pos: torch.Tensor,
input_ids: torch.Tensor,
@@ -111,12 +111,14 @@ class Sampler:
logits = torch.empty_like(logits, dtype=torch.float32).copy_(logits)
# Apply logit bias (e.g., allowed_token_ids, min_tokens) in place.
self.logit_bias_state.apply_logit_bias(logits, idx_mapping, idx_mapping_np, pos)
self.logit_bias_state.apply_logit_bias(
logits, expanded_idx_mapping, idx_mapping_np, pos
)
# Apply penalties in place.
self.penalties_state.apply_penalties(
logits,
idx_mapping,
expanded_idx_mapping,
idx_mapping_np,
input_ids,
expanded_local_pos,
@@ -126,27 +128,29 @@ class Sampler:
# Apply bad words masking in place.
self.bad_words_state.apply_bad_words(
logits,
idx_mapping,
expanded_idx_mapping,
idx_mapping_np,
input_ids,
expanded_local_pos,
)
# Apply temperature in place.
self.sampling_states.apply_temperature(logits, idx_mapping, idx_mapping_np)
self.sampling_states.apply_temperature(
logits, expanded_idx_mapping, idx_mapping_np
)
# Apply min_p in place.
self.sampling_states.apply_min_p(logits, idx_mapping, idx_mapping_np)
self.sampling_states.apply_min_p(logits, expanded_idx_mapping, idx_mapping_np)
# Apply top_k and/or top_p. This might or might not return a new tensor.
logits = self.sampling_states.apply_top_k_top_p(
logits, idx_mapping, idx_mapping_np
logits, expanded_idx_mapping, idx_mapping_np
)
# Sample the next token.
sampled = gumbel_sample(
logits,
idx_mapping,
expanded_idx_mapping,
self.sampling_states.temperature.gpu,
self.sampling_states.seeds.gpu,
pos,

View File

@@ -64,7 +64,7 @@ class SamplingStates:
def apply_temperature(
self,
logits: torch.Tensor,
idx_mapping: torch.Tensor,
expanded_idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
) -> None:
temp_np = self.temperature.np[idx_mapping_np]
@@ -72,23 +72,23 @@ class SamplingStates:
# No request requires temperature. Skip the kernel launch.
return
apply_temperature(logits, idx_mapping, self.temperature.gpu)
apply_temperature(logits, expanded_idx_mapping, self.temperature.gpu)
def apply_min_p(
self,
logits: torch.Tensor,
idx_mapping: torch.Tensor,
expanded_idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
) -> None:
if np.all(self.min_p.np[idx_mapping_np] == 0.0):
# No request uses min_p. Skip the kernel launch.
return
apply_min_p(logits, idx_mapping, self.min_p.gpu)
apply_min_p(logits, expanded_idx_mapping, self.min_p.gpu)
def apply_top_k_top_p(
self,
logits: torch.Tensor,
idx_mapping: torch.Tensor,
expanded_idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
) -> torch.Tensor:
do_top_k = np.any(self.top_k.np[idx_mapping_np] != self.vocab_size)
@@ -96,8 +96,8 @@ class SamplingStates:
if not (do_top_k or do_top_p):
return logits
top_k = self.top_k.gpu[idx_mapping] if do_top_k else None
top_p = self.top_p.gpu[idx_mapping] if do_top_p else None
top_k = self.top_k.gpu[expanded_idx_mapping] if do_top_k else None
top_p = self.top_p.gpu[expanded_idx_mapping] if do_top_p else None
return apply_top_k_top_p(logits, top_k, top_p)
def max_num_logprobs(self, idx_mapping_np: np.ndarray) -> int:

View File

@@ -17,6 +17,7 @@ from vllm.v1.worker.gpu.cudagraph_utils import (
)
from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
from vllm.v1.worker.gpu.input_batch import InputBuffers
from vllm.v1.worker.gpu.model_states.interface import ModelState
from vllm.v1.worker.utils import AttentionGroup
@@ -54,11 +55,32 @@ class EagleCudaGraphManager:
def get_cudagraph_size(self, num_tokens: int) -> int | None:
return self.cudagraph_sizes.get(num_tokens)
def get_cudagraph_runtime_mode(
self, num_tokens: int
) -> tuple[CUDAGraphMode, int | None]:
cudagraph_size = self.get_cudagraph_size(num_tokens)
if cudagraph_size is None:
cudagraph_mode = CUDAGraphMode.NONE
else:
cudagraph_mode = self.cudagraph_mode
if (
cudagraph_mode == CUDAGraphMode.FULL
and cudagraph_size is not None
and cudagraph_size not in self.graphs
):
# If graph wasn't captured yet, fall back to eager.
# This might happen when the dummy run is called before capture.
cudagraph_mode = CUDAGraphMode.NONE
cudagraph_size = None
return cudagraph_mode, cudagraph_size
def capture_graph(
self,
num_tokens: int,
capture_cg_mode: CUDAGraphMode,
generate_fn: Callable,
model_state: ModelState,
input_buffers: InputBuffers,
block_tables: BlockTables,
attn_groups: list[list[AttentionGroup]],
@@ -76,12 +98,11 @@ class EagleCudaGraphManager:
attn_metadata, slot_mappings = prepare_inputs_to_capture(
num_reqs,
num_tokens,
model_state,
input_buffers,
block_tables,
attn_groups,
self.max_model_len,
kv_cache_config,
uniform_decode_query_len=1,
)
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
@@ -158,6 +179,7 @@ class EagleCudaGraphManager:
def capture(
self,
generate_fn: Callable,
model_state: ModelState,
input_buffers: InputBuffers,
block_tables: BlockTables,
attn_groups: list[list[AttentionGroup]],
@@ -173,6 +195,7 @@ class EagleCudaGraphManager:
capture_cudagraph_mode=self.cudagraph_mode,
desc=f"Capturing eagle CUDA graphs ({self.cudagraph_mode.name})",
generate_fn=generate_fn,
model_state=model_state,
input_buffers=input_buffers,
block_tables=block_tables,
attn_groups=attn_groups,

View File

@@ -16,7 +16,9 @@ from vllm.v1.worker.gpu.attn_utils import (
build_slot_mappings_by_layer,
)
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.dp_utils import get_cudagraph_and_dp_padding
from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
from vllm.v1.worker.gpu.model_states.interface import ModelState
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
from vllm.v1.worker.gpu.spec_decode.eagle.cudagraph import EagleCudaGraphManager
from vllm.v1.worker.gpu.spec_decode.eagle.utils import load_eagle_model
@@ -44,10 +46,13 @@ class EagleSpeculator:
# the draft model's hidden size can be different from the target model's
# hidden size (e.g., Llama 3.3 70B).
self.hidden_size = self.draft_model_config.get_hidden_size()
self.inputs_embeds_size = self.draft_model_config.get_inputs_embeds_size()
self.vocab_size = self.draft_model_config.get_vocab_size()
self.dtype = vllm_config.model_config.dtype
# DP configuration
self.dp_size = vllm_config.parallel_config.data_parallel_size
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
self.input_buffers = InputBuffers(
max_num_reqs=self.max_num_reqs,
max_num_tokens=self.max_num_tokens,
@@ -77,10 +82,12 @@ class EagleSpeculator:
def set_attn(
self,
model_state: ModelState,
kv_cache_config: KVCacheConfig,
attn_groups: list[list[AttentionGroup]],
block_tables: BlockTables,
) -> None:
self.model_state = model_state
self.kv_cache_config = kv_cache_config
self.attn_groups = attn_groups
self.block_tables = block_tables
@@ -120,8 +127,8 @@ class EagleSpeculator:
self,
num_reqs: int,
num_tokens_padded: int,
attn_metadata: dict[str, Any],
slot_mappings: dict[str, torch.Tensor],
attn_metadata: dict[str, Any] | None,
slot_mappings: dict[str, torch.Tensor] | None,
num_tokens_across_dp: torch.Tensor | None,
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
) -> None:
@@ -162,9 +169,10 @@ class EagleSpeculator:
self.hidden_states,
self.max_model_len,
)
self.block_tables.compute_slot_mappings(
idx_mapping, query_start_loc, pos
)
if attn_metadata is not None:
self.block_tables.compute_slot_mappings(
idx_mapping, query_start_loc, pos
)
def capture_model(self) -> None:
if self.num_speculative_steps == 1:
@@ -172,6 +180,7 @@ class EagleSpeculator:
logger.info("Capturing model for Eagle speculator...")
self.cudagraph_manager.capture(
self.generate_draft,
self.model_state,
self.input_buffers,
self.block_tables,
self.attn_groups,
@@ -182,6 +191,8 @@ class EagleSpeculator:
def propose(
self,
input_batch: InputBatch,
attn_metadata: dict[str, Any],
slot_mappings: dict[str, torch.Tensor],
# [num_tokens, hidden_size]
last_hidden_states: torch.Tensor,
# num_layers x [num_tokens, hidden_size]
@@ -198,6 +209,9 @@ class EagleSpeculator:
temperature: torch.Tensor,
# [max_num_reqs]
seeds: torch.Tensor,
num_tokens_across_dp: torch.Tensor | None = None,
dummy_run: bool = False,
skip_attn_for_dummy_run: bool = False,
) -> torch.Tensor:
# NOTE(woosuk): To avoid CPU-GPU synchronization without CPU knowing the
# number of rejected tokens, we maintain the size of eagle's input_ids and
@@ -229,9 +243,9 @@ class EagleSpeculator:
# TODO(woosuk): Support CUDA graph for prefill.
last_hidden_states, hidden_states = self.run_model(
num_tokens,
input_batch.attn_metadata,
input_batch.slot_mappings,
num_tokens_across_dp=None, # FIXME
attn_metadata,
slot_mappings,
num_tokens_across_dp=num_tokens_across_dp,
)
sample_hidden_states = last_hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states)
@@ -277,48 +291,64 @@ class EagleSpeculator:
self.max_model_len,
self.max_num_reqs,
)
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
slot_mappings = self.block_tables.compute_slot_mappings(
idx_mapping, query_start_loc, pos
)
cudagraph_size = self.cudagraph_manager.get_cudagraph_size(num_reqs)
cudagraph_mode = self.cudagraph_manager.cudagraph_mode
if cudagraph_size is not None and cudagraph_mode == CUDAGraphMode.FULL:
if not (dummy_run and skip_attn_for_dummy_run):
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
slot_mappings = self.block_tables.compute_slot_mappings(
idx_mapping, query_start_loc, pos
)
cudagraph_mode, cudagraph_size = (
self.cudagraph_manager.get_cudagraph_runtime_mode(num_reqs)
)
num_tokens_padded, num_tokens_across_dp, synced_cudagraph_mode = (
get_cudagraph_and_dp_padding(
num_reqs,
cudagraph_size,
cudagraph_mode.value,
self.dp_size,
self.dp_rank,
)
)
cudagraph_mode = CUDAGraphMode(synced_cudagraph_mode)
if cudagraph_mode == CUDAGraphMode.FULL:
# Run full CUDA graph.
self.cudagraph_manager.run_fullgraph(cudagraph_size)
self.cudagraph_manager.run_fullgraph(num_tokens_padded)
return self.draft_tokens[:num_reqs]
# Run eager or piecewise CUDA graph.
num_tokens_padded = cudagraph_size if cudagraph_size is not None else num_reqs
query_start_loc_cpu = torch.arange(
num_reqs + 1, dtype=torch.int32, device="cpu"
)
block_tables = [x[:num_reqs] for x in self.block_tables.input_block_tables]
attn_metadata_updated = None
slot_mappings_updated = None
if not (dummy_run and skip_attn_for_dummy_run):
query_start_loc_cpu = torch.arange(
num_reqs + 1, dtype=torch.int32, device="cpu"
)
block_tables = [x[:num_reqs] for x in self.block_tables.input_block_tables]
# FIXME(woosuk): This is UNSAFE!!
attn_metadata_updated = build_attn_metadata(
attn_groups=self.attn_groups,
num_reqs=num_reqs,
num_tokens=num_reqs,
query_start_loc_gpu=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
max_query_len=1,
seq_lens=self.input_buffers.seq_lens[:num_reqs],
max_seq_len=self.max_model_len,
block_tables=block_tables,
slot_mappings=slot_mappings,
kv_cache_config=self.kv_cache_config,
)
slot_mappings_updated = build_slot_mappings_by_layer(
slot_mappings, self.kv_cache_config
)
# FIXME(woosuk): This is UNSAFE!!
attn_metadata = build_attn_metadata(
attn_groups=self.attn_groups,
num_reqs=num_reqs,
num_tokens=num_reqs,
query_start_loc_gpu=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
max_query_len=1,
seq_lens=self.input_buffers.seq_lens[:num_reqs],
max_seq_len=self.max_model_len,
block_tables=block_tables,
slot_mappings=slot_mappings,
kv_cache_config=self.kv_cache_config,
)
slot_mappings_by_layer = build_slot_mappings_by_layer(
slot_mappings, self.kv_cache_config
)
self.generate_draft(
num_reqs,
num_tokens_padded,
attn_metadata,
slot_mappings_by_layer,
num_tokens_across_dp=None, # FIXME
attn_metadata_updated,
slot_mappings_updated,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_mode,
)
return self.draft_tokens[:num_reqs]

View File

@@ -0,0 +1,105 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import torch
from vllm import PoolingParams, SamplingParams
from vllm.v1.core.sched.output import (
CachedRequestData,
GrammarOutput,
NewRequestData,
SchedulerOutput,
)
from vllm.v1.request import Request
from vllm.v1.worker.gpu.model_runner import GPUModelRunner
@torch.inference_mode()
def warmup_kernels(model_runner: GPUModelRunner) -> None:
"""Run two execute_model + sample_tokens iterations to JIT compile
triton kernels.
The first iteration simulates a prefill with requests of 2 prompt
tokens each. The second iteration simulates a decode step with all
requests generating 1 token each.
"""
prompt_token_ids = [0, 1]
prompt_len = len(prompt_token_ids)
num_reqs = min(
model_runner.scheduler_config.max_num_seqs,
model_runner.scheduler_config.max_num_batched_tokens // prompt_len,
)
num_kv_cache_groups = len(model_runner.kv_cache_config.kv_cache_groups)
req_ids = [f"_warmup_{i}_" for i in range(num_reqs)]
# SamplingParams exercising all sampling features.
if model_runner.is_pooling_model:
sampling_params = None
pooling_params = PoolingParams()
else:
sampling_params = SamplingParams.for_sampler_warmup()
pooling_params = None
# Step 1: Prefill all requests with 2 prompt tokens each.
new_reqs = [
NewRequestData.from_request(
Request(req_ids[i], prompt_token_ids, sampling_params, pooling_params),
# Each request uses a distinct block per KV cache group.
block_ids=tuple([i] for _ in range(num_kv_cache_groups)),
prefill_token_ids=prompt_token_ids,
)
for i in range(num_reqs)
]
prefill_output = SchedulerOutput.make_empty()
prefill_output.scheduled_new_reqs = new_reqs
prefill_output.num_scheduled_tokens = {rid: prompt_len for rid in req_ids}
prefill_output.total_num_scheduled_tokens = prompt_len * num_reqs
prefill_output.num_common_prefix_blocks = [0] * num_kv_cache_groups
# Disable KV connector for warmup run.
model_runner.kv_connector.set_disabled(True)
model_runner.execute_model(prefill_output)
if not model_runner.is_pooling_model:
# Warm up sampler and perform a decode step for non-pooling models.
grammar_output = None
if model_runner.is_last_pp_rank:
# Build a GrammarOutput to exercise the structured output bitmask
# kernel during the prefill step.
vocab_size = model_runner.model_config.get_vocab_size()
bitmask_width = (vocab_size + 31) // 32
grammar_bitmask = np.full(
(len(req_ids), bitmask_width), fill_value=-1, dtype=np.int32
)
grammar_output = GrammarOutput(
structured_output_request_ids=req_ids, grammar_bitmask=grammar_bitmask
)
model_runner.sample_tokens(grammar_output)
# Step 2: Decode all requests with 1 token each.
cached_req_data = CachedRequestData.make_empty()
cached_req_data.req_ids = list(req_ids)
cached_req_data.new_block_ids = [None] * num_reqs
cached_req_data.num_computed_tokens = [prompt_len] * num_reqs
cached_req_data.num_output_tokens = [1] * num_reqs
decode_output = SchedulerOutput.make_empty()
decode_output.scheduled_cached_reqs = cached_req_data
decode_output.num_scheduled_tokens = {rid: 1 for rid in req_ids}
decode_output.total_num_scheduled_tokens = num_reqs
decode_output.num_common_prefix_blocks = [0] * num_kv_cache_groups
model_runner.execute_model(decode_output)
model_runner.sample_tokens(None)
# Clean up - process finish_req_ids.
cleanup_output = SchedulerOutput.make_empty()
cleanup_output.finished_req_ids = set(req_ids)
model_runner.execute_model(cleanup_output)
model_runner.kv_connector.set_disabled(False)
torch.cuda.synchronize()

View File

@@ -53,6 +53,13 @@ class CachedRequestState:
pooling_params: PoolingParams | None = None
pooling_states: PoolingStates | None = None
# for multi layer eagle proposer
cached_len: torch.Tensor | None = None
cached_token_ids: torch.Tensor | None = None
cached_hidden_states: torch.Tensor | None = None
cached_slot_mappings: torch.Tensor | None = None
cached_positions: torch.Tensor | None = None
def __post_init__(self):
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
self.prompt_token_ids, self.prompt_embeds
@@ -95,6 +102,8 @@ class InputBatch:
is_spec_decode: bool = False,
is_pooling_model: bool = False,
cp_kv_cache_interleave_size: int = 1,
multi_layer_eagle_num: int = 0,
hidden_size: int | None = None,
):
self.is_pooling_model = is_pooling_model
self.is_spec_decode = is_spec_decode
@@ -211,6 +220,46 @@ class InputBatch:
)
self.num_accepted_tokens_cpu = self.num_accepted_tokens_cpu_tensor.numpy()
# Multi layer eagle
self.multi_layer_eagle_num = multi_layer_eagle_num
if multi_layer_eagle_num > 0:
self.cached_len = torch.zeros(
(max_num_reqs,), dtype=torch.int64, device=device
)
self.cached_token_ids = torch.zeros(
(
max_num_reqs,
multi_layer_eagle_num,
),
dtype=torch.int32,
device=device,
)
self.cached_hidden_states = torch.zeros(
(
max_num_reqs,
multi_layer_eagle_num,
hidden_size,
),
dtype=torch.float,
device=device,
)
self.cached_slot_mappings = torch.zeros(
(
max_num_reqs,
multi_layer_eagle_num,
),
dtype=torch.int64,
device=device,
)
self.cached_positions = torch.zeros(
(
max_num_reqs,
multi_layer_eagle_num,
),
dtype=torch.int64,
device=device,
)
# lora related
self.request_lora_mapping = np.zeros((self.max_num_reqs,), dtype=np.int64)
self.lora_id_to_request_ids: dict[int, set[str]] = {}
@@ -425,6 +474,13 @@ class InputBatch:
# Speculative decoding: by default 1 token is generated.
self.num_accepted_tokens_cpu[req_index] = 1
if self.multi_layer_eagle_num > 0:
self.cached_len[req_index] = request.cached_len
self.cached_token_ids[req_index] = request.cached_token_ids
self.cached_hidden_states[req_index] = request.cached_hidden_states
self.cached_slot_mappings[req_index] = request.cached_slot_mappings
self.cached_positions[req_index] = request.cached_positions
# Add request lora ID
if request.lora_request:
lora_id = request.lora_request.lora_int_id
@@ -623,6 +679,24 @@ class InputBatch:
self.allowed_token_ids_mask_cpu_tensor[i1],
)
if self.multi_layer_eagle_num > 0:
self.cached_len[i1], self.cached_len[i2] = (
self.cached_len[i2],
self.cached_len[i1],
)
self.cached_token_ids[[i1, i2], ...] = self.cached_token_ids[
[i2, i1], ...
]
self.cached_hidden_states[[i1, i2], ...] = self.cached_hidden_states[
[i2, i1], ...
]
self.cached_slot_mappings[[i1, i2], ...] = self.cached_slot_mappings[
[i2, i1], ...
]
self.cached_positions[[i1, i2], ...] = self.cached_positions[
[i2, i1], ...
]
def condense(self) -> None:
"""Slide non-empty requests down into lower, empty indices.
@@ -745,6 +819,21 @@ class InputBatch:
if bad_words_token_ids is not None:
self.bad_words_token_ids[empty_index] = bad_words_token_ids
if self.multi_layer_eagle_num > 0:
self.cached_len[empty_index] = self.cached_len[last_req_index]
self.cached_token_ids[empty_index] = self.cached_token_ids[
last_req_index
]
self.cached_hidden_states[empty_index] = self.cached_hidden_states[
last_req_index
]
self.cached_slot_mappings[empty_index] = self.cached_slot_mappings[
last_req_index
]
self.cached_positions[empty_index] = self.cached_positions[
last_req_index
]
# Decrement last_req_index since it is now empty.
last_req_index -= 1

File diff suppressed because it is too large Load Diff

View File

@@ -7,11 +7,10 @@ import os
from collections.abc import Callable
from contextlib import AbstractContextManager, nullcontext
from types import NoneType
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any
import numpy as np
import torch
import torch.distributed
import torch.nn as nn
import vllm.envs as envs
@@ -32,14 +31,13 @@ from vllm.distributed.kv_transfer import (
)
from vllm.distributed.parallel_state import (
Handle,
get_pcp_group,
get_pp_group,
get_tp_group,
get_world_group
)
from vllm.distributed.weight_transfer import WeightTransferEngineFactory
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.models.interfaces import is_mixture_of_experts
from vllm.model_executor.warmup.kernel_warmup import kernel_warmup
from vllm.platforms import current_platform
from vllm.profiler.wrapper import CudaProfilerWrapper, TorchProfilerWrapper
@@ -49,7 +47,6 @@ from vllm.tracing import instrument
from vllm.utils.mem_utils import MemorySnapshot, format_gib, memory_profiling
from vllm.utils.torch_utils import set_random_seed
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import (
AsyncModelRunnerOutput,
@@ -61,6 +58,8 @@ from vllm.v1.worker.utils import is_residual_scattered_for_sp
from vllm.v1.worker.worker_base import WorkerBase
from vllm.v1.worker.workspace import init_workspace_manager
from ...model_executor.model_loader import TensorizerLoader
from .gpu.warmup import warmup_kernels
from .utils import request_memory
logger = init_logger(__name__)
@@ -123,6 +122,10 @@ class Worker(WorkerBase):
precision = envs.VLLM_FLOAT32_MATMUL_PRECISION
torch.set_float32_matmul_precision(precision)
from vllm.distributed.elastic_ep.elastic_execute import ElasticEPScalingExecutor
self.elastic_ep_executor = ElasticEPScalingExecutor(self)
# Buffers saved before sleep
self._sleep_saved_buffers: dict[str, torch.Tensor] = {}
@@ -316,12 +319,29 @@ class Worker(WorkerBase):
# FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
# to hijack tensor allocation.
def load_model(self) -> None:
eep_scale_up = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
dummy_weights = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
if dummy_weights:
(
expanded_physical_to_logical,
num_logical_experts,
old_num_physical_experts,
) = self.elastic_ep_executor.receive_expert_mapping()
num_physical_experts = expanded_physical_to_logical.shape[1]
self.parallel_config.eplb_config.num_redundant_experts = (
num_physical_experts - num_logical_experts
)
with (
self._maybe_get_memory_pool_context(tag="weights"),
set_current_vllm_config(self.vllm_config),
):
self.model_runner.load_model(eep_scale_up=eep_scale_up)
self.model_runner.load_model(load_dummy_weights=dummy_weights)
if dummy_weights:
self.model_runner.setup_eplb_from_mapping(
expanded_physical_to_logical, old_num_physical_experts
)
self.model_runner.eep_eplb_suppressed = True
def update_config(self, overrides: dict[str, Any]) -> None:
self.model_runner.update_config(overrides)
@@ -421,9 +441,10 @@ class Worker(WorkerBase):
# metadata across workers.
if (metadata := connector.get_handshake_metadata()) is None:
return None
tp_rank = get_tp_group().rank_in_group
return {tp_rank: metadata}
# tp_rank = get_tp_group().rank_in_group
global_rank = get_world_group().rank_in_group
return {global_rank: metadata}
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
return self.model_runner.get_kv_cache_spec()
@@ -461,8 +482,16 @@ class Worker(WorkerBase):
else:
self.model_runner.initialize_kv_cache(kv_cache_config)
# Build KV-zero metadata outside the CuMem pool so the bookkeeping
# GPU tensors (seg_addrs, block-id buffers) use the standard PyTorch
# allocator and are not discarded during sleep/wake cycles.
if kv_cache_config.needs_kv_cache_zeroing and hasattr(
self.model_runner, "_init_kv_zero_meta"
):
self.model_runner._init_kv_zero_meta()
@instrument(span_name="Warmup (GPU)")
def compile_or_warm_up_model(self) -> None:
def compile_or_warm_up_model(self) -> float:
warmup_sizes = []
if self.vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE:
@@ -558,12 +587,15 @@ class Worker(WorkerBase):
logger.debug(msg)
# Warm up sampler and preallocate memory buffer for logits and other
# sampling related tensors of max possible shape to avoid memory
# fragmentation issue.
# NOTE: This is called after `capture_model` on purpose to prevent
# memory buffers from being cleared by `torch.cuda.empty_cache`.
if get_pp_group().is_last_rank:
if self.use_v2_model_runner:
# V2: Run full execute_model + sample_tokens to JIT compile triton kernels.
warmup_kernels(self.model_runner)
elif get_pp_group().is_last_rank:
# V1: Warm up sampler and preallocate memory buffer for logits and other
# sampling related tensors of max possible shape to avoid memory
# fragmentation issue.
# NOTE: This is called after `capture_model` on purpose to prevent
# memory buffers from being cleared by `torch.cuda.empty_cache`.
max_num_reqs = min(
self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens,
@@ -584,6 +616,8 @@ class Worker(WorkerBase):
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
return self.compilation_config.compilation_time
def reset_mm_cache(self) -> None:
self.model_runner.reset_mm_cache()
@@ -696,6 +730,12 @@ class Worker(WorkerBase):
output = self.model_runner.execute_model(
scheduler_output, intermediate_tensors
)
if (
self.use_v2_model_runner
and self.model_runner.is_pooling_model
and output is None
):
output = self.model_runner.pool() # type: ignore
if isinstance(
output, ModelRunnerOutput | AsyncModelRunnerOutput | NoneType
):
@@ -744,7 +784,8 @@ class Worker(WorkerBase):
# Create the profiler wrapper only on the first start call
if self.profiler is None:
if self.profiler_config.profiler == "torch":
profiler_type = self.profiler_config.profiler
if profiler_type == "torch":
self.profiler = TorchProfilerWrapper(
self.profiler_config,
worker_name=trace_name,
@@ -754,14 +795,18 @@ class Worker(WorkerBase):
logger.debug(
"Starting torch profiler with trace name: %s", trace_name
)
elif self.profiler_config.profiler == "cuda":
elif profiler_type == "cuda":
self.profiler = CudaProfilerWrapper(self.profiler_config)
logger.debug("Starting CUDA profiler")
self.profiler.start()
else:
# Profiler already initialized. Restart profiling but keep
# the original trace name from the first initialization.
self.profiler.start()
else:
# Config validation should prevent this code being reached
raise ValueError(
f"Invalid profiler value of {self.profiler_config.profiler}"
)
# If profiler already initialized, restart profiling but keep
# the original trace name from the first initialization.
self.profiler.start()
else:
if self.profiler is None:
logger.warning("Profiler was not started, nothing to stop.")
@@ -787,227 +832,6 @@ class Worker(WorkerBase):
# worker will always be healthy as long as it's running.
return
def _eplb_before_scale_down(self, old_ep_size: int, new_ep_size: int) -> None:
from vllm.distributed.parallel_state import get_ep_group
if get_ep_group().rank == 0:
logger.info(
"[Elastic EP] Starting expert resharding before scaling down..."
)
rank_mapping = {
old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1
for old_ep_rank in range(old_ep_size)
}
assert self.model_runner.eplb_state is not None
self.model_runner.eplb_state.rearrange(
execute_shuffle=True,
global_expert_loads=None,
rank_mapping=rank_mapping,
)
torch.cuda.synchronize()
if get_ep_group().rank == 0:
logger.info("[Elastic EP] Expert resharding completed!")
def _eplb_after_scale_up(
self,
old_ep_size: int,
new_ep_size: int,
global_expert_loads: list[torch.Tensor] | None,
) -> None:
from vllm.distributed.parallel_state import get_ep_group
if get_ep_group().rank == 0:
logger.info("[Elastic EP] Starting expert resharding after scaling up...")
rank_mapping = {old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size)}
assert self.model_runner.eplb_state is not None
self.model_runner.eplb_state.rearrange(
execute_shuffle=True,
global_expert_loads=global_expert_loads,
rank_mapping=rank_mapping,
)
if get_ep_group().rank == 0:
logger.info("[Elastic EP] Expert resharding completed!")
def _reconfigure_parallel_config(
self, reconfig_request: ReconfigureDistributedRequest
) -> None:
"""
Update parallel config with provided reconfig_request
"""
parallel_config = self.vllm_config.parallel_config
parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
if (
reconfig_request.new_data_parallel_rank
!= ReconfigureRankType.KEEP_CURRENT_RANK
):
parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
if (
reconfig_request.new_data_parallel_rank_local
!= ReconfigureRankType.KEEP_CURRENT_RANK
):
parallel_config.data_parallel_rank_local = (
reconfig_request.new_data_parallel_rank_local
)
parallel_config.data_parallel_master_ip = (
reconfig_request.new_data_parallel_master_ip
)
parallel_config.data_parallel_master_port = (
reconfig_request.new_data_parallel_master_port
)
def _reconfigure_moe(
self, old_ep_size: int, new_ep_size: int
) -> list[torch.Tensor] | None:
"""
Reconfigure MoE modules with provided reconfig_request
Return the global expert load if new_ep_size > old_ep_size,
otherwise None
"""
from vllm.distributed.parallel_state import (
get_dp_group,
get_ep_group,
prepare_communication_buffer_for_model,
)
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE,
FusedMoEParallelConfig,
)
parallel_config = self.vllm_config.parallel_config
def get_moe_modules(model: torch.nn.Module) -> list[FusedMoE]:
return [
module
for module in model.modules()
if (
module.__class__.__name__ == "FusedMoE"
or module.__class__.__name__ == "SharedFusedMoE"
)
]
def update_moe_modules(moe_modules: list[FusedMoE], num_local_experts: int):
assert all(
module.moe_config.num_local_experts == num_local_experts
for module in moe_modules
), "All MoE modules must have the same number of experts"
for module in moe_modules:
module.moe_config.num_experts = num_local_experts * new_ep_size
module.global_num_experts = module.moe_config.num_experts
tp_size = get_tp_group().world_size
is_sequence_parallel = parallel_config.use_sequence_parallel_moe
sp_size = tp_size if is_sequence_parallel else 1
module.moe_parallel_config = FusedMoEParallelConfig.make(
tp_size_=tp_size,
pcp_size_=get_pcp_group().world_size,
dp_size_=get_dp_group().world_size,
sp_size_=sp_size,
vllm_parallel_config=parallel_config,
)
module.moe_config.moe_parallel_config = module.moe_parallel_config
return moe_modules
model_moe_modules = get_moe_modules(self.model_runner.model)
num_local_experts = model_moe_modules[0].moe_config.num_local_experts
update_moe_modules(model_moe_modules, num_local_experts)
drafter_model = None
if hasattr(self.model_runner, "drafter") and hasattr(
self.model_runner.drafter, "model"
):
drafter_model = self.model_runner.drafter.model
if drafter_model is not None and is_mixture_of_experts(drafter_model):
drafter_moe_modules = get_moe_modules(drafter_model)
# Check if drafter and model have matching configs
assert (
drafter_moe_modules[0].moe_config.num_local_experts == num_local_experts
), "Drafter and model configs should be the same"
update_moe_modules(drafter_moe_modules, num_local_experts)
if new_ep_size < old_ep_size:
num_local_physical_experts = num_local_experts
assert self.model_runner.eplb_state is not None
new_physical_experts = (
self.model_runner.eplb_state.physical_to_logical_map.shape[1] # type: ignore[attr-defined]
)
parallel_config.eplb_config.num_redundant_experts = (
new_physical_experts
- self.model_runner.eplb_state.logical_replica_count.shape[1] # type: ignore[attr-defined]
)
global_expert_loads = None
else:
num_local_physical_experts_tensor = torch.tensor(
[num_local_experts], dtype=torch.int32, device="cpu"
)
torch.distributed.broadcast(
num_local_physical_experts_tensor,
group=get_ep_group().cpu_group,
group_src=0,
)
num_local_physical_experts = int(num_local_physical_experts_tensor.item())
new_physical_experts = num_local_physical_experts * new_ep_size
assert self.model_runner.eplb_state is not None
global_expert_loads_any = self.model_runner.eplb_state.rearrange(
execute_shuffle=False
)
global_expert_loads = cast(list[torch.Tensor], global_expert_loads_any)
parallel_config.eplb_config.num_redundant_experts = (
new_physical_experts - global_expert_loads[0].shape[1]
)
prepare_communication_buffer_for_model(self.model_runner.model)
if drafter_model is not None:
prepare_communication_buffer_for_model(drafter_model)
self.model_runner.model.update_physical_experts_metadata(
num_physical_experts=new_physical_experts,
num_local_physical_experts=num_local_physical_experts,
)
return global_expert_loads
def reinitialize_distributed(
self, reconfig_request: ReconfigureDistributedRequest
) -> None:
from vllm.config import set_current_vllm_config
from vllm.distributed.parallel_state import (
cleanup_dist_env_and_memory,
get_ep_group,
)
old_ep_size = get_ep_group().world_size
old_ep_rank = get_ep_group().rank
new_ep_size = (
reconfig_request.new_data_parallel_size
* get_tp_group().world_size
* get_pp_group().world_size
)
if new_ep_size < old_ep_size:
self._eplb_before_scale_down(old_ep_size, new_ep_size)
cleanup_dist_env_and_memory()
if (
reconfig_request.new_data_parallel_rank
== ReconfigureRankType.SHUTDOWN_CURRENT_RANK
):
assert old_ep_rank >= new_ep_size
# shutdown
return
self._reconfigure_parallel_config(reconfig_request)
with set_current_vllm_config(self.vllm_config):
init_worker_distributed_environment(
self.vllm_config,
self.rank,
self.distributed_init_method,
self.local_rank,
)
global_expert_loads = self._reconfigure_moe(old_ep_size, new_ep_size)
if new_ep_size > old_ep_size:
assert global_expert_loads is not None
self._eplb_after_scale_up(old_ep_size, new_ep_size, global_expert_loads)
def save_sharded_state(
self,
path: str,
@@ -1023,12 +847,11 @@ class Worker(WorkerBase):
max_size=max_size,
)
def save_tensorized_model(
self,
tensorizer_config: "TensorizerConfig",
) -> None:
self.model_runner.save_tensorized_model(
def save_tensorized_model(self, tensorizer_config: "TensorizerConfig") -> None:
TensorizerLoader.save_model(
self.get_model(),
tensorizer_config=tensorizer_config,
model_config=self.model_config,
)
def init_weight_transfer_engine(self, init_info: dict) -> None:
@@ -1104,6 +927,9 @@ class Worker(WorkerBase):
if weight_transfer_engine := getattr(self, "weight_transfer_engine", None):
weight_transfer_engine.shutdown()
def elastic_ep_execute(self, execute_method: str, *args, **kwargs):
return self.elastic_ep_executor.execute(execute_method, *args, **kwargs)
def init_worker_distributed_environment(
vllm_config: VllmConfig,

View File

@@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
import itertools
from collections.abc import Callable
from typing import Any
import torch
@@ -13,6 +15,7 @@ from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.gpu_input_batch import CachedRequestState
from vllm.v1.worker.lora_model_runner_mixin import GPUInputBatch
@@ -59,10 +62,36 @@ def get_mamba_groups(kv_cache_config: KVCacheConfig) -> tuple[list[int], MambaSp
return mamba_group_ids, mamba_specs[0]
@dataclasses.dataclass
class MambaCopyBuffers:
src_ptrs: CpuGpuBuffer
dst_ptrs: CpuGpuBuffer
sizes: CpuGpuBuffer
offset: int = 0
@classmethod
def create(
cls,
max_num_reqs: int,
kv_cache_config: KVCacheConfig,
copy_funcs: tuple[MambaStateCopyFunc, ...],
make_buffer: Callable[..., CpuGpuBuffer],
) -> "MambaCopyBuffers":
mamba_group_ids, _ = get_mamba_groups(kv_cache_config)
entries_per_req = sum(
len(kv_cache_config.kv_cache_groups[gid].layer_names)
for gid in mamba_group_ids
) * len(copy_funcs)
n = max_num_reqs * entries_per_req
return cls(
src_ptrs=make_buffer(n, dtype=torch.int64),
dst_ptrs=make_buffer(n, dtype=torch.int64),
sizes=make_buffer(n, dtype=torch.int32),
)
def collect_mamba_copy_meta(
src_state_list: list[int],
dest_state_list: list[int],
num_elements_list: list[int],
copy_bufs: MambaCopyBuffers,
kv_cache_config: KVCacheConfig,
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
mamba_group_ids: list[int],
@@ -71,10 +100,15 @@ def collect_mamba_copy_meta(
accept_token_bias: int,
req_state: CachedRequestState,
forward_context: dict[str, Any],
):
) -> None:
if src_block_idx == dest_block_idx and accept_token_bias == 0:
return
src_ptrs_np = copy_bufs.src_ptrs.np
dst_ptrs_np = copy_bufs.dst_ptrs.np
sizes_np = copy_bufs.sizes.np
offset = copy_bufs.offset
for mamba_group_id in mamba_group_ids:
block_ids = req_state.block_ids[mamba_group_id]
dest_block_id = block_ids[dest_block_idx]
@@ -87,25 +121,23 @@ def collect_mamba_copy_meta(
state, block_ids, src_block_idx, accept_token_bias + 1
)
src_state_list.append(copy_spec.start_addr)
dest_state_list.append(state[dest_block_id].data_ptr())
num_elements_list.append(copy_spec.num_elements * state.element_size())
src_ptrs_np[offset] = copy_spec.start_addr
dst_ptrs_np[offset] = state[dest_block_id].data_ptr()
sizes_np[offset] = copy_spec.num_elements * state.element_size()
offset += 1
copy_bufs.offset = offset
def do_mamba_copy_block(
src_state_list: list[int],
dest_state_list: list[int],
num_elements_list: list[int],
):
if len(src_state_list) == 0:
def do_mamba_copy_block(copy_bufs: MambaCopyBuffers):
n = copy_bufs.offset
if n == 0:
return
assert len(src_state_list) == len(dest_state_list)
assert len(src_state_list) == len(num_elements_list)
src_state_ptrs = torch.tensor(src_state_list, device="cuda", dtype=torch.int64)
dst_state_ptrs = torch.tensor(dest_state_list, device="cuda", dtype=torch.int64)
num_elements = torch.tensor(num_elements_list, device="cuda", dtype=torch.int32)
batch_memcpy(src_state_ptrs, dst_state_ptrs, num_elements)
batch_memcpy(
copy_bufs.src_ptrs.copy_to_gpu(n),
copy_bufs.dst_ptrs.copy_to_gpu(n),
copy_bufs.sizes.copy_to_gpu(n),
)
def preprocess_mamba(
@@ -117,6 +149,7 @@ def preprocess_mamba(
requests: dict[str, CachedRequestState],
forward_context: dict[str, Any],
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
copy_bufs: MambaCopyBuffers,
):
"""
Copy the mamba state of previous step to the last
@@ -138,9 +171,7 @@ def preprocess_mamba(
for req_id in itertools.chain(finished_req_ids, preempted_req_ids, resumed_req_ids):
mamba_state_idx.pop(req_id, None)
src_state_list: list[int] = []
dest_state_list: list[int] = []
num_elements_list: list[int] = []
copy_bufs.offset = 0
for i, req_id in enumerate(input_batch.req_ids):
req_state = requests[req_id]
prev_state_idx = mamba_state_idx.get(req_id)
@@ -169,9 +200,7 @@ def preprocess_mamba(
mamba_state_idx[req_id] = curr_state_idx
if prev_state_idx != -1 and prev_state_idx != curr_state_idx:
collect_mamba_copy_meta(
src_state_list,
dest_state_list,
num_elements_list,
copy_bufs,
kv_cache_config,
mamba_state_copy_funcs,
mamba_group_ids,
@@ -182,7 +211,7 @@ def preprocess_mamba(
forward_context,
)
input_batch.num_accepted_tokens_cpu[i] = 1
do_mamba_copy_block(src_state_list, dest_state_list, num_elements_list)
do_mamba_copy_block(copy_bufs)
def postprocess_mamba(
@@ -193,6 +222,7 @@ def postprocess_mamba(
mamba_state_idx: dict[str, int],
forward_context: dict[str, Any],
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
copy_bufs: MambaCopyBuffers,
):
"""
If a blocks is converted from partial block to full block in this step, copy the
@@ -203,9 +233,7 @@ def postprocess_mamba(
num_accepted_tokens_cpu = input_batch.num_accepted_tokens_cpu
# NOTE: can be optimized as this function always returns the same result
mamba_group_ids, mamba_spec = get_mamba_groups(kv_cache_config)
src_state_list: list[int] = []
dest_state_list: list[int] = []
num_elements_list: list[int] = []
copy_bufs.offset = 0
for i, req_id in enumerate(input_batch.req_ids):
req_state = requests[req_id]
num_computed_tokens = req_state.num_computed_tokens
@@ -225,9 +253,7 @@ def postprocess_mamba(
src_block_idx = mamba_state_idx[req_id]
dest_block_idx = aligned_new_computed_tokens // mamba_spec.block_size - 1
collect_mamba_copy_meta(
src_state_list,
dest_state_list,
num_elements_list,
copy_bufs,
kv_cache_config,
mamba_state_copy_funcs,
mamba_group_ids,
@@ -239,4 +265,4 @@ def postprocess_mamba(
)
if src_block_idx == dest_block_idx:
num_accepted_tokens_cpu[i] = 1
do_mamba_copy_block(src_state_list, dest_state_list, num_elements_list)
do_mamba_copy_block(copy_bufs)

View File

@@ -2,7 +2,10 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections import defaultdict
from collections.abc import Iterable
from dataclasses import dataclass, field
from itertools import product as iprod
from typing import Any
import torch
@@ -12,13 +15,208 @@ from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.model_executor.models.utils import extract_layer_index
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import largest_power_of_2_divisor
from vllm.utils.mem_utils import MemorySnapshot, format_gib
from vllm.v1.attention.backend import AttentionBackend, AttentionMetadataBuilder
from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionMetadataBuilder,
MultipleOf,
)
from vllm.v1.kv_cache_interface import (
AttentionSpec,
EncoderOnlyAttentionSpec,
FullAttentionSpec,
KVCacheConfig,
KVCacheGroupSpec,
KVCacheSpec,
MambaSpec,
UniformTypeKVCacheSpecs,
)
logger = init_logger(__name__)
@triton.jit
def _zero_kv_blocks_kernel(
seg_addrs_ptr,
block_ids_ptr,
n_blocks,
N_SEGS: tl.constexpr,
PAGE_SIZE_EL: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""Zero KV cache blocks across all segments in a single launch.
Each segment is a contiguous region of one block's data. For backends
where blocks are outermost (block_dim=0) there is one segment per
buffer. For backends where K/V is outermost (block_dim=1) there are
two segments per buffer (one for K, one for V).
seg_addrs_ptr holds absolute byte addresses (int64) for each segment,
allowing segments to live in different CUDA allocations.
Programs are mapped as (block_index, seg_index, chunk_index).
"""
pid = tl.program_id(0)
chunks = PAGE_SIZE_EL // BLOCK_SIZE
work_per_block = N_SEGS * chunks
block_index = pid // work_per_block
if block_index >= n_blocks:
return
remainder = pid % work_per_block
seg_index = remainder // chunks
chunk_index = remainder % chunks
block_id = tl.load(block_ids_ptr + block_index)
seg_addr = tl.load(seg_addrs_ptr + seg_index)
ptr = tl.cast(seg_addr, tl.pointer_type(tl.int32))
offset = (
block_id.to(tl.int64) * PAGE_SIZE_EL + chunk_index.to(tl.int64) * BLOCK_SIZE
)
cols = tl.arange(0, BLOCK_SIZE).to(tl.int64)
tl.store(ptr + offset + cols, tl.zeros([BLOCK_SIZE], dtype=tl.int32))
class KVBlockZeroer:
"""Manages efficient zeroing of KV cache blocks via a Triton kernel.
Call :meth:`init_meta` once after KV caches are allocated to precompute
segment addresses, then call :meth:`zero_block_ids` each step to zero
newly-allocated blocks.
"""
def __init__(self, device: torch.device, pin_memory: bool):
self.device = device
self.pin_memory = pin_memory
self._meta: tuple[torch.Tensor, int, int, int] | None = None
self._id_cap: int = 0
self._ids_pinned: torch.Tensor | None = None
self._ids_gpu: torch.Tensor | None = None
def init_meta(
self,
attn_groups_iter: Iterable["AttentionGroup"],
kernel_block_sizes: list[int],
cache_dtype: str,
runner_only_attn_layers: set[str],
static_forward_context: dict[str, Any],
) -> None:
"""One-time precomputation for zero_block_ids.
Builds absolute-address table for the Triton zeroing kernel.
Each entry is the absolute byte address of a segment start on the
GPU, so segments in different CUDA allocations work correctly.
Block IDs from the scheduler reference logical blocks whose size
may differ from the kernel block size (virtual block splitting).
PAGE_SIZE_EL accounts for this ratio so that
``block_id * PAGE_SIZE_EL`` lands at the correct offset.
Only AttentionSpec layers are processed; Mamba layers are skipped.
"""
seen_ptrs: set[int] = set()
seg_addrs: list[int] = []
page_size_el: int | None = None
for group in attn_groups_iter:
spec = group.kv_cache_spec
if type(spec) is not FullAttentionSpec:
continue
if group.kv_cache_group_id >= len(kernel_block_sizes):
continue
kernel_bs = kernel_block_sizes[group.kv_cache_group_id]
ratio = spec.block_size // kernel_bs
block_dim = group.backend.get_kv_cache_block_dim(
kernel_bs,
spec.num_kv_heads,
spec.head_size,
cache_dtype_str=cache_dtype,
)
for layer_name in group.layer_names:
if layer_name in runner_only_attn_layers:
continue
kv = static_forward_context[layer_name].kv_cache[0]
if isinstance(kv, list):
continue
dp = kv.data_ptr()
if dp in seen_ptrs:
continue
seen_ptrs.add(dp)
el = kv.element_size()
cur_bytes = kv.stride(block_dim) * el
assert cur_bytes % 4 == 0
kernel_block_el = cur_bytes // 4
cur_page_el = kernel_block_el * ratio
if page_size_el is None:
page_size_el = cur_page_el
else:
assert page_size_el == cur_page_el, (
f"Non-uniform page sizes: {page_size_el} vs {cur_page_el}"
)
block_stride_bytes = cur_bytes
outer_dims = [
d
for d in range(block_dim)
if kv.stride(d) * el > block_stride_bytes
]
outer_strides = [kv.stride(d) * el for d in outer_dims]
for outer in iprod(*(range(kv.shape[d]) for d in outer_dims)):
off_bytes = sum(i * s for i, s in zip(outer, outer_strides))
seg_addrs.append(dp + off_bytes)
if not seg_addrs or page_size_el is None:
self._meta = None
return
blk_size = min(largest_power_of_2_divisor(page_size_el), 1024)
self._id_cap = 8192
self._ids_pinned = torch.empty(
self._id_cap,
dtype=torch.int64,
pin_memory=self.pin_memory,
)
self._ids_gpu = torch.empty(self._id_cap, dtype=torch.int64, device=self.device)
self._meta = (
torch.tensor(seg_addrs, dtype=torch.int64, device=self.device),
page_size_el,
blk_size,
len(seg_addrs),
)
def zero_block_ids(self, block_ids: list[int]) -> None:
"""Zero the KV cache memory for the given block IDs."""
if not block_ids or self._meta is None:
return
seg_addrs, page_size_el, blk_size, n_segs = self._meta
n_blocks = len(block_ids)
if n_blocks > self._id_cap:
self._id_cap = n_blocks * 2
self._ids_pinned = torch.empty(
self._id_cap,
dtype=torch.int64,
pin_memory=self.pin_memory,
)
self._ids_gpu = torch.empty(
self._id_cap, dtype=torch.int64, device=self.device
)
assert self._ids_pinned is not None and self._ids_gpu is not None
self._ids_pinned[:n_blocks].numpy()[:] = block_ids
idx = self._ids_gpu[:n_blocks]
idx.copy_(self._ids_pinned[:n_blocks], non_blocking=True)
grid = (n_blocks * n_segs * (page_size_el // blk_size),)
_zero_kv_blocks_kernel[grid](
seg_addrs,
idx,
n_blocks,
N_SEGS=n_segs,
PAGE_SIZE_EL=page_size_el,
BLOCK_SIZE=blk_size,
)
@dataclass
class AttentionGroup:
backend: type[AttentionBackend]
@@ -36,7 +234,7 @@ class AttentionGroup:
self,
vllm_config,
device,
kernel_block_size: int | None,
kernel_block_size: int | None = None,
num_metadata_builders: int = 1,
):
kv_cache_spec_builder = (
@@ -59,6 +257,119 @@ class AttentionGroup:
return self.metadata_builders[ubatch_id]
def select_common_block_size(
kv_manager_block_size: int, attn_groups: list[AttentionGroup]
) -> int:
"""
Select a block size that is supported by all backends and is a factor of
kv_manager_block_size.
If kv_manager_block_size is supported by all backends, return it directly.
Otherwise, return the max supported size.
Args:
kv_manager_block_size: Block size of KV cache.
attn_groups: List of attention groups.
Returns:
The selected block size.
Raises:
ValueError: If no valid block size found.
"""
def block_size_is_supported(
backends: list[type[AttentionBackend]], block_size: int
) -> bool:
"""Check if the block size is supported by all backends."""
for backend in backends:
is_supported = False
for supported_size in backend.get_supported_kernel_block_sizes():
if isinstance(supported_size, int):
if block_size == supported_size:
is_supported = True
elif isinstance(supported_size, MultipleOf):
if block_size % supported_size.base == 0:
is_supported = True
else:
raise ValueError(f"Unknown supported size: {supported_size}")
if not is_supported:
return False
return True
backends = [group.backend for group in attn_groups]
# Case 1: if the block_size of kv cache manager is supported by all backends,
# return it directly.
if block_size_is_supported(backends, kv_manager_block_size):
return kv_manager_block_size
# Case 2: otherwise, the block_size must be an `int`-format supported size of
# at least one backend. Iterate over all `int`-format supported sizes in
# descending order and return the first one that is supported by all backends.
# Simple proof:
# If the supported size b is in MultipleOf(x_i) format for all attention
# backends i, and b a factor of kv_manager_block_size, then
# kv_manager_block_size also satisfies MultipleOf(x_i) for all i. We will
# return kv_manager_block_size in case 1.
all_int_supported_sizes = set(
supported_size
for backend in backends
for supported_size in backend.get_supported_kernel_block_sizes()
if isinstance(supported_size, int)
)
for supported_size in sorted(all_int_supported_sizes, reverse=True):
if kv_manager_block_size % supported_size != 0:
continue
if block_size_is_supported(backends, supported_size):
return supported_size
raise ValueError(f"No common block size for {kv_manager_block_size}. ")
def prepare_kernel_block_sizes(
kv_cache_config: KVCacheConfig, attn_groups: list[list[AttentionGroup]]
) -> list[int]:
"""
Generate kernel_block_sizes that matches each block_size.
For attention backends that support virtual block splitting,
use the supported block sizes from the backend.
For other backends (like Mamba), use the same block size (no splitting).
Args:
kv_cache_config: The KV cache configuration.
attn_groups: Attention groups indexed by KV cache group id.
Returns:
List of kernel block sizes for each cache group.
"""
kernel_block_sizes = []
for kv_cache_gid, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
kv_cache_spec = kv_cache_group.kv_cache_spec
if isinstance(kv_cache_spec, UniformTypeKVCacheSpecs):
# All layers in the UniformTypeKVCacheSpecs have the same type,
# pick an arbitrary one to dispatch.
kv_cache_spec = next(iter(kv_cache_spec.kv_cache_specs.values()))
if isinstance(kv_cache_spec, EncoderOnlyAttentionSpec):
continue
if isinstance(kv_cache_spec, AttentionSpec):
# This is an attention backend that supports virtual block splitting.
kv_manager_block_size = kv_cache_group.kv_cache_spec.block_size
selected_kernel_size = select_common_block_size(
kv_manager_block_size, attn_groups[kv_cache_gid]
)
kernel_block_sizes.append(selected_kernel_size)
elif isinstance(kv_cache_spec, MambaSpec):
# This is likely Mamba or other non-attention cache, no splitting.
kernel_block_sizes.append(kv_cache_spec.block_size)
else:
raise NotImplementedError(
f"unknown kv cache spec {kv_cache_group.kv_cache_spec}"
)
return kernel_block_sizes
def sanity_check_mm_encoder_outputs(
mm_embeddings: MultiModalEmbeddings,
expected_num_items: int,
@@ -201,6 +512,55 @@ def bind_kv_cache(
for layer_name, kv_cache in kv_caches.items():
# NOTE: Use list because of v0 PP virtual engine.
forward_context[layer_name].kv_cache = [kv_cache]
def bind_kv_cache_scale(
kv_caches_scale: dict[str, torch.Tensor],
forward_context: dict[str, "Attention"],
runner_kv_caches_scale: list[torch.Tensor],
num_attn_module: int | None = 1,
) -> None:
"""
Bind the allocated KV cache to both ModelRunner and forward context so
that the KV cache can be used in the forward pass.
This function:
1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with
kv_caches.
2) Associates each attention layer in the `forward_context` with its
corresponding KV cache in kv_caches.
Args:
kv_caches: The allocated kv_caches with layer names as keys.
forward_context: The global forward context containing all Attention
layers with layer names as keys.
runner_kv_caches: The kv_cache declared by ModelRunner.
"""
# Bind kv_caches to ModelRunner
assert len(runner_kv_caches_scale) == 0
# Convert kv_caches dict to a list of tensors in the order of layer_index.
index2name = defaultdict(list)
for layer_name in kv_caches_scale:
index2name[extract_layer_index(layer_name,
num_attn_module)].append(layer_name)
for layer_index in sorted(index2name.keys()):
layer_names = index2name[layer_index]
if len(layer_names) > 1:
# One typical case is encoder-decoder model, e.g., bart.
# The cross attention and self attention in the same decoder layer
# has different layer_name but the same layer_index.
if current_platform.is_cuda() or current_platform.is_xpu():
pass
else:
raise NotImplementedError
layer_name = layer_names[0]
runner_kv_caches_scale.append(kv_caches_scale[layer_name])
# Bind kv_caches to forward context
for layer_name, kv_cache_scale in kv_caches_scale.items():
# NOTE: Use list because of v0 PP virtual engine.
forward_context[layer_name].kv_cache_scale = [kv_cache_scale]
def is_residual_scattered_for_sp(

View File

@@ -87,8 +87,12 @@ class WorkerBase:
"""Get specifications for KV cache implementation."""
raise NotImplementedError
def compile_or_warm_up_model(self) -> None:
"""Prepare model for execution through compilation/warmup."""
def compile_or_warm_up_model(self) -> float:
"""Prepare model for execution through compilation/warmup.
Returns:
The accumulated compilation time in seconds.
"""
raise NotImplementedError
def check_health(self) -> None:
@@ -213,13 +217,8 @@ class WorkerWrapperBase:
It is only used during the initialization of the executor,
to adjust the rpc_rank of workers after we create all workers.
"""
# if self.rpc_rank in rank_mapping:
# self.rpc_rank = rank_mapping[self.rpc_rank]
old_rank = self.rpc_rank
if old_rank in rank_mapping:
self.rpc_rank = rank_mapping[old_rank]
if self.global_rank == old_rank:
self.global_rank = rank_mapping[old_rank]
if self.rpc_rank in rank_mapping:
self.rpc_rank = rank_mapping[self.rpc_rank]
def update_environment_variables(
self,

View File

@@ -66,6 +66,23 @@ class WorkspaceManager:
],
)
def unlock(self) -> None:
"""Unlock the workspace to allow growth.
This is used during elastic EP scaling when the workspace size
needs to grow due to changes in the number of experts.
"""
self._locked = False
if envs.VLLM_DEBUG_WORKSPACE:
logger.info(
"[WORKSPACE DEBUG] Workspace unlocked. Current sizes: %s",
[
self._workspace_size_bytes(ws) / _MB
for ws in self._current_workspaces
if ws is not None
],
)
def is_locked(self) -> bool:
"""Check if workspace is locked."""
return self._locked
@@ -242,6 +259,17 @@ def lock_workspace() -> None:
current_workspace_manager().lock()
def unlock_workspace() -> None:
"""Unlock the workspace to allow growth.
This is used during elastic EP scaling when the workspace size
needs to grow due to changes in the number of experts.
After scaling operations complete, lock_workspace() should be
called again to prevent unexpected allocations.
"""
current_workspace_manager().unlock()
def reset_workspace_manager() -> None:
"""Reset the workspace manager to uninitialized state.