# SPDX-License-Identifier: Apache-2.0 from typing import Optional import torch from vllm.triton_utils import HAS_TRITON, tl, triton from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import (GREEDY_TEMPERATURE, generate_uniform_probs) from vllm_ascend.sample.sampler import apply_top_k_top_p PLACEHOLDER_TOKEN_ID = -1 # Maximum number of speculative draft tokens allowed per request in a single # step. This value is chosen to be large enough to handle typical use cases. MAX_SPEC_LEN = 32 vectorcore_num = None device_properties = None if HAS_TRITON: from triton.runtime import driver # type: ignore device_properties = driver.active.utils.get_device_properties( torch.npu.current_device()) vectorcore_num = device_properties['num_vectorcore'] #get vector core number in order for later tiling def apply_sampling_constraints( logits: torch.Tensor, # [num_tokens, vocab_size] cu_num_draft_tokens: torch.Tensor, # [batch_size] sampling_metadata: SamplingMetadata, ) -> torch.Tensor: """Process logits based on sampling metadata. This function applies temperature scaling to the logits, as well as top-k and top-p. For greedy decoding, it returns the original logits. Args: logits: Input logits tensor to be processed. cu_num_draft_tokens: Cumulative number of draft tokens. sampling_metadata: Metadata containing sampling parameters such as temperature and whether greedy sampling is used. Returns: torch.Tensor: Processed logits if non-greedy sampling is used, otherwise returns the original logits. """ assert logits.ndim == 2 assert cu_num_draft_tokens.ndim == 1 if sampling_metadata.all_greedy: return logits num_tokens = logits.shape[0] temperature = expand_batch_to_tokens( sampling_metadata.temperature, cu_num_draft_tokens, num_tokens, replace_from=GREEDY_TEMPERATURE, replace_to=1, ) # NOTE(woosuk): Update `logits` in place to avoid allocating a new tensor. logits.div_(temperature.unsqueeze(-1)) # Get expanded top_k and top_p tensors. top_k = None if sampling_metadata.top_k is not None: top_k = expand_batch_to_tokens( sampling_metadata.top_k, cu_num_draft_tokens, num_tokens, ) top_p = None if sampling_metadata.top_p is not None: top_p = expand_batch_to_tokens( sampling_metadata.top_p, cu_num_draft_tokens, num_tokens, ) # NOTE(woosuk): `apply_top_k_top_p` uses sorting to calculate the mask, # which is slow for large vocab sizes. This may cause performance issues. return apply_top_k_top_p(logits, top_k, top_p) def rejection_sample( # [num_tokens] draft_token_ids: torch.Tensor, # [batch_size] num_draft_tokens: list[int], max_spec_len: int, # [batch_size] cu_num_draft_tokens: torch.Tensor, # [num_tokens, vocab_size] draft_probs: Optional[torch.Tensor], # [num_tokens, vocab_size] target_probs: torch.Tensor, # [batch_size, 1] bonus_token_ids: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> torch.Tensor: assert draft_token_ids.ndim == 1 assert draft_probs is None or draft_probs.ndim == 2 assert cu_num_draft_tokens.ndim == 1 assert target_probs.ndim == 2 batch_size = len(num_draft_tokens) num_tokens = draft_token_ids.shape[0] vocab_size = target_probs.shape[-1] device = target_probs.device assert draft_token_ids.is_contiguous() assert draft_probs is None or draft_probs.is_contiguous() assert target_probs.is_contiguous() assert bonus_token_ids.is_contiguous() assert target_probs.shape == (num_tokens, vocab_size) # Create output buffer. output_token_ids = torch.empty( (batch_size, max_spec_len + 1), dtype=torch.int32, # Consistent with SamplerOutput.sampled_token_ids. device=device, ) output_token_ids.fill_(PLACEHOLDER_TOKEN_ID) if sampling_metadata.all_greedy: is_greedy = None else: is_greedy = sampling_metadata.temperature == GREEDY_TEMPERATURE if not sampling_metadata.all_random: # Rejection sampling for greedy sampling requests. target_argmax = target_probs.argmax(dim=-1) if HAS_TRITON: vec_len = batch_size n = cu_num_draft_tokens.numel() BLOCK_SIZE = 2 grid = triton.cdiv(n, BLOCK_SIZE) if n >= vectorcore_num: grid = vectorcore_num # Empirically tuned value BLOCK_SIZE = triton.next_power_of_2(triton.cdiv(n, grid)) if min(num_draft_tokens) == 1 and max( num_draft_tokens) == 1 and sampling_metadata.all_greedy: rejection_greedy_sample_spec_len_1_triton[(grid, )]( output_token_ids, draft_token_ids, target_argmax, bonus_token_ids, vec_len, BLOCK_SIZE=BLOCK_SIZE, ) else: rejection_greedy_sample_triton[(grid, )]( output_token_ids, cu_num_draft_tokens, draft_token_ids, target_argmax, bonus_token_ids, is_greedy, vec_len, max_spec_len, BLOCK_SIZE=BLOCK_SIZE, ) else: if min(num_draft_tokens) == 1 and max( num_draft_tokens) == 1 and sampling_metadata.all_greedy: rejection_greedy_sample_spec_len_1_pytorch( output_token_ids, draft_token_ids, target_argmax, bonus_token_ids, ) else: rejection_greedy_sample_pytorch( output_token_ids, cu_num_draft_tokens, draft_token_ids, target_argmax, bonus_token_ids, num_draft_tokens, max_spec_len, is_greedy, ) if sampling_metadata.all_greedy: return output_token_ids # Generate uniform probabilities for rejection sampling. # [num_tokens] uniform_probs = generate_uniform_probs( num_tokens, num_draft_tokens, sampling_metadata.generators, device, ) # Sample recovered tokens for each position. # [num_tokens] recovered_token_ids = sample_recovered_tokens( max_spec_len, num_draft_tokens, cu_num_draft_tokens, draft_token_ids, draft_probs, target_probs, sampling_metadata, device, ) # Rejection sampling for random sampling requests. if HAS_TRITON: rejection_random_sample_kernel[(batch_size, )]( output_token_ids, cu_num_draft_tokens, draft_token_ids, draft_probs, target_probs, bonus_token_ids, recovered_token_ids, uniform_probs.to(torch.float32), is_greedy, max_spec_len, vocab_size, NO_DRAFT_PROBS=draft_probs is None, ) else: rejection_random_sample_pytorch( output_token_ids, cu_num_draft_tokens, draft_token_ids, draft_probs, target_probs, bonus_token_ids, recovered_token_ids, uniform_probs, is_greedy, max_spec_len, vocab_size, IS_NGRAM=draft_probs is None, # num_warps=1, ) return output_token_ids def expand_batch_to_tokens( x: torch.Tensor, # [batch_size] cu_num_tokens: torch.Tensor, # [batch_size] num_tokens: int, replace_from: int = 0, replace_to: int = 0, ) -> torch.Tensor: """Expand [batch_size] tensor to [num_tokens] tensor based on the number of tokens per batch in cu_num_tokens. For example, if x = [a, b, c] and cu_num_tokens = [2, 5, 6], then num_tokens = 6, and expanded_x = [a, a, b, b, b, c]. Args: x: [batch_size] tensor to expand. cu_num_tokens: [batch_size] tensor containing the cumulative number of tokens per batch. Each element represents the total number of tokens up to and including that batch. num_tokens: Total number of tokens. replace_from: int = 0 Value to be replaced if it is found in x. replace_to: int = 0 Value to replace with when replace_from is found. Returns: expanded_x: [num_tokens] tensor. """ batch_size = x.shape[0] assert cu_num_tokens.shape[0] == batch_size expanded_x = x.new_empty(num_tokens) if HAS_TRITON: vec_len = batch_size n = cu_num_tokens.numel() BLOCK_SIZE = 2 grid = triton.cdiv(n, BLOCK_SIZE) if n >= vectorcore_num: grid = vectorcore_num BLOCK_SIZE = triton.next_power_of_2(triton.cdiv(n, grid)) expand_kernel[(grid, )]( expanded_x, x, cu_num_tokens, replace_from, replace_to, vec_len, MAX_NUM_TOKENS=MAX_SPEC_LEN, # To avoid recompilation. BLOCK_SIZE=BLOCK_SIZE, ) else: expand_pytorch( expanded_x, x, cu_num_tokens, replace_from, replace_to, MAX_NUM_TOKENS=MAX_SPEC_LEN, # To avoid recompilation. ) return expanded_x def sample_recovered_tokens( max_spec_len: int, num_draft_tokens: list[int], cu_num_draft_tokens: torch.Tensor, draft_token_ids: torch.Tensor, draft_probs: Optional[torch.Tensor], target_probs: torch.Tensor, sampling_metadata: SamplingMetadata, device: torch.device, ) -> torch.Tensor: batch_size = len(num_draft_tokens) vocab_size = target_probs.shape[-1] q = torch.empty( (batch_size, vocab_size), dtype=torch.float32, device=device, ) q.exponential_() num_draft_tensor = torch.tensor(num_draft_tokens, pin_memory=True).to(device, non_blocking=True) has_draft_mask = num_draft_tensor > 0 for i, generator in sampling_metadata.generators.items(): temp_q = torch.empty_like(q[i]) temp_q.exponential_(generator=generator) q[i] = torch.where(has_draft_mask[i], temp_q, q[i]) recovered_token_ids = torch.empty_like(draft_token_ids) if HAS_TRITON: sample_recovered_tokens_kernel[(batch_size, max_spec_len)]( recovered_token_ids, cu_num_draft_tokens, draft_token_ids, draft_probs, target_probs, q, vocab_size, triton.next_power_of_2(vocab_size), NO_DRAFT_PROBS=draft_probs is None, SUB_BLOCK=4 * 1024, # TODO: enable multibuffer when accuracy problem is solved. multibuffer=False, ) else: sample_recovered_tokens_pytorch( recovered_token_ids, cu_num_draft_tokens, draft_token_ids, draft_probs, target_probs, q, vocab_size, IS_NGRAM=draft_probs is None, ) return recovered_token_ids def rejection_greedy_sample_spec_len_1_pytorch( output_token_ids, # [batch_size, 2] draft_token_ids, # [num_tokens] target_argmax, # [num_tokens] bonus_token_ids, # [batch_size] ): batch_size = output_token_ids.size(0) num_tokens = draft_token_ids.size(0) assert batch_size == num_tokens accept_req_mask = draft_token_ids == target_argmax output_token_ids[:, 0] = target_argmax bonus_token_ids = bonus_token_ids.squeeze(1) output_token_ids[:, 1] = torch.where(accept_req_mask, bonus_token_ids, output_token_ids[:, 1]) def rejection_greedy_sample_pytorch( output_token_ids, # [batch_size, max_spec_len + 1] cu_num_draft_tokens, # [batch_size] draft_token_ids, # [num_tokens] target_argmax, # [num_tokens] bonus_token_ids, # [batch_size] draft_tokens_per_req, # [batch_size], list max_spec_len, is_greedy=None, # [batch_size] or None ): batch_size = output_token_ids.size(0) num_tokens = draft_token_ids.size(0) device = output_token_ids.device draft_tokens_per_req = torch.tensor(draft_tokens_per_req).to( device, non_blocking=True) if is_greedy is None: is_greedy = torch.ones(batch_size, dtype=torch.bool, device=device) start_indices = cu_num_draft_tokens - draft_tokens_per_req req_ids = torch.arange(batch_size, device=device) token_req_ids = torch.repeat_interleave(req_ids, draft_tokens_per_req) token_positions = torch.arange( num_tokens, device=device) - start_indices[token_req_ids] # Find the first mismatch position of each request. mismatch_global = (draft_token_ids != target_argmax) if max_spec_len == 0: first_mismatch_pos_per_req = torch.zeros(batch_size, dtype=torch.long, device=device) else: # [bs, max_spec_len] pos_matrix = torch.full((batch_size, max_spec_len), -1, dtype=torch.long, device=device) pos_matrix[token_req_ids, token_positions] = token_positions mismatch_matrix = torch.full((batch_size, max_spec_len), False, dtype=torch.bool, device=device) mismatch_matrix[token_req_ids, token_positions] = mismatch_global mismatch_positions = torch.where(mismatch_matrix, pos_matrix, max_spec_len * 2) first_mismatch_pos_per_req, _ = torch.min(mismatch_positions, dim=1) no_mismatch_mask = (first_mismatch_pos_per_req == max_spec_len * 2) first_mismatch_pos_per_req[no_mismatch_mask] = draft_tokens_per_req[ no_mismatch_mask] # Copy matched target tokens into output. copy_len = torch.minimum(first_mismatch_pos_per_req + 1, draft_tokens_per_req) copy_indices = torch.arange(max_spec_len + 1, device=device).expand(batch_size, -1) copy_mask = copy_indices < copy_len.unsqueeze(1) greedy_mask = is_greedy.unsqueeze(1) final_copy_mask = copy_mask & greedy_mask global_idx = start_indices.unsqueeze(1) + copy_indices output_token_ids[final_copy_mask] = target_argmax[ global_idx[final_copy_mask]].to(output_token_ids.dtype) # Fill bonus token. needs_bonus = is_greedy & (first_mismatch_pos_per_req >= draft_tokens_per_req) if torch.any(needs_bonus): bonus_rows = torch.where(needs_bonus)[0] bonus_cols = draft_tokens_per_req[bonus_rows] bonus_token_ids = bonus_token_ids.squeeze(1) output_token_ids[bonus_rows, bonus_cols] = bonus_token_ids[bonus_rows] def rejection_random_sample_pytorch( output_token_ids, # [batch_size, max_spec_len + 1] cu_num_draft_tokens, # [batch_size] draft_token_ids, # [num_tokens] draft_probs, # [num_tokens, vocab_size] or None target_probs, # [num_tokens, vocab_size] bonus_token_ids, # [batch_size] recovered_token_ids, # [num_tokens] uniform_probs, # [num_tokens] is_greedy, # [batch_size] max_spec_len, vocab_size, IS_NGRAM=False, ): """ This function implements the Speculative Decoding rejection sampling step. Instead of looping through each request and each token (which causes high overhead), it uses a fully vectorized approach: 1. **Index Mapping**: Converts the flattened 1D token arrays into a 2D [batch_size, max_draft_len] grid using 'cu_num_draft_tokens' to handle variable-length sequences in the batch. 2. **Parallel Validation**: Calculates the acceptance condition (target_prob / draft_prob >= uniform_sample) for ALL draft tokens simultaneously across the entire batch. 3. **Short-circuit Simulation**: In the loop version, once a token is rejected, subsequent tokens are ignored. Here, we simulate this by finding the 'first_reject_pos' using argmax on the rejection mask and creating a 'should_skip' mask for all indices after the first failure. 4. **Token Selection**: Uses 'torch.where' to select: - Draft tokens (if accepted) - Recovered tokens (at the point of first rejection) - Bonus tokens (if all tokens in a sequence were accepted) 5. **Masking**: Ensures operations only apply to non-greedy requests and within valid sequence lengths. """ batch_size = output_token_ids.shape[0] device = output_token_ids.device zero_cpu = torch.tensor([0], pin_memory=True) zero_device = zero_cpu.to(device, non_blocking=True) cu_start = torch.cat([zero_device, cu_num_draft_tokens[:-1]]) cu_end = cu_num_draft_tokens num_draft_per_batch = cu_end - cu_start max_draft_len = max_spec_len pos_indices_cpu = torch.arange(max_draft_len, pin_memory=True) pos_indices = pos_indices_cpu.to(device, non_blocking=True)[None, :] valid_mask = pos_indices < num_draft_per_batch[:, None] global_token_indices = cu_start[:, None] + pos_indices global_token_indices = global_token_indices.clamp( 0, draft_token_ids.shape[0] - 1) draft_tokens = draft_token_ids[ global_token_indices] # [batch_size, max_draft_len] if IS_NGRAM: ones_cpu = torch.ones(1, pin_memory=True, dtype=torch.float32) draft_token_probs = ones_cpu.to( device, non_blocking=True).expand_as(draft_tokens) else: flat_indices = global_token_indices.flatten() flat_draft_tokens = draft_tokens.flatten() flat_draft_probs = draft_probs[flat_indices, flat_draft_tokens] draft_token_probs = flat_draft_probs.view(batch_size, max_draft_len) flat_indices = global_token_indices.flatten() flat_draft_tokens = draft_tokens.flatten() flat_target_probs = target_probs[flat_indices, flat_draft_tokens] target_token_probs = flat_target_probs.view(batch_size, max_draft_len) uniform_token_probs = uniform_probs[global_token_indices] recovered_tokens = recovered_token_ids[global_token_indices] zero_threshold_cpu = torch.tensor([0.0], pin_memory=True, dtype=torch.float32) zero_threshold = zero_threshold_cpu.to(device, non_blocking=True) acceptance_condition = (draft_token_probs > zero_threshold) & ( target_token_probs / draft_token_probs >= uniform_token_probs) first_rejection = (~acceptance_condition) & valid_mask default_pos_cpu = torch.full([batch_size, 1], max_draft_len, pin_memory=True) default_pos = default_pos_cpu.to(device, non_blocking=True) first_reject_pos = torch.where( first_rejection.any(dim=1, keepdim=True), first_rejection.float().argmax(dim=1, keepdim=True), default_pos) pos_mask = pos_indices >= first_reject_pos should_skip = pos_mask & valid_mask final_acceptance = acceptance_condition & (~should_skip) non_greedy_mask = ~is_greedy update_mask = non_greedy_mask[:, None] & valid_mask & (~should_skip) first_reject_mask = (pos_indices == first_reject_pos ) & valid_mask & non_greedy_mask[:, None] final_update_mask = update_mask | first_reject_mask final_tokens = torch.where( first_reject_mask, recovered_tokens, torch.where(final_acceptance, draft_tokens, output_token_ids[:, :max_draft_len])) output_token_ids[:, :max_draft_len] = torch.where( final_update_mask, final_tokens, output_token_ids[:, :max_draft_len]) no_rejection = first_reject_pos.squeeze(1) >= num_draft_per_batch should_add_bonus = non_greedy_mask & no_rejection bonus_positions = num_draft_per_batch # [batch_size] seq_len = output_token_ids.shape[1] all_positions_cpu = torch.arange(seq_len, pin_memory=True) all_positions = all_positions_cpu.to( device, non_blocking=True)[None, :] # [1, seq_len] batch_bonus_positions = bonus_positions[:, None] # [batch_size, 1] max_spec_len_cpu = torch.tensor([max_spec_len], pin_memory=True) max_spec_len_device = max_spec_len_cpu.to(device, non_blocking=True) valid_bonus_pos = bonus_positions < (max_spec_len_device + 1) final_bonus_mask = should_add_bonus & valid_bonus_pos bonus_pos_match = (all_positions == batch_bonus_positions) bonus_pos_mask = bonus_pos_match & final_bonus_mask[:, None] bonus_values_expanded = bonus_token_ids.view(-1, 1).expand(-1, seq_len) output_token_ids[:] = torch.where(bonus_pos_mask, bonus_values_expanded, output_token_ids) def expand_pytorch( output_ptr, # [num_tokens] input_ptr, # [batch_size] cu_num_tokens_ptr, # [batch_size] replace_from, replace_to, MAX_NUM_TOKENS, ): """ This function broadcasts batch-level values (input_ptr) to token-level positions (output_ptr) based on cumulative token offsets. It acts like a "scatter" or "repeat_interleave" operation but with custom logic: 1. **Range Broadcasting**: It creates a boolean matrix 'in_range' of size [num_tokens, batch_size] that identifies which batch index each token belongs to by checking if the token index falls between cu_start and cu_end. 2. **Conditional Replacement**: Before expansion, it replaces specific values (e.g., padding or special markers) in the input to prepare the data. 3. **Matrix-based Mapping**: It uses 'torch.einsum' to perform a weighted sum that effectively "picks" the correct batch value for every token position simultaneously, avoiding a Python loop over the batch. """ device = cu_num_tokens_ptr.device batch_size = input_ptr.shape[0] num_tokens = output_ptr.shape[0] if batch_size == 0 or num_tokens == 0: return cu_start = torch.cat([ torch.tensor([0], pin_memory=True).to(device, non_blocking=True), cu_num_tokens_ptr[:-1] ]) cu_end = cu_num_tokens_ptr token_indices = torch.arange(num_tokens, device=device)[:, None] # [num_tokens, 1] cu_start_exp = cu_start[None, :] # [1, batch_size] cu_end_exp = cu_end[None, :] # [1, batch_size] in_range = (token_indices >= cu_start_exp) & (token_indices < cu_end_exp) replaced_input = torch.where(input_ptr == replace_from, replace_to, input_ptr).float() token_values = torch.einsum("tb,b->t", in_range.float(), replaced_input) needs_update = in_range.any(dim=1) output_ptr[:] = torch.where(needs_update, token_values, output_ptr) def sample_recovered_tokens_pytorch( output_token_ids, # [num_tokens] cu_num_draft_tokens, # [batch_size] draft_token_ids, # [num_tokens] draft_probs, # [num_tokens, vocab_size] or None target_probs, # [num_tokens, vocab_size] q, # [batch_size, vocab_size] vocab_size, IS_NGRAM=False, ): """ When a draft token is rejected, we must sample a "recovered" token from a modified distribution. This function calculates that distribution across the entire flattened batch. 1. **Token-to-Batch Mapping**: Using the cumulative draft token counts, it determines which request in the batch each token belongs to. This is necessary because 'q' (normalization factor) is stored per-request. 2. **Probability Adjustment**: - If N-GRAM: It zeroes out the draft token's probability in the target. - If Probabilistic: It calculates max(0, target_probs - draft_probs) as per the standard speculative decoding algorithm. 3. **Normalization & Sampling**: It divides the adjusted probabilities by the normalization distribution 'q'. To remain vectorized, it broadcasts 'q' from [batch_size, vocab] to [num_tokens, vocab]. 4. **Argmax Selection**: It selects the best recovery token for every position in one pass using torch.argmax. """ device = output_token_ids.device num_tokens = output_token_ids.shape[0] if num_tokens == 0: return cu_start = torch.cat([ torch.tensor([0], pin_memory=True).to(device, non_blocking=True), cu_num_draft_tokens[:-1], ]) cu_end = cu_num_draft_tokens token_indices = torch.arange(num_tokens, device=device) # [num_tokens] token_indices_expanded = token_indices[:, None] # [num_tokens, 1] cu_start_expanded = cu_start[None, :] # [1, batch_size] cu_end_expanded = cu_end[None, :] # [1, batch_size] in_range_mask = (token_indices_expanded >= cu_start_expanded) & ( token_indices_expanded < cu_end_expanded) token_to_batch = torch.argmax(in_range_mask.int(), dim=1) has_match = in_range_mask.any(dim=1) token_to_batch = torch.where(has_match, token_to_batch, 0) if IS_NGRAM: token_indices = torch.arange(num_tokens, device=device) modified_target_probs = target_probs.clone() modified_target_probs[token_indices, draft_token_ids] = 0 prob = modified_target_probs else: prob = torch.maximum( target_probs - draft_probs, torch.tensor(0.0, pin_memory=True).to(device, non_blocking=True), ) q_values = q[token_to_batch] # [num_tokens, vocab_size] epsilon = 1e-10 q_values_safe = torch.where(q_values == 0, epsilon, q_values) q_values_safe = torch.where(torch.isinf(q_values), epsilon, q_values_safe) prob_over_q = prob / q_values_safe prob_over_q = torch.where((q_values == 0) | torch.isinf(q_values), -1e10, prob_over_q) recovered_ids = torch.argmax(prob_over_q, dim=1) output_token_ids[:] = recovered_ids @triton.jit(do_not_specialize=["max_spec_len"]) def bonus_renew_1( bonus_token_ids_ptr, position, output_token_ids_ptr, ): bonus_token_id = tl.load(bonus_token_ids_ptr + position) tl.store(output_token_ids_ptr + position * 2 + 1, bonus_token_id) @triton.jit(do_not_specialize=["max_spec_len"]) def rejection_greedy_sample_spec_len_1_triton( output_token_ids_ptr, # [batch_size, 2] draft_token_ids_ptr, # [num_tokens] target_argmax_ptr, # [num_tokens] bonus_token_ids_ptr, vec_len, BLOCK_SIZE: tl.constexpr, ): block_idx = tl.program_id(0) offset = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = offset < vec_len draft_token_id = tl.load(draft_token_ids_ptr + offset, mask) target_argmax_id = tl.load(target_argmax_ptr + offset, mask) tl.store(output_token_ids_ptr + offset * 2, target_argmax_id, mask) for pos in tl.range(0, BLOCK_SIZE): draft_token_id1 = tl.get_element(draft_token_id, (pos, )) target_argmax1 = tl.get_element(target_argmax_id, (pos, )) position = block_idx * BLOCK_SIZE + pos if draft_token_id1 == target_argmax1: bonus_renew_1( bonus_token_ids_ptr, position, output_token_ids_ptr, ) @triton.jit(do_not_specialize=["max_spec_len"]) def bonus_renew( bonus_token_ids_ptr, position, output_token_ids_ptr, max_spec_len, num_tokens1, ): bonus_token_id = tl.load(bonus_token_ids_ptr + position) tl.store( output_token_ids_ptr + position * (max_spec_len + 1) + num_tokens1, bonus_token_id) @triton.jit(do_not_specialize=["max_spec_len"]) def rejection_greedy_sample_triton( output_token_ids_ptr, # [batch_size, max_spec_len + 1] cu_num_draft_tokens_ptr, # [batch_size] draft_token_ids_ptr, # [num_tokens] target_argmax_ptr, # [num_tokens] bonus_token_ids_ptr, # [batch_size] is_greedy_ptr, # [batch_size] or None vec_len, max_spec_len, BLOCK_SIZE: tl.constexpr, ): block_idx = tl.program_id(0) offset = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = offset < vec_len if is_greedy_ptr is None: is_greedy_mask = mask else: is_greedy = tl.load(is_greedy_ptr + offset, mask=mask, other=0) is_greedy_mask = mask & (is_greedy != 0) start_idx = tl.where( offset == 0, 0, tl.load(cu_num_draft_tokens_ptr + offset - 1, is_greedy_mask)) end_idx = tl.load(cu_num_draft_tokens_ptr + offset, is_greedy_mask) num_draft_tokens = end_idx - start_idx for pos in tl.range(0, BLOCK_SIZE): num_tokens1 = tl.get_element(num_draft_tokens, (pos, )) rejected = False start_idx1 = tl.get_element(start_idx, (pos, )) is_greedy_mask1 = tl.get_element(is_greedy_mask, (pos, )) position = block_idx * BLOCK_SIZE + pos for i in range(num_tokens1): if not rejected: draft_token_id = tl.load(draft_token_ids_ptr + start_idx1 + i) target_argmax_id = tl.load(target_argmax_ptr + start_idx1 + i) tl.store( output_token_ids_ptr + position * (max_spec_len + 1) + i, target_argmax_id, ) if draft_token_id != target_argmax_id: # Reject. rejected = True if not rejected and is_greedy_mask1: bonus_renew( bonus_token_ids_ptr, position, output_token_ids_ptr, max_spec_len, num_tokens1, ) @triton.jit(do_not_specialize=["max_spec_len"]) def rejection_random_sample_kernel( output_token_ids_ptr, # [batch_size, max_spec_len + 1] cu_num_draft_tokens_ptr, # [batch_size] draft_token_ids_ptr, # [num_tokens] draft_probs_ptr, # [num_tokens, vocab_size] or None target_probs_ptr, # [num_tokens, vocab_size] bonus_token_ids_ptr, # [batch_size] recovered_token_ids_ptr, # [num_tokens] uniform_probs_ptr, # [num_tokens] is_greedy_ptr, # [batch_size] max_spec_len, vocab_size, NO_DRAFT_PROBS: tl.constexpr, ): req_idx = tl.program_id(0) is_greedy = tl.load(is_greedy_ptr + req_idx) if is_greedy: # Early exost for greedy sampling requests return start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + req_idx - 1) end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx) num_draft_tokens = end_idx - start_idx rejected = False for pos in range(num_draft_tokens): if not rejected: draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) if NO_DRAFT_PROBS: draft_prob = 1 else: draft_prob = tl.load(draft_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id) target_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id) uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos) if draft_prob > 0 and target_prob / draft_prob >= uniform_prob: # Accept token_id = draft_token_id else: # Reject. Use recovered token rejected = True token_id = tl.load(recovered_token_ids_ptr + start_idx + pos) tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, token_id) if not rejected: # If all tokens are accepted, append the bonus token bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx) tl.store( output_token_ids_ptr + req_idx * (max_spec_len + 1) + num_draft_tokens, bonus_token_id, ) @triton.jit(do_not_specialize=["replace_from", "replace_to"]) def expand_kernel( output_ptr, # [num_tokens] input_ptr, # [batch_size] cu_num_tokens_ptr, # [batch_size] replace_from, replace_to, vec_len, MAX_NUM_TOKENS: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): req_idx = tl.program_id(0) offset = req_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) len_mask = offset < vec_len start_idx = tl.where(offset == 0, 0, tl.load(cu_num_tokens_ptr + offset - 1, len_mask)) end_idx = tl.load(cu_num_tokens_ptr + offset, len_mask) num_tokens = end_idx - start_idx src_val = tl.load(input_ptr + offset, len_mask) src_val = tl.where(src_val == replace_from, replace_to, src_val) for i in tl.range(0, BLOCK_SIZE): num_tokens1 = tl.get_element(num_tokens, (i, )) start_idx1 = tl.get_element(start_idx, (i, )) src_val1 = tl.get_element(src_val, (i, )) offset1 = tl.arange(0, MAX_NUM_TOKENS) tl.store(output_ptr + start_idx1 + offset1, src_val1, mask=offset1 < num_tokens1) @triton.jit def sample_recovered_tokens_kernel( output_token_ids_ptr, # [num_tokens] cu_num_draft_tokens_ptr, # [batch_size] draft_token_ids_ptr, # [num_tokens] draft_probs_ptr, # [num_tokens, vocab_size] or None target_probs_ptr, # [num_tokens, vocab_size] q_ptr, # [batch_size, vocab_size] vocab_size, PADDED_VOCAB_SIZE: tl.constexpr, NO_DRAFT_PROBS: tl.constexpr, SUB_BLOCK: tl.constexpr, ): req_idx = tl.program_id(0) start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + req_idx - 1) end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx) num_draft_tokens = end_idx - start_idx # Early exit for out-of-range positions. pos = tl.program_id(1) if pos >= num_draft_tokens: return loop = (vocab_size + SUB_BLOCK - 1) // SUB_BLOCK global_recovered_id = -1 global_max_p = -1.0 if NO_DRAFT_PROBS: draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) orig_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id) # Temporarily zero out the probability of the draft token. # This is essentially the same as target_prob - draft_prob, except that # n-gram does not have draft_prob. We regard it as 1. tl.store( target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id, 0) for loop_i in range(loop): vocab_start = loop_i * SUB_BLOCK vocab_offset = vocab_start + tl.arange(0, SUB_BLOCK) prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset, mask=vocab_offset < vocab_size, other=0) q = tl.load(q_ptr + req_idx * vocab_size + vocab_offset, mask=vocab_offset < vocab_size, other=float("-inf")) new_p = prob / q recovered_id = tl.argmax(new_p, axis=-1) max_p = tl.get_element(new_p, (recovered_id, )) if max_p > global_max_p: global_max_p = max_p global_recovered_id = vocab_start + recovered_id else: for loop_i in range(loop): vocab_start = loop_i * SUB_BLOCK vocab_offset = vocab_start + tl.arange(0, SUB_BLOCK) draft_prob = tl.load(draft_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset, mask=vocab_offset < vocab_size, other=0) target_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset, mask=vocab_offset < vocab_size, other=0) prob = tl.maximum(target_prob - draft_prob, 0) # NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here because # `tl.argmax` will select the maximum value. q = tl.load(q_ptr + req_idx * vocab_size + vocab_offset, mask=vocab_offset < vocab_size, other=float("-inf")) new_p = prob / q recovered_id = tl.argmax(new_p, axis=-1) max_p = tl.get_element(new_p, (recovered_id, )) if max_p > global_max_p: global_max_p = max_p global_recovered_id = vocab_start + recovered_id tl.store(output_token_ids_ptr + start_idx + pos, global_recovered_id) if NO_DRAFT_PROBS: # Restore the original probability. tl.store( target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id, orig_prob)