init src 0.9.2

This commit is contained in:
2026-01-09 15:09:53 +08:00
parent 0eb2c0a4b3
commit 41d98d4359
1438 changed files with 417605 additions and 683 deletions

0
vllm/worker/__init__.py Normal file
View File

155
vllm/worker/cache_engine.py Normal file
View File

@@ -0,0 +1,155 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""CacheEngine class for managing the KV cache."""
from typing import List
import torch
from vllm.attention import get_attn_backend
from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig
from vllm.logger import init_logger
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType,
get_dtype_size, is_pin_memory_available)
from vllm.attention.backends.tree_decoding_utils import move_cache
logger = init_logger(__name__)
class CacheEngine:
"""Manages the KV cache.
This class is responsible for initializing and managing the GPU and CPU KV
caches. It also provides methods for performing KV cache operations, such
as swapping and copying.
"""
def __init__(
self,
cache_config: CacheConfig,
model_config: ModelConfig,
parallel_config: ParallelConfig,
device_config: DeviceConfig,
) -> None:
self.cache_config = cache_config
self.model_config = model_config
self.parallel_config = parallel_config
self.device_config = device_config
self.head_size = model_config.get_head_size()
# Models like Jamba, have mixed typed layers, E.g Mamba
self.num_attention_layers = model_config.get_num_layers_by_block_type(
parallel_config, LayerBlockType.attention)
self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
self.block_size = cache_config.block_size
self.num_gpu_blocks = cache_config.num_gpu_blocks
if self.num_gpu_blocks:
self.num_gpu_blocks //= parallel_config.pipeline_parallel_size
self.num_cpu_blocks = cache_config.num_cpu_blocks
if self.num_cpu_blocks:
self.num_cpu_blocks //= parallel_config.pipeline_parallel_size
if cache_config.cache_dtype == "auto":
self.dtype = model_config.dtype
else:
self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
# Get attention backend.
self.attn_backend = get_attn_backend(self.head_size,
model_config.dtype,
cache_config.cache_dtype,
self.block_size,
model_config.is_attention_free,
use_mla=model_config.use_mla)
# Initialize the cache.
self.gpu_cache = self._allocate_kv_cache(
self.num_gpu_blocks, self.device_config.device_type)
self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks, "cpu")
def _allocate_kv_cache(
self,
num_blocks: int,
device: str,
) -> List[torch.Tensor]:
"""Allocates KV cache on the specified device."""
kv_cache_generic_shape = self.attn_backend.get_kv_cache_shape(
num_blocks, self.block_size, self.num_kv_heads, self.head_size)
pin_memory = is_pin_memory_available() if device == "cpu" else False
kv_cache: List[torch.Tensor] = []
try:
kv_cache_stride_order = self.attn_backend.get_kv_cache_stride_order(
)
except (AttributeError, NotImplementedError):
kv_cache_stride_order = tuple(range(len(kv_cache_generic_shape)))
# The allocation respects the backend-defined stride order to ensure
# the semantic remains consistent for each backend. We first obtain the
# generic kv cache shape and then permute it according to the stride
# order which could result in a non-contiguous tensor.
kv_cache_allocation_shape = tuple(kv_cache_generic_shape[i]
for i in kv_cache_stride_order)
for _ in range(self.num_attention_layers):
# null block in CpuGpuBlockAllocator requires at least that
# block to be zeroed-out.
# We zero-out everything for simplicity.
layer_kv_cache = torch.zeros(
kv_cache_allocation_shape,
dtype=self.dtype,
pin_memory=pin_memory,
device=device).permute(*kv_cache_stride_order)
# view back to (TOTAL_PAGES, PAGE_SIZE, entry_shape...) for cases
# when entry_shape is higher than 1D
kv_cache.append(layer_kv_cache)
return kv_cache
def swap_in(self, src_to_dst: torch.Tensor) -> None:
for i in range(self.num_attention_layers):
self.attn_backend.swap_blocks(self.cpu_cache[i], self.gpu_cache[i],
src_to_dst)
def swap_out(self, src_to_dst: torch.Tensor) -> None:
for i in range(self.num_attention_layers):
self.attn_backend.swap_blocks(self.gpu_cache[i], self.cpu_cache[i],
src_to_dst)
def copy(self, src_to_dsts: torch.Tensor) -> None:
self.attn_backend.copy_blocks(self.gpu_cache, src_to_dsts)
def move_caches(self, kv_caches: List[torch.Tensor],
src_to_dsts: torch.Tensor) -> None:
move_cache(self.attn_backend,
kv_caches,
src_to_dsts,
self.cache_config.cache_dtype,
self.num_kv_heads,
self.head_size)
@staticmethod
def get_cache_block_size(
cache_config: CacheConfig,
model_config: ModelConfig,
parallel_config: ParallelConfig,
) -> int:
head_size = model_config.get_head_size()
num_heads = model_config.get_num_kv_heads(parallel_config)
num_attention_layers = model_config.get_num_layers_by_block_type(
parallel_config, LayerBlockType.attention)
if cache_config.cache_dtype == "auto":
dtype = model_config.dtype
else:
dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
key_cache_entry = num_heads * head_size
# For MLA there is no value cache, since the latent vector
# is joint keys and values.
value_cache_entry = key_cache_entry if not model_config.use_mla else 0
total = num_attention_layers * cache_config.block_size * \
(key_cache_entry + value_cache_entry)
dtype_size = get_dtype_size(dtype)
return dtype_size * total

View File

@@ -0,0 +1,326 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, cast
import torch
from vllm.attention import AttentionMetadata
from vllm.forward_context import set_forward_context
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.multimodal import MultiModalKwargs
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.utils import make_tensor_with_pad
from vllm.worker.cpu_model_runner import (CPUModelRunnerBase,
ModelInputForCPUBuilder,
ModelInputForCPUWithSamplingMetadata)
from vllm.worker.model_runner_base import (
_add_attn_metadata_broadcastable_dict,
_add_sampling_metadata_broadcastable_dict)
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
@dataclasses.dataclass(frozen=True)
class EncoderDecoderModelInputForCPU(ModelInputForCPUWithSamplingMetadata):
"""
Used by the EncoderDecoderModelRunner.
"""
encoder_input_tokens: Optional[torch.Tensor] = None
encoder_input_positions: Optional[torch.Tensor] = None
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
"input_tokens": self.input_tokens,
"input_positions": self.input_positions,
"encoder_input_tokens": self.encoder_input_tokens,
"encoder_input_positions": self.encoder_input_positions,
"multi_modal_kwargs": self.multi_modal_kwargs,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
_add_sampling_metadata_broadcastable_dict(tensor_dict,
self.sampling_metadata)
return tensor_dict
@classmethod
def from_broadcasted_tensor_dict(
cls,
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None,
) -> "EncoderDecoderModelInputForCPU":
return cast(
EncoderDecoderModelInputForCPU,
super().from_broadcasted_tensor_dict(tensor_dict, attn_backend))
class CPUEncoderDecoderModelRunner(
CPUModelRunnerBase[EncoderDecoderModelInputForCPU]):
_model_input_cls: Type[EncoderDecoderModelInputForCPU] = (
EncoderDecoderModelInputForCPU)
_builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder
def _list_to_int32_tensor(
self,
_list: List[int],
) -> torch.Tensor:
return torch.tensor(_list, dtype=torch.int32, device=self.device)
def _list_to_long_tensor(
self,
_list: List[int],
) -> torch.Tensor:
return torch.tensor(_list, dtype=torch.long, device=self.device)
def _empty_int32_tensor(self) -> torch.Tensor:
return self._list_to_int32_tensor([])
def _empty_long_tensor(self) -> torch.Tensor:
return self._list_to_long_tensor([])
def make_model_input_from_broadcasted_tensor_dict(
self, tensor_dict: Dict[str,
Any]) -> EncoderDecoderModelInputForCPU:
return EncoderDecoderModelInputForCPU.from_broadcasted_tensor_dict(
tensor_dict,
attn_backend=self.attn_backend,
)
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> EncoderDecoderModelInputForCPU:
model_input = self._prepare_model_input_tensors(
seq_group_metadata_list, finished_requests_ids)
(
attn_metadata,
encoder_input_tokens_tensor,
encoder_input_positions_tensor,
) = self._prepare_encoder_model_input_tensors(seq_group_metadata_list,
model_input)
# Sampling metadata is only required for the final pp group
generators = self.get_generators(finished_requests_ids)
sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
model_input.seq_lens,
model_input.query_lens,
self.device,
pin_memory=False,
generators=generators)
return dataclasses.replace(
model_input,
sampling_metadata=sampling_metadata,
attn_metadata=attn_metadata,
encoder_input_tokens=encoder_input_tokens_tensor,
encoder_input_positions=encoder_input_positions_tensor,
virtual_engine=virtual_engine,
)
def _prepare_encoder_model_input_tensors(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
model_input: EncoderDecoderModelInputForCPU,
) -> Tuple[AttentionMetadata, Optional[torch.Tensor],
Optional[torch.Tensor]]:
"""Helper method to prepare the encoder- and cross-attn-related
model inputs based on a given sequence group. These additional inputs
are used to augment an already-computed `EncoderDecoderModelInput`
data structure which already has decoder-related model inputs
populated.
Sets the following attn_metadata fields:
* `num_encoder_tokens`
* `encoder_seq_lens`
* `encoder_seq_lens_tensor`
* `max_encoder_seq_len`
* `cross_slot_mapping`
* `cross_block_tables`
Constructs a new model inputs data structure, based on
(1) the existing fields in the `model_inputs` argument,
and (2) the following additional fields which are
computed (or in the case of `attn_metadata`, updated)
by this function:
* attn_metadata
* encoder_input_tokens
* encoder_input_positions
Arguments:
* seq_group_metadata_list: list of sequence groups for which to
compute inputs
* model_inputs: model inputs data structure with decoder-oriented
fields already computed.
Return:
* Updated model inputs data structure
"""
if len(seq_group_metadata_list) == 0:
return (model_input.attn_metadata, None, None)
# Since we are not supporting chunked prefill either the entire
# batch is prefill or it is decode
is_prompt = seq_group_metadata_list[0].is_prompt
# Build encoder inputs
encoder_seq_lens: List[int] = []
if is_prompt:
# Prefill phase.
cross_block_tables = self._empty_int32_tensor().view(
len(seq_group_metadata_list), -1)
# Extract input tokens/positions, cross-attention slot-mapping,
# & seq len from each sequence group metadata
(
encoder_input_tokens,
encoder_input_positions,
cross_slot_mapping,
) = (
[],
[],
[],
)
for seq_group_metadata in seq_group_metadata_list:
# Build seq lens
seq_len = seq_group_metadata.encoder_seq_data.get_len()
token_ids = seq_group_metadata.encoder_seq_data.get_token_ids()
encoder_seq_lens.append(seq_len)
# Build slot mapping
for i in range(0, seq_len):
block_number = seq_group_metadata.cross_block_table[
i // self.block_size]
block_offset = i % self.block_size
slot = block_number * self.block_size + block_offset
cross_slot_mapping.append(slot)
# Build encoder input tokens
encoder_input_tokens.extend(token_ids)
encoder_input_positions.extend(list(range(0, seq_len)))
# Convert tokens/positions & cross-attention
# slot-mapping to encoder input tensors
encoder_input_tokens_tensor = self._list_to_long_tensor(
encoder_input_tokens)
encoder_input_positions_tensor = self._list_to_long_tensor(
encoder_input_positions)
cross_slot_mapping_tensor = self._list_to_long_tensor(
cross_slot_mapping)
else:
# Decode phase.
encoder_input_tokens_tensor = self._empty_long_tensor()
encoder_input_positions_tensor = self._empty_long_tensor()
cross_slot_mapping_tensor = self._empty_long_tensor()
# Extract cross-attention block tables &
# seq len from each sequence group metadata.
# Cross-attention block tables are empty
# during vLLM memory profiling.
cross_block_tables = []
for seq_group_metadata in seq_group_metadata_list:
for _ in range(len(seq_group_metadata.seq_data)):
encoder_seq_lens.append(
seq_group_metadata.encoder_seq_data.get_len())
cross_block_table = seq_group_metadata.cross_block_table
cross_block_tables.append([] if (
cross_block_table is None) else cross_block_table)
max_len_of_block_table = max(
len(block_table) for block_table in cross_block_tables)
cross_block_tables = make_tensor_with_pad(
cross_block_tables,
max_len=max_len_of_block_table,
pad=0,
dtype=torch.int32,
device=self.device,
)
# Compute encoder sequence lengths & encoder
# sequence starting offset tensors
max_encoder_seq_len = max(encoder_seq_lens, default=0)
encoder_seq_lens_tensor = self._list_to_int32_tensor(encoder_seq_lens)
encoder_seq_start_loc = torch.zeros(encoder_seq_lens_tensor.shape[0] +
1,
dtype=torch.int32,
device=self.device)
torch.cumsum(encoder_seq_lens_tensor,
dim=0,
dtype=encoder_seq_start_loc.dtype,
out=encoder_seq_start_loc[1:])
# Update attention metadata with encoder-oriented attributes
attn_metadata = model_input.attn_metadata
assert attn_metadata is not None
(
attn_metadata.num_encoder_tokens,
attn_metadata.encoder_seq_lens,
attn_metadata.encoder_seq_lens_tensor,
attn_metadata.max_encoder_seq_len,
attn_metadata.cross_slot_mapping,
attn_metadata.cross_block_tables,
) = (
sum(encoder_seq_lens),
encoder_seq_lens,
encoder_seq_lens_tensor,
max_encoder_seq_len,
cross_slot_mapping_tensor,
cross_block_tables,
)
return (attn_metadata, encoder_input_tokens_tensor,
encoder_input_positions_tensor)
@torch.no_grad()
def execute_model(
self,
model_input: EncoderDecoderModelInputForCPU,
kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> Optional[List[SamplerOutput]]:
if num_steps > 1:
raise ValueError(
"CPU worker does not support multi-step execution.")
model_executable = self.model
execute_model_kwargs = {
"input_ids":
model_input.input_tokens,
"positions":
model_input.input_positions,
"encoder_input_ids":
model_input.encoder_input_tokens,
"encoder_positions":
model_input.encoder_input_positions,
**MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs or {},
device=self.device,
),
"intermediate_tensors":
intermediate_tensors,
}
with set_forward_context(model_input.attn_metadata, self.vllm_config,
model_input.virtual_engine):
hidden_states = model_executable(**execute_model_kwargs)
# Compute the logits.
logits = self.model.compute_logits(hidden_states,
model_input.sampling_metadata)
# Only perform sampling in the driver worker.
if not self.is_driver_worker:
return []
# Sample the next token.
output = self.sampler(
logits=logits,
sampling_metadata=model_input.sampling_metadata,
)
return [output]

View File

@@ -0,0 +1,671 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
import weakref
from collections import defaultdict
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Type,
TypeVar, Union)
import torch
from torch import nn
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import supports_lora, supports_multimodal
from vllm.multimodal import (BatchedTensorInputs, MultiModalKwargs,
MultiModalPlaceholderMap)
from vllm.sequence import (IntermediateTensors, SequenceData,
SequenceGroupMetadata)
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
_add_attn_metadata_broadcastable_dict,
_add_sampling_metadata_broadcastable_dict,
_init_attn_metadata_from_tensor_dict,
_init_sampling_metadata_from_tensor_dict)
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
logger = init_logger(__name__)
TModelInputForCPU = TypeVar('TModelInputForCPU', bound="ModelInputForCPU")
_PAD_SLOT_ID = -1
@dataclass(frozen=True)
class ModelInputForCPU(ModelRunnerInputBase):
"""
Base class contains metadata needed for the base model forward pass on CPU
"""
input_tokens: Optional[torch.Tensor] = None
input_positions: Optional[torch.Tensor] = None
token_type_ids: Optional[torch.Tensor] = None
attn_metadata: Optional["AttentionMetadata"] = None
multi_modal_kwargs: Optional[BatchedTensorInputs] = None
virtual_engine: Optional[int] = None
seq_lens: Optional[List[int]] = None
query_lens: Optional[List[int]] = None
lora_mapping: Optional["LoRAMapping"] = None
lora_requests: Optional[Set[LoRARequest]] = None
def as_broadcastable_tensor_dict(
self) -> Dict[str, Union[int, torch.Tensor]]:
tensor_dict = {
"input_tokens": self.input_tokens,
"input_positions": self.input_positions,
"token_type_ids": self.token_type_ids,
"multi_modal_kwargs": self.multi_modal_kwargs,
"lora_requests": self.lora_requests,
"lora_mapping": self.lora_mapping,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
return tensor_dict
@classmethod
def from_broadcasted_tensor_dict(
cls: Type[TModelInputForCPU],
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None
) -> TModelInputForCPU:
if attn_backend is not None:
tensor_dict = _init_attn_metadata_from_tensor_dict(
attn_backend, tensor_dict)
return cls(**tensor_dict)
@dataclass(frozen=True)
class ModelInputForCPUWithSamplingMetadata(ModelInputForCPU):
"""
Used by the ModelRunner.
"""
sampling_metadata: Optional["SamplingMetadata"] = None
is_prompt: Optional[bool] = None
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
"input_tokens": self.input_tokens,
"input_positions": self.input_positions,
"token_type_ids": self.token_type_ids,
"multi_modal_kwargs": self.multi_modal_kwargs,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
_add_sampling_metadata_broadcastable_dict(tensor_dict,
self.sampling_metadata)
return tensor_dict
@classmethod
def from_broadcasted_tensor_dict(
cls,
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None,
) -> "ModelInputForCPUWithSamplingMetadata":
tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
if attn_backend is not None:
tensor_dict = _init_attn_metadata_from_tensor_dict(
attn_backend, tensor_dict)
return cls(**tensor_dict)
class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
class ModelInputData:
def __init__(self, use_mrope: bool):
self.use_mrope = use_mrope
self.input_tokens: List[int] = []
self.input_positions: List[int] = []
self.token_type_ids: Optional[List[int]] = []
self.seq_lens: List[int] = []
self.query_lens: List[int] = []
self.prefill_block_tables: List[List[int]] = []
self.decode_block_tables: List[List[int]] = []
self.max_decode_seq_len: int = 0
self.num_prefills: int = 0
self.num_prefill_tokens: int = 0
self.num_decode_tokens: int = 0
self.slot_mapping: List[int] = []
self.multi_modal_inputs_list: List[MultiModalKwargs] = []
self.multi_modal_placeholder_maps: Dict[
str, MultiModalPlaceholderMap] = defaultdict(
MultiModalPlaceholderMap)
self.input_mrope_positions: List[List[int]] = [[]
for _ in range(3)]
def __init__(self,
runner: "CPUModelRunner",
finished_requests_ids: Optional[List[str]] = None) -> None:
super().__init__()
self.runner = runner
self.chunked_prefill = (runner.scheduler_config.chunked_prefill_enabled
or runner.cache_config.enable_prefix_caching)
self.model_input_cls = self.runner._model_input_cls
self.attn_backend = self.runner.attn_backend
self.sliding_window = self.runner.sliding_window
self.block_size = self.runner.block_size
self.device = self.runner.device
self.enable_lora = self.runner.lora_config is not None
if self.runner.attn_backend is not None:
# spec decode (e.g. Medusa) does not have atten backend
attn_backend = self.runner.attn_backend
self.att_metadata_builder = attn_backend.get_builder_cls()(self)
def prepare(self,
finished_requests_ids: Optional[List[str]] = None) -> None:
self.seq_group_metadata_list: List[SequenceGroupMetadata] = []
self.input_data = ModelInputForCPUBuilder.ModelInputData(
self.runner.model_config.uses_mrope)
self.att_metadata_builder.prepare()
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
self.seq_group_metadata_list.append(seq_group_metadata)
def set_seq_group_list(
self, seq_group_metadata_list: List[SequenceGroupMetadata]):
self.seq_group_metadata_list = seq_group_metadata_list
def build(self) -> ModelInputForCPU:
self._build_input_data()
input_data = self.input_data
input_tokens = torch.tensor(input_data.input_tokens,
dtype=torch.long,
device="cpu")
input_positions = torch.tensor(
input_data.input_positions
if not any(input_data.input_mrope_positions) else
input_data.input_mrope_positions,
dtype=torch.long,
device="cpu")
token_type_ids = torch.tensor(input_data.token_type_ids,
dtype=torch.long,
device="cpu") \
if input_data.token_type_ids else None
# For multi-modal models
multi_modal_kwargs = None
if len(input_data.multi_modal_inputs_list) != 0:
multi_modal_kwargs = MultiModalKwargs.batch(
input_data.multi_modal_inputs_list)
attn_metadata = self.att_metadata_builder.build(
input_data.seq_lens, input_data.query_lens, -1, -1)
is_prompt = (self.seq_group_metadata_list[0].is_prompt
if self.seq_group_metadata_list else None)
# LoRA data.
lora_requests = set()
lora_mapping = None
if self.enable_lora:
lora_requests = set(seq.lora_request
for seq in self.seq_group_metadata_list
if seq.lora_request is not None)
lora_mapping = self._prepare_lora_input(
self.seq_group_metadata_list, is_prompt)
return self.model_input_cls(input_tokens=input_tokens,
input_positions=input_positions,
token_type_ids=token_type_ids,
seq_lens=input_data.seq_lens,
query_lens=input_data.query_lens,
attn_metadata=attn_metadata,
multi_modal_kwargs=multi_modal_kwargs,
lora_mapping=lora_mapping,
lora_requests=lora_requests)
def _build_input_data(self):
for seq_group_metadata in self.seq_group_metadata_list:
for seq_id, seq_data in seq_group_metadata.seq_data.items():
if seq_group_metadata.is_prompt:
self._compute_prompt_input_tokens(self.input_data,
seq_group_metadata,
seq_data, seq_id)
if seq_group_metadata.multi_modal_data:
self._compute_multi_modal_input(
seq_group_metadata, seq_data)
else:
self._compute_decode_input_tokens(self.input_data,
seq_group_metadata,
seq_data, seq_id)
def _compute_decode_input_tokens(self, data: ModelInputData,
seq_group_metadata: SequenceGroupMetadata,
seq_data: SequenceData, seq_id: int):
"""
Compute decode input tokens, positions, block table and slot mapping.
"""
block_size = self.runner.block_size
block_table = seq_group_metadata.block_tables[seq_id]
seq_len = seq_data.get_len()
context_len = seq_data.get_num_computed_tokens()
tokens = seq_data.get_last_token_id()
token_positions = seq_len - 1
block_number = block_table[token_positions // block_size]
block_offset = token_positions % block_size
slot = block_number * block_size + block_offset
# For paged_attention kernel
if self.runner.sliding_window:
start_idx = max(0, seq_len - self.runner.sliding_window)
start_block = start_idx // block_size
start_idx = start_block * block_size
seq_len = seq_len - start_idx
block_table = block_table[start_block:]
# For MRotaryEmbedding
if seq_data.mrope_position_delta is not None:
next_pos = MRotaryEmbedding.get_next_input_positions(
seq_data.mrope_position_delta,
context_len,
seq_len,
)
for idx in range(3):
data.input_mrope_positions[idx].extend( # type: ignore
next_pos[idx])
else:
data.input_positions.append(token_positions) # type: ignore
# Update fields
data.input_tokens.append(tokens)
data.max_decode_seq_len = max(data.max_decode_seq_len, seq_len)
data.num_decode_tokens += 1
data.slot_mapping.append(slot)
data.decode_block_tables.append(block_table)
data.query_lens.append(1)
data.seq_lens.append(seq_len)
def _compute_prompt_input_tokens(self, data: ModelInputData,
seq_group_metadata: SequenceGroupMetadata,
seq_data: SequenceData, seq_id: int):
"""
Compute prompt input tokens, positions, block table and slot mapping.
"""
token_chunk_size = seq_group_metadata.token_chunk_size
block_size = self.runner.block_size
block_table = seq_group_metadata.block_tables[seq_id]
seq_len = seq_data.get_len()
context_len = seq_data.get_num_computed_tokens()
seq_len = min(seq_len, context_len + token_chunk_size)
# For prefix caching
prefix_cache_block_num = len(seq_group_metadata.computed_block_nums)
if prefix_cache_block_num > 0:
prefix_cache_len = (prefix_cache_block_num *
self.runner.block_size)
if prefix_cache_len <= context_len:
# We already passed the cache hit region,
# so do normal computation.
pass
elif context_len < prefix_cache_len < seq_len:
# Partial hit. Compute the missing part.
context_len = prefix_cache_len
token_chunk_size = seq_len - context_len
elif seq_len <= prefix_cache_len:
# Full hit. Only compute the last token to avoid
# erroneous behavior. FIXME: Ideally we should directly
# mark all tokens as computed in the scheduler and do not
# schedule this sequence, so this case should not happen.
context_len = seq_len - 1
token_chunk_size = 1
tokens = seq_data.get_token_ids()
tokens = tokens[context_len:seq_len]
token_positions = range(context_len, seq_len)
token_types = seq_group_metadata.token_type_ids
# For encoder-only models, the block_table is None,
# and there is no need to initialize the slot_mapping.
if block_table is not None:
slot_mapping = [_PAD_SLOT_ID] * len(token_positions)
for i, pos in enumerate(token_positions):
block_number = block_table[pos // block_size]
block_offset = pos % block_size
slot = block_number * block_size + block_offset
slot_mapping[i] = slot
data.slot_mapping.extend(slot_mapping)
# The MROPE positions are prepared in _compute_multi_modal_input
data.input_positions.extend(token_positions)
if data.token_type_ids is not None:
data.token_type_ids.extend(token_types if token_types else [])
# Update fields
data.input_tokens.extend(tokens)
data.num_prefills += 1
data.num_prefill_tokens += len(tokens)
data.query_lens.append(len(tokens))
data.prefill_block_tables.append(block_table)
data.seq_lens.append(seq_len)
def _compute_multi_modal_input(self,
seq_group_metadata: SequenceGroupMetadata,
seq_data: SequenceData):
computed_len = seq_data.get_num_computed_tokens()
seq_len = self.input_data.seq_lens[-1]
# NOTE: mm_kwargs only includes the subset of multi-modal items that
# intersect with the current prefill positions.
mm_kwargs, placeholder_maps = MultiModalPlaceholderMap.from_seq_group(
seq_group_metadata, range(computed_len, seq_len))
if not mm_kwargs:
return
# special processing for mrope position deltas.
if self.runner.model_config.uses_mrope:
assert not self.chunked_prefill, \
"MROPE on CPU does not support chunked-prefill."
image_grid_thw = mm_kwargs.get("image_grid_thw", None)
video_grid_thw = mm_kwargs.get("video_grid_thw", None)
audio_feature_lengths = mm_kwargs.get("audio_feature_lengths",
None)
assert (
image_grid_thw is not None or video_grid_thw is not None
or audio_feature_lengths is not None), (
"mrope embedding type requires multi-modal input mapper "
"returns 'image_grid_thw' or 'video_grid_thw' or "
"'audio_feature_lengths'.")
second_per_grid_ts = mm_kwargs.get("second_per_grid_ts", None)
use_audio_in_video = mm_kwargs.get("use_audio_in_video", False)
hf_config = self.runner.model_config.hf_config
token_ids = seq_data.get_token_ids()
mrope_positions, mrope_position_delta = \
MRotaryEmbedding.get_input_positions(
token_ids,
hf_config=hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
context_len=computed_len,
audio_feature_lengths=audio_feature_lengths,
use_audio_in_video=use_audio_in_video,
)
seq_data.mrope_position_delta = mrope_position_delta
for i in range(3):
self.input_data.input_mrope_positions[ # type: ignore
i].extend(mrope_positions[i])
self.input_data.multi_modal_inputs_list.append(mm_kwargs)
for modality, placeholder_map in placeholder_maps.items():
self.input_data.multi_modal_placeholder_maps[modality].extend(
placeholder_map)
def _prepare_lora_input(
self, seq_group_metadata_list: List[SequenceGroupMetadata],
is_prefill: bool) -> LoRAMapping:
index_mapping = []
prompt_mapping = []
for seq in seq_group_metadata_list:
lora_id = seq.lora_int_id
query_len = seq.token_chunk_size
index_mapping += [lora_id] * query_len
prompt_mapping += [lora_id] * (
query_len if seq.sampling_params
and seq.sampling_params.prompt_logprobs is not None else 1)
return LoRAMapping(index_mapping=tuple(index_mapping),
prompt_mapping=tuple(prompt_mapping),
is_prefill=is_prefill)
class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
"""
Helper class for shared methods between CPU model runners.
"""
_model_input_cls: Type[TModelInputForCPU]
_builder_cls: Type[ModelInputForCPUBuilder]
builder: ModelInputForCPUBuilder
def __init__(
self,
vllm_config: VllmConfig,
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
return_hidden_states: bool = False,
*args,
**kwargs,
):
ModelRunnerBase.__init__(self, vllm_config)
model_config = self.model_config
cache_config = self.cache_config
self.is_driver_worker = is_driver_worker
self.return_hidden_states = return_hidden_states
self.device = self.device_config.device
self.pin_memory = False
self.kv_cache_dtype = kv_cache_dtype
self.sliding_window = model_config.get_sliding_window()
self.block_size = cache_config.block_size
num_attn_heads = self.model_config.get_num_attention_heads(
self.parallel_config)
needs_attn_backend = (num_attn_heads != 0
or self.model_config.is_attention_free)
self.attn_backend = get_attn_backend(
self.model_config.get_head_size(),
self.model_config.dtype,
self.kv_cache_dtype,
self.block_size,
self.model_config.is_attention_free,
use_mla=self.model_config.use_mla,
) if needs_attn_backend else None
# Lazy initialization.
self.model: nn.Module # Set after init_Model
# Set after load_model.
self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
self.sampler = get_sampler()
if hasattr(self, "_builder_cls"):
# multi-step model runner does not have `_builder_cls`
self.builder = self._builder_cls(weakref.proxy(self))
def load_model(self) -> None:
self.model = get_model(vllm_config=self.vllm_config)
if self.lora_config:
assert supports_lora(
self.model
), f"{self.model.__class__.__name__} does not support LoRA yet."
if supports_multimodal(self.model):
logger.warning("Regarding multimodal models, vLLM currently "
"only supports adding LoRA to language model.")
# Use get_text_config() in case of multimodal models
text_config = self.model_config.hf_config.get_text_config()
self.lora_manager = LRUCacheWorkerLoRAManager(
self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens,
self.vocab_size,
self.lora_config,
self.device,
self.model.embedding_modules,
self.model.embedding_padding_modules,
max_position_embeddings=text_config.max_position_embeddings,
)
self.model = self.lora_manager.create_lora_manager(self.model)
def get_model(self) -> nn.Module:
return self.model
def _prepare_model_input_tensors(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
finished_requests_ids: Optional[List[str]] = None
) -> TModelInputForCPU:
"""Helper method to prepare the model input based on a given sequence
group. Prepares metadata needed for the base model forward pass but not
metadata for possible additional steps, e.g., sampling.
"""
self.builder.prepare(finished_requests_ids)
self.builder.set_seq_group_list(seq_group_metadata_list)
return self.builder.build() # type: ignore
@property
def vocab_size(self) -> int:
return self.model_config.get_vocab_size()
def remove_all_loras(self):
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
self.lora_manager.remove_all_adapters()
def set_active_loras(self, lora_requests: Set[LoRARequest],
lora_mapping: LoRAMapping) -> None:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
def add_lora(self, lora_request: LoRARequest) -> bool:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.add_adapter(lora_request)
def remove_lora(self, lora_id: int) -> bool:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.remove_adapter(lora_id)
def pin_lora(self, lora_id: int) -> bool:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.pin_adapter(lora_id)
def list_loras(self) -> Set[int]:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.list_adapters()
class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
_model_input_cls: Type[ModelInputForCPUWithSamplingMetadata] = (
ModelInputForCPUWithSamplingMetadata)
_builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder
def make_model_input_from_broadcasted_tensor_dict(
self,
tensor_dict: Dict[str, Any],
) -> ModelInputForCPUWithSamplingMetadata:
return ModelInputForCPUWithSamplingMetadata.from_broadcasted_tensor_dict( # noqa: E501
tensor_dict,
attn_backend=self.attn_backend,
)
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> ModelInputForCPUWithSamplingMetadata:
"""Prepare the model input based on a given sequence group, including
metadata for the sampling step.
"""
model_input = self._prepare_model_input_tensors(
seq_group_metadata_list, finished_requests_ids)
# Sampling metadata is only required for the final pp group
generators = self.get_generators(finished_requests_ids)
sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
model_input.seq_lens,
model_input.query_lens,
self.device,
pin_memory=False,
generators=generators)
is_prompt = (seq_group_metadata_list[0].is_prompt
if seq_group_metadata_list else None)
return dataclasses.replace(model_input,
sampling_metadata=sampling_metadata,
virtual_engine=virtual_engine,
is_prompt=is_prompt)
@torch.no_grad()
def execute_model(
self,
model_input: ModelInputForCPUWithSamplingMetadata,
kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
previous_hidden_states: Optional[torch.Tensor] = None,
) -> Optional[List[SamplerOutput]]:
if num_steps > 1:
raise ValueError(
"CPU worker does not support multi-step execution.")
if self.lora_config:
assert model_input.lora_requests is not None
assert model_input.lora_mapping is not None
self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping)
model_executable = self.model
multimodal_kwargs = {}
if model_input.multi_modal_kwargs is not None:
multimodal_kwargs = MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs,
device=self.device,
)
execute_model_kwargs = {}
if previous_hidden_states is not None:
execute_model_kwargs.update(
{"previous_hidden_states": previous_hidden_states})
with set_forward_context(model_input.attn_metadata, self.vllm_config,
model_input.virtual_engine):
hidden_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
intermediate_tensors=intermediate_tensors,
**execute_model_kwargs,
**multimodal_kwargs,
)
# Compute the logits.
logits = self.model.compute_logits(hidden_states,
model_input.sampling_metadata)
# Only perform sampling in the driver worker.
if not self.is_driver_worker:
return []
# Sample the next token.
output = self.sampler(
logits=logits,
sampling_metadata=model_input.sampling_metadata,
)
if self.return_hidden_states:
# we only need to pass hidden states of most recent token
if model_input.is_prompt:
output.prefill_hidden_states = hidden_states
output.hidden_states = hidden_states
return [output]
def generate_proposals(self, *args, **kwargs):
return self.model.generate_proposals(*args, **kwargs)

View File

@@ -0,0 +1,125 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
from typing import Any, Dict, List, Optional, Tuple, Type, Union
import torch
from vllm.forward_context import set_forward_context
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.multimodal import MultiModalKwargs
from vllm.pooling_params import PoolingParams
from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData,
SequenceGroupMetadata)
from vllm.worker.cpu_model_runner import (CPUModelRunnerBase, ModelInputForCPU,
ModelInputForCPUBuilder)
@dataclasses.dataclass(frozen=True)
class ModelInputForCPUWithPoolingMetadata(ModelInputForCPU):
"""
Used by the CPUPoolingModelRunner.
"""
pooling_metadata: Optional["PoolingMetadata"] = None
class CPUPoolingModelRunner(
CPUModelRunnerBase[ModelInputForCPUWithPoolingMetadata]):
_model_input_cls: Type[ModelInputForCPUWithPoolingMetadata] = (
ModelInputForCPUWithPoolingMetadata)
_builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder
@torch.inference_mode()
def execute_model(
self,
model_input: ModelInputForCPUWithPoolingMetadata,
kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> Optional[Union[List[PoolerOutput], IntermediateTensors]]:
if num_steps > 1:
raise ValueError(
"CPU worker does not support multi-step execution.")
model_executable = self.model
cross_enc_kwargs = {}
if model_input.token_type_ids is not None:
cross_enc_kwargs["token_type_ids"] = model_input.token_type_ids
execute_model_kwargs = {
"input_ids":
model_input.input_tokens,
"positions":
model_input.input_positions,
**MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs or {},
device=self.device,
),
**cross_enc_kwargs,
"intermediate_tensors":
intermediate_tensors,
}
with set_forward_context(model_input.attn_metadata, self.vllm_config,
model_input.virtual_engine):
hidden_states = model_executable(**execute_model_kwargs)
# Only perform pooling in the driver worker.
if not self.is_driver_worker:
return []
return [
self.model.pooler(hidden_states=hidden_states,
pooling_metadata=model_input.pooling_metadata)
]
def make_model_input_from_broadcasted_tensor_dict(
self,
tensor_dict: Dict[str,
Any]) -> ModelInputForCPUWithPoolingMetadata:
return ModelInputForCPUWithPoolingMetadata.from_broadcasted_tensor_dict(
tensor_dict,
attn_backend=self.attn_backend,
)
def prepare_model_input(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> ModelInputForCPUWithPoolingMetadata:
assert seq_group_metadata_list is not None
model_input = self._prepare_model_input_tensors(
seq_group_metadata_list, finished_requests_ids)
# Prepare PoolingMetadata.
assert model_input.seq_lens is not None
pooling_metadata = self._prepare_pooling(seq_group_metadata_list,
model_input.seq_lens)
return dataclasses.replace(model_input,
virtual_engine=virtual_engine,
pooling_metadata=pooling_metadata)
def _prepare_pooling(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
prompt_lens: List[int],
) -> PoolingMetadata:
"""Prepare PoolingMetadata for the sequence group metadata list."""
seq_groups: List[Tuple[List[int], PoolingParams]] = []
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
seq_ids = list(seq_group_metadata.seq_data.keys())
pooling_params = seq_group_metadata.pooling_params
seq_groups.append((seq_ids, pooling_params))
seq_data: Dict[int, SequenceData] = {}
for seq_group_metadata in seq_group_metadata_list:
seq_data.update(seq_group_metadata.seq_data)
pooling_metadata = PoolingMetadata(
seq_groups=seq_groups,
seq_data=seq_data,
prompt_lens=prompt_lens,
)
return pooling_metadata

457
vllm/worker/cpu_worker.py Normal file
View File

@@ -0,0 +1,457 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A CPU worker class."""
import os
from importlib import util
from typing import List, Optional, Set, Tuple, Type
import torch
import torch.distributed
import vllm.envs as envs
from vllm.attention import get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, VllmConfig)
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.worker.cache_engine import CacheEngine
from vllm.sequence import ExecuteModelRequest
from vllm.utils import bind_kv_cache
from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner
from vllm.worker.cpu_model_runner import CPUModelRunner, CPUModelRunnerBase
from vllm.worker.cpu_pooling_model_runner import CPUPoolingModelRunner
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
WorkerInput)
logger = init_logger(__name__)
class CPUCacheEngine:
"""Manages the KV cache for CPU backend.
This class is responsible for initializing and managing CPU KV
caches. It also provides methods for performing KV cache operations, such
as copying.
"""
def __init__(self, cache_config: CacheConfig, model_config: ModelConfig,
parallel_config: ParallelConfig,
device_config: DeviceConfig) -> None:
assert device_config.device_type == "cpu"
self.cache_config = cache_config
self.model_config = model_config
self.parallel_config = parallel_config
self.head_size = model_config.get_head_size()
self.num_layers = model_config.get_num_layers(parallel_config)
self.num_heads = model_config.get_num_kv_heads(parallel_config)
self.block_size = cache_config.block_size
# Note: In CacheConfig, num_gpu_blocks actual is num_cpu_blocks
# for CPU backend, because we want to reuse KV cache management
# in the scheduler.
self.num_cpu_blocks = cache_config.num_gpu_blocks
self.dtype = CPUCacheEngine.get_kv_cache_dtype(cache_config,
model_config)
# Get attention backend.
self.attn_backend = get_attn_backend(
self.model_config.get_head_size(),
self.model_config.dtype,
cache_config.cache_dtype,
self.block_size,
self.model_config.is_attention_free,
use_mla=self.model_config.use_mla,
)
# Initialize the cache.
self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks)
def _allocate_kv_cache(
self,
num_blocks: int,
) -> List[torch.Tensor]:
"""Allocates KV cache on CPU."""
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks, self.block_size, self.num_heads, self.head_size)
kv_cache: List[torch.Tensor] = []
for _ in range(self.num_layers):
kv_cache.append(
torch.empty(kv_cache_shape, dtype=self.dtype, device="cpu"))
return kv_cache
def swap_in(self, src_to_dst: torch.Tensor) -> None:
raise NotImplementedError("Swap is not supported in CPUCacheEngine.")
def swap_out(self, src_to_dst: torch.Tensor) -> None:
raise NotImplementedError("Swap is not supported in CPUCacheEngine.")
def copy(self, src_to_dsts: torch.Tensor) -> None:
self.attn_backend.copy_blocks(self.cpu_cache, src_to_dsts)
@staticmethod
def get_kv_cache_dtype(cache_config: CacheConfig,
model_config: ModelConfig):
if cache_config.cache_dtype == "auto":
return model_config.dtype
elif cache_config.cache_dtype in ["fp8", "fp8_e5m2"]:
return torch.float8_e5m2
else:
raise NotImplementedError(f"Unsupported KV cache type "
f"{cache_config.cache_dtype}.")
@staticmethod
def get_cache_block_size(
cache_config: CacheConfig,
model_config: ModelConfig,
parallel_config: ParallelConfig,
) -> int:
head_size = model_config.get_head_size()
num_heads = model_config.get_num_kv_heads(parallel_config)
num_layers = model_config.get_num_layers(parallel_config)
key_cache_block = cache_config.block_size * num_heads * head_size
value_cache_block = key_cache_block if not model_config.use_mla else 0
total = num_layers * (key_cache_block + value_cache_block)
dtype = CPUCacheEngine.get_kv_cache_dtype(cache_config, model_config)
dtype_size = torch.tensor([], dtype=dtype).element_size()
return dtype_size * total
class CPUWorker(LocalOrDistributedWorkerBase):
"""A worker class that executes (a partition of) the model on a CPU socket.
Each worker is associated with a single CPU socket. The worker is
responsible for maintaining the KV cache and executing the model on the
CPU. In case of distributed inference, each worker is assigned a partition
of the model.
"""
def __init__(
self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
model_runner_cls: Optional[Type[CPUModelRunner]] = None,
) -> None:
WorkerBase.__init__(self, vllm_config=vllm_config)
self.local_rank = local_rank
self.rank = rank
vllm_config.parallel_config.rank = rank
self.distributed_init_method = distributed_init_method
self.is_driver_worker = is_driver_worker
if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0."
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
# Setup OpenMP threads affinity.
omp_cpuids = envs.VLLM_CPU_OMP_THREADS_BIND
self.local_omp_cpuid = "all"
if omp_cpuids == "auto":
self.local_omp_cpuid = self.get_cpus_id_binding_based_on_numa_nodes(
)
else:
self.local_omp_cpuid = omp_cpuids.split("|")[rank]
# Return hidden states from target model if the draft model is an
# mlp_speculator
speculative_config = self.speculative_config
model_config = self.model_config
speculative_args = {} if speculative_config is None \
or (speculative_config.draft_model_config.model ==
model_config.model) \
or (speculative_config.draft_model_config.hf_config.model_type
not in ["medusa", "mlp_speculator", "eagle"]) \
else {"return_hidden_states": True}
ModelRunnerClass: Type[CPUModelRunnerBase] = CPUModelRunner
if self.model_config.runner_type == "pooling":
ModelRunnerClass = CPUPoolingModelRunner
elif self.model_config.is_encoder_decoder:
ModelRunnerClass = CPUEncoderDecoderModelRunner
self.model_runner: CPUModelRunnerBase = ModelRunnerClass(
vllm_config=vllm_config,
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker,
**speculative_args,
)
if model_runner_cls is not None:
self.model_runner = model_runner_cls(self.model_runner)
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
self.cache_engine: List[CPUCacheEngine]
# Initialize cpu_cache as pooling models don't initialize kv_caches
self.cpu_cache: Optional[List[List[torch.Tensor]]] = None
# Torch profiler. Enabled and configured through env vars:
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
if envs.VLLM_TORCH_PROFILER_DIR:
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
logger.info("Profiling enabled. Traces will be saved to: %s",
torch_profiler_trace_dir)
self.profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
],
with_stack=True,
on_trace_ready=torch.profiler.tensorboard_trace_handler(
torch_profiler_trace_dir, use_gzip=True))
else:
self.profiler = None
def start_profile(self):
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
self.profiler.start()
def stop_profile(self):
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
self.profiler.stop()
def init_device(self) -> None:
if self.local_omp_cpuid != "all":
ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)
if ret:
logger.info(ret)
# Note: unique identifier for creating allreduce shared memory
os.environ["VLLM_DIST_IDENT"] = self.distributed_init_method.split(
":")[-1]
self.device = torch.device("cpu")
self.init_distributed_environment()
# Set random seed.
set_random_seed(self.model_config.seed)
def load_model(self):
self.model_runner.load_model()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of blocks available for the KV cache.
This determines how many KV blocks can fit into the configured CPU
KV cache space.
Note that since vLLM assumes a block resides on GPU if it can be
modified, we return num_gpu_blocks=num_cpu_blocks and num_cpu_blocks=0.
This allows us to reuse the scheduler of vLLM without generalizing it
to different devices.
"""
# For CPU device, the block number will be calculated based on the
# cpu_kvcache_space.
cache_block_size = self.get_cache_block_size_bytes()
num_cpu_blocks = int(self.cache_config.cpu_kvcache_space_bytes //
cache_block_size)
num_cpu_blocks = max(num_cpu_blocks, 0)
# Note: To reuse the cache management procedure,
# use cpu cache as 'gpu cache'.
num_gpu_blocks = num_cpu_blocks
num_cpu_blocks = 0
return num_gpu_blocks, num_cpu_blocks
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Initialize the KV cache. Currently, swappable CPU memory is not
supported.
Since this worker does not support GPUs, we use the num_gpu_blocks to
determine how many non-swappable CPU blocks to allocate.
"""
assert (num_cpu_blocks == 0
), f"{type(self)} does not support swappable cache"
# Note: To reuse the cache management procedure,
# use cpu cache as 'gpu cache'.
num_cpu_blocks = num_gpu_blocks
self._validate_num_cpu_blocks(num_cpu_blocks)
self.cache_config.num_gpu_blocks = num_cpu_blocks
self.cache_config.num_cpu_blocks = 0
# Initialize the cache.
self._init_cache_engine()
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_runner.add_lora(lora_request)
def remove_lora(self, lora_id: int) -> bool:
return self.model_runner.remove_lora(lora_id)
def pin_lora(self, lora_id: int) -> bool:
return self.model_runner.pin_lora(lora_id)
def list_loras(self) -> Set[int]:
return self.model_runner.list_loras()
def _validate_num_cpu_blocks(self, num_cpu_blocks: int) -> None:
"""Raise errors if the num_cpu_blocks is invalid.
"""
if num_cpu_blocks <= 0:
raise ValueError("No available memory for the cache blocks. "
"Try increasing `VLLM_CPU_KVCACHE_SPACE` when "
"initializing the engine.")
max_seq_len = self.cache_config.block_size * num_cpu_blocks
if self.model_config.max_model_len > max_seq_len:
raise ValueError(
f"The model's max seq len ({self.model_config.max_model_len}) "
"is larger than the maximum number of tokens that can be "
f"stored in KV cache ({max_seq_len}). Try increasing "
"`VLLM_CPU_KVCACHE_SPACE` or decreasing `max_model_len` when "
"initializing the engine.")
def _init_cache_engine(self) -> None:
self.cache_engine = [
CPUCacheEngine(self.cache_config, self.model_config,
self.parallel_config, self.device_config)
for _ in range(self.parallel_config.pipeline_parallel_size)
]
self.cpu_cache = [
self.cache_engine[ve].cpu_cache
for ve in range(self.parallel_config.pipeline_parallel_size)
]
bind_kv_cache(self.compilation_config.static_forward_context,
self.cpu_cache)
self.model_runner.block_size = self.cache_engine[0].block_size
assert all(
self.cpu_cache[ve] is not None
for ve in range(self.parallel_config.pipeline_parallel_size))
# Populate the cache to warmup the memory
for ve in range(self.parallel_config.pipeline_parallel_size):
for layer_cache in self.cpu_cache[ve]:
layer_cache.fill_(0)
@property
def do_metadata_broadcast(self) -> bool:
return self.parallel_config.tensor_parallel_size > 1
@property
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
return self.cpu_cache
@property
def cache_engines(self) -> Optional[List[CacheEngine]]:
return None
@property
def vocab_size(self) -> int:
return self.model_runner.vocab_size
@property
def max_model_len(self) -> int:
return self.model_config.max_model_len
def execute_worker(
self,
worker_input: WorkerInput,
) -> None:
if (worker_input.blocks_to_copy is not None
and worker_input.blocks_to_copy.numel() > 0):
self.cache_engine[worker_input.virtual_engine].copy(
worker_input.blocks_to_copy)
@torch.inference_mode()
def prepare_worker_input(
self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
assert execute_model_req is not None
virtual_engine: int = execute_model_req.virtual_engine
num_seq_groups: int = len(execute_model_req.seq_group_metadata_list)
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
device="cpu",
dtype=torch.int64).view(-1, 2)
assert len(execute_model_req.blocks_to_swap_in) == 0
assert len(execute_model_req.blocks_to_swap_out) == 0
return WorkerInput(
num_seq_groups=num_seq_groups,
blocks_to_copy=blocks_to_copy,
virtual_engine=virtual_engine,
)
def init_distributed_environment(self) -> None:
"""Initialize the distributed environment."""
parallel_config = self.parallel_config
rank = self.rank
distributed_init_method = self.distributed_init_method
init_distributed_environment(
world_size=parallel_config.world_size,
rank=rank,
distributed_init_method=distributed_init_method,
backend="gloo",
)
# A small all_reduce for warmup.
torch.distributed.all_reduce(torch.zeros(1).cpu())
ensure_model_parallel_initialized(
parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
def get_cache_block_size_bytes(self) -> int:
"""Return the size in bytes of a single KV cache block.
"""
return CPUCacheEngine.get_cache_block_size(self.cache_config,
self.model_config,
self.parallel_config)
def get_cpus_id_binding_based_on_numa_nodes(self) -> str:
"""Return CPUs id binding based on NUMA nodes.
"""
rank_to_cpus = self.local_omp_cpuid
# Setup OpenMP thread affinity based on NUMA nodes automatically
world_size = self.vllm_config.parallel_config.world_size
libnuma_found = util.find_spec("numa") is not None
psutil_found = util.find_spec("psutil") is not None
if libnuma_found and psutil_found:
import psutil
from numa import info
cpu_count = psutil.cpu_count(logical=False)
cpus_allow_list = psutil.Process().cpu_affinity()
numa_size = info.get_num_configured_nodes()
cpu_count_per_numa = cpu_count // numa_size
num_of_reserved_cpu = min(envs.VLLM_CPU_NUM_OF_RESERVED_CPU,
cpu_count_per_numa // 2)
# check allow node_to_cpus list
node_to_cpus = []
for i in range(numa_size):
node_intersect = set(
info.node_to_cpus(i)).intersection(cpus_allow_list)
if bool(node_intersect):
node_to_cpus.append(list(node_intersect))
if world_size > len(node_to_cpus):
logger.error(
"Auto thread-binding failed due to "
"world size: %d is larger than "
"allowed NUMA nodes number: %d."
"Please try to bind threads manually.", world_size,
len(node_to_cpus))
else:
end = cpu_count_per_numa - num_of_reserved_cpu
rank_to_cpus_list = node_to_cpus[self.rank][:end]
rank_to_cpus = ','.join(str(x) for x in rank_to_cpus_list)
logger.info("auto thread-binding list: %s", rank_to_cpus)
else:
logger.warning(
"Auto thread-binding is not supported due to "
"the lack of package numa and psutil,"
"fallback to no thread-binding. To get better performance,"
"please try to manually bind threads.")
return rank_to_cpus

View File

@@ -0,0 +1,555 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
import itertools
from typing import Any, Dict, List, Optional, Tuple, Type, cast
import torch
import torch.distributed
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata)
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.attention.selector import (get_env_variable_attn_backend,
get_global_forced_attn_backend)
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
MultiModalRegistry)
from vllm.platforms import _Backend
from vllm.sampling_params import SamplingParams
from vllm.sequence import (IntermediateTensors, PoolerOutput,
SequenceGroupMetadata)
from vllm.utils import STR_NOT_IMPL_ENC_DEC_BACKEND, make_tensor_with_pad
from vllm.worker.model_runner import (GPUModelRunnerBase,
ModelInputForGPUBuilder,
ModelInputForGPUWithSamplingMetadata)
from vllm.worker.model_runner_base import (
_add_attn_metadata_broadcastable_dict,
_add_sampling_metadata_broadcastable_dict)
from vllm.worker.utils import assert_enc_dec_mr_supported_scenario
logger = init_logger(__name__)
LORA_WARMUP_RANK = 8
@dataclasses.dataclass(frozen=True)
class EncoderDecoderModelInput(ModelInputForGPUWithSamplingMetadata):
"""
Used by the EncoderDecoderModelRunner.
"""
encoder_input_tokens: Optional[torch.Tensor] = None
encoder_input_positions: Optional[torch.Tensor] = None
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
"input_tokens": self.input_tokens,
"inputs_embeds": self.inputs_embeds,
"input_positions": self.input_positions,
"encoder_input_tokens": self.encoder_input_tokens,
"encoder_input_positions": self.encoder_input_positions,
"virtual_engine": self.virtual_engine,
"request_ids_to_seq_ids": self.request_ids_to_seq_ids,
"finished_requests_ids": self.finished_requests_ids,
"multi_modal_kwargs": self.multi_modal_kwargs,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
_add_sampling_metadata_broadcastable_dict(tensor_dict,
self.sampling_metadata)
return tensor_dict
@classmethod
def from_broadcasted_tensor_dict(
cls,
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None,
) -> "EncoderDecoderModelInput":
return cast(
EncoderDecoderModelInput,
super().from_broadcasted_tensor_dict(tensor_dict, attn_backend))
class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
_model_input_cls: Type[EncoderDecoderModelInput] = (
EncoderDecoderModelInput)
_builder_cls: Type[ModelInputForGPUBuilder] = (ModelInputForGPUBuilder)
def __init__(
self,
vllm_config: VllmConfig,
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
input_registry: InputRegistry = INPUT_REGISTRY,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
):
'''
EncoderDecoderModelRunner constructor.
`lora_config` and `prompt_adapter_config` are
unused (since these features are not yet supported for encoder/decoder
models) but these arguments are present here for compatibility with
the base-class constructor.
'''
self._maybe_force_supported_attention_backend()
super().__init__(
vllm_config=vllm_config,
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker,
input_registry=input_registry,
mm_registry=mm_registry,
)
# Crash for unsupported encoder/scenarios
assert_enc_dec_mr_supported_scenario(self)
def _maybe_force_supported_attention_backend(self):
'''
Force vLLM to use the XFormers attention backend,
which is currently the only supported option.
'''
def raise_backend_err():
# The user has specified an attention backend override
# which is invalid for encoder/decoder models
raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_BACKEND)
maybe_env_var_forced_backend = get_env_variable_attn_backend()
maybe_global_forced_backend = get_global_forced_attn_backend()
is_forced_by_global = maybe_global_forced_backend is not None
is_forced_by_env_var = maybe_env_var_forced_backend is not None
if is_forced_by_global: # noqa: SIM102
# Backend override enforced by global variable takes
# precedence over vLLM backend environment variable.
if maybe_global_forced_backend not in\
[_Backend.XFORMERS, _Backend.FLASH_ATTN]:
raise_backend_err()
elif is_forced_by_env_var: # noqa: SIM102
# Backend override enforced by vLLM backend
# environment variable
if maybe_env_var_forced_backend not in\
[_Backend.XFORMERS, _Backend.FLASH_ATTN]:
raise_backend_err()
def _list_to_int32_tensor(
self,
_list: List[int],
) -> torch.Tensor:
return torch.tensor(_list, dtype=torch.int32, device=self.device)
def _list_to_long_tensor(
self,
_list: List[int],
) -> torch.Tensor:
return torch.tensor(_list, dtype=torch.long, device=self.device)
def _empty_int32_tensor(self) -> torch.Tensor:
return self._list_to_int32_tensor([])
def _empty_long_tensor(self) -> torch.Tensor:
return self._list_to_long_tensor([])
@torch.inference_mode()
def execute_model(
self,
model_input: EncoderDecoderModelInput,
kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> Optional[List[PoolerOutput]]:
if num_steps > 1:
raise ValueError("num_steps > 1 is not supported in "
"EncoderDecoderModelRunner")
if self.lora_config:
assert model_input.lora_requests is not None
assert model_input.lora_mapping is not None
self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping)
if (model_input.attn_metadata is not None
and model_input.attn_metadata.prefill_metadata is None
and model_input.attn_metadata.decode_metadata.use_cuda_graph):
if model_input.inputs_embeds is None:
assert model_input.input_tokens is not None
graph_batch_size = model_input.input_tokens.shape[0]
model_executable = (
self.graph_runners[model_input.virtual_engine][(
graph_batch_size, False)])
else:
graph_batch_size = model_input.inputs_embeds.shape[0]
model_executable = (
self.graph_runners[model_input.virtual_engine][(
graph_batch_size, True)])
else:
model_executable = self.model
seqlen_agnostic_kwargs = {
"finished_requests_ids": model_input.finished_requests_ids,
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
} if self.has_inner_state else {}
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
with set_forward_context(model_input.attn_metadata, self.vllm_config,
model_input.virtual_engine):
hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens,
inputs_embeds=model_input.inputs_embeds,
positions=model_input.input_positions,
encoder_input_ids=model_input.encoder_input_tokens,
encoder_positions=model_input.encoder_input_positions,
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(
multi_modal_kwargs,
device=self.device,
),
**seqlen_agnostic_kwargs,
)
logits = self.model.compute_logits(hidden_or_intermediate_states,
model_input.sampling_metadata)
if not self.is_driver_worker:
return []
if model_input.async_callback is not None:
model_input.async_callback()
# Sample the next token.
output: SamplerOutput = self.sampler(
logits=logits,
sampling_metadata=model_input.sampling_metadata,
)
return [output]
def make_model_input_from_broadcasted_tensor_dict(
self, tensor_dict: Dict[str, Any]) -> EncoderDecoderModelInput:
return EncoderDecoderModelInput.from_broadcasted_tensor_dict(
tensor_dict,
attn_backend=self.attn_backend,
)
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> EncoderDecoderModelInput:
"""Prepare the model input based on a given sequence group, including
metadata for the sampling step.
Since chunked prefill is not supported for encoder/decoder models,
`input_tokens` is assumed to be either entirely prefill tokens or
entirely decode tokens.
"""
model_input = self._prepare_model_input_tensors(
seq_group_metadata_list, finished_requests_ids)
(
attn_metadata,
encoder_input_tokens_tensor,
encoder_input_positions_tensor,
) = (self._prepare_encoder_model_input_tensors(seq_group_metadata_list,
model_input))
# Inject attn_metadata encoder/cross-attention fields &
# encoder input tokens/positions into model_input.
# Frozen dataclass fields cannot be modified, so use
# dataclasses.replace to construct a new model input
# instance.
model_input = dataclasses.replace(
model_input,
attn_metadata=attn_metadata,
encoder_input_tokens=encoder_input_tokens_tensor,
encoder_input_positions=encoder_input_positions_tensor,
)
generators = self.get_generators(finished_requests_ids)
sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
model_input.seq_lens,
model_input.query_lens,
self.device,
self.pin_memory,
generators=generators)
is_prompt = (seq_group_metadata_list[0].is_prompt
if seq_group_metadata_list else None)
return dataclasses.replace(model_input,
sampling_metadata=sampling_metadata,
is_prompt=is_prompt,
virtual_engine=virtual_engine)
@torch.inference_mode()
def profile_run(self) -> None:
# Enable top-k sampling to reflect the accurate memory usage.
sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
max_num_seqs = self.scheduler_config.max_num_seqs
# This represents the maximum number of different requests
# that will have unique loras, and therefore the max amount of
# memory consumption. Create dummy lora request copies from the
# lora request passed in, which contains a lora from the lora
# warmup path.
dummy_lora_requests: List[LoRARequest] = []
dummy_lora_requests_per_seq: List[LoRARequest] = []
if self.lora_config:
dummy_lora_requests = self._add_dummy_loras(
self.lora_config.max_loras)
assert len(dummy_lora_requests) == self.lora_config.max_loras
dummy_lora_requests_per_seq = [
dummy_lora_requests[idx % len(dummy_lora_requests)]
for idx in range(max_num_seqs)
]
# Profile memory usage with max_num_sequences sequences and the total
# number of tokens equal to max_num_batched_tokens.
seqs: List[SequenceGroupMetadata] = []
max_mm_tokens = self.mm_registry.get_max_multimodal_tokens(
self.model_config)
if max_mm_tokens > 0:
logger.info("Starting profile run for multi-modal models.")
batch_size = 0
for group_id in range(max_num_seqs):
seq_len = (max_num_batched_tokens // max_num_seqs +
(group_id < max_num_batched_tokens % max_num_seqs))
batch_size += seq_len
decoder_dummy_data = self.input_registry \
.dummy_data_for_profiling(self.model_config,
seq_len,
self.mm_registry,
is_encoder_data=False)
encoder_dummy_data = self.input_registry \
.dummy_data_for_profiling(self.model_config,
seq_len,
self.mm_registry,
is_encoder_data=True)
# Having more tokens is over-conservative but otherwise fine
assert len(
decoder_dummy_data.seq_data.prompt_token_ids
) >= seq_len, (
f"Expected at least {seq_len} dummy tokens for profiling, "
f"but got: {len(decoder_dummy_data.seq_data.prompt_token_ids)}"
)
assert decoder_dummy_data.multi_modal_data is None or \
encoder_dummy_data.multi_modal_data is None, (
"Multi-modal data can't be provided in both encoder and decoder"
)
seq = SequenceGroupMetadata(
request_id=str(group_id),
is_prompt=True,
seq_data={group_id: decoder_dummy_data.seq_data},
sampling_params=sampling_params,
block_tables=None,
encoder_seq_data=encoder_dummy_data.seq_data,
cross_block_table=None,
lora_request=dummy_lora_requests_per_seq[group_id]
if dummy_lora_requests_per_seq else None,
multi_modal_data=decoder_dummy_data.multi_modal_data
or encoder_dummy_data.multi_modal_data,
multi_modal_placeholders=decoder_dummy_data.
multi_modal_placeholders
or encoder_dummy_data.multi_modal_placeholders)
seqs.append(seq)
finished_requests_ids = [seq.request_id for seq in seqs]
model_input = self.prepare_model_input(
seqs, finished_requests_ids=finished_requests_ids)
intermediate_tensors = None
self.execute_model(model_input, None, intermediate_tensors)
torch.cuda.synchronize()
return
def _prepare_encoder_model_input_tensors(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
model_input: EncoderDecoderModelInput,
) -> Tuple[AttentionMetadata, Optional[torch.Tensor],
Optional[torch.Tensor]]:
"""Helper method to prepare the encoder- and cross-attn-related
model inputs based on a given sequence group. These additional inputs
are used to augment an already-computed `EncoderDecoderModelInput`
data structure which already has decoder-related model inputs
populated.
Sets the following attn_metadata fields:
* `num_encoder_tokens`
* `encoder_seq_lens`
* `encoder_seq_lens_tensor`
* `max_encoder_seq_len`
* `cross_slot_mapping`
* `cross_block_tables`
Constructs a new model inputs data structure, based on
(1) the existing fields in the `model_inputs` argument,
and (2) the following additional fields which are
computed (or in the case of `attn_metadata`, updated)
by this function:
* attn_metadata
* encoder_input_tokens
* encoder_input_positions
Arguments:
* seq_group_metadata_list: list of sequence groups for which to
compute inputs
* model_inputs: model inputs data structure with decoder-oriented
fields already computed.
Return:
* Updated model inputs data structure
"""
if len(seq_group_metadata_list) == 0:
return (model_input.attn_metadata, None, None)
# Since we are not supporting chunked prefill either the entire
# batch is prefill or it is decode
is_prompt = seq_group_metadata_list[0].is_prompt
# Build encoder inputs
encoder_seq_lens: List[int] = []
if is_prompt:
# Prefill phase.
cross_block_tables = self._empty_int32_tensor().view(
len(seq_group_metadata_list), -1)
# Extract input tokens/positions, cross-attention slot-mapping,
# & seq len from each sequence group metadata
(
encoder_input_tokens,
encoder_input_positions,
cross_slot_mapping,
) = (
[],
[],
[],
)
for seq_group_metadata in seq_group_metadata_list:
# Build seq lens
seq_len = seq_group_metadata.encoder_seq_data.get_len()
token_ids = seq_group_metadata.encoder_seq_data.get_token_ids()
encoder_seq_lens.append(seq_len)
# Build slot mapping
is_profile_run = (seq_group_metadata.block_tables is None)
if is_profile_run:
# During memory profiling, the block tables are not
# initialized yet. In this case, we just use a dummy
# slot mapping.
# In embeddings, the block tables are {seq_id: None}.
cross_slot_mapping.extend([PAD_SLOT_ID] * seq_len)
else:
for i in range(0, seq_len):
block_number = seq_group_metadata.cross_block_table[
i // self.block_size]
block_offset = i % self.block_size
slot = block_number * self.block_size + block_offset
cross_slot_mapping.append(slot)
# Build encoder input tokens
encoder_input_tokens.extend(token_ids)
encoder_input_positions.extend(list(range(0, seq_len)))
# Convert tokens/positions & cross-attention
# slot-mapping to encoder input tensors
encoder_input_tokens_tensor = self._list_to_long_tensor(
encoder_input_tokens)
encoder_input_positions_tensor = self._list_to_long_tensor(
encoder_input_positions)
cross_slot_mapping_tensor = self._list_to_long_tensor(
cross_slot_mapping)
else:
# Decode phase.
encoder_input_tokens_tensor = self._empty_long_tensor()
encoder_input_positions_tensor = self._empty_long_tensor()
cross_slot_mapping_tensor = self._empty_long_tensor()
# Extract cross-attention block tables &
# seq len from each sequence group metadata.
# Cross-attention block tables are empty
# during vLLM memory profiling.
cross_block_tables = []
for seq_group_metadata in seq_group_metadata_list:
for _ in range(len(seq_group_metadata.seq_data)):
encoder_seq_lens.append(
seq_group_metadata.encoder_seq_data.get_len())
cross_block_table = seq_group_metadata.cross_block_table
cross_block_tables.append([] if (
cross_block_table is None) else cross_block_table)
if (model_input.attn_metadata is not None
and model_input.attn_metadata.use_cuda_graph):
# We will be using CUDA graph replay for this decode.
max_len_of_block_table = self.get_max_block_per_batch()
batch_size = len(encoder_seq_lens)
graph_batch_size = self.vllm_config.pad_for_cudagraph(
batch_size)
assert graph_batch_size >= batch_size
cuda_graph_pad_size = graph_batch_size - batch_size
# extend the cross_block_tables and encoder_seq_lens to match
# the graph_batch_size.
cross_block_tables.extend([[]
for _ in range(cuda_graph_pad_size)
])
encoder_seq_lens.extend(
itertools.repeat(1, cuda_graph_pad_size))
else:
max_len_of_block_table = max(
len(block_table) for block_table in cross_block_tables)
cross_block_tables = make_tensor_with_pad(
cross_block_tables,
max_len=max_len_of_block_table,
pad=0,
dtype=torch.int32,
device=self.device,
)
# Compute encoder sequence lengths & encoder
# sequence starting offset tensors
max_encoder_seq_len = max(encoder_seq_lens, default=0)
encoder_seq_lens_tensor = self._list_to_int32_tensor(encoder_seq_lens)
encoder_seq_start_loc = torch.zeros(encoder_seq_lens_tensor.shape[0] +
1,
dtype=torch.int32,
device=self.device)
torch.cumsum(encoder_seq_lens_tensor,
dim=0,
dtype=encoder_seq_start_loc.dtype,
out=encoder_seq_start_loc[1:])
# Update attention metadata with encoder-oriented attributes
attn_metadata = model_input.attn_metadata
assert attn_metadata is not None
(
attn_metadata.num_encoder_tokens,
attn_metadata.encoder_seq_lens,
attn_metadata.encoder_seq_lens_tensor,
attn_metadata.max_encoder_seq_len,
attn_metadata.encoder_seq_start_loc,
attn_metadata.cross_slot_mapping,
attn_metadata.cross_block_tables,
) = (
sum(encoder_seq_lens),
encoder_seq_lens,
encoder_seq_lens_tensor,
max_encoder_seq_len,
encoder_seq_start_loc,
cross_slot_mapping_tensor,
cross_block_tables,
)
return (attn_metadata, encoder_input_tokens_tensor,
encoder_input_positions_tensor)

File diff suppressed because it is too large Load Diff

484
vllm/worker/hpu_worker.py Normal file
View File

@@ -0,0 +1,484 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
###############################################################################
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company
###############################################################################
import contextlib
import gc
import os
from typing import List, Optional, Set, Tuple, Type
import habana_frameworks.torch as htorch # noqa:F401
import torch
import torch.distributed
from vllm_hpu_extension.profiler import HabanaMemoryProfiler, format_bytes
import vllm.envs as envs
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest
from vllm.utils import bind_kv_cache
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.hpu_model_runner import HPUModelRunner
from vllm.worker.model_runner_base import ModelRunnerBase
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
WorkerInput)
logger = init_logger(__name__)
class HPUWorker(LocalOrDistributedWorkerBase):
"""A worker class that executes (a partition of) the model on a HPU.
Each worker is associated with a single HPU. The worker is responsible for
maintaining the KV cache and executing the model on the HPU. In case of
distributed inference, each worker is assigned a partition of the model.
"""
def __init__(
self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool = False,
model_runner_cls: Optional[Type[ModelRunnerBase]] = None,
) -> None:
WorkerBase.__init__(self, vllm_config=vllm_config)
self.parallel_config.rank = rank
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
self.is_driver_worker = is_driver_worker
if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0."
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
self.model_runner: HPUModelRunner = HPUModelRunner(
vllm_config=vllm_config, is_driver_worker=is_driver_worker)
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
self.cache_engine: List[HPUCacheEngine]
# Initialize gpu_cache as pooling models don't initialize kv_caches
self.hpu_cache: Optional[List[List[torch.Tensor]]] = None
# Torch profiler. Enabled and configured through env vars:
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
if envs.VLLM_TORCH_PROFILER_DIR:
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
logger.info("Profiling enabled. Traces will be saved to: %s",
torch_profiler_trace_dir)
self.profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.HPU,
],
with_stack=True,
on_trace_ready=torch.profiler.tensorboard_trace_handler(
torch_profiler_trace_dir, use_gzip=True))
else:
self.profiler = None
def start_profile(self):
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
self.profiler.start()
def stop_profile(self):
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
self.profiler.stop()
def _set_env_vars(self):
local_rank = self.local_rank
if self.parallel_config.world_size == 1:
local_rank = -1
import os
os.environ["LOCAL_RANK"] = str(local_rank)
os.environ["ID"] = str(local_rank)
os.environ["WORLD_SIZE"] = str(self.parallel_config.world_size)
os.environ["RANK"] = str(self.rank)
def init_device(self) -> None:
if self.device_config.device.type == "hpu":
self.device = torch.device("hpu")
torch.hpu.set_device(self.device)
else:
raise RuntimeError(
f"Not support device type: {self.device_config.device}")
# Initialize the distributed environment.
if self.model_config.quantization == 'inc':
self._set_env_vars()
init_worker_distributed_environment(self.parallel_config, self.rank,
self.distributed_init_method,
self.local_rank)
# Set random seed.
set_random_seed(self.model_config.seed)
def load_model(self):
self.model_runner.load_model()
def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None,
) -> Optional[List[SamplerOutput]]:
# VLLM_HPU_LOG_STEP_GRAPH_COMPILATION - will log graph compilations per engine step, only when there was any - highly recommended to use alongside PT_HPU_METRICS_GC_DETAILS! # noqa:E501
# VLLM_HPU_LOG_STEP_GRAPH_COMPILATION_ALL - will log graph compilations per engine step, always, even if there were none # noqa:E501
# VLLM_HPU_LOG_STEP_CPU_FALLBACKS - will log cpu fallbacks per engine step, only when there was any # noqa:E501
# VLLM_HPU_LOG_STEP_CPU_FALLBACKS_ALL - will log cpu fallbacks per engine step, always, even if there were none # noqa:E501
log_graph_compilation_all = os.environ.get(
'VLLM_HPU_LOG_STEP_GRAPH_COMPILATION_ALL', '0') != '0'
log_graph_compilation = os.environ.get(
'VLLM_HPU_LOG_STEP_GRAPH_COMPILATION',
'0') != '0' or log_graph_compilation_all
log_cpu_fallbacks_all = os.environ.get(
'VLLM_HPU_LOG_STEP_CPU_FALLBACKS_ALL', '0') != '0'
log_cpu_fallbacks = os.environ.get('VLLM_HPU_LOG_STEP_CPU_FALLBACKS',
'0') != '0' or log_cpu_fallbacks_all
if (log_graph_compilation or log_cpu_fallbacks) and \
execute_model_req is not None:
from habana_frameworks.torch.hpu.metrics import metric_localcontext
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
is_prompt = any([
seq_group_metadata.is_prompt
for seq_group_metadata in seq_group_metadata_list
])
max_context_len = max([
max([
len(v.prompt_token_ids) + len(v.output_token_ids)
for v in seq_group_metadata.seq_data.values()
]) for seq_group_metadata in seq_group_metadata_list
]) # whoa, that's some spicy stuff right here
max_num_blocks = (
(max_context_len - 1) // self.cache_config.block_size) + 1
input_stats = (f'is_prompt: {is_prompt}, '
f'num_seqs: {len(seq_group_metadata_list)}, '
f'max_context_len: {max_context_len}, '
f'max_num_blocks {max_num_blocks}')
gc_ctx = metric_localcontext(
"graph_compilation"
) if log_graph_compilation else contextlib.nullcontext()
cpu_fallback_ctx = metric_localcontext(
"cpu_fallback"
) if log_cpu_fallbacks else contextlib.nullcontext()
with gc_ctx as gc_local_metric, \
cpu_fallback_ctx as cpu_fallback_local_metric:
output = LocalOrDistributedWorkerBase.execute_model(
self, execute_model_req)
if (log_graph_compilation and gc_local_metric.stats()[0][1]
> 0) or log_graph_compilation_all:
msg = ("VLLM_HPU_STEP_GRAPH_COMPILATION: "
f"{gc_local_metric.stats()}, {input_stats}")
logger.warning(msg)
if (log_cpu_fallbacks and cpu_fallback_local_metric.stats()[0][1]
> 0) or log_cpu_fallbacks_all:
msg = ("VLLM_HPU_STEP_CPU_FALLBACK: "
f"{cpu_fallback_local_metric.stats()}, {input_stats}")
logger.warning(msg)
return output
output = LocalOrDistributedWorkerBase.execute_model(
self, execute_model_req)
return output
@torch.inference_mode()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Profiles the peak memory usage of the model to determine how many
KV blocks may be allocated without OOMs.
The engine will first conduct a profiling of the existing memory usage.
Then, it calculate the maximum possible number of GPU and CPU blocks
that can be allocated with the remaining free memory.
Tip:
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""
# Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory.
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
with HabanaMemoryProfiler() as m:
self.model_runner.profile_run()
torch.hpu.synchronize()
msg = ("Model profiling run "
f"took {m.get_summary_string()}")
logger.info(msg)
# At this point we should've allocated the maximum workspace for all
# recipes we will use the extra memory for graphs/blocks
free_hpu_memory = torch.hpu.mem_get_info()[0]
cache_block_size = self.get_cache_block_size_bytes()
graph_reserved_mem = (float(
os.environ.get('VLLM_GRAPH_RESERVED_MEM', '0.1'))
if not self.model_config.enforce_eager else 0)
graph_headroom = 1 - graph_reserved_mem
available_hpu_memory = free_hpu_memory * \
self.cache_config.gpu_memory_utilization
hpu_memory_margin = free_hpu_memory * (
1 - self.cache_config.gpu_memory_utilization)
self.model_runner.mem_margin = hpu_memory_margin
cache_size_bytes = available_hpu_memory * graph_headroom
graph_headroom_bytes = available_hpu_memory * (1 - graph_headroom)
msg = (
f"Free device memory: {format_bytes(free_hpu_memory)}, "
f"{format_bytes(available_hpu_memory)} usable "
f"(gpu_memory_utilization={self.cache_config.gpu_memory_utilization}),"
f" {format_bytes(graph_headroom_bytes)} reserved for HPUGraphs "
f"(VLLM_GRAPH_RESERVED_MEM={graph_reserved_mem}), "
f"{format_bytes(cache_size_bytes)} reserved for KV cache")
logger.info(msg)
num_hpu_blocks = int(cache_size_bytes // cache_block_size)
num_cpu_blocks = int(self.cache_config.swap_space_bytes //
cache_block_size)
num_hpu_blocks = max(num_hpu_blocks, 0)
num_cpu_blocks = max(num_cpu_blocks, 0)
self.model_runner.bucketing_ctx.num_hpu_blocks = num_hpu_blocks
if self.model_runner.lora_manager:
self.model_runner.remove_all_loras()
gc.collect()
return num_hpu_blocks, num_cpu_blocks
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Allocate GPU and CPU KV cache with the specified number of blocks.
This also warms up the model, which may record CUDA graphs.
"""
raise_if_cache_size_invalid(
num_gpu_blocks, self.cache_config.block_size,
self.model_config.max_model_len,
self.parallel_config.pipeline_parallel_size)
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
with HabanaMemoryProfiler() as m:
self._init_cache_engine()
torch.hpu.synchronize()
msg = ("Initializing cache engine "
f"took {m.get_summary_string()}")
logger.info(msg)
self._warm_up_model()
def _init_cache_engine(self):
assert self.cache_config.num_gpu_blocks is not None
self.cache_engine = [
HPUCacheEngine(self.cache_config, self.model_config,
self.parallel_config, self.device_config)
for _ in range(self.parallel_config.pipeline_parallel_size)
]
self.hpu_cache = [
self.cache_engine[ve].gpu_cache
for ve in range(self.parallel_config.pipeline_parallel_size)
]
bind_kv_cache(self.compilation_config.static_forward_context,
self.hpu_cache)
def _warm_up_model(self) -> None:
# NOTE(kzawora): We should use virtual engine index here
# for pipeline parallelism. Using 0 for now.
assert self.hpu_cache is not None
self.model_runner.warmup_model(self.hpu_cache[0])
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
def finish_measurements(self):
self.model_runner.finish_measurements()
@property
def do_metadata_broadcast(self) -> bool:
return self.parallel_config.tensor_parallel_size > 1
@property
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
return self.hpu_cache
@torch.inference_mode()
def prepare_worker_input(
self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
virtual_engine = execute_model_req.virtual_engine
num_seq_groups = len(execute_model_req.seq_group_metadata_list)
# `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors.
# they contain parameters to launch cudamemcpyasync.
blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in,
device="cpu",
dtype=torch.int64).view(-1, 2)
blocks_to_swap_out = torch.tensor(execute_model_req.blocks_to_swap_out,
device="cpu",
dtype=torch.int64).view(-1, 2)
# `blocks_to_copy` is a gpu tensor. The src and tgt of
# blocks to copy are in the same device, and `blocks_to_copy`
# can be used directly within cuda kernels.
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
device=self.device,
dtype=torch.int64).view(-1, 2)
return WorkerInput(
num_seq_groups=num_seq_groups,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
virtual_engine=virtual_engine,
)
@torch.inference_mode()
def execute_worker(self, worker_input: WorkerInput) -> None:
virtual_engine = worker_input.virtual_engine
# Issue cache operations.
if (worker_input.blocks_to_swap_in is not None
and worker_input.blocks_to_swap_in.numel() > 0):
self.cache_engine[virtual_engine].swap_in(
worker_input.blocks_to_swap_in)
if (worker_input.blocks_to_swap_out is not None
and worker_input.blocks_to_swap_out.numel() > 0):
self.cache_engine[virtual_engine].swap_out(
worker_input.blocks_to_swap_out)
if (worker_input.blocks_to_copy is not None
and worker_input.blocks_to_copy.numel() > 0):
self.cache_engine[virtual_engine].copy(worker_input.blocks_to_copy)
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_runner.add_lora(lora_request)
def remove_lora(self, lora_id: int) -> bool:
return self.model_runner.remove_lora(lora_id)
def pin_lora(self, lora_id: int) -> bool:
return self.model_runner.pin_lora(lora_id)
def list_loras(self) -> Set[int]:
return self.model_runner.list_loras()
def add_prompt_adapter(
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
raise NotImplementedError(
"Prompt Adapter is not implemented for HPU backend.")
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
raise NotImplementedError(
"Prompt Adapter is not implemented for HPU backend.")
def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
raise NotImplementedError(
"Prompt Adapter is not implemented for HPU backend.")
def list_prompt_adapters(self) -> Set[int]:
raise NotImplementedError(
"Prompt Adapter is not implemented for HPU backend.")
def shutdown_inc(self):
self.model_runner.shutdown_inc()
@property
def max_model_len(self) -> int:
return self.model_config.max_model_len
@property
def vocab_size(self) -> int:
return self.model_runner.vocab_size
def get_cache_block_size_bytes(self) -> int:
"""Get the size of the KV cache block size in bytes.
"""
return HPUCacheEngine.get_cache_block_size(self.cache_config,
self.model_config,
self.parallel_config)
def init_worker_distributed_environment(
parallel_config: ParallelConfig,
rank: int,
distributed_init_method: Optional[str] = None,
local_rank: int = -1,
) -> None:
"""Initialize the distributed environment."""
init_distributed_environment(parallel_config.world_size,
rank,
distributed_init_method,
local_rank,
backend='hccl')
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
if torch.distributed.is_initialized():
torch_world_size = torch.distributed.get_world_size()
if torch_world_size != parallel_config.world_size:
raise RuntimeError(
"torch.distributed is already initialized but the torch world "
"size does not match parallel_config.world_size "
f"({torch_world_size} vs. {parallel_config.world_size}).")
elif not distributed_init_method:
raise ValueError(
"distributed_init_method must be set if torch.distributed "
"is not already initialized")
else:
torch.distributed.init_process_group(
backend="hccl",
world_size=parallel_config.world_size,
rank=rank,
init_method=distributed_init_method,
)
# A small all_reduce for warmup & checking conformance.
dummy_tensor_hpu = torch.ones(1).to('hpu')
torch.distributed.all_reduce(dummy_tensor_hpu)
assert dummy_tensor_hpu.item() == parallel_config.world_size
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
def raise_if_cache_size_invalid(num_gpu_blocks, block_size, max_model_len,
pipeline_parallel_size) -> None:
if num_gpu_blocks <= 0:
raise ValueError("No available memory for the cache blocks. "
"Try increasing `gpu_memory_utilization` when "
"initializing the engine.")
max_seq_len = block_size * (num_gpu_blocks // pipeline_parallel_size)
if max_model_len > max_seq_len:
raise ValueError(
f"The model's max seq len ({max_model_len}) "
"is larger than the maximum number of tokens that can be "
f"stored in KV cache ({max_seq_len}). Try increasing "
"`gpu_memory_utilization` or decreasing `max_model_len` when "
"initializing the engine.")
class HPUCacheEngine(CacheEngine):
def _allocate_kv_cache(
self,
num_blocks: int,
device: str,
) -> List[Tuple[torch.Tensor, torch.Tensor]]:
"""Allocates KV cache on the specified device."""
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks, self.block_size, self.num_kv_heads, self.head_size)
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]] = []
for _ in range(self.num_attention_layers):
key_cache = torch.zeros(kv_cache_shape,
dtype=self.dtype,
device=device)
value_cache = torch.zeros(kv_cache_shape,
dtype=self.dtype,
device=device)
kv_layer = (key_cache, value_cache)
kv_cache.append(kv_layer)
return kv_cache

2215
vllm/worker/model_runner.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,282 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
from abc import ABC, abstractmethod
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
TypeVar)
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
if TYPE_CHECKING:
from vllm.attention import AttentionMetadata
from vllm.attention.backends.abstract import AttentionBackend
from vllm.model_executor import SamplingMetadata
logger = init_logger(__name__)
T = TypeVar('T', bound="BroadcastableModelInput")
def _add_attn_metadata_broadcastable_dict(
tensor_dict: Dict[str, Any],
attn_metadata: Optional["AttentionMetadata"]) -> None:
"""
Helper method to update tensor_dict with broadcastable
AttentionMetadata fields.
"""
if attn_metadata is not None:
tensor_dict.update(attn_metadata.asdict_zerocopy())
def _init_attn_metadata_from_tensor_dict(
attn_backend: "AttentionBackend",
tensor_dict: Dict[str, Any],
) -> Dict[str, Any]:
"""
Helper method to initialize AttentionMetadata based on an
AttentionBackend and broadcastable AttentionMetadata fields.
"""
# Extract the fields used to create AttentionMetadata.
valid_attn_kwargs = {}
for field in dataclasses.fields(attn_backend.get_metadata_cls()):
if field.name in tensor_dict:
if field.name == "input_positions":
valid_attn_kwargs[field.name] = tensor_dict[field.name]
else:
valid_attn_kwargs[field.name] = tensor_dict.pop(field.name)
attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs)
tensor_dict["attn_metadata"] = attn_metadata
return tensor_dict
def _init_sampling_metadata_from_tensor_dict( # type: ignore
tensor_dict: Dict[str, Any]) -> Dict[str, Any]:
"""
Helper method to initialize SamplingMetadata based on broadcastable
SamplingMetadata fields.
"""
from vllm.model_executor import SamplingMetadata
selected_token_indices = tensor_dict.pop("selected_token_indices", None)
# An empty SamplingMetadata to signal that the worker should skip
# sampling.
if selected_token_indices is not None:
tensor_dict["sampling_metadata"] = SamplingMetadata(
seq_groups=None,
selected_token_indices=selected_token_indices,
categorized_sample_indices=None,
num_prompts=0,
)
return tensor_dict
def _add_sampling_metadata_broadcastable_dict(
tensor_dict: Dict[str, Any],
sampling_metadata: Optional["SamplingMetadata"]) -> None:
"""
Helper method to update tensor_dict with broadcastable
SamplingMetadata fields.
"""
if sampling_metadata is not None:
tensor_dict["selected_token_indices"] = (
sampling_metadata.selected_token_indices)
def _init_frozen_model_input_from_tensor_dict(
frozen_model_input_cls: Type["ModelRunnerInputBase"],
tensor_dict: Dict[str, Any]) -> Dict[str, Any]:
"""
Helper method to initialize a frozen ModelInput based on broadcastable
"""
valid_tensor_kwargs = {}
for field in dataclasses.fields(frozen_model_input_cls):
val = tensor_dict.pop(field.name, None)
if val is not None:
valid_tensor_kwargs[field.name] = val
frozen_model_input = frozen_model_input_cls(**valid_tensor_kwargs)
tensor_dict["frozen_model_input"] = frozen_model_input
return tensor_dict
class BroadcastableModelInput(ABC):
@abstractmethod
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
"""
Extract broadcastable fields. Override for fields that require some
custom deserialization.
"""
raise NotImplementedError
@classmethod
@abstractmethod
def from_broadcasted_tensor_dict(
cls: Type[T],
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None,
) -> T:
"""
Pop fields from the given tensor_dict and populate a new instance of
BroadcastableModelInput.
"""
raise NotImplementedError
@dataclasses.dataclass(frozen=True)
class ModelRunnerInputBase(BroadcastableModelInput):
"""Local inputs to each worker's model runner. May contain
device-specific data. Different worker backends may have different methods
of converting from the global ExecuteModelRequest produced by the LLM
engine to the worker-local ModelRunnerInputBase objects.
Model runners that support multi-GPU execution should define a
ModelRunnerInputBase subclass, add their required fields, and specify how to
serialize/deserialize a ModelInput for broadcast between workers.
"""
pass
class ModelRunnerInputBuilderBase(ABC, Generic[T]):
"""A builder to create ModelRunnerInputBase objects.
"""
@abstractmethod
def prepare(self,
finished_requests_ids: Optional[List[str]] = None) -> None:
raise NotImplementedError
@abstractmethod
def add_seq_group(self, seq_group_metadata):
"""TBA"""
raise NotImplementedError
@abstractmethod
def build(self, *args, **kwargs) -> T:
"""Build metadata with on-device tensors."""
raise NotImplementedError
class ModelRunnerBase(ABC, Generic[T]):
"""
Model runner interface that abstracts a particular hardware and/or type of
model. Model execution may communicate data with model runners in other
processes, but it should not include control plane metadata communication.
Each ModelRunnerBase subclass should define a corresponding
ModelRunnerInputBase subclass.
"""
def __init__(
self,
vllm_config: VllmConfig,
) -> None:
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config
self.scheduler_config = vllm_config.scheduler_config
self.device_config = vllm_config.device_config
self.speculative_config = vllm_config.speculative_config
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config
# Map of request_id -> generator used for seeded random sampling
generators: Dict[str, torch.Generator] = {}
@abstractmethod
def make_model_input_from_broadcasted_tensor_dict(
self,
tensor_dict: Dict[str, Any],
) -> T:
"""
Make an instance of a ModelRunnerInputBase from the broadcasted tensor
dict.
"""
raise NotImplementedError
@abstractmethod
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None,
) -> T:
"""
Prepare the inputs to ModelRunnerBase.execute_model from an execution
request. This method may move data to the worker's local device. It is
not allowed to communicate with other workers or devices.
"""
raise NotImplementedError
@abstractmethod
def get_model(self) -> nn.Module:
raise NotImplementedError
def execute_model(
self,
model_input: T,
kv_caches: Optional[List[torch.Tensor]],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
**kwargs,
) -> Optional[List[SamplerOutput]]:
"""
Execute the model on the given input.
"""
raise NotImplementedError
def get_generators(self, finished_request_ids: Optional[List[str]] = None):
"""
Return dict of per-request generators used for random sampling.
"""
# Clean up generators from completed requests
if finished_request_ids:
for request_id in finished_request_ids:
self.generators.pop(request_id, None)
return self.generators
class ModelRunnerWrapperBase:
"""
The whole point of this class is to lazily initialize the model_runner.
"""
def __init__(
self,
model_runner: ModelRunnerBase,
) -> None:
self.model_runner: ModelRunnerBase = model_runner
def __getattr__(self, attr):
return getattr(self.model_runner, attr)
class InputProcessingError(Exception):
"""This exception is raised when an error occurs preparing the inputs for
a single sequence group.
This allows the engine to gracefully handle errors with a single sequence
group without having to fail the entire batch.
"""
def __init__(self, request_id, message):
"""request_id is the id of the offending sequence group"""
self.request_id = request_id
self.message = message
super().__init__(self.message)
def __str__(self):
return "Failed to prepare inputs for sequence group with request id: " \
f"{self.request_id}, Error: {self.message}"

View File

@@ -0,0 +1,123 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
###############################################################################
# Copyright (C) 2025 Habana Labs, Ltd. an Intel Company
###############################################################################
import dataclasses
from typing import Dict, Optional, Tuple
import torch
from vllm.distributed import broadcast_tensor_dict
from vllm.sequence import ExecuteModelRequest
from vllm.worker.hpu_model_runner import ModelInputForHPU
from vllm.worker.hpu_worker import HPUWorker
from vllm.worker.worker_base import WorkerInput
class MultiStepHPUWorker(HPUWorker):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.cached_model_input: Optional[ModelInputForHPU] = None
def _get_driver_input_and_broadcast(
self, execute_model_req: ExecuteModelRequest
) -> Tuple[ModelInputForHPU, WorkerInput, Dict[str, torch.Tensor]]:
"""
Get the driver input and broadcast it to other workers.
"""
assert self.is_driver_worker
assert execute_model_req.virtual_engine == 0
is_first_multi_step = execute_model_req.is_first_multi_step
is_last_step = execute_model_req.is_last_step
if is_first_multi_step:
# on first step we prepare the worker input and model input normally
worker_input: WorkerInput = self.prepare_worker_input(
execute_model_req=execute_model_req)
worker_input = dataclasses.replace(
worker_input,
num_steps=execute_model_req.num_lookahead_slots + 1)
model_input: ModelInputForHPU = (
self.model_runner.prepare_model_input(
execute_model_req.seq_group_metadata_list,
execute_model_req.virtual_engine,
execute_model_req.finished_requests_ids))
if execute_model_req.async_callback:
model_input = dataclasses.replace(
model_input,
async_callback=execute_model_req.async_callback)
else:
# on subsequent steps we reuse the worker input and model input
assert self.cached_model_input is not None
model_input = self.cached_model_input
worker_input = WorkerInput()
model_input = dataclasses.replace(
model_input,
is_first_multi_step=is_first_multi_step,
is_last_step=is_last_step)
if self.do_metadata_broadcast:
if is_first_multi_step:
broadcast_data = worker_input.as_broadcastable_tensor_dict()
broadcast_data.update(
model_input.as_broadcastable_tensor_dict())
broadcast_tensor_dict(broadcast_data, src=0)
else:
broadcast_data = {
"is_first_multi_step": is_first_multi_step,
"is_last_step": is_last_step,
}
broadcast_tensor_dict(broadcast_data, src=0)
# Returning empty dict here to keep this compatible with
# `LocalOrDistributedWorkerBase._get_driver_input_and_broadcast`
return model_input, worker_input, {}
def prepare_input(
self,
execute_model_req: Optional[ExecuteModelRequest] = None,
) -> Optional[Tuple[ModelInputForHPU, WorkerInput, Dict[str,
torch.Tensor]]]:
if self.is_driver_worker:
if execute_model_req is None:
if self.do_metadata_broadcast:
# This signals that there's no more requests to process for
# now. All workers are running infinite loop with
# broadcast_tensor_dict, and it stops the loop when the
# driver broadcasts an empty input. Send an empty input to
# notify all other workers to stop their execution loop.
broadcast_tensor_dict({}, src=0)
return None
model_input, worker_input, _ = self._get_driver_input_and_broadcast(
execute_model_req)
if model_input.is_first_multi_step:
self.cached_model_input = model_input
return model_input, worker_input, {}
else:
broadcast_data = broadcast_tensor_dict(src=0)
if not broadcast_data:
return None
if len(broadcast_data) == 2:
assert self.cached_model_input is not None
self.cached_model_input = dataclasses.replace(
self.cached_model_input,
is_first_multi_step=broadcast_data["is_first_multi_step"],
is_last_step=broadcast_data["is_last_step"])
empty_worker_input = WorkerInput()
return self.cached_model_input, empty_worker_input, {}
worker_input = WorkerInput.from_broadcasted_tensor_dict(
broadcast_data)
model_input = (
self.model_runner.
make_model_input_from_broadcasted_tensor_dict(broadcast_data))
self.cached_model_input = model_input
return model_input, worker_input, {}

View File

@@ -0,0 +1,911 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
import functools
from dataclasses import dataclass, field
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
Union)
import torch
from vllm.distributed import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import (PromptLogprobs, SampleLogprobs,
SamplerOutput,
SamplingMetadata, get_logprobs,
get_pythonized_sample_results)
from vllm.platforms import current_platform
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
Logprob, SequenceGroupMetadata, SequenceOutput)
from vllm.utils import PyObjectCache, async_tensor_h2d, current_stream
from vllm.worker.model_runner import (GPUModelRunnerBase,
ModelInputForGPUWithSamplingMetadata)
from vllm.worker.model_runner_base import (
BroadcastableModelInput, _init_attn_metadata_from_tensor_dict,
_init_frozen_model_input_from_tensor_dict,
_init_sampling_metadata_from_tensor_dict)
from ..model_executor.model_loader.tensorizer import TensorizerConfig
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
logger = init_logger(__name__)
MULTI_STEP_ATTENTION_BACKENDS = [
"FLASH_ATTN", "ROCM_FLASH", "FLASHINFER", "NO_ATTENTION"
]
MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS = ["FLASH_ATTN", "FLASHINFER"]
def _get_supported_attention_backends(chunked_prefill_enabled: bool) \
-> List[str]:
if chunked_prefill_enabled:
return MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS
else:
return MULTI_STEP_ATTENTION_BACKENDS
def seq_output_builder():
return SequenceOutput(
0, 0,
{0: Logprob(logprob=float('inf'), rank=None, decoded_token=None)})
def completion_seq_group_output_builder():
return CompletionSequenceGroupOutput([], None)
# Used by pythonization to reduce python object allocations
class PythonizationCache:
def __init__(self):
self.cached_seq_output = PyObjectCache(seq_output_builder)
self.cached_completion_seq_group_output = PyObjectCache(
completion_seq_group_output_builder)
def reset(self):
self.cached_seq_output.reset()
self.cached_completion_seq_group_output.reset()
@dataclass
class ModelOutput:
"""The output of a single model forward pass.
The sampler_output_ready_event is set when the tensors in
sampler_output are ready (the model+sampler forward pass has
completed). We use the event to synchronize the GPU->CPU transfer,
which we want to only run when the data has been written to the
GPU tensors. Until the event is ready, the tensors in sampler_output
will have garbage data.
There are two scenarios:
1. The output tensors are ready and we can pythonize them immediately.
2. The output tensors are not ready and we need to wait for the event to be
ready.
"""
sampler_output: SamplerOutput
sampler_output_ready_event: torch.cuda.Event
sampled_token_ids: Optional[torch.Tensor] = None
pythonized: bool = False
# On-device tensor containing the logprobs of each token.
logprobs: Optional["torch.Tensor"] = None
pythonization_cache: Optional[PythonizationCache] = None
def pythonize(self, input_metadata: "StatefulModelInput",
copy_stream: torch.cuda.Stream,
pinned_sampled_token_buffer: torch.Tensor) -> None:
"""Pythonize the output. Blocking."""
if not self.pythonized:
self._pythonize_sampler_output(input_metadata, copy_stream,
pinned_sampled_token_buffer, True)
self.pythonized = True
def maybe_pythonize(self, input_metadata: "StatefulModelInput",
copy_stream: torch.cuda.Stream,
pinned_sampled_token_buffer: torch.Tensor) -> None:
"""Pythonize the output if ready, else return None. Non-blocking."""
if not self.pythonized:
self.pythonized = self._pythonize_sampler_output(
input_metadata, copy_stream, pinned_sampled_token_buffer,
False)
def _pythonize_sampler_output(self, input_metadata: "StatefulModelInput",
copy_stream: torch.cuda.Stream,
pinned_sampled_token_buffer: torch.Tensor,
blocking: bool) -> bool:
"""
If blocking is set, will block until the forward pass for the output is
ready and pythonize the output. Upon completing Pythonization, erases
self.logprobs (note that a non-blocking call that is performed when
the sampler output is not yet ready, will not erase self.logprobs.)
"""
assert self.sampled_token_ids is not None
if not blocking and not self.sampler_output_ready_event.query():
return False
if blocking:
self.sampler_output_ready_event.synchronize()
with torch.cuda.stream(copy_stream):
_pythonize_sampler_output(input_metadata, self.sampler_output,
pinned_sampled_token_buffer,
self.sampled_token_ids, self.logprobs,
self.pythonization_cache)
# Erase the logprobs GPU-side tensor.
# Note that although _pythonize_sampler_output() runs in its
# own CUDA stream, nonetheless _pythonize_sampler_output()
# cannot return until Pythonization is complete; therefore
# we know that by the time the CPU reaches this point,
# `self.logprobs` is no longer needed.
self.logprobs = None
return True
@dataclass(frozen=False)
class StatefulModelInput(BroadcastableModelInput):
# actual frozen model input dataclass passed to _base_model_runner
frozen_model_input: Optional[ModelInputForGPUWithSamplingMetadata] = None
# list of model outputs for each step, may not be all pythonized
cached_outputs: List[ModelOutput] = field(default_factory=list)
# used to pass sampled token ids from the last step to the current step for
# TP workers. Used to append to end of outputs and used by advance_step
last_sampled_token_ids: Optional[torch.Tensor] = None
current_step: int = 0
is_multi_step: bool = True
is_last_step: bool = False
is_first_multi_step: bool = False
base_output_proc_callback: Optional[Callable] = None
# ping-pong data structures for multi-step to wait on the previous step
step_cuda_events: List[current_platform.Event] = field(
default_factory=lambda: [current_platform.Event(blocking=True)] * 2)
num_seqs: int = -1
num_queries: int = -1
num_single_step_prefills: int = 0
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
assert self.frozen_model_input is not None
tensor_dict = self.frozen_model_input.as_broadcastable_tensor_dict()
new_tensor_dict = {
'last_sampled_token_ids': self.last_sampled_token_ids,
'current_step': self.current_step,
'is_multi_step': self.is_multi_step,
'is_last_step': self.is_last_step,
'is_first_multi_step': self.is_first_multi_step,
'num_seqs': self.num_seqs,
'num_queries': self.num_queries,
'num_single_step_prefills': self.num_single_step_prefills,
}
tensor_dict.update(new_tensor_dict)
return tensor_dict
@classmethod
def from_broadcasted_tensor_dict(
cls,
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None,
) -> "StatefulModelInput":
tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
if attn_backend is not None:
tensor_dict = _init_attn_metadata_from_tensor_dict(
attn_backend, tensor_dict)
tensor_dict = _init_frozen_model_input_from_tensor_dict(
ModelInputForGPUWithSamplingMetadata, tensor_dict)
return cls(**tensor_dict)
def record_step_event(self, current_stream: torch.cuda.Stream):
# record the event for the current step so that the next step can sync
# on it. We modulo by 2 to keep the events in a circular buffer and
# support any attn backends that may be supported in the future. ie
# Flashinfer would want two DecodeWrappers to overlap the CPU and GPU.
self.step_cuda_events[self.current_step & 1] = \
torch.cuda.Event(blocking=True)
self.step_cuda_events[self.current_step & 1].record(current_stream)
def wait_previous_step(self):
# These cuda events are an explicit synchronization to ensure that
# advance_step() (for other attn backends that may be supported in the
# future) do not clobber any data structures that is also used by any
# enqueued forwards steps. For distributed case, only a single event is
# needed, but for single GPU case, since we can let the CPU run much
# further ahead, two events allow us to overlap the advance_step with
# the previous forward (ie using two DecodeWrappers for flashinfer
# backend)
self.step_cuda_events[(self.current_step + 1) & 1].wait()
def add_sampler_output(self,
sampler_output: SamplerOutput,
sampled_token_ids: Optional[torch.Tensor] = None):
self.cached_outputs.append(
ModelOutput(sampler_output=sampler_output,
sampler_output_ready_event=None,
sampled_token_ids=sampled_token_ids,
pythonized=False))
def maybe_advance_sampling_metadata(self, device: str, pin_memory: bool):
"""
sampling_metadata.selected_token_indices is constructed for the
first-step in Multi-Step. However, when chunked-prefill is enabled with
multi-step, the scheduled prompts are fully processed in the
first-step and are processed as decodes in the rest of the steps.
This function updates the sampling_metadata.selected_token_indices
to account for this conversion.
Example:
Let 2 prompts and 2 decodes be scheduled together. Let the
num-tokens to process for the 2 prompts be 5 and 8 respectively.
In that case, sampling_metadata.sampled_token_indices will be,
[4, 12, 13, 14] as it is constructed for the first-step in
multi-step.
However, the prompts turns to decodes after the first-step
and the num-tokens for the previously-prompt sequences will
be 1 and 1 as they are decodes now. The self.sampled_token_indices
must be updated to [0,1,2,3].
"""
assert self.current_step == 1 and self.num_single_step_prefills > 0
if not get_pp_group().is_last_rank:
return
assert self.frozen_model_input is not None
assert self.frozen_model_input.sampling_metadata is not None
self.frozen_model_input.sampling_metadata.selected_token_indices = \
async_tensor_h2d(list(range(self.num_queries)),
dtype=torch.long,
target_device=device,
pin_memory=pin_memory)
def maybe_advance_frozen_model_input(self, device: str, pin_memory: bool):
"""
Advancing the datastructures of StatefulModelInput::frozen_model_input
is only required when prefills are scheduled with decodes to run in
multi-step. This advancement/correction is required to account for
the conversion of Prefills to Decodes after the first multi-step.
"""
if self.current_step != 1 or self.num_single_step_prefills == 0:
return
assert self.frozen_model_input is not None
fmi = self.frozen_model_input
# Truncate input_tokens
assert fmi.input_tokens is not None
assert fmi.input_tokens.shape[0] >= self.num_seqs
fmi_new_input_tokens: torch.Tensor = fmi.input_tokens[:self.num_seqs]
# Update frozen_model_input::input_positions.
assert fmi.input_positions is not None
assert fmi.input_positions.shape[0] >= self.num_seqs
fmi_new_input_positions: torch.Tensor = fmi.input_positions[:self.
num_seqs]
# Assert unsupported
assert fmi.lora_mapping is None
assert fmi.lora_requests is not None
assert len(fmi.lora_requests) == 0
assert fmi.attn_metadata is not None
assert fmi.prompt_adapter_mapping is None
assert fmi.prompt_adapter_requests is not None
assert len(fmi.prompt_adapter_requests) == 0
assert fmi.multi_modal_kwargs is not None
assert len(fmi.multi_modal_kwargs) == 0
self.frozen_model_input = dataclasses.replace(
self.frozen_model_input,
input_tokens=fmi_new_input_tokens,
input_positions=fmi_new_input_positions)
self.maybe_advance_sampling_metadata(device, pin_memory)
# MutableModelInputForGPUWithMultiStepMetadata is not subclass of
# ModelInputForGPU but it wraps the actual input dataclass and adds multi-step
# metadata
# mypy: disable-error-code=type-var
class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
# mypy: enable-error-code=type-var
def __init__(self, base_model_runner: GPUModelRunnerBase, *args, **kwargs):
super().__init__(*args, **kwargs)
# Check attention backend support.
supported_attention_backends: List[str] = \
_get_supported_attention_backends(
self.scheduler_config.chunked_prefill_enabled)
if self.attn_backend.get_name() not in supported_attention_backends:
ms_config_str: str = "Multi-Step + Chunked-Prefill" \
if self.scheduler_config.chunked_prefill_enabled \
else "Multi-Step"
raise ValueError(
f"{ms_config_str} not supported for attention backend: "
f"{self.attn_backend.get_name()}. Set VLLM_ATTENTION_BACKEND "
f"to a value from {supported_attention_backends}.")
# uses the base model runner to execute the model and wraps it with
# multi-step logic
self._base_model_runner: GPUModelRunnerBase = base_model_runner
self.is_multi_step = self.scheduler_config.is_multi_step
self.pinned_sampled_token_ids: Optional[torch.Tensor] = None
# Using the PythonizationCache in Pipeline-Parallel clobbers the
# SequenceOutput and CompletionSequenceGroupOutput object.
# When cache-reset happens at the last step of a multi-step
# execution, there may be other on-going single-step/multi-step
# executions. The current caching implementation does not check
# for this.
self.pythonization_cache = PythonizationCache() \
if self.parallel_config.pipeline_parallel_size == 1 else None
@functools.cached_property
def _copy_stream(self):
# used to copy tensors from GPU to CPU asynchronously
return torch.cuda.Stream()
def make_model_input_from_broadcasted_tensor_dict(
self, tensor_dict: Dict[str, Any]) -> StatefulModelInput:
model_input = (StatefulModelInput.from_broadcasted_tensor_dict(
tensor_dict,
attn_backend=self.attn_backend,
))
return model_input
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> StatefulModelInput:
frozen_model_input: ModelInputForGPUWithSamplingMetadata = \
self._base_model_runner.prepare_model_input(
seq_group_metadata_list,
virtual_engine,
finished_requests_ids)
assert frozen_model_input.query_lens is not None
assert frozen_model_input.seq_lens is not None
assert frozen_model_input.attn_metadata is not None
num_queries = len(frozen_model_input.query_lens)
num_seqs = len(frozen_model_input.seq_lens)
num_single_step_prefills = frozen_model_input.attn_metadata.num_prefills
model_input = StatefulModelInput(
frozen_model_input=frozen_model_input,
num_seqs=num_seqs,
num_queries=num_queries,
num_single_step_prefills=num_single_step_prefills)
return model_input
def _async_process_outputs(self, model_input: StatefulModelInput,
output_proc_callback: Callable):
# Proceed with pythonization and output_proc in order.
# Stop on the first one that fails to pythonize
output_proc_callback()
cont = True
for step_num, model_output in enumerate(model_input.cached_outputs):
if not model_output.pythonized:
model_output.maybe_pythonize(model_input, self._copy_stream,
self.pinned_sampled_token_ids)
if model_output.pythonized:
ctx = output_proc_callback.keywords["ctx"]
ctx.append_output(
outputs=[model_output.sampler_output],
seq_group_metadata_list=ctx.seq_group_metadata_list,
scheduler_outputs=ctx.scheduler_outputs,
is_async=False,
is_last_step=False,
is_first_step_output=step_num == 0)
output_proc_callback()
else:
cont = False
if not cont:
break
def _final_process_outputs(
self, model_input: StatefulModelInput,
output_proc_callback: Optional[Callable]) -> List[SamplerOutput]:
assert model_input.frozen_model_input is not None
has_async_callback = output_proc_callback is not None
outputs = []
for step_num, output in enumerate(model_input.cached_outputs):
is_last_step = step_num == len(model_input.cached_outputs) - 1
# For non-async case:
# -- We simply add the outputs
# For async case:
# -- Invoke callback, pythonize, add to callback queue and repeat
# -- For last output, just add to callback queue
if has_async_callback:
assert output_proc_callback is not None
# Invoke callback before pythonize (to overlap with GPU)
output_proc_callback()
# Pythonize
if not output.pythonized:
output.pythonize(model_input, self._copy_stream,
self.pinned_sampled_token_ids)
# For non last step, add to callback queue to chain
# callbacks=>pythonize pairs (for GPU overlap)
if not is_last_step:
ctx = output_proc_callback.keywords[ # type: ignore
"ctx"] # type: ignore
ctx.append_output(
outputs=[output.sampler_output],
seq_group_metadata_list=ctx.
seq_group_metadata_list,
scheduler_outputs=ctx.scheduler_outputs,
is_async=False,
is_last_step=False,
is_first_step_output=step_num == 0)
else:
outputs.append(output.sampler_output)
else:
output.pythonize(model_input, self._copy_stream,
self.pinned_sampled_token_ids)
outputs.append(output.sampler_output)
return outputs
@torch.inference_mode()
def execute_model(
self,
model_input: StatefulModelInput,
kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
"""
Execute the model for a single step and update multi-step
metadata
"""
assert num_steps == 1, "MultiStepModelRunner only supports num_steps=1"
frozen_model_input = model_input.frozen_model_input
assert frozen_model_input is not None
# path for warm up runs
if not model_input.is_multi_step:
return self._base_model_runner.execute_model(
frozen_model_input, None, intermediate_tensors, num_steps)
# make sure we skip the sampler on the lask rank and only pythonize
# if CPU is ahead.
if self.is_driver_worker and get_pp_group().is_last_rank:
if self.pinned_sampled_token_ids is None:
self.pinned_sampled_token_ids = torch.zeros(
(self.scheduler_config.max_num_seqs, 1),
dtype=torch.long,
device="cpu",
pin_memory=True)
self._base_model_runner.sampler.include_gpu_probs_tensor = True
if frozen_model_input.sampling_metadata:
frozen_model_input.sampling_metadata.skip_sampler_cpu_output = (
True)
# some pre-execute model logic for multi-step:
# - if it's the first step, we need to reset the sampling tensors
# - if it's not the first step, we need to advance the step using the
# appended sampler output from last iteration
# - also maybe pythonize if CPU is ahead of GPU
stream = current_stream()
if not model_input.is_first_multi_step:
# Explicitly block on the previous step's forward to make sure we
# don't clobber any GPU tensors still in use.
# This is not needed for flashattn backend, but for other attn
# backends such as flashinfer that performs extra CPU operations on
# input metadata we may need to synchronize any CPU operations that
# might clobber enqueued forwards. (prevents CPU from running too
# far ahead if needed)
model_input.wait_previous_step()
model_input = self._advance_step(
model_input, model_input.cached_outputs[-1].sampler_output)
# frozen_model_input may have been updated
frozen_model_input = model_input.frozen_model_input
assert frozen_model_input is not None
if model_input.base_output_proc_callback is None:
assert frozen_model_input is not None
model_input.base_output_proc_callback = \
frozen_model_input.async_callback
if frozen_model_input.async_callback is not None:
assert model_input.base_output_proc_callback is not None
async_callback = functools.partial(
self._async_process_outputs,
model_input=model_input,
output_proc_callback=model_input.base_output_proc_callback)
model_input.frozen_model_input = dataclasses.replace( # type: ignore
model_input.frozen_model_input,
async_callback=async_callback)
# Update the local instance
frozen_model_input = model_input.frozen_model_input
assert frozen_model_input is not None
# Execute the model
output = self._base_model_runner.execute_model(frozen_model_input,
None,
intermediate_tensors,
num_steps=1)
# record the event for the current step so that the next step can sync
model_input.record_step_event(stream)
if get_pp_group().is_last_rank and self.is_driver_worker:
assert isinstance(output, list)
assert len(
output
) == 1, "MultiStepModelRunner requires single-step base_models"
# event for the pythonization so that we only pythonize if the
# tensors are ready. May be able to be combined with the step event
output_ready_event = torch.cuda.Event()
output_ready_event.record(stream)
if self.parallel_config.pipeline_parallel_size > 1:
output[0].sampled_token_ids_cpu = output[
0].sampled_token_ids.cpu()
model_input.cached_outputs.append(
ModelOutput(output[0], output_ready_event,
output[0].sampled_token_ids, False,
output[0].logprobs, self.pythonization_cache))
# These GPU tensors are not required by multi-step;
# erase them to ensure they are not pythonized or
# transferred to CPU
output[0].sampled_token_ids = None
output[0].sampled_token_probs = None
output[0].logprobs = None
# Pythonize the output if CPU is ahead and the previous step is
# ready.
if frozen_model_input.async_callback is None:
for model_output in model_input.cached_outputs:
model_output.maybe_pythonize(model_input,
self._copy_stream,
self.pinned_sampled_token_ids)
model_input.current_step += 1
if not get_pp_group().is_last_rank:
# Should be IntermediateTensors
assert isinstance(output, IntermediateTensors)
return output
if not self.is_driver_worker:
return []
# Pythonize the output and block if needed since it is the last step
if model_input.is_last_step:
outputs = self._final_process_outputs(
model_input, model_input.base_output_proc_callback)
if self.pythonization_cache:
self.pythonization_cache.reset()
return outputs
# should be [SamplerOutput]
return output
def _update_sampling_metadata(self, sampling_metadata: SamplingMetadata,
num_seqs: Optional[int], num_queries: int):
assert sampling_metadata.num_prompts == 0
assert len(sampling_metadata.seq_groups) == num_queries
assert sampling_metadata.selected_token_indices.shape == (
num_queries, )
# assert sampling_metadata.categorized_sample_indices == TODO: Add if needed # noqa: E501
# Verify that all sequences are decodes
for i in range(num_queries):
seq_group = sampling_metadata.seq_groups[i]
assert seq_group.is_prompt is False # No prompt
assert seq_group.prompt_logprob_indices == [] # No prompt
assert seq_group.sample_indices == [i] # Simple
assert seq_group.seq_len is None # Decode
assert seq_group.query_len is None # Decode
def _advance_step(self, model_input: StatefulModelInput,
out: SamplerOutput) -> StatefulModelInput:
model_input.maybe_advance_frozen_model_input(self.device,
self.pin_memory)
frozen_model_input = model_input.frozen_model_input
assert frozen_model_input is not None
assert frozen_model_input.input_tokens is not None
assert frozen_model_input.input_tokens.shape[0] == model_input.num_seqs
assert frozen_model_input.attn_metadata is not None
sampled_token_ids = model_input.cached_outputs[-1].sampled_token_ids
num_seqs = model_input.num_seqs
num_queries = model_input.num_queries
frozen_model_input = model_input.frozen_model_input
assert frozen_model_input is not None
attn_metadata = frozen_model_input.attn_metadata
assert attn_metadata is not None
turn_prefills_into_decodes: bool = model_input.current_step == 1 and \
model_input.num_single_step_prefills != 0
attn_metadata.advance_step(
frozen_model_input,
sampled_token_ids,
self.block_size,
num_seqs,
num_queries,
turn_prefills_into_decodes=turn_prefills_into_decodes)
return model_input
def load_model(self) -> None:
self._base_model_runner.load_model()
self.model_memory_usage = self._base_model_runner.model_memory_usage
def save_sharded_state(
self,
path: str,
pattern: Optional[str] = None,
max_size: Optional[int] = None,
) -> None:
return self._base_model_runner.save_sharded_state(
path, pattern, max_size)
def save_tensorized_model(self,
tensorizer_config: TensorizerConfig) -> None:
return self._base_model_runner.save_tensorized_model(tensorizer_config)
def profile_run(self) -> None:
return self._base_model_runner.profile_run()
def remove_all_loras(self):
return self._base_model_runner.remove_all_loras()
def capture_model(self, kv_caches: List[List]) -> None:
return self._base_model_runner.capture_model(kv_caches)
@property
def vocab_size(self) -> int:
return self._base_model_runner.vocab_size
DeferredLogprobsReturnType = Tuple[Optional[List[Optional[PromptLogprobs]]],
Optional[List[SampleLogprobs]]]
def deferred_pythonize_logprobs(
output: SamplerOutput,
sampling_metadata: SamplingMetadata,
logprobs_tensor: Optional[torch.Tensor],
) -> DeferredLogprobsReturnType:
"""Perform deferred logprob Pythonization.
1. Pythonize GPU-side sampler result tensors into CPU-side sampler result.
2. Pythonize GPU-side logprobs tensor into CPU-side logprobs lists,
utilizing the Pythonized sampler result computed in step 1.
These deferred computations are not required for single-step scheduling
or the `profile_run()` phase of multi-step scheduling.
Args:
output: sampler output (under deferred Pythonization)
sampling_metadata
Returns:
prompt_logprobs (CPU), sample_logprobs (CPU)
"""
# - Deferred pythonization of sample result
sampler_result = get_pythonized_sample_results(
output.deferred_sample_results_args)
# - Erase the GPU-side deferred sample_result
# computation args to ensure it is never
# pythonized or transferred to CPU
output.deferred_sample_results_args = None
# - Deferred pythonization of logprobs
(
prompt_logprobs,
sample_logprobs,
) = get_logprobs(logprobs_tensor, sampling_metadata, sampler_result)
assert len(prompt_logprobs) == len(sampling_metadata.seq_groups)
assert len(sample_logprobs) == len(sampling_metadata.seq_groups)
return prompt_logprobs, sample_logprobs
def _pythonize_sampler_output(
model_input: StatefulModelInput,
output: SamplerOutput,
pinned_sampled_token_buffer: torch.Tensor,
sampled_token_ids: torch.Tensor,
logprobs_tensor: Optional[torch.Tensor],
cache: Optional[PythonizationCache],
) -> None:
""" This function is only called when the output tensors are ready.
See [`ModelOutput`][vllm.worker.multi_step_model_runner.ModelOutput].
Modifies `output.outputs` and `pinned_sampled_token_buffer` in-place,
adding a Pythonized output data structure
([`CompletionSequenceGroupOutput`][vllm.sequence.CompletionSequenceGroupOutput])
for each [`SequenceGroup`][vllm.sequence.SequenceGroup].
Args:
model_input
output: sampler output
pinned_sampled_token_token_buffer: CPU-side pinned memory
(receives copy of
GPU-side token buffer.)
sampled_token_ids: GPU-side token buffer
logprobs_tensor: GPU-side tensor containing
logprobs computed during sampling
"""
assert model_input.frozen_model_input is not None
frozen_model_input = model_input.frozen_model_input
assert frozen_model_input.sampling_metadata is not None
sampling_metadata = frozen_model_input.sampling_metadata
# samples generation should have been skipped
assert not output.outputs
pinned_buffer = pinned_sampled_token_buffer[:model_input.num_queries]
# We guarantee output tensors are ready, so it is safe to
# pythonize the sampler output & obtain CPU-side logprobs.
#
# However we should check whether logprobs pythonization may
# be skipped entirely, i.e. because no logprobs were requested
# or pythonization was not deferred. To that end,
#
# * `prompt_logprobs_are_requested_for_prefill` signals that
# there are *any* prefill-phase requests which specify that
# prompt logprobs should be returned.
#
# * `any_logprobs_are_requested` signals that there are any
# requests which (1) specify that sample logprobs should be
# returned, or (2) are in the prefill phase AND specify that
# prompt logprobs should be returned.
#
# Later on, these flags cause adjustments to the pythonization
# process to accommodate logprobs.
seq_groups = sampling_metadata.seq_groups
prompt_logprobs_are_requested_for_prefill = any([
sg.sampling_params.prompt_logprobs is not None and sg.is_prompt
for sg in seq_groups
])
any_logprobs_are_requested = (
prompt_logprobs_are_requested_for_prefill
or any([sg.sampling_params.logprobs is not None for sg in seq_groups]))
if prompt_logprobs_are_requested_for_prefill:
# CPU GPU sync, after gathering *only* sampled tokens (since
# requesting prompt logprobs leads `sampled_token_ids` to
# include prompt token ids in addition to sampled token ids.)
sample_idx_tensor = torch.tensor(
[sdx for sg in seq_groups for sdx in sg.sample_indices])
pinned_buffer = pinned_buffer.copy_(
sampled_token_ids[sample_idx_tensor, :], non_blocking=False)
else:
# CPU GPU sync
pinned_buffer = pinned_buffer.copy_(sampled_token_ids,
non_blocking=False)
# this will not block as the tensors are already on CPU
samples_list = pinned_buffer.tolist()
skip_sampler_cpu_output = (
frozen_model_input.sampling_metadata.skip_sampler_cpu_output)
# *Don't* skip logprobs pythonization *if*:
# * Any requests require logprobs to be returned in this
# iteration AND
# * These requests are being scheduled in a fashion which
# defers pythonization (i.e. multi-step scheduling.)
do_pythonize_logprobs = (skip_sampler_cpu_output
and any_logprobs_are_requested)
(
prompt_logprobs,
sample_logprobs,
) = (deferred_pythonize_logprobs(output, sampling_metadata,
logprobs_tensor)
if do_pythonize_logprobs else (None, None))
for sgdx, (seq_group,
sample_result) in enumerate(zip(seq_groups, samples_list)):
# Reminder: Please update docs/features/compatibility_matrix.md
# If the feature combo become valid
# (Check for Guided Decoding)
if seq_group.sampling_params.logits_processors:
assert len(seq_group.sampling_params.logits_processors) == 0, (
"Logits Processors are not supported in multi-step decoding")
if do_pythonize_logprobs:
assert prompt_logprobs is not None
assert sample_logprobs is not None
(
group_prompt_logprobs,
group_sample_logprobs,
) = ( # Utilize deferred pythonization results
prompt_logprobs[sgdx],
sample_logprobs[sgdx],
)
elif any_logprobs_are_requested:
(
group_prompt_logprobs,
group_sample_logprobs,
) = (
# profile_run: use already-computed logprobs
output.outputs[sgdx].prompt_logprobs,
[sample.logprobs for sample in output.outputs[sgdx].samples])
seq_ids = seq_group.seq_ids
next_token_ids = sample_result
parent_ids = [0]
seq_outputs: List[SequenceOutput]
if cache is not None:
completion_seq_group_output: CompletionSequenceGroupOutput = \
cache.cached_completion_seq_group_output.get_object()
completion_seq_group_output.samples.clear()
seq_outputs = completion_seq_group_output.samples
else:
seq_outputs = []
for tdx, (parent_id,
next_token_id) in enumerate(zip(parent_ids, next_token_ids)):
if cache is not None:
seq_output: SequenceOutput = cache.cached_seq_output.get_object(
)
seq_output.parent_seq_id = seq_ids[parent_id]
seq_output.output_token = next_token_id
if any_logprobs_are_requested:
seq_output.logprobs = group_sample_logprobs[tdx]
else:
logprobs = next(iter(seq_output.logprobs.values()))
seq_output.logprobs.clear()
logprobs.logprob = float('inf')
logprobs.rank = None
logprobs.decoded_token = None
seq_output.logprobs[next_token_id] = logprobs
seq_outputs.append(seq_output)
else:
seq_outputs.append(
SequenceOutput(seq_ids[parent_id], next_token_id,
(group_sample_logprobs[tdx]
if any_logprobs_are_requested else {
next_token_id:
Logprob(logprob=float('inf'),
rank=None,
decoded_token=None)
})))
if cache is not None:
completion_seq_group_output.prompt_logprobs = \
group_prompt_logprobs if any_logprobs_are_requested else None
output.outputs.append(completion_seq_group_output)
else:
output.outputs.append(
CompletionSequenceGroupOutput(
seq_outputs, (group_prompt_logprobs
if any_logprobs_are_requested else None)))
assert len(output.outputs) > 0

View File

@@ -0,0 +1,84 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from importlib.util import find_spec
from typing import List, Optional
import torch
from vllm.config import VllmConfig
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.multimodal import MultiModalKwargs
from vllm.sequence import IntermediateTensors
from vllm.worker.neuron_model_runner import (ModelInputForNeuron,
NeuronModelRunner)
class MultiStepNeuronModelRunner(NeuronModelRunner):
"""A model runner for multi step decoding using the transformers_neuronx
framework"""
def __init__(
self,
vllm_config: VllmConfig,
):
super().__init__(vllm_config)
self.speculation_config = self.speculative_config
from transformers_neuronx.config import GenerationConfig
self.speculation_config.draft_model_config.neuron_sampling_params = (
GenerationConfig(
max_length=self.scheduler_config.max_model_len,
do_sample=True,
per_batch_line=True,
top_k=[self._MAX_NEURON_SAMPLING_TOP_K] \
* self.scheduler_config.max_num_seqs,
top_p=[1.0] * self.scheduler_config.max_num_seqs,
temperature=[1.0] * self.scheduler_config.max_num_seqs,
dynamic=True,
global_top_k=self._MAX_NEURON_SAMPLING_TOP_K
))
def load_model(self) -> None:
if find_spec("transformers_neuronx") is not None:
from vllm.model_executor.model_loader.neuron import (
get_neuron_eagle_speculation_model,
get_neuron_speculation_model)
if self.speculation_config.speculative_token_tree is not None:
self.model = get_neuron_eagle_speculation_model(
self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
speculation_config=self.speculation_config)
else:
self.model = get_neuron_speculation_model(
self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
speculation_config=self.speculation_config)
else:
raise NotImplementedError(
"Supports only Transformer-NeuronX based models.")
@torch.inference_mode()
def execute_model(
self,
model_input: ModelInputForNeuron,
kv_caches: Optional[List[torch.Tensor]] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> Optional[List[SamplerOutput]]:
logits = self.model(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
input_block_ids=model_input.input_block_ids,
**MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs or {},
device=self.device,
),
)
output = self.model.sample(
logits=logits,
sampling_metadata=model_input.sampling_metadata,
)
return output

View File

@@ -0,0 +1,63 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import List, Optional
import torch
from vllm.config import VllmConfig
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.multimodal import MultiModalKwargs
from vllm.sequence import IntermediateTensors
from vllm.worker.neuronx_distributed_model_runner import (
NeuronxDistributedModelRunner)
class MultiStepNeuronxDistributedModelRunner(NeuronxDistributedModelRunner):
"""A model runner for multi-step decoding using the
neuronx-distributed-inference framework"""
def __init__(
self,
vllm_config: VllmConfig,
):
super().__init__(vllm_config)
def load_model(self) -> None:
from vllm.model_executor.model_loader.neuronx_distributed import (
get_neuron_speculation_model)
self.model = get_neuron_speculation_model(
self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
speculation_config=self.speculative_config)
@torch.inference_mode()
def execute_model(
self,
model_input,
kv_caches: Optional[List[torch.Tensor]] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> Optional[List[SamplerOutput]]:
sampling_params = torch.tensor([[
seq_group.sampling_params.top_k,
seq_group.sampling_params.top_p,
seq_group.sampling_params.temperature,
] for seq_group in model_input.sampling_metadata.seq_groups])
logits = self.model(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
input_block_ids=model_input.input_block_ids,
sampling_params=sampling_params,
**MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs or {},
device=self.device,
),
)
output = self.model.sample(
logits=logits,
sampling_metadata=model_input.sampling_metadata,
)
return output

View File

@@ -0,0 +1,108 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
from typing import Dict, Optional, Tuple
import torch
from vllm.distributed import broadcast_tensor_dict
from vllm.sequence import ExecuteModelRequest
from vllm.worker.tpu_model_runner import ModelInputForTPU
from vllm.worker.tpu_worker import TPUWorker
from vllm.worker.worker_base import WorkerInput
class MultiStepTPUWorker(TPUWorker):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.cached_model_input: Optional[ModelInputForTPU] = None
def _get_driver_input_and_broadcast(
self, execute_model_req: ExecuteModelRequest
) -> Tuple[ModelInputForTPU, WorkerInput, Dict[str, torch.Tensor]]:
assert self.is_driver_worker
assert execute_model_req.virtual_engine == 0
is_first_multi_step = execute_model_req.is_first_multi_step
is_last_step = execute_model_req.is_last_step
if is_first_multi_step:
worker_input: WorkerInput = self.prepare_worker_input(
execute_model_req=execute_model_req)
worker_input = dataclasses.replace(
worker_input,
num_steps=execute_model_req.num_lookahead_slots + 1)
model_input: ModelInputForTPU = (
self.model_runner.prepare_model_input(
execute_model_req.seq_group_metadata_list,
execute_model_req.virtual_engine,
execute_model_req.finished_requests_ids))
if execute_model_req.async_callback:
model_input = dataclasses.replace(
model_input,
async_callback=execute_model_req.async_callback)
else:
assert self.cached_model_input is not None
model_input = self.cached_model_input
worker_input = WorkerInput()
model_input = dataclasses.replace(
model_input,
is_first_multi_step=is_first_multi_step,
is_last_step=is_last_step)
if self.do_metadata_broadcast:
if is_first_multi_step:
broadcast_data = worker_input.as_broadcastable_tensor_dict()
broadcast_data.update(
model_input.as_broadcastable_tensor_dict())
broadcast_tensor_dict(broadcast_data, src=0)
else:
broadcast_data = {
"is_first_multi_step": is_first_multi_step,
"is_last_step": is_last_step,
}
broadcast_tensor_dict(broadcast_data, src=0)
# Retuning empty dict here to keep this compatible with
# `LocalOrDistributedWorkerBase._get_driver_input_and_broadcast`
return model_input, worker_input, {}
def prepare_input(
self,
execute_model_req: Optional[ExecuteModelRequest] = None,
) -> Optional[Tuple[ModelInputForTPU, WorkerInput, Dict[str,
torch.Tensor]]]:
if self.is_driver_worker:
if execute_model_req is None:
if self.do_metadata_broadcast:
broadcast_tensor_dict({}, src=0)
return None
model_input, worker_input, _ = self._get_driver_input_and_broadcast(
execute_model_req)
if model_input.is_first_multi_step:
self.cached_model_input = model_input
return model_input, worker_input, {}
else:
broadcast_data = broadcast_tensor_dict(src=0)
if not broadcast_data:
return None
if len(broadcast_data) == 2:
assert self.cached_model_input is not None
self.cached_model_input = dataclasses.replace(
self.cached_model_input,
is_first_multi_step=broadcast_data["is_first_multi_step"],
is_last_step=broadcast_data["is_last_step"])
empty_worker_input = WorkerInput()
return self.cached_model_input, empty_worker_input, {}
worker_input = WorkerInput.from_broadcasted_tensor_dict(
broadcast_data)
model_input = (
self.model_runner.
make_model_input_from_broadcasted_tensor_dict(broadcast_data))
self.cached_model_input = model_input
return model_input, worker_input, {}

View File

@@ -0,0 +1,197 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import torch
from vllm.distributed import broadcast_tensor_dict, get_pp_group
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.worker.model_runner_base import BroadcastableModelInput
from vllm.worker.multi_step_model_runner import (MultiStepModelRunner,
StatefulModelInput)
from vllm.worker.worker import Worker, WorkerInput
@dataclass
class MultiStepState:
worker_input: WorkerInput
model_input: StatefulModelInput
class MultiStepWorker(Worker):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
base_model_runner = self.model_runner
# for multi-step model, wrap the model runner with MultiStepModelRunner
self.model_runner = MultiStepModelRunner(
base_model_runner,
vllm_config=base_model_runner.vllm_config,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=base_model_runner.is_driver_worker,
)
pipeline_parallel_size = self.parallel_config.pipeline_parallel_size
self.multi_step_states: List[
Optional[MultiStepState]] = [None] * pipeline_parallel_size
self.temp_output = None
def _get_driver_input_and_broadcast(
self, execute_model_req: ExecuteModelRequest
) -> Tuple[BroadcastableModelInput, WorkerInput, Dict[str, torch.Tensor]]:
"""
Get the driver input and broadcast it to other workers.
"""
assert self.is_driver_worker
virtual_engine = execute_model_req.virtual_engine
is_first_multi_step = execute_model_req.is_first_multi_step
if is_first_multi_step:
# on first step we prepare the worker input and model input normally
worker_input: WorkerInput = self.prepare_worker_input(
execute_model_req=execute_model_req)
model_input: StatefulModelInput = (
self.model_runner.prepare_model_input(
execute_model_req.seq_group_metadata_list,
execute_model_req.virtual_engine,
execute_model_req.finished_requests_ids))
if execute_model_req.async_callback:
model_input.frozen_model_input = dataclasses.replace( # type: ignore
model_input.frozen_model_input,
async_callback=execute_model_req.async_callback)
else:
# on subsequent steps we reuse the worker input and model input
multi_step_state = self.multi_step_states[virtual_engine]
worker_input = multi_step_state.worker_input
model_input = multi_step_state.model_input
frozen_model_input = model_input.frozen_model_input
assert frozen_model_input is not None
assert frozen_model_input.attn_metadata is not None
# clear the cached metadata so that it can be recomputed on
# the workers.
frozen_model_input.attn_metadata._cached_prefill_metadata = None
frozen_model_input.attn_metadata._cached_decode_metadata = None
model_input.is_first_multi_step = is_first_multi_step
model_input.is_last_step = execute_model_req.is_last_step
if not is_first_multi_step:
# we broadcast the last sampled token ids to all TP workers so they
# can update their model input metadata in-place.
self._prepare_last_sampled_token_ids_for_tp_workers(
execute_model_req=execute_model_req, model_input=model_input)
if self.do_metadata_broadcast:
broadcast_data = worker_input.as_broadcastable_tensor_dict()
broadcast_data.update(model_input.as_broadcastable_tensor_dict())
broadcast_tensor_dict(broadcast_data, src=0)
# Retuning empty dict here to keep this compatible with
# `LocalOrDistributedWorkerBase._get_driver_input_and_broadcast`
return model_input, worker_input, {}
def _prepare_last_sampled_token_ids_for_tp_workers(
self,
execute_model_req: ExecuteModelRequest,
model_input: StatefulModelInput,
) -> None:
"""
Prepare the last sampled token ids for TP workers. If it's the last
PP rank, then the last sampled token ids are already in the model_input.
If it is NOT the last PP rank, then we need to get the last sampled
token that is cached in the execute_model_req.
"""
if get_pp_group().is_last_rank:
assert model_input.cached_outputs[
-1].sampler_output.sampled_token_ids is None
assert model_input.cached_outputs[-1].sampled_token_ids is not None
model_input.last_sampled_token_ids = model_input.cached_outputs[
-1].sampled_token_ids
# free sampled token ids from the previous step if it has been
# pythonized. Cannot free the last sampled token ids because
# we need it for GPU advance_step.
for output in model_input.cached_outputs[:-1]:
if output.pythonized:
output.sampled_token_ids = None
else:
# otherwise we need to get the cached sampled token ids from the
# execute_model_req
assert execute_model_req.last_sampled_token_ids is not None
model_input.last_sampled_token_ids = (
execute_model_req.last_sampled_token_ids.cuda())
model_input.add_sampler_output(
SamplerOutput(outputs=[], sampled_token_ids=None),
model_input.last_sampled_token_ids)
# free sampled token ids from the previous step.
# TODO(will) we could reuse the sampled token ids tensor from
# the previous step instead.
for output in model_input.cached_outputs[:-1]:
output.sampled_token_ids = None
assert model_input.cached_outputs[-1].sampled_token_ids is not None
def prepare_input(
self,
execute_model_req: Optional[ExecuteModelRequest] = None,
) -> Optional[Tuple[StatefulModelInput, WorkerInput, Dict[str,
torch.Tensor]]]:
"""
Depending on the current state of the request and multi step worker,
this method may skip the normal _prepare_model_input and
_prepare_worker_input methods and instead used cached values.
"""
if self.is_driver_worker:
if execute_model_req is None:
if self.do_metadata_broadcast:
# This signals that there's no more requests to process for
# now. All workers are running infinite loop with
# broadcast_tensor_dict, and it stops the loop when the
# driver broadcasts an empty input. Send an empty input to
# notify all other workers to stop their execution loop.
broadcast_tensor_dict({}, src=0)
return None
virtual_engine = execute_model_req.virtual_engine
(model_input, worker_input,
kwargs) = self._get_driver_input_and_broadcast(execute_model_req)
assert isinstance(model_input, StatefulModelInput)
if execute_model_req.is_first_multi_step:
# cache the worker input and model input for the next steps
self.multi_step_states[virtual_engine] = MultiStepState(
worker_input=worker_input, model_input=model_input)
# if TP workers
else:
broadcast_data = self._get_worker_input_from_broadcast()
# if the driver has sent an empty input, we should stop the worker
# loop
if broadcast_data is None:
return None
model_input, worker_input, kwargs = broadcast_data
assert isinstance(model_input, StatefulModelInput)
virtual_engine = worker_input.virtual_engine
if model_input.is_first_multi_step:
pass
# TODO(will) Can cache the worker input and model input for the
# next steps. See below for details
else:
# TODO(will) possible to also cache and reuse the cached worker
# input and model input. The idea is essentially the delta
# optimization for model_inputs. Where the TP workers can cache
# the model input states and we only broadcast the delta need
# for the next step (sampled_token_ids from the previous step)
assert isinstance(model_input, StatefulModelInput)
# we need to update the last sampled token ids in the model
# input for the workers so that they can run inplace
# advance_step
model_input.add_sampler_output(
SamplerOutput(outputs=[], sampled_token_ids=None),
model_input.last_sampled_token_ids)
assert model_input is not None
assert worker_input is not None
return model_input, worker_input, kwargs

View File

@@ -0,0 +1,460 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
import torch
from torch import nn
from vllm.config import DeviceConfig, VllmConfig
from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.neuron import get_neuron_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalKwargs)
from vllm.platforms import current_platform
from vllm.sampling_params import SamplingParams
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
logger = init_logger(__name__)
@dataclass(frozen=True)
class ModelInputForNeuron(ModelRunnerInputBase):
"""
Used by the NeuronModelRunner.
"""
input_tokens: Optional[torch.Tensor] = None
input_positions: Optional[torch.Tensor] = None
input_block_ids: Optional[torch.Tensor] = None
sampling_metadata: SamplingMetadata = None
multi_modal_kwargs: BatchedTensorInputs = None
adapter_ids: Optional[str] = None
def as_broadcastable_tensor_dict(
self) -> Dict[str, Union[int, torch.Tensor]]:
return {
"input_tokens": self.input_tokens,
"input_positions": self.input_positions,
"input_block_ids": self.input_block_ids,
"sampling_metadata": self.sampling_metadata,
"multi_modal_kwargs": self.multi_modal_kwargs,
}
@classmethod
def from_broadcasted_tensor_dict(
cls,
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None,
) -> "ModelInputForNeuron":
return ModelInputForNeuron(
input_tokens=tensor_dict["input_tokens"],
input_positions=tensor_dict["input_positions"],
input_block_ids=tensor_dict["input_block_ids"],
sampling_metadata=tensor_dict["sampling_metadata"],
multi_modal_kwargs=tensor_dict["multi_modal_kwargs"],
)
class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
"""A model runner for AWS Neuron hardware"""
# NEURON has an upper limit on the top_k
_MAX_NEURON_SAMPLING_TOP_K = 256
def __init__(
self,
vllm_config: VllmConfig,
):
ModelRunnerBase.__init__(self, vllm_config)
if (self.model_config is not None
and self.model_config.get_sliding_window()):
logger.warning("Sliding window is not supported on Neuron. "
"The model will run without sliding window.")
self.device_config = (self.device_config if self.device_config
is not None else DeviceConfig())
self.lora_config = vllm_config.lora_config
self.device = self.device_config.device
self.pin_memory = is_pin_memory_available()
# Multi-modal data support
self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
.create_input_mapper(self.model_config)
# Lazy initialization.
self.model: nn.Module # initialize after load_model.
# Once NEURON_ON_DEVICE_SAMPLING_DISABLED is set to a non-zero value,
# turn off on-device sampling.
self._on_device_sampling_disabled = int(
os.getenv("NEURON_ON_DEVICE_SAMPLING_DISABLED", "0"))
# NEURON needs to update sampling parameters when request IDs change
# across batches. This variable stores the previous batch's request IDs
# to determine if an update is needed.
self._previous_batch_request_ids: List[str] = []
if not self._on_device_sampling_disabled:
self._init_neuron_sampling()
def _init_neuron_sampling(self) -> None:
if current_platform.use_transformers_neuronx():
from transformers_neuronx.config import GenerationConfig
else:
from transformers import GenerationConfig
logger.warning(
"On-device sampling is turned on in Neuron by default, only "
"top_k, top_p, and temperature are current supported sampling "
"parameters. To turn off the on-device sampling, please set "
"the environment variable NEURON_ON_DEVICE_SAMPLING_DISABLED=1.")
self.model_config.neuron_sampling_params = GenerationConfig(
max_length=self.scheduler_config.max_model_len,
do_sample=True,
per_batch_line=True,
top_k=[self._MAX_NEURON_SAMPLING_TOP_K] \
* self.scheduler_config.max_num_seqs,
top_p=[1.0] * self.scheduler_config.max_num_seqs,
temperature=[1.0] * self.scheduler_config.max_num_seqs,
dynamic=True,
global_top_k=self._MAX_NEURON_SAMPLING_TOP_K)
def load_model(self) -> None:
self.model = get_neuron_model(self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config)
def get_model(self) -> nn.Module:
return self.model
def _prepare_prompt(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int],
BatchedTensorInputs]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[List[int]] = []
input_positions: List[List[int]] = []
input_block_ids: List[int] = []
seq_lens: List[int] = []
multi_modal_kwargs_list: List[MultiModalKwargs] = []
for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys())
assert len(seq_ids) == 1
seq_id = seq_ids[0]
seq_data = seq_group_metadata.seq_data[seq_id]
prompt_tokens = seq_data.get_token_ids()
seq_len = len(prompt_tokens)
seq_lens.append(seq_len)
input_tokens.append(prompt_tokens)
input_positions.append(list(range(seq_len)))
assert seq_group_metadata.block_tables is not None
block_table = seq_group_metadata.block_tables[seq_id]
assert len(block_table) == 1
input_block_ids.append(block_table[0])
mm_kwargs = seq_group_metadata.multi_modal_data
if mm_kwargs:
mm_kwargs = self.process_multi_modal_data_neuron(mm_kwargs)
multi_modal_kwargs_list.append(mm_kwargs)
max_seq_len = max(seq_lens)
assert max_seq_len > 0
input_tokens = make_tensor_with_pad(input_tokens,
pad=0,
max_len=max_seq_len,
dtype=torch.long,
device=self.device)
input_positions = make_tensor_with_pad(input_positions,
pad=0,
max_len=max_seq_len,
dtype=torch.long,
device=self.device)
input_block_ids = torch.tensor(input_block_ids,
dtype=torch.long,
device=self.device)
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
return (input_tokens, input_positions, input_block_ids, seq_lens,
multi_modal_kwargs)
def _prepare_decode(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[List[int]] = []
input_positions: List[List[int]] = []
input_block_ids: List[int] = []
context_lens: List[int] = []
for seq_group_metadata in seq_group_metadata_list:
assert not seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys())
for seq_id in seq_ids:
seq_data = seq_group_metadata.seq_data[seq_id]
generation_token = seq_data.get_last_token_id()
input_tokens.append([generation_token])
seq_len = seq_data.get_len()
position = seq_len - 1
input_positions.append([position])
context_lens.append(seq_len)
assert seq_group_metadata.block_tables is not None
block_table = seq_group_metadata.block_tables[seq_id]
assert len(block_table) == 1
input_block_ids.append(block_table[0])
input_tokens = make_tensor_with_pad(input_tokens,
pad=0,
max_len=1,
dtype=torch.long,
device=self.device)
input_positions = make_tensor_with_pad(input_positions,
pad=0,
max_len=1,
dtype=torch.long,
device=self.device)
context_lens = torch.tensor(context_lens,
dtype=torch.int,
device=self.device)
input_block_ids = torch.tensor(input_block_ids,
dtype=torch.long,
device=self.device)
return input_tokens, input_positions, input_block_ids
def make_model_input_from_broadcasted_tensor_dict(
self, tensor_dict: Dict[str, Any]) -> ModelInputForNeuron:
return ModelInputForNeuron.from_broadcasted_tensor_dict(tensor_dict)
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> ModelInputForNeuron:
multi_modal_kwargs = None
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
is_prompt = seq_group_metadata_list[0].is_prompt
# Prepare input tensors.
if is_prompt:
(input_tokens, input_positions, input_block_ids, seq_lens,
multi_modal_kwargs
) = self._prepare_prompt(seq_group_metadata_list)
else:
(input_tokens, input_positions,
input_block_ids) = self._prepare_decode(seq_group_metadata_list)
seq_lens = None
if not self._on_device_sampling_disabled:
for seq_group_metadata in seq_group_metadata_list:
sampling_params = seq_group_metadata.sampling_params
top_k, top_p, temperature = (
self._convert_to_neuron_sampling_params(sampling_params))
sampling_params.top_k = top_k
sampling_params.top_p = top_p
sampling_params.temperature = temperature
# we need multi_modal_data for later tokens as well
multi_modal_kwargs_list: List[MultiModalKwargs] = []
for seq_group_metadata in seq_group_metadata_list:
mm_data = seq_group_metadata.multi_modal_data
if mm_data:
multi_modal_kwargs_list.append(mm_data)
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
seq_lens,
# query_lens is not needed if chunked prefill is not
# supported. Since neuron worker doesn't support chunked prefill
# just use seq_lens instead.
seq_lens,
self.device,
self.pin_memory,
generators=self.get_generators(finished_requests_ids))
if current_platform.use_transformers_neuronx(
) and not self._on_device_sampling_disabled:
# Once the request IDs are changed in current iteration, we will
# update the on-device sampling parameters.
current_batch_request_ids = [
seq_group_meta_data.request_id
for seq_group_meta_data in seq_group_metadata_list
]
if current_batch_request_ids != self._previous_batch_request_ids:
self._update_neuron_sampling_params(seq_group_metadata_list)
self._previous_batch_request_ids = current_batch_request_ids
return ModelInputForNeuron(input_tokens=input_tokens,
input_positions=input_positions,
input_block_ids=input_block_ids,
sampling_metadata=sampling_metadata,
multi_modal_kwargs=multi_modal_kwargs)
def _update_neuron_sampling_params(
self, seq_group_metadata_list: List[SequenceGroupMetadata]):
# Update Neuron sampling parameters (GenerationConfig in Neuron)
current_sampling_params = self.model_config.neuron_sampling_params
assert current_sampling_params is not None, (
f"Failed to update sampling_params, "
f"current sampling params is {current_sampling_params}")
is_update_needed = False
top_k = current_sampling_params.top_k
top_p = current_sampling_params.top_p
temperature = current_sampling_params.temperature
# The index of a sequence's sampling parameters in neuron is equal to
# its index in `input_block_ids`.
for seq_group_metadata in seq_group_metadata_list:
seq_ids = list(seq_group_metadata.seq_data.keys())
sampling_params = seq_group_metadata.sampling_params
seq_group_top_k = sampling_params.top_k
seq_group_top_p = sampling_params.top_p
seq_group_temperature = sampling_params.temperature
for seq_id in seq_ids:
index = seq_group_metadata.block_tables[seq_id][0]
if (top_k[index] != seq_group_top_k
or top_p[index] != seq_group_top_p
or temperature[index] != seq_group_temperature):
is_update_needed = True
top_k[index] = seq_group_top_k
top_p[index] = seq_group_top_p
temperature[index] = seq_group_temperature
# update_generation_config is only available in transformers-neuronx
if is_update_needed and current_platform.use_transformers_neuronx():
self.model.model.update_generation_config(current_sampling_params)
def _convert_to_neuron_sampling_params(
self, sampling_params: SamplingParams) -> Tuple[int, float, float]:
# Returns the top_k, top_p and temperature parameters for neuron.
top_k = sampling_params.top_k
top_p = sampling_params.top_p
temperature = sampling_params.temperature
if temperature == 0.0:
# Enable greedy sampling on zero temperature
return (1, 1.0, 1.0)
if top_k < 1 or top_k > self._MAX_NEURON_SAMPLING_TOP_K:
top_k = self._MAX_NEURON_SAMPLING_TOP_K
return (top_k, top_p, temperature)
@torch.inference_mode()
def execute_model(
self,
model_input: ModelInputForNeuron,
kv_caches: Optional[List[torch.Tensor]] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> Optional[List[SamplerOutput]]:
if num_steps > 1:
raise ValueError(
"NeuronModelRunner does not support multi-step execution.")
# extract top_k, top_p and temperature from model_input for neuron
# forward call
sampling_params = (torch.tensor([[
seq_group.sampling_params.top_k, seq_group.sampling_params.top_p,
seq_group.sampling_params.temperature
] for seq_group in model_input.sampling_metadata.seq_groups]))
if current_platform.use_neuronx_distributed():
hidden_states = self.model(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
input_block_ids=model_input.input_block_ids,
sampling_params=sampling_params,
adapter_ids=model_input.adapter_ids,
**MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs or {},
device=self.device,
),
)
elif current_platform.use_transformers_neuronx():
# [TODO] validate on-device sampling
# The model signature may need change for on-device sampling
hidden_states = self.model(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
input_block_ids=model_input.input_block_ids,
**MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs or {},
device=self.device,
),
)
# Compute the logits only if the on-device sampling is turned off as
# on-device sampling outputs the token ids.
if self._on_device_sampling_disabled:
logits = self.model.compute_logits(hidden_states,
model_input.sampling_metadata)
else:
logits = hidden_states
# Sample the next token.
output = self.model.sample(
logits=logits,
sampling_metadata=model_input.sampling_metadata,
)
return [output]
@property
def vocab_size(self) -> int:
return self.model_config.get_vocab_size()
def process_multi_modal_data_neuron(self, mm_data):
# this is a no-op for NeuronModelRunner
return mm_data
def remove_all_loras(self):
raise NotImplementedError(
"LoRAs are not supported for Transformers NeuronX framework")
def set_active_loras(self, lora_requests: Set[LoRARequest],
lora_mapping: LoRAMapping) -> None:
raise NotImplementedError(
"LoRAs are not supported for Transformers NeuronX framework")
def add_lora(self, lora_request: LoRARequest):
raise NotImplementedError(
"LoRAs are not supported for Transformers NeuronX framework")
def remove_lora(self, lora_id: int) -> bool:
raise NotImplementedError(
"LoRAs are not supported for Transformers NeuronX framework")
def pin_lora(self, lora_id: int) -> bool:
raise NotImplementedError(
"LoRAs are not supported for Transformers NeuronX framework")
def list_loras(self) -> Set[int]:
raise NotImplementedError(
"LoRAs are not supported for Transformers NeuronX framework")

View File

@@ -0,0 +1,198 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A Neuron worker class."""
import os
from typing import List, Optional, Set, Tuple
import torch.distributed
from vllm.config import VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
from vllm.platforms.neuron import NeuronFramework
from vllm.sequence import ExecuteModelRequest
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.neuron_model_runner import NeuronModelRunner
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
WorkerInput)
logger = init_logger(__name__)
class NeuronWorker(LocalOrDistributedWorkerBase):
"""A worker class that executes the model on a group of neuron cores.
"""
model_runner: NeuronModelRunner
def __init__(self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool = False) -> None:
WorkerBase.__init__(self, vllm_config=vllm_config)
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
self.is_driver_worker = is_driver_worker
self.lora_config = vllm_config.lora_config
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
neuron_framework = current_platform.get_neuron_framework_to_use()
if neuron_framework == NeuronFramework.TRANSFORMERS_NEURONX:
self.model_runner = self.get_tnx_model_runner(vllm_config)
elif neuron_framework == NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE:
self.model_runner = self.get_neuronx_distributed_model_runner(
vllm_config)
else:
raise NotImplementedError(
"Specified framework" +
f" {os.environ.get('VLLM_NEURON_FRAMEWORK')}" +
" is either not installed or not supported." +
" Supported frameworks: " +
"[transformers-neuronx, neuronx-distributed-inference]")
def get_tnx_model_runner(self, vllm_config):
assert (self.lora_config
is None), ("LoRA is not supported for TransformersNeuronX "
"framework.")
from vllm.worker.multi_step_neuron_model_runner import (
MultiStepNeuronModelRunner)
if self.speculative_config is not None:
return MultiStepNeuronModelRunner(vllm_config=vllm_config)
else:
return NeuronModelRunner(vllm_config=vllm_config)
def get_neuronx_distributed_model_runner(self, vllm_config):
from vllm.worker.multi_step_neuronx_distributed_model_runner import (
MultiStepNeuronxDistributedModelRunner)
from vllm.worker.neuronx_distributed_model_runner import (
NeuronxDistributedModelRunner)
if self.speculative_config is not None:
assert (self.lora_config
is None), "LoRA is not supported for Speculative Decoding"
return MultiStepNeuronxDistributedModelRunner(
vllm_config=vllm_config)
else:
return NeuronxDistributedModelRunner(vllm_config=vllm_config)
def init_device(self) -> None:
self.init_distributed_environment()
# Set random seed.
set_random_seed(self.model_config.seed)
def load_model(self):
self.model_runner.load_model()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks.
Swapping is not yet supported, so always return num_cpu_blocks=0.
We configure num_gpu_blocks to be equal to max_num_seqs.
"""
# Set the number of GPU blocks to be the same as the maximum number of
# sequences that can be processed in a single batch. This is equivalent
# to schedule without PagedAttention.
num_gpu_blocks = self.scheduler_config.max_num_seqs + 1
# Swap not yet supported with Neuron backend.
num_cpu_blocks = 0
return num_gpu_blocks, num_cpu_blocks
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Initialize the KV cache.
"""
# Different values are not tested.
assert num_cpu_blocks == 0
assert num_gpu_blocks == self.scheduler_config.max_num_seqs + 1
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
@property
def do_metadata_broadcast(self) -> bool:
return False
@property
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
return None
@property
def cache_engines(self) -> Optional[List[CacheEngine]]:
return None
@torch.inference_mode()
def prepare_worker_input(
self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
return WorkerInput(num_seq_groups=len(
execute_model_req.seq_group_metadata_list), )
def execute_worker(self, worker_input: WorkerInput) -> None:
pass
def get_cache_block_size_bytes(self) -> int:
"""Determine the size in bytes of a cache block.
This is required for speculative decoding; it is not yet implemented.
"""
raise NotImplementedError
def init_distributed_environment(self):
"""Neuron uses transformers-neuronx for tensor parallelism.
vLLM still needs the environment initialized when TP/PP > 1
"""
init_distributed_environment(
world_size=1,
rank=self.rank,
local_rank=self.local_rank,
distributed_init_method=self.distributed_init_method,
backend="gloo",
)
ensure_model_parallel_initialized(
1,
1,
)
def add_lora(self, lora_request: LoRARequest) -> bool:
if current_platform.use_transformers_neuronx():
raise NotImplementedError(
f"{type(self)} does not support LoRA with Neuron Framework "
f"Transformers NeuronX")
return self.model_runner.add_lora(lora_request)
def remove_lora(self, lora_id: int) -> bool:
if current_platform.use_transformers_neuronx():
raise NotImplementedError(
f"{type(self)} does not support LoRA with Neuron Framework "
f"Transformers NeuronX")
return self.model_runner.remove_lora(lora_id)
def pin_lora(self, lora_id: int) -> bool:
if current_platform.use_transformers_neuronx():
raise NotImplementedError(
f"{type(self)} does not support LoRA with Neuron Framework "
f"Transformers NeuronX")
return self.model_runner.pin_lora(lora_id)
def list_loras(self) -> Set[int]:
if current_platform.use_transformers_neuronx():
raise NotImplementedError(
f"{type(self)} does not support LoRA with Neuron Framework "
f"Transformers NeuronX")
return self.model_runner.list_loras()

View File

@@ -0,0 +1,294 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import List, Optional, Set
import torch
from neuronx_distributed_inference.models.mllama.aspect_ratio_utils import (
get_all_supported_aspect_ratios)
from neuronx_distributed_inference.modules.generation.sampling import (
prepare_sampling_params)
from neuronx_distributed_inference.modules.lora_serving import (
LoraCheckpoint, LoraServingConfig)
from vllm.config import VllmConfig
from vllm.entrypoints.openai.serving_models import LoRAModulePath
from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.neuronx_distributed import (
_get_model_architecture, get_neuron_model)
from vllm.multimodal import MultiModalKwargs
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.worker.neuron_model_runner import (ModelInputForNeuron,
NeuronModelRunner)
logger = init_logger(__name__)
class NeuronxDistributedModelRunner(NeuronModelRunner):
def __init__(
self,
vllm_config: VllmConfig,
):
super().__init__(vllm_config)
self.lora_checkpoint = None
self.model = None
self.lora_serving_config = None
@staticmethod
def _get_lora_paths_strings(lora_modules: List[LoRAModulePath]):
if not lora_modules:
return None
return {_.get("name"): _.get("path") for _ in lora_modules}
def _get_nxdi_lora_config(self):
override_neuron_config = self.model_config.override_neuron_config
lora_modules = override_neuron_config.pop("lora_modules", None)
target_modules = override_neuron_config.pop("target_modules", None)
lora_ckpt_paths = self._get_lora_paths_strings(lora_modules)
if self.lora_config.max_loras < len(lora_ckpt_paths):
raise ValueError(
"Number of LoRAs (%s) exceeds maximum "
"allowed (%s)", len(lora_ckpt_paths),
self.lora_config.max_loras)
return LoraServingConfig(
max_loras=self.lora_config.max_loras,
max_lora_rank=self.lora_config.max_lora_rank,
target_modules=target_modules,
lora_ckpt_paths=lora_ckpt_paths,
)
def load_model(self) -> None:
# Update LoRA config
if self.lora_config is not None:
self.lora_serving_config = self._get_nxdi_lora_config()
self.lora_checkpoint = LoraCheckpoint(self.lora_serving_config)
self.model = get_neuron_model(
self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
lora_serving_config=self.lora_serving_config)
def get_nxd_sampling_params(self, sampling_metadata):
if self.model.config.neuron_config.on_device_sampling_config:
max_topk = (self.model.config.neuron_config.
on_device_sampling_config.global_topk)
else:
max_topk = self.model.config.vocab_size
top_k = [1] * self.scheduler_config.max_num_seqs
top_p = [1.0] * self.scheduler_config.max_num_seqs
temperature = [1.0] * self.scheduler_config.max_num_seqs
for index, sequenceGroupToSample in enumerate(
sampling_metadata.seq_groups):
top_k[index] = (sequenceGroupToSample.sampling_params.top_k
if sequenceGroupToSample.sampling_params.top_k > 0
else max_topk)
top_p[index] = sequenceGroupToSample.sampling_params.top_p
temperature[index] = (
sequenceGroupToSample.sampling_params.temperature)
sampling_params = prepare_sampling_params(
batch_size=self.scheduler_config.max_num_seqs,
top_k=top_k,
top_p=top_p,
temperature=temperature)
return sampling_params
def get_multi_modal_data_neuron(self, input_images):
raise NotImplementedError("need to restore multi-modal support")
@torch.inference_mode()
def execute_model(
self,
model_input: ModelInputForNeuron,
kv_caches: Optional[List[torch.Tensor]] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> Optional[List[SamplerOutput]]:
if num_steps > 1:
raise ValueError(
"NeuronModelRunner does not support multi-step execution.")
if _get_model_architecture(
self.model.config) != "MllamaForConditionalGeneration":
return super().execute_model(model_input, kv_caches,
intermediate_tensors, num_steps)
sampling_params = self.get_nxd_sampling_params(
model_input.sampling_metadata)
if model_input.multi_modal_kwargs.get('pixel_values') is not None:
hidden_states = self.model(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
seq_ids=model_input.input_block_ids,
pixel_values=model_input.multi_modal_kwargs.get(
'pixel_values'),
aspect_ratios=model_input.multi_modal_kwargs.get(
'aspect_ratios'),
sampling_params=sampling_params,
num_chunks=model_input.multi_modal_kwargs.get('num_chunks'),
has_image=model_input.multi_modal_kwargs.get(
'has_image').squeeze(1),
)
else:
bs = model_input.input_tokens.shape[0] if (model_input.input_tokens
is not None) else 1
empty_pixel_values = torch.zeros([bs, 1, 4, 3, 560, 560],
dtype=torch.bfloat16)
empty_aspect_ratios = torch.ones([bs, 1, 2], dtype=torch.int64)
num_chunks = torch.zeros((bs, 1), dtype=torch.int32)
has_image = torch.zeros([bs], dtype=torch.int32)
hidden_states = self.model(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
seq_ids=model_input.input_block_ids,
pixel_values=empty_pixel_values,
aspect_ratios=empty_aspect_ratios,
sampling_params=sampling_params,
num_chunks=num_chunks,
has_image=has_image,
)
output = self.model.sample(
hidden_states=hidden_states,
sampling_metadata=model_input.sampling_metadata,
)
return [output]
def process_multi_modal_data_neuron(self, mm_data):
# Neuron uses aspect_ratios instead of aspect_ratio_ids
all_supported_aspect_ratios = get_all_supported_aspect_ratios(
self.model.config.vision_config.max_num_tiles)
aspect_ratio_ids = mm_data.get("aspect_ratio_ids")
mm_data["aspect_ratios"] = torch.tensor(
all_supported_aspect_ratios[aspect_ratio_ids]).unsqueeze(0)
# Neuron's num_chunks is HF's num_tiles
mm_data["num_chunks"] = mm_data.get("num_tiles")
# Input has an image if it has pixel_values
bs = mm_data["num_chunks"].shape[0]
pixel_values = mm_data.get("pixel_values")
if pixel_values is not None and not torch.all(pixel_values == 0):
mm_data["has_image"] = torch.ones(bs)
else:
mm_data["has_image"] = torch.zeros(bs)
return mm_data
def _get_lora_adapter_ids(self, seq_group_metadata_list):
# set LoRA adapter IDs for multi-lora serving
batch_size = len(seq_group_metadata_list)
if self.lora_checkpoint is not None:
# "0" indicates NxDI to use the base model for inference
adapter_ids = ["0"] * batch_size
for idx, seq_group_metadata in enumerate(seq_group_metadata_list):
if seq_group_metadata.lora_request is not None:
adapter_ids[
idx] = seq_group_metadata.lora_request.lora_name
# convert adapter_ids from strings to integers
adapter_ids = self.lora_checkpoint.convert_adapter_ids_to_indices(
adapter_ids, batch_size)
else:
adapter_ids = torch.zeros((batch_size), dtype=torch.int32)
return adapter_ids
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> ModelInputForNeuron:
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
is_prompt = seq_group_metadata_list[0].is_prompt
# Prepare input tensors.
if is_prompt:
(input_tokens, input_positions, input_block_ids, seq_lens,
multi_modal_kwargs
) = self._prepare_prompt(seq_group_metadata_list)
else:
(input_tokens, input_positions,
input_block_ids) = self._prepare_decode(seq_group_metadata_list)
seq_lens = None
if not self._on_device_sampling_disabled:
for seq_group_metadata in seq_group_metadata_list:
sampling_params = seq_group_metadata.sampling_params
top_k, top_p, temperature = (
self._convert_to_neuron_sampling_params(sampling_params))
sampling_params.top_k = top_k
sampling_params.top_p = top_p
sampling_params.temperature = temperature
# we need multi_modal_data for later tokens as well
multi_modal_kwargs_list: List[MultiModalKwargs] = []
for seq_group_metadata in seq_group_metadata_list:
mm_data = seq_group_metadata.multi_modal_data
if mm_data:
multi_modal_kwargs_list.append(mm_data)
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
lora_adapter_ids = self._get_lora_adapter_ids(seq_group_metadata_list)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
seq_lens,
# query_lens is not needed if chunked prefill is not
# supported. Since neuron worker doesn't support chunked prefill
# just use seq_lens instead.
seq_lens,
self.device,
self.pin_memory,
generators=self.get_generators(finished_requests_ids))
return ModelInputForNeuron(input_tokens=input_tokens,
input_positions=input_positions,
input_block_ids=input_block_ids,
sampling_metadata=sampling_metadata,
multi_modal_kwargs=multi_modal_kwargs,
adapter_ids=lora_adapter_ids)
def remove_all_loras(self):
raise NotImplementedError(
"Managing LoRAs is only supported through the "
"lora_modules parameter in override_neuron_config")
def set_active_loras(self, lora_requests: Set[LoRARequest],
lora_mapping: LoRAMapping) -> None:
raise NotImplementedError(
"Managing LoRAs is only supported through the "
"lora_modules parameter in override_neuron_config")
def add_lora(self, lora_request: LoRARequest):
logger.warning(
"Adding LoRAs is only supported through the "
"lora_modules parameter in override_neuron_config. If you supplied "
"the parameter, you can ignore this warning. Ignoring"
"lora request: ", lora_request)
def remove_lora(self, lora_id: int) -> bool:
raise NotImplementedError(
"Managing LoRAs is only supported through the "
"lora_modules parameter in override_neuron_config")
def pin_lora(self, lora_id: int) -> bool:
raise NotImplementedError(
"Managing LoRAs is only supported through the "
"lora_modules parameter in override_neuron_config")
def list_loras(self) -> Set[int]:
raise NotImplementedError(
"Managing LoRAs is only supported through the "
"lora_modules parameter in override_neuron_config")

View File

@@ -0,0 +1,211 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
from typing import Any, Dict, List, Optional, Tuple, Type, Union
import torch
from vllm.config import VllmConfig
from vllm.distributed import get_pp_group
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.multimodal import MultiModalKwargs
from vllm.pooling_params import PoolingParams
from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData,
SequenceGroupMetadata)
from vllm.worker.model_runner import (GPUModelRunnerBase, ModelInputForGPU,
ModelInputForGPUBuilder)
logger = init_logger(__name__)
@dataclasses.dataclass(frozen=True)
class ModelInputForGPUWithPoolingMetadata(ModelInputForGPU):
"""
Used by the PoolingModelRunner.
"""
pooling_metadata: Optional["PoolingMetadata"] = None
class PoolingModelRunner(
GPUModelRunnerBase[ModelInputForGPUWithPoolingMetadata]):
_model_input_cls: Type[ModelInputForGPUWithPoolingMetadata] = (
ModelInputForGPUWithPoolingMetadata)
_builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder
def __init__(
self,
vllm_config: VllmConfig,
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
):
super().__init__(vllm_config=vllm_config,
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker)
@torch.inference_mode()
def execute_model(
self,
model_input: ModelInputForGPUWithPoolingMetadata,
kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> Optional[Union[List[PoolerOutput], IntermediateTensors]]:
if num_steps > 1:
raise ValueError(
"PoolingModelRunner does not support multi-step execution.")
if self.lora_config:
assert model_input.lora_requests is not None
assert model_input.lora_mapping is not None
self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping)
if self.prompt_adapter_config:
assert model_input.prompt_adapter_requests is not None
assert model_input.prompt_adapter_mapping is not None
self.set_active_prompt_adapters(
model_input.prompt_adapter_requests,
model_input.prompt_adapter_mapping)
# Currently cuda graph is only supported by the decode phase.
assert model_input.attn_metadata is not None
prefill_meta = model_input.attn_metadata.prefill_metadata
decode_meta = model_input.attn_metadata.decode_metadata
virtual_engine = model_input.virtual_engine
# Pooling models are (ab-)used also to integrate non text models that
# are not autoregressive (PrithviGeosaptialMAE).
# These model might not use attention and do not really have a prefill
# and decode phase. The model input is processed in one shot and both
# decode_metadata and prefill_metadata would be None for such models.
# See the PlaceholderAttentionMetadata class.
# TODO: Figure out if cuda_graph is of any use for these models and
# explore how to leverage it.
if (prefill_meta is None and decode_meta is not None
and decode_meta.use_cuda_graph):
if model_input.inputs_embeds is None:
assert model_input.input_tokens is not None
graph_batch_size = model_input.input_tokens.shape[0]
model_executable = (
self.graph_runners[model_input.virtual_engine][(
graph_batch_size, False)])
else:
graph_batch_size = model_input.inputs_embeds.shape[0]
model_executable = (
self.graph_runners[model_input.virtual_engine][(
graph_batch_size, True)])
else:
model_executable = self.model
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
seqlen_agnostic_kwargs = {
"finished_requests_ids": model_input.finished_requests_ids,
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
} if self.has_inner_state else {}
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time):
model_forward_start = torch.cuda.Event(enable_timing=True)
model_forward_end = torch.cuda.Event(enable_timing=True)
model_forward_start.record()
cross_enc_kwargs = {}
if model_input.token_types is not None:
cross_enc_kwargs["token_type_ids"] = model_input.token_types
with set_forward_context(model_input.attn_metadata, self.vllm_config,
virtual_engine):
hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(
multi_modal_kwargs,
device=self.device,
),
**cross_enc_kwargs,
**seqlen_agnostic_kwargs,
)
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time):
model_forward_end.record()
# Only perform pooling in the last pipeline stage.
if not get_pp_group().is_last_rank:
if (self.is_driver_worker
and hidden_or_intermediate_states is not None
and isinstance(hidden_or_intermediate_states,
IntermediateTensors)
and self.observability_config is not None
and self.observability_config.collect_model_forward_time):
model_forward_end.synchronize()
model_forward_time = model_forward_start.elapsed_time(
model_forward_end)
orig_model_forward_time = 0.0
if intermediate_tensors is not None:
orig_model_forward_time = intermediate_tensors.tensors.get(
"model_forward_time", torch.tensor(0.0)).item()
hidden_or_intermediate_states.tensors["model_forward_time"] = (
torch.tensor(model_forward_time + orig_model_forward_time))
return hidden_or_intermediate_states
# Only perform pooling in the driver worker.
if not self.is_driver_worker:
return []
return [
self.model.pooler(hidden_states=hidden_or_intermediate_states,
pooling_metadata=model_input.pooling_metadata)
]
def make_model_input_from_broadcasted_tensor_dict(
self,
tensor_dict: Dict[str,
Any]) -> ModelInputForGPUWithPoolingMetadata:
return ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict(
tensor_dict,
attn_backend=self.attn_backend,
)
def prepare_model_input(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> ModelInputForGPUWithPoolingMetadata:
assert seq_group_metadata_list is not None
model_input = self._prepare_model_input_tensors(
seq_group_metadata_list, finished_requests_ids)
# Prepare PoolingMetadata.
assert model_input.seq_lens is not None
pooling_metadata = self._prepare_pooling(seq_group_metadata_list,
model_input.seq_lens)
return dataclasses.replace(model_input,
pooling_metadata=pooling_metadata)
def _prepare_pooling(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
prompt_lens: List[int],
) -> PoolingMetadata:
"""Prepare PoolingMetadata for the sequence group metadata list."""
seq_groups: List[Tuple[List[int], PoolingParams]] = []
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
seq_ids = list(seq_group_metadata.seq_data.keys())
pooling_params = seq_group_metadata.pooling_params
seq_groups.append((seq_ids, pooling_params))
seq_data: Dict[int, SequenceData] = {}
for seq_group_metadata in seq_group_metadata_list:
seq_data.update(seq_group_metadata.seq_data)
pooling_metadata = PoolingMetadata(
seq_groups=seq_groups,
seq_data=seq_data,
prompt_lens=prompt_lens,
)
return pooling_metadata

View File

@@ -0,0 +1,909 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum
import time
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
Type, Union)
from unittest.mock import patch
import numpy as np
import torch
import torch.nn as nn
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import VllmConfig
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
Logprob, SequenceGroupMetadata, SequenceOutput)
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase,
_add_attn_metadata_broadcastable_dict,
_init_attn_metadata_from_tensor_dict)
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
logger = init_logger(__name__)
# Here we utilize the behavior that out-of-bound index is ignored.
# FIXME(woosuk): Find a more reliable way to prevent possible bugs.
_PAD_SLOT_ID = 1_000_000_000
# FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow.
_ENABLE_TOP_P = False
# FIXME(woosuk): A temporary hack to support `n > 1`.
# This can significantly affect the performance if too large.
_MAX_NUM_SAMPLES = 128
class ExecutionMode(enum.Enum):
PREFILL = enum.auto()
DECODE = enum.auto()
PREFIX_PREFILL = enum.auto()
def is_prefill(self) -> bool:
return self in (ExecutionMode.PREFILL, ExecutionMode.PREFIX_PREFILL)
@dataclass(frozen=True)
class ModelInputForTPU(ModelRunnerInputBase):
token_ids: torch.Tensor
position_ids: torch.Tensor
attn_metadata: AttentionMetadata
input_lens: torch.Tensor
t: torch.Tensor
p: torch.Tensor
num_samples: int
n: List[int]
seq_groups: List[List[int]]
is_first_multi_step: bool = True
is_last_step: bool = True
virtual_engine: int = 0
async_callback: Optional[Callable] = None
def as_broadcastable_tensor_dict(
self) -> Dict[str, Union[int, torch.Tensor]]:
tensor_dict = {
"token_ids": self.token_ids,
"position_ids": self.position_ids,
"input_lens": self.input_lens,
"t": self.t,
"p": self.p,
"num_samples": self.num_samples,
"n": self.n,
"seq_groups": self.seq_groups,
"is_first_multi_step": self.is_first_multi_step,
"is_last_step": self.is_last_step,
"virtual_engine": self.virtual_engine,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
return tensor_dict
@classmethod
def from_broadcasted_tensor_dict(
cls: Type["ModelInputForTPU"],
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None,
) -> "ModelInputForTPU":
if attn_backend is not None:
tensor_dict = _init_attn_metadata_from_tensor_dict(
attn_backend, tensor_dict)
return cls(**tensor_dict)
class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
def __init__(
self,
vllm_config: VllmConfig,
is_driver_worker: bool = False,
):
ModelRunnerBase.__init__(self, vllm_config=vllm_config)
self.is_driver_worker = is_driver_worker
self.block_size = self.cache_config.block_size
self.max_num_blocks_per_seq = (self.model_config.max_model_len //
self.block_size)
self.block_tables = np.zeros(
(self.scheduler_config.max_num_seqs, self.max_num_blocks_per_seq),
dtype=np.int32)
self.attn_backend = get_attn_backend(
self.model_config.get_head_size(),
self.model_config.dtype,
self.cache_config.cache_dtype,
self.block_size,
self.model_config.is_attention_free,
False,
)
self.cached_step_outputs: List[torch.Tensor] = []
smem_size = 512 * 1024
block_table_size = 4 * self.block_tables.size
if block_table_size >= smem_size:
logger.warning(
"The max_model_len (%d) is too large. This may degrade the "
"performance due to the insufficient smem size. Consider "
"setting --max-model-len to a smaller value, like %d.",
self.model_config.max_model_len,
self.model_config.max_model_len /
(block_table_size / smem_size))
def load_model(self) -> None:
self.device = self.device_config.device
# NOTE(woosuk): While the executor assigns the TP ranks to the worker
# process, the ranks can be different from the ranks internally assigned
# by the xm runtime. Therefore, there is a mismatch in the rank
# assignment between the gloo (cpu) runtime and the xm (tpu) runtime.
# This is not a problem in linear layers because all-reduce is
# rank-agnostic. However, it matters for all-gather as the ranks
# determine the order of concatenating the output tensors.
# As a workaround, we use the xm's rank assignment only when loading
# the embedding weights.
xm_tp_rank = xr.global_ordinal()
with patch(
"vllm.model_executor.layers.vocab_parallel_embedding."
"get_tensor_model_parallel_rank",
return_value=xm_tp_rank):
model = get_model(vllm_config=self.vllm_config)
model = model.eval()
xm.wait_device_ops()
model = ModelWrapper(model)
self.model = torch.compile(model,
backend="openxla",
fullgraph=True,
dynamic=False)
def get_model(self) -> nn.Module:
return self.model.model
def _dummy_run(
self,
batch_size: int,
seq_len: int,
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
exec_mode: ExecutionMode,
) -> None:
exec_mode = ExecutionMode(exec_mode)
if exec_mode.is_prefill():
seq_len = (seq_len + 15) // 16 * 16
token_ids = torch.zeros((batch_size, seq_len),
dtype=torch.int32,
device=self.device)
position_ids = torch.zeros((batch_size, seq_len),
dtype=torch.int32,
device=self.device)
slot_mapping = torch.zeros((batch_size, seq_len),
dtype=torch.int64,
device=self.device)
input_lens = torch.ones((batch_size, ),
dtype=torch.int32,
device=self.device)
if exec_mode == ExecutionMode.PREFILL:
attn_metadata = self.attn_backend.make_metadata(
num_prefills=batch_size,
num_prefill_tokens=batch_size * seq_len,
num_decode_tokens=0,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
block_tables=None,
context_lens=None,
effective_query_lens=None,
)
else:
context_lens = torch.ones((batch_size, ),
dtype=torch.int32,
device=self.device)
block_tables = torch.tensor(self.block_tables[:batch_size],
dtype=torch.int32,
device=self.device)
effective_query_lens = torch.ones_like(context_lens)
attn_metadata = self.attn_backend.make_metadata(
num_prefills=batch_size,
num_prefill_tokens=batch_size * seq_len,
num_decode_tokens=0,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
block_tables=block_tables,
context_lens=context_lens,
effective_query_lens=effective_query_lens,
)
else:
assert seq_len == 1
token_ids = torch.zeros((batch_size, seq_len),
dtype=torch.int32,
device=self.device)
position_ids = torch.zeros((batch_size, seq_len),
dtype=torch.int32,
device=self.device)
slot_mapping = torch.zeros((batch_size, seq_len),
dtype=torch.int64,
device=self.device)
block_tables = torch.zeros(
(batch_size, self.max_num_blocks_per_seq),
dtype=torch.int32,
device=self.device)
context_lens = torch.ones((batch_size, ),
dtype=torch.int32,
device=self.device)
input_lens = torch.ones((batch_size, ),
dtype=torch.int32,
device=self.device)
attn_metadata = self.attn_backend.make_metadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=batch_size * seq_len,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
block_tables=block_tables,
context_lens=context_lens,
)
t = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
num_samples = _MAX_NUM_SAMPLES if exec_mode.is_prefill() else 1
# NOTE(woosuk): There are two stages of compilation: torch.compile and
# XLA compilation. Using `mark_dynamic` can reduce the torch.compile
# overhead by reusing the FX graph for different shapes.
# However, the XLA graph will still require static shapes and needs to
# be re-compiled for every different shapes. This overhead is inevitable
# in the first run, but can be skipped afterwards as we cache the XLA
# graphs in the disk (VLLM_XLA_CACHE_PATH).
if exec_mode.is_prefill():
# Prefll
torch._dynamo.mark_dynamic(token_ids, 1)
torch._dynamo.mark_dynamic(position_ids, 1)
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1)
else:
# Decode
torch._dynamo.mark_dynamic(token_ids, 0)
torch._dynamo.mark_dynamic(position_ids, 0)
torch._dynamo.mark_dynamic(input_lens, 0)
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0)
torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0)
torch._dynamo.mark_dynamic(t, 0)
torch._dynamo.mark_dynamic(p, 0)
# Dummy run.
with set_forward_context(attn_metadata, self.vllm_config, 0):
self.model(token_ids, position_ids, input_lens, t, p, num_samples,
kv_caches)
def warmup_model(
self,
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
) -> None:
# Prefill
logger.info("Compiling the model with different input shapes...")
start = time.time()
for batch_size in [1]:
seq_len = 16
while seq_len <= self.model_config.max_model_len:
self._dummy_run(batch_size,
seq_len,
kv_caches,
exec_mode=ExecutionMode.PREFILL)
xm.wait_device_ops()
logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len)
num_tokens = batch_size * seq_len
if num_tokens >= self.scheduler_config.max_num_batched_tokens:
break
seq_len = seq_len * 2
end = time.time()
logger.info("Compilation for prefill done in %.2f s.", end - start)
# Prefix prefill
if self.cache_config.enable_prefix_caching:
logger.info("Compiling the model with different input shapes for "
"prefix prefill...")
start = time.time()
for batch_size in [1]:
seq_len = 16
while seq_len <= self.model_config.max_model_len:
self._dummy_run(batch_size,
seq_len,
kv_caches,
exec_mode=ExecutionMode.PREFIX_PREFILL)
xm.wait_device_ops()
logger.info("batch_size: %d, seq_len: %d", batch_size,
seq_len)
num_tokens = batch_size * seq_len
if (num_tokens
>= self.scheduler_config.max_num_batched_tokens):
break
seq_len = seq_len * 2
end = time.time()
logger.info("Compilation for prefix prefill done in %.2f s.",
end - start)
# Decode
start = time.time()
seq_len = 1
batch_size = 8 # Must be in sync with _get_padded_batch_size()
while True:
self._dummy_run(batch_size,
seq_len,
kv_caches,
exec_mode=ExecutionMode.DECODE)
xm.wait_device_ops()
logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len)
if batch_size >= self.scheduler_config.max_num_seqs:
break
batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2
end = time.time()
logger.info("Compilation for decode done in %.2f s.", end - start)
def _prepare_prompt(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[int] = []
input_positions: List[int] = []
prompt_lens: List[int] = []
context_lens: List[int] = []
slot_mapping: List[int] = []
for batch_idx, seq_group_metadata in enumerate(
seq_group_metadata_list):
assert seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys())
assert len(seq_ids) == 1
seq_id = seq_ids[0]
seq_data = seq_group_metadata.seq_data[seq_id]
# Could include output tokens when a request is preempted.
prompt_tokens = seq_data.get_token_ids()
seq_len = len(prompt_tokens)
num_computed_blocks = len(seq_group_metadata.computed_block_nums)
num_computed_tokens = num_computed_blocks * self.block_size
if num_computed_tokens > 0:
prompt_tokens = prompt_tokens[num_computed_tokens:]
context_lens.append(seq_len)
else:
context_lens.append(0)
prompt_len = len(prompt_tokens)
prompt_lens.append(prompt_len)
input_tokens.extend(prompt_tokens)
input_positions.extend(range(num_computed_tokens, seq_len))
assert seq_group_metadata.block_tables is not None
block_table = seq_group_metadata.block_tables[seq_id]
for i in range(num_computed_tokens, seq_len):
block_number = block_table[i // self.block_size]
block_offset = i % self.block_size
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)
if num_computed_tokens > 0:
self.block_tables[batch_idx, :len(block_table)] = block_table
# Add paddings to EACH prompt to the smallest power of 2 that is
# greater than or equal to the prompt length.
# We pad the seq_len to reduce the compilation overhead.
# We execute each prompt individually (i.e., with batch_size 1)
# because the FlashAttention kernel does not support ragged inputs.
# TODO(woosuk): Use SplashAttention to support ragged inputs.
padded_prompt_len = _get_padded_prefill_len(prompt_len)
num_paddings = padded_prompt_len - prompt_len
input_tokens += [0] * num_paddings
input_positions += [0] * num_paddings
slot_mapping += [_PAD_SLOT_ID] * num_paddings
assert len(prompt_lens) > 0
num_prefills = len(prompt_lens)
input_tokens = torch.tensor(input_tokens,
dtype=torch.int32,
device="cpu")
input_positions = torch.tensor(input_positions,
dtype=torch.int32,
device="cpu")
slot_mapping = torch.tensor(slot_mapping,
dtype=torch.int64,
device="cpu")
prompt_lens = torch.tensor(prompt_lens,
dtype=torch.int32,
device="cpu")
context_lens = torch.tensor(context_lens,
dtype=torch.int32,
device="cpu")
block_tables = torch.tensor(self.block_tables[:num_prefills],
dtype=torch.int32,
device="cpu")
attn_metadata = self.attn_backend.make_metadata(
num_prefills=num_prefills,
num_prefill_tokens=0, # NOTE: This is not used.
num_decode_tokens=0,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
block_tables=block_tables,
context_lens=context_lens,
effective_query_lens=prompt_lens,
)
return input_tokens, input_positions, attn_metadata, prompt_lens
def _prepare_decode(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[List[int]] = []
input_positions: List[List[int]] = []
slot_mapping: List[List[int]] = []
context_lens: List[int] = []
batch_idx = 0
for seq_group_metadata in seq_group_metadata_list:
assert not seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys())
for seq_id in seq_ids:
seq_data = seq_group_metadata.seq_data[seq_id]
generation_token = seq_data.get_last_token_id()
input_tokens.append([generation_token])
seq_len = seq_data.get_len()
position = seq_len - 1
input_positions.append([position])
context_lens.append(seq_len)
assert seq_group_metadata.block_tables is not None
block_table = seq_group_metadata.block_tables[seq_id]
self.block_tables[batch_idx, :len(block_table)] = block_table
batch_idx += 1
block_number = block_table[position // self.block_size]
block_offset = position % self.block_size
slot = block_number * self.block_size + block_offset
slot_mapping.append([slot])
batch_size = _get_padded_batch_size(batch_idx)
num_paddings = batch_size - batch_idx
input_tokens = input_tokens + [[0]] * num_paddings
input_positions = input_positions + [[0]] * num_paddings
slot_mapping = slot_mapping + [[_PAD_SLOT_ID]] * num_paddings
context_lens = context_lens + [0] * num_paddings
input_tokens = torch.tensor(input_tokens,
dtype=torch.int32,
device="cpu")
input_positions = torch.tensor(input_positions,
dtype=torch.int32,
device="cpu")
slot_mapping = torch.tensor(slot_mapping,
dtype=torch.int64,
device="cpu")
context_lens = torch.tensor(context_lens,
dtype=torch.int32,
device="cpu")
block_tables = torch.tensor(self.block_tables[:batch_size],
dtype=torch.int32,
device="cpu")
input_lens = torch.tensor([1] * batch_size,
dtype=torch.int32,
device="cpu")
attn_metadata = self.attn_backend.make_metadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=batch_size,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
block_tables=block_tables,
context_lens=context_lens,
)
return input_tokens, input_positions, attn_metadata, input_lens
def _prepare_sample(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
padded_batch_size: int,
) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
assert len(seq_group_metadata_list) > 0
t = []
p = []
n = []
for seq_group_metadata in seq_group_metadata_list:
sampling_params = seq_group_metadata.sampling_params
t.append(sampling_params.temperature)
if sampling_params.top_p != 1 and not _ENABLE_TOP_P:
raise NotImplementedError(
"Top-p sampling is currently disabled for the TPU backend "
"due to performance issues.")
p.append(sampling_params.top_p)
if sampling_params.top_k > 0:
raise NotImplementedError(
"Top-k sampling is currently disabled for the TPU backend "
"due to performance issues.")
if sampling_params.n > _MAX_NUM_SAMPLES:
raise NotImplementedError(
f"Best of > {_MAX_NUM_SAMPLES} is not supported by the TPU "
"backend.")
n.append(sampling_params.n)
if sampling_params.logprobs is not None:
raise NotImplementedError(
"logprobs is not currently supported by the TPU backend.")
if sampling_params.prompt_logprobs is not None:
raise NotImplementedError(
"prompt_logprobs is not currently supported by the TPU "
"backend.")
# Repeat the sampling params if the seq group has multiple seqs.
num_seqs = len(seq_group_metadata.seq_data)
t += [t[-1]] * (num_seqs - 1)
p += [p[-1]] * (num_seqs - 1)
n += [n[-1]] * (num_seqs - 1)
num_paddings = padded_batch_size - len(t)
t += [1.0] * num_paddings
p += [1.0] * num_paddings
t = torch.tensor(t, dtype=torch.float32, device="cpu")
p = torch.tensor(p, dtype=torch.float32, device="cpu")
return t, p, n
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None,
) -> ModelInputForTPU:
del finished_requests_ids # Unused.
assert virtual_engine == 0
assert len(seq_group_metadata_list) > 0
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
is_prompt = seq_group_metadata_list[0].is_prompt
if is_prompt:
inputs = self._prepare_prompt(seq_group_metadata_list)
else:
inputs = self._prepare_decode(seq_group_metadata_list)
input_tokens, input_positions, attn_metadata, input_lens = inputs
padded_batch_size = input_tokens.shape[0]
t, p, n = self._prepare_sample(seq_group_metadata_list,
padded_batch_size)
num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
seq_groups = [
list(metadata.seq_data.keys())
for metadata in seq_group_metadata_list
]
return ModelInputForTPU(input_tokens, input_positions, attn_metadata,
input_lens, t, p, num_samples, n, seq_groups)
def make_model_input_from_broadcasted_tensor_dict(
self, tensor_dict: Dict[str, Any]) -> ModelInputForTPU:
model_input = ModelInputForTPU.from_broadcasted_tensor_dict(
tensor_dict, attn_backend=self.attn_backend)
return model_input
@torch.no_grad()
def execute_model(
self,
model_input: ModelInputForTPU,
kv_caches: Optional[List[Any]],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> List[SamplerOutput]:
assert intermediate_tensors is None
if not model_input.is_first_multi_step:
if not model_input.is_last_step:
return []
use_async_out_proc = model_input.async_callback is not None
sampler_outputs = []
num_outputs = len(self.cached_step_outputs)
for i in range(num_outputs):
next_token_ids = self.cached_step_outputs.pop(0)
next_token_ids = next_token_ids.cpu().tolist()
sampler_output = _make_decode_output(next_token_ids,
model_input.seq_groups)
sampler_outputs.append(sampler_output)
if i < num_outputs - 1 and use_async_out_proc:
assert model_input.async_callback is not None
ctx = model_input.async_callback.keywords[ # type: ignore
"ctx"]
ctx.append_output(
outputs=[sampler_output],
seq_group_metadata_list=ctx.seq_group_metadata_list,
scheduler_outputs=ctx.scheduler_outputs,
is_async=False,
is_last_step=False,
is_first_step_output=i == 0)
model_input.async_callback()
if use_async_out_proc:
return [sampler_outputs[-1]]
else:
return sampler_outputs
is_prompt = model_input.attn_metadata.num_prefills > 0
if is_prompt:
assert num_steps == 1
# NOTE(woosuk): Since the FlashAttention kernel does not support
# ragged inputs, we split the prompts into different batches and
# process them separately. This is a temporary hack that should be
# optimized by using SplashAttention.
orig_slot_mapping = model_input.attn_metadata.slot_mapping
orig_block_tables = model_input.attn_metadata.block_tables
orig_context_lens = model_input.attn_metadata.context_lens
orig_effective_query_lens = \
model_input.attn_metadata.effective_query_lens
batch_size = model_input.input_lens.shape[0]
start_idx = 0
next_token_ids = []
for i in range(batch_size):
# Get the actual prefill_len.
prefill_len = model_input.input_lens[i:i + 1].item()
prefill_len = _get_padded_prefill_len(prefill_len)
end_idx = start_idx + prefill_len
token_ids = model_input.token_ids[None, start_idx:end_idx].to(
self.device)
position_ids = model_input.position_ids[None,
start_idx:end_idx].to(
self.device)
attn_metadata = model_input.attn_metadata
attn_metadata.num_prefills = 1
attn_metadata.slot_mapping = orig_slot_mapping[
None, start_idx:end_idx].to(self.device)
if orig_context_lens[i].item() > 0:
attn_metadata.context_lens = orig_context_lens[i:i + 1].to(
self.device)
attn_metadata.block_tables = orig_block_tables[
i].unsqueeze(0).to(self.device)
attn_metadata.effective_query_lens = \
orig_effective_query_lens[i:i + 1].to(self.device)
else:
attn_metadata.context_lens = None
attn_metadata.block_tables = None
attn_metadata.effective_query_lens = None
input_lens = model_input.input_lens[i:i + 1].to(self.device)
t = model_input.t[i:i + 1].to(self.device)
p = model_input.p[i:i + 1].to(self.device)
with set_forward_context(model_input.attn_metadata,
self.vllm_config,
model_input.virtual_engine):
output_token_ids = self.model(token_ids, position_ids,
input_lens, t, p,
model_input.num_samples,
kv_caches)
next_token_ids.append(output_token_ids[0])
start_idx = end_idx
if model_input.async_callback is not None:
model_input.async_callback()
# Retrieve the outputs to CPU.
next_token_ids = [
output_token_ids.cpu().tolist()
for output_token_ids in next_token_ids
]
# NOTE(woosuk): Minimal code to construct the sampler outputs.
# The TPU backend does not reuse the sampler, since the TPU backend
# does not support advanced sampling parameters such as logprobs.
zero_logprob = Logprob(0.0)
sampler_outputs = []
for i, seq_group in enumerate(model_input.seq_groups):
seq_ids = seq_group
assert len(seq_ids) == 1
seq_id = seq_ids[0]
seq_outputs = []
for j in range(model_input.n[i]):
next_token_id = next_token_ids[i][j]
seq_outputs.append(
SequenceOutput(seq_id, next_token_id,
{next_token_id: zero_logprob}))
sampler_outputs.append(
CompletionSequenceGroupOutput(seq_outputs, None))
return [SamplerOutput(sampler_outputs)]
else:
token_ids = model_input.token_ids.to(self.device)
position_ids = model_input.position_ids.to(self.device)
attn_metadata = model_input.attn_metadata
attn_metadata.slot_mapping = attn_metadata.slot_mapping.to(
self.device)
attn_metadata.block_tables = attn_metadata.block_tables.to(
self.device)
attn_metadata.context_lens = attn_metadata.context_lens.to(
self.device)
t = model_input.t.to(self.device)
p = model_input.p.to(self.device)
input_lens = model_input.input_lens.to(self.device)
for i in range(num_steps):
slot_mapping = attn_metadata.slot_mapping
with set_forward_context(model_input.attn_metadata,
self.vllm_config,
model_input.virtual_engine):
output_token_ids = self.model(token_ids, position_ids,
input_lens, t, p,
model_input.num_samples,
kv_caches)
self.cached_step_outputs.append(output_token_ids)
if i < num_steps - 1:
# Prepare the inputs for the next step.
token_ids = output_token_ids.unsqueeze(dim=1).int()
position_ids = position_ids + 1
attn_metadata.context_lens = attn_metadata.context_lens + 1
block_tables = attn_metadata.block_tables
block_number = block_tables.gather(
1,
position_ids.long() // self.block_size)
block_offset = position_ids % self.block_size
is_padding = slot_mapping == _PAD_SLOT_ID
slot_mapping = block_number * self.block_size + block_offset
slot_mapping = slot_mapping.long()
slot_mapping = torch.where(is_padding, _PAD_SLOT_ID,
slot_mapping)
attn_metadata.slot_mapping = slot_mapping
if model_input.async_callback is not None:
model_input.async_callback()
if num_steps > 1:
return []
# Retrieve the outputs to CPU.
next_token_ids = self.cached_step_outputs.pop(0)
next_token_ids = next_token_ids.cpu().tolist()
sampler_output = _make_decode_output(next_token_ids,
model_input.seq_groups)
return [sampler_output]
class ModelWrapper(nn.Module):
def __init__(self, model: nn.Module):
super().__init__()
self.model = model
def forward(
self,
token_ids: torch.Tensor,
position_ids: torch.Tensor,
input_lens: torch.Tensor,
t: torch.Tensor,
p: torch.Tensor,
num_samples: int,
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
) -> torch.Tensor:
"""Executes the forward pass of the model and samples the next token.
Args:
token_ids: The input token IDs of shape [batch_size, seq_len].
position_ids: The input position IDs of shape [batch_size, seq_len].
input_lens: The actual input lengths of shape [batch_size].
t: The sampling temperature of shape [batch_size].
p: The top-p probability of shape [batch_size].
num_samples: Number of samples to draw from each logits vector.
kv_caches: The key and value caches. They can be None during the
memory profiling at initialization.
"""
batch_size, seq_len = token_ids.shape
# Calculate the positions to sample from.
start_indices = torch.arange(
batch_size, dtype=torch.int32, device=input_lens.device) * seq_len
logits_indices = start_indices + input_lens - 1
attn_metadata = get_forward_context().attn_metadata
# FIXME(woosuk): This is a temporary hack to avoid using the existing
# sampler and sampling metadata.
sampling_metadata = SamplingMetadata(
seq_groups=[],
selected_token_indices=logits_indices,
categorized_sample_indices={},
num_prompts=attn_metadata.num_prefills,
)
# Skip this in memory profiling at initialization.
if kv_caches[0][0].numel() > 0:
# index_copy_(slot_mapping) only works when the inserted dimension
# is 0. However, the KV cache in the Pallas backend has the shape
# [num_kv_heads, num_blocks, block_size, head_size]. To make it
# work, we need to flatten the first three dimensions and modify
# the slot_mapping accordingly.
num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape
slot_mapping = attn_metadata.slot_mapping
slot_mapping = slot_mapping.flatten()
head_indices = torch.arange(0,
num_kv_heads,
device=slot_mapping.device,
dtype=slot_mapping.dtype)
head_indices *= block_size * num_blocks
slot_mapping = slot_mapping.repeat_interleave(num_kv_heads).view(
-1, num_kv_heads)
slot_mapping = slot_mapping + head_indices.view(1, -1)
slot_mapping = slot_mapping.flatten()
attn_metadata.slot_mapping = slot_mapping
hidden_states = self.model(token_ids, position_ids)
hidden_states = hidden_states.flatten(0, 1)
logits = self.model.compute_logits(hidden_states, sampling_metadata)
# Argmax sampling.
argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True)
argmax_token_ids = argmax_token_ids.repeat(1, num_samples)
# Zero temperature means greedy decoding. Avoid division by zero.
nonzero_t = torch.where(t != 0, t, 1.0)
logits = logits / nonzero_t.unsqueeze(dim=1)
if _ENABLE_TOP_P:
logits = _apply_top_p(logits, p.unsqueeze(dim=1))
# Random sampling.
probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
sampled_token_ids = torch.multinomial(probs,
num_samples,
replacement=True)
if num_samples == 1:
argmax_token_ids = argmax_token_ids.squeeze(dim=-1)
sampled_token_ids = sampled_token_ids.squeeze(dim=-1)
next_token_ids = torch.where(t != 0, sampled_token_ids,
argmax_token_ids)
return next_token_ids
def _get_padded_prefill_len(x: int) -> int:
# NOTE(woosuk): The pallas FlashAttention kernel requires the sequence
# length to be a multiple of 16. We pad the prompt length to the nearest
# multiple of 16. This is also good for performance.
if x <= 16:
return 16
return 1 << (x - 1).bit_length()
def _get_padded_batch_size(batch_size: int) -> int:
# The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16.
# To meet this requirement in the simplest way, we set the minimal batch
# size to 8.
if batch_size <= 8:
return 8
else:
return ((batch_size + 15) // 16) * 16
def _apply_top_p(logits: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
logits_sorted = torch.sort(logits, dim=-1, descending=True).values
sorted_cum_probs = torch.cumsum(logits_sorted.softmax(dim=-1), dim=-1)
cutoff_index = torch.sum(sorted_cum_probs < p, dim=-1, keepdim=True)
cutoff_logit = torch.gather(logits_sorted, -1, cutoff_index)
logits = logits.masked_fill_(logits < cutoff_logit, -float("inf"))
return logits
def _make_decode_output(
next_token_ids: List[int],
seq_groups: List[List[int]],
) -> SamplerOutput:
zero_logprob = Logprob(0.0)
sampler_outputs = []
batch_idx = 0
for seq_group in seq_groups:
seq_ids = seq_group
seq_outputs = []
for seq_id in seq_ids:
next_token_id = next_token_ids[batch_idx]
seq_outputs.append(
SequenceOutput(seq_id, next_token_id,
{next_token_id: zero_logprob}))
batch_idx += 1
sampler_outputs.append(CompletionSequenceGroupOutput(
seq_outputs, None))
return SamplerOutput(sampler_outputs)

341
vllm/worker/tpu_worker.py Normal file
View File

@@ -0,0 +1,341 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from typing import List, Optional, Tuple, Union
import torch
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp
import torch_xla.runtime as xr
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.sequence import ExecuteModelRequest
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, bind_kv_cache, get_dtype_size
from vllm.worker.tpu_model_runner import ExecutionMode, TPUModelRunner
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
LoRANotSupportedWorkerBase, WorkerBase,
WorkerInput)
logger = init_logger(__name__)
class TPUWorker(LoRANotSupportedWorkerBase, LocalOrDistributedWorkerBase):
def __init__(
self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool,
) -> None:
WorkerBase.__init__(self, vllm_config=vllm_config)
self.parallel_config.rank = rank
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
self.is_driver_worker = is_driver_worker
assert self.device_config.device_type == "tpu"
if self.cache_config.cache_dtype == "auto":
self.cache_dtype = self.model_config.dtype
else:
self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
self.cache_config.cache_dtype]
self.model_runner: TPUModelRunner = TPUModelRunner(
vllm_config=vllm_config, is_driver_worker=is_driver_worker)
if self.model_config.seed is None:
self.model_config.seed = 0
if vllm_config.lora_config is not None:
raise NotImplementedError(
"The V0 TPU backend doesn't support LoRA serving")
def init_device(self) -> None:
os.environ["PJRT_DEVICE"] = "TPU"
torch.set_grad_enabled(False)
torch.set_default_dtype(self.model_config.dtype)
# NOTE(woosuk): This is just to initialize the TP group and broadcast
# the input objects on CPU. The all-reduce and all-gather ops on TPU
# are invoked by `xm.all_reduce` and `xm.all_gather` which use their
# own context.
init_distributed_environment(
world_size=self.parallel_config.world_size,
rank=self.rank,
local_rank=self.local_rank,
distributed_init_method=self.distributed_init_method,
backend="gloo",
)
ensure_model_parallel_initialized(
self.parallel_config.tensor_parallel_size,
self.parallel_config.pipeline_parallel_size)
# Device initialization should happen after initializing the distributed
# runtime.
self.device = xm.xla_device()
self.device_config.device = self.device
# Set random seed.
set_random_seed(self.model_config.seed)
xm.set_rng_state(self.model_config.seed, self.device)
# Increase the cache size limit, which is the maximum number of
# dynamo graphs that can be compiled.
# NOTE(woosuk): Usually, we compile 10-15 graphs for prefill and
# 30-40 graphs for decode. 128 is an arbitrary safe number.
torch._dynamo.config.cache_size_limit = 128
# Use persistent cache to avoid XLA recompilation.
# NOTE(woosuk): Set per-rank cache path since different ranks
# can have slightly different XLA graphs.
world_size = self.parallel_config.world_size
rank = xr.global_ordinal()
# The PyTorch/XLA compilation cache uses the Torch IR to generate keys.
# Consequently, changes in optimization flags, which affect compilation
# results, don't change the cache key. This can result in the wrong
# compilation being used. To prevent this, disabling the XLA compilation
# cache during development is recommended.We can disable it by
# `export VLLM_XLA_CACHE_PATH=`
if envs.VLLM_XLA_CACHE_PATH:
per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH,
f"tp{world_size}_rank{rank}")
xr.initialize_cache(per_rank_path, readonly=False)
self.profiler = None
if envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1:
# For TPU, we can only have 1 active profiler session for 1 profiler
# server. So we only profile on rank0.
self.profile_dir = envs.VLLM_TORCH_PROFILER_DIR
logger.info("Profiling enabled. Traces will be saved to: %s",
self.profile_dir)
self.profiler = xp.start_server(9012)
def start_profile(self):
if self.rank < 1:
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
xp.start_trace(self.profile_dir)
def stop_profile(self):
if self.rank < 1:
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
xp.stop_trace()
def load_model(self):
self.model_runner.load_model()
def determine_num_available_blocks(self) -> Tuple[int, int]:
num_layers = self.model_config.get_num_layers(self.parallel_config)
head_size = self.model_config.get_head_size()
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value ``None``.
# the `dtype` argument does not matter, and we use `float32` as
# a placeholder (it has wide hardware support).
kv_caches = [(torch.tensor([], dtype=torch.float32,
device=self.device),
torch.tensor([], dtype=torch.float32,
device=self.device))
for _ in range(num_layers)]
bind_kv_cache(self.compilation_config.static_forward_context,
[kv_caches])
self.model_runner._dummy_run(
batch_size=1,
seq_len=self.scheduler_config.max_num_batched_tokens,
kv_caches=kv_caches,
exec_mode=ExecutionMode.PREFILL,
)
# Synchronize before measuring the memory usage.
xm.wait_device_ops()
# Get the maximum amount of memory used by the model weights and
# intermediate activations.
m = xm.get_memory_info(self.device)
total_memory_size = m["bytes_limit"]
profiled = m["peak_bytes_used"] # Weights + intermediate activations.
# Calculate the TPU KV cache size based on profiling.
usable_memory_size = int(total_memory_size *
self.cache_config.gpu_memory_utilization)
tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0)
dtype_bytes = get_dtype_size(self.cache_dtype)
block_size_bytes = (dtype_bytes * self.cache_config.block_size *
num_layers * 2 * head_size * num_kv_heads)
num_tpu_blocks = tpu_kv_cache_bytes // block_size_bytes
num_tpu_blocks = (num_tpu_blocks // 8) * 8 # Round down to 8.
# Calculate the CPU KV cache size based on the config.
num_cpu_blocks = int(self.cache_config.swap_space_bytes //
block_size_bytes)
num_cpu_blocks = (num_cpu_blocks // 8) * 8 # Round down to 8.
return num_tpu_blocks, num_cpu_blocks
def initialize_cache(
self,
num_gpu_blocks: int,
num_cpu_blocks: int,
) -> None:
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
self.block_size = self.cache_config.block_size
dtype = self.cache_dtype
num_layers = self.model_config.get_num_layers(self.parallel_config)
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
head_size = self.model_config.get_head_size()
self.cpu_cache: List[Tuple[torch.Tensor, torch.Tensor]] = []
self.tpu_cache: List[Tuple[torch.Tensor, torch.Tensor]] = []
tpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape(
num_gpu_blocks, self.block_size, num_kv_heads, head_size)
cpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape(
num_cpu_blocks, self.block_size, num_kv_heads, head_size)
for _ in range(num_layers):
tpu_k_cache = torch.zeros(tpu_cache_shape,
dtype=dtype,
device=self.device)
tpu_v_cache = torch.zeros_like(tpu_k_cache)
self.tpu_cache.append((tpu_k_cache, tpu_v_cache))
cpu_k_cache = torch.zeros(cpu_cache_shape,
dtype=dtype,
device="cpu")
cpu_v_cache = torch.zeros_like(cpu_k_cache)
self.cpu_cache.append((cpu_k_cache, cpu_v_cache))
bind_kv_cache(self.compilation_config.static_forward_context,
[self.tpu_cache])
self._warmup_model()
def _warmup_model(self) -> None:
# FIXME(woosuk): Here we are abusing `enforce_eager` which is defined
# for CUDA graphs. We should refactor this part.
if not self.model_config.enforce_eager:
# Warm up the model with all possible input shapes so that
# compilation never happens during the actual execution.
# This may take ~30 mins for the first run and ~20 mins for the
# subsequent runs.
# If `enforce_eager` is True, the ahead-of-time compilation is
# skipped and the compilation happens during the actual execution,
# which is bad for performance but useful for development.
self.model_runner.warmup_model(self.tpu_cache)
def get_cache_block_size_bytes(self) -> int:
head_size = self.model_config.get_head_size()
num_heads = self.model_config.get_num_kv_heads(self.parallel_config)
num_layers = self.model_config.get_num_layers(self.parallel_config)
key_cache_block = self.cache_config.block_size * num_heads * head_size
value_cache_block = key_cache_block
total = num_layers * (key_cache_block + value_cache_block)
dtype_size = get_dtype_size(self.cache_dtype)
return dtype_size * total
@property
def do_metadata_broadcast(self) -> bool:
return self.parallel_config.tensor_parallel_size > 1
@property
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
# NOTE(woosuk): This assumes virtual_engine == 0, i.e., no pipeline
# parallelism.
return [self.tpu_cache]
@property
def cache_engines(self) -> Optional[List[CacheEngine]]:
return None
def prepare_worker_input(
self,
execute_model_req: ExecuteModelRequest,
) -> WorkerInput:
virtual_engine = execute_model_req.virtual_engine
num_seq_groups = len(execute_model_req.seq_group_metadata_list)
blocks_to_swap_in = _make_src_to_dst(
execute_model_req.blocks_to_swap_in, "cpu", self.device)
blocks_to_swap_out = _make_src_to_dst(
execute_model_req.blocks_to_swap_out, self.device, "cpu")
blocks_to_copy = _make_src_to_dst(execute_model_req.blocks_to_copy,
self.device, self.device)
return WorkerInput(
num_seq_groups=num_seq_groups,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
virtual_engine=virtual_engine,
)
def execute_worker(self, worker_input: WorkerInput) -> None:
virtual_engine = worker_input.virtual_engine
assert virtual_engine == 0
attn_backend = self.model_runner.attn_backend
num_layers = self.model_config.get_num_layers(self.parallel_config)
# Issue cache operations.
if worker_input.blocks_to_swap_in is not None:
src_indices, dst_indices = worker_input.blocks_to_swap_in
if src_indices.numel() > 0:
# Swap from CPU to TPU.
for i in range(num_layers):
tpu_k_cache, tpu_v_cache = self.tpu_cache[i]
cpu_k_cache, cpu_v_cache = self.cpu_cache[i]
k = cpu_k_cache[:, src_indices].to(self.device)
v = cpu_v_cache[:, src_indices].to(self.device)
_insert_kv(k, v, dst_indices, tpu_k_cache, tpu_v_cache)
if worker_input.blocks_to_swap_out is not None:
src_indices, dst_indices = worker_input.blocks_to_swap_out
if src_indices.numel() > 0:
# Swap from TPU to CPU.
for i in range(num_layers):
tpu_k_cache, tpu_v_cache = self.tpu_cache[i]
cpu_k_cache, cpu_v_cache = self.cpu_cache[i]
cpu_k_cache[:, dst_indices] = tpu_k_cache[:, src_indices]
cpu_v_cache[:, dst_indices] = tpu_v_cache[:, src_indices]
if worker_input.blocks_to_copy is not None:
src_indices, dst_indices = worker_input.blocks_to_copy
if src_indices.numel() > 0:
attn_backend.copy_blocks(self.tpu_cache,
(src_indices, dst_indices))
def _make_src_to_dst(
mapping: List[Tuple[int, int]],
src_device: Union[torch.device, str],
dst_device: Union[torch.device, str],
) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
if not mapping:
return None
src_indices = [i for i, _ in mapping]
dst_indices = [i for _, i in mapping]
src_indices = torch.tensor(src_indices,
device=src_device,
dtype=torch.int64)
dst_indices = torch.tensor(dst_indices,
device=dst_device,
dtype=torch.int64)
return src_indices, dst_indices
@torch.compile(backend="openxla")
def _insert_kv(
k: torch.Tensor,
v: torch.Tensor,
indices: torch.Tensor,
tpu_k_cache: torch.Tensor,
tpu_v_cache: torch.Tensor,
) -> None:
torch.ops.xla.dynamo_set_buffer_donor_(tpu_k_cache, True)
torch.ops.xla.dynamo_set_buffer_donor_(tpu_v_cache, True)
tpu_k_cache[:, indices] = k
tpu_v_cache[:, indices] = v

53
vllm/worker/utils.py Normal file
View File

@@ -0,0 +1,53 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
'''
Worker-related helper functions.
'''
from vllm.utils import STR_NOT_IMPL_ENC_DEC_ERR_STRS
from vllm.worker.model_runner import GPUModelRunnerBase
def assert_enc_dec_mr_supported_scenario(
enc_dec_mr: GPUModelRunnerBase) -> None:
'''
Asserted that the provided encoder/decoder model runner instance reflects
a supported scenario.
'''
# Reminder: Please update docs/features/compatibility_matrix.md
# If the feature combo become valid
if enc_dec_mr.cache_config.enable_prefix_caching:
raise NotImplementedError(
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE'])
if enc_dec_mr.sliding_window is not None:
raise NotImplementedError(
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_SWA'])
if enc_dec_mr.scheduler_config.chunked_prefill_enabled:
raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_ERR_STRS[
'STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL'])
if getattr(enc_dec_mr.model_config.hf_config, 'attn_logit_softcapping',
None) is not None:
raise NotImplementedError(
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP']
)
if enc_dec_mr.lora_config is not None:
raise NotImplementedError(
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_LORA'])
if enc_dec_mr.parallel_config.pipeline_parallel_size > 1:
raise NotImplementedError(
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_PP'])
if enc_dec_mr.scheduler_config.num_lookahead_slots > 0:
raise NotImplementedError(
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_SPEC_DEC'])
if enc_dec_mr.prompt_adapter_config is not None:
raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_ERR_STRS[
'STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER'])

588
vllm/worker/worker.py Normal file
View File

@@ -0,0 +1,588 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A GPU worker class."""
import gc
import os
from typing import Dict, List, Optional, Set, Tuple, Type, Union
import torch
import torch.distributed
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.device_allocator.cumem import CuMemAllocator
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment,
set_custom_all_reduce)
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.platforms import current_platform
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
SequenceGroupMetadata, SequenceGroupMetadataDelta)
from vllm.utils import (GiB_bytes, MemorySnapshot, bind_kv_cache,
memory_profiling)
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
from vllm.worker.pooling_model_runner import PoolingModelRunner
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
WorkerInput)
logger = init_logger(__name__)
class Worker(LocalOrDistributedWorkerBase):
"""A worker class that executes (a partition of) the model on a GPU.
Each worker is associated with a single GPU. The worker is responsible for
maintaining the KV cache and executing the model on the GPU. In case of
distributed inference, each worker is assigned a partition of the model.
"""
def __init__(
self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool = False,
model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None,
) -> None:
WorkerBase.__init__(self, vllm_config)
self.parallel_config.rank = rank
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
self.is_driver_worker = is_driver_worker
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
# Return hidden states from target model if the draft model is an
# mlp_speculator
speculative_config = self.speculative_config
model_config = self.model_config
speculative_args = {} if speculative_config is None \
or (speculative_config.draft_model_config.hf_config.model_type ==
model_config.hf_config.model_type) \
or (speculative_config.draft_model_config.hf_config.model_type
not in ("medusa",
"mlp_speculator",
"eagle",
"deepseek_mtp",
"glm4_moe_mtp",
"mimo_mtp")) \
else {"return_hidden_states": True}
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
if model_config.runner_type == "pooling":
ModelRunnerClass = PoolingModelRunner
elif self.model_config.is_encoder_decoder:
ModelRunnerClass = EncoderDecoderModelRunner
self.model_runner: GPUModelRunnerBase = ModelRunnerClass(
vllm_config=self.vllm_config,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=is_driver_worker,
**speculative_args,
)
if model_runner_cls is not None:
self.model_runner = model_runner_cls(self.model_runner)
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
self.cache_engine: List[CacheEngine]
# Initialize gpu_cache as pooling models don't initialize kv_caches
self.gpu_cache: Optional[List[List[torch.Tensor]]] = None
self._seq_group_metadata_cache: Dict[str, SequenceGroupMetadata] = {}
# Buffers saved before sleep
self._sleep_saved_buffers: Dict[str, torch.Tensor] = {}
# Torch profiler. Enabled and configured through env vars:
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
if envs.VLLM_TORCH_PROFILER_DIR:
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
logger.info("Profiling enabled. Traces will be saved to: %s",
torch_profiler_trace_dir)
self.profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
with_stack=True,
on_trace_ready=torch.profiler.tensorboard_trace_handler(
torch_profiler_trace_dir, use_gzip=True))
else:
self.profiler = None
def start_profile(self):
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
self.profiler.start()
def stop_profile(self):
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
self.profiler.stop()
print(
self.profiler.key_averages().table(sort_by="self_cuda_time_total"))
def sleep(self, level: int = 1) -> None:
free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
# Save the buffers before level 2 sleep
if level == 2:
model = self.model_runner.model
self._sleep_saved_buffers = {
name: buffer.cpu().clone()
for name, buffer in model.named_buffers()
}
allocator = CuMemAllocator.get_instance()
allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple())
free_bytes_after_sleep, total = torch.cuda.mem_get_info()
freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
used_bytes = total - free_bytes_after_sleep
assert freed_bytes >= 0, "Memory usage increased after sleeping."
logger.info(
"Sleep mode freed %.2f GiB memory, "
"%.2f GiB memory is still in use.", freed_bytes / GiB_bytes,
used_bytes / GiB_bytes)
def wake_up(self, tags: Optional[list[str]] = None) -> None:
allocator = CuMemAllocator.get_instance()
allocator.wake_up(tags=tags)
# Restore the buffers after level 2 sleep
if len(self._sleep_saved_buffers):
model = self.model_runner.model
for name, buffer in model.named_buffers():
if name in self._sleep_saved_buffers:
buffer.data.copy_(self._sleep_saved_buffers[name].data)
self._sleep_saved_buffers = {}
def init_device(self) -> None:
if self.device_config.device.type == "cuda":
# torch.distributed.all_reduce does not free the input tensor until
# the synchronization point. This causes the memory usage to grow
# as the number of all_reduce calls increases. This env var disables
# this behavior.
# Related issue:
# https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
# This env var set by Ray causes exceptions with graph building.
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
self.device = torch.device(f"cuda:{self.local_rank}")
torch.cuda.set_device(self.device)
_check_if_gpu_supports_dtype(self.model_config.dtype)
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
self.baseline_snapshot = MemorySnapshot()
else:
raise RuntimeError(
f"Not support device type: {self.device_config.device}")
# Initialize the distributed environment.
init_worker_distributed_environment(self.vllm_config, self.rank,
self.distributed_init_method,
self.local_rank)
# Set random seed.
set_random_seed(self.model_config.seed)
def load_model(self):
if self.vllm_config.model_config.enable_sleep_mode:
allocator = CuMemAllocator.get_instance()
assert allocator.get_current_usage() == 0, (
"Sleep mode can only be "
"used for one instance per process.")
context = allocator.use_memory_pool(tag="weights")
else:
from contextlib import nullcontext
context = nullcontext()
with context:
self.model_runner.load_model()
def save_sharded_state(
self,
path: str,
pattern: Optional[str] = None,
max_size: Optional[int] = None,
) -> None:
self.model_runner.save_sharded_state(
path,
pattern=pattern,
max_size=max_size,
)
def save_tensorized_model(
self,
tensorizer_config: TensorizerConfig,
) -> None:
self.model_runner.save_tensorized_model(
tensorizer_config=tensorizer_config, )
@torch.inference_mode()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Profiles the peak memory usage of the model to determine how many
KV blocks may be allocated without OOMs.
The engine will first conduct a profiling of the existing memory usage.
Then, it calculate the maximum possible number of GPU and CPU blocks
that can be allocated with the remaining free memory.
Tip:
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""
# Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory.
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
free_memory_pre_profile, total_gpu_memory = torch.cuda.mem_get_info()
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
with memory_profiling(
self.baseline_snapshot,
weights_memory=self.model_runner.model_memory_usage) as result:
self.model_runner.profile_run()
self._assert_memory_footprint_increased_during_profiling()
memory_for_current_instance = total_gpu_memory * \
self.cache_config.gpu_memory_utilization
available_kv_cache_memory = (memory_for_current_instance -
result.non_kv_cache_memory)
# Calculate the number of blocks that can be allocated with the
# profiled peak memory.
cache_block_size = self.get_cache_block_size_bytes()
if cache_block_size == 0:
num_gpu_blocks = 0
num_cpu_blocks = 0
else:
num_gpu_blocks = int(available_kv_cache_memory // cache_block_size)
num_cpu_blocks = int(self.cache_config.swap_space_bytes //
cache_block_size)
num_gpu_blocks = max(num_gpu_blocks, 0)
num_cpu_blocks = max(num_cpu_blocks, 0)
msg = (f"Memory profiling takes {result.profile_time:.2f} seconds\n"
"the current vLLM instance can use "
"total_gpu_memory "
f"({(total_gpu_memory / GiB_bytes):.2f}GiB)"
" x gpu_memory_utilization "
f"({self.cache_config.gpu_memory_utilization:.2f})"
f" = {(memory_for_current_instance / GiB_bytes):.2f}GiB\n"
"model weights take "
f"{(result.weights_memory / GiB_bytes):.2f}GiB;"
" non_torch_memory takes "
f"{(result.non_torch_increase / GiB_bytes):.2f}GiB;"
" PyTorch activation peak memory takes "
f"{(result.torch_peak_increase / GiB_bytes):.2f}GiB;"
" the rest of the memory reserved for KV Cache is "
f"{(available_kv_cache_memory / GiB_bytes):.2f}GiB.")
logger.info(msg)
# Final cleanup
gc.collect()
return num_gpu_blocks, num_cpu_blocks
def _assert_memory_footprint_increased_during_profiling(self):
# NOTE(woosuk): Here we assume that the other processes using the same
# GPU did not change their memory usage during the profiling.
free_gpu_memory, total = torch.cuda.mem_get_info()
cuda_memory = total - free_gpu_memory
assert self.baseline_snapshot.cuda_memory < cuda_memory, (
"Error in memory profiling. "
f"Initial used memory {self.baseline_snapshot.cuda_memory}, "
f"currently used memory {cuda_memory}. "
f"This happens when the GPU memory was "
"not properly cleaned up before initializing the vLLM instance.")
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Allocate GPU and CPU KV cache with the specified number of blocks.
This also warms up the model, which may record CUDA graphs.
"""
raise_if_cache_size_invalid(
num_gpu_blocks, self.cache_config.block_size,
self.cache_config.is_attention_free,
self.model_config.max_model_len,
self.parallel_config.pipeline_parallel_size)
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
if self.vllm_config.model_config.enable_sleep_mode:
allocator = CuMemAllocator.get_instance()
context = allocator.use_memory_pool(tag="kv_cache")
else:
from contextlib import nullcontext
context = nullcontext()
with context:
self._init_cache_engine()
self._warm_up_model()
def _init_cache_engine(self):
assert self.cache_config.num_gpu_blocks is not None
self.cache_engine = [
CacheEngine(self.cache_config, self.model_config,
self.parallel_config, self.device_config)
for _ in range(self.parallel_config.pipeline_parallel_size)
]
self.gpu_cache = [
self.cache_engine[ve].gpu_cache
for ve in range(self.parallel_config.pipeline_parallel_size)
]
bind_kv_cache(self.compilation_config.static_forward_context,
self.gpu_cache)
def _warm_up_model(self) -> None:
# warm up sizes that are not in cudagraph capture sizes,
# but users still want to compile for better performance,
# e.g. for the max-num-batched token size in chunked prefill.
warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
if not self.model_config.enforce_eager:
warmup_sizes = [
x for x in warmup_sizes if x not in
self.vllm_config.compilation_config.cudagraph_capture_sizes
]
for size in sorted(warmup_sizes, reverse=True):
logger.info("Compile and warming up model for size %d", size)
self.model_runner._dummy_run(size)
if not self.model_config.enforce_eager:
self.model_runner.capture_model(self.gpu_cache)
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
@property
def do_metadata_broadcast(self) -> bool:
return self.parallel_config.tensor_parallel_size > 1
@property
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
return self.gpu_cache
@property
def cache_engines(self) -> Optional[List[CacheEngine]]:
return self.cache_engine
@torch.inference_mode()
def prepare_worker_input(
self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
virtual_engine = execute_model_req.virtual_engine
num_steps = execute_model_req.num_steps
num_seq_groups = len(execute_model_req.seq_group_metadata_list)
# `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors.
# they contain parameters to launch cudamemcpyasync.
blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in,
device="cpu",
dtype=torch.int64).view(-1, 2)
blocks_to_swap_out = torch.tensor(execute_model_req.blocks_to_swap_out,
device="cpu",
dtype=torch.int64).view(-1, 2)
# `blocks_to_copy` is a gpu tensor. The src and tgt of
# blocks to copy are in the same device, and `blocks_to_copy`
# can be used directly within cuda kernels.
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
device=self.device,
dtype=torch.int64).view(-1, 2)
return WorkerInput(
num_seq_groups=num_seq_groups,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
virtual_engine=virtual_engine,
num_steps=num_steps,
kvcache_slot_to_be_moved=execute_model_req.kvcache_slot_to_be_moved
)
@torch.inference_mode()
def execute_worker(self, worker_input: WorkerInput) -> None:
virtual_engine = worker_input.virtual_engine
# Issue cache operations.
if (worker_input.blocks_to_swap_in is not None
and worker_input.blocks_to_swap_in.numel() > 0):
self.cache_engine[virtual_engine].swap_in(
worker_input.blocks_to_swap_in)
if (worker_input.blocks_to_swap_out is not None
and worker_input.blocks_to_swap_out.numel() > 0):
self.cache_engine[virtual_engine].swap_out(
worker_input.blocks_to_swap_out)
if (worker_input.blocks_to_copy is not None
and worker_input.blocks_to_copy.numel() > 0):
self.cache_engine[virtual_engine].copy(worker_input.blocks_to_copy)
# tree-style generation need to move kvcache to correct position
if worker_input.kvcache_slot_to_be_moved is not None:
self.cache_engine[virtual_engine].move_caches(self.kv_cache[virtual_engine],
worker_input.kvcache_slot_to_be_moved)
def _get_cached_seq_group_metadata(
self,
seq_group_metadata_list: List[Union[SequenceGroupMetadata,
SequenceGroupMetadataDelta]],
finished_request_ids: List[str]) -> List[SequenceGroupMetadata]:
"""Return a list of cached Sequence Group Metadata after updating its
state.
It is used because scheduler only sends delta to workers to reduce
the data payload size. The function also cleans up cache based on
a given `finished_request_ids`.
"""
new_seq_group_metadata_list = []
for metadata_or_delta in seq_group_metadata_list:
request_id = metadata_or_delta.request_id
if request_id not in self._seq_group_metadata_cache:
# The first prefill.
assert isinstance(metadata_or_delta, SequenceGroupMetadata)
self._seq_group_metadata_cache[request_id] = metadata_or_delta
else:
# The first prefill is already cached.
if isinstance(metadata_or_delta, SequenceGroupMetadataDelta):
self._seq_group_metadata_cache[request_id].apply_delta(
metadata_or_delta)
else:
# If metadata snapshot is sent again, it is
# preempted. Reset the cache because we need to start
# from scratch.
assert isinstance(metadata_or_delta, SequenceGroupMetadata)
self._seq_group_metadata_cache[
request_id] = metadata_or_delta
new_seq_group_metadata_list.append(
self._seq_group_metadata_cache[request_id])
# Clean up finished ids
for finished_id in finished_request_ids:
del self._seq_group_metadata_cache[finished_id]
return new_seq_group_metadata_list
def _execute_model_spmd(
self,
execute_model_req: ExecuteModelRequest,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Optional[List[SamplerOutput]]:
if execute_model_req is not None:
new_seq_group_metadata_list = self._get_cached_seq_group_metadata(
execute_model_req.seq_group_metadata_list,
execute_model_req.finished_requests_ids)
execute_model_req.seq_group_metadata_list = (
new_seq_group_metadata_list)
output = super()._execute_model_spmd(execute_model_req,
intermediate_tensors)
return output
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_runner.add_lora(lora_request)
def remove_lora(self, lora_id: int) -> bool:
return self.model_runner.remove_lora(lora_id)
def pin_lora(self, lora_id: int) -> bool:
return self.model_runner.pin_lora(lora_id)
def list_loras(self) -> Set[int]:
return self.model_runner.list_loras()
def add_prompt_adapter(
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
return self.model_runner.add_prompt_adapter(prompt_adapter_request)
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
return self.model_runner.remove_lora(prompt_adapter_id)
def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
return self.model_runner.pin_prompt_adapter(prompt_adapter_id)
def list_prompt_adapters(self) -> Set[int]:
return self.model_runner.list_prompt_adapters()
@property
def max_model_len(self) -> int:
return self.model_config.max_model_len
@property
def vocab_size(self) -> int:
return self.model_runner.vocab_size
def get_cache_block_size_bytes(self) -> int:
"""Get the size of the KV cache block size in bytes.
"""
return CacheEngine.get_cache_block_size(self.cache_config,
self.model_config,
self.parallel_config)
def init_worker_distributed_environment(
vllm_config: VllmConfig,
rank: int,
distributed_init_method: Optional[str] = None,
local_rank: int = -1,
) -> None:
"""Initialize the distributed environment."""
parallel_config = vllm_config.parallel_config
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
init_distributed_environment(parallel_config.world_size, rank,
distributed_init_method, local_rank)
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
ensure_kv_transfer_initialized(vllm_config)
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
# Check if the GPU supports the dtype.
if torch_dtype == torch.bfloat16: # noqa: SIM102
if not current_platform.has_device_capability(80):
capability = current_platform.get_device_capability()
gpu_name = current_platform.get_device_name()
if capability is None:
compute_str = "does not have a compute capability"
else:
version_str = capability.as_version_str()
compute_str = f"has compute capability {version_str}"
raise ValueError(
"Bfloat16 is only supported on GPUs with compute capability "
f"of at least 8.0. Your {gpu_name} GPU {compute_str}. "
"You can use float16 instead by explicitly setting the "
"`dtype` flag in CLI, for example: --dtype=half.")
def raise_if_cache_size_invalid(num_gpu_blocks, block_size, is_attention_free,
max_model_len, pipeline_parallel_size) -> None:
if is_attention_free and num_gpu_blocks != 0:
raise ValueError("No memory should be allocated for the cache blocks "
f"for an attention-free model, but {num_gpu_blocks} "
"blocks are allocated.")
if not is_attention_free and num_gpu_blocks <= 0:
raise ValueError("No available memory for the cache blocks. "
"Try increasing `gpu_memory_utilization` when "
"initializing the engine.")
max_seq_len = block_size * (num_gpu_blocks // pipeline_parallel_size)
if not is_attention_free and max_model_len > max_seq_len:
raise ValueError(
f"The model's max seq len ({max_model_len}) "
"is larger than the maximum number of tokens that can be "
f"stored in KV cache ({max_seq_len}). Try increasing "
"`gpu_memory_utilization` or decreasing `max_model_len` when "
"initializing the engine.")

706
vllm/worker/worker_base.py Normal file
View File

@@ -0,0 +1,706 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
import os
import numa
import time
from abc import abstractmethod
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
import cloudpickle
import torch
import torch.nn as nn
from vllm.config import (ObservabilityConfig, VllmConfig,
set_current_vllm_config)
from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
from vllm.utils import (enable_trace_function_call_for_thread,
resolve_obj_by_qualname, run_method,
update_environment_variables,
warn_for_unimplemented_methods)
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.model_runner_base import (BroadcastableModelInput,
ModelRunnerBase,
ModelRunnerInputBase)
logger = init_logger(__name__)
# 设置当前进程绑定到 NUMA 节点
def bind_to_numa(local_rank):
env_str = f"VLLM_RANK{local_rank}_NUMA"
node_count = numa.get_max_node() + 1
numa_node = int(os.getenv(env_str, -1))
# 未配置环境变量或配置错误则不做绑定TODO根据topo自动绑定方案
if numa_node < 0:
logger.warning("%s is unset or set incorrectly, vllm will not bind to numa! %s = %d", env_str, env_str, numa_node)
return
if numa_node > numa.get_max_node():
raise ValueError(f"NUMA node {numa_node} is not available.")
numa.bind([numa_node])
@warn_for_unimplemented_methods
class WorkerBase:
"""Worker interface that allows vLLM to cleanly separate implementations for
different hardware. Also abstracts control plane communication, e.g., to
communicate request metadata to other workers.
"""
model_input: Optional[ModelRunnerInputBase] = None
tree_decoding = (os.environ.get('VLLM_TREE_DECODING') == '1')
def __init__(
self,
vllm_config: VllmConfig,
) -> None:
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config
self.scheduler_config = vllm_config.scheduler_config
self.device_config = vllm_config.device_config
self.speculative_config = vllm_config.speculative_config
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config
self.kv_transfer_config = vllm_config.kv_transfer_config
self.compilation_config = vllm_config.compilation_config
from vllm.platforms import current_platform
self.current_platform = current_platform
def init_device(self) -> None:
"""Initialize device state, such as loading the model or other on-device
memory allocations.
"""
raise NotImplementedError
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Initialize the KV cache with the given size in blocks.
"""
raise NotImplementedError
def get_model(self) -> nn.Module:
raise NotImplementedError
def load_model(self) -> None:
"""Load model onto target device."""
raise NotImplementedError
def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> Optional[List[SamplerOutput]]:
raise NotImplementedError
def start_worker_execution_loop(self) -> None:
"""Execute model loop in parallel worker.
You can stop the loop by executing a driver worker with an empty output.
See `stop_remote_worker_execution_loop` for more details.
"""
with self.current_platform.inference_mode():
while True:
output = self.execute_model(execute_model_req=None)
if output is None:
return None
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available blocks for the GPU KV cache and
swappable CPU KV cache.
The implementation may run profiling or other heuristics to determine
the size of caches.
Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
are blocks that are "active" on the device and can be appended to.
num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
appended to.
"""
raise NotImplementedError
def get_cache_block_size_bytes(self) -> int:
"""Return the size of a single cache block, in bytes. Used in
speculative decoding.
"""
raise NotImplementedError
def add_lora(self, lora_request: LoRARequest) -> bool:
raise NotImplementedError
def remove_lora(self, lora_id: int) -> bool:
raise NotImplementedError
def pin_lora(self, lora_id: int) -> bool:
raise NotImplementedError
def list_loras(self) -> Set[int]:
raise NotImplementedError
# @property
# @abstractmethod
# def cache_engines(self) -> Optional[List[CacheEngine]]:
# raise NotImplementedError
@property
def vocab_size(self) -> int:
"""Get vocabulary size from model configuration."""
return self.model_config.get_vocab_size()
class DelegateWorkerBase(WorkerBase):
"""
A class that delegates all methods to another WorkerBase instance. This is
useful for creating a WorkerBase that wraps another WorkerBase instance,
e.g. speculative decoding.
"""
worker: WorkerBase
def __init__(
self,
*args,
**kwargs,
) -> None:
vllm_config: VllmConfig = kwargs.get("vllm_config")
cls = resolve_obj_by_qualname(vllm_config.parallel_config.worker_cls)
self.worker = cls(*args, **kwargs)
def init_device(self) -> None:
self.worker.init_device()
def determine_num_available_blocks(self) -> Tuple[int, int]:
return self.worker.determine_num_available_blocks()
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
self.worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
def load_model(self) -> None:
"""Load model onto target device."""
self.worker.load_model()
def get_model(self) -> nn.Module:
return self.worker.get_model()
def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> Optional[List[SamplerOutput]]:
return self.worker.execute_model(execute_model_req)
def get_cache_block_size_bytes(self) -> int:
return self.worker.get_cache_block_size_bytes()
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.worker.add_lora(lora_request)
def remove_lora(self, lora_id: int) -> bool:
return self.worker.remove_lora(lora_id)
def pin_lora(self, lora_id: int) -> bool:
return self.worker.pin_lora(lora_id)
def list_loras(self) -> Set[int]:
return self.worker.list_loras()
def __getattr__(self, attr):
return getattr(self.worker, attr)
class LoRANotSupportedWorkerBase(WorkerBase):
"""Partial implementation of WorkerBase that raises exceptions when LoRA
methods are invoked.
"""
def add_lora(self, lora_request: LoRARequest) -> bool:
raise ValueError(f"{type(self)} does not support LoRA")
def remove_lora(self, lora_id: int) -> bool:
raise ValueError(f"{type(self)} does not support LoRA")
def pin_lora(self, lora_id: int) -> bool:
raise ValueError(f"{type(self)} does not support LoRA")
def list_loras(self) -> Set[int]:
raise ValueError(f"{type(self)} does not support LoRA")
@property
def cache_engines(self) -> Optional[List[CacheEngine]]:
return None
@dataclasses.dataclass(frozen=True)
class WorkerInput:
"""Local inputs to each worker. May contain device-specific data. These
fields should be broadcastable to other workers.
"""
num_seq_groups: Optional[int] = None
blocks_to_swap_in: Optional[torch.Tensor] = None
blocks_to_swap_out: Optional[torch.Tensor] = None
blocks_to_copy: Optional[torch.Tensor] = None
virtual_engine: int = 0
num_steps: int = 1
# Optional slot mapping of kvcache that pending to be moved generated from draft model.
kvcache_slot_to_be_moved: Optional[torch.Tensor] = None
@classmethod
def from_broadcasted_tensor_dict(
cls: Type["WorkerInput"],
tensor_dict: Dict[str, Any],
) -> "WorkerInput":
"""
Pop fields from the given tensor_dict and populate a new instance of
WorkerInput.
"""
return cls(
num_seq_groups=tensor_dict.pop("num_seq_groups"),
blocks_to_swap_in=tensor_dict.pop("blocks_to_swap_in"),
blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"),
blocks_to_copy=tensor_dict.pop("blocks_to_copy"),
virtual_engine=tensor_dict["virtual_engine"],
num_steps=tensor_dict.pop("num_steps"),
kvcache_slot_to_be_moved=tensor_dict.pop("kvcache_slot_to_be_moved"),
)
def as_broadcastable_tensor_dict(
self) -> Dict[str, Union[int, torch.Tensor]]:
"""
Extract broadcastable fields.
"""
tensor_dict = {
"num_seq_groups": self.num_seq_groups,
"blocks_to_swap_in": self.blocks_to_swap_in,
"blocks_to_swap_out": self.blocks_to_swap_out,
"blocks_to_copy": self.blocks_to_copy,
"virtual_engine": self.virtual_engine,
"num_steps": self.num_steps,
"kvcache_slot_to_be_moved": self.kvcache_slot_to_be_moved
}
return tensor_dict
class LocalOrDistributedWorkerBase(WorkerBase):
"""
Partial implementation of WorkerBase that has a default `execute_model`
definition to perform metadata transfer between workers when in distributed
mode. Subclasses of this interface should use model runners that inherit
from ModelRunnerBase, and should only need to implement worker-local logic.
If custom control plane logic is needed to transfer metadata, or if the
model runner cannot inherit from ModelRunnerBase, use WorkerBase instead.
"""
is_driver_worker: bool
model_runner: ModelRunnerBase
observability_config: Optional[ObservabilityConfig] = None
@property
@abstractmethod
def do_metadata_broadcast(self) -> bool:
"""
Used by the default `execute_model` to check whether broadcast is
needed to transfer request inputs from the driver worker to other
workers in the TP group. If WorkerBase subclass only supports
single-worker execution, then this method should return False.
"""
raise NotImplementedError
@property
@abstractmethod
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
"""
Gets the list of kv caches to pass to the worker's model runner. Each
element in the list is a kv cache corresponding to a particular virtual
engine (PP stream). Used by the default `execute_model`. If the worker's
model runner does not follow the ModelRunnerBase interface, then inherit
from WorkerBase instead.
"""
raise NotImplementedError
@abstractmethod
def prepare_worker_input(
self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
"""
Prepare the inputs to WorkerBase.execute_worker from an execution
request. This method may move data to the worker's local device. It is
not allowed to communicate with other workers or devices.
"""
raise NotImplementedError
@abstractmethod
def execute_worker(self, worker_input: WorkerInput) -> None:
"""
Process an execution request.
"""
raise NotImplementedError
def _get_worker_input_from_broadcast(
self
) -> Optional[Tuple[BroadcastableModelInput, WorkerInput, Dict[
str, torch.Tensor]]]:
""" Get the worker input from the broadcasted tensor dict. """
assert self.do_metadata_broadcast
assert not self.is_driver_worker
broadcast_data = broadcast_tensor_dict(src=0)
if not broadcast_data:
return None
worker_input = WorkerInput.from_broadcasted_tensor_dict(broadcast_data)
model_input = (
self.model_runner.make_model_input_from_broadcasted_tensor_dict(
broadcast_data))
kwargs = extract_previous_hidden_states(broadcast_data)
return model_input, worker_input, kwargs
def _get_driver_input_and_broadcast(
self, execute_model_req: ExecuteModelRequest
) -> Tuple[BroadcastableModelInput, WorkerInput, Dict[str, torch.Tensor]]:
""" Get the driver input and broadcast it to other workers. """
assert self.is_driver_worker
worker_input: WorkerInput = self.prepare_worker_input(
execute_model_req=execute_model_req)
model_input: ModelRunnerInputBase = (
self.model_runner.prepare_model_input(
execute_model_req.seq_group_metadata_list,
execute_model_req.virtual_engine,
execute_model_req.finished_requests_ids))
if self.tree_decoding and execute_model_req.tree_position_ids is not None and \
execute_model_req.tree_attn_masks is not None:
if hasattr(model_input, "input_positions") and \
hasattr(model_input, "attn_metadata") and \
hasattr(model_input.attn_metadata, "tree_attention_masks_tensor"):
attn_metadata = model_input.attn_metadata
attn_metadata.tree_attention_masks_tensor = execute_model_req.tree_attn_masks.contiguous()
model_input = dataclasses.replace(model_input,
input_positions=execute_model_req.tree_position_ids.contiguous(),
attn_metadata=attn_metadata)
kwargs = extract_previous_hidden_states(execute_model_req)
if self.do_metadata_broadcast:
broadcast_data = worker_input.as_broadcastable_tensor_dict()
broadcast_data.update(model_input.as_broadcastable_tensor_dict())
broadcast_data.update(kwargs)
broadcast_tensor_dict(broadcast_data, src=0)
if execute_model_req.async_callback:
model_input = dataclasses.replace( # type: ignore
model_input,
async_callback=execute_model_req.async_callback)
return model_input, worker_input, kwargs
def prepare_input(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> Optional[Tuple[BroadcastableModelInput, WorkerInput, Dict[
str, torch.Tensor]]]:
"""
Prepare the inputs to ModelRunner and workers.
"""
if self.is_driver_worker:
if execute_model_req is None:
if self.do_metadata_broadcast:
# This signals that there's no more requests to process for
# now. All workers are running infinite loop with
# broadcast_tensor_dict, and it stops the loop when the
# driver broadcasts an empty input. Send an empty input to
# notify all other workers to stop their execution loop.
broadcast_tensor_dict({}, src=0)
return None
return self._get_driver_input_and_broadcast(execute_model_req)
else:
return self._get_worker_input_from_broadcast()
def get_model(self) -> nn.Module:
return self.model_runner.get_model()
def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None,
) -> Optional[List[SamplerOutput]]:
"""Executes at least one model step on the given sequences, unless no
sequences are provided."""
start_time = time.perf_counter()
inputs = self.prepare_input(execute_model_req)
if inputs is None:
return None
model_input, worker_input, kwargs = inputs
num_steps = worker_input.num_steps
if execute_model_req is not None and execute_model_req.spec_step_idx:
kwargs["spec_step_idx"] = execute_model_req.spec_step_idx
self.model_input = model_input
self.execute_worker(worker_input)
# If there is no input, we don't need to execute the model.
if worker_input.num_seq_groups == 0:
return []
intermediate_tensors = None
orig_model_execute_time = 0.0
if not get_pp_group().is_first_rank:
intermediate_tensors = IntermediateTensors(
get_pp_group().recv_tensor_dict(
all_gather_group=get_tp_group()))
if (self.observability_config is not None
and self.observability_config.collect_model_execute_time):
orig_model_execute_time = intermediate_tensors.tensors.get(
"model_execute_time", torch.tensor(0)).item()
output = self.model_runner.execute_model(
model_input=model_input,
kv_caches=self.kv_cache[worker_input.virtual_engine]
if self.kv_cache is not None else None,
intermediate_tensors=intermediate_tensors,
num_steps=num_steps,
**kwargs,
)
model_execute_time = time.perf_counter() - start_time
if not get_pp_group().is_last_rank:
# output is IntermediateTensors
assert isinstance(output, IntermediateTensors)
if (self.observability_config is not None
and self.observability_config.collect_model_execute_time):
output.tensors["model_execute_time"] = torch.tensor(
model_execute_time + orig_model_execute_time)
get_pp_group().send_tensor_dict(output.tensors,
all_gather_group=get_tp_group())
return [None]
if (self.observability_config is not None
and self.observability_config.collect_model_execute_time
and output is not None):
for o in output:
o.model_execute_time = (orig_model_execute_time +
model_execute_time)
# output is List[SamplerOutput]
return output
def _execute_model_spmd(
self,
execute_model_req: ExecuteModelRequest,
intermediate_tensors: Optional[IntermediateTensors] = None
) -> Optional[List[SamplerOutput]]:
"""
Execute model in Single Program Multiple Data (SPMD) fashion.
All workers take the same request, prepare the input and
execute the model.
"""
assert execute_model_req is not None, (
"_execute_model_spmd() requires each worker to take in an "
"ExecuteModelRequest")
worker_input: WorkerInput = self.prepare_worker_input(
execute_model_req=execute_model_req)
model_input: ModelRunnerInputBase = (
self.model_runner.prepare_model_input(
execute_model_req.seq_group_metadata_list))
self.execute_worker(worker_input)
# If there is no input, we don't need to execute the model.
if worker_input.num_seq_groups == 0:
return []
kwargs = extract_previous_hidden_states(execute_model_req)
return self.model_runner.execute_model(
model_input=model_input,
kv_caches=self.kv_cache[worker_input.virtual_engine]
if self.kv_cache is not None else None,
intermediate_tensors=intermediate_tensors,
**kwargs,
)
class WorkerWrapperBase:
"""
This class represents one process in an executor/engine. It is responsible
for lazily initializing the worker and handling the worker's lifecycle.
We first instantiate the WorkerWrapper, which remembers the worker module
and class name. Then, when we call `update_environment_variables`, and the
real initialization happens in `init_worker`.
"""
def __init__(
self,
vllm_config: VllmConfig,
rpc_rank: int = 0,
) -> None:
"""
Initialize the worker wrapper with the given vllm_config and rpc_rank.
Note: rpc_rank is the rank of the worker in the executor. In most cases,
it is also the rank of the worker in the distributed group. However,
when multiple executors work together, they can be different.
e.g. in the case of SPMD-style offline inference with TP=2,
users can launch 2 engines/executors, each with only 1 worker.
All workers have rpc_rank=0, but they have different ranks in the TP
group.
"""
self.rpc_rank = rpc_rank
self.worker: Optional[WorkerBase] = None
self.vllm_config: Optional[VllmConfig] = None
# do not store this `vllm_config`, `init_worker` will set the final
# one. TODO: investigate if we can remove this field in
# `WorkerWrapperBase`, `init_cached_hf_modules` should be
# unnecessary now.
if vllm_config.model_config is not None:
# it can be None in tests
trust_remote_code = vllm_config.model_config.trust_remote_code
if trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
def adjust_rank(self, rank_mapping: Dict[int, int]) -> None:
"""
Adjust the rpc_rank based on the given mapping.
It is only used during the initialization of the executor,
to adjust the rpc_rank of workers after we create all workers.
"""
if self.rpc_rank in rank_mapping:
self.rpc_rank = rank_mapping[self.rpc_rank]
def update_environment_variables(self, envs_list: List[Dict[str,
str]]) -> None:
envs = envs_list[self.rpc_rank]
key = 'CUDA_VISIBLE_DEVICES'
if key in envs and key in os.environ:
# overwriting CUDA_VISIBLE_DEVICES is desired behavior
# suppress the warning in `update_environment_variables`
del os.environ[key]
update_environment_variables(envs)
def init_worker(self, all_kwargs: List[Dict[str, Any]]) -> None:
"""
Here we inject some common logic before initializing the worker.
Arguments are passed to the worker class constructor.
"""
kwargs = all_kwargs[self.rpc_rank]
self.vllm_config = kwargs.get("vllm_config", None)
assert self.vllm_config is not None, (
"vllm_config is required to initialize the worker")
enable_trace_function_call_for_thread(self.vllm_config)
from vllm.plugins import load_general_plugins
load_general_plugins()
if isinstance(self.vllm_config.parallel_config.worker_cls, str):
worker_class = resolve_obj_by_qualname(
self.vllm_config.parallel_config.worker_cls)
else:
logger.warning(
"passing worker_cls as a class object is strongly deprecated,"
" as the serialization of class objects can be tricky and"
" error-prone. To be safe, please keep the class in a separate"
" module and pass the qualified name of the class as a string."
)
assert isinstance(self.vllm_config.parallel_config.worker_cls,
bytes)
worker_class = cloudpickle.loads(
self.vllm_config.parallel_config.worker_cls)
if self.vllm_config.parallel_config.worker_extension_cls:
worker_extension_cls = resolve_obj_by_qualname(
self.vllm_config.parallel_config.worker_extension_cls)
extended_calls = []
if worker_extension_cls not in worker_class.__bases__:
# check any conflicts between worker and worker_extension_cls
for attr in dir(worker_extension_cls):
if attr.startswith("__"):
continue
assert not hasattr(worker_class, attr), (
f"Worker class {worker_class} already has an attribute"
f" {attr}, which conflicts with the worker"
f" extension class {worker_extension_cls}.")
if callable(getattr(worker_extension_cls, attr)):
extended_calls.append(attr)
# dynamically inherit the worker extension class
worker_class.__bases__ = worker_class.__bases__ + (
worker_extension_cls, )
logger.info(
"Injected %s into %s for extended collective_rpc calls %s",
worker_extension_cls, worker_class, extended_calls)
with set_current_vllm_config(self.vllm_config):
# To make vLLM config available during worker initialization
self.worker = worker_class(**kwargs)
assert self.worker is not None
VLLM_NUMA_BIND = int(os.getenv("VLLM_NUMA_BIND", 1))
if VLLM_NUMA_BIND > 0:
# 绑定当前进程到指定 NUMA 节点
bind_to_numa(kwargs['local_rank'])
pid = os.getpid()
logger.info("########## %d process(rank%s) is running on CPU(s): %s", pid, str(kwargs['local_rank']), str(os.sched_getaffinity(pid)))
logger.info("########## %d process(rank%s) is running on memnode(s): %s", pid, str(kwargs['local_rank']), str(numa.get_membind()))
def initialize_from_config(self, kv_cache_configs: List[Any]) -> None:
kv_cache_config = kv_cache_configs[self.rpc_rank]
with set_current_vllm_config(self.vllm_config):
self.worker.initialize_from_config(kv_cache_config) # type: ignore
def init_device(self):
with set_current_vllm_config(self.vllm_config):
# To make vLLM config available during device initialization
self.worker.init_device() # type: ignore
def execute_method(self, method: Union[str, bytes], *args, **kwargs):
try:
# method resolution order:
# if a method is defined in this class, it will be called directly.
# otherwise, since we define `__getattr__` and redirect attribute
# query to `self.worker`, the method will be called on the worker.
return run_method(self, method, args, kwargs)
except Exception as e:
# if the driver worker also execute methods,
# exceptions in the rest worker may cause deadlock in rpc like ray
# see https://github.com/vllm-project/vllm/issues/3455
# print the error and inform the user to solve the error
msg = (f"Error executing method {method!r}. "
"This might cause deadlock in distributed execution.")
logger.exception(msg)
raise e
def __getattr__(self, attr):
return getattr(self.worker, attr)
def extract_previous_hidden_states(
data: Union[ExecuteModelRequest, Dict[str, torch.Tensor]]) -> \
Dict[str, torch.Tensor]:
"""If data contains previous_hidden_states, extract it. This returns a dict
which can be used directly as additional kwargs in any following
execute_model calls. This is used in draft models like EAGLE."""
output = {}
# When called from non-driver worker, data is dict but when called from
# driver worker, data is ExecuteModelRequest.
if isinstance(data, dict):
if "previous_hidden_states" in data:
output["previous_hidden_states"] = data["previous_hidden_states"]
elif data.previous_hidden_states is not None:
output["previous_hidden_states"] = data.previous_hidden_states\
.hidden_states
return output

View File

@@ -0,0 +1,606 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
import time
import weakref
from collections import defaultdict
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
Type, TypeVar)
import torch
import torch.nn as nn
from vllm.attention import get_attn_backend
from vllm.config import VllmConfig
from vllm.distributed import get_pp_group
from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadataCache
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader import get_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalKwargs, MultiModalPlaceholderMap,
MultiModalRegistry)
from vllm.sampling_params import SamplingParams
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.utils import DeviceMemoryProfiler, GiB_bytes, make_tensor_with_pad
from vllm.worker.model_runner import AttentionMetadata, SamplingMetadata
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
_add_attn_metadata_broadcastable_dict,
_add_sampling_metadata_broadcastable_dict,
_init_attn_metadata_from_tensor_dict,
_init_sampling_metadata_from_tensor_dict)
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
logger = init_logger(__name__)
_PAD_SLOT_ID = -1
TModelInputForXPU = TypeVar('TModelInputForXPU', bound="ModelInputForXPU")
@dataclass(frozen=True)
class ModelInputForXPU(ModelRunnerInputBase):
"""
Used by the NeuronModelRunner.
"""
input_tokens: Optional[torch.Tensor] = None
input_positions: Optional[torch.Tensor] = None
attn_metadata: Optional["AttentionMetadata"] = None
multi_modal_kwargs: Optional[BatchedTensorInputs] = None
virtual_engine: Optional[int] = None
seq_lens: Optional[List[int]] = None
query_lens: Optional[List[int]] = None
async_callback: Optional[Callable] = None
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
"input_tokens": self.input_tokens,
"input_positions": self.input_positions,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
return tensor_dict
@classmethod
def from_broadcasted_tensor_dict(
cls: Type[TModelInputForXPU],
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None,
) -> TModelInputForXPU:
if attn_backend is not None:
tensor_dict = _init_attn_metadata_from_tensor_dict(
attn_backend, tensor_dict)
return cls(**tensor_dict)
@dataclass(frozen=True)
class ModelInputForXPUWithSamplingMetadata(ModelInputForXPU):
"""
Used by the ModelRunner.
"""
sampling_metadata: Optional["SamplingMetadata"] = None
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
"input_tokens": self.input_tokens,
"input_positions": self.input_positions,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
_add_sampling_metadata_broadcastable_dict(tensor_dict,
self.sampling_metadata)
return tensor_dict
@classmethod
def from_broadcasted_tensor_dict(
cls,
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None,
) -> "ModelInputForXPUWithSamplingMetadata":
tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
if attn_backend is not None:
tensor_dict = _init_attn_metadata_from_tensor_dict(
attn_backend, tensor_dict)
return cls(**tensor_dict)
class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]):
def __init__(self,
runner: "XPUModelRunner",
finished_requests_ids: Optional[List[str]] = None) -> None:
super().__init__()
self.runner = runner
self.model_input_cls = self.runner._model_input_cls
self.attn_backend = self.runner.attn_backend
self.sliding_window = self.runner.sliding_window
self.block_size = self.runner.block_size
self.device = self.runner.device
def prepare(self,
finished_requests_ids: Optional[List[str]] = None) -> None:
self.seq_group_metadata_list: List[SequenceGroupMetadata] = []
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
self.seq_group_metadata_list.append(seq_group_metadata)
def build(self) -> ModelInputForXPU:
is_prompt = self.seq_group_metadata_list[0].is_prompt
# Prepare input tensors.
if is_prompt:
(input_tokens, input_positions, attn_metadata, seq_lens,
multi_modal_kwargs) = self._prepare_prompt(
self.seq_group_metadata_list)
else:
(input_tokens, input_positions,
attn_metadata) = self._prepare_decode(
self.seq_group_metadata_list)
seq_lens = None
multi_modal_kwargs = None
return self.model_input_cls(
input_tokens=input_tokens,
input_positions=input_positions,
attn_metadata=attn_metadata,
multi_modal_kwargs=multi_modal_kwargs,
seq_lens=seq_lens,
query_lens=seq_lens,
)
def _prepare_prompt(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
BatchedTensorInputs]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[int] = []
input_positions: List[int] = []
slot_mapping: List[int] = []
seq_lens: List[int] = []
multi_modal_kwargs_list: List[MultiModalKwargs] = []
multi_modal_placeholder_maps: Dict[
str,
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys())
assert len(seq_ids) == 1
seq_id = seq_ids[0]
seq_data = seq_group_metadata.seq_data[seq_id]
prompt_tokens = seq_data.get_token_ids()
computed_len = seq_data.get_num_computed_tokens()
seq_len = len(prompt_tokens)
seq_lens.append(seq_len) # Prompt token num
input_tokens.extend(prompt_tokens) # Token ids
# Token position ids
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
positions_range = range(computed_len, seq_len)
input_positions.extend(list(positions_range))
if seq_group_metadata.multi_modal_data:
# NOTE: mm_kwargs only includes the subset of multi-modal items
# that intersect with the current prefill positions.
mm_kwargs, placeholder_maps = MultiModalPlaceholderMap \
.from_seq_group(seq_group_metadata, positions_range)
multi_modal_kwargs_list.append(mm_kwargs)
for modality, placeholder_map in placeholder_maps.items():
multi_modal_placeholder_maps[modality].extend(
placeholder_map)
if seq_group_metadata.block_tables is None:
# During memory profiling, the block tables are not initialized
# yet. In this case, we just use a dummy slot mapping.
slot_mapping.extend([_PAD_SLOT_ID] * seq_len)
continue
# Compute the slot mapping.
block_table = seq_group_metadata.block_tables[seq_id]
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
# where start_idx is max(0, seq_len - sliding_window).
# For example, if the prompt len is 10, sliding window is 8, and
# block size is 4, the first two tokens are masked and the slot
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
start_idx = 0
if self.sliding_window is not None:
start_idx = max(0, seq_len - self.sliding_window)
for i in range(computed_len, seq_len):
if i < start_idx:
slot_mapping.append(_PAD_SLOT_ID)
continue
block_number = block_table[i //
self.block_size] # type: ignore
block_offset = i % self.block_size # type: ignore
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)
num_prompt_tokens = len(input_tokens)
input_tokens = torch.tensor(input_tokens,
dtype=torch.long,
device=self.device) # type: ignore
input_positions = torch.tensor(input_positions,
dtype=torch.long,
device=self.device) # type: ignore
slot_mapping = torch.tensor(slot_mapping,
dtype=torch.long,
device=self.device) # type: ignore
placeholder_index_maps = {
modality: placeholder_map.index_map()
for modality, placeholder_map in
multi_modal_placeholder_maps.items()
}
max_seqlen = max(seq_lens)
tmp = [0]
tmp.extend(seq_lens)
seqlen = torch.tensor(tmp)
seqlen_q = torch.cumsum(seqlen, dim=0).to(device=self.device)
attn_metadata = self.attn_backend.make_metadata(
is_prompt=True,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=False,
seq_lens=seq_lens,
seqlen_q=seqlen_q,
max_seqlen=max_seqlen,
seq_lens_tensor=torch.tensor([]),
max_decode_seq_len=0,
num_prefills=len(seq_lens),
num_prefill_tokens=num_prompt_tokens,
num_decode_tokens=0,
block_tables=torch.tensor([], device=self.device, dtype=torch.int),
)
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
return (input_tokens, input_positions, attn_metadata, seq_lens,
multi_modal_kwargs)
def _prepare_decode(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[int] = []
input_positions: List[int] = []
slot_mapping: List[int] = []
seq_lens: List[int] = []
block_tables: List[List[int]] = []
for seq_group_metadata in seq_group_metadata_list:
assert not seq_group_metadata.is_prompt
assert seq_group_metadata.token_chunk_size == 1
seq_ids = list(seq_group_metadata.seq_data.keys())
for seq_id in seq_ids:
seq_data = seq_group_metadata.seq_data[seq_id]
generation_token = seq_data.get_last_token_id()
input_tokens.append(generation_token)
seq_len = seq_data.get_len()
position = seq_len - 1
input_positions.append(position)
seq_len = seq_len if self.sliding_window is None else min(
seq_len, self.sliding_window)
seq_lens.append(seq_len)
block_table = seq_group_metadata.block_tables[seq_id]
block_number = block_table[position // self.block_size]
block_offset = position % self.block_size
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)
if self.sliding_window is not None:
sliding_window_blocks = (self.sliding_window //
self.block_size)
block_table = block_table[-sliding_window_blocks:]
block_tables.append(block_table)
max_decode_seq_len = max(seq_lens)
input_tokens = torch.tensor(input_tokens,
dtype=torch.long,
device=self.device)
input_positions = torch.tensor(input_positions,
dtype=torch.long,
device=self.device)
slot_mapping = torch.tensor(slot_mapping,
dtype=torch.long,
device=self.device)
seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int,
device=self.device)
block_tables = make_tensor_with_pad(
block_tables,
pad=0,
dtype=torch.int,
device=self.device,
)
attn_metadata = self.attn_backend.make_metadata(
is_prompt=False,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
seq_lens=seq_lens,
seqlen_q=torch.tensor([]),
max_seqlen=0,
seq_lens_tensor=seq_lens_tensor,
max_decode_seq_len=max_decode_seq_len,
num_prefill_tokens=0,
num_decode_tokens=len(input_tokens),
num_prefills=0,
block_tables=block_tables,
)
return (
input_tokens,
input_positions,
attn_metadata,
)
class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
_model_input_cls: Type[ModelInputForXPUWithSamplingMetadata] = (
ModelInputForXPUWithSamplingMetadata)
_builder_cls: Type[ModelInputForXPUBuilder] = ModelInputForXPUBuilder
def __init__(
self,
vllm_config: VllmConfig,
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
return_hidden_states: bool = False,
input_registry: InputRegistry = INPUT_REGISTRY,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
):
ModelRunnerBase.__init__(self, vllm_config=vllm_config)
model_config = self.model_config
cache_config = self.cache_config
self.is_driver_worker = is_driver_worker
self.return_hidden_states = return_hidden_states
self.device = self.device_config.device
self.kv_cache_dtype = kv_cache_dtype
self.sliding_window = model_config.get_sliding_window()
self.block_size = cache_config.block_size
self.attn_backend = get_attn_backend(
self.model_config.get_head_size(),
self.model_config.dtype,
self.kv_cache_dtype,
self.block_size,
self.model_config.is_attention_free,
)
# Multi-modal data support
self.input_registry = input_registry
self.mm_registry = mm_registry
# Lazy initialization.
self.model: nn.Module # Set after init_Model
self.sampler = get_sampler()
self.sampling_metadata_cache: SamplingMetadataCache = \
SamplingMetadataCache() \
if self.parallel_config.pipeline_parallel_size == 1 else None
self.builder = self._builder_cls(weakref.proxy(self))
def load_model(self) -> None:
with DeviceMemoryProfiler() as m:
self.model = get_model(vllm_config=self.vllm_config)
self.model_memory_usage = m.consumed_memory
logger.info("Loading model weights took %.4f GiB",
self.model_memory_usage / GiB_bytes)
def get_model(self) -> nn.Module:
return self.model
@property
def vocab_size(self) -> int:
return self.model_config.get_vocab_size()
@torch.inference_mode()
def profile_run(self) -> None:
# Enable top-k sampling to reflect the accurate memory usage.
sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
max_num_seqs = self.scheduler_config.max_num_seqs
# Profile memory usage with max_num_sequences sequences and the total
# number of tokens equal to max_num_batched_tokens.
seqs: List[SequenceGroupMetadata] = []
# Additional GPU memory may be needed for multi-modal encoding, which
# needs to be accounted for when calculating the GPU blocks for
# vLLM blocker manager.
# To exercise the worst scenario for GPU memory consumption,
# the number of seqs (batch_size) is chosen to maximize the number
# of images processed.
max_mm_tokens = self.mm_registry.get_max_multimodal_tokens(
self.model_config)
if max_mm_tokens > 0:
max_num_seqs_orig = max_num_seqs
max_num_seqs = min(max_num_seqs,
max_num_batched_tokens // max_mm_tokens)
if max_num_seqs < 1:
expr = (f"min({max_num_seqs_orig}, "
f"{max_num_batched_tokens} // {max_mm_tokens})")
logger.warning(
"Computed max_num_seqs (%s) to be less than 1. "
"Setting it to the minimum value of 1.", expr)
max_num_seqs = 1
batch_size = 0
for group_id in range(max_num_seqs):
seq_len = (max_num_batched_tokens // max_num_seqs +
(group_id < max_num_batched_tokens % max_num_seqs))
batch_size += seq_len
dummy_data = self.input_registry \
.dummy_data_for_profiling(self.model_config,
seq_len,
self.mm_registry)
seq = SequenceGroupMetadata(
request_id=str(group_id),
is_prompt=True,
seq_data={group_id: dummy_data.seq_data},
sampling_params=sampling_params,
block_tables=None,
lora_request=None,
multi_modal_data=dummy_data.multi_modal_data,
multi_modal_placeholders=dummy_data.multi_modal_placeholders)
seqs.append(seq)
finished_requests_ids = [seq.request_id for seq in seqs]
model_input = self.prepare_model_input(
seqs, finished_requests_ids=finished_requests_ids)
intermediate_tensors = None
if not get_pp_group().is_first_rank:
intermediate_tensors = self.model.make_empty_intermediate_tensors(
batch_size=batch_size,
dtype=self.model_config.dtype,
device=self.device)
self.execute_model(model_input, None, intermediate_tensors)
torch.xpu.synchronize()
return
def make_model_input_from_broadcasted_tensor_dict(
self,
tensor_dict: Dict[str,
Any]) -> ModelInputForXPUWithSamplingMetadata:
return (
ModelInputForXPUWithSamplingMetadata.from_broadcasted_tensor_dict(
tensor_dict,
attn_backend=self.attn_backend,
))
def _prepare_model_input_tensors(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
finished_requests_ids: Optional[List[str]] = None
) -> ModelInputForXPUWithSamplingMetadata:
"""Helper method to prepare the model input based on a given sequence
group. Prepares metadata needed for the base model forward pass but not
metadata for possible additional steps, e.g., sampling.
"""
builder = self.builder
builder.prepare(finished_requests_ids)
for seq_group_metadata in seq_group_metadata_list:
builder.add_seq_group(seq_group_metadata)
return builder.build() # type: ignore
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> ModelInputForXPUWithSamplingMetadata:
"""Prepare the model input based on a given sequence group, including
metadata for the sampling step.
"""
model_input = self._prepare_model_input_tensors(
seq_group_metadata_list, finished_requests_ids)
# Sampling metadata is only required for the final pp group
generators = self.get_generators(finished_requests_ids)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
model_input.seq_lens,
model_input.query_lens,
self.device,
pin_memory=False,
generators=generators,
cache=self.sampling_metadata_cache)
return dataclasses.replace(model_input,
sampling_metadata=sampling_metadata,
virtual_engine=virtual_engine)
@torch.inference_mode()
def execute_model(
self,
model_input: ModelInputForXPUWithSamplingMetadata,
kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> Optional[List[SamplerOutput]]:
if num_steps > 1:
raise ValueError(
"XPUModelRunner does not support multi-step execution.")
model_executable = self.model
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time):
model_forward_start_time = time.time()
with set_forward_context(model_input.attn_metadata, self.vllm_config,
model_input.virtual_engine):
hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs or {},
device=self.device,
),
)
# Compute the logits in the last pipeline stage.
if not get_pp_group().is_last_rank:
return hidden_or_intermediate_states
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time):
model_forward_end_time = time.time()
# Compute the logits.
logits = self.model.compute_logits(hidden_or_intermediate_states,
model_input.sampling_metadata)
# Only perform sampling in the driver worker.
if not self.is_driver_worker:
return []
if model_input.async_callback is not None:
model_input.async_callback()
# Sample the next token.
output: SamplerOutput = self.sampler(
logits=logits,
sampling_metadata=model_input.sampling_metadata,
)
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time
and output is not None):
model_forward_time = (model_forward_end_time -
model_forward_start_time)
# If there are multiple workers, we are still tracking the latency
# from the start time of the driver worker to the end time of the
# driver worker. The model forward time will then end up covering
# the communication time as well.
output.model_forward_time = model_forward_time
return [output]

