Crash the CI jobs on model import errors (#2072)

This commit is contained in:
Lianmin Zheng
2024-11-17 22:18:11 -08:00
committed by GitHub
parent a7164b620f
commit df7fe4521a
5 changed files with 30 additions and 25 deletions

View File

@@ -8,7 +8,7 @@ from torch import nn
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.utils import is_flashinfer_available
from sglang.srt.utils import crash_on_warnings, is_flashinfer_available
if is_flashinfer_available():
from flashinfer.sampling import (
@@ -19,10 +19,6 @@ if is_flashinfer_available():
)
# Crash on warning if we are running CI tests
crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
logger = logging.getLogger(__name__)
@@ -46,7 +42,8 @@ class Sampler(nn.Module):
logits = torch.where(
torch.isnan(logits), torch.full_like(logits, -1e5), logits
)
exit(1) if crash_on_warning else None
if crash_on_warnings():
raise ValueError("Detected errors during sampling! NaN in the logits.")
if sampling_info.is_all_greedy:
# Use torch.argmax if all requests use greedy sampling

View File

@@ -67,6 +67,7 @@ from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import (
broadcast_pyobj,
configure_logger,
crash_on_warnings,
get_zmq_socket,
kill_parent_process,
set_random_seed,
@@ -76,10 +77,6 @@ from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__)
# Crash on warning if we are running CI tests
crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
# Test retract decode
test_retract = os.getenv("SGLANG_TEST_RETRACT", "false") == "true"
@@ -662,21 +659,23 @@ class Scheduler:
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
)
if available_size != self.max_total_num_tokens:
warnings.warn(
"Warning: "
f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n"
msg = (
"KV cache pool leak detected!"
f"{available_size=}, {self.max_total_num_tokens=}\n"
)
exit(1) if crash_on_warning else None
warnings.warn(msg)
if crash_on_warnings():
raise ValueError(msg)
if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
warnings.warn(
"Warning: "
f"available req slots={len(self.req_to_token_pool.free_slots)}, "
f"total slots={self.req_to_token_pool.size}\n"
msg = (
"Memory pool leak detected!"
f"available_size={len(self.req_to_token_pool.free_slots)}, "
f"total_size={self.req_to_token_pool.size}\n"
)
exit(1) if crash_on_warning else None
warnings.warn(msg)
if crash_on_warnings():
raise ValueError(msg)
def get_next_batch_to_run(self):
# Merge the prefill batch into the running batch

View File

@@ -20,6 +20,7 @@ import importlib
import importlib.resources
import json
import logging
import os
import pkgutil
from functools import lru_cache
from typing import Optional, Type
@@ -56,6 +57,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
crash_on_warnings,
enable_show_time_cost,
get_available_gpu_memory,
monkey_patch_vllm_p2p_access_check,
@@ -665,7 +667,9 @@ def import_model_classes():
try:
module = importlib.import_module(name)
except Exception as e:
logger.warning(f"Ignore import error when loading {name}. " f"{e}")
logger.warning(f"Ignore import error when loading {name}. {e}")
if crash_on_warnings():
raise ValueError(f"Ignore import error when loading {name}. {e}")
continue
if hasattr(module, "EntryClass"):
entry = module.EntryClass

View File

@@ -1,14 +1,14 @@
import math
from typing import Dict, Iterable, List, Optional, Tuple, Union
from typing import Iterable, Optional, Tuple, Union
import torch
from torch import nn
from transformers import Phi3Config
from transformers.configuration_utils import PretrainedConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import make_layers, maybe_prefix
from vllm.model_executor.models.utils import make_layers
from sglang.srt.layers.linear import (
MergedColumnParallelLinear,
@@ -339,7 +339,7 @@ class Phi3SmallForCausalLM(nn.Module):
self,
config: Phi3Config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
cache_config=None,
):
super().__init__()
@@ -349,7 +349,7 @@ class Phi3SmallForCausalLM(nn.Module):
self.model = Phi3SmallModel(
config=config,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "model"),
prefix="model",
)
self.torchao_config = global_server_args_dict["torchao_config"]
self.vocab_size = config.vocab_size

View File

@@ -816,3 +816,8 @@ def get_nvgpu_memory_capacity():
raise RuntimeError(
"nvidia-smi not found. Ensure NVIDIA drivers are installed and accessible."
)
def crash_on_warnings():
# Crash on warning if we are running CI tests
return os.getenv("SGLANG_IS_IN_CI", "false") == "true"