diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 789cd12db..264d89bb9 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1061,38 +1061,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ) return req_pool_indices - def allocate_for_eagle_v2(self): - from sglang.srt.speculative.eagle_info import EagleDraftInput - from sglang.srt.speculative.spec_utils import assign_req_to_token_pool - - bs = self.batch_size() - - assert self.spec_info.is_draft_input() - draft_input: EagleDraftInput = self.spec_info - - # FIXME(lsyin): now implementation does not enable over-allocation - # Now seq_lens and allocate_lens are correct - self.maybe_wait_verify_done() - - new_allocate_lens = self.seq_lens + EagleDraftInput.ALLOC_LEN_PER_DECODE - num_needed_tokens = (new_allocate_lens - draft_input.allocate_lens).sum().item() - out_cache_loc = alloc_token_slots(self.tree_cache, num_needed_tokens) - - assign_req_to_token_pool[(bs,)]( - self.req_pool_indices, - self.req_to_token_pool.req_to_token, - draft_input.allocate_lens, - new_allocate_lens, - out_cache_loc, - self.req_to_token_pool.req_to_token.shape[1], - next_power_of_2(bs), - ) - draft_input.allocate_lens = new_allocate_lens - - # FIXME(lsyin): remove seq_lens_sum calculation - self.seq_lens_cpu = self.seq_lens.cpu() - self.seq_lens_sum = self.seq_lens_cpu.sum().item() - def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]): self.encoder_lens_cpu = [] self.encoder_cached = [] @@ -1522,8 +1490,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): bs = len(self.reqs) if self.is_v2_eagle: - # FIXME(lsyin): make this sync optional - self.allocate_for_eagle_v2() + # TODO(spec-v2): all v2 spec should go through this path + from sglang.srt.speculative.eagle_info import EagleDraftInput + + draft_input: EagleDraftInput = self.spec_info + draft_input.prepare_for_decode(self) if not self.spec_algorithm.is_none(): # if spec decoding is used, the decode batch is prepared inside diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index cec6af433..682a23586 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -215,10 +215,10 @@ class GenerationBatchResult: delay_sample_func: Optional[callable] = None future_indices: Optional[FutureIndices] = None - # FIXME(lsyin): maybe move to ? + # FIXME(lsyin): maybe move to a better place? # sync path: forward stream -> output processor accept_lens: Optional[torch.Tensor] = None - last_batch_allocate_lens: Optional[torch.Tensor] = None + allocate_lens: Optional[torch.Tensor] = None # relay path: forward stream -> next step forward next_draft_input: Optional[EagleDraftInput] = None @@ -246,10 +246,8 @@ class GenerationBatchResult: if self.accept_lens is not None: self.accept_lens = self.accept_lens.to("cpu", non_blocking=True) - if self.last_batch_allocate_lens is not None: - self.last_batch_allocate_lens = self.last_batch_allocate_lens.to( - "cpu", non_blocking=True - ) + if self.allocate_lens is not None: + self.allocate_lens = self.allocate_lens.to("cpu", non_blocking=True) self.copy_done.record() diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index bde62957e..64d34bf03 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -42,23 +42,21 @@ class SchedulerOutputProcessorMixin: skip_stream_req = None if self.is_generation: + if result.copy_done is not None: + result.copy_done.synchronize() + ( logits_output, next_token_ids, extend_input_len_per_req, extend_logprob_start_len_per_req, - copy_done, ) = ( result.logits_output, result.next_token_ids, result.extend_input_len_per_req, result.extend_logprob_start_len_per_req, - result.copy_done, ) - if copy_done is not None: - copy_done.synchronize() - # Move next_token_ids and logprobs to cpu next_token_ids = next_token_ids.tolist() if batch.return_logprob: @@ -199,57 +197,52 @@ class SchedulerOutputProcessorMixin: self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req) - def hacky_process_eagle_overlap_result( + def _resolve_spec_overlap_token_ids( self: Scheduler, result: GenerationBatchResult, batch: ScheduleBatch - ): - # TODO(lsyin): try use a copy stream to share SMs with forward - # FIXME(lsyin): better organize this token free logic in eagle-overlap - last_batch_allocate_lens_cpu = result.last_batch_allocate_lens.tolist() - accept_lens_cpu = result.accept_lens.tolist() + ) -> List[List[int]]: + """Resolve the padding next token ids for speculative decoding with overlap.""" + assert result.next_token_ids.is_cpu + assert result.accept_lens.is_cpu + assert result.allocate_lens.is_cpu + next_token_ids = result.next_token_ids.tolist() + accept_lens = result.accept_lens.tolist() + result.num_accepted_tokens = sum(accept_lens) predict_tokens = [] - num_draft_tokens = self.draft_worker.speculative_num_draft_tokens + stride = self.draft_worker.speculative_num_draft_tokens for i, req in enumerate(batch.reqs): predict_tokens.append( - next_token_ids[ - i * num_draft_tokens : i * num_draft_tokens + accept_lens_cpu[i] - ] + next_token_ids[i * stride : i * stride + accept_lens[i]] ) - # FIXME(lsyin): move this update elsewhere req.spec_verify_ct += 1 - return last_batch_allocate_lens_cpu, accept_lens_cpu, predict_tokens + return predict_tokens def process_batch_result_decode( self: Scheduler, batch: ScheduleBatch, result: GenerationBatchResult, ): - logits_output, next_token_ids, can_run_cuda_graph, copy_done = ( + if result.copy_done is not None: + result.copy_done.synchronize() + + logits_output, next_token_ids, can_run_cuda_graph = ( result.logits_output, result.next_token_ids, result.can_run_cuda_graph, - result.copy_done, ) - self.num_generated_tokens += len(batch.reqs) - - if copy_done is not None: - copy_done.synchronize() if batch.spec_algorithm.is_none(): next_token_ids = next_token_ids.tolist() if batch.return_logprob: next_token_logprobs = logits_output.next_token_logprobs.tolist() elif batch.is_v2_eagle: - ( - last_batch_allocate_lens_cpu, - accept_lens_cpu, - next_token_ids, - ) = self.hacky_process_eagle_overlap_result(result, batch) - result.num_accepted_tokens = sum(accept_lens_cpu) + next_token_ids = self._resolve_spec_overlap_token_ids(result, batch) + allocate_lens_list = result.allocate_lens.tolist() + accept_lens_list = result.accept_lens.tolist() - # FIXME(lsyin): we suppose we have already got the num_accepted_tokens in result + self.num_generated_tokens += len(batch.reqs) if not self.spec_algorithm.is_none(): self.update_spec_metrics(batch.batch_size(), result.num_accepted_tokens) @@ -264,43 +257,38 @@ class SchedulerOutputProcessorMixin: continue if self.enable_overlap and req.finished(): + indices_to_free = None if self.page_size == 1: if batch.spec_algorithm.is_eagle(): - from sglang.srt.speculative.eagle_worker_v2 import ( - free_spec_dec_tokens_page_size_1, - ) + from sglang.srt.speculative.eagle_info import EagleDraftInput - free_spec_dec_tokens_page_size_1( - self.req_to_token_pool, - self.token_to_kv_pool_allocator, - req, - last_batch_allocate_lens_cpu[i], - None, - ) + end_p = allocate_lens_list[i] + start_p = end_p - EagleDraftInput.ALLOC_LEN_PER_DECODE + indices_to_free = self.req_to_token_pool.req_to_token[ + req.req_pool_idx + ][start_p:end_p] else: # Free the one extra delayed token - self.token_to_kv_pool_allocator.free( - batch.out_cache_loc[i : i + 1] - ) + indices_to_free = batch.out_cache_loc[i : i + 1] else: if batch.spec_algorithm.is_eagle(): - # TODO(lsyin): support eagle with page_size > 1 + # TODO(spec-v2): support eagle with page_size > 1 raise NotImplementedError() else: if ( len(req.origin_input_ids) + len(req.output_ids) - 1 ) % self.page_size == 0: # Only free when the extra token is in a new page - self.token_to_kv_pool_allocator.free( - batch.out_cache_loc[i : i + 1] - ) + indices_to_free = batch.out_cache_loc[i : i + 1] + + if indices_to_free is not None: + self.token_to_kv_pool_allocator.free(indices_to_free) continue if batch.spec_algorithm.is_none(): req.output_ids.append(next_token_id) elif batch.is_v2_eagle: - # FIXME(lsyin): non-overlap spec worker will solve the output_ids in speculative decoding - # !!!unify the logic here!!! + # Only v2 eagle's output_ids are updated here. req.output_ids.extend(next_token_id) req.check_finished() @@ -308,24 +296,13 @@ class SchedulerOutputProcessorMixin: if batch.is_v2_eagle and self.cur_batch.forward_mode.is_extend(): # FIXME(lsyin): fix the messy logic here # 1) when not overlap (v2 impl), we free the extra tokens in the req - # 2) when overlap and current batch is extend, we free the extra tokens in the req of the previous batch - from sglang.srt.speculative.eagle_worker_v2 import ( - free_spec_dec_tokens_page_size_1, - ) - - new_seq_len = len(req.origin_input_ids) + len(req.output_ids) - 1 - # FIXME(lsyin): remove this assert - assert new_seq_len == int( - batch.seq_lens_cpu[i] + accept_lens_cpu[i] - ), f"{new_seq_len=} vs {batch.seq_lens_cpu[i] + accept_lens_cpu[i]=}" - - free_spec_dec_tokens_page_size_1( - self.req_to_token_pool, - self.token_to_kv_pool_allocator, - req, - last_batch_allocate_lens_cpu[i], - new_seq_len, - ) + # 2) overlap eagle and the current batch is prefill. This seq will not run extra iteration. + start_p = batch.seq_lens_cpu[i] + accept_lens_list[i] + end_p = allocate_lens_list[i] + indices_to_free = self.req_to_token_pool.req_to_token[ + req.req_pool_idx + ][start_p:end_p] + self.token_to_kv_pool_allocator.free(indices_to_free) if self.server_args.disaggregation_decode_enable_offload_kvcache: # Asynchronously offload KV cache; cache_finished_req will be called after Device->Host transfer completes diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 6546781de..0a623d4a2 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -15,6 +15,7 @@ from __future__ import annotations import logging +from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Optional import torch @@ -54,7 +55,140 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class TpModelWorker: +class BaseTpWorker(ABC): + @abstractmethod + def forward_batch_generation(self, forward_batch: ForwardBatch): + pass + + @property + @abstractmethod + def model_runner(self) -> ModelRunner: + pass + + @property + def sliding_window_size(self) -> Optional[int]: + return self.model_runner.sliding_window_size + + @property + def is_hybrid(self) -> bool: + return self.model_runner.is_hybrid is not None + + def get_tokens_per_layer_info(self): + return ( + self.model_runner.full_max_total_num_tokens, + self.model_runner.swa_max_total_num_tokens, + ) + + def get_pad_input_ids_func(self): + return getattr(self.model_runner.model, "pad_input_ids", None) + + def get_tp_group(self): + return self.model_runner.tp_group + + def get_attention_tp_group(self): + return self.model_runner.attention_tp_group + + def get_attention_tp_cpu_group(self): + return getattr(self.model_runner.attention_tp_group, "cpu_group", None) + + def get_memory_pool(self): + return ( + self.model_runner.req_to_token_pool, + self.model_runner.token_to_kv_pool_allocator, + ) + + def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput): + success, message = self.model_runner.update_weights_from_disk( + recv_req.model_path, recv_req.load_format + ) + return success, message + + def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput): + success, message = self.model_runner.init_weights_update_group( + recv_req.master_address, + recv_req.master_port, + recv_req.rank_offset, + recv_req.world_size, + recv_req.group_name, + recv_req.backend, + ) + return success, message + + def destroy_weights_update_group(self, recv_req: DestroyWeightsUpdateGroupReqInput): + success, message = self.model_runner.destroy_weights_update_group( + recv_req.group_name, + ) + return success, message + + def init_weights_send_group_for_remote_instance( + self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput + ): + success, message = ( + self.model_runner.init_weights_send_group_for_remote_instance( + recv_req.master_address, + recv_req.ports, + recv_req.group_rank, + recv_req.world_size, + recv_req.group_name, + recv_req.backend, + ) + ) + return success, message + + def send_weights_to_remote_instance( + self, recv_req: SendWeightsToRemoteInstanceReqInput + ): + success, message = self.model_runner.send_weights_to_remote_instance( + recv_req.master_address, + recv_req.ports, + recv_req.group_name, + ) + return success, message + + def update_weights_from_distributed( + self, recv_req: UpdateWeightsFromDistributedReqInput + ): + success, message = self.model_runner.update_weights_from_distributed( + recv_req.names, recv_req.dtypes, recv_req.shapes, recv_req.group_name + ) + return success, message + + def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput): + + monkey_patch_torch_reductions() + success, message = self.model_runner.update_weights_from_tensor( + named_tensors=MultiprocessingSerializer.deserialize( + recv_req.serialized_named_tensors[self.tp_rank] + ), + load_format=recv_req.load_format, + ) + return success, message + + def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput): + parameter = self.model_runner.get_weights_by_name( + recv_req.name, recv_req.truncate_size + ) + return parameter + + def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput): + result = self.model_runner.load_lora_adapter(recv_req.to_ref()) + return result + + def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput): + result = self.model_runner.unload_lora_adapter(recv_req.to_ref()) + return result + + def can_run_lora_batch(self, lora_ids: list[str]) -> bool: + return self.model_runner.lora_manager.validate_lora_batch(lora_ids) + + def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch): + forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) + logits_output, _ = self.model_runner.forward(forward_batch) + embeddings = logits_output.embeddings + return embeddings + + +class TpModelWorker(BaseTpWorker): """A tensor parallel model worker.""" def __init__( @@ -92,7 +226,7 @@ class TpModelWorker: is_draft_model=is_draft_worker, ) - self.model_runner = ModelRunner( + self._model_runner = ModelRunner( model_config=self.model_config, mem_fraction_static=server_args.mem_fraction_static, gpu_id=gpu_id, @@ -171,6 +305,10 @@ class TpModelWorker: self.enable_overlap = not server_args.disable_overlap_schedule self.hicache_layer_transfer_counter = None + @property + def model_runner(self) -> ModelRunner: + return self._model_runner + def register_hicache_layer_transfer_counter(self, counter: LayerDoneCounter): self.hicache_layer_transfer_counter = counter @@ -193,38 +331,6 @@ class TpModelWorker: self.model_runner.token_to_kv_pool.size, ) - @property - def sliding_window_size(self) -> Optional[int]: - return self.model_runner.sliding_window_size - - @property - def is_hybrid(self) -> bool: - return self.model_runner.is_hybrid is not None - - def get_tokens_per_layer_info(self): - return ( - self.model_runner.full_max_total_num_tokens, - self.model_runner.swa_max_total_num_tokens, - ) - - def get_pad_input_ids_func(self): - return getattr(self.model_runner.model, "pad_input_ids", None) - - def get_tp_group(self): - return self.model_runner.tp_group - - def get_attention_tp_group(self): - return self.model_runner.attention_tp_group - - def get_attention_tp_cpu_group(self): - return getattr(self.model_runner.attention_tp_group, "cpu_group", None) - - def get_memory_pool(self): - return ( - self.model_runner.req_to_token_pool, - self.model_runner.token_to_kv_pool_allocator, - ) - def forward_batch_generation( self, model_worker_batch: ModelWorkerBatch, @@ -313,93 +419,3 @@ class TpModelWorker: pp_hidden_states_proxy_tensors=pp_proxy_tensors, can_run_cuda_graph=can_run_cuda_graph, ) - - def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch): - forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) - logits_output, _ = self.model_runner.forward(forward_batch) - embeddings = logits_output.embeddings - return embeddings - - def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput): - success, message = self.model_runner.update_weights_from_disk( - recv_req.model_path, recv_req.load_format - ) - return success, message - - def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput): - success, message = self.model_runner.init_weights_update_group( - recv_req.master_address, - recv_req.master_port, - recv_req.rank_offset, - recv_req.world_size, - recv_req.group_name, - recv_req.backend, - ) - return success, message - - def destroy_weights_update_group(self, recv_req: DestroyWeightsUpdateGroupReqInput): - success, message = self.model_runner.destroy_weights_update_group( - recv_req.group_name, - ) - return success, message - - def init_weights_send_group_for_remote_instance( - self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput - ): - success, message = ( - self.model_runner.init_weights_send_group_for_remote_instance( - recv_req.master_address, - recv_req.ports, - recv_req.group_rank, - recv_req.world_size, - recv_req.group_name, - recv_req.backend, - ) - ) - return success, message - - def send_weights_to_remote_instance( - self, recv_req: SendWeightsToRemoteInstanceReqInput - ): - success, message = self.model_runner.send_weights_to_remote_instance( - recv_req.master_address, - recv_req.ports, - recv_req.group_name, - ) - return success, message - - def update_weights_from_distributed( - self, recv_req: UpdateWeightsFromDistributedReqInput - ): - success, message = self.model_runner.update_weights_from_distributed( - recv_req.names, recv_req.dtypes, recv_req.shapes, recv_req.group_name - ) - return success, message - - def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput): - - monkey_patch_torch_reductions() - success, message = self.model_runner.update_weights_from_tensor( - named_tensors=MultiprocessingSerializer.deserialize( - recv_req.serialized_named_tensors[self.tp_rank] - ), - load_format=recv_req.load_format, - ) - return success, message - - def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput): - parameter = self.model_runner.get_weights_by_name( - recv_req.name, recv_req.truncate_size - ) - return parameter - - def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput): - result = self.model_runner.load_lora_adapter(recv_req.to_ref()) - return result - - def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput): - result = self.model_runner.unload_lora_adapter(recv_req.to_ref()) - return result - - def can_run_lora_batch(self, lora_ids: list[str]) -> bool: - return self.model_runner.lora_manager.validate_lora_batch(lora_ids) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index d24ce8ae3..90635c776 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -53,7 +53,6 @@ from sglang.srt.utils import ( empty_context, get_available_gpu_memory, get_bool_env_var, - get_device_memory_capacity, is_hip, log_info_on_rank0, require_attn_tp_gather, @@ -274,7 +273,6 @@ class CudaGraphRunner: self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value() ) - # FIXME(lsyin): leave it here for now, I don't know whether it is necessary self.encoder_len_fill_value = 0 self.seq_lens_cpu = torch.full( (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32 diff --git a/python/sglang/srt/speculative/base_spec_worker.py b/python/sglang/srt/speculative/base_spec_worker.py new file mode 100644 index 000000000..c77d9b86b --- /dev/null +++ b/python/sglang/srt/speculative/base_spec_worker.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from sglang.srt.managers.tp_worker import TpModelWorker + + +class BaseDraftWorker(ABC): + @abstractmethod + def draft(): + pass + + @abstractmethod + def draft_extend(): + pass + + +class BaseSpecWorker(ABC): + @property + @abstractmethod + def target_worker(self) -> TpModelWorker: + pass + + @property + @abstractmethod + def draft_worker(self) -> BaseDraftWorker: + pass diff --git a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py index 7a04b5c12..a2ce4614b 100644 --- a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py @@ -40,7 +40,11 @@ class EAGLEDraftCudaGraphRunner: def __init__(self, eagle_worker: EAGLEWorker): # Parse args self.eagle_worker = eagle_worker - self.model_runner = model_runner = eagle_worker.model_runner + if not hasattr(eagle_worker, "model_runner"): + # V2: EagleDraftWorker + self.model_runner = model_runner = eagle_worker.draft_runner + else: + self.model_runner = model_runner = eagle_worker.model_runner self.graphs = {} self.output_buffers = {} self.enable_torch_compile = model_runner.server_args.enable_torch_compile diff --git a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py index 72f182ed9..9612a8da2 100644 --- a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py @@ -38,7 +38,12 @@ class EAGLEDraftExtendCudaGraphRunner: def __init__(self, eagle_worker: EAGLEWorker): # Parse args self.eagle_worker = eagle_worker - self.model_runner = model_runner = eagle_worker.model_runner + if not hasattr(eagle_worker, "model_runner"): + # V2: EagleDraftWorker + self.model_runner = model_runner = eagle_worker.draft_runner + else: + self.model_runner = model_runner = eagle_worker.model_runner + self.graphs = {} self.output_buffers = {} self.enable_torch_compile = model_runner.server_args.enable_torch_compile @@ -285,7 +290,7 @@ class EAGLEDraftExtendCudaGraphRunner: output_cache_loc_backup = forward_batch.out_cache_loc hidden_states_backup = forward_batch.spec_info.hidden_states - ret = self.eagle_worker.draft_model_runner.model.forward( + ret = self.model_runner.model.forward( forward_batch.input_ids, forward_batch.positions, forward_batch, diff --git a/python/sglang/srt/speculative/eagle_info.py b/python/sglang/srt/speculative/eagle_info.py index 4373aa3f3..083814c91 100644 --- a/python/sglang/srt/speculative/eagle_info.py +++ b/python/sglang/srt/speculative/eagle_info.py @@ -574,6 +574,9 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin): @dataclass class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin): + # Constant: alloc length per decode step + ALLOC_LEN_PER_DECODE: ClassVar[int] = None + # The inputs for decode # shape: (b, topk) topk_p: torch.Tensor = None @@ -609,9 +612,6 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin): new_seq_lens: Optional[torch.Tensor] = None verify_done: Optional[torch.cuda.Event] = None - # FIXME(lsyin): remove this hack - ALLOC_LEN_PER_DECODE: ClassVar[int] = None - def __post_init__(self): super().__init__(SpecInputType.EAGLE_DRAFT) diff --git a/python/sglang/srt/speculative/eagle_info_v2.py b/python/sglang/srt/speculative/eagle_info_v2.py index 339576965..6ba42f326 100644 --- a/python/sglang/srt/speculative/eagle_info_v2.py +++ b/python/sglang/srt/speculative/eagle_info_v2.py @@ -9,7 +9,8 @@ import triton import triton.language as tl from sglang.srt.layers.logits_processor import LogitsProcessorOutput -from sglang.srt.managers.schedule_batch import ModelWorkerBatch +from sglang.srt.managers.schedule_batch import ModelWorkerBatch, ScheduleBatch +from sglang.srt.mem_cache.common import alloc_token_slots from sglang.srt.mem_cache.memory_pool import ReqToTokenPool from sglang.srt.model_executor.forward_batch_info import ( CaptureHiddenMode, @@ -72,6 +73,34 @@ def assign_draft_cache_locs_page_size_1( @dataclass class EagleDraftInputV2Mixin: + def prepare_for_decode(self: EagleDraftInput, batch: ScheduleBatch): + from sglang.srt.speculative.spec_utils import assign_req_to_token_pool + + bs = batch.batch_size() + + # TODO(lsyin): implement over-allocation + # Now seq_lens and allocate_lens are correct + batch.maybe_wait_verify_done() + + new_allocate_lens = batch.seq_lens + self.ALLOC_LEN_PER_DECODE + num_needed_tokens = (new_allocate_lens - self.allocate_lens).sum().item() + out_cache_loc = alloc_token_slots(batch.tree_cache, num_needed_tokens) + + assign_req_to_token_pool[(bs,)]( + batch.req_pool_indices, + batch.req_to_token_pool.req_to_token, + self.allocate_lens, + new_allocate_lens, + out_cache_loc, + batch.req_to_token_pool.req_to_token.shape[1], + next_power_of_2(bs), + ) + self.allocate_lens = new_allocate_lens + + # FIXME(lsyin): make this sync optional + batch.seq_lens_cpu = batch.seq_lens.cpu() + batch.seq_lens_sum = batch.seq_lens_cpu.sum().item() + def prepare_for_v2_draft( self: EagleDraftInput, req_to_token_pool: ReqToTokenPool, diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index ce01187c2..e141a0238 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -1,17 +1,10 @@ import logging -import os import time -from contextlib import contextmanager from typing import List, Optional, Tuple import torch -from huggingface_hub import snapshot_download -from sglang.srt.distributed import ( - GroupCoordinator, - get_tp_group, - patch_tensor_parallel_group, -) +from sglang.srt.distributed import get_tp_group from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs from sglang.srt.managers.schedule_batch import ScheduleBatch @@ -47,15 +40,17 @@ from sglang.srt.speculative.eagle_utils import ( from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.speculative.spec_utils import ( assign_draft_cache_locs, + detect_nan, + draft_tp_context, fast_topk, generate_token_bitmask, + load_token_map, select_top_k_tokens, ) from sglang.srt.utils import ( empty_context, get_available_gpu_memory, get_bool_env_var, - is_blackwell, is_cuda, next_power_of_2, ) @@ -67,14 +62,6 @@ logger = logging.getLogger(__name__) SGLANG_RETURN_ORIGINAL_LOGPROB = get_bool_env_var("SGLANG_RETURN_ORIGINAL_LOGPROB") -@contextmanager -def draft_tp_context(tp_group: GroupCoordinator): - # Draft model doesn't use dp and has its own tp group. - # We disable mscclpp now because it doesn't support 2 comm groups. - with patch_tensor_parallel_group(tp_group): - yield - - class EAGLEWorker(TpModelWorker): def __init__( @@ -100,7 +87,6 @@ class EAGLEWorker(TpModelWorker): self.speculative_algorithm = SpeculativeAlgorithm.from_string( server_args.speculative_algorithm ) - self.padded_static_len = -1 # Override the context length of the draft model to be the same as the target model. server_args.context_length = target_worker.model_runner.model_config.context_len @@ -612,7 +598,8 @@ class EAGLEWorker(TpModelWorker): logits_output, _ = self.draft_model_runner.forward( forward_batch, skip_attn_backend_init=True ) - self._detect_nan_if_needed(logits_output) + if self.server_args.enable_nan_detection: + detect_nan(logits_output) probs = torch.softmax(logits_output.next_token_logits, dim=-1) topk_p, topk_index = fast_topk(probs, self.topk, dim=-1) if self.hot_token_id is not None: @@ -680,7 +667,9 @@ class EAGLEWorker(TpModelWorker): # and will be applied to produce wrong results batch.sampling_info.vocab_mask = None - self._detect_nan_if_needed(logits_output) + if self.enable_nan_detection: + detect_nan(logits_output) + spec_info.hidden_states = logits_output.hidden_states res: EagleVerifyOutput = spec_info.verify( batch, @@ -833,7 +822,8 @@ class EAGLEWorker(TpModelWorker): ) forward_batch.return_logprob = False logits_output, _ = self.draft_model_runner.forward(forward_batch) - self._detect_nan_if_needed(logits_output) + if self.enable_nan_detection: + detect_nan(logits_output) assert isinstance(forward_batch.spec_info, EagleDraftInput) assert forward_batch.spec_info is batch.spec_info self.capture_for_decode(logits_output, forward_batch.spec_info) @@ -928,7 +918,8 @@ class EAGLEWorker(TpModelWorker): ) self.capture_for_decode(logits_output, forward_batch.spec_info) - self._detect_nan_if_needed(logits_output) + if self.enable_nan_detection: + detect_nan(logits_output) # Restore backup. # This is because `seq_lens` can be modified in `prepare_extend_after_decode` @@ -948,24 +939,6 @@ class EAGLEWorker(TpModelWorker): draft_input.topk_p, draft_input.topk_index = fast_topk(probs, self.topk, dim=-1) draft_input.hidden_states = logits_output.hidden_states - def _detect_nan_if_needed(self, logits_output: LogitsProcessorOutput): - if self.enable_nan_detection: - logits = logits_output.next_token_logits - if torch.any(torch.isnan(logits)): - logger.error("Detected errors during sampling! NaN in the logits.") - raise ValueError("Detected errors during sampling! NaN in the logits.") - - -def load_token_map(token_map_path: str) -> List[int]: - if not os.path.exists(token_map_path): - cache_dir = snapshot_download( - os.path.dirname(token_map_path), - ignore_patterns=["*.bin", "*.safetensors"], - ) - token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path)) - hot_token_id = torch.load(token_map_path, weights_only=True) - return torch.tensor(hot_token_id, dtype=torch.int64) - @torch.compile(dynamic=True) def get_last_loc_large_page_size_top_k_1( diff --git a/python/sglang/srt/speculative/eagle_worker_v2.py b/python/sglang/srt/speculative/eagle_worker_v2.py index 0104a370d..832f6b5a8 100644 --- a/python/sglang/srt/speculative/eagle_worker_v2.py +++ b/python/sglang/srt/speculative/eagle_worker_v2.py @@ -1,19 +1,25 @@ import contextlib import logging -from typing import List, Optional +import time +from typing import List, Optional, Tuple import torch from torch.cuda import Stream as CudaStream from sglang.srt.environ import envs -from sglang.srt.layers.logits_processor import LogitsProcessorOutput -from sglang.srt.managers.schedule_batch import ModelWorkerBatch, Req +from sglang.srt.managers.schedule_batch import ModelWorkerBatch from sglang.srt.managers.scheduler import GenerationBatchResult from sglang.srt.managers.tp_worker import TpModelWorker -from sglang.srt.mem_cache.allocator import TokenToKVPoolAllocator -from sglang.srt.mem_cache.memory_pool import ReqToTokenPool from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardBatch from sglang.srt.server_args import ServerArgs +from sglang.srt.speculative.base_spec_worker import BaseDraftWorker, BaseSpecWorker +from sglang.srt.speculative.draft_utils import DraftBackendFactory +from sglang.srt.speculative.eagle_draft_cuda_graph_runner import ( + EAGLEDraftCudaGraphRunner, +) +from sglang.srt.speculative.eagle_draft_extend_cuda_graph_runner import ( + EAGLEDraftExtendCudaGraphRunner, +) from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.eagle_info_v2 import ( assign_extend_cache_locs, @@ -22,69 +28,214 @@ from sglang.srt.speculative.eagle_info_v2 import ( select_top_k_tokens_tmp, ) from sglang.srt.speculative.eagle_utils import TreeMaskMode, build_tree_kernel_efficient -from sglang.srt.speculative.eagle_worker import EAGLEWorker -from sglang.srt.utils.common import fast_topk, next_power_of_2 +from sglang.srt.speculative.spec_info import SpeculativeAlgorithm +from sglang.srt.speculative.spec_utils import ( + detect_nan, + draft_tp_context, + load_token_map, +) +from sglang.srt.utils.common import ( + empty_context, + fast_topk, + get_available_gpu_memory, + next_power_of_2, +) logger = logging.getLogger(__name__) -class EAGLEWorkerV2(EAGLEWorker): +def _get_plan_stream( + device: str, +) -> Tuple[Optional[CudaStream], contextlib.AbstractContextManager]: + if envs.SGLANG_ENABLE_OVERLAP_PLAN_STREAM.get(): + plan_stream: CudaStream = torch.get_device_module(device).Stream() + plan_stream_ctx = torch.cuda.stream(plan_stream) + return plan_stream, plan_stream_ctx + else: + return None, contextlib.nullcontext() + + +class EagleDraftWorker(BaseDraftWorker): def __init__( self, server_args: ServerArgs, gpu_id: int, tp_rank: int, - dp_rank: Optional[int], + dp_rank: int, moe_ep_rank: int, nccl_port: int, target_worker: TpModelWorker, ): - super().__init__( - server_args, - gpu_id, - tp_rank, - dp_rank, - moe_ep_rank, - nccl_port, - target_worker, + # copy args + self.server_args = server_args + self.gpu_id = gpu_id + self.tp_rank = tp_rank + self.dp_rank = dp_rank + self.moe_ep_rank = moe_ep_rank + self.nccl_port = nccl_port + self.target_worker = target_worker + + # Args for easy access + self.device = server_args.device + self.topk = server_args.speculative_eagle_topk + self.speculative_num_steps = server_args.speculative_num_steps + self.speculative_num_draft_tokens = server_args.speculative_num_draft_tokens + self.speculative_algorithm = SpeculativeAlgorithm.from_string( + server_args.speculative_algorithm ) + + # Set constant EagleDraftInput.ALLOC_LEN_PER_DECODE = max( self.speculative_num_steps * self.topk, self.speculative_num_draft_tokens ) + + # Do not capture cuda graph in `TpModelWorker` init, + # will capture later with init_cuda_graphs() + backup_disable_cuda_graph = server_args.disable_cuda_graph + server_args.disable_cuda_graph = True + + # Share the allocator with a target worker. + # Draft and target worker own their own KV cache pools. + self.req_to_token_pool, self.token_to_kv_pool_allocator = ( + target_worker.get_memory_pool() + ) + with empty_context(): + # Init draft worker + self.draft_worker = TpModelWorker( + server_args=server_args, + gpu_id=gpu_id, + tp_rank=tp_rank, + pp_rank=0, # FIXME + dp_rank=dp_rank, + moe_ep_rank=moe_ep_rank, + nccl_port=nccl_port, + is_draft_worker=True, + req_to_token_pool=self.req_to_token_pool, + token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, + ) + + # Alias for better readability + self.draft_runner = self.draft_worker.model_runner + + self.init_token_map() + self.init_lm_head() + + # Init attention backend and cuda graphs + self.draft_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph + self.draft_tp_context = ( + draft_tp_context if server_args.enable_dp_attention else empty_context + ) + with self.draft_tp_context(self.draft_runner.tp_group): + self.init_attention_backend() + self.init_cuda_graphs() + self.tree_mask_mode = TreeMaskMode.FULL_MASK - if envs.SGLANG_ENABLE_OVERLAP_PLAN_STREAM.get(): - self.plan_stream: CudaStream = torch.get_device_module(self.device).Stream() - self.plan_stream_ctx = torch.cuda.stream(self.plan_stream) - else: - self.plan_stream = None - self.plan_stream_ctx = contextlib.nullcontext() + self.plan_stream, self.plan_stream_ctx = _get_plan_stream(self.device) - def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): - if model_worker_batch.forward_mode.is_decode(): - # FIXME(lsyin): why shall we use spec_info for both draft and verify? - draft_input: EagleDraftInput = model_worker_batch.spec_info - assert draft_input.is_draft_input() - verify_input: EagleVerifyInput = self.draft(model_worker_batch) - assert verify_input.is_verify_input() - model_worker_batch.spec_info = verify_input - batch_output = self.verify(model_worker_batch, draft_input.allocate_lens) - return batch_output + def init_token_map(self): + # Load hot token ids + if self.speculative_algorithm.is_eagle3(): + if self.server_args.speculative_token_map is not None: + logger.warning( + "Speculative token map specified, but EAGLE3 models already have this. Ignoring the specified token map." + ) + self.hot_token_id = None + elif self.server_args.speculative_token_map is not None: + self.hot_token_id = load_token_map(self.server_args.speculative_token_map) + self.server_args.json_model_override_args = ( + f'{{"hot_vocab_size": {len(self.hot_token_id)}}}' + ) else: - # Target prefill - model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL - batch_output = self.target_worker.forward_batch_generation( - model_worker_batch + self.hot_token_id = None + + def init_lm_head(self): + embed, head = self.target_worker.model_runner.model.get_embed_and_head() + if self.speculative_algorithm.is_eagle3(): + # most cases EAGLE3 models don't share lm_head + # but some models (e.g. nvidia/gpt-oss-120b-Eagle3) shares + if ( + hasattr(self.draft_runner.model, "load_lm_head_from_target") + and self.draft_runner.model.load_lm_head_from_target + ): + self.draft_runner.model.set_embed_and_head(embed, head) + else: + self.draft_runner.model.set_embed(embed) + + # grab hot token ids + if self.draft_runner.model.hot_token_id is not None: + self.hot_token_id = self.draft_runner.model.hot_token_id.to( + embed.device + ) + + else: + if self.hot_token_id is not None: + head = head.clone() + self.hot_token_id = self.hot_token_id.to(head.device) + head.data = head.data[self.hot_token_id] + + # Share the embedding and lm_head + self.draft_runner.model.set_embed_and_head(embed, head) + + def init_attention_backend(self): + # Create multi-step attn backends and cuda graph runners + + self.has_prefill_wrapper_verify = False + self.draft_extend_attn_backend = None + + draft_backend_factory = DraftBackendFactory( + self.server_args, + self.draft_runner, + self.topk, + self.speculative_num_steps, + ) + + # Initialize decode attention backend + self.draft_attn_backend = draft_backend_factory.create_decode_backend() + + # Initialize draft extend attention backend (respects speculative_attention_mode setting) + self.draft_extend_attn_backend = ( + draft_backend_factory.create_draft_extend_backend() + ) + + self.draft_runner.draft_attn_backend = self.draft_attn_backend + self.tree_mask_mode = TreeMaskMode.FULL_MASK + + def init_cuda_graphs(self): + """Capture cuda graphs.""" + self.cuda_graph_runner = None + self.cuda_graph_runner_for_draft_extend = None + + if self.server_args.disable_cuda_graph: + return + + # Capture draft + if self.speculative_num_steps > 1: + tic = time.perf_counter() + before_mem = get_available_gpu_memory(self.device, self.gpu_id) + logger.info( + f"Capture draft cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB" + ) + self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self) + after_mem = get_available_gpu_memory(self.device, self.gpu_id) + logger.info( + f"Capture draft cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB." ) - # Draft prefill - model_worker_batch.capture_hidden_mode = CaptureHiddenMode.LAST - batch_output.next_draft_input = self.forward_draft_extend( - model_worker_batch, - batch_output.logits_output.hidden_states, - batch_output.next_token_ids, + # Capture extend + if self.draft_extend_attn_backend: + tic = time.perf_counter() + before_mem = get_available_gpu_memory(self.device, self.gpu_id) + logger.info( + f"Capture draft extend cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB" + ) + self.cuda_graph_runner_for_draft_extend = EAGLEDraftExtendCudaGraphRunner( + self + ) + after_mem = get_available_gpu_memory(self.device, self.gpu_id) + logger.info( + f"Capture draft extend cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB." ) - return batch_output def draft(self, model_worker_batch: ModelWorkerBatch): draft_input: EagleDraftInput = model_worker_batch.spec_info @@ -92,7 +243,7 @@ class EAGLEWorkerV2(EAGLEWorker): self.req_to_token_pool, model_worker_batch, self.cuda_graph_runner, - self.draft_model_runner, + self.draft_runner, self.topk, self.speculative_num_steps, ) @@ -201,10 +352,11 @@ class EAGLEWorkerV2(EAGLEWorker): spec_info.hidden_states = hidden_states # Run forward - logits_output = self.draft_model_runner.model.forward( + logits_output = self.draft_runner.model.forward( forward_batch.input_ids, forward_batch.positions, forward_batch ) - self._detect_nan_if_needed(logits_output) + if self.server_args.enable_nan_detection: + detect_nan(logits_output) probs = torch.softmax(logits_output.next_token_logits, dim=-1) topk_p, topk_index = fast_topk(probs, self.topk, dim=-1) if self.hot_token_id is not None: @@ -233,10 +385,190 @@ class EAGLEWorkerV2(EAGLEWorker): return parent_list, top_scores_index, draft_tokens + def draft_extend(self): + pass + + def _draft_extend_for_prefill( + self, + batch: ModelWorkerBatch, + target_hidden_states: torch.Tensor, + next_token_ids: torch.Tensor, + ): + """ + Run draft model extend to correctly fill the KV cache. + + Args: + batch: The batch to run. + target_hidden_states: Hidden states from the target model forward + next_token_ids: Next token ids generated from the target forward. + """ + # Construct input_ids + pt = 0 + for i, extend_len in enumerate(batch.extend_seq_lens): + input_ids = batch.input_ids[pt : pt + extend_len] + batch.input_ids[pt : pt + extend_len] = torch.cat( + (input_ids[1:], next_token_ids[i].reshape(1)) + ) + pt += extend_len + + # Construct spec_info + next_draft_input = EagleDraftInput( + hidden_states=target_hidden_states, + verified_id=next_token_ids, + new_seq_lens=batch.seq_lens, + allocate_lens=batch.seq_lens, + ) + batch.spec_info = next_draft_input + + # Run forward + forward_batch = ForwardBatch.init_new(batch, self.draft_runner) + logits_output, _ = self.draft_runner.forward(forward_batch) + + # Update spec_info for the next draft step + probs = torch.softmax(logits_output.next_token_logits, dim=-1) + next_draft_input.topk_p, next_draft_input.topk_index = fast_topk( + probs, self.topk, dim=-1 + ) + next_draft_input.hidden_states = logits_output.hidden_states + return next_draft_input + + def _draft_extend_for_decode( + self, batch: ModelWorkerBatch, batch_result: GenerationBatchResult + ): + # Batch 2: Draft extend + draft_input = EagleDraftInput( + hidden_states=batch_result.logits_output.hidden_states, + ) + select_index = ( + torch.arange(len(batch.seq_lens), device=self.device) + * self.speculative_num_draft_tokens + + batch_result.accept_lens + - 1 + ) + + # Prepare for draft extend in a separate stream + with self.plan_stream_ctx: + forward_batch = draft_input.prepare_for_extend_to_fill_draft_kvcache( + batch, + batch_result.next_token_ids, + self.speculative_num_draft_tokens, + self.draft_runner, + ) + + if self.plan_stream: + torch.cuda.current_stream().wait_stream(self.plan_stream) + + # Run draft extend batch in the main compute stream + draft_logits_output = self.draft_runner.model.forward( + forward_batch.input_ids, forward_batch.positions, forward_batch + ) + + # Reorganize the spec info for the next batch + draft_logits_output.next_token_logits = draft_logits_output.next_token_logits[ + select_index + ] + draft_logits_output.hidden_states = draft_logits_output.hidden_states[ + select_index + ] + probs = torch.softmax(draft_logits_output.next_token_logits, dim=-1) + ret_topk_p, ret_topk_index = fast_topk(probs, self.topk, dim=-1) + ret_hidden_states = draft_logits_output.hidden_states + + # Construct the return values + next_draft_input = batch_result.next_draft_input + ( + next_draft_input.topk_p, + next_draft_input.topk_index, + next_draft_input.hidden_states, + ) = ( + ret_topk_p, + ret_topk_index, + ret_hidden_states, + ) + + +class EAGLEWorkerV2(BaseSpecWorker): + def __init__( + self, + server_args: ServerArgs, + gpu_id: int, + tp_rank: int, + dp_rank: Optional[int], + moe_ep_rank: int, + nccl_port: int, + target_worker: TpModelWorker, + ): + # Parse arguments + self.server_args = server_args + self.topk = server_args.speculative_eagle_topk + self.speculative_num_steps = server_args.speculative_num_steps + self.speculative_num_draft_tokens = server_args.speculative_num_draft_tokens + self.enable_nan_detection = server_args.enable_nan_detection + self.gpu_id = gpu_id + self.device = server_args.device + self._target_worker = target_worker + self.page_size = server_args.page_size + self.speculative_algorithm = SpeculativeAlgorithm.from_string( + server_args.speculative_algorithm + ) + + self.req_to_token_pool, self.token_to_kv_pool_allocator = ( + target_worker.get_memory_pool() + ) + + # Override the context length of the draft model to be the same as the target model. + server_args.context_length = target_worker.model_runner.model_config.context_len + + self._draft_worker = EagleDraftWorker( + server_args, gpu_id, tp_rank, dp_rank, moe_ep_rank, nccl_port, target_worker + ) + + # Some dummy tensors + self.num_new_pages_per_topk = torch.empty( + (), dtype=torch.int64, device=self.device + ) + self.extend_lens = torch.empty((), dtype=torch.int64, device=self.device) + + self.plan_stream, self.plan_stream_ctx = _get_plan_stream(self.device) + + @property + def target_worker(self): + return self._target_worker + + @property + def draft_worker(self): + return self._draft_worker + + def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): + if model_worker_batch.forward_mode.is_decode(): + draft_input: EagleDraftInput = model_worker_batch.spec_info + assert draft_input.is_draft_input() + verify_input: EagleVerifyInput = self.draft_worker.draft(model_worker_batch) + assert verify_input.is_verify_input() + model_worker_batch.spec_info = verify_input + batch_output = self.verify(model_worker_batch, draft_input.allocate_lens) + self.draft_worker._draft_extend_for_decode(model_worker_batch, batch_output) + return batch_output + else: + # Target prefill + model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL + batch_output = self.target_worker.forward_batch_generation( + model_worker_batch + ) + + # Draft prefill + model_worker_batch.capture_hidden_mode = CaptureHiddenMode.LAST + batch_output.next_draft_input = self.draft_worker._draft_extend_for_prefill( + model_worker_batch, + batch_output.logits_output.hidden_states, + batch_output.next_token_ids, + ) + return batch_output + def verify( self, batch: ModelWorkerBatch, - pre_draft_allocate_lens: torch.Tensor, + cur_allocate_lens: torch.Tensor, ): # Since batch.seq_lens is allocated in another stream, we need # record_stream() to prevent pytorch gc and reuse the gpu memory @@ -284,7 +616,8 @@ class EAGLEWorkerV2(EAGLEWorker): logits_output = forward_batch_output.logits_output # Sample - self._detect_nan_if_needed(logits_output) + if self.enable_nan_detection: + detect_nan(logits_output) ( predict, accept_length, @@ -303,53 +636,11 @@ class EAGLEWorkerV2(EAGLEWorker): self.speculative_num_draft_tokens, ) - # Batch 2: Draft extend - draft_input = EagleDraftInput( - hidden_states=logits_output.hidden_states, - ) - select_index = ( - torch.arange(len(batch.seq_lens), device=self.device) - * self.speculative_num_draft_tokens - + accept_length - - 1 - ) - - # Prepare for draft extend in a separate stream - with self.plan_stream_ctx: - forward_batch = draft_input.prepare_for_extend_to_fill_draft_kvcache( - batch, - predict, - self.speculative_num_draft_tokens, - self.draft_model_runner, - ) - - if self.plan_stream: - torch.cuda.current_stream().wait_stream(self.plan_stream) - - # Run draft extend batch in the main compute stream - draft_logits_output = self.draft_model_runner.model.forward( - forward_batch.input_ids, forward_batch.positions, forward_batch - ) - - # Reorganize the spec info for the next batch - draft_logits_output.next_token_logits = draft_logits_output.next_token_logits[ - select_index - ] - draft_logits_output.hidden_states = draft_logits_output.hidden_states[ - select_index - ] - probs = torch.softmax(draft_logits_output.next_token_logits, dim=-1) - ret_topk_p, ret_topk_index = fast_topk(probs, self.topk, dim=-1) - ret_hidden_states = draft_logits_output.hidden_states - - # Construct the return values + # Construct the next draft input next_draft_input = EagleDraftInput( - topk_p=ret_topk_p, - topk_index=ret_topk_index, - hidden_states=ret_hidden_states, verified_id=verified_id, new_seq_lens=new_seq_lens, - allocate_lens=pre_draft_allocate_lens, + allocate_lens=cur_allocate_lens, verify_done=verify_done, ) @@ -359,53 +650,9 @@ class EAGLEWorkerV2(EAGLEWorker): can_run_cuda_graph=can_run_cuda_graph, next_draft_input=next_draft_input, accept_lens=accept_length, - last_batch_allocate_lens=pre_draft_allocate_lens, + allocate_lens=cur_allocate_lens, ) - def forward_draft_extend( - self, - batch: ModelWorkerBatch, - target_hidden_states: torch.Tensor, - next_token_ids: torch.Tensor, - ): - """ - Run draft model extend to correctly fill the KV cache. - - Args: - batch: The batch to run. - target_hidden_states: Hidden states from the target model forward - next_token_ids: Next token ids generated from the target forward. - """ - # Construct input_ids - pt = 0 - for i, extend_len in enumerate(batch.extend_seq_lens): - input_ids = batch.input_ids[pt : pt + extend_len] - batch.input_ids[pt : pt + extend_len] = torch.cat( - (input_ids[1:], next_token_ids[i].reshape(1)) - ) - pt += extend_len - - # Construct spec_info - next_draft_input = EagleDraftInput( - hidden_states=target_hidden_states, - verified_id=next_token_ids, - new_seq_lens=batch.seq_lens, - allocate_lens=batch.seq_lens, - ) - batch.spec_info = next_draft_input - - # Run forward - forward_batch = ForwardBatch.init_new(batch, self.draft_model_runner) - logits_output, _ = self.draft_model_runner.forward(forward_batch) - - # Update spec_info for the next draft step - probs = torch.softmax(logits_output.next_token_logits, dim=-1) - next_draft_input.topk_p, next_draft_input.topk_index = fast_topk( - probs, self.topk, dim=-1 - ) - next_draft_input.hidden_states = logits_output.hidden_states - return next_draft_input - def move_accepted_tokens_to_target_kvcache( self, batch: ModelWorkerBatch, @@ -449,32 +696,3 @@ class EAGLEWorkerV2(EAGLEWorker): self.token_to_kv_pool_allocator.get_kvcache().move_kv_cache( tgt_cache_loc, accepted_out_cache_loc ) - - def _detect_nan_if_needed(self, logits_output: LogitsProcessorOutput): - if self.enable_nan_detection: - logits = logits_output.next_token_logits - if torch.any(torch.isnan(logits)): - logger.error("Detected errors during sampling! NaN in the logits.") - raise ValueError("Detected errors during sampling! NaN in the logits.") - - -def free_spec_dec_tokens_page_size_1( - req_to_token_pool: ReqToTokenPool, - token_to_kv_pool_allocator: TokenToKVPoolAllocator, - req: Req, - allocate_len: int, - new_seq_len: int, -): - # FIXME(lsyin): move this function elsewhere - - # free extra allocated tokens - if new_seq_len is None: - # True only for overlap eagle and the current batch is decode. This seq will be part of the decode, so the final iteration's allocation is not used (i.e. this case). - start_len = allocate_len - EagleDraftInput.ALLOC_LEN_PER_DECODE - else: - # True for 1) non-overlap; 2) overlap eagle and the current batch is prefill. This seq will not run extra iteration, so start_lens is passed in. - start_len = new_seq_len - indices_to_free = req_to_token_pool.req_to_token[req.req_pool_idx][ - start_len:allocate_len - ] - token_to_kv_pool_allocator.free(indices_to_free) diff --git a/python/sglang/srt/speculative/spec_utils.py b/python/sglang/srt/speculative/spec_utils.py index 4c3c8a070..d89236dbe 100644 --- a/python/sglang/srt/speculative/spec_utils.py +++ b/python/sglang/srt/speculative/spec_utils.py @@ -1,25 +1,37 @@ from __future__ import annotations import logging +import os import time +from contextlib import contextmanager from typing import TYPE_CHECKING, List import torch import triton import triton.language as tl +from huggingface_hub import snapshot_download from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject +from sglang.srt.distributed.parallel_state import ( + GroupCoordinator, + patch_tensor_parallel_group, +) from sglang.srt.environ import envs +from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.schedule_batch import Req from sglang.srt.utils import is_cuda, is_hip +if TYPE_CHECKING: + from sglang.srt.mem_cache.allocator import TokenToKVPoolAllocator + from sglang.srt.mem_cache.memory_pool import ReqToTokenPool + from sglang.srt.speculative.eagle_info import EagleVerifyInput + + if is_cuda(): from sgl_kernel import fast_topk elif is_hip(): from sgl_kernel import fast_topk -if TYPE_CHECKING: - from sglang.srt.speculative.eagle_info import EagleVerifyInput logger = logging.getLogger(__name__) @@ -603,3 +615,29 @@ def generate_token_bitmask( verify_input.grammar = grammar return allocate_token_bitmask + + +def load_token_map(token_map_path: str) -> List[int]: + if not os.path.exists(token_map_path): + cache_dir = snapshot_download( + os.path.dirname(token_map_path), + ignore_patterns=["*.bin", "*.safetensors"], + ) + token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path)) + hot_token_id = torch.load(token_map_path, weights_only=True) + return torch.tensor(hot_token_id, dtype=torch.int64) + + +@contextmanager +def draft_tp_context(tp_group: GroupCoordinator): + # Draft model doesn't use dp and has its own tp group. + # We disable mscclpp now because it doesn't support 2 comm groups. + with patch_tensor_parallel_group(tp_group): + yield + + +def detect_nan(logits_output: LogitsProcessorOutput): + logits = logits_output.next_token_logits + if torch.any(torch.isnan(logits)): + logger.error("Detected errors during sampling! NaN in the logits.") + raise ValueError("Detected errors during sampling! NaN in the logits.") diff --git a/python/sglang/srt/speculative/standalone_worker.py b/python/sglang/srt/speculative/standalone_worker.py index ca0490e39..23f9b9dd2 100644 --- a/python/sglang/srt/speculative/standalone_worker.py +++ b/python/sglang/srt/speculative/standalone_worker.py @@ -1,14 +1,13 @@ import logging -from contextlib import contextmanager from typing import Optional import torch -from sglang.srt.distributed import GroupCoordinator, patch_tensor_parallel_group from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.server_args import ServerArgs -from sglang.srt.speculative.eagle_worker import EAGLEWorker, load_token_map +from sglang.srt.speculative.eagle_worker import EAGLEWorker from sglang.srt.speculative.spec_info import SpeculativeAlgorithm +from sglang.srt.speculative.spec_utils import draft_tp_context, load_token_map from sglang.srt.utils import empty_context, get_bool_env_var, is_cuda if is_cuda(): @@ -18,14 +17,6 @@ logger = logging.getLogger(__name__) SGLANG_RETURN_ORIGINAL_LOGPROB = get_bool_env_var("SGLANG_RETURN_ORIGINAL_LOGPROB") -@contextmanager -def draft_tp_context(tp_group: GroupCoordinator): - # Draft model doesn't use dp and has its own tp group. - # We disable mscclpp now because it doesn't support 2 comm groups. - with patch_tensor_parallel_group(tp_group): - yield - - class StandaloneWorker(EAGLEWorker): def __init__( @@ -51,7 +42,6 @@ class StandaloneWorker(EAGLEWorker): self.speculative_algorithm = SpeculativeAlgorithm.from_string( server_args.speculative_algorithm ) - self.padded_static_len = -1 # Override the context length of the draft model to be the same as the target model. server_args.context_length = target_worker.model_runner.model_config.context_len