Iluvatar-mrv100 SDK 4.3.0
This commit is contained in:
120
vllm/v1/sample/tpu/metadata.py
Normal file
120
vllm/v1/sample/tpu/metadata.py
Normal file
@@ -0,0 +1,120 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
DEFAULT_SAMPLING_PARAMS = dict(
|
||||
temperature=-1.0,
|
||||
min_p=0.0,
|
||||
# strictly disabled for now
|
||||
# top_k=-1,
|
||||
# top_p=0.0,
|
||||
# frequency_penalties=0.0,
|
||||
# presence_penalties=0.0,
|
||||
# repetition_penalties=0.0,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TPUSupportedSamplingMetadata:
|
||||
# This class exposes a more xla-friendly interface than SamplingMetadata
|
||||
# on TPU, in particular all arguments should be traceable and no optionals
|
||||
# are allowed, to avoid graph recompilation on Nones.
|
||||
temperature: torch.Tensor
|
||||
|
||||
min_p: torch.Tensor
|
||||
# Still too slow on forward_native!
|
||||
top_k: torch.Tensor = None
|
||||
top_p: torch.Tensor = None
|
||||
|
||||
# Greedy sampling flag for compiling single xla graph.
|
||||
all_greedy: torch.Tensor = None
|
||||
|
||||
# Generator not supported by xla
|
||||
generators: dict[int,
|
||||
torch.Generator] = field(default_factory=lambda: dict())
|
||||
|
||||
# unsupported, you need to return an extra tensor of static size BxV
|
||||
max_num_logprobs = None
|
||||
|
||||
# TODO No penalties for now
|
||||
no_penalties: bool = True
|
||||
prompt_token_ids = None
|
||||
frequency_penalties = None
|
||||
presence_penalties = None
|
||||
repetition_penalties = None
|
||||
# should use tensor
|
||||
output_token_ids: list[list[int]] = field(default_factory=lambda: list())
|
||||
|
||||
min_tokens = None # impl is not vectorized
|
||||
|
||||
logit_bias: list[Optional[dict[int, float]]] = field(
|
||||
default_factory=lambda: list())
|
||||
|
||||
allowed_token_ids_mask = None
|
||||
bad_words_token_ids = None
|
||||
indices_do_sample: torch.Tensor = None
|
||||
|
||||
@classmethod
|
||||
def from_input_batch(
|
||||
cls, input_batch: InputBatch,
|
||||
indices_do_sample: torch.Tensor) -> "TPUSupportedSamplingMetadata":
|
||||
"""
|
||||
Copy sampling tensors slices from `input_batch` to on device tensors.
|
||||
|
||||
`InputBatch._make_sampling_metadata` causes recompilation on XLA as it
|
||||
slices dynamic shapes on device tensors. This impl moves the dynamic
|
||||
ops to CPU and produces tensors of fixed `padded_num_reqs` size. It
|
||||
also reuses the on-device persistent tensors managed in `input_batch`
|
||||
to reduce waste.
|
||||
|
||||
`indices_do_sample` contains the indices to be fed to the Sampler,
|
||||
normally one per request, here padded to the closest pre-compiled shape
|
||||
We expect sampling params tensors to be padded to the same fixed shape.
|
||||
|
||||
Eg. 3 requests, tensors padded to 4
|
||||
temperature: [0.7, 0.2, 0.9]=>[0.7, 0.2, 0.9, 0.0]
|
||||
sample indices: [4, 10, 11]=>indices_do_sample: [4, 10, 11, 0]
|
||||
"""
|
||||
num_reqs = input_batch.num_reqs
|
||||
padded_num_reqs = len(indices_do_sample)
|
||||
|
||||
def copy_slice(cpu_tensor: torch.Tensor, tpu_tensor: torch.Tensor,
|
||||
fill_val) -> torch.Tensor:
|
||||
# Copy slice from CPU to corresponding TPU pre-allocated tensor.
|
||||
# Pad value is the default one.
|
||||
cpu_tensor[num_reqs:padded_num_reqs] = fill_val
|
||||
# Subtle compilation: len(tpu_tensor) must be >= `padded_num_reqs`
|
||||
tpu_tensor[:padded_num_reqs] = cpu_tensor[:padded_num_reqs]
|
||||
|
||||
# NOTE NickLucche The sync CPU-TPU graph we produce here must be
|
||||
# consistent. We can't have flags to skip copies or we'll end up
|
||||
# recompiling.
|
||||
copy_slice(input_batch.temperature_cpu_tensor, input_batch.temperature,
|
||||
DEFAULT_SAMPLING_PARAMS["temperature"])
|
||||
# TODO Temporarily disabled until sampling options are enabled
|
||||
# copy_slice(input_batch.top_p_cpu_tensor, input_batch.top_p)
|
||||
# copy_slice(input_batch.top_k_cpu_tensor, input_batch.top_k)
|
||||
copy_slice(input_batch.min_p_cpu_tensor, input_batch.min_p,
|
||||
DEFAULT_SAMPLING_PARAMS["min_p"])
|
||||
|
||||
xm.mark_step()
|
||||
xm.wait_device_ops()
|
||||
|
||||
# Slice persistent device tensors to a fixed pre-compiled padded shape.
|
||||
return cls(
|
||||
temperature=input_batch.temperature[:padded_num_reqs],
|
||||
# Scalar tensor for xla-friendly tracing.
|
||||
all_greedy=torch.tensor(input_batch.all_greedy,
|
||||
dtype=torch.bool,
|
||||
device=input_batch.device),
|
||||
# TODO enable more and avoid returning None values
|
||||
top_p=None, # input_batch.top_p[:padded_num_reqs],
|
||||
top_k=None, # input_batch.top_k[:padded_num_reqs],
|
||||
min_p=input_batch.min_p[:padded_num_reqs],
|
||||
generators=input_batch.generators,
|
||||
indices_do_sample=indices_do_sample)
|
||||
Reference in New Issue
Block a user