from typing import Optional import torch import torch.nn as nn from vllm.logger import init_logger from vllm.triton_utils import tl, triton from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.sample.rejection_sampler import generate_uniform_probs, compute_probs, rejection_random_sample_kernel, sample_recovered_tokens from vllm.distributed import get_tensor_model_parallel_rank PLACEHOLDER_TOKEN_ID: tl.constexpr = -1 GREEDY_TEMPERATURE: tl.constexpr = -1 # Maximum number of speculative draft tokens allowed per request in a single # step. This value is chosen to be large enough to handle typical use cases. MAX_SPEC_LEN = 32 def rejection_greedy_sample_python( output_token_ids_ptr, # [batch_size, max_spec_len + 1] cu_num_draft_tokens_ptr, # [batch_size] draft_token_ids_ptr, # [num_tokens] target_argmax_ptr, # [num_tokens] bonus_token_ids_ptr, # [batch_size] is_greedy_ptr, # [batch_size] or None max_spec_len, num_warps ): # print('max_spec_len', max_spec_len) if max_spec_len == 1: for bi in range(output_token_ids_ptr.shape[0]): output_token_ids_ptr[bi, 0] = target_argmax_ptr[bi] if target_argmax_ptr[bi].item() == draft_token_ids_ptr[bi].item(): output_token_ids_ptr[bi, 1] = bonus_token_ids_ptr[bi] else: raise ValueError('TODO mtp k > 1') class RejectionSampler(nn.Module): def forward( self, metadata: SpecDecodeMetadata, # [num_tokens, vocab_size] draft_probs: Optional[torch.Tensor], # [num_tokens, vocab_size] target_logits: torch.Tensor, # [batch_size, 1] bonus_token_ids: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> torch.Tensor: ''' Args: metadata: Metadata for spec decoding. draft_probs (Optional[torch.Tensor]): Probability distribution for the draft tokens. Shape is [num_tokens, vocab_size]. Can be None if probabilities are not provided, which is the case for ngram spec decode. target_logits (torch.Tensor): Target model's logits probability distribution. Shape is [num_tokens, vocab_size]. Here, probabilities from different requests are flattened into a single tensor because this is the shape of the output logits. NOTE: `target_logits` can be updated in place to save memory. bonus_token_ids_tensor (torch.Tensor): A tensor containing bonus tokens. Shape is [batch_size, 1]. Bonus tokens are added to the end of the sequence if all proposed tokens are accepted. We generate the bonus tokens outside of the rejection sampler with the default sampling strategy. It allows for more flexibility in the sampling process such as top_p, top_k sampling. sampling_metadata (vllm.v1.sample.metadata.SamplingMetadata): Additional metadata needed for sampling, such as temperature, top-k/top-p parameters, or other relevant information. Returns: output_token_ids (torch.Tensor): A tensor containing the final output token IDs. ''' assert metadata.max_spec_len <= MAX_SPEC_LEN # [num_tokens, vocab_size] # NOTE(woosuk): `target_logits` can be updated in place inside the # `compute_probs` function. # print(sampling_metadata) # rank_id = get_tensor_model_parallel_rank() if metadata.max_spec_len == 1: output_token_ids = torch.vacc.rejection_sampler_v1( target_logits.to(torch.float32), metadata.draft_token_ids, bonus_token_ids, sampling_metadata.temperature, sampling_metadata.top_p, sampling_metadata.top_k, sampling_metadata.all_greedy, sampling_metadata.all_random, sampling_metadata.generators ) else: target_probs = compute_probs( target_logits.to(torch.float32), metadata.cu_num_draft_tokens, sampling_metadata, ) output_token_ids = rejection_sample( metadata.draft_token_ids, metadata.num_draft_tokens, metadata.max_spec_len, metadata.cu_num_draft_tokens, draft_probs, target_probs, bonus_token_ids, sampling_metadata, ) # output_token_ids_cpu = output_token_ids.cpu().tolist() # output_token_ids_dev_cpu = output_token_ids_dev.cpu().tolist() # for i in range(len(output_token_ids_cpu)): # for j in range(len(output_token_ids_cpu[0])): # if output_token_ids_cpu[i][j] != output_token_ids_dev_cpu[i][j]: # # print(output_token_ids_cpu) # # print(output_token_ids_dev_cpu) # print("cpu \n", output_token_ids, "vacc \n" , output_token_ids_dev) # exit() # print("cpu \n", output_token_ids, "vacc \n" , output_token_ids_dev) return output_token_ids def rejection_sample( # [num_tokens] draft_token_ids: torch.Tensor, # [batch_size] num_draft_tokens: list[int], max_spec_len: int, # [batch_size] cu_num_draft_tokens: torch.Tensor, # [num_tokens, vocab_size] draft_probs: Optional[torch.Tensor], # [num_tokens, vocab_size] target_probs: torch.Tensor, # [batch_size, 1] bonus_token_ids: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> torch.Tensor: assert draft_token_ids.ndim == 1 assert draft_probs is None or draft_probs.ndim == 2 assert cu_num_draft_tokens.ndim == 1 assert target_probs.ndim == 2 batch_size = len(num_draft_tokens) num_tokens = draft_token_ids.shape[0] vocab_size = target_probs.shape[-1] device = target_probs.device assert draft_token_ids.is_contiguous() assert draft_probs is None or draft_probs.is_contiguous() assert target_probs.is_contiguous() assert bonus_token_ids.is_contiguous() assert target_probs.shape == (num_tokens, vocab_size) # Create output buffer. output_token_ids = torch.empty( (batch_size, max_spec_len + 1), dtype=torch.int32, # Consistent with SamplerOutput.sampled_token_ids. device=device, ) output_token_ids.fill_(PLACEHOLDER_TOKEN_ID) if sampling_metadata.all_greedy: is_greedy = None else: is_greedy = sampling_metadata.temperature == GREEDY_TEMPERATURE if not sampling_metadata.all_random: # Rejection sampling for greedy sampling requests. target_argmax = target_probs.argmax(dim=-1) # rejection_greedy_sample_kernel[(batch_size, )]( rejection_greedy_sample_python( output_token_ids, cu_num_draft_tokens, draft_token_ids, target_argmax, bonus_token_ids, is_greedy, max_spec_len, num_warps=1, ) if sampling_metadata.all_greedy: return output_token_ids else: # TODO raise ValueError('not support yet') # Generate uniform probabilities for rejection sampling. # [num_tokens] uniform_probs = generate_uniform_probs( num_tokens, num_draft_tokens, sampling_metadata.generators, device, ) # Sample recovered tokens for each position. # [num_tokens] recovered_token_ids = sample_recovered_tokens( max_spec_len, num_draft_tokens, cu_num_draft_tokens, draft_token_ids, draft_probs, target_probs, sampling_metadata, device, ) # Rejection sampling for random sampling requests. rejection_random_sample_kernel[(batch_size, )]( output_token_ids, cu_num_draft_tokens, draft_token_ids, draft_probs, target_probs, bonus_token_ids, recovered_token_ids, uniform_probs, is_greedy, max_spec_len, vocab_size, NO_DRAFT_PROBS=draft_probs is None, num_warps=1, ) return output_token_ids