Files
xc-llm-ascend/tests/e2e/singlecard/test_sampler.py
wangxiyuan 69b817ed65 [CI] Add unit test framework (#1201)
This PR added the unit test framework to enable ut for vLLM Ascend. Unit
test runs on CPU machines. It'll be ran once lint check is passed the
same as e2e test.

For unit test, this PR created a new folder called `ut` under `tests`
module. All the test file in `ut` should keep the same with the code in
`vllm-ascend`. The file name should be start with `test_` prefix. For
example, in this PR. the `test_ascend_config.py` is added for
`ascend_config.py` test.

A new fille `worker/test_worker_v1.py` is also added as the placeholder.
This file should be the unit test for `vllm-ascend/worker/worker_v1.py`.

Additional, a new `fake_weight` folder is added, it contains the
config.json from `facebook/opt-125m`, so that the test will not always
visit huggingface.

TODO:
We should add all the unit test file one by one in the future.

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
2025-06-16 18:32:28 +08:00

148 lines
5.3 KiB
Python

#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
# Adapted from vllm/tests/entrypoints/llm/test_guided_generate.py
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Optional
import torch
from vllm.v1.sample.sampler import Sampler # noqa: F401
# Set tolerance to 1 for quant ops
DEFAULT_ATOL = 1e-3
DEFAULT_RTOL = 1e-3
def apply_min_p_new(
logits: torch.Tensor,
min_p: torch.Tensor,
) -> torch.Tensor:
"""
Filters logits using adaptive probability thresholding.
"""
if min_p == 0:
return logits
# Convert logits to probability distribution
probability_values = torch.nn.functional.softmax(logits, dim=-1)
# Calculate maximum probabilities per sequence
max_probabilities = torch.amax(probability_values, dim=-1, keepdim=True)
# Reshape min_p for broadcasting
adjusted_min_p = min_p.unsqueeze(1) * max_probabilities
# Identify valid tokens using threshold comparison
# Apply mask using boolean indexing
logits = logits.masked_fill(probability_values < adjusted_min_p,
-float('inf'))
return logits
def apply_top_k_top_p(
logits: torch.Tensor,
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
) -> torch.Tensor:
"""Apply top-k and top-p masks to the logits.
If a top-p is used, this function will sort the logits tensor,
which can be slow for large batches.
The logits tensor may be updated in-place.
"""
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
if k is not None:
# Apply top-k.
top_k_mask = logits_sort.size(1) - k.to(torch.long) # shape: B
# Get all the top_k values.
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
top_k_mask = logits_sort < top_k_mask
logits_sort.masked_fill_(top_k_mask, -float("inf"))
if p is not None:
# Apply top-p.
probs_sort = logits_sort.softmax(dim=-1)
probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort)
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
# at least one
top_p_mask[:, -1] = False
logits_sort.masked_fill_(top_p_mask, -float("inf"))
# Re-sort the probabilities.
logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)
return logits
def apply_top_k_top_p_new(
logits: torch.Tensor,
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
) -> torch.Tensor:
batch_size, vocab_size = logits.shape
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
# Apply top-k.
boundary = logits_sort.gather(1, (vocab_size - k).unsqueeze(dim=1))
top_k_mask = logits_sort < boundary
logits_sort.masked_fill_(top_k_mask, -float("inf"))
if p is not None:
# Apply top-p.
cutoff = top_k_mask.sum(dim=-1).min()
probs_sort = logits_sort.softmax(dim=-1)[:, cutoff:]
probs_sum = probs_sort.cumsum(dim=-1)
top_p_mask = probs_sum > 1 - p.unsqueeze(dim=1)
top_p_mask[:, -1] = True
strides = torch.arange(0,
batch_size * vocab_size,
vocab_size,
device=logits.device)
flatten_idx = logits_idx[:, cutoff:] + strides.unsqueeze(dim=1)
valid_idx = torch.masked_select(flatten_idx, top_p_mask)
logits_flatten = logits.flatten()
valid_logits = torch.index_select(logits_flatten, 0, valid_idx)
logits = torch.empty_like(logits_flatten).fill_(-float("inf"))
logits[valid_idx] = valid_logits
return logits.reshape(batch_size, vocab_size)
# test with leading dimension and merge seqlen and batch_size as num_tokens
@torch.inference_mode()
def test_apply_min_p() -> None:
logits = torch.randn((128, 7168)).npu()
min_p = torch.Tensor([0.01]).npu()
logits_new = apply_min_p_new(logits, min_p)
sampler = Sampler()
logits_old = sampler.apply_min_p(logits, min_p)
# Compare the results.
torch.testing.assert_close(logits_new,
logits_old,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)
# test with leading dimension and merge seqlen and batch_size as num_tokens
@torch.inference_mode()
def test_apply_top_k_top_p() -> None:
logits = torch.randn((128, 7168)).npu()
k = torch.Tensor([-1]).int().npu()
p = torch.Tensor([1]).int().npu()
logits_new = apply_top_k_top_p_new(logits, k, p)
logits_old = apply_top_k_top_p(logits, k, p)
# Compare the results.
torch.testing.assert_close(logits_new,
logits_old,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)