from __future__ import annotations import math from dataclasses import dataclass from typing import TYPE_CHECKING, Any, List, Optional import torch import torch.nn.functional as F import triton import triton.language as tl from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.schedule_batch import ModelWorkerBatch from sglang.srt.managers.scheduler import global_server_args_dict from sglang.srt.mem_cache.memory_pool import ReqToTokenPool from sglang.srt.model_executor.forward_batch_info import ( CaptureHiddenMode, ForwardBatch, ForwardMode, ) from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.speculative.build_eagle_tree import TreeMaskMode from sglang.srt.speculative.spec_utils import ( SIMULATE_ACC_LEN, generate_simulated_accept_index, ) from sglang.srt.utils.common import fast_topk, is_cuda, is_hip, next_power_of_2 if TYPE_CHECKING: from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.speculative.eagle_draft_cuda_graph_runner import ( EAGLEDraftCudaGraphRunner, ) from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput if is_cuda(): from sgl_kernel import ( top_k_renorm_prob, top_p_renorm_prob, tree_speculative_sampling_target_only, verify_tree_greedy, ) from sgl_kernel.top_k import fast_topk elif is_hip(): from sgl_kernel import verify_tree_greedy @triton.jit def assign_draft_cache_locs_page_size_1( req_pool_indices, req_to_token, seq_lens, out_cache_loc, pool_len: tl.constexpr, topk: tl.constexpr, speculative_num_steps: tl.constexpr, ): BLOCK_SIZE: tl.constexpr = 128 pid = tl.program_id(axis=0) copy_len = topk * speculative_num_steps out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps # Copy from req_to_token to out_cache_loc kv_start = tl.load(seq_lens + pid) token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len num_loop = tl.cdiv(copy_len, BLOCK_SIZE) for i in range(num_loop): copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE mask = copy_offset < copy_len data = tl.load(token_pool + kv_start + copy_offset, mask=mask) tl.store(out_cache_ptr + copy_offset, data, mask=mask) @dataclass class EagleDraftInputV2Mixin: def prepare_for_v2_draft( self: EagleDraftInput, req_to_token_pool: ReqToTokenPool, batch: ModelWorkerBatch, cuda_graph_runner: EAGLEDraftCudaGraphRunner, draft_model_runner: ModelRunner, topk: int, num_steps: int, ): bs = len(batch.seq_lens) # Assign cache locations batch.out_cache_loc = torch.empty( (bs * topk * num_steps,), dtype=torch.int64, device=batch.input_ids.device, ) # FIXME(lsyin): align with the default code path assign_draft_cache_locs_page_size_1[(bs,)]( batch.req_pool_indices, req_to_token_pool.req_to_token, batch.seq_lens, batch.out_cache_loc, req_to_token_pool.req_to_token.shape[1], topk, num_steps, ) # Get a forward batch batch.capture_hidden_mode = CaptureHiddenMode.LAST self.positions = batch.seq_lens.repeat_interleave(topk, dim=0) forward_batch = ForwardBatch.init_new(batch, draft_model_runner) can_cuda_graph = cuda_graph_runner and cuda_graph_runner.can_run(forward_batch) return forward_batch, can_cuda_graph def prepare_for_extend_to_fill_draft_kvcache( self, batch: ModelWorkerBatch, predict: torch.Tensor, num_draft_tokens: int, draft_model_runner: Any, ): seq_lens_cpu_backup = batch.seq_lens_cpu extend_num_tokens = len(batch.seq_lens) * num_draft_tokens batch.spec_info = self batch.input_ids = predict batch.seq_lens = batch.seq_lens + num_draft_tokens batch.seq_lens_cpu = batch.seq_lens_cpu + num_draft_tokens batch.seq_lens_sum += extend_num_tokens batch.extend_seq_lens = [num_draft_tokens for _ in range(len(batch.seq_lens))] batch.extend_prefix_lens = seq_lens_cpu_backup.tolist() batch.extend_prefix_lens_cpu = seq_lens_cpu_backup batch.extend_num_tokens = extend_num_tokens batch.capture_hidden_mode = CaptureHiddenMode.FULL batch.forward_mode = ForwardMode.DRAFT_EXTEND_V2 forward_batch = ForwardBatch.init_new(batch, draft_model_runner) draft_model_runner.attn_backend.init_forward_metadata(forward_batch) return forward_batch @dataclass class EagleVerifyInputV2Mixin: def prepare_for_v2_verify( self: EagleVerifyInput, req_to_token_pool: ReqToTokenPool, batch: ModelWorkerBatch, target_worker: TpModelWorker, ): # Assign cache locations bs = len(batch.req_pool_indices) batch.input_ids = self.draft_token device = batch.input_ids.device batch.out_cache_loc = torch.empty( (bs * self.draft_token_num,), dtype=torch.int64, device=device, ) assign_extend_cache_locs[(bs,)]( batch.req_pool_indices, req_to_token_pool.req_to_token, batch.seq_lens, batch.seq_lens + self.draft_token_num, batch.out_cache_loc, req_to_token_pool.req_to_token.shape[1], next_power_of_2(bs), ) # Get a forward batch batch.forward_mode = ForwardMode.TARGET_VERIFY batch.capture_hidden_mode = CaptureHiddenMode.FULL verify_forward_batch = ForwardBatch.init_new(batch, target_worker.model_runner) # Run attention backend plan and cuda graph preparation can_run_cuda_graph = bool( target_worker.model_runner.graph_runner and target_worker.model_runner.graph_runner.can_run(verify_forward_batch) ) if can_run_cuda_graph: target_worker.model_runner.graph_runner.replay_prepare(verify_forward_batch) else: target_worker.model_runner.attn_backend.init_forward_metadata( verify_forward_batch ) return verify_forward_batch, can_run_cuda_graph def sample( self: EagleVerifyInput, batch: ModelWorkerBatch, logits_output: LogitsProcessorOutput, ): """ Verify and find accepted tokens based on logits output and batch (which contains spec decoding information). """ bs = len(batch.seq_lens) sampling_info = batch.sampling_info next_token_logits = logits_output.next_token_logits device = batch.input_ids.device candidates = self.draft_token.reshape(bs, self.draft_token_num) predict = torch.zeros( (bs * (self.spec_steps + 1),), dtype=torch.int32, device=device ) accept_index = torch.full( (bs, self.spec_steps + 1), -1, dtype=torch.int32, device=device ) accept_length = torch.empty((bs,), dtype=torch.int32, device=device) # Sample tokens if sampling_info.is_all_greedy: target_predict = torch.argmax(next_token_logits, dim=-1) target_predict = target_predict.reshape(bs, self.draft_token_num) verify_tree_greedy( predicts=predict, # mutable accept_index=accept_index, # mutable accept_token_num=accept_length, # mutable candidates=candidates, retrive_index=self.retrive_index, retrive_next_token=self.retrive_next_token, retrive_next_sibling=self.retrive_next_sibling, target_predict=target_predict, ) else: # Apply temperature and get target probs expanded_temperature = torch.repeat_interleave( sampling_info.temperatures, self.draft_token_num, dim=0 ) # (bs * num_draft_tokens, 1) target_probs = F.softmax( next_token_logits / expanded_temperature, dim=-1 ) # (bs * num_draft_tokens, vocab_size) target_probs = top_k_renorm_prob( target_probs, torch.repeat_interleave( sampling_info.top_ks, self.draft_token_num, dim=0 ), ) # (bs * num_draft_tokens, vocab_size) target_probs = top_p_renorm_prob( target_probs, torch.repeat_interleave( sampling_info.top_ps, self.draft_token_num, dim=0 ), ) target_probs = target_probs.reshape(bs, self.draft_token_num, -1) # This is currently not used draft_probs = torch.empty_like(target_probs) # coins for rejection sampling coins = torch.rand_like(candidates, dtype=torch.float32, device=device) # coins for final sampling coins_for_final_sampling = torch.rand( (bs,), dtype=torch.float32, device=device ) tree_speculative_sampling_target_only( predicts=predict, # mutable accept_index=accept_index, # mutable accept_token_num=accept_length, # mutable candidates=candidates, retrive_index=self.retrive_index, retrive_next_token=self.retrive_next_token, retrive_next_sibling=self.retrive_next_sibling, uniform_samples=coins, uniform_samples_for_final_sampling=coins_for_final_sampling, target_probs=target_probs, draft_probs=draft_probs, threshold_single=global_server_args_dict[ "speculative_accept_threshold_single" ], threshold_acc=global_server_args_dict[ "speculative_accept_threshold_acc" ], deterministic=True, ) if SIMULATE_ACC_LEN > 0: # Do simulation accept_index = generate_simulated_accept_index( accept_index=accept_index, predict=predict, # mutable accept_length=accept_length, # mutable simulate_acc_len=SIMULATE_ACC_LEN, bs=bs, spec_steps=self.draft_token_num, ) # Include the bonus token accept_length.add_(1) return predict, accept_length, accept_index def build_tree_kernel_efficient_tmp( verified_id: torch.Tensor, parent_list: List[torch.Tensor], top_scores_index: torch.Tensor, draft_tokens: torch.Tensor, seq_lens: torch.Tensor, seq_lens_sum: int, topk: int, spec_steps: int, num_verify_tokens: int, tree_mask_mode: TreeMaskMode = TreeMaskMode.FULL_MASK, tree_mask_buf: Optional[torch.Tensor] = None, position_buf: Optional[torch.Tensor] = None, ): # TODO(lsyin): make it compatible with default code path # TODO(lsyin): support cuda graph graph padding for eagle draft_tokens = torch.cat((verified_id.unsqueeze(1), draft_tokens), dim=1).flatten() # seq_lens_sum == sum(seq_lens); seq_lens: sequence length without draft tokens bs = seq_lens.numel() device = seq_lens.device # e.g. for bs=1, tree_mask: num_draft_token, seq_lens_sum + num_draft_token (flattened) # where each row indicates the attending pattern of each draft token # if use_partial_packed_tree_mask is True, tree_mask: num_draft_token (flattened, packed) if tree_mask_buf is not None: tree_mask = tree_mask_buf if tree_mask_mode == TreeMaskMode.QLEN_ONLY: tree_mask.fill_(True) elif tree_mask_mode == TreeMaskMode.QLEN_ONLY_BITPACKING: tree_mask.fill_(0) elif tree_mask_mode == TreeMaskMode.FULL_MASK: tree_mask.fill_(True) else: raise NotImplementedError(f"Invalid tree mask: {tree_mask_mode=}") elif tree_mask_mode == TreeMaskMode.QLEN_ONLY: tree_mask = torch.full( (num_verify_tokens * bs * num_verify_tokens,), True, dtype=torch.bool, device=device, ) elif tree_mask_mode == TreeMaskMode.QLEN_ONLY_BITPACKING: packed_dtypes = [torch.uint8, torch.uint16, torch.uint32] packed_dtype_idx = int(math.ceil(math.log2((num_verify_tokens + 7) // 8))) tree_mask = torch.zeros( (num_verify_tokens * bs,), dtype=packed_dtypes[packed_dtype_idx], device=device, ) elif tree_mask_mode == TreeMaskMode.FULL_MASK: tree_mask = torch.full( ( seq_lens_sum * num_verify_tokens + num_verify_tokens * num_verify_tokens * bs, ), True, device=device, ) else: raise NotImplementedError(f"Invalid tree mask: {tree_mask_mode=}") # TODO: make them torch.empty and fuse them into `sgl_build_tree_kernel` retrive_buf = torch.full( (3, bs, num_verify_tokens), -1, device=device, dtype=torch.long ) retrive_index, retrive_next_token, retrive_next_sibling = retrive_buf # position: where each token belongs to # e.g. if depth of each draft token is [0, 1, 1, 2] and the prompt length is 7 # then, positions = [7, 8, 8, 9] if position_buf is not None: positions = position_buf else: positions = torch.empty( (bs * num_verify_tokens,), device=device, dtype=torch.long ) from sgl_kernel import ( build_tree_kernel_efficient as sgl_build_tree_kernel_efficient, ) sgl_build_tree_kernel_efficient( parent_list, top_scores_index, seq_lens, tree_mask, positions, retrive_index, retrive_next_token, retrive_next_sibling, topk, spec_steps, num_verify_tokens, tree_mask_mode, ) return ( tree_mask, positions, retrive_index, retrive_next_token, retrive_next_sibling, draft_tokens, ) @torch.compile(dynamic=True) def select_top_k_tokens_tmp( i: int, topk_p: torch.Tensor, topk_index: torch.Tensor, hidden_states: torch.Tensor, scores: torch.Tensor, topk: int, ): # FIXME(lsyin): remove this duplicate code if i == 0: # The first step after extend input_ids = topk_index.flatten() hidden_states = hidden_states.repeat_interleave(topk, dim=0) scores = topk_p # shape: (b, topk) tree_info = ( topk_p.unsqueeze(1), # shape: (b, 1, topk) topk_index, # shape: (b, topk) torch.arange(-1, topk, dtype=torch.long, device=hidden_states.device) .unsqueeze(0) .repeat(topk_p.shape[0], 1), # shape: (b, topk + 1) ) else: # The later decode steps expand_scores = torch.mul( scores.unsqueeze(2), topk_p.reshape(-1, topk, topk) ) # (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk) topk_cs_p, topk_cs_index = fast_topk( expand_scores.flatten(start_dim=1), topk, dim=-1 ) # (b, topk) scores = topk_cs_p # shape: (b, topk) topk_index = topk_index.reshape(-1, topk**2) input_ids = torch.gather(topk_index, index=topk_cs_index, dim=1).flatten() selected_input_index = topk_cs_index.flatten() // topk + torch.arange( 0, hidden_states.shape[0], step=topk, device=hidden_states.device ).repeat_interleave(topk) hidden_states = hidden_states[selected_input_index, :] tree_info = ( expand_scores, # shape: (b, topk, topk) topk_index, # shape: (b, topk * topk) topk_cs_index + (topk**2 * (i - 1) + topk), # shape: (b, topk) ) return input_ids, hidden_states, scores, tree_info @triton.jit def fill_new_verified_id( verified_id, accept_lens, new_verified_id, num_draft_tokens: tl.constexpr, ): # NOTE: we cannot fuse any in-place operations of `accept_lens` inside this kernel # because this kernel reads accept_lens pid = tl.program_id(axis=0) accept_length = tl.load(accept_lens + pid) verified_id_idx = num_draft_tokens * pid + accept_length - 1 verified_id_data = tl.load(verified_id + verified_id_idx) tl.store(new_verified_id + pid, verified_id_data) @triton.jit def fill_accepted_out_cache_loc( accept_index, out_cache_loc, accepted_out_cache_loc, size_upper: tl.constexpr, ): pid = tl.program_id(axis=0) offset = tl.arange(0, size_upper) masks = (tl.load(accept_index + offset, offset < pid, other=-1) != -1).to(tl.int64) dst = tl.sum(masks) src = tl.load(accept_index + pid) if src > -1: value = tl.load(out_cache_loc + src) tl.store(accepted_out_cache_loc + dst, value) @triton.jit def assign_extend_cache_locs( req_pool_indices, req_to_token, start_offset, end_offset, out_cache_loc, pool_len: tl.constexpr, bs_upper: tl.constexpr, ): BLOCK_SIZE: tl.constexpr = 32 pid = tl.program_id(axis=0) kv_start = tl.load(start_offset + pid) kv_end = tl.load(end_offset + pid) token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len length_offset = tl.arange(0, bs_upper) start = tl.load(start_offset + length_offset, mask=length_offset < pid, other=0) end = tl.load(end_offset + length_offset, mask=length_offset < pid, other=0) out_offset = tl.sum(end - start, axis=0) out_cache_ptr = out_cache_loc + out_offset load_offset = tl.arange(0, BLOCK_SIZE) + kv_start save_offset = tl.arange(0, BLOCK_SIZE) num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE) for _ in range(num_loop): mask = load_offset < kv_end data = tl.load(token_pool + load_offset, mask=mask) tl.store(out_cache_ptr + save_offset, data, mask=mask) load_offset += BLOCK_SIZE save_offset += BLOCK_SIZE