Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -53,7 +53,12 @@ class CPUModelRunner(GPUModelRunner):
|
||||
v.gpu = v.cpu
|
||||
|
||||
@instrument(span_name="Loading (CPU)")
|
||||
def load_model(self, eep_scale_up: bool = False) -> None:
|
||||
def load_model(self, load_dummy_weights: bool = False) -> None:
|
||||
if load_dummy_weights:
|
||||
raise ValueError(
|
||||
"Loading dummy weights (needed for elastic EP scale-up) "
|
||||
"Is not supported by the CPU Model Runner."
|
||||
)
|
||||
logger.info("Starting to load model %s...", self.model_config.model)
|
||||
self.model = get_model(vllm_config=self.vllm_config)
|
||||
|
||||
|
||||
@@ -85,7 +85,7 @@ class CPUWorker(Worker):
|
||||
self.local_omp_cpuid = omp_cpuids_list[self.rank]
|
||||
|
||||
if self.local_omp_cpuid != "nobind":
|
||||
ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)
|
||||
ret = torch.ops._C.init_cpu_threads_env(self.local_omp_cpuid)
|
||||
if ret:
|
||||
logger.info(ret)
|
||||
|
||||
@@ -118,11 +118,12 @@ class CPUWorker(Worker):
|
||||
def determine_available_memory(self) -> int:
|
||||
return self.cache_config.cpu_kvcache_space_bytes or 0
|
||||
|
||||
def compile_or_warm_up_model(self) -> None:
|
||||
def compile_or_warm_up_model(self) -> float:
|
||||
# Reset the seed to ensure that the random state is not affected by
|
||||
# the model initialization and profiling.
|
||||
set_random_seed(self.model_config.seed)
|
||||
self.model_runner.warming_up_model()
|
||||
return self.compilation_config.compilation_time
|
||||
|
||||
def _get_autobind_cpu_ids(
|
||||
self, cpu_selector: Callable[[list[LogicalCPUInfo]], list[LogicalCPUInfo]]
|
||||
|
||||
@@ -37,7 +37,6 @@ def _get_device_and_group(parallel_config: ParallelConfig):
|
||||
|
||||
def _run_ar(
|
||||
should_ubatch: bool,
|
||||
should_dp_pad: bool,
|
||||
orig_num_tokens_per_ubatch: int,
|
||||
padded_num_tokens_per_ubatch: int,
|
||||
cudagraph_mode: int,
|
||||
@@ -46,12 +45,11 @@ def _run_ar(
|
||||
dp_size = parallel_config.data_parallel_size
|
||||
dp_rank = parallel_config.data_parallel_rank
|
||||
device, group = _get_device_and_group(parallel_config)
|
||||
tensor = torch.zeros(5, dp_size, device=device, dtype=torch.int32)
|
||||
tensor = torch.zeros(4, dp_size, device=device, dtype=torch.int32)
|
||||
tensor[0][dp_rank] = orig_num_tokens_per_ubatch
|
||||
tensor[1][dp_rank] = padded_num_tokens_per_ubatch
|
||||
tensor[2][dp_rank] = 1 if should_ubatch else 0
|
||||
tensor[3][dp_rank] = 1 if should_dp_pad else 0
|
||||
tensor[4][dp_rank] = cudagraph_mode
|
||||
tensor[3][dp_rank] = cudagraph_mode
|
||||
dist.all_reduce(tensor, group=group)
|
||||
return tensor
|
||||
|
||||
@@ -97,14 +95,13 @@ def _post_process_cudagraph_mode(tensor: torch.Tensor) -> int:
|
||||
If any rank has NONE (0), all ranks use NONE.
|
||||
This ensures all ranks send consistent values (all padded or all unpadded).
|
||||
"""
|
||||
return int(tensor[4, :].min().item())
|
||||
return int(tensor[3, :].min().item())
|
||||
|
||||
|
||||
def _synchronize_dp_ranks(
|
||||
num_tokens_unpadded: int,
|
||||
num_tokens_padded: int,
|
||||
should_attempt_ubatching: bool,
|
||||
should_attempt_dp_padding: bool,
|
||||
cudagraph_mode: int,
|
||||
parallel_config: ParallelConfig,
|
||||
) -> tuple[bool, torch.Tensor | None, int]:
|
||||
@@ -113,8 +110,8 @@ def _synchronize_dp_ranks(
|
||||
run with microbatching or none of them do.
|
||||
|
||||
2. Determines the total number of tokens that each rank will run.
|
||||
When running microbatched or if should_attempt_dp_padding is True, all
|
||||
ranks will be padded out so that the run with the same number of tokens
|
||||
When running microbatched or if cudagraph is enabled (synced across ranks),
|
||||
all ranks will be padded out so that they run with the same number of tokens.
|
||||
|
||||
3. Synchronizes cudagraph_mode across ranks by taking the minimum.
|
||||
|
||||
@@ -133,29 +130,26 @@ def _synchronize_dp_ranks(
|
||||
# will run and if we are using ubatching or not.
|
||||
tensor = _run_ar(
|
||||
should_ubatch=should_attempt_ubatching,
|
||||
should_dp_pad=should_attempt_dp_padding,
|
||||
orig_num_tokens_per_ubatch=num_tokens_unpadded,
|
||||
padded_num_tokens_per_ubatch=num_tokens_padded,
|
||||
cudagraph_mode=cudagraph_mode,
|
||||
parallel_config=parallel_config,
|
||||
)
|
||||
|
||||
should_dp_pad = bool(torch.all(tensor[3] == 1).item())
|
||||
|
||||
# DP ranks should all have the same value for should_attempt_dp_padding.
|
||||
assert should_attempt_dp_padding == should_dp_pad
|
||||
# Synchronize cudagraph_mode across ranks first (take min).
|
||||
# This is needed before DP padding decision since we use the synced
|
||||
# cudagraph mode to determine whether DP padding is needed.
|
||||
synced_cudagraph_mode = _post_process_cudagraph_mode(tensor)
|
||||
|
||||
# Check conditions for microbatching
|
||||
should_ubatch = _post_process_ubatch(tensor, parallel_config.num_ubatches)
|
||||
|
||||
if should_ubatch and not should_dp_pad:
|
||||
logger.debug_once(
|
||||
"Microbatching has been triggered and requires DP padding. "
|
||||
"Enabling DP padding even though it has been explicitly "
|
||||
"disabled.",
|
||||
scope="global",
|
||||
)
|
||||
should_dp_pad = True
|
||||
# DP padding is needed when cudagraph is enabled (synced across ranks)
|
||||
# or when ubatching/DBO is active (ubatching requires uniform batch
|
||||
# sizes across DP ranks currently).
|
||||
# Use the synced runtime cudagraph mode rather than the compilation config
|
||||
# so we can avoid padding when cudagraph is not enabled for this step.
|
||||
should_dp_pad = synced_cudagraph_mode != 0 or should_ubatch
|
||||
|
||||
# Pad all DP ranks up to the maximum token count across ranks if
|
||||
# should_dp_pad is True
|
||||
@@ -164,16 +158,12 @@ def _synchronize_dp_ranks(
|
||||
should_dp_pad,
|
||||
)
|
||||
|
||||
# Synchronize cudagraph_mode across ranks (take min)
|
||||
synced_cudagraph_mode = _post_process_cudagraph_mode(tensor)
|
||||
|
||||
return should_ubatch, num_tokens_after_padding, synced_cudagraph_mode
|
||||
|
||||
|
||||
def coordinate_batch_across_dp(
|
||||
num_tokens_unpadded: int,
|
||||
allow_microbatching: bool,
|
||||
allow_dp_padding: bool,
|
||||
parallel_config: ParallelConfig,
|
||||
num_tokens_padded: int | None = None,
|
||||
uniform_decode: bool | None = None,
|
||||
@@ -187,7 +177,6 @@ def coordinate_batch_across_dp(
|
||||
Args:
|
||||
num_tokens_unpadded: Number of tokens without accounting for padding
|
||||
allow_microbatching: If microbatching should be attempted
|
||||
allow_dp_padding: If all DP ranks should be padded up to the same value
|
||||
parallel_config: The parallel config
|
||||
num_tokens_padded: Number of tokens including any non-DP padding (CUDA graphs,
|
||||
TP, etc)
|
||||
@@ -195,15 +184,15 @@ def coordinate_batch_across_dp(
|
||||
only contains single token decodes
|
||||
num_scheduled_tokens_per_request: Only used if allow_microbatching is True. The
|
||||
number of tokens per request.
|
||||
cudagraph_mode: The cudagraph mode for this rank (0=NONE, 1=PIECEWISE, 2=FULL)
|
||||
cudagraph_mode: The cudagraph mode for this rank (0=NONE, 1=PIECEWISE, 2=FULL).
|
||||
DP padding is enabled when synced cudagraph mode across ranks is not NONE.
|
||||
|
||||
Returns: tuple[
|
||||
ubatch_slices: if this is set then all DP ranks have agreed to
|
||||
microbatch
|
||||
num_tokens_after_padding: A tensor containing the total number of
|
||||
tokens per-microbatch for each DP rank including padding. Will be
|
||||
padded up to the max value across all DP ranks when allow_dp_padding
|
||||
is True.
|
||||
padded up to the max value across all DP ranks when cudagraph is enabled.
|
||||
synced_cudagraph_mode: The synchronized cudagraph mode (min across ranks)
|
||||
]
|
||||
|
||||
@@ -231,7 +220,6 @@ def coordinate_batch_across_dp(
|
||||
num_tokens_unpadded,
|
||||
num_tokens_padded,
|
||||
should_attempt_ubatching,
|
||||
allow_dp_padding,
|
||||
cudagraph_mode,
|
||||
parallel_config,
|
||||
)
|
||||
|
||||
@@ -70,6 +70,42 @@ class AsyncOutput(AsyncModelRunnerOutput):
|
||||
return self.model_runner_output
|
||||
|
||||
|
||||
class AsyncPoolingOutput(AsyncModelRunnerOutput):
|
||||
def __init__(
|
||||
self,
|
||||
model_runner_output: ModelRunnerOutput,
|
||||
pooler_output: torch.Tensor,
|
||||
is_valid: torch.Tensor | None,
|
||||
main_stream: torch.cuda.Stream,
|
||||
copy_stream: torch.cuda.Stream,
|
||||
copy_event: torch.cuda.Event,
|
||||
):
|
||||
self.model_runner_output = model_runner_output
|
||||
self.pooler_output = pooler_output
|
||||
self.is_valid = is_valid
|
||||
self.copy_event = copy_event
|
||||
|
||||
with stream(copy_stream, main_stream):
|
||||
copy_stream.wait_stream(main_stream)
|
||||
self.pooler_output_cpu = self.pooler_output.to("cpu", non_blocking=True)
|
||||
if self.is_valid is not None:
|
||||
self.is_valid_cpu = self.is_valid.to("cpu", non_blocking=True)
|
||||
else:
|
||||
self.is_valid_cpu = None
|
||||
self.copy_event.record(copy_stream)
|
||||
|
||||
def get_output(self) -> ModelRunnerOutput:
|
||||
self.copy_event.synchronize()
|
||||
pooler_output = self.pooler_output_cpu.unbind(dim=0)
|
||||
if self.is_valid_cpu is not None:
|
||||
is_valid_cpu = self.is_valid_cpu.tolist()
|
||||
for i, is_valid in enumerate(is_valid_cpu):
|
||||
if not is_valid:
|
||||
pooler_output[i] = None
|
||||
self.model_runner_output.pooler_output = pooler_output
|
||||
return self.model_runner_output
|
||||
|
||||
|
||||
def async_copy_to_np(x: torch.Tensor) -> np.ndarray:
|
||||
return x.to("cpu", non_blocking=True).numpy()
|
||||
|
||||
|
||||
@@ -119,6 +119,10 @@ class BlockTables:
|
||||
return tuple(block_table[:num_reqs] for block_table in self.input_block_tables)
|
||||
|
||||
def get_dummy_block_tables(self, num_reqs: int) -> tuple[torch.Tensor, ...]:
|
||||
# NOTE(woosuk): The output may be used for CUDA graph capture.
|
||||
# Therefore, this method must return the persistent tensor
|
||||
# with the same memory address as that used during the model's forward pass,
|
||||
# rather than allocating a new tensor.
|
||||
return tuple(block_table[:num_reqs] for block_table in self.input_block_tables)
|
||||
|
||||
def compute_slot_mappings(
|
||||
@@ -150,7 +154,14 @@ class BlockTables:
|
||||
return self.slot_mappings[:, :num_tokens]
|
||||
|
||||
def get_dummy_slot_mappings(self, num_tokens: int) -> torch.Tensor:
|
||||
# Fill the entire slot_mappings tensor, not just the first `num_tokens` entries.
|
||||
# This is because the padding logic is complex and kernels may access beyond
|
||||
# the requested range.
|
||||
self.slot_mappings.fill_(PAD_SLOT_ID)
|
||||
# NOTE(woosuk): The output may be used for CUDA graph capture.
|
||||
# Therefore, this method must return the persistent tensor
|
||||
# with the same memory address as that used during the model's forward pass,
|
||||
# rather than allocating a new tensor.
|
||||
return self.slot_mappings[:, :num_tokens]
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from tqdm import tqdm
|
||||
@@ -15,13 +14,12 @@ from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
from vllm.model_executor.offloader.base import get_offloader
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.worker.gpu.attn_utils import (
|
||||
build_attn_metadata,
|
||||
build_slot_mappings_by_layer,
|
||||
)
|
||||
from vllm.v1.worker.gpu.attn_utils import build_slot_mappings_by_layer
|
||||
from vllm.v1.worker.gpu.block_table import BlockTables
|
||||
from vllm.v1.worker.gpu.cp_utils import prepare_dcp_local_seq_lens
|
||||
from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
|
||||
from vllm.v1.worker.gpu.input_batch import InputBuffers
|
||||
from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
|
||||
from vllm.v1.worker.gpu.model_states.interface import ModelState
|
||||
from vllm.v1.worker.utils import AttentionGroup
|
||||
|
||||
|
||||
@@ -29,13 +27,11 @@ class CudaGraphManager:
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
uses_mrope: bool,
|
||||
use_aux_hidden_state_outputs: bool,
|
||||
device: torch.device,
|
||||
):
|
||||
self.vllm_config = vllm_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.uses_mrope = uses_mrope
|
||||
self.use_aux_hidden_state_outputs = use_aux_hidden_state_outputs
|
||||
self.device = device
|
||||
|
||||
@@ -88,9 +84,8 @@ class CudaGraphManager:
|
||||
num_tokens: int,
|
||||
capture_cg_mode: CUDAGraphMode,
|
||||
model: nn.Module,
|
||||
model_state: ModelState,
|
||||
input_buffers: InputBuffers,
|
||||
mrope_positions: torch.Tensor | None,
|
||||
inputs_embeds: torch.Tensor | None,
|
||||
block_tables: BlockTables,
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
@@ -113,24 +108,23 @@ class CudaGraphManager:
|
||||
)
|
||||
else:
|
||||
num_reqs = min(num_tokens, self.max_num_reqs)
|
||||
input_ids = input_buffers.input_ids[:num_tokens]
|
||||
positions = input_buffers.positions[:num_tokens]
|
||||
if self.uses_mrope:
|
||||
assert mrope_positions is not None
|
||||
positions = mrope_positions[:, :num_tokens]
|
||||
if inputs_embeds is not None:
|
||||
inputs_embeds = inputs_embeds[:num_tokens]
|
||||
|
||||
model_inputs = {
|
||||
"input_ids": input_buffers.input_ids[:num_tokens],
|
||||
"positions": input_buffers.positions[:num_tokens],
|
||||
# NOTE: Values returned by `prepare_dummy_inputs` will override the
|
||||
# default values above.
|
||||
**model_state.prepare_dummy_inputs(num_reqs, num_tokens),
|
||||
}
|
||||
|
||||
attn_metadata, slot_mappings = prepare_inputs_to_capture(
|
||||
num_reqs,
|
||||
num_tokens,
|
||||
model_state,
|
||||
input_buffers,
|
||||
block_tables,
|
||||
attn_groups,
|
||||
self.max_model_len,
|
||||
kv_cache_config,
|
||||
uniform_decode_query_len=(
|
||||
self.uniform_decode_query_len if uniform_decode else 0
|
||||
),
|
||||
)
|
||||
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
|
||||
|
||||
@@ -143,11 +137,7 @@ class CudaGraphManager:
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
slot_mapping=slot_mappings,
|
||||
):
|
||||
model_output = model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
model_output = model(**model_inputs)
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
hidden_states, aux_hidden_states = model_output
|
||||
else:
|
||||
@@ -164,9 +154,7 @@ class CudaGraphManager:
|
||||
num_tokens=num_tokens,
|
||||
num_reqs=num_reqs,
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
model_inputs=model_inputs,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
attn_metadata=attn_metadata,
|
||||
slot_mappings=slot_mappings,
|
||||
@@ -178,9 +166,7 @@ class CudaGraphManager:
|
||||
num_tokens: int,
|
||||
num_reqs: int,
|
||||
model: nn.Module,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor | None,
|
||||
model_inputs: dict[str, torch.Tensor | None],
|
||||
num_tokens_across_dp: torch.Tensor,
|
||||
attn_metadata: dict[str, Any] | None,
|
||||
slot_mappings: dict[str, torch.Tensor] | None,
|
||||
@@ -206,11 +192,8 @@ class CudaGraphManager:
|
||||
),
|
||||
torch.cuda.graph(graph, self.pool),
|
||||
):
|
||||
model_output = model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
model_output = model(**model_inputs)
|
||||
|
||||
# Join offloader's copy stream after forward to avoid unjoined
|
||||
# stream error. The last layer's start_prefetch forks copy_stream,
|
||||
# but wait_prefetch only happens in the next forward pass.
|
||||
@@ -235,9 +218,7 @@ class CudaGraphManager:
|
||||
num_tokens: int,
|
||||
num_reqs: int,
|
||||
model: nn.Module,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor | None,
|
||||
model_inputs: dict[str, torch.Tensor | None],
|
||||
num_tokens_across_dp: torch.Tensor,
|
||||
attn_metadata: dict[str, Any] | None,
|
||||
slot_mappings: dict[str, torch.Tensor] | None,
|
||||
@@ -256,19 +237,14 @@ class CudaGraphManager:
|
||||
batch_descriptor=batch_descriptor,
|
||||
slot_mapping=slot_mappings,
|
||||
):
|
||||
model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
model(**model_inputs)
|
||||
|
||||
@torch.inference_mode()
|
||||
def capture(
|
||||
self,
|
||||
model: nn.Module,
|
||||
model_state: ModelState,
|
||||
input_buffers: InputBuffers,
|
||||
mrope_positions: torch.Tensor | None,
|
||||
inputs_embeds: torch.Tensor | None,
|
||||
block_tables: BlockTables,
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
@@ -278,9 +254,8 @@ class CudaGraphManager:
|
||||
device=self.device,
|
||||
capture_fn=self.capture_graph,
|
||||
model=model,
|
||||
model_state=model_state,
|
||||
input_buffers=input_buffers,
|
||||
mrope_positions=mrope_positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
block_tables=block_tables,
|
||||
attn_groups=attn_groups,
|
||||
kv_cache_config=kv_cache_config,
|
||||
@@ -412,51 +387,36 @@ def capture_graphs(
|
||||
def prepare_inputs_to_capture(
|
||||
num_reqs: int,
|
||||
num_tokens: int,
|
||||
model_state: ModelState,
|
||||
input_buffers: InputBuffers,
|
||||
block_tables: BlockTables,
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
max_model_len: int,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
uniform_decode_query_len: int = 0,
|
||||
) -> tuple[dict[str, Any], dict[str, torch.Tensor]]:
|
||||
if uniform_decode_query_len > 0:
|
||||
num_tokens_per_req = uniform_decode_query_len
|
||||
else:
|
||||
num_tokens_per_req = num_tokens // num_reqs
|
||||
|
||||
query_start_loc_np = np.arange(num_reqs + 1, dtype=np.int32) * num_tokens_per_req
|
||||
query_start_loc_np[-1] = num_tokens
|
||||
query_start_loc_cpu = torch.from_numpy(query_start_loc_np)
|
||||
input_buffers.query_start_loc[: num_reqs + 1] = query_start_loc_cpu
|
||||
input_buffers.query_start_loc[num_reqs + 1 :] = num_tokens
|
||||
query_start_loc = input_buffers.query_start_loc[: num_reqs + 1]
|
||||
|
||||
# HACK(woosuk): For faster warmup, we set seq_lens (GPU) to num_tokens
|
||||
# rather than max_model_len.
|
||||
input_buffers.seq_lens[:num_reqs] = num_tokens
|
||||
input_buffers.seq_lens[num_reqs:] = 0
|
||||
|
||||
input_buffers.dcp_local_seq_lens[:num_reqs] = num_tokens
|
||||
input_buffers.dcp_local_seq_lens[num_reqs:] = 0
|
||||
|
||||
input_block_tables = [x[:num_reqs] for x in block_tables.input_block_tables]
|
||||
slot_mappings = block_tables.slot_mappings[:, :num_tokens]
|
||||
input_batch = InputBatch.make_dummy(num_reqs, num_tokens, input_buffers)
|
||||
input_block_tables = block_tables.get_dummy_block_tables(num_reqs)
|
||||
slot_mappings = block_tables.get_dummy_slot_mappings(num_tokens)
|
||||
slot_mappings_by_layer = build_slot_mappings_by_layer(
|
||||
slot_mappings, kv_cache_config
|
||||
)
|
||||
|
||||
attn_metadata = build_attn_metadata(
|
||||
attn_groups=attn_groups,
|
||||
num_reqs=num_reqs,
|
||||
num_tokens=num_tokens,
|
||||
query_start_loc_gpu=query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
max_query_len=num_tokens_per_req,
|
||||
seq_lens=input_buffers.seq_lens,
|
||||
max_seq_len=max_model_len,
|
||||
block_tables=input_block_tables,
|
||||
slot_mappings=slot_mappings,
|
||||
kv_cache_config=kv_cache_config,
|
||||
dcp_local_seq_lens=input_buffers.dcp_local_seq_lens,
|
||||
# HACK(woosuk): Special handling for DCP.
|
||||
if block_tables.cp_size > 1:
|
||||
prepare_dcp_local_seq_lens(
|
||||
input_buffers.dcp_local_seq_lens,
|
||||
input_batch.seq_lens,
|
||||
num_reqs,
|
||||
block_tables.cp_size,
|
||||
block_tables.cp_rank,
|
||||
block_tables.cp_interleave,
|
||||
)
|
||||
input_batch.dcp_local_seq_lens = input_buffers.dcp_local_seq_lens[:num_reqs]
|
||||
|
||||
attn_metadata = model_state.prepare_attn(
|
||||
input_batch,
|
||||
input_block_tables,
|
||||
slot_mappings,
|
||||
attn_groups,
|
||||
kv_cache_config,
|
||||
)
|
||||
return attn_metadata, slot_mappings_by_layer
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -60,20 +59,13 @@ class InputBatch:
|
||||
query_start_loc_np: np.ndarray
|
||||
# [num_reqs]
|
||||
seq_lens: torch.Tensor
|
||||
# [num_reqs]
|
||||
dcp_local_seq_lens: torch.Tensor | None
|
||||
|
||||
# [num_tokens_after_padding]
|
||||
input_ids: torch.Tensor
|
||||
# [num_tokens_after_padding]
|
||||
positions: torch.Tensor
|
||||
# [3, num_tokens_after_padding]
|
||||
mrope_positions: torch.Tensor | None
|
||||
# [num_tokens_after_padding, hidden_size]
|
||||
inputs_embeds: torch.Tensor | None
|
||||
|
||||
# layer_name -> Metadata
|
||||
attn_metadata: dict[str, Any]
|
||||
# layer_name -> slot_mapping
|
||||
slot_mappings: dict[str, torch.Tensor]
|
||||
|
||||
# [total_num_logits]
|
||||
logits_indices: torch.Tensor
|
||||
@@ -90,14 +82,16 @@ class InputBatch:
|
||||
num_reqs: int,
|
||||
num_tokens: int,
|
||||
input_buffers: InputBuffers,
|
||||
device: torch.device,
|
||||
) -> "InputBatch":
|
||||
assert 0 < num_reqs <= num_tokens
|
||||
device = input_buffers.device
|
||||
|
||||
req_ids = [f"req_{i}_{random_uuid()}" for i in range(num_reqs)]
|
||||
idx_mapping_np = np.arange(num_reqs, dtype=np.int32)
|
||||
idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=device)
|
||||
expanded_idx_mapping = idx_mapping
|
||||
expanded_local_pos = torch.zeros(num_reqs, dtype=torch.int32, device=device)
|
||||
|
||||
num_scheduled_tokens = np.full(num_reqs, num_tokens // num_reqs, dtype=np.int32)
|
||||
num_scheduled_tokens[-1] += num_tokens % num_reqs
|
||||
assert int(num_scheduled_tokens.sum()) == num_tokens
|
||||
@@ -123,7 +117,6 @@ class InputBatch:
|
||||
input_ids = input_buffers.input_ids[:num_tokens].zero_()
|
||||
positions = input_buffers.positions[:num_tokens].zero_()
|
||||
|
||||
# attn_metadata = defaultdict(lambda: None)
|
||||
logits_indices = query_start_loc[1:] - 1
|
||||
cu_num_logits = torch.arange(num_reqs + 1, device=device, dtype=torch.int32)
|
||||
cu_num_logits_np = np.arange(num_reqs + 1, dtype=np.int32)
|
||||
@@ -141,12 +134,9 @@ class InputBatch:
|
||||
query_start_loc=query_start_loc,
|
||||
query_start_loc_np=query_start_loc_np,
|
||||
seq_lens=seq_lens,
|
||||
dcp_local_seq_lens=None,
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
mrope_positions=None,
|
||||
inputs_embeds=None,
|
||||
attn_metadata=None, # type: ignore
|
||||
slot_mappings=None, # type: ignore
|
||||
logits_indices=logits_indices,
|
||||
cu_num_logits=cu_num_logits,
|
||||
cu_num_logits_np=cu_num_logits_np,
|
||||
@@ -507,6 +497,38 @@ def post_update(
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _post_update_pool_kernel(
|
||||
idx_mapping_ptr,
|
||||
num_computed_tokens_ptr,
|
||||
query_start_loc_ptr,
|
||||
):
|
||||
batch_id = tl.program_id(0)
|
||||
query_start = tl.load(query_start_loc_ptr + batch_id)
|
||||
query_end = tl.load(query_start_loc_ptr + batch_id + 1)
|
||||
query_len = query_end - query_start
|
||||
|
||||
req_state_idx = tl.load(idx_mapping_ptr + batch_id)
|
||||
num_computed = tl.load(num_computed_tokens_ptr + req_state_idx)
|
||||
tl.store(num_computed_tokens_ptr + req_state_idx, num_computed + query_len)
|
||||
|
||||
|
||||
def post_update_pool(
|
||||
# [num_reqs]
|
||||
idx_mapping: torch.Tensor,
|
||||
# [max_num_reqs]
|
||||
num_computed_tokens: torch.Tensor,
|
||||
# [num_reqs + 1]
|
||||
query_start_loc: torch.Tensor,
|
||||
) -> None:
|
||||
num_reqs = idx_mapping.shape[0]
|
||||
_post_update_pool_kernel[(num_reqs,)](
|
||||
idx_mapping,
|
||||
num_computed_tokens,
|
||||
query_start_loc,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _expand_idx_mapping_kernel(
|
||||
idx_mapping_ptr,
|
||||
|
||||
40
vllm/v1/worker/gpu/mm/encoder_cache.py
Normal file
40
vllm/v1/worker/gpu/mm/encoder_cache.py
Normal file
@@ -0,0 +1,40 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
|
||||
from vllm.multimodal.inputs import MultiModalFeatureSpec
|
||||
|
||||
|
||||
class EncoderCache:
|
||||
def __init__(self):
|
||||
# req_id -> MM features
|
||||
self.mm_features: dict[str, list[MultiModalFeatureSpec]] = {}
|
||||
# MM hash -> encoder outputs
|
||||
self.encoder_outputs: dict[str, torch.Tensor] = {}
|
||||
|
||||
def add_request(
|
||||
self, req_id: str, mm_features: list[MultiModalFeatureSpec]
|
||||
) -> None:
|
||||
self.mm_features[req_id] = mm_features
|
||||
|
||||
def remove_request(self, req_id: str) -> None:
|
||||
self.mm_features.pop(req_id, None)
|
||||
|
||||
def reset_mm_cache(self) -> None:
|
||||
"""
|
||||
Clear the multi-modal cache that was used during profiling,
|
||||
but no longer needed during inference.
|
||||
"""
|
||||
# TODO: Implement MM budget for encoder dummy run
|
||||
pass
|
||||
|
||||
def reset_encoder_cache(self) -> None:
|
||||
"""Clear the GPU-side encoder cache storing vision embeddings.
|
||||
|
||||
This should be called when model weights are updated to ensure
|
||||
stale embeddings computed with old weights are not reused.
|
||||
"""
|
||||
self.encoder_outputs.clear()
|
||||
|
||||
def free_encoder_cache(self, mm_hash: str) -> None:
|
||||
self.encoder_outputs.pop(mm_hash, None)
|
||||
@@ -4,54 +4,32 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.models.interfaces import SupportsMultiModal
|
||||
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItem
|
||||
from vllm.multimodal.inputs import MultiModalKwargsItem
|
||||
from vllm.multimodal.utils import group_mm_kwargs_by_modality
|
||||
from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache
|
||||
from vllm.v1.worker.utils import sanity_check_mm_encoder_outputs
|
||||
|
||||
|
||||
class EncoderRunner:
|
||||
def __init__(
|
||||
self,
|
||||
model: SupportsMultiModal,
|
||||
max_num_tokens: int,
|
||||
hidden_size: int,
|
||||
encoder_cache: EncoderCache,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
):
|
||||
self.model = model
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.hidden_size = hidden_size
|
||||
self.encoder_cache = encoder_cache
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
self.inputs_embeds = torch.zeros(
|
||||
max_num_tokens, hidden_size, dtype=dtype, device=device
|
||||
)
|
||||
self.req_id_to_mm_features: dict[str, list[MultiModalFeatureSpec]] = {}
|
||||
self.encoder_cache: dict[str, torch.Tensor] = {}
|
||||
|
||||
def reset_mm_cache(self) -> None:
|
||||
"""
|
||||
Clear the multi-modal cache that was used during profiling,
|
||||
but no longer needed during inference.
|
||||
"""
|
||||
# TODO: Implement MM budget for encoder dummy run
|
||||
pass
|
||||
|
||||
def reset_encoder_cache(self) -> None:
|
||||
"""Clear the GPU-side encoder cache storing vision embeddings.
|
||||
|
||||
This should be called when model weights are updated to ensure
|
||||
stale embeddings computed with old weights are not reused.
|
||||
"""
|
||||
self.encoder_cache.clear()
|
||||
|
||||
def add_request(self, req_id: str, mm_features: list[MultiModalFeatureSpec]):
|
||||
self.req_id_to_mm_features[req_id] = mm_features
|
||||
|
||||
def free_encoder_cache(self, mm_hash: str) -> None:
|
||||
self.encoder_cache.pop(mm_hash, None)
|
||||
|
||||
def remove_request(self, req_id: str) -> None:
|
||||
self.req_id_to_mm_features.pop(req_id, None)
|
||||
|
||||
def prepare_mm_inputs(
|
||||
self, scheduled_encoder_inputs: dict[str, list[int]]
|
||||
@@ -59,7 +37,7 @@ class EncoderRunner:
|
||||
mm_hashes: list[str] = []
|
||||
mm_kwargs: list[tuple[str, MultiModalKwargsItem]] = []
|
||||
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
|
||||
mm_features = self.req_id_to_mm_features[req_id]
|
||||
mm_features = self.encoder_cache.mm_features[req_id]
|
||||
for mm_input_id in encoder_input_ids:
|
||||
mm_feature = mm_features[mm_input_id]
|
||||
if mm_feature.data is None:
|
||||
@@ -72,25 +50,17 @@ class EncoderRunner:
|
||||
@torch.inference_mode()
|
||||
def execute_mm_encoder(
|
||||
self,
|
||||
model: SupportsMultiModal,
|
||||
mm_hashes: list[str],
|
||||
mm_kwargs: list[tuple[str, MultiModalKwargsItem]],
|
||||
) -> list[torch.Tensor]:
|
||||
if not mm_hashes:
|
||||
return []
|
||||
|
||||
encoder_outputs: list[torch.Tensor] = []
|
||||
for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
|
||||
mm_kwargs, device=self.device, pin_memory=False
|
||||
):
|
||||
curr_group_outputs = model.embed_multimodal(**mm_kwargs_group)
|
||||
curr_group_outputs = self.model.embed_multimodal(**mm_kwargs_group)
|
||||
sanity_check_mm_encoder_outputs(
|
||||
curr_group_outputs, expected_num_items=num_items
|
||||
)
|
||||
encoder_outputs.extend(curr_group_outputs)
|
||||
|
||||
# Cache the encoder outputs by mm_hash
|
||||
self.encoder_cache.update(zip(mm_hashes, encoder_outputs))
|
||||
return encoder_outputs
|
||||
|
||||
def gather_mm_embeddings(
|
||||
@@ -122,7 +92,7 @@ class EncoderRunner:
|
||||
# OPTIMIZATION: Skip decode requests.
|
||||
continue
|
||||
|
||||
mm_features = self.req_id_to_mm_features[req_id]
|
||||
mm_features = self.encoder_cache.mm_features[req_id]
|
||||
for mm_feature in mm_features:
|
||||
pos_info = mm_feature.mm_position
|
||||
start_pos = pos_info.offset
|
||||
@@ -148,7 +118,7 @@ class EncoderRunner:
|
||||
continue
|
||||
|
||||
mm_hash = mm_feature.identifier
|
||||
encoder_output = self.encoder_cache.get(mm_hash, None)
|
||||
encoder_output = self.encoder_cache.encoder_outputs.get(mm_hash, None)
|
||||
assert encoder_output is not None, f"Encoder cache miss for {mm_hash}."
|
||||
|
||||
if (is_embed := pos_info.is_embed) is not None:
|
||||
@@ -170,12 +140,11 @@ class EncoderRunner:
|
||||
@torch.inference_mode()
|
||||
def get_inputs_embeds(
|
||||
self,
|
||||
model: SupportsMultiModal,
|
||||
input_ids: torch.Tensor,
|
||||
mm_embeds: list[torch.Tensor],
|
||||
is_mm_embed: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
x = model.embed_input_ids(
|
||||
x = self.model.embed_input_ids(
|
||||
input_ids, multimodal_embeddings=mm_embeds, is_multimodal=is_mm_embed
|
||||
)
|
||||
# Copy to the pre-allocated buffer for CUDA graphs.
|
||||
|
||||
@@ -38,15 +38,16 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model_loader
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
|
||||
from vllm.v1.worker.cp_utils import check_attention_cp_compatibility
|
||||
from vllm.v1.worker.gpu.async_utils import AsyncOutput
|
||||
from vllm.v1.worker.gpu.async_utils import AsyncOutput, AsyncPoolingOutput
|
||||
from vllm.v1.worker.gpu.attn_utils import (
|
||||
build_attn_metadata,
|
||||
build_slot_mappings_by_layer,
|
||||
get_kv_cache_spec,
|
||||
init_attn_backend,
|
||||
@@ -56,10 +57,7 @@ from vllm.v1.worker.gpu.block_table import BlockTables
|
||||
from vllm.v1.worker.gpu.buffer_utils import async_copy_to_gpu
|
||||
from vllm.v1.worker.gpu.cp_utils import prepare_dcp_local_seq_lens
|
||||
from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager
|
||||
from vllm.v1.worker.gpu.dp_utils import (
|
||||
get_cudagraph_and_dp_padding,
|
||||
make_num_tokens_across_dp,
|
||||
)
|
||||
from vllm.v1.worker.gpu.dp_utils import get_cudagraph_and_dp_padding
|
||||
from vllm.v1.worker.gpu.input_batch import (
|
||||
InputBatch,
|
||||
InputBuffers,
|
||||
@@ -67,6 +65,7 @@ from vllm.v1.worker.gpu.input_batch import (
|
||||
expand_idx_mapping,
|
||||
get_num_sampled_and_rejected,
|
||||
post_update,
|
||||
post_update_pool,
|
||||
prepare_pos_seq_lens,
|
||||
prepare_prefill_inputs,
|
||||
)
|
||||
@@ -76,8 +75,9 @@ from vllm.v1.worker.gpu.kv_connector import (
|
||||
get_kv_connector,
|
||||
)
|
||||
from vllm.v1.worker.gpu.lora_utils import LoraState
|
||||
from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner
|
||||
from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState
|
||||
from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache
|
||||
from vllm.v1.worker.gpu.model_states import init_model_state
|
||||
from vllm.v1.worker.gpu.pool.pooling_runner import PoolingRunner
|
||||
from vllm.v1.worker.gpu.pp_utils import pp_broadcast, pp_receive
|
||||
from vllm.v1.worker.gpu.sample.output import SamplerOutput
|
||||
from vllm.v1.worker.gpu.sample.prompt_logprob import PromptLogprobsWorker
|
||||
@@ -120,34 +120,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
|
||||
self.cache_config.cache_dtype
|
||||
]
|
||||
self.is_pooling_model = False
|
||||
|
||||
self.vocab_size = self.model_config.get_vocab_size()
|
||||
self.max_model_len = self.model_config.max_model_len
|
||||
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
|
||||
self.max_num_reqs = self.scheduler_config.max_num_seqs
|
||||
self.inputs_embeds_size = self.model_config.get_inputs_embeds_size()
|
||||
|
||||
# Multimodal
|
||||
self.mm_registry = MULTIMODAL_REGISTRY
|
||||
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
|
||||
self.model_config
|
||||
)
|
||||
if self.supports_mm_inputs:
|
||||
self.encoder_runner = EncoderRunner(
|
||||
max_num_tokens=self.max_num_tokens,
|
||||
hidden_size=self.inputs_embeds_size,
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
self.uses_mrope = self.model_config.uses_mrope
|
||||
if self.uses_mrope:
|
||||
self.mrope_states = MRopeState(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_num_tokens=self.max_num_tokens,
|
||||
max_model_len=self.max_model_len,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.use_async_scheduling = self.scheduler_config.async_scheduling
|
||||
self.output_copy_stream = torch.cuda.Stream(self.device)
|
||||
@@ -169,6 +146,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.dcp_rank = get_dcp_group().rank_in_group if self.use_dcp else 0
|
||||
self.cp_interleave = self.parallel_config.cp_kv_cache_interleave_size
|
||||
|
||||
# Multimodal
|
||||
self.mm_registry = MULTIMODAL_REGISTRY
|
||||
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
|
||||
self.model_config
|
||||
)
|
||||
self.encoder_cache = None
|
||||
if self.supports_mm_inputs and self.is_first_pp_rank:
|
||||
self.encoder_cache = EncoderCache()
|
||||
|
||||
self.speculator = None
|
||||
self.num_speculative_steps = 0
|
||||
self.use_aux_hidden_state_outputs = False
|
||||
@@ -212,7 +198,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# CUDA graphs.
|
||||
self.cudagraph_manager = CudaGraphManager(
|
||||
self.vllm_config,
|
||||
self.uses_mrope,
|
||||
self.use_aux_hidden_state_outputs,
|
||||
self.device,
|
||||
)
|
||||
@@ -227,13 +212,24 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# KV Connector if configured.
|
||||
self.kv_connector: KVConnector = NO_OP_KV_CONNECTOR
|
||||
|
||||
# Pooling models.
|
||||
self.is_pooling_model = self.model_config.runner_type == "pooling"
|
||||
self.pooling_runner: PoolingRunner | None = None
|
||||
|
||||
# For transferring state from execute_model to subsequent sample_tokens call.
|
||||
self.execute_model_state: tuple | None = None
|
||||
|
||||
def update_max_model_len(self, max_model_len: int) -> None:
|
||||
self.max_model_len = max_model_len
|
||||
self.req_states.max_model_len = max_model_len
|
||||
|
||||
@staticmethod
|
||||
def get_supported_tasks() -> tuple[str]:
|
||||
return ("generate",)
|
||||
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
||||
tasks: list[SupportedTask] = []
|
||||
if self.model_config.runner_type == "generate":
|
||||
tasks.append("generate")
|
||||
if self.pooling_runner is not None:
|
||||
tasks.extend(self.pooling_runner.get_supported_pooling_tasks())
|
||||
return tuple(tasks)
|
||||
|
||||
def load_model(self, *args, **kwargs) -> None:
|
||||
time_before_load = time.perf_counter()
|
||||
@@ -266,7 +262,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
prepare_communication_buffer_for_model(self.model)
|
||||
if self.speculator is not None:
|
||||
prepare_communication_buffer_for_model(self.speculator)
|
||||
prepare_communication_buffer_for_model(self.speculator.model)
|
||||
|
||||
# Initialize the components that require the model.
|
||||
self.model_state = init_model_state(
|
||||
self.vllm_config, self.model, self.encoder_cache, self.device
|
||||
)
|
||||
if self.is_pooling_model:
|
||||
self.pooling_runner = PoolingRunner(self.model)
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model
|
||||
@@ -305,6 +308,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
if self.speculator is not None:
|
||||
# HACK(woosuk)
|
||||
self.speculator.set_attn(
|
||||
self.model_state,
|
||||
self.kv_cache_config,
|
||||
self.attn_groups,
|
||||
self.block_tables,
|
||||
@@ -320,39 +324,27 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
)
|
||||
self.kv_connector = get_kv_connector(self.vllm_config, kv_caches_dict)
|
||||
|
||||
def prepare_dummy_attn_metadata(self, input_batch: InputBatch) -> None:
|
||||
block_tables = self.block_tables.get_dummy_block_tables(input_batch.num_reqs)
|
||||
slot_mappings = self.block_tables.get_dummy_slot_mappings(
|
||||
input_batch.num_tokens
|
||||
)
|
||||
slot_mappings_by_layer = build_slot_mappings_by_layer(
|
||||
slot_mappings, self.kv_cache_config
|
||||
)
|
||||
attn_metadata = build_attn_metadata(
|
||||
attn_groups=self.attn_groups,
|
||||
num_reqs=input_batch.num_reqs,
|
||||
num_tokens=input_batch.num_tokens,
|
||||
query_start_loc_gpu=input_batch.query_start_loc,
|
||||
query_start_loc_cpu=torch.from_numpy(input_batch.query_start_loc_np),
|
||||
max_query_len=input_batch.num_scheduled_tokens.max().item(),
|
||||
seq_lens=input_batch.seq_lens,
|
||||
max_seq_len=self.max_model_len,
|
||||
block_tables=block_tables,
|
||||
slot_mappings=slot_mappings,
|
||||
kv_cache_config=self.kv_cache_config,
|
||||
dcp_local_seq_lens=self.input_buffers.dcp_local_seq_lens,
|
||||
)
|
||||
input_batch.attn_metadata = attn_metadata
|
||||
input_batch.slot_mappings = slot_mappings_by_layer
|
||||
|
||||
@torch.inference_mode()
|
||||
def _dummy_run(
|
||||
self, num_tokens: int, *args, skip_attn: bool = True, **kwargs
|
||||
self,
|
||||
num_tokens: int,
|
||||
*args,
|
||||
skip_attn: bool = True,
|
||||
uniform_decode: bool = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
|
||||
# Create a dummy scheduler output.
|
||||
num_reqs = min(num_tokens, self.max_num_reqs)
|
||||
num_tokens_per_request = [num_tokens // num_reqs] * num_reqs
|
||||
num_tokens_per_request[-1] += num_tokens % num_reqs
|
||||
if uniform_decode:
|
||||
# Align tokens to uniform_decode_query_len for cudagraph
|
||||
# compatibility across DP ranks.
|
||||
query_len = self.cudagraph_manager.uniform_decode_query_len
|
||||
num_reqs = min(cdiv(num_tokens, query_len), self.max_num_reqs)
|
||||
num_tokens = num_reqs * query_len
|
||||
num_tokens_per_request = [query_len] * num_reqs
|
||||
else:
|
||||
num_reqs = min(num_tokens, self.max_num_reqs)
|
||||
num_tokens_per_request = [num_tokens // num_reqs] * num_reqs
|
||||
num_tokens_per_request[-1] += num_tokens % num_reqs
|
||||
assert sum(num_tokens_per_request) == num_tokens
|
||||
num_scheduled_tokens = {
|
||||
f"_dummy_req_{i}": n for i, n in enumerate(num_tokens_per_request)
|
||||
@@ -387,7 +379,41 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
return None, None
|
||||
|
||||
assert self.execute_model_state is not None
|
||||
hidden_states, _, input_batch, _ = self.execute_model_state
|
||||
(
|
||||
input_batch,
|
||||
model_inputs,
|
||||
attn_metadata,
|
||||
slot_mappings_by_layer,
|
||||
hidden_states,
|
||||
aux_hidden_states,
|
||||
kv_connector_output,
|
||||
num_tokens_across_dp,
|
||||
) = self.execute_model_state
|
||||
self.execute_model_state = None
|
||||
|
||||
# dummy run the eagle speculator's propose to ensure DP/EP sync.
|
||||
if self.speculator is not None:
|
||||
self.speculator.propose(
|
||||
input_batch=input_batch,
|
||||
attn_metadata=attn_metadata,
|
||||
slot_mappings=slot_mappings_by_layer,
|
||||
last_hidden_states=hidden_states,
|
||||
aux_hidden_states=aux_hidden_states,
|
||||
num_sampled=torch.ones(
|
||||
input_batch.num_reqs, dtype=torch.int32, device=self.device
|
||||
),
|
||||
num_rejected=torch.zeros(
|
||||
input_batch.num_reqs, dtype=torch.int32, device=self.device
|
||||
),
|
||||
last_sampled=self.req_states.last_sampled_tokens,
|
||||
next_prefill_tokens=self.req_states.next_prefill_tokens,
|
||||
temperature=self.sampler.sampling_states.temperature.gpu,
|
||||
seeds=self.sampler.sampling_states.seeds.gpu,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
dummy_run=True,
|
||||
skip_attn_for_dummy_run=skip_attn,
|
||||
)
|
||||
|
||||
assert hidden_states is not None # Last PP rank always has hidden_states
|
||||
sample_hidden_states = hidden_states[input_batch.logits_indices]
|
||||
return hidden_states, sample_hidden_states
|
||||
@@ -416,39 +442,36 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
expanded_local_pos,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def _dummy_pooler_run(self, hidden_states: torch.Tensor) -> None:
|
||||
assert self.pooling_runner is not None
|
||||
self.pooling_runner.dummy_pooler_run(hidden_states)
|
||||
|
||||
@torch.inference_mode()
|
||||
def profile_run(self) -> None:
|
||||
hidden_states, sample_hidden_states = self._dummy_run(
|
||||
self.max_num_tokens, skip_attn=True
|
||||
)
|
||||
|
||||
# Only run sampler on last PP rank (non-last ranks return None).
|
||||
# Only run sampler/pooler on last PP rank (non-last ranks return None).
|
||||
if self.is_last_pp_rank:
|
||||
assert sample_hidden_states is not None
|
||||
self._dummy_sampler_run(sample_hidden_states)
|
||||
|
||||
if self.speculator is not None:
|
||||
num_tokens_across_dp = make_num_tokens_across_dp(
|
||||
self.parallel_config.data_parallel_size, self.max_num_tokens
|
||||
)
|
||||
self.speculator.run_model(
|
||||
self.max_num_tokens,
|
||||
attn_metadata=None,
|
||||
slot_mappings=None,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
)
|
||||
if self.pooling_runner is None:
|
||||
self._dummy_sampler_run(sample_hidden_states)
|
||||
else:
|
||||
self._dummy_pooler_run(hidden_states)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
del hidden_states, sample_hidden_states
|
||||
gc.collect()
|
||||
|
||||
def reset_mm_cache(self) -> None:
|
||||
if self.supports_mm_inputs:
|
||||
self.encoder_runner.reset_mm_cache()
|
||||
if self.encoder_cache is not None:
|
||||
self.encoder_cache.reset_mm_cache()
|
||||
|
||||
def reset_encoder_cache(self) -> None:
|
||||
if self.supports_mm_inputs:
|
||||
self.encoder_runner.reset_encoder_cache()
|
||||
if self.encoder_cache is not None:
|
||||
self.encoder_cache.reset_encoder_cache()
|
||||
|
||||
def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int:
|
||||
# SP is not supported yet.
|
||||
@@ -477,17 +500,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||
|
||||
with self.maybe_setup_dummy_loras(self.lora_config):
|
||||
mrope_positions = None
|
||||
if self.uses_mrope:
|
||||
mrope_positions = self.mrope_states.mrope_positions
|
||||
inputs_embeds = None
|
||||
if self.supports_mm_inputs:
|
||||
inputs_embeds = self.encoder_runner.inputs_embeds
|
||||
self.cudagraph_manager.capture(
|
||||
model=self.model,
|
||||
model_state=self.model_state,
|
||||
input_buffers=self.input_buffers,
|
||||
mrope_positions=mrope_positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
block_tables=self.block_tables,
|
||||
attn_groups=self.attn_groups,
|
||||
kv_cache_config=self.kv_cache_config,
|
||||
@@ -522,21 +538,20 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
finished_req_ids = finished_req_ids.union(preempted_req_ids)
|
||||
for req_id in finished_req_ids:
|
||||
self.req_states.remove_request(req_id)
|
||||
if self.supports_mm_inputs:
|
||||
self.encoder_runner.remove_request(req_id)
|
||||
if self.encoder_cache is not None:
|
||||
self.encoder_cache.remove_request(req_id)
|
||||
self.prompt_logprobs_worker.remove_request(req_id)
|
||||
self.lora_state.remove_request(req_id)
|
||||
|
||||
def free_states(self, scheduler_output: SchedulerOutput) -> None:
|
||||
if self.supports_mm_inputs:
|
||||
if self.encoder_cache is not None:
|
||||
for mm_hash in scheduler_output.free_encoder_mm_hashes:
|
||||
self.encoder_runner.free_encoder_cache(mm_hash)
|
||||
self.encoder_cache.free_encoder_cache(mm_hash)
|
||||
|
||||
def add_requests(self, scheduler_output: SchedulerOutput) -> None:
|
||||
for new_req_data in scheduler_output.scheduled_new_reqs:
|
||||
assert new_req_data.prompt_token_ids is not None
|
||||
assert new_req_data.prefill_token_ids is not None
|
||||
assert new_req_data.sampling_params is not None
|
||||
req_id = new_req_data.req_id
|
||||
prompt_len = len(new_req_data.prompt_token_ids)
|
||||
self.req_states.add_request(
|
||||
@@ -547,34 +562,27 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
)
|
||||
req_index = self.req_states.req_id_to_index[req_id]
|
||||
|
||||
if self.supports_mm_inputs:
|
||||
self.encoder_runner.add_request(req_id, new_req_data.mm_features)
|
||||
|
||||
# Pre-compute M-RoPE positions for prefill.
|
||||
if self.uses_mrope:
|
||||
self.mrope_states.init_prefill_mrope_positions(
|
||||
req_index,
|
||||
self.model, # type: ignore
|
||||
new_req_data.prefill_token_ids,
|
||||
mm_features=new_req_data.mm_features,
|
||||
)
|
||||
if self.encoder_cache is not None:
|
||||
self.encoder_cache.add_request(req_id, new_req_data.mm_features)
|
||||
|
||||
self.model_state.add_request(req_index, new_req_data)
|
||||
self.block_tables.append_block_ids(
|
||||
req_index, new_req_data.block_ids, overwrite=True
|
||||
)
|
||||
self.sampler.add_request(
|
||||
req_index, prompt_len, new_req_data.sampling_params
|
||||
)
|
||||
self.prompt_logprobs_worker.add_request(
|
||||
req_id, req_index, new_req_data.sampling_params
|
||||
)
|
||||
self.lora_state.add_request(req_id, req_index, new_req_data.lora_request)
|
||||
|
||||
if new_req_data.sampling_params is not None:
|
||||
self.sampler.add_request(
|
||||
req_index, prompt_len, new_req_data.sampling_params
|
||||
)
|
||||
self.prompt_logprobs_worker.add_request(
|
||||
req_id, req_index, new_req_data.sampling_params
|
||||
)
|
||||
|
||||
if scheduler_output.scheduled_new_reqs:
|
||||
self.req_states.apply_staged_writes()
|
||||
self.sampler.apply_staged_writes()
|
||||
if self.uses_mrope:
|
||||
self.mrope_states.apply_staged_writes()
|
||||
self.model_state.apply_staged_writes()
|
||||
|
||||
def update_requests(self, scheduler_output: SchedulerOutput) -> None:
|
||||
# Add new blocks for the existing requests.
|
||||
@@ -637,9 +645,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
idx_mapping, total_num_logits, cu_num_logits, max_expand_len
|
||||
)
|
||||
|
||||
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
|
||||
block_tables = self.block_tables.gather_block_tables(idx_mapping)
|
||||
|
||||
# Get query_start_loc.
|
||||
query_start_loc_np = np.empty(self.max_num_reqs + 1, dtype=np.int32)
|
||||
query_start_loc_np[0] = 0
|
||||
@@ -648,11 +653,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Some attention backends like FA3 require query_start_loc to be non-decreasing.
|
||||
query_start_loc_np[num_reqs + 1 :] = num_tokens
|
||||
async_copy_to_gpu(query_start_loc_np, out=self.input_buffers.query_start_loc)
|
||||
|
||||
query_start_loc_np = query_start_loc_np[: num_reqs + 1]
|
||||
query_start_loc_cpu = torch.from_numpy(query_start_loc_np)
|
||||
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
|
||||
max_query_len = num_scheduled_tokens.max().item()
|
||||
|
||||
# Get prefill tokens if any.
|
||||
if self.req_states.any_prefills(idx_mapping_np):
|
||||
@@ -676,6 +678,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
)
|
||||
seq_lens = self.input_buffers.seq_lens[:num_reqs]
|
||||
|
||||
dcp_local_seq_lens = None
|
||||
if self.use_dcp:
|
||||
# Prepare dcp local seq_lens.
|
||||
prepare_dcp_local_seq_lens(
|
||||
@@ -686,16 +689,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.dcp_rank,
|
||||
self.cp_interleave,
|
||||
)
|
||||
dcp_local_seq_lens = self.input_buffers.dcp_local_seq_lens[:num_reqs]
|
||||
|
||||
# Prepare M-RoPE positions.
|
||||
if self.uses_mrope:
|
||||
self.mrope_states.prepare_mrope_positions(
|
||||
idx_mapping,
|
||||
query_start_loc,
|
||||
self.req_states.prefill_len.gpu,
|
||||
self.req_states.num_computed_tokens.gpu,
|
||||
)
|
||||
dcp_local_seq_lens = self.input_buffers.dcp_local_seq_lens[:num_reqs]
|
||||
|
||||
# Some input token ids are directly read from the last sampled tokens
|
||||
# and draft tokens. Also, get the logits indices to sample tokens from.
|
||||
@@ -711,39 +705,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
total_num_logits,
|
||||
)
|
||||
|
||||
# Compute slot mappings: [num_kv_cache_groups, num_tokens]
|
||||
slot_mappings = self.block_tables.compute_slot_mappings(
|
||||
idx_mapping,
|
||||
query_start_loc,
|
||||
self.input_buffers.positions[:num_tokens],
|
||||
)
|
||||
# Layer name -> slot mapping.
|
||||
slot_mappings_by_layer = build_slot_mappings_by_layer(
|
||||
slot_mappings, self.kv_cache_config
|
||||
)
|
||||
|
||||
# Layer name -> attention metadata.
|
||||
attn_metadata = build_attn_metadata(
|
||||
attn_groups=self.attn_groups,
|
||||
num_reqs=num_reqs,
|
||||
num_tokens=num_tokens,
|
||||
query_start_loc_gpu=query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
max_query_len=max_query_len,
|
||||
seq_lens=self.input_buffers.seq_lens,
|
||||
max_seq_len=self.max_model_len,
|
||||
block_tables=block_tables,
|
||||
slot_mappings=slot_mappings,
|
||||
kv_cache_config=self.kv_cache_config,
|
||||
dcp_local_seq_lens=dcp_local_seq_lens,
|
||||
)
|
||||
|
||||
input_ids = self.input_buffers.input_ids[:num_tokens_after_padding]
|
||||
positions = self.input_buffers.positions[:num_tokens_after_padding]
|
||||
mrope_positions = None
|
||||
if self.uses_mrope:
|
||||
mrope_positions = self.mrope_states.mrope_positions
|
||||
mrope_positions = mrope_positions[:, :num_tokens_after_padding]
|
||||
return InputBatch(
|
||||
req_ids=req_ids,
|
||||
num_reqs=num_reqs,
|
||||
@@ -758,37 +719,36 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
query_start_loc=query_start_loc,
|
||||
query_start_loc_np=query_start_loc_np,
|
||||
seq_lens=seq_lens,
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
mrope_positions=mrope_positions,
|
||||
inputs_embeds=None,
|
||||
attn_metadata=attn_metadata,
|
||||
slot_mappings=slot_mappings_by_layer,
|
||||
dcp_local_seq_lens=dcp_local_seq_lens,
|
||||
input_ids=self.input_buffers.input_ids[:num_tokens_after_padding],
|
||||
positions=self.input_buffers.positions[:num_tokens_after_padding],
|
||||
logits_indices=logits_indices,
|
||||
cu_num_logits=cu_num_logits,
|
||||
cu_num_logits_np=cu_num_logits_np,
|
||||
has_structured_output_reqs=scheduler_output.has_structured_output_requests,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_mm_embeddings(
|
||||
self,
|
||||
scheduled_encoder_inputs: dict[str, list[int]],
|
||||
input_batch: InputBatch,
|
||||
) -> tuple[list[torch.Tensor], torch.Tensor]:
|
||||
mm_hashes, mm_kwargs = self.encoder_runner.prepare_mm_inputs(
|
||||
scheduled_encoder_inputs
|
||||
def prepare_attn(
|
||||
self, input_batch: InputBatch
|
||||
) -> tuple[tuple[torch.Tensor, ...], torch.Tensor]:
|
||||
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
|
||||
block_tables = self.block_tables.gather_block_tables(input_batch.idx_mapping)
|
||||
# Compute slot mappings: [num_kv_cache_groups, num_tokens]
|
||||
slot_mappings = self.block_tables.compute_slot_mappings(
|
||||
input_batch.idx_mapping,
|
||||
input_batch.query_start_loc,
|
||||
input_batch.positions,
|
||||
)
|
||||
self.encoder_runner.execute_mm_encoder(self.model, mm_hashes, mm_kwargs)
|
||||
mm_embeds, is_mm_embed = self.encoder_runner.gather_mm_embeddings(
|
||||
input_batch.req_ids,
|
||||
input_batch.num_tokens,
|
||||
input_batch.num_scheduled_tokens,
|
||||
input_batch.query_start_loc_np,
|
||||
self.req_states.prefill_len.np[input_batch.idx_mapping_np],
|
||||
self.req_states.num_computed_prefill_tokens[input_batch.idx_mapping_np],
|
||||
return block_tables, slot_mappings
|
||||
|
||||
def prepare_dummy_attn(
|
||||
self, input_batch: InputBatch
|
||||
) -> tuple[tuple[torch.Tensor, ...], torch.Tensor]:
|
||||
block_tables = self.block_tables.get_dummy_block_tables(input_batch.num_reqs)
|
||||
slot_mappings = self.block_tables.get_dummy_slot_mappings(
|
||||
input_batch.num_tokens
|
||||
)
|
||||
return mm_embeds, is_mm_embed
|
||||
return block_tables, slot_mappings
|
||||
|
||||
def sample(
|
||||
self,
|
||||
@@ -926,6 +886,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
input_batch = self.prepare_inputs(
|
||||
scheduler_output, num_tokens_after_padding
|
||||
)
|
||||
block_tables, slot_mappings = self.prepare_attn(input_batch)
|
||||
|
||||
if self.lora_config:
|
||||
# Activate LoRA adapters.
|
||||
lora_inputs = self.lora_state.make_lora_inputs(
|
||||
@@ -934,35 +896,61 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
input_batch.num_scheduled_tokens,
|
||||
)
|
||||
self._set_active_loras(*lora_inputs)
|
||||
|
||||
# Only first PP rank prepares multimodal embeddings.
|
||||
if self.supports_mm_inputs and self.is_first_pp_rank:
|
||||
mm_embeds, is_mm_embed = self.get_mm_embeddings(
|
||||
scheduler_output.scheduled_encoder_inputs, input_batch
|
||||
)
|
||||
inputs_embeds = self.encoder_runner.get_inputs_embeds(
|
||||
self.model, input_batch.input_ids, mm_embeds, is_mm_embed
|
||||
)
|
||||
input_batch.inputs_embeds = inputs_embeds[
|
||||
: input_batch.num_tokens_after_padding
|
||||
]
|
||||
else:
|
||||
# No actual tokens to run. A dummy run for DP or memory profiling.
|
||||
num_reqs = min(num_tokens_after_padding, self.max_num_reqs)
|
||||
input_batch = InputBatch.make_dummy(
|
||||
num_reqs=num_reqs,
|
||||
num_tokens=num_tokens_after_padding,
|
||||
input_buffers=self.input_buffers,
|
||||
device=self.device,
|
||||
num_reqs, num_tokens_after_padding, self.input_buffers
|
||||
)
|
||||
if self.uses_mrope:
|
||||
input_batch.mrope_positions = self.mrope_states.mrope_positions[
|
||||
:, :num_tokens_after_padding
|
||||
]
|
||||
if not skip_attn_for_dummy_run:
|
||||
self.prepare_dummy_attn_metadata(input_batch)
|
||||
block_tables, slot_mappings = self.prepare_dummy_attn(input_batch)
|
||||
else:
|
||||
block_tables = None
|
||||
slot_mappings = None
|
||||
# FIXME(woosuk): Fix warmup for LoRA.
|
||||
|
||||
attn_metadata = None
|
||||
slot_mappings_by_layer = None
|
||||
if not (dummy_run and skip_attn_for_dummy_run):
|
||||
assert slot_mappings is not None
|
||||
slot_mappings_by_layer = build_slot_mappings_by_layer(
|
||||
slot_mappings, self.kv_cache_config
|
||||
)
|
||||
assert block_tables is not None
|
||||
attn_metadata = self.model_state.prepare_attn(
|
||||
input_batch,
|
||||
block_tables,
|
||||
slot_mappings,
|
||||
self.attn_groups,
|
||||
self.kv_cache_config,
|
||||
)
|
||||
|
||||
inputs_embeds = None
|
||||
if self.supports_mm_inputs and self.is_first_pp_rank:
|
||||
# Run MM encoder (if needed) and get multimodal embeddings.
|
||||
# Only first PP rank prepares multimodal embeddings.
|
||||
# NOTE(woosuk): We must call get_mm_embeddings even during dummy runs
|
||||
# to obtain inputs_embeds, because the compiled model expects this input.
|
||||
inputs_embeds = self.model_state.get_mm_embeddings(
|
||||
scheduler_output.scheduled_encoder_inputs,
|
||||
input_batch,
|
||||
self.req_states,
|
||||
)
|
||||
|
||||
model_inputs = {
|
||||
"input_ids": input_batch.input_ids,
|
||||
"positions": input_batch.positions,
|
||||
"inputs_embeds": inputs_embeds,
|
||||
# NOTE: Values returned by `prepare_inputs` will override the default
|
||||
# values above.
|
||||
**self.model_state.prepare_inputs(input_batch, self.req_states),
|
||||
}
|
||||
if not self.is_first_pp_rank:
|
||||
# Update for non-first PP ranks.
|
||||
model_inputs["input_ids"] = None
|
||||
model_inputs["inputs_embeds"] = None
|
||||
model_inputs["intermediate_tensors"] = intermediate_tensors
|
||||
|
||||
# Run model.
|
||||
if cudagraph_runtime_mode == CUDAGraphMode.FULL:
|
||||
# Use explicit cudagraph replay for FULL mode.
|
||||
@@ -979,41 +967,22 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
aux_hidden_states = None
|
||||
else:
|
||||
# For piecewise and eager mode, just call model().
|
||||
positions = input_batch.positions
|
||||
if self.uses_mrope:
|
||||
assert input_batch.mrope_positions is not None
|
||||
positions = input_batch.mrope_positions
|
||||
|
||||
if self.is_first_pp_rank:
|
||||
input_ids = input_batch.input_ids
|
||||
inputs_embeds = input_batch.inputs_embeds
|
||||
assert intermediate_tensors is None
|
||||
else:
|
||||
input_ids = None
|
||||
inputs_embeds = None
|
||||
assert intermediate_tensors is not None
|
||||
|
||||
batch_descriptor = BatchDescriptor(
|
||||
num_tokens=input_batch.num_tokens_after_padding,
|
||||
has_lora=self.lora_config is not None,
|
||||
)
|
||||
|
||||
with set_forward_context(
|
||||
input_batch.attn_metadata,
|
||||
attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=input_batch.num_tokens_after_padding,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
batch_descriptor=batch_descriptor,
|
||||
slot_mapping=input_batch.slot_mappings,
|
||||
slot_mapping=slot_mappings_by_layer,
|
||||
):
|
||||
self.kv_connector.pre_forward(scheduler_output)
|
||||
model_output = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
)
|
||||
model_output = self.model(**model_inputs)
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
hidden_states, aux_hidden_states = model_output
|
||||
else:
|
||||
@@ -1021,33 +990,44 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
aux_hidden_states = None
|
||||
|
||||
kv_connector_output = self.kv_connector.post_forward(scheduler_output)
|
||||
self.execute_model_state = (
|
||||
input_batch,
|
||||
model_inputs,
|
||||
attn_metadata,
|
||||
slot_mappings_by_layer,
|
||||
hidden_states,
|
||||
aux_hidden_states,
|
||||
kv_connector_output,
|
||||
num_tokens_across_dp,
|
||||
)
|
||||
|
||||
if not self.is_last_pp_rank:
|
||||
# Non-last PP rank: return IntermediateTensors for sending.
|
||||
assert isinstance(hidden_states, IntermediateTensors)
|
||||
hidden_states.kv_connector_output = kv_connector_output
|
||||
self.execute_model_state = (None, None, input_batch, kv_connector_output)
|
||||
return hidden_states
|
||||
|
||||
# Last rank (or no PP): hidden_states is a tensor for sampling.
|
||||
assert isinstance(hidden_states, torch.Tensor)
|
||||
self.execute_model_state = (
|
||||
hidden_states,
|
||||
aux_hidden_states,
|
||||
input_batch,
|
||||
kv_connector_output,
|
||||
) # type: ignore
|
||||
return None
|
||||
|
||||
@torch.inference_mode()
|
||||
def sample_tokens(
|
||||
self, grammar_output: GrammarOutput | None
|
||||
) -> AsyncOutput | ModelRunnerOutput | None:
|
||||
assert self.execute_model_state is not None
|
||||
hidden_states, aux_hidden_states, input_batch, kv_connector_output = (
|
||||
self.execute_model_state
|
||||
)
|
||||
self.execute_model_state = None # type: ignore
|
||||
if self.execute_model_state is None:
|
||||
# The prior execute_model call must have failed.
|
||||
return None
|
||||
(
|
||||
input_batch,
|
||||
model_inputs,
|
||||
attn_metadata,
|
||||
slot_mappings_by_layer,
|
||||
hidden_states,
|
||||
aux_hidden_states,
|
||||
kv_connector_output,
|
||||
num_tokens_across_dp,
|
||||
) = self.execute_model_state
|
||||
self.execute_model_state = None
|
||||
|
||||
if not self.is_last_pp_rank:
|
||||
# Non-last PP rank: hidden_states is None because this rank produced
|
||||
@@ -1109,6 +1089,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
if self.speculator is not None:
|
||||
draft_tokens = self.speculator.propose(
|
||||
input_batch,
|
||||
attn_metadata,
|
||||
slot_mappings_by_layer,
|
||||
hidden_states,
|
||||
aux_hidden_states,
|
||||
num_sampled,
|
||||
@@ -1117,6 +1099,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.req_states.next_prefill_tokens,
|
||||
self.sampler.sampling_states.temperature.gpu,
|
||||
self.sampler.sampling_states.seeds.gpu,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
)
|
||||
self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens
|
||||
self.draft_tokens_handler.set_draft_tokens(input_batch, draft_tokens)
|
||||
@@ -1127,3 +1110,58 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
def take_draft_token_ids(self) -> DraftTokenIds | None:
|
||||
return self.draft_tokens_handler.get_draft_tokens()
|
||||
|
||||
@torch.inference_mode()
|
||||
def pool(self) -> AsyncPoolingOutput | ModelRunnerOutput | None:
|
||||
if self.execute_model_state is None:
|
||||
# The prior execute_model call must have failed.
|
||||
return None
|
||||
|
||||
input_batch, _, _, _, hidden_states, _, kv_connector_output = (
|
||||
self.execute_model_state
|
||||
)
|
||||
self.execute_model_state = None
|
||||
|
||||
if not self.is_last_pp_rank:
|
||||
self.postprocess_pool(input_batch)
|
||||
return None
|
||||
|
||||
assert self.pooling_runner is not None
|
||||
pooler_output, is_valid = self.pooling_runner.pool(
|
||||
hidden_states, input_batch, self.req_states
|
||||
)
|
||||
self.postprocess_pool(input_batch)
|
||||
|
||||
# Build the model runner output.
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=input_batch.req_ids,
|
||||
req_id_to_index={req_id: i for i, req_id in enumerate(input_batch.req_ids)},
|
||||
kv_connector_output=kv_connector_output,
|
||||
)
|
||||
async_output = AsyncPoolingOutput(
|
||||
model_runner_output=model_runner_output,
|
||||
pooler_output=pooler_output,
|
||||
is_valid=is_valid,
|
||||
main_stream=self.main_stream,
|
||||
copy_stream=self.output_copy_stream,
|
||||
copy_event=self.output_copy_event,
|
||||
)
|
||||
if self.use_async_scheduling:
|
||||
return async_output
|
||||
return async_output.get_output()
|
||||
|
||||
def postprocess_pool(self, input_batch: InputBatch) -> None:
|
||||
# Update the number of computed tokens.
|
||||
post_update_pool(
|
||||
input_batch.idx_mapping,
|
||||
self.req_states.num_computed_tokens.gpu,
|
||||
input_batch.query_start_loc,
|
||||
)
|
||||
|
||||
# Update the number of computed prefill tokens.
|
||||
idx_mapping_np = input_batch.idx_mapping_np
|
||||
computed_prefill = self.req_states.num_computed_prefill_tokens
|
||||
computed_prefill[idx_mapping_np] += input_batch.num_scheduled_tokens
|
||||
np.minimum(
|
||||
computed_prefill, self.req_states.prefill_len.np, out=computed_prefill
|
||||
)
|
||||
|
||||
18
vllm/v1/worker/gpu/model_states/__init__.py
Normal file
18
vllm/v1/worker/gpu/model_states/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache
|
||||
|
||||
|
||||
def init_model_state(
|
||||
vllm_config: VllmConfig,
|
||||
model: nn.Module,
|
||||
encoder_cache: EncoderCache | None,
|
||||
device: torch.device,
|
||||
):
|
||||
from vllm.v1.worker.gpu.model_states.default import DefaultModelState
|
||||
|
||||
return DefaultModelState(vllm_config, model, encoder_cache, device)
|
||||
161
vllm/v1/worker/gpu/model_states/default.py
Normal file
161
vllm/v1/worker/gpu/model_states/default.py
Normal file
@@ -0,0 +1,161 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.core.sched.output import NewRequestData
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
|
||||
from vllm.v1.worker.gpu.input_batch import InputBatch
|
||||
from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache
|
||||
from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner
|
||||
from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState
|
||||
from vllm.v1.worker.gpu.model_states.interface import ModelState
|
||||
from vllm.v1.worker.gpu.states import RequestState
|
||||
from vllm.v1.worker.utils import AttentionGroup
|
||||
|
||||
|
||||
class DefaultModelState(ModelState):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
model: nn.Module,
|
||||
encoder_cache: EncoderCache | None,
|
||||
device: torch.device,
|
||||
):
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.model = model
|
||||
self.device = device
|
||||
|
||||
self.supports_mm_inputs = encoder_cache is not None
|
||||
self.max_model_len = self.model_config.max_model_len
|
||||
self.max_num_reqs = self.scheduler_config.max_num_seqs
|
||||
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
|
||||
self.inputs_embeds_size = self.model_config.get_inputs_embeds_size()
|
||||
self.dtype = self.model_config.dtype
|
||||
|
||||
if self.supports_mm_inputs:
|
||||
assert encoder_cache is not None
|
||||
self.encoder_cache = encoder_cache
|
||||
self.encoder_runner = EncoderRunner(
|
||||
model=self.model,
|
||||
max_num_tokens=self.max_num_tokens,
|
||||
hidden_size=self.inputs_embeds_size,
|
||||
encoder_cache=encoder_cache,
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.uses_mrope = self.model_config.uses_mrope
|
||||
if self.uses_mrope:
|
||||
self.mrope_state = MRopeState(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_num_tokens=self.max_num_tokens,
|
||||
max_model_len=self.max_model_len,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def add_request(self, req_index: int, new_req_data: NewRequestData) -> None:
|
||||
if self.uses_mrope:
|
||||
# Pre-compute M-RoPE positions for prefill.
|
||||
assert new_req_data.prefill_token_ids is not None
|
||||
self.mrope_state.init_prefill_mrope_positions(
|
||||
req_index,
|
||||
self.model, # type: ignore
|
||||
new_req_data.prefill_token_ids,
|
||||
mm_features=new_req_data.mm_features,
|
||||
)
|
||||
|
||||
def apply_staged_writes(self) -> None:
|
||||
if self.uses_mrope:
|
||||
self.mrope_state.apply_staged_writes()
|
||||
|
||||
def get_mm_embeddings(
|
||||
self,
|
||||
scheduled_encoder_inputs: dict[str, list[int]],
|
||||
input_batch: InputBatch,
|
||||
req_states: RequestState,
|
||||
) -> torch.Tensor:
|
||||
mm_hashes, mm_kwargs = self.encoder_runner.prepare_mm_inputs(
|
||||
scheduled_encoder_inputs
|
||||
)
|
||||
if mm_kwargs:
|
||||
# Execute the multimodal encoder.
|
||||
encoder_outputs = self.encoder_runner.execute_mm_encoder(mm_kwargs)
|
||||
# Cache the encoder outputs by mm_hash
|
||||
self.encoder_cache.encoder_outputs.update(zip(mm_hashes, encoder_outputs))
|
||||
|
||||
mm_embeds, is_mm_embed = self.encoder_runner.gather_mm_embeddings(
|
||||
input_batch.req_ids,
|
||||
input_batch.num_tokens,
|
||||
input_batch.num_scheduled_tokens,
|
||||
input_batch.query_start_loc_np,
|
||||
req_states.prefill_len.np[input_batch.idx_mapping_np],
|
||||
req_states.num_computed_prefill_tokens[input_batch.idx_mapping_np],
|
||||
)
|
||||
inputs_embeds = self.encoder_runner.get_inputs_embeds(
|
||||
input_batch.input_ids, mm_embeds, is_mm_embed
|
||||
)
|
||||
return inputs_embeds[: input_batch.num_tokens_after_padding]
|
||||
|
||||
def prepare_inputs(
|
||||
self, input_batch: InputBatch, req_states: RequestState
|
||||
) -> dict[str, torch.Tensor | None]:
|
||||
if not self.uses_mrope:
|
||||
# Common case (1D positions).
|
||||
return {}
|
||||
|
||||
# Prepare M-RoPE positions.
|
||||
self.mrope_state.prepare_mrope_positions(
|
||||
input_batch.idx_mapping,
|
||||
input_batch.query_start_loc,
|
||||
req_states.prefill_len.gpu,
|
||||
req_states.num_computed_tokens.gpu,
|
||||
)
|
||||
mrope_positions = self.mrope_state.mrope_positions[
|
||||
:, : input_batch.num_tokens_after_padding
|
||||
]
|
||||
return {"positions": mrope_positions}
|
||||
|
||||
def prepare_dummy_inputs(
|
||||
self, num_reqs: int, num_tokens: int
|
||||
) -> dict[str, torch.Tensor | None]:
|
||||
model_inputs = {}
|
||||
if self.supports_mm_inputs:
|
||||
inputs_embeds = self.encoder_runner.inputs_embeds[:num_tokens]
|
||||
model_inputs["inputs_embeds"] = inputs_embeds
|
||||
if self.uses_mrope:
|
||||
mrope_positions = self.mrope_state.mrope_positions[:, :num_tokens]
|
||||
model_inputs["positions"] = mrope_positions
|
||||
return model_inputs
|
||||
|
||||
def prepare_attn(
|
||||
self,
|
||||
input_batch: InputBatch,
|
||||
block_tables: tuple[torch.Tensor, ...],
|
||||
slot_mappings: torch.Tensor,
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
) -> dict[str, Any]:
|
||||
query_start_loc_cpu = torch.from_numpy(input_batch.query_start_loc_np)
|
||||
max_query_len = input_batch.num_scheduled_tokens.max().item()
|
||||
attn_metadata = build_attn_metadata(
|
||||
attn_groups=attn_groups,
|
||||
num_reqs=input_batch.num_reqs,
|
||||
num_tokens=input_batch.num_tokens,
|
||||
query_start_loc_gpu=input_batch.query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
max_query_len=max_query_len,
|
||||
seq_lens=input_batch.seq_lens,
|
||||
max_seq_len=self.max_model_len,
|
||||
block_tables=block_tables,
|
||||
slot_mappings=slot_mappings,
|
||||
kv_cache_config=kv_cache_config,
|
||||
dcp_local_seq_lens=input_batch.dcp_local_seq_lens,
|
||||
)
|
||||
return attn_metadata
|
||||
67
vllm/v1/worker/gpu/model_states/interface.py
Normal file
67
vllm/v1/worker/gpu/model_states/interface.py
Normal file
@@ -0,0 +1,67 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.core.sched.output import NewRequestData
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.worker.gpu.input_batch import InputBatch
|
||||
from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache
|
||||
from vllm.v1.worker.gpu.states import RequestState
|
||||
from vllm.v1.worker.utils import AttentionGroup
|
||||
|
||||
|
||||
class ModelState(ABC):
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
model: nn.Module,
|
||||
encoder_cache: EncoderCache | None,
|
||||
device: torch.device,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def add_request(self, req_index: int, new_req_data: NewRequestData) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply_staged_writes(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_mm_embeddings(
|
||||
self,
|
||||
scheduled_encoder_inputs: dict[str, list[int]],
|
||||
input_batch: InputBatch,
|
||||
req_states: RequestState,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def prepare_inputs(
|
||||
self, input_batch: InputBatch, req_states: RequestState
|
||||
) -> dict[str, torch.Tensor | None]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def prepare_dummy_inputs(
|
||||
self, num_reqs: int, num_tokens: int
|
||||
) -> dict[str, torch.Tensor | None]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def prepare_attn(
|
||||
self,
|
||||
input_batch: InputBatch,
|
||||
block_tables: tuple[torch.Tensor, ...],
|
||||
slot_mappings: torch.Tensor,
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
0
vllm/v1/worker/gpu/pool/__init__.py
Normal file
0
vllm/v1/worker/gpu/pool/__init__.py
Normal file
45
vllm/v1/worker/gpu/pool/pooling_runner.py
Normal file
45
vllm/v1/worker/gpu/pool/pooling_runner.py
Normal file
@@ -0,0 +1,45 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import cast
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm.model_executor.models import VllmModelForPooling, is_pooling_model
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.v1.worker.gpu.input_batch import InputBatch
|
||||
from vllm.v1.worker.gpu.states import RequestState
|
||||
|
||||
|
||||
# NOTE(woosuk): Currently, this class only supports the "LAST" pooling task
|
||||
# on decoder-only models. How to support other pooling tasks and models
|
||||
# is to be determined.
|
||||
class PoolingRunner:
|
||||
def __init__(self, model: nn.Module):
|
||||
self.model = cast(VllmModelForPooling, model)
|
||||
|
||||
def get_supported_pooling_tasks(self) -> list[PoolingTask]:
|
||||
if not is_pooling_model(self.model):
|
||||
return []
|
||||
assert "embed" in self.model.pooler.get_supported_tasks()
|
||||
return ["embed"]
|
||||
|
||||
def pool(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
input_batch: InputBatch,
|
||||
req_states: RequestState,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
# TODO(woosuk): Support different types of pooling tasks.
|
||||
last_hidden_states = hidden_states[input_batch.logits_indices]
|
||||
# TODO(woosuk): Make normalization optional.
|
||||
last_hidden_states = F.normalize(last_hidden_states, p=2, dim=-1)
|
||||
|
||||
prompt_len = req_states.prompt_len.gpu[input_batch.idx_mapping]
|
||||
is_valid = input_batch.seq_lens == prompt_len
|
||||
return last_hidden_states, is_valid
|
||||
|
||||
def dummy_pooler_run(self, hidden_states: torch.Tensor) -> None:
|
||||
F.normalize(hidden_states, p=2, dim=-1)
|
||||
return
|
||||
@@ -72,7 +72,7 @@ class BadWordsState:
|
||||
def apply_bad_words(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
expanded_idx_mapping: torch.Tensor,
|
||||
idx_mapping_np: np.ndarray,
|
||||
input_ids: torch.Tensor,
|
||||
expanded_local_pos: torch.Tensor,
|
||||
@@ -84,7 +84,7 @@ class BadWordsState:
|
||||
|
||||
apply_bad_words(
|
||||
logits,
|
||||
idx_mapping,
|
||||
expanded_idx_mapping,
|
||||
self.bad_word_token_ids.gpu,
|
||||
self.bad_word_offsets.gpu,
|
||||
self.num_bad_words.gpu,
|
||||
@@ -114,17 +114,17 @@ def _bad_words_kernel(
|
||||
input_ids_ptr,
|
||||
expanded_local_pos_ptr,
|
||||
):
|
||||
logit_idx = tl.program_id(0)
|
||||
token_idx = tl.program_id(0)
|
||||
bw_idx = tl.program_id(1)
|
||||
|
||||
req_state_idx = tl.load(expanded_idx_mapping_ptr + logit_idx)
|
||||
req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
|
||||
num_bad_words = tl.load(num_bad_words_ptr + req_state_idx)
|
||||
|
||||
if bw_idx >= num_bad_words:
|
||||
return
|
||||
|
||||
pos = tl.load(expanded_local_pos_ptr + logit_idx)
|
||||
cur_req_first_pos = logit_idx - pos
|
||||
pos = tl.load(expanded_local_pos_ptr + token_idx)
|
||||
cur_req_first_pos = token_idx - pos
|
||||
|
||||
prompt_len = tl.load(prompt_len_ptr + req_state_idx)
|
||||
total_len = tl.load(total_len_ptr + req_state_idx)
|
||||
@@ -159,7 +159,7 @@ def _bad_words_kernel(
|
||||
match = match & (expected == actual)
|
||||
|
||||
if match:
|
||||
tl.store(logits_ptr + logit_idx * logits_stride + last_token, -float("inf"))
|
||||
tl.store(logits_ptr + token_idx * logits_stride + last_token, -float("inf"))
|
||||
|
||||
|
||||
def apply_bad_words(
|
||||
@@ -175,8 +175,8 @@ def apply_bad_words(
|
||||
expanded_local_pos: torch.Tensor,
|
||||
max_num_bad_words: int,
|
||||
) -> None:
|
||||
total_num_tokens = logits.shape[0]
|
||||
_bad_words_kernel[(total_num_tokens, max_num_bad_words)](
|
||||
num_tokens = logits.shape[0]
|
||||
_bad_words_kernel[(num_tokens, max_num_bad_words)](
|
||||
logits,
|
||||
logits.stride(0),
|
||||
expanded_idx_mapping,
|
||||
|
||||
@@ -9,13 +9,13 @@ from vllm.triton_utils import tl, triton
|
||||
def _temperature_kernel(
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
idx_mapping_ptr,
|
||||
expanded_idx_mapping_ptr,
|
||||
temperature_ptr,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
||||
token_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
|
||||
temperature = tl.load(temperature_ptr + req_state_idx).to(tl.float32)
|
||||
if temperature == 0.0 or temperature == 1.0:
|
||||
# Early return to avoid loading logits.
|
||||
@@ -25,24 +25,24 @@ def _temperature_kernel(
|
||||
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < vocab_size
|
||||
|
||||
logits = tl.load(logits_ptr + batch_idx * logits_stride + block, mask=mask)
|
||||
logits = tl.load(logits_ptr + token_idx * logits_stride + block, mask=mask)
|
||||
logits = logits.to(tl.float32)
|
||||
logits = logits / temperature
|
||||
tl.store(logits_ptr + batch_idx * logits_stride + block, logits, mask=mask)
|
||||
tl.store(logits_ptr + token_idx * logits_stride + block, logits, mask=mask)
|
||||
|
||||
|
||||
def apply_temperature(
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
expanded_idx_mapping: torch.Tensor,
|
||||
temperature: torch.Tensor,
|
||||
) -> None:
|
||||
num_reqs, vocab_size = logits.shape
|
||||
num_tokens, vocab_size = logits.shape
|
||||
BLOCK_SIZE = 8192
|
||||
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
|
||||
_temperature_kernel[(num_reqs, num_blocks)](
|
||||
_temperature_kernel[(num_tokens, num_blocks)](
|
||||
logits,
|
||||
logits.stride(0),
|
||||
idx_mapping,
|
||||
expanded_idx_mapping,
|
||||
temperature,
|
||||
vocab_size,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
@@ -57,7 +57,7 @@ def _gumbel_sample_kernel(
|
||||
local_max_stride,
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
idx_mapping_ptr,
|
||||
expanded_idx_mapping_ptr,
|
||||
seeds_ptr,
|
||||
pos_ptr,
|
||||
temp_ptr,
|
||||
@@ -65,14 +65,14 @@ def _gumbel_sample_kernel(
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
APPLY_TEMPERATURE: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
||||
token_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
|
||||
|
||||
block_idx = tl.program_id(1)
|
||||
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < vocab_size
|
||||
logits = tl.load(
|
||||
logits_ptr + batch_idx * logits_stride + block,
|
||||
logits_ptr + token_idx * logits_stride + block,
|
||||
mask=mask,
|
||||
other=float("-inf"),
|
||||
)
|
||||
@@ -82,7 +82,7 @@ def _gumbel_sample_kernel(
|
||||
if temp != 0.0:
|
||||
# Calculate the seed for gumbel noise.
|
||||
seed = tl.load(seeds_ptr + req_state_idx)
|
||||
pos = tl.load(pos_ptr + batch_idx)
|
||||
pos = tl.load(pos_ptr + token_idx)
|
||||
gumbel_seed = tl.randint(seed, pos)
|
||||
|
||||
# Generate gumbel noise in FP32.
|
||||
@@ -101,41 +101,41 @@ def _gumbel_sample_kernel(
|
||||
|
||||
value, idx = tl.max(logits, axis=0, return_indices=True)
|
||||
token_id = block_idx * BLOCK_SIZE + idx
|
||||
tl.store(local_argmax_ptr + batch_idx * local_argmax_stride + block_idx, token_id)
|
||||
tl.store(local_max_ptr + batch_idx * local_max_stride + block_idx, value)
|
||||
tl.store(local_argmax_ptr + token_idx * local_argmax_stride + block_idx, token_id)
|
||||
tl.store(local_max_ptr + token_idx * local_max_stride + block_idx, value)
|
||||
|
||||
|
||||
def gumbel_sample(
|
||||
logits: torch.Tensor, # [num_reqs, vocab_size]
|
||||
idx_mapping: torch.Tensor, # [max_num_reqs]
|
||||
logits: torch.Tensor, # [num_tokens, vocab_size]
|
||||
expanded_idx_mapping: torch.Tensor, # [num_tokens]
|
||||
temperature: torch.Tensor, # [max_num_reqs]
|
||||
seed: torch.Tensor, # [max_num_reqs]
|
||||
pos: torch.Tensor, # [num_reqs]
|
||||
pos: torch.Tensor, # [num_tokens]
|
||||
apply_temperature: bool,
|
||||
) -> torch.Tensor:
|
||||
num_reqs, vocab_size = logits.shape
|
||||
num_tokens, vocab_size = logits.shape
|
||||
BLOCK_SIZE = 1024
|
||||
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
|
||||
local_argmax = torch.empty(
|
||||
num_reqs,
|
||||
num_tokens,
|
||||
num_blocks,
|
||||
dtype=torch.int64,
|
||||
device=logits.device,
|
||||
)
|
||||
local_max = torch.empty(
|
||||
num_reqs,
|
||||
num_tokens,
|
||||
num_blocks,
|
||||
dtype=torch.float32,
|
||||
device=logits.device,
|
||||
)
|
||||
_gumbel_sample_kernel[(num_reqs, num_blocks)](
|
||||
_gumbel_sample_kernel[(num_tokens, num_blocks)](
|
||||
local_argmax,
|
||||
local_argmax.stride(0),
|
||||
local_max,
|
||||
local_max.stride(0),
|
||||
logits,
|
||||
logits.stride(0),
|
||||
idx_mapping,
|
||||
expanded_idx_mapping,
|
||||
seed,
|
||||
pos,
|
||||
temperature,
|
||||
|
||||
@@ -121,7 +121,7 @@ class LogitBiasState:
|
||||
def apply_logit_bias(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
expanded_idx_mapping: torch.Tensor,
|
||||
idx_mapping_np: np.ndarray,
|
||||
pos: torch.Tensor,
|
||||
) -> None:
|
||||
@@ -131,7 +131,7 @@ class LogitBiasState:
|
||||
|
||||
apply_logit_bias(
|
||||
logits,
|
||||
idx_mapping,
|
||||
expanded_idx_mapping,
|
||||
pos,
|
||||
self.num_allowed_token_ids.gpu,
|
||||
self.allowed_token_ids.gpu,
|
||||
@@ -149,7 +149,7 @@ def _bias_kernel(
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
vocab_size,
|
||||
idx_mapping_ptr,
|
||||
expanded_idx_mapping_ptr,
|
||||
# Allowed token IDs.
|
||||
num_allowed_token_ids_ptr,
|
||||
allowed_token_ids_ptr,
|
||||
@@ -169,8 +169,8 @@ def _bias_kernel(
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
LOGITS_BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
||||
token_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
|
||||
|
||||
block = tl.arange(0, BLOCK_SIZE)
|
||||
|
||||
@@ -186,21 +186,21 @@ def _bias_kernel(
|
||||
mask=mask,
|
||||
)
|
||||
logits = tl.load(
|
||||
logits_ptr + batch_idx * logits_stride + allowed_token_ids, mask=mask
|
||||
logits_ptr + token_idx * logits_stride + allowed_token_ids, mask=mask
|
||||
)
|
||||
|
||||
# Set logits to -inf for all tokens.
|
||||
for i in range(0, vocab_size, LOGITS_BLOCK_SIZE):
|
||||
offset = i + tl.arange(0, LOGITS_BLOCK_SIZE)
|
||||
tl.store(
|
||||
logits_ptr + batch_idx * logits_stride + offset,
|
||||
logits_ptr + token_idx * logits_stride + offset,
|
||||
-float("inf"),
|
||||
mask=offset < vocab_size,
|
||||
)
|
||||
|
||||
# Restore logits for allowed token IDs.
|
||||
tl.store(
|
||||
logits_ptr + batch_idx * logits_stride + allowed_token_ids,
|
||||
logits_ptr + token_idx * logits_stride + allowed_token_ids,
|
||||
logits,
|
||||
mask=mask,
|
||||
)
|
||||
@@ -214,13 +214,13 @@ def _bias_kernel(
|
||||
mask=mask,
|
||||
)
|
||||
bias = tl.load(bias_ptr + req_state_idx * bias_stride + block, mask=mask)
|
||||
logits = tl.load(logits_ptr + batch_idx * logits_stride + token_ids, mask=mask)
|
||||
logits = tl.load(logits_ptr + token_idx * logits_stride + token_ids, mask=mask)
|
||||
logits += bias
|
||||
tl.store(logits_ptr + batch_idx * logits_stride + token_ids, logits, mask=mask)
|
||||
tl.store(logits_ptr + token_idx * logits_stride + token_ids, logits, mask=mask)
|
||||
|
||||
# Apply min tokens.
|
||||
num_stop_token_ids = tl.load(num_stop_token_ids_ptr + req_state_idx)
|
||||
pos = tl.load(pos_ptr + batch_idx)
|
||||
pos = tl.load(pos_ptr + token_idx)
|
||||
min_len = tl.load(min_lens_ptr + req_state_idx)
|
||||
if num_stop_token_ids > 0 and pos < min_len:
|
||||
mask = block < num_stop_token_ids
|
||||
@@ -229,7 +229,7 @@ def _bias_kernel(
|
||||
mask=mask,
|
||||
)
|
||||
tl.store(
|
||||
logits_ptr + batch_idx * logits_stride + stop_token_ids,
|
||||
logits_ptr + token_idx * logits_stride + stop_token_ids,
|
||||
-float("inf"),
|
||||
mask=mask,
|
||||
)
|
||||
@@ -237,7 +237,7 @@ def _bias_kernel(
|
||||
|
||||
def apply_logit_bias(
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
expanded_idx_mapping: torch.Tensor,
|
||||
pos: torch.Tensor,
|
||||
num_allowed_token_ids: torch.Tensor,
|
||||
allowed_token_ids: torch.Tensor,
|
||||
@@ -248,7 +248,7 @@ def apply_logit_bias(
|
||||
num_stop_token_ids: torch.Tensor,
|
||||
stop_token_ids: torch.Tensor,
|
||||
) -> None:
|
||||
num_reqs, vocab_size = logits.shape
|
||||
num_tokens, vocab_size = logits.shape
|
||||
BLOCK_SIZE = triton.next_power_of_2(
|
||||
max(
|
||||
allowed_token_ids.shape[-1],
|
||||
@@ -257,11 +257,11 @@ def apply_logit_bias(
|
||||
)
|
||||
)
|
||||
LOGITS_BLOCK_SIZE = 8192
|
||||
_bias_kernel[(num_reqs,)](
|
||||
_bias_kernel[(num_tokens,)](
|
||||
logits,
|
||||
logits.stride(0),
|
||||
vocab_size,
|
||||
idx_mapping,
|
||||
expanded_idx_mapping,
|
||||
num_allowed_token_ids,
|
||||
allowed_token_ids,
|
||||
allowed_token_ids.stride(0),
|
||||
|
||||
@@ -9,13 +9,13 @@ from vllm.triton_utils import tl, triton
|
||||
def _min_p_kernel(
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
idx_mapping_ptr,
|
||||
expanded_idx_mapping_ptr,
|
||||
min_p_ptr,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + req_idx)
|
||||
token_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
|
||||
min_p = tl.load(min_p_ptr + req_state_idx).to(tl.float32)
|
||||
if min_p == 0.0:
|
||||
return
|
||||
@@ -25,7 +25,9 @@ def _min_p_kernel(
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < vocab_size
|
||||
logits = tl.load(
|
||||
logits_ptr + req_idx * logits_stride + block, mask=mask, other=float("-inf")
|
||||
logits_ptr + token_idx * logits_stride + block,
|
||||
mask=mask,
|
||||
other=float("-inf"),
|
||||
)
|
||||
max_val = tl.max(tl.maximum(logits, max_val))
|
||||
max_val = max_val.to(tl.float32) # type: ignore
|
||||
@@ -35,21 +37,23 @@ def _min_p_kernel(
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < vocab_size
|
||||
logits = tl.load(
|
||||
logits_ptr + req_idx * logits_stride + block, mask=mask, other=float("-inf")
|
||||
logits_ptr + token_idx * logits_stride + block,
|
||||
mask=mask,
|
||||
other=float("-inf"),
|
||||
)
|
||||
logits = tl.where(logits < threshold, float("-inf"), logits)
|
||||
tl.store(logits_ptr + req_idx * logits_stride + block, logits, mask=mask)
|
||||
tl.store(logits_ptr + token_idx * logits_stride + block, logits, mask=mask)
|
||||
|
||||
|
||||
def apply_min_p(
|
||||
logits: torch.Tensor, idx_mapping: torch.Tensor, min_p: torch.Tensor
|
||||
logits: torch.Tensor, expanded_idx_mapping: torch.Tensor, min_p: torch.Tensor
|
||||
) -> None:
|
||||
num_reqs, vocab_size = logits.shape
|
||||
num_tokens, vocab_size = logits.shape
|
||||
BLOCK_SIZE = 1024
|
||||
_min_p_kernel[(num_reqs,)](
|
||||
_min_p_kernel[(num_tokens,)](
|
||||
logits,
|
||||
logits.stride(0),
|
||||
idx_mapping,
|
||||
expanded_idx_mapping,
|
||||
min_p,
|
||||
vocab_size,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
|
||||
@@ -82,7 +82,7 @@ class PenaltiesState:
|
||||
def apply_penalties(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
expanded_idx_mapping: torch.Tensor,
|
||||
idx_mapping_np: np.ndarray,
|
||||
input_ids: torch.Tensor,
|
||||
expanded_local_pos: torch.Tensor,
|
||||
@@ -94,7 +94,7 @@ class PenaltiesState:
|
||||
|
||||
apply_penalties(
|
||||
logits,
|
||||
idx_mapping,
|
||||
expanded_idx_mapping,
|
||||
input_ids,
|
||||
expanded_local_pos,
|
||||
self.repetition_penalty.gpu,
|
||||
@@ -110,7 +110,7 @@ class PenaltiesState:
|
||||
def _penalties_kernel(
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
idx_mapping_ptr,
|
||||
expanded_idx_mapping_ptr,
|
||||
token_ids_ptr,
|
||||
expanded_local_pos_ptr,
|
||||
repetition_penalty_ptr,
|
||||
@@ -125,7 +125,7 @@ def _penalties_kernel(
|
||||
MAX_SPEC_LEN: tl.constexpr,
|
||||
):
|
||||
token_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + token_idx)
|
||||
req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
|
||||
rep_penalty = tl.load(repetition_penalty_ptr + req_state_idx)
|
||||
freq_penalty = tl.load(frequency_penalty_ptr + req_state_idx)
|
||||
pres_penalty = tl.load(presence_penalty_ptr + req_state_idx)
|
||||
@@ -191,7 +191,7 @@ def _penalties_kernel(
|
||||
|
||||
def apply_penalties(
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
expanded_idx_mapping: torch.Tensor,
|
||||
token_ids: torch.Tensor,
|
||||
expanded_local_pos: torch.Tensor,
|
||||
repetition_penalty: torch.Tensor,
|
||||
@@ -207,7 +207,7 @@ def apply_penalties(
|
||||
_penalties_kernel[(num_tokens, num_blocks)](
|
||||
logits,
|
||||
logits.stride(0),
|
||||
idx_mapping,
|
||||
expanded_idx_mapping,
|
||||
token_ids,
|
||||
expanded_local_pos,
|
||||
repetition_penalty,
|
||||
@@ -225,7 +225,7 @@ def apply_penalties(
|
||||
|
||||
@triton.jit
|
||||
def _bincount_kernel(
|
||||
idx_mapping_ptr,
|
||||
expanded_idx_mapping_ptr,
|
||||
all_token_ids_ptr,
|
||||
all_token_ids_stride,
|
||||
prompt_len_ptr,
|
||||
@@ -236,9 +236,9 @@ def _bincount_kernel(
|
||||
output_bin_counts_stride,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
token_idx = tl.program_id(0)
|
||||
block_idx = tl.program_id(1)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
||||
req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
|
||||
|
||||
prefill_len = tl.load(prefill_len_ptr + req_state_idx)
|
||||
if block_idx * BLOCK_SIZE >= prefill_len:
|
||||
@@ -276,7 +276,7 @@ def _bincount_kernel(
|
||||
|
||||
|
||||
def bincount(
|
||||
idx_mapping: torch.Tensor,
|
||||
expanded_idx_mapping: torch.Tensor,
|
||||
all_token_ids: torch.Tensor,
|
||||
prompt_len: torch.Tensor,
|
||||
prefill_len: torch.Tensor,
|
||||
@@ -284,13 +284,13 @@ def bincount(
|
||||
output_bin_counts: torch.Tensor,
|
||||
max_prefill_len: int,
|
||||
) -> None:
|
||||
prompt_bin_mask[idx_mapping] = 0
|
||||
output_bin_counts[idx_mapping] = 0
|
||||
num_reqs = idx_mapping.shape[0]
|
||||
prompt_bin_mask[expanded_idx_mapping] = 0
|
||||
output_bin_counts[expanded_idx_mapping] = 0
|
||||
num_tokens = expanded_idx_mapping.shape[0]
|
||||
BLOCK_SIZE = 1024
|
||||
num_blocks = triton.cdiv(max_prefill_len, BLOCK_SIZE)
|
||||
_bincount_kernel[(num_reqs, num_blocks)](
|
||||
idx_mapping,
|
||||
_bincount_kernel[(num_tokens, num_blocks)](
|
||||
expanded_idx_mapping,
|
||||
all_token_ids,
|
||||
all_token_ids.stride(0),
|
||||
prompt_len,
|
||||
|
||||
@@ -56,7 +56,7 @@ class Sampler:
|
||||
def __call__(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
expanded_idx_mapping: torch.Tensor,
|
||||
idx_mapping_np: np.ndarray,
|
||||
cu_num_logits_np: np.ndarray,
|
||||
pos: torch.Tensor,
|
||||
@@ -68,7 +68,7 @@ class Sampler:
|
||||
num_nans = get_num_nans(logits) if self.compute_nans else None
|
||||
sampled, processed_logits = self.sample(
|
||||
logits,
|
||||
idx_mapping,
|
||||
expanded_idx_mapping,
|
||||
idx_mapping_np,
|
||||
pos,
|
||||
input_ids,
|
||||
@@ -101,7 +101,7 @@ class Sampler:
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
expanded_idx_mapping: torch.Tensor,
|
||||
idx_mapping_np: np.ndarray,
|
||||
pos: torch.Tensor,
|
||||
input_ids: torch.Tensor,
|
||||
@@ -111,12 +111,14 @@ class Sampler:
|
||||
logits = torch.empty_like(logits, dtype=torch.float32).copy_(logits)
|
||||
|
||||
# Apply logit bias (e.g., allowed_token_ids, min_tokens) in place.
|
||||
self.logit_bias_state.apply_logit_bias(logits, idx_mapping, idx_mapping_np, pos)
|
||||
self.logit_bias_state.apply_logit_bias(
|
||||
logits, expanded_idx_mapping, idx_mapping_np, pos
|
||||
)
|
||||
|
||||
# Apply penalties in place.
|
||||
self.penalties_state.apply_penalties(
|
||||
logits,
|
||||
idx_mapping,
|
||||
expanded_idx_mapping,
|
||||
idx_mapping_np,
|
||||
input_ids,
|
||||
expanded_local_pos,
|
||||
@@ -126,27 +128,29 @@ class Sampler:
|
||||
# Apply bad words masking in place.
|
||||
self.bad_words_state.apply_bad_words(
|
||||
logits,
|
||||
idx_mapping,
|
||||
expanded_idx_mapping,
|
||||
idx_mapping_np,
|
||||
input_ids,
|
||||
expanded_local_pos,
|
||||
)
|
||||
|
||||
# Apply temperature in place.
|
||||
self.sampling_states.apply_temperature(logits, idx_mapping, idx_mapping_np)
|
||||
self.sampling_states.apply_temperature(
|
||||
logits, expanded_idx_mapping, idx_mapping_np
|
||||
)
|
||||
|
||||
# Apply min_p in place.
|
||||
self.sampling_states.apply_min_p(logits, idx_mapping, idx_mapping_np)
|
||||
self.sampling_states.apply_min_p(logits, expanded_idx_mapping, idx_mapping_np)
|
||||
|
||||
# Apply top_k and/or top_p. This might or might not return a new tensor.
|
||||
logits = self.sampling_states.apply_top_k_top_p(
|
||||
logits, idx_mapping, idx_mapping_np
|
||||
logits, expanded_idx_mapping, idx_mapping_np
|
||||
)
|
||||
|
||||
# Sample the next token.
|
||||
sampled = gumbel_sample(
|
||||
logits,
|
||||
idx_mapping,
|
||||
expanded_idx_mapping,
|
||||
self.sampling_states.temperature.gpu,
|
||||
self.sampling_states.seeds.gpu,
|
||||
pos,
|
||||
|
||||
@@ -64,7 +64,7 @@ class SamplingStates:
|
||||
def apply_temperature(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
expanded_idx_mapping: torch.Tensor,
|
||||
idx_mapping_np: np.ndarray,
|
||||
) -> None:
|
||||
temp_np = self.temperature.np[idx_mapping_np]
|
||||
@@ -72,23 +72,23 @@ class SamplingStates:
|
||||
# No request requires temperature. Skip the kernel launch.
|
||||
return
|
||||
|
||||
apply_temperature(logits, idx_mapping, self.temperature.gpu)
|
||||
apply_temperature(logits, expanded_idx_mapping, self.temperature.gpu)
|
||||
|
||||
def apply_min_p(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
expanded_idx_mapping: torch.Tensor,
|
||||
idx_mapping_np: np.ndarray,
|
||||
) -> None:
|
||||
if np.all(self.min_p.np[idx_mapping_np] == 0.0):
|
||||
# No request uses min_p. Skip the kernel launch.
|
||||
return
|
||||
apply_min_p(logits, idx_mapping, self.min_p.gpu)
|
||||
apply_min_p(logits, expanded_idx_mapping, self.min_p.gpu)
|
||||
|
||||
def apply_top_k_top_p(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
expanded_idx_mapping: torch.Tensor,
|
||||
idx_mapping_np: np.ndarray,
|
||||
) -> torch.Tensor:
|
||||
do_top_k = np.any(self.top_k.np[idx_mapping_np] != self.vocab_size)
|
||||
@@ -96,8 +96,8 @@ class SamplingStates:
|
||||
if not (do_top_k or do_top_p):
|
||||
return logits
|
||||
|
||||
top_k = self.top_k.gpu[idx_mapping] if do_top_k else None
|
||||
top_p = self.top_p.gpu[idx_mapping] if do_top_p else None
|
||||
top_k = self.top_k.gpu[expanded_idx_mapping] if do_top_k else None
|
||||
top_p = self.top_p.gpu[expanded_idx_mapping] if do_top_p else None
|
||||
return apply_top_k_top_p(logits, top_k, top_p)
|
||||
|
||||
def max_num_logprobs(self, idx_mapping_np: np.ndarray) -> int:
|
||||
|
||||
@@ -17,6 +17,7 @@ from vllm.v1.worker.gpu.cudagraph_utils import (
|
||||
)
|
||||
from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
|
||||
from vllm.v1.worker.gpu.input_batch import InputBuffers
|
||||
from vllm.v1.worker.gpu.model_states.interface import ModelState
|
||||
from vllm.v1.worker.utils import AttentionGroup
|
||||
|
||||
|
||||
@@ -54,11 +55,32 @@ class EagleCudaGraphManager:
|
||||
def get_cudagraph_size(self, num_tokens: int) -> int | None:
|
||||
return self.cudagraph_sizes.get(num_tokens)
|
||||
|
||||
def get_cudagraph_runtime_mode(
|
||||
self, num_tokens: int
|
||||
) -> tuple[CUDAGraphMode, int | None]:
|
||||
cudagraph_size = self.get_cudagraph_size(num_tokens)
|
||||
if cudagraph_size is None:
|
||||
cudagraph_mode = CUDAGraphMode.NONE
|
||||
else:
|
||||
cudagraph_mode = self.cudagraph_mode
|
||||
|
||||
if (
|
||||
cudagraph_mode == CUDAGraphMode.FULL
|
||||
and cudagraph_size is not None
|
||||
and cudagraph_size not in self.graphs
|
||||
):
|
||||
# If graph wasn't captured yet, fall back to eager.
|
||||
# This might happen when the dummy run is called before capture.
|
||||
cudagraph_mode = CUDAGraphMode.NONE
|
||||
cudagraph_size = None
|
||||
return cudagraph_mode, cudagraph_size
|
||||
|
||||
def capture_graph(
|
||||
self,
|
||||
num_tokens: int,
|
||||
capture_cg_mode: CUDAGraphMode,
|
||||
generate_fn: Callable,
|
||||
model_state: ModelState,
|
||||
input_buffers: InputBuffers,
|
||||
block_tables: BlockTables,
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
@@ -76,12 +98,11 @@ class EagleCudaGraphManager:
|
||||
attn_metadata, slot_mappings = prepare_inputs_to_capture(
|
||||
num_reqs,
|
||||
num_tokens,
|
||||
model_state,
|
||||
input_buffers,
|
||||
block_tables,
|
||||
attn_groups,
|
||||
self.max_model_len,
|
||||
kv_cache_config,
|
||||
uniform_decode_query_len=1,
|
||||
)
|
||||
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
|
||||
|
||||
@@ -158,6 +179,7 @@ class EagleCudaGraphManager:
|
||||
def capture(
|
||||
self,
|
||||
generate_fn: Callable,
|
||||
model_state: ModelState,
|
||||
input_buffers: InputBuffers,
|
||||
block_tables: BlockTables,
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
@@ -173,6 +195,7 @@ class EagleCudaGraphManager:
|
||||
capture_cudagraph_mode=self.cudagraph_mode,
|
||||
desc=f"Capturing eagle CUDA graphs ({self.cudagraph_mode.name})",
|
||||
generate_fn=generate_fn,
|
||||
model_state=model_state,
|
||||
input_buffers=input_buffers,
|
||||
block_tables=block_tables,
|
||||
attn_groups=attn_groups,
|
||||
|
||||
@@ -16,7 +16,9 @@ from vllm.v1.worker.gpu.attn_utils import (
|
||||
build_slot_mappings_by_layer,
|
||||
)
|
||||
from vllm.v1.worker.gpu.block_table import BlockTables
|
||||
from vllm.v1.worker.gpu.dp_utils import get_cudagraph_and_dp_padding
|
||||
from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
|
||||
from vllm.v1.worker.gpu.model_states.interface import ModelState
|
||||
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
|
||||
from vllm.v1.worker.gpu.spec_decode.eagle.cudagraph import EagleCudaGraphManager
|
||||
from vllm.v1.worker.gpu.spec_decode.eagle.utils import load_eagle_model
|
||||
@@ -44,10 +46,13 @@ class EagleSpeculator:
|
||||
# the draft model's hidden size can be different from the target model's
|
||||
# hidden size (e.g., Llama 3.3 70B).
|
||||
self.hidden_size = self.draft_model_config.get_hidden_size()
|
||||
self.inputs_embeds_size = self.draft_model_config.get_inputs_embeds_size()
|
||||
self.vocab_size = self.draft_model_config.get_vocab_size()
|
||||
self.dtype = vllm_config.model_config.dtype
|
||||
|
||||
# DP configuration
|
||||
self.dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
|
||||
self.input_buffers = InputBuffers(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_num_tokens=self.max_num_tokens,
|
||||
@@ -77,10 +82,12 @@ class EagleSpeculator:
|
||||
|
||||
def set_attn(
|
||||
self,
|
||||
model_state: ModelState,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
block_tables: BlockTables,
|
||||
) -> None:
|
||||
self.model_state = model_state
|
||||
self.kv_cache_config = kv_cache_config
|
||||
self.attn_groups = attn_groups
|
||||
self.block_tables = block_tables
|
||||
@@ -120,8 +127,8 @@ class EagleSpeculator:
|
||||
self,
|
||||
num_reqs: int,
|
||||
num_tokens_padded: int,
|
||||
attn_metadata: dict[str, Any],
|
||||
slot_mappings: dict[str, torch.Tensor],
|
||||
attn_metadata: dict[str, Any] | None,
|
||||
slot_mappings: dict[str, torch.Tensor] | None,
|
||||
num_tokens_across_dp: torch.Tensor | None,
|
||||
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||
) -> None:
|
||||
@@ -162,9 +169,10 @@ class EagleSpeculator:
|
||||
self.hidden_states,
|
||||
self.max_model_len,
|
||||
)
|
||||
self.block_tables.compute_slot_mappings(
|
||||
idx_mapping, query_start_loc, pos
|
||||
)
|
||||
if attn_metadata is not None:
|
||||
self.block_tables.compute_slot_mappings(
|
||||
idx_mapping, query_start_loc, pos
|
||||
)
|
||||
|
||||
def capture_model(self) -> None:
|
||||
if self.num_speculative_steps == 1:
|
||||
@@ -172,6 +180,7 @@ class EagleSpeculator:
|
||||
logger.info("Capturing model for Eagle speculator...")
|
||||
self.cudagraph_manager.capture(
|
||||
self.generate_draft,
|
||||
self.model_state,
|
||||
self.input_buffers,
|
||||
self.block_tables,
|
||||
self.attn_groups,
|
||||
@@ -182,6 +191,8 @@ class EagleSpeculator:
|
||||
def propose(
|
||||
self,
|
||||
input_batch: InputBatch,
|
||||
attn_metadata: dict[str, Any],
|
||||
slot_mappings: dict[str, torch.Tensor],
|
||||
# [num_tokens, hidden_size]
|
||||
last_hidden_states: torch.Tensor,
|
||||
# num_layers x [num_tokens, hidden_size]
|
||||
@@ -198,6 +209,9 @@ class EagleSpeculator:
|
||||
temperature: torch.Tensor,
|
||||
# [max_num_reqs]
|
||||
seeds: torch.Tensor,
|
||||
num_tokens_across_dp: torch.Tensor | None = None,
|
||||
dummy_run: bool = False,
|
||||
skip_attn_for_dummy_run: bool = False,
|
||||
) -> torch.Tensor:
|
||||
# NOTE(woosuk): To avoid CPU-GPU synchronization without CPU knowing the
|
||||
# number of rejected tokens, we maintain the size of eagle's input_ids and
|
||||
@@ -229,9 +243,9 @@ class EagleSpeculator:
|
||||
# TODO(woosuk): Support CUDA graph for prefill.
|
||||
last_hidden_states, hidden_states = self.run_model(
|
||||
num_tokens,
|
||||
input_batch.attn_metadata,
|
||||
input_batch.slot_mappings,
|
||||
num_tokens_across_dp=None, # FIXME
|
||||
attn_metadata,
|
||||
slot_mappings,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
)
|
||||
sample_hidden_states = last_hidden_states[last_token_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states)
|
||||
@@ -277,48 +291,64 @@ class EagleSpeculator:
|
||||
self.max_model_len,
|
||||
self.max_num_reqs,
|
||||
)
|
||||
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
|
||||
slot_mappings = self.block_tables.compute_slot_mappings(
|
||||
idx_mapping, query_start_loc, pos
|
||||
)
|
||||
|
||||
cudagraph_size = self.cudagraph_manager.get_cudagraph_size(num_reqs)
|
||||
cudagraph_mode = self.cudagraph_manager.cudagraph_mode
|
||||
if cudagraph_size is not None and cudagraph_mode == CUDAGraphMode.FULL:
|
||||
if not (dummy_run and skip_attn_for_dummy_run):
|
||||
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
|
||||
slot_mappings = self.block_tables.compute_slot_mappings(
|
||||
idx_mapping, query_start_loc, pos
|
||||
)
|
||||
|
||||
cudagraph_mode, cudagraph_size = (
|
||||
self.cudagraph_manager.get_cudagraph_runtime_mode(num_reqs)
|
||||
)
|
||||
num_tokens_padded, num_tokens_across_dp, synced_cudagraph_mode = (
|
||||
get_cudagraph_and_dp_padding(
|
||||
num_reqs,
|
||||
cudagraph_size,
|
||||
cudagraph_mode.value,
|
||||
self.dp_size,
|
||||
self.dp_rank,
|
||||
)
|
||||
)
|
||||
cudagraph_mode = CUDAGraphMode(synced_cudagraph_mode)
|
||||
if cudagraph_mode == CUDAGraphMode.FULL:
|
||||
# Run full CUDA graph.
|
||||
self.cudagraph_manager.run_fullgraph(cudagraph_size)
|
||||
self.cudagraph_manager.run_fullgraph(num_tokens_padded)
|
||||
return self.draft_tokens[:num_reqs]
|
||||
|
||||
# Run eager or piecewise CUDA graph.
|
||||
num_tokens_padded = cudagraph_size if cudagraph_size is not None else num_reqs
|
||||
query_start_loc_cpu = torch.arange(
|
||||
num_reqs + 1, dtype=torch.int32, device="cpu"
|
||||
)
|
||||
block_tables = [x[:num_reqs] for x in self.block_tables.input_block_tables]
|
||||
attn_metadata_updated = None
|
||||
slot_mappings_updated = None
|
||||
if not (dummy_run and skip_attn_for_dummy_run):
|
||||
query_start_loc_cpu = torch.arange(
|
||||
num_reqs + 1, dtype=torch.int32, device="cpu"
|
||||
)
|
||||
block_tables = [x[:num_reqs] for x in self.block_tables.input_block_tables]
|
||||
|
||||
# FIXME(woosuk): This is UNSAFE!!
|
||||
attn_metadata_updated = build_attn_metadata(
|
||||
attn_groups=self.attn_groups,
|
||||
num_reqs=num_reqs,
|
||||
num_tokens=num_reqs,
|
||||
query_start_loc_gpu=query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
max_query_len=1,
|
||||
seq_lens=self.input_buffers.seq_lens[:num_reqs],
|
||||
max_seq_len=self.max_model_len,
|
||||
block_tables=block_tables,
|
||||
slot_mappings=slot_mappings,
|
||||
kv_cache_config=self.kv_cache_config,
|
||||
)
|
||||
slot_mappings_updated = build_slot_mappings_by_layer(
|
||||
slot_mappings, self.kv_cache_config
|
||||
)
|
||||
|
||||
# FIXME(woosuk): This is UNSAFE!!
|
||||
attn_metadata = build_attn_metadata(
|
||||
attn_groups=self.attn_groups,
|
||||
num_reqs=num_reqs,
|
||||
num_tokens=num_reqs,
|
||||
query_start_loc_gpu=query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
max_query_len=1,
|
||||
seq_lens=self.input_buffers.seq_lens[:num_reqs],
|
||||
max_seq_len=self.max_model_len,
|
||||
block_tables=block_tables,
|
||||
slot_mappings=slot_mappings,
|
||||
kv_cache_config=self.kv_cache_config,
|
||||
)
|
||||
slot_mappings_by_layer = build_slot_mappings_by_layer(
|
||||
slot_mappings, self.kv_cache_config
|
||||
)
|
||||
self.generate_draft(
|
||||
num_reqs,
|
||||
num_tokens_padded,
|
||||
attn_metadata,
|
||||
slot_mappings_by_layer,
|
||||
num_tokens_across_dp=None, # FIXME
|
||||
attn_metadata_updated,
|
||||
slot_mappings_updated,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
cudagraph_runtime_mode=cudagraph_mode,
|
||||
)
|
||||
return self.draft_tokens[:num_reqs]
|
||||
|
||||
105
vllm/v1/worker/gpu/warmup.py
Normal file
105
vllm/v1/worker/gpu/warmup.py
Normal file
@@ -0,0 +1,105 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm import PoolingParams, SamplingParams
|
||||
from vllm.v1.core.sched.output import (
|
||||
CachedRequestData,
|
||||
GrammarOutput,
|
||||
NewRequestData,
|
||||
SchedulerOutput,
|
||||
)
|
||||
from vllm.v1.request import Request
|
||||
from vllm.v1.worker.gpu.model_runner import GPUModelRunner
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def warmup_kernels(model_runner: GPUModelRunner) -> None:
|
||||
"""Run two execute_model + sample_tokens iterations to JIT compile
|
||||
triton kernels.
|
||||
|
||||
The first iteration simulates a prefill with requests of 2 prompt
|
||||
tokens each. The second iteration simulates a decode step with all
|
||||
requests generating 1 token each.
|
||||
"""
|
||||
prompt_token_ids = [0, 1]
|
||||
prompt_len = len(prompt_token_ids)
|
||||
num_reqs = min(
|
||||
model_runner.scheduler_config.max_num_seqs,
|
||||
model_runner.scheduler_config.max_num_batched_tokens // prompt_len,
|
||||
)
|
||||
|
||||
num_kv_cache_groups = len(model_runner.kv_cache_config.kv_cache_groups)
|
||||
req_ids = [f"_warmup_{i}_" for i in range(num_reqs)]
|
||||
|
||||
# SamplingParams exercising all sampling features.
|
||||
if model_runner.is_pooling_model:
|
||||
sampling_params = None
|
||||
pooling_params = PoolingParams()
|
||||
else:
|
||||
sampling_params = SamplingParams.for_sampler_warmup()
|
||||
pooling_params = None
|
||||
|
||||
# Step 1: Prefill all requests with 2 prompt tokens each.
|
||||
new_reqs = [
|
||||
NewRequestData.from_request(
|
||||
Request(req_ids[i], prompt_token_ids, sampling_params, pooling_params),
|
||||
# Each request uses a distinct block per KV cache group.
|
||||
block_ids=tuple([i] for _ in range(num_kv_cache_groups)),
|
||||
prefill_token_ids=prompt_token_ids,
|
||||
)
|
||||
for i in range(num_reqs)
|
||||
]
|
||||
|
||||
prefill_output = SchedulerOutput.make_empty()
|
||||
prefill_output.scheduled_new_reqs = new_reqs
|
||||
prefill_output.num_scheduled_tokens = {rid: prompt_len for rid in req_ids}
|
||||
prefill_output.total_num_scheduled_tokens = prompt_len * num_reqs
|
||||
prefill_output.num_common_prefix_blocks = [0] * num_kv_cache_groups
|
||||
|
||||
# Disable KV connector for warmup run.
|
||||
model_runner.kv_connector.set_disabled(True)
|
||||
model_runner.execute_model(prefill_output)
|
||||
|
||||
if not model_runner.is_pooling_model:
|
||||
# Warm up sampler and perform a decode step for non-pooling models.
|
||||
|
||||
grammar_output = None
|
||||
if model_runner.is_last_pp_rank:
|
||||
# Build a GrammarOutput to exercise the structured output bitmask
|
||||
# kernel during the prefill step.
|
||||
vocab_size = model_runner.model_config.get_vocab_size()
|
||||
bitmask_width = (vocab_size + 31) // 32
|
||||
grammar_bitmask = np.full(
|
||||
(len(req_ids), bitmask_width), fill_value=-1, dtype=np.int32
|
||||
)
|
||||
grammar_output = GrammarOutput(
|
||||
structured_output_request_ids=req_ids, grammar_bitmask=grammar_bitmask
|
||||
)
|
||||
|
||||
model_runner.sample_tokens(grammar_output)
|
||||
|
||||
# Step 2: Decode all requests with 1 token each.
|
||||
cached_req_data = CachedRequestData.make_empty()
|
||||
cached_req_data.req_ids = list(req_ids)
|
||||
cached_req_data.new_block_ids = [None] * num_reqs
|
||||
cached_req_data.num_computed_tokens = [prompt_len] * num_reqs
|
||||
cached_req_data.num_output_tokens = [1] * num_reqs
|
||||
|
||||
decode_output = SchedulerOutput.make_empty()
|
||||
decode_output.scheduled_cached_reqs = cached_req_data
|
||||
decode_output.num_scheduled_tokens = {rid: 1 for rid in req_ids}
|
||||
decode_output.total_num_scheduled_tokens = num_reqs
|
||||
decode_output.num_common_prefix_blocks = [0] * num_kv_cache_groups
|
||||
|
||||
model_runner.execute_model(decode_output)
|
||||
model_runner.sample_tokens(None)
|
||||
|
||||
# Clean up - process finish_req_ids.
|
||||
cleanup_output = SchedulerOutput.make_empty()
|
||||
cleanup_output.finished_req_ids = set(req_ids)
|
||||
model_runner.execute_model(cleanup_output)
|
||||
model_runner.kv_connector.set_disabled(False)
|
||||
torch.cuda.synchronize()
|
||||
@@ -53,6 +53,13 @@ class CachedRequestState:
|
||||
pooling_params: PoolingParams | None = None
|
||||
pooling_states: PoolingStates | None = None
|
||||
|
||||
# for multi layer eagle proposer
|
||||
cached_len: torch.Tensor | None = None
|
||||
cached_token_ids: torch.Tensor | None = None
|
||||
cached_hidden_states: torch.Tensor | None = None
|
||||
cached_slot_mappings: torch.Tensor | None = None
|
||||
cached_positions: torch.Tensor | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
|
||||
self.prompt_token_ids, self.prompt_embeds
|
||||
@@ -95,6 +102,8 @@ class InputBatch:
|
||||
is_spec_decode: bool = False,
|
||||
is_pooling_model: bool = False,
|
||||
cp_kv_cache_interleave_size: int = 1,
|
||||
multi_layer_eagle_num: int = 0,
|
||||
hidden_size: int | None = None,
|
||||
):
|
||||
self.is_pooling_model = is_pooling_model
|
||||
self.is_spec_decode = is_spec_decode
|
||||
@@ -211,6 +220,46 @@ class InputBatch:
|
||||
)
|
||||
self.num_accepted_tokens_cpu = self.num_accepted_tokens_cpu_tensor.numpy()
|
||||
|
||||
# Multi layer eagle
|
||||
self.multi_layer_eagle_num = multi_layer_eagle_num
|
||||
if multi_layer_eagle_num > 0:
|
||||
self.cached_len = torch.zeros(
|
||||
(max_num_reqs,), dtype=torch.int64, device=device
|
||||
)
|
||||
self.cached_token_ids = torch.zeros(
|
||||
(
|
||||
max_num_reqs,
|
||||
multi_layer_eagle_num,
|
||||
),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.cached_hidden_states = torch.zeros(
|
||||
(
|
||||
max_num_reqs,
|
||||
multi_layer_eagle_num,
|
||||
hidden_size,
|
||||
),
|
||||
dtype=torch.float,
|
||||
device=device,
|
||||
)
|
||||
self.cached_slot_mappings = torch.zeros(
|
||||
(
|
||||
max_num_reqs,
|
||||
multi_layer_eagle_num,
|
||||
),
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
self.cached_positions = torch.zeros(
|
||||
(
|
||||
max_num_reqs,
|
||||
multi_layer_eagle_num,
|
||||
),
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# lora related
|
||||
self.request_lora_mapping = np.zeros((self.max_num_reqs,), dtype=np.int64)
|
||||
self.lora_id_to_request_ids: dict[int, set[str]] = {}
|
||||
@@ -425,6 +474,13 @@ class InputBatch:
|
||||
# Speculative decoding: by default 1 token is generated.
|
||||
self.num_accepted_tokens_cpu[req_index] = 1
|
||||
|
||||
if self.multi_layer_eagle_num > 0:
|
||||
self.cached_len[req_index] = request.cached_len
|
||||
self.cached_token_ids[req_index] = request.cached_token_ids
|
||||
self.cached_hidden_states[req_index] = request.cached_hidden_states
|
||||
self.cached_slot_mappings[req_index] = request.cached_slot_mappings
|
||||
self.cached_positions[req_index] = request.cached_positions
|
||||
|
||||
# Add request lora ID
|
||||
if request.lora_request:
|
||||
lora_id = request.lora_request.lora_int_id
|
||||
@@ -623,6 +679,24 @@ class InputBatch:
|
||||
self.allowed_token_ids_mask_cpu_tensor[i1],
|
||||
)
|
||||
|
||||
if self.multi_layer_eagle_num > 0:
|
||||
self.cached_len[i1], self.cached_len[i2] = (
|
||||
self.cached_len[i2],
|
||||
self.cached_len[i1],
|
||||
)
|
||||
self.cached_token_ids[[i1, i2], ...] = self.cached_token_ids[
|
||||
[i2, i1], ...
|
||||
]
|
||||
self.cached_hidden_states[[i1, i2], ...] = self.cached_hidden_states[
|
||||
[i2, i1], ...
|
||||
]
|
||||
self.cached_slot_mappings[[i1, i2], ...] = self.cached_slot_mappings[
|
||||
[i2, i1], ...
|
||||
]
|
||||
self.cached_positions[[i1, i2], ...] = self.cached_positions[
|
||||
[i2, i1], ...
|
||||
]
|
||||
|
||||
def condense(self) -> None:
|
||||
"""Slide non-empty requests down into lower, empty indices.
|
||||
|
||||
@@ -745,6 +819,21 @@ class InputBatch:
|
||||
if bad_words_token_ids is not None:
|
||||
self.bad_words_token_ids[empty_index] = bad_words_token_ids
|
||||
|
||||
if self.multi_layer_eagle_num > 0:
|
||||
self.cached_len[empty_index] = self.cached_len[last_req_index]
|
||||
self.cached_token_ids[empty_index] = self.cached_token_ids[
|
||||
last_req_index
|
||||
]
|
||||
self.cached_hidden_states[empty_index] = self.cached_hidden_states[
|
||||
last_req_index
|
||||
]
|
||||
self.cached_slot_mappings[empty_index] = self.cached_slot_mappings[
|
||||
last_req_index
|
||||
]
|
||||
self.cached_positions[empty_index] = self.cached_positions[
|
||||
last_req_index
|
||||
]
|
||||
|
||||
# Decrement last_req_index since it is now empty.
|
||||
last_req_index -= 1
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -7,11 +7,10 @@ import os
|
||||
from collections.abc import Callable
|
||||
from contextlib import AbstractContextManager, nullcontext
|
||||
from types import NoneType
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.nn as nn
|
||||
|
||||
import vllm.envs as envs
|
||||
@@ -32,14 +31,13 @@ from vllm.distributed.kv_transfer import (
|
||||
)
|
||||
from vllm.distributed.parallel_state import (
|
||||
Handle,
|
||||
get_pcp_group,
|
||||
get_pp_group,
|
||||
get_tp_group,
|
||||
get_world_group
|
||||
)
|
||||
from vllm.distributed.weight_transfer import WeightTransferEngineFactory
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.models.interfaces import is_mixture_of_experts
|
||||
from vllm.model_executor.warmup.kernel_warmup import kernel_warmup
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.profiler.wrapper import CudaProfilerWrapper, TorchProfilerWrapper
|
||||
@@ -49,7 +47,6 @@ from vllm.tracing import instrument
|
||||
from vllm.utils.mem_utils import MemorySnapshot, format_gib, memory_profiling
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||||
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||
from vllm.v1.outputs import (
|
||||
AsyncModelRunnerOutput,
|
||||
@@ -61,6 +58,8 @@ from vllm.v1.worker.utils import is_residual_scattered_for_sp
|
||||
from vllm.v1.worker.worker_base import WorkerBase
|
||||
from vllm.v1.worker.workspace import init_workspace_manager
|
||||
|
||||
from ...model_executor.model_loader import TensorizerLoader
|
||||
from .gpu.warmup import warmup_kernels
|
||||
from .utils import request_memory
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -123,6 +122,10 @@ class Worker(WorkerBase):
|
||||
precision = envs.VLLM_FLOAT32_MATMUL_PRECISION
|
||||
torch.set_float32_matmul_precision(precision)
|
||||
|
||||
from vllm.distributed.elastic_ep.elastic_execute import ElasticEPScalingExecutor
|
||||
|
||||
self.elastic_ep_executor = ElasticEPScalingExecutor(self)
|
||||
|
||||
# Buffers saved before sleep
|
||||
self._sleep_saved_buffers: dict[str, torch.Tensor] = {}
|
||||
|
||||
@@ -316,12 +319,29 @@ class Worker(WorkerBase):
|
||||
# FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
|
||||
# to hijack tensor allocation.
|
||||
def load_model(self) -> None:
|
||||
eep_scale_up = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
|
||||
dummy_weights = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
|
||||
if dummy_weights:
|
||||
(
|
||||
expanded_physical_to_logical,
|
||||
num_logical_experts,
|
||||
old_num_physical_experts,
|
||||
) = self.elastic_ep_executor.receive_expert_mapping()
|
||||
num_physical_experts = expanded_physical_to_logical.shape[1]
|
||||
self.parallel_config.eplb_config.num_redundant_experts = (
|
||||
num_physical_experts - num_logical_experts
|
||||
)
|
||||
|
||||
with (
|
||||
self._maybe_get_memory_pool_context(tag="weights"),
|
||||
set_current_vllm_config(self.vllm_config),
|
||||
):
|
||||
self.model_runner.load_model(eep_scale_up=eep_scale_up)
|
||||
self.model_runner.load_model(load_dummy_weights=dummy_weights)
|
||||
|
||||
if dummy_weights:
|
||||
self.model_runner.setup_eplb_from_mapping(
|
||||
expanded_physical_to_logical, old_num_physical_experts
|
||||
)
|
||||
self.model_runner.eep_eplb_suppressed = True
|
||||
|
||||
def update_config(self, overrides: dict[str, Any]) -> None:
|
||||
self.model_runner.update_config(overrides)
|
||||
@@ -421,9 +441,10 @@ class Worker(WorkerBase):
|
||||
# metadata across workers.
|
||||
if (metadata := connector.get_handshake_metadata()) is None:
|
||||
return None
|
||||
|
||||
tp_rank = get_tp_group().rank_in_group
|
||||
return {tp_rank: metadata}
|
||||
|
||||
# tp_rank = get_tp_group().rank_in_group
|
||||
global_rank = get_world_group().rank_in_group
|
||||
return {global_rank: metadata}
|
||||
|
||||
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
||||
return self.model_runner.get_kv_cache_spec()
|
||||
@@ -461,8 +482,16 @@ class Worker(WorkerBase):
|
||||
else:
|
||||
self.model_runner.initialize_kv_cache(kv_cache_config)
|
||||
|
||||
# Build KV-zero metadata outside the CuMem pool so the bookkeeping
|
||||
# GPU tensors (seg_addrs, block-id buffers) use the standard PyTorch
|
||||
# allocator and are not discarded during sleep/wake cycles.
|
||||
if kv_cache_config.needs_kv_cache_zeroing and hasattr(
|
||||
self.model_runner, "_init_kv_zero_meta"
|
||||
):
|
||||
self.model_runner._init_kv_zero_meta()
|
||||
|
||||
@instrument(span_name="Warmup (GPU)")
|
||||
def compile_or_warm_up_model(self) -> None:
|
||||
def compile_or_warm_up_model(self) -> float:
|
||||
warmup_sizes = []
|
||||
|
||||
if self.vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE:
|
||||
@@ -558,12 +587,15 @@ class Worker(WorkerBase):
|
||||
|
||||
logger.debug(msg)
|
||||
|
||||
# Warm up sampler and preallocate memory buffer for logits and other
|
||||
# sampling related tensors of max possible shape to avoid memory
|
||||
# fragmentation issue.
|
||||
# NOTE: This is called after `capture_model` on purpose to prevent
|
||||
# memory buffers from being cleared by `torch.cuda.empty_cache`.
|
||||
if get_pp_group().is_last_rank:
|
||||
if self.use_v2_model_runner:
|
||||
# V2: Run full execute_model + sample_tokens to JIT compile triton kernels.
|
||||
warmup_kernels(self.model_runner)
|
||||
elif get_pp_group().is_last_rank:
|
||||
# V1: Warm up sampler and preallocate memory buffer for logits and other
|
||||
# sampling related tensors of max possible shape to avoid memory
|
||||
# fragmentation issue.
|
||||
# NOTE: This is called after `capture_model` on purpose to prevent
|
||||
# memory buffers from being cleared by `torch.cuda.empty_cache`.
|
||||
max_num_reqs = min(
|
||||
self.scheduler_config.max_num_seqs,
|
||||
self.scheduler_config.max_num_batched_tokens,
|
||||
@@ -584,6 +616,8 @@ class Worker(WorkerBase):
|
||||
# the model initialization and profiling.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
return self.compilation_config.compilation_time
|
||||
|
||||
def reset_mm_cache(self) -> None:
|
||||
self.model_runner.reset_mm_cache()
|
||||
|
||||
@@ -696,6 +730,12 @@ class Worker(WorkerBase):
|
||||
output = self.model_runner.execute_model(
|
||||
scheduler_output, intermediate_tensors
|
||||
)
|
||||
if (
|
||||
self.use_v2_model_runner
|
||||
and self.model_runner.is_pooling_model
|
||||
and output is None
|
||||
):
|
||||
output = self.model_runner.pool() # type: ignore
|
||||
if isinstance(
|
||||
output, ModelRunnerOutput | AsyncModelRunnerOutput | NoneType
|
||||
):
|
||||
@@ -744,7 +784,8 @@ class Worker(WorkerBase):
|
||||
|
||||
# Create the profiler wrapper only on the first start call
|
||||
if self.profiler is None:
|
||||
if self.profiler_config.profiler == "torch":
|
||||
profiler_type = self.profiler_config.profiler
|
||||
if profiler_type == "torch":
|
||||
self.profiler = TorchProfilerWrapper(
|
||||
self.profiler_config,
|
||||
worker_name=trace_name,
|
||||
@@ -754,14 +795,18 @@ class Worker(WorkerBase):
|
||||
logger.debug(
|
||||
"Starting torch profiler with trace name: %s", trace_name
|
||||
)
|
||||
elif self.profiler_config.profiler == "cuda":
|
||||
elif profiler_type == "cuda":
|
||||
self.profiler = CudaProfilerWrapper(self.profiler_config)
|
||||
logger.debug("Starting CUDA profiler")
|
||||
self.profiler.start()
|
||||
else:
|
||||
# Profiler already initialized. Restart profiling but keep
|
||||
# the original trace name from the first initialization.
|
||||
self.profiler.start()
|
||||
else:
|
||||
# Config validation should prevent this code being reached
|
||||
raise ValueError(
|
||||
f"Invalid profiler value of {self.profiler_config.profiler}"
|
||||
)
|
||||
|
||||
# If profiler already initialized, restart profiling but keep
|
||||
# the original trace name from the first initialization.
|
||||
self.profiler.start()
|
||||
else:
|
||||
if self.profiler is None:
|
||||
logger.warning("Profiler was not started, nothing to stop.")
|
||||
@@ -787,227 +832,6 @@ class Worker(WorkerBase):
|
||||
# worker will always be healthy as long as it's running.
|
||||
return
|
||||
|
||||
def _eplb_before_scale_down(self, old_ep_size: int, new_ep_size: int) -> None:
|
||||
from vllm.distributed.parallel_state import get_ep_group
|
||||
|
||||
if get_ep_group().rank == 0:
|
||||
logger.info(
|
||||
"[Elastic EP] Starting expert resharding before scaling down..."
|
||||
)
|
||||
rank_mapping = {
|
||||
old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1
|
||||
for old_ep_rank in range(old_ep_size)
|
||||
}
|
||||
assert self.model_runner.eplb_state is not None
|
||||
self.model_runner.eplb_state.rearrange(
|
||||
execute_shuffle=True,
|
||||
global_expert_loads=None,
|
||||
rank_mapping=rank_mapping,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
if get_ep_group().rank == 0:
|
||||
logger.info("[Elastic EP] Expert resharding completed!")
|
||||
|
||||
def _eplb_after_scale_up(
|
||||
self,
|
||||
old_ep_size: int,
|
||||
new_ep_size: int,
|
||||
global_expert_loads: list[torch.Tensor] | None,
|
||||
) -> None:
|
||||
from vllm.distributed.parallel_state import get_ep_group
|
||||
|
||||
if get_ep_group().rank == 0:
|
||||
logger.info("[Elastic EP] Starting expert resharding after scaling up...")
|
||||
rank_mapping = {old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size)}
|
||||
assert self.model_runner.eplb_state is not None
|
||||
self.model_runner.eplb_state.rearrange(
|
||||
execute_shuffle=True,
|
||||
global_expert_loads=global_expert_loads,
|
||||
rank_mapping=rank_mapping,
|
||||
)
|
||||
if get_ep_group().rank == 0:
|
||||
logger.info("[Elastic EP] Expert resharding completed!")
|
||||
|
||||
def _reconfigure_parallel_config(
|
||||
self, reconfig_request: ReconfigureDistributedRequest
|
||||
) -> None:
|
||||
"""
|
||||
Update parallel config with provided reconfig_request
|
||||
"""
|
||||
parallel_config = self.vllm_config.parallel_config
|
||||
parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
|
||||
if (
|
||||
reconfig_request.new_data_parallel_rank
|
||||
!= ReconfigureRankType.KEEP_CURRENT_RANK
|
||||
):
|
||||
parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
|
||||
if (
|
||||
reconfig_request.new_data_parallel_rank_local
|
||||
!= ReconfigureRankType.KEEP_CURRENT_RANK
|
||||
):
|
||||
parallel_config.data_parallel_rank_local = (
|
||||
reconfig_request.new_data_parallel_rank_local
|
||||
)
|
||||
parallel_config.data_parallel_master_ip = (
|
||||
reconfig_request.new_data_parallel_master_ip
|
||||
)
|
||||
parallel_config.data_parallel_master_port = (
|
||||
reconfig_request.new_data_parallel_master_port
|
||||
)
|
||||
|
||||
def _reconfigure_moe(
|
||||
self, old_ep_size: int, new_ep_size: int
|
||||
) -> list[torch.Tensor] | None:
|
||||
"""
|
||||
Reconfigure MoE modules with provided reconfig_request
|
||||
|
||||
Return the global expert load if new_ep_size > old_ep_size,
|
||||
otherwise None
|
||||
"""
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_dp_group,
|
||||
get_ep_group,
|
||||
prepare_communication_buffer_for_model,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE,
|
||||
FusedMoEParallelConfig,
|
||||
)
|
||||
|
||||
parallel_config = self.vllm_config.parallel_config
|
||||
|
||||
def get_moe_modules(model: torch.nn.Module) -> list[FusedMoE]:
|
||||
return [
|
||||
module
|
||||
for module in model.modules()
|
||||
if (
|
||||
module.__class__.__name__ == "FusedMoE"
|
||||
or module.__class__.__name__ == "SharedFusedMoE"
|
||||
)
|
||||
]
|
||||
|
||||
def update_moe_modules(moe_modules: list[FusedMoE], num_local_experts: int):
|
||||
assert all(
|
||||
module.moe_config.num_local_experts == num_local_experts
|
||||
for module in moe_modules
|
||||
), "All MoE modules must have the same number of experts"
|
||||
for module in moe_modules:
|
||||
module.moe_config.num_experts = num_local_experts * new_ep_size
|
||||
module.global_num_experts = module.moe_config.num_experts
|
||||
tp_size = get_tp_group().world_size
|
||||
is_sequence_parallel = parallel_config.use_sequence_parallel_moe
|
||||
sp_size = tp_size if is_sequence_parallel else 1
|
||||
module.moe_parallel_config = FusedMoEParallelConfig.make(
|
||||
tp_size_=tp_size,
|
||||
pcp_size_=get_pcp_group().world_size,
|
||||
dp_size_=get_dp_group().world_size,
|
||||
sp_size_=sp_size,
|
||||
vllm_parallel_config=parallel_config,
|
||||
)
|
||||
module.moe_config.moe_parallel_config = module.moe_parallel_config
|
||||
return moe_modules
|
||||
|
||||
model_moe_modules = get_moe_modules(self.model_runner.model)
|
||||
num_local_experts = model_moe_modules[0].moe_config.num_local_experts
|
||||
|
||||
update_moe_modules(model_moe_modules, num_local_experts)
|
||||
drafter_model = None
|
||||
if hasattr(self.model_runner, "drafter") and hasattr(
|
||||
self.model_runner.drafter, "model"
|
||||
):
|
||||
drafter_model = self.model_runner.drafter.model
|
||||
if drafter_model is not None and is_mixture_of_experts(drafter_model):
|
||||
drafter_moe_modules = get_moe_modules(drafter_model)
|
||||
# Check if drafter and model have matching configs
|
||||
assert (
|
||||
drafter_moe_modules[0].moe_config.num_local_experts == num_local_experts
|
||||
), "Drafter and model configs should be the same"
|
||||
update_moe_modules(drafter_moe_modules, num_local_experts)
|
||||
|
||||
if new_ep_size < old_ep_size:
|
||||
num_local_physical_experts = num_local_experts
|
||||
assert self.model_runner.eplb_state is not None
|
||||
new_physical_experts = (
|
||||
self.model_runner.eplb_state.physical_to_logical_map.shape[1] # type: ignore[attr-defined]
|
||||
)
|
||||
parallel_config.eplb_config.num_redundant_experts = (
|
||||
new_physical_experts
|
||||
- self.model_runner.eplb_state.logical_replica_count.shape[1] # type: ignore[attr-defined]
|
||||
)
|
||||
global_expert_loads = None
|
||||
else:
|
||||
num_local_physical_experts_tensor = torch.tensor(
|
||||
[num_local_experts], dtype=torch.int32, device="cpu"
|
||||
)
|
||||
torch.distributed.broadcast(
|
||||
num_local_physical_experts_tensor,
|
||||
group=get_ep_group().cpu_group,
|
||||
group_src=0,
|
||||
)
|
||||
num_local_physical_experts = int(num_local_physical_experts_tensor.item())
|
||||
new_physical_experts = num_local_physical_experts * new_ep_size
|
||||
assert self.model_runner.eplb_state is not None
|
||||
global_expert_loads_any = self.model_runner.eplb_state.rearrange(
|
||||
execute_shuffle=False
|
||||
)
|
||||
global_expert_loads = cast(list[torch.Tensor], global_expert_loads_any)
|
||||
parallel_config.eplb_config.num_redundant_experts = (
|
||||
new_physical_experts - global_expert_loads[0].shape[1]
|
||||
)
|
||||
prepare_communication_buffer_for_model(self.model_runner.model)
|
||||
if drafter_model is not None:
|
||||
prepare_communication_buffer_for_model(drafter_model)
|
||||
self.model_runner.model.update_physical_experts_metadata(
|
||||
num_physical_experts=new_physical_experts,
|
||||
num_local_physical_experts=num_local_physical_experts,
|
||||
)
|
||||
return global_expert_loads
|
||||
|
||||
def reinitialize_distributed(
|
||||
self, reconfig_request: ReconfigureDistributedRequest
|
||||
) -> None:
|
||||
from vllm.config import set_current_vllm_config
|
||||
from vllm.distributed.parallel_state import (
|
||||
cleanup_dist_env_and_memory,
|
||||
get_ep_group,
|
||||
)
|
||||
|
||||
old_ep_size = get_ep_group().world_size
|
||||
old_ep_rank = get_ep_group().rank
|
||||
new_ep_size = (
|
||||
reconfig_request.new_data_parallel_size
|
||||
* get_tp_group().world_size
|
||||
* get_pp_group().world_size
|
||||
)
|
||||
if new_ep_size < old_ep_size:
|
||||
self._eplb_before_scale_down(old_ep_size, new_ep_size)
|
||||
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
if (
|
||||
reconfig_request.new_data_parallel_rank
|
||||
== ReconfigureRankType.SHUTDOWN_CURRENT_RANK
|
||||
):
|
||||
assert old_ep_rank >= new_ep_size
|
||||
# shutdown
|
||||
return
|
||||
|
||||
self._reconfigure_parallel_config(reconfig_request)
|
||||
|
||||
with set_current_vllm_config(self.vllm_config):
|
||||
init_worker_distributed_environment(
|
||||
self.vllm_config,
|
||||
self.rank,
|
||||
self.distributed_init_method,
|
||||
self.local_rank,
|
||||
)
|
||||
|
||||
global_expert_loads = self._reconfigure_moe(old_ep_size, new_ep_size)
|
||||
|
||||
if new_ep_size > old_ep_size:
|
||||
assert global_expert_loads is not None
|
||||
self._eplb_after_scale_up(old_ep_size, new_ep_size, global_expert_loads)
|
||||
|
||||
def save_sharded_state(
|
||||
self,
|
||||
path: str,
|
||||
@@ -1023,12 +847,11 @@ class Worker(WorkerBase):
|
||||
max_size=max_size,
|
||||
)
|
||||
|
||||
def save_tensorized_model(
|
||||
self,
|
||||
tensorizer_config: "TensorizerConfig",
|
||||
) -> None:
|
||||
self.model_runner.save_tensorized_model(
|
||||
def save_tensorized_model(self, tensorizer_config: "TensorizerConfig") -> None:
|
||||
TensorizerLoader.save_model(
|
||||
self.get_model(),
|
||||
tensorizer_config=tensorizer_config,
|
||||
model_config=self.model_config,
|
||||
)
|
||||
|
||||
def init_weight_transfer_engine(self, init_info: dict) -> None:
|
||||
@@ -1104,6 +927,9 @@ class Worker(WorkerBase):
|
||||
if weight_transfer_engine := getattr(self, "weight_transfer_engine", None):
|
||||
weight_transfer_engine.shutdown()
|
||||
|
||||
def elastic_ep_execute(self, execute_method: str, *args, **kwargs):
|
||||
return self.elastic_ep_executor.execute(execute_method, *args, **kwargs)
|
||||
|
||||
|
||||
def init_worker_distributed_environment(
|
||||
vllm_config: VllmConfig,
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import dataclasses
|
||||
import itertools
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
@@ -13,6 +15,7 @@ from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState
|
||||
from vllm.v1.worker.lora_model_runner_mixin import GPUInputBatch
|
||||
|
||||
@@ -59,10 +62,36 @@ def get_mamba_groups(kv_cache_config: KVCacheConfig) -> tuple[list[int], MambaSp
|
||||
return mamba_group_ids, mamba_specs[0]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MambaCopyBuffers:
|
||||
src_ptrs: CpuGpuBuffer
|
||||
dst_ptrs: CpuGpuBuffer
|
||||
sizes: CpuGpuBuffer
|
||||
offset: int = 0
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
max_num_reqs: int,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
copy_funcs: tuple[MambaStateCopyFunc, ...],
|
||||
make_buffer: Callable[..., CpuGpuBuffer],
|
||||
) -> "MambaCopyBuffers":
|
||||
mamba_group_ids, _ = get_mamba_groups(kv_cache_config)
|
||||
entries_per_req = sum(
|
||||
len(kv_cache_config.kv_cache_groups[gid].layer_names)
|
||||
for gid in mamba_group_ids
|
||||
) * len(copy_funcs)
|
||||
n = max_num_reqs * entries_per_req
|
||||
return cls(
|
||||
src_ptrs=make_buffer(n, dtype=torch.int64),
|
||||
dst_ptrs=make_buffer(n, dtype=torch.int64),
|
||||
sizes=make_buffer(n, dtype=torch.int32),
|
||||
)
|
||||
|
||||
|
||||
def collect_mamba_copy_meta(
|
||||
src_state_list: list[int],
|
||||
dest_state_list: list[int],
|
||||
num_elements_list: list[int],
|
||||
copy_bufs: MambaCopyBuffers,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
|
||||
mamba_group_ids: list[int],
|
||||
@@ -71,10 +100,15 @@ def collect_mamba_copy_meta(
|
||||
accept_token_bias: int,
|
||||
req_state: CachedRequestState,
|
||||
forward_context: dict[str, Any],
|
||||
):
|
||||
) -> None:
|
||||
if src_block_idx == dest_block_idx and accept_token_bias == 0:
|
||||
return
|
||||
|
||||
src_ptrs_np = copy_bufs.src_ptrs.np
|
||||
dst_ptrs_np = copy_bufs.dst_ptrs.np
|
||||
sizes_np = copy_bufs.sizes.np
|
||||
offset = copy_bufs.offset
|
||||
|
||||
for mamba_group_id in mamba_group_ids:
|
||||
block_ids = req_state.block_ids[mamba_group_id]
|
||||
dest_block_id = block_ids[dest_block_idx]
|
||||
@@ -87,25 +121,23 @@ def collect_mamba_copy_meta(
|
||||
state, block_ids, src_block_idx, accept_token_bias + 1
|
||||
)
|
||||
|
||||
src_state_list.append(copy_spec.start_addr)
|
||||
dest_state_list.append(state[dest_block_id].data_ptr())
|
||||
num_elements_list.append(copy_spec.num_elements * state.element_size())
|
||||
src_ptrs_np[offset] = copy_spec.start_addr
|
||||
dst_ptrs_np[offset] = state[dest_block_id].data_ptr()
|
||||
sizes_np[offset] = copy_spec.num_elements * state.element_size()
|
||||
offset += 1
|
||||
|
||||
copy_bufs.offset = offset
|
||||
|
||||
|
||||
def do_mamba_copy_block(
|
||||
src_state_list: list[int],
|
||||
dest_state_list: list[int],
|
||||
num_elements_list: list[int],
|
||||
):
|
||||
if len(src_state_list) == 0:
|
||||
def do_mamba_copy_block(copy_bufs: MambaCopyBuffers):
|
||||
n = copy_bufs.offset
|
||||
if n == 0:
|
||||
return
|
||||
assert len(src_state_list) == len(dest_state_list)
|
||||
assert len(src_state_list) == len(num_elements_list)
|
||||
src_state_ptrs = torch.tensor(src_state_list, device="cuda", dtype=torch.int64)
|
||||
dst_state_ptrs = torch.tensor(dest_state_list, device="cuda", dtype=torch.int64)
|
||||
num_elements = torch.tensor(num_elements_list, device="cuda", dtype=torch.int32)
|
||||
|
||||
batch_memcpy(src_state_ptrs, dst_state_ptrs, num_elements)
|
||||
batch_memcpy(
|
||||
copy_bufs.src_ptrs.copy_to_gpu(n),
|
||||
copy_bufs.dst_ptrs.copy_to_gpu(n),
|
||||
copy_bufs.sizes.copy_to_gpu(n),
|
||||
)
|
||||
|
||||
|
||||
def preprocess_mamba(
|
||||
@@ -117,6 +149,7 @@ def preprocess_mamba(
|
||||
requests: dict[str, CachedRequestState],
|
||||
forward_context: dict[str, Any],
|
||||
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
|
||||
copy_bufs: MambaCopyBuffers,
|
||||
):
|
||||
"""
|
||||
Copy the mamba state of previous step to the last
|
||||
@@ -138,9 +171,7 @@ def preprocess_mamba(
|
||||
for req_id in itertools.chain(finished_req_ids, preempted_req_ids, resumed_req_ids):
|
||||
mamba_state_idx.pop(req_id, None)
|
||||
|
||||
src_state_list: list[int] = []
|
||||
dest_state_list: list[int] = []
|
||||
num_elements_list: list[int] = []
|
||||
copy_bufs.offset = 0
|
||||
for i, req_id in enumerate(input_batch.req_ids):
|
||||
req_state = requests[req_id]
|
||||
prev_state_idx = mamba_state_idx.get(req_id)
|
||||
@@ -169,9 +200,7 @@ def preprocess_mamba(
|
||||
mamba_state_idx[req_id] = curr_state_idx
|
||||
if prev_state_idx != -1 and prev_state_idx != curr_state_idx:
|
||||
collect_mamba_copy_meta(
|
||||
src_state_list,
|
||||
dest_state_list,
|
||||
num_elements_list,
|
||||
copy_bufs,
|
||||
kv_cache_config,
|
||||
mamba_state_copy_funcs,
|
||||
mamba_group_ids,
|
||||
@@ -182,7 +211,7 @@ def preprocess_mamba(
|
||||
forward_context,
|
||||
)
|
||||
input_batch.num_accepted_tokens_cpu[i] = 1
|
||||
do_mamba_copy_block(src_state_list, dest_state_list, num_elements_list)
|
||||
do_mamba_copy_block(copy_bufs)
|
||||
|
||||
|
||||
def postprocess_mamba(
|
||||
@@ -193,6 +222,7 @@ def postprocess_mamba(
|
||||
mamba_state_idx: dict[str, int],
|
||||
forward_context: dict[str, Any],
|
||||
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
|
||||
copy_bufs: MambaCopyBuffers,
|
||||
):
|
||||
"""
|
||||
If a blocks is converted from partial block to full block in this step, copy the
|
||||
@@ -203,9 +233,7 @@ def postprocess_mamba(
|
||||
num_accepted_tokens_cpu = input_batch.num_accepted_tokens_cpu
|
||||
# NOTE: can be optimized as this function always returns the same result
|
||||
mamba_group_ids, mamba_spec = get_mamba_groups(kv_cache_config)
|
||||
src_state_list: list[int] = []
|
||||
dest_state_list: list[int] = []
|
||||
num_elements_list: list[int] = []
|
||||
copy_bufs.offset = 0
|
||||
for i, req_id in enumerate(input_batch.req_ids):
|
||||
req_state = requests[req_id]
|
||||
num_computed_tokens = req_state.num_computed_tokens
|
||||
@@ -225,9 +253,7 @@ def postprocess_mamba(
|
||||
src_block_idx = mamba_state_idx[req_id]
|
||||
dest_block_idx = aligned_new_computed_tokens // mamba_spec.block_size - 1
|
||||
collect_mamba_copy_meta(
|
||||
src_state_list,
|
||||
dest_state_list,
|
||||
num_elements_list,
|
||||
copy_bufs,
|
||||
kv_cache_config,
|
||||
mamba_state_copy_funcs,
|
||||
mamba_group_ids,
|
||||
@@ -239,4 +265,4 @@ def postprocess_mamba(
|
||||
)
|
||||
if src_block_idx == dest_block_idx:
|
||||
num_accepted_tokens_cpu[i] = 1
|
||||
do_mamba_copy_block(src_state_list, dest_state_list, num_elements_list)
|
||||
do_mamba_copy_block(copy_bufs)
|
||||
|
||||
@@ -2,7 +2,10 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass, field
|
||||
from itertools import product as iprod
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
@@ -12,13 +15,208 @@ from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
|
||||
from vllm.model_executor.models.utils import extract_layer_index
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.math_utils import largest_power_of_2_divisor
|
||||
from vllm.utils.mem_utils import MemorySnapshot, format_gib
|
||||
from vllm.v1.attention.backend import AttentionBackend, AttentionMetadataBuilder
|
||||
from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionMetadataBuilder,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
AttentionSpec,
|
||||
EncoderOnlyAttentionSpec,
|
||||
FullAttentionSpec,
|
||||
KVCacheConfig,
|
||||
KVCacheGroupSpec,
|
||||
KVCacheSpec,
|
||||
MambaSpec,
|
||||
UniformTypeKVCacheSpecs,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _zero_kv_blocks_kernel(
|
||||
seg_addrs_ptr,
|
||||
block_ids_ptr,
|
||||
n_blocks,
|
||||
N_SEGS: tl.constexpr,
|
||||
PAGE_SIZE_EL: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""Zero KV cache blocks across all segments in a single launch.
|
||||
|
||||
Each segment is a contiguous region of one block's data. For backends
|
||||
where blocks are outermost (block_dim=0) there is one segment per
|
||||
buffer. For backends where K/V is outermost (block_dim=1) there are
|
||||
two segments per buffer (one for K, one for V).
|
||||
|
||||
seg_addrs_ptr holds absolute byte addresses (int64) for each segment,
|
||||
allowing segments to live in different CUDA allocations.
|
||||
|
||||
Programs are mapped as (block_index, seg_index, chunk_index).
|
||||
"""
|
||||
pid = tl.program_id(0)
|
||||
chunks = PAGE_SIZE_EL // BLOCK_SIZE
|
||||
work_per_block = N_SEGS * chunks
|
||||
block_index = pid // work_per_block
|
||||
if block_index >= n_blocks:
|
||||
return
|
||||
remainder = pid % work_per_block
|
||||
seg_index = remainder // chunks
|
||||
chunk_index = remainder % chunks
|
||||
block_id = tl.load(block_ids_ptr + block_index)
|
||||
seg_addr = tl.load(seg_addrs_ptr + seg_index)
|
||||
ptr = tl.cast(seg_addr, tl.pointer_type(tl.int32))
|
||||
offset = (
|
||||
block_id.to(tl.int64) * PAGE_SIZE_EL + chunk_index.to(tl.int64) * BLOCK_SIZE
|
||||
)
|
||||
cols = tl.arange(0, BLOCK_SIZE).to(tl.int64)
|
||||
tl.store(ptr + offset + cols, tl.zeros([BLOCK_SIZE], dtype=tl.int32))
|
||||
|
||||
|
||||
class KVBlockZeroer:
|
||||
"""Manages efficient zeroing of KV cache blocks via a Triton kernel.
|
||||
|
||||
Call :meth:`init_meta` once after KV caches are allocated to precompute
|
||||
segment addresses, then call :meth:`zero_block_ids` each step to zero
|
||||
newly-allocated blocks.
|
||||
"""
|
||||
|
||||
def __init__(self, device: torch.device, pin_memory: bool):
|
||||
self.device = device
|
||||
self.pin_memory = pin_memory
|
||||
self._meta: tuple[torch.Tensor, int, int, int] | None = None
|
||||
self._id_cap: int = 0
|
||||
self._ids_pinned: torch.Tensor | None = None
|
||||
self._ids_gpu: torch.Tensor | None = None
|
||||
|
||||
def init_meta(
|
||||
self,
|
||||
attn_groups_iter: Iterable["AttentionGroup"],
|
||||
kernel_block_sizes: list[int],
|
||||
cache_dtype: str,
|
||||
runner_only_attn_layers: set[str],
|
||||
static_forward_context: dict[str, Any],
|
||||
) -> None:
|
||||
"""One-time precomputation for zero_block_ids.
|
||||
|
||||
Builds absolute-address table for the Triton zeroing kernel.
|
||||
Each entry is the absolute byte address of a segment start on the
|
||||
GPU, so segments in different CUDA allocations work correctly.
|
||||
|
||||
Block IDs from the scheduler reference logical blocks whose size
|
||||
may differ from the kernel block size (virtual block splitting).
|
||||
PAGE_SIZE_EL accounts for this ratio so that
|
||||
``block_id * PAGE_SIZE_EL`` lands at the correct offset.
|
||||
|
||||
Only AttentionSpec layers are processed; Mamba layers are skipped.
|
||||
"""
|
||||
seen_ptrs: set[int] = set()
|
||||
seg_addrs: list[int] = []
|
||||
page_size_el: int | None = None
|
||||
|
||||
for group in attn_groups_iter:
|
||||
spec = group.kv_cache_spec
|
||||
if type(spec) is not FullAttentionSpec:
|
||||
continue
|
||||
if group.kv_cache_group_id >= len(kernel_block_sizes):
|
||||
continue
|
||||
kernel_bs = kernel_block_sizes[group.kv_cache_group_id]
|
||||
ratio = spec.block_size // kernel_bs
|
||||
block_dim = group.backend.get_kv_cache_block_dim(
|
||||
kernel_bs,
|
||||
spec.num_kv_heads,
|
||||
spec.head_size,
|
||||
cache_dtype_str=cache_dtype,
|
||||
)
|
||||
|
||||
for layer_name in group.layer_names:
|
||||
if layer_name in runner_only_attn_layers:
|
||||
continue
|
||||
kv = static_forward_context[layer_name].kv_cache[0]
|
||||
if isinstance(kv, list):
|
||||
continue
|
||||
dp = kv.data_ptr()
|
||||
if dp in seen_ptrs:
|
||||
continue
|
||||
seen_ptrs.add(dp)
|
||||
|
||||
el = kv.element_size()
|
||||
cur_bytes = kv.stride(block_dim) * el
|
||||
assert cur_bytes % 4 == 0
|
||||
kernel_block_el = cur_bytes // 4
|
||||
cur_page_el = kernel_block_el * ratio
|
||||
if page_size_el is None:
|
||||
page_size_el = cur_page_el
|
||||
else:
|
||||
assert page_size_el == cur_page_el, (
|
||||
f"Non-uniform page sizes: {page_size_el} vs {cur_page_el}"
|
||||
)
|
||||
|
||||
block_stride_bytes = cur_bytes
|
||||
outer_dims = [
|
||||
d
|
||||
for d in range(block_dim)
|
||||
if kv.stride(d) * el > block_stride_bytes
|
||||
]
|
||||
outer_strides = [kv.stride(d) * el for d in outer_dims]
|
||||
for outer in iprod(*(range(kv.shape[d]) for d in outer_dims)):
|
||||
off_bytes = sum(i * s for i, s in zip(outer, outer_strides))
|
||||
seg_addrs.append(dp + off_bytes)
|
||||
|
||||
if not seg_addrs or page_size_el is None:
|
||||
self._meta = None
|
||||
return
|
||||
|
||||
blk_size = min(largest_power_of_2_divisor(page_size_el), 1024)
|
||||
self._id_cap = 8192
|
||||
self._ids_pinned = torch.empty(
|
||||
self._id_cap,
|
||||
dtype=torch.int64,
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
self._ids_gpu = torch.empty(self._id_cap, dtype=torch.int64, device=self.device)
|
||||
self._meta = (
|
||||
torch.tensor(seg_addrs, dtype=torch.int64, device=self.device),
|
||||
page_size_el,
|
||||
blk_size,
|
||||
len(seg_addrs),
|
||||
)
|
||||
|
||||
def zero_block_ids(self, block_ids: list[int]) -> None:
|
||||
"""Zero the KV cache memory for the given block IDs."""
|
||||
if not block_ids or self._meta is None:
|
||||
return
|
||||
seg_addrs, page_size_el, blk_size, n_segs = self._meta
|
||||
n_blocks = len(block_ids)
|
||||
if n_blocks > self._id_cap:
|
||||
self._id_cap = n_blocks * 2
|
||||
self._ids_pinned = torch.empty(
|
||||
self._id_cap,
|
||||
dtype=torch.int64,
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
self._ids_gpu = torch.empty(
|
||||
self._id_cap, dtype=torch.int64, device=self.device
|
||||
)
|
||||
assert self._ids_pinned is not None and self._ids_gpu is not None
|
||||
self._ids_pinned[:n_blocks].numpy()[:] = block_ids
|
||||
idx = self._ids_gpu[:n_blocks]
|
||||
idx.copy_(self._ids_pinned[:n_blocks], non_blocking=True)
|
||||
grid = (n_blocks * n_segs * (page_size_el // blk_size),)
|
||||
_zero_kv_blocks_kernel[grid](
|
||||
seg_addrs,
|
||||
idx,
|
||||
n_blocks,
|
||||
N_SEGS=n_segs,
|
||||
PAGE_SIZE_EL=page_size_el,
|
||||
BLOCK_SIZE=blk_size,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AttentionGroup:
|
||||
backend: type[AttentionBackend]
|
||||
@@ -36,7 +234,7 @@ class AttentionGroup:
|
||||
self,
|
||||
vllm_config,
|
||||
device,
|
||||
kernel_block_size: int | None,
|
||||
kernel_block_size: int | None = None,
|
||||
num_metadata_builders: int = 1,
|
||||
):
|
||||
kv_cache_spec_builder = (
|
||||
@@ -59,6 +257,119 @@ class AttentionGroup:
|
||||
return self.metadata_builders[ubatch_id]
|
||||
|
||||
|
||||
def select_common_block_size(
|
||||
kv_manager_block_size: int, attn_groups: list[AttentionGroup]
|
||||
) -> int:
|
||||
"""
|
||||
Select a block size that is supported by all backends and is a factor of
|
||||
kv_manager_block_size.
|
||||
|
||||
If kv_manager_block_size is supported by all backends, return it directly.
|
||||
Otherwise, return the max supported size.
|
||||
|
||||
Args:
|
||||
kv_manager_block_size: Block size of KV cache.
|
||||
attn_groups: List of attention groups.
|
||||
|
||||
Returns:
|
||||
The selected block size.
|
||||
|
||||
Raises:
|
||||
ValueError: If no valid block size found.
|
||||
"""
|
||||
|
||||
def block_size_is_supported(
|
||||
backends: list[type[AttentionBackend]], block_size: int
|
||||
) -> bool:
|
||||
"""Check if the block size is supported by all backends."""
|
||||
for backend in backends:
|
||||
is_supported = False
|
||||
for supported_size in backend.get_supported_kernel_block_sizes():
|
||||
if isinstance(supported_size, int):
|
||||
if block_size == supported_size:
|
||||
is_supported = True
|
||||
elif isinstance(supported_size, MultipleOf):
|
||||
if block_size % supported_size.base == 0:
|
||||
is_supported = True
|
||||
else:
|
||||
raise ValueError(f"Unknown supported size: {supported_size}")
|
||||
if not is_supported:
|
||||
return False
|
||||
return True
|
||||
|
||||
backends = [group.backend for group in attn_groups]
|
||||
|
||||
# Case 1: if the block_size of kv cache manager is supported by all backends,
|
||||
# return it directly.
|
||||
if block_size_is_supported(backends, kv_manager_block_size):
|
||||
return kv_manager_block_size
|
||||
|
||||
# Case 2: otherwise, the block_size must be an `int`-format supported size of
|
||||
# at least one backend. Iterate over all `int`-format supported sizes in
|
||||
# descending order and return the first one that is supported by all backends.
|
||||
# Simple proof:
|
||||
# If the supported size b is in MultipleOf(x_i) format for all attention
|
||||
# backends i, and b a factor of kv_manager_block_size, then
|
||||
# kv_manager_block_size also satisfies MultipleOf(x_i) for all i. We will
|
||||
# return kv_manager_block_size in case 1.
|
||||
all_int_supported_sizes = set(
|
||||
supported_size
|
||||
for backend in backends
|
||||
for supported_size in backend.get_supported_kernel_block_sizes()
|
||||
if isinstance(supported_size, int)
|
||||
)
|
||||
|
||||
for supported_size in sorted(all_int_supported_sizes, reverse=True):
|
||||
if kv_manager_block_size % supported_size != 0:
|
||||
continue
|
||||
if block_size_is_supported(backends, supported_size):
|
||||
return supported_size
|
||||
raise ValueError(f"No common block size for {kv_manager_block_size}. ")
|
||||
|
||||
|
||||
def prepare_kernel_block_sizes(
|
||||
kv_cache_config: KVCacheConfig, attn_groups: list[list[AttentionGroup]]
|
||||
) -> list[int]:
|
||||
"""
|
||||
Generate kernel_block_sizes that matches each block_size.
|
||||
|
||||
For attention backends that support virtual block splitting,
|
||||
use the supported block sizes from the backend.
|
||||
For other backends (like Mamba), use the same block size (no splitting).
|
||||
|
||||
Args:
|
||||
kv_cache_config: The KV cache configuration.
|
||||
attn_groups: Attention groups indexed by KV cache group id.
|
||||
|
||||
Returns:
|
||||
List of kernel block sizes for each cache group.
|
||||
"""
|
||||
kernel_block_sizes = []
|
||||
for kv_cache_gid, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
|
||||
kv_cache_spec = kv_cache_group.kv_cache_spec
|
||||
if isinstance(kv_cache_spec, UniformTypeKVCacheSpecs):
|
||||
# All layers in the UniformTypeKVCacheSpecs have the same type,
|
||||
# pick an arbitrary one to dispatch.
|
||||
kv_cache_spec = next(iter(kv_cache_spec.kv_cache_specs.values()))
|
||||
if isinstance(kv_cache_spec, EncoderOnlyAttentionSpec):
|
||||
continue
|
||||
if isinstance(kv_cache_spec, AttentionSpec):
|
||||
# This is an attention backend that supports virtual block splitting.
|
||||
kv_manager_block_size = kv_cache_group.kv_cache_spec.block_size
|
||||
selected_kernel_size = select_common_block_size(
|
||||
kv_manager_block_size, attn_groups[kv_cache_gid]
|
||||
)
|
||||
kernel_block_sizes.append(selected_kernel_size)
|
||||
elif isinstance(kv_cache_spec, MambaSpec):
|
||||
# This is likely Mamba or other non-attention cache, no splitting.
|
||||
kernel_block_sizes.append(kv_cache_spec.block_size)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"unknown kv cache spec {kv_cache_group.kv_cache_spec}"
|
||||
)
|
||||
return kernel_block_sizes
|
||||
|
||||
|
||||
def sanity_check_mm_encoder_outputs(
|
||||
mm_embeddings: MultiModalEmbeddings,
|
||||
expected_num_items: int,
|
||||
@@ -201,6 +512,55 @@ def bind_kv_cache(
|
||||
for layer_name, kv_cache in kv_caches.items():
|
||||
# NOTE: Use list because of v0 PP virtual engine.
|
||||
forward_context[layer_name].kv_cache = [kv_cache]
|
||||
|
||||
def bind_kv_cache_scale(
|
||||
kv_caches_scale: dict[str, torch.Tensor],
|
||||
forward_context: dict[str, "Attention"],
|
||||
runner_kv_caches_scale: list[torch.Tensor],
|
||||
num_attn_module: int | None = 1,
|
||||
) -> None:
|
||||
"""
|
||||
Bind the allocated KV cache to both ModelRunner and forward context so
|
||||
that the KV cache can be used in the forward pass.
|
||||
|
||||
This function:
|
||||
1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with
|
||||
kv_caches.
|
||||
2) Associates each attention layer in the `forward_context` with its
|
||||
corresponding KV cache in kv_caches.
|
||||
|
||||
Args:
|
||||
kv_caches: The allocated kv_caches with layer names as keys.
|
||||
forward_context: The global forward context containing all Attention
|
||||
layers with layer names as keys.
|
||||
runner_kv_caches: The kv_cache declared by ModelRunner.
|
||||
"""
|
||||
# Bind kv_caches to ModelRunner
|
||||
assert len(runner_kv_caches_scale) == 0
|
||||
|
||||
# Convert kv_caches dict to a list of tensors in the order of layer_index.
|
||||
index2name = defaultdict(list)
|
||||
for layer_name in kv_caches_scale:
|
||||
index2name[extract_layer_index(layer_name,
|
||||
num_attn_module)].append(layer_name)
|
||||
|
||||
for layer_index in sorted(index2name.keys()):
|
||||
layer_names = index2name[layer_index]
|
||||
if len(layer_names) > 1:
|
||||
# One typical case is encoder-decoder model, e.g., bart.
|
||||
# The cross attention and self attention in the same decoder layer
|
||||
# has different layer_name but the same layer_index.
|
||||
if current_platform.is_cuda() or current_platform.is_xpu():
|
||||
pass
|
||||
else:
|
||||
raise NotImplementedError
|
||||
layer_name = layer_names[0]
|
||||
runner_kv_caches_scale.append(kv_caches_scale[layer_name])
|
||||
|
||||
# Bind kv_caches to forward context
|
||||
for layer_name, kv_cache_scale in kv_caches_scale.items():
|
||||
# NOTE: Use list because of v0 PP virtual engine.
|
||||
forward_context[layer_name].kv_cache_scale = [kv_cache_scale]
|
||||
|
||||
|
||||
def is_residual_scattered_for_sp(
|
||||
|
||||
@@ -87,8 +87,12 @@ class WorkerBase:
|
||||
"""Get specifications for KV cache implementation."""
|
||||
raise NotImplementedError
|
||||
|
||||
def compile_or_warm_up_model(self) -> None:
|
||||
"""Prepare model for execution through compilation/warmup."""
|
||||
def compile_or_warm_up_model(self) -> float:
|
||||
"""Prepare model for execution through compilation/warmup.
|
||||
|
||||
Returns:
|
||||
The accumulated compilation time in seconds.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def check_health(self) -> None:
|
||||
@@ -213,13 +217,8 @@ class WorkerWrapperBase:
|
||||
It is only used during the initialization of the executor,
|
||||
to adjust the rpc_rank of workers after we create all workers.
|
||||
"""
|
||||
# if self.rpc_rank in rank_mapping:
|
||||
# self.rpc_rank = rank_mapping[self.rpc_rank]
|
||||
old_rank = self.rpc_rank
|
||||
if old_rank in rank_mapping:
|
||||
self.rpc_rank = rank_mapping[old_rank]
|
||||
if self.global_rank == old_rank:
|
||||
self.global_rank = rank_mapping[old_rank]
|
||||
if self.rpc_rank in rank_mapping:
|
||||
self.rpc_rank = rank_mapping[self.rpc_rank]
|
||||
|
||||
def update_environment_variables(
|
||||
self,
|
||||
|
||||
@@ -66,6 +66,23 @@ class WorkspaceManager:
|
||||
],
|
||||
)
|
||||
|
||||
def unlock(self) -> None:
|
||||
"""Unlock the workspace to allow growth.
|
||||
|
||||
This is used during elastic EP scaling when the workspace size
|
||||
needs to grow due to changes in the number of experts.
|
||||
"""
|
||||
self._locked = False
|
||||
if envs.VLLM_DEBUG_WORKSPACE:
|
||||
logger.info(
|
||||
"[WORKSPACE DEBUG] Workspace unlocked. Current sizes: %s",
|
||||
[
|
||||
self._workspace_size_bytes(ws) / _MB
|
||||
for ws in self._current_workspaces
|
||||
if ws is not None
|
||||
],
|
||||
)
|
||||
|
||||
def is_locked(self) -> bool:
|
||||
"""Check if workspace is locked."""
|
||||
return self._locked
|
||||
@@ -242,6 +259,17 @@ def lock_workspace() -> None:
|
||||
current_workspace_manager().lock()
|
||||
|
||||
|
||||
def unlock_workspace() -> None:
|
||||
"""Unlock the workspace to allow growth.
|
||||
|
||||
This is used during elastic EP scaling when the workspace size
|
||||
needs to grow due to changes in the number of experts.
|
||||
After scaling operations complete, lock_workspace() should be
|
||||
called again to prevent unexpected allocations.
|
||||
"""
|
||||
current_workspace_manager().unlock()
|
||||
|
||||
|
||||
def reset_workspace_manager() -> None:
|
||||
"""Reset the workspace manager to uninitialized state.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user