186
vllm/worker/xpu_worker.py Normal file
View File

@@ -0,0 +1,186 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A XPU worker class."""
import gc
import os
from typing import List, Optional, Tuple
import intel_extension_for_pytorch # noqa: F401
import oneccl_bindings_for_pytorch # noqa: F401
import torch
import torch.distributed
from vllm.config import VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.distributed.parallel_state import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.worker import Worker
from vllm.worker.worker_base import LoRANotSupportedWorkerBase, WorkerBase
from vllm.worker.xpu_model_runner import XPUModelRunner
logger = init_logger(__name__)
class XPUWorker(LoRANotSupportedWorkerBase, Worker):
"""A worker class that executes (a partition of) the model on a GPU.
Each worker is associated with a single XPU device. The worker is
responsible for maintaining the KV cache and executing the model on the
XPU. In case of distributed inference, each worker is assigned a partition
of the model.
"""
def __init__(
self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool = False,
) -> None:
WorkerBase.__init__(self, vllm_config=vllm_config)
device_config = self.device_config
parallel_config = self.parallel_config
assert device_config.device_type == "xpu"
assert current_platform.is_xpu()
self.parallel_config.rank = rank
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
self.is_driver_worker = is_driver_worker
if parallel_config and is_driver_worker:
assert rank % parallel_config.tensor_parallel_size == 0, \
"Driver worker should be rank 0 of tensor parallel group."
self.model_runner = XPUModelRunner( # type: ignore
vllm_config=vllm_config,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=is_driver_worker,
)
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
self.cache_engine: List[CacheEngine]
self.gpu_cache: Optional[List[List[torch.Tensor]]]
def init_device(self) -> None:
if self.device_config.device.type == "xpu" and current_platform.is_xpu(
):
self.device = torch.device(f"xpu:{self.local_rank}")
torch.xpu.set_device(self.device)
torch.xpu.empty_cache()
self.init_gpu_memory = torch.xpu.get_device_properties(
self.local_rank).total_memory
else:
raise RuntimeError(
f"Not support device type: {self.device_config.device}")
# Initialize the distributed environment.
self.init_worker_distributed_environment()
# Initialize the model.
set_random_seed(self.model_config.seed)
# keep this method for `empty_cache` and `synchronize` api
@torch.inference_mode()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Profiles the peak memory usage of the model to determine how many
KV blocks may be allocated without OOMs.
The engine will first conduct a profiling of the existing memory usage.
Then, it calculate the maximum possible number of GPU and CPU blocks
that can be allocated with the remaining free memory.
Tip:
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""
# Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory.
torch.xpu.empty_cache()
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
self.model_runner.profile_run()
# Calculate the number of blocks that can be allocated with the
# profiled peak memory.
torch.xpu.synchronize()
used_memory = torch.xpu.memory_allocated()
total_gpu_memory = torch.xpu.get_device_properties(
self.local_rank).total_memory
free_gpu_memory = total_gpu_memory - used_memory
# NOTE(woosuk): Here we assume that the other processes using the same
# GPU did not change their memory usage during the profiling.
peak_memory = self.init_gpu_memory - free_gpu_memory
assert peak_memory > 0, (
"Error in memory profiling. "
f"Initial free memory {self.init_gpu_memory}, current free memory"
f" {free_gpu_memory}. This happens when the GPU memory was "
"not properly cleaned up before initializing the vLLM instance.")
cache_block_size = self.get_cache_block_size_bytes()
num_gpu_blocks = int(
(total_gpu_memory * self.cache_config.gpu_memory_utilization -
peak_memory) // cache_block_size)
num_cpu_blocks = int(self.cache_config.swap_space_bytes //
cache_block_size)
num_gpu_blocks = max(num_gpu_blocks, 0)
num_cpu_blocks = max(num_cpu_blocks, 0)
gc.collect()
torch.xpu.empty_cache()
return num_gpu_blocks, num_cpu_blocks
def _warm_up_model(self) -> None:
# IPEX don't support capture graph yet
pass
def init_worker_distributed_environment(self) -> None:
"""Initialize the distributed environment."""
parallel_config = self.parallel_config
rank = self.rank
distributed_init_method = self.distributed_init_method
if torch.distributed.is_initialized():
torch_world_size = torch.distributed.get_world_size()
if torch_world_size != parallel_config.world_size:
raise RuntimeError(
"torch.distributed is already initialized but the torch "
"world size does not match parallel_config.world_size "
f"({torch_world_size} vs. {parallel_config.world_size}).")
elif not distributed_init_method:
raise ValueError(
"distributed_init_method must be set if torch.distributed "
"is not already initialized")
else:
# use sockets as default Level zero IPC exchange backend. By
# default oneccl will use `drmfd` as mechanism which need extra
# dependency (libdrm and drm headers) on your system.
ENV_CCL_ATL_TRANSPORT = os.getenv("CCL_ATL_TRANSPORT", "ofi")
ENV_LOCAL_WORLD_SIZE = os.getenv("LOCAL_WORLD_SIZE",
str(parallel_config.world_size))
os.environ["CCL_ATL_TRANSPORT"] = ENV_CCL_ATL_TRANSPORT
os.environ["LOCAL_WORLD_SIZE"] = ENV_LOCAL_WORLD_SIZE
os.environ["LOCAL_RANK"] = str(self.local_rank)
init_distributed_environment(
world_size=parallel_config.world_size,
rank=rank,
distributed_init_method=distributed_init_method,
local_rank=self.local_rank,
backend="ccl")
ensure_model_parallel_initialized(
parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
# global all_reduce needed for overall oneccl warm up
torch.distributed.all_reduce(torch.zeros(1).xpu())
if parallel_config.pipeline_parallel_size > 1:
# Add pp group init to avoid
# p2p communication as the first call
get_pp_group().all_reduce(torch.zeros(1).xpu())