Add minimal vLLM 0.16.1 build repo for BI-V150
This commit is contained in:
0
vllm/v1/worker/__init__.py
Normal file
0
vllm/v1/worker/__init__.py
Normal file
342
vllm/v1/worker/block_table.py
Normal file
342
vllm/v1/worker/block_table.py
Normal file
@@ -0,0 +1,342 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.distributed import get_dcp_group, get_pcp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
from vllm.v1.worker.cp_utils import get_total_cp_world_size
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class BlockTable:
|
||||
def __init__(
|
||||
self,
|
||||
block_size: int,
|
||||
max_num_reqs: int,
|
||||
max_num_blocks_per_req: int,
|
||||
max_num_batched_tokens: int,
|
||||
pin_memory: bool,
|
||||
device: torch.device,
|
||||
kernel_block_size: int,
|
||||
cp_kv_cache_interleave_size: int,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
block_size: Block size used for KV cache memory allocation
|
||||
max_num_reqs: Maximum number of concurrent requests supported.
|
||||
max_num_blocks_per_req: Maximum number of blocks per request.
|
||||
max_num_batched_tokens: Maximum number of tokens in a batch.
|
||||
pin_memory: Whether to pin memory for faster GPU transfers.
|
||||
device: Target device for the block table.
|
||||
kernel_block_size: The block_size of underlying attention kernel.
|
||||
Will be the same as `block_size` if `block_size` is supported
|
||||
by the attention kernel.
|
||||
"""
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
self.pin_memory = pin_memory
|
||||
self.device = device
|
||||
|
||||
if kernel_block_size == block_size:
|
||||
# Standard case: allocation and computation use same block size
|
||||
# No block splitting needed, direct mapping
|
||||
self.block_size = block_size
|
||||
self.blocks_per_kv_block = 1
|
||||
self.use_hybrid_blocks = False
|
||||
else:
|
||||
# Hybrid case: allocation block size differs from kernel block size
|
||||
# Memory blocks are subdivided to match kernel requirements
|
||||
# Example: 32-token memory blocks with 16-token kernel blocks
|
||||
# → Each memory block corresponds to 2 kernel blocks
|
||||
if block_size % kernel_block_size != 0:
|
||||
raise ValueError(
|
||||
f"kernel_block_size {kernel_block_size} must divide "
|
||||
f"kv_manager_block_size size {block_size} evenly"
|
||||
)
|
||||
|
||||
self.block_size = kernel_block_size
|
||||
self.blocks_per_kv_block = block_size // kernel_block_size
|
||||
self.use_hybrid_blocks = True
|
||||
|
||||
self.max_num_blocks_per_req = max_num_blocks_per_req * self.blocks_per_kv_block
|
||||
|
||||
self.block_table = self._make_buffer(
|
||||
self.max_num_reqs, self.max_num_blocks_per_req, dtype=torch.int32
|
||||
)
|
||||
self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)
|
||||
|
||||
self.slot_mapping = self._make_buffer(
|
||||
self.max_num_batched_tokens, dtype=torch.int64
|
||||
)
|
||||
|
||||
if self.use_hybrid_blocks:
|
||||
self._kernel_block_arange = np.arange(0, self.blocks_per_kv_block).reshape(
|
||||
1, -1
|
||||
)
|
||||
else:
|
||||
self._kernel_block_arange = None
|
||||
|
||||
try:
|
||||
self.pcp_world_size = get_pcp_group().world_size
|
||||
self.pcp_rank = get_pcp_group().rank_in_group
|
||||
except AssertionError:
|
||||
# PCP might not be initialized in testing
|
||||
self.pcp_world_size = 1
|
||||
self.pcp_rank = 0
|
||||
try:
|
||||
self.dcp_world_size = get_dcp_group().world_size
|
||||
self.dcp_rank = get_dcp_group().rank_in_group
|
||||
except AssertionError:
|
||||
# DCP might not be initialized in testing
|
||||
self.dcp_world_size = 1
|
||||
self.dcp_rank = 0
|
||||
self.cp_kv_cache_interleave_size = cp_kv_cache_interleave_size
|
||||
|
||||
def append_row(
|
||||
self,
|
||||
block_ids: list[int],
|
||||
row_idx: int,
|
||||
) -> None:
|
||||
if not block_ids:
|
||||
return
|
||||
|
||||
if self.use_hybrid_blocks:
|
||||
block_ids = self.map_to_kernel_blocks(
|
||||
np.array(block_ids), self.blocks_per_kv_block, self._kernel_block_arange
|
||||
)
|
||||
|
||||
num_blocks = len(block_ids)
|
||||
start = self.num_blocks_per_row[row_idx]
|
||||
self.num_blocks_per_row[row_idx] += num_blocks
|
||||
self.block_table.np[row_idx, start : start + num_blocks] = block_ids
|
||||
|
||||
def add_row(self, block_ids: list[int], row_idx: int) -> None:
|
||||
self.num_blocks_per_row[row_idx] = 0
|
||||
self.append_row(block_ids, row_idx)
|
||||
|
||||
def move_row(self, src: int, tgt: int) -> None:
|
||||
num_blocks = self.num_blocks_per_row[src]
|
||||
block_table_np = self.block_table.np
|
||||
block_table_np[tgt, :num_blocks] = block_table_np[src, :num_blocks]
|
||||
self.num_blocks_per_row[tgt] = num_blocks
|
||||
|
||||
def swap_row(self, src: int, tgt: int) -> None:
|
||||
src_tgt, tgt_src = [src, tgt], [tgt, src]
|
||||
self.num_blocks_per_row[src_tgt] = self.num_blocks_per_row[tgt_src]
|
||||
self.block_table.np[src_tgt] = self.block_table.np[tgt_src]
|
||||
|
||||
def compute_slot_mapping(
|
||||
self, req_indices: np.ndarray, positions: np.ndarray
|
||||
) -> None:
|
||||
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
||||
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
|
||||
# where K is the max_num_blocks_per_req and the block size is 2.
|
||||
# NOTE(woosuk): We can't simply use `token_indices // block_size`
|
||||
# here because M (max_model_len) is not necessarily divisible by
|
||||
# block_size.
|
||||
total_cp_world_size = self.pcp_world_size * self.dcp_world_size
|
||||
total_cp_rank = self.pcp_rank * self.dcp_world_size + self.dcp_rank
|
||||
if total_cp_world_size > 1:
|
||||
# Note(hc): The DCP implement store kvcache with an interleave
|
||||
# style, the kvcache for the token whose token_idx is i is
|
||||
# always stored on the GPU whose dcp_rank equals i % cp_world_size:
|
||||
|
||||
# Use a "virtual block" which equals to world_size * block_size
|
||||
# for block_table_indices calculation.
|
||||
virtual_block_size = self.block_size * total_cp_world_size
|
||||
block_table_indices = (
|
||||
req_indices * self.max_num_blocks_per_req
|
||||
+ positions // virtual_block_size
|
||||
)
|
||||
|
||||
block_numbers = self.block_table.np.ravel()[block_table_indices]
|
||||
# Use virtual_block_size for mask calculation, which marks local
|
||||
# tokens.
|
||||
virtual_block_offsets = positions % virtual_block_size
|
||||
mask = (
|
||||
virtual_block_offsets
|
||||
// self.cp_kv_cache_interleave_size
|
||||
% total_cp_world_size
|
||||
== total_cp_rank
|
||||
)
|
||||
# Calculate local block_offsets
|
||||
block_offsets = (
|
||||
virtual_block_offsets
|
||||
// (total_cp_world_size * self.cp_kv_cache_interleave_size)
|
||||
* self.cp_kv_cache_interleave_size
|
||||
+ virtual_block_offsets % self.cp_kv_cache_interleave_size
|
||||
)
|
||||
# Calculate slot_mapping
|
||||
slot_mapping = block_numbers * self.block_size + block_offsets
|
||||
# Write final slots, use -1 for not-local
|
||||
self.slot_mapping.np[: req_indices.shape[0]] = np.where(
|
||||
mask, slot_mapping, -1
|
||||
)
|
||||
else:
|
||||
block_table_indices = (
|
||||
req_indices * self.max_num_blocks_per_req + positions // self.block_size
|
||||
)
|
||||
|
||||
block_numbers = self.block_table.np.ravel()[block_table_indices]
|
||||
block_offsets = positions % self.block_size
|
||||
np.add(
|
||||
block_numbers * self.block_size,
|
||||
block_offsets,
|
||||
out=self.slot_mapping.np[: req_indices.shape[0]],
|
||||
)
|
||||
|
||||
def commit_block_table(self, num_reqs: int) -> None:
|
||||
self.block_table.copy_to_gpu(num_reqs)
|
||||
|
||||
def commit_slot_mapping(self, num_tokens: int) -> None:
|
||||
self.slot_mapping.copy_to_gpu(num_tokens)
|
||||
|
||||
def clear(self) -> None:
|
||||
self.block_table.gpu.fill_(0)
|
||||
self.block_table.cpu.fill_(0)
|
||||
|
||||
@staticmethod
|
||||
def map_to_kernel_blocks(
|
||||
kv_manager_block_ids: np.ndarray,
|
||||
blocks_per_kv_block: int,
|
||||
kernel_block_arange: np.ndarray,
|
||||
) -> np.ndarray:
|
||||
"""Convert kv_manager_block_id IDs to kernel block IDs.
|
||||
|
||||
Example:
|
||||
# kv_manager_block_ids: 32 tokens,
|
||||
# Kernel block size: 16 tokens
|
||||
# blocks_per_kv_block = 2
|
||||
>>> kv_manager_block_ids = np.array([0, 1, 2])
|
||||
>>> Result: [0, 1, 2, 3, 4, 5]
|
||||
|
||||
# Each kv_manager_block_id maps to 2 kernel block id:
|
||||
# kv_manager_block_id 0 → kernel block id [0, 1]
|
||||
# kv_manager_block_id 1 → kernel block id [2, 3]
|
||||
# kv_manager_block_id 2 → kernel block id [4, 5]
|
||||
"""
|
||||
if blocks_per_kv_block == 1:
|
||||
return kv_manager_block_ids
|
||||
|
||||
kernel_block_ids = (
|
||||
kv_manager_block_ids.reshape(-1, 1) * blocks_per_kv_block
|
||||
+ kernel_block_arange
|
||||
)
|
||||
|
||||
return kernel_block_ids.reshape(-1)
|
||||
|
||||
def get_device_tensor(self, num_reqs: int) -> torch.Tensor:
|
||||
"""Returns the device tensor of the block table."""
|
||||
return self.block_table.gpu[:num_reqs]
|
||||
|
||||
def get_cpu_tensor(self) -> torch.Tensor:
|
||||
"""Returns the CPU tensor of the block table."""
|
||||
return self.block_table.cpu
|
||||
|
||||
def get_numpy_array(self) -> np.ndarray:
|
||||
"""Returns the numpy array of the block table."""
|
||||
return self.block_table.np
|
||||
|
||||
def _make_buffer(
|
||||
self, *size: int | torch.SymInt, dtype: torch.dtype
|
||||
) -> CpuGpuBuffer:
|
||||
return CpuGpuBuffer(
|
||||
*size, dtype=dtype, device=self.device, pin_memory=self.pin_memory
|
||||
)
|
||||
|
||||
|
||||
class MultiGroupBlockTable:
|
||||
"""The BlockTables for each KV cache group."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_num_reqs: int,
|
||||
max_model_len: int,
|
||||
max_num_batched_tokens: int,
|
||||
pin_memory: bool,
|
||||
device: torch.device,
|
||||
block_sizes: list[int],
|
||||
kernel_block_sizes: list[int],
|
||||
max_num_blocks: list[int] | None = None,
|
||||
cp_kv_cache_interleave_size: int = 1,
|
||||
) -> None:
|
||||
if len(kernel_block_sizes) != len(block_sizes):
|
||||
raise ValueError(
|
||||
f"kernel_block_sizes length ({len(kernel_block_sizes)}) "
|
||||
f"must match block_sizes length ({len(block_sizes)})"
|
||||
)
|
||||
if max_num_blocks is None:
|
||||
# Note(hc): each dcp rank only store
|
||||
# (max_model_len//dcp_world_size) tokens in kvcache,
|
||||
# so the block_size which used for calc max_num_blocks_per_req
|
||||
# must be multiplied by dcp_world_size.
|
||||
total_cp_world_size = get_total_cp_world_size()
|
||||
max_num_blocks = [
|
||||
cdiv(max_model_len, block_size * total_cp_world_size)
|
||||
for block_size in block_sizes
|
||||
]
|
||||
|
||||
if len(max_num_blocks) != len(block_sizes):
|
||||
raise ValueError(
|
||||
f"max_num_blocks length ({len(max_num_blocks)}) "
|
||||
f"must match block_sizes length ({len(block_sizes)})"
|
||||
)
|
||||
|
||||
self.block_tables = [
|
||||
BlockTable(
|
||||
block_size,
|
||||
max_num_reqs,
|
||||
max_num_blocks_per_req,
|
||||
max_num_batched_tokens,
|
||||
pin_memory,
|
||||
device,
|
||||
kernel_block_size,
|
||||
cp_kv_cache_interleave_size,
|
||||
)
|
||||
for block_size, kernel_block_size, max_num_blocks_per_req in zip(
|
||||
block_sizes, kernel_block_sizes, max_num_blocks
|
||||
)
|
||||
]
|
||||
|
||||
def append_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None:
|
||||
for i, block_table in enumerate(self.block_tables):
|
||||
block_table.append_row(block_ids[i], row_idx)
|
||||
|
||||
def add_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None:
|
||||
for i, block_table in enumerate(self.block_tables):
|
||||
block_table.add_row(block_ids[i], row_idx)
|
||||
|
||||
def move_row(self, src: int, tgt: int) -> None:
|
||||
for block_table in self.block_tables:
|
||||
block_table.move_row(src, tgt)
|
||||
|
||||
def swap_row(self, src: int, tgt: int) -> None:
|
||||
for block_table in self.block_tables:
|
||||
block_table.swap_row(src, tgt)
|
||||
|
||||
def compute_slot_mapping(
|
||||
self, req_indices: np.ndarray, positions: np.ndarray
|
||||
) -> None:
|
||||
for block_table in self.block_tables:
|
||||
block_table.compute_slot_mapping(req_indices, positions)
|
||||
|
||||
def commit_block_table(self, num_reqs: int) -> None:
|
||||
for block_table in self.block_tables:
|
||||
block_table.commit_block_table(num_reqs)
|
||||
|
||||
def commit_slot_mapping(self, num_tokens: int) -> None:
|
||||
for block_table in self.block_tables:
|
||||
block_table.commit_slot_mapping(num_tokens)
|
||||
|
||||
def clear(self) -> None:
|
||||
for block_table in self.block_tables:
|
||||
block_table.clear()
|
||||
|
||||
def __getitem__(self, idx: int) -> "BlockTable":
|
||||
"""Returns the BlockTable for the i-th KV cache group."""
|
||||
return self.block_tables[idx]
|
||||
57
vllm/v1/worker/cp_utils.py
Normal file
57
vllm/v1/worker/cp_utils.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.distributed import get_dcp_group, get_pcp_group
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
else:
|
||||
AttentionLayerBase = object
|
||||
|
||||
|
||||
def check_attention_cp_compatibility(vllm_config: VllmConfig) -> None:
|
||||
pcp_size = vllm_config.parallel_config.prefill_context_parallel_size
|
||||
dcp_size = vllm_config.parallel_config.decode_context_parallel_size
|
||||
interleave_size = vllm_config.parallel_config.cp_kv_cache_interleave_size
|
||||
if pcp_size * dcp_size > 1:
|
||||
layer_type = cast(type[Any], AttentionLayerBase)
|
||||
layers = get_layers_from_vllm_config(vllm_config, layer_type)
|
||||
for layer in layers.values():
|
||||
layer_impl = getattr(layer, "impl", None)
|
||||
if layer_impl is None:
|
||||
continue
|
||||
if vllm_config.speculative_config is not None and interleave_size > 1:
|
||||
assert layer_impl.supports_mtp_with_cp_non_trivial_interleave_size, (
|
||||
"MTP with cp_kv_cache_interleave_size > 1 is not "
|
||||
f"supported in {layer_impl.__class__.__name__}."
|
||||
)
|
||||
if dcp_size > 1:
|
||||
assert layer_impl.need_to_return_lse_for_decode, (
|
||||
"DCP requires attention impls to return"
|
||||
" the softmax lse for decode, but the impl "
|
||||
f"{layer_impl.__class__.__name__} "
|
||||
"does not return the softmax lse for decode."
|
||||
)
|
||||
|
||||
if pcp_size > 1:
|
||||
assert layer_impl.supports_pcp, (
|
||||
"PCP requires attention impls' support, "
|
||||
f"but the impl {layer_impl.__class__.__name__} "
|
||||
"does not support PCP."
|
||||
)
|
||||
|
||||
|
||||
def get_total_cp_world_size():
|
||||
try:
|
||||
pcp_world_size = get_pcp_group().world_size
|
||||
except AssertionError:
|
||||
# PCP might not be initialized in testing
|
||||
pcp_world_size = 1
|
||||
try:
|
||||
dcp_world_size = get_dcp_group().world_size
|
||||
except AssertionError:
|
||||
# DCP might not be initialized in testing
|
||||
dcp_world_size = 1
|
||||
return dcp_world_size * pcp_world_size
|
||||
125
vllm/v1/worker/cpu_model_runner.py
Normal file
125
vllm/v1/worker/cpu_model_runner.py
Normal file
@@ -0,0 +1,125 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.tracing import instrument
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class CPUModelRunner(GPUModelRunner):
|
||||
def __init__(self, vllm_config: VllmConfig, device: torch.device):
|
||||
with _torch_cuda_wrapper():
|
||||
super().__init__(vllm_config, device)
|
||||
|
||||
assert device == torch.device("cpu")
|
||||
assert self.speculative_config is None, "spec decode is not supported."
|
||||
|
||||
self.use_cuda_graph = False
|
||||
self.cascade_attn_enabled = False
|
||||
|
||||
self._postprocess_tensors()
|
||||
|
||||
def _postprocess_tensors(self) -> None:
|
||||
# Note: replace device tensors with cpu tensors
|
||||
def replace_tensor(obj: Any, cpu_attr_name: str, device_attr_name) -> None:
|
||||
cpu_tensor = getattr(obj, cpu_attr_name, None)
|
||||
device_tensor = getattr(obj, device_attr_name, None)
|
||||
if cpu_tensor is not None and device_tensor is not None:
|
||||
assert isinstance(cpu_tensor, torch.Tensor)
|
||||
assert isinstance(device_tensor, torch.Tensor)
|
||||
setattr(obj, device_attr_name, cpu_tensor)
|
||||
|
||||
for v in vars(self).values():
|
||||
if isinstance(v, CpuGpuBuffer):
|
||||
v.gpu = v.cpu
|
||||
|
||||
for k, v in vars(self.input_batch).items():
|
||||
if k.endswith("_cpu_tensor") and isinstance(v, torch.Tensor):
|
||||
replace_tensor(self.input_batch, k, k[:-11])
|
||||
|
||||
for block_table in self.input_batch.block_table.block_tables:
|
||||
for v in vars(block_table).values():
|
||||
if isinstance(v, CpuGpuBuffer):
|
||||
v.gpu = v.cpu
|
||||
|
||||
@instrument(span_name="Loading (CPU)")
|
||||
def load_model(self, eep_scale_up: bool = False) -> None:
|
||||
logger.info("Starting to load model %s...", self.model_config.model)
|
||||
self.model = get_model(vllm_config=self.vllm_config)
|
||||
|
||||
if self.lora_config:
|
||||
self.model = self.load_lora_model(self.model, self.vllm_config, self.device)
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model
|
||||
|
||||
@instrument(span_name="Warmup (CPU)")
|
||||
def warming_up_model(self) -> None:
|
||||
logger.info("Warming up model for the compilation...")
|
||||
# Only generate graph for the generic shape
|
||||
with _set_global_compilation_settings(self.vllm_config):
|
||||
self._dummy_run(
|
||||
min(
|
||||
max(16, self.max_num_reqs),
|
||||
self.scheduler_config.max_num_batched_tokens,
|
||||
)
|
||||
)
|
||||
|
||||
logger.info("Warming up done.")
|
||||
|
||||
def _init_device_properties(self) -> None:
|
||||
pass
|
||||
|
||||
def _sync_device(self) -> None:
|
||||
pass
|
||||
|
||||
def get_dp_padding(self, num_tokens: int) -> tuple[int, torch.Tensor | None]:
|
||||
# Note: For CPU backend, dp padding is not required for now.
|
||||
return 0, None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _torch_cuda_wrapper():
|
||||
class _EventPlaceholder:
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
self.record = lambda: None
|
||||
self.synchronize = lambda: None
|
||||
|
||||
class _StreamPlaceholder:
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
pass
|
||||
|
||||
cuda_event = torch.Event
|
||||
cuda_stream = torch.cuda.Stream
|
||||
try:
|
||||
torch.Event = _EventPlaceholder
|
||||
torch.cuda.Stream = _StreamPlaceholder
|
||||
yield
|
||||
finally:
|
||||
torch.Event = cuda_event
|
||||
torch.cuda.Stream = cuda_stream
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _set_global_compilation_settings(config: VllmConfig):
|
||||
import torch._inductor.config as torch_inductor_config
|
||||
|
||||
inductor_config = config.compilation_config.inductor_compile_config
|
||||
# Note: The MKLDNN and CPPGEMM backend requires freezing parameters.
|
||||
freezing_value = torch_inductor_config.freezing
|
||||
try:
|
||||
if inductor_config.get("max_autotune", False):
|
||||
torch_inductor_config.freezing = True
|
||||
yield
|
||||
finally:
|
||||
torch_inductor_config.freezing = freezing_value
|
||||
221
vllm/v1/worker/cpu_worker.py
Normal file
221
vllm/v1/worker/cpu_worker.py
Normal file
@@ -0,0 +1,221 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import os
|
||||
import platform
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import CpuArchEnum, current_platform
|
||||
from vllm.platforms.cpu import CpuPlatform, LogicalCPUInfo
|
||||
from vllm.profiler.wrapper import TorchProfilerWrapper
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
from vllm.v1.worker.cpu_model_runner import CPUModelRunner
|
||||
from vllm.v1.worker.gpu_worker import Worker, init_worker_distributed_environment
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class CPUWorker(Worker):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
is_driver_worker: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
vllm_config,
|
||||
local_rank,
|
||||
rank,
|
||||
distributed_init_method,
|
||||
is_driver_worker=is_driver_worker,
|
||||
)
|
||||
|
||||
self.parallel_config.disable_custom_all_reduce = True
|
||||
|
||||
# Torch profiler. Enabled and configured through profiler_config.
|
||||
self.profiler: Any | None = None
|
||||
profiler_config = vllm_config.profiler_config
|
||||
if profiler_config.profiler == "torch":
|
||||
worker_name = f"{vllm_config.instance_id}-rank-{self.rank}"
|
||||
self.profiler = TorchProfilerWrapper(
|
||||
profiler_config,
|
||||
worker_name=worker_name,
|
||||
local_rank=self.local_rank,
|
||||
activities=["CPU"],
|
||||
)
|
||||
|
||||
def init_device(self):
|
||||
# Setup OpenMP threads affinity.
|
||||
omp_cpuids = envs.VLLM_CPU_OMP_THREADS_BIND
|
||||
# Under numa binding some cores reserved for kv transfer in nixl_connector.py
|
||||
if omp_cpuids == "auto" and platform.system() == "Linux":
|
||||
cpu_arch = current_platform.get_cpu_architecture()
|
||||
if cpu_arch in (CpuArchEnum.POWERPC, CpuArchEnum.S390X):
|
||||
# For S390X/POWERPC SMT-8/4/2
|
||||
self.local_omp_cpuid = self._get_autobind_cpu_ids(
|
||||
lambda cpus: [cpu for cpu in cpus if cpu.id % 8 < 4]
|
||||
)
|
||||
elif cpu_arch == CpuArchEnum.X86:
|
||||
# For x86 SMT-2, use 1 CPU per core
|
||||
self.local_omp_cpuid = self._get_autobind_cpu_ids(
|
||||
lambda cpus: cpus[-1:]
|
||||
)
|
||||
elif cpu_arch == CpuArchEnum.ARM:
|
||||
# For AArch64, no SMT
|
||||
self.local_omp_cpuid = self._get_autobind_cpu_ids(lambda cpus: cpus)
|
||||
else:
|
||||
self.local_omp_cpuid = "nobind"
|
||||
elif omp_cpuids == "nobind":
|
||||
self.local_omp_cpuid = "nobind"
|
||||
else:
|
||||
local_dp_rank = self.parallel_config.data_parallel_rank_local
|
||||
omp_cpuids_list = omp_cpuids.split("|")
|
||||
if local_dp_rank is not None:
|
||||
world_size = self.parallel_config.world_size
|
||||
omp_cpuids_list = omp_cpuids_list[
|
||||
local_dp_rank * world_size : (local_dp_rank + 1) * world_size
|
||||
]
|
||||
self.local_omp_cpuid = omp_cpuids_list[self.rank]
|
||||
|
||||
if self.local_omp_cpuid != "nobind":
|
||||
ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)
|
||||
if ret:
|
||||
logger.info(ret)
|
||||
|
||||
# Note: unique identifier for creating allreduce shared memory
|
||||
os.environ["VLLM_DIST_IDENT"] = self.distributed_init_method.split(":")[-1]
|
||||
# Initialize the distributed environment.
|
||||
init_worker_distributed_environment(
|
||||
self.vllm_config,
|
||||
self.rank,
|
||||
self.distributed_init_method,
|
||||
self.local_rank,
|
||||
current_platform.dist_backend,
|
||||
)
|
||||
# Set random seed.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
# Construct the model runner
|
||||
self.model_runner: CPUModelRunner = CPUModelRunner(
|
||||
self.vllm_config, torch.device("cpu")
|
||||
)
|
||||
|
||||
def sleep(self, level: int = 1) -> None:
|
||||
logger.warning("sleep mode is not supported on CPU, ignore it.")
|
||||
pass
|
||||
|
||||
def wake_up(self, tags: list[str] | None = None) -> None:
|
||||
logger.warning("sleep mode is not supported on CPU, ignore it.")
|
||||
pass
|
||||
|
||||
def determine_available_memory(self) -> int:
|
||||
return self.cache_config.cpu_kvcache_space_bytes or 0
|
||||
|
||||
def compile_or_warm_up_model(self) -> None:
|
||||
# Reset the seed to ensure that the random state is not affected by
|
||||
# the model initialization and profiling.
|
||||
set_random_seed(self.model_config.seed)
|
||||
self.model_runner.warming_up_model()
|
||||
|
||||
def _get_autobind_cpu_ids(
|
||||
self, cpu_selector: Callable[[list[LogicalCPUInfo]], list[LogicalCPUInfo]]
|
||||
) -> str:
|
||||
"""
|
||||
Return CPU ids to bind based on NUMA nodes.
|
||||
Currently for rank N, only CPU ids on the N-th node in available NUMA
|
||||
node list will be selected.
|
||||
Args:
|
||||
cpu_selector: a callable object to select CPUs from a CPU list
|
||||
of a physical core. The input is a LogicalCPUInfo list, sorted by
|
||||
the LogicalCPUInfo.id. A selected LogicalCPUInfo list should be
|
||||
returned.
|
||||
"""
|
||||
# simulate multiple numa nodes, for testing
|
||||
sim_multi_numa_nodes = os.environ.get("VLLM_CPU_SIM_MULTI_NUMA", "0") != "0"
|
||||
|
||||
allowed_numa_nodes, logical_cpu_list = (
|
||||
CpuPlatform.get_allowed_cpu_core_node_list()
|
||||
)
|
||||
assert (
|
||||
len(allowed_numa_nodes) >= self.parallel_config.world_size
|
||||
or sim_multi_numa_nodes
|
||||
), (
|
||||
f"Not enough allowed NUMA nodes to bind threads of "
|
||||
f"{self.parallel_config.world_size} CPUWorkers. "
|
||||
f"Allowed NUMA nodes are {allowed_numa_nodes}. "
|
||||
"Please try to bind threads manually."
|
||||
)
|
||||
|
||||
if not sim_multi_numa_nodes:
|
||||
# Get CPUs on NUMA node `allowed_numa_nodes[local_rank]`
|
||||
selected_numa_node = allowed_numa_nodes[self.local_rank] # type: ignore
|
||||
logical_cpu_list = [
|
||||
x for x in logical_cpu_list if x.numa_node == selected_numa_node
|
||||
]
|
||||
else:
|
||||
# This is a bit tricky because the internal DP size
|
||||
# is always 1 for non-MoE models
|
||||
world_size_across_dp = (
|
||||
self.parallel_config.world_size
|
||||
* self.parallel_config._api_process_count
|
||||
)
|
||||
assert len(logical_cpu_list) >= world_size_across_dp
|
||||
logical_cpu_list = sorted(logical_cpu_list, key=lambda x: x.numa_node)
|
||||
sim_cpu_num_per_node = len(logical_cpu_list) // world_size_across_dp
|
||||
assert self.parallel_config.data_parallel_rank_local is not None
|
||||
start_idx = (
|
||||
self.local_rank
|
||||
+ self.parallel_config.world_size
|
||||
* self.parallel_config.data_parallel_rank_local
|
||||
) * sim_cpu_num_per_node
|
||||
logical_cpu_list = logical_cpu_list[
|
||||
start_idx : (start_idx + sim_cpu_num_per_node)
|
||||
]
|
||||
|
||||
# Select CPUs from each physical core via cpu_selector
|
||||
core_to_cpus: dict[int, list[LogicalCPUInfo]] = {}
|
||||
for cpu_info in logical_cpu_list:
|
||||
if cpu_info.physical_core not in core_to_cpus:
|
||||
core_to_cpus[cpu_info.physical_core] = []
|
||||
core_to_cpus[cpu_info.physical_core].append(cpu_info)
|
||||
logical_cpu_list = []
|
||||
for cpu_list in core_to_cpus.values():
|
||||
cpu_list = sorted(cpu_list, key=lambda x: x.id)
|
||||
logical_cpu_list.extend(cpu_selector(cpu_list))
|
||||
logical_cpu_list = sorted(logical_cpu_list, key=lambda x: x.id)
|
||||
|
||||
# Reserve CPUs for other processes
|
||||
reserve_cpu_num = envs.VLLM_CPU_NUM_OF_RESERVED_CPU
|
||||
if reserve_cpu_num is None:
|
||||
need_reserve = (
|
||||
self.parallel_config.world_size > 1
|
||||
or self.parallel_config.data_parallel_size_local > 1
|
||||
)
|
||||
reserve_cpu_num = 1 if need_reserve else 0
|
||||
assert len(logical_cpu_list) > reserve_cpu_num, (
|
||||
f"VLLM_CPU_NUM_OF_RESERVED_CPU ({reserve_cpu_num}) "
|
||||
f"should less than {len(logical_cpu_list)}."
|
||||
)
|
||||
if reserve_cpu_num != 0:
|
||||
logical_cpu_list = logical_cpu_list[:-reserve_cpu_num]
|
||||
|
||||
logger.info(
|
||||
"auto thread-binding list (id, physical core): %s",
|
||||
[(x.id, x.physical_core) for x in logical_cpu_list],
|
||||
)
|
||||
return ",".join([str(x.id) for x in logical_cpu_list])
|
||||
|
||||
def profile(self, is_start: bool = True, profile_prefix: str | None = None):
|
||||
if self.profiler is None:
|
||||
raise RuntimeError("Profiler is not enabled.")
|
||||
if is_start:
|
||||
self.profiler.start()
|
||||
else:
|
||||
self.profiler.stop()
|
||||
240
vllm/v1/worker/dp_utils.py
Normal file
240
vllm/v1/worker/dp_utils.py
Normal file
@@ -0,0 +1,240 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.distributed.parallel_state import get_dp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.worker.ubatch_utils import (
|
||||
check_ubatch_thresholds,
|
||||
is_last_ubatch_empty,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _get_device_and_group(parallel_config: ParallelConfig):
|
||||
# Use the actual device assigned to the DP group, not just the device type
|
||||
device = get_dp_group().device
|
||||
group = get_dp_group().device_group
|
||||
|
||||
# Transferring this tensor from GPU to CPU will introduce a GPU sync
|
||||
# point that could adversely affect performance of vllm with asynch
|
||||
# scheduling. This environment variable exists to quickly disable
|
||||
# this optimization if we run into this case.
|
||||
if parallel_config.disable_nccl_for_dp_synchronization:
|
||||
logger.info_once(
|
||||
"Using CPU all reduce to synchronize DP padding between ranks."
|
||||
)
|
||||
device = "cpu"
|
||||
group = get_dp_group().cpu_group
|
||||
return device, group
|
||||
|
||||
|
||||
def _run_ar(
|
||||
should_ubatch: bool,
|
||||
should_dp_pad: bool,
|
||||
orig_num_tokens_per_ubatch: int,
|
||||
padded_num_tokens_per_ubatch: int,
|
||||
cudagraph_mode: int,
|
||||
parallel_config: ParallelConfig,
|
||||
) -> torch.Tensor:
|
||||
dp_size = parallel_config.data_parallel_size
|
||||
dp_rank = parallel_config.data_parallel_rank
|
||||
device, group = _get_device_and_group(parallel_config)
|
||||
tensor = torch.zeros(5, dp_size, device=device, dtype=torch.int32)
|
||||
tensor[0][dp_rank] = orig_num_tokens_per_ubatch
|
||||
tensor[1][dp_rank] = padded_num_tokens_per_ubatch
|
||||
tensor[2][dp_rank] = 1 if should_ubatch else 0
|
||||
tensor[3][dp_rank] = 1 if should_dp_pad else 0
|
||||
tensor[4][dp_rank] = cudagraph_mode
|
||||
dist.all_reduce(tensor, group=group)
|
||||
return tensor
|
||||
|
||||
|
||||
def _post_process_ubatch(tensor: torch.Tensor, num_ubatches: int) -> bool:
|
||||
orig_num_tokens_tensor = tensor[0, :]
|
||||
padded_num_tokens_tensor = tensor[1, :]
|
||||
|
||||
# First determine if we are going to be ubatching.
|
||||
should_ubatch: bool = bool(torch.all(tensor[2] == 1).item())
|
||||
if not should_ubatch:
|
||||
return False
|
||||
# If the DP ranks are planning to ubatch, make sure that
|
||||
# there are no "empty" second ubatches
|
||||
orig_min_num_tokens = int(orig_num_tokens_tensor.min().item())
|
||||
padded_max_num_tokens = int(padded_num_tokens_tensor.max().item())
|
||||
if is_last_ubatch_empty(orig_min_num_tokens, padded_max_num_tokens, num_ubatches):
|
||||
logger.debug(
|
||||
"Aborting ubatching %s %s", orig_min_num_tokens, padded_max_num_tokens
|
||||
)
|
||||
should_ubatch = False
|
||||
return should_ubatch
|
||||
|
||||
|
||||
def _post_process_dp_padding(tensor: torch.Tensor, should_dp_pad: bool) -> torch.Tensor:
|
||||
num_tokens_across_dp = tensor[1, :]
|
||||
if should_dp_pad:
|
||||
# If DP padding is enabled, ensure that each rank is processing the same number
|
||||
# of tokens
|
||||
max_num_tokens = int(num_tokens_across_dp.max().item())
|
||||
return torch.tensor(
|
||||
[max_num_tokens] * len(num_tokens_across_dp),
|
||||
device="cpu",
|
||||
dtype=torch.int32,
|
||||
)
|
||||
else:
|
||||
return num_tokens_across_dp.cpu()
|
||||
|
||||
|
||||
def _post_process_cudagraph_mode(tensor: torch.Tensor) -> int:
|
||||
"""
|
||||
Synchronize cudagraph_mode across DP ranks by taking the minimum.
|
||||
If any rank has NONE (0), all ranks use NONE.
|
||||
This ensures all ranks send consistent values (all padded or all unpadded).
|
||||
"""
|
||||
return int(tensor[4, :].min().item())
|
||||
|
||||
|
||||
def _synchronize_dp_ranks(
|
||||
num_tokens_unpadded: int,
|
||||
num_tokens_padded: int,
|
||||
should_attempt_ubatching: bool,
|
||||
should_attempt_dp_padding: bool,
|
||||
cudagraph_mode: int,
|
||||
parallel_config: ParallelConfig,
|
||||
) -> tuple[bool, torch.Tensor | None, int]:
|
||||
"""
|
||||
1. Decides if each DP rank is going to microbatch. Either all ranks
|
||||
run with microbatching or none of them do.
|
||||
|
||||
2. Determines the total number of tokens that each rank will run.
|
||||
When running microbatched or if should_attempt_dp_padding is True, all
|
||||
ranks will be padded out so that the run with the same number of tokens
|
||||
|
||||
3. Synchronizes cudagraph_mode across ranks by taking the minimum.
|
||||
|
||||
Returns: tuple[
|
||||
should_ubatch: Are all DP ranks going to microbatch
|
||||
num_tokens_after_padding: A tensor containing the total number of
|
||||
tokens per-microbatch for each DP rank including any DP padding.
|
||||
synced_cudagraph_mode: The synchronized cudagraph mode (min across ranks)
|
||||
]
|
||||
|
||||
"""
|
||||
assert num_tokens_padded >= num_tokens_unpadded
|
||||
|
||||
# Coordinate between the DP ranks via an All Reduce
|
||||
# to determine the total number of tokens that each rank
|
||||
# will run and if we are using ubatching or not.
|
||||
tensor = _run_ar(
|
||||
should_ubatch=should_attempt_ubatching,
|
||||
should_dp_pad=should_attempt_dp_padding,
|
||||
orig_num_tokens_per_ubatch=num_tokens_unpadded,
|
||||
padded_num_tokens_per_ubatch=num_tokens_padded,
|
||||
cudagraph_mode=cudagraph_mode,
|
||||
parallel_config=parallel_config,
|
||||
)
|
||||
|
||||
should_dp_pad = bool(torch.all(tensor[3] == 1).item())
|
||||
|
||||
# DP ranks should all have the same value for should_attempt_dp_padding.
|
||||
assert should_attempt_dp_padding == should_dp_pad
|
||||
|
||||
# Check conditions for microbatching
|
||||
should_ubatch = _post_process_ubatch(tensor, parallel_config.num_ubatches)
|
||||
|
||||
if should_ubatch and not should_dp_pad:
|
||||
logger.debug_once(
|
||||
"Microbatching has been triggered and requires DP padding. "
|
||||
"Enabling DP padding even though it has been explicitly "
|
||||
"disabled.",
|
||||
scope="global",
|
||||
)
|
||||
should_dp_pad = True
|
||||
|
||||
# Pad all DP ranks up to the maximum token count across ranks if
|
||||
# should_dp_pad is True
|
||||
num_tokens_after_padding = _post_process_dp_padding(
|
||||
tensor,
|
||||
should_dp_pad,
|
||||
)
|
||||
|
||||
# Synchronize cudagraph_mode across ranks (take min)
|
||||
synced_cudagraph_mode = _post_process_cudagraph_mode(tensor)
|
||||
|
||||
return should_ubatch, num_tokens_after_padding, synced_cudagraph_mode
|
||||
|
||||
|
||||
def coordinate_batch_across_dp(
|
||||
num_tokens_unpadded: int,
|
||||
allow_microbatching: bool,
|
||||
allow_dp_padding: bool,
|
||||
parallel_config: ParallelConfig,
|
||||
num_tokens_padded: int | None = None,
|
||||
uniform_decode: bool | None = None,
|
||||
num_scheduled_tokens_per_request: np.ndarray | None = None,
|
||||
cudagraph_mode: int = 0,
|
||||
) -> tuple[bool, torch.Tensor | None, int]:
|
||||
"""
|
||||
Coordinates amongst all DP ranks to determine if and how the full batch
|
||||
should be split into microbatches.
|
||||
|
||||
Args:
|
||||
num_tokens_unpadded: Number of tokens without accounting for padding
|
||||
allow_microbatching: If microbatching should be attempted
|
||||
allow_dp_padding: If all DP ranks should be padded up to the same value
|
||||
parallel_config: The parallel config
|
||||
num_tokens_padded: Number of tokens including any non-DP padding (CUDA graphs,
|
||||
TP, etc)
|
||||
uniform_decode: Only used if allow_microbatching is True. True if the batch
|
||||
only contains single token decodes
|
||||
num_scheduled_tokens_per_request: Only used if allow_microbatching is True. The
|
||||
number of tokens per request.
|
||||
cudagraph_mode: The cudagraph mode for this rank (0=NONE, 1=PIECEWISE, 2=FULL)
|
||||
|
||||
Returns: tuple[
|
||||
ubatch_slices: if this is set then all DP ranks have agreed to
|
||||
microbatch
|
||||
num_tokens_after_padding: A tensor containing the total number of
|
||||
tokens per-microbatch for each DP rank including padding. Will be
|
||||
padded up to the max value across all DP ranks when allow_dp_padding
|
||||
is True.
|
||||
synced_cudagraph_mode: The synchronized cudagraph mode (min across ranks)
|
||||
]
|
||||
|
||||
"""
|
||||
if parallel_config.data_parallel_size == 1:
|
||||
# Early exit.
|
||||
return False, None, cudagraph_mode
|
||||
|
||||
# If the caller has explicitly enabled microbatching.
|
||||
should_attempt_ubatching = False
|
||||
if allow_microbatching:
|
||||
# Check preconditions for microbatching
|
||||
assert uniform_decode is not None
|
||||
should_attempt_ubatching = check_ubatch_thresholds(
|
||||
parallel_config,
|
||||
num_tokens_unpadded,
|
||||
uniform_decode=uniform_decode,
|
||||
)
|
||||
|
||||
if num_tokens_padded is None:
|
||||
num_tokens_padded = num_tokens_unpadded
|
||||
|
||||
(should_ubatch, num_tokens_after_padding, synced_cudagraph_mode) = (
|
||||
_synchronize_dp_ranks(
|
||||
num_tokens_unpadded,
|
||||
num_tokens_padded,
|
||||
should_attempt_ubatching,
|
||||
allow_dp_padding,
|
||||
cudagraph_mode,
|
||||
parallel_config,
|
||||
)
|
||||
)
|
||||
|
||||
return (should_ubatch, num_tokens_after_padding, synced_cudagraph_mode)
|
||||
86
vllm/v1/worker/ec_connector_model_runner_mixin.py
Normal file
86
vllm/v1/worker/ec_connector_model_runner_mixin.py
Normal file
@@ -0,0 +1,86 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Define EC connector functionality mixin for model runners.
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
from contextlib import AbstractContextManager, contextmanager, nullcontext
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer
|
||||
from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.outputs import ECConnectorOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# Defined as a EC connector functionality mixin for ModelRunner (GPU, TPU)
|
||||
class ECConnectorModelRunnerMixin:
|
||||
@staticmethod
|
||||
def maybe_save_ec_to_connector(
|
||||
encoder_cache: dict[str, torch.Tensor],
|
||||
mm_hash: str,
|
||||
):
|
||||
if not has_ec_transfer():
|
||||
logger.debug("Not have ec transfer please check")
|
||||
return
|
||||
connector = get_ec_transfer()
|
||||
connector.save_caches(encoder_cache=encoder_cache, mm_hash=mm_hash)
|
||||
|
||||
@staticmethod
|
||||
def get_finished_ec_transfers(
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> tuple[set[str] | None, set[str] | None]:
|
||||
if has_ec_transfer():
|
||||
return get_ec_transfer().get_finished(scheduler_output.finished_req_ids)
|
||||
return None, None
|
||||
|
||||
@staticmethod
|
||||
def maybe_get_ec_connector_output(
|
||||
scheduler_output: "SchedulerOutput",
|
||||
encoder_cache: dict[str, torch.Tensor],
|
||||
**kwargs,
|
||||
) -> AbstractContextManager[ECConnectorOutput | None]:
|
||||
return (
|
||||
ECConnectorModelRunnerMixin._get_ec_connector_output(
|
||||
scheduler_output, encoder_cache, **kwargs
|
||||
)
|
||||
if has_ec_transfer()
|
||||
else nullcontext()
|
||||
)
|
||||
|
||||
# This context manager must be used within an active forward context.
|
||||
# It encapsulates the entire EC connector lifecycle within execute_model
|
||||
@staticmethod
|
||||
@contextmanager
|
||||
def _get_ec_connector_output(
|
||||
scheduler_output: "SchedulerOutput",
|
||||
encoder_cache: dict[str, torch.Tensor],
|
||||
**kwargs,
|
||||
) -> Generator[ECConnectorOutput, None, None]:
|
||||
output = ECConnectorOutput()
|
||||
|
||||
ec_connector = get_ec_transfer()
|
||||
assert isinstance(ec_connector, ECConnectorBase)
|
||||
assert scheduler_output.ec_connector_metadata is not None
|
||||
ec_connector.bind_connector_metadata(scheduler_output.ec_connector_metadata)
|
||||
|
||||
# Load caches for consumer or both roles
|
||||
if ec_connector.is_consumer:
|
||||
ec_connector.start_load_caches(encoder_cache, **kwargs)
|
||||
|
||||
try:
|
||||
yield output
|
||||
finally:
|
||||
output.finished_sending, output.finished_recving = (
|
||||
ec_connector.get_finished(scheduler_output.finished_req_ids)
|
||||
)
|
||||
|
||||
ec_connector.clear_connector_metadata()
|
||||
4
vllm/v1/worker/gpu/README.md
Normal file
4
vllm/v1/worker/gpu/README.md
Normal file
@@ -0,0 +1,4 @@
|
||||
# [Experimental] Model Runner V2
|
||||
|
||||
This directory contains the new model runner which is under active development.
|
||||
Ping [Woosuk Kwon](https://github.com/WoosukKwon) for any changes.
|
||||
0
vllm/v1/worker/gpu/__init__.py
Normal file
0
vllm/v1/worker/gpu/__init__.py
Normal file
86
vllm/v1/worker/gpu/async_utils.py
Normal file
86
vllm/v1/worker/gpu/async_utils.py
Normal file
@@ -0,0 +1,86 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import contextlib
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.v1.outputs import AsyncModelRunnerOutput, LogprobsTensors, ModelRunnerOutput
|
||||
from vllm.v1.worker.gpu.sample.output import SamplerOutput
|
||||
|
||||
|
||||
class AsyncOutput(AsyncModelRunnerOutput):
|
||||
def __init__(
|
||||
self,
|
||||
model_runner_output: ModelRunnerOutput,
|
||||
sampler_output: SamplerOutput,
|
||||
num_sampled_tokens: torch.Tensor,
|
||||
main_stream: torch.cuda.Stream,
|
||||
copy_stream: torch.cuda.Stream,
|
||||
copy_event: torch.cuda.Event,
|
||||
):
|
||||
# NOTE(woosuk): We must retain references to the GPU tensors,
|
||||
# as the copy operations are performed on a different CUDA stream than
|
||||
# the one where the tensors were created.
|
||||
self.model_runner_output = model_runner_output
|
||||
self.sampler_output = sampler_output
|
||||
self.num_sampled_tokens = num_sampled_tokens
|
||||
self.copy_event = copy_event
|
||||
|
||||
with stream(copy_stream, main_stream):
|
||||
copy_stream.wait_stream(main_stream)
|
||||
|
||||
self.sampled_token_ids = async_copy_to_np(sampler_output.sampled_token_ids)
|
||||
self.logprobs_tensors: LogprobsTensors | None = None
|
||||
if sampler_output.logprobs_tensors is not None:
|
||||
self.logprobs_tensors = (
|
||||
sampler_output.logprobs_tensors.to_cpu_nonblocking()
|
||||
)
|
||||
self.num_nans: np.ndarray | None = None
|
||||
if sampler_output.num_nans is not None:
|
||||
self.num_nans = async_copy_to_np(sampler_output.num_nans)
|
||||
self.num_sampled_tokens_np = async_copy_to_np(num_sampled_tokens)
|
||||
self.prompt_logprobs_dict = {
|
||||
k: v.to_cpu_nonblocking() if v is not None else None
|
||||
for k, v in self.model_runner_output.prompt_logprobs_dict.items()
|
||||
}
|
||||
self.copy_event.record(copy_stream)
|
||||
|
||||
def get_output(self) -> ModelRunnerOutput:
|
||||
self.copy_event.synchronize()
|
||||
|
||||
# NOTE(woosuk): The following code is to ensure compatibility with
|
||||
# the existing model runner.
|
||||
# Going forward, we should keep the data structures as NumPy arrays
|
||||
# rather than Python lists.
|
||||
sampled_token_ids: list[list[int]] = self.sampled_token_ids.tolist()
|
||||
num_sampled_tokens: list[int] = self.num_sampled_tokens_np.tolist()
|
||||
for token_ids, num_tokens in zip(sampled_token_ids, num_sampled_tokens):
|
||||
del token_ids[num_tokens:]
|
||||
self.model_runner_output.sampled_token_ids = sampled_token_ids
|
||||
|
||||
if self.num_nans is not None:
|
||||
self.model_runner_output.num_nans_in_logits = dict(
|
||||
zip(self.model_runner_output.req_ids, self.num_nans.tolist())
|
||||
)
|
||||
|
||||
if self.logprobs_tensors is not None:
|
||||
self.model_runner_output.logprobs = self.logprobs_tensors.tolists()
|
||||
self.model_runner_output.prompt_logprobs_dict = self.prompt_logprobs_dict
|
||||
return self.model_runner_output
|
||||
|
||||
|
||||
def async_copy_to_np(x: torch.Tensor) -> np.ndarray:
|
||||
return x.to("cpu", non_blocking=True).numpy()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def stream(to_stream: torch.cuda.Stream, from_stream: torch.cuda.Stream):
|
||||
"""Lightweight version of torch.cuda.stream() context manager which
|
||||
avoids current_stream and device lookups.
|
||||
"""
|
||||
try:
|
||||
torch.cuda.set_stream(to_stream)
|
||||
yield
|
||||
finally:
|
||||
torch.cuda.set_stream(from_stream)
|
||||
215
vllm/v1/worker/gpu/attn_utils.py
Normal file
215
vllm/v1/worker/gpu/attn_utils.py
Normal file
@@ -0,0 +1,215 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.v1.attention.backend import AttentionBackend, CommonAttentionMetadata
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
AttentionSpec,
|
||||
KVCacheConfig,
|
||||
KVCacheSpec,
|
||||
UniformTypeKVCacheSpecs,
|
||||
)
|
||||
from vllm.v1.worker.utils import AttentionGroup, bind_kv_cache
|
||||
|
||||
|
||||
def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]:
|
||||
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
||||
layer_type = cast(type[Any], AttentionLayerBase)
|
||||
attn_layers = get_layers_from_vllm_config(vllm_config, layer_type)
|
||||
for layer_name, attn_module in attn_layers.items():
|
||||
# Skip modules that don't need KV cache (eg encoder-only attention)
|
||||
if spec := attn_module.get_kv_cache_spec(vllm_config):
|
||||
kv_cache_spec[layer_name] = spec
|
||||
return kv_cache_spec
|
||||
|
||||
|
||||
def init_attn_backend(
|
||||
kv_cache_config: KVCacheConfig, vllm_config: VllmConfig, device: torch.device
|
||||
):
|
||||
attn_backends: dict[str, type[AttentionBackend]] = {}
|
||||
attn_groups: list[list[AttentionGroup]] = []
|
||||
attn_backend_workspace: torch.Tensor | None = None
|
||||
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
||||
kv_cache_config.kv_cache_groups
|
||||
):
|
||||
layer_names = kv_cache_group_spec.layer_names
|
||||
|
||||
layer_type = cast(type[Any], AttentionLayerBase)
|
||||
attn_layers = get_layers_from_vllm_config(vllm_config, layer_type, layer_names)
|
||||
|
||||
group_map: dict[tuple[tuple[str, str], KVCacheSpec], AttentionGroup] = {}
|
||||
group_order: list[tuple[tuple[str, str], KVCacheSpec]] = []
|
||||
|
||||
for layer_name in layer_names:
|
||||
attn_backend = attn_layers[layer_name].get_attn_backend()
|
||||
attn_backends[layer_name] = attn_backend
|
||||
|
||||
layer_kv_cache_spec: KVCacheSpec = kv_cache_group_spec.kv_cache_spec
|
||||
if isinstance(layer_kv_cache_spec, UniformTypeKVCacheSpecs):
|
||||
layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[layer_name]
|
||||
|
||||
key = (attn_backend.full_cls_name(), layer_kv_cache_spec)
|
||||
if key not in group_map:
|
||||
group_map[key] = AttentionGroup(
|
||||
attn_backend,
|
||||
[layer_name],
|
||||
layer_kv_cache_spec,
|
||||
kv_cache_group_id,
|
||||
)
|
||||
group_order.append(key)
|
||||
else:
|
||||
group_map[key].layer_names.append(layer_name)
|
||||
|
||||
groups = [group_map[key] for key in group_order]
|
||||
for group in groups:
|
||||
group.create_metadata_builders(
|
||||
vllm_config=vllm_config,
|
||||
device=device,
|
||||
kernel_block_size=None,
|
||||
num_metadata_builders=1,
|
||||
)
|
||||
builder = group.get_metadata_builder(0)
|
||||
if attn_backend_workspace is None:
|
||||
if hasattr(builder, "_get_workspace_buffer"):
|
||||
attn_backend_workspace = builder._get_workspace_buffer()
|
||||
else:
|
||||
if hasattr(builder, "set_workspace_buffer"):
|
||||
builder.set_workspace_buffer(attn_backend_workspace)
|
||||
attn_groups.append(groups)
|
||||
return attn_backends, attn_groups
|
||||
|
||||
|
||||
def _allocate_kv_cache(kv_cache_config: KVCacheConfig, device: torch.device):
|
||||
kv_cache_raw_tensors: dict[str, torch.Tensor] = {}
|
||||
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
|
||||
tensor = torch.zeros(kv_cache_tensor.size, dtype=torch.int8, device=device)
|
||||
for layer_name in kv_cache_tensor.shared_by:
|
||||
kv_cache_raw_tensors[layer_name] = tensor
|
||||
|
||||
layer_names = set()
|
||||
for group in kv_cache_config.kv_cache_groups:
|
||||
for layer_name in group.layer_names:
|
||||
layer_names.add(layer_name)
|
||||
assert layer_names == set(kv_cache_raw_tensors.keys()), (
|
||||
"Some layers are not correctly initialized"
|
||||
)
|
||||
return kv_cache_raw_tensors
|
||||
|
||||
|
||||
def _reshape_kv_cache(
|
||||
kv_cache_config: KVCacheConfig,
|
||||
kv_cache_raw_tensors: dict[str, torch.Tensor],
|
||||
attn_backends: dict[str, AttentionBackend],
|
||||
) -> dict[str, torch.Tensor]:
|
||||
kv_caches: dict[str, torch.Tensor] = {}
|
||||
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
|
||||
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
|
||||
assert isinstance(kv_cache_spec, AttentionSpec)
|
||||
for layer_name in kv_cache_group_spec.layer_names:
|
||||
raw_tensor = kv_cache_raw_tensors[layer_name]
|
||||
assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
|
||||
num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes
|
||||
|
||||
attn_backend = attn_backends[layer_name]
|
||||
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
||||
num_blocks,
|
||||
kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads,
|
||||
kv_cache_spec.head_size,
|
||||
)
|
||||
|
||||
# FIXME(woosuk): Add kv_cache_stride_order to all attention backends.
|
||||
try:
|
||||
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
|
||||
assert len(kv_cache_stride_order) == len(kv_cache_shape)
|
||||
except (AttributeError, NotImplementedError):
|
||||
kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
|
||||
|
||||
kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)
|
||||
inv_order = [
|
||||
kv_cache_stride_order.index(i)
|
||||
for i in range(len(kv_cache_stride_order))
|
||||
]
|
||||
|
||||
dtype = kv_cache_spec.dtype
|
||||
raw_tensor = raw_tensor.view(dtype)
|
||||
raw_tensor = raw_tensor.view(kv_cache_shape)
|
||||
kv_caches[layer_name] = raw_tensor.permute(*inv_order)
|
||||
return kv_caches
|
||||
|
||||
|
||||
def init_kv_cache(
|
||||
runner_kv_caches: list[torch.Tensor],
|
||||
forward_context: dict[str, Any],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
attn_backends: dict[str, AttentionBackend],
|
||||
device: torch.device,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
kv_cache_raw_tensors = _allocate_kv_cache(kv_cache_config, device)
|
||||
kv_caches = _reshape_kv_cache(kv_cache_config, kv_cache_raw_tensors, attn_backends)
|
||||
bind_kv_cache(kv_caches, forward_context, runner_kv_caches)
|
||||
return kv_caches
|
||||
|
||||
|
||||
def build_slot_mappings_by_layer(
|
||||
slot_mappings: torch.Tensor, kv_cache_config: KVCacheConfig
|
||||
) -> dict[str, torch.Tensor]:
|
||||
slot_mappings_by_layer: dict[str, torch.Tensor] = {}
|
||||
kv_cache_groups = kv_cache_config.kv_cache_groups
|
||||
for slot_mapping, kv_cache_group in zip(slot_mappings, kv_cache_groups):
|
||||
for layer_name in kv_cache_group.layer_names:
|
||||
slot_mappings_by_layer[layer_name] = slot_mapping
|
||||
return slot_mappings_by_layer
|
||||
|
||||
|
||||
def build_attn_metadata(
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
num_reqs: int,
|
||||
num_tokens: int,
|
||||
query_start_loc_gpu: torch.Tensor,
|
||||
query_start_loc_cpu: torch.Tensor,
|
||||
max_query_len: int,
|
||||
seq_lens: torch.Tensor,
|
||||
max_seq_len: int,
|
||||
block_tables: Sequence[torch.Tensor],
|
||||
slot_mappings: torch.Tensor,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
dcp_local_seq_lens: torch.Tensor | None = None,
|
||||
) -> dict[str, Any]:
|
||||
seq_lens = seq_lens[:num_reqs]
|
||||
if dcp_local_seq_lens is not None:
|
||||
dcp_local_seq_lens = dcp_local_seq_lens[:num_reqs]
|
||||
|
||||
attn_metadata: dict[str, Any] = {}
|
||||
num_kv_cache_groups = len(kv_cache_config.kv_cache_groups)
|
||||
for i in range(num_kv_cache_groups):
|
||||
block_table = block_tables[i]
|
||||
slot_mapping = slot_mappings[i]
|
||||
|
||||
common_attn_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=query_start_loc_gpu,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
seq_lens=seq_lens,
|
||||
max_seq_len=max_seq_len,
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=num_tokens,
|
||||
max_query_len=max_query_len,
|
||||
block_table_tensor=block_table,
|
||||
slot_mapping=slot_mapping,
|
||||
causal=True,
|
||||
dcp_local_seq_lens=dcp_local_seq_lens,
|
||||
)
|
||||
|
||||
for attn_group in attn_groups[i]:
|
||||
attn_metadata_builder = attn_group.get_metadata_builder(0)
|
||||
metadata = attn_metadata_builder.build(
|
||||
common_prefix_len=0, common_attn_metadata=common_attn_metadata
|
||||
)
|
||||
for layer_name in attn_group.layer_names:
|
||||
attn_metadata[layer_name] = metadata
|
||||
return attn_metadata
|
||||
253
vllm/v1/worker/gpu/block_table.py
Normal file
253
vllm/v1/worker/gpu/block_table.py
Normal file
@@ -0,0 +1,253 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor
|
||||
|
||||
|
||||
class BlockTables:
|
||||
def __init__(
|
||||
self,
|
||||
block_sizes: list[int],
|
||||
max_num_reqs: int,
|
||||
max_num_batched_tokens: int,
|
||||
max_model_len: int,
|
||||
device: torch.device,
|
||||
cp_size: int = 1,
|
||||
cp_rank: int = 0,
|
||||
cp_interleave: int = 1,
|
||||
):
|
||||
self.block_sizes = block_sizes
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
self.max_model_len = max_model_len
|
||||
self.device = device
|
||||
|
||||
self.cp_size = cp_size
|
||||
self.cp_rank = cp_rank
|
||||
self.cp_interleave = cp_interleave
|
||||
|
||||
self.num_kv_cache_groups = len(self.block_sizes)
|
||||
# num_kv_cache_groups x [max_num_reqs, max_num_blocks]
|
||||
self.block_tables: list[StagedWriteTensor] = []
|
||||
for i in range(self.num_kv_cache_groups):
|
||||
block_size = self.block_sizes[i]
|
||||
# When using DCP, each request's KV cache is sharded among different ranks.
|
||||
# As a result, one block on the current rank covers `block_size * cp_size`
|
||||
# tokens in the full, global (unsharded) sequence.
|
||||
max_num_blocks = cdiv(self.max_model_len, block_size * self.cp_size)
|
||||
block_table = StagedWriteTensor(
|
||||
(self.max_num_reqs, max_num_blocks),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.block_tables.append(block_table)
|
||||
self.block_table_ptrs = self._make_ptr_tensor(
|
||||
[b.gpu for b in self.block_tables]
|
||||
)
|
||||
self.block_table_strides = torch.tensor(
|
||||
[b.gpu.stride(0) for b in self.block_tables],
|
||||
dtype=torch.int64,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.block_sizes_tensor = torch.tensor(
|
||||
self.block_sizes, dtype=torch.int32, device=self.device
|
||||
)
|
||||
self.num_blocks = UvaBackedTensor(
|
||||
(self.num_kv_cache_groups, self.max_num_reqs),
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
# Block tables used for model's forward pass.
|
||||
# num_kv_cache_groups x [max_num_reqs, max_num_blocks]
|
||||
self.input_block_tables: list[torch.Tensor] = [
|
||||
torch.zeros_like(b.gpu) for b in self.block_tables
|
||||
]
|
||||
self.input_block_table_ptrs = self._make_ptr_tensor(self.input_block_tables)
|
||||
|
||||
self.slot_mappings = torch.zeros(
|
||||
self.num_kv_cache_groups,
|
||||
self.max_num_batched_tokens,
|
||||
dtype=torch.int64,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def _make_ptr_tensor(self, x: Iterable[torch.Tensor]) -> torch.Tensor:
|
||||
# NOTE(woosuk): Use uint64 instead of int64 to cover all possible addresses.
|
||||
return torch.tensor(
|
||||
[t.data_ptr() for t in x], dtype=torch.uint64, device=self.device
|
||||
)
|
||||
|
||||
def append_block_ids(
|
||||
self,
|
||||
req_index: int,
|
||||
new_block_ids: tuple[list[int], ...],
|
||||
overwrite: bool,
|
||||
) -> None:
|
||||
for i in range(self.num_kv_cache_groups):
|
||||
start = self.num_blocks.np[i, req_index] if not overwrite else 0
|
||||
block_ids = new_block_ids[i]
|
||||
self.block_tables[i].stage_write(req_index, start, block_ids)
|
||||
self.num_blocks.np[i, req_index] = start + len(block_ids)
|
||||
|
||||
def apply_staged_writes(self) -> None:
|
||||
# TODO(woosuk): This can be inefficient since it launches one kernel per
|
||||
# block table. Implement a kernel to handle all block tables at once.
|
||||
for block_table in self.block_tables:
|
||||
block_table.apply_write()
|
||||
self.num_blocks.copy_to_uva()
|
||||
|
||||
def gather_block_tables(
|
||||
self, idx_mapping: torch.Tensor
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
num_reqs = idx_mapping.shape[0]
|
||||
_gather_block_tables_kernel[(self.num_kv_cache_groups, num_reqs)](
|
||||
idx_mapping,
|
||||
self.block_table_ptrs,
|
||||
self.input_block_table_ptrs,
|
||||
self.block_table_strides,
|
||||
self.num_blocks.gpu,
|
||||
self.num_blocks.gpu.stride(0),
|
||||
BLOCK_SIZE=1024, # type: ignore
|
||||
)
|
||||
return tuple(block_table[:num_reqs] for block_table in self.input_block_tables)
|
||||
|
||||
def get_dummy_block_tables(self, num_reqs: int) -> tuple[torch.Tensor, ...]:
|
||||
return tuple(block_table[:num_reqs] for block_table in self.input_block_tables)
|
||||
|
||||
def compute_slot_mappings(
|
||||
self,
|
||||
idx_mapping: torch.Tensor,
|
||||
query_start_loc: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
num_reqs = idx_mapping.shape[0]
|
||||
num_tokens = positions.shape[0]
|
||||
num_groups = self.num_kv_cache_groups
|
||||
_compute_slot_mappings_kernel[(num_groups, num_reqs + 1)](
|
||||
num_tokens,
|
||||
self.max_num_batched_tokens,
|
||||
idx_mapping,
|
||||
query_start_loc,
|
||||
positions,
|
||||
self.block_table_ptrs,
|
||||
self.block_table_strides,
|
||||
self.block_sizes_tensor,
|
||||
self.slot_mappings,
|
||||
self.slot_mappings.stride(0),
|
||||
self.cp_rank,
|
||||
CP_SIZE=self.cp_size,
|
||||
CP_INTERLEAVE=self.cp_interleave,
|
||||
PAD_ID=PAD_SLOT_ID,
|
||||
TRITON_BLOCK_SIZE=1024, # type: ignore
|
||||
)
|
||||
return self.slot_mappings[:, :num_tokens]
|
||||
|
||||
def get_dummy_slot_mappings(self, num_tokens: int) -> torch.Tensor:
|
||||
self.slot_mappings.fill_(PAD_SLOT_ID)
|
||||
return self.slot_mappings[:, :num_tokens]
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _gather_block_tables_kernel(
|
||||
batch_idx_to_req_idx, # [batch_size]
|
||||
src_block_table_ptrs, # [num_kv_cache_groups]
|
||||
dst_block_table_ptrs, # [num_kv_cache_groups]
|
||||
block_table_strides, # [num_kv_cache_groups]
|
||||
num_blocks_ptr, # [num_kv_cache_groups, max_num_reqs]
|
||||
num_blocks_stride,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
# kv cache group id
|
||||
group_id = tl.program_id(0)
|
||||
batch_idx = tl.program_id(1)
|
||||
req_idx = tl.load(batch_idx_to_req_idx + batch_idx)
|
||||
|
||||
group_num_blocks_ptr = num_blocks_ptr + group_id * num_blocks_stride
|
||||
num_blocks = tl.load(group_num_blocks_ptr + req_idx)
|
||||
|
||||
stride = tl.load(block_table_strides + group_id)
|
||||
src_block_table_ptr = _load_ptr(src_block_table_ptrs + group_id, tl.int32)
|
||||
src_row_ptr = src_block_table_ptr + req_idx * stride
|
||||
dst_block_table_ptr = _load_ptr(dst_block_table_ptrs + group_id, tl.int32)
|
||||
dst_row_ptr = dst_block_table_ptr + batch_idx * stride
|
||||
|
||||
for i in tl.range(0, num_blocks, BLOCK_SIZE):
|
||||
offset = i + tl.arange(0, BLOCK_SIZE)
|
||||
block_ids = tl.load(src_row_ptr + offset, mask=offset < num_blocks)
|
||||
tl.store(dst_row_ptr + offset, block_ids, mask=offset < num_blocks)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _compute_slot_mappings_kernel(
|
||||
num_tokens,
|
||||
max_num_tokens,
|
||||
idx_mapping, # [num_reqs]
|
||||
query_start_loc, # [num_reqs + 1]
|
||||
pos, # [num_tokens]
|
||||
block_table_ptrs, # [num_kv_cache_groups]
|
||||
block_table_strides, # [num_kv_cache_groups]
|
||||
block_sizes, # [num_kv_cache_groups]
|
||||
slot_mappings_ptr, # [num_kv_cache_groups, max_num_tokens]
|
||||
slot_mappings_stride,
|
||||
cp_rank,
|
||||
CP_SIZE: tl.constexpr,
|
||||
CP_INTERLEAVE: tl.constexpr,
|
||||
PAD_ID: tl.constexpr,
|
||||
TRITON_BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
# kv cache group id
|
||||
group_id = tl.program_id(0)
|
||||
batch_idx = tl.program_id(1)
|
||||
slot_mapping_ptr = slot_mappings_ptr + group_id * slot_mappings_stride
|
||||
|
||||
if batch_idx == tl.num_programs(1) - 1:
|
||||
# Pad remaining slots to -1. This is needed for CUDA graphs.
|
||||
for i in range(num_tokens, max_num_tokens, TRITON_BLOCK_SIZE):
|
||||
offset = i + tl.arange(0, TRITON_BLOCK_SIZE)
|
||||
tl.store(slot_mapping_ptr + offset, PAD_ID, mask=offset < max_num_tokens)
|
||||
return
|
||||
|
||||
block_table_ptr = _load_ptr(block_table_ptrs + group_id, tl.int32)
|
||||
block_table_stride = tl.load(block_table_strides + group_id)
|
||||
block_size = tl.load(block_sizes + group_id)
|
||||
|
||||
req_state_idx = tl.load(idx_mapping + batch_idx)
|
||||
start_idx = tl.load(query_start_loc + batch_idx)
|
||||
end_idx = tl.load(query_start_loc + batch_idx + 1)
|
||||
for i in range(start_idx, end_idx, TRITON_BLOCK_SIZE):
|
||||
offset = i + tl.arange(0, TRITON_BLOCK_SIZE)
|
||||
positions = tl.load(pos + offset, mask=offset < end_idx, other=0)
|
||||
|
||||
block_indices = positions // (block_size * CP_SIZE)
|
||||
block_offsets = positions % (block_size * CP_SIZE)
|
||||
block_numbers = tl.load(
|
||||
block_table_ptr + req_state_idx * block_table_stride + block_indices
|
||||
)
|
||||
|
||||
if CP_SIZE == 1:
|
||||
# Common case: Context parallelism is not used.
|
||||
slot_ids = block_numbers * block_size + block_offsets
|
||||
else:
|
||||
# Context parallelism is used.
|
||||
is_local = block_offsets // CP_INTERLEAVE % CP_SIZE == cp_rank
|
||||
rounds = block_offsets // (CP_INTERLEAVE * CP_SIZE)
|
||||
remainder = block_offsets % CP_INTERLEAVE
|
||||
local_offsets = rounds * CP_INTERLEAVE + remainder
|
||||
slot_ids = block_numbers * block_size + local_offsets
|
||||
slot_ids = tl.where(is_local, slot_ids, PAD_ID)
|
||||
|
||||
tl.store(slot_mapping_ptr + offset, slot_ids, mask=offset < end_idx)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _load_ptr(ptr_to_ptr, elem_dtype):
|
||||
ptr = tl.load(ptr_to_ptr)
|
||||
ptr = tl.cast(ptr, tl.pointer_type(elem_dtype))
|
||||
return tl.multiple_of(ptr, 16)
|
||||
220
vllm/v1/worker/gpu/buffer_utils.py
Normal file
220
vllm/v1/worker/gpu/buffer_utils.py
Normal file
@@ -0,0 +1,220 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Iterable, Sequence
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.platform_utils import is_uva_available
|
||||
from vllm.utils.torch_utils import (
|
||||
async_tensor_h2d,
|
||||
get_accelerator_view_from_cpu_tensor,
|
||||
)
|
||||
|
||||
|
||||
def async_copy_to_gpu(
|
||||
x: torch.Tensor | np.ndarray,
|
||||
out: torch.Tensor | None = None,
|
||||
device: torch.device | None = None,
|
||||
) -> torch.Tensor:
|
||||
if isinstance(x, np.ndarray):
|
||||
x = torch.from_numpy(x)
|
||||
assert x.is_cpu
|
||||
|
||||
if out is None:
|
||||
assert device is not None
|
||||
out = torch.empty_like(x, device=device)
|
||||
|
||||
# CPU-to-CPU copy
|
||||
tmp = x.pin_memory()
|
||||
assert tmp is not x
|
||||
|
||||
# CPU-to-GPU copy
|
||||
return out.copy_(tmp, non_blocking=True)
|
||||
|
||||
|
||||
class UvaBuffer:
|
||||
def __init__(self, size: int | Sequence[int], dtype: torch.dtype):
|
||||
if not is_uva_available():
|
||||
raise RuntimeError("UVA is not available")
|
||||
self.cpu = torch.zeros(size, dtype=dtype, device="cpu", pin_memory=True)
|
||||
self.np = self.cpu.numpy()
|
||||
self.uva = get_accelerator_view_from_cpu_tensor(self.cpu)
|
||||
|
||||
|
||||
class UvaBufferPool:
|
||||
def __init__(
|
||||
self,
|
||||
size: int | Sequence[int],
|
||||
dtype: torch.dtype,
|
||||
max_concurrency: int = 2,
|
||||
):
|
||||
self.size = size
|
||||
self.dtype = dtype
|
||||
self.max_concurrency = max_concurrency
|
||||
|
||||
# UVA buffers for concurrency
|
||||
self._uva_bufs = [UvaBuffer(size, dtype) for _ in range(max_concurrency)]
|
||||
# Current buffer index
|
||||
self._curr = 0
|
||||
|
||||
def copy_to_uva(self, x: torch.Tensor | np.ndarray | list) -> torch.Tensor:
|
||||
# Round robin to the next buffer.
|
||||
self._curr = (self._curr + 1) % self.max_concurrency
|
||||
buf = self._uva_bufs[self._curr]
|
||||
# CPU-to-CPU copy
|
||||
dst = buf.cpu if isinstance(x, torch.Tensor) else buf.np
|
||||
n = len(x)
|
||||
dst[:n] = x
|
||||
return buf.uva[:n]
|
||||
|
||||
def copy_to_gpu(
|
||||
self,
|
||||
x: torch.Tensor | np.ndarray,
|
||||
out: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
uva = self.copy_to_uva(x)
|
||||
# CPU-to-GPU copy
|
||||
return uva.clone() if out is None else out.copy_(uva, non_blocking=True)
|
||||
|
||||
|
||||
class UvaBackedTensor:
|
||||
def __init__(
|
||||
self, size: int | Sequence[int], dtype: torch.dtype, max_concurrency: int = 2
|
||||
):
|
||||
self.dtype = dtype
|
||||
self.max_concurrency = max_concurrency
|
||||
|
||||
# Source of truth
|
||||
self.cpu = torch.zeros(size, dtype=dtype, device="cpu", pin_memory=False)
|
||||
self.np = self.cpu.numpy()
|
||||
|
||||
# Buffers for concurrency
|
||||
self.pool = UvaBufferPool(size, dtype, max_concurrency)
|
||||
self.gpu = self.pool.copy_to_uva(self.np)
|
||||
|
||||
def copy_to_uva(self, n: int | None = None) -> torch.Tensor:
|
||||
# CPU-to-CPU copy
|
||||
self.gpu = self.pool.copy_to_uva(self.np[:n] if n is not None else self.np)
|
||||
return self.gpu
|
||||
|
||||
|
||||
class StagedWriteTensor:
|
||||
def __init__(
|
||||
self,
|
||||
size: int | Sequence[int],
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
max_concurrency: int = 2,
|
||||
uva_instead_of_gpu: bool = False,
|
||||
):
|
||||
supported_dtypes = [torch.int32, torch.int64, torch.float32]
|
||||
if dtype not in supported_dtypes:
|
||||
raise ValueError(
|
||||
f"Unsupported dtype {dtype}: should be one of {supported_dtypes}"
|
||||
)
|
||||
self.num_rows = size if isinstance(size, int) else size[0]
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.max_concurrency = max_concurrency
|
||||
|
||||
if not uva_instead_of_gpu:
|
||||
# Create a GPU tensor (default)
|
||||
self.gpu = torch.zeros(size, dtype=dtype, device=device)
|
||||
else:
|
||||
# For a large but not-frequently-accessed tensor, we can use UVA instead of
|
||||
# GPU to save GPU memory
|
||||
self._uva_buf = UvaBuffer(size, dtype)
|
||||
self.gpu = self._uva_buf.uva
|
||||
|
||||
self._staged_write_indices: list[int] = []
|
||||
self._staged_write_starts: list[int] = []
|
||||
self._staged_write_contents: list[int | float] = []
|
||||
self._staged_write_cu_lens: list[int] = []
|
||||
|
||||
new_buffer = partial(UvaBufferPool, max_concurrency=max_concurrency)
|
||||
|
||||
self.write_indices = new_buffer(self.num_rows, dtype=torch.int32)
|
||||
self.write_starts = new_buffer(self.num_rows, dtype=torch.int32)
|
||||
self.write_cu_lens = new_buffer(self.num_rows, dtype=torch.int32)
|
||||
|
||||
def stage_write(
|
||||
self, index: int, start: int, x: Iterable[int] | Iterable[float]
|
||||
) -> None:
|
||||
assert index >= 0
|
||||
assert start >= 0
|
||||
if not x:
|
||||
return
|
||||
self._staged_write_indices.append(index)
|
||||
self._staged_write_starts.append(start)
|
||||
self._staged_write_contents.extend(x)
|
||||
self._staged_write_cu_lens.append(len(self._staged_write_contents))
|
||||
|
||||
def stage_write_elem(self, index: int, x: int) -> None:
|
||||
assert index >= 0
|
||||
self._staged_write_indices.append(index)
|
||||
self._staged_write_starts.append(0)
|
||||
self._staged_write_contents.append(x)
|
||||
self._staged_write_cu_lens.append(len(self._staged_write_contents))
|
||||
|
||||
def apply_write(self) -> None:
|
||||
n = len(self._staged_write_indices)
|
||||
if n == 0:
|
||||
return
|
||||
|
||||
indices_uva = self.write_indices.copy_to_uva(self._staged_write_indices)
|
||||
starts_uva = self.write_starts.copy_to_uva(self._staged_write_starts)
|
||||
cu_lens_uva = self.write_cu_lens.copy_to_uva(self._staged_write_cu_lens)
|
||||
|
||||
# Special handling for write_contents
|
||||
write_contents = async_tensor_h2d(
|
||||
self._staged_write_contents, self.dtype, self.device, pin_memory=True
|
||||
)
|
||||
|
||||
# Write diffs to the GPU buffer
|
||||
_apply_write_kernel[(n,)](
|
||||
self.gpu,
|
||||
self.gpu.stride(0),
|
||||
indices_uva,
|
||||
starts_uva,
|
||||
write_contents,
|
||||
cu_lens_uva,
|
||||
BLOCK_SIZE=1024,
|
||||
)
|
||||
# Clear the staged writes
|
||||
self.clear_staged_writes()
|
||||
|
||||
def clear_staged_writes(self) -> None:
|
||||
self._staged_write_indices.clear()
|
||||
self._staged_write_starts.clear()
|
||||
self._staged_write_contents.clear()
|
||||
self._staged_write_cu_lens.clear()
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _apply_write_kernel(
|
||||
output_ptr,
|
||||
output_stride,
|
||||
write_indices_ptr,
|
||||
write_starts_ptr,
|
||||
write_contents_ptr,
|
||||
write_cu_lens_ptr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
row_idx = tl.load(write_indices_ptr + pid)
|
||||
start_idx = tl.load(write_starts_ptr + pid)
|
||||
|
||||
cu_start = tl.load(write_cu_lens_ptr + pid - 1) if pid > 0 else 0
|
||||
cu_end = tl.load(write_cu_lens_ptr + pid)
|
||||
content_len = cu_end - cu_start
|
||||
|
||||
for i in range(0, content_len, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < content_len
|
||||
content = tl.load(write_contents_ptr + cu_start + block, mask=mask)
|
||||
tl.store(
|
||||
output_ptr + row_idx * output_stride + start_idx + block, content, mask=mask
|
||||
)
|
||||
61
vllm/v1/worker/gpu/cp_utils.py
Normal file
61
vllm/v1/worker/gpu/cp_utils.py
Normal file
@@ -0,0 +1,61 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
def prepare_dcp_local_seq_lens(
|
||||
dcp_local_seq_lens: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
num_reqs: int,
|
||||
dcp_size: int,
|
||||
dcp_rank: int,
|
||||
cp_interleave: int,
|
||||
) -> None:
|
||||
"""Populate the persistent DCP local seq_lens buffer (CUDA graph safe)."""
|
||||
if dcp_size == 1:
|
||||
return
|
||||
|
||||
max_num_reqs = dcp_local_seq_lens.shape[0]
|
||||
BLOCK_SIZE = 128
|
||||
num_blocks = triton.cdiv(max_num_reqs, BLOCK_SIZE)
|
||||
_dcp_local_seq_lens_kernel[(num_blocks,)](
|
||||
dcp_local_seq_lens,
|
||||
seq_lens,
|
||||
dcp_size,
|
||||
dcp_rank,
|
||||
cp_interleave,
|
||||
num_reqs,
|
||||
max_num_reqs,
|
||||
BLOCK_SIZE,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _dcp_local_seq_lens_kernel(
|
||||
out_ptr,
|
||||
seq_lens_ptr,
|
||||
dcp_size,
|
||||
dcp_rank,
|
||||
cp_interleave,
|
||||
num_reqs,
|
||||
max_num_reqs,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
block = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
|
||||
seq_lens = tl.load(seq_lens_ptr + block, mask=block < num_reqs)
|
||||
|
||||
# Distribute KV cache among different ranks, in a round-robin manner.
|
||||
rounds = seq_lens // (dcp_size * cp_interleave)
|
||||
remainder = seq_lens % (dcp_size * cp_interleave)
|
||||
|
||||
remainder = tl.maximum(remainder - dcp_rank * cp_interleave, 0)
|
||||
remainder = tl.minimum(remainder, cp_interleave)
|
||||
local_seq_lens = rounds * cp_interleave + remainder
|
||||
|
||||
# For [num_reqs, max_num_reqs), pad with 0
|
||||
local_seq_lens = tl.where(block < num_reqs, local_seq_lens, 0)
|
||||
tl.store(out_ptr + block, local_seq_lens, mask=block < max_num_reqs)
|
||||
462
vllm/v1/worker/gpu/cudagraph_utils.py
Normal file
462
vllm/v1/worker/gpu/cudagraph_utils.py
Normal file
@@ -0,0 +1,462 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.compilation import CUDAGraphMode
|
||||
from vllm.distributed.parallel_state import graph_capture, is_global_first_rank
|
||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
from vllm.model_executor.offloader.base import get_offloader
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.worker.gpu.attn_utils import (
|
||||
build_attn_metadata,
|
||||
build_slot_mappings_by_layer,
|
||||
)
|
||||
from vllm.v1.worker.gpu.block_table import BlockTables
|
||||
from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
|
||||
from vllm.v1.worker.gpu.input_batch import InputBuffers
|
||||
from vllm.v1.worker.utils import AttentionGroup
|
||||
|
||||
|
||||
class CudaGraphManager:
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
uses_mrope: bool,
|
||||
use_aux_hidden_state_outputs: bool,
|
||||
device: torch.device,
|
||||
):
|
||||
self.vllm_config = vllm_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.uses_mrope = uses_mrope
|
||||
self.use_aux_hidden_state_outputs = use_aux_hidden_state_outputs
|
||||
self.device = device
|
||||
|
||||
self.max_model_len = vllm_config.model_config.max_model_len
|
||||
self.max_num_reqs = self.scheduler_config.max_num_seqs
|
||||
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
|
||||
self.dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
|
||||
self.uniform_decode_query_len = 1
|
||||
spec_config = vllm_config.speculative_config
|
||||
if spec_config is not None:
|
||||
self.uniform_decode_query_len += spec_config.num_speculative_tokens
|
||||
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
assert self.compilation_config is not None
|
||||
self.cudagraph_mode = self.compilation_config.cudagraph_mode
|
||||
|
||||
use_uniform_decode_cudagraph = (
|
||||
self.cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
|
||||
and self.cudagraph_mode.separate_routine()
|
||||
)
|
||||
self.cudagraph_sizes, self.uniform_decode_cudagraph_sizes = get_cudagraph_sizes(
|
||||
self.compilation_config.cudagraph_capture_sizes,
|
||||
self.max_num_reqs,
|
||||
self.max_num_tokens,
|
||||
self.cudagraph_mode,
|
||||
self.uniform_decode_query_len,
|
||||
use_uniform_decode_cudagraph,
|
||||
)
|
||||
|
||||
self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
|
||||
self.pool = None
|
||||
if self.cudagraph_mode != CUDAGraphMode.NONE:
|
||||
self.pool = torch.cuda.graph_pool_handle()
|
||||
self.hidden_states: torch.Tensor | None = None
|
||||
self.aux_hidden_states: list[torch.Tensor] = []
|
||||
|
||||
def needs_capture(self) -> bool:
|
||||
return len(self.cudagraph_sizes) > 0
|
||||
|
||||
def get_cudagraph_size(
|
||||
self, num_tokens: int, uniform_decode: bool = False
|
||||
) -> int | None:
|
||||
if uniform_decode and self.uniform_decode_cudagraph_sizes:
|
||||
return self.uniform_decode_cudagraph_sizes.get(num_tokens)
|
||||
return self.cudagraph_sizes.get(num_tokens)
|
||||
|
||||
def capture_graph(
|
||||
self,
|
||||
num_tokens: int,
|
||||
capture_cg_mode: CUDAGraphMode,
|
||||
model: nn.Module,
|
||||
input_buffers: InputBuffers,
|
||||
mrope_positions: torch.Tensor | None,
|
||||
inputs_embeds: torch.Tensor | None,
|
||||
block_tables: BlockTables,
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
has_lora: bool = False,
|
||||
uniform_decode: bool = False,
|
||||
) -> None:
|
||||
# select and check capture function
|
||||
assert capture_cg_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], (
|
||||
f"Invalid capture_cudagraph_mode for capture: {capture_cg_mode}"
|
||||
)
|
||||
if capture_cg_mode == CUDAGraphMode.PIECEWISE:
|
||||
capture_fn = self._capture_piecewise_graph
|
||||
else:
|
||||
capture_fn = self._capture_full_graph
|
||||
# prepare inputs
|
||||
if uniform_decode:
|
||||
num_reqs = min(
|
||||
cdiv(num_tokens, self.uniform_decode_query_len),
|
||||
self.max_num_reqs,
|
||||
)
|
||||
else:
|
||||
num_reqs = min(num_tokens, self.max_num_reqs)
|
||||
input_ids = input_buffers.input_ids[:num_tokens]
|
||||
positions = input_buffers.positions[:num_tokens]
|
||||
if self.uses_mrope:
|
||||
assert mrope_positions is not None
|
||||
positions = mrope_positions[:, :num_tokens]
|
||||
if inputs_embeds is not None:
|
||||
inputs_embeds = inputs_embeds[:num_tokens]
|
||||
attn_metadata, slot_mappings = prepare_inputs_to_capture(
|
||||
num_reqs,
|
||||
num_tokens,
|
||||
input_buffers,
|
||||
block_tables,
|
||||
attn_groups,
|
||||
self.max_model_len,
|
||||
kv_cache_config,
|
||||
uniform_decode_query_len=(
|
||||
self.uniform_decode_query_len if uniform_decode else 0
|
||||
),
|
||||
)
|
||||
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
|
||||
|
||||
# Warm up.
|
||||
with set_forward_context(
|
||||
attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_tokens,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
slot_mapping=slot_mappings,
|
||||
):
|
||||
model_output = model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
hidden_states, aux_hidden_states = model_output
|
||||
else:
|
||||
hidden_states = model_output
|
||||
aux_hidden_states = None
|
||||
|
||||
# Allocate output buffers if not already done.
|
||||
if self.hidden_states is None:
|
||||
self.hidden_states = torch.empty_like(hidden_states)
|
||||
if self.use_aux_hidden_state_outputs and not self.aux_hidden_states:
|
||||
self.aux_hidden_states = [torch.empty_like(x) for x in aux_hidden_states]
|
||||
|
||||
capture_fn(
|
||||
num_tokens=num_tokens,
|
||||
num_reqs=num_reqs,
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
attn_metadata=attn_metadata,
|
||||
slot_mappings=slot_mappings,
|
||||
has_lora=has_lora,
|
||||
)
|
||||
|
||||
def _capture_full_graph(
|
||||
self,
|
||||
num_tokens: int,
|
||||
num_reqs: int,
|
||||
model: nn.Module,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor | None,
|
||||
num_tokens_across_dp: torch.Tensor,
|
||||
attn_metadata: dict[str, Any] | None,
|
||||
slot_mappings: dict[str, torch.Tensor] | None,
|
||||
has_lora: bool = False,
|
||||
) -> None:
|
||||
assert attn_metadata is not None
|
||||
# Capture the graph.
|
||||
assert num_tokens not in self.graphs
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
|
||||
# Sync offloader's copy stream before capture.
|
||||
# Ensure any pre-capture prefetches from offloader are complete.
|
||||
get_offloader().sync_prev_onload()
|
||||
|
||||
with (
|
||||
set_forward_context(
|
||||
attn_metadata=attn_metadata,
|
||||
vllm_config=self.vllm_config,
|
||||
num_tokens=num_tokens,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
slot_mapping=slot_mappings,
|
||||
),
|
||||
torch.cuda.graph(graph, self.pool),
|
||||
):
|
||||
model_output = model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
# Join offloader's copy stream after forward to avoid unjoined
|
||||
# stream error. The last layer's start_prefetch forks copy_stream,
|
||||
# but wait_prefetch only happens in the next forward pass.
|
||||
get_offloader().join_after_forward()
|
||||
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
hidden_states, aux_hidden_states = model_output
|
||||
else:
|
||||
hidden_states = model_output
|
||||
aux_hidden_states = None
|
||||
|
||||
# Copy outputs to the output buffers.
|
||||
assert self.hidden_states is not None
|
||||
self.hidden_states[:num_tokens] = hidden_states
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
for i, aux_hidden in enumerate(aux_hidden_states):
|
||||
self.aux_hidden_states[i][:num_tokens] = aux_hidden
|
||||
self.graphs[num_tokens] = graph
|
||||
|
||||
def _capture_piecewise_graph(
|
||||
self,
|
||||
num_tokens: int,
|
||||
num_reqs: int,
|
||||
model: nn.Module,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor | None,
|
||||
num_tokens_across_dp: torch.Tensor,
|
||||
attn_metadata: dict[str, Any] | None,
|
||||
slot_mappings: dict[str, torch.Tensor] | None,
|
||||
has_lora: bool = False,
|
||||
) -> None:
|
||||
# create batch descriptor for piecewise cudagraph dispatch key
|
||||
batch_descriptor = BatchDescriptor(num_tokens=num_tokens, has_lora=has_lora)
|
||||
|
||||
# Capture run - CUDAGraphWrapper inside torch.compile will auto capture.
|
||||
with set_forward_context(
|
||||
attn_metadata=None, # piecewise no need attn_metadata
|
||||
vllm_config=self.vllm_config,
|
||||
num_tokens=num_tokens,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
batch_descriptor=batch_descriptor,
|
||||
slot_mapping=slot_mappings,
|
||||
):
|
||||
model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def capture(
|
||||
self,
|
||||
model: nn.Module,
|
||||
input_buffers: InputBuffers,
|
||||
mrope_positions: torch.Tensor | None,
|
||||
inputs_embeds: torch.Tensor | None,
|
||||
block_tables: BlockTables,
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
has_lora: bool = False,
|
||||
) -> None:
|
||||
common_kwargs = dict(
|
||||
device=self.device,
|
||||
capture_fn=self.capture_graph,
|
||||
model=model,
|
||||
input_buffers=input_buffers,
|
||||
mrope_positions=mrope_positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
block_tables=block_tables,
|
||||
attn_groups=attn_groups,
|
||||
kv_cache_config=kv_cache_config,
|
||||
has_lora=has_lora,
|
||||
)
|
||||
|
||||
# Phase 1: Capture for mixed prefill-decode batches if needed.
|
||||
mixed_mode = self.cudagraph_mode.mixed_mode()
|
||||
if mixed_mode != CUDAGraphMode.NONE:
|
||||
capture_graphs(
|
||||
cudagraph_sizes=self.cudagraph_sizes,
|
||||
capture_cudagraph_mode=mixed_mode,
|
||||
desc=f"Capturing CUDA graphs (mixed, {mixed_mode.name})",
|
||||
uniform_decode=False,
|
||||
**common_kwargs,
|
||||
)
|
||||
|
||||
# Phase 2: Capture FULL graphs for uniform decode batches if needed.
|
||||
# This is only needed if we use a separate routine for decode batches
|
||||
# and the decode_mode is FULL.
|
||||
if self.uniform_decode_cudagraph_sizes:
|
||||
capture_graphs(
|
||||
cudagraph_sizes=self.uniform_decode_cudagraph_sizes,
|
||||
capture_cudagraph_mode=CUDAGraphMode.FULL,
|
||||
desc="Capturing CUDA graphs (decode, FULL)",
|
||||
uniform_decode=True,
|
||||
**common_kwargs,
|
||||
)
|
||||
|
||||
def get_cudagraph_runtime_mode(
|
||||
self, num_reqs: int, num_tokens: int, max_query_len: int
|
||||
) -> tuple[CUDAGraphMode, int | None]:
|
||||
is_uniform_decode = (max_query_len == self.uniform_decode_query_len) and (
|
||||
num_tokens == max_query_len * num_reqs
|
||||
)
|
||||
|
||||
cudagraph_size = self.get_cudagraph_size(num_tokens, is_uniform_decode)
|
||||
if cudagraph_size is None:
|
||||
cudagraph_mode = CUDAGraphMode.NONE
|
||||
elif is_uniform_decode:
|
||||
cudagraph_mode = self.cudagraph_mode.decode_mode()
|
||||
else:
|
||||
cudagraph_mode = self.cudagraph_mode.mixed_mode()
|
||||
|
||||
if (
|
||||
cudagraph_mode == CUDAGraphMode.FULL
|
||||
and cudagraph_size is not None
|
||||
and cudagraph_size not in self.graphs
|
||||
):
|
||||
# If graph wasn't captured yet, fall back to eager.
|
||||
# This might happen when the dummy run is called before capture.
|
||||
cudagraph_mode = CUDAGraphMode.NONE
|
||||
cudagraph_size = None
|
||||
return cudagraph_mode, cudagraph_size
|
||||
|
||||
def run_fullgraph(
|
||||
self, num_tokens: int
|
||||
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
|
||||
assert num_tokens in self.graphs, f"No cudagraph for {num_tokens} tokens"
|
||||
# Sync offloader before replay - needed when transitioning from
|
||||
# eager/piecewise to full cudagraph (e.g., prefill → decode).
|
||||
# The previous eager iteration's start_prefetch may have queued
|
||||
# H2D copies on copy_stream that the graph's captured events
|
||||
# cannot see. Without this, replay could overwrite static buffers
|
||||
# while those copies are still in flight.
|
||||
get_offloader().sync_prev_onload()
|
||||
self.graphs[num_tokens].replay()
|
||||
assert self.hidden_states is not None
|
||||
hidden_states = self.hidden_states[:num_tokens]
|
||||
if not self.use_aux_hidden_state_outputs:
|
||||
return hidden_states
|
||||
return hidden_states, [x[:num_tokens] for x in self.aux_hidden_states]
|
||||
|
||||
|
||||
def get_cudagraph_sizes(
|
||||
capture_sizes: list[int] | None,
|
||||
max_num_reqs: int,
|
||||
max_num_tokens: int,
|
||||
cudagraph_mode: CUDAGraphMode,
|
||||
uniform_decode_query_len: int = 1,
|
||||
uniform_decode_cudagraph: bool = False,
|
||||
) -> tuple[dict[int, int], dict[int, int]]:
|
||||
# Support both FULL and PIECEWISE cudagraph modes
|
||||
if cudagraph_mode == CUDAGraphMode.NONE:
|
||||
return {}, {}
|
||||
if not capture_sizes:
|
||||
return {}, {}
|
||||
|
||||
capture_sizes = sorted(capture_sizes)
|
||||
if not capture_sizes:
|
||||
return {}, {}
|
||||
|
||||
cudagraph_sizes: dict[int, int] = {}
|
||||
for i in range(1, capture_sizes[-1] + 1):
|
||||
for x in capture_sizes:
|
||||
if i <= x:
|
||||
cudagraph_sizes[i] = x
|
||||
break
|
||||
|
||||
uniform_decode_cudagraph_sizes: dict[int, int] = {}
|
||||
if uniform_decode_cudagraph:
|
||||
max_num_tokens = max_num_reqs * uniform_decode_query_len
|
||||
uniform_decode_cudagraph_sizes = {
|
||||
k: v
|
||||
for k, v in cudagraph_sizes.items()
|
||||
if v <= max_num_tokens and v >= uniform_decode_query_len
|
||||
}
|
||||
return cudagraph_sizes, uniform_decode_cudagraph_sizes
|
||||
|
||||
|
||||
def capture_graphs(
|
||||
cudagraph_sizes: dict[int, int],
|
||||
device: torch.device,
|
||||
capture_fn: Callable,
|
||||
capture_cudagraph_mode: CUDAGraphMode,
|
||||
desc: str = "Capturing CUDA graphs",
|
||||
**capture_kwargs,
|
||||
) -> None:
|
||||
# Capture larger graphs first.
|
||||
sizes_to_capture = sorted(set(cudagraph_sizes.values()), reverse=True)
|
||||
if is_global_first_rank():
|
||||
sizes_to_capture = tqdm(sizes_to_capture, desc=desc)
|
||||
|
||||
with graph_capture(device=device):
|
||||
for size in sizes_to_capture:
|
||||
capture_fn(size, capture_cudagraph_mode, **capture_kwargs)
|
||||
|
||||
|
||||
def prepare_inputs_to_capture(
|
||||
num_reqs: int,
|
||||
num_tokens: int,
|
||||
input_buffers: InputBuffers,
|
||||
block_tables: BlockTables,
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
max_model_len: int,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
uniform_decode_query_len: int = 0,
|
||||
) -> tuple[dict[str, Any], dict[str, torch.Tensor]]:
|
||||
if uniform_decode_query_len > 0:
|
||||
num_tokens_per_req = uniform_decode_query_len
|
||||
else:
|
||||
num_tokens_per_req = num_tokens // num_reqs
|
||||
|
||||
query_start_loc_np = np.arange(num_reqs + 1, dtype=np.int32) * num_tokens_per_req
|
||||
query_start_loc_np[-1] = num_tokens
|
||||
query_start_loc_cpu = torch.from_numpy(query_start_loc_np)
|
||||
input_buffers.query_start_loc[: num_reqs + 1] = query_start_loc_cpu
|
||||
input_buffers.query_start_loc[num_reqs + 1 :] = num_tokens
|
||||
query_start_loc = input_buffers.query_start_loc[: num_reqs + 1]
|
||||
|
||||
# HACK(woosuk): For faster warmup, we set seq_lens (GPU) to num_tokens
|
||||
# rather than max_model_len.
|
||||
input_buffers.seq_lens[:num_reqs] = num_tokens
|
||||
input_buffers.seq_lens[num_reqs:] = 0
|
||||
|
||||
input_buffers.dcp_local_seq_lens[:num_reqs] = num_tokens
|
||||
input_buffers.dcp_local_seq_lens[num_reqs:] = 0
|
||||
|
||||
input_block_tables = [x[:num_reqs] for x in block_tables.input_block_tables]
|
||||
slot_mappings = block_tables.slot_mappings[:, :num_tokens]
|
||||
slot_mappings_by_layer = build_slot_mappings_by_layer(
|
||||
slot_mappings, kv_cache_config
|
||||
)
|
||||
|
||||
attn_metadata = build_attn_metadata(
|
||||
attn_groups=attn_groups,
|
||||
num_reqs=num_reqs,
|
||||
num_tokens=num_tokens,
|
||||
query_start_loc_gpu=query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
max_query_len=num_tokens_per_req,
|
||||
seq_lens=input_buffers.seq_lens,
|
||||
max_seq_len=max_model_len,
|
||||
block_tables=input_block_tables,
|
||||
slot_mappings=slot_mappings,
|
||||
kv_cache_config=kv_cache_config,
|
||||
dcp_local_seq_lens=input_buffers.dcp_local_seq_lens,
|
||||
)
|
||||
return attn_metadata, slot_mappings_by_layer
|
||||
77
vllm/v1/worker/gpu/dp_utils.py
Normal file
77
vllm/v1/worker/gpu/dp_utils.py
Normal file
@@ -0,0 +1,77 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from vllm.distributed.parallel_state import get_dp_group
|
||||
|
||||
|
||||
def make_num_tokens_across_dp(dp_size: int, num_tokens: int) -> torch.Tensor | None:
|
||||
if dp_size == 1:
|
||||
return None
|
||||
return torch.full((dp_size,), num_tokens, dtype=torch.int32, device="cpu")
|
||||
|
||||
|
||||
def get_batch_metadata_across_dp(
|
||||
num_tokens: int,
|
||||
cudagraph_size: int,
|
||||
cudagraph_runtime_mode: int,
|
||||
dp_size: int,
|
||||
dp_rank: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
assert dp_size > 1
|
||||
# Use CPU group to avoid CPU-GPU synchronization.
|
||||
group = get_dp_group().cpu_group
|
||||
tensor = torch.zeros(3, dp_size, dtype=torch.int32, device="cpu")
|
||||
tensor[0][dp_rank] = num_tokens
|
||||
tensor[1][dp_rank] = cudagraph_size
|
||||
tensor[2][dp_rank] = cudagraph_runtime_mode
|
||||
dist.all_reduce(tensor, group=group)
|
||||
return tensor[0], tensor[1], tensor[2]
|
||||
|
||||
|
||||
def get_cudagraph_and_dp_padding(
|
||||
num_tokens: int,
|
||||
cudagraph_size: int | None,
|
||||
cudagraph_runtime_mode: int,
|
||||
dp_size: int,
|
||||
dp_rank: int,
|
||||
) -> tuple[int, torch.Tensor | None, int]:
|
||||
if dp_size == 1:
|
||||
if cudagraph_size is not None:
|
||||
return cudagraph_size, None, cudagraph_runtime_mode
|
||||
else:
|
||||
return num_tokens, None, cudagraph_runtime_mode
|
||||
|
||||
# Convert None to -1 for sync (indicates no cudagraph available)
|
||||
if num_tokens == 0:
|
||||
cudagraph_size = 0
|
||||
elif cudagraph_size is None:
|
||||
cudagraph_size = -1
|
||||
|
||||
num_tokens_across_dp, cudagraph_size_across_dp, cudagraph_mode_across_dp = (
|
||||
get_batch_metadata_across_dp(
|
||||
num_tokens, cudagraph_size, cudagraph_runtime_mode, dp_size, dp_rank
|
||||
)
|
||||
)
|
||||
if torch.all(num_tokens_across_dp == 0).item():
|
||||
# All ranks have zero tokens to run.
|
||||
return 0, None, 0
|
||||
|
||||
# Synchronize cudagraph_runtime_mode across ranks by taking the minimum.
|
||||
synced_cudagraph_mode = int(cudagraph_mode_across_dp.min().item())
|
||||
# Check if all ranks have valid cudagraph_size.
|
||||
all_have_cudagraph = torch.all(cudagraph_size_across_dp != -1).item()
|
||||
|
||||
if synced_cudagraph_mode != 0 and all_have_cudagraph:
|
||||
# All ranks use cudagraph. Pad to max cudagraph_size.
|
||||
max_cudagraph_size = int(cudagraph_size_across_dp.max().item())
|
||||
num_tokens_across_dp[:] = max_cudagraph_size
|
||||
return max_cudagraph_size, num_tokens_across_dp, synced_cudagraph_mode
|
||||
else:
|
||||
# Fall back to eager mode (no cudagraph).
|
||||
# Either some rank doesn't have cudagraph size or mode is NONE.
|
||||
synced_cudagraph_mode = 0
|
||||
num_tokens_across_dp = torch.clamp(num_tokens_across_dp, min=1)
|
||||
num_tokens_after_padding = int(num_tokens_across_dp[dp_rank].item())
|
||||
return num_tokens_after_padding, num_tokens_across_dp, synced_cudagraph_mode
|
||||
548
vllm/v1/worker/gpu/input_batch.py
Normal file
548
vllm/v1/worker/gpu/input_batch.py
Normal file
@@ -0,0 +1,548 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
|
||||
class InputBuffers:
|
||||
def __init__(
|
||||
self,
|
||||
max_num_reqs: int,
|
||||
max_num_tokens: int,
|
||||
device: torch.device,
|
||||
):
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.device = device
|
||||
|
||||
self.input_ids = torch.zeros(max_num_tokens, dtype=torch.int32, device=device)
|
||||
self.positions = torch.zeros(max_num_tokens, dtype=torch.int64, device=device)
|
||||
self.query_start_loc = torch.zeros(
|
||||
max_num_reqs + 1, dtype=torch.int32, device=device
|
||||
)
|
||||
self.seq_lens = torch.zeros(max_num_reqs, dtype=torch.int32, device=device)
|
||||
# DCP: per-request local seq_lens buffer
|
||||
self.dcp_local_seq_lens = torch.zeros(
|
||||
max_num_reqs, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InputBatch:
|
||||
# batch_idx -> req_id
|
||||
req_ids: list[str]
|
||||
num_reqs: int
|
||||
|
||||
# batch_idx -> req_state_idx
|
||||
idx_mapping: torch.Tensor
|
||||
idx_mapping_np: np.ndarray
|
||||
# Identical to idx_mapping except for spec decoding.
|
||||
expanded_idx_mapping: torch.Tensor
|
||||
# [total_num_logits] position within request for each logit
|
||||
expanded_local_pos: torch.Tensor
|
||||
|
||||
# [num_reqs]
|
||||
# batch_idx -> num_scheduled_tokens
|
||||
num_scheduled_tokens: np.ndarray
|
||||
# sum(num_scheduled_tokens)
|
||||
num_tokens: int
|
||||
num_tokens_after_padding: int
|
||||
num_draft_tokens: int
|
||||
|
||||
# [num_reqs + 1]
|
||||
query_start_loc: torch.Tensor
|
||||
query_start_loc_np: np.ndarray
|
||||
# [num_reqs]
|
||||
seq_lens: torch.Tensor
|
||||
|
||||
# [num_tokens_after_padding]
|
||||
input_ids: torch.Tensor
|
||||
# [num_tokens_after_padding]
|
||||
positions: torch.Tensor
|
||||
# [3, num_tokens_after_padding]
|
||||
mrope_positions: torch.Tensor | None
|
||||
# [num_tokens_after_padding, hidden_size]
|
||||
inputs_embeds: torch.Tensor | None
|
||||
|
||||
# layer_name -> Metadata
|
||||
attn_metadata: dict[str, Any]
|
||||
# layer_name -> slot_mapping
|
||||
slot_mappings: dict[str, torch.Tensor]
|
||||
|
||||
# [total_num_logits]
|
||||
logits_indices: torch.Tensor
|
||||
# [num_reqs + 1]
|
||||
cu_num_logits: torch.Tensor
|
||||
cu_num_logits_np: np.ndarray
|
||||
|
||||
# Whether any requests in batch use structured output.
|
||||
has_structured_output_reqs: bool
|
||||
|
||||
@classmethod
|
||||
def make_dummy(
|
||||
cls,
|
||||
num_reqs: int,
|
||||
num_tokens: int,
|
||||
input_buffers: InputBuffers,
|
||||
device: torch.device,
|
||||
) -> "InputBatch":
|
||||
assert 0 < num_reqs <= num_tokens
|
||||
req_ids = [f"req_{i}_{random_uuid()}" for i in range(num_reqs)]
|
||||
idx_mapping_np = np.arange(num_reqs, dtype=np.int32)
|
||||
idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=device)
|
||||
expanded_idx_mapping = idx_mapping
|
||||
expanded_local_pos = torch.zeros(num_reqs, dtype=torch.int32, device=device)
|
||||
num_scheduled_tokens = np.full(num_reqs, num_tokens // num_reqs, dtype=np.int32)
|
||||
num_scheduled_tokens[-1] += num_tokens % num_reqs
|
||||
assert int(num_scheduled_tokens.sum()) == num_tokens
|
||||
|
||||
# seq_len equals to query_len
|
||||
input_buffers.seq_lens[:num_reqs] = num_tokens // num_reqs
|
||||
input_buffers.seq_lens[num_reqs - 1] += num_tokens % num_reqs
|
||||
# Pad for full CUDA graph mode.
|
||||
input_buffers.seq_lens[num_reqs:] = 0
|
||||
seq_lens = input_buffers.seq_lens[:num_reqs]
|
||||
|
||||
query_start_loc_np = np.empty(num_reqs + 1, dtype=np.int32)
|
||||
query_start_loc_np[0] = 0
|
||||
np.cumsum(num_scheduled_tokens, out=query_start_loc_np[1:])
|
||||
input_buffers.query_start_loc[:1] = 0
|
||||
torch.cumsum(
|
||||
seq_lens, dim=0, out=input_buffers.query_start_loc[1 : num_reqs + 1]
|
||||
)
|
||||
# Pad for full CUDA graph mode.
|
||||
input_buffers.query_start_loc[num_reqs + 1 :] = num_tokens
|
||||
query_start_loc = input_buffers.query_start_loc[: num_reqs + 1]
|
||||
|
||||
input_ids = input_buffers.input_ids[:num_tokens].zero_()
|
||||
positions = input_buffers.positions[:num_tokens].zero_()
|
||||
|
||||
# attn_metadata = defaultdict(lambda: None)
|
||||
logits_indices = query_start_loc[1:] - 1
|
||||
cu_num_logits = torch.arange(num_reqs + 1, device=device, dtype=torch.int32)
|
||||
cu_num_logits_np = np.arange(num_reqs + 1, dtype=np.int32)
|
||||
return cls(
|
||||
req_ids=req_ids,
|
||||
num_reqs=num_reqs,
|
||||
idx_mapping=idx_mapping,
|
||||
idx_mapping_np=idx_mapping_np,
|
||||
expanded_idx_mapping=expanded_idx_mapping,
|
||||
expanded_local_pos=expanded_local_pos,
|
||||
num_scheduled_tokens=num_scheduled_tokens,
|
||||
num_tokens=num_tokens,
|
||||
num_tokens_after_padding=num_tokens,
|
||||
num_draft_tokens=0,
|
||||
query_start_loc=query_start_loc,
|
||||
query_start_loc_np=query_start_loc_np,
|
||||
seq_lens=seq_lens,
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
mrope_positions=None,
|
||||
inputs_embeds=None,
|
||||
attn_metadata=None, # type: ignore
|
||||
slot_mappings=None, # type: ignore
|
||||
logits_indices=logits_indices,
|
||||
cu_num_logits=cu_num_logits,
|
||||
cu_num_logits_np=cu_num_logits_np,
|
||||
has_structured_output_reqs=False,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _prepare_prefill_inputs_kernel(
|
||||
input_ids_ptr,
|
||||
next_prefill_tokens_ptr,
|
||||
idx_mapping_ptr,
|
||||
query_start_loc_ptr,
|
||||
all_token_ids_ptr,
|
||||
all_token_ids_stride,
|
||||
prefill_lens_ptr,
|
||||
num_computed_tokens_ptr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
||||
prefill_len = tl.load(prefill_lens_ptr + req_state_idx)
|
||||
num_computed = tl.load(num_computed_tokens_ptr + req_state_idx)
|
||||
if num_computed >= prefill_len:
|
||||
# Not prefill.
|
||||
return
|
||||
|
||||
query_start = tl.load(query_start_loc_ptr + batch_idx)
|
||||
query_end = tl.load(query_start_loc_ptr + batch_idx + 1)
|
||||
query_len = query_end - query_start
|
||||
|
||||
request_ptr = all_token_ids_ptr + req_state_idx * all_token_ids_stride
|
||||
for i in range(0, query_len, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < query_len
|
||||
tokens = tl.load(request_ptr + num_computed + block, mask=mask)
|
||||
tl.store(input_ids_ptr + query_start + block, tokens, mask=mask)
|
||||
|
||||
next_pos = num_computed + query_len
|
||||
if next_pos < prefill_len:
|
||||
next_token = tl.load(request_ptr + next_pos)
|
||||
tl.store(next_prefill_tokens_ptr + req_state_idx, next_token)
|
||||
|
||||
|
||||
def prepare_prefill_inputs(
|
||||
input_ids: torch.Tensor,
|
||||
next_prefill_tokens: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
query_start_loc: torch.Tensor,
|
||||
all_token_ids: torch.Tensor,
|
||||
prefill_len: torch.Tensor,
|
||||
num_computed_tokens: torch.Tensor,
|
||||
) -> None:
|
||||
num_reqs = idx_mapping.shape[0]
|
||||
_prepare_prefill_inputs_kernel[(num_reqs,)](
|
||||
input_ids,
|
||||
next_prefill_tokens,
|
||||
idx_mapping,
|
||||
query_start_loc,
|
||||
all_token_ids,
|
||||
all_token_ids.stride(0),
|
||||
prefill_len,
|
||||
num_computed_tokens,
|
||||
BLOCK_SIZE=1024,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _prepare_pos_seq_lens_kernel(
|
||||
pos_ptr,
|
||||
seq_lens_ptr,
|
||||
idx_mapping_ptr,
|
||||
query_start_loc_ptr,
|
||||
num_computed_tokens_ptr,
|
||||
max_num_reqs,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
req_id = tl.program_id(0)
|
||||
num_reqs = tl.num_programs(0) - 1
|
||||
if req_id == num_reqs:
|
||||
# Pad unused seq_lens as 0 for full CUDA graphs.
|
||||
for i in tl.range(num_reqs, max_num_reqs, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < max_num_reqs
|
||||
tl.store(seq_lens_ptr + block, 0, mask=mask)
|
||||
return
|
||||
|
||||
req_state_idx = tl.load(idx_mapping_ptr + req_id)
|
||||
num_computed_tokens = tl.load(num_computed_tokens_ptr + req_state_idx)
|
||||
|
||||
start = tl.load(query_start_loc_ptr + req_id)
|
||||
end = tl.load(query_start_loc_ptr + req_id + 1)
|
||||
query_len = end - start
|
||||
|
||||
seq_len = num_computed_tokens + query_len
|
||||
tl.store(seq_lens_ptr + req_id, seq_len)
|
||||
|
||||
for i in tl.range(0, query_len, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < query_len
|
||||
pos = num_computed_tokens + block
|
||||
tl.store(pos_ptr + start + block, pos, mask=mask)
|
||||
|
||||
|
||||
def prepare_pos_seq_lens(
|
||||
idx_mapping: torch.Tensor,
|
||||
query_start_loc: torch.Tensor,
|
||||
num_computed_tokens: torch.Tensor,
|
||||
pos: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
) -> None:
|
||||
num_reqs = idx_mapping.shape[0]
|
||||
# NOTE(woosuk): We do +1 because the last thread block is used
|
||||
# to pad unused seq_lens as 0 for full CUDA graphs.
|
||||
_prepare_pos_seq_lens_kernel[(num_reqs + 1,)](
|
||||
pos,
|
||||
seq_lens,
|
||||
idx_mapping,
|
||||
query_start_loc,
|
||||
num_computed_tokens,
|
||||
seq_lens.shape[0],
|
||||
BLOCK_SIZE=1024,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _combine_sampled_and_draft_tokens_kernel(
|
||||
input_ids_ptr,
|
||||
idx_mapping_ptr,
|
||||
last_sampled_tokens_ptr,
|
||||
query_start_loc_ptr,
|
||||
seq_lens_ptr,
|
||||
prefill_len_ptr,
|
||||
draft_tokens_ptr,
|
||||
draft_tokens_stride,
|
||||
cu_num_logits_ptr,
|
||||
logits_indices_ptr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
||||
|
||||
# Get the number of logits and draft tokens.
|
||||
cu_num_logits_start = tl.load(cu_num_logits_ptr + batch_idx)
|
||||
cu_num_logits_end = tl.load(cu_num_logits_ptr + batch_idx + 1)
|
||||
num_logits = cu_num_logits_end - cu_num_logits_start
|
||||
num_draft_tokens = num_logits - 1
|
||||
|
||||
# Compute the logits indices.
|
||||
block = tl.arange(0, BLOCK_SIZE)
|
||||
query_end = tl.load(query_start_loc_ptr + batch_idx + 1)
|
||||
logits_start = query_end - num_logits
|
||||
tl.store(
|
||||
logits_indices_ptr + cu_num_logits_start + block,
|
||||
logits_start + block,
|
||||
mask=block < num_logits,
|
||||
)
|
||||
|
||||
seq_len = tl.load(seq_lens_ptr + batch_idx)
|
||||
prefill_len = tl.load(prefill_len_ptr + req_state_idx)
|
||||
if seq_len <= prefill_len:
|
||||
# Handling prefill tokens. No sampled or draft tokens.
|
||||
return
|
||||
|
||||
# Write the last sampled token ID to input_ids.
|
||||
last_token_id = tl.load(last_sampled_tokens_ptr + req_state_idx)
|
||||
tl.store(input_ids_ptr + query_end - num_logits, last_token_id)
|
||||
|
||||
# Write the draft tokens (if any) to input_ids.
|
||||
if num_draft_tokens > 0:
|
||||
mask = block < num_draft_tokens
|
||||
draft_tokens = tl.load(
|
||||
draft_tokens_ptr + req_state_idx * draft_tokens_stride + block,
|
||||
mask=mask,
|
||||
)
|
||||
tl.store(
|
||||
input_ids_ptr + query_end - num_draft_tokens + block,
|
||||
draft_tokens,
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
|
||||
def combine_sampled_and_draft_tokens(
|
||||
input_ids: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
last_sampled_tokens: torch.Tensor,
|
||||
query_start_loc: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
prefill_len: torch.Tensor,
|
||||
draft_tokens: torch.Tensor,
|
||||
cu_num_logits: torch.Tensor,
|
||||
num_logits: int,
|
||||
) -> torch.Tensor:
|
||||
num_reqs = seq_lens.shape[0]
|
||||
num_speculative_steps = draft_tokens.shape[-1]
|
||||
|
||||
logits_indices = torch.empty(
|
||||
num_logits,
|
||||
dtype=torch.int64,
|
||||
device=input_ids.device,
|
||||
)
|
||||
_combine_sampled_and_draft_tokens_kernel[(num_reqs,)](
|
||||
input_ids,
|
||||
idx_mapping,
|
||||
last_sampled_tokens,
|
||||
query_start_loc,
|
||||
seq_lens,
|
||||
prefill_len,
|
||||
draft_tokens,
|
||||
draft_tokens.stride(0),
|
||||
cu_num_logits,
|
||||
logits_indices,
|
||||
# NOTE(woosuk): Add 1 to ensure the block can cover the last sampled token
|
||||
# in addition to all draft tokens.
|
||||
BLOCK_SIZE=triton.next_power_of_2(num_speculative_steps + 1),
|
||||
)
|
||||
return logits_indices
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _get_num_sampled_and_rejected_kernel(
|
||||
num_sampled_ptr,
|
||||
num_rejected_ptr,
|
||||
seq_lens_ptr,
|
||||
cu_num_logits_ptr,
|
||||
idx_mapping_ptr,
|
||||
prefill_len_ptr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
||||
|
||||
seq_len = tl.load(seq_lens_ptr + batch_idx)
|
||||
prefill_len = tl.load(prefill_len_ptr + req_state_idx)
|
||||
is_chunked_prefilling = seq_len < prefill_len
|
||||
|
||||
num_sampled = tl.load(num_sampled_ptr + batch_idx)
|
||||
num_sampled = tl.where(is_chunked_prefilling, 0, num_sampled)
|
||||
tl.store(num_sampled_ptr + batch_idx, num_sampled)
|
||||
|
||||
logits_start = tl.load(cu_num_logits_ptr + batch_idx)
|
||||
logits_end = tl.load(cu_num_logits_ptr + batch_idx + 1)
|
||||
num_logits = logits_end - logits_start
|
||||
|
||||
num_rejected = num_logits - num_sampled
|
||||
num_rejected = tl.where(is_chunked_prefilling, 0, num_rejected)
|
||||
tl.store(num_rejected_ptr + batch_idx, num_rejected)
|
||||
|
||||
|
||||
def get_num_sampled_and_rejected(
|
||||
num_sampled: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
cu_num_logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
prefill_len: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
num_reqs = idx_mapping.shape[0]
|
||||
num_rejected = torch.empty_like(num_sampled)
|
||||
_get_num_sampled_and_rejected_kernel[(num_reqs,)](
|
||||
num_sampled,
|
||||
num_rejected,
|
||||
seq_lens,
|
||||
cu_num_logits,
|
||||
idx_mapping,
|
||||
prefill_len,
|
||||
)
|
||||
return num_sampled, num_rejected
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _post_update_kernel(
|
||||
idx_mapping_ptr,
|
||||
num_computed_tokens_ptr,
|
||||
last_sampled_tokens_ptr,
|
||||
output_bin_counts_ptr,
|
||||
output_bin_counts_stride,
|
||||
sampled_tokens_ptr,
|
||||
sampled_tokens_stride,
|
||||
num_sampled_ptr,
|
||||
num_rejected_ptr,
|
||||
query_start_loc_ptr,
|
||||
all_token_ids_ptr,
|
||||
all_token_ids_stride,
|
||||
total_len_ptr,
|
||||
):
|
||||
req_id = tl.program_id(0)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + req_id)
|
||||
|
||||
total_len = tl.load(total_len_ptr + req_state_idx)
|
||||
num_sampled = tl.load(num_sampled_ptr + req_id)
|
||||
if num_sampled > 0:
|
||||
token_id = tl.load(
|
||||
sampled_tokens_ptr + req_id * sampled_tokens_stride + num_sampled - 1
|
||||
)
|
||||
tl.store(last_sampled_tokens_ptr + req_state_idx, token_id)
|
||||
tl.store(total_len_ptr + req_state_idx, total_len + num_sampled)
|
||||
|
||||
for i in range(num_sampled):
|
||||
token_id = tl.load(sampled_tokens_ptr + req_id * sampled_tokens_stride + i)
|
||||
token_ptr = (
|
||||
output_bin_counts_ptr + req_state_idx * output_bin_counts_stride + token_id
|
||||
)
|
||||
count = tl.load(token_ptr)
|
||||
count += 1
|
||||
tl.store(token_ptr, count)
|
||||
tl.store(
|
||||
all_token_ids_ptr + req_state_idx * all_token_ids_stride + total_len + i,
|
||||
token_id,
|
||||
)
|
||||
|
||||
query_start = tl.load(query_start_loc_ptr + req_id)
|
||||
query_end = tl.load(query_start_loc_ptr + req_id + 1)
|
||||
query_len = query_end - query_start
|
||||
num_rejected = tl.load(num_rejected_ptr + req_id)
|
||||
|
||||
num_computed = tl.load(num_computed_tokens_ptr + req_state_idx)
|
||||
num_computed += query_len - num_rejected
|
||||
tl.store(num_computed_tokens_ptr + req_state_idx, num_computed)
|
||||
|
||||
|
||||
def post_update(
|
||||
# [num_reqs]
|
||||
idx_mapping: torch.Tensor,
|
||||
# [max_num_reqs]
|
||||
num_computed_tokens: torch.Tensor,
|
||||
# [max_num_reqs]
|
||||
last_sampled_tokens: torch.Tensor,
|
||||
# [max_num_reqs, vocab_size]
|
||||
output_bin_counts: torch.Tensor,
|
||||
# [num_reqs, num_speculative_steps + 1]
|
||||
sampled_tokens: torch.Tensor,
|
||||
# [num_reqs]
|
||||
num_sampled: torch.Tensor,
|
||||
# [num_reqs]
|
||||
num_rejected: torch.Tensor,
|
||||
# [num_reqs + 1]
|
||||
query_start_loc: torch.Tensor,
|
||||
# [max_num_reqs, max_model_len]
|
||||
all_token_ids: torch.Tensor,
|
||||
# [max_num_reqs]
|
||||
total_len: torch.Tensor,
|
||||
) -> None:
|
||||
num_reqs = idx_mapping.shape[0]
|
||||
_post_update_kernel[(num_reqs,)](
|
||||
idx_mapping,
|
||||
num_computed_tokens,
|
||||
last_sampled_tokens,
|
||||
output_bin_counts,
|
||||
output_bin_counts.stride(0),
|
||||
sampled_tokens,
|
||||
sampled_tokens.stride(0),
|
||||
num_sampled,
|
||||
num_rejected,
|
||||
query_start_loc,
|
||||
all_token_ids,
|
||||
all_token_ids.stride(0),
|
||||
total_len,
|
||||
num_warps=1,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _expand_idx_mapping_kernel(
|
||||
idx_mapping_ptr,
|
||||
expanded_idx_mapping_ptr,
|
||||
expanded_local_pos_ptr,
|
||||
cu_num_logits_ptr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
start_idx = tl.load(cu_num_logits_ptr + req_idx)
|
||||
end_idx = tl.load(cu_num_logits_ptr + req_idx + 1)
|
||||
num_tokens = end_idx - start_idx
|
||||
|
||||
block = tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < num_tokens
|
||||
req_state_idx = tl.load(idx_mapping_ptr + req_idx)
|
||||
tl.store(expanded_idx_mapping_ptr + start_idx + block, req_state_idx, mask=mask)
|
||||
tl.store(expanded_local_pos_ptr + start_idx + block, block, mask=mask)
|
||||
|
||||
|
||||
def expand_idx_mapping(
|
||||
idx_mapping: torch.Tensor,
|
||||
total_num_logits: int,
|
||||
cu_num_logits: torch.Tensor,
|
||||
max_expand_len: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
num_reqs = idx_mapping.shape[0]
|
||||
expanded_idx_mapping = idx_mapping.new_empty(total_num_logits)
|
||||
expanded_local_pos = torch.empty(
|
||||
total_num_logits, dtype=torch.int32, device=idx_mapping.device
|
||||
)
|
||||
_expand_idx_mapping_kernel[(num_reqs,)](
|
||||
idx_mapping,
|
||||
expanded_idx_mapping,
|
||||
expanded_local_pos,
|
||||
cu_num_logits,
|
||||
BLOCK_SIZE=triton.next_power_of_2(max_expand_len),
|
||||
)
|
||||
return expanded_idx_mapping, expanded_local_pos
|
||||
134
vllm/v1/worker/gpu/kv_connector.py
Normal file
134
vllm/v1/worker/gpu/kv_connector.py
Normal file
@@ -0,0 +1,134 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer import (
|
||||
get_kv_transfer_group,
|
||||
has_kv_transfer_group,
|
||||
kv_transfer_state,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks
|
||||
from vllm.forward_context import (
|
||||
get_forward_context,
|
||||
is_forward_context_available,
|
||||
set_forward_context,
|
||||
)
|
||||
from vllm.v1.outputs import (
|
||||
EMPTY_MODEL_RUNNER_OUTPUT,
|
||||
KVConnectorOutput,
|
||||
ModelRunnerOutput,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
|
||||
class KVConnector:
|
||||
"""KVConnector interface used by GPUModelRunner."""
|
||||
|
||||
def pre_forward(self, scheduler_output: "SchedulerOutput") -> None:
|
||||
pass
|
||||
|
||||
def post_forward(
|
||||
self, scheduler_output: "SchedulerOutput", wait_for_save: bool = True
|
||||
) -> KVConnectorOutput | None:
|
||||
return None
|
||||
|
||||
def no_forward(self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
def set_disabled(self, disabled: bool) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class ActiveKVConnector(KVConnector):
|
||||
def __init__(
|
||||
self, vllm_config: VllmConfig, kv_caches_dict: dict[str, torch.Tensor]
|
||||
):
|
||||
self.vllm_config = vllm_config
|
||||
self.kv_connector = get_kv_transfer_group()
|
||||
# Register kv caches with KV Connector if applicable.
|
||||
# TODO: support cross_layers_kv_cache
|
||||
# (see https://github.com/vllm-project/vllm/pull/27743)
|
||||
self.kv_connector.register_kv_caches(kv_caches_dict)
|
||||
self.kv_connector.set_host_xfer_buffer_ops(copy_kv_blocks)
|
||||
|
||||
self._disabled = False
|
||||
|
||||
def pre_forward(self, scheduler_output: "SchedulerOutput") -> None:
|
||||
if self._disabled:
|
||||
return
|
||||
|
||||
if scheduler_output.preempted_req_ids:
|
||||
self.kv_connector.handle_preemptions(scheduler_output.preempted_req_ids)
|
||||
kv_connector_metadata = scheduler_output.kv_connector_metadata
|
||||
assert kv_connector_metadata is not None
|
||||
self.kv_connector.bind_connector_metadata(kv_connector_metadata)
|
||||
|
||||
# TODO: sort out KV Connectors' use of forward_context
|
||||
if is_forward_context_available():
|
||||
self.kv_connector.start_load_kv(get_forward_context())
|
||||
else:
|
||||
with set_forward_context(None, self.vllm_config):
|
||||
self.kv_connector.start_load_kv(get_forward_context())
|
||||
|
||||
def post_forward(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
wait_for_save: bool = True,
|
||||
clear_metadata: bool = True,
|
||||
) -> KVConnectorOutput | None:
|
||||
if self._disabled:
|
||||
return None
|
||||
|
||||
output = KVConnectorOutput()
|
||||
if wait_for_save:
|
||||
self.kv_connector.wait_for_save()
|
||||
output.finished_sending, output.finished_recving = (
|
||||
self.kv_connector.get_finished(scheduler_output.finished_req_ids)
|
||||
)
|
||||
output.invalid_block_ids = self.kv_connector.get_block_ids_with_load_errors()
|
||||
output.kv_connector_stats = self.kv_connector.get_kv_connector_stats()
|
||||
output.kv_cache_events = self.kv_connector.get_kv_connector_kv_cache_events()
|
||||
if clear_metadata:
|
||||
self.kv_connector.clear_connector_metadata()
|
||||
return output
|
||||
|
||||
def clear_metadata(self) -> None:
|
||||
"""Clear the connector metadata. Call this after draft model runs."""
|
||||
if not self._disabled:
|
||||
self.kv_connector.clear_connector_metadata()
|
||||
|
||||
def no_forward(self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
|
||||
if self._disabled:
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
self.pre_forward(scheduler_output)
|
||||
kv_connector_output = self.post_forward(scheduler_output, wait_for_save=False)
|
||||
if kv_connector_output is None or kv_connector_output.is_empty():
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
output.kv_connector_output = kv_connector_output
|
||||
return output
|
||||
|
||||
def set_disabled(self, disabled: bool) -> None:
|
||||
# Ensure that layer-wise connector hooks aren't called when disabled.
|
||||
kv_transfer_state._KV_CONNECTOR_AGENT = None if disabled else self.kv_connector
|
||||
self._disabled = disabled
|
||||
|
||||
|
||||
NO_OP_KV_CONNECTOR = KVConnector()
|
||||
|
||||
|
||||
def get_kv_connector(
|
||||
vllm_config: VllmConfig, kv_caches_dict: dict[str, torch.Tensor]
|
||||
) -> KVConnector:
|
||||
if not has_kv_transfer_group():
|
||||
# No-op connector.
|
||||
return NO_OP_KV_CONNECTOR
|
||||
|
||||
return ActiveKVConnector(vllm_config, kv_caches_dict)
|
||||
44
vllm/v1/worker/gpu/lora_utils.py
Normal file
44
vllm/v1/worker/gpu/lora_utils.py
Normal file
@@ -0,0 +1,44 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import numpy as np
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
NO_LORA_ID = 0
|
||||
|
||||
|
||||
class LoraState:
|
||||
def __init__(self, max_num_reqs: int):
|
||||
self.lora_ids = np.zeros(max_num_reqs, dtype=np.int32)
|
||||
self.lora_ids.fill(NO_LORA_ID)
|
||||
# req_id -> lora_request
|
||||
self.lora_requests: dict[str, LoRARequest] = {}
|
||||
|
||||
def add_request(
|
||||
self, req_id: str, req_index: int, lora_request: LoRARequest | None
|
||||
) -> None:
|
||||
if lora_request is not None:
|
||||
self.lora_requests[req_id] = lora_request
|
||||
self.lora_ids[req_index] = lora_request.lora_int_id
|
||||
else:
|
||||
self.lora_ids[req_index] = NO_LORA_ID
|
||||
|
||||
def remove_request(self, req_id: str) -> None:
|
||||
self.lora_requests.pop(req_id, None)
|
||||
|
||||
def make_lora_inputs(
|
||||
self,
|
||||
req_ids: list[str],
|
||||
idx_mapping: np.ndarray,
|
||||
num_scheduled_tokens: np.ndarray,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
|
||||
lora_ids = self.lora_ids[idx_mapping]
|
||||
prompt_lora_mapping = tuple(lora_ids)
|
||||
token_lora_mapping = tuple(lora_ids.repeat(num_scheduled_tokens))
|
||||
|
||||
active_lora_requests: set[LoRARequest] = set()
|
||||
for req_id in req_ids:
|
||||
lora_request = self.lora_requests.get(req_id)
|
||||
if lora_request is not None:
|
||||
active_lora_requests.add(lora_request)
|
||||
return prompt_lora_mapping, token_lora_mapping, active_lora_requests
|
||||
0
vllm/v1/worker/gpu/metrics/__init__.py
Normal file
0
vllm/v1/worker/gpu/metrics/__init__.py
Normal file
42
vllm/v1/worker/gpu/metrics/logits.py
Normal file
42
vllm/v1/worker/gpu/metrics/logits.py
Normal file
@@ -0,0 +1,42 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
from torch._inductor.runtime.triton_helpers import libdevice
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _num_nans_kernel(
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
num_nans_ptr,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
num_nans = 0
|
||||
for i in range(0, vocab_size, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < vocab_size
|
||||
logits = tl.load(
|
||||
logits_ptr + req_idx * logits_stride + block, mask=mask, other=0
|
||||
)
|
||||
logits = logits.to(tl.float32)
|
||||
is_nan = libdevice.isnan(logits).to(tl.int1)
|
||||
num_nans += tl.sum(is_nan).to(tl.int32)
|
||||
tl.store(num_nans_ptr + req_idx, num_nans)
|
||||
|
||||
|
||||
def get_num_nans(logits: torch.Tensor) -> torch.Tensor:
|
||||
num_reqs, vocab_size = logits.shape
|
||||
BLOCK_SIZE = 8192
|
||||
num_nans = torch.empty(num_reqs, dtype=torch.int32, device=logits.device)
|
||||
_num_nans_kernel[(num_reqs,)](
|
||||
logits,
|
||||
logits.stride(0),
|
||||
num_nans,
|
||||
vocab_size,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
return num_nans
|
||||
0
vllm/v1/worker/gpu/mm/__init__.py
Normal file
0
vllm/v1/worker/gpu/mm/__init__.py
Normal file
183
vllm/v1/worker/gpu/mm/encoder_runner.py
Normal file
183
vllm/v1/worker/gpu/mm/encoder_runner.py
Normal file
@@ -0,0 +1,183 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.models.interfaces import SupportsMultiModal
|
||||
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItem
|
||||
from vllm.multimodal.utils import group_mm_kwargs_by_modality
|
||||
from vllm.v1.worker.utils import sanity_check_mm_encoder_outputs
|
||||
|
||||
|
||||
class EncoderRunner:
|
||||
def __init__(
|
||||
self,
|
||||
max_num_tokens: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
):
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.hidden_size = hidden_size
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
self.inputs_embeds = torch.zeros(
|
||||
max_num_tokens, hidden_size, dtype=dtype, device=device
|
||||
)
|
||||
self.req_id_to_mm_features: dict[str, list[MultiModalFeatureSpec]] = {}
|
||||
self.encoder_cache: dict[str, torch.Tensor] = {}
|
||||
|
||||
def reset_mm_cache(self) -> None:
|
||||
"""
|
||||
Clear the multi-modal cache that was used during profiling,
|
||||
but no longer needed during inference.
|
||||
"""
|
||||
# TODO: Implement MM budget for encoder dummy run
|
||||
pass
|
||||
|
||||
def reset_encoder_cache(self) -> None:
|
||||
"""Clear the GPU-side encoder cache storing vision embeddings.
|
||||
|
||||
This should be called when model weights are updated to ensure
|
||||
stale embeddings computed with old weights are not reused.
|
||||
"""
|
||||
self.encoder_cache.clear()
|
||||
|
||||
def add_request(self, req_id: str, mm_features: list[MultiModalFeatureSpec]):
|
||||
self.req_id_to_mm_features[req_id] = mm_features
|
||||
|
||||
def free_encoder_cache(self, mm_hash: str) -> None:
|
||||
self.encoder_cache.pop(mm_hash, None)
|
||||
|
||||
def remove_request(self, req_id: str) -> None:
|
||||
self.req_id_to_mm_features.pop(req_id, None)
|
||||
|
||||
def prepare_mm_inputs(
|
||||
self, scheduled_encoder_inputs: dict[str, list[int]]
|
||||
) -> tuple[list[str], list[tuple[str, MultiModalKwargsItem]]]:
|
||||
mm_hashes: list[str] = []
|
||||
mm_kwargs: list[tuple[str, MultiModalKwargsItem]] = []
|
||||
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
|
||||
mm_features = self.req_id_to_mm_features[req_id]
|
||||
for mm_input_id in encoder_input_ids:
|
||||
mm_feature = mm_features[mm_input_id]
|
||||
if mm_feature.data is None:
|
||||
continue
|
||||
mm_hashes.append(mm_feature.identifier)
|
||||
mm_kwargs.append((mm_feature.modality, mm_feature.data))
|
||||
|
||||
return mm_hashes, mm_kwargs
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_mm_encoder(
|
||||
self,
|
||||
model: SupportsMultiModal,
|
||||
mm_hashes: list[str],
|
||||
mm_kwargs: list[tuple[str, MultiModalKwargsItem]],
|
||||
) -> list[torch.Tensor]:
|
||||
if not mm_hashes:
|
||||
return []
|
||||
|
||||
encoder_outputs: list[torch.Tensor] = []
|
||||
for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
|
||||
mm_kwargs, device=self.device, pin_memory=False
|
||||
):
|
||||
curr_group_outputs = model.embed_multimodal(**mm_kwargs_group)
|
||||
sanity_check_mm_encoder_outputs(
|
||||
curr_group_outputs, expected_num_items=num_items
|
||||
)
|
||||
encoder_outputs.extend(curr_group_outputs)
|
||||
|
||||
# Cache the encoder outputs by mm_hash
|
||||
self.encoder_cache.update(zip(mm_hashes, encoder_outputs))
|
||||
return encoder_outputs
|
||||
|
||||
def gather_mm_embeddings(
|
||||
self,
|
||||
req_ids: list[str],
|
||||
total_num_scheduled_tokens: int,
|
||||
num_scheduled_tokens: np.ndarray,
|
||||
query_start_loc: np.ndarray,
|
||||
prefill_lens: np.ndarray,
|
||||
computed_prefill_lens: np.ndarray,
|
||||
) -> tuple[list[torch.Tensor], torch.Tensor]:
|
||||
is_prefilling = (computed_prefill_lens < prefill_lens).tolist()
|
||||
all_decode = not any(is_prefilling)
|
||||
if all_decode:
|
||||
# All decode requests, so no need to gather any embeddings.
|
||||
return [], torch.zeros(
|
||||
total_num_scheduled_tokens, dtype=torch.bool, device=self.device
|
||||
)
|
||||
|
||||
query_start = computed_prefill_lens.tolist()
|
||||
query_end = (computed_prefill_lens + num_scheduled_tokens).tolist()
|
||||
|
||||
mm_embeds: list[torch.Tensor] = []
|
||||
is_mm_embed = torch.zeros(
|
||||
total_num_scheduled_tokens, dtype=torch.bool, device="cpu", pin_memory=True
|
||||
)
|
||||
for i, req_id in enumerate(req_ids):
|
||||
if not is_prefilling[i]:
|
||||
# OPTIMIZATION: Skip decode requests.
|
||||
continue
|
||||
|
||||
mm_features = self.req_id_to_mm_features[req_id]
|
||||
for mm_feature in mm_features:
|
||||
pos_info = mm_feature.mm_position
|
||||
start_pos = pos_info.offset
|
||||
num_encoder_tokens = pos_info.length
|
||||
|
||||
if start_pos >= query_end[i]:
|
||||
# The encoder output is not needed in this step.
|
||||
break
|
||||
if start_pos + num_encoder_tokens <= query_start[i]:
|
||||
# The encoder output is already processed and stored
|
||||
# in the decoder's KV cache.
|
||||
continue
|
||||
|
||||
start_idx = max(query_start[i] - start_pos, 0)
|
||||
end_idx = min(query_end[i] - start_pos, num_encoder_tokens)
|
||||
assert start_idx < end_idx
|
||||
curr_embeds_start, curr_embeds_end = (
|
||||
pos_info.get_embeds_indices_in_range(start_idx, end_idx)
|
||||
)
|
||||
# If there are no embeddings in the current range, we skip
|
||||
# gathering the embeddings.
|
||||
if curr_embeds_start == curr_embeds_end:
|
||||
continue
|
||||
|
||||
mm_hash = mm_feature.identifier
|
||||
encoder_output = self.encoder_cache.get(mm_hash, None)
|
||||
assert encoder_output is not None, f"Encoder cache miss for {mm_hash}."
|
||||
|
||||
if (is_embed := pos_info.is_embed) is not None:
|
||||
is_embed = is_embed[start_idx:end_idx]
|
||||
mm_embeds_item = encoder_output[curr_embeds_start:curr_embeds_end]
|
||||
else:
|
||||
mm_embeds_item = encoder_output[start_idx:end_idx]
|
||||
|
||||
req_start_pos = query_start_loc[i] + start_pos - query_start[i]
|
||||
is_mm_embed[req_start_pos + start_idx : req_start_pos + end_idx] = (
|
||||
True if is_embed is None else is_embed
|
||||
)
|
||||
mm_embeds.append(mm_embeds_item)
|
||||
|
||||
# Copy the is_mm_embed tensor to the GPU.
|
||||
is_mm_embed = is_mm_embed.to(device=self.device, non_blocking=True)
|
||||
return mm_embeds, is_mm_embed
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_inputs_embeds(
|
||||
self,
|
||||
model: SupportsMultiModal,
|
||||
input_ids: torch.Tensor,
|
||||
mm_embeds: list[torch.Tensor],
|
||||
is_mm_embed: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
x = model.embed_input_ids(
|
||||
input_ids, multimodal_embeddings=mm_embeds, is_multimodal=is_mm_embed
|
||||
)
|
||||
# Copy to the pre-allocated buffer for CUDA graphs.
|
||||
self.inputs_embeds[: x.shape[0]] = x
|
||||
return self.inputs_embeds
|
||||
136
vllm/v1/worker/gpu/mm/mrope_utils.py
Normal file
136
vllm/v1/worker/gpu/mm/mrope_utils.py
Normal file
@@ -0,0 +1,136 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.models.interfaces import SupportsMRoPE
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor
|
||||
|
||||
|
||||
class MRopeState:
|
||||
def __init__(
|
||||
self,
|
||||
max_num_reqs: int,
|
||||
max_num_tokens: int,
|
||||
max_model_len: int,
|
||||
device: torch.device,
|
||||
):
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.max_model_len = max_model_len
|
||||
self.device = device
|
||||
|
||||
# NOTE(woosuk): This tensor can be extremely large (e.g., several GBs)
|
||||
# wasting a lot of CPU memory.
|
||||
self.prefill_mrope_positions = StagedWriteTensor(
|
||||
(max_num_reqs * 3, max_model_len),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
uva_instead_of_gpu=True,
|
||||
)
|
||||
self.prefill_mrope_delta = UvaBackedTensor(max_num_reqs, dtype=torch.int32)
|
||||
|
||||
# NOTE: `mrope_positions` is implemented with one additional dummy
|
||||
# position on purpose to make it non-contiguous so that it can work
|
||||
# with torch compile.
|
||||
# See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923
|
||||
# NOTE: When M-RoPE is enabled, position ids are 3D regardless of
|
||||
# the modality of inputs. For text-only inputs, each dimension has
|
||||
# identical position IDs, making M-RoPE functionally equivalent to
|
||||
# 1D-RoPE.
|
||||
# See page 5 of https://arxiv.org/abs/2409.12191
|
||||
self.mrope_positions = torch.zeros(
|
||||
(3, max_num_tokens + 1), dtype=torch.int64, device=device
|
||||
)
|
||||
|
||||
def init_prefill_mrope_positions(
|
||||
self,
|
||||
req_idx: int,
|
||||
mrope_model: SupportsMRoPE,
|
||||
prefill_token_ids: list[int],
|
||||
mm_features: list,
|
||||
) -> None:
|
||||
prefill_mrope_positions, prefill_mrope_delta = (
|
||||
mrope_model.get_mrope_input_positions(prefill_token_ids, mm_features)
|
||||
)
|
||||
for i in range(3):
|
||||
pos = prefill_mrope_positions[i].tolist()
|
||||
self.prefill_mrope_positions.stage_write(3 * req_idx + i, 0, pos)
|
||||
self.prefill_mrope_delta.np[req_idx] = prefill_mrope_delta
|
||||
|
||||
def apply_staged_writes(self) -> None:
|
||||
self.prefill_mrope_positions.apply_write()
|
||||
self.prefill_mrope_delta.copy_to_uva()
|
||||
|
||||
def prepare_mrope_positions(
|
||||
self,
|
||||
idx_mapping: torch.Tensor,
|
||||
query_start_loc: torch.Tensor,
|
||||
prefill_lens: torch.Tensor,
|
||||
num_computed_tokens: torch.Tensor,
|
||||
) -> None:
|
||||
num_reqs = idx_mapping.shape[0]
|
||||
_prepare_mrope_positions_kernel[(num_reqs,)](
|
||||
self.mrope_positions,
|
||||
self.mrope_positions.stride(0),
|
||||
self.prefill_mrope_positions.gpu,
|
||||
3 * self.max_model_len,
|
||||
self.max_model_len,
|
||||
self.prefill_mrope_delta.gpu,
|
||||
idx_mapping,
|
||||
query_start_loc,
|
||||
prefill_lens,
|
||||
num_computed_tokens,
|
||||
BLOCK_SIZE=1024,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _prepare_mrope_positions_kernel(
|
||||
mrope_positions_ptr,
|
||||
mrope_positions_stride,
|
||||
prefill_mrope_positions_ptr,
|
||||
prefill_mrope_positions_stride0,
|
||||
prefill_mrope_positions_stride1,
|
||||
prefill_mrope_delta_ptr,
|
||||
idx_mapping_ptr,
|
||||
query_start_loc_ptr,
|
||||
prefill_lens_ptr,
|
||||
num_computed_tokens_ptr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
||||
|
||||
prefill_len = tl.load(prefill_lens_ptr + req_state_idx)
|
||||
num_computed = tl.load(num_computed_tokens_ptr + req_state_idx)
|
||||
is_prefill = num_computed < prefill_len
|
||||
|
||||
query_start = tl.load(query_start_loc_ptr + batch_idx)
|
||||
query_end = tl.load(query_start_loc_ptr + batch_idx + 1)
|
||||
query_len = query_end - query_start
|
||||
|
||||
mrope_delta = tl.load(prefill_mrope_delta_ptr + req_state_idx)
|
||||
for i in range(0, query_len, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < query_len
|
||||
orig_pos = num_computed + block
|
||||
|
||||
for j in tl.static_range(3):
|
||||
if is_prefill:
|
||||
# Read from pre-computed M-RoPE positions.
|
||||
pos = tl.load(
|
||||
prefill_mrope_positions_ptr
|
||||
+ req_state_idx * prefill_mrope_positions_stride0
|
||||
+ j * prefill_mrope_positions_stride1
|
||||
+ orig_pos,
|
||||
mask=mask,
|
||||
)
|
||||
else:
|
||||
# Apply M-RoPE delta.
|
||||
pos = orig_pos + mrope_delta
|
||||
tl.store(
|
||||
mrope_positions_ptr + j * mrope_positions_stride + query_start + block,
|
||||
pos,
|
||||
mask=mask,
|
||||
)
|
||||
1129
vllm/v1/worker/gpu/model_runner.py
Normal file
1129
vllm/v1/worker/gpu/model_runner.py
Normal file
File diff suppressed because it is too large
Load Diff
41
vllm/v1/worker/gpu/pp_utils.py
Normal file
41
vllm/v1/worker/gpu/pp_utils.py
Normal file
@@ -0,0 +1,41 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Pipeline Parallelism utils for V2 Model Runner."""
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
|
||||
|
||||
def pp_broadcast(
|
||||
sampled_token_ids: torch.Tensor,
|
||||
num_sampled: torch.Tensor,
|
||||
num_rejected: torch.Tensor,
|
||||
) -> None:
|
||||
pp = get_pp_group()
|
||||
assert pp.is_last_rank
|
||||
|
||||
assert sampled_token_ids.dtype == torch.int64
|
||||
torch.distributed.broadcast(
|
||||
sampled_token_ids.contiguous(), src=pp.last_rank, group=pp.device_group
|
||||
)
|
||||
|
||||
combined = torch.stack((num_sampled, num_rejected), dim=0)
|
||||
torch.distributed.broadcast(combined, src=pp.last_rank, group=pp.device_group)
|
||||
|
||||
|
||||
def pp_receive(
|
||||
num_reqs: int, max_sample_len: int = 1
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
pp = get_pp_group()
|
||||
assert not pp.is_last_rank
|
||||
|
||||
sampled_tokens = torch.empty(
|
||||
num_reqs, max_sample_len, dtype=torch.int64, device=pp.device
|
||||
)
|
||||
torch.distributed.broadcast(sampled_tokens, src=pp.last_rank, group=pp.device_group)
|
||||
|
||||
combined = torch.empty(2, num_reqs, dtype=torch.int32, device=pp.device)
|
||||
torch.distributed.broadcast(combined, src=pp.last_rank, group=pp.device_group)
|
||||
num_sampled, num_rejected = combined.unbind(dim=0)
|
||||
return sampled_tokens, num_sampled, num_rejected
|
||||
0
vllm/v1/worker/gpu/sample/__init__.py
Normal file
0
vllm/v1/worker/gpu/sample/__init__.py
Normal file
194
vllm/v1/worker/gpu/sample/bad_words.py
Normal file
194
vllm/v1/worker/gpu/sample/bad_words.py
Normal file
@@ -0,0 +1,194 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor
|
||||
from vllm.v1.worker.gpu.states import RequestState
|
||||
|
||||
MAX_BAD_WORDS_TOTAL_TOKENS = 1024 # Max total tokens for all bad words per request
|
||||
MAX_NUM_BAD_WORDS = 128 # Max number of bad words per request
|
||||
|
||||
|
||||
class BadWordsState:
|
||||
def __init__(self, req_states: RequestState):
|
||||
self.req_states = req_states
|
||||
self.max_num_reqs = req_states.max_num_reqs
|
||||
self.device = req_states.device
|
||||
|
||||
# flattened bad word tokens: [max_num_reqs, MAX_BAD_WORDS_TOTAL_TOKENS]
|
||||
self.bad_word_token_ids = StagedWriteTensor(
|
||||
(self.max_num_reqs, MAX_BAD_WORDS_TOTAL_TOKENS),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
# cumulative offsets of bad words: [max_num_reqs, MAX_NUM_BAD_WORDS + 1]
|
||||
self.bad_word_offsets = StagedWriteTensor(
|
||||
(self.max_num_reqs, MAX_NUM_BAD_WORDS + 1),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
# number of bad words per request
|
||||
self.num_bad_words = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32)
|
||||
|
||||
def add_request(self, req_idx: int, sampling_params: SamplingParams) -> None:
|
||||
bad_words_token_ids = sampling_params.bad_words_token_ids
|
||||
if not bad_words_token_ids:
|
||||
self.num_bad_words.np[req_idx] = 0
|
||||
return
|
||||
|
||||
num_bad_words = len(bad_words_token_ids)
|
||||
if num_bad_words > MAX_NUM_BAD_WORDS:
|
||||
raise ValueError(
|
||||
f"Too many bad words: {num_bad_words}. "
|
||||
f"The max number is {MAX_NUM_BAD_WORDS}."
|
||||
)
|
||||
|
||||
# Flatten bad words and compute offsets
|
||||
flattened_tokens: list[int] = []
|
||||
offsets: list[int] = [0]
|
||||
for bad_word in bad_words_token_ids:
|
||||
flattened_tokens.extend(bad_word)
|
||||
offsets.append(len(flattened_tokens))
|
||||
|
||||
if len(flattened_tokens) > MAX_BAD_WORDS_TOTAL_TOKENS:
|
||||
raise ValueError(
|
||||
f"Too many total bad word tokens: {len(flattened_tokens)}. "
|
||||
f"The max is {MAX_BAD_WORDS_TOTAL_TOKENS}."
|
||||
)
|
||||
|
||||
# Stage writes
|
||||
self.bad_word_token_ids.stage_write(req_idx, 0, flattened_tokens)
|
||||
self.bad_word_offsets.stage_write(req_idx, 0, offsets)
|
||||
self.num_bad_words.np[req_idx] = num_bad_words
|
||||
|
||||
def apply_staged_writes(self) -> None:
|
||||
self.num_bad_words.copy_to_uva()
|
||||
self.bad_word_token_ids.apply_write()
|
||||
self.bad_word_offsets.apply_write()
|
||||
|
||||
def apply_bad_words(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
idx_mapping_np: np.ndarray,
|
||||
input_ids: torch.Tensor,
|
||||
expanded_local_pos: torch.Tensor,
|
||||
) -> None:
|
||||
max_num_bad_words = int(self.num_bad_words.np[idx_mapping_np].max())
|
||||
if max_num_bad_words == 0:
|
||||
# No request uses bad words. Skip the kernel launch.
|
||||
return
|
||||
|
||||
apply_bad_words(
|
||||
logits,
|
||||
idx_mapping,
|
||||
self.bad_word_token_ids.gpu,
|
||||
self.bad_word_offsets.gpu,
|
||||
self.num_bad_words.gpu,
|
||||
self.req_states.all_token_ids.gpu,
|
||||
self.req_states.prompt_len.gpu,
|
||||
self.req_states.total_len.gpu,
|
||||
input_ids,
|
||||
expanded_local_pos,
|
||||
max_num_bad_words,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _bad_words_kernel(
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
expanded_idx_mapping_ptr,
|
||||
bad_word_token_ids_ptr,
|
||||
bad_word_token_ids_stride,
|
||||
bad_word_offsets_ptr,
|
||||
bad_word_offsets_stride,
|
||||
num_bad_words_ptr,
|
||||
all_token_ids_ptr,
|
||||
all_token_ids_stride,
|
||||
prompt_len_ptr,
|
||||
total_len_ptr,
|
||||
input_ids_ptr,
|
||||
expanded_local_pos_ptr,
|
||||
):
|
||||
logit_idx = tl.program_id(0)
|
||||
bw_idx = tl.program_id(1)
|
||||
|
||||
req_state_idx = tl.load(expanded_idx_mapping_ptr + logit_idx)
|
||||
num_bad_words = tl.load(num_bad_words_ptr + req_state_idx)
|
||||
|
||||
if bw_idx >= num_bad_words:
|
||||
return
|
||||
|
||||
pos = tl.load(expanded_local_pos_ptr + logit_idx)
|
||||
cur_req_first_pos = logit_idx - pos
|
||||
|
||||
prompt_len = tl.load(prompt_len_ptr + req_state_idx)
|
||||
total_len = tl.load(total_len_ptr + req_state_idx)
|
||||
output_len = total_len - prompt_len
|
||||
effective_len = output_len + pos
|
||||
|
||||
bd_offsets_base = bad_word_offsets_ptr + req_state_idx * bad_word_offsets_stride
|
||||
bd_tokens_base = bad_word_token_ids_ptr + req_state_idx * bad_word_token_ids_stride
|
||||
output_base = all_token_ids_ptr + req_state_idx * all_token_ids_stride + prompt_len
|
||||
|
||||
start = tl.load(bd_offsets_base + bw_idx)
|
||||
end = tl.load(bd_offsets_base + bw_idx + 1)
|
||||
bad_word_len = end - start
|
||||
prefix_len = bad_word_len - 1
|
||||
|
||||
if prefix_len > effective_len:
|
||||
return
|
||||
|
||||
last_token = tl.load(bd_tokens_base + end - 1)
|
||||
match = 1
|
||||
for i in range(prefix_len):
|
||||
expected = tl.load(bd_tokens_base + start + i)
|
||||
actual_pos = effective_len - prefix_len + i
|
||||
|
||||
from_spec_input = actual_pos >= output_len
|
||||
if from_spec_input:
|
||||
spec_offset = actual_pos - output_len
|
||||
actual = tl.load(input_ids_ptr + cur_req_first_pos + spec_offset)
|
||||
else:
|
||||
actual = tl.load(output_base + actual_pos)
|
||||
|
||||
match = match & (expected == actual)
|
||||
|
||||
if match:
|
||||
tl.store(logits_ptr + logit_idx * logits_stride + last_token, -float("inf"))
|
||||
|
||||
|
||||
def apply_bad_words(
|
||||
logits: torch.Tensor,
|
||||
expanded_idx_mapping: torch.Tensor,
|
||||
bad_word_token_ids: torch.Tensor,
|
||||
bad_word_offsets: torch.Tensor,
|
||||
num_bad_words: torch.Tensor,
|
||||
all_token_ids: torch.Tensor,
|
||||
prompt_len: torch.Tensor,
|
||||
total_len: torch.Tensor,
|
||||
input_ids: torch.Tensor,
|
||||
expanded_local_pos: torch.Tensor,
|
||||
max_num_bad_words: int,
|
||||
) -> None:
|
||||
total_num_tokens = logits.shape[0]
|
||||
_bad_words_kernel[(total_num_tokens, max_num_bad_words)](
|
||||
logits,
|
||||
logits.stride(0),
|
||||
expanded_idx_mapping,
|
||||
bad_word_token_ids,
|
||||
bad_word_token_ids.stride(0),
|
||||
bad_word_offsets,
|
||||
bad_word_offsets.stride(0),
|
||||
num_bad_words,
|
||||
all_token_ids,
|
||||
all_token_ids.stride(0),
|
||||
prompt_len,
|
||||
total_len,
|
||||
input_ids,
|
||||
expanded_local_pos,
|
||||
)
|
||||
149
vllm/v1/worker/gpu/sample/gumbel.py
Normal file
149
vllm/v1/worker/gpu/sample/gumbel.py
Normal file
@@ -0,0 +1,149 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _temperature_kernel(
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
idx_mapping_ptr,
|
||||
temperature_ptr,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
||||
temperature = tl.load(temperature_ptr + req_state_idx).to(tl.float32)
|
||||
if temperature == 0.0 or temperature == 1.0:
|
||||
# Early return to avoid loading logits.
|
||||
return
|
||||
|
||||
block_idx = tl.program_id(1)
|
||||
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < vocab_size
|
||||
|
||||
logits = tl.load(logits_ptr + batch_idx * logits_stride + block, mask=mask)
|
||||
logits = logits.to(tl.float32)
|
||||
logits = logits / temperature
|
||||
tl.store(logits_ptr + batch_idx * logits_stride + block, logits, mask=mask)
|
||||
|
||||
|
||||
def apply_temperature(
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
temperature: torch.Tensor,
|
||||
) -> None:
|
||||
num_reqs, vocab_size = logits.shape
|
||||
BLOCK_SIZE = 8192
|
||||
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
|
||||
_temperature_kernel[(num_reqs, num_blocks)](
|
||||
logits,
|
||||
logits.stride(0),
|
||||
idx_mapping,
|
||||
temperature,
|
||||
vocab_size,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _gumbel_sample_kernel(
|
||||
local_argmax_ptr,
|
||||
local_argmax_stride,
|
||||
local_max_ptr,
|
||||
local_max_stride,
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
idx_mapping_ptr,
|
||||
seeds_ptr,
|
||||
pos_ptr,
|
||||
temp_ptr,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
APPLY_TEMPERATURE: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
||||
|
||||
block_idx = tl.program_id(1)
|
||||
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < vocab_size
|
||||
logits = tl.load(
|
||||
logits_ptr + batch_idx * logits_stride + block,
|
||||
mask=mask,
|
||||
other=float("-inf"),
|
||||
)
|
||||
logits = logits.to(tl.float32)
|
||||
|
||||
temp = tl.load(temp_ptr + req_state_idx).to(tl.float32)
|
||||
if temp != 0.0:
|
||||
# Calculate the seed for gumbel noise.
|
||||
seed = tl.load(seeds_ptr + req_state_idx)
|
||||
pos = tl.load(pos_ptr + batch_idx)
|
||||
gumbel_seed = tl.randint(seed, pos)
|
||||
|
||||
# Generate gumbel noise in FP32.
|
||||
u = tl.rand(gumbel_seed, block)
|
||||
u = tl.maximum(u, 1e-7)
|
||||
gumbel_noise = -tl.log(-tl.log(u))
|
||||
|
||||
# Apply temperature.
|
||||
if APPLY_TEMPERATURE:
|
||||
# NOTE(woosuk): Match the behavior of _temperature_kernel.
|
||||
# E.g., if the kernel uses tl.div_rn, we should use tl.div_rn here too.
|
||||
logits = logits / temp
|
||||
|
||||
# Apply gumbel noise.
|
||||
logits = tl.where(mask, logits + gumbel_noise, float("-inf"))
|
||||
|
||||
value, idx = tl.max(logits, axis=0, return_indices=True)
|
||||
token_id = block_idx * BLOCK_SIZE + idx
|
||||
tl.store(local_argmax_ptr + batch_idx * local_argmax_stride + block_idx, token_id)
|
||||
tl.store(local_max_ptr + batch_idx * local_max_stride + block_idx, value)
|
||||
|
||||
|
||||
def gumbel_sample(
|
||||
logits: torch.Tensor, # [num_reqs, vocab_size]
|
||||
idx_mapping: torch.Tensor, # [max_num_reqs]
|
||||
temperature: torch.Tensor, # [max_num_reqs]
|
||||
seed: torch.Tensor, # [max_num_reqs]
|
||||
pos: torch.Tensor, # [num_reqs]
|
||||
apply_temperature: bool,
|
||||
) -> torch.Tensor:
|
||||
num_reqs, vocab_size = logits.shape
|
||||
BLOCK_SIZE = 1024
|
||||
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
|
||||
local_argmax = torch.empty(
|
||||
num_reqs,
|
||||
num_blocks,
|
||||
dtype=torch.int64,
|
||||
device=logits.device,
|
||||
)
|
||||
local_max = torch.empty(
|
||||
num_reqs,
|
||||
num_blocks,
|
||||
dtype=torch.float32,
|
||||
device=logits.device,
|
||||
)
|
||||
_gumbel_sample_kernel[(num_reqs, num_blocks)](
|
||||
local_argmax,
|
||||
local_argmax.stride(0),
|
||||
local_max,
|
||||
local_max.stride(0),
|
||||
logits,
|
||||
logits.stride(0),
|
||||
idx_mapping,
|
||||
seed,
|
||||
pos,
|
||||
temperature,
|
||||
vocab_size,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
APPLY_TEMPERATURE=apply_temperature,
|
||||
)
|
||||
# NOTE(woosuk): Use int64 for later indexing.
|
||||
max_block_idx = local_max.argmax(dim=-1, keepdim=True)
|
||||
sampled = local_argmax.gather(dim=-1, index=max_block_idx).view(-1)
|
||||
return sampled
|
||||
280
vllm/v1/worker/gpu/sample/logit_bias.py
Normal file
280
vllm/v1/worker/gpu/sample/logit_bias.py
Normal file
@@ -0,0 +1,280 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor
|
||||
|
||||
MAX_NUM_ALLOWED_TOKEN_IDS = 1024
|
||||
MAX_NUM_LOGIT_BIAS_TOKENS = 1024
|
||||
MAX_NUM_STOP_TOKEN_IDS = 128
|
||||
|
||||
|
||||
class LogitBiasState:
|
||||
def __init__(self, max_num_reqs: int, device: torch.device):
|
||||
self.max_num_reqs = max_num_reqs
|
||||
|
||||
# Allowed token IDs.
|
||||
self.num_allowed_token_ids = UvaBackedTensor(
|
||||
self.max_num_reqs, dtype=torch.int32
|
||||
)
|
||||
self.allowed_token_ids = StagedWriteTensor(
|
||||
(self.max_num_reqs, MAX_NUM_ALLOWED_TOKEN_IDS),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
# Logit bias.
|
||||
self.num_logit_bias = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32)
|
||||
self.logit_bias_token_ids = StagedWriteTensor(
|
||||
(self.max_num_reqs, MAX_NUM_LOGIT_BIAS_TOKENS),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.logit_bias = StagedWriteTensor(
|
||||
(self.max_num_reqs, MAX_NUM_LOGIT_BIAS_TOKENS),
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
# Min tokens.
|
||||
self.min_lens = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32)
|
||||
self.num_stop_token_ids = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32)
|
||||
self.stop_token_ids = StagedWriteTensor(
|
||||
(self.max_num_reqs, MAX_NUM_STOP_TOKEN_IDS),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Using any of the above.
|
||||
self.use_logit_bias = np.zeros(max_num_reqs, dtype=bool)
|
||||
|
||||
def add_request(
|
||||
self, req_idx: int, prompt_len: int, sampling_params: SamplingParams
|
||||
) -> None:
|
||||
# Using any logit bias.
|
||||
use_logit_bias = False
|
||||
|
||||
# Allowed token IDs.
|
||||
allowed_token_ids = sampling_params.allowed_token_ids
|
||||
if allowed_token_ids:
|
||||
num_allowed_token_ids = len(allowed_token_ids)
|
||||
if num_allowed_token_ids > MAX_NUM_ALLOWED_TOKEN_IDS:
|
||||
raise ValueError(
|
||||
f"Too many allowed token IDs: {num_allowed_token_ids}. "
|
||||
f"The max size is {MAX_NUM_ALLOWED_TOKEN_IDS}."
|
||||
)
|
||||
self.num_allowed_token_ids.np[req_idx] = num_allowed_token_ids
|
||||
self.allowed_token_ids.stage_write(req_idx, 0, allowed_token_ids)
|
||||
use_logit_bias = True
|
||||
else:
|
||||
self.num_allowed_token_ids.np[req_idx] = 0
|
||||
|
||||
# Logit bias.
|
||||
logit_bias = sampling_params.logit_bias
|
||||
if logit_bias:
|
||||
num_logit_bias = len(logit_bias)
|
||||
if num_logit_bias > MAX_NUM_LOGIT_BIAS_TOKENS:
|
||||
raise ValueError(
|
||||
f"Too many logit bias tokens: {num_logit_bias}. "
|
||||
f"The max size is {MAX_NUM_LOGIT_BIAS_TOKENS}."
|
||||
)
|
||||
self.num_logit_bias.np[req_idx] = num_logit_bias
|
||||
self.logit_bias_token_ids.stage_write(req_idx, 0, logit_bias.keys())
|
||||
self.logit_bias.stage_write(req_idx, 0, logit_bias.values())
|
||||
use_logit_bias = True
|
||||
else:
|
||||
self.num_logit_bias.np[req_idx] = 0
|
||||
|
||||
# Min tokens.
|
||||
min_tokens = sampling_params.min_tokens
|
||||
min_len = prompt_len + min_tokens
|
||||
self.min_lens.np[req_idx] = min_len
|
||||
stop_token_ids = sampling_params.all_stop_token_ids
|
||||
if min_tokens > 0 and stop_token_ids:
|
||||
num_stop_token_ids = len(stop_token_ids)
|
||||
if num_stop_token_ids > MAX_NUM_STOP_TOKEN_IDS:
|
||||
raise ValueError(
|
||||
f"Too many stop tokens: {num_stop_token_ids}. "
|
||||
f"The max size is {MAX_NUM_STOP_TOKEN_IDS}."
|
||||
)
|
||||
self.num_stop_token_ids.np[req_idx] = num_stop_token_ids
|
||||
self.stop_token_ids.stage_write(req_idx, 0, stop_token_ids)
|
||||
use_logit_bias = True
|
||||
else:
|
||||
self.num_stop_token_ids.np[req_idx] = 0
|
||||
|
||||
self.use_logit_bias[req_idx] = use_logit_bias
|
||||
|
||||
def apply_staged_writes(self) -> None:
|
||||
self.num_allowed_token_ids.copy_to_uva()
|
||||
self.allowed_token_ids.apply_write()
|
||||
|
||||
self.num_logit_bias.copy_to_uva()
|
||||
self.logit_bias_token_ids.apply_write()
|
||||
self.logit_bias.apply_write()
|
||||
|
||||
self.min_lens.copy_to_uva()
|
||||
self.num_stop_token_ids.copy_to_uva()
|
||||
self.stop_token_ids.apply_write()
|
||||
|
||||
def apply_logit_bias(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
idx_mapping_np: np.ndarray,
|
||||
pos: torch.Tensor,
|
||||
) -> None:
|
||||
if not np.any(self.use_logit_bias[idx_mapping_np]):
|
||||
# No request uses logit bias. Skip the kernel launch.
|
||||
return
|
||||
|
||||
apply_logit_bias(
|
||||
logits,
|
||||
idx_mapping,
|
||||
pos,
|
||||
self.num_allowed_token_ids.gpu,
|
||||
self.allowed_token_ids.gpu,
|
||||
self.num_logit_bias.gpu,
|
||||
self.logit_bias_token_ids.gpu,
|
||||
self.logit_bias.gpu,
|
||||
self.min_lens.gpu,
|
||||
self.num_stop_token_ids.gpu,
|
||||
self.stop_token_ids.gpu,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _bias_kernel(
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
vocab_size,
|
||||
idx_mapping_ptr,
|
||||
# Allowed token IDs.
|
||||
num_allowed_token_ids_ptr,
|
||||
allowed_token_ids_ptr,
|
||||
allowed_token_ids_stride,
|
||||
# Logit bias.
|
||||
num_logit_bias_ptr,
|
||||
bias_token_ids_ptr,
|
||||
bias_token_ids_stride,
|
||||
bias_ptr,
|
||||
bias_stride,
|
||||
# Min tokens.
|
||||
pos_ptr,
|
||||
min_lens_ptr,
|
||||
num_stop_token_ids_ptr,
|
||||
stop_token_ids_ptr,
|
||||
stop_token_ids_stride,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
LOGITS_BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
||||
|
||||
block = tl.arange(0, BLOCK_SIZE)
|
||||
|
||||
# Allowed token IDs.
|
||||
num_allowed_token_ids = tl.load(num_allowed_token_ids_ptr + req_state_idx)
|
||||
if num_allowed_token_ids > 0:
|
||||
block = tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < num_allowed_token_ids
|
||||
|
||||
# Save logits for allowed token IDs.
|
||||
allowed_token_ids = tl.load(
|
||||
allowed_token_ids_ptr + req_state_idx * allowed_token_ids_stride + block,
|
||||
mask=mask,
|
||||
)
|
||||
logits = tl.load(
|
||||
logits_ptr + batch_idx * logits_stride + allowed_token_ids, mask=mask
|
||||
)
|
||||
|
||||
# Set logits to -inf for all tokens.
|
||||
for i in range(0, vocab_size, LOGITS_BLOCK_SIZE):
|
||||
offset = i + tl.arange(0, LOGITS_BLOCK_SIZE)
|
||||
tl.store(
|
||||
logits_ptr + batch_idx * logits_stride + offset,
|
||||
-float("inf"),
|
||||
mask=offset < vocab_size,
|
||||
)
|
||||
|
||||
# Restore logits for allowed token IDs.
|
||||
tl.store(
|
||||
logits_ptr + batch_idx * logits_stride + allowed_token_ids,
|
||||
logits,
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
# Logit bias.
|
||||
num_logit_bias = tl.load(num_logit_bias_ptr + req_state_idx)
|
||||
if num_logit_bias > 0:
|
||||
mask = block < num_logit_bias
|
||||
token_ids = tl.load(
|
||||
bias_token_ids_ptr + req_state_idx * bias_token_ids_stride + block,
|
||||
mask=mask,
|
||||
)
|
||||
bias = tl.load(bias_ptr + req_state_idx * bias_stride + block, mask=mask)
|
||||
logits = tl.load(logits_ptr + batch_idx * logits_stride + token_ids, mask=mask)
|
||||
logits += bias
|
||||
tl.store(logits_ptr + batch_idx * logits_stride + token_ids, logits, mask=mask)
|
||||
|
||||
# Apply min tokens.
|
||||
num_stop_token_ids = tl.load(num_stop_token_ids_ptr + req_state_idx)
|
||||
pos = tl.load(pos_ptr + batch_idx)
|
||||
min_len = tl.load(min_lens_ptr + req_state_idx)
|
||||
if num_stop_token_ids > 0 and pos < min_len:
|
||||
mask = block < num_stop_token_ids
|
||||
stop_token_ids = tl.load(
|
||||
stop_token_ids_ptr + req_state_idx * stop_token_ids_stride + block,
|
||||
mask=mask,
|
||||
)
|
||||
tl.store(
|
||||
logits_ptr + batch_idx * logits_stride + stop_token_ids,
|
||||
-float("inf"),
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
|
||||
def apply_logit_bias(
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
pos: torch.Tensor,
|
||||
num_allowed_token_ids: torch.Tensor,
|
||||
allowed_token_ids: torch.Tensor,
|
||||
num_logit_bias: torch.Tensor,
|
||||
logit_bias_token_ids: torch.Tensor,
|
||||
logit_bias: torch.Tensor,
|
||||
min_lens: torch.Tensor,
|
||||
num_stop_token_ids: torch.Tensor,
|
||||
stop_token_ids: torch.Tensor,
|
||||
) -> None:
|
||||
num_reqs, vocab_size = logits.shape
|
||||
BLOCK_SIZE = triton.next_power_of_2(
|
||||
max(
|
||||
allowed_token_ids.shape[-1],
|
||||
logit_bias_token_ids.shape[-1],
|
||||
stop_token_ids.shape[-1],
|
||||
)
|
||||
)
|
||||
LOGITS_BLOCK_SIZE = 8192
|
||||
_bias_kernel[(num_reqs,)](
|
||||
logits,
|
||||
logits.stride(0),
|
||||
vocab_size,
|
||||
idx_mapping,
|
||||
num_allowed_token_ids,
|
||||
allowed_token_ids,
|
||||
allowed_token_ids.stride(0),
|
||||
num_logit_bias,
|
||||
logit_bias_token_ids,
|
||||
logit_bias_token_ids.stride(0),
|
||||
logit_bias,
|
||||
logit_bias.stride(0),
|
||||
pos,
|
||||
min_lens,
|
||||
num_stop_token_ids,
|
||||
stop_token_ids,
|
||||
stop_token_ids.stride(0),
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
LOGITS_BLOCK_SIZE=LOGITS_BLOCK_SIZE,
|
||||
)
|
||||
126
vllm/v1/worker/gpu/sample/logprob.py
Normal file
126
vllm/v1/worker/gpu/sample/logprob.py
Normal file
@@ -0,0 +1,126 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.outputs import LogprobsTensors
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _topk_log_softmax_kernel(
|
||||
output_ptr,
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
topk_ids_ptr,
|
||||
topk,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
PADDED_TOPK: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
row_ptr = logits_ptr + req_idx * logits_stride
|
||||
|
||||
max_val = float("-inf")
|
||||
for i in range(0, vocab_size, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
logits = tl.load(row_ptr + block, mask=block < vocab_size, other=float("-inf"))
|
||||
max_val = tl.max(tl.maximum(logits, max_val))
|
||||
max_val = max_val.to(tl.float32) # type: ignore
|
||||
|
||||
se = 0.0
|
||||
for i in range(0, vocab_size, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
logits = tl.load(row_ptr + block, mask=block < vocab_size, other=0.0)
|
||||
# NOTE(woosuk): Make sure that logits and all following operations use FP32.
|
||||
logits = logits.to(tl.float32)
|
||||
e = tl.exp(logits - max_val)
|
||||
e = tl.where(block < vocab_size, e, 0.0)
|
||||
se += tl.sum(e)
|
||||
lse = tl.log(se)
|
||||
|
||||
k_offset = tl.arange(0, PADDED_TOPK)
|
||||
k_mask = k_offset < topk
|
||||
topk_ids = tl.load(topk_ids_ptr + req_idx * topk + k_offset, mask=k_mask, other=0)
|
||||
|
||||
logits = tl.load(row_ptr + topk_ids, mask=k_mask)
|
||||
logits = logits.to(tl.float32)
|
||||
o = logits - max_val - lse
|
||||
tl.store(output_ptr + req_idx * topk + k_offset, o, mask=k_mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _ranks_kernel(
|
||||
output_ptr,
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
token_ids_ptr,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
row_ptr = logits_ptr + req_idx * logits_stride
|
||||
|
||||
token_id = tl.load(token_ids_ptr + req_idx)
|
||||
x = tl.load(row_ptr + token_id)
|
||||
|
||||
n = 0
|
||||
for i in range(0, vocab_size, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
logits = tl.load(row_ptr + block, mask=block < vocab_size, other=float("-inf"))
|
||||
n += tl.sum((logits >= x).to(tl.int32))
|
||||
tl.store(output_ptr + req_idx, n)
|
||||
|
||||
|
||||
def compute_token_logprobs(
|
||||
logits: torch.Tensor, token_ids: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
batch_size, vocab_size = logits.shape
|
||||
token_ids = token_ids.to(torch.int64)
|
||||
num_logprobs = token_ids.shape[1]
|
||||
logprobs = logits.new_empty((batch_size, num_logprobs), dtype=torch.float32)
|
||||
_topk_log_softmax_kernel[(batch_size,)](
|
||||
logprobs,
|
||||
logits,
|
||||
logits.stride(0),
|
||||
token_ids,
|
||||
num_logprobs,
|
||||
vocab_size,
|
||||
BLOCK_SIZE=1024, # type: ignore
|
||||
PADDED_TOPK=triton.next_power_of_2(num_logprobs),
|
||||
)
|
||||
return logprobs
|
||||
|
||||
|
||||
def compute_topk_logprobs(
|
||||
logits: torch.Tensor,
|
||||
num_logprobs: int,
|
||||
sampled_token_ids: torch.Tensor,
|
||||
cu_num_logits: list[int] | None = None,
|
||||
) -> LogprobsTensors:
|
||||
assert num_logprobs >= 0
|
||||
batch_size, vocab_size = logits.shape
|
||||
logprob_token_ids = sampled_token_ids.unsqueeze(-1)
|
||||
if num_logprobs > 0:
|
||||
topk_indices = torch.topk(logits, num_logprobs, dim=-1).indices
|
||||
logprob_token_ids = torch.cat((logprob_token_ids, topk_indices), dim=1)
|
||||
|
||||
# NOTE(woosuk): Here, to save GPU memory, we do not materialize the full
|
||||
# logprobs tensor. Instead, we only compute and return the logprobs of
|
||||
# the topk + 1 tokens.
|
||||
logprobs = compute_token_logprobs(logits, logprob_token_ids)
|
||||
token_ranks = torch.empty(batch_size, dtype=torch.int64, device=logits.device)
|
||||
_ranks_kernel[(batch_size,)](
|
||||
token_ranks,
|
||||
logits,
|
||||
logits.stride(0),
|
||||
sampled_token_ids,
|
||||
vocab_size,
|
||||
BLOCK_SIZE=8192, # type: ignore
|
||||
)
|
||||
return LogprobsTensors(
|
||||
logprob_token_ids=logprob_token_ids,
|
||||
logprobs=logprobs,
|
||||
selected_token_ranks=token_ranks,
|
||||
cu_num_generated_tokens=cu_num_logits,
|
||||
)
|
||||
56
vllm/v1/worker/gpu/sample/min_p.py
Normal file
56
vllm/v1/worker/gpu/sample/min_p.py
Normal file
@@ -0,0 +1,56 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _min_p_kernel(
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
idx_mapping_ptr,
|
||||
min_p_ptr,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + req_idx)
|
||||
min_p = tl.load(min_p_ptr + req_state_idx).to(tl.float32)
|
||||
if min_p == 0.0:
|
||||
return
|
||||
|
||||
max_val = float("-inf")
|
||||
for i in range(0, vocab_size, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < vocab_size
|
||||
logits = tl.load(
|
||||
logits_ptr + req_idx * logits_stride + block, mask=mask, other=float("-inf")
|
||||
)
|
||||
max_val = tl.max(tl.maximum(logits, max_val))
|
||||
max_val = max_val.to(tl.float32) # type: ignore
|
||||
|
||||
threshold = max_val + tl.log(min_p)
|
||||
for i in range(0, vocab_size, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < vocab_size
|
||||
logits = tl.load(
|
||||
logits_ptr + req_idx * logits_stride + block, mask=mask, other=float("-inf")
|
||||
)
|
||||
logits = tl.where(logits < threshold, float("-inf"), logits)
|
||||
tl.store(logits_ptr + req_idx * logits_stride + block, logits, mask=mask)
|
||||
|
||||
|
||||
def apply_min_p(
|
||||
logits: torch.Tensor, idx_mapping: torch.Tensor, min_p: torch.Tensor
|
||||
) -> None:
|
||||
num_reqs, vocab_size = logits.shape
|
||||
BLOCK_SIZE = 1024
|
||||
_min_p_kernel[(num_reqs,)](
|
||||
logits,
|
||||
logits.stride(0),
|
||||
idx_mapping,
|
||||
min_p,
|
||||
vocab_size,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
14
vllm/v1/worker/gpu/sample/output.py
Normal file
14
vllm/v1/worker/gpu/sample/output.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.v1.outputs import LogprobsTensors
|
||||
|
||||
|
||||
@dataclass
|
||||
class SamplerOutput:
|
||||
sampled_token_ids: torch.Tensor
|
||||
logprobs_tensors: LogprobsTensors | None
|
||||
num_nans: torch.Tensor | None
|
||||
311
vllm/v1/worker/gpu/sample/penalties.py
Normal file
311
vllm/v1/worker/gpu/sample/penalties.py
Normal file
@@ -0,0 +1,311 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.torch_utils import async_tensor_h2d
|
||||
from vllm.v1.worker.gpu.buffer_utils import UvaBackedTensor
|
||||
from vllm.v1.worker.gpu.states import RequestState
|
||||
|
||||
|
||||
class PenaltiesState:
|
||||
def __init__(self, req_states: RequestState):
|
||||
self.req_states = req_states
|
||||
|
||||
max_num_reqs = req_states.max_num_reqs
|
||||
self.vocab_size = req_states.vocab_size
|
||||
self.device = req_states.device
|
||||
|
||||
self.repetition_penalty = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
|
||||
self.frequency_penalty = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
|
||||
self.presence_penalty = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
|
||||
self.use_penalty = np.zeros(max_num_reqs, dtype=bool)
|
||||
|
||||
# Initialize repetition penalty manually because 0 is an invalid value for it.
|
||||
self.repetition_penalty.np.fill(1.0)
|
||||
self.repetition_penalty.copy_to_uva()
|
||||
|
||||
# Statistics for penalties.
|
||||
self.prompt_bin_mask = torch.zeros(
|
||||
max_num_reqs,
|
||||
cdiv(self.vocab_size, 32),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
# TODO(woosuk): This tensor is rarely used but can be very large, taking up
|
||||
# GBs of GPU memory. Optimize the memory usage.
|
||||
self.output_bin_counts = torch.zeros(
|
||||
max_num_reqs, self.vocab_size, dtype=torch.int32, device=self.device
|
||||
)
|
||||
|
||||
self._new_penalties_reqs: list[int] = []
|
||||
|
||||
def add_request(self, req_idx: int, sampling_params: SamplingParams) -> None:
|
||||
self.repetition_penalty.np[req_idx] = sampling_params.repetition_penalty
|
||||
self.frequency_penalty.np[req_idx] = sampling_params.frequency_penalty
|
||||
self.presence_penalty.np[req_idx] = sampling_params.presence_penalty
|
||||
|
||||
do_penalty = use_penalty(sampling_params)
|
||||
self.use_penalty[req_idx] = do_penalty
|
||||
if do_penalty:
|
||||
self._new_penalties_reqs.append(req_idx)
|
||||
|
||||
def apply_staged_writes(self) -> None:
|
||||
if self._new_penalties_reqs:
|
||||
idx_mapping = async_tensor_h2d(
|
||||
self._new_penalties_reqs,
|
||||
dtype=torch.int32,
|
||||
target_device=self.device,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
prefill_lens = self.req_states.prefill_len.np[self._new_penalties_reqs]
|
||||
max_prefill_len = int(prefill_lens.max())
|
||||
bincount(
|
||||
idx_mapping,
|
||||
self.req_states.all_token_ids.gpu,
|
||||
self.req_states.prompt_len.gpu,
|
||||
self.req_states.prefill_len.gpu,
|
||||
self.prompt_bin_mask,
|
||||
self.output_bin_counts,
|
||||
max_prefill_len,
|
||||
)
|
||||
self._new_penalties_reqs.clear()
|
||||
|
||||
self.repetition_penalty.copy_to_uva()
|
||||
self.frequency_penalty.copy_to_uva()
|
||||
self.presence_penalty.copy_to_uva()
|
||||
|
||||
def apply_penalties(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
idx_mapping_np: np.ndarray,
|
||||
input_ids: torch.Tensor,
|
||||
expanded_local_pos: torch.Tensor,
|
||||
num_speculative_tokens: int,
|
||||
) -> None:
|
||||
if not np.any(self.use_penalty[idx_mapping_np]):
|
||||
# No request uses penalties. Skip the kernel launch.
|
||||
return
|
||||
|
||||
apply_penalties(
|
||||
logits,
|
||||
idx_mapping,
|
||||
input_ids,
|
||||
expanded_local_pos,
|
||||
self.repetition_penalty.gpu,
|
||||
self.frequency_penalty.gpu,
|
||||
self.presence_penalty.gpu,
|
||||
self.prompt_bin_mask,
|
||||
self.output_bin_counts,
|
||||
num_speculative_tokens,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _penalties_kernel(
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
idx_mapping_ptr,
|
||||
token_ids_ptr,
|
||||
expanded_local_pos_ptr,
|
||||
repetition_penalty_ptr,
|
||||
frequency_penalty_ptr,
|
||||
presence_penalty_ptr,
|
||||
prompt_bin_mask_ptr,
|
||||
prompt_bin_mask_stride,
|
||||
output_bin_counts_ptr,
|
||||
output_bin_counts_stride,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
MAX_SPEC_LEN: tl.constexpr,
|
||||
):
|
||||
token_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + token_idx)
|
||||
rep_penalty = tl.load(repetition_penalty_ptr + req_state_idx)
|
||||
freq_penalty = tl.load(frequency_penalty_ptr + req_state_idx)
|
||||
pres_penalty = tl.load(presence_penalty_ptr + req_state_idx)
|
||||
|
||||
use_rep_penalty = rep_penalty != 1.0
|
||||
use_freq_penalty = freq_penalty != 0.0
|
||||
use_pres_penalty = pres_penalty != 0.0
|
||||
use_penalty = use_rep_penalty or use_freq_penalty or use_pres_penalty
|
||||
if not use_penalty:
|
||||
# Early return to avoid loading logits.
|
||||
return
|
||||
|
||||
block_idx = tl.program_id(1)
|
||||
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < vocab_size
|
||||
logits = tl.load(logits_ptr + token_idx * logits_stride + block, mask=mask)
|
||||
logits = logits.to(tl.float32)
|
||||
|
||||
base_output_counts = tl.load(
|
||||
output_bin_counts_ptr + req_state_idx * output_bin_counts_stride + block,
|
||||
mask=mask,
|
||||
other=0,
|
||||
)
|
||||
|
||||
# Compute cumulative draft_counts from previous positions in this request
|
||||
pos = tl.load(expanded_local_pos_ptr + token_idx)
|
||||
start_idx = token_idx - pos
|
||||
draft_counts = tl.zeros((BLOCK_SIZE,), dtype=tl.int32)
|
||||
for prev_pos in tl.static_range(MAX_SPEC_LEN):
|
||||
if prev_pos < pos:
|
||||
prev_token = tl.load(token_ids_ptr + start_idx + prev_pos + 1)
|
||||
token_match = block == prev_token
|
||||
draft_counts = draft_counts + token_match.to(tl.int32)
|
||||
|
||||
# Total counts = base output counts + cumulative draft counts
|
||||
output_bin_counts = base_output_counts + draft_counts
|
||||
output_bin_mask = output_bin_counts > 0
|
||||
|
||||
# Apply repetition penalties.
|
||||
if use_rep_penalty:
|
||||
packed_block = block_idx * BLOCK_SIZE // 32 + tl.arange(0, BLOCK_SIZE // 32)
|
||||
packed_mask = tl.load(
|
||||
prompt_bin_mask_ptr + req_state_idx * prompt_bin_mask_stride + packed_block,
|
||||
mask=packed_block < tl.cdiv(vocab_size, 32),
|
||||
other=0,
|
||||
)
|
||||
prompt_bin_mask = (packed_mask[:, None] >> (tl.arange(0, 32)[None, :])) & 1
|
||||
prompt_bin_mask = prompt_bin_mask.to(tl.int1)
|
||||
prompt_bin_mask = prompt_bin_mask.reshape(BLOCK_SIZE)
|
||||
|
||||
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
|
||||
scale = tl.where(prompt_bin_mask | output_bin_mask, rep_penalty, 1.0)
|
||||
# If logits are positive, divide by penalty, otherwise multiply by penalty.
|
||||
logits *= tl.where(logits > 0, 1.0 / scale, scale)
|
||||
|
||||
# Apply frequency penalties.
|
||||
logits -= freq_penalty * output_bin_counts
|
||||
# Apply presence penalties.
|
||||
logits -= pres_penalty * output_bin_mask
|
||||
# Store back to logits.
|
||||
tl.store(logits_ptr + token_idx * logits_stride + block, logits, mask=mask)
|
||||
|
||||
|
||||
def apply_penalties(
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
token_ids: torch.Tensor,
|
||||
expanded_local_pos: torch.Tensor,
|
||||
repetition_penalty: torch.Tensor,
|
||||
frequency_penalty: torch.Tensor,
|
||||
presence_penalty: torch.Tensor,
|
||||
prompt_bin_mask: torch.Tensor,
|
||||
output_bin_counts: torch.Tensor,
|
||||
num_speculative_tokens: int,
|
||||
) -> None:
|
||||
num_tokens, vocab_size = logits.shape
|
||||
BLOCK_SIZE = 8192
|
||||
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
|
||||
_penalties_kernel[(num_tokens, num_blocks)](
|
||||
logits,
|
||||
logits.stride(0),
|
||||
idx_mapping,
|
||||
token_ids,
|
||||
expanded_local_pos,
|
||||
repetition_penalty,
|
||||
frequency_penalty,
|
||||
presence_penalty,
|
||||
prompt_bin_mask,
|
||||
prompt_bin_mask.stride(0),
|
||||
output_bin_counts,
|
||||
output_bin_counts.stride(0),
|
||||
vocab_size,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
MAX_SPEC_LEN=num_speculative_tokens,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _bincount_kernel(
|
||||
idx_mapping_ptr,
|
||||
all_token_ids_ptr,
|
||||
all_token_ids_stride,
|
||||
prompt_len_ptr,
|
||||
prefill_len_ptr,
|
||||
prompt_bin_mask_ptr,
|
||||
prompt_bin_mask_stride,
|
||||
output_bin_counts_ptr,
|
||||
output_bin_counts_stride,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
block_idx = tl.program_id(1)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
||||
|
||||
prefill_len = tl.load(prefill_len_ptr + req_state_idx)
|
||||
if block_idx * BLOCK_SIZE >= prefill_len:
|
||||
return
|
||||
|
||||
prompt_len = tl.load(prompt_len_ptr + req_state_idx)
|
||||
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
if block_idx * BLOCK_SIZE < prompt_len:
|
||||
mask = block < prompt_len
|
||||
prompt_tokens = tl.load(
|
||||
all_token_ids_ptr + req_state_idx * all_token_ids_stride + block, mask=mask
|
||||
)
|
||||
idx = prompt_tokens // 32
|
||||
bit_idx = prompt_tokens % 32
|
||||
bit = tl.full((BLOCK_SIZE,), 1, tl.int32) << bit_idx
|
||||
tl.atomic_or(
|
||||
prompt_bin_mask_ptr + req_state_idx * prompt_bin_mask_stride + idx,
|
||||
bit,
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
if (block_idx + 1) * BLOCK_SIZE >= prompt_len:
|
||||
mask = block < prefill_len
|
||||
mask &= block >= prompt_len
|
||||
output_tokens = tl.load(
|
||||
all_token_ids_ptr + req_state_idx * all_token_ids_stride + block, mask=mask
|
||||
)
|
||||
tl.atomic_add(
|
||||
output_bin_counts_ptr
|
||||
+ req_state_idx * output_bin_counts_stride
|
||||
+ output_tokens,
|
||||
1,
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
|
||||
def bincount(
|
||||
idx_mapping: torch.Tensor,
|
||||
all_token_ids: torch.Tensor,
|
||||
prompt_len: torch.Tensor,
|
||||
prefill_len: torch.Tensor,
|
||||
prompt_bin_mask: torch.Tensor,
|
||||
output_bin_counts: torch.Tensor,
|
||||
max_prefill_len: int,
|
||||
) -> None:
|
||||
prompt_bin_mask[idx_mapping] = 0
|
||||
output_bin_counts[idx_mapping] = 0
|
||||
num_reqs = idx_mapping.shape[0]
|
||||
BLOCK_SIZE = 1024
|
||||
num_blocks = triton.cdiv(max_prefill_len, BLOCK_SIZE)
|
||||
_bincount_kernel[(num_reqs, num_blocks)](
|
||||
idx_mapping,
|
||||
all_token_ids,
|
||||
all_token_ids.stride(0),
|
||||
prompt_len,
|
||||
prefill_len,
|
||||
prompt_bin_mask,
|
||||
prompt_bin_mask.stride(0),
|
||||
output_bin_counts,
|
||||
output_bin_counts.stride(0),
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
|
||||
|
||||
def use_penalty(sampling_params: SamplingParams) -> bool:
|
||||
return (
|
||||
sampling_params.repetition_penalty != 1.0
|
||||
or sampling_params.frequency_penalty != 0.0
|
||||
or sampling_params.presence_penalty != 0.0
|
||||
)
|
||||
208
vllm/v1/worker/gpu/sample/prompt_logprob.py
Normal file
208
vllm/v1/worker/gpu/sample/prompt_logprob.py
Normal file
@@ -0,0 +1,208 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Callable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.outputs import LogprobsTensors
|
||||
from vllm.v1.worker.gpu.input_batch import InputBatch
|
||||
from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs
|
||||
|
||||
|
||||
class PromptLogprobsWorker:
|
||||
def __init__(self, max_num_reqs: int):
|
||||
self.max_num_reqs = max_num_reqs
|
||||
|
||||
self.uses_prompt_logprobs = np.zeros(self.max_num_reqs, dtype=bool)
|
||||
# req_idx -> list of in-progress LogprobsTensors
|
||||
self.in_progress_prompt_logprobs: dict[str, list[LogprobsTensors]] = {}
|
||||
|
||||
def add_request(self, req_id: str, req_idx: int, sampling_params: SamplingParams):
|
||||
# For now, only support prompt logprobs for the prompt tokens (not top-k).
|
||||
uses_prompt_logprobs = sampling_params.prompt_logprobs is not None
|
||||
self.uses_prompt_logprobs[req_idx] = uses_prompt_logprobs
|
||||
if uses_prompt_logprobs:
|
||||
self.in_progress_prompt_logprobs[req_id] = []
|
||||
|
||||
def remove_request(self, req_id: str) -> None:
|
||||
self.in_progress_prompt_logprobs.pop(req_id, None)
|
||||
|
||||
def compute_prompt_logprobs(
|
||||
self,
|
||||
logits_fn: Callable[[torch.Tensor], torch.Tensor],
|
||||
hidden_states: torch.Tensor,
|
||||
input_batch: InputBatch,
|
||||
# [max_num_reqs, max_model_len]
|
||||
all_token_ids: torch.Tensor,
|
||||
# [max_num_reqs]
|
||||
num_computed_tokens: torch.Tensor,
|
||||
# [max_num_reqs]
|
||||
prompt_lens: np.ndarray,
|
||||
# [max_num_reqs]
|
||||
prefill_lens: np.ndarray,
|
||||
# [max_num_reqs]
|
||||
num_computed_prefill_tokens: np.ndarray,
|
||||
) -> dict[str, LogprobsTensors]:
|
||||
idx_mapping_np = input_batch.idx_mapping_np
|
||||
needs_prompt_logprobs = self.uses_prompt_logprobs[idx_mapping_np]
|
||||
if not np.any(needs_prompt_logprobs):
|
||||
# Common case: No request asks for prompt logprobs.
|
||||
return {}
|
||||
|
||||
prompt_lens = prompt_lens[idx_mapping_np]
|
||||
# NOTE(woosuk): -1 because the last prompt token's hidden state is not
|
||||
# needed for prompt logprobs.
|
||||
computed_prefill = num_computed_prefill_tokens[idx_mapping_np]
|
||||
includes_prompt = computed_prefill < prompt_lens - 1
|
||||
# NOTE(woosuk): If the request was resumed after preemption, its prompt
|
||||
# logprobs must have been computed before preemption. Skip.
|
||||
resumed_after_prompt = prompt_lens < prefill_lens[idx_mapping_np]
|
||||
needs_prompt_logprobs &= includes_prompt & ~resumed_after_prompt
|
||||
if not np.any(needs_prompt_logprobs):
|
||||
return {}
|
||||
|
||||
# Get the prompt logprobs token_ids.
|
||||
prompt_logprobs_token_ids = get_prompt_logprobs_token_ids(
|
||||
input_batch.num_tokens,
|
||||
input_batch.query_start_loc,
|
||||
input_batch.idx_mapping,
|
||||
num_computed_tokens,
|
||||
all_token_ids,
|
||||
)
|
||||
# Compute the prompt logprobs.
|
||||
prompt_logprobs, prompt_ranks = compute_prompt_logprobs_with_chunking(
|
||||
prompt_logprobs_token_ids,
|
||||
hidden_states[: input_batch.num_tokens],
|
||||
logits_fn,
|
||||
)
|
||||
|
||||
pos_after_step = computed_prefill + input_batch.num_scheduled_tokens
|
||||
is_prompt_chunked = pos_after_step < prompt_lens
|
||||
|
||||
query_start_loc_np = input_batch.query_start_loc_np
|
||||
prompt_token_ids = prompt_logprobs_token_ids.unsqueeze(-1)
|
||||
prompt_logprobs_dict: dict[str, LogprobsTensors] = {}
|
||||
for i, req_id in enumerate(input_batch.req_ids):
|
||||
if not needs_prompt_logprobs[i]:
|
||||
continue
|
||||
|
||||
start_idx = query_start_loc_np[i]
|
||||
end_idx = query_start_loc_np[i + 1]
|
||||
assert start_idx < end_idx, (
|
||||
f"start_idx ({start_idx}) >= end_idx ({end_idx})"
|
||||
)
|
||||
if not is_prompt_chunked[i]:
|
||||
end_idx -= 1
|
||||
logprobs = LogprobsTensors(
|
||||
logprob_token_ids=prompt_token_ids[start_idx:end_idx],
|
||||
logprobs=prompt_logprobs[start_idx:end_idx],
|
||||
selected_token_ranks=prompt_ranks[start_idx:end_idx],
|
||||
)
|
||||
|
||||
prompt_logprobs_list = self.in_progress_prompt_logprobs[req_id]
|
||||
if is_prompt_chunked[i]:
|
||||
# Prompt is chunked. Do not return the logprobs yet.
|
||||
prompt_logprobs_list.append(logprobs)
|
||||
continue
|
||||
|
||||
if prompt_logprobs_list:
|
||||
# Merge the in-progress logprobs.
|
||||
prompt_logprobs_list.append(logprobs)
|
||||
logprobs = LogprobsTensors(
|
||||
logprob_token_ids=torch.cat(
|
||||
[x.logprob_token_ids for x in prompt_logprobs_list]
|
||||
),
|
||||
logprobs=torch.cat([x.logprobs for x in prompt_logprobs_list]),
|
||||
selected_token_ranks=torch.cat(
|
||||
[x.selected_token_ranks for x in prompt_logprobs_list]
|
||||
),
|
||||
)
|
||||
prompt_logprobs_list.clear()
|
||||
|
||||
prompt_logprobs_dict[req_id] = logprobs
|
||||
return prompt_logprobs_dict
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _prompt_logprobs_token_ids_kernel(
|
||||
prompt_logprobs_token_ids_ptr,
|
||||
query_start_loc_ptr,
|
||||
idx_mapping_ptr,
|
||||
num_computed_tokens_ptr,
|
||||
all_token_ids_ptr,
|
||||
all_token_ids_stride,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
||||
|
||||
query_start = tl.load(query_start_loc_ptr + batch_idx)
|
||||
query_end = tl.load(query_start_loc_ptr + batch_idx + 1)
|
||||
query_len = query_end - query_start
|
||||
|
||||
num_computed_tokens = tl.load(num_computed_tokens_ptr + req_state_idx)
|
||||
for i in range(0, query_len, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < query_len
|
||||
# NOTE(woosuk): We should shift the pos by one
|
||||
# because the logprob is computed for the next token.
|
||||
target_pos = num_computed_tokens + 1 + block
|
||||
token_ids = tl.load(
|
||||
all_token_ids_ptr + req_state_idx * all_token_ids_stride + target_pos,
|
||||
mask=mask,
|
||||
)
|
||||
tl.store(
|
||||
prompt_logprobs_token_ids_ptr + query_start + block, token_ids, mask=mask
|
||||
)
|
||||
|
||||
|
||||
def get_prompt_logprobs_token_ids(
|
||||
num_tokens: int,
|
||||
query_start_loc: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
num_computed_tokens: torch.Tensor,
|
||||
all_token_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
token_ids = torch.empty(num_tokens, dtype=torch.int64, device=idx_mapping.device)
|
||||
num_reqs = idx_mapping.shape[0]
|
||||
_prompt_logprobs_token_ids_kernel[(num_reqs,)](
|
||||
token_ids,
|
||||
query_start_loc,
|
||||
idx_mapping,
|
||||
num_computed_tokens,
|
||||
all_token_ids,
|
||||
all_token_ids.stride(0),
|
||||
BLOCK_SIZE=1024,
|
||||
)
|
||||
return token_ids
|
||||
|
||||
|
||||
def compute_prompt_logprobs_with_chunking(
|
||||
prompt_token_ids: torch.Tensor,
|
||||
prompt_hidden_states: torch.Tensor,
|
||||
logits_fn: Callable[[torch.Tensor], torch.Tensor],
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Since materializing the full prompt logits can take too much memory,
|
||||
# we compute it in chunks.
|
||||
CHUNK_SIZE = 1024
|
||||
logprobs = []
|
||||
ranks = []
|
||||
prompt_token_ids = prompt_token_ids.to(torch.int64)
|
||||
for start_idx in range(0, prompt_token_ids.shape[0], CHUNK_SIZE):
|
||||
end_idx = start_idx + CHUNK_SIZE
|
||||
# NOTE(woosuk): logits_fn can be slow because it involves all-gather.
|
||||
prompt_logits = logits_fn(prompt_hidden_states[start_idx:end_idx])
|
||||
prompt_logprobs = compute_topk_logprobs(
|
||||
prompt_logits,
|
||||
0, # num_logprobs
|
||||
prompt_token_ids[start_idx:end_idx],
|
||||
)
|
||||
logprobs.append(prompt_logprobs.logprobs)
|
||||
ranks.append(prompt_logprobs.selected_token_ranks)
|
||||
|
||||
logprobs = torch.cat(logprobs, dim=0) if len(logprobs) > 1 else logprobs[0]
|
||||
ranks = torch.cat(ranks, dim=0) if len(ranks) > 1 else ranks[0]
|
||||
return logprobs, ranks
|
||||
155
vllm/v1/worker/gpu/sample/sampler.py
Normal file
155
vllm/v1/worker/gpu/sample/sampler.py
Normal file
@@ -0,0 +1,155 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config.model import LogprobsMode
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.worker.gpu.metrics.logits import get_num_nans
|
||||
from vllm.v1.worker.gpu.sample.bad_words import BadWordsState
|
||||
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
|
||||
from vllm.v1.worker.gpu.sample.logit_bias import LogitBiasState
|
||||
from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs
|
||||
from vllm.v1.worker.gpu.sample.output import SamplerOutput
|
||||
from vllm.v1.worker.gpu.sample.penalties import PenaltiesState
|
||||
from vllm.v1.worker.gpu.sample.states import NO_LOGPROBS, SamplingStates
|
||||
from vllm.v1.worker.gpu.states import RequestState
|
||||
|
||||
|
||||
class Sampler:
|
||||
def __init__(
|
||||
self,
|
||||
max_num_reqs: int,
|
||||
vocab_size: int,
|
||||
device: torch.device,
|
||||
req_states: RequestState,
|
||||
logprobs_mode: LogprobsMode = "raw_logprobs",
|
||||
num_speculative_tokens: int = 1,
|
||||
):
|
||||
if logprobs_mode not in ("processed_logprobs", "raw_logprobs"):
|
||||
raise NotImplementedError(f"Unsupported logprobs_mode: {logprobs_mode}")
|
||||
self.logprobs_mode = logprobs_mode
|
||||
self.compute_nans = envs.VLLM_COMPUTE_NANS_IN_LOGITS # False by default.
|
||||
|
||||
self.sampling_states = SamplingStates(max_num_reqs, vocab_size)
|
||||
self.penalties_state = PenaltiesState(req_states)
|
||||
self.logit_bias_state = LogitBiasState(max_num_reqs, device)
|
||||
self.bad_words_state = BadWordsState(req_states)
|
||||
self.num_speculative_tokens = num_speculative_tokens
|
||||
|
||||
def add_request(
|
||||
self, req_idx: int, prompt_len: int, sampling_params: SamplingParams
|
||||
) -> None:
|
||||
self.sampling_states.add_request(req_idx, sampling_params)
|
||||
self.penalties_state.add_request(req_idx, sampling_params)
|
||||
self.logit_bias_state.add_request(req_idx, prompt_len, sampling_params)
|
||||
self.bad_words_state.add_request(req_idx, sampling_params)
|
||||
|
||||
def apply_staged_writes(self) -> None:
|
||||
self.sampling_states.apply_staged_writes()
|
||||
self.penalties_state.apply_staged_writes()
|
||||
self.logit_bias_state.apply_staged_writes()
|
||||
self.bad_words_state.apply_staged_writes()
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
idx_mapping_np: np.ndarray,
|
||||
cu_num_logits_np: np.ndarray,
|
||||
pos: torch.Tensor,
|
||||
input_ids: torch.Tensor,
|
||||
expanded_local_pos: torch.Tensor,
|
||||
) -> SamplerOutput:
|
||||
# NOTE(woosuk): We intentionally compute num_nans before sampling to make clear
|
||||
# that num_nans is computed before applying penalties and temperature.
|
||||
num_nans = get_num_nans(logits) if self.compute_nans else None
|
||||
sampled, processed_logits = self.sample(
|
||||
logits,
|
||||
idx_mapping,
|
||||
idx_mapping_np,
|
||||
pos,
|
||||
input_ids,
|
||||
expanded_local_pos,
|
||||
)
|
||||
|
||||
max_num_logprobs = self.sampling_states.max_num_logprobs(idx_mapping_np)
|
||||
if max_num_logprobs != NO_LOGPROBS:
|
||||
if self.logprobs_mode == "processed_logprobs":
|
||||
logits = processed_logits
|
||||
expanded_logits = logits.shape[0] != idx_mapping_np.shape[0]
|
||||
cu_num_logits = cu_num_logits_np.tolist() if expanded_logits else None
|
||||
logprobs_tensors = compute_topk_logprobs(
|
||||
logits, max_num_logprobs, sampled, cu_num_logits
|
||||
)
|
||||
else:
|
||||
logprobs_tensors = None
|
||||
|
||||
# These are GPU tensors.
|
||||
sampler_output = SamplerOutput(
|
||||
# The sampled tokens are expanded to 2D tensor with shape
|
||||
# [num_requests, 1], where each row represents one generated
|
||||
# token per request.
|
||||
sampled_token_ids=sampled.view(-1, 1),
|
||||
logprobs_tensors=logprobs_tensors,
|
||||
num_nans=num_nans,
|
||||
)
|
||||
return sampler_output
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
idx_mapping_np: np.ndarray,
|
||||
pos: torch.Tensor,
|
||||
input_ids: torch.Tensor,
|
||||
expanded_local_pos: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Copy logits to a new FP32 tensor.
|
||||
logits = torch.empty_like(logits, dtype=torch.float32).copy_(logits)
|
||||
|
||||
# Apply logit bias (e.g., allowed_token_ids, min_tokens) in place.
|
||||
self.logit_bias_state.apply_logit_bias(logits, idx_mapping, idx_mapping_np, pos)
|
||||
|
||||
# Apply penalties in place.
|
||||
self.penalties_state.apply_penalties(
|
||||
logits,
|
||||
idx_mapping,
|
||||
idx_mapping_np,
|
||||
input_ids,
|
||||
expanded_local_pos,
|
||||
self.num_speculative_tokens,
|
||||
)
|
||||
|
||||
# Apply bad words masking in place.
|
||||
self.bad_words_state.apply_bad_words(
|
||||
logits,
|
||||
idx_mapping,
|
||||
idx_mapping_np,
|
||||
input_ids,
|
||||
expanded_local_pos,
|
||||
)
|
||||
|
||||
# Apply temperature in place.
|
||||
self.sampling_states.apply_temperature(logits, idx_mapping, idx_mapping_np)
|
||||
|
||||
# Apply min_p in place.
|
||||
self.sampling_states.apply_min_p(logits, idx_mapping, idx_mapping_np)
|
||||
|
||||
# Apply top_k and/or top_p. This might or might not return a new tensor.
|
||||
logits = self.sampling_states.apply_top_k_top_p(
|
||||
logits, idx_mapping, idx_mapping_np
|
||||
)
|
||||
|
||||
# Sample the next token.
|
||||
sampled = gumbel_sample(
|
||||
logits,
|
||||
idx_mapping,
|
||||
self.sampling_states.temperature.gpu,
|
||||
self.sampling_states.seeds.gpu,
|
||||
pos,
|
||||
apply_temperature=False,
|
||||
)
|
||||
return sampled, logits
|
||||
104
vllm/v1/worker/gpu/sample/states.py
Normal file
104
vllm/v1/worker/gpu/sample/states.py
Normal file
@@ -0,0 +1,104 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||
from vllm.v1.worker.gpu.buffer_utils import UvaBackedTensor
|
||||
from vllm.v1.worker.gpu.sample.gumbel import apply_temperature
|
||||
from vllm.v1.worker.gpu.sample.min_p import apply_min_p
|
||||
|
||||
NO_LOGPROBS = -1
|
||||
_NP_INT64_MIN = np.iinfo(np.int64).min
|
||||
_NP_INT64_MAX = np.iinfo(np.int64).max
|
||||
|
||||
|
||||
class SamplingStates:
|
||||
def __init__(self, max_num_reqs: int, vocab_size: int):
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
self.temperature = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
|
||||
self.top_k = UvaBackedTensor(max_num_reqs, dtype=torch.int32)
|
||||
self.top_p = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
|
||||
self.min_p = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
|
||||
self.seeds = UvaBackedTensor(max_num_reqs, dtype=torch.int64)
|
||||
|
||||
# Initialize top_k and top_p manually because 0 is an invalid value for them.
|
||||
self.top_k.np.fill(self.vocab_size)
|
||||
self.top_k.copy_to_uva()
|
||||
self.top_p.np.fill(1.0)
|
||||
self.top_p.copy_to_uva()
|
||||
|
||||
self.num_logprobs = np.empty(self.max_num_reqs, dtype=np.int32)
|
||||
# -1 means no logprobs are requested.
|
||||
self.num_logprobs.fill(NO_LOGPROBS)
|
||||
|
||||
def add_request(self, req_idx: int, sampling_params: SamplingParams) -> None:
|
||||
self.temperature.np[req_idx] = sampling_params.temperature
|
||||
self.top_p.np[req_idx] = sampling_params.top_p
|
||||
top_k = sampling_params.top_k
|
||||
if top_k <= 0 or top_k > self.vocab_size:
|
||||
top_k = self.vocab_size
|
||||
self.top_k.np[req_idx] = top_k
|
||||
self.min_p.np[req_idx] = sampling_params.min_p
|
||||
|
||||
seed = sampling_params.seed
|
||||
if seed is None:
|
||||
seed = np.random.randint(_NP_INT64_MIN, _NP_INT64_MAX)
|
||||
self.seeds.np[req_idx] = seed
|
||||
|
||||
num_logprobs = sampling_params.logprobs
|
||||
if num_logprobs is None:
|
||||
num_logprobs = NO_LOGPROBS
|
||||
self.num_logprobs[req_idx] = num_logprobs
|
||||
|
||||
def apply_staged_writes(self) -> None:
|
||||
self.temperature.copy_to_uva()
|
||||
self.top_p.copy_to_uva()
|
||||
self.top_k.copy_to_uva()
|
||||
self.min_p.copy_to_uva()
|
||||
self.seeds.copy_to_uva()
|
||||
|
||||
def apply_temperature(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
idx_mapping_np: np.ndarray,
|
||||
) -> None:
|
||||
temp_np = self.temperature.np[idx_mapping_np]
|
||||
if np.all((temp_np == 0.0) | (temp_np == 1.0)):
|
||||
# No request requires temperature. Skip the kernel launch.
|
||||
return
|
||||
|
||||
apply_temperature(logits, idx_mapping, self.temperature.gpu)
|
||||
|
||||
def apply_min_p(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
idx_mapping_np: np.ndarray,
|
||||
) -> None:
|
||||
if np.all(self.min_p.np[idx_mapping_np] == 0.0):
|
||||
# No request uses min_p. Skip the kernel launch.
|
||||
return
|
||||
apply_min_p(logits, idx_mapping, self.min_p.gpu)
|
||||
|
||||
def apply_top_k_top_p(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
idx_mapping_np: np.ndarray,
|
||||
) -> torch.Tensor:
|
||||
do_top_k = np.any(self.top_k.np[idx_mapping_np] != self.vocab_size)
|
||||
do_top_p = np.any(self.top_p.np[idx_mapping_np] != 1.0)
|
||||
if not (do_top_k or do_top_p):
|
||||
return logits
|
||||
|
||||
top_k = self.top_k.gpu[idx_mapping] if do_top_k else None
|
||||
top_p = self.top_p.gpu[idx_mapping] if do_top_p else None
|
||||
return apply_top_k_top_p(logits, top_k, top_p)
|
||||
|
||||
def max_num_logprobs(self, idx_mapping_np: np.ndarray) -> int:
|
||||
return int(np.max(self.num_logprobs[idx_mapping_np]))
|
||||
15
vllm/v1/worker/gpu/spec_decode/__init__.py
Normal file
15
vllm/v1/worker/gpu/spec_decode/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
|
||||
def init_speculator(vllm_config: VllmConfig, device: torch.device):
|
||||
speculative_config = vllm_config.speculative_config
|
||||
assert speculative_config is not None
|
||||
if speculative_config.use_eagle():
|
||||
from vllm.v1.worker.gpu.spec_decode.eagle.speculator import EagleSpeculator
|
||||
|
||||
return EagleSpeculator(vllm_config, device)
|
||||
raise NotImplementedError(f"{speculative_config.method} is not supported yet.")
|
||||
0
vllm/v1/worker/gpu/spec_decode/eagle/__init__.py
Normal file
0
vllm/v1/worker/gpu/spec_decode/eagle/__init__.py
Normal file
191
vllm/v1/worker/gpu/spec_decode/eagle/cudagraph.py
Normal file
191
vllm/v1/worker/gpu/spec_decode/eagle/cudagraph.py
Normal file
@@ -0,0 +1,191 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.compilation import CUDAGraphMode
|
||||
from vllm.model_executor.offloader.base import get_offloader
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.worker.gpu.block_table import BlockTables
|
||||
from vllm.v1.worker.gpu.cudagraph_utils import (
|
||||
capture_graphs,
|
||||
get_cudagraph_sizes,
|
||||
prepare_inputs_to_capture,
|
||||
)
|
||||
from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
|
||||
from vllm.v1.worker.gpu.input_batch import InputBuffers
|
||||
from vllm.v1.worker.utils import AttentionGroup
|
||||
|
||||
|
||||
class EagleCudaGraphManager:
|
||||
def __init__(self, vllm_config: VllmConfig, device: torch.device):
|
||||
self.vllm_config = vllm_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.device = device
|
||||
|
||||
self.max_model_len = vllm_config.model_config.max_model_len
|
||||
self.max_num_reqs = self.scheduler_config.max_num_seqs
|
||||
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
|
||||
self.dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
assert self.compilation_config is not None
|
||||
|
||||
# NOTE(woosuk): For Eagle, we only use CUDA graphs for decode.
|
||||
self.cudagraph_mode = self.compilation_config.cudagraph_mode.decode_mode()
|
||||
|
||||
# only need to capture uniform decode cudagraph sizes (the 2nd return value)
|
||||
_, self.cudagraph_sizes = get_cudagraph_sizes(
|
||||
self.compilation_config.cudagraph_capture_sizes,
|
||||
self.max_num_reqs,
|
||||
self.max_num_tokens,
|
||||
self.cudagraph_mode,
|
||||
uniform_decode_query_len=1,
|
||||
uniform_decode_cudagraph=True,
|
||||
)
|
||||
|
||||
self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
|
||||
self.pool = None
|
||||
if self.cudagraph_mode != CUDAGraphMode.NONE:
|
||||
self.pool = torch.cuda.graph_pool_handle()
|
||||
|
||||
def get_cudagraph_size(self, num_tokens: int) -> int | None:
|
||||
return self.cudagraph_sizes.get(num_tokens)
|
||||
|
||||
def capture_graph(
|
||||
self,
|
||||
num_tokens: int,
|
||||
capture_cg_mode: CUDAGraphMode,
|
||||
generate_fn: Callable,
|
||||
input_buffers: InputBuffers,
|
||||
block_tables: BlockTables,
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
) -> None:
|
||||
assert capture_cg_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], (
|
||||
f"Invalid capture_cudagraph_mode for capture: {capture_cg_mode}"
|
||||
)
|
||||
if capture_cg_mode == CUDAGraphMode.PIECEWISE:
|
||||
capture_fn = self._capture_piecewise_graph
|
||||
else:
|
||||
capture_fn = self._capture_full_graph
|
||||
|
||||
num_reqs = min(num_tokens, self.max_num_reqs)
|
||||
attn_metadata, slot_mappings = prepare_inputs_to_capture(
|
||||
num_reqs,
|
||||
num_tokens,
|
||||
input_buffers,
|
||||
block_tables,
|
||||
attn_groups,
|
||||
self.max_model_len,
|
||||
kv_cache_config,
|
||||
uniform_decode_query_len=1,
|
||||
)
|
||||
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
|
||||
|
||||
# Warm up.
|
||||
generate_fn(
|
||||
num_reqs,
|
||||
num_tokens,
|
||||
attn_metadata,
|
||||
slot_mappings,
|
||||
num_tokens_across_dp,
|
||||
CUDAGraphMode.NONE,
|
||||
)
|
||||
|
||||
# Capture the graph.
|
||||
capture_fn(
|
||||
num_reqs=num_reqs,
|
||||
num_tokens=num_tokens,
|
||||
generate_fn=generate_fn,
|
||||
attn_metadata=attn_metadata,
|
||||
slot_mappings=slot_mappings,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
)
|
||||
|
||||
def _capture_full_graph(
|
||||
self,
|
||||
num_reqs: int,
|
||||
num_tokens: int,
|
||||
generate_fn: Callable,
|
||||
attn_metadata: dict[str, Any],
|
||||
slot_mappings: dict[str, torch.Tensor],
|
||||
num_tokens_across_dp: torch.Tensor,
|
||||
) -> None:
|
||||
assert num_tokens not in self.graphs
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
|
||||
# Sync offloader's copy stream before capture.
|
||||
# Ensure any pre-capture prefetches from offloader are complete.
|
||||
get_offloader().sync_prev_onload()
|
||||
|
||||
with torch.cuda.graph(graph, self.pool):
|
||||
generate_fn(
|
||||
num_reqs,
|
||||
num_tokens,
|
||||
attn_metadata,
|
||||
slot_mappings,
|
||||
num_tokens_across_dp,
|
||||
CUDAGraphMode.NONE,
|
||||
)
|
||||
# Join offloader's copy stream after forward to avoid unjoined
|
||||
# stream error. The last layer's start_prefetch forks copy_stream,
|
||||
# but wait_prefetch only happens in the next forward pass.
|
||||
get_offloader().join_after_forward()
|
||||
self.graphs[num_tokens] = graph
|
||||
|
||||
def _capture_piecewise_graph(
|
||||
self,
|
||||
num_reqs: int,
|
||||
num_tokens: int,
|
||||
generate_fn: Callable,
|
||||
attn_metadata: dict[str, Any],
|
||||
slot_mappings: dict[str, torch.Tensor],
|
||||
num_tokens_across_dp: torch.Tensor,
|
||||
) -> None:
|
||||
generate_fn(
|
||||
num_reqs,
|
||||
num_tokens,
|
||||
attn_metadata,
|
||||
slot_mappings,
|
||||
num_tokens_across_dp,
|
||||
CUDAGraphMode.PIECEWISE,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def capture(
|
||||
self,
|
||||
generate_fn: Callable,
|
||||
input_buffers: InputBuffers,
|
||||
block_tables: BlockTables,
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
) -> None:
|
||||
if self.cudagraph_mode == CUDAGraphMode.NONE:
|
||||
return
|
||||
|
||||
capture_graphs(
|
||||
self.cudagraph_sizes,
|
||||
self.device,
|
||||
self.capture_graph,
|
||||
capture_cudagraph_mode=self.cudagraph_mode,
|
||||
desc=f"Capturing eagle CUDA graphs ({self.cudagraph_mode.name})",
|
||||
generate_fn=generate_fn,
|
||||
input_buffers=input_buffers,
|
||||
block_tables=block_tables,
|
||||
attn_groups=attn_groups,
|
||||
kv_cache_config=kv_cache_config,
|
||||
)
|
||||
|
||||
def run_fullgraph(self, num_tokens: int) -> None:
|
||||
assert num_tokens in self.graphs
|
||||
# Sync offloader before replay - needed when transitioning from
|
||||
# eager/piecewise to full cudagraph (e.g., prefill → decode).
|
||||
# The previous eager iteration's start_prefetch may have queued
|
||||
# H2D copies on copy_stream that the graph's captured events
|
||||
# cannot see. Without this, replay could overwrite static buffers
|
||||
# while those copies are still in flight.
|
||||
get_offloader().sync_prev_onload()
|
||||
self.graphs[num_tokens].replay()
|
||||
46
vllm/v1/worker/gpu/spec_decode/eagle/eagle3_utils.py
Normal file
46
vllm/v1/worker/gpu/spec_decode/eagle/eagle3_utils.py
Normal file
@@ -0,0 +1,46 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import cast
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import SpeculativeConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models.interfaces import SupportsEagle3, supports_eagle3
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def set_eagle3_aux_hidden_state_layers(
|
||||
model: nn.Module,
|
||||
spec_config: SpeculativeConfig,
|
||||
) -> None:
|
||||
if not supports_eagle3(model):
|
||||
raise RuntimeError("Model does not support EAGLE3 interface")
|
||||
# mypy may infer the class-level overload for supports_eagle3.
|
||||
# Narrow explicitly to the runtime protocol instance.
|
||||
if isinstance(model, type):
|
||||
raise RuntimeError("Expected model instance for EAGLE3 configuration")
|
||||
eagle3_model = cast(SupportsEagle3, model)
|
||||
|
||||
aux_layers = get_eagle3_aux_layers_from_config(spec_config)
|
||||
if aux_layers:
|
||||
logger.info("Using Eagle3 auxiliary layers from config: %s", aux_layers)
|
||||
else:
|
||||
aux_layers = eagle3_model.get_eagle3_aux_hidden_state_layers()
|
||||
logger.info("Using Eagle3 auxiliary layers from model: %s", aux_layers)
|
||||
eagle3_model.set_aux_hidden_state_layers(aux_layers)
|
||||
|
||||
|
||||
def get_eagle3_aux_layers_from_config(
|
||||
spec_config: SpeculativeConfig,
|
||||
) -> tuple[int, ...] | None:
|
||||
if not (spec_config and spec_config.draft_model_config):
|
||||
return None
|
||||
hf_config = spec_config.draft_model_config.hf_config
|
||||
if not hasattr(hf_config, "eagle_aux_hidden_state_layer_ids"):
|
||||
return None
|
||||
layer_ids = hf_config.eagle_aux_hidden_state_layer_ids
|
||||
if layer_ids and isinstance(layer_ids, (list, tuple)):
|
||||
return tuple(layer_ids)
|
||||
return None
|
||||
583
vllm/v1/worker/gpu/spec_decode/eagle/speculator.py
Normal file
583
vllm/v1/worker/gpu/spec_decode/eagle/speculator.py
Normal file
@@ -0,0 +1,583 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.compilation import CUDAGraphMode
|
||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.worker.gpu.attn_utils import (
|
||||
build_attn_metadata,
|
||||
build_slot_mappings_by_layer,
|
||||
)
|
||||
from vllm.v1.worker.gpu.block_table import BlockTables
|
||||
from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
|
||||
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
|
||||
from vllm.v1.worker.gpu.spec_decode.eagle.cudagraph import EagleCudaGraphManager
|
||||
from vllm.v1.worker.gpu.spec_decode.eagle.utils import load_eagle_model
|
||||
from vllm.v1.worker.utils import AttentionGroup
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class EagleSpeculator:
|
||||
def __init__(self, vllm_config: VllmConfig, device: torch.device):
|
||||
self.vllm_config = vllm_config
|
||||
self.device = device
|
||||
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
assert self.speculative_config is not None
|
||||
self.method = self.speculative_config.method
|
||||
self.num_speculative_steps = self.speculative_config.num_speculative_tokens
|
||||
self.draft_model_config = self.speculative_config.draft_model_config
|
||||
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.max_num_reqs = self.scheduler_config.max_num_seqs
|
||||
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
|
||||
self.max_model_len = vllm_config.model_config.max_model_len
|
||||
# We need to get the hidden size from the draft model config because
|
||||
# the draft model's hidden size can be different from the target model's
|
||||
# hidden size (e.g., Llama 3.3 70B).
|
||||
self.hidden_size = self.draft_model_config.get_hidden_size()
|
||||
self.inputs_embeds_size = self.draft_model_config.get_inputs_embeds_size()
|
||||
self.vocab_size = self.draft_model_config.get_vocab_size()
|
||||
self.dtype = vllm_config.model_config.dtype
|
||||
|
||||
self.input_buffers = InputBuffers(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_num_tokens=self.max_num_tokens,
|
||||
device=device,
|
||||
)
|
||||
self.hidden_states = torch.zeros(
|
||||
self.max_num_tokens, self.hidden_size, dtype=self.dtype, device=device
|
||||
)
|
||||
self.idx_mapping = torch.zeros(
|
||||
self.max_num_reqs, dtype=torch.int32, device=device
|
||||
)
|
||||
self.temperature = torch.zeros(
|
||||
self.max_num_reqs, dtype=torch.float32, device=device
|
||||
)
|
||||
self.seeds = torch.zeros(self.max_num_reqs, dtype=torch.int64, device=device)
|
||||
self.draft_tokens = torch.zeros(
|
||||
self.max_num_reqs,
|
||||
self.num_speculative_steps,
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
|
||||
self.cudagraph_manager = EagleCudaGraphManager(vllm_config, device)
|
||||
|
||||
def load_model(self, target_model: nn.Module) -> None:
|
||||
self.model = load_eagle_model(target_model, self.vllm_config)
|
||||
|
||||
def set_attn(
|
||||
self,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
block_tables: BlockTables,
|
||||
) -> None:
|
||||
self.kv_cache_config = kv_cache_config
|
||||
self.attn_groups = attn_groups
|
||||
self.block_tables = block_tables
|
||||
|
||||
@torch.inference_mode()
|
||||
def run_model(
|
||||
self,
|
||||
num_tokens: int,
|
||||
attn_metadata: dict[str, Any] | None,
|
||||
slot_mappings: dict[str, torch.Tensor] | None,
|
||||
num_tokens_across_dp: torch.Tensor | None,
|
||||
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_descriptor = BatchDescriptor(num_tokens=num_tokens)
|
||||
with set_forward_context(
|
||||
attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_tokens,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
slot_mapping=slot_mappings,
|
||||
batch_descriptor=batch_descriptor,
|
||||
):
|
||||
ret_hidden_states = self.model(
|
||||
input_ids=self.input_buffers.input_ids[:num_tokens],
|
||||
positions=self.input_buffers.positions[:num_tokens],
|
||||
hidden_states=self.hidden_states[:num_tokens],
|
||||
)
|
||||
if self.method == "mtp":
|
||||
last_hidden_states = ret_hidden_states
|
||||
hidden_states = ret_hidden_states
|
||||
else:
|
||||
last_hidden_states, hidden_states = ret_hidden_states
|
||||
return last_hidden_states, hidden_states
|
||||
|
||||
def generate_draft(
|
||||
self,
|
||||
num_reqs: int,
|
||||
num_tokens_padded: int,
|
||||
attn_metadata: dict[str, Any],
|
||||
slot_mappings: dict[str, torch.Tensor],
|
||||
num_tokens_across_dp: torch.Tensor | None,
|
||||
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||
) -> None:
|
||||
pos = self.input_buffers.positions[:num_reqs]
|
||||
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
|
||||
idx_mapping = self.idx_mapping[:num_reqs]
|
||||
for step in range(1, self.num_speculative_steps):
|
||||
# Run the eagle model.
|
||||
last_hidden_states, hidden_states = self.run_model(
|
||||
num_tokens_padded,
|
||||
attn_metadata,
|
||||
slot_mappings,
|
||||
num_tokens_across_dp,
|
||||
cudagraph_runtime_mode,
|
||||
)
|
||||
last_hidden_states = last_hidden_states[:num_reqs]
|
||||
hidden_states = hidden_states[:num_reqs]
|
||||
logits = self.model.compute_logits(last_hidden_states)
|
||||
|
||||
# NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise
|
||||
# used for draft and target sampling.
|
||||
draft_tokens = gumbel_sample(
|
||||
logits,
|
||||
idx_mapping,
|
||||
self.temperature,
|
||||
self.seeds,
|
||||
pos + 1,
|
||||
apply_temperature=True,
|
||||
)
|
||||
self.draft_tokens[:num_reqs, step] = draft_tokens
|
||||
|
||||
if step < self.num_speculative_steps - 1:
|
||||
# Update the inputs for the next step.
|
||||
update_eagle_inputs(
|
||||
draft_tokens,
|
||||
hidden_states,
|
||||
self.input_buffers,
|
||||
self.hidden_states,
|
||||
self.max_model_len,
|
||||
)
|
||||
self.block_tables.compute_slot_mappings(
|
||||
idx_mapping, query_start_loc, pos
|
||||
)
|
||||
|
||||
def capture_model(self) -> None:
|
||||
if self.num_speculative_steps == 1:
|
||||
return
|
||||
logger.info("Capturing model for Eagle speculator...")
|
||||
self.cudagraph_manager.capture(
|
||||
self.generate_draft,
|
||||
self.input_buffers,
|
||||
self.block_tables,
|
||||
self.attn_groups,
|
||||
self.kv_cache_config,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def propose(
|
||||
self,
|
||||
input_batch: InputBatch,
|
||||
# [num_tokens, hidden_size]
|
||||
last_hidden_states: torch.Tensor,
|
||||
# num_layers x [num_tokens, hidden_size]
|
||||
aux_hidden_states: list[torch.Tensor] | None,
|
||||
# [num_reqs]
|
||||
num_sampled: torch.Tensor,
|
||||
# [num_reqs]
|
||||
num_rejected: torch.Tensor,
|
||||
# [max_num_reqs]
|
||||
last_sampled: torch.Tensor,
|
||||
# [max_num_reqs]
|
||||
next_prefill_tokens: torch.Tensor,
|
||||
# [max_num_reqs]
|
||||
temperature: torch.Tensor,
|
||||
# [max_num_reqs]
|
||||
seeds: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# NOTE(woosuk): To avoid CPU-GPU synchronization without CPU knowing the
|
||||
# number of rejected tokens, we maintain the size of eagle's input_ids and
|
||||
# hidden_states the same as the target model's. This means, we pad each
|
||||
# request's query length to include any rejected positions. By doing so,
|
||||
# we can also reuse the attention metadata (e.g., query_start_loc,
|
||||
# seq_lens) of the target model.
|
||||
if aux_hidden_states:
|
||||
assert self.method == "eagle3"
|
||||
hidden_states = self.model.combine_hidden_states(
|
||||
torch.cat(aux_hidden_states, dim=-1)
|
||||
)
|
||||
else:
|
||||
hidden_states = last_hidden_states
|
||||
num_tokens = input_batch.num_tokens_after_padding
|
||||
self.hidden_states[:num_tokens] = hidden_states
|
||||
|
||||
# Get the input ids and last token indices for the speculator.
|
||||
last_token_indices = prepare_eagle_inputs(
|
||||
self.input_buffers,
|
||||
input_batch,
|
||||
num_sampled,
|
||||
num_rejected,
|
||||
last_sampled,
|
||||
next_prefill_tokens,
|
||||
)
|
||||
|
||||
# Prefill: Run the eagle speculator with eager mode.
|
||||
# TODO(woosuk): Support CUDA graph for prefill.
|
||||
last_hidden_states, hidden_states = self.run_model(
|
||||
num_tokens,
|
||||
input_batch.attn_metadata,
|
||||
input_batch.slot_mappings,
|
||||
num_tokens_across_dp=None, # FIXME
|
||||
)
|
||||
sample_hidden_states = last_hidden_states[last_token_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states)
|
||||
|
||||
num_reqs = input_batch.num_reqs
|
||||
# NOTE(woosuk): For draft sampling, we only consider the temperature
|
||||
# and ignore the other sampling parameters such as top_k and top_p,
|
||||
# for simplicity and performance.
|
||||
# While this may slightly degrade the acceptance rate, it does not
|
||||
# affect the output distribution after rejection sampling.
|
||||
idx_mapping = self.idx_mapping[:num_reqs]
|
||||
idx_mapping.copy_(input_batch.idx_mapping)
|
||||
self.temperature.copy_(temperature)
|
||||
self.seeds.copy_(seeds)
|
||||
# Gather the values and copy them to the pre-allocated buffers.
|
||||
pos = self.input_buffers.positions[:num_reqs]
|
||||
torch.gather(input_batch.positions, 0, last_token_indices, out=pos)
|
||||
# NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise
|
||||
# used for draft and target sampling.
|
||||
draft_tokens = gumbel_sample(
|
||||
logits,
|
||||
idx_mapping,
|
||||
self.temperature,
|
||||
self.seeds,
|
||||
pos + 1,
|
||||
apply_temperature=True,
|
||||
)
|
||||
if self.num_speculative_steps == 1:
|
||||
# Early exit.
|
||||
return draft_tokens.view(-1, 1)
|
||||
|
||||
# Save the draft tokens for the first step.
|
||||
self.draft_tokens[:num_reqs, 0] = draft_tokens
|
||||
# Prepare the inputs for the decode steps.
|
||||
prepare_eagle_decode(
|
||||
draft_tokens,
|
||||
hidden_states,
|
||||
last_token_indices,
|
||||
input_batch.seq_lens,
|
||||
num_rejected,
|
||||
self.input_buffers,
|
||||
self.hidden_states,
|
||||
self.max_model_len,
|
||||
self.max_num_reqs,
|
||||
)
|
||||
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
|
||||
slot_mappings = self.block_tables.compute_slot_mappings(
|
||||
idx_mapping, query_start_loc, pos
|
||||
)
|
||||
|
||||
cudagraph_size = self.cudagraph_manager.get_cudagraph_size(num_reqs)
|
||||
cudagraph_mode = self.cudagraph_manager.cudagraph_mode
|
||||
if cudagraph_size is not None and cudagraph_mode == CUDAGraphMode.FULL:
|
||||
# Run full CUDA graph.
|
||||
self.cudagraph_manager.run_fullgraph(cudagraph_size)
|
||||
return self.draft_tokens[:num_reqs]
|
||||
|
||||
# Run eager or piecewise CUDA graph.
|
||||
num_tokens_padded = cudagraph_size if cudagraph_size is not None else num_reqs
|
||||
query_start_loc_cpu = torch.arange(
|
||||
num_reqs + 1, dtype=torch.int32, device="cpu"
|
||||
)
|
||||
block_tables = [x[:num_reqs] for x in self.block_tables.input_block_tables]
|
||||
|
||||
# FIXME(woosuk): This is UNSAFE!!
|
||||
attn_metadata = build_attn_metadata(
|
||||
attn_groups=self.attn_groups,
|
||||
num_reqs=num_reqs,
|
||||
num_tokens=num_reqs,
|
||||
query_start_loc_gpu=query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
max_query_len=1,
|
||||
seq_lens=self.input_buffers.seq_lens[:num_reqs],
|
||||
max_seq_len=self.max_model_len,
|
||||
block_tables=block_tables,
|
||||
slot_mappings=slot_mappings,
|
||||
kv_cache_config=self.kv_cache_config,
|
||||
)
|
||||
slot_mappings_by_layer = build_slot_mappings_by_layer(
|
||||
slot_mappings, self.kv_cache_config
|
||||
)
|
||||
self.generate_draft(
|
||||
num_reqs,
|
||||
num_tokens_padded,
|
||||
attn_metadata,
|
||||
slot_mappings_by_layer,
|
||||
num_tokens_across_dp=None, # FIXME
|
||||
cudagraph_runtime_mode=cudagraph_mode,
|
||||
)
|
||||
return self.draft_tokens[:num_reqs]
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _prepare_eagle_inputs_kernel(
|
||||
last_token_indices_ptr,
|
||||
eagle_input_ids_ptr,
|
||||
eagle_positions_ptr,
|
||||
target_input_ids_ptr,
|
||||
target_positions_ptr,
|
||||
idx_mapping_ptr,
|
||||
last_sampled_ptr,
|
||||
next_prefill_tokens_ptr,
|
||||
num_sampled_ptr,
|
||||
num_rejected_ptr,
|
||||
query_start_loc_ptr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
||||
|
||||
query_start = tl.load(query_start_loc_ptr + batch_idx)
|
||||
query_end = tl.load(query_start_loc_ptr + batch_idx + 1)
|
||||
query_len = query_end - query_start
|
||||
|
||||
# Get the true query length and next token after accounting for rejected tokens.
|
||||
num_rejected = tl.load(num_rejected_ptr + batch_idx)
|
||||
query_len -= num_rejected
|
||||
|
||||
num_sampled = tl.load(num_sampled_ptr + batch_idx)
|
||||
if num_sampled > 0:
|
||||
next_token = tl.load(last_sampled_ptr + req_state_idx).to(tl.int32)
|
||||
else:
|
||||
# Chunked prefilling.
|
||||
# Get the next prefill token.
|
||||
next_token = tl.load(next_prefill_tokens_ptr + req_state_idx)
|
||||
|
||||
# Shift target_input_ids by one.
|
||||
for i in range(1, query_len, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < query_len
|
||||
input_ids = tl.load(target_input_ids_ptr + query_start + block, mask=mask)
|
||||
tl.store(eagle_input_ids_ptr + query_start + block - 1, input_ids, mask=mask)
|
||||
|
||||
last_token_index = query_start + query_len - 1
|
||||
tl.store(last_token_indices_ptr + batch_idx, last_token_index)
|
||||
tl.store(eagle_input_ids_ptr + last_token_index, next_token)
|
||||
|
||||
# Copy positions.
|
||||
for i in range(0, query_len, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < query_len
|
||||
target_pos = tl.load(target_positions_ptr + query_start + block, mask=mask)
|
||||
tl.store(eagle_positions_ptr + query_start + block, target_pos, mask=mask)
|
||||
|
||||
|
||||
def prepare_eagle_inputs(
|
||||
input_buffers: InputBuffers,
|
||||
input_batch: InputBatch,
|
||||
# [num_reqs]
|
||||
num_sampled: torch.Tensor,
|
||||
# [num_reqs]
|
||||
num_rejected: torch.Tensor,
|
||||
# [max_num_reqs]
|
||||
last_sampled: torch.Tensor,
|
||||
# [max_num_reqs]
|
||||
next_prefill_tokens: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
num_reqs = input_batch.num_reqs
|
||||
last_token_indices = torch.empty(
|
||||
num_reqs,
|
||||
dtype=torch.int64,
|
||||
device=num_sampled.device,
|
||||
)
|
||||
_prepare_eagle_inputs_kernel[(num_reqs,)](
|
||||
last_token_indices,
|
||||
input_buffers.input_ids,
|
||||
input_buffers.positions,
|
||||
input_batch.input_ids,
|
||||
input_batch.positions,
|
||||
input_batch.idx_mapping,
|
||||
last_sampled,
|
||||
next_prefill_tokens,
|
||||
num_sampled,
|
||||
num_rejected,
|
||||
input_batch.query_start_loc,
|
||||
BLOCK_SIZE=1024,
|
||||
)
|
||||
return last_token_indices
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _prepare_eagle_docode_kernel(
|
||||
draft_tokens_ptr,
|
||||
output_hidden_states_ptr,
|
||||
output_hidden_states_stride,
|
||||
last_token_indices_ptr,
|
||||
target_seq_lens_ptr,
|
||||
num_rejected_ptr,
|
||||
input_ids_ptr,
|
||||
positions_ptr,
|
||||
input_hidden_states_ptr,
|
||||
input_hidden_states_stride,
|
||||
query_start_loc_ptr,
|
||||
seq_lens_ptr,
|
||||
hidden_size,
|
||||
max_model_len,
|
||||
max_num_reqs,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
num_reqs = tl.num_programs(0) - 1
|
||||
if req_idx == num_reqs:
|
||||
# Compute query_start_loc. Pad it with the last query_start_loc
|
||||
# for CUDA graphs.
|
||||
for i in range(0, max_num_reqs + 1, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
q = tl.where(block < num_reqs, block, num_reqs)
|
||||
mask = block < max_num_reqs + 1
|
||||
tl.store(query_start_loc_ptr + block, q, mask=mask)
|
||||
# Pad seq_lens for CUDA graphs.
|
||||
for i in range(req_idx, max_num_reqs, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < max_num_reqs
|
||||
tl.store(seq_lens_ptr + block, 0, mask=mask)
|
||||
return
|
||||
|
||||
# draft token -> input id.
|
||||
draft_token = tl.load(draft_tokens_ptr + req_idx)
|
||||
tl.store(input_ids_ptr + req_idx, draft_token)
|
||||
|
||||
# output hidden states -> input hidden states.
|
||||
src_idx = tl.load(last_token_indices_ptr + req_idx)
|
||||
for i in range(0, hidden_size, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < hidden_size
|
||||
output_hidden_states = tl.load(
|
||||
output_hidden_states_ptr + src_idx * output_hidden_states_stride + block,
|
||||
mask=mask,
|
||||
)
|
||||
tl.store(
|
||||
input_hidden_states_ptr + req_idx * input_hidden_states_stride + block,
|
||||
output_hidden_states,
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
# Compute position and seq_lens.
|
||||
# NOTE(woosuk): To prevent out-of-range access, we clamp these values
|
||||
# if they reach the max model length.
|
||||
position = tl.load(positions_ptr + req_idx)
|
||||
position = tl.minimum(position + 1, max_model_len - 1)
|
||||
tl.store(positions_ptr + req_idx, position)
|
||||
|
||||
target_seq_len = tl.load(target_seq_lens_ptr + req_idx)
|
||||
num_rejected = tl.load(num_rejected_ptr + req_idx)
|
||||
seq_len = target_seq_len - num_rejected
|
||||
seq_len = tl.minimum(seq_len + 1, max_model_len)
|
||||
tl.store(seq_lens_ptr + req_idx, seq_len)
|
||||
|
||||
|
||||
def prepare_eagle_decode(
|
||||
draft_tokens: torch.Tensor,
|
||||
output_hidden_states: torch.Tensor,
|
||||
last_token_indices: torch.Tensor,
|
||||
target_seq_lens: torch.Tensor,
|
||||
num_rejected: torch.Tensor,
|
||||
input_buffers: InputBuffers,
|
||||
input_hidden_states: torch.Tensor,
|
||||
max_model_len: int,
|
||||
max_num_reqs: int,
|
||||
):
|
||||
num_reqs = draft_tokens.shape[0]
|
||||
hidden_size = output_hidden_states.shape[-1]
|
||||
_prepare_eagle_docode_kernel[(num_reqs + 1,)](
|
||||
draft_tokens,
|
||||
output_hidden_states,
|
||||
output_hidden_states.stride(0),
|
||||
last_token_indices,
|
||||
target_seq_lens,
|
||||
num_rejected,
|
||||
input_buffers.input_ids,
|
||||
input_buffers.positions,
|
||||
input_hidden_states,
|
||||
input_hidden_states.stride(0),
|
||||
input_buffers.query_start_loc,
|
||||
input_buffers.seq_lens,
|
||||
hidden_size,
|
||||
max_model_len,
|
||||
max_num_reqs,
|
||||
BLOCK_SIZE=1024,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _update_eagle_inputs_kernel(
|
||||
input_ids_ptr,
|
||||
positions_ptr,
|
||||
input_hidden_states_ptr,
|
||||
input_hidden_states_stride,
|
||||
seq_lens_ptr,
|
||||
max_model_len,
|
||||
draft_tokens_ptr,
|
||||
output_hidden_states_ptr,
|
||||
output_hidden_states_stride,
|
||||
hidden_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
|
||||
# Draft token -> Input ID.
|
||||
draft_token = tl.load(draft_tokens_ptr + req_idx)
|
||||
tl.store(input_ids_ptr + req_idx, draft_token)
|
||||
|
||||
# Output hidden states -> Input hidden states.
|
||||
for i in range(0, hidden_size, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < hidden_size
|
||||
output_hidden_states = tl.load(
|
||||
output_hidden_states_ptr + req_idx * output_hidden_states_stride + block,
|
||||
mask=mask,
|
||||
)
|
||||
tl.store(
|
||||
input_hidden_states_ptr + req_idx * input_hidden_states_stride + block,
|
||||
output_hidden_states,
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
# Increment position and seq_lens.
|
||||
# NOTE(woosuk): To prevent out-of-range access, we clamp these values
|
||||
# if they reach the max model length.
|
||||
position = tl.load(positions_ptr + req_idx)
|
||||
position = tl.minimum(position + 1, max_model_len - 1)
|
||||
tl.store(positions_ptr + req_idx, position)
|
||||
|
||||
seq_len = tl.load(seq_lens_ptr + req_idx)
|
||||
seq_len = tl.minimum(seq_len + 1, max_model_len)
|
||||
tl.store(seq_lens_ptr + req_idx, seq_len)
|
||||
|
||||
|
||||
def update_eagle_inputs(
|
||||
draft_tokens: torch.Tensor,
|
||||
output_hidden_states: torch.Tensor,
|
||||
input_buffers: InputBuffers,
|
||||
hidden_states: torch.Tensor,
|
||||
max_model_len: int,
|
||||
):
|
||||
num_reqs, hidden_size = output_hidden_states.shape
|
||||
_update_eagle_inputs_kernel[(num_reqs,)](
|
||||
input_buffers.input_ids,
|
||||
input_buffers.positions,
|
||||
hidden_states,
|
||||
hidden_states.stride(0),
|
||||
input_buffers.seq_lens,
|
||||
max_model_len,
|
||||
draft_tokens,
|
||||
output_hidden_states,
|
||||
output_hidden_states.stride(0),
|
||||
hidden_size,
|
||||
BLOCK_SIZE=1024,
|
||||
)
|
||||
52
vllm/v1/worker/gpu/spec_decode/eagle/utils.py
Normal file
52
vllm/v1/worker/gpu/spec_decode/eagle/utils.py
Normal file
@@ -0,0 +1,52 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
|
||||
|
||||
def load_eagle_model(target_model: nn.Module, vllm_config: VllmConfig) -> nn.Module:
|
||||
from vllm.compilation.backends import set_model_tag
|
||||
|
||||
speculative_config = vllm_config.speculative_config
|
||||
assert speculative_config is not None
|
||||
draft_model_config = speculative_config.draft_model_config
|
||||
with set_model_tag("eagle_head"):
|
||||
eagle_model = get_model(
|
||||
vllm_config=vllm_config, model_config=draft_model_config
|
||||
)
|
||||
|
||||
# Share target embeddings when the draft checkpoint does not include
|
||||
# its own vocab embedding table.
|
||||
share_embeddings = True
|
||||
if hasattr(eagle_model, "has_own_embed_tokens"):
|
||||
share_embeddings = not eagle_model.has_own_embed_tokens
|
||||
if share_embeddings:
|
||||
target_language_model = (
|
||||
target_model.get_language_model()
|
||||
if hasattr(target_model, "get_language_model")
|
||||
else target_model
|
||||
)
|
||||
inner_model = getattr(target_language_model, "model", None)
|
||||
target_embed_tokens = None
|
||||
if inner_model is not None:
|
||||
if hasattr(inner_model, "embed_tokens"):
|
||||
target_embed_tokens = inner_model.embed_tokens
|
||||
elif hasattr(inner_model, "embedding"):
|
||||
target_embed_tokens = inner_model.embedding
|
||||
if target_embed_tokens is not None and hasattr(eagle_model, "model"):
|
||||
if hasattr(eagle_model.model, "embed_tokens"):
|
||||
del eagle_model.model.embed_tokens
|
||||
eagle_model.model.embed_tokens = target_embed_tokens
|
||||
|
||||
# Only share target lm_head when the draft model does not own one.
|
||||
share_lm_head = True
|
||||
if hasattr(eagle_model, "has_own_lm_head"):
|
||||
share_lm_head = not eagle_model.has_own_lm_head
|
||||
if share_lm_head and hasattr(target_model, "lm_head"):
|
||||
if hasattr(eagle_model, "lm_head"):
|
||||
del eagle_model.lm_head
|
||||
eagle_model.lm_head = target_model.lm_head
|
||||
|
||||
return eagle_model
|
||||
71
vllm/v1/worker/gpu/spec_decode/rejection_sample.py
Normal file
71
vllm/v1/worker/gpu/spec_decode/rejection_sample.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _rejection_sample_kernel(
|
||||
sampled_ptr, # [num_reqs, num_speculative_steps + 1]
|
||||
sampled_stride,
|
||||
num_sampled_ptr, # [num_reqs]
|
||||
target_sampled_ptr, # [num_draft_tokens + num_reqs]
|
||||
input_ids_ptr, # [num_draft_tokens + num_reqs]
|
||||
cu_num_logits_ptr, # [num_reqs + 1]
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
start_idx = tl.load(cu_num_logits_ptr + req_idx)
|
||||
end_idx = tl.load(cu_num_logits_ptr + req_idx + 1)
|
||||
num_tokens = end_idx - start_idx
|
||||
|
||||
num_sampled = 0
|
||||
rejected = False
|
||||
for i in range(num_tokens - 1):
|
||||
if not rejected:
|
||||
target_sampled = tl.load(target_sampled_ptr + start_idx + i)
|
||||
draft_sampled = tl.load(input_ids_ptr + start_idx + i + 1)
|
||||
tl.store(sampled_ptr + req_idx * sampled_stride + i, target_sampled)
|
||||
num_sampled += 1
|
||||
if target_sampled != draft_sampled:
|
||||
rejected = True
|
||||
if not rejected:
|
||||
target_sampled = tl.load(target_sampled_ptr + start_idx + num_tokens - 1)
|
||||
tl.store(
|
||||
sampled_ptr + req_idx * sampled_stride + num_tokens - 1, target_sampled
|
||||
)
|
||||
num_sampled += 1
|
||||
tl.store(num_sampled_ptr + req_idx, num_sampled)
|
||||
|
||||
|
||||
def rejection_sample(
|
||||
# [num_draft_tokens + num_reqs]
|
||||
target_sampled: torch.Tensor,
|
||||
# [num_draft_tokens + num_reqs]
|
||||
input_ids: torch.Tensor,
|
||||
# [num_reqs + 1]
|
||||
cu_num_logits: torch.Tensor,
|
||||
num_speculative_steps: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
num_reqs = cu_num_logits.shape[0] - 1
|
||||
sampled = torch.empty(
|
||||
num_reqs,
|
||||
num_speculative_steps + 1,
|
||||
dtype=target_sampled.dtype,
|
||||
device=target_sampled.device,
|
||||
)
|
||||
num_sampled = torch.empty(
|
||||
num_reqs,
|
||||
dtype=torch.int32,
|
||||
device=target_sampled.device,
|
||||
)
|
||||
_rejection_sample_kernel[(num_reqs,)](
|
||||
sampled,
|
||||
sampled.stride(0),
|
||||
num_sampled,
|
||||
target_sampled,
|
||||
input_ids,
|
||||
cu_num_logits,
|
||||
num_warps=1,
|
||||
)
|
||||
return sampled, num_sampled
|
||||
47
vllm/v1/worker/gpu/spec_decode/utils.py
Normal file
47
vllm/v1/worker/gpu/spec_decode/utils.py
Normal file
@@ -0,0 +1,47 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.v1.outputs import DraftTokenIds
|
||||
from vllm.v1.worker.gpu.async_utils import async_copy_to_np
|
||||
from vllm.v1.worker.gpu.input_batch import InputBatch
|
||||
|
||||
|
||||
class DraftTokensHandler:
|
||||
def __init__(self, device: torch.device | None = None):
|
||||
self.device = device
|
||||
self.copy_stream = torch.cuda.Stream(device)
|
||||
self.copy_event = torch.cuda.Event()
|
||||
|
||||
self.req_ids: list[str] = []
|
||||
self.draft_tokens_np: np.ndarray | None = None
|
||||
self.num_draft_tokens: int = 0
|
||||
|
||||
def set_draft_tokens(
|
||||
self, input_batch: InputBatch, draft_tokens: torch.Tensor
|
||||
) -> None:
|
||||
self.req_ids = input_batch.req_ids
|
||||
self.num_draft_tokens = draft_tokens.shape[1]
|
||||
if not input_batch.has_structured_output_reqs:
|
||||
# No draft token validation needs to be performed by
|
||||
# the scheduler for this batch.
|
||||
self.draft_tokens_np = None
|
||||
return
|
||||
|
||||
# For spec decoding + structured outputs, we must transfer the
|
||||
# draft tokens back to the scheduler for grammar validation.
|
||||
current_stream = torch.cuda.current_stream(self.device)
|
||||
self.copy_stream.wait_stream(current_stream)
|
||||
with torch.cuda.stream(self.copy_stream):
|
||||
self.draft_tokens_np = async_copy_to_np(draft_tokens)
|
||||
self.copy_event.record()
|
||||
|
||||
def get_draft_tokens(self) -> DraftTokenIds | None:
|
||||
if self.draft_tokens_np is not None:
|
||||
self.copy_event.synchronize()
|
||||
draft_token_ids = self.draft_tokens_np.tolist()
|
||||
else:
|
||||
# This case only happens when async scheduling is disabled.
|
||||
draft_token_ids = [[-1] * self.num_draft_tokens for _ in self.req_ids]
|
||||
return DraftTokenIds(self.req_ids, draft_token_ids)
|
||||
123
vllm/v1/worker/gpu/states.py
Normal file
123
vllm/v1/worker/gpu/states.py
Normal file
@@ -0,0 +1,123 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor
|
||||
|
||||
|
||||
class RequestState:
|
||||
def __init__(
|
||||
self,
|
||||
max_num_reqs: int,
|
||||
max_model_len: int,
|
||||
max_num_batched_tokens: int,
|
||||
num_speculative_steps: int,
|
||||
vocab_size: int,
|
||||
device: torch.device,
|
||||
):
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_model_len = max_model_len
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
self.num_speculative_steps = num_speculative_steps
|
||||
self.vocab_size = vocab_size
|
||||
self.device = device
|
||||
|
||||
self.req_id_to_index: dict[str, int] = {}
|
||||
self.index_to_req_id: dict[int, str] = {}
|
||||
self.free_indices = list(range(max_num_reqs))
|
||||
|
||||
# NOTE(woosuk): This tensor can be extremely large (e.g., several GBs)
|
||||
# depending on the configured max_num_reqs and max_model_len.
|
||||
# To save GPU memory, we use UVA instead of GPU for this tensor.
|
||||
self.all_token_ids = StagedWriteTensor(
|
||||
(self.max_num_reqs, self.max_model_len),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
uva_instead_of_gpu=True,
|
||||
)
|
||||
# NOTE(woosuk): Distinguish clearly between prompt_len and prefill_len:
|
||||
# - prompt_len: Number of tokens in the user-provided prompt.
|
||||
# - prefill_len: Number of tokens passed into the model runner.
|
||||
# This can include the prompt and additional partial output tokens,
|
||||
# so prefill_len >= prompt_len.
|
||||
# Usually, prefill_len equals prompt_len, but in cases such as resumption after
|
||||
# preemption, prefill_len may be greater. Differentiating between these values
|
||||
# is crucial, as certain features such as prompt logprobs or frequency penalties
|
||||
# must treat prompt and output tokens separately.
|
||||
self.prompt_len = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32)
|
||||
self.prefill_len = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32)
|
||||
# total_len = prompt_len + output_len. It grows as the request progresses.
|
||||
self.total_len = StagedWriteTensor(
|
||||
self.max_num_reqs, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
# Number of computed tokens.
|
||||
self.num_computed_prefill_tokens = np.zeros(self.max_num_reqs, dtype=np.int32)
|
||||
self.num_computed_tokens = StagedWriteTensor(
|
||||
self.max_num_reqs, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
# Last sampled tokens.
|
||||
self.last_sampled_tokens = torch.zeros(
|
||||
self.max_num_reqs, 1, dtype=torch.int64, device=device
|
||||
)
|
||||
|
||||
# Draft tokens.
|
||||
self.draft_tokens = torch.zeros(
|
||||
self.max_num_reqs,
|
||||
self.num_speculative_steps,
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
self.next_prefill_tokens = torch.zeros(
|
||||
self.max_num_reqs, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
@property
|
||||
def num_reqs(self) -> int:
|
||||
return len(self.req_id_to_index)
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
req_id: str,
|
||||
prompt_len: int,
|
||||
all_token_ids: list[int],
|
||||
num_computed_tokens: int,
|
||||
) -> None:
|
||||
assert len(self.free_indices) > 0, "No free indices"
|
||||
req_idx = self.free_indices.pop()
|
||||
self.req_id_to_index[req_id] = req_idx
|
||||
self.index_to_req_id[req_idx] = req_id
|
||||
|
||||
self.prompt_len.np[req_idx] = prompt_len
|
||||
prefill_len = len(all_token_ids)
|
||||
assert prefill_len >= prompt_len, (
|
||||
f"prefill_len {prefill_len} < prompt_len {prompt_len}"
|
||||
)
|
||||
self.prefill_len.np[req_idx] = prefill_len
|
||||
self.total_len.stage_write_elem(req_idx, prefill_len)
|
||||
self.all_token_ids.stage_write(req_idx, 0, all_token_ids)
|
||||
self.num_computed_prefill_tokens[req_idx] = num_computed_tokens
|
||||
self.num_computed_tokens.stage_write_elem(req_idx, num_computed_tokens)
|
||||
|
||||
def apply_staged_writes(self) -> None:
|
||||
self.prompt_len.copy_to_uva()
|
||||
self.prefill_len.copy_to_uva()
|
||||
self.total_len.apply_write()
|
||||
self.all_token_ids.apply_write()
|
||||
self.num_computed_tokens.apply_write()
|
||||
|
||||
def remove_request(self, req_id: str) -> None:
|
||||
req_idx = self.req_id_to_index.pop(req_id, None)
|
||||
if req_idx is None:
|
||||
# Request not found.
|
||||
return
|
||||
self.index_to_req_id.pop(req_idx, None)
|
||||
self.free_indices.append(req_idx)
|
||||
|
||||
def any_prefills(self, idx_mapping_np: np.ndarray) -> bool:
|
||||
return np.any(
|
||||
self.num_computed_prefill_tokens[idx_mapping_np]
|
||||
< self.prefill_len.np[idx_mapping_np]
|
||||
)
|
||||
115
vllm/v1/worker/gpu/structured_outputs.py
Normal file
115
vllm/v1/worker/gpu/structured_outputs.py
Normal file
@@ -0,0 +1,115 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.worker.gpu.buffer_utils import async_copy_to_gpu
|
||||
from vllm.v1.worker.gpu.input_batch import InputBatch
|
||||
|
||||
|
||||
class StructuredOutputsWorker:
|
||||
def __init__(self, max_num_logits: int, vocab_size: int, device: torch.device):
|
||||
self.logits_indices = torch.zeros(
|
||||
max_num_logits, dtype=torch.int32, device=device
|
||||
)
|
||||
self.grammar_bitmask = torch.zeros(
|
||||
(max_num_logits, cdiv(vocab_size, 32)), dtype=torch.int32, device=device
|
||||
)
|
||||
self.device = device
|
||||
self.copy_stream = torch.cuda.Stream()
|
||||
|
||||
def apply_grammar_bitmask(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
input_batch: InputBatch,
|
||||
grammar_req_ids: list[str],
|
||||
grammar_bitmask: np.ndarray,
|
||||
) -> None:
|
||||
if not grammar_req_ids:
|
||||
return
|
||||
|
||||
# Asynchronously copy the bitmask to GPU.
|
||||
with torch.cuda.stream(self.copy_stream):
|
||||
bitmask = async_copy_to_gpu(
|
||||
grammar_bitmask, out=self.grammar_bitmask[: grammar_bitmask.shape[0]]
|
||||
)
|
||||
|
||||
# Construct bitmask -> logits mapping
|
||||
mapping: list[int] = []
|
||||
req_ids = input_batch.req_ids
|
||||
cu_num_logits = input_batch.cu_num_logits_np.tolist()
|
||||
req_id_to_idx = {req_id: i for i, req_id in enumerate(req_ids)}
|
||||
for grammar_req_id in grammar_req_ids:
|
||||
req_idx = req_id_to_idx[grammar_req_id]
|
||||
logits_start_idx = cu_num_logits[req_idx]
|
||||
logits_end_idx = cu_num_logits[req_idx + 1]
|
||||
mapping.extend(range(logits_start_idx, logits_end_idx))
|
||||
|
||||
# Asynchronously copy the mapping to GPU.
|
||||
with torch.cuda.stream(self.copy_stream):
|
||||
logits_indices = torch.tensor(
|
||||
mapping, dtype=torch.int32, device="cpu", pin_memory=True
|
||||
)
|
||||
logits_indices = self.logits_indices[: len(mapping)].copy_(
|
||||
logits_indices, non_blocking=True
|
||||
)
|
||||
|
||||
# Ensure all async copies are complete before launching the kernel.
|
||||
current_stream = torch.cuda.current_stream()
|
||||
current_stream.wait_stream(self.copy_stream)
|
||||
|
||||
num_masks = bitmask.shape[0]
|
||||
assert num_masks == len(mapping)
|
||||
vocab_size = logits.shape[-1]
|
||||
BLOCK_SIZE = 8192
|
||||
grid = (num_masks, triton.cdiv(vocab_size, BLOCK_SIZE))
|
||||
_apply_grammar_bitmask_kernel[grid](
|
||||
logits,
|
||||
logits.stride(0),
|
||||
logits_indices,
|
||||
bitmask,
|
||||
bitmask.stride(0),
|
||||
vocab_size,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
|
||||
# Ensure the copy stream waits for the device tensors to finish being used
|
||||
# before it re-uses or deallocates them
|
||||
self.copy_stream.wait_stream(current_stream)
|
||||
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/mlc-ai/xgrammar/blob/main/python/xgrammar/kernels/apply_token_bitmask_inplace_triton.py
|
||||
@triton.jit
|
||||
def _apply_grammar_bitmask_kernel(
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
logits_indices_ptr,
|
||||
bitmask_ptr,
|
||||
bitmask_stride,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
bitmask_idx = tl.program_id(0)
|
||||
logits_idx = tl.load(logits_indices_ptr + bitmask_idx)
|
||||
|
||||
# Load the bitmask.
|
||||
block_id = tl.program_id(1)
|
||||
bitmask_offset = (block_id * BLOCK_SIZE) // 32 + tl.arange(0, BLOCK_SIZE // 32)
|
||||
packed_bitmask = tl.load(
|
||||
bitmask_ptr + bitmask_idx * bitmask_stride + bitmask_offset,
|
||||
mask=bitmask_offset < bitmask_stride,
|
||||
)
|
||||
# Unpack the bitmask.
|
||||
bitmask = ((packed_bitmask[:, None] >> (tl.arange(0, 32)[None, :])) & 1) == 0
|
||||
bitmask = bitmask.reshape(BLOCK_SIZE)
|
||||
|
||||
# Apply the bitmask to the logits.
|
||||
block_offset = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
tl.store(
|
||||
logits_ptr + logits_idx * logits_stride + block_offset,
|
||||
-float("inf"),
|
||||
mask=bitmask & (block_offset < vocab_size),
|
||||
)
|
||||
1030
vllm/v1/worker/gpu_input_batch.py
Normal file
1030
vllm/v1/worker/gpu_input_batch.py
Normal file
File diff suppressed because it is too large
Load Diff
6351
vllm/v1/worker/gpu_model_runner.py
Normal file
6351
vllm/v1/worker/gpu_model_runner.py
Normal file
File diff suppressed because it is too large
Load Diff
494
vllm/v1/worker/gpu_ubatch_wrapper.py
Normal file
494
vllm/v1/worker/gpu_ubatch_wrapper.py
Normal file
@@ -0,0 +1,494 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import threading
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.cuda_graph import CUDAGraphWrapper
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.distributed import get_ep_group
|
||||
from vllm.distributed.device_communicators.pynccl_allocator import set_graph_pool_id
|
||||
from vllm.forward_context import (
|
||||
DPMetadata,
|
||||
create_forward_context,
|
||||
get_forward_context,
|
||||
override_forward_context,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.offloader.base import get_offloader
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.import_utils import has_deep_gemm
|
||||
from vllm.utils.platform_utils import num_compute_units
|
||||
from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class UbatchMetadata:
|
||||
context: UBatchContext
|
||||
input_ids: torch.Tensor
|
||||
positions: torch.Tensor
|
||||
inputs_embeds: torch.Tensor | None
|
||||
intermediate_tensors: IntermediateTensors | None
|
||||
num_tokens: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class CUDAGraphMetaData:
|
||||
cudagraph: torch.cuda.CUDAGraph
|
||||
ubatch_metadata: UbatchMetadata
|
||||
outputs: Any | None = None
|
||||
|
||||
|
||||
class SMControlContextManager:
|
||||
def __init__(
|
||||
self,
|
||||
comm_sms: int,
|
||||
set_comm_sms: Callable[[int], None],
|
||||
set_compute_sms: Callable[[int], None],
|
||||
):
|
||||
"""
|
||||
Context manager for controlling SM (Streaming Multiprocessor)
|
||||
allocation. Upon entering the context, it sets the number of SMs
|
||||
allocated for communication and computation to comm_sms and
|
||||
total_sms - comm_sms respectively. Upon exiting, it restores the
|
||||
allocation to use all available SMs (i.e. total_sms).
|
||||
|
||||
Args:
|
||||
comm_sms (int): The number of SMs to allocate for communication.
|
||||
(The remainder will be used for computation.)
|
||||
set_comm_sms (Callable[[int], None]):
|
||||
A function that sets the number of SMs for communication.
|
||||
set_compute_sms (Callable[[int], None]):
|
||||
A function that sets the number of SMs for computation.
|
||||
"""
|
||||
|
||||
assert current_platform.is_cuda(), (
|
||||
"SM control is currently only supported on CUDA"
|
||||
)
|
||||
|
||||
total_sms = num_compute_units(torch.cuda.current_device())
|
||||
|
||||
assert comm_sms < total_sms
|
||||
self.total_sms = total_sms
|
||||
self.compute_sms = total_sms - comm_sms
|
||||
self.comm_sms = comm_sms
|
||||
self.set_comm_sms = set_comm_sms
|
||||
self.set_compute_sms = set_compute_sms
|
||||
|
||||
def __enter__(self):
|
||||
self.set_comm_sms(self.comm_sms)
|
||||
self.set_compute_sms(self.compute_sms)
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.set_comm_sms(self.total_sms)
|
||||
self.set_compute_sms(self.total_sms)
|
||||
|
||||
|
||||
class UBatchWrapper:
|
||||
def __init__(
|
||||
self,
|
||||
runnable: Callable,
|
||||
vllm_config: VllmConfig,
|
||||
runtime_mode: CUDAGraphMode,
|
||||
device: torch.cuda.device,
|
||||
):
|
||||
self.runnable = runnable
|
||||
self.vllm_config = vllm_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.comm_stream = torch.cuda.Stream(device=device)
|
||||
# Ubatch threads plus the main thread
|
||||
self.ready_barrier = threading.Barrier(
|
||||
self.vllm_config.parallel_config.num_ubatches + 1
|
||||
)
|
||||
|
||||
self.cudagraphs: dict[int, CUDAGraphMetaData] = {}
|
||||
|
||||
self.cudagraph_wrapper = None
|
||||
self.graph_pool = None
|
||||
if runtime_mode is not CUDAGraphMode.NONE:
|
||||
self.cudagraph_wrapper = CUDAGraphWrapper(
|
||||
runnable, vllm_config, runtime_mode=runtime_mode
|
||||
)
|
||||
self.graph_pool = current_platform.get_global_graph_pool()
|
||||
|
||||
self.sm_control = self._create_sm_control_context(vllm_config)
|
||||
self.device = device
|
||||
|
||||
@staticmethod
|
||||
def _create_sm_control_context(vllm_config: VllmConfig):
|
||||
comm_sms: int = envs.VLLM_DBO_COMM_SMS
|
||||
|
||||
set_comm_sms = lambda sms: None
|
||||
if vllm_config.parallel_config.enable_expert_parallel:
|
||||
# Currently only DeepEP highthroughput supports SM control so this
|
||||
# only affects that case.
|
||||
ep_group = get_ep_group()
|
||||
device_communicator = ep_group.device_communicator
|
||||
all2all_manager = None
|
||||
if device_communicator is not None:
|
||||
all2all_manager = device_communicator.all2all_manager
|
||||
|
||||
if all2all_manager is not None:
|
||||
max_sms_used = all2all_manager.max_sms_used()
|
||||
if max_sms_used is not None:
|
||||
comm_sms = min(comm_sms, max_sms_used)
|
||||
|
||||
if comm_sms > 0 and all2all_manager is not None:
|
||||
set_comm_sms = lambda sms: all2all_manager.set_num_sms(sms)
|
||||
|
||||
# TODO(lucas): support other kernels besides DeepGEMM
|
||||
set_compute_sms = lambda sms: None
|
||||
if has_deep_gemm() and comm_sms > 0:
|
||||
import deep_gemm as dg
|
||||
|
||||
set_compute_sms = lambda sms: dg.set_num_sms(sms)
|
||||
|
||||
return SMControlContextManager(
|
||||
comm_sms=comm_sms,
|
||||
set_comm_sms=set_comm_sms,
|
||||
set_compute_sms=set_compute_sms,
|
||||
)
|
||||
|
||||
def __getattr__(self, key: str):
|
||||
# allow accessing the attributes of the runnable.
|
||||
if hasattr(self.runnable, key):
|
||||
return getattr(self.runnable, key)
|
||||
raise AttributeError(
|
||||
f"Attribute {key} not exists in the runnable of "
|
||||
f"cudagraph wrapper: {self.runnable}"
|
||||
)
|
||||
|
||||
def unwrap(self) -> Callable:
|
||||
# in case we need to access the original runnable.
|
||||
return self.runnable
|
||||
|
||||
def _capture_ubatches(self, ubatch_metadata, model) -> torch.Tensor:
|
||||
"""
|
||||
Capture a cudagraph for a microbatched run.
|
||||
|
||||
The logic here is somewhat complicated because we need to make sure that
|
||||
each of the ubatch threads initialize the cuda context before we start
|
||||
the graph capture.
|
||||
|
||||
The flow is as follows:
|
||||
1. The main thread starts up each ubatch thread. Each thread will
|
||||
initialize its cuda context (torch.cuda.current_blas_handle())
|
||||
before going to sleep upon entering the ubatch_context.
|
||||
|
||||
2. The main thread starts the graph capture and wakes up the first
|
||||
ubatch thread.
|
||||
|
||||
3. Each ubatch thread runs the model to completion and returns the
|
||||
completed output tensors back to the main thread.
|
||||
|
||||
4. The main thread stores the captured cudagraph along with its metadata
|
||||
and returns
|
||||
"""
|
||||
|
||||
@torch.inference_mode()
|
||||
def _capture_ubatch_thread(results, ubatch_metadata):
|
||||
torch.cuda.set_device(self.device)
|
||||
ubatch_context = ubatch_metadata.context
|
||||
with torch.cuda.stream(ubatch_context.compute_stream):
|
||||
_ = torch.cuda.current_blas_handle()
|
||||
with torch.cuda.stream(ubatch_context.comm_stream):
|
||||
_ = torch.cuda.current_blas_handle()
|
||||
with ubatch_context:
|
||||
model_output = model(
|
||||
input_ids=ubatch_metadata.input_ids,
|
||||
positions=ubatch_metadata.positions,
|
||||
intermediate_tensors=ubatch_metadata.intermediate_tensors,
|
||||
inputs_embeds=ubatch_metadata.inputs_embeds,
|
||||
)
|
||||
|
||||
results.append((ubatch_metadata.context.id, model_output))
|
||||
|
||||
results: list[tuple[int, torch.Tensor]] = []
|
||||
compute_stream = ubatch_metadata[0].context.compute_stream
|
||||
num_tokens = ubatch_metadata[0].num_tokens + ubatch_metadata[1].num_tokens
|
||||
|
||||
# Ubatches will manually manage the forward context, so we override
|
||||
# it to None here so we can have it restored correctly later
|
||||
with override_forward_context(None):
|
||||
ubatch_threads = []
|
||||
for metadata in ubatch_metadata:
|
||||
thread = threading.Thread(
|
||||
target=_capture_ubatch_thread,
|
||||
args=(
|
||||
results,
|
||||
metadata,
|
||||
),
|
||||
)
|
||||
ubatch_threads.append(thread)
|
||||
thread.start()
|
||||
self.ready_barrier.wait() # Wait for both threads to be ready
|
||||
|
||||
# Capture the cudagraph
|
||||
cudagraph_metadata = CUDAGraphMetaData(
|
||||
cudagraph=torch.cuda.CUDAGraph(),
|
||||
ubatch_metadata=ubatch_metadata,
|
||||
)
|
||||
if self.graph_pool is not None:
|
||||
set_graph_pool_id(self.graph_pool)
|
||||
else:
|
||||
set_graph_pool_id(current_platform.graph_pool_handle())
|
||||
|
||||
# Sync offloader's copy stream before capture.
|
||||
# Ensure any pre-capture prefetches from offloader are complete.
|
||||
get_offloader().sync_prev_onload()
|
||||
|
||||
with torch.cuda.graph(
|
||||
cudagraph_metadata.cudagraph,
|
||||
stream=compute_stream,
|
||||
pool=self.graph_pool,
|
||||
):
|
||||
ubatch_metadata[0].context.cpu_wait_event.set()
|
||||
for thread in ubatch_threads:
|
||||
thread.join()
|
||||
sorted_results = [value for position, value in sorted(results)]
|
||||
result = torch.cat(sorted_results, dim=0)
|
||||
cudagraph_metadata.outputs = result
|
||||
# Join offloader's copy stream after forward to avoid unjoined
|
||||
# stream error. The last layer's start_prefetch forks copy_stream,
|
||||
# but wait_prefetch only happens in the next forward pass.
|
||||
get_offloader().join_after_forward()
|
||||
self.cudagraphs[num_tokens] = cudagraph_metadata
|
||||
return cudagraph_metadata.outputs
|
||||
|
||||
def _run_ubatches(self, ubatch_metadata, model) -> torch.Tensor:
|
||||
@torch.inference_mode()
|
||||
def _ubatch_thread(results, model, ubatch_metadata):
|
||||
with ubatch_metadata.context:
|
||||
model_output = model(
|
||||
input_ids=ubatch_metadata.input_ids,
|
||||
positions=ubatch_metadata.positions,
|
||||
intermediate_tensors=ubatch_metadata.intermediate_tensors,
|
||||
inputs_embeds=ubatch_metadata.inputs_embeds,
|
||||
)
|
||||
results.append((ubatch_metadata.context.id, model_output))
|
||||
|
||||
results: list[tuple[int, torch.Tensor]] = []
|
||||
|
||||
# Ubatch threads will manually manage the forward context, so we
|
||||
# override it to None here so we can have it restored correctly
|
||||
# after both threads have finished
|
||||
with override_forward_context(None):
|
||||
ubatch_threads = []
|
||||
for metadata in ubatch_metadata:
|
||||
thread = threading.Thread(
|
||||
target=_ubatch_thread,
|
||||
args=(
|
||||
results,
|
||||
model,
|
||||
metadata,
|
||||
),
|
||||
)
|
||||
ubatch_threads.append(thread)
|
||||
thread.start()
|
||||
self.ready_barrier.wait() # Wait for both threads to be ready
|
||||
ubatch_metadata[0].context.cpu_wait_event.set()
|
||||
for thread in ubatch_threads:
|
||||
thread.join()
|
||||
sorted_results = [value for position, value in sorted(results)]
|
||||
result = torch.cat(sorted_results, dim=0)
|
||||
return result
|
||||
|
||||
def _make_ubatch_metadata(
|
||||
self,
|
||||
ubatch_slices,
|
||||
attn_metadata,
|
||||
slot_mapping,
|
||||
input_ids,
|
||||
positions,
|
||||
inputs_embeds,
|
||||
intermediate_tensors,
|
||||
compute_stream,
|
||||
dp_metadata,
|
||||
batch_descriptor,
|
||||
cudagraph_runtime_mode,
|
||||
) -> list[UbatchMetadata]:
|
||||
# Create one forward context per ubatch
|
||||
forward_contexts = []
|
||||
# slot_mapping can be None, an empty dict (from create_forward_context
|
||||
# converting None to {}), or a list of dicts (one per ubatch)
|
||||
has_slot_mapping = slot_mapping and isinstance(slot_mapping, list)
|
||||
for i, ubatch_slice in enumerate(ubatch_slices):
|
||||
forward_contexts.append(
|
||||
create_forward_context(
|
||||
attn_metadata[i] if attn_metadata is not None else None,
|
||||
self.vllm_config,
|
||||
dp_metadata=dp_metadata[i],
|
||||
batch_descriptor=batch_descriptor,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
slot_mapping=slot_mapping[i] if has_slot_mapping else None,
|
||||
)
|
||||
)
|
||||
|
||||
ubatch_ctxs = make_ubatch_contexts(
|
||||
num_micro_batches=len(ubatch_slices),
|
||||
comm_stream=self.comm_stream,
|
||||
compute_stream=compute_stream,
|
||||
forward_contexts=forward_contexts,
|
||||
ready_barrier=self.ready_barrier,
|
||||
)
|
||||
|
||||
ubatch_metadata: list[UbatchMetadata] = []
|
||||
for i, ubatch_slice in enumerate(ubatch_slices):
|
||||
(
|
||||
sliced_input_ids,
|
||||
sliced_positions,
|
||||
sliced_inputs_embeds,
|
||||
sliced_intermediate_tensors,
|
||||
) = self._slice_model_inputs(
|
||||
ubatch_slice.token_slice,
|
||||
input_ids,
|
||||
positions,
|
||||
inputs_embeds,
|
||||
intermediate_tensors,
|
||||
)
|
||||
ubatch_metadata.append(
|
||||
UbatchMetadata(
|
||||
context=ubatch_ctxs[i],
|
||||
input_ids=sliced_input_ids,
|
||||
positions=sliced_positions,
|
||||
inputs_embeds=sliced_inputs_embeds,
|
||||
intermediate_tensors=sliced_intermediate_tensors,
|
||||
num_tokens=ubatch_slice.token_slice.stop
|
||||
- ubatch_slice.token_slice.start,
|
||||
)
|
||||
)
|
||||
|
||||
return ubatch_metadata
|
||||
|
||||
def _slice_model_inputs(
|
||||
self,
|
||||
tokens_slice: slice,
|
||||
input_ids,
|
||||
positions,
|
||||
inputs_embeds,
|
||||
intermediate_tensors,
|
||||
):
|
||||
sliced_input_ids = input_ids[tokens_slice]
|
||||
# if we are using mrope. Mrope adds an additional dimension to the
|
||||
# positions tensor
|
||||
if positions.ndim == 2:
|
||||
sliced_positions = positions[:, tokens_slice]
|
||||
else:
|
||||
sliced_positions = positions[tokens_slice]
|
||||
sliced_inputs_embeds = inputs_embeds[tokens_slice] if inputs_embeds else None
|
||||
sliced_intermediate_tensors = (
|
||||
intermediate_tensors[tokens_slice] if intermediate_tensors else None
|
||||
)
|
||||
|
||||
return (
|
||||
sliced_input_ids,
|
||||
sliced_positions,
|
||||
sliced_inputs_embeds,
|
||||
sliced_intermediate_tensors,
|
||||
)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
forward_context = get_forward_context()
|
||||
batch_descriptor = forward_context.batch_descriptor
|
||||
ubatch_slices = forward_context.ubatch_slices
|
||||
cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode
|
||||
|
||||
# If there's no ubatching, just run the runnable object
|
||||
if ubatch_slices is None:
|
||||
# This is to account for the case where ubatching was aborted.
|
||||
# When we capture full graphs we only capture one graph per shape,
|
||||
# meaning that if we have a ubatched cudagraph for the current
|
||||
# num_tokens, we don't have a non-ubatched one. Without this
|
||||
# check, the cudagraph wrapper will try to capture a cudagraph
|
||||
# for this shape during a normal run.
|
||||
if cudagraph_runtime_mode is CUDAGraphMode.FULL:
|
||||
assert batch_descriptor is not None
|
||||
if batch_descriptor.num_tokens in self.cudagraphs:
|
||||
cudagraph_runtime_mode = CUDAGraphMode.NONE
|
||||
|
||||
if cudagraph_runtime_mode in (CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE):
|
||||
return self.runnable(*args, **kwargs)
|
||||
else:
|
||||
assert self.cudagraph_wrapper is not None
|
||||
return self.cudagraph_wrapper(*args, **kwargs)
|
||||
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
slot_mapping = forward_context.slot_mapping
|
||||
num_tokens = sum(ubatch_slice.num_tokens for ubatch_slice in ubatch_slices)
|
||||
input_ids = kwargs["input_ids"]
|
||||
positions = kwargs["positions"]
|
||||
intermediate_tensors = kwargs["intermediate_tensors"]
|
||||
inputs_embeds = kwargs["inputs_embeds"]
|
||||
compute_stream = torch.cuda.current_stream()
|
||||
|
||||
dp_metadata = forward_context.dp_metadata
|
||||
|
||||
# We shouldn't be here unless we are running with multiple DP ranks
|
||||
assert dp_metadata is not None
|
||||
ubatch_dp_metadata = []
|
||||
for ubatch_slice in ubatch_slices:
|
||||
dp_size = self.vllm_config.parallel_config.data_parallel_size
|
||||
ubatch_num_tokens_across_dp = torch.tensor(
|
||||
[ubatch_slice.num_tokens] * dp_size, device="cpu", dtype=torch.int32
|
||||
)
|
||||
ubatch_dp_metadata.append(
|
||||
DPMetadata.make(
|
||||
self.vllm_config.parallel_config,
|
||||
ubatch_slice.num_tokens,
|
||||
ubatch_num_tokens_across_dp,
|
||||
)
|
||||
)
|
||||
|
||||
if (
|
||||
num_tokens not in self.cudagraphs
|
||||
and cudagraph_runtime_mode is CUDAGraphMode.FULL
|
||||
):
|
||||
ubatch_metadata = self._make_ubatch_metadata(
|
||||
ubatch_slices=ubatch_slices,
|
||||
attn_metadata=attn_metadata,
|
||||
slot_mapping=slot_mapping,
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
compute_stream=compute_stream,
|
||||
dp_metadata=ubatch_dp_metadata,
|
||||
batch_descriptor=batch_descriptor,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||
)
|
||||
with self.sm_control:
|
||||
return self._capture_ubatches(ubatch_metadata, self.model)
|
||||
elif (
|
||||
num_tokens in self.cudagraphs
|
||||
and cudagraph_runtime_mode is CUDAGraphMode.FULL
|
||||
):
|
||||
cudagraph_metadata = self.cudagraphs[num_tokens]
|
||||
# Sync offloader before replay - ensures any external dependencies
|
||||
# from pre-capture prefetches are satisfied.
|
||||
get_offloader().sync_prev_onload()
|
||||
cudagraph_metadata.cudagraph.replay()
|
||||
return cudagraph_metadata.outputs
|
||||
else:
|
||||
ubatch_metadata = self._make_ubatch_metadata(
|
||||
ubatch_slices=ubatch_slices,
|
||||
attn_metadata=attn_metadata,
|
||||
slot_mapping=slot_mapping,
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
compute_stream=compute_stream,
|
||||
dp_metadata=ubatch_dp_metadata,
|
||||
batch_descriptor=batch_descriptor,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||
)
|
||||
with self.sm_control:
|
||||
return self._run_ubatches(ubatch_metadata, self.model)
|
||||
1138
vllm/v1/worker/gpu_worker.py
Normal file
1138
vllm/v1/worker/gpu_worker.py
Normal file
File diff suppressed because it is too large
Load Diff
283
vllm/v1/worker/kv_connector_model_runner_mixin.py
Normal file
283
vllm/v1/worker/kv_connector_model_runner_mixin.py
Normal file
@@ -0,0 +1,283 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Define KV connector functionality mixin for model runners.
|
||||
"""
|
||||
|
||||
import copy
|
||||
from collections.abc import Generator
|
||||
from contextlib import AbstractContextManager, contextmanager, nullcontext
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.distributed.kv_transfer import (
|
||||
ensure_kv_transfer_shutdown,
|
||||
get_kv_transfer_group,
|
||||
has_kv_transfer_group,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
|
||||
from vllm.forward_context import get_forward_context, set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backend import AttentionBackend
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig
|
||||
from vllm.v1.outputs import (
|
||||
EMPTY_MODEL_RUNNER_OUTPUT,
|
||||
KVConnectorOutput,
|
||||
ModelRunnerOutput,
|
||||
)
|
||||
from vllm.v1.worker.utils import AttentionGroup
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# Defined as a kv connector functionality mixin for ModelRunner (GPU, TPU)
|
||||
class KVConnectorModelRunnerMixin:
|
||||
@staticmethod
|
||||
def ensure_kv_transfer_shutdown() -> None:
|
||||
# has_kv_transfer_group can be None during interpreter shutdown.
|
||||
if has_kv_transfer_group and has_kv_transfer_group(): # type: ignore[truthy-function]
|
||||
ensure_kv_transfer_shutdown()
|
||||
|
||||
@staticmethod
|
||||
def kv_connector_no_forward(
|
||||
scheduler_output: "SchedulerOutput", vllm_config: VllmConfig
|
||||
) -> ModelRunnerOutput:
|
||||
# KV send/recv even if no work to do.
|
||||
with (
|
||||
set_forward_context(None, vllm_config),
|
||||
KVConnectorModelRunnerMixin._get_kv_connector_output(
|
||||
scheduler_output, wait_for_save=False
|
||||
) as kv_connector_output,
|
||||
):
|
||||
pass
|
||||
|
||||
if kv_connector_output.is_empty():
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
output.kv_connector_output = kv_connector_output
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def maybe_get_kv_connector_output(
|
||||
scheduler_output: "SchedulerOutput",
|
||||
clear_metadata: bool = True,
|
||||
) -> AbstractContextManager[KVConnectorOutput | None]:
|
||||
return (
|
||||
KVConnectorModelRunnerMixin._get_kv_connector_output(
|
||||
scheduler_output, clear_metadata=clear_metadata
|
||||
)
|
||||
if has_kv_transfer_group()
|
||||
else nullcontext()
|
||||
)
|
||||
|
||||
# This context manager must be used within an active forward context.
|
||||
# It encapsulates the entire KV connector lifecycle within execute_model
|
||||
@staticmethod
|
||||
@contextmanager
|
||||
def _get_kv_connector_output(
|
||||
scheduler_output: "SchedulerOutput",
|
||||
wait_for_save: bool = True,
|
||||
clear_metadata: bool = True,
|
||||
) -> Generator[KVConnectorOutput, None, None]:
|
||||
output = KVConnectorOutput()
|
||||
|
||||
# Update KVConnector with the KVConnector metadata forward().
|
||||
kv_connector = get_kv_transfer_group()
|
||||
assert isinstance(kv_connector, KVConnectorBase)
|
||||
assert scheduler_output.kv_connector_metadata is not None
|
||||
kv_connector.bind_connector_metadata(scheduler_output.kv_connector_metadata)
|
||||
|
||||
# Background KV cache transfers happen here.
|
||||
# These transfers are designed to be async and the requests
|
||||
# involved may be disjoint from the running requests.
|
||||
# Do this here to save a collective_rpc.
|
||||
kv_connector.start_load_kv(get_forward_context())
|
||||
try:
|
||||
yield output
|
||||
finally:
|
||||
if wait_for_save:
|
||||
kv_connector.wait_for_save()
|
||||
|
||||
output.finished_sending, output.finished_recving = (
|
||||
kv_connector.get_finished(scheduler_output.finished_req_ids)
|
||||
)
|
||||
output.invalid_block_ids = kv_connector.get_block_ids_with_load_errors()
|
||||
|
||||
output.kv_connector_stats = kv_connector.get_kv_connector_stats()
|
||||
output.kv_cache_events = kv_connector.get_kv_connector_kv_cache_events()
|
||||
|
||||
if clear_metadata:
|
||||
kv_connector.clear_connector_metadata()
|
||||
|
||||
@staticmethod
|
||||
def clear_kv_connector_metadata() -> None:
|
||||
"""Clear the KV connector metadata. Call after draft model runs."""
|
||||
if has_kv_transfer_group():
|
||||
kv_connector = get_kv_transfer_group()
|
||||
kv_connector.clear_connector_metadata()
|
||||
|
||||
@staticmethod
|
||||
def use_uniform_kv_cache(
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
cache_dtype: CacheDType,
|
||||
) -> bool:
|
||||
"""
|
||||
Determines whether a uniform KV layout should be used.
|
||||
A uniform layout means all layers KV caches will share the same
|
||||
underlying tensor, where for a given block number, the respective
|
||||
KV data for all layers will be contiguous.
|
||||
This will allow efficient KV transfer of per-block KV data for all
|
||||
layers at once.
|
||||
Note this layout will only be applied given 3 conditions:
|
||||
1. The KV Cache config contains just a single group where all layers
|
||||
have the same page size.
|
||||
2. A KV connector is configured, and the KV connector instance prefers
|
||||
to use this layout (prefer_cross_layer_blocks() returns True)
|
||||
2. The flash attention backend supports this layout
|
||||
(get_kv_cache_stride_order(True) includes a placement for a
|
||||
num_layers dimension)
|
||||
|
||||
Note that the actual placement of the num_layers dimensions
|
||||
in the unified layers tensors will be determined by the attention
|
||||
backend.
|
||||
Thus, the layers KV data may still not be contiguous per block
|
||||
if the attention backend does not support it.
|
||||
|
||||
Args:
|
||||
attn_groups: The list of attention groups for this model
|
||||
cache_dtype: The KV cache dtype
|
||||
Returns:
|
||||
True if we should use a uniform KV cache layout.
|
||||
"""
|
||||
|
||||
if not has_kv_transfer_group():
|
||||
return False
|
||||
if not get_kv_transfer_group().prefer_cross_layer_blocks:
|
||||
return False
|
||||
|
||||
if len(attn_groups) != 1 or len(attn_groups[0]) != 1:
|
||||
return False
|
||||
|
||||
attn_group = attn_groups[0][0]
|
||||
kv_cache_spec = attn_group.kv_cache_spec
|
||||
if not isinstance(kv_cache_spec, AttentionSpec):
|
||||
return False
|
||||
|
||||
attn_backend = attn_group.backend
|
||||
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
||||
1234,
|
||||
kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads,
|
||||
kv_cache_spec.head_size,
|
||||
cache_dtype_str=cache_dtype,
|
||||
)
|
||||
|
||||
try:
|
||||
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
|
||||
include_num_layers_dimension=True
|
||||
)
|
||||
except (AttributeError, NotImplementedError):
|
||||
return False
|
||||
|
||||
# check that attention backend include a layers dimension
|
||||
return len(kv_cache_stride_order) == len(kv_cache_shape) + 1
|
||||
|
||||
@staticmethod
|
||||
def allocate_uniform_kv_caches(
|
||||
kv_cache_config: KVCacheConfig,
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
cache_dtype: CacheDType,
|
||||
device: torch.device,
|
||||
kernel_block_sizes: list[int],
|
||||
) -> tuple[dict[str, torch.Tensor], torch.Tensor, type[AttentionBackend]]:
|
||||
"""
|
||||
Initializes and reshapes KV caches for the simple case where all
|
||||
layers have the same layout.
|
||||
|
||||
This function assumes use_uniform_kv_cache() returned True.
|
||||
|
||||
Args:
|
||||
kv_cache_config: The KV cache config
|
||||
attn_groups: The list of attention groups for this model
|
||||
cache_dtype: The KV cache dtype
|
||||
device: The torch device to allocate on.
|
||||
kernel_block_sizes: The kernel block sizes for each KV cache group.
|
||||
Returns:
|
||||
A tuple (kv_caches, cross_layers_kv_cache, attn_backend) where:
|
||||
kv_caches is a dict mapping between layer names to their
|
||||
corresponding memory buffer for KV cache.
|
||||
cross_layers_kv_cache is the cross layers kv cache tensor
|
||||
attn_backend is the attention backend matching this tensor
|
||||
"""
|
||||
attn_group = attn_groups[0][0]
|
||||
kv_cache_spec = attn_group.kv_cache_spec
|
||||
assert isinstance(kv_cache_spec, AttentionSpec)
|
||||
|
||||
tensor_sizes = set(
|
||||
kv_cache_tensor.size for kv_cache_tensor in kv_cache_config.kv_cache_tensors
|
||||
)
|
||||
assert len(tensor_sizes) == 1
|
||||
tensor_size = tensor_sizes.pop()
|
||||
|
||||
page_size = kv_cache_spec.page_size_bytes
|
||||
assert tensor_size % page_size == 0
|
||||
num_blocks = tensor_size // page_size
|
||||
num_layers = len(kv_cache_config.kv_cache_tensors)
|
||||
total_size = tensor_size * num_layers
|
||||
|
||||
assert len(kernel_block_sizes) == 1
|
||||
kernel_block_size = kernel_block_sizes[0]
|
||||
num_blocks_per_kv_block = kv_cache_spec.block_size // kernel_block_size
|
||||
kernel_num_blocks = num_blocks * num_blocks_per_kv_block
|
||||
|
||||
attn_backend = attn_group.backend
|
||||
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
||||
kernel_num_blocks,
|
||||
kernel_block_size,
|
||||
kv_cache_spec.num_kv_heads,
|
||||
kv_cache_spec.head_size,
|
||||
cache_dtype_str=cache_dtype,
|
||||
)
|
||||
|
||||
# prepend a num_layers dimension into the shape
|
||||
kv_cache_shape = (num_layers,) + kv_cache_shape
|
||||
|
||||
try:
|
||||
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
|
||||
include_num_layers_dimension=True
|
||||
)
|
||||
assert len(kv_cache_stride_order) == len(kv_cache_shape)
|
||||
except (AttributeError, NotImplementedError):
|
||||
kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
|
||||
|
||||
kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)
|
||||
|
||||
logger.info("Allocating a cross layer KV cache of shape %s", kv_cache_shape)
|
||||
|
||||
# allocate one contiguous buffer for all layers
|
||||
cross_layers_kv_cache = (
|
||||
torch.zeros(total_size, dtype=torch.int8, device=device)
|
||||
.view(kv_cache_spec.dtype)
|
||||
.view(kv_cache_shape)
|
||||
)
|
||||
|
||||
# Maintain original KV shape view.
|
||||
inv_order = [
|
||||
kv_cache_stride_order.index(i) for i in range(len(kv_cache_stride_order))
|
||||
]
|
||||
permuted_kv_cache = cross_layers_kv_cache.permute(*inv_order)
|
||||
|
||||
kv_caches = {}
|
||||
for i, kv_cache_tensor in enumerate(kv_cache_config.kv_cache_tensors):
|
||||
tensor = permuted_kv_cache[i]
|
||||
for layer_name in kv_cache_tensor.shared_by:
|
||||
kv_caches[layer_name] = tensor
|
||||
|
||||
return kv_caches, cross_layers_kv_cache, attn_backend
|
||||
285
vllm/v1/worker/lora_model_runner_mixin.py
Normal file
285
vllm/v1/worker/lora_model_runner_mixin.py
Normal file
@@ -0,0 +1,285 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Define LoRA functionality mixin for model runners.
|
||||
"""
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import TypeAlias
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.lora import LoRAConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import LoRAMapping, LoRAMappingType
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
|
||||
from vllm.model_executor.models import supports_lora
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch as GPUInputBatch
|
||||
from vllm.v1.worker.tpu_input_batch import InputBatch as TPUInputBatch
|
||||
|
||||
InputBatch: TypeAlias = TPUInputBatch | GPUInputBatch
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# Defined as a mixin for GPUModelRunner
|
||||
class LoRAModelRunnerMixin:
|
||||
def load_lora_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
) -> nn.Module:
|
||||
if not supports_lora(model):
|
||||
raise ValueError(f"{model.__class__.__name__} does not support LoRA yet.")
|
||||
|
||||
# Add LoRA Manager to the Model Runner
|
||||
self.lora_manager = LRUCacheWorkerLoRAManager(
|
||||
vllm_config,
|
||||
device,
|
||||
model.embedding_modules,
|
||||
)
|
||||
return self.lora_manager.create_lora_manager(model, vllm_config)
|
||||
|
||||
def _set_active_loras(
|
||||
self,
|
||||
prompt_lora_mapping: tuple[int, ...],
|
||||
token_lora_mapping: tuple[int, ...],
|
||||
lora_requests: set[LoRARequest],
|
||||
mapping_type: LoRAMappingType = LoRAMappingType.LANGUAGE,
|
||||
) -> None:
|
||||
self._ensure_lora_enabled()
|
||||
|
||||
# Set is_prefill to True, so we always use the SGMV kernels on
|
||||
# non-cuda platforms.
|
||||
# On cuda platforms we use the same kernels for prefill and
|
||||
# decode and this flag is generally ignored.
|
||||
lora_mapping = LoRAMapping(
|
||||
token_lora_mapping,
|
||||
prompt_lora_mapping,
|
||||
is_prefill=True,
|
||||
type=mapping_type,
|
||||
)
|
||||
self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
|
||||
|
||||
def _ensure_lora_enabled(self) -> None:
|
||||
if not hasattr(self, "lora_manager"):
|
||||
raise RuntimeError("LoRA is not enabled. Use --enable-lora to enable LoRA.")
|
||||
|
||||
def set_active_loras(
|
||||
self,
|
||||
input_batch: InputBatch,
|
||||
num_scheduled_tokens: np.ndarray,
|
||||
num_sampled_tokens: np.ndarray | None = None,
|
||||
mapping_type: LoRAMappingType = LoRAMappingType.LANGUAGE,
|
||||
) -> None:
|
||||
if num_sampled_tokens is None:
|
||||
num_sampled_tokens = np.ones_like(num_scheduled_tokens, dtype=np.int32)
|
||||
|
||||
prompt_lora_mapping: tuple[int, ...] # of size np.sum(num_sampled_tokens)
|
||||
token_lora_mapping: tuple[int, ...] # of size np.sum(num_scheduled_tokens)
|
||||
lora_requests: set[LoRARequest]
|
||||
prompt_lora_mapping, token_lora_mapping, lora_requests = (
|
||||
input_batch.make_lora_inputs(num_scheduled_tokens, num_sampled_tokens)
|
||||
)
|
||||
return self._set_active_loras(
|
||||
prompt_lora_mapping, token_lora_mapping, lora_requests, mapping_type
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def maybe_setup_dummy_loras(
|
||||
self, lora_config: LoRAConfig | None, remove_lora: bool = True
|
||||
):
|
||||
if lora_config is None:
|
||||
yield
|
||||
else:
|
||||
# __enter__ code
|
||||
assert self.lora_manager is not None, "LoRA is not enabled"
|
||||
|
||||
num_loras = lora_config.max_loras
|
||||
lora_warmup_rank = (
|
||||
lora_config.max_lora_rank if lora_config.max_lora_rank < 8 else 8
|
||||
)
|
||||
# Make dummy lora requests
|
||||
lora_requests: set[LoRARequest] = {
|
||||
LoRARequest(
|
||||
lora_name=f"warmup_{lora_id}",
|
||||
lora_int_id=lora_id,
|
||||
lora_path="/not/a/real/path",
|
||||
)
|
||||
for lora_id in range(1, num_loras + 1)
|
||||
}
|
||||
|
||||
with self.lora_manager.dummy_lora_cache():
|
||||
# Add the dummy LoRAs here so _set_active_loras doesn't try to
|
||||
# load from disk.
|
||||
for lr in lora_requests:
|
||||
self.lora_manager.add_dummy_lora(lr, rank=lora_warmup_rank)
|
||||
|
||||
yield
|
||||
|
||||
# __exit__ code
|
||||
if remove_lora:
|
||||
self.lora_manager.remove_all_adapters()
|
||||
|
||||
@contextmanager
|
||||
def maybe_select_dummy_loras(
|
||||
self,
|
||||
lora_config: LoRAConfig | None,
|
||||
num_scheduled_tokens: np.ndarray,
|
||||
mapping_type: LoRAMappingType = LoRAMappingType.LANGUAGE,
|
||||
num_sampled_tokens: np.ndarray | None = None,
|
||||
num_active_loras: int = 0,
|
||||
):
|
||||
"""
|
||||
Context manager to select dummy LoRAs for capture/warmup.
|
||||
|
||||
Args:
|
||||
lora_config: LoRA configuration, or None if LoRA is disabled.
|
||||
num_scheduled_tokens: Array of scheduled token counts per request.
|
||||
num_sampled_tokens: Array of sampled token counts per request.
|
||||
num_active_loras: Number of distinct active LoRAs to use.
|
||||
- 0: No LoRA active (set up zero mappings).
|
||||
- >0: Use exactly this many distinct LoRAs.
|
||||
"""
|
||||
if num_sampled_tokens is None:
|
||||
num_sampled_tokens = np.ones_like(num_scheduled_tokens, dtype=np.int32)
|
||||
|
||||
# Skip LoRA setup entirely only if no LoRA config
|
||||
if lora_config is None:
|
||||
yield
|
||||
else:
|
||||
# __enter__ code
|
||||
assert self.lora_manager is not None, "LoRA is not enabled"
|
||||
|
||||
num_reqs = len(num_scheduled_tokens)
|
||||
max_loras = lora_config.max_loras
|
||||
|
||||
# Determine how many distinct LoRAs to use and whether to include
|
||||
# no-LoRA tokens (-1 entries).
|
||||
# When num_active_loras > max_loras (e.g., max_loras + 1), we need
|
||||
# to include -1 entries to simulate batches with both LoRA and
|
||||
# no-LoRA tokens. This ensures prepare_tensors computes the correct
|
||||
# num_active_loras that matches the cudagraph capture key.
|
||||
if num_active_loras == 0:
|
||||
# No LoRA active - use 0 mappings like the original code
|
||||
effective_num_loras = 0
|
||||
include_no_lora = False
|
||||
elif num_active_loras > max_loras:
|
||||
# num_active_loras > max_loras means we want max_loras adapters
|
||||
# PLUS no-LoRA tokens (-1). This is the max_loras + 1 case.
|
||||
effective_num_loras = max_loras
|
||||
include_no_lora = True
|
||||
else:
|
||||
# Specific number of active LoRAs requested
|
||||
effective_num_loras = min(num_active_loras, max_loras)
|
||||
include_no_lora = False
|
||||
|
||||
# Make prompt lora mapping
|
||||
# Assign LoRA IDs cyclically to simulate a worst-case scenario.
|
||||
# LoRA IDs are 1-indexed (1 to max_loras) as required by LoRARequest.
|
||||
# convert_mapping() will convert these to 0-indexed slot indices.
|
||||
if effective_num_loras > 0:
|
||||
if include_no_lora:
|
||||
# Include -1 (no-LoRA) entries by cycling through
|
||||
# -1, 1, 2, ..., effective_num_loras
|
||||
# This ensures prepare_tensors sees both LoRA and no-LoRA
|
||||
# tokens, computing num_active_loras = effective_num_loras+1
|
||||
cycle_values = np.array(
|
||||
list(range(1, effective_num_loras + 1)),
|
||||
dtype=np.int32,
|
||||
)
|
||||
prompt_lora_mapping = cycle_values[
|
||||
np.arange(num_reqs, dtype=np.int32) % len(cycle_values)
|
||||
]
|
||||
else:
|
||||
# Use 1 to effective_num_loras (1-indexed lora IDs)
|
||||
prompt_lora_mapping = (
|
||||
np.arange(num_reqs, dtype=np.int32) % effective_num_loras
|
||||
) + 1
|
||||
else:
|
||||
# No LoRA active - use 0 for all tokens (original behavior)
|
||||
prompt_lora_mapping = np.zeros(num_reqs, dtype=np.int32)
|
||||
|
||||
# Make sample lora mapping
|
||||
sample_lora_mapping = np.repeat(prompt_lora_mapping, num_sampled_tokens)
|
||||
|
||||
# Make token lora mapping
|
||||
token_lora_mapping = np.repeat(prompt_lora_mapping, num_scheduled_tokens)
|
||||
|
||||
# Make dummy lora requests (only for the active LoRAs)
|
||||
lora_requests: set[LoRARequest] = {
|
||||
LoRARequest(
|
||||
lora_name=f"warmup_{lora_id}",
|
||||
lora_int_id=lora_id,
|
||||
lora_path="/not/a/real/path",
|
||||
)
|
||||
for lora_id in range(1, effective_num_loras + 1)
|
||||
}
|
||||
|
||||
self._set_active_loras(
|
||||
tuple(sample_lora_mapping),
|
||||
tuple(token_lora_mapping),
|
||||
lora_requests,
|
||||
mapping_type,
|
||||
)
|
||||
|
||||
yield
|
||||
|
||||
@contextmanager
|
||||
def maybe_dummy_run_with_lora(
|
||||
self,
|
||||
lora_config: LoRAConfig | None,
|
||||
num_scheduled_tokens: np.ndarray,
|
||||
num_sampled_tokens: np.ndarray,
|
||||
remove_lora: bool = True,
|
||||
num_active_loras: int = 0,
|
||||
mapping_type: LoRAMappingType = LoRAMappingType.LANGUAGE,
|
||||
):
|
||||
"""
|
||||
Context manager for dummy runs with LoRA.
|
||||
|
||||
Args:
|
||||
lora_config: LoRA configuration.
|
||||
num_scheduled_tokens: Array of scheduled token counts per request.
|
||||
num_sampled_tokens: Array of sampled token counts per request.
|
||||
remove_lora: Whether to remove LoRAs after the context exits.
|
||||
num_active_loras: Number of distinct active LoRAs to use.
|
||||
LoRA is activated when num_active_loras > 0.
|
||||
"""
|
||||
with (
|
||||
self.maybe_setup_dummy_loras(lora_config, remove_lora),
|
||||
self.maybe_select_dummy_loras(
|
||||
lora_config,
|
||||
num_scheduled_tokens,
|
||||
mapping_type,
|
||||
num_sampled_tokens,
|
||||
num_active_loras,
|
||||
),
|
||||
):
|
||||
yield
|
||||
|
||||
def maybe_remove_all_loras(self, lora_config: LoRAConfig | None):
|
||||
if lora_config is None:
|
||||
return
|
||||
self.lora_manager.remove_all_adapters()
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
self._ensure_lora_enabled()
|
||||
return self.lora_manager.add_adapter(lora_request)
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
self._ensure_lora_enabled()
|
||||
return self.lora_manager.remove_adapter(lora_id)
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
self._ensure_lora_enabled()
|
||||
return self.lora_manager.pin_adapter(lora_id)
|
||||
|
||||
def list_loras(self) -> set[int]:
|
||||
self._ensure_lora_enabled()
|
||||
return self.lora_manager.list_adapters()
|
||||
242
vllm/v1/worker/mamba_utils.py
Normal file
242
vllm/v1/worker/mamba_utils.py
Normal file
@@ -0,0 +1,242 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import itertools
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateCopyFunc,
|
||||
)
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState
|
||||
from vllm.v1.worker.lora_model_runner_mixin import GPUInputBatch
|
||||
|
||||
|
||||
@triton.jit
|
||||
def batch_memcpy_kernel(src_ptrs, dst_ptrs, sizes, BLOCK_SIZE: tl.constexpr):
|
||||
pid = tl.program_id(0)
|
||||
|
||||
src_ptr = tl.load(src_ptrs + pid)
|
||||
dst_ptr = tl.load(dst_ptrs + pid)
|
||||
size = tl.load(sizes + pid)
|
||||
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
for i in range(0, size, BLOCK_SIZE):
|
||||
mask = (i + offsets) < size
|
||||
|
||||
curr_src_ptr = (src_ptr + i + offsets).to(tl.pointer_type(tl.uint8))
|
||||
curr_dst_ptr = (dst_ptr + i + offsets).to(tl.pointer_type(tl.uint8))
|
||||
|
||||
data = tl.load(curr_src_ptr, mask=mask)
|
||||
tl.store(curr_dst_ptr, data, mask=mask)
|
||||
|
||||
|
||||
def batch_memcpy(src_ptrs, dst_ptrs, sizes):
|
||||
batch = src_ptrs.shape[0]
|
||||
assert dst_ptrs.shape[0] == batch
|
||||
assert sizes.shape[0] == batch
|
||||
|
||||
grid = (batch,)
|
||||
BLOCK_SIZE = 1024
|
||||
batch_memcpy_kernel[grid](src_ptrs, dst_ptrs, sizes, BLOCK_SIZE=BLOCK_SIZE)
|
||||
|
||||
|
||||
def get_mamba_groups(kv_cache_config: KVCacheConfig) -> tuple[list[int], MambaSpec]:
|
||||
mamba_group_ids: list[int] = []
|
||||
mamba_specs: list[MambaSpec] = []
|
||||
for i in range(len(kv_cache_config.kv_cache_groups)):
|
||||
kv_cache_spec = kv_cache_config.kv_cache_groups[i].kv_cache_spec
|
||||
if isinstance(kv_cache_spec, MambaSpec):
|
||||
mamba_group_ids.append(i)
|
||||
mamba_specs.append(kv_cache_spec)
|
||||
assert len(mamba_group_ids) > 0, "no mamba layers in the model"
|
||||
assert all(mamba_specs[0] == spec for spec in mamba_specs)
|
||||
return mamba_group_ids, mamba_specs[0]
|
||||
|
||||
|
||||
def collect_mamba_copy_meta(
|
||||
src_state_list: list[int],
|
||||
dest_state_list: list[int],
|
||||
num_elements_list: list[int],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
|
||||
mamba_group_ids: list[int],
|
||||
src_block_idx: int,
|
||||
dest_block_idx: int,
|
||||
accept_token_bias: int,
|
||||
req_state: CachedRequestState,
|
||||
forward_context: dict[str, Any],
|
||||
):
|
||||
if src_block_idx == dest_block_idx and accept_token_bias == 0:
|
||||
return
|
||||
|
||||
for mamba_group_id in mamba_group_ids:
|
||||
block_ids = req_state.block_ids[mamba_group_id]
|
||||
dest_block_id = block_ids[dest_block_idx]
|
||||
layer_names = kv_cache_config.kv_cache_groups[mamba_group_id].layer_names
|
||||
for layer_name in layer_names:
|
||||
attention = forward_context[layer_name]
|
||||
kv_caches: list[torch.Tensor] = attention.kv_cache[0]
|
||||
for state, state_copy_func in zip(kv_caches, mamba_state_copy_funcs):
|
||||
copy_spec = state_copy_func(
|
||||
state, block_ids, src_block_idx, accept_token_bias + 1
|
||||
)
|
||||
|
||||
src_state_list.append(copy_spec.start_addr)
|
||||
dest_state_list.append(state[dest_block_id].data_ptr())
|
||||
num_elements_list.append(copy_spec.num_elements * state.element_size())
|
||||
|
||||
|
||||
def do_mamba_copy_block(
|
||||
src_state_list: list[int],
|
||||
dest_state_list: list[int],
|
||||
num_elements_list: list[int],
|
||||
):
|
||||
if len(src_state_list) == 0:
|
||||
return
|
||||
assert len(src_state_list) == len(dest_state_list)
|
||||
assert len(src_state_list) == len(num_elements_list)
|
||||
src_state_ptrs = torch.tensor(src_state_list, device="cuda", dtype=torch.int64)
|
||||
dst_state_ptrs = torch.tensor(dest_state_list, device="cuda", dtype=torch.int64)
|
||||
num_elements = torch.tensor(num_elements_list, device="cuda", dtype=torch.int32)
|
||||
|
||||
batch_memcpy(src_state_ptrs, dst_state_ptrs, num_elements)
|
||||
|
||||
|
||||
def preprocess_mamba(
|
||||
scheduler_output: SchedulerOutput,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
cache_config: CacheConfig,
|
||||
mamba_state_idx: dict[str, int],
|
||||
input_batch: GPUInputBatch,
|
||||
requests: dict[str, CachedRequestState],
|
||||
forward_context: dict[str, Any],
|
||||
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
|
||||
):
|
||||
"""
|
||||
Copy the mamba state of previous step to the last
|
||||
(1 + num_speculative_blocks) block.
|
||||
"""
|
||||
mamba_group_ids, mamba_spec = get_mamba_groups(kv_cache_config)
|
||||
num_speculative_blocks = mamba_spec.num_speculative_blocks
|
||||
# TODO(Chen): we need to optimize this function a lot
|
||||
assert cache_config.enable_prefix_caching
|
||||
block_size = mamba_spec.block_size
|
||||
finished_req_ids = scheduler_output.finished_req_ids
|
||||
preempted_req_ids = scheduler_output.preempted_req_ids or set()
|
||||
# We need to clear mamba_state_idx for resumed requests. When requests are
|
||||
# force-preempted (e.g., during reset_prefix_cache / KV cache flush),
|
||||
# they appear in resumed_req_ids without a corresponding entry in
|
||||
# preempted_req_ids, leaving stale mamba_state_idx entries that can
|
||||
# point to block indices beyond the new (smaller) block allocation.
|
||||
resumed_req_ids = scheduler_output.scheduled_cached_reqs.resumed_req_ids
|
||||
for req_id in itertools.chain(finished_req_ids, preempted_req_ids, resumed_req_ids):
|
||||
mamba_state_idx.pop(req_id, None)
|
||||
|
||||
src_state_list: list[int] = []
|
||||
dest_state_list: list[int] = []
|
||||
num_elements_list: list[int] = []
|
||||
for i, req_id in enumerate(input_batch.req_ids):
|
||||
req_state = requests[req_id]
|
||||
prev_state_idx = mamba_state_idx.get(req_id)
|
||||
if prev_state_idx is None:
|
||||
# new / resumed request, no previous state
|
||||
# if num_computed_tokens is 0, prev_state_idx will be -1
|
||||
prev_state_idx = (req_state.num_computed_tokens - 1) // block_size
|
||||
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
num_blocks: int = (
|
||||
cdiv(req_state.num_computed_tokens + num_scheduled_tokens, block_size)
|
||||
+ num_speculative_blocks
|
||||
)
|
||||
|
||||
# We always save the current running state at the last
|
||||
# (1 + num_speculative_blocks) block.
|
||||
# A corner case worth mention here: assume we have block_size = 4 and
|
||||
# num_speculative_tokens = 2. The request is [A, B, C] and contains 2 draft
|
||||
# tokens [draft 1, draft 2]. Then we will have:
|
||||
# Block 0: [A, B, C, draft 1]
|
||||
# Block 1: [draft 2, TOFILL, TOFILL, TOFILL]
|
||||
# Block 2: speculative block
|
||||
# Block 3: speculative block
|
||||
# And use block 1 to save the running state.
|
||||
curr_state_idx = num_blocks - 1 - num_speculative_blocks
|
||||
mamba_state_idx[req_id] = curr_state_idx
|
||||
if prev_state_idx != -1 and prev_state_idx != curr_state_idx:
|
||||
collect_mamba_copy_meta(
|
||||
src_state_list,
|
||||
dest_state_list,
|
||||
num_elements_list,
|
||||
kv_cache_config,
|
||||
mamba_state_copy_funcs,
|
||||
mamba_group_ids,
|
||||
prev_state_idx,
|
||||
curr_state_idx,
|
||||
input_batch.num_accepted_tokens_cpu[i] - 1,
|
||||
req_state,
|
||||
forward_context,
|
||||
)
|
||||
input_batch.num_accepted_tokens_cpu[i] = 1
|
||||
do_mamba_copy_block(src_state_list, dest_state_list, num_elements_list)
|
||||
|
||||
|
||||
def postprocess_mamba(
|
||||
scheduler_output: SchedulerOutput,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
input_batch: GPUInputBatch,
|
||||
requests: dict[str, CachedRequestState],
|
||||
mamba_state_idx: dict[str, int],
|
||||
forward_context: dict[str, Any],
|
||||
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
|
||||
):
|
||||
"""
|
||||
If a blocks is converted from partial block to full block in this step, copy the
|
||||
state from the block for running state to the new full block.
|
||||
"""
|
||||
num_scheduled_tokens_dict = scheduler_output.num_scheduled_tokens
|
||||
scheduled_spec_decode_tokens_dict = scheduler_output.scheduled_spec_decode_tokens
|
||||
num_accepted_tokens_cpu = input_batch.num_accepted_tokens_cpu
|
||||
# NOTE: can be optimized as this function always returns the same result
|
||||
mamba_group_ids, mamba_spec = get_mamba_groups(kv_cache_config)
|
||||
src_state_list: list[int] = []
|
||||
dest_state_list: list[int] = []
|
||||
num_elements_list: list[int] = []
|
||||
for i, req_id in enumerate(input_batch.req_ids):
|
||||
req_state = requests[req_id]
|
||||
num_computed_tokens = req_state.num_computed_tokens
|
||||
num_draft_tokens = len(scheduled_spec_decode_tokens_dict.get(req_id, []))
|
||||
num_scheduled_tokens = num_scheduled_tokens_dict[req_id]
|
||||
num_accepted_tokens = num_accepted_tokens_cpu[i]
|
||||
num_tokens_running_state = (
|
||||
num_computed_tokens + num_scheduled_tokens - num_draft_tokens
|
||||
)
|
||||
new_num_computed_tokens = num_tokens_running_state + num_accepted_tokens - 1
|
||||
aligned_new_computed_tokens = (
|
||||
new_num_computed_tokens // mamba_spec.block_size * mamba_spec.block_size
|
||||
)
|
||||
# TODO: how to ensure all blocks that cache_blocks called are cached here?
|
||||
if aligned_new_computed_tokens >= num_tokens_running_state:
|
||||
accept_token_bias = aligned_new_computed_tokens - num_tokens_running_state
|
||||
src_block_idx = mamba_state_idx[req_id]
|
||||
dest_block_idx = aligned_new_computed_tokens // mamba_spec.block_size - 1
|
||||
collect_mamba_copy_meta(
|
||||
src_state_list,
|
||||
dest_state_list,
|
||||
num_elements_list,
|
||||
kv_cache_config,
|
||||
mamba_state_copy_funcs,
|
||||
mamba_group_ids,
|
||||
src_block_idx,
|
||||
dest_block_idx,
|
||||
accept_token_bias,
|
||||
req_state,
|
||||
forward_context,
|
||||
)
|
||||
if src_block_idx == dest_block_idx:
|
||||
num_accepted_tokens_cpu[i] = 1
|
||||
do_mamba_copy_block(src_state_list, dest_state_list, num_elements_list)
|
||||
574
vllm/v1/worker/tpu_input_batch.py
Normal file
574
vllm/v1/worker/tpu_input_batch.py
Normal file
@@ -0,0 +1,574 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Datastructures defining a TPU input batch
|
||||
|
||||
from typing import cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sampling_params import SamplingType
|
||||
from vllm.utils import length_from_prompt_token_ids_or_embeds
|
||||
from vllm.utils.collection_utils import swap_dict_values
|
||||
from vllm.v1.outputs import LogprobsTensors
|
||||
from vllm.v1.worker.block_table import MultiGroupBlockTable
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
|
||||
class InputBatch:
|
||||
def __init__(
|
||||
self,
|
||||
max_num_reqs: int,
|
||||
max_model_len: int,
|
||||
max_num_batched_tokens: int,
|
||||
device: torch.device,
|
||||
pin_memory: bool,
|
||||
vocab_size: int,
|
||||
block_sizes: list[int], # The block_size of each kv cache group
|
||||
kernel_block_sizes: list[int],
|
||||
):
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_model_len = max_model_len
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
self.device = device
|
||||
self.pin_memory = pin_memory
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
self._req_ids: list[str | None] = []
|
||||
self.req_id_to_index: dict[str, int] = {}
|
||||
|
||||
# TODO(woosuk): This buffer could be too large if max_model_len is big.
|
||||
# Find a way to reduce the CPU memory usage.
|
||||
# This buffer is not directly transferred to the GPU, so it does not
|
||||
# need to be pinned.
|
||||
self.token_ids_cpu_tensor = torch.zeros(
|
||||
(max_num_reqs, max_model_len),
|
||||
device="cpu",
|
||||
dtype=torch.int32,
|
||||
pin_memory=False,
|
||||
)
|
||||
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
|
||||
self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
|
||||
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
|
||||
self.num_computed_tokens_cpu_tensor = torch.zeros(
|
||||
(max_num_reqs,),
|
||||
device="cpu",
|
||||
dtype=torch.int32,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
self.num_computed_tokens_cpu = self.num_computed_tokens_cpu_tensor.numpy()
|
||||
|
||||
# Block table.
|
||||
self.block_table = MultiGroupBlockTable(
|
||||
max_num_reqs=max_num_reqs,
|
||||
max_model_len=max_model_len,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
pin_memory=pin_memory,
|
||||
device=device,
|
||||
block_sizes=block_sizes,
|
||||
kernel_block_sizes=kernel_block_sizes,
|
||||
)
|
||||
|
||||
# Sampling-related.
|
||||
self.temperature = torch.empty(
|
||||
(max_num_reqs,), dtype=torch.float32, device=device
|
||||
)
|
||||
self.temperature_cpu_tensor = torch.empty(
|
||||
(max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=pin_memory
|
||||
)
|
||||
self.temperature_cpu = self.temperature_cpu_tensor.numpy()
|
||||
self.greedy_reqs: set[str] = set()
|
||||
self.random_reqs: set[str] = set()
|
||||
|
||||
self.top_p = torch.empty((max_num_reqs,), dtype=torch.float32, device=device)
|
||||
self.top_p_cpu_tensor = torch.empty(
|
||||
(max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=pin_memory
|
||||
)
|
||||
self.top_p_cpu = self.top_p_cpu_tensor.numpy()
|
||||
self.top_p_reqs: set[str] = set()
|
||||
|
||||
self.top_k = torch.empty((max_num_reqs,), dtype=torch.int32, device=device)
|
||||
self.top_k_cpu_tensor = torch.empty(
|
||||
(max_num_reqs,), dtype=torch.int32, device="cpu", pin_memory=pin_memory
|
||||
)
|
||||
self.top_k_cpu = self.top_k_cpu_tensor.numpy()
|
||||
self.top_k_reqs: set[str] = set()
|
||||
|
||||
self.min_p = torch.empty((max_num_reqs,), dtype=torch.float32, device=device)
|
||||
self.min_p_cpu_tensor = torch.empty(
|
||||
(max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=pin_memory
|
||||
)
|
||||
self.min_p_cpu = self.min_p_cpu_tensor.numpy()
|
||||
self.min_p_reqs: set[str] = set()
|
||||
|
||||
# Frequency penalty related data structures
|
||||
self.frequency_penalties = torch.empty(
|
||||
(max_num_reqs,), dtype=torch.float, device=device
|
||||
)
|
||||
self.frequency_penalties_cpu_tensor = torch.empty(
|
||||
(max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory
|
||||
)
|
||||
self.frequency_penalties_cpu = self.frequency_penalties_cpu_tensor.numpy()
|
||||
self.frequency_penalties_reqs: set[str] = set()
|
||||
|
||||
# Presence penalty related data structures
|
||||
self.presence_penalties = torch.empty(
|
||||
(max_num_reqs,), dtype=torch.float, device=device
|
||||
)
|
||||
self.presence_penalties_cpu_tensor = torch.empty(
|
||||
(max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory
|
||||
)
|
||||
self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy()
|
||||
self.presence_penalties_reqs: set[str] = set()
|
||||
|
||||
# Repetition penalty related data structures
|
||||
self.repetition_penalties = torch.empty(
|
||||
(max_num_reqs,), dtype=torch.float, device=device
|
||||
)
|
||||
self.repetition_penalties_cpu_tensor = torch.empty(
|
||||
(max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory
|
||||
)
|
||||
self.repetition_penalties_cpu = self.repetition_penalties_cpu_tensor.numpy()
|
||||
self.repetition_penalties_reqs: set[str] = set()
|
||||
|
||||
# req_index -> (min_tokens, stop_token_ids)
|
||||
self.min_tokens: dict[int, tuple[int, set[int]]] = {}
|
||||
|
||||
# lora related
|
||||
self.request_lora_mapping = np.zeros((self.max_num_reqs,), dtype=np.int64)
|
||||
self.lora_id_to_request_ids: dict[int, set[str]] = {}
|
||||
self.lora_id_to_lora_request: dict[int, LoRARequest] = {}
|
||||
|
||||
# req_index -> generator
|
||||
# NOTE(woosuk): The indices of the requests that do not have their own
|
||||
# generator should not be included in the dictionary.
|
||||
self.generators: dict[int, torch.Generator] = {}
|
||||
|
||||
self.num_logprobs: dict[str, int] = {}
|
||||
|
||||
# To accumulate prompt logprobs tensor chunks across prefill steps.
|
||||
self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}
|
||||
|
||||
self.logit_bias: list[dict[int, float] | None] = [None] * max_num_reqs
|
||||
self.has_allowed_token_ids: set[str] = set()
|
||||
# NOTE(lufang): In the mask tensor, if the corresponding token allowed,
|
||||
# the value is False. Since we use masked_fill_ to set -inf.
|
||||
self.allowed_token_ids_mask: torch.Tensor | None = None
|
||||
self.allowed_token_ids_mask_cpu_tensor: torch.Tensor | None = None
|
||||
|
||||
# req_index -> bad_words_token_ids
|
||||
self.bad_words_token_ids: dict[int, list[list[int]]] = {}
|
||||
|
||||
self.req_output_token_ids: list[list[int] | None] = []
|
||||
|
||||
@property
|
||||
def req_ids(self) -> list[str]:
|
||||
# None elements should only be present transiently
|
||||
# while performing state updates to the batch.
|
||||
return cast(list[str], self._req_ids)
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
request: "CachedRequestState",
|
||||
req_index: int | None = None,
|
||||
) -> None:
|
||||
if req_index is None:
|
||||
req_index = self.num_reqs
|
||||
assert req_index < self.max_num_reqs
|
||||
|
||||
req_id = request.req_id
|
||||
if req_index == len(self._req_ids):
|
||||
self._req_ids.append(req_id)
|
||||
self.req_output_token_ids.append(request.output_token_ids)
|
||||
else:
|
||||
self._req_ids[req_index] = req_id
|
||||
self.req_output_token_ids[req_index] = request.output_token_ids
|
||||
|
||||
self.req_id_to_index[req_id] = req_index
|
||||
|
||||
# Copy the prompt token ids and output token ids.
|
||||
num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
|
||||
request.prompt_token_ids, request.prompt_embeds
|
||||
)
|
||||
# TODO: copy prompt_embeds
|
||||
self.num_prompt_tokens[req_index] = num_prompt_tokens
|
||||
self.token_ids_cpu[req_index, :num_prompt_tokens] = request.prompt_token_ids
|
||||
start_idx = num_prompt_tokens
|
||||
end_idx = start_idx + len(request.output_token_ids)
|
||||
self.token_ids_cpu[req_index, start_idx:end_idx] = request.output_token_ids
|
||||
# Number of tokens without spec decode tokens.
|
||||
self.num_tokens_no_spec[req_index] = request.num_tokens
|
||||
|
||||
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
|
||||
self.block_table.add_row(request.block_ids, req_index)
|
||||
|
||||
sampling_params = request.sampling_params
|
||||
assert sampling_params is not None, "pooling requests not supported yet"
|
||||
if sampling_params.sampling_type == SamplingType.GREEDY:
|
||||
# Should avoid division by zero later when apply_temperature.
|
||||
self.temperature_cpu[req_index] = 0.0
|
||||
self.greedy_reqs.add(req_id)
|
||||
else:
|
||||
self.temperature_cpu[req_index] = sampling_params.temperature
|
||||
self.random_reqs.add(req_id)
|
||||
|
||||
self.top_p_cpu[req_index] = sampling_params.top_p
|
||||
if sampling_params.top_p < 1:
|
||||
self.top_p_reqs.add(req_id)
|
||||
top_k = sampling_params.top_k
|
||||
if 0 < top_k < self.vocab_size:
|
||||
self.top_k_reqs.add(req_id)
|
||||
else:
|
||||
top_k = self.vocab_size
|
||||
self.top_k_cpu[req_index] = top_k
|
||||
self.min_p_cpu[req_index] = sampling_params.min_p
|
||||
self.frequency_penalties_cpu[req_index] = sampling_params.frequency_penalty
|
||||
if sampling_params.min_p > _SAMPLING_EPS:
|
||||
self.min_p_reqs.add(req_id)
|
||||
if sampling_params.frequency_penalty != 0.0:
|
||||
self.frequency_penalties_reqs.add(req_id)
|
||||
self.presence_penalties_cpu[req_index] = sampling_params.presence_penalty
|
||||
if sampling_params.presence_penalty != 0.0:
|
||||
self.presence_penalties_reqs.add(req_id)
|
||||
self.repetition_penalties_cpu[req_index] = sampling_params.repetition_penalty
|
||||
if sampling_params.repetition_penalty != 1.0:
|
||||
self.repetition_penalties_reqs.add(req_id)
|
||||
if sampling_params.min_tokens:
|
||||
self.min_tokens[req_index] = (
|
||||
sampling_params.min_tokens,
|
||||
sampling_params.all_stop_token_ids,
|
||||
)
|
||||
|
||||
# NOTE(woosuk): self.generators should not include the requests that
|
||||
# do not have their own generator.
|
||||
if request.generator is not None:
|
||||
self.generators[req_index] = request.generator
|
||||
|
||||
if sampling_params.logprobs is not None:
|
||||
self.num_logprobs[req_id] = sampling_params.logprobs
|
||||
if sampling_params.logit_bias is not None:
|
||||
self.logit_bias[req_index] = sampling_params.logit_bias
|
||||
|
||||
if sampling_params.allowed_token_ids:
|
||||
self.has_allowed_token_ids.add(req_id)
|
||||
if self.allowed_token_ids_mask_cpu_tensor is None:
|
||||
# Lazy allocation for this tensor, which can be large.
|
||||
# False means we don't fill with -inf.
|
||||
self.allowed_token_ids_mask = torch.zeros(
|
||||
self.max_num_reqs,
|
||||
self.vocab_size,
|
||||
dtype=torch.bool,
|
||||
device=self.device,
|
||||
)
|
||||
self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
|
||||
self.max_num_reqs, self.vocab_size, dtype=torch.bool, device="cpu"
|
||||
)
|
||||
self.allowed_token_ids_mask_cpu_tensor[req_index] = True
|
||||
# False means we don't fill with -inf.
|
||||
self.allowed_token_ids_mask_cpu_tensor[req_index][
|
||||
sampling_params.allowed_token_ids
|
||||
] = False
|
||||
|
||||
if sampling_params.bad_words_token_ids:
|
||||
self.bad_words_token_ids[req_index] = sampling_params.bad_words_token_ids
|
||||
|
||||
# Add request lora ID
|
||||
if request.lora_request:
|
||||
lora_id = request.lora_request.lora_int_id
|
||||
if lora_id not in self.lora_id_to_request_ids:
|
||||
self.lora_id_to_request_ids[lora_id] = set()
|
||||
|
||||
self.request_lora_mapping[req_index] = lora_id
|
||||
self.lora_id_to_request_ids[lora_id].add(request.req_id)
|
||||
self.lora_id_to_lora_request[lora_id] = request.lora_request
|
||||
else:
|
||||
# No LoRA
|
||||
self.request_lora_mapping[req_index] = 0
|
||||
|
||||
def remove_request(self, req_id: str) -> int | None:
|
||||
"""This method must always be followed by a call to condense()."""
|
||||
|
||||
req_index = self.req_id_to_index.pop(req_id, None)
|
||||
if req_index is None:
|
||||
return None
|
||||
self._req_ids[req_index] = None
|
||||
self.req_output_token_ids[req_index] = None
|
||||
|
||||
self.greedy_reqs.discard(req_id)
|
||||
self.random_reqs.discard(req_id)
|
||||
self.top_p_reqs.discard(req_id)
|
||||
self.top_k_reqs.discard(req_id)
|
||||
self.min_p_reqs.discard(req_id)
|
||||
self.min_tokens.pop(req_index, None)
|
||||
self.frequency_penalties_reqs.discard(req_id)
|
||||
self.presence_penalties_reqs.discard(req_id)
|
||||
self.repetition_penalties_reqs.discard(req_id)
|
||||
self.generators.pop(req_index, None)
|
||||
self.num_logprobs.pop(req_id, None)
|
||||
self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
|
||||
|
||||
# LoRA
|
||||
lora_id = self.request_lora_mapping[req_index]
|
||||
if lora_id != 0:
|
||||
self.lora_id_to_request_ids[lora_id].discard(req_id)
|
||||
if len(self.lora_id_to_request_ids[lora_id]) == 0:
|
||||
self.lora_id_to_request_ids.pop(lora_id)
|
||||
self.lora_id_to_lora_request.pop(lora_id)
|
||||
self.request_lora_mapping[req_index] = 0
|
||||
|
||||
self.logit_bias[req_index] = None
|
||||
self.has_allowed_token_ids.discard(req_id)
|
||||
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
||||
# False means we don't fill with -inf.
|
||||
self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
|
||||
self.bad_words_token_ids.pop(req_index, None)
|
||||
return req_index
|
||||
|
||||
def swap_states(self, i1: int, i2: int) -> None:
|
||||
old_id_i1 = self._req_ids[i1]
|
||||
old_id_i2 = self._req_ids[i2]
|
||||
self._req_ids[i1], self._req_ids[i2] = self._req_ids[i2], self._req_ids[i1] # noqa
|
||||
self.req_output_token_ids[i1], self.req_output_token_ids[i2] = (
|
||||
self.req_output_token_ids[i2],
|
||||
self.req_output_token_ids[i1],
|
||||
)
|
||||
assert old_id_i1 is not None and old_id_i2 is not None
|
||||
self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] = (
|
||||
self.req_id_to_index[old_id_i2],
|
||||
self.req_id_to_index[old_id_i1],
|
||||
)
|
||||
self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] = (
|
||||
self.num_tokens_no_spec[i2],
|
||||
self.num_tokens_no_spec[i1],
|
||||
)
|
||||
self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] = (
|
||||
self.num_prompt_tokens[i2],
|
||||
self.num_prompt_tokens[i1],
|
||||
)
|
||||
self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] = (
|
||||
self.num_computed_tokens_cpu[i2],
|
||||
self.num_computed_tokens_cpu[i1],
|
||||
)
|
||||
self.temperature_cpu[i1], self.temperature_cpu[i2] = (
|
||||
self.temperature_cpu[i2],
|
||||
self.temperature_cpu[i1],
|
||||
)
|
||||
self.top_p_cpu[i1], self.top_p_cpu[i2] = self.top_p_cpu[i2], self.top_p_cpu[i1]
|
||||
self.top_k_cpu[i1], self.top_k_cpu[i2] = self.top_k_cpu[i2], self.top_k_cpu[i1]
|
||||
self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] = (
|
||||
self.frequency_penalties_cpu[i2],
|
||||
self.frequency_penalties_cpu[i1],
|
||||
)
|
||||
self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] = (
|
||||
self.presence_penalties_cpu[i2],
|
||||
self.presence_penalties_cpu[i1],
|
||||
)
|
||||
self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] = (
|
||||
self.repetition_penalties_cpu[i2],
|
||||
self.repetition_penalties_cpu[i1],
|
||||
)
|
||||
self.min_p_cpu[i1], self.min_p_cpu[i2] = self.min_p_cpu[i2], self.min_p_cpu[i1]
|
||||
|
||||
# NOTE: the following is unsafe
|
||||
# self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
|
||||
# self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
|
||||
# instead, we need to temporarily copy the data for one of the indices
|
||||
# TODO(lucas): optimize this by only copying valid indices
|
||||
tmp = self.token_ids_cpu[i1, ...].copy()
|
||||
self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
|
||||
self.token_ids_cpu[i2, ...] = tmp
|
||||
|
||||
swap_dict_values(self.generators, i1, i2)
|
||||
swap_dict_values(self.min_tokens, i1, i2)
|
||||
swap_dict_values(self.bad_words_token_ids, i1, i2)
|
||||
|
||||
self.request_lora_mapping[i1], self.request_lora_mapping[i2] = (
|
||||
self.request_lora_mapping[i2],
|
||||
self.request_lora_mapping[i1],
|
||||
)
|
||||
self.logit_bias[i1], self.logit_bias[i2] = (
|
||||
self.logit_bias[i2],
|
||||
self.logit_bias[i1],
|
||||
)
|
||||
|
||||
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
||||
(
|
||||
self.allowed_token_ids_mask_cpu_tensor[i1],
|
||||
self.allowed_token_ids_mask_cpu_tensor[i2],
|
||||
) = (
|
||||
self.allowed_token_ids_mask_cpu_tensor[i2],
|
||||
self.allowed_token_ids_mask_cpu_tensor[i1],
|
||||
)
|
||||
self.block_table.swap_row(i1, i2)
|
||||
|
||||
def condense(self, empty_req_indices: list[int]) -> None:
|
||||
"""Move non-empty requests down into lower, empty indices.
|
||||
|
||||
Args:
|
||||
empty_req_indices: empty batch indices, sorted descending.
|
||||
"""
|
||||
num_reqs = self.num_reqs
|
||||
if num_reqs == 0:
|
||||
# The batched states are empty.
|
||||
self._req_ids.clear()
|
||||
self.req_output_token_ids.clear()
|
||||
return
|
||||
|
||||
# NOTE(woosuk): This function assumes that the empty_req_indices
|
||||
# is sorted in descending order.
|
||||
last_req_index = num_reqs + len(empty_req_indices) - 1
|
||||
while empty_req_indices:
|
||||
# Find the largest non-empty index.
|
||||
while last_req_index in empty_req_indices:
|
||||
last_req_index -= 1
|
||||
|
||||
# Find the smallest empty index.
|
||||
empty_index = empty_req_indices.pop()
|
||||
if empty_index >= last_req_index:
|
||||
break
|
||||
|
||||
# Swap the states.
|
||||
req_id = self._req_ids[last_req_index]
|
||||
output_token_ids = self.req_output_token_ids[last_req_index]
|
||||
assert req_id is not None
|
||||
self._req_ids[empty_index] = req_id
|
||||
self._req_ids[last_req_index] = None
|
||||
self.req_output_token_ids[empty_index] = output_token_ids
|
||||
self.req_output_token_ids[last_req_index] = None
|
||||
self.req_id_to_index[req_id] = empty_index
|
||||
|
||||
num_tokens = self.num_tokens_no_spec[last_req_index]
|
||||
self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
|
||||
last_req_index, :num_tokens
|
||||
]
|
||||
self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
|
||||
last_req_index
|
||||
]
|
||||
self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[last_req_index]
|
||||
self.num_computed_tokens_cpu[empty_index] = self.num_computed_tokens_cpu[
|
||||
last_req_index
|
||||
]
|
||||
self.block_table.move_row(last_req_index, empty_index)
|
||||
self.temperature_cpu[empty_index] = self.temperature_cpu[last_req_index]
|
||||
self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
|
||||
self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
|
||||
self.frequency_penalties_cpu[empty_index] = self.frequency_penalties_cpu[
|
||||
last_req_index
|
||||
]
|
||||
self.presence_penalties_cpu[empty_index] = self.presence_penalties_cpu[
|
||||
last_req_index
|
||||
]
|
||||
self.repetition_penalties_cpu[empty_index] = self.repetition_penalties_cpu[
|
||||
last_req_index
|
||||
]
|
||||
self.min_p_cpu[empty_index] = self.min_p_cpu[last_req_index]
|
||||
generator = self.generators.pop(last_req_index, None)
|
||||
if generator is not None:
|
||||
self.generators[empty_index] = generator
|
||||
|
||||
min_token = self.min_tokens.pop(last_req_index, None)
|
||||
if min_token is not None:
|
||||
self.min_tokens[empty_index] = min_token
|
||||
|
||||
self.request_lora_mapping[empty_index] = self.request_lora_mapping[
|
||||
last_req_index
|
||||
]
|
||||
|
||||
self.logit_bias[empty_index] = self.logit_bias[last_req_index]
|
||||
|
||||
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
||||
self.allowed_token_ids_mask_cpu_tensor[empty_index] = (
|
||||
self.allowed_token_ids_mask_cpu_tensor[last_req_index]
|
||||
)
|
||||
|
||||
bad_words_token_ids = self.bad_words_token_ids.pop(last_req_index, None)
|
||||
if bad_words_token_ids is not None:
|
||||
self.bad_words_token_ids[empty_index] = bad_words_token_ids
|
||||
# Decrement last_req_index since it is now empty.
|
||||
last_req_index -= 1
|
||||
|
||||
# Trim lists to the batch size.
|
||||
del self._req_ids[self.num_reqs :]
|
||||
del self.req_output_token_ids[self.num_reqs :]
|
||||
|
||||
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
|
||||
max_prompt_len = self.num_prompt_tokens[: self.num_reqs].max()
|
||||
prompt_token_ids_cpu_tensor = torch.empty(
|
||||
(self.num_reqs, max_prompt_len),
|
||||
device="cpu",
|
||||
dtype=torch.int64,
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
|
||||
prompt_token_ids[:] = self.token_ids_cpu[: self.num_reqs, :max_prompt_len]
|
||||
# Use the value of vocab_size as a pad since we don't have a
|
||||
# token_id of this value.
|
||||
for i in range(self.num_reqs):
|
||||
prompt_token_ids[i, self.num_prompt_tokens[i] :] = self.vocab_size
|
||||
return prompt_token_ids_cpu_tensor.to(device=self.device, non_blocking=True)
|
||||
|
||||
def make_lora_inputs(
|
||||
self, num_scheduled_tokens: np.ndarray, num_sampled_tokens: np.ndarray
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
|
||||
"""
|
||||
Given the num_scheduled_tokens for each request in the batch, return
|
||||
datastructures used to activate the current LoRAs.
|
||||
Returns:
|
||||
1. prompt_lora_mapping: A tuple of size self.num_reqs where,
|
||||
prompt_lora_mapping[i] is the LoRA id to use for the ith prompt.
|
||||
2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens)
|
||||
where, token_lora_mapping[i] is the LoRA id to use for ith token.
|
||||
3. lora_requests: Set of relevant LoRA requests.
|
||||
"""
|
||||
|
||||
req_lora_mapping = self.request_lora_mapping[: self.num_reqs]
|
||||
prompt_lora_mapping = tuple(req_lora_mapping)
|
||||
token_lora_mapping = tuple(req_lora_mapping.repeat(num_scheduled_tokens))
|
||||
active_lora_requests: set[LoRARequest] = set(
|
||||
self.lora_id_to_lora_request.values()
|
||||
)
|
||||
|
||||
return prompt_lora_mapping, token_lora_mapping, active_lora_requests
|
||||
|
||||
@property
|
||||
def num_reqs(self) -> int:
|
||||
return len(self.req_id_to_index)
|
||||
|
||||
@property
|
||||
def all_greedy(self) -> bool:
|
||||
return len(self.random_reqs) == 0
|
||||
|
||||
@property
|
||||
def all_random(self) -> bool:
|
||||
return len(self.greedy_reqs) == 0
|
||||
|
||||
@property
|
||||
def no_top_p(self) -> bool:
|
||||
return len(self.top_p_reqs) == 0
|
||||
|
||||
@property
|
||||
def no_top_k(self) -> bool:
|
||||
return len(self.top_k_reqs) == 0
|
||||
|
||||
@property
|
||||
def no_min_p(self) -> bool:
|
||||
return len(self.min_p_reqs) == 0
|
||||
|
||||
@property
|
||||
def no_penalties(self) -> bool:
|
||||
return (
|
||||
len(self.presence_penalties_reqs) == 0
|
||||
and len(self.frequency_penalties_reqs) == 0
|
||||
and len(self.repetition_penalties_reqs) == 0
|
||||
)
|
||||
|
||||
@property
|
||||
def max_num_logprobs(self) -> int | None:
|
||||
return max(self.num_logprobs.values()) if self.num_logprobs else None
|
||||
|
||||
@property
|
||||
def no_allowed_token_ids(self) -> bool:
|
||||
return len(self.has_allowed_token_ids) == 0
|
||||
243
vllm/v1/worker/ubatch_utils.py
Normal file
243
vllm/v1/worker/ubatch_utils.py
Normal file
@@ -0,0 +1,243 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import TypeAlias
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.v1.attention.backend import CommonAttentionMetadata
|
||||
|
||||
|
||||
@dataclass
|
||||
class UBatchSlice:
|
||||
request_slice: slice
|
||||
token_slice: slice
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
return (
|
||||
self.request_slice.start == self.request_slice.stop
|
||||
or self.token_slice.start == self.token_slice.stop
|
||||
)
|
||||
|
||||
@property
|
||||
def num_tokens(self) -> int:
|
||||
return self.token_slice.stop - self.token_slice.start
|
||||
|
||||
|
||||
UBatchSlices: TypeAlias = list[UBatchSlice]
|
||||
|
||||
|
||||
def is_last_ubatch_empty(
|
||||
orig_num_tokens: int, padded_num_tokens: int, num_ubatches: int
|
||||
) -> bool:
|
||||
return (padded_num_tokens // num_ubatches) * (num_ubatches - 1) >= orig_num_tokens
|
||||
|
||||
|
||||
def check_ubatch_thresholds(
|
||||
config: ParallelConfig, num_tokens: int, uniform_decode: bool
|
||||
) -> bool:
|
||||
if not config.use_ubatching:
|
||||
return False
|
||||
if uniform_decode:
|
||||
return num_tokens >= config.dbo_decode_token_threshold
|
||||
else:
|
||||
return num_tokens >= config.dbo_prefill_token_threshold
|
||||
|
||||
|
||||
# This pads the last ubatch slice out to the total number of tokens
|
||||
# (num_tokens + padding) since we do `create_ubatch_slices` before applying DP padding.
|
||||
def _pad_out_ubatch_slices(
|
||||
ubatch_slices: UBatchSlices, num_total_tokens: int, num_reqs_padded: int
|
||||
) -> UBatchSlices:
|
||||
last_slice = ubatch_slices[-1]
|
||||
padded_last_request_slice = slice(last_slice.request_slice.start, num_reqs_padded)
|
||||
padded_last_token_slice = slice(last_slice.token_slice.start, num_total_tokens)
|
||||
|
||||
return ubatch_slices[:-1] + [
|
||||
UBatchSlice(padded_last_request_slice, padded_last_token_slice)
|
||||
]
|
||||
|
||||
|
||||
def maybe_create_ubatch_slices(
|
||||
should_ubatch: bool,
|
||||
num_scheduled_tokens: np.ndarray,
|
||||
num_tokens_padded: int,
|
||||
num_reqs_padded: int,
|
||||
num_ubatches: int,
|
||||
split_point: list[int] | int | None = None,
|
||||
) -> tuple[UBatchSlices | None, UBatchSlices | None]:
|
||||
if not should_ubatch:
|
||||
return None, None
|
||||
|
||||
if split_point is None:
|
||||
split_point = int(num_tokens_padded) // num_ubatches
|
||||
|
||||
token_split_points = [split_point * i for i in range(1, num_ubatches)]
|
||||
|
||||
# TODO(lucas): Refactor the gpu_model_runner.py so we can pass
|
||||
# in cu_num_tokens directly (i.e. query_start_loc)
|
||||
cu_num_tokens = np.zeros(len(num_scheduled_tokens) + 1, dtype=np.int32)
|
||||
np.cumsum(num_scheduled_tokens, dtype=np.int32, out=cu_num_tokens[1:])
|
||||
|
||||
ubatch_slices = []
|
||||
start_token = 0
|
||||
|
||||
# Add the end point to the split points to make iteration easier
|
||||
all_points = token_split_points + [cu_num_tokens[-1]]
|
||||
|
||||
for end_token in all_points:
|
||||
token_slice = slice(start_token, end_token)
|
||||
|
||||
# Determine request slices using exclusive stop semantics
|
||||
# Ubatch includes requests whose tokens overlap [start_token, end_token)
|
||||
|
||||
# Start at the request that contains the start_token
|
||||
# or the request starting exactly at start_token (if on boundary)
|
||||
req_start = int(np.searchsorted(cu_num_tokens, start_token, side="right") - 1)
|
||||
|
||||
# Stop at the request that starts at or after end_token
|
||||
req_stop = int(np.searchsorted(cu_num_tokens, end_token, side="left"))
|
||||
|
||||
req_slice = slice(req_start, req_stop)
|
||||
ubatch_slices.append(UBatchSlice(req_slice, token_slice))
|
||||
|
||||
start_token = end_token
|
||||
|
||||
ubatch_slices_padded = _pad_out_ubatch_slices(
|
||||
ubatch_slices, num_tokens_padded, num_reqs_padded
|
||||
)
|
||||
|
||||
assert sum(s.num_tokens for s in ubatch_slices_padded) == num_tokens_padded
|
||||
|
||||
return ubatch_slices, ubatch_slices_padded
|
||||
|
||||
|
||||
def slice_query_start_locs(
|
||||
query_start_loc: torch.Tensor,
|
||||
request_slice: slice,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Creates a new query_start_loc that corresponds to the requests in
|
||||
request_slice.
|
||||
|
||||
Note: This function creates a new tensor to hold the new query_start_locs.
|
||||
This will break cudagraph compatibility.
|
||||
"""
|
||||
return (
|
||||
query_start_loc[request_slice.start : request_slice.stop + 1]
|
||||
- query_start_loc[request_slice.start]
|
||||
)
|
||||
|
||||
|
||||
def _make_metadata_with_slice(
|
||||
ubatch_slice: UBatchSlice, attn_metadata: CommonAttentionMetadata
|
||||
) -> CommonAttentionMetadata:
|
||||
"""
|
||||
This function creates a new CommonAttentionMetadata that corresponds to
|
||||
the requests included in ubatch_slice
|
||||
"""
|
||||
|
||||
assert not ubatch_slice.is_empty(), f"Ubatch slice {ubatch_slice} is empty"
|
||||
|
||||
request_slice = ubatch_slice.request_slice
|
||||
token_slice = ubatch_slice.token_slice
|
||||
|
||||
start_locs = attn_metadata.query_start_loc_cpu
|
||||
first_req = request_slice.start
|
||||
first_tok = token_slice.start
|
||||
last_req = request_slice.stop - 1
|
||||
last_tok = token_slice.stop - 1
|
||||
|
||||
assert start_locs[first_req] <= first_tok < start_locs[first_req + 1], (
|
||||
"Token slice start outside of first request"
|
||||
)
|
||||
# NOTE: last token can be outside of the last request if we have CG padding.
|
||||
|
||||
# If the request is split across ubatches, we have to adjust the metadata.
|
||||
# splits_first_request: The first request in this slice is the continuation of
|
||||
# a request that started in a previous slice.
|
||||
# splits_last_request: The last request in this slice continues into the
|
||||
# next slice.
|
||||
splits_first_request = first_tok > start_locs[first_req]
|
||||
splits_last_request = last_tok < start_locs[last_req + 1] - 1
|
||||
|
||||
query_start_loc_cpu = slice_query_start_locs(start_locs, request_slice)
|
||||
query_start_loc = slice_query_start_locs(
|
||||
attn_metadata.query_start_loc, request_slice
|
||||
)
|
||||
|
||||
assert len(query_start_loc) >= 2, (
|
||||
f"query_start_loc must have at least 2 elements, got {len(query_start_loc)}"
|
||||
)
|
||||
|
||||
if splits_first_request:
|
||||
tokens_skipped = first_tok - start_locs[first_req]
|
||||
query_start_loc[1:] -= tokens_skipped
|
||||
query_start_loc_cpu[1:] -= tokens_skipped
|
||||
seq_lens = attn_metadata.seq_lens[request_slice]
|
||||
seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice]
|
||||
|
||||
if splits_last_request:
|
||||
# NOTE: We use start_locs (the original query_start_loc_cpu) to calculate
|
||||
# the tokens skipped because query_start_loc_cpu might have been modified
|
||||
# if splits_first_request is True.
|
||||
tokens_skipped = start_locs[last_req + 1] - token_slice.stop
|
||||
query_start_loc[-1] -= tokens_skipped
|
||||
query_start_loc_cpu[-1] -= tokens_skipped
|
||||
|
||||
# Make sure we don't modify the seq_lens tensors
|
||||
# (not cudagraph compatible)
|
||||
seq_lens = seq_lens.clone()
|
||||
seq_lens_cpu = seq_lens_cpu.clone()
|
||||
seq_lens[-1] -= tokens_skipped
|
||||
seq_lens_cpu[-1] -= tokens_skipped
|
||||
|
||||
max_seq_len = int(seq_lens_cpu.max())
|
||||
num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[request_slice]
|
||||
|
||||
num_requests = request_slice.stop - request_slice.start
|
||||
num_actual_tokens = token_slice.stop - token_slice.start
|
||||
max_query_len = int(
|
||||
torch.max(torch.abs(query_start_loc_cpu[1:] - query_start_loc_cpu[:-1])).item()
|
||||
)
|
||||
|
||||
# This is to account for the case where we are in a dummy
|
||||
# run and query_start_loc_cpu is full of 0s
|
||||
if max_query_len == 0:
|
||||
max_query_len = attn_metadata.max_query_len
|
||||
|
||||
block_table_tensor = attn_metadata.block_table_tensor[request_slice]
|
||||
slot_mapping = attn_metadata.slot_mapping[token_slice]
|
||||
|
||||
return CommonAttentionMetadata(
|
||||
query_start_loc=query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
seq_lens=seq_lens,
|
||||
num_reqs=num_requests,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
max_query_len=max_query_len,
|
||||
max_seq_len=max_seq_len,
|
||||
block_table_tensor=block_table_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
_seq_lens_cpu=seq_lens_cpu,
|
||||
_num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||
)
|
||||
|
||||
|
||||
def split_attn_metadata(
|
||||
ubatch_slices: list[UBatchSlice],
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
) -> list[CommonAttentionMetadata]:
|
||||
"""
|
||||
Creates a new CommonAttentionMetadata instance that corresponds to the
|
||||
requests for each UBatchSlice in ubatch_slices.
|
||||
|
||||
Note: This function does not modify common_attn_metadata
|
||||
"""
|
||||
results = []
|
||||
for ubatch_slice in ubatch_slices:
|
||||
results.append(_make_metadata_with_slice(ubatch_slice, common_attn_metadata))
|
||||
|
||||
return results
|
||||
241
vllm/v1/worker/ubatching.py
Normal file
241
vllm/v1/worker/ubatching.py
Normal file
@@ -0,0 +1,241 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import threading
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import forward_context
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.torch_utils import current_stream
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_THREAD_ID_TO_CONTEXT: dict = {}
|
||||
# Here we hardcode the number of microbatches to 2 for default.
|
||||
_NUM_UBATCHES: int = 2
|
||||
_CURRENT_CONTEXTS: list["UBatchContext | None"] = []
|
||||
|
||||
|
||||
class UBatchContext:
|
||||
"""
|
||||
Context manager for micro-batching synchronization using threading events.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: int,
|
||||
comm_stream: torch.cuda.Stream,
|
||||
compute_stream: torch.cuda.Stream,
|
||||
forward_context: ForwardContext,
|
||||
ready_barrier: threading.Barrier,
|
||||
cpu_wait_event: threading.Event,
|
||||
cpu_signal_event: threading.Event,
|
||||
gpu_comm_done_event: torch.Event,
|
||||
gpu_compute_done_event: torch.Event,
|
||||
schedule: str = "default",
|
||||
):
|
||||
self.id = id
|
||||
self.comm_stream = comm_stream
|
||||
self.compute_stream = compute_stream
|
||||
self.forward_context = forward_context
|
||||
self.ready_barrier = ready_barrier
|
||||
self.cpu_wait_event = cpu_wait_event
|
||||
self.cpu_signal_event = cpu_signal_event
|
||||
self.current_stream = compute_stream
|
||||
self.gpu_comm_done_event = gpu_comm_done_event
|
||||
self.gpu_compute_done_event = gpu_compute_done_event
|
||||
self.schedule = schedule
|
||||
self.recv_hook = None
|
||||
|
||||
def __enter__(self):
|
||||
global _CURRENT_CONTEXTS, _THREAD_ID_TO_CONTEXT
|
||||
_THREAD_ID_TO_CONTEXT[threading.get_ident()] = self.id
|
||||
_CURRENT_CONTEXTS[self.id] = self
|
||||
# _NUM_UBATCHES is set in make_ubatch_contexts
|
||||
self.ready_barrier.wait()
|
||||
|
||||
self.cpu_wait_event.wait()
|
||||
self.cpu_wait_event.clear()
|
||||
self._restore_context()
|
||||
# Assume we want to start on the compute stream
|
||||
self.update_stream(self.compute_stream)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
global _CURRENT_CONTEXTS, _THREAD_ID_TO_CONTEXT
|
||||
_CURRENT_CONTEXTS[self.id] = None
|
||||
del _THREAD_ID_TO_CONTEXT[threading.get_ident()]
|
||||
self.maybe_run_recv_hook()
|
||||
self.cpu_signal_event.set()
|
||||
self.cpu_wait_event.clear()
|
||||
return False
|
||||
|
||||
def _restore_context(self):
|
||||
forward_context._forward_context = self.forward_context
|
||||
|
||||
def update_stream(self, stream):
|
||||
self.current_stream = stream
|
||||
if current_stream() != self.current_stream:
|
||||
torch.cuda.set_stream(self.current_stream)
|
||||
|
||||
def _signal_comm_done(self):
|
||||
self.gpu_comm_done_event.record(self.comm_stream)
|
||||
|
||||
def _signal_compute_done(self):
|
||||
self.gpu_compute_done_event.record(self.compute_stream)
|
||||
|
||||
def _wait_compute_done(self):
|
||||
self.comm_stream.wait_event(self.gpu_compute_done_event)
|
||||
|
||||
def _wait_comm_done(self):
|
||||
self.compute_stream.wait_event(self.gpu_comm_done_event)
|
||||
|
||||
def _cpu_yield(self):
|
||||
# It is critical for correctness that only one thread is running
|
||||
# at a time. These asserts just make sure that this is the only
|
||||
# thread running before waking the other one up and going to sleep
|
||||
assert forward_context._forward_context == self.forward_context
|
||||
assert current_stream() == self.current_stream
|
||||
assert not self.cpu_wait_event.is_set()
|
||||
|
||||
self.cpu_signal_event.set()
|
||||
self.cpu_wait_event.wait()
|
||||
self.cpu_wait_event.clear()
|
||||
self._restore_context()
|
||||
|
||||
def switch_to_comm(self):
|
||||
self.update_stream(self.comm_stream)
|
||||
|
||||
def switch_to_compute(self):
|
||||
self.update_stream(self.compute_stream)
|
||||
|
||||
def switch_to_comm_sync(self):
|
||||
self._signal_compute_done()
|
||||
self.update_stream(self.comm_stream)
|
||||
self._wait_compute_done()
|
||||
|
||||
def switch_to_compute_sync(self):
|
||||
self._signal_comm_done()
|
||||
self.update_stream(self.compute_stream)
|
||||
self._wait_comm_done()
|
||||
|
||||
def maybe_run_recv_hook(self):
|
||||
if self.recv_hook is not None:
|
||||
self.recv_hook()
|
||||
self.recv_hook = None
|
||||
|
||||
def yield_(self):
|
||||
self.current_stream = current_stream()
|
||||
self._cpu_yield()
|
||||
self.update_stream(self.current_stream)
|
||||
|
||||
def yield_and_switch_from_compute_to_comm(self):
|
||||
assert current_stream() == self.compute_stream
|
||||
self._signal_compute_done()
|
||||
self._cpu_yield()
|
||||
assert self.current_stream == self.compute_stream
|
||||
self.update_stream(self.comm_stream)
|
||||
self._wait_compute_done()
|
||||
|
||||
def yield_and_switch_from_comm_to_compute(self):
|
||||
assert current_stream() == self.comm_stream
|
||||
self._signal_comm_done()
|
||||
self._cpu_yield()
|
||||
assert self.current_stream == self.comm_stream
|
||||
self.update_stream(self.compute_stream)
|
||||
self._wait_comm_done()
|
||||
|
||||
|
||||
def dbo_enabled() -> bool:
|
||||
return len(_THREAD_ID_TO_CONTEXT) > 0
|
||||
|
||||
|
||||
def dbo_current_ubatch_id() -> int:
|
||||
if len(_THREAD_ID_TO_CONTEXT) == 0:
|
||||
return 0
|
||||
return _THREAD_ID_TO_CONTEXT[threading.get_ident()]
|
||||
|
||||
|
||||
def _register_ubatch_function(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
if len(_THREAD_ID_TO_CONTEXT) > 0:
|
||||
ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()]
|
||||
ctx = _CURRENT_CONTEXTS[ctx_idx]
|
||||
func(ctx, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
dbo_maybe_run_recv_hook = _register_ubatch_function(UBatchContext.maybe_run_recv_hook)
|
||||
dbo_yield = _register_ubatch_function(UBatchContext.yield_)
|
||||
dbo_yield_and_switch_from_compute_to_comm = _register_ubatch_function(
|
||||
UBatchContext.yield_and_switch_from_compute_to_comm
|
||||
)
|
||||
dbo_yield_and_switch_from_comm_to_compute = _register_ubatch_function(
|
||||
UBatchContext.yield_and_switch_from_comm_to_compute
|
||||
)
|
||||
dbo_switch_to_comm = _register_ubatch_function(UBatchContext.switch_to_comm)
|
||||
dbo_switch_to_compute = _register_ubatch_function(UBatchContext.switch_to_compute)
|
||||
dbo_switch_to_comm_sync = _register_ubatch_function(UBatchContext.switch_to_comm_sync)
|
||||
dbo_switch_to_compute_sync = _register_ubatch_function(
|
||||
UBatchContext.switch_to_compute_sync
|
||||
)
|
||||
|
||||
|
||||
def dbo_register_recv_hook(recv_hook):
|
||||
if len(_THREAD_ID_TO_CONTEXT) > 0:
|
||||
ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()]
|
||||
next_ctx = _CURRENT_CONTEXTS[(ctx_idx + 1) % _NUM_UBATCHES]
|
||||
next_ctx.recv_hook = recv_hook
|
||||
|
||||
|
||||
def dbo_get_previous_event(func, *args, **kwargs):
|
||||
if len(_THREAD_ID_TO_CONTEXT) > 0:
|
||||
ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()]
|
||||
ctx = _CURRENT_CONTEXTS[ctx_idx]
|
||||
# execute callable on the ubatch compute stream to record/wait events there
|
||||
with torch.cuda.stream(ctx.compute_stream):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
def make_ubatch_contexts(
|
||||
num_micro_batches: int,
|
||||
compute_stream: torch.cuda.Stream,
|
||||
comm_stream: torch.cuda.Stream,
|
||||
forward_contexts: list[ForwardContext],
|
||||
ready_barrier: threading.Barrier,
|
||||
schedule: str = "default",
|
||||
) -> list[UBatchContext]:
|
||||
global _NUM_UBATCHES, _CURRENT_CONTEXTS
|
||||
assert num_micro_batches > 1, "num_micro_batches must be greater than 1"
|
||||
|
||||
_NUM_UBATCHES = num_micro_batches
|
||||
# Ensure the global context list is large enough
|
||||
if len(_CURRENT_CONTEXTS) < num_micro_batches:
|
||||
_CURRENT_CONTEXTS.extend([None] * (num_micro_batches - len(_CURRENT_CONTEXTS)))
|
||||
|
||||
"""
|
||||
Create a context manager for micro-batching synchronization.
|
||||
"""
|
||||
cpu_events = [threading.Event() for _ in range(num_micro_batches)]
|
||||
gpu_comm_done_events = [torch.Event() for _ in range(num_micro_batches)]
|
||||
gpu_compute_done_events = [torch.Event() for _ in range(num_micro_batches)]
|
||||
|
||||
ctxs = []
|
||||
for i in range(num_micro_batches):
|
||||
ctx = UBatchContext(
|
||||
id=i,
|
||||
compute_stream=compute_stream,
|
||||
comm_stream=comm_stream,
|
||||
forward_context=forward_contexts[i],
|
||||
ready_barrier=ready_barrier,
|
||||
cpu_wait_event=cpu_events[i],
|
||||
cpu_signal_event=cpu_events[(i + 1) % num_micro_batches],
|
||||
gpu_comm_done_event=gpu_comm_done_events[i],
|
||||
gpu_compute_done_event=gpu_compute_done_events[i],
|
||||
schedule=schedule,
|
||||
)
|
||||
ctxs.append(ctx)
|
||||
|
||||
return ctxs
|
||||
239
vllm/v1/worker/utils.py
Normal file
239
vllm/v1/worker/utils.py
Normal file
@@ -0,0 +1,239 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
|
||||
from vllm.model_executor.models.utils import extract_layer_index
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.mem_utils import MemorySnapshot, format_gib
|
||||
from vllm.v1.attention.backend import AttentionBackend, AttentionMetadataBuilder
|
||||
from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AttentionGroup:
|
||||
backend: type[AttentionBackend]
|
||||
layer_names: list[str]
|
||||
kv_cache_spec: KVCacheSpec
|
||||
kv_cache_group_id: int
|
||||
# When ubatching is enabled we will have a metadata builder for each ubatch
|
||||
# so that if they use internal persistent buffers for cudagraphs, and they
|
||||
# won't have to worry about conflicting with the other ubatches.
|
||||
metadata_builders: list[AttentionMetadataBuilder] = field(
|
||||
default_factory=lambda: []
|
||||
)
|
||||
|
||||
def create_metadata_builders(
|
||||
self,
|
||||
vllm_config,
|
||||
device,
|
||||
kernel_block_size: int | None,
|
||||
num_metadata_builders: int = 1,
|
||||
):
|
||||
kv_cache_spec_builder = (
|
||||
self.kv_cache_spec.copy_with_new_block_size(kernel_block_size)
|
||||
if kernel_block_size is not None
|
||||
else self.kv_cache_spec
|
||||
)
|
||||
self.metadata_builders = [
|
||||
self.backend.get_builder_cls()(
|
||||
kv_cache_spec_builder,
|
||||
self.layer_names,
|
||||
vllm_config,
|
||||
device,
|
||||
)
|
||||
for _ in range(num_metadata_builders)
|
||||
]
|
||||
|
||||
def get_metadata_builder(self, ubatch_id: int = 0) -> AttentionMetadataBuilder:
|
||||
assert len(self.metadata_builders) > ubatch_id
|
||||
return self.metadata_builders[ubatch_id]
|
||||
|
||||
|
||||
def sanity_check_mm_encoder_outputs(
|
||||
mm_embeddings: MultiModalEmbeddings,
|
||||
expected_num_items: int,
|
||||
) -> None:
|
||||
"""
|
||||
Perform sanity checks for the result of
|
||||
[`vllm.model_executor.models.SupportsMultiModal.embed_multimodal`][].
|
||||
"""
|
||||
assert isinstance(mm_embeddings, (list, tuple, torch.Tensor)), (
|
||||
"Expected multimodal embeddings to be a list/tuple of 2D tensors, "
|
||||
f"or a single 3D tensor, but got {type(mm_embeddings)} "
|
||||
"instead. This is most likely due to incorrect implementation "
|
||||
"of the model's `embed_multimodal` method."
|
||||
)
|
||||
|
||||
assert len(mm_embeddings) == expected_num_items, (
|
||||
"Expected number of multimodal embeddings to match number of "
|
||||
f"input items: {expected_num_items}, but got {len(mm_embeddings)=} "
|
||||
"instead. This is most likely due to incorrect implementation "
|
||||
"of the model's `embed_multimodal` method."
|
||||
)
|
||||
|
||||
assert all(e.ndim == 2 for e in mm_embeddings), (
|
||||
"Expected multimodal embeddings to be a sequence of 2D tensors, "
|
||||
f"but got tensors with shapes {[e.shape for e in mm_embeddings]} "
|
||||
"instead. This is most likely due to incorrect implementation "
|
||||
"of the model's `embed_multimodal` method."
|
||||
)
|
||||
|
||||
|
||||
def request_memory(init_snapshot: MemorySnapshot, cache_config: CacheConfig) -> int:
|
||||
"""
|
||||
Calculate the amount of memory required by vLLM, then validate
|
||||
that the current amount of free memory is sufficient for that.
|
||||
"""
|
||||
requested_memory = math.ceil(
|
||||
init_snapshot.total_memory * cache_config.gpu_memory_utilization
|
||||
)
|
||||
|
||||
if init_snapshot.free_memory < requested_memory:
|
||||
raise ValueError(
|
||||
f"Free memory on device {init_snapshot.device_} "
|
||||
f"({format_gib(init_snapshot.free_memory)}/"
|
||||
f"{format_gib(init_snapshot.total_memory)} GiB) on startup "
|
||||
f"is less than desired GPU memory utilization "
|
||||
f"({cache_config.gpu_memory_utilization}, "
|
||||
f"{format_gib(requested_memory)} GiB). Decrease GPU memory "
|
||||
f"utilization or reduce GPU memory used by other processes."
|
||||
)
|
||||
|
||||
return requested_memory
|
||||
|
||||
|
||||
def add_kv_sharing_layers_to_kv_cache_groups(
|
||||
shared_kv_cache_layers: dict[str, str],
|
||||
kv_cache_groups: list[KVCacheGroupSpec],
|
||||
runner_only_attn_layers: set[str] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Sets up KV cache sharing by reusing the allocated KV caches in `kv_caches`
|
||||
for layers that do not allocate its own KV cache, based on the mapping in
|
||||
`shared_kv_cache_layers`. Adds these layers to the corresponding KV cache
|
||||
group, which is needed to ensure that attention metadata is assigned later.
|
||||
|
||||
Args:
|
||||
shared_kv_cache_layers: Layer pairings for cross-layer KV sharing.
|
||||
If an Attention layer `layer_name` is in the keys of this dict, it
|
||||
means this layer will perform attention using the keys and values
|
||||
from the KV cache of `shared_kv_cache_layers[layer_name]`.
|
||||
kv_cache_groups: The KV cache groups of the model.
|
||||
"""
|
||||
layer_to_kv_cache_group: dict[str, KVCacheGroupSpec] = {}
|
||||
for kv_cache_group in kv_cache_groups:
|
||||
for layer_name in kv_cache_group.layer_names:
|
||||
layer_to_kv_cache_group[layer_name] = kv_cache_group
|
||||
|
||||
for layer_name, target_layer_name in shared_kv_cache_layers.items():
|
||||
tgt_kv_cache_group = layer_to_kv_cache_group[target_layer_name]
|
||||
tgt_kv_cache_group.layer_names.append(layer_name)
|
||||
|
||||
if runner_only_attn_layers is not None:
|
||||
runner_only_attn_layers.add(layer_name)
|
||||
|
||||
|
||||
def bind_kv_cache(
|
||||
kv_caches: dict[str, torch.Tensor],
|
||||
forward_context: dict[str, Attention],
|
||||
runner_kv_caches: list[torch.Tensor],
|
||||
num_attn_module: int = 1,
|
||||
) -> None:
|
||||
"""
|
||||
Bind the allocated KV cache to both ModelRunner and forward context so
|
||||
that the KV cache can be used in the forward pass.
|
||||
|
||||
This function:
|
||||
1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with
|
||||
kv_caches.
|
||||
2) Associates each attention layer in the `forward_context` with its
|
||||
corresponding KV cache in kv_caches.
|
||||
|
||||
Args:
|
||||
kv_caches: The allocated kv_caches with layer names as keys.
|
||||
forward_context: The global forward context containing all Attention
|
||||
layers with layer names as keys.
|
||||
runner_kv_caches: The kv_cache declared by ModelRunner.
|
||||
"""
|
||||
# Bind kv_caches to ModelRunner
|
||||
assert len(runner_kv_caches) == 0
|
||||
|
||||
# Convert kv_caches dict to a list of tensors in the order of layer_index.
|
||||
index2name = defaultdict(list)
|
||||
for layer_name in kv_caches:
|
||||
index2name[extract_layer_index(layer_name, num_attn_module)].append(layer_name)
|
||||
|
||||
for layer_index in sorted(index2name.keys()):
|
||||
layer_names = index2name[layer_index]
|
||||
if len(layer_names) > 1:
|
||||
# One typical case is encoder-decoder model, e.g., bart.
|
||||
# The cross attention and self attention in the same decoder layer
|
||||
# has different layer_name but the same layer_index.
|
||||
|
||||
# TODO - analyze where runner_kv_caches is used and the right
|
||||
# way to ensure it properly reflects multiple attention layers
|
||||
# in the same decoder block.
|
||||
if (
|
||||
current_platform.is_cuda_alike()
|
||||
or current_platform.is_xpu()
|
||||
or current_platform.is_cpu()
|
||||
):
|
||||
# We know that the GPU / CPU runner is not impacted by this
|
||||
# case. Some test code depends on runner_kv_caches, but
|
||||
# not in a way that's impacted by ignoring this.
|
||||
pass
|
||||
else:
|
||||
raise NotImplementedError
|
||||
for layer_name in layer_names:
|
||||
runner_kv_caches.append(kv_caches[layer_name])
|
||||
|
||||
# Bind kv_caches to forward context
|
||||
for layer_name, kv_cache in kv_caches.items():
|
||||
# NOTE: Use list because of v0 PP virtual engine.
|
||||
forward_context[layer_name].kv_cache = [kv_cache]
|
||||
|
||||
|
||||
def is_residual_scattered_for_sp(
|
||||
vllm_config: VllmConfig, num_input_tokens: int
|
||||
) -> bool:
|
||||
"""Check if the residual tensor is scattered for sequence parallelism.
|
||||
|
||||
The residual tensor is scattered across tensor parallel ranks when sequence
|
||||
parallelism and tensor parallelism is enabled.
|
||||
|
||||
This follows the same logic as SequenceParallelismPass.is_applicable_for_range():
|
||||
- In full-graph compilation mode (no splitting ops or using inductor graph
|
||||
partition), SP is always applied
|
||||
- Otherwise, SP is only applied for specific shapes in compile_sizes
|
||||
"""
|
||||
if not vllm_config.compilation_config.pass_config.enable_sp:
|
||||
return False
|
||||
|
||||
tp = vllm_config.parallel_config.tensor_parallel_size
|
||||
|
||||
if tp == 1:
|
||||
return False
|
||||
|
||||
# When sequence parallelism is enabled, we always pad num_input_tokens
|
||||
# to be a multiple of tensor_parallel_size (tp) earlier.
|
||||
assert num_input_tokens % tp == 0
|
||||
|
||||
if (
|
||||
not vllm_config.compilation_config.splitting_ops
|
||||
or vllm_config.compilation_config.use_inductor_graph_partition
|
||||
):
|
||||
return True
|
||||
compile_sizes = vllm_config.compilation_config.compile_sizes
|
||||
if compile_sizes is None:
|
||||
return False
|
||||
return num_input_tokens in compile_sizes
|
||||
373
vllm/v1/worker/worker_base.py
Normal file
373
vllm/v1/worker/worker_base.py
Normal file
@@ -0,0 +1,373 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, Any, TypeVar
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.tracing import instrument
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
from vllm.utils.system_utils import update_environment_variables
|
||||
from vllm.v1.kv_cache_interface import KVCacheSpec
|
||||
from vllm.v1.serial_utils import run_method
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||||
from vllm.v1.outputs import AsyncModelRunnerOutput, ModelRunnerOutput
|
||||
else:
|
||||
SchedulerOutput = object
|
||||
GrammarOutput = object
|
||||
AsyncModelRunnerOutput = object
|
||||
ModelRunnerOutput = object
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_R = TypeVar("_R")
|
||||
|
||||
|
||||
class WorkerBase:
|
||||
"""Worker interface that allows vLLM to cleanly separate implementations for
|
||||
different hardware. Also abstracts control plane communication, e.g., to
|
||||
communicate request metadata to other workers.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
is_driver_worker: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize common worker components.
|
||||
|
||||
Args:
|
||||
vllm_config: Complete vLLM configuration
|
||||
local_rank: Local device index
|
||||
rank: Global rank in distributed setup
|
||||
distributed_init_method: Distributed initialization method
|
||||
is_driver_worker: Whether this worker handles driver
|
||||
responsibilities
|
||||
"""
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.lora_config = vllm_config.lora_config
|
||||
self.load_config = vllm_config.load_config
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.device_config = vllm_config.device_config
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.observability_config = vllm_config.observability_config
|
||||
self.kv_transfer_config = vllm_config.kv_transfer_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
self.current_platform = current_platform
|
||||
|
||||
self.parallel_config.rank = rank
|
||||
self.local_rank = local_rank
|
||||
self.rank = rank
|
||||
self.distributed_init_method = distributed_init_method
|
||||
self.is_driver_worker = is_driver_worker
|
||||
|
||||
# Device and model state
|
||||
self.device: torch.device | None = None
|
||||
self.model_runner: nn.Module | None = None
|
||||
|
||||
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
||||
"""Get specifications for KV cache implementation."""
|
||||
raise NotImplementedError
|
||||
|
||||
def compile_or_warm_up_model(self) -> None:
|
||||
"""Prepare model for execution through compilation/warmup."""
|
||||
raise NotImplementedError
|
||||
|
||||
def check_health(self) -> None:
|
||||
"""Basic health check (override for device-specific checks)."""
|
||||
return
|
||||
|
||||
def init_device(self) -> None:
|
||||
"""Initialize device state, such as loading the model or other on-device
|
||||
memory allocations.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
|
||||
"""Initialize the KV cache with the given size in blocks."""
|
||||
raise NotImplementedError
|
||||
|
||||
def reset_mm_cache(self) -> None:
|
||||
reset_fn = getattr(self.model_runner, "reset_mm_cache", None)
|
||||
if callable(reset_fn):
|
||||
reset_fn()
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
raise NotImplementedError
|
||||
|
||||
def apply_model(self, fn: Callable[[nn.Module], _R]) -> _R:
|
||||
"""Apply a function on the model inside this worker."""
|
||||
return fn(self.get_model())
|
||||
|
||||
def get_model_inspection(self) -> str:
|
||||
"""Return a transformers-style hierarchical view of the model."""
|
||||
from vllm.model_inspection import format_model_inspection
|
||||
|
||||
return format_model_inspection(self.get_model())
|
||||
|
||||
def load_model(self) -> None:
|
||||
"""Load model onto target device."""
|
||||
raise NotImplementedError
|
||||
|
||||
def execute_model(
|
||||
self, scheduler_output: SchedulerOutput
|
||||
) -> ModelRunnerOutput | AsyncModelRunnerOutput | None:
|
||||
"""If this method returns None, sample_tokens should be called immediately after
|
||||
to obtain the ModelRunnerOutput.
|
||||
|
||||
Note that this design may be changed in future if/when structured outputs
|
||||
parallelism is re-architected.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def sample_tokens(
|
||||
self, grammar_output: GrammarOutput
|
||||
) -> ModelRunnerOutput | AsyncModelRunnerOutput:
|
||||
"""Should be called immediately after execute_model iff it returned None."""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_cache_block_size_bytes(self) -> int:
|
||||
"""Return the size of a single cache block, in bytes. Used in
|
||||
speculative decoding.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def list_loras(self) -> set[int]:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
"""Get vocabulary size from model configuration."""
|
||||
return self.model_config.get_vocab_size()
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Clean up resources held by the worker."""
|
||||
return
|
||||
|
||||
|
||||
class WorkerWrapperBase:
|
||||
"""
|
||||
This class represents one process in an executor/engine. It is responsible
|
||||
for lazily initializing the worker and handling the worker's lifecycle.
|
||||
We first instantiate the WorkerWrapper, which remembers the worker module
|
||||
and class name. Then, when we call `update_environment_variables`, and the
|
||||
real initialization happens in `init_worker`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rpc_rank: int = 0,
|
||||
global_rank: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the worker wrapper with the given vllm_config and rpc_rank.
|
||||
Note: rpc_rank is the rank of the worker in the executor. In most cases,
|
||||
it is also the rank of the worker in the distributed group. However,
|
||||
when multiple executors work together, they can be different.
|
||||
e.g. in the case of SPMD-style offline inference with TP=2,
|
||||
users can launch 2 engines/executors, each with only 1 worker.
|
||||
All workers have rpc_rank=0, but they have different ranks in the TP
|
||||
group.
|
||||
"""
|
||||
self.rpc_rank = rpc_rank
|
||||
self.global_rank = self.rpc_rank if global_rank is None else global_rank
|
||||
|
||||
# Initialized after init_worker is called
|
||||
self.worker: WorkerBase
|
||||
self.vllm_config: VllmConfig
|
||||
|
||||
def shutdown(self) -> None:
|
||||
if self.worker is not None:
|
||||
self.worker.shutdown()
|
||||
|
||||
def adjust_rank(self, rank_mapping: dict[int, int]) -> None:
|
||||
"""
|
||||
Adjust the rpc_rank based on the given mapping.
|
||||
It is only used during the initialization of the executor,
|
||||
to adjust the rpc_rank of workers after we create all workers.
|
||||
"""
|
||||
# if self.rpc_rank in rank_mapping:
|
||||
# self.rpc_rank = rank_mapping[self.rpc_rank]
|
||||
old_rank = self.rpc_rank
|
||||
if old_rank in rank_mapping:
|
||||
self.rpc_rank = rank_mapping[old_rank]
|
||||
if self.global_rank == old_rank:
|
||||
self.global_rank = rank_mapping[old_rank]
|
||||
|
||||
def update_environment_variables(
|
||||
self,
|
||||
envs_list: list[dict[str, str]],
|
||||
) -> None:
|
||||
envs = envs_list[self.rpc_rank]
|
||||
update_environment_variables(envs)
|
||||
|
||||
@instrument(span_name="Worker init")
|
||||
def init_worker(self, all_kwargs: list[dict[str, Any]]) -> None:
|
||||
"""
|
||||
Here we inject some common logic before initializing the worker.
|
||||
Arguments are passed to the worker class constructor.
|
||||
"""
|
||||
kwargs = all_kwargs[self.rpc_rank]
|
||||
|
||||
vllm_config: VllmConfig | None = kwargs.get("vllm_config")
|
||||
assert vllm_config is not None, (
|
||||
"vllm_config is required to initialize the worker"
|
||||
)
|
||||
self.vllm_config = vllm_config
|
||||
|
||||
vllm_config.enable_trace_function_call_for_thread()
|
||||
|
||||
from vllm.plugins import load_general_plugins
|
||||
|
||||
load_general_plugins()
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
if isinstance(parallel_config.worker_cls, str):
|
||||
worker_class: type[WorkerBase] = resolve_obj_by_qualname(
|
||||
parallel_config.worker_cls
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"passing worker_cls is no longer supported. "
|
||||
"Please pass keep the class in a separate module "
|
||||
"and pass the qualified name of the class as a string."
|
||||
)
|
||||
|
||||
if parallel_config.worker_extension_cls:
|
||||
worker_extension_cls = resolve_obj_by_qualname(
|
||||
parallel_config.worker_extension_cls
|
||||
)
|
||||
extended_calls = []
|
||||
if worker_extension_cls not in worker_class.__bases__:
|
||||
# check any conflicts between worker and worker_extension_cls
|
||||
for attr in dir(worker_extension_cls):
|
||||
if attr.startswith("__"):
|
||||
continue
|
||||
assert not hasattr(worker_class, attr), (
|
||||
f"Worker class {worker_class} already has an attribute"
|
||||
f" {attr}, which conflicts with the worker"
|
||||
f" extension class {worker_extension_cls}."
|
||||
)
|
||||
if callable(getattr(worker_extension_cls, attr)):
|
||||
extended_calls.append(attr)
|
||||
# dynamically inherit the worker extension class
|
||||
worker_class.__bases__ = worker_class.__bases__ + (
|
||||
worker_extension_cls,
|
||||
)
|
||||
logger.info(
|
||||
"Injected %s into %s for extended collective_rpc calls %s",
|
||||
worker_extension_cls,
|
||||
worker_class,
|
||||
extended_calls,
|
||||
)
|
||||
|
||||
shared_worker_lock = kwargs.pop("shared_worker_lock", None)
|
||||
if shared_worker_lock is None:
|
||||
msg = (
|
||||
"Missing `shared_worker_lock` argument from executor. "
|
||||
"This argument is needed for mm_processor_cache_type='shm'."
|
||||
)
|
||||
|
||||
mm_config = vllm_config.model_config.multimodal_config
|
||||
if mm_config and mm_config.mm_processor_cache_type == "shm":
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
logger.warning_once(msg)
|
||||
|
||||
self.mm_receiver_cache = None
|
||||
else:
|
||||
self.mm_receiver_cache = (
|
||||
MULTIMODAL_REGISTRY.worker_receiver_cache_from_config(
|
||||
vllm_config,
|
||||
shared_worker_lock,
|
||||
)
|
||||
)
|
||||
|
||||
with set_current_vllm_config(self.vllm_config):
|
||||
# To make vLLM config available during worker initialization
|
||||
self.worker = worker_class(**kwargs)
|
||||
|
||||
def initialize_from_config(self, kv_cache_configs: list[Any]) -> None:
|
||||
kv_cache_config = kv_cache_configs[self.global_rank]
|
||||
assert self.vllm_config is not None
|
||||
with set_current_vllm_config(self.vllm_config):
|
||||
self.worker.initialize_from_config(kv_cache_config) # type: ignore
|
||||
|
||||
def init_device(self):
|
||||
assert self.vllm_config is not None
|
||||
with set_current_vllm_config(self.vllm_config):
|
||||
# To make vLLM config available during device initialization
|
||||
self.worker.init_device() # type: ignore
|
||||
|
||||
def execute_method(self, method: str | bytes, *args, **kwargs):
|
||||
try:
|
||||
# method resolution order:
|
||||
# if a method is defined in this class, it will be called directly.
|
||||
# otherwise, since we define `__getattr__` and redirect attribute
|
||||
# query to `self.worker`, the method will be called on the worker.
|
||||
return run_method(self, method, args, kwargs)
|
||||
except Exception as e:
|
||||
# if the driver worker also execute methods,
|
||||
# exceptions in the rest worker may cause deadlock in rpc like ray
|
||||
# see https://github.com/vllm-project/vllm/issues/3455
|
||||
# print the error and inform the user to solve the error
|
||||
msg = (
|
||||
f"Error executing method {method!r}. "
|
||||
"This might cause deadlock in distributed execution."
|
||||
)
|
||||
logger.exception(msg)
|
||||
raise e
|
||||
|
||||
def __getattr__(self, attr: str):
|
||||
return getattr(self.worker, attr)
|
||||
|
||||
def _apply_mm_cache(self, scheduler_output: SchedulerOutput) -> None:
|
||||
mm_cache = self.mm_receiver_cache
|
||||
if mm_cache is None:
|
||||
return
|
||||
|
||||
for req_data in scheduler_output.scheduled_new_reqs:
|
||||
req_data.mm_features = mm_cache.get_and_update_features(
|
||||
req_data.mm_features
|
||||
)
|
||||
|
||||
def execute_model(
|
||||
self, scheduler_output: SchedulerOutput
|
||||
) -> ModelRunnerOutput | AsyncModelRunnerOutput | None:
|
||||
self._apply_mm_cache(scheduler_output)
|
||||
|
||||
return self.worker.execute_model(scheduler_output)
|
||||
|
||||
def reset_mm_cache(self) -> None:
|
||||
mm_receiver_cache = self.mm_receiver_cache
|
||||
if mm_receiver_cache is not None:
|
||||
mm_receiver_cache.clear_cache()
|
||||
|
||||
self.worker.reset_mm_cache()
|
||||
252
vllm/v1/worker/workspace.py
Normal file
252
vllm/v1/worker/workspace.py
Normal file
@@ -0,0 +1,252 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import inspect
|
||||
import os
|
||||
from itertools import accumulate
|
||||
from math import prod
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.math_utils import round_up
|
||||
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _compute_bytes(shape: tuple[int, ...], dtype: torch.dtype) -> int:
|
||||
return prod(shape) * dtype.itemsize
|
||||
|
||||
|
||||
# Constants
|
||||
_MB = 1024**2
|
||||
_GiB = 1024**3
|
||||
|
||||
# Global workspace manager instance
|
||||
_manager: "WorkspaceManager | None" = None
|
||||
|
||||
|
||||
class WorkspaceManager:
|
||||
"""Manager for workspace allocation.
|
||||
|
||||
Manages workspace buffers for DBO (Dual Batch Overlap) execution.
|
||||
Can be locked to prevent further growth during execution.
|
||||
"""
|
||||
|
||||
def __init__(self, device: torch.device, num_ubatches: int | None = None):
|
||||
self._device = device
|
||||
# Cache num ubatches at init based on configuration (default to 1)
|
||||
self._num_ubatches = num_ubatches if num_ubatches is not None else 1
|
||||
self._current_workspaces: list[torch.Tensor | None] = [None, None]
|
||||
self._locked: bool = False
|
||||
|
||||
@staticmethod
|
||||
def _workspace_size_bytes(workspace: torch.Tensor | None) -> int:
|
||||
"""Get size of workspace in bytes."""
|
||||
if workspace is None:
|
||||
return 0
|
||||
return workspace.numel() * workspace.element_size()
|
||||
|
||||
def lock(self) -> None:
|
||||
"""Lock the workspace to prevent further growth.
|
||||
|
||||
After locking, any attempt to allocate a larger workspace will raise
|
||||
an assertion error. This ensures workspace size is fixed during execution.
|
||||
"""
|
||||
self._locked = True
|
||||
if envs.VLLM_DEBUG_WORKSPACE:
|
||||
logger.info(
|
||||
"[WORKSPACE DEBUG] Workspace locked. Current sizes: %s",
|
||||
[
|
||||
self._workspace_size_bytes(ws) / _MB
|
||||
for ws in self._current_workspaces
|
||||
if ws is not None
|
||||
],
|
||||
)
|
||||
|
||||
def is_locked(self) -> bool:
|
||||
"""Check if workspace is locked."""
|
||||
return self._locked
|
||||
|
||||
def get_simultaneous(
|
||||
self, *shapes_and_dtypes: tuple[tuple[int, ...], torch.dtype]
|
||||
) -> list[torch.Tensor]:
|
||||
"""Get multiple workspace tensors simultaneously from a single allocation.
|
||||
|
||||
Args:
|
||||
*shapes_and_dtypes: One or more (shape, dtype) tuples.
|
||||
|
||||
Returns:
|
||||
List of tensor views into the workspace buffer, one per shape/dtype pair.
|
||||
"""
|
||||
actual_bytes = [_compute_bytes(s, d) for s, d in shapes_and_dtypes]
|
||||
aligned_bytes = [round_up(actual, 256) for actual in actual_bytes]
|
||||
total_bytes = sum(aligned_bytes)
|
||||
|
||||
# Calculate cumulative offsets using itertools.accumulate
|
||||
offsets = list(accumulate([0] + aligned_bytes[:-1]))
|
||||
|
||||
current_workspace = self._ensure_workspace_size(total_bytes)
|
||||
|
||||
return [
|
||||
current_workspace[offsets[i] : offsets[i] + actual_bytes[i]]
|
||||
.view(shapes_and_dtypes[i][1])
|
||||
.reshape(shapes_and_dtypes[i][0])
|
||||
for i in range(len(shapes_and_dtypes))
|
||||
]
|
||||
|
||||
def _ensure_workspace_size(self, required_bytes: int) -> torch.Tensor:
|
||||
"""Ensure workspace is allocated and large enough, return current workspace.
|
||||
|
||||
Args:
|
||||
required_bytes: The number of bytes required.
|
||||
|
||||
Returns:
|
||||
The current workspace tensor.
|
||||
"""
|
||||
ubatch_id = dbo_current_ubatch_id()
|
||||
current_workspace = self._current_workspaces[ubatch_id]
|
||||
current_size = self._workspace_size_bytes(current_workspace)
|
||||
|
||||
if current_size < required_bytes:
|
||||
|
||||
def get_caller_info() -> str:
|
||||
"""Find first frame outside WorkspaceManager."""
|
||||
curr_frame = inspect.currentframe()
|
||||
if curr_frame is None:
|
||||
return "unknown"
|
||||
# Walk up the stack skipping WorkspaceManager frames
|
||||
curr_frame = curr_frame.f_back
|
||||
while curr_frame is not None:
|
||||
# TODO: This only catches instance methods (self), missing
|
||||
# classmethods and staticmethods. Once Python 3.11+ is the
|
||||
# minimum supported version, use co_qualname instead:
|
||||
# qualname = curr_frame.f_code.co_qualname
|
||||
# if qualname.startswith("WorkspaceManager."):
|
||||
if isinstance(curr_frame.f_locals.get("self"), WorkspaceManager):
|
||||
curr_frame = curr_frame.f_back
|
||||
continue
|
||||
filename = os.path.basename(curr_frame.f_code.co_filename)
|
||||
return (
|
||||
f"{filename}:{curr_frame.f_lineno}:{curr_frame.f_code.co_name}"
|
||||
)
|
||||
return "unknown"
|
||||
|
||||
if self._locked:
|
||||
raise AssertionError(
|
||||
f"Workspace is locked but allocation from '{get_caller_info()}' "
|
||||
f"requires {required_bytes / _MB:.2f} MB, current size is "
|
||||
f"{current_size / _MB:.2f} MB. "
|
||||
"Workspace growth is not allowed after locking."
|
||||
)
|
||||
|
||||
for ubatch_id in range(self._num_ubatches):
|
||||
current_workspace = self._current_workspaces[ubatch_id]
|
||||
if (
|
||||
current_workspace is None
|
||||
or self._workspace_size_bytes(current_workspace) < required_bytes
|
||||
):
|
||||
# Delete old tensor before allocating new one to avoid
|
||||
# memory spike from resize_(). resize_() allocates new
|
||||
# memory before freeing old, which can cause OOM.
|
||||
# Must clear the list reference first since local var
|
||||
# is just a copy of the reference.
|
||||
self._current_workspaces[ubatch_id] = None
|
||||
del current_workspace
|
||||
self._current_workspaces[ubatch_id] = torch.empty(
|
||||
(required_bytes,), dtype=torch.uint8, device=self._device
|
||||
)
|
||||
|
||||
if envs.VLLM_DEBUG_WORKSPACE:
|
||||
logger.info(
|
||||
"[WORKSPACE DEBUG] Resized workspace from '%s': %.2f MB -> "
|
||||
"%.2f MB (%d ubatches, total memory %.2f MB)",
|
||||
get_caller_info(),
|
||||
current_size / _MB,
|
||||
required_bytes / _MB,
|
||||
self._num_ubatches,
|
||||
required_bytes * self._num_ubatches / _MB,
|
||||
)
|
||||
|
||||
current_workspace = self._current_workspaces[dbo_current_ubatch_id()]
|
||||
|
||||
return current_workspace
|
||||
|
||||
|
||||
def is_workspace_manager_initialized() -> bool:
|
||||
"""Check if workspace manager has been initialized.
|
||||
|
||||
Returns:
|
||||
True if workspace manager is initialized, False otherwise.
|
||||
"""
|
||||
return _manager is not None
|
||||
|
||||
|
||||
def current_workspace_manager() -> "WorkspaceManager":
|
||||
"""Get the current workspace manager instance.
|
||||
|
||||
Raises:
|
||||
AssertionError: If workspace manager has not been initialized.
|
||||
"""
|
||||
assert _manager is not None, (
|
||||
"WorkspaceManager not initialized. Call init_workspace_manager() "
|
||||
"with a device before using workspace functions."
|
||||
)
|
||||
return _manager
|
||||
|
||||
|
||||
def init_workspace_manager(
|
||||
device: torch.device, num_ubatches: int | None = None
|
||||
) -> None:
|
||||
"""Initialize the workspace manager with a device.
|
||||
|
||||
Must be called before using any workspace functions. Typically called
|
||||
from GPUModelRunner.__init__.
|
||||
|
||||
Args:
|
||||
device: The device to allocate workspace on.
|
||||
num_ubatches: Number of micro-batches. Defaults to 1.
|
||||
"""
|
||||
global _manager
|
||||
if _manager is not None:
|
||||
logger.warning(
|
||||
"WorkspaceManager already initialized on device %s, "
|
||||
"reinitializing on device %s",
|
||||
_manager._device,
|
||||
device,
|
||||
)
|
||||
_manager = WorkspaceManager(device, num_ubatches)
|
||||
|
||||
|
||||
def lock_workspace() -> None:
|
||||
"""Lock the workspace to prevent further growth.
|
||||
|
||||
After calling this function, any attempt to allocate a workspace larger
|
||||
than the current size will raise an AssertionError. This ensures that
|
||||
workspace size is fixed during execution and prevents unexpected memory
|
||||
allocations in the hot path.
|
||||
|
||||
Example:
|
||||
# During initialization
|
||||
init_workspace_manager(device)
|
||||
reserve_workspace(shape1, dtype1)
|
||||
reserve_workspace(shape2, dtype2)
|
||||
|
||||
# Lock after warmup/profiling
|
||||
lock_workspace()
|
||||
|
||||
# Now all get_workspace calls must fit in pre-allocated size
|
||||
"""
|
||||
current_workspace_manager().lock()
|
||||
|
||||
|
||||
def reset_workspace_manager() -> None:
|
||||
"""Reset the workspace manager to uninitialized state.
|
||||
|
||||
This is primarily intended for testing purposes to allow tests
|
||||
to reinitialize the workspace manager cleanly.
|
||||
"""
|
||||
global _manager
|
||||
_manager = None
|
||||
52
vllm/v1/worker/xpu_model_runner.py
Normal file
52
vllm/v1/worker/xpu_model_runner.py
Normal file
@@ -0,0 +1,52 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.torch_utils import supports_xpu_graph
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class XPUModelRunner(GPUModelRunner):
|
||||
"""A model runner for XPU devices."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
with _torch_cuda_wrapper():
|
||||
super().__init__(vllm_config, device)
|
||||
# FIXME: To be verified.
|
||||
self.cascade_attn_enabled = False
|
||||
|
||||
def _sync_device(self) -> None:
|
||||
torch.xpu.synchronize()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _torch_cuda_wrapper():
|
||||
try:
|
||||
# replace cuda APIs with xpu APIs, this should work by default
|
||||
torch.cuda.Stream = torch.xpu.Stream
|
||||
torch.cuda.default_stream = torch.xpu.current_stream
|
||||
torch.cuda.current_stream = torch.xpu.current_stream
|
||||
torch.cuda.stream = torch.xpu.stream
|
||||
torch.cuda.mem_get_info = torch.xpu.mem_get_info
|
||||
torch.cuda.synchronize = torch.xpu.synchronize
|
||||
if supports_xpu_graph():
|
||||
torch.cuda.graph = torch.xpu.graph
|
||||
torch.cuda.CUDAGraph = torch.xpu.XPUGraph
|
||||
torch.cuda.empty_cache = torch.xpu.empty_cache
|
||||
yield
|
||||
finally:
|
||||
pass
|
||||
114
vllm/v1/worker/xpu_worker.py
Normal file
114
vllm/v1/worker/xpu_worker.py
Normal file
@@ -0,0 +1,114 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import gc
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.profiler.wrapper import TorchProfilerWrapper
|
||||
from vllm.utils.mem_utils import MemorySnapshot, format_gib
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
from vllm.v1.utils import report_usage_stats
|
||||
from vllm.v1.worker.gpu_worker import Worker, init_worker_distributed_environment
|
||||
from vllm.v1.worker.workspace import init_workspace_manager
|
||||
from vllm.v1.worker.xpu_model_runner import XPUModelRunner
|
||||
|
||||
from .utils import request_memory
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class XPUWorker(Worker):
|
||||
"""A XPU worker class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
is_driver_worker: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
vllm_config, local_rank, rank, distributed_init_method, is_driver_worker
|
||||
)
|
||||
device_config = self.device_config
|
||||
assert device_config.device_type == "xpu"
|
||||
assert current_platform.is_xpu()
|
||||
|
||||
# Torch profiler. Enabled and configured through profiler_config.
|
||||
self.profiler: Any | None = None
|
||||
profiler_config = vllm_config.profiler_config
|
||||
if profiler_config.profiler == "torch":
|
||||
worker_name = f"{vllm_config.instance_id}-rank-{self.rank}"
|
||||
self.profiler = TorchProfilerWrapper(
|
||||
profiler_config,
|
||||
worker_name=worker_name,
|
||||
local_rank=self.local_rank,
|
||||
activities=["CPU", "XPU"],
|
||||
)
|
||||
|
||||
def init_device(self):
|
||||
device = self.device_config.device
|
||||
if (
|
||||
isinstance(device, torch.device)
|
||||
and device.type == "xpu"
|
||||
and current_platform.is_xpu()
|
||||
):
|
||||
self.device = torch.device(f"xpu:{self.local_rank}")
|
||||
current_platform.set_device(self.device)
|
||||
current_platform.check_if_supports_dtype(self.model_config.dtype)
|
||||
torch.xpu.empty_cache()
|
||||
self.init_gpu_memory = torch.xpu.get_device_properties(
|
||||
self.local_rank
|
||||
).total_memory
|
||||
else:
|
||||
raise RuntimeError(f"Not support device type: {self.device_config.device}")
|
||||
|
||||
ENV_CCL_ATL_TRANSPORT = os.getenv("CCL_ATL_TRANSPORT", "ofi")
|
||||
ENV_LOCAL_WORLD_SIZE = os.getenv(
|
||||
"LOCAL_WORLD_SIZE", str(self.parallel_config.world_size)
|
||||
)
|
||||
os.environ["CCL_ATL_TRANSPORT"] = ENV_CCL_ATL_TRANSPORT
|
||||
os.environ["LOCAL_WORLD_SIZE"] = ENV_LOCAL_WORLD_SIZE
|
||||
os.environ["LOCAL_RANK"] = str(self.local_rank)
|
||||
|
||||
init_worker_distributed_environment(
|
||||
self.vllm_config,
|
||||
self.rank,
|
||||
self.distributed_init_method,
|
||||
self.local_rank,
|
||||
current_platform.dist_backend,
|
||||
)
|
||||
|
||||
# Set random seed.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
# Now take memory snapshot after NCCL is initialized
|
||||
gc.collect()
|
||||
torch.xpu.empty_cache()
|
||||
|
||||
# take current memory snapshot
|
||||
self.init_snapshot = init_snapshot = MemorySnapshot(device=self.device)
|
||||
self.requested_memory = request_memory(init_snapshot, self.cache_config)
|
||||
logger.debug("worker init memory snapshot: %r", self.init_snapshot)
|
||||
logger.debug(
|
||||
"worker requested memory: %sGiB", format_gib(self.requested_memory)
|
||||
)
|
||||
|
||||
# Initialize workspace manager
|
||||
num_ubatches = 2 if self.vllm_config.parallel_config.enable_dbo else 1
|
||||
init_workspace_manager(self.device, num_ubatches)
|
||||
|
||||
# Construct the model runner
|
||||
self.model_runner = XPUModelRunner( # type: ignore
|
||||
self.vllm_config, self.device
|
||||
)
|
||||
|
||||
if self.rank == 0:
|
||||
# If usage stat is enabled, collect relevant info.
|
||||
report_usage_stats(self.vllm_config)
|
||||
Reference in New Issue
Block a user