diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 66d0b5c53..797a23040 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -41,6 +41,9 @@ if TYPE_CHECKING: def _to_torch(model: torch.nn.Module, reverse: bool = False): for sub in model._modules.values(): if isinstance(sub, CustomOp): + # NOTE: FusedMoE torch native implementaiton is not efficient + if "FusedMoE" in sub.__class__.__name__: + continue if reverse: sub._forward_method = sub.forward_cuda setattr(sub, "is_torch_compile", False) @@ -105,7 +108,15 @@ class CudaGraphRunner: self.capture_bs = list(range(1, 32)) + [64, 128] else: self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)] - self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if self.use_torch_compile else [] + self.compile_bs = ( + [ + bs + for bs in self.capture_bs + if bs <= self.model_runner.server_args.max_torch_compile_bs + ] + if self.use_torch_compile + else [] + ) # Common inputs self.max_bs = max(self.capture_bs) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 7bc106abc..30fa465d8 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -653,6 +653,7 @@ class DeepseekV2ForCausalLM(nn.Module): ) self.logits_processor = LogitsProcessor(config) + @torch.no_grad() def forward( self, input_ids: torch.Tensor, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 35b99b6af..33536ec16 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -110,6 +110,7 @@ class ServerArgs: disable_custom_all_reduce: bool = False enable_mixed_chunk: bool = False enable_torch_compile: bool = False + max_torch_compile_bs: int = 32 torchao_config: str = "" enable_p2p_check: bool = False enable_mla: bool = False @@ -523,6 +524,12 @@ class ServerArgs: action="store_true", help="Optimize the model with torch.compile. Experimental feature.", ) + parser.add_argument( + "--max-torch-compile-bs", + type=int, + default=ServerArgs.max_torch_compile_bs, + help="Set the maximum batch size when using torch compile.", + ) parser.add_argument( "--torchao-config", type=str,