Files
enginex-mthreads-vllm/vllm/sequence.py

99 lines
3.4 KiB
Python
Raw Permalink Normal View History

2026-01-19 10:38:50 +08:00
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
2026-01-09 13:34:11 +08:00
"""Sequence and its related classes."""
2026-01-19 10:38:50 +08:00
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
2026-01-09 13:34:11 +08:00
2026-01-19 10:38:50 +08:00
import torch
2026-01-09 13:34:11 +08:00
2026-01-19 10:38:50 +08:00
if TYPE_CHECKING:
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput
else:
KVConnectorOutput = Any
2026-01-09 13:34:11 +08:00
2026-01-19 10:38:50 +08:00
VLLM_TOKEN_ID_ARRAY_TYPE = "l"
2026-01-09 13:34:11 +08:00
2026-01-19 10:38:50 +08:00
VLLM_INVALID_TOKEN_ID = -1
2026-01-09 13:34:11 +08:00
@dataclass
class RequestMetrics:
"""Metrics associated with a request.
Attributes:
arrival_time: The time when the request arrived.
first_scheduled_time: The time when the request was first scheduled.
first_token_time: The time when the first token was generated.
time_in_queue: The time the request spent in the queue.
finished_time: The time when the request was finished.
2026-01-19 10:38:50 +08:00
scheduler_time: The time spent in the scheduler when this request was
being considered by the scheduler.
model_forward_time: The time spent in the model forward pass when this
request was in the batch.
model_execute_time: The time spent in the model execute function. This
will include model forward, block/sync across
workers, cpu-gpu sync time and sampling time.
2026-01-09 13:34:11 +08:00
"""
2026-01-19 10:38:50 +08:00
2026-01-09 13:34:11 +08:00
arrival_time: float
last_token_time: float
2026-01-19 10:38:50 +08:00
first_scheduled_time: float | None
first_token_time: float | None
time_in_queue: float | None
finished_time: float | None = None
scheduler_time: float | None = None
model_forward_time: float | None = None
model_execute_time: float | None = None
2026-01-09 13:34:11 +08:00
2026-01-19 10:38:50 +08:00
# cannot use msgspec.Struct here because Dynamo does not support it
2026-01-09 13:34:11 +08:00
@dataclass
2026-01-19 10:38:50 +08:00
class IntermediateTensors:
"""For all pipeline stages except the last, we need to return the hidden
states and residuals to be sent to the next stage. This data structure
contains the hidden states and residuals for a request.
2026-01-09 13:34:11 +08:00
2026-01-19 10:38:50 +08:00
Each stage also needs to handle its own kv_connector_output.
2026-01-09 13:34:11 +08:00
"""
2026-01-19 10:38:50 +08:00
tensors: dict[str, torch.Tensor]
kv_connector_output: KVConnectorOutput | None
2026-01-09 13:34:11 +08:00
def __init__(
self,
2026-01-19 10:38:50 +08:00
tensors: dict[str, torch.Tensor],
kv_connector_output: KVConnectorOutput | None = None,
2026-01-09 13:34:11 +08:00
) -> None:
2026-01-19 10:38:50 +08:00
# manually define this function, so that
# Dynamo knows `IntermediateTensors()` comes from this file.
# Otherwise, dataclass will generate this function by evaluating
# a string, and we will lose the information about the source file.
self.tensors = tensors
self.kv_connector_output = kv_connector_output
2026-01-09 13:34:11 +08:00
2026-01-19 10:38:50 +08:00
def __getitem__(self, key: str | slice):
if isinstance(key, str):
return self.tensors[key]
elif isinstance(key, slice):
return self.__class__({k: v[key] for k, v in self.tensors.items()})
2026-01-09 13:34:11 +08:00
2026-01-19 10:38:50 +08:00
def __setitem__(self, key: str, value: torch.Tensor):
self.tensors[key] = value
2026-01-09 13:34:11 +08:00
2026-01-19 10:38:50 +08:00
def items(self):
return self.tensors.items()
2026-01-09 13:34:11 +08:00
def __len__(self):
2026-01-19 10:38:50 +08:00
return len(self.tensors)
2026-01-09 13:34:11 +08:00
def __eq__(self, other: object):
2026-01-19 10:38:50 +08:00
if not isinstance(other, self.__class__):
return False
if self.tensors.keys() != other.tensors.keys():
return False
return all(torch.equal(self.tensors[k], other.tensors[k]) for k in self.tensors)
2026-01-09 13:34:11 +08:00
def __repr__(self) -> str:
2026-01-19 10:38:50 +08:00
return f"IntermediateTensors(tensors={self.tensors})"