Iluvatar-mrv100 SDK 4.3.0
This commit is contained in:
0
vllm/v1/worker/__init__.py
Normal file
0
vllm/v1/worker/__init__.py
Normal file
87
vllm/v1/worker/block_table.py
Normal file
87
vllm/v1/worker/block_table.py
Normal file
@@ -0,0 +1,87 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class BlockTable:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_num_reqs: int,
|
||||
max_num_blocks_per_req: int,
|
||||
pin_memory: bool,
|
||||
device: torch.device,
|
||||
):
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_num_blocks_per_req = max_num_blocks_per_req
|
||||
self.pin_memory = pin_memory
|
||||
self.device = device
|
||||
|
||||
self.block_table = torch.zeros(
|
||||
(max_num_reqs, max_num_blocks_per_req),
|
||||
device=self.device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
self.block_table_cpu = torch.zeros(
|
||||
(max_num_reqs, max_num_blocks_per_req),
|
||||
device="cpu",
|
||||
dtype=torch.int32,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
self.block_table_np = self.block_table_cpu.numpy()
|
||||
self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)
|
||||
|
||||
def append_row(
|
||||
self,
|
||||
block_ids: list[int],
|
||||
row_idx: int,
|
||||
) -> None:
|
||||
if not block_ids:
|
||||
return
|
||||
num_blocks = len(block_ids)
|
||||
start = self.num_blocks_per_row[row_idx]
|
||||
self.num_blocks_per_row[row_idx] += num_blocks
|
||||
self.block_table_np[row_idx, start:start + num_blocks] = block_ids
|
||||
|
||||
def add_row(self, block_ids: list[int], row_idx: int) -> None:
|
||||
self.num_blocks_per_row[row_idx] = 0
|
||||
self.append_row(block_ids, row_idx)
|
||||
|
||||
def move_row(self, src: int, tgt: int) -> None:
|
||||
num_blocks = self.num_blocks_per_row[src]
|
||||
self.block_table_np[tgt, :num_blocks] = self.block_table_np[
|
||||
src, :num_blocks]
|
||||
self.num_blocks_per_row[tgt] = num_blocks
|
||||
|
||||
def swap_row(self, src: int, tgt: int) -> None:
|
||||
num_blocks_src = self.num_blocks_per_row[src]
|
||||
num_blocks_tgt = self.num_blocks_per_row[tgt]
|
||||
self.num_blocks_per_row[src] = num_blocks_tgt
|
||||
self.num_blocks_per_row[tgt] = num_blocks_src
|
||||
|
||||
self.block_table_np[[src, tgt]] = self.block_table_np[[tgt, src]]
|
||||
|
||||
def commit(self, num_reqs: int) -> None:
|
||||
self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs],
|
||||
non_blocking=True)
|
||||
|
||||
def clear(self) -> None:
|
||||
self.block_table.fill_(0)
|
||||
self.block_table_cpu.fill_(0)
|
||||
|
||||
def get_device_tensor(self) -> torch.Tensor:
|
||||
"""Ruturns the device tensor of the block table."""
|
||||
return self.block_table
|
||||
|
||||
def get_cpu_tensor(self) -> torch.Tensor:
|
||||
"""Returns the CPU tensor of the block table."""
|
||||
return self.block_table_cpu
|
||||
|
||||
def get_numpy_array(self) -> np.ndarray:
|
||||
"""Returns the numpy array of the block table."""
|
||||
return self.block_table_np
|
||||
678
vllm/v1/worker/gpu_input_batch.py
Normal file
678
vllm/v1/worker/gpu_input_batch.py
Normal file
@@ -0,0 +1,678 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Datastructures defining an input batch
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.utils import swap_dict_values
|
||||
from vllm.v1.outputs import LogprobsTensors
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.utils import copy_slice
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
|
||||
@dataclass
|
||||
class CachedRequestState:
|
||||
|
||||
req_id: str
|
||||
prompt_token_ids: list[int]
|
||||
prompt: Optional[str]
|
||||
mm_inputs: list[MultiModalKwargs]
|
||||
mm_positions: list[PlaceholderRange]
|
||||
sampling_params: SamplingParams
|
||||
generator: Optional[torch.Generator]
|
||||
|
||||
block_ids: list[int]
|
||||
num_computed_tokens: int
|
||||
output_token_ids: list[int]
|
||||
|
||||
mrope_positions: Optional[torch.Tensor] = None
|
||||
mrope_position_delta: Optional[int] = None
|
||||
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
|
||||
def __post_init__(self):
|
||||
self.num_prompt_tokens = len(self.prompt_token_ids)
|
||||
|
||||
@property
|
||||
def num_tokens(self) -> int:
|
||||
return self.num_prompt_tokens + len(self.output_token_ids)
|
||||
|
||||
def get_token_id(self, idx: int) -> int:
|
||||
if idx < self.num_prompt_tokens:
|
||||
return self.prompt_token_ids[idx]
|
||||
else:
|
||||
return self.output_token_ids[idx - self.num_prompt_tokens]
|
||||
|
||||
|
||||
class InputBatch:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_num_reqs: int,
|
||||
max_model_len: int,
|
||||
max_num_blocks_per_req: int,
|
||||
device: torch.device,
|
||||
pin_memory: bool,
|
||||
vocab_size: int,
|
||||
):
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_model_len = max_model_len
|
||||
self.max_num_blocks_per_req = max_num_blocks_per_req
|
||||
self.device = device
|
||||
self.pin_memory = pin_memory
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
self._req_ids: list[Optional[str]] = []
|
||||
self.req_id_to_index: dict[str, int] = {}
|
||||
|
||||
# TODO(woosuk): This buffer could be too large if max_model_len is big.
|
||||
# Find a way to reduce the CPU memory usage.
|
||||
# This buffer is not directly transferred to the GPU, so it does not
|
||||
# need to be pinned.
|
||||
self.token_ids_cpu_tensor = torch.zeros(
|
||||
(max_num_reqs, max_model_len),
|
||||
device="cpu",
|
||||
dtype=torch.int32,
|
||||
pin_memory=False,
|
||||
)
|
||||
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
|
||||
self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
|
||||
self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
|
||||
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
|
||||
self.num_computed_tokens_cpu_tensor = torch.zeros(
|
||||
(max_num_reqs, ),
|
||||
device="cpu",
|
||||
dtype=torch.int32,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
self.num_computed_tokens_cpu = \
|
||||
self.num_computed_tokens_cpu_tensor.numpy()
|
||||
|
||||
# Block table.
|
||||
self.block_table = BlockTable(
|
||||
max_num_reqs=max_num_reqs,
|
||||
max_num_blocks_per_req=max_num_blocks_per_req,
|
||||
pin_memory=pin_memory,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Sampling-related.
|
||||
self.temperature = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
self.temperature_cpu_tensor = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float32,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.temperature_cpu = self.temperature_cpu_tensor.numpy()
|
||||
self.greedy_reqs: set[str] = set()
|
||||
self.random_reqs: set[str] = set()
|
||||
|
||||
self.top_p = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
self.top_p_cpu_tensor = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float32,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.top_p_cpu = self.top_p_cpu_tensor.numpy()
|
||||
self.top_p_reqs: set[str] = set()
|
||||
|
||||
self.top_k = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
self.top_k_cpu_tensor = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.top_k_cpu = self.top_k_cpu_tensor.numpy()
|
||||
self.top_k_reqs: set[str] = set()
|
||||
|
||||
self.min_p = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
self.min_p_cpu_tensor = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float32,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.min_p_cpu = self.min_p_cpu_tensor.numpy()
|
||||
self.min_p_reqs: set[str] = set()
|
||||
|
||||
# Frequency penalty related data structures
|
||||
self.frequency_penalties = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float,
|
||||
device=device)
|
||||
self.frequency_penalties_cpu_tensor = torch.empty(
|
||||
(max_num_reqs, ),
|
||||
dtype=torch.float,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.frequency_penalties_cpu = \
|
||||
self.frequency_penalties_cpu_tensor.numpy()
|
||||
self.frequency_penalties_reqs: set[str] = set()
|
||||
|
||||
# Presence penalty related data structures
|
||||
self.presence_penalties = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float,
|
||||
device=device)
|
||||
self.presence_penalties_cpu_tensor = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy(
|
||||
)
|
||||
self.presence_penalties_reqs: set[str] = set()
|
||||
|
||||
# Repetition penalty related data structures
|
||||
self.repetition_penalties = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float,
|
||||
device=device)
|
||||
self.repetition_penalties_cpu_tensor = torch.empty(
|
||||
(max_num_reqs, ),
|
||||
dtype=torch.float,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.repetition_penalties_cpu = \
|
||||
self.repetition_penalties_cpu_tensor.numpy()
|
||||
self.repetition_penalties_reqs: set[str] = set()
|
||||
|
||||
# req_index -> (min_tokens, stop_token_ids)
|
||||
self.min_tokens: dict[int, tuple[int, set[int]]] = {}
|
||||
|
||||
# lora related
|
||||
self.request_lora_mapping = np.zeros((self.max_num_reqs, ),
|
||||
dtype=np.int32)
|
||||
self.lora_id_to_request_ids: dict[int, set[str]] = {}
|
||||
self.lora_id_to_lora_request: dict[int, LoRARequest] = {}
|
||||
|
||||
# req_index -> generator
|
||||
# NOTE(woosuk): The indices of the requests that do not have their own
|
||||
# generator should not be included in the dictionary.
|
||||
self.generators: dict[int, torch.Generator] = {}
|
||||
|
||||
self.num_logprobs: dict[str, int] = {}
|
||||
# NOTE(rob): num_prompt_logprobs only includes reqs
|
||||
# that are currently in the prefill phase.
|
||||
self.num_prompt_logprobs: dict[str, int] = {}
|
||||
|
||||
# To accumulate prompt logprobs tensor chunks across prefill steps.
|
||||
self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}
|
||||
|
||||
self.logit_bias: list[Optional[dict[int,
|
||||
float]]] = [None] * max_num_reqs
|
||||
self.has_allowed_token_ids: set[str] = set()
|
||||
# NOTE(lufang): In the mask tensor, if the corresponding token allowed,
|
||||
# the value is False. Since we use masked_fill_ to set -inf.
|
||||
self.allowed_token_ids_mask: Optional[torch.Tensor] = None
|
||||
self.allowed_token_ids_mask_cpu_tensor: Optional[torch.Tensor] = None
|
||||
|
||||
# req_index -> bad_words_token_ids
|
||||
self.bad_words_token_ids: dict[int, list[list[int]]] = {}
|
||||
|
||||
self.req_output_token_ids: list[Optional[list[int]]] = []
|
||||
|
||||
# This is updated each time the batch constituents change.
|
||||
self.sampling_metadata = self._make_sampling_metadata()
|
||||
|
||||
@property
|
||||
def req_ids(self) -> list[str]:
|
||||
# None elements should only be present transiently
|
||||
# while performing state updates to the batch.
|
||||
return cast(list[str], self._req_ids)
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
request: "CachedRequestState",
|
||||
req_index: Optional[int] = None,
|
||||
) -> None:
|
||||
if req_index is None:
|
||||
req_index = self.num_reqs
|
||||
assert req_index < self.max_num_reqs
|
||||
|
||||
req_id = request.req_id
|
||||
if req_index == len(self._req_ids):
|
||||
self._req_ids.append(req_id)
|
||||
self.req_output_token_ids.append(request.output_token_ids)
|
||||
else:
|
||||
self._req_ids[req_index] = req_id
|
||||
self.req_output_token_ids[req_index] = request.output_token_ids
|
||||
|
||||
self.req_id_to_index[req_id] = req_index
|
||||
|
||||
# Copy the prompt token ids and output token ids.
|
||||
num_prompt_tokens = len(request.prompt_token_ids)
|
||||
self.num_prompt_tokens[req_index] = num_prompt_tokens
|
||||
self.token_ids_cpu[
|
||||
req_index, :num_prompt_tokens] = request.prompt_token_ids
|
||||
start_idx = num_prompt_tokens
|
||||
end_idx = start_idx + len(request.output_token_ids)
|
||||
self.token_ids_cpu[req_index,
|
||||
start_idx:end_idx] = request.output_token_ids
|
||||
# Number of token ids in token_ids_cpu.
|
||||
# NOTE(woosuk): This may include spec decode tokens.
|
||||
self.num_tokens[req_index] = request.num_tokens
|
||||
# Number of tokens without spec decode tokens.
|
||||
self.num_tokens_no_spec[req_index] = request.num_tokens
|
||||
|
||||
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
|
||||
self.block_table.add_row(request.block_ids, req_index)
|
||||
|
||||
sampling_params = request.sampling_params
|
||||
if sampling_params.sampling_type == SamplingType.GREEDY:
|
||||
# Avoid later division by zero.
|
||||
self.temperature_cpu[req_index] = -1.0
|
||||
self.greedy_reqs.add(req_id)
|
||||
else:
|
||||
self.temperature_cpu[req_index] = sampling_params.temperature
|
||||
self.random_reqs.add(req_id)
|
||||
|
||||
self.top_p_cpu[req_index] = sampling_params.top_p
|
||||
if sampling_params.top_p < 1:
|
||||
self.top_p_reqs.add(req_id)
|
||||
top_k = sampling_params.top_k
|
||||
if 0 < top_k < self.vocab_size:
|
||||
self.top_k_reqs.add(req_id)
|
||||
else:
|
||||
top_k = self.vocab_size
|
||||
self.top_k_cpu[req_index] = top_k
|
||||
self.min_p_cpu[req_index] = sampling_params.min_p
|
||||
self.frequency_penalties_cpu[
|
||||
req_index] = sampling_params.frequency_penalty
|
||||
if sampling_params.min_p > _SAMPLING_EPS:
|
||||
self.min_p_reqs.add(req_id)
|
||||
if sampling_params.frequency_penalty != 0.0:
|
||||
self.frequency_penalties_reqs.add(req_id)
|
||||
self.presence_penalties_cpu[
|
||||
req_index] = sampling_params.presence_penalty
|
||||
if sampling_params.presence_penalty != 0.0:
|
||||
self.presence_penalties_reqs.add(req_id)
|
||||
self.repetition_penalties_cpu[
|
||||
req_index] = sampling_params.repetition_penalty
|
||||
if sampling_params.repetition_penalty != 1.0:
|
||||
self.repetition_penalties_reqs.add(req_id)
|
||||
if sampling_params.min_tokens:
|
||||
self.min_tokens[req_index] = (sampling_params.min_tokens,
|
||||
sampling_params.all_stop_token_ids)
|
||||
|
||||
# NOTE(woosuk): self.generators should not include the requests that
|
||||
# do not have their own generator.
|
||||
if request.generator is not None:
|
||||
self.generators[req_index] = request.generator
|
||||
|
||||
if sampling_params.logprobs is not None:
|
||||
self.num_logprobs[req_id] = sampling_params.logprobs
|
||||
if sampling_params.prompt_logprobs is not None:
|
||||
self.num_prompt_logprobs[req_id] = sampling_params.prompt_logprobs
|
||||
if sampling_params.logit_bias is not None:
|
||||
self.logit_bias[req_index] = sampling_params.logit_bias
|
||||
|
||||
if sampling_params.allowed_token_ids:
|
||||
self.has_allowed_token_ids.add(req_id)
|
||||
if self.allowed_token_ids_mask_cpu_tensor is None:
|
||||
# Lazy allocation for this tensor, which can be large.
|
||||
# False means we don't fill with -inf.
|
||||
self.allowed_token_ids_mask = torch.zeros(self.max_num_reqs,
|
||||
self.vocab_size,
|
||||
dtype=torch.bool,
|
||||
device=self.device)
|
||||
self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
|
||||
self.max_num_reqs,
|
||||
self.vocab_size,
|
||||
dtype=torch.bool,
|
||||
device="cpu")
|
||||
self.allowed_token_ids_mask_cpu_tensor[req_index] = True
|
||||
# False means we don't fill with -inf.
|
||||
self.allowed_token_ids_mask_cpu_tensor[req_index][
|
||||
sampling_params.allowed_token_ids] = False
|
||||
|
||||
if sampling_params.bad_words_token_ids:
|
||||
self.bad_words_token_ids[
|
||||
req_index] = sampling_params.bad_words_token_ids
|
||||
|
||||
# Add request lora ID
|
||||
if request.lora_request:
|
||||
lora_id = request.lora_request.lora_int_id
|
||||
if lora_id not in self.lora_id_to_request_ids:
|
||||
self.lora_id_to_request_ids[lora_id] = set()
|
||||
|
||||
self.request_lora_mapping[req_index] = lora_id
|
||||
self.lora_id_to_request_ids[lora_id].add(request.req_id)
|
||||
self.lora_id_to_lora_request[lora_id] = request.lora_request
|
||||
else:
|
||||
# No LoRA
|
||||
self.request_lora_mapping[req_index] = 0
|
||||
|
||||
def remove_request(self, req_id: str) -> Optional[int]:
|
||||
"""This method must always be followed by a call to condense()."""
|
||||
|
||||
req_index = self.req_id_to_index.pop(req_id, None)
|
||||
if req_index is None:
|
||||
return None
|
||||
self._req_ids[req_index] = None
|
||||
self.req_output_token_ids[req_index] = None
|
||||
|
||||
self.greedy_reqs.discard(req_id)
|
||||
self.random_reqs.discard(req_id)
|
||||
self.top_p_reqs.discard(req_id)
|
||||
self.top_k_reqs.discard(req_id)
|
||||
self.min_p_reqs.discard(req_id)
|
||||
self.min_tokens.pop(req_index, None)
|
||||
self.frequency_penalties_reqs.discard(req_id)
|
||||
self.presence_penalties_reqs.discard(req_id)
|
||||
self.repetition_penalties_reqs.discard(req_id)
|
||||
self.generators.pop(req_index, None)
|
||||
self.num_logprobs.pop(req_id, None)
|
||||
self.num_prompt_logprobs.pop(req_id, None)
|
||||
self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
|
||||
|
||||
# LoRA
|
||||
lora_id = self.request_lora_mapping[req_index]
|
||||
if lora_id != 0:
|
||||
self.lora_id_to_request_ids[lora_id].discard(req_id)
|
||||
if len(self.lora_id_to_request_ids[lora_id]) == 0:
|
||||
self.lora_id_to_request_ids.pop(lora_id)
|
||||
self.lora_id_to_lora_request.pop(lora_id)
|
||||
self.request_lora_mapping[req_index] = 0
|
||||
|
||||
self.logit_bias[req_index] = None
|
||||
self.has_allowed_token_ids.discard(req_id)
|
||||
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
||||
# False means we don't fill with -inf.
|
||||
self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
|
||||
self.bad_words_token_ids.pop(req_index, None)
|
||||
return req_index
|
||||
|
||||
def swap_states(self, i1: int, i2: int) -> None:
|
||||
old_id_i1 = self._req_ids[i1]
|
||||
old_id_i2 = self._req_ids[i2]
|
||||
self._req_ids[i1], self._req_ids[i2] =\
|
||||
self._req_ids[i2], self._req_ids[i1] # noqa
|
||||
self.req_output_token_ids[i1], self.req_output_token_ids[i2] =\
|
||||
self.req_output_token_ids[i2], self.req_output_token_ids[i1]
|
||||
assert old_id_i1 is not None and old_id_i2 is not None
|
||||
self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] =\
|
||||
self.req_id_to_index[old_id_i2], self.req_id_to_index[old_id_i1]
|
||||
self.num_tokens[i1], self.num_tokens[i2] =\
|
||||
self.num_tokens[i2], self.num_tokens[i1]
|
||||
self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] =\
|
||||
self.num_tokens_no_spec[i2], self.num_tokens_no_spec[i1]
|
||||
self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] =\
|
||||
self.num_prompt_tokens[i2], self.num_prompt_tokens[i1]
|
||||
self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\
|
||||
self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1]
|
||||
self.temperature_cpu[i1], self.temperature_cpu[i2] =\
|
||||
self.temperature_cpu[i2], self.temperature_cpu[i1]
|
||||
self.top_p_cpu[i1], self.top_p_cpu[i2] =\
|
||||
self.top_p_cpu[i2], self.top_p_cpu[i1]
|
||||
self.top_k_cpu[i1], self.top_k_cpu[i2] =\
|
||||
self.top_k_cpu[i2], self.top_k_cpu[i1]
|
||||
self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] =\
|
||||
self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1]
|
||||
self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] =\
|
||||
self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1]
|
||||
self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] =\
|
||||
self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1]
|
||||
self.min_p_cpu[i1], self.min_p_cpu[i2] =\
|
||||
self.min_p_cpu[i2], self.min_p_cpu[i1]
|
||||
|
||||
# NOTE: the following is unsafe
|
||||
# self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
|
||||
# self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
|
||||
# instead, we need to temporiarily copy the data for one of the indices
|
||||
# TODO(lucas): optimize this by only copying valid indices
|
||||
tmp = self.token_ids_cpu[i1, ...].copy()
|
||||
self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
|
||||
self.token_ids_cpu[i2, ...] = tmp
|
||||
|
||||
swap_dict_values(self.generators, i1, i2)
|
||||
swap_dict_values(self.min_tokens, i1, i2)
|
||||
swap_dict_values(self.bad_words_token_ids, i1, i2)
|
||||
|
||||
self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\
|
||||
self.request_lora_mapping[i2], self.request_lora_mapping[i1]
|
||||
self.logit_bias[i1], self.logit_bias[i2] =\
|
||||
self.logit_bias[i2], self.logit_bias[i1]
|
||||
|
||||
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
||||
self.allowed_token_ids_mask_cpu_tensor[i1], \
|
||||
self.allowed_token_ids_mask_cpu_tensor[i2] =\
|
||||
self.allowed_token_ids_mask_cpu_tensor[i2], \
|
||||
self.allowed_token_ids_mask_cpu_tensor[i1]
|
||||
self.block_table.swap_row(i1, i2)
|
||||
|
||||
def condense(self, empty_req_indices: list[int]) -> None:
|
||||
num_reqs = self.num_reqs
|
||||
if num_reqs == 0:
|
||||
# The batched states are empty.
|
||||
self._req_ids.clear()
|
||||
self.req_output_token_ids.clear()
|
||||
return
|
||||
|
||||
# NOTE(woosuk): This function assumes that the empty_req_indices
|
||||
# is sorted in descending order.
|
||||
last_req_index = num_reqs + len(empty_req_indices) - 1
|
||||
while empty_req_indices:
|
||||
# Find the largest non-empty index.
|
||||
while last_req_index in empty_req_indices:
|
||||
last_req_index -= 1
|
||||
|
||||
# Find the smallest empty index.
|
||||
empty_index = empty_req_indices.pop()
|
||||
if empty_index >= last_req_index:
|
||||
break
|
||||
|
||||
# Swap the states.
|
||||
req_id = self._req_ids[last_req_index]
|
||||
output_token_ids = self.req_output_token_ids[last_req_index]
|
||||
assert req_id is not None
|
||||
self._req_ids[empty_index] = req_id
|
||||
self._req_ids[last_req_index] = None
|
||||
self.req_output_token_ids[empty_index] = output_token_ids
|
||||
self.req_output_token_ids[last_req_index] = None
|
||||
self.req_id_to_index[req_id] = empty_index
|
||||
|
||||
num_tokens = self.num_tokens[last_req_index]
|
||||
self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
|
||||
last_req_index, :num_tokens]
|
||||
self.num_tokens[empty_index] = num_tokens
|
||||
self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
|
||||
last_req_index]
|
||||
self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[
|
||||
last_req_index]
|
||||
self.num_computed_tokens_cpu[
|
||||
empty_index] = self.num_computed_tokens_cpu[last_req_index]
|
||||
self.block_table.move_row(last_req_index, empty_index)
|
||||
self.temperature_cpu[empty_index] = self.temperature_cpu[
|
||||
last_req_index]
|
||||
self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
|
||||
self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
|
||||
self.frequency_penalties_cpu[
|
||||
empty_index] = self.frequency_penalties_cpu[last_req_index]
|
||||
self.presence_penalties_cpu[
|
||||
empty_index] = self.presence_penalties_cpu[last_req_index]
|
||||
self.repetition_penalties_cpu[
|
||||
empty_index] = self.repetition_penalties_cpu[last_req_index]
|
||||
self.min_p_cpu[empty_index] = self.min_p_cpu[last_req_index]
|
||||
generator = self.generators.pop(last_req_index, None)
|
||||
if generator is not None:
|
||||
self.generators[empty_index] = generator
|
||||
|
||||
min_token = self.min_tokens.pop(last_req_index, None)
|
||||
if min_token is not None:
|
||||
self.min_tokens[empty_index] = min_token
|
||||
|
||||
self.request_lora_mapping[empty_index] = self.request_lora_mapping[
|
||||
last_req_index]
|
||||
|
||||
self.logit_bias[empty_index] = self.logit_bias[last_req_index]
|
||||
|
||||
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
||||
self.allowed_token_ids_mask_cpu_tensor[
|
||||
empty_index] = self.allowed_token_ids_mask_cpu_tensor[
|
||||
last_req_index]
|
||||
|
||||
bad_words_token_ids = self.bad_words_token_ids.pop(
|
||||
last_req_index, None)
|
||||
if bad_words_token_ids is not None:
|
||||
self.bad_words_token_ids[empty_index] = bad_words_token_ids
|
||||
# Decrement last_req_index since it is now empty.
|
||||
last_req_index -= 1
|
||||
|
||||
# Trim lists to the batch size.
|
||||
del self._req_ids[self.num_reqs:]
|
||||
del self.req_output_token_ids[self.num_reqs:]
|
||||
|
||||
def refresh_sampling_metadata(self):
|
||||
self.sampling_metadata = self._make_sampling_metadata()
|
||||
|
||||
def _make_sampling_metadata(self) -> SamplingMetadata:
|
||||
num_reqs = self.num_reqs
|
||||
if not self.all_greedy:
|
||||
temperature = copy_slice(self.temperature_cpu_tensor,
|
||||
self.temperature, num_reqs)
|
||||
else:
|
||||
temperature = None
|
||||
if not self.no_top_p:
|
||||
copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs)
|
||||
if not self.no_top_k:
|
||||
copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs)
|
||||
if not self.no_min_p:
|
||||
copy_slice(self.min_p_cpu_tensor, self.min_p, num_reqs)
|
||||
|
||||
if not self.no_penalties:
|
||||
# Since syncing these tensors is expensive only copy them
|
||||
# if necessary i.e. if there are requests which require
|
||||
# penalties to be applied during sampling.
|
||||
copy_slice(self.frequency_penalties_cpu_tensor,
|
||||
self.frequency_penalties, num_reqs)
|
||||
copy_slice(self.presence_penalties_cpu_tensor,
|
||||
self.presence_penalties, num_reqs)
|
||||
copy_slice(self.repetition_penalties_cpu_tensor,
|
||||
self.repetition_penalties, num_reqs)
|
||||
|
||||
# The prompt tokens are used only for applying penalties during
|
||||
# the sampling process. Hence copy these tensors only when
|
||||
# there are requests which need penalties to be applied.
|
||||
prompt_token_ids = self._make_prompt_token_ids_tensor()
|
||||
else:
|
||||
prompt_token_ids = None
|
||||
|
||||
allowed_token_ids_mask: Optional[torch.Tensor] = None
|
||||
if not self.no_allowed_token_ids:
|
||||
assert self.allowed_token_ids_mask is not None
|
||||
copy_slice(self.allowed_token_ids_mask_cpu_tensor,
|
||||
self.allowed_token_ids_mask, num_reqs)
|
||||
allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs]
|
||||
|
||||
return SamplingMetadata(
|
||||
temperature=temperature,
|
||||
all_greedy=self.all_greedy,
|
||||
all_random=self.all_random,
|
||||
top_p=None if self.no_top_p else self.top_p[:num_reqs],
|
||||
top_k=None if self.no_top_k else self.top_k[:num_reqs],
|
||||
min_p=None if self.no_min_p else self.min_p[:num_reqs],
|
||||
generators=self.generators,
|
||||
max_num_logprobs=self.max_num_logprobs,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
frequency_penalties=self.frequency_penalties[:num_reqs],
|
||||
presence_penalties=self.presence_penalties[:num_reqs],
|
||||
repetition_penalties=self.repetition_penalties[:num_reqs],
|
||||
output_token_ids=cast(list[list[int]], self.req_output_token_ids),
|
||||
min_tokens=self.min_tokens,
|
||||
no_penalties=self.no_penalties,
|
||||
logit_bias=self.logit_bias[:num_reqs],
|
||||
allowed_token_ids_mask=allowed_token_ids_mask,
|
||||
bad_words_token_ids=self.bad_words_token_ids,
|
||||
)
|
||||
|
||||
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
|
||||
max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max()
|
||||
prompt_token_ids_cpu_tensor = torch.empty(
|
||||
(self.num_reqs, max_prompt_len),
|
||||
device="cpu",
|
||||
dtype=torch.int64,
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
|
||||
prompt_token_ids[:] = self.token_ids_cpu[:self.
|
||||
num_reqs, :max_prompt_len]
|
||||
# Use the value of vocab_size as a pad since we don't have a
|
||||
# token_id of this value.
|
||||
for i in range(self.num_reqs):
|
||||
prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size
|
||||
return prompt_token_ids_cpu_tensor.to(device=self.device,
|
||||
non_blocking=True)
|
||||
|
||||
def make_lora_inputs(
|
||||
self, num_scheduled_tokens: np.ndarray
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
|
||||
"""
|
||||
Given the num_scheduled_tokens for each request in the batch, return
|
||||
datastructures used to activate the current LoRAs.
|
||||
Returns:
|
||||
1. prompt_lora_mapping: A tuple of size self.num_reqs where,
|
||||
prompt_lora_mapping[i] is the LoRA id to use for the ith prompt.
|
||||
2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens)
|
||||
where, token_lora_mapping[i] is the LoRA id to use for ith token.
|
||||
3. lora_requests: Set of relevant LoRA requests.
|
||||
"""
|
||||
|
||||
req_lora_mapping = self.request_lora_mapping[:self.num_reqs]
|
||||
prompt_lora_mapping = tuple(req_lora_mapping)
|
||||
token_lora_mapping = tuple(
|
||||
req_lora_mapping.repeat(num_scheduled_tokens))
|
||||
active_lora_requests: set[LoRARequest] = set(
|
||||
self.lora_id_to_lora_request.values())
|
||||
|
||||
return prompt_lora_mapping, token_lora_mapping, active_lora_requests
|
||||
|
||||
@property
|
||||
def num_reqs(self) -> int:
|
||||
return len(self.req_id_to_index)
|
||||
|
||||
@property
|
||||
def all_greedy(self) -> bool:
|
||||
return len(self.random_reqs) == 0
|
||||
|
||||
@property
|
||||
def all_random(self) -> bool:
|
||||
return len(self.greedy_reqs) == 0
|
||||
|
||||
@property
|
||||
def no_top_p(self) -> bool:
|
||||
return len(self.top_p_reqs) == 0
|
||||
|
||||
@property
|
||||
def no_top_k(self) -> bool:
|
||||
return len(self.top_k_reqs) == 0
|
||||
|
||||
@property
|
||||
def no_min_p(self) -> bool:
|
||||
return len(self.min_p_reqs) == 0
|
||||
|
||||
@property
|
||||
def no_penalties(self) -> bool:
|
||||
return (len(self.presence_penalties_reqs) == 0
|
||||
and len(self.frequency_penalties_reqs) == 0
|
||||
and len(self.repetition_penalties_reqs) == 0)
|
||||
|
||||
@property
|
||||
def max_num_logprobs(self) -> Optional[int]:
|
||||
return max(self.num_logprobs.values()) if self.num_logprobs else None
|
||||
|
||||
@property
|
||||
def no_prompt_logprob(self) -> bool:
|
||||
return not self.num_prompt_logprobs
|
||||
|
||||
@property
|
||||
def no_allowed_token_ids(self) -> bool:
|
||||
return len(self.has_allowed_token_ids) == 0
|
||||
1717
vllm/v1/worker/gpu_model_runner.py
Normal file
1717
vllm/v1/worker/gpu_model_runner.py
Normal file
File diff suppressed because it is too large
Load Diff
320
vllm/v1/worker/gpu_worker.py
Normal file
320
vllm/v1/worker/gpu_worker.py
Normal file
@@ -0,0 +1,320 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""A GPU worker class."""
|
||||
import gc
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.nn as nn
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ParallelConfig, VllmConfig
|
||||
from vllm.device_allocator.cumem import CuMemAllocator
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment,
|
||||
set_custom_all_reduce)
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import GiB_bytes
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
from vllm.v1.worker.worker_base import WorkerBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
|
||||
class Worker(WorkerBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
is_driver_worker: bool = False,
|
||||
):
|
||||
|
||||
super().__init__(vllm_config=vllm_config,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
is_driver_worker=is_driver_worker)
|
||||
|
||||
if self.model_config.trust_remote_code:
|
||||
# note: lazy import to avoid importing torch before initializing
|
||||
from vllm.utils import init_cached_hf_modules
|
||||
init_cached_hf_modules()
|
||||
|
||||
# Torch profiler. Enabled and configured through env vars:
|
||||
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
|
||||
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
|
||||
logger.info("Profiling enabled. Traces will be saved to: %s",
|
||||
torch_profiler_trace_dir)
|
||||
self.profiler = torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
with_stack=True,
|
||||
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
||||
torch_profiler_trace_dir, use_gzip=True))
|
||||
else:
|
||||
self.profiler = None
|
||||
|
||||
def sleep(self, level: int = 1) -> None:
|
||||
free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
|
||||
allocator = CuMemAllocator.get_instance()
|
||||
allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple())
|
||||
free_bytes_after_sleep, total = torch.cuda.mem_get_info()
|
||||
freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
|
||||
used_bytes = total - free_bytes_after_sleep
|
||||
assert freed_bytes >= 0, "Memory usage increased after sleeping."
|
||||
logger.info(
|
||||
"Sleep mode freed %.2f GiB memory, "
|
||||
"%.2f GiB memory is still in use.", freed_bytes / GiB_bytes,
|
||||
used_bytes / GiB_bytes)
|
||||
|
||||
def wake_up(self, tags: Optional[list[str]] = None) -> None:
|
||||
allocator = CuMemAllocator.get_instance()
|
||||
allocator.wake_up(tags)
|
||||
|
||||
def init_device(self):
|
||||
if self.device_config.device.type == "cuda":
|
||||
# torch.distributed.all_reduce does not free the input tensor until
|
||||
# the synchronization point. This causes the memory usage to grow
|
||||
# as the number of all_reduce calls increases. This env var disables
|
||||
# this behavior.
|
||||
# Related issue:
|
||||
# https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
|
||||
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
||||
|
||||
# This env var set by Ray causes exceptions with graph building.
|
||||
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
|
||||
self.device = torch.device(f"cuda:{self.local_rank}")
|
||||
torch.cuda.set_device(self.device)
|
||||
|
||||
_check_if_gpu_supports_dtype(self.model_config.dtype)
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
self.init_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Not support device type: {self.device_config.device}")
|
||||
# Initialize the distributed environment.
|
||||
init_worker_distributed_environment(self.parallel_config, self.rank,
|
||||
self.distributed_init_method,
|
||||
self.local_rank)
|
||||
# Set random seed.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
# Construct the model runner
|
||||
self.model_runner: GPUModelRunner = GPUModelRunner(
|
||||
self.vllm_config, self.device)
|
||||
|
||||
# FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
|
||||
# to hijack tensor allocation.
|
||||
def load_model(self) -> None:
|
||||
if self.vllm_config.model_config.enable_sleep_mode:
|
||||
allocator = CuMemAllocator.get_instance()
|
||||
assert allocator.get_current_usage() == 0, (
|
||||
"Sleep mode can only be "
|
||||
"used for one instance per process.")
|
||||
context = allocator.use_memory_pool(tag="weights")
|
||||
else:
|
||||
from contextlib import nullcontext
|
||||
context = nullcontext()
|
||||
with context:
|
||||
self.model_runner.load_model()
|
||||
|
||||
@torch.inference_mode()
|
||||
def determine_available_memory(self) -> int:
|
||||
"""Profiles the peak memory usage of the model to determine how much
|
||||
memory can be used for KV cache without OOMs.
|
||||
|
||||
The engine will first conduct a profiling of the existing memory usage.
|
||||
Then, it calculate the free memory that can be used for KV cache in
|
||||
bytes.
|
||||
|
||||
.. tip::
|
||||
You may limit the usage of GPU memory
|
||||
by adjusting the `gpu_memory_utilization` parameter.
|
||||
"""
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
_, total_gpu_memory = torch.cuda.mem_get_info()
|
||||
# Execute a forward pass with dummy inputs to profile the memory usage
|
||||
# of the model.
|
||||
self.model_runner.profile_run()
|
||||
|
||||
free_gpu_memory, _ = torch.cuda.mem_get_info()
|
||||
# NOTE(woosuk): Here we assume that the other processes using the same
|
||||
# GPU did not change their memory usage during the profiling.
|
||||
assert self.init_gpu_memory > free_gpu_memory, (
|
||||
"Error in memory profiling. "
|
||||
f"Initial free memory {self.init_gpu_memory}, current free memory"
|
||||
f" {free_gpu_memory}. This happens when the GPU memory was "
|
||||
"not properly cleaned up before initializing the vLLM instance.")
|
||||
|
||||
# Get the peak memory allocation recorded by torch
|
||||
peak_memory = torch.cuda.memory_stats()["allocated_bytes.all.peak"]
|
||||
|
||||
# Check for any memory left around that may have been allocated on the
|
||||
# gpu outside of `torch`. NCCL operations, for example, can use a few
|
||||
# GB during a forward pass
|
||||
torch.cuda.empty_cache()
|
||||
torch_allocated_bytes = torch.cuda.memory_stats(
|
||||
)["allocated_bytes.all.current"]
|
||||
total_allocated_bytes = torch.cuda.mem_get_info(
|
||||
)[1] - torch.cuda.mem_get_info()[0]
|
||||
non_torch_allocations = total_allocated_bytes - torch_allocated_bytes
|
||||
if non_torch_allocations > 0:
|
||||
peak_memory += non_torch_allocations
|
||||
available_kv_cache_memory = (
|
||||
total_gpu_memory * self.cache_config.gpu_memory_utilization -
|
||||
peak_memory)
|
||||
|
||||
return int(available_kv_cache_memory)
|
||||
|
||||
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
||||
return self.model_runner.get_kv_cache_spec()
|
||||
|
||||
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""Allocate GPU KV cache with the specified kv_cache_config."""
|
||||
if self.vllm_config.model_config.enable_sleep_mode:
|
||||
allocator = CuMemAllocator.get_instance()
|
||||
context = allocator.use_memory_pool(tag="kv_cache")
|
||||
else:
|
||||
from contextlib import nullcontext
|
||||
context = nullcontext()
|
||||
with context:
|
||||
self.model_runner.initialize_kv_cache(kv_cache_config)
|
||||
|
||||
def compile_or_warm_up_model(self) -> None:
|
||||
# warm up sizes that are not in cudagraph capture sizes,
|
||||
# but users still want to compile for better performance,
|
||||
# e.g. for the max-num-batched token size in chunked prefill.
|
||||
warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
|
||||
if not self.model_config.enforce_eager:
|
||||
warmup_sizes = [
|
||||
x for x in warmup_sizes if x not in
|
||||
self.vllm_config.compilation_config.cudagraph_capture_sizes
|
||||
]
|
||||
for size in sorted(warmup_sizes, reverse=True):
|
||||
logger.info("Compile and warming up model for size %d", size)
|
||||
self.model_runner._dummy_run(size)
|
||||
if not self.model_config.enforce_eager:
|
||||
self.model_runner.capture_model()
|
||||
|
||||
# Warm up sampler and preallocate memory buffer for logits and other
|
||||
# sampling related tensors of max possible shape to avoid memory
|
||||
# fragmentation issue.
|
||||
# NOTE: This is called after `capture_model` on purpose to prevent
|
||||
# memory buffers from being cleared by `torch.cuda.empty_cache`.
|
||||
if get_pp_group().is_last_rank:
|
||||
max_num_reqs = min(self.scheduler_config.max_num_seqs,
|
||||
self.scheduler_config.max_num_batched_tokens)
|
||||
self.model_runner._dummy_sampler_run(
|
||||
hidden_states=self.model_runner._dummy_run(
|
||||
num_tokens=max_num_reqs))
|
||||
|
||||
# Reset the seed to ensure that the random state is not affected by
|
||||
# the model initialization and profiling.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model_runner.get_model()
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> Optional[ModelRunnerOutput]:
|
||||
output = self.model_runner.execute_model(scheduler_output)
|
||||
return output if self.is_driver_worker else None
|
||||
|
||||
def profile(self, is_start: bool = True):
|
||||
if self.profiler is None:
|
||||
raise RuntimeError("Profiler is not enabled.")
|
||||
if is_start:
|
||||
self.profiler.start()
|
||||
else:
|
||||
self.profiler.stop()
|
||||
|
||||
def execute_dummy_batch(self) -> None:
|
||||
self.model_runner._dummy_run(1)
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
return self.model_runner.add_lora(lora_request)
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
return self.model_runner.remove_lora(lora_id)
|
||||
|
||||
def list_loras(self) -> set[int]:
|
||||
return self.model_runner.list_loras()
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
return self.model_runner.pin_lora(lora_id)
|
||||
|
||||
def check_health(self) -> None:
|
||||
# worker will always be healthy as long as it's running.
|
||||
return
|
||||
|
||||
def save_sharded_state(
|
||||
self,
|
||||
path: str,
|
||||
pattern: Optional[str] = None,
|
||||
max_size: Optional[int] = None,
|
||||
) -> None:
|
||||
from vllm.model_executor.model_loader.loader import ShardedStateLoader
|
||||
ShardedStateLoader.save_model(
|
||||
self.model_runner.model,
|
||||
path,
|
||||
pattern=pattern,
|
||||
max_size=max_size,
|
||||
)
|
||||
|
||||
|
||||
def init_worker_distributed_environment(
|
||||
parallel_config: ParallelConfig,
|
||||
rank: int,
|
||||
distributed_init_method: Optional[str] = None,
|
||||
local_rank: int = -1,
|
||||
) -> None:
|
||||
"""Initialize the distributed environment."""
|
||||
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
|
||||
|
||||
init_distributed_environment(parallel_config.world_size, rank,
|
||||
distributed_init_method, local_rank)
|
||||
|
||||
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size)
|
||||
|
||||
|
||||
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
|
||||
# Check if the GPU supports the dtype.
|
||||
if torch_dtype == torch.bfloat16: # noqa: SIM102
|
||||
if not current_platform.has_device_capability(80):
|
||||
capability = current_platform.get_device_capability()
|
||||
gpu_name = current_platform.get_device_name()
|
||||
|
||||
if capability is None:
|
||||
compute_str = "does not have a compute capability"
|
||||
else:
|
||||
version_str = capability.as_version_str()
|
||||
compute_str = f"has compute capability {version_str}"
|
||||
|
||||
raise ValueError(
|
||||
"Bfloat16 is only supported on GPUs with compute capability "
|
||||
f"of at least 8.0. Your {gpu_name} GPU {compute_str}. "
|
||||
"You can use float16 instead by explicitly setting the "
|
||||
"`dtype` flag in CLI, for example: --dtype=half.")
|
||||
149
vllm/v1/worker/lora_model_runner_mixin.py
Normal file
149
vllm/v1/worker/lora_model_runner_mixin.py
Normal file
@@ -0,0 +1,149 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
Define LoRA functionality mixin for model runners.
|
||||
"""
|
||||
|
||||
from contextlib import contextmanager
|
||||
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import LoRAConfig, ModelConfig, SchedulerConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import LoRAMapping
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
|
||||
from vllm.model_executor.models import supports_lora, supports_multimodal
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# Defined as a mixin for GPUModelRunner
|
||||
class LoRAModelRunnerMixin:
|
||||
|
||||
LORA_WARMUP_RANK = 8
|
||||
|
||||
def load_lora_model(self, model: nn.Module, model_config: ModelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
lora_config: LoRAConfig, device: str) -> nn.Module:
|
||||
|
||||
assert supports_lora(
|
||||
model), f"{model.__class__.__name__} does not support LoRA yet."
|
||||
|
||||
if supports_multimodal(model):
|
||||
logger.warning("Regarding multimodal models, vLLM currently "
|
||||
"only supports adding LoRA to language model.")
|
||||
|
||||
# It's necessary to distinguish between the max_position_embeddings
|
||||
# of VLMs and LLMs.
|
||||
if hasattr(model.config, "max_position_embeddings"):
|
||||
max_pos_embeddings = model.config.max_position_embeddings
|
||||
else:
|
||||
max_pos_embeddings = (
|
||||
model.config.text_config.max_position_embeddings)
|
||||
|
||||
# Add LoRA Manager to the Model Runner
|
||||
self.lora_manager = LRUCacheWorkerLoRAManager(
|
||||
scheduler_config.max_num_seqs,
|
||||
scheduler_config.max_num_batched_tokens,
|
||||
model_config.get_vocab_size(),
|
||||
lora_config,
|
||||
device,
|
||||
model.embedding_modules,
|
||||
model.embedding_padding_modules,
|
||||
max_position_embeddings=max_pos_embeddings,
|
||||
)
|
||||
return self.lora_manager.create_lora_manager(model)
|
||||
|
||||
def _set_active_loras(self, prompt_lora_mapping: tuple[int, ...],
|
||||
token_lora_mapping: tuple[int, ...],
|
||||
lora_requests: set[LoRARequest]) -> None:
|
||||
if not self.lora_manager:
|
||||
raise RuntimeError("LoRA is not enabled.")
|
||||
|
||||
# Set is_prefill to True, so we always use the SGMV kernels on
|
||||
# non-cuda platforms.
|
||||
# On cuda platforms we use the same kernels for prefill and
|
||||
# decode and this flag is generally ignored.
|
||||
lora_mapping = LoRAMapping(token_lora_mapping,
|
||||
prompt_lora_mapping,
|
||||
is_prefill=True)
|
||||
self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
|
||||
|
||||
def set_active_loras(self, input_batch: InputBatch,
|
||||
num_scheduled_tokens: np.ndarray) -> None:
|
||||
|
||||
prompt_lora_mapping: tuple[int, ...] # of size input_batch.num_reqs
|
||||
token_lora_mapping: tuple[int,
|
||||
...] # of size np.sum(num_scheduled_tokens)
|
||||
lora_requests: set[LoRARequest]
|
||||
prompt_lora_mapping, token_lora_mapping, lora_requests = \
|
||||
input_batch.make_lora_inputs(num_scheduled_tokens)
|
||||
return self._set_active_loras(prompt_lora_mapping, token_lora_mapping,
|
||||
lora_requests)
|
||||
|
||||
@contextmanager
|
||||
def maybe_dummy_run_with_lora(self, lora_config: LoRAConfig,
|
||||
num_scheduled_tokens: np.ndarray):
|
||||
if lora_config is None:
|
||||
yield
|
||||
else:
|
||||
# __enter__ code
|
||||
assert self.lora_manager is not None, "LoRA is not enabled"
|
||||
|
||||
num_reqs = len(num_scheduled_tokens)
|
||||
num_loras = lora_config.max_loras
|
||||
|
||||
# Make prompt lora mapping
|
||||
# Assign LoRA IDs cyclically to simulate a worst-case scenario.
|
||||
prompt_lora_mapping = (np.arange(num_reqs, dtype=np.int32) %
|
||||
num_loras) + 1
|
||||
|
||||
# Make token lora mapping
|
||||
token_lora_mapping = np.repeat(prompt_lora_mapping,
|
||||
num_scheduled_tokens)
|
||||
|
||||
# Make dummy lora requests
|
||||
lora_requests: set[LoRARequest] = {
|
||||
LoRARequest(lora_name=f"warmup_{lora_id}",
|
||||
lora_int_id=lora_id,
|
||||
lora_path="/not/a/real/path")
|
||||
for lora_id in range(1, num_loras + 1)
|
||||
}
|
||||
|
||||
with self.lora_manager.dummy_lora_cache():
|
||||
# Add the dummy LoRAs here so _set_active_loras doesn't try to
|
||||
# load from disk.
|
||||
for lr in lora_requests:
|
||||
self.lora_manager.add_dummy_lora(
|
||||
lr, rank=self.LORA_WARMUP_RANK)
|
||||
|
||||
self._set_active_loras(tuple(prompt_lora_mapping),
|
||||
tuple(token_lora_mapping),
|
||||
lora_requests)
|
||||
|
||||
yield
|
||||
|
||||
# __exit__ code
|
||||
self.lora_manager.remove_all_adapters()
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
if not self.lora_manager:
|
||||
raise RuntimeError("LoRA is not enabled.")
|
||||
return self.lora_manager.add_adapter(lora_request)
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
if not self.lora_manager:
|
||||
raise RuntimeError("LoRA is not enabled.")
|
||||
return self.lora_manager.remove_adapter(lora_id)
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
if not self.lora_manager:
|
||||
raise RuntimeError("LoRA is not enabled.")
|
||||
return self.lora_manager.pin_adapter(lora_id)
|
||||
|
||||
def list_loras(self) -> set[int]:
|
||||
if not self.lora_manager:
|
||||
raise RuntimeError("LoRA is not enabled.")
|
||||
return self.lora_manager.list_adapters()
|
||||
1037
vllm/v1/worker/tpu_model_runner.py
Normal file
1037
vllm/v1/worker/tpu_model_runner.py
Normal file
File diff suppressed because it is too large
Load Diff
250
vllm/v1/worker/tpu_worker.py
Normal file
250
vllm/v1/worker/tpu_worker.py
Normal file
@@ -0,0 +1,250 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""A TPU worker class."""
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.nn as nn
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.debug.profiler as xp
|
||||
import torch_xla.runtime as xr
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ParallelConfig, VllmConfig
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import (AttentionSpec, KVCacheConfig,
|
||||
KVCacheSpec)
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.utils import bind_kv_cache
|
||||
from vllm.v1.worker.tpu_model_runner import TPUModelRunner
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TPUWorker:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
is_driver_worker: bool = False,
|
||||
):
|
||||
self.is_driver_worker = is_driver_worker
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.lora_config = vllm_config.lora_config
|
||||
self.load_config = vllm_config.load_config
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.device_config = vllm_config.device_config
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.prompt_adapter_config = vllm_config.prompt_adapter_config
|
||||
self.observability_config = vllm_config.observability_config
|
||||
|
||||
self.parallel_config.rank = rank
|
||||
self.local_rank = local_rank
|
||||
self.rank = rank
|
||||
self.distributed_init_method = distributed_init_method
|
||||
|
||||
if self.cache_config.cache_dtype == "auto":
|
||||
self.cache_dtype = self.model_config.dtype
|
||||
else:
|
||||
self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
|
||||
self.cache_config.cache_dtype]
|
||||
|
||||
if self.model_config.trust_remote_code:
|
||||
# note: lazy import to avoid importing torch before initializing
|
||||
from vllm.utils import init_cached_hf_modules
|
||||
init_cached_hf_modules()
|
||||
|
||||
# Delay profiler initialization to the start of the profiling.
|
||||
# This is because in vLLM V1, MP runtime is initialized before the
|
||||
# TPU Worker is initialized. The profiler server needs to start after
|
||||
# MP runtime is initialized.
|
||||
self.profiler = None
|
||||
self.profile_dir = None
|
||||
if envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1:
|
||||
# For TPU, we can only have 1 active profiler session for 1 profiler
|
||||
# server. So we only profile on rank0.
|
||||
self.profile_dir = envs.VLLM_TORCH_PROFILER_DIR
|
||||
logger.info("Profiling enabled. Traces will be saved to: %s",
|
||||
self.profile_dir)
|
||||
|
||||
if self.model_config.seed is None:
|
||||
self.model_config.seed = 0
|
||||
|
||||
def init_device(self):
|
||||
os.environ["PJRT_DEVICE"] = "TPU"
|
||||
# Note: Currently the XLA compiler wrongly uses 2D ring strategy on 1D
|
||||
# ring, the xla tpu compiler flag
|
||||
# `xla_tpu_force_1d_allreduce_at_chunk_count` is a temporary solution to
|
||||
# fix this. It will be removed after the bug in XLA compiler is fixed.
|
||||
os.environ["LIBTPU_INIT_ARGS"] = (
|
||||
"--xla_tpu_force_1d_allreduce_at_chunk_count=1")
|
||||
torch.set_grad_enabled(False)
|
||||
torch.set_default_dtype(self.model_config.dtype)
|
||||
|
||||
# Initialize the distributed environment.
|
||||
init_tpu_worker_distributed_environment(self.parallel_config,
|
||||
self.rank,
|
||||
self.distributed_init_method,
|
||||
self.local_rank)
|
||||
|
||||
# Device initialization should happen after initializing
|
||||
# the distributed runtime.
|
||||
self.device = xm.xla_device()
|
||||
self.device_config.device = self.device
|
||||
|
||||
# Set random seed.
|
||||
set_random_seed(self.model_config.seed)
|
||||
if self.model_config.seed is not None:
|
||||
xm.set_rng_state(self.model_config.seed, self.device)
|
||||
|
||||
# Increase the cache size limit, which is the maximum number of
|
||||
# dynamo graphs that can be compiled.
|
||||
# TODO (NickLucche) On gsm we compile 80+ graphs.
|
||||
# Re-evaluate limit, with MM we may get close to this limit.
|
||||
torch._dynamo.config.cache_size_limit = 128
|
||||
# Use persistent cache to avoid XLA recompilation.
|
||||
# NOTE(woosuk): Set per-rank cache path since different ranks
|
||||
# can have slightly different XLA graphs.
|
||||
world_size = self.parallel_config.world_size
|
||||
rank = xr.global_ordinal()
|
||||
# The PyTorch/XLA compilation cache uses the Torch IR to generate keys.
|
||||
# Consequently, changes in optimization flags, which affect compilation
|
||||
# results, don't change the cache key. This can result in the wrong
|
||||
# compilation being used. To prevent this, disabling the XLA compilation
|
||||
# cache during development is recommended.We can disable it by
|
||||
# `export VLLM_XLA_CACHE_PATH=`
|
||||
if envs.VLLM_XLA_CACHE_PATH:
|
||||
per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH,
|
||||
f"tp{world_size}_rank{rank}")
|
||||
xr.initialize_cache(per_rank_path, readonly=False)
|
||||
|
||||
# Init ModelRunner here, so that we have access to self.device.
|
||||
self.model_runner = TPUModelRunner(self.vllm_config, self.device)
|
||||
|
||||
def determine_available_memory(self) -> int:
|
||||
kv_caches: dict[str, torch.Tensor] = {}
|
||||
kv_cache_spec = self.model_runner.get_kv_cache_spec()
|
||||
for layer_name, layer_spec in kv_cache_spec.items():
|
||||
if isinstance(layer_spec, AttentionSpec):
|
||||
dtype = layer_spec.dtype
|
||||
|
||||
# Use an empty tensor instead of `None`` to force Dynamo to pass
|
||||
# it by reference, rather by specializing on the value ``None``.
|
||||
tpu_kv_cache = torch.tensor([],
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
kv_caches[layer_name] = tpu_kv_cache
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unsupported KV cache spec '{type(layer_spec)}'")
|
||||
|
||||
runner_kv_caches: list[torch.Tensor] = []
|
||||
bind_kv_cache(
|
||||
kv_caches,
|
||||
self.vllm_config.compilation_config.static_forward_context,
|
||||
runner_kv_caches)
|
||||
|
||||
self.model_runner._dummy_run(
|
||||
runner_kv_caches,
|
||||
num_tokens=self.scheduler_config.max_num_batched_tokens,
|
||||
)
|
||||
|
||||
# Synchronize before measuring the memory usage.
|
||||
xm.wait_device_ops()
|
||||
|
||||
# Get the maximum amount of memory used by the model weights and
|
||||
# intermediate activations.
|
||||
m = xm.get_memory_info(self.device)
|
||||
total_memory_size = m["bytes_limit"]
|
||||
current_mem = m["bytes_used"]
|
||||
# Ideally we would use profiled = m["peak_bytes_used"] to
|
||||
# get weights + activations. But there is memory used during
|
||||
# compilation / weight loading that impacts the peak and
|
||||
# there is no way to reset peak memory in XLA, So we
|
||||
# use the heuristic of 2% of weights.
|
||||
profiled = current_mem * 1.02
|
||||
|
||||
# Calculate the TPU KV cache size based on profiling.
|
||||
usable_memory_size = int(total_memory_size *
|
||||
self.cache_config.gpu_memory_utilization)
|
||||
tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0)
|
||||
|
||||
return int(tpu_kv_cache_bytes)
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> Optional[ModelRunnerOutput]:
|
||||
output = self.model_runner.execute_model(scheduler_output)
|
||||
return output if self.is_driver_worker else None
|
||||
|
||||
def profile(self, is_start: bool = True):
|
||||
if self.rank < 1:
|
||||
if self.profile_dir is None:
|
||||
raise RuntimeError("Profiler is not enabled.")
|
||||
if is_start:
|
||||
if self.profiler is None:
|
||||
self.profiler = xp.start_server(9012)
|
||||
xp.start_trace(self.profile_dir)
|
||||
else:
|
||||
xp.stop_trace()
|
||||
|
||||
def load_model(self) -> None:
|
||||
self.model_runner.load_model()
|
||||
|
||||
def compile_or_warm_up_model(self) -> None:
|
||||
if not self.model_config.enforce_eager:
|
||||
self.model_runner.capture_model()
|
||||
|
||||
# Reset the seed to ensure that the random state is not affected by
|
||||
# the model initialization and profiling.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model_runner.get_model()
|
||||
|
||||
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
||||
return self.model_runner.get_kv_cache_spec()
|
||||
|
||||
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""Allocate GPU KV cache with the specified kv_cache_config."""
|
||||
self.model_runner.initialize_kv_cache(kv_cache_config)
|
||||
|
||||
def check_health(self) -> None:
|
||||
# worker will always be healthy as long as it's running.
|
||||
return
|
||||
|
||||
|
||||
def init_tpu_worker_distributed_environment(
|
||||
parallel_config: ParallelConfig,
|
||||
rank: int,
|
||||
distributed_init_method: Optional[str] = None,
|
||||
local_rank: int = -1,
|
||||
) -> None:
|
||||
"""Initialize the distributed environment."""
|
||||
|
||||
# NOTE(woosuk): This is just to initialize the TP group and broadcast
|
||||
# the input objects on CPU. The all-reduce and all-gather ops on TPU
|
||||
# are invoked by `xm.all_reduce` and `xm.all_gather` which use their
|
||||
# own context.
|
||||
init_distributed_environment(
|
||||
world_size=parallel_config.world_size,
|
||||
rank=rank,
|
||||
local_rank=local_rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
backend="gloo",
|
||||
)
|
||||
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size)
|
||||
29
vllm/v1/worker/utils.py
Normal file
29
vllm/v1/worker/utils.py
Normal file
@@ -0,0 +1,29 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import torch
|
||||
|
||||
|
||||
def sanity_check_mm_encoder_outputs(
|
||||
mm_embeddings: object,
|
||||
expected_num_items: int,
|
||||
) -> None:
|
||||
"""
|
||||
Perform sanity checks for the result of
|
||||
:meth:`vllm.model_executor.models.SupportsMultiModal.get_multimodal_embeddings`.
|
||||
"""
|
||||
assert isinstance(mm_embeddings, (list, tuple, torch.Tensor)), (
|
||||
"Expected multimodal embeddings to be a list/tuple of 2D tensors, "
|
||||
f"or a single 3D tensor, but got {type(mm_embeddings)} "
|
||||
"instead. This is most likely due to incorrect implementation "
|
||||
"of the model's `get_multimodal_embeddings` method.")
|
||||
|
||||
assert len(mm_embeddings) == expected_num_items, (
|
||||
"Expected number of multimodal embeddings to match number of "
|
||||
f"input items: {expected_num_items}, but got {len(mm_embeddings)=} "
|
||||
"instead. This is most likely due to incorrect implementation "
|
||||
"of the model's `get_multimodal_embeddings` method.")
|
||||
|
||||
assert all(e.ndim == 2 for e in mm_embeddings), (
|
||||
"Expected multimodal embeddings to be a sequence of 2D tensors, "
|
||||
f"but got tensors with shapes {[e.shape for e in mm_embeddings]} "
|
||||
"instead. This is most likely due to incorrect implementation "
|
||||
"of the model's `get_multimodal_embeddings` method.")
|
||||
64
vllm/v1/worker/worker_base.py
Normal file
64
vllm/v1/worker/worker_base.py
Normal file
@@ -0,0 +1,64 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.kv_cache_interface import KVCacheSpec
|
||||
from vllm.worker.worker_base import WorkerBase as WorkerBaseV0
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class WorkerBase(WorkerBaseV0):
|
||||
"""
|
||||
Abstract class for v1 worker, mainly define some methods for v1.
|
||||
For methods shared by v0 and v1, define them in v0 WorkerBase
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
is_driver_worker: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize common worker components.
|
||||
|
||||
Args:
|
||||
vllm_config: Complete vLLM configuration
|
||||
local_rank: Local device index
|
||||
rank: Global rank in distributed setup
|
||||
distributed_init_method: Distributed initialization method
|
||||
is_driver_worker: Whether this worker handles driver
|
||||
responsibilities
|
||||
"""
|
||||
# Configuration storage
|
||||
super().__init__(vllm_config=vllm_config)
|
||||
|
||||
self.parallel_config.rank = rank
|
||||
self.local_rank = local_rank
|
||||
self.rank = rank
|
||||
self.distributed_init_method = distributed_init_method
|
||||
self.is_driver_worker = is_driver_worker
|
||||
|
||||
# Device and model state
|
||||
self.device: Optional[torch.device] = None
|
||||
self.model_runner: Optional[nn.Module] = None
|
||||
|
||||
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
||||
"""Get specifications for KV cache implementation."""
|
||||
raise NotImplementedError
|
||||
|
||||
def compile_or_warm_up_model(self) -> None:
|
||||
"""Prepare model for execution through compilation/warmup."""
|
||||
raise NotImplementedError
|
||||
|
||||
def check_health(self) -> None:
|
||||
"""Basic health check (override for device-specific checks)."""
|
||||
return
|
||||
Reference in New Issue
Block a user