[gpt-oss] Add gpt-oss bf16 support
This commit is contained in:
0
vllm/v1/sample/tpu/__init__.py
Normal file
0
vllm/v1/sample/tpu/__init__.py
Normal file
124
vllm/v1/sample/tpu/metadata.py
Normal file
124
vllm/v1/sample/tpu/metadata.py
Normal file
@@ -0,0 +1,124 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
DEFAULT_SAMPLING_PARAMS = dict(
|
||||
temperature=-1.0,
|
||||
min_p=0.0,
|
||||
# strictly disabled for now
|
||||
top_k=0,
|
||||
top_p=1.0,
|
||||
# frequency_penalties=0.0,
|
||||
# presence_penalties=0.0,
|
||||
# repetition_penalties=0.0,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TPUSupportedSamplingMetadata:
|
||||
# This class exposes a more xla-friendly interface than SamplingMetadata
|
||||
# on TPU, in particular all arguments should be traceable and no optionals
|
||||
# are allowed, to avoid graph recompilation on Nones.
|
||||
temperature: torch.Tensor = None
|
||||
|
||||
min_p: torch.Tensor = None
|
||||
top_k: torch.Tensor = None
|
||||
top_p: torch.Tensor = None
|
||||
|
||||
all_greedy: bool = True
|
||||
|
||||
# Whether logprobs are to be gathered in this batch of request. To balance
|
||||
# out compile time and runtime, a fixed `max_number_logprobs` value is used
|
||||
# when gathering logprobs, regardless of the values specified in the batch.
|
||||
logprobs: bool = False
|
||||
|
||||
# TODO No penalties for now
|
||||
no_penalties: bool = True
|
||||
prompt_token_ids = None
|
||||
frequency_penalties = None
|
||||
presence_penalties = None
|
||||
repetition_penalties = None
|
||||
# should use tensor
|
||||
output_token_ids: list[list[int]] = field(default_factory=lambda: list())
|
||||
|
||||
min_tokens = None # impl is not vectorized
|
||||
|
||||
logit_bias: list[Optional[dict[int, float]]] = field(
|
||||
default_factory=lambda: list())
|
||||
|
||||
allowed_token_ids_mask = None
|
||||
bad_words_token_ids = None
|
||||
|
||||
# Generator not supported by xla
|
||||
_generators: dict[int,
|
||||
torch.Generator] = field(default_factory=lambda: dict())
|
||||
|
||||
@property
|
||||
def generators(self) -> dict[int, torch.Generator]:
|
||||
# Generator not supported by torch/xla. This field must be immutable.
|
||||
return self._generators
|
||||
|
||||
@classmethod
|
||||
def from_input_batch(
|
||||
cls,
|
||||
input_batch: InputBatch,
|
||||
padded_num_reqs: int,
|
||||
xla_device: torch.device,
|
||||
generate_params_if_all_greedy: bool = False
|
||||
) -> "TPUSupportedSamplingMetadata":
|
||||
"""
|
||||
Copy sampling tensors slices from `input_batch` to on device tensors.
|
||||
|
||||
`InputBatch._make_sampling_metadata` causes recompilation on XLA as it
|
||||
slices dynamic shapes on device tensors. This impl moves the dynamic
|
||||
ops to CPU and produces tensors of fixed `padded_num_reqs` size.
|
||||
|
||||
Args:
|
||||
input_batch: The input batch containing sampling parameters.
|
||||
padded_num_reqs: The padded number of requests.
|
||||
xla_device: The XLA device.
|
||||
generate_params_if_all_greedy: If True, generate sampling parameters
|
||||
even if all requests are greedy. this is useful for cases where
|
||||
we want to pre-compile a graph with sampling parameters, even if
|
||||
they are not strictly needed for greedy decoding.
|
||||
"""
|
||||
needs_logprobs = input_batch.max_num_logprobs>0 if \
|
||||
input_batch.max_num_logprobs else False
|
||||
# Early return to avoid unnecessary cpu to tpu copy
|
||||
if (input_batch.all_greedy is True
|
||||
and generate_params_if_all_greedy is False):
|
||||
return cls(all_greedy=True, logprobs=needs_logprobs)
|
||||
|
||||
num_reqs = input_batch.num_reqs
|
||||
|
||||
def fill_slice(cpu_tensor: torch.Tensor, fill_val) -> torch.Tensor:
|
||||
# Pad value is the default one.
|
||||
cpu_tensor[num_reqs:padded_num_reqs] = fill_val
|
||||
|
||||
fill_slice(input_batch.temperature_cpu_tensor,
|
||||
DEFAULT_SAMPLING_PARAMS["temperature"])
|
||||
fill_slice(input_batch.min_p_cpu_tensor,
|
||||
DEFAULT_SAMPLING_PARAMS["min_p"])
|
||||
fill_slice(input_batch.top_k_cpu_tensor,
|
||||
DEFAULT_SAMPLING_PARAMS["top_k"])
|
||||
fill_slice(input_batch.top_p_cpu_tensor,
|
||||
DEFAULT_SAMPLING_PARAMS["top_p"])
|
||||
|
||||
# Slice persistent device tensors to a fixed pre-compiled padded shape.
|
||||
return cls(
|
||||
temperature=input_batch.temperature_cpu_tensor[:padded_num_reqs].
|
||||
to(xla_device),
|
||||
all_greedy=input_batch.all_greedy,
|
||||
# TODO enable more and avoid returning None values
|
||||
top_p=input_batch.top_p_cpu_tensor[:padded_num_reqs].to(
|
||||
xla_device),
|
||||
top_k=input_batch.top_k_cpu_tensor[:padded_num_reqs].to(
|
||||
xla_device),
|
||||
min_p=input_batch.min_p_cpu_tensor[:padded_num_reqs].to(
|
||||
xla_device),
|
||||
logprobs=needs_logprobs)
|
||||
145
vllm/v1/sample/tpu/sampler.py
Normal file
145
vllm/v1/sample/tpu/sampler.py
Normal file
@@ -0,0 +1,145 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Sampler layer implementing TPU supported operations."""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
|
||||
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
|
||||
class Sampler(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.topk_topp_sampler = TopKTopPSampler()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: TPUSupportedSamplingMetadata,
|
||||
) -> SamplerOutput:
|
||||
# Use float32 for the logits.
|
||||
logits = logits.to(torch.float32)
|
||||
# Sample the next token.
|
||||
sampled = self.sample(logits, sampling_metadata)
|
||||
|
||||
# These are TPU tensors.
|
||||
sampler_output = SamplerOutput(
|
||||
# The sampled tokens are expanded to 2D tensor with shape
|
||||
# [num_requests, 1], where each row represents one generated
|
||||
# token per request.
|
||||
sampled_token_ids=sampled.unsqueeze(-1),
|
||||
logprobs_tensors=None)
|
||||
return sampler_output
|
||||
|
||||
def apply_temperature(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
temp: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return logits.div_(temp.unsqueeze(dim=1))
|
||||
|
||||
def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
return logits.argmax(dim=-1).view(-1)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: TPUSupportedSamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
greedy_sampled = self.greedy_sample(logits)
|
||||
|
||||
assert sampling_metadata.temperature is not None
|
||||
|
||||
# Apply temperature.
|
||||
logits = self.apply_temperature(logits, sampling_metadata.temperature)
|
||||
|
||||
# Apply min_p.
|
||||
if sampling_metadata.min_p is not None:
|
||||
logits = self.apply_min_p(logits, sampling_metadata.min_p)
|
||||
|
||||
# Apply top_k and/or top_p.
|
||||
random_sampled = self.topk_topp_sampler(
|
||||
logits,
|
||||
sampling_metadata.generators,
|
||||
sampling_metadata.top_k,
|
||||
sampling_metadata.top_p,
|
||||
)
|
||||
|
||||
sampled = torch.where(sampling_metadata.temperature < _SAMPLING_EPS,
|
||||
greedy_sampled, random_sampled)
|
||||
return sampled
|
||||
|
||||
def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
return logits.log_softmax(dim=-1, dtype=torch.float32)
|
||||
|
||||
def gather_logprobs(
|
||||
self,
|
||||
logprobs: torch.Tensor,
|
||||
num_logprobs: int,
|
||||
token_ids: torch.Tensor,
|
||||
) -> LogprobsTensors:
|
||||
"""
|
||||
Gather logprobs for topk and sampled/prompt token.
|
||||
|
||||
Args:
|
||||
logits: (num tokens) x (vocab) tensor
|
||||
num_logprobs: minimum number of logprobs to
|
||||
retain per token
|
||||
token_ids: prompt tokens (if prompt logprobs)
|
||||
or sampled tokens (if sampled
|
||||
logprobs); 1D token ID tensor
|
||||
with (num tokens) elements
|
||||
|
||||
Returns:
|
||||
Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
|
||||
Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
|
||||
Sampled token rank tensor, (num tokens)
|
||||
"""
|
||||
# Find the topK values.
|
||||
topk_logprobs, topk_indices = torch.topk(logprobs,
|
||||
num_logprobs,
|
||||
dim=-1)
|
||||
|
||||
# Get with the logprob of the prompt or sampled token.
|
||||
token_ids = token_ids.unsqueeze(-1)
|
||||
token_logprobs = logprobs.gather(-1, token_ids)
|
||||
|
||||
# Compute the ranks of the actual token.
|
||||
token_ranks = (logprobs >= token_logprobs).sum(-1)
|
||||
|
||||
# Concatenate together with the topk.
|
||||
indices = torch.cat((token_ids, topk_indices), dim=1)
|
||||
logprobs = torch.cat((token_logprobs, topk_logprobs), dim=1)
|
||||
|
||||
# Use int32 to reduce the tensor size.
|
||||
indices = indices.to(torch.int32)
|
||||
|
||||
return LogprobsTensors(indices, logprobs, token_ranks)
|
||||
|
||||
def apply_min_p(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
min_p: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Filters logits using adaptive probability thresholding.
|
||||
"""
|
||||
# Convert logits to probability distribution
|
||||
probability_values = torch.nn.functional.softmax(logits, dim=-1)
|
||||
# Calculate maximum probabilities per sequence
|
||||
max_probabilities = torch.amax(probability_values,
|
||||
dim=-1,
|
||||
keepdim=True)
|
||||
# Reshape min_p for broadcasting
|
||||
adjusted_min_p = min_p.unsqueeze(1) * max_probabilities
|
||||
# Identify valid tokens using threshold comparison
|
||||
valid_token_mask = probability_values >= adjusted_min_p
|
||||
# Apply mask using boolean indexing (xla friendly)
|
||||
logits.masked_fill_(~valid_token_mask, -float("inf"))
|
||||
return logits
|
||||
Reference in New Issue
Block a user