# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project # SPDX-License-Identifier: Apache-2.0 from typing import Optional import math import torch import triton import triton.language as tl import vllm from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample import rejection_sampler from vllm.v1.sample.rejection_sampler import sample_recovered_tokens from vllm_mlu.mlu_hijack_utils import MluHijackObject from vllm_mlu._mlu_utils import * from vllm_mlu import _mlu_ops as mlu_ops PLACEHOLDER_TOKEN_ID: tl.constexpr = -1 GREEDY_TEMPERATURE: tl.constexpr = 0 # Maximum number of speculative draft tokens allowed per request in a single # step. This value is chosen to be large enough to handle typical use cases. MAX_SPEC_LEN = 128 ''' ============================= Modify by vllm_mlu ============================= @brief: - Limit maximum batch size due to NRAM memory constraints - Add generate_recovered_uniform_probs function for tmo rejection sampler ''' MAX_BATCH_SIZE = 65536 def generate_recovered_uniform_probs( num_tokens: int, vocab_size: int, num_draft_tokens: list[int], sampling_metadata: SamplingMetadata, device: torch.device, ) -> torch.Tensor: q = torch.empty( (num_tokens, vocab_size), dtype=torch.float32, device=device, ) q.exponential_() for i, generator in sampling_metadata.generators.items(): # Do not generate random numbers for requests with no draft tokens. # This can be important for reproducibility. if num_draft_tokens[i] > 0: q[i].exponential_(generator=generator) return q ''' ============================= End of MLU Hijack ============================= ''' def vllm__v1__sample__rejection_sampler__expand_batch_to_tokens( x: torch.Tensor, # [batch_size] cu_num_tokens: torch.Tensor, # [batch_size] num_tokens: int, replace_from: int = 0, replace_to: int = 0, ) -> torch.Tensor: """Expand [batch_size] tensor to [num_tokens] tensor based on the number of tokens per batch in cu_num_tokens. For example, if x = [a, b, c] and cu_num_tokens = [2, 5, 6], then num_tokens = 6, and expanded_x = [a, a, b, b, b, c]. Args: x: [batch_size] tensor to expand. cu_num_tokens: [batch_size] tensor containing the cumulative number of tokens per batch. Each element represents the total number of tokens up to and including that batch. num_tokens: Total number of tokens. replace_from: int = 0 Value to be replaced if it is found in x. replace_to: int = 0 Value to replace with when replace_from is found. Returns: expanded_x: [num_tokens] tensor. """ batch_size = x.shape[0] assert cu_num_tokens.shape[0] == batch_size ''' ============================= Modify by vllm_mlu ============================= ''' if batch_size > MAX_BATCH_SIZE: raise ValueError(f"Rejection Sampler Not Supported: " f"Batch size exceeds the maximum allowed value of {MAX_BATCH_SIZE}") ''' ================== End of MLU Hijack ================== ''' expanded_x = x.new_empty(num_tokens) vllm__v1__sample__rejection_sampler__expand_kernel[(batch_size, )]( expanded_x, x, cu_num_tokens, replace_from, replace_to, MAX_NUM_TOKENS=MAX_SPEC_LEN, # To avoid recompilation. ) return expanded_x # NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation. @triton.jit(do_not_specialize=["replace_from", "replace_to"]) def vllm__v1__sample__rejection_sampler__expand_kernel( output_ptr, # [num_tokens] input_ptr, # [batch_size] cu_num_tokens_ptr, # [batch_size] replace_from, replace_to, MAX_NUM_TOKENS: tl.constexpr, ): req_idx = tl.program_id(0) if req_idx == 0: # noqa: SIM108 ''' ============================= Modify by vllm_mlu ============================= ''' # Ensure data types are consistent start_idx = tl.full((), 0, tl.int64) ''' ================== End of MLU Hijack ================== ''' else: ''' ============================= Modify by vllm_mlu ============================= ''' start_idx = tl.load(cu_num_tokens_ptr + req_idx - 1).to(tl.int64) ''' ================== End of MLU Hijack ================== ''' end_idx = tl.load(cu_num_tokens_ptr + req_idx) num_tokens = end_idx - start_idx src_val = tl.load(input_ptr + req_idx) src_val = tl.where(src_val == replace_from, replace_to, src_val) offset = tl.arange(0, MAX_NUM_TOKENS) tl.store(output_ptr + start_idx + offset, src_val, mask=offset < num_tokens) @triton.jit def vllm__v1__sample__rejection_sampler__sample_recovered_tokens_kernel( output_token_ids_ptr, # [num_tokens] cu_num_draft_tokens_ptr, # [batch_size] draft_token_ids_ptr, # [num_tokens] draft_probs_ptr, # [num_tokens, vocab_size] or None target_probs_ptr, # [num_tokens, vocab_size] q_ptr, # [batch_size, vocab_size] vocab_size, PADDED_VOCAB_SIZE: tl.constexpr, NO_DRAFT_PROBS: tl.constexpr, BLOCK_VOCAB: tl.constexpr = 2048, ): req_idx = tl.program_id(0) start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + req_idx - 1) end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx) num_draft_tokens = end_idx - start_idx # Early exit for out-of-range positions. pos = tl.program_id(1) if pos >= num_draft_tokens: return ''' ============================= Modify by vllm_mlu ============================= ''' max_score = -float("inf") max_index = 0 ''' ================== End of MLU Hijack ================== ''' if NO_DRAFT_PROBS: draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) orig_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id) # Temporarily zero out the probability of the draft token. # This is essentially the same as target_prob - draft_prob, except that # n-gram does not have draft_prob. We regard it as 1. tl.store( target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id, 0) ''' ============================= Modify by vllm_mlu ============================= @brief: Replace with block loop due to ngram limitations ''' num_blocks = tl.cdiv(PADDED_VOCAB_SIZE, BLOCK_VOCAB) for i in tl.range(0, num_blocks): offset = i * BLOCK_VOCAB + tl.arange(0, BLOCK_VOCAB) mask = offset < vocab_size if NO_DRAFT_PROBS: prob = tl.load( target_probs_ptr + (start_idx + pos) * vocab_size + offset, mask=mask, other=0 ) else: draft_prob = tl.load( draft_probs_ptr + (start_idx + pos) * vocab_size + offset, mask=mask, other=0 ) target_prob = tl.load( target_probs_ptr + (start_idx + pos) * vocab_size + offset, mask=mask, other=0 ) prob = tl.maximum(target_prob - draft_prob, 0) # NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here because # `tl.argmax` will select the maximum value. q = tl.load(q_ptr + req_idx * vocab_size + offset, mask=mask, other=float("-inf")) score = prob / q # Broadcasting elementwise cur_max = tl.argmax(score, axis=0) cur_score = score[cur_max] cur_index = offset[cur_max] # Manually maintain argmax. if cur_score > max_score: max_score = cur_score max_index = cur_index tl.store(output_token_ids_ptr + start_idx + pos, max_index) ''' ================== End of MLU Hijack ================== ''' if NO_DRAFT_PROBS: # Restore the original probability. tl.store( target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id, orig_prob) """ ============================= Modify by vllm_mlu ============================= """ def filter_with_acceptance_rate(output_token_ids, # [batch_size, max_spec_len + 1] fixed_acceptance_rate): """ Filter speculative tokens based on a fixed acceptance rate using batch-level accept/reject decisions. This function implements an adaptive acceptance rate control mechanism that maintains a target acceptance rate over time through error compensation and PID-style adjustments. Args: output_token_ids (torch.Tensor): Input tensor of shape [batch_size, max_spec_len + 1] where the first column contains base tokens and remaining columns contain speculative tokens fixed_acceptance_rate (float or None): Target acceptance rate between 0.0 and 1.0 If None, returns input tensor unchanged Returns: torch.Tensor: Modified tensor where rejected batches have all speculative tokens (columns 1 to max_spec_len) set to PLACEHOLDER_TOKEN_ID Algorithm Flow: 1. **Initialization Phase**: - Extract batch dimensions and device information - Initialize static variables for tracking acceptance statistics: * cumulative_error: Long-term error accumulation * total_batches/accepted_batches: Global acceptance tracking * acceptance_history: Sliding window for recent performance * precision_adjustment: PID controller adjustment factor * recent_adjustments: Error history for PID calculation 2. **Statistics Calculation**: - Calculate global acceptance rate from all historical data - Calculate sliding window acceptance rate from recent batches - Compute combined error using weighted average of global and window errors - Weight transitions from global-focused (early) to window-focused (later) 3. **PID Controller Adjustment** (after 50+ batches): - Proportional term: Current error magnitude - Integral term: Accumulated error over recent history - Derivative term: Rate of error change - Combines P, I, D terms to compute precision adjustment factor - Limits adjustment range to prevent over-correction 4. **Error Correction**: - Applies smooth nonlinear correction based on combined error magnitude - Uses exponential decay mapping for gradual adjustment strength - Handles boundary cases (0.0, 1.0, very low rates) specially 5. **Gap-based Adjustment**: - Calculates difference between target and actual accepted batches - Applies adaptive threshold-based corrections - Uses exponential smoothing for adjustment strength - Adjustment strength decreases as total batch count increases 6. **Random Perturbation** (after 100+ batches): - Adds small random noise to prevent local minima - Noise amplitude decreases over time for stability 7. **Batch Decision**: - Generates random value and compares with adjusted acceptance rate - Makes binary accept/reject decision for entire batch 8. **Token Modification**: - If accepted: Leave all tokens unchanged - If rejected: Set all speculative tokens (columns 1:) to PLACEHOLDER_TOKEN_ID - This ensures token-level acceptance rate matches batch-level rate 9. **State Updates**: - Update acceptance counters and history - Update cumulative error using exponential moving average - Prepare state for next function call Key Features: - **Batch-level consistency**: All samples in a batch share the same accept/reject fate - **Adaptive control**: Uses multiple feedback mechanisms (global, windowed, PID) - **Error compensation**: Corrects for deviations from target rate over time - **Stability mechanisms**: Includes smoothing, limits, and perturbation for robustness - **Token-level alignment**: Ensures token acceptance rate matches batch acceptance rate Note: This function maintains internal state across calls through static variables, so it will converge to the target acceptance rate over multiple invocations. """ if fixed_acceptance_rate is None: return output_token_ids else: # Apply accept/reject decisions for the entire batch based on fixed_acceptance_rate batch_size = output_token_ids.shape[0] max_spec_len = output_token_ids.shape[1] - 1 # Get max_spec_len device = output_token_ids.device assert fixed_acceptance_rate >= 0 and fixed_acceptance_rate <= 1 # Use error compensation method to track global acceptance rate # These are static variables that persist between calls if not hasattr(filter_with_acceptance_rate, "cumulative_error"): filter_with_acceptance_rate.cumulative_error = 0.0 if not hasattr(filter_with_acceptance_rate, "total_batches"): filter_with_acceptance_rate.total_batches = 0 if not hasattr(filter_with_acceptance_rate, "accepted_batches"): filter_with_acceptance_rate.accepted_batches = 0 if not hasattr(filter_with_acceptance_rate, "window_size"): filter_with_acceptance_rate.window_size = 1000 # Sliding window size if not hasattr(filter_with_acceptance_rate, "acceptance_history"): filter_with_acceptance_rate.acceptance_history = [] # Track recent accept/reject history if not hasattr(filter_with_acceptance_rate, "precision_adjustment"): filter_with_acceptance_rate.precision_adjustment = 0.0 # Precision adjustment factor if not hasattr(filter_with_acceptance_rate, "recent_adjustments"): filter_with_acceptance_rate.recent_adjustments = [] # Recent adjustment history if not hasattr(filter_with_acceptance_rate, "target_rate"): filter_with_acceptance_rate.target_rate = fixed_acceptance_rate # Record target acceptance rate else: # If target acceptance rate changes, reset adjustment state if filter_with_acceptance_rate.target_rate != fixed_acceptance_rate: filter_with_acceptance_rate.precision_adjustment = 0.0 filter_with_acceptance_rate.recent_adjustments = [] filter_with_acceptance_rate.target_rate = fixed_acceptance_rate # Update batch count filter_with_acceptance_rate.total_batches += 1 # Calculate current global acceptance rate global_rate = (filter_with_acceptance_rate.accepted_batches / filter_with_acceptance_rate.total_batches if filter_with_acceptance_rate.total_batches > 0 else 0.0) # Calculate sliding window acceptance rate (focusing on recent performance) filter_with_acceptance_rate.acceptance_history.append(0) # Default to reject if len(filter_with_acceptance_rate.acceptance_history) > filter_with_acceptance_rate.window_size: filter_with_acceptance_rate.acceptance_history.pop(0) # Remove oldest record window_rate = sum(filter_with_acceptance_rate.acceptance_history) / len(filter_with_acceptance_rate.acceptance_history) # Enhance precision for small batches - use smoother weight function batch_weight_factor = 1.0 - math.exp(-filter_with_acceptance_rate.total_batches / 30.0) # Exponential smooth transition # Dynamically adjust error weights: rely more on global error for fewer batches, # more on sliding window error as batch count increases window_size = len(filter_with_acceptance_rate.acceptance_history) window_significance = min(window_size / 100.0, 0.9) # Window significance depends on historical data volume window_weight = window_significance * batch_weight_factor global_weight = 1.0 - window_weight # Consider both global error and window error combined_error = (global_weight * (global_rate - fixed_acceptance_rate) + window_weight * (window_rate - fixed_acceptance_rate)) # Update precision adjustment factor - use PID controller style adjustment if filter_with_acceptance_rate.total_batches > 50: # Only perform precision adjustment when there's enough data current_error = global_rate - fixed_acceptance_rate # Save recent adjustment history filter_with_acceptance_rate.recent_adjustments.append(current_error) if len(filter_with_acceptance_rate.recent_adjustments) > 20: # Keep recent 20 errors filter_with_acceptance_rate.recent_adjustments.pop(0) # PID controller parameters kp = 0.05 # Proportional coefficient ki = 0.001 # Integral coefficient kd = 0.01 # Derivative coefficient # Proportional term - current error p_term = current_error # Integral term - accumulated error i_term = sum(filter_with_acceptance_rate.recent_adjustments) # Derivative term - error change rate d_term = 0.0 if len(filter_with_acceptance_rate.recent_adjustments) >= 2: d_term = filter_with_acceptance_rate.recent_adjustments[-1] - filter_with_acceptance_rate.recent_adjustments[-2] # Calculate PID adjustment pid_adjustment = kp * p_term + ki * i_term + kd * d_term # Update precision adjustment factor filter_with_acceptance_rate.precision_adjustment = pid_adjustment # Limit adjustment factor range to prevent over-adjustment max_adjustment = 0.02 + 0.03 * (1.0 - math.exp(-filter_with_acceptance_rate.total_batches / 500.0)) filter_with_acceptance_rate.precision_adjustment = max(-max_adjustment, min(max_adjustment, filter_with_acceptance_rate.precision_adjustment)) # Calculate acceptance rate correction factor error_magnitude = abs(combined_error) correction_factor = 1.0 # Use more refined error correction logic - use smooth nonlinear correction function if error_magnitude > 0.0005: # Correct even smaller errors # Use smooth correction function instead of piecewise function base_strength = 2.0 error_scale = 1.0 - math.exp(-error_magnitude * 50.0) # Exponential decay mapping to [0,1] correction_strength = base_strength + error_scale * 1.5 # Range from 2.0 to 3.5 # Smooth correction sign = 1 if combined_error > 0 else -1 correction_factor = 1.0 + (correction_strength * error_magnitude * sign) # Handle boundary cases to avoid division by zero if correction_factor == 0.0: correction_factor = 1.0 # Apply correction factor adjusted_rate = max(0.0, min(1.0, fixed_acceptance_rate * (1.0 / correction_factor))) # Apply precision adjustment factor adjusted_rate = max(0.0, min(1.0, adjusted_rate - filter_with_acceptance_rate.precision_adjustment)) # More precise boundary case handling if fixed_acceptance_rate > 0 and fixed_acceptance_rate < 0.05: if filter_with_acceptance_rate.total_batches % int(1/fixed_acceptance_rate) == 0: adjusted_rate = 1.0 # Periodically force accept to ensure accuracy in low acceptance rate scenarios # If fixed_acceptance_rate is 0, directly reject elif fixed_acceptance_rate == 0.0: adjusted_rate = 0.0 # If fixed_acceptance_rate is 1, directly accept elif fixed_acceptance_rate == 1.0: adjusted_rate = 1.0 # Make precise adjustments for cases with large remaining errors target_accepted = int(filter_with_acceptance_rate.total_batches * fixed_acceptance_rate + 0.5) # Round to nearest actual_accepted = filter_with_acceptance_rate.accepted_batches acceptance_gap = target_accepted - actual_accepted # More aggressive gap adjustment strategy - use adaptive threshold and smooth adjustment gap_relative = abs(acceptance_gap) / max(1, filter_with_acceptance_rate.total_batches) gap_threshold = max(1, int(filter_with_acceptance_rate.total_batches * 0.005)) # Smaller dynamic threshold, at least 1 # Dynamically adjust acceptance rate based on the gap if abs(acceptance_gap) >= gap_threshold: # Use dynamic threshold # Use smooth adjustment strategy if acceptance_gap > 0: # Need to accept more # Use exponential function for smooth adjustment gap_importance = 1.0 - math.exp(-gap_relative * 50.0) # Map to [0,1] # Adjustment strength decreases as total batch count increases strength_factor = math.exp(-filter_with_acceptance_rate.total_batches / 1000.0) boost_factor = gap_importance * (0.2 + 0.8 * strength_factor) # Range from 0 to 1, decreases with total batch count adjusted_rate = min(1.0, adjusted_rate + (1.0 - adjusted_rate) * boost_factor) else: # Accepted too many, need to reject # Use exponential function for smooth adjustment gap_importance = 1.0 - math.exp(-gap_relative * 50.0) # Map to [0,1] # Adjustment strength decreases as total batch count increases strength_factor = math.exp(-filter_with_acceptance_rate.total_batches / 1000.0) reduction_factor = gap_importance * (0.2 + 0.8 * strength_factor) # Range from 0 to 1, decreases with total batch count adjusted_rate = max(0.0, adjusted_rate * (1.0 - reduction_factor)) # Add small random perturbation in fixed intervals to enhance convergence if 0.01 < adjusted_rate < 0.99 and filter_with_acceptance_rate.total_batches > 100: # Random perturbation amplitude decreases as batch count increases noise_amplitude = 0.01 * math.exp(-filter_with_acceptance_rate.total_batches / 500.0) noise = (torch.rand(1, device=device).item() * 2 - 1) * noise_amplitude adjusted_rate = max(0.0, min(1.0, adjusted_rate + noise)) # Generate a random number to decide whether to accept the current batch random_value = torch.rand(1, device=device).item() accept_batch = random_value < adjusted_rate # Set some tokens to PLACEHOLDER_TOKEN_ID to achieve specified acceptance rate # Support max_spec_len > 1 cases if accept_batch: # Accept batch - don't modify token_ids filter_with_acceptance_rate.accepted_batches += 1 filter_with_acceptance_rate.acceptance_history[-1] = 1 # Update the most recent acceptance status else: # Reject batch - set all speculative tokens (except first column) to PLACEHOLDER_TOKEN_ID # This ensures token-level acceptance rate matches batch-level acceptance rate output_token_ids[:, 1:] = PLACEHOLDER_TOKEN_ID # Note: acceptance rate calculation is still based on entire batch accept/reject, no modification needed # But we can add a comment explaining how actual token-level acceptance rate is calculated # Actual token-level acceptance rate = 1 - (number of PLACEHOLDER_TOKEN_ID in output_token_ids / max_spec_len) # Update cumulative error - use exponential moving average for smoother error adjustment actual_rate = filter_with_acceptance_rate.accepted_batches / filter_with_acceptance_rate.total_batches # Use EMA to smooth error updates - use adaptive EMA coefficient alpha = 0.05 * math.exp(-filter_with_acceptance_rate.total_batches / 200.0) + 0.01 # EMA coefficient gradually decreases over time filter_with_acceptance_rate.cumulative_error = (alpha * (actual_rate - fixed_acceptance_rate) + (1 - alpha) * filter_with_acceptance_rate.cumulative_error) return output_token_ids """ ============================= End of MLU Hijack ============================= """ def vllm__v1__sample__rejection_sampler__rejection_sample( # [num_tokens] draft_token_ids: torch.Tensor, # [batch_size] num_draft_tokens: list[int], max_spec_len: int, # [batch_size] cu_num_draft_tokens: torch.Tensor, # [num_tokens, vocab_size] draft_probs: torch.Tensor | None, # [num_tokens, vocab_size] target_probs: torch.Tensor, # [batch_size, 1] bonus_token_ids: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> torch.Tensor: assert draft_token_ids.ndim == 1 assert draft_probs is None or draft_probs.ndim == 2 assert cu_num_draft_tokens.ndim == 1 assert target_probs.ndim == 2 batch_size = len(num_draft_tokens) num_tokens = draft_token_ids.shape[0] vocab_size = target_probs.shape[-1] device = target_probs.device assert draft_token_ids.is_contiguous() assert draft_probs is None or draft_probs.is_contiguous() assert target_probs.is_contiguous() assert bonus_token_ids.is_contiguous() assert target_probs.shape == (num_tokens, vocab_size) ''' ============================= Modify by vllm_mlu ============================= @brief: use tmo rejection_sample for all random sampling requests ''' fixed_acceptance_rate = VLLM_MTP_FIXED_ACCEPTANCE_RATE use_fusion_kernel = (sampling_metadata.all_random and max_spec_len == 1 and (num_draft_tokens is not None and 0 not in num_draft_tokens)) if use_fusion_kernel: # All data is random, use tmo rejection_sample # Generate uniform probabilities for rejection sampling. # [num_tokens] uniform_rand = vllm__v1__sample__rejection_sampler__generate_uniform_probs( num_tokens, num_draft_tokens, sampling_metadata.generators, device, ) # generate random probs for recovered tokens uniform_probs = generate_recovered_uniform_probs( num_tokens, vocab_size, num_draft_tokens, sampling_metadata, device, ) # num_draft_tokens need to be a tensor num_draft_tokens_tensor = torch.tensor(num_draft_tokens, dtype=torch.int32, device=device) # tmo rejection_sample dtype need to be int32 bonus_token_ids = bonus_token_ids.to(torch.int32) draft_token_ids = draft_token_ids.to(torch.int32) # use tmo rejection_sample output_token_ids = mlu_ops.rejection_sample( draft_token_ids, num_draft_tokens_tensor, cu_num_draft_tokens, draft_probs, target_probs, bonus_token_ids, uniform_rand, uniform_probs, max_spec_len, high_acc=True # for now, only support high_acc ).view(batch_size, max_spec_len + 1) if fixed_acceptance_rate is not None: # set all speculative tokens to placeholder token output_token_ids[:, 1:] = 0 output_token_ids = filter_with_acceptance_rate(output_token_ids, fixed_acceptance_rate) return output_token_ids ''' ============================= End of MLU Hijack ============================= ''' # Create output buffer. output_token_ids = torch.full( (batch_size, max_spec_len + 1), PLACEHOLDER_TOKEN_ID, dtype=torch.int32, # Consistent with SamplerOutput.sampled_token_ids. device=device, ) if sampling_metadata.all_greedy: is_greedy = None else: is_greedy = sampling_metadata.temperature == GREEDY_TEMPERATURE if not sampling_metadata.all_random: # Rejection sampling for greedy sampling requests. target_argmax = target_probs.argmax(dim=-1) vllm__v1__sample__rejection_sampler__rejection_greedy_sample_kernel[(batch_size, )]( output_token_ids, cu_num_draft_tokens, draft_token_ids, target_argmax, bonus_token_ids, is_greedy, max_spec_len, has_acceptance_rate=fixed_acceptance_rate is not None, ) if sampling_metadata.all_greedy: output_token_ids = filter_with_acceptance_rate(output_token_ids, fixed_acceptance_rate) return output_token_ids # Generate uniform probabilities for rejection sampling. # [num_tokens] uniform_probs = vllm__v1__sample__rejection_sampler__generate_uniform_probs( num_tokens, num_draft_tokens, sampling_metadata.generators, device, ) # Sample recovered tokens for each position. # [num_tokens] recovered_token_ids = sample_recovered_tokens( max_spec_len, num_draft_tokens, cu_num_draft_tokens, draft_token_ids, draft_probs, target_probs, sampling_metadata, device, ) ''' ============================= Modify by vllm_mlu ============================= @brief: Add fixed acceptance rate check ''' # Rejection sampling for random sampling requests. vllm__v1__sample__rejection_sampler__rejection_random_sample_kernel[(batch_size, )]( output_token_ids, cu_num_draft_tokens, draft_token_ids, draft_probs, target_probs, bonus_token_ids, recovered_token_ids, uniform_probs, is_greedy, max_spec_len, vocab_size, NO_DRAFT_PROBS=draft_probs is None, has_acceptance_rate=fixed_acceptance_rate is not None, ) output_token_ids = filter_with_acceptance_rate(output_token_ids, fixed_acceptance_rate) ''' ================== End of MLU Hijack ================== ''' return output_token_ids # NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation. @triton.jit(do_not_specialize=["max_spec_len"]) def vllm__v1__sample__rejection_sampler__rejection_random_sample_kernel( output_token_ids_ptr, # [batch_size, max_spec_len + 1] cu_num_draft_tokens_ptr, # [batch_size] draft_token_ids_ptr, # [num_tokens] draft_probs_ptr, # [num_tokens, vocab_size] or None target_probs_ptr, # [num_tokens, vocab_size] bonus_token_ids_ptr, # [batch_size] recovered_token_ids_ptr, # [num_tokens] uniform_probs_ptr, # [num_tokens] is_greedy_ptr, # [batch_size] max_spec_len, vocab_size, NO_DRAFT_PROBS: tl.constexpr, has_acceptance_rate: tl.constexpr, ): req_idx = tl.program_id(0) is_greedy = tl.load(is_greedy_ptr + req_idx) if is_greedy: # Early exit for greedy sampling requests. return start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + req_idx - 1) end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx) num_draft_tokens = end_idx - start_idx rejected = False for pos in range(num_draft_tokens): if not rejected: draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) if NO_DRAFT_PROBS: draft_prob = 1 else: draft_prob = tl.load( draft_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id ) target_prob = tl.load( target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id ) uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos) ''' ============================= Modify by vllm_mlu ============================= @brief: add accept rate check, always accept if has_acceptance_rate is True ''' # NOTE(woosuk): While the draft probability should never be 0, # we check it to avoid NaNs. If it happens to be 0, we reject. if draft_prob > 0 and target_prob / draft_prob >= uniform_prob or has_acceptance_rate: # Accept. token_id = draft_token_id else: # Reject. Use recovered token. rejected = True token_id = tl.load(recovered_token_ids_ptr + start_idx + pos) ''' ============================= End of MLU Hijack ============================= ''' tl.store( output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, token_id ) ''' ============================= Modify by vllm_mlu ============================= @brief: Check whether to accept bonus token through acceptance_rate_ptr ''' # If has acceptance rate, all tokens are accepted if has_acceptance_rate: rejected = False if not rejected: # If all tokens are accepted, append the bonus token. bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx) tl.store( output_token_ids_ptr + req_idx * (max_spec_len + 1) + num_draft_tokens, bonus_token_id, ) ''' ================== End of MLU Hijack ================== ''' # NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation. @triton.jit(do_not_specialize=["max_spec_len"]) def vllm__v1__sample__rejection_sampler__rejection_greedy_sample_kernel( output_token_ids_ptr, # [batch_size, max_spec_len + 1] cu_num_draft_tokens_ptr, # [batch_size] draft_token_ids_ptr, # [num_tokens] target_argmax_ptr, # [num_tokens] bonus_token_ids_ptr, # [batch_size] is_greedy_ptr, # [batch_size] or None max_spec_len, has_acceptance_rate: tl.constexpr, ): req_idx = tl.program_id(0) # FIXME(woosuk): Because is_greedy_ptr is not None at profiling run, # re-compilation may happen during runtime when is_greedy_ptr is None. is_greedy = True if is_greedy_ptr is None else tl.load(is_greedy_ptr + req_idx) if not is_greedy: # Early exit for non-greedy sampling requests. return start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + req_idx - 1) end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx) num_draft_tokens = end_idx - start_idx rejected = False for pos in range(num_draft_tokens): if not rejected: draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) target_argmax_id = tl.load(target_argmax_ptr + start_idx + pos) tl.store( output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, target_argmax_id, ) if draft_token_id != target_argmax_id: # Reject. rejected = True if has_acceptance_rate: rejected = False if not rejected: # If all tokens are accepted, append the bonus token. bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx) tl.store( output_token_ids_ptr + req_idx * (max_spec_len + 1) + num_draft_tokens, bonus_token_id, ) def vllm__v1__sample__rejection_sampler__generate_uniform_probs( num_tokens: int, num_draft_tokens: list[int], generators: dict[int, torch.Generator], device: torch.device, ) -> torch.Tensor: """ Generates a batch of uniform random samples, with optional seeding if available. This method creates a tensor of shape `(num_tokens, )` filled with uniform random values in the range [0, 1). If `generators` is provided, the requests with their own seeds will use the provided `torch.Generator` for reproducibility. The samples for the other requests will be generated without a seed. Args: num_tokens: int Total number of tokens. num_draft_tokens: List[List[int]] Number of draft tokens per request. generators: Optional[Dict[int, torch.Generator]] A dictionary mapping indices in the batch to `torch.Generator` objects. device: torch.device The device on which to allocate the tensor. Returns: uniform_rand: torch.Tensor A tensor of shape `(num_tokens, )` containing uniform random values in the range [0, 1). """ # NOTE(woosuk): We deliberately use float64 instead of float32 here # because when using float32, there's a non-negligible chance that # uniform_prob is sampled to be exact 0.0 as reported in # https://github.com/pytorch/pytorch/issues/16706. Using float64 # mitigates the issue. ''' ============================= Modify by vllm_mlu ============================= @brief: Changed torch.float64 to torch.float32 ''' uniform_probs = torch.rand( (num_tokens,), dtype=torch.float32, device=device, ) ''' ================== End of MLU Hijack ================== ''' start_idx = 0 for req_idx, n in enumerate(num_draft_tokens): # Do not generate random numbers for requests with no draft tokens. # This can be important for reproducibility. if n == 0: continue end_idx = start_idx + n generator = generators.get(req_idx) if generator is not None: uniform_probs[start_idx:end_idx].uniform_(generator=generator) start_idx = end_idx return uniform_probs MluHijackObject.apply_hijack(rejection_sampler, rejection_sampler.generate_uniform_probs, vllm__v1__sample__rejection_sampler__generate_uniform_probs) MluHijackObject.apply_hijack(rejection_sampler, rejection_sampler.expand_batch_to_tokens, vllm__v1__sample__rejection_sampler__expand_batch_to_tokens) MluHijackObject.apply_hijack(rejection_sampler, rejection_sampler.expand_kernel, vllm__v1__sample__rejection_sampler__expand_kernel) MluHijackObject.apply_hijack(rejection_sampler, rejection_sampler.sample_recovered_tokens_kernel, vllm__v1__sample__rejection_sampler__sample_recovered_tokens_kernel) MluHijackObject.apply_hijack(rejection_sampler, rejection_sampler.rejection_sample, vllm__v1__sample__rejection_sampler__rejection_sample) MluHijackObject.apply_hijack(rejection_sampler, rejection_sampler.rejection_random_sample_kernel, vllm__v1__sample__rejection_sampler__rejection_random_sample_kernel) MluHijackObject.apply_hijack(rejection_sampler, rejection_sampler.rejection_greedy_sample_kernel, vllm__v1__sample__rejection_sampler__rejection_greedy_sample_kernel)