111 lines
3.7 KiB
Python
111 lines
3.7 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
|
|
|
from dataclasses import dataclass
|
|
from functools import cached_property
|
|
from typing import TYPE_CHECKING
|
|
|
|
from typing_extensions import deprecated
|
|
|
|
from vllm._bc_linter import bc_linter_include
|
|
|
|
if TYPE_CHECKING:
|
|
import numpy as np
|
|
import numpy.typing as npt
|
|
import torch
|
|
|
|
from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorMetadata
|
|
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
|
|
from vllm.lora.request import LoRARequest
|
|
from vllm.multimodal.inputs import MultiModalFeatureSpec
|
|
from vllm.pooling_params import PoolingParams
|
|
from vllm.sampling_params import SamplingParams
|
|
from vllm.v1.request import Request
|
|
else:
|
|
ECConnectorMetadata = object
|
|
KVConnectorMetadata = object
|
|
LoRARequest = object
|
|
MultiModalFeatureSpec = object
|
|
PoolingParams = object
|
|
SamplingParams = object
|
|
Request = object
|
|
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
=============================
|
|
@brief: Add new_toked_ids to pass the first token generated
|
|
by the prefiller to the decoder's model_runner.
|
|
'''
|
|
@bc_linter_include
|
|
@dataclass
|
|
class NewRequestData:
|
|
req_id: str
|
|
prompt_token_ids: list[int] | None
|
|
mm_features: list[MultiModalFeatureSpec]
|
|
sampling_params: SamplingParams | None
|
|
pooling_params: PoolingParams | None
|
|
block_ids: tuple[list[int], ...]
|
|
num_computed_tokens: int
|
|
lora_request: LoRARequest | None
|
|
new_token_ids: list[list[int]]
|
|
prompt_embeds: "torch.Tensor | None" = None
|
|
|
|
@classmethod
|
|
def from_request(
|
|
cls,
|
|
request: Request,
|
|
block_ids: tuple[list[int], ...],
|
|
) -> "NewRequestData":
|
|
return cls(
|
|
req_id=request.request_id,
|
|
prompt_token_ids=request.prompt_token_ids,
|
|
mm_features=request.mm_features,
|
|
sampling_params=request.sampling_params,
|
|
pooling_params=request.pooling_params,
|
|
block_ids=block_ids,
|
|
num_computed_tokens=request.num_computed_tokens,
|
|
lora_request=request.lora_request,
|
|
prompt_embeds=request.prompt_embeds,
|
|
new_token_ids=request._output_token_ids,
|
|
)
|
|
|
|
def __repr__(self) -> str:
|
|
prompt_embeds_shape = self.prompt_embeds.shape if self.prompt_embeds else None
|
|
return (
|
|
f"NewRequestData("
|
|
f"req_id={self.req_id},"
|
|
f"prompt_token_ids={self.prompt_token_ids},"
|
|
f"mm_features={self.mm_features},"
|
|
f"sampling_params={self.sampling_params},"
|
|
f"block_ids={self.block_ids},"
|
|
f"num_computed_tokens={self.num_computed_tokens},"
|
|
f"lora_request={self.lora_request},"
|
|
f"prompt_embeds_shape={prompt_embeds_shape},"
|
|
f"new_token_ids={self.new_token_ids}"
|
|
")"
|
|
)
|
|
|
|
# Version of __repr__ with the prompt data obfuscated
|
|
def anon_repr(self) -> str:
|
|
prompt_token_ids_len = (
|
|
len(self.prompt_token_ids) if self.prompt_token_ids is not None else None
|
|
)
|
|
prompt_embeds_shape = self.prompt_embeds.shape if self.prompt_embeds else None
|
|
return (
|
|
f"NewRequestData("
|
|
f"req_id={self.req_id},"
|
|
f"prompt_token_ids_len={prompt_token_ids_len},"
|
|
f"mm_features={self.mm_features},"
|
|
f"sampling_params={self.sampling_params},"
|
|
f"block_ids={self.block_ids},"
|
|
f"num_computed_tokens={self.num_computed_tokens},"
|
|
f"lora_request={self.lora_request},"
|
|
f"prompt_embeds_shape={prompt_embeds_shape}"
|
|
")"
|
|
)
|
|
'''
|
|
==================
|
|
End of MLU Hijack
|
|
==================
|
|
''' |