diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index 7421bda18..454078d59 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -21,6 +21,10 @@ logger = logging.getLogger(__name__) class Sampler(nn.Module): + def __init__(self): + super().__init__() + self.use_nan_detectioin = not global_server_args_dict["disable_nan_detection"] + def forward( self, logits: Union[torch.Tensor, LogitsProcessorOutput], @@ -36,13 +40,13 @@ class Sampler(nn.Module): logits = None del logits - if torch.any(torch.isnan(probs)): + if self.use_nan_detectioin and torch.any(torch.isnan(probs)): logger.warning("Detected errors during sampling! NaN in the probability.") probs = torch.where( torch.isnan(probs), torch.full_like(probs, 1e-10), probs ) - if sampling_info.top_ks.max().item() <= 1: + if sampling_info.is_all_greedy: # Use torch.argmax if all requests use greedy sampling batch_next_token_ids = torch.argmax(probs, -1) elif global_server_args_dict["sampling_backend"] == "flashinfer": diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 000f8ecdc..8e55fb1d7 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -53,6 +53,7 @@ global_server_args_dict = { "triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32, "disable_mla": ServerArgs.disable_mla, "torchao_config": ServerArgs.torchao_config, + "disable_nan_detection": ServerArgs.disable_nan_detection, } diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index cdf3a77c9..d3ff3cd1d 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -245,10 +245,10 @@ class CudaGraphRunner: self.out_cache_loc.zero_() # Common inputs - self.input_ids[:raw_bs] = forward_batch.input_ids - self.req_pool_indices[:raw_bs] = forward_batch.req_pool_indices - self.seq_lens[:raw_bs] = forward_batch.seq_lens - self.out_cache_loc[:raw_bs] = forward_batch.out_cache_loc + self.input_ids[:raw_bs].copy_(forward_batch.input_ids) + self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) + self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens) + self.out_cache_loc[:raw_bs].copy_(forward_batch.out_cache_loc) # Attention backend self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph( diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index a60ac1c70..a8e64205b 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -137,6 +137,7 @@ class ModelRunner: "disable_mla": server_args.disable_mla, "torchao_config": server_args.torchao_config, "disable_penalizer": server_args.disable_penalizer, + "disable_nan_detection": server_args.disable_nan_detection, } ) diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 37dedcd17..457b100e9 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -20,6 +20,9 @@ class SamplingBatchInfo: top_ks: torch.Tensor min_ps: torch.Tensor + # All requests use greedy sampling + is_all_greedy: bool + # Dispatch in CUDA graph need_min_p_sampling: bool @@ -73,6 +76,7 @@ class SamplingBatchInfo: top_ks=top_ks, min_ps=min_ps, need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs), + is_all_greedy=top_ks.max().item() <= 1, vocab_size=vocab_size, device=batch.input_ids.device, ) @@ -204,6 +208,7 @@ class SamplingBatchInfo: other_val = getattr(other, item, None) setattr(self, item, torch.concat([self_val, other_val])) + self.is_all_greedy = self.is_all_greedy and other.is_all_greedy self.logit_bias = SamplingBatchInfo.merge_bias_tensor( self.logit_bias, other.logit_bias, len(self), len(other), self.device ) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 10f63e697..722e30f6b 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -114,6 +114,7 @@ class ServerArgs: disable_custom_all_reduce: bool = False disable_mla: bool = False disable_penalizer: bool = False + disable_nan_detection: bool = False enable_overlap_schedule: bool = False enable_mixed_chunk: bool = False enable_torch_compile: bool = False @@ -577,7 +578,12 @@ class ServerArgs: parser.add_argument( "--disable-penalizer", action="store_true", - help="Disable the logit penalizer (e.g., frequency and repetition penalty).", + help="Disable the logit penalizers (e.g., frequency and repetition penalty) for better performance if they are not used in any requests.", + ) + parser.add_argument( + "--disable-nan-detection", + action="store_true", + help="Disable the NaN detection for better performance.", ) parser.add_argument( "--enable-overlap-schedule",