diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1e09712ab..70a145b5a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,7 +27,9 @@ repos: rev: v0.11.7 hooks: - id: ruff - args: [--select=F401,F821, --fixable=F401] + args: + - --select=F401,F821 + - --fix files: ^(benchmark/|docs/|examples/|python/sglang/) exclude: __init__\.py$|\.ipynb$|^python/sglang/srt/grpc/.*_pb2\.py$|^python/sglang/srt/grpc/.*_pb2_grpc\.py$|^python/sglang/srt/grpc/.*_pb2\.pyi$|^python/sglang/srt/grpc/.*_pb2_grpc\.pyi$ - repo: https://github.com/psf/black diff --git a/benchmark/kernels/minmax-text-01-lightning_attention/benchmark_lightning_attention_decode.py b/benchmark/kernels/minmax-text-01-lightning_attention/benchmark_lightning_attention_decode.py index 78d81499e..734a0314b 100644 --- a/benchmark/kernels/minmax-text-01-lightning_attention/benchmark_lightning_attention_decode.py +++ b/benchmark/kernels/minmax-text-01-lightning_attention/benchmark_lightning_attention_decode.py @@ -167,6 +167,7 @@ class MiniMaxText01LightningAttention(nn.Module): past_key_value: Optional[Tuple[torch.Tensor]] = None, use_cache: bool = False, slope_rate: Optional[torch.Tensor] = None, + do_eval: bool = False, **kwargs, ): if (not self.training) and (not do_eval): diff --git a/benchmark/kernels/minmax-text-01-lightning_attention/benchmark_lightning_attention_prefill.py b/benchmark/kernels/minmax-text-01-lightning_attention/benchmark_lightning_attention_prefill.py index 3bf9054bd..9f11ac904 100644 --- a/benchmark/kernels/minmax-text-01-lightning_attention/benchmark_lightning_attention_prefill.py +++ b/benchmark/kernels/minmax-text-01-lightning_attention/benchmark_lightning_attention_prefill.py @@ -1,4 +1,5 @@ import itertools +import logging import math import os from typing import Optional, Tuple @@ -10,6 +11,8 @@ import triton import triton.language as tl from einops import rearrange +logger = logging.getLogger(__name__) + # Adapted from https://github.com/OpenNLPLab/lightning-attention/blob/main/lightning_attn/ops/triton/lightning_attn2.py @triton.jit @@ -302,6 +305,7 @@ class MiniMaxText01LightningAttention(nn.Module): past_key_value: Optional[Tuple[torch.Tensor]] = None, use_cache: bool = False, slope_rate: Optional[torch.Tensor] = None, + do_eval: bool = False, **kwargs, ): if (not self.training) and (not do_eval): diff --git a/python/sglang/bench_one_batch_server.py b/python/sglang/bench_one_batch_server.py index 920286e33..8f6640536 100644 --- a/python/sglang/bench_one_batch_server.py +++ b/python/sglang/bench_one_batch_server.py @@ -16,6 +16,7 @@ import argparse import dataclasses import itertools import json +import logging import multiprocessing import os import random @@ -39,6 +40,8 @@ from sglang.srt.server_args import ServerArgs from sglang.srt.utils import is_blackwell, kill_process_tree from sglang.test.test_utils import is_in_ci, write_github_step_summary +logger = logging.getLogger(__name__) + class ProfileLinks(BaseModel): """Pydantic model for profile trace links.""" diff --git a/python/sglang/srt/disaggregation/common/conn.py b/python/sglang/srt/disaggregation/common/conn.py index 5d0fd19c1..1a66c83ae 100644 --- a/python/sglang/srt/disaggregation/common/conn.py +++ b/python/sglang/srt/disaggregation/common/conn.py @@ -77,8 +77,8 @@ class CommonKVManager(BaseKVManager): if self.disaggregation_mode == DisaggregationMode.PREFILL: self._register_to_bootstrap() - self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {} - self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {} + self.transfer_infos = {} + self.decode_kv_args_table = {} self.pp_group = get_pp_group() elif self.disaggregation_mode == DisaggregationMode.DECODE: self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {} diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index af2c75d83..2326a647a 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -9,7 +9,7 @@ import struct import threading import time from collections import defaultdict -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Set, Tuple import numpy as np import numpy.typing as npt diff --git a/python/sglang/srt/entrypoints/openai/serving_responses.py b/python/sglang/srt/entrypoints/openai/serving_responses.py index 52d0b1104..b2f22c1b5 100644 --- a/python/sglang/srt/entrypoints/openai/serving_responses.py +++ b/python/sglang/srt/entrypoints/openai/serving_responses.py @@ -14,6 +14,7 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Optional, import jinja2 import openai.types.responses as openai_responses_types +import orjson from fastapi import Request from fastapi.responses import ORJSONResponse from openai.types.responses import ( @@ -1063,7 +1064,7 @@ class OpenAIServingResponses(OpenAIServingChat): ): function_name = previous_item.recipient[len("browser.") :] action = None - parsed_args = ororjson.loads(previous_item.content[0].text) + parsed_args = orjson.loads(previous_item.content[0].text) if function_name == "search": action = openai_responses_types.response_function_web_search.ActionSearch( type="search", diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 33ff82ca6..587b06af0 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -194,7 +194,7 @@ class FlashInferAttnBackend(AttentionBackend): ) if init_new_workspace: self.workspace_buffer = torch.empty( - global_config.flashinfer_workspace_size, + envs.SGLANG_FLASHINFER_WORKSPACE_SIZE.get(), dtype=torch.uint8, device=model_runner.device, ) diff --git a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py index ad9cbfd44..041263867 100644 --- a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py @@ -38,6 +38,9 @@ from sglang.srt.utils import ( ) if TYPE_CHECKING: + from sglang.srt.layers.attention.flashinfer_mla_backend import ( + FlashInferMlaAttnBackend, + ) from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.speculative.spec_info import SpecInput @@ -66,7 +69,7 @@ global_workspace_buffer = None class FlashInferMhaChunkKVRunner: def __init__( - self, model_runner: ModelRunner, attn_backend: "FlashInferMlaAttnBackend" + self, model_runner: ModelRunner, attn_backend: FlashInferMlaAttnBackend ): # Parse Constants self.num_local_heads = ( diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py index 5d39b8bbc..d4be7ae05 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py @@ -13,7 +13,8 @@ from triton_kernels.matmul_ogs import ( PrecisionConfig, matmul_ogs, ) -from triton_kernels.numerics import InFlexData +from triton_kernels.numerics import InFlexData, MicroscalingCtx +from triton_kernels.quantization import downcast_to_mxfp from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx from triton_kernels.swiglu import swiglu_fn @@ -119,14 +120,14 @@ def triton_kernel_fused_experts( block_shape: Optional[list[int]] = None, ) -> torch.Tensor: - assert use_fp8_w8a8 == False, "use_fp8_w8a8 is not supported" - assert per_channel_quant == False, "per_channel_quant is not supported" - assert expert_map == None, "expert_map is not supported" - assert w1_scale == None, "w1_scale is not supported" - assert w2_scale == None, "w2_scale is not supported" - assert a1_scale == None, "a1_scale is not supported" - assert a2_scale == None, "a2_scale is not supported" - assert block_shape == None, "block_shape is not supported" + assert use_fp8_w8a8 is False, "use_fp8_w8a8 is not supported" + assert per_channel_quant is False, "per_channel_quant is not supported" + assert expert_map is None, "expert_map is not supported" + assert w1_scale is None, "w1_scale is not supported" + assert w2_scale is None, "w2_scale is not supported" + assert a1_scale is None, "a1_scale is not supported" + assert a2_scale is None, "a2_scale is not supported" + assert block_shape is None, "block_shape is not supported" # type check assert hidden_states.dtype == torch.bfloat16, "hidden_states must be bfloat16" @@ -143,7 +144,7 @@ def triton_kernel_fused_experts( ), f"w2 shape[-1] {w2.shape[-1]} must be equal to w1 shape[1] {w1.shape[1]}" # feature check - assert inplace == False, "Inplace is not supported in new triton MoE kernel" + assert inplace is False, "Inplace is not supported in new triton MoE kernel" M, K = hidden_states.shape E, _, N = w1.shape @@ -264,14 +265,14 @@ def triton_kernel_fused_experts_with_bias( gemm1_alpha: Optional[float] = None, gemm1_clamp_limit: Optional[float] = None, ) -> torch.Tensor: - assert use_fp8_w8a8 == False, "use_fp8_w8a8 is not supported" - assert per_channel_quant == False, "per_channel_quant is not supported" - assert expert_map == None, "expert_map is not supported" - assert w1_scale == None, "w1_scale is not supported" - assert w2_scale == None, "w2_scale is not supported" - assert a1_scale == None, "a1_scale is not supported" - assert a2_scale == None, "a2_scale is not supported" - assert block_shape == None, "block_shape is not supported" + assert use_fp8_w8a8 is False, "use_fp8_w8a8 is not supported" + assert per_channel_quant is False, "per_channel_quant is not supported" + assert expert_map is None, "expert_map is not supported" + assert w1_scale is None, "w1_scale is not supported" + assert w2_scale is None, "w2_scale is not supported" + assert a1_scale is None, "a1_scale is not supported" + assert a2_scale is None, "a2_scale is not supported" + assert block_shape is None, "block_shape is not supported" # type check assert hidden_states.dtype == torch.bfloat16, "hidden_states must be bfloat16" @@ -290,7 +291,7 @@ def triton_kernel_fused_experts_with_bias( ), f"w2 shape[-1] {w2.shape[-1]} must be equal to w1 shape[1] {w1.shape[1]}" # feature check - assert inplace == False, "Inplace is not supported in new triton MoE kernel" + assert inplace is False, "Inplace is not supported in new triton MoE kernel" E, _, _ = w1.shape diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py index fde541e19..274b6184c 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py @@ -44,6 +44,13 @@ from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod try: + from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_24 import ( + CompressedTensors24, + ) + from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_w4a16_sparse24 import ( + W4A16SPARSE24_SUPPORTED_BITS, + CompressedTensorsW4A16Sparse24, + ) from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( WNA16_SUPPORTED_BITS, CompressedTensorsWNA16, diff --git a/python/sglang/srt/layers/quantization/fpgemm_fp8.py b/python/sglang/srt/layers/quantization/fpgemm_fp8.py index 0c7030101..352d74628 100644 --- a/python/sglang/srt/layers/quantization/fpgemm_fp8.py +++ b/python/sglang/srt/layers/quantization/fpgemm_fp8.py @@ -2,7 +2,7 @@ from __future__ import annotations import logging -from typing import Any, Optional +from typing import Any, List, Optional import torch from torch.nn import Module diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index c468269f3..9d3b9eb5e 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -47,6 +47,8 @@ from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2 if TYPE_CHECKING: from sglang.srt.managers.cache_controller import LayerDoneCounter + from sglang.srt.managers.schedule_batch import Req + logger = logging.getLogger(__name__) @@ -341,7 +343,7 @@ class HybridReqToTokenPool(ReqToTokenPool): # For chunk prefill req, we do not need to allocate mamba cache, # We could use allocated mamba cache instead. def alloc( - self, need_size: int, reqs: Optional[List["Req"]] = None + self, need_size: int, reqs: Optional[List[Req]] = None ) -> Optional[List[int]]: select_index = super().alloc(need_size) if select_index == None: diff --git a/python/sglang/srt/model_loader/weight_utils.py b/python/sglang/srt/model_loader/weight_utils.py index 900a37074..d4585bbb3 100644 --- a/python/sglang/srt/model_loader/weight_utils.py +++ b/python/sglang/srt/model_loader/weight_utils.py @@ -110,6 +110,9 @@ def convert_bin_to_safetensor_file( dirname = os.path.dirname(sf_filename) os.makedirs(dirname, exist_ok=True) + + from safetensors.torch import save_file + save_file(loaded, sf_filename, metadata={"format": "pt"}) # check file size diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 5d4dd5325..27b7bfbd3 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -25,6 +25,7 @@ from typing import Any, Dict, Iterable, Optional, Tuple, Union import torch import torch.nn.functional as F +import tqdm from torch import nn from transformers import PretrainedConfig @@ -3499,7 +3500,7 @@ class DeepseekV2ForCausalLM(nn.Module): # temporarily only support DeepSeek V3/R1 weight_block_size = [128, 128] - for layer_id in trange( + for layer_id in tqdm.trange( self.config.num_hidden_layers + int(is_nextn), desc="quant attn to fp8 ue8m0", ): diff --git a/python/sglang/srt/models/glm4v.py b/python/sglang/srt/models/glm4v.py index 953a86c73..d7b0e9618 100644 --- a/python/sglang/srt/models/glm4v.py +++ b/python/sglang/srt/models/glm4v.py @@ -9,6 +9,7 @@ from transformers.models.glm4v.configuration_glm4v import Glm4vConfig, Glm4vVisi from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.attention import vision_utils +from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( ColumnParallelLinear, diff --git a/python/sglang/srt/models/opt.py b/python/sglang/srt/models/opt.py index bf989f6e8..0b2f0edb7 100644 --- a/python/sglang/srt/models/opt.py +++ b/python/sglang/srt/models/opt.py @@ -13,6 +13,7 @@ # ============================================================================== """Inference-only OPT model compatible with HuggingFace weights.""" +import logging from collections.abc import Iterable from typing import Optional, Union @@ -46,6 +47,9 @@ from sglang.srt.model_loader.weight_utils import ( kv_cache_scales_loader, ) from sglang.srt.utils import add_prefix, make_layers +from sglang.utils import get_exception_traceback + +logger = logging.getLogger(__name__) def get_activation(name="relu"): diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index 14637d672..547150059 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -42,6 +42,7 @@ import tempfile import threading import time import traceback +import types import uuid import warnings from collections import OrderedDict, defaultdict @@ -55,6 +56,7 @@ from json import JSONDecodeError from multiprocessing.reduction import ForkingPickler from pathlib import Path from typing import ( + TYPE_CHECKING, Any, Callable, Dict, @@ -62,6 +64,7 @@ from typing import ( List, Optional, Protocol, + Sequence, Set, Tuple, TypeVar, @@ -91,6 +94,9 @@ from typing_extensions import Literal from sglang.srt.environ import envs from sglang.srt.metrics.func_timer import enable_func_timer +if TYPE_CHECKING: + from sglang.srt.layers.quantization.base_config import QuantizeMethodBase + logger = logging.getLogger(__name__) show_time_cost = False @@ -1076,7 +1082,7 @@ def monkey_patch_vllm_gguf_config(): def get_quant_method_with_embedding_replaced( self, layer: torch.nn.Module, prefix: str - ) -> Optional["QuantizeMethodBase"]: + ) -> Optional[QuantizeMethodBase]: if isinstance(layer, LinearBase): return GGUFLinearMethod(self) elif isinstance(layer, VocabParallelEmbedding): @@ -1956,7 +1962,9 @@ def direct_register_custom_op( if fake_impl is not None: my_lib._register_fake(op_name, fake_impl) except RuntimeError as error: - if "Tried to register an operator" in str(e) and "multiple times" in str(e): + if "Tried to register an operator" in str(error) and "multiple times" in str( + error + ): # Silently ignore duplicate registration errors # This can happen in multi-engine scenarios pass diff --git a/python/sglang/test/few_shot_gsm8k_engine.py b/python/sglang/test/few_shot_gsm8k_engine.py index 567816cfc..07eda86e2 100644 --- a/python/sglang/test/few_shot_gsm8k_engine.py +++ b/python/sglang/test/few_shot_gsm8k_engine.py @@ -3,6 +3,7 @@ import ast import asyncio import re import time +from typing import Optional import numpy as np