forked from EngineX-Cambricon/enginex-mlu370-vllm
add qwen3
This commit is contained in:
158
vllm-v0.6.2/vllm/v1/sample/sampler.py
Normal file
158
vllm-v0.6.2/vllm/v1/sample/sampler.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""A layer that samples the next tokens from the model's outputs."""
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.v1.outputs import SamplerOutput
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
|
||||
class Sampler(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> SamplerOutput:
|
||||
logits = self.apply_temperature(logits, sampling_metadata.temperature)
|
||||
logits = self.apply_top_k_top_p(logits, sampling_metadata)
|
||||
|
||||
probs = self.get_probs(logits)
|
||||
sampled = self.sample(probs, sampling_metadata)
|
||||
# Use int32 to reduce the tensor size.
|
||||
sampled = sampled.to(torch.int32)
|
||||
|
||||
if sampling_metadata.max_num_logprobs > 0:
|
||||
logprobs = self.get_logprobs(logits)
|
||||
# FIXME: Mask the sampled token_id, get topk logprobs,
|
||||
# and concatenate the topk with the sampled token_id.
|
||||
topk_logprobs, topk_indices = torch.topk(
|
||||
logprobs, sampling_metadata.max_num_logprobs, dim=-1)
|
||||
# Use int32 to reduce the tensor size.
|
||||
topk_indices = topk_indices.to(torch.int32)
|
||||
else:
|
||||
topk_logprobs = None
|
||||
topk_indices = None
|
||||
|
||||
sampler_output = SamplerOutput(
|
||||
sampled_token_ids=sampled,
|
||||
logprob_token_ids=topk_indices,
|
||||
logprobs=topk_logprobs,
|
||||
prompt_logprob_token_ids=None,
|
||||
prompt_logprobs=None,
|
||||
)
|
||||
return sampler_output
|
||||
|
||||
def apply_temperature(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
temp: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# Use float32 to apply temperature scaling.
|
||||
logits = logits.to(torch.float32)
|
||||
# Avoid division by zero.
|
||||
temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
|
||||
# Use in-place division to avoid creating a new tensor.
|
||||
logits.div_(temp.unsqueeze(dim=1))
|
||||
return logits
|
||||
|
||||
def apply_top_k_top_p(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
return _apply_top_k_top_p(
|
||||
logits,
|
||||
sampling_metadata.no_top_k,
|
||||
sampling_metadata.top_k,
|
||||
sampling_metadata.no_top_p,
|
||||
sampling_metadata.top_p,
|
||||
)
|
||||
|
||||
def get_probs(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
return torch.softmax(logits, dim=-1, dtype=torch.float32)
|
||||
|
||||
def get_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
return torch.log_softmax(logits, dim=-1, dtype=torch.float32)
|
||||
|
||||
def greedy_sample(self, probs: torch.Tensor) -> torch.Tensor:
|
||||
return probs.argmax(dim=-1).view(-1)
|
||||
|
||||
def random_sample(
|
||||
self,
|
||||
probs: torch.Tensor,
|
||||
generators: Dict[int, torch.Generator],
|
||||
) -> torch.Tensor:
|
||||
q = torch.empty_like(probs)
|
||||
# NOTE(woosuk): To batch-process the requests without their own seeds,
|
||||
# which is the common case, we first assume that every request does
|
||||
# not have its own seed. Then, we overwrite the values for the requests
|
||||
# that have their own seeds.
|
||||
if len(generators) != probs.shape[0]:
|
||||
# This might still be done here unnecessarily if there are greedies
|
||||
q.exponential_()
|
||||
if generators:
|
||||
# TODO(woosuk): This can be slow because we handle each request
|
||||
# one by one. Optimize this.
|
||||
for i, generator in generators.items():
|
||||
q[i].exponential_(generator=generator)
|
||||
return probs.div_(q).argmax(dim=-1).view(-1)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
probs: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert not (sampling_metadata.all_greedy
|
||||
and sampling_metadata.all_random)
|
||||
if sampling_metadata.all_greedy:
|
||||
return self.greedy_sample(probs)
|
||||
if sampling_metadata.all_random:
|
||||
return self.random_sample(probs, sampling_metadata.generators)
|
||||
|
||||
greedy_sampled = self.greedy_sample(probs)
|
||||
random_sampled = self.random_sample(probs,
|
||||
sampling_metadata.generators)
|
||||
sampled = torch.where(
|
||||
sampling_metadata.temperature < _SAMPLING_EPS,
|
||||
greedy_sampled,
|
||||
random_sampled,
|
||||
)
|
||||
return sampled
|
||||
|
||||
|
||||
# TODO(woosuk): Optimize this with a custom kernel.
|
||||
def _apply_top_k_top_p(
|
||||
logits: torch.Tensor,
|
||||
no_top_k: bool,
|
||||
k: torch.Tensor,
|
||||
no_top_p: bool,
|
||||
p: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
if no_top_k and no_top_p:
|
||||
return logits
|
||||
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
|
||||
|
||||
if not no_top_k:
|
||||
# Apply top-k.
|
||||
top_k_mask = logits_sort.size(1) - k.to(torch.long)
|
||||
# Get all the top_k values.
|
||||
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
|
||||
top_k_mask = logits_sort < top_k_mask
|
||||
logits_sort.masked_fill_(top_k_mask, -float("inf"))
|
||||
|
||||
if not no_top_p:
|
||||
# Apply top-p.
|
||||
probs_sort = logits_sort.softmax(dim=-1)
|
||||
probs_sum = probs_sort.cumsum(dim=-1)
|
||||
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
|
||||
# at least one
|
||||
top_p_mask[:, -1] = False
|
||||
logits_sort.masked_fill_(top_p_mask, -float("inf"))
|
||||
|
||||
# Re-sort the probabilities.
|
||||
logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)
|
||||
return logits
|
||||
Reference in New Issue
Block a user