88 lines
3.3 KiB
Python
88 lines
3.3 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import torch
|
|
|
|
import vllm.envs as envs
|
|
from vllm.config.model import LogprobsMode
|
|
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
|
from vllm.v1.worker.gpu.metrics.logits import get_num_nans
|
|
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
|
|
from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs
|
|
from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
|
|
from vllm.v1.worker.gpu.sample.min_p import apply_min_p
|
|
from vllm.v1.worker.gpu.sample.output import SamplerOutput
|
|
from vllm.v1.worker.gpu.sample.penalties import apply_penalties_and_temperature
|
|
|
|
|
|
class Sampler:
|
|
def __init__(
|
|
self,
|
|
logprobs_mode: LogprobsMode = "raw_logprobs",
|
|
):
|
|
if logprobs_mode not in ["processed_logprobs", "raw_logprobs"]:
|
|
raise NotImplementedError(f"Unsupported logprobs_mode: {logprobs_mode}")
|
|
self.logprobs_mode = logprobs_mode
|
|
self.compute_nans = envs.VLLM_COMPUTE_NANS_IN_LOGITS # False by default.
|
|
|
|
def __call__(
|
|
self,
|
|
logits: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata,
|
|
) -> SamplerOutput:
|
|
# NOTE(woosuk): We intentionally compute num_nans before sampling to make clear
|
|
# that num_nans is computed before applying penalties and temperature.
|
|
num_nans = get_num_nans(logits) if self.compute_nans else None
|
|
sampled, processed_logits = self.sample(logits, sampling_metadata)
|
|
if sampling_metadata.max_num_logprobs is not None:
|
|
logits = (
|
|
processed_logits
|
|
if self.logprobs_mode == "processed_logprobs"
|
|
else logits
|
|
)
|
|
logprobs_tensors = compute_topk_logprobs(
|
|
logits,
|
|
sampling_metadata.max_num_logprobs,
|
|
sampled,
|
|
)
|
|
else:
|
|
logprobs_tensors = None
|
|
|
|
# These are GPU tensors.
|
|
sampler_output = SamplerOutput(
|
|
# The sampled tokens are expanded to 2D tensor with shape
|
|
# [num_requests, 1], where each row represents one generated
|
|
# token per request.
|
|
sampled_token_ids=sampled.view(-1, 1),
|
|
logprobs_tensors=logprobs_tensors,
|
|
num_nans=num_nans,
|
|
)
|
|
return sampler_output
|
|
|
|
def sample(
|
|
self,
|
|
logits: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
# Copy logits to a new FP32 tensor.
|
|
logits = torch.empty_like(logits, dtype=torch.float32).copy_(logits)
|
|
|
|
# Apply penalties and temperature in place.
|
|
apply_penalties_and_temperature(logits, sampling_metadata)
|
|
# Apply min_p in place.
|
|
if sampling_metadata.min_p is not None:
|
|
apply_min_p(logits, sampling_metadata.min_p)
|
|
# Apply top_k and/or top_p. This might return a new tensor.
|
|
logits = apply_top_k_top_p(
|
|
logits, sampling_metadata.top_k, sampling_metadata.top_p
|
|
)
|
|
|
|
sampled = gumbel_sample(
|
|
logits,
|
|
sampling_metadata.temperature,
|
|
sampling_metadata.seeds,
|
|
sampling_metadata.pos,
|
|
apply_temperature=False,
|
|
)
|
|
return sampled, logits
|