diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 5bf4faa21..842f59a3b 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -427,7 +427,7 @@ class CudaGraphRunner: self.capture_hidden_mode = hidden_mode_from_spec_info self.capture() - def replay(self, forward_batch: ForwardBatch): + def replay(self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False): self.recapture_if_needed(forward_batch) raw_bs = forward_batch.batch_size diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index a931cb15a..666b97e2b 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -122,34 +122,115 @@ class ModelRunner: self.token_to_kv_pool_allocator = token_to_kv_pool_allocator # Model-specific adjustment + self.model_specific_adjustment() + + if server_args.show_time_cost: + enable_show_time_cost() + + if server_args.disable_outlines_disk_cache: + from outlines.caching import disable_cache + + disable_cache() + + # Global vars + global_server_args_dict.update( + { + "attention_backend": server_args.attention_backend, + "sampling_backend": server_args.sampling_backend, + "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32, + "disable_mla": server_args.disable_mla, + "torchao_config": server_args.torchao_config, + "enable_nan_detection": server_args.enable_nan_detection, + "enable_dp_attention": server_args.enable_dp_attention, + "enable_ep_moe": server_args.enable_ep_moe, + "device": server_args.device, + "speculative_accept_threshold_single": server_args.speculative_accept_threshold_single, + "speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc, + "enable_flashinfer_mla": server_args.enable_flashinfer_mla, + "disable_radix_cache": server_args.disable_radix_cache, + "flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged, + "debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder, + "debug_tensor_dump_inject": server_args.debug_tensor_dump_inject, + } + ) + + # CPU offload + set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3)) + + # Get memory before model loading + min_per_gpu_memory = self.init_torch_distributed() + + self.memory_saver_adapter = TorchMemorySaverAdapter.create( + enable=self.server_args.enable_memory_saver + ) + + # Load the model + self.sampler = Sampler() + self.load_model() + + # Apply torchao quantization + torchao_applied = getattr(self.model, "torchao_applied", False) + # In layered loading, torchao may have been applied + if not torchao_applied: + apply_torchao_config_to_model( + self.model, global_server_args_dict["torchao_config"] + ) + + # Apply torch TP if the model supports it + supports_torch_tp = getattr(self.model, "supports_torch_tp", False) + if self.tp_size > 1 and supports_torch_tp: + self.apply_torch_tp() + self.torch_tp_applied = True + else: + self.torch_tp_applied = False + + # Init lora + if server_args.lora_paths is not None: + self.init_lora_manager() + + # Init memory pool and attention backends + self.init_memory_pool( + min_per_gpu_memory, + server_args.max_running_requests, + server_args.max_total_tokens, + ) + if self.device == "cuda": + self.init_cublas() + self.init_attention_backend() + self.init_cuda_graphs() + else: + self.cuda_graph_runner = None + self.init_attention_backend() + + def model_specific_adjustment(self): + server_args = self.server_args + if ( self.model_config.attention_arch == AttentionArch.MLA - and not self.server_args.disable_mla + and not server_args.disable_mla ): # TODO: add MLA optimization on CPU - if self.server_args.device != "cpu": + if server_args.device != "cpu": if server_args.enable_flashinfer_mla: logger.info( "MLA optimization is turned on. Use flashinfer mla backend." ) - self.server_args.attention_backend = "flashinfer_mla" + server_args.attention_backend = "flashinfer_mla" else: logger.info("MLA optimization is turned on. Use triton backend.") - self.server_args.attention_backend = "triton" + server_args.attention_backend = "triton" - if self.server_args.enable_double_sparsity: + if server_args.enable_double_sparsity: logger.info( "Double sparsity optimization is turned on. Use triton backend without CUDA graph." ) - self.server_args.attention_backend = "triton" - self.server_args.disable_cuda_graph = True - if self.server_args.ds_heavy_channel_type is None: + server_args.attention_backend = "triton" + server_args.disable_cuda_graph = True + if server_args.ds_heavy_channel_type is None: raise ValueError( "Please specify the heavy channel type for double sparsity optimization." ) - self.init_double_sparsity_channel_config( - self.server_args.ds_heavy_channel_type - ) + self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type) if self.is_multimodal: self.mem_fraction_static *= 0.95 @@ -174,96 +255,10 @@ class ModelRunner: server_args.chunked_prefill_size = -1 server_args.disable_radix_cache = True - # Global vars - if server_args.show_time_cost: - enable_show_time_cost() - if server_args.disable_outlines_disk_cache: - from outlines.caching import disable_cache - - disable_cache() - - global_server_args_dict.update( - { - "attention_backend": server_args.attention_backend, - "sampling_backend": server_args.sampling_backend, - "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32, - "disable_mla": server_args.disable_mla, - "torchao_config": server_args.torchao_config, - "enable_nan_detection": server_args.enable_nan_detection, - "enable_dp_attention": server_args.enable_dp_attention, - "enable_ep_moe": server_args.enable_ep_moe, - "device": server_args.device, - "speculative_accept_threshold_single": server_args.speculative_accept_threshold_single, - "speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc, - "enable_flashinfer_mla": server_args.enable_flashinfer_mla, - "disable_radix_cache": server_args.disable_radix_cache, - "flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged, - "debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder, - "debug_tensor_dump_inject": server_args.debug_tensor_dump_inject, - } - ) - - set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3)) - - # Get memory before model loading - min_per_gpu_memory = self.init_torch_distributed() - - self.memory_saver_adapter = TorchMemorySaverAdapter.create( - enable=self.server_args.enable_memory_saver - ) - - # Load the model - self.sampler = Sampler() - self.load_model() - - # Handle the case where some of models don't finish loading. - try: - dist.monitored_barrier( - group=get_tp_group().cpu_group, - timeout=datetime.timedelta(seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S), - wait_all_ranks=True, - ) - except RuntimeError: - raise ValueError( - f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node." - ) from None - - # Apply torchao quantization - torchao_applied = getattr(self.model, "torchao_applied", False) - # In layered loading, torchao may have been applied - if not torchao_applied: - apply_torchao_config_to_model( - self.model, global_server_args_dict["torchao_config"] - ) - - # Apply torch TP if the model supports it - supports_torch_tp = getattr(self.model, "supports_torch_tp", False) - if self.tp_size > 1 and supports_torch_tp: - self.apply_torch_tp() - self.torch_tp_applied = True - else: - self.torch_tp_applied = False - - # Init memory pool and attention backends - if server_args.lora_paths is not None: - self.init_lora_manager() - self.init_memory_pool( - min_per_gpu_memory, - server_args.max_running_requests, - server_args.max_total_tokens, - ) - if self.device == "cuda": - self.init_cublas() - self.init_attention_backend() - self.init_cuda_graphs() - else: - self.cuda_graph_runner = None - self.init_attention_backend() - def init_torch_distributed(self): logger.info("Init torch distributed begin.") - torch.get_device_module(self.device).set_device(self.gpu_id) + torch.get_device_module(self.device).set_device(self.gpu_id) if self.device == "cuda": backend = "nccl" elif self.device == "xpu": @@ -400,6 +395,18 @@ class ModelRunner: f"mem usage={(before_avail_memory - after_avail_memory):.2f} GB." ) + # Handle the case where some ranks do not finish loading. + try: + dist.monitored_barrier( + group=get_tp_group().cpu_group, + timeout=datetime.timedelta(seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S), + wait_all_ranks=True, + ) + except RuntimeError: + raise ValueError( + f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node." + ) from None + def update_weights_from_disk( self, model_path: str, load_format: str ) -> tuple[bool, str]: @@ -772,6 +779,10 @@ class ModelRunner: def init_attention_backend(self): """Init attention kernel backend.""" if self.server_args.attention_backend == "flashinfer": + # Init streams + if self.server_args.speculative_algorithm == "EAGLE": + self.plan_stream_for_flashinfer = torch.cuda.Stream() + self.attn_backend = FlashInferAttnBackend(self) elif self.server_args.attention_backend == "triton": assert self.sliding_window_size is None, ( @@ -880,18 +891,24 @@ class ModelRunner: forward_batch.input_ids, forward_batch.positions, forward_batch ) - def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput: + def forward( + self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False + ) -> LogitsProcessorOutput: if ( forward_batch.forward_mode.is_cuda_graph() and self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch) ): - return self.cuda_graph_runner.replay(forward_batch) + return self.cuda_graph_runner.replay( + forward_batch, skip_attn_backend_init=skip_attn_backend_init + ) if forward_batch.forward_mode.is_decode(): return self.forward_decode(forward_batch) elif forward_batch.forward_mode.is_extend(): - return self.forward_extend(forward_batch) + return self.forward_extend( + forward_batch, skip_attn_backend_init=skip_attn_backend_init + ) elif forward_batch.forward_mode.is_idle(): return self.forward_idle(forward_batch) else: