[Perf][MTP] Optimize reject sampler in greedy situation. (#2137)

This PR port optimization in PR #2002 to main and makes it cleaner.

- vLLM version: v0.10.0
- vLLM main:
afa5b7ca0b

---------

Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
whx
2025-08-11 17:37:49 +08:00
committed by GitHub
parent ca274001b0
commit 29aaba5f84
3 changed files with 120 additions and 58 deletions

View File

@@ -32,11 +32,12 @@ class TestAscendRejectionSampler(TestBase):
def test_rejection_greedy_sample_pytorch(self):
"""Test greedy rejection sampling: stop when draft doesn't match, otherwise append bonus token"""
batch_size = 2
max_spec_len = 3
max_spec_len = 2
output_token_ids = torch.full((batch_size, max_spec_len + 1),
PLACEHOLDER_TOKEN_ID)
cu_num_draft_tokens = torch.tensor([2, 4])
num_draft_tokens = [2, 2]
draft_token_ids = torch.tensor([10, 11, 20, 21])
target_argmax = torch.tensor([10, 99, 20, 22])
bonus_token_ids = torch.tensor([[100], [200]])
@@ -49,8 +50,9 @@ class TestAscendRejectionSampler(TestBase):
draft_token_ids,
target_argmax,
bonus_token_ids,
is_greedy,
num_draft_tokens,
max_spec_len,
is_greedy,
)
assert output_token_ids[0, 0].item() == 10