231 lines
8.5 KiB
Python
231 lines
8.5 KiB
Python
from typing import Optional
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from vllm.logger import init_logger
|
|
from vllm.triton_utils import tl, triton
|
|
from vllm.v1.sample.metadata import SamplingMetadata
|
|
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
|
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
|
from vllm.v1.sample.rejection_sampler import generate_uniform_probs, compute_probs, rejection_random_sample_kernel, sample_recovered_tokens
|
|
from vllm.distributed import get_tensor_model_parallel_rank
|
|
PLACEHOLDER_TOKEN_ID: tl.constexpr = -1
|
|
GREEDY_TEMPERATURE: tl.constexpr = -1
|
|
# Maximum number of speculative draft tokens allowed per request in a single
|
|
# step. This value is chosen to be large enough to handle typical use cases.
|
|
MAX_SPEC_LEN = 32
|
|
|
|
|
|
def rejection_greedy_sample_python(
|
|
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
|
|
cu_num_draft_tokens_ptr, # [batch_size]
|
|
draft_token_ids_ptr, # [num_tokens]
|
|
target_argmax_ptr, # [num_tokens]
|
|
bonus_token_ids_ptr, # [batch_size]
|
|
is_greedy_ptr, # [batch_size] or None
|
|
max_spec_len,
|
|
num_warps
|
|
):
|
|
# print('max_spec_len', max_spec_len)
|
|
if max_spec_len == 1:
|
|
for bi in range(output_token_ids_ptr.shape[0]):
|
|
output_token_ids_ptr[bi, 0] = target_argmax_ptr[bi]
|
|
if target_argmax_ptr[bi].item() == draft_token_ids_ptr[bi].item():
|
|
output_token_ids_ptr[bi, 1] = bonus_token_ids_ptr[bi]
|
|
else:
|
|
raise ValueError('TODO mtp k > 1')
|
|
|
|
|
|
class RejectionSampler(nn.Module):
|
|
def forward(
|
|
self,
|
|
metadata: SpecDecodeMetadata,
|
|
# [num_tokens, vocab_size]
|
|
draft_probs: Optional[torch.Tensor],
|
|
# [num_tokens, vocab_size]
|
|
target_logits: torch.Tensor,
|
|
# [batch_size, 1]
|
|
bonus_token_ids: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata,
|
|
) -> torch.Tensor:
|
|
'''
|
|
Args:
|
|
metadata:
|
|
Metadata for spec decoding.
|
|
draft_probs (Optional[torch.Tensor]):
|
|
Probability distribution for the draft tokens. Shape is
|
|
[num_tokens, vocab_size]. Can be None if probabilities are
|
|
not provided, which is the case for ngram spec decode.
|
|
target_logits (torch.Tensor):
|
|
Target model's logits probability distribution.
|
|
Shape is [num_tokens, vocab_size]. Here, probabilities from
|
|
different requests are flattened into a single tensor because
|
|
this is the shape of the output logits.
|
|
NOTE: `target_logits` can be updated in place to save memory.
|
|
bonus_token_ids_tensor (torch.Tensor):
|
|
A tensor containing bonus tokens. Shape is [batch_size, 1].
|
|
Bonus tokens are added to the end of the sequence if all
|
|
proposed tokens are accepted. We generate the bonus tokens
|
|
outside of the rejection sampler with the default sampling
|
|
strategy. It allows for more flexibility in the sampling
|
|
process such as top_p, top_k sampling.
|
|
sampling_metadata (vllm.v1.sample.metadata.SamplingMetadata):
|
|
Additional metadata needed for sampling, such as temperature,
|
|
top-k/top-p parameters, or other relevant information.
|
|
Returns:
|
|
output_token_ids (torch.Tensor):
|
|
A tensor containing the final output token IDs.
|
|
'''
|
|
assert metadata.max_spec_len <= MAX_SPEC_LEN
|
|
# [num_tokens, vocab_size]
|
|
# NOTE(woosuk): `target_logits` can be updated in place inside the
|
|
# `compute_probs` function.
|
|
|
|
# print(sampling_metadata)
|
|
# rank_id = get_tensor_model_parallel_rank()
|
|
if metadata.max_spec_len == 1:
|
|
output_token_ids = torch.vacc.rejection_sampler_v1(
|
|
target_logits.to(torch.float32),
|
|
metadata.draft_token_ids,
|
|
bonus_token_ids,
|
|
sampling_metadata.temperature,
|
|
sampling_metadata.top_p,
|
|
sampling_metadata.top_k,
|
|
sampling_metadata.all_greedy,
|
|
sampling_metadata.all_random,
|
|
sampling_metadata.generators
|
|
)
|
|
else:
|
|
target_probs = compute_probs(
|
|
target_logits.to(torch.float32),
|
|
metadata.cu_num_draft_tokens,
|
|
sampling_metadata,
|
|
)
|
|
output_token_ids = rejection_sample(
|
|
metadata.draft_token_ids,
|
|
metadata.num_draft_tokens,
|
|
metadata.max_spec_len,
|
|
metadata.cu_num_draft_tokens,
|
|
draft_probs,
|
|
target_probs,
|
|
bonus_token_ids,
|
|
sampling_metadata,
|
|
)
|
|
|
|
# output_token_ids_cpu = output_token_ids.cpu().tolist()
|
|
# output_token_ids_dev_cpu = output_token_ids_dev.cpu().tolist()
|
|
# for i in range(len(output_token_ids_cpu)):
|
|
# for j in range(len(output_token_ids_cpu[0])):
|
|
# if output_token_ids_cpu[i][j] != output_token_ids_dev_cpu[i][j]:
|
|
# # print(output_token_ids_cpu)
|
|
# # print(output_token_ids_dev_cpu)
|
|
# print("cpu \n", output_token_ids, "vacc \n" , output_token_ids_dev)
|
|
# exit()
|
|
# print("cpu \n", output_token_ids, "vacc \n" , output_token_ids_dev)
|
|
return output_token_ids
|
|
|
|
def rejection_sample(
|
|
# [num_tokens]
|
|
draft_token_ids: torch.Tensor,
|
|
# [batch_size]
|
|
num_draft_tokens: list[int],
|
|
max_spec_len: int,
|
|
# [batch_size]
|
|
cu_num_draft_tokens: torch.Tensor,
|
|
# [num_tokens, vocab_size]
|
|
draft_probs: Optional[torch.Tensor],
|
|
# [num_tokens, vocab_size]
|
|
target_probs: torch.Tensor,
|
|
# [batch_size, 1]
|
|
bonus_token_ids: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata,
|
|
) -> torch.Tensor:
|
|
assert draft_token_ids.ndim == 1
|
|
assert draft_probs is None or draft_probs.ndim == 2
|
|
assert cu_num_draft_tokens.ndim == 1
|
|
assert target_probs.ndim == 2
|
|
|
|
batch_size = len(num_draft_tokens)
|
|
num_tokens = draft_token_ids.shape[0]
|
|
vocab_size = target_probs.shape[-1]
|
|
device = target_probs.device
|
|
assert draft_token_ids.is_contiguous()
|
|
assert draft_probs is None or draft_probs.is_contiguous()
|
|
assert target_probs.is_contiguous()
|
|
assert bonus_token_ids.is_contiguous()
|
|
assert target_probs.shape == (num_tokens, vocab_size)
|
|
|
|
# Create output buffer.
|
|
output_token_ids = torch.empty(
|
|
(batch_size, max_spec_len + 1),
|
|
dtype=torch.int32, # Consistent with SamplerOutput.sampled_token_ids.
|
|
device=device,
|
|
)
|
|
output_token_ids.fill_(PLACEHOLDER_TOKEN_ID)
|
|
|
|
if sampling_metadata.all_greedy:
|
|
is_greedy = None
|
|
else:
|
|
is_greedy = sampling_metadata.temperature == GREEDY_TEMPERATURE
|
|
if not sampling_metadata.all_random:
|
|
# Rejection sampling for greedy sampling requests.
|
|
target_argmax = target_probs.argmax(dim=-1)
|
|
# rejection_greedy_sample_kernel[(batch_size, )](
|
|
rejection_greedy_sample_python(
|
|
output_token_ids,
|
|
cu_num_draft_tokens,
|
|
draft_token_ids,
|
|
target_argmax,
|
|
bonus_token_ids,
|
|
is_greedy,
|
|
max_spec_len,
|
|
num_warps=1,
|
|
)
|
|
if sampling_metadata.all_greedy:
|
|
return output_token_ids
|
|
else:
|
|
# TODO
|
|
raise ValueError('not support yet')
|
|
|
|
# Generate uniform probabilities for rejection sampling.
|
|
# [num_tokens]
|
|
uniform_probs = generate_uniform_probs(
|
|
num_tokens,
|
|
num_draft_tokens,
|
|
sampling_metadata.generators,
|
|
device,
|
|
)
|
|
|
|
# Sample recovered tokens for each position.
|
|
# [num_tokens]
|
|
recovered_token_ids = sample_recovered_tokens(
|
|
max_spec_len,
|
|
num_draft_tokens,
|
|
cu_num_draft_tokens,
|
|
draft_token_ids,
|
|
draft_probs,
|
|
target_probs,
|
|
sampling_metadata,
|
|
device,
|
|
)
|
|
|
|
# Rejection sampling for random sampling requests.
|
|
rejection_random_sample_kernel[(batch_size, )](
|
|
output_token_ids,
|
|
cu_num_draft_tokens,
|
|
draft_token_ids,
|
|
draft_probs,
|
|
target_probs,
|
|
bonus_token_ids,
|
|
recovered_token_ids,
|
|
uniform_probs,
|
|
is_greedy,
|
|
max_spec_len,
|
|
vocab_size,
|
|
NO_DRAFT_PROBS=draft_probs is None,
|
|
num_warps=1,
|
|
)
|
|
return output_token_ids
|
|
|