diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index 5f6ed3fb7..eb27c5c25 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -8,7 +8,7 @@ from torch import nn from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo -from sglang.srt.utils import is_flashinfer_available +from sglang.srt.utils import crash_on_warnings, is_flashinfer_available if is_flashinfer_available(): from flashinfer.sampling import ( @@ -19,10 +19,6 @@ if is_flashinfer_available(): ) -# Crash on warning if we are running CI tests -crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true" - - logger = logging.getLogger(__name__) @@ -46,7 +42,8 @@ class Sampler(nn.Module): logits = torch.where( torch.isnan(logits), torch.full_like(logits, -1e5), logits ) - exit(1) if crash_on_warning else None + if crash_on_warnings(): + raise ValueError("Detected errors during sampling! NaN in the logits.") if sampling_info.is_all_greedy: # Use torch.argmax if all requests use greedy sampling diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 125abaaf7..46f431adf 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -67,6 +67,7 @@ from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import ( broadcast_pyobj, configure_logger, + crash_on_warnings, get_zmq_socket, kill_parent_process, set_random_seed, @@ -76,10 +77,6 @@ from sglang.utils import get_exception_traceback logger = logging.getLogger(__name__) - -# Crash on warning if we are running CI tests -crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true" - # Test retract decode test_retract = os.getenv("SGLANG_TEST_RETRACT", "false") == "true" @@ -662,21 +659,23 @@ class Scheduler: self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size() ) if available_size != self.max_total_num_tokens: - warnings.warn( - "Warning: " - f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n" + msg = ( "KV cache pool leak detected!" + f"{available_size=}, {self.max_total_num_tokens=}\n" ) - exit(1) if crash_on_warning else None + warnings.warn(msg) + if crash_on_warnings(): + raise ValueError(msg) if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size: - warnings.warn( - "Warning: " - f"available req slots={len(self.req_to_token_pool.free_slots)}, " - f"total slots={self.req_to_token_pool.size}\n" + msg = ( "Memory pool leak detected!" + f"available_size={len(self.req_to_token_pool.free_slots)}, " + f"total_size={self.req_to_token_pool.size}\n" ) - exit(1) if crash_on_warning else None + warnings.warn(msg) + if crash_on_warnings(): + raise ValueError(msg) def get_next_batch_to_run(self): # Merge the prefill batch into the running batch diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index ea09b3c26..3c76f5ad7 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -20,6 +20,7 @@ import importlib import importlib.resources import json import logging +import os import pkgutil from functools import lru_cache from typing import Optional, Type @@ -56,6 +57,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( + crash_on_warnings, enable_show_time_cost, get_available_gpu_memory, monkey_patch_vllm_p2p_access_check, @@ -665,7 +667,9 @@ def import_model_classes(): try: module = importlib.import_module(name) except Exception as e: - logger.warning(f"Ignore import error when loading {name}. " f"{e}") + logger.warning(f"Ignore import error when loading {name}. {e}") + if crash_on_warnings(): + raise ValueError(f"Ignore import error when loading {name}. {e}") continue if hasattr(module, "EntryClass"): entry = module.EntryClass diff --git a/python/sglang/srt/models/phi3_small.py b/python/sglang/srt/models/phi3_small.py index 68f8cbc45..656115321 100644 --- a/python/sglang/srt/models/phi3_small.py +++ b/python/sglang/srt/models/phi3_small.py @@ -1,14 +1,14 @@ import math -from typing import Dict, Iterable, List, Optional, Tuple, Union +from typing import Iterable, Optional, Tuple, Union import torch from torch import nn from transformers import Phi3Config from transformers.configuration_utils import PretrainedConfig -from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.utils import make_layers, maybe_prefix +from vllm.model_executor.models.utils import make_layers from sglang.srt.layers.linear import ( MergedColumnParallelLinear, @@ -339,7 +339,7 @@ class Phi3SmallForCausalLM(nn.Module): self, config: Phi3Config, quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", + cache_config=None, ): super().__init__() @@ -349,7 +349,7 @@ class Phi3SmallForCausalLM(nn.Module): self.model = Phi3SmallModel( config=config, quant_config=quant_config, - prefix=maybe_prefix(prefix, "model"), + prefix="model", ) self.torchao_config = global_server_args_dict["torchao_config"] self.vocab_size = config.vocab_size diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 8df6d7b7e..84bf9a2e5 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -816,3 +816,8 @@ def get_nvgpu_memory_capacity(): raise RuntimeError( "nvidia-smi not found. Ensure NVIDIA drivers are installed and accessible." ) + + +def crash_on_warnings(): + # Crash on warning if we are running CI tests + return os.getenv("SGLANG_IS_IN_CI", "false") == "true"