Upgrade to vllm 0.17.0 corex v4.1 overlay

This commit is contained in:
2026-04-29 19:38:22 +08:00
parent 8fac6062e4
commit 938d0854a5
430 changed files with 35969 additions and 14511 deletions

View File

@@ -53,7 +53,12 @@ class CPUModelRunner(GPUModelRunner):
v.gpu = v.cpu
@instrument(span_name="Loading (CPU)")
def load_model(self, eep_scale_up: bool = False) -> None:
def load_model(self, load_dummy_weights: bool = False) -> None:
if load_dummy_weights:
raise ValueError(
"Loading dummy weights (needed for elastic EP scale-up) "
"Is not supported by the CPU Model Runner."
)
logger.info("Starting to load model %s...", self.model_config.model)
self.model = get_model(vllm_config=self.vllm_config)

View File

@@ -85,7 +85,7 @@ class CPUWorker(Worker):
self.local_omp_cpuid = omp_cpuids_list[self.rank]
if self.local_omp_cpuid != "nobind":
ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)
ret = torch.ops._C.init_cpu_threads_env(self.local_omp_cpuid)
if ret:
logger.info(ret)
@@ -118,11 +118,12 @@ class CPUWorker(Worker):
def determine_available_memory(self) -> int:
return self.cache_config.cpu_kvcache_space_bytes or 0
def compile_or_warm_up_model(self) -> None:
def compile_or_warm_up_model(self) -> float:
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
self.model_runner.warming_up_model()
return self.compilation_config.compilation_time
def _get_autobind_cpu_ids(
self, cpu_selector: Callable[[list[LogicalCPUInfo]], list[LogicalCPUInfo]]

View File

@@ -37,7 +37,6 @@ def _get_device_and_group(parallel_config: ParallelConfig):
def _run_ar(
should_ubatch: bool,
should_dp_pad: bool,
orig_num_tokens_per_ubatch: int,
padded_num_tokens_per_ubatch: int,
cudagraph_mode: int,
@@ -46,12 +45,11 @@ def _run_ar(
dp_size = parallel_config.data_parallel_size
dp_rank = parallel_config.data_parallel_rank
device, group = _get_device_and_group(parallel_config)
tensor = torch.zeros(5, dp_size, device=device, dtype=torch.int32)
tensor = torch.zeros(4, dp_size, device=device, dtype=torch.int32)
tensor[0][dp_rank] = orig_num_tokens_per_ubatch
tensor[1][dp_rank] = padded_num_tokens_per_ubatch
tensor[2][dp_rank] = 1 if should_ubatch else 0
tensor[3][dp_rank] = 1 if should_dp_pad else 0
tensor[4][dp_rank] = cudagraph_mode
tensor[3][dp_rank] = cudagraph_mode
dist.all_reduce(tensor, group=group)
return tensor
@@ -97,14 +95,13 @@ def _post_process_cudagraph_mode(tensor: torch.Tensor) -> int:
If any rank has NONE (0), all ranks use NONE.
This ensures all ranks send consistent values (all padded or all unpadded).
"""
return int(tensor[4, :].min().item())
return int(tensor[3, :].min().item())
def _synchronize_dp_ranks(
num_tokens_unpadded: int,
num_tokens_padded: int,
should_attempt_ubatching: bool,
should_attempt_dp_padding: bool,
cudagraph_mode: int,
parallel_config: ParallelConfig,
) -> tuple[bool, torch.Tensor | None, int]:
@@ -113,8 +110,8 @@ def _synchronize_dp_ranks(
run with microbatching or none of them do.
2. Determines the total number of tokens that each rank will run.
When running microbatched or if should_attempt_dp_padding is True, all
ranks will be padded out so that the run with the same number of tokens
When running microbatched or if cudagraph is enabled (synced across ranks),
all ranks will be padded out so that they run with the same number of tokens.
3. Synchronizes cudagraph_mode across ranks by taking the minimum.
@@ -133,29 +130,26 @@ def _synchronize_dp_ranks(
# will run and if we are using ubatching or not.
tensor = _run_ar(
should_ubatch=should_attempt_ubatching,
should_dp_pad=should_attempt_dp_padding,
orig_num_tokens_per_ubatch=num_tokens_unpadded,
padded_num_tokens_per_ubatch=num_tokens_padded,
cudagraph_mode=cudagraph_mode,
parallel_config=parallel_config,
)
should_dp_pad = bool(torch.all(tensor[3] == 1).item())
# DP ranks should all have the same value for should_attempt_dp_padding.
assert should_attempt_dp_padding == should_dp_pad
# Synchronize cudagraph_mode across ranks first (take min).
# This is needed before DP padding decision since we use the synced
# cudagraph mode to determine whether DP padding is needed.
synced_cudagraph_mode = _post_process_cudagraph_mode(tensor)
# Check conditions for microbatching
should_ubatch = _post_process_ubatch(tensor, parallel_config.num_ubatches)
if should_ubatch and not should_dp_pad:
logger.debug_once(
"Microbatching has been triggered and requires DP padding. "
"Enabling DP padding even though it has been explicitly "
"disabled.",
scope="global",
)
should_dp_pad = True
# DP padding is needed when cudagraph is enabled (synced across ranks)
# or when ubatching/DBO is active (ubatching requires uniform batch
# sizes across DP ranks currently).
# Use the synced runtime cudagraph mode rather than the compilation config
# so we can avoid padding when cudagraph is not enabled for this step.
should_dp_pad = synced_cudagraph_mode != 0 or should_ubatch
# Pad all DP ranks up to the maximum token count across ranks if
# should_dp_pad is True
@@ -164,16 +158,12 @@ def _synchronize_dp_ranks(
should_dp_pad,
)
# Synchronize cudagraph_mode across ranks (take min)
synced_cudagraph_mode = _post_process_cudagraph_mode(tensor)
return should_ubatch, num_tokens_after_padding, synced_cudagraph_mode
def coordinate_batch_across_dp(
num_tokens_unpadded: int,
allow_microbatching: bool,
allow_dp_padding: bool,
parallel_config: ParallelConfig,
num_tokens_padded: int | None = None,
uniform_decode: bool | None = None,
@@ -187,7 +177,6 @@ def coordinate_batch_across_dp(
Args:
num_tokens_unpadded: Number of tokens without accounting for padding
allow_microbatching: If microbatching should be attempted
allow_dp_padding: If all DP ranks should be padded up to the same value
parallel_config: The parallel config
num_tokens_padded: Number of tokens including any non-DP padding (CUDA graphs,
TP, etc)
@@ -195,15 +184,15 @@ def coordinate_batch_across_dp(
only contains single token decodes
num_scheduled_tokens_per_request: Only used if allow_microbatching is True. The
number of tokens per request.
cudagraph_mode: The cudagraph mode for this rank (0=NONE, 1=PIECEWISE, 2=FULL)
cudagraph_mode: The cudagraph mode for this rank (0=NONE, 1=PIECEWISE, 2=FULL).
DP padding is enabled when synced cudagraph mode across ranks is not NONE.
Returns: tuple[
ubatch_slices: if this is set then all DP ranks have agreed to
microbatch
num_tokens_after_padding: A tensor containing the total number of
tokens per-microbatch for each DP rank including padding. Will be
padded up to the max value across all DP ranks when allow_dp_padding
is True.
padded up to the max value across all DP ranks when cudagraph is enabled.
synced_cudagraph_mode: The synchronized cudagraph mode (min across ranks)
]
@@ -231,7 +220,6 @@ def coordinate_batch_across_dp(
num_tokens_unpadded,
num_tokens_padded,
should_attempt_ubatching,
allow_dp_padding,
cudagraph_mode,
parallel_config,
)

View File

@@ -70,6 +70,42 @@ class AsyncOutput(AsyncModelRunnerOutput):
return self.model_runner_output
class AsyncPoolingOutput(AsyncModelRunnerOutput):
def __init__(
self,
model_runner_output: ModelRunnerOutput,
pooler_output: torch.Tensor,
is_valid: torch.Tensor | None,
main_stream: torch.cuda.Stream,
copy_stream: torch.cuda.Stream,
copy_event: torch.cuda.Event,
):
self.model_runner_output = model_runner_output
self.pooler_output = pooler_output
self.is_valid = is_valid
self.copy_event = copy_event
with stream(copy_stream, main_stream):
copy_stream.wait_stream(main_stream)
self.pooler_output_cpu = self.pooler_output.to("cpu", non_blocking=True)
if self.is_valid is not None:
self.is_valid_cpu = self.is_valid.to("cpu", non_blocking=True)
else:
self.is_valid_cpu = None
self.copy_event.record(copy_stream)
def get_output(self) -> ModelRunnerOutput:
self.copy_event.synchronize()
pooler_output = self.pooler_output_cpu.unbind(dim=0)
if self.is_valid_cpu is not None:
is_valid_cpu = self.is_valid_cpu.tolist()
for i, is_valid in enumerate(is_valid_cpu):
if not is_valid:
pooler_output[i] = None
self.model_runner_output.pooler_output = pooler_output
return self.model_runner_output
def async_copy_to_np(x: torch.Tensor) -> np.ndarray:
return x.to("cpu", non_blocking=True).numpy()

View File

@@ -119,6 +119,10 @@ class BlockTables:
return tuple(block_table[:num_reqs] for block_table in self.input_block_tables)
def get_dummy_block_tables(self, num_reqs: int) -> tuple[torch.Tensor, ...]:
# NOTE(woosuk): The output may be used for CUDA graph capture.
# Therefore, this method must return the persistent tensor
# with the same memory address as that used during the model's forward pass,
# rather than allocating a new tensor.
return tuple(block_table[:num_reqs] for block_table in self.input_block_tables)
def compute_slot_mappings(
@@ -150,7 +154,14 @@ class BlockTables:
return self.slot_mappings[:, :num_tokens]
def get_dummy_slot_mappings(self, num_tokens: int) -> torch.Tensor:
# Fill the entire slot_mappings tensor, not just the first `num_tokens` entries.
# This is because the padding logic is complex and kernels may access beyond
# the requested range.
self.slot_mappings.fill_(PAD_SLOT_ID)
# NOTE(woosuk): The output may be used for CUDA graph capture.
# Therefore, this method must return the persistent tensor
# with the same memory address as that used during the model's forward pass,
# rather than allocating a new tensor.
return self.slot_mappings[:, :num_tokens]

View File

@@ -3,7 +3,6 @@
from collections.abc import Callable
from typing import Any
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
@@ -15,13 +14,12 @@ from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.model_executor.offloader.base import get_offloader
from vllm.utils.math_utils import cdiv
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import (
build_attn_metadata,
build_slot_mappings_by_layer,
)
from vllm.v1.worker.gpu.attn_utils import build_slot_mappings_by_layer
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.cp_utils import prepare_dcp_local_seq_lens
from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
from vllm.v1.worker.gpu.input_batch import InputBuffers
from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
from vllm.v1.worker.gpu.model_states.interface import ModelState
from vllm.v1.worker.utils import AttentionGroup
@@ -29,13 +27,11 @@ class CudaGraphManager:
def __init__(
self,
vllm_config: VllmConfig,
uses_mrope: bool,
use_aux_hidden_state_outputs: bool,
device: torch.device,
):
self.vllm_config = vllm_config
self.scheduler_config = vllm_config.scheduler_config
self.uses_mrope = uses_mrope
self.use_aux_hidden_state_outputs = use_aux_hidden_state_outputs
self.device = device
@@ -88,9 +84,8 @@ class CudaGraphManager:
num_tokens: int,
capture_cg_mode: CUDAGraphMode,
model: nn.Module,
model_state: ModelState,
input_buffers: InputBuffers,
mrope_positions: torch.Tensor | None,
inputs_embeds: torch.Tensor | None,
block_tables: BlockTables,
attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig,
@@ -113,24 +108,23 @@ class CudaGraphManager:
)
else:
num_reqs = min(num_tokens, self.max_num_reqs)
input_ids = input_buffers.input_ids[:num_tokens]
positions = input_buffers.positions[:num_tokens]
if self.uses_mrope:
assert mrope_positions is not None
positions = mrope_positions[:, :num_tokens]
if inputs_embeds is not None:
inputs_embeds = inputs_embeds[:num_tokens]
model_inputs = {
"input_ids": input_buffers.input_ids[:num_tokens],
"positions": input_buffers.positions[:num_tokens],
# NOTE: Values returned by `prepare_dummy_inputs` will override the
# default values above.
**model_state.prepare_dummy_inputs(num_reqs, num_tokens),
}
attn_metadata, slot_mappings = prepare_inputs_to_capture(
num_reqs,
num_tokens,
model_state,
input_buffers,
block_tables,
attn_groups,
self.max_model_len,
kv_cache_config,
uniform_decode_query_len=(
self.uniform_decode_query_len if uniform_decode else 0
),
)
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
@@ -143,11 +137,7 @@ class CudaGraphManager:
num_tokens_across_dp=num_tokens_across_dp,
slot_mapping=slot_mappings,
):
model_output = model(
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
)
model_output = model(**model_inputs)
if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = model_output
else:
@@ -164,9 +154,7 @@ class CudaGraphManager:
num_tokens=num_tokens,
num_reqs=num_reqs,
model=model,
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
model_inputs=model_inputs,
num_tokens_across_dp=num_tokens_across_dp,
attn_metadata=attn_metadata,
slot_mappings=slot_mappings,
@@ -178,9 +166,7 @@ class CudaGraphManager:
num_tokens: int,
num_reqs: int,
model: nn.Module,
input_ids: torch.Tensor,
positions: torch.Tensor,
inputs_embeds: torch.Tensor | None,
model_inputs: dict[str, torch.Tensor | None],
num_tokens_across_dp: torch.Tensor,
attn_metadata: dict[str, Any] | None,
slot_mappings: dict[str, torch.Tensor] | None,
@@ -206,11 +192,8 @@ class CudaGraphManager:
),
torch.cuda.graph(graph, self.pool),
):
model_output = model(
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
)
model_output = model(**model_inputs)
# Join offloader's copy stream after forward to avoid unjoined
# stream error. The last layer's start_prefetch forks copy_stream,
# but wait_prefetch only happens in the next forward pass.
@@ -235,9 +218,7 @@ class CudaGraphManager:
num_tokens: int,
num_reqs: int,
model: nn.Module,
input_ids: torch.Tensor,
positions: torch.Tensor,
inputs_embeds: torch.Tensor | None,
model_inputs: dict[str, torch.Tensor | None],
num_tokens_across_dp: torch.Tensor,
attn_metadata: dict[str, Any] | None,
slot_mappings: dict[str, torch.Tensor] | None,
@@ -256,19 +237,14 @@ class CudaGraphManager:
batch_descriptor=batch_descriptor,
slot_mapping=slot_mappings,
):
model(
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
)
model(**model_inputs)
@torch.inference_mode()
def capture(
self,
model: nn.Module,
model_state: ModelState,
input_buffers: InputBuffers,
mrope_positions: torch.Tensor | None,
inputs_embeds: torch.Tensor | None,
block_tables: BlockTables,
attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig,
@@ -278,9 +254,8 @@ class CudaGraphManager:
device=self.device,
capture_fn=self.capture_graph,
model=model,
model_state=model_state,
input_buffers=input_buffers,
mrope_positions=mrope_positions,
inputs_embeds=inputs_embeds,
block_tables=block_tables,
attn_groups=attn_groups,
kv_cache_config=kv_cache_config,
@@ -412,51 +387,36 @@ def capture_graphs(
def prepare_inputs_to_capture(
num_reqs: int,
num_tokens: int,
model_state: ModelState,
input_buffers: InputBuffers,
block_tables: BlockTables,
attn_groups: list[list[AttentionGroup]],
max_model_len: int,
kv_cache_config: KVCacheConfig,
uniform_decode_query_len: int = 0,
) -> tuple[dict[str, Any], dict[str, torch.Tensor]]:
if uniform_decode_query_len > 0:
num_tokens_per_req = uniform_decode_query_len
else:
num_tokens_per_req = num_tokens // num_reqs
query_start_loc_np = np.arange(num_reqs + 1, dtype=np.int32) * num_tokens_per_req
query_start_loc_np[-1] = num_tokens
query_start_loc_cpu = torch.from_numpy(query_start_loc_np)
input_buffers.query_start_loc[: num_reqs + 1] = query_start_loc_cpu
input_buffers.query_start_loc[num_reqs + 1 :] = num_tokens
query_start_loc = input_buffers.query_start_loc[: num_reqs + 1]
# HACK(woosuk): For faster warmup, we set seq_lens (GPU) to num_tokens
# rather than max_model_len.
input_buffers.seq_lens[:num_reqs] = num_tokens
input_buffers.seq_lens[num_reqs:] = 0
input_buffers.dcp_local_seq_lens[:num_reqs] = num_tokens
input_buffers.dcp_local_seq_lens[num_reqs:] = 0
input_block_tables = [x[:num_reqs] for x in block_tables.input_block_tables]
slot_mappings = block_tables.slot_mappings[:, :num_tokens]
input_batch = InputBatch.make_dummy(num_reqs, num_tokens, input_buffers)
input_block_tables = block_tables.get_dummy_block_tables(num_reqs)
slot_mappings = block_tables.get_dummy_slot_mappings(num_tokens)
slot_mappings_by_layer = build_slot_mappings_by_layer(
slot_mappings, kv_cache_config
)
attn_metadata = build_attn_metadata(
attn_groups=attn_groups,
num_reqs=num_reqs,
num_tokens=num_tokens,
query_start_loc_gpu=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
max_query_len=num_tokens_per_req,
seq_lens=input_buffers.seq_lens,
max_seq_len=max_model_len,
block_tables=input_block_tables,
slot_mappings=slot_mappings,
kv_cache_config=kv_cache_config,
dcp_local_seq_lens=input_buffers.dcp_local_seq_lens,
# HACK(woosuk): Special handling for DCP.
if block_tables.cp_size > 1:
prepare_dcp_local_seq_lens(
input_buffers.dcp_local_seq_lens,
input_batch.seq_lens,
num_reqs,
block_tables.cp_size,
block_tables.cp_rank,
block_tables.cp_interleave,
)
input_batch.dcp_local_seq_lens = input_buffers.dcp_local_seq_lens[:num_reqs]
attn_metadata = model_state.prepare_attn(
input_batch,
input_block_tables,
slot_mappings,
attn_groups,
kv_cache_config,
)
return attn_metadata, slot_mappings_by_layer

View File

@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Any
import numpy as np
import torch
@@ -60,20 +59,13 @@ class InputBatch:
query_start_loc_np: np.ndarray
# [num_reqs]
seq_lens: torch.Tensor
# [num_reqs]
dcp_local_seq_lens: torch.Tensor | None
# [num_tokens_after_padding]
input_ids: torch.Tensor
# [num_tokens_after_padding]
positions: torch.Tensor
# [3, num_tokens_after_padding]
mrope_positions: torch.Tensor | None
# [num_tokens_after_padding, hidden_size]
inputs_embeds: torch.Tensor | None
# layer_name -> Metadata
attn_metadata: dict[str, Any]
# layer_name -> slot_mapping
slot_mappings: dict[str, torch.Tensor]
# [total_num_logits]
logits_indices: torch.Tensor
@@ -90,14 +82,16 @@ class InputBatch:
num_reqs: int,
num_tokens: int,
input_buffers: InputBuffers,
device: torch.device,
) -> "InputBatch":
assert 0 < num_reqs <= num_tokens
device = input_buffers.device
req_ids = [f"req_{i}_{random_uuid()}" for i in range(num_reqs)]
idx_mapping_np = np.arange(num_reqs, dtype=np.int32)
idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=device)
expanded_idx_mapping = idx_mapping
expanded_local_pos = torch.zeros(num_reqs, dtype=torch.int32, device=device)
num_scheduled_tokens = np.full(num_reqs, num_tokens // num_reqs, dtype=np.int32)
num_scheduled_tokens[-1] += num_tokens % num_reqs
assert int(num_scheduled_tokens.sum()) == num_tokens
@@ -123,7 +117,6 @@ class InputBatch:
input_ids = input_buffers.input_ids[:num_tokens].zero_()
positions = input_buffers.positions[:num_tokens].zero_()
# attn_metadata = defaultdict(lambda: None)
logits_indices = query_start_loc[1:] - 1
cu_num_logits = torch.arange(num_reqs + 1, device=device, dtype=torch.int32)
cu_num_logits_np = np.arange(num_reqs + 1, dtype=np.int32)
@@ -141,12 +134,9 @@ class InputBatch:
query_start_loc=query_start_loc,
query_start_loc_np=query_start_loc_np,
seq_lens=seq_lens,
dcp_local_seq_lens=None,
input_ids=input_ids,
positions=positions,
mrope_positions=None,
inputs_embeds=None,
attn_metadata=None, # type: ignore
slot_mappings=None, # type: ignore
logits_indices=logits_indices,
cu_num_logits=cu_num_logits,
cu_num_logits_np=cu_num_logits_np,
@@ -507,6 +497,38 @@ def post_update(
)
@triton.jit
def _post_update_pool_kernel(
idx_mapping_ptr,
num_computed_tokens_ptr,
query_start_loc_ptr,
):
batch_id = tl.program_id(0)
query_start = tl.load(query_start_loc_ptr + batch_id)
query_end = tl.load(query_start_loc_ptr + batch_id + 1)
query_len = query_end - query_start
req_state_idx = tl.load(idx_mapping_ptr + batch_id)
num_computed = tl.load(num_computed_tokens_ptr + req_state_idx)
tl.store(num_computed_tokens_ptr + req_state_idx, num_computed + query_len)
def post_update_pool(
# [num_reqs]
idx_mapping: torch.Tensor,
# [max_num_reqs]
num_computed_tokens: torch.Tensor,
# [num_reqs + 1]
query_start_loc: torch.Tensor,
) -> None:
num_reqs = idx_mapping.shape[0]
_post_update_pool_kernel[(num_reqs,)](
idx_mapping,
num_computed_tokens,
query_start_loc,
)
@triton.jit
def _expand_idx_mapping_kernel(
idx_mapping_ptr,

View File

@@ -0,0 +1,40 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.multimodal.inputs import MultiModalFeatureSpec
class EncoderCache:
def __init__(self):
# req_id -> MM features
self.mm_features: dict[str, list[MultiModalFeatureSpec]] = {}
# MM hash -> encoder outputs
self.encoder_outputs: dict[str, torch.Tensor] = {}
def add_request(
self, req_id: str, mm_features: list[MultiModalFeatureSpec]
) -> None:
self.mm_features[req_id] = mm_features
def remove_request(self, req_id: str) -> None:
self.mm_features.pop(req_id, None)
def reset_mm_cache(self) -> None:
"""
Clear the multi-modal cache that was used during profiling,
but no longer needed during inference.
"""
# TODO: Implement MM budget for encoder dummy run
pass
def reset_encoder_cache(self) -> None:
"""Clear the GPU-side encoder cache storing vision embeddings.
This should be called when model weights are updated to ensure
stale embeddings computed with old weights are not reused.
"""
self.encoder_outputs.clear()
def free_encoder_cache(self, mm_hash: str) -> None:
self.encoder_outputs.pop(mm_hash, None)

View File

@@ -4,54 +4,32 @@ import numpy as np
import torch
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItem
from vllm.multimodal.inputs import MultiModalKwargsItem
from vllm.multimodal.utils import group_mm_kwargs_by_modality
from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache
from vllm.v1.worker.utils import sanity_check_mm_encoder_outputs
class EncoderRunner:
def __init__(
self,
model: SupportsMultiModal,
max_num_tokens: int,
hidden_size: int,
encoder_cache: EncoderCache,
dtype: torch.dtype,
device: torch.device,
):
self.model = model
self.max_num_tokens = max_num_tokens
self.hidden_size = hidden_size
self.encoder_cache = encoder_cache
self.dtype = dtype
self.device = device
self.inputs_embeds = torch.zeros(
max_num_tokens, hidden_size, dtype=dtype, device=device
)
self.req_id_to_mm_features: dict[str, list[MultiModalFeatureSpec]] = {}
self.encoder_cache: dict[str, torch.Tensor] = {}
def reset_mm_cache(self) -> None:
"""
Clear the multi-modal cache that was used during profiling,
but no longer needed during inference.
"""
# TODO: Implement MM budget for encoder dummy run
pass
def reset_encoder_cache(self) -> None:
"""Clear the GPU-side encoder cache storing vision embeddings.
This should be called when model weights are updated to ensure
stale embeddings computed with old weights are not reused.
"""
self.encoder_cache.clear()
def add_request(self, req_id: str, mm_features: list[MultiModalFeatureSpec]):
self.req_id_to_mm_features[req_id] = mm_features
def free_encoder_cache(self, mm_hash: str) -> None:
self.encoder_cache.pop(mm_hash, None)
def remove_request(self, req_id: str) -> None:
self.req_id_to_mm_features.pop(req_id, None)
def prepare_mm_inputs(
self, scheduled_encoder_inputs: dict[str, list[int]]
@@ -59,7 +37,7 @@ class EncoderRunner:
mm_hashes: list[str] = []
mm_kwargs: list[tuple[str, MultiModalKwargsItem]] = []
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
mm_features = self.req_id_to_mm_features[req_id]
mm_features = self.encoder_cache.mm_features[req_id]
for mm_input_id in encoder_input_ids:
mm_feature = mm_features[mm_input_id]
if mm_feature.data is None:
@@ -72,25 +50,17 @@ class EncoderRunner:
@torch.inference_mode()
def execute_mm_encoder(
self,
model: SupportsMultiModal,
mm_hashes: list[str],
mm_kwargs: list[tuple[str, MultiModalKwargsItem]],
) -> list[torch.Tensor]:
if not mm_hashes:
return []
encoder_outputs: list[torch.Tensor] = []
for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
mm_kwargs, device=self.device, pin_memory=False
):
curr_group_outputs = model.embed_multimodal(**mm_kwargs_group)
curr_group_outputs = self.model.embed_multimodal(**mm_kwargs_group)
sanity_check_mm_encoder_outputs(
curr_group_outputs, expected_num_items=num_items
)
encoder_outputs.extend(curr_group_outputs)
# Cache the encoder outputs by mm_hash
self.encoder_cache.update(zip(mm_hashes, encoder_outputs))
return encoder_outputs
def gather_mm_embeddings(
@@ -122,7 +92,7 @@ class EncoderRunner:
# OPTIMIZATION: Skip decode requests.
continue
mm_features = self.req_id_to_mm_features[req_id]
mm_features = self.encoder_cache.mm_features[req_id]
for mm_feature in mm_features:
pos_info = mm_feature.mm_position
start_pos = pos_info.offset
@@ -148,7 +118,7 @@ class EncoderRunner:
continue
mm_hash = mm_feature.identifier
encoder_output = self.encoder_cache.get(mm_hash, None)
encoder_output = self.encoder_cache.encoder_outputs.get(mm_hash, None)
assert encoder_output is not None, f"Encoder cache miss for {mm_hash}."
if (is_embed := pos_info.is_embed) is not None:
@@ -170,12 +140,11 @@ class EncoderRunner:
@torch.inference_mode()
def get_inputs_embeds(
self,
model: SupportsMultiModal,
input_ids: torch.Tensor,
mm_embeds: list[torch.Tensor],
is_mm_embed: torch.Tensor,
) -> torch.Tensor:
x = model.embed_input_ids(
x = self.model.embed_input_ids(
input_ids, multimodal_embeddings=mm_embeds, is_multimodal=is_mm_embed
)
# Copy to the pre-allocated buffer for CUDA graphs.

View File

@@ -38,15 +38,16 @@ from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model_loader
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors
from vllm.tasks import SupportedTask
from vllm.utils.math_utils import cdiv
from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
from vllm.v1.worker.cp_utils import check_attention_cp_compatibility
from vllm.v1.worker.gpu.async_utils import AsyncOutput
from vllm.v1.worker.gpu.async_utils import AsyncOutput, AsyncPoolingOutput
from vllm.v1.worker.gpu.attn_utils import (
build_attn_metadata,
build_slot_mappings_by_layer,
get_kv_cache_spec,
init_attn_backend,
@@ -56,10 +57,7 @@ from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.buffer_utils import async_copy_to_gpu
from vllm.v1.worker.gpu.cp_utils import prepare_dcp_local_seq_lens
from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager
from vllm.v1.worker.gpu.dp_utils import (
get_cudagraph_and_dp_padding,
make_num_tokens_across_dp,
)
from vllm.v1.worker.gpu.dp_utils import get_cudagraph_and_dp_padding
from vllm.v1.worker.gpu.input_batch import (
InputBatch,
InputBuffers,
@@ -67,6 +65,7 @@ from vllm.v1.worker.gpu.input_batch import (
expand_idx_mapping,
get_num_sampled_and_rejected,
post_update,
post_update_pool,
prepare_pos_seq_lens,
prepare_prefill_inputs,
)
@@ -76,8 +75,9 @@ from vllm.v1.worker.gpu.kv_connector import (
get_kv_connector,
)
from vllm.v1.worker.gpu.lora_utils import LoraState
from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner
from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState
from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache
from vllm.v1.worker.gpu.model_states import init_model_state
from vllm.v1.worker.gpu.pool.pooling_runner import PoolingRunner
from vllm.v1.worker.gpu.pp_utils import pp_broadcast, pp_receive
from vllm.v1.worker.gpu.sample.output import SamplerOutput
from vllm.v1.worker.gpu.sample.prompt_logprob import PromptLogprobsWorker
@@ -120,34 +120,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
self.cache_config.cache_dtype
]
self.is_pooling_model = False
self.vocab_size = self.model_config.get_vocab_size()
self.max_model_len = self.model_config.max_model_len
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
self.max_num_reqs = self.scheduler_config.max_num_seqs
self.inputs_embeds_size = self.model_config.get_inputs_embeds_size()
# Multimodal
self.mm_registry = MULTIMODAL_REGISTRY
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
self.model_config
)
if self.supports_mm_inputs:
self.encoder_runner = EncoderRunner(
max_num_tokens=self.max_num_tokens,
hidden_size=self.inputs_embeds_size,
dtype=self.dtype,
device=self.device,
)
self.uses_mrope = self.model_config.uses_mrope
if self.uses_mrope:
self.mrope_states = MRopeState(
max_num_reqs=self.max_num_reqs,
max_num_tokens=self.max_num_tokens,
max_model_len=self.max_model_len,
device=self.device,
)
self.use_async_scheduling = self.scheduler_config.async_scheduling
self.output_copy_stream = torch.cuda.Stream(self.device)
@@ -169,6 +146,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.dcp_rank = get_dcp_group().rank_in_group if self.use_dcp else 0
self.cp_interleave = self.parallel_config.cp_kv_cache_interleave_size
# Multimodal
self.mm_registry = MULTIMODAL_REGISTRY
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
self.model_config
)
self.encoder_cache = None
if self.supports_mm_inputs and self.is_first_pp_rank:
self.encoder_cache = EncoderCache()
self.speculator = None
self.num_speculative_steps = 0
self.use_aux_hidden_state_outputs = False
@@ -212,7 +198,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# CUDA graphs.
self.cudagraph_manager = CudaGraphManager(
self.vllm_config,
self.uses_mrope,
self.use_aux_hidden_state_outputs,
self.device,
)
@@ -227,13 +212,24 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# KV Connector if configured.
self.kv_connector: KVConnector = NO_OP_KV_CONNECTOR
# Pooling models.
self.is_pooling_model = self.model_config.runner_type == "pooling"
self.pooling_runner: PoolingRunner | None = None
# For transferring state from execute_model to subsequent sample_tokens call.
self.execute_model_state: tuple | None = None
def update_max_model_len(self, max_model_len: int) -> None:
self.max_model_len = max_model_len
self.req_states.max_model_len = max_model_len
@staticmethod
def get_supported_tasks() -> tuple[str]:
return ("generate",)
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
tasks: list[SupportedTask] = []
if self.model_config.runner_type == "generate":
tasks.append("generate")
if self.pooling_runner is not None:
tasks.extend(self.pooling_runner.get_supported_pooling_tasks())
return tuple(tasks)
def load_model(self, *args, **kwargs) -> None:
time_before_load = time.perf_counter()
@@ -266,7 +262,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
prepare_communication_buffer_for_model(self.model)
if self.speculator is not None:
prepare_communication_buffer_for_model(self.speculator)
prepare_communication_buffer_for_model(self.speculator.model)
# Initialize the components that require the model.
self.model_state = init_model_state(
self.vllm_config, self.model, self.encoder_cache, self.device
)
if self.is_pooling_model:
self.pooling_runner = PoolingRunner(self.model)
def get_model(self) -> nn.Module:
return self.model
@@ -305,6 +308,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if self.speculator is not None:
# HACK(woosuk)
self.speculator.set_attn(
self.model_state,
self.kv_cache_config,
self.attn_groups,
self.block_tables,
@@ -320,39 +324,27 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
self.kv_connector = get_kv_connector(self.vllm_config, kv_caches_dict)
def prepare_dummy_attn_metadata(self, input_batch: InputBatch) -> None:
block_tables = self.block_tables.get_dummy_block_tables(input_batch.num_reqs)
slot_mappings = self.block_tables.get_dummy_slot_mappings(
input_batch.num_tokens
)
slot_mappings_by_layer = build_slot_mappings_by_layer(
slot_mappings, self.kv_cache_config
)
attn_metadata = build_attn_metadata(
attn_groups=self.attn_groups,
num_reqs=input_batch.num_reqs,
num_tokens=input_batch.num_tokens,
query_start_loc_gpu=input_batch.query_start_loc,
query_start_loc_cpu=torch.from_numpy(input_batch.query_start_loc_np),
max_query_len=input_batch.num_scheduled_tokens.max().item(),
seq_lens=input_batch.seq_lens,
max_seq_len=self.max_model_len,
block_tables=block_tables,
slot_mappings=slot_mappings,
kv_cache_config=self.kv_cache_config,
dcp_local_seq_lens=self.input_buffers.dcp_local_seq_lens,
)
input_batch.attn_metadata = attn_metadata
input_batch.slot_mappings = slot_mappings_by_layer
@torch.inference_mode()
def _dummy_run(
self, num_tokens: int, *args, skip_attn: bool = True, **kwargs
self,
num_tokens: int,
*args,
skip_attn: bool = True,
uniform_decode: bool = False,
**kwargs,
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
# Create a dummy scheduler output.
num_reqs = min(num_tokens, self.max_num_reqs)
num_tokens_per_request = [num_tokens // num_reqs] * num_reqs
num_tokens_per_request[-1] += num_tokens % num_reqs
if uniform_decode:
# Align tokens to uniform_decode_query_len for cudagraph
# compatibility across DP ranks.
query_len = self.cudagraph_manager.uniform_decode_query_len
num_reqs = min(cdiv(num_tokens, query_len), self.max_num_reqs)
num_tokens = num_reqs * query_len
num_tokens_per_request = [query_len] * num_reqs
else:
num_reqs = min(num_tokens, self.max_num_reqs)
num_tokens_per_request = [num_tokens // num_reqs] * num_reqs
num_tokens_per_request[-1] += num_tokens % num_reqs
assert sum(num_tokens_per_request) == num_tokens
num_scheduled_tokens = {
f"_dummy_req_{i}": n for i, n in enumerate(num_tokens_per_request)
@@ -387,7 +379,41 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return None, None
assert self.execute_model_state is not None
hidden_states, _, input_batch, _ = self.execute_model_state
(
input_batch,
model_inputs,
attn_metadata,
slot_mappings_by_layer,
hidden_states,
aux_hidden_states,
kv_connector_output,
num_tokens_across_dp,
) = self.execute_model_state
self.execute_model_state = None
# dummy run the eagle speculator's propose to ensure DP/EP sync.
if self.speculator is not None:
self.speculator.propose(
input_batch=input_batch,
attn_metadata=attn_metadata,
slot_mappings=slot_mappings_by_layer,
last_hidden_states=hidden_states,
aux_hidden_states=aux_hidden_states,
num_sampled=torch.ones(
input_batch.num_reqs, dtype=torch.int32, device=self.device
),
num_rejected=torch.zeros(
input_batch.num_reqs, dtype=torch.int32, device=self.device
),
last_sampled=self.req_states.last_sampled_tokens,
next_prefill_tokens=self.req_states.next_prefill_tokens,
temperature=self.sampler.sampling_states.temperature.gpu,
seeds=self.sampler.sampling_states.seeds.gpu,
num_tokens_across_dp=num_tokens_across_dp,
dummy_run=True,
skip_attn_for_dummy_run=skip_attn,
)
assert hidden_states is not None # Last PP rank always has hidden_states
sample_hidden_states = hidden_states[input_batch.logits_indices]
return hidden_states, sample_hidden_states
@@ -416,39 +442,36 @@ class GPUModelRunner(LoRAModelRunnerMixin):
expanded_local_pos,
)
@torch.inference_mode()
def _dummy_pooler_run(self, hidden_states: torch.Tensor) -> None:
assert self.pooling_runner is not None
self.pooling_runner.dummy_pooler_run(hidden_states)
@torch.inference_mode()
def profile_run(self) -> None:
hidden_states, sample_hidden_states = self._dummy_run(
self.max_num_tokens, skip_attn=True
)
# Only run sampler on last PP rank (non-last ranks return None).
# Only run sampler/pooler on last PP rank (non-last ranks return None).
if self.is_last_pp_rank:
assert sample_hidden_states is not None
self._dummy_sampler_run(sample_hidden_states)
if self.speculator is not None:
num_tokens_across_dp = make_num_tokens_across_dp(
self.parallel_config.data_parallel_size, self.max_num_tokens
)
self.speculator.run_model(
self.max_num_tokens,
attn_metadata=None,
slot_mappings=None,
num_tokens_across_dp=num_tokens_across_dp,
)
if self.pooling_runner is None:
self._dummy_sampler_run(sample_hidden_states)
else:
self._dummy_pooler_run(hidden_states)
torch.cuda.synchronize()
del hidden_states, sample_hidden_states
gc.collect()
def reset_mm_cache(self) -> None:
if self.supports_mm_inputs:
self.encoder_runner.reset_mm_cache()
if self.encoder_cache is not None:
self.encoder_cache.reset_mm_cache()
def reset_encoder_cache(self) -> None:
if self.supports_mm_inputs:
self.encoder_runner.reset_encoder_cache()
if self.encoder_cache is not None:
self.encoder_cache.reset_encoder_cache()
def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int:
# SP is not supported yet.
@@ -477,17 +500,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
with self.maybe_setup_dummy_loras(self.lora_config):
mrope_positions = None
if self.uses_mrope:
mrope_positions = self.mrope_states.mrope_positions
inputs_embeds = None
if self.supports_mm_inputs:
inputs_embeds = self.encoder_runner.inputs_embeds
self.cudagraph_manager.capture(
model=self.model,
model_state=self.model_state,
input_buffers=self.input_buffers,
mrope_positions=mrope_positions,
inputs_embeds=inputs_embeds,
block_tables=self.block_tables,
attn_groups=self.attn_groups,
kv_cache_config=self.kv_cache_config,
@@ -522,21 +538,20 @@ class GPUModelRunner(LoRAModelRunnerMixin):
finished_req_ids = finished_req_ids.union(preempted_req_ids)
for req_id in finished_req_ids:
self.req_states.remove_request(req_id)
if self.supports_mm_inputs:
self.encoder_runner.remove_request(req_id)
if self.encoder_cache is not None:
self.encoder_cache.remove_request(req_id)
self.prompt_logprobs_worker.remove_request(req_id)
self.lora_state.remove_request(req_id)
def free_states(self, scheduler_output: SchedulerOutput) -> None:
if self.supports_mm_inputs:
if self.encoder_cache is not None:
for mm_hash in scheduler_output.free_encoder_mm_hashes:
self.encoder_runner.free_encoder_cache(mm_hash)
self.encoder_cache.free_encoder_cache(mm_hash)
def add_requests(self, scheduler_output: SchedulerOutput) -> None:
for new_req_data in scheduler_output.scheduled_new_reqs:
assert new_req_data.prompt_token_ids is not None
assert new_req_data.prefill_token_ids is not None
assert new_req_data.sampling_params is not None
req_id = new_req_data.req_id
prompt_len = len(new_req_data.prompt_token_ids)
self.req_states.add_request(
@@ -547,34 +562,27 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
req_index = self.req_states.req_id_to_index[req_id]
if self.supports_mm_inputs:
self.encoder_runner.add_request(req_id, new_req_data.mm_features)
# Pre-compute M-RoPE positions for prefill.
if self.uses_mrope:
self.mrope_states.init_prefill_mrope_positions(
req_index,
self.model, # type: ignore
new_req_data.prefill_token_ids,
mm_features=new_req_data.mm_features,
)
if self.encoder_cache is not None:
self.encoder_cache.add_request(req_id, new_req_data.mm_features)
self.model_state.add_request(req_index, new_req_data)
self.block_tables.append_block_ids(
req_index, new_req_data.block_ids, overwrite=True
)
self.sampler.add_request(
req_index, prompt_len, new_req_data.sampling_params
)
self.prompt_logprobs_worker.add_request(
req_id, req_index, new_req_data.sampling_params
)
self.lora_state.add_request(req_id, req_index, new_req_data.lora_request)
if new_req_data.sampling_params is not None:
self.sampler.add_request(
req_index, prompt_len, new_req_data.sampling_params
)
self.prompt_logprobs_worker.add_request(
req_id, req_index, new_req_data.sampling_params
)
if scheduler_output.scheduled_new_reqs:
self.req_states.apply_staged_writes()
self.sampler.apply_staged_writes()
if self.uses_mrope:
self.mrope_states.apply_staged_writes()
self.model_state.apply_staged_writes()
def update_requests(self, scheduler_output: SchedulerOutput) -> None:
# Add new blocks for the existing requests.
@@ -637,9 +645,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
idx_mapping, total_num_logits, cu_num_logits, max_expand_len
)
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
block_tables = self.block_tables.gather_block_tables(idx_mapping)
# Get query_start_loc.
query_start_loc_np = np.empty(self.max_num_reqs + 1, dtype=np.int32)
query_start_loc_np[0] = 0
@@ -648,11 +653,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Some attention backends like FA3 require query_start_loc to be non-decreasing.
query_start_loc_np[num_reqs + 1 :] = num_tokens
async_copy_to_gpu(query_start_loc_np, out=self.input_buffers.query_start_loc)
query_start_loc_np = query_start_loc_np[: num_reqs + 1]
query_start_loc_cpu = torch.from_numpy(query_start_loc_np)
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
max_query_len = num_scheduled_tokens.max().item()
# Get prefill tokens if any.
if self.req_states.any_prefills(idx_mapping_np):
@@ -676,6 +678,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
seq_lens = self.input_buffers.seq_lens[:num_reqs]
dcp_local_seq_lens = None
if self.use_dcp:
# Prepare dcp local seq_lens.
prepare_dcp_local_seq_lens(
@@ -686,16 +689,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.dcp_rank,
self.cp_interleave,
)
dcp_local_seq_lens = self.input_buffers.dcp_local_seq_lens[:num_reqs]
# Prepare M-RoPE positions.
if self.uses_mrope:
self.mrope_states.prepare_mrope_positions(
idx_mapping,
query_start_loc,
self.req_states.prefill_len.gpu,
self.req_states.num_computed_tokens.gpu,
)
dcp_local_seq_lens = self.input_buffers.dcp_local_seq_lens[:num_reqs]
# Some input token ids are directly read from the last sampled tokens
# and draft tokens. Also, get the logits indices to sample tokens from.
@@ -711,39 +705,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
total_num_logits,
)
# Compute slot mappings: [num_kv_cache_groups, num_tokens]
slot_mappings = self.block_tables.compute_slot_mappings(
idx_mapping,
query_start_loc,
self.input_buffers.positions[:num_tokens],
)
# Layer name -> slot mapping.
slot_mappings_by_layer = build_slot_mappings_by_layer(
slot_mappings, self.kv_cache_config
)
# Layer name -> attention metadata.
attn_metadata = build_attn_metadata(
attn_groups=self.attn_groups,
num_reqs=num_reqs,
num_tokens=num_tokens,
query_start_loc_gpu=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
max_query_len=max_query_len,
seq_lens=self.input_buffers.seq_lens,
max_seq_len=self.max_model_len,
block_tables=block_tables,
slot_mappings=slot_mappings,
kv_cache_config=self.kv_cache_config,
dcp_local_seq_lens=dcp_local_seq_lens,
)
input_ids = self.input_buffers.input_ids[:num_tokens_after_padding]
positions = self.input_buffers.positions[:num_tokens_after_padding]
mrope_positions = None
if self.uses_mrope:
mrope_positions = self.mrope_states.mrope_positions
mrope_positions = mrope_positions[:, :num_tokens_after_padding]
return InputBatch(
req_ids=req_ids,
num_reqs=num_reqs,
@@ -758,37 +719,36 @@ class GPUModelRunner(LoRAModelRunnerMixin):
query_start_loc=query_start_loc,
query_start_loc_np=query_start_loc_np,
seq_lens=seq_lens,
input_ids=input_ids,
positions=positions,
mrope_positions=mrope_positions,
inputs_embeds=None,
attn_metadata=attn_metadata,
slot_mappings=slot_mappings_by_layer,
dcp_local_seq_lens=dcp_local_seq_lens,
input_ids=self.input_buffers.input_ids[:num_tokens_after_padding],
positions=self.input_buffers.positions[:num_tokens_after_padding],
logits_indices=logits_indices,
cu_num_logits=cu_num_logits,
cu_num_logits_np=cu_num_logits_np,
has_structured_output_reqs=scheduler_output.has_structured_output_requests,
)
@torch.inference_mode()
def get_mm_embeddings(
self,
scheduled_encoder_inputs: dict[str, list[int]],
input_batch: InputBatch,
) -> tuple[list[torch.Tensor], torch.Tensor]:
mm_hashes, mm_kwargs = self.encoder_runner.prepare_mm_inputs(
scheduled_encoder_inputs
def prepare_attn(
self, input_batch: InputBatch
) -> tuple[tuple[torch.Tensor, ...], torch.Tensor]:
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
block_tables = self.block_tables.gather_block_tables(input_batch.idx_mapping)
# Compute slot mappings: [num_kv_cache_groups, num_tokens]
slot_mappings = self.block_tables.compute_slot_mappings(
input_batch.idx_mapping,
input_batch.query_start_loc,
input_batch.positions,
)
self.encoder_runner.execute_mm_encoder(self.model, mm_hashes, mm_kwargs)
mm_embeds, is_mm_embed = self.encoder_runner.gather_mm_embeddings(
input_batch.req_ids,
input_batch.num_tokens,
input_batch.num_scheduled_tokens,
input_batch.query_start_loc_np,
self.req_states.prefill_len.np[input_batch.idx_mapping_np],
self.req_states.num_computed_prefill_tokens[input_batch.idx_mapping_np],
return block_tables, slot_mappings
def prepare_dummy_attn(
self, input_batch: InputBatch
) -> tuple[tuple[torch.Tensor, ...], torch.Tensor]:
block_tables = self.block_tables.get_dummy_block_tables(input_batch.num_reqs)
slot_mappings = self.block_tables.get_dummy_slot_mappings(
input_batch.num_tokens
)
return mm_embeds, is_mm_embed
return block_tables, slot_mappings
def sample(
self,
@@ -926,6 +886,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
input_batch = self.prepare_inputs(
scheduler_output, num_tokens_after_padding
)
block_tables, slot_mappings = self.prepare_attn(input_batch)
if self.lora_config:
# Activate LoRA adapters.
lora_inputs = self.lora_state.make_lora_inputs(
@@ -934,35 +896,61 @@ class GPUModelRunner(LoRAModelRunnerMixin):
input_batch.num_scheduled_tokens,
)
self._set_active_loras(*lora_inputs)
# Only first PP rank prepares multimodal embeddings.
if self.supports_mm_inputs and self.is_first_pp_rank:
mm_embeds, is_mm_embed = self.get_mm_embeddings(
scheduler_output.scheduled_encoder_inputs, input_batch
)
inputs_embeds = self.encoder_runner.get_inputs_embeds(
self.model, input_batch.input_ids, mm_embeds, is_mm_embed
)
input_batch.inputs_embeds = inputs_embeds[
: input_batch.num_tokens_after_padding
]
else:
# No actual tokens to run. A dummy run for DP or memory profiling.
num_reqs = min(num_tokens_after_padding, self.max_num_reqs)
input_batch = InputBatch.make_dummy(
num_reqs=num_reqs,
num_tokens=num_tokens_after_padding,
input_buffers=self.input_buffers,
device=self.device,
num_reqs, num_tokens_after_padding, self.input_buffers
)
if self.uses_mrope:
input_batch.mrope_positions = self.mrope_states.mrope_positions[
:, :num_tokens_after_padding
]
if not skip_attn_for_dummy_run:
self.prepare_dummy_attn_metadata(input_batch)
block_tables, slot_mappings = self.prepare_dummy_attn(input_batch)
else:
block_tables = None
slot_mappings = None
# FIXME(woosuk): Fix warmup for LoRA.
attn_metadata = None
slot_mappings_by_layer = None
if not (dummy_run and skip_attn_for_dummy_run):
assert slot_mappings is not None
slot_mappings_by_layer = build_slot_mappings_by_layer(
slot_mappings, self.kv_cache_config
)
assert block_tables is not None
attn_metadata = self.model_state.prepare_attn(
input_batch,
block_tables,
slot_mappings,
self.attn_groups,
self.kv_cache_config,
)
inputs_embeds = None
if self.supports_mm_inputs and self.is_first_pp_rank:
# Run MM encoder (if needed) and get multimodal embeddings.
# Only first PP rank prepares multimodal embeddings.
# NOTE(woosuk): We must call get_mm_embeddings even during dummy runs
# to obtain inputs_embeds, because the compiled model expects this input.
inputs_embeds = self.model_state.get_mm_embeddings(
scheduler_output.scheduled_encoder_inputs,
input_batch,
self.req_states,
)
model_inputs = {
"input_ids": input_batch.input_ids,
"positions": input_batch.positions,
"inputs_embeds": inputs_embeds,
# NOTE: Values returned by `prepare_inputs` will override the default
# values above.
**self.model_state.prepare_inputs(input_batch, self.req_states),
}
if not self.is_first_pp_rank:
# Update for non-first PP ranks.
model_inputs["input_ids"] = None
model_inputs["inputs_embeds"] = None
model_inputs["intermediate_tensors"] = intermediate_tensors
# Run model.
if cudagraph_runtime_mode == CUDAGraphMode.FULL:
# Use explicit cudagraph replay for FULL mode.
@@ -979,41 +967,22 @@ class GPUModelRunner(LoRAModelRunnerMixin):
aux_hidden_states = None
else:
# For piecewise and eager mode, just call model().
positions = input_batch.positions
if self.uses_mrope:
assert input_batch.mrope_positions is not None
positions = input_batch.mrope_positions
if self.is_first_pp_rank:
input_ids = input_batch.input_ids
inputs_embeds = input_batch.inputs_embeds
assert intermediate_tensors is None
else:
input_ids = None
inputs_embeds = None
assert intermediate_tensors is not None
batch_descriptor = BatchDescriptor(
num_tokens=input_batch.num_tokens_after_padding,
has_lora=self.lora_config is not None,
)
with set_forward_context(
input_batch.attn_metadata,
attn_metadata,
self.vllm_config,
num_tokens=input_batch.num_tokens_after_padding,
cudagraph_runtime_mode=cudagraph_runtime_mode,
num_tokens_across_dp=num_tokens_across_dp,
batch_descriptor=batch_descriptor,
slot_mapping=input_batch.slot_mappings,
slot_mapping=slot_mappings_by_layer,
):
self.kv_connector.pre_forward(scheduler_output)
model_output = self.model(
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors,
)
model_output = self.model(**model_inputs)
if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = model_output
else:
@@ -1021,33 +990,44 @@ class GPUModelRunner(LoRAModelRunnerMixin):
aux_hidden_states = None
kv_connector_output = self.kv_connector.post_forward(scheduler_output)
self.execute_model_state = (
input_batch,
model_inputs,
attn_metadata,
slot_mappings_by_layer,
hidden_states,
aux_hidden_states,
kv_connector_output,
num_tokens_across_dp,
)
if not self.is_last_pp_rank:
# Non-last PP rank: return IntermediateTensors for sending.
assert isinstance(hidden_states, IntermediateTensors)
hidden_states.kv_connector_output = kv_connector_output
self.execute_model_state = (None, None, input_batch, kv_connector_output)
return hidden_states
# Last rank (or no PP): hidden_states is a tensor for sampling.
assert isinstance(hidden_states, torch.Tensor)
self.execute_model_state = (
hidden_states,
aux_hidden_states,
input_batch,
kv_connector_output,
) # type: ignore
return None
@torch.inference_mode()
def sample_tokens(
self, grammar_output: GrammarOutput | None
) -> AsyncOutput | ModelRunnerOutput | None:
assert self.execute_model_state is not None
hidden_states, aux_hidden_states, input_batch, kv_connector_output = (
self.execute_model_state
)
self.execute_model_state = None # type: ignore
if self.execute_model_state is None:
# The prior execute_model call must have failed.
return None
(
input_batch,
model_inputs,
attn_metadata,
slot_mappings_by_layer,
hidden_states,
aux_hidden_states,
kv_connector_output,
num_tokens_across_dp,
) = self.execute_model_state
self.execute_model_state = None
if not self.is_last_pp_rank:
# Non-last PP rank: hidden_states is None because this rank produced
@@ -1109,6 +1089,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if self.speculator is not None:
draft_tokens = self.speculator.propose(
input_batch,
attn_metadata,
slot_mappings_by_layer,
hidden_states,
aux_hidden_states,
num_sampled,
@@ -1117,6 +1099,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.req_states.next_prefill_tokens,
self.sampler.sampling_states.temperature.gpu,
self.sampler.sampling_states.seeds.gpu,
num_tokens_across_dp=num_tokens_across_dp,
)
self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens
self.draft_tokens_handler.set_draft_tokens(input_batch, draft_tokens)
@@ -1127,3 +1110,58 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def take_draft_token_ids(self) -> DraftTokenIds | None:
return self.draft_tokens_handler.get_draft_tokens()
@torch.inference_mode()
def pool(self) -> AsyncPoolingOutput | ModelRunnerOutput | None:
if self.execute_model_state is None:
# The prior execute_model call must have failed.
return None
input_batch, _, _, _, hidden_states, _, kv_connector_output = (
self.execute_model_state
)
self.execute_model_state = None
if not self.is_last_pp_rank:
self.postprocess_pool(input_batch)
return None
assert self.pooling_runner is not None
pooler_output, is_valid = self.pooling_runner.pool(
hidden_states, input_batch, self.req_states
)
self.postprocess_pool(input_batch)
# Build the model runner output.
model_runner_output = ModelRunnerOutput(
req_ids=input_batch.req_ids,
req_id_to_index={req_id: i for i, req_id in enumerate(input_batch.req_ids)},
kv_connector_output=kv_connector_output,
)
async_output = AsyncPoolingOutput(
model_runner_output=model_runner_output,
pooler_output=pooler_output,
is_valid=is_valid,
main_stream=self.main_stream,
copy_stream=self.output_copy_stream,
copy_event=self.output_copy_event,
)
if self.use_async_scheduling:
return async_output
return async_output.get_output()
def postprocess_pool(self, input_batch: InputBatch) -> None:
# Update the number of computed tokens.
post_update_pool(
input_batch.idx_mapping,
self.req_states.num_computed_tokens.gpu,
input_batch.query_start_loc,
)
# Update the number of computed prefill tokens.
idx_mapping_np = input_batch.idx_mapping_np
computed_prefill = self.req_states.num_computed_prefill_tokens
computed_prefill[idx_mapping_np] += input_batch.num_scheduled_tokens
np.minimum(
computed_prefill, self.req_states.prefill_len.np, out=computed_prefill
)

View File

@@ -0,0 +1,18 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache
def init_model_state(
vllm_config: VllmConfig,
model: nn.Module,
encoder_cache: EncoderCache | None,
device: torch.device,
):
from vllm.v1.worker.gpu.model_states.default import DefaultModelState
return DefaultModelState(vllm_config, model, encoder_cache, device)

View File

@@ -0,0 +1,161 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.v1.core.sched.output import NewRequestData
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
from vllm.v1.worker.gpu.input_batch import InputBatch
from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache
from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner
from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState
from vllm.v1.worker.gpu.model_states.interface import ModelState
from vllm.v1.worker.gpu.states import RequestState
from vllm.v1.worker.utils import AttentionGroup
class DefaultModelState(ModelState):
def __init__(
self,
vllm_config: VllmConfig,
model: nn.Module,
encoder_cache: EncoderCache | None,
device: torch.device,
):
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.scheduler_config = vllm_config.scheduler_config
self.model = model
self.device = device
self.supports_mm_inputs = encoder_cache is not None
self.max_model_len = self.model_config.max_model_len
self.max_num_reqs = self.scheduler_config.max_num_seqs
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
self.inputs_embeds_size = self.model_config.get_inputs_embeds_size()
self.dtype = self.model_config.dtype
if self.supports_mm_inputs:
assert encoder_cache is not None
self.encoder_cache = encoder_cache
self.encoder_runner = EncoderRunner(
model=self.model,
max_num_tokens=self.max_num_tokens,
hidden_size=self.inputs_embeds_size,
encoder_cache=encoder_cache,
dtype=self.dtype,
device=self.device,
)
self.uses_mrope = self.model_config.uses_mrope
if self.uses_mrope:
self.mrope_state = MRopeState(
max_num_reqs=self.max_num_reqs,
max_num_tokens=self.max_num_tokens,
max_model_len=self.max_model_len,
device=self.device,
)
def add_request(self, req_index: int, new_req_data: NewRequestData) -> None:
if self.uses_mrope:
# Pre-compute M-RoPE positions for prefill.
assert new_req_data.prefill_token_ids is not None
self.mrope_state.init_prefill_mrope_positions(
req_index,
self.model, # type: ignore
new_req_data.prefill_token_ids,
mm_features=new_req_data.mm_features,
)
def apply_staged_writes(self) -> None:
if self.uses_mrope:
self.mrope_state.apply_staged_writes()
def get_mm_embeddings(
self,
scheduled_encoder_inputs: dict[str, list[int]],
input_batch: InputBatch,
req_states: RequestState,
) -> torch.Tensor:
mm_hashes, mm_kwargs = self.encoder_runner.prepare_mm_inputs(
scheduled_encoder_inputs
)
if mm_kwargs:
# Execute the multimodal encoder.
encoder_outputs = self.encoder_runner.execute_mm_encoder(mm_kwargs)
# Cache the encoder outputs by mm_hash
self.encoder_cache.encoder_outputs.update(zip(mm_hashes, encoder_outputs))
mm_embeds, is_mm_embed = self.encoder_runner.gather_mm_embeddings(
input_batch.req_ids,
input_batch.num_tokens,
input_batch.num_scheduled_tokens,
input_batch.query_start_loc_np,
req_states.prefill_len.np[input_batch.idx_mapping_np],
req_states.num_computed_prefill_tokens[input_batch.idx_mapping_np],
)
inputs_embeds = self.encoder_runner.get_inputs_embeds(
input_batch.input_ids, mm_embeds, is_mm_embed
)
return inputs_embeds[: input_batch.num_tokens_after_padding]
def prepare_inputs(
self, input_batch: InputBatch, req_states: RequestState
) -> dict[str, torch.Tensor | None]:
if not self.uses_mrope:
# Common case (1D positions).
return {}
# Prepare M-RoPE positions.
self.mrope_state.prepare_mrope_positions(
input_batch.idx_mapping,
input_batch.query_start_loc,
req_states.prefill_len.gpu,
req_states.num_computed_tokens.gpu,
)
mrope_positions = self.mrope_state.mrope_positions[
:, : input_batch.num_tokens_after_padding
]
return {"positions": mrope_positions}
def prepare_dummy_inputs(
self, num_reqs: int, num_tokens: int
) -> dict[str, torch.Tensor | None]:
model_inputs = {}
if self.supports_mm_inputs:
inputs_embeds = self.encoder_runner.inputs_embeds[:num_tokens]
model_inputs["inputs_embeds"] = inputs_embeds
if self.uses_mrope:
mrope_positions = self.mrope_state.mrope_positions[:, :num_tokens]
model_inputs["positions"] = mrope_positions
return model_inputs
def prepare_attn(
self,
input_batch: InputBatch,
block_tables: tuple[torch.Tensor, ...],
slot_mappings: torch.Tensor,
attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig,
) -> dict[str, Any]:
query_start_loc_cpu = torch.from_numpy(input_batch.query_start_loc_np)
max_query_len = input_batch.num_scheduled_tokens.max().item()
attn_metadata = build_attn_metadata(
attn_groups=attn_groups,
num_reqs=input_batch.num_reqs,
num_tokens=input_batch.num_tokens,
query_start_loc_gpu=input_batch.query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
max_query_len=max_query_len,
seq_lens=input_batch.seq_lens,
max_seq_len=self.max_model_len,
block_tables=block_tables,
slot_mappings=slot_mappings,
kv_cache_config=kv_cache_config,
dcp_local_seq_lens=input_batch.dcp_local_seq_lens,
)
return attn_metadata

View File

@@ -0,0 +1,67 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from typing import Any
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.v1.core.sched.output import NewRequestData
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.input_batch import InputBatch
from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache
from vllm.v1.worker.gpu.states import RequestState
from vllm.v1.worker.utils import AttentionGroup
class ModelState(ABC):
@abstractmethod
def __init__(
self,
vllm_config: VllmConfig,
model: nn.Module,
encoder_cache: EncoderCache | None,
device: torch.device,
) -> None:
raise NotImplementedError
@abstractmethod
def add_request(self, req_index: int, new_req_data: NewRequestData) -> None:
raise NotImplementedError
@abstractmethod
def apply_staged_writes(self) -> None:
raise NotImplementedError
@abstractmethod
def get_mm_embeddings(
self,
scheduled_encoder_inputs: dict[str, list[int]],
input_batch: InputBatch,
req_states: RequestState,
) -> torch.Tensor:
raise NotImplementedError
@abstractmethod
def prepare_inputs(
self, input_batch: InputBatch, req_states: RequestState
) -> dict[str, torch.Tensor | None]:
raise NotImplementedError
@abstractmethod
def prepare_dummy_inputs(
self, num_reqs: int, num_tokens: int
) -> dict[str, torch.Tensor | None]:
raise NotImplementedError
@abstractmethod
def prepare_attn(
self,
input_batch: InputBatch,
block_tables: tuple[torch.Tensor, ...],
slot_mappings: torch.Tensor,
attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig,
) -> dict[str, Any]:
raise NotImplementedError

View File

View File

@@ -0,0 +1,45 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import cast
import torch
import torch.nn as nn
import torch.nn.functional as F
from vllm.model_executor.models import VllmModelForPooling, is_pooling_model
from vllm.tasks import PoolingTask
from vllm.v1.worker.gpu.input_batch import InputBatch
from vllm.v1.worker.gpu.states import RequestState
# NOTE(woosuk): Currently, this class only supports the "LAST" pooling task
# on decoder-only models. How to support other pooling tasks and models
# is to be determined.
class PoolingRunner:
def __init__(self, model: nn.Module):
self.model = cast(VllmModelForPooling, model)
def get_supported_pooling_tasks(self) -> list[PoolingTask]:
if not is_pooling_model(self.model):
return []
assert "embed" in self.model.pooler.get_supported_tasks()
return ["embed"]
def pool(
self,
hidden_states: torch.Tensor,
input_batch: InputBatch,
req_states: RequestState,
) -> tuple[torch.Tensor, torch.Tensor | None]:
# TODO(woosuk): Support different types of pooling tasks.
last_hidden_states = hidden_states[input_batch.logits_indices]
# TODO(woosuk): Make normalization optional.
last_hidden_states = F.normalize(last_hidden_states, p=2, dim=-1)
prompt_len = req_states.prompt_len.gpu[input_batch.idx_mapping]
is_valid = input_batch.seq_lens == prompt_len
return last_hidden_states, is_valid
def dummy_pooler_run(self, hidden_states: torch.Tensor) -> None:
F.normalize(hidden_states, p=2, dim=-1)
return

View File

@@ -72,7 +72,7 @@ class BadWordsState:
def apply_bad_words(
self,
logits: torch.Tensor,
idx_mapping: torch.Tensor,
expanded_idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
input_ids: torch.Tensor,
expanded_local_pos: torch.Tensor,
@@ -84,7 +84,7 @@ class BadWordsState:
apply_bad_words(
logits,
idx_mapping,
expanded_idx_mapping,
self.bad_word_token_ids.gpu,
self.bad_word_offsets.gpu,
self.num_bad_words.gpu,
@@ -114,17 +114,17 @@ def _bad_words_kernel(
input_ids_ptr,
expanded_local_pos_ptr,
):
logit_idx = tl.program_id(0)
token_idx = tl.program_id(0)
bw_idx = tl.program_id(1)
req_state_idx = tl.load(expanded_idx_mapping_ptr + logit_idx)
req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
num_bad_words = tl.load(num_bad_words_ptr + req_state_idx)
if bw_idx >= num_bad_words:
return
pos = tl.load(expanded_local_pos_ptr + logit_idx)
cur_req_first_pos = logit_idx - pos
pos = tl.load(expanded_local_pos_ptr + token_idx)
cur_req_first_pos = token_idx - pos
prompt_len = tl.load(prompt_len_ptr + req_state_idx)
total_len = tl.load(total_len_ptr + req_state_idx)
@@ -159,7 +159,7 @@ def _bad_words_kernel(
match = match & (expected == actual)
if match:
tl.store(logits_ptr + logit_idx * logits_stride + last_token, -float("inf"))
tl.store(logits_ptr + token_idx * logits_stride + last_token, -float("inf"))
def apply_bad_words(
@@ -175,8 +175,8 @@ def apply_bad_words(
expanded_local_pos: torch.Tensor,
max_num_bad_words: int,
) -> None:
total_num_tokens = logits.shape[0]
_bad_words_kernel[(total_num_tokens, max_num_bad_words)](
num_tokens = logits.shape[0]
_bad_words_kernel[(num_tokens, max_num_bad_words)](
logits,
logits.stride(0),
expanded_idx_mapping,

View File

@@ -9,13 +9,13 @@ from vllm.triton_utils import tl, triton
def _temperature_kernel(
logits_ptr,
logits_stride,
idx_mapping_ptr,
expanded_idx_mapping_ptr,
temperature_ptr,
vocab_size,
BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
token_idx = tl.program_id(0)
req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
temperature = tl.load(temperature_ptr + req_state_idx).to(tl.float32)
if temperature == 0.0 or temperature == 1.0:
# Early return to avoid loading logits.
@@ -25,24 +25,24 @@ def _temperature_kernel(
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size
logits = tl.load(logits_ptr + batch_idx * logits_stride + block, mask=mask)
logits = tl.load(logits_ptr + token_idx * logits_stride + block, mask=mask)
logits = logits.to(tl.float32)
logits = logits / temperature
tl.store(logits_ptr + batch_idx * logits_stride + block, logits, mask=mask)
tl.store(logits_ptr + token_idx * logits_stride + block, logits, mask=mask)
def apply_temperature(
logits: torch.Tensor,
idx_mapping: torch.Tensor,
expanded_idx_mapping: torch.Tensor,
temperature: torch.Tensor,
) -> None:
num_reqs, vocab_size = logits.shape
num_tokens, vocab_size = logits.shape
BLOCK_SIZE = 8192
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
_temperature_kernel[(num_reqs, num_blocks)](
_temperature_kernel[(num_tokens, num_blocks)](
logits,
logits.stride(0),
idx_mapping,
expanded_idx_mapping,
temperature,
vocab_size,
BLOCK_SIZE=BLOCK_SIZE,
@@ -57,7 +57,7 @@ def _gumbel_sample_kernel(
local_max_stride,
logits_ptr,
logits_stride,
idx_mapping_ptr,
expanded_idx_mapping_ptr,
seeds_ptr,
pos_ptr,
temp_ptr,
@@ -65,14 +65,14 @@ def _gumbel_sample_kernel(
BLOCK_SIZE: tl.constexpr,
APPLY_TEMPERATURE: tl.constexpr,
):
batch_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
token_idx = tl.program_id(0)
req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
block_idx = tl.program_id(1)
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size
logits = tl.load(
logits_ptr + batch_idx * logits_stride + block,
logits_ptr + token_idx * logits_stride + block,
mask=mask,
other=float("-inf"),
)
@@ -82,7 +82,7 @@ def _gumbel_sample_kernel(
if temp != 0.0:
# Calculate the seed for gumbel noise.
seed = tl.load(seeds_ptr + req_state_idx)
pos = tl.load(pos_ptr + batch_idx)
pos = tl.load(pos_ptr + token_idx)
gumbel_seed = tl.randint(seed, pos)
# Generate gumbel noise in FP32.
@@ -101,41 +101,41 @@ def _gumbel_sample_kernel(
value, idx = tl.max(logits, axis=0, return_indices=True)
token_id = block_idx * BLOCK_SIZE + idx
tl.store(local_argmax_ptr + batch_idx * local_argmax_stride + block_idx, token_id)
tl.store(local_max_ptr + batch_idx * local_max_stride + block_idx, value)
tl.store(local_argmax_ptr + token_idx * local_argmax_stride + block_idx, token_id)
tl.store(local_max_ptr + token_idx * local_max_stride + block_idx, value)
def gumbel_sample(
logits: torch.Tensor, # [num_reqs, vocab_size]
idx_mapping: torch.Tensor, # [max_num_reqs]
logits: torch.Tensor, # [num_tokens, vocab_size]
expanded_idx_mapping: torch.Tensor, # [num_tokens]
temperature: torch.Tensor, # [max_num_reqs]
seed: torch.Tensor, # [max_num_reqs]
pos: torch.Tensor, # [num_reqs]
pos: torch.Tensor, # [num_tokens]
apply_temperature: bool,
) -> torch.Tensor:
num_reqs, vocab_size = logits.shape
num_tokens, vocab_size = logits.shape
BLOCK_SIZE = 1024
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
local_argmax = torch.empty(
num_reqs,
num_tokens,
num_blocks,
dtype=torch.int64,
device=logits.device,
)
local_max = torch.empty(
num_reqs,
num_tokens,
num_blocks,
dtype=torch.float32,
device=logits.device,
)
_gumbel_sample_kernel[(num_reqs, num_blocks)](
_gumbel_sample_kernel[(num_tokens, num_blocks)](
local_argmax,
local_argmax.stride(0),
local_max,
local_max.stride(0),
logits,
logits.stride(0),
idx_mapping,
expanded_idx_mapping,
seed,
pos,
temperature,

View File

@@ -121,7 +121,7 @@ class LogitBiasState:
def apply_logit_bias(
self,
logits: torch.Tensor,
idx_mapping: torch.Tensor,
expanded_idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
pos: torch.Tensor,
) -> None:
@@ -131,7 +131,7 @@ class LogitBiasState:
apply_logit_bias(
logits,
idx_mapping,
expanded_idx_mapping,
pos,
self.num_allowed_token_ids.gpu,
self.allowed_token_ids.gpu,
@@ -149,7 +149,7 @@ def _bias_kernel(
logits_ptr,
logits_stride,
vocab_size,
idx_mapping_ptr,
expanded_idx_mapping_ptr,
# Allowed token IDs.
num_allowed_token_ids_ptr,
allowed_token_ids_ptr,
@@ -169,8 +169,8 @@ def _bias_kernel(
BLOCK_SIZE: tl.constexpr,
LOGITS_BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
token_idx = tl.program_id(0)
req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
block = tl.arange(0, BLOCK_SIZE)
@@ -186,21 +186,21 @@ def _bias_kernel(
mask=mask,
)
logits = tl.load(
logits_ptr + batch_idx * logits_stride + allowed_token_ids, mask=mask
logits_ptr + token_idx * logits_stride + allowed_token_ids, mask=mask
)
# Set logits to -inf for all tokens.
for i in range(0, vocab_size, LOGITS_BLOCK_SIZE):
offset = i + tl.arange(0, LOGITS_BLOCK_SIZE)
tl.store(
logits_ptr + batch_idx * logits_stride + offset,
logits_ptr + token_idx * logits_stride + offset,
-float("inf"),
mask=offset < vocab_size,
)
# Restore logits for allowed token IDs.
tl.store(
logits_ptr + batch_idx * logits_stride + allowed_token_ids,
logits_ptr + token_idx * logits_stride + allowed_token_ids,
logits,
mask=mask,
)
@@ -214,13 +214,13 @@ def _bias_kernel(
mask=mask,
)
bias = tl.load(bias_ptr + req_state_idx * bias_stride + block, mask=mask)
logits = tl.load(logits_ptr + batch_idx * logits_stride + token_ids, mask=mask)
logits = tl.load(logits_ptr + token_idx * logits_stride + token_ids, mask=mask)
logits += bias
tl.store(logits_ptr + batch_idx * logits_stride + token_ids, logits, mask=mask)
tl.store(logits_ptr + token_idx * logits_stride + token_ids, logits, mask=mask)
# Apply min tokens.
num_stop_token_ids = tl.load(num_stop_token_ids_ptr + req_state_idx)
pos = tl.load(pos_ptr + batch_idx)
pos = tl.load(pos_ptr + token_idx)
min_len = tl.load(min_lens_ptr + req_state_idx)
if num_stop_token_ids > 0 and pos < min_len:
mask = block < num_stop_token_ids
@@ -229,7 +229,7 @@ def _bias_kernel(
mask=mask,
)
tl.store(
logits_ptr + batch_idx * logits_stride + stop_token_ids,
logits_ptr + token_idx * logits_stride + stop_token_ids,
-float("inf"),
mask=mask,
)
@@ -237,7 +237,7 @@ def _bias_kernel(
def apply_logit_bias(
logits: torch.Tensor,
idx_mapping: torch.Tensor,
expanded_idx_mapping: torch.Tensor,
pos: torch.Tensor,
num_allowed_token_ids: torch.Tensor,
allowed_token_ids: torch.Tensor,
@@ -248,7 +248,7 @@ def apply_logit_bias(
num_stop_token_ids: torch.Tensor,
stop_token_ids: torch.Tensor,
) -> None:
num_reqs, vocab_size = logits.shape
num_tokens, vocab_size = logits.shape
BLOCK_SIZE = triton.next_power_of_2(
max(
allowed_token_ids.shape[-1],
@@ -257,11 +257,11 @@ def apply_logit_bias(
)
)
LOGITS_BLOCK_SIZE = 8192
_bias_kernel[(num_reqs,)](
_bias_kernel[(num_tokens,)](
logits,
logits.stride(0),
vocab_size,
idx_mapping,
expanded_idx_mapping,
num_allowed_token_ids,
allowed_token_ids,
allowed_token_ids.stride(0),

View File

@@ -9,13 +9,13 @@ from vllm.triton_utils import tl, triton
def _min_p_kernel(
logits_ptr,
logits_stride,
idx_mapping_ptr,
expanded_idx_mapping_ptr,
min_p_ptr,
vocab_size,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + req_idx)
token_idx = tl.program_id(0)
req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
min_p = tl.load(min_p_ptr + req_state_idx).to(tl.float32)
if min_p == 0.0:
return
@@ -25,7 +25,9 @@ def _min_p_kernel(
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size
logits = tl.load(
logits_ptr + req_idx * logits_stride + block, mask=mask, other=float("-inf")
logits_ptr + token_idx * logits_stride + block,
mask=mask,
other=float("-inf"),
)
max_val = tl.max(tl.maximum(logits, max_val))
max_val = max_val.to(tl.float32) # type: ignore
@@ -35,21 +37,23 @@ def _min_p_kernel(
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size
logits = tl.load(
logits_ptr + req_idx * logits_stride + block, mask=mask, other=float("-inf")
logits_ptr + token_idx * logits_stride + block,
mask=mask,
other=float("-inf"),
)
logits = tl.where(logits < threshold, float("-inf"), logits)
tl.store(logits_ptr + req_idx * logits_stride + block, logits, mask=mask)
tl.store(logits_ptr + token_idx * logits_stride + block, logits, mask=mask)
def apply_min_p(
logits: torch.Tensor, idx_mapping: torch.Tensor, min_p: torch.Tensor
logits: torch.Tensor, expanded_idx_mapping: torch.Tensor, min_p: torch.Tensor
) -> None:
num_reqs, vocab_size = logits.shape
num_tokens, vocab_size = logits.shape
BLOCK_SIZE = 1024
_min_p_kernel[(num_reqs,)](
_min_p_kernel[(num_tokens,)](
logits,
logits.stride(0),
idx_mapping,
expanded_idx_mapping,
min_p,
vocab_size,
BLOCK_SIZE=BLOCK_SIZE,

View File

@@ -82,7 +82,7 @@ class PenaltiesState:
def apply_penalties(
self,
logits: torch.Tensor,
idx_mapping: torch.Tensor,
expanded_idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
input_ids: torch.Tensor,
expanded_local_pos: torch.Tensor,
@@ -94,7 +94,7 @@ class PenaltiesState:
apply_penalties(
logits,
idx_mapping,
expanded_idx_mapping,
input_ids,
expanded_local_pos,
self.repetition_penalty.gpu,
@@ -110,7 +110,7 @@ class PenaltiesState:
def _penalties_kernel(
logits_ptr,
logits_stride,
idx_mapping_ptr,
expanded_idx_mapping_ptr,
token_ids_ptr,
expanded_local_pos_ptr,
repetition_penalty_ptr,
@@ -125,7 +125,7 @@ def _penalties_kernel(
MAX_SPEC_LEN: tl.constexpr,
):
token_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + token_idx)
req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
rep_penalty = tl.load(repetition_penalty_ptr + req_state_idx)
freq_penalty = tl.load(frequency_penalty_ptr + req_state_idx)
pres_penalty = tl.load(presence_penalty_ptr + req_state_idx)
@@ -191,7 +191,7 @@ def _penalties_kernel(
def apply_penalties(
logits: torch.Tensor,
idx_mapping: torch.Tensor,
expanded_idx_mapping: torch.Tensor,
token_ids: torch.Tensor,
expanded_local_pos: torch.Tensor,
repetition_penalty: torch.Tensor,
@@ -207,7 +207,7 @@ def apply_penalties(
_penalties_kernel[(num_tokens, num_blocks)](
logits,
logits.stride(0),
idx_mapping,
expanded_idx_mapping,
token_ids,
expanded_local_pos,
repetition_penalty,
@@ -225,7 +225,7 @@ def apply_penalties(
@triton.jit
def _bincount_kernel(
idx_mapping_ptr,
expanded_idx_mapping_ptr,
all_token_ids_ptr,
all_token_ids_stride,
prompt_len_ptr,
@@ -236,9 +236,9 @@ def _bincount_kernel(
output_bin_counts_stride,
BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
token_idx = tl.program_id(0)
block_idx = tl.program_id(1)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
prefill_len = tl.load(prefill_len_ptr + req_state_idx)
if block_idx * BLOCK_SIZE >= prefill_len:
@@ -276,7 +276,7 @@ def _bincount_kernel(
def bincount(
idx_mapping: torch.Tensor,
expanded_idx_mapping: torch.Tensor,
all_token_ids: torch.Tensor,
prompt_len: torch.Tensor,
prefill_len: torch.Tensor,
@@ -284,13 +284,13 @@ def bincount(
output_bin_counts: torch.Tensor,
max_prefill_len: int,
) -> None:
prompt_bin_mask[idx_mapping] = 0
output_bin_counts[idx_mapping] = 0
num_reqs = idx_mapping.shape[0]
prompt_bin_mask[expanded_idx_mapping] = 0
output_bin_counts[expanded_idx_mapping] = 0
num_tokens = expanded_idx_mapping.shape[0]
BLOCK_SIZE = 1024
num_blocks = triton.cdiv(max_prefill_len, BLOCK_SIZE)
_bincount_kernel[(num_reqs, num_blocks)](
idx_mapping,
_bincount_kernel[(num_tokens, num_blocks)](
expanded_idx_mapping,
all_token_ids,
all_token_ids.stride(0),
prompt_len,

View File

@@ -56,7 +56,7 @@ class Sampler:
def __call__(
self,
logits: torch.Tensor,
idx_mapping: torch.Tensor,
expanded_idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
cu_num_logits_np: np.ndarray,
pos: torch.Tensor,
@@ -68,7 +68,7 @@ class Sampler:
num_nans = get_num_nans(logits) if self.compute_nans else None
sampled, processed_logits = self.sample(
logits,
idx_mapping,
expanded_idx_mapping,
idx_mapping_np,
pos,
input_ids,
@@ -101,7 +101,7 @@ class Sampler:
def sample(
self,
logits: torch.Tensor,
idx_mapping: torch.Tensor,
expanded_idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
pos: torch.Tensor,
input_ids: torch.Tensor,
@@ -111,12 +111,14 @@ class Sampler:
logits = torch.empty_like(logits, dtype=torch.float32).copy_(logits)
# Apply logit bias (e.g., allowed_token_ids, min_tokens) in place.
self.logit_bias_state.apply_logit_bias(logits, idx_mapping, idx_mapping_np, pos)
self.logit_bias_state.apply_logit_bias(
logits, expanded_idx_mapping, idx_mapping_np, pos
)
# Apply penalties in place.
self.penalties_state.apply_penalties(
logits,
idx_mapping,
expanded_idx_mapping,
idx_mapping_np,
input_ids,
expanded_local_pos,
@@ -126,27 +128,29 @@ class Sampler:
# Apply bad words masking in place.
self.bad_words_state.apply_bad_words(
logits,
idx_mapping,
expanded_idx_mapping,
idx_mapping_np,
input_ids,
expanded_local_pos,
)
# Apply temperature in place.
self.sampling_states.apply_temperature(logits, idx_mapping, idx_mapping_np)
self.sampling_states.apply_temperature(
logits, expanded_idx_mapping, idx_mapping_np
)
# Apply min_p in place.
self.sampling_states.apply_min_p(logits, idx_mapping, idx_mapping_np)
self.sampling_states.apply_min_p(logits, expanded_idx_mapping, idx_mapping_np)
# Apply top_k and/or top_p. This might or might not return a new tensor.
logits = self.sampling_states.apply_top_k_top_p(
logits, idx_mapping, idx_mapping_np
logits, expanded_idx_mapping, idx_mapping_np
)
# Sample the next token.
sampled = gumbel_sample(
logits,
idx_mapping,
expanded_idx_mapping,
self.sampling_states.temperature.gpu,
self.sampling_states.seeds.gpu,
pos,

View File

@@ -64,7 +64,7 @@ class SamplingStates:
def apply_temperature(
self,
logits: torch.Tensor,
idx_mapping: torch.Tensor,
expanded_idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
) -> None:
temp_np = self.temperature.np[idx_mapping_np]
@@ -72,23 +72,23 @@ class SamplingStates:
# No request requires temperature. Skip the kernel launch.
return
apply_temperature(logits, idx_mapping, self.temperature.gpu)
apply_temperature(logits, expanded_idx_mapping, self.temperature.gpu)
def apply_min_p(
self,
logits: torch.Tensor,
idx_mapping: torch.Tensor,
expanded_idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
) -> None:
if np.all(self.min_p.np[idx_mapping_np] == 0.0):
# No request uses min_p. Skip the kernel launch.
return
apply_min_p(logits, idx_mapping, self.min_p.gpu)
apply_min_p(logits, expanded_idx_mapping, self.min_p.gpu)
def apply_top_k_top_p(
self,
logits: torch.Tensor,
idx_mapping: torch.Tensor,
expanded_idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
) -> torch.Tensor:
do_top_k = np.any(self.top_k.np[idx_mapping_np] != self.vocab_size)
@@ -96,8 +96,8 @@ class SamplingStates:
if not (do_top_k or do_top_p):
return logits
top_k = self.top_k.gpu[idx_mapping] if do_top_k else None
top_p = self.top_p.gpu[idx_mapping] if do_top_p else None
top_k = self.top_k.gpu[expanded_idx_mapping] if do_top_k else None
top_p = self.top_p.gpu[expanded_idx_mapping] if do_top_p else None
return apply_top_k_top_p(logits, top_k, top_p)
def max_num_logprobs(self, idx_mapping_np: np.ndarray) -> int:

View File

@@ -17,6 +17,7 @@ from vllm.v1.worker.gpu.cudagraph_utils import (
)
from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
from vllm.v1.worker.gpu.input_batch import InputBuffers
from vllm.v1.worker.gpu.model_states.interface import ModelState
from vllm.v1.worker.utils import AttentionGroup
@@ -54,11 +55,32 @@ class EagleCudaGraphManager:
def get_cudagraph_size(self, num_tokens: int) -> int | None:
return self.cudagraph_sizes.get(num_tokens)
def get_cudagraph_runtime_mode(
self, num_tokens: int
) -> tuple[CUDAGraphMode, int | None]:
cudagraph_size = self.get_cudagraph_size(num_tokens)
if cudagraph_size is None:
cudagraph_mode = CUDAGraphMode.NONE
else:
cudagraph_mode = self.cudagraph_mode
if (
cudagraph_mode == CUDAGraphMode.FULL
and cudagraph_size is not None
and cudagraph_size not in self.graphs
):
# If graph wasn't captured yet, fall back to eager.
# This might happen when the dummy run is called before capture.
cudagraph_mode = CUDAGraphMode.NONE
cudagraph_size = None
return cudagraph_mode, cudagraph_size
def capture_graph(
self,
num_tokens: int,
capture_cg_mode: CUDAGraphMode,
generate_fn: Callable,
model_state: ModelState,
input_buffers: InputBuffers,
block_tables: BlockTables,
attn_groups: list[list[AttentionGroup]],
@@ -76,12 +98,11 @@ class EagleCudaGraphManager:
attn_metadata, slot_mappings = prepare_inputs_to_capture(
num_reqs,
num_tokens,
model_state,
input_buffers,
block_tables,
attn_groups,
self.max_model_len,
kv_cache_config,
uniform_decode_query_len=1,
)
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
@@ -158,6 +179,7 @@ class EagleCudaGraphManager:
def capture(
self,
generate_fn: Callable,
model_state: ModelState,
input_buffers: InputBuffers,
block_tables: BlockTables,
attn_groups: list[list[AttentionGroup]],
@@ -173,6 +195,7 @@ class EagleCudaGraphManager:
capture_cudagraph_mode=self.cudagraph_mode,
desc=f"Capturing eagle CUDA graphs ({self.cudagraph_mode.name})",
generate_fn=generate_fn,
model_state=model_state,
input_buffers=input_buffers,
block_tables=block_tables,
attn_groups=attn_groups,

View File

@@ -16,7 +16,9 @@ from vllm.v1.worker.gpu.attn_utils import (
build_slot_mappings_by_layer,
)
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.dp_utils import get_cudagraph_and_dp_padding
from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
from vllm.v1.worker.gpu.model_states.interface import ModelState
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
from vllm.v1.worker.gpu.spec_decode.eagle.cudagraph import EagleCudaGraphManager
from vllm.v1.worker.gpu.spec_decode.eagle.utils import load_eagle_model
@@ -44,10 +46,13 @@ class EagleSpeculator:
# the draft model's hidden size can be different from the target model's
# hidden size (e.g., Llama 3.3 70B).
self.hidden_size = self.draft_model_config.get_hidden_size()
self.inputs_embeds_size = self.draft_model_config.get_inputs_embeds_size()
self.vocab_size = self.draft_model_config.get_vocab_size()
self.dtype = vllm_config.model_config.dtype
# DP configuration
self.dp_size = vllm_config.parallel_config.data_parallel_size
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
self.input_buffers = InputBuffers(
max_num_reqs=self.max_num_reqs,
max_num_tokens=self.max_num_tokens,
@@ -77,10 +82,12 @@ class EagleSpeculator:
def set_attn(
self,
model_state: ModelState,
kv_cache_config: KVCacheConfig,
attn_groups: list[list[AttentionGroup]],
block_tables: BlockTables,
) -> None:
self.model_state = model_state
self.kv_cache_config = kv_cache_config
self.attn_groups = attn_groups
self.block_tables = block_tables
@@ -120,8 +127,8 @@ class EagleSpeculator:
self,
num_reqs: int,
num_tokens_padded: int,
attn_metadata: dict[str, Any],
slot_mappings: dict[str, torch.Tensor],
attn_metadata: dict[str, Any] | None,
slot_mappings: dict[str, torch.Tensor] | None,
num_tokens_across_dp: torch.Tensor | None,
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
) -> None:
@@ -162,9 +169,10 @@ class EagleSpeculator:
self.hidden_states,
self.max_model_len,
)
self.block_tables.compute_slot_mappings(
idx_mapping, query_start_loc, pos
)
if attn_metadata is not None:
self.block_tables.compute_slot_mappings(
idx_mapping, query_start_loc, pos
)
def capture_model(self) -> None:
if self.num_speculative_steps == 1:
@@ -172,6 +180,7 @@ class EagleSpeculator:
logger.info("Capturing model for Eagle speculator...")
self.cudagraph_manager.capture(
self.generate_draft,
self.model_state,
self.input_buffers,
self.block_tables,
self.attn_groups,
@@ -182,6 +191,8 @@ class EagleSpeculator:
def propose(
self,
input_batch: InputBatch,
attn_metadata: dict[str, Any],
slot_mappings: dict[str, torch.Tensor],
# [num_tokens, hidden_size]
last_hidden_states: torch.Tensor,
# num_layers x [num_tokens, hidden_size]
@@ -198,6 +209,9 @@ class EagleSpeculator:
temperature: torch.Tensor,
# [max_num_reqs]
seeds: torch.Tensor,
num_tokens_across_dp: torch.Tensor | None = None,
dummy_run: bool = False,
skip_attn_for_dummy_run: bool = False,
) -> torch.Tensor:
# NOTE(woosuk): To avoid CPU-GPU synchronization without CPU knowing the
# number of rejected tokens, we maintain the size of eagle's input_ids and
@@ -229,9 +243,9 @@ class EagleSpeculator:
# TODO(woosuk): Support CUDA graph for prefill.
last_hidden_states, hidden_states = self.run_model(
num_tokens,
input_batch.attn_metadata,
input_batch.slot_mappings,
num_tokens_across_dp=None, # FIXME
attn_metadata,
slot_mappings,
num_tokens_across_dp=num_tokens_across_dp,
)
sample_hidden_states = last_hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states)
@@ -277,48 +291,64 @@ class EagleSpeculator:
self.max_model_len,
self.max_num_reqs,
)
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
slot_mappings = self.block_tables.compute_slot_mappings(
idx_mapping, query_start_loc, pos
)
cudagraph_size = self.cudagraph_manager.get_cudagraph_size(num_reqs)
cudagraph_mode = self.cudagraph_manager.cudagraph_mode
if cudagraph_size is not None and cudagraph_mode == CUDAGraphMode.FULL:
if not (dummy_run and skip_attn_for_dummy_run):
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
slot_mappings = self.block_tables.compute_slot_mappings(
idx_mapping, query_start_loc, pos
)
cudagraph_mode, cudagraph_size = (
self.cudagraph_manager.get_cudagraph_runtime_mode(num_reqs)
)
num_tokens_padded, num_tokens_across_dp, synced_cudagraph_mode = (
get_cudagraph_and_dp_padding(
num_reqs,
cudagraph_size,
cudagraph_mode.value,
self.dp_size,
self.dp_rank,
)
)
cudagraph_mode = CUDAGraphMode(synced_cudagraph_mode)
if cudagraph_mode == CUDAGraphMode.FULL:
# Run full CUDA graph.
self.cudagraph_manager.run_fullgraph(cudagraph_size)
self.cudagraph_manager.run_fullgraph(num_tokens_padded)
return self.draft_tokens[:num_reqs]
# Run eager or piecewise CUDA graph.
num_tokens_padded = cudagraph_size if cudagraph_size is not None else num_reqs
query_start_loc_cpu = torch.arange(
num_reqs + 1, dtype=torch.int32, device="cpu"
)
block_tables = [x[:num_reqs] for x in self.block_tables.input_block_tables]
attn_metadata_updated = None
slot_mappings_updated = None
if not (dummy_run and skip_attn_for_dummy_run):
query_start_loc_cpu = torch.arange(
num_reqs + 1, dtype=torch.int32, device="cpu"
)
block_tables = [x[:num_reqs] for x in self.block_tables.input_block_tables]
# FIXME(woosuk): This is UNSAFE!!
attn_metadata_updated = build_attn_metadata(
attn_groups=self.attn_groups,
num_reqs=num_reqs,
num_tokens=num_reqs,
query_start_loc_gpu=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
max_query_len=1,
seq_lens=self.input_buffers.seq_lens[:num_reqs],
max_seq_len=self.max_model_len,
block_tables=block_tables,
slot_mappings=slot_mappings,
kv_cache_config=self.kv_cache_config,
)
slot_mappings_updated = build_slot_mappings_by_layer(
slot_mappings, self.kv_cache_config
)
# FIXME(woosuk): This is UNSAFE!!
attn_metadata = build_attn_metadata(
attn_groups=self.attn_groups,
num_reqs=num_reqs,
num_tokens=num_reqs,
query_start_loc_gpu=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
max_query_len=1,
seq_lens=self.input_buffers.seq_lens[:num_reqs],
max_seq_len=self.max_model_len,
block_tables=block_tables,
slot_mappings=slot_mappings,
kv_cache_config=self.kv_cache_config,
)
slot_mappings_by_layer = build_slot_mappings_by_layer(
slot_mappings, self.kv_cache_config
)
self.generate_draft(
num_reqs,
num_tokens_padded,
attn_metadata,
slot_mappings_by_layer,
num_tokens_across_dp=None, # FIXME
attn_metadata_updated,
slot_mappings_updated,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_mode,
)
return self.draft_tokens[:num_reqs]

View File

@@ -0,0 +1,105 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import torch
from vllm import PoolingParams, SamplingParams
from vllm.v1.core.sched.output import (
CachedRequestData,
GrammarOutput,
NewRequestData,
SchedulerOutput,
)
from vllm.v1.request import Request
from vllm.v1.worker.gpu.model_runner import GPUModelRunner
@torch.inference_mode()
def warmup_kernels(model_runner: GPUModelRunner) -> None:
"""Run two execute_model + sample_tokens iterations to JIT compile
triton kernels.
The first iteration simulates a prefill with requests of 2 prompt
tokens each. The second iteration simulates a decode step with all
requests generating 1 token each.
"""
prompt_token_ids = [0, 1]
prompt_len = len(prompt_token_ids)
num_reqs = min(
model_runner.scheduler_config.max_num_seqs,
model_runner.scheduler_config.max_num_batched_tokens // prompt_len,
)
num_kv_cache_groups = len(model_runner.kv_cache_config.kv_cache_groups)
req_ids = [f"_warmup_{i}_" for i in range(num_reqs)]
# SamplingParams exercising all sampling features.
if model_runner.is_pooling_model:
sampling_params = None
pooling_params = PoolingParams()
else:
sampling_params = SamplingParams.for_sampler_warmup()
pooling_params = None
# Step 1: Prefill all requests with 2 prompt tokens each.
new_reqs = [
NewRequestData.from_request(
Request(req_ids[i], prompt_token_ids, sampling_params, pooling_params),
# Each request uses a distinct block per KV cache group.
block_ids=tuple([i] for _ in range(num_kv_cache_groups)),
prefill_token_ids=prompt_token_ids,
)
for i in range(num_reqs)
]
prefill_output = SchedulerOutput.make_empty()
prefill_output.scheduled_new_reqs = new_reqs
prefill_output.num_scheduled_tokens = {rid: prompt_len for rid in req_ids}
prefill_output.total_num_scheduled_tokens = prompt_len * num_reqs
prefill_output.num_common_prefix_blocks = [0] * num_kv_cache_groups
# Disable KV connector for warmup run.
model_runner.kv_connector.set_disabled(True)
model_runner.execute_model(prefill_output)
if not model_runner.is_pooling_model:
# Warm up sampler and perform a decode step for non-pooling models.
grammar_output = None
if model_runner.is_last_pp_rank:
# Build a GrammarOutput to exercise the structured output bitmask
# kernel during the prefill step.
vocab_size = model_runner.model_config.get_vocab_size()
bitmask_width = (vocab_size + 31) // 32
grammar_bitmask = np.full(
(len(req_ids), bitmask_width), fill_value=-1, dtype=np.int32
)
grammar_output = GrammarOutput(
structured_output_request_ids=req_ids, grammar_bitmask=grammar_bitmask
)
model_runner.sample_tokens(grammar_output)
# Step 2: Decode all requests with 1 token each.
cached_req_data = CachedRequestData.make_empty()
cached_req_data.req_ids = list(req_ids)
cached_req_data.new_block_ids = [None] * num_reqs
cached_req_data.num_computed_tokens = [prompt_len] * num_reqs
cached_req_data.num_output_tokens = [1] * num_reqs
decode_output = SchedulerOutput.make_empty()
decode_output.scheduled_cached_reqs = cached_req_data
decode_output.num_scheduled_tokens = {rid: 1 for rid in req_ids}
decode_output.total_num_scheduled_tokens = num_reqs
decode_output.num_common_prefix_blocks = [0] * num_kv_cache_groups
model_runner.execute_model(decode_output)
model_runner.sample_tokens(None)
# Clean up - process finish_req_ids.
cleanup_output = SchedulerOutput.make_empty()
cleanup_output.finished_req_ids = set(req_ids)
model_runner.execute_model(cleanup_output)
model_runner.kv_connector.set_disabled(False)
torch.cuda.synchronize()

View File

@@ -53,6 +53,13 @@ class CachedRequestState:
pooling_params: PoolingParams | None = None
pooling_states: PoolingStates | None = None
# for multi layer eagle proposer
cached_len: torch.Tensor | None = None
cached_token_ids: torch.Tensor | None = None
cached_hidden_states: torch.Tensor | None = None
cached_slot_mappings: torch.Tensor | None = None
cached_positions: torch.Tensor | None = None
def __post_init__(self):
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
self.prompt_token_ids, self.prompt_embeds
@@ -95,6 +102,8 @@ class InputBatch:
is_spec_decode: bool = False,
is_pooling_model: bool = False,
cp_kv_cache_interleave_size: int = 1,
multi_layer_eagle_num: int = 0,
hidden_size: int | None = None,
):
self.is_pooling_model = is_pooling_model
self.is_spec_decode = is_spec_decode
@@ -211,6 +220,46 @@ class InputBatch:
)
self.num_accepted_tokens_cpu = self.num_accepted_tokens_cpu_tensor.numpy()
# Multi layer eagle
self.multi_layer_eagle_num = multi_layer_eagle_num
if multi_layer_eagle_num > 0:
self.cached_len = torch.zeros(
(max_num_reqs,), dtype=torch.int64, device=device
)
self.cached_token_ids = torch.zeros(
(
max_num_reqs,
multi_layer_eagle_num,
),
dtype=torch.int32,
device=device,
)
self.cached_hidden_states = torch.zeros(
(
max_num_reqs,
multi_layer_eagle_num,
hidden_size,
),
dtype=torch.float,
device=device,
)
self.cached_slot_mappings = torch.zeros(
(
max_num_reqs,
multi_layer_eagle_num,
),
dtype=torch.int64,
device=device,
)
self.cached_positions = torch.zeros(
(
max_num_reqs,
multi_layer_eagle_num,
),
dtype=torch.int64,
device=device,
)
# lora related
self.request_lora_mapping = np.zeros((self.max_num_reqs,), dtype=np.int64)
self.lora_id_to_request_ids: dict[int, set[str]] = {}
@@ -425,6 +474,13 @@ class InputBatch:
# Speculative decoding: by default 1 token is generated.
self.num_accepted_tokens_cpu[req_index] = 1
if self.multi_layer_eagle_num > 0:
self.cached_len[req_index] = request.cached_len
self.cached_token_ids[req_index] = request.cached_token_ids
self.cached_hidden_states[req_index] = request.cached_hidden_states
self.cached_slot_mappings[req_index] = request.cached_slot_mappings
self.cached_positions[req_index] = request.cached_positions
# Add request lora ID
if request.lora_request:
lora_id = request.lora_request.lora_int_id
@@ -623,6 +679,24 @@ class InputBatch:
self.allowed_token_ids_mask_cpu_tensor[i1],
)
if self.multi_layer_eagle_num > 0:
self.cached_len[i1], self.cached_len[i2] = (
self.cached_len[i2],
self.cached_len[i1],
)
self.cached_token_ids[[i1, i2], ...] = self.cached_token_ids[
[i2, i1], ...
]
self.cached_hidden_states[[i1, i2], ...] = self.cached_hidden_states[
[i2, i1], ...
]
self.cached_slot_mappings[[i1, i2], ...] = self.cached_slot_mappings[
[i2, i1], ...
]
self.cached_positions[[i1, i2], ...] = self.cached_positions[
[i2, i1], ...
]
def condense(self) -> None:
"""Slide non-empty requests down into lower, empty indices.
@@ -745,6 +819,21 @@ class InputBatch:
if bad_words_token_ids is not None:
self.bad_words_token_ids[empty_index] = bad_words_token_ids
if self.multi_layer_eagle_num > 0:
self.cached_len[empty_index] = self.cached_len[last_req_index]
self.cached_token_ids[empty_index] = self.cached_token_ids[
last_req_index
]
self.cached_hidden_states[empty_index] = self.cached_hidden_states[
last_req_index
]
self.cached_slot_mappings[empty_index] = self.cached_slot_mappings[
last_req_index
]
self.cached_positions[empty_index] = self.cached_positions[
last_req_index
]
# Decrement last_req_index since it is now empty.
last_req_index -= 1

File diff suppressed because it is too large Load Diff

View File

@@ -7,11 +7,10 @@ import os
from collections.abc import Callable
from contextlib import AbstractContextManager, nullcontext
from types import NoneType
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any
import numpy as np
import torch
import torch.distributed
import torch.nn as nn
import vllm.envs as envs
@@ -32,14 +31,13 @@ from vllm.distributed.kv_transfer import (
)
from vllm.distributed.parallel_state import (
Handle,
get_pcp_group,
get_pp_group,
get_tp_group,
get_world_group
)
from vllm.distributed.weight_transfer import WeightTransferEngineFactory
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.models.interfaces import is_mixture_of_experts
from vllm.model_executor.warmup.kernel_warmup import kernel_warmup
from vllm.platforms import current_platform
from vllm.profiler.wrapper import CudaProfilerWrapper, TorchProfilerWrapper
@@ -49,7 +47,6 @@ from vllm.tracing import instrument
from vllm.utils.mem_utils import MemorySnapshot, format_gib, memory_profiling
from vllm.utils.torch_utils import set_random_seed
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import (
AsyncModelRunnerOutput,
@@ -61,6 +58,8 @@ from vllm.v1.worker.utils import is_residual_scattered_for_sp
from vllm.v1.worker.worker_base import WorkerBase
from vllm.v1.worker.workspace import init_workspace_manager
from ...model_executor.model_loader import TensorizerLoader
from .gpu.warmup import warmup_kernels
from .utils import request_memory
logger = init_logger(__name__)
@@ -123,6 +122,10 @@ class Worker(WorkerBase):
precision = envs.VLLM_FLOAT32_MATMUL_PRECISION
torch.set_float32_matmul_precision(precision)
from vllm.distributed.elastic_ep.elastic_execute import ElasticEPScalingExecutor
self.elastic_ep_executor = ElasticEPScalingExecutor(self)
# Buffers saved before sleep
self._sleep_saved_buffers: dict[str, torch.Tensor] = {}
@@ -316,12 +319,29 @@ class Worker(WorkerBase):
# FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
# to hijack tensor allocation.
def load_model(self) -> None:
eep_scale_up = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
dummy_weights = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
if dummy_weights:
(
expanded_physical_to_logical,
num_logical_experts,
old_num_physical_experts,
) = self.elastic_ep_executor.receive_expert_mapping()
num_physical_experts = expanded_physical_to_logical.shape[1]
self.parallel_config.eplb_config.num_redundant_experts = (
num_physical_experts - num_logical_experts
)
with (
self._maybe_get_memory_pool_context(tag="weights"),
set_current_vllm_config(self.vllm_config),
):
self.model_runner.load_model(eep_scale_up=eep_scale_up)
self.model_runner.load_model(load_dummy_weights=dummy_weights)
if dummy_weights:
self.model_runner.setup_eplb_from_mapping(
expanded_physical_to_logical, old_num_physical_experts
)
self.model_runner.eep_eplb_suppressed = True
def update_config(self, overrides: dict[str, Any]) -> None:
self.model_runner.update_config(overrides)
@@ -421,9 +441,10 @@ class Worker(WorkerBase):
# metadata across workers.
if (metadata := connector.get_handshake_metadata()) is None:
return None
tp_rank = get_tp_group().rank_in_group
return {tp_rank: metadata}
# tp_rank = get_tp_group().rank_in_group
global_rank = get_world_group().rank_in_group
return {global_rank: metadata}
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
return self.model_runner.get_kv_cache_spec()
@@ -461,8 +482,16 @@ class Worker(WorkerBase):
else:
self.model_runner.initialize_kv_cache(kv_cache_config)
# Build KV-zero metadata outside the CuMem pool so the bookkeeping
# GPU tensors (seg_addrs, block-id buffers) use the standard PyTorch
# allocator and are not discarded during sleep/wake cycles.
if kv_cache_config.needs_kv_cache_zeroing and hasattr(
self.model_runner, "_init_kv_zero_meta"
):
self.model_runner._init_kv_zero_meta()
@instrument(span_name="Warmup (GPU)")
def compile_or_warm_up_model(self) -> None:
def compile_or_warm_up_model(self) -> float:
warmup_sizes = []
if self.vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE:
@@ -558,12 +587,15 @@ class Worker(WorkerBase):
logger.debug(msg)
# Warm up sampler and preallocate memory buffer for logits and other
# sampling related tensors of max possible shape to avoid memory
# fragmentation issue.
# NOTE: This is called after `capture_model` on purpose to prevent
# memory buffers from being cleared by `torch.cuda.empty_cache`.
if get_pp_group().is_last_rank:
if self.use_v2_model_runner:
# V2: Run full execute_model + sample_tokens to JIT compile triton kernels.
warmup_kernels(self.model_runner)
elif get_pp_group().is_last_rank:
# V1: Warm up sampler and preallocate memory buffer for logits and other
# sampling related tensors of max possible shape to avoid memory
# fragmentation issue.
# NOTE: This is called after `capture_model` on purpose to prevent
# memory buffers from being cleared by `torch.cuda.empty_cache`.
max_num_reqs = min(
self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens,
@@ -584,6 +616,8 @@ class Worker(WorkerBase):
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
return self.compilation_config.compilation_time
def reset_mm_cache(self) -> None:
self.model_runner.reset_mm_cache()
@@ -696,6 +730,12 @@ class Worker(WorkerBase):
output = self.model_runner.execute_model(
scheduler_output, intermediate_tensors
)
if (
self.use_v2_model_runner
and self.model_runner.is_pooling_model
and output is None
):
output = self.model_runner.pool() # type: ignore
if isinstance(
output, ModelRunnerOutput | AsyncModelRunnerOutput | NoneType
):
@@ -744,7 +784,8 @@ class Worker(WorkerBase):
# Create the profiler wrapper only on the first start call
if self.profiler is None:
if self.profiler_config.profiler == "torch":
profiler_type = self.profiler_config.profiler
if profiler_type == "torch":
self.profiler = TorchProfilerWrapper(
self.profiler_config,
worker_name=trace_name,
@@ -754,14 +795,18 @@ class Worker(WorkerBase):
logger.debug(
"Starting torch profiler with trace name: %s", trace_name
)
elif self.profiler_config.profiler == "cuda":
elif profiler_type == "cuda":
self.profiler = CudaProfilerWrapper(self.profiler_config)
logger.debug("Starting CUDA profiler")
self.profiler.start()
else:
# Profiler already initialized. Restart profiling but keep
# the original trace name from the first initialization.
self.profiler.start()
else:
# Config validation should prevent this code being reached
raise ValueError(
f"Invalid profiler value of {self.profiler_config.profiler}"
)
# If profiler already initialized, restart profiling but keep
# the original trace name from the first initialization.
self.profiler.start()
else:
if self.profiler is None:
logger.warning("Profiler was not started, nothing to stop.")
@@ -787,227 +832,6 @@ class Worker(WorkerBase):
# worker will always be healthy as long as it's running.
return
def _eplb_before_scale_down(self, old_ep_size: int, new_ep_size: int) -> None:
from vllm.distributed.parallel_state import get_ep_group
if get_ep_group().rank == 0:
logger.info(
"[Elastic EP] Starting expert resharding before scaling down..."
)
rank_mapping = {
old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1
for old_ep_rank in range(old_ep_size)
}
assert self.model_runner.eplb_state is not None
self.model_runner.eplb_state.rearrange(
execute_shuffle=True,
global_expert_loads=None,
rank_mapping=rank_mapping,
)
torch.cuda.synchronize()
if get_ep_group().rank == 0:
logger.info("[Elastic EP] Expert resharding completed!")
def _eplb_after_scale_up(
self,
old_ep_size: int,
new_ep_size: int,
global_expert_loads: list[torch.Tensor] | None,
) -> None:
from vllm.distributed.parallel_state import get_ep_group
if get_ep_group().rank == 0:
logger.info("[Elastic EP] Starting expert resharding after scaling up...")
rank_mapping = {old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size)}
assert self.model_runner.eplb_state is not None
self.model_runner.eplb_state.rearrange(
execute_shuffle=True,
global_expert_loads=global_expert_loads,
rank_mapping=rank_mapping,
)
if get_ep_group().rank == 0:
logger.info("[Elastic EP] Expert resharding completed!")
def _reconfigure_parallel_config(
self, reconfig_request: ReconfigureDistributedRequest
) -> None:
"""
Update parallel config with provided reconfig_request
"""
parallel_config = self.vllm_config.parallel_config
parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
if (
reconfig_request.new_data_parallel_rank
!= ReconfigureRankType.KEEP_CURRENT_RANK
):
parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
if (
reconfig_request.new_data_parallel_rank_local
!= ReconfigureRankType.KEEP_CURRENT_RANK
):
parallel_config.data_parallel_rank_local = (
reconfig_request.new_data_parallel_rank_local
)
parallel_config.data_parallel_master_ip = (
reconfig_request.new_data_parallel_master_ip
)
parallel_config.data_parallel_master_port = (
reconfig_request.new_data_parallel_master_port
)
def _reconfigure_moe(
self, old_ep_size: int, new_ep_size: int
) -> list[torch.Tensor] | None:
"""
Reconfigure MoE modules with provided reconfig_request
Return the global expert load if new_ep_size > old_ep_size,
otherwise None
"""
from vllm.distributed.parallel_state import (
get_dp_group,
get_ep_group,
prepare_communication_buffer_for_model,
)
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE,
FusedMoEParallelConfig,
)
parallel_config = self.vllm_config.parallel_config
def get_moe_modules(model: torch.nn.Module) -> list[FusedMoE]:
return [
module
for module in model.modules()
if (
module.__class__.__name__ == "FusedMoE"
or module.__class__.__name__ == "SharedFusedMoE"
)
]
def update_moe_modules(moe_modules: list[FusedMoE], num_local_experts: int):
assert all(
module.moe_config.num_local_experts == num_local_experts
for module in moe_modules
), "All MoE modules must have the same number of experts"
for module in moe_modules:
module.moe_config.num_experts = num_local_experts * new_ep_size
module.global_num_experts = module.moe_config.num_experts
tp_size = get_tp_group().world_size
is_sequence_parallel = parallel_config.use_sequence_parallel_moe
sp_size = tp_size if is_sequence_parallel else 1
module.moe_parallel_config = FusedMoEParallelConfig.make(
tp_size_=tp_size,
pcp_size_=get_pcp_group().world_size,
dp_size_=get_dp_group().world_size,
sp_size_=sp_size,
vllm_parallel_config=parallel_config,
)
module.moe_config.moe_parallel_config = module.moe_parallel_config
return moe_modules
model_moe_modules = get_moe_modules(self.model_runner.model)
num_local_experts = model_moe_modules[0].moe_config.num_local_experts
update_moe_modules(model_moe_modules, num_local_experts)
drafter_model = None
if hasattr(self.model_runner, "drafter") and hasattr(
self.model_runner.drafter, "model"
):
drafter_model = self.model_runner.drafter.model
if drafter_model is not None and is_mixture_of_experts(drafter_model):
drafter_moe_modules = get_moe_modules(drafter_model)
# Check if drafter and model have matching configs
assert (
drafter_moe_modules[0].moe_config.num_local_experts == num_local_experts
), "Drafter and model configs should be the same"
update_moe_modules(drafter_moe_modules, num_local_experts)
if new_ep_size < old_ep_size:
num_local_physical_experts = num_local_experts
assert self.model_runner.eplb_state is not None
new_physical_experts = (
self.model_runner.eplb_state.physical_to_logical_map.shape[1] # type: ignore[attr-defined]
)
parallel_config.eplb_config.num_redundant_experts = (
new_physical_experts
- self.model_runner.eplb_state.logical_replica_count.shape[1] # type: ignore[attr-defined]
)
global_expert_loads = None
else:
num_local_physical_experts_tensor = torch.tensor(
[num_local_experts], dtype=torch.int32, device="cpu"
)
torch.distributed.broadcast(
num_local_physical_experts_tensor,
group=get_ep_group().cpu_group,
group_src=0,
)
num_local_physical_experts = int(num_local_physical_experts_tensor.item())
new_physical_experts = num_local_physical_experts * new_ep_size
assert self.model_runner.eplb_state is not None
global_expert_loads_any = self.model_runner.eplb_state.rearrange(
execute_shuffle=False
)
global_expert_loads = cast(list[torch.Tensor], global_expert_loads_any)
parallel_config.eplb_config.num_redundant_experts = (
new_physical_experts - global_expert_loads[0].shape[1]
)
prepare_communication_buffer_for_model(self.model_runner.model)
if drafter_model is not None:
prepare_communication_buffer_for_model(drafter_model)
self.model_runner.model.update_physical_experts_metadata(
num_physical_experts=new_physical_experts,
num_local_physical_experts=num_local_physical_experts,
)
return global_expert_loads
def reinitialize_distributed(
self, reconfig_request: ReconfigureDistributedRequest
) -> None:
from vllm.config import set_current_vllm_config
from vllm.distributed.parallel_state import (
cleanup_dist_env_and_memory,
get_ep_group,
)
old_ep_size = get_ep_group().world_size
old_ep_rank = get_ep_group().rank
new_ep_size = (
reconfig_request.new_data_parallel_size
* get_tp_group().world_size
* get_pp_group().world_size
)
if new_ep_size < old_ep_size:
self._eplb_before_scale_down(old_ep_size, new_ep_size)
cleanup_dist_env_and_memory()
if (
reconfig_request.new_data_parallel_rank
== ReconfigureRankType.SHUTDOWN_CURRENT_RANK
):
assert old_ep_rank >= new_ep_size
# shutdown
return
self._reconfigure_parallel_config(reconfig_request)
with set_current_vllm_config(self.vllm_config):
init_worker_distributed_environment(
self.vllm_config,
self.rank,
self.distributed_init_method,
self.local_rank,
)
global_expert_loads = self._reconfigure_moe(old_ep_size, new_ep_size)
if new_ep_size > old_ep_size:
assert global_expert_loads is not None
self._eplb_after_scale_up(old_ep_size, new_ep_size, global_expert_loads)
def save_sharded_state(
self,
path: str,
@@ -1023,12 +847,11 @@ class Worker(WorkerBase):
max_size=max_size,
)
def save_tensorized_model(
self,
tensorizer_config: "TensorizerConfig",
) -> None:
self.model_runner.save_tensorized_model(
def save_tensorized_model(self, tensorizer_config: "TensorizerConfig") -> None:
TensorizerLoader.save_model(
self.get_model(),
tensorizer_config=tensorizer_config,
model_config=self.model_config,
)
def init_weight_transfer_engine(self, init_info: dict) -> None:
@@ -1104,6 +927,9 @@ class Worker(WorkerBase):
if weight_transfer_engine := getattr(self, "weight_transfer_engine", None):
weight_transfer_engine.shutdown()
def elastic_ep_execute(self, execute_method: str, *args, **kwargs):
return self.elastic_ep_executor.execute(execute_method, *args, **kwargs)
def init_worker_distributed_environment(
vllm_config: VllmConfig,

View File

@@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
import itertools
from collections.abc import Callable
from typing import Any
import torch
@@ -13,6 +15,7 @@ from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.gpu_input_batch import CachedRequestState
from vllm.v1.worker.lora_model_runner_mixin import GPUInputBatch
@@ -59,10 +62,36 @@ def get_mamba_groups(kv_cache_config: KVCacheConfig) -> tuple[list[int], MambaSp
return mamba_group_ids, mamba_specs[0]
@dataclasses.dataclass
class MambaCopyBuffers:
src_ptrs: CpuGpuBuffer
dst_ptrs: CpuGpuBuffer
sizes: CpuGpuBuffer
offset: int = 0
@classmethod
def create(
cls,
max_num_reqs: int,
kv_cache_config: KVCacheConfig,
copy_funcs: tuple[MambaStateCopyFunc, ...],
make_buffer: Callable[..., CpuGpuBuffer],
) -> "MambaCopyBuffers":
mamba_group_ids, _ = get_mamba_groups(kv_cache_config)
entries_per_req = sum(
len(kv_cache_config.kv_cache_groups[gid].layer_names)
for gid in mamba_group_ids
) * len(copy_funcs)
n = max_num_reqs * entries_per_req
return cls(
src_ptrs=make_buffer(n, dtype=torch.int64),
dst_ptrs=make_buffer(n, dtype=torch.int64),
sizes=make_buffer(n, dtype=torch.int32),
)
def collect_mamba_copy_meta(
src_state_list: list[int],
dest_state_list: list[int],
num_elements_list: list[int],
copy_bufs: MambaCopyBuffers,
kv_cache_config: KVCacheConfig,
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
mamba_group_ids: list[int],
@@ -71,10 +100,15 @@ def collect_mamba_copy_meta(
accept_token_bias: int,
req_state: CachedRequestState,
forward_context: dict[str, Any],
):
) -> None:
if src_block_idx == dest_block_idx and accept_token_bias == 0:
return
src_ptrs_np = copy_bufs.src_ptrs.np
dst_ptrs_np = copy_bufs.dst_ptrs.np
sizes_np = copy_bufs.sizes.np
offset = copy_bufs.offset
for mamba_group_id in mamba_group_ids:
block_ids = req_state.block_ids[mamba_group_id]
dest_block_id = block_ids[dest_block_idx]
@@ -87,25 +121,23 @@ def collect_mamba_copy_meta(
state, block_ids, src_block_idx, accept_token_bias + 1
)
src_state_list.append(copy_spec.start_addr)
dest_state_list.append(state[dest_block_id].data_ptr())
num_elements_list.append(copy_spec.num_elements * state.element_size())
src_ptrs_np[offset] = copy_spec.start_addr
dst_ptrs_np[offset] = state[dest_block_id].data_ptr()
sizes_np[offset] = copy_spec.num_elements * state.element_size()
offset += 1
copy_bufs.offset = offset
def do_mamba_copy_block(
src_state_list: list[int],
dest_state_list: list[int],
num_elements_list: list[int],
):
if len(src_state_list) == 0:
def do_mamba_copy_block(copy_bufs: MambaCopyBuffers):
n = copy_bufs.offset
if n == 0:
return
assert len(src_state_list) == len(dest_state_list)
assert len(src_state_list) == len(num_elements_list)
src_state_ptrs = torch.tensor(src_state_list, device="cuda", dtype=torch.int64)
dst_state_ptrs = torch.tensor(dest_state_list, device="cuda", dtype=torch.int64)
num_elements = torch.tensor(num_elements_list, device="cuda", dtype=torch.int32)
batch_memcpy(src_state_ptrs, dst_state_ptrs, num_elements)
batch_memcpy(
copy_bufs.src_ptrs.copy_to_gpu(n),
copy_bufs.dst_ptrs.copy_to_gpu(n),
copy_bufs.sizes.copy_to_gpu(n),
)
def preprocess_mamba(
@@ -117,6 +149,7 @@ def preprocess_mamba(
requests: dict[str, CachedRequestState],
forward_context: dict[str, Any],
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
copy_bufs: MambaCopyBuffers,
):
"""
Copy the mamba state of previous step to the last
@@ -138,9 +171,7 @@ def preprocess_mamba(
for req_id in itertools.chain(finished_req_ids, preempted_req_ids, resumed_req_ids):
mamba_state_idx.pop(req_id, None)
src_state_list: list[int] = []
dest_state_list: list[int] = []
num_elements_list: list[int] = []
copy_bufs.offset = 0
for i, req_id in enumerate(input_batch.req_ids):
req_state = requests[req_id]
prev_state_idx = mamba_state_idx.get(req_id)
@@ -169,9 +200,7 @@ def preprocess_mamba(
mamba_state_idx[req_id] = curr_state_idx
if prev_state_idx != -1 and prev_state_idx != curr_state_idx:
collect_mamba_copy_meta(
src_state_list,
dest_state_list,
num_elements_list,
copy_bufs,
kv_cache_config,
mamba_state_copy_funcs,
mamba_group_ids,
@@ -182,7 +211,7 @@ def preprocess_mamba(
forward_context,
)
input_batch.num_accepted_tokens_cpu[i] = 1
do_mamba_copy_block(src_state_list, dest_state_list, num_elements_list)
do_mamba_copy_block(copy_bufs)
def postprocess_mamba(
@@ -193,6 +222,7 @@ def postprocess_mamba(
mamba_state_idx: dict[str, int],
forward_context: dict[str, Any],
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
copy_bufs: MambaCopyBuffers,
):
"""
If a blocks is converted from partial block to full block in this step, copy the
@@ -203,9 +233,7 @@ def postprocess_mamba(
num_accepted_tokens_cpu = input_batch.num_accepted_tokens_cpu
# NOTE: can be optimized as this function always returns the same result
mamba_group_ids, mamba_spec = get_mamba_groups(kv_cache_config)
src_state_list: list[int] = []
dest_state_list: list[int] = []
num_elements_list: list[int] = []
copy_bufs.offset = 0
for i, req_id in enumerate(input_batch.req_ids):
req_state = requests[req_id]
num_computed_tokens = req_state.num_computed_tokens
@@ -225,9 +253,7 @@ def postprocess_mamba(
src_block_idx = mamba_state_idx[req_id]
dest_block_idx = aligned_new_computed_tokens // mamba_spec.block_size - 1
collect_mamba_copy_meta(
src_state_list,
dest_state_list,
num_elements_list,
copy_bufs,
kv_cache_config,
mamba_state_copy_funcs,
mamba_group_ids,
@@ -239,4 +265,4 @@ def postprocess_mamba(
)
if src_block_idx == dest_block_idx:
num_accepted_tokens_cpu[i] = 1
do_mamba_copy_block(src_state_list, dest_state_list, num_elements_list)
do_mamba_copy_block(copy_bufs)

View File

@@ -2,7 +2,10 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections import defaultdict
from collections.abc import Iterable
from dataclasses import dataclass, field
from itertools import product as iprod
from typing import Any
import torch
@@ -12,13 +15,208 @@ from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.model_executor.models.utils import extract_layer_index
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import largest_power_of_2_divisor
from vllm.utils.mem_utils import MemorySnapshot, format_gib
from vllm.v1.attention.backend import AttentionBackend, AttentionMetadataBuilder
from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionMetadataBuilder,
MultipleOf,
)
from vllm.v1.kv_cache_interface import (
AttentionSpec,
EncoderOnlyAttentionSpec,
FullAttentionSpec,
KVCacheConfig,
KVCacheGroupSpec,
KVCacheSpec,
MambaSpec,
UniformTypeKVCacheSpecs,
)
logger = init_logger(__name__)
@triton.jit
def _zero_kv_blocks_kernel(
seg_addrs_ptr,
block_ids_ptr,
n_blocks,
N_SEGS: tl.constexpr,
PAGE_SIZE_EL: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""Zero KV cache blocks across all segments in a single launch.
Each segment is a contiguous region of one block's data. For backends
where blocks are outermost (block_dim=0) there is one segment per
buffer. For backends where K/V is outermost (block_dim=1) there are
two segments per buffer (one for K, one for V).
seg_addrs_ptr holds absolute byte addresses (int64) for each segment,
allowing segments to live in different CUDA allocations.
Programs are mapped as (block_index, seg_index, chunk_index).
"""
pid = tl.program_id(0)
chunks = PAGE_SIZE_EL // BLOCK_SIZE
work_per_block = N_SEGS * chunks
block_index = pid // work_per_block
if block_index >= n_blocks:
return
remainder = pid % work_per_block
seg_index = remainder // chunks
chunk_index = remainder % chunks
block_id = tl.load(block_ids_ptr + block_index)
seg_addr = tl.load(seg_addrs_ptr + seg_index)
ptr = tl.cast(seg_addr, tl.pointer_type(tl.int32))
offset = (
block_id.to(tl.int64) * PAGE_SIZE_EL + chunk_index.to(tl.int64) * BLOCK_SIZE
)
cols = tl.arange(0, BLOCK_SIZE).to(tl.int64)
tl.store(ptr + offset + cols, tl.zeros([BLOCK_SIZE], dtype=tl.int32))
class KVBlockZeroer:
"""Manages efficient zeroing of KV cache blocks via a Triton kernel.
Call :meth:`init_meta` once after KV caches are allocated to precompute
segment addresses, then call :meth:`zero_block_ids` each step to zero
newly-allocated blocks.
"""
def __init__(self, device: torch.device, pin_memory: bool):
self.device = device
self.pin_memory = pin_memory
self._meta: tuple[torch.Tensor, int, int, int] | None = None
self._id_cap: int = 0
self._ids_pinned: torch.Tensor | None = None
self._ids_gpu: torch.Tensor | None = None
def init_meta(
self,
attn_groups_iter: Iterable["AttentionGroup"],
kernel_block_sizes: list[int],
cache_dtype: str,
runner_only_attn_layers: set[str],
static_forward_context: dict[str, Any],
) -> None:
"""One-time precomputation for zero_block_ids.
Builds absolute-address table for the Triton zeroing kernel.
Each entry is the absolute byte address of a segment start on the
GPU, so segments in different CUDA allocations work correctly.
Block IDs from the scheduler reference logical blocks whose size
may differ from the kernel block size (virtual block splitting).
PAGE_SIZE_EL accounts for this ratio so that
``block_id * PAGE_SIZE_EL`` lands at the correct offset.
Only AttentionSpec layers are processed; Mamba layers are skipped.
"""
seen_ptrs: set[int] = set()
seg_addrs: list[int] = []
page_size_el: int | None = None
for group in attn_groups_iter:
spec = group.kv_cache_spec
if type(spec) is not FullAttentionSpec:
continue
if group.kv_cache_group_id >= len(kernel_block_sizes):
continue
kernel_bs = kernel_block_sizes[group.kv_cache_group_id]
ratio = spec.block_size // kernel_bs
block_dim = group.backend.get_kv_cache_block_dim(
kernel_bs,
spec.num_kv_heads,
spec.head_size,
cache_dtype_str=cache_dtype,
)
for layer_name in group.layer_names:
if layer_name in runner_only_attn_layers:
continue
kv = static_forward_context[layer_name].kv_cache[0]
if isinstance(kv, list):
continue
dp = kv.data_ptr()
if dp in seen_ptrs:
continue
seen_ptrs.add(dp)
el = kv.element_size()
cur_bytes = kv.stride(block_dim) * el
assert cur_bytes % 4 == 0
kernel_block_el = cur_bytes // 4
cur_page_el = kernel_block_el * ratio
if page_size_el is None:
page_size_el = cur_page_el
else:
assert page_size_el == cur_page_el, (
f"Non-uniform page sizes: {page_size_el} vs {cur_page_el}"
)
block_stride_bytes = cur_bytes
outer_dims = [
d
for d in range(block_dim)
if kv.stride(d) * el > block_stride_bytes
]
outer_strides = [kv.stride(d) * el for d in outer_dims]
for outer in iprod(*(range(kv.shape[d]) for d in outer_dims)):
off_bytes = sum(i * s for i, s in zip(outer, outer_strides))
seg_addrs.append(dp + off_bytes)
if not seg_addrs or page_size_el is None:
self._meta = None
return
blk_size = min(largest_power_of_2_divisor(page_size_el), 1024)
self._id_cap = 8192
self._ids_pinned = torch.empty(
self._id_cap,
dtype=torch.int64,
pin_memory=self.pin_memory,
)
self._ids_gpu = torch.empty(self._id_cap, dtype=torch.int64, device=self.device)
self._meta = (
torch.tensor(seg_addrs, dtype=torch.int64, device=self.device),
page_size_el,
blk_size,
len(seg_addrs),
)
def zero_block_ids(self, block_ids: list[int]) -> None:
"""Zero the KV cache memory for the given block IDs."""
if not block_ids or self._meta is None:
return
seg_addrs, page_size_el, blk_size, n_segs = self._meta
n_blocks = len(block_ids)
if n_blocks > self._id_cap:
self._id_cap = n_blocks * 2
self._ids_pinned = torch.empty(
self._id_cap,
dtype=torch.int64,
pin_memory=self.pin_memory,
)
self._ids_gpu = torch.empty(
self._id_cap, dtype=torch.int64, device=self.device
)
assert self._ids_pinned is not None and self._ids_gpu is not None
self._ids_pinned[:n_blocks].numpy()[:] = block_ids
idx = self._ids_gpu[:n_blocks]
idx.copy_(self._ids_pinned[:n_blocks], non_blocking=True)
grid = (n_blocks * n_segs * (page_size_el // blk_size),)
_zero_kv_blocks_kernel[grid](
seg_addrs,
idx,
n_blocks,
N_SEGS=n_segs,
PAGE_SIZE_EL=page_size_el,
BLOCK_SIZE=blk_size,
)
@dataclass
class AttentionGroup:
backend: type[AttentionBackend]
@@ -36,7 +234,7 @@ class AttentionGroup:
self,
vllm_config,
device,
kernel_block_size: int | None,
kernel_block_size: int | None = None,
num_metadata_builders: int = 1,
):
kv_cache_spec_builder = (
@@ -59,6 +257,119 @@ class AttentionGroup:
return self.metadata_builders[ubatch_id]
def select_common_block_size(
kv_manager_block_size: int, attn_groups: list[AttentionGroup]
) -> int:
"""
Select a block size that is supported by all backends and is a factor of
kv_manager_block_size.
If kv_manager_block_size is supported by all backends, return it directly.
Otherwise, return the max supported size.
Args:
kv_manager_block_size: Block size of KV cache.
attn_groups: List of attention groups.
Returns:
The selected block size.
Raises:
ValueError: If no valid block size found.
"""
def block_size_is_supported(
backends: list[type[AttentionBackend]], block_size: int
) -> bool:
"""Check if the block size is supported by all backends."""
for backend in backends:
is_supported = False
for supported_size in backend.get_supported_kernel_block_sizes():
if isinstance(supported_size, int):
if block_size == supported_size:
is_supported = True
elif isinstance(supported_size, MultipleOf):
if block_size % supported_size.base == 0:
is_supported = True
else:
raise ValueError(f"Unknown supported size: {supported_size}")
if not is_supported:
return False
return True
backends = [group.backend for group in attn_groups]
# Case 1: if the block_size of kv cache manager is supported by all backends,
# return it directly.
if block_size_is_supported(backends, kv_manager_block_size):
return kv_manager_block_size
# Case 2: otherwise, the block_size must be an `int`-format supported size of
# at least one backend. Iterate over all `int`-format supported sizes in
# descending order and return the first one that is supported by all backends.
# Simple proof:
# If the supported size b is in MultipleOf(x_i) format for all attention
# backends i, and b a factor of kv_manager_block_size, then
# kv_manager_block_size also satisfies MultipleOf(x_i) for all i. We will
# return kv_manager_block_size in case 1.
all_int_supported_sizes = set(
supported_size
for backend in backends
for supported_size in backend.get_supported_kernel_block_sizes()
if isinstance(supported_size, int)
)
for supported_size in sorted(all_int_supported_sizes, reverse=True):
if kv_manager_block_size % supported_size != 0:
continue
if block_size_is_supported(backends, supported_size):
return supported_size
raise ValueError(f"No common block size for {kv_manager_block_size}. ")
def prepare_kernel_block_sizes(
kv_cache_config: KVCacheConfig, attn_groups: list[list[AttentionGroup]]
) -> list[int]:
"""
Generate kernel_block_sizes that matches each block_size.
For attention backends that support virtual block splitting,
use the supported block sizes from the backend.
For other backends (like Mamba), use the same block size (no splitting).
Args:
kv_cache_config: The KV cache configuration.
attn_groups: Attention groups indexed by KV cache group id.
Returns:
List of kernel block sizes for each cache group.
"""
kernel_block_sizes = []
for kv_cache_gid, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
kv_cache_spec = kv_cache_group.kv_cache_spec
if isinstance(kv_cache_spec, UniformTypeKVCacheSpecs):
# All layers in the UniformTypeKVCacheSpecs have the same type,
# pick an arbitrary one to dispatch.
kv_cache_spec = next(iter(kv_cache_spec.kv_cache_specs.values()))
if isinstance(kv_cache_spec, EncoderOnlyAttentionSpec):
continue
if isinstance(kv_cache_spec, AttentionSpec):
# This is an attention backend that supports virtual block splitting.
kv_manager_block_size = kv_cache_group.kv_cache_spec.block_size
selected_kernel_size = select_common_block_size(
kv_manager_block_size, attn_groups[kv_cache_gid]
)
kernel_block_sizes.append(selected_kernel_size)
elif isinstance(kv_cache_spec, MambaSpec):
# This is likely Mamba or other non-attention cache, no splitting.
kernel_block_sizes.append(kv_cache_spec.block_size)
else:
raise NotImplementedError(
f"unknown kv cache spec {kv_cache_group.kv_cache_spec}"
)
return kernel_block_sizes
def sanity_check_mm_encoder_outputs(
mm_embeddings: MultiModalEmbeddings,
expected_num_items: int,
@@ -201,6 +512,55 @@ def bind_kv_cache(
for layer_name, kv_cache in kv_caches.items():
# NOTE: Use list because of v0 PP virtual engine.
forward_context[layer_name].kv_cache = [kv_cache]
def bind_kv_cache_scale(
kv_caches_scale: dict[str, torch.Tensor],
forward_context: dict[str, "Attention"],
runner_kv_caches_scale: list[torch.Tensor],
num_attn_module: int | None = 1,
) -> None:
"""
Bind the allocated KV cache to both ModelRunner and forward context so
that the KV cache can be used in the forward pass.
This function:
1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with
kv_caches.
2) Associates each attention layer in the `forward_context` with its
corresponding KV cache in kv_caches.
Args:
kv_caches: The allocated kv_caches with layer names as keys.
forward_context: The global forward context containing all Attention
layers with layer names as keys.
runner_kv_caches: The kv_cache declared by ModelRunner.
"""
# Bind kv_caches to ModelRunner
assert len(runner_kv_caches_scale) == 0
# Convert kv_caches dict to a list of tensors in the order of layer_index.
index2name = defaultdict(list)
for layer_name in kv_caches_scale:
index2name[extract_layer_index(layer_name,
num_attn_module)].append(layer_name)
for layer_index in sorted(index2name.keys()):
layer_names = index2name[layer_index]
if len(layer_names) > 1:
# One typical case is encoder-decoder model, e.g., bart.
# The cross attention and self attention in the same decoder layer
# has different layer_name but the same layer_index.
if current_platform.is_cuda() or current_platform.is_xpu():
pass
else:
raise NotImplementedError
layer_name = layer_names[0]
runner_kv_caches_scale.append(kv_caches_scale[layer_name])
# Bind kv_caches to forward context
for layer_name, kv_cache_scale in kv_caches_scale.items():
# NOTE: Use list because of v0 PP virtual engine.
forward_context[layer_name].kv_cache_scale = [kv_cache_scale]
def is_residual_scattered_for_sp(

View File

@@ -87,8 +87,12 @@ class WorkerBase:
"""Get specifications for KV cache implementation."""
raise NotImplementedError
def compile_or_warm_up_model(self) -> None:
"""Prepare model for execution through compilation/warmup."""
def compile_or_warm_up_model(self) -> float:
"""Prepare model for execution through compilation/warmup.
Returns:
The accumulated compilation time in seconds.
"""
raise NotImplementedError
def check_health(self) -> None:
@@ -213,13 +217,8 @@ class WorkerWrapperBase:
It is only used during the initialization of the executor,
to adjust the rpc_rank of workers after we create all workers.
"""
# if self.rpc_rank in rank_mapping:
# self.rpc_rank = rank_mapping[self.rpc_rank]
old_rank = self.rpc_rank
if old_rank in rank_mapping:
self.rpc_rank = rank_mapping[old_rank]
if self.global_rank == old_rank:
self.global_rank = rank_mapping[old_rank]
if self.rpc_rank in rank_mapping:
self.rpc_rank = rank_mapping[self.rpc_rank]
def update_environment_variables(
self,

View File

@@ -66,6 +66,23 @@ class WorkspaceManager:
],
)
def unlock(self) -> None:
"""Unlock the workspace to allow growth.
This is used during elastic EP scaling when the workspace size
needs to grow due to changes in the number of experts.
"""
self._locked = False
if envs.VLLM_DEBUG_WORKSPACE:
logger.info(
"[WORKSPACE DEBUG] Workspace unlocked. Current sizes: %s",
[
self._workspace_size_bytes(ws) / _MB
for ws in self._current_workspaces
if ws is not None
],
)
def is_locked(self) -> bool:
"""Check if workspace is locked."""
return self._locked
@@ -242,6 +259,17 @@ def lock_workspace() -> None:
current_workspace_manager().lock()
def unlock_workspace() -> None:
"""Unlock the workspace to allow growth.
This is used during elastic EP scaling when the workspace size
needs to grow due to changes in the number of experts.
After scaling operations complete, lock_workspace() should be
called again to prevent unexpected allocations.
"""
current_workspace_manager().unlock()
def reset_workspace_manager() -> None:
"""Reset the workspace manager to uninitialized state.