71 lines
2.3 KiB
Python
71 lines
2.3 KiB
Python
|
|
# SPDX-License-Identifier: Apache-2.0
|
||
|
|
import torch
|
||
|
|
from vllm.v1.spec_decode.eagle import EagleProposer
|
||
|
|
|
||
|
|
|
||
|
|
def prepare_inputs(
|
||
|
|
# [batch_size + 1]
|
||
|
|
cu_target_query_lens: torch.Tensor,
|
||
|
|
# [batch_size]
|
||
|
|
num_rejected_tokens: torch.Tensor,
|
||
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||
|
|
# cu_target_query_lens: [0, a, a + b, a + b + c]
|
||
|
|
# num_rejected_tokens: [n1, n2, n3]
|
||
|
|
# num_tokens_per_req: [a - n1, b - n2, c - n3]
|
||
|
|
# cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
|
||
|
|
# token_indices: [0, 1, ..., a - n1 - 1,
|
||
|
|
# a, a + 1, ..., a + b - n2 - 1,
|
||
|
|
# a + b, a + b + 1, ..., a + b + c - n3 - 1]
|
||
|
|
|
||
|
|
# [0, a, a + b, a + b + c] -> [a, b, c]
|
||
|
|
query_len_per_req = (cu_target_query_lens[1:] - cu_target_query_lens[:-1])
|
||
|
|
# [a, b, c] -> [a - n1, b - n2, c - n3]
|
||
|
|
num_tokens_per_req = query_len_per_req - num_rejected_tokens
|
||
|
|
|
||
|
|
cu_num_tokens = torch.empty_like(cu_target_query_lens)
|
||
|
|
torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:])
|
||
|
|
cu_num_tokens[0] = 0
|
||
|
|
|
||
|
|
# FIXME(woosuk): Avoid synchronization.
|
||
|
|
num_tokens = cu_num_tokens[-1].item()
|
||
|
|
token_indices = torch.empty(
|
||
|
|
num_tokens,
|
||
|
|
dtype=torch.int32,
|
||
|
|
device=cu_num_tokens.device,
|
||
|
|
)
|
||
|
|
|
||
|
|
BLOCK_SIZE = 1024
|
||
|
|
prepare_input_pytorch(
|
||
|
|
token_indices,
|
||
|
|
cu_target_query_lens,
|
||
|
|
cu_num_tokens,
|
||
|
|
block_size=BLOCK_SIZE,
|
||
|
|
)
|
||
|
|
return cu_num_tokens, token_indices
|
||
|
|
|
||
|
|
|
||
|
|
def prepare_input_pytorch(out_ptr: torch.Tensor, cu_query_lens: torch.Tensor,
|
||
|
|
cu_num_tokens: torch.Tensor, block_size: int):
|
||
|
|
num_pids = cu_num_tokens.shape[0] - 1
|
||
|
|
|
||
|
|
for pid in range(num_pids):
|
||
|
|
start_pos = cu_num_tokens[pid].item()
|
||
|
|
end_pos = cu_num_tokens[pid + 1].item()
|
||
|
|
num_tokens = end_pos - start_pos
|
||
|
|
|
||
|
|
index_start = cu_query_lens[pid].item()
|
||
|
|
num_blocks = (num_tokens + block_size - 1)
|
||
|
|
|
||
|
|
for i in range(num_blocks):
|
||
|
|
offset = torch.arange(0,
|
||
|
|
block_size,
|
||
|
|
dtype=out_ptr.dtype,
|
||
|
|
device=cu_query_lens.device)
|
||
|
|
global_indices = start_pos + offset
|
||
|
|
values = index_start + offset
|
||
|
|
mask = offset < num_tokens
|
||
|
|
out_ptr[global_indices[mask]] = values[mask]
|
||
|
|
|
||
|
|
|
||
|
|
EagleProposer.prepare_inputs = prepare_inputs
|