init
This commit is contained in:
0
vllm_vacc/vllm/v1/sample/__init__.py
Normal file
0
vllm_vacc/vllm/v1/sample/__init__.py
Normal file
BIN
vllm_vacc/vllm/v1/sample/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/v1/sample/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
vllm_vacc/vllm/v1/sample/__pycache__/metadata.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/v1/sample/__pycache__/metadata.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
vllm_vacc/vllm/v1/sample/__pycache__/sampler.cpython-312.pyc
Normal file
BIN
vllm_vacc/vllm/v1/sample/__pycache__/sampler.cpython-312.pyc
Normal file
Binary file not shown.
91
vllm_vacc/vllm/v1/sample/cached_pooler.py
Normal file
91
vllm_vacc/vllm/v1/sample/cached_pooler.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import torch
|
||||
import torch_vacc
|
||||
|
||||
class BinCountTensorPooler:
|
||||
"""
|
||||
sampler类中(bin-count tensor)的缓冲器。
|
||||
核心功能:为apply_penalties 提供创建好的bin-count, 为避免重复重头计算。
|
||||
注意: 如果系统中的req_id未能确保在max_count的数量下的唯一性,要考虑是否适用该pooler
|
||||
|
||||
pooler = BinCountTensorPooler(15169, "cpu")
|
||||
list_1 = pooler.request_tensors(['1','3','aad'])
|
||||
bin_count_buffers = pooler.request_tensors(req_ids)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
device: torch.device,
|
||||
max_cache_count: int = 20,
|
||||
):
|
||||
# 缓存容器:用列表维护请求ID与对应张量的顺序(FIFO淘汰)
|
||||
self._cached_tensors: list[torch.Tensor] = []
|
||||
self._cached_req_ids: list[str] = []
|
||||
|
||||
# 张量维度(+1 通常用于padding/特殊标记)
|
||||
self.vocab_size: int = vocab_size + 1
|
||||
self.device: torch.device = device
|
||||
self.max_cache_count: int = max_cache_count
|
||||
|
||||
def request_tensors(self, req_ids: list[str]) -> list[torch.Tensor]:
|
||||
"""
|
||||
批量请求张量:已缓存的直接返回,未缓存的创建并加入缓存。
|
||||
|
||||
Args:
|
||||
req_ids: 待请求的请求ID列表
|
||||
|
||||
Returns:
|
||||
与req_ids顺序对应的bin-count tensor
|
||||
"""
|
||||
if not req_ids:
|
||||
return [] # 空输入快速返回,避免无效循环
|
||||
|
||||
out_tensors = []
|
||||
for req_id in req_ids:
|
||||
if req_id not in self._cached_req_ids:
|
||||
# 注册新的req, 增加cached-tensor
|
||||
self._add_new_cache(req_id)
|
||||
|
||||
# 快速获取缓存索引(后续可优化为dict映射,提升性能)
|
||||
cache_idx = self._cached_req_ids.index(req_id)
|
||||
out_tensors.append(self._cached_tensors[cache_idx])
|
||||
|
||||
return out_tensors
|
||||
|
||||
def _add_new_cache(self, req_id: str) -> None:
|
||||
"""
|
||||
新增缓存项:创建张量并加入缓存,超出最大数量时按FIFO淘汰最早项。
|
||||
私有方法:封装内部逻辑,避免外部直接调用
|
||||
"""
|
||||
# FIFO淘汰:缓存满时移除头部(最早加入的项)
|
||||
if len(self._cached_req_ids) >= self.max_cache_count: # >= 更严谨(避免边界值问题)
|
||||
self._cached_tensors.pop(0)
|
||||
self._cached_req_ids.pop(0)
|
||||
|
||||
# 创建二进制计数张量(int32类型,初始全0)
|
||||
bin_count_tensor = torch.zeros(
|
||||
size=(1, self.vocab_size),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
requires_grad=False, # 明确禁用梯度(计数张量无需反向传播)
|
||||
)
|
||||
|
||||
# 同步添加ID和张量(保证两个列表索引一致)
|
||||
self._cached_req_ids.append(req_id)
|
||||
self._cached_tensors.append(bin_count_tensor)
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""清空所有缓存(可选扩展方法,方便外部手动清理)"""
|
||||
self._cached_tensors.clear()
|
||||
self._cached_req_ids.clear()
|
||||
|
||||
def get_cache_status(self) -> dict:
|
||||
"""获取缓存状态(可选扩展方法,方便监控)"""
|
||||
return {
|
||||
"current_cache_count": len(self._cached_req_ids),
|
||||
"max_cache_count": self.max_cache_count,
|
||||
"cached_req_ids": self._cached_req_ids.copy(), # 返回副本避免外部修改
|
||||
"tensor_shape": (1, self.vocab_size),
|
||||
"tensor_dtype": torch.int32,
|
||||
"device": str(self.device),
|
||||
}
|
||||
48
vllm_vacc/vllm/v1/sample/metadata.py
Normal file
48
vllm_vacc/vllm/v1/sample/metadata.py
Normal file
@@ -0,0 +1,48 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessors
|
||||
|
||||
|
||||
@dataclass
|
||||
class SamplingMetadata:
|
||||
|
||||
temperature: Optional[torch.Tensor]
|
||||
all_greedy: bool
|
||||
all_random: bool
|
||||
|
||||
top_p: Optional[torch.Tensor]
|
||||
top_k: Optional[torch.Tensor]
|
||||
|
||||
generators: dict[int, torch.Generator]
|
||||
|
||||
# None means no logprobs, 0 means sampled token logprobs only
|
||||
max_num_logprobs: Optional[int]
|
||||
|
||||
no_penalties: bool
|
||||
prompt_token_ids: Optional[torch.Tensor]
|
||||
frequency_penalties: torch.Tensor
|
||||
presence_penalties: torch.Tensor
|
||||
repetition_penalties: torch.Tensor
|
||||
|
||||
output_token_ids: list[list[int]]
|
||||
|
||||
# `allowed_token_ids_mask` is a 2D bool tensor of shape (max batch size,
|
||||
# vocab size).
|
||||
allowed_token_ids_mask: Optional[torch.Tensor]
|
||||
|
||||
# req_index -> bad_words_token_ids
|
||||
bad_words_token_ids: dict[int, list[list[int]]]
|
||||
|
||||
# Loaded logits processors
|
||||
logitsprocs: LogitsProcessors
|
||||
|
||||
temperature_cpu: Optional[torch.Tensor]
|
||||
top_p_cpu: Optional[torch.Tensor]
|
||||
top_k_cpu: Optional[torch.Tensor]
|
||||
|
||||
230
vllm_vacc/vllm/v1/sample/rejection_sampler.py
Normal file
230
vllm_vacc/vllm/v1/sample/rejection_sampler.py
Normal file
@@ -0,0 +1,230 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.sample.rejection_sampler import generate_uniform_probs, compute_probs, rejection_random_sample_kernel, sample_recovered_tokens
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
PLACEHOLDER_TOKEN_ID: tl.constexpr = -1
|
||||
GREEDY_TEMPERATURE: tl.constexpr = -1
|
||||
# Maximum number of speculative draft tokens allowed per request in a single
|
||||
# step. This value is chosen to be large enough to handle typical use cases.
|
||||
MAX_SPEC_LEN = 32
|
||||
|
||||
|
||||
def rejection_greedy_sample_python(
|
||||
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
|
||||
cu_num_draft_tokens_ptr, # [batch_size]
|
||||
draft_token_ids_ptr, # [num_tokens]
|
||||
target_argmax_ptr, # [num_tokens]
|
||||
bonus_token_ids_ptr, # [batch_size]
|
||||
is_greedy_ptr, # [batch_size] or None
|
||||
max_spec_len,
|
||||
num_warps
|
||||
):
|
||||
# print('max_spec_len', max_spec_len)
|
||||
if max_spec_len == 1:
|
||||
for bi in range(output_token_ids_ptr.shape[0]):
|
||||
output_token_ids_ptr[bi, 0] = target_argmax_ptr[bi]
|
||||
if target_argmax_ptr[bi].item() == draft_token_ids_ptr[bi].item():
|
||||
output_token_ids_ptr[bi, 1] = bonus_token_ids_ptr[bi]
|
||||
else:
|
||||
raise ValueError('TODO mtp k > 1')
|
||||
|
||||
|
||||
class RejectionSampler(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
metadata: SpecDecodeMetadata,
|
||||
# [num_tokens, vocab_size]
|
||||
draft_probs: Optional[torch.Tensor],
|
||||
# [num_tokens, vocab_size]
|
||||
target_logits: torch.Tensor,
|
||||
# [batch_size, 1]
|
||||
bonus_token_ids: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
'''
|
||||
Args:
|
||||
metadata:
|
||||
Metadata for spec decoding.
|
||||
draft_probs (Optional[torch.Tensor]):
|
||||
Probability distribution for the draft tokens. Shape is
|
||||
[num_tokens, vocab_size]. Can be None if probabilities are
|
||||
not provided, which is the case for ngram spec decode.
|
||||
target_logits (torch.Tensor):
|
||||
Target model's logits probability distribution.
|
||||
Shape is [num_tokens, vocab_size]. Here, probabilities from
|
||||
different requests are flattened into a single tensor because
|
||||
this is the shape of the output logits.
|
||||
NOTE: `target_logits` can be updated in place to save memory.
|
||||
bonus_token_ids_tensor (torch.Tensor):
|
||||
A tensor containing bonus tokens. Shape is [batch_size, 1].
|
||||
Bonus tokens are added to the end of the sequence if all
|
||||
proposed tokens are accepted. We generate the bonus tokens
|
||||
outside of the rejection sampler with the default sampling
|
||||
strategy. It allows for more flexibility in the sampling
|
||||
process such as top_p, top_k sampling.
|
||||
sampling_metadata (vllm.v1.sample.metadata.SamplingMetadata):
|
||||
Additional metadata needed for sampling, such as temperature,
|
||||
top-k/top-p parameters, or other relevant information.
|
||||
Returns:
|
||||
output_token_ids (torch.Tensor):
|
||||
A tensor containing the final output token IDs.
|
||||
'''
|
||||
assert metadata.max_spec_len <= MAX_SPEC_LEN
|
||||
# [num_tokens, vocab_size]
|
||||
# NOTE(woosuk): `target_logits` can be updated in place inside the
|
||||
# `compute_probs` function.
|
||||
|
||||
# print(sampling_metadata)
|
||||
# rank_id = get_tensor_model_parallel_rank()
|
||||
if metadata.max_spec_len == 1:
|
||||
output_token_ids = torch.vacc.rejection_sampler_v1(
|
||||
target_logits.to(torch.float32),
|
||||
metadata.draft_token_ids,
|
||||
bonus_token_ids,
|
||||
sampling_metadata.temperature,
|
||||
sampling_metadata.top_p,
|
||||
sampling_metadata.top_k,
|
||||
sampling_metadata.all_greedy,
|
||||
sampling_metadata.all_random,
|
||||
sampling_metadata.generators
|
||||
)
|
||||
else:
|
||||
target_probs = compute_probs(
|
||||
target_logits.to(torch.float32),
|
||||
metadata.cu_num_draft_tokens,
|
||||
sampling_metadata,
|
||||
)
|
||||
output_token_ids = rejection_sample(
|
||||
metadata.draft_token_ids,
|
||||
metadata.num_draft_tokens,
|
||||
metadata.max_spec_len,
|
||||
metadata.cu_num_draft_tokens,
|
||||
draft_probs,
|
||||
target_probs,
|
||||
bonus_token_ids,
|
||||
sampling_metadata,
|
||||
)
|
||||
|
||||
# output_token_ids_cpu = output_token_ids.cpu().tolist()
|
||||
# output_token_ids_dev_cpu = output_token_ids_dev.cpu().tolist()
|
||||
# for i in range(len(output_token_ids_cpu)):
|
||||
# for j in range(len(output_token_ids_cpu[0])):
|
||||
# if output_token_ids_cpu[i][j] != output_token_ids_dev_cpu[i][j]:
|
||||
# # print(output_token_ids_cpu)
|
||||
# # print(output_token_ids_dev_cpu)
|
||||
# print("cpu \n", output_token_ids, "vacc \n" , output_token_ids_dev)
|
||||
# exit()
|
||||
# print("cpu \n", output_token_ids, "vacc \n" , output_token_ids_dev)
|
||||
return output_token_ids
|
||||
|
||||
def rejection_sample(
|
||||
# [num_tokens]
|
||||
draft_token_ids: torch.Tensor,
|
||||
# [batch_size]
|
||||
num_draft_tokens: list[int],
|
||||
max_spec_len: int,
|
||||
# [batch_size]
|
||||
cu_num_draft_tokens: torch.Tensor,
|
||||
# [num_tokens, vocab_size]
|
||||
draft_probs: Optional[torch.Tensor],
|
||||
# [num_tokens, vocab_size]
|
||||
target_probs: torch.Tensor,
|
||||
# [batch_size, 1]
|
||||
bonus_token_ids: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert draft_token_ids.ndim == 1
|
||||
assert draft_probs is None or draft_probs.ndim == 2
|
||||
assert cu_num_draft_tokens.ndim == 1
|
||||
assert target_probs.ndim == 2
|
||||
|
||||
batch_size = len(num_draft_tokens)
|
||||
num_tokens = draft_token_ids.shape[0]
|
||||
vocab_size = target_probs.shape[-1]
|
||||
device = target_probs.device
|
||||
assert draft_token_ids.is_contiguous()
|
||||
assert draft_probs is None or draft_probs.is_contiguous()
|
||||
assert target_probs.is_contiguous()
|
||||
assert bonus_token_ids.is_contiguous()
|
||||
assert target_probs.shape == (num_tokens, vocab_size)
|
||||
|
||||
# Create output buffer.
|
||||
output_token_ids = torch.empty(
|
||||
(batch_size, max_spec_len + 1),
|
||||
dtype=torch.int32, # Consistent with SamplerOutput.sampled_token_ids.
|
||||
device=device,
|
||||
)
|
||||
output_token_ids.fill_(PLACEHOLDER_TOKEN_ID)
|
||||
|
||||
if sampling_metadata.all_greedy:
|
||||
is_greedy = None
|
||||
else:
|
||||
is_greedy = sampling_metadata.temperature == GREEDY_TEMPERATURE
|
||||
if not sampling_metadata.all_random:
|
||||
# Rejection sampling for greedy sampling requests.
|
||||
target_argmax = target_probs.argmax(dim=-1)
|
||||
# rejection_greedy_sample_kernel[(batch_size, )](
|
||||
rejection_greedy_sample_python(
|
||||
output_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
target_argmax,
|
||||
bonus_token_ids,
|
||||
is_greedy,
|
||||
max_spec_len,
|
||||
num_warps=1,
|
||||
)
|
||||
if sampling_metadata.all_greedy:
|
||||
return output_token_ids
|
||||
else:
|
||||
# TODO
|
||||
raise ValueError('not support yet')
|
||||
|
||||
# Generate uniform probabilities for rejection sampling.
|
||||
# [num_tokens]
|
||||
uniform_probs = generate_uniform_probs(
|
||||
num_tokens,
|
||||
num_draft_tokens,
|
||||
sampling_metadata.generators,
|
||||
device,
|
||||
)
|
||||
|
||||
# Sample recovered tokens for each position.
|
||||
# [num_tokens]
|
||||
recovered_token_ids = sample_recovered_tokens(
|
||||
max_spec_len,
|
||||
num_draft_tokens,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
draft_probs,
|
||||
target_probs,
|
||||
sampling_metadata,
|
||||
device,
|
||||
)
|
||||
|
||||
# Rejection sampling for random sampling requests.
|
||||
rejection_random_sample_kernel[(batch_size, )](
|
||||
output_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
draft_probs,
|
||||
target_probs,
|
||||
bonus_token_ids,
|
||||
recovered_token_ids,
|
||||
uniform_probs,
|
||||
is_greedy,
|
||||
max_spec_len,
|
||||
vocab_size,
|
||||
NO_DRAFT_PROBS=draft_probs is None,
|
||||
num_warps=1,
|
||||
)
|
||||
return output_token_ids
|
||||
|
||||
276
vllm_vacc/vllm/v1/sample/sampler.py
Normal file
276
vllm_vacc/vllm/v1/sample/sampler.py
Normal file
@@ -0,0 +1,276 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""A layer that samples the next tokens from the model's outputs."""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
|
||||
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.ops.bad_words import apply_bad_words
|
||||
from vllm.v1.sample.ops.penalties import apply_all_penalties
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
|
||||
|
||||
from .cached_pooler import BinCountTensorPooler
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
def _convert_to_tensors(output_token_ids: list[list[int]], vocab_size: int,
|
||||
device: torch.device) -> torch.Tensor:
|
||||
"""
|
||||
Convert the different list data structures to tensors.
|
||||
"""
|
||||
output_tokens_tensor = make_tensor_with_pad(
|
||||
output_token_ids,
|
||||
# Use the value of vocab_size as a pad since we don't have a
|
||||
# token_id of this value.
|
||||
pad=vocab_size,
|
||||
device="cpu",
|
||||
dtype=torch.int32, # init with int32
|
||||
pin_memory=is_pin_memory_available(),
|
||||
)
|
||||
return output_tokens_tensor.to(device, non_blocking=True)
|
||||
|
||||
class Sampler(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
is_first_calculate,
|
||||
req_ids: list[str],
|
||||
) -> SamplerOutput:
|
||||
# print(sampling_metadata.generators, len(sampling_metadata.generators), sampling_metadata.temperature_cpu)
|
||||
# sampling_metadata.temperature = sampling_metadata.temperature.to(logits.device)
|
||||
# sampling_metadata.top_p = sampling_metadata.top_p.to(logits.device)
|
||||
# sampling_metadata.top_k = sampling_metadata.top_k.to(logits.device)
|
||||
# NOTE(woosuk): Use the original logits (before any penalties or
|
||||
# temperature scaling) for the top-k logprobs.
|
||||
# This is different from the V0 sampler, which uses the logits that
|
||||
# is used for sampling (after penalties and temperature scaling).
|
||||
# TODO(rob): provide option for logprobs post sampling.
|
||||
# See https://vllm-dev.slack.com/archives/C07UUL8E61Z/p1735907856007919 # noqa: E501
|
||||
num_logprobs = sampling_metadata.max_num_logprobs
|
||||
if num_logprobs is not None:
|
||||
raw_logprobs = self.compute_logprobs(logits)
|
||||
|
||||
# Use float32 for the logits.
|
||||
logits = logits.to(torch.float32)
|
||||
# Apply allowed token ids.
|
||||
logits = self.apply_allowed_token_ids(logits, sampling_metadata)
|
||||
# Apply bad words exclusion.
|
||||
logits = self.apply_bad_words(logits, sampling_metadata)
|
||||
|
||||
# Apply logits processors which can impact greedy sampling
|
||||
for processor in (sampling_metadata.logitsprocs.non_argmax_invariant):
|
||||
logits = processor.apply(logits)
|
||||
|
||||
# Apply penalties (e.g., min_tokens, freq_penalties).
|
||||
# logits = self.apply_penalties(logits, sampling_metadata)
|
||||
|
||||
|
||||
|
||||
if not sampling_metadata.no_penalties:
|
||||
if (not hasattr(self, "bin_count_pooler")):
|
||||
self.bin_count_pooler = BinCountTensorPooler(logits.shape[-1], logits.device)
|
||||
assert sampling_metadata.prompt_token_ids is not None
|
||||
buf_bin_buffer = self.bin_count_pooler.request_tensors(req_ids)
|
||||
batch, vocab_size = logits.shape
|
||||
logits_res = []
|
||||
|
||||
for i in range(batch):
|
||||
output_tokens_t = _convert_to_tensors([sampling_metadata.output_token_ids[i]], vocab_size, logits.device).to(torch.int32)
|
||||
output_tokens_t_lastone = output_tokens_t[:, -1:]
|
||||
if logits.shape[0] > 1:
|
||||
logits_i = torch.vacc.apply_penalties(logits[i:i+1],
|
||||
output_tokens_t_lastone,
|
||||
buf_bin_buffer[i],
|
||||
vocab_size,
|
||||
output_tokens_t_lastone.shape[-1],
|
||||
[sampling_metadata.frequency_penalties[i]],
|
||||
[sampling_metadata.presence_penalties[i]],
|
||||
is_first_calculate)
|
||||
else:
|
||||
logits_i = torch.vacc.apply_penalties(logits,
|
||||
output_tokens_t_lastone,
|
||||
buf_bin_buffer[i],
|
||||
vocab_size,
|
||||
output_tokens_t_lastone.shape[-1],
|
||||
[sampling_metadata.frequency_penalties[i]],
|
||||
[sampling_metadata.presence_penalties[i]],
|
||||
is_first_calculate)
|
||||
logits_res.append(logits_i)
|
||||
|
||||
if len(logits_res) > 1:
|
||||
logits = torch.concat(logits_res)
|
||||
else:
|
||||
logits = logits_res[0]
|
||||
|
||||
# Sample the next token.
|
||||
|
||||
# sampled = self.sample(logits, sampling_metadata)
|
||||
sampled, _ = torch.vacc.sampler_v1(logits, sampling_metadata.top_p_cpu, sampling_metadata.top_k_cpu, sampling_metadata.temperature_cpu, int(sampling_metadata.all_greedy), int(sampling_metadata.all_random), sampling_metadata.generators)
|
||||
|
||||
# Gather the logprobs of the topk and sampled token (if requested).
|
||||
# Get logprobs and rank tensors (if requested)
|
||||
logprobs_tensors = None if num_logprobs is None else \
|
||||
self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=sampled.long())
|
||||
|
||||
|
||||
# These are GPU tensors.
|
||||
sampler_output = SamplerOutput(
|
||||
# The sampled tokens are expanded to 2D tensor with shape
|
||||
# [num_requests, 1], where each row represents one generated
|
||||
# token per request.
|
||||
sampled_token_ids=sampled.unsqueeze(-1),
|
||||
logprobs_tensors=logprobs_tensors,
|
||||
)
|
||||
return sampler_output
|
||||
|
||||
def apply_temperature(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
temp: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# Use in-place division to avoid creating a new tensor.
|
||||
return logits.div_(temp.unsqueeze(dim=1))
|
||||
|
||||
def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
return logits.argmax(dim=-1).view(-1)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
"""Sample logits based on sampling metadata.
|
||||
|
||||
The various logits processing functions called in this method
|
||||
may update the logits tensor in-place.
|
||||
"""
|
||||
|
||||
assert not (sampling_metadata.all_greedy
|
||||
and sampling_metadata.all_random)
|
||||
if sampling_metadata.all_random:
|
||||
greedy_sampled = None
|
||||
else:
|
||||
greedy_sampled = self.greedy_sample(logits)
|
||||
if sampling_metadata.all_greedy:
|
||||
return greedy_sampled
|
||||
|
||||
assert sampling_metadata.temperature is not None
|
||||
|
||||
# Apply temperature.
|
||||
logits = self.apply_temperature(logits, sampling_metadata.temperature)
|
||||
|
||||
# Apply logits processors that only apply to random sampling
|
||||
# (argmax invariant)
|
||||
for processor in sampling_metadata.logitsprocs.argmax_invariant:
|
||||
logits = processor.apply(logits)
|
||||
|
||||
# Apply top_k and/or top_p.
|
||||
random_sampled = self.topk_topp_sampler(
|
||||
logits,
|
||||
sampling_metadata.generators,
|
||||
sampling_metadata.top_k,
|
||||
sampling_metadata.top_p,
|
||||
)
|
||||
|
||||
if greedy_sampled is None:
|
||||
return random_sampled
|
||||
|
||||
sampled = torch.where(
|
||||
sampling_metadata.temperature < _SAMPLING_EPS,
|
||||
greedy_sampled,
|
||||
random_sampled,
|
||||
out=greedy_sampled, # Reuse tensor
|
||||
)
|
||||
return sampled
|
||||
|
||||
def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
return logits.log_softmax(dim=-1, dtype=torch.float32)
|
||||
|
||||
def gather_logprobs(
|
||||
self,
|
||||
logprobs: torch.Tensor,
|
||||
num_logprobs: int,
|
||||
token_ids: torch.Tensor,
|
||||
) -> LogprobsTensors:
|
||||
"""
|
||||
Gather logprobs for topk and sampled/prompt token.
|
||||
|
||||
Args:
|
||||
logprobs: (num tokens) x (vocab) tensor
|
||||
num_logprobs: minimum number of logprobs to
|
||||
retain per token
|
||||
token_ids: prompt tokens (if prompt logprobs)
|
||||
or sampled tokens (if sampled
|
||||
logprobs); 1D token ID tensor
|
||||
with (num tokens) elements
|
||||
Must be int64.
|
||||
|
||||
Returns:
|
||||
Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
|
||||
Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
|
||||
Sampled token rank tensor, (num tokens)
|
||||
"""
|
||||
assert token_ids.dtype == torch.int64
|
||||
# Find the topK values.
|
||||
topk_logprobs, topk_indices = torch.topk(logprobs,
|
||||
num_logprobs,
|
||||
dim=-1)
|
||||
|
||||
# Get with the logprob of the prompt or sampled token.
|
||||
token_ids = token_ids.unsqueeze(-1)
|
||||
token_logprobs = logprobs.gather(-1, token_ids)
|
||||
|
||||
# Compute the ranks of the actual token.
|
||||
token_ranks = (logprobs >= token_logprobs).sum(-1)
|
||||
|
||||
# Concatenate together with the topk.
|
||||
indices = torch.cat((token_ids, topk_indices), dim=1)
|
||||
logprobs = torch.cat((token_logprobs, topk_logprobs), dim=1)
|
||||
|
||||
# Use int32 to reduce the tensor size.
|
||||
indices = indices.to(torch.int32)
|
||||
|
||||
return LogprobsTensors(indices, logprobs, token_ranks)
|
||||
|
||||
def apply_penalties(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
if not sampling_metadata.no_penalties:
|
||||
assert sampling_metadata.prompt_token_ids is not None
|
||||
logits = apply_all_penalties(
|
||||
logits,
|
||||
sampling_metadata.prompt_token_ids,
|
||||
sampling_metadata.presence_penalties,
|
||||
sampling_metadata.frequency_penalties,
|
||||
sampling_metadata.repetition_penalties,
|
||||
sampling_metadata.output_token_ids,
|
||||
)
|
||||
return logits
|
||||
|
||||
def apply_allowed_token_ids(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
if sampling_metadata.allowed_token_ids_mask is not None:
|
||||
logits.masked_fill_(sampling_metadata.allowed_token_ids_mask,
|
||||
float("-inf"))
|
||||
return logits
|
||||
|
||||
def apply_bad_words(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
if sampling_metadata.bad_words_token_ids:
|
||||
apply_bad_words(
|
||||
logits,
|
||||
sampling_metadata.bad_words_token_ids,
|
||||
sampling_metadata.output_token_ids,
|
||||
)
|
||||
return logits
|
||||
Reference in New Issue
Block a user