[Perf][MTP] Optimize reject sampler in greedy situation. (#2137)
This PR port optimization in PR #2002 to main and makes it cleaner.
- vLLM version: v0.10.0
- vLLM main:
afa5b7ca0b
---------
Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
@@ -77,8 +77,9 @@ def test_perfect_match(rejection_sampler):
|
||||
|
||||
metadata = create_sampling_metadata(all_greedy=True)
|
||||
logits = create_logits_tensor(output_tokens)
|
||||
bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
|
||||
device=logits.device)
|
||||
bonus_token_tensor = torch.tensor([[output_tokens[0][-1]]],
|
||||
device=logits.device,
|
||||
dtype=torch.int32)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
|
||||
device=logits.device)
|
||||
|
||||
@@ -102,8 +103,9 @@ def test_early_mismatch(rejection_sampler):
|
||||
|
||||
metadata = create_sampling_metadata(all_greedy=True)
|
||||
logits = create_logits_tensor(output_tokens)
|
||||
bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
|
||||
device=logits.device)
|
||||
bonus_token_tensor = torch.tensor([[output_tokens[0][-1]]],
|
||||
device=logits.device,
|
||||
dtype=torch.int32)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
|
||||
device=logits.device)
|
||||
|
||||
@@ -131,7 +133,9 @@ def test_multiple_sequences(rejection_sampler):
|
||||
metadata = create_sampling_metadata(all_greedy=True)
|
||||
logits = create_logits_tensor(output_tokens)
|
||||
bonus_token_tensor = torch.tensor(
|
||||
[output_tokens[0][-1], output_tokens[1][-1]], device=logits.device)
|
||||
[output_tokens[0][-1], output_tokens[1][-1]],
|
||||
device=logits.device,
|
||||
dtype=torch.int32).unsqueeze(1)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
|
||||
device=logits.device)
|
||||
|
||||
@@ -155,8 +159,9 @@ def test_single_token_sequence(rejection_sampler):
|
||||
|
||||
metadata = create_sampling_metadata(all_greedy=True)
|
||||
logits = create_logits_tensor(output_tokens)
|
||||
bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
|
||||
device=logits.device)
|
||||
bonus_token_tensor = torch.tensor([[output_tokens[0][-1]]],
|
||||
device=logits.device,
|
||||
dtype=torch.int32)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
|
||||
device=logits.device)
|
||||
|
||||
@@ -178,8 +183,9 @@ def test_empty_sequence(rejection_sampler):
|
||||
|
||||
metadata = create_sampling_metadata(all_greedy=True)
|
||||
logits = create_logits_tensor(output_tokens)
|
||||
bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
|
||||
device=logits.device)
|
||||
bonus_token_tensor = torch.tensor([[output_tokens[0][-1]]],
|
||||
device=logits.device,
|
||||
dtype=torch.int32)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
|
||||
device=logits.device)
|
||||
|
||||
@@ -203,7 +209,9 @@ def test_multiple_mismatches(rejection_sampler):
|
||||
metadata = create_sampling_metadata(all_greedy=True)
|
||||
logits = create_logits_tensor(output_tokens)
|
||||
bonus_token_tensor = torch.tensor(
|
||||
[output_tokens[0][-1], output_tokens[1][-1]], device=logits.device)
|
||||
[output_tokens[0][-1], output_tokens[1][-1]],
|
||||
device=logits.device,
|
||||
dtype=torch.int32).unsqueeze(1)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
|
||||
device=logits.device)
|
||||
|
||||
@@ -237,7 +245,8 @@ def test_parametrized_cases(rejection_sampler, spec_tokens, output_tokens,
|
||||
metadata = create_sampling_metadata(all_greedy=True)
|
||||
logits = create_logits_tensor(output_tokens)
|
||||
bonus_token_tensor = torch.tensor([tokens[-1] for tokens in output_tokens],
|
||||
device=logits.device)
|
||||
device=logits.device,
|
||||
dtype=torch.int32).unsqueeze(1)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
|
||||
device=logits.device)
|
||||
|
||||
|
||||
@@ -32,11 +32,12 @@ class TestAscendRejectionSampler(TestBase):
|
||||
def test_rejection_greedy_sample_pytorch(self):
|
||||
"""Test greedy rejection sampling: stop when draft doesn't match, otherwise append bonus token"""
|
||||
batch_size = 2
|
||||
max_spec_len = 3
|
||||
max_spec_len = 2
|
||||
output_token_ids = torch.full((batch_size, max_spec_len + 1),
|
||||
PLACEHOLDER_TOKEN_ID)
|
||||
|
||||
cu_num_draft_tokens = torch.tensor([2, 4])
|
||||
num_draft_tokens = [2, 2]
|
||||
draft_token_ids = torch.tensor([10, 11, 20, 21])
|
||||
target_argmax = torch.tensor([10, 99, 20, 22])
|
||||
bonus_token_ids = torch.tensor([[100], [200]])
|
||||
@@ -49,8 +50,9 @@ class TestAscendRejectionSampler(TestBase):
|
||||
draft_token_ids,
|
||||
target_argmax,
|
||||
bonus_token_ids,
|
||||
is_greedy,
|
||||
num_draft_tokens,
|
||||
max_spec_len,
|
||||
is_greedy,
|
||||
)
|
||||
|
||||
assert output_token_ids[0, 0].item() == 10
|
||||
|
||||
Reference in New Issue
Block a user