Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -86,6 +86,26 @@ class AttentionBackend(ABC):
|
||||
) -> tuple[int, ...]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_kv_cache_block_dim(
|
||||
cls,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> int:
|
||||
"""Discover which tensor dim is the block index, since different
|
||||
backends lay out dims differently."""
|
||||
_S = 1234567
|
||||
shape = cls.get_kv_cache_shape(
|
||||
_S,
|
||||
block_size,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
cache_dtype_str=cache_dtype_str,
|
||||
)
|
||||
return shape.index(_S)
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_stride_order(
|
||||
include_num_layers_dimension: bool = False,
|
||||
@@ -301,10 +321,13 @@ class CommonAttentionMetadata:
|
||||
|
||||
query_start_loc: torch.Tensor
|
||||
query_start_loc_cpu: torch.Tensor
|
||||
|
||||
"""(batch_size + 1,), the start location of each request in query Tensor"""
|
||||
|
||||
key_start_loc: torch.Tensor
|
||||
"""(batch_size + 1,), the start location of each request in key/valye Tensor(none-crossattention)"""
|
||||
seq_lens: torch.Tensor
|
||||
"""(batch_size,), the number of computed tokens for each request"""
|
||||
seq_lens_np: np.array
|
||||
|
||||
num_reqs: int
|
||||
"""Number of requests"""
|
||||
@@ -394,7 +417,9 @@ class CommonAttentionMetadata:
|
||||
return CommonAttentionMetadata(
|
||||
query_start_loc=self.query_start_loc[: num_actual_reqs + 1],
|
||||
query_start_loc_cpu=self.query_start_loc_cpu[: num_actual_reqs + 1],
|
||||
key_start_loc=self.key_start_loc[: num_actual_reqs + 1],
|
||||
seq_lens=self.seq_lens[:num_actual_reqs],
|
||||
seq_lens_np=self.seq_lens_np[:num_actual_reqs],
|
||||
_seq_lens_cpu=self._seq_lens_cpu[:num_actual_reqs]
|
||||
if self._seq_lens_cpu is not None
|
||||
else None,
|
||||
@@ -811,6 +836,28 @@ class MLAAttentionImpl(AttentionImplBase[T], Generic[T]):
|
||||
"""MQA-style decode forward pass."""
|
||||
raise NotImplementedError
|
||||
|
||||
def do_kv_cache_update(
|
||||
self,
|
||||
kv_c_normed: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
k_scale: torch.Tensor,
|
||||
) -> None:
|
||||
if kv_cache.numel() == 0:
|
||||
return
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
ops.concat_and_cache_mla(
|
||||
kv_c_normed,
|
||||
k_pe.squeeze(1),
|
||||
kv_cache,
|
||||
slot_mapping.flatten(),
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
scale=k_scale,
|
||||
)
|
||||
|
||||
|
||||
class SparseMLAAttentionImpl(AttentionImplBase[T], Generic[T]):
|
||||
"""Sparse MLA attention implementation with only forward_mqa method.
|
||||
@@ -856,6 +903,28 @@ class SparseMLAAttentionImpl(AttentionImplBase[T], Generic[T]):
|
||||
"""MQA-style decode forward pass."""
|
||||
raise NotImplementedError
|
||||
|
||||
def do_kv_cache_update(
|
||||
self,
|
||||
kv_c_normed: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
k_scale: torch.Tensor,
|
||||
) -> None:
|
||||
if kv_cache.numel() == 0:
|
||||
return
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
ops.concat_and_cache_mla(
|
||||
kv_c_normed,
|
||||
k_pe.squeeze(1),
|
||||
kv_cache,
|
||||
slot_mapping.flatten(),
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
scale=k_scale,
|
||||
)
|
||||
|
||||
|
||||
def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
|
||||
return kv_cache_dtype.startswith("fp8")
|
||||
|
||||
@@ -15,13 +15,11 @@ logger = init_logger(__name__)
|
||||
_ROCM_FLASH_ATTN_AVAILABLE = False
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from vllm._custom_ops import reshape_and_cache_flash
|
||||
# from vllm.vllm_flash_attn import ( # type: ignore[attr-defined]
|
||||
# flash_attn_varlen_func,
|
||||
# get_scheduler_metadata,
|
||||
# )
|
||||
from ixformer.contrib.vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
reshape_and_cache_flash = ops.reshape_and_cache_flash
|
||||
from ixformer.contrib.vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache, flash_attn_varlen_int8_func
|
||||
|
||||
elif current_platform.is_xpu():
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm._xpu_ops import xpu_ops
|
||||
@@ -53,67 +51,93 @@ elif current_platform.is_rocm():
|
||||
reshape_and_cache_flash = ops.reshape_and_cache_flash
|
||||
|
||||
|
||||
def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
|
||||
# import here to avoid circular dependencies
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
return 3
|
||||
|
||||
def get_flash_attn_version(
|
||||
requires_alibi: bool = False, head_size: int | None = None
|
||||
) -> int | None:
|
||||
if current_platform.is_xpu():
|
||||
return 2
|
||||
if current_platform.is_rocm():
|
||||
# ROCm doesn't use vllm_flash_attn; return None to skip fa_version arg
|
||||
return None
|
||||
try:
|
||||
from vllm.vllm_flash_attn.flash_attn_interface import (
|
||||
fa_version_unsupported_reason,
|
||||
is_fa_version_supported,
|
||||
)
|
||||
return None
|
||||
# try:
|
||||
# from vllm.vllm_flash_attn.flash_attn_interface import (
|
||||
# fa_version_unsupported_reason,
|
||||
# is_fa_version_supported,
|
||||
# )
|
||||
|
||||
device_capability = current_platform.get_device_capability()
|
||||
# device_capability = current_platform.get_device_capability()
|
||||
|
||||
assert device_capability is not None
|
||||
# assert device_capability is not None
|
||||
|
||||
# 1. default version depending on platform
|
||||
fa_version = (
|
||||
3 if (device_capability.major == 9 and is_fa_version_supported(3)) else 2
|
||||
)
|
||||
# # 1. default version depending on platform
|
||||
# if device_capability.major == 9 and is_fa_version_supported(3):
|
||||
# # Hopper (SM90): prefer FA3
|
||||
# fa_version = 3
|
||||
# elif device_capability.major == 10 and is_fa_version_supported(4):
|
||||
# # Blackwell (SM100+, restrict to SM100 for now): prefer FA4
|
||||
# fa_version = 4
|
||||
# else:
|
||||
# # Fallback to FA2
|
||||
# fa_version = 2
|
||||
|
||||
# 2. override if passed by environment or config
|
||||
from vllm.config import get_current_vllm_config_or_none
|
||||
# # 2. override if passed by environment or config
|
||||
# from vllm.config import get_current_vllm_config_or_none
|
||||
|
||||
vllm_config = get_current_vllm_config_or_none()
|
||||
if (
|
||||
vllm_config is not None
|
||||
and vllm_config.attention_config.flash_attn_version is not None
|
||||
):
|
||||
fa_version = vllm_config.attention_config.flash_attn_version
|
||||
# vllm_config = get_current_vllm_config_or_none()
|
||||
# if (
|
||||
# vllm_config is not None
|
||||
# and vllm_config.attention_config.flash_attn_version is not None
|
||||
# ):
|
||||
# fa_version = vllm_config.attention_config.flash_attn_version
|
||||
|
||||
# 3. fallback for unsupported combinations
|
||||
if device_capability.major == 10 and fa_version == 3:
|
||||
logger.warning_once(
|
||||
"Cannot use FA version 3 on Blackwell platform, "
|
||||
"defaulting to FA version 2."
|
||||
)
|
||||
fa_version = 2
|
||||
# # 3. fallback for unsupported combinations
|
||||
# if device_capability.major >= 10 and fa_version == 3:
|
||||
# logger.warning_once(
|
||||
# "Cannot use FA version 3 on Blackwell platform, "
|
||||
# "defaulting to FA version 4 if supported, otherwise FA2."
|
||||
# )
|
||||
# fa_version = 4 if is_fa_version_supported(4) else 2
|
||||
|
||||
if requires_alibi and fa_version == 3:
|
||||
logger.warning_once(
|
||||
"Cannot use FA version 3 with ALiBi, defaulting to FA version 2."
|
||||
)
|
||||
fa_version = 2
|
||||
# if requires_alibi and fa_version == 3:
|
||||
# logger.warning_once(
|
||||
# "Cannot use FA version 3 with ALiBi, defaulting to FA version 2."
|
||||
# )
|
||||
# fa_version = 2
|
||||
|
||||
if not is_fa_version_supported(fa_version):
|
||||
logger.error(
|
||||
"Cannot use FA version %d is not supported due to %s",
|
||||
fa_version,
|
||||
fa_version_unsupported_reason(fa_version),
|
||||
)
|
||||
# if requires_alibi and fa_version == 4:
|
||||
# logger.warning_once(
|
||||
# "Cannot use FA version 4 with ALiBi, defaulting to FA version 2."
|
||||
# )
|
||||
# fa_version = 2
|
||||
|
||||
assert is_fa_version_supported(fa_version)
|
||||
return fa_version
|
||||
except (ImportError, AssertionError):
|
||||
return None
|
||||
# # FA4 on SM100 (Blackwell) has TMEM capacity limits that restrict
|
||||
# # supported head dimensions.
|
||||
# # See: https://github.com/Dao-AILab/flash-attention/issues/1959
|
||||
# if (
|
||||
# fa_version == 4
|
||||
# and device_capability.major >= 10
|
||||
# and head_size is not None
|
||||
# and head_size > 128
|
||||
# ):
|
||||
# logger.warning_once(
|
||||
# "FA4 on Blackwell does not support head_size=%d due to TMEM "
|
||||
# "capacity limits, defaulting to FA version 2.",
|
||||
# head_size,
|
||||
# )
|
||||
# fa_version = 2
|
||||
|
||||
# if not is_fa_version_supported(fa_version):
|
||||
# logger.error(
|
||||
# "Cannot use FA version %d is not supported due to %s",
|
||||
# fa_version,
|
||||
# fa_version_unsupported_reason(fa_version),
|
||||
# )
|
||||
|
||||
# assert is_fa_version_supported(fa_version)
|
||||
# return fa_version
|
||||
# except (ImportError, AssertionError):
|
||||
# return None
|
||||
|
||||
|
||||
def flash_attn_supports_fp8() -> bool:
|
||||
@@ -124,10 +148,7 @@ def flash_attn_supports_fp8() -> bool:
|
||||
|
||||
|
||||
def flash_attn_supports_sinks() -> bool:
|
||||
if current_platform.is_xpu():
|
||||
return True
|
||||
else:
|
||||
return get_flash_attn_version() == 3
|
||||
return True
|
||||
|
||||
|
||||
def flash_attn_supports_mla():
|
||||
@@ -142,6 +163,10 @@ def flash_attn_supports_mla():
|
||||
return is_fa_version_supported(
|
||||
3
|
||||
) and current_platform.is_device_capability_family(90)
|
||||
|
||||
# NOTE(Lucas): FA4 CuteDSL does NOT currently support MLA's non-standard
|
||||
# head dimensions (576 for qk, 512 for v) due to TMEM capacity limits.
|
||||
|
||||
except (ImportError, AssertionError):
|
||||
pass
|
||||
return False
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
from typing import ClassVar, Optional, Union, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -23,15 +23,15 @@ from vllm.v1.attention.backends.fa_utils import (
|
||||
is_flash_attn_varlen_func_available,
|
||||
)
|
||||
from vllm.v1.attention.ops.common import cp_lse_ag_out_rs
|
||||
from vllm.v1.attention.ops.merge_attn_states import merge_attn_states
|
||||
from ixformer.contrib.vllm_flash_attn import merge_attn_states
|
||||
|
||||
if is_flash_attn_varlen_func_available():
|
||||
from vllm.v1.attention.backends.fa_utils import (
|
||||
flash_attn_supports_sinks,
|
||||
flash_attn_varlen_func,
|
||||
flash_attn_with_kvcache,
|
||||
# get_scheduler_metadata,
|
||||
reshape_and_cache_flash,
|
||||
flash_attn_varlen_int8_func
|
||||
)
|
||||
from vllm.config import VllmConfig, get_current_vllm_config, get_layers_from_vllm_config
|
||||
from vllm.config.cache import CacheDType
|
||||
@@ -50,9 +50,12 @@ from vllm.v1.attention.backend import (
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
get_dcp_local_seq_lens,
|
||||
get_kv_cache_layout,
|
||||
split_decodes_and_prefills,
|
||||
split_decodes_and_prefills
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm import _custom_ops as ops
|
||||
import vllm.envs as envs
|
||||
import ixformer.inference.functions as ixf_ops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -63,23 +66,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
vllm_config = get_current_vllm_config()
|
||||
model_config = vllm_config.model_config
|
||||
cache_config = vllm_config.cache_config
|
||||
if (
|
||||
model_config
|
||||
and model_config.is_hybrid
|
||||
and (
|
||||
cache_config.mamba_ssm_cache_dtype == "float32"
|
||||
or cache_config.mamba_cache_dtype == "float32"
|
||||
)
|
||||
):
|
||||
# NOTE(tdoublep): while in principle, FA supports
|
||||
# MultipleOf(16), these are the block sizes that do not
|
||||
# suffer from the NaN propagation problem described here:
|
||||
# https://github.com/Dao-AILab/flash-attention/issues/1974
|
||||
return [16, 32, 64]
|
||||
return [MultipleOf(16)]
|
||||
return [16, 32, 64]
|
||||
|
||||
forward_includes_kv_cache_update: bool = False
|
||||
|
||||
@@ -120,7 +107,8 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
) -> tuple[int, ...]:
|
||||
if block_size % 16 != 0:
|
||||
raise ValueError("Block size must be a multiple of 16.")
|
||||
# return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
if envs.VLLM_ATTN_OPT_LEVEL == 2:
|
||||
return (3, num_blocks, num_kv_heads, block_size, head_size)
|
||||
return (2, num_blocks, num_kv_heads, block_size, head_size)
|
||||
|
||||
@staticmethod
|
||||
@@ -139,7 +127,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
# (num_blocks, num_kv_heads, num_layers, 2, block_size, head_size)
|
||||
return (2, 4, 0, 1, 3, 5)
|
||||
elif cache_layout == "HND":
|
||||
stride_order = (0, 1, 3, 2, 4)
|
||||
stride_order = (0, 1, 2, 3, 4)
|
||||
else:
|
||||
raise ValueError(f"Unknown cache layout format {cache_layout}.")
|
||||
return stride_order
|
||||
@@ -188,24 +176,22 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
if has_sink and device_capability < DeviceCapability(9, 0):
|
||||
return "sink not supported on compute capability < 9.0"
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashAttentionPrefillMetadata:
|
||||
"""Prefill Specific Metadata"""
|
||||
|
||||
""" Prefill Specific Metadata """
|
||||
block_table: torch.Tensor
|
||||
query_start_loc: torch.Tensor
|
||||
key_start_loc: torch.Tensor
|
||||
max_query_len: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashAttentionDecodeMetadata:
|
||||
block_table: torch.Tensor
|
||||
query_start_loc: torch.Tensor
|
||||
seq_lens: torch.Tensor
|
||||
max_query_len: int
|
||||
max_decode_seq_len: int
|
||||
|
||||
use_graph: bool
|
||||
|
||||
@dataclass
|
||||
class FlashAttentionMetadata:
|
||||
@@ -220,11 +206,12 @@ class FlashAttentionMetadata:
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
max_query_len: int
|
||||
query_start_loc: torch.Tensor
|
||||
key_start_loc: torch.Tensor
|
||||
max_seq_len: int
|
||||
seq_lens: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
|
||||
num_decodes: int
|
||||
num_decode_tokens: int
|
||||
num_prefills: int
|
||||
@@ -235,7 +222,6 @@ class FlashAttentionMetadata:
|
||||
cu_prefix_query_lens: torch.Tensor | None
|
||||
prefix_kv_lens: torch.Tensor | None
|
||||
suffix_kv_lens: torch.Tensor | None
|
||||
|
||||
cu_prefix_kv_lens: torch.Tensor | None
|
||||
cu_suffix_kv_lens: torch.Tensor | None
|
||||
|
||||
@@ -247,7 +233,7 @@ class FlashAttentionMetadata:
|
||||
scheduler_metadata: torch.Tensor | None = None
|
||||
prefix_scheduler_metadata: torch.Tensor | None = None
|
||||
max_num_splits: int = 0
|
||||
|
||||
|
||||
prefill: FlashAttentionPrefillMetadata | None = None
|
||||
decode: FlashAttentionDecodeMetadata | None = None
|
||||
|
||||
@@ -291,7 +277,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
else AttentionCGSupport.UNIFORM_BATCH
|
||||
)
|
||||
supports_update_block_table: bool = True
|
||||
|
||||
|
||||
reorder_batch_threshold: ClassVar[int] = 1
|
||||
|
||||
@classmethod
|
||||
@@ -316,6 +302,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.attention_config = vllm_config.attention_config
|
||||
|
||||
self.decode_use_graph = vllm_config.compilation_config.cudagraph_mode.decode_use_graph()
|
||||
self.num_heads_q = self.model_config.get_num_attention_heads(
|
||||
self.parallel_config
|
||||
)
|
||||
@@ -325,7 +312,6 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
|
||||
self.max_num_splits = 0 # No upper bound on the number of splits.
|
||||
# self.aot_schedule = get_flash_attn_version() == 3
|
||||
self.aot_schedule = False
|
||||
|
||||
try:
|
||||
@@ -346,6 +332,9 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||
)
|
||||
self.max_cudagraph_size = self.compilation_config.max_cudagraph_capture_size
|
||||
# Align decode/prefill split threshold with speculative decode query length
|
||||
# when backend supports treating spec requests as decode.
|
||||
self._init_reorder_batch_threshold(1, supports_spec_as_decode=True)
|
||||
|
||||
if self.use_full_cuda_graph and self.aot_schedule:
|
||||
# FA3 scheduler_metadata size: 1 + round_up(batch_size, 4) * 4
|
||||
@@ -388,15 +377,17 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
max_seq_len = common_attn_metadata.max_seq_len
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
key_start_loc = common_attn_metadata.key_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
seq_lens_np = common_attn_metadata.seq_lens_np
|
||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
causal = common_attn_metadata.causal
|
||||
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
split_decodes_and_prefills(common_attn_metadata)
|
||||
)
|
||||
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
|
||||
split_decodes_and_prefills(
|
||||
common_attn_metadata,
|
||||
decode_threshold=self.reorder_batch_threshold,
|
||||
)
|
||||
assert num_decodes + num_prefills == num_reqs
|
||||
assert num_decode_tokens + num_prefill_tokens == num_actual_tokens
|
||||
|
||||
@@ -467,11 +458,11 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
dcp_context_kv_lens = None
|
||||
|
||||
cu_prefix_query_lens = None
|
||||
cu_prefix_kv_lens = None
|
||||
cu_suffix_kv_lens = None
|
||||
prefix_kv_lens = None
|
||||
suffix_kv_lens = None
|
||||
prefix_scheduler_metadata = None
|
||||
cu_prefix_kv_lens = None
|
||||
cu_suffix_kv_lens = None
|
||||
|
||||
if self.dcp_world_size > 1:
|
||||
query_kv_lens = query_start_loc[1:] - query_start_loc[:-1]
|
||||
@@ -507,11 +498,11 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
prefix_kv_lens = torch.tensor(
|
||||
[common_prefix_len], dtype=torch.int32, device=self.device
|
||||
)
|
||||
# Use GPU tensor directly - no CPU sync needed
|
||||
suffix_kv_lens = seq_lens[:num_reqs] - common_prefix_len
|
||||
cu_prefix_kv_lens = torch.tensor([0, common_prefix_len],
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
# Use GPU tensor directly - no CPU sync needed
|
||||
suffix_kv_lens = seq_lens[:num_reqs] - common_prefix_len
|
||||
|
||||
cu_suffix_kv_lens = torch.tensor([0,] + suffix_kv_lens.tolist(),
|
||||
dtype=torch.int32,
|
||||
@@ -542,7 +533,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
causal=causal,
|
||||
)
|
||||
# For FA3 + full cudagraph
|
||||
max_num_splits = 0
|
||||
max_num_splits = 0
|
||||
if self.use_full_cuda_graph and scheduler_metadata is not None:
|
||||
n = scheduler_metadata.shape[0]
|
||||
self.scheduler_metadata[:n] = scheduler_metadata
|
||||
@@ -552,50 +543,59 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
# output buffer.
|
||||
self.scheduler_metadata[n:] = 0
|
||||
scheduler_metadata = self.scheduler_metadata[:n]
|
||||
|
||||
|
||||
if num_actual_tokens <= self.max_cudagraph_size:
|
||||
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
|
||||
# usage, because the intermediate buffers of size [num_splits,
|
||||
# num_heads, num_tokens, head_size] are allocated. Therefore,
|
||||
# we only set num_splits when using cuda graphs.
|
||||
max_num_splits = self.max_num_splits
|
||||
|
||||
|
||||
prefill_metadata = None
|
||||
if num_prefills > 0:
|
||||
reqs_start = num_decodes
|
||||
prefill_query_start_loc = (
|
||||
query_start_loc[reqs_start:] - query_start_loc[reqs_start]
|
||||
)
|
||||
prefill_key_start_loc = (
|
||||
query_start_loc[reqs_start:] - query_start_loc[reqs_start]
|
||||
)
|
||||
reqs_start = num_decodes # prefill_start
|
||||
|
||||
prefill_query_start_loc = query_start_loc[
|
||||
reqs_start:] - query_start_loc[reqs_start]
|
||||
prefill_key_start_loc = key_start_loc[
|
||||
reqs_start:] - key_start_loc[reqs_start]
|
||||
prefill_metadata = FlashAttentionPrefillMetadata(
|
||||
block_table=block_table_tensor[reqs_start:, ...],
|
||||
query_start_loc=prefill_query_start_loc,
|
||||
key_start_loc=prefill_key_start_loc,
|
||||
max_query_len=max_query_len,
|
||||
)
|
||||
block_table=block_table_tensor[reqs_start:, ...],
|
||||
query_start_loc=prefill_query_start_loc,
|
||||
key_start_loc=prefill_key_start_loc,
|
||||
max_query_len=max_query_len,
|
||||
)
|
||||
decode_metadata = None
|
||||
if num_decodes > 0:
|
||||
reqs_start = num_decodes
|
||||
reqs_start = num_decodes # prefill_start
|
||||
decode_query_start_loc = query_start_loc[: reqs_start + 1]
|
||||
decode_query_lens = (
|
||||
decode_query_start_loc[1:] - decode_query_start_loc[:-1]
|
||||
)
|
||||
decode_metadata = FlashAttentionDecodeMetadata(
|
||||
block_table=block_table_tensor[:reqs_start, ...],
|
||||
query_start_loc=decode_query_start_loc,
|
||||
seq_lens=seq_lens[:reqs_start],
|
||||
max_decode_seq_len=torch.max(seq_lens[:reqs_start]).item(),
|
||||
max_query_len=decode_query_lens.max().item(),
|
||||
max_decode_seq_len=np.max(seq_lens_np[:reqs_start]).item(),
|
||||
use_graph=num_prefills==0 and self.decode_use_graph
|
||||
)
|
||||
|
||||
|
||||
attn_metadata = FlashAttentionMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
max_query_len=max_query_len,
|
||||
query_start_loc=query_start_loc,
|
||||
key_start_loc=key_start_loc,
|
||||
max_seq_len=max_seq_len,
|
||||
seq_lens=seq_lens,
|
||||
block_table=block_table_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
max_dcp_context_kv_len=max_dcp_context_kv_len,
|
||||
dcp_context_kv_lens=dcp_context_kv_lens,
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
num_prefills=num_prefills,
|
||||
max_dcp_context_kv_len=max_dcp_context_kv_len,
|
||||
dcp_context_kv_lens=dcp_context_kv_lens,
|
||||
use_cascade=use_cascade,
|
||||
common_prefix_len=common_prefix_len,
|
||||
scheduler_metadata=scheduler_metadata,
|
||||
@@ -607,8 +607,8 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
prefix_scheduler_metadata=prefix_scheduler_metadata,
|
||||
max_num_splits=max_num_splits,
|
||||
causal=causal,
|
||||
prefill=prefill_metadata,
|
||||
decode=decode_metadata,
|
||||
prefill = prefill_metadata,
|
||||
decode = decode_metadata,
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
@@ -621,6 +621,19 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
new_metadata = copy.copy(metadata)
|
||||
new_metadata.block_table = blk_table
|
||||
new_metadata.slot_mapping = slot_mapping
|
||||
# Keep nested prefill/decode block tables in sync. Decode path consumes
|
||||
# `attn_metadata.decode.block_table`, so updating only the top-level
|
||||
# `block_table` is insufficient when metadata is reused across groups.
|
||||
if metadata.decode is not None:
|
||||
new_decode = copy.copy(metadata.decode)
|
||||
reqs_start = metadata.num_decodes
|
||||
new_decode.block_table = blk_table[:reqs_start, ...]
|
||||
new_metadata.decode = new_decode
|
||||
if metadata.prefill is not None:
|
||||
new_prefill = copy.copy(metadata.prefill)
|
||||
reqs_start = metadata.num_decodes
|
||||
new_prefill.block_table = blk_table[reqs_start:, ...]
|
||||
new_metadata.prefill = new_prefill
|
||||
return new_metadata
|
||||
|
||||
def use_cascade_attention(self, *args, **kwargs) -> bool:
|
||||
@@ -667,7 +680,15 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
self.attn_type = attn_type
|
||||
self.vllm_flash_attn_version = get_flash_attn_version()
|
||||
self.vllm_flash_attn_version = get_flash_attn_version(
|
||||
requires_alibi=alibi_slopes is not None,
|
||||
head_size=head_size,
|
||||
)
|
||||
logger.info_once(
|
||||
"Using FlashAttention version %s",
|
||||
self.vllm_flash_attn_version,
|
||||
scope="local",
|
||||
)
|
||||
# Cache the batch invariant result for use in forward passes
|
||||
self.batch_invariant_enabled = vllm_is_batch_invariant()
|
||||
|
||||
@@ -677,6 +698,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
)
|
||||
|
||||
self.sinks = sinks
|
||||
|
||||
if self.sinks is not None:
|
||||
assert flash_attn_supports_sinks(), (
|
||||
"Sinks are only supported in FlashAttention 3"
|
||||
@@ -687,6 +709,28 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
)
|
||||
|
||||
self.supports_quant_query_input = True
|
||||
self.supports_per_head_quant_scales = (
|
||||
self.vllm_flash_attn_version >= 3
|
||||
if self.vllm_flash_attn_version is not None
|
||||
else False
|
||||
)
|
||||
assert envs.VLLM_ATTN_OPT_LEVEL in [0, 1, 2], "VLLM_ATTN_OPT_LEVEL only support [0 for non-quant, 1 for I8Q_I8K_I8V, 2 for I8Q_I8K_F16V] now! but got {}".format(envs.VLLM_ATTN_OPT_LEVEL)
|
||||
'''
|
||||
quant_type = 0
|
||||
attention:f16 qkv
|
||||
cache:f16 kv cache
|
||||
quant_type = 1
|
||||
attention:int8q int8k int8v
|
||||
cache:
|
||||
int8 k cache && fp32 k cache scale
|
||||
int8 v cache && fp32 v cache scale(load from file, dont update)
|
||||
quant_type = 2
|
||||
attention:int8q int8k fp16v
|
||||
cache:
|
||||
int8 k cache && fp32 k cache scale
|
||||
fp16 v cache
|
||||
'''
|
||||
self.quant_type = int(envs.VLLM_ATTN_OPT_LEVEL)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -698,7 +742,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
attn_metadata: FlashAttentionMetadata,
|
||||
output: torch.Tensor | None = None,
|
||||
sqrt_alibi: bool = False,
|
||||
kv_cache_scale: torch.Tensor | None = None,
|
||||
kv_cache_scale: Union[torch.Tensor, List[torch.Tensor]] | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
@@ -711,6 +755,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
kv_cache: shape =
|
||||
[2, num_blocks, block_size, num_kv_heads, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
kv_cache_scale = [num_blocks, num_kv_heads, block_size] + [num_kv_heads, head_size]
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
NOTE: FP8 quantization, flash-attn expect the size of
|
||||
@@ -718,9 +763,9 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
We use torch's .expand() to avoid duplicating values
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
assert self.vllm_flash_attn_version is not None, (
|
||||
"FlashAttention version not detected."
|
||||
)
|
||||
# assert self.vllm_flash_attn_version is not None, (
|
||||
# "FlashAttention version not detected."
|
||||
# )
|
||||
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
@@ -729,13 +774,12 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
# return output.fill_(0)
|
||||
return output.fill_(0).view(-1, self.num_heads * self.head_size)
|
||||
|
||||
return output.view(-1, self.num_heads * self.head_size)
|
||||
|
||||
softmax_scale: float = self.scale
|
||||
window_size = self.sliding_window
|
||||
alibi_slopes: torch.Tensor | None = self.alibi_slopes
|
||||
logits_soft_cap: float | None = self.logits_soft_cap
|
||||
alibi_slopes: torch.Tensor = self.alibi_slopes
|
||||
logits_soft_cap: float = self.logits_soft_cap
|
||||
|
||||
attn_type = self.attn_type
|
||||
|
||||
@@ -761,18 +805,140 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
output[:num_actual_tokens],
|
||||
attn_metadata,
|
||||
layer,
|
||||
)
|
||||
).view(-1, self.num_heads * self.head_size)
|
||||
|
||||
# For decoder and cross-attention, use KV cache as before
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
has_decode = attn_metadata.num_decodes > 0
|
||||
has_prefill = attn_metadata.num_prefills > 0
|
||||
decode_only = has_decode and not has_prefill
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
|
||||
if self.quant_type == 0:
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
elif self.quant_type == 1:
|
||||
i8_key_cache, i8_value_cache = kv_cache.unbind(0)
|
||||
num_blocks, num_kv_heads, block_size, head_size = i8_key_cache.shape
|
||||
|
||||
key_scale_cache, value_scale_cache = kv_cache_scale
|
||||
assert key_scale_cache.shape == (num_blocks, num_kv_heads, block_size) and key_scale_cache.dtype == torch.float32, f"key_scale_cache.shape {key_scale_cache.shape} != (num_blocks, num_kv_heads, block_size) or key_scale_cache.dtype {key_scale_cache.dtype} != torch.float32"
|
||||
assert value_scale_cache.shape == (num_kv_heads, head_size) and value_scale_cache.dtype == torch.float32, f"value_scale_cache.shape {value_scale_cache.shape} != (num_kv_heads, head_size) or value_scale_cache.dtype {value_scale_cache.dtype} != torch.float32"
|
||||
value_cache_info = (i8_value_cache, value_scale_cache)
|
||||
|
||||
elif self.quant_type == 2:
|
||||
# key_cache 是 f16,value_cache 是 int8
|
||||
i8_key_cache = kv_cache[0]
|
||||
num_blocks, num_kv_heads, block_size, head_size = i8_key_cache.shape
|
||||
value_cache = kv_cache[1:].view(query.dtype).reshape(num_blocks, num_kv_heads, block_size, head_size)
|
||||
key_scale_cache = kv_cache_scale
|
||||
value_cache_info = (value_cache, None)
|
||||
|
||||
decode_q = query[:num_decode_tokens]
|
||||
prefill_q = query[num_decode_tokens:]
|
||||
prefill_output = output[num_decode_tokens:]
|
||||
decode_output = output[:num_decode_tokens]
|
||||
|
||||
if self.quant_type == 1:
|
||||
if decode_only:
|
||||
int8_query, query_scale = ixf_ops.scaled_int8_quant_for_attn(
|
||||
query, 2, transpose_scale=False
|
||||
)
|
||||
i8_key, key_scale = ixf_ops.scaled_int8_quant_for_attn(
|
||||
key, 2, transpose_scale=False
|
||||
)
|
||||
i8_value, _value_scale = ixf_ops.scaled_int8_quant_for_attn(
|
||||
value, 0, transpose_scale=False, scale=value_cache_info[1]
|
||||
)
|
||||
else:
|
||||
int8_query, query_scale = ixf_ops.scaled_int8_quant_for_attn(
|
||||
query, 2, transpose_scale=True
|
||||
)
|
||||
i8_key, key_scale = ixf_ops.scaled_int8_quant_for_attn(
|
||||
key, 2, transpose_scale=False
|
||||
)
|
||||
i8_value, _value_scale = ixf_ops.scaled_int8_quant_for_attn(
|
||||
value, 0, transpose_scale=False, scale=value_cache_info[1]
|
||||
)
|
||||
elif self.quant_type == 2:
|
||||
'''
|
||||
origin key cache
|
||||
num_blocks, num_kv_heads, block_size, head_size f16
|
||||
reformat key cache
|
||||
key_cache_i8 : num_blocks, num_kv_heads, block_size, head_size int8
|
||||
key_scale_cache : num_blocks, num_kv_heads, block_size fp32
|
||||
'''
|
||||
|
||||
if decode_only:
|
||||
int8_query, query_scale = ixf_ops.scaled_int8_quant_for_attn(
|
||||
query, 2, transpose_scale=False
|
||||
)
|
||||
i8_key, key_scale = ixf_ops.scaled_int8_quant_for_attn(
|
||||
key, 2, transpose_scale=False
|
||||
)
|
||||
else:
|
||||
int8_query, query_scale = ixf_ops.scaled_int8_quant_for_attn(
|
||||
query, 2, transpose_scale=True
|
||||
)
|
||||
i8_key, key_scale = ixf_ops.scaled_int8_quant_for_attn(
|
||||
key, 2, transpose_scale=False
|
||||
)
|
||||
else:
|
||||
if layer.quant_manager is not None and layer.quant_manager.check_enable():
|
||||
i8_value, value_scale = ixf_ops.scaled_int8_quant_for_attn(
|
||||
value, 0, transpose_scale=False
|
||||
)
|
||||
layer.quant_manager.update_data(value_scale)
|
||||
|
||||
# key and value may be None in the case of cross attention. They are
|
||||
# calculated once based on the output from the encoder and then cached
|
||||
# in KV cache.
|
||||
if (
|
||||
self.kv_sharing_target_layer_name is None
|
||||
and key is not None
|
||||
and value is not None
|
||||
):
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
|
||||
# not padded. However, we don't need to do key[:num_actual_tokens]
|
||||
# and value[:num_actual_tokens] because the reshape_and_cache_flash
|
||||
# op uses the slot_mapping's shape to determine the number of
|
||||
# actual tokens.
|
||||
if self.quant_type == 1:
|
||||
if has_prefill:
|
||||
ixf_ops.reshape_and_cache_flash_int8(
|
||||
key=i8_key,
|
||||
value=i8_value,
|
||||
k_scale=key_scale,
|
||||
key_cache=i8_key_cache,
|
||||
value_cache=value_cache_info[0],
|
||||
key_scale_cache=key_scale_cache,
|
||||
slot_mapping=attn_metadata.slot_mapping,
|
||||
kv_cache_dtype="",
|
||||
)
|
||||
elif self.quant_type == 2:
|
||||
if has_prefill:
|
||||
ixf_ops.reshape_and_cache_flash_mix(
|
||||
key=i8_key,
|
||||
value=value,
|
||||
k_scale=key_scale,
|
||||
key_cache=i8_key_cache,
|
||||
value_cache=value_cache_info[0],
|
||||
key_scale_cache=key_scale_cache,
|
||||
slot_mapping=attn_metadata.slot_mapping,
|
||||
kv_cache_dtype="",
|
||||
)
|
||||
|
||||
else:
|
||||
ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
# queries are quantized in the attention layer
|
||||
@@ -783,19 +949,6 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
value_cache = value_cache.view(dtype)
|
||||
|
||||
if not attn_metadata.use_cascade:
|
||||
cu_seqlens_q = attn_metadata.query_start_loc
|
||||
seqused_k = attn_metadata.seq_lens
|
||||
max_seqlen_q = attn_metadata.max_query_len
|
||||
max_seqlen_k = attn_metadata.max_seq_len
|
||||
block_table = attn_metadata.block_table
|
||||
scheduler_metadata = attn_metadata.scheduler_metadata
|
||||
|
||||
descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads)
|
||||
|
||||
q_descale = layer._q_scale.expand(descale_shape)
|
||||
k_descale = layer._k_scale.expand(descale_shape)
|
||||
v_descale = layer._v_scale.expand(descale_shape)
|
||||
|
||||
if self.dcp_world_size > 1:
|
||||
self._forward_with_dcp(
|
||||
query[:num_actual_tokens],
|
||||
@@ -805,79 +958,140 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
value_cache,
|
||||
output[:num_actual_tokens],
|
||||
attn_metadata,
|
||||
q_descale=q_descale,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
)
|
||||
return output.view(-1, self.num_heads * self.head_size)
|
||||
else:
|
||||
sliding_window_size = (
|
||||
list(self.sliding_window)
|
||||
if self.sliding_window is not None
|
||||
else None
|
||||
)
|
||||
if has_prefill:
|
||||
flash_attn_varlen_func(
|
||||
q=prefill_q,
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
cu_seqlens_q=attn_metadata.prefill.query_start_loc,
|
||||
cu_seqlens_k=attn_metadata.prefill.query_start_loc,
|
||||
max_seqlen_q=attn_metadata.prefill.max_query_len,
|
||||
max_seqlen_k=attn_metadata.max_query_len,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
window_size=sliding_window_size,
|
||||
alibi_slopes=alibi_slopes,
|
||||
softcap=logits_soft_cap,
|
||||
sqrt_alibi=sqrt_alibi,
|
||||
sinks=self.sinks,
|
||||
out=prefill_output,
|
||||
block_table=attn_metadata.prefill.block_table,
|
||||
)
|
||||
# key = key[num_decode_tokens:]
|
||||
# value = value[num_decode_tokens:]
|
||||
|
||||
# int8 attn
|
||||
if self.quant_type > 0:
|
||||
flash_attn_varlen_int8_func(
|
||||
q=int8_query[num_decode_tokens:],
|
||||
k=i8_key_cache,
|
||||
v=value_cache_info[0],
|
||||
q_scale=query_scale[:, num_decode_tokens:],
|
||||
k_scale=key_scale_cache,
|
||||
v_scale=value_cache_info[1],
|
||||
cu_seqlens_q=attn_metadata.prefill.query_start_loc,
|
||||
cu_seqlens_k=attn_metadata.prefill.key_start_loc,
|
||||
max_seqlen_q=attn_metadata.prefill.max_query_len,
|
||||
max_seqlen_k=attn_metadata.max_query_len,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
window_size=window_size,
|
||||
alibi_slopes=alibi_slopes,
|
||||
softcap=logits_soft_cap,
|
||||
sqrt_alibi=sqrt_alibi,
|
||||
out=prefill_output,
|
||||
block_table=attn_metadata.prefill.block_table,
|
||||
output_dtype=query.dtype
|
||||
)
|
||||
else:
|
||||
flash_attn_varlen_func(
|
||||
q=prefill_q,
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
cu_seqlens_q=attn_metadata.prefill.query_start_loc,
|
||||
cu_seqlens_k=attn_metadata.prefill.key_start_loc,
|
||||
max_seqlen_q=attn_metadata.prefill.max_query_len,
|
||||
max_seqlen_k=attn_metadata.max_query_len,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
window_size=window_size,
|
||||
alibi_slopes=alibi_slopes,
|
||||
softcap=logits_soft_cap,
|
||||
sqrt_alibi=sqrt_alibi,
|
||||
sinks=self.sinks,
|
||||
out=prefill_output,
|
||||
block_table=attn_metadata.prefill.block_table,
|
||||
)
|
||||
if has_decode:
|
||||
flash_attn_with_kvcache(
|
||||
q=decode_q.unsqueeze(1),
|
||||
k_cache=key_cache.contiguous(),
|
||||
v_cache=value_cache.contiguous(),
|
||||
block_table=attn_metadata.decode.block_table,
|
||||
cache_seqlens=attn_metadata.decode.seq_lens,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
window_size=sliding_window_size,
|
||||
alibi_slopes=alibi_slopes,
|
||||
softcap=logits_soft_cap,
|
||||
use_sqrt_alibi=sqrt_alibi,
|
||||
out=decode_output.unsqueeze(1),
|
||||
max_context_len=attn_metadata.decode.max_decode_seq_len,
|
||||
# sinks=self.sinks,
|
||||
)
|
||||
# for mtp + cuda graph
|
||||
max_q_len = attn_metadata.decode.max_query_len if attn_metadata.decode is not None else attn_metadata.max_query_len
|
||||
max_ct_len = attn_metadata.decode.max_decode_seq_len if attn_metadata.decode is not None else attn_metadata.max_seq_len
|
||||
if self.quant_type in [1, 2]:
|
||||
para_dict = dict(
|
||||
output=decode_output,
|
||||
query=int8_query[:num_decode_tokens],
|
||||
key_cache=i8_key_cache,
|
||||
query_scale=query_scale[:num_decode_tokens] if decode_only else query_scale[:, :num_decode_tokens].t().contiguous(),
|
||||
key_scale_cache=key_scale_cache,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
scale=softmax_scale,
|
||||
block_tables=attn_metadata.decode.block_table,
|
||||
context_lens=attn_metadata.decode.seq_lens,
|
||||
block_size=i8_key_cache.shape[-2],
|
||||
softcap=logits_soft_cap,
|
||||
alibi_slopes=alibi_slopes,
|
||||
causal=True,
|
||||
window_left=window_size[0],
|
||||
window_right=window_size[1],
|
||||
use_sqrt_alibi = sqrt_alibi,
|
||||
use_cuda_graph=attn_metadata.decode.use_graph if decode_only else False,
|
||||
max_context_len=max_ct_len,
|
||||
# mtp
|
||||
cu_query_lens=attn_metadata.decode.query_start_loc,
|
||||
max_query_len=max_q_len,
|
||||
)
|
||||
|
||||
if self.quant_type == 1:
|
||||
para_dict.update(
|
||||
dict(
|
||||
value_cache=value_cache_info[0],
|
||||
value_scale_cache=value_cache_info[1],
|
||||
)
|
||||
)
|
||||
# for kv + k_scale write fusion
|
||||
if decode_only:
|
||||
para_dict.update(
|
||||
dict(
|
||||
save_key=i8_key[:num_decode_tokens],
|
||||
save_value=i8_value[:num_decode_tokens],
|
||||
save_key_scale=key_scale[:num_decode_tokens],
|
||||
)
|
||||
)
|
||||
ixf_ops.vllm_paged_attention_int8(**para_dict)
|
||||
elif self.quant_type == 2:
|
||||
para_dict.update(
|
||||
dict(
|
||||
value_cache=value_cache,
|
||||
)
|
||||
)
|
||||
if decode_only:
|
||||
para_dict.update(
|
||||
dict(
|
||||
save_key=i8_key[:num_decode_tokens],
|
||||
save_value=value[:num_decode_tokens].contiguous(),
|
||||
save_key_scale=key_scale[:num_decode_tokens],
|
||||
)
|
||||
)
|
||||
ixf_ops.vllm_paged_attention_mix(
|
||||
**para_dict
|
||||
)
|
||||
else:
|
||||
flash_attn_with_kvcache(
|
||||
q=decode_q.unsqueeze(1),
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
block_table=attn_metadata.decode.block_table,
|
||||
cache_seqlens=attn_metadata.decode.seq_lens,
|
||||
max_query_len=max_q_len,
|
||||
cu_query_lens=attn_metadata.decode.query_start_loc,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
window_size=window_size,
|
||||
alibi_slopes=alibi_slopes,
|
||||
softcap=logits_soft_cap,
|
||||
use_sqrt_alibi=sqrt_alibi,
|
||||
sinks=self.sinks,
|
||||
out=decode_output.unsqueeze(1),
|
||||
use_cuda_graph=attn_metadata.decode.use_graph,
|
||||
max_context_len=max_ct_len
|
||||
)
|
||||
# Compute attention and update output up to `num_actual_tokens`.
|
||||
return output.view(-1, self.num_heads * self.head_size)
|
||||
|
||||
# flash_attn_varlen_func(
|
||||
# q=query[:num_actual_tokens],
|
||||
# k=key_cache,
|
||||
# v=value_cache,
|
||||
# out=output[:num_actual_tokens],
|
||||
# cu_seqlens_q=cu_seqlens_q,
|
||||
# max_seqlen_q=max_seqlen_q,
|
||||
# seqused_k=seqused_k,
|
||||
# max_seqlen_k=max_seqlen_k,
|
||||
# softmax_scale=self.scale,
|
||||
# causal=attn_metadata.causal,
|
||||
# alibi_slopes=self.alibi_slopes,
|
||||
# window_size=sliding_window_size,
|
||||
# block_table=block_table,
|
||||
# softcap=self.logits_soft_cap,
|
||||
# scheduler_metadata=scheduler_metadata,
|
||||
# fa_version=self.vllm_flash_attn_version,
|
||||
# q_descale=q_descale,
|
||||
# k_descale=k_descale,
|
||||
# v_descale=v_descale,
|
||||
# num_splits=attn_metadata.max_num_splits,
|
||||
# s_aux=self.sinks,
|
||||
# )
|
||||
# return output
|
||||
|
||||
# Cascade attention (rare case).
|
||||
cascade_attention(
|
||||
@@ -906,12 +1120,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
v_descale=layer._v_scale,
|
||||
s_aux=self.sinks,
|
||||
)
|
||||
# return output
|
||||
return (
|
||||
output[:num_actual_tokens]
|
||||
.contiguous()
|
||||
.view(-1, self.num_heads * self.head_size)
|
||||
)
|
||||
return output.view(-1, self.num_heads * self.head_size)
|
||||
|
||||
def do_kv_cache_update(
|
||||
self,
|
||||
@@ -935,7 +1144,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
# and value[:num_actual_tokens] because the reshape_and_cache_flash
|
||||
# op uses the slot_mapping's shape to determine the number of
|
||||
# actual tokens.
|
||||
reshape_and_cache_flash(
|
||||
ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
@@ -959,9 +1168,9 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
k_descale: torch.Tensor | None = None,
|
||||
v_descale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert self.vllm_flash_attn_version is not None, (
|
||||
"FlashAttention version not detected."
|
||||
)
|
||||
# assert self.vllm_flash_attn_version is not None, (
|
||||
# "FlashAttention version not detected."
|
||||
# )
|
||||
|
||||
cu_seqlens_q = attn_metadata.query_start_loc
|
||||
max_seqlen_q = attn_metadata.max_query_len
|
||||
@@ -969,27 +1178,22 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
|
||||
query = query.contiguous()
|
||||
query_across_dcp = get_dcp_group().all_gather(query, dim=1)
|
||||
cu_dcp_kv_klens = attn_metadata.dcp_context_kv_lens.cumsum(dim=0, dtype=torch.int32)
|
||||
new_tensor = torch.tensor([0],
|
||||
device=attn_metadata.dcp_context_kv_lens.device,
|
||||
dtype=attn_metadata.dcp_context_kv_lens.dtype)
|
||||
cu_seqlens_k = torch.cat([new_tensor, cu_dcp_kv_klens])
|
||||
sliding_window_size = (
|
||||
list(self.sliding_window) if self.sliding_window is not None else None
|
||||
)
|
||||
cu_seqlens_k = torch.cat(
|
||||
[
|
||||
torch.zeros(1, device=cu_seqlens_q.device, dtype=cu_seqlens_q.dtype),
|
||||
attn_metadata.dcp_context_kv_lens.cumsum(
|
||||
dim=0, dtype=cu_seqlens_q.dtype
|
||||
),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
context_attn_out, context_lse = flash_attn_varlen_func(
|
||||
q=query_across_dcp,
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
out=None,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
# seqused_k=attn_metadata.dcp_context_kv_lens,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
max_seqlen_k=attn_metadata.max_dcp_context_kv_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=False,
|
||||
@@ -998,11 +1202,6 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
block_table=block_table,
|
||||
softcap=self.logits_soft_cap,
|
||||
return_softmax_lse=True,
|
||||
# scheduler_metadata=attn_metadata.scheduler_metadata,
|
||||
# fa_version=self.vllm_flash_attn_version,
|
||||
# q_descale=q_descale,
|
||||
# k_descale=k_descale,
|
||||
# v_descale=v_descale,
|
||||
)
|
||||
# FA returns LSE in shape [ H, B ] but cp_lse_ag_out_rs wants [ B, H ]
|
||||
context_attn_out_cor, context_lse_cor = cp_lse_ag_out_rs(
|
||||
@@ -1028,10 +1227,6 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
window_size=sliding_window_size,
|
||||
softcap=self.logits_soft_cap,
|
||||
return_softmax_lse=True,
|
||||
# fa_version=self.vllm_flash_attn_version,
|
||||
# q_descale=q_descale,
|
||||
# k_descale=k_descale,
|
||||
# v_descale=v_descale,
|
||||
)
|
||||
assert context_attn_out_cor.shape == query_attn_out.shape
|
||||
assert context_lse_cor.shape == query_lse.shape
|
||||
@@ -1040,7 +1235,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
context_lse_cor,
|
||||
query_attn_out,
|
||||
query_lse,
|
||||
output,
|
||||
output
|
||||
)
|
||||
|
||||
def _forward_encoder_attention(
|
||||
@@ -1062,9 +1257,9 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
attn_metadata: Encoder attention metadata
|
||||
layer: The attention layer
|
||||
"""
|
||||
assert self.vllm_flash_attn_version is not None, (
|
||||
"FlashAttention version not detected."
|
||||
)
|
||||
# assert self.vllm_flash_attn_version is not None, (
|
||||
# "FlashAttention version not detected."
|
||||
# )
|
||||
|
||||
# For encoder attention, process FP8 quantization if needed
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
@@ -1101,18 +1296,9 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
window_size=sliding_window_size,
|
||||
softcap=self.logits_soft_cap,
|
||||
# fa_version=self.vllm_flash_attn_version,
|
||||
# q_descale=layer._q_scale.expand(descale_shape),
|
||||
# k_descale=layer._k_scale.expand(descale_shape),
|
||||
# v_descale=layer._v_scale.expand(descale_shape),
|
||||
# num_splits=1 if self.batch_invariant_enabled else 0,
|
||||
)
|
||||
|
||||
return (
|
||||
output[: attn_metadata.num_actual_tokens]
|
||||
.contiguous()
|
||||
.view(-1, self.num_heads * self.head_size)
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def use_cascade_attention(
|
||||
@@ -1203,8 +1389,6 @@ def cascade_attention(
|
||||
cu_prefix_query_lens: torch.Tensor,
|
||||
cu_prefix_kv_lens: torch.Tensor,
|
||||
cu_suffix_kv_lens: torch.Tensor,
|
||||
# prefix_kv_lens: torch.Tensor,
|
||||
# suffix_kv_lens: torch.Tensor,
|
||||
max_kv_len: int,
|
||||
softmax_scale: float,
|
||||
alibi_slopes: torch.Tensor | None,
|
||||
@@ -1228,12 +1412,13 @@ def cascade_attention(
|
||||
)
|
||||
|
||||
num_tokens = query.shape[0]
|
||||
# block_size = key_cache.shape[-3]
|
||||
block_size = key_cache.shape[-2]
|
||||
assert common_prefix_len % block_size == 0
|
||||
num_common_kv_blocks = common_prefix_len // block_size
|
||||
assert num_common_kv_blocks > 0
|
||||
descale_shape = (cu_prefix_query_lens.shape[0] - 1, key_cache.shape[-2])
|
||||
assert q_descale is None or q_descale==1, f"q_descale is not None, q_descale: {q_descale}"
|
||||
assert k_descale is None or k_descale==1, f"k_descale is not None, k_descale: {k_descale}"
|
||||
assert v_descale is None or v_descale==1, f"v_descale is not None, v_descale: {v_descale}"
|
||||
|
||||
# Process shared prefix.
|
||||
prefix_output, prefix_lse = flash_attn_varlen_func(
|
||||
@@ -1241,7 +1426,6 @@ def cascade_attention(
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
cu_seqlens_q=cu_prefix_query_lens,
|
||||
# seqused_k=prefix_kv_lens,
|
||||
cu_seqlens_k=cu_prefix_kv_lens,
|
||||
max_seqlen_q=num_tokens,
|
||||
max_seqlen_k=common_prefix_len,
|
||||
@@ -1251,26 +1435,14 @@ def cascade_attention(
|
||||
block_table=block_table[:1],
|
||||
softcap=logits_soft_cap,
|
||||
return_softmax_lse=True,
|
||||
# scheduler_metadata=prefix_scheduler_metadata,
|
||||
# fa_version=fa_version,
|
||||
# q_descale=q_descale.expand(descale_shape) if q_descale is not None else None,
|
||||
# k_descale=k_descale.expand(descale_shape) if k_descale is not None else None,
|
||||
# v_descale=v_descale.expand(descale_shape) if v_descale is not None else None,
|
||||
# s_aux is incorporated into prefix_lse inside the GPU kernel,
|
||||
# enabling its effect during the final attention merge.
|
||||
# s_aux=s_aux,
|
||||
# num_splits=1 if vllm_is_batch_invariant() else max_num_splits,
|
||||
)
|
||||
|
||||
descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2])
|
||||
|
||||
# Process suffix per query.
|
||||
suffix_output, suffix_lse = flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
cu_seqlens_q=cu_query_lens,
|
||||
# seqused_k=suffix_kv_lens,
|
||||
cu_seqlens_k=cu_suffix_kv_lens,
|
||||
max_seqlen_q=max_query_len,
|
||||
max_seqlen_k=max_kv_len - common_prefix_len,
|
||||
@@ -1280,14 +1452,6 @@ def cascade_attention(
|
||||
block_table=block_table[:, num_common_kv_blocks:],
|
||||
softcap=logits_soft_cap,
|
||||
return_softmax_lse=True,
|
||||
# scheduler_metadata=suffix_scheduler_metadata,
|
||||
# fa_version=fa_version,
|
||||
# q_descale=q_descale.expand(descale_shape) if q_descale is not None else None,
|
||||
# k_descale=k_descale.expand(descale_shape) if k_descale is not None else None,
|
||||
# v_descale=v_descale.expand(descale_shape) if v_descale is not None else None,
|
||||
# num_splits=1 if vllm_is_batch_invariant() else max_num_splits,
|
||||
)
|
||||
|
||||
# Merge prefix and suffix outputs, and store the result in output.
|
||||
# merge_attn_states(output, prefix_output, prefix_lse, suffix_output, suffix_lse)
|
||||
merge_attn_states(prefix_output, prefix_lse, suffix_output, suffix_lse, output)
|
||||
|
||||
@@ -13,7 +13,7 @@ from flashinfer import (
|
||||
BatchPrefillWithRaggedKVCacheWrapper,
|
||||
MultiLevelCascadeAttentionWrapper,
|
||||
)
|
||||
from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache
|
||||
from flashinfer.decode import fast_decode_plan, trtllm_batch_decode_with_kv_cache
|
||||
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
|
||||
from flashinfer.utils import FP4Tensor
|
||||
from typing_extensions import override
|
||||
@@ -199,14 +199,14 @@ class BatchDCPPrefillWrapper:
|
||||
):
|
||||
"""Plan the prefill operation with given parameters."""
|
||||
self._context.plan(
|
||||
qo_indptr_cpu,
|
||||
paged_kv_indptr_cpu,
|
||||
paged_kv_indices,
|
||||
paged_kv_last_page_len_cpu,
|
||||
num_qo_heads * dcp_world_size,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
page_size,
|
||||
qo_indptr=qo_indptr_cpu,
|
||||
paged_kv_indptr=paged_kv_indptr_cpu,
|
||||
paged_kv_indices=paged_kv_indices,
|
||||
paged_kv_last_page_len=paged_kv_last_page_len_cpu,
|
||||
num_qo_heads=num_qo_heads * dcp_world_size,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim_qk=head_dim,
|
||||
page_size=page_size,
|
||||
causal=False, # This is context run
|
||||
sm_scale=sm_scale,
|
||||
window_left=window_left,
|
||||
@@ -374,13 +374,13 @@ class FlashInferBackend(AttentionBackend):
|
||||
|
||||
@classmethod
|
||||
def get_required_kv_cache_layout(cls) -> KVCacheLayoutType | None:
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
capability = current_platform.get_device_capability()
|
||||
if capability is not None and capability.major == 10:
|
||||
return "HND"
|
||||
return None
|
||||
|
||||
forward_includes_kv_cache_update: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class FIPrefill:
|
||||
@@ -573,20 +573,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
# try to use fp8 q if kv cache is fp8, and will fall back to model dtype
|
||||
# if TRTLLM attention kernel is not used when building attn metadata
|
||||
can_use_trtllm = can_use_trtllm_attention(self.num_qo_heads, self.num_kv_heads)
|
||||
|
||||
# TRTLLM attention requires strictly contiguous KV cache tensors.
|
||||
# When KV transfer (P/D disaggregation) is enabled, the KV cache may be
|
||||
# permuted into non-contiguous views, which causes assertion failures.
|
||||
self._kv_transfer_enabled = vllm_config.kv_transfer_config is not None
|
||||
if can_use_trtllm and self._kv_transfer_enabled:
|
||||
logger.info_once(
|
||||
"TRTLLM attention is disabled because KV transfer "
|
||||
"(P/D disaggregation) is enabled. TRTLLM attention requires "
|
||||
"strictly contiguous KV cache tensors which may not be "
|
||||
"guaranteed with KV transfer."
|
||||
)
|
||||
can_use_trtllm = False
|
||||
|
||||
if (
|
||||
can_use_trtllm
|
||||
and not vllm_config.attention_config.disable_flashinfer_q_quantization
|
||||
@@ -816,6 +802,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
page_size,
|
||||
paged_kv_last_page_len_np,
|
||||
)
|
||||
self.paged_kv_last_page_len.gpu[:num_reqs].copy_(
|
||||
self.paged_kv_last_page_len.cpu[:num_reqs], non_blocking=True
|
||||
)
|
||||
return paged_kv_indices
|
||||
|
||||
def build(
|
||||
@@ -860,9 +849,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
has_sinks=self.has_sinks,
|
||||
has_spec=uses_spec_reorder,
|
||||
)
|
||||
# KV transfer requires non-contiguous KV cache views, incompatible with TRTLLM
|
||||
if self._kv_transfer_enabled:
|
||||
prefill_use_trtllm = False
|
||||
decode_use_trtllm = (
|
||||
self.use_trtllm_decode_attention and self.dcp_world_size <= 1
|
||||
)
|
||||
@@ -997,14 +983,17 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
|
||||
attn_metadata.cascade_wrapper = self._get_cascade_wrapper()
|
||||
attn_metadata.cascade_wrapper.plan(
|
||||
[shared_qo_indptr_cpu, qo_indptr_cpu],
|
||||
[shared_kv_page_indptr_cpu, paged_kv_indptr_cpu],
|
||||
[shared_kv_page_indices_cpu, paged_kv_indices],
|
||||
[shared_kv_last_page_len_cpu, paged_kv_last_page_len_cpu],
|
||||
self.num_qo_heads,
|
||||
self.num_kv_heads,
|
||||
self.head_dim,
|
||||
self.page_size,
|
||||
qo_indptr_arr=[shared_qo_indptr_cpu, qo_indptr_cpu],
|
||||
paged_kv_indptr_arr=[shared_kv_page_indptr_cpu, paged_kv_indptr_cpu],
|
||||
paged_kv_indices_arr=[shared_kv_page_indices_cpu, paged_kv_indices],
|
||||
paged_kv_last_page_len=[
|
||||
shared_kv_last_page_len_cpu,
|
||||
paged_kv_last_page_len_cpu,
|
||||
],
|
||||
num_qo_heads=self.num_qo_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_dim=self.head_dim,
|
||||
page_size=self.page_size,
|
||||
causal=True,
|
||||
sm_scale=self.sm_scale,
|
||||
window_left=self.window_left,
|
||||
@@ -1082,14 +1071,14 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
BatchPrefillWithPagedKVCacheWrapper,
|
||||
)
|
||||
prefill_wrapper.plan(
|
||||
qo_indptr_prefill_cpu,
|
||||
paged_kv_indptr_prefill_cpu,
|
||||
paged_kv_indices,
|
||||
paged_kv_last_page_len_prefill_cpu,
|
||||
self.num_qo_heads,
|
||||
self.num_kv_heads,
|
||||
self.head_dim,
|
||||
self.page_size,
|
||||
qo_indptr=qo_indptr_prefill_cpu,
|
||||
paged_kv_indptr=paged_kv_indptr_prefill_cpu,
|
||||
paged_kv_indices=paged_kv_indices,
|
||||
paged_kv_last_page_len=paged_kv_last_page_len_prefill_cpu,
|
||||
num_qo_heads=self.num_qo_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_dim_qk=self.head_dim,
|
||||
page_size=self.page_size,
|
||||
causal=True,
|
||||
sm_scale=self.sm_scale,
|
||||
window_left=self.window_left,
|
||||
@@ -1130,14 +1119,15 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
# in atten_metadata when using cudagraph.
|
||||
fast_plan_decode(
|
||||
decode_wrapper,
|
||||
self.paged_kv_indptr.cpu[: num_input_tokens + 1],
|
||||
paged_kv_indices,
|
||||
self.paged_kv_last_page_len.cpu[:num_input_tokens],
|
||||
seq_lens_cpu[:num_input_tokens],
|
||||
self.num_qo_heads * self.dcp_world_size,
|
||||
self.num_kv_heads,
|
||||
self.head_dim,
|
||||
self.page_size,
|
||||
indptr_cpu=self.paged_kv_indptr.cpu[: num_input_tokens + 1],
|
||||
indices=paged_kv_indices,
|
||||
last_page_len_cpu=self.paged_kv_last_page_len.cpu[
|
||||
:num_input_tokens
|
||||
],
|
||||
num_qo_heads=self.num_qo_heads * self.dcp_world_size,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_dim=self.head_dim,
|
||||
page_size=self.page_size,
|
||||
# Disable flashinfer's pos encoding and use vllm's rope.
|
||||
pos_encoding_mode="NONE",
|
||||
sm_scale=self.sm_scale,
|
||||
@@ -1330,32 +1320,15 @@ class FlashInferImpl(AttentionImpl):
|
||||
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
|
||||
if self.kv_sharing_target_layer_name is None:
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
|
||||
# not padded. However, we don't need to do key[:num_actual_tokens]
|
||||
# and value[:num_actual_tokens] because the reshape_and_cache_flash
|
||||
# op uses the slot_mapping's shape to determine the number of
|
||||
# actual tokens.
|
||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
kv_cache[:, 0],
|
||||
kv_cache[:, 1],
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
|
||||
# to process the cache when the kv_cache_dtype is fp8
|
||||
if self.kv_sharing_target_layer_name is None and self.kv_cache_dtype.startswith(
|
||||
"fp8"
|
||||
):
|
||||
torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
|
||||
self.kv_cache_dtype
|
||||
)
|
||||
|
||||
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
|
||||
# to process the cache when the kv_cache_dtype is fp8
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
|
||||
self.kv_cache_dtype
|
||||
)
|
||||
kv_cache = kv_cache.view(torch_dtype)
|
||||
kv_cache = kv_cache.view(torch_dtype)
|
||||
|
||||
# Inputs and outputs may be padded for CUDA graphs
|
||||
query = query[:num_actual_tokens]
|
||||
@@ -1599,13 +1572,39 @@ class FlashInferImpl(AttentionImpl):
|
||||
)
|
||||
return output_padded
|
||||
|
||||
def do_kv_cache_update(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
) -> None:
|
||||
if self.kv_sharing_target_layer_name is None:
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
|
||||
# not padded. However, we don't need to do key[:num_actual_tokens]
|
||||
# and value[:num_actual_tokens] because the reshape_and_cache_flash
|
||||
# op uses the slot_mapping's shape to determine the number of
|
||||
# actual tokens.
|
||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
kv_cache[:, 0],
|
||||
kv_cache[:, 1],
|
||||
slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
|
||||
def fast_plan_decode(
|
||||
self, # decode wrapper
|
||||
indptr_cpu: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
last_page_len_cpu: torch.Tensor,
|
||||
seq_lens_cpu: torch.Tensor,
|
||||
num_qo_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_dim: int,
|
||||
@@ -1642,110 +1641,56 @@ def fast_plan_decode(
|
||||
# this warm up is to generate the _cached_module for the decode wrapper.
|
||||
if not self.is_cuda_graph_enabled or getattr(self, "vllm_first_call", True):
|
||||
self.plan(
|
||||
indptr_cpu,
|
||||
indices,
|
||||
last_page_len_cpu,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
page_size,
|
||||
pos_encoding_mode,
|
||||
window_left,
|
||||
logits_soft_cap,
|
||||
q_data_type,
|
||||
kv_data_type,
|
||||
o_data_type,
|
||||
data_type,
|
||||
sm_scale,
|
||||
rope_scale,
|
||||
rope_theta,
|
||||
non_blocking,
|
||||
None, # block_tables
|
||||
None, # seq_lens
|
||||
fixed_split_size,
|
||||
disable_split_kv,
|
||||
indptr=indptr_cpu,
|
||||
indices=indices,
|
||||
last_page_len=last_page_len_cpu,
|
||||
num_qo_heads=num_qo_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_dim,
|
||||
page_size=page_size,
|
||||
pos_encoding_mode=pos_encoding_mode,
|
||||
window_left=window_left,
|
||||
logits_soft_cap=logits_soft_cap,
|
||||
q_data_type=q_data_type,
|
||||
kv_data_type=kv_data_type,
|
||||
o_data_type=o_data_type,
|
||||
data_type=data_type,
|
||||
sm_scale=sm_scale,
|
||||
rope_scale=rope_scale,
|
||||
rope_theta=rope_theta,
|
||||
non_blocking=non_blocking,
|
||||
block_tables=None,
|
||||
seq_lens=None,
|
||||
fixed_split_size=fixed_split_size,
|
||||
disable_split_kv=disable_split_kv,
|
||||
)
|
||||
self.vllm_first_call = False
|
||||
return
|
||||
|
||||
assert self.is_cuda_graph_enabled, "Should be cudagraph only here"
|
||||
|
||||
batch_size = len(last_page_len_cpu)
|
||||
if logits_soft_cap is None:
|
||||
logits_soft_cap = 0.0
|
||||
|
||||
# Handle data types consistently
|
||||
if data_type is not None:
|
||||
if q_data_type is None:
|
||||
q_data_type = data_type
|
||||
if kv_data_type is None:
|
||||
kv_data_type = data_type
|
||||
elif q_data_type is None:
|
||||
q_data_type = "float16"
|
||||
|
||||
if kv_data_type is None:
|
||||
kv_data_type = q_data_type
|
||||
q_data_type = (
|
||||
getattr(torch, q_data_type) if isinstance(q_data_type, str) else q_data_type
|
||||
fast_decode_plan(
|
||||
self,
|
||||
indptr=indptr_cpu,
|
||||
indices=indices,
|
||||
last_page_len=last_page_len_cpu,
|
||||
num_qo_heads=num_qo_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_dim,
|
||||
page_size=page_size,
|
||||
pos_encoding_mode=pos_encoding_mode,
|
||||
window_left=window_left,
|
||||
logits_soft_cap=logits_soft_cap,
|
||||
q_data_type=q_data_type,
|
||||
kv_data_type=kv_data_type,
|
||||
data_type=data_type,
|
||||
sm_scale=sm_scale,
|
||||
rope_scale=rope_scale,
|
||||
rope_theta=rope_theta,
|
||||
non_blocking=non_blocking,
|
||||
fixed_split_size=fixed_split_size,
|
||||
disable_split_kv=disable_split_kv,
|
||||
)
|
||||
kv_data_type = (
|
||||
getattr(torch, kv_data_type) if isinstance(kv_data_type, str) else kv_data_type
|
||||
)
|
||||
|
||||
if batch_size != self._fixed_batch_size:
|
||||
raise ValueError(
|
||||
"The batch size should be fixed in cudagraph mode, the runtime "
|
||||
"batch size {} mismatches the batch size set during "
|
||||
"initialization {}".format(batch_size, self._fixed_batch_size)
|
||||
)
|
||||
if len(indices) > len(self._paged_kv_indices_buf):
|
||||
raise ValueError(
|
||||
"The size of indices should be less than or equal to the allocated buffer"
|
||||
)
|
||||
|
||||
# host-to-device copy for the indptr buffer
|
||||
self._paged_kv_indptr_buf.copy_(indptr_cpu, non_blocking=True)
|
||||
# host-to-device copy for the last_page_len buffer
|
||||
self._paged_kv_last_page_len_buf.copy_(last_page_len_cpu, non_blocking=True)
|
||||
|
||||
qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
|
||||
|
||||
try:
|
||||
# Make sure we pass exactly 19 arguments for tensor core version
|
||||
args = [
|
||||
self._float_workspace_buffer,
|
||||
self._int_workspace_buffer,
|
||||
self._pin_memory_int_workspace_buffer,
|
||||
qo_indptr_host,
|
||||
indptr_cpu,
|
||||
seq_lens_cpu,
|
||||
batch_size, # total_num_rows
|
||||
batch_size,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
page_size,
|
||||
self.is_cuda_graph_enabled,
|
||||
head_dim,
|
||||
head_dim,
|
||||
False, # causal
|
||||
window_left,
|
||||
]
|
||||
if self._backend == "fa2":
|
||||
args.append(fixed_split_size)
|
||||
args.append(disable_split_kv)
|
||||
args.append(0) # num_colocated_ctas
|
||||
self._plan_info = self._cached_module.plan(
|
||||
*args,
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error in tensor core plan: {e}") from e
|
||||
|
||||
self._pos_encoding_mode = pos_encoding_mode
|
||||
self._window_left = window_left
|
||||
self._logits_soft_cap = logits_soft_cap
|
||||
self._sm_scale = sm_scale
|
||||
self._rope_scale = rope_scale
|
||||
self._rope_theta = rope_theta
|
||||
|
||||
|
||||
@triton.jit
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, replace
|
||||
from typing import Any
|
||||
|
||||
from vllm.v1.attention.backend import AttentionBackend
|
||||
from vllm.v1.attention.backend import AttentionBackend, CommonAttentionMetadata
|
||||
from vllm.v1.attention.backends.mamba_attn import (
|
||||
BaseMambaAttentionMetadata,
|
||||
BaseMambaAttentionMetadataBuilder,
|
||||
@@ -29,3 +30,31 @@ class Mamba1AttentionMetadataBuilder(
|
||||
BaseMambaAttentionMetadataBuilder[Mamba1AttentionMetadata]
|
||||
):
|
||||
metadata_cls = Mamba1AttentionMetadata
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Mamba1AttentionMetadata:
|
||||
common = self._compute_common_metadata(common_attn_metadata)
|
||||
|
||||
if (
|
||||
common.num_prefills > 0
|
||||
and self.vllm_config.cache_config.mamba_cache_mode == "all"
|
||||
):
|
||||
cu_chunk_seqlen_p, _, last_chunk_indices_p = (
|
||||
self._build_chunk_metadata_tensors(
|
||||
self.kv_cache_spec.block_size,
|
||||
common,
|
||||
common_attn_metadata,
|
||||
)
|
||||
)
|
||||
return replace(
|
||||
common,
|
||||
cu_chunk_seqlen_p=cu_chunk_seqlen_p,
|
||||
last_chunk_indices_p=last_chunk_indices_p,
|
||||
)
|
||||
|
||||
return common
|
||||
|
||||
@@ -7,7 +7,6 @@ from typing import Any
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
CommonAttentionMetadata,
|
||||
@@ -105,14 +104,6 @@ class Mamba2AttentionMetadata(BaseMambaAttentionMetadata):
|
||||
|
||||
# Chunk-related metadata (only for prefill)
|
||||
seq_idx_p: torch.Tensor | None = None
|
||||
# cu_chunk_seqlen_p is a tensor of shape (nchunks+1,) that contains, for
|
||||
# each chunk, its offsets into the varlen sequence dimension. It is defined
|
||||
# such that the i-th chunk contains tokens from cu_chunk_seqlen_p[i] to
|
||||
# cu_chunk_seqlen_p[i+1].
|
||||
cu_chunk_seqlen_p: torch.Tensor | None = None
|
||||
# last_chunk_indices_p is a tensor of shape (batch,) that contains the
|
||||
# index of the last chunk for every sequence in the (prefill) batch.
|
||||
last_chunk_indices_p: torch.Tensor | None = None
|
||||
|
||||
|
||||
class Mamba2AttentionMetadataBuilder(
|
||||
@@ -134,68 +125,6 @@ class Mamba2AttentionMetadataBuilder(
|
||||
)
|
||||
self.chunk_size: int = chunk_size
|
||||
|
||||
def _compute_chunk_metadata(
|
||||
self,
|
||||
num_prefills: int,
|
||||
num_computed_tokens_p_cpu: torch.Tensor,
|
||||
query_start_loc_p_cpu: torch.Tensor,
|
||||
) -> tuple[list[int], list[int], list[int]]:
|
||||
"""
|
||||
Compute chunk-specific metadata for Mamba2.
|
||||
|
||||
The code below carefully constructs the chunks such that:
|
||||
1. Chunks contain tokens from a *single* sequence only.
|
||||
2. For every sequence, we are guaranteed that we can
|
||||
retrieve the mamba state *every* chunk_size tokens.
|
||||
Constraint (1) dramatically simplifies the mamba2 kernels.
|
||||
Constraint (2) dramatically simplifies the implementation
|
||||
of prefix caching for mamba2 (wip). We need to take care
|
||||
of the interaction with chunked prefill in order to
|
||||
satisfy constraint (2).
|
||||
"""
|
||||
# TODO (tdoublep): This code could probably be optimized.
|
||||
cu_chunk_seqlen = []
|
||||
seq_idx = []
|
||||
last_chunk_indices = []
|
||||
seqlen_pos = 0
|
||||
|
||||
for req_idx in range(num_prefills):
|
||||
this_num_computed = num_computed_tokens_p_cpu[req_idx].item()
|
||||
this_new_tokens = (
|
||||
query_start_loc_p_cpu[req_idx + 1].item()
|
||||
- query_start_loc_p_cpu[req_idx].item()
|
||||
)
|
||||
|
||||
# if computed tokens are not chunk-aligned, use the first
|
||||
# chunk to finish it off
|
||||
if this_num_computed % self.chunk_size != 0:
|
||||
seq_idx.append(req_idx)
|
||||
cu_chunk_seqlen.append(seqlen_pos)
|
||||
# how many tokens to finish the chunk?
|
||||
chunk_len = (
|
||||
cdiv(this_num_computed, self.chunk_size) * self.chunk_size
|
||||
- this_num_computed
|
||||
)
|
||||
# we can only use at most this_new_tokens
|
||||
chunk_len = min(chunk_len, this_new_tokens)
|
||||
seqlen_pos += chunk_len
|
||||
this_new_tokens -= chunk_len
|
||||
|
||||
n_chunks = cdiv(this_new_tokens, self.chunk_size)
|
||||
for chunk in range(n_chunks):
|
||||
seq_idx.append(req_idx)
|
||||
cu_chunk_seqlen.append(seqlen_pos)
|
||||
chunk_len = min(self.chunk_size, this_new_tokens)
|
||||
seqlen_pos += chunk_len
|
||||
this_new_tokens -= chunk_len
|
||||
|
||||
assert this_new_tokens == 0
|
||||
last_chunk_indices.append(len(cu_chunk_seqlen) - 1)
|
||||
|
||||
cu_chunk_seqlen.append(seqlen_pos)
|
||||
|
||||
return cu_chunk_seqlen, seq_idx, last_chunk_indices
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
@@ -220,41 +149,12 @@ class Mamba2AttentionMetadataBuilder(
|
||||
else False
|
||||
)
|
||||
|
||||
num_reqs = common.num_reqs
|
||||
num_prefills = common.num_prefills
|
||||
num_decode_tokens = common.num_decode_tokens
|
||||
|
||||
num_computed_tokens_cpu = (
|
||||
common_attn_metadata.compute_num_computed_tokens().cpu()
|
||||
)
|
||||
num_computed_tokens_p_cpu = num_computed_tokens_cpu[
|
||||
num_reqs - num_prefills : num_reqs
|
||||
]
|
||||
query_start_loc_p_cpu = (
|
||||
common_attn_metadata.query_start_loc_cpu[-num_prefills - 1 :]
|
||||
- num_decode_tokens
|
||||
)
|
||||
|
||||
cu_chunk_seqlen, seq_idx, last_chunk_indices = self._compute_chunk_metadata(
|
||||
num_prefills,
|
||||
num_computed_tokens_p_cpu,
|
||||
query_start_loc_p_cpu,
|
||||
)
|
||||
|
||||
seq_idx_p = torch.as_tensor(
|
||||
seq_idx,
|
||||
device=common_attn_metadata.query_start_loc.device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
cu_chunk_seqlen_p = torch.as_tensor(
|
||||
cu_chunk_seqlen,
|
||||
device=common_attn_metadata.query_start_loc.device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
last_chunk_indices_p = torch.as_tensor(
|
||||
last_chunk_indices,
|
||||
device=common_attn_metadata.query_start_loc.device,
|
||||
dtype=torch.int32,
|
||||
cu_chunk_seqlen_p, seq_idx_p, last_chunk_indices_p = (
|
||||
self._build_chunk_metadata_tensors(
|
||||
self.chunk_size,
|
||||
common,
|
||||
common_attn_metadata,
|
||||
)
|
||||
)
|
||||
|
||||
return replace(
|
||||
|
||||
@@ -59,6 +59,15 @@ class BaseMambaAttentionMetadata:
|
||||
# The following tensor is only used for prefix caching in align mode
|
||||
seq_lens: torch.Tensor
|
||||
|
||||
# cu_chunk_seqlen_p is a tensor of shape (nchunks+1,) that contains, for
|
||||
# each chunk, its offsets into the varlen sequence dimension. It is defined
|
||||
# such that the i-th chunk contains tokens from cu_chunk_seqlen_p[i] to
|
||||
# cu_chunk_seqlen_p[i+1].
|
||||
cu_chunk_seqlen_p: torch.Tensor | None = None
|
||||
# last_chunk_indices_p is a tensor of shape (batch,) that contains the
|
||||
# index of the last chunk for every sequence in the (prefill) batch.
|
||||
last_chunk_indices_p: torch.Tensor | None = None
|
||||
|
||||
# The following attributes are for triton implementation of causal_conv1d
|
||||
nums_dict: dict | None = None
|
||||
batch_ptr: torch.Tensor | None = None
|
||||
@@ -185,6 +194,118 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
||||
common_attn_metadata, num_accepted_tokens=num_accepted_tokens
|
||||
)
|
||||
|
||||
def _compute_chunk_metadata(
|
||||
self,
|
||||
chunk_size: int,
|
||||
num_prefills: int,
|
||||
num_computed_tokens_p_cpu: torch.Tensor,
|
||||
query_start_loc_p_cpu: torch.Tensor,
|
||||
) -> tuple[list[int], list[int], list[int]]:
|
||||
"""
|
||||
Compute chunk-specific metadata for Mamba models.
|
||||
|
||||
The code below carefully constructs the chunks such that:
|
||||
1. Chunks contain tokens from a *single* sequence only.
|
||||
2. For every sequence, we are guaranteed that we can
|
||||
retrieve the mamba state *every* chunk_size tokens.
|
||||
Constraint (1) dramatically simplifies the mamba kernels.
|
||||
Constraint (2) dramatically simplifies the implementation
|
||||
of prefix caching for mamba (wip). We need to take care
|
||||
of the interaction with chunked prefill in order to
|
||||
satisfy constraint (2).
|
||||
"""
|
||||
# TODO (tdoublep): This code could probably be optimized.
|
||||
cu_chunk_seqlen = []
|
||||
seq_idx = []
|
||||
last_chunk_indices = []
|
||||
seqlen_pos = 0
|
||||
|
||||
for req_idx in range(num_prefills):
|
||||
this_num_computed = num_computed_tokens_p_cpu[req_idx].item()
|
||||
this_new_tokens = (
|
||||
query_start_loc_p_cpu[req_idx + 1].item()
|
||||
- query_start_loc_p_cpu[req_idx].item()
|
||||
)
|
||||
|
||||
# if computed tokens are not chunk-aligned, use the first
|
||||
# chunk to finish it off
|
||||
if this_num_computed % chunk_size != 0:
|
||||
seq_idx.append(req_idx)
|
||||
cu_chunk_seqlen.append(seqlen_pos)
|
||||
# how many tokens to finish the chunk?
|
||||
chunk_len = (
|
||||
cdiv(this_num_computed, chunk_size) * chunk_size - this_num_computed
|
||||
)
|
||||
# we can only use at most this_new_tokens
|
||||
chunk_len = min(chunk_len, this_new_tokens)
|
||||
seqlen_pos += chunk_len
|
||||
this_new_tokens -= chunk_len
|
||||
|
||||
n_chunks = cdiv(this_new_tokens, chunk_size)
|
||||
for chunk in range(n_chunks):
|
||||
seq_idx.append(req_idx)
|
||||
cu_chunk_seqlen.append(seqlen_pos)
|
||||
chunk_len = min(chunk_size, this_new_tokens)
|
||||
seqlen_pos += chunk_len
|
||||
this_new_tokens -= chunk_len
|
||||
|
||||
assert this_new_tokens == 0
|
||||
last_chunk_indices.append(len(cu_chunk_seqlen) - 1)
|
||||
|
||||
cu_chunk_seqlen.append(seqlen_pos)
|
||||
|
||||
return cu_chunk_seqlen, seq_idx, last_chunk_indices
|
||||
|
||||
def _build_chunk_metadata_tensors(
|
||||
self,
|
||||
chunk_size: int,
|
||||
common: M,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Compute chunk metadata and return as device tensors.
|
||||
Returns (cu_chunk_seqlen_p, seq_idx_p, last_chunk_indices_p).
|
||||
"""
|
||||
num_reqs = common.num_reqs
|
||||
num_prefills = common.num_prefills
|
||||
num_decode_tokens = common.num_decode_tokens
|
||||
|
||||
num_computed_tokens_cpu = (
|
||||
common_attn_metadata.compute_num_computed_tokens().cpu()
|
||||
)
|
||||
num_computed_tokens_p_cpu = num_computed_tokens_cpu[
|
||||
num_reqs - num_prefills : num_reqs
|
||||
]
|
||||
query_start_loc_p_cpu = (
|
||||
common_attn_metadata.query_start_loc_cpu[-num_prefills - 1 :]
|
||||
- num_decode_tokens
|
||||
)
|
||||
|
||||
cu_chunk_seqlen, seq_idx, last_chunk_indices = self._compute_chunk_metadata(
|
||||
chunk_size,
|
||||
num_prefills,
|
||||
num_computed_tokens_p_cpu,
|
||||
query_start_loc_p_cpu,
|
||||
)
|
||||
|
||||
device = common_attn_metadata.query_start_loc.device
|
||||
cu_chunk_seqlen_p = torch.as_tensor(
|
||||
cu_chunk_seqlen,
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
seq_idx_p = torch.as_tensor(
|
||||
seq_idx,
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
last_chunk_indices_p = torch.as_tensor(
|
||||
last_chunk_indices,
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
return cu_chunk_seqlen_p, seq_idx_p, last_chunk_indices_p
|
||||
|
||||
def _compute_prefix_caching_block_indices(
|
||||
self,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
|
||||
@@ -191,6 +191,8 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
|
||||
query_start_loc_device: torch.Tensor,
|
||||
num_decode_tokens: int,
|
||||
dcp_tot_seq_lens_device: torch.Tensor | None,
|
||||
max_decode_seq_len: int = 0,
|
||||
use_cuda_graph: bool = False,
|
||||
) -> FlashAttnMLADecodeMetadata:
|
||||
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
max_query_len = query_lens_cpu.max().item()
|
||||
@@ -239,12 +241,14 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
|
||||
metadata = FlashAttnMLADecodeMetadata(
|
||||
block_table=block_table_tensor,
|
||||
seq_lens=seq_lens_device,
|
||||
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
|
||||
max_decode_seq_len=max_decode_seq_len,
|
||||
use_cuda_graph=use_cuda_graph,
|
||||
query_start_loc=query_start_loc_device,
|
||||
max_query_len=max_query_len,
|
||||
max_seq_len=max_seq_len,
|
||||
scheduler_metadata=scheduler_metadata,
|
||||
max_num_splits=max_num_splits,
|
||||
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
|
||||
)
|
||||
return metadata
|
||||
|
||||
|
||||
@@ -156,6 +156,8 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
query_start_loc_device: torch.Tensor,
|
||||
num_decode_tokens: int,
|
||||
dcp_tot_seq_lens_device: torch.Tensor | None,
|
||||
max_decode_seq_len: int = 0,
|
||||
use_cuda_graph: bool = False,
|
||||
) -> FlashMLADecodeMetadata:
|
||||
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
# we use the max but all should be the same due to uniform length requirement
|
||||
@@ -179,8 +181,10 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
return FlashMLADecodeMetadata(
|
||||
block_table=block_table_tensor,
|
||||
seq_lens=seq_lens_device,
|
||||
scheduler_metadata=scheduler_metadata,
|
||||
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
|
||||
max_decode_seq_len=max_decode_seq_len,
|
||||
use_cuda_graph=use_cuda_graph,
|
||||
scheduler_metadata=scheduler_metadata,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -13,6 +13,11 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention.mla_attention import (
|
||||
get_mla_dims,
|
||||
)
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
LinearBase,
|
||||
UnquantizedLinearMethod,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.utils.platform_utils import num_compute_units
|
||||
@@ -37,13 +42,17 @@ from vllm.v1.attention.backends.utils import (
|
||||
)
|
||||
from vllm.v1.attention.ops.flashmla import (
|
||||
FlashMLASchedMeta,
|
||||
flash_mla_sparse_fwd,
|
||||
flash_mla_sparse_prefill,
|
||||
flash_mla_with_kvcache,
|
||||
get_mla_metadata,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.workspace import current_workspace_manager
|
||||
|
||||
import functools
|
||||
from vllm import envs
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import scaled_dequantize
|
||||
import ixformer.inference.functions as ixf_ops
|
||||
import numpy as np
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.models.deepseek_v2 import Indexer
|
||||
|
||||
@@ -74,7 +83,15 @@ structured as:
|
||||
- **Last 128 bytes:** The "RoPE" part, containing 64 `bfloat16` values. This
|
||||
part is not quantized for accuracy.
|
||||
"""
|
||||
|
||||
def dynamic_per_batched_tensor_quant(
|
||||
x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn
|
||||
):
|
||||
DTYPE_MAX = torch.finfo(dtype).max
|
||||
min_val, max_val = x.aminmax()
|
||||
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-10)
|
||||
scale = DTYPE_MAX / amax
|
||||
x_scl_sat = (x * scale).clamp(min=-DTYPE_MAX, max=DTYPE_MAX)
|
||||
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
|
||||
|
||||
class FlashMLASparseBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
@@ -558,6 +575,11 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.kv_lora_rank: int = mla_args["kv_lora_rank"]
|
||||
self.qk_nope_head_dim = mla_args["qk_nope_head_dim"]
|
||||
self.qk_rope_head_dim = mla_args["qk_rope_head_dim"]
|
||||
self.qk_head_dim = mla_args["qk_head_dim"]
|
||||
self.v_head_dim = mla_args["v_head_dim"]
|
||||
self.kv_b_proj = mla_args["kv_b_proj"]
|
||||
self.softmax_scale = scale
|
||||
assert indexer is not None
|
||||
self.topk_indices_buffer: torch.Tensor | None = indexer.topk_indices_buffer
|
||||
@@ -580,6 +602,65 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
|
||||
(self.prefill_workspace_shape, torch.bfloat16)
|
||||
)
|
||||
)
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
def get_layer_weight(layer):
|
||||
WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
|
||||
for attr in WEIGHT_NAMES:
|
||||
if hasattr(layer, attr):
|
||||
return getattr(layer, attr)
|
||||
raise AttributeError(
|
||||
f"Layer '{layer}' has no recognized weight attribute: {WEIGHT_NAMES}."
|
||||
)
|
||||
|
||||
def get_and_maybe_dequant_weights(layer: LinearBase):
|
||||
if layer.quant_method is not None and not isinstance(
|
||||
layer.quant_method, UnquantizedLinearMethod
|
||||
):
|
||||
# NOTE: This should only be used offline, since it's O(N^3)
|
||||
eye = torch.eye(
|
||||
layer.input_size_per_partition,
|
||||
dtype=act_dtype,
|
||||
device=get_layer_weight(layer).device,
|
||||
)
|
||||
dequant_weights = layer.quant_method.apply(layer, eye, bias=None)
|
||||
del eye
|
||||
# standardize to (output, input)
|
||||
return dequant_weights.T
|
||||
return layer.weight
|
||||
|
||||
# we currently do not have quantized bmm's which are needed for
|
||||
# `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform
|
||||
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
|
||||
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
|
||||
assert kv_b_proj_weight.shape == (
|
||||
self.kv_lora_rank,
|
||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
||||
), (
|
||||
f"{kv_b_proj_weight.shape=}, "
|
||||
f"{self.kv_lora_rank=}, "
|
||||
f"{self.num_heads=}, "
|
||||
f"{self.qk_nope_head_dim=}, "
|
||||
f"{self.v_head_dim=}"
|
||||
)
|
||||
kv_b_proj_weight = kv_b_proj_weight.view(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads,
|
||||
self.qk_nope_head_dim + self.v_head_dim,
|
||||
)
|
||||
|
||||
W_UK, W_UV = kv_b_proj_weight.split(
|
||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1
|
||||
)
|
||||
self.W_UV = W_UV
|
||||
self.W_UK = W_UK
|
||||
# self.W_UK_T = W_UK.permute(1, 2, 0)
|
||||
|
||||
def _v_up_proj(self, x: torch.Tensor):
|
||||
|
||||
return torch.einsum("bnl,lnv->bnv", x, self.W_UV)
|
||||
def _k_up_proj(self, q_nope):
|
||||
|
||||
return torch.einsum("bnp,lnp->bnl", q_nope, self.W_UK).view(-1, self.num_heads, self.kv_lora_rank)
|
||||
|
||||
def _forward_bf16_kv(
|
||||
self,
|
||||
@@ -590,12 +671,11 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
|
||||
) -> torch.Tensor:
|
||||
# Convert per-request indices to global slots (decode) or workspace
|
||||
# offsets (prefill).
|
||||
topk_indices = triton_convert_req_index_to_global_index(
|
||||
topk_indices = ops.dsa_convert_req_index_to_global_index(
|
||||
attn_metadata.req_id_per_token,
|
||||
attn_metadata.block_table,
|
||||
topk_indices,
|
||||
BLOCK_SIZE=attn_metadata.block_size,
|
||||
NUM_TOPK_TOKENS=topk_indices.shape[1],
|
||||
attn_metadata.block_size,
|
||||
)
|
||||
|
||||
return self._bf16_flash_mla_kernel(q, kv_c_and_k_pe_cache, topk_indices)
|
||||
@@ -790,22 +870,10 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
|
||||
-1, 1, kv_c_and_k_pe_cache.shape[-1]
|
||||
)
|
||||
|
||||
# NOTE(Chen): kernel requires num_local_head to be a multiple of
|
||||
# 64 on hopper and 128 on blackwell
|
||||
if self.num_heads % self.prefill_padding != 0:
|
||||
assert self.prefill_padding % self.num_heads == 0
|
||||
logger.warning_once(
|
||||
f"Padding num_heads from {self.num_heads} to "
|
||||
f"{self.prefill_padding} for BF16 sparse prefill kernel"
|
||||
)
|
||||
q_padded = q.new_empty((q.shape[0], self.prefill_padding, q.shape[2]))
|
||||
q_padded[:, : self.num_heads, :] = q
|
||||
q = q_padded
|
||||
|
||||
topk_indices = topk_indices.view(num_tokens, 1, -1)
|
||||
output = flash_mla_sparse_fwd(
|
||||
output = flash_mla_sparse_prefill(
|
||||
q, kv_c_and_k_pe_cache, topk_indices, self.softmax_scale
|
||||
)[0]
|
||||
)
|
||||
output = output[:, : self.num_heads, :]
|
||||
return output
|
||||
|
||||
@@ -843,5 +911,5 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
|
||||
attn_out = self._forward_fp8_kv_separate_prefill_decode(
|
||||
q, kv_c_and_k_pe_cache, topk_indices, attn_metadata
|
||||
)
|
||||
|
||||
return attn_out, None
|
||||
|
||||
return attn_out
|
||||
|
||||
@@ -8,7 +8,11 @@ import torch
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata, has_deep_gemm
|
||||
from vllm.utils.deep_gemm import (
|
||||
get_paged_mqa_logits_metadata,
|
||||
is_deep_gemm_supported,
|
||||
)
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.platform_utils import num_compute_units
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
@@ -21,6 +25,7 @@ from vllm.v1.attention.backends.utils import (
|
||||
split_decodes_and_prefills,
|
||||
split_prefill_chunks,
|
||||
)
|
||||
from vllm.v1.worker.cp_utils import get_total_cp_world_size
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -68,11 +73,15 @@ class DeepseekV32IndexerPrefillChunkMetadata:
|
||||
cu_seqlen_ks: torch.Tensor
|
||||
cu_seqlen_ke: torch.Tensor
|
||||
cu_seq_lens: torch.Tensor
|
||||
cu_seqlens_q: torch.Tensor
|
||||
token_to_seq: torch.Tensor
|
||||
total_seq_lens: int
|
||||
token_start: int
|
||||
token_end: int
|
||||
num_reqs: int
|
||||
max_context_len: int
|
||||
max_q_len: int # Maximum query length for dsa_indexer_mqa_logits_with_blocks
|
||||
max_kv_len: int # Maximum key-value length for dsa_indexer_mqa_logits_with_blocks
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -86,9 +95,16 @@ class DeepSeekV32IndexerDecodeMetadata:
|
||||
seq_lens: torch.Tensor
|
||||
decode_lens: torch.Tensor
|
||||
requires_padding: bool
|
||||
schedule_metadata: torch.Tensor
|
||||
# schedule_metadata: torch.Tensor
|
||||
use_large_context_topk: bool
|
||||
offsets: torch.Tensor | None # Precomputed offsets for speculative decoding
|
||||
cu_seqlen_ks: torch.Tensor
|
||||
cu_seqlen_ke: torch.Tensor
|
||||
cu_seqlens_kv: torch.Tensor
|
||||
cu_seqlens_q: torch.Tensor
|
||||
max_context_len: int
|
||||
max_q_len: int
|
||||
max_kv_len: int
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -211,20 +227,39 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
||||
if self.vllm_config.speculative_config
|
||||
else 0
|
||||
)
|
||||
if self.num_speculative_tokens > 1:
|
||||
raise ValueError(
|
||||
"Sparse MLA only supports "
|
||||
"num_speculative_tokens <= 1 because the DeepGEMM "
|
||||
"fp8_paged_mqa_logits kernel does not support next_n > 2. "
|
||||
f"Got num_speculative_tokens={self.num_speculative_tokens}."
|
||||
)
|
||||
self.reorder_batch_threshold += self.num_speculative_tokens
|
||||
|
||||
sm_count = num_compute_units(self.device.index)
|
||||
self.num_sms = sm_count
|
||||
|
||||
self.decode_lens_buffer = torch.empty(
|
||||
(scheduler_config.max_num_seqs,), dtype=torch.int32, device=self.device
|
||||
(scheduler_config.max_num_batched_tokens,),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
# Pre-allocated buffers for flattening (spec decode).
|
||||
self.arange_buffer = torch.arange(
|
||||
scheduler_config.max_num_seqs * (1 + self.num_speculative_tokens),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
self.expanded_seq_lens_buffer = torch.zeros(
|
||||
(scheduler_config.max_num_batched_tokens,),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
max_num_blocks_per_req = cdiv(
|
||||
self.vllm_config.model_config.max_model_len,
|
||||
self.kv_cache_spec.block_size * get_total_cp_world_size(),
|
||||
)
|
||||
self.expanded_block_table_buffer = torch.zeros(
|
||||
(
|
||||
scheduler_config.max_num_batched_tokens,
|
||||
max_num_blocks_per_req,
|
||||
),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
# See: DeepGMM/csrc/apis/attention.hpp
|
||||
@@ -260,18 +295,88 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
||||
.to(torch.int32)
|
||||
.to(self.device)
|
||||
)
|
||||
cu_seqlens_q = prefill_query_start_loc.to(torch.int32).to(self.device)
|
||||
max_context_len = seq_lens_cpu[reqs_start:reqs_end].max().item()
|
||||
# max_q_len is the maximum query length among all batches in this chunk
|
||||
# prefill_query_start_loc is cumsum of lengths with shape [batch+1]
|
||||
max_q_len = (prefill_query_start_loc[1:] - prefill_query_start_loc[:-1]).max().item()
|
||||
return DeepseekV32IndexerPrefillChunkMetadata(
|
||||
cu_seqlen_ks=cu_seqlen_ks,
|
||||
cu_seqlen_ke=cu_seqlen_ke,
|
||||
cu_seq_lens=cu_seq_lens,
|
||||
token_to_seq=token_to_seq,
|
||||
total_seq_lens=total_seq_lens,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
block_table=block_table[reqs_start:reqs_end],
|
||||
token_start=token_start,
|
||||
token_end=token_end,
|
||||
num_reqs=reqs_end - reqs_start,
|
||||
max_context_len=max_context_len,
|
||||
max_q_len=max_q_len,
|
||||
max_kv_len=max_context_len
|
||||
)
|
||||
|
||||
def build_decode_metadata(
|
||||
self, common_attn_metadata, num_decodes, decode_lens, use_large_context_topk, offsets
|
||||
):
|
||||
decode_lens_cpu = torch.diff(
|
||||
common_attn_metadata.query_start_loc_cpu[: num_decodes + 1]
|
||||
)
|
||||
assert (
|
||||
decode_lens_cpu.max().item()
|
||||
== decode_lens_cpu.min().item()
|
||||
== 1
|
||||
), "Only support single token decode in dsa_indexer backend"
|
||||
|
||||
# Calculate decode metadata parameters
|
||||
seq_lens_decode = common_attn_metadata.seq_lens_cpu[:num_decodes]
|
||||
max_context_len = seq_lens_decode.max().item()
|
||||
max_kv_len = max_context_len
|
||||
max_q_len = 1 # Single token decode
|
||||
|
||||
# Create cu_seqlens_q: cumulative sum of query lengths (all 1s)
|
||||
cu_seqlens_q = torch.arange(
|
||||
num_decodes + 1, dtype=torch.int32, device=self.device
|
||||
)
|
||||
|
||||
# Create cu_seqlens_kv and related tensors using kv_spans_from_batches
|
||||
decode_query_start_loc = torch.arange(
|
||||
num_decodes + 1, dtype=torch.long
|
||||
)
|
||||
cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(
|
||||
decode_query_start_loc, seq_lens_decode, self.device
|
||||
)
|
||||
|
||||
cu_seqlens_kv = torch.cat(
|
||||
[
|
||||
torch.zeros(1, dtype=torch.int32, device=self.device),
|
||||
torch.cumsum(seq_lens_decode.to(self.device), dim=0)
|
||||
.to(torch.int32),
|
||||
]
|
||||
)
|
||||
|
||||
decode_metadata = DeepSeekV32IndexerDecodeMetadata(
|
||||
block_table=common_attn_metadata.block_table_tensor[
|
||||
:num_decodes, ...
|
||||
],
|
||||
seq_lens=common_attn_metadata.seq_lens[:num_decodes],
|
||||
decode_lens=decode_lens,
|
||||
requires_padding=(
|
||||
decode_lens_cpu.max() > decode_lens_cpu.min()
|
||||
).item(),
|
||||
use_large_context_topk=use_large_context_topk,
|
||||
offsets=offsets,
|
||||
cu_seqlen_ks=cu_seqlen_ks,
|
||||
cu_seqlen_ke=cu_seqlen_ke,
|
||||
cu_seqlens_kv=cu_seqlens_kv,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
max_context_len=max_context_len,
|
||||
max_q_len=max_q_len,
|
||||
max_kv_len=max_kv_len,
|
||||
# schedule_metadata=self.scheduler_metadata_buffer,
|
||||
)
|
||||
return decode_metadata
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
@@ -323,45 +428,103 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
||||
common_attn_metadata.query_start_loc_cpu[: num_decodes + 1]
|
||||
)
|
||||
|
||||
# Use CPU to avoid GPU sync; breaking async scheduling
|
||||
requires_padding = (decode_lens_cpu.max() > decode_lens_cpu.min()).item()
|
||||
|
||||
# Decide which top-k kernel to use based on batch size and sequence length
|
||||
batch_size = num_decodes
|
||||
_is_large_context = common_attn_metadata.max_seq_len > 8192
|
||||
|
||||
# Decision logic based on micro-benchmark results:
|
||||
# - large_context_topk wins for batch <= 128 and seq_len > 8K
|
||||
# - top_k_per_row_decode wins for batch > 128 or seq_len <= 8K
|
||||
use_large_context_topk = batch_size <= 128 and _is_large_context
|
||||
|
||||
next_n = 1 + self.num_speculative_tokens
|
||||
if next_n > 1:
|
||||
offsets = torch.arange(next_n, device=self.device, dtype=torch.int32)
|
||||
else:
|
||||
offsets = None
|
||||
|
||||
seq_lens = common_attn_metadata.seq_lens[:num_decodes]
|
||||
|
||||
# DeepGEMM is required for the paged MQA logits on CUDA devices
|
||||
if current_platform.is_cuda() and has_deep_gemm():
|
||||
self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
|
||||
seq_lens, self.kv_cache_spec.block_size, self.num_sms
|
||||
)
|
||||
block_table = common_attn_metadata.block_table_tensor[:num_decodes, ...]
|
||||
|
||||
# Padded CUDA graph requests have block_table entries of -1.
|
||||
# Clamp to 0 to prevent OOB access in the DeepGEMM kernel.
|
||||
# This is safe because padded requests have seq_lens=0, so the
|
||||
# kernel produces no meaningful output for those rows.
|
||||
block_table.clamp_(min=0)
|
||||
decode_metadata = DeepSeekV32IndexerDecodeMetadata(
|
||||
block_table=block_table,
|
||||
seq_lens=common_attn_metadata.seq_lens[:num_decodes],
|
||||
decode_lens=decode_lens,
|
||||
requires_padding=requires_padding,
|
||||
schedule_metadata=self.scheduler_metadata_buffer,
|
||||
use_large_context_topk=use_large_context_topk,
|
||||
offsets=offsets,
|
||||
|
||||
max_decode_len = int(decode_lens_cpu.max().item())
|
||||
if max_decode_len > 1:
|
||||
# Flatten multi-token decode requests into single-token
|
||||
# batch entries, expanding seq_lens and block tables so
|
||||
# the kernel always sees next_n=1.
|
||||
|
||||
# Assume 4 requests with seq_lens [10, 7, 12, 0] (the final req is
|
||||
# padding) and decode_lens [3, 1, 4, 0] in the below example comments.
|
||||
# The context lengths are therefore
|
||||
# [10-3, 7-1, 12-4, 0-0] = [7, 6, 8, 0].
|
||||
|
||||
# 3 + 1 + 4 + 0 = 8
|
||||
actual_expanded = int(decode_lens_cpu.sum().item())
|
||||
|
||||
# [7, 6, 8, 0] -> [7, 7, 7, 6, 8, 8, 8, 8]
|
||||
expanded_base = torch.repeat_interleave(
|
||||
seq_lens - decode_lens, decode_lens
|
||||
)
|
||||
|
||||
# [0, 3, 4, 8] -> [0, 0, 0, 3, 4, 4, 4, 4]
|
||||
expanded_starts = torch.repeat_interleave(
|
||||
common_attn_metadata.query_start_loc[:num_decodes], decode_lens
|
||||
)
|
||||
|
||||
# [0, 1, 2, 0, 0, 1, 2, 3]
|
||||
positions_within = (
|
||||
self.arange_buffer[:actual_expanded] - expanded_starts
|
||||
)
|
||||
|
||||
# [8, 9, 10, 7, 9, 10, 11, 12, ...] where ... is unused buffer space
|
||||
self.expanded_seq_lens_buffer[:actual_expanded] = (
|
||||
expanded_base + positions_within + 1
|
||||
)
|
||||
self.expanded_seq_lens_buffer[actual_expanded:] = 0
|
||||
seq_lens = self.expanded_seq_lens_buffer[:num_decode_tokens]
|
||||
|
||||
# Give each of the flattened entries the same block table row as the
|
||||
# original request.
|
||||
self.expanded_block_table_buffer[:actual_expanded] = (
|
||||
torch.repeat_interleave(block_table, decode_lens, dim=0)
|
||||
)
|
||||
if actual_expanded < num_decode_tokens:
|
||||
self.expanded_block_table_buffer[
|
||||
actual_expanded:num_decode_tokens, 0
|
||||
] = 0
|
||||
block_table = self.expanded_block_table_buffer[:num_decode_tokens]
|
||||
|
||||
# All reqs now have decode_len=1
|
||||
self.decode_lens_buffer[:num_decode_tokens] = 1
|
||||
decode_lens = self.decode_lens_buffer[:num_decode_tokens]
|
||||
offsets = None
|
||||
batch_size = num_decode_tokens
|
||||
else:
|
||||
next_n = 1 + self.num_speculative_tokens
|
||||
if next_n > 1:
|
||||
offsets = torch.arange(
|
||||
next_n, device=self.device, dtype=torch.int32
|
||||
)
|
||||
else:
|
||||
offsets = None
|
||||
batch_size = num_decodes
|
||||
|
||||
# DeepGEMM is required for the paged MQA logits on CUDA devices
|
||||
if current_platform.is_cuda() and is_deep_gemm_supported():
|
||||
self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
|
||||
seq_lens,
|
||||
self.kv_cache_spec.block_size,
|
||||
self.num_sms,
|
||||
)
|
||||
|
||||
# Decide which top-k kernel to use based on batch size and sequence length
|
||||
# Decision logic based on micro-benchmark results:
|
||||
# - large_context_topk wins for batch <= 128 and seq_len > 8K
|
||||
# - top_k_per_row_decode wins for batch > 128 or seq_len <= 8K
|
||||
_is_large_context = common_attn_metadata.max_seq_len > 8192
|
||||
use_large_context_topk = batch_size <= 128 and _is_large_context
|
||||
|
||||
# decode_metadata = DeepSeekV32IndexerDecodeMetadata(
|
||||
# block_table=block_table,
|
||||
# seq_lens=seq_lens,
|
||||
# decode_lens=decode_lens,
|
||||
# requires_padding=False,
|
||||
# # schedule_metadata=self.scheduler_metadata_buffer,
|
||||
# use_large_context_topk=use_large_context_topk,
|
||||
# offsets=offsets,
|
||||
# )
|
||||
decode_metadata = self.build_decode_metadata(
|
||||
common_attn_metadata, num_decodes, decode_lens, use_large_context_topk, offsets
|
||||
)
|
||||
|
||||
attn_metadata = DeepseekV32IndexerMetadata(
|
||||
|
||||
@@ -115,6 +115,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
query_start_loc_device: torch.Tensor,
|
||||
num_decode_tokens: int,
|
||||
dcp_tot_seq_lens_device: torch.Tensor | None,
|
||||
max_decode_seq_len: int = 0,
|
||||
use_cuda_graph: bool = False,
|
||||
) -> AiterMLADecodeMetadata:
|
||||
# kernel block size is always 1, although the kv block size is not 1.
|
||||
device = self.device
|
||||
@@ -170,11 +172,13 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
attn_metadata = AiterMLADecodeMetadata(
|
||||
block_table=block_table_tensor,
|
||||
seq_lens=seq_lens_device,
|
||||
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
|
||||
max_decode_seq_len=max_decode_seq_len,
|
||||
use_cuda_graph=use_cuda_graph,
|
||||
paged_kv_indptr=paged_kv_indptr,
|
||||
paged_kv_indices=paged_kv_indices,
|
||||
paged_kv_last_page_len=paged_kv_last_page_len,
|
||||
qo_indptr=qo_indptr,
|
||||
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
|
||||
max_qo_len=max_qo_len,
|
||||
attn_out_dtype=self.decode_attn_out_dtype,
|
||||
)
|
||||
|
||||
@@ -15,6 +15,7 @@ from vllm.model_executor.layers.attention.mla_attention import (
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.distributed.parallel_state import get_dcp_group
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionLayer,
|
||||
@@ -22,20 +23,19 @@ from vllm.v1.attention.backend import (
|
||||
is_quantized_kv_cache,
|
||||
)
|
||||
from vllm.v1.attention.ops.triton_decode_attention import decode_attention_fwd
|
||||
|
||||
import ixformer.inference.functions as ixf_ops
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.distributed.parallel_state import get_dcp_group
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TritonMLABackend(MLACommonBackend):
|
||||
# supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
# supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
# "auto",
|
||||
# "bfloat16",
|
||||
# ]
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"bfloat16",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
@@ -120,10 +120,9 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
# layer: AttentionLayer,
|
||||
k_c_normed: torch.Tensor |None = None,
|
||||
k_pe: torch.Tensor |None = None,
|
||||
kv_c_and_k_pe_cache_scale: torch.Tensor |None = None,
|
||||
k_c_normed: torch.Tensor | None,
|
||||
k_pe: torch.Tensor | None,
|
||||
kv_c_and_k_pe_cache_scale: torch.Tensor | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
@@ -136,7 +135,7 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
q_nope = q_nope.view(-1, self.num_heads, self.kv_lora_rank)
|
||||
|
||||
B = q_nope.shape[0]
|
||||
|
||||
|
||||
if self.dcp_world_size > 1:
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
q = get_dcp_group().all_gather(q, dim=1)
|
||||
@@ -147,7 +146,7 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
device=q_nope.device)
|
||||
if envs.VLLM_USE_INT8_MLA:
|
||||
q_int8, q_scale = ops.quant_kv(q)
|
||||
attn_out, softmax_lse = ixf_ops.ref_vllm_paged_attention_mla_int8(
|
||||
attn_out, softmax_lse = ixf_ops.vllm_paged_attention_mla_int8(
|
||||
o,
|
||||
q_int8,
|
||||
q_scale,
|
||||
@@ -160,7 +159,7 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
return_softmax_lse=True
|
||||
)
|
||||
else:
|
||||
attn_out, softmax_lse = ixf_ops.ref_vllm_paged_attention_mla(
|
||||
attn_out, softmax_lse = ixf_ops.vllm_paged_attention_mla(
|
||||
output=o,
|
||||
query=q,
|
||||
kv_cache=kv_c_and_k_pe_cache,
|
||||
@@ -170,12 +169,12 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
max_context_len=decode_meta.max_decode_seq_len,
|
||||
return_softmax_lse=True)
|
||||
return attn_out, softmax_lse
|
||||
|
||||
|
||||
o = torch.empty(B,
|
||||
self.num_heads,
|
||||
self.kv_lora_rank,
|
||||
dtype=q_nope.dtype,
|
||||
device=q_nope.device)
|
||||
device=q_nope.device)
|
||||
|
||||
if envs.VLLM_USE_INT8_MLA:
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
@@ -193,18 +192,30 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
attn_metadata.decode.use_cuda_graph
|
||||
)
|
||||
else:
|
||||
# fused q concat & cache write
|
||||
ixf_ops.vllm_paged_attention_mla_fused(
|
||||
output=o,
|
||||
q_nope=q_nope,
|
||||
q_pe=q_pe.contiguous(),
|
||||
kv_cache=kv_c_and_k_pe_cache,
|
||||
scale=self.scale,
|
||||
block_tables=attn_metadata.decode.block_table,
|
||||
context_lens=attn_metadata.decode.seq_lens,
|
||||
max_context_len=decode_meta.max_decode_seq_len,
|
||||
k_c_normed=k_c_normed,
|
||||
k_pe=k_pe,
|
||||
use_cuda_graph=decode_meta.use_cuda_graph
|
||||
)
|
||||
if k_c_normed is None:
|
||||
q = torch.cat([q_nope, q_pe.contiguous()], dim=-1)
|
||||
ixf_ops.vllm_paged_attention_mla(
|
||||
output=o,
|
||||
query=q,
|
||||
kv_cache=kv_c_and_k_pe_cache,
|
||||
scale=self.scale,
|
||||
block_tables=attn_metadata.decode.block_table,
|
||||
context_lens=attn_metadata.decode.seq_lens,
|
||||
max_context_len=decode_meta.max_decode_seq_len,
|
||||
use_cuda_graph=decode_meta.use_cuda_graph,
|
||||
)
|
||||
else:
|
||||
ixf_ops.vllm_paged_attention_mla_fused(
|
||||
output=o,
|
||||
q_nope=q_nope.contiguous(),
|
||||
q_pe=q_pe.contiguous(),
|
||||
kv_cache=kv_c_and_k_pe_cache,
|
||||
scale=self.scale,
|
||||
block_tables=attn_metadata.decode.block_table,
|
||||
context_lens=attn_metadata.decode.seq_lens,
|
||||
max_context_len=decode_meta.max_decode_seq_len,
|
||||
k_c_normed=k_c_normed,
|
||||
k_pe=k_pe,
|
||||
use_cuda_graph=decode_meta.use_cuda_graph,
|
||||
)
|
||||
return self._v_up_proj(o), None
|
||||
|
||||
@@ -55,6 +55,16 @@ class RocmAiterUnifiedAttentionBackend(RocmAttentionBackend):
|
||||
def get_builder_cls() -> type["RocmAttentionMetadataBuilder"]:
|
||||
return RocmAttentionMetadataBuilder
|
||||
|
||||
@classmethod
|
||||
def supports_attn_type(cls, attn_type: str) -> bool:
|
||||
"""RocmAiterUnifiedAttention supports all attention types."""
|
||||
return attn_type in (
|
||||
AttentionType.DECODER,
|
||||
AttentionType.ENCODER,
|
||||
AttentionType.ENCODER_ONLY,
|
||||
AttentionType.ENCODER_DECODER,
|
||||
)
|
||||
|
||||
|
||||
class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
|
||||
def fused_output_quant_supported(self, quant_key: QuantKey):
|
||||
@@ -143,6 +153,19 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
|
||||
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
|
||||
# Handle encoder attention differently - no KV cache needed
|
||||
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
|
||||
# For encoder attention,
|
||||
# we use direct Q, K, V tensors without caching
|
||||
return self._forward_encoder_attention(
|
||||
query[:num_actual_tokens],
|
||||
key[:num_actual_tokens],
|
||||
value[:num_actual_tokens],
|
||||
output[:num_actual_tokens],
|
||||
attn_metadata,
|
||||
layer,
|
||||
)
|
||||
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
@@ -195,6 +218,10 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
|
||||
kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
):
|
||||
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
|
||||
# For encoder attention,
|
||||
# we use direct Q, K, V tensors without caching
|
||||
return
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
@@ -224,6 +251,10 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
|
||||
kv_cache: torch.Tensor,
|
||||
layer_slot_mapping: torch.Tensor,
|
||||
):
|
||||
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
|
||||
# For encoder attention,
|
||||
# we use direct Q, K, V tensors without caching
|
||||
return
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
flash_layout = True
|
||||
|
||||
|
||||
@@ -182,7 +182,7 @@ class RocmAttentionBackend(AttentionBackend):
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [32, 64, 96, 128, 160, 192, 224, 256]
|
||||
return [32, 64, 80, 96, 128, 160, 192, 224, 256]
|
||||
|
||||
@classmethod
|
||||
def validate_head_size(cls, head_size: int) -> None:
|
||||
@@ -205,6 +205,16 @@ class RocmAttentionBackend(AttentionBackend):
|
||||
def get_impl_cls() -> type["RocmAttentionImpl"]:
|
||||
return RocmAttentionImpl
|
||||
|
||||
@classmethod
|
||||
def supports_attn_type(cls, attn_type: str) -> bool:
|
||||
"""RocmAttention supports all attention types."""
|
||||
return attn_type in (
|
||||
AttentionType.DECODER,
|
||||
AttentionType.ENCODER,
|
||||
AttentionType.ENCODER_ONLY,
|
||||
AttentionType.ENCODER_DECODER,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
@@ -244,6 +254,7 @@ class RocmAttentionImpl(AttentionImpl):
|
||||
kv_sharing_target_layer_name: int | None = None,
|
||||
sinks: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
self.attn_type = attn_type
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
@@ -266,11 +277,6 @@ class RocmAttentionImpl(AttentionImpl):
|
||||
|
||||
RocmAttentionBackend.validate_head_size(head_size)
|
||||
|
||||
if attn_type not in [AttentionType.DECODER, AttentionType.ENCODER_DECODER]:
|
||||
raise NotImplementedError(
|
||||
"Encoder self-attention is not implemented for RocmAttentionImpl"
|
||||
)
|
||||
|
||||
self.fp8_dtype = current_platform.fp8_dtype()
|
||||
|
||||
self.sinks = sinks
|
||||
@@ -281,6 +287,54 @@ class RocmAttentionImpl(AttentionImpl):
|
||||
f"num_heads: {num_heads}."
|
||||
)
|
||||
|
||||
def _forward_encoder_attention(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
attn_metadata: FlashAttentionMetadata,
|
||||
layer: torch.nn.Module,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass for encoder attention without KV cache.
|
||||
|
||||
Args:
|
||||
query: shape = [num_encoder_tokens, num_heads, head_size]
|
||||
key: shape = [num_encoder_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_encoder_tokens, num_kv_heads, head_size]
|
||||
output: shape = [num_encoder_tokens, num_heads, head_size]
|
||||
attn_metadata: Encoder attention metadata
|
||||
layer: The attention layer
|
||||
"""
|
||||
# For encoder attention, process FP8 quantization if needed
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
raise NotImplementedError(
|
||||
"quantization is not supported for encoder attention"
|
||||
)
|
||||
|
||||
# Use encoder-specific metadata for sequence information
|
||||
query_start_loc = attn_metadata.query_start_loc
|
||||
seq_lens = attn_metadata.seq_lens
|
||||
max_query_len = attn_metadata.max_query_len
|
||||
|
||||
# Call flash attention directly on Q, K, V tensors
|
||||
from vllm.v1.attention.ops.triton_prefill_attention import context_attention_fwd
|
||||
|
||||
context_attention_fwd(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
o=output,
|
||||
b_start_loc=query_start_loc,
|
||||
b_seq_len=seq_lens,
|
||||
max_input_len=max_query_len,
|
||||
is_causal=False,
|
||||
softmax_scale=self.scale,
|
||||
sliding_window_q=self.sliding_window[0],
|
||||
sliding_window_k=self.sliding_window[1],
|
||||
)
|
||||
return output
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@@ -330,6 +384,16 @@ class RocmAttentionImpl(AttentionImpl):
|
||||
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
|
||||
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
|
||||
return self._forward_encoder_attention(
|
||||
query[:num_actual_tokens],
|
||||
key[:num_actual_tokens],
|
||||
value[:num_actual_tokens],
|
||||
output[:num_actual_tokens],
|
||||
attn_metadata,
|
||||
layer,
|
||||
)
|
||||
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
kv_cache, self.num_kv_heads, self.head_size
|
||||
)
|
||||
@@ -380,6 +444,8 @@ class RocmAttentionImpl(AttentionImpl):
|
||||
kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
):
|
||||
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
|
||||
return
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
kv_cache, self.num_kv_heads, self.head_size
|
||||
)
|
||||
@@ -432,6 +498,8 @@ class RocmAttentionImpl(AttentionImpl):
|
||||
kv_cache: torch.Tensor,
|
||||
layer_slot_mapping: torch.Tensor,
|
||||
):
|
||||
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
|
||||
return
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
kv_cache,
|
||||
layer.num_kv_heads, # type: ignore[attr-defined]
|
||||
|
||||
@@ -6,6 +6,7 @@ import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -151,7 +152,34 @@ def flash_mla_with_kvcache_fp8(
|
||||
descale_k,
|
||||
)
|
||||
return out, softmax_lse
|
||||
def flash_mla_sparse_prefill(
|
||||
q: torch.Tensor,
|
||||
kv: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
sm_scale: float,
|
||||
d_v: int = 512,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Sparse attention prefill kernel
|
||||
|
||||
Args:
|
||||
- q: [s_q, h_q, d_qk], bfloat16
|
||||
- kv: [s_kv, h_kv, d_qk], bfloat16
|
||||
- indices: [s_q, h_kv, topk], int32.
|
||||
Invalid indices should be set to -1 or numbers >= s_kv
|
||||
- sm_scale: float
|
||||
- d_v: The dimension of value vectors. Can only be 512
|
||||
|
||||
Returns:
|
||||
- (output, max_logits, lse)
|
||||
About the definition of output,
|
||||
max_logits and lse, please refer to README.md
|
||||
- output: [s_q, h_q, d_v], bfloat16
|
||||
- max_logits: [s_q, h_q], float
|
||||
- lse: [s_q, h_q], float, 2-based log-sum-exp
|
||||
"""
|
||||
results = ops.sparse_prefill_fwd(q, kv, indices,sm_scale, d_v)
|
||||
return results
|
||||
|
||||
#
|
||||
# TODO: Add fake functions
|
||||
|
||||
@@ -37,8 +37,8 @@ def flash_attn_maxseqlen_wrapper(
|
||||
else:
|
||||
from vllm.v1.attention.backends.fa_utils import flash_attn_varlen_func
|
||||
|
||||
# if not current_platform.is_rocm() and fa_version is not None:
|
||||
# kwargs["fa_version"] = fa_version
|
||||
if not current_platform.is_rocm() and fa_version is not None:
|
||||
kwargs["fa_version"] = fa_version
|
||||
|
||||
q_len = q.size(1)
|
||||
if cu_seqlens is None:
|
||||
@@ -268,3 +268,91 @@ def vit_torch_sdpa_wrapper(
|
||||
return torch.ops.vllm.torch_sdpa_wrapper(
|
||||
q, k, v, scale, cu_seqlens, enable_gqa=enable_gqa
|
||||
)
|
||||
|
||||
|
||||
def flashinfer_wrapper(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
scale: float,
|
||||
workspace_buffer: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None,
|
||||
sequence_lengths: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
from flashinfer.prefill import cudnn_batch_prefill_with_kv_cache
|
||||
|
||||
is_reshaped = q.dim() == 4
|
||||
|
||||
if is_reshaped:
|
||||
reshape_batch_size = q.shape[0]
|
||||
q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
||||
# cuDNN <= 9.10.2.21 requires q, k to be contiguous
|
||||
# this comes with no cost for ViTs with RoPE because
|
||||
# RoPE has already made q and k contiguous.
|
||||
q, k = q.contiguous(), k.contiguous()
|
||||
|
||||
assert len(cu_seqlens) % 2 == 0, "cu_seqlens must be divisible by 2"
|
||||
cu_seqlength = len(cu_seqlens) // 2
|
||||
batch_offsets_qko = cu_seqlens[:cu_seqlength].view(-1, 1, 1, 1)
|
||||
batch_offsets_v = cu_seqlens[cu_seqlength:].view(-1, 1, 1, 1)
|
||||
sequence_lengths = sequence_lengths.view(-1, 1, 1, 1)
|
||||
max_seqlen = max_seqlen.item()
|
||||
|
||||
output, _ = cudnn_batch_prefill_with_kv_cache(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
scale,
|
||||
workspace_buffer,
|
||||
max_token_per_sequence=max_seqlen,
|
||||
max_sequence_kv=max_seqlen,
|
||||
actual_seq_lens_q=sequence_lengths,
|
||||
actual_seq_lens_kv=sequence_lengths,
|
||||
causal=False,
|
||||
return_lse=False,
|
||||
batch_offsets_q=batch_offsets_qko,
|
||||
batch_offsets_k=batch_offsets_qko,
|
||||
batch_offsets_v=batch_offsets_v,
|
||||
batch_offsets_o=batch_offsets_qko,
|
||||
)
|
||||
|
||||
if is_reshaped:
|
||||
output = einops.rearrange(output, "(b s) h d -> b s h d", b=reshape_batch_size)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def vit_flashinfer_wrapper_fake(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
scale: float,
|
||||
workspace_buffer: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None,
|
||||
sequence_lengths: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(q)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="flashinfer_wrapper",
|
||||
op_func=flashinfer_wrapper,
|
||||
fake_impl=vit_flashinfer_wrapper_fake,
|
||||
)
|
||||
|
||||
|
||||
def vit_flashinfer_wrapper(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
scale: float,
|
||||
workspace_buffer: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None,
|
||||
sequence_lengths: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.vllm.flashinfer_wrapper(
|
||||
q, k, v, scale, workspace_buffer, cu_seqlens, max_seqlen, sequence_lengths
|
||||
)
|
||||
|
||||
@@ -456,6 +456,37 @@ class KVCacheManager:
|
||||
"""
|
||||
return self.coordinator.get_num_common_prefix_blocks(running_request_id)
|
||||
|
||||
def get_num_free_blocks(self) -> int:
|
||||
"""Get the number of free blocks in the pool."""
|
||||
return self.block_pool.get_num_free_blocks()
|
||||
|
||||
def get_num_blocks_needed_for_tokens(
|
||||
self,
|
||||
request_id: str,
|
||||
num_tokens_need_slot: int,
|
||||
new_computed_blocks: KVCacheBlocks | None = None,
|
||||
num_encoder_tokens: int = 0,
|
||||
total_computed_tokens: int = 0,
|
||||
num_tokens_main_model: int = 0,
|
||||
) -> int:
|
||||
"""Estimate number of blocks needed for a request (no allocation).
|
||||
|
||||
Used e.g. to check if enough KV cache exists for full chunked prefill
|
||||
before allowing a request into running.
|
||||
"""
|
||||
if new_computed_blocks is not None:
|
||||
new_computed_block_list = new_computed_blocks.blocks
|
||||
else:
|
||||
new_computed_block_list = self.empty_kv_cache_blocks.blocks
|
||||
return self.coordinator.get_num_blocks_to_allocate(
|
||||
request_id=request_id,
|
||||
num_tokens=num_tokens_need_slot,
|
||||
new_computed_blocks=new_computed_block_list,
|
||||
num_encoder_tokens=num_encoder_tokens,
|
||||
total_computed_tokens=total_computed_tokens,
|
||||
num_tokens_main_model=num_tokens_main_model,
|
||||
)
|
||||
|
||||
def take_events(self) -> list[KVCacheEvent]:
|
||||
"""Take the KV cache events from the block pool.
|
||||
|
||||
@@ -489,6 +520,13 @@ class KVCacheManager:
|
||||
# Only create new KVCacheBlocks for non-empty blocks
|
||||
return KVCacheBlocks(blocks) if any(blocks) else self.empty_kv_cache_blocks
|
||||
|
||||
def take_new_block_ids(self) -> list[int]:
|
||||
"""Drain and return new attention block IDs for zeroing."""
|
||||
ids: list[int] = []
|
||||
for mgr in self.coordinator.single_type_managers:
|
||||
ids.extend(mgr.take_new_block_ids())
|
||||
return ids
|
||||
|
||||
def new_step_starts(self) -> None:
|
||||
"""Called when a new step is started."""
|
||||
self.coordinator.new_step_starts()
|
||||
|
||||
@@ -10,9 +10,8 @@ from collections.abc import Callable, Iterable, Iterator, Sequence
|
||||
from dataclasses import dataclass, replace
|
||||
from functools import partial
|
||||
from typing import Any, NewType, TypeAlias, overload
|
||||
import vllm.envs as envs
|
||||
|
||||
from vllm import envs
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.hashing import sha256_cbor, xxhash_cbor
|
||||
@@ -835,7 +834,7 @@ def may_override_num_blocks(vllm_config: VllmConfig, num_blocks: int) -> int:
|
||||
|
||||
|
||||
def get_num_blocks(
|
||||
vllm_config: VllmConfig, num_layers: int, available_memory: int, page_size: int
|
||||
vllm_config: VllmConfig, num_layers: int, available_memory: int, page_size: int, scale_page_size: int
|
||||
) -> int:
|
||||
"""
|
||||
Get the number of kv cache blocks.
|
||||
@@ -846,7 +845,7 @@ def get_num_blocks(
|
||||
available_memory: Memory available for KV cache in bytes.
|
||||
page_size: The page size of the KV cache.
|
||||
"""
|
||||
num_blocks = int(available_memory // page_size // num_layers)
|
||||
num_blocks = int(available_memory // (page_size + scale_page_size) // num_layers)
|
||||
num_blocks = max(num_blocks, 0)
|
||||
num_blocks = may_override_num_blocks(vllm_config, num_blocks)
|
||||
return num_blocks
|
||||
@@ -857,9 +856,14 @@ def get_uniform_page_size(kv_cache_specs: Iterable[KVCacheSpec]) -> int:
|
||||
Get the page size of the KV cache.
|
||||
"""
|
||||
page_sizes = {layer.page_size_bytes for layer in kv_cache_specs}
|
||||
scale_page_sizes = {layer.scale_page_size_bytes for layer in kv_cache_specs}
|
||||
assert len(page_sizes) == 1
|
||||
return page_sizes.pop()
|
||||
|
||||
if envs.VLLM_ATTN_OPT_LEVEL == 1:
|
||||
v_cache_scale_sizes = set(layer.v_cache_scale_size_bytes for layer in kv_cache_specs)
|
||||
assert len(v_cache_scale_sizes) == 1
|
||||
return page_sizes.pop(), scale_page_sizes.pop(), v_cache_scale_sizes.pop()
|
||||
else:
|
||||
return page_sizes.pop(), scale_page_sizes.pop()
|
||||
|
||||
def _get_kv_cache_groups_uniform_spec(
|
||||
kv_cache_specs: dict[str, KVCacheSpec],
|
||||
@@ -955,6 +959,7 @@ def is_kv_cache_type_attention_free(kv_cache_spec: dict[str, KVCacheSpec]) -> bo
|
||||
|
||||
|
||||
def _get_kv_cache_groups_uniform_page_size(
|
||||
vllm_config: VllmConfig,
|
||||
kv_cache_spec: dict[str, KVCacheSpec],
|
||||
) -> list[KVCacheGroupSpec]:
|
||||
"""
|
||||
@@ -1015,6 +1020,7 @@ def _get_kv_cache_groups_uniform_page_size(
|
||||
memory per block is the same for all groups.
|
||||
|
||||
Args:
|
||||
vllm_config: The global VllmConfig
|
||||
kv_cache_spec: The KVCacheSpec of each attention layer in the model
|
||||
Returns:
|
||||
The generated KVCacheGroupSpecs
|
||||
@@ -1058,19 +1064,28 @@ def _get_kv_cache_groups_uniform_page_size(
|
||||
num_padding_layers / len(layers) * 100,
|
||||
)
|
||||
num_groups = cdiv(len(layers), group_size)
|
||||
# In PP case, say if we have
|
||||
# - stage 0: full.0, sw.0, sw.1
|
||||
# - stage 1: full.1, sw.2, sw.3
|
||||
# We should have 3 groups: (full.0, full.1), (sw.0, sw.2), (sw.1, sw.3)
|
||||
# It can't be (full.0, full.1), (sw.0, sw.1), (sw.2, sw.3) because
|
||||
# the 3 groups in stage 0 will be (full.0), (sw.0, sw.1), (empty group)
|
||||
# and it will be padded to (full.0, padding), (sw.0, sw.1),
|
||||
# (padding, padding) to ensure the number of layers in each group is
|
||||
# the same and will cause memory waste.
|
||||
# To avoid this, we assign layers[i::num_groups] to the i-th group
|
||||
# instead of layers[i * group_size: (i + 1) * group_size]
|
||||
for i in range(num_groups):
|
||||
grouped_layers.append(layers[i::num_groups])
|
||||
# for support multi layer mtp, we need to
|
||||
# make all mtp layers in the same group
|
||||
if (
|
||||
vllm_config.speculative_config is not None
|
||||
and vllm_config.speculative_config.enable_multi_layers_mtp
|
||||
):
|
||||
for i in range(0, len(layers), group_size):
|
||||
grouped_layers.append(layers[i : i + group_size])
|
||||
else:
|
||||
# In PP case, say if we have
|
||||
# - stage 0: full.0, sw.0, sw.1
|
||||
# - stage 1: full.1, sw.2, sw.3
|
||||
# We should have 3 groups: (full.0, full.1), (sw.0, sw.2), (sw.1, sw.3)
|
||||
# It can't be (full.0, full.1), (sw.0, sw.1), (sw.2, sw.3) because
|
||||
# the 3 groups in stage 0 will be (full.0), (sw.0, sw.1), (empty group)
|
||||
# and it will be padded to (full.0, padding), (sw.0, sw.1),
|
||||
# (padding, padding) to ensure the number of layers in each group is
|
||||
# the same and will cause memory waste.
|
||||
# To avoid this, we assign layers[i::num_groups] to the i-th group
|
||||
# instead of layers[i * group_size: (i + 1) * group_size]
|
||||
for i in range(num_groups):
|
||||
grouped_layers.append(layers[i::num_groups])
|
||||
return create_kv_cache_group_specs(kv_cache_spec, grouped_layers)
|
||||
|
||||
|
||||
@@ -1096,9 +1111,9 @@ def get_kv_cache_config_from_groups(
|
||||
return KVCacheConfig(
|
||||
num_blocks=1,
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_scale_tensors=[],
|
||||
kv_cache_groups=kv_cache_groups,
|
||||
)
|
||||
|
||||
# Determine how model runners should initialize the KV cache tensors.
|
||||
if len(kv_cache_groups) == 1 and isinstance(
|
||||
kv_cache_groups[0].kv_cache_spec, UniformTypeKVCacheSpecs
|
||||
@@ -1118,6 +1133,12 @@ def get_kv_cache_config_from_groups(
|
||||
)
|
||||
for layer_name in kv_cache_groups[0].layer_names
|
||||
]
|
||||
kv_cache_scale_tensors = [
|
||||
KVCacheTensor(size=per_layer_specs[layer_name].scale_page_size_bytes *
|
||||
num_blocks,
|
||||
shared_by=[layer_name])
|
||||
for layer_name in kv_cache_groups[0].layer_names
|
||||
]
|
||||
else:
|
||||
# General case:
|
||||
# We will have group_size memory pools, each is shared by one layer from
|
||||
@@ -1129,55 +1150,39 @@ def get_kv_cache_config_from_groups(
|
||||
# full.1, sw.2: share another Tensor with size=available_memory//2
|
||||
group_size = max(len(group.layer_names) for group in kv_cache_groups)
|
||||
|
||||
page_size = get_uniform_page_size(
|
||||
if envs.VLLM_ATTN_OPT_LEVEL == 1:
|
||||
page_size, scale_page_size, v_cache_scale_size = get_uniform_page_size([group.kv_cache_spec for group in kv_cache_groups])
|
||||
else:
|
||||
page_size, scale_page_size = get_uniform_page_size(
|
||||
[group.kv_cache_spec for group in kv_cache_groups]
|
||||
)
|
||||
v_cache_scale_size = 0
|
||||
assert group_size > 0, "group_size must be greater than 0"
|
||||
# num_blocks = get_num_blocks(
|
||||
# vllm_config, group_size, available_memory, page_size
|
||||
# )
|
||||
num_blocks = get_num_blocks(
|
||||
vllm_config,
|
||||
group_size,
|
||||
available_memory,
|
||||
page_size,
|
||||
scale_page_size,
|
||||
)
|
||||
kv_cache_tensors = []
|
||||
# TODO: will add scale ?
|
||||
kv_cache_scale_tensors = []
|
||||
if envs.VLLM_KV_DISABLE_CROSS_GROUP_SHARE:
|
||||
total_layers = sum(len(group.layer_names) for group in kv_cache_groups)
|
||||
num_blocks = get_num_blocks(
|
||||
vllm_config,
|
||||
total_layers,
|
||||
available_memory,
|
||||
page_size,
|
||||
for i in range(group_size):
|
||||
shared_by = []
|
||||
for j in range(len(kv_cache_groups)):
|
||||
if i < len(kv_cache_groups[j].layer_names):
|
||||
shared_by.append(kv_cache_groups[j].layer_names[i])
|
||||
kv_cache_tensors.append(
|
||||
KVCacheTensor(size=page_size * num_blocks, shared_by=shared_by)
|
||||
)
|
||||
for group in kv_cache_groups:
|
||||
for layer_name in group.layer_names:
|
||||
kv_cache_tensors.append(
|
||||
KVCacheTensor(size=page_size * num_blocks, shared_by=[layer_name])
|
||||
)
|
||||
logger.warning(
|
||||
"VLLM_KV_DISABLE_CROSS_GROUP_SHARE=1: using dedicated KV tensors per layer "
|
||||
"(groups=%d, tensors=%d, num_blocks=%d)",
|
||||
len(kv_cache_groups),
|
||||
len(kv_cache_tensors),
|
||||
num_blocks,
|
||||
kv_cache_scale_tensors.append(
|
||||
KVCacheTensor(size=scale_page_size * num_blocks, shared_by=shared_by, size_scale=v_cache_scale_size)
|
||||
)
|
||||
else:
|
||||
num_blocks = get_num_blocks(
|
||||
vllm_config,
|
||||
group_size,
|
||||
available_memory,
|
||||
page_size,
|
||||
)
|
||||
for i in range(group_size):
|
||||
shared_by = []
|
||||
for j in range(len(kv_cache_groups)):
|
||||
if i < len(kv_cache_groups[j].layer_names):
|
||||
shared_by.append(kv_cache_groups[j].layer_names[i])
|
||||
kv_cache_tensors.append(
|
||||
KVCacheTensor(size=page_size * num_blocks, shared_by=shared_by)
|
||||
)
|
||||
|
||||
return KVCacheConfig(
|
||||
num_blocks=num_blocks,
|
||||
kv_cache_tensors=kv_cache_tensors,
|
||||
kv_cache_scale_tensors = kv_cache_scale_tensors,
|
||||
kv_cache_groups=kv_cache_groups,
|
||||
)
|
||||
|
||||
@@ -1284,7 +1289,9 @@ def get_kv_cache_groups(
|
||||
# have the same physical memory per block per layer. Split the layers
|
||||
# into groups with the same number of layers, and thus same total page
|
||||
# size.
|
||||
return _get_kv_cache_groups_uniform_page_size(kv_cache_spec)
|
||||
return _get_kv_cache_groups_uniform_page_size(
|
||||
vllm_config=vllm_config, kv_cache_spec=kv_cache_spec
|
||||
)
|
||||
|
||||
|
||||
def generate_scheduler_kv_cache_config(
|
||||
@@ -1381,11 +1388,17 @@ def _max_memory_usage_bytes_from_groups(
|
||||
# General case: group_size pools, each shared by one layer per group
|
||||
# Memory = group_size * page_size * blocks_for_max_len
|
||||
group_size = max(len(group.layer_names) for group in kv_cache_groups)
|
||||
page_size = get_uniform_page_size(
|
||||
[group.kv_cache_spec for group in kv_cache_groups]
|
||||
)
|
||||
if envs.VLLM_ATTN_OPT_LEVEL == 1:
|
||||
page_size, scale_page_size, v_cache_scale_size = get_uniform_page_size(
|
||||
[group.kv_cache_spec for group in kv_cache_groups]
|
||||
)
|
||||
else:
|
||||
page_size, scale_page_size = get_uniform_page_size(
|
||||
[group.kv_cache_spec for group in kv_cache_groups]
|
||||
)
|
||||
v_cache_scale_size = 0
|
||||
any_spec = kv_cache_groups[0].kv_cache_spec
|
||||
blocks_needed = cdiv(any_spec.max_memory_usage_bytes(vllm_config), page_size)
|
||||
blocks_needed = cdiv(any_spec.max_memory_usage_bytes(vllm_config), (page_size + scale_page_size + v_cache_scale_size))
|
||||
|
||||
return group_size * page_size * blocks_needed
|
||||
|
||||
@@ -1633,6 +1646,10 @@ def get_kv_cache_configs(
|
||||
for tensor in kv_cache_config.kv_cache_tensors:
|
||||
assert tensor.size % num_blocks_old == 0
|
||||
tensor.size = tensor.size // num_blocks_old * min_num_blocks
|
||||
|
||||
for tensor in kv_cache_config.kv_cache_scale_tensors:
|
||||
assert tensor.size % num_blocks_old == 0
|
||||
tensor.size = tensor.size // num_blocks_old * min_num_blocks
|
||||
|
||||
if len(kv_cache_config.kv_cache_groups) > 0:
|
||||
_report_kv_cache_config(vllm_config, kv_cache_config)
|
||||
|
||||
@@ -5,8 +5,6 @@ from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from vllm._bc_linter import bc_linter_include
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
@@ -29,7 +27,6 @@ else:
|
||||
Request = object
|
||||
|
||||
|
||||
@bc_linter_include
|
||||
@dataclass
|
||||
class NewRequestData:
|
||||
req_id: str
|
||||
@@ -109,7 +106,6 @@ class NewRequestData:
|
||||
)
|
||||
|
||||
|
||||
@bc_linter_include
|
||||
@dataclass
|
||||
class CachedRequestData:
|
||||
req_ids: list[str]
|
||||
@@ -179,7 +175,6 @@ class CachedRequestData:
|
||||
)
|
||||
|
||||
|
||||
@bc_linter_include
|
||||
@dataclass
|
||||
class SchedulerOutput:
|
||||
# list of the requests that are scheduled for the first time.
|
||||
@@ -217,6 +212,9 @@ class SchedulerOutput:
|
||||
# freed from the encoder cache.
|
||||
free_encoder_mm_hashes: list[str]
|
||||
|
||||
# Request IDs that are resumed from preemption in this step.
|
||||
scheduled_resumed_reqs: list[str] | None = None
|
||||
|
||||
# Request IDs that are preempted in this step.
|
||||
# Only used for v2 model runner.
|
||||
preempted_req_ids: set[str] | None = None
|
||||
@@ -238,6 +236,11 @@ class SchedulerOutput:
|
||||
# EC Cache Connector metadata
|
||||
ec_connector_metadata: ECConnectorMetadata | None = None
|
||||
|
||||
# Block IDs freshly allocated from the pool during this scheduling step.
|
||||
# The worker zeros the corresponding GPU memory before the blocks are used,
|
||||
# preventing stale NaN/data from corrupting attention or SSM computation.
|
||||
new_block_ids_to_zero: list[int] | None = None
|
||||
|
||||
@classmethod
|
||||
def make_empty(cls) -> "SchedulerOutput":
|
||||
return cls(
|
||||
|
||||
@@ -48,7 +48,7 @@ from vllm.v1.core.sched.output import (
|
||||
from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue
|
||||
from vllm.v1.core.sched.utils import check_stop, remove_all
|
||||
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.metrics.perf import ModelMetrics, PerfStats
|
||||
from vllm.v1.metrics.stats import PrefixCacheStats, SchedulerStats
|
||||
from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput
|
||||
@@ -233,13 +233,8 @@ class Scheduler(SchedulerInterface):
|
||||
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
|
||||
self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER
|
||||
|
||||
def has_mamba_layers(kv_cache_config: KVCacheConfig) -> bool:
|
||||
return any(
|
||||
isinstance(group_spec.kv_cache_spec, MambaSpec)
|
||||
for group_spec in kv_cache_config.kv_cache_groups
|
||||
)
|
||||
|
||||
self.has_mamba_layers = has_mamba_layers(kv_cache_config)
|
||||
self.has_mamba_layers = kv_cache_config.has_mamba_layers
|
||||
self.needs_kv_cache_zeroing = kv_cache_config.needs_kv_cache_zeroing
|
||||
self.need_mamba_block_aligned_split = (
|
||||
self.has_mamba_layers and self.cache_config.mamba_cache_mode == "align"
|
||||
)
|
||||
@@ -320,6 +315,9 @@ class Scheduler(SchedulerInterface):
|
||||
return num_new_tokens
|
||||
|
||||
def schedule(self) -> SchedulerOutput:
|
||||
if envs.VLLM_ENABLE_PP_MIX_ILU_SCHEDULING:
|
||||
return self.schedule_opt()
|
||||
|
||||
# NOTE(woosuk) on the scheduling algorithm:
|
||||
# There's no "decoding phase" nor "prefill phase" in the scheduler.
|
||||
# Each request just has the num_computed_tokens and
|
||||
@@ -413,7 +411,7 @@ class Scheduler(SchedulerInterface):
|
||||
request, num_new_tokens
|
||||
)
|
||||
|
||||
if num_new_tokens == 0:
|
||||
if num_new_tokens <= 0:
|
||||
# The request cannot be scheduled because one of the following
|
||||
# reasons:
|
||||
# 1. No new tokens to schedule. This may happen when
|
||||
@@ -425,6 +423,8 @@ class Scheduler(SchedulerInterface):
|
||||
# 3. The encoder cache is exhausted.
|
||||
# 4. Insufficient budget for a block-aligned chunk in hybrid
|
||||
# models with mamba cache mode \"align\".
|
||||
# 5. num_computed_tokens > num_tokens_with_spec due to PP
|
||||
# timing: schedule() runs before update_from_output().
|
||||
# NOTE(woosuk): Here, by doing `continue` instead of `break`,
|
||||
# we do not strictly follow the FCFS scheduling policy and
|
||||
# allow the lower-priority requests to be scheduled.
|
||||
@@ -670,7 +670,7 @@ class Scheduler(SchedulerInterface):
|
||||
# If chunked_prefill is disabled,
|
||||
# we can stop the scheduling here.
|
||||
break
|
||||
temp_num_new_tokens = num_new_tokens
|
||||
|
||||
num_new_tokens = min(num_new_tokens, token_budget)
|
||||
assert num_new_tokens > 0
|
||||
|
||||
@@ -688,7 +688,7 @@ class Scheduler(SchedulerInterface):
|
||||
encoder_compute_budget,
|
||||
shift_computed_tokens=1 if self.use_eagle else 0,
|
||||
)
|
||||
if num_new_tokens == 0 or num_new_tokens < temp_num_new_tokens:
|
||||
if num_new_tokens == 0:
|
||||
# The request cannot be scheduled.
|
||||
break
|
||||
|
||||
@@ -723,6 +723,35 @@ class Scheduler(SchedulerInterface):
|
||||
for i in encoder_inputs_to_schedule
|
||||
)
|
||||
|
||||
if not load_kv_async:
|
||||
enable_chunked = self.scheduler_config.enable_chunked_prefill
|
||||
tokens_still_to_compute = (
|
||||
request.num_tokens - num_computed_tokens
|
||||
)
|
||||
is_chunked = (
|
||||
enable_chunked
|
||||
and tokens_still_to_compute > num_new_tokens
|
||||
)
|
||||
if is_chunked:
|
||||
assert (
|
||||
request.num_tokens <= self.max_model_len
|
||||
), "request.num_tokens must not exceed max_model_len"
|
||||
num_tokens_need_slot = min(
|
||||
request.num_tokens + effective_lookahead_tokens,
|
||||
self.max_model_len,
|
||||
)
|
||||
blocks_needed = (
|
||||
self.kv_cache_manager.get_num_blocks_needed_for_tokens(
|
||||
request.request_id,
|
||||
num_tokens_need_slot,
|
||||
new_computed_blocks,
|
||||
num_encoder_tokens,
|
||||
)
|
||||
)
|
||||
num_free = self.kv_cache_manager.get_num_free_blocks()
|
||||
if num_free < blocks_needed:
|
||||
break
|
||||
|
||||
new_blocks = self.kv_cache_manager.allocate_slots(
|
||||
request,
|
||||
num_new_tokens,
|
||||
@@ -871,6 +900,12 @@ class Scheduler(SchedulerInterface):
|
||||
self.prev_step_scheduled_req_ids.clear()
|
||||
self.prev_step_scheduled_req_ids.update(num_scheduled_tokens.keys())
|
||||
|
||||
new_block_ids_to_zero = (
|
||||
(self.kv_cache_manager.take_new_block_ids() or None)
|
||||
if self.needs_kv_cache_zeroing
|
||||
else None
|
||||
)
|
||||
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=new_reqs_data,
|
||||
scheduled_cached_reqs=cached_reqs_data,
|
||||
@@ -886,6 +921,7 @@ class Scheduler(SchedulerInterface):
|
||||
# the previous and the current steps.
|
||||
finished_req_ids=self.finished_req_ids,
|
||||
free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(),
|
||||
new_block_ids_to_zero=new_block_ids_to_zero,
|
||||
)
|
||||
|
||||
# NOTE(Kuntai): this function is designed for multiple purposes:
|
||||
@@ -909,6 +945,527 @@ class Scheduler(SchedulerInterface):
|
||||
self._update_after_schedule(scheduler_output)
|
||||
return scheduler_output
|
||||
|
||||
def schedule_opt(self) -> SchedulerOutput:
|
||||
"""PP mix ILU scheduling variant of schedule()."""
|
||||
|
||||
scheduled_new_reqs: list[Request] = []
|
||||
scheduled_resumed_reqs: list[Request] = []
|
||||
scheduled_running_reqs: list[Request] = []
|
||||
preempted_reqs: list[Request] = []
|
||||
|
||||
req_to_new_blocks: dict[str, KVCacheBlocks] = {}
|
||||
num_scheduled_tokens: dict[str, int] = {}
|
||||
token_budget = self.max_num_scheduled_tokens
|
||||
if self._pause_state == PauseState.PAUSED_ALL:
|
||||
token_budget = 0
|
||||
|
||||
# Encoder-related.
|
||||
scheduled_encoder_inputs: dict[str, list[int]] = {}
|
||||
encoder_compute_budget = self.max_num_encoder_input_tokens
|
||||
# Spec decode-related.
|
||||
scheduled_spec_decode_tokens: dict[str, list[int]] = {}
|
||||
|
||||
# For logging.
|
||||
scheduled_timestamp = time.monotonic()
|
||||
|
||||
self.kv_cache_manager.new_step_starts()
|
||||
|
||||
# First, schedule the RUNNING requests.
|
||||
req_index = 0
|
||||
while req_index < len(self.running) and token_budget > 0:
|
||||
request = self.running[req_index]
|
||||
|
||||
if (
|
||||
request.num_output_placeholders > 0
|
||||
and request.num_computed_tokens + 2 - request.num_output_placeholders
|
||||
>= request.num_prompt_tokens + request.max_tokens
|
||||
):
|
||||
req_index += 1
|
||||
continue
|
||||
|
||||
num_new_tokens = (
|
||||
request.num_tokens_with_spec
|
||||
+ request.num_output_placeholders
|
||||
- request.num_computed_tokens
|
||||
)
|
||||
if 0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens:
|
||||
num_new_tokens = self.scheduler_config.long_prefill_token_threshold
|
||||
num_new_tokens = min(num_new_tokens, token_budget)
|
||||
|
||||
num_new_tokens = min(
|
||||
num_new_tokens, self.max_model_len - 1 - request.num_computed_tokens
|
||||
)
|
||||
|
||||
# Schedule encoder inputs.
|
||||
encoder_inputs_to_schedule = None
|
||||
external_load_encoder_input: list[int] = []
|
||||
new_encoder_compute_budget = encoder_compute_budget
|
||||
if request.has_encoder_inputs:
|
||||
(
|
||||
encoder_inputs_to_schedule,
|
||||
num_new_tokens,
|
||||
new_encoder_compute_budget,
|
||||
external_load_encoder_input,
|
||||
) = self._try_schedule_encoder_inputs(
|
||||
request,
|
||||
request.num_computed_tokens,
|
||||
num_new_tokens,
|
||||
encoder_compute_budget,
|
||||
shift_computed_tokens=1 if self.use_eagle else 0,
|
||||
)
|
||||
|
||||
if self.need_mamba_block_aligned_split:
|
||||
num_new_tokens = self._mamba_block_aligned_split(
|
||||
request, num_new_tokens
|
||||
)
|
||||
|
||||
if num_new_tokens <= 0:
|
||||
# The request cannot be scheduled because one of the following
|
||||
# reasons:
|
||||
# 1. No new tokens to schedule. This may happen when
|
||||
# (1) PP>1 and we have already scheduled all prompt tokens
|
||||
# but they are not finished yet.
|
||||
# (2) Async scheduling and the request has reached to either
|
||||
# its max_total_tokens or max_model_len.
|
||||
# 2. The encoder budget is exhausted.
|
||||
# 3. The encoder cache is exhausted.
|
||||
# 4. num_computed_tokens > num_tokens_with_spec due to PP
|
||||
# timing: schedule() runs before update_from_output().
|
||||
req_index += 1
|
||||
continue
|
||||
|
||||
# Schedule newly needed KV blocks for the request.
|
||||
with record_function_or_nullcontext("schedule: allocate_slots"):
|
||||
while True:
|
||||
new_blocks = self.kv_cache_manager.allocate_slots(
|
||||
request,
|
||||
num_new_tokens,
|
||||
num_lookahead_tokens=self.num_lookahead_tokens,
|
||||
)
|
||||
|
||||
if new_blocks is not None:
|
||||
break
|
||||
|
||||
if self.policy == SchedulingPolicy.PRIORITY:
|
||||
preempted_req = max(
|
||||
self.running,
|
||||
key=lambda r: (r.priority, r.arrival_time),
|
||||
)
|
||||
self.running.remove(preempted_req)
|
||||
if preempted_req in scheduled_running_reqs:
|
||||
preempted_req_id = preempted_req.request_id
|
||||
scheduled_running_reqs.remove(preempted_req)
|
||||
token_budget += num_scheduled_tokens.pop(preempted_req_id)
|
||||
req_to_new_blocks.pop(preempted_req_id)
|
||||
scheduled_spec_decode_tokens.pop(preempted_req_id, None)
|
||||
preempted_encoder_inputs = scheduled_encoder_inputs.pop(
|
||||
preempted_req_id, None
|
||||
)
|
||||
if preempted_encoder_inputs:
|
||||
num_embeds_to_restore = sum(
|
||||
preempted_req.get_num_encoder_embeds(i)
|
||||
for i in preempted_encoder_inputs
|
||||
)
|
||||
encoder_compute_budget += num_embeds_to_restore
|
||||
req_index -= 1
|
||||
else:
|
||||
preempted_req = self.running.pop()
|
||||
|
||||
self._preempt_request(preempted_req, scheduled_timestamp)
|
||||
preempted_reqs.append(preempted_req)
|
||||
if preempted_req == request:
|
||||
break
|
||||
|
||||
if new_blocks is None:
|
||||
break
|
||||
|
||||
scheduled_running_reqs.append(request)
|
||||
request_id = request.request_id
|
||||
req_to_new_blocks[request_id] = new_blocks
|
||||
num_scheduled_tokens[request_id] = num_new_tokens
|
||||
token_budget -= num_new_tokens
|
||||
req_index += 1
|
||||
|
||||
if request.spec_token_ids:
|
||||
num_scheduled_spec_tokens = (
|
||||
num_new_tokens
|
||||
+ request.num_computed_tokens
|
||||
- request.num_tokens
|
||||
- request.num_output_placeholders
|
||||
)
|
||||
if num_scheduled_spec_tokens > 0:
|
||||
spec_token_ids = request.spec_token_ids
|
||||
if len(spec_token_ids) > num_scheduled_spec_tokens:
|
||||
spec_token_ids = spec_token_ids[:num_scheduled_spec_tokens]
|
||||
scheduled_spec_decode_tokens[request.request_id] = spec_token_ids
|
||||
|
||||
request.spec_token_ids = []
|
||||
|
||||
if encoder_inputs_to_schedule:
|
||||
scheduled_encoder_inputs[request_id] = encoder_inputs_to_schedule
|
||||
for i in encoder_inputs_to_schedule:
|
||||
self.encoder_cache_manager.allocate(request, i)
|
||||
encoder_compute_budget = new_encoder_compute_budget
|
||||
if external_load_encoder_input:
|
||||
for i in external_load_encoder_input:
|
||||
self.encoder_cache_manager.allocate(request, i)
|
||||
if self.ec_connector is not None:
|
||||
self.ec_connector.update_state_after_alloc(request, i)
|
||||
|
||||
# Record the LoRAs in scheduled_running_reqs
|
||||
scheduled_loras: set[int] = set()
|
||||
if self.lora_config:
|
||||
scheduled_loras = set(
|
||||
req.lora_request.lora_int_id
|
||||
for req in scheduled_running_reqs
|
||||
if req.lora_request and req.lora_request.lora_int_id > 0
|
||||
)
|
||||
assert len(scheduled_loras) <= self.lora_config.max_loras
|
||||
|
||||
# Next, schedule the WAITING requests.
|
||||
if not preempted_reqs and self._pause_state == PauseState.UNPAUSED:
|
||||
skipped_waiting_requests = create_request_queue(self.policy)
|
||||
|
||||
while self.waiting and token_budget > 0:
|
||||
if len(self.running) == self.max_num_running_reqs:
|
||||
break
|
||||
|
||||
request = self.waiting.peek_request()
|
||||
request_id = request.request_id
|
||||
|
||||
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
|
||||
is_ready = self._update_waiting_for_remote_kv(request)
|
||||
if is_ready:
|
||||
if request.num_preemptions:
|
||||
request.status = RequestStatus.PREEMPTED
|
||||
else:
|
||||
request.status = RequestStatus.WAITING
|
||||
else:
|
||||
logger.debug(
|
||||
"%s is still in WAITING_FOR_REMOTE_KVS state.",
|
||||
request_id,
|
||||
)
|
||||
self.waiting.pop_request()
|
||||
skipped_waiting_requests.prepend_request(request)
|
||||
continue
|
||||
|
||||
if request.status == RequestStatus.WAITING_FOR_FSM:
|
||||
structured_output_req = request.structured_output_request
|
||||
if structured_output_req and structured_output_req.grammar:
|
||||
request.status = RequestStatus.WAITING
|
||||
else:
|
||||
self.waiting.pop_request()
|
||||
skipped_waiting_requests.prepend_request(request)
|
||||
continue
|
||||
|
||||
if request.status == RequestStatus.WAITING_FOR_STREAMING_REQ:
|
||||
assert not request.streaming_queue
|
||||
self.waiting.pop_request()
|
||||
skipped_waiting_requests.prepend_request(request)
|
||||
continue
|
||||
|
||||
if (
|
||||
self.lora_config
|
||||
and request.lora_request
|
||||
and (
|
||||
len(scheduled_loras) == self.lora_config.max_loras
|
||||
and request.lora_request.lora_int_id not in scheduled_loras
|
||||
)
|
||||
):
|
||||
self.waiting.pop_request()
|
||||
skipped_waiting_requests.prepend_request(request)
|
||||
continue
|
||||
|
||||
num_external_computed_tokens = 0
|
||||
load_kv_async = False
|
||||
connector_prefix_cache_queries, connector_prefix_cache_hits = 0, 0
|
||||
|
||||
if request.num_computed_tokens == 0:
|
||||
new_computed_blocks, num_new_local_computed_tokens = (
|
||||
self.kv_cache_manager.get_computed_blocks(request)
|
||||
)
|
||||
|
||||
if self.connector is not None:
|
||||
ext_tokens, load_kv_async = (
|
||||
self.connector.get_num_new_matched_tokens(
|
||||
request, num_new_local_computed_tokens
|
||||
)
|
||||
)
|
||||
|
||||
if ext_tokens is None:
|
||||
self.waiting.pop_request()
|
||||
skipped_waiting_requests.prepend_request(request)
|
||||
continue
|
||||
|
||||
request.num_external_computed_tokens = ext_tokens
|
||||
num_external_computed_tokens = ext_tokens
|
||||
|
||||
connector_prefix_cache_queries = (
|
||||
request.num_tokens - num_new_local_computed_tokens
|
||||
)
|
||||
connector_prefix_cache_hits = num_external_computed_tokens
|
||||
|
||||
num_computed_tokens = (
|
||||
num_new_local_computed_tokens + num_external_computed_tokens
|
||||
)
|
||||
else:
|
||||
new_computed_blocks = self.kv_cache_manager.empty_kv_cache_blocks
|
||||
num_new_local_computed_tokens = 0
|
||||
num_computed_tokens = request.num_computed_tokens
|
||||
|
||||
encoder_inputs_to_schedule = None
|
||||
external_load_encoder_input = []
|
||||
new_encoder_compute_budget = encoder_compute_budget
|
||||
|
||||
if load_kv_async:
|
||||
assert num_external_computed_tokens > 0
|
||||
num_new_tokens = 0
|
||||
else:
|
||||
num_new_tokens = request.num_tokens - num_computed_tokens
|
||||
threshold = self.scheduler_config.long_prefill_token_threshold
|
||||
if 0 < threshold < num_new_tokens:
|
||||
num_new_tokens = threshold
|
||||
|
||||
if (
|
||||
not self.scheduler_config.enable_chunked_prefill
|
||||
and num_new_tokens > token_budget
|
||||
):
|
||||
break
|
||||
|
||||
num_new_tokens = min(num_new_tokens, token_budget)
|
||||
assert num_new_tokens > 0
|
||||
|
||||
if request.has_encoder_inputs:
|
||||
(
|
||||
encoder_inputs_to_schedule,
|
||||
num_new_tokens,
|
||||
new_encoder_compute_budget,
|
||||
external_load_encoder_input,
|
||||
) = self._try_schedule_encoder_inputs(
|
||||
request,
|
||||
num_computed_tokens,
|
||||
num_new_tokens,
|
||||
encoder_compute_budget,
|
||||
shift_computed_tokens=1 if self.use_eagle else 0,
|
||||
)
|
||||
if num_new_tokens == 0:
|
||||
break
|
||||
|
||||
if self.need_mamba_block_aligned_split:
|
||||
num_new_tokens = self._mamba_block_aligned_split(
|
||||
request,
|
||||
num_new_tokens,
|
||||
num_new_local_computed_tokens,
|
||||
num_external_computed_tokens,
|
||||
)
|
||||
if num_new_tokens == 0:
|
||||
break
|
||||
|
||||
effective_lookahead_tokens = (
|
||||
0 if request.num_computed_tokens == 0 else self.num_lookahead_tokens
|
||||
)
|
||||
|
||||
num_encoder_tokens = 0
|
||||
if (
|
||||
self.is_encoder_decoder
|
||||
and request.has_encoder_inputs
|
||||
and encoder_inputs_to_schedule
|
||||
):
|
||||
num_encoder_tokens = sum(
|
||||
request.get_num_encoder_embeds(i)
|
||||
for i in encoder_inputs_to_schedule
|
||||
)
|
||||
|
||||
if not load_kv_async:
|
||||
enable_chunked = self.scheduler_config.enable_chunked_prefill
|
||||
tokens_still_to_compute = (
|
||||
request.num_tokens - num_computed_tokens
|
||||
)
|
||||
is_chunked = (
|
||||
enable_chunked
|
||||
and tokens_still_to_compute > num_new_tokens
|
||||
)
|
||||
if is_chunked:
|
||||
assert (
|
||||
request.num_tokens <= self.max_model_len
|
||||
), "request.num_tokens must not exceed max_model_len"
|
||||
num_tokens_need_slot = min(
|
||||
request.num_tokens + effective_lookahead_tokens,
|
||||
self.max_model_len,
|
||||
)
|
||||
blocks_needed = (
|
||||
self.kv_cache_manager.get_num_blocks_needed_for_tokens(
|
||||
request.request_id,
|
||||
num_tokens_need_slot,
|
||||
new_computed_blocks,
|
||||
num_encoder_tokens,
|
||||
)
|
||||
)
|
||||
num_free = self.kv_cache_manager.get_num_free_blocks()
|
||||
if num_free < blocks_needed:
|
||||
break
|
||||
|
||||
new_blocks = self.kv_cache_manager.allocate_slots(
|
||||
request,
|
||||
num_new_tokens,
|
||||
num_new_computed_tokens=num_new_local_computed_tokens,
|
||||
new_computed_blocks=new_computed_blocks,
|
||||
num_lookahead_tokens=effective_lookahead_tokens,
|
||||
num_external_computed_tokens=num_external_computed_tokens,
|
||||
delay_cache_blocks=load_kv_async,
|
||||
num_encoder_tokens=num_encoder_tokens,
|
||||
)
|
||||
|
||||
if new_blocks is None:
|
||||
if request.has_encoder_inputs:
|
||||
self.encoder_cache_manager.free(request)
|
||||
break
|
||||
|
||||
if self.connector is not None:
|
||||
self.connector.update_state_after_alloc(
|
||||
request,
|
||||
self.kv_cache_manager.get_blocks(request_id),
|
||||
num_external_computed_tokens,
|
||||
)
|
||||
if (
|
||||
self.connector_prefix_cache_stats is not None
|
||||
and connector_prefix_cache_queries != 0
|
||||
):
|
||||
self.connector_prefix_cache_stats.record(
|
||||
num_tokens=connector_prefix_cache_queries,
|
||||
num_hits=connector_prefix_cache_hits,
|
||||
preempted=request.num_preemptions > 0,
|
||||
)
|
||||
|
||||
request = self.waiting.pop_request()
|
||||
if load_kv_async:
|
||||
skipped_waiting_requests.prepend_request(request)
|
||||
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
continue
|
||||
|
||||
self.running.append(request)
|
||||
if self.log_stats:
|
||||
request.record_event(
|
||||
EngineCoreEventType.SCHEDULED, scheduled_timestamp
|
||||
)
|
||||
if request.status == RequestStatus.WAITING:
|
||||
scheduled_new_reqs.append(request)
|
||||
elif request.status == RequestStatus.PREEMPTED:
|
||||
scheduled_resumed_reqs.append(request)
|
||||
else:
|
||||
raise RuntimeError(f"Invalid request status: {request.status}")
|
||||
|
||||
if self.lora_config and request.lora_request:
|
||||
scheduled_loras.add(request.lora_request.lora_int_id)
|
||||
req_to_new_blocks[request_id] = self.kv_cache_manager.get_blocks(
|
||||
request_id
|
||||
)
|
||||
num_scheduled_tokens[request_id] = num_new_tokens
|
||||
token_budget -= num_new_tokens
|
||||
request.status = RequestStatus.RUNNING
|
||||
request.num_computed_tokens = num_computed_tokens
|
||||
if request.num_cached_tokens < 0:
|
||||
request.num_cached_tokens = num_computed_tokens
|
||||
if encoder_inputs_to_schedule:
|
||||
scheduled_encoder_inputs[request_id] = encoder_inputs_to_schedule
|
||||
for i in encoder_inputs_to_schedule:
|
||||
self.encoder_cache_manager.allocate(request, i)
|
||||
encoder_compute_budget = new_encoder_compute_budget
|
||||
if external_load_encoder_input:
|
||||
for i in external_load_encoder_input:
|
||||
self.encoder_cache_manager.allocate(request, i)
|
||||
if self.ec_connector is not None:
|
||||
self.ec_connector.update_state_after_alloc(request, i)
|
||||
|
||||
if skipped_waiting_requests:
|
||||
self.waiting.prepend_requests(skipped_waiting_requests)
|
||||
|
||||
# Check if the scheduling constraints are satisfied.
|
||||
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
|
||||
assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
|
||||
|
||||
assert token_budget >= 0
|
||||
assert len(self.running) <= self.max_num_running_reqs
|
||||
assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len(
|
||||
scheduled_running_reqs
|
||||
) <= len(self.running)
|
||||
|
||||
num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups)
|
||||
with record_function_or_nullcontext("schedule: get_num_common_prefix_blocks"):
|
||||
if self.running:
|
||||
any_request_id = self.running[0].request_id
|
||||
num_common_prefix_blocks = (
|
||||
self.kv_cache_manager.get_num_common_prefix_blocks(any_request_id)
|
||||
)
|
||||
|
||||
if self.use_v2_model_runner:
|
||||
scheduled_new_reqs = scheduled_new_reqs + scheduled_resumed_reqs
|
||||
scheduled_resumed_reqs = []
|
||||
new_reqs_data = [
|
||||
NewRequestData.from_request(
|
||||
req,
|
||||
req_to_new_blocks[req.request_id].get_block_ids(),
|
||||
req._all_token_ids,
|
||||
)
|
||||
for req in scheduled_new_reqs
|
||||
]
|
||||
else:
|
||||
new_reqs_data = [
|
||||
NewRequestData.from_request(
|
||||
req, req_to_new_blocks[req.request_id].get_block_ids()
|
||||
)
|
||||
for req in scheduled_new_reqs
|
||||
]
|
||||
|
||||
with record_function_or_nullcontext("schedule: make_cached_request_data"):
|
||||
cached_reqs_data = self._make_cached_request_data(
|
||||
scheduled_running_reqs,
|
||||
scheduled_resumed_reqs,
|
||||
num_scheduled_tokens,
|
||||
scheduled_spec_decode_tokens,
|
||||
req_to_new_blocks,
|
||||
)
|
||||
|
||||
self.prev_step_scheduled_req_ids.clear()
|
||||
self.prev_step_scheduled_req_ids.update(num_scheduled_tokens.keys())
|
||||
|
||||
new_block_ids_to_zero = (
|
||||
(self.kv_cache_manager.take_new_block_ids() or None)
|
||||
if self.needs_kv_cache_zeroing
|
||||
else None
|
||||
)
|
||||
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=new_reqs_data,
|
||||
scheduled_cached_reqs=cached_reqs_data,
|
||||
scheduled_resumed_reqs=[r.request_id for r in scheduled_resumed_reqs],
|
||||
num_scheduled_tokens=num_scheduled_tokens,
|
||||
total_num_scheduled_tokens=total_num_scheduled_tokens,
|
||||
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
|
||||
scheduled_encoder_inputs=scheduled_encoder_inputs,
|
||||
num_common_prefix_blocks=num_common_prefix_blocks,
|
||||
preempted_req_ids={req.request_id for req in preempted_reqs},
|
||||
finished_req_ids=self.finished_req_ids,
|
||||
free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(),
|
||||
new_block_ids_to_zero=new_block_ids_to_zero,
|
||||
)
|
||||
|
||||
if self.connector is not None:
|
||||
meta: KVConnectorMetadata = self.connector.build_connector_meta(
|
||||
scheduler_output
|
||||
)
|
||||
scheduler_output.kv_connector_metadata = meta
|
||||
|
||||
if self.ec_connector is not None:
|
||||
ec_meta: ECConnectorMetadata = self.ec_connector.build_connector_meta(
|
||||
scheduler_output
|
||||
)
|
||||
scheduler_output.ec_connector_metadata = ec_meta
|
||||
|
||||
with record_function_or_nullcontext("schedule: update_after_schedule"):
|
||||
self._update_after_schedule(scheduler_output)
|
||||
return scheduler_output
|
||||
|
||||
def _preempt_request(self, request: Request, timestamp: float) -> None:
|
||||
"""Preempt a request and put it back to the waiting queue.
|
||||
|
||||
@@ -1193,7 +1750,6 @@ class Scheduler(SchedulerInterface):
|
||||
# available. In this case, we can't schedule any token for
|
||||
# the request in this step.
|
||||
num_new_tokens = 0
|
||||
num_new_tokens = 0
|
||||
break
|
||||
|
||||
# Calculate the number of embeddings to schedule in the current range
|
||||
@@ -1508,6 +2064,9 @@ class Scheduler(SchedulerInterface):
|
||||
# outputs this step.
|
||||
engine_core_outputs[0] = eco = EngineCoreOutputs()
|
||||
eco.scheduler_stats = stats
|
||||
|
||||
if model_runner_output.draft_token_ids is not None:
|
||||
self.update_draft_token_ids(model_runner_output.draft_token_ids)
|
||||
|
||||
return engine_core_outputs
|
||||
|
||||
|
||||
@@ -1,10 +1,64 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import contextlib
|
||||
from collections.abc import Sequence
|
||||
|
||||
from vllm.sampling_params import RepetitionDetectionParams
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
|
||||
|
||||
def _has_repeating_pattern(
|
||||
token_ids: Sequence[int],
|
||||
pattern_len: int,
|
||||
repetition_min_count: int,
|
||||
) -> bool:
|
||||
"""Check if the tail of token_ids contains a repeating pattern.
|
||||
|
||||
Compares the last pattern_len tokens against the preceding
|
||||
(repetition_min_count - 1) repetitions of the same length.
|
||||
"""
|
||||
for n in range(1, pattern_len + 1):
|
||||
target_token = token_ids[-n]
|
||||
for m in range(1, repetition_min_count):
|
||||
if token_ids[-(pattern_len * m + n)] != target_token:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def check_sequence_repetition(
|
||||
token_ids: Sequence[int],
|
||||
params: RepetitionDetectionParams,
|
||||
) -> bool:
|
||||
"""Check if a sequence of token IDs has a repetition pattern.
|
||||
Args:
|
||||
token_ids: List of token IDs
|
||||
params: Repetition detection parameters.
|
||||
Returns:
|
||||
True if a repetition pattern is found, False otherwise.
|
||||
"""
|
||||
max_pattern_size = params.max_pattern_size
|
||||
min_pattern_size = params.min_pattern_size
|
||||
min_count = params.min_count
|
||||
|
||||
if min_pattern_size <= 0:
|
||||
min_pattern_size = 1
|
||||
|
||||
if max_pattern_size <= 0 or min_count < 2 or min_pattern_size > max_pattern_size:
|
||||
return False
|
||||
|
||||
for pattern_len in range(
|
||||
min_pattern_size,
|
||||
max_pattern_size + 1,
|
||||
):
|
||||
if pattern_len * min_count > len(token_ids):
|
||||
return False
|
||||
|
||||
if _has_repeating_pattern(token_ids, pattern_len, min_count):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def remove_all(lst: list, items_to_remove: set) -> list:
|
||||
"""Remove all items from a list that are in the items_to_remove set.
|
||||
|
||||
@@ -61,4 +115,16 @@ def check_stop(request: Request, max_model_len: int) -> bool:
|
||||
):
|
||||
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
|
||||
return True
|
||||
|
||||
repetition_detection = sampling_params.repetition_detection
|
||||
if repetition_detection is not None and (
|
||||
check_sequence_repetition(
|
||||
request.output_token_ids,
|
||||
repetition_detection,
|
||||
)
|
||||
):
|
||||
request.status = RequestStatus.FINISHED_REPETITION
|
||||
request.stop_reason = "repetition_detected"
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@@ -55,6 +55,7 @@ class SingleTypeKVCacheManager(ABC):
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.block_pool = block_pool
|
||||
self.enable_caching = enable_caching
|
||||
self.new_block_ids: list[int] = []
|
||||
|
||||
# Mapping from request ID to blocks to track the blocks allocated
|
||||
# for each request, so that we can free the blocks when the request
|
||||
@@ -208,6 +209,8 @@ class SingleTypeKVCacheManager(ABC):
|
||||
cdiv(num_total_computed_tokens, self.block_size) - len(req_blocks)
|
||||
)
|
||||
req_blocks.extend(allocated_blocks)
|
||||
if type(self.kv_cache_spec) is FullAttentionSpec:
|
||||
self.new_block_ids.extend(b.block_id for b in allocated_blocks)
|
||||
|
||||
def allocate_new_blocks(
|
||||
self, request_id: str, num_tokens: int, num_tokens_main_model: int
|
||||
@@ -234,8 +237,16 @@ class SingleTypeKVCacheManager(ABC):
|
||||
else:
|
||||
new_blocks = self.block_pool.get_new_blocks(num_new_blocks)
|
||||
req_blocks.extend(new_blocks)
|
||||
if type(self.kv_cache_spec) is FullAttentionSpec:
|
||||
self.new_block_ids.extend(b.block_id for b in new_blocks)
|
||||
return new_blocks
|
||||
|
||||
def take_new_block_ids(self) -> list[int]:
|
||||
"""Drain and return block IDs allocated since the last call."""
|
||||
ids = self.new_block_ids
|
||||
self.new_block_ids = []
|
||||
return ids
|
||||
|
||||
def cache_blocks(self, request: Request, num_tokens: int) -> None:
|
||||
"""
|
||||
Cache the blocks for the request.
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Set as AbstractSet
|
||||
from dataclasses import replace
|
||||
from itertools import product
|
||||
|
||||
@@ -136,7 +137,7 @@ class CudagraphDispatcher:
|
||||
num_tokens_padded = self._bs_to_padded_graph_size[num_tokens]
|
||||
|
||||
if uniform_decode and self.cudagraph_mode.has_mode(CUDAGraphMode.FULL):
|
||||
num_reqs = num_tokens_padded // uniform_decode_query_len
|
||||
num_reqs = min(num_tokens_padded // uniform_decode_query_len, max_num_seqs)
|
||||
assert num_tokens_padded % uniform_decode_query_len == 0
|
||||
else:
|
||||
uniform_decode = False
|
||||
@@ -232,8 +233,9 @@ class CudagraphDispatcher:
|
||||
num_tokens: int,
|
||||
uniform_decode: bool = False,
|
||||
has_lora: bool = False,
|
||||
disable_full: bool = False,
|
||||
num_active_loras: int = 0,
|
||||
valid_modes: AbstractSet[CUDAGraphMode] | None = None,
|
||||
invalid_modes: AbstractSet[CUDAGraphMode] | None = None,
|
||||
) -> tuple[CUDAGraphMode, BatchDescriptor]:
|
||||
"""
|
||||
Given conditions(e.g.,batch descriptor and if using piecewise only),
|
||||
@@ -246,15 +248,29 @@ class CudagraphDispatcher:
|
||||
uniform_decode: Whether the batch is uniform decode (i.e. uniform and query
|
||||
length is uniform_decode_query_len).
|
||||
has_lora: Whether LoRA is active.
|
||||
disable_full: If True, skip FULL cudagraph checks and
|
||||
return PIECEWISE or NONE only. (can be used for features like
|
||||
cascade attention that are not supported by full cudagraphs)
|
||||
num_active_loras: Number of distinct active LoRA adapters.
|
||||
valid_modes: Set of cudagraph modes that are allowed. None means
|
||||
all modes are allowed.
|
||||
invalid_modes: Set of cudagraph modes to exclude. Subtracted from
|
||||
valid_modes to compute allowed modes. (e.g., {FULL} for
|
||||
features like cascade attention not supported by full
|
||||
cudagraphs). None means no modes are excluded.
|
||||
"""
|
||||
allowed_modes = valid_modes or CUDAGraphMode.valid_runtime_modes()
|
||||
|
||||
if invalid_modes:
|
||||
allowed_modes -= invalid_modes
|
||||
|
||||
assert len(allowed_modes) >= 1, (
|
||||
f"No allowed cudagraph modes: valid_modes={valid_modes}, "
|
||||
f"invalid_modes={invalid_modes}"
|
||||
)
|
||||
|
||||
if (
|
||||
not self.keys_initialized
|
||||
or self.cudagraph_mode == CUDAGraphMode.NONE
|
||||
or num_tokens > self.compilation_config.max_cudagraph_capture_size
|
||||
or allowed_modes <= {CUDAGraphMode.NONE}
|
||||
):
|
||||
return CUDAGraphMode.NONE, BatchDescriptor(num_tokens)
|
||||
|
||||
@@ -281,24 +297,26 @@ class CudagraphDispatcher:
|
||||
num_tokens, uniform_decode, has_lora, effective_num_active_loras
|
||||
)
|
||||
|
||||
# check if key exists for full cudagraph
|
||||
# For pure FULL mode, keys are registered with uniform=False.
|
||||
batch_desc_to_check = batch_desc
|
||||
if self.cudagraph_mode == CUDAGraphMode.FULL:
|
||||
batch_desc_to_check = replace(batch_desc, uniform=False)
|
||||
if (
|
||||
not disable_full
|
||||
and batch_desc_to_check in self.cudagraph_keys[CUDAGraphMode.FULL]
|
||||
):
|
||||
return CUDAGraphMode.FULL, batch_desc_to_check
|
||||
if CUDAGraphMode.FULL in allowed_modes:
|
||||
# check if key exists for full cudagraph
|
||||
# For pure FULL mode, keys are registered with uniform=False.
|
||||
batch_desc_to_check = batch_desc
|
||||
if self.cudagraph_mode == CUDAGraphMode.FULL:
|
||||
batch_desc_to_check = replace(batch_desc, uniform=False)
|
||||
if batch_desc_to_check in self.cudagraph_keys[CUDAGraphMode.FULL]:
|
||||
return CUDAGraphMode.FULL, batch_desc_to_check
|
||||
|
||||
# also check if the relaxed key exists for more "general"
|
||||
# piecewise cudagraph
|
||||
batch_desc_to_check = replace(batch_desc, num_reqs=None, uniform=False)
|
||||
if batch_desc_to_check in self.cudagraph_keys[CUDAGraphMode.PIECEWISE]:
|
||||
return CUDAGraphMode.PIECEWISE, batch_desc_to_check
|
||||
if CUDAGraphMode.PIECEWISE in allowed_modes:
|
||||
# also check if the relaxed key exists for more "general"
|
||||
# piecewise cudagraph
|
||||
batch_desc_to_check = replace(batch_desc, num_reqs=None, uniform=False)
|
||||
if batch_desc_to_check in self.cudagraph_keys[CUDAGraphMode.PIECEWISE]:
|
||||
return CUDAGraphMode.PIECEWISE, batch_desc_to_check
|
||||
|
||||
# finally, just return no cudagraphs and a trivial batch descriptor
|
||||
assert CUDAGraphMode.NONE in allowed_modes, (
|
||||
f"No matching cudagraph found and NONE is not in "
|
||||
f"allowed_modes={allowed_modes}"
|
||||
)
|
||||
return CUDAGraphMode.NONE, BatchDescriptor(num_tokens)
|
||||
|
||||
def get_capture_descs(self) -> list[tuple[CUDAGraphMode, list[BatchDescriptor]]]:
|
||||
|
||||
@@ -27,12 +27,21 @@ PauseMode = Literal["abort", "wait", "keep"]
|
||||
|
||||
# These are possible values of RequestOutput.finish_reason,
|
||||
# so form part of the external API.
|
||||
FINISH_REASON_STRINGS = ("stop", "length", "abort", "error")
|
||||
FINISH_REASON_STRINGS = ("stop", "length", "abort", "error", "repetition")
|
||||
|
||||
EEP_NOTIFICATION_CALL_ID = -1
|
||||
|
||||
|
||||
class EEPNotificationType(enum.Enum):
|
||||
NEW_CORE_ENGINES_INIT_READY = "NEW_CORE_ENGINES_INIT_READY"
|
||||
NEW_CORE_ENGINES_WEIGHTS_INIT_READY = "NEW_CORE_ENGINES_WEIGHTS_INIT_READY"
|
||||
RECONFIGURE_FINISHED = "RECONFIGURE_FINISHED"
|
||||
SHUTDOWN_COMPLETE = "SHUTDOWN_COMPLETE"
|
||||
|
||||
|
||||
class FinishReason(enum.IntEnum):
|
||||
"""
|
||||
Reason a request finished - stop, length, abort, or error.
|
||||
Reason a request finished - stop, length, abort, error, or repetition.
|
||||
|
||||
Int rather than Str for more compact serialization.
|
||||
|
||||
@@ -41,6 +50,7 @@ class FinishReason(enum.IntEnum):
|
||||
abort - aborted by client
|
||||
error - retryable request-level internal error (e.g., KV load failure).
|
||||
Invariant: always converted to 500 Internal Server Error.
|
||||
repetition - repetitive token pattern detected (hallucination)
|
||||
|
||||
"""
|
||||
|
||||
@@ -48,6 +58,7 @@ class FinishReason(enum.IntEnum):
|
||||
LENGTH = 1
|
||||
ABORT = 2
|
||||
ERROR = 3
|
||||
REPETITION = 4
|
||||
|
||||
def __str__(self):
|
||||
return FINISH_REASON_STRINGS[self.value]
|
||||
@@ -235,6 +246,11 @@ class ReconfigureDistributedRequest(msgspec.Struct):
|
||||
new_data_parallel_rank_local: int
|
||||
new_data_parallel_master_ip: str
|
||||
new_data_parallel_master_port: int
|
||||
new_data_parallel_master_port_list: list[int]
|
||||
new_stateless_world_group_port_list: list[list[int]]
|
||||
new_stateless_dp_group_port_list: list[list[int]]
|
||||
new_stateless_ep_group_port_list: list[list[int]]
|
||||
new_stateless_eplb_group_port_list: list[list[int]]
|
||||
|
||||
|
||||
class ReconfigureRankType(enum.IntEnum):
|
||||
|
||||
@@ -20,6 +20,7 @@ from vllm.distributed.weight_transfer.base import (
|
||||
)
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.protocol import EngineClient, StreamingInput
|
||||
from vllm.entrypoints.serve.elastic_ep.middleware import set_scaling_elastic_ep
|
||||
from vllm.inputs import ProcessorInputs, PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
@@ -134,6 +135,7 @@ class AsyncLLM(EngineClient):
|
||||
self.renderer = renderer = renderer_from_config(self.vllm_config)
|
||||
self.io_processor = get_io_processor(
|
||||
self.vllm_config,
|
||||
self.renderer,
|
||||
self.model_config.io_processor_plugin,
|
||||
)
|
||||
|
||||
@@ -647,7 +649,11 @@ class AsyncLLM(EngineClient):
|
||||
engine_core = self.engine_core
|
||||
output_processor = self.output_processor
|
||||
log_stats = self.log_stats
|
||||
logger_manager = self.logger_manager
|
||||
# We use a mutable list for logger_manager so that it can be updated
|
||||
# during elastic EP scaling (see scale_elastic_ep) without creating
|
||||
# a circular reference via self.
|
||||
self._logger_ref = [self.logger_manager]
|
||||
logger_ref = self._logger_ref
|
||||
renderer = self.renderer
|
||||
chunk_size = envs.VLLM_V1_OUTPUT_PROC_CHUNK_SIZE
|
||||
|
||||
@@ -691,8 +697,8 @@ class AsyncLLM(EngineClient):
|
||||
# 4) Logging.
|
||||
# TODO(rob): make into a coroutine and launch it in
|
||||
# background thread once Prometheus overhead is non-trivial.
|
||||
if logger_manager:
|
||||
logger_manager.record(
|
||||
if logger_ref[0]:
|
||||
logger_ref[0].record(
|
||||
engine_idx=outputs.engine_index,
|
||||
scheduler_stats=outputs.scheduler_stats,
|
||||
iteration_stats=iteration_stats,
|
||||
@@ -976,17 +982,13 @@ class AsyncLLM(EngineClient):
|
||||
new_data_parallel_size,
|
||||
)
|
||||
return
|
||||
logger.info(
|
||||
"Waiting for requests to drain before scaling up to %s engines...",
|
||||
new_data_parallel_size,
|
||||
)
|
||||
await self.wait_for_requests_to_drain(drain_timeout)
|
||||
logger.info(
|
||||
"Requests have been drained, proceeding with scale to %s engines",
|
||||
new_data_parallel_size,
|
||||
)
|
||||
await self.engine_core.scale_elastic_ep(new_data_parallel_size)
|
||||
self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
|
||||
|
||||
if envs.VLLM_ELASTIC_EP_DRAIN_REQUESTS:
|
||||
logger.info(
|
||||
"VLLM_ELASTIC_EP_DRAIN_REQUESTS is set, "
|
||||
"waiting for requests to drain before scaling"
|
||||
)
|
||||
await self.wait_for_requests_to_drain(drain_timeout)
|
||||
|
||||
# recreate stat loggers
|
||||
if new_data_parallel_size > old_data_parallel_size and self.log_stats:
|
||||
@@ -999,6 +1001,18 @@ class AsyncLLM(EngineClient):
|
||||
engine_idxs=list(range(new_data_parallel_size)),
|
||||
custom_stat_loggers=None,
|
||||
)
|
||||
# Update the mutable ref so output_handler picks up the
|
||||
# new logger without creating a circular reference via self.
|
||||
if hasattr(self, "_logger_ref"):
|
||||
self._logger_ref[0] = self.logger_manager
|
||||
self.logger_manager.log_engine_initialized()
|
||||
|
||||
set_scaling_elastic_ep(True)
|
||||
try:
|
||||
await self.engine_core.scale_elastic_ep(new_data_parallel_size)
|
||||
self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
|
||||
finally:
|
||||
set_scaling_elastic_ep(False)
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
|
||||
@@ -71,6 +71,9 @@ class DPCoordinator:
|
||||
)
|
||||
|
||||
local_only_eng = dp_size == parallel_config.data_parallel_size_local
|
||||
# NOTE(yongji): handling scaling from intra-node to inter-node
|
||||
if parallel_config.enable_elastic_ep:
|
||||
local_only_eng = False
|
||||
back_publish_address = get_engine_client_zmq_addr(local_only_eng, host)
|
||||
back_output_address = get_engine_client_zmq_addr(local_only_eng, host)
|
||||
|
||||
@@ -201,6 +204,7 @@ class DPCoordinatorProc:
|
||||
|
||||
poller = zmq.Poller()
|
||||
poller.register(publish_front, zmq.POLLIN)
|
||||
poller.register(publish_back, zmq.POLLIN)
|
||||
poller.register(output_back, zmq.POLLIN)
|
||||
last_publish_time = 0
|
||||
while True:
|
||||
@@ -231,6 +235,22 @@ class DPCoordinatorProc:
|
||||
events = dict(events)
|
||||
wave_state_changed = False
|
||||
|
||||
if publish_back in events:
|
||||
buffer = publish_back.recv()
|
||||
if buffer == b"\x01":
|
||||
# NOTE(yongji): newly started engine subscribed
|
||||
# We need to send READY message here instead of receiving
|
||||
# SCALE_ELASTIC_EP notification from engine core client
|
||||
# as SCALE_ELASTIC_EP is only sent when
|
||||
# new engines finished initialization.
|
||||
# Subscription message, on the other hand, is sent
|
||||
# by each engine during initialization
|
||||
publish_back.send(b"READY")
|
||||
else:
|
||||
logger.error(
|
||||
"DP Coordinator receives unexpected message from engines"
|
||||
)
|
||||
|
||||
if publish_front in events:
|
||||
buffer = publish_front.recv()
|
||||
if buffer in (b"\x01", b"\x00"):
|
||||
@@ -259,7 +279,6 @@ class DPCoordinatorProc:
|
||||
# current_wave
|
||||
# we note that 0 is the wave number for the new
|
||||
# engine
|
||||
engines_running = False
|
||||
logger.info(
|
||||
"DPCoordinator scaled up from %s to %s engines",
|
||||
current_count,
|
||||
|
||||
@@ -17,6 +17,7 @@ from typing import Any, TypeVar, cast
|
||||
import msgspec
|
||||
import zmq
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ParallelConfig, VllmConfig
|
||||
from vllm.distributed import stateless_destroy_torch_distributed_process_group
|
||||
from vllm.envs import enable_envs_cache
|
||||
@@ -44,6 +45,8 @@ from vllm.v1.core.kv_cache_utils import (
|
||||
from vllm.v1.core.sched.interface import PauseState, SchedulerInterface
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.engine import (
|
||||
EEP_NOTIFICATION_CALL_ID,
|
||||
EEPNotificationType,
|
||||
EngineCoreOutput,
|
||||
EngineCoreOutputs,
|
||||
EngineCoreRequest,
|
||||
@@ -72,7 +75,6 @@ from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
POLLING_TIMEOUT_S = 2.5
|
||||
HANDSHAKE_TIMEOUT_MINS = 5
|
||||
|
||||
_R = TypeVar("_R") # Return type for collective_rpc
|
||||
@@ -111,6 +113,9 @@ class EngineCore:
|
||||
|
||||
self.available_gpu_memory_for_kv_cache = -1
|
||||
|
||||
if envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH:
|
||||
self._eep_scale_up_before_kv_init()
|
||||
|
||||
# Setup KV Caches and update CacheConfig after profiling.
|
||||
num_gpu_blocks, num_cpu_blocks, kv_cache_config = self._initialize_kv_caches(
|
||||
vllm_config
|
||||
@@ -180,13 +185,55 @@ class EngineCore:
|
||||
# Batch queue for scheduled batches. This enables us to asynchronously
|
||||
# schedule and execute batches, and is required by pipeline parallelism
|
||||
# to eliminate pipeline bubbles.
|
||||
self.batch_queue_size = self.model_executor.max_concurrent_batches
|
||||
base_batch_queue_size = self.model_executor.max_concurrent_batches
|
||||
if envs.VLLM_ENABLE_PP_ILU_OPT:
|
||||
self.batch_queue_size = envs.VLLM_PP_ILU_OPT_BATCH_QUEUE_SIZE
|
||||
if self.batch_queue_size <= 0:
|
||||
self.batch_queue_size = base_batch_queue_size * 2
|
||||
self._use_batch_queue_ilu_opt = True
|
||||
logger.info(
|
||||
"PP ILU opt is enabled: batch_queue_size=%d (base=%d)",
|
||||
self.batch_queue_size,
|
||||
base_batch_queue_size,
|
||||
)
|
||||
else:
|
||||
self.batch_queue_size = base_batch_queue_size
|
||||
self._use_batch_queue_ilu_opt = False
|
||||
self.batch_queue: (
|
||||
deque[tuple[Future[ModelRunnerOutput], SchedulerOutput, Future[Any]]] | None
|
||||
) = None
|
||||
if self.batch_queue_size > 1:
|
||||
logger.debug("Batch queue is enabled with size %d", self.batch_queue_size)
|
||||
logger.info(
|
||||
"Batch queue is enabled with size %d (ilu_opt=%s)",
|
||||
self.batch_queue_size,
|
||||
self._use_batch_queue_ilu_opt,
|
||||
)
|
||||
self.batch_queue = deque(maxlen=self.batch_queue_size)
|
||||
if self._use_batch_queue_ilu_opt:
|
||||
self.engine_core_input_queue: queue.Queue[
|
||||
tuple[Future[ModelRunnerOutput], SchedulerOutput]
|
||||
] = queue.Queue(maxsize=self.batch_queue_size)
|
||||
self.engine_core_output_queue: queue.Queue[
|
||||
tuple[SchedulerOutput, ModelRunnerOutput, bool]
|
||||
] = queue.Queue(maxsize=self.batch_queue_size)
|
||||
self._batch_queue_loop_thread = threading.Thread(
|
||||
target=self._process_batch_queue_loop,
|
||||
daemon=True,
|
||||
)
|
||||
self._batch_queue_loop_thread.start()
|
||||
|
||||
# When PP mix ILU scheduling or PP ILU opt is enabled with a KV
|
||||
# connector, only NixlConnector is supported.
|
||||
if vllm_config.kv_transfer_config is not None and (
|
||||
envs.VLLM_ENABLE_PP_MIX_ILU_SCHEDULING or envs.VLLM_ENABLE_PP_ILU_OPT
|
||||
):
|
||||
kv_connector_name = vllm_config.kv_transfer_config.kv_connector
|
||||
if kv_connector_name != "NixlConnector":
|
||||
raise ValueError(
|
||||
"When VLLM_ENABLE_PP_MIX_ILU_SCHEDULING or VLLM_ENABLE_PP_ILU_OPT "
|
||||
"is enabled with a KV connector, only NixlConnector is supported; "
|
||||
f"current kv_connector is {kv_connector_name!r}."
|
||||
)
|
||||
|
||||
self.is_ec_producer = (
|
||||
vllm_config.ec_transfer_config is not None
|
||||
@@ -209,6 +256,10 @@ class EngineCore:
|
||||
self.step if self.batch_queue is None else self.step_with_batch_queue
|
||||
)
|
||||
self.async_scheduling = vllm_config.scheduler_config.async_scheduling
|
||||
|
||||
self.draft_in_model_output = (
|
||||
self.batch_queue is not None and self.use_spec_decode
|
||||
)
|
||||
|
||||
self.aborts_queue = queue.Queue[list[str]]()
|
||||
|
||||
@@ -234,12 +285,10 @@ class EngineCore:
|
||||
|
||||
has_kv_cache = any(kv_cache_spec for kv_cache_spec in kv_cache_specs)
|
||||
if has_kv_cache:
|
||||
if os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1":
|
||||
dp_group = getattr(self, "dp_group", None)
|
||||
assert dp_group is not None
|
||||
self.available_gpu_memory_for_kv_cache = (
|
||||
ParallelConfig.sync_kv_cache_memory_size(dp_group, -1)
|
||||
)
|
||||
if envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH:
|
||||
# NOTE(yongji): should already be set
|
||||
# during _eep_scale_up_before_kv_init
|
||||
assert self.available_gpu_memory_for_kv_cache > 0
|
||||
available_gpu_memory = [self.available_gpu_memory_for_kv_cache] * len(
|
||||
kv_cache_specs
|
||||
)
|
||||
@@ -408,12 +457,52 @@ class EngineCore:
|
||||
# When using async scheduling we can't get draft token ids in advance,
|
||||
# so we update draft token ids in the worker process and don't
|
||||
# need to update draft token ids here.
|
||||
if self.draft_in_model_output:
|
||||
return
|
||||
if not self.async_scheduling and self.use_spec_decode and model_executed:
|
||||
# Take the draft token ids.
|
||||
draft_token_ids = self.model_executor.take_draft_token_ids()
|
||||
if draft_token_ids is not None:
|
||||
self.scheduler.update_draft_token_ids(draft_token_ids)
|
||||
|
||||
def _has_kv_connector_work(self, meta: Any) -> bool:
|
||||
"""Return True if kv_connector_metadata has any recv/save/send work."""
|
||||
if meta is None:
|
||||
return False
|
||||
for attr in ("reqs_to_recv", "reqs_to_save", "reqs_to_send"):
|
||||
val = getattr(meta, attr, None)
|
||||
if val is not None and len(val) > 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _has_meaningful_scheduler_output(
|
||||
self, scheduler_output: SchedulerOutput
|
||||
) -> bool:
|
||||
"""Return False if scheduler_output is effectively empty."""
|
||||
return not (
|
||||
len(scheduler_output.scheduled_new_reqs) == 0
|
||||
and len(scheduler_output.scheduled_cached_reqs.req_ids) == 0
|
||||
and len(scheduler_output.num_scheduled_tokens) == 0
|
||||
and scheduler_output.total_num_scheduled_tokens == 0
|
||||
and len(scheduler_output.scheduled_spec_decode_tokens) == 0
|
||||
and len(scheduler_output.scheduled_encoder_inputs) == 0
|
||||
and len(scheduler_output.finished_req_ids) == 0
|
||||
and (scheduler_output.scheduled_resumed_reqs is None
|
||||
or len(scheduler_output.scheduled_resumed_reqs) == 0)
|
||||
and not self._has_kv_connector_work(
|
||||
scheduler_output.kv_connector_metadata
|
||||
)
|
||||
)
|
||||
|
||||
def _process_batch_queue_loop(self) -> None:
|
||||
while True:
|
||||
future, scheduler_output = self.engine_core_input_queue.get()
|
||||
with self.log_error_detail(scheduler_output):
|
||||
model_output = future.result()
|
||||
self.engine_core_output_queue.put(
|
||||
(scheduler_output, model_output, False)
|
||||
)
|
||||
|
||||
def step_with_batch_queue(
|
||||
self,
|
||||
) -> tuple[dict[int, EngineCoreOutputs] | None, bool]:
|
||||
@@ -434,6 +523,9 @@ class EngineCore:
|
||||
batch_queue = self.batch_queue
|
||||
assert batch_queue is not None
|
||||
|
||||
if self._use_batch_queue_ilu_opt:
|
||||
return self.step_with_batch_queue_ilu_opt()
|
||||
|
||||
# Try to schedule a new batch if the batch queue is not full, but
|
||||
# the scheduler may return an empty batch if all requests are scheduled.
|
||||
# Note that this is not blocking.
|
||||
@@ -531,6 +623,96 @@ class EngineCore:
|
||||
|
||||
return engine_core_outputs, model_executed
|
||||
|
||||
def step_with_batch_queue_ilu_opt(
|
||||
self,
|
||||
) -> tuple[dict[int, EngineCoreOutputs] | None, bool]:
|
||||
"""Async batch queue variant using background thread for PP ILU opt.
|
||||
|
||||
Uses engine_core_input_queue / engine_core_output_queue with a
|
||||
background thread (_process_batch_queue_loop) that blocks on
|
||||
future.result(), so the main thread never blocks on GPU compute.
|
||||
"""
|
||||
assert not self.is_ec_producer, (
|
||||
"ec_producer is not supported in step_with_batch_queue_ilu_opt"
|
||||
)
|
||||
assert not self.is_pooling_model, (
|
||||
"is_pooling_model is not supported in step_with_batch_queue_ilu_opt"
|
||||
)
|
||||
assert not self.async_scheduling, (
|
||||
"async_scheduling is not supported in step_with_batch_queue_ilu_opt"
|
||||
)
|
||||
|
||||
model_executed = False
|
||||
|
||||
if self.scheduler.has_requests():
|
||||
scheduler_output = self.scheduler.schedule()
|
||||
has_meaningful_schedule = self._has_meaningful_scheduler_output(
|
||||
scheduler_output
|
||||
)
|
||||
if (
|
||||
self.engine_core_input_queue.qsize() <= 1
|
||||
and not has_meaningful_schedule
|
||||
):
|
||||
has_meaningful_schedule = True
|
||||
if has_meaningful_schedule:
|
||||
logger.debug(
|
||||
"[step_with_batch_queue_ilu_opt] scheduler_output: "
|
||||
"total_num_scheduled_tokens=%s num_scheduled_tokens=%s "
|
||||
"scheduled_new_reqs=%s scheduled_cached_reqs.req_ids=%s "
|
||||
"resumed_req_ids=%s finished_req_ids=%s "
|
||||
"has_meaningful_schedule=%s",
|
||||
scheduler_output.total_num_scheduled_tokens,
|
||||
scheduler_output.num_scheduled_tokens,
|
||||
[r.req_id for r in scheduler_output.scheduled_new_reqs],
|
||||
scheduler_output.scheduled_cached_reqs.req_ids,
|
||||
scheduler_output.scheduled_cached_reqs.resumed_req_ids,
|
||||
scheduler_output.finished_req_ids,
|
||||
has_meaningful_schedule,
|
||||
)
|
||||
|
||||
if has_meaningful_schedule:
|
||||
exec_future = self.model_executor.execute_model(
|
||||
scheduler_output, non_block=True
|
||||
)
|
||||
model_executed = (
|
||||
scheduler_output.total_num_scheduled_tokens > 0
|
||||
)
|
||||
|
||||
if not model_executed:
|
||||
future = cast(Future[ModelRunnerOutput], exec_future)
|
||||
else:
|
||||
grammar_output = self.scheduler.get_grammar_bitmask(
|
||||
scheduler_output
|
||||
)
|
||||
future = self.model_executor.sample_tokens(
|
||||
grammar_output, non_block=True
|
||||
)
|
||||
|
||||
if self.engine_core_input_queue.full():
|
||||
scheduler_output_out, model_output_out, model_executed_out = (
|
||||
self.engine_core_output_queue.get()
|
||||
)
|
||||
engine_core_outputs = self.scheduler.update_from_output(
|
||||
scheduler_output_out, model_output_out
|
||||
)
|
||||
self.engine_core_input_queue.put(
|
||||
(future, scheduler_output)
|
||||
)
|
||||
return engine_core_outputs, model_executed_out
|
||||
|
||||
self.engine_core_input_queue.put((future, scheduler_output))
|
||||
|
||||
try:
|
||||
scheduler_output, model_output, model_executed = (
|
||||
self.engine_core_output_queue.get_nowait()
|
||||
)
|
||||
engine_core_outputs = self.scheduler.update_from_output(
|
||||
scheduler_output, model_output
|
||||
)
|
||||
return engine_core_outputs, model_executed
|
||||
except queue.Empty:
|
||||
return None, False
|
||||
|
||||
def _process_aborts_queue(self):
|
||||
if not self.aborts_queue.empty():
|
||||
request_ids = []
|
||||
@@ -753,11 +935,22 @@ class EngineCore:
|
||||
self.structured_output_manager.grammar_init(req)
|
||||
return req, request.current_wave
|
||||
|
||||
def _eep_scale_up_before_kv_init(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def _eep_send_engine_core_notification(
|
||||
self,
|
||||
notification_type: EEPNotificationType,
|
||||
vllm_config: VllmConfig | None = None,
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class EngineCoreProc(EngineCore):
|
||||
"""ZMQ-wrapper for running EngineCore in background process."""
|
||||
|
||||
ENGINE_CORE_DEAD = b"ENGINE_CORE_DEAD"
|
||||
addresses: EngineZmqAddresses
|
||||
|
||||
@instrument(span_name="EngineCoreProc init")
|
||||
def __init__(
|
||||
@@ -808,6 +1001,13 @@ class EngineCoreProc(EngineCore):
|
||||
# and "hybrid" LB modes.
|
||||
self.publish_dp_lb_stats = internal_dp_balancing
|
||||
|
||||
self.addresses = addresses
|
||||
self.process_input_queue_block = True
|
||||
if envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH:
|
||||
self._eep_send_engine_core_notification(
|
||||
EEPNotificationType.NEW_CORE_ENGINES_INIT_READY,
|
||||
vllm_config=vllm_config,
|
||||
)
|
||||
self._init_data_parallel(vllm_config)
|
||||
|
||||
super().__init__(
|
||||
@@ -1120,8 +1320,14 @@ class EngineCoreProc(EngineCore):
|
||||
if logger.isEnabledFor(DEBUG):
|
||||
logger.debug("EngineCore waiting for work.")
|
||||
waited = True
|
||||
req = self.input_queue.get()
|
||||
self._handle_client_request(*req)
|
||||
block = self.process_input_queue_block
|
||||
try:
|
||||
req = self.input_queue.get(block=block)
|
||||
self._handle_client_request(*req)
|
||||
except queue.Empty:
|
||||
break
|
||||
if not block:
|
||||
break
|
||||
|
||||
if waited:
|
||||
logger.debug("EngineCore loop active.")
|
||||
@@ -1291,6 +1497,11 @@ class EngineCoreProc(EngineCore):
|
||||
for input_socket, _ in poller.poll():
|
||||
# (RequestType, RequestData)
|
||||
type_frame, *data_frames = input_socket.recv_multipart(copy=False)
|
||||
# NOTE(yongji): ignore READY message sent by DP coordinator
|
||||
# that is used to notify newly started engines
|
||||
if type_frame.buffer == b"READY":
|
||||
assert input_socket == coord_socket
|
||||
continue
|
||||
request_type = EngineCoreRequestType(bytes(type_frame.buffer))
|
||||
|
||||
# Deserialize the request data.
|
||||
@@ -1489,6 +1700,10 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
self.current_wave = 0
|
||||
self.last_counts = (0, 0)
|
||||
|
||||
from vllm.distributed.elastic_ep.elastic_state import ElasticEPScalingState
|
||||
|
||||
self.eep_scaling_state: ElasticEPScalingState | None = None
|
||||
|
||||
# Initialize the engine.
|
||||
dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
super().__init__(
|
||||
@@ -1512,7 +1727,9 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
assert 0 <= local_dp_rank <= dp_rank < dp_size
|
||||
|
||||
self.dp_rank = dp_rank
|
||||
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()
|
||||
self.dp_group, self.dp_store = (
|
||||
vllm_config.parallel_config.stateless_init_dp_group(return_store=True)
|
||||
)
|
||||
|
||||
def shutdown(self):
|
||||
super().shutdown()
|
||||
@@ -1533,7 +1750,11 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
|
||||
def resume_scheduler(self):
|
||||
super().resume_scheduler()
|
||||
if not self.engines_running and self.scheduler.has_unfinished_requests():
|
||||
if (
|
||||
self.has_coordinator
|
||||
and not self.engines_running
|
||||
and self.scheduler.has_unfinished_requests()
|
||||
):
|
||||
# Wake up other DP engines.
|
||||
self.output_queue.put_nowait(
|
||||
(-1, EngineCoreOutputs(start_wave=self.current_wave))
|
||||
@@ -1575,7 +1796,12 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
# 1) Poll the input queue until there is work to do.
|
||||
self._process_input_queue()
|
||||
|
||||
# 2) Step the engine core.
|
||||
if self.eep_scaling_state is not None:
|
||||
_ = self.eep_scaling_state.progress()
|
||||
if self.eep_scaling_state.is_complete():
|
||||
self.process_input_queue_block = True
|
||||
self.eep_scaling_state = None
|
||||
|
||||
executed = self._process_engine_step()
|
||||
self._maybe_publish_request_counts()
|
||||
|
||||
@@ -1625,54 +1851,129 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
def reinitialize_distributed(
|
||||
self, reconfig_request: ReconfigureDistributedRequest
|
||||
) -> None:
|
||||
stateless_destroy_torch_distributed_process_group(self.dp_group)
|
||||
self.shutdown()
|
||||
from copy import deepcopy
|
||||
|
||||
parallel_config = self.vllm_config.parallel_config
|
||||
old_dp_size = parallel_config.data_parallel_size
|
||||
parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
|
||||
if reconfig_request.new_data_parallel_rank != -1:
|
||||
parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
|
||||
# local rank specifies device visibility, it should not be changed
|
||||
assert (
|
||||
reconfig_request.new_data_parallel_rank_local
|
||||
== ReconfigureRankType.KEEP_CURRENT_RANK
|
||||
)
|
||||
parallel_config.data_parallel_master_ip = (
|
||||
reconfig_request.new_data_parallel_master_ip
|
||||
)
|
||||
parallel_config.data_parallel_master_port = (
|
||||
reconfig_request.new_data_parallel_master_port
|
||||
)
|
||||
if reconfig_request.new_data_parallel_rank != -2:
|
||||
self.dp_rank = parallel_config.data_parallel_rank
|
||||
self.dp_group = parallel_config.stateless_init_dp_group()
|
||||
reconfig_request.new_data_parallel_master_port = (
|
||||
parallel_config.data_parallel_master_port
|
||||
)
|
||||
from vllm.distributed.elastic_ep.elastic_state import ElasticEPScalingState
|
||||
|
||||
self.model_executor.reinitialize_distributed(reconfig_request)
|
||||
if reconfig_request.new_data_parallel_size > old_dp_size:
|
||||
assert self.available_gpu_memory_for_kv_cache > 0
|
||||
# pass available_gpu_memory_for_kv_cache from existing
|
||||
# engine-cores to new engine-cores so they can directly
|
||||
# use it in _initialize_kv_caches() rather than profiling.
|
||||
ParallelConfig.sync_kv_cache_memory_size(
|
||||
self.dp_group, self.available_gpu_memory_for_kv_cache
|
||||
)
|
||||
# NOTE(yongji): newly joined workers require dummy_run even
|
||||
# CUDA graph is not used
|
||||
self.model_executor.collective_rpc("compile_or_warm_up_model")
|
||||
new_parallel_config = deepcopy(self.vllm_config.parallel_config)
|
||||
old_dp_size = new_parallel_config.data_parallel_size
|
||||
new_parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
|
||||
if (
|
||||
reconfig_request.new_data_parallel_rank
|
||||
== ReconfigureRankType.SHUTDOWN_CURRENT_RANK
|
||||
!= ReconfigureRankType.KEEP_CURRENT_RANK
|
||||
):
|
||||
self.shutdown()
|
||||
logger.info("DPEngineCoreProc %s shutdown", self.dp_rank)
|
||||
else:
|
||||
logger.info(
|
||||
"Distributed environment reinitialized for DP rank %s", self.dp_rank
|
||||
new_parallel_config.data_parallel_rank = (
|
||||
reconfig_request.new_data_parallel_rank
|
||||
)
|
||||
new_parallel_config.data_parallel_master_ip = (
|
||||
reconfig_request.new_data_parallel_master_ip
|
||||
)
|
||||
new_parallel_config.data_parallel_master_port = (
|
||||
reconfig_request.new_data_parallel_master_port
|
||||
)
|
||||
new_parallel_config._data_parallel_master_port_list = (
|
||||
reconfig_request.new_data_parallel_master_port_list
|
||||
)
|
||||
|
||||
is_scale_down = reconfig_request.new_data_parallel_size < old_dp_size
|
||||
is_shutdown = (
|
||||
reconfig_request.new_data_parallel_rank
|
||||
== ReconfigureRankType.SHUTDOWN_CURRENT_RANK
|
||||
)
|
||||
|
||||
self.eep_scaling_state = ElasticEPScalingState(
|
||||
model_executor=self.model_executor,
|
||||
engine_core=self,
|
||||
vllm_config=self.vllm_config,
|
||||
new_parallel_config=new_parallel_config,
|
||||
worker_type="removing" if is_shutdown else "existing",
|
||||
scale_type="scale_down" if is_scale_down else "scale_up",
|
||||
reconfig_request=reconfig_request,
|
||||
)
|
||||
self.process_input_queue_block = False
|
||||
logger.info(
|
||||
"[Elastic EP] Received reconfiguration request and starting scaling up/down"
|
||||
)
|
||||
|
||||
def _eep_send_engine_core_notification(
|
||||
self,
|
||||
notification_type: EEPNotificationType,
|
||||
vllm_config: VllmConfig | None = None,
|
||||
):
|
||||
"""
|
||||
Send notifications to EngineCoreClient, which can then forward
|
||||
the notifications to other engine core processes. It is used for:
|
||||
1) In scale up: new core engines to notify exisiting core engines
|
||||
that they are ready;
|
||||
2) In scale down: removing core engines to notify EngineCoreClient
|
||||
so EngineCoreClient can release their ray placement groups;
|
||||
3) Both scale up/down: to notify EngineCoreClient that exisiting
|
||||
core engines have already switched to the new parallel setup.
|
||||
"""
|
||||
if vllm_config is None:
|
||||
dp_rank = self.vllm_config.parallel_config.data_parallel_rank
|
||||
else:
|
||||
dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
notification_data = (notification_type.value, dp_rank)
|
||||
outputs = EngineCoreOutputs(
|
||||
utility_output=UtilityOutput(
|
||||
call_id=EEP_NOTIFICATION_CALL_ID,
|
||||
result=UtilityResult(notification_data),
|
||||
)
|
||||
)
|
||||
outputs.engine_index = self.engine_index
|
||||
|
||||
if hasattr(self, "output_thread") and self.output_thread.is_alive():
|
||||
self.output_queue.put_nowait((0, outputs))
|
||||
else:
|
||||
encoder = MsgpackEncoder()
|
||||
with (
|
||||
zmq.Context() as ctx,
|
||||
make_zmq_socket(
|
||||
ctx, self.addresses.outputs[0], zmq.PUSH, linger=4000
|
||||
) as socket,
|
||||
):
|
||||
socket.send_multipart(encoder.encode(outputs))
|
||||
|
||||
def eep_handle_engine_core_notification(
|
||||
self, notification_type: str | EEPNotificationType
|
||||
):
|
||||
"""
|
||||
Handle notification received from EngineCoreClient
|
||||
(forwarded from new core engines).
|
||||
"""
|
||||
assert self.eep_scaling_state is not None
|
||||
if isinstance(notification_type, str):
|
||||
notification_type = EEPNotificationType(notification_type)
|
||||
self.eep_scaling_state.handle_notification(notification_type)
|
||||
|
||||
def _eep_scale_up_before_kv_init(self):
|
||||
from vllm.distributed.elastic_ep.elastic_state import ElasticEPScalingState
|
||||
|
||||
self.eep_scaling_state = ElasticEPScalingState(
|
||||
model_executor=self.model_executor,
|
||||
engine_core=self,
|
||||
vllm_config=self.vllm_config,
|
||||
new_parallel_config=self.vllm_config.parallel_config,
|
||||
worker_type="new",
|
||||
scale_type="scale_up",
|
||||
reconfig_request=None,
|
||||
)
|
||||
self.model_executor.collective_rpc("init_device")
|
||||
self.model_executor.collective_rpc("load_model")
|
||||
self._eep_send_engine_core_notification(
|
||||
EEPNotificationType.NEW_CORE_ENGINES_WEIGHTS_INIT_READY
|
||||
)
|
||||
self.model_executor.collective_rpc(
|
||||
"elastic_ep_execute", args=("receive_weights",)
|
||||
)
|
||||
self.available_gpu_memory_for_kv_cache = (
|
||||
ParallelConfig.sync_kv_cache_memory_size(self.dp_group, -1)
|
||||
)
|
||||
self.model_executor.collective_rpc(
|
||||
"elastic_ep_execute", args=("prepare_new_worker",)
|
||||
)
|
||||
self.process_input_queue_block = False
|
||||
|
||||
|
||||
class EngineCoreActorMixin:
|
||||
|
||||
@@ -28,11 +28,12 @@ from vllm.tracing import instrument
|
||||
from vllm.utils.async_utils import in_loop
|
||||
from vllm.utils.network_utils import (
|
||||
close_sockets,
|
||||
get_open_port,
|
||||
get_open_zmq_inproc_path,
|
||||
make_zmq_socket,
|
||||
)
|
||||
from vllm.v1.engine import (
|
||||
EEP_NOTIFICATION_CALL_ID,
|
||||
EEPNotificationType,
|
||||
EngineCoreOutputs,
|
||||
EngineCoreRequest,
|
||||
EngineCoreRequestType,
|
||||
@@ -47,6 +48,7 @@ from vllm.v1.engine.exceptions import EngineDeadError
|
||||
from vllm.v1.engine.utils import (
|
||||
CoreEngineActorManager,
|
||||
CoreEngineProcManager,
|
||||
get_engine_zmq_addresses,
|
||||
launch_core_engines,
|
||||
)
|
||||
from vllm.v1.executor import Executor
|
||||
@@ -445,6 +447,63 @@ class BackgroundResources:
|
||||
raise EngineDeadError()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ElasticScalingCache:
|
||||
existing_core_engines: list[EngineIdentity]
|
||||
num_new_core_engines: int
|
||||
pending_notifications: dict[EEPNotificationType, set[int]]
|
||||
|
||||
|
||||
def allocate_stateless_group_ports(parallel_config, new_data_parallel_size: int):
|
||||
"""
|
||||
Allocate stateless group ports for elastic EP.
|
||||
"""
|
||||
from vllm.utils.network_utils import get_open_ports_list
|
||||
|
||||
assert parallel_config.enable_elastic_ep, "Elastic EP must be enabled"
|
||||
world_size = parallel_config.world_size
|
||||
new_world_size_across_dp = world_size * new_data_parallel_size
|
||||
num_world_groups = 1
|
||||
num_dp_groups = max(1, new_world_size_across_dp // new_data_parallel_size)
|
||||
num_ep_groups = max(
|
||||
1,
|
||||
new_world_size_across_dp
|
||||
// (new_data_parallel_size * parallel_config.tensor_parallel_size),
|
||||
)
|
||||
num_eplb_groups = num_ep_groups
|
||||
total_ports_needed = (
|
||||
num_world_groups + num_dp_groups + num_ep_groups + num_eplb_groups
|
||||
) * 3 + 5
|
||||
all_ports = get_open_ports_list(total_ports_needed)
|
||||
new_data_parallel_master_port_list = all_ports[-5:]
|
||||
all_ports = all_ports[:-5]
|
||||
new_stateless_world_group_port_list = [
|
||||
all_ports[i : i + 3] for i in range(0, num_world_groups * 3, 3)
|
||||
]
|
||||
start_idx = num_world_groups * 3
|
||||
new_stateless_dp_group_port_list = [
|
||||
all_ports[i : i + 3] for i in range(start_idx, start_idx + num_dp_groups * 3, 3)
|
||||
]
|
||||
start_idx += num_dp_groups * 3
|
||||
new_stateless_ep_group_port_list = [
|
||||
all_ports[i : i + 3] for i in range(start_idx, start_idx + num_ep_groups * 3, 3)
|
||||
]
|
||||
start_idx += num_ep_groups * 3
|
||||
new_stateless_eplb_group_port_list = [
|
||||
all_ports[i : i + 3]
|
||||
for i in range(start_idx, start_idx + num_eplb_groups * 3, 3)
|
||||
]
|
||||
|
||||
parallel_config._stateless_world_group_port_list = (
|
||||
new_stateless_world_group_port_list
|
||||
)
|
||||
parallel_config._stateless_dp_group_port_list = new_stateless_dp_group_port_list
|
||||
parallel_config._stateless_ep_group_port_list = new_stateless_ep_group_port_list
|
||||
parallel_config._stateless_eplb_group_port_list = new_stateless_eplb_group_port_list
|
||||
parallel_config.data_parallel_master_port = new_data_parallel_master_port_list.pop()
|
||||
parallel_config._data_parallel_master_port_list = new_data_parallel_master_port_list
|
||||
|
||||
|
||||
class MPClient(EngineCoreClient):
|
||||
"""
|
||||
MPClient: base client for multi-proc EngineCore.
|
||||
@@ -491,32 +550,37 @@ class MPClient(EngineCoreClient):
|
||||
input_address = client_addresses["input_address"]
|
||||
output_address = client_addresses["output_address"]
|
||||
self.stats_update_address = client_addresses.get("stats_update_address")
|
||||
self.input_socket = self.resources.input_socket = make_zmq_socket(
|
||||
self.ctx, input_address, zmq.ROUTER, bind=True
|
||||
)
|
||||
self.resources.output_socket = make_zmq_socket(
|
||||
self.ctx, output_address, zmq.PULL
|
||||
)
|
||||
else:
|
||||
# Engines are managed by this client.
|
||||
with launch_core_engines(vllm_config, executor_class, log_stats) as (
|
||||
engine_manager,
|
||||
coordinator,
|
||||
addresses = get_engine_zmq_addresses(vllm_config)
|
||||
self.input_socket = self.resources.input_socket = make_zmq_socket(
|
||||
self.ctx, addresses.inputs[0], zmq.ROUTER, bind=True
|
||||
)
|
||||
self.resources.output_socket = make_zmq_socket(
|
||||
self.ctx, addresses.outputs[0], zmq.PULL
|
||||
)
|
||||
|
||||
with launch_core_engines(
|
||||
vllm_config,
|
||||
executor_class,
|
||||
log_stats,
|
||||
addresses,
|
||||
):
|
||||
) as (engine_manager, coordinator, addresses):
|
||||
self.resources.coordinator = coordinator
|
||||
self.resources.engine_manager = engine_manager
|
||||
|
||||
(input_address,) = addresses.inputs
|
||||
(output_address,) = addresses.outputs
|
||||
self.stats_update_address = addresses.frontend_stats_publish_address
|
||||
if coordinator is not None:
|
||||
assert self.stats_update_address == (
|
||||
coordinator.get_stats_publish_address()
|
||||
)
|
||||
|
||||
# Create input and output sockets.
|
||||
self.input_socket = self.resources.input_socket = make_zmq_socket(
|
||||
self.ctx, input_address, zmq.ROUTER, bind=True
|
||||
)
|
||||
self.resources.output_socket = make_zmq_socket(
|
||||
self.ctx, output_address, zmq.PULL
|
||||
)
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
dp_size = parallel_config.data_parallel_size
|
||||
dp_rank = parallel_config.data_parallel_index
|
||||
@@ -545,8 +609,13 @@ class MPClient(EngineCoreClient):
|
||||
timeout=VLLM_ENGINE_READY_TIMEOUT_S * 1000 # convert to ms
|
||||
):
|
||||
raise TimeoutError(
|
||||
"Timed out waiting for engines to send "
|
||||
"initial message on input socket."
|
||||
f"Timed out waiting for engine core processes to "
|
||||
f"start. This is often caused by slow weight loading "
|
||||
f"for large models. Waited "
|
||||
f"{VLLM_ENGINE_READY_TIMEOUT_S}s (configured by "
|
||||
f"VLLM_ENGINE_READY_TIMEOUT_S). To increase the "
|
||||
f"timeout, set the environment variable: "
|
||||
f"VLLM_ENGINE_READY_TIMEOUT_S=<seconds>"
|
||||
)
|
||||
identity, _ = sync_input_socket.recv_multipart()
|
||||
identities.remove(identity)
|
||||
@@ -877,6 +946,10 @@ class AsyncMPClient(MPClient):
|
||||
output_socket = resources.output_socket
|
||||
assert output_socket is not None
|
||||
|
||||
notification_callback_handler: (
|
||||
Callable[[AsyncMPClient, Sequence[Any]], Any] | None
|
||||
) = getattr(self.__class__, "eep_process_engine_core_notification", None)
|
||||
|
||||
async def process_outputs_socket():
|
||||
try:
|
||||
while True:
|
||||
@@ -884,7 +957,26 @@ class AsyncMPClient(MPClient):
|
||||
resources.validate_alive(frames)
|
||||
outputs: EngineCoreOutputs = decoder.decode(frames)
|
||||
if outputs.utility_output:
|
||||
_process_utility_output(outputs.utility_output, utility_results)
|
||||
if (
|
||||
outputs.utility_output.call_id == EEP_NOTIFICATION_CALL_ID
|
||||
and notification_callback_handler is not None
|
||||
):
|
||||
assert _self_ref is not None
|
||||
_self = _self_ref()
|
||||
if not _self:
|
||||
return
|
||||
if outputs.utility_output.result is None:
|
||||
continue
|
||||
notification_data = outputs.utility_output.result.result
|
||||
assert isinstance(notification_data, Sequence)
|
||||
assert len(notification_data) == 2
|
||||
asyncio.create_task(
|
||||
notification_callback_handler(_self, notification_data)
|
||||
)
|
||||
else:
|
||||
_process_utility_output(
|
||||
outputs.utility_output, utility_results
|
||||
)
|
||||
continue
|
||||
|
||||
if output_handler is not None:
|
||||
@@ -1081,6 +1173,8 @@ class DPAsyncMPClient(AsyncMPClient):
|
||||
# Used only by DPLBAsyncMPClient subclass.
|
||||
self.lb_engines: list[list[int]] = [[0, 0] for _ in self.core_engines]
|
||||
|
||||
self.eep_scaling_cache: ElasticScalingCache | None = None
|
||||
|
||||
self.first_req_sock_addr = get_open_zmq_inproc_path()
|
||||
self.first_req_send_socket = self.resources.first_req_send_socket = (
|
||||
make_zmq_socket(self.ctx, self.first_req_sock_addr, zmq.PAIR, bind=True)
|
||||
@@ -1101,12 +1195,6 @@ class DPAsyncMPClient(AsyncMPClient):
|
||||
assert self.stats_update_address is not None
|
||||
stats_addr: str = self.stats_update_address
|
||||
assert len(self.engine_ranks_managed) > 0
|
||||
# NOTE: running and waiting counts are all global from
|
||||
# the Coordinator include all global EngineCores. This
|
||||
# slice includes just the cores managed by this client.
|
||||
count_slice = slice(
|
||||
self.engine_ranks_managed[0], self.engine_ranks_managed[-1] + 1
|
||||
)
|
||||
|
||||
async def run_engine_stats_update_task():
|
||||
with (
|
||||
@@ -1145,6 +1233,29 @@ class DPAsyncMPClient(AsyncMPClient):
|
||||
):
|
||||
# Extract new engine count from the decoded message
|
||||
new_engine_count = decoded[1]
|
||||
# Update engine_ranks_managed and count_slice
|
||||
parallel_config = self.vllm_config.parallel_config
|
||||
dp_size = parallel_config.data_parallel_size
|
||||
dp_rank = parallel_config.data_parallel_rank
|
||||
assert dp_rank == 0
|
||||
assert dp_size == new_engine_count
|
||||
assert not (
|
||||
parallel_config.data_parallel_hybrid_lb
|
||||
or parallel_config.data_parallel_external_lb
|
||||
)
|
||||
num_ranks = dp_size
|
||||
self.engine_ranks_managed = list(
|
||||
range(dp_rank, dp_rank + num_ranks)
|
||||
)
|
||||
if len(self.lb_engines) < new_engine_count:
|
||||
self.lb_engines = self.lb_engines + [
|
||||
[0, 0]
|
||||
for _ in range(
|
||||
new_engine_count - len(self.lb_engines)
|
||||
)
|
||||
]
|
||||
else:
|
||||
self.lb_engines = self.lb_engines[:new_engine_count]
|
||||
# Send scale up notification to coordinator
|
||||
scale_msg = msgspec.msgpack.encode(
|
||||
("SCALE_ELASTIC_EP", new_engine_count)
|
||||
@@ -1178,6 +1289,11 @@ class DPAsyncMPClient(AsyncMPClient):
|
||||
self.current_wave = wave
|
||||
self.engines_running = running
|
||||
if counts is not None:
|
||||
# Running and waiting counts are global from the
|
||||
# Coordinator including all EngineCores. Slice to get
|
||||
# just the cores managed by this client.
|
||||
ranks = self.engine_ranks_managed
|
||||
count_slice = slice(ranks[0], ranks[-1] + 1)
|
||||
sliced_counts = counts[count_slice]
|
||||
self.lb_engines = sliced_counts
|
||||
logger.debug(
|
||||
@@ -1287,6 +1403,67 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
|
||||
for req_id in outputs.finished_requests:
|
||||
self.reqs_in_flight.pop(req_id, None)
|
||||
|
||||
@staticmethod
|
||||
async def eep_process_engine_core_notification(
|
||||
self: "DPLBAsyncMPClient", notification_data: tuple[str, int]
|
||||
):
|
||||
cache = self.eep_scaling_cache
|
||||
notification_type_str, dp_rank = notification_data
|
||||
try:
|
||||
notification_type = EEPNotificationType(notification_type_str)
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
f"Unknown EEP notification type: {notification_type_str}"
|
||||
) from e
|
||||
|
||||
if notification_type == EEPNotificationType.RECONFIGURE_FINISHED:
|
||||
from vllm.v1.engine import UtilityResult
|
||||
|
||||
# NOTE(yongji): process a dummy UtilityOutput to resolve the future
|
||||
# awaited in _eep_wait_for_setup_switch_complete(), signaling that
|
||||
# all engine cores have completed reconfiguration.
|
||||
dummy_output = UtilityOutput(
|
||||
call_id=EEP_NOTIFICATION_CALL_ID, result=UtilityResult(None)
|
||||
)
|
||||
_process_utility_output(dummy_output, self.utility_results)
|
||||
return
|
||||
assert cache is not None
|
||||
if notification_type not in cache.pending_notifications:
|
||||
cache.pending_notifications[notification_type] = set()
|
||||
if dp_rank in cache.pending_notifications[notification_type]:
|
||||
raise ValueError(
|
||||
f"Duplicate notification {notification_type} from dp_rank {dp_rank}"
|
||||
)
|
||||
cache.pending_notifications[notification_type].add(dp_rank)
|
||||
if len(cache.pending_notifications[notification_type]) >= abs(
|
||||
cache.num_new_core_engines
|
||||
):
|
||||
if notification_type == EEPNotificationType.SHUTDOWN_COMPLETE:
|
||||
assert isinstance(self.resources.engine_manager, CoreEngineActorManager)
|
||||
assert cache.num_new_core_engines < 0
|
||||
old_dp_size = len(cache.existing_core_engines)
|
||||
new_dp_size = old_dp_size + cache.num_new_core_engines
|
||||
self.resources.engine_manager.scale_down_elastic_ep(
|
||||
old_dp_size, new_dp_size
|
||||
)
|
||||
else:
|
||||
await asyncio.gather(
|
||||
*[
|
||||
self._call_utility_async(
|
||||
"eep_handle_engine_core_notification",
|
||||
notification_type,
|
||||
engine=engine,
|
||||
)
|
||||
for engine in cache.existing_core_engines
|
||||
]
|
||||
)
|
||||
cache.pending_notifications[notification_type] = set()
|
||||
if notification_type in [
|
||||
EEPNotificationType.SHUTDOWN_COMPLETE,
|
||||
EEPNotificationType.NEW_CORE_ENGINES_WEIGHTS_INIT_READY,
|
||||
]:
|
||||
self.eep_scaling_cache = None
|
||||
|
||||
async def abort_requests_async(self, request_ids: list[str]) -> None:
|
||||
if not request_ids or self.resources.engine_dead:
|
||||
return
|
||||
@@ -1333,6 +1510,20 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
|
||||
cur_data_parallel_size, new_data_parallel_size
|
||||
)
|
||||
|
||||
async def _eep_wait_for_setup_switch_complete(self) -> None:
|
||||
"""
|
||||
Wait for core engines to switch to the new setup.
|
||||
|
||||
In eep_process_engine_core_notification(), a dummy UtilityOutput with
|
||||
EEP_NOTIFICATION_CALL_ID will be set when RECONFIGURE_FINISHED
|
||||
notification is received from engine 0. We create a future with
|
||||
that call_id and wait for it to be resolved.
|
||||
"""
|
||||
future = asyncio.get_running_loop().create_future()
|
||||
self.utility_results[EEP_NOTIFICATION_CALL_ID] = future
|
||||
self._ensure_output_queue_task()
|
||||
await future
|
||||
|
||||
async def _scale_up_elastic_ep(
|
||||
self, cur_data_parallel_size: int, new_data_parallel_size: int
|
||||
) -> None:
|
||||
@@ -1340,38 +1531,57 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
|
||||
and reconfiguring existing ones."""
|
||||
cur_data_parallel_size = len(self.core_engines)
|
||||
|
||||
# Phase 1: Send reconfigure messages to all existing engines and wait
|
||||
# for them to be sent
|
||||
self.eep_scaling_cache = ElasticScalingCache(
|
||||
existing_core_engines=self.core_engines.copy(),
|
||||
num_new_core_engines=new_data_parallel_size - cur_data_parallel_size,
|
||||
pending_notifications=dict(),
|
||||
)
|
||||
|
||||
parallel_config = self.vllm_config.parallel_config
|
||||
allocate_stateless_group_ports(parallel_config, new_data_parallel_size)
|
||||
|
||||
# Phase 1: Send reconfig messages to existing engines
|
||||
reconfig_futures = []
|
||||
self.vllm_config.parallel_config.data_parallel_master_port = get_open_port()
|
||||
for engine in self.core_engines:
|
||||
reconfig_request = ReconfigureDistributedRequest(
|
||||
new_data_parallel_size=new_data_parallel_size,
|
||||
new_data_parallel_rank=ReconfigureRankType.KEEP_CURRENT_RANK,
|
||||
new_data_parallel_rank_local=ReconfigureRankType.KEEP_CURRENT_RANK,
|
||||
new_data_parallel_master_ip=self.vllm_config.parallel_config.data_parallel_master_ip,
|
||||
new_data_parallel_master_port=self.vllm_config.parallel_config.data_parallel_master_port,
|
||||
new_data_parallel_master_ip=parallel_config.data_parallel_master_ip,
|
||||
new_data_parallel_master_port=parallel_config.data_parallel_master_port,
|
||||
new_data_parallel_master_port_list=parallel_config._data_parallel_master_port_list,
|
||||
new_stateless_world_group_port_list=parallel_config._stateless_world_group_port_list,
|
||||
new_stateless_dp_group_port_list=parallel_config._stateless_dp_group_port_list,
|
||||
new_stateless_ep_group_port_list=parallel_config._stateless_ep_group_port_list,
|
||||
new_stateless_eplb_group_port_list=parallel_config._stateless_eplb_group_port_list,
|
||||
)
|
||||
coro = self._call_utility_async(
|
||||
"reinitialize_distributed", reconfig_request, engine=engine
|
||||
)
|
||||
reconfig_futures.append(asyncio.create_task(coro))
|
||||
|
||||
logger.info("All reconfigure messages sent, starting engine creation")
|
||||
|
||||
# Phase 2: Create new engines now that reconfig messages have been sent
|
||||
# self.resources.engine_manager is guaranteed to be
|
||||
# CoreEngineActorManager for RayDPClient
|
||||
# Phase 2: Create new engines
|
||||
assert isinstance(self.resources.engine_manager, CoreEngineActorManager)
|
||||
self.resources.engine_manager.scale_up_elastic_ep(
|
||||
self.vllm_config, new_data_parallel_size
|
||||
parallel_config.eplb_config.num_redundant_experts = 0
|
||||
start_new_worker_future = asyncio.to_thread(
|
||||
self.resources.engine_manager.scale_up_elastic_ep,
|
||||
self.vllm_config,
|
||||
new_data_parallel_size,
|
||||
)
|
||||
wait_future = self._eep_wait_for_setup_switch_complete()
|
||||
|
||||
# Phase 3: Wait for new engines to be created
|
||||
# and reconfig messages to be received
|
||||
await asyncio.gather(start_new_worker_future, *reconfig_futures)
|
||||
logger.info("[Elastic EP] Successfully started new engines")
|
||||
|
||||
# Create new CoreEngine objects for the new engines
|
||||
new_engine_identities = set()
|
||||
for i in range(cur_data_parallel_size, new_data_parallel_size):
|
||||
new_engine = i.to_bytes(2, "little")
|
||||
self.core_engines.append(new_engine)
|
||||
# NOTE(yongji): we don't update lb_engines here,
|
||||
# we let run_engine_stats_update_task to update it.
|
||||
new_engine_identities.add(new_engine)
|
||||
|
||||
# Wait for ready messages from new engines on the input socket
|
||||
@@ -1381,16 +1591,21 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
|
||||
timeout=VLLM_ENGINE_READY_TIMEOUT_S * 1000 # convert to ms
|
||||
):
|
||||
raise TimeoutError(
|
||||
"Timed out waiting for new engines to send initial "
|
||||
"message on input socket."
|
||||
f"Timed out waiting for new engine core processes to "
|
||||
f"start. Waited "
|
||||
f"{VLLM_ENGINE_READY_TIMEOUT_S}s (configured by "
|
||||
f"VLLM_ENGINE_READY_TIMEOUT_S). To increase the "
|
||||
f"timeout, set the environment variable: "
|
||||
f"VLLM_ENGINE_READY_TIMEOUT_S=<seconds>"
|
||||
)
|
||||
identity, _ = sync_input_socket.recv_multipart()
|
||||
new_engine_identities.discard(identity)
|
||||
|
||||
# Phase 3: Wait for all existing engines to complete reconfiguration
|
||||
logger.info("Waiting for existing engines to complete reconfiguration")
|
||||
await asyncio.gather(*reconfig_futures)
|
||||
|
||||
# NOTE(yongji): Before we schedule any requests on the new workers,
|
||||
# we should wait for them to switch to the new setup.
|
||||
await wait_future
|
||||
# Update the parallel config
|
||||
self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
|
||||
# Notify coordinator about scale up through existing
|
||||
# stats_update_task connection
|
||||
self._ensure_stats_update_task()
|
||||
@@ -1399,8 +1614,6 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
|
||||
)
|
||||
await self.first_req_send_socket.send(scale_up_marker)
|
||||
|
||||
# Update the parallel config
|
||||
self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
|
||||
logger.info(
|
||||
"[Elastic EP] Scale up completed, new data parallel size: %s",
|
||||
new_data_parallel_size,
|
||||
@@ -1413,7 +1626,14 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
|
||||
reconfiguring existing engine cores."""
|
||||
cur_data_parallel_size = len(self.core_engines)
|
||||
|
||||
self.vllm_config.parallel_config.data_parallel_master_port = get_open_port()
|
||||
self.eep_scaling_cache = ElasticScalingCache(
|
||||
existing_core_engines=self.core_engines.copy(),
|
||||
num_new_core_engines=new_data_parallel_size - cur_data_parallel_size,
|
||||
pending_notifications=dict(),
|
||||
)
|
||||
|
||||
parallel_config = self.vllm_config.parallel_config
|
||||
allocate_stateless_group_ports(parallel_config, new_data_parallel_size)
|
||||
|
||||
reconfig_futures = []
|
||||
for cur_dp_rank, engine in enumerate(self.core_engines):
|
||||
@@ -1421,8 +1641,13 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
|
||||
new_data_parallel_size=new_data_parallel_size,
|
||||
new_data_parallel_rank=ReconfigureRankType.KEEP_CURRENT_RANK,
|
||||
new_data_parallel_rank_local=ReconfigureRankType.KEEP_CURRENT_RANK,
|
||||
new_data_parallel_master_ip=self.vllm_config.parallel_config.data_parallel_master_ip,
|
||||
new_data_parallel_master_port=self.vllm_config.parallel_config.data_parallel_master_port,
|
||||
new_data_parallel_master_ip=parallel_config.data_parallel_master_ip,
|
||||
new_data_parallel_master_port=parallel_config.data_parallel_master_port,
|
||||
new_data_parallel_master_port_list=parallel_config._data_parallel_master_port_list,
|
||||
new_stateless_world_group_port_list=parallel_config._stateless_world_group_port_list,
|
||||
new_stateless_dp_group_port_list=parallel_config._stateless_dp_group_port_list,
|
||||
new_stateless_ep_group_port_list=parallel_config._stateless_ep_group_port_list,
|
||||
new_stateless_eplb_group_port_list=parallel_config._stateless_eplb_group_port_list,
|
||||
)
|
||||
if cur_dp_rank >= new_data_parallel_size:
|
||||
reconfig_request.new_data_parallel_rank = (
|
||||
@@ -1433,23 +1658,24 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
|
||||
)
|
||||
reconfig_futures.append(asyncio.create_task(coro))
|
||||
|
||||
for _ in range(new_data_parallel_size, cur_data_parallel_size):
|
||||
self.core_engines.pop()
|
||||
# NOTE(yongji): Immediately stop sending requests to the removing engines.
|
||||
self.core_engines = self.core_engines[:new_data_parallel_size]
|
||||
self.lb_engines = self.lb_engines[:new_data_parallel_size]
|
||||
wait_future = self._eep_wait_for_setup_switch_complete()
|
||||
|
||||
await asyncio.gather(*reconfig_futures)
|
||||
|
||||
assert isinstance(self.resources.engine_manager, CoreEngineActorManager)
|
||||
self.resources.engine_manager.scale_down_elastic_ep(
|
||||
cur_data_parallel_size, new_data_parallel_size
|
||||
)
|
||||
|
||||
self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
|
||||
self._ensure_stats_update_task()
|
||||
scale_down_marker = msgspec.msgpack.encode(
|
||||
("SCALE_ELASTIC_EP", new_data_parallel_size)
|
||||
)
|
||||
await self.first_req_send_socket.send(scale_down_marker)
|
||||
|
||||
self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
|
||||
# NOTE(yongji): Unlike scaling up,
|
||||
# here we don't actually need to wait for the setup switch to complete.
|
||||
# We may want to remove it in the future.
|
||||
await wait_future
|
||||
logger.info(
|
||||
"[Elastic EP] Scale down completed, new data parallel size: %s",
|
||||
new_data_parallel_size,
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import time
|
||||
import warnings
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Literal
|
||||
|
||||
@@ -114,16 +113,6 @@ class InputProcessor:
|
||||
supported_tasks: tuple[SupportedTask, ...],
|
||||
) -> None:
|
||||
"""Raise `ValueError` if SamplingParams or PoolingParams is not valid."""
|
||||
if params.truncate_prompt_tokens is not None:
|
||||
params_type = type(params).__name__
|
||||
warnings.warn(
|
||||
f"The `truncate_prompt_tokens` parameter in `{params_type}` "
|
||||
"is deprecated and will be removed in v0.17. "
|
||||
"Please pass it via `tokenization_kwargs` instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
if isinstance(params, SamplingParams):
|
||||
supported_generation_tasks = [
|
||||
task for task in supported_tasks if task in GENERATION_TASKS
|
||||
|
||||
@@ -92,6 +92,7 @@ class LLMEngine:
|
||||
self.renderer = renderer = renderer_from_config(self.vllm_config)
|
||||
self.io_processor = get_io_processor(
|
||||
self.vllm_config,
|
||||
self.renderer,
|
||||
self.model_config.io_processor_plugin,
|
||||
)
|
||||
|
||||
|
||||
@@ -277,6 +277,8 @@ class CoreEngineActorManager:
|
||||
else:
|
||||
ray.init()
|
||||
|
||||
vllm_config.parallel_config.allocate_elastic_ep_ports()
|
||||
|
||||
if placement_groups is not None:
|
||||
assert local_dp_ranks is not None, (
|
||||
"local_dp_ranks must be provided if placement_groups is provided"
|
||||
@@ -584,6 +586,8 @@ class CoreEngineActorManager:
|
||||
|
||||
node_ip = node.node_ip
|
||||
node_id = node.node_id
|
||||
if device_str not in available_resources[node_id]:
|
||||
continue
|
||||
available_gpus = int(available_resources[node_id][device_str])
|
||||
|
||||
# Get total GPUs on this node from the node's resources
|
||||
@@ -773,11 +777,50 @@ class CoreEngineActorManager:
|
||||
ray.util.remove_placement_group(pg)
|
||||
|
||||
|
||||
def get_engine_zmq_addresses(
|
||||
vllm_config: VllmConfig,
|
||||
num_api_servers: int = 1,
|
||||
) -> EngineZmqAddresses:
|
||||
"""Allocate ZMQ addresses for engine-client communication."""
|
||||
parallel_config = vllm_config.parallel_config
|
||||
local_engine_count = parallel_config.data_parallel_size_local
|
||||
local_start_index = parallel_config.data_parallel_rank_local
|
||||
dp_size = parallel_config.data_parallel_size
|
||||
host = parallel_config.data_parallel_master_ip
|
||||
local_engines_only = parallel_config.local_engines_only
|
||||
|
||||
# In offline mode there is an LLM instance per DP rank and
|
||||
# one core engine per LLM, see
|
||||
# examples/offline_inference/data_parallel.py.
|
||||
offline_mode = local_start_index is not None
|
||||
|
||||
# client_local_only = True for cases where this front-end
|
||||
# sends requests only to colocated engines.
|
||||
client_local_only = (
|
||||
offline_mode or local_engines_only or (local_engine_count == dp_size)
|
||||
)
|
||||
# NOTE(yongji): handling scaling from intra-node to inter-node
|
||||
if parallel_config.enable_elastic_ep:
|
||||
client_local_only = False
|
||||
|
||||
return EngineZmqAddresses(
|
||||
inputs=[
|
||||
get_engine_client_zmq_addr(client_local_only, host)
|
||||
for _ in range(num_api_servers)
|
||||
],
|
||||
outputs=[
|
||||
get_engine_client_zmq_addr(client_local_only, host)
|
||||
for _ in range(num_api_servers)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def launch_core_engines(
|
||||
vllm_config: VllmConfig,
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
addresses: EngineZmqAddresses,
|
||||
num_api_servers: int = 1,
|
||||
) -> Iterator[
|
||||
tuple[
|
||||
@@ -796,29 +839,8 @@ def launch_core_engines(
|
||||
host = parallel_config.data_parallel_master_ip
|
||||
local_engines_only = parallel_config.local_engines_only
|
||||
|
||||
# In offline mode there is an LLM instance per DP rank and
|
||||
# one core engine per LLM, see
|
||||
# examples/offline_inference/data_parallel.py.
|
||||
offline_mode = local_start_index is not None
|
||||
|
||||
# client_local_only = True for cases where this front-end
|
||||
# sends requests only to colocated engines.
|
||||
client_local_only = (
|
||||
offline_mode or local_engines_only or (local_engine_count == dp_size)
|
||||
)
|
||||
|
||||
# Set up input and output addresses.
|
||||
addresses = EngineZmqAddresses(
|
||||
inputs=[
|
||||
get_engine_client_zmq_addr(client_local_only, host)
|
||||
for _ in range(num_api_servers)
|
||||
],
|
||||
outputs=[
|
||||
get_engine_client_zmq_addr(client_local_only, host)
|
||||
for _ in range(num_api_servers)
|
||||
],
|
||||
)
|
||||
|
||||
# Run the DP Coordinator process with rank 0 when in online DP mode.
|
||||
# The coordinator is needed for:
|
||||
# 1. Internal/hybrid LB: collecting and publishing queue stats for load balancing
|
||||
@@ -885,6 +907,10 @@ def launch_core_engines(
|
||||
# will be False.
|
||||
handshake_local_only = offline_mode or local_engine_count == dp_size
|
||||
|
||||
# NOTE(yongji): handling scaling from intra-node to inter-node
|
||||
if parallel_config.enable_elastic_ep:
|
||||
handshake_local_only = False
|
||||
|
||||
handshake_address = get_engine_client_zmq_addr(
|
||||
handshake_local_only, host, parallel_config.data_parallel_rpc_port
|
||||
)
|
||||
|
||||
@@ -115,7 +115,15 @@ class Executor(ABC):
|
||||
underlying workers.
|
||||
"""
|
||||
self.collective_rpc("initialize_from_config", args=(kv_cache_configs,))
|
||||
self.collective_rpc("compile_or_warm_up_model")
|
||||
compilation_times: list[float] = self.collective_rpc("compile_or_warm_up_model")
|
||||
# Propagate compilation time from workers back to the main process.
|
||||
# With TP>1, compilation happens in worker processes, so the main
|
||||
# process config is never updated. Use max across workers since they
|
||||
# compile in parallel.
|
||||
if compilation_times:
|
||||
self.vllm_config.compilation_config.compilation_time = max(
|
||||
compilation_times
|
||||
)
|
||||
|
||||
def register_failure_callback(self, callback: FailureCallback): # noqa: B027
|
||||
"""
|
||||
|
||||
@@ -38,12 +38,14 @@ from vllm.distributed.parallel_state import (
|
||||
get_pcp_group,
|
||||
get_pp_group,
|
||||
get_tp_group,
|
||||
model_parallel_is_initialized,
|
||||
)
|
||||
from vllm.envs import enable_envs_cache
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tracing import instrument, maybe_init_worker_tracer
|
||||
from vllm.utils.network_utils import (
|
||||
get_distributed_init_method,
|
||||
get_ip,
|
||||
get_loopback_ip,
|
||||
get_open_port,
|
||||
)
|
||||
@@ -128,11 +130,27 @@ class MultiprocExecutor(Executor):
|
||||
# For leader node within each dp rank,
|
||||
# each dp will have its own leader multiproc executor.
|
||||
max_chunk_bytes = envs.VLLM_MQ_MAX_CHUNK_BYTES_MB * 1024 * 1024
|
||||
mq_connect_ip = get_ip()
|
||||
logger.info(
|
||||
"DP group leader: node_rank=%d, node_rank_within_dp=%d, "
|
||||
"master_addr=%s, mq_connect_ip=%s (local), "
|
||||
"world_size=%d, local_world_size=%d",
|
||||
self.parallel_config.node_rank,
|
||||
self.parallel_config.node_rank_within_dp,
|
||||
self.parallel_config.master_addr,
|
||||
mq_connect_ip,
|
||||
self.world_size,
|
||||
self.local_world_size,
|
||||
)
|
||||
mq_kwargs: dict[str, Any] = {}
|
||||
if envs.VLLM_ENABLE_PP_ILU_OPT:
|
||||
mq_kwargs["max_chunks"] = 32
|
||||
self.rpc_broadcast_mq = MessageQueue(
|
||||
self.world_size,
|
||||
self.local_world_size,
|
||||
max_chunk_bytes=max_chunk_bytes,
|
||||
connect_ip=self.parallel_config.master_addr,
|
||||
connect_ip=mq_connect_ip,
|
||||
**mq_kwargs,
|
||||
)
|
||||
scheduler_output_handle = self.rpc_broadcast_mq.export_handle()
|
||||
# Create workers
|
||||
@@ -567,17 +585,22 @@ class WorkerProc:
|
||||
)
|
||||
self.async_output_copy_thread.start()
|
||||
|
||||
# Initialize device
|
||||
self.worker.init_device()
|
||||
|
||||
# Set process title and log prefix
|
||||
self.setup_proc_title_and_log_prefix(
|
||||
enable_ep=vllm_config.parallel_config.enable_expert_parallel
|
||||
)
|
||||
|
||||
# Load model
|
||||
is_eep_new_worker = envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH
|
||||
if not is_eep_new_worker:
|
||||
self.worker.init_device()
|
||||
# Update process title now that parallel groups are initialized
|
||||
self.setup_proc_title_and_log_prefix(
|
||||
enable_ep=vllm_config.parallel_config.enable_expert_parallel
|
||||
)
|
||||
self.worker.load_model()
|
||||
# Initialize message queues after init_device() since multi-node setups
|
||||
# (nnodes_within_dp > 1) require distributed groups to be initialized
|
||||
self._init_message_queues(input_shm_handle, vllm_config)
|
||||
self.worker.load_model()
|
||||
|
||||
# Enable environment variable cache (e.g. assume no more
|
||||
# environment variable overrides after this point)
|
||||
@@ -872,6 +895,13 @@ class WorkerProc:
|
||||
|
||||
@staticmethod
|
||||
def setup_proc_title_and_log_prefix(enable_ep: bool) -> None:
|
||||
# Check if parallel groups are initialized first
|
||||
if not model_parallel_is_initialized():
|
||||
# Parallel groups not yet initialized, use default process name
|
||||
set_process_title(name="Worker")
|
||||
decorate_logs("Worker")
|
||||
return
|
||||
|
||||
dp_size = get_dp_group().world_size
|
||||
dp_rank = get_dp_group().rank_in_group
|
||||
pp_size = get_pp_group().world_size
|
||||
|
||||
@@ -382,8 +382,10 @@ class RayDistributedExecutor(Executor):
|
||||
all_kwargs.append(kwargs)
|
||||
self.collective_rpc("init_worker", args=(all_kwargs,))
|
||||
|
||||
self.collective_rpc("init_device")
|
||||
self.collective_rpc("load_model")
|
||||
is_eep_new_worker = envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH
|
||||
if not is_eep_new_worker:
|
||||
self.collective_rpc("init_device")
|
||||
self.collective_rpc("load_model")
|
||||
|
||||
for pp_rank in range(self.parallel_config.pipeline_parallel_size):
|
||||
self.pp_tp_workers.append([])
|
||||
|
||||
@@ -104,11 +104,23 @@ try:
|
||||
scheduler_output, intermediate_tensors
|
||||
)
|
||||
if self._is_intermediate_tensors(output):
|
||||
if (
|
||||
self.worker.model_runner.supports_mm_inputs
|
||||
and get_pp_group().is_first_rank
|
||||
):
|
||||
# Strip mm_features before Ray forwards it to the next PP Stage.
|
||||
# PP Stage>0 only needs the intermediate tensors,
|
||||
# not preprocessed multimodal data.
|
||||
|
||||
# scheduled_new_reqs is a required field of SchedulerOutput,
|
||||
# so accessing it directly will raise AttributeError if missing.
|
||||
for req in scheduler_output.scheduled_new_reqs:
|
||||
req.mm_features = []
|
||||
return scheduler_output, grammar_output, output
|
||||
|
||||
if isinstance(output, AsyncModelRunnerOutput):
|
||||
output = output.get_output()
|
||||
if not get_pp_group().is_last_rank:
|
||||
if not self._is_last_rank():
|
||||
# Case where there are no scheduled requests
|
||||
# but may still be finished requests.
|
||||
assert not output or not output.req_ids
|
||||
@@ -128,6 +140,9 @@ try:
|
||||
def _is_intermediate_tensors(self, output) -> bool:
|
||||
return isinstance(output, IntermediateTensors)
|
||||
|
||||
def _is_last_rank(self) -> bool:
|
||||
return get_pp_group().is_last_rank
|
||||
|
||||
ray_import_err = None
|
||||
|
||||
except ImportError as e:
|
||||
@@ -362,7 +377,40 @@ def initialize_ray_cluster(
|
||||
runtime_env=parallel_config.ray_runtime_env,
|
||||
)
|
||||
else:
|
||||
ray.init(address=ray_address, runtime_env=parallel_config.ray_runtime_env)
|
||||
import os
|
||||
import torch
|
||||
import vllm.envs as envs
|
||||
runtime_env = {}
|
||||
device_count = torch.cuda.device_count()
|
||||
nccl_if_name = os.environ.get("NCCL_SOCKET_IFNAME",None)
|
||||
vllm_nccl_comm = os.environ.get("VLLM_FORCE_NCCL_COMM",None)
|
||||
if nccl_if_name is not None and vllm_nccl_comm is not None:
|
||||
runtime_env = {"env_vars":{
|
||||
"NCCL_SOCKET_IFNAME":nccl_if_name,
|
||||
"VLLM_FORCE_NCCL_COMM":vllm_nccl_comm}}
|
||||
elif nccl_if_name is not None:
|
||||
runtime_env = {"env_vars":{
|
||||
"NCCL_SOCKET_IFNAME":nccl_if_name}}
|
||||
elif vllm_nccl_comm is not None:
|
||||
runtime_env = {"env_vars":{
|
||||
"VLLM_FORCE_NCCL_COMM":vllm_nccl_comm}}
|
||||
if "env_vars" not in runtime_env:
|
||||
runtime_env = {
|
||||
"env_vars":{"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES":"1"}
|
||||
}
|
||||
else:
|
||||
runtime_env["env_vars"].update({"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES":"1"})
|
||||
all_envs = dict(os.environ)
|
||||
all_vllm_envs = {k: v for k,v in all_envs.items() if "VLLM" in k}
|
||||
runtime_env["env_vars"].update(all_vllm_envs)
|
||||
# ray.init(address=ray_address, ignore_reinit_error=True, runtime_env=runtime_env)
|
||||
if device_count >= parallel_config.world_size:
|
||||
ray.init(address=ray_address,
|
||||
ignore_reinit_error=True,
|
||||
num_gpus=parallel_config.world_size,
|
||||
runtime_env=runtime_env)
|
||||
else:
|
||||
ray.init(address=ray_address, ignore_reinit_error=True, runtime_env=runtime_env)
|
||||
|
||||
device_str = current_platform.ray_device_key
|
||||
if not device_str:
|
||||
|
||||
@@ -14,7 +14,6 @@ import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.network_utils import get_distributed_init_method, get_ip, get_open_port
|
||||
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||||
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.outputs import AsyncModelRunnerOutput, DraftTokenIds, ModelRunnerOutput
|
||||
from vllm.v1.serial_utils import run_method
|
||||
@@ -43,9 +42,11 @@ class UniProcExecutor(Executor):
|
||||
max_workers=1, thread_name_prefix="WorkerAsyncOutput"
|
||||
)
|
||||
|
||||
is_eep_new_worker = envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH
|
||||
self.driver_worker.init_worker(all_kwargs=[kwargs])
|
||||
self.driver_worker.init_device()
|
||||
self.driver_worker.load_model()
|
||||
if not is_eep_new_worker:
|
||||
self.driver_worker.init_device()
|
||||
self.driver_worker.load_model()
|
||||
|
||||
def _distributed_args(self) -> tuple[str, int, int]:
|
||||
"""Return (distributed_init_method, rank, local_rank)."""
|
||||
@@ -122,16 +123,6 @@ class UniProcExecutor(Executor):
|
||||
# it's running.
|
||||
return
|
||||
|
||||
def reinitialize_distributed(
|
||||
self, reconfig_request: ReconfigureDistributedRequest
|
||||
) -> None:
|
||||
self.driver_worker.reinitialize_distributed(reconfig_request)
|
||||
if (
|
||||
reconfig_request.new_data_parallel_rank
|
||||
== ReconfigureRankType.SHUTDOWN_CURRENT_RANK
|
||||
):
|
||||
self.shutdown()
|
||||
|
||||
def shutdown(self) -> None:
|
||||
if worker := self.driver_worker:
|
||||
worker.shutdown()
|
||||
|
||||
@@ -12,6 +12,7 @@ from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.torch_utils import get_dtype_size
|
||||
import vllm.envs as envs
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -34,6 +35,25 @@ class KVCacheSpec:
|
||||
The page size
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@property
|
||||
def scale_page_size_bytes(self) -> int:
|
||||
"""
|
||||
The size of a scale page with `block_size` tokens in bytes.
|
||||
|
||||
Returns:
|
||||
The scale page size
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def v_cache_scale_size_bytes(self) -> int:
|
||||
"""
|
||||
The size of a scale page with `block_size` tokens in bytes.
|
||||
|
||||
Returns:
|
||||
The scale page size
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||
"""
|
||||
@@ -78,13 +98,27 @@ class AttentionSpec(KVCacheSpec):
|
||||
|
||||
@property
|
||||
def real_page_size_bytes(self) -> int:
|
||||
return (
|
||||
2
|
||||
* self.block_size
|
||||
* self.num_kv_heads
|
||||
* self.head_size
|
||||
* get_dtype_size(self.dtype)
|
||||
)
|
||||
if envs.VLLM_ATTN_OPT_LEVEL == 1:
|
||||
# mla 和 i8qi8ki8v 申请的内存一样
|
||||
return 2 * self.block_size * self.num_kv_heads * self.head_size \
|
||||
* get_dtype_size(torch.int8)
|
||||
elif envs.VLLM_ATTN_OPT_LEVEL == 2:
|
||||
# i8qi8kf16v 申请的内存是f16+int8,所以是3
|
||||
return 3 * self.block_size * self.num_kv_heads * self.head_size \
|
||||
* get_dtype_size(torch.int8)
|
||||
return 2 * self.block_size * self.num_kv_heads * self.head_size \
|
||||
* get_dtype_size(self.dtype)
|
||||
@property
|
||||
def scale_page_size_bytes(self) -> int:
|
||||
# For MLA we only store a single latent vector
|
||||
if envs.VLLM_ATTN_OPT_LEVEL > 0:
|
||||
return self.block_size * self.num_kv_heads * get_dtype_size(torch.float32)
|
||||
else:
|
||||
return 0
|
||||
|
||||
@property
|
||||
def v_cache_scale_size_bytes(self) -> int:
|
||||
return self.head_size * self.num_kv_heads * get_dtype_size(torch.float32)
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
@@ -118,7 +152,7 @@ class FullAttentionSpec(AttentionSpec):
|
||||
# (max_model_len//dcp_world_size) tokens locally.
|
||||
if dcp_world_size * pcp_world_size > 1:
|
||||
max_model_len = cdiv(max_model_len, dcp_world_size * pcp_world_size)
|
||||
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
|
||||
return cdiv(max_model_len, self.block_size) * (self.page_size_bytes + self.scale_page_size_bytes)
|
||||
|
||||
@classmethod
|
||||
def merge_window_sizes(cls, window_sizes: set[int]) -> int | None:
|
||||
@@ -179,12 +213,28 @@ class FullAttentionSpec(AttentionSpec):
|
||||
|
||||
@property
|
||||
def real_page_size_bytes(self) -> int:
|
||||
return (
|
||||
self.block_size
|
||||
* self.num_kv_heads
|
||||
* (self.head_size + self.head_size_v)
|
||||
* get_dtype_size(self.dtype)
|
||||
)
|
||||
if envs.VLLM_ATTN_OPT_LEVEL == 1:
|
||||
return (
|
||||
self.block_size
|
||||
* self.num_kv_heads
|
||||
* (self.head_size + self.head_size_v)
|
||||
* get_dtype_size(torch.int8)
|
||||
)
|
||||
elif envs.VLLM_ATTN_OPT_LEVEL == 2:
|
||||
return self.block_size * self.num_kv_heads * self.head_size \
|
||||
* get_dtype_size(torch.int8) + self.block_size * self.num_kv_heads * self.head_size_v \
|
||||
* get_dtype_size(self.dtype)
|
||||
else:
|
||||
return (
|
||||
self.block_size
|
||||
* self.num_kv_heads
|
||||
* (self.head_size + self.head_size_v)
|
||||
* get_dtype_size(self.dtype)
|
||||
)
|
||||
@property
|
||||
def v_cache_scale_size_bytes(self) -> int:
|
||||
return self.head_size_v * self.num_kv_heads * get_dtype_size(torch.float32)
|
||||
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
@@ -198,12 +248,30 @@ class MLAAttentionSpec(FullAttentionSpec):
|
||||
# See `vllm/v1/attention/backends/mla/flashmla_sparse.py`
|
||||
# for details.
|
||||
return self.block_size * 656
|
||||
if envs.VLLM_USE_INT8_MLA:
|
||||
return (
|
||||
self.block_size
|
||||
* self.num_kv_heads
|
||||
* self.head_size
|
||||
* get_dtype_size(torch.int8)
|
||||
)
|
||||
return (
|
||||
self.block_size
|
||||
* self.num_kv_heads
|
||||
* self.head_size
|
||||
* get_dtype_size(self.dtype)
|
||||
)
|
||||
@property
|
||||
def scale_page_size_bytes(self) -> int:
|
||||
# For MLA we only store a single latent vector
|
||||
if envs.VLLM_USE_INT8_MLA:
|
||||
return (
|
||||
self.block_size
|
||||
* self.num_kv_heads * 2
|
||||
* get_dtype_size(torch.float32)
|
||||
)
|
||||
else:
|
||||
return 0
|
||||
|
||||
@classmethod
|
||||
def merge(cls, specs: list[Self]) -> Self:
|
||||
@@ -267,7 +335,7 @@ class SlidingWindowSpec(AttentionSpec):
|
||||
# of the block. For example, if the block size is 4 and num_token
|
||||
# is 4, we need two blocks [XXCD] [EF] to store the sliding
|
||||
# window [CDEF] of 6 tokens.
|
||||
return (cdiv(num_tokens, self.block_size) + 1) * self.page_size_bytes
|
||||
return (cdiv(num_tokens, self.block_size) + 1) * (self.page_size_bytes + self.scale_page_size_bytes)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -289,7 +357,6 @@ class MambaSpec(KVCacheSpec):
|
||||
assert self.page_size_padded >= page_size
|
||||
return self.page_size_padded
|
||||
return page_size
|
||||
|
||||
@property
|
||||
def scale_page_size_bytes(self) -> int:
|
||||
return 0
|
||||
@@ -389,6 +456,9 @@ class UniformTypeKVCacheSpecs(KVCacheSpec):
|
||||
@property
|
||||
def page_size_bytes(self) -> int:
|
||||
return sum(spec.page_size_bytes for spec in self.kv_cache_specs.values())
|
||||
@property
|
||||
def scale_page_size_bytes(self) -> int:
|
||||
return sum(spec.scale_page_size_bytes for spec in self.kv_cache_specs.values())
|
||||
|
||||
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||
max_num_pages = max(
|
||||
@@ -460,6 +530,7 @@ class KVCacheTensor:
|
||||
|
||||
size: int # size of the KV cache tensor in bytes
|
||||
shared_by: list[str] # layer names that share the same KV cache tensor
|
||||
size_scale: int = 0 # size of the v_cache_scale tensor in bytes, only used for VLLM_ATTN_OPT_LEVEL == 1
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -486,6 +557,7 @@ class KVCacheConfig:
|
||||
kv_cache_tensors: list[KVCacheTensor]
|
||||
"""How should model runner initialize the KV cache tensors for each layer"""
|
||||
kv_cache_groups: list[KVCacheGroupSpec]
|
||||
kv_cache_scale_tensors: list[KVCacheTensor]
|
||||
"""
|
||||
The kv cache groups of the model.
|
||||
For models with only one type of attention, there is only one group that
|
||||
@@ -493,3 +565,11 @@ class KVCacheConfig:
|
||||
For models with multiple types of attention, there will be multiple groups,
|
||||
see `_get_kv_cache_config_uniform_page_size` for more details.
|
||||
"""
|
||||
|
||||
@property
|
||||
def has_mamba_layers(self) -> bool:
|
||||
return any(isinstance(g.kv_cache_spec, MambaSpec) for g in self.kv_cache_groups)
|
||||
|
||||
@property
|
||||
def needs_kv_cache_zeroing(self) -> bool:
|
||||
return self.has_mamba_layers
|
||||
|
||||
@@ -259,16 +259,20 @@ class CpuGpuOffloadingHandlers:
|
||||
assert gpu_shape[0] == 2
|
||||
split_k_and_v = True
|
||||
|
||||
try:
|
||||
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
|
||||
include_num_layers_dimension=has_layers_dim
|
||||
)
|
||||
assert len(kv_cache_stride_order) == len(gpu_shape)
|
||||
except (AttributeError, NotImplementedError):
|
||||
kv_cache_stride_order = tuple(range(len(gpu_shape)))
|
||||
if has_layers_dim:
|
||||
# in the cross layers case, the registered kv cache tensor
|
||||
# shape matches the physical layout, whereas test_shape
|
||||
# is the logical layout.
|
||||
# To match them, we need to permute test_shape
|
||||
try:
|
||||
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
|
||||
include_num_layers_dimension=has_layers_dim
|
||||
)
|
||||
assert len(kv_cache_stride_order) == len(gpu_shape)
|
||||
except (AttributeError, NotImplementedError):
|
||||
kv_cache_stride_order = tuple(range(len(gpu_shape)))
|
||||
|
||||
# permute test_shape according to stride_order
|
||||
test_shape = tuple(test_shape[i] for i in kv_cache_stride_order)
|
||||
test_shape = tuple(test_shape[i] for i in kv_cache_stride_order)
|
||||
|
||||
# find block_size (16) dimension index
|
||||
block_size_idx = test_shape.index(16)
|
||||
|
||||
@@ -2,8 +2,9 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, NamedTuple, TypeAlias
|
||||
from typing import TYPE_CHECKING, NamedTuple, TypeAlias, TypeVar
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -120,6 +121,20 @@ class SamplerOutput:
|
||||
logprobs_tensors: LogprobsTensors | None
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def _combine_non_none(f: Callable[[T, T], T], items: list[T | None]) -> T | None:
|
||||
non_none = [item for item in items if item is not None]
|
||||
if len(non_none) == 0:
|
||||
return None
|
||||
|
||||
combined = non_none[0]
|
||||
for item in non_none[1:]:
|
||||
combined = f(combined, item)
|
||||
return combined
|
||||
|
||||
|
||||
@dataclass
|
||||
class KVConnectorOutput:
|
||||
# [req_ids]
|
||||
@@ -146,6 +161,43 @@ class KVConnectorOutput:
|
||||
and not self.invalid_block_ids
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def merge(cls, *outputs: "KVConnectorOutput"):
|
||||
assert len(outputs) > 0, "Cannot merge empty outputs"
|
||||
finished_sending = _combine_non_none(
|
||||
set.union, [output.finished_sending for output in outputs]
|
||||
)
|
||||
finished_recving = _combine_non_none(
|
||||
set.union, [output.finished_recving for output in outputs]
|
||||
)
|
||||
kv_connector_stats = _combine_non_none(
|
||||
lambda x, y: x.aggregate(y),
|
||||
[output.kv_connector_stats for output in outputs],
|
||||
)
|
||||
kv_cache_events = _combine_non_none(
|
||||
lambda x, y: x.merge(y),
|
||||
[output.kv_cache_events for output in outputs],
|
||||
)
|
||||
invalid_block_ids = _combine_non_none(
|
||||
set.union, [output.invalid_block_ids for output in outputs]
|
||||
)
|
||||
assert invalid_block_ids is not None
|
||||
|
||||
assert all(
|
||||
output.expected_finished_count == outputs[0].expected_finished_count
|
||||
for output in outputs
|
||||
)
|
||||
expected_finished_count = outputs[0].expected_finished_count
|
||||
|
||||
return cls(
|
||||
finished_sending=finished_sending,
|
||||
finished_recving=finished_recving,
|
||||
kv_connector_stats=kv_connector_stats,
|
||||
kv_cache_events=kv_cache_events,
|
||||
invalid_block_ids=invalid_block_ids,
|
||||
expected_finished_count=expected_finished_count,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ECConnectorOutput:
|
||||
@@ -153,7 +205,12 @@ class ECConnectorOutput:
|
||||
finished_sending: set[str] | None = None
|
||||
finished_recving: set[str] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DraftTokenIds:
|
||||
# [num_reqs]
|
||||
req_ids: list[str]
|
||||
# num_reqs x num_draft_tokens
|
||||
draft_token_ids: list[list[int]]
|
||||
# ModelRunnerOutput is serialized and sent to the scheduler process.
|
||||
# This is expensive for torch.Tensor so prefer to use list instead.
|
||||
@dataclass
|
||||
@@ -191,6 +248,8 @@ class ModelRunnerOutput:
|
||||
|
||||
# req_id -> num_nans_in_logits
|
||||
num_nans_in_logits: dict[str, int] | None = None
|
||||
|
||||
draft_token_ids: DraftTokenIds | None = None
|
||||
|
||||
# information related to cudagraph execution
|
||||
cudagraph_stats: CUDAGraphStat | None = None
|
||||
@@ -209,13 +268,6 @@ class AsyncModelRunnerOutput(ABC):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class DraftTokenIds:
|
||||
# [num_reqs]
|
||||
req_ids: list[str]
|
||||
# num_reqs x num_draft_tokens
|
||||
draft_token_ids: list[list[int]]
|
||||
|
||||
|
||||
def make_empty_encoder_model_runner_output(
|
||||
scheduler_output: "SchedulerOutput",
|
||||
|
||||
@@ -320,6 +320,7 @@ class RequestStatus(enum.IntEnum):
|
||||
FINISHED_ABORTED = enum.auto()
|
||||
FINISHED_IGNORED = enum.auto()
|
||||
FINISHED_ERROR = enum.auto()
|
||||
FINISHED_REPETITION = enum.auto()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.name
|
||||
@@ -344,4 +345,5 @@ _FINISHED_REASON_MAP = {
|
||||
RequestStatus.FINISHED_IGNORED: FinishReason.LENGTH,
|
||||
RequestStatus.FINISHED_ERROR: FinishReason.ERROR,
|
||||
RequestStatus.WAITING_FOR_STREAMING_REQ: FinishReason.STOP,
|
||||
RequestStatus.FINISHED_REPETITION: FinishReason.REPETITION,
|
||||
}
|
||||
|
||||
@@ -202,10 +202,11 @@ def build_logitsprocs(
|
||||
if custom_logitsprocs:
|
||||
raise ValueError(STR_SPEC_DEC_REJECTS_LOGITSPROCS)
|
||||
logger.warning(
|
||||
"min_p, logit_bias, and min_tokens parameters won't currently work "
|
||||
"with speculative decoding enabled."
|
||||
"min_p and logit_bias parameters won't work with speculative decoding."
|
||||
)
|
||||
return LogitsProcessors(
|
||||
[MinTokensLogitsProcessor(vllm_config, device, is_pin_memory)]
|
||||
)
|
||||
return LogitsProcessors()
|
||||
|
||||
custom_logitsprocs_classes = _load_custom_logitsprocs(custom_logitsprocs)
|
||||
return LogitsProcessors(
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm import SamplingParams
|
||||
@@ -236,6 +237,59 @@ class MinTokensLogitsProcessor(LogitsProcessor):
|
||||
logits.index_put_(self.logits_slice, self.neg_inf_tensor)
|
||||
return logits
|
||||
|
||||
def apply_with_spec_decode(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
num_draft_tokens: list[int],
|
||||
) -> torch.Tensor:
|
||||
"""Spec-decode version of apply().
|
||||
Priority: ``min_tokens`` > ``stop_token_ids`` / EOS.
|
||||
Example: ``num_draft_tokens = [2, 3, 1]``
|
||||
→ ``logits`` shape ``[6, V]``, ``cumsum = [0, 2, 5, 6]``
|
||||
→ request 0 owns rows 0‑1, request 1 rows 2‑4, request 2 row 5.
|
||||
"""
|
||||
if not self.min_toks:
|
||||
return logits
|
||||
|
||||
num_draft_arr = np.array(num_draft_tokens, dtype=np.int64)
|
||||
cumsum = np.concatenate([[0], np.cumsum(num_draft_arr)])
|
||||
|
||||
entries = [
|
||||
(req_idx, min_tok, len(out_tok_ids), list(stop_tok_ids))
|
||||
for req_idx, (min_tok, out_tok_ids, stop_tok_ids) in self.min_toks.items()
|
||||
if stop_tok_ids
|
||||
]
|
||||
|
||||
if not entries:
|
||||
return logits
|
||||
|
||||
all_rows: list[np.ndarray] = [] # row indices to mask
|
||||
all_toks: list[np.ndarray] = [] # stop-token ids at those rows
|
||||
|
||||
for req_idx, min_tok, current_len, stop_toks in entries:
|
||||
remaining = min_tok - current_len
|
||||
# How many leading draft positions still need stop-token masking.
|
||||
n_mask = int(min(max(remaining, 0), num_draft_arr[req_idx]))
|
||||
|
||||
if n_mask > 0:
|
||||
offset = cumsum[req_idx]
|
||||
row_indices = np.arange(offset, offset + n_mask, dtype=np.int64)
|
||||
n_stop = len(stop_toks)
|
||||
all_rows.append(np.repeat(row_indices, n_stop))
|
||||
all_toks.append(np.tile(stop_toks, n_mask))
|
||||
|
||||
if all_rows:
|
||||
rows_arr = np.concatenate(all_rows)
|
||||
toks_arr = np.concatenate(all_toks)
|
||||
# (row_indices, token_indices) for index_put_ to set -inf.
|
||||
logits_slice = (
|
||||
torch.from_numpy(rows_arr).to(self.device, non_blocking=True),
|
||||
torch.from_numpy(toks_arr).to(self.device, non_blocking=True),
|
||||
)
|
||||
logits.index_put_(logits_slice, self.neg_inf_tensor)
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
def process_dict_updates(
|
||||
req_entries: dict[int, T],
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import Iterable, Iterator
|
||||
from itertools import chain
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@@ -148,7 +148,7 @@ class BatchUpdateBuilder:
|
||||
class LogitsProcessors:
|
||||
"""Encapsulates initialized logitsproc objects."""
|
||||
|
||||
def __init__(self, logitsprocs: Iterator["LogitsProcessor"] | None = None) -> None:
|
||||
def __init__(self, logitsprocs: Iterable["LogitsProcessor"] | None = None) -> None:
|
||||
self.argmax_invariant: list[LogitsProcessor] = []
|
||||
self.non_argmax_invariant: list[LogitsProcessor] = []
|
||||
if logitsprocs:
|
||||
|
||||
@@ -10,12 +10,14 @@ import torch.nn as nn
|
||||
from vllm.logger import init_logger
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.outputs import LogprobsLists, LogprobsTensors, SamplerOutput
|
||||
from vllm.v1.sample.logits_processor.builtin import MinTokensLogitsProcessor
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.ops.bad_words import apply_bad_words_with_drafts
|
||||
from vllm.v1.sample.ops.penalties import apply_all_penalties
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||
from vllm.v1.sample.sampler import Sampler
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -292,6 +294,12 @@ class RejectionSampler(nn.Module):
|
||||
logits, bad_words_token_ids, output_token_ids, metadata.num_draft_tokens
|
||||
)
|
||||
|
||||
for processor in sampling_metadata.logitsprocs.non_argmax_invariant:
|
||||
if isinstance(processor, MinTokensLogitsProcessor):
|
||||
logits = processor.apply_with_spec_decode(
|
||||
logits, metadata.num_draft_tokens
|
||||
)
|
||||
|
||||
return logits
|
||||
|
||||
@staticmethod
|
||||
@@ -385,14 +393,13 @@ def rejection_sample(
|
||||
if not sampling_metadata.all_random:
|
||||
# Rejection sampling for greedy sampling requests.
|
||||
target_argmax = target_logits.argmax(dim=-1)
|
||||
rejection_greedy_sample_kernel[(batch_size,)](
|
||||
ops.rejection_greedy_sample_torch(
|
||||
output_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
target_argmax,
|
||||
bonus_token_ids,
|
||||
is_greedy,
|
||||
max_spec_len,
|
||||
)
|
||||
if sampling_metadata.all_greedy:
|
||||
return output_token_ids
|
||||
@@ -424,7 +431,7 @@ def rejection_sample(
|
||||
)
|
||||
|
||||
# Rejection sampling for random sampling requests.
|
||||
rejection_random_sample_kernel[(batch_size,)](
|
||||
ops.rejection_random_sample_torch(
|
||||
output_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
@@ -434,8 +441,6 @@ def rejection_sample(
|
||||
recovered_token_ids,
|
||||
uniform_probs,
|
||||
is_greedy,
|
||||
max_spec_len,
|
||||
vocab_size,
|
||||
NO_DRAFT_PROBS=draft_probs is None,
|
||||
)
|
||||
return output_token_ids
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import ast
|
||||
from dataclasses import replace
|
||||
from importlib.util import find_spec
|
||||
from typing import cast
|
||||
from typing import Any, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -20,17 +20,13 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.models import supports_multimodal
|
||||
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
|
||||
from vllm.model_executor.models.interfaces import SupportsMultiModal
|
||||
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import triton
|
||||
from vllm.utils.platform_utils import is_pin_memory_available
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.attention.backend import CommonAttentionMetadata
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.v1.attention.backends.tree_attn import (
|
||||
TreeAttentionMetadata,
|
||||
@@ -38,14 +34,15 @@ from vllm.v1.attention.backends.tree_attn import (
|
||||
)
|
||||
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
|
||||
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, UniformTypeKVCacheSpecs
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.sampler import _SAMPLING_EPS
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.spec_decode.metadata import MultiLayerEagleMetadata, SpecDecodeMetadata
|
||||
from vllm.v1.spec_decode.utils import (
|
||||
PADDING_SLOT_ID,
|
||||
compute_new_slot_mapping,
|
||||
copy_and_expand_eagle_inputs_kernel,
|
||||
create_vllm_config_for_draft_model,
|
||||
eagle_prepare_inputs_padded_kernel,
|
||||
eagle_prepare_next_token_padded_kernel,
|
||||
extend_all_queries_by_N,
|
||||
@@ -53,6 +50,7 @@ from vllm.v1.spec_decode.utils import (
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
from vllm.v1.worker.utils import AttentionGroup
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -68,6 +66,7 @@ class SpecDecodeBaseProposer:
|
||||
self.vllm_config = vllm_config
|
||||
assert vllm_config.speculative_config is not None
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.draft_vllm_config = create_vllm_config_for_draft_model(vllm_config)
|
||||
self.draft_model_config = self.speculative_config.draft_model_config
|
||||
self.method = self.speculative_config.method
|
||||
self.pass_hidden_states_to_model = pass_hidden_states_to_model
|
||||
@@ -79,6 +78,9 @@ class SpecDecodeBaseProposer:
|
||||
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
self.num_speculative_tokens = self.speculative_config.num_speculative_tokens
|
||||
|
||||
self.enable_multi_layers_mtp = self.speculative_config.enable_multi_layers_mtp
|
||||
self.layer_num = 1
|
||||
|
||||
# We need to get the hidden size from the draft model config because
|
||||
# the draft model's hidden size can be different from the target model's
|
||||
# hidden size (e.g., Llama 3.3 70B).
|
||||
@@ -113,21 +115,19 @@ class SpecDecodeBaseProposer:
|
||||
vllm_config.model_config
|
||||
)
|
||||
|
||||
self.attn_metadata_builder: AttentionMetadataBuilder | None = None
|
||||
self.draft_indexer_metadata_builder: AttentionMetadataBuilder | None = None
|
||||
self.attn_layer_names: list[str] = []
|
||||
self.indexer_layer_names: list[str] = []
|
||||
self.draft_attn_groups: list[AttentionGroup] = []
|
||||
self.kv_cache_gid: int = -1
|
||||
self.eagle3_use_aux_hidden_state: bool = (
|
||||
self._get_eagle3_use_aux_hidden_state_from_config()
|
||||
)
|
||||
|
||||
self.compilation_config = self.vllm_config.compilation_config
|
||||
self.compilation_config = self.draft_vllm_config.compilation_config
|
||||
|
||||
# Cudagraph dispatcher for PIECEWISE-only dispatching in eagle.
|
||||
# Keys are initialized later via initialize_cudagraph_keys() called from
|
||||
# gpu_model_runner._check_and_update_cudagraph_mode after
|
||||
# adjust_cudagraph_sizes_for_spec_decode is called.
|
||||
self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config)
|
||||
self.cudagraph_dispatcher = CudagraphDispatcher(self.draft_vllm_config)
|
||||
|
||||
# persistent buffers for cuda graph
|
||||
self.input_ids = torch.zeros(
|
||||
@@ -353,7 +353,7 @@ class SpecDecodeBaseProposer:
|
||||
self._slot_mapping_buffer[num_actual:num_tokens].fill_(PADDING_SLOT_ID)
|
||||
|
||||
view = self._slot_mapping_buffer[:num_tokens]
|
||||
return {name: view for name in self.attn_layer_names + self.indexer_layer_names}
|
||||
return {name: view for name in self._draft_attn_layer_names}
|
||||
|
||||
def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None:
|
||||
"""Initialize cudagraph dispatcher keys for eagle.
|
||||
@@ -372,6 +372,23 @@ class SpecDecodeBaseProposer:
|
||||
|
||||
self.cudagraph_dispatcher.initialize_cudagraph_keys(eagle_cudagraph_mode)
|
||||
|
||||
def adjust_input(
|
||||
self,
|
||||
batch_size: int,
|
||||
target_token_ids: torch.Tensor,
|
||||
target_positions: torch.Tensor,
|
||||
target_hidden_states: torch.Tensor,
|
||||
token_indices_to_sample: torch.Tensor,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
multi_layer_eagle_metadata: MultiLayerEagleMetadata | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, Any]:
|
||||
return (
|
||||
target_token_ids,
|
||||
target_positions,
|
||||
target_hidden_states,
|
||||
common_attn_metadata,
|
||||
)
|
||||
|
||||
def _greedy_sample(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""Greedy-sample draft tokens from hidden states."""
|
||||
if self.use_local_argmax_reduction:
|
||||
@@ -391,6 +408,7 @@ class SpecDecodeBaseProposer:
|
||||
token_indices_to_sample: torch.Tensor | None,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
multi_layer_eagle_metadata: MultiLayerEagleMetadata | None = None,
|
||||
mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None,
|
||||
num_rejected_tokens_gpu: torch.Tensor | None = None,
|
||||
slot_mappings: dict[str, torch.Tensor]
|
||||
@@ -406,6 +424,21 @@ class SpecDecodeBaseProposer:
|
||||
)
|
||||
assert target_hidden_states.shape[-1] == self.hidden_size
|
||||
|
||||
(
|
||||
target_token_ids,
|
||||
target_positions,
|
||||
target_hidden_states,
|
||||
common_attn_metadata,
|
||||
) = self.adjust_input(
|
||||
batch_size=batch_size,
|
||||
target_token_ids=target_token_ids,
|
||||
target_positions=target_positions,
|
||||
target_hidden_states=target_hidden_states,
|
||||
token_indices_to_sample=token_indices_to_sample,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
multi_layer_eagle_metadata=multi_layer_eagle_metadata,
|
||||
)
|
||||
|
||||
num_tokens, token_indices_to_sample, common_attn_metadata = (
|
||||
self.set_inputs_first_pass(
|
||||
target_token_ids=target_token_ids,
|
||||
@@ -420,109 +453,114 @@ class SpecDecodeBaseProposer:
|
||||
|
||||
assert self.runner is not None
|
||||
|
||||
if self.attn_metadata_builder is None:
|
||||
attn_metadata_builder = self._get_attention_metadata_builder()
|
||||
else:
|
||||
attn_metadata_builder = self.attn_metadata_builder
|
||||
per_layer_attn_metadata: dict[str, object] = {}
|
||||
for attn_group in self.draft_attn_groups:
|
||||
attn_metadata = attn_group.get_metadata_builder().build_for_drafting(
|
||||
common_attn_metadata=common_attn_metadata, draft_index=0
|
||||
)
|
||||
for layer_name in attn_group.layer_names:
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||
|
||||
attn_metadata = attn_metadata_builder.build_for_drafting(
|
||||
common_attn_metadata=common_attn_metadata, draft_index=0
|
||||
cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
|
||||
self._determine_batch_execution_and_padding(num_tokens)
|
||||
)
|
||||
# FIXME: support hybrid kv for draft model (remove separate indexer)
|
||||
if self.draft_indexer_metadata_builder:
|
||||
draft_indexer_metadata = (
|
||||
self.draft_indexer_metadata_builder.build_for_drafting(
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
draft_index=0,
|
||||
|
||||
draft_token_ids_list = []
|
||||
for spec_step_idx in range(self.layer_num):
|
||||
if self.supports_mm_inputs:
|
||||
mm_embeds, is_mm_embed = mm_embed_inputs or (None, None)
|
||||
|
||||
self.inputs_embeds[:num_tokens] = self.model.embed_input_ids(
|
||||
self.input_ids[:num_tokens],
|
||||
multimodal_embeddings=mm_embeds,
|
||||
is_multimodal=is_mm_embed,
|
||||
)
|
||||
)
|
||||
else:
|
||||
draft_indexer_metadata = None
|
||||
# At this moment, we assume all eagle layers belong to the same KV
|
||||
# cache group, thus using the same attention metadata.
|
||||
per_layer_attn_metadata = {}
|
||||
for layer_name in self.attn_layer_names:
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||
|
||||
for layer_name in self.indexer_layer_names:
|
||||
assert draft_indexer_metadata is not None
|
||||
per_layer_attn_metadata[layer_name] = draft_indexer_metadata
|
||||
|
||||
num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp(
|
||||
num_tokens_unpadded=num_tokens, num_tokens_padded=num_tokens
|
||||
)
|
||||
|
||||
cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
|
||||
num_tokens_dp_padded
|
||||
)
|
||||
num_input_tokens = batch_desc.num_tokens
|
||||
if num_tokens_across_dp is not None:
|
||||
num_tokens_across_dp[self.dp_rank] = num_input_tokens
|
||||
|
||||
if self.supports_mm_inputs:
|
||||
mm_embeds, is_mm_embed = mm_embed_inputs or (None, None)
|
||||
|
||||
self.inputs_embeds[:num_tokens] = self.model.embed_input_ids(
|
||||
self.input_ids[:num_tokens],
|
||||
multimodal_embeddings=mm_embeds,
|
||||
is_multimodal=is_mm_embed,
|
||||
)
|
||||
|
||||
input_ids = None
|
||||
inputs_embeds = self.inputs_embeds[:num_input_tokens]
|
||||
else:
|
||||
input_ids = self.input_ids[:num_input_tokens]
|
||||
inputs_embeds = None
|
||||
|
||||
model_kwargs = {
|
||||
"input_ids": input_ids,
|
||||
"positions": self._get_positions(num_input_tokens),
|
||||
"inputs_embeds": inputs_embeds,
|
||||
}
|
||||
if self.pass_hidden_states_to_model:
|
||||
model_kwargs["hidden_states"] = self.hidden_states[:num_input_tokens]
|
||||
|
||||
with set_forward_context(
|
||||
per_layer_attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
slot_mapping=self._get_slot_mapping(
|
||||
num_input_tokens, common_attn_metadata.slot_mapping
|
||||
),
|
||||
):
|
||||
ret_hidden_states = self.model(**model_kwargs)
|
||||
if not self.model_returns_tuple():
|
||||
last_hidden_states = ret_hidden_states
|
||||
hidden_states = last_hidden_states
|
||||
input_ids = None
|
||||
inputs_embeds = self.inputs_embeds[:num_input_tokens]
|
||||
else:
|
||||
last_hidden_states, hidden_states = ret_hidden_states
|
||||
input_ids = self.input_ids[:num_input_tokens]
|
||||
inputs_embeds = None
|
||||
|
||||
sample_hidden_states = last_hidden_states[token_indices_to_sample]
|
||||
model_kwargs = {
|
||||
"input_ids": input_ids,
|
||||
"positions": self._get_positions(num_input_tokens),
|
||||
"inputs_embeds": inputs_embeds,
|
||||
}
|
||||
if self.pass_hidden_states_to_model:
|
||||
model_kwargs["hidden_states"] = self.hidden_states[:num_input_tokens]
|
||||
|
||||
# Early exit if there is only one draft token to be generated.
|
||||
if self.num_speculative_tokens == 1 or self.parallel_drafting:
|
||||
draft_token_ids = self._greedy_sample(sample_hidden_states)
|
||||
return draft_token_ids.view(-1, self.num_speculative_tokens)
|
||||
if self.enable_multi_layers_mtp:
|
||||
model_kwargs["spec_step_idx"] = spec_step_idx
|
||||
|
||||
with set_forward_context(
|
||||
per_layer_attn_metadata,
|
||||
self.draft_vllm_config,
|
||||
num_tokens=num_input_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
slot_mapping=self._get_slot_mapping(
|
||||
num_input_tokens, common_attn_metadata.slot_mapping
|
||||
),
|
||||
):
|
||||
ret_hidden_states = self.model(**model_kwargs)
|
||||
if not self.model_returns_tuple():
|
||||
last_hidden_states = ret_hidden_states
|
||||
hidden_states = last_hidden_states
|
||||
else:
|
||||
last_hidden_states, hidden_states = ret_hidden_states
|
||||
|
||||
sample_hidden_states = last_hidden_states[token_indices_to_sample]
|
||||
if self.enable_multi_layers_mtp:
|
||||
logits = self.model.compute_logits(
|
||||
sample_hidden_states, spec_step_idx=spec_step_idx
|
||||
)
|
||||
else:
|
||||
logits = self.model.compute_logits(sample_hidden_states)
|
||||
|
||||
draft_token_ids = logits.argmax(dim=-1)
|
||||
|
||||
# Generate the remaining draft tokens.
|
||||
draft_token_ids_list.append(draft_token_ids)
|
||||
|
||||
if spec_step_idx < self.layer_num - 1:
|
||||
prev_token_ids = self.input_ids[:num_tokens].clone()
|
||||
hidden_states = hidden_states[:num_tokens]
|
||||
next_token_ids = draft_token_ids_list[-1].int()
|
||||
|
||||
num_tokens, token_indices_to_sample, common_attn_metadata = (
|
||||
self.set_inputs_first_pass(
|
||||
target_token_ids=prev_token_ids,
|
||||
next_token_ids=next_token_ids,
|
||||
target_positions=target_positions,
|
||||
target_hidden_states=hidden_states,
|
||||
token_indices_to_sample=token_indices_to_sample,
|
||||
cad=common_attn_metadata,
|
||||
num_rejected_tokens_gpu=num_rejected_tokens_gpu,
|
||||
)
|
||||
)
|
||||
|
||||
# Early exit if all draft tokens are generated in one pass
|
||||
if self.num_speculative_tokens == self.layer_num or self.parallel_drafting:
|
||||
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
|
||||
return draft_token_ids
|
||||
|
||||
if self.uses_mrope:
|
||||
positions = self.mrope_positions[:, token_indices_to_sample]
|
||||
else:
|
||||
positions = self.positions[token_indices_to_sample]
|
||||
if self.method in (
|
||||
"deepseek_mtp",
|
||||
"ernie_mtp",
|
||||
"longcat_flash_mtp",
|
||||
"pangu_ultra_moe_mtp",
|
||||
):
|
||||
if self.method == "mtp":
|
||||
hidden_states = self.hidden_states[token_indices_to_sample]
|
||||
else:
|
||||
hidden_states = hidden_states[token_indices_to_sample]
|
||||
|
||||
if isinstance(attn_metadata, TreeAttentionMetadata):
|
||||
# Draft using tree attention - requires full logits for top-k
|
||||
logits = self.model.compute_logits(sample_hidden_states)
|
||||
if self.enable_multi_layers_mtp:
|
||||
raise NotImplementedError(
|
||||
"Speculative Decoding with multi-layer MTP and tree attention "
|
||||
"is not supported yet."
|
||||
)
|
||||
# Draft using tree attention.
|
||||
draft_token_ids_list = self.propose_tree(
|
||||
batch_size=batch_size,
|
||||
logits=logits,
|
||||
@@ -534,32 +572,20 @@ class SpecDecodeBaseProposer:
|
||||
# [batch_size, num_tree_tokens]
|
||||
return torch.cat(draft_token_ids_list, dim=1)
|
||||
|
||||
draft_token_ids = self._greedy_sample(sample_hidden_states)
|
||||
|
||||
if self.allowed_attn_types is not None and not isinstance(
|
||||
attn_metadata, self.allowed_attn_types
|
||||
):
|
||||
raise ValueError(
|
||||
f"Unsupported attention metadata type for speculative "
|
||||
"decoding with num_speculative_tokens > 1: "
|
||||
"decoding with num_speculative_tokens > layer_num: "
|
||||
f"{type(attn_metadata)}. Supported types are: "
|
||||
f"{self.allowed_attn_types}"
|
||||
)
|
||||
|
||||
# Generate the remaining draft tokens.
|
||||
draft_token_ids_list = [draft_token_ids]
|
||||
|
||||
batch_size_dp_padded, batch_size_across_dp = self._pad_batch_across_dp(
|
||||
num_tokens_unpadded=batch_size, num_tokens_padded=batch_size
|
||||
cudagraph_runtime_mode, input_batch_size, batch_size_across_dp = (
|
||||
self._determine_batch_execution_and_padding(batch_size)
|
||||
)
|
||||
|
||||
cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
|
||||
batch_size_dp_padded
|
||||
)
|
||||
input_batch_size = batch_desc.num_tokens
|
||||
if batch_size_across_dp is not None:
|
||||
batch_size_across_dp[self.dp_rank] = input_batch_size
|
||||
|
||||
common_attn_metadata.num_actual_tokens = batch_size
|
||||
common_attn_metadata.max_query_len = 1
|
||||
common_attn_metadata.query_start_loc = self.arange[: batch_size + 1]
|
||||
@@ -577,7 +603,7 @@ class SpecDecodeBaseProposer:
|
||||
common_attn_metadata._seq_lens_cpu = None
|
||||
common_attn_metadata._num_computed_tokens_cpu = None
|
||||
|
||||
for token_index in range(self.num_speculative_tokens - 1):
|
||||
for token_index in range(self.num_speculative_tokens - self.layer_num):
|
||||
# Update the inputs.
|
||||
# cast to int32 is crucial when eagle model is compiled.
|
||||
# tensor.argmax() returns int64 by default.
|
||||
@@ -627,7 +653,8 @@ class SpecDecodeBaseProposer:
|
||||
common_attn_metadata._num_computed_tokens_cpu += 1
|
||||
|
||||
# Compute the slot mapping.
|
||||
block_size = attn_metadata_builder.kv_cache_spec.block_size
|
||||
# Use the first draft attention group's kv_cache_spec for block_size
|
||||
block_size = self.draft_attn_groups[0].kv_cache_spec.block_size
|
||||
if self.uses_mrope:
|
||||
# all dimensions of positions are the same
|
||||
block_numbers = clamped_positions[0] // block_size
|
||||
@@ -653,11 +680,13 @@ class SpecDecodeBaseProposer:
|
||||
)
|
||||
|
||||
# Rebuild attention metadata
|
||||
attn_metadata = attn_metadata_builder.build_for_drafting( # type: ignore
|
||||
common_attn_metadata=common_attn_metadata, draft_index=token_index + 1
|
||||
)
|
||||
for layer_name in self.attn_layer_names:
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||
for attn_group in self.draft_attn_groups:
|
||||
attn_metadata = attn_group.get_metadata_builder().build_for_drafting(
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
draft_index=token_index + 1,
|
||||
)
|
||||
for layer_name in attn_group.layer_names:
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||
|
||||
# copy inputs to buffer for cudagraph
|
||||
self.input_ids[:batch_size] = input_ids
|
||||
@@ -683,7 +712,7 @@ class SpecDecodeBaseProposer:
|
||||
|
||||
with set_forward_context(
|
||||
per_layer_attn_metadata,
|
||||
self.vllm_config,
|
||||
self.draft_vllm_config,
|
||||
num_tokens=input_batch_size,
|
||||
num_tokens_across_dp=batch_size_across_dp,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
@@ -819,18 +848,17 @@ class SpecDecodeBaseProposer:
|
||||
# 2.
|
||||
# Recompute the slot mapping based on the new positions and
|
||||
# rejection mask.
|
||||
builder = (
|
||||
self._get_attention_metadata_builder()
|
||||
if self.attn_metadata_builder is None
|
||||
else self.attn_metadata_builder
|
||||
)
|
||||
# Use the first draft attention group's kv_cache_spec for block_size
|
||||
# (all draft layers share the same kv-cache group)
|
||||
assert len(self.draft_attn_groups) > 0
|
||||
block_size = self.draft_attn_groups[0].kv_cache_spec.block_size
|
||||
new_slot_mapping = compute_new_slot_mapping(
|
||||
cad=cad,
|
||||
new_positions=self.positions[:total_num_output_tokens],
|
||||
is_rejected_token_mask=self.is_rejected_token_mask[
|
||||
:total_num_output_tokens
|
||||
],
|
||||
block_size=builder.kv_cache_spec.block_size,
|
||||
block_size=block_size,
|
||||
num_new_tokens=self.net_num_new_slots_per_request,
|
||||
max_model_len=self.max_model_len,
|
||||
)
|
||||
@@ -880,6 +908,69 @@ class SpecDecodeBaseProposer:
|
||||
next_token_ids, dtype=torch.int32, device=self.input_ids.device
|
||||
)
|
||||
return next_token_ids
|
||||
|
||||
def eagle_prepare_next_token_padded(
|
||||
self,
|
||||
sampled_token_ids, # [num_reqs, num_sampled_tokens_per_req]
|
||||
discard_request_mask, # [num_reqs]
|
||||
backup_next_token_ids, # [num_reqs]
|
||||
vocab_size,
|
||||
):
|
||||
"""
|
||||
PyTorch implementation of eagle_prepare_next_token_padded kernel.
|
||||
|
||||
Args:
|
||||
sampled_token_ids: Tensor of shape [num_reqs, num_sampled_tokens_per_req]
|
||||
containing sampled token ids (-1 for rejected tokens)
|
||||
discard_request_mask: Boolean tensor of shape [num_reqs] indicating
|
||||
which requests should be discarded
|
||||
backup_next_token_ids: Tensor of shape [num_reqs] containing backup
|
||||
token ids for when no valid tokens are found
|
||||
vocab_size: Vocabulary size for validity checking
|
||||
|
||||
Returns:
|
||||
next_token_ids: Tensor of shape [num_reqs] containing the next token
|
||||
to sample (last accepted token or backup)
|
||||
valid_sampled_tokens_count: Tensor of shape [num_reqs] containing the
|
||||
number of valid (1 + accepted) tokens
|
||||
"""
|
||||
num_reqs = sampled_token_ids.shape[0]
|
||||
num_sampled_tokens_per_req = sampled_token_ids.shape[1]
|
||||
|
||||
# Initialize output tensors
|
||||
next_token_ids = torch.empty(num_reqs, dtype=sampled_token_ids.dtype, device=sampled_token_ids.device)
|
||||
valid_sampled_tokens_count = torch.zeros(num_reqs, dtype=torch.int32, device=sampled_token_ids.device)
|
||||
|
||||
# Process each request
|
||||
for req_idx in range(num_reqs):
|
||||
if discard_request_mask[req_idx]:
|
||||
# Discarded request: use backup token and valid_count=0
|
||||
next_token_ids[req_idx] = backup_next_token_ids[req_idx]
|
||||
valid_sampled_tokens_count[req_idx] = 0
|
||||
else:
|
||||
# Get sampled tokens for this request
|
||||
tokens = sampled_token_ids[req_idx]
|
||||
|
||||
# Find valid tokens (not -1 and within vocabulary range)
|
||||
is_valid = (tokens != -1) & (tokens < vocab_size)
|
||||
valid_count = is_valid.sum().item()
|
||||
|
||||
if valid_count > 0:
|
||||
# Find the last valid token index
|
||||
# Get indices where is_valid is True
|
||||
valid_indices = torch.where(is_valid)[0]
|
||||
last_valid_idx = valid_indices[-1].item()
|
||||
|
||||
# Get the token at that index
|
||||
last_valid_token = tokens[last_valid_idx]
|
||||
next_token_ids[req_idx] = last_valid_token
|
||||
else:
|
||||
# No valid tokens, use backup token
|
||||
next_token_ids[req_idx] = backup_next_token_ids[req_idx]
|
||||
|
||||
valid_sampled_tokens_count[req_idx] = valid_count
|
||||
|
||||
return next_token_ids, valid_sampled_tokens_count
|
||||
|
||||
def prepare_next_token_ids_padded(
|
||||
self,
|
||||
@@ -910,31 +1001,15 @@ class SpecDecodeBaseProposer:
|
||||
self.backup_next_token_ids.copy_to_gpu(num_reqs)
|
||||
backup_tokens_gpu = self.backup_next_token_ids.gpu
|
||||
|
||||
batch_size, num_tokens = sampled_token_ids.shape
|
||||
device = sampled_token_ids.device
|
||||
|
||||
assert discard_request_mask.dtype == torch.bool
|
||||
assert backup_tokens_gpu.dtype == torch.int32
|
||||
|
||||
next_token_ids = torch.empty(batch_size, dtype=torch.int32, device=device)
|
||||
valid_sampled_tokens_count = next_token_ids.new_empty(batch_size)
|
||||
|
||||
# Kernel grid: one program per request (row)
|
||||
grid = (batch_size,)
|
||||
|
||||
# Find the next power of 2 for block sizes
|
||||
BLOCK_SIZE_TOKENS = triton.next_power_of_2(num_tokens)
|
||||
eagle_prepare_next_token_padded_kernel[grid](
|
||||
|
||||
next_token_ids, valid_sampled_tokens_count = self.eagle_prepare_next_token_padded(
|
||||
sampled_token_ids,
|
||||
discard_request_mask,
|
||||
backup_tokens_gpu,
|
||||
next_token_ids,
|
||||
valid_sampled_tokens_count,
|
||||
gpu_input_batch.vocab_size,
|
||||
num_tokens,
|
||||
batch_size,
|
||||
sampled_token_ids.stride(0),
|
||||
BLOCK_SIZE_TOKENS=BLOCK_SIZE_TOKENS,
|
||||
)
|
||||
|
||||
return next_token_ids, valid_sampled_tokens_count
|
||||
@@ -974,6 +1049,8 @@ class SpecDecodeBaseProposer:
|
||||
)
|
||||
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
key_start_loc = common_attn_metadata.key_start_loc
|
||||
|
||||
new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
|
||||
total_num_tokens = query_start_loc_cpu[-1].item()
|
||||
@@ -981,7 +1058,9 @@ class SpecDecodeBaseProposer:
|
||||
spec_common_attn_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=common_attn_metadata.query_start_loc,
|
||||
seq_lens=common_attn_metadata.seq_lens,
|
||||
seq_lens_np = common_attn_metadata.seq_lens_np,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
key_start_loc=key_start_loc,
|
||||
_seq_lens_cpu=common_attn_metadata._seq_lens_cpu,
|
||||
_num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu,
|
||||
num_reqs=common_attn_metadata.num_reqs,
|
||||
@@ -1014,9 +1093,7 @@ class SpecDecodeBaseProposer:
|
||||
| list[dict[str, torch.Tensor]]
|
||||
| None = None,
|
||||
) -> list[torch.Tensor]:
|
||||
tree_attn_metadata_builder = self.runner.attn_groups[0][
|
||||
0
|
||||
].get_metadata_builder()
|
||||
tree_attn_metadata_builder = self.draft_attn_groups[0].get_metadata_builder()
|
||||
assert isinstance(tree_attn_metadata_builder, TreeAttentionMetadataBuilder)
|
||||
|
||||
total_num_drafts = self.cu_drafts_per_level[0]
|
||||
@@ -1092,10 +1169,11 @@ class SpecDecodeBaseProposer:
|
||||
common_attn_metadata=common_attn_metadata, draft_index=level + 1
|
||||
)
|
||||
|
||||
# Apply new attention metadata to all layers.
|
||||
# Apply new attention metadata to all draft layers.
|
||||
per_layer_attn_metadata = {}
|
||||
for layer_name in self.attn_layer_names:
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||
for attn_group in self.draft_attn_groups:
|
||||
for layer_name in attn_group.layer_names:
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||
|
||||
# Consider max model length.
|
||||
attn_metadata.max_seq_len = min(
|
||||
@@ -1131,7 +1209,7 @@ class SpecDecodeBaseProposer:
|
||||
# Run the model.
|
||||
with set_forward_context(
|
||||
per_layer_attn_metadata,
|
||||
self.vllm_config,
|
||||
self.draft_vllm_config,
|
||||
num_tokens=num_input_tokens,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
slot_mapping=self._get_slot_mapping(
|
||||
@@ -1209,6 +1287,7 @@ class SpecDecodeBaseProposer:
|
||||
|
||||
device = common_attn_metadata.query_start_loc.device
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
key_start_loc = common_attn_metadata.key_start_loc
|
||||
new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu - num_rejected_tokens
|
||||
|
||||
# [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3]
|
||||
@@ -1261,6 +1340,7 @@ class SpecDecodeBaseProposer:
|
||||
query_start_loc=new_query_start_loc_cpu.to(device, non_blocking=True),
|
||||
seq_lens=new_seq_lens_cpu.to(device, non_blocking=True),
|
||||
query_start_loc_cpu=new_query_start_loc_cpu,
|
||||
key_start_loc=key_start_loc,
|
||||
_seq_lens_cpu=new_seq_lens_cpu,
|
||||
_num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu,
|
||||
num_reqs=common_attn_metadata.num_reqs,
|
||||
@@ -1289,7 +1369,7 @@ class SpecDecodeBaseProposer:
|
||||
|
||||
with set_model_tag("eagle_head"):
|
||||
model = get_model(
|
||||
vllm_config=self.vllm_config,
|
||||
vllm_config=self.draft_vllm_config,
|
||||
model_config=self.speculative_config.draft_model_config,
|
||||
load_config=self.speculative_config.draft_load_config,
|
||||
)
|
||||
@@ -1302,43 +1382,17 @@ class SpecDecodeBaseProposer:
|
||||
AttentionLayerBase, # type: ignore[type-abstract]
|
||||
).keys()
|
||||
)
|
||||
# FIXME: support hybrid kv for draft model
|
||||
target_indexer_layer_names = set(
|
||||
get_layers_from_vllm_config(
|
||||
self.vllm_config, DeepseekV32IndexerCache
|
||||
).keys()
|
||||
)
|
||||
|
||||
self.model = self._get_model()
|
||||
|
||||
draft_attn_layer_names = (
|
||||
get_layers_from_vllm_config(
|
||||
self.vllm_config,
|
||||
AttentionLayerBase, # type: ignore[type-abstract]
|
||||
).keys()
|
||||
- target_attn_layer_names
|
||||
# Find draft layers (attention layers added by draft model)
|
||||
all_attn_layers = get_layers_from_vllm_config(
|
||||
self.draft_vllm_config,
|
||||
AttentionLayerBase, # type: ignore[type-abstract]
|
||||
)
|
||||
indexer_layers = get_layers_from_vllm_config(
|
||||
self.vllm_config, DeepseekV32IndexerCache
|
||||
self._draft_attn_layer_names = (
|
||||
set(all_attn_layers.keys()) - target_attn_layer_names
|
||||
)
|
||||
draft_indexer_layer_names = indexer_layers.keys() - target_indexer_layer_names
|
||||
self.attn_layer_names = list(draft_attn_layer_names - draft_indexer_layer_names)
|
||||
self.indexer_layer_names = list(draft_indexer_layer_names)
|
||||
|
||||
if self.indexer_layer_names:
|
||||
first_layer = self.indexer_layer_names[0]
|
||||
self.draft_indexer_metadata_builder = (
|
||||
indexer_layers[first_layer]
|
||||
.get_attn_backend()
|
||||
.get_builder_cls()(
|
||||
indexer_layers[first_layer].get_kv_cache_spec(self.vllm_config),
|
||||
self.indexer_layer_names,
|
||||
self.vllm_config,
|
||||
self.device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.draft_indexer_metadata_builder = None
|
||||
|
||||
if self.supports_mm_inputs:
|
||||
# Even if the target model is multimodal, we can also use
|
||||
@@ -1568,25 +1622,17 @@ class SpecDecodeBaseProposer:
|
||||
self.num_speculative_tokens if not is_graph_capturing else 1
|
||||
):
|
||||
if fwd_idx <= 1:
|
||||
num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp(
|
||||
num_tokens_unpadded=num_tokens, num_tokens_padded=num_tokens
|
||||
)
|
||||
if use_cudagraphs:
|
||||
cudagraph_runtime_mode, batch_desc = (
|
||||
self.cudagraph_dispatcher.dispatch(num_tokens_dp_padded)
|
||||
cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
|
||||
self._determine_batch_execution_and_padding(
|
||||
num_tokens, use_cudagraphs=use_cudagraphs
|
||||
)
|
||||
num_input_tokens = batch_desc.num_tokens
|
||||
else:
|
||||
cudagraph_runtime_mode = CUDAGraphMode.NONE
|
||||
num_input_tokens = num_tokens_dp_padded
|
||||
if num_tokens_across_dp is not None:
|
||||
num_tokens_across_dp[self.dp_rank] = num_input_tokens
|
||||
)
|
||||
|
||||
# Make sure to use EAGLE's own buffer during cudagraph capture.
|
||||
if (
|
||||
self.attn_layer_names
|
||||
self._draft_attn_layer_names
|
||||
and slot_mappings is not None
|
||||
and self.attn_layer_names[0] in slot_mappings
|
||||
and next(iter(self._draft_attn_layer_names)) in slot_mappings
|
||||
):
|
||||
slot_mapping_dict = self._get_slot_mapping(num_input_tokens)
|
||||
else:
|
||||
@@ -1594,7 +1640,7 @@ class SpecDecodeBaseProposer:
|
||||
|
||||
with set_forward_context(
|
||||
None,
|
||||
self.vllm_config,
|
||||
self.draft_vllm_config,
|
||||
num_tokens=num_input_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
@@ -1616,31 +1662,6 @@ class SpecDecodeBaseProposer:
|
||||
kwargs["hidden_states"] = self.hidden_states[:num_input_tokens]
|
||||
self.model(**kwargs)
|
||||
|
||||
def _get_attention_metadata_builder(self) -> AttentionMetadataBuilder:
|
||||
"""Find and return the attention metadata builders for EAGLE layers.
|
||||
|
||||
Returns:
|
||||
The metadata builders for EAGLE layers.
|
||||
|
||||
Raises:
|
||||
AssertionError: If no metadata builders are found for EAGLE layers.
|
||||
"""
|
||||
builder = None
|
||||
chosen_layer = self.attn_layer_names[0]
|
||||
|
||||
for kv_cache_group in self.runner.attn_groups:
|
||||
for attn_group in kv_cache_group:
|
||||
if chosen_layer in attn_group.layer_names:
|
||||
builder = attn_group.get_metadata_builder()
|
||||
break
|
||||
if builder is not None:
|
||||
break
|
||||
|
||||
assert builder is not None, (
|
||||
"Failed to find attention metadata builder for EAGLE layers."
|
||||
)
|
||||
return builder
|
||||
|
||||
def _get_eagle3_use_aux_hidden_state_from_config(self) -> bool:
|
||||
"""
|
||||
Some eagle3 heads (e.g., nvidia/gpt-oss-120b-Eagle3-v2) do not use auxiliary
|
||||
@@ -1673,35 +1694,114 @@ class SpecDecodeBaseProposer:
|
||||
set(
|
||||
[
|
||||
kv_cache_groups[layer_name]
|
||||
for layer_name in self.attn_layer_names
|
||||
for layer_name in self._draft_attn_layer_names
|
||||
]
|
||||
)
|
||||
)
|
||||
== 1
|
||||
), "All drafting layers should belong to the same kv cache group"
|
||||
|
||||
def _pad_batch_across_dp(
|
||||
def initialize_attn_backend(
|
||||
self,
|
||||
num_tokens_unpadded: int,
|
||||
num_tokens_padded: int,
|
||||
) -> tuple[int, torch.Tensor]:
|
||||
# TODO(Flechman): support DBO ubatching
|
||||
should_ubatch, num_toks_across_dp, _ = coordinate_batch_across_dp(
|
||||
num_tokens_unpadded=num_tokens_unpadded,
|
||||
parallel_config=self.vllm_config.parallel_config,
|
||||
allow_microbatching=False,
|
||||
allow_dp_padding=self.cudagraph_dispatcher.cudagraph_mode
|
||||
!= CUDAGraphMode.NONE,
|
||||
num_tokens_padded=num_tokens_padded,
|
||||
uniform_decode=None,
|
||||
num_scheduled_tokens_per_request=None,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
kernel_block_sizes: list[int] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize AttentionGroups for draft layers using kv_cache_config.
|
||||
Called from the model runner's initialize_metadata_builders.
|
||||
"""
|
||||
all_attn_layers = get_layers_from_vllm_config(
|
||||
self.draft_vllm_config,
|
||||
AttentionLayerBase, # type: ignore[type-abstract]
|
||||
)
|
||||
assert not should_ubatch, "DBO ubatching not implemented for EAGLE"
|
||||
|
||||
num_tokens_dp_padded = num_tokens_padded
|
||||
if num_toks_across_dp is not None:
|
||||
num_tokens_dp_padded = int(num_toks_across_dp[self.dp_rank].item())
|
||||
return num_tokens_dp_padded, num_toks_across_dp
|
||||
# Find which kv_cache_group the draft layers belong to
|
||||
self.validate_same_kv_cache_group(kv_cache_config)
|
||||
kv_cache_spec = None
|
||||
for gid, group in enumerate(kv_cache_config.kv_cache_groups):
|
||||
if self._draft_attn_layer_names & set(group.layer_names):
|
||||
self.kv_cache_gid = gid
|
||||
kv_cache_spec = group.kv_cache_spec
|
||||
break
|
||||
|
||||
attention_groups: dict[tuple[str, str], AttentionGroup] = {}
|
||||
if kv_cache_spec is not None:
|
||||
for layer_name in self._draft_attn_layer_names:
|
||||
attn_backend = all_attn_layers[layer_name].get_attn_backend()
|
||||
backend_key = attn_backend.full_cls_name()
|
||||
if backend_key not in attention_groups:
|
||||
layer_kv_cache_spec = kv_cache_spec
|
||||
if isinstance(layer_kv_cache_spec, UniformTypeKVCacheSpecs):
|
||||
layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[
|
||||
layer_name
|
||||
]
|
||||
|
||||
kernel_block_size = (
|
||||
kernel_block_sizes[self.kv_cache_gid]
|
||||
if kernel_block_sizes is not None
|
||||
and self.kv_cache_gid < len(kernel_block_sizes)
|
||||
else None
|
||||
)
|
||||
attn_group = AttentionGroup(
|
||||
backend=attn_backend,
|
||||
layer_names=[layer_name],
|
||||
kv_cache_spec=layer_kv_cache_spec,
|
||||
kv_cache_group_id=self.kv_cache_gid,
|
||||
)
|
||||
attn_group.create_metadata_builders(
|
||||
self.draft_vllm_config,
|
||||
self.device,
|
||||
kernel_block_size=kernel_block_size,
|
||||
)
|
||||
attention_groups[backend_key] = attn_group
|
||||
else:
|
||||
attention_groups[backend_key].layer_names.append(layer_name)
|
||||
|
||||
self.draft_attn_groups = list(attention_groups.values())
|
||||
|
||||
def _determine_batch_execution_and_padding(
|
||||
self,
|
||||
num_tokens: int,
|
||||
use_cudagraphs: bool = True,
|
||||
) -> tuple[CUDAGraphMode, int, torch.Tensor | None]:
|
||||
cudagraph_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
|
||||
num_tokens,
|
||||
valid_modes=({CUDAGraphMode.NONE} if not use_cudagraphs else None),
|
||||
)
|
||||
num_tokens_padded = batch_desc.num_tokens
|
||||
|
||||
# Extra coordination when running data-parallel since we need to
|
||||
# coordinate across ranks
|
||||
# TODO(Flechman): support DBO ubatching
|
||||
should_ubatch, num_tokens_across_dp = False, None
|
||||
if self.draft_vllm_config.parallel_config.data_parallel_size > 1:
|
||||
should_ubatch, num_tokens_across_dp, synced_cudagraph_mode = (
|
||||
coordinate_batch_across_dp(
|
||||
num_tokens_unpadded=num_tokens,
|
||||
parallel_config=self.draft_vllm_config.parallel_config,
|
||||
allow_microbatching=False,
|
||||
num_tokens_padded=num_tokens_padded,
|
||||
cudagraph_mode=cudagraph_mode.value,
|
||||
)
|
||||
)
|
||||
assert not should_ubatch, "DBO ubatching not implemented for EAGLE"
|
||||
|
||||
# Extract DP-synced values
|
||||
if num_tokens_across_dp is not None:
|
||||
dp_rank = self.dp_rank
|
||||
num_tokens_padded = int(num_tokens_across_dp[dp_rank].item())
|
||||
# Re-dispatch with DP padding so we have the correct
|
||||
# batch_descriptor
|
||||
cudagraph_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
|
||||
num_tokens_padded,
|
||||
valid_modes={CUDAGraphMode(synced_cudagraph_mode)},
|
||||
)
|
||||
# Assert to make sure the agreed upon token count is correct
|
||||
# otherwise num_tokens_across_dp will no-longer be valid
|
||||
assert batch_desc.num_tokens == num_tokens_padded
|
||||
num_tokens_across_dp[dp_rank] = num_tokens_padded
|
||||
|
||||
return cudagraph_mode, num_tokens_padded, num_tokens_across_dp
|
||||
|
||||
|
||||
class EagleProposer(SpecDecodeBaseProposer):
|
||||
|
||||
395
vllm/v1/spec_decode/extract_hidden_states.py
Normal file
395
vllm/v1/spec_decode/extract_hidden_states.py
Normal file
@@ -0,0 +1,395 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import nullcontext
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import CUDAGraphMode, VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.distributed.kv_transfer import has_kv_transfer_group
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.v1.attention.backend import AttentionMetadataBuilder, CommonAttentionMetadata
|
||||
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
|
||||
from vllm.v1.outputs import KVConnectorOutput
|
||||
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
|
||||
PADDING_SLOT_ID = -1
|
||||
|
||||
|
||||
class ExtractHiddenStatesProposer:
|
||||
def __init__(self, vllm_config: VllmConfig, device):
|
||||
assert vllm_config.speculative_config is not None
|
||||
|
||||
assert vllm_config.speculative_config.num_speculative_tokens == 1
|
||||
if vllm_config.speculative_config.disable_padded_drafter_batch:
|
||||
raise ValueError(
|
||||
"disable_padded_drafter_batch is not supported with "
|
||||
"extract_hidden_states method"
|
||||
)
|
||||
self.vllm_config = vllm_config
|
||||
self.device = device
|
||||
self.dtype = vllm_config.model_config.dtype
|
||||
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
|
||||
# Model and attention layer tracking (initialized in load_model)
|
||||
self.model: nn.Module | None = None
|
||||
self.attn_layer_names: list[str] = []
|
||||
self.attn_metadata_builder: AttentionMetadataBuilder | None = None
|
||||
|
||||
# Maximum number of tokens for buffers
|
||||
max_batch_size = vllm_config.scheduler_config.max_num_seqs
|
||||
self.max_num_tokens = (
|
||||
vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size
|
||||
)
|
||||
|
||||
self.hf_config = vllm_config.speculative_config.draft_model_config.hf_config
|
||||
layer_ids = getattr(self.hf_config, "eagle_aux_hidden_state_layer_ids", None)
|
||||
if not layer_ids:
|
||||
raise ValueError(
|
||||
"eagle_aux_hidden_state_layer_ids must be set in the draft "
|
||||
"model config for extract_hidden_states method"
|
||||
)
|
||||
self.num_hidden_states = len(layer_ids)
|
||||
self.hidden_size = vllm_config.model_config.get_hidden_size()
|
||||
self.hidden_states = torch.zeros(
|
||||
(self.max_num_tokens, self.num_hidden_states, self.hidden_size),
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
)
|
||||
self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config)
|
||||
|
||||
self._slot_mapping_buffer = torch.zeros(
|
||||
self.max_num_tokens, dtype=torch.int64, device=device
|
||||
)
|
||||
|
||||
def propose(
|
||||
self,
|
||||
sampled_token_ids: torch.Tensor,
|
||||
target_hidden_states: list[torch.Tensor],
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
scheduler_output: SchedulerOutput,
|
||||
slot_mappings: dict[str, torch.Tensor]
|
||||
| list[dict[str, torch.Tensor]]
|
||||
| None = None,
|
||||
) -> tuple[torch.Tensor, KVConnectorOutput | None]:
|
||||
"""Propose draft tokens by calling the ExtractHiddenStatesModel model.
|
||||
|
||||
The ExtractHiddenStatesModel caches the hidden states in the KV cache
|
||||
without performing actual attention computation. This allows us to
|
||||
extract and store hidden states for later use (e.g., KV transfer).
|
||||
|
||||
This proposer doesn't actually perform speculation - it returns the
|
||||
sampled tokens as "draft" tokens, ensuring they always verify (match).
|
||||
The main purpose is to cache hidden states, not to speculate.
|
||||
|
||||
Args:
|
||||
sampled_token_ids: Sampled token IDs from the target model
|
||||
target_hidden_states: List of hidden state tensors from target model
|
||||
(one per aux hidden state layer)
|
||||
common_attn_metadata: Attention metadata
|
||||
scheduler_output: Scheduler output for KV connector
|
||||
slot_mappings: Slot mappings for KV cache (unused, provided for
|
||||
interface compatibility)
|
||||
|
||||
Returns:
|
||||
Tuple of:
|
||||
- Draft tokens matching sampled tokens, shape [batch_size, 1]
|
||||
- KV connector output (if KV transfer is active), else None
|
||||
"""
|
||||
assert self.model is not None and isinstance(target_hidden_states, list)
|
||||
|
||||
# target_hidden_states is a list of tensors (one per layer)
|
||||
# Each tensor has shape [num_tokens, hidden_size]
|
||||
# Stack to shape: [num_tokens, num_hidden_states, hidden_size]
|
||||
stacked_hidden_states = torch.stack(target_hidden_states, dim=1)
|
||||
num_tokens = stacked_hidden_states.shape[0]
|
||||
|
||||
# Copy hidden states to buffer
|
||||
self.hidden_states[:num_tokens] = stacked_hidden_states
|
||||
|
||||
assert self.attn_metadata_builder is not None
|
||||
attn_metadata = self.attn_metadata_builder.build_for_drafting(
|
||||
common_attn_metadata=common_attn_metadata, draft_index=0
|
||||
)
|
||||
|
||||
# We assume all cache-only layers belong to the same KV cache group,
|
||||
# thus using the same attention metadata.
|
||||
per_layer_attn_metadata = {}
|
||||
for layer_name in self.attn_layer_names:
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||
|
||||
cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
|
||||
self._determine_batch_execution_and_padding(num_tokens)
|
||||
)
|
||||
if num_tokens_across_dp is not None:
|
||||
num_tokens_across_dp[self.dp_rank] = num_input_tokens
|
||||
|
||||
with (
|
||||
set_forward_context(
|
||||
per_layer_attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
slot_mapping=self._get_slot_mapping(
|
||||
num_input_tokens, common_attn_metadata.slot_mapping
|
||||
),
|
||||
),
|
||||
(
|
||||
KVConnectorModelRunnerMixin._get_kv_connector_output(scheduler_output)
|
||||
if has_kv_transfer_group()
|
||||
else nullcontext()
|
||||
) as kv_connector_output,
|
||||
):
|
||||
self.model(
|
||||
hidden_states=self.hidden_states[:num_input_tokens],
|
||||
)
|
||||
|
||||
# Return the sampled tokens as "draft" tokens
|
||||
# Shape: [batch_size, 1] to match num_speculative_tokens=1
|
||||
return sampled_token_ids.unsqueeze(-1), kv_connector_output
|
||||
|
||||
def _get_slot_mapping(
|
||||
self,
|
||||
num_tokens: int,
|
||||
slot_mapping: torch.Tensor | None = None,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""Return slot_mapping dict for cache-only attention layers.
|
||||
|
||||
If slot_mapping is provided, copies it into the buffer first.
|
||||
"""
|
||||
if slot_mapping is not None:
|
||||
num_actual = slot_mapping.shape[0]
|
||||
self._slot_mapping_buffer[:num_actual].copy_(slot_mapping)
|
||||
if num_tokens > num_actual:
|
||||
self._slot_mapping_buffer[num_actual:num_tokens].fill_(PADDING_SLOT_ID)
|
||||
|
||||
view = self._slot_mapping_buffer[:num_tokens]
|
||||
return {name: view for name in self.attn_layer_names}
|
||||
|
||||
def _determine_batch_execution_and_padding(
|
||||
self,
|
||||
num_tokens: int,
|
||||
use_cudagraphs: bool = True,
|
||||
) -> tuple[CUDAGraphMode, int, torch.Tensor | None]:
|
||||
cudagraph_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
|
||||
num_tokens,
|
||||
valid_modes=({CUDAGraphMode.NONE} if not use_cudagraphs else None),
|
||||
)
|
||||
num_tokens_padded = batch_desc.num_tokens
|
||||
|
||||
# Extra coordination when running data-parallel since we need to
|
||||
# coordinate across ranks
|
||||
# TODO(Flechman): support DBO ubatching
|
||||
should_ubatch, num_tokens_across_dp = False, None
|
||||
if self.vllm_config.parallel_config.data_parallel_size > 1:
|
||||
should_ubatch, num_tokens_across_dp, synced_cudagraph_mode = (
|
||||
coordinate_batch_across_dp(
|
||||
num_tokens_unpadded=num_tokens,
|
||||
parallel_config=self.vllm_config.parallel_config,
|
||||
allow_microbatching=False,
|
||||
num_tokens_padded=num_tokens_padded,
|
||||
cudagraph_mode=cudagraph_mode.value,
|
||||
)
|
||||
)
|
||||
assert not should_ubatch, (
|
||||
"DBO ubatching not implemented for extract_hidden_states"
|
||||
)
|
||||
|
||||
# Extract DP-synced values
|
||||
if num_tokens_across_dp is not None:
|
||||
dp_rank = self.dp_rank
|
||||
num_tokens_padded = int(num_tokens_across_dp[dp_rank].item())
|
||||
# Re-dispatch with DP padding so we have the correct
|
||||
# batch_descriptor
|
||||
cudagraph_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
|
||||
num_tokens_padded,
|
||||
valid_modes={CUDAGraphMode(synced_cudagraph_mode)},
|
||||
)
|
||||
# Assert to make sure the agreed upon token count is correct
|
||||
# otherwise num_tokens_across_dp will no-longer be valid
|
||||
assert batch_desc.num_tokens == num_tokens_padded
|
||||
num_tokens_across_dp[dp_rank] = num_tokens_padded
|
||||
|
||||
return cudagraph_mode, num_tokens_padded, num_tokens_across_dp
|
||||
|
||||
def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None:
|
||||
"""Initialize cudagraph dispatcher keys.
|
||||
|
||||
Only supports PIECEWISE cudagraphs (via mixed_mode).
|
||||
Should be called after adjust_cudagraph_sizes_for_spec_decode.
|
||||
"""
|
||||
assert self.vllm_config.speculative_config is not None
|
||||
if (
|
||||
not self.vllm_config.speculative_config.enforce_eager
|
||||
and cudagraph_mode.mixed_mode()
|
||||
in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL]
|
||||
):
|
||||
proposer_cudagraph_mode = CUDAGraphMode.PIECEWISE
|
||||
else:
|
||||
proposer_cudagraph_mode = CUDAGraphMode.NONE
|
||||
|
||||
self.cudagraph_dispatcher.initialize_cudagraph_keys(proposer_cudagraph_mode)
|
||||
|
||||
@torch.inference_mode()
|
||||
def dummy_run(
|
||||
self,
|
||||
num_tokens: int,
|
||||
use_cudagraphs: bool = True,
|
||||
is_graph_capturing: bool = False,
|
||||
slot_mappings: dict[str, torch.Tensor] | None = None,
|
||||
) -> None:
|
||||
assert self.model is not None, "Model must be initialized before dummy_run"
|
||||
cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
|
||||
self._determine_batch_execution_and_padding(
|
||||
num_tokens, use_cudagraphs=use_cudagraphs
|
||||
)
|
||||
)
|
||||
|
||||
if num_tokens_across_dp is not None:
|
||||
num_tokens_across_dp[self.dp_rank] = num_input_tokens
|
||||
|
||||
# Use our own slot mapping buffer during cudagraph capture.
|
||||
if (
|
||||
self.attn_layer_names
|
||||
and slot_mappings is not None
|
||||
and self.attn_layer_names[0] in slot_mappings
|
||||
):
|
||||
slot_mapping_dict = self._get_slot_mapping(num_input_tokens)
|
||||
else:
|
||||
slot_mapping_dict = slot_mappings or {}
|
||||
|
||||
with set_forward_context(
|
||||
None,
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
slot_mapping=slot_mapping_dict,
|
||||
):
|
||||
self.model(
|
||||
hidden_states=self.hidden_states[:num_input_tokens],
|
||||
)
|
||||
|
||||
def _build_attn_metadata_builder(
|
||||
self, draft_attn_layers: dict[str, AttentionLayerBase]
|
||||
) -> AttentionMetadataBuilder:
|
||||
"""Build the attention metadata builder from draft attention layers."""
|
||||
if not draft_attn_layers:
|
||||
raise ValueError("No attention layers found for ExtractHiddenStatesModel")
|
||||
layer = next(iter(draft_attn_layers.values()))
|
||||
attn_backend = layer.get_attn_backend()
|
||||
return attn_backend.get_builder_cls()(
|
||||
layer.get_kv_cache_spec(self.vllm_config),
|
||||
self.attn_layer_names,
|
||||
self.vllm_config,
|
||||
self.device,
|
||||
)
|
||||
|
||||
def prepare_next_token_ids_padded(
|
||||
self,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
sampled_token_ids: torch.Tensor,
|
||||
requests: dict[str, CachedRequestState],
|
||||
gpu_input_batch: InputBatch,
|
||||
discard_request_mask: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Prepare next token IDs for speculative decoding.
|
||||
|
||||
Since num_speculative_tokens == 1, sampled_token_ids has shape
|
||||
(batch_size, 1). For each request we either use the sampled token
|
||||
(if valid and not discarded) or a backup token from the request state.
|
||||
"""
|
||||
num_reqs = gpu_input_batch.num_reqs
|
||||
device = sampled_token_ids.device
|
||||
|
||||
# Compute backup tokens for discarded / invalid requests
|
||||
backup_tokens_gpu = torch.tensor(
|
||||
[
|
||||
requests[gpu_input_batch.req_ids[i]].get_token_id(
|
||||
common_attn_metadata.seq_lens_cpu[i].item()
|
||||
)
|
||||
for i in range(num_reqs)
|
||||
],
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
assert discard_request_mask.dtype == torch.bool
|
||||
|
||||
# With num_speculative_tokens == 1, there is exactly one token
|
||||
sampled = sampled_token_ids[:, 0]
|
||||
is_valid = (sampled >= 0) & (sampled < gpu_input_batch.vocab_size)
|
||||
valid_sampled_tokens_count = is_valid.to(torch.int32)
|
||||
|
||||
use_sampled = is_valid & ~discard_request_mask[:num_reqs]
|
||||
next_token_ids = torch.where(
|
||||
use_sampled, sampled.to(torch.int32), backup_tokens_gpu
|
||||
)
|
||||
|
||||
return next_token_ids, valid_sampled_tokens_count
|
||||
|
||||
def load_model(self, target_model: nn.Module) -> None:
|
||||
"""Load the ExtractHiddenStatesModel model.
|
||||
|
||||
This method instantiates the ExtractHiddenStatesModel model which is used
|
||||
to cache hidden states during speculative decoding. The model uses
|
||||
cache-only attention (no computation, just caching KV states).
|
||||
|
||||
Args:
|
||||
target_model: The target model (passed for compatibility with
|
||||
EagleProposer interface, but not used here)
|
||||
"""
|
||||
# Get the target model's attention layers before loading draft model
|
||||
target_attn_layer_names = set(
|
||||
get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys() # type: ignore[type-abstract]
|
||||
)
|
||||
|
||||
assert self.vllm_config.speculative_config is not None
|
||||
draft_model_config = self.vllm_config.speculative_config.draft_model_config
|
||||
from vllm.compilation.backends import set_model_tag
|
||||
|
||||
with set_model_tag("extract_hidden_states"):
|
||||
self.model = get_model(
|
||||
vllm_config=self.vllm_config, model_config=draft_model_config
|
||||
)
|
||||
|
||||
# Identify draft model's attention layers (difference from target)
|
||||
all_attn_layers = get_layers_from_vllm_config(
|
||||
self.vllm_config,
|
||||
AttentionLayerBase, # type: ignore[type-abstract]
|
||||
)
|
||||
draft_attn_layers = {
|
||||
name: layer
|
||||
for name, layer in all_attn_layers.items()
|
||||
if name not in target_attn_layer_names
|
||||
}
|
||||
self.attn_layer_names = list(draft_attn_layers.keys())
|
||||
assert len(draft_attn_layers) == 1, (
|
||||
"ExtractHiddenStatesModel should have exactly one "
|
||||
f"attention layer, found {len(draft_attn_layers)}"
|
||||
)
|
||||
self.attn_metadata_builder = self._build_attn_metadata_builder(
|
||||
draft_attn_layers
|
||||
)
|
||||
|
||||
def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""Validate all drafting layers belong to the same KV cache group.
|
||||
|
||||
With exactly one attention layer (asserted in load_model), this is
|
||||
trivially satisfied.
|
||||
"""
|
||||
assert len(self.attn_layer_names) == 1
|
||||
@@ -64,3 +64,45 @@ class SpecDecodeMetadata:
|
||||
bonus_logits_indices=bonus_logits_indices,
|
||||
logits_indices=logits_indices,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiLayerEagleMetadata:
|
||||
# [batch_size]
|
||||
cached_len: torch.Tensor | None = None
|
||||
# [batch_size, layer_num]
|
||||
cached_token_ids: torch.Tensor | None = None
|
||||
# [batch_size, layer_num, hidden_size]
|
||||
cached_hidden_states: torch.Tensor | None = None
|
||||
# [batch_size, layer_num]
|
||||
cached_slot_mappings: torch.Tensor | None = None
|
||||
# [batch_size, layer_num]
|
||||
cached_positions: torch.Tensor | None = None
|
||||
|
||||
@classmethod
|
||||
def make_dummy(
|
||||
cls,
|
||||
layer_num: int,
|
||||
hidden_size: int,
|
||||
device: torch.device,
|
||||
) -> "MultiLayerEagleMetadata":
|
||||
cached_len = torch.zeros((1), dtype=torch.int64, device=device)
|
||||
cached_token_ids = torch.zeros(
|
||||
(1, layer_num), dtype=torch.int32, device=device
|
||||
)
|
||||
cached_hidden_states = torch.zeros(
|
||||
(1, layer_num, hidden_size), dtype=torch.float32, device=device
|
||||
)
|
||||
cached_slot_mappings = torch.zeros(
|
||||
(1, layer_num), dtype=torch.int64, device=device
|
||||
)
|
||||
cached_positions = torch.zeros(
|
||||
(1, layer_num), dtype=torch.int64, device=device
|
||||
)
|
||||
return cls(
|
||||
cached_len=cached_len,
|
||||
cached_token_ids=cached_token_ids,
|
||||
cached_hidden_states=cached_hidden_states,
|
||||
cached_slot_mappings=cached_slot_mappings,
|
||||
cached_positions=cached_positions,
|
||||
)
|
||||
|
||||
504
vllm/v1/spec_decode/multi_layer_eagle.py
Normal file
504
vllm/v1/spec_decode/multi_layer_eagle.py
Normal file
@@ -0,0 +1,504 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.attention.backend import (
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.spec_decode.eagle import EagleProposer
|
||||
from vllm.v1.spec_decode.metadata import MultiLayerEagleMetadata
|
||||
|
||||
BLOCK_HIDDEN = 128
|
||||
BLOCK_TOKENS = 128
|
||||
|
||||
|
||||
class MultiLayerEagleProposer(EagleProposer):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
runner=None,
|
||||
):
|
||||
super().__init__(vllm_config, device, runner)
|
||||
|
||||
self.layer_num: int = getattr(
|
||||
self.speculative_config.draft_model_config.hf_text_config,
|
||||
"n_predict", 0
|
||||
)
|
||||
self.num_speculative_tokens: int = (
|
||||
self.speculative_config.num_speculative_tokens
|
||||
)
|
||||
|
||||
def adjust_input(
|
||||
self,
|
||||
batch_size: int,
|
||||
target_token_ids: torch.Tensor,
|
||||
target_positions: torch.Tensor,
|
||||
target_hidden_states: torch.Tensor,
|
||||
token_indices_to_sample: torch.Tensor,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
multi_layer_eagle_metadata: MultiLayerEagleMetadata | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, Any]:
|
||||
assert multi_layer_eagle_metadata is not None
|
||||
if token_indices_to_sample is None:
|
||||
token_indices_to_sample = (
|
||||
common_attn_metadata.query_start_loc[1:] - 1
|
||||
)
|
||||
|
||||
MAX_SHIFT = self.layer_num
|
||||
assert MAX_SHIFT > 0
|
||||
|
||||
prev_token_ids = target_token_ids.clone()
|
||||
prev_positions = target_positions.clone()
|
||||
prev_hidden_states = target_hidden_states.clone()
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
|
||||
start_token_indices = common_attn_metadata.query_start_loc[:-1]
|
||||
end_token_indices = common_attn_metadata.query_start_loc[1:] - 1
|
||||
|
||||
pos_for_shift = (
|
||||
target_positions[0]
|
||||
if target_positions.dim() == 2
|
||||
else target_positions
|
||||
)
|
||||
start_token_pos = pos_for_shift[start_token_indices]
|
||||
|
||||
shift = torch.minimum(
|
||||
end_token_indices - token_indices_to_sample,
|
||||
start_token_pos,
|
||||
)
|
||||
shift = torch.clamp(shift, min=0)
|
||||
|
||||
token_indices_to_sample.add_(shift)
|
||||
common_attn_metadata.seq_lens.sub_(shift)
|
||||
|
||||
cached_lens = multi_layer_eagle_metadata.cached_len
|
||||
shift = torch.minimum(shift, cached_lens)
|
||||
|
||||
_multi_layer_eagle_shift_and_cache(
|
||||
batch_size=batch_size,
|
||||
max_shift=MAX_SHIFT,
|
||||
src_token_ids=target_token_ids,
|
||||
dst_token_ids=prev_token_ids,
|
||||
src_positions=target_positions,
|
||||
dst_positions=prev_positions,
|
||||
src_hidden_states=target_hidden_states,
|
||||
dst_hidden_states=prev_hidden_states,
|
||||
src_slot_mapping=slot_mapping,
|
||||
dst_slot_mapping=slot_mapping,
|
||||
start_token_indices=start_token_indices,
|
||||
end_token_indices=end_token_indices,
|
||||
token_indices_to_sample=token_indices_to_sample,
|
||||
shift=shift,
|
||||
cached_lens=cached_lens,
|
||||
cached_prev_token_ids=(
|
||||
multi_layer_eagle_metadata.cached_token_ids
|
||||
),
|
||||
cached_prev_positions=(
|
||||
multi_layer_eagle_metadata.cached_positions
|
||||
),
|
||||
cached_prev_hidden_states=(
|
||||
multi_layer_eagle_metadata.cached_hidden_states
|
||||
),
|
||||
cached_slot_mappings=(
|
||||
multi_layer_eagle_metadata.cached_slot_mappings
|
||||
),
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
)
|
||||
|
||||
return (
|
||||
prev_token_ids,
|
||||
prev_positions,
|
||||
prev_hidden_states,
|
||||
common_attn_metadata,
|
||||
)
|
||||
|
||||
def prepare_inputs(
|
||||
self,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
sampled_token_ids: list[list[int]],
|
||||
num_draft_tokens: list[int],
|
||||
) -> tuple[CommonAttentionMetadata, torch.Tensor]:
|
||||
raise Exception(
|
||||
"speculative_config.disable_padded_drafter_batch"
|
||||
" is not supported now for MultiLayerEagleProposer."
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def dummy_run(
|
||||
self,
|
||||
num_tokens: int,
|
||||
use_cudagraphs: bool = True,
|
||||
is_graph_capturing: bool = False,
|
||||
slot_mappings: dict[str, torch.Tensor] | None = None,
|
||||
) -> None:
|
||||
cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
|
||||
self._determine_batch_execution_and_padding(
|
||||
num_tokens, use_cudagraphs=use_cudagraphs
|
||||
)
|
||||
)
|
||||
|
||||
if (
|
||||
self._draft_attn_layer_names
|
||||
and slot_mappings is not None
|
||||
and next(iter(self._draft_attn_layer_names)) in slot_mappings
|
||||
):
|
||||
slot_mapping_dict = self._get_slot_mapping(num_input_tokens)
|
||||
else:
|
||||
slot_mapping_dict = slot_mappings or {}
|
||||
|
||||
adjust_input_kwargs = {
|
||||
"batch_size": 1,
|
||||
"target_token_ids": self.input_ids[:num_input_tokens],
|
||||
"target_positions": self._get_positions(num_input_tokens),
|
||||
"target_hidden_states": self.hidden_states[:num_input_tokens],
|
||||
"token_indices_to_sample": torch.tensor(
|
||||
[num_input_tokens - 1],
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
),
|
||||
"common_attn_metadata": CommonAttentionMetadata(
|
||||
query_start_loc=torch.tensor(
|
||||
[0, num_input_tokens],
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
),
|
||||
query_start_loc_cpu=torch.tensor(
|
||||
[0, num_input_tokens],
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
),
|
||||
key_start_loc=torch.tensor(
|
||||
[0, num_input_tokens],
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
),
|
||||
seq_lens=torch.tensor(
|
||||
[num_input_tokens],
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
),
|
||||
seq_lens_np=np.array([num_input_tokens], dtype=np.int32),
|
||||
num_reqs=1,
|
||||
num_actual_tokens=num_input_tokens,
|
||||
max_query_len=self.num_speculative_tokens + 1,
|
||||
max_seq_len=self.max_model_len,
|
||||
block_table_tensor=torch.tensor(
|
||||
[], dtype=torch.int32, device=self.device
|
||||
),
|
||||
slot_mapping=self.arange[:num_input_tokens],
|
||||
logits_indices_padded=None,
|
||||
num_logits_indices=None,
|
||||
causal=True,
|
||||
encoder_seq_lens=None,
|
||||
),
|
||||
"multi_layer_eagle_metadata": MultiLayerEagleMetadata.make_dummy(
|
||||
layer_num=self.layer_num,
|
||||
hidden_size=self.hidden_size,
|
||||
device=self.device,
|
||||
),
|
||||
}
|
||||
self.adjust_input(**adjust_input_kwargs)
|
||||
|
||||
for fwd_idx in range(self.layer_num):
|
||||
with set_forward_context(
|
||||
None,
|
||||
self.draft_vllm_config,
|
||||
num_tokens=num_input_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
slot_mapping=slot_mapping_dict,
|
||||
):
|
||||
if self.supports_mm_inputs:
|
||||
input_ids = None
|
||||
inputs_embeds = self.inputs_embeds[:num_input_tokens]
|
||||
else:
|
||||
input_ids = self.input_ids[:num_input_tokens]
|
||||
inputs_embeds = None
|
||||
|
||||
model_kwargs = {
|
||||
"input_ids": input_ids,
|
||||
"positions": self._get_positions(num_input_tokens),
|
||||
"hidden_states": self.hidden_states[:num_input_tokens],
|
||||
"inputs_embeds": inputs_embeds,
|
||||
"spec_step_idx": fwd_idx,
|
||||
}
|
||||
|
||||
self.model(**model_kwargs)
|
||||
|
||||
|
||||
def _multi_layer_eagle_shift_and_cache(
|
||||
*,
|
||||
batch_size: int,
|
||||
max_shift: int,
|
||||
src_token_ids: torch.Tensor,
|
||||
dst_token_ids: torch.Tensor,
|
||||
src_positions: torch.Tensor,
|
||||
dst_positions: torch.Tensor,
|
||||
src_hidden_states: torch.Tensor,
|
||||
dst_hidden_states: torch.Tensor,
|
||||
src_slot_mapping: torch.Tensor,
|
||||
dst_slot_mapping: torch.Tensor,
|
||||
start_token_indices: torch.Tensor,
|
||||
end_token_indices: torch.Tensor,
|
||||
token_indices_to_sample: torch.Tensor,
|
||||
shift: torch.Tensor,
|
||||
cached_lens: torch.Tensor,
|
||||
cached_prev_token_ids: torch.Tensor,
|
||||
cached_prev_positions: torch.Tensor,
|
||||
cached_prev_hidden_states: torch.Tensor,
|
||||
cached_slot_mappings: torch.Tensor,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
):
|
||||
if batch_size == 0:
|
||||
return
|
||||
|
||||
assert max_shift > 0
|
||||
assert cached_prev_positions.is_contiguous()
|
||||
assert cached_prev_token_ids.is_contiguous()
|
||||
assert cached_prev_hidden_states.is_contiguous()
|
||||
assert cached_slot_mappings.is_contiguous()
|
||||
assert src_hidden_states.is_contiguous()
|
||||
assert dst_hidden_states.is_contiguous()
|
||||
|
||||
if src_slot_mapping.data_ptr() == dst_slot_mapping.data_ptr():
|
||||
src_slot_mapping = src_slot_mapping.clone()
|
||||
|
||||
store_start = torch.maximum(
|
||||
start_token_indices,
|
||||
(token_indices_to_sample + 1 - max_shift),
|
||||
)
|
||||
store_lens = torch.clamp(
|
||||
token_indices_to_sample - store_start + 1,
|
||||
min=0,
|
||||
max=max_shift,
|
||||
)
|
||||
|
||||
max_window_len = int(
|
||||
(
|
||||
common_attn_metadata.query_start_loc_cpu[1:]
|
||||
- common_attn_metadata.query_start_loc_cpu[:-1]
|
||||
)
|
||||
.max()
|
||||
.item()
|
||||
)
|
||||
num_blocks = max(1, (max_window_len + BLOCK_TOKENS - 1) // BLOCK_TOKENS)
|
||||
|
||||
_shift_and_gather_cache_1d_kernel[(batch_size, num_blocks)](
|
||||
src_token_ids,
|
||||
dst_token_ids,
|
||||
cached_prev_token_ids,
|
||||
start_token_indices,
|
||||
end_token_indices,
|
||||
shift,
|
||||
cached_lens,
|
||||
store_start,
|
||||
store_lens,
|
||||
MAX_SHIFT=max_shift,
|
||||
PADDED_SHIFT=triton.next_power_of_2(max_shift),
|
||||
BLOCK_TOKENS=BLOCK_TOKENS,
|
||||
)
|
||||
|
||||
_shift_and_gather_cache_1d_kernel[(batch_size, num_blocks)](
|
||||
src_slot_mapping,
|
||||
dst_slot_mapping,
|
||||
cached_slot_mappings,
|
||||
start_token_indices,
|
||||
end_token_indices,
|
||||
shift,
|
||||
cached_lens,
|
||||
store_start,
|
||||
store_lens,
|
||||
MAX_SHIFT=max_shift,
|
||||
PADDED_SHIFT=triton.next_power_of_2(max_shift),
|
||||
BLOCK_TOKENS=BLOCK_TOKENS,
|
||||
)
|
||||
|
||||
_shift_and_gather_cache_1d_kernel[(batch_size, num_blocks)](
|
||||
src_positions,
|
||||
dst_positions,
|
||||
cached_prev_positions,
|
||||
start_token_indices,
|
||||
end_token_indices,
|
||||
shift,
|
||||
cached_lens,
|
||||
store_start,
|
||||
store_lens,
|
||||
MAX_SHIFT=max_shift,
|
||||
PADDED_SHIFT=triton.next_power_of_2(max_shift),
|
||||
BLOCK_TOKENS=BLOCK_TOKENS,
|
||||
)
|
||||
|
||||
hidden_size = int(dst_hidden_states.shape[1])
|
||||
num_hidden_blocks = max(
|
||||
1, (hidden_size + BLOCK_HIDDEN - 1) // BLOCK_HIDDEN
|
||||
)
|
||||
|
||||
_shift_and_gather_hidden_kernel[
|
||||
(batch_size, num_blocks, num_hidden_blocks)
|
||||
](
|
||||
src_hidden_states,
|
||||
dst_hidden_states,
|
||||
cached_prev_hidden_states,
|
||||
start_token_indices,
|
||||
end_token_indices,
|
||||
shift,
|
||||
cached_lens,
|
||||
store_start,
|
||||
store_lens,
|
||||
MAX_SHIFT=max_shift,
|
||||
PADDED_SHIFT=triton.next_power_of_2(max_shift),
|
||||
HIDDEN_SIZE=hidden_size,
|
||||
BLOCK_TOKENS=BLOCK_TOKENS,
|
||||
BLOCK_HIDDEN=BLOCK_HIDDEN,
|
||||
num_warps=4,
|
||||
)
|
||||
|
||||
cached_lens.copy_(store_lens)
|
||||
return
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _shift_and_gather_cache_1d_kernel(
|
||||
src_ptr,
|
||||
dst_ptr,
|
||||
cached_ptr,
|
||||
start_ptr,
|
||||
end_ptr,
|
||||
shift_ptr,
|
||||
cached_len_ptr,
|
||||
store_start_ptr,
|
||||
store_len_ptr,
|
||||
MAX_SHIFT: tl.constexpr,
|
||||
PADDED_SHIFT: tl.constexpr,
|
||||
BLOCK_TOKENS: tl.constexpr,
|
||||
):
|
||||
# Per-sequence "shift + gather" for packed 1D arrays (token ids, positions,
|
||||
# slot mappings, ...).
|
||||
#
|
||||
# For a single sequence (0-based index i within its window):
|
||||
# - Prefix (i < shift):
|
||||
# dst[start + i] = cached[cached_len - shift + i]
|
||||
# - Body (i >= shift):
|
||||
# dst[start + i] = src[start + i - shift]
|
||||
pid_seq = tl.program_id(0)
|
||||
pid_blk = tl.program_id(1)
|
||||
|
||||
start = tl.load(start_ptr + pid_seq).to(tl.int32)
|
||||
end = tl.load(end_ptr + pid_seq).to(tl.int32)
|
||||
shift = tl.load(shift_ptr + pid_seq).to(tl.int32)
|
||||
cached_len = tl.load(cached_len_ptr + pid_seq).to(tl.int32)
|
||||
|
||||
assert cached_len >= shift
|
||||
|
||||
base = pid_blk * BLOCK_TOKENS
|
||||
k = tl.arange(0, BLOCK_TOKENS)
|
||||
offs = base + k
|
||||
dst_idx = start + offs
|
||||
|
||||
window_len = end - start + 1
|
||||
mask = offs < window_len
|
||||
|
||||
base_cached = cached_ptr + pid_seq * MAX_SHIFT
|
||||
cached_idx = cached_len - shift + offs
|
||||
cached_mask = offs < shift
|
||||
val_cached = tl.load(
|
||||
base_cached + cached_idx, mask=mask & cached_mask, other=0
|
||||
)
|
||||
|
||||
src_idx = start + offs - shift
|
||||
val_src = tl.load(src_ptr + src_idx, mask=mask & ~cached_mask, other=0)
|
||||
|
||||
val = tl.where(cached_mask, val_cached, val_src)
|
||||
tl.store(dst_ptr + dst_idx, val, mask=mask)
|
||||
|
||||
store_start = tl.load(store_start_ptr + pid_seq).to(tl.int32)
|
||||
store_len = tl.load(store_len_ptr + pid_seq).to(tl.int32)
|
||||
m = tl.arange(0, PADDED_SHIFT)
|
||||
store_mask = m < MAX_SHIFT
|
||||
dst_idx = store_start + m
|
||||
val = tl.load(
|
||||
dst_ptr + dst_idx, mask=store_mask & (m < store_len), other=0
|
||||
)
|
||||
tl.store(base_cached + m, val, mask=store_mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _shift_and_gather_hidden_kernel(
|
||||
src_ptr,
|
||||
dst_ptr,
|
||||
cached_ptr,
|
||||
start_ptr,
|
||||
end_ptr,
|
||||
shift_ptr,
|
||||
cached_len_ptr,
|
||||
store_start_ptr,
|
||||
store_len_ptr,
|
||||
MAX_SHIFT: tl.constexpr,
|
||||
PADDED_SHIFT: tl.constexpr,
|
||||
HIDDEN_SIZE: tl.constexpr,
|
||||
BLOCK_TOKENS: tl.constexpr,
|
||||
BLOCK_HIDDEN: tl.constexpr,
|
||||
):
|
||||
# Per-sequence "shift + gather" for hidden states.
|
||||
# Layout:
|
||||
# - src_ptr / dst_ptr: [num_tokens, hidden_size]
|
||||
# - cached_ptr: [batch_size, MAX_SHIFT, hidden_size]
|
||||
pid_seq = tl.program_id(0)
|
||||
pid_blk = tl.program_id(1)
|
||||
pid_hid = tl.program_id(2)
|
||||
|
||||
start = tl.load(start_ptr + pid_seq).to(tl.int32)
|
||||
end = tl.load(end_ptr + pid_seq).to(tl.int32)
|
||||
shift = tl.load(shift_ptr + pid_seq).to(tl.int32)
|
||||
cached_len = tl.load(cached_len_ptr + pid_seq).to(tl.int32)
|
||||
|
||||
assert cached_len >= shift
|
||||
|
||||
base = pid_blk * BLOCK_TOKENS
|
||||
k = tl.arange(0, BLOCK_TOKENS)
|
||||
tok_offs = base + k
|
||||
dst_tok = start + tok_offs
|
||||
n = pid_hid * BLOCK_HIDDEN + tl.arange(0, BLOCK_HIDDEN)
|
||||
dst_ptrs = dst_ptr + dst_tok[:, None] * HIDDEN_SIZE + n[None, :] * 1
|
||||
|
||||
window_len = end - start + 1
|
||||
tok_mask = tok_offs < window_len
|
||||
n_mask = n < HIDDEN_SIZE
|
||||
mask = tok_mask[:, None] & n_mask[None, :]
|
||||
|
||||
base_cached = cached_ptr + pid_seq * HIDDEN_SIZE * MAX_SHIFT
|
||||
cached_tok = cached_len - shift + tok_offs
|
||||
cached_ptrs = (
|
||||
base_cached + cached_tok[:, None] * HIDDEN_SIZE + n[None, :] * 1
|
||||
)
|
||||
cached_mask = tok_offs < shift
|
||||
val_cached = tl.load(
|
||||
cached_ptrs, mask=mask & cached_mask[:, None], other=0
|
||||
)
|
||||
|
||||
src_tok = start + tok_offs - shift
|
||||
src_ptrs = src_ptr + src_tok[:, None] * HIDDEN_SIZE + n[None, :] * 1
|
||||
val_src = tl.load(src_ptrs, mask=mask & ~cached_mask[:, None], other=0)
|
||||
|
||||
val = tl.where(cached_mask[:, None], val_cached, val_src)
|
||||
tl.store(dst_ptrs, val, mask=mask)
|
||||
|
||||
store_start = tl.load(store_start_ptr + pid_seq).to(tl.int32)
|
||||
store_len = tl.load(store_len_ptr + pid_seq).to(tl.int32)
|
||||
m = tl.arange(0, PADDED_SHIFT)
|
||||
m_mask = (m < MAX_SHIFT) & (m < store_len)
|
||||
store_tok = store_start + m
|
||||
dst_ptrs = dst_ptr + store_tok[:, None] * HIDDEN_SIZE + n[None, :] * 1
|
||||
store_ptrs = (
|
||||
base_cached + m[:, None] * HIDDEN_SIZE + n[None, :] * 1
|
||||
)
|
||||
mask = m_mask[:, None] & n_mask[None, :]
|
||||
val = tl.load(dst_ptrs, mask=mask, other=0)
|
||||
tl.store(store_ptrs, val, mask=mask)
|
||||
@@ -157,12 +157,23 @@ def create_vllm_config_for_draft_model(
|
||||
quantized differently, and has potentially different tensor_parallel_size.
|
||||
This function creates a new vllm_config configured for the drafter.
|
||||
The vllm_config is useful when loading the draft model with get_model().
|
||||
|
||||
This helper returns the original target config for the common case and only
|
||||
rewrites rank/parallel info when the drafter is configured to run locally
|
||||
on the last target PP stage. This keeps runtime behavior unchanged for the
|
||||
common case while still handling PP rank remapping.
|
||||
"""
|
||||
old = target_model_vllm_config
|
||||
assert old.speculative_config is not None, "speculative_config is not set"
|
||||
old_spec_config = old.speculative_config
|
||||
needs_rank_remap = old_spec_config.needs_partial_pp_draft_remap(old.parallel_config)
|
||||
if not needs_rank_remap:
|
||||
return old
|
||||
|
||||
draft_rank = old_spec_config.resolve_partial_pp_draft_rank(old.parallel_config)
|
||||
|
||||
new_parallel_config = replace(
|
||||
old_spec_config.draft_parallel_config, rank=old.parallel_config.rank
|
||||
old_spec_config.draft_parallel_config, rank=draft_rank
|
||||
)
|
||||
new: VllmConfig = replace(
|
||||
old,
|
||||
|
||||
@@ -53,7 +53,12 @@ class CPUModelRunner(GPUModelRunner):
|
||||
v.gpu = v.cpu
|
||||
|
||||
@instrument(span_name="Loading (CPU)")
|
||||
def load_model(self, eep_scale_up: bool = False) -> None:
|
||||
def load_model(self, load_dummy_weights: bool = False) -> None:
|
||||
if load_dummy_weights:
|
||||
raise ValueError(
|
||||
"Loading dummy weights (needed for elastic EP scale-up) "
|
||||
"Is not supported by the CPU Model Runner."
|
||||
)
|
||||
logger.info("Starting to load model %s...", self.model_config.model)
|
||||
self.model = get_model(vllm_config=self.vllm_config)
|
||||
|
||||
|
||||
@@ -85,7 +85,7 @@ class CPUWorker(Worker):
|
||||
self.local_omp_cpuid = omp_cpuids_list[self.rank]
|
||||
|
||||
if self.local_omp_cpuid != "nobind":
|
||||
ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)
|
||||
ret = torch.ops._C.init_cpu_threads_env(self.local_omp_cpuid)
|
||||
if ret:
|
||||
logger.info(ret)
|
||||
|
||||
@@ -118,11 +118,12 @@ class CPUWorker(Worker):
|
||||
def determine_available_memory(self) -> int:
|
||||
return self.cache_config.cpu_kvcache_space_bytes or 0
|
||||
|
||||
def compile_or_warm_up_model(self) -> None:
|
||||
def compile_or_warm_up_model(self) -> float:
|
||||
# Reset the seed to ensure that the random state is not affected by
|
||||
# the model initialization and profiling.
|
||||
set_random_seed(self.model_config.seed)
|
||||
self.model_runner.warming_up_model()
|
||||
return self.compilation_config.compilation_time
|
||||
|
||||
def _get_autobind_cpu_ids(
|
||||
self, cpu_selector: Callable[[list[LogicalCPUInfo]], list[LogicalCPUInfo]]
|
||||
|
||||
@@ -37,7 +37,6 @@ def _get_device_and_group(parallel_config: ParallelConfig):
|
||||
|
||||
def _run_ar(
|
||||
should_ubatch: bool,
|
||||
should_dp_pad: bool,
|
||||
orig_num_tokens_per_ubatch: int,
|
||||
padded_num_tokens_per_ubatch: int,
|
||||
cudagraph_mode: int,
|
||||
@@ -46,12 +45,11 @@ def _run_ar(
|
||||
dp_size = parallel_config.data_parallel_size
|
||||
dp_rank = parallel_config.data_parallel_rank
|
||||
device, group = _get_device_and_group(parallel_config)
|
||||
tensor = torch.zeros(5, dp_size, device=device, dtype=torch.int32)
|
||||
tensor = torch.zeros(4, dp_size, device=device, dtype=torch.int32)
|
||||
tensor[0][dp_rank] = orig_num_tokens_per_ubatch
|
||||
tensor[1][dp_rank] = padded_num_tokens_per_ubatch
|
||||
tensor[2][dp_rank] = 1 if should_ubatch else 0
|
||||
tensor[3][dp_rank] = 1 if should_dp_pad else 0
|
||||
tensor[4][dp_rank] = cudagraph_mode
|
||||
tensor[3][dp_rank] = cudagraph_mode
|
||||
dist.all_reduce(tensor, group=group)
|
||||
return tensor
|
||||
|
||||
@@ -97,14 +95,13 @@ def _post_process_cudagraph_mode(tensor: torch.Tensor) -> int:
|
||||
If any rank has NONE (0), all ranks use NONE.
|
||||
This ensures all ranks send consistent values (all padded or all unpadded).
|
||||
"""
|
||||
return int(tensor[4, :].min().item())
|
||||
return int(tensor[3, :].min().item())
|
||||
|
||||
|
||||
def _synchronize_dp_ranks(
|
||||
num_tokens_unpadded: int,
|
||||
num_tokens_padded: int,
|
||||
should_attempt_ubatching: bool,
|
||||
should_attempt_dp_padding: bool,
|
||||
cudagraph_mode: int,
|
||||
parallel_config: ParallelConfig,
|
||||
) -> tuple[bool, torch.Tensor | None, int]:
|
||||
@@ -113,8 +110,8 @@ def _synchronize_dp_ranks(
|
||||
run with microbatching or none of them do.
|
||||
|
||||
2. Determines the total number of tokens that each rank will run.
|
||||
When running microbatched or if should_attempt_dp_padding is True, all
|
||||
ranks will be padded out so that the run with the same number of tokens
|
||||
When running microbatched or if cudagraph is enabled (synced across ranks),
|
||||
all ranks will be padded out so that they run with the same number of tokens.
|
||||
|
||||
3. Synchronizes cudagraph_mode across ranks by taking the minimum.
|
||||
|
||||
@@ -133,29 +130,26 @@ def _synchronize_dp_ranks(
|
||||
# will run and if we are using ubatching or not.
|
||||
tensor = _run_ar(
|
||||
should_ubatch=should_attempt_ubatching,
|
||||
should_dp_pad=should_attempt_dp_padding,
|
||||
orig_num_tokens_per_ubatch=num_tokens_unpadded,
|
||||
padded_num_tokens_per_ubatch=num_tokens_padded,
|
||||
cudagraph_mode=cudagraph_mode,
|
||||
parallel_config=parallel_config,
|
||||
)
|
||||
|
||||
should_dp_pad = bool(torch.all(tensor[3] == 1).item())
|
||||
|
||||
# DP ranks should all have the same value for should_attempt_dp_padding.
|
||||
assert should_attempt_dp_padding == should_dp_pad
|
||||
# Synchronize cudagraph_mode across ranks first (take min).
|
||||
# This is needed before DP padding decision since we use the synced
|
||||
# cudagraph mode to determine whether DP padding is needed.
|
||||
synced_cudagraph_mode = _post_process_cudagraph_mode(tensor)
|
||||
|
||||
# Check conditions for microbatching
|
||||
should_ubatch = _post_process_ubatch(tensor, parallel_config.num_ubatches)
|
||||
|
||||
if should_ubatch and not should_dp_pad:
|
||||
logger.debug_once(
|
||||
"Microbatching has been triggered and requires DP padding. "
|
||||
"Enabling DP padding even though it has been explicitly "
|
||||
"disabled.",
|
||||
scope="global",
|
||||
)
|
||||
should_dp_pad = True
|
||||
# DP padding is needed when cudagraph is enabled (synced across ranks)
|
||||
# or when ubatching/DBO is active (ubatching requires uniform batch
|
||||
# sizes across DP ranks currently).
|
||||
# Use the synced runtime cudagraph mode rather than the compilation config
|
||||
# so we can avoid padding when cudagraph is not enabled for this step.
|
||||
should_dp_pad = synced_cudagraph_mode != 0 or should_ubatch
|
||||
|
||||
# Pad all DP ranks up to the maximum token count across ranks if
|
||||
# should_dp_pad is True
|
||||
@@ -164,16 +158,12 @@ def _synchronize_dp_ranks(
|
||||
should_dp_pad,
|
||||
)
|
||||
|
||||
# Synchronize cudagraph_mode across ranks (take min)
|
||||
synced_cudagraph_mode = _post_process_cudagraph_mode(tensor)
|
||||
|
||||
return should_ubatch, num_tokens_after_padding, synced_cudagraph_mode
|
||||
|
||||
|
||||
def coordinate_batch_across_dp(
|
||||
num_tokens_unpadded: int,
|
||||
allow_microbatching: bool,
|
||||
allow_dp_padding: bool,
|
||||
parallel_config: ParallelConfig,
|
||||
num_tokens_padded: int | None = None,
|
||||
uniform_decode: bool | None = None,
|
||||
@@ -187,7 +177,6 @@ def coordinate_batch_across_dp(
|
||||
Args:
|
||||
num_tokens_unpadded: Number of tokens without accounting for padding
|
||||
allow_microbatching: If microbatching should be attempted
|
||||
allow_dp_padding: If all DP ranks should be padded up to the same value
|
||||
parallel_config: The parallel config
|
||||
num_tokens_padded: Number of tokens including any non-DP padding (CUDA graphs,
|
||||
TP, etc)
|
||||
@@ -195,15 +184,15 @@ def coordinate_batch_across_dp(
|
||||
only contains single token decodes
|
||||
num_scheduled_tokens_per_request: Only used if allow_microbatching is True. The
|
||||
number of tokens per request.
|
||||
cudagraph_mode: The cudagraph mode for this rank (0=NONE, 1=PIECEWISE, 2=FULL)
|
||||
cudagraph_mode: The cudagraph mode for this rank (0=NONE, 1=PIECEWISE, 2=FULL).
|
||||
DP padding is enabled when synced cudagraph mode across ranks is not NONE.
|
||||
|
||||
Returns: tuple[
|
||||
ubatch_slices: if this is set then all DP ranks have agreed to
|
||||
microbatch
|
||||
num_tokens_after_padding: A tensor containing the total number of
|
||||
tokens per-microbatch for each DP rank including padding. Will be
|
||||
padded up to the max value across all DP ranks when allow_dp_padding
|
||||
is True.
|
||||
padded up to the max value across all DP ranks when cudagraph is enabled.
|
||||
synced_cudagraph_mode: The synchronized cudagraph mode (min across ranks)
|
||||
]
|
||||
|
||||
@@ -231,7 +220,6 @@ def coordinate_batch_across_dp(
|
||||
num_tokens_unpadded,
|
||||
num_tokens_padded,
|
||||
should_attempt_ubatching,
|
||||
allow_dp_padding,
|
||||
cudagraph_mode,
|
||||
parallel_config,
|
||||
)
|
||||
|
||||
@@ -70,6 +70,42 @@ class AsyncOutput(AsyncModelRunnerOutput):
|
||||
return self.model_runner_output
|
||||
|
||||
|
||||
class AsyncPoolingOutput(AsyncModelRunnerOutput):
|
||||
def __init__(
|
||||
self,
|
||||
model_runner_output: ModelRunnerOutput,
|
||||
pooler_output: torch.Tensor,
|
||||
is_valid: torch.Tensor | None,
|
||||
main_stream: torch.cuda.Stream,
|
||||
copy_stream: torch.cuda.Stream,
|
||||
copy_event: torch.cuda.Event,
|
||||
):
|
||||
self.model_runner_output = model_runner_output
|
||||
self.pooler_output = pooler_output
|
||||
self.is_valid = is_valid
|
||||
self.copy_event = copy_event
|
||||
|
||||
with stream(copy_stream, main_stream):
|
||||
copy_stream.wait_stream(main_stream)
|
||||
self.pooler_output_cpu = self.pooler_output.to("cpu", non_blocking=True)
|
||||
if self.is_valid is not None:
|
||||
self.is_valid_cpu = self.is_valid.to("cpu", non_blocking=True)
|
||||
else:
|
||||
self.is_valid_cpu = None
|
||||
self.copy_event.record(copy_stream)
|
||||
|
||||
def get_output(self) -> ModelRunnerOutput:
|
||||
self.copy_event.synchronize()
|
||||
pooler_output = self.pooler_output_cpu.unbind(dim=0)
|
||||
if self.is_valid_cpu is not None:
|
||||
is_valid_cpu = self.is_valid_cpu.tolist()
|
||||
for i, is_valid in enumerate(is_valid_cpu):
|
||||
if not is_valid:
|
||||
pooler_output[i] = None
|
||||
self.model_runner_output.pooler_output = pooler_output
|
||||
return self.model_runner_output
|
||||
|
||||
|
||||
def async_copy_to_np(x: torch.Tensor) -> np.ndarray:
|
||||
return x.to("cpu", non_blocking=True).numpy()
|
||||
|
||||
|
||||
@@ -119,6 +119,10 @@ class BlockTables:
|
||||
return tuple(block_table[:num_reqs] for block_table in self.input_block_tables)
|
||||
|
||||
def get_dummy_block_tables(self, num_reqs: int) -> tuple[torch.Tensor, ...]:
|
||||
# NOTE(woosuk): The output may be used for CUDA graph capture.
|
||||
# Therefore, this method must return the persistent tensor
|
||||
# with the same memory address as that used during the model's forward pass,
|
||||
# rather than allocating a new tensor.
|
||||
return tuple(block_table[:num_reqs] for block_table in self.input_block_tables)
|
||||
|
||||
def compute_slot_mappings(
|
||||
@@ -150,7 +154,14 @@ class BlockTables:
|
||||
return self.slot_mappings[:, :num_tokens]
|
||||
|
||||
def get_dummy_slot_mappings(self, num_tokens: int) -> torch.Tensor:
|
||||
# Fill the entire slot_mappings tensor, not just the first `num_tokens` entries.
|
||||
# This is because the padding logic is complex and kernels may access beyond
|
||||
# the requested range.
|
||||
self.slot_mappings.fill_(PAD_SLOT_ID)
|
||||
# NOTE(woosuk): The output may be used for CUDA graph capture.
|
||||
# Therefore, this method must return the persistent tensor
|
||||
# with the same memory address as that used during the model's forward pass,
|
||||
# rather than allocating a new tensor.
|
||||
return self.slot_mappings[:, :num_tokens]
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from tqdm import tqdm
|
||||
@@ -15,13 +14,12 @@ from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
from vllm.model_executor.offloader.base import get_offloader
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.worker.gpu.attn_utils import (
|
||||
build_attn_metadata,
|
||||
build_slot_mappings_by_layer,
|
||||
)
|
||||
from vllm.v1.worker.gpu.attn_utils import build_slot_mappings_by_layer
|
||||
from vllm.v1.worker.gpu.block_table import BlockTables
|
||||
from vllm.v1.worker.gpu.cp_utils import prepare_dcp_local_seq_lens
|
||||
from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
|
||||
from vllm.v1.worker.gpu.input_batch import InputBuffers
|
||||
from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
|
||||
from vllm.v1.worker.gpu.model_states.interface import ModelState
|
||||
from vllm.v1.worker.utils import AttentionGroup
|
||||
|
||||
|
||||
@@ -29,13 +27,11 @@ class CudaGraphManager:
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
uses_mrope: bool,
|
||||
use_aux_hidden_state_outputs: bool,
|
||||
device: torch.device,
|
||||
):
|
||||
self.vllm_config = vllm_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.uses_mrope = uses_mrope
|
||||
self.use_aux_hidden_state_outputs = use_aux_hidden_state_outputs
|
||||
self.device = device
|
||||
|
||||
@@ -88,9 +84,8 @@ class CudaGraphManager:
|
||||
num_tokens: int,
|
||||
capture_cg_mode: CUDAGraphMode,
|
||||
model: nn.Module,
|
||||
model_state: ModelState,
|
||||
input_buffers: InputBuffers,
|
||||
mrope_positions: torch.Tensor | None,
|
||||
inputs_embeds: torch.Tensor | None,
|
||||
block_tables: BlockTables,
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
@@ -113,24 +108,23 @@ class CudaGraphManager:
|
||||
)
|
||||
else:
|
||||
num_reqs = min(num_tokens, self.max_num_reqs)
|
||||
input_ids = input_buffers.input_ids[:num_tokens]
|
||||
positions = input_buffers.positions[:num_tokens]
|
||||
if self.uses_mrope:
|
||||
assert mrope_positions is not None
|
||||
positions = mrope_positions[:, :num_tokens]
|
||||
if inputs_embeds is not None:
|
||||
inputs_embeds = inputs_embeds[:num_tokens]
|
||||
|
||||
model_inputs = {
|
||||
"input_ids": input_buffers.input_ids[:num_tokens],
|
||||
"positions": input_buffers.positions[:num_tokens],
|
||||
# NOTE: Values returned by `prepare_dummy_inputs` will override the
|
||||
# default values above.
|
||||
**model_state.prepare_dummy_inputs(num_reqs, num_tokens),
|
||||
}
|
||||
|
||||
attn_metadata, slot_mappings = prepare_inputs_to_capture(
|
||||
num_reqs,
|
||||
num_tokens,
|
||||
model_state,
|
||||
input_buffers,
|
||||
block_tables,
|
||||
attn_groups,
|
||||
self.max_model_len,
|
||||
kv_cache_config,
|
||||
uniform_decode_query_len=(
|
||||
self.uniform_decode_query_len if uniform_decode else 0
|
||||
),
|
||||
)
|
||||
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
|
||||
|
||||
@@ -143,11 +137,7 @@ class CudaGraphManager:
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
slot_mapping=slot_mappings,
|
||||
):
|
||||
model_output = model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
model_output = model(**model_inputs)
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
hidden_states, aux_hidden_states = model_output
|
||||
else:
|
||||
@@ -164,9 +154,7 @@ class CudaGraphManager:
|
||||
num_tokens=num_tokens,
|
||||
num_reqs=num_reqs,
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
model_inputs=model_inputs,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
attn_metadata=attn_metadata,
|
||||
slot_mappings=slot_mappings,
|
||||
@@ -178,9 +166,7 @@ class CudaGraphManager:
|
||||
num_tokens: int,
|
||||
num_reqs: int,
|
||||
model: nn.Module,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor | None,
|
||||
model_inputs: dict[str, torch.Tensor | None],
|
||||
num_tokens_across_dp: torch.Tensor,
|
||||
attn_metadata: dict[str, Any] | None,
|
||||
slot_mappings: dict[str, torch.Tensor] | None,
|
||||
@@ -206,11 +192,8 @@ class CudaGraphManager:
|
||||
),
|
||||
torch.cuda.graph(graph, self.pool),
|
||||
):
|
||||
model_output = model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
model_output = model(**model_inputs)
|
||||
|
||||
# Join offloader's copy stream after forward to avoid unjoined
|
||||
# stream error. The last layer's start_prefetch forks copy_stream,
|
||||
# but wait_prefetch only happens in the next forward pass.
|
||||
@@ -235,9 +218,7 @@ class CudaGraphManager:
|
||||
num_tokens: int,
|
||||
num_reqs: int,
|
||||
model: nn.Module,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor | None,
|
||||
model_inputs: dict[str, torch.Tensor | None],
|
||||
num_tokens_across_dp: torch.Tensor,
|
||||
attn_metadata: dict[str, Any] | None,
|
||||
slot_mappings: dict[str, torch.Tensor] | None,
|
||||
@@ -256,19 +237,14 @@ class CudaGraphManager:
|
||||
batch_descriptor=batch_descriptor,
|
||||
slot_mapping=slot_mappings,
|
||||
):
|
||||
model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
model(**model_inputs)
|
||||
|
||||
@torch.inference_mode()
|
||||
def capture(
|
||||
self,
|
||||
model: nn.Module,
|
||||
model_state: ModelState,
|
||||
input_buffers: InputBuffers,
|
||||
mrope_positions: torch.Tensor | None,
|
||||
inputs_embeds: torch.Tensor | None,
|
||||
block_tables: BlockTables,
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
@@ -278,9 +254,8 @@ class CudaGraphManager:
|
||||
device=self.device,
|
||||
capture_fn=self.capture_graph,
|
||||
model=model,
|
||||
model_state=model_state,
|
||||
input_buffers=input_buffers,
|
||||
mrope_positions=mrope_positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
block_tables=block_tables,
|
||||
attn_groups=attn_groups,
|
||||
kv_cache_config=kv_cache_config,
|
||||
@@ -412,51 +387,36 @@ def capture_graphs(
|
||||
def prepare_inputs_to_capture(
|
||||
num_reqs: int,
|
||||
num_tokens: int,
|
||||
model_state: ModelState,
|
||||
input_buffers: InputBuffers,
|
||||
block_tables: BlockTables,
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
max_model_len: int,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
uniform_decode_query_len: int = 0,
|
||||
) -> tuple[dict[str, Any], dict[str, torch.Tensor]]:
|
||||
if uniform_decode_query_len > 0:
|
||||
num_tokens_per_req = uniform_decode_query_len
|
||||
else:
|
||||
num_tokens_per_req = num_tokens // num_reqs
|
||||
|
||||
query_start_loc_np = np.arange(num_reqs + 1, dtype=np.int32) * num_tokens_per_req
|
||||
query_start_loc_np[-1] = num_tokens
|
||||
query_start_loc_cpu = torch.from_numpy(query_start_loc_np)
|
||||
input_buffers.query_start_loc[: num_reqs + 1] = query_start_loc_cpu
|
||||
input_buffers.query_start_loc[num_reqs + 1 :] = num_tokens
|
||||
query_start_loc = input_buffers.query_start_loc[: num_reqs + 1]
|
||||
|
||||
# HACK(woosuk): For faster warmup, we set seq_lens (GPU) to num_tokens
|
||||
# rather than max_model_len.
|
||||
input_buffers.seq_lens[:num_reqs] = num_tokens
|
||||
input_buffers.seq_lens[num_reqs:] = 0
|
||||
|
||||
input_buffers.dcp_local_seq_lens[:num_reqs] = num_tokens
|
||||
input_buffers.dcp_local_seq_lens[num_reqs:] = 0
|
||||
|
||||
input_block_tables = [x[:num_reqs] for x in block_tables.input_block_tables]
|
||||
slot_mappings = block_tables.slot_mappings[:, :num_tokens]
|
||||
input_batch = InputBatch.make_dummy(num_reqs, num_tokens, input_buffers)
|
||||
input_block_tables = block_tables.get_dummy_block_tables(num_reqs)
|
||||
slot_mappings = block_tables.get_dummy_slot_mappings(num_tokens)
|
||||
slot_mappings_by_layer = build_slot_mappings_by_layer(
|
||||
slot_mappings, kv_cache_config
|
||||
)
|
||||
|
||||
attn_metadata = build_attn_metadata(
|
||||
attn_groups=attn_groups,
|
||||
num_reqs=num_reqs,
|
||||
num_tokens=num_tokens,
|
||||
query_start_loc_gpu=query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
max_query_len=num_tokens_per_req,
|
||||
seq_lens=input_buffers.seq_lens,
|
||||
max_seq_len=max_model_len,
|
||||
block_tables=input_block_tables,
|
||||
slot_mappings=slot_mappings,
|
||||
kv_cache_config=kv_cache_config,
|
||||
dcp_local_seq_lens=input_buffers.dcp_local_seq_lens,
|
||||
# HACK(woosuk): Special handling for DCP.
|
||||
if block_tables.cp_size > 1:
|
||||
prepare_dcp_local_seq_lens(
|
||||
input_buffers.dcp_local_seq_lens,
|
||||
input_batch.seq_lens,
|
||||
num_reqs,
|
||||
block_tables.cp_size,
|
||||
block_tables.cp_rank,
|
||||
block_tables.cp_interleave,
|
||||
)
|
||||
input_batch.dcp_local_seq_lens = input_buffers.dcp_local_seq_lens[:num_reqs]
|
||||
|
||||
attn_metadata = model_state.prepare_attn(
|
||||
input_batch,
|
||||
input_block_tables,
|
||||
slot_mappings,
|
||||
attn_groups,
|
||||
kv_cache_config,
|
||||
)
|
||||
return attn_metadata, slot_mappings_by_layer
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -60,20 +59,13 @@ class InputBatch:
|
||||
query_start_loc_np: np.ndarray
|
||||
# [num_reqs]
|
||||
seq_lens: torch.Tensor
|
||||
# [num_reqs]
|
||||
dcp_local_seq_lens: torch.Tensor | None
|
||||
|
||||
# [num_tokens_after_padding]
|
||||
input_ids: torch.Tensor
|
||||
# [num_tokens_after_padding]
|
||||
positions: torch.Tensor
|
||||
# [3, num_tokens_after_padding]
|
||||
mrope_positions: torch.Tensor | None
|
||||
# [num_tokens_after_padding, hidden_size]
|
||||
inputs_embeds: torch.Tensor | None
|
||||
|
||||
# layer_name -> Metadata
|
||||
attn_metadata: dict[str, Any]
|
||||
# layer_name -> slot_mapping
|
||||
slot_mappings: dict[str, torch.Tensor]
|
||||
|
||||
# [total_num_logits]
|
||||
logits_indices: torch.Tensor
|
||||
@@ -90,14 +82,16 @@ class InputBatch:
|
||||
num_reqs: int,
|
||||
num_tokens: int,
|
||||
input_buffers: InputBuffers,
|
||||
device: torch.device,
|
||||
) -> "InputBatch":
|
||||
assert 0 < num_reqs <= num_tokens
|
||||
device = input_buffers.device
|
||||
|
||||
req_ids = [f"req_{i}_{random_uuid()}" for i in range(num_reqs)]
|
||||
idx_mapping_np = np.arange(num_reqs, dtype=np.int32)
|
||||
idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=device)
|
||||
expanded_idx_mapping = idx_mapping
|
||||
expanded_local_pos = torch.zeros(num_reqs, dtype=torch.int32, device=device)
|
||||
|
||||
num_scheduled_tokens = np.full(num_reqs, num_tokens // num_reqs, dtype=np.int32)
|
||||
num_scheduled_tokens[-1] += num_tokens % num_reqs
|
||||
assert int(num_scheduled_tokens.sum()) == num_tokens
|
||||
@@ -123,7 +117,6 @@ class InputBatch:
|
||||
input_ids = input_buffers.input_ids[:num_tokens].zero_()
|
||||
positions = input_buffers.positions[:num_tokens].zero_()
|
||||
|
||||
# attn_metadata = defaultdict(lambda: None)
|
||||
logits_indices = query_start_loc[1:] - 1
|
||||
cu_num_logits = torch.arange(num_reqs + 1, device=device, dtype=torch.int32)
|
||||
cu_num_logits_np = np.arange(num_reqs + 1, dtype=np.int32)
|
||||
@@ -141,12 +134,9 @@ class InputBatch:
|
||||
query_start_loc=query_start_loc,
|
||||
query_start_loc_np=query_start_loc_np,
|
||||
seq_lens=seq_lens,
|
||||
dcp_local_seq_lens=None,
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
mrope_positions=None,
|
||||
inputs_embeds=None,
|
||||
attn_metadata=None, # type: ignore
|
||||
slot_mappings=None, # type: ignore
|
||||
logits_indices=logits_indices,
|
||||
cu_num_logits=cu_num_logits,
|
||||
cu_num_logits_np=cu_num_logits_np,
|
||||
@@ -507,6 +497,38 @@ def post_update(
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _post_update_pool_kernel(
|
||||
idx_mapping_ptr,
|
||||
num_computed_tokens_ptr,
|
||||
query_start_loc_ptr,
|
||||
):
|
||||
batch_id = tl.program_id(0)
|
||||
query_start = tl.load(query_start_loc_ptr + batch_id)
|
||||
query_end = tl.load(query_start_loc_ptr + batch_id + 1)
|
||||
query_len = query_end - query_start
|
||||
|
||||
req_state_idx = tl.load(idx_mapping_ptr + batch_id)
|
||||
num_computed = tl.load(num_computed_tokens_ptr + req_state_idx)
|
||||
tl.store(num_computed_tokens_ptr + req_state_idx, num_computed + query_len)
|
||||
|
||||
|
||||
def post_update_pool(
|
||||
# [num_reqs]
|
||||
idx_mapping: torch.Tensor,
|
||||
# [max_num_reqs]
|
||||
num_computed_tokens: torch.Tensor,
|
||||
# [num_reqs + 1]
|
||||
query_start_loc: torch.Tensor,
|
||||
) -> None:
|
||||
num_reqs = idx_mapping.shape[0]
|
||||
_post_update_pool_kernel[(num_reqs,)](
|
||||
idx_mapping,
|
||||
num_computed_tokens,
|
||||
query_start_loc,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _expand_idx_mapping_kernel(
|
||||
idx_mapping_ptr,
|
||||
|
||||
40
vllm/v1/worker/gpu/mm/encoder_cache.py
Normal file
40
vllm/v1/worker/gpu/mm/encoder_cache.py
Normal file
@@ -0,0 +1,40 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
|
||||
from vllm.multimodal.inputs import MultiModalFeatureSpec
|
||||
|
||||
|
||||
class EncoderCache:
|
||||
def __init__(self):
|
||||
# req_id -> MM features
|
||||
self.mm_features: dict[str, list[MultiModalFeatureSpec]] = {}
|
||||
# MM hash -> encoder outputs
|
||||
self.encoder_outputs: dict[str, torch.Tensor] = {}
|
||||
|
||||
def add_request(
|
||||
self, req_id: str, mm_features: list[MultiModalFeatureSpec]
|
||||
) -> None:
|
||||
self.mm_features[req_id] = mm_features
|
||||
|
||||
def remove_request(self, req_id: str) -> None:
|
||||
self.mm_features.pop(req_id, None)
|
||||
|
||||
def reset_mm_cache(self) -> None:
|
||||
"""
|
||||
Clear the multi-modal cache that was used during profiling,
|
||||
but no longer needed during inference.
|
||||
"""
|
||||
# TODO: Implement MM budget for encoder dummy run
|
||||
pass
|
||||
|
||||
def reset_encoder_cache(self) -> None:
|
||||
"""Clear the GPU-side encoder cache storing vision embeddings.
|
||||
|
||||
This should be called when model weights are updated to ensure
|
||||
stale embeddings computed with old weights are not reused.
|
||||
"""
|
||||
self.encoder_outputs.clear()
|
||||
|
||||
def free_encoder_cache(self, mm_hash: str) -> None:
|
||||
self.encoder_outputs.pop(mm_hash, None)
|
||||
@@ -4,54 +4,32 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.models.interfaces import SupportsMultiModal
|
||||
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItem
|
||||
from vllm.multimodal.inputs import MultiModalKwargsItem
|
||||
from vllm.multimodal.utils import group_mm_kwargs_by_modality
|
||||
from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache
|
||||
from vllm.v1.worker.utils import sanity_check_mm_encoder_outputs
|
||||
|
||||
|
||||
class EncoderRunner:
|
||||
def __init__(
|
||||
self,
|
||||
model: SupportsMultiModal,
|
||||
max_num_tokens: int,
|
||||
hidden_size: int,
|
||||
encoder_cache: EncoderCache,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
):
|
||||
self.model = model
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.hidden_size = hidden_size
|
||||
self.encoder_cache = encoder_cache
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
self.inputs_embeds = torch.zeros(
|
||||
max_num_tokens, hidden_size, dtype=dtype, device=device
|
||||
)
|
||||
self.req_id_to_mm_features: dict[str, list[MultiModalFeatureSpec]] = {}
|
||||
self.encoder_cache: dict[str, torch.Tensor] = {}
|
||||
|
||||
def reset_mm_cache(self) -> None:
|
||||
"""
|
||||
Clear the multi-modal cache that was used during profiling,
|
||||
but no longer needed during inference.
|
||||
"""
|
||||
# TODO: Implement MM budget for encoder dummy run
|
||||
pass
|
||||
|
||||
def reset_encoder_cache(self) -> None:
|
||||
"""Clear the GPU-side encoder cache storing vision embeddings.
|
||||
|
||||
This should be called when model weights are updated to ensure
|
||||
stale embeddings computed with old weights are not reused.
|
||||
"""
|
||||
self.encoder_cache.clear()
|
||||
|
||||
def add_request(self, req_id: str, mm_features: list[MultiModalFeatureSpec]):
|
||||
self.req_id_to_mm_features[req_id] = mm_features
|
||||
|
||||
def free_encoder_cache(self, mm_hash: str) -> None:
|
||||
self.encoder_cache.pop(mm_hash, None)
|
||||
|
||||
def remove_request(self, req_id: str) -> None:
|
||||
self.req_id_to_mm_features.pop(req_id, None)
|
||||
|
||||
def prepare_mm_inputs(
|
||||
self, scheduled_encoder_inputs: dict[str, list[int]]
|
||||
@@ -59,7 +37,7 @@ class EncoderRunner:
|
||||
mm_hashes: list[str] = []
|
||||
mm_kwargs: list[tuple[str, MultiModalKwargsItem]] = []
|
||||
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
|
||||
mm_features = self.req_id_to_mm_features[req_id]
|
||||
mm_features = self.encoder_cache.mm_features[req_id]
|
||||
for mm_input_id in encoder_input_ids:
|
||||
mm_feature = mm_features[mm_input_id]
|
||||
if mm_feature.data is None:
|
||||
@@ -72,25 +50,17 @@ class EncoderRunner:
|
||||
@torch.inference_mode()
|
||||
def execute_mm_encoder(
|
||||
self,
|
||||
model: SupportsMultiModal,
|
||||
mm_hashes: list[str],
|
||||
mm_kwargs: list[tuple[str, MultiModalKwargsItem]],
|
||||
) -> list[torch.Tensor]:
|
||||
if not mm_hashes:
|
||||
return []
|
||||
|
||||
encoder_outputs: list[torch.Tensor] = []
|
||||
for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
|
||||
mm_kwargs, device=self.device, pin_memory=False
|
||||
):
|
||||
curr_group_outputs = model.embed_multimodal(**mm_kwargs_group)
|
||||
curr_group_outputs = self.model.embed_multimodal(**mm_kwargs_group)
|
||||
sanity_check_mm_encoder_outputs(
|
||||
curr_group_outputs, expected_num_items=num_items
|
||||
)
|
||||
encoder_outputs.extend(curr_group_outputs)
|
||||
|
||||
# Cache the encoder outputs by mm_hash
|
||||
self.encoder_cache.update(zip(mm_hashes, encoder_outputs))
|
||||
return encoder_outputs
|
||||
|
||||
def gather_mm_embeddings(
|
||||
@@ -122,7 +92,7 @@ class EncoderRunner:
|
||||
# OPTIMIZATION: Skip decode requests.
|
||||
continue
|
||||
|
||||
mm_features = self.req_id_to_mm_features[req_id]
|
||||
mm_features = self.encoder_cache.mm_features[req_id]
|
||||
for mm_feature in mm_features:
|
||||
pos_info = mm_feature.mm_position
|
||||
start_pos = pos_info.offset
|
||||
@@ -148,7 +118,7 @@ class EncoderRunner:
|
||||
continue
|
||||
|
||||
mm_hash = mm_feature.identifier
|
||||
encoder_output = self.encoder_cache.get(mm_hash, None)
|
||||
encoder_output = self.encoder_cache.encoder_outputs.get(mm_hash, None)
|
||||
assert encoder_output is not None, f"Encoder cache miss for {mm_hash}."
|
||||
|
||||
if (is_embed := pos_info.is_embed) is not None:
|
||||
@@ -170,12 +140,11 @@ class EncoderRunner:
|
||||
@torch.inference_mode()
|
||||
def get_inputs_embeds(
|
||||
self,
|
||||
model: SupportsMultiModal,
|
||||
input_ids: torch.Tensor,
|
||||
mm_embeds: list[torch.Tensor],
|
||||
is_mm_embed: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
x = model.embed_input_ids(
|
||||
x = self.model.embed_input_ids(
|
||||
input_ids, multimodal_embeddings=mm_embeds, is_multimodal=is_mm_embed
|
||||
)
|
||||
# Copy to the pre-allocated buffer for CUDA graphs.
|
||||
|
||||
@@ -38,15 +38,16 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model_loader
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
|
||||
from vllm.v1.worker.cp_utils import check_attention_cp_compatibility
|
||||
from vllm.v1.worker.gpu.async_utils import AsyncOutput
|
||||
from vllm.v1.worker.gpu.async_utils import AsyncOutput, AsyncPoolingOutput
|
||||
from vllm.v1.worker.gpu.attn_utils import (
|
||||
build_attn_metadata,
|
||||
build_slot_mappings_by_layer,
|
||||
get_kv_cache_spec,
|
||||
init_attn_backend,
|
||||
@@ -56,10 +57,7 @@ from vllm.v1.worker.gpu.block_table import BlockTables
|
||||
from vllm.v1.worker.gpu.buffer_utils import async_copy_to_gpu
|
||||
from vllm.v1.worker.gpu.cp_utils import prepare_dcp_local_seq_lens
|
||||
from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager
|
||||
from vllm.v1.worker.gpu.dp_utils import (
|
||||
get_cudagraph_and_dp_padding,
|
||||
make_num_tokens_across_dp,
|
||||
)
|
||||
from vllm.v1.worker.gpu.dp_utils import get_cudagraph_and_dp_padding
|
||||
from vllm.v1.worker.gpu.input_batch import (
|
||||
InputBatch,
|
||||
InputBuffers,
|
||||
@@ -67,6 +65,7 @@ from vllm.v1.worker.gpu.input_batch import (
|
||||
expand_idx_mapping,
|
||||
get_num_sampled_and_rejected,
|
||||
post_update,
|
||||
post_update_pool,
|
||||
prepare_pos_seq_lens,
|
||||
prepare_prefill_inputs,
|
||||
)
|
||||
@@ -76,8 +75,9 @@ from vllm.v1.worker.gpu.kv_connector import (
|
||||
get_kv_connector,
|
||||
)
|
||||
from vllm.v1.worker.gpu.lora_utils import LoraState
|
||||
from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner
|
||||
from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState
|
||||
from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache
|
||||
from vllm.v1.worker.gpu.model_states import init_model_state
|
||||
from vllm.v1.worker.gpu.pool.pooling_runner import PoolingRunner
|
||||
from vllm.v1.worker.gpu.pp_utils import pp_broadcast, pp_receive
|
||||
from vllm.v1.worker.gpu.sample.output import SamplerOutput
|
||||
from vllm.v1.worker.gpu.sample.prompt_logprob import PromptLogprobsWorker
|
||||
@@ -120,34 +120,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
|
||||
self.cache_config.cache_dtype
|
||||
]
|
||||
self.is_pooling_model = False
|
||||
|
||||
self.vocab_size = self.model_config.get_vocab_size()
|
||||
self.max_model_len = self.model_config.max_model_len
|
||||
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
|
||||
self.max_num_reqs = self.scheduler_config.max_num_seqs
|
||||
self.inputs_embeds_size = self.model_config.get_inputs_embeds_size()
|
||||
|
||||
# Multimodal
|
||||
self.mm_registry = MULTIMODAL_REGISTRY
|
||||
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
|
||||
self.model_config
|
||||
)
|
||||
if self.supports_mm_inputs:
|
||||
self.encoder_runner = EncoderRunner(
|
||||
max_num_tokens=self.max_num_tokens,
|
||||
hidden_size=self.inputs_embeds_size,
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
self.uses_mrope = self.model_config.uses_mrope
|
||||
if self.uses_mrope:
|
||||
self.mrope_states = MRopeState(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_num_tokens=self.max_num_tokens,
|
||||
max_model_len=self.max_model_len,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.use_async_scheduling = self.scheduler_config.async_scheduling
|
||||
self.output_copy_stream = torch.cuda.Stream(self.device)
|
||||
@@ -169,6 +146,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.dcp_rank = get_dcp_group().rank_in_group if self.use_dcp else 0
|
||||
self.cp_interleave = self.parallel_config.cp_kv_cache_interleave_size
|
||||
|
||||
# Multimodal
|
||||
self.mm_registry = MULTIMODAL_REGISTRY
|
||||
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
|
||||
self.model_config
|
||||
)
|
||||
self.encoder_cache = None
|
||||
if self.supports_mm_inputs and self.is_first_pp_rank:
|
||||
self.encoder_cache = EncoderCache()
|
||||
|
||||
self.speculator = None
|
||||
self.num_speculative_steps = 0
|
||||
self.use_aux_hidden_state_outputs = False
|
||||
@@ -212,7 +198,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# CUDA graphs.
|
||||
self.cudagraph_manager = CudaGraphManager(
|
||||
self.vllm_config,
|
||||
self.uses_mrope,
|
||||
self.use_aux_hidden_state_outputs,
|
||||
self.device,
|
||||
)
|
||||
@@ -227,13 +212,24 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# KV Connector if configured.
|
||||
self.kv_connector: KVConnector = NO_OP_KV_CONNECTOR
|
||||
|
||||
# Pooling models.
|
||||
self.is_pooling_model = self.model_config.runner_type == "pooling"
|
||||
self.pooling_runner: PoolingRunner | None = None
|
||||
|
||||
# For transferring state from execute_model to subsequent sample_tokens call.
|
||||
self.execute_model_state: tuple | None = None
|
||||
|
||||
def update_max_model_len(self, max_model_len: int) -> None:
|
||||
self.max_model_len = max_model_len
|
||||
self.req_states.max_model_len = max_model_len
|
||||
|
||||
@staticmethod
|
||||
def get_supported_tasks() -> tuple[str]:
|
||||
return ("generate",)
|
||||
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
||||
tasks: list[SupportedTask] = []
|
||||
if self.model_config.runner_type == "generate":
|
||||
tasks.append("generate")
|
||||
if self.pooling_runner is not None:
|
||||
tasks.extend(self.pooling_runner.get_supported_pooling_tasks())
|
||||
return tuple(tasks)
|
||||
|
||||
def load_model(self, *args, **kwargs) -> None:
|
||||
time_before_load = time.perf_counter()
|
||||
@@ -266,7 +262,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
prepare_communication_buffer_for_model(self.model)
|
||||
if self.speculator is not None:
|
||||
prepare_communication_buffer_for_model(self.speculator)
|
||||
prepare_communication_buffer_for_model(self.speculator.model)
|
||||
|
||||
# Initialize the components that require the model.
|
||||
self.model_state = init_model_state(
|
||||
self.vllm_config, self.model, self.encoder_cache, self.device
|
||||
)
|
||||
if self.is_pooling_model:
|
||||
self.pooling_runner = PoolingRunner(self.model)
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model
|
||||
@@ -305,6 +308,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
if self.speculator is not None:
|
||||
# HACK(woosuk)
|
||||
self.speculator.set_attn(
|
||||
self.model_state,
|
||||
self.kv_cache_config,
|
||||
self.attn_groups,
|
||||
self.block_tables,
|
||||
@@ -320,39 +324,27 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
)
|
||||
self.kv_connector = get_kv_connector(self.vllm_config, kv_caches_dict)
|
||||
|
||||
def prepare_dummy_attn_metadata(self, input_batch: InputBatch) -> None:
|
||||
block_tables = self.block_tables.get_dummy_block_tables(input_batch.num_reqs)
|
||||
slot_mappings = self.block_tables.get_dummy_slot_mappings(
|
||||
input_batch.num_tokens
|
||||
)
|
||||
slot_mappings_by_layer = build_slot_mappings_by_layer(
|
||||
slot_mappings, self.kv_cache_config
|
||||
)
|
||||
attn_metadata = build_attn_metadata(
|
||||
attn_groups=self.attn_groups,
|
||||
num_reqs=input_batch.num_reqs,
|
||||
num_tokens=input_batch.num_tokens,
|
||||
query_start_loc_gpu=input_batch.query_start_loc,
|
||||
query_start_loc_cpu=torch.from_numpy(input_batch.query_start_loc_np),
|
||||
max_query_len=input_batch.num_scheduled_tokens.max().item(),
|
||||
seq_lens=input_batch.seq_lens,
|
||||
max_seq_len=self.max_model_len,
|
||||
block_tables=block_tables,
|
||||
slot_mappings=slot_mappings,
|
||||
kv_cache_config=self.kv_cache_config,
|
||||
dcp_local_seq_lens=self.input_buffers.dcp_local_seq_lens,
|
||||
)
|
||||
input_batch.attn_metadata = attn_metadata
|
||||
input_batch.slot_mappings = slot_mappings_by_layer
|
||||
|
||||
@torch.inference_mode()
|
||||
def _dummy_run(
|
||||
self, num_tokens: int, *args, skip_attn: bool = True, **kwargs
|
||||
self,
|
||||
num_tokens: int,
|
||||
*args,
|
||||
skip_attn: bool = True,
|
||||
uniform_decode: bool = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
|
||||
# Create a dummy scheduler output.
|
||||
num_reqs = min(num_tokens, self.max_num_reqs)
|
||||
num_tokens_per_request = [num_tokens // num_reqs] * num_reqs
|
||||
num_tokens_per_request[-1] += num_tokens % num_reqs
|
||||
if uniform_decode:
|
||||
# Align tokens to uniform_decode_query_len for cudagraph
|
||||
# compatibility across DP ranks.
|
||||
query_len = self.cudagraph_manager.uniform_decode_query_len
|
||||
num_reqs = min(cdiv(num_tokens, query_len), self.max_num_reqs)
|
||||
num_tokens = num_reqs * query_len
|
||||
num_tokens_per_request = [query_len] * num_reqs
|
||||
else:
|
||||
num_reqs = min(num_tokens, self.max_num_reqs)
|
||||
num_tokens_per_request = [num_tokens // num_reqs] * num_reqs
|
||||
num_tokens_per_request[-1] += num_tokens % num_reqs
|
||||
assert sum(num_tokens_per_request) == num_tokens
|
||||
num_scheduled_tokens = {
|
||||
f"_dummy_req_{i}": n for i, n in enumerate(num_tokens_per_request)
|
||||
@@ -387,7 +379,41 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
return None, None
|
||||
|
||||
assert self.execute_model_state is not None
|
||||
hidden_states, _, input_batch, _ = self.execute_model_state
|
||||
(
|
||||
input_batch,
|
||||
model_inputs,
|
||||
attn_metadata,
|
||||
slot_mappings_by_layer,
|
||||
hidden_states,
|
||||
aux_hidden_states,
|
||||
kv_connector_output,
|
||||
num_tokens_across_dp,
|
||||
) = self.execute_model_state
|
||||
self.execute_model_state = None
|
||||
|
||||
# dummy run the eagle speculator's propose to ensure DP/EP sync.
|
||||
if self.speculator is not None:
|
||||
self.speculator.propose(
|
||||
input_batch=input_batch,
|
||||
attn_metadata=attn_metadata,
|
||||
slot_mappings=slot_mappings_by_layer,
|
||||
last_hidden_states=hidden_states,
|
||||
aux_hidden_states=aux_hidden_states,
|
||||
num_sampled=torch.ones(
|
||||
input_batch.num_reqs, dtype=torch.int32, device=self.device
|
||||
),
|
||||
num_rejected=torch.zeros(
|
||||
input_batch.num_reqs, dtype=torch.int32, device=self.device
|
||||
),
|
||||
last_sampled=self.req_states.last_sampled_tokens,
|
||||
next_prefill_tokens=self.req_states.next_prefill_tokens,
|
||||
temperature=self.sampler.sampling_states.temperature.gpu,
|
||||
seeds=self.sampler.sampling_states.seeds.gpu,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
dummy_run=True,
|
||||
skip_attn_for_dummy_run=skip_attn,
|
||||
)
|
||||
|
||||
assert hidden_states is not None # Last PP rank always has hidden_states
|
||||
sample_hidden_states = hidden_states[input_batch.logits_indices]
|
||||
return hidden_states, sample_hidden_states
|
||||
@@ -416,39 +442,36 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
expanded_local_pos,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def _dummy_pooler_run(self, hidden_states: torch.Tensor) -> None:
|
||||
assert self.pooling_runner is not None
|
||||
self.pooling_runner.dummy_pooler_run(hidden_states)
|
||||
|
||||
@torch.inference_mode()
|
||||
def profile_run(self) -> None:
|
||||
hidden_states, sample_hidden_states = self._dummy_run(
|
||||
self.max_num_tokens, skip_attn=True
|
||||
)
|
||||
|
||||
# Only run sampler on last PP rank (non-last ranks return None).
|
||||
# Only run sampler/pooler on last PP rank (non-last ranks return None).
|
||||
if self.is_last_pp_rank:
|
||||
assert sample_hidden_states is not None
|
||||
self._dummy_sampler_run(sample_hidden_states)
|
||||
|
||||
if self.speculator is not None:
|
||||
num_tokens_across_dp = make_num_tokens_across_dp(
|
||||
self.parallel_config.data_parallel_size, self.max_num_tokens
|
||||
)
|
||||
self.speculator.run_model(
|
||||
self.max_num_tokens,
|
||||
attn_metadata=None,
|
||||
slot_mappings=None,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
)
|
||||
if self.pooling_runner is None:
|
||||
self._dummy_sampler_run(sample_hidden_states)
|
||||
else:
|
||||
self._dummy_pooler_run(hidden_states)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
del hidden_states, sample_hidden_states
|
||||
gc.collect()
|
||||
|
||||
def reset_mm_cache(self) -> None:
|
||||
if self.supports_mm_inputs:
|
||||
self.encoder_runner.reset_mm_cache()
|
||||
if self.encoder_cache is not None:
|
||||
self.encoder_cache.reset_mm_cache()
|
||||
|
||||
def reset_encoder_cache(self) -> None:
|
||||
if self.supports_mm_inputs:
|
||||
self.encoder_runner.reset_encoder_cache()
|
||||
if self.encoder_cache is not None:
|
||||
self.encoder_cache.reset_encoder_cache()
|
||||
|
||||
def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int:
|
||||
# SP is not supported yet.
|
||||
@@ -477,17 +500,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||
|
||||
with self.maybe_setup_dummy_loras(self.lora_config):
|
||||
mrope_positions = None
|
||||
if self.uses_mrope:
|
||||
mrope_positions = self.mrope_states.mrope_positions
|
||||
inputs_embeds = None
|
||||
if self.supports_mm_inputs:
|
||||
inputs_embeds = self.encoder_runner.inputs_embeds
|
||||
self.cudagraph_manager.capture(
|
||||
model=self.model,
|
||||
model_state=self.model_state,
|
||||
input_buffers=self.input_buffers,
|
||||
mrope_positions=mrope_positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
block_tables=self.block_tables,
|
||||
attn_groups=self.attn_groups,
|
||||
kv_cache_config=self.kv_cache_config,
|
||||
@@ -522,21 +538,20 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
finished_req_ids = finished_req_ids.union(preempted_req_ids)
|
||||
for req_id in finished_req_ids:
|
||||
self.req_states.remove_request(req_id)
|
||||
if self.supports_mm_inputs:
|
||||
self.encoder_runner.remove_request(req_id)
|
||||
if self.encoder_cache is not None:
|
||||
self.encoder_cache.remove_request(req_id)
|
||||
self.prompt_logprobs_worker.remove_request(req_id)
|
||||
self.lora_state.remove_request(req_id)
|
||||
|
||||
def free_states(self, scheduler_output: SchedulerOutput) -> None:
|
||||
if self.supports_mm_inputs:
|
||||
if self.encoder_cache is not None:
|
||||
for mm_hash in scheduler_output.free_encoder_mm_hashes:
|
||||
self.encoder_runner.free_encoder_cache(mm_hash)
|
||||
self.encoder_cache.free_encoder_cache(mm_hash)
|
||||
|
||||
def add_requests(self, scheduler_output: SchedulerOutput) -> None:
|
||||
for new_req_data in scheduler_output.scheduled_new_reqs:
|
||||
assert new_req_data.prompt_token_ids is not None
|
||||
assert new_req_data.prefill_token_ids is not None
|
||||
assert new_req_data.sampling_params is not None
|
||||
req_id = new_req_data.req_id
|
||||
prompt_len = len(new_req_data.prompt_token_ids)
|
||||
self.req_states.add_request(
|
||||
@@ -547,34 +562,27 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
)
|
||||
req_index = self.req_states.req_id_to_index[req_id]
|
||||
|
||||
if self.supports_mm_inputs:
|
||||
self.encoder_runner.add_request(req_id, new_req_data.mm_features)
|
||||
|
||||
# Pre-compute M-RoPE positions for prefill.
|
||||
if self.uses_mrope:
|
||||
self.mrope_states.init_prefill_mrope_positions(
|
||||
req_index,
|
||||
self.model, # type: ignore
|
||||
new_req_data.prefill_token_ids,
|
||||
mm_features=new_req_data.mm_features,
|
||||
)
|
||||
if self.encoder_cache is not None:
|
||||
self.encoder_cache.add_request(req_id, new_req_data.mm_features)
|
||||
|
||||
self.model_state.add_request(req_index, new_req_data)
|
||||
self.block_tables.append_block_ids(
|
||||
req_index, new_req_data.block_ids, overwrite=True
|
||||
)
|
||||
self.sampler.add_request(
|
||||
req_index, prompt_len, new_req_data.sampling_params
|
||||
)
|
||||
self.prompt_logprobs_worker.add_request(
|
||||
req_id, req_index, new_req_data.sampling_params
|
||||
)
|
||||
self.lora_state.add_request(req_id, req_index, new_req_data.lora_request)
|
||||
|
||||
if new_req_data.sampling_params is not None:
|
||||
self.sampler.add_request(
|
||||
req_index, prompt_len, new_req_data.sampling_params
|
||||
)
|
||||
self.prompt_logprobs_worker.add_request(
|
||||
req_id, req_index, new_req_data.sampling_params
|
||||
)
|
||||
|
||||
if scheduler_output.scheduled_new_reqs:
|
||||
self.req_states.apply_staged_writes()
|
||||
self.sampler.apply_staged_writes()
|
||||
if self.uses_mrope:
|
||||
self.mrope_states.apply_staged_writes()
|
||||
self.model_state.apply_staged_writes()
|
||||
|
||||
def update_requests(self, scheduler_output: SchedulerOutput) -> None:
|
||||
# Add new blocks for the existing requests.
|
||||
@@ -637,9 +645,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
idx_mapping, total_num_logits, cu_num_logits, max_expand_len
|
||||
)
|
||||
|
||||
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
|
||||
block_tables = self.block_tables.gather_block_tables(idx_mapping)
|
||||
|
||||
# Get query_start_loc.
|
||||
query_start_loc_np = np.empty(self.max_num_reqs + 1, dtype=np.int32)
|
||||
query_start_loc_np[0] = 0
|
||||
@@ -648,11 +653,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Some attention backends like FA3 require query_start_loc to be non-decreasing.
|
||||
query_start_loc_np[num_reqs + 1 :] = num_tokens
|
||||
async_copy_to_gpu(query_start_loc_np, out=self.input_buffers.query_start_loc)
|
||||
|
||||
query_start_loc_np = query_start_loc_np[: num_reqs + 1]
|
||||
query_start_loc_cpu = torch.from_numpy(query_start_loc_np)
|
||||
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
|
||||
max_query_len = num_scheduled_tokens.max().item()
|
||||
|
||||
# Get prefill tokens if any.
|
||||
if self.req_states.any_prefills(idx_mapping_np):
|
||||
@@ -676,6 +678,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
)
|
||||
seq_lens = self.input_buffers.seq_lens[:num_reqs]
|
||||
|
||||
dcp_local_seq_lens = None
|
||||
if self.use_dcp:
|
||||
# Prepare dcp local seq_lens.
|
||||
prepare_dcp_local_seq_lens(
|
||||
@@ -686,16 +689,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.dcp_rank,
|
||||
self.cp_interleave,
|
||||
)
|
||||
dcp_local_seq_lens = self.input_buffers.dcp_local_seq_lens[:num_reqs]
|
||||
|
||||
# Prepare M-RoPE positions.
|
||||
if self.uses_mrope:
|
||||
self.mrope_states.prepare_mrope_positions(
|
||||
idx_mapping,
|
||||
query_start_loc,
|
||||
self.req_states.prefill_len.gpu,
|
||||
self.req_states.num_computed_tokens.gpu,
|
||||
)
|
||||
dcp_local_seq_lens = self.input_buffers.dcp_local_seq_lens[:num_reqs]
|
||||
|
||||
# Some input token ids are directly read from the last sampled tokens
|
||||
# and draft tokens. Also, get the logits indices to sample tokens from.
|
||||
@@ -711,39 +705,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
total_num_logits,
|
||||
)
|
||||
|
||||
# Compute slot mappings: [num_kv_cache_groups, num_tokens]
|
||||
slot_mappings = self.block_tables.compute_slot_mappings(
|
||||
idx_mapping,
|
||||
query_start_loc,
|
||||
self.input_buffers.positions[:num_tokens],
|
||||
)
|
||||
# Layer name -> slot mapping.
|
||||
slot_mappings_by_layer = build_slot_mappings_by_layer(
|
||||
slot_mappings, self.kv_cache_config
|
||||
)
|
||||
|
||||
# Layer name -> attention metadata.
|
||||
attn_metadata = build_attn_metadata(
|
||||
attn_groups=self.attn_groups,
|
||||
num_reqs=num_reqs,
|
||||
num_tokens=num_tokens,
|
||||
query_start_loc_gpu=query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
max_query_len=max_query_len,
|
||||
seq_lens=self.input_buffers.seq_lens,
|
||||
max_seq_len=self.max_model_len,
|
||||
block_tables=block_tables,
|
||||
slot_mappings=slot_mappings,
|
||||
kv_cache_config=self.kv_cache_config,
|
||||
dcp_local_seq_lens=dcp_local_seq_lens,
|
||||
)
|
||||
|
||||
input_ids = self.input_buffers.input_ids[:num_tokens_after_padding]
|
||||
positions = self.input_buffers.positions[:num_tokens_after_padding]
|
||||
mrope_positions = None
|
||||
if self.uses_mrope:
|
||||
mrope_positions = self.mrope_states.mrope_positions
|
||||
mrope_positions = mrope_positions[:, :num_tokens_after_padding]
|
||||
return InputBatch(
|
||||
req_ids=req_ids,
|
||||
num_reqs=num_reqs,
|
||||
@@ -758,37 +719,36 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
query_start_loc=query_start_loc,
|
||||
query_start_loc_np=query_start_loc_np,
|
||||
seq_lens=seq_lens,
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
mrope_positions=mrope_positions,
|
||||
inputs_embeds=None,
|
||||
attn_metadata=attn_metadata,
|
||||
slot_mappings=slot_mappings_by_layer,
|
||||
dcp_local_seq_lens=dcp_local_seq_lens,
|
||||
input_ids=self.input_buffers.input_ids[:num_tokens_after_padding],
|
||||
positions=self.input_buffers.positions[:num_tokens_after_padding],
|
||||
logits_indices=logits_indices,
|
||||
cu_num_logits=cu_num_logits,
|
||||
cu_num_logits_np=cu_num_logits_np,
|
||||
has_structured_output_reqs=scheduler_output.has_structured_output_requests,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_mm_embeddings(
|
||||
self,
|
||||
scheduled_encoder_inputs: dict[str, list[int]],
|
||||
input_batch: InputBatch,
|
||||
) -> tuple[list[torch.Tensor], torch.Tensor]:
|
||||
mm_hashes, mm_kwargs = self.encoder_runner.prepare_mm_inputs(
|
||||
scheduled_encoder_inputs
|
||||
def prepare_attn(
|
||||
self, input_batch: InputBatch
|
||||
) -> tuple[tuple[torch.Tensor, ...], torch.Tensor]:
|
||||
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
|
||||
block_tables = self.block_tables.gather_block_tables(input_batch.idx_mapping)
|
||||
# Compute slot mappings: [num_kv_cache_groups, num_tokens]
|
||||
slot_mappings = self.block_tables.compute_slot_mappings(
|
||||
input_batch.idx_mapping,
|
||||
input_batch.query_start_loc,
|
||||
input_batch.positions,
|
||||
)
|
||||
self.encoder_runner.execute_mm_encoder(self.model, mm_hashes, mm_kwargs)
|
||||
mm_embeds, is_mm_embed = self.encoder_runner.gather_mm_embeddings(
|
||||
input_batch.req_ids,
|
||||
input_batch.num_tokens,
|
||||
input_batch.num_scheduled_tokens,
|
||||
input_batch.query_start_loc_np,
|
||||
self.req_states.prefill_len.np[input_batch.idx_mapping_np],
|
||||
self.req_states.num_computed_prefill_tokens[input_batch.idx_mapping_np],
|
||||
return block_tables, slot_mappings
|
||||
|
||||
def prepare_dummy_attn(
|
||||
self, input_batch: InputBatch
|
||||
) -> tuple[tuple[torch.Tensor, ...], torch.Tensor]:
|
||||
block_tables = self.block_tables.get_dummy_block_tables(input_batch.num_reqs)
|
||||
slot_mappings = self.block_tables.get_dummy_slot_mappings(
|
||||
input_batch.num_tokens
|
||||
)
|
||||
return mm_embeds, is_mm_embed
|
||||
return block_tables, slot_mappings
|
||||
|
||||
def sample(
|
||||
self,
|
||||
@@ -926,6 +886,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
input_batch = self.prepare_inputs(
|
||||
scheduler_output, num_tokens_after_padding
|
||||
)
|
||||
block_tables, slot_mappings = self.prepare_attn(input_batch)
|
||||
|
||||
if self.lora_config:
|
||||
# Activate LoRA adapters.
|
||||
lora_inputs = self.lora_state.make_lora_inputs(
|
||||
@@ -934,35 +896,61 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
input_batch.num_scheduled_tokens,
|
||||
)
|
||||
self._set_active_loras(*lora_inputs)
|
||||
|
||||
# Only first PP rank prepares multimodal embeddings.
|
||||
if self.supports_mm_inputs and self.is_first_pp_rank:
|
||||
mm_embeds, is_mm_embed = self.get_mm_embeddings(
|
||||
scheduler_output.scheduled_encoder_inputs, input_batch
|
||||
)
|
||||
inputs_embeds = self.encoder_runner.get_inputs_embeds(
|
||||
self.model, input_batch.input_ids, mm_embeds, is_mm_embed
|
||||
)
|
||||
input_batch.inputs_embeds = inputs_embeds[
|
||||
: input_batch.num_tokens_after_padding
|
||||
]
|
||||
else:
|
||||
# No actual tokens to run. A dummy run for DP or memory profiling.
|
||||
num_reqs = min(num_tokens_after_padding, self.max_num_reqs)
|
||||
input_batch = InputBatch.make_dummy(
|
||||
num_reqs=num_reqs,
|
||||
num_tokens=num_tokens_after_padding,
|
||||
input_buffers=self.input_buffers,
|
||||
device=self.device,
|
||||
num_reqs, num_tokens_after_padding, self.input_buffers
|
||||
)
|
||||
if self.uses_mrope:
|
||||
input_batch.mrope_positions = self.mrope_states.mrope_positions[
|
||||
:, :num_tokens_after_padding
|
||||
]
|
||||
if not skip_attn_for_dummy_run:
|
||||
self.prepare_dummy_attn_metadata(input_batch)
|
||||
block_tables, slot_mappings = self.prepare_dummy_attn(input_batch)
|
||||
else:
|
||||
block_tables = None
|
||||
slot_mappings = None
|
||||
# FIXME(woosuk): Fix warmup for LoRA.
|
||||
|
||||
attn_metadata = None
|
||||
slot_mappings_by_layer = None
|
||||
if not (dummy_run and skip_attn_for_dummy_run):
|
||||
assert slot_mappings is not None
|
||||
slot_mappings_by_layer = build_slot_mappings_by_layer(
|
||||
slot_mappings, self.kv_cache_config
|
||||
)
|
||||
assert block_tables is not None
|
||||
attn_metadata = self.model_state.prepare_attn(
|
||||
input_batch,
|
||||
block_tables,
|
||||
slot_mappings,
|
||||
self.attn_groups,
|
||||
self.kv_cache_config,
|
||||
)
|
||||
|
||||
inputs_embeds = None
|
||||
if self.supports_mm_inputs and self.is_first_pp_rank:
|
||||
# Run MM encoder (if needed) and get multimodal embeddings.
|
||||
# Only first PP rank prepares multimodal embeddings.
|
||||
# NOTE(woosuk): We must call get_mm_embeddings even during dummy runs
|
||||
# to obtain inputs_embeds, because the compiled model expects this input.
|
||||
inputs_embeds = self.model_state.get_mm_embeddings(
|
||||
scheduler_output.scheduled_encoder_inputs,
|
||||
input_batch,
|
||||
self.req_states,
|
||||
)
|
||||
|
||||
model_inputs = {
|
||||
"input_ids": input_batch.input_ids,
|
||||
"positions": input_batch.positions,
|
||||
"inputs_embeds": inputs_embeds,
|
||||
# NOTE: Values returned by `prepare_inputs` will override the default
|
||||
# values above.
|
||||
**self.model_state.prepare_inputs(input_batch, self.req_states),
|
||||
}
|
||||
if not self.is_first_pp_rank:
|
||||
# Update for non-first PP ranks.
|
||||
model_inputs["input_ids"] = None
|
||||
model_inputs["inputs_embeds"] = None
|
||||
model_inputs["intermediate_tensors"] = intermediate_tensors
|
||||
|
||||
# Run model.
|
||||
if cudagraph_runtime_mode == CUDAGraphMode.FULL:
|
||||
# Use explicit cudagraph replay for FULL mode.
|
||||
@@ -979,41 +967,22 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
aux_hidden_states = None
|
||||
else:
|
||||
# For piecewise and eager mode, just call model().
|
||||
positions = input_batch.positions
|
||||
if self.uses_mrope:
|
||||
assert input_batch.mrope_positions is not None
|
||||
positions = input_batch.mrope_positions
|
||||
|
||||
if self.is_first_pp_rank:
|
||||
input_ids = input_batch.input_ids
|
||||
inputs_embeds = input_batch.inputs_embeds
|
||||
assert intermediate_tensors is None
|
||||
else:
|
||||
input_ids = None
|
||||
inputs_embeds = None
|
||||
assert intermediate_tensors is not None
|
||||
|
||||
batch_descriptor = BatchDescriptor(
|
||||
num_tokens=input_batch.num_tokens_after_padding,
|
||||
has_lora=self.lora_config is not None,
|
||||
)
|
||||
|
||||
with set_forward_context(
|
||||
input_batch.attn_metadata,
|
||||
attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=input_batch.num_tokens_after_padding,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
batch_descriptor=batch_descriptor,
|
||||
slot_mapping=input_batch.slot_mappings,
|
||||
slot_mapping=slot_mappings_by_layer,
|
||||
):
|
||||
self.kv_connector.pre_forward(scheduler_output)
|
||||
model_output = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
)
|
||||
model_output = self.model(**model_inputs)
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
hidden_states, aux_hidden_states = model_output
|
||||
else:
|
||||
@@ -1021,33 +990,44 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
aux_hidden_states = None
|
||||
|
||||
kv_connector_output = self.kv_connector.post_forward(scheduler_output)
|
||||
self.execute_model_state = (
|
||||
input_batch,
|
||||
model_inputs,
|
||||
attn_metadata,
|
||||
slot_mappings_by_layer,
|
||||
hidden_states,
|
||||
aux_hidden_states,
|
||||
kv_connector_output,
|
||||
num_tokens_across_dp,
|
||||
)
|
||||
|
||||
if not self.is_last_pp_rank:
|
||||
# Non-last PP rank: return IntermediateTensors for sending.
|
||||
assert isinstance(hidden_states, IntermediateTensors)
|
||||
hidden_states.kv_connector_output = kv_connector_output
|
||||
self.execute_model_state = (None, None, input_batch, kv_connector_output)
|
||||
return hidden_states
|
||||
|
||||
# Last rank (or no PP): hidden_states is a tensor for sampling.
|
||||
assert isinstance(hidden_states, torch.Tensor)
|
||||
self.execute_model_state = (
|
||||
hidden_states,
|
||||
aux_hidden_states,
|
||||
input_batch,
|
||||
kv_connector_output,
|
||||
) # type: ignore
|
||||
return None
|
||||
|
||||
@torch.inference_mode()
|
||||
def sample_tokens(
|
||||
self, grammar_output: GrammarOutput | None
|
||||
) -> AsyncOutput | ModelRunnerOutput | None:
|
||||
assert self.execute_model_state is not None
|
||||
hidden_states, aux_hidden_states, input_batch, kv_connector_output = (
|
||||
self.execute_model_state
|
||||
)
|
||||
self.execute_model_state = None # type: ignore
|
||||
if self.execute_model_state is None:
|
||||
# The prior execute_model call must have failed.
|
||||
return None
|
||||
(
|
||||
input_batch,
|
||||
model_inputs,
|
||||
attn_metadata,
|
||||
slot_mappings_by_layer,
|
||||
hidden_states,
|
||||
aux_hidden_states,
|
||||
kv_connector_output,
|
||||
num_tokens_across_dp,
|
||||
) = self.execute_model_state
|
||||
self.execute_model_state = None
|
||||
|
||||
if not self.is_last_pp_rank:
|
||||
# Non-last PP rank: hidden_states is None because this rank produced
|
||||
@@ -1109,6 +1089,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
if self.speculator is not None:
|
||||
draft_tokens = self.speculator.propose(
|
||||
input_batch,
|
||||
attn_metadata,
|
||||
slot_mappings_by_layer,
|
||||
hidden_states,
|
||||
aux_hidden_states,
|
||||
num_sampled,
|
||||
@@ -1117,6 +1099,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.req_states.next_prefill_tokens,
|
||||
self.sampler.sampling_states.temperature.gpu,
|
||||
self.sampler.sampling_states.seeds.gpu,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
)
|
||||
self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens
|
||||
self.draft_tokens_handler.set_draft_tokens(input_batch, draft_tokens)
|
||||
@@ -1127,3 +1110,58 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
def take_draft_token_ids(self) -> DraftTokenIds | None:
|
||||
return self.draft_tokens_handler.get_draft_tokens()
|
||||
|
||||
@torch.inference_mode()
|
||||
def pool(self) -> AsyncPoolingOutput | ModelRunnerOutput | None:
|
||||
if self.execute_model_state is None:
|
||||
# The prior execute_model call must have failed.
|
||||
return None
|
||||
|
||||
input_batch, _, _, _, hidden_states, _, kv_connector_output = (
|
||||
self.execute_model_state
|
||||
)
|
||||
self.execute_model_state = None
|
||||
|
||||
if not self.is_last_pp_rank:
|
||||
self.postprocess_pool(input_batch)
|
||||
return None
|
||||
|
||||
assert self.pooling_runner is not None
|
||||
pooler_output, is_valid = self.pooling_runner.pool(
|
||||
hidden_states, input_batch, self.req_states
|
||||
)
|
||||
self.postprocess_pool(input_batch)
|
||||
|
||||
# Build the model runner output.
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=input_batch.req_ids,
|
||||
req_id_to_index={req_id: i for i, req_id in enumerate(input_batch.req_ids)},
|
||||
kv_connector_output=kv_connector_output,
|
||||
)
|
||||
async_output = AsyncPoolingOutput(
|
||||
model_runner_output=model_runner_output,
|
||||
pooler_output=pooler_output,
|
||||
is_valid=is_valid,
|
||||
main_stream=self.main_stream,
|
||||
copy_stream=self.output_copy_stream,
|
||||
copy_event=self.output_copy_event,
|
||||
)
|
||||
if self.use_async_scheduling:
|
||||
return async_output
|
||||
return async_output.get_output()
|
||||
|
||||
def postprocess_pool(self, input_batch: InputBatch) -> None:
|
||||
# Update the number of computed tokens.
|
||||
post_update_pool(
|
||||
input_batch.idx_mapping,
|
||||
self.req_states.num_computed_tokens.gpu,
|
||||
input_batch.query_start_loc,
|
||||
)
|
||||
|
||||
# Update the number of computed prefill tokens.
|
||||
idx_mapping_np = input_batch.idx_mapping_np
|
||||
computed_prefill = self.req_states.num_computed_prefill_tokens
|
||||
computed_prefill[idx_mapping_np] += input_batch.num_scheduled_tokens
|
||||
np.minimum(
|
||||
computed_prefill, self.req_states.prefill_len.np, out=computed_prefill
|
||||
)
|
||||
|
||||
18
vllm/v1/worker/gpu/model_states/__init__.py
Normal file
18
vllm/v1/worker/gpu/model_states/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache
|
||||
|
||||
|
||||
def init_model_state(
|
||||
vllm_config: VllmConfig,
|
||||
model: nn.Module,
|
||||
encoder_cache: EncoderCache | None,
|
||||
device: torch.device,
|
||||
):
|
||||
from vllm.v1.worker.gpu.model_states.default import DefaultModelState
|
||||
|
||||
return DefaultModelState(vllm_config, model, encoder_cache, device)
|
||||
161
vllm/v1/worker/gpu/model_states/default.py
Normal file
161
vllm/v1/worker/gpu/model_states/default.py
Normal file
@@ -0,0 +1,161 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.core.sched.output import NewRequestData
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
|
||||
from vllm.v1.worker.gpu.input_batch import InputBatch
|
||||
from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache
|
||||
from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner
|
||||
from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState
|
||||
from vllm.v1.worker.gpu.model_states.interface import ModelState
|
||||
from vllm.v1.worker.gpu.states import RequestState
|
||||
from vllm.v1.worker.utils import AttentionGroup
|
||||
|
||||
|
||||
class DefaultModelState(ModelState):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
model: nn.Module,
|
||||
encoder_cache: EncoderCache | None,
|
||||
device: torch.device,
|
||||
):
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.model = model
|
||||
self.device = device
|
||||
|
||||
self.supports_mm_inputs = encoder_cache is not None
|
||||
self.max_model_len = self.model_config.max_model_len
|
||||
self.max_num_reqs = self.scheduler_config.max_num_seqs
|
||||
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
|
||||
self.inputs_embeds_size = self.model_config.get_inputs_embeds_size()
|
||||
self.dtype = self.model_config.dtype
|
||||
|
||||
if self.supports_mm_inputs:
|
||||
assert encoder_cache is not None
|
||||
self.encoder_cache = encoder_cache
|
||||
self.encoder_runner = EncoderRunner(
|
||||
model=self.model,
|
||||
max_num_tokens=self.max_num_tokens,
|
||||
hidden_size=self.inputs_embeds_size,
|
||||
encoder_cache=encoder_cache,
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.uses_mrope = self.model_config.uses_mrope
|
||||
if self.uses_mrope:
|
||||
self.mrope_state = MRopeState(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_num_tokens=self.max_num_tokens,
|
||||
max_model_len=self.max_model_len,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def add_request(self, req_index: int, new_req_data: NewRequestData) -> None:
|
||||
if self.uses_mrope:
|
||||
# Pre-compute M-RoPE positions for prefill.
|
||||
assert new_req_data.prefill_token_ids is not None
|
||||
self.mrope_state.init_prefill_mrope_positions(
|
||||
req_index,
|
||||
self.model, # type: ignore
|
||||
new_req_data.prefill_token_ids,
|
||||
mm_features=new_req_data.mm_features,
|
||||
)
|
||||
|
||||
def apply_staged_writes(self) -> None:
|
||||
if self.uses_mrope:
|
||||
self.mrope_state.apply_staged_writes()
|
||||
|
||||
def get_mm_embeddings(
|
||||
self,
|
||||
scheduled_encoder_inputs: dict[str, list[int]],
|
||||
input_batch: InputBatch,
|
||||
req_states: RequestState,
|
||||
) -> torch.Tensor:
|
||||
mm_hashes, mm_kwargs = self.encoder_runner.prepare_mm_inputs(
|
||||
scheduled_encoder_inputs
|
||||
)
|
||||
if mm_kwargs:
|
||||
# Execute the multimodal encoder.
|
||||
encoder_outputs = self.encoder_runner.execute_mm_encoder(mm_kwargs)
|
||||
# Cache the encoder outputs by mm_hash
|
||||
self.encoder_cache.encoder_outputs.update(zip(mm_hashes, encoder_outputs))
|
||||
|
||||
mm_embeds, is_mm_embed = self.encoder_runner.gather_mm_embeddings(
|
||||
input_batch.req_ids,
|
||||
input_batch.num_tokens,
|
||||
input_batch.num_scheduled_tokens,
|
||||
input_batch.query_start_loc_np,
|
||||
req_states.prefill_len.np[input_batch.idx_mapping_np],
|
||||
req_states.num_computed_prefill_tokens[input_batch.idx_mapping_np],
|
||||
)
|
||||
inputs_embeds = self.encoder_runner.get_inputs_embeds(
|
||||
input_batch.input_ids, mm_embeds, is_mm_embed
|
||||
)
|
||||
return inputs_embeds[: input_batch.num_tokens_after_padding]
|
||||
|
||||
def prepare_inputs(
|
||||
self, input_batch: InputBatch, req_states: RequestState
|
||||
) -> dict[str, torch.Tensor | None]:
|
||||
if not self.uses_mrope:
|
||||
# Common case (1D positions).
|
||||
return {}
|
||||
|
||||
# Prepare M-RoPE positions.
|
||||
self.mrope_state.prepare_mrope_positions(
|
||||
input_batch.idx_mapping,
|
||||
input_batch.query_start_loc,
|
||||
req_states.prefill_len.gpu,
|
||||
req_states.num_computed_tokens.gpu,
|
||||
)
|
||||
mrope_positions = self.mrope_state.mrope_positions[
|
||||
:, : input_batch.num_tokens_after_padding
|
||||
]
|
||||
return {"positions": mrope_positions}
|
||||
|
||||
def prepare_dummy_inputs(
|
||||
self, num_reqs: int, num_tokens: int
|
||||
) -> dict[str, torch.Tensor | None]:
|
||||
model_inputs = {}
|
||||
if self.supports_mm_inputs:
|
||||
inputs_embeds = self.encoder_runner.inputs_embeds[:num_tokens]
|
||||
model_inputs["inputs_embeds"] = inputs_embeds
|
||||
if self.uses_mrope:
|
||||
mrope_positions = self.mrope_state.mrope_positions[:, :num_tokens]
|
||||
model_inputs["positions"] = mrope_positions
|
||||
return model_inputs
|
||||
|
||||
def prepare_attn(
|
||||
self,
|
||||
input_batch: InputBatch,
|
||||
block_tables: tuple[torch.Tensor, ...],
|
||||
slot_mappings: torch.Tensor,
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
) -> dict[str, Any]:
|
||||
query_start_loc_cpu = torch.from_numpy(input_batch.query_start_loc_np)
|
||||
max_query_len = input_batch.num_scheduled_tokens.max().item()
|
||||
attn_metadata = build_attn_metadata(
|
||||
attn_groups=attn_groups,
|
||||
num_reqs=input_batch.num_reqs,
|
||||
num_tokens=input_batch.num_tokens,
|
||||
query_start_loc_gpu=input_batch.query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
max_query_len=max_query_len,
|
||||
seq_lens=input_batch.seq_lens,
|
||||
max_seq_len=self.max_model_len,
|
||||
block_tables=block_tables,
|
||||
slot_mappings=slot_mappings,
|
||||
kv_cache_config=kv_cache_config,
|
||||
dcp_local_seq_lens=input_batch.dcp_local_seq_lens,
|
||||
)
|
||||
return attn_metadata
|
||||
67
vllm/v1/worker/gpu/model_states/interface.py
Normal file
67
vllm/v1/worker/gpu/model_states/interface.py
Normal file
@@ -0,0 +1,67 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.core.sched.output import NewRequestData
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.worker.gpu.input_batch import InputBatch
|
||||
from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache
|
||||
from vllm.v1.worker.gpu.states import RequestState
|
||||
from vllm.v1.worker.utils import AttentionGroup
|
||||
|
||||
|
||||
class ModelState(ABC):
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
model: nn.Module,
|
||||
encoder_cache: EncoderCache | None,
|
||||
device: torch.device,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def add_request(self, req_index: int, new_req_data: NewRequestData) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply_staged_writes(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_mm_embeddings(
|
||||
self,
|
||||
scheduled_encoder_inputs: dict[str, list[int]],
|
||||
input_batch: InputBatch,
|
||||
req_states: RequestState,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def prepare_inputs(
|
||||
self, input_batch: InputBatch, req_states: RequestState
|
||||
) -> dict[str, torch.Tensor | None]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def prepare_dummy_inputs(
|
||||
self, num_reqs: int, num_tokens: int
|
||||
) -> dict[str, torch.Tensor | None]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def prepare_attn(
|
||||
self,
|
||||
input_batch: InputBatch,
|
||||
block_tables: tuple[torch.Tensor, ...],
|
||||
slot_mappings: torch.Tensor,
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
0
vllm/v1/worker/gpu/pool/__init__.py
Normal file
0
vllm/v1/worker/gpu/pool/__init__.py
Normal file
45
vllm/v1/worker/gpu/pool/pooling_runner.py
Normal file
45
vllm/v1/worker/gpu/pool/pooling_runner.py
Normal file
@@ -0,0 +1,45 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import cast
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm.model_executor.models import VllmModelForPooling, is_pooling_model
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.v1.worker.gpu.input_batch import InputBatch
|
||||
from vllm.v1.worker.gpu.states import RequestState
|
||||
|
||||
|
||||
# NOTE(woosuk): Currently, this class only supports the "LAST" pooling task
|
||||
# on decoder-only models. How to support other pooling tasks and models
|
||||
# is to be determined.
|
||||
class PoolingRunner:
|
||||
def __init__(self, model: nn.Module):
|
||||
self.model = cast(VllmModelForPooling, model)
|
||||
|
||||
def get_supported_pooling_tasks(self) -> list[PoolingTask]:
|
||||
if not is_pooling_model(self.model):
|
||||
return []
|
||||
assert "embed" in self.model.pooler.get_supported_tasks()
|
||||
return ["embed"]
|
||||
|
||||
def pool(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
input_batch: InputBatch,
|
||||
req_states: RequestState,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
# TODO(woosuk): Support different types of pooling tasks.
|
||||
last_hidden_states = hidden_states[input_batch.logits_indices]
|
||||
# TODO(woosuk): Make normalization optional.
|
||||
last_hidden_states = F.normalize(last_hidden_states, p=2, dim=-1)
|
||||
|
||||
prompt_len = req_states.prompt_len.gpu[input_batch.idx_mapping]
|
||||
is_valid = input_batch.seq_lens == prompt_len
|
||||
return last_hidden_states, is_valid
|
||||
|
||||
def dummy_pooler_run(self, hidden_states: torch.Tensor) -> None:
|
||||
F.normalize(hidden_states, p=2, dim=-1)
|
||||
return
|
||||
@@ -72,7 +72,7 @@ class BadWordsState:
|
||||
def apply_bad_words(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
expanded_idx_mapping: torch.Tensor,
|
||||
idx_mapping_np: np.ndarray,
|
||||
input_ids: torch.Tensor,
|
||||
expanded_local_pos: torch.Tensor,
|
||||
@@ -84,7 +84,7 @@ class BadWordsState:
|
||||
|
||||
apply_bad_words(
|
||||
logits,
|
||||
idx_mapping,
|
||||
expanded_idx_mapping,
|
||||
self.bad_word_token_ids.gpu,
|
||||
self.bad_word_offsets.gpu,
|
||||
self.num_bad_words.gpu,
|
||||
@@ -114,17 +114,17 @@ def _bad_words_kernel(
|
||||
input_ids_ptr,
|
||||
expanded_local_pos_ptr,
|
||||
):
|
||||
logit_idx = tl.program_id(0)
|
||||
token_idx = tl.program_id(0)
|
||||
bw_idx = tl.program_id(1)
|
||||
|
||||
req_state_idx = tl.load(expanded_idx_mapping_ptr + logit_idx)
|
||||
req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
|
||||
num_bad_words = tl.load(num_bad_words_ptr + req_state_idx)
|
||||
|
||||
if bw_idx >= num_bad_words:
|
||||
return
|
||||
|
||||
pos = tl.load(expanded_local_pos_ptr + logit_idx)
|
||||
cur_req_first_pos = logit_idx - pos
|
||||
pos = tl.load(expanded_local_pos_ptr + token_idx)
|
||||
cur_req_first_pos = token_idx - pos
|
||||
|
||||
prompt_len = tl.load(prompt_len_ptr + req_state_idx)
|
||||
total_len = tl.load(total_len_ptr + req_state_idx)
|
||||
@@ -159,7 +159,7 @@ def _bad_words_kernel(
|
||||
match = match & (expected == actual)
|
||||
|
||||
if match:
|
||||
tl.store(logits_ptr + logit_idx * logits_stride + last_token, -float("inf"))
|
||||
tl.store(logits_ptr + token_idx * logits_stride + last_token, -float("inf"))
|
||||
|
||||
|
||||
def apply_bad_words(
|
||||
@@ -175,8 +175,8 @@ def apply_bad_words(
|
||||
expanded_local_pos: torch.Tensor,
|
||||
max_num_bad_words: int,
|
||||
) -> None:
|
||||
total_num_tokens = logits.shape[0]
|
||||
_bad_words_kernel[(total_num_tokens, max_num_bad_words)](
|
||||
num_tokens = logits.shape[0]
|
||||
_bad_words_kernel[(num_tokens, max_num_bad_words)](
|
||||
logits,
|
||||
logits.stride(0),
|
||||
expanded_idx_mapping,
|
||||
|
||||
@@ -9,13 +9,13 @@ from vllm.triton_utils import tl, triton
|
||||
def _temperature_kernel(
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
idx_mapping_ptr,
|
||||
expanded_idx_mapping_ptr,
|
||||
temperature_ptr,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
||||
token_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
|
||||
temperature = tl.load(temperature_ptr + req_state_idx).to(tl.float32)
|
||||
if temperature == 0.0 or temperature == 1.0:
|
||||
# Early return to avoid loading logits.
|
||||
@@ -25,24 +25,24 @@ def _temperature_kernel(
|
||||
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < vocab_size
|
||||
|
||||
logits = tl.load(logits_ptr + batch_idx * logits_stride + block, mask=mask)
|
||||
logits = tl.load(logits_ptr + token_idx * logits_stride + block, mask=mask)
|
||||
logits = logits.to(tl.float32)
|
||||
logits = logits / temperature
|
||||
tl.store(logits_ptr + batch_idx * logits_stride + block, logits, mask=mask)
|
||||
tl.store(logits_ptr + token_idx * logits_stride + block, logits, mask=mask)
|
||||
|
||||
|
||||
def apply_temperature(
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
expanded_idx_mapping: torch.Tensor,
|
||||
temperature: torch.Tensor,
|
||||
) -> None:
|
||||
num_reqs, vocab_size = logits.shape
|
||||
num_tokens, vocab_size = logits.shape
|
||||
BLOCK_SIZE = 8192
|
||||
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
|
||||
_temperature_kernel[(num_reqs, num_blocks)](
|
||||
_temperature_kernel[(num_tokens, num_blocks)](
|
||||
logits,
|
||||
logits.stride(0),
|
||||
idx_mapping,
|
||||
expanded_idx_mapping,
|
||||
temperature,
|
||||
vocab_size,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
@@ -57,7 +57,7 @@ def _gumbel_sample_kernel(
|
||||
local_max_stride,
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
idx_mapping_ptr,
|
||||
expanded_idx_mapping_ptr,
|
||||
seeds_ptr,
|
||||
pos_ptr,
|
||||
temp_ptr,
|
||||
@@ -65,14 +65,14 @@ def _gumbel_sample_kernel(
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
APPLY_TEMPERATURE: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
||||
token_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
|
||||
|
||||
block_idx = tl.program_id(1)
|
||||
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < vocab_size
|
||||
logits = tl.load(
|
||||
logits_ptr + batch_idx * logits_stride + block,
|
||||
logits_ptr + token_idx * logits_stride + block,
|
||||
mask=mask,
|
||||
other=float("-inf"),
|
||||
)
|
||||
@@ -82,7 +82,7 @@ def _gumbel_sample_kernel(
|
||||
if temp != 0.0:
|
||||
# Calculate the seed for gumbel noise.
|
||||
seed = tl.load(seeds_ptr + req_state_idx)
|
||||
pos = tl.load(pos_ptr + batch_idx)
|
||||
pos = tl.load(pos_ptr + token_idx)
|
||||
gumbel_seed = tl.randint(seed, pos)
|
||||
|
||||
# Generate gumbel noise in FP32.
|
||||
@@ -101,41 +101,41 @@ def _gumbel_sample_kernel(
|
||||
|
||||
value, idx = tl.max(logits, axis=0, return_indices=True)
|
||||
token_id = block_idx * BLOCK_SIZE + idx
|
||||
tl.store(local_argmax_ptr + batch_idx * local_argmax_stride + block_idx, token_id)
|
||||
tl.store(local_max_ptr + batch_idx * local_max_stride + block_idx, value)
|
||||
tl.store(local_argmax_ptr + token_idx * local_argmax_stride + block_idx, token_id)
|
||||
tl.store(local_max_ptr + token_idx * local_max_stride + block_idx, value)
|
||||
|
||||
|
||||
def gumbel_sample(
|
||||
logits: torch.Tensor, # [num_reqs, vocab_size]
|
||||
idx_mapping: torch.Tensor, # [max_num_reqs]
|
||||
logits: torch.Tensor, # [num_tokens, vocab_size]
|
||||
expanded_idx_mapping: torch.Tensor, # [num_tokens]
|
||||
temperature: torch.Tensor, # [max_num_reqs]
|
||||
seed: torch.Tensor, # [max_num_reqs]
|
||||
pos: torch.Tensor, # [num_reqs]
|
||||
pos: torch.Tensor, # [num_tokens]
|
||||
apply_temperature: bool,
|
||||
) -> torch.Tensor:
|
||||
num_reqs, vocab_size = logits.shape
|
||||
num_tokens, vocab_size = logits.shape
|
||||
BLOCK_SIZE = 1024
|
||||
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
|
||||
local_argmax = torch.empty(
|
||||
num_reqs,
|
||||
num_tokens,
|
||||
num_blocks,
|
||||
dtype=torch.int64,
|
||||
device=logits.device,
|
||||
)
|
||||
local_max = torch.empty(
|
||||
num_reqs,
|
||||
num_tokens,
|
||||
num_blocks,
|
||||
dtype=torch.float32,
|
||||
device=logits.device,
|
||||
)
|
||||
_gumbel_sample_kernel[(num_reqs, num_blocks)](
|
||||
_gumbel_sample_kernel[(num_tokens, num_blocks)](
|
||||
local_argmax,
|
||||
local_argmax.stride(0),
|
||||
local_max,
|
||||
local_max.stride(0),
|
||||
logits,
|
||||
logits.stride(0),
|
||||
idx_mapping,
|
||||
expanded_idx_mapping,
|
||||
seed,
|
||||
pos,
|
||||
temperature,
|
||||
|
||||
@@ -121,7 +121,7 @@ class LogitBiasState:
|
||||
def apply_logit_bias(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
expanded_idx_mapping: torch.Tensor,
|
||||
idx_mapping_np: np.ndarray,
|
||||
pos: torch.Tensor,
|
||||
) -> None:
|
||||
@@ -131,7 +131,7 @@ class LogitBiasState:
|
||||
|
||||
apply_logit_bias(
|
||||
logits,
|
||||
idx_mapping,
|
||||
expanded_idx_mapping,
|
||||
pos,
|
||||
self.num_allowed_token_ids.gpu,
|
||||
self.allowed_token_ids.gpu,
|
||||
@@ -149,7 +149,7 @@ def _bias_kernel(
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
vocab_size,
|
||||
idx_mapping_ptr,
|
||||
expanded_idx_mapping_ptr,
|
||||
# Allowed token IDs.
|
||||
num_allowed_token_ids_ptr,
|
||||
allowed_token_ids_ptr,
|
||||
@@ -169,8 +169,8 @@ def _bias_kernel(
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
LOGITS_BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
||||
token_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
|
||||
|
||||
block = tl.arange(0, BLOCK_SIZE)
|
||||
|
||||
@@ -186,21 +186,21 @@ def _bias_kernel(
|
||||
mask=mask,
|
||||
)
|
||||
logits = tl.load(
|
||||
logits_ptr + batch_idx * logits_stride + allowed_token_ids, mask=mask
|
||||
logits_ptr + token_idx * logits_stride + allowed_token_ids, mask=mask
|
||||
)
|
||||
|
||||
# Set logits to -inf for all tokens.
|
||||
for i in range(0, vocab_size, LOGITS_BLOCK_SIZE):
|
||||
offset = i + tl.arange(0, LOGITS_BLOCK_SIZE)
|
||||
tl.store(
|
||||
logits_ptr + batch_idx * logits_stride + offset,
|
||||
logits_ptr + token_idx * logits_stride + offset,
|
||||
-float("inf"),
|
||||
mask=offset < vocab_size,
|
||||
)
|
||||
|
||||
# Restore logits for allowed token IDs.
|
||||
tl.store(
|
||||
logits_ptr + batch_idx * logits_stride + allowed_token_ids,
|
||||
logits_ptr + token_idx * logits_stride + allowed_token_ids,
|
||||
logits,
|
||||
mask=mask,
|
||||
)
|
||||
@@ -214,13 +214,13 @@ def _bias_kernel(
|
||||
mask=mask,
|
||||
)
|
||||
bias = tl.load(bias_ptr + req_state_idx * bias_stride + block, mask=mask)
|
||||
logits = tl.load(logits_ptr + batch_idx * logits_stride + token_ids, mask=mask)
|
||||
logits = tl.load(logits_ptr + token_idx * logits_stride + token_ids, mask=mask)
|
||||
logits += bias
|
||||
tl.store(logits_ptr + batch_idx * logits_stride + token_ids, logits, mask=mask)
|
||||
tl.store(logits_ptr + token_idx * logits_stride + token_ids, logits, mask=mask)
|
||||
|
||||
# Apply min tokens.
|
||||
num_stop_token_ids = tl.load(num_stop_token_ids_ptr + req_state_idx)
|
||||
pos = tl.load(pos_ptr + batch_idx)
|
||||
pos = tl.load(pos_ptr + token_idx)
|
||||
min_len = tl.load(min_lens_ptr + req_state_idx)
|
||||
if num_stop_token_ids > 0 and pos < min_len:
|
||||
mask = block < num_stop_token_ids
|
||||
@@ -229,7 +229,7 @@ def _bias_kernel(
|
||||
mask=mask,
|
||||
)
|
||||
tl.store(
|
||||
logits_ptr + batch_idx * logits_stride + stop_token_ids,
|
||||
logits_ptr + token_idx * logits_stride + stop_token_ids,
|
||||
-float("inf"),
|
||||
mask=mask,
|
||||
)
|
||||
@@ -237,7 +237,7 @@ def _bias_kernel(
|
||||
|
||||
def apply_logit_bias(
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
expanded_idx_mapping: torch.Tensor,
|
||||
pos: torch.Tensor,
|
||||
num_allowed_token_ids: torch.Tensor,
|
||||
allowed_token_ids: torch.Tensor,
|
||||
@@ -248,7 +248,7 @@ def apply_logit_bias(
|
||||
num_stop_token_ids: torch.Tensor,
|
||||
stop_token_ids: torch.Tensor,
|
||||
) -> None:
|
||||
num_reqs, vocab_size = logits.shape
|
||||
num_tokens, vocab_size = logits.shape
|
||||
BLOCK_SIZE = triton.next_power_of_2(
|
||||
max(
|
||||
allowed_token_ids.shape[-1],
|
||||
@@ -257,11 +257,11 @@ def apply_logit_bias(
|
||||
)
|
||||
)
|
||||
LOGITS_BLOCK_SIZE = 8192
|
||||
_bias_kernel[(num_reqs,)](
|
||||
_bias_kernel[(num_tokens,)](
|
||||
logits,
|
||||
logits.stride(0),
|
||||
vocab_size,
|
||||
idx_mapping,
|
||||
expanded_idx_mapping,
|
||||
num_allowed_token_ids,
|
||||
allowed_token_ids,
|
||||
allowed_token_ids.stride(0),
|
||||
|
||||
@@ -9,13 +9,13 @@ from vllm.triton_utils import tl, triton
|
||||
def _min_p_kernel(
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
idx_mapping_ptr,
|
||||
expanded_idx_mapping_ptr,
|
||||
min_p_ptr,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + req_idx)
|
||||
token_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
|
||||
min_p = tl.load(min_p_ptr + req_state_idx).to(tl.float32)
|
||||
if min_p == 0.0:
|
||||
return
|
||||
@@ -25,7 +25,9 @@ def _min_p_kernel(
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < vocab_size
|
||||
logits = tl.load(
|
||||
logits_ptr + req_idx * logits_stride + block, mask=mask, other=float("-inf")
|
||||
logits_ptr + token_idx * logits_stride + block,
|
||||
mask=mask,
|
||||
other=float("-inf"),
|
||||
)
|
||||
max_val = tl.max(tl.maximum(logits, max_val))
|
||||
max_val = max_val.to(tl.float32) # type: ignore
|
||||
@@ -35,21 +37,23 @@ def _min_p_kernel(
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < vocab_size
|
||||
logits = tl.load(
|
||||
logits_ptr + req_idx * logits_stride + block, mask=mask, other=float("-inf")
|
||||
logits_ptr + token_idx * logits_stride + block,
|
||||
mask=mask,
|
||||
other=float("-inf"),
|
||||
)
|
||||
logits = tl.where(logits < threshold, float("-inf"), logits)
|
||||
tl.store(logits_ptr + req_idx * logits_stride + block, logits, mask=mask)
|
||||
tl.store(logits_ptr + token_idx * logits_stride + block, logits, mask=mask)
|
||||
|
||||
|
||||
def apply_min_p(
|
||||
logits: torch.Tensor, idx_mapping: torch.Tensor, min_p: torch.Tensor
|
||||
logits: torch.Tensor, expanded_idx_mapping: torch.Tensor, min_p: torch.Tensor
|
||||
) -> None:
|
||||
num_reqs, vocab_size = logits.shape
|
||||
num_tokens, vocab_size = logits.shape
|
||||
BLOCK_SIZE = 1024
|
||||
_min_p_kernel[(num_reqs,)](
|
||||
_min_p_kernel[(num_tokens,)](
|
||||
logits,
|
||||
logits.stride(0),
|
||||
idx_mapping,
|
||||
expanded_idx_mapping,
|
||||
min_p,
|
||||
vocab_size,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
|
||||
@@ -82,7 +82,7 @@ class PenaltiesState:
|
||||
def apply_penalties(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
expanded_idx_mapping: torch.Tensor,
|
||||
idx_mapping_np: np.ndarray,
|
||||
input_ids: torch.Tensor,
|
||||
expanded_local_pos: torch.Tensor,
|
||||
@@ -94,7 +94,7 @@ class PenaltiesState:
|
||||
|
||||
apply_penalties(
|
||||
logits,
|
||||
idx_mapping,
|
||||
expanded_idx_mapping,
|
||||
input_ids,
|
||||
expanded_local_pos,
|
||||
self.repetition_penalty.gpu,
|
||||
@@ -110,7 +110,7 @@ class PenaltiesState:
|
||||
def _penalties_kernel(
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
idx_mapping_ptr,
|
||||
expanded_idx_mapping_ptr,
|
||||
token_ids_ptr,
|
||||
expanded_local_pos_ptr,
|
||||
repetition_penalty_ptr,
|
||||
@@ -125,7 +125,7 @@ def _penalties_kernel(
|
||||
MAX_SPEC_LEN: tl.constexpr,
|
||||
):
|
||||
token_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + token_idx)
|
||||
req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
|
||||
rep_penalty = tl.load(repetition_penalty_ptr + req_state_idx)
|
||||
freq_penalty = tl.load(frequency_penalty_ptr + req_state_idx)
|
||||
pres_penalty = tl.load(presence_penalty_ptr + req_state_idx)
|
||||
@@ -191,7 +191,7 @@ def _penalties_kernel(
|
||||
|
||||
def apply_penalties(
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
expanded_idx_mapping: torch.Tensor,
|
||||
token_ids: torch.Tensor,
|
||||
expanded_local_pos: torch.Tensor,
|
||||
repetition_penalty: torch.Tensor,
|
||||
@@ -207,7 +207,7 @@ def apply_penalties(
|
||||
_penalties_kernel[(num_tokens, num_blocks)](
|
||||
logits,
|
||||
logits.stride(0),
|
||||
idx_mapping,
|
||||
expanded_idx_mapping,
|
||||
token_ids,
|
||||
expanded_local_pos,
|
||||
repetition_penalty,
|
||||
@@ -225,7 +225,7 @@ def apply_penalties(
|
||||
|
||||
@triton.jit
|
||||
def _bincount_kernel(
|
||||
idx_mapping_ptr,
|
||||
expanded_idx_mapping_ptr,
|
||||
all_token_ids_ptr,
|
||||
all_token_ids_stride,
|
||||
prompt_len_ptr,
|
||||
@@ -236,9 +236,9 @@ def _bincount_kernel(
|
||||
output_bin_counts_stride,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
token_idx = tl.program_id(0)
|
||||
block_idx = tl.program_id(1)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
||||
req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
|
||||
|
||||
prefill_len = tl.load(prefill_len_ptr + req_state_idx)
|
||||
if block_idx * BLOCK_SIZE >= prefill_len:
|
||||
@@ -276,7 +276,7 @@ def _bincount_kernel(
|
||||
|
||||
|
||||
def bincount(
|
||||
idx_mapping: torch.Tensor,
|
||||
expanded_idx_mapping: torch.Tensor,
|
||||
all_token_ids: torch.Tensor,
|
||||
prompt_len: torch.Tensor,
|
||||
prefill_len: torch.Tensor,
|
||||
@@ -284,13 +284,13 @@ def bincount(
|
||||
output_bin_counts: torch.Tensor,
|
||||
max_prefill_len: int,
|
||||
) -> None:
|
||||
prompt_bin_mask[idx_mapping] = 0
|
||||
output_bin_counts[idx_mapping] = 0
|
||||
num_reqs = idx_mapping.shape[0]
|
||||
prompt_bin_mask[expanded_idx_mapping] = 0
|
||||
output_bin_counts[expanded_idx_mapping] = 0
|
||||
num_tokens = expanded_idx_mapping.shape[0]
|
||||
BLOCK_SIZE = 1024
|
||||
num_blocks = triton.cdiv(max_prefill_len, BLOCK_SIZE)
|
||||
_bincount_kernel[(num_reqs, num_blocks)](
|
||||
idx_mapping,
|
||||
_bincount_kernel[(num_tokens, num_blocks)](
|
||||
expanded_idx_mapping,
|
||||
all_token_ids,
|
||||
all_token_ids.stride(0),
|
||||
prompt_len,
|
||||
|
||||
@@ -56,7 +56,7 @@ class Sampler:
|
||||
def __call__(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
expanded_idx_mapping: torch.Tensor,
|
||||
idx_mapping_np: np.ndarray,
|
||||
cu_num_logits_np: np.ndarray,
|
||||
pos: torch.Tensor,
|
||||
@@ -68,7 +68,7 @@ class Sampler:
|
||||
num_nans = get_num_nans(logits) if self.compute_nans else None
|
||||
sampled, processed_logits = self.sample(
|
||||
logits,
|
||||
idx_mapping,
|
||||
expanded_idx_mapping,
|
||||
idx_mapping_np,
|
||||
pos,
|
||||
input_ids,
|
||||
@@ -101,7 +101,7 @@ class Sampler:
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
expanded_idx_mapping: torch.Tensor,
|
||||
idx_mapping_np: np.ndarray,
|
||||
pos: torch.Tensor,
|
||||
input_ids: torch.Tensor,
|
||||
@@ -111,12 +111,14 @@ class Sampler:
|
||||
logits = torch.empty_like(logits, dtype=torch.float32).copy_(logits)
|
||||
|
||||
# Apply logit bias (e.g., allowed_token_ids, min_tokens) in place.
|
||||
self.logit_bias_state.apply_logit_bias(logits, idx_mapping, idx_mapping_np, pos)
|
||||
self.logit_bias_state.apply_logit_bias(
|
||||
logits, expanded_idx_mapping, idx_mapping_np, pos
|
||||
)
|
||||
|
||||
# Apply penalties in place.
|
||||
self.penalties_state.apply_penalties(
|
||||
logits,
|
||||
idx_mapping,
|
||||
expanded_idx_mapping,
|
||||
idx_mapping_np,
|
||||
input_ids,
|
||||
expanded_local_pos,
|
||||
@@ -126,27 +128,29 @@ class Sampler:
|
||||
# Apply bad words masking in place.
|
||||
self.bad_words_state.apply_bad_words(
|
||||
logits,
|
||||
idx_mapping,
|
||||
expanded_idx_mapping,
|
||||
idx_mapping_np,
|
||||
input_ids,
|
||||
expanded_local_pos,
|
||||
)
|
||||
|
||||
# Apply temperature in place.
|
||||
self.sampling_states.apply_temperature(logits, idx_mapping, idx_mapping_np)
|
||||
self.sampling_states.apply_temperature(
|
||||
logits, expanded_idx_mapping, idx_mapping_np
|
||||
)
|
||||
|
||||
# Apply min_p in place.
|
||||
self.sampling_states.apply_min_p(logits, idx_mapping, idx_mapping_np)
|
||||
self.sampling_states.apply_min_p(logits, expanded_idx_mapping, idx_mapping_np)
|
||||
|
||||
# Apply top_k and/or top_p. This might or might not return a new tensor.
|
||||
logits = self.sampling_states.apply_top_k_top_p(
|
||||
logits, idx_mapping, idx_mapping_np
|
||||
logits, expanded_idx_mapping, idx_mapping_np
|
||||
)
|
||||
|
||||
# Sample the next token.
|
||||
sampled = gumbel_sample(
|
||||
logits,
|
||||
idx_mapping,
|
||||
expanded_idx_mapping,
|
||||
self.sampling_states.temperature.gpu,
|
||||
self.sampling_states.seeds.gpu,
|
||||
pos,
|
||||
|
||||
@@ -64,7 +64,7 @@ class SamplingStates:
|
||||
def apply_temperature(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
expanded_idx_mapping: torch.Tensor,
|
||||
idx_mapping_np: np.ndarray,
|
||||
) -> None:
|
||||
temp_np = self.temperature.np[idx_mapping_np]
|
||||
@@ -72,23 +72,23 @@ class SamplingStates:
|
||||
# No request requires temperature. Skip the kernel launch.
|
||||
return
|
||||
|
||||
apply_temperature(logits, idx_mapping, self.temperature.gpu)
|
||||
apply_temperature(logits, expanded_idx_mapping, self.temperature.gpu)
|
||||
|
||||
def apply_min_p(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
expanded_idx_mapping: torch.Tensor,
|
||||
idx_mapping_np: np.ndarray,
|
||||
) -> None:
|
||||
if np.all(self.min_p.np[idx_mapping_np] == 0.0):
|
||||
# No request uses min_p. Skip the kernel launch.
|
||||
return
|
||||
apply_min_p(logits, idx_mapping, self.min_p.gpu)
|
||||
apply_min_p(logits, expanded_idx_mapping, self.min_p.gpu)
|
||||
|
||||
def apply_top_k_top_p(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
expanded_idx_mapping: torch.Tensor,
|
||||
idx_mapping_np: np.ndarray,
|
||||
) -> torch.Tensor:
|
||||
do_top_k = np.any(self.top_k.np[idx_mapping_np] != self.vocab_size)
|
||||
@@ -96,8 +96,8 @@ class SamplingStates:
|
||||
if not (do_top_k or do_top_p):
|
||||
return logits
|
||||
|
||||
top_k = self.top_k.gpu[idx_mapping] if do_top_k else None
|
||||
top_p = self.top_p.gpu[idx_mapping] if do_top_p else None
|
||||
top_k = self.top_k.gpu[expanded_idx_mapping] if do_top_k else None
|
||||
top_p = self.top_p.gpu[expanded_idx_mapping] if do_top_p else None
|
||||
return apply_top_k_top_p(logits, top_k, top_p)
|
||||
|
||||
def max_num_logprobs(self, idx_mapping_np: np.ndarray) -> int:
|
||||
|
||||
@@ -17,6 +17,7 @@ from vllm.v1.worker.gpu.cudagraph_utils import (
|
||||
)
|
||||
from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
|
||||
from vllm.v1.worker.gpu.input_batch import InputBuffers
|
||||
from vllm.v1.worker.gpu.model_states.interface import ModelState
|
||||
from vllm.v1.worker.utils import AttentionGroup
|
||||
|
||||
|
||||
@@ -54,11 +55,32 @@ class EagleCudaGraphManager:
|
||||
def get_cudagraph_size(self, num_tokens: int) -> int | None:
|
||||
return self.cudagraph_sizes.get(num_tokens)
|
||||
|
||||
def get_cudagraph_runtime_mode(
|
||||
self, num_tokens: int
|
||||
) -> tuple[CUDAGraphMode, int | None]:
|
||||
cudagraph_size = self.get_cudagraph_size(num_tokens)
|
||||
if cudagraph_size is None:
|
||||
cudagraph_mode = CUDAGraphMode.NONE
|
||||
else:
|
||||
cudagraph_mode = self.cudagraph_mode
|
||||
|
||||
if (
|
||||
cudagraph_mode == CUDAGraphMode.FULL
|
||||
and cudagraph_size is not None
|
||||
and cudagraph_size not in self.graphs
|
||||
):
|
||||
# If graph wasn't captured yet, fall back to eager.
|
||||
# This might happen when the dummy run is called before capture.
|
||||
cudagraph_mode = CUDAGraphMode.NONE
|
||||
cudagraph_size = None
|
||||
return cudagraph_mode, cudagraph_size
|
||||
|
||||
def capture_graph(
|
||||
self,
|
||||
num_tokens: int,
|
||||
capture_cg_mode: CUDAGraphMode,
|
||||
generate_fn: Callable,
|
||||
model_state: ModelState,
|
||||
input_buffers: InputBuffers,
|
||||
block_tables: BlockTables,
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
@@ -76,12 +98,11 @@ class EagleCudaGraphManager:
|
||||
attn_metadata, slot_mappings = prepare_inputs_to_capture(
|
||||
num_reqs,
|
||||
num_tokens,
|
||||
model_state,
|
||||
input_buffers,
|
||||
block_tables,
|
||||
attn_groups,
|
||||
self.max_model_len,
|
||||
kv_cache_config,
|
||||
uniform_decode_query_len=1,
|
||||
)
|
||||
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
|
||||
|
||||
@@ -158,6 +179,7 @@ class EagleCudaGraphManager:
|
||||
def capture(
|
||||
self,
|
||||
generate_fn: Callable,
|
||||
model_state: ModelState,
|
||||
input_buffers: InputBuffers,
|
||||
block_tables: BlockTables,
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
@@ -173,6 +195,7 @@ class EagleCudaGraphManager:
|
||||
capture_cudagraph_mode=self.cudagraph_mode,
|
||||
desc=f"Capturing eagle CUDA graphs ({self.cudagraph_mode.name})",
|
||||
generate_fn=generate_fn,
|
||||
model_state=model_state,
|
||||
input_buffers=input_buffers,
|
||||
block_tables=block_tables,
|
||||
attn_groups=attn_groups,
|
||||
|
||||
@@ -16,7 +16,9 @@ from vllm.v1.worker.gpu.attn_utils import (
|
||||
build_slot_mappings_by_layer,
|
||||
)
|
||||
from vllm.v1.worker.gpu.block_table import BlockTables
|
||||
from vllm.v1.worker.gpu.dp_utils import get_cudagraph_and_dp_padding
|
||||
from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
|
||||
from vllm.v1.worker.gpu.model_states.interface import ModelState
|
||||
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
|
||||
from vllm.v1.worker.gpu.spec_decode.eagle.cudagraph import EagleCudaGraphManager
|
||||
from vllm.v1.worker.gpu.spec_decode.eagle.utils import load_eagle_model
|
||||
@@ -44,10 +46,13 @@ class EagleSpeculator:
|
||||
# the draft model's hidden size can be different from the target model's
|
||||
# hidden size (e.g., Llama 3.3 70B).
|
||||
self.hidden_size = self.draft_model_config.get_hidden_size()
|
||||
self.inputs_embeds_size = self.draft_model_config.get_inputs_embeds_size()
|
||||
self.vocab_size = self.draft_model_config.get_vocab_size()
|
||||
self.dtype = vllm_config.model_config.dtype
|
||||
|
||||
# DP configuration
|
||||
self.dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
|
||||
self.input_buffers = InputBuffers(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_num_tokens=self.max_num_tokens,
|
||||
@@ -77,10 +82,12 @@ class EagleSpeculator:
|
||||
|
||||
def set_attn(
|
||||
self,
|
||||
model_state: ModelState,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
block_tables: BlockTables,
|
||||
) -> None:
|
||||
self.model_state = model_state
|
||||
self.kv_cache_config = kv_cache_config
|
||||
self.attn_groups = attn_groups
|
||||
self.block_tables = block_tables
|
||||
@@ -120,8 +127,8 @@ class EagleSpeculator:
|
||||
self,
|
||||
num_reqs: int,
|
||||
num_tokens_padded: int,
|
||||
attn_metadata: dict[str, Any],
|
||||
slot_mappings: dict[str, torch.Tensor],
|
||||
attn_metadata: dict[str, Any] | None,
|
||||
slot_mappings: dict[str, torch.Tensor] | None,
|
||||
num_tokens_across_dp: torch.Tensor | None,
|
||||
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||
) -> None:
|
||||
@@ -162,9 +169,10 @@ class EagleSpeculator:
|
||||
self.hidden_states,
|
||||
self.max_model_len,
|
||||
)
|
||||
self.block_tables.compute_slot_mappings(
|
||||
idx_mapping, query_start_loc, pos
|
||||
)
|
||||
if attn_metadata is not None:
|
||||
self.block_tables.compute_slot_mappings(
|
||||
idx_mapping, query_start_loc, pos
|
||||
)
|
||||
|
||||
def capture_model(self) -> None:
|
||||
if self.num_speculative_steps == 1:
|
||||
@@ -172,6 +180,7 @@ class EagleSpeculator:
|
||||
logger.info("Capturing model for Eagle speculator...")
|
||||
self.cudagraph_manager.capture(
|
||||
self.generate_draft,
|
||||
self.model_state,
|
||||
self.input_buffers,
|
||||
self.block_tables,
|
||||
self.attn_groups,
|
||||
@@ -182,6 +191,8 @@ class EagleSpeculator:
|
||||
def propose(
|
||||
self,
|
||||
input_batch: InputBatch,
|
||||
attn_metadata: dict[str, Any],
|
||||
slot_mappings: dict[str, torch.Tensor],
|
||||
# [num_tokens, hidden_size]
|
||||
last_hidden_states: torch.Tensor,
|
||||
# num_layers x [num_tokens, hidden_size]
|
||||
@@ -198,6 +209,9 @@ class EagleSpeculator:
|
||||
temperature: torch.Tensor,
|
||||
# [max_num_reqs]
|
||||
seeds: torch.Tensor,
|
||||
num_tokens_across_dp: torch.Tensor | None = None,
|
||||
dummy_run: bool = False,
|
||||
skip_attn_for_dummy_run: bool = False,
|
||||
) -> torch.Tensor:
|
||||
# NOTE(woosuk): To avoid CPU-GPU synchronization without CPU knowing the
|
||||
# number of rejected tokens, we maintain the size of eagle's input_ids and
|
||||
@@ -229,9 +243,9 @@ class EagleSpeculator:
|
||||
# TODO(woosuk): Support CUDA graph for prefill.
|
||||
last_hidden_states, hidden_states = self.run_model(
|
||||
num_tokens,
|
||||
input_batch.attn_metadata,
|
||||
input_batch.slot_mappings,
|
||||
num_tokens_across_dp=None, # FIXME
|
||||
attn_metadata,
|
||||
slot_mappings,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
)
|
||||
sample_hidden_states = last_hidden_states[last_token_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states)
|
||||
@@ -277,48 +291,64 @@ class EagleSpeculator:
|
||||
self.max_model_len,
|
||||
self.max_num_reqs,
|
||||
)
|
||||
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
|
||||
slot_mappings = self.block_tables.compute_slot_mappings(
|
||||
idx_mapping, query_start_loc, pos
|
||||
)
|
||||
|
||||
cudagraph_size = self.cudagraph_manager.get_cudagraph_size(num_reqs)
|
||||
cudagraph_mode = self.cudagraph_manager.cudagraph_mode
|
||||
if cudagraph_size is not None and cudagraph_mode == CUDAGraphMode.FULL:
|
||||
if not (dummy_run and skip_attn_for_dummy_run):
|
||||
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
|
||||
slot_mappings = self.block_tables.compute_slot_mappings(
|
||||
idx_mapping, query_start_loc, pos
|
||||
)
|
||||
|
||||
cudagraph_mode, cudagraph_size = (
|
||||
self.cudagraph_manager.get_cudagraph_runtime_mode(num_reqs)
|
||||
)
|
||||
num_tokens_padded, num_tokens_across_dp, synced_cudagraph_mode = (
|
||||
get_cudagraph_and_dp_padding(
|
||||
num_reqs,
|
||||
cudagraph_size,
|
||||
cudagraph_mode.value,
|
||||
self.dp_size,
|
||||
self.dp_rank,
|
||||
)
|
||||
)
|
||||
cudagraph_mode = CUDAGraphMode(synced_cudagraph_mode)
|
||||
if cudagraph_mode == CUDAGraphMode.FULL:
|
||||
# Run full CUDA graph.
|
||||
self.cudagraph_manager.run_fullgraph(cudagraph_size)
|
||||
self.cudagraph_manager.run_fullgraph(num_tokens_padded)
|
||||
return self.draft_tokens[:num_reqs]
|
||||
|
||||
# Run eager or piecewise CUDA graph.
|
||||
num_tokens_padded = cudagraph_size if cudagraph_size is not None else num_reqs
|
||||
query_start_loc_cpu = torch.arange(
|
||||
num_reqs + 1, dtype=torch.int32, device="cpu"
|
||||
)
|
||||
block_tables = [x[:num_reqs] for x in self.block_tables.input_block_tables]
|
||||
attn_metadata_updated = None
|
||||
slot_mappings_updated = None
|
||||
if not (dummy_run and skip_attn_for_dummy_run):
|
||||
query_start_loc_cpu = torch.arange(
|
||||
num_reqs + 1, dtype=torch.int32, device="cpu"
|
||||
)
|
||||
block_tables = [x[:num_reqs] for x in self.block_tables.input_block_tables]
|
||||
|
||||
# FIXME(woosuk): This is UNSAFE!!
|
||||
attn_metadata_updated = build_attn_metadata(
|
||||
attn_groups=self.attn_groups,
|
||||
num_reqs=num_reqs,
|
||||
num_tokens=num_reqs,
|
||||
query_start_loc_gpu=query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
max_query_len=1,
|
||||
seq_lens=self.input_buffers.seq_lens[:num_reqs],
|
||||
max_seq_len=self.max_model_len,
|
||||
block_tables=block_tables,
|
||||
slot_mappings=slot_mappings,
|
||||
kv_cache_config=self.kv_cache_config,
|
||||
)
|
||||
slot_mappings_updated = build_slot_mappings_by_layer(
|
||||
slot_mappings, self.kv_cache_config
|
||||
)
|
||||
|
||||
# FIXME(woosuk): This is UNSAFE!!
|
||||
attn_metadata = build_attn_metadata(
|
||||
attn_groups=self.attn_groups,
|
||||
num_reqs=num_reqs,
|
||||
num_tokens=num_reqs,
|
||||
query_start_loc_gpu=query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
max_query_len=1,
|
||||
seq_lens=self.input_buffers.seq_lens[:num_reqs],
|
||||
max_seq_len=self.max_model_len,
|
||||
block_tables=block_tables,
|
||||
slot_mappings=slot_mappings,
|
||||
kv_cache_config=self.kv_cache_config,
|
||||
)
|
||||
slot_mappings_by_layer = build_slot_mappings_by_layer(
|
||||
slot_mappings, self.kv_cache_config
|
||||
)
|
||||
self.generate_draft(
|
||||
num_reqs,
|
||||
num_tokens_padded,
|
||||
attn_metadata,
|
||||
slot_mappings_by_layer,
|
||||
num_tokens_across_dp=None, # FIXME
|
||||
attn_metadata_updated,
|
||||
slot_mappings_updated,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
cudagraph_runtime_mode=cudagraph_mode,
|
||||
)
|
||||
return self.draft_tokens[:num_reqs]
|
||||
|
||||
105
vllm/v1/worker/gpu/warmup.py
Normal file
105
vllm/v1/worker/gpu/warmup.py
Normal file
@@ -0,0 +1,105 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm import PoolingParams, SamplingParams
|
||||
from vllm.v1.core.sched.output import (
|
||||
CachedRequestData,
|
||||
GrammarOutput,
|
||||
NewRequestData,
|
||||
SchedulerOutput,
|
||||
)
|
||||
from vllm.v1.request import Request
|
||||
from vllm.v1.worker.gpu.model_runner import GPUModelRunner
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def warmup_kernels(model_runner: GPUModelRunner) -> None:
|
||||
"""Run two execute_model + sample_tokens iterations to JIT compile
|
||||
triton kernels.
|
||||
|
||||
The first iteration simulates a prefill with requests of 2 prompt
|
||||
tokens each. The second iteration simulates a decode step with all
|
||||
requests generating 1 token each.
|
||||
"""
|
||||
prompt_token_ids = [0, 1]
|
||||
prompt_len = len(prompt_token_ids)
|
||||
num_reqs = min(
|
||||
model_runner.scheduler_config.max_num_seqs,
|
||||
model_runner.scheduler_config.max_num_batched_tokens // prompt_len,
|
||||
)
|
||||
|
||||
num_kv_cache_groups = len(model_runner.kv_cache_config.kv_cache_groups)
|
||||
req_ids = [f"_warmup_{i}_" for i in range(num_reqs)]
|
||||
|
||||
# SamplingParams exercising all sampling features.
|
||||
if model_runner.is_pooling_model:
|
||||
sampling_params = None
|
||||
pooling_params = PoolingParams()
|
||||
else:
|
||||
sampling_params = SamplingParams.for_sampler_warmup()
|
||||
pooling_params = None
|
||||
|
||||
# Step 1: Prefill all requests with 2 prompt tokens each.
|
||||
new_reqs = [
|
||||
NewRequestData.from_request(
|
||||
Request(req_ids[i], prompt_token_ids, sampling_params, pooling_params),
|
||||
# Each request uses a distinct block per KV cache group.
|
||||
block_ids=tuple([i] for _ in range(num_kv_cache_groups)),
|
||||
prefill_token_ids=prompt_token_ids,
|
||||
)
|
||||
for i in range(num_reqs)
|
||||
]
|
||||
|
||||
prefill_output = SchedulerOutput.make_empty()
|
||||
prefill_output.scheduled_new_reqs = new_reqs
|
||||
prefill_output.num_scheduled_tokens = {rid: prompt_len for rid in req_ids}
|
||||
prefill_output.total_num_scheduled_tokens = prompt_len * num_reqs
|
||||
prefill_output.num_common_prefix_blocks = [0] * num_kv_cache_groups
|
||||
|
||||
# Disable KV connector for warmup run.
|
||||
model_runner.kv_connector.set_disabled(True)
|
||||
model_runner.execute_model(prefill_output)
|
||||
|
||||
if not model_runner.is_pooling_model:
|
||||
# Warm up sampler and perform a decode step for non-pooling models.
|
||||
|
||||
grammar_output = None
|
||||
if model_runner.is_last_pp_rank:
|
||||
# Build a GrammarOutput to exercise the structured output bitmask
|
||||
# kernel during the prefill step.
|
||||
vocab_size = model_runner.model_config.get_vocab_size()
|
||||
bitmask_width = (vocab_size + 31) // 32
|
||||
grammar_bitmask = np.full(
|
||||
(len(req_ids), bitmask_width), fill_value=-1, dtype=np.int32
|
||||
)
|
||||
grammar_output = GrammarOutput(
|
||||
structured_output_request_ids=req_ids, grammar_bitmask=grammar_bitmask
|
||||
)
|
||||
|
||||
model_runner.sample_tokens(grammar_output)
|
||||
|
||||
# Step 2: Decode all requests with 1 token each.
|
||||
cached_req_data = CachedRequestData.make_empty()
|
||||
cached_req_data.req_ids = list(req_ids)
|
||||
cached_req_data.new_block_ids = [None] * num_reqs
|
||||
cached_req_data.num_computed_tokens = [prompt_len] * num_reqs
|
||||
cached_req_data.num_output_tokens = [1] * num_reqs
|
||||
|
||||
decode_output = SchedulerOutput.make_empty()
|
||||
decode_output.scheduled_cached_reqs = cached_req_data
|
||||
decode_output.num_scheduled_tokens = {rid: 1 for rid in req_ids}
|
||||
decode_output.total_num_scheduled_tokens = num_reqs
|
||||
decode_output.num_common_prefix_blocks = [0] * num_kv_cache_groups
|
||||
|
||||
model_runner.execute_model(decode_output)
|
||||
model_runner.sample_tokens(None)
|
||||
|
||||
# Clean up - process finish_req_ids.
|
||||
cleanup_output = SchedulerOutput.make_empty()
|
||||
cleanup_output.finished_req_ids = set(req_ids)
|
||||
model_runner.execute_model(cleanup_output)
|
||||
model_runner.kv_connector.set_disabled(False)
|
||||
torch.cuda.synchronize()
|
||||
@@ -53,6 +53,13 @@ class CachedRequestState:
|
||||
pooling_params: PoolingParams | None = None
|
||||
pooling_states: PoolingStates | None = None
|
||||
|
||||
# for multi layer eagle proposer
|
||||
cached_len: torch.Tensor | None = None
|
||||
cached_token_ids: torch.Tensor | None = None
|
||||
cached_hidden_states: torch.Tensor | None = None
|
||||
cached_slot_mappings: torch.Tensor | None = None
|
||||
cached_positions: torch.Tensor | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
|
||||
self.prompt_token_ids, self.prompt_embeds
|
||||
@@ -95,6 +102,8 @@ class InputBatch:
|
||||
is_spec_decode: bool = False,
|
||||
is_pooling_model: bool = False,
|
||||
cp_kv_cache_interleave_size: int = 1,
|
||||
multi_layer_eagle_num: int = 0,
|
||||
hidden_size: int | None = None,
|
||||
):
|
||||
self.is_pooling_model = is_pooling_model
|
||||
self.is_spec_decode = is_spec_decode
|
||||
@@ -211,6 +220,46 @@ class InputBatch:
|
||||
)
|
||||
self.num_accepted_tokens_cpu = self.num_accepted_tokens_cpu_tensor.numpy()
|
||||
|
||||
# Multi layer eagle
|
||||
self.multi_layer_eagle_num = multi_layer_eagle_num
|
||||
if multi_layer_eagle_num > 0:
|
||||
self.cached_len = torch.zeros(
|
||||
(max_num_reqs,), dtype=torch.int64, device=device
|
||||
)
|
||||
self.cached_token_ids = torch.zeros(
|
||||
(
|
||||
max_num_reqs,
|
||||
multi_layer_eagle_num,
|
||||
),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.cached_hidden_states = torch.zeros(
|
||||
(
|
||||
max_num_reqs,
|
||||
multi_layer_eagle_num,
|
||||
hidden_size,
|
||||
),
|
||||
dtype=torch.float,
|
||||
device=device,
|
||||
)
|
||||
self.cached_slot_mappings = torch.zeros(
|
||||
(
|
||||
max_num_reqs,
|
||||
multi_layer_eagle_num,
|
||||
),
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
self.cached_positions = torch.zeros(
|
||||
(
|
||||
max_num_reqs,
|
||||
multi_layer_eagle_num,
|
||||
),
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# lora related
|
||||
self.request_lora_mapping = np.zeros((self.max_num_reqs,), dtype=np.int64)
|
||||
self.lora_id_to_request_ids: dict[int, set[str]] = {}
|
||||
@@ -425,6 +474,13 @@ class InputBatch:
|
||||
# Speculative decoding: by default 1 token is generated.
|
||||
self.num_accepted_tokens_cpu[req_index] = 1
|
||||
|
||||
if self.multi_layer_eagle_num > 0:
|
||||
self.cached_len[req_index] = request.cached_len
|
||||
self.cached_token_ids[req_index] = request.cached_token_ids
|
||||
self.cached_hidden_states[req_index] = request.cached_hidden_states
|
||||
self.cached_slot_mappings[req_index] = request.cached_slot_mappings
|
||||
self.cached_positions[req_index] = request.cached_positions
|
||||
|
||||
# Add request lora ID
|
||||
if request.lora_request:
|
||||
lora_id = request.lora_request.lora_int_id
|
||||
@@ -623,6 +679,24 @@ class InputBatch:
|
||||
self.allowed_token_ids_mask_cpu_tensor[i1],
|
||||
)
|
||||
|
||||
if self.multi_layer_eagle_num > 0:
|
||||
self.cached_len[i1], self.cached_len[i2] = (
|
||||
self.cached_len[i2],
|
||||
self.cached_len[i1],
|
||||
)
|
||||
self.cached_token_ids[[i1, i2], ...] = self.cached_token_ids[
|
||||
[i2, i1], ...
|
||||
]
|
||||
self.cached_hidden_states[[i1, i2], ...] = self.cached_hidden_states[
|
||||
[i2, i1], ...
|
||||
]
|
||||
self.cached_slot_mappings[[i1, i2], ...] = self.cached_slot_mappings[
|
||||
[i2, i1], ...
|
||||
]
|
||||
self.cached_positions[[i1, i2], ...] = self.cached_positions[
|
||||
[i2, i1], ...
|
||||
]
|
||||
|
||||
def condense(self) -> None:
|
||||
"""Slide non-empty requests down into lower, empty indices.
|
||||
|
||||
@@ -745,6 +819,21 @@ class InputBatch:
|
||||
if bad_words_token_ids is not None:
|
||||
self.bad_words_token_ids[empty_index] = bad_words_token_ids
|
||||
|
||||
if self.multi_layer_eagle_num > 0:
|
||||
self.cached_len[empty_index] = self.cached_len[last_req_index]
|
||||
self.cached_token_ids[empty_index] = self.cached_token_ids[
|
||||
last_req_index
|
||||
]
|
||||
self.cached_hidden_states[empty_index] = self.cached_hidden_states[
|
||||
last_req_index
|
||||
]
|
||||
self.cached_slot_mappings[empty_index] = self.cached_slot_mappings[
|
||||
last_req_index
|
||||
]
|
||||
self.cached_positions[empty_index] = self.cached_positions[
|
||||
last_req_index
|
||||
]
|
||||
|
||||
# Decrement last_req_index since it is now empty.
|
||||
last_req_index -= 1
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -7,11 +7,10 @@ import os
|
||||
from collections.abc import Callable
|
||||
from contextlib import AbstractContextManager, nullcontext
|
||||
from types import NoneType
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.nn as nn
|
||||
|
||||
import vllm.envs as envs
|
||||
@@ -32,14 +31,13 @@ from vllm.distributed.kv_transfer import (
|
||||
)
|
||||
from vllm.distributed.parallel_state import (
|
||||
Handle,
|
||||
get_pcp_group,
|
||||
get_pp_group,
|
||||
get_tp_group,
|
||||
get_world_group
|
||||
)
|
||||
from vllm.distributed.weight_transfer import WeightTransferEngineFactory
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.models.interfaces import is_mixture_of_experts
|
||||
from vllm.model_executor.warmup.kernel_warmup import kernel_warmup
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.profiler.wrapper import CudaProfilerWrapper, TorchProfilerWrapper
|
||||
@@ -49,7 +47,6 @@ from vllm.tracing import instrument
|
||||
from vllm.utils.mem_utils import MemorySnapshot, format_gib, memory_profiling
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||||
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||
from vllm.v1.outputs import (
|
||||
AsyncModelRunnerOutput,
|
||||
@@ -61,6 +58,8 @@ from vllm.v1.worker.utils import is_residual_scattered_for_sp
|
||||
from vllm.v1.worker.worker_base import WorkerBase
|
||||
from vllm.v1.worker.workspace import init_workspace_manager
|
||||
|
||||
from ...model_executor.model_loader import TensorizerLoader
|
||||
from .gpu.warmup import warmup_kernels
|
||||
from .utils import request_memory
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -123,6 +122,10 @@ class Worker(WorkerBase):
|
||||
precision = envs.VLLM_FLOAT32_MATMUL_PRECISION
|
||||
torch.set_float32_matmul_precision(precision)
|
||||
|
||||
from vllm.distributed.elastic_ep.elastic_execute import ElasticEPScalingExecutor
|
||||
|
||||
self.elastic_ep_executor = ElasticEPScalingExecutor(self)
|
||||
|
||||
# Buffers saved before sleep
|
||||
self._sleep_saved_buffers: dict[str, torch.Tensor] = {}
|
||||
|
||||
@@ -316,12 +319,29 @@ class Worker(WorkerBase):
|
||||
# FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
|
||||
# to hijack tensor allocation.
|
||||
def load_model(self) -> None:
|
||||
eep_scale_up = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
|
||||
dummy_weights = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
|
||||
if dummy_weights:
|
||||
(
|
||||
expanded_physical_to_logical,
|
||||
num_logical_experts,
|
||||
old_num_physical_experts,
|
||||
) = self.elastic_ep_executor.receive_expert_mapping()
|
||||
num_physical_experts = expanded_physical_to_logical.shape[1]
|
||||
self.parallel_config.eplb_config.num_redundant_experts = (
|
||||
num_physical_experts - num_logical_experts
|
||||
)
|
||||
|
||||
with (
|
||||
self._maybe_get_memory_pool_context(tag="weights"),
|
||||
set_current_vllm_config(self.vllm_config),
|
||||
):
|
||||
self.model_runner.load_model(eep_scale_up=eep_scale_up)
|
||||
self.model_runner.load_model(load_dummy_weights=dummy_weights)
|
||||
|
||||
if dummy_weights:
|
||||
self.model_runner.setup_eplb_from_mapping(
|
||||
expanded_physical_to_logical, old_num_physical_experts
|
||||
)
|
||||
self.model_runner.eep_eplb_suppressed = True
|
||||
|
||||
def update_config(self, overrides: dict[str, Any]) -> None:
|
||||
self.model_runner.update_config(overrides)
|
||||
@@ -421,9 +441,10 @@ class Worker(WorkerBase):
|
||||
# metadata across workers.
|
||||
if (metadata := connector.get_handshake_metadata()) is None:
|
||||
return None
|
||||
|
||||
tp_rank = get_tp_group().rank_in_group
|
||||
return {tp_rank: metadata}
|
||||
|
||||
# tp_rank = get_tp_group().rank_in_group
|
||||
global_rank = get_world_group().rank_in_group
|
||||
return {global_rank: metadata}
|
||||
|
||||
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
||||
return self.model_runner.get_kv_cache_spec()
|
||||
@@ -461,8 +482,16 @@ class Worker(WorkerBase):
|
||||
else:
|
||||
self.model_runner.initialize_kv_cache(kv_cache_config)
|
||||
|
||||
# Build KV-zero metadata outside the CuMem pool so the bookkeeping
|
||||
# GPU tensors (seg_addrs, block-id buffers) use the standard PyTorch
|
||||
# allocator and are not discarded during sleep/wake cycles.
|
||||
if kv_cache_config.needs_kv_cache_zeroing and hasattr(
|
||||
self.model_runner, "_init_kv_zero_meta"
|
||||
):
|
||||
self.model_runner._init_kv_zero_meta()
|
||||
|
||||
@instrument(span_name="Warmup (GPU)")
|
||||
def compile_or_warm_up_model(self) -> None:
|
||||
def compile_or_warm_up_model(self) -> float:
|
||||
warmup_sizes = []
|
||||
|
||||
if self.vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE:
|
||||
@@ -558,12 +587,15 @@ class Worker(WorkerBase):
|
||||
|
||||
logger.debug(msg)
|
||||
|
||||
# Warm up sampler and preallocate memory buffer for logits and other
|
||||
# sampling related tensors of max possible shape to avoid memory
|
||||
# fragmentation issue.
|
||||
# NOTE: This is called after `capture_model` on purpose to prevent
|
||||
# memory buffers from being cleared by `torch.cuda.empty_cache`.
|
||||
if get_pp_group().is_last_rank:
|
||||
if self.use_v2_model_runner:
|
||||
# V2: Run full execute_model + sample_tokens to JIT compile triton kernels.
|
||||
warmup_kernels(self.model_runner)
|
||||
elif get_pp_group().is_last_rank:
|
||||
# V1: Warm up sampler and preallocate memory buffer for logits and other
|
||||
# sampling related tensors of max possible shape to avoid memory
|
||||
# fragmentation issue.
|
||||
# NOTE: This is called after `capture_model` on purpose to prevent
|
||||
# memory buffers from being cleared by `torch.cuda.empty_cache`.
|
||||
max_num_reqs = min(
|
||||
self.scheduler_config.max_num_seqs,
|
||||
self.scheduler_config.max_num_batched_tokens,
|
||||
@@ -584,6 +616,8 @@ class Worker(WorkerBase):
|
||||
# the model initialization and profiling.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
return self.compilation_config.compilation_time
|
||||
|
||||
def reset_mm_cache(self) -> None:
|
||||
self.model_runner.reset_mm_cache()
|
||||
|
||||
@@ -696,6 +730,12 @@ class Worker(WorkerBase):
|
||||
output = self.model_runner.execute_model(
|
||||
scheduler_output, intermediate_tensors
|
||||
)
|
||||
if (
|
||||
self.use_v2_model_runner
|
||||
and self.model_runner.is_pooling_model
|
||||
and output is None
|
||||
):
|
||||
output = self.model_runner.pool() # type: ignore
|
||||
if isinstance(
|
||||
output, ModelRunnerOutput | AsyncModelRunnerOutput | NoneType
|
||||
):
|
||||
@@ -744,7 +784,8 @@ class Worker(WorkerBase):
|
||||
|
||||
# Create the profiler wrapper only on the first start call
|
||||
if self.profiler is None:
|
||||
if self.profiler_config.profiler == "torch":
|
||||
profiler_type = self.profiler_config.profiler
|
||||
if profiler_type == "torch":
|
||||
self.profiler = TorchProfilerWrapper(
|
||||
self.profiler_config,
|
||||
worker_name=trace_name,
|
||||
@@ -754,14 +795,18 @@ class Worker(WorkerBase):
|
||||
logger.debug(
|
||||
"Starting torch profiler with trace name: %s", trace_name
|
||||
)
|
||||
elif self.profiler_config.profiler == "cuda":
|
||||
elif profiler_type == "cuda":
|
||||
self.profiler = CudaProfilerWrapper(self.profiler_config)
|
||||
logger.debug("Starting CUDA profiler")
|
||||
self.profiler.start()
|
||||
else:
|
||||
# Profiler already initialized. Restart profiling but keep
|
||||
# the original trace name from the first initialization.
|
||||
self.profiler.start()
|
||||
else:
|
||||
# Config validation should prevent this code being reached
|
||||
raise ValueError(
|
||||
f"Invalid profiler value of {self.profiler_config.profiler}"
|
||||
)
|
||||
|
||||
# If profiler already initialized, restart profiling but keep
|
||||
# the original trace name from the first initialization.
|
||||
self.profiler.start()
|
||||
else:
|
||||
if self.profiler is None:
|
||||
logger.warning("Profiler was not started, nothing to stop.")
|
||||
@@ -787,227 +832,6 @@ class Worker(WorkerBase):
|
||||
# worker will always be healthy as long as it's running.
|
||||
return
|
||||
|
||||
def _eplb_before_scale_down(self, old_ep_size: int, new_ep_size: int) -> None:
|
||||
from vllm.distributed.parallel_state import get_ep_group
|
||||
|
||||
if get_ep_group().rank == 0:
|
||||
logger.info(
|
||||
"[Elastic EP] Starting expert resharding before scaling down..."
|
||||
)
|
||||
rank_mapping = {
|
||||
old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1
|
||||
for old_ep_rank in range(old_ep_size)
|
||||
}
|
||||
assert self.model_runner.eplb_state is not None
|
||||
self.model_runner.eplb_state.rearrange(
|
||||
execute_shuffle=True,
|
||||
global_expert_loads=None,
|
||||
rank_mapping=rank_mapping,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
if get_ep_group().rank == 0:
|
||||
logger.info("[Elastic EP] Expert resharding completed!")
|
||||
|
||||
def _eplb_after_scale_up(
|
||||
self,
|
||||
old_ep_size: int,
|
||||
new_ep_size: int,
|
||||
global_expert_loads: list[torch.Tensor] | None,
|
||||
) -> None:
|
||||
from vllm.distributed.parallel_state import get_ep_group
|
||||
|
||||
if get_ep_group().rank == 0:
|
||||
logger.info("[Elastic EP] Starting expert resharding after scaling up...")
|
||||
rank_mapping = {old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size)}
|
||||
assert self.model_runner.eplb_state is not None
|
||||
self.model_runner.eplb_state.rearrange(
|
||||
execute_shuffle=True,
|
||||
global_expert_loads=global_expert_loads,
|
||||
rank_mapping=rank_mapping,
|
||||
)
|
||||
if get_ep_group().rank == 0:
|
||||
logger.info("[Elastic EP] Expert resharding completed!")
|
||||
|
||||
def _reconfigure_parallel_config(
|
||||
self, reconfig_request: ReconfigureDistributedRequest
|
||||
) -> None:
|
||||
"""
|
||||
Update parallel config with provided reconfig_request
|
||||
"""
|
||||
parallel_config = self.vllm_config.parallel_config
|
||||
parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
|
||||
if (
|
||||
reconfig_request.new_data_parallel_rank
|
||||
!= ReconfigureRankType.KEEP_CURRENT_RANK
|
||||
):
|
||||
parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
|
||||
if (
|
||||
reconfig_request.new_data_parallel_rank_local
|
||||
!= ReconfigureRankType.KEEP_CURRENT_RANK
|
||||
):
|
||||
parallel_config.data_parallel_rank_local = (
|
||||
reconfig_request.new_data_parallel_rank_local
|
||||
)
|
||||
parallel_config.data_parallel_master_ip = (
|
||||
reconfig_request.new_data_parallel_master_ip
|
||||
)
|
||||
parallel_config.data_parallel_master_port = (
|
||||
reconfig_request.new_data_parallel_master_port
|
||||
)
|
||||
|
||||
def _reconfigure_moe(
|
||||
self, old_ep_size: int, new_ep_size: int
|
||||
) -> list[torch.Tensor] | None:
|
||||
"""
|
||||
Reconfigure MoE modules with provided reconfig_request
|
||||
|
||||
Return the global expert load if new_ep_size > old_ep_size,
|
||||
otherwise None
|
||||
"""
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_dp_group,
|
||||
get_ep_group,
|
||||
prepare_communication_buffer_for_model,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE,
|
||||
FusedMoEParallelConfig,
|
||||
)
|
||||
|
||||
parallel_config = self.vllm_config.parallel_config
|
||||
|
||||
def get_moe_modules(model: torch.nn.Module) -> list[FusedMoE]:
|
||||
return [
|
||||
module
|
||||
for module in model.modules()
|
||||
if (
|
||||
module.__class__.__name__ == "FusedMoE"
|
||||
or module.__class__.__name__ == "SharedFusedMoE"
|
||||
)
|
||||
]
|
||||
|
||||
def update_moe_modules(moe_modules: list[FusedMoE], num_local_experts: int):
|
||||
assert all(
|
||||
module.moe_config.num_local_experts == num_local_experts
|
||||
for module in moe_modules
|
||||
), "All MoE modules must have the same number of experts"
|
||||
for module in moe_modules:
|
||||
module.moe_config.num_experts = num_local_experts * new_ep_size
|
||||
module.global_num_experts = module.moe_config.num_experts
|
||||
tp_size = get_tp_group().world_size
|
||||
is_sequence_parallel = parallel_config.use_sequence_parallel_moe
|
||||
sp_size = tp_size if is_sequence_parallel else 1
|
||||
module.moe_parallel_config = FusedMoEParallelConfig.make(
|
||||
tp_size_=tp_size,
|
||||
pcp_size_=get_pcp_group().world_size,
|
||||
dp_size_=get_dp_group().world_size,
|
||||
sp_size_=sp_size,
|
||||
vllm_parallel_config=parallel_config,
|
||||
)
|
||||
module.moe_config.moe_parallel_config = module.moe_parallel_config
|
||||
return moe_modules
|
||||
|
||||
model_moe_modules = get_moe_modules(self.model_runner.model)
|
||||
num_local_experts = model_moe_modules[0].moe_config.num_local_experts
|
||||
|
||||
update_moe_modules(model_moe_modules, num_local_experts)
|
||||
drafter_model = None
|
||||
if hasattr(self.model_runner, "drafter") and hasattr(
|
||||
self.model_runner.drafter, "model"
|
||||
):
|
||||
drafter_model = self.model_runner.drafter.model
|
||||
if drafter_model is not None and is_mixture_of_experts(drafter_model):
|
||||
drafter_moe_modules = get_moe_modules(drafter_model)
|
||||
# Check if drafter and model have matching configs
|
||||
assert (
|
||||
drafter_moe_modules[0].moe_config.num_local_experts == num_local_experts
|
||||
), "Drafter and model configs should be the same"
|
||||
update_moe_modules(drafter_moe_modules, num_local_experts)
|
||||
|
||||
if new_ep_size < old_ep_size:
|
||||
num_local_physical_experts = num_local_experts
|
||||
assert self.model_runner.eplb_state is not None
|
||||
new_physical_experts = (
|
||||
self.model_runner.eplb_state.physical_to_logical_map.shape[1] # type: ignore[attr-defined]
|
||||
)
|
||||
parallel_config.eplb_config.num_redundant_experts = (
|
||||
new_physical_experts
|
||||
- self.model_runner.eplb_state.logical_replica_count.shape[1] # type: ignore[attr-defined]
|
||||
)
|
||||
global_expert_loads = None
|
||||
else:
|
||||
num_local_physical_experts_tensor = torch.tensor(
|
||||
[num_local_experts], dtype=torch.int32, device="cpu"
|
||||
)
|
||||
torch.distributed.broadcast(
|
||||
num_local_physical_experts_tensor,
|
||||
group=get_ep_group().cpu_group,
|
||||
group_src=0,
|
||||
)
|
||||
num_local_physical_experts = int(num_local_physical_experts_tensor.item())
|
||||
new_physical_experts = num_local_physical_experts * new_ep_size
|
||||
assert self.model_runner.eplb_state is not None
|
||||
global_expert_loads_any = self.model_runner.eplb_state.rearrange(
|
||||
execute_shuffle=False
|
||||
)
|
||||
global_expert_loads = cast(list[torch.Tensor], global_expert_loads_any)
|
||||
parallel_config.eplb_config.num_redundant_experts = (
|
||||
new_physical_experts - global_expert_loads[0].shape[1]
|
||||
)
|
||||
prepare_communication_buffer_for_model(self.model_runner.model)
|
||||
if drafter_model is not None:
|
||||
prepare_communication_buffer_for_model(drafter_model)
|
||||
self.model_runner.model.update_physical_experts_metadata(
|
||||
num_physical_experts=new_physical_experts,
|
||||
num_local_physical_experts=num_local_physical_experts,
|
||||
)
|
||||
return global_expert_loads
|
||||
|
||||
def reinitialize_distributed(
|
||||
self, reconfig_request: ReconfigureDistributedRequest
|
||||
) -> None:
|
||||
from vllm.config import set_current_vllm_config
|
||||
from vllm.distributed.parallel_state import (
|
||||
cleanup_dist_env_and_memory,
|
||||
get_ep_group,
|
||||
)
|
||||
|
||||
old_ep_size = get_ep_group().world_size
|
||||
old_ep_rank = get_ep_group().rank
|
||||
new_ep_size = (
|
||||
reconfig_request.new_data_parallel_size
|
||||
* get_tp_group().world_size
|
||||
* get_pp_group().world_size
|
||||
)
|
||||
if new_ep_size < old_ep_size:
|
||||
self._eplb_before_scale_down(old_ep_size, new_ep_size)
|
||||
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
if (
|
||||
reconfig_request.new_data_parallel_rank
|
||||
== ReconfigureRankType.SHUTDOWN_CURRENT_RANK
|
||||
):
|
||||
assert old_ep_rank >= new_ep_size
|
||||
# shutdown
|
||||
return
|
||||
|
||||
self._reconfigure_parallel_config(reconfig_request)
|
||||
|
||||
with set_current_vllm_config(self.vllm_config):
|
||||
init_worker_distributed_environment(
|
||||
self.vllm_config,
|
||||
self.rank,
|
||||
self.distributed_init_method,
|
||||
self.local_rank,
|
||||
)
|
||||
|
||||
global_expert_loads = self._reconfigure_moe(old_ep_size, new_ep_size)
|
||||
|
||||
if new_ep_size > old_ep_size:
|
||||
assert global_expert_loads is not None
|
||||
self._eplb_after_scale_up(old_ep_size, new_ep_size, global_expert_loads)
|
||||
|
||||
def save_sharded_state(
|
||||
self,
|
||||
path: str,
|
||||
@@ -1023,12 +847,11 @@ class Worker(WorkerBase):
|
||||
max_size=max_size,
|
||||
)
|
||||
|
||||
def save_tensorized_model(
|
||||
self,
|
||||
tensorizer_config: "TensorizerConfig",
|
||||
) -> None:
|
||||
self.model_runner.save_tensorized_model(
|
||||
def save_tensorized_model(self, tensorizer_config: "TensorizerConfig") -> None:
|
||||
TensorizerLoader.save_model(
|
||||
self.get_model(),
|
||||
tensorizer_config=tensorizer_config,
|
||||
model_config=self.model_config,
|
||||
)
|
||||
|
||||
def init_weight_transfer_engine(self, init_info: dict) -> None:
|
||||
@@ -1104,6 +927,9 @@ class Worker(WorkerBase):
|
||||
if weight_transfer_engine := getattr(self, "weight_transfer_engine", None):
|
||||
weight_transfer_engine.shutdown()
|
||||
|
||||
def elastic_ep_execute(self, execute_method: str, *args, **kwargs):
|
||||
return self.elastic_ep_executor.execute(execute_method, *args, **kwargs)
|
||||
|
||||
|
||||
def init_worker_distributed_environment(
|
||||
vllm_config: VllmConfig,
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import dataclasses
|
||||
import itertools
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
@@ -13,6 +15,7 @@ from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState
|
||||
from vllm.v1.worker.lora_model_runner_mixin import GPUInputBatch
|
||||
|
||||
@@ -59,10 +62,36 @@ def get_mamba_groups(kv_cache_config: KVCacheConfig) -> tuple[list[int], MambaSp
|
||||
return mamba_group_ids, mamba_specs[0]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MambaCopyBuffers:
|
||||
src_ptrs: CpuGpuBuffer
|
||||
dst_ptrs: CpuGpuBuffer
|
||||
sizes: CpuGpuBuffer
|
||||
offset: int = 0
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
max_num_reqs: int,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
copy_funcs: tuple[MambaStateCopyFunc, ...],
|
||||
make_buffer: Callable[..., CpuGpuBuffer],
|
||||
) -> "MambaCopyBuffers":
|
||||
mamba_group_ids, _ = get_mamba_groups(kv_cache_config)
|
||||
entries_per_req = sum(
|
||||
len(kv_cache_config.kv_cache_groups[gid].layer_names)
|
||||
for gid in mamba_group_ids
|
||||
) * len(copy_funcs)
|
||||
n = max_num_reqs * entries_per_req
|
||||
return cls(
|
||||
src_ptrs=make_buffer(n, dtype=torch.int64),
|
||||
dst_ptrs=make_buffer(n, dtype=torch.int64),
|
||||
sizes=make_buffer(n, dtype=torch.int32),
|
||||
)
|
||||
|
||||
|
||||
def collect_mamba_copy_meta(
|
||||
src_state_list: list[int],
|
||||
dest_state_list: list[int],
|
||||
num_elements_list: list[int],
|
||||
copy_bufs: MambaCopyBuffers,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
|
||||
mamba_group_ids: list[int],
|
||||
@@ -71,10 +100,15 @@ def collect_mamba_copy_meta(
|
||||
accept_token_bias: int,
|
||||
req_state: CachedRequestState,
|
||||
forward_context: dict[str, Any],
|
||||
):
|
||||
) -> None:
|
||||
if src_block_idx == dest_block_idx and accept_token_bias == 0:
|
||||
return
|
||||
|
||||
src_ptrs_np = copy_bufs.src_ptrs.np
|
||||
dst_ptrs_np = copy_bufs.dst_ptrs.np
|
||||
sizes_np = copy_bufs.sizes.np
|
||||
offset = copy_bufs.offset
|
||||
|
||||
for mamba_group_id in mamba_group_ids:
|
||||
block_ids = req_state.block_ids[mamba_group_id]
|
||||
dest_block_id = block_ids[dest_block_idx]
|
||||
@@ -87,25 +121,23 @@ def collect_mamba_copy_meta(
|
||||
state, block_ids, src_block_idx, accept_token_bias + 1
|
||||
)
|
||||
|
||||
src_state_list.append(copy_spec.start_addr)
|
||||
dest_state_list.append(state[dest_block_id].data_ptr())
|
||||
num_elements_list.append(copy_spec.num_elements * state.element_size())
|
||||
src_ptrs_np[offset] = copy_spec.start_addr
|
||||
dst_ptrs_np[offset] = state[dest_block_id].data_ptr()
|
||||
sizes_np[offset] = copy_spec.num_elements * state.element_size()
|
||||
offset += 1
|
||||
|
||||
copy_bufs.offset = offset
|
||||
|
||||
|
||||
def do_mamba_copy_block(
|
||||
src_state_list: list[int],
|
||||
dest_state_list: list[int],
|
||||
num_elements_list: list[int],
|
||||
):
|
||||
if len(src_state_list) == 0:
|
||||
def do_mamba_copy_block(copy_bufs: MambaCopyBuffers):
|
||||
n = copy_bufs.offset
|
||||
if n == 0:
|
||||
return
|
||||
assert len(src_state_list) == len(dest_state_list)
|
||||
assert len(src_state_list) == len(num_elements_list)
|
||||
src_state_ptrs = torch.tensor(src_state_list, device="cuda", dtype=torch.int64)
|
||||
dst_state_ptrs = torch.tensor(dest_state_list, device="cuda", dtype=torch.int64)
|
||||
num_elements = torch.tensor(num_elements_list, device="cuda", dtype=torch.int32)
|
||||
|
||||
batch_memcpy(src_state_ptrs, dst_state_ptrs, num_elements)
|
||||
batch_memcpy(
|
||||
copy_bufs.src_ptrs.copy_to_gpu(n),
|
||||
copy_bufs.dst_ptrs.copy_to_gpu(n),
|
||||
copy_bufs.sizes.copy_to_gpu(n),
|
||||
)
|
||||
|
||||
|
||||
def preprocess_mamba(
|
||||
@@ -117,6 +149,7 @@ def preprocess_mamba(
|
||||
requests: dict[str, CachedRequestState],
|
||||
forward_context: dict[str, Any],
|
||||
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
|
||||
copy_bufs: MambaCopyBuffers,
|
||||
):
|
||||
"""
|
||||
Copy the mamba state of previous step to the last
|
||||
@@ -138,9 +171,7 @@ def preprocess_mamba(
|
||||
for req_id in itertools.chain(finished_req_ids, preempted_req_ids, resumed_req_ids):
|
||||
mamba_state_idx.pop(req_id, None)
|
||||
|
||||
src_state_list: list[int] = []
|
||||
dest_state_list: list[int] = []
|
||||
num_elements_list: list[int] = []
|
||||
copy_bufs.offset = 0
|
||||
for i, req_id in enumerate(input_batch.req_ids):
|
||||
req_state = requests[req_id]
|
||||
prev_state_idx = mamba_state_idx.get(req_id)
|
||||
@@ -169,9 +200,7 @@ def preprocess_mamba(
|
||||
mamba_state_idx[req_id] = curr_state_idx
|
||||
if prev_state_idx != -1 and prev_state_idx != curr_state_idx:
|
||||
collect_mamba_copy_meta(
|
||||
src_state_list,
|
||||
dest_state_list,
|
||||
num_elements_list,
|
||||
copy_bufs,
|
||||
kv_cache_config,
|
||||
mamba_state_copy_funcs,
|
||||
mamba_group_ids,
|
||||
@@ -182,7 +211,7 @@ def preprocess_mamba(
|
||||
forward_context,
|
||||
)
|
||||
input_batch.num_accepted_tokens_cpu[i] = 1
|
||||
do_mamba_copy_block(src_state_list, dest_state_list, num_elements_list)
|
||||
do_mamba_copy_block(copy_bufs)
|
||||
|
||||
|
||||
def postprocess_mamba(
|
||||
@@ -193,6 +222,7 @@ def postprocess_mamba(
|
||||
mamba_state_idx: dict[str, int],
|
||||
forward_context: dict[str, Any],
|
||||
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
|
||||
copy_bufs: MambaCopyBuffers,
|
||||
):
|
||||
"""
|
||||
If a blocks is converted from partial block to full block in this step, copy the
|
||||
@@ -203,9 +233,7 @@ def postprocess_mamba(
|
||||
num_accepted_tokens_cpu = input_batch.num_accepted_tokens_cpu
|
||||
# NOTE: can be optimized as this function always returns the same result
|
||||
mamba_group_ids, mamba_spec = get_mamba_groups(kv_cache_config)
|
||||
src_state_list: list[int] = []
|
||||
dest_state_list: list[int] = []
|
||||
num_elements_list: list[int] = []
|
||||
copy_bufs.offset = 0
|
||||
for i, req_id in enumerate(input_batch.req_ids):
|
||||
req_state = requests[req_id]
|
||||
num_computed_tokens = req_state.num_computed_tokens
|
||||
@@ -225,9 +253,7 @@ def postprocess_mamba(
|
||||
src_block_idx = mamba_state_idx[req_id]
|
||||
dest_block_idx = aligned_new_computed_tokens // mamba_spec.block_size - 1
|
||||
collect_mamba_copy_meta(
|
||||
src_state_list,
|
||||
dest_state_list,
|
||||
num_elements_list,
|
||||
copy_bufs,
|
||||
kv_cache_config,
|
||||
mamba_state_copy_funcs,
|
||||
mamba_group_ids,
|
||||
@@ -239,4 +265,4 @@ def postprocess_mamba(
|
||||
)
|
||||
if src_block_idx == dest_block_idx:
|
||||
num_accepted_tokens_cpu[i] = 1
|
||||
do_mamba_copy_block(src_state_list, dest_state_list, num_elements_list)
|
||||
do_mamba_copy_block(copy_bufs)
|
||||
|
||||
@@ -2,7 +2,10 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass, field
|
||||
from itertools import product as iprod
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
@@ -12,13 +15,208 @@ from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
|
||||
from vllm.model_executor.models.utils import extract_layer_index
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.math_utils import largest_power_of_2_divisor
|
||||
from vllm.utils.mem_utils import MemorySnapshot, format_gib
|
||||
from vllm.v1.attention.backend import AttentionBackend, AttentionMetadataBuilder
|
||||
from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionMetadataBuilder,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
AttentionSpec,
|
||||
EncoderOnlyAttentionSpec,
|
||||
FullAttentionSpec,
|
||||
KVCacheConfig,
|
||||
KVCacheGroupSpec,
|
||||
KVCacheSpec,
|
||||
MambaSpec,
|
||||
UniformTypeKVCacheSpecs,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _zero_kv_blocks_kernel(
|
||||
seg_addrs_ptr,
|
||||
block_ids_ptr,
|
||||
n_blocks,
|
||||
N_SEGS: tl.constexpr,
|
||||
PAGE_SIZE_EL: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""Zero KV cache blocks across all segments in a single launch.
|
||||
|
||||
Each segment is a contiguous region of one block's data. For backends
|
||||
where blocks are outermost (block_dim=0) there is one segment per
|
||||
buffer. For backends where K/V is outermost (block_dim=1) there are
|
||||
two segments per buffer (one for K, one for V).
|
||||
|
||||
seg_addrs_ptr holds absolute byte addresses (int64) for each segment,
|
||||
allowing segments to live in different CUDA allocations.
|
||||
|
||||
Programs are mapped as (block_index, seg_index, chunk_index).
|
||||
"""
|
||||
pid = tl.program_id(0)
|
||||
chunks = PAGE_SIZE_EL // BLOCK_SIZE
|
||||
work_per_block = N_SEGS * chunks
|
||||
block_index = pid // work_per_block
|
||||
if block_index >= n_blocks:
|
||||
return
|
||||
remainder = pid % work_per_block
|
||||
seg_index = remainder // chunks
|
||||
chunk_index = remainder % chunks
|
||||
block_id = tl.load(block_ids_ptr + block_index)
|
||||
seg_addr = tl.load(seg_addrs_ptr + seg_index)
|
||||
ptr = tl.cast(seg_addr, tl.pointer_type(tl.int32))
|
||||
offset = (
|
||||
block_id.to(tl.int64) * PAGE_SIZE_EL + chunk_index.to(tl.int64) * BLOCK_SIZE
|
||||
)
|
||||
cols = tl.arange(0, BLOCK_SIZE).to(tl.int64)
|
||||
tl.store(ptr + offset + cols, tl.zeros([BLOCK_SIZE], dtype=tl.int32))
|
||||
|
||||
|
||||
class KVBlockZeroer:
|
||||
"""Manages efficient zeroing of KV cache blocks via a Triton kernel.
|
||||
|
||||
Call :meth:`init_meta` once after KV caches are allocated to precompute
|
||||
segment addresses, then call :meth:`zero_block_ids` each step to zero
|
||||
newly-allocated blocks.
|
||||
"""
|
||||
|
||||
def __init__(self, device: torch.device, pin_memory: bool):
|
||||
self.device = device
|
||||
self.pin_memory = pin_memory
|
||||
self._meta: tuple[torch.Tensor, int, int, int] | None = None
|
||||
self._id_cap: int = 0
|
||||
self._ids_pinned: torch.Tensor | None = None
|
||||
self._ids_gpu: torch.Tensor | None = None
|
||||
|
||||
def init_meta(
|
||||
self,
|
||||
attn_groups_iter: Iterable["AttentionGroup"],
|
||||
kernel_block_sizes: list[int],
|
||||
cache_dtype: str,
|
||||
runner_only_attn_layers: set[str],
|
||||
static_forward_context: dict[str, Any],
|
||||
) -> None:
|
||||
"""One-time precomputation for zero_block_ids.
|
||||
|
||||
Builds absolute-address table for the Triton zeroing kernel.
|
||||
Each entry is the absolute byte address of a segment start on the
|
||||
GPU, so segments in different CUDA allocations work correctly.
|
||||
|
||||
Block IDs from the scheduler reference logical blocks whose size
|
||||
may differ from the kernel block size (virtual block splitting).
|
||||
PAGE_SIZE_EL accounts for this ratio so that
|
||||
``block_id * PAGE_SIZE_EL`` lands at the correct offset.
|
||||
|
||||
Only AttentionSpec layers are processed; Mamba layers are skipped.
|
||||
"""
|
||||
seen_ptrs: set[int] = set()
|
||||
seg_addrs: list[int] = []
|
||||
page_size_el: int | None = None
|
||||
|
||||
for group in attn_groups_iter:
|
||||
spec = group.kv_cache_spec
|
||||
if type(spec) is not FullAttentionSpec:
|
||||
continue
|
||||
if group.kv_cache_group_id >= len(kernel_block_sizes):
|
||||
continue
|
||||
kernel_bs = kernel_block_sizes[group.kv_cache_group_id]
|
||||
ratio = spec.block_size // kernel_bs
|
||||
block_dim = group.backend.get_kv_cache_block_dim(
|
||||
kernel_bs,
|
||||
spec.num_kv_heads,
|
||||
spec.head_size,
|
||||
cache_dtype_str=cache_dtype,
|
||||
)
|
||||
|
||||
for layer_name in group.layer_names:
|
||||
if layer_name in runner_only_attn_layers:
|
||||
continue
|
||||
kv = static_forward_context[layer_name].kv_cache[0]
|
||||
if isinstance(kv, list):
|
||||
continue
|
||||
dp = kv.data_ptr()
|
||||
if dp in seen_ptrs:
|
||||
continue
|
||||
seen_ptrs.add(dp)
|
||||
|
||||
el = kv.element_size()
|
||||
cur_bytes = kv.stride(block_dim) * el
|
||||
assert cur_bytes % 4 == 0
|
||||
kernel_block_el = cur_bytes // 4
|
||||
cur_page_el = kernel_block_el * ratio
|
||||
if page_size_el is None:
|
||||
page_size_el = cur_page_el
|
||||
else:
|
||||
assert page_size_el == cur_page_el, (
|
||||
f"Non-uniform page sizes: {page_size_el} vs {cur_page_el}"
|
||||
)
|
||||
|
||||
block_stride_bytes = cur_bytes
|
||||
outer_dims = [
|
||||
d
|
||||
for d in range(block_dim)
|
||||
if kv.stride(d) * el > block_stride_bytes
|
||||
]
|
||||
outer_strides = [kv.stride(d) * el for d in outer_dims]
|
||||
for outer in iprod(*(range(kv.shape[d]) for d in outer_dims)):
|
||||
off_bytes = sum(i * s for i, s in zip(outer, outer_strides))
|
||||
seg_addrs.append(dp + off_bytes)
|
||||
|
||||
if not seg_addrs or page_size_el is None:
|
||||
self._meta = None
|
||||
return
|
||||
|
||||
blk_size = min(largest_power_of_2_divisor(page_size_el), 1024)
|
||||
self._id_cap = 8192
|
||||
self._ids_pinned = torch.empty(
|
||||
self._id_cap,
|
||||
dtype=torch.int64,
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
self._ids_gpu = torch.empty(self._id_cap, dtype=torch.int64, device=self.device)
|
||||
self._meta = (
|
||||
torch.tensor(seg_addrs, dtype=torch.int64, device=self.device),
|
||||
page_size_el,
|
||||
blk_size,
|
||||
len(seg_addrs),
|
||||
)
|
||||
|
||||
def zero_block_ids(self, block_ids: list[int]) -> None:
|
||||
"""Zero the KV cache memory for the given block IDs."""
|
||||
if not block_ids or self._meta is None:
|
||||
return
|
||||
seg_addrs, page_size_el, blk_size, n_segs = self._meta
|
||||
n_blocks = len(block_ids)
|
||||
if n_blocks > self._id_cap:
|
||||
self._id_cap = n_blocks * 2
|
||||
self._ids_pinned = torch.empty(
|
||||
self._id_cap,
|
||||
dtype=torch.int64,
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
self._ids_gpu = torch.empty(
|
||||
self._id_cap, dtype=torch.int64, device=self.device
|
||||
)
|
||||
assert self._ids_pinned is not None and self._ids_gpu is not None
|
||||
self._ids_pinned[:n_blocks].numpy()[:] = block_ids
|
||||
idx = self._ids_gpu[:n_blocks]
|
||||
idx.copy_(self._ids_pinned[:n_blocks], non_blocking=True)
|
||||
grid = (n_blocks * n_segs * (page_size_el // blk_size),)
|
||||
_zero_kv_blocks_kernel[grid](
|
||||
seg_addrs,
|
||||
idx,
|
||||
n_blocks,
|
||||
N_SEGS=n_segs,
|
||||
PAGE_SIZE_EL=page_size_el,
|
||||
BLOCK_SIZE=blk_size,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AttentionGroup:
|
||||
backend: type[AttentionBackend]
|
||||
@@ -36,7 +234,7 @@ class AttentionGroup:
|
||||
self,
|
||||
vllm_config,
|
||||
device,
|
||||
kernel_block_size: int | None,
|
||||
kernel_block_size: int | None = None,
|
||||
num_metadata_builders: int = 1,
|
||||
):
|
||||
kv_cache_spec_builder = (
|
||||
@@ -59,6 +257,119 @@ class AttentionGroup:
|
||||
return self.metadata_builders[ubatch_id]
|
||||
|
||||
|
||||
def select_common_block_size(
|
||||
kv_manager_block_size: int, attn_groups: list[AttentionGroup]
|
||||
) -> int:
|
||||
"""
|
||||
Select a block size that is supported by all backends and is a factor of
|
||||
kv_manager_block_size.
|
||||
|
||||
If kv_manager_block_size is supported by all backends, return it directly.
|
||||
Otherwise, return the max supported size.
|
||||
|
||||
Args:
|
||||
kv_manager_block_size: Block size of KV cache.
|
||||
attn_groups: List of attention groups.
|
||||
|
||||
Returns:
|
||||
The selected block size.
|
||||
|
||||
Raises:
|
||||
ValueError: If no valid block size found.
|
||||
"""
|
||||
|
||||
def block_size_is_supported(
|
||||
backends: list[type[AttentionBackend]], block_size: int
|
||||
) -> bool:
|
||||
"""Check if the block size is supported by all backends."""
|
||||
for backend in backends:
|
||||
is_supported = False
|
||||
for supported_size in backend.get_supported_kernel_block_sizes():
|
||||
if isinstance(supported_size, int):
|
||||
if block_size == supported_size:
|
||||
is_supported = True
|
||||
elif isinstance(supported_size, MultipleOf):
|
||||
if block_size % supported_size.base == 0:
|
||||
is_supported = True
|
||||
else:
|
||||
raise ValueError(f"Unknown supported size: {supported_size}")
|
||||
if not is_supported:
|
||||
return False
|
||||
return True
|
||||
|
||||
backends = [group.backend for group in attn_groups]
|
||||
|
||||
# Case 1: if the block_size of kv cache manager is supported by all backends,
|
||||
# return it directly.
|
||||
if block_size_is_supported(backends, kv_manager_block_size):
|
||||
return kv_manager_block_size
|
||||
|
||||
# Case 2: otherwise, the block_size must be an `int`-format supported size of
|
||||
# at least one backend. Iterate over all `int`-format supported sizes in
|
||||
# descending order and return the first one that is supported by all backends.
|
||||
# Simple proof:
|
||||
# If the supported size b is in MultipleOf(x_i) format for all attention
|
||||
# backends i, and b a factor of kv_manager_block_size, then
|
||||
# kv_manager_block_size also satisfies MultipleOf(x_i) for all i. We will
|
||||
# return kv_manager_block_size in case 1.
|
||||
all_int_supported_sizes = set(
|
||||
supported_size
|
||||
for backend in backends
|
||||
for supported_size in backend.get_supported_kernel_block_sizes()
|
||||
if isinstance(supported_size, int)
|
||||
)
|
||||
|
||||
for supported_size in sorted(all_int_supported_sizes, reverse=True):
|
||||
if kv_manager_block_size % supported_size != 0:
|
||||
continue
|
||||
if block_size_is_supported(backends, supported_size):
|
||||
return supported_size
|
||||
raise ValueError(f"No common block size for {kv_manager_block_size}. ")
|
||||
|
||||
|
||||
def prepare_kernel_block_sizes(
|
||||
kv_cache_config: KVCacheConfig, attn_groups: list[list[AttentionGroup]]
|
||||
) -> list[int]:
|
||||
"""
|
||||
Generate kernel_block_sizes that matches each block_size.
|
||||
|
||||
For attention backends that support virtual block splitting,
|
||||
use the supported block sizes from the backend.
|
||||
For other backends (like Mamba), use the same block size (no splitting).
|
||||
|
||||
Args:
|
||||
kv_cache_config: The KV cache configuration.
|
||||
attn_groups: Attention groups indexed by KV cache group id.
|
||||
|
||||
Returns:
|
||||
List of kernel block sizes for each cache group.
|
||||
"""
|
||||
kernel_block_sizes = []
|
||||
for kv_cache_gid, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
|
||||
kv_cache_spec = kv_cache_group.kv_cache_spec
|
||||
if isinstance(kv_cache_spec, UniformTypeKVCacheSpecs):
|
||||
# All layers in the UniformTypeKVCacheSpecs have the same type,
|
||||
# pick an arbitrary one to dispatch.
|
||||
kv_cache_spec = next(iter(kv_cache_spec.kv_cache_specs.values()))
|
||||
if isinstance(kv_cache_spec, EncoderOnlyAttentionSpec):
|
||||
continue
|
||||
if isinstance(kv_cache_spec, AttentionSpec):
|
||||
# This is an attention backend that supports virtual block splitting.
|
||||
kv_manager_block_size = kv_cache_group.kv_cache_spec.block_size
|
||||
selected_kernel_size = select_common_block_size(
|
||||
kv_manager_block_size, attn_groups[kv_cache_gid]
|
||||
)
|
||||
kernel_block_sizes.append(selected_kernel_size)
|
||||
elif isinstance(kv_cache_spec, MambaSpec):
|
||||
# This is likely Mamba or other non-attention cache, no splitting.
|
||||
kernel_block_sizes.append(kv_cache_spec.block_size)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"unknown kv cache spec {kv_cache_group.kv_cache_spec}"
|
||||
)
|
||||
return kernel_block_sizes
|
||||
|
||||
|
||||
def sanity_check_mm_encoder_outputs(
|
||||
mm_embeddings: MultiModalEmbeddings,
|
||||
expected_num_items: int,
|
||||
@@ -201,6 +512,55 @@ def bind_kv_cache(
|
||||
for layer_name, kv_cache in kv_caches.items():
|
||||
# NOTE: Use list because of v0 PP virtual engine.
|
||||
forward_context[layer_name].kv_cache = [kv_cache]
|
||||
|
||||
def bind_kv_cache_scale(
|
||||
kv_caches_scale: dict[str, torch.Tensor],
|
||||
forward_context: dict[str, "Attention"],
|
||||
runner_kv_caches_scale: list[torch.Tensor],
|
||||
num_attn_module: int | None = 1,
|
||||
) -> None:
|
||||
"""
|
||||
Bind the allocated KV cache to both ModelRunner and forward context so
|
||||
that the KV cache can be used in the forward pass.
|
||||
|
||||
This function:
|
||||
1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with
|
||||
kv_caches.
|
||||
2) Associates each attention layer in the `forward_context` with its
|
||||
corresponding KV cache in kv_caches.
|
||||
|
||||
Args:
|
||||
kv_caches: The allocated kv_caches with layer names as keys.
|
||||
forward_context: The global forward context containing all Attention
|
||||
layers with layer names as keys.
|
||||
runner_kv_caches: The kv_cache declared by ModelRunner.
|
||||
"""
|
||||
# Bind kv_caches to ModelRunner
|
||||
assert len(runner_kv_caches_scale) == 0
|
||||
|
||||
# Convert kv_caches dict to a list of tensors in the order of layer_index.
|
||||
index2name = defaultdict(list)
|
||||
for layer_name in kv_caches_scale:
|
||||
index2name[extract_layer_index(layer_name,
|
||||
num_attn_module)].append(layer_name)
|
||||
|
||||
for layer_index in sorted(index2name.keys()):
|
||||
layer_names = index2name[layer_index]
|
||||
if len(layer_names) > 1:
|
||||
# One typical case is encoder-decoder model, e.g., bart.
|
||||
# The cross attention and self attention in the same decoder layer
|
||||
# has different layer_name but the same layer_index.
|
||||
if current_platform.is_cuda() or current_platform.is_xpu():
|
||||
pass
|
||||
else:
|
||||
raise NotImplementedError
|
||||
layer_name = layer_names[0]
|
||||
runner_kv_caches_scale.append(kv_caches_scale[layer_name])
|
||||
|
||||
# Bind kv_caches to forward context
|
||||
for layer_name, kv_cache_scale in kv_caches_scale.items():
|
||||
# NOTE: Use list because of v0 PP virtual engine.
|
||||
forward_context[layer_name].kv_cache_scale = [kv_cache_scale]
|
||||
|
||||
|
||||
def is_residual_scattered_for_sp(
|
||||
|
||||
@@ -87,8 +87,12 @@ class WorkerBase:
|
||||
"""Get specifications for KV cache implementation."""
|
||||
raise NotImplementedError
|
||||
|
||||
def compile_or_warm_up_model(self) -> None:
|
||||
"""Prepare model for execution through compilation/warmup."""
|
||||
def compile_or_warm_up_model(self) -> float:
|
||||
"""Prepare model for execution through compilation/warmup.
|
||||
|
||||
Returns:
|
||||
The accumulated compilation time in seconds.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def check_health(self) -> None:
|
||||
@@ -213,13 +217,8 @@ class WorkerWrapperBase:
|
||||
It is only used during the initialization of the executor,
|
||||
to adjust the rpc_rank of workers after we create all workers.
|
||||
"""
|
||||
# if self.rpc_rank in rank_mapping:
|
||||
# self.rpc_rank = rank_mapping[self.rpc_rank]
|
||||
old_rank = self.rpc_rank
|
||||
if old_rank in rank_mapping:
|
||||
self.rpc_rank = rank_mapping[old_rank]
|
||||
if self.global_rank == old_rank:
|
||||
self.global_rank = rank_mapping[old_rank]
|
||||
if self.rpc_rank in rank_mapping:
|
||||
self.rpc_rank = rank_mapping[self.rpc_rank]
|
||||
|
||||
def update_environment_variables(
|
||||
self,
|
||||
|
||||
@@ -66,6 +66,23 @@ class WorkspaceManager:
|
||||
],
|
||||
)
|
||||
|
||||
def unlock(self) -> None:
|
||||
"""Unlock the workspace to allow growth.
|
||||
|
||||
This is used during elastic EP scaling when the workspace size
|
||||
needs to grow due to changes in the number of experts.
|
||||
"""
|
||||
self._locked = False
|
||||
if envs.VLLM_DEBUG_WORKSPACE:
|
||||
logger.info(
|
||||
"[WORKSPACE DEBUG] Workspace unlocked. Current sizes: %s",
|
||||
[
|
||||
self._workspace_size_bytes(ws) / _MB
|
||||
for ws in self._current_workspaces
|
||||
if ws is not None
|
||||
],
|
||||
)
|
||||
|
||||
def is_locked(self) -> bool:
|
||||
"""Check if workspace is locked."""
|
||||
return self._locked
|
||||
@@ -242,6 +259,17 @@ def lock_workspace() -> None:
|
||||
current_workspace_manager().lock()
|
||||
|
||||
|
||||
def unlock_workspace() -> None:
|
||||
"""Unlock the workspace to allow growth.
|
||||
|
||||
This is used during elastic EP scaling when the workspace size
|
||||
needs to grow due to changes in the number of experts.
|
||||
After scaling operations complete, lock_workspace() should be
|
||||
called again to prevent unexpected allocations.
|
||||
"""
|
||||
current_workspace_manager().unlock()
|
||||
|
||||
|
||||
def reset_workspace_manager() -> None:
|
||||
"""Reset the workspace manager to uninitialized state.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user