first commit
This commit is contained in:
17
vllm_br/v1/sample/__init__.py
Normal file
17
vllm_br/v1/sample/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from . import ops # noqa: F401
|
||||
BIN
vllm_br/v1/sample/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
vllm_br/v1/sample/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
17
vllm_br/v1/sample/ops/__init__.py
Normal file
17
vllm_br/v1/sample/ops/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from . import logprobs, topk_topp_sampler # noqa: F401
|
||||
BIN
vllm_br/v1/sample/ops/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
vllm_br/v1/sample/ops/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm_br/v1/sample/ops/__pycache__/logprobs.cpython-310.pyc
Normal file
BIN
vllm_br/v1/sample/ops/__pycache__/logprobs.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
40
vllm_br/v1/sample/ops/logprobs.py
Normal file
40
vllm_br/v1/sample/ops/logprobs.py
Normal file
@@ -0,0 +1,40 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
"""Some utilities for logprobs, including logits."""
|
||||
|
||||
import torch
|
||||
from fastcore.basics import patch_to
|
||||
|
||||
from vllm.v1.sample.ops import logprobs
|
||||
|
||||
|
||||
@patch_to(logprobs)
|
||||
def batched_count_greater_than(x: torch.Tensor,
|
||||
values: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Counts elements in each row of x that are greater than the corresponding
|
||||
value in values. Use torch.compile to generate an optimized kernel for
|
||||
this function. otherwise, it will create additional copies of the input
|
||||
tensors and cause memory issues.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): A 2D tensor of shape (batch_size, n_elements).
|
||||
values (torch.Tensor): A 2D tensor of shape (batch_size, 1).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: A 1D tensor of shape (batch_size,) with the counts.
|
||||
"""
|
||||
return (x >= values).sum(-1)
|
||||
138
vllm_br/v1/sample/ops/topk_topp_sampler.py
Normal file
138
vllm_br/v1/sample/ops/topk_topp_sampler.py
Normal file
@@ -0,0 +1,138 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from fastcore.basics import patch_to
|
||||
|
||||
from vllm.v1.sample.ops import topk_topp_sampler
|
||||
|
||||
|
||||
def topk_topp_sampler_forward_native(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
generators: dict[int, torch.Generator],
|
||||
k: Optional[torch.Tensor],
|
||||
p: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
PyTorch-native implementation of top-k and top-p sampling.
|
||||
|
||||
The logits tensor may be updated in-place.
|
||||
"""
|
||||
logits = apply_top_k_top_p(logits, k, p)
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
return random_sample(probs, generators)
|
||||
|
||||
|
||||
def apply_top_k_only(
|
||||
logits: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply top-k mask to the logits.
|
||||
|
||||
This implementation doesn't involve sorting the entire vocab.
|
||||
|
||||
The logits tensor may be updated in-place.
|
||||
"""
|
||||
no_top_k_mask = k == logits.shape[1]
|
||||
# Set non-top-k rows to 1 so that we can gather.
|
||||
k = k.masked_fill(no_top_k_mask, 1)
|
||||
max_top_k = k.max()
|
||||
# topk.values tensor has shape [batch_size, max_top_k].
|
||||
# Convert top k to 0-based index in range [0, max_top_k).
|
||||
k_index = k.sub_(1).unsqueeze(1)
|
||||
top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index.long())
|
||||
# Handle non-topk rows.
|
||||
top_k_mask.masked_fill_(no_top_k_mask.unsqueeze(1), -float("inf"))
|
||||
logits.masked_fill_(logits < top_k_mask, -float("inf"))
|
||||
return logits
|
||||
|
||||
|
||||
# scatter usage not support on br, need fix.
|
||||
@patch_to(topk_topp_sampler)
|
||||
def apply_top_k_top_p(
|
||||
logits: torch.Tensor,
|
||||
k: Optional[torch.Tensor],
|
||||
p: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""Apply top-k and top-p masks to the logits.
|
||||
|
||||
If a top-p is used, this function will sort the logits tensor,
|
||||
which can be slow for large batches.
|
||||
|
||||
The logits tensor may be updated in-place.
|
||||
"""
|
||||
if p is None:
|
||||
if k is None:
|
||||
return logits
|
||||
|
||||
# Avoid sorting vocab for top-k only case.
|
||||
return apply_top_k_only(logits, k)
|
||||
|
||||
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
|
||||
|
||||
if k is not None:
|
||||
# Apply top-k.
|
||||
top_k_mask = logits_sort.size(1) - k.to(torch.long) # shape: B
|
||||
# Get all the top_k values.
|
||||
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
|
||||
top_k_mask = logits_sort < top_k_mask
|
||||
logits_sort.masked_fill_(top_k_mask, -float("inf"))
|
||||
|
||||
if p is not None:
|
||||
# Apply top-p.
|
||||
probs_sort = logits_sort.softmax(dim=-1)
|
||||
probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort)
|
||||
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
|
||||
# at least one
|
||||
top_p_mask[:, -1] = False
|
||||
logits_sort.masked_fill_(top_p_mask, -float("inf"))
|
||||
# Re-sort the probabilities.
|
||||
logits = logits_sort.clone()
|
||||
logits = logits.scatter_(dim=-1, index=logits_idx, src=logits_sort)
|
||||
return logits
|
||||
|
||||
|
||||
def random_sample(
|
||||
probs: torch.Tensor,
|
||||
generators: dict[int, torch.Generator],
|
||||
) -> torch.Tensor:
|
||||
"""Randomly sample from the probabilities.
|
||||
|
||||
We use this function instead of torch.multinomial because torch.multinomial
|
||||
causes CPU-GPU synchronization.
|
||||
"""
|
||||
q = torch.empty_like(probs)
|
||||
# NOTE(woosuk): To batch-process the requests without their own seeds,
|
||||
# which is the common case, we first assume that every request does
|
||||
# not have its own seed. Then, we overwrite the values for the requests
|
||||
# that have their own seeds.
|
||||
if len(generators) != probs.shape[0]:
|
||||
q.exponential_()
|
||||
if generators:
|
||||
# TODO(woosuk): This can be slow because we handle each request
|
||||
# one by one. Optimize this.
|
||||
for i, generator in generators.items():
|
||||
q[i].exponential_(generator=generator)
|
||||
return probs.div_(q).argmax(dim=-1).view(-1)
|
||||
|
||||
|
||||
# vllm.v1.sample.ops.topk_topp_sampler.TopKTopPSampler.forward_native = topk_topp_sampler_forward_native
|
||||
Reference in New Issue
Block a user