Crash the CI jobs on model import errors (#2072)

This commit is contained in:
Lianmin Zheng
2024-11-17 22:18:11 -08:00
committed by GitHub
parent a7164b620f
commit df7fe4521a
5 changed files with 30 additions and 25 deletions

View File

@@ -8,7 +8,7 @@ from torch import nn
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.utils import is_flashinfer_available from sglang.srt.utils import crash_on_warnings, is_flashinfer_available
if is_flashinfer_available(): if is_flashinfer_available():
from flashinfer.sampling import ( from flashinfer.sampling import (
@@ -19,10 +19,6 @@ if is_flashinfer_available():
) )
# Crash on warning if we are running CI tests
crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -46,7 +42,8 @@ class Sampler(nn.Module):
logits = torch.where( logits = torch.where(
torch.isnan(logits), torch.full_like(logits, -1e5), logits torch.isnan(logits), torch.full_like(logits, -1e5), logits
) )
exit(1) if crash_on_warning else None if crash_on_warnings():
raise ValueError("Detected errors during sampling! NaN in the logits.")
if sampling_info.is_all_greedy: if sampling_info.is_all_greedy:
# Use torch.argmax if all requests use greedy sampling # Use torch.argmax if all requests use greedy sampling

View File

@@ -67,6 +67,7 @@ from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
broadcast_pyobj, broadcast_pyobj,
configure_logger, configure_logger,
crash_on_warnings,
get_zmq_socket, get_zmq_socket,
kill_parent_process, kill_parent_process,
set_random_seed, set_random_seed,
@@ -76,10 +77,6 @@ from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Crash on warning if we are running CI tests
crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
# Test retract decode # Test retract decode
test_retract = os.getenv("SGLANG_TEST_RETRACT", "false") == "true" test_retract = os.getenv("SGLANG_TEST_RETRACT", "false") == "true"
@@ -662,21 +659,23 @@ class Scheduler:
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size() self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
) )
if available_size != self.max_total_num_tokens: if available_size != self.max_total_num_tokens:
warnings.warn( msg = (
"Warning: "
f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n"
"KV cache pool leak detected!" "KV cache pool leak detected!"
f"{available_size=}, {self.max_total_num_tokens=}\n"
) )
exit(1) if crash_on_warning else None warnings.warn(msg)
if crash_on_warnings():
raise ValueError(msg)
if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size: if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
warnings.warn( msg = (
"Warning: "
f"available req slots={len(self.req_to_token_pool.free_slots)}, "
f"total slots={self.req_to_token_pool.size}\n"
"Memory pool leak detected!" "Memory pool leak detected!"
f"available_size={len(self.req_to_token_pool.free_slots)}, "
f"total_size={self.req_to_token_pool.size}\n"
) )
exit(1) if crash_on_warning else None warnings.warn(msg)
if crash_on_warnings():
raise ValueError(msg)
def get_next_batch_to_run(self): def get_next_batch_to_run(self):
# Merge the prefill batch into the running batch # Merge the prefill batch into the running batch

View File

@@ -20,6 +20,7 @@ import importlib
import importlib.resources import importlib.resources
import json import json
import logging import logging
import os
import pkgutil import pkgutil
from functools import lru_cache from functools import lru_cache
from typing import Optional, Type from typing import Optional, Type
@@ -56,6 +57,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
crash_on_warnings,
enable_show_time_cost, enable_show_time_cost,
get_available_gpu_memory, get_available_gpu_memory,
monkey_patch_vllm_p2p_access_check, monkey_patch_vllm_p2p_access_check,
@@ -665,7 +667,9 @@ def import_model_classes():
try: try:
module = importlib.import_module(name) module = importlib.import_module(name)
except Exception as e: except Exception as e:
logger.warning(f"Ignore import error when loading {name}. " f"{e}") logger.warning(f"Ignore import error when loading {name}. {e}")
if crash_on_warnings():
raise ValueError(f"Ignore import error when loading {name}. {e}")
continue continue
if hasattr(module, "EntryClass"): if hasattr(module, "EntryClass"):
entry = module.EntryClass entry = module.EntryClass

View File

@@ -1,14 +1,14 @@
import math import math
from typing import Dict, Iterable, List, Optional, Tuple, Union from typing import Iterable, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from transformers import Phi3Config from transformers import Phi3Config
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import make_layers, maybe_prefix from vllm.model_executor.models.utils import make_layers
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
MergedColumnParallelLinear, MergedColumnParallelLinear,
@@ -339,7 +339,7 @@ class Phi3SmallForCausalLM(nn.Module):
self, self,
config: Phi3Config, config: Phi3Config,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", cache_config=None,
): ):
super().__init__() super().__init__()
@@ -349,7 +349,7 @@ class Phi3SmallForCausalLM(nn.Module):
self.model = Phi3SmallModel( self.model = Phi3SmallModel(
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
prefix=maybe_prefix(prefix, "model"), prefix="model",
) )
self.torchao_config = global_server_args_dict["torchao_config"] self.torchao_config = global_server_args_dict["torchao_config"]
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size

View File

@@ -816,3 +816,8 @@ def get_nvgpu_memory_capacity():
raise RuntimeError( raise RuntimeError(
"nvidia-smi not found. Ensure NVIDIA drivers are installed and accessible." "nvidia-smi not found. Ensure NVIDIA drivers are installed and accessible."
) )
def crash_on_warnings():
# Crash on warning if we are running CI tests
return os.getenv("SGLANG_IS_IN_CI", "false") == "true"