forked from EngineX-Hygon/enginex-hygon-vllm
init src 0.9.2
This commit is contained in:
0
vllm/worker/__init__.py
Normal file
0
vllm/worker/__init__.py
Normal file
155
vllm/worker/cache_engine.py
Normal file
155
vllm/worker/cache_engine.py
Normal file
@@ -0,0 +1,155 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""CacheEngine class for managing the KV cache."""
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention import get_attn_backend
|
||||
from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType,
|
||||
get_dtype_size, is_pin_memory_available)
|
||||
from vllm.attention.backends.tree_decoding_utils import move_cache
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class CacheEngine:
|
||||
"""Manages the KV cache.
|
||||
|
||||
This class is responsible for initializing and managing the GPU and CPU KV
|
||||
caches. It also provides methods for performing KV cache operations, such
|
||||
as swapping and copying.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cache_config: CacheConfig,
|
||||
model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
device_config: DeviceConfig,
|
||||
) -> None:
|
||||
self.cache_config = cache_config
|
||||
self.model_config = model_config
|
||||
self.parallel_config = parallel_config
|
||||
self.device_config = device_config
|
||||
|
||||
self.head_size = model_config.get_head_size()
|
||||
# Models like Jamba, have mixed typed layers, E.g Mamba
|
||||
self.num_attention_layers = model_config.get_num_layers_by_block_type(
|
||||
parallel_config, LayerBlockType.attention)
|
||||
self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
|
||||
|
||||
self.block_size = cache_config.block_size
|
||||
self.num_gpu_blocks = cache_config.num_gpu_blocks
|
||||
if self.num_gpu_blocks:
|
||||
self.num_gpu_blocks //= parallel_config.pipeline_parallel_size
|
||||
self.num_cpu_blocks = cache_config.num_cpu_blocks
|
||||
if self.num_cpu_blocks:
|
||||
self.num_cpu_blocks //= parallel_config.pipeline_parallel_size
|
||||
|
||||
if cache_config.cache_dtype == "auto":
|
||||
self.dtype = model_config.dtype
|
||||
else:
|
||||
self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
|
||||
|
||||
# Get attention backend.
|
||||
self.attn_backend = get_attn_backend(self.head_size,
|
||||
model_config.dtype,
|
||||
cache_config.cache_dtype,
|
||||
self.block_size,
|
||||
model_config.is_attention_free,
|
||||
use_mla=model_config.use_mla)
|
||||
|
||||
# Initialize the cache.
|
||||
self.gpu_cache = self._allocate_kv_cache(
|
||||
self.num_gpu_blocks, self.device_config.device_type)
|
||||
self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks, "cpu")
|
||||
|
||||
def _allocate_kv_cache(
|
||||
self,
|
||||
num_blocks: int,
|
||||
device: str,
|
||||
) -> List[torch.Tensor]:
|
||||
"""Allocates KV cache on the specified device."""
|
||||
kv_cache_generic_shape = self.attn_backend.get_kv_cache_shape(
|
||||
num_blocks, self.block_size, self.num_kv_heads, self.head_size)
|
||||
pin_memory = is_pin_memory_available() if device == "cpu" else False
|
||||
kv_cache: List[torch.Tensor] = []
|
||||
try:
|
||||
kv_cache_stride_order = self.attn_backend.get_kv_cache_stride_order(
|
||||
)
|
||||
except (AttributeError, NotImplementedError):
|
||||
kv_cache_stride_order = tuple(range(len(kv_cache_generic_shape)))
|
||||
|
||||
# The allocation respects the backend-defined stride order to ensure
|
||||
# the semantic remains consistent for each backend. We first obtain the
|
||||
# generic kv cache shape and then permute it according to the stride
|
||||
# order which could result in a non-contiguous tensor.
|
||||
kv_cache_allocation_shape = tuple(kv_cache_generic_shape[i]
|
||||
for i in kv_cache_stride_order)
|
||||
|
||||
for _ in range(self.num_attention_layers):
|
||||
# null block in CpuGpuBlockAllocator requires at least that
|
||||
# block to be zeroed-out.
|
||||
# We zero-out everything for simplicity.
|
||||
layer_kv_cache = torch.zeros(
|
||||
kv_cache_allocation_shape,
|
||||
dtype=self.dtype,
|
||||
pin_memory=pin_memory,
|
||||
device=device).permute(*kv_cache_stride_order)
|
||||
|
||||
# view back to (TOTAL_PAGES, PAGE_SIZE, entry_shape...) for cases
|
||||
# when entry_shape is higher than 1D
|
||||
kv_cache.append(layer_kv_cache)
|
||||
return kv_cache
|
||||
|
||||
def swap_in(self, src_to_dst: torch.Tensor) -> None:
|
||||
for i in range(self.num_attention_layers):
|
||||
self.attn_backend.swap_blocks(self.cpu_cache[i], self.gpu_cache[i],
|
||||
src_to_dst)
|
||||
|
||||
def swap_out(self, src_to_dst: torch.Tensor) -> None:
|
||||
for i in range(self.num_attention_layers):
|
||||
self.attn_backend.swap_blocks(self.gpu_cache[i], self.cpu_cache[i],
|
||||
src_to_dst)
|
||||
|
||||
def copy(self, src_to_dsts: torch.Tensor) -> None:
|
||||
self.attn_backend.copy_blocks(self.gpu_cache, src_to_dsts)
|
||||
|
||||
def move_caches(self, kv_caches: List[torch.Tensor],
|
||||
src_to_dsts: torch.Tensor) -> None:
|
||||
move_cache(self.attn_backend,
|
||||
kv_caches,
|
||||
src_to_dsts,
|
||||
self.cache_config.cache_dtype,
|
||||
self.num_kv_heads,
|
||||
self.head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_cache_block_size(
|
||||
cache_config: CacheConfig,
|
||||
model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
) -> int:
|
||||
head_size = model_config.get_head_size()
|
||||
num_heads = model_config.get_num_kv_heads(parallel_config)
|
||||
num_attention_layers = model_config.get_num_layers_by_block_type(
|
||||
parallel_config, LayerBlockType.attention)
|
||||
|
||||
if cache_config.cache_dtype == "auto":
|
||||
dtype = model_config.dtype
|
||||
else:
|
||||
dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
|
||||
|
||||
key_cache_entry = num_heads * head_size
|
||||
|
||||
# For MLA there is no value cache, since the latent vector
|
||||
# is joint keys and values.
|
||||
value_cache_entry = key_cache_entry if not model_config.use_mla else 0
|
||||
total = num_attention_layers * cache_config.block_size * \
|
||||
(key_cache_entry + value_cache_entry)
|
||||
|
||||
dtype_size = get_dtype_size(dtype)
|
||||
return dtype_size * total
|
||||
326
vllm/worker/cpu_enc_dec_model_runner.py
Normal file
326
vllm/worker/cpu_enc_dec_model_runner.py
Normal file
@@ -0,0 +1,326 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import dataclasses
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, cast
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.multimodal import MultiModalKwargs
|
||||
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
|
||||
from vllm.utils import make_tensor_with_pad
|
||||
from vllm.worker.cpu_model_runner import (CPUModelRunnerBase,
|
||||
ModelInputForCPUBuilder,
|
||||
ModelInputForCPUWithSamplingMetadata)
|
||||
from vllm.worker.model_runner_base import (
|
||||
_add_attn_metadata_broadcastable_dict,
|
||||
_add_sampling_metadata_broadcastable_dict)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class EncoderDecoderModelInputForCPU(ModelInputForCPUWithSamplingMetadata):
|
||||
"""
|
||||
Used by the EncoderDecoderModelRunner.
|
||||
"""
|
||||
encoder_input_tokens: Optional[torch.Tensor] = None
|
||||
encoder_input_positions: Optional[torch.Tensor] = None
|
||||
|
||||
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
||||
tensor_dict = {
|
||||
"input_tokens": self.input_tokens,
|
||||
"input_positions": self.input_positions,
|
||||
"encoder_input_tokens": self.encoder_input_tokens,
|
||||
"encoder_input_positions": self.encoder_input_positions,
|
||||
"multi_modal_kwargs": self.multi_modal_kwargs,
|
||||
}
|
||||
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
|
||||
_add_sampling_metadata_broadcastable_dict(tensor_dict,
|
||||
self.sampling_metadata)
|
||||
return tensor_dict
|
||||
|
||||
@classmethod
|
||||
def from_broadcasted_tensor_dict(
|
||||
cls,
|
||||
tensor_dict: Dict[str, Any],
|
||||
attn_backend: Optional["AttentionBackend"] = None,
|
||||
) -> "EncoderDecoderModelInputForCPU":
|
||||
return cast(
|
||||
EncoderDecoderModelInputForCPU,
|
||||
super().from_broadcasted_tensor_dict(tensor_dict, attn_backend))
|
||||
|
||||
|
||||
class CPUEncoderDecoderModelRunner(
|
||||
CPUModelRunnerBase[EncoderDecoderModelInputForCPU]):
|
||||
_model_input_cls: Type[EncoderDecoderModelInputForCPU] = (
|
||||
EncoderDecoderModelInputForCPU)
|
||||
_builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder
|
||||
|
||||
def _list_to_int32_tensor(
|
||||
self,
|
||||
_list: List[int],
|
||||
) -> torch.Tensor:
|
||||
return torch.tensor(_list, dtype=torch.int32, device=self.device)
|
||||
|
||||
def _list_to_long_tensor(
|
||||
self,
|
||||
_list: List[int],
|
||||
) -> torch.Tensor:
|
||||
return torch.tensor(_list, dtype=torch.long, device=self.device)
|
||||
|
||||
def _empty_int32_tensor(self) -> torch.Tensor:
|
||||
return self._list_to_int32_tensor([])
|
||||
|
||||
def _empty_long_tensor(self) -> torch.Tensor:
|
||||
return self._list_to_long_tensor([])
|
||||
|
||||
def make_model_input_from_broadcasted_tensor_dict(
|
||||
self, tensor_dict: Dict[str,
|
||||
Any]) -> EncoderDecoderModelInputForCPU:
|
||||
return EncoderDecoderModelInputForCPU.from_broadcasted_tensor_dict(
|
||||
tensor_dict,
|
||||
attn_backend=self.attn_backend,
|
||||
)
|
||||
|
||||
def prepare_model_input(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
virtual_engine: int = 0,
|
||||
finished_requests_ids: Optional[List[str]] = None
|
||||
) -> EncoderDecoderModelInputForCPU:
|
||||
model_input = self._prepare_model_input_tensors(
|
||||
seq_group_metadata_list, finished_requests_ids)
|
||||
(
|
||||
attn_metadata,
|
||||
encoder_input_tokens_tensor,
|
||||
encoder_input_positions_tensor,
|
||||
) = self._prepare_encoder_model_input_tensors(seq_group_metadata_list,
|
||||
model_input)
|
||||
# Sampling metadata is only required for the final pp group
|
||||
generators = self.get_generators(finished_requests_ids)
|
||||
sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
|
||||
model_input.seq_lens,
|
||||
model_input.query_lens,
|
||||
self.device,
|
||||
pin_memory=False,
|
||||
generators=generators)
|
||||
return dataclasses.replace(
|
||||
model_input,
|
||||
sampling_metadata=sampling_metadata,
|
||||
attn_metadata=attn_metadata,
|
||||
encoder_input_tokens=encoder_input_tokens_tensor,
|
||||
encoder_input_positions=encoder_input_positions_tensor,
|
||||
virtual_engine=virtual_engine,
|
||||
)
|
||||
|
||||
def _prepare_encoder_model_input_tensors(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
model_input: EncoderDecoderModelInputForCPU,
|
||||
) -> Tuple[AttentionMetadata, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
"""Helper method to prepare the encoder- and cross-attn-related
|
||||
model inputs based on a given sequence group. These additional inputs
|
||||
are used to augment an already-computed `EncoderDecoderModelInput`
|
||||
data structure which already has decoder-related model inputs
|
||||
populated.
|
||||
|
||||
Sets the following attn_metadata fields:
|
||||
* `num_encoder_tokens`
|
||||
* `encoder_seq_lens`
|
||||
* `encoder_seq_lens_tensor`
|
||||
* `max_encoder_seq_len`
|
||||
* `cross_slot_mapping`
|
||||
* `cross_block_tables`
|
||||
|
||||
Constructs a new model inputs data structure, based on
|
||||
(1) the existing fields in the `model_inputs` argument,
|
||||
and (2) the following additional fields which are
|
||||
computed (or in the case of `attn_metadata`, updated)
|
||||
by this function:
|
||||
* attn_metadata
|
||||
* encoder_input_tokens
|
||||
* encoder_input_positions
|
||||
|
||||
Arguments:
|
||||
|
||||
* seq_group_metadata_list: list of sequence groups for which to
|
||||
compute inputs
|
||||
* model_inputs: model inputs data structure with decoder-oriented
|
||||
fields already computed.
|
||||
|
||||
Return:
|
||||
|
||||
* Updated model inputs data structure
|
||||
"""
|
||||
|
||||
if len(seq_group_metadata_list) == 0:
|
||||
return (model_input.attn_metadata, None, None)
|
||||
|
||||
# Since we are not supporting chunked prefill either the entire
|
||||
# batch is prefill or it is decode
|
||||
is_prompt = seq_group_metadata_list[0].is_prompt
|
||||
|
||||
# Build encoder inputs
|
||||
encoder_seq_lens: List[int] = []
|
||||
if is_prompt:
|
||||
# Prefill phase.
|
||||
cross_block_tables = self._empty_int32_tensor().view(
|
||||
len(seq_group_metadata_list), -1)
|
||||
|
||||
# Extract input tokens/positions, cross-attention slot-mapping,
|
||||
# & seq len from each sequence group metadata
|
||||
(
|
||||
encoder_input_tokens,
|
||||
encoder_input_positions,
|
||||
cross_slot_mapping,
|
||||
) = (
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
)
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
# Build seq lens
|
||||
seq_len = seq_group_metadata.encoder_seq_data.get_len()
|
||||
token_ids = seq_group_metadata.encoder_seq_data.get_token_ids()
|
||||
encoder_seq_lens.append(seq_len)
|
||||
|
||||
# Build slot mapping
|
||||
for i in range(0, seq_len):
|
||||
block_number = seq_group_metadata.cross_block_table[
|
||||
i // self.block_size]
|
||||
block_offset = i % self.block_size
|
||||
slot = block_number * self.block_size + block_offset
|
||||
cross_slot_mapping.append(slot)
|
||||
|
||||
# Build encoder input tokens
|
||||
encoder_input_tokens.extend(token_ids)
|
||||
encoder_input_positions.extend(list(range(0, seq_len)))
|
||||
|
||||
# Convert tokens/positions & cross-attention
|
||||
# slot-mapping to encoder input tensors
|
||||
encoder_input_tokens_tensor = self._list_to_long_tensor(
|
||||
encoder_input_tokens)
|
||||
encoder_input_positions_tensor = self._list_to_long_tensor(
|
||||
encoder_input_positions)
|
||||
cross_slot_mapping_tensor = self._list_to_long_tensor(
|
||||
cross_slot_mapping)
|
||||
|
||||
else:
|
||||
# Decode phase.
|
||||
encoder_input_tokens_tensor = self._empty_long_tensor()
|
||||
encoder_input_positions_tensor = self._empty_long_tensor()
|
||||
cross_slot_mapping_tensor = self._empty_long_tensor()
|
||||
# Extract cross-attention block tables &
|
||||
# seq len from each sequence group metadata.
|
||||
# Cross-attention block tables are empty
|
||||
# during vLLM memory profiling.
|
||||
cross_block_tables = []
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
for _ in range(len(seq_group_metadata.seq_data)):
|
||||
encoder_seq_lens.append(
|
||||
seq_group_metadata.encoder_seq_data.get_len())
|
||||
cross_block_table = seq_group_metadata.cross_block_table
|
||||
cross_block_tables.append([] if (
|
||||
cross_block_table is None) else cross_block_table)
|
||||
|
||||
max_len_of_block_table = max(
|
||||
len(block_table) for block_table in cross_block_tables)
|
||||
|
||||
cross_block_tables = make_tensor_with_pad(
|
||||
cross_block_tables,
|
||||
max_len=max_len_of_block_table,
|
||||
pad=0,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
# Compute encoder sequence lengths & encoder
|
||||
# sequence starting offset tensors
|
||||
max_encoder_seq_len = max(encoder_seq_lens, default=0)
|
||||
encoder_seq_lens_tensor = self._list_to_int32_tensor(encoder_seq_lens)
|
||||
encoder_seq_start_loc = torch.zeros(encoder_seq_lens_tensor.shape[0] +
|
||||
1,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
torch.cumsum(encoder_seq_lens_tensor,
|
||||
dim=0,
|
||||
dtype=encoder_seq_start_loc.dtype,
|
||||
out=encoder_seq_start_loc[1:])
|
||||
|
||||
# Update attention metadata with encoder-oriented attributes
|
||||
attn_metadata = model_input.attn_metadata
|
||||
assert attn_metadata is not None
|
||||
(
|
||||
attn_metadata.num_encoder_tokens,
|
||||
attn_metadata.encoder_seq_lens,
|
||||
attn_metadata.encoder_seq_lens_tensor,
|
||||
attn_metadata.max_encoder_seq_len,
|
||||
attn_metadata.cross_slot_mapping,
|
||||
attn_metadata.cross_block_tables,
|
||||
) = (
|
||||
sum(encoder_seq_lens),
|
||||
encoder_seq_lens,
|
||||
encoder_seq_lens_tensor,
|
||||
max_encoder_seq_len,
|
||||
cross_slot_mapping_tensor,
|
||||
cross_block_tables,
|
||||
)
|
||||
|
||||
return (attn_metadata, encoder_input_tokens_tensor,
|
||||
encoder_input_positions_tensor)
|
||||
|
||||
@torch.no_grad()
|
||||
def execute_model(
|
||||
self,
|
||||
model_input: EncoderDecoderModelInputForCPU,
|
||||
kv_caches: List[torch.Tensor],
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
num_steps: int = 1,
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
if num_steps > 1:
|
||||
raise ValueError(
|
||||
"CPU worker does not support multi-step execution.")
|
||||
|
||||
model_executable = self.model
|
||||
execute_model_kwargs = {
|
||||
"input_ids":
|
||||
model_input.input_tokens,
|
||||
"positions":
|
||||
model_input.input_positions,
|
||||
"encoder_input_ids":
|
||||
model_input.encoder_input_tokens,
|
||||
"encoder_positions":
|
||||
model_input.encoder_input_positions,
|
||||
**MultiModalKwargs.as_kwargs(
|
||||
model_input.multi_modal_kwargs or {},
|
||||
device=self.device,
|
||||
),
|
||||
"intermediate_tensors":
|
||||
intermediate_tensors,
|
||||
}
|
||||
|
||||
with set_forward_context(model_input.attn_metadata, self.vllm_config,
|
||||
model_input.virtual_engine):
|
||||
hidden_states = model_executable(**execute_model_kwargs)
|
||||
|
||||
# Compute the logits.
|
||||
logits = self.model.compute_logits(hidden_states,
|
||||
model_input.sampling_metadata)
|
||||
|
||||
# Only perform sampling in the driver worker.
|
||||
if not self.is_driver_worker:
|
||||
return []
|
||||
|
||||
# Sample the next token.
|
||||
output = self.sampler(
|
||||
logits=logits,
|
||||
sampling_metadata=model_input.sampling_metadata,
|
||||
)
|
||||
return [output]
|
||||
671
vllm/worker/cpu_model_runner.py
Normal file
671
vllm/worker/cpu_model_runner.py
Normal file
@@ -0,0 +1,671 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import dataclasses
|
||||
import weakref
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Type,
|
||||
TypeVar, Union)
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.attention import AttentionMetadata, get_attn_backend
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import LoRAMapping
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.models import supports_lora, supports_multimodal
|
||||
from vllm.multimodal import (BatchedTensorInputs, MultiModalKwargs,
|
||||
MultiModalPlaceholderMap)
|
||||
from vllm.sequence import (IntermediateTensors, SequenceData,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.worker.model_runner_base import (
|
||||
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
|
||||
_add_attn_metadata_broadcastable_dict,
|
||||
_add_sampling_metadata_broadcastable_dict,
|
||||
_init_attn_metadata_from_tensor_dict,
|
||||
_init_sampling_metadata_from_tensor_dict)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
TModelInputForCPU = TypeVar('TModelInputForCPU', bound="ModelInputForCPU")
|
||||
_PAD_SLOT_ID = -1
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ModelInputForCPU(ModelRunnerInputBase):
|
||||
"""
|
||||
Base class contains metadata needed for the base model forward pass on CPU
|
||||
"""
|
||||
input_tokens: Optional[torch.Tensor] = None
|
||||
input_positions: Optional[torch.Tensor] = None
|
||||
token_type_ids: Optional[torch.Tensor] = None
|
||||
attn_metadata: Optional["AttentionMetadata"] = None
|
||||
multi_modal_kwargs: Optional[BatchedTensorInputs] = None
|
||||
virtual_engine: Optional[int] = None
|
||||
seq_lens: Optional[List[int]] = None
|
||||
query_lens: Optional[List[int]] = None
|
||||
lora_mapping: Optional["LoRAMapping"] = None
|
||||
lora_requests: Optional[Set[LoRARequest]] = None
|
||||
|
||||
def as_broadcastable_tensor_dict(
|
||||
self) -> Dict[str, Union[int, torch.Tensor]]:
|
||||
tensor_dict = {
|
||||
"input_tokens": self.input_tokens,
|
||||
"input_positions": self.input_positions,
|
||||
"token_type_ids": self.token_type_ids,
|
||||
"multi_modal_kwargs": self.multi_modal_kwargs,
|
||||
"lora_requests": self.lora_requests,
|
||||
"lora_mapping": self.lora_mapping,
|
||||
}
|
||||
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
|
||||
|
||||
return tensor_dict
|
||||
|
||||
@classmethod
|
||||
def from_broadcasted_tensor_dict(
|
||||
cls: Type[TModelInputForCPU],
|
||||
tensor_dict: Dict[str, Any],
|
||||
attn_backend: Optional["AttentionBackend"] = None
|
||||
) -> TModelInputForCPU:
|
||||
if attn_backend is not None:
|
||||
tensor_dict = _init_attn_metadata_from_tensor_dict(
|
||||
attn_backend, tensor_dict)
|
||||
return cls(**tensor_dict)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ModelInputForCPUWithSamplingMetadata(ModelInputForCPU):
|
||||
"""
|
||||
Used by the ModelRunner.
|
||||
"""
|
||||
sampling_metadata: Optional["SamplingMetadata"] = None
|
||||
is_prompt: Optional[bool] = None
|
||||
|
||||
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
||||
tensor_dict = {
|
||||
"input_tokens": self.input_tokens,
|
||||
"input_positions": self.input_positions,
|
||||
"token_type_ids": self.token_type_ids,
|
||||
"multi_modal_kwargs": self.multi_modal_kwargs,
|
||||
}
|
||||
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
|
||||
_add_sampling_metadata_broadcastable_dict(tensor_dict,
|
||||
self.sampling_metadata)
|
||||
return tensor_dict
|
||||
|
||||
@classmethod
|
||||
def from_broadcasted_tensor_dict(
|
||||
cls,
|
||||
tensor_dict: Dict[str, Any],
|
||||
attn_backend: Optional["AttentionBackend"] = None,
|
||||
) -> "ModelInputForCPUWithSamplingMetadata":
|
||||
tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
|
||||
if attn_backend is not None:
|
||||
tensor_dict = _init_attn_metadata_from_tensor_dict(
|
||||
attn_backend, tensor_dict)
|
||||
return cls(**tensor_dict)
|
||||
|
||||
|
||||
class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
|
||||
|
||||
class ModelInputData:
|
||||
|
||||
def __init__(self, use_mrope: bool):
|
||||
self.use_mrope = use_mrope
|
||||
self.input_tokens: List[int] = []
|
||||
self.input_positions: List[int] = []
|
||||
self.token_type_ids: Optional[List[int]] = []
|
||||
self.seq_lens: List[int] = []
|
||||
self.query_lens: List[int] = []
|
||||
self.prefill_block_tables: List[List[int]] = []
|
||||
self.decode_block_tables: List[List[int]] = []
|
||||
self.max_decode_seq_len: int = 0
|
||||
self.num_prefills: int = 0
|
||||
self.num_prefill_tokens: int = 0
|
||||
self.num_decode_tokens: int = 0
|
||||
self.slot_mapping: List[int] = []
|
||||
self.multi_modal_inputs_list: List[MultiModalKwargs] = []
|
||||
self.multi_modal_placeholder_maps: Dict[
|
||||
str, MultiModalPlaceholderMap] = defaultdict(
|
||||
MultiModalPlaceholderMap)
|
||||
self.input_mrope_positions: List[List[int]] = [[]
|
||||
for _ in range(3)]
|
||||
|
||||
def __init__(self,
|
||||
runner: "CPUModelRunner",
|
||||
finished_requests_ids: Optional[List[str]] = None) -> None:
|
||||
super().__init__()
|
||||
self.runner = runner
|
||||
self.chunked_prefill = (runner.scheduler_config.chunked_prefill_enabled
|
||||
or runner.cache_config.enable_prefix_caching)
|
||||
self.model_input_cls = self.runner._model_input_cls
|
||||
self.attn_backend = self.runner.attn_backend
|
||||
self.sliding_window = self.runner.sliding_window
|
||||
self.block_size = self.runner.block_size
|
||||
self.device = self.runner.device
|
||||
self.enable_lora = self.runner.lora_config is not None
|
||||
if self.runner.attn_backend is not None:
|
||||
# spec decode (e.g. Medusa) does not have atten backend
|
||||
attn_backend = self.runner.attn_backend
|
||||
self.att_metadata_builder = attn_backend.get_builder_cls()(self)
|
||||
|
||||
def prepare(self,
|
||||
finished_requests_ids: Optional[List[str]] = None) -> None:
|
||||
self.seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||
self.input_data = ModelInputForCPUBuilder.ModelInputData(
|
||||
self.runner.model_config.uses_mrope)
|
||||
self.att_metadata_builder.prepare()
|
||||
|
||||
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
|
||||
self.seq_group_metadata_list.append(seq_group_metadata)
|
||||
|
||||
def set_seq_group_list(
|
||||
self, seq_group_metadata_list: List[SequenceGroupMetadata]):
|
||||
self.seq_group_metadata_list = seq_group_metadata_list
|
||||
|
||||
def build(self) -> ModelInputForCPU:
|
||||
self._build_input_data()
|
||||
|
||||
input_data = self.input_data
|
||||
input_tokens = torch.tensor(input_data.input_tokens,
|
||||
dtype=torch.long,
|
||||
device="cpu")
|
||||
input_positions = torch.tensor(
|
||||
input_data.input_positions
|
||||
if not any(input_data.input_mrope_positions) else
|
||||
input_data.input_mrope_positions,
|
||||
dtype=torch.long,
|
||||
device="cpu")
|
||||
token_type_ids = torch.tensor(input_data.token_type_ids,
|
||||
dtype=torch.long,
|
||||
device="cpu") \
|
||||
if input_data.token_type_ids else None
|
||||
|
||||
# For multi-modal models
|
||||
multi_modal_kwargs = None
|
||||
if len(input_data.multi_modal_inputs_list) != 0:
|
||||
multi_modal_kwargs = MultiModalKwargs.batch(
|
||||
input_data.multi_modal_inputs_list)
|
||||
|
||||
attn_metadata = self.att_metadata_builder.build(
|
||||
input_data.seq_lens, input_data.query_lens, -1, -1)
|
||||
|
||||
is_prompt = (self.seq_group_metadata_list[0].is_prompt
|
||||
if self.seq_group_metadata_list else None)
|
||||
# LoRA data.
|
||||
lora_requests = set()
|
||||
lora_mapping = None
|
||||
if self.enable_lora:
|
||||
lora_requests = set(seq.lora_request
|
||||
for seq in self.seq_group_metadata_list
|
||||
if seq.lora_request is not None)
|
||||
|
||||
lora_mapping = self._prepare_lora_input(
|
||||
self.seq_group_metadata_list, is_prompt)
|
||||
|
||||
return self.model_input_cls(input_tokens=input_tokens,
|
||||
input_positions=input_positions,
|
||||
token_type_ids=token_type_ids,
|
||||
seq_lens=input_data.seq_lens,
|
||||
query_lens=input_data.query_lens,
|
||||
attn_metadata=attn_metadata,
|
||||
multi_modal_kwargs=multi_modal_kwargs,
|
||||
lora_mapping=lora_mapping,
|
||||
lora_requests=lora_requests)
|
||||
|
||||
def _build_input_data(self):
|
||||
for seq_group_metadata in self.seq_group_metadata_list:
|
||||
for seq_id, seq_data in seq_group_metadata.seq_data.items():
|
||||
if seq_group_metadata.is_prompt:
|
||||
self._compute_prompt_input_tokens(self.input_data,
|
||||
seq_group_metadata,
|
||||
seq_data, seq_id)
|
||||
if seq_group_metadata.multi_modal_data:
|
||||
self._compute_multi_modal_input(
|
||||
seq_group_metadata, seq_data)
|
||||
else:
|
||||
self._compute_decode_input_tokens(self.input_data,
|
||||
seq_group_metadata,
|
||||
seq_data, seq_id)
|
||||
|
||||
def _compute_decode_input_tokens(self, data: ModelInputData,
|
||||
seq_group_metadata: SequenceGroupMetadata,
|
||||
seq_data: SequenceData, seq_id: int):
|
||||
"""
|
||||
Compute decode input tokens, positions, block table and slot mapping.
|
||||
"""
|
||||
block_size = self.runner.block_size
|
||||
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
seq_len = seq_data.get_len()
|
||||
context_len = seq_data.get_num_computed_tokens()
|
||||
|
||||
tokens = seq_data.get_last_token_id()
|
||||
token_positions = seq_len - 1
|
||||
block_number = block_table[token_positions // block_size]
|
||||
block_offset = token_positions % block_size
|
||||
slot = block_number * block_size + block_offset
|
||||
|
||||
# For paged_attention kernel
|
||||
if self.runner.sliding_window:
|
||||
start_idx = max(0, seq_len - self.runner.sliding_window)
|
||||
start_block = start_idx // block_size
|
||||
start_idx = start_block * block_size
|
||||
seq_len = seq_len - start_idx
|
||||
block_table = block_table[start_block:]
|
||||
|
||||
# For MRotaryEmbedding
|
||||
if seq_data.mrope_position_delta is not None:
|
||||
next_pos = MRotaryEmbedding.get_next_input_positions(
|
||||
seq_data.mrope_position_delta,
|
||||
context_len,
|
||||
seq_len,
|
||||
)
|
||||
for idx in range(3):
|
||||
data.input_mrope_positions[idx].extend( # type: ignore
|
||||
next_pos[idx])
|
||||
else:
|
||||
data.input_positions.append(token_positions) # type: ignore
|
||||
|
||||
# Update fields
|
||||
data.input_tokens.append(tokens)
|
||||
data.max_decode_seq_len = max(data.max_decode_seq_len, seq_len)
|
||||
data.num_decode_tokens += 1
|
||||
data.slot_mapping.append(slot)
|
||||
data.decode_block_tables.append(block_table)
|
||||
data.query_lens.append(1)
|
||||
data.seq_lens.append(seq_len)
|
||||
|
||||
def _compute_prompt_input_tokens(self, data: ModelInputData,
|
||||
seq_group_metadata: SequenceGroupMetadata,
|
||||
seq_data: SequenceData, seq_id: int):
|
||||
"""
|
||||
Compute prompt input tokens, positions, block table and slot mapping.
|
||||
"""
|
||||
token_chunk_size = seq_group_metadata.token_chunk_size
|
||||
block_size = self.runner.block_size
|
||||
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
seq_len = seq_data.get_len()
|
||||
context_len = seq_data.get_num_computed_tokens()
|
||||
seq_len = min(seq_len, context_len + token_chunk_size)
|
||||
|
||||
# For prefix caching
|
||||
prefix_cache_block_num = len(seq_group_metadata.computed_block_nums)
|
||||
if prefix_cache_block_num > 0:
|
||||
prefix_cache_len = (prefix_cache_block_num *
|
||||
self.runner.block_size)
|
||||
if prefix_cache_len <= context_len:
|
||||
# We already passed the cache hit region,
|
||||
# so do normal computation.
|
||||
pass
|
||||
elif context_len < prefix_cache_len < seq_len:
|
||||
# Partial hit. Compute the missing part.
|
||||
context_len = prefix_cache_len
|
||||
token_chunk_size = seq_len - context_len
|
||||
elif seq_len <= prefix_cache_len:
|
||||
# Full hit. Only compute the last token to avoid
|
||||
# erroneous behavior. FIXME: Ideally we should directly
|
||||
# mark all tokens as computed in the scheduler and do not
|
||||
# schedule this sequence, so this case should not happen.
|
||||
context_len = seq_len - 1
|
||||
token_chunk_size = 1
|
||||
|
||||
tokens = seq_data.get_token_ids()
|
||||
tokens = tokens[context_len:seq_len]
|
||||
token_positions = range(context_len, seq_len)
|
||||
token_types = seq_group_metadata.token_type_ids
|
||||
|
||||
# For encoder-only models, the block_table is None,
|
||||
# and there is no need to initialize the slot_mapping.
|
||||
if block_table is not None:
|
||||
slot_mapping = [_PAD_SLOT_ID] * len(token_positions)
|
||||
for i, pos in enumerate(token_positions):
|
||||
block_number = block_table[pos // block_size]
|
||||
block_offset = pos % block_size
|
||||
slot = block_number * block_size + block_offset
|
||||
slot_mapping[i] = slot
|
||||
data.slot_mapping.extend(slot_mapping)
|
||||
|
||||
# The MROPE positions are prepared in _compute_multi_modal_input
|
||||
data.input_positions.extend(token_positions)
|
||||
|
||||
if data.token_type_ids is not None:
|
||||
data.token_type_ids.extend(token_types if token_types else [])
|
||||
|
||||
# Update fields
|
||||
data.input_tokens.extend(tokens)
|
||||
data.num_prefills += 1
|
||||
data.num_prefill_tokens += len(tokens)
|
||||
data.query_lens.append(len(tokens))
|
||||
data.prefill_block_tables.append(block_table)
|
||||
data.seq_lens.append(seq_len)
|
||||
|
||||
def _compute_multi_modal_input(self,
|
||||
seq_group_metadata: SequenceGroupMetadata,
|
||||
seq_data: SequenceData):
|
||||
computed_len = seq_data.get_num_computed_tokens()
|
||||
seq_len = self.input_data.seq_lens[-1]
|
||||
|
||||
# NOTE: mm_kwargs only includes the subset of multi-modal items that
|
||||
# intersect with the current prefill positions.
|
||||
mm_kwargs, placeholder_maps = MultiModalPlaceholderMap.from_seq_group(
|
||||
seq_group_metadata, range(computed_len, seq_len))
|
||||
|
||||
if not mm_kwargs:
|
||||
return
|
||||
|
||||
# special processing for mrope position deltas.
|
||||
if self.runner.model_config.uses_mrope:
|
||||
assert not self.chunked_prefill, \
|
||||
"MROPE on CPU does not support chunked-prefill."
|
||||
|
||||
image_grid_thw = mm_kwargs.get("image_grid_thw", None)
|
||||
video_grid_thw = mm_kwargs.get("video_grid_thw", None)
|
||||
audio_feature_lengths = mm_kwargs.get("audio_feature_lengths",
|
||||
None)
|
||||
assert (
|
||||
image_grid_thw is not None or video_grid_thw is not None
|
||||
or audio_feature_lengths is not None), (
|
||||
"mrope embedding type requires multi-modal input mapper "
|
||||
"returns 'image_grid_thw' or 'video_grid_thw' or "
|
||||
"'audio_feature_lengths'.")
|
||||
|
||||
second_per_grid_ts = mm_kwargs.get("second_per_grid_ts", None)
|
||||
use_audio_in_video = mm_kwargs.get("use_audio_in_video", False)
|
||||
hf_config = self.runner.model_config.hf_config
|
||||
token_ids = seq_data.get_token_ids()
|
||||
|
||||
mrope_positions, mrope_position_delta = \
|
||||
MRotaryEmbedding.get_input_positions(
|
||||
token_ids,
|
||||
hf_config=hf_config,
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
second_per_grid_ts=second_per_grid_ts,
|
||||
context_len=computed_len,
|
||||
audio_feature_lengths=audio_feature_lengths,
|
||||
use_audio_in_video=use_audio_in_video,
|
||||
)
|
||||
seq_data.mrope_position_delta = mrope_position_delta
|
||||
|
||||
for i in range(3):
|
||||
self.input_data.input_mrope_positions[ # type: ignore
|
||||
i].extend(mrope_positions[i])
|
||||
|
||||
self.input_data.multi_modal_inputs_list.append(mm_kwargs)
|
||||
for modality, placeholder_map in placeholder_maps.items():
|
||||
self.input_data.multi_modal_placeholder_maps[modality].extend(
|
||||
placeholder_map)
|
||||
|
||||
def _prepare_lora_input(
|
||||
self, seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
is_prefill: bool) -> LoRAMapping:
|
||||
index_mapping = []
|
||||
prompt_mapping = []
|
||||
for seq in seq_group_metadata_list:
|
||||
lora_id = seq.lora_int_id
|
||||
query_len = seq.token_chunk_size
|
||||
|
||||
index_mapping += [lora_id] * query_len
|
||||
prompt_mapping += [lora_id] * (
|
||||
query_len if seq.sampling_params
|
||||
and seq.sampling_params.prompt_logprobs is not None else 1)
|
||||
|
||||
return LoRAMapping(index_mapping=tuple(index_mapping),
|
||||
prompt_mapping=tuple(prompt_mapping),
|
||||
is_prefill=is_prefill)
|
||||
|
||||
|
||||
class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
|
||||
"""
|
||||
Helper class for shared methods between CPU model runners.
|
||||
"""
|
||||
_model_input_cls: Type[TModelInputForCPU]
|
||||
_builder_cls: Type[ModelInputForCPUBuilder]
|
||||
builder: ModelInputForCPUBuilder
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
kv_cache_dtype: Optional[str] = "auto",
|
||||
is_driver_worker: bool = False,
|
||||
return_hidden_states: bool = False,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
ModelRunnerBase.__init__(self, vllm_config)
|
||||
model_config = self.model_config
|
||||
cache_config = self.cache_config
|
||||
|
||||
self.is_driver_worker = is_driver_worker
|
||||
self.return_hidden_states = return_hidden_states
|
||||
|
||||
self.device = self.device_config.device
|
||||
self.pin_memory = False
|
||||
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.sliding_window = model_config.get_sliding_window()
|
||||
self.block_size = cache_config.block_size
|
||||
num_attn_heads = self.model_config.get_num_attention_heads(
|
||||
self.parallel_config)
|
||||
needs_attn_backend = (num_attn_heads != 0
|
||||
or self.model_config.is_attention_free)
|
||||
self.attn_backend = get_attn_backend(
|
||||
self.model_config.get_head_size(),
|
||||
self.model_config.dtype,
|
||||
self.kv_cache_dtype,
|
||||
self.block_size,
|
||||
self.model_config.is_attention_free,
|
||||
use_mla=self.model_config.use_mla,
|
||||
) if needs_attn_backend else None
|
||||
|
||||
# Lazy initialization.
|
||||
self.model: nn.Module # Set after init_Model
|
||||
# Set after load_model.
|
||||
self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
|
||||
self.sampler = get_sampler()
|
||||
|
||||
if hasattr(self, "_builder_cls"):
|
||||
# multi-step model runner does not have `_builder_cls`
|
||||
self.builder = self._builder_cls(weakref.proxy(self))
|
||||
|
||||
def load_model(self) -> None:
|
||||
self.model = get_model(vllm_config=self.vllm_config)
|
||||
|
||||
if self.lora_config:
|
||||
assert supports_lora(
|
||||
self.model
|
||||
), f"{self.model.__class__.__name__} does not support LoRA yet."
|
||||
|
||||
if supports_multimodal(self.model):
|
||||
logger.warning("Regarding multimodal models, vLLM currently "
|
||||
"only supports adding LoRA to language model.")
|
||||
|
||||
# Use get_text_config() in case of multimodal models
|
||||
text_config = self.model_config.hf_config.get_text_config()
|
||||
|
||||
self.lora_manager = LRUCacheWorkerLoRAManager(
|
||||
self.scheduler_config.max_num_seqs,
|
||||
self.scheduler_config.max_num_batched_tokens,
|
||||
self.vocab_size,
|
||||
self.lora_config,
|
||||
self.device,
|
||||
self.model.embedding_modules,
|
||||
self.model.embedding_padding_modules,
|
||||
max_position_embeddings=text_config.max_position_embeddings,
|
||||
)
|
||||
self.model = self.lora_manager.create_lora_manager(self.model)
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model
|
||||
|
||||
def _prepare_model_input_tensors(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
finished_requests_ids: Optional[List[str]] = None
|
||||
) -> TModelInputForCPU:
|
||||
"""Helper method to prepare the model input based on a given sequence
|
||||
group. Prepares metadata needed for the base model forward pass but not
|
||||
metadata for possible additional steps, e.g., sampling.
|
||||
|
||||
"""
|
||||
self.builder.prepare(finished_requests_ids)
|
||||
self.builder.set_seq_group_list(seq_group_metadata_list)
|
||||
|
||||
return self.builder.build() # type: ignore
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
return self.model_config.get_vocab_size()
|
||||
|
||||
def remove_all_loras(self):
|
||||
if not self.lora_manager:
|
||||
raise RuntimeError("LoRA is not enabled.")
|
||||
self.lora_manager.remove_all_adapters()
|
||||
|
||||
def set_active_loras(self, lora_requests: Set[LoRARequest],
|
||||
lora_mapping: LoRAMapping) -> None:
|
||||
if not self.lora_manager:
|
||||
raise RuntimeError("LoRA is not enabled.")
|
||||
self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
if not self.lora_manager:
|
||||
raise RuntimeError("LoRA is not enabled.")
|
||||
return self.lora_manager.add_adapter(lora_request)
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
if not self.lora_manager:
|
||||
raise RuntimeError("LoRA is not enabled.")
|
||||
return self.lora_manager.remove_adapter(lora_id)
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
if not self.lora_manager:
|
||||
raise RuntimeError("LoRA is not enabled.")
|
||||
return self.lora_manager.pin_adapter(lora_id)
|
||||
|
||||
def list_loras(self) -> Set[int]:
|
||||
if not self.lora_manager:
|
||||
raise RuntimeError("LoRA is not enabled.")
|
||||
return self.lora_manager.list_adapters()
|
||||
|
||||
|
||||
class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
|
||||
_model_input_cls: Type[ModelInputForCPUWithSamplingMetadata] = (
|
||||
ModelInputForCPUWithSamplingMetadata)
|
||||
_builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder
|
||||
|
||||
def make_model_input_from_broadcasted_tensor_dict(
|
||||
self,
|
||||
tensor_dict: Dict[str, Any],
|
||||
) -> ModelInputForCPUWithSamplingMetadata:
|
||||
return ModelInputForCPUWithSamplingMetadata.from_broadcasted_tensor_dict( # noqa: E501
|
||||
tensor_dict,
|
||||
attn_backend=self.attn_backend,
|
||||
)
|
||||
|
||||
def prepare_model_input(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
virtual_engine: int = 0,
|
||||
finished_requests_ids: Optional[List[str]] = None
|
||||
) -> ModelInputForCPUWithSamplingMetadata:
|
||||
"""Prepare the model input based on a given sequence group, including
|
||||
metadata for the sampling step.
|
||||
|
||||
"""
|
||||
model_input = self._prepare_model_input_tensors(
|
||||
seq_group_metadata_list, finished_requests_ids)
|
||||
# Sampling metadata is only required for the final pp group
|
||||
generators = self.get_generators(finished_requests_ids)
|
||||
sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
|
||||
model_input.seq_lens,
|
||||
model_input.query_lens,
|
||||
self.device,
|
||||
pin_memory=False,
|
||||
generators=generators)
|
||||
|
||||
is_prompt = (seq_group_metadata_list[0].is_prompt
|
||||
if seq_group_metadata_list else None)
|
||||
return dataclasses.replace(model_input,
|
||||
sampling_metadata=sampling_metadata,
|
||||
virtual_engine=virtual_engine,
|
||||
is_prompt=is_prompt)
|
||||
|
||||
@torch.no_grad()
|
||||
def execute_model(
|
||||
self,
|
||||
model_input: ModelInputForCPUWithSamplingMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
num_steps: int = 1,
|
||||
previous_hidden_states: Optional[torch.Tensor] = None,
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
if num_steps > 1:
|
||||
raise ValueError(
|
||||
"CPU worker does not support multi-step execution.")
|
||||
|
||||
if self.lora_config:
|
||||
assert model_input.lora_requests is not None
|
||||
assert model_input.lora_mapping is not None
|
||||
self.set_active_loras(model_input.lora_requests,
|
||||
model_input.lora_mapping)
|
||||
|
||||
model_executable = self.model
|
||||
|
||||
multimodal_kwargs = {}
|
||||
if model_input.multi_modal_kwargs is not None:
|
||||
multimodal_kwargs = MultiModalKwargs.as_kwargs(
|
||||
model_input.multi_modal_kwargs,
|
||||
device=self.device,
|
||||
)
|
||||
execute_model_kwargs = {}
|
||||
if previous_hidden_states is not None:
|
||||
execute_model_kwargs.update(
|
||||
{"previous_hidden_states": previous_hidden_states})
|
||||
|
||||
with set_forward_context(model_input.attn_metadata, self.vllm_config,
|
||||
model_input.virtual_engine):
|
||||
hidden_states = model_executable(
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**execute_model_kwargs,
|
||||
**multimodal_kwargs,
|
||||
)
|
||||
|
||||
# Compute the logits.
|
||||
logits = self.model.compute_logits(hidden_states,
|
||||
model_input.sampling_metadata)
|
||||
|
||||
# Only perform sampling in the driver worker.
|
||||
if not self.is_driver_worker:
|
||||
return []
|
||||
|
||||
# Sample the next token.
|
||||
output = self.sampler(
|
||||
logits=logits,
|
||||
sampling_metadata=model_input.sampling_metadata,
|
||||
)
|
||||
if self.return_hidden_states:
|
||||
# we only need to pass hidden states of most recent token
|
||||
if model_input.is_prompt:
|
||||
output.prefill_hidden_states = hidden_states
|
||||
output.hidden_states = hidden_states
|
||||
return [output]
|
||||
|
||||
def generate_proposals(self, *args, **kwargs):
|
||||
return self.model.generate_proposals(*args, **kwargs)
|
||||
125
vllm/worker/cpu_pooling_model_runner.py
Normal file
125
vllm/worker/cpu_pooling_model_runner.py
Normal file
@@ -0,0 +1,125 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import dataclasses
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.multimodal import MultiModalKwargs
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.worker.cpu_model_runner import (CPUModelRunnerBase, ModelInputForCPU,
|
||||
ModelInputForCPUBuilder)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ModelInputForCPUWithPoolingMetadata(ModelInputForCPU):
|
||||
"""
|
||||
Used by the CPUPoolingModelRunner.
|
||||
"""
|
||||
pooling_metadata: Optional["PoolingMetadata"] = None
|
||||
|
||||
|
||||
class CPUPoolingModelRunner(
|
||||
CPUModelRunnerBase[ModelInputForCPUWithPoolingMetadata]):
|
||||
_model_input_cls: Type[ModelInputForCPUWithPoolingMetadata] = (
|
||||
ModelInputForCPUWithPoolingMetadata)
|
||||
_builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
model_input: ModelInputForCPUWithPoolingMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
num_steps: int = 1,
|
||||
) -> Optional[Union[List[PoolerOutput], IntermediateTensors]]:
|
||||
if num_steps > 1:
|
||||
raise ValueError(
|
||||
"CPU worker does not support multi-step execution.")
|
||||
|
||||
model_executable = self.model
|
||||
cross_enc_kwargs = {}
|
||||
if model_input.token_type_ids is not None:
|
||||
cross_enc_kwargs["token_type_ids"] = model_input.token_type_ids
|
||||
execute_model_kwargs = {
|
||||
"input_ids":
|
||||
model_input.input_tokens,
|
||||
"positions":
|
||||
model_input.input_positions,
|
||||
**MultiModalKwargs.as_kwargs(
|
||||
model_input.multi_modal_kwargs or {},
|
||||
device=self.device,
|
||||
),
|
||||
**cross_enc_kwargs,
|
||||
"intermediate_tensors":
|
||||
intermediate_tensors,
|
||||
}
|
||||
|
||||
with set_forward_context(model_input.attn_metadata, self.vllm_config,
|
||||
model_input.virtual_engine):
|
||||
hidden_states = model_executable(**execute_model_kwargs)
|
||||
|
||||
# Only perform pooling in the driver worker.
|
||||
if not self.is_driver_worker:
|
||||
return []
|
||||
|
||||
return [
|
||||
self.model.pooler(hidden_states=hidden_states,
|
||||
pooling_metadata=model_input.pooling_metadata)
|
||||
]
|
||||
|
||||
def make_model_input_from_broadcasted_tensor_dict(
|
||||
self,
|
||||
tensor_dict: Dict[str,
|
||||
Any]) -> ModelInputForCPUWithPoolingMetadata:
|
||||
return ModelInputForCPUWithPoolingMetadata.from_broadcasted_tensor_dict(
|
||||
tensor_dict,
|
||||
attn_backend=self.attn_backend,
|
||||
)
|
||||
|
||||
def prepare_model_input(
|
||||
self,
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||
virtual_engine: int = 0,
|
||||
finished_requests_ids: Optional[List[str]] = None
|
||||
) -> ModelInputForCPUWithPoolingMetadata:
|
||||
assert seq_group_metadata_list is not None
|
||||
model_input = self._prepare_model_input_tensors(
|
||||
seq_group_metadata_list, finished_requests_ids)
|
||||
# Prepare PoolingMetadata.
|
||||
assert model_input.seq_lens is not None
|
||||
pooling_metadata = self._prepare_pooling(seq_group_metadata_list,
|
||||
model_input.seq_lens)
|
||||
|
||||
return dataclasses.replace(model_input,
|
||||
virtual_engine=virtual_engine,
|
||||
pooling_metadata=pooling_metadata)
|
||||
|
||||
def _prepare_pooling(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
prompt_lens: List[int],
|
||||
) -> PoolingMetadata:
|
||||
"""Prepare PoolingMetadata for the sequence group metadata list."""
|
||||
seq_groups: List[Tuple[List[int], PoolingParams]] = []
|
||||
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
pooling_params = seq_group_metadata.pooling_params
|
||||
seq_groups.append((seq_ids, pooling_params))
|
||||
|
||||
seq_data: Dict[int, SequenceData] = {}
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
seq_data.update(seq_group_metadata.seq_data)
|
||||
|
||||
pooling_metadata = PoolingMetadata(
|
||||
seq_groups=seq_groups,
|
||||
seq_data=seq_data,
|
||||
prompt_lens=prompt_lens,
|
||||
)
|
||||
|
||||
return pooling_metadata
|
||||
457
vllm/worker/cpu_worker.py
Normal file
457
vllm/worker/cpu_worker.py
Normal file
@@ -0,0 +1,457 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""A CPU worker class."""
|
||||
import os
|
||||
from importlib import util
|
||||
from typing import List, Optional, Set, Tuple, Type
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention import get_attn_backend
|
||||
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
|
||||
ParallelConfig, VllmConfig)
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.worker.cache_engine import CacheEngine
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.utils import bind_kv_cache
|
||||
from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner
|
||||
from vllm.worker.cpu_model_runner import CPUModelRunner, CPUModelRunnerBase
|
||||
from vllm.worker.cpu_pooling_model_runner import CPUPoolingModelRunner
|
||||
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
|
||||
WorkerInput)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class CPUCacheEngine:
|
||||
"""Manages the KV cache for CPU backend.
|
||||
|
||||
This class is responsible for initializing and managing CPU KV
|
||||
caches. It also provides methods for performing KV cache operations, such
|
||||
as copying.
|
||||
"""
|
||||
|
||||
def __init__(self, cache_config: CacheConfig, model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
device_config: DeviceConfig) -> None:
|
||||
assert device_config.device_type == "cpu"
|
||||
self.cache_config = cache_config
|
||||
self.model_config = model_config
|
||||
self.parallel_config = parallel_config
|
||||
|
||||
self.head_size = model_config.get_head_size()
|
||||
self.num_layers = model_config.get_num_layers(parallel_config)
|
||||
self.num_heads = model_config.get_num_kv_heads(parallel_config)
|
||||
|
||||
self.block_size = cache_config.block_size
|
||||
# Note: In CacheConfig, num_gpu_blocks actual is num_cpu_blocks
|
||||
# for CPU backend, because we want to reuse KV cache management
|
||||
# in the scheduler.
|
||||
self.num_cpu_blocks = cache_config.num_gpu_blocks
|
||||
|
||||
self.dtype = CPUCacheEngine.get_kv_cache_dtype(cache_config,
|
||||
model_config)
|
||||
|
||||
# Get attention backend.
|
||||
self.attn_backend = get_attn_backend(
|
||||
self.model_config.get_head_size(),
|
||||
self.model_config.dtype,
|
||||
cache_config.cache_dtype,
|
||||
self.block_size,
|
||||
self.model_config.is_attention_free,
|
||||
use_mla=self.model_config.use_mla,
|
||||
)
|
||||
|
||||
# Initialize the cache.
|
||||
self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks)
|
||||
|
||||
def _allocate_kv_cache(
|
||||
self,
|
||||
num_blocks: int,
|
||||
) -> List[torch.Tensor]:
|
||||
"""Allocates KV cache on CPU."""
|
||||
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
||||
num_blocks, self.block_size, self.num_heads, self.head_size)
|
||||
kv_cache: List[torch.Tensor] = []
|
||||
for _ in range(self.num_layers):
|
||||
kv_cache.append(
|
||||
torch.empty(kv_cache_shape, dtype=self.dtype, device="cpu"))
|
||||
return kv_cache
|
||||
|
||||
def swap_in(self, src_to_dst: torch.Tensor) -> None:
|
||||
raise NotImplementedError("Swap is not supported in CPUCacheEngine.")
|
||||
|
||||
def swap_out(self, src_to_dst: torch.Tensor) -> None:
|
||||
raise NotImplementedError("Swap is not supported in CPUCacheEngine.")
|
||||
|
||||
def copy(self, src_to_dsts: torch.Tensor) -> None:
|
||||
self.attn_backend.copy_blocks(self.cpu_cache, src_to_dsts)
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_dtype(cache_config: CacheConfig,
|
||||
model_config: ModelConfig):
|
||||
if cache_config.cache_dtype == "auto":
|
||||
return model_config.dtype
|
||||
elif cache_config.cache_dtype in ["fp8", "fp8_e5m2"]:
|
||||
return torch.float8_e5m2
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported KV cache type "
|
||||
f"{cache_config.cache_dtype}.")
|
||||
|
||||
@staticmethod
|
||||
def get_cache_block_size(
|
||||
cache_config: CacheConfig,
|
||||
model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
) -> int:
|
||||
head_size = model_config.get_head_size()
|
||||
num_heads = model_config.get_num_kv_heads(parallel_config)
|
||||
num_layers = model_config.get_num_layers(parallel_config)
|
||||
|
||||
key_cache_block = cache_config.block_size * num_heads * head_size
|
||||
value_cache_block = key_cache_block if not model_config.use_mla else 0
|
||||
total = num_layers * (key_cache_block + value_cache_block)
|
||||
dtype = CPUCacheEngine.get_kv_cache_dtype(cache_config, model_config)
|
||||
dtype_size = torch.tensor([], dtype=dtype).element_size()
|
||||
return dtype_size * total
|
||||
|
||||
|
||||
class CPUWorker(LocalOrDistributedWorkerBase):
|
||||
"""A worker class that executes (a partition of) the model on a CPU socket.
|
||||
|
||||
Each worker is associated with a single CPU socket. The worker is
|
||||
responsible for maintaining the KV cache and executing the model on the
|
||||
CPU. In case of distributed inference, each worker is assigned a partition
|
||||
of the model.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
kv_cache_dtype: Optional[str] = "auto",
|
||||
is_driver_worker: bool = False,
|
||||
model_runner_cls: Optional[Type[CPUModelRunner]] = None,
|
||||
) -> None:
|
||||
WorkerBase.__init__(self, vllm_config=vllm_config)
|
||||
|
||||
self.local_rank = local_rank
|
||||
self.rank = rank
|
||||
vllm_config.parallel_config.rank = rank
|
||||
|
||||
self.distributed_init_method = distributed_init_method
|
||||
|
||||
self.is_driver_worker = is_driver_worker
|
||||
if self.is_driver_worker:
|
||||
assert self.rank == 0, "The driver worker must have rank 0."
|
||||
|
||||
if self.model_config.trust_remote_code:
|
||||
# note: lazy import to avoid importing torch before initializing
|
||||
from vllm.utils import init_cached_hf_modules
|
||||
init_cached_hf_modules()
|
||||
|
||||
# Setup OpenMP threads affinity.
|
||||
omp_cpuids = envs.VLLM_CPU_OMP_THREADS_BIND
|
||||
self.local_omp_cpuid = "all"
|
||||
if omp_cpuids == "auto":
|
||||
self.local_omp_cpuid = self.get_cpus_id_binding_based_on_numa_nodes(
|
||||
)
|
||||
else:
|
||||
self.local_omp_cpuid = omp_cpuids.split("|")[rank]
|
||||
|
||||
# Return hidden states from target model if the draft model is an
|
||||
# mlp_speculator
|
||||
speculative_config = self.speculative_config
|
||||
model_config = self.model_config
|
||||
speculative_args = {} if speculative_config is None \
|
||||
or (speculative_config.draft_model_config.model ==
|
||||
model_config.model) \
|
||||
or (speculative_config.draft_model_config.hf_config.model_type
|
||||
not in ["medusa", "mlp_speculator", "eagle"]) \
|
||||
else {"return_hidden_states": True}
|
||||
ModelRunnerClass: Type[CPUModelRunnerBase] = CPUModelRunner
|
||||
if self.model_config.runner_type == "pooling":
|
||||
ModelRunnerClass = CPUPoolingModelRunner
|
||||
elif self.model_config.is_encoder_decoder:
|
||||
ModelRunnerClass = CPUEncoderDecoderModelRunner
|
||||
self.model_runner: CPUModelRunnerBase = ModelRunnerClass(
|
||||
vllm_config=vllm_config,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
is_driver_worker=is_driver_worker,
|
||||
**speculative_args,
|
||||
)
|
||||
if model_runner_cls is not None:
|
||||
self.model_runner = model_runner_cls(self.model_runner)
|
||||
# Uninitialized cache engine. Will be initialized by
|
||||
# initialize_cache.
|
||||
self.cache_engine: List[CPUCacheEngine]
|
||||
# Initialize cpu_cache as pooling models don't initialize kv_caches
|
||||
self.cpu_cache: Optional[List[List[torch.Tensor]]] = None
|
||||
|
||||
# Torch profiler. Enabled and configured through env vars:
|
||||
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
|
||||
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
|
||||
logger.info("Profiling enabled. Traces will be saved to: %s",
|
||||
torch_profiler_trace_dir)
|
||||
self.profiler = torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
],
|
||||
with_stack=True,
|
||||
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
||||
torch_profiler_trace_dir, use_gzip=True))
|
||||
else:
|
||||
self.profiler = None
|
||||
|
||||
def start_profile(self):
|
||||
if self.profiler is None:
|
||||
raise RuntimeError("Profiler is not enabled.")
|
||||
self.profiler.start()
|
||||
|
||||
def stop_profile(self):
|
||||
if self.profiler is None:
|
||||
raise RuntimeError("Profiler is not enabled.")
|
||||
self.profiler.stop()
|
||||
|
||||
def init_device(self) -> None:
|
||||
if self.local_omp_cpuid != "all":
|
||||
ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)
|
||||
if ret:
|
||||
logger.info(ret)
|
||||
|
||||
# Note: unique identifier for creating allreduce shared memory
|
||||
os.environ["VLLM_DIST_IDENT"] = self.distributed_init_method.split(
|
||||
":")[-1]
|
||||
self.device = torch.device("cpu")
|
||||
self.init_distributed_environment()
|
||||
# Set random seed.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
def load_model(self):
|
||||
self.model_runner.load_model()
|
||||
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
"""Determine the number of blocks available for the KV cache.
|
||||
|
||||
This determines how many KV blocks can fit into the configured CPU
|
||||
KV cache space.
|
||||
|
||||
Note that since vLLM assumes a block resides on GPU if it can be
|
||||
modified, we return num_gpu_blocks=num_cpu_blocks and num_cpu_blocks=0.
|
||||
This allows us to reuse the scheduler of vLLM without generalizing it
|
||||
to different devices.
|
||||
"""
|
||||
# For CPU device, the block number will be calculated based on the
|
||||
# cpu_kvcache_space.
|
||||
cache_block_size = self.get_cache_block_size_bytes()
|
||||
num_cpu_blocks = int(self.cache_config.cpu_kvcache_space_bytes //
|
||||
cache_block_size)
|
||||
num_cpu_blocks = max(num_cpu_blocks, 0)
|
||||
|
||||
# Note: To reuse the cache management procedure,
|
||||
# use cpu cache as 'gpu cache'.
|
||||
num_gpu_blocks = num_cpu_blocks
|
||||
num_cpu_blocks = 0
|
||||
return num_gpu_blocks, num_cpu_blocks
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int,
|
||||
num_cpu_blocks: int) -> None:
|
||||
"""Initialize the KV cache. Currently, swappable CPU memory is not
|
||||
supported.
|
||||
|
||||
Since this worker does not support GPUs, we use the num_gpu_blocks to
|
||||
determine how many non-swappable CPU blocks to allocate.
|
||||
"""
|
||||
assert (num_cpu_blocks == 0
|
||||
), f"{type(self)} does not support swappable cache"
|
||||
|
||||
# Note: To reuse the cache management procedure,
|
||||
# use cpu cache as 'gpu cache'.
|
||||
num_cpu_blocks = num_gpu_blocks
|
||||
|
||||
self._validate_num_cpu_blocks(num_cpu_blocks)
|
||||
self.cache_config.num_gpu_blocks = num_cpu_blocks
|
||||
self.cache_config.num_cpu_blocks = 0
|
||||
|
||||
# Initialize the cache.
|
||||
self._init_cache_engine()
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
return self.model_runner.add_lora(lora_request)
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
return self.model_runner.remove_lora(lora_id)
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
return self.model_runner.pin_lora(lora_id)
|
||||
|
||||
def list_loras(self) -> Set[int]:
|
||||
return self.model_runner.list_loras()
|
||||
|
||||
def _validate_num_cpu_blocks(self, num_cpu_blocks: int) -> None:
|
||||
"""Raise errors if the num_cpu_blocks is invalid.
|
||||
"""
|
||||
if num_cpu_blocks <= 0:
|
||||
raise ValueError("No available memory for the cache blocks. "
|
||||
"Try increasing `VLLM_CPU_KVCACHE_SPACE` when "
|
||||
"initializing the engine.")
|
||||
|
||||
max_seq_len = self.cache_config.block_size * num_cpu_blocks
|
||||
if self.model_config.max_model_len > max_seq_len:
|
||||
raise ValueError(
|
||||
f"The model's max seq len ({self.model_config.max_model_len}) "
|
||||
"is larger than the maximum number of tokens that can be "
|
||||
f"stored in KV cache ({max_seq_len}). Try increasing "
|
||||
"`VLLM_CPU_KVCACHE_SPACE` or decreasing `max_model_len` when "
|
||||
"initializing the engine.")
|
||||
|
||||
def _init_cache_engine(self) -> None:
|
||||
self.cache_engine = [
|
||||
CPUCacheEngine(self.cache_config, self.model_config,
|
||||
self.parallel_config, self.device_config)
|
||||
for _ in range(self.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
self.cpu_cache = [
|
||||
self.cache_engine[ve].cpu_cache
|
||||
for ve in range(self.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
bind_kv_cache(self.compilation_config.static_forward_context,
|
||||
self.cpu_cache)
|
||||
self.model_runner.block_size = self.cache_engine[0].block_size
|
||||
|
||||
assert all(
|
||||
self.cpu_cache[ve] is not None
|
||||
for ve in range(self.parallel_config.pipeline_parallel_size))
|
||||
|
||||
# Populate the cache to warmup the memory
|
||||
for ve in range(self.parallel_config.pipeline_parallel_size):
|
||||
for layer_cache in self.cpu_cache[ve]:
|
||||
layer_cache.fill_(0)
|
||||
|
||||
@property
|
||||
def do_metadata_broadcast(self) -> bool:
|
||||
return self.parallel_config.tensor_parallel_size > 1
|
||||
|
||||
@property
|
||||
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
|
||||
return self.cpu_cache
|
||||
|
||||
@property
|
||||
def cache_engines(self) -> Optional[List[CacheEngine]]:
|
||||
return None
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
return self.model_runner.vocab_size
|
||||
|
||||
@property
|
||||
def max_model_len(self) -> int:
|
||||
return self.model_config.max_model_len
|
||||
|
||||
def execute_worker(
|
||||
self,
|
||||
worker_input: WorkerInput,
|
||||
) -> None:
|
||||
if (worker_input.blocks_to_copy is not None
|
||||
and worker_input.blocks_to_copy.numel() > 0):
|
||||
self.cache_engine[worker_input.virtual_engine].copy(
|
||||
worker_input.blocks_to_copy)
|
||||
|
||||
@torch.inference_mode()
|
||||
def prepare_worker_input(
|
||||
self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
|
||||
assert execute_model_req is not None
|
||||
virtual_engine: int = execute_model_req.virtual_engine
|
||||
num_seq_groups: int = len(execute_model_req.seq_group_metadata_list)
|
||||
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
|
||||
device="cpu",
|
||||
dtype=torch.int64).view(-1, 2)
|
||||
assert len(execute_model_req.blocks_to_swap_in) == 0
|
||||
assert len(execute_model_req.blocks_to_swap_out) == 0
|
||||
return WorkerInput(
|
||||
num_seq_groups=num_seq_groups,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
virtual_engine=virtual_engine,
|
||||
)
|
||||
|
||||
def init_distributed_environment(self) -> None:
|
||||
"""Initialize the distributed environment."""
|
||||
|
||||
parallel_config = self.parallel_config
|
||||
rank = self.rank
|
||||
distributed_init_method = self.distributed_init_method
|
||||
init_distributed_environment(
|
||||
world_size=parallel_config.world_size,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
backend="gloo",
|
||||
)
|
||||
|
||||
# A small all_reduce for warmup.
|
||||
torch.distributed.all_reduce(torch.zeros(1).cpu())
|
||||
|
||||
ensure_model_parallel_initialized(
|
||||
parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size)
|
||||
|
||||
def get_cache_block_size_bytes(self) -> int:
|
||||
"""Return the size in bytes of a single KV cache block.
|
||||
"""
|
||||
return CPUCacheEngine.get_cache_block_size(self.cache_config,
|
||||
self.model_config,
|
||||
self.parallel_config)
|
||||
|
||||
def get_cpus_id_binding_based_on_numa_nodes(self) -> str:
|
||||
"""Return CPUs id binding based on NUMA nodes.
|
||||
"""
|
||||
rank_to_cpus = self.local_omp_cpuid
|
||||
# Setup OpenMP thread affinity based on NUMA nodes automatically
|
||||
world_size = self.vllm_config.parallel_config.world_size
|
||||
libnuma_found = util.find_spec("numa") is not None
|
||||
psutil_found = util.find_spec("psutil") is not None
|
||||
if libnuma_found and psutil_found:
|
||||
import psutil
|
||||
from numa import info
|
||||
cpu_count = psutil.cpu_count(logical=False)
|
||||
cpus_allow_list = psutil.Process().cpu_affinity()
|
||||
numa_size = info.get_num_configured_nodes()
|
||||
cpu_count_per_numa = cpu_count // numa_size
|
||||
num_of_reserved_cpu = min(envs.VLLM_CPU_NUM_OF_RESERVED_CPU,
|
||||
cpu_count_per_numa // 2)
|
||||
|
||||
# check allow node_to_cpus list
|
||||
node_to_cpus = []
|
||||
for i in range(numa_size):
|
||||
node_intersect = set(
|
||||
info.node_to_cpus(i)).intersection(cpus_allow_list)
|
||||
if bool(node_intersect):
|
||||
node_to_cpus.append(list(node_intersect))
|
||||
|
||||
if world_size > len(node_to_cpus):
|
||||
logger.error(
|
||||
"Auto thread-binding failed due to "
|
||||
"world size: %d is larger than "
|
||||
"allowed NUMA nodes number: %d."
|
||||
"Please try to bind threads manually.", world_size,
|
||||
len(node_to_cpus))
|
||||
else:
|
||||
end = cpu_count_per_numa - num_of_reserved_cpu
|
||||
rank_to_cpus_list = node_to_cpus[self.rank][:end]
|
||||
rank_to_cpus = ','.join(str(x) for x in rank_to_cpus_list)
|
||||
logger.info("auto thread-binding list: %s", rank_to_cpus)
|
||||
else:
|
||||
logger.warning(
|
||||
"Auto thread-binding is not supported due to "
|
||||
"the lack of package numa and psutil,"
|
||||
"fallback to no thread-binding. To get better performance,"
|
||||
"please try to manually bind threads.")
|
||||
return rank_to_cpus
|
||||
555
vllm/worker/enc_dec_model_runner.py
Normal file
555
vllm/worker/enc_dec_model_runner.py
Normal file
@@ -0,0 +1,555 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import dataclasses
|
||||
import itertools
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, cast
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
AttentionMetadata)
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.attention.selector import (get_env_variable_attn_backend,
|
||||
get_global_forced_attn_backend)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.inputs import INPUT_REGISTRY, InputRegistry
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
|
||||
MultiModalRegistry)
|
||||
from vllm.platforms import _Backend
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (IntermediateTensors, PoolerOutput,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.utils import STR_NOT_IMPL_ENC_DEC_BACKEND, make_tensor_with_pad
|
||||
from vllm.worker.model_runner import (GPUModelRunnerBase,
|
||||
ModelInputForGPUBuilder,
|
||||
ModelInputForGPUWithSamplingMetadata)
|
||||
from vllm.worker.model_runner_base import (
|
||||
_add_attn_metadata_broadcastable_dict,
|
||||
_add_sampling_metadata_broadcastable_dict)
|
||||
from vllm.worker.utils import assert_enc_dec_mr_supported_scenario
|
||||
|
||||
logger = init_logger(__name__)
|
||||
LORA_WARMUP_RANK = 8
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class EncoderDecoderModelInput(ModelInputForGPUWithSamplingMetadata):
|
||||
"""
|
||||
Used by the EncoderDecoderModelRunner.
|
||||
"""
|
||||
encoder_input_tokens: Optional[torch.Tensor] = None
|
||||
encoder_input_positions: Optional[torch.Tensor] = None
|
||||
|
||||
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
||||
tensor_dict = {
|
||||
"input_tokens": self.input_tokens,
|
||||
"inputs_embeds": self.inputs_embeds,
|
||||
"input_positions": self.input_positions,
|
||||
"encoder_input_tokens": self.encoder_input_tokens,
|
||||
"encoder_input_positions": self.encoder_input_positions,
|
||||
"virtual_engine": self.virtual_engine,
|
||||
"request_ids_to_seq_ids": self.request_ids_to_seq_ids,
|
||||
"finished_requests_ids": self.finished_requests_ids,
|
||||
"multi_modal_kwargs": self.multi_modal_kwargs,
|
||||
}
|
||||
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
|
||||
_add_sampling_metadata_broadcastable_dict(tensor_dict,
|
||||
self.sampling_metadata)
|
||||
return tensor_dict
|
||||
|
||||
@classmethod
|
||||
def from_broadcasted_tensor_dict(
|
||||
cls,
|
||||
tensor_dict: Dict[str, Any],
|
||||
attn_backend: Optional["AttentionBackend"] = None,
|
||||
) -> "EncoderDecoderModelInput":
|
||||
return cast(
|
||||
EncoderDecoderModelInput,
|
||||
super().from_broadcasted_tensor_dict(tensor_dict, attn_backend))
|
||||
|
||||
|
||||
class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
||||
_model_input_cls: Type[EncoderDecoderModelInput] = (
|
||||
EncoderDecoderModelInput)
|
||||
_builder_cls: Type[ModelInputForGPUBuilder] = (ModelInputForGPUBuilder)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
kv_cache_dtype: Optional[str] = "auto",
|
||||
is_driver_worker: bool = False,
|
||||
input_registry: InputRegistry = INPUT_REGISTRY,
|
||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||
):
|
||||
'''
|
||||
EncoderDecoderModelRunner constructor.
|
||||
|
||||
`lora_config` and `prompt_adapter_config` are
|
||||
unused (since these features are not yet supported for encoder/decoder
|
||||
models) but these arguments are present here for compatibility with
|
||||
the base-class constructor.
|
||||
'''
|
||||
self._maybe_force_supported_attention_backend()
|
||||
|
||||
super().__init__(
|
||||
vllm_config=vllm_config,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
is_driver_worker=is_driver_worker,
|
||||
input_registry=input_registry,
|
||||
mm_registry=mm_registry,
|
||||
)
|
||||
|
||||
# Crash for unsupported encoder/scenarios
|
||||
assert_enc_dec_mr_supported_scenario(self)
|
||||
|
||||
def _maybe_force_supported_attention_backend(self):
|
||||
'''
|
||||
Force vLLM to use the XFormers attention backend,
|
||||
which is currently the only supported option.
|
||||
'''
|
||||
|
||||
def raise_backend_err():
|
||||
# The user has specified an attention backend override
|
||||
# which is invalid for encoder/decoder models
|
||||
raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_BACKEND)
|
||||
|
||||
maybe_env_var_forced_backend = get_env_variable_attn_backend()
|
||||
maybe_global_forced_backend = get_global_forced_attn_backend()
|
||||
is_forced_by_global = maybe_global_forced_backend is not None
|
||||
is_forced_by_env_var = maybe_env_var_forced_backend is not None
|
||||
if is_forced_by_global: # noqa: SIM102
|
||||
# Backend override enforced by global variable takes
|
||||
# precedence over vLLM backend environment variable.
|
||||
if maybe_global_forced_backend not in\
|
||||
[_Backend.XFORMERS, _Backend.FLASH_ATTN]:
|
||||
raise_backend_err()
|
||||
elif is_forced_by_env_var: # noqa: SIM102
|
||||
# Backend override enforced by vLLM backend
|
||||
# environment variable
|
||||
if maybe_env_var_forced_backend not in\
|
||||
[_Backend.XFORMERS, _Backend.FLASH_ATTN]:
|
||||
raise_backend_err()
|
||||
|
||||
def _list_to_int32_tensor(
|
||||
self,
|
||||
_list: List[int],
|
||||
) -> torch.Tensor:
|
||||
return torch.tensor(_list, dtype=torch.int32, device=self.device)
|
||||
|
||||
def _list_to_long_tensor(
|
||||
self,
|
||||
_list: List[int],
|
||||
) -> torch.Tensor:
|
||||
return torch.tensor(_list, dtype=torch.long, device=self.device)
|
||||
|
||||
def _empty_int32_tensor(self) -> torch.Tensor:
|
||||
return self._list_to_int32_tensor([])
|
||||
|
||||
def _empty_long_tensor(self) -> torch.Tensor:
|
||||
return self._list_to_long_tensor([])
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
model_input: EncoderDecoderModelInput,
|
||||
kv_caches: List[torch.Tensor],
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
num_steps: int = 1,
|
||||
) -> Optional[List[PoolerOutput]]:
|
||||
if num_steps > 1:
|
||||
raise ValueError("num_steps > 1 is not supported in "
|
||||
"EncoderDecoderModelRunner")
|
||||
if self.lora_config:
|
||||
assert model_input.lora_requests is not None
|
||||
assert model_input.lora_mapping is not None
|
||||
self.set_active_loras(model_input.lora_requests,
|
||||
model_input.lora_mapping)
|
||||
if (model_input.attn_metadata is not None
|
||||
and model_input.attn_metadata.prefill_metadata is None
|
||||
and model_input.attn_metadata.decode_metadata.use_cuda_graph):
|
||||
if model_input.inputs_embeds is None:
|
||||
assert model_input.input_tokens is not None
|
||||
graph_batch_size = model_input.input_tokens.shape[0]
|
||||
model_executable = (
|
||||
self.graph_runners[model_input.virtual_engine][(
|
||||
graph_batch_size, False)])
|
||||
else:
|
||||
graph_batch_size = model_input.inputs_embeds.shape[0]
|
||||
model_executable = (
|
||||
self.graph_runners[model_input.virtual_engine][(
|
||||
graph_batch_size, True)])
|
||||
else:
|
||||
model_executable = self.model
|
||||
|
||||
seqlen_agnostic_kwargs = {
|
||||
"finished_requests_ids": model_input.finished_requests_ids,
|
||||
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
|
||||
} if self.has_inner_state else {}
|
||||
|
||||
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
|
||||
with set_forward_context(model_input.attn_metadata, self.vllm_config,
|
||||
model_input.virtual_engine):
|
||||
hidden_or_intermediate_states = model_executable(
|
||||
input_ids=model_input.input_tokens,
|
||||
inputs_embeds=model_input.inputs_embeds,
|
||||
positions=model_input.input_positions,
|
||||
encoder_input_ids=model_input.encoder_input_tokens,
|
||||
encoder_positions=model_input.encoder_input_positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**MultiModalKwargs.as_kwargs(
|
||||
multi_modal_kwargs,
|
||||
device=self.device,
|
||||
),
|
||||
**seqlen_agnostic_kwargs,
|
||||
)
|
||||
|
||||
logits = self.model.compute_logits(hidden_or_intermediate_states,
|
||||
model_input.sampling_metadata)
|
||||
|
||||
if not self.is_driver_worker:
|
||||
return []
|
||||
|
||||
if model_input.async_callback is not None:
|
||||
model_input.async_callback()
|
||||
|
||||
# Sample the next token.
|
||||
output: SamplerOutput = self.sampler(
|
||||
logits=logits,
|
||||
sampling_metadata=model_input.sampling_metadata,
|
||||
)
|
||||
|
||||
return [output]
|
||||
|
||||
def make_model_input_from_broadcasted_tensor_dict(
|
||||
self, tensor_dict: Dict[str, Any]) -> EncoderDecoderModelInput:
|
||||
return EncoderDecoderModelInput.from_broadcasted_tensor_dict(
|
||||
tensor_dict,
|
||||
attn_backend=self.attn_backend,
|
||||
)
|
||||
|
||||
def prepare_model_input(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
virtual_engine: int = 0,
|
||||
finished_requests_ids: Optional[List[str]] = None
|
||||
) -> EncoderDecoderModelInput:
|
||||
"""Prepare the model input based on a given sequence group, including
|
||||
metadata for the sampling step.
|
||||
|
||||
Since chunked prefill is not supported for encoder/decoder models,
|
||||
`input_tokens` is assumed to be either entirely prefill tokens or
|
||||
entirely decode tokens.
|
||||
|
||||
"""
|
||||
model_input = self._prepare_model_input_tensors(
|
||||
seq_group_metadata_list, finished_requests_ids)
|
||||
(
|
||||
attn_metadata,
|
||||
encoder_input_tokens_tensor,
|
||||
encoder_input_positions_tensor,
|
||||
) = (self._prepare_encoder_model_input_tensors(seq_group_metadata_list,
|
||||
model_input))
|
||||
# Inject attn_metadata encoder/cross-attention fields &
|
||||
# encoder input tokens/positions into model_input.
|
||||
# Frozen dataclass fields cannot be modified, so use
|
||||
# dataclasses.replace to construct a new model input
|
||||
# instance.
|
||||
model_input = dataclasses.replace(
|
||||
model_input,
|
||||
attn_metadata=attn_metadata,
|
||||
encoder_input_tokens=encoder_input_tokens_tensor,
|
||||
encoder_input_positions=encoder_input_positions_tensor,
|
||||
)
|
||||
|
||||
generators = self.get_generators(finished_requests_ids)
|
||||
sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
|
||||
model_input.seq_lens,
|
||||
model_input.query_lens,
|
||||
self.device,
|
||||
self.pin_memory,
|
||||
generators=generators)
|
||||
is_prompt = (seq_group_metadata_list[0].is_prompt
|
||||
if seq_group_metadata_list else None)
|
||||
return dataclasses.replace(model_input,
|
||||
sampling_metadata=sampling_metadata,
|
||||
is_prompt=is_prompt,
|
||||
virtual_engine=virtual_engine)
|
||||
|
||||
@torch.inference_mode()
|
||||
def profile_run(self) -> None:
|
||||
# Enable top-k sampling to reflect the accurate memory usage.
|
||||
sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
|
||||
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
|
||||
max_num_seqs = self.scheduler_config.max_num_seqs
|
||||
|
||||
# This represents the maximum number of different requests
|
||||
# that will have unique loras, and therefore the max amount of
|
||||
# memory consumption. Create dummy lora request copies from the
|
||||
# lora request passed in, which contains a lora from the lora
|
||||
# warmup path.
|
||||
dummy_lora_requests: List[LoRARequest] = []
|
||||
dummy_lora_requests_per_seq: List[LoRARequest] = []
|
||||
if self.lora_config:
|
||||
dummy_lora_requests = self._add_dummy_loras(
|
||||
self.lora_config.max_loras)
|
||||
assert len(dummy_lora_requests) == self.lora_config.max_loras
|
||||
dummy_lora_requests_per_seq = [
|
||||
dummy_lora_requests[idx % len(dummy_lora_requests)]
|
||||
for idx in range(max_num_seqs)
|
||||
]
|
||||
|
||||
# Profile memory usage with max_num_sequences sequences and the total
|
||||
# number of tokens equal to max_num_batched_tokens.
|
||||
seqs: List[SequenceGroupMetadata] = []
|
||||
|
||||
max_mm_tokens = self.mm_registry.get_max_multimodal_tokens(
|
||||
self.model_config)
|
||||
if max_mm_tokens > 0:
|
||||
logger.info("Starting profile run for multi-modal models.")
|
||||
|
||||
batch_size = 0
|
||||
for group_id in range(max_num_seqs):
|
||||
seq_len = (max_num_batched_tokens // max_num_seqs +
|
||||
(group_id < max_num_batched_tokens % max_num_seqs))
|
||||
batch_size += seq_len
|
||||
|
||||
decoder_dummy_data = self.input_registry \
|
||||
.dummy_data_for_profiling(self.model_config,
|
||||
seq_len,
|
||||
self.mm_registry,
|
||||
is_encoder_data=False)
|
||||
encoder_dummy_data = self.input_registry \
|
||||
.dummy_data_for_profiling(self.model_config,
|
||||
seq_len,
|
||||
self.mm_registry,
|
||||
is_encoder_data=True)
|
||||
|
||||
# Having more tokens is over-conservative but otherwise fine
|
||||
assert len(
|
||||
decoder_dummy_data.seq_data.prompt_token_ids
|
||||
) >= seq_len, (
|
||||
f"Expected at least {seq_len} dummy tokens for profiling, "
|
||||
f"but got: {len(decoder_dummy_data.seq_data.prompt_token_ids)}"
|
||||
)
|
||||
|
||||
assert decoder_dummy_data.multi_modal_data is None or \
|
||||
encoder_dummy_data.multi_modal_data is None, (
|
||||
"Multi-modal data can't be provided in both encoder and decoder"
|
||||
)
|
||||
|
||||
seq = SequenceGroupMetadata(
|
||||
request_id=str(group_id),
|
||||
is_prompt=True,
|
||||
seq_data={group_id: decoder_dummy_data.seq_data},
|
||||
sampling_params=sampling_params,
|
||||
block_tables=None,
|
||||
encoder_seq_data=encoder_dummy_data.seq_data,
|
||||
cross_block_table=None,
|
||||
lora_request=dummy_lora_requests_per_seq[group_id]
|
||||
if dummy_lora_requests_per_seq else None,
|
||||
multi_modal_data=decoder_dummy_data.multi_modal_data
|
||||
or encoder_dummy_data.multi_modal_data,
|
||||
multi_modal_placeholders=decoder_dummy_data.
|
||||
multi_modal_placeholders
|
||||
or encoder_dummy_data.multi_modal_placeholders)
|
||||
seqs.append(seq)
|
||||
|
||||
finished_requests_ids = [seq.request_id for seq in seqs]
|
||||
model_input = self.prepare_model_input(
|
||||
seqs, finished_requests_ids=finished_requests_ids)
|
||||
intermediate_tensors = None
|
||||
self.execute_model(model_input, None, intermediate_tensors)
|
||||
torch.cuda.synchronize()
|
||||
return
|
||||
|
||||
def _prepare_encoder_model_input_tensors(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
model_input: EncoderDecoderModelInput,
|
||||
) -> Tuple[AttentionMetadata, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
"""Helper method to prepare the encoder- and cross-attn-related
|
||||
model inputs based on a given sequence group. These additional inputs
|
||||
are used to augment an already-computed `EncoderDecoderModelInput`
|
||||
data structure which already has decoder-related model inputs
|
||||
populated.
|
||||
|
||||
Sets the following attn_metadata fields:
|
||||
* `num_encoder_tokens`
|
||||
* `encoder_seq_lens`
|
||||
* `encoder_seq_lens_tensor`
|
||||
* `max_encoder_seq_len`
|
||||
* `cross_slot_mapping`
|
||||
* `cross_block_tables`
|
||||
|
||||
Constructs a new model inputs data structure, based on
|
||||
(1) the existing fields in the `model_inputs` argument,
|
||||
and (2) the following additional fields which are
|
||||
computed (or in the case of `attn_metadata`, updated)
|
||||
by this function:
|
||||
* attn_metadata
|
||||
* encoder_input_tokens
|
||||
* encoder_input_positions
|
||||
|
||||
Arguments:
|
||||
|
||||
* seq_group_metadata_list: list of sequence groups for which to
|
||||
compute inputs
|
||||
* model_inputs: model inputs data structure with decoder-oriented
|
||||
fields already computed.
|
||||
|
||||
Return:
|
||||
|
||||
* Updated model inputs data structure
|
||||
"""
|
||||
|
||||
if len(seq_group_metadata_list) == 0:
|
||||
return (model_input.attn_metadata, None, None)
|
||||
|
||||
# Since we are not supporting chunked prefill either the entire
|
||||
# batch is prefill or it is decode
|
||||
is_prompt = seq_group_metadata_list[0].is_prompt
|
||||
|
||||
# Build encoder inputs
|
||||
encoder_seq_lens: List[int] = []
|
||||
if is_prompt:
|
||||
# Prefill phase.
|
||||
cross_block_tables = self._empty_int32_tensor().view(
|
||||
len(seq_group_metadata_list), -1)
|
||||
|
||||
# Extract input tokens/positions, cross-attention slot-mapping,
|
||||
# & seq len from each sequence group metadata
|
||||
(
|
||||
encoder_input_tokens,
|
||||
encoder_input_positions,
|
||||
cross_slot_mapping,
|
||||
) = (
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
)
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
# Build seq lens
|
||||
seq_len = seq_group_metadata.encoder_seq_data.get_len()
|
||||
token_ids = seq_group_metadata.encoder_seq_data.get_token_ids()
|
||||
encoder_seq_lens.append(seq_len)
|
||||
|
||||
# Build slot mapping
|
||||
is_profile_run = (seq_group_metadata.block_tables is None)
|
||||
if is_profile_run:
|
||||
# During memory profiling, the block tables are not
|
||||
# initialized yet. In this case, we just use a dummy
|
||||
# slot mapping.
|
||||
# In embeddings, the block tables are {seq_id: None}.
|
||||
cross_slot_mapping.extend([PAD_SLOT_ID] * seq_len)
|
||||
else:
|
||||
for i in range(0, seq_len):
|
||||
block_number = seq_group_metadata.cross_block_table[
|
||||
i // self.block_size]
|
||||
block_offset = i % self.block_size
|
||||
slot = block_number * self.block_size + block_offset
|
||||
cross_slot_mapping.append(slot)
|
||||
|
||||
# Build encoder input tokens
|
||||
encoder_input_tokens.extend(token_ids)
|
||||
encoder_input_positions.extend(list(range(0, seq_len)))
|
||||
|
||||
# Convert tokens/positions & cross-attention
|
||||
# slot-mapping to encoder input tensors
|
||||
encoder_input_tokens_tensor = self._list_to_long_tensor(
|
||||
encoder_input_tokens)
|
||||
encoder_input_positions_tensor = self._list_to_long_tensor(
|
||||
encoder_input_positions)
|
||||
cross_slot_mapping_tensor = self._list_to_long_tensor(
|
||||
cross_slot_mapping)
|
||||
|
||||
else:
|
||||
# Decode phase.
|
||||
encoder_input_tokens_tensor = self._empty_long_tensor()
|
||||
encoder_input_positions_tensor = self._empty_long_tensor()
|
||||
cross_slot_mapping_tensor = self._empty_long_tensor()
|
||||
# Extract cross-attention block tables &
|
||||
# seq len from each sequence group metadata.
|
||||
# Cross-attention block tables are empty
|
||||
# during vLLM memory profiling.
|
||||
cross_block_tables = []
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
for _ in range(len(seq_group_metadata.seq_data)):
|
||||
encoder_seq_lens.append(
|
||||
seq_group_metadata.encoder_seq_data.get_len())
|
||||
cross_block_table = seq_group_metadata.cross_block_table
|
||||
cross_block_tables.append([] if (
|
||||
cross_block_table is None) else cross_block_table)
|
||||
|
||||
if (model_input.attn_metadata is not None
|
||||
and model_input.attn_metadata.use_cuda_graph):
|
||||
# We will be using CUDA graph replay for this decode.
|
||||
max_len_of_block_table = self.get_max_block_per_batch()
|
||||
batch_size = len(encoder_seq_lens)
|
||||
graph_batch_size = self.vllm_config.pad_for_cudagraph(
|
||||
batch_size)
|
||||
assert graph_batch_size >= batch_size
|
||||
cuda_graph_pad_size = graph_batch_size - batch_size
|
||||
# extend the cross_block_tables and encoder_seq_lens to match
|
||||
# the graph_batch_size.
|
||||
cross_block_tables.extend([[]
|
||||
for _ in range(cuda_graph_pad_size)
|
||||
])
|
||||
encoder_seq_lens.extend(
|
||||
itertools.repeat(1, cuda_graph_pad_size))
|
||||
|
||||
else:
|
||||
max_len_of_block_table = max(
|
||||
len(block_table) for block_table in cross_block_tables)
|
||||
|
||||
cross_block_tables = make_tensor_with_pad(
|
||||
cross_block_tables,
|
||||
max_len=max_len_of_block_table,
|
||||
pad=0,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
# Compute encoder sequence lengths & encoder
|
||||
# sequence starting offset tensors
|
||||
max_encoder_seq_len = max(encoder_seq_lens, default=0)
|
||||
encoder_seq_lens_tensor = self._list_to_int32_tensor(encoder_seq_lens)
|
||||
encoder_seq_start_loc = torch.zeros(encoder_seq_lens_tensor.shape[0] +
|
||||
1,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
torch.cumsum(encoder_seq_lens_tensor,
|
||||
dim=0,
|
||||
dtype=encoder_seq_start_loc.dtype,
|
||||
out=encoder_seq_start_loc[1:])
|
||||
|
||||
# Update attention metadata with encoder-oriented attributes
|
||||
attn_metadata = model_input.attn_metadata
|
||||
assert attn_metadata is not None
|
||||
(
|
||||
attn_metadata.num_encoder_tokens,
|
||||
attn_metadata.encoder_seq_lens,
|
||||
attn_metadata.encoder_seq_lens_tensor,
|
||||
attn_metadata.max_encoder_seq_len,
|
||||
attn_metadata.encoder_seq_start_loc,
|
||||
attn_metadata.cross_slot_mapping,
|
||||
attn_metadata.cross_block_tables,
|
||||
) = (
|
||||
sum(encoder_seq_lens),
|
||||
encoder_seq_lens,
|
||||
encoder_seq_lens_tensor,
|
||||
max_encoder_seq_len,
|
||||
encoder_seq_start_loc,
|
||||
cross_slot_mapping_tensor,
|
||||
cross_block_tables,
|
||||
)
|
||||
|
||||
return (attn_metadata, encoder_input_tokens_tensor,
|
||||
encoder_input_positions_tensor)
|
||||
2320
vllm/worker/hpu_model_runner.py
Normal file
2320
vllm/worker/hpu_model_runner.py
Normal file
File diff suppressed because it is too large
Load Diff
484
vllm/worker/hpu_worker.py
Normal file
484
vllm/worker/hpu_worker.py
Normal file
@@ -0,0 +1,484 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
###############################################################################
|
||||
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company
|
||||
###############################################################################
|
||||
|
||||
import contextlib
|
||||
import gc
|
||||
import os
|
||||
from typing import List, Optional, Set, Tuple, Type
|
||||
|
||||
import habana_frameworks.torch as htorch # noqa:F401
|
||||
import torch
|
||||
import torch.distributed
|
||||
from vllm_hpu_extension.profiler import HabanaMemoryProfiler, format_bytes
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ParallelConfig, VllmConfig
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.utils import bind_kv_cache
|
||||
from vllm.worker.cache_engine import CacheEngine
|
||||
from vllm.worker.hpu_model_runner import HPUModelRunner
|
||||
from vllm.worker.model_runner_base import ModelRunnerBase
|
||||
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
|
||||
WorkerInput)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class HPUWorker(LocalOrDistributedWorkerBase):
|
||||
"""A worker class that executes (a partition of) the model on a HPU.
|
||||
|
||||
Each worker is associated with a single HPU. The worker is responsible for
|
||||
maintaining the KV cache and executing the model on the HPU. In case of
|
||||
distributed inference, each worker is assigned a partition of the model.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
is_driver_worker: bool = False,
|
||||
model_runner_cls: Optional[Type[ModelRunnerBase]] = None,
|
||||
) -> None:
|
||||
WorkerBase.__init__(self, vllm_config=vllm_config)
|
||||
self.parallel_config.rank = rank
|
||||
self.local_rank = local_rank
|
||||
self.rank = rank
|
||||
self.distributed_init_method = distributed_init_method
|
||||
self.is_driver_worker = is_driver_worker
|
||||
if self.is_driver_worker:
|
||||
assert self.rank == 0, "The driver worker must have rank 0."
|
||||
|
||||
if self.model_config.trust_remote_code:
|
||||
# note: lazy import to avoid importing torch before initializing
|
||||
from vllm.utils import init_cached_hf_modules
|
||||
init_cached_hf_modules()
|
||||
|
||||
self.model_runner: HPUModelRunner = HPUModelRunner(
|
||||
vllm_config=vllm_config, is_driver_worker=is_driver_worker)
|
||||
# Uninitialized cache engine. Will be initialized by
|
||||
# initialize_cache.
|
||||
self.cache_engine: List[HPUCacheEngine]
|
||||
# Initialize gpu_cache as pooling models don't initialize kv_caches
|
||||
self.hpu_cache: Optional[List[List[torch.Tensor]]] = None
|
||||
# Torch profiler. Enabled and configured through env vars:
|
||||
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
|
||||
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
|
||||
logger.info("Profiling enabled. Traces will be saved to: %s",
|
||||
torch_profiler_trace_dir)
|
||||
self.profiler = torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.HPU,
|
||||
],
|
||||
with_stack=True,
|
||||
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
||||
torch_profiler_trace_dir, use_gzip=True))
|
||||
else:
|
||||
self.profiler = None
|
||||
|
||||
def start_profile(self):
|
||||
if self.profiler is None:
|
||||
raise RuntimeError("Profiler is not enabled.")
|
||||
self.profiler.start()
|
||||
|
||||
def stop_profile(self):
|
||||
if self.profiler is None:
|
||||
raise RuntimeError("Profiler is not enabled.")
|
||||
self.profiler.stop()
|
||||
|
||||
def _set_env_vars(self):
|
||||
local_rank = self.local_rank
|
||||
if self.parallel_config.world_size == 1:
|
||||
local_rank = -1
|
||||
import os
|
||||
os.environ["LOCAL_RANK"] = str(local_rank)
|
||||
os.environ["ID"] = str(local_rank)
|
||||
os.environ["WORLD_SIZE"] = str(self.parallel_config.world_size)
|
||||
os.environ["RANK"] = str(self.rank)
|
||||
|
||||
def init_device(self) -> None:
|
||||
if self.device_config.device.type == "hpu":
|
||||
self.device = torch.device("hpu")
|
||||
torch.hpu.set_device(self.device)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Not support device type: {self.device_config.device}")
|
||||
# Initialize the distributed environment.
|
||||
if self.model_config.quantization == 'inc':
|
||||
self._set_env_vars()
|
||||
init_worker_distributed_environment(self.parallel_config, self.rank,
|
||||
self.distributed_init_method,
|
||||
self.local_rank)
|
||||
# Set random seed.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
def load_model(self):
|
||||
self.model_runner.load_model()
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None,
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
# VLLM_HPU_LOG_STEP_GRAPH_COMPILATION - will log graph compilations per engine step, only when there was any - highly recommended to use alongside PT_HPU_METRICS_GC_DETAILS! # noqa:E501
|
||||
# VLLM_HPU_LOG_STEP_GRAPH_COMPILATION_ALL - will log graph compilations per engine step, always, even if there were none # noqa:E501
|
||||
# VLLM_HPU_LOG_STEP_CPU_FALLBACKS - will log cpu fallbacks per engine step, only when there was any # noqa:E501
|
||||
# VLLM_HPU_LOG_STEP_CPU_FALLBACKS_ALL - will log cpu fallbacks per engine step, always, even if there were none # noqa:E501
|
||||
log_graph_compilation_all = os.environ.get(
|
||||
'VLLM_HPU_LOG_STEP_GRAPH_COMPILATION_ALL', '0') != '0'
|
||||
log_graph_compilation = os.environ.get(
|
||||
'VLLM_HPU_LOG_STEP_GRAPH_COMPILATION',
|
||||
'0') != '0' or log_graph_compilation_all
|
||||
log_cpu_fallbacks_all = os.environ.get(
|
||||
'VLLM_HPU_LOG_STEP_CPU_FALLBACKS_ALL', '0') != '0'
|
||||
log_cpu_fallbacks = os.environ.get('VLLM_HPU_LOG_STEP_CPU_FALLBACKS',
|
||||
'0') != '0' or log_cpu_fallbacks_all
|
||||
if (log_graph_compilation or log_cpu_fallbacks) and \
|
||||
execute_model_req is not None:
|
||||
from habana_frameworks.torch.hpu.metrics import metric_localcontext
|
||||
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
|
||||
is_prompt = any([
|
||||
seq_group_metadata.is_prompt
|
||||
for seq_group_metadata in seq_group_metadata_list
|
||||
])
|
||||
max_context_len = max([
|
||||
max([
|
||||
len(v.prompt_token_ids) + len(v.output_token_ids)
|
||||
for v in seq_group_metadata.seq_data.values()
|
||||
]) for seq_group_metadata in seq_group_metadata_list
|
||||
]) # whoa, that's some spicy stuff right here
|
||||
max_num_blocks = (
|
||||
(max_context_len - 1) // self.cache_config.block_size) + 1
|
||||
input_stats = (f'is_prompt: {is_prompt}, '
|
||||
f'num_seqs: {len(seq_group_metadata_list)}, '
|
||||
f'max_context_len: {max_context_len}, '
|
||||
f'max_num_blocks {max_num_blocks}')
|
||||
gc_ctx = metric_localcontext(
|
||||
"graph_compilation"
|
||||
) if log_graph_compilation else contextlib.nullcontext()
|
||||
cpu_fallback_ctx = metric_localcontext(
|
||||
"cpu_fallback"
|
||||
) if log_cpu_fallbacks else contextlib.nullcontext()
|
||||
with gc_ctx as gc_local_metric, \
|
||||
cpu_fallback_ctx as cpu_fallback_local_metric:
|
||||
output = LocalOrDistributedWorkerBase.execute_model(
|
||||
self, execute_model_req)
|
||||
if (log_graph_compilation and gc_local_metric.stats()[0][1]
|
||||
> 0) or log_graph_compilation_all:
|
||||
msg = ("VLLM_HPU_STEP_GRAPH_COMPILATION: "
|
||||
f"{gc_local_metric.stats()}, {input_stats}")
|
||||
logger.warning(msg)
|
||||
if (log_cpu_fallbacks and cpu_fallback_local_metric.stats()[0][1]
|
||||
> 0) or log_cpu_fallbacks_all:
|
||||
msg = ("VLLM_HPU_STEP_CPU_FALLBACK: "
|
||||
f"{cpu_fallback_local_metric.stats()}, {input_stats}")
|
||||
logger.warning(msg)
|
||||
|
||||
return output
|
||||
|
||||
output = LocalOrDistributedWorkerBase.execute_model(
|
||||
self, execute_model_req)
|
||||
return output
|
||||
|
||||
@torch.inference_mode()
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
"""Profiles the peak memory usage of the model to determine how many
|
||||
KV blocks may be allocated without OOMs.
|
||||
|
||||
The engine will first conduct a profiling of the existing memory usage.
|
||||
Then, it calculate the maximum possible number of GPU and CPU blocks
|
||||
that can be allocated with the remaining free memory.
|
||||
|
||||
Tip:
|
||||
You may limit the usage of GPU memory
|
||||
by adjusting the `gpu_memory_utilization` parameter.
|
||||
"""
|
||||
# Profile the memory usage of the model and get the maximum number of
|
||||
# cache blocks that can be allocated with the remaining free memory.
|
||||
|
||||
# Execute a forward pass with dummy inputs to profile the memory usage
|
||||
# of the model.
|
||||
with HabanaMemoryProfiler() as m:
|
||||
self.model_runner.profile_run()
|
||||
torch.hpu.synchronize()
|
||||
msg = ("Model profiling run "
|
||||
f"took {m.get_summary_string()}")
|
||||
logger.info(msg)
|
||||
# At this point we should've allocated the maximum workspace for all
|
||||
# recipes we will use the extra memory for graphs/blocks
|
||||
free_hpu_memory = torch.hpu.mem_get_info()[0]
|
||||
|
||||
cache_block_size = self.get_cache_block_size_bytes()
|
||||
graph_reserved_mem = (float(
|
||||
os.environ.get('VLLM_GRAPH_RESERVED_MEM', '0.1'))
|
||||
if not self.model_config.enforce_eager else 0)
|
||||
graph_headroom = 1 - graph_reserved_mem
|
||||
available_hpu_memory = free_hpu_memory * \
|
||||
self.cache_config.gpu_memory_utilization
|
||||
hpu_memory_margin = free_hpu_memory * (
|
||||
1 - self.cache_config.gpu_memory_utilization)
|
||||
self.model_runner.mem_margin = hpu_memory_margin
|
||||
cache_size_bytes = available_hpu_memory * graph_headroom
|
||||
graph_headroom_bytes = available_hpu_memory * (1 - graph_headroom)
|
||||
msg = (
|
||||
f"Free device memory: {format_bytes(free_hpu_memory)}, "
|
||||
f"{format_bytes(available_hpu_memory)} usable "
|
||||
f"(gpu_memory_utilization={self.cache_config.gpu_memory_utilization}),"
|
||||
f" {format_bytes(graph_headroom_bytes)} reserved for HPUGraphs "
|
||||
f"(VLLM_GRAPH_RESERVED_MEM={graph_reserved_mem}), "
|
||||
f"{format_bytes(cache_size_bytes)} reserved for KV cache")
|
||||
logger.info(msg)
|
||||
num_hpu_blocks = int(cache_size_bytes // cache_block_size)
|
||||
num_cpu_blocks = int(self.cache_config.swap_space_bytes //
|
||||
cache_block_size)
|
||||
num_hpu_blocks = max(num_hpu_blocks, 0)
|
||||
num_cpu_blocks = max(num_cpu_blocks, 0)
|
||||
self.model_runner.bucketing_ctx.num_hpu_blocks = num_hpu_blocks
|
||||
|
||||
if self.model_runner.lora_manager:
|
||||
self.model_runner.remove_all_loras()
|
||||
|
||||
gc.collect()
|
||||
return num_hpu_blocks, num_cpu_blocks
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int,
|
||||
num_cpu_blocks: int) -> None:
|
||||
"""Allocate GPU and CPU KV cache with the specified number of blocks.
|
||||
|
||||
This also warms up the model, which may record CUDA graphs.
|
||||
"""
|
||||
raise_if_cache_size_invalid(
|
||||
num_gpu_blocks, self.cache_config.block_size,
|
||||
self.model_config.max_model_len,
|
||||
self.parallel_config.pipeline_parallel_size)
|
||||
|
||||
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||
|
||||
with HabanaMemoryProfiler() as m:
|
||||
self._init_cache_engine()
|
||||
torch.hpu.synchronize()
|
||||
msg = ("Initializing cache engine "
|
||||
f"took {m.get_summary_string()}")
|
||||
logger.info(msg)
|
||||
self._warm_up_model()
|
||||
|
||||
def _init_cache_engine(self):
|
||||
assert self.cache_config.num_gpu_blocks is not None
|
||||
self.cache_engine = [
|
||||
HPUCacheEngine(self.cache_config, self.model_config,
|
||||
self.parallel_config, self.device_config)
|
||||
for _ in range(self.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
self.hpu_cache = [
|
||||
self.cache_engine[ve].gpu_cache
|
||||
for ve in range(self.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
bind_kv_cache(self.compilation_config.static_forward_context,
|
||||
self.hpu_cache)
|
||||
|
||||
def _warm_up_model(self) -> None:
|
||||
# NOTE(kzawora): We should use virtual engine index here
|
||||
# for pipeline parallelism. Using 0 for now.
|
||||
assert self.hpu_cache is not None
|
||||
self.model_runner.warmup_model(self.hpu_cache[0])
|
||||
# Reset the seed to ensure that the random state is not affected by
|
||||
# the model initialization and profiling.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
def finish_measurements(self):
|
||||
self.model_runner.finish_measurements()
|
||||
|
||||
@property
|
||||
def do_metadata_broadcast(self) -> bool:
|
||||
return self.parallel_config.tensor_parallel_size > 1
|
||||
|
||||
@property
|
||||
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
|
||||
return self.hpu_cache
|
||||
|
||||
@torch.inference_mode()
|
||||
def prepare_worker_input(
|
||||
self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
|
||||
virtual_engine = execute_model_req.virtual_engine
|
||||
num_seq_groups = len(execute_model_req.seq_group_metadata_list)
|
||||
# `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors.
|
||||
# they contain parameters to launch cudamemcpyasync.
|
||||
blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in,
|
||||
device="cpu",
|
||||
dtype=torch.int64).view(-1, 2)
|
||||
blocks_to_swap_out = torch.tensor(execute_model_req.blocks_to_swap_out,
|
||||
device="cpu",
|
||||
dtype=torch.int64).view(-1, 2)
|
||||
# `blocks_to_copy` is a gpu tensor. The src and tgt of
|
||||
# blocks to copy are in the same device, and `blocks_to_copy`
|
||||
# can be used directly within cuda kernels.
|
||||
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
|
||||
device=self.device,
|
||||
dtype=torch.int64).view(-1, 2)
|
||||
|
||||
return WorkerInput(
|
||||
num_seq_groups=num_seq_groups,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
virtual_engine=virtual_engine,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_worker(self, worker_input: WorkerInput) -> None:
|
||||
virtual_engine = worker_input.virtual_engine
|
||||
# Issue cache operations.
|
||||
if (worker_input.blocks_to_swap_in is not None
|
||||
and worker_input.blocks_to_swap_in.numel() > 0):
|
||||
self.cache_engine[virtual_engine].swap_in(
|
||||
worker_input.blocks_to_swap_in)
|
||||
if (worker_input.blocks_to_swap_out is not None
|
||||
and worker_input.blocks_to_swap_out.numel() > 0):
|
||||
self.cache_engine[virtual_engine].swap_out(
|
||||
worker_input.blocks_to_swap_out)
|
||||
if (worker_input.blocks_to_copy is not None
|
||||
and worker_input.blocks_to_copy.numel() > 0):
|
||||
self.cache_engine[virtual_engine].copy(worker_input.blocks_to_copy)
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
return self.model_runner.add_lora(lora_request)
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
return self.model_runner.remove_lora(lora_id)
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
return self.model_runner.pin_lora(lora_id)
|
||||
|
||||
def list_loras(self) -> Set[int]:
|
||||
return self.model_runner.list_loras()
|
||||
|
||||
def add_prompt_adapter(
|
||||
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
|
||||
raise NotImplementedError(
|
||||
"Prompt Adapter is not implemented for HPU backend.")
|
||||
|
||||
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
|
||||
raise NotImplementedError(
|
||||
"Prompt Adapter is not implemented for HPU backend.")
|
||||
|
||||
def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
|
||||
raise NotImplementedError(
|
||||
"Prompt Adapter is not implemented for HPU backend.")
|
||||
|
||||
def list_prompt_adapters(self) -> Set[int]:
|
||||
raise NotImplementedError(
|
||||
"Prompt Adapter is not implemented for HPU backend.")
|
||||
|
||||
def shutdown_inc(self):
|
||||
self.model_runner.shutdown_inc()
|
||||
|
||||
@property
|
||||
def max_model_len(self) -> int:
|
||||
return self.model_config.max_model_len
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
return self.model_runner.vocab_size
|
||||
|
||||
def get_cache_block_size_bytes(self) -> int:
|
||||
"""Get the size of the KV cache block size in bytes.
|
||||
"""
|
||||
return HPUCacheEngine.get_cache_block_size(self.cache_config,
|
||||
self.model_config,
|
||||
self.parallel_config)
|
||||
|
||||
|
||||
def init_worker_distributed_environment(
|
||||
parallel_config: ParallelConfig,
|
||||
rank: int,
|
||||
distributed_init_method: Optional[str] = None,
|
||||
local_rank: int = -1,
|
||||
) -> None:
|
||||
"""Initialize the distributed environment."""
|
||||
init_distributed_environment(parallel_config.world_size,
|
||||
rank,
|
||||
distributed_init_method,
|
||||
local_rank,
|
||||
backend='hccl')
|
||||
|
||||
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size)
|
||||
|
||||
if torch.distributed.is_initialized():
|
||||
torch_world_size = torch.distributed.get_world_size()
|
||||
if torch_world_size != parallel_config.world_size:
|
||||
raise RuntimeError(
|
||||
"torch.distributed is already initialized but the torch world "
|
||||
"size does not match parallel_config.world_size "
|
||||
f"({torch_world_size} vs. {parallel_config.world_size}).")
|
||||
elif not distributed_init_method:
|
||||
raise ValueError(
|
||||
"distributed_init_method must be set if torch.distributed "
|
||||
"is not already initialized")
|
||||
else:
|
||||
torch.distributed.init_process_group(
|
||||
backend="hccl",
|
||||
world_size=parallel_config.world_size,
|
||||
rank=rank,
|
||||
init_method=distributed_init_method,
|
||||
)
|
||||
|
||||
# A small all_reduce for warmup & checking conformance.
|
||||
dummy_tensor_hpu = torch.ones(1).to('hpu')
|
||||
torch.distributed.all_reduce(dummy_tensor_hpu)
|
||||
assert dummy_tensor_hpu.item() == parallel_config.world_size
|
||||
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size)
|
||||
|
||||
|
||||
def raise_if_cache_size_invalid(num_gpu_blocks, block_size, max_model_len,
|
||||
pipeline_parallel_size) -> None:
|
||||
if num_gpu_blocks <= 0:
|
||||
raise ValueError("No available memory for the cache blocks. "
|
||||
"Try increasing `gpu_memory_utilization` when "
|
||||
"initializing the engine.")
|
||||
max_seq_len = block_size * (num_gpu_blocks // pipeline_parallel_size)
|
||||
if max_model_len > max_seq_len:
|
||||
raise ValueError(
|
||||
f"The model's max seq len ({max_model_len}) "
|
||||
"is larger than the maximum number of tokens that can be "
|
||||
f"stored in KV cache ({max_seq_len}). Try increasing "
|
||||
"`gpu_memory_utilization` or decreasing `max_model_len` when "
|
||||
"initializing the engine.")
|
||||
|
||||
|
||||
class HPUCacheEngine(CacheEngine):
|
||||
|
||||
def _allocate_kv_cache(
|
||||
self,
|
||||
num_blocks: int,
|
||||
device: str,
|
||||
) -> List[Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""Allocates KV cache on the specified device."""
|
||||
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
||||
num_blocks, self.block_size, self.num_kv_heads, self.head_size)
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]] = []
|
||||
for _ in range(self.num_attention_layers):
|
||||
key_cache = torch.zeros(kv_cache_shape,
|
||||
dtype=self.dtype,
|
||||
device=device)
|
||||
value_cache = torch.zeros(kv_cache_shape,
|
||||
dtype=self.dtype,
|
||||
device=device)
|
||||
kv_layer = (key_cache, value_cache)
|
||||
kv_cache.append(kv_layer)
|
||||
return kv_cache
|
||||
2215
vllm/worker/model_runner.py
Normal file
2215
vllm/worker/model_runner.py
Normal file
File diff suppressed because it is too large
Load Diff
282
vllm/worker/model_runner_base.py
Normal file
282
vllm/worker/model_runner_base.py
Normal file
@@ -0,0 +1,282 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import dataclasses
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
|
||||
TypeVar)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
T = TypeVar('T', bound="BroadcastableModelInput")
|
||||
|
||||
|
||||
def _add_attn_metadata_broadcastable_dict(
|
||||
tensor_dict: Dict[str, Any],
|
||||
attn_metadata: Optional["AttentionMetadata"]) -> None:
|
||||
"""
|
||||
Helper method to update tensor_dict with broadcastable
|
||||
AttentionMetadata fields.
|
||||
"""
|
||||
if attn_metadata is not None:
|
||||
tensor_dict.update(attn_metadata.asdict_zerocopy())
|
||||
|
||||
|
||||
def _init_attn_metadata_from_tensor_dict(
|
||||
attn_backend: "AttentionBackend",
|
||||
tensor_dict: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Helper method to initialize AttentionMetadata based on an
|
||||
AttentionBackend and broadcastable AttentionMetadata fields.
|
||||
"""
|
||||
# Extract the fields used to create AttentionMetadata.
|
||||
valid_attn_kwargs = {}
|
||||
for field in dataclasses.fields(attn_backend.get_metadata_cls()):
|
||||
if field.name in tensor_dict:
|
||||
if field.name == "input_positions":
|
||||
valid_attn_kwargs[field.name] = tensor_dict[field.name]
|
||||
else:
|
||||
valid_attn_kwargs[field.name] = tensor_dict.pop(field.name)
|
||||
|
||||
attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs)
|
||||
tensor_dict["attn_metadata"] = attn_metadata
|
||||
return tensor_dict
|
||||
|
||||
|
||||
def _init_sampling_metadata_from_tensor_dict( # type: ignore
|
||||
tensor_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Helper method to initialize SamplingMetadata based on broadcastable
|
||||
SamplingMetadata fields.
|
||||
"""
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
|
||||
selected_token_indices = tensor_dict.pop("selected_token_indices", None)
|
||||
# An empty SamplingMetadata to signal that the worker should skip
|
||||
# sampling.
|
||||
if selected_token_indices is not None:
|
||||
tensor_dict["sampling_metadata"] = SamplingMetadata(
|
||||
seq_groups=None,
|
||||
selected_token_indices=selected_token_indices,
|
||||
categorized_sample_indices=None,
|
||||
num_prompts=0,
|
||||
)
|
||||
return tensor_dict
|
||||
|
||||
|
||||
def _add_sampling_metadata_broadcastable_dict(
|
||||
tensor_dict: Dict[str, Any],
|
||||
sampling_metadata: Optional["SamplingMetadata"]) -> None:
|
||||
"""
|
||||
Helper method to update tensor_dict with broadcastable
|
||||
SamplingMetadata fields.
|
||||
"""
|
||||
if sampling_metadata is not None:
|
||||
tensor_dict["selected_token_indices"] = (
|
||||
sampling_metadata.selected_token_indices)
|
||||
|
||||
|
||||
def _init_frozen_model_input_from_tensor_dict(
|
||||
frozen_model_input_cls: Type["ModelRunnerInputBase"],
|
||||
tensor_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Helper method to initialize a frozen ModelInput based on broadcastable
|
||||
"""
|
||||
valid_tensor_kwargs = {}
|
||||
for field in dataclasses.fields(frozen_model_input_cls):
|
||||
val = tensor_dict.pop(field.name, None)
|
||||
if val is not None:
|
||||
valid_tensor_kwargs[field.name] = val
|
||||
|
||||
frozen_model_input = frozen_model_input_cls(**valid_tensor_kwargs)
|
||||
tensor_dict["frozen_model_input"] = frozen_model_input
|
||||
return tensor_dict
|
||||
|
||||
|
||||
class BroadcastableModelInput(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Extract broadcastable fields. Override for fields that require some
|
||||
custom deserialization.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_broadcasted_tensor_dict(
|
||||
cls: Type[T],
|
||||
tensor_dict: Dict[str, Any],
|
||||
attn_backend: Optional["AttentionBackend"] = None,
|
||||
) -> T:
|
||||
"""
|
||||
Pop fields from the given tensor_dict and populate a new instance of
|
||||
BroadcastableModelInput.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ModelRunnerInputBase(BroadcastableModelInput):
|
||||
"""Local inputs to each worker's model runner. May contain
|
||||
device-specific data. Different worker backends may have different methods
|
||||
of converting from the global ExecuteModelRequest produced by the LLM
|
||||
engine to the worker-local ModelRunnerInputBase objects.
|
||||
|
||||
Model runners that support multi-GPU execution should define a
|
||||
ModelRunnerInputBase subclass, add their required fields, and specify how to
|
||||
serialize/deserialize a ModelInput for broadcast between workers.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ModelRunnerInputBuilderBase(ABC, Generic[T]):
|
||||
"""A builder to create ModelRunnerInputBase objects.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def prepare(self,
|
||||
finished_requests_ids: Optional[List[str]] = None) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def add_seq_group(self, seq_group_metadata):
|
||||
"""TBA"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def build(self, *args, **kwargs) -> T:
|
||||
"""Build metadata with on-device tensors."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ModelRunnerBase(ABC, Generic[T]):
|
||||
"""
|
||||
Model runner interface that abstracts a particular hardware and/or type of
|
||||
model. Model execution may communicate data with model runners in other
|
||||
processes, but it should not include control plane metadata communication.
|
||||
|
||||
Each ModelRunnerBase subclass should define a corresponding
|
||||
ModelRunnerInputBase subclass.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
) -> None:
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.lora_config = vllm_config.lora_config
|
||||
self.load_config = vllm_config.load_config
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.device_config = vllm_config.device_config
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.prompt_adapter_config = vllm_config.prompt_adapter_config
|
||||
self.observability_config = vllm_config.observability_config
|
||||
|
||||
# Map of request_id -> generator used for seeded random sampling
|
||||
generators: Dict[str, torch.Generator] = {}
|
||||
|
||||
@abstractmethod
|
||||
def make_model_input_from_broadcasted_tensor_dict(
|
||||
self,
|
||||
tensor_dict: Dict[str, Any],
|
||||
) -> T:
|
||||
"""
|
||||
Make an instance of a ModelRunnerInputBase from the broadcasted tensor
|
||||
dict.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def prepare_model_input(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
virtual_engine: int = 0,
|
||||
finished_requests_ids: Optional[List[str]] = None,
|
||||
) -> T:
|
||||
"""
|
||||
Prepare the inputs to ModelRunnerBase.execute_model from an execution
|
||||
request. This method may move data to the worker's local device. It is
|
||||
not allowed to communicate with other workers or devices.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_model(self) -> nn.Module:
|
||||
raise NotImplementedError
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
model_input: T,
|
||||
kv_caches: Optional[List[torch.Tensor]],
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
num_steps: int = 1,
|
||||
**kwargs,
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
"""
|
||||
Execute the model on the given input.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_generators(self, finished_request_ids: Optional[List[str]] = None):
|
||||
"""
|
||||
Return dict of per-request generators used for random sampling.
|
||||
"""
|
||||
|
||||
# Clean up generators from completed requests
|
||||
if finished_request_ids:
|
||||
for request_id in finished_request_ids:
|
||||
self.generators.pop(request_id, None)
|
||||
|
||||
return self.generators
|
||||
|
||||
|
||||
class ModelRunnerWrapperBase:
|
||||
"""
|
||||
The whole point of this class is to lazily initialize the model_runner.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_runner: ModelRunnerBase,
|
||||
) -> None:
|
||||
self.model_runner: ModelRunnerBase = model_runner
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self.model_runner, attr)
|
||||
|
||||
|
||||
class InputProcessingError(Exception):
|
||||
"""This exception is raised when an error occurs preparing the inputs for
|
||||
a single sequence group.
|
||||
This allows the engine to gracefully handle errors with a single sequence
|
||||
group without having to fail the entire batch.
|
||||
"""
|
||||
|
||||
def __init__(self, request_id, message):
|
||||
"""request_id is the id of the offending sequence group"""
|
||||
self.request_id = request_id
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
def __str__(self):
|
||||
return "Failed to prepare inputs for sequence group with request id: " \
|
||||
f"{self.request_id}, Error: {self.message}"
|
||||
123
vllm/worker/multi_step_hpu_worker.py
Normal file
123
vllm/worker/multi_step_hpu_worker.py
Normal file
@@ -0,0 +1,123 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
###############################################################################
|
||||
# Copyright (C) 2025 Habana Labs, Ltd. an Intel Company
|
||||
###############################################################################
|
||||
|
||||
import dataclasses
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed import broadcast_tensor_dict
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.worker.hpu_model_runner import ModelInputForHPU
|
||||
from vllm.worker.hpu_worker import HPUWorker
|
||||
from vllm.worker.worker_base import WorkerInput
|
||||
|
||||
|
||||
class MultiStepHPUWorker(HPUWorker):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.cached_model_input: Optional[ModelInputForHPU] = None
|
||||
|
||||
def _get_driver_input_and_broadcast(
|
||||
self, execute_model_req: ExecuteModelRequest
|
||||
) -> Tuple[ModelInputForHPU, WorkerInput, Dict[str, torch.Tensor]]:
|
||||
"""
|
||||
Get the driver input and broadcast it to other workers.
|
||||
"""
|
||||
assert self.is_driver_worker
|
||||
assert execute_model_req.virtual_engine == 0
|
||||
|
||||
is_first_multi_step = execute_model_req.is_first_multi_step
|
||||
is_last_step = execute_model_req.is_last_step
|
||||
|
||||
if is_first_multi_step:
|
||||
# on first step we prepare the worker input and model input normally
|
||||
worker_input: WorkerInput = self.prepare_worker_input(
|
||||
execute_model_req=execute_model_req)
|
||||
worker_input = dataclasses.replace(
|
||||
worker_input,
|
||||
num_steps=execute_model_req.num_lookahead_slots + 1)
|
||||
model_input: ModelInputForHPU = (
|
||||
self.model_runner.prepare_model_input(
|
||||
execute_model_req.seq_group_metadata_list,
|
||||
execute_model_req.virtual_engine,
|
||||
execute_model_req.finished_requests_ids))
|
||||
|
||||
if execute_model_req.async_callback:
|
||||
model_input = dataclasses.replace(
|
||||
model_input,
|
||||
async_callback=execute_model_req.async_callback)
|
||||
else:
|
||||
# on subsequent steps we reuse the worker input and model input
|
||||
assert self.cached_model_input is not None
|
||||
model_input = self.cached_model_input
|
||||
worker_input = WorkerInput()
|
||||
|
||||
model_input = dataclasses.replace(
|
||||
model_input,
|
||||
is_first_multi_step=is_first_multi_step,
|
||||
is_last_step=is_last_step)
|
||||
|
||||
if self.do_metadata_broadcast:
|
||||
if is_first_multi_step:
|
||||
broadcast_data = worker_input.as_broadcastable_tensor_dict()
|
||||
broadcast_data.update(
|
||||
model_input.as_broadcastable_tensor_dict())
|
||||
broadcast_tensor_dict(broadcast_data, src=0)
|
||||
else:
|
||||
broadcast_data = {
|
||||
"is_first_multi_step": is_first_multi_step,
|
||||
"is_last_step": is_last_step,
|
||||
}
|
||||
broadcast_tensor_dict(broadcast_data, src=0)
|
||||
|
||||
# Returning empty dict here to keep this compatible with
|
||||
# `LocalOrDistributedWorkerBase._get_driver_input_and_broadcast`
|
||||
return model_input, worker_input, {}
|
||||
|
||||
def prepare_input(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None,
|
||||
) -> Optional[Tuple[ModelInputForHPU, WorkerInput, Dict[str,
|
||||
torch.Tensor]]]:
|
||||
if self.is_driver_worker:
|
||||
if execute_model_req is None:
|
||||
if self.do_metadata_broadcast:
|
||||
# This signals that there's no more requests to process for
|
||||
# now. All workers are running infinite loop with
|
||||
# broadcast_tensor_dict, and it stops the loop when the
|
||||
# driver broadcasts an empty input. Send an empty input to
|
||||
# notify all other workers to stop their execution loop.
|
||||
broadcast_tensor_dict({}, src=0)
|
||||
return None
|
||||
model_input, worker_input, _ = self._get_driver_input_and_broadcast(
|
||||
execute_model_req)
|
||||
if model_input.is_first_multi_step:
|
||||
self.cached_model_input = model_input
|
||||
return model_input, worker_input, {}
|
||||
else:
|
||||
broadcast_data = broadcast_tensor_dict(src=0)
|
||||
if not broadcast_data:
|
||||
return None
|
||||
|
||||
if len(broadcast_data) == 2:
|
||||
assert self.cached_model_input is not None
|
||||
self.cached_model_input = dataclasses.replace(
|
||||
self.cached_model_input,
|
||||
is_first_multi_step=broadcast_data["is_first_multi_step"],
|
||||
is_last_step=broadcast_data["is_last_step"])
|
||||
empty_worker_input = WorkerInput()
|
||||
return self.cached_model_input, empty_worker_input, {}
|
||||
|
||||
worker_input = WorkerInput.from_broadcasted_tensor_dict(
|
||||
broadcast_data)
|
||||
model_input = (
|
||||
self.model_runner.
|
||||
make_model_input_from_broadcasted_tensor_dict(broadcast_data))
|
||||
self.cached_model_input = model_input
|
||||
return model_input, worker_input, {}
|
||||
911
vllm/worker/multi_step_model_runner.py
Normal file
911
vllm/worker/multi_step_model_runner.py
Normal file
@@ -0,0 +1,911 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import dataclasses
|
||||
import functools
|
||||
from dataclasses import dataclass, field
|
||||
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
|
||||
Union)
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.sampler import (PromptLogprobs, SampleLogprobs,
|
||||
SamplerOutput,
|
||||
SamplingMetadata, get_logprobs,
|
||||
get_pythonized_sample_results)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
|
||||
Logprob, SequenceGroupMetadata, SequenceOutput)
|
||||
from vllm.utils import PyObjectCache, async_tensor_h2d, current_stream
|
||||
from vllm.worker.model_runner import (GPUModelRunnerBase,
|
||||
ModelInputForGPUWithSamplingMetadata)
|
||||
from vllm.worker.model_runner_base import (
|
||||
BroadcastableModelInput, _init_attn_metadata_from_tensor_dict,
|
||||
_init_frozen_model_input_from_tensor_dict,
|
||||
_init_sampling_metadata_from_tensor_dict)
|
||||
|
||||
from ..model_executor.model_loader.tensorizer import TensorizerConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
MULTI_STEP_ATTENTION_BACKENDS = [
|
||||
"FLASH_ATTN", "ROCM_FLASH", "FLASHINFER", "NO_ATTENTION"
|
||||
]
|
||||
MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS = ["FLASH_ATTN", "FLASHINFER"]
|
||||
|
||||
def _get_supported_attention_backends(chunked_prefill_enabled: bool) \
|
||||
-> List[str]:
|
||||
if chunked_prefill_enabled:
|
||||
return MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS
|
||||
else:
|
||||
return MULTI_STEP_ATTENTION_BACKENDS
|
||||
|
||||
|
||||
def seq_output_builder():
|
||||
return SequenceOutput(
|
||||
0, 0,
|
||||
{0: Logprob(logprob=float('inf'), rank=None, decoded_token=None)})
|
||||
|
||||
|
||||
def completion_seq_group_output_builder():
|
||||
return CompletionSequenceGroupOutput([], None)
|
||||
|
||||
|
||||
# Used by pythonization to reduce python object allocations
|
||||
class PythonizationCache:
|
||||
|
||||
def __init__(self):
|
||||
self.cached_seq_output = PyObjectCache(seq_output_builder)
|
||||
self.cached_completion_seq_group_output = PyObjectCache(
|
||||
completion_seq_group_output_builder)
|
||||
|
||||
def reset(self):
|
||||
self.cached_seq_output.reset()
|
||||
self.cached_completion_seq_group_output.reset()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelOutput:
|
||||
"""The output of a single model forward pass.
|
||||
|
||||
The sampler_output_ready_event is set when the tensors in
|
||||
sampler_output are ready (the model+sampler forward pass has
|
||||
completed). We use the event to synchronize the GPU->CPU transfer,
|
||||
which we want to only run when the data has been written to the
|
||||
GPU tensors. Until the event is ready, the tensors in sampler_output
|
||||
will have garbage data.
|
||||
|
||||
There are two scenarios:
|
||||
1. The output tensors are ready and we can pythonize them immediately.
|
||||
2. The output tensors are not ready and we need to wait for the event to be
|
||||
ready.
|
||||
"""
|
||||
sampler_output: SamplerOutput
|
||||
sampler_output_ready_event: torch.cuda.Event
|
||||
sampled_token_ids: Optional[torch.Tensor] = None
|
||||
pythonized: bool = False
|
||||
# On-device tensor containing the logprobs of each token.
|
||||
logprobs: Optional["torch.Tensor"] = None
|
||||
pythonization_cache: Optional[PythonizationCache] = None
|
||||
|
||||
def pythonize(self, input_metadata: "StatefulModelInput",
|
||||
copy_stream: torch.cuda.Stream,
|
||||
pinned_sampled_token_buffer: torch.Tensor) -> None:
|
||||
"""Pythonize the output. Blocking."""
|
||||
if not self.pythonized:
|
||||
self._pythonize_sampler_output(input_metadata, copy_stream,
|
||||
pinned_sampled_token_buffer, True)
|
||||
self.pythonized = True
|
||||
|
||||
def maybe_pythonize(self, input_metadata: "StatefulModelInput",
|
||||
copy_stream: torch.cuda.Stream,
|
||||
pinned_sampled_token_buffer: torch.Tensor) -> None:
|
||||
"""Pythonize the output if ready, else return None. Non-blocking."""
|
||||
if not self.pythonized:
|
||||
self.pythonized = self._pythonize_sampler_output(
|
||||
input_metadata, copy_stream, pinned_sampled_token_buffer,
|
||||
False)
|
||||
|
||||
def _pythonize_sampler_output(self, input_metadata: "StatefulModelInput",
|
||||
copy_stream: torch.cuda.Stream,
|
||||
pinned_sampled_token_buffer: torch.Tensor,
|
||||
blocking: bool) -> bool:
|
||||
"""
|
||||
If blocking is set, will block until the forward pass for the output is
|
||||
ready and pythonize the output. Upon completing Pythonization, erases
|
||||
self.logprobs (note that a non-blocking call that is performed when
|
||||
the sampler output is not yet ready, will not erase self.logprobs.)
|
||||
"""
|
||||
assert self.sampled_token_ids is not None
|
||||
if not blocking and not self.sampler_output_ready_event.query():
|
||||
return False
|
||||
|
||||
if blocking:
|
||||
self.sampler_output_ready_event.synchronize()
|
||||
with torch.cuda.stream(copy_stream):
|
||||
_pythonize_sampler_output(input_metadata, self.sampler_output,
|
||||
pinned_sampled_token_buffer,
|
||||
self.sampled_token_ids, self.logprobs,
|
||||
self.pythonization_cache)
|
||||
|
||||
# Erase the logprobs GPU-side tensor.
|
||||
# Note that although _pythonize_sampler_output() runs in its
|
||||
# own CUDA stream, nonetheless _pythonize_sampler_output()
|
||||
# cannot return until Pythonization is complete; therefore
|
||||
# we know that by the time the CPU reaches this point,
|
||||
# `self.logprobs` is no longer needed.
|
||||
self.logprobs = None
|
||||
return True
|
||||
|
||||
|
||||
@dataclass(frozen=False)
|
||||
class StatefulModelInput(BroadcastableModelInput):
|
||||
# actual frozen model input dataclass passed to _base_model_runner
|
||||
frozen_model_input: Optional[ModelInputForGPUWithSamplingMetadata] = None
|
||||
|
||||
# list of model outputs for each step, may not be all pythonized
|
||||
cached_outputs: List[ModelOutput] = field(default_factory=list)
|
||||
|
||||
# used to pass sampled token ids from the last step to the current step for
|
||||
# TP workers. Used to append to end of outputs and used by advance_step
|
||||
last_sampled_token_ids: Optional[torch.Tensor] = None
|
||||
current_step: int = 0
|
||||
is_multi_step: bool = True
|
||||
is_last_step: bool = False
|
||||
is_first_multi_step: bool = False
|
||||
base_output_proc_callback: Optional[Callable] = None
|
||||
# ping-pong data structures for multi-step to wait on the previous step
|
||||
step_cuda_events: List[current_platform.Event] = field(
|
||||
default_factory=lambda: [current_platform.Event(blocking=True)] * 2)
|
||||
num_seqs: int = -1
|
||||
num_queries: int = -1
|
||||
num_single_step_prefills: int = 0
|
||||
|
||||
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
||||
assert self.frozen_model_input is not None
|
||||
tensor_dict = self.frozen_model_input.as_broadcastable_tensor_dict()
|
||||
new_tensor_dict = {
|
||||
'last_sampled_token_ids': self.last_sampled_token_ids,
|
||||
'current_step': self.current_step,
|
||||
'is_multi_step': self.is_multi_step,
|
||||
'is_last_step': self.is_last_step,
|
||||
'is_first_multi_step': self.is_first_multi_step,
|
||||
'num_seqs': self.num_seqs,
|
||||
'num_queries': self.num_queries,
|
||||
'num_single_step_prefills': self.num_single_step_prefills,
|
||||
}
|
||||
tensor_dict.update(new_tensor_dict)
|
||||
return tensor_dict
|
||||
|
||||
@classmethod
|
||||
def from_broadcasted_tensor_dict(
|
||||
cls,
|
||||
tensor_dict: Dict[str, Any],
|
||||
attn_backend: Optional["AttentionBackend"] = None,
|
||||
) -> "StatefulModelInput":
|
||||
tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
|
||||
if attn_backend is not None:
|
||||
tensor_dict = _init_attn_metadata_from_tensor_dict(
|
||||
attn_backend, tensor_dict)
|
||||
tensor_dict = _init_frozen_model_input_from_tensor_dict(
|
||||
ModelInputForGPUWithSamplingMetadata, tensor_dict)
|
||||
|
||||
return cls(**tensor_dict)
|
||||
|
||||
def record_step_event(self, current_stream: torch.cuda.Stream):
|
||||
# record the event for the current step so that the next step can sync
|
||||
# on it. We modulo by 2 to keep the events in a circular buffer and
|
||||
# support any attn backends that may be supported in the future. ie
|
||||
# Flashinfer would want two DecodeWrappers to overlap the CPU and GPU.
|
||||
self.step_cuda_events[self.current_step & 1] = \
|
||||
torch.cuda.Event(blocking=True)
|
||||
self.step_cuda_events[self.current_step & 1].record(current_stream)
|
||||
|
||||
def wait_previous_step(self):
|
||||
# These cuda events are an explicit synchronization to ensure that
|
||||
# advance_step() (for other attn backends that may be supported in the
|
||||
# future) do not clobber any data structures that is also used by any
|
||||
# enqueued forwards steps. For distributed case, only a single event is
|
||||
# needed, but for single GPU case, since we can let the CPU run much
|
||||
# further ahead, two events allow us to overlap the advance_step with
|
||||
# the previous forward (ie using two DecodeWrappers for flashinfer
|
||||
# backend)
|
||||
self.step_cuda_events[(self.current_step + 1) & 1].wait()
|
||||
|
||||
def add_sampler_output(self,
|
||||
sampler_output: SamplerOutput,
|
||||
sampled_token_ids: Optional[torch.Tensor] = None):
|
||||
self.cached_outputs.append(
|
||||
ModelOutput(sampler_output=sampler_output,
|
||||
sampler_output_ready_event=None,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
pythonized=False))
|
||||
|
||||
def maybe_advance_sampling_metadata(self, device: str, pin_memory: bool):
|
||||
"""
|
||||
sampling_metadata.selected_token_indices is constructed for the
|
||||
first-step in Multi-Step. However, when chunked-prefill is enabled with
|
||||
multi-step, the scheduled prompts are fully processed in the
|
||||
first-step and are processed as decodes in the rest of the steps.
|
||||
This function updates the sampling_metadata.selected_token_indices
|
||||
to account for this conversion.
|
||||
|
||||
Example:
|
||||
Let 2 prompts and 2 decodes be scheduled together. Let the
|
||||
num-tokens to process for the 2 prompts be 5 and 8 respectively.
|
||||
|
||||
In that case, sampling_metadata.sampled_token_indices will be,
|
||||
[4, 12, 13, 14] as it is constructed for the first-step in
|
||||
multi-step.
|
||||
However, the prompts turns to decodes after the first-step
|
||||
and the num-tokens for the previously-prompt sequences will
|
||||
be 1 and 1 as they are decodes now. The self.sampled_token_indices
|
||||
must be updated to [0,1,2,3].
|
||||
"""
|
||||
assert self.current_step == 1 and self.num_single_step_prefills > 0
|
||||
if not get_pp_group().is_last_rank:
|
||||
return
|
||||
|
||||
assert self.frozen_model_input is not None
|
||||
assert self.frozen_model_input.sampling_metadata is not None
|
||||
self.frozen_model_input.sampling_metadata.selected_token_indices = \
|
||||
async_tensor_h2d(list(range(self.num_queries)),
|
||||
dtype=torch.long,
|
||||
target_device=device,
|
||||
pin_memory=pin_memory)
|
||||
|
||||
def maybe_advance_frozen_model_input(self, device: str, pin_memory: bool):
|
||||
"""
|
||||
Advancing the datastructures of StatefulModelInput::frozen_model_input
|
||||
is only required when prefills are scheduled with decodes to run in
|
||||
multi-step. This advancement/correction is required to account for
|
||||
the conversion of Prefills to Decodes after the first multi-step.
|
||||
"""
|
||||
if self.current_step != 1 or self.num_single_step_prefills == 0:
|
||||
return
|
||||
|
||||
assert self.frozen_model_input is not None
|
||||
fmi = self.frozen_model_input
|
||||
|
||||
# Truncate input_tokens
|
||||
assert fmi.input_tokens is not None
|
||||
assert fmi.input_tokens.shape[0] >= self.num_seqs
|
||||
fmi_new_input_tokens: torch.Tensor = fmi.input_tokens[:self.num_seqs]
|
||||
|
||||
# Update frozen_model_input::input_positions.
|
||||
assert fmi.input_positions is not None
|
||||
assert fmi.input_positions.shape[0] >= self.num_seqs
|
||||
fmi_new_input_positions: torch.Tensor = fmi.input_positions[:self.
|
||||
num_seqs]
|
||||
|
||||
# Assert unsupported
|
||||
assert fmi.lora_mapping is None
|
||||
assert fmi.lora_requests is not None
|
||||
assert len(fmi.lora_requests) == 0
|
||||
assert fmi.attn_metadata is not None
|
||||
assert fmi.prompt_adapter_mapping is None
|
||||
assert fmi.prompt_adapter_requests is not None
|
||||
assert len(fmi.prompt_adapter_requests) == 0
|
||||
assert fmi.multi_modal_kwargs is not None
|
||||
assert len(fmi.multi_modal_kwargs) == 0
|
||||
|
||||
self.frozen_model_input = dataclasses.replace(
|
||||
self.frozen_model_input,
|
||||
input_tokens=fmi_new_input_tokens,
|
||||
input_positions=fmi_new_input_positions)
|
||||
|
||||
self.maybe_advance_sampling_metadata(device, pin_memory)
|
||||
|
||||
|
||||
# MutableModelInputForGPUWithMultiStepMetadata is not subclass of
|
||||
# ModelInputForGPU but it wraps the actual input dataclass and adds multi-step
|
||||
# metadata
|
||||
# mypy: disable-error-code=type-var
|
||||
class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
|
||||
# mypy: enable-error-code=type-var
|
||||
|
||||
def __init__(self, base_model_runner: GPUModelRunnerBase, *args, **kwargs):
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# Check attention backend support.
|
||||
supported_attention_backends: List[str] = \
|
||||
_get_supported_attention_backends(
|
||||
self.scheduler_config.chunked_prefill_enabled)
|
||||
if self.attn_backend.get_name() not in supported_attention_backends:
|
||||
ms_config_str: str = "Multi-Step + Chunked-Prefill" \
|
||||
if self.scheduler_config.chunked_prefill_enabled \
|
||||
else "Multi-Step"
|
||||
raise ValueError(
|
||||
f"{ms_config_str} not supported for attention backend: "
|
||||
f"{self.attn_backend.get_name()}. Set VLLM_ATTENTION_BACKEND "
|
||||
f"to a value from {supported_attention_backends}.")
|
||||
|
||||
# uses the base model runner to execute the model and wraps it with
|
||||
# multi-step logic
|
||||
self._base_model_runner: GPUModelRunnerBase = base_model_runner
|
||||
|
||||
self.is_multi_step = self.scheduler_config.is_multi_step
|
||||
self.pinned_sampled_token_ids: Optional[torch.Tensor] = None
|
||||
|
||||
# Using the PythonizationCache in Pipeline-Parallel clobbers the
|
||||
# SequenceOutput and CompletionSequenceGroupOutput object.
|
||||
# When cache-reset happens at the last step of a multi-step
|
||||
# execution, there may be other on-going single-step/multi-step
|
||||
# executions. The current caching implementation does not check
|
||||
# for this.
|
||||
self.pythonization_cache = PythonizationCache() \
|
||||
if self.parallel_config.pipeline_parallel_size == 1 else None
|
||||
|
||||
@functools.cached_property
|
||||
def _copy_stream(self):
|
||||
# used to copy tensors from GPU to CPU asynchronously
|
||||
return torch.cuda.Stream()
|
||||
|
||||
def make_model_input_from_broadcasted_tensor_dict(
|
||||
self, tensor_dict: Dict[str, Any]) -> StatefulModelInput:
|
||||
model_input = (StatefulModelInput.from_broadcasted_tensor_dict(
|
||||
tensor_dict,
|
||||
attn_backend=self.attn_backend,
|
||||
))
|
||||
return model_input
|
||||
|
||||
def prepare_model_input(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
virtual_engine: int = 0,
|
||||
finished_requests_ids: Optional[List[str]] = None
|
||||
) -> StatefulModelInput:
|
||||
frozen_model_input: ModelInputForGPUWithSamplingMetadata = \
|
||||
self._base_model_runner.prepare_model_input(
|
||||
seq_group_metadata_list,
|
||||
virtual_engine,
|
||||
finished_requests_ids)
|
||||
|
||||
assert frozen_model_input.query_lens is not None
|
||||
assert frozen_model_input.seq_lens is not None
|
||||
assert frozen_model_input.attn_metadata is not None
|
||||
num_queries = len(frozen_model_input.query_lens)
|
||||
num_seqs = len(frozen_model_input.seq_lens)
|
||||
num_single_step_prefills = frozen_model_input.attn_metadata.num_prefills
|
||||
|
||||
model_input = StatefulModelInput(
|
||||
frozen_model_input=frozen_model_input,
|
||||
num_seqs=num_seqs,
|
||||
num_queries=num_queries,
|
||||
num_single_step_prefills=num_single_step_prefills)
|
||||
|
||||
return model_input
|
||||
|
||||
def _async_process_outputs(self, model_input: StatefulModelInput,
|
||||
output_proc_callback: Callable):
|
||||
# Proceed with pythonization and output_proc in order.
|
||||
# Stop on the first one that fails to pythonize
|
||||
output_proc_callback()
|
||||
|
||||
cont = True
|
||||
for step_num, model_output in enumerate(model_input.cached_outputs):
|
||||
if not model_output.pythonized:
|
||||
model_output.maybe_pythonize(model_input, self._copy_stream,
|
||||
self.pinned_sampled_token_ids)
|
||||
if model_output.pythonized:
|
||||
ctx = output_proc_callback.keywords["ctx"]
|
||||
ctx.append_output(
|
||||
outputs=[model_output.sampler_output],
|
||||
seq_group_metadata_list=ctx.seq_group_metadata_list,
|
||||
scheduler_outputs=ctx.scheduler_outputs,
|
||||
is_async=False,
|
||||
is_last_step=False,
|
||||
is_first_step_output=step_num == 0)
|
||||
|
||||
output_proc_callback()
|
||||
else:
|
||||
cont = False
|
||||
|
||||
if not cont:
|
||||
break
|
||||
|
||||
def _final_process_outputs(
|
||||
self, model_input: StatefulModelInput,
|
||||
output_proc_callback: Optional[Callable]) -> List[SamplerOutput]:
|
||||
assert model_input.frozen_model_input is not None
|
||||
|
||||
has_async_callback = output_proc_callback is not None
|
||||
|
||||
outputs = []
|
||||
for step_num, output in enumerate(model_input.cached_outputs):
|
||||
is_last_step = step_num == len(model_input.cached_outputs) - 1
|
||||
|
||||
# For non-async case:
|
||||
# -- We simply add the outputs
|
||||
# For async case:
|
||||
# -- Invoke callback, pythonize, add to callback queue and repeat
|
||||
# -- For last output, just add to callback queue
|
||||
if has_async_callback:
|
||||
assert output_proc_callback is not None
|
||||
|
||||
# Invoke callback before pythonize (to overlap with GPU)
|
||||
output_proc_callback()
|
||||
|
||||
# Pythonize
|
||||
if not output.pythonized:
|
||||
output.pythonize(model_input, self._copy_stream,
|
||||
self.pinned_sampled_token_ids)
|
||||
|
||||
# For non last step, add to callback queue to chain
|
||||
# callbacks=>pythonize pairs (for GPU overlap)
|
||||
if not is_last_step:
|
||||
ctx = output_proc_callback.keywords[ # type: ignore
|
||||
"ctx"] # type: ignore
|
||||
ctx.append_output(
|
||||
outputs=[output.sampler_output],
|
||||
seq_group_metadata_list=ctx.
|
||||
seq_group_metadata_list,
|
||||
scheduler_outputs=ctx.scheduler_outputs,
|
||||
is_async=False,
|
||||
is_last_step=False,
|
||||
is_first_step_output=step_num == 0)
|
||||
else:
|
||||
outputs.append(output.sampler_output)
|
||||
else:
|
||||
output.pythonize(model_input, self._copy_stream,
|
||||
self.pinned_sampled_token_ids)
|
||||
outputs.append(output.sampler_output)
|
||||
|
||||
return outputs
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
model_input: StatefulModelInput,
|
||||
kv_caches: List[torch.Tensor],
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
num_steps: int = 1,
|
||||
) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
|
||||
"""
|
||||
Execute the model for a single step and update multi-step
|
||||
metadata
|
||||
"""
|
||||
assert num_steps == 1, "MultiStepModelRunner only supports num_steps=1"
|
||||
frozen_model_input = model_input.frozen_model_input
|
||||
assert frozen_model_input is not None
|
||||
|
||||
# path for warm up runs
|
||||
if not model_input.is_multi_step:
|
||||
return self._base_model_runner.execute_model(
|
||||
frozen_model_input, None, intermediate_tensors, num_steps)
|
||||
|
||||
# make sure we skip the sampler on the lask rank and only pythonize
|
||||
# if CPU is ahead.
|
||||
if self.is_driver_worker and get_pp_group().is_last_rank:
|
||||
if self.pinned_sampled_token_ids is None:
|
||||
self.pinned_sampled_token_ids = torch.zeros(
|
||||
(self.scheduler_config.max_num_seqs, 1),
|
||||
dtype=torch.long,
|
||||
device="cpu",
|
||||
pin_memory=True)
|
||||
|
||||
self._base_model_runner.sampler.include_gpu_probs_tensor = True
|
||||
if frozen_model_input.sampling_metadata:
|
||||
frozen_model_input.sampling_metadata.skip_sampler_cpu_output = (
|
||||
True)
|
||||
|
||||
# some pre-execute model logic for multi-step:
|
||||
# - if it's the first step, we need to reset the sampling tensors
|
||||
# - if it's not the first step, we need to advance the step using the
|
||||
# appended sampler output from last iteration
|
||||
# - also maybe pythonize if CPU is ahead of GPU
|
||||
|
||||
stream = current_stream()
|
||||
if not model_input.is_first_multi_step:
|
||||
# Explicitly block on the previous step's forward to make sure we
|
||||
# don't clobber any GPU tensors still in use.
|
||||
# This is not needed for flashattn backend, but for other attn
|
||||
# backends such as flashinfer that performs extra CPU operations on
|
||||
# input metadata we may need to synchronize any CPU operations that
|
||||
# might clobber enqueued forwards. (prevents CPU from running too
|
||||
# far ahead if needed)
|
||||
model_input.wait_previous_step()
|
||||
model_input = self._advance_step(
|
||||
model_input, model_input.cached_outputs[-1].sampler_output)
|
||||
|
||||
# frozen_model_input may have been updated
|
||||
frozen_model_input = model_input.frozen_model_input
|
||||
assert frozen_model_input is not None
|
||||
|
||||
if model_input.base_output_proc_callback is None:
|
||||
assert frozen_model_input is not None
|
||||
model_input.base_output_proc_callback = \
|
||||
frozen_model_input.async_callback
|
||||
|
||||
if frozen_model_input.async_callback is not None:
|
||||
assert model_input.base_output_proc_callback is not None
|
||||
async_callback = functools.partial(
|
||||
self._async_process_outputs,
|
||||
model_input=model_input,
|
||||
output_proc_callback=model_input.base_output_proc_callback)
|
||||
|
||||
model_input.frozen_model_input = dataclasses.replace( # type: ignore
|
||||
model_input.frozen_model_input,
|
||||
async_callback=async_callback)
|
||||
# Update the local instance
|
||||
frozen_model_input = model_input.frozen_model_input
|
||||
assert frozen_model_input is not None
|
||||
|
||||
# Execute the model
|
||||
output = self._base_model_runner.execute_model(frozen_model_input,
|
||||
None,
|
||||
intermediate_tensors,
|
||||
num_steps=1)
|
||||
|
||||
# record the event for the current step so that the next step can sync
|
||||
model_input.record_step_event(stream)
|
||||
|
||||
if get_pp_group().is_last_rank and self.is_driver_worker:
|
||||
assert isinstance(output, list)
|
||||
assert len(
|
||||
output
|
||||
) == 1, "MultiStepModelRunner requires single-step base_models"
|
||||
|
||||
# event for the pythonization so that we only pythonize if the
|
||||
# tensors are ready. May be able to be combined with the step event
|
||||
output_ready_event = torch.cuda.Event()
|
||||
output_ready_event.record(stream)
|
||||
if self.parallel_config.pipeline_parallel_size > 1:
|
||||
output[0].sampled_token_ids_cpu = output[
|
||||
0].sampled_token_ids.cpu()
|
||||
model_input.cached_outputs.append(
|
||||
ModelOutput(output[0], output_ready_event,
|
||||
output[0].sampled_token_ids, False,
|
||||
output[0].logprobs, self.pythonization_cache))
|
||||
|
||||
# These GPU tensors are not required by multi-step;
|
||||
# erase them to ensure they are not pythonized or
|
||||
# transferred to CPU
|
||||
output[0].sampled_token_ids = None
|
||||
output[0].sampled_token_probs = None
|
||||
output[0].logprobs = None
|
||||
|
||||
# Pythonize the output if CPU is ahead and the previous step is
|
||||
# ready.
|
||||
if frozen_model_input.async_callback is None:
|
||||
for model_output in model_input.cached_outputs:
|
||||
model_output.maybe_pythonize(model_input,
|
||||
self._copy_stream,
|
||||
self.pinned_sampled_token_ids)
|
||||
|
||||
model_input.current_step += 1
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
# Should be IntermediateTensors
|
||||
assert isinstance(output, IntermediateTensors)
|
||||
return output
|
||||
if not self.is_driver_worker:
|
||||
return []
|
||||
|
||||
# Pythonize the output and block if needed since it is the last step
|
||||
if model_input.is_last_step:
|
||||
outputs = self._final_process_outputs(
|
||||
model_input, model_input.base_output_proc_callback)
|
||||
if self.pythonization_cache:
|
||||
self.pythonization_cache.reset()
|
||||
return outputs
|
||||
|
||||
# should be [SamplerOutput]
|
||||
return output
|
||||
|
||||
def _update_sampling_metadata(self, sampling_metadata: SamplingMetadata,
|
||||
num_seqs: Optional[int], num_queries: int):
|
||||
|
||||
assert sampling_metadata.num_prompts == 0
|
||||
assert len(sampling_metadata.seq_groups) == num_queries
|
||||
assert sampling_metadata.selected_token_indices.shape == (
|
||||
num_queries, )
|
||||
# assert sampling_metadata.categorized_sample_indices == TODO: Add if needed # noqa: E501
|
||||
|
||||
# Verify that all sequences are decodes
|
||||
for i in range(num_queries):
|
||||
seq_group = sampling_metadata.seq_groups[i]
|
||||
|
||||
assert seq_group.is_prompt is False # No prompt
|
||||
assert seq_group.prompt_logprob_indices == [] # No prompt
|
||||
assert seq_group.sample_indices == [i] # Simple
|
||||
assert seq_group.seq_len is None # Decode
|
||||
assert seq_group.query_len is None # Decode
|
||||
|
||||
def _advance_step(self, model_input: StatefulModelInput,
|
||||
out: SamplerOutput) -> StatefulModelInput:
|
||||
|
||||
model_input.maybe_advance_frozen_model_input(self.device,
|
||||
self.pin_memory)
|
||||
frozen_model_input = model_input.frozen_model_input
|
||||
assert frozen_model_input is not None
|
||||
assert frozen_model_input.input_tokens is not None
|
||||
assert frozen_model_input.input_tokens.shape[0] == model_input.num_seqs
|
||||
assert frozen_model_input.attn_metadata is not None
|
||||
|
||||
sampled_token_ids = model_input.cached_outputs[-1].sampled_token_ids
|
||||
num_seqs = model_input.num_seqs
|
||||
num_queries = model_input.num_queries
|
||||
frozen_model_input = model_input.frozen_model_input
|
||||
assert frozen_model_input is not None
|
||||
attn_metadata = frozen_model_input.attn_metadata
|
||||
assert attn_metadata is not None
|
||||
|
||||
turn_prefills_into_decodes: bool = model_input.current_step == 1 and \
|
||||
model_input.num_single_step_prefills != 0
|
||||
attn_metadata.advance_step(
|
||||
frozen_model_input,
|
||||
sampled_token_ids,
|
||||
self.block_size,
|
||||
num_seqs,
|
||||
num_queries,
|
||||
turn_prefills_into_decodes=turn_prefills_into_decodes)
|
||||
|
||||
return model_input
|
||||
|
||||
def load_model(self) -> None:
|
||||
self._base_model_runner.load_model()
|
||||
self.model_memory_usage = self._base_model_runner.model_memory_usage
|
||||
|
||||
def save_sharded_state(
|
||||
self,
|
||||
path: str,
|
||||
pattern: Optional[str] = None,
|
||||
max_size: Optional[int] = None,
|
||||
) -> None:
|
||||
return self._base_model_runner.save_sharded_state(
|
||||
path, pattern, max_size)
|
||||
|
||||
def save_tensorized_model(self,
|
||||
tensorizer_config: TensorizerConfig) -> None:
|
||||
return self._base_model_runner.save_tensorized_model(tensorizer_config)
|
||||
|
||||
def profile_run(self) -> None:
|
||||
return self._base_model_runner.profile_run()
|
||||
|
||||
def remove_all_loras(self):
|
||||
return self._base_model_runner.remove_all_loras()
|
||||
|
||||
def capture_model(self, kv_caches: List[List]) -> None:
|
||||
return self._base_model_runner.capture_model(kv_caches)
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
return self._base_model_runner.vocab_size
|
||||
|
||||
|
||||
DeferredLogprobsReturnType = Tuple[Optional[List[Optional[PromptLogprobs]]],
|
||||
Optional[List[SampleLogprobs]]]
|
||||
|
||||
|
||||
def deferred_pythonize_logprobs(
|
||||
output: SamplerOutput,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
logprobs_tensor: Optional[torch.Tensor],
|
||||
) -> DeferredLogprobsReturnType:
|
||||
"""Perform deferred logprob Pythonization.
|
||||
|
||||
1. Pythonize GPU-side sampler result tensors into CPU-side sampler result.
|
||||
2. Pythonize GPU-side logprobs tensor into CPU-side logprobs lists,
|
||||
utilizing the Pythonized sampler result computed in step 1.
|
||||
|
||||
These deferred computations are not required for single-step scheduling
|
||||
or the `profile_run()` phase of multi-step scheduling.
|
||||
|
||||
Args:
|
||||
output: sampler output (under deferred Pythonization)
|
||||
sampling_metadata
|
||||
|
||||
Returns:
|
||||
prompt_logprobs (CPU), sample_logprobs (CPU)
|
||||
"""
|
||||
|
||||
# - Deferred pythonization of sample result
|
||||
sampler_result = get_pythonized_sample_results(
|
||||
output.deferred_sample_results_args)
|
||||
|
||||
# - Erase the GPU-side deferred sample_result
|
||||
# computation args to ensure it is never
|
||||
# pythonized or transferred to CPU
|
||||
output.deferred_sample_results_args = None
|
||||
|
||||
# - Deferred pythonization of logprobs
|
||||
(
|
||||
prompt_logprobs,
|
||||
sample_logprobs,
|
||||
) = get_logprobs(logprobs_tensor, sampling_metadata, sampler_result)
|
||||
assert len(prompt_logprobs) == len(sampling_metadata.seq_groups)
|
||||
assert len(sample_logprobs) == len(sampling_metadata.seq_groups)
|
||||
|
||||
return prompt_logprobs, sample_logprobs
|
||||
|
||||
|
||||
def _pythonize_sampler_output(
|
||||
model_input: StatefulModelInput,
|
||||
output: SamplerOutput,
|
||||
pinned_sampled_token_buffer: torch.Tensor,
|
||||
sampled_token_ids: torch.Tensor,
|
||||
logprobs_tensor: Optional[torch.Tensor],
|
||||
cache: Optional[PythonizationCache],
|
||||
) -> None:
|
||||
""" This function is only called when the output tensors are ready.
|
||||
See [`ModelOutput`][vllm.worker.multi_step_model_runner.ModelOutput].
|
||||
|
||||
Modifies `output.outputs` and `pinned_sampled_token_buffer` in-place,
|
||||
adding a Pythonized output data structure
|
||||
([`CompletionSequenceGroupOutput`][vllm.sequence.CompletionSequenceGroupOutput])
|
||||
for each [`SequenceGroup`][vllm.sequence.SequenceGroup].
|
||||
|
||||
Args:
|
||||
model_input
|
||||
output: sampler output
|
||||
pinned_sampled_token_token_buffer: CPU-side pinned memory
|
||||
(receives copy of
|
||||
GPU-side token buffer.)
|
||||
sampled_token_ids: GPU-side token buffer
|
||||
logprobs_tensor: GPU-side tensor containing
|
||||
logprobs computed during sampling
|
||||
"""
|
||||
|
||||
assert model_input.frozen_model_input is not None
|
||||
|
||||
frozen_model_input = model_input.frozen_model_input
|
||||
assert frozen_model_input.sampling_metadata is not None
|
||||
sampling_metadata = frozen_model_input.sampling_metadata
|
||||
# samples generation should have been skipped
|
||||
assert not output.outputs
|
||||
|
||||
pinned_buffer = pinned_sampled_token_buffer[:model_input.num_queries]
|
||||
|
||||
# We guarantee output tensors are ready, so it is safe to
|
||||
# pythonize the sampler output & obtain CPU-side logprobs.
|
||||
#
|
||||
# However we should check whether logprobs pythonization may
|
||||
# be skipped entirely, i.e. because no logprobs were requested
|
||||
# or pythonization was not deferred. To that end,
|
||||
#
|
||||
# * `prompt_logprobs_are_requested_for_prefill` signals that
|
||||
# there are *any* prefill-phase requests which specify that
|
||||
# prompt logprobs should be returned.
|
||||
#
|
||||
# * `any_logprobs_are_requested` signals that there are any
|
||||
# requests which (1) specify that sample logprobs should be
|
||||
# returned, or (2) are in the prefill phase AND specify that
|
||||
# prompt logprobs should be returned.
|
||||
#
|
||||
# Later on, these flags cause adjustments to the pythonization
|
||||
# process to accommodate logprobs.
|
||||
|
||||
seq_groups = sampling_metadata.seq_groups
|
||||
prompt_logprobs_are_requested_for_prefill = any([
|
||||
sg.sampling_params.prompt_logprobs is not None and sg.is_prompt
|
||||
for sg in seq_groups
|
||||
])
|
||||
any_logprobs_are_requested = (
|
||||
prompt_logprobs_are_requested_for_prefill
|
||||
or any([sg.sampling_params.logprobs is not None for sg in seq_groups]))
|
||||
|
||||
if prompt_logprobs_are_requested_for_prefill:
|
||||
# CPU GPU sync, after gathering *only* sampled tokens (since
|
||||
# requesting prompt logprobs leads `sampled_token_ids` to
|
||||
# include prompt token ids in addition to sampled token ids.)
|
||||
sample_idx_tensor = torch.tensor(
|
||||
[sdx for sg in seq_groups for sdx in sg.sample_indices])
|
||||
pinned_buffer = pinned_buffer.copy_(
|
||||
sampled_token_ids[sample_idx_tensor, :], non_blocking=False)
|
||||
else:
|
||||
# CPU GPU sync
|
||||
pinned_buffer = pinned_buffer.copy_(sampled_token_ids,
|
||||
non_blocking=False)
|
||||
|
||||
# this will not block as the tensors are already on CPU
|
||||
samples_list = pinned_buffer.tolist()
|
||||
|
||||
skip_sampler_cpu_output = (
|
||||
frozen_model_input.sampling_metadata.skip_sampler_cpu_output)
|
||||
|
||||
# *Don't* skip logprobs pythonization *if*:
|
||||
# * Any requests require logprobs to be returned in this
|
||||
# iteration AND
|
||||
# * These requests are being scheduled in a fashion which
|
||||
# defers pythonization (i.e. multi-step scheduling.)
|
||||
do_pythonize_logprobs = (skip_sampler_cpu_output
|
||||
and any_logprobs_are_requested)
|
||||
(
|
||||
prompt_logprobs,
|
||||
sample_logprobs,
|
||||
) = (deferred_pythonize_logprobs(output, sampling_metadata,
|
||||
logprobs_tensor)
|
||||
if do_pythonize_logprobs else (None, None))
|
||||
|
||||
for sgdx, (seq_group,
|
||||
sample_result) in enumerate(zip(seq_groups, samples_list)):
|
||||
# Reminder: Please update docs/features/compatibility_matrix.md
|
||||
# If the feature combo become valid
|
||||
# (Check for Guided Decoding)
|
||||
if seq_group.sampling_params.logits_processors:
|
||||
assert len(seq_group.sampling_params.logits_processors) == 0, (
|
||||
"Logits Processors are not supported in multi-step decoding")
|
||||
|
||||
if do_pythonize_logprobs:
|
||||
assert prompt_logprobs is not None
|
||||
assert sample_logprobs is not None
|
||||
|
||||
(
|
||||
group_prompt_logprobs,
|
||||
group_sample_logprobs,
|
||||
) = ( # Utilize deferred pythonization results
|
||||
prompt_logprobs[sgdx],
|
||||
sample_logprobs[sgdx],
|
||||
)
|
||||
elif any_logprobs_are_requested:
|
||||
(
|
||||
group_prompt_logprobs,
|
||||
group_sample_logprobs,
|
||||
) = (
|
||||
# profile_run: use already-computed logprobs
|
||||
output.outputs[sgdx].prompt_logprobs,
|
||||
[sample.logprobs for sample in output.outputs[sgdx].samples])
|
||||
|
||||
seq_ids = seq_group.seq_ids
|
||||
next_token_ids = sample_result
|
||||
parent_ids = [0]
|
||||
seq_outputs: List[SequenceOutput]
|
||||
|
||||
if cache is not None:
|
||||
completion_seq_group_output: CompletionSequenceGroupOutput = \
|
||||
cache.cached_completion_seq_group_output.get_object()
|
||||
completion_seq_group_output.samples.clear()
|
||||
seq_outputs = completion_seq_group_output.samples
|
||||
else:
|
||||
seq_outputs = []
|
||||
|
||||
for tdx, (parent_id,
|
||||
next_token_id) in enumerate(zip(parent_ids, next_token_ids)):
|
||||
if cache is not None:
|
||||
seq_output: SequenceOutput = cache.cached_seq_output.get_object(
|
||||
)
|
||||
seq_output.parent_seq_id = seq_ids[parent_id]
|
||||
seq_output.output_token = next_token_id
|
||||
|
||||
if any_logprobs_are_requested:
|
||||
seq_output.logprobs = group_sample_logprobs[tdx]
|
||||
else:
|
||||
logprobs = next(iter(seq_output.logprobs.values()))
|
||||
seq_output.logprobs.clear()
|
||||
|
||||
logprobs.logprob = float('inf')
|
||||
logprobs.rank = None
|
||||
logprobs.decoded_token = None
|
||||
|
||||
seq_output.logprobs[next_token_id] = logprobs
|
||||
|
||||
seq_outputs.append(seq_output)
|
||||
|
||||
else:
|
||||
seq_outputs.append(
|
||||
SequenceOutput(seq_ids[parent_id], next_token_id,
|
||||
(group_sample_logprobs[tdx]
|
||||
if any_logprobs_are_requested else {
|
||||
next_token_id:
|
||||
Logprob(logprob=float('inf'),
|
||||
rank=None,
|
||||
decoded_token=None)
|
||||
})))
|
||||
if cache is not None:
|
||||
completion_seq_group_output.prompt_logprobs = \
|
||||
group_prompt_logprobs if any_logprobs_are_requested else None
|
||||
output.outputs.append(completion_seq_group_output)
|
||||
else:
|
||||
output.outputs.append(
|
||||
CompletionSequenceGroupOutput(
|
||||
seq_outputs, (group_prompt_logprobs
|
||||
if any_logprobs_are_requested else None)))
|
||||
|
||||
assert len(output.outputs) > 0
|
||||
84
vllm/worker/multi_step_neuron_model_runner.py
Normal file
84
vllm/worker/multi_step_neuron_model_runner.py
Normal file
@@ -0,0 +1,84 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from importlib.util import find_spec
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.multimodal import MultiModalKwargs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.worker.neuron_model_runner import (ModelInputForNeuron,
|
||||
NeuronModelRunner)
|
||||
|
||||
|
||||
class MultiStepNeuronModelRunner(NeuronModelRunner):
|
||||
"""A model runner for multi step decoding using the transformers_neuronx
|
||||
framework"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
):
|
||||
super().__init__(vllm_config)
|
||||
self.speculation_config = self.speculative_config
|
||||
from transformers_neuronx.config import GenerationConfig
|
||||
self.speculation_config.draft_model_config.neuron_sampling_params = (
|
||||
GenerationConfig(
|
||||
max_length=self.scheduler_config.max_model_len,
|
||||
do_sample=True,
|
||||
per_batch_line=True,
|
||||
top_k=[self._MAX_NEURON_SAMPLING_TOP_K] \
|
||||
* self.scheduler_config.max_num_seqs,
|
||||
top_p=[1.0] * self.scheduler_config.max_num_seqs,
|
||||
temperature=[1.0] * self.scheduler_config.max_num_seqs,
|
||||
dynamic=True,
|
||||
global_top_k=self._MAX_NEURON_SAMPLING_TOP_K
|
||||
))
|
||||
|
||||
def load_model(self) -> None:
|
||||
if find_spec("transformers_neuronx") is not None:
|
||||
from vllm.model_executor.model_loader.neuron import (
|
||||
get_neuron_eagle_speculation_model,
|
||||
get_neuron_speculation_model)
|
||||
if self.speculation_config.speculative_token_tree is not None:
|
||||
self.model = get_neuron_eagle_speculation_model(
|
||||
self.model_config,
|
||||
parallel_config=self.parallel_config,
|
||||
scheduler_config=self.scheduler_config,
|
||||
speculation_config=self.speculation_config)
|
||||
else:
|
||||
self.model = get_neuron_speculation_model(
|
||||
self.model_config,
|
||||
parallel_config=self.parallel_config,
|
||||
scheduler_config=self.scheduler_config,
|
||||
speculation_config=self.speculation_config)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Supports only Transformer-NeuronX based models.")
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
model_input: ModelInputForNeuron,
|
||||
kv_caches: Optional[List[torch.Tensor]] = None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
num_steps: int = 1,
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
logits = self.model(
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
input_block_ids=model_input.input_block_ids,
|
||||
**MultiModalKwargs.as_kwargs(
|
||||
model_input.multi_modal_kwargs or {},
|
||||
device=self.device,
|
||||
),
|
||||
)
|
||||
|
||||
output = self.model.sample(
|
||||
logits=logits,
|
||||
sampling_metadata=model_input.sampling_metadata,
|
||||
)
|
||||
return output
|
||||
63
vllm/worker/multi_step_neuronx_distributed_model_runner.py
Normal file
63
vllm/worker/multi_step_neuronx_distributed_model_runner.py
Normal file
@@ -0,0 +1,63 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.multimodal import MultiModalKwargs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.worker.neuronx_distributed_model_runner import (
|
||||
NeuronxDistributedModelRunner)
|
||||
|
||||
|
||||
class MultiStepNeuronxDistributedModelRunner(NeuronxDistributedModelRunner):
|
||||
"""A model runner for multi-step decoding using the
|
||||
neuronx-distributed-inference framework"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
):
|
||||
super().__init__(vllm_config)
|
||||
|
||||
def load_model(self) -> None:
|
||||
from vllm.model_executor.model_loader.neuronx_distributed import (
|
||||
get_neuron_speculation_model)
|
||||
self.model = get_neuron_speculation_model(
|
||||
self.model_config,
|
||||
parallel_config=self.parallel_config,
|
||||
scheduler_config=self.scheduler_config,
|
||||
speculation_config=self.speculative_config)
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
model_input,
|
||||
kv_caches: Optional[List[torch.Tensor]] = None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
num_steps: int = 1,
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
sampling_params = torch.tensor([[
|
||||
seq_group.sampling_params.top_k,
|
||||
seq_group.sampling_params.top_p,
|
||||
seq_group.sampling_params.temperature,
|
||||
] for seq_group in model_input.sampling_metadata.seq_groups])
|
||||
|
||||
logits = self.model(
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
input_block_ids=model_input.input_block_ids,
|
||||
sampling_params=sampling_params,
|
||||
**MultiModalKwargs.as_kwargs(
|
||||
model_input.multi_modal_kwargs or {},
|
||||
device=self.device,
|
||||
),
|
||||
)
|
||||
|
||||
output = self.model.sample(
|
||||
logits=logits,
|
||||
sampling_metadata=model_input.sampling_metadata,
|
||||
)
|
||||
return output
|
||||
108
vllm/worker/multi_step_tpu_worker.py
Normal file
108
vllm/worker/multi_step_tpu_worker.py
Normal file
@@ -0,0 +1,108 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import dataclasses
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed import broadcast_tensor_dict
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.worker.tpu_model_runner import ModelInputForTPU
|
||||
from vllm.worker.tpu_worker import TPUWorker
|
||||
from vllm.worker.worker_base import WorkerInput
|
||||
|
||||
|
||||
class MultiStepTPUWorker(TPUWorker):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.cached_model_input: Optional[ModelInputForTPU] = None
|
||||
|
||||
def _get_driver_input_and_broadcast(
|
||||
self, execute_model_req: ExecuteModelRequest
|
||||
) -> Tuple[ModelInputForTPU, WorkerInput, Dict[str, torch.Tensor]]:
|
||||
assert self.is_driver_worker
|
||||
assert execute_model_req.virtual_engine == 0
|
||||
|
||||
is_first_multi_step = execute_model_req.is_first_multi_step
|
||||
is_last_step = execute_model_req.is_last_step
|
||||
if is_first_multi_step:
|
||||
worker_input: WorkerInput = self.prepare_worker_input(
|
||||
execute_model_req=execute_model_req)
|
||||
worker_input = dataclasses.replace(
|
||||
worker_input,
|
||||
num_steps=execute_model_req.num_lookahead_slots + 1)
|
||||
model_input: ModelInputForTPU = (
|
||||
self.model_runner.prepare_model_input(
|
||||
execute_model_req.seq_group_metadata_list,
|
||||
execute_model_req.virtual_engine,
|
||||
execute_model_req.finished_requests_ids))
|
||||
|
||||
if execute_model_req.async_callback:
|
||||
model_input = dataclasses.replace(
|
||||
model_input,
|
||||
async_callback=execute_model_req.async_callback)
|
||||
else:
|
||||
assert self.cached_model_input is not None
|
||||
model_input = self.cached_model_input
|
||||
worker_input = WorkerInput()
|
||||
model_input = dataclasses.replace(
|
||||
model_input,
|
||||
is_first_multi_step=is_first_multi_step,
|
||||
is_last_step=is_last_step)
|
||||
|
||||
if self.do_metadata_broadcast:
|
||||
if is_first_multi_step:
|
||||
broadcast_data = worker_input.as_broadcastable_tensor_dict()
|
||||
broadcast_data.update(
|
||||
model_input.as_broadcastable_tensor_dict())
|
||||
broadcast_tensor_dict(broadcast_data, src=0)
|
||||
else:
|
||||
broadcast_data = {
|
||||
"is_first_multi_step": is_first_multi_step,
|
||||
"is_last_step": is_last_step,
|
||||
}
|
||||
broadcast_tensor_dict(broadcast_data, src=0)
|
||||
|
||||
# Retuning empty dict here to keep this compatible with
|
||||
# `LocalOrDistributedWorkerBase._get_driver_input_and_broadcast`
|
||||
return model_input, worker_input, {}
|
||||
|
||||
def prepare_input(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None,
|
||||
) -> Optional[Tuple[ModelInputForTPU, WorkerInput, Dict[str,
|
||||
torch.Tensor]]]:
|
||||
if self.is_driver_worker:
|
||||
if execute_model_req is None:
|
||||
if self.do_metadata_broadcast:
|
||||
broadcast_tensor_dict({}, src=0)
|
||||
return None
|
||||
|
||||
model_input, worker_input, _ = self._get_driver_input_and_broadcast(
|
||||
execute_model_req)
|
||||
if model_input.is_first_multi_step:
|
||||
self.cached_model_input = model_input
|
||||
return model_input, worker_input, {}
|
||||
else:
|
||||
broadcast_data = broadcast_tensor_dict(src=0)
|
||||
if not broadcast_data:
|
||||
return None
|
||||
|
||||
if len(broadcast_data) == 2:
|
||||
assert self.cached_model_input is not None
|
||||
self.cached_model_input = dataclasses.replace(
|
||||
self.cached_model_input,
|
||||
is_first_multi_step=broadcast_data["is_first_multi_step"],
|
||||
is_last_step=broadcast_data["is_last_step"])
|
||||
empty_worker_input = WorkerInput()
|
||||
return self.cached_model_input, empty_worker_input, {}
|
||||
|
||||
worker_input = WorkerInput.from_broadcasted_tensor_dict(
|
||||
broadcast_data)
|
||||
model_input = (
|
||||
self.model_runner.
|
||||
make_model_input_from_broadcasted_tensor_dict(broadcast_data))
|
||||
self.cached_model_input = model_input
|
||||
return model_input, worker_input, {}
|
||||
197
vllm/worker/multi_step_worker.py
Normal file
197
vllm/worker/multi_step_worker.py
Normal file
@@ -0,0 +1,197 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import dataclasses
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed import broadcast_tensor_dict, get_pp_group
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.worker.model_runner_base import BroadcastableModelInput
|
||||
from vllm.worker.multi_step_model_runner import (MultiStepModelRunner,
|
||||
StatefulModelInput)
|
||||
from vllm.worker.worker import Worker, WorkerInput
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiStepState:
|
||||
worker_input: WorkerInput
|
||||
model_input: StatefulModelInput
|
||||
|
||||
|
||||
class MultiStepWorker(Worker):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
base_model_runner = self.model_runner
|
||||
# for multi-step model, wrap the model runner with MultiStepModelRunner
|
||||
self.model_runner = MultiStepModelRunner(
|
||||
base_model_runner,
|
||||
vllm_config=base_model_runner.vllm_config,
|
||||
kv_cache_dtype=self.cache_config.cache_dtype,
|
||||
is_driver_worker=base_model_runner.is_driver_worker,
|
||||
)
|
||||
|
||||
pipeline_parallel_size = self.parallel_config.pipeline_parallel_size
|
||||
self.multi_step_states: List[
|
||||
Optional[MultiStepState]] = [None] * pipeline_parallel_size
|
||||
self.temp_output = None
|
||||
|
||||
def _get_driver_input_and_broadcast(
|
||||
self, execute_model_req: ExecuteModelRequest
|
||||
) -> Tuple[BroadcastableModelInput, WorkerInput, Dict[str, torch.Tensor]]:
|
||||
"""
|
||||
Get the driver input and broadcast it to other workers.
|
||||
"""
|
||||
assert self.is_driver_worker
|
||||
virtual_engine = execute_model_req.virtual_engine
|
||||
is_first_multi_step = execute_model_req.is_first_multi_step
|
||||
if is_first_multi_step:
|
||||
# on first step we prepare the worker input and model input normally
|
||||
worker_input: WorkerInput = self.prepare_worker_input(
|
||||
execute_model_req=execute_model_req)
|
||||
model_input: StatefulModelInput = (
|
||||
self.model_runner.prepare_model_input(
|
||||
execute_model_req.seq_group_metadata_list,
|
||||
execute_model_req.virtual_engine,
|
||||
execute_model_req.finished_requests_ids))
|
||||
|
||||
if execute_model_req.async_callback:
|
||||
model_input.frozen_model_input = dataclasses.replace( # type: ignore
|
||||
model_input.frozen_model_input,
|
||||
async_callback=execute_model_req.async_callback)
|
||||
else:
|
||||
# on subsequent steps we reuse the worker input and model input
|
||||
multi_step_state = self.multi_step_states[virtual_engine]
|
||||
worker_input = multi_step_state.worker_input
|
||||
model_input = multi_step_state.model_input
|
||||
frozen_model_input = model_input.frozen_model_input
|
||||
assert frozen_model_input is not None
|
||||
assert frozen_model_input.attn_metadata is not None
|
||||
# clear the cached metadata so that it can be recomputed on
|
||||
# the workers.
|
||||
frozen_model_input.attn_metadata._cached_prefill_metadata = None
|
||||
frozen_model_input.attn_metadata._cached_decode_metadata = None
|
||||
|
||||
model_input.is_first_multi_step = is_first_multi_step
|
||||
model_input.is_last_step = execute_model_req.is_last_step
|
||||
|
||||
if not is_first_multi_step:
|
||||
# we broadcast the last sampled token ids to all TP workers so they
|
||||
# can update their model input metadata in-place.
|
||||
self._prepare_last_sampled_token_ids_for_tp_workers(
|
||||
execute_model_req=execute_model_req, model_input=model_input)
|
||||
|
||||
if self.do_metadata_broadcast:
|
||||
broadcast_data = worker_input.as_broadcastable_tensor_dict()
|
||||
broadcast_data.update(model_input.as_broadcastable_tensor_dict())
|
||||
broadcast_tensor_dict(broadcast_data, src=0)
|
||||
|
||||
# Retuning empty dict here to keep this compatible with
|
||||
# `LocalOrDistributedWorkerBase._get_driver_input_and_broadcast`
|
||||
return model_input, worker_input, {}
|
||||
|
||||
def _prepare_last_sampled_token_ids_for_tp_workers(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
model_input: StatefulModelInput,
|
||||
) -> None:
|
||||
"""
|
||||
Prepare the last sampled token ids for TP workers. If it's the last
|
||||
PP rank, then the last sampled token ids are already in the model_input.
|
||||
If it is NOT the last PP rank, then we need to get the last sampled
|
||||
token that is cached in the execute_model_req.
|
||||
"""
|
||||
if get_pp_group().is_last_rank:
|
||||
assert model_input.cached_outputs[
|
||||
-1].sampler_output.sampled_token_ids is None
|
||||
assert model_input.cached_outputs[-1].sampled_token_ids is not None
|
||||
model_input.last_sampled_token_ids = model_input.cached_outputs[
|
||||
-1].sampled_token_ids
|
||||
# free sampled token ids from the previous step if it has been
|
||||
# pythonized. Cannot free the last sampled token ids because
|
||||
# we need it for GPU advance_step.
|
||||
for output in model_input.cached_outputs[:-1]:
|
||||
if output.pythonized:
|
||||
output.sampled_token_ids = None
|
||||
else:
|
||||
# otherwise we need to get the cached sampled token ids from the
|
||||
# execute_model_req
|
||||
assert execute_model_req.last_sampled_token_ids is not None
|
||||
model_input.last_sampled_token_ids = (
|
||||
execute_model_req.last_sampled_token_ids.cuda())
|
||||
model_input.add_sampler_output(
|
||||
SamplerOutput(outputs=[], sampled_token_ids=None),
|
||||
model_input.last_sampled_token_ids)
|
||||
|
||||
# free sampled token ids from the previous step.
|
||||
# TODO(will) we could reuse the sampled token ids tensor from
|
||||
# the previous step instead.
|
||||
for output in model_input.cached_outputs[:-1]:
|
||||
output.sampled_token_ids = None
|
||||
assert model_input.cached_outputs[-1].sampled_token_ids is not None
|
||||
|
||||
def prepare_input(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None,
|
||||
) -> Optional[Tuple[StatefulModelInput, WorkerInput, Dict[str,
|
||||
torch.Tensor]]]:
|
||||
"""
|
||||
Depending on the current state of the request and multi step worker,
|
||||
this method may skip the normal _prepare_model_input and
|
||||
_prepare_worker_input methods and instead used cached values.
|
||||
"""
|
||||
if self.is_driver_worker:
|
||||
if execute_model_req is None:
|
||||
if self.do_metadata_broadcast:
|
||||
# This signals that there's no more requests to process for
|
||||
# now. All workers are running infinite loop with
|
||||
# broadcast_tensor_dict, and it stops the loop when the
|
||||
# driver broadcasts an empty input. Send an empty input to
|
||||
# notify all other workers to stop their execution loop.
|
||||
broadcast_tensor_dict({}, src=0)
|
||||
return None
|
||||
|
||||
virtual_engine = execute_model_req.virtual_engine
|
||||
(model_input, worker_input,
|
||||
kwargs) = self._get_driver_input_and_broadcast(execute_model_req)
|
||||
assert isinstance(model_input, StatefulModelInput)
|
||||
if execute_model_req.is_first_multi_step:
|
||||
# cache the worker input and model input for the next steps
|
||||
self.multi_step_states[virtual_engine] = MultiStepState(
|
||||
worker_input=worker_input, model_input=model_input)
|
||||
# if TP workers
|
||||
else:
|
||||
broadcast_data = self._get_worker_input_from_broadcast()
|
||||
# if the driver has sent an empty input, we should stop the worker
|
||||
# loop
|
||||
if broadcast_data is None:
|
||||
return None
|
||||
model_input, worker_input, kwargs = broadcast_data
|
||||
assert isinstance(model_input, StatefulModelInput)
|
||||
virtual_engine = worker_input.virtual_engine
|
||||
if model_input.is_first_multi_step:
|
||||
pass
|
||||
# TODO(will) Can cache the worker input and model input for the
|
||||
# next steps. See below for details
|
||||
else:
|
||||
# TODO(will) possible to also cache and reuse the cached worker
|
||||
# input and model input. The idea is essentially the delta
|
||||
# optimization for model_inputs. Where the TP workers can cache
|
||||
# the model input states and we only broadcast the delta need
|
||||
# for the next step (sampled_token_ids from the previous step)
|
||||
|
||||
assert isinstance(model_input, StatefulModelInput)
|
||||
# we need to update the last sampled token ids in the model
|
||||
# input for the workers so that they can run inplace
|
||||
# advance_step
|
||||
model_input.add_sampler_output(
|
||||
SamplerOutput(outputs=[], sampled_token_ids=None),
|
||||
model_input.last_sampled_token_ids)
|
||||
|
||||
assert model_input is not None
|
||||
assert worker_input is not None
|
||||
return model_input, worker_input, kwargs
|
||||
460
vllm/worker/neuron_model_runner.py
Normal file
460
vllm/worker/neuron_model_runner.py
Normal file
@@ -0,0 +1,460 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import DeviceConfig, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import LoRAMapping
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.model_loader.neuron import get_neuron_model
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
|
||||
MultiModalKwargs)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
|
||||
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
|
||||
from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ModelInputForNeuron(ModelRunnerInputBase):
|
||||
"""
|
||||
Used by the NeuronModelRunner.
|
||||
"""
|
||||
input_tokens: Optional[torch.Tensor] = None
|
||||
input_positions: Optional[torch.Tensor] = None
|
||||
input_block_ids: Optional[torch.Tensor] = None
|
||||
sampling_metadata: SamplingMetadata = None
|
||||
multi_modal_kwargs: BatchedTensorInputs = None
|
||||
adapter_ids: Optional[str] = None
|
||||
|
||||
def as_broadcastable_tensor_dict(
|
||||
self) -> Dict[str, Union[int, torch.Tensor]]:
|
||||
return {
|
||||
"input_tokens": self.input_tokens,
|
||||
"input_positions": self.input_positions,
|
||||
"input_block_ids": self.input_block_ids,
|
||||
"sampling_metadata": self.sampling_metadata,
|
||||
"multi_modal_kwargs": self.multi_modal_kwargs,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_broadcasted_tensor_dict(
|
||||
cls,
|
||||
tensor_dict: Dict[str, Any],
|
||||
attn_backend: Optional["AttentionBackend"] = None,
|
||||
) -> "ModelInputForNeuron":
|
||||
return ModelInputForNeuron(
|
||||
input_tokens=tensor_dict["input_tokens"],
|
||||
input_positions=tensor_dict["input_positions"],
|
||||
input_block_ids=tensor_dict["input_block_ids"],
|
||||
sampling_metadata=tensor_dict["sampling_metadata"],
|
||||
multi_modal_kwargs=tensor_dict["multi_modal_kwargs"],
|
||||
)
|
||||
|
||||
|
||||
class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
|
||||
"""A model runner for AWS Neuron hardware"""
|
||||
|
||||
# NEURON has an upper limit on the top_k
|
||||
_MAX_NEURON_SAMPLING_TOP_K = 256
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
):
|
||||
ModelRunnerBase.__init__(self, vllm_config)
|
||||
|
||||
if (self.model_config is not None
|
||||
and self.model_config.get_sliding_window()):
|
||||
logger.warning("Sliding window is not supported on Neuron. "
|
||||
"The model will run without sliding window.")
|
||||
self.device_config = (self.device_config if self.device_config
|
||||
is not None else DeviceConfig())
|
||||
self.lora_config = vllm_config.lora_config
|
||||
self.device = self.device_config.device
|
||||
self.pin_memory = is_pin_memory_available()
|
||||
|
||||
# Multi-modal data support
|
||||
self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
|
||||
.create_input_mapper(self.model_config)
|
||||
|
||||
# Lazy initialization.
|
||||
self.model: nn.Module # initialize after load_model.
|
||||
|
||||
# Once NEURON_ON_DEVICE_SAMPLING_DISABLED is set to a non-zero value,
|
||||
# turn off on-device sampling.
|
||||
self._on_device_sampling_disabled = int(
|
||||
os.getenv("NEURON_ON_DEVICE_SAMPLING_DISABLED", "0"))
|
||||
|
||||
# NEURON needs to update sampling parameters when request IDs change
|
||||
# across batches. This variable stores the previous batch's request IDs
|
||||
# to determine if an update is needed.
|
||||
self._previous_batch_request_ids: List[str] = []
|
||||
|
||||
if not self._on_device_sampling_disabled:
|
||||
self._init_neuron_sampling()
|
||||
|
||||
def _init_neuron_sampling(self) -> None:
|
||||
if current_platform.use_transformers_neuronx():
|
||||
from transformers_neuronx.config import GenerationConfig
|
||||
else:
|
||||
from transformers import GenerationConfig
|
||||
logger.warning(
|
||||
"On-device sampling is turned on in Neuron by default, only "
|
||||
"top_k, top_p, and temperature are current supported sampling "
|
||||
"parameters. To turn off the on-device sampling, please set "
|
||||
"the environment variable NEURON_ON_DEVICE_SAMPLING_DISABLED=1.")
|
||||
self.model_config.neuron_sampling_params = GenerationConfig(
|
||||
max_length=self.scheduler_config.max_model_len,
|
||||
do_sample=True,
|
||||
per_batch_line=True,
|
||||
top_k=[self._MAX_NEURON_SAMPLING_TOP_K] \
|
||||
* self.scheduler_config.max_num_seqs,
|
||||
top_p=[1.0] * self.scheduler_config.max_num_seqs,
|
||||
temperature=[1.0] * self.scheduler_config.max_num_seqs,
|
||||
dynamic=True,
|
||||
global_top_k=self._MAX_NEURON_SAMPLING_TOP_K)
|
||||
|
||||
def load_model(self) -> None:
|
||||
self.model = get_neuron_model(self.model_config,
|
||||
parallel_config=self.parallel_config,
|
||||
scheduler_config=self.scheduler_config)
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model
|
||||
|
||||
def _prepare_prompt(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int],
|
||||
BatchedTensorInputs]:
|
||||
assert len(seq_group_metadata_list) > 0
|
||||
input_tokens: List[List[int]] = []
|
||||
input_positions: List[List[int]] = []
|
||||
input_block_ids: List[int] = []
|
||||
|
||||
seq_lens: List[int] = []
|
||||
multi_modal_kwargs_list: List[MultiModalKwargs] = []
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
assert seq_group_metadata.is_prompt
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
assert len(seq_ids) == 1
|
||||
seq_id = seq_ids[0]
|
||||
|
||||
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||
prompt_tokens = seq_data.get_token_ids()
|
||||
seq_len = len(prompt_tokens)
|
||||
seq_lens.append(seq_len)
|
||||
|
||||
input_tokens.append(prompt_tokens)
|
||||
input_positions.append(list(range(seq_len)))
|
||||
|
||||
assert seq_group_metadata.block_tables is not None
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
assert len(block_table) == 1
|
||||
input_block_ids.append(block_table[0])
|
||||
|
||||
mm_kwargs = seq_group_metadata.multi_modal_data
|
||||
if mm_kwargs:
|
||||
mm_kwargs = self.process_multi_modal_data_neuron(mm_kwargs)
|
||||
multi_modal_kwargs_list.append(mm_kwargs)
|
||||
|
||||
max_seq_len = max(seq_lens)
|
||||
assert max_seq_len > 0
|
||||
input_tokens = make_tensor_with_pad(input_tokens,
|
||||
pad=0,
|
||||
max_len=max_seq_len,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
input_positions = make_tensor_with_pad(input_positions,
|
||||
pad=0,
|
||||
max_len=max_seq_len,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
input_block_ids = torch.tensor(input_block_ids,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
|
||||
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
|
||||
|
||||
return (input_tokens, input_positions, input_block_ids, seq_lens,
|
||||
multi_modal_kwargs)
|
||||
|
||||
def _prepare_decode(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
assert len(seq_group_metadata_list) > 0
|
||||
input_tokens: List[List[int]] = []
|
||||
input_positions: List[List[int]] = []
|
||||
input_block_ids: List[int] = []
|
||||
context_lens: List[int] = []
|
||||
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
assert not seq_group_metadata.is_prompt
|
||||
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
|
||||
for seq_id in seq_ids:
|
||||
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||
generation_token = seq_data.get_last_token_id()
|
||||
input_tokens.append([generation_token])
|
||||
|
||||
seq_len = seq_data.get_len()
|
||||
position = seq_len - 1
|
||||
input_positions.append([position])
|
||||
context_lens.append(seq_len)
|
||||
|
||||
assert seq_group_metadata.block_tables is not None
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
assert len(block_table) == 1
|
||||
input_block_ids.append(block_table[0])
|
||||
|
||||
input_tokens = make_tensor_with_pad(input_tokens,
|
||||
pad=0,
|
||||
max_len=1,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
input_positions = make_tensor_with_pad(input_positions,
|
||||
pad=0,
|
||||
max_len=1,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
context_lens = torch.tensor(context_lens,
|
||||
dtype=torch.int,
|
||||
device=self.device)
|
||||
input_block_ids = torch.tensor(input_block_ids,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
|
||||
return input_tokens, input_positions, input_block_ids
|
||||
|
||||
def make_model_input_from_broadcasted_tensor_dict(
|
||||
self, tensor_dict: Dict[str, Any]) -> ModelInputForNeuron:
|
||||
return ModelInputForNeuron.from_broadcasted_tensor_dict(tensor_dict)
|
||||
|
||||
def prepare_model_input(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
virtual_engine: int = 0,
|
||||
finished_requests_ids: Optional[List[str]] = None
|
||||
) -> ModelInputForNeuron:
|
||||
multi_modal_kwargs = None
|
||||
# NOTE: We assume that all sequences in the group are all prompts or
|
||||
# all decodes.
|
||||
is_prompt = seq_group_metadata_list[0].is_prompt
|
||||
# Prepare input tensors.
|
||||
if is_prompt:
|
||||
(input_tokens, input_positions, input_block_ids, seq_lens,
|
||||
multi_modal_kwargs
|
||||
) = self._prepare_prompt(seq_group_metadata_list)
|
||||
else:
|
||||
(input_tokens, input_positions,
|
||||
input_block_ids) = self._prepare_decode(seq_group_metadata_list)
|
||||
seq_lens = None
|
||||
|
||||
if not self._on_device_sampling_disabled:
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
sampling_params = seq_group_metadata.sampling_params
|
||||
top_k, top_p, temperature = (
|
||||
self._convert_to_neuron_sampling_params(sampling_params))
|
||||
sampling_params.top_k = top_k
|
||||
sampling_params.top_p = top_p
|
||||
sampling_params.temperature = temperature
|
||||
|
||||
# we need multi_modal_data for later tokens as well
|
||||
multi_modal_kwargs_list: List[MultiModalKwargs] = []
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
mm_data = seq_group_metadata.multi_modal_data
|
||||
if mm_data:
|
||||
multi_modal_kwargs_list.append(mm_data)
|
||||
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
|
||||
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
seq_lens,
|
||||
# query_lens is not needed if chunked prefill is not
|
||||
# supported. Since neuron worker doesn't support chunked prefill
|
||||
# just use seq_lens instead.
|
||||
seq_lens,
|
||||
self.device,
|
||||
self.pin_memory,
|
||||
generators=self.get_generators(finished_requests_ids))
|
||||
|
||||
if current_platform.use_transformers_neuronx(
|
||||
) and not self._on_device_sampling_disabled:
|
||||
# Once the request IDs are changed in current iteration, we will
|
||||
# update the on-device sampling parameters.
|
||||
current_batch_request_ids = [
|
||||
seq_group_meta_data.request_id
|
||||
for seq_group_meta_data in seq_group_metadata_list
|
||||
]
|
||||
if current_batch_request_ids != self._previous_batch_request_ids:
|
||||
self._update_neuron_sampling_params(seq_group_metadata_list)
|
||||
self._previous_batch_request_ids = current_batch_request_ids
|
||||
|
||||
return ModelInputForNeuron(input_tokens=input_tokens,
|
||||
input_positions=input_positions,
|
||||
input_block_ids=input_block_ids,
|
||||
sampling_metadata=sampling_metadata,
|
||||
multi_modal_kwargs=multi_modal_kwargs)
|
||||
|
||||
def _update_neuron_sampling_params(
|
||||
self, seq_group_metadata_list: List[SequenceGroupMetadata]):
|
||||
# Update Neuron sampling parameters (GenerationConfig in Neuron)
|
||||
current_sampling_params = self.model_config.neuron_sampling_params
|
||||
assert current_sampling_params is not None, (
|
||||
f"Failed to update sampling_params, "
|
||||
f"current sampling params is {current_sampling_params}")
|
||||
|
||||
is_update_needed = False
|
||||
|
||||
top_k = current_sampling_params.top_k
|
||||
top_p = current_sampling_params.top_p
|
||||
temperature = current_sampling_params.temperature
|
||||
|
||||
# The index of a sequence's sampling parameters in neuron is equal to
|
||||
# its index in `input_block_ids`.
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
sampling_params = seq_group_metadata.sampling_params
|
||||
|
||||
seq_group_top_k = sampling_params.top_k
|
||||
seq_group_top_p = sampling_params.top_p
|
||||
seq_group_temperature = sampling_params.temperature
|
||||
|
||||
for seq_id in seq_ids:
|
||||
index = seq_group_metadata.block_tables[seq_id][0]
|
||||
if (top_k[index] != seq_group_top_k
|
||||
or top_p[index] != seq_group_top_p
|
||||
or temperature[index] != seq_group_temperature):
|
||||
is_update_needed = True
|
||||
|
||||
top_k[index] = seq_group_top_k
|
||||
top_p[index] = seq_group_top_p
|
||||
temperature[index] = seq_group_temperature
|
||||
|
||||
# update_generation_config is only available in transformers-neuronx
|
||||
if is_update_needed and current_platform.use_transformers_neuronx():
|
||||
self.model.model.update_generation_config(current_sampling_params)
|
||||
|
||||
def _convert_to_neuron_sampling_params(
|
||||
self, sampling_params: SamplingParams) -> Tuple[int, float, float]:
|
||||
# Returns the top_k, top_p and temperature parameters for neuron.
|
||||
top_k = sampling_params.top_k
|
||||
top_p = sampling_params.top_p
|
||||
temperature = sampling_params.temperature
|
||||
|
||||
if temperature == 0.0:
|
||||
# Enable greedy sampling on zero temperature
|
||||
return (1, 1.0, 1.0)
|
||||
if top_k < 1 or top_k > self._MAX_NEURON_SAMPLING_TOP_K:
|
||||
top_k = self._MAX_NEURON_SAMPLING_TOP_K
|
||||
|
||||
return (top_k, top_p, temperature)
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
model_input: ModelInputForNeuron,
|
||||
kv_caches: Optional[List[torch.Tensor]] = None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
num_steps: int = 1,
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
if num_steps > 1:
|
||||
raise ValueError(
|
||||
"NeuronModelRunner does not support multi-step execution.")
|
||||
|
||||
# extract top_k, top_p and temperature from model_input for neuron
|
||||
# forward call
|
||||
sampling_params = (torch.tensor([[
|
||||
seq_group.sampling_params.top_k, seq_group.sampling_params.top_p,
|
||||
seq_group.sampling_params.temperature
|
||||
] for seq_group in model_input.sampling_metadata.seq_groups]))
|
||||
|
||||
if current_platform.use_neuronx_distributed():
|
||||
hidden_states = self.model(
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
input_block_ids=model_input.input_block_ids,
|
||||
sampling_params=sampling_params,
|
||||
adapter_ids=model_input.adapter_ids,
|
||||
**MultiModalKwargs.as_kwargs(
|
||||
model_input.multi_modal_kwargs or {},
|
||||
device=self.device,
|
||||
),
|
||||
)
|
||||
elif current_platform.use_transformers_neuronx():
|
||||
# [TODO] validate on-device sampling
|
||||
# The model signature may need change for on-device sampling
|
||||
hidden_states = self.model(
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
input_block_ids=model_input.input_block_ids,
|
||||
**MultiModalKwargs.as_kwargs(
|
||||
model_input.multi_modal_kwargs or {},
|
||||
device=self.device,
|
||||
),
|
||||
)
|
||||
|
||||
# Compute the logits only if the on-device sampling is turned off as
|
||||
# on-device sampling outputs the token ids.
|
||||
if self._on_device_sampling_disabled:
|
||||
logits = self.model.compute_logits(hidden_states,
|
||||
model_input.sampling_metadata)
|
||||
else:
|
||||
logits = hidden_states
|
||||
|
||||
# Sample the next token.
|
||||
output = self.model.sample(
|
||||
logits=logits,
|
||||
sampling_metadata=model_input.sampling_metadata,
|
||||
)
|
||||
return [output]
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
return self.model_config.get_vocab_size()
|
||||
|
||||
def process_multi_modal_data_neuron(self, mm_data):
|
||||
# this is a no-op for NeuronModelRunner
|
||||
return mm_data
|
||||
|
||||
def remove_all_loras(self):
|
||||
raise NotImplementedError(
|
||||
"LoRAs are not supported for Transformers NeuronX framework")
|
||||
|
||||
def set_active_loras(self, lora_requests: Set[LoRARequest],
|
||||
lora_mapping: LoRAMapping) -> None:
|
||||
raise NotImplementedError(
|
||||
"LoRAs are not supported for Transformers NeuronX framework")
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest):
|
||||
raise NotImplementedError(
|
||||
"LoRAs are not supported for Transformers NeuronX framework")
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
raise NotImplementedError(
|
||||
"LoRAs are not supported for Transformers NeuronX framework")
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
raise NotImplementedError(
|
||||
"LoRAs are not supported for Transformers NeuronX framework")
|
||||
|
||||
def list_loras(self) -> Set[int]:
|
||||
raise NotImplementedError(
|
||||
"LoRAs are not supported for Transformers NeuronX framework")
|
||||
198
vllm/worker/neuron_worker.py
Normal file
198
vllm/worker/neuron_worker.py
Normal file
@@ -0,0 +1,198 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""A Neuron worker class."""
|
||||
import os
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
import torch.distributed
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.neuron import NeuronFramework
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.worker.cache_engine import CacheEngine
|
||||
from vllm.worker.neuron_model_runner import NeuronModelRunner
|
||||
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
|
||||
WorkerInput)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class NeuronWorker(LocalOrDistributedWorkerBase):
|
||||
"""A worker class that executes the model on a group of neuron cores.
|
||||
"""
|
||||
|
||||
model_runner: NeuronModelRunner
|
||||
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
is_driver_worker: bool = False) -> None:
|
||||
WorkerBase.__init__(self, vllm_config=vllm_config)
|
||||
self.local_rank = local_rank
|
||||
self.rank = rank
|
||||
self.distributed_init_method = distributed_init_method
|
||||
self.is_driver_worker = is_driver_worker
|
||||
self.lora_config = vllm_config.lora_config
|
||||
|
||||
if self.model_config.trust_remote_code:
|
||||
# note: lazy import to avoid importing torch before initializing
|
||||
from vllm.utils import init_cached_hf_modules
|
||||
init_cached_hf_modules()
|
||||
|
||||
neuron_framework = current_platform.get_neuron_framework_to_use()
|
||||
if neuron_framework == NeuronFramework.TRANSFORMERS_NEURONX:
|
||||
self.model_runner = self.get_tnx_model_runner(vllm_config)
|
||||
elif neuron_framework == NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE:
|
||||
self.model_runner = self.get_neuronx_distributed_model_runner(
|
||||
vllm_config)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Specified framework" +
|
||||
f" {os.environ.get('VLLM_NEURON_FRAMEWORK')}" +
|
||||
" is either not installed or not supported." +
|
||||
" Supported frameworks: " +
|
||||
"[transformers-neuronx, neuronx-distributed-inference]")
|
||||
|
||||
def get_tnx_model_runner(self, vllm_config):
|
||||
assert (self.lora_config
|
||||
is None), ("LoRA is not supported for TransformersNeuronX "
|
||||
"framework.")
|
||||
from vllm.worker.multi_step_neuron_model_runner import (
|
||||
MultiStepNeuronModelRunner)
|
||||
if self.speculative_config is not None:
|
||||
return MultiStepNeuronModelRunner(vllm_config=vllm_config)
|
||||
else:
|
||||
return NeuronModelRunner(vllm_config=vllm_config)
|
||||
|
||||
def get_neuronx_distributed_model_runner(self, vllm_config):
|
||||
from vllm.worker.multi_step_neuronx_distributed_model_runner import (
|
||||
MultiStepNeuronxDistributedModelRunner)
|
||||
from vllm.worker.neuronx_distributed_model_runner import (
|
||||
NeuronxDistributedModelRunner)
|
||||
if self.speculative_config is not None:
|
||||
assert (self.lora_config
|
||||
is None), "LoRA is not supported for Speculative Decoding"
|
||||
return MultiStepNeuronxDistributedModelRunner(
|
||||
vllm_config=vllm_config)
|
||||
else:
|
||||
return NeuronxDistributedModelRunner(vllm_config=vllm_config)
|
||||
|
||||
def init_device(self) -> None:
|
||||
self.init_distributed_environment()
|
||||
|
||||
# Set random seed.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
def load_model(self):
|
||||
self.model_runner.load_model()
|
||||
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
"""Determine the number of available KV blocks.
|
||||
|
||||
Swapping is not yet supported, so always return num_cpu_blocks=0.
|
||||
|
||||
We configure num_gpu_blocks to be equal to max_num_seqs.
|
||||
"""
|
||||
# Set the number of GPU blocks to be the same as the maximum number of
|
||||
# sequences that can be processed in a single batch. This is equivalent
|
||||
# to schedule without PagedAttention.
|
||||
num_gpu_blocks = self.scheduler_config.max_num_seqs + 1
|
||||
|
||||
# Swap not yet supported with Neuron backend.
|
||||
num_cpu_blocks = 0
|
||||
|
||||
return num_gpu_blocks, num_cpu_blocks
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int,
|
||||
num_cpu_blocks: int) -> None:
|
||||
"""Initialize the KV cache.
|
||||
"""
|
||||
|
||||
# Different values are not tested.
|
||||
assert num_cpu_blocks == 0
|
||||
assert num_gpu_blocks == self.scheduler_config.max_num_seqs + 1
|
||||
|
||||
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||
|
||||
@property
|
||||
def do_metadata_broadcast(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
|
||||
return None
|
||||
|
||||
@property
|
||||
def cache_engines(self) -> Optional[List[CacheEngine]]:
|
||||
return None
|
||||
|
||||
@torch.inference_mode()
|
||||
def prepare_worker_input(
|
||||
self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
|
||||
return WorkerInput(num_seq_groups=len(
|
||||
execute_model_req.seq_group_metadata_list), )
|
||||
|
||||
def execute_worker(self, worker_input: WorkerInput) -> None:
|
||||
pass
|
||||
|
||||
def get_cache_block_size_bytes(self) -> int:
|
||||
"""Determine the size in bytes of a cache block.
|
||||
|
||||
This is required for speculative decoding; it is not yet implemented.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def init_distributed_environment(self):
|
||||
"""Neuron uses transformers-neuronx for tensor parallelism.
|
||||
|
||||
vLLM still needs the environment initialized when TP/PP > 1
|
||||
"""
|
||||
init_distributed_environment(
|
||||
world_size=1,
|
||||
rank=self.rank,
|
||||
local_rank=self.local_rank,
|
||||
distributed_init_method=self.distributed_init_method,
|
||||
backend="gloo",
|
||||
)
|
||||
|
||||
ensure_model_parallel_initialized(
|
||||
1,
|
||||
1,
|
||||
)
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
if current_platform.use_transformers_neuronx():
|
||||
raise NotImplementedError(
|
||||
f"{type(self)} does not support LoRA with Neuron Framework "
|
||||
f"Transformers NeuronX")
|
||||
return self.model_runner.add_lora(lora_request)
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
if current_platform.use_transformers_neuronx():
|
||||
raise NotImplementedError(
|
||||
f"{type(self)} does not support LoRA with Neuron Framework "
|
||||
f"Transformers NeuronX")
|
||||
return self.model_runner.remove_lora(lora_id)
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
if current_platform.use_transformers_neuronx():
|
||||
raise NotImplementedError(
|
||||
f"{type(self)} does not support LoRA with Neuron Framework "
|
||||
f"Transformers NeuronX")
|
||||
return self.model_runner.pin_lora(lora_id)
|
||||
|
||||
def list_loras(self) -> Set[int]:
|
||||
if current_platform.use_transformers_neuronx():
|
||||
raise NotImplementedError(
|
||||
f"{type(self)} does not support LoRA with Neuron Framework "
|
||||
f"Transformers NeuronX")
|
||||
return self.model_runner.list_loras()
|
||||
294
vllm/worker/neuronx_distributed_model_runner.py
Normal file
294
vllm/worker/neuronx_distributed_model_runner.py
Normal file
@@ -0,0 +1,294 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import List, Optional, Set
|
||||
|
||||
import torch
|
||||
from neuronx_distributed_inference.models.mllama.aspect_ratio_utils import (
|
||||
get_all_supported_aspect_ratios)
|
||||
from neuronx_distributed_inference.modules.generation.sampling import (
|
||||
prepare_sampling_params)
|
||||
from neuronx_distributed_inference.modules.lora_serving import (
|
||||
LoraCheckpoint, LoraServingConfig)
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.entrypoints.openai.serving_models import LoRAModulePath
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import LoRAMapping
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.model_loader.neuronx_distributed import (
|
||||
_get_model_architecture, get_neuron_model)
|
||||
from vllm.multimodal import MultiModalKwargs
|
||||
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
|
||||
from vllm.worker.neuron_model_runner import (ModelInputForNeuron,
|
||||
NeuronModelRunner)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class NeuronxDistributedModelRunner(NeuronModelRunner):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
):
|
||||
super().__init__(vllm_config)
|
||||
self.lora_checkpoint = None
|
||||
self.model = None
|
||||
self.lora_serving_config = None
|
||||
|
||||
@staticmethod
|
||||
def _get_lora_paths_strings(lora_modules: List[LoRAModulePath]):
|
||||
if not lora_modules:
|
||||
return None
|
||||
return {_.get("name"): _.get("path") for _ in lora_modules}
|
||||
|
||||
def _get_nxdi_lora_config(self):
|
||||
override_neuron_config = self.model_config.override_neuron_config
|
||||
lora_modules = override_neuron_config.pop("lora_modules", None)
|
||||
target_modules = override_neuron_config.pop("target_modules", None)
|
||||
lora_ckpt_paths = self._get_lora_paths_strings(lora_modules)
|
||||
if self.lora_config.max_loras < len(lora_ckpt_paths):
|
||||
raise ValueError(
|
||||
"Number of LoRAs (%s) exceeds maximum "
|
||||
"allowed (%s)", len(lora_ckpt_paths),
|
||||
self.lora_config.max_loras)
|
||||
|
||||
return LoraServingConfig(
|
||||
max_loras=self.lora_config.max_loras,
|
||||
max_lora_rank=self.lora_config.max_lora_rank,
|
||||
target_modules=target_modules,
|
||||
lora_ckpt_paths=lora_ckpt_paths,
|
||||
)
|
||||
|
||||
def load_model(self) -> None:
|
||||
# Update LoRA config
|
||||
if self.lora_config is not None:
|
||||
self.lora_serving_config = self._get_nxdi_lora_config()
|
||||
self.lora_checkpoint = LoraCheckpoint(self.lora_serving_config)
|
||||
self.model = get_neuron_model(
|
||||
self.model_config,
|
||||
parallel_config=self.parallel_config,
|
||||
scheduler_config=self.scheduler_config,
|
||||
lora_serving_config=self.lora_serving_config)
|
||||
|
||||
def get_nxd_sampling_params(self, sampling_metadata):
|
||||
if self.model.config.neuron_config.on_device_sampling_config:
|
||||
max_topk = (self.model.config.neuron_config.
|
||||
on_device_sampling_config.global_topk)
|
||||
else:
|
||||
max_topk = self.model.config.vocab_size
|
||||
|
||||
top_k = [1] * self.scheduler_config.max_num_seqs
|
||||
top_p = [1.0] * self.scheduler_config.max_num_seqs
|
||||
temperature = [1.0] * self.scheduler_config.max_num_seqs
|
||||
|
||||
for index, sequenceGroupToSample in enumerate(
|
||||
sampling_metadata.seq_groups):
|
||||
top_k[index] = (sequenceGroupToSample.sampling_params.top_k
|
||||
if sequenceGroupToSample.sampling_params.top_k > 0
|
||||
else max_topk)
|
||||
top_p[index] = sequenceGroupToSample.sampling_params.top_p
|
||||
temperature[index] = (
|
||||
sequenceGroupToSample.sampling_params.temperature)
|
||||
|
||||
sampling_params = prepare_sampling_params(
|
||||
batch_size=self.scheduler_config.max_num_seqs,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
temperature=temperature)
|
||||
return sampling_params
|
||||
|
||||
def get_multi_modal_data_neuron(self, input_images):
|
||||
raise NotImplementedError("need to restore multi-modal support")
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
model_input: ModelInputForNeuron,
|
||||
kv_caches: Optional[List[torch.Tensor]] = None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
num_steps: int = 1,
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
if num_steps > 1:
|
||||
raise ValueError(
|
||||
"NeuronModelRunner does not support multi-step execution.")
|
||||
|
||||
if _get_model_architecture(
|
||||
self.model.config) != "MllamaForConditionalGeneration":
|
||||
return super().execute_model(model_input, kv_caches,
|
||||
intermediate_tensors, num_steps)
|
||||
|
||||
sampling_params = self.get_nxd_sampling_params(
|
||||
model_input.sampling_metadata)
|
||||
|
||||
if model_input.multi_modal_kwargs.get('pixel_values') is not None:
|
||||
hidden_states = self.model(
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
seq_ids=model_input.input_block_ids,
|
||||
pixel_values=model_input.multi_modal_kwargs.get(
|
||||
'pixel_values'),
|
||||
aspect_ratios=model_input.multi_modal_kwargs.get(
|
||||
'aspect_ratios'),
|
||||
sampling_params=sampling_params,
|
||||
num_chunks=model_input.multi_modal_kwargs.get('num_chunks'),
|
||||
has_image=model_input.multi_modal_kwargs.get(
|
||||
'has_image').squeeze(1),
|
||||
)
|
||||
else:
|
||||
bs = model_input.input_tokens.shape[0] if (model_input.input_tokens
|
||||
is not None) else 1
|
||||
empty_pixel_values = torch.zeros([bs, 1, 4, 3, 560, 560],
|
||||
dtype=torch.bfloat16)
|
||||
empty_aspect_ratios = torch.ones([bs, 1, 2], dtype=torch.int64)
|
||||
num_chunks = torch.zeros((bs, 1), dtype=torch.int32)
|
||||
has_image = torch.zeros([bs], dtype=torch.int32)
|
||||
hidden_states = self.model(
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
seq_ids=model_input.input_block_ids,
|
||||
pixel_values=empty_pixel_values,
|
||||
aspect_ratios=empty_aspect_ratios,
|
||||
sampling_params=sampling_params,
|
||||
num_chunks=num_chunks,
|
||||
has_image=has_image,
|
||||
)
|
||||
|
||||
output = self.model.sample(
|
||||
hidden_states=hidden_states,
|
||||
sampling_metadata=model_input.sampling_metadata,
|
||||
)
|
||||
|
||||
return [output]
|
||||
|
||||
def process_multi_modal_data_neuron(self, mm_data):
|
||||
# Neuron uses aspect_ratios instead of aspect_ratio_ids
|
||||
all_supported_aspect_ratios = get_all_supported_aspect_ratios(
|
||||
self.model.config.vision_config.max_num_tiles)
|
||||
aspect_ratio_ids = mm_data.get("aspect_ratio_ids")
|
||||
mm_data["aspect_ratios"] = torch.tensor(
|
||||
all_supported_aspect_ratios[aspect_ratio_ids]).unsqueeze(0)
|
||||
|
||||
# Neuron's num_chunks is HF's num_tiles
|
||||
mm_data["num_chunks"] = mm_data.get("num_tiles")
|
||||
|
||||
# Input has an image if it has pixel_values
|
||||
bs = mm_data["num_chunks"].shape[0]
|
||||
pixel_values = mm_data.get("pixel_values")
|
||||
if pixel_values is not None and not torch.all(pixel_values == 0):
|
||||
mm_data["has_image"] = torch.ones(bs)
|
||||
|
||||
else:
|
||||
mm_data["has_image"] = torch.zeros(bs)
|
||||
return mm_data
|
||||
|
||||
def _get_lora_adapter_ids(self, seq_group_metadata_list):
|
||||
# set LoRA adapter IDs for multi-lora serving
|
||||
batch_size = len(seq_group_metadata_list)
|
||||
if self.lora_checkpoint is not None:
|
||||
# "0" indicates NxDI to use the base model for inference
|
||||
adapter_ids = ["0"] * batch_size
|
||||
for idx, seq_group_metadata in enumerate(seq_group_metadata_list):
|
||||
if seq_group_metadata.lora_request is not None:
|
||||
adapter_ids[
|
||||
idx] = seq_group_metadata.lora_request.lora_name
|
||||
|
||||
# convert adapter_ids from strings to integers
|
||||
adapter_ids = self.lora_checkpoint.convert_adapter_ids_to_indices(
|
||||
adapter_ids, batch_size)
|
||||
else:
|
||||
adapter_ids = torch.zeros((batch_size), dtype=torch.int32)
|
||||
|
||||
return adapter_ids
|
||||
|
||||
def prepare_model_input(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
virtual_engine: int = 0,
|
||||
finished_requests_ids: Optional[List[str]] = None
|
||||
) -> ModelInputForNeuron:
|
||||
# NOTE: We assume that all sequences in the group are all prompts or
|
||||
# all decodes.
|
||||
is_prompt = seq_group_metadata_list[0].is_prompt
|
||||
# Prepare input tensors.
|
||||
if is_prompt:
|
||||
(input_tokens, input_positions, input_block_ids, seq_lens,
|
||||
multi_modal_kwargs
|
||||
) = self._prepare_prompt(seq_group_metadata_list)
|
||||
else:
|
||||
(input_tokens, input_positions,
|
||||
input_block_ids) = self._prepare_decode(seq_group_metadata_list)
|
||||
seq_lens = None
|
||||
|
||||
if not self._on_device_sampling_disabled:
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
sampling_params = seq_group_metadata.sampling_params
|
||||
top_k, top_p, temperature = (
|
||||
self._convert_to_neuron_sampling_params(sampling_params))
|
||||
sampling_params.top_k = top_k
|
||||
sampling_params.top_p = top_p
|
||||
sampling_params.temperature = temperature
|
||||
|
||||
# we need multi_modal_data for later tokens as well
|
||||
multi_modal_kwargs_list: List[MultiModalKwargs] = []
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
mm_data = seq_group_metadata.multi_modal_data
|
||||
if mm_data:
|
||||
multi_modal_kwargs_list.append(mm_data)
|
||||
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
|
||||
|
||||
lora_adapter_ids = self._get_lora_adapter_ids(seq_group_metadata_list)
|
||||
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
seq_lens,
|
||||
# query_lens is not needed if chunked prefill is not
|
||||
# supported. Since neuron worker doesn't support chunked prefill
|
||||
# just use seq_lens instead.
|
||||
seq_lens,
|
||||
self.device,
|
||||
self.pin_memory,
|
||||
generators=self.get_generators(finished_requests_ids))
|
||||
|
||||
return ModelInputForNeuron(input_tokens=input_tokens,
|
||||
input_positions=input_positions,
|
||||
input_block_ids=input_block_ids,
|
||||
sampling_metadata=sampling_metadata,
|
||||
multi_modal_kwargs=multi_modal_kwargs,
|
||||
adapter_ids=lora_adapter_ids)
|
||||
|
||||
def remove_all_loras(self):
|
||||
raise NotImplementedError(
|
||||
"Managing LoRAs is only supported through the "
|
||||
"lora_modules parameter in override_neuron_config")
|
||||
|
||||
def set_active_loras(self, lora_requests: Set[LoRARequest],
|
||||
lora_mapping: LoRAMapping) -> None:
|
||||
raise NotImplementedError(
|
||||
"Managing LoRAs is only supported through the "
|
||||
"lora_modules parameter in override_neuron_config")
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest):
|
||||
logger.warning(
|
||||
"Adding LoRAs is only supported through the "
|
||||
"lora_modules parameter in override_neuron_config. If you supplied "
|
||||
"the parameter, you can ignore this warning. Ignoring"
|
||||
"lora request: ", lora_request)
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
raise NotImplementedError(
|
||||
"Managing LoRAs is only supported through the "
|
||||
"lora_modules parameter in override_neuron_config")
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
raise NotImplementedError(
|
||||
"Managing LoRAs is only supported through the "
|
||||
"lora_modules parameter in override_neuron_config")
|
||||
|
||||
def list_loras(self) -> Set[int]:
|
||||
raise NotImplementedError(
|
||||
"Managing LoRAs is only supported through the "
|
||||
"lora_modules parameter in override_neuron_config")
|
||||
211
vllm/worker/pooling_model_runner.py
Normal file
211
vllm/worker/pooling_model_runner.py
Normal file
@@ -0,0 +1,211 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import dataclasses
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.multimodal import MultiModalKwargs
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.worker.model_runner import (GPUModelRunnerBase, ModelInputForGPU,
|
||||
ModelInputForGPUBuilder)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ModelInputForGPUWithPoolingMetadata(ModelInputForGPU):
|
||||
"""
|
||||
Used by the PoolingModelRunner.
|
||||
"""
|
||||
pooling_metadata: Optional["PoolingMetadata"] = None
|
||||
|
||||
|
||||
class PoolingModelRunner(
|
||||
GPUModelRunnerBase[ModelInputForGPUWithPoolingMetadata]):
|
||||
_model_input_cls: Type[ModelInputForGPUWithPoolingMetadata] = (
|
||||
ModelInputForGPUWithPoolingMetadata)
|
||||
_builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
kv_cache_dtype: Optional[str] = "auto",
|
||||
is_driver_worker: bool = False,
|
||||
):
|
||||
super().__init__(vllm_config=vllm_config,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
is_driver_worker=is_driver_worker)
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
model_input: ModelInputForGPUWithPoolingMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
num_steps: int = 1,
|
||||
) -> Optional[Union[List[PoolerOutput], IntermediateTensors]]:
|
||||
if num_steps > 1:
|
||||
raise ValueError(
|
||||
"PoolingModelRunner does not support multi-step execution.")
|
||||
|
||||
if self.lora_config:
|
||||
assert model_input.lora_requests is not None
|
||||
assert model_input.lora_mapping is not None
|
||||
self.set_active_loras(model_input.lora_requests,
|
||||
model_input.lora_mapping)
|
||||
|
||||
if self.prompt_adapter_config:
|
||||
assert model_input.prompt_adapter_requests is not None
|
||||
assert model_input.prompt_adapter_mapping is not None
|
||||
self.set_active_prompt_adapters(
|
||||
model_input.prompt_adapter_requests,
|
||||
model_input.prompt_adapter_mapping)
|
||||
|
||||
# Currently cuda graph is only supported by the decode phase.
|
||||
assert model_input.attn_metadata is not None
|
||||
prefill_meta = model_input.attn_metadata.prefill_metadata
|
||||
decode_meta = model_input.attn_metadata.decode_metadata
|
||||
virtual_engine = model_input.virtual_engine
|
||||
# Pooling models are (ab-)used also to integrate non text models that
|
||||
# are not autoregressive (PrithviGeosaptialMAE).
|
||||
# These model might not use attention and do not really have a prefill
|
||||
# and decode phase. The model input is processed in one shot and both
|
||||
# decode_metadata and prefill_metadata would be None for such models.
|
||||
# See the PlaceholderAttentionMetadata class.
|
||||
# TODO: Figure out if cuda_graph is of any use for these models and
|
||||
# explore how to leverage it.
|
||||
if (prefill_meta is None and decode_meta is not None
|
||||
and decode_meta.use_cuda_graph):
|
||||
if model_input.inputs_embeds is None:
|
||||
assert model_input.input_tokens is not None
|
||||
graph_batch_size = model_input.input_tokens.shape[0]
|
||||
model_executable = (
|
||||
self.graph_runners[model_input.virtual_engine][(
|
||||
graph_batch_size, False)])
|
||||
else:
|
||||
graph_batch_size = model_input.inputs_embeds.shape[0]
|
||||
model_executable = (
|
||||
self.graph_runners[model_input.virtual_engine][(
|
||||
graph_batch_size, True)])
|
||||
else:
|
||||
model_executable = self.model
|
||||
|
||||
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
|
||||
seqlen_agnostic_kwargs = {
|
||||
"finished_requests_ids": model_input.finished_requests_ids,
|
||||
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
|
||||
} if self.has_inner_state else {}
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_forward_time):
|
||||
model_forward_start = torch.cuda.Event(enable_timing=True)
|
||||
model_forward_end = torch.cuda.Event(enable_timing=True)
|
||||
model_forward_start.record()
|
||||
|
||||
cross_enc_kwargs = {}
|
||||
if model_input.token_types is not None:
|
||||
cross_enc_kwargs["token_type_ids"] = model_input.token_types
|
||||
|
||||
with set_forward_context(model_input.attn_metadata, self.vllm_config,
|
||||
virtual_engine):
|
||||
hidden_or_intermediate_states = model_executable(
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**MultiModalKwargs.as_kwargs(
|
||||
multi_modal_kwargs,
|
||||
device=self.device,
|
||||
),
|
||||
**cross_enc_kwargs,
|
||||
**seqlen_agnostic_kwargs,
|
||||
)
|
||||
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_forward_time):
|
||||
model_forward_end.record()
|
||||
|
||||
# Only perform pooling in the last pipeline stage.
|
||||
if not get_pp_group().is_last_rank:
|
||||
if (self.is_driver_worker
|
||||
and hidden_or_intermediate_states is not None
|
||||
and isinstance(hidden_or_intermediate_states,
|
||||
IntermediateTensors)
|
||||
and self.observability_config is not None
|
||||
and self.observability_config.collect_model_forward_time):
|
||||
model_forward_end.synchronize()
|
||||
model_forward_time = model_forward_start.elapsed_time(
|
||||
model_forward_end)
|
||||
orig_model_forward_time = 0.0
|
||||
if intermediate_tensors is not None:
|
||||
orig_model_forward_time = intermediate_tensors.tensors.get(
|
||||
"model_forward_time", torch.tensor(0.0)).item()
|
||||
hidden_or_intermediate_states.tensors["model_forward_time"] = (
|
||||
torch.tensor(model_forward_time + orig_model_forward_time))
|
||||
return hidden_or_intermediate_states
|
||||
|
||||
# Only perform pooling in the driver worker.
|
||||
if not self.is_driver_worker:
|
||||
return []
|
||||
|
||||
return [
|
||||
self.model.pooler(hidden_states=hidden_or_intermediate_states,
|
||||
pooling_metadata=model_input.pooling_metadata)
|
||||
]
|
||||
|
||||
def make_model_input_from_broadcasted_tensor_dict(
|
||||
self,
|
||||
tensor_dict: Dict[str,
|
||||
Any]) -> ModelInputForGPUWithPoolingMetadata:
|
||||
return ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict(
|
||||
tensor_dict,
|
||||
attn_backend=self.attn_backend,
|
||||
)
|
||||
|
||||
def prepare_model_input(
|
||||
self,
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||
virtual_engine: int = 0,
|
||||
finished_requests_ids: Optional[List[str]] = None
|
||||
) -> ModelInputForGPUWithPoolingMetadata:
|
||||
assert seq_group_metadata_list is not None
|
||||
model_input = self._prepare_model_input_tensors(
|
||||
seq_group_metadata_list, finished_requests_ids)
|
||||
# Prepare PoolingMetadata.
|
||||
assert model_input.seq_lens is not None
|
||||
pooling_metadata = self._prepare_pooling(seq_group_metadata_list,
|
||||
model_input.seq_lens)
|
||||
|
||||
return dataclasses.replace(model_input,
|
||||
pooling_metadata=pooling_metadata)
|
||||
|
||||
def _prepare_pooling(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
prompt_lens: List[int],
|
||||
) -> PoolingMetadata:
|
||||
"""Prepare PoolingMetadata for the sequence group metadata list."""
|
||||
seq_groups: List[Tuple[List[int], PoolingParams]] = []
|
||||
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
pooling_params = seq_group_metadata.pooling_params
|
||||
seq_groups.append((seq_ids, pooling_params))
|
||||
|
||||
seq_data: Dict[int, SequenceData] = {}
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
seq_data.update(seq_group_metadata.seq_data)
|
||||
|
||||
pooling_metadata = PoolingMetadata(
|
||||
seq_groups=seq_groups,
|
||||
seq_data=seq_data,
|
||||
prompt_lens=prompt_lens,
|
||||
)
|
||||
|
||||
return pooling_metadata
|
||||
909
vllm/worker/tpu_model_runner.py
Normal file
909
vllm/worker/tpu_model_runner.py
Normal file
@@ -0,0 +1,909 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import enum
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
|
||||
Type, Union)
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.runtime as xr
|
||||
|
||||
from vllm.attention import AttentionMetadata, get_attn_backend
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.forward_context import get_forward_context, set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
|
||||
Logprob, SequenceGroupMetadata, SequenceOutput)
|
||||
from vllm.worker.model_runner_base import (
|
||||
ModelRunnerBase, ModelRunnerInputBase,
|
||||
_add_attn_metadata_broadcastable_dict,
|
||||
_init_attn_metadata_from_tensor_dict)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Here we utilize the behavior that out-of-bound index is ignored.
|
||||
# FIXME(woosuk): Find a more reliable way to prevent possible bugs.
|
||||
_PAD_SLOT_ID = 1_000_000_000
|
||||
# FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow.
|
||||
_ENABLE_TOP_P = False
|
||||
# FIXME(woosuk): A temporary hack to support `n > 1`.
|
||||
# This can significantly affect the performance if too large.
|
||||
_MAX_NUM_SAMPLES = 128
|
||||
|
||||
|
||||
class ExecutionMode(enum.Enum):
|
||||
PREFILL = enum.auto()
|
||||
DECODE = enum.auto()
|
||||
PREFIX_PREFILL = enum.auto()
|
||||
|
||||
def is_prefill(self) -> bool:
|
||||
return self in (ExecutionMode.PREFILL, ExecutionMode.PREFIX_PREFILL)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ModelInputForTPU(ModelRunnerInputBase):
|
||||
token_ids: torch.Tensor
|
||||
position_ids: torch.Tensor
|
||||
attn_metadata: AttentionMetadata
|
||||
input_lens: torch.Tensor
|
||||
t: torch.Tensor
|
||||
p: torch.Tensor
|
||||
num_samples: int
|
||||
n: List[int]
|
||||
seq_groups: List[List[int]]
|
||||
is_first_multi_step: bool = True
|
||||
is_last_step: bool = True
|
||||
virtual_engine: int = 0
|
||||
async_callback: Optional[Callable] = None
|
||||
|
||||
def as_broadcastable_tensor_dict(
|
||||
self) -> Dict[str, Union[int, torch.Tensor]]:
|
||||
tensor_dict = {
|
||||
"token_ids": self.token_ids,
|
||||
"position_ids": self.position_ids,
|
||||
"input_lens": self.input_lens,
|
||||
"t": self.t,
|
||||
"p": self.p,
|
||||
"num_samples": self.num_samples,
|
||||
"n": self.n,
|
||||
"seq_groups": self.seq_groups,
|
||||
"is_first_multi_step": self.is_first_multi_step,
|
||||
"is_last_step": self.is_last_step,
|
||||
"virtual_engine": self.virtual_engine,
|
||||
}
|
||||
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
|
||||
return tensor_dict
|
||||
|
||||
@classmethod
|
||||
def from_broadcasted_tensor_dict(
|
||||
cls: Type["ModelInputForTPU"],
|
||||
tensor_dict: Dict[str, Any],
|
||||
attn_backend: Optional["AttentionBackend"] = None,
|
||||
) -> "ModelInputForTPU":
|
||||
if attn_backend is not None:
|
||||
tensor_dict = _init_attn_metadata_from_tensor_dict(
|
||||
attn_backend, tensor_dict)
|
||||
return cls(**tensor_dict)
|
||||
|
||||
|
||||
class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
is_driver_worker: bool = False,
|
||||
):
|
||||
ModelRunnerBase.__init__(self, vllm_config=vllm_config)
|
||||
self.is_driver_worker = is_driver_worker
|
||||
|
||||
self.block_size = self.cache_config.block_size
|
||||
self.max_num_blocks_per_seq = (self.model_config.max_model_len //
|
||||
self.block_size)
|
||||
self.block_tables = np.zeros(
|
||||
(self.scheduler_config.max_num_seqs, self.max_num_blocks_per_seq),
|
||||
dtype=np.int32)
|
||||
self.attn_backend = get_attn_backend(
|
||||
self.model_config.get_head_size(),
|
||||
self.model_config.dtype,
|
||||
self.cache_config.cache_dtype,
|
||||
self.block_size,
|
||||
self.model_config.is_attention_free,
|
||||
False,
|
||||
)
|
||||
self.cached_step_outputs: List[torch.Tensor] = []
|
||||
|
||||
smem_size = 512 * 1024
|
||||
block_table_size = 4 * self.block_tables.size
|
||||
if block_table_size >= smem_size:
|
||||
logger.warning(
|
||||
"The max_model_len (%d) is too large. This may degrade the "
|
||||
"performance due to the insufficient smem size. Consider "
|
||||
"setting --max-model-len to a smaller value, like %d.",
|
||||
self.model_config.max_model_len,
|
||||
self.model_config.max_model_len /
|
||||
(block_table_size / smem_size))
|
||||
|
||||
def load_model(self) -> None:
|
||||
self.device = self.device_config.device
|
||||
|
||||
# NOTE(woosuk): While the executor assigns the TP ranks to the worker
|
||||
# process, the ranks can be different from the ranks internally assigned
|
||||
# by the xm runtime. Therefore, there is a mismatch in the rank
|
||||
# assignment between the gloo (cpu) runtime and the xm (tpu) runtime.
|
||||
# This is not a problem in linear layers because all-reduce is
|
||||
# rank-agnostic. However, it matters for all-gather as the ranks
|
||||
# determine the order of concatenating the output tensors.
|
||||
# As a workaround, we use the xm's rank assignment only when loading
|
||||
# the embedding weights.
|
||||
xm_tp_rank = xr.global_ordinal()
|
||||
with patch(
|
||||
"vllm.model_executor.layers.vocab_parallel_embedding."
|
||||
"get_tensor_model_parallel_rank",
|
||||
return_value=xm_tp_rank):
|
||||
model = get_model(vllm_config=self.vllm_config)
|
||||
model = model.eval()
|
||||
xm.wait_device_ops()
|
||||
model = ModelWrapper(model)
|
||||
self.model = torch.compile(model,
|
||||
backend="openxla",
|
||||
fullgraph=True,
|
||||
dynamic=False)
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model.model
|
||||
|
||||
def _dummy_run(
|
||||
self,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
exec_mode: ExecutionMode,
|
||||
) -> None:
|
||||
exec_mode = ExecutionMode(exec_mode)
|
||||
if exec_mode.is_prefill():
|
||||
seq_len = (seq_len + 15) // 16 * 16
|
||||
token_ids = torch.zeros((batch_size, seq_len),
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
position_ids = torch.zeros((batch_size, seq_len),
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
slot_mapping = torch.zeros((batch_size, seq_len),
|
||||
dtype=torch.int64,
|
||||
device=self.device)
|
||||
input_lens = torch.ones((batch_size, ),
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
if exec_mode == ExecutionMode.PREFILL:
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
num_prefills=batch_size,
|
||||
num_prefill_tokens=batch_size * seq_len,
|
||||
num_decode_tokens=0,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=False,
|
||||
block_tables=None,
|
||||
context_lens=None,
|
||||
effective_query_lens=None,
|
||||
)
|
||||
else:
|
||||
context_lens = torch.ones((batch_size, ),
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
block_tables = torch.tensor(self.block_tables[:batch_size],
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
effective_query_lens = torch.ones_like(context_lens)
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
num_prefills=batch_size,
|
||||
num_prefill_tokens=batch_size * seq_len,
|
||||
num_decode_tokens=0,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=False,
|
||||
block_tables=block_tables,
|
||||
context_lens=context_lens,
|
||||
effective_query_lens=effective_query_lens,
|
||||
)
|
||||
else:
|
||||
assert seq_len == 1
|
||||
token_ids = torch.zeros((batch_size, seq_len),
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
position_ids = torch.zeros((batch_size, seq_len),
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
slot_mapping = torch.zeros((batch_size, seq_len),
|
||||
dtype=torch.int64,
|
||||
device=self.device)
|
||||
block_tables = torch.zeros(
|
||||
(batch_size, self.max_num_blocks_per_seq),
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
context_lens = torch.ones((batch_size, ),
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
input_lens = torch.ones((batch_size, ),
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=batch_size * seq_len,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=False,
|
||||
block_tables=block_tables,
|
||||
context_lens=context_lens,
|
||||
)
|
||||
t = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
|
||||
p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
|
||||
num_samples = _MAX_NUM_SAMPLES if exec_mode.is_prefill() else 1
|
||||
|
||||
# NOTE(woosuk): There are two stages of compilation: torch.compile and
|
||||
# XLA compilation. Using `mark_dynamic` can reduce the torch.compile
|
||||
# overhead by reusing the FX graph for different shapes.
|
||||
# However, the XLA graph will still require static shapes and needs to
|
||||
# be re-compiled for every different shapes. This overhead is inevitable
|
||||
# in the first run, but can be skipped afterwards as we cache the XLA
|
||||
# graphs in the disk (VLLM_XLA_CACHE_PATH).
|
||||
if exec_mode.is_prefill():
|
||||
# Prefll
|
||||
torch._dynamo.mark_dynamic(token_ids, 1)
|
||||
torch._dynamo.mark_dynamic(position_ids, 1)
|
||||
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1)
|
||||
else:
|
||||
# Decode
|
||||
torch._dynamo.mark_dynamic(token_ids, 0)
|
||||
torch._dynamo.mark_dynamic(position_ids, 0)
|
||||
torch._dynamo.mark_dynamic(input_lens, 0)
|
||||
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
|
||||
torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0)
|
||||
torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0)
|
||||
torch._dynamo.mark_dynamic(t, 0)
|
||||
torch._dynamo.mark_dynamic(p, 0)
|
||||
# Dummy run.
|
||||
with set_forward_context(attn_metadata, self.vllm_config, 0):
|
||||
self.model(token_ids, position_ids, input_lens, t, p, num_samples,
|
||||
kv_caches)
|
||||
|
||||
def warmup_model(
|
||||
self,
|
||||
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
) -> None:
|
||||
# Prefill
|
||||
logger.info("Compiling the model with different input shapes...")
|
||||
start = time.time()
|
||||
for batch_size in [1]:
|
||||
seq_len = 16
|
||||
while seq_len <= self.model_config.max_model_len:
|
||||
self._dummy_run(batch_size,
|
||||
seq_len,
|
||||
kv_caches,
|
||||
exec_mode=ExecutionMode.PREFILL)
|
||||
xm.wait_device_ops()
|
||||
logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len)
|
||||
num_tokens = batch_size * seq_len
|
||||
if num_tokens >= self.scheduler_config.max_num_batched_tokens:
|
||||
break
|
||||
seq_len = seq_len * 2
|
||||
|
||||
end = time.time()
|
||||
logger.info("Compilation for prefill done in %.2f s.", end - start)
|
||||
|
||||
# Prefix prefill
|
||||
if self.cache_config.enable_prefix_caching:
|
||||
logger.info("Compiling the model with different input shapes for "
|
||||
"prefix prefill...")
|
||||
start = time.time()
|
||||
for batch_size in [1]:
|
||||
seq_len = 16
|
||||
while seq_len <= self.model_config.max_model_len:
|
||||
self._dummy_run(batch_size,
|
||||
seq_len,
|
||||
kv_caches,
|
||||
exec_mode=ExecutionMode.PREFIX_PREFILL)
|
||||
xm.wait_device_ops()
|
||||
logger.info("batch_size: %d, seq_len: %d", batch_size,
|
||||
seq_len)
|
||||
num_tokens = batch_size * seq_len
|
||||
if (num_tokens
|
||||
>= self.scheduler_config.max_num_batched_tokens):
|
||||
break
|
||||
seq_len = seq_len * 2
|
||||
end = time.time()
|
||||
logger.info("Compilation for prefix prefill done in %.2f s.",
|
||||
end - start)
|
||||
|
||||
# Decode
|
||||
start = time.time()
|
||||
seq_len = 1
|
||||
batch_size = 8 # Must be in sync with _get_padded_batch_size()
|
||||
while True:
|
||||
self._dummy_run(batch_size,
|
||||
seq_len,
|
||||
kv_caches,
|
||||
exec_mode=ExecutionMode.DECODE)
|
||||
xm.wait_device_ops()
|
||||
logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len)
|
||||
|
||||
if batch_size >= self.scheduler_config.max_num_seqs:
|
||||
break
|
||||
batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2
|
||||
|
||||
end = time.time()
|
||||
logger.info("Compilation for decode done in %.2f s.", end - start)
|
||||
|
||||
def _prepare_prompt(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]:
|
||||
assert len(seq_group_metadata_list) > 0
|
||||
input_tokens: List[int] = []
|
||||
input_positions: List[int] = []
|
||||
prompt_lens: List[int] = []
|
||||
context_lens: List[int] = []
|
||||
slot_mapping: List[int] = []
|
||||
|
||||
for batch_idx, seq_group_metadata in enumerate(
|
||||
seq_group_metadata_list):
|
||||
assert seq_group_metadata.is_prompt
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
assert len(seq_ids) == 1
|
||||
seq_id = seq_ids[0]
|
||||
|
||||
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||
# Could include output tokens when a request is preempted.
|
||||
prompt_tokens = seq_data.get_token_ids()
|
||||
seq_len = len(prompt_tokens)
|
||||
|
||||
num_computed_blocks = len(seq_group_metadata.computed_block_nums)
|
||||
num_computed_tokens = num_computed_blocks * self.block_size
|
||||
if num_computed_tokens > 0:
|
||||
prompt_tokens = prompt_tokens[num_computed_tokens:]
|
||||
context_lens.append(seq_len)
|
||||
else:
|
||||
context_lens.append(0)
|
||||
|
||||
prompt_len = len(prompt_tokens)
|
||||
prompt_lens.append(prompt_len)
|
||||
|
||||
input_tokens.extend(prompt_tokens)
|
||||
input_positions.extend(range(num_computed_tokens, seq_len))
|
||||
|
||||
assert seq_group_metadata.block_tables is not None
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
for i in range(num_computed_tokens, seq_len):
|
||||
block_number = block_table[i // self.block_size]
|
||||
block_offset = i % self.block_size
|
||||
slot = block_number * self.block_size + block_offset
|
||||
slot_mapping.append(slot)
|
||||
if num_computed_tokens > 0:
|
||||
self.block_tables[batch_idx, :len(block_table)] = block_table
|
||||
|
||||
# Add paddings to EACH prompt to the smallest power of 2 that is
|
||||
# greater than or equal to the prompt length.
|
||||
# We pad the seq_len to reduce the compilation overhead.
|
||||
# We execute each prompt individually (i.e., with batch_size 1)
|
||||
# because the FlashAttention kernel does not support ragged inputs.
|
||||
# TODO(woosuk): Use SplashAttention to support ragged inputs.
|
||||
padded_prompt_len = _get_padded_prefill_len(prompt_len)
|
||||
num_paddings = padded_prompt_len - prompt_len
|
||||
input_tokens += [0] * num_paddings
|
||||
input_positions += [0] * num_paddings
|
||||
slot_mapping += [_PAD_SLOT_ID] * num_paddings
|
||||
|
||||
assert len(prompt_lens) > 0
|
||||
num_prefills = len(prompt_lens)
|
||||
input_tokens = torch.tensor(input_tokens,
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
input_positions = torch.tensor(input_positions,
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
slot_mapping = torch.tensor(slot_mapping,
|
||||
dtype=torch.int64,
|
||||
device="cpu")
|
||||
prompt_lens = torch.tensor(prompt_lens,
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
context_lens = torch.tensor(context_lens,
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
block_tables = torch.tensor(self.block_tables[:num_prefills],
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
num_prefills=num_prefills,
|
||||
num_prefill_tokens=0, # NOTE: This is not used.
|
||||
num_decode_tokens=0,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=False,
|
||||
block_tables=block_tables,
|
||||
context_lens=context_lens,
|
||||
effective_query_lens=prompt_lens,
|
||||
)
|
||||
return input_tokens, input_positions, attn_metadata, prompt_lens
|
||||
|
||||
def _prepare_decode(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]:
|
||||
assert len(seq_group_metadata_list) > 0
|
||||
input_tokens: List[List[int]] = []
|
||||
input_positions: List[List[int]] = []
|
||||
slot_mapping: List[List[int]] = []
|
||||
context_lens: List[int] = []
|
||||
|
||||
batch_idx = 0
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
assert not seq_group_metadata.is_prompt
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
for seq_id in seq_ids:
|
||||
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||
generation_token = seq_data.get_last_token_id()
|
||||
input_tokens.append([generation_token])
|
||||
|
||||
seq_len = seq_data.get_len()
|
||||
position = seq_len - 1
|
||||
input_positions.append([position])
|
||||
context_lens.append(seq_len)
|
||||
|
||||
assert seq_group_metadata.block_tables is not None
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
self.block_tables[batch_idx, :len(block_table)] = block_table
|
||||
batch_idx += 1
|
||||
|
||||
block_number = block_table[position // self.block_size]
|
||||
block_offset = position % self.block_size
|
||||
slot = block_number * self.block_size + block_offset
|
||||
slot_mapping.append([slot])
|
||||
|
||||
batch_size = _get_padded_batch_size(batch_idx)
|
||||
num_paddings = batch_size - batch_idx
|
||||
input_tokens = input_tokens + [[0]] * num_paddings
|
||||
input_positions = input_positions + [[0]] * num_paddings
|
||||
slot_mapping = slot_mapping + [[_PAD_SLOT_ID]] * num_paddings
|
||||
context_lens = context_lens + [0] * num_paddings
|
||||
|
||||
input_tokens = torch.tensor(input_tokens,
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
input_positions = torch.tensor(input_positions,
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
slot_mapping = torch.tensor(slot_mapping,
|
||||
dtype=torch.int64,
|
||||
device="cpu")
|
||||
context_lens = torch.tensor(context_lens,
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
block_tables = torch.tensor(self.block_tables[:batch_size],
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
input_lens = torch.tensor([1] * batch_size,
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=batch_size,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=False,
|
||||
block_tables=block_tables,
|
||||
context_lens=context_lens,
|
||||
)
|
||||
return input_tokens, input_positions, attn_metadata, input_lens
|
||||
|
||||
def _prepare_sample(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
padded_batch_size: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
|
||||
assert len(seq_group_metadata_list) > 0
|
||||
t = []
|
||||
p = []
|
||||
n = []
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
sampling_params = seq_group_metadata.sampling_params
|
||||
t.append(sampling_params.temperature)
|
||||
if sampling_params.top_p != 1 and not _ENABLE_TOP_P:
|
||||
raise NotImplementedError(
|
||||
"Top-p sampling is currently disabled for the TPU backend "
|
||||
"due to performance issues.")
|
||||
p.append(sampling_params.top_p)
|
||||
if sampling_params.top_k > 0:
|
||||
raise NotImplementedError(
|
||||
"Top-k sampling is currently disabled for the TPU backend "
|
||||
"due to performance issues.")
|
||||
if sampling_params.n > _MAX_NUM_SAMPLES:
|
||||
raise NotImplementedError(
|
||||
f"Best of > {_MAX_NUM_SAMPLES} is not supported by the TPU "
|
||||
"backend.")
|
||||
n.append(sampling_params.n)
|
||||
if sampling_params.logprobs is not None:
|
||||
raise NotImplementedError(
|
||||
"logprobs is not currently supported by the TPU backend.")
|
||||
if sampling_params.prompt_logprobs is not None:
|
||||
raise NotImplementedError(
|
||||
"prompt_logprobs is not currently supported by the TPU "
|
||||
"backend.")
|
||||
|
||||
# Repeat the sampling params if the seq group has multiple seqs.
|
||||
num_seqs = len(seq_group_metadata.seq_data)
|
||||
t += [t[-1]] * (num_seqs - 1)
|
||||
p += [p[-1]] * (num_seqs - 1)
|
||||
n += [n[-1]] * (num_seqs - 1)
|
||||
|
||||
num_paddings = padded_batch_size - len(t)
|
||||
t += [1.0] * num_paddings
|
||||
p += [1.0] * num_paddings
|
||||
|
||||
t = torch.tensor(t, dtype=torch.float32, device="cpu")
|
||||
p = torch.tensor(p, dtype=torch.float32, device="cpu")
|
||||
return t, p, n
|
||||
|
||||
def prepare_model_input(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
virtual_engine: int = 0,
|
||||
finished_requests_ids: Optional[List[str]] = None,
|
||||
) -> ModelInputForTPU:
|
||||
del finished_requests_ids # Unused.
|
||||
assert virtual_engine == 0
|
||||
assert len(seq_group_metadata_list) > 0
|
||||
# NOTE: We assume that all sequences in the group are all prompts or
|
||||
# all decodes.
|
||||
is_prompt = seq_group_metadata_list[0].is_prompt
|
||||
if is_prompt:
|
||||
inputs = self._prepare_prompt(seq_group_metadata_list)
|
||||
else:
|
||||
inputs = self._prepare_decode(seq_group_metadata_list)
|
||||
input_tokens, input_positions, attn_metadata, input_lens = inputs
|
||||
padded_batch_size = input_tokens.shape[0]
|
||||
t, p, n = self._prepare_sample(seq_group_metadata_list,
|
||||
padded_batch_size)
|
||||
num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
|
||||
|
||||
seq_groups = [
|
||||
list(metadata.seq_data.keys())
|
||||
for metadata in seq_group_metadata_list
|
||||
]
|
||||
return ModelInputForTPU(input_tokens, input_positions, attn_metadata,
|
||||
input_lens, t, p, num_samples, n, seq_groups)
|
||||
|
||||
def make_model_input_from_broadcasted_tensor_dict(
|
||||
self, tensor_dict: Dict[str, Any]) -> ModelInputForTPU:
|
||||
model_input = ModelInputForTPU.from_broadcasted_tensor_dict(
|
||||
tensor_dict, attn_backend=self.attn_backend)
|
||||
return model_input
|
||||
|
||||
@torch.no_grad()
|
||||
def execute_model(
|
||||
self,
|
||||
model_input: ModelInputForTPU,
|
||||
kv_caches: Optional[List[Any]],
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
num_steps: int = 1,
|
||||
) -> List[SamplerOutput]:
|
||||
assert intermediate_tensors is None
|
||||
if not model_input.is_first_multi_step:
|
||||
if not model_input.is_last_step:
|
||||
return []
|
||||
|
||||
use_async_out_proc = model_input.async_callback is not None
|
||||
sampler_outputs = []
|
||||
num_outputs = len(self.cached_step_outputs)
|
||||
for i in range(num_outputs):
|
||||
next_token_ids = self.cached_step_outputs.pop(0)
|
||||
next_token_ids = next_token_ids.cpu().tolist()
|
||||
sampler_output = _make_decode_output(next_token_ids,
|
||||
model_input.seq_groups)
|
||||
sampler_outputs.append(sampler_output)
|
||||
|
||||
if i < num_outputs - 1 and use_async_out_proc:
|
||||
assert model_input.async_callback is not None
|
||||
ctx = model_input.async_callback.keywords[ # type: ignore
|
||||
"ctx"]
|
||||
ctx.append_output(
|
||||
outputs=[sampler_output],
|
||||
seq_group_metadata_list=ctx.seq_group_metadata_list,
|
||||
scheduler_outputs=ctx.scheduler_outputs,
|
||||
is_async=False,
|
||||
is_last_step=False,
|
||||
is_first_step_output=i == 0)
|
||||
model_input.async_callback()
|
||||
if use_async_out_proc:
|
||||
return [sampler_outputs[-1]]
|
||||
else:
|
||||
return sampler_outputs
|
||||
|
||||
is_prompt = model_input.attn_metadata.num_prefills > 0
|
||||
if is_prompt:
|
||||
assert num_steps == 1
|
||||
# NOTE(woosuk): Since the FlashAttention kernel does not support
|
||||
# ragged inputs, we split the prompts into different batches and
|
||||
# process them separately. This is a temporary hack that should be
|
||||
# optimized by using SplashAttention.
|
||||
orig_slot_mapping = model_input.attn_metadata.slot_mapping
|
||||
orig_block_tables = model_input.attn_metadata.block_tables
|
||||
orig_context_lens = model_input.attn_metadata.context_lens
|
||||
orig_effective_query_lens = \
|
||||
model_input.attn_metadata.effective_query_lens
|
||||
batch_size = model_input.input_lens.shape[0]
|
||||
start_idx = 0
|
||||
next_token_ids = []
|
||||
for i in range(batch_size):
|
||||
# Get the actual prefill_len.
|
||||
prefill_len = model_input.input_lens[i:i + 1].item()
|
||||
prefill_len = _get_padded_prefill_len(prefill_len)
|
||||
end_idx = start_idx + prefill_len
|
||||
|
||||
token_ids = model_input.token_ids[None, start_idx:end_idx].to(
|
||||
self.device)
|
||||
position_ids = model_input.position_ids[None,
|
||||
start_idx:end_idx].to(
|
||||
self.device)
|
||||
attn_metadata = model_input.attn_metadata
|
||||
attn_metadata.num_prefills = 1
|
||||
attn_metadata.slot_mapping = orig_slot_mapping[
|
||||
None, start_idx:end_idx].to(self.device)
|
||||
if orig_context_lens[i].item() > 0:
|
||||
attn_metadata.context_lens = orig_context_lens[i:i + 1].to(
|
||||
self.device)
|
||||
attn_metadata.block_tables = orig_block_tables[
|
||||
i].unsqueeze(0).to(self.device)
|
||||
attn_metadata.effective_query_lens = \
|
||||
orig_effective_query_lens[i:i + 1].to(self.device)
|
||||
else:
|
||||
attn_metadata.context_lens = None
|
||||
attn_metadata.block_tables = None
|
||||
attn_metadata.effective_query_lens = None
|
||||
input_lens = model_input.input_lens[i:i + 1].to(self.device)
|
||||
t = model_input.t[i:i + 1].to(self.device)
|
||||
p = model_input.p[i:i + 1].to(self.device)
|
||||
with set_forward_context(model_input.attn_metadata,
|
||||
self.vllm_config,
|
||||
model_input.virtual_engine):
|
||||
output_token_ids = self.model(token_ids, position_ids,
|
||||
input_lens, t, p,
|
||||
model_input.num_samples,
|
||||
kv_caches)
|
||||
next_token_ids.append(output_token_ids[0])
|
||||
start_idx = end_idx
|
||||
|
||||
if model_input.async_callback is not None:
|
||||
model_input.async_callback()
|
||||
# Retrieve the outputs to CPU.
|
||||
next_token_ids = [
|
||||
output_token_ids.cpu().tolist()
|
||||
for output_token_ids in next_token_ids
|
||||
]
|
||||
|
||||
# NOTE(woosuk): Minimal code to construct the sampler outputs.
|
||||
# The TPU backend does not reuse the sampler, since the TPU backend
|
||||
# does not support advanced sampling parameters such as logprobs.
|
||||
zero_logprob = Logprob(0.0)
|
||||
sampler_outputs = []
|
||||
for i, seq_group in enumerate(model_input.seq_groups):
|
||||
seq_ids = seq_group
|
||||
assert len(seq_ids) == 1
|
||||
seq_id = seq_ids[0]
|
||||
seq_outputs = []
|
||||
for j in range(model_input.n[i]):
|
||||
next_token_id = next_token_ids[i][j]
|
||||
seq_outputs.append(
|
||||
SequenceOutput(seq_id, next_token_id,
|
||||
{next_token_id: zero_logprob}))
|
||||
sampler_outputs.append(
|
||||
CompletionSequenceGroupOutput(seq_outputs, None))
|
||||
return [SamplerOutput(sampler_outputs)]
|
||||
else:
|
||||
token_ids = model_input.token_ids.to(self.device)
|
||||
position_ids = model_input.position_ids.to(self.device)
|
||||
attn_metadata = model_input.attn_metadata
|
||||
attn_metadata.slot_mapping = attn_metadata.slot_mapping.to(
|
||||
self.device)
|
||||
attn_metadata.block_tables = attn_metadata.block_tables.to(
|
||||
self.device)
|
||||
attn_metadata.context_lens = attn_metadata.context_lens.to(
|
||||
self.device)
|
||||
t = model_input.t.to(self.device)
|
||||
p = model_input.p.to(self.device)
|
||||
input_lens = model_input.input_lens.to(self.device)
|
||||
for i in range(num_steps):
|
||||
slot_mapping = attn_metadata.slot_mapping
|
||||
with set_forward_context(model_input.attn_metadata,
|
||||
self.vllm_config,
|
||||
model_input.virtual_engine):
|
||||
output_token_ids = self.model(token_ids, position_ids,
|
||||
input_lens, t, p,
|
||||
model_input.num_samples,
|
||||
kv_caches)
|
||||
self.cached_step_outputs.append(output_token_ids)
|
||||
|
||||
if i < num_steps - 1:
|
||||
# Prepare the inputs for the next step.
|
||||
token_ids = output_token_ids.unsqueeze(dim=1).int()
|
||||
position_ids = position_ids + 1
|
||||
attn_metadata.context_lens = attn_metadata.context_lens + 1
|
||||
|
||||
block_tables = attn_metadata.block_tables
|
||||
block_number = block_tables.gather(
|
||||
1,
|
||||
position_ids.long() // self.block_size)
|
||||
block_offset = position_ids % self.block_size
|
||||
|
||||
is_padding = slot_mapping == _PAD_SLOT_ID
|
||||
slot_mapping = block_number * self.block_size + block_offset
|
||||
slot_mapping = slot_mapping.long()
|
||||
slot_mapping = torch.where(is_padding, _PAD_SLOT_ID,
|
||||
slot_mapping)
|
||||
attn_metadata.slot_mapping = slot_mapping
|
||||
|
||||
if model_input.async_callback is not None:
|
||||
model_input.async_callback()
|
||||
|
||||
if num_steps > 1:
|
||||
return []
|
||||
# Retrieve the outputs to CPU.
|
||||
next_token_ids = self.cached_step_outputs.pop(0)
|
||||
next_token_ids = next_token_ids.cpu().tolist()
|
||||
sampler_output = _make_decode_output(next_token_ids,
|
||||
model_input.seq_groups)
|
||||
return [sampler_output]
|
||||
|
||||
|
||||
class ModelWrapper(nn.Module):
|
||||
|
||||
def __init__(self, model: nn.Module):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
token_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
input_lens: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
p: torch.Tensor,
|
||||
num_samples: int,
|
||||
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
) -> torch.Tensor:
|
||||
"""Executes the forward pass of the model and samples the next token.
|
||||
|
||||
Args:
|
||||
token_ids: The input token IDs of shape [batch_size, seq_len].
|
||||
position_ids: The input position IDs of shape [batch_size, seq_len].
|
||||
input_lens: The actual input lengths of shape [batch_size].
|
||||
t: The sampling temperature of shape [batch_size].
|
||||
p: The top-p probability of shape [batch_size].
|
||||
num_samples: Number of samples to draw from each logits vector.
|
||||
kv_caches: The key and value caches. They can be None during the
|
||||
memory profiling at initialization.
|
||||
"""
|
||||
batch_size, seq_len = token_ids.shape
|
||||
# Calculate the positions to sample from.
|
||||
start_indices = torch.arange(
|
||||
batch_size, dtype=torch.int32, device=input_lens.device) * seq_len
|
||||
logits_indices = start_indices + input_lens - 1
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
|
||||
# FIXME(woosuk): This is a temporary hack to avoid using the existing
|
||||
# sampler and sampling metadata.
|
||||
sampling_metadata = SamplingMetadata(
|
||||
seq_groups=[],
|
||||
selected_token_indices=logits_indices,
|
||||
categorized_sample_indices={},
|
||||
num_prompts=attn_metadata.num_prefills,
|
||||
)
|
||||
|
||||
# Skip this in memory profiling at initialization.
|
||||
if kv_caches[0][0].numel() > 0:
|
||||
# index_copy_(slot_mapping) only works when the inserted dimension
|
||||
# is 0. However, the KV cache in the Pallas backend has the shape
|
||||
# [num_kv_heads, num_blocks, block_size, head_size]. To make it
|
||||
# work, we need to flatten the first three dimensions and modify
|
||||
# the slot_mapping accordingly.
|
||||
num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape
|
||||
slot_mapping = attn_metadata.slot_mapping
|
||||
slot_mapping = slot_mapping.flatten()
|
||||
head_indices = torch.arange(0,
|
||||
num_kv_heads,
|
||||
device=slot_mapping.device,
|
||||
dtype=slot_mapping.dtype)
|
||||
head_indices *= block_size * num_blocks
|
||||
slot_mapping = slot_mapping.repeat_interleave(num_kv_heads).view(
|
||||
-1, num_kv_heads)
|
||||
slot_mapping = slot_mapping + head_indices.view(1, -1)
|
||||
slot_mapping = slot_mapping.flatten()
|
||||
attn_metadata.slot_mapping = slot_mapping
|
||||
|
||||
hidden_states = self.model(token_ids, position_ids)
|
||||
hidden_states = hidden_states.flatten(0, 1)
|
||||
logits = self.model.compute_logits(hidden_states, sampling_metadata)
|
||||
|
||||
# Argmax sampling.
|
||||
argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True)
|
||||
argmax_token_ids = argmax_token_ids.repeat(1, num_samples)
|
||||
|
||||
# Zero temperature means greedy decoding. Avoid division by zero.
|
||||
nonzero_t = torch.where(t != 0, t, 1.0)
|
||||
logits = logits / nonzero_t.unsqueeze(dim=1)
|
||||
if _ENABLE_TOP_P:
|
||||
logits = _apply_top_p(logits, p.unsqueeze(dim=1))
|
||||
|
||||
# Random sampling.
|
||||
probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
|
||||
sampled_token_ids = torch.multinomial(probs,
|
||||
num_samples,
|
||||
replacement=True)
|
||||
if num_samples == 1:
|
||||
argmax_token_ids = argmax_token_ids.squeeze(dim=-1)
|
||||
sampled_token_ids = sampled_token_ids.squeeze(dim=-1)
|
||||
next_token_ids = torch.where(t != 0, sampled_token_ids,
|
||||
argmax_token_ids)
|
||||
return next_token_ids
|
||||
|
||||
|
||||
def _get_padded_prefill_len(x: int) -> int:
|
||||
# NOTE(woosuk): The pallas FlashAttention kernel requires the sequence
|
||||
# length to be a multiple of 16. We pad the prompt length to the nearest
|
||||
# multiple of 16. This is also good for performance.
|
||||
if x <= 16:
|
||||
return 16
|
||||
return 1 << (x - 1).bit_length()
|
||||
|
||||
|
||||
def _get_padded_batch_size(batch_size: int) -> int:
|
||||
# The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16.
|
||||
# To meet this requirement in the simplest way, we set the minimal batch
|
||||
# size to 8.
|
||||
if batch_size <= 8:
|
||||
return 8
|
||||
else:
|
||||
return ((batch_size + 15) // 16) * 16
|
||||
|
||||
|
||||
def _apply_top_p(logits: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
|
||||
logits_sorted = torch.sort(logits, dim=-1, descending=True).values
|
||||
sorted_cum_probs = torch.cumsum(logits_sorted.softmax(dim=-1), dim=-1)
|
||||
cutoff_index = torch.sum(sorted_cum_probs < p, dim=-1, keepdim=True)
|
||||
cutoff_logit = torch.gather(logits_sorted, -1, cutoff_index)
|
||||
logits = logits.masked_fill_(logits < cutoff_logit, -float("inf"))
|
||||
return logits
|
||||
|
||||
|
||||
def _make_decode_output(
|
||||
next_token_ids: List[int],
|
||||
seq_groups: List[List[int]],
|
||||
) -> SamplerOutput:
|
||||
zero_logprob = Logprob(0.0)
|
||||
sampler_outputs = []
|
||||
batch_idx = 0
|
||||
for seq_group in seq_groups:
|
||||
seq_ids = seq_group
|
||||
seq_outputs = []
|
||||
for seq_id in seq_ids:
|
||||
next_token_id = next_token_ids[batch_idx]
|
||||
seq_outputs.append(
|
||||
SequenceOutput(seq_id, next_token_id,
|
||||
{next_token_id: zero_logprob}))
|
||||
batch_idx += 1
|
||||
sampler_outputs.append(CompletionSequenceGroupOutput(
|
||||
seq_outputs, None))
|
||||
return SamplerOutput(sampler_outputs)
|
||||
341
vllm/worker/tpu_worker.py
Normal file
341
vllm/worker/tpu_worker.py
Normal file
@@ -0,0 +1,341 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.debug.profiler as xp
|
||||
import torch_xla.runtime as xr
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, bind_kv_cache, get_dtype_size
|
||||
from vllm.worker.tpu_model_runner import ExecutionMode, TPUModelRunner
|
||||
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
|
||||
LoRANotSupportedWorkerBase, WorkerBase,
|
||||
WorkerInput)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TPUWorker(LoRANotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
is_driver_worker: bool,
|
||||
) -> None:
|
||||
WorkerBase.__init__(self, vllm_config=vllm_config)
|
||||
self.parallel_config.rank = rank
|
||||
self.local_rank = local_rank
|
||||
self.rank = rank
|
||||
self.distributed_init_method = distributed_init_method
|
||||
self.is_driver_worker = is_driver_worker
|
||||
|
||||
assert self.device_config.device_type == "tpu"
|
||||
if self.cache_config.cache_dtype == "auto":
|
||||
self.cache_dtype = self.model_config.dtype
|
||||
else:
|
||||
self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
|
||||
self.cache_config.cache_dtype]
|
||||
|
||||
self.model_runner: TPUModelRunner = TPUModelRunner(
|
||||
vllm_config=vllm_config, is_driver_worker=is_driver_worker)
|
||||
|
||||
if self.model_config.seed is None:
|
||||
self.model_config.seed = 0
|
||||
|
||||
if vllm_config.lora_config is not None:
|
||||
raise NotImplementedError(
|
||||
"The V0 TPU backend doesn't support LoRA serving")
|
||||
|
||||
def init_device(self) -> None:
|
||||
os.environ["PJRT_DEVICE"] = "TPU"
|
||||
torch.set_grad_enabled(False)
|
||||
torch.set_default_dtype(self.model_config.dtype)
|
||||
|
||||
# NOTE(woosuk): This is just to initialize the TP group and broadcast
|
||||
# the input objects on CPU. The all-reduce and all-gather ops on TPU
|
||||
# are invoked by `xm.all_reduce` and `xm.all_gather` which use their
|
||||
# own context.
|
||||
init_distributed_environment(
|
||||
world_size=self.parallel_config.world_size,
|
||||
rank=self.rank,
|
||||
local_rank=self.local_rank,
|
||||
distributed_init_method=self.distributed_init_method,
|
||||
backend="gloo",
|
||||
)
|
||||
ensure_model_parallel_initialized(
|
||||
self.parallel_config.tensor_parallel_size,
|
||||
self.parallel_config.pipeline_parallel_size)
|
||||
|
||||
# Device initialization should happen after initializing the distributed
|
||||
# runtime.
|
||||
self.device = xm.xla_device()
|
||||
self.device_config.device = self.device
|
||||
|
||||
# Set random seed.
|
||||
set_random_seed(self.model_config.seed)
|
||||
xm.set_rng_state(self.model_config.seed, self.device)
|
||||
|
||||
# Increase the cache size limit, which is the maximum number of
|
||||
# dynamo graphs that can be compiled.
|
||||
# NOTE(woosuk): Usually, we compile 10-15 graphs for prefill and
|
||||
# 30-40 graphs for decode. 128 is an arbitrary safe number.
|
||||
torch._dynamo.config.cache_size_limit = 128
|
||||
# Use persistent cache to avoid XLA recompilation.
|
||||
# NOTE(woosuk): Set per-rank cache path since different ranks
|
||||
# can have slightly different XLA graphs.
|
||||
world_size = self.parallel_config.world_size
|
||||
rank = xr.global_ordinal()
|
||||
# The PyTorch/XLA compilation cache uses the Torch IR to generate keys.
|
||||
# Consequently, changes in optimization flags, which affect compilation
|
||||
# results, don't change the cache key. This can result in the wrong
|
||||
# compilation being used. To prevent this, disabling the XLA compilation
|
||||
# cache during development is recommended.We can disable it by
|
||||
# `export VLLM_XLA_CACHE_PATH=`
|
||||
if envs.VLLM_XLA_CACHE_PATH:
|
||||
per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH,
|
||||
f"tp{world_size}_rank{rank}")
|
||||
xr.initialize_cache(per_rank_path, readonly=False)
|
||||
|
||||
self.profiler = None
|
||||
if envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1:
|
||||
# For TPU, we can only have 1 active profiler session for 1 profiler
|
||||
# server. So we only profile on rank0.
|
||||
self.profile_dir = envs.VLLM_TORCH_PROFILER_DIR
|
||||
logger.info("Profiling enabled. Traces will be saved to: %s",
|
||||
self.profile_dir)
|
||||
self.profiler = xp.start_server(9012)
|
||||
|
||||
def start_profile(self):
|
||||
if self.rank < 1:
|
||||
if self.profiler is None:
|
||||
raise RuntimeError("Profiler is not enabled.")
|
||||
xp.start_trace(self.profile_dir)
|
||||
|
||||
def stop_profile(self):
|
||||
if self.rank < 1:
|
||||
if self.profiler is None:
|
||||
raise RuntimeError("Profiler is not enabled.")
|
||||
xp.stop_trace()
|
||||
|
||||
def load_model(self):
|
||||
self.model_runner.load_model()
|
||||
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||
head_size = self.model_config.get_head_size()
|
||||
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
|
||||
|
||||
# use an empty tensor instead of `None`` to force Dynamo to pass
|
||||
# it by reference, rather by specializing on the value ``None``.
|
||||
# the `dtype` argument does not matter, and we use `float32` as
|
||||
# a placeholder (it has wide hardware support).
|
||||
kv_caches = [(torch.tensor([], dtype=torch.float32,
|
||||
device=self.device),
|
||||
torch.tensor([], dtype=torch.float32,
|
||||
device=self.device))
|
||||
for _ in range(num_layers)]
|
||||
bind_kv_cache(self.compilation_config.static_forward_context,
|
||||
[kv_caches])
|
||||
self.model_runner._dummy_run(
|
||||
batch_size=1,
|
||||
seq_len=self.scheduler_config.max_num_batched_tokens,
|
||||
kv_caches=kv_caches,
|
||||
exec_mode=ExecutionMode.PREFILL,
|
||||
)
|
||||
# Synchronize before measuring the memory usage.
|
||||
xm.wait_device_ops()
|
||||
|
||||
# Get the maximum amount of memory used by the model weights and
|
||||
# intermediate activations.
|
||||
m = xm.get_memory_info(self.device)
|
||||
total_memory_size = m["bytes_limit"]
|
||||
profiled = m["peak_bytes_used"] # Weights + intermediate activations.
|
||||
|
||||
# Calculate the TPU KV cache size based on profiling.
|
||||
usable_memory_size = int(total_memory_size *
|
||||
self.cache_config.gpu_memory_utilization)
|
||||
tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0)
|
||||
dtype_bytes = get_dtype_size(self.cache_dtype)
|
||||
block_size_bytes = (dtype_bytes * self.cache_config.block_size *
|
||||
num_layers * 2 * head_size * num_kv_heads)
|
||||
num_tpu_blocks = tpu_kv_cache_bytes // block_size_bytes
|
||||
num_tpu_blocks = (num_tpu_blocks // 8) * 8 # Round down to 8.
|
||||
|
||||
# Calculate the CPU KV cache size based on the config.
|
||||
num_cpu_blocks = int(self.cache_config.swap_space_bytes //
|
||||
block_size_bytes)
|
||||
num_cpu_blocks = (num_cpu_blocks // 8) * 8 # Round down to 8.
|
||||
return num_tpu_blocks, num_cpu_blocks
|
||||
|
||||
def initialize_cache(
|
||||
self,
|
||||
num_gpu_blocks: int,
|
||||
num_cpu_blocks: int,
|
||||
) -> None:
|
||||
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||
self.block_size = self.cache_config.block_size
|
||||
|
||||
dtype = self.cache_dtype
|
||||
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
|
||||
head_size = self.model_config.get_head_size()
|
||||
|
||||
self.cpu_cache: List[Tuple[torch.Tensor, torch.Tensor]] = []
|
||||
self.tpu_cache: List[Tuple[torch.Tensor, torch.Tensor]] = []
|
||||
tpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape(
|
||||
num_gpu_blocks, self.block_size, num_kv_heads, head_size)
|
||||
cpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape(
|
||||
num_cpu_blocks, self.block_size, num_kv_heads, head_size)
|
||||
for _ in range(num_layers):
|
||||
tpu_k_cache = torch.zeros(tpu_cache_shape,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
tpu_v_cache = torch.zeros_like(tpu_k_cache)
|
||||
self.tpu_cache.append((tpu_k_cache, tpu_v_cache))
|
||||
cpu_k_cache = torch.zeros(cpu_cache_shape,
|
||||
dtype=dtype,
|
||||
device="cpu")
|
||||
cpu_v_cache = torch.zeros_like(cpu_k_cache)
|
||||
self.cpu_cache.append((cpu_k_cache, cpu_v_cache))
|
||||
bind_kv_cache(self.compilation_config.static_forward_context,
|
||||
[self.tpu_cache])
|
||||
self._warmup_model()
|
||||
|
||||
def _warmup_model(self) -> None:
|
||||
# FIXME(woosuk): Here we are abusing `enforce_eager` which is defined
|
||||
# for CUDA graphs. We should refactor this part.
|
||||
if not self.model_config.enforce_eager:
|
||||
# Warm up the model with all possible input shapes so that
|
||||
# compilation never happens during the actual execution.
|
||||
# This may take ~30 mins for the first run and ~20 mins for the
|
||||
# subsequent runs.
|
||||
# If `enforce_eager` is True, the ahead-of-time compilation is
|
||||
# skipped and the compilation happens during the actual execution,
|
||||
# which is bad for performance but useful for development.
|
||||
self.model_runner.warmup_model(self.tpu_cache)
|
||||
|
||||
def get_cache_block_size_bytes(self) -> int:
|
||||
head_size = self.model_config.get_head_size()
|
||||
num_heads = self.model_config.get_num_kv_heads(self.parallel_config)
|
||||
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||
|
||||
key_cache_block = self.cache_config.block_size * num_heads * head_size
|
||||
value_cache_block = key_cache_block
|
||||
total = num_layers * (key_cache_block + value_cache_block)
|
||||
dtype_size = get_dtype_size(self.cache_dtype)
|
||||
return dtype_size * total
|
||||
|
||||
@property
|
||||
def do_metadata_broadcast(self) -> bool:
|
||||
return self.parallel_config.tensor_parallel_size > 1
|
||||
|
||||
@property
|
||||
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
|
||||
# NOTE(woosuk): This assumes virtual_engine == 0, i.e., no pipeline
|
||||
# parallelism.
|
||||
return [self.tpu_cache]
|
||||
|
||||
@property
|
||||
def cache_engines(self) -> Optional[List[CacheEngine]]:
|
||||
return None
|
||||
|
||||
def prepare_worker_input(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
) -> WorkerInput:
|
||||
virtual_engine = execute_model_req.virtual_engine
|
||||
num_seq_groups = len(execute_model_req.seq_group_metadata_list)
|
||||
blocks_to_swap_in = _make_src_to_dst(
|
||||
execute_model_req.blocks_to_swap_in, "cpu", self.device)
|
||||
blocks_to_swap_out = _make_src_to_dst(
|
||||
execute_model_req.blocks_to_swap_out, self.device, "cpu")
|
||||
blocks_to_copy = _make_src_to_dst(execute_model_req.blocks_to_copy,
|
||||
self.device, self.device)
|
||||
return WorkerInput(
|
||||
num_seq_groups=num_seq_groups,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
virtual_engine=virtual_engine,
|
||||
)
|
||||
|
||||
def execute_worker(self, worker_input: WorkerInput) -> None:
|
||||
virtual_engine = worker_input.virtual_engine
|
||||
assert virtual_engine == 0
|
||||
attn_backend = self.model_runner.attn_backend
|
||||
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||
|
||||
# Issue cache operations.
|
||||
if worker_input.blocks_to_swap_in is not None:
|
||||
src_indices, dst_indices = worker_input.blocks_to_swap_in
|
||||
if src_indices.numel() > 0:
|
||||
# Swap from CPU to TPU.
|
||||
for i in range(num_layers):
|
||||
tpu_k_cache, tpu_v_cache = self.tpu_cache[i]
|
||||
cpu_k_cache, cpu_v_cache = self.cpu_cache[i]
|
||||
k = cpu_k_cache[:, src_indices].to(self.device)
|
||||
v = cpu_v_cache[:, src_indices].to(self.device)
|
||||
_insert_kv(k, v, dst_indices, tpu_k_cache, tpu_v_cache)
|
||||
|
||||
if worker_input.blocks_to_swap_out is not None:
|
||||
src_indices, dst_indices = worker_input.blocks_to_swap_out
|
||||
if src_indices.numel() > 0:
|
||||
# Swap from TPU to CPU.
|
||||
for i in range(num_layers):
|
||||
tpu_k_cache, tpu_v_cache = self.tpu_cache[i]
|
||||
cpu_k_cache, cpu_v_cache = self.cpu_cache[i]
|
||||
cpu_k_cache[:, dst_indices] = tpu_k_cache[:, src_indices]
|
||||
cpu_v_cache[:, dst_indices] = tpu_v_cache[:, src_indices]
|
||||
|
||||
if worker_input.blocks_to_copy is not None:
|
||||
src_indices, dst_indices = worker_input.blocks_to_copy
|
||||
if src_indices.numel() > 0:
|
||||
attn_backend.copy_blocks(self.tpu_cache,
|
||||
(src_indices, dst_indices))
|
||||
|
||||
|
||||
def _make_src_to_dst(
|
||||
mapping: List[Tuple[int, int]],
|
||||
src_device: Union[torch.device, str],
|
||||
dst_device: Union[torch.device, str],
|
||||
) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
|
||||
if not mapping:
|
||||
return None
|
||||
|
||||
src_indices = [i for i, _ in mapping]
|
||||
dst_indices = [i for _, i in mapping]
|
||||
src_indices = torch.tensor(src_indices,
|
||||
device=src_device,
|
||||
dtype=torch.int64)
|
||||
dst_indices = torch.tensor(dst_indices,
|
||||
device=dst_device,
|
||||
dtype=torch.int64)
|
||||
return src_indices, dst_indices
|
||||
|
||||
|
||||
@torch.compile(backend="openxla")
|
||||
def _insert_kv(
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
tpu_k_cache: torch.Tensor,
|
||||
tpu_v_cache: torch.Tensor,
|
||||
) -> None:
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(tpu_k_cache, True)
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(tpu_v_cache, True)
|
||||
tpu_k_cache[:, indices] = k
|
||||
tpu_v_cache[:, indices] = v
|
||||
53
vllm/worker/utils.py
Normal file
53
vllm/worker/utils.py
Normal file
@@ -0,0 +1,53 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
'''
|
||||
Worker-related helper functions.
|
||||
'''
|
||||
|
||||
from vllm.utils import STR_NOT_IMPL_ENC_DEC_ERR_STRS
|
||||
from vllm.worker.model_runner import GPUModelRunnerBase
|
||||
|
||||
|
||||
def assert_enc_dec_mr_supported_scenario(
|
||||
enc_dec_mr: GPUModelRunnerBase) -> None:
|
||||
'''
|
||||
Asserted that the provided encoder/decoder model runner instance reflects
|
||||
a supported scenario.
|
||||
'''
|
||||
|
||||
# Reminder: Please update docs/features/compatibility_matrix.md
|
||||
# If the feature combo become valid
|
||||
|
||||
if enc_dec_mr.cache_config.enable_prefix_caching:
|
||||
raise NotImplementedError(
|
||||
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE'])
|
||||
|
||||
if enc_dec_mr.sliding_window is not None:
|
||||
raise NotImplementedError(
|
||||
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_SWA'])
|
||||
|
||||
if enc_dec_mr.scheduler_config.chunked_prefill_enabled:
|
||||
raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_ERR_STRS[
|
||||
'STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL'])
|
||||
|
||||
if getattr(enc_dec_mr.model_config.hf_config, 'attn_logit_softcapping',
|
||||
None) is not None:
|
||||
raise NotImplementedError(
|
||||
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP']
|
||||
)
|
||||
|
||||
if enc_dec_mr.lora_config is not None:
|
||||
raise NotImplementedError(
|
||||
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_LORA'])
|
||||
|
||||
if enc_dec_mr.parallel_config.pipeline_parallel_size > 1:
|
||||
raise NotImplementedError(
|
||||
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_PP'])
|
||||
|
||||
if enc_dec_mr.scheduler_config.num_lookahead_slots > 0:
|
||||
raise NotImplementedError(
|
||||
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_SPEC_DEC'])
|
||||
|
||||
if enc_dec_mr.prompt_adapter_config is not None:
|
||||
raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_ERR_STRS[
|
||||
'STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER'])
|
||||
588
vllm/worker/worker.py
Normal file
588
vllm/worker/worker.py
Normal file
@@ -0,0 +1,588 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""A GPU worker class."""
|
||||
import gc
|
||||
import os
|
||||
from typing import Dict, List, Optional, Set, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.device_allocator.cumem import CuMemAllocator
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment,
|
||||
set_custom_all_reduce)
|
||||
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
|
||||
SequenceGroupMetadata, SequenceGroupMetadataDelta)
|
||||
from vllm.utils import (GiB_bytes, MemorySnapshot, bind_kv_cache,
|
||||
memory_profiling)
|
||||
from vllm.worker.cache_engine import CacheEngine
|
||||
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
|
||||
from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
|
||||
from vllm.worker.pooling_model_runner import PoolingModelRunner
|
||||
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
|
||||
WorkerInput)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Worker(LocalOrDistributedWorkerBase):
|
||||
"""A worker class that executes (a partition of) the model on a GPU.
|
||||
|
||||
Each worker is associated with a single GPU. The worker is responsible for
|
||||
maintaining the KV cache and executing the model on the GPU. In case of
|
||||
distributed inference, each worker is assigned a partition of the model.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
is_driver_worker: bool = False,
|
||||
model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None,
|
||||
) -> None:
|
||||
WorkerBase.__init__(self, vllm_config)
|
||||
self.parallel_config.rank = rank
|
||||
self.local_rank = local_rank
|
||||
self.rank = rank
|
||||
self.distributed_init_method = distributed_init_method
|
||||
self.is_driver_worker = is_driver_worker
|
||||
if self.model_config.trust_remote_code:
|
||||
# note: lazy import to avoid importing torch before initializing
|
||||
from vllm.utils import init_cached_hf_modules
|
||||
init_cached_hf_modules()
|
||||
|
||||
# Return hidden states from target model if the draft model is an
|
||||
# mlp_speculator
|
||||
speculative_config = self.speculative_config
|
||||
model_config = self.model_config
|
||||
speculative_args = {} if speculative_config is None \
|
||||
or (speculative_config.draft_model_config.hf_config.model_type ==
|
||||
model_config.hf_config.model_type) \
|
||||
or (speculative_config.draft_model_config.hf_config.model_type
|
||||
not in ("medusa",
|
||||
"mlp_speculator",
|
||||
"eagle",
|
||||
"deepseek_mtp",
|
||||
"glm4_moe_mtp",
|
||||
"mimo_mtp")) \
|
||||
else {"return_hidden_states": True}
|
||||
|
||||
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
|
||||
if model_config.runner_type == "pooling":
|
||||
ModelRunnerClass = PoolingModelRunner
|
||||
elif self.model_config.is_encoder_decoder:
|
||||
ModelRunnerClass = EncoderDecoderModelRunner
|
||||
self.model_runner: GPUModelRunnerBase = ModelRunnerClass(
|
||||
vllm_config=self.vllm_config,
|
||||
kv_cache_dtype=self.cache_config.cache_dtype,
|
||||
is_driver_worker=is_driver_worker,
|
||||
**speculative_args,
|
||||
)
|
||||
if model_runner_cls is not None:
|
||||
self.model_runner = model_runner_cls(self.model_runner)
|
||||
|
||||
# Uninitialized cache engine. Will be initialized by
|
||||
# initialize_cache.
|
||||
self.cache_engine: List[CacheEngine]
|
||||
# Initialize gpu_cache as pooling models don't initialize kv_caches
|
||||
self.gpu_cache: Optional[List[List[torch.Tensor]]] = None
|
||||
self._seq_group_metadata_cache: Dict[str, SequenceGroupMetadata] = {}
|
||||
|
||||
# Buffers saved before sleep
|
||||
self._sleep_saved_buffers: Dict[str, torch.Tensor] = {}
|
||||
|
||||
# Torch profiler. Enabled and configured through env vars:
|
||||
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
|
||||
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
|
||||
logger.info("Profiling enabled. Traces will be saved to: %s",
|
||||
torch_profiler_trace_dir)
|
||||
self.profiler = torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
with_stack=True,
|
||||
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
||||
torch_profiler_trace_dir, use_gzip=True))
|
||||
else:
|
||||
self.profiler = None
|
||||
|
||||
def start_profile(self):
|
||||
if self.profiler is None:
|
||||
raise RuntimeError("Profiler is not enabled.")
|
||||
self.profiler.start()
|
||||
|
||||
def stop_profile(self):
|
||||
if self.profiler is None:
|
||||
raise RuntimeError("Profiler is not enabled.")
|
||||
self.profiler.stop()
|
||||
print(
|
||||
self.profiler.key_averages().table(sort_by="self_cuda_time_total"))
|
||||
|
||||
def sleep(self, level: int = 1) -> None:
|
||||
free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
|
||||
|
||||
# Save the buffers before level 2 sleep
|
||||
if level == 2:
|
||||
model = self.model_runner.model
|
||||
self._sleep_saved_buffers = {
|
||||
name: buffer.cpu().clone()
|
||||
for name, buffer in model.named_buffers()
|
||||
}
|
||||
|
||||
allocator = CuMemAllocator.get_instance()
|
||||
allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple())
|
||||
free_bytes_after_sleep, total = torch.cuda.mem_get_info()
|
||||
freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
|
||||
used_bytes = total - free_bytes_after_sleep
|
||||
assert freed_bytes >= 0, "Memory usage increased after sleeping."
|
||||
logger.info(
|
||||
"Sleep mode freed %.2f GiB memory, "
|
||||
"%.2f GiB memory is still in use.", freed_bytes / GiB_bytes,
|
||||
used_bytes / GiB_bytes)
|
||||
|
||||
def wake_up(self, tags: Optional[list[str]] = None) -> None:
|
||||
allocator = CuMemAllocator.get_instance()
|
||||
allocator.wake_up(tags=tags)
|
||||
|
||||
# Restore the buffers after level 2 sleep
|
||||
if len(self._sleep_saved_buffers):
|
||||
model = self.model_runner.model
|
||||
for name, buffer in model.named_buffers():
|
||||
if name in self._sleep_saved_buffers:
|
||||
buffer.data.copy_(self._sleep_saved_buffers[name].data)
|
||||
self._sleep_saved_buffers = {}
|
||||
|
||||
def init_device(self) -> None:
|
||||
if self.device_config.device.type == "cuda":
|
||||
# torch.distributed.all_reduce does not free the input tensor until
|
||||
# the synchronization point. This causes the memory usage to grow
|
||||
# as the number of all_reduce calls increases. This env var disables
|
||||
# this behavior.
|
||||
# Related issue:
|
||||
# https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
|
||||
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
||||
|
||||
# This env var set by Ray causes exceptions with graph building.
|
||||
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
|
||||
self.device = torch.device(f"cuda:{self.local_rank}")
|
||||
torch.cuda.set_device(self.device)
|
||||
|
||||
_check_if_gpu_supports_dtype(self.model_config.dtype)
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
self.baseline_snapshot = MemorySnapshot()
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Not support device type: {self.device_config.device}")
|
||||
# Initialize the distributed environment.
|
||||
init_worker_distributed_environment(self.vllm_config, self.rank,
|
||||
self.distributed_init_method,
|
||||
self.local_rank)
|
||||
# Set random seed.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
def load_model(self):
|
||||
if self.vllm_config.model_config.enable_sleep_mode:
|
||||
allocator = CuMemAllocator.get_instance()
|
||||
assert allocator.get_current_usage() == 0, (
|
||||
"Sleep mode can only be "
|
||||
"used for one instance per process.")
|
||||
context = allocator.use_memory_pool(tag="weights")
|
||||
else:
|
||||
from contextlib import nullcontext
|
||||
context = nullcontext()
|
||||
with context:
|
||||
self.model_runner.load_model()
|
||||
|
||||
def save_sharded_state(
|
||||
self,
|
||||
path: str,
|
||||
pattern: Optional[str] = None,
|
||||
max_size: Optional[int] = None,
|
||||
) -> None:
|
||||
self.model_runner.save_sharded_state(
|
||||
path,
|
||||
pattern=pattern,
|
||||
max_size=max_size,
|
||||
)
|
||||
|
||||
def save_tensorized_model(
|
||||
self,
|
||||
tensorizer_config: TensorizerConfig,
|
||||
) -> None:
|
||||
self.model_runner.save_tensorized_model(
|
||||
tensorizer_config=tensorizer_config, )
|
||||
|
||||
@torch.inference_mode()
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
"""Profiles the peak memory usage of the model to determine how many
|
||||
KV blocks may be allocated without OOMs.
|
||||
|
||||
The engine will first conduct a profiling of the existing memory usage.
|
||||
Then, it calculate the maximum possible number of GPU and CPU blocks
|
||||
that can be allocated with the remaining free memory.
|
||||
|
||||
Tip:
|
||||
You may limit the usage of GPU memory
|
||||
by adjusting the `gpu_memory_utilization` parameter.
|
||||
"""
|
||||
# Profile the memory usage of the model and get the maximum number of
|
||||
# cache blocks that can be allocated with the remaining free memory.
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
free_memory_pre_profile, total_gpu_memory = torch.cuda.mem_get_info()
|
||||
|
||||
# Execute a forward pass with dummy inputs to profile the memory usage
|
||||
# of the model.
|
||||
with memory_profiling(
|
||||
self.baseline_snapshot,
|
||||
weights_memory=self.model_runner.model_memory_usage) as result:
|
||||
self.model_runner.profile_run()
|
||||
|
||||
self._assert_memory_footprint_increased_during_profiling()
|
||||
|
||||
memory_for_current_instance = total_gpu_memory * \
|
||||
self.cache_config.gpu_memory_utilization
|
||||
available_kv_cache_memory = (memory_for_current_instance -
|
||||
result.non_kv_cache_memory)
|
||||
|
||||
# Calculate the number of blocks that can be allocated with the
|
||||
# profiled peak memory.
|
||||
cache_block_size = self.get_cache_block_size_bytes()
|
||||
if cache_block_size == 0:
|
||||
num_gpu_blocks = 0
|
||||
num_cpu_blocks = 0
|
||||
else:
|
||||
num_gpu_blocks = int(available_kv_cache_memory // cache_block_size)
|
||||
num_cpu_blocks = int(self.cache_config.swap_space_bytes //
|
||||
cache_block_size)
|
||||
num_gpu_blocks = max(num_gpu_blocks, 0)
|
||||
num_cpu_blocks = max(num_cpu_blocks, 0)
|
||||
|
||||
msg = (f"Memory profiling takes {result.profile_time:.2f} seconds\n"
|
||||
"the current vLLM instance can use "
|
||||
"total_gpu_memory "
|
||||
f"({(total_gpu_memory / GiB_bytes):.2f}GiB)"
|
||||
" x gpu_memory_utilization "
|
||||
f"({self.cache_config.gpu_memory_utilization:.2f})"
|
||||
f" = {(memory_for_current_instance / GiB_bytes):.2f}GiB\n"
|
||||
"model weights take "
|
||||
f"{(result.weights_memory / GiB_bytes):.2f}GiB;"
|
||||
" non_torch_memory takes "
|
||||
f"{(result.non_torch_increase / GiB_bytes):.2f}GiB;"
|
||||
" PyTorch activation peak memory takes "
|
||||
f"{(result.torch_peak_increase / GiB_bytes):.2f}GiB;"
|
||||
" the rest of the memory reserved for KV Cache is "
|
||||
f"{(available_kv_cache_memory / GiB_bytes):.2f}GiB.")
|
||||
|
||||
logger.info(msg)
|
||||
# Final cleanup
|
||||
gc.collect()
|
||||
|
||||
return num_gpu_blocks, num_cpu_blocks
|
||||
|
||||
def _assert_memory_footprint_increased_during_profiling(self):
|
||||
# NOTE(woosuk): Here we assume that the other processes using the same
|
||||
# GPU did not change their memory usage during the profiling.
|
||||
free_gpu_memory, total = torch.cuda.mem_get_info()
|
||||
cuda_memory = total - free_gpu_memory
|
||||
assert self.baseline_snapshot.cuda_memory < cuda_memory, (
|
||||
"Error in memory profiling. "
|
||||
f"Initial used memory {self.baseline_snapshot.cuda_memory}, "
|
||||
f"currently used memory {cuda_memory}. "
|
||||
f"This happens when the GPU memory was "
|
||||
"not properly cleaned up before initializing the vLLM instance.")
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int,
|
||||
num_cpu_blocks: int) -> None:
|
||||
"""Allocate GPU and CPU KV cache with the specified number of blocks.
|
||||
|
||||
This also warms up the model, which may record CUDA graphs.
|
||||
"""
|
||||
raise_if_cache_size_invalid(
|
||||
num_gpu_blocks, self.cache_config.block_size,
|
||||
self.cache_config.is_attention_free,
|
||||
self.model_config.max_model_len,
|
||||
self.parallel_config.pipeline_parallel_size)
|
||||
|
||||
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||
|
||||
if self.vllm_config.model_config.enable_sleep_mode:
|
||||
allocator = CuMemAllocator.get_instance()
|
||||
context = allocator.use_memory_pool(tag="kv_cache")
|
||||
else:
|
||||
from contextlib import nullcontext
|
||||
context = nullcontext()
|
||||
with context:
|
||||
self._init_cache_engine()
|
||||
self._warm_up_model()
|
||||
|
||||
def _init_cache_engine(self):
|
||||
assert self.cache_config.num_gpu_blocks is not None
|
||||
self.cache_engine = [
|
||||
CacheEngine(self.cache_config, self.model_config,
|
||||
self.parallel_config, self.device_config)
|
||||
for _ in range(self.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
self.gpu_cache = [
|
||||
self.cache_engine[ve].gpu_cache
|
||||
for ve in range(self.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
bind_kv_cache(self.compilation_config.static_forward_context,
|
||||
self.gpu_cache)
|
||||
|
||||
def _warm_up_model(self) -> None:
|
||||
# warm up sizes that are not in cudagraph capture sizes,
|
||||
# but users still want to compile for better performance,
|
||||
# e.g. for the max-num-batched token size in chunked prefill.
|
||||
warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
|
||||
if not self.model_config.enforce_eager:
|
||||
warmup_sizes = [
|
||||
x for x in warmup_sizes if x not in
|
||||
self.vllm_config.compilation_config.cudagraph_capture_sizes
|
||||
]
|
||||
for size in sorted(warmup_sizes, reverse=True):
|
||||
logger.info("Compile and warming up model for size %d", size)
|
||||
self.model_runner._dummy_run(size)
|
||||
if not self.model_config.enforce_eager:
|
||||
self.model_runner.capture_model(self.gpu_cache)
|
||||
# Reset the seed to ensure that the random state is not affected by
|
||||
# the model initialization and profiling.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
@property
|
||||
def do_metadata_broadcast(self) -> bool:
|
||||
return self.parallel_config.tensor_parallel_size > 1
|
||||
|
||||
@property
|
||||
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
|
||||
return self.gpu_cache
|
||||
|
||||
@property
|
||||
def cache_engines(self) -> Optional[List[CacheEngine]]:
|
||||
return self.cache_engine
|
||||
|
||||
@torch.inference_mode()
|
||||
def prepare_worker_input(
|
||||
self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
|
||||
virtual_engine = execute_model_req.virtual_engine
|
||||
num_steps = execute_model_req.num_steps
|
||||
num_seq_groups = len(execute_model_req.seq_group_metadata_list)
|
||||
# `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors.
|
||||
# they contain parameters to launch cudamemcpyasync.
|
||||
blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in,
|
||||
device="cpu",
|
||||
dtype=torch.int64).view(-1, 2)
|
||||
blocks_to_swap_out = torch.tensor(execute_model_req.blocks_to_swap_out,
|
||||
device="cpu",
|
||||
dtype=torch.int64).view(-1, 2)
|
||||
# `blocks_to_copy` is a gpu tensor. The src and tgt of
|
||||
# blocks to copy are in the same device, and `blocks_to_copy`
|
||||
# can be used directly within cuda kernels.
|
||||
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
|
||||
device=self.device,
|
||||
dtype=torch.int64).view(-1, 2)
|
||||
|
||||
return WorkerInput(
|
||||
num_seq_groups=num_seq_groups,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
virtual_engine=virtual_engine,
|
||||
num_steps=num_steps,
|
||||
kvcache_slot_to_be_moved=execute_model_req.kvcache_slot_to_be_moved
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_worker(self, worker_input: WorkerInput) -> None:
|
||||
virtual_engine = worker_input.virtual_engine
|
||||
# Issue cache operations.
|
||||
if (worker_input.blocks_to_swap_in is not None
|
||||
and worker_input.blocks_to_swap_in.numel() > 0):
|
||||
self.cache_engine[virtual_engine].swap_in(
|
||||
worker_input.blocks_to_swap_in)
|
||||
if (worker_input.blocks_to_swap_out is not None
|
||||
and worker_input.blocks_to_swap_out.numel() > 0):
|
||||
self.cache_engine[virtual_engine].swap_out(
|
||||
worker_input.blocks_to_swap_out)
|
||||
if (worker_input.blocks_to_copy is not None
|
||||
and worker_input.blocks_to_copy.numel() > 0):
|
||||
self.cache_engine[virtual_engine].copy(worker_input.blocks_to_copy)
|
||||
|
||||
# tree-style generation need to move kvcache to correct position
|
||||
if worker_input.kvcache_slot_to_be_moved is not None:
|
||||
self.cache_engine[virtual_engine].move_caches(self.kv_cache[virtual_engine],
|
||||
worker_input.kvcache_slot_to_be_moved)
|
||||
|
||||
def _get_cached_seq_group_metadata(
|
||||
self,
|
||||
seq_group_metadata_list: List[Union[SequenceGroupMetadata,
|
||||
SequenceGroupMetadataDelta]],
|
||||
finished_request_ids: List[str]) -> List[SequenceGroupMetadata]:
|
||||
"""Return a list of cached Sequence Group Metadata after updating its
|
||||
state.
|
||||
|
||||
It is used because scheduler only sends delta to workers to reduce
|
||||
the data payload size. The function also cleans up cache based on
|
||||
a given `finished_request_ids`.
|
||||
"""
|
||||
new_seq_group_metadata_list = []
|
||||
for metadata_or_delta in seq_group_metadata_list:
|
||||
request_id = metadata_or_delta.request_id
|
||||
if request_id not in self._seq_group_metadata_cache:
|
||||
# The first prefill.
|
||||
assert isinstance(metadata_or_delta, SequenceGroupMetadata)
|
||||
self._seq_group_metadata_cache[request_id] = metadata_or_delta
|
||||
else:
|
||||
# The first prefill is already cached.
|
||||
if isinstance(metadata_or_delta, SequenceGroupMetadataDelta):
|
||||
self._seq_group_metadata_cache[request_id].apply_delta(
|
||||
metadata_or_delta)
|
||||
else:
|
||||
# If metadata snapshot is sent again, it is
|
||||
# preempted. Reset the cache because we need to start
|
||||
# from scratch.
|
||||
assert isinstance(metadata_or_delta, SequenceGroupMetadata)
|
||||
self._seq_group_metadata_cache[
|
||||
request_id] = metadata_or_delta
|
||||
|
||||
new_seq_group_metadata_list.append(
|
||||
self._seq_group_metadata_cache[request_id])
|
||||
|
||||
# Clean up finished ids
|
||||
for finished_id in finished_request_ids:
|
||||
del self._seq_group_metadata_cache[finished_id]
|
||||
|
||||
return new_seq_group_metadata_list
|
||||
|
||||
def _execute_model_spmd(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
if execute_model_req is not None:
|
||||
new_seq_group_metadata_list = self._get_cached_seq_group_metadata(
|
||||
execute_model_req.seq_group_metadata_list,
|
||||
execute_model_req.finished_requests_ids)
|
||||
|
||||
execute_model_req.seq_group_metadata_list = (
|
||||
new_seq_group_metadata_list)
|
||||
output = super()._execute_model_spmd(execute_model_req,
|
||||
intermediate_tensors)
|
||||
return output
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
return self.model_runner.add_lora(lora_request)
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
return self.model_runner.remove_lora(lora_id)
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
return self.model_runner.pin_lora(lora_id)
|
||||
|
||||
def list_loras(self) -> Set[int]:
|
||||
return self.model_runner.list_loras()
|
||||
|
||||
def add_prompt_adapter(
|
||||
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
|
||||
return self.model_runner.add_prompt_adapter(prompt_adapter_request)
|
||||
|
||||
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
|
||||
return self.model_runner.remove_lora(prompt_adapter_id)
|
||||
|
||||
def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
|
||||
return self.model_runner.pin_prompt_adapter(prompt_adapter_id)
|
||||
|
||||
def list_prompt_adapters(self) -> Set[int]:
|
||||
return self.model_runner.list_prompt_adapters()
|
||||
|
||||
@property
|
||||
def max_model_len(self) -> int:
|
||||
return self.model_config.max_model_len
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
return self.model_runner.vocab_size
|
||||
|
||||
def get_cache_block_size_bytes(self) -> int:
|
||||
"""Get the size of the KV cache block size in bytes.
|
||||
"""
|
||||
return CacheEngine.get_cache_block_size(self.cache_config,
|
||||
self.model_config,
|
||||
self.parallel_config)
|
||||
|
||||
|
||||
def init_worker_distributed_environment(
|
||||
vllm_config: VllmConfig,
|
||||
rank: int,
|
||||
distributed_init_method: Optional[str] = None,
|
||||
local_rank: int = -1,
|
||||
) -> None:
|
||||
"""Initialize the distributed environment."""
|
||||
parallel_config = vllm_config.parallel_config
|
||||
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
|
||||
|
||||
init_distributed_environment(parallel_config.world_size, rank,
|
||||
distributed_init_method, local_rank)
|
||||
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size)
|
||||
|
||||
ensure_kv_transfer_initialized(vllm_config)
|
||||
|
||||
|
||||
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
|
||||
# Check if the GPU supports the dtype.
|
||||
if torch_dtype == torch.bfloat16: # noqa: SIM102
|
||||
if not current_platform.has_device_capability(80):
|
||||
capability = current_platform.get_device_capability()
|
||||
gpu_name = current_platform.get_device_name()
|
||||
|
||||
if capability is None:
|
||||
compute_str = "does not have a compute capability"
|
||||
else:
|
||||
version_str = capability.as_version_str()
|
||||
compute_str = f"has compute capability {version_str}"
|
||||
|
||||
raise ValueError(
|
||||
"Bfloat16 is only supported on GPUs with compute capability "
|
||||
f"of at least 8.0. Your {gpu_name} GPU {compute_str}. "
|
||||
"You can use float16 instead by explicitly setting the "
|
||||
"`dtype` flag in CLI, for example: --dtype=half.")
|
||||
|
||||
|
||||
def raise_if_cache_size_invalid(num_gpu_blocks, block_size, is_attention_free,
|
||||
max_model_len, pipeline_parallel_size) -> None:
|
||||
if is_attention_free and num_gpu_blocks != 0:
|
||||
raise ValueError("No memory should be allocated for the cache blocks "
|
||||
f"for an attention-free model, but {num_gpu_blocks} "
|
||||
"blocks are allocated.")
|
||||
if not is_attention_free and num_gpu_blocks <= 0:
|
||||
raise ValueError("No available memory for the cache blocks. "
|
||||
"Try increasing `gpu_memory_utilization` when "
|
||||
"initializing the engine.")
|
||||
max_seq_len = block_size * (num_gpu_blocks // pipeline_parallel_size)
|
||||
if not is_attention_free and max_model_len > max_seq_len:
|
||||
raise ValueError(
|
||||
f"The model's max seq len ({max_model_len}) "
|
||||
"is larger than the maximum number of tokens that can be "
|
||||
f"stored in KV cache ({max_seq_len}). Try increasing "
|
||||
"`gpu_memory_utilization` or decreasing `max_model_len` when "
|
||||
"initializing the engine.")
|
||||
706
vllm/worker/worker_base.py
Normal file
706
vllm/worker/worker_base.py
Normal file
@@ -0,0 +1,706 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import dataclasses
|
||||
import os
|
||||
import numa
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
|
||||
|
||||
import cloudpickle
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import (ObservabilityConfig, VllmConfig,
|
||||
set_current_vllm_config)
|
||||
from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
|
||||
from vllm.utils import (enable_trace_function_call_for_thread,
|
||||
resolve_obj_by_qualname, run_method,
|
||||
update_environment_variables,
|
||||
warn_for_unimplemented_methods)
|
||||
from vllm.worker.cache_engine import CacheEngine
|
||||
from vllm.worker.model_runner_base import (BroadcastableModelInput,
|
||||
ModelRunnerBase,
|
||||
ModelRunnerInputBase)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# 设置当前进程绑定到 NUMA 节点
|
||||
def bind_to_numa(local_rank):
|
||||
env_str = f"VLLM_RANK{local_rank}_NUMA"
|
||||
node_count = numa.get_max_node() + 1
|
||||
numa_node = int(os.getenv(env_str, -1))
|
||||
|
||||
# 未配置环境变量或配置错误则不做绑定,TODO:根据topo自动绑定方案
|
||||
if numa_node < 0:
|
||||
logger.warning("%s is unset or set incorrectly, vllm will not bind to numa! %s = %d", env_str, env_str, numa_node)
|
||||
return
|
||||
|
||||
if numa_node > numa.get_max_node():
|
||||
raise ValueError(f"NUMA node {numa_node} is not available.")
|
||||
|
||||
numa.bind([numa_node])
|
||||
|
||||
|
||||
@warn_for_unimplemented_methods
|
||||
class WorkerBase:
|
||||
"""Worker interface that allows vLLM to cleanly separate implementations for
|
||||
different hardware. Also abstracts control plane communication, e.g., to
|
||||
communicate request metadata to other workers.
|
||||
"""
|
||||
|
||||
model_input: Optional[ModelRunnerInputBase] = None
|
||||
tree_decoding = (os.environ.get('VLLM_TREE_DECODING') == '1')
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
) -> None:
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.lora_config = vllm_config.lora_config
|
||||
self.load_config = vllm_config.load_config
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.device_config = vllm_config.device_config
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.prompt_adapter_config = vllm_config.prompt_adapter_config
|
||||
self.observability_config = vllm_config.observability_config
|
||||
self.kv_transfer_config = vllm_config.kv_transfer_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
from vllm.platforms import current_platform
|
||||
self.current_platform = current_platform
|
||||
|
||||
def init_device(self) -> None:
|
||||
"""Initialize device state, such as loading the model or other on-device
|
||||
memory allocations.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int,
|
||||
num_cpu_blocks: int) -> None:
|
||||
"""Initialize the KV cache with the given size in blocks.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
raise NotImplementedError
|
||||
|
||||
def load_model(self) -> None:
|
||||
"""Load model onto target device."""
|
||||
raise NotImplementedError
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
raise NotImplementedError
|
||||
|
||||
def start_worker_execution_loop(self) -> None:
|
||||
"""Execute model loop in parallel worker.
|
||||
|
||||
You can stop the loop by executing a driver worker with an empty output.
|
||||
See `stop_remote_worker_execution_loop` for more details.
|
||||
"""
|
||||
with self.current_platform.inference_mode():
|
||||
while True:
|
||||
output = self.execute_model(execute_model_req=None)
|
||||
if output is None:
|
||||
return None
|
||||
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
"""Determine the number of available blocks for the GPU KV cache and
|
||||
swappable CPU KV cache.
|
||||
|
||||
The implementation may run profiling or other heuristics to determine
|
||||
the size of caches.
|
||||
|
||||
Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
|
||||
are blocks that are "active" on the device and can be appended to.
|
||||
num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
|
||||
appended to.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_cache_block_size_bytes(self) -> int:
|
||||
"""Return the size of a single cache block, in bytes. Used in
|
||||
speculative decoding.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def list_loras(self) -> Set[int]:
|
||||
raise NotImplementedError
|
||||
|
||||
# @property
|
||||
# @abstractmethod
|
||||
# def cache_engines(self) -> Optional[List[CacheEngine]]:
|
||||
# raise NotImplementedError
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
"""Get vocabulary size from model configuration."""
|
||||
return self.model_config.get_vocab_size()
|
||||
|
||||
|
||||
class DelegateWorkerBase(WorkerBase):
|
||||
"""
|
||||
A class that delegates all methods to another WorkerBase instance. This is
|
||||
useful for creating a WorkerBase that wraps another WorkerBase instance,
|
||||
e.g. speculative decoding.
|
||||
"""
|
||||
worker: WorkerBase
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
vllm_config: VllmConfig = kwargs.get("vllm_config")
|
||||
cls = resolve_obj_by_qualname(vllm_config.parallel_config.worker_cls)
|
||||
self.worker = cls(*args, **kwargs)
|
||||
|
||||
def init_device(self) -> None:
|
||||
self.worker.init_device()
|
||||
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
return self.worker.determine_num_available_blocks()
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int,
|
||||
num_cpu_blocks: int) -> None:
|
||||
self.worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
|
||||
|
||||
def load_model(self) -> None:
|
||||
"""Load model onto target device."""
|
||||
self.worker.load_model()
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.worker.get_model()
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
return self.worker.execute_model(execute_model_req)
|
||||
|
||||
def get_cache_block_size_bytes(self) -> int:
|
||||
return self.worker.get_cache_block_size_bytes()
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
return self.worker.add_lora(lora_request)
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
return self.worker.remove_lora(lora_id)
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
return self.worker.pin_lora(lora_id)
|
||||
|
||||
def list_loras(self) -> Set[int]:
|
||||
return self.worker.list_loras()
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self.worker, attr)
|
||||
|
||||
|
||||
class LoRANotSupportedWorkerBase(WorkerBase):
|
||||
"""Partial implementation of WorkerBase that raises exceptions when LoRA
|
||||
methods are invoked.
|
||||
"""
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
raise ValueError(f"{type(self)} does not support LoRA")
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
raise ValueError(f"{type(self)} does not support LoRA")
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
raise ValueError(f"{type(self)} does not support LoRA")
|
||||
|
||||
def list_loras(self) -> Set[int]:
|
||||
raise ValueError(f"{type(self)} does not support LoRA")
|
||||
|
||||
@property
|
||||
def cache_engines(self) -> Optional[List[CacheEngine]]:
|
||||
return None
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class WorkerInput:
|
||||
"""Local inputs to each worker. May contain device-specific data. These
|
||||
fields should be broadcastable to other workers.
|
||||
"""
|
||||
|
||||
num_seq_groups: Optional[int] = None
|
||||
blocks_to_swap_in: Optional[torch.Tensor] = None
|
||||
blocks_to_swap_out: Optional[torch.Tensor] = None
|
||||
blocks_to_copy: Optional[torch.Tensor] = None
|
||||
virtual_engine: int = 0
|
||||
num_steps: int = 1
|
||||
|
||||
# Optional slot mapping of kvcache that pending to be moved generated from draft model.
|
||||
kvcache_slot_to_be_moved: Optional[torch.Tensor] = None
|
||||
|
||||
@classmethod
|
||||
def from_broadcasted_tensor_dict(
|
||||
cls: Type["WorkerInput"],
|
||||
tensor_dict: Dict[str, Any],
|
||||
) -> "WorkerInput":
|
||||
"""
|
||||
Pop fields from the given tensor_dict and populate a new instance of
|
||||
WorkerInput.
|
||||
"""
|
||||
return cls(
|
||||
num_seq_groups=tensor_dict.pop("num_seq_groups"),
|
||||
blocks_to_swap_in=tensor_dict.pop("blocks_to_swap_in"),
|
||||
blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"),
|
||||
blocks_to_copy=tensor_dict.pop("blocks_to_copy"),
|
||||
virtual_engine=tensor_dict["virtual_engine"],
|
||||
num_steps=tensor_dict.pop("num_steps"),
|
||||
kvcache_slot_to_be_moved=tensor_dict.pop("kvcache_slot_to_be_moved"),
|
||||
)
|
||||
|
||||
def as_broadcastable_tensor_dict(
|
||||
self) -> Dict[str, Union[int, torch.Tensor]]:
|
||||
"""
|
||||
Extract broadcastable fields.
|
||||
"""
|
||||
tensor_dict = {
|
||||
"num_seq_groups": self.num_seq_groups,
|
||||
"blocks_to_swap_in": self.blocks_to_swap_in,
|
||||
"blocks_to_swap_out": self.blocks_to_swap_out,
|
||||
"blocks_to_copy": self.blocks_to_copy,
|
||||
"virtual_engine": self.virtual_engine,
|
||||
"num_steps": self.num_steps,
|
||||
"kvcache_slot_to_be_moved": self.kvcache_slot_to_be_moved
|
||||
}
|
||||
|
||||
return tensor_dict
|
||||
|
||||
|
||||
class LocalOrDistributedWorkerBase(WorkerBase):
|
||||
"""
|
||||
Partial implementation of WorkerBase that has a default `execute_model`
|
||||
definition to perform metadata transfer between workers when in distributed
|
||||
mode. Subclasses of this interface should use model runners that inherit
|
||||
from ModelRunnerBase, and should only need to implement worker-local logic.
|
||||
If custom control plane logic is needed to transfer metadata, or if the
|
||||
model runner cannot inherit from ModelRunnerBase, use WorkerBase instead.
|
||||
"""
|
||||
is_driver_worker: bool
|
||||
model_runner: ModelRunnerBase
|
||||
observability_config: Optional[ObservabilityConfig] = None
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def do_metadata_broadcast(self) -> bool:
|
||||
"""
|
||||
Used by the default `execute_model` to check whether broadcast is
|
||||
needed to transfer request inputs from the driver worker to other
|
||||
workers in the TP group. If WorkerBase subclass only supports
|
||||
single-worker execution, then this method should return False.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
|
||||
"""
|
||||
Gets the list of kv caches to pass to the worker's model runner. Each
|
||||
element in the list is a kv cache corresponding to a particular virtual
|
||||
engine (PP stream). Used by the default `execute_model`. If the worker's
|
||||
model runner does not follow the ModelRunnerBase interface, then inherit
|
||||
from WorkerBase instead.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def prepare_worker_input(
|
||||
self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
|
||||
"""
|
||||
Prepare the inputs to WorkerBase.execute_worker from an execution
|
||||
request. This method may move data to the worker's local device. It is
|
||||
not allowed to communicate with other workers or devices.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def execute_worker(self, worker_input: WorkerInput) -> None:
|
||||
"""
|
||||
Process an execution request.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_worker_input_from_broadcast(
|
||||
self
|
||||
) -> Optional[Tuple[BroadcastableModelInput, WorkerInput, Dict[
|
||||
str, torch.Tensor]]]:
|
||||
""" Get the worker input from the broadcasted tensor dict. """
|
||||
assert self.do_metadata_broadcast
|
||||
assert not self.is_driver_worker
|
||||
broadcast_data = broadcast_tensor_dict(src=0)
|
||||
if not broadcast_data:
|
||||
return None
|
||||
|
||||
worker_input = WorkerInput.from_broadcasted_tensor_dict(broadcast_data)
|
||||
model_input = (
|
||||
self.model_runner.make_model_input_from_broadcasted_tensor_dict(
|
||||
broadcast_data))
|
||||
|
||||
kwargs = extract_previous_hidden_states(broadcast_data)
|
||||
|
||||
return model_input, worker_input, kwargs
|
||||
|
||||
def _get_driver_input_and_broadcast(
|
||||
self, execute_model_req: ExecuteModelRequest
|
||||
) -> Tuple[BroadcastableModelInput, WorkerInput, Dict[str, torch.Tensor]]:
|
||||
""" Get the driver input and broadcast it to other workers. """
|
||||
assert self.is_driver_worker
|
||||
|
||||
worker_input: WorkerInput = self.prepare_worker_input(
|
||||
execute_model_req=execute_model_req)
|
||||
|
||||
model_input: ModelRunnerInputBase = (
|
||||
self.model_runner.prepare_model_input(
|
||||
execute_model_req.seq_group_metadata_list,
|
||||
execute_model_req.virtual_engine,
|
||||
execute_model_req.finished_requests_ids))
|
||||
|
||||
if self.tree_decoding and execute_model_req.tree_position_ids is not None and \
|
||||
execute_model_req.tree_attn_masks is not None:
|
||||
if hasattr(model_input, "input_positions") and \
|
||||
hasattr(model_input, "attn_metadata") and \
|
||||
hasattr(model_input.attn_metadata, "tree_attention_masks_tensor"):
|
||||
attn_metadata = model_input.attn_metadata
|
||||
attn_metadata.tree_attention_masks_tensor = execute_model_req.tree_attn_masks.contiguous()
|
||||
model_input = dataclasses.replace(model_input,
|
||||
input_positions=execute_model_req.tree_position_ids.contiguous(),
|
||||
attn_metadata=attn_metadata)
|
||||
|
||||
kwargs = extract_previous_hidden_states(execute_model_req)
|
||||
|
||||
if self.do_metadata_broadcast:
|
||||
broadcast_data = worker_input.as_broadcastable_tensor_dict()
|
||||
broadcast_data.update(model_input.as_broadcastable_tensor_dict())
|
||||
broadcast_data.update(kwargs)
|
||||
broadcast_tensor_dict(broadcast_data, src=0)
|
||||
|
||||
if execute_model_req.async_callback:
|
||||
model_input = dataclasses.replace( # type: ignore
|
||||
model_input,
|
||||
async_callback=execute_model_req.async_callback)
|
||||
|
||||
return model_input, worker_input, kwargs
|
||||
|
||||
def prepare_input(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||
) -> Optional[Tuple[BroadcastableModelInput, WorkerInput, Dict[
|
||||
str, torch.Tensor]]]:
|
||||
"""
|
||||
Prepare the inputs to ModelRunner and workers.
|
||||
"""
|
||||
if self.is_driver_worker:
|
||||
if execute_model_req is None:
|
||||
if self.do_metadata_broadcast:
|
||||
# This signals that there's no more requests to process for
|
||||
# now. All workers are running infinite loop with
|
||||
# broadcast_tensor_dict, and it stops the loop when the
|
||||
# driver broadcasts an empty input. Send an empty input to
|
||||
# notify all other workers to stop their execution loop.
|
||||
broadcast_tensor_dict({}, src=0)
|
||||
return None
|
||||
return self._get_driver_input_and_broadcast(execute_model_req)
|
||||
else:
|
||||
return self._get_worker_input_from_broadcast()
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model_runner.get_model()
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None,
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
"""Executes at least one model step on the given sequences, unless no
|
||||
sequences are provided."""
|
||||
start_time = time.perf_counter()
|
||||
|
||||
inputs = self.prepare_input(execute_model_req)
|
||||
if inputs is None:
|
||||
return None
|
||||
|
||||
model_input, worker_input, kwargs = inputs
|
||||
num_steps = worker_input.num_steps
|
||||
if execute_model_req is not None and execute_model_req.spec_step_idx:
|
||||
kwargs["spec_step_idx"] = execute_model_req.spec_step_idx
|
||||
|
||||
self.model_input = model_input
|
||||
|
||||
self.execute_worker(worker_input)
|
||||
|
||||
# If there is no input, we don't need to execute the model.
|
||||
if worker_input.num_seq_groups == 0:
|
||||
return []
|
||||
|
||||
intermediate_tensors = None
|
||||
orig_model_execute_time = 0.0
|
||||
if not get_pp_group().is_first_rank:
|
||||
intermediate_tensors = IntermediateTensors(
|
||||
get_pp_group().recv_tensor_dict(
|
||||
all_gather_group=get_tp_group()))
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_execute_time):
|
||||
orig_model_execute_time = intermediate_tensors.tensors.get(
|
||||
"model_execute_time", torch.tensor(0)).item()
|
||||
|
||||
output = self.model_runner.execute_model(
|
||||
model_input=model_input,
|
||||
kv_caches=self.kv_cache[worker_input.virtual_engine]
|
||||
if self.kv_cache is not None else None,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
num_steps=num_steps,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
model_execute_time = time.perf_counter() - start_time
|
||||
if not get_pp_group().is_last_rank:
|
||||
# output is IntermediateTensors
|
||||
assert isinstance(output, IntermediateTensors)
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_execute_time):
|
||||
output.tensors["model_execute_time"] = torch.tensor(
|
||||
model_execute_time + orig_model_execute_time)
|
||||
get_pp_group().send_tensor_dict(output.tensors,
|
||||
all_gather_group=get_tp_group())
|
||||
return [None]
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_execute_time
|
||||
and output is not None):
|
||||
for o in output:
|
||||
o.model_execute_time = (orig_model_execute_time +
|
||||
model_execute_time)
|
||||
|
||||
# output is List[SamplerOutput]
|
||||
return output
|
||||
|
||||
def _execute_model_spmd(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
"""
|
||||
Execute model in Single Program Multiple Data (SPMD) fashion.
|
||||
All workers take the same request, prepare the input and
|
||||
execute the model.
|
||||
"""
|
||||
assert execute_model_req is not None, (
|
||||
"_execute_model_spmd() requires each worker to take in an "
|
||||
"ExecuteModelRequest")
|
||||
worker_input: WorkerInput = self.prepare_worker_input(
|
||||
execute_model_req=execute_model_req)
|
||||
model_input: ModelRunnerInputBase = (
|
||||
self.model_runner.prepare_model_input(
|
||||
execute_model_req.seq_group_metadata_list))
|
||||
|
||||
self.execute_worker(worker_input)
|
||||
|
||||
# If there is no input, we don't need to execute the model.
|
||||
if worker_input.num_seq_groups == 0:
|
||||
return []
|
||||
|
||||
kwargs = extract_previous_hidden_states(execute_model_req)
|
||||
|
||||
return self.model_runner.execute_model(
|
||||
model_input=model_input,
|
||||
kv_caches=self.kv_cache[worker_input.virtual_engine]
|
||||
if self.kv_cache is not None else None,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class WorkerWrapperBase:
|
||||
"""
|
||||
This class represents one process in an executor/engine. It is responsible
|
||||
for lazily initializing the worker and handling the worker's lifecycle.
|
||||
We first instantiate the WorkerWrapper, which remembers the worker module
|
||||
and class name. Then, when we call `update_environment_variables`, and the
|
||||
real initialization happens in `init_worker`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
rpc_rank: int = 0,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the worker wrapper with the given vllm_config and rpc_rank.
|
||||
Note: rpc_rank is the rank of the worker in the executor. In most cases,
|
||||
it is also the rank of the worker in the distributed group. However,
|
||||
when multiple executors work together, they can be different.
|
||||
e.g. in the case of SPMD-style offline inference with TP=2,
|
||||
users can launch 2 engines/executors, each with only 1 worker.
|
||||
All workers have rpc_rank=0, but they have different ranks in the TP
|
||||
group.
|
||||
"""
|
||||
self.rpc_rank = rpc_rank
|
||||
self.worker: Optional[WorkerBase] = None
|
||||
self.vllm_config: Optional[VllmConfig] = None
|
||||
# do not store this `vllm_config`, `init_worker` will set the final
|
||||
# one. TODO: investigate if we can remove this field in
|
||||
# `WorkerWrapperBase`, `init_cached_hf_modules` should be
|
||||
# unnecessary now.
|
||||
if vllm_config.model_config is not None:
|
||||
# it can be None in tests
|
||||
trust_remote_code = vllm_config.model_config.trust_remote_code
|
||||
if trust_remote_code:
|
||||
# note: lazy import to avoid importing torch before initializing
|
||||
from vllm.utils import init_cached_hf_modules
|
||||
init_cached_hf_modules()
|
||||
|
||||
def adjust_rank(self, rank_mapping: Dict[int, int]) -> None:
|
||||
"""
|
||||
Adjust the rpc_rank based on the given mapping.
|
||||
It is only used during the initialization of the executor,
|
||||
to adjust the rpc_rank of workers after we create all workers.
|
||||
"""
|
||||
if self.rpc_rank in rank_mapping:
|
||||
self.rpc_rank = rank_mapping[self.rpc_rank]
|
||||
|
||||
def update_environment_variables(self, envs_list: List[Dict[str,
|
||||
str]]) -> None:
|
||||
envs = envs_list[self.rpc_rank]
|
||||
key = 'CUDA_VISIBLE_DEVICES'
|
||||
if key in envs and key in os.environ:
|
||||
# overwriting CUDA_VISIBLE_DEVICES is desired behavior
|
||||
# suppress the warning in `update_environment_variables`
|
||||
del os.environ[key]
|
||||
update_environment_variables(envs)
|
||||
|
||||
def init_worker(self, all_kwargs: List[Dict[str, Any]]) -> None:
|
||||
"""
|
||||
Here we inject some common logic before initializing the worker.
|
||||
Arguments are passed to the worker class constructor.
|
||||
"""
|
||||
kwargs = all_kwargs[self.rpc_rank]
|
||||
self.vllm_config = kwargs.get("vllm_config", None)
|
||||
assert self.vllm_config is not None, (
|
||||
"vllm_config is required to initialize the worker")
|
||||
enable_trace_function_call_for_thread(self.vllm_config)
|
||||
|
||||
from vllm.plugins import load_general_plugins
|
||||
load_general_plugins()
|
||||
|
||||
if isinstance(self.vllm_config.parallel_config.worker_cls, str):
|
||||
worker_class = resolve_obj_by_qualname(
|
||||
self.vllm_config.parallel_config.worker_cls)
|
||||
else:
|
||||
logger.warning(
|
||||
"passing worker_cls as a class object is strongly deprecated,"
|
||||
" as the serialization of class objects can be tricky and"
|
||||
" error-prone. To be safe, please keep the class in a separate"
|
||||
" module and pass the qualified name of the class as a string."
|
||||
)
|
||||
assert isinstance(self.vllm_config.parallel_config.worker_cls,
|
||||
bytes)
|
||||
worker_class = cloudpickle.loads(
|
||||
self.vllm_config.parallel_config.worker_cls)
|
||||
if self.vllm_config.parallel_config.worker_extension_cls:
|
||||
worker_extension_cls = resolve_obj_by_qualname(
|
||||
self.vllm_config.parallel_config.worker_extension_cls)
|
||||
extended_calls = []
|
||||
if worker_extension_cls not in worker_class.__bases__:
|
||||
# check any conflicts between worker and worker_extension_cls
|
||||
for attr in dir(worker_extension_cls):
|
||||
if attr.startswith("__"):
|
||||
continue
|
||||
assert not hasattr(worker_class, attr), (
|
||||
f"Worker class {worker_class} already has an attribute"
|
||||
f" {attr}, which conflicts with the worker"
|
||||
f" extension class {worker_extension_cls}.")
|
||||
if callable(getattr(worker_extension_cls, attr)):
|
||||
extended_calls.append(attr)
|
||||
# dynamically inherit the worker extension class
|
||||
worker_class.__bases__ = worker_class.__bases__ + (
|
||||
worker_extension_cls, )
|
||||
logger.info(
|
||||
"Injected %s into %s for extended collective_rpc calls %s",
|
||||
worker_extension_cls, worker_class, extended_calls)
|
||||
with set_current_vllm_config(self.vllm_config):
|
||||
# To make vLLM config available during worker initialization
|
||||
self.worker = worker_class(**kwargs)
|
||||
assert self.worker is not None
|
||||
|
||||
VLLM_NUMA_BIND = int(os.getenv("VLLM_NUMA_BIND", 1))
|
||||
if VLLM_NUMA_BIND > 0:
|
||||
# 绑定当前进程到指定 NUMA 节点
|
||||
bind_to_numa(kwargs['local_rank'])
|
||||
|
||||
pid = os.getpid()
|
||||
logger.info("########## %d process(rank%s) is running on CPU(s): %s", pid, str(kwargs['local_rank']), str(os.sched_getaffinity(pid)))
|
||||
logger.info("########## %d process(rank%s) is running on memnode(s): %s", pid, str(kwargs['local_rank']), str(numa.get_membind()))
|
||||
|
||||
|
||||
def initialize_from_config(self, kv_cache_configs: List[Any]) -> None:
|
||||
kv_cache_config = kv_cache_configs[self.rpc_rank]
|
||||
with set_current_vllm_config(self.vllm_config):
|
||||
self.worker.initialize_from_config(kv_cache_config) # type: ignore
|
||||
|
||||
def init_device(self):
|
||||
with set_current_vllm_config(self.vllm_config):
|
||||
# To make vLLM config available during device initialization
|
||||
self.worker.init_device() # type: ignore
|
||||
|
||||
def execute_method(self, method: Union[str, bytes], *args, **kwargs):
|
||||
try:
|
||||
# method resolution order:
|
||||
# if a method is defined in this class, it will be called directly.
|
||||
# otherwise, since we define `__getattr__` and redirect attribute
|
||||
# query to `self.worker`, the method will be called on the worker.
|
||||
return run_method(self, method, args, kwargs)
|
||||
except Exception as e:
|
||||
# if the driver worker also execute methods,
|
||||
# exceptions in the rest worker may cause deadlock in rpc like ray
|
||||
# see https://github.com/vllm-project/vllm/issues/3455
|
||||
# print the error and inform the user to solve the error
|
||||
msg = (f"Error executing method {method!r}. "
|
||||
"This might cause deadlock in distributed execution.")
|
||||
logger.exception(msg)
|
||||
raise e
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self.worker, attr)
|
||||
|
||||
|
||||
def extract_previous_hidden_states(
|
||||
data: Union[ExecuteModelRequest, Dict[str, torch.Tensor]]) -> \
|
||||
Dict[str, torch.Tensor]:
|
||||
"""If data contains previous_hidden_states, extract it. This returns a dict
|
||||
which can be used directly as additional kwargs in any following
|
||||
execute_model calls. This is used in draft models like EAGLE."""
|
||||
output = {}
|
||||
|
||||
# When called from non-driver worker, data is dict but when called from
|
||||
# driver worker, data is ExecuteModelRequest.
|
||||
if isinstance(data, dict):
|
||||
if "previous_hidden_states" in data:
|
||||
output["previous_hidden_states"] = data["previous_hidden_states"]
|
||||
elif data.previous_hidden_states is not None:
|
||||
output["previous_hidden_states"] = data.previous_hidden_states\
|
||||
.hidden_states
|
||||
|
||||
return output
|
||||
606
vllm/worker/xpu_model_runner.py
Normal file
606
vllm/worker/xpu_model_runner.py
Normal file
@@ -0,0 +1,606 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import dataclasses
|
||||
import time
|
||||
import weakref
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
|
||||
Type, TypeVar)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.attention import get_attn_backend
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.inputs import INPUT_REGISTRY, InputRegistry
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import SamplingMetadataCache
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
|
||||
MultiModalKwargs, MultiModalPlaceholderMap,
|
||||
MultiModalRegistry)
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
|
||||
from vllm.utils import DeviceMemoryProfiler, GiB_bytes, make_tensor_with_pad
|
||||
from vllm.worker.model_runner import AttentionMetadata, SamplingMetadata
|
||||
from vllm.worker.model_runner_base import (
|
||||
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
|
||||
_add_attn_metadata_broadcastable_dict,
|
||||
_add_sampling_metadata_broadcastable_dict,
|
||||
_init_attn_metadata_from_tensor_dict,
|
||||
_init_sampling_metadata_from_tensor_dict)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_PAD_SLOT_ID = -1
|
||||
|
||||
TModelInputForXPU = TypeVar('TModelInputForXPU', bound="ModelInputForXPU")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ModelInputForXPU(ModelRunnerInputBase):
|
||||
"""
|
||||
Used by the NeuronModelRunner.
|
||||
"""
|
||||
input_tokens: Optional[torch.Tensor] = None
|
||||
input_positions: Optional[torch.Tensor] = None
|
||||
attn_metadata: Optional["AttentionMetadata"] = None
|
||||
multi_modal_kwargs: Optional[BatchedTensorInputs] = None
|
||||
virtual_engine: Optional[int] = None
|
||||
seq_lens: Optional[List[int]] = None
|
||||
query_lens: Optional[List[int]] = None
|
||||
async_callback: Optional[Callable] = None
|
||||
|
||||
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
||||
tensor_dict = {
|
||||
"input_tokens": self.input_tokens,
|
||||
"input_positions": self.input_positions,
|
||||
}
|
||||
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
|
||||
|
||||
return tensor_dict
|
||||
|
||||
@classmethod
|
||||
def from_broadcasted_tensor_dict(
|
||||
cls: Type[TModelInputForXPU],
|
||||
tensor_dict: Dict[str, Any],
|
||||
attn_backend: Optional["AttentionBackend"] = None,
|
||||
) -> TModelInputForXPU:
|
||||
if attn_backend is not None:
|
||||
tensor_dict = _init_attn_metadata_from_tensor_dict(
|
||||
attn_backend, tensor_dict)
|
||||
return cls(**tensor_dict)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ModelInputForXPUWithSamplingMetadata(ModelInputForXPU):
|
||||
"""
|
||||
Used by the ModelRunner.
|
||||
"""
|
||||
sampling_metadata: Optional["SamplingMetadata"] = None
|
||||
|
||||
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
||||
tensor_dict = {
|
||||
"input_tokens": self.input_tokens,
|
||||
"input_positions": self.input_positions,
|
||||
}
|
||||
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
|
||||
_add_sampling_metadata_broadcastable_dict(tensor_dict,
|
||||
self.sampling_metadata)
|
||||
return tensor_dict
|
||||
|
||||
@classmethod
|
||||
def from_broadcasted_tensor_dict(
|
||||
cls,
|
||||
tensor_dict: Dict[str, Any],
|
||||
attn_backend: Optional["AttentionBackend"] = None,
|
||||
) -> "ModelInputForXPUWithSamplingMetadata":
|
||||
tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
|
||||
if attn_backend is not None:
|
||||
tensor_dict = _init_attn_metadata_from_tensor_dict(
|
||||
attn_backend, tensor_dict)
|
||||
return cls(**tensor_dict)
|
||||
|
||||
|
||||
class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]):
|
||||
|
||||
def __init__(self,
|
||||
runner: "XPUModelRunner",
|
||||
finished_requests_ids: Optional[List[str]] = None) -> None:
|
||||
super().__init__()
|
||||
self.runner = runner
|
||||
self.model_input_cls = self.runner._model_input_cls
|
||||
self.attn_backend = self.runner.attn_backend
|
||||
self.sliding_window = self.runner.sliding_window
|
||||
self.block_size = self.runner.block_size
|
||||
self.device = self.runner.device
|
||||
|
||||
def prepare(self,
|
||||
finished_requests_ids: Optional[List[str]] = None) -> None:
|
||||
self.seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||
|
||||
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
|
||||
self.seq_group_metadata_list.append(seq_group_metadata)
|
||||
|
||||
def build(self) -> ModelInputForXPU:
|
||||
is_prompt = self.seq_group_metadata_list[0].is_prompt
|
||||
# Prepare input tensors.
|
||||
if is_prompt:
|
||||
(input_tokens, input_positions, attn_metadata, seq_lens,
|
||||
multi_modal_kwargs) = self._prepare_prompt(
|
||||
self.seq_group_metadata_list)
|
||||
else:
|
||||
(input_tokens, input_positions,
|
||||
attn_metadata) = self._prepare_decode(
|
||||
self.seq_group_metadata_list)
|
||||
seq_lens = None
|
||||
multi_modal_kwargs = None
|
||||
|
||||
return self.model_input_cls(
|
||||
input_tokens=input_tokens,
|
||||
input_positions=input_positions,
|
||||
attn_metadata=attn_metadata,
|
||||
multi_modal_kwargs=multi_modal_kwargs,
|
||||
seq_lens=seq_lens,
|
||||
query_lens=seq_lens,
|
||||
)
|
||||
|
||||
def _prepare_prompt(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
|
||||
BatchedTensorInputs]:
|
||||
assert len(seq_group_metadata_list) > 0
|
||||
input_tokens: List[int] = []
|
||||
input_positions: List[int] = []
|
||||
slot_mapping: List[int] = []
|
||||
seq_lens: List[int] = []
|
||||
multi_modal_kwargs_list: List[MultiModalKwargs] = []
|
||||
multi_modal_placeholder_maps: Dict[
|
||||
str,
|
||||
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
|
||||
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
assert seq_group_metadata.is_prompt
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
assert len(seq_ids) == 1
|
||||
seq_id = seq_ids[0]
|
||||
|
||||
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||
prompt_tokens = seq_data.get_token_ids()
|
||||
computed_len = seq_data.get_num_computed_tokens()
|
||||
seq_len = len(prompt_tokens)
|
||||
|
||||
seq_lens.append(seq_len) # Prompt token num
|
||||
input_tokens.extend(prompt_tokens) # Token ids
|
||||
|
||||
# Token position ids
|
||||
# NOTE(woosuk): Here we assume that the first token in the prompt
|
||||
# is always the first token in the sequence.
|
||||
positions_range = range(computed_len, seq_len)
|
||||
input_positions.extend(list(positions_range))
|
||||
|
||||
if seq_group_metadata.multi_modal_data:
|
||||
# NOTE: mm_kwargs only includes the subset of multi-modal items
|
||||
# that intersect with the current prefill positions.
|
||||
mm_kwargs, placeholder_maps = MultiModalPlaceholderMap \
|
||||
.from_seq_group(seq_group_metadata, positions_range)
|
||||
|
||||
multi_modal_kwargs_list.append(mm_kwargs)
|
||||
|
||||
for modality, placeholder_map in placeholder_maps.items():
|
||||
multi_modal_placeholder_maps[modality].extend(
|
||||
placeholder_map)
|
||||
|
||||
if seq_group_metadata.block_tables is None:
|
||||
# During memory profiling, the block tables are not initialized
|
||||
# yet. In this case, we just use a dummy slot mapping.
|
||||
slot_mapping.extend([_PAD_SLOT_ID] * seq_len)
|
||||
continue
|
||||
|
||||
# Compute the slot mapping.
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
|
||||
# where start_idx is max(0, seq_len - sliding_window).
|
||||
# For example, if the prompt len is 10, sliding window is 8, and
|
||||
# block size is 4, the first two tokens are masked and the slot
|
||||
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
|
||||
start_idx = 0
|
||||
if self.sliding_window is not None:
|
||||
start_idx = max(0, seq_len - self.sliding_window)
|
||||
|
||||
for i in range(computed_len, seq_len):
|
||||
if i < start_idx:
|
||||
slot_mapping.append(_PAD_SLOT_ID)
|
||||
continue
|
||||
|
||||
block_number = block_table[i //
|
||||
self.block_size] # type: ignore
|
||||
block_offset = i % self.block_size # type: ignore
|
||||
slot = block_number * self.block_size + block_offset
|
||||
slot_mapping.append(slot)
|
||||
|
||||
num_prompt_tokens = len(input_tokens)
|
||||
|
||||
input_tokens = torch.tensor(input_tokens,
|
||||
dtype=torch.long,
|
||||
device=self.device) # type: ignore
|
||||
input_positions = torch.tensor(input_positions,
|
||||
dtype=torch.long,
|
||||
device=self.device) # type: ignore
|
||||
slot_mapping = torch.tensor(slot_mapping,
|
||||
dtype=torch.long,
|
||||
device=self.device) # type: ignore
|
||||
placeholder_index_maps = {
|
||||
modality: placeholder_map.index_map()
|
||||
for modality, placeholder_map in
|
||||
multi_modal_placeholder_maps.items()
|
||||
}
|
||||
|
||||
max_seqlen = max(seq_lens)
|
||||
tmp = [0]
|
||||
tmp.extend(seq_lens)
|
||||
seqlen = torch.tensor(tmp)
|
||||
seqlen_q = torch.cumsum(seqlen, dim=0).to(device=self.device)
|
||||
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
is_prompt=True,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
||||
enable_kv_scales_calculation=False,
|
||||
seq_lens=seq_lens,
|
||||
seqlen_q=seqlen_q,
|
||||
max_seqlen=max_seqlen,
|
||||
seq_lens_tensor=torch.tensor([]),
|
||||
max_decode_seq_len=0,
|
||||
num_prefills=len(seq_lens),
|
||||
num_prefill_tokens=num_prompt_tokens,
|
||||
num_decode_tokens=0,
|
||||
block_tables=torch.tensor([], device=self.device, dtype=torch.int),
|
||||
)
|
||||
|
||||
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
|
||||
|
||||
return (input_tokens, input_positions, attn_metadata, seq_lens,
|
||||
multi_modal_kwargs)
|
||||
|
||||
def _prepare_decode(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata]:
|
||||
assert len(seq_group_metadata_list) > 0
|
||||
input_tokens: List[int] = []
|
||||
input_positions: List[int] = []
|
||||
slot_mapping: List[int] = []
|
||||
seq_lens: List[int] = []
|
||||
block_tables: List[List[int]] = []
|
||||
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
assert not seq_group_metadata.is_prompt
|
||||
assert seq_group_metadata.token_chunk_size == 1
|
||||
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
|
||||
for seq_id in seq_ids:
|
||||
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||
generation_token = seq_data.get_last_token_id()
|
||||
input_tokens.append(generation_token)
|
||||
|
||||
seq_len = seq_data.get_len()
|
||||
position = seq_len - 1
|
||||
input_positions.append(position)
|
||||
|
||||
seq_len = seq_len if self.sliding_window is None else min(
|
||||
seq_len, self.sliding_window)
|
||||
seq_lens.append(seq_len)
|
||||
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
block_number = block_table[position // self.block_size]
|
||||
block_offset = position % self.block_size
|
||||
slot = block_number * self.block_size + block_offset
|
||||
slot_mapping.append(slot)
|
||||
|
||||
if self.sliding_window is not None:
|
||||
sliding_window_blocks = (self.sliding_window //
|
||||
self.block_size)
|
||||
block_table = block_table[-sliding_window_blocks:]
|
||||
block_tables.append(block_table)
|
||||
|
||||
max_decode_seq_len = max(seq_lens)
|
||||
|
||||
input_tokens = torch.tensor(input_tokens,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
input_positions = torch.tensor(input_positions,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
slot_mapping = torch.tensor(slot_mapping,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
seq_lens_tensor = torch.tensor(seq_lens,
|
||||
dtype=torch.int,
|
||||
device=self.device)
|
||||
|
||||
block_tables = make_tensor_with_pad(
|
||||
block_tables,
|
||||
pad=0,
|
||||
dtype=torch.int,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
is_prompt=False,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=False,
|
||||
seq_lens=seq_lens,
|
||||
seqlen_q=torch.tensor([]),
|
||||
max_seqlen=0,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_decode_seq_len=max_decode_seq_len,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=len(input_tokens),
|
||||
num_prefills=0,
|
||||
block_tables=block_tables,
|
||||
)
|
||||
return (
|
||||
input_tokens,
|
||||
input_positions,
|
||||
attn_metadata,
|
||||
)
|
||||
|
||||
|
||||
class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
|
||||
_model_input_cls: Type[ModelInputForXPUWithSamplingMetadata] = (
|
||||
ModelInputForXPUWithSamplingMetadata)
|
||||
_builder_cls: Type[ModelInputForXPUBuilder] = ModelInputForXPUBuilder
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
kv_cache_dtype: Optional[str] = "auto",
|
||||
is_driver_worker: bool = False,
|
||||
return_hidden_states: bool = False,
|
||||
input_registry: InputRegistry = INPUT_REGISTRY,
|
||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||
):
|
||||
|
||||
ModelRunnerBase.__init__(self, vllm_config=vllm_config)
|
||||
model_config = self.model_config
|
||||
cache_config = self.cache_config
|
||||
self.is_driver_worker = is_driver_worker
|
||||
self.return_hidden_states = return_hidden_states
|
||||
|
||||
self.device = self.device_config.device
|
||||
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.sliding_window = model_config.get_sliding_window()
|
||||
self.block_size = cache_config.block_size
|
||||
|
||||
self.attn_backend = get_attn_backend(
|
||||
self.model_config.get_head_size(),
|
||||
self.model_config.dtype,
|
||||
self.kv_cache_dtype,
|
||||
self.block_size,
|
||||
self.model_config.is_attention_free,
|
||||
)
|
||||
|
||||
# Multi-modal data support
|
||||
self.input_registry = input_registry
|
||||
self.mm_registry = mm_registry
|
||||
|
||||
# Lazy initialization.
|
||||
self.model: nn.Module # Set after init_Model
|
||||
self.sampler = get_sampler()
|
||||
|
||||
self.sampling_metadata_cache: SamplingMetadataCache = \
|
||||
SamplingMetadataCache() \
|
||||
if self.parallel_config.pipeline_parallel_size == 1 else None
|
||||
|
||||
self.builder = self._builder_cls(weakref.proxy(self))
|
||||
|
||||
def load_model(self) -> None:
|
||||
with DeviceMemoryProfiler() as m:
|
||||
self.model = get_model(vllm_config=self.vllm_config)
|
||||
|
||||
self.model_memory_usage = m.consumed_memory
|
||||
logger.info("Loading model weights took %.4f GiB",
|
||||
self.model_memory_usage / GiB_bytes)
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
return self.model_config.get_vocab_size()
|
||||
|
||||
@torch.inference_mode()
|
||||
def profile_run(self) -> None:
|
||||
# Enable top-k sampling to reflect the accurate memory usage.
|
||||
sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
|
||||
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
|
||||
max_num_seqs = self.scheduler_config.max_num_seqs
|
||||
|
||||
# Profile memory usage with max_num_sequences sequences and the total
|
||||
# number of tokens equal to max_num_batched_tokens.
|
||||
seqs: List[SequenceGroupMetadata] = []
|
||||
# Additional GPU memory may be needed for multi-modal encoding, which
|
||||
# needs to be accounted for when calculating the GPU blocks for
|
||||
# vLLM blocker manager.
|
||||
# To exercise the worst scenario for GPU memory consumption,
|
||||
# the number of seqs (batch_size) is chosen to maximize the number
|
||||
# of images processed.
|
||||
max_mm_tokens = self.mm_registry.get_max_multimodal_tokens(
|
||||
self.model_config)
|
||||
if max_mm_tokens > 0:
|
||||
max_num_seqs_orig = max_num_seqs
|
||||
max_num_seqs = min(max_num_seqs,
|
||||
max_num_batched_tokens // max_mm_tokens)
|
||||
if max_num_seqs < 1:
|
||||
expr = (f"min({max_num_seqs_orig}, "
|
||||
f"{max_num_batched_tokens} // {max_mm_tokens})")
|
||||
logger.warning(
|
||||
"Computed max_num_seqs (%s) to be less than 1. "
|
||||
"Setting it to the minimum value of 1.", expr)
|
||||
max_num_seqs = 1
|
||||
|
||||
batch_size = 0
|
||||
for group_id in range(max_num_seqs):
|
||||
seq_len = (max_num_batched_tokens // max_num_seqs +
|
||||
(group_id < max_num_batched_tokens % max_num_seqs))
|
||||
batch_size += seq_len
|
||||
|
||||
dummy_data = self.input_registry \
|
||||
.dummy_data_for_profiling(self.model_config,
|
||||
seq_len,
|
||||
self.mm_registry)
|
||||
|
||||
seq = SequenceGroupMetadata(
|
||||
request_id=str(group_id),
|
||||
is_prompt=True,
|
||||
seq_data={group_id: dummy_data.seq_data},
|
||||
sampling_params=sampling_params,
|
||||
block_tables=None,
|
||||
lora_request=None,
|
||||
multi_modal_data=dummy_data.multi_modal_data,
|
||||
multi_modal_placeholders=dummy_data.multi_modal_placeholders)
|
||||
seqs.append(seq)
|
||||
|
||||
finished_requests_ids = [seq.request_id for seq in seqs]
|
||||
model_input = self.prepare_model_input(
|
||||
seqs, finished_requests_ids=finished_requests_ids)
|
||||
intermediate_tensors = None
|
||||
if not get_pp_group().is_first_rank:
|
||||
intermediate_tensors = self.model.make_empty_intermediate_tensors(
|
||||
batch_size=batch_size,
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device)
|
||||
self.execute_model(model_input, None, intermediate_tensors)
|
||||
torch.xpu.synchronize()
|
||||
return
|
||||
|
||||
def make_model_input_from_broadcasted_tensor_dict(
|
||||
self,
|
||||
tensor_dict: Dict[str,
|
||||
Any]) -> ModelInputForXPUWithSamplingMetadata:
|
||||
return (
|
||||
ModelInputForXPUWithSamplingMetadata.from_broadcasted_tensor_dict(
|
||||
tensor_dict,
|
||||
attn_backend=self.attn_backend,
|
||||
))
|
||||
|
||||
def _prepare_model_input_tensors(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
finished_requests_ids: Optional[List[str]] = None
|
||||
) -> ModelInputForXPUWithSamplingMetadata:
|
||||
"""Helper method to prepare the model input based on a given sequence
|
||||
group. Prepares metadata needed for the base model forward pass but not
|
||||
metadata for possible additional steps, e.g., sampling.
|
||||
|
||||
"""
|
||||
builder = self.builder
|
||||
builder.prepare(finished_requests_ids)
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
builder.add_seq_group(seq_group_metadata)
|
||||
|
||||
return builder.build() # type: ignore
|
||||
|
||||
def prepare_model_input(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
virtual_engine: int = 0,
|
||||
finished_requests_ids: Optional[List[str]] = None
|
||||
) -> ModelInputForXPUWithSamplingMetadata:
|
||||
"""Prepare the model input based on a given sequence group, including
|
||||
metadata for the sampling step.
|
||||
|
||||
"""
|
||||
model_input = self._prepare_model_input_tensors(
|
||||
seq_group_metadata_list, finished_requests_ids)
|
||||
# Sampling metadata is only required for the final pp group
|
||||
generators = self.get_generators(finished_requests_ids)
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
model_input.seq_lens,
|
||||
model_input.query_lens,
|
||||
self.device,
|
||||
pin_memory=False,
|
||||
generators=generators,
|
||||
cache=self.sampling_metadata_cache)
|
||||
|
||||
return dataclasses.replace(model_input,
|
||||
sampling_metadata=sampling_metadata,
|
||||
virtual_engine=virtual_engine)
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
model_input: ModelInputForXPUWithSamplingMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
num_steps: int = 1,
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
if num_steps > 1:
|
||||
raise ValueError(
|
||||
"XPUModelRunner does not support multi-step execution.")
|
||||
|
||||
model_executable = self.model
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_forward_time):
|
||||
model_forward_start_time = time.time()
|
||||
with set_forward_context(model_input.attn_metadata, self.vllm_config,
|
||||
model_input.virtual_engine):
|
||||
hidden_or_intermediate_states = model_executable(
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**MultiModalKwargs.as_kwargs(
|
||||
model_input.multi_modal_kwargs or {},
|
||||
device=self.device,
|
||||
),
|
||||
)
|
||||
# Compute the logits in the last pipeline stage.
|
||||
if not get_pp_group().is_last_rank:
|
||||
return hidden_or_intermediate_states
|
||||
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_forward_time):
|
||||
model_forward_end_time = time.time()
|
||||
|
||||
# Compute the logits.
|
||||
logits = self.model.compute_logits(hidden_or_intermediate_states,
|
||||
model_input.sampling_metadata)
|
||||
|
||||
# Only perform sampling in the driver worker.
|
||||
if not self.is_driver_worker:
|
||||
return []
|
||||
|
||||
if model_input.async_callback is not None:
|
||||
model_input.async_callback()
|
||||
|
||||
# Sample the next token.
|
||||
output: SamplerOutput = self.sampler(
|
||||
logits=logits,
|
||||
sampling_metadata=model_input.sampling_metadata,
|
||||
)
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_forward_time
|
||||
and output is not None):
|
||||
model_forward_time = (model_forward_end_time -
|
||||
model_forward_start_time)
|
||||
# If there are multiple workers, we are still tracking the latency
|
||||
# from the start time of the driver worker to the end time of the
|
||||
# driver worker. The model forward time will then end up covering
|
||||
# the communication time as well.
|
||||
output.model_forward_time = model_forward_time
|
||||
|
||||
return [output]
|
||||
186
vllm/worker/xpu_worker.py
Normal file
186
vllm/worker/xpu_worker.py
Normal file
@@ -0,0 +1,186 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""A XPU worker class."""
|
||||
import gc
|
||||
import os
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import intel_extension_for_pytorch # noqa: F401
|
||||
import oneccl_bindings_for_pytorch # noqa: F401
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.worker.cache_engine import CacheEngine
|
||||
from vllm.worker.worker import Worker
|
||||
from vllm.worker.worker_base import LoRANotSupportedWorkerBase, WorkerBase
|
||||
from vllm.worker.xpu_model_runner import XPUModelRunner
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class XPUWorker(LoRANotSupportedWorkerBase, Worker):
|
||||
"""A worker class that executes (a partition of) the model on a GPU.
|
||||
|
||||
Each worker is associated with a single XPU device. The worker is
|
||||
responsible for maintaining the KV cache and executing the model on the
|
||||
XPU. In case of distributed inference, each worker is assigned a partition
|
||||
of the model.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
is_driver_worker: bool = False,
|
||||
) -> None:
|
||||
WorkerBase.__init__(self, vllm_config=vllm_config)
|
||||
device_config = self.device_config
|
||||
parallel_config = self.parallel_config
|
||||
assert device_config.device_type == "xpu"
|
||||
assert current_platform.is_xpu()
|
||||
|
||||
self.parallel_config.rank = rank
|
||||
|
||||
self.local_rank = local_rank
|
||||
self.rank = rank
|
||||
self.distributed_init_method = distributed_init_method
|
||||
self.is_driver_worker = is_driver_worker
|
||||
if parallel_config and is_driver_worker:
|
||||
assert rank % parallel_config.tensor_parallel_size == 0, \
|
||||
"Driver worker should be rank 0 of tensor parallel group."
|
||||
|
||||
self.model_runner = XPUModelRunner( # type: ignore
|
||||
vllm_config=vllm_config,
|
||||
kv_cache_dtype=self.cache_config.cache_dtype,
|
||||
is_driver_worker=is_driver_worker,
|
||||
)
|
||||
# Uninitialized cache engine. Will be initialized by
|
||||
# initialize_cache.
|
||||
self.cache_engine: List[CacheEngine]
|
||||
self.gpu_cache: Optional[List[List[torch.Tensor]]]
|
||||
|
||||
def init_device(self) -> None:
|
||||
if self.device_config.device.type == "xpu" and current_platform.is_xpu(
|
||||
):
|
||||
self.device = torch.device(f"xpu:{self.local_rank}")
|
||||
torch.xpu.set_device(self.device)
|
||||
torch.xpu.empty_cache()
|
||||
self.init_gpu_memory = torch.xpu.get_device_properties(
|
||||
self.local_rank).total_memory
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Not support device type: {self.device_config.device}")
|
||||
# Initialize the distributed environment.
|
||||
self.init_worker_distributed_environment()
|
||||
# Initialize the model.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
# keep this method for `empty_cache` and `synchronize` api
|
||||
@torch.inference_mode()
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
"""Profiles the peak memory usage of the model to determine how many
|
||||
KV blocks may be allocated without OOMs.
|
||||
|
||||
The engine will first conduct a profiling of the existing memory usage.
|
||||
Then, it calculate the maximum possible number of GPU and CPU blocks
|
||||
that can be allocated with the remaining free memory.
|
||||
|
||||
Tip:
|
||||
You may limit the usage of GPU memory
|
||||
by adjusting the `gpu_memory_utilization` parameter.
|
||||
"""
|
||||
# Profile the memory usage of the model and get the maximum number of
|
||||
# cache blocks that can be allocated with the remaining free memory.
|
||||
torch.xpu.empty_cache()
|
||||
|
||||
# Execute a forward pass with dummy inputs to profile the memory usage
|
||||
# of the model.
|
||||
self.model_runner.profile_run()
|
||||
|
||||
# Calculate the number of blocks that can be allocated with the
|
||||
# profiled peak memory.
|
||||
torch.xpu.synchronize()
|
||||
used_memory = torch.xpu.memory_allocated()
|
||||
total_gpu_memory = torch.xpu.get_device_properties(
|
||||
self.local_rank).total_memory
|
||||
free_gpu_memory = total_gpu_memory - used_memory
|
||||
|
||||
# NOTE(woosuk): Here we assume that the other processes using the same
|
||||
# GPU did not change their memory usage during the profiling.
|
||||
peak_memory = self.init_gpu_memory - free_gpu_memory
|
||||
assert peak_memory > 0, (
|
||||
"Error in memory profiling. "
|
||||
f"Initial free memory {self.init_gpu_memory}, current free memory"
|
||||
f" {free_gpu_memory}. This happens when the GPU memory was "
|
||||
"not properly cleaned up before initializing the vLLM instance.")
|
||||
|
||||
cache_block_size = self.get_cache_block_size_bytes()
|
||||
num_gpu_blocks = int(
|
||||
(total_gpu_memory * self.cache_config.gpu_memory_utilization -
|
||||
peak_memory) // cache_block_size)
|
||||
num_cpu_blocks = int(self.cache_config.swap_space_bytes //
|
||||
cache_block_size)
|
||||
num_gpu_blocks = max(num_gpu_blocks, 0)
|
||||
num_cpu_blocks = max(num_cpu_blocks, 0)
|
||||
gc.collect()
|
||||
torch.xpu.empty_cache()
|
||||
return num_gpu_blocks, num_cpu_blocks
|
||||
|
||||
def _warm_up_model(self) -> None:
|
||||
# IPEX don't support capture graph yet
|
||||
pass
|
||||
|
||||
def init_worker_distributed_environment(self) -> None:
|
||||
"""Initialize the distributed environment."""
|
||||
|
||||
parallel_config = self.parallel_config
|
||||
rank = self.rank
|
||||
distributed_init_method = self.distributed_init_method
|
||||
|
||||
if torch.distributed.is_initialized():
|
||||
torch_world_size = torch.distributed.get_world_size()
|
||||
if torch_world_size != parallel_config.world_size:
|
||||
raise RuntimeError(
|
||||
"torch.distributed is already initialized but the torch "
|
||||
"world size does not match parallel_config.world_size "
|
||||
f"({torch_world_size} vs. {parallel_config.world_size}).")
|
||||
elif not distributed_init_method:
|
||||
raise ValueError(
|
||||
"distributed_init_method must be set if torch.distributed "
|
||||
"is not already initialized")
|
||||
else:
|
||||
# use sockets as default Level zero IPC exchange backend. By
|
||||
# default oneccl will use `drmfd` as mechanism which need extra
|
||||
# dependency (libdrm and drm headers) on your system.
|
||||
ENV_CCL_ATL_TRANSPORT = os.getenv("CCL_ATL_TRANSPORT", "ofi")
|
||||
ENV_LOCAL_WORLD_SIZE = os.getenv("LOCAL_WORLD_SIZE",
|
||||
str(parallel_config.world_size))
|
||||
os.environ["CCL_ATL_TRANSPORT"] = ENV_CCL_ATL_TRANSPORT
|
||||
os.environ["LOCAL_WORLD_SIZE"] = ENV_LOCAL_WORLD_SIZE
|
||||
os.environ["LOCAL_RANK"] = str(self.local_rank)
|
||||
init_distributed_environment(
|
||||
world_size=parallel_config.world_size,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
local_rank=self.local_rank,
|
||||
backend="ccl")
|
||||
|
||||
ensure_model_parallel_initialized(
|
||||
parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size)
|
||||
# global all_reduce needed for overall oneccl warm up
|
||||
torch.distributed.all_reduce(torch.zeros(1).xpu())
|
||||
|
||||
if parallel_config.pipeline_parallel_size > 1:
|
||||
# Add pp group init to avoid
|
||||
# p2p communication as the first call
|
||||
get_pp_group().all_reduce(torch.zeros(1).xpu())
|
||||
Reference in New Issue
Block a user