From ebaa2f31996e80e4128b832d70f29f288b59944e Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 17 Nov 2024 16:53:44 -0800 Subject: [PATCH] Rename arguments `--disable-nan-detection` to `--enable-nan-detection` (#2066) --- python/sglang/srt/layers/sampler.py | 2 +- python/sglang/srt/managers/schedule_batch.py | 2 +- .../sglang/srt/model_executor/model_runner.py | 6 ++++- python/sglang/srt/models/gemma2.py | 1 + python/sglang/srt/server_args.py | 26 +++++++------------ 5 files changed, 17 insertions(+), 20 deletions(-) diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index a5afcab51..5f6ed3fb7 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -29,7 +29,7 @@ logger = logging.getLogger(__name__) class Sampler(nn.Module): def __init__(self): super().__init__() - self.use_nan_detectioin = not global_server_args_dict["disable_nan_detection"] + self.use_nan_detectioin = global_server_args_dict["enable_nan_detection"] def forward( self, diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 054d1fcf8..6171c93c0 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -57,7 +57,7 @@ global_server_args_dict = { "triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32, "disable_mla": ServerArgs.disable_mla, "torchao_config": ServerArgs.torchao_config, - "disable_nan_detection": ServerArgs.disable_nan_detection, + "enable_nan_detection": ServerArgs.enable_nan_detection, "enable_dp_attention": ServerArgs.enable_dp_attention, } diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 35d81050a..02750d5df 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -139,7 +139,7 @@ class ModelRunner: "disable_mla": server_args.disable_mla, "torchao_config": server_args.torchao_config, "disable_penalizer": server_args.disable_penalizer, - "disable_nan_detection": server_args.disable_nan_detection, + "enable_nan_detection": server_args.enable_nan_detection, "enable_dp_attention": server_args.enable_dp_attention, } ) @@ -276,6 +276,10 @@ class ModelRunner: else None ) self.dtype = self.vllm_model_config.dtype + if self.sliding_window_size: + assert ( + self.server_args.attention_backend == "flashinfer" + ), "Only flashinfer supports window attention." logger.info( f"Load weight end. " diff --git a/python/sglang/srt/models/gemma2.py b/python/sglang/srt/models/gemma2.py index b295c7bbc..8dc5effb4 100644 --- a/python/sglang/srt/models/gemma2.py +++ b/python/sglang/srt/models/gemma2.py @@ -332,6 +332,7 @@ class Gemma2ForCausalLM(nn.Module): # Gemma does not apply LoRA to the embedding layer. embedding_modules = {} embedding_padding_modules = [] + supports_lora = True def __init__( self, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 8508d15d8..394bb519b 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -124,7 +124,6 @@ class ServerArgs: disable_custom_all_reduce: bool = False disable_mla: bool = False disable_penalizer: bool = False - disable_nan_detection: bool = False enable_overlap_schedule: bool = False enable_mixed_chunk: bool = False enable_dp_attention: bool = False @@ -132,6 +131,7 @@ class ServerArgs: torch_compile_max_bs: int = 32 cuda_graph_max_bs: int = 160 torchao_config: str = "" + enable_nan_detection: bool = False enable_p2p_check: bool = False triton_attention_reduce_in_fp32: bool = False num_continuous_decode_steps: int = 1 @@ -171,11 +171,11 @@ class ServerArgs: else: gpu_mem = get_nvgpu_memory_capacity() if gpu_mem < 25000: + self.chunked_prefill_size //= 4 # make it 2048 + self.cuda_graph_max_bs = 4 logger.warning( "Automatically adjust --chunked-prefill-size for small GPUs." ) - self.chunked_prefill_size //= 4 # make it 2048 - self.cuda_graph_max_bs = 4 if not is_flashinfer_available(): self.attention_backend = "triton" @@ -194,7 +194,7 @@ class ServerArgs: self.cuda_graph_max_bs = min(self.cuda_graph_max_bs, 96) self.enable_overlap_schedule = False logger.warning( - f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE workload issue. " + f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. " f"The CUDA graph max batch size is adjusted to {self.cuda_graph_max_bs}. " "Data parallel size is adjusted to be the same as tensor parallel size." ) @@ -204,21 +204,8 @@ class ServerArgs: "Overlap scheduler mode is enabled. This is an experimental feature. " "Sampling penalizer (e.g., frequency and repetition penalty), constrained decoding (e.g., regex, JSON), " "and embedding APIs are not supported and will lead to wrong results. " - "The NaN detection is also disabled." ) self.disable_penalizer = True - self.disable_nan_detection = True - - # Model-specific patches - if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path: - logger.info( - "Not sure why, the tokenizer will add an additional token at the end of the prompt when trust_remote_mode=True" - ) - self.trust_remote_code = False - - if "gemma-2" in self.model_path.lower(): - logger.info("When using sliding window in gemma-2, turn on flashinfer.") - self.attention_backend = "flashinfer" @staticmethod def add_cli_args(parser: argparse.ArgumentParser): @@ -683,6 +670,11 @@ class ServerArgs: default=ServerArgs.torchao_config, help="Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-, fp8wo", ) + parser.add_argument( + "--enable-nan-detection", + action="store_true", + help="Enable the NaN detection for debugging purposes.", + ) parser.add_argument( "--enable-p2p-check", action="store_true",