diff --git a/python/sglang/srt/model_executor/compilation/backend.py b/python/sglang/srt/compilation/backend.py similarity index 96% rename from python/sglang/srt/model_executor/compilation/backend.py rename to python/sglang/srt/compilation/backend.py index 031e40fd4..88171a124 100644 --- a/python/sglang/srt/model_executor/compilation/backend.py +++ b/python/sglang/srt/compilation/backend.py @@ -15,15 +15,11 @@ import torch import torch.fx as fx from torch._dispatch.python import enable_python_dispatcher -from sglang.srt.model_executor.compilation.compilation_config import CompilationConfig -from sglang.srt.model_executor.compilation.compilation_counter import ( - compilation_counter, -) -from sglang.srt.model_executor.compilation.compiler_interface import InductorAdaptor -from sglang.srt.model_executor.compilation.cuda_piecewise_backend import ( - CUDAPiecewiseBackend, -) -from sglang.srt.model_executor.compilation.pass_manager import PostGradPassManager +from sglang.srt.compilation.compilation_config import CompilationConfig +from sglang.srt.compilation.compilation_counter import compilation_counter +from sglang.srt.compilation.compiler_interface import InductorAdaptor +from sglang.srt.compilation.cuda_piecewise_backend import CUDAPiecewiseBackend +from sglang.srt.compilation.pass_manager import PostGradPassManager logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/model_executor/compilation/compilation_config.py b/python/sglang/srt/compilation/compilation_config.py similarity index 100% rename from python/sglang/srt/model_executor/compilation/compilation_config.py rename to python/sglang/srt/compilation/compilation_config.py diff --git a/python/sglang/srt/model_executor/compilation/compilation_counter.py b/python/sglang/srt/compilation/compilation_counter.py similarity index 100% rename from python/sglang/srt/model_executor/compilation/compilation_counter.py rename to python/sglang/srt/compilation/compilation_counter.py diff --git a/python/sglang/srt/model_executor/compilation/compile.py b/python/sglang/srt/compilation/compile.py similarity index 97% rename from python/sglang/srt/model_executor/compilation/compile.py rename to python/sglang/srt/compilation/compile.py index dee7f0169..a77f5aee7 100644 --- a/python/sglang/srt/model_executor/compilation/compile.py +++ b/python/sglang/srt/compilation/compile.py @@ -10,7 +10,7 @@ from typing import Any, Callable, Optional, Union import torch -from sglang.srt.model_executor.compilation.compilation_config import CompilationConfig +from sglang.srt.compilation.compilation_config import CompilationConfig logger = logging.getLogger(__name__) @@ -134,7 +134,7 @@ def install_torch_compiled( dyn_map = dynamic_arg_dims or _infer_dynamic_arg_dims_from_annotations(unbound_fwd) if backend_factory is None: - from sglang.srt.model_executor.compilation.backend import SGLangBackend + from sglang.srt.compilation.backend import SGLangBackend backend_factory = lambda gm, ex: SGLangBackend(compile_config, graph_pool)( gm, ex diff --git a/python/sglang/srt/model_executor/compilation/compiler_interface.py b/python/sglang/srt/compilation/compiler_interface.py similarity index 99% rename from python/sglang/srt/model_executor/compilation/compiler_interface.py rename to python/sglang/srt/compilation/compiler_interface.py index 016703022..0c58a0dea 100644 --- a/python/sglang/srt/model_executor/compilation/compiler_interface.py +++ b/python/sglang/srt/compilation/compiler_interface.py @@ -12,10 +12,8 @@ import torch import torch._inductor.compile_fx import torch.fx as fx -from sglang.srt.model_executor.compilation.compilation_counter import ( - compilation_counter, -) -from sglang.srt.model_executor.compilation.inductor_pass import pass_context +from sglang.srt.compilation.compilation_counter import compilation_counter +from sglang.srt.compilation.inductor_pass import pass_context class CompilerInterface: diff --git a/python/sglang/srt/model_executor/compilation/cuda_piecewise_backend.py b/python/sglang/srt/compilation/cuda_piecewise_backend.py similarity index 97% rename from python/sglang/srt/model_executor/compilation/cuda_piecewise_backend.py rename to python/sglang/srt/compilation/cuda_piecewise_backend.py index 22f35b3bc..9f4b8cc8e 100644 --- a/python/sglang/srt/model_executor/compilation/cuda_piecewise_backend.py +++ b/python/sglang/srt/compilation/cuda_piecewise_backend.py @@ -9,11 +9,9 @@ from unittest.mock import patch import torch import torch.fx as fx -import sglang.srt.model_executor.compilation.weak_ref_tensor_jit -from sglang.srt.model_executor.compilation.compilation_config import CompilationConfig -from sglang.srt.model_executor.compilation.compilation_counter import ( - compilation_counter, -) +import sglang.srt.compilation.weak_ref_tensor_jit +from sglang.srt.compilation.compilation_config import CompilationConfig +from sglang.srt.compilation.compilation_counter import compilation_counter logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/model_executor/compilation/fix_functionalization.py b/python/sglang/srt/compilation/fix_functionalization.py similarity index 97% rename from python/sglang/srt/model_executor/compilation/fix_functionalization.py rename to python/sglang/srt/compilation/fix_functionalization.py index bd18173ae..8673e3576 100644 --- a/python/sglang/srt/model_executor/compilation/fix_functionalization.py +++ b/python/sglang/srt/compilation/fix_functionalization.py @@ -8,8 +8,8 @@ from typing import Optional, Union import torch from torch._higher_order_ops.auto_functionalize import auto_functionalized -from sglang.srt.model_executor.compilation.fx_utils import is_func -from sglang.srt.model_executor.compilation.inductor_pass import SGLangInductorPass +from sglang.srt.compilation.fx_utils import is_func +from sglang.srt.compilation.inductor_pass import SGLangInductorPass logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/model_executor/compilation/fx_utils.py b/python/sglang/srt/compilation/fx_utils.py similarity index 100% rename from python/sglang/srt/model_executor/compilation/fx_utils.py rename to python/sglang/srt/compilation/fx_utils.py diff --git a/python/sglang/srt/model_executor/compilation/inductor_pass.py b/python/sglang/srt/compilation/inductor_pass.py similarity index 100% rename from python/sglang/srt/model_executor/compilation/inductor_pass.py rename to python/sglang/srt/compilation/inductor_pass.py diff --git a/python/sglang/srt/model_executor/compilation/pass_manager.py b/python/sglang/srt/compilation/pass_manager.py similarity index 92% rename from python/sglang/srt/model_executor/compilation/pass_manager.py rename to python/sglang/srt/compilation/pass_manager.py index bc06a49ea..9173976f1 100644 --- a/python/sglang/srt/model_executor/compilation/pass_manager.py +++ b/python/sglang/srt/compilation/pass_manager.py @@ -4,10 +4,8 @@ import logging from torch import fx as fx -from sglang.srt.model_executor.compilation.fix_functionalization import ( - FixFunctionalizationPass, -) -from sglang.srt.model_executor.compilation.inductor_pass import ( +from sglang.srt.compilation.fix_functionalization import FixFunctionalizationPass +from sglang.srt.compilation.inductor_pass import ( CustomGraphPass, InductorPass, SGLangInductorPass, diff --git a/python/sglang/srt/model_executor/compilation/piecewise_context_manager.py b/python/sglang/srt/compilation/piecewise_context_manager.py similarity index 100% rename from python/sglang/srt/model_executor/compilation/piecewise_context_manager.py rename to python/sglang/srt/compilation/piecewise_context_manager.py diff --git a/python/sglang/srt/model_executor/compilation/weak_ref_tensor.cpp b/python/sglang/srt/compilation/weak_ref_tensor.cpp similarity index 100% rename from python/sglang/srt/model_executor/compilation/weak_ref_tensor.cpp rename to python/sglang/srt/compilation/weak_ref_tensor.cpp diff --git a/python/sglang/srt/model_executor/compilation/weak_ref_tensor_jit.py b/python/sglang/srt/compilation/weak_ref_tensor_jit.py similarity index 100% rename from python/sglang/srt/model_executor/compilation/weak_ref_tensor_jit.py rename to python/sglang/srt/compilation/weak_ref_tensor_jit.py diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index 775e98ca0..9a7ddd825 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -337,7 +337,6 @@ class GroupCoordinator: else: ca_max_size = 8 * 1024 * 1024 try: - # print(f"ca_max_size: {ca_max_size}") self.ca_comm = CustomAllreduce( group=self.cpu_group, device=self.device, diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index bd5866137..9d6536520 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -20,15 +20,13 @@ from typing import TYPE_CHECKING, Optional import torch from torch import nn +from sglang.srt.compilation.piecewise_context_manager import get_forward_context +from sglang.srt.utils import direct_register_custom_op + if TYPE_CHECKING: from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.model_executor.forward_batch_info import ForwardBatch -from sglang.srt.model_executor.compilation.piecewise_context_manager import ( - get_forward_context, -) -from sglang.srt.utils import direct_register_custom_op - class AttentionType(Enum): """ @@ -112,7 +110,7 @@ class RadixAttention(nn.Module): k = k.view(-1, self.tp_k_head_num, self.v_head_dim) if forward_batch.forward_mode.is_extend() and get_forward_context() is not None: - output = torch.zeros_like(q) + output = torch.empty_like(q) torch.ops.sglang.unified_attention_with_output( q, k, v, output, save_kv_cache, self.layer_id ) diff --git a/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py b/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py index a5f3b1d54..e4f9002b7 100644 --- a/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py @@ -24,6 +24,9 @@ from typing import TYPE_CHECKING, Union import torch import tqdm +from sglang.srt.compilation.compilation_config import CompilationConfig +from sglang.srt.compilation.compile import install_torch_compiled, set_compiled +from sglang.srt.compilation.piecewise_context_manager import set_forward_context from sglang.srt.custom_op import CustomOp from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.distributed.device_communicators.pynccl_allocator import ( @@ -38,14 +41,6 @@ from sglang.srt.layers.dp_attention import ( ) from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.torchao_utils import save_gemlite_cache -from sglang.srt.model_executor.compilation.compilation_config import CompilationConfig -from sglang.srt.model_executor.compilation.compile import ( - install_torch_compiled, - set_compiled, -) -from sglang.srt.model_executor.compilation.piecewise_context_manager import ( - set_forward_context, -) from sglang.srt.model_executor.forward_batch_info import ( CaptureHiddenMode, ForwardBatch,