update
This commit is contained in:
123
vllm/v1/worker/gpu/states.py
Normal file
123
vllm/v1/worker/gpu/states.py
Normal file
@@ -0,0 +1,123 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor
|
||||
|
||||
|
||||
class RequestState:
|
||||
def __init__(
|
||||
self,
|
||||
max_num_reqs: int,
|
||||
max_model_len: int,
|
||||
max_num_batched_tokens: int,
|
||||
num_speculative_steps: int,
|
||||
vocab_size: int,
|
||||
device: torch.device,
|
||||
):
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_model_len = max_model_len
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
self.num_speculative_steps = num_speculative_steps
|
||||
self.vocab_size = vocab_size
|
||||
self.device = device
|
||||
|
||||
self.req_id_to_index: dict[str, int] = {}
|
||||
self.index_to_req_id: dict[int, str] = {}
|
||||
self.free_indices = list(range(max_num_reqs))
|
||||
|
||||
# NOTE(woosuk): This tensor can be extremely large (e.g., several GBs)
|
||||
# depending on the configured max_num_reqs and max_model_len.
|
||||
# To save GPU memory, we use UVA instead of GPU for this tensor.
|
||||
self.all_token_ids = StagedWriteTensor(
|
||||
(self.max_num_reqs, self.max_model_len),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
uva_instead_of_gpu=True,
|
||||
)
|
||||
# NOTE(woosuk): Distinguish clearly between prompt_len and prefill_len:
|
||||
# - prompt_len: Number of tokens in the user-provided prompt.
|
||||
# - prefill_len: Number of tokens passed into the model runner.
|
||||
# This can include the prompt and additional partial output tokens,
|
||||
# so prefill_len >= prompt_len.
|
||||
# Usually, prefill_len equals prompt_len, but in cases such as resumption after
|
||||
# preemption, prefill_len may be greater. Differentiating between these values
|
||||
# is crucial, as certain features such as prompt logprobs or frequency penalties
|
||||
# must treat prompt and output tokens separately.
|
||||
self.prompt_len = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32)
|
||||
self.prefill_len = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32)
|
||||
# total_len = prompt_len + output_len. It grows as the request progresses.
|
||||
self.total_len = StagedWriteTensor(
|
||||
self.max_num_reqs, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
# Number of computed tokens.
|
||||
self.num_computed_prefill_tokens = np.zeros(self.max_num_reqs, dtype=np.int32)
|
||||
self.num_computed_tokens = StagedWriteTensor(
|
||||
self.max_num_reqs, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
# Last sampled tokens.
|
||||
self.last_sampled_tokens = torch.zeros(
|
||||
self.max_num_reqs, 1, dtype=torch.int64, device=device
|
||||
)
|
||||
|
||||
# Draft tokens.
|
||||
self.draft_tokens = torch.zeros(
|
||||
self.max_num_reqs,
|
||||
self.num_speculative_steps,
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
self.next_prefill_tokens = torch.zeros(
|
||||
self.max_num_reqs, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
@property
|
||||
def num_reqs(self) -> int:
|
||||
return len(self.req_id_to_index)
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
req_id: str,
|
||||
prompt_len: int,
|
||||
all_token_ids: list[int],
|
||||
num_computed_tokens: int,
|
||||
) -> None:
|
||||
assert len(self.free_indices) > 0, "No free indices"
|
||||
req_idx = self.free_indices.pop()
|
||||
self.req_id_to_index[req_id] = req_idx
|
||||
self.index_to_req_id[req_idx] = req_id
|
||||
|
||||
self.prompt_len.np[req_idx] = prompt_len
|
||||
prefill_len = len(all_token_ids)
|
||||
assert prefill_len >= prompt_len, (
|
||||
f"prefill_len {prefill_len} < prompt_len {prompt_len}"
|
||||
)
|
||||
self.prefill_len.np[req_idx] = prefill_len
|
||||
self.total_len.stage_write_elem(req_idx, prefill_len)
|
||||
self.all_token_ids.stage_write(req_idx, 0, all_token_ids)
|
||||
self.num_computed_prefill_tokens[req_idx] = num_computed_tokens
|
||||
self.num_computed_tokens.stage_write_elem(req_idx, num_computed_tokens)
|
||||
|
||||
def apply_staged_writes(self) -> None:
|
||||
self.prompt_len.copy_to_uva()
|
||||
self.prefill_len.copy_to_uva()
|
||||
self.total_len.apply_write()
|
||||
self.all_token_ids.apply_write()
|
||||
self.num_computed_tokens.apply_write()
|
||||
|
||||
def remove_request(self, req_id: str) -> None:
|
||||
req_idx = self.req_id_to_index.pop(req_id, None)
|
||||
if req_idx is None:
|
||||
# Request not found.
|
||||
return
|
||||
self.index_to_req_id.pop(req_idx, None)
|
||||
self.free_indices.append(req_idx)
|
||||
|
||||
def any_prefills(self, idx_mapping_np: np.ndarray) -> bool:
|
||||
return np.any(
|
||||
self.num_computed_prefill_tokens[idx_mapping_np]
|
||||
< self.prefill_len.np[idx_mapping_np]
|
||||
)
|
||||
Reference in New Issue
Block a user