From 9b8ebb2798e2cdfe9659de230aea79baf25c0a34 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Thu, 9 Oct 2025 16:46:15 -0700 Subject: [PATCH] move more files under srt/utils (#11285) --- .github/workflows/pr-test.yml | 4 +- python/sglang/srt/disaggregation/decode.py | 2 +- python/sglang/srt/entrypoints/engine.py | 2 +- python/sglang/srt/entrypoints/grpc_server.py | 2 +- .../srt/layers/attention/nsa/nsa_indexer.py | 5 +- python/sglang/srt/layers/moe/ep_moe/layer.py | 2 +- python/sglang/srt/layers/moe/router.py | 66 ++++++++++++++----- .../srt/layers/quantization/fp8_utils.py | 4 +- python/sglang/srt/lora/lora_registry.py | 2 +- .../srt/managers/data_parallel_controller.py | 2 +- python/sglang/srt/managers/io_struct.py | 3 +- python/sglang/srt/managers/scheduler.py | 2 +- .../sglang/srt/managers/tokenizer_manager.py | 2 +- python/sglang/srt/mem_cache/memory_pool.py | 2 +- .../sglang/srt/model_executor/model_runner.py | 12 ++-- python/sglang/srt/models/kimi_vl.py | 8 +-- python/sglang/srt/models/kimi_vl_moonvit.py | 4 +- python/sglang/srt/server_args.py | 7 ++ .../srt/tokenizer/tiktoken_tokenizer.py | 4 +- python/sglang/srt/utils/__init__.py | 2 +- python/sglang/srt/{ => utils}/aio_rwlock.py | 0 python/sglang/srt/{ => utils}/bench_utils.py | 0 python/sglang/srt/utils/common.py | 2 +- .../srt/{ => utils}/host_shared_memory.py | 0 python/sglang/srt/{ => utils}/offloader.py | 8 +-- .../{ => utils}/torch_memory_saver_adapter.py | 0 .../bench_per_token_group_quant_8bit.py | 2 +- .../benchmark/bench_rotary_embedding.py | 2 +- 28 files changed, 96 insertions(+), 55 deletions(-) rename python/sglang/srt/{ => utils}/aio_rwlock.py (100%) rename python/sglang/srt/{ => utils}/bench_utils.py (100%) rename python/sglang/srt/{ => utils}/host_shared_memory.py (100%) rename python/sglang/srt/{ => utils}/offloader.py (99%) rename python/sglang/srt/{ => utils}/torch_memory_saver_adapter.py (100%) diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index eb2e754f0..4f31d34bb 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -292,7 +292,7 @@ jobs: needs: [check-changes, unit-test-backend-2-gpu, sgl-kernel-build-wheels] if: always() && !failure() && !cancelled() && ((needs.check-changes.outputs.main_package == 'true') || (needs.check-changes.outputs.sgl_kernel == 'true')) - runs-on: 4-gpu-runner + runs-on: 4-gpu-h100 strategy: fail-fast: false matrix: @@ -614,7 +614,7 @@ jobs: needs: [check-changes, unit-test-backend-2-gpu, sgl-kernel-build-wheels] if: always() && !failure() && !cancelled() && ((needs.check-changes.outputs.main_package == 'true') || (needs.check-changes.outputs.sgl_kernel == 'true')) - runs-on: 4-gpu-runner + runs-on: 4-gpu-h100 steps: - name: Checkout code uses: actions/checkout@v4 diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index fa3b2bc1f..7fb2365ca 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -51,8 +51,8 @@ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool from sglang.srt.model_executor.forward_batch_info import ForwardMode -from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.utils import get_int_env_var, require_mlp_sync +from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index a9f88dbf8..d754f1f95 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -68,7 +68,6 @@ from sglang.srt.managers.scheduler import run_scheduler_process from sglang.srt.managers.template_manager import TemplateManager from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.server_args import PortArgs, ServerArgs -from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.utils import ( MultiprocessingSerializer, assert_pkg_version, @@ -82,6 +81,7 @@ from sglang.srt.utils import ( set_prometheus_multiproc_dir, set_ulimit, ) +from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.version import __version__ logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/entrypoints/grpc_server.py b/python/sglang/srt/entrypoints/grpc_server.py index c3c813a3a..f9c7c72fd 100644 --- a/python/sglang/srt/entrypoints/grpc_server.py +++ b/python/sglang/srt/entrypoints/grpc_server.py @@ -35,8 +35,8 @@ from sglang.srt.managers.io_struct import ( from sglang.srt.managers.scheduler import run_scheduler_process from sglang.srt.sampling.sampling_params import SamplingParams as SGLSamplingParams from sglang.srt.server_args import PortArgs, ServerArgs -from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.utils import configure_logger, prepare_model_and_tokenizer +from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.utils import get_exception_traceback logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py index 2bc6771ab..b37e5ffac 100644 --- a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py +++ b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py @@ -12,7 +12,10 @@ from sglang.srt.custom_op import CustomOp from sglang.srt.utils import add_prefix, align, is_cuda, is_hip, is_npu if is_cuda(): - import deep_gemm + try: + import deep_gemm + except ImportError as e: + deep_gemm = e from sglang.srt.layers.attention.nsa.utils import NSA_DUAL_STREAM, NSA_USE_REAL_INDEXER from sglang.srt.layers.dp_attention import get_attention_tp_group diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 30e3faab3..bc7251989 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -30,9 +30,9 @@ from sglang.srt.layers.quantization.modelopt_quant import ( ModelOptNvFp4FusedMoEMethod, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch -from sglang.srt.offloader import get_offloader from sglang.srt.single_batch_overlap import DownGemmOverlapArgs from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu +from sglang.srt.utils.offloader import get_offloader if TYPE_CHECKING: from sglang.srt.layers.moe.token_dispatcher import ( diff --git a/python/sglang/srt/layers/moe/router.py b/python/sglang/srt/layers/moe/router.py index 0138dcdad..5c0b86e58 100644 --- a/python/sglang/srt/layers/moe/router.py +++ b/python/sglang/srt/layers/moe/router.py @@ -11,7 +11,7 @@ _is_hip = is_hip() @triton.jit -def fused_moe_router_kernel( +def fused_moe_router_cudacore_kernel( input_ptr, # input (bs, hidden_dim) moe_router_weight_ptr, # input (num_experts, hidden_dim) topk_weights_ptr, # output (bs, topk) @@ -114,7 +114,7 @@ def fused_moe_router_kernel( # assert not moe_renormalize, "moe weight renormalization not implemented" -def fused_moe_router_impl( +def fused_moe_router_cudacore( x: torch.Tensor, router_weight: torch.Tensor, topk: int, @@ -138,7 +138,7 @@ def fused_moe_router_impl( ), } - fused_moe_router_kernel[(bs,)]( + fused_moe_router_cudacore_kernel[(bs,)]( x, router_weight, topk_weights, @@ -157,7 +157,7 @@ def fused_moe_router_impl( @triton.jit -def fused_moe_router_large_bs_kernel( +def fused_moe_router_tensorcore_kernel( a_ptr, # input (bs, hidden_dim) b_ptr, # input (num_experts, hidden_dim) topk_weights_ptr, # output (bs, topk) @@ -167,12 +167,15 @@ def fused_moe_router_large_bs_kernel( topk: tl.constexpr, # only support topk <= 2 moe_softcapping: tl.constexpr, moe_renormalize: tl.constexpr, # not supported + correction_bias_ptr, + is_correction_bias: tl.constexpr, K: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, stride_am: tl.constexpr, stride_bn: tl.constexpr, + dp_attn_workaround_flag: tl.constexpr, ): # 1. get block id @@ -217,6 +220,20 @@ def fused_moe_router_large_bs_kernel( exped = tl.exp(2 * logits_scaled) logits_softcapped = (exped - 1) / (exped + 1) * moe_softcapping + # Add bias after softcapping + if is_correction_bias: + bias = tl.load( + correction_bias_ptr + tl.arange(0, BLOCK_SIZE_N)[None, :], + mask=expert_mask.T, + other=0.0, + ) + logits_softcapped = logits_softcapped + bias + + if dp_attn_workaround_flag: + logits_softcapped = tl.where( + logits_softcapped != logits_softcapped, -1e9, logits_softcapped + ) + # 5. top1 arange_block_size_n = tl.arange(0, BLOCK_SIZE_N)[None, :] cond_top1 = arange_block_size_n < num_experts @@ -266,7 +283,7 @@ def fused_moe_router_large_bs_kernel( ) -def fused_moe_router_large_bs_impl( +def fused_moe_router_tensorcore( x: torch.Tensor, router_weight: torch.Tensor, topk: int, @@ -274,6 +291,7 @@ def fused_moe_router_large_bs_impl( BLOCK_SIZE_M: int, BLOCK_SIZE_N: int, BLOCK_SIZE_K: int, + correction_bias: Optional[torch.Tensor] = None, ): assert len(x.shape) == 2 and x.shape[1] == router_weight.shape[1] bs, hidden_dim = x.shape @@ -285,10 +303,17 @@ def fused_moe_router_large_bs_impl( topk_weights = torch.empty((bs, topk), dtype=torch.float32, device=x.device) topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device) + is_correction_bias = correction_bias is not None grid = (triton.cdiv(bs, BLOCK_SIZE_M) * triton.cdiv(num_experts, BLOCK_SIZE_N),) - fused_moe_router_large_bs_kernel[grid]( + # TODO(ch-wan): temporary workaround for dp attention. We should support masked + # router to skip padded tokens. + from sglang.srt.layers.dp_attention import is_dp_attention_enabled + + dp_attn_workaround_flag = is_dp_attention_enabled() + + fused_moe_router_tensorcore_kernel[grid]( a_ptr=x, b_ptr=router_weight, topk_weights_ptr=topk_weights, @@ -299,11 +324,14 @@ def fused_moe_router_large_bs_impl( moe_softcapping=moe_softcapping, moe_renormalize=False, K=hidden_dim, + correction_bias_ptr=correction_bias, + is_correction_bias=is_correction_bias, BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, stride_am=hidden_dim, stride_bn=hidden_dim, + dp_attn_workaround_flag=dp_attn_workaround_flag, ) return topk_weights, topk_ids @@ -316,6 +344,7 @@ def fused_moe_router_shim( topk, renormalize, correction_bias: Optional[torch.Tensor] = None, + enable_deterministic_inference: bool = False, ): assert not renormalize assert ( @@ -324,16 +353,22 @@ def fused_moe_router_shim( ) bs, hidden_dim = hidden_states.shape num_experts = gating_output.shape[0] + BLOCK_SIZE_M = 32 - BLOCK_SIZE_N = 16 - BLOCK_SIZE_K = 256 + + BLOCK_SIZE_N = max(num_experts, 16) + BLOCK_SIZE_K = ( + 256 if num_experts < 256 else 64 + ) # if experts are large, need to use smaller k block or shared memory OOM + if ( - bs >= 512 - and topk <= 2 - and num_experts <= BLOCK_SIZE_N + (bs >= 512 or num_experts > 8) and hidden_dim % BLOCK_SIZE_K == 0 + # we keep using single kernel to avoid non-deterministic behavior + and not enable_deterministic_inference ): - return fused_moe_router_large_bs_impl( + # if large batch size or large expert, use kernel that uses tensorcore in matmul + return fused_moe_router_tensorcore( x=hidden_states, router_weight=gating_output, topk=topk, @@ -341,9 +376,11 @@ def fused_moe_router_shim( BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, + correction_bias=correction_bias, ) else: - return fused_moe_router_impl( + # if smaller, use kernel that does not use tensorcore in matmul + return fused_moe_router_cudacore( x=hidden_states, router_weight=gating_output, topk=topk, @@ -380,11 +417,10 @@ class FusedMoeRouter: renormalize=False, ) - def forward_vllm( + def forward_torch( self, x: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: - # g, _ = self.router_linear.forward(x) g = x.float() @ self.router_linear.weight.T.float() g = torch.tanh(g.float() / self.moe_softcapping) * self.moe_softcapping diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index 3066842f0..fc50c1f54 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -2,11 +2,10 @@ from typing import Callable, List, Optional, Tuple import torch -from sglang.srt import offloader from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8 from sglang.srt.layers.quantization.mxfp4_tensor import MXFP4QuantizeUtil -from sglang.srt.utils import is_sm100_supported +from sglang.srt.utils import is_sm100_supported, offloader try: from vllm import _custom_ops as ops @@ -29,7 +28,6 @@ from sglang.srt.layers.quantization.fp8_kernel import ( ) from sglang.srt.utils import ( align, - ceil_div, get_bool_env_var, get_cuda_version, get_device_capability, diff --git a/python/sglang/srt/lora/lora_registry.py b/python/sglang/srt/lora/lora_registry.py index 51d2b0e66..5b4b538ac 100644 --- a/python/sglang/srt/lora/lora_registry.py +++ b/python/sglang/srt/lora/lora_registry.py @@ -18,8 +18,8 @@ from dataclasses import dataclass, field, fields from typing import Dict, List, Optional, Union from uuid import uuid4 -from sglang.srt.aio_rwlock import RWLock from sglang.srt.utils import ConcurrentCounter +from sglang.srt.utils.aio_rwlock import RWLock @dataclass(frozen=True) diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index 70f41e5ce..ee6b0f0d2 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -37,13 +37,13 @@ from sglang.srt.managers.io_struct import ( from sglang.srt.managers.schedule_batch import Req from sglang.srt.managers.scheduler import run_scheduler_process from sglang.srt.server_args import PortArgs, ServerArgs -from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.utils import ( bind_port, configure_logger, get_zmq_socket, kill_itself_when_parent_died, ) +from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.utils import TypeBasedDispatcher, get_exception_traceback logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 32733c277..e6dfa35c4 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -36,7 +36,6 @@ else: Image = Any -# Parameters for a session @dataclass class BaseReq(ABC): rid: Optional[Union[str, List[str]]] = field(default=None, kw_only=True) @@ -60,9 +59,11 @@ class BaseBatchReq(ABC): return self.rids +# Parameters for a session @dataclass class SessionParams: id: Optional[str] = None + rid: Optional[str] = None offset: Optional[int] = None replace: Optional[bool] = None drop_previous_output: Optional[bool] = None diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 655b39a4e..f279092eb 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -156,7 +156,6 @@ from sglang.srt.model_executor.forward_batch_info import ( from sglang.srt.parser.reasoning_parser import ReasoningParser from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.speculative.spec_info import SpeculativeAlgorithm -from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.tracing.trace import ( process_tracing_init, trace_set_proc_propagate_context, @@ -192,6 +191,7 @@ from sglang.srt.utils.hf_transformers_utils import ( get_tokenizer, get_tokenizer_from_processor, ) +from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.utils import TypeBasedDispatcher, get_exception_traceback logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 9d6bf9fc5..be9e5699a 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -40,7 +40,6 @@ import zmq import zmq.asyncio from fastapi import BackgroundTasks -from sglang.srt.aio_rwlock import RWLock from sglang.srt.configs.model_config import ModelConfig from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.lora.lora_registry import LoRARegistry @@ -94,6 +93,7 @@ from sglang.srt.utils import ( get_zmq_socket, kill_process_tree, ) +from sglang.srt.utils.aio_rwlock import RWLock from sglang.srt.utils.hf_transformers_utils import ( get_processor, get_tokenizer, diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index eb863f4c8..f948ed636 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -20,7 +20,7 @@ from dataclasses import dataclass from sglang.srt.configs.mamba_utils import Mamba2CacheParams from sglang.srt.layers.attention.nsa import index_buf_accessor from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache -from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter +from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter """ Memory pool. diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index a74f85d71..5b1b9d22a 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -117,15 +117,9 @@ from sglang.srt.model_loader.remote_instance_weight_loader_utils import ( ) from sglang.srt.model_loader.utils import set_default_torch_dtype from sglang.srt.model_loader.weight_utils import default_weight_loader -from sglang.srt.offloader import ( - create_offloader_from_server_args, - get_offloader, - set_offloader, -) from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.server_args import ServerArgs from sglang.srt.speculative.spec_info import SpeculativeAlgorithm -from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.utils import ( MultiprocessingSerializer, cpu_has_amx_support, @@ -148,7 +142,13 @@ from sglang.srt.utils import ( set_cuda_arch, slow_rank_detector, ) +from sglang.srt.utils.offloader import ( + create_offloader_from_server_args, + get_offloader, + set_offloader, +) from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions +from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.weight_sync.tensor_bucket import ( FlattenedTensorBucket, FlattenedTensorMetadata, diff --git a/python/sglang/srt/models/kimi_vl.py b/python/sglang/srt/models/kimi_vl.py index 68ed47b2e..03ce44653 100644 --- a/python/sglang/srt/models/kimi_vl.py +++ b/python/sglang/srt/models/kimi_vl.py @@ -43,10 +43,8 @@ import copy import logging -import math -from collections.abc import Mapping from dataclasses import dataclass -from typing import Any, Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple import torch from torch import nn @@ -56,10 +54,6 @@ from sglang.srt.configs import KimiVLConfig from sglang.srt.configs.deepseekvl2 import DeepseekV2Config from sglang.srt.configs.kimi_vl import KimiVLConfig from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig -from sglang.srt.distributed import ( - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, -) from sglang.srt.layers.activation import QuickGELU from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.quantization.base_config import QuantizationConfig diff --git a/python/sglang/srt/models/kimi_vl_moonvit.py b/python/sglang/srt/models/kimi_vl_moonvit.py index f86d5c0e8..286e85772 100644 --- a/python/sglang/srt/models/kimi_vl_moonvit.py +++ b/python/sglang/srt/models/kimi_vl_moonvit.py @@ -49,7 +49,7 @@ from typing import List, Optional, Sequence, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F -from transformers.activations import ACT2FN, GELUTanh +from transformers.activations import ACT2FN from transformers.modeling_utils import PreTrainedModel try: @@ -596,6 +596,8 @@ class MoonVitPretrainedModel(PreTrainedModel): _supports_sdpa = True def __init__(self, config: MoonViTConfig, *inputs, **kwargs): + from transformers.activations import GELUTanh + super().__init__(config, *inputs, **kwargs) config = deepcopy(config) self.merge_kernel_size = config.merge_kernel_size diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 408d18dda..3b205ef8f 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -238,6 +238,7 @@ class ServerArgs: log_requests: bool = False log_requests_level: int = 2 crash_dump_folder: Optional[str] = None + crash_on_nan: bool = False show_time_cost: bool = False enable_metrics: bool = False enable_metrics_for_all_schedulers: bool = False @@ -1733,6 +1734,12 @@ class ServerArgs: default=ServerArgs.crash_dump_folder, help="Folder path to dump requests from the last 5 min before a crash (if any). If not specified, crash dumping is disabled.", ) + parser.add_argument( + "--crash-on-nan", + type=str, + default=ServerArgs.crash_on_nan, + help="Crash the server on nan logprobs.", + ) parser.add_argument( "--show-time-cost", action="store_true", diff --git a/python/sglang/srt/tokenizer/tiktoken_tokenizer.py b/python/sglang/srt/tokenizer/tiktoken_tokenizer.py index 98df443e5..c1f2a91b0 100644 --- a/python/sglang/srt/tokenizer/tiktoken_tokenizer.py +++ b/python/sglang/srt/tokenizer/tiktoken_tokenizer.py @@ -133,9 +133,9 @@ class TiktokenTokenizer: ) return self.encode(ret) if tokenize else ret - def __call__(self, text, **kwargs): + def __call__(self, text: List[str], **kwargs): return { - "input_ids": self.encode(text), + "input_ids": [self.encode(x) for x in text], } def init_xgrammar(self): diff --git a/python/sglang/srt/utils/__init__.py b/python/sglang/srt/utils/__init__.py index 5fb724e1a..40f7bdfb4 100644 --- a/python/sglang/srt/utils/__init__.py +++ b/python/sglang/srt/utils/__init__.py @@ -1,2 +1,2 @@ # Temporarily do this to avoid changing all imports in the repo -from .common import * +from sglang.srt.utils.common import * diff --git a/python/sglang/srt/aio_rwlock.py b/python/sglang/srt/utils/aio_rwlock.py similarity index 100% rename from python/sglang/srt/aio_rwlock.py rename to python/sglang/srt/utils/aio_rwlock.py diff --git a/python/sglang/srt/bench_utils.py b/python/sglang/srt/utils/bench_utils.py similarity index 100% rename from python/sglang/srt/bench_utils.py rename to python/sglang/srt/utils/bench_utils.py diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index 0e6828c7e..7ac6b20c5 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -487,7 +487,7 @@ def make_layers( # circula imports from sglang.srt.distributed import get_pp_indices from sglang.srt.layers.utils import PPMissingLayer - from sglang.srt.offloader import get_offloader + from sglang.srt.utils.offloader import get_offloader assert not pp_size or num_hidden_layers >= pp_size start_layer, end_layer = ( diff --git a/python/sglang/srt/host_shared_memory.py b/python/sglang/srt/utils/host_shared_memory.py similarity index 100% rename from python/sglang/srt/host_shared_memory.py rename to python/sglang/srt/utils/host_shared_memory.py diff --git a/python/sglang/srt/offloader.py b/python/sglang/srt/utils/offloader.py similarity index 99% rename from python/sglang/srt/offloader.py rename to python/sglang/srt/utils/offloader.py index 0adddf5a6..58ab19c1f 100644 --- a/python/sglang/srt/offloader.py +++ b/python/sglang/srt/utils/offloader.py @@ -11,14 +11,14 @@ from sglang.srt.distributed.naive_distributed import ( get_naive_distributed, set_naive_distributed, ) -from sglang.srt.host_shared_memory import ( +from sglang.srt.layers.parameter import ModelWeightParameter +from sglang.srt.server_args import ServerArgs +from sglang.srt.utils import MultiprocessingSerializer, is_pin_memory_available +from sglang.srt.utils.host_shared_memory import ( HostSharedMemoryManager, get_host_shared_memory_manager, set_host_shared_memory_manager, ) -from sglang.srt.layers.parameter import ModelWeightParameter -from sglang.srt.server_args import ServerArgs -from sglang.srt.utils import MultiprocessingSerializer, is_pin_memory_available logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/torch_memory_saver_adapter.py b/python/sglang/srt/utils/torch_memory_saver_adapter.py similarity index 100% rename from python/sglang/srt/torch_memory_saver_adapter.py rename to python/sglang/srt/utils/torch_memory_saver_adapter.py diff --git a/sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py b/sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py index 558b7486e..1e8c985d5 100644 --- a/sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py +++ b/sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py @@ -7,7 +7,6 @@ from pathlib import Path import torch import triton -from sglang.srt.bench_utils import bench_kineto from sglang.srt.layers.quantization.fp8_kernel import ( create_per_token_group_quant_fp8_output_scale, ) @@ -16,6 +15,7 @@ from sglang.srt.layers.quantization.fp8_kernel import ( ) from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_8bit from sglang.srt.utils import is_hip +from sglang.srt.utils.bench_utils import bench_kineto # CI environment detection IS_CI = ( diff --git a/sgl-kernel/benchmark/bench_rotary_embedding.py b/sgl-kernel/benchmark/bench_rotary_embedding.py index 418fcd7dd..0cab8e653 100644 --- a/sgl-kernel/benchmark/bench_rotary_embedding.py +++ b/sgl-kernel/benchmark/bench_rotary_embedding.py @@ -11,7 +11,7 @@ from sgl_kernel.testing.rotary_embedding import ( create_inputs, ) -from sglang.srt.bench_utils import bench_kineto +from sglang.srt.utils.bench_utils import bench_kineto # CI environment detection IS_CI = (