forked from EngineX-Ascend/enginex-ascend-910-vllm
init v0.11.0rc0
This commit is contained in:
@@ -24,8 +24,9 @@ import numpy as np
|
||||
import torch
|
||||
from typing_extensions import deprecated
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal.inputs import (MultiModalKwargs, MultiModalKwargsItem,
|
||||
PlaceholderRange)
|
||||
from vllm.multimodal.inputs import (MultiModalFeatureSpec,
|
||||
MultiModalKwargsItem,
|
||||
MultiModalKwargsItems, PlaceholderRange)
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.utils import swap_dict_values
|
||||
@@ -37,9 +38,9 @@ from vllm.v1.sample.logits_processor import (BatchUpdateBuilder,
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.spec_decode.utils import is_spec_decode_unsupported
|
||||
from vllm.v1.utils import copy_slice
|
||||
from vllm.v1.worker.block_table import MultiGroupBlockTable
|
||||
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
from vllm_ascend.worker.block_table import MultiGroupBlockTable
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -47,10 +48,6 @@ class CachedRequestState:
|
||||
|
||||
req_id: str
|
||||
prompt_token_ids: list[int]
|
||||
mm_kwargs: list[MultiModalKwargsItem]
|
||||
mm_positions: list[PlaceholderRange]
|
||||
# TODO: remove Optional after 0.10.1.1
|
||||
mm_hashes: Optional[list[str]]
|
||||
sampling_params: Optional[SamplingParams]
|
||||
pooling_params: Optional[PoolingParams]
|
||||
generator: Optional[torch.Generator]
|
||||
@@ -62,6 +59,12 @@ class CachedRequestState:
|
||||
mrope_positions: Optional[torch.Tensor] = None
|
||||
mrope_position_delta: Optional[int] = None
|
||||
|
||||
mm_features: Optional[list[MultiModalFeatureSpec]] = None
|
||||
# for back-compatibility, will be removed in next major release
|
||||
mm_kwargs: Optional[list[MultiModalKwargsItem]] = None
|
||||
mm_positions: Optional[list[PlaceholderRange]] = None
|
||||
mm_hashes: Optional[list[PlaceholderRange]] = None
|
||||
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
|
||||
def __post_init__(self):
|
||||
@@ -75,8 +78,18 @@ class CachedRequestState:
|
||||
@property
|
||||
@deprecated("`mm_inputs` is superseded by `mm_kwargs` and will be "
|
||||
"removed in v0.13. Please use `mm_kwargs` instead.")
|
||||
def mm_inputs(self) -> list[MultiModalKwargs]:
|
||||
return [MultiModalKwargs([item]) for item in self.mm_kwargs]
|
||||
def mm_inputs(self) -> list[MultiModalKwargsItems]:
|
||||
if vllm_version_is("0.10.2"):
|
||||
assert self.mm_kwargs is not None
|
||||
return [
|
||||
MultiModalKwargsItems.from_seq([item])
|
||||
for item in self.mm_kwargs
|
||||
]
|
||||
assert self.mm_features is not None
|
||||
return [
|
||||
MultiModalKwargsItems.from_seq([f.data]) for f in self.mm_features
|
||||
if f.data is not None
|
||||
]
|
||||
|
||||
def get_token_id(self, idx: int) -> int:
|
||||
if idx < self.num_prompt_tokens:
|
||||
@@ -88,18 +101,19 @@ class CachedRequestState:
|
||||
class InputBatch:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_num_reqs: int,
|
||||
max_model_len: int,
|
||||
max_num_batched_tokens: int,
|
||||
device: torch.device,
|
||||
pin_memory: bool,
|
||||
vocab_size: int,
|
||||
block_sizes: list[int], # The block_size of each kv cache group
|
||||
logitsprocs: Optional[LogitsProcessors] = None,
|
||||
is_spec_decode: bool = False,
|
||||
is_pooling_model: bool = False,
|
||||
):
|
||||
self,
|
||||
max_num_reqs: int,
|
||||
max_model_len: int,
|
||||
max_num_batched_tokens: int,
|
||||
device: torch.device,
|
||||
pin_memory: bool,
|
||||
vocab_size: int,
|
||||
block_sizes: list[int], # The block_size of each kv cache group
|
||||
logitsprocs: Optional[LogitsProcessors] = None,
|
||||
is_spec_decode: bool = False,
|
||||
is_pooling_model: bool = False,
|
||||
num_speculative_tokens: int = 0,
|
||||
kernel_block_sizes: Optional[list[list[int]]] = None):
|
||||
self.is_pooling_model = is_pooling_model
|
||||
self.is_spec_decode = is_spec_decode
|
||||
self.max_num_reqs = max_num_reqs
|
||||
@@ -143,7 +157,8 @@ class InputBatch:
|
||||
pin_memory=pin_memory,
|
||||
device=device,
|
||||
block_sizes=block_sizes,
|
||||
)
|
||||
num_speculative_tokens=num_speculative_tokens,
|
||||
kernel_sizes=kernel_block_sizes)
|
||||
|
||||
# Sampling-related.
|
||||
self.temperature = torch.empty((max_num_reqs, ),
|
||||
@@ -218,6 +233,14 @@ class InputBatch:
|
||||
self.repetition_penalties_cpu_tensor.numpy()
|
||||
self.repetition_penalties_reqs: set[str] = set()
|
||||
|
||||
# Speculative decoding
|
||||
self.num_accepted_tokens_cpu_tensor = torch.ones((max_num_reqs, ),
|
||||
dtype=torch.int64,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.num_accepted_tokens_cpu = \
|
||||
self.num_accepted_tokens_cpu_tensor.numpy()
|
||||
|
||||
# lora related
|
||||
self.request_lora_mapping = np.zeros((self.max_num_reqs, ),
|
||||
dtype=np.int32)
|
||||
@@ -266,6 +289,11 @@ class InputBatch:
|
||||
|
||||
self.pooling_params: dict[str, PoolingParams] = {}
|
||||
|
||||
# Cached reference to the GPU tensor of previously sampled tokens
|
||||
self.prev_sampled_token_ids: Optional[torch.Tensor] = None
|
||||
self.prev_sampled_token_ids_invalid_indices: Optional[set[int]] = None
|
||||
self.prev_req_id_to_index: Optional[dict[str, int]] = None
|
||||
|
||||
@property
|
||||
def req_ids(self) -> list[str]:
|
||||
# None elements should only be present transiently
|
||||
@@ -407,6 +435,9 @@ class InputBatch:
|
||||
else:
|
||||
raise NotImplementedError(request)
|
||||
|
||||
# Speculative decoding: by default 1 token is generated.
|
||||
self.num_accepted_tokens_cpu[req_index] = 1
|
||||
|
||||
# Add request lora ID
|
||||
if request.lora_request:
|
||||
lora_id = request.lora_request.lora_int_id
|
||||
@@ -506,6 +537,8 @@ class InputBatch:
|
||||
self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1]
|
||||
self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] =\
|
||||
self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1]
|
||||
self.num_accepted_tokens_cpu[i1], self.num_accepted_tokens_cpu[i2] =\
|
||||
self.num_accepted_tokens_cpu[i2], self.num_accepted_tokens_cpu[i1]
|
||||
|
||||
# NOTE: the following is unsafe
|
||||
# self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
|
||||
@@ -612,6 +645,8 @@ class InputBatch:
|
||||
empty_index] = self.presence_penalties_cpu[last_req_index]
|
||||
self.repetition_penalties_cpu[
|
||||
empty_index] = self.repetition_penalties_cpu[last_req_index]
|
||||
self.num_accepted_tokens_cpu[
|
||||
empty_index] = self.num_accepted_tokens_cpu[last_req_index]
|
||||
generator = self.generators.pop(last_req_index, None)
|
||||
if generator is not None:
|
||||
self.generators[empty_index] = generator
|
||||
@@ -726,20 +761,13 @@ class InputBatch:
|
||||
pooling_params = [
|
||||
self.pooling_params[req_id] for req_id in self.req_ids
|
||||
]
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
return PoolingMetadata(
|
||||
prompt_lens=torch.from_numpy(
|
||||
self.num_prompt_tokens[:self.num_reqs]).to(self.device),
|
||||
prompt_token_ids=self.sampling_metadata.prompt_token_ids,
|
||||
pooling_params=pooling_params,
|
||||
)
|
||||
else:
|
||||
return PoolingMetadata(
|
||||
prompt_lens=torch.from_numpy(
|
||||
self.num_prompt_tokens[:self.num_reqs]),
|
||||
prompt_token_ids=self.sampling_metadata.prompt_token_ids,
|
||||
pooling_params=pooling_params,
|
||||
)
|
||||
|
||||
return PoolingMetadata(
|
||||
prompt_lens=torch.from_numpy(
|
||||
self.num_prompt_tokens[:self.num_reqs]),
|
||||
prompt_token_ids=self.sampling_metadata.prompt_token_ids,
|
||||
pooling_params=pooling_params,
|
||||
)
|
||||
|
||||
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
|
||||
max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max()
|
||||
|
||||
Reference in New Issue
Block a user