Crash the CI jobs on model import errors (#2072)
This commit is contained in:
@@ -8,7 +8,7 @@ from torch import nn
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.utils import is_flashinfer_available
|
||||
from sglang.srt.utils import crash_on_warnings, is_flashinfer_available
|
||||
|
||||
if is_flashinfer_available():
|
||||
from flashinfer.sampling import (
|
||||
@@ -19,10 +19,6 @@ if is_flashinfer_available():
|
||||
)
|
||||
|
||||
|
||||
# Crash on warning if we are running CI tests
|
||||
crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -46,7 +42,8 @@ class Sampler(nn.Module):
|
||||
logits = torch.where(
|
||||
torch.isnan(logits), torch.full_like(logits, -1e5), logits
|
||||
)
|
||||
exit(1) if crash_on_warning else None
|
||||
if crash_on_warnings():
|
||||
raise ValueError("Detected errors during sampling! NaN in the logits.")
|
||||
|
||||
if sampling_info.is_all_greedy:
|
||||
# Use torch.argmax if all requests use greedy sampling
|
||||
|
||||
@@ -67,6 +67,7 @@ from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
broadcast_pyobj,
|
||||
configure_logger,
|
||||
crash_on_warnings,
|
||||
get_zmq_socket,
|
||||
kill_parent_process,
|
||||
set_random_seed,
|
||||
@@ -76,10 +77,6 @@ from sglang.utils import get_exception_traceback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Crash on warning if we are running CI tests
|
||||
crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
|
||||
|
||||
# Test retract decode
|
||||
test_retract = os.getenv("SGLANG_TEST_RETRACT", "false") == "true"
|
||||
|
||||
@@ -662,21 +659,23 @@ class Scheduler:
|
||||
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
||||
)
|
||||
if available_size != self.max_total_num_tokens:
|
||||
warnings.warn(
|
||||
"Warning: "
|
||||
f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n"
|
||||
msg = (
|
||||
"KV cache pool leak detected!"
|
||||
f"{available_size=}, {self.max_total_num_tokens=}\n"
|
||||
)
|
||||
exit(1) if crash_on_warning else None
|
||||
warnings.warn(msg)
|
||||
if crash_on_warnings():
|
||||
raise ValueError(msg)
|
||||
|
||||
if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
|
||||
warnings.warn(
|
||||
"Warning: "
|
||||
f"available req slots={len(self.req_to_token_pool.free_slots)}, "
|
||||
f"total slots={self.req_to_token_pool.size}\n"
|
||||
msg = (
|
||||
"Memory pool leak detected!"
|
||||
f"available_size={len(self.req_to_token_pool.free_slots)}, "
|
||||
f"total_size={self.req_to_token_pool.size}\n"
|
||||
)
|
||||
exit(1) if crash_on_warning else None
|
||||
warnings.warn(msg)
|
||||
if crash_on_warnings():
|
||||
raise ValueError(msg)
|
||||
|
||||
def get_next_batch_to_run(self):
|
||||
# Merge the prefill batch into the running batch
|
||||
|
||||
@@ -20,6 +20,7 @@ import importlib
|
||||
import importlib.resources
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pkgutil
|
||||
from functools import lru_cache
|
||||
from typing import Optional, Type
|
||||
@@ -56,6 +57,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
crash_on_warnings,
|
||||
enable_show_time_cost,
|
||||
get_available_gpu_memory,
|
||||
monkey_patch_vllm_p2p_access_check,
|
||||
@@ -665,7 +667,9 @@ def import_model_classes():
|
||||
try:
|
||||
module = importlib.import_module(name)
|
||||
except Exception as e:
|
||||
logger.warning(f"Ignore import error when loading {name}. " f"{e}")
|
||||
logger.warning(f"Ignore import error when loading {name}. {e}")
|
||||
if crash_on_warnings():
|
||||
raise ValueError(f"Ignore import error when loading {name}. {e}")
|
||||
continue
|
||||
if hasattr(module, "EntryClass"):
|
||||
entry = module.EntryClass
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
import math
|
||||
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
||||
from typing import Iterable, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import Phi3Config
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.utils import make_layers, maybe_prefix
|
||||
from vllm.model_executor.models.utils import make_layers
|
||||
|
||||
from sglang.srt.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
@@ -339,7 +339,7 @@ class Phi3SmallForCausalLM(nn.Module):
|
||||
self,
|
||||
config: Phi3Config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
cache_config=None,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
@@ -349,7 +349,7 @@ class Phi3SmallForCausalLM(nn.Module):
|
||||
self.model = Phi3SmallModel(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "model"),
|
||||
prefix="model",
|
||||
)
|
||||
self.torchao_config = global_server_args_dict["torchao_config"]
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
@@ -816,3 +816,8 @@ def get_nvgpu_memory_capacity():
|
||||
raise RuntimeError(
|
||||
"nvidia-smi not found. Ensure NVIDIA drivers are installed and accessible."
|
||||
)
|
||||
|
||||
|
||||
def crash_on_warnings():
|
||||
# Crash on warning if we are running CI tests
|
||||
return os.getenv("SGLANG_IS_IN_CI", "false") == "true"
|
||||
|
||||
Reference in New Issue
Block a user