124 lines
4.7 KiB
Python
124 lines
4.7 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import numpy as np
|
|
import torch
|
|
|
|
from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor
|
|
|
|
|
|
class RequestState:
|
|
def __init__(
|
|
self,
|
|
max_num_reqs: int,
|
|
max_model_len: int,
|
|
max_num_batched_tokens: int,
|
|
num_speculative_steps: int,
|
|
vocab_size: int,
|
|
device: torch.device,
|
|
):
|
|
self.max_num_reqs = max_num_reqs
|
|
self.max_model_len = max_model_len
|
|
self.max_num_batched_tokens = max_num_batched_tokens
|
|
self.num_speculative_steps = num_speculative_steps
|
|
self.vocab_size = vocab_size
|
|
self.device = device
|
|
|
|
self.req_id_to_index: dict[str, int] = {}
|
|
self.index_to_req_id: dict[int, str] = {}
|
|
self.free_indices = list(range(max_num_reqs))
|
|
|
|
# NOTE(woosuk): This tensor can be extremely large (e.g., several GBs)
|
|
# depending on the configured max_num_reqs and max_model_len.
|
|
# To save GPU memory, we use UVA instead of GPU for this tensor.
|
|
self.all_token_ids = StagedWriteTensor(
|
|
(self.max_num_reqs, self.max_model_len),
|
|
dtype=torch.int32,
|
|
device=device,
|
|
uva_instead_of_gpu=True,
|
|
)
|
|
# NOTE(woosuk): Distinguish clearly between prompt_len and prefill_len:
|
|
# - prompt_len: Number of tokens in the user-provided prompt.
|
|
# - prefill_len: Number of tokens passed into the model runner.
|
|
# This can include the prompt and additional partial output tokens,
|
|
# so prefill_len >= prompt_len.
|
|
# Usually, prefill_len equals prompt_len, but in cases such as resumption after
|
|
# preemption, prefill_len may be greater. Differentiating between these values
|
|
# is crucial, as certain features such as prompt logprobs or frequency penalties
|
|
# must treat prompt and output tokens separately.
|
|
self.prompt_len = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32)
|
|
self.prefill_len = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32)
|
|
# total_len = prompt_len + output_len. It grows as the request progresses.
|
|
self.total_len = StagedWriteTensor(
|
|
self.max_num_reqs, dtype=torch.int32, device=device
|
|
)
|
|
|
|
# Number of computed tokens.
|
|
self.num_computed_prefill_tokens = np.zeros(self.max_num_reqs, dtype=np.int32)
|
|
self.num_computed_tokens = StagedWriteTensor(
|
|
self.max_num_reqs, dtype=torch.int32, device=device
|
|
)
|
|
|
|
# Last sampled tokens.
|
|
self.last_sampled_tokens = torch.zeros(
|
|
self.max_num_reqs, 1, dtype=torch.int64, device=device
|
|
)
|
|
|
|
# Draft tokens.
|
|
self.draft_tokens = torch.zeros(
|
|
self.max_num_reqs,
|
|
self.num_speculative_steps,
|
|
dtype=torch.int64,
|
|
device=device,
|
|
)
|
|
self.next_prefill_tokens = torch.zeros(
|
|
self.max_num_reqs, dtype=torch.int32, device=device
|
|
)
|
|
|
|
@property
|
|
def num_reqs(self) -> int:
|
|
return len(self.req_id_to_index)
|
|
|
|
def add_request(
|
|
self,
|
|
req_id: str,
|
|
prompt_len: int,
|
|
all_token_ids: list[int],
|
|
num_computed_tokens: int,
|
|
) -> None:
|
|
assert len(self.free_indices) > 0, "No free indices"
|
|
req_idx = self.free_indices.pop()
|
|
self.req_id_to_index[req_id] = req_idx
|
|
self.index_to_req_id[req_idx] = req_id
|
|
|
|
self.prompt_len.np[req_idx] = prompt_len
|
|
prefill_len = len(all_token_ids)
|
|
assert prefill_len >= prompt_len, (
|
|
f"prefill_len {prefill_len} < prompt_len {prompt_len}"
|
|
)
|
|
self.prefill_len.np[req_idx] = prefill_len
|
|
self.total_len.stage_write_elem(req_idx, prefill_len)
|
|
self.all_token_ids.stage_write(req_idx, 0, all_token_ids)
|
|
self.num_computed_prefill_tokens[req_idx] = num_computed_tokens
|
|
self.num_computed_tokens.stage_write_elem(req_idx, num_computed_tokens)
|
|
|
|
def apply_staged_writes(self) -> None:
|
|
self.prompt_len.copy_to_uva()
|
|
self.prefill_len.copy_to_uva()
|
|
self.total_len.apply_write()
|
|
self.all_token_ids.apply_write()
|
|
self.num_computed_tokens.apply_write()
|
|
|
|
def remove_request(self, req_id: str) -> None:
|
|
req_idx = self.req_id_to_index.pop(req_id, None)
|
|
if req_idx is None:
|
|
# Request not found.
|
|
return
|
|
self.index_to_req_id.pop(req_idx, None)
|
|
self.free_indices.append(req_idx)
|
|
|
|
def any_prefills(self, idx_mapping_np: np.ndarray) -> bool:
|
|
return np.any(
|
|
self.num_computed_prefill_tokens[idx_mapping_np]
|
|
< self.prefill_len.np[idx_mapping_np]
|
|
)
|