[Perf][MTP] Optimize reject sampler in greedy situation. (#2137)
This PR port optimization in PR #2002 to main and makes it cleaner.
- vLLM version: v0.10.0
- vLLM main:
afa5b7ca0b
---------
Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
@@ -77,8 +77,9 @@ def test_perfect_match(rejection_sampler):
|
|||||||
|
|
||||||
metadata = create_sampling_metadata(all_greedy=True)
|
metadata = create_sampling_metadata(all_greedy=True)
|
||||||
logits = create_logits_tensor(output_tokens)
|
logits = create_logits_tensor(output_tokens)
|
||||||
bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
|
bonus_token_tensor = torch.tensor([[output_tokens[0][-1]]],
|
||||||
device=logits.device)
|
device=logits.device,
|
||||||
|
dtype=torch.int32)
|
||||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
|
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
|
||||||
device=logits.device)
|
device=logits.device)
|
||||||
|
|
||||||
@@ -102,8 +103,9 @@ def test_early_mismatch(rejection_sampler):
|
|||||||
|
|
||||||
metadata = create_sampling_metadata(all_greedy=True)
|
metadata = create_sampling_metadata(all_greedy=True)
|
||||||
logits = create_logits_tensor(output_tokens)
|
logits = create_logits_tensor(output_tokens)
|
||||||
bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
|
bonus_token_tensor = torch.tensor([[output_tokens[0][-1]]],
|
||||||
device=logits.device)
|
device=logits.device,
|
||||||
|
dtype=torch.int32)
|
||||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
|
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
|
||||||
device=logits.device)
|
device=logits.device)
|
||||||
|
|
||||||
@@ -131,7 +133,9 @@ def test_multiple_sequences(rejection_sampler):
|
|||||||
metadata = create_sampling_metadata(all_greedy=True)
|
metadata = create_sampling_metadata(all_greedy=True)
|
||||||
logits = create_logits_tensor(output_tokens)
|
logits = create_logits_tensor(output_tokens)
|
||||||
bonus_token_tensor = torch.tensor(
|
bonus_token_tensor = torch.tensor(
|
||||||
[output_tokens[0][-1], output_tokens[1][-1]], device=logits.device)
|
[output_tokens[0][-1], output_tokens[1][-1]],
|
||||||
|
device=logits.device,
|
||||||
|
dtype=torch.int32).unsqueeze(1)
|
||||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
|
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
|
||||||
device=logits.device)
|
device=logits.device)
|
||||||
|
|
||||||
@@ -155,8 +159,9 @@ def test_single_token_sequence(rejection_sampler):
|
|||||||
|
|
||||||
metadata = create_sampling_metadata(all_greedy=True)
|
metadata = create_sampling_metadata(all_greedy=True)
|
||||||
logits = create_logits_tensor(output_tokens)
|
logits = create_logits_tensor(output_tokens)
|
||||||
bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
|
bonus_token_tensor = torch.tensor([[output_tokens[0][-1]]],
|
||||||
device=logits.device)
|
device=logits.device,
|
||||||
|
dtype=torch.int32)
|
||||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
|
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
|
||||||
device=logits.device)
|
device=logits.device)
|
||||||
|
|
||||||
@@ -178,8 +183,9 @@ def test_empty_sequence(rejection_sampler):
|
|||||||
|
|
||||||
metadata = create_sampling_metadata(all_greedy=True)
|
metadata = create_sampling_metadata(all_greedy=True)
|
||||||
logits = create_logits_tensor(output_tokens)
|
logits = create_logits_tensor(output_tokens)
|
||||||
bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
|
bonus_token_tensor = torch.tensor([[output_tokens[0][-1]]],
|
||||||
device=logits.device)
|
device=logits.device,
|
||||||
|
dtype=torch.int32)
|
||||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
|
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
|
||||||
device=logits.device)
|
device=logits.device)
|
||||||
|
|
||||||
@@ -203,7 +209,9 @@ def test_multiple_mismatches(rejection_sampler):
|
|||||||
metadata = create_sampling_metadata(all_greedy=True)
|
metadata = create_sampling_metadata(all_greedy=True)
|
||||||
logits = create_logits_tensor(output_tokens)
|
logits = create_logits_tensor(output_tokens)
|
||||||
bonus_token_tensor = torch.tensor(
|
bonus_token_tensor = torch.tensor(
|
||||||
[output_tokens[0][-1], output_tokens[1][-1]], device=logits.device)
|
[output_tokens[0][-1], output_tokens[1][-1]],
|
||||||
|
device=logits.device,
|
||||||
|
dtype=torch.int32).unsqueeze(1)
|
||||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
|
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
|
||||||
device=logits.device)
|
device=logits.device)
|
||||||
|
|
||||||
@@ -237,7 +245,8 @@ def test_parametrized_cases(rejection_sampler, spec_tokens, output_tokens,
|
|||||||
metadata = create_sampling_metadata(all_greedy=True)
|
metadata = create_sampling_metadata(all_greedy=True)
|
||||||
logits = create_logits_tensor(output_tokens)
|
logits = create_logits_tensor(output_tokens)
|
||||||
bonus_token_tensor = torch.tensor([tokens[-1] for tokens in output_tokens],
|
bonus_token_tensor = torch.tensor([tokens[-1] for tokens in output_tokens],
|
||||||
device=logits.device)
|
device=logits.device,
|
||||||
|
dtype=torch.int32).unsqueeze(1)
|
||||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
|
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
|
||||||
device=logits.device)
|
device=logits.device)
|
||||||
|
|
||||||
|
|||||||
@@ -32,11 +32,12 @@ class TestAscendRejectionSampler(TestBase):
|
|||||||
def test_rejection_greedy_sample_pytorch(self):
|
def test_rejection_greedy_sample_pytorch(self):
|
||||||
"""Test greedy rejection sampling: stop when draft doesn't match, otherwise append bonus token"""
|
"""Test greedy rejection sampling: stop when draft doesn't match, otherwise append bonus token"""
|
||||||
batch_size = 2
|
batch_size = 2
|
||||||
max_spec_len = 3
|
max_spec_len = 2
|
||||||
output_token_ids = torch.full((batch_size, max_spec_len + 1),
|
output_token_ids = torch.full((batch_size, max_spec_len + 1),
|
||||||
PLACEHOLDER_TOKEN_ID)
|
PLACEHOLDER_TOKEN_ID)
|
||||||
|
|
||||||
cu_num_draft_tokens = torch.tensor([2, 4])
|
cu_num_draft_tokens = torch.tensor([2, 4])
|
||||||
|
num_draft_tokens = [2, 2]
|
||||||
draft_token_ids = torch.tensor([10, 11, 20, 21])
|
draft_token_ids = torch.tensor([10, 11, 20, 21])
|
||||||
target_argmax = torch.tensor([10, 99, 20, 22])
|
target_argmax = torch.tensor([10, 99, 20, 22])
|
||||||
bonus_token_ids = torch.tensor([[100], [200]])
|
bonus_token_ids = torch.tensor([[100], [200]])
|
||||||
@@ -49,8 +50,9 @@ class TestAscendRejectionSampler(TestBase):
|
|||||||
draft_token_ids,
|
draft_token_ids,
|
||||||
target_argmax,
|
target_argmax,
|
||||||
bonus_token_ids,
|
bonus_token_ids,
|
||||||
is_greedy,
|
num_draft_tokens,
|
||||||
max_spec_len,
|
max_spec_len,
|
||||||
|
is_greedy,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert output_token_ids[0, 0].item() == 10
|
assert output_token_ids[0, 0].item() == 10
|
||||||
|
|||||||
@@ -147,16 +147,25 @@ def rejection_sample(
|
|||||||
if not sampling_metadata.all_random:
|
if not sampling_metadata.all_random:
|
||||||
# Rejection sampling for greedy sampling requests.
|
# Rejection sampling for greedy sampling requests.
|
||||||
target_argmax = target_probs.argmax(dim=-1)
|
target_argmax = target_probs.argmax(dim=-1)
|
||||||
rejection_greedy_sample_pytorch(
|
if min(num_draft_tokens) == 1 and max(
|
||||||
output_token_ids,
|
num_draft_tokens) == 1 and sampling_metadata.all_greedy:
|
||||||
cu_num_draft_tokens,
|
rejection_greedy_sample_spec_len_1_pytorch(
|
||||||
draft_token_ids,
|
output_token_ids,
|
||||||
target_argmax,
|
draft_token_ids,
|
||||||
bonus_token_ids,
|
target_argmax,
|
||||||
is_greedy,
|
bonus_token_ids,
|
||||||
max_spec_len,
|
)
|
||||||
# num_warps=1,
|
else:
|
||||||
)
|
rejection_greedy_sample_pytorch(
|
||||||
|
output_token_ids,
|
||||||
|
cu_num_draft_tokens,
|
||||||
|
draft_token_ids,
|
||||||
|
target_argmax,
|
||||||
|
bonus_token_ids,
|
||||||
|
num_draft_tokens,
|
||||||
|
max_spec_len,
|
||||||
|
is_greedy,
|
||||||
|
)
|
||||||
if sampling_metadata.all_greedy:
|
if sampling_metadata.all_greedy:
|
||||||
return output_token_ids
|
return output_token_ids
|
||||||
|
|
||||||
@@ -284,47 +293,89 @@ def sample_recovered_tokens(
|
|||||||
return recovered_token_ids
|
return recovered_token_ids
|
||||||
|
|
||||||
|
|
||||||
def rejection_greedy_sample_pytorch(
|
def rejection_greedy_sample_spec_len_1_pytorch(
|
||||||
output_token_ids, # [batch_size, max_spec_len + 1]
|
output_token_ids, # [batch_size, 2]
|
||||||
cu_num_draft_tokens, # [batch_size]
|
draft_token_ids, # [num_tokens]
|
||||||
draft_token_ids, # [num_tokens]
|
target_argmax, # [num_tokens]
|
||||||
target_argmax, # [num_tokens]
|
bonus_token_ids, # [batch_size]
|
||||||
bonus_token_ids, # [batch_size]
|
|
||||||
is_greedy=None, # [batch_size] or None
|
|
||||||
max_spec_len=None,
|
|
||||||
):
|
):
|
||||||
batch_size = output_token_ids.shape[0]
|
batch_size = output_token_ids.size(0)
|
||||||
|
num_tokens = draft_token_ids.size(0)
|
||||||
|
assert batch_size == num_tokens
|
||||||
|
accept_req_mask = draft_token_ids == target_argmax
|
||||||
|
output_token_ids[:, 0] = target_argmax
|
||||||
|
bonus_token_ids = bonus_token_ids.squeeze(1)
|
||||||
|
output_token_ids[accept_req_mask, 1] = bonus_token_ids[accept_req_mask]
|
||||||
|
|
||||||
|
|
||||||
|
def rejection_greedy_sample_pytorch(
|
||||||
|
output_token_ids, # [batch_size, max_spec_len + 1]
|
||||||
|
cu_num_draft_tokens, # [batch_size]
|
||||||
|
draft_token_ids, # [num_tokens]
|
||||||
|
target_argmax, # [num_tokens]
|
||||||
|
bonus_token_ids, # [batch_size]
|
||||||
|
draft_tokens_per_req, # [batch_size], list
|
||||||
|
max_spec_len,
|
||||||
|
is_greedy=None, # [batch_size] or None
|
||||||
|
):
|
||||||
|
batch_size = output_token_ids.size(0)
|
||||||
|
num_tokens = draft_token_ids.size(0)
|
||||||
|
device = output_token_ids.device
|
||||||
|
draft_tokens_per_req = torch.tensor(draft_tokens_per_req).to(
|
||||||
|
device, non_blocking=True)
|
||||||
if is_greedy is None:
|
if is_greedy is None:
|
||||||
is_greedy = torch.ones(batch_size,
|
is_greedy = torch.ones(batch_size, dtype=torch.bool, device=device)
|
||||||
dtype=torch.bool,
|
|
||||||
device=output_token_ids.device)
|
|
||||||
|
|
||||||
for req_idx in range(batch_size):
|
start_indices = cu_num_draft_tokens - draft_tokens_per_req
|
||||||
if not is_greedy[req_idx]:
|
req_ids = torch.arange(batch_size, device=device)
|
||||||
continue
|
token_req_ids = torch.repeat_interleave(req_ids, draft_tokens_per_req)
|
||||||
|
token_positions = torch.arange(
|
||||||
|
num_tokens, device=device) - start_indices[token_req_ids]
|
||||||
|
|
||||||
if req_idx == 0:
|
# Find the first mismatch position of each request.
|
||||||
start_idx = 0
|
mismatch_global = (draft_token_ids != target_argmax)
|
||||||
else:
|
if max_spec_len == 0:
|
||||||
start_idx = cu_num_draft_tokens[req_idx - 1].item()
|
first_mismatch_pos_per_req = torch.zeros(batch_size,
|
||||||
end_idx = cu_num_draft_tokens[req_idx].item()
|
dtype=torch.long,
|
||||||
num_draft_tokens = end_idx - start_idx
|
device=device)
|
||||||
|
else:
|
||||||
|
# [bs, max_spec_len]
|
||||||
|
pos_matrix = torch.full((batch_size, max_spec_len),
|
||||||
|
-1,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=device)
|
||||||
|
pos_matrix[token_req_ids, token_positions] = token_positions
|
||||||
|
mismatch_matrix = torch.full((batch_size, max_spec_len),
|
||||||
|
False,
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=device)
|
||||||
|
mismatch_matrix[token_req_ids, token_positions] = mismatch_global
|
||||||
|
mismatch_positions = torch.where(mismatch_matrix, pos_matrix,
|
||||||
|
max_spec_len * 2)
|
||||||
|
first_mismatch_pos_per_req, _ = torch.min(mismatch_positions, dim=1)
|
||||||
|
no_mismatch_mask = (first_mismatch_pos_per_req == max_spec_len * 2)
|
||||||
|
first_mismatch_pos_per_req[no_mismatch_mask] = draft_tokens_per_req[
|
||||||
|
no_mismatch_mask]
|
||||||
|
|
||||||
rejected = False
|
# Copy matched target tokens into output.
|
||||||
for pos in range(num_draft_tokens):
|
copy_len = torch.minimum(first_mismatch_pos_per_req + 1,
|
||||||
if not rejected:
|
draft_tokens_per_req)
|
||||||
draft_token_id = draft_token_ids[start_idx + pos].item()
|
copy_indices = torch.arange(max_spec_len + 1,
|
||||||
target_argmax_id = target_argmax[start_idx + pos].item()
|
device=device).expand(batch_size, -1)
|
||||||
|
copy_mask = copy_indices < copy_len.unsqueeze(1)
|
||||||
output_token_ids[req_idx, pos] = target_argmax_id
|
greedy_mask = is_greedy.unsqueeze(1)
|
||||||
|
final_copy_mask = copy_mask & greedy_mask
|
||||||
if draft_token_id != target_argmax_id:
|
global_idx = start_indices.unsqueeze(1) + copy_indices
|
||||||
rejected = True
|
output_token_ids[final_copy_mask] = target_argmax[
|
||||||
|
global_idx[final_copy_mask]].to(output_token_ids.dtype)
|
||||||
if not rejected:
|
# Fill bonus token.
|
||||||
bonus_token_id = bonus_token_ids[req_idx].item()
|
needs_bonus = is_greedy & (first_mismatch_pos_per_req
|
||||||
output_token_ids[req_idx, num_draft_tokens] = bonus_token_id
|
>= draft_tokens_per_req)
|
||||||
|
if torch.any(needs_bonus):
|
||||||
|
bonus_rows = torch.where(needs_bonus)[0]
|
||||||
|
bonus_cols = draft_tokens_per_req[bonus_rows]
|
||||||
|
bonus_token_ids = bonus_token_ids.squeeze(1)
|
||||||
|
output_token_ids[bonus_rows, bonus_cols] = bonus_token_ids[bonus_rows]
|
||||||
|
|
||||||
|
|
||||||
def rejection_random_sample_pytorch(
|
def rejection_random_sample_pytorch(
|
||||||
|
|||||||
Reference in New Issue
Block a user