first commit

This commit is contained in:
2026-03-10 13:31:25 +08:00
parent ba974cecfa
commit b62b889355
2604 changed files with 438977 additions and 0 deletions

View File

@@ -0,0 +1,17 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from . import ops # noqa: F401

Binary file not shown.

View File

@@ -0,0 +1,17 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from . import logprobs, topk_topp_sampler # noqa: F401

View File

@@ -0,0 +1,40 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
"""Some utilities for logprobs, including logits."""
import torch
from fastcore.basics import patch_to
from vllm.v1.sample.ops import logprobs
@patch_to(logprobs)
def batched_count_greater_than(x: torch.Tensor,
values: torch.Tensor) -> torch.Tensor:
"""
Counts elements in each row of x that are greater than the corresponding
value in values. Use torch.compile to generate an optimized kernel for
this function. otherwise, it will create additional copies of the input
tensors and cause memory issues.
Args:
x (torch.Tensor): A 2D tensor of shape (batch_size, n_elements).
values (torch.Tensor): A 2D tensor of shape (batch_size, 1).
Returns:
torch.Tensor: A 1D tensor of shape (batch_size,) with the counts.
"""
return (x >= values).sum(-1)

View File

@@ -0,0 +1,138 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import torch
from fastcore.basics import patch_to
from vllm.v1.sample.ops import topk_topp_sampler
def topk_topp_sampler_forward_native(
self,
logits: torch.Tensor,
generators: dict[int, torch.Generator],
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
) -> torch.Tensor:
"""
PyTorch-native implementation of top-k and top-p sampling.
The logits tensor may be updated in-place.
"""
logits = apply_top_k_top_p(logits, k, p)
probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators)
def apply_top_k_only(
logits: torch.Tensor,
k: torch.Tensor,
) -> torch.Tensor:
"""
Apply top-k mask to the logits.
This implementation doesn't involve sorting the entire vocab.
The logits tensor may be updated in-place.
"""
no_top_k_mask = k == logits.shape[1]
# Set non-top-k rows to 1 so that we can gather.
k = k.masked_fill(no_top_k_mask, 1)
max_top_k = k.max()
# topk.values tensor has shape [batch_size, max_top_k].
# Convert top k to 0-based index in range [0, max_top_k).
k_index = k.sub_(1).unsqueeze(1)
top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index.long())
# Handle non-topk rows.
top_k_mask.masked_fill_(no_top_k_mask.unsqueeze(1), -float("inf"))
logits.masked_fill_(logits < top_k_mask, -float("inf"))
return logits
# scatter usage not support on br, need fix.
@patch_to(topk_topp_sampler)
def apply_top_k_top_p(
logits: torch.Tensor,
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
) -> torch.Tensor:
"""Apply top-k and top-p masks to the logits.
If a top-p is used, this function will sort the logits tensor,
which can be slow for large batches.
The logits tensor may be updated in-place.
"""
if p is None:
if k is None:
return logits
# Avoid sorting vocab for top-k only case.
return apply_top_k_only(logits, k)
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
if k is not None:
# Apply top-k.
top_k_mask = logits_sort.size(1) - k.to(torch.long) # shape: B
# Get all the top_k values.
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
top_k_mask = logits_sort < top_k_mask
logits_sort.masked_fill_(top_k_mask, -float("inf"))
if p is not None:
# Apply top-p.
probs_sort = logits_sort.softmax(dim=-1)
probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort)
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
# at least one
top_p_mask[:, -1] = False
logits_sort.masked_fill_(top_p_mask, -float("inf"))
# Re-sort the probabilities.
logits = logits_sort.clone()
logits = logits.scatter_(dim=-1, index=logits_idx, src=logits_sort)
return logits
def random_sample(
probs: torch.Tensor,
generators: dict[int, torch.Generator],
) -> torch.Tensor:
"""Randomly sample from the probabilities.
We use this function instead of torch.multinomial because torch.multinomial
causes CPU-GPU synchronization.
"""
q = torch.empty_like(probs)
# NOTE(woosuk): To batch-process the requests without their own seeds,
# which is the common case, we first assume that every request does
# not have its own seed. Then, we overwrite the values for the requests
# that have their own seeds.
if len(generators) != probs.shape[0]:
q.exponential_()
if generators:
# TODO(woosuk): This can be slow because we handle each request
# one by one. Optimize this.
for i, generator in generators.items():
q[i].exponential_(generator=generator)
return probs.div_(q).argmax(dim=-1).view(-1)
# vllm.v1.sample.ops.topk_topp_sampler.TopKTopPSampler.forward_native = topk_topp_sampler_forward_native