optimize the funtion of computing topk and topp in sampler. (#970)
### What this PR does / why we need it? Optimize the performance of calculation logic in sampler and deepseekv2. ### Does this PR introduce _any_ user-facing change? Added VLLM_ENABLE_TOPK_OPTIMZE config in sampler ### How was this patch tested? pytest test_sampler.py Signed-off-by: wangxiaoxin (A) <wangxiaoxin7@huawei.com> Co-authored-by: wangxiaoxin (A) <wangxiaoxin7@huawei.com> Co-authored-by: ZhengWG <zwg0606@gmail.com>
This commit is contained in:
@@ -21,8 +21,10 @@
|
|||||||
Run `pytest tests/test_offline_inference.py`.
|
Run `pytest tests/test_offline_inference.py`.
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
import vllm # noqa: F401
|
import vllm # noqa: F401
|
||||||
|
from vllm import SamplingParams
|
||||||
|
|
||||||
from tests.conftest import VllmRunner
|
from tests.conftest import VllmRunner
|
||||||
|
|
||||||
@@ -57,3 +59,25 @@ def test_models_distributed_DeepSeek():
|
|||||||
distributed_executor_backend="mp",
|
distributed_executor_backend="mp",
|
||||||
) as vllm_model:
|
) as vllm_model:
|
||||||
vllm_model.generate_greedy(example_prompts, max_tokens)
|
vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||||
|
|
||||||
|
|
||||||
|
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_TOPK_OPTIMZE": "1"})
|
||||||
|
def test_models_distributed_topk() -> None:
|
||||||
|
example_prompts = [
|
||||||
|
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.",
|
||||||
|
"Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.",
|
||||||
|
"Compare and contrast artificial intelligence with human intelligence in terms of processing information.",
|
||||||
|
]
|
||||||
|
dtype = "half"
|
||||||
|
sampling_params = SamplingParams(max_tokens=5,
|
||||||
|
temperature=0.0,
|
||||||
|
top_k=50,
|
||||||
|
top_p=0.9)
|
||||||
|
|
||||||
|
with VllmRunner(
|
||||||
|
"deepseek-ai/DeepSeek-V2-Lite",
|
||||||
|
dtype=dtype,
|
||||||
|
tensor_parallel_size=4,
|
||||||
|
distributed_executor_backend="mp",
|
||||||
|
) as vllm_model:
|
||||||
|
vllm_model.generate(example_prompts, sampling_params)
|
||||||
|
|||||||
@@ -21,9 +21,11 @@
|
|||||||
Run `pytest tests/test_offline_inference.py`.
|
Run `pytest tests/test_offline_inference.py`.
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import vllm # noqa: F401
|
import vllm # noqa: F401
|
||||||
|
from vllm import SamplingParams
|
||||||
from vllm.assets.image import ImageAsset
|
from vllm.assets.image import ImageAsset
|
||||||
|
|
||||||
import vllm_ascend # noqa: F401
|
import vllm_ascend # noqa: F401
|
||||||
@@ -81,3 +83,24 @@ def test_multimodal(model, prompt_template, vllm_runner):
|
|||||||
vllm_model.generate_greedy(prompts=prompts,
|
vllm_model.generate_greedy(prompts=prompts,
|
||||||
images=images,
|
images=images,
|
||||||
max_tokens=64)
|
max_tokens=64)
|
||||||
|
|
||||||
|
|
||||||
|
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_TOPK_OPTIMZE": "1"})
|
||||||
|
def test_models_topk() -> None:
|
||||||
|
example_prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
"The president of the United States is",
|
||||||
|
"The capital of France is",
|
||||||
|
"The future of AI is",
|
||||||
|
]
|
||||||
|
sampling_params = SamplingParams(max_tokens=5,
|
||||||
|
temperature=0.0,
|
||||||
|
top_k=50,
|
||||||
|
top_p=0.9)
|
||||||
|
|
||||||
|
with VllmRunner("Qwen/Qwen2.5-0.5B-Instruct",
|
||||||
|
max_model_len=8192,
|
||||||
|
dtype="float16",
|
||||||
|
enforce_eager=True,
|
||||||
|
gpu_memory_utilization=0.7) as vllm_model:
|
||||||
|
vllm_model.generate(example_prompts, sampling_params)
|
||||||
|
|||||||
147
tests/singlecard/test_sampler.py
Normal file
147
tests/singlecard/test_sampler.py
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||||
|
# This file is a part of the vllm-ascend project.
|
||||||
|
# Adapted from vllm/tests/entrypoints/llm/test_guided_generate.py
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from vllm.v1.sample.sampler import Sampler # noqa: F401
|
||||||
|
|
||||||
|
# Set tolerance to 1 for quant ops
|
||||||
|
DEFAULT_ATOL = 1e-3
|
||||||
|
DEFAULT_RTOL = 1e-3
|
||||||
|
|
||||||
|
|
||||||
|
def apply_min_p_new(
|
||||||
|
logits: torch.Tensor,
|
||||||
|
min_p: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Filters logits using adaptive probability thresholding.
|
||||||
|
"""
|
||||||
|
if min_p == 0:
|
||||||
|
return logits
|
||||||
|
# Convert logits to probability distribution
|
||||||
|
probability_values = torch.nn.functional.softmax(logits, dim=-1)
|
||||||
|
# Calculate maximum probabilities per sequence
|
||||||
|
max_probabilities = torch.amax(probability_values, dim=-1, keepdim=True)
|
||||||
|
# Reshape min_p for broadcasting
|
||||||
|
adjusted_min_p = min_p.unsqueeze(1) * max_probabilities
|
||||||
|
# Identify valid tokens using threshold comparison
|
||||||
|
# Apply mask using boolean indexing
|
||||||
|
logits = logits.masked_fill(probability_values < adjusted_min_p,
|
||||||
|
-float('inf'))
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
def apply_top_k_top_p(
|
||||||
|
logits: torch.Tensor,
|
||||||
|
k: Optional[torch.Tensor],
|
||||||
|
p: Optional[torch.Tensor],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Apply top-k and top-p masks to the logits.
|
||||||
|
|
||||||
|
If a top-p is used, this function will sort the logits tensor,
|
||||||
|
which can be slow for large batches.
|
||||||
|
|
||||||
|
The logits tensor may be updated in-place.
|
||||||
|
"""
|
||||||
|
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
|
||||||
|
|
||||||
|
if k is not None:
|
||||||
|
# Apply top-k.
|
||||||
|
top_k_mask = logits_sort.size(1) - k.to(torch.long) # shape: B
|
||||||
|
# Get all the top_k values.
|
||||||
|
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
|
||||||
|
top_k_mask = logits_sort < top_k_mask
|
||||||
|
logits_sort.masked_fill_(top_k_mask, -float("inf"))
|
||||||
|
|
||||||
|
if p is not None:
|
||||||
|
# Apply top-p.
|
||||||
|
probs_sort = logits_sort.softmax(dim=-1)
|
||||||
|
probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort)
|
||||||
|
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
|
||||||
|
# at least one
|
||||||
|
top_p_mask[:, -1] = False
|
||||||
|
logits_sort.masked_fill_(top_p_mask, -float("inf"))
|
||||||
|
|
||||||
|
# Re-sort the probabilities.
|
||||||
|
logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
def apply_top_k_top_p_new(
|
||||||
|
logits: torch.Tensor,
|
||||||
|
k: Optional[torch.Tensor],
|
||||||
|
p: Optional[torch.Tensor],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
batch_size, vocab_size = logits.shape
|
||||||
|
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
|
||||||
|
|
||||||
|
# Apply top-k.
|
||||||
|
boundary = logits_sort.gather(1, (vocab_size - k).unsqueeze(dim=1))
|
||||||
|
top_k_mask = logits_sort < boundary
|
||||||
|
logits_sort.masked_fill_(top_k_mask, -float("inf"))
|
||||||
|
|
||||||
|
if p is not None:
|
||||||
|
# Apply top-p.
|
||||||
|
cutoff = top_k_mask.sum(dim=-1).min()
|
||||||
|
probs_sort = logits_sort.softmax(dim=-1)[:, cutoff:]
|
||||||
|
probs_sum = probs_sort.cumsum(dim=-1)
|
||||||
|
top_p_mask = probs_sum > 1 - p.unsqueeze(dim=1)
|
||||||
|
top_p_mask[:, -1] = True
|
||||||
|
strides = torch.arange(0,
|
||||||
|
batch_size * vocab_size,
|
||||||
|
vocab_size,
|
||||||
|
device=logits.device)
|
||||||
|
flatten_idx = logits_idx[:, cutoff:] + strides.unsqueeze(dim=1)
|
||||||
|
valid_idx = torch.masked_select(flatten_idx, top_p_mask)
|
||||||
|
logits_flatten = logits.flatten()
|
||||||
|
valid_logits = torch.index_select(logits_flatten, 0, valid_idx)
|
||||||
|
logits = torch.empty_like(logits_flatten).fill_(-float("inf"))
|
||||||
|
logits[valid_idx] = valid_logits
|
||||||
|
return logits.reshape(batch_size, vocab_size)
|
||||||
|
|
||||||
|
|
||||||
|
# test with leading dimension and merge seqlen and batch_size as num_tokens
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_apply_min_p() -> None:
|
||||||
|
logits = torch.randn((128, 7168)).npu()
|
||||||
|
min_p = torch.Tensor([0.01]).npu()
|
||||||
|
logits_new = apply_min_p_new(logits, min_p)
|
||||||
|
sampler = Sampler()
|
||||||
|
logits_old = sampler.apply_min_p(logits, min_p)
|
||||||
|
# Compare the results.
|
||||||
|
torch.testing.assert_close(logits_new,
|
||||||
|
logits_old,
|
||||||
|
atol=DEFAULT_ATOL,
|
||||||
|
rtol=DEFAULT_RTOL)
|
||||||
|
|
||||||
|
|
||||||
|
# test with leading dimension and merge seqlen and batch_size as num_tokens
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_apply_top_k_top_p() -> None:
|
||||||
|
logits = torch.randn((128, 7168)).npu()
|
||||||
|
k = torch.Tensor([-1]).int().npu()
|
||||||
|
p = torch.Tensor([1]).int().npu()
|
||||||
|
logits_new = apply_top_k_top_p_new(logits, k, p)
|
||||||
|
logits_old = apply_top_k_top_p(logits, k, p)
|
||||||
|
# Compare the results.
|
||||||
|
torch.testing.assert_close(logits_new,
|
||||||
|
logits_old,
|
||||||
|
atol=DEFAULT_ATOL,
|
||||||
|
rtol=DEFAULT_RTOL)
|
||||||
@@ -36,6 +36,8 @@ env_variables: Dict[str, Callable[[], Any]] = {
|
|||||||
lambda: bool(int(os.getenv("COMPILE_CUSTOM_KERNELS", "1"))),
|
lambda: bool(int(os.getenv("COMPILE_CUSTOM_KERNELS", "1"))),
|
||||||
"VLLM_ENABLE_MC2":
|
"VLLM_ENABLE_MC2":
|
||||||
lambda: bool(int(os.getenv("VLLM_ENABLE_MC2", '0'))),
|
lambda: bool(int(os.getenv("VLLM_ENABLE_MC2", '0'))),
|
||||||
|
"VLLM_ASCEND_ENABLE_TOPK_OPTIMZE":
|
||||||
|
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_TOPK_OPTIMZE", '0'))),
|
||||||
"USING_LCCL_COM":
|
"USING_LCCL_COM":
|
||||||
lambda: bool(int(os.getenv("USING_LCCL_COM", '0'))),
|
lambda: bool(int(os.getenv("USING_LCCL_COM", '0'))),
|
||||||
"SOC_VERSION":
|
"SOC_VERSION":
|
||||||
|
|||||||
@@ -238,8 +238,7 @@ class CustomDeepseekV2MoE(nn.Module):
|
|||||||
|
|
||||||
num_tokens, hidden_size = hidden_states.shape
|
num_tokens, hidden_size = hidden_states.shape
|
||||||
|
|
||||||
if self.n_shared_experts is not None:
|
old_hidden_states = hidden_states.clone()
|
||||||
shared_output = self.shared_experts(hidden_states)
|
|
||||||
|
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill:
|
if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill:
|
||||||
@@ -288,6 +287,9 @@ class CustomDeepseekV2MoE(nn.Module):
|
|||||||
if num_padding_tokens > 0:
|
if num_padding_tokens > 0:
|
||||||
hidden_states = hidden_states[:-num_padding_tokens]
|
hidden_states = hidden_states[:-num_padding_tokens]
|
||||||
|
|
||||||
|
if self.n_shared_experts is not None:
|
||||||
|
shared_output = self.shared_experts(old_hidden_states)
|
||||||
|
|
||||||
if shared_output is not None:
|
if shared_output is not None:
|
||||||
hidden_states = hidden_states + shared_output
|
hidden_states = hidden_states + shared_output
|
||||||
|
|
||||||
|
|||||||
@@ -363,7 +363,7 @@ def fused_experts(
|
|||||||
num_experts)).to(topk_ids.dtype)
|
num_experts)).to(topk_ids.dtype)
|
||||||
|
|
||||||
# Sort by local expert IDs
|
# Sort by local expert IDs
|
||||||
sort_indices = torch.argsort(filtered_experts)
|
sort_indices = torch.argsort(filtered_experts.view(torch.float32))
|
||||||
sorted_token_indices = token_indices[sort_indices]
|
sorted_token_indices = token_indices[sort_indices]
|
||||||
sorted_weights = filtered_weights[sort_indices]
|
sorted_weights = filtered_weights[sort_indices]
|
||||||
|
|
||||||
|
|||||||
@@ -166,3 +166,30 @@
|
|||||||
# Future Plan:
|
# Future Plan:
|
||||||
# Revert it when the ascend support triton kernel.
|
# Revert it when the ascend support triton kernel.
|
||||||
#
|
#
|
||||||
|
# ** File: v1/sample/sampler.py **
|
||||||
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
# 1. `vllm.v1.sample.sampler.Sampler.apply_top_k_top_p`
|
||||||
|
# Why:
|
||||||
|
# We need to use the patched `apply_top_k_top_p` in `sample`.
|
||||||
|
# The mainly reason to overwrite `apply_top_k_top_p` is
|
||||||
|
# to improve performance.
|
||||||
|
# How:
|
||||||
|
# Re-implementation the `apply_top_k_top_p` function by pytorch
|
||||||
|
# Related PR (if no, explain why): 1. refused by vllm. 2. vllm doesn't support 3. prepare to submit....
|
||||||
|
# - https://github.com/vllm-project/vllm-ascend/pull/970
|
||||||
|
# Future Plan:
|
||||||
|
# Revert it when the ascend scatter performance improves.
|
||||||
|
#
|
||||||
|
# ** File: v1/sample/sampler.py **
|
||||||
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~s
|
||||||
|
# 1. `vllm.v1.sample.sampler.Sampler.apply_min_p`
|
||||||
|
# Why:
|
||||||
|
# We need to use the patched `apply_min_p` in `sample`.
|
||||||
|
# The mainly reason to overwrite `apply_min_p` is
|
||||||
|
# to improve performance.
|
||||||
|
# How:
|
||||||
|
# Re-implementation the `apply_min_p` function by pytorch
|
||||||
|
# Related PR (if no, explain why): 1. refused by vllm. 2. vllm doesn't support 3. prepare to submit....
|
||||||
|
# - https://github.com/vllm-project/vllm-ascend/pull/970
|
||||||
|
# Future Plan:
|
||||||
|
# Revert it when the ascend indexput performance improves.
|
||||||
|
|||||||
@@ -23,4 +23,5 @@ import vllm_ascend.patch.worker.patch_common.patch_eagle # noqa
|
|||||||
import vllm_ascend.patch.worker.patch_common.patch_metrics # noqa
|
import vllm_ascend.patch.worker.patch_common.patch_metrics # noqa
|
||||||
import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa
|
import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa
|
||||||
import vllm_ascend.patch.worker.patch_common.patch_multi_step_worker # noqa
|
import vllm_ascend.patch.worker.patch_common.patch_multi_step_worker # noqa
|
||||||
|
import vllm_ascend.patch.worker.patch_common.patch_sampler # noqa
|
||||||
import vllm_ascend.patch.worker.patch_common.patch_spec_decode_worker # noqa
|
import vllm_ascend.patch.worker.patch_common.patch_spec_decode_worker # noqa
|
||||||
|
|||||||
101
vllm_ascend/patch/worker/patch_common/patch_sampler.py
Normal file
101
vllm_ascend/patch/worker/patch_common/patch_sampler.py
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# This file is a part of the vllm-ascend project.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler, random_sample
|
||||||
|
from vllm.v1.sample.sampler import Sampler
|
||||||
|
|
||||||
|
from vllm_ascend import envs
|
||||||
|
|
||||||
|
|
||||||
|
def apply_min_p(
|
||||||
|
self,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
min_p: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Filters logits using adaptive probability thresholding.
|
||||||
|
"""
|
||||||
|
# Convert logits to probability distribution
|
||||||
|
probability_values = torch.nn.functional.softmax(logits, dim=-1)
|
||||||
|
# Calculate maximum probabilities per sequence
|
||||||
|
max_probabilities = torch.amax(probability_values, dim=-1, keepdim=True)
|
||||||
|
# Reshape min_p for broadcasting
|
||||||
|
adjusted_min_p = min_p.unsqueeze(1) * max_probabilities
|
||||||
|
# Identify valid tokens using threshold comparison
|
||||||
|
# Apply mask using boolean indexing
|
||||||
|
logits = logits.masked_fill(probability_values < adjusted_min_p,
|
||||||
|
-float('inf'))
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_top_k_top_p(
|
||||||
|
logits: torch.Tensor,
|
||||||
|
p: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
probs = logits.softmax(dim=-1)
|
||||||
|
probs_sort, _ = probs.sort(dim=-1, descending=False)
|
||||||
|
|
||||||
|
if k is not None:
|
||||||
|
top_k_count = probs_sort.size(1) - k.to(torch.long) # shape: (batch, )
|
||||||
|
top_k_count = top_k_count.unsqueeze(dim=1)
|
||||||
|
top_k_cutoff = probs_sort.gather(-1, top_k_count)
|
||||||
|
|
||||||
|
# Make sure the no top-k rows are no-op.
|
||||||
|
no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1)
|
||||||
|
top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf"))
|
||||||
|
|
||||||
|
elements_to_discard = probs < top_k_cutoff
|
||||||
|
logits.masked_fill_(elements_to_discard, -float("inf"))
|
||||||
|
|
||||||
|
if p is not None:
|
||||||
|
cumprob = torch.cumsum(probs_sort, dim=-1)
|
||||||
|
top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1)
|
||||||
|
top_p_mask[:, -1] = False # at least one
|
||||||
|
|
||||||
|
top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1)
|
||||||
|
top_p_cutoff = probs_sort.gather(-1, top_p_count)
|
||||||
|
elements_to_discard = probs < top_p_cutoff
|
||||||
|
logits.masked_fill_(elements_to_discard, -float("inf"))
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
def topk_topp_forward_native(
|
||||||
|
self,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
generators: dict[int, torch.Generator],
|
||||||
|
k: Optional[torch.Tensor],
|
||||||
|
p: Optional[torch.Tensor],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
PyTorch-native implementation of top-k and top-p sampling.
|
||||||
|
|
||||||
|
The logits tensor may be updated in-place.
|
||||||
|
"""
|
||||||
|
logits = _apply_top_k_top_p(logits, k, p)
|
||||||
|
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||||
|
return random_sample(probs, generators)
|
||||||
|
|
||||||
|
|
||||||
|
Sampler.apply_min_p = apply_min_p
|
||||||
|
if envs.VLLM_ASCEND_ENABLE_TOPK_OPTIMZE:
|
||||||
|
TopKTopPSampler.forward_native = topk_topp_forward_native
|
||||||
Reference in New Issue
Block a user