[Perf][MTP] Optimize reject sampler in greedy situation. (#2137)

This PR port optimization in PR #2002 to main and makes it cleaner.

- vLLM version: v0.10.0
- vLLM main:
afa5b7ca0b

---------

Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
whx
2025-08-11 17:37:49 +08:00
committed by GitHub
parent ca274001b0
commit 29aaba5f84
3 changed files with 120 additions and 58 deletions

View File

@@ -77,8 +77,9 @@ def test_perfect_match(rejection_sampler):
metadata = create_sampling_metadata(all_greedy=True)
logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
device=logits.device)
bonus_token_tensor = torch.tensor([[output_tokens[0][-1]]],
device=logits.device,
dtype=torch.int32)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
device=logits.device)
@@ -102,8 +103,9 @@ def test_early_mismatch(rejection_sampler):
metadata = create_sampling_metadata(all_greedy=True)
logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
device=logits.device)
bonus_token_tensor = torch.tensor([[output_tokens[0][-1]]],
device=logits.device,
dtype=torch.int32)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
device=logits.device)
@@ -131,7 +133,9 @@ def test_multiple_sequences(rejection_sampler):
metadata = create_sampling_metadata(all_greedy=True)
logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor(
[output_tokens[0][-1], output_tokens[1][-1]], device=logits.device)
[output_tokens[0][-1], output_tokens[1][-1]],
device=logits.device,
dtype=torch.int32).unsqueeze(1)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
device=logits.device)
@@ -155,8 +159,9 @@ def test_single_token_sequence(rejection_sampler):
metadata = create_sampling_metadata(all_greedy=True)
logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
device=logits.device)
bonus_token_tensor = torch.tensor([[output_tokens[0][-1]]],
device=logits.device,
dtype=torch.int32)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
device=logits.device)
@@ -178,8 +183,9 @@ def test_empty_sequence(rejection_sampler):
metadata = create_sampling_metadata(all_greedy=True)
logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
device=logits.device)
bonus_token_tensor = torch.tensor([[output_tokens[0][-1]]],
device=logits.device,
dtype=torch.int32)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
device=logits.device)
@@ -203,7 +209,9 @@ def test_multiple_mismatches(rejection_sampler):
metadata = create_sampling_metadata(all_greedy=True)
logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor(
[output_tokens[0][-1], output_tokens[1][-1]], device=logits.device)
[output_tokens[0][-1], output_tokens[1][-1]],
device=logits.device,
dtype=torch.int32).unsqueeze(1)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
device=logits.device)
@@ -237,7 +245,8 @@ def test_parametrized_cases(rejection_sampler, spec_tokens, output_tokens,
metadata = create_sampling_metadata(all_greedy=True)
logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor([tokens[-1] for tokens in output_tokens],
device=logits.device)
device=logits.device,
dtype=torch.int32).unsqueeze(1)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
device=logits.device)