This commit is contained in:
2026-04-02 04:53:13 +00:00
parent 80932c96e5
commit 24df76db9d
1987 changed files with 447445 additions and 0 deletions

View File

View File

@@ -0,0 +1,91 @@
import torch
import torch_vacc
class BinCountTensorPooler:
"""
sampler类中bin-count tensor的缓冲器。
核心功能为apply_penalties 提供创建好的bin-count, 为避免重复重头计算。
注意: 如果系统中的req_id未能确保在max_count的数量下的唯一性要考虑是否适用该pooler
pooler = BinCountTensorPooler(15169, "cpu")
list_1 = pooler.request_tensors(['1','3','aad'])
bin_count_buffers = pooler.request_tensors(req_ids)
"""
def __init__(
self,
vocab_size: int,
device: torch.device,
max_cache_count: int = 20,
):
# 缓存容器用列表维护请求ID与对应张量的顺序FIFO淘汰
self._cached_tensors: list[torch.Tensor] = []
self._cached_req_ids: list[str] = []
# 张量维度(+1 通常用于padding/特殊标记)
self.vocab_size: int = vocab_size + 1
self.device: torch.device = device
self.max_cache_count: int = max_cache_count
def request_tensors(self, req_ids: list[str]) -> list[torch.Tensor]:
"""
批量请求张量:已缓存的直接返回,未缓存的创建并加入缓存。
Args:
req_ids: 待请求的请求ID列表
Returns:
与req_ids顺序对应的bin-count tensor
"""
if not req_ids:
return [] # 空输入快速返回,避免无效循环
out_tensors = []
for req_id in req_ids:
if req_id not in self._cached_req_ids:
# 注册新的req, 增加cached-tensor
self._add_new_cache(req_id)
# 快速获取缓存索引后续可优化为dict映射提升性能
cache_idx = self._cached_req_ids.index(req_id)
out_tensors.append(self._cached_tensors[cache_idx])
return out_tensors
def _add_new_cache(self, req_id: str) -> None:
"""
新增缓存项创建张量并加入缓存超出最大数量时按FIFO淘汰最早项。
私有方法:封装内部逻辑,避免外部直接调用
"""
# FIFO淘汰缓存满时移除头部最早加入的项
if len(self._cached_req_ids) >= self.max_cache_count: # >= 更严谨(避免边界值问题)
self._cached_tensors.pop(0)
self._cached_req_ids.pop(0)
# 创建二进制计数张量int32类型初始全0
bin_count_tensor = torch.zeros(
size=(1, self.vocab_size),
dtype=torch.int32,
device=self.device,
requires_grad=False, # 明确禁用梯度(计数张量无需反向传播)
)
# 同步添加ID和张量保证两个列表索引一致
self._cached_req_ids.append(req_id)
self._cached_tensors.append(bin_count_tensor)
def clear_cache(self) -> None:
"""清空所有缓存(可选扩展方法,方便外部手动清理)"""
self._cached_tensors.clear()
self._cached_req_ids.clear()
def get_cache_status(self) -> dict:
"""获取缓存状态(可选扩展方法,方便监控)"""
return {
"current_cache_count": len(self._cached_req_ids),
"max_cache_count": self.max_cache_count,
"cached_req_ids": self._cached_req_ids.copy(), # 返回副本避免外部修改
"tensor_shape": (1, self.vocab_size),
"tensor_dtype": torch.int32,
"device": str(self.device),
}

View File

@@ -0,0 +1,48 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Optional
import torch
from vllm.v1.sample.logits_processor import LogitsProcessors
@dataclass
class SamplingMetadata:
temperature: Optional[torch.Tensor]
all_greedy: bool
all_random: bool
top_p: Optional[torch.Tensor]
top_k: Optional[torch.Tensor]
generators: dict[int, torch.Generator]
# None means no logprobs, 0 means sampled token logprobs only
max_num_logprobs: Optional[int]
no_penalties: bool
prompt_token_ids: Optional[torch.Tensor]
frequency_penalties: torch.Tensor
presence_penalties: torch.Tensor
repetition_penalties: torch.Tensor
output_token_ids: list[list[int]]
# `allowed_token_ids_mask` is a 2D bool tensor of shape (max batch size,
# vocab size).
allowed_token_ids_mask: Optional[torch.Tensor]
# req_index -> bad_words_token_ids
bad_words_token_ids: dict[int, list[list[int]]]
# Loaded logits processors
logitsprocs: LogitsProcessors
temperature_cpu: Optional[torch.Tensor]
top_p_cpu: Optional[torch.Tensor]
top_k_cpu: Optional[torch.Tensor]

View File

@@ -0,0 +1,230 @@
from typing import Optional
import torch
import torch.nn as nn
from vllm.logger import init_logger
from vllm.triton_utils import tl, triton
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.sample.rejection_sampler import generate_uniform_probs, compute_probs, rejection_random_sample_kernel, sample_recovered_tokens
from vllm.distributed import get_tensor_model_parallel_rank
PLACEHOLDER_TOKEN_ID: tl.constexpr = -1
GREEDY_TEMPERATURE: tl.constexpr = -1
# Maximum number of speculative draft tokens allowed per request in a single
# step. This value is chosen to be large enough to handle typical use cases.
MAX_SPEC_LEN = 32
def rejection_greedy_sample_python(
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
cu_num_draft_tokens_ptr, # [batch_size]
draft_token_ids_ptr, # [num_tokens]
target_argmax_ptr, # [num_tokens]
bonus_token_ids_ptr, # [batch_size]
is_greedy_ptr, # [batch_size] or None
max_spec_len,
num_warps
):
# print('max_spec_len', max_spec_len)
if max_spec_len == 1:
for bi in range(output_token_ids_ptr.shape[0]):
output_token_ids_ptr[bi, 0] = target_argmax_ptr[bi]
if target_argmax_ptr[bi].item() == draft_token_ids_ptr[bi].item():
output_token_ids_ptr[bi, 1] = bonus_token_ids_ptr[bi]
else:
raise ValueError('TODO mtp k > 1')
class RejectionSampler(nn.Module):
def forward(
self,
metadata: SpecDecodeMetadata,
# [num_tokens, vocab_size]
draft_probs: Optional[torch.Tensor],
# [num_tokens, vocab_size]
target_logits: torch.Tensor,
# [batch_size, 1]
bonus_token_ids: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
'''
Args:
metadata:
Metadata for spec decoding.
draft_probs (Optional[torch.Tensor]):
Probability distribution for the draft tokens. Shape is
[num_tokens, vocab_size]. Can be None if probabilities are
not provided, which is the case for ngram spec decode.
target_logits (torch.Tensor):
Target model's logits probability distribution.
Shape is [num_tokens, vocab_size]. Here, probabilities from
different requests are flattened into a single tensor because
this is the shape of the output logits.
NOTE: `target_logits` can be updated in place to save memory.
bonus_token_ids_tensor (torch.Tensor):
A tensor containing bonus tokens. Shape is [batch_size, 1].
Bonus tokens are added to the end of the sequence if all
proposed tokens are accepted. We generate the bonus tokens
outside of the rejection sampler with the default sampling
strategy. It allows for more flexibility in the sampling
process such as top_p, top_k sampling.
sampling_metadata (vllm.v1.sample.metadata.SamplingMetadata):
Additional metadata needed for sampling, such as temperature,
top-k/top-p parameters, or other relevant information.
Returns:
output_token_ids (torch.Tensor):
A tensor containing the final output token IDs.
'''
assert metadata.max_spec_len <= MAX_SPEC_LEN
# [num_tokens, vocab_size]
# NOTE(woosuk): `target_logits` can be updated in place inside the
# `compute_probs` function.
# print(sampling_metadata)
# rank_id = get_tensor_model_parallel_rank()
if metadata.max_spec_len == 1:
output_token_ids = torch.vacc.rejection_sampler_v1(
target_logits.to(torch.float32),
metadata.draft_token_ids,
bonus_token_ids,
sampling_metadata.temperature,
sampling_metadata.top_p,
sampling_metadata.top_k,
sampling_metadata.all_greedy,
sampling_metadata.all_random,
sampling_metadata.generators
)
else:
target_probs = compute_probs(
target_logits.to(torch.float32),
metadata.cu_num_draft_tokens,
sampling_metadata,
)
output_token_ids = rejection_sample(
metadata.draft_token_ids,
metadata.num_draft_tokens,
metadata.max_spec_len,
metadata.cu_num_draft_tokens,
draft_probs,
target_probs,
bonus_token_ids,
sampling_metadata,
)
# output_token_ids_cpu = output_token_ids.cpu().tolist()
# output_token_ids_dev_cpu = output_token_ids_dev.cpu().tolist()
# for i in range(len(output_token_ids_cpu)):
# for j in range(len(output_token_ids_cpu[0])):
# if output_token_ids_cpu[i][j] != output_token_ids_dev_cpu[i][j]:
# # print(output_token_ids_cpu)
# # print(output_token_ids_dev_cpu)
# print("cpu \n", output_token_ids, "vacc \n" , output_token_ids_dev)
# exit()
# print("cpu \n", output_token_ids, "vacc \n" , output_token_ids_dev)
return output_token_ids
def rejection_sample(
# [num_tokens]
draft_token_ids: torch.Tensor,
# [batch_size]
num_draft_tokens: list[int],
max_spec_len: int,
# [batch_size]
cu_num_draft_tokens: torch.Tensor,
# [num_tokens, vocab_size]
draft_probs: Optional[torch.Tensor],
# [num_tokens, vocab_size]
target_probs: torch.Tensor,
# [batch_size, 1]
bonus_token_ids: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
assert draft_token_ids.ndim == 1
assert draft_probs is None or draft_probs.ndim == 2
assert cu_num_draft_tokens.ndim == 1
assert target_probs.ndim == 2
batch_size = len(num_draft_tokens)
num_tokens = draft_token_ids.shape[0]
vocab_size = target_probs.shape[-1]
device = target_probs.device
assert draft_token_ids.is_contiguous()
assert draft_probs is None or draft_probs.is_contiguous()
assert target_probs.is_contiguous()
assert bonus_token_ids.is_contiguous()
assert target_probs.shape == (num_tokens, vocab_size)
# Create output buffer.
output_token_ids = torch.empty(
(batch_size, max_spec_len + 1),
dtype=torch.int32, # Consistent with SamplerOutput.sampled_token_ids.
device=device,
)
output_token_ids.fill_(PLACEHOLDER_TOKEN_ID)
if sampling_metadata.all_greedy:
is_greedy = None
else:
is_greedy = sampling_metadata.temperature == GREEDY_TEMPERATURE
if not sampling_metadata.all_random:
# Rejection sampling for greedy sampling requests.
target_argmax = target_probs.argmax(dim=-1)
# rejection_greedy_sample_kernel[(batch_size, )](
rejection_greedy_sample_python(
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
target_argmax,
bonus_token_ids,
is_greedy,
max_spec_len,
num_warps=1,
)
if sampling_metadata.all_greedy:
return output_token_ids
else:
# TODO
raise ValueError('not support yet')
# Generate uniform probabilities for rejection sampling.
# [num_tokens]
uniform_probs = generate_uniform_probs(
num_tokens,
num_draft_tokens,
sampling_metadata.generators,
device,
)
# Sample recovered tokens for each position.
# [num_tokens]
recovered_token_ids = sample_recovered_tokens(
max_spec_len,
num_draft_tokens,
cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
sampling_metadata,
device,
)
# Rejection sampling for random sampling requests.
rejection_random_sample_kernel[(batch_size, )](
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
bonus_token_ids,
recovered_token_ids,
uniform_probs,
is_greedy,
max_spec_len,
vocab_size,
NO_DRAFT_PROBS=draft_probs is None,
num_warps=1,
)
return output_token_ids

View File

@@ -0,0 +1,276 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A layer that samples the next tokens from the model's outputs."""
import torch
import torch.nn as nn
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.bad_words import apply_bad_words
from vllm.v1.sample.ops.penalties import apply_all_penalties
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
from .cached_pooler import BinCountTensorPooler
_SAMPLING_EPS = 1e-5
def _convert_to_tensors(output_token_ids: list[list[int]], vocab_size: int,
device: torch.device) -> torch.Tensor:
"""
Convert the different list data structures to tensors.
"""
output_tokens_tensor = make_tensor_with_pad(
output_token_ids,
# Use the value of vocab_size as a pad since we don't have a
# token_id of this value.
pad=vocab_size,
device="cpu",
dtype=torch.int32, # init with int32
pin_memory=is_pin_memory_available(),
)
return output_tokens_tensor.to(device, non_blocking=True)
class Sampler(nn.Module):
def forward(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
is_first_calculate,
req_ids: list[str],
) -> SamplerOutput:
# print(sampling_metadata.generators, len(sampling_metadata.generators), sampling_metadata.temperature_cpu)
# sampling_metadata.temperature = sampling_metadata.temperature.to(logits.device)
# sampling_metadata.top_p = sampling_metadata.top_p.to(logits.device)
# sampling_metadata.top_k = sampling_metadata.top_k.to(logits.device)
# NOTE(woosuk): Use the original logits (before any penalties or
# temperature scaling) for the top-k logprobs.
# This is different from the V0 sampler, which uses the logits that
# is used for sampling (after penalties and temperature scaling).
# TODO(rob): provide option for logprobs post sampling.
# See https://vllm-dev.slack.com/archives/C07UUL8E61Z/p1735907856007919 # noqa: E501
num_logprobs = sampling_metadata.max_num_logprobs
if num_logprobs is not None:
raw_logprobs = self.compute_logprobs(logits)
# Use float32 for the logits.
logits = logits.to(torch.float32)
# Apply allowed token ids.
logits = self.apply_allowed_token_ids(logits, sampling_metadata)
# Apply bad words exclusion.
logits = self.apply_bad_words(logits, sampling_metadata)
# Apply logits processors which can impact greedy sampling
for processor in (sampling_metadata.logitsprocs.non_argmax_invariant):
logits = processor.apply(logits)
# Apply penalties (e.g., min_tokens, freq_penalties).
# logits = self.apply_penalties(logits, sampling_metadata)
if not sampling_metadata.no_penalties:
if (not hasattr(self, "bin_count_pooler")):
self.bin_count_pooler = BinCountTensorPooler(logits.shape[-1], logits.device)
assert sampling_metadata.prompt_token_ids is not None
buf_bin_buffer = self.bin_count_pooler.request_tensors(req_ids)
batch, vocab_size = logits.shape
logits_res = []
for i in range(batch):
output_tokens_t = _convert_to_tensors([sampling_metadata.output_token_ids[i]], vocab_size, logits.device).to(torch.int32)
output_tokens_t_lastone = output_tokens_t[:, -1:]
if logits.shape[0] > 1:
logits_i = torch.vacc.apply_penalties(logits[i:i+1],
output_tokens_t_lastone,
buf_bin_buffer[i],
vocab_size,
output_tokens_t_lastone.shape[-1],
[sampling_metadata.frequency_penalties[i]],
[sampling_metadata.presence_penalties[i]],
is_first_calculate)
else:
logits_i = torch.vacc.apply_penalties(logits,
output_tokens_t_lastone,
buf_bin_buffer[i],
vocab_size,
output_tokens_t_lastone.shape[-1],
[sampling_metadata.frequency_penalties[i]],
[sampling_metadata.presence_penalties[i]],
is_first_calculate)
logits_res.append(logits_i)
if len(logits_res) > 1:
logits = torch.concat(logits_res)
else:
logits = logits_res[0]
# Sample the next token.
# sampled = self.sample(logits, sampling_metadata)
sampled, _ = torch.vacc.sampler_v1(logits, sampling_metadata.top_p_cpu, sampling_metadata.top_k_cpu, sampling_metadata.temperature_cpu, int(sampling_metadata.all_greedy), int(sampling_metadata.all_random), sampling_metadata.generators)
# Gather the logprobs of the topk and sampled token (if requested).
# Get logprobs and rank tensors (if requested)
logprobs_tensors = None if num_logprobs is None else \
self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=sampled.long())
# These are GPU tensors.
sampler_output = SamplerOutput(
# The sampled tokens are expanded to 2D tensor with shape
# [num_requests, 1], where each row represents one generated
# token per request.
sampled_token_ids=sampled.unsqueeze(-1),
logprobs_tensors=logprobs_tensors,
)
return sampler_output
def apply_temperature(
self,
logits: torch.Tensor,
temp: torch.Tensor,
) -> torch.Tensor:
# Use in-place division to avoid creating a new tensor.
return logits.div_(temp.unsqueeze(dim=1))
def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
return logits.argmax(dim=-1).view(-1)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
"""Sample logits based on sampling metadata.
The various logits processing functions called in this method
may update the logits tensor in-place.
"""
assert not (sampling_metadata.all_greedy
and sampling_metadata.all_random)
if sampling_metadata.all_random:
greedy_sampled = None
else:
greedy_sampled = self.greedy_sample(logits)
if sampling_metadata.all_greedy:
return greedy_sampled
assert sampling_metadata.temperature is not None
# Apply temperature.
logits = self.apply_temperature(logits, sampling_metadata.temperature)
# Apply logits processors that only apply to random sampling
# (argmax invariant)
for processor in sampling_metadata.logitsprocs.argmax_invariant:
logits = processor.apply(logits)
# Apply top_k and/or top_p.
random_sampled = self.topk_topp_sampler(
logits,
sampling_metadata.generators,
sampling_metadata.top_k,
sampling_metadata.top_p,
)
if greedy_sampled is None:
return random_sampled
sampled = torch.where(
sampling_metadata.temperature < _SAMPLING_EPS,
greedy_sampled,
random_sampled,
out=greedy_sampled, # Reuse tensor
)
return sampled
def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
return logits.log_softmax(dim=-1, dtype=torch.float32)
def gather_logprobs(
self,
logprobs: torch.Tensor,
num_logprobs: int,
token_ids: torch.Tensor,
) -> LogprobsTensors:
"""
Gather logprobs for topk and sampled/prompt token.
Args:
logprobs: (num tokens) x (vocab) tensor
num_logprobs: minimum number of logprobs to
retain per token
token_ids: prompt tokens (if prompt logprobs)
or sampled tokens (if sampled
logprobs); 1D token ID tensor
with (num tokens) elements
Must be int64.
Returns:
Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
Sampled token rank tensor, (num tokens)
"""
assert token_ids.dtype == torch.int64
# Find the topK values.
topk_logprobs, topk_indices = torch.topk(logprobs,
num_logprobs,
dim=-1)
# Get with the logprob of the prompt or sampled token.
token_ids = token_ids.unsqueeze(-1)
token_logprobs = logprobs.gather(-1, token_ids)
# Compute the ranks of the actual token.
token_ranks = (logprobs >= token_logprobs).sum(-1)
# Concatenate together with the topk.
indices = torch.cat((token_ids, topk_indices), dim=1)
logprobs = torch.cat((token_logprobs, topk_logprobs), dim=1)
# Use int32 to reduce the tensor size.
indices = indices.to(torch.int32)
return LogprobsTensors(indices, logprobs, token_ranks)
def apply_penalties(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
if not sampling_metadata.no_penalties:
assert sampling_metadata.prompt_token_ids is not None
logits = apply_all_penalties(
logits,
sampling_metadata.prompt_token_ids,
sampling_metadata.presence_penalties,
sampling_metadata.frequency_penalties,
sampling_metadata.repetition_penalties,
sampling_metadata.output_token_ids,
)
return logits
def apply_allowed_token_ids(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
if sampling_metadata.allowed_token_ids_mask is not None:
logits.masked_fill_(sampling_metadata.allowed_token_ids_mask,
float("-inf"))
return logits
def apply_bad_words(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
if sampling_metadata.bad_words_token_ids:
apply_bad_words(
logits,
sampling_metadata.bad_words_token_ids,
sampling_metadata.output_token_ids,
)
return logits