From 177320a582eccc9ef1fbf2c0156c7ba782c5bf11 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Wed, 16 Apr 2025 15:26:49 -0700 Subject: [PATCH] Clean up imports (#5467) --- python/sglang/__init__.py | 6 +- python/sglang/bench_serving.py | 4 - python/sglang/lang/__init__.py | 0 python/sglang/lang/backend/anthropic.py | 4 - python/sglang/lang/backend/base_backend.py | 2 +- python/sglang/lang/backend/openai.py | 2 +- python/sglang/lang/backend/vertexai.py | 1 - python/sglang/lang/compiler.py | 8 +- python/sglang/lang/tracer.py | 10 +- python/sglang/srt/_custom_ops.py | 2 - python/sglang/srt/custom_op.py | 62 ----- python/sglang/srt/entrypoints/verl_engine.py | 3 +- python/sglang/srt/layers/activation.py | 14 +- python/sglang/srt/layers/layernorm.py | 2 +- python/sglang/srt/layers/moe/ep_moe/layer.py | 40 +-- .../layers/moe/fused_moe_triton/fused_moe.py | 31 +-- python/sglang/srt/layers/moe/topk.py | 30 +-- python/sglang/srt/layers/parameter.py | 2 - .../compressed_tensors/compressed_tensors.py | 3 +- .../compressed_tensors_moe.py | 56 +--- .../schemes/compressed_tensors_w8a8_fp8.py | 3 - python/sglang/srt/layers/quantization/fp8.py | 245 ++++++++---------- .../srt/layers/quantization/fp8_kernel.py | 121 +++++---- .../srt/layers/quantization/fp8_utils.py | 117 ++++----- .../sglang/srt/layers/quantization/utils.py | 16 +- .../srt/layers/quantization/w8a8_int8.py | 12 +- python/sglang/srt/layers/rotary_embedding.py | 5 +- python/sglang/srt/lora/backend/__init__.py | 25 -- .../sglang/srt/lora/backend/base_backend.py | 20 +- .../srt/lora/backend/flashinfer_backend.py | 2 +- .../sglang/srt/lora/backend/triton_backend.py | 2 +- python/sglang/srt/lora/layers.py | 2 +- python/sglang/srt/lora/lora.py | 2 +- python/sglang/srt/lora/lora_manager.py | 2 +- .../srt/managers/detokenizer_manager.py | 1 - python/sglang/srt/managers/mm_utils.py | 7 +- .../srt/managers/multimodal_processor.py | 2 - .../multimodal_processors/base_processor.py | 5 +- .../srt/model_executor/cuda_graph_runner.py | 4 +- .../sglang/srt/model_executor/model_runner.py | 1 - python/sglang/srt/models/deepseek_nextn.py | 4 +- python/sglang/srt/models/deepseek_v2.py | 11 +- python/sglang/srt/reasoning_parser.py | 1 - .../srt/sampling/sampling_batch_info.py | 5 +- python/sglang/srt/server.py | 18 -- python/sglang/srt/server_args.py | 16 +- python/sglang/srt/utils.py | 4 +- python/sglang/test/runners.py | 2 +- python/sglang/test/test_custom_ops.py | 2 +- test/srt/test_fp8_kernel.py | 4 +- .../srt/test_triton_moe_channel_fp8_kernel.py | 6 +- 51 files changed, 376 insertions(+), 573 deletions(-) delete mode 100644 python/sglang/lang/__init__.py delete mode 100644 python/sglang/srt/lora/backend/__init__.py delete mode 100644 python/sglang/srt/server.py diff --git a/python/sglang/__init__.py b/python/sglang/__init__.py index db0cf2604..91dc58e53 100644 --- a/python/sglang/__init__.py +++ b/python/sglang/__init__.py @@ -24,6 +24,7 @@ from sglang.api import ( user_end, video, ) +from sglang.global_config import global_config from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint from sglang.lang.choices import ( greedy_token_selection, @@ -31,6 +32,7 @@ from sglang.lang.choices import ( unconditional_likelihood_normalized, ) from sglang.utils import LazyImport +from sglang.version import __version__ ServerArgs = LazyImport("sglang.srt.server_args", "ServerArgs") Anthropic = LazyImport("sglang.lang.backend.anthropic", "Anthropic") @@ -38,10 +40,6 @@ LiteLLM = LazyImport("sglang.lang.backend.litellm", "LiteLLM") OpenAI = LazyImport("sglang.lang.backend.openai", "OpenAI") VertexAI = LazyImport("sglang.lang.backend.vertexai", "VertexAI") -# Other configs -from sglang.global_config import global_config -from sglang.version import __version__ - __all__ = [ "Engine", "Runtime", diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index 7d434782d..264505875 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -707,10 +707,6 @@ def sample_random_requests( # Download sharegpt if necessary if not os.path.isfile(dataset_path): - print( - "If you do not want to randomly sample from a dataset," - " please use --dataset-name random-ids." - ) dataset_path = download_and_cache_file(SHAREGPT_URL) # Load the dataset. diff --git a/python/sglang/lang/__init__.py b/python/sglang/lang/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/python/sglang/lang/backend/anthropic.py b/python/sglang/lang/backend/anthropic.py index 5a36bd9ac..4918a1703 100644 --- a/python/sglang/lang/backend/anthropic.py +++ b/python/sglang/lang/backend/anthropic.py @@ -1,7 +1,3 @@ -from typing import List, Optional, Union - -import numpy as np - from sglang.lang.backend.base_backend import BaseBackend from sglang.lang.chat_template import get_chat_template from sglang.lang.interpreter import StreamExecutor diff --git a/python/sglang/lang/backend/base_backend.py b/python/sglang/lang/backend/base_backend.py index 725c0a91d..62dd50416 100644 --- a/python/sglang/lang/backend/base_backend.py +++ b/python/sglang/lang/backend/base_backend.py @@ -1,4 +1,4 @@ -from typing import Callable, List, Optional, Union +from typing import List, Optional, Union from sglang.lang.chat_template import get_chat_template from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod diff --git a/python/sglang/lang/backend/openai.py b/python/sglang/lang/backend/openai.py index c6e31f73d..502d2f3e4 100644 --- a/python/sglang/lang/backend/openai.py +++ b/python/sglang/lang/backend/openai.py @@ -2,7 +2,7 @@ import dataclasses import logging import time import warnings -from typing import Callable, List, Optional, Union +from typing import List, Optional, Union import numpy as np diff --git a/python/sglang/lang/backend/vertexai.py b/python/sglang/lang/backend/vertexai.py index c27733b3e..3d51fb137 100644 --- a/python/sglang/lang/backend/vertexai.py +++ b/python/sglang/lang/backend/vertexai.py @@ -1,6 +1,5 @@ import os import warnings -from typing import Optional from sglang.lang.backend.base_backend import BaseBackend from sglang.lang.chat_template import get_chat_template diff --git a/python/sglang/lang/compiler.py b/python/sglang/lang/compiler.py index 5e1b411fc..1284232f7 100644 --- a/python/sglang/lang/compiler.py +++ b/python/sglang/lang/compiler.py @@ -5,13 +5,7 @@ from typing import List, Union from sglang.global_config import global_config from sglang.lang.interpreter import ProgramState, StreamExecutor, cache_program -from sglang.lang.ir import ( - SglArgument, - SglConstantText, - SglExpr, - SglSamplingParams, - SglVariable, -) +from sglang.lang.ir import SglArgument, SglExpr, SglSamplingParams, SglVariable def compile_func(function, backend): diff --git a/python/sglang/lang/tracer.py b/python/sglang/lang/tracer.py index 69f035b07..7b3c72804 100644 --- a/python/sglang/lang/tracer.py +++ b/python/sglang/lang/tracer.py @@ -1,20 +1,16 @@ """Tracing a program.""" import uuid -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional -from sglang.global_config import global_config from sglang.lang.backend.base_backend import BaseBackend from sglang.lang.interpreter import ProgramState, ProgramStateGroup from sglang.lang.ir import ( SglArgument, - SglCommitLazy, - SglConcateAndAppend, SglConstantText, SglExpr, SglExprList, SglFork, - SglFunction, SglGen, SglGetForkItem, SglRoleBegin, @@ -230,8 +226,8 @@ class TracerProgramState(ProgramState): self.cur_role = None def _execute_var_scope_end(self, expr: SglVarScopeEnd): - new_node = SglVariable(name, source=self.last_node) - self.variables[name] = new_node + new_node = SglVariable(expr.name, source=self.last_node) + self.variables[expr.name] = new_node def get_var(self, name): ret = self.arguments.get(name, None) diff --git a/python/sglang/srt/_custom_ops.py b/python/sglang/srt/_custom_ops.py index 2e9db19f9..07c087bf6 100644 --- a/python/sglang/srt/_custom_ops.py +++ b/python/sglang/srt/_custom_ops.py @@ -1,10 +1,8 @@ # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/_custom_ops.py import logging -import os from typing import List, Tuple import torch -import torch.library from sglang.srt.utils import get_bool_env_var, is_hip, is_hpu diff --git a/python/sglang/srt/custom_op.py b/python/sglang/srt/custom_op.py index aed8ec2f1..5e0f4bd1e 100644 --- a/python/sglang/srt/custom_op.py +++ b/python/sglang/srt/custom_op.py @@ -42,65 +42,3 @@ class CustomOp(nn.Module): return self.forward_hip else: return self.forward_native - - -if _is_cuda: - from sgl_kernel import sgl_per_tensor_quant_fp8, sgl_per_token_quant_fp8 - - def scaled_fp8_quant( - input: torch.Tensor, - scale: Optional[torch.Tensor] = None, - num_token_padding: Optional[int] = None, - use_per_token_if_dynamic: bool = False, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Quantize input tensor to FP8 (8-bit floating point) format. - - Args: - input (torch.Tensor): Input tensor to be quantized - scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization. - If None, scales will be computed dynamically. - num_token_padding (Optional[int]): If specified, pad the first dimension - of the output to at least this value. - use_per_token_if_dynamic (bool): When using dynamic scaling (scale=None), - determines the quantization granularity: - - True: compute scale per token - - False: compute single scale per tensor - - Returns: - Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - quantized_tensor: The FP8 quantized version of input - - scale_tensor: The scaling factors used for quantization - - Raises: - AssertionError: If input is not 2D or if static scale's numel != 1 - """ - assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D" - shape = input.shape - out_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn - if num_token_padding: - shape = (max(num_token_padding, input.shape[0]), shape[1]) - output = torch.empty(shape, device=input.device, dtype=out_dtype) - - if scale is None: - # Dynamic scaling - if use_per_token_if_dynamic: - scale = torch.empty( - (shape[0], 1), device=input.device, dtype=torch.float32 - ) - sgl_per_token_quant_fp8(input, output, scale) - else: - scale = torch.zeros(1, device=input.device, dtype=torch.float32) - sgl_per_tensor_quant_fp8( - input, output, scale, is_static=False - ) # False for dynamic - else: - # Static scaling - assert ( - scale.numel() == 1 - ), f"Expected scalar scale, got numel={scale.numel()}" - sgl_per_tensor_quant_fp8( - input, output, scale, is_static=True - ) # True for static - - return output, scale diff --git a/python/sglang/srt/entrypoints/verl_engine.py b/python/sglang/srt/entrypoints/verl_engine.py index ef139af27..e1ce84731 100644 --- a/python/sglang/srt/entrypoints/verl_engine.py +++ b/python/sglang/srt/entrypoints/verl_engine.py @@ -19,11 +19,10 @@ import torch.distributed as dist from PIL.Image import Image from torch.distributed.tensor import DeviceMesh, DTensor +from sglang.srt.entrypoints.engine import Engine from sglang.srt.entrypoints.http_server_engine import HttpServerEngineAdapter from sglang.srt.model_executor.model_runner import LocalSerializedTensor from sglang.srt.patch_torch import monkey_patch_torch_reductions -from sglang.srt.server import Engine -from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj diff --git a/python/sglang/srt/layers/activation.py b/python/sglang/srt/layers/activation.py index 1ee10c0aa..9e7f933d7 100644 --- a/python/sglang/srt/layers/activation.py +++ b/python/sglang/srt/layers/activation.py @@ -21,13 +21,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -from sglang.srt.utils import is_cuda_available - -_is_cuda = is_cuda_available() - -if _is_cuda: - from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul - from sglang.srt.custom_op import CustomOp from sglang.srt.distributed import ( divide, @@ -35,7 +28,12 @@ from sglang.srt.distributed import ( get_tensor_model_parallel_world_size, ) from sglang.srt.layers.quantization.base_config import QuantizationConfig -from sglang.srt.utils import set_weight_attrs +from sglang.srt.utils import is_cuda_available, set_weight_attrs + +_is_cuda = is_cuda_available() + +if _is_cuda: + from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index 47dccc9f9..0359d7234 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -19,6 +19,7 @@ from typing import Optional, Tuple, Union import torch import torch.nn as nn +from sglang.srt.custom_op import CustomOp from sglang.srt.utils import is_cuda_available _is_cuda = is_cuda_available() @@ -31,7 +32,6 @@ if _is_cuda: rmsnorm, ) -from sglang.srt.custom_op import CustomOp logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index dfecb63d9..a35d0b8d0 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -2,6 +2,7 @@ import logging from typing import Callable, List, Optional, Tuple import torch +from torch.nn import Module try: from deep_gemm import ( @@ -13,8 +14,6 @@ try: except ImportError: use_deep_gemm = False -from torch.nn import Module - from sglang.srt.custom_op import CustomOp from sglang.srt.distributed import ( get_tensor_model_parallel_rank, @@ -37,21 +36,16 @@ from sglang.srt.layers.quantization.base_config import ( QuantizeMethodBase, ) from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod +from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant from sglang.srt.model_executor.forward_batch_info import ForwardMode -from sglang.srt.utils import DeepEPMode, is_cuda, is_hip, set_weight_attrs - -_is_cuda = is_cuda() - -if _is_cuda: - from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant -else: - from vllm import _custom_ops as vllm_ops - -logger = logging.getLogger(__name__) +from sglang.srt.utils import DeepEPMode, is_hip, set_weight_attrs _is_hip = is_hip() -_buffer = None +if _is_hip: + from vllm._custom_ops import scaled_fp8_quant + +logger = logging.getLogger(__name__) class GroupedGemmRunner(torch.nn.Module): @@ -740,20 +734,12 @@ class Fp8EPMoEMethod(Fp8MoEMethod): ) for expert in range(layer.num_experts_per_partition): - if _is_cuda: - w13_weight[expert, :, :], layer.w13_weight_scale[expert] = ( - sgl_scaled_fp8_quant(layer.w13_weight.data[expert, :, :]) - ) - w2_weight[expert, :, :], layer.w2_weight_scale[expert] = ( - sgl_scaled_fp8_quant(layer.w2_weight.data[expert, :, :]) - ) - else: - w13_weight[expert, :, :], layer.w13_weight_scale[expert] = ( - vllm_ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :]) - ) - w2_weight[expert, :, :], layer.w2_weight_scale[expert] = ( - vllm_ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :]) - ) + w13_weight[expert, :, :], layer.w13_weight_scale[expert] = ( + scaled_fp8_quant(layer.w13_weight.data[expert, :, :]) + ) + w2_weight[expert, :, :], layer.w2_weight_scale[expert] = ( + scaled_fp8_quant(layer.w2_weight.data[expert, :, :]) + ) layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) return diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index 282cfbf06..61400787a 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -13,6 +13,7 @@ import triton import triton.language as tl from sglang.srt.layers.moe.topk import select_experts +from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant from sglang.srt.utils import ( direct_register_custom_op, get_bool_env_var, @@ -22,28 +23,25 @@ from sglang.srt.utils import ( ) _is_hip = is_hip() - - -logger = logging.getLogger(__name__) -padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0 - -enable_moe_align_block_size_triton = bool( - int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0")) -) - _is_cuda = is_cuda() if _is_cuda: from sgl_kernel import gelu_and_mul, silu_and_mul - - from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant else: from vllm import _custom_ops as vllm_ops + from vllm._custom_ops import scaled_fp8_quant if _is_cuda or _is_hip: from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size +logger = logging.getLogger(__name__) +padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0 +enable_moe_align_block_size_triton = bool( + int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0")) +) + + @triton.jit def write_zeros_to_output( c_ptr, @@ -770,14 +768,9 @@ def invoke_fused_moe_kernel( # activation tensor-wise fp8 quantization, dynamic or static padded_size = padding_size # activations apply per-token quantization when weights apply per-channel quantization by default - if _is_cuda: - A, A_scale = sgl_scaled_fp8_quant( - A, A_scale, use_per_token_if_dynamic=per_channel_quant - ) - else: - A, A_scale = vllm_ops.scaled_fp8_quant( - A, A_scale, use_per_token_if_dynamic=per_channel_quant - ) + A, A_scale = scaled_fp8_quant( + A, A_scale, use_per_token_if_dynamic=per_channel_quant + ) else: # activation block-wise fp8 quantization assert len(block_shape) == 2 diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index c12b9d019..f17b908cb 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -13,7 +13,6 @@ # ============================================================================== import math -import os from typing import Callable, Optional import torch @@ -29,6 +28,10 @@ _is_hip = is_hip() if _is_cuda: from sgl_kernel import moe_fused_gate +if _is_cuda or _is_hip: + from sgl_kernel import topk_softmax + + expert_distribution_recorder = ExpertDistributionRecorder() @@ -59,11 +62,6 @@ def fused_topk( topk: int, renormalize: bool, ): - if _is_cuda or _is_hip: - from sgl_kernel import topk_softmax - else: - from vllm import _custom_ops as vllm_ops - assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" M, _ = hidden_states.shape @@ -76,20 +74,12 @@ def fused_topk( M, topk, dtype=torch.int32, device=hidden_states.device ) - if _is_cuda or _is_hip: - topk_softmax( - topk_weights, - topk_ids, - token_expert_indicies, - gating_output.float(), - ) - else: - vllm_ops.topk_softmax( - topk_weights, - topk_ids, - token_expert_indicies, - gating_output.float(), - ) + topk_softmax( + topk_weights, + topk_ids, + token_expert_indicies, + gating_output.float(), + ) del token_expert_indicies if renormalize: diff --git a/python/sglang/srt/layers/parameter.py b/python/sglang/srt/layers/parameter.py index 33bd46c05..978ec0ad0 100644 --- a/python/sglang/srt/layers/parameter.py +++ b/python/sglang/srt/layers/parameter.py @@ -7,8 +7,6 @@ from typing import Callable, Optional, Union import torch from torch.nn import Parameter -from sglang.srt.distributed import get_tensor_model_parallel_rank - __all__ = [ "BasevLLMParameter", "PackedvLLMParameter", diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py index ce2155600..b0a664460 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py @@ -1,4 +1,4 @@ -# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors +# Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors # SPDX-License-Identifier: Apache-2.0 import logging @@ -39,7 +39,6 @@ from sglang.srt.layers.quantization.compressed_tensors.utils import ( is_activation_quantization_format, should_ignore_layer, ) -from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 569f2a2d6..b8d9d637e 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -1,22 +1,16 @@ -# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors +# Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors # SPDX-License-Identifier: Apache-2.0 import enum import logging from enum import Enum -from typing import TYPE_CHECKING, Callable, List, Optional +from typing import Callable, List, Optional import torch from compressed_tensors import CompressionFormat from compressed_tensors.quantization import QuantizationStrategy -if TYPE_CHECKING: - from sglang.srt.layers.moe.fused_moe_triton import ( - FusedMoE, - FusedMoEMethodBase, - FusedMoeWeightScaleSupported, - ) - +from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz from sglang.srt.layers.quantization.utils import ( all_close_1d, @@ -29,10 +23,9 @@ from sglang.srt.utils import set_weight_attrs _is_cuda = is_cuda() -if _is_cuda: - from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant -else: +if not _is_cuda: from vllm import _custom_ops as vllm_ops + from vllm._custom_ops import scaled_fp8_quant try: import vllm @@ -58,8 +51,6 @@ __all__ = [ class CompressedTensorsMoEMethod: def __new__(cls, *args, **kwargs): - from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase - if cls is CompressedTensorsMoEMethod: return super().__new__(cls) return super().__new__(cls) @@ -76,7 +67,7 @@ class CompressedTensorsMoEMethod: if quant_config._is_wNa16_group_channel(weight_quant, input_quant): if not VLLM_AVAILABLE: raise ImportError( - "vllm is not installed, to use CompressedTensorsWNA16MoEMethod, please install vllm" + "vllm is not installed, to use CompressedTensorsWNA16MoEMethod, please install vllm." ) return CompressedTensorsWNA16MoEMethod(quant_config) elif quant_config._is_fp8_w8a8(weight_quant, input_quant): @@ -92,11 +83,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): def __init__( self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 ): - from sglang.srt.layers.moe.fused_moe_triton import ( - FusedMoEMethodBase, - FusedMoeWeightScaleSupported, - ) - self.quant_config = quant_config self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights") self.input_quant = self.quant_config.target_scheme_map["Linear"].get( @@ -267,19 +253,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): layer.w13_weight[expert_id][start : start + shard_size, :], layer.w13_weight_scale[expert_id][shard_id], ) + ( + layer.w13_weight[expert_id][start : start + shard_size, :], + _, + ) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id]) - if _is_cuda: - ( - layer.w13_weight[expert_id][start : start + shard_size, :], - _, - ) = sgl_scaled_fp8_quant(dq_weight, max_w13_scales[expert_id]) - else: - ( - layer.w13_weight[expert_id][start : start + shard_size, :], - _, - ) = vllm_ops.scaled_fp8_quant( - dq_weight, max_w13_scales[expert_id] - ) start += shard_size layer.w13_weight_scale = torch.nn.Parameter( @@ -345,11 +323,6 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): def __init__( self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 ): - from sglang.srt.layers.moe.fused_moe_triton import ( - FusedMoEMethodBase, - FusedMoeWeightScaleSupported, - ) - self.quant_config = quant_config # TODO: @dsikka: refactor this to use schemes as other kernels # are supported + check if the layer is being ignored. @@ -609,7 +582,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): requires_grad=False, ) - marlin_w13_qweight = ops.gptq_marlin_moe_repack( + marlin_w13_qweight = vllm_ops.gptq_marlin_moe_repack( layer.w13_weight_packed, layer.w13_g_idx_sort_indices, layer.w13_weight_packed.shape[1] * self.packed_factor, @@ -617,7 +590,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): self.num_bits, ) replace_tensor("w13_weight_packed", marlin_w13_qweight) - marlin_w2_qweight = ops.gptq_marlin_moe_repack( + marlin_w2_qweight = vllm_ops.gptq_marlin_moe_repack( layer.w2_weight_packed, layer.w2_g_idx_sort_indices, layer.w2_weight_packed.shape[1] * self.packed_factor, @@ -661,14 +634,9 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): correction_bias: Optional[torch.Tensor] = None, activation: str = "silu", ) -> torch.Tensor: - from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.topk import select_experts assert activation == "silu", "Only SiLU activation is supported." - if not VLLM_AVAILABLE: - raise ImportError( - "vllm is not installed, to use fused_marlin_moe, please install vllm" - ) if expert_map is not None: raise NotImplementedError( "Expert Parallelism is not supported for " "fused Marlin MoE method." diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 6c624a070..8e4f84714 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -17,7 +17,6 @@ from sglang.srt.layers.quantization.compressed_tensors.schemes import ( ) from sglang.srt.layers.quantization.fp8_utils import ( Fp8LinearOp, - maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz, ) from sglang.srt.layers.quantization.utils import is_fp8_fnuz, requantize_with_max_scale @@ -99,8 +98,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): weight_loader: Callable, **kwargs, ): - maybe_create_device_identity() - output_size_per_partition = sum(output_partition_sizes) layer.logical_widths = output_partition_sizes diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 3749abc34..5ba2b3fb8 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -8,15 +8,6 @@ import torch.nn.functional as F from torch.nn import Module from torch.nn.parameter import Parameter -from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod -from sglang.srt.layers.quantization.utils import ( - all_close_1d, - convert_to_channelwise, - is_layer_skipped, - per_tensor_dequantize, - requantize_with_max_scale, -) - try: from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( apply_fp8_marlin_linear, @@ -27,11 +18,12 @@ try: except ImportError: MARLIN_FP8_AVAILABLE = False - def apply_fp8_marlin_linear(*args, **kwargs): - raise ImportError("vllm is not installed") + def dummy_func(*args, **kwargs): + raise ImportError( + "marlin FP8 requires some operators from vllm. Please install vllm." + ) - def prepare_fp8_layer_for_marlin(*args, **kwargs): - raise ImportError("vllm is not installed") + apply_fp8_marlin_linear = prepare_fp8_layer_for_marlin = dummy_func from sglang.srt.distributed import get_tensor_model_parallel_world_size @@ -49,7 +41,10 @@ from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) -from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8 +from sglang.srt.layers.quantization.fp8_kernel import ( + per_token_group_quant_fp8, + scaled_fp8_quant, +) from sglang.srt.layers.quantization.fp8_utils import ( apply_fp8_linear, apply_w8a8_block_fp8_linear, @@ -57,30 +52,35 @@ from sglang.srt.layers.quantization.fp8_utils import ( input_to_float8, normalize_e4m3fn_to_e4m3fnuz, ) +from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod +from sglang.srt.layers.quantization.utils import ( + all_close_1d, + convert_to_channelwise, + is_layer_skipped, + per_tensor_dequantize, + requantize_with_max_scale, +) from sglang.srt.utils import ( get_bool_env_var, is_cuda, is_hip, - permute_weight, print_warning_once, set_weight_attrs, ) -ACTIVATION_SCHEMES = ["static", "dynamic"] - _is_hip = is_hip() +_is_cuda = is_cuda() if _is_hip: from aiter import ActivationType from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages, ck_moe_2stages_win4 from aiter.ops.shuffle import shuffle_weight -_is_cuda = is_cuda() +if not _is_cuda: + from vllm._custom_ops import scaled_fp8_quant -if _is_cuda: - from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant -else: - from vllm import _custom_ops as vllm_ops + +ACTIVATION_SCHEMES = ["static", "dynamic"] logger = logging.getLogger(__name__) @@ -243,7 +243,6 @@ class Fp8LinearMethod(LinearMethodBase): ) layer.logical_widths = output_partition_sizes - layer.input_size_per_partition = input_size_per_partition layer.output_size_per_partition = output_size_per_partition layer.orig_dtype = params_dtype @@ -327,7 +326,9 @@ class Fp8LinearMethod(LinearMethodBase): layer.weight_scale_inv.data, requires_grad=False ) return + layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False) + # If checkpoint not serialized fp8, quantize the weights. if not self.quant_config.is_checkpoint_fp8_serialized: if self.cutlass_fp8_supported or self.use_marlin: @@ -391,12 +392,9 @@ class Fp8LinearMethod(LinearMethodBase): ) if self.use_marlin: - try: - prepare_fp8_layer_for_marlin(layer) - # Activations not quantized for marlin. - del layer.input_scale - except ImportError: - self.use_marlin = False + prepare_fp8_layer_for_marlin(layer) + # Activations not quantized for marlin. + del layer.input_scale def apply( self, @@ -406,18 +404,15 @@ class Fp8LinearMethod(LinearMethodBase): ) -> torch.Tensor: if self.use_marlin: - try: - return apply_fp8_marlin_linear( - input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - workspace=layer.workspace, - size_n=layer.output_size_per_partition, - size_k=layer.input_size_per_partition, - bias=bias, - ) - except ImportError: - self.use_marlin = False + return apply_fp8_marlin_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + workspace=layer.workspace, + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + bias=bias, + ) if self.block_quant: return apply_w8a8_block_fp8_linear( @@ -516,7 +511,7 @@ class Fp8MoEMethod: ) # WEIGHTS - if get_bool_env_var("USE_INT4_WEIGHT"): + if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"): # INT4 MoE weight - INT32 packed w13_weight = torch.nn.Parameter( torch.empty( @@ -617,7 +612,7 @@ class Fp8MoEMethod: set_weight_attrs(w13_weight_scale, extra_weight_attrs) set_weight_attrs(w2_weight_scale, extra_weight_attrs) - if get_bool_env_var("USE_INT4_WEIGHT"): + if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"): extra_weight_attrs.update( {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value} ) @@ -649,7 +644,7 @@ class Fp8MoEMethod: layer.w2_input_scale = None def process_weights_after_loading(self, layer: Module) -> None: - if get_bool_env_var("USE_INT4_WEIGHT"): + if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"): self.process_weights_hip_int4(layer) return @@ -706,20 +701,12 @@ class Fp8MoEMethod: requires_grad=False, ) for expert in range(layer.num_experts): - if _is_cuda: - w13_weight[expert, :, :], layer.w13_weight_scale[expert] = ( - sgl_scaled_fp8_quant(layer.w13_weight.data[expert, :, :]) - ) - w2_weight[expert, :, :], layer.w2_weight_scale[expert] = ( - sgl_scaled_fp8_quant(layer.w2_weight.data[expert, :, :]) - ) - else: - w13_weight[expert, :, :], layer.w13_weight_scale[expert] = ( - vllm_ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :]) - ) - w2_weight[expert, :, :], layer.w2_weight_scale[expert] = ( - vllm_ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :]) - ) + w13_weight[expert, :, :], layer.w13_weight_scale[expert] = ( + scaled_fp8_quant(layer.w13_weight.data[expert, :, :]) + ) + w2_weight[expert, :, :], layer.w2_weight_scale[expert] = ( + scaled_fp8_quant(layer.w2_weight.data[expert, :, :]) + ) layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) @@ -796,18 +783,10 @@ class Fp8MoEMethod: layer.w13_weight[expert_id][start : start + shard_size, :], layer.w13_weight_scale[expert_id][shard_id], ) - if _is_cuda: - ( - layer.w13_weight[expert_id][start : start + shard_size, :], - _, - ) = sgl_scaled_fp8_quant(dq_weight, max_w13_scales[expert_id]) - else: - ( - layer.w13_weight[expert_id][start : start + shard_size, :], - _, - ) = vllm_ops.scaled_fp8_quant( - dq_weight, max_w13_scales[expert_id] - ) + ( + layer.w13_weight[expert_id][start : start + shard_size, :], + _, + ) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id]) start += shard_size layer.w13_weight_scale = torch.nn.Parameter( @@ -930,41 +909,11 @@ class Fp8MoEMethod: correction_bias=correction_bias, ) - if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"): - # TODO: add triton kernel and add check get_bool_env_var("CK_MOE") - assert not no_combine, f"{no_combine=} is not supported." - return ck_moe_2stages_win4( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights, - topk_ids, - layer.w13_weight_scale1, - layer.w2_weight_scale1, - activation=( - ActivationType.Silu if activation == "silu" else ActivationType.Gelu - ), - ) - if _is_hip and get_bool_env_var("CK_MOE"): - assert not no_combine, f"{no_combine=} is not supported." - if self.block_quant: - # TODO(CK_MOE): FP8 block_quant only supports 'silu' for the time-being. - assert ( - activation == "silu" - ), f"CK_MOE: FP8 bloack_quant {activation=} will be supported later, unset CK_MOE" - return asm_moe( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights, - topk_ids, - layer.w13_weight_scale_inv, - layer.w2_weight_scale_inv, - block_shape=tuple(self.quant_config.weight_block_size), - expert_mask=None, - ) - else: - return ck_moe_2stages( + if _is_hip: + if get_bool_env_var("USE_INT4_WEIGHT"): + # TODO: add triton kernel and add check get_bool_env_var("CK_MOE") + assert not no_combine, f"{no_combine=} is not supported." + return ck_moe_2stages_win4( x, layer.w13_weight, layer.w2_weight, @@ -978,33 +927,65 @@ class Fp8MoEMethod: else ActivationType.Gelu ), ) - else: - # Expert fusion with FP8 quantization - return fused_experts( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=inplace and not no_combine, - activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input, - use_fp8_w8a8=True, - w1_scale=( - layer.w13_weight_scale_inv - if self.block_quant - else layer.w13_weight_scale - ), - w2_scale=( - layer.w2_weight_scale_inv - if self.block_quant - else layer.w2_weight_scale - ), - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, - block_shape=self.quant_config.weight_block_size, - no_combine=no_combine, - ) + + if get_bool_env_var("CK_MOE"): + assert not no_combine, f"{no_combine=} is not supported." + if self.block_quant: + # TODO(CK_MOE): FP8 block_quant only supports 'silu' for the time-being. + assert ( + activation == "silu" + ), f"CK_MOE: FP8 bloack_quant {activation=} will be supported later, unset CK_MOE" + return asm_moe( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + layer.w13_weight_scale_inv, + layer.w2_weight_scale_inv, + block_shape=tuple(self.quant_config.weight_block_size), + expert_mask=None, + ) + else: + return ck_moe_2stages( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + layer.w13_weight_scale1, + layer.w2_weight_scale1, + activation=( + ActivationType.Silu + if activation == "silu" + else ActivationType.Gelu + ), + ) + + # Expert fusion with FP8 quantization + return fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=inplace and not no_combine, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + use_fp8_w8a8=True, + w1_scale=( + layer.w13_weight_scale_inv + if self.block_quant + else layer.w13_weight_scale + ), + w2_scale=( + layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale + ), + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + block_shape=self.quant_config.weight_block_size, + no_combine=no_combine, + ) class Fp8KVCacheMethod(BaseKVCacheMethod): diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py index 36060d374..71de14b46 100644 --- a/python/sglang/srt/layers/quantization/fp8_kernel.py +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -34,15 +34,23 @@ from sglang.srt.utils import ( supports_custom_op, ) -_enable_jit_deepgemm = False - _is_hip = is_hip() -fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn - _is_cuda = is_cuda() +_fp8_type = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn +if _is_hip: + fp8_max = 224.0 +else: + fp8_max = torch.finfo(_fp8_type).max +fp8_min = -fp8_max + +_enable_jit_deepgemm = False if _is_cuda: import deep_gemm - from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8 + from sgl_kernel import ( + sgl_per_tensor_quant_fp8, + sgl_per_token_group_quant_fp8, + sgl_per_token_quant_fp8, + ) sm_version = get_device_sm() if sm_version == 90 and get_bool_env_var( @@ -53,6 +61,7 @@ if _is_cuda: logger = logging.getLogger(__name__) + if supports_custom_op(): def deep_gemm_fp8_fp8_bf16_nt( @@ -179,7 +188,6 @@ def per_token_group_quant_fp8( x: torch.Tensor, group_size: int, eps: float = 1e-10, - dtype: torch.dtype = fp8_type_, column_major_scales: bool = False, scale_tma_aligned: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -192,7 +200,6 @@ def per_token_group_quant_fp8( x: The input tenosr with ndim >= 2. group_size: The group size used for quantization. eps: The minimum to avoid dividing zero. - dtype: The dype of output tensor. Returns: Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization. @@ -202,15 +209,7 @@ def per_token_group_quant_fp8( ), "the last dimension of `x` cannot be divisible by `group_size`" assert x.is_contiguous(), "`x` is not contiguous" - finfo = torch.finfo(dtype) - fp8_max = finfo.max - - if _is_hip: - fp8_max = 224.0 - - fp8_min = -fp8_max - - x_q = torch.empty_like(x, device=x.device, dtype=dtype) + x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type) M = x.numel() // group_size N = group_size if column_major_scales: @@ -276,27 +275,18 @@ def sglang_per_token_group_quant_fp8( x: torch.Tensor, group_size: int, eps: float = 1e-10, - dtype: torch.dtype = fp8_type_, ): assert ( x.shape[-1] % group_size == 0 ), "the last dimension of `x` cannot be divisible by `group_size`" assert x.is_contiguous(), "`x` is not contiguous" - finfo = torch.finfo(dtype) - fp8_max = finfo.max - - fp8_min = -fp8_max - - x_q = torch.empty_like(x, device=x.device, dtype=dtype) - M = x.numel() // group_size - N = group_size + x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type) x_s = torch.empty( x.shape[:-1] + (x.shape[-1] // group_size,), device=x.device, dtype=torch.float32, ) - sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max) return x_q, x_s @@ -304,7 +294,7 @@ def sglang_per_token_group_quant_fp8( def sglang_per_token_quant_fp8( x: torch.Tensor, - dtype: torch.dtype = fp8_type_, + dtype: torch.dtype = _fp8_type, ): assert x.is_contiguous(), "`x` is not contiguous" @@ -368,7 +358,6 @@ def static_quant_fp8( x: torch.Tensor, x_s: torch.Tensor, repeat_scale: bool = False, - dtype: torch.dtype = fp8_type_, ) -> Tuple[torch.Tensor, torch.Tensor]: """Function to perform static quantization using the given scale on an input tensor `x`. @@ -386,15 +375,8 @@ def static_quant_fp8( """ assert x.is_contiguous(), "`x` is not contiguous" assert x_s.numel() == 1, "only supports per-tensor scale" - finfo = torch.finfo(dtype) - fp8_max = finfo.max - if _is_hip: - fp8_max = 224.0 - - fp8_min = -fp8_max - - x_q = torch.empty_like(x, device=x.device, dtype=dtype) + x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type) M = x.numel() // x.shape[-1] N = x.shape[-1] if repeat_scale: @@ -896,7 +878,7 @@ def _per_tensor_quant_mla_fp8_stage2( def per_tensor_quant_mla_fp8( - x: torch.Tensor, eps: float = 1e-12, dtype: torch.dtype = torch.float8_e4m3fn + x: torch.Tensor, eps: float = 1e-12 ) -> Tuple[torch.Tensor, torch.Tensor]: """ This function quantizes input values to float8 values with tensor-wise quantization @@ -904,13 +886,7 @@ def per_tensor_quant_mla_fp8( """ assert x.dim() == 3, "`x` is not a 3d-tensor" - finfo = torch.finfo(dtype) - fp8_max = finfo.max - if _is_hip: - dtype = torch.float8_e4m3fnuz - fp8_max = 224.0 - - x_q = x.new_empty(x.size(), dtype=dtype) + x_q = x.new_empty(x.size(), dtype=_fp8_type) x_s = torch.zeros((1,), dtype=torch.float32, device=x.device) num_head, num_seq, head_size = x.shape @@ -935,9 +911,64 @@ def per_tensor_quant_mla_fp8( head_size, x.stride(0), x.stride(1), - -fp8_max, + fp8_min, fp8_max, BLOCK_SIZE, ) return x_q, x_s + + +def scaled_fp8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + num_token_padding: Optional[int] = None, + use_per_token_if_dynamic: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Quantize input tensor to FP8 (8-bit floating point) format. + + Args: + input (torch.Tensor): Input tensor to be quantized + scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization. + If None, scales will be computed dynamically. + num_token_padding (Optional[int]): If specified, pad the first dimension + of the output to at least this value. + use_per_token_if_dynamic (bool): When using dynamic scaling (scale=None), + determines the quantization granularity: + - True: compute scale per token + - False: compute single scale per tensor + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - quantized_tensor: The FP8 quantized version of input + - scale_tensor: The scaling factors used for quantization + + Raises: + AssertionError: If input is not 2D or if static scale's numel != 1 + """ + assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D" + shape = input.shape + out_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn + if num_token_padding: + shape = (max(num_token_padding, input.shape[0]), shape[1]) + output = torch.empty(shape, device=input.device, dtype=out_dtype) + + if scale is None: + # Dynamic scaling + if use_per_token_if_dynamic: + scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32) + sgl_per_token_quant_fp8(input, output, scale) + else: + scale = torch.zeros(1, device=input.device, dtype=torch.float32) + sgl_per_tensor_quant_fp8( + input, output, scale, is_static=False + ) # False for dynamic + else: + # Static scaling + assert scale.numel() == 1, f"Expected scalar scale, got numel={scale.numel()}" + sgl_per_tensor_quant_fp8( + input, output, scale, is_static=True + ) # True for static + + return output, scale diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index b9f4e2804..7acf95678 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -1,11 +1,19 @@ -import os from typing import List, Optional, Tuple import torch +try: + from vllm import _custom_ops as vllm_ops + + VLLM_AVAILABLE = True +except ImportError: + VLLM_AVAILABLE = False + from sglang.srt.layers.quantization.fp8_kernel import ( _enable_jit_deepgemm, per_token_group_quant_fp8, + scaled_fp8_quant, + sglang_per_token_quant_fp8, static_quant_fp8, w8a8_block_fp8_matmul, ) @@ -17,30 +25,20 @@ from sglang.srt.utils import ( is_hip, ) -try: - import vllm - from vllm import _custom_ops as ops - - VLLM_AVAILABLE = True -except ImportError: - VLLM_AVAILABLE = False - -use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_KERNEL") - _is_hip = is_hip() +_is_cuda = is_cuda() + if _is_hip and get_bool_env_var("CK_MOE"): from aiter import gemm_a8w8_blockscale -_is_cuda = is_cuda() if _is_cuda: from sgl_kernel import fp8_blockwise_scaled_mm, fp8_scaled_mm - from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant - from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_quant_fp8 +use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_KERNEL") # Input scaling factors are no longer optional in _scaled_mm starting # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale -TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32) +TORCH_DEVICE_IDENTITY = None _TORCH_VERSION = torch.__version__.split("+")[0] try: @@ -214,7 +212,7 @@ def block_quant_to_tensor_quant( x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i] x_q_tensor, scale = ( - sgl_scaled_fp8_quant(x_dq_block) + scaled_fp8_quant(x_dq_block) if _is_cuda else input_to_float8(x_dq_block, dtype=x_q_block.dtype) ) @@ -227,7 +225,7 @@ def channel_quant_to_tensor_quant( ) -> Tuple[torch.Tensor, torch.Tensor]: x_dq_channel = x_q_channel.to(torch.float32) * x_s x_q_tensor, scale = ( - sgl_scaled_fp8_quant(x_dq_channel) + scaled_fp8_quant(x_dq_channel) if _is_cuda else input_to_float8(x_dq_channel, dtype=x_q_channel.dtype) ) @@ -264,7 +262,7 @@ def apply_fp8_linear( # final solution should be: 1. add support to per-tensor activation scaling. # 2. solve the torch.compile error from weight_scale.numel() == 1 and x_scale.numel() > 1 (below line#308) if _is_hip and weight_scale.numel() == 1: - qinput, x_scale = ops.scaled_fp8_quant( + qinput, x_scale = vllm_ops.scaled_fp8_quant( input_2d, input_scale, use_per_token_if_dynamic=use_per_token_if_dynamic, @@ -275,32 +273,29 @@ def apply_fp8_linear( ) if cutlass_fp8_supported: - try: - if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel: - # Fall back to vllm cutlass w8a8 fp8 kernel - output = ops.cutlass_scaled_mm( - qinput, - weight, - out_dtype=input.dtype, - scale_a=x_scale, - scale_b=weight_scale, - bias=bias, - ) - else: - assert ( - weight_scale.numel() == weight.shape[1] - ), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale" - output = fp8_scaled_mm( - qinput, - weight, - x_scale, - weight_scale, - out_dtype=input.dtype, - bias=bias, - ) - return output.view(*output_shape) - except (ImportError, NameError, AttributeError): - pass + if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel: + # Fall back to vllm cutlass w8a8 fp8 kernel + output = vllm_ops.cutlass_scaled_mm( + qinput, + weight, + out_dtype=input.dtype, + scale_a=x_scale, + scale_b=weight_scale, + bias=bias, + ) + else: + assert ( + weight_scale.numel() == weight.shape[1] + ), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale" + output = fp8_scaled_mm( + qinput, + weight, + x_scale, + weight_scale, + out_dtype=input.dtype, + bias=bias, + ) + return output.view(*output_shape) # torch.scaled_mm supports per tensor weights + activations only # so fallback to naive if per channel or per token @@ -343,8 +338,10 @@ def apply_fp8_linear( # Making sure the dummy tensor is on the same device as the weight global TORCH_DEVICE_IDENTITY - if TORCH_DEVICE_IDENTITY.device != weight.device: - TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device) + if TORCH_DEVICE_IDENTITY is None: + TORCH_DEVICE_IDENTITY = torch.ones( + 1, dtype=torch.float32, device=weight.device + ) # GEMM # This computes C = (X * W). @@ -372,13 +369,6 @@ def apply_fp8_linear( return output.to(dtype=input.dtype).view(*output_shape) -def maybe_create_device_identity(): - # Allocate dummy ones tensor for torch._scaled_mm - global TORCH_DEVICE_IDENTITY - if TORCH_DEVICE_IDENTITY is None: - TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32) - - # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/w8a8_utils.py # TODO(luka): follow similar pattern for marlin and block-fp8-linear # https://github.com/vllm-project/vllm/issues/14397 @@ -405,9 +395,7 @@ class Fp8LinearOp: # We also don't pad when using torch.compile, # as it breaks with dynamic shapes. if pad_output is None: - enable_torch_compile = os.environ.get( - "SGLANG_ENABLE_TORCH_COMPILE", "0" - ).lower() in ("1", "true", "yes") + enable_torch_compile = get_bool_env_var("SGLANG_ENABLE_TORCH_COMPILE") pad_output = not enable_torch_compile self.output_padding = 17 if pad_output else None @@ -439,13 +427,13 @@ class Fp8LinearOp: # for sgl-kernel fp8_scaled_mm, it support per channel W now if self.cutlass_fp8_supported and weight_scale.numel() == weight.shape[1]: if _is_cuda: - qinput, x_scale = sgl_scaled_fp8_quant( + qinput, x_scale = scaled_fp8_quant( input_2d, input_scale, use_per_token_if_dynamic=use_per_token_if_dynamic, ) else: - qinput, x_scale = ops.scaled_fp8_quant( + qinput, x_scale = vllm_ops.scaled_fp8_quant( input_2d, input_scale, scale_ub=input_scale_ub, @@ -455,7 +443,7 @@ class Fp8LinearOp: # Fused GEMM_DQ if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel: # Fall back to vllm cutlass w8a8 fp8 kernel - output = ops.cutlass_scaled_mm( + output = vllm_ops.cutlass_scaled_mm( qinput, weight, out_dtype=input.dtype, @@ -482,14 +470,14 @@ class Fp8LinearOp: else: # Maybe apply padding to output, see comment in __init__ if _is_cuda: - qinput, x_scale = sgl_scaled_fp8_quant( + qinput, x_scale = scaled_fp8_quant( input_2d, input_scale, num_token_padding=self.output_padding, use_per_token_if_dynamic=use_per_token_if_dynamic, ) else: - qinput, x_scale = ops.scaled_fp8_quant( + qinput, x_scale = vllm_ops.scaled_fp8_quant( input_2d, input_scale, num_token_padding=self.output_padding, @@ -562,9 +550,12 @@ class Fp8LinearOp: # This computes C = (X * W). # Output in fp32 to allow subsequent ops to happen in-place + # Making sure the dummy tensor is on the same device as the weight global TORCH_DEVICE_IDENTITY - if TORCH_DEVICE_IDENTITY.device != weight.device: - TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device) + if TORCH_DEVICE_IDENTITY is None: + TORCH_DEVICE_IDENTITY = torch.ones( + 1, dtype=torch.float32, device=weight.device + ) output = torch._scaled_mm( qinput, diff --git a/python/sglang/srt/layers/quantization/utils.py b/python/sglang/srt/layers/quantization/utils.py index 5615dbca3..fdea997e3 100644 --- a/python/sglang/srt/layers/quantization/utils.py +++ b/python/sglang/srt/layers/quantization/utils.py @@ -1,18 +1,17 @@ # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py from types import MappingProxyType -from typing import List, Mapping, Optional, Tuple, Union +from typing import List, Mapping, Tuple, Union import torch +from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant from sglang.srt.utils import is_cuda _is_cuda = is_cuda() -if _is_cuda: - from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant -else: - from vllm import _custom_ops as vllm_ops +if not _is_cuda: + from vllm._custom_ops import scaled_fp8_quant def is_fp8_fnuz() -> bool: @@ -116,12 +115,7 @@ def requantize_with_max_scale( for idx, logical_width in enumerate(logical_widths): end = start + logical_width weight_dq = per_tensor_dequantize(weight[start:end, :], weight_scale[idx]) - if _is_cuda: - weight[start:end, :], _ = sgl_scaled_fp8_quant(weight_dq, max_w_scale) - else: - weight[start:end, :], _ = vllm_ops.scaled_fp8_quant( - weight_dq, max_w_scale - ) + weight[start:end, :], _ = scaled_fp8_quant(weight_dq, max_w_scale) start = end return max_w_scale, weight diff --git a/python/sglang/srt/layers/quantization/w8a8_int8.py b/python/sglang/srt/layers/quantization/w8a8_int8.py index 6df5693f8..df345a0a2 100644 --- a/python/sglang/srt/layers/quantization/w8a8_int8.py +++ b/python/sglang/srt/layers/quantization/w8a8_int8.py @@ -1,13 +1,6 @@ from typing import Any, Callable, Dict, List, Optional import torch - -from sglang.srt.utils import is_cuda_available, set_weight_attrs - -is_cuda = is_cuda_available() -if is_cuda: - from sgl_kernel import int8_scaled_mm - from torch.nn.parameter import Parameter from sglang.srt.distributed import get_tensor_model_parallel_world_size @@ -18,6 +11,11 @@ from sglang.srt.layers.quantization.base_config import ( QuantizeMethodBase, ) from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8 +from sglang.srt.utils import is_cuda_available, set_weight_attrs + +is_cuda = is_cuda_available() +if is_cuda: + from sgl_kernel import int8_scaled_mm class W8A8Int8Config(QuantizationConfig): diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index b819e96f0..92f5f74e5 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -11,10 +11,11 @@ from sglang.srt.custom_op import CustomOp from sglang.srt.utils import is_cuda_available _is_cuda_available = is_cuda_available() + if _is_cuda_available: from sgl_kernel import apply_rope_with_cos_sin_cache_inplace else: - from vllm import _custom_ops as ops + from vllm._custom_ops import rotary_embedding as vllm_rotary_embedding def _rotate_neox(x: torch.Tensor) -> torch.Tensor: @@ -159,7 +160,7 @@ class RotaryEmbedding(CustomOp): ) else: self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype) - ops.rotary_embedding( + vllm_rotary_embedding( positions, query, key, diff --git a/python/sglang/srt/lora/backend/__init__.py b/python/sglang/srt/lora/backend/__init__.py deleted file mode 100644 index 7b76f90e5..000000000 --- a/python/sglang/srt/lora/backend/__init__.py +++ /dev/null @@ -1,25 +0,0 @@ -from sglang.srt.lora.backend.base_backend import BaseLoRABackend - - -def get_backend_from_name(name: str) -> BaseLoRABackend: - """ - Get corresponding backend class from backend's name - """ - if name == "triton": - from sglang.srt.lora.backend.triton_backend import TritonLoRABackend - - return TritonLoRABackend - elif name == "flashinfer": - from sglang.srt.lora.backend.flashinfer_backend import FlashInferLoRABackend - - return FlashInferLoRABackend - else: - raise ValueError(f"Invalid backend: {name}") - - -__all__ = [ - "BaseLoRABackend", - "FlashInferLoRABackend", - "TritonLoRABackend", - "get_backend_from_name", -] diff --git a/python/sglang/srt/lora/backend/base_backend.py b/python/sglang/srt/lora/backend/base_backend.py index c4346681c..e1bdc5408 100644 --- a/python/sglang/srt/lora/backend/base_backend.py +++ b/python/sglang/srt/lora/backend/base_backend.py @@ -75,7 +75,7 @@ class BaseLoRABackend: qkv_lora_a: torch.Tensor, qkv_lora_b: Union[torch.Tensor, Tuple[torch.Tensor]], *args, - **kwargs + **kwargs, ) -> torch.Tensor: """Run the lora pass for QKV Layer. @@ -98,7 +98,7 @@ class BaseLoRABackend: gate_up_lora_a: torch.Tensor, gate_up_lora_b: Union[torch.Tensor, Tuple[torch.Tensor]], *args, - **kwargs + **kwargs, ) -> torch.Tensor: """Run the lora pass for gate_up_proj, usually attached to MergedColumnParallelLayer. @@ -115,3 +115,19 @@ class BaseLoRABackend: def set_batch_info(self, batch_info: LoRABatchInfo): self.batch_info = batch_info + + +def get_backend_from_name(name: str) -> BaseLoRABackend: + """ + Get corresponding backend class from backend's name + """ + if name == "triton": + from sglang.srt.lora.backend.triton_backend import TritonLoRABackend + + return TritonLoRABackend + elif name == "flashinfer": + from sglang.srt.lora.backend.flashinfer_backend import FlashInferLoRABackend + + return FlashInferLoRABackend + else: + raise ValueError(f"Invalid backend: {name}") diff --git a/python/sglang/srt/lora/backend/flashinfer_backend.py b/python/sglang/srt/lora/backend/flashinfer_backend.py index 7505ba69a..0370c6c81 100644 --- a/python/sglang/srt/lora/backend/flashinfer_backend.py +++ b/python/sglang/srt/lora/backend/flashinfer_backend.py @@ -2,7 +2,7 @@ from typing import Tuple import torch -from sglang.srt.lora.backend import BaseLoRABackend +from sglang.srt.lora.backend.base_backend import BaseLoRABackend from sglang.srt.lora.utils import LoRABatchInfo from sglang.srt.utils import is_flashinfer_available diff --git a/python/sglang/srt/lora/backend/triton_backend.py b/python/sglang/srt/lora/backend/triton_backend.py index 88eb87c76..d3a854b40 100644 --- a/python/sglang/srt/lora/backend/triton_backend.py +++ b/python/sglang/srt/lora/backend/triton_backend.py @@ -1,6 +1,6 @@ import torch -from sglang.srt.lora.backend import BaseLoRABackend +from sglang.srt.lora.backend.base_backend import BaseLoRABackend from sglang.srt.lora.triton_ops import ( gate_up_lora_b_fwd, qkv_lora_b_fwd, diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index cafd8b7e0..8da81826f 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -16,7 +16,7 @@ from sglang.srt.layers.linear import ( RowParallelLinear, ) from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding -from sglang.srt.lora.backend import BaseLoRABackend +from sglang.srt.lora.backend.base_backend import BaseLoRABackend class BaseLayerWithLoRA(nn.Module): diff --git a/python/sglang/srt/lora/lora.py b/python/sglang/srt/lora/lora.py index 51e4d56f2..b0db40d6a 100644 --- a/python/sglang/srt/lora/lora.py +++ b/python/sglang/srt/lora/lora.py @@ -27,7 +27,7 @@ from torch import nn from sglang.srt.configs.load_config import LoadConfig from sglang.srt.hf_transformers_utils import AutoConfig -from sglang.srt.lora.backend import BaseLoRABackend +from sglang.srt.lora.backend.base_backend import BaseLoRABackend from sglang.srt.lora.lora_config import LoRAConfig from sglang.srt.model_loader.loader import DefaultModelLoader diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index fc0374ace..7b8f11629 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -22,7 +22,7 @@ import torch from sglang.srt.configs.load_config import LoadConfig from sglang.srt.hf_transformers_utils import AutoConfig -from sglang.srt.lora.backend import BaseLoRABackend, get_backend_from_name +from sglang.srt.lora.backend.base_backend import BaseLoRABackend, get_backend_from_name from sglang.srt.lora.layers import BaseLayerWithLoRA, get_lora_layer from sglang.srt.lora.lora import LoRAAdapter from sglang.srt.lora.lora_config import LoRAConfig diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index ed7326360..aeae266eb 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -14,7 +14,6 @@ """DetokenizerManager is a process that detokenizes the token ids.""" import dataclasses -import json import logging import os import signal diff --git a/python/sglang/srt/managers/mm_utils.py b/python/sglang/srt/managers/mm_utils.py index 045b4c13a..aa7ed1554 100644 --- a/python/sglang/srt/managers/mm_utils.py +++ b/python/sglang/srt/managers/mm_utils.py @@ -1,7 +1,8 @@ """ - Multi-modality utils +Multi-modality utils """ +import logging from abc import abstractmethod from typing import Callable, List, Optional, Tuple @@ -12,11 +13,11 @@ from sglang.srt.managers.schedule_batch import ( MultimodalDataItem, MultimodalInputs, global_server_args_dict, - logger, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.utils import print_warning_once -from sglang.utils import logger + +logger = logging.getLogger(__name__) class MultiModalityDataPaddingPattern: diff --git a/python/sglang/srt/managers/multimodal_processor.py b/python/sglang/srt/managers/multimodal_processor.py index ac603235c..938ddd044 100644 --- a/python/sglang/srt/managers/multimodal_processor.py +++ b/python/sglang/srt/managers/multimodal_processor.py @@ -5,8 +5,6 @@ import logging import pkgutil from functools import lru_cache -from transformers import PROCESSOR_MAPPING - from sglang.srt.managers.multimodal_processors.base_processor import ( BaseMultimodalProcessor, ) diff --git a/python/sglang/srt/managers/multimodal_processors/base_processor.py b/python/sglang/srt/managers/multimodal_processors/base_processor.py index 22ad7e797..6b0672a9f 100644 --- a/python/sglang/srt/managers/multimodal_processors/base_processor.py +++ b/python/sglang/srt/managers/multimodal_processors/base_processor.py @@ -8,8 +8,6 @@ from typing import List, Optional import numpy as np import PIL -from decord import VideoReader, cpu -from PIL import Image from transformers import BaseImageProcessorFast from sglang.srt.managers.schedule_batch import Modality @@ -102,6 +100,9 @@ class BaseMultimodalProcessor(ABC): """ estimate the total frame count from all visual input """ + # Lazy import because decord is not available on some arm platforms. + from decord import VideoReader, cpu + # Before processing inputs estimated_frames_list = [] for image in image_data: diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index c71cae07a..b1fa22614 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -37,11 +37,11 @@ from sglang.srt.model_executor.forward_batch_info import ( from sglang.srt.patch_torch import monkey_patch_torch_compile from sglang.srt.utils import get_available_gpu_memory, is_hip -_is_hip = is_hip() - if TYPE_CHECKING: from sglang.srt.model_executor.model_runner import ModelRunner +_is_hip = is_hip() + def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int): for sub in model._modules.values(): diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index ca9ffbab2..ed899f080 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -320,7 +320,6 @@ class ModelRunner: logger.info(f"DeepEP is turned on. DeepEP mode: {server_args.deepep_mode}") if not self.use_mla_backend: - logger.info("Disable chunked prefix cache for non-MLA backend.") server_args.disable_chunked_prefix_cache = True elif self.page_size > 1: logger.info("Disable chunked prefix cache when page size > 1.") diff --git a/python/sglang/srt/models/deepseek_nextn.py b/python/sglang/srt/models/deepseek_nextn.py index be3466197..a1739ec78 100644 --- a/python/sglang/srt/models/deepseek_nextn.py +++ b/python/sglang/srt/models/deepseek_nextn.py @@ -48,7 +48,7 @@ _is_cuda = is_cuda() if _is_cuda: from sgl_kernel import awq_dequantize else: - from vllm import _custom_ops as ops + from vllm._custom_ops import awq_dequantize class DeepseekModelNextN(nn.Module): @@ -273,7 +273,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM): self_attn.kv_b_proj.qzeros, ).T else: - w = ops.awq_dequantize( + w = awq_dequantize( self_attn.kv_b_proj.qweight, self_attn.kv_b_proj.scales, self_attn.kv_b_proj.qzeros, diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 533c3169c..abb9aa4bb 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -51,6 +51,7 @@ from sglang.srt.layers.linear import ( ) from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE +from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.quantization.base_config import QuantizationConfig @@ -80,10 +81,8 @@ _is_cuda = is_cuda() if _is_cuda: from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2 - - from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher else: - from vllm import _custom_ops as ops + from vllm._custom_ops import awq_dequantize if _is_hip: from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import ( @@ -861,7 +860,7 @@ class DeepseekV2AttentionMLA(nn.Module): ) elif self.w_kc.dtype == torch.float8_e4m3fn: q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8( - q_nope.transpose(0, 1), dtype=torch.float8_e4m3fn + q_nope.transpose(0, 1), ) q_nope_out = bmm_fp8( q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16 @@ -892,7 +891,7 @@ class DeepseekV2AttentionMLA(nn.Module): ) elif self.w_vc.dtype == torch.float8_e4m3fn: attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8( - attn_output.transpose(0, 1), dtype=torch.float8_e4m3fn + attn_output.transpose(0, 1), ) attn_bmm_output = bmm_fp8( attn_output_val, @@ -1565,7 +1564,7 @@ class DeepseekV2ForCausalLM(nn.Module): self_attn.kv_b_proj.qzeros, ).T else: - w = ops.awq_dequantize( + w = awq_dequantize( self_attn.kv_b_proj.qweight, self_attn.kv_b_proj.scales, self_attn.kv_b_proj.qzeros, diff --git a/python/sglang/srt/reasoning_parser.py b/python/sglang/srt/reasoning_parser.py index fe369896f..22a73fbe2 100644 --- a/python/sglang/srt/reasoning_parser.py +++ b/python/sglang/srt/reasoning_parser.py @@ -1,4 +1,3 @@ -import re from typing import Dict, Tuple diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 70a2443bf..66e6552c0 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -10,12 +10,11 @@ import torch import sglang.srt.sampling.penaltylib as penaltylib from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor -logger = logging.getLogger(__name__) - - if TYPE_CHECKING: from sglang.srt.managers.schedule_batch import ScheduleBatch +logger = logging.getLogger(__name__) + @dataclasses.dataclass class SamplingBatchInfo: diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py deleted file mode 100644 index 869a984d0..000000000 --- a/python/sglang/srt/server.py +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright 2023-2024 SGLang Team -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -# Some shortcuts for backward compatibility. -# They will be removed in new versions. -from sglang.srt.entrypoints.engine import Engine -from sglang.srt.entrypoints.http_server import kill_process_tree, launch_server diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index f9878d2c5..b45860aba 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -187,6 +187,7 @@ class ServerArgs: n_share_experts_fusion: int = 0 disable_shared_experts_fusion: bool = False disable_chunked_prefix_cache: bool = False + disable_fast_image_processor: bool = False # Debug tensor dumps debug_tensor_dump_output_folder: Optional[str] = None @@ -198,9 +199,6 @@ class ServerArgs: disaggregation_bootstrap_port: int = 8998 disaggregation_transfer_backend: str = "mooncake" - # multimodal - disable_fast_image_processor: bool = False - def __post_init__(self): # Expert parallelism if self.enable_ep_moe: @@ -1136,6 +1134,11 @@ class ServerArgs: action="store_true", help="Disable chunked prefix cache feature for deepseek, which should save overhead for short sequences.", ) + parser.add_argument( + "--disable-fast-image-processor", + action="store_true", + help="Adopt base image processor instead of fast image processor.", + ) # Server warmups parser.add_argument( @@ -1187,13 +1190,6 @@ class ServerArgs: help="The backend for disaggregation transfer. Default is mooncake.", ) - # Multimodal - parser.add_argument( - "--disable-fast-image-processor", - action="store_true", - help="Adopt base image processor instead of fast image processor.", - ) - @classmethod def from_cli_args(cls, args: argparse.Namespace): args.tp_size = args.tensor_parallel_size diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index e812a7802..a6c2d910b 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -55,7 +55,6 @@ import torch.distributed import torch.distributed as dist import triton import zmq -from decord import VideoReader, cpu from fastapi.responses import ORJSONResponse from packaging import version as pkg_version from PIL import Image @@ -545,6 +544,9 @@ def load_audio(audio_file: str, sr: int = 16000, mono: bool = True) -> np.ndarra def encode_video(video_path, frame_count_limit=None): + # Lazy import because decord is not available on some arm platforms. + from decord import VideoReader, cpu + if not os.path.exists(video_path): logger.error(f"Video {video_path} does not exist") return [] diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index 66adc84a7..a7eaebdb2 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -26,8 +26,8 @@ from transformers import ( AutoProcessor, ) +from sglang.srt.entrypoints.engine import Engine from sglang.srt.hf_transformers_utils import get_tokenizer -from sglang.srt.server import Engine from sglang.srt.utils import load_image from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER, calculate_rouge_l diff --git a/python/sglang/test/test_custom_ops.py b/python/sglang/test/test_custom_ops.py index 72b9f5ab3..873f9960e 100644 --- a/python/sglang/test/test_custom_ops.py +++ b/python/sglang/test/test_custom_ops.py @@ -3,7 +3,7 @@ import pytest import torch -from sglang.srt.custom_op import scaled_fp8_quant +from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant from sglang.srt.utils import is_cuda diff --git a/test/srt/test_fp8_kernel.py b/test/srt/test_fp8_kernel.py index 1f8d94b3a..42502277b 100644 --- a/test/srt/test_fp8_kernel.py +++ b/test/srt/test_fp8_kernel.py @@ -93,9 +93,7 @@ class TestPerTokenGroupQuantFP8(TestFP8Base): A, A_quant_gt, scale_gt = self._make_A( M=self.M, K=self.K, group_size=self.group_size, out_dtype=self.quant_type ) - A_quant, scale = per_token_group_quant_fp8( - x=A, group_size=self.group_size, dtype=self.quant_type - ) + A_quant, scale = per_token_group_quant_fp8(x=A, group_size=self.group_size) torch.testing.assert_close(scale, scale_gt) diff = (A_quant.to(torch.float16) - A_quant_gt.to(torch.float16)).abs() diff_count = (diff > 1e-5).count_nonzero() diff --git a/test/srt/test_triton_moe_channel_fp8_kernel.py b/test/srt/test_triton_moe_channel_fp8_kernel.py index 2de9a6790..89b5af650 100644 --- a/test/srt/test_triton_moe_channel_fp8_kernel.py +++ b/test/srt/test_triton_moe_channel_fp8_kernel.py @@ -3,9 +3,9 @@ import unittest import torch -from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe +from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant from sglang.test.test_utils import CustomTestCase @@ -41,7 +41,7 @@ def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk): B, D = a.shape # Perform per-token quantization - a_q, a_s = sgl_scaled_fp8_quant(a, use_per_token_if_dynamic=True) + a_q, a_s = scaled_fp8_quant(a, use_per_token_if_dynamic=True) # Repeat tokens to match topk a_q = a_q.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) # Also repeat the scale @@ -69,7 +69,7 @@ def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk): # Activation function act_out = SiluAndMul().forward_native(inter_out) # Quantize activation output with per-token - act_out_q, act_out_s = sgl_scaled_fp8_quant( + act_out_q, act_out_s = scaled_fp8_quant( act_out, use_per_token_if_dynamic=True )