[Test] add rejection sampler ut (#2084)
### What this PR does / why we need it?
add rejection sampler ut.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
UT passed
- vLLM version: v0.10.0
- vLLM main:
586f286789
Signed-off-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
This commit is contained in:
201
tests/ut/sample/test_rejection_sampler.py
Normal file
201
tests/ut/sample/test_rejection_sampler.py
Normal file
@@ -0,0 +1,201 @@
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.sample.rejection_sampler import (
|
||||
expand_batch_to_tokens, expand_pytorch, rejection_greedy_sample_pytorch,
|
||||
rejection_random_sample_pytorch, sample_recovered_tokens_pytorch)
|
||||
|
||||
# Global constants
|
||||
PLACEHOLDER_TOKEN_ID = -1
|
||||
GREEDY_TEMPERATURE = 0.0
|
||||
MAX_SPEC_LEN = 8 # Used as MAX_NUM_TOKENS in expand_batch_to_tokens
|
||||
|
||||
|
||||
class TestAscendRejectionSampler(TestBase):
|
||||
|
||||
def test_rejection_greedy_sample_pytorch(self):
|
||||
"""Test greedy rejection sampling: stop when draft doesn't match, otherwise append bonus token"""
|
||||
batch_size = 2
|
||||
max_spec_len = 3
|
||||
output_token_ids = torch.full((batch_size, max_spec_len + 1),
|
||||
PLACEHOLDER_TOKEN_ID)
|
||||
|
||||
cu_num_draft_tokens = torch.tensor([2, 4])
|
||||
draft_token_ids = torch.tensor([10, 11, 20, 21])
|
||||
target_argmax = torch.tensor([10, 99, 20, 22])
|
||||
bonus_token_ids = torch.tensor([[100], [200]])
|
||||
|
||||
is_greedy = torch.tensor([True, True])
|
||||
|
||||
rejection_greedy_sample_pytorch(
|
||||
output_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
target_argmax,
|
||||
bonus_token_ids,
|
||||
is_greedy,
|
||||
max_spec_len,
|
||||
)
|
||||
|
||||
assert output_token_ids[0, 0].item() == 10
|
||||
assert output_token_ids[0, 1].item() == 99
|
||||
assert output_token_ids[1, 0].item() == 20
|
||||
assert output_token_ids[1, 2].item() == PLACEHOLDER_TOKEN_ID
|
||||
|
||||
def test_rejection_random_sample_pytorch(self):
|
||||
"""Test random rejection sampling: accept based on uniform probability"""
|
||||
batch_size = 2
|
||||
max_spec_len = 3
|
||||
output_token_ids = torch.full((batch_size, max_spec_len + 1),
|
||||
PLACEHOLDER_TOKEN_ID)
|
||||
|
||||
cu_num_draft_tokens = torch.tensor([2, 1])
|
||||
draft_token_ids = torch.tensor([1, 0, 2])
|
||||
draft_probs = torch.tensor([
|
||||
[0.0, 0.6, 0.0, 0.4], # vocab_size=4
|
||||
[0.1, 0.2, 0.3, 0.4],
|
||||
[0.5, 0.5, 0.0, 0.0],
|
||||
])
|
||||
target_probs = torch.tensor([
|
||||
[0.0, 0.8, 0.0, 0.2],
|
||||
[0.2, 0.1, 0.3, 0.4],
|
||||
[0.9, 0.1, 0.0, 0.0],
|
||||
])
|
||||
bonus_token_ids = torch.tensor([[100], [200]])
|
||||
recovered_token_ids = torch.tensor([1, 2, 3])
|
||||
uniform_probs = torch.tensor([0.7, 0.6, 0.5])
|
||||
is_greedy = torch.tensor([False, False])
|
||||
vocab_size = 4
|
||||
|
||||
rejection_random_sample_pytorch(
|
||||
output_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
draft_probs,
|
||||
target_probs,
|
||||
bonus_token_ids,
|
||||
recovered_token_ids,
|
||||
uniform_probs,
|
||||
is_greedy,
|
||||
max_spec_len,
|
||||
vocab_size,
|
||||
IS_NGRAM=False,
|
||||
)
|
||||
|
||||
assert output_token_ids[0, 0].item() == 1
|
||||
assert output_token_ids[0, 1].item() == 0
|
||||
assert output_token_ids[0, 2].item() == 100
|
||||
|
||||
def test_expand_pytorch(self):
|
||||
"""Test expand_pytorch functionality"""
|
||||
input_ptr = torch.tensor([10, 20, 30], dtype=torch.int32)
|
||||
cu_num_tokens_ptr = torch.tensor([2, 5, 7])
|
||||
output_ptr = torch.empty(7, dtype=torch.int32)
|
||||
|
||||
expand_pytorch(
|
||||
output_ptr,
|
||||
input_ptr,
|
||||
cu_num_tokens_ptr,
|
||||
replace_from=0,
|
||||
replace_to=0,
|
||||
MAX_NUM_TOKENS=MAX_SPEC_LEN,
|
||||
)
|
||||
|
||||
expected = torch.tensor([10, 10, 20, 20, 20, 30, 30])
|
||||
assert torch.equal(output_ptr, expected)
|
||||
|
||||
def test_expand_batch_to_tokens(self):
|
||||
"""Test expand_batch_to_tokens wrapper"""
|
||||
x = torch.tensor([10, 20, 30])
|
||||
cu_num_tokens = torch.tensor([2, 5, 7])
|
||||
num_tokens = 7
|
||||
|
||||
with patch("vllm_ascend.sample.rejection_sampler.expand_pytorch"
|
||||
) as mock_kernel:
|
||||
expand_batch_to_tokens(x, cu_num_tokens, num_tokens)
|
||||
mock_kernel.assert_called_once()
|
||||
args = mock_kernel.call_args[0]
|
||||
assert (args[1] == x).all()
|
||||
assert (args[2] == cu_num_tokens).all()
|
||||
|
||||
# Run actual function
|
||||
result = expand_batch_to_tokens(x, cu_num_tokens, num_tokens)
|
||||
expected = torch.tensor([10, 10, 20, 20, 20, 30, 30])
|
||||
assert torch.equal(result, expected)
|
||||
|
||||
def test_sample_recovered_tokens_pytorch_ngram(self):
|
||||
"""Test recovered token sampling under n-gram mode"""
|
||||
output_token_ids = torch.empty(2, dtype=torch.int32)
|
||||
cu_num_draft_tokens = torch.tensor([1, 2])
|
||||
draft_token_ids = torch.tensor([1, 2])
|
||||
draft_probs = None
|
||||
target_probs = torch.tensor([
|
||||
[0.1, 0.2, 0.7],
|
||||
[0.3, 0.3, 0.4],
|
||||
])
|
||||
q = torch.tensor([
|
||||
[0.1, 0.2, 0.7],
|
||||
[0.5, 0.4, 0.1],
|
||||
])
|
||||
vocab_size = 3
|
||||
|
||||
sample_recovered_tokens_pytorch(
|
||||
output_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
draft_probs,
|
||||
target_probs,
|
||||
q,
|
||||
vocab_size,
|
||||
IS_NGRAM=True,
|
||||
)
|
||||
|
||||
assert output_token_ids[0].item() == 0
|
||||
assert output_token_ids[1].item() == 1
|
||||
|
||||
def test_sample_recovered_tokens_pytorch_autoregressive(self):
|
||||
"""Test recovered token sampling for autoregressive models"""
|
||||
output_token_ids = torch.empty(2, dtype=torch.int32)
|
||||
cu_num_draft_tokens = torch.tensor([1, 1])
|
||||
draft_token_ids = torch.tensor([0, 1])
|
||||
draft_probs = torch.tensor([
|
||||
[0.6, 0.1, 0.3],
|
||||
[0.2, 0.7, 0.1],
|
||||
])
|
||||
target_probs = torch.tensor([
|
||||
[0.8, 0.1, 0.1],
|
||||
[0.3, 0.6, 0.1],
|
||||
])
|
||||
q = torch.tensor([
|
||||
[0.5, 0.3, 0.2],
|
||||
[0.1, 0.8, 0.1],
|
||||
])
|
||||
vocab_size = 3
|
||||
|
||||
sample_recovered_tokens_pytorch(
|
||||
output_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
draft_probs,
|
||||
target_probs,
|
||||
q,
|
||||
vocab_size,
|
||||
IS_NGRAM=False,
|
||||
)
|
||||
assert output_token_ids[0].item() == 0
|
||||
Reference in New Issue
Block a user