Clean up imports (#5467)

This commit is contained in:
Lianmin Zheng
2025-04-16 15:26:49 -07:00
committed by GitHub
parent d7bc19a46a
commit 177320a582
51 changed files with 376 additions and 573 deletions

View File

@@ -24,6 +24,7 @@ from sglang.api import (
user_end, user_end,
video, video,
) )
from sglang.global_config import global_config
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.lang.choices import ( from sglang.lang.choices import (
greedy_token_selection, greedy_token_selection,
@@ -31,6 +32,7 @@ from sglang.lang.choices import (
unconditional_likelihood_normalized, unconditional_likelihood_normalized,
) )
from sglang.utils import LazyImport from sglang.utils import LazyImport
from sglang.version import __version__
ServerArgs = LazyImport("sglang.srt.server_args", "ServerArgs") ServerArgs = LazyImport("sglang.srt.server_args", "ServerArgs")
Anthropic = LazyImport("sglang.lang.backend.anthropic", "Anthropic") Anthropic = LazyImport("sglang.lang.backend.anthropic", "Anthropic")
@@ -38,10 +40,6 @@ LiteLLM = LazyImport("sglang.lang.backend.litellm", "LiteLLM")
OpenAI = LazyImport("sglang.lang.backend.openai", "OpenAI") OpenAI = LazyImport("sglang.lang.backend.openai", "OpenAI")
VertexAI = LazyImport("sglang.lang.backend.vertexai", "VertexAI") VertexAI = LazyImport("sglang.lang.backend.vertexai", "VertexAI")
# Other configs
from sglang.global_config import global_config
from sglang.version import __version__
__all__ = [ __all__ = [
"Engine", "Engine",
"Runtime", "Runtime",

View File

@@ -707,10 +707,6 @@ def sample_random_requests(
# Download sharegpt if necessary # Download sharegpt if necessary
if not os.path.isfile(dataset_path): if not os.path.isfile(dataset_path):
print(
"If you do not want to randomly sample from a dataset,"
" please use --dataset-name random-ids."
)
dataset_path = download_and_cache_file(SHAREGPT_URL) dataset_path = download_and_cache_file(SHAREGPT_URL)
# Load the dataset. # Load the dataset.

View File

@@ -1,7 +1,3 @@
from typing import List, Optional, Union
import numpy as np
from sglang.lang.backend.base_backend import BaseBackend from sglang.lang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import get_chat_template from sglang.lang.chat_template import get_chat_template
from sglang.lang.interpreter import StreamExecutor from sglang.lang.interpreter import StreamExecutor

View File

@@ -1,4 +1,4 @@
from typing import Callable, List, Optional, Union from typing import List, Optional, Union
from sglang.lang.chat_template import get_chat_template from sglang.lang.chat_template import get_chat_template
from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod

View File

@@ -2,7 +2,7 @@ import dataclasses
import logging import logging
import time import time
import warnings import warnings
from typing import Callable, List, Optional, Union from typing import List, Optional, Union
import numpy as np import numpy as np

View File

@@ -1,6 +1,5 @@
import os import os
import warnings import warnings
from typing import Optional
from sglang.lang.backend.base_backend import BaseBackend from sglang.lang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import get_chat_template from sglang.lang.chat_template import get_chat_template

View File

@@ -5,13 +5,7 @@ from typing import List, Union
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.lang.interpreter import ProgramState, StreamExecutor, cache_program from sglang.lang.interpreter import ProgramState, StreamExecutor, cache_program
from sglang.lang.ir import ( from sglang.lang.ir import SglArgument, SglExpr, SglSamplingParams, SglVariable
SglArgument,
SglConstantText,
SglExpr,
SglSamplingParams,
SglVariable,
)
def compile_func(function, backend): def compile_func(function, backend):

View File

@@ -1,20 +1,16 @@
"""Tracing a program.""" """Tracing a program."""
import uuid import uuid
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Dict, List, Optional
from sglang.global_config import global_config
from sglang.lang.backend.base_backend import BaseBackend from sglang.lang.backend.base_backend import BaseBackend
from sglang.lang.interpreter import ProgramState, ProgramStateGroup from sglang.lang.interpreter import ProgramState, ProgramStateGroup
from sglang.lang.ir import ( from sglang.lang.ir import (
SglArgument, SglArgument,
SglCommitLazy,
SglConcateAndAppend,
SglConstantText, SglConstantText,
SglExpr, SglExpr,
SglExprList, SglExprList,
SglFork, SglFork,
SglFunction,
SglGen, SglGen,
SglGetForkItem, SglGetForkItem,
SglRoleBegin, SglRoleBegin,
@@ -230,8 +226,8 @@ class TracerProgramState(ProgramState):
self.cur_role = None self.cur_role = None
def _execute_var_scope_end(self, expr: SglVarScopeEnd): def _execute_var_scope_end(self, expr: SglVarScopeEnd):
new_node = SglVariable(name, source=self.last_node) new_node = SglVariable(expr.name, source=self.last_node)
self.variables[name] = new_node self.variables[expr.name] = new_node
def get_var(self, name): def get_var(self, name):
ret = self.arguments.get(name, None) ret = self.arguments.get(name, None)

View File

@@ -1,10 +1,8 @@
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/_custom_ops.py # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/_custom_ops.py
import logging import logging
import os
from typing import List, Tuple from typing import List, Tuple
import torch import torch
import torch.library
from sglang.srt.utils import get_bool_env_var, is_hip, is_hpu from sglang.srt.utils import get_bool_env_var, is_hip, is_hpu

View File

@@ -42,65 +42,3 @@ class CustomOp(nn.Module):
return self.forward_hip return self.forward_hip
else: else:
return self.forward_native return self.forward_native
if _is_cuda:
from sgl_kernel import sgl_per_tensor_quant_fp8, sgl_per_token_quant_fp8
def scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
num_token_padding: Optional[int] = None,
use_per_token_if_dynamic: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to FP8 (8-bit floating point) format.
Args:
input (torch.Tensor): Input tensor to be quantized
scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization.
If None, scales will be computed dynamically.
num_token_padding (Optional[int]): If specified, pad the first dimension
of the output to at least this value.
use_per_token_if_dynamic (bool): When using dynamic scaling (scale=None),
determines the quantization granularity:
- True: compute scale per token
- False: compute single scale per tensor
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- quantized_tensor: The FP8 quantized version of input
- scale_tensor: The scaling factors used for quantization
Raises:
AssertionError: If input is not 2D or if static scale's numel != 1
"""
assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D"
shape = input.shape
out_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
if num_token_padding:
shape = (max(num_token_padding, input.shape[0]), shape[1])
output = torch.empty(shape, device=input.device, dtype=out_dtype)
if scale is None:
# Dynamic scaling
if use_per_token_if_dynamic:
scale = torch.empty(
(shape[0], 1), device=input.device, dtype=torch.float32
)
sgl_per_token_quant_fp8(input, output, scale)
else:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
sgl_per_tensor_quant_fp8(
input, output, scale, is_static=False
) # False for dynamic
else:
# Static scaling
assert (
scale.numel() == 1
), f"Expected scalar scale, got numel={scale.numel()}"
sgl_per_tensor_quant_fp8(
input, output, scale, is_static=True
) # True for static
return output, scale

View File

@@ -19,11 +19,10 @@ import torch.distributed as dist
from PIL.Image import Image from PIL.Image import Image
from torch.distributed.tensor import DeviceMesh, DTensor from torch.distributed.tensor import DeviceMesh, DTensor
from sglang.srt.entrypoints.engine import Engine
from sglang.srt.entrypoints.http_server_engine import HttpServerEngineAdapter from sglang.srt.entrypoints.http_server_engine import HttpServerEngineAdapter
from sglang.srt.model_executor.model_runner import LocalSerializedTensor from sglang.srt.model_executor.model_runner import LocalSerializedTensor
from sglang.srt.patch_torch import monkey_patch_torch_reductions from sglang.srt.patch_torch import monkey_patch_torch_reductions
from sglang.srt.server import Engine
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj

View File

@@ -21,13 +21,6 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from sglang.srt.utils import is_cuda_available
_is_cuda = is_cuda_available()
if _is_cuda:
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
from sglang.srt.custom_op import CustomOp from sglang.srt.custom_op import CustomOp
from sglang.srt.distributed import ( from sglang.srt.distributed import (
divide, divide,
@@ -35,7 +28,12 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
) )
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.utils import set_weight_attrs from sglang.srt.utils import is_cuda_available, set_weight_attrs
_is_cuda = is_cuda_available()
if _is_cuda:
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -19,6 +19,7 @@ from typing import Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from sglang.srt.custom_op import CustomOp
from sglang.srt.utils import is_cuda_available from sglang.srt.utils import is_cuda_available
_is_cuda = is_cuda_available() _is_cuda = is_cuda_available()
@@ -31,7 +32,6 @@ if _is_cuda:
rmsnorm, rmsnorm,
) )
from sglang.srt.custom_op import CustomOp
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -2,6 +2,7 @@ import logging
from typing import Callable, List, Optional, Tuple from typing import Callable, List, Optional, Tuple
import torch import torch
from torch.nn import Module
try: try:
from deep_gemm import ( from deep_gemm import (
@@ -13,8 +14,6 @@ try:
except ImportError: except ImportError:
use_deep_gemm = False use_deep_gemm = False
from torch.nn import Module
from sglang.srt.custom_op import CustomOp from sglang.srt.custom_op import CustomOp
from sglang.srt.distributed import ( from sglang.srt.distributed import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
@@ -37,21 +36,16 @@ from sglang.srt.layers.quantization.base_config import (
QuantizeMethodBase, QuantizeMethodBase,
) )
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.utils import DeepEPMode, is_cuda, is_hip, set_weight_attrs from sglang.srt.utils import DeepEPMode, is_hip, set_weight_attrs
_is_cuda = is_cuda()
if _is_cuda:
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
else:
from vllm import _custom_ops as vllm_ops
logger = logging.getLogger(__name__)
_is_hip = is_hip() _is_hip = is_hip()
_buffer = None if _is_hip:
from vllm._custom_ops import scaled_fp8_quant
logger = logging.getLogger(__name__)
class GroupedGemmRunner(torch.nn.Module): class GroupedGemmRunner(torch.nn.Module):
@@ -740,20 +734,12 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
) )
for expert in range(layer.num_experts_per_partition): for expert in range(layer.num_experts_per_partition):
if _is_cuda: w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = ( scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
sgl_scaled_fp8_quant(layer.w13_weight.data[expert, :, :]) )
) w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = ( scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
sgl_scaled_fp8_quant(layer.w2_weight.data[expert, :, :]) )
)
else:
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
vllm_ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
)
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
vllm_ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
)
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
return return

View File

@@ -13,6 +13,7 @@ import triton
import triton.language as tl import triton.language as tl
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
from sglang.srt.utils import ( from sglang.srt.utils import (
direct_register_custom_op, direct_register_custom_op,
get_bool_env_var, get_bool_env_var,
@@ -22,28 +23,25 @@ from sglang.srt.utils import (
) )
_is_hip = is_hip() _is_hip = is_hip()
logger = logging.getLogger(__name__)
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
enable_moe_align_block_size_triton = bool(
int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
)
_is_cuda = is_cuda() _is_cuda = is_cuda()
if _is_cuda: if _is_cuda:
from sgl_kernel import gelu_and_mul, silu_and_mul from sgl_kernel import gelu_and_mul, silu_and_mul
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
else: else:
from vllm import _custom_ops as vllm_ops from vllm import _custom_ops as vllm_ops
from vllm._custom_ops import scaled_fp8_quant
if _is_cuda or _is_hip: if _is_cuda or _is_hip:
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
logger = logging.getLogger(__name__)
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
enable_moe_align_block_size_triton = bool(
int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
)
@triton.jit @triton.jit
def write_zeros_to_output( def write_zeros_to_output(
c_ptr, c_ptr,
@@ -770,14 +768,9 @@ def invoke_fused_moe_kernel(
# activation tensor-wise fp8 quantization, dynamic or static # activation tensor-wise fp8 quantization, dynamic or static
padded_size = padding_size padded_size = padding_size
# activations apply per-token quantization when weights apply per-channel quantization by default # activations apply per-token quantization when weights apply per-channel quantization by default
if _is_cuda: A, A_scale = scaled_fp8_quant(
A, A_scale = sgl_scaled_fp8_quant( A, A_scale, use_per_token_if_dynamic=per_channel_quant
A, A_scale, use_per_token_if_dynamic=per_channel_quant )
)
else:
A, A_scale = vllm_ops.scaled_fp8_quant(
A, A_scale, use_per_token_if_dynamic=per_channel_quant
)
else: else:
# activation block-wise fp8 quantization # activation block-wise fp8 quantization
assert len(block_shape) == 2 assert len(block_shape) == 2

View File

@@ -13,7 +13,6 @@
# ============================================================================== # ==============================================================================
import math import math
import os
from typing import Callable, Optional from typing import Callable, Optional
import torch import torch
@@ -29,6 +28,10 @@ _is_hip = is_hip()
if _is_cuda: if _is_cuda:
from sgl_kernel import moe_fused_gate from sgl_kernel import moe_fused_gate
if _is_cuda or _is_hip:
from sgl_kernel import topk_softmax
expert_distribution_recorder = ExpertDistributionRecorder() expert_distribution_recorder = ExpertDistributionRecorder()
@@ -59,11 +62,6 @@ def fused_topk(
topk: int, topk: int,
renormalize: bool, renormalize: bool,
): ):
if _is_cuda or _is_hip:
from sgl_kernel import topk_softmax
else:
from vllm import _custom_ops as vllm_ops
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
M, _ = hidden_states.shape M, _ = hidden_states.shape
@@ -76,20 +74,12 @@ def fused_topk(
M, topk, dtype=torch.int32, device=hidden_states.device M, topk, dtype=torch.int32, device=hidden_states.device
) )
if _is_cuda or _is_hip: topk_softmax(
topk_softmax( topk_weights,
topk_weights, topk_ids,
topk_ids, token_expert_indicies,
token_expert_indicies, gating_output.float(),
gating_output.float(), )
)
else:
vllm_ops.topk_softmax(
topk_weights,
topk_ids,
token_expert_indicies,
gating_output.float(),
)
del token_expert_indicies del token_expert_indicies
if renormalize: if renormalize:

View File

@@ -7,8 +7,6 @@ from typing import Callable, Optional, Union
import torch import torch
from torch.nn import Parameter from torch.nn import Parameter
from sglang.srt.distributed import get_tensor_model_parallel_rank
__all__ = [ __all__ = [
"BasevLLMParameter", "BasevLLMParameter",
"PackedvLLMParameter", "PackedvLLMParameter",

View File

@@ -1,4 +1,4 @@
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors # Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import logging import logging
@@ -39,7 +39,6 @@ from sglang.srt.layers.quantization.compressed_tensors.utils import (
is_activation_quantization_format, is_activation_quantization_format,
should_ignore_layer, should_ignore_layer,
) )
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -1,22 +1,16 @@
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors # Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import enum import enum
import logging import logging
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING, Callable, List, Optional from typing import Callable, List, Optional
import torch import torch
from compressed_tensors import CompressionFormat from compressed_tensors import CompressionFormat
from compressed_tensors.quantization import QuantizationStrategy from compressed_tensors.quantization import QuantizationStrategy
if TYPE_CHECKING: from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
from sglang.srt.layers.moe.fused_moe_triton import (
FusedMoE,
FusedMoEMethodBase,
FusedMoeWeightScaleSupported,
)
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
from sglang.srt.layers.quantization.utils import ( from sglang.srt.layers.quantization.utils import (
all_close_1d, all_close_1d,
@@ -29,10 +23,9 @@ from sglang.srt.utils import set_weight_attrs
_is_cuda = is_cuda() _is_cuda = is_cuda()
if _is_cuda: if not _is_cuda:
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
else:
from vllm import _custom_ops as vllm_ops from vllm import _custom_ops as vllm_ops
from vllm._custom_ops import scaled_fp8_quant
try: try:
import vllm import vllm
@@ -58,8 +51,6 @@ __all__ = [
class CompressedTensorsMoEMethod: class CompressedTensorsMoEMethod:
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
if cls is CompressedTensorsMoEMethod: if cls is CompressedTensorsMoEMethod:
return super().__new__(cls) return super().__new__(cls)
return super().__new__(cls) return super().__new__(cls)
@@ -76,7 +67,7 @@ class CompressedTensorsMoEMethod:
if quant_config._is_wNa16_group_channel(weight_quant, input_quant): if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
if not VLLM_AVAILABLE: if not VLLM_AVAILABLE:
raise ImportError( raise ImportError(
"vllm is not installed, to use CompressedTensorsWNA16MoEMethod, please install vllm" "vllm is not installed, to use CompressedTensorsWNA16MoEMethod, please install vllm."
) )
return CompressedTensorsWNA16MoEMethod(quant_config) return CompressedTensorsWNA16MoEMethod(quant_config)
elif quant_config._is_fp8_w8a8(weight_quant, input_quant): elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
@@ -92,11 +83,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
def __init__( def __init__(
self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
): ):
from sglang.srt.layers.moe.fused_moe_triton import (
FusedMoEMethodBase,
FusedMoeWeightScaleSupported,
)
self.quant_config = quant_config self.quant_config = quant_config
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights") self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights")
self.input_quant = self.quant_config.target_scheme_map["Linear"].get( self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
@@ -267,19 +253,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer.w13_weight[expert_id][start : start + shard_size, :], layer.w13_weight[expert_id][start : start + shard_size, :],
layer.w13_weight_scale[expert_id][shard_id], layer.w13_weight_scale[expert_id][shard_id],
) )
(
layer.w13_weight[expert_id][start : start + shard_size, :],
_,
) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
if _is_cuda:
(
layer.w13_weight[expert_id][start : start + shard_size, :],
_,
) = sgl_scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
else:
(
layer.w13_weight[expert_id][start : start + shard_size, :],
_,
) = vllm_ops.scaled_fp8_quant(
dq_weight, max_w13_scales[expert_id]
)
start += shard_size start += shard_size
layer.w13_weight_scale = torch.nn.Parameter( layer.w13_weight_scale = torch.nn.Parameter(
@@ -345,11 +323,6 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
def __init__( def __init__(
self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
): ):
from sglang.srt.layers.moe.fused_moe_triton import (
FusedMoEMethodBase,
FusedMoeWeightScaleSupported,
)
self.quant_config = quant_config self.quant_config = quant_config
# TODO: @dsikka: refactor this to use schemes as other kernels # TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored. # are supported + check if the layer is being ignored.
@@ -609,7 +582,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
requires_grad=False, requires_grad=False,
) )
marlin_w13_qweight = ops.gptq_marlin_moe_repack( marlin_w13_qweight = vllm_ops.gptq_marlin_moe_repack(
layer.w13_weight_packed, layer.w13_weight_packed,
layer.w13_g_idx_sort_indices, layer.w13_g_idx_sort_indices,
layer.w13_weight_packed.shape[1] * self.packed_factor, layer.w13_weight_packed.shape[1] * self.packed_factor,
@@ -617,7 +590,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
self.num_bits, self.num_bits,
) )
replace_tensor("w13_weight_packed", marlin_w13_qweight) replace_tensor("w13_weight_packed", marlin_w13_qweight)
marlin_w2_qweight = ops.gptq_marlin_moe_repack( marlin_w2_qweight = vllm_ops.gptq_marlin_moe_repack(
layer.w2_weight_packed, layer.w2_weight_packed,
layer.w2_g_idx_sort_indices, layer.w2_g_idx_sort_indices,
layer.w2_weight_packed.shape[1] * self.packed_factor, layer.w2_weight_packed.shape[1] * self.packed_factor,
@@ -661,14 +634,9 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
correction_bias: Optional[torch.Tensor] = None, correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu", activation: str = "silu",
) -> torch.Tensor: ) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.topk import select_experts
assert activation == "silu", "Only SiLU activation is supported." assert activation == "silu", "Only SiLU activation is supported."
if not VLLM_AVAILABLE:
raise ImportError(
"vllm is not installed, to use fused_marlin_moe, please install vllm"
)
if expert_map is not None: if expert_map is not None:
raise NotImplementedError( raise NotImplementedError(
"Expert Parallelism is not supported for " "fused Marlin MoE method." "Expert Parallelism is not supported for " "fused Marlin MoE method."

View File

@@ -17,7 +17,6 @@ from sglang.srt.layers.quantization.compressed_tensors.schemes import (
) )
from sglang.srt.layers.quantization.fp8_utils import ( from sglang.srt.layers.quantization.fp8_utils import (
Fp8LinearOp, Fp8LinearOp,
maybe_create_device_identity,
normalize_e4m3fn_to_e4m3fnuz, normalize_e4m3fn_to_e4m3fnuz,
) )
from sglang.srt.layers.quantization.utils import is_fp8_fnuz, requantize_with_max_scale from sglang.srt.layers.quantization.utils import is_fp8_fnuz, requantize_with_max_scale
@@ -99,8 +98,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
weight_loader: Callable, weight_loader: Callable,
**kwargs, **kwargs,
): ):
maybe_create_device_identity()
output_size_per_partition = sum(output_partition_sizes) output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes layer.logical_widths = output_partition_sizes

View File

@@ -8,15 +8,6 @@ import torch.nn.functional as F
from torch.nn import Module from torch.nn import Module
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
from sglang.srt.layers.quantization.utils import (
all_close_1d,
convert_to_channelwise,
is_layer_skipped,
per_tensor_dequantize,
requantize_with_max_scale,
)
try: try:
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, apply_fp8_marlin_linear,
@@ -27,11 +18,12 @@ try:
except ImportError: except ImportError:
MARLIN_FP8_AVAILABLE = False MARLIN_FP8_AVAILABLE = False
def apply_fp8_marlin_linear(*args, **kwargs): def dummy_func(*args, **kwargs):
raise ImportError("vllm is not installed") raise ImportError(
"marlin FP8 requires some operators from vllm. Please install vllm."
)
def prepare_fp8_layer_for_marlin(*args, **kwargs): apply_fp8_marlin_linear = prepare_fp8_layer_for_marlin = dummy_func
raise ImportError("vllm is not installed")
from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.distributed import get_tensor_model_parallel_world_size
@@ -49,7 +41,10 @@ from sglang.srt.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
) )
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8 from sglang.srt.layers.quantization.fp8_kernel import (
per_token_group_quant_fp8,
scaled_fp8_quant,
)
from sglang.srt.layers.quantization.fp8_utils import ( from sglang.srt.layers.quantization.fp8_utils import (
apply_fp8_linear, apply_fp8_linear,
apply_w8a8_block_fp8_linear, apply_w8a8_block_fp8_linear,
@@ -57,30 +52,35 @@ from sglang.srt.layers.quantization.fp8_utils import (
input_to_float8, input_to_float8,
normalize_e4m3fn_to_e4m3fnuz, normalize_e4m3fn_to_e4m3fnuz,
) )
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
from sglang.srt.layers.quantization.utils import (
all_close_1d,
convert_to_channelwise,
is_layer_skipped,
per_tensor_dequantize,
requantize_with_max_scale,
)
from sglang.srt.utils import ( from sglang.srt.utils import (
get_bool_env_var, get_bool_env_var,
is_cuda, is_cuda,
is_hip, is_hip,
permute_weight,
print_warning_once, print_warning_once,
set_weight_attrs, set_weight_attrs,
) )
ACTIVATION_SCHEMES = ["static", "dynamic"]
_is_hip = is_hip() _is_hip = is_hip()
_is_cuda = is_cuda()
if _is_hip: if _is_hip:
from aiter import ActivationType from aiter import ActivationType
from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages, ck_moe_2stages_win4 from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages, ck_moe_2stages_win4
from aiter.ops.shuffle import shuffle_weight from aiter.ops.shuffle import shuffle_weight
_is_cuda = is_cuda() if not _is_cuda:
from vllm._custom_ops import scaled_fp8_quant
if _is_cuda:
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant ACTIVATION_SCHEMES = ["static", "dynamic"]
else:
from vllm import _custom_ops as vllm_ops
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -243,7 +243,6 @@ class Fp8LinearMethod(LinearMethodBase):
) )
layer.logical_widths = output_partition_sizes layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition layer.output_size_per_partition = output_size_per_partition
layer.orig_dtype = params_dtype layer.orig_dtype = params_dtype
@@ -327,7 +326,9 @@ class Fp8LinearMethod(LinearMethodBase):
layer.weight_scale_inv.data, requires_grad=False layer.weight_scale_inv.data, requires_grad=False
) )
return return
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False) layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
# If checkpoint not serialized fp8, quantize the weights. # If checkpoint not serialized fp8, quantize the weights.
if not self.quant_config.is_checkpoint_fp8_serialized: if not self.quant_config.is_checkpoint_fp8_serialized:
if self.cutlass_fp8_supported or self.use_marlin: if self.cutlass_fp8_supported or self.use_marlin:
@@ -391,12 +392,9 @@ class Fp8LinearMethod(LinearMethodBase):
) )
if self.use_marlin: if self.use_marlin:
try: prepare_fp8_layer_for_marlin(layer)
prepare_fp8_layer_for_marlin(layer) # Activations not quantized for marlin.
# Activations not quantized for marlin. del layer.input_scale
del layer.input_scale
except ImportError:
self.use_marlin = False
def apply( def apply(
self, self,
@@ -406,18 +404,15 @@ class Fp8LinearMethod(LinearMethodBase):
) -> torch.Tensor: ) -> torch.Tensor:
if self.use_marlin: if self.use_marlin:
try: return apply_fp8_marlin_linear(
return apply_fp8_marlin_linear( input=x,
input=x, weight=layer.weight,
weight=layer.weight, weight_scale=layer.weight_scale,
weight_scale=layer.weight_scale, workspace=layer.workspace,
workspace=layer.workspace, size_n=layer.output_size_per_partition,
size_n=layer.output_size_per_partition, size_k=layer.input_size_per_partition,
size_k=layer.input_size_per_partition, bias=bias,
bias=bias, )
)
except ImportError:
self.use_marlin = False
if self.block_quant: if self.block_quant:
return apply_w8a8_block_fp8_linear( return apply_w8a8_block_fp8_linear(
@@ -516,7 +511,7 @@ class Fp8MoEMethod:
) )
# WEIGHTS # WEIGHTS
if get_bool_env_var("USE_INT4_WEIGHT"): if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
# INT4 MoE weight - INT32 packed # INT4 MoE weight - INT32 packed
w13_weight = torch.nn.Parameter( w13_weight = torch.nn.Parameter(
torch.empty( torch.empty(
@@ -617,7 +612,7 @@ class Fp8MoEMethod:
set_weight_attrs(w13_weight_scale, extra_weight_attrs) set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs) set_weight_attrs(w2_weight_scale, extra_weight_attrs)
if get_bool_env_var("USE_INT4_WEIGHT"): if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
extra_weight_attrs.update( extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value} {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
) )
@@ -649,7 +644,7 @@ class Fp8MoEMethod:
layer.w2_input_scale = None layer.w2_input_scale = None
def process_weights_after_loading(self, layer: Module) -> None: def process_weights_after_loading(self, layer: Module) -> None:
if get_bool_env_var("USE_INT4_WEIGHT"): if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
self.process_weights_hip_int4(layer) self.process_weights_hip_int4(layer)
return return
@@ -706,20 +701,12 @@ class Fp8MoEMethod:
requires_grad=False, requires_grad=False,
) )
for expert in range(layer.num_experts): for expert in range(layer.num_experts):
if _is_cuda: w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = ( scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
sgl_scaled_fp8_quant(layer.w13_weight.data[expert, :, :]) )
) w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = ( scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
sgl_scaled_fp8_quant(layer.w2_weight.data[expert, :, :]) )
)
else:
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
vllm_ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
)
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
vllm_ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
)
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
@@ -796,18 +783,10 @@ class Fp8MoEMethod:
layer.w13_weight[expert_id][start : start + shard_size, :], layer.w13_weight[expert_id][start : start + shard_size, :],
layer.w13_weight_scale[expert_id][shard_id], layer.w13_weight_scale[expert_id][shard_id],
) )
if _is_cuda: (
( layer.w13_weight[expert_id][start : start + shard_size, :],
layer.w13_weight[expert_id][start : start + shard_size, :], _,
_, ) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
) = sgl_scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
else:
(
layer.w13_weight[expert_id][start : start + shard_size, :],
_,
) = vllm_ops.scaled_fp8_quant(
dq_weight, max_w13_scales[expert_id]
)
start += shard_size start += shard_size
layer.w13_weight_scale = torch.nn.Parameter( layer.w13_weight_scale = torch.nn.Parameter(
@@ -930,41 +909,11 @@ class Fp8MoEMethod:
correction_bias=correction_bias, correction_bias=correction_bias,
) )
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"): if _is_hip:
# TODO: add triton kernel and add check get_bool_env_var("CK_MOE") if get_bool_env_var("USE_INT4_WEIGHT"):
assert not no_combine, f"{no_combine=} is not supported." # TODO: add triton kernel and add check get_bool_env_var("CK_MOE")
return ck_moe_2stages_win4( assert not no_combine, f"{no_combine=} is not supported."
x, return ck_moe_2stages_win4(
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
layer.w13_weight_scale1,
layer.w2_weight_scale1,
activation=(
ActivationType.Silu if activation == "silu" else ActivationType.Gelu
),
)
if _is_hip and get_bool_env_var("CK_MOE"):
assert not no_combine, f"{no_combine=} is not supported."
if self.block_quant:
# TODO(CK_MOE): FP8 block_quant only supports 'silu' for the time-being.
assert (
activation == "silu"
), f"CK_MOE: FP8 bloack_quant {activation=} will be supported later, unset CK_MOE"
return asm_moe(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
layer.w13_weight_scale_inv,
layer.w2_weight_scale_inv,
block_shape=tuple(self.quant_config.weight_block_size),
expert_mask=None,
)
else:
return ck_moe_2stages(
x, x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
@@ -978,33 +927,65 @@ class Fp8MoEMethod:
else ActivationType.Gelu else ActivationType.Gelu
), ),
) )
else:
# Expert fusion with FP8 quantization if get_bool_env_var("CK_MOE"):
return fused_experts( assert not no_combine, f"{no_combine=} is not supported."
x, if self.block_quant:
layer.w13_weight, # TODO(CK_MOE): FP8 block_quant only supports 'silu' for the time-being.
layer.w2_weight, assert (
topk_weights=topk_weights, activation == "silu"
topk_ids=topk_ids, ), f"CK_MOE: FP8 bloack_quant {activation=} will be supported later, unset CK_MOE"
inplace=inplace and not no_combine, return asm_moe(
activation=activation, x,
apply_router_weight_on_input=apply_router_weight_on_input, layer.w13_weight,
use_fp8_w8a8=True, layer.w2_weight,
w1_scale=( topk_weights,
layer.w13_weight_scale_inv topk_ids,
if self.block_quant layer.w13_weight_scale_inv,
else layer.w13_weight_scale layer.w2_weight_scale_inv,
), block_shape=tuple(self.quant_config.weight_block_size),
w2_scale=( expert_mask=None,
layer.w2_weight_scale_inv )
if self.block_quant else:
else layer.w2_weight_scale return ck_moe_2stages(
), x,
a1_scale=layer.w13_input_scale, layer.w13_weight,
a2_scale=layer.w2_input_scale, layer.w2_weight,
block_shape=self.quant_config.weight_block_size, topk_weights,
no_combine=no_combine, topk_ids,
) layer.w13_weight_scale1,
layer.w2_weight_scale1,
activation=(
ActivationType.Silu
if activation == "silu"
else ActivationType.Gelu
),
)
# Expert fusion with FP8 quantization
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=inplace and not no_combine,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=True,
w1_scale=(
layer.w13_weight_scale_inv
if self.block_quant
else layer.w13_weight_scale
),
w2_scale=(
layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size,
no_combine=no_combine,
)
class Fp8KVCacheMethod(BaseKVCacheMethod): class Fp8KVCacheMethod(BaseKVCacheMethod):

View File

@@ -34,15 +34,23 @@ from sglang.srt.utils import (
supports_custom_op, supports_custom_op,
) )
_enable_jit_deepgemm = False
_is_hip = is_hip() _is_hip = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
_is_cuda = is_cuda() _is_cuda = is_cuda()
_fp8_type = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
if _is_hip:
fp8_max = 224.0
else:
fp8_max = torch.finfo(_fp8_type).max
fp8_min = -fp8_max
_enable_jit_deepgemm = False
if _is_cuda: if _is_cuda:
import deep_gemm import deep_gemm
from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8 from sgl_kernel import (
sgl_per_tensor_quant_fp8,
sgl_per_token_group_quant_fp8,
sgl_per_token_quant_fp8,
)
sm_version = get_device_sm() sm_version = get_device_sm()
if sm_version == 90 and get_bool_env_var( if sm_version == 90 and get_bool_env_var(
@@ -53,6 +61,7 @@ if _is_cuda:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if supports_custom_op(): if supports_custom_op():
def deep_gemm_fp8_fp8_bf16_nt( def deep_gemm_fp8_fp8_bf16_nt(
@@ -179,7 +188,6 @@ def per_token_group_quant_fp8(
x: torch.Tensor, x: torch.Tensor,
group_size: int, group_size: int,
eps: float = 1e-10, eps: float = 1e-10,
dtype: torch.dtype = fp8_type_,
column_major_scales: bool = False, column_major_scales: bool = False,
scale_tma_aligned: bool = False, scale_tma_aligned: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -192,7 +200,6 @@ def per_token_group_quant_fp8(
x: The input tenosr with ndim >= 2. x: The input tenosr with ndim >= 2.
group_size: The group size used for quantization. group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero. eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor.
Returns: Returns:
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization. Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
@@ -202,15 +209,7 @@ def per_token_group_quant_fp8(
), "the last dimension of `x` cannot be divisible by `group_size`" ), "the last dimension of `x` cannot be divisible by `group_size`"
assert x.is_contiguous(), "`x` is not contiguous" assert x.is_contiguous(), "`x` is not contiguous"
finfo = torch.finfo(dtype) x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type)
fp8_max = finfo.max
if _is_hip:
fp8_max = 224.0
fp8_min = -fp8_max
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
M = x.numel() // group_size M = x.numel() // group_size
N = group_size N = group_size
if column_major_scales: if column_major_scales:
@@ -276,27 +275,18 @@ def sglang_per_token_group_quant_fp8(
x: torch.Tensor, x: torch.Tensor,
group_size: int, group_size: int,
eps: float = 1e-10, eps: float = 1e-10,
dtype: torch.dtype = fp8_type_,
): ):
assert ( assert (
x.shape[-1] % group_size == 0 x.shape[-1] % group_size == 0
), "the last dimension of `x` cannot be divisible by `group_size`" ), "the last dimension of `x` cannot be divisible by `group_size`"
assert x.is_contiguous(), "`x` is not contiguous" assert x.is_contiguous(), "`x` is not contiguous"
finfo = torch.finfo(dtype) x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type)
fp8_max = finfo.max
fp8_min = -fp8_max
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
M = x.numel() // group_size
N = group_size
x_s = torch.empty( x_s = torch.empty(
x.shape[:-1] + (x.shape[-1] // group_size,), x.shape[:-1] + (x.shape[-1] // group_size,),
device=x.device, device=x.device,
dtype=torch.float32, dtype=torch.float32,
) )
sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max) sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max)
return x_q, x_s return x_q, x_s
@@ -304,7 +294,7 @@ def sglang_per_token_group_quant_fp8(
def sglang_per_token_quant_fp8( def sglang_per_token_quant_fp8(
x: torch.Tensor, x: torch.Tensor,
dtype: torch.dtype = fp8_type_, dtype: torch.dtype = _fp8_type,
): ):
assert x.is_contiguous(), "`x` is not contiguous" assert x.is_contiguous(), "`x` is not contiguous"
@@ -368,7 +358,6 @@ def static_quant_fp8(
x: torch.Tensor, x: torch.Tensor,
x_s: torch.Tensor, x_s: torch.Tensor,
repeat_scale: bool = False, repeat_scale: bool = False,
dtype: torch.dtype = fp8_type_,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
"""Function to perform static quantization using the given scale on an input tensor `x`. """Function to perform static quantization using the given scale on an input tensor `x`.
@@ -386,15 +375,8 @@ def static_quant_fp8(
""" """
assert x.is_contiguous(), "`x` is not contiguous" assert x.is_contiguous(), "`x` is not contiguous"
assert x_s.numel() == 1, "only supports per-tensor scale" assert x_s.numel() == 1, "only supports per-tensor scale"
finfo = torch.finfo(dtype)
fp8_max = finfo.max
if _is_hip: x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type)
fp8_max = 224.0
fp8_min = -fp8_max
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
M = x.numel() // x.shape[-1] M = x.numel() // x.shape[-1]
N = x.shape[-1] N = x.shape[-1]
if repeat_scale: if repeat_scale:
@@ -896,7 +878,7 @@ def _per_tensor_quant_mla_fp8_stage2(
def per_tensor_quant_mla_fp8( def per_tensor_quant_mla_fp8(
x: torch.Tensor, eps: float = 1e-12, dtype: torch.dtype = torch.float8_e4m3fn x: torch.Tensor, eps: float = 1e-12
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
This function quantizes input values to float8 values with tensor-wise quantization This function quantizes input values to float8 values with tensor-wise quantization
@@ -904,13 +886,7 @@ def per_tensor_quant_mla_fp8(
""" """
assert x.dim() == 3, "`x` is not a 3d-tensor" assert x.dim() == 3, "`x` is not a 3d-tensor"
finfo = torch.finfo(dtype) x_q = x.new_empty(x.size(), dtype=_fp8_type)
fp8_max = finfo.max
if _is_hip:
dtype = torch.float8_e4m3fnuz
fp8_max = 224.0
x_q = x.new_empty(x.size(), dtype=dtype)
x_s = torch.zeros((1,), dtype=torch.float32, device=x.device) x_s = torch.zeros((1,), dtype=torch.float32, device=x.device)
num_head, num_seq, head_size = x.shape num_head, num_seq, head_size = x.shape
@@ -935,9 +911,64 @@ def per_tensor_quant_mla_fp8(
head_size, head_size,
x.stride(0), x.stride(0),
x.stride(1), x.stride(1),
-fp8_max, fp8_min,
fp8_max, fp8_max,
BLOCK_SIZE, BLOCK_SIZE,
) )
return x_q, x_s return x_q, x_s
def scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
num_token_padding: Optional[int] = None,
use_per_token_if_dynamic: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to FP8 (8-bit floating point) format.
Args:
input (torch.Tensor): Input tensor to be quantized
scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization.
If None, scales will be computed dynamically.
num_token_padding (Optional[int]): If specified, pad the first dimension
of the output to at least this value.
use_per_token_if_dynamic (bool): When using dynamic scaling (scale=None),
determines the quantization granularity:
- True: compute scale per token
- False: compute single scale per tensor
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- quantized_tensor: The FP8 quantized version of input
- scale_tensor: The scaling factors used for quantization
Raises:
AssertionError: If input is not 2D or if static scale's numel != 1
"""
assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D"
shape = input.shape
out_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
if num_token_padding:
shape = (max(num_token_padding, input.shape[0]), shape[1])
output = torch.empty(shape, device=input.device, dtype=out_dtype)
if scale is None:
# Dynamic scaling
if use_per_token_if_dynamic:
scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32)
sgl_per_token_quant_fp8(input, output, scale)
else:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
sgl_per_tensor_quant_fp8(
input, output, scale, is_static=False
) # False for dynamic
else:
# Static scaling
assert scale.numel() == 1, f"Expected scalar scale, got numel={scale.numel()}"
sgl_per_tensor_quant_fp8(
input, output, scale, is_static=True
) # True for static
return output, scale

View File

@@ -1,11 +1,19 @@
import os
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
try:
from vllm import _custom_ops as vllm_ops
VLLM_AVAILABLE = True
except ImportError:
VLLM_AVAILABLE = False
from sglang.srt.layers.quantization.fp8_kernel import ( from sglang.srt.layers.quantization.fp8_kernel import (
_enable_jit_deepgemm, _enable_jit_deepgemm,
per_token_group_quant_fp8, per_token_group_quant_fp8,
scaled_fp8_quant,
sglang_per_token_quant_fp8,
static_quant_fp8, static_quant_fp8,
w8a8_block_fp8_matmul, w8a8_block_fp8_matmul,
) )
@@ -17,30 +25,20 @@ from sglang.srt.utils import (
is_hip, is_hip,
) )
try:
import vllm
from vllm import _custom_ops as ops
VLLM_AVAILABLE = True
except ImportError:
VLLM_AVAILABLE = False
use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_KERNEL")
_is_hip = is_hip() _is_hip = is_hip()
_is_cuda = is_cuda()
if _is_hip and get_bool_env_var("CK_MOE"): if _is_hip and get_bool_env_var("CK_MOE"):
from aiter import gemm_a8w8_blockscale from aiter import gemm_a8w8_blockscale
_is_cuda = is_cuda()
if _is_cuda: if _is_cuda:
from sgl_kernel import fp8_blockwise_scaled_mm, fp8_scaled_mm from sgl_kernel import fp8_blockwise_scaled_mm, fp8_scaled_mm
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_KERNEL")
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_quant_fp8
# Input scaling factors are no longer optional in _scaled_mm starting # Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32) TORCH_DEVICE_IDENTITY = None
_TORCH_VERSION = torch.__version__.split("+")[0] _TORCH_VERSION = torch.__version__.split("+")[0]
try: try:
@@ -214,7 +212,7 @@ def block_quant_to_tensor_quant(
x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i] x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i]
x_q_tensor, scale = ( x_q_tensor, scale = (
sgl_scaled_fp8_quant(x_dq_block) scaled_fp8_quant(x_dq_block)
if _is_cuda if _is_cuda
else input_to_float8(x_dq_block, dtype=x_q_block.dtype) else input_to_float8(x_dq_block, dtype=x_q_block.dtype)
) )
@@ -227,7 +225,7 @@ def channel_quant_to_tensor_quant(
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
x_dq_channel = x_q_channel.to(torch.float32) * x_s x_dq_channel = x_q_channel.to(torch.float32) * x_s
x_q_tensor, scale = ( x_q_tensor, scale = (
sgl_scaled_fp8_quant(x_dq_channel) scaled_fp8_quant(x_dq_channel)
if _is_cuda if _is_cuda
else input_to_float8(x_dq_channel, dtype=x_q_channel.dtype) else input_to_float8(x_dq_channel, dtype=x_q_channel.dtype)
) )
@@ -264,7 +262,7 @@ def apply_fp8_linear(
# final solution should be: 1. add support to per-tensor activation scaling. # final solution should be: 1. add support to per-tensor activation scaling.
# 2. solve the torch.compile error from weight_scale.numel() == 1 and x_scale.numel() > 1 (below line#308) # 2. solve the torch.compile error from weight_scale.numel() == 1 and x_scale.numel() > 1 (below line#308)
if _is_hip and weight_scale.numel() == 1: if _is_hip and weight_scale.numel() == 1:
qinput, x_scale = ops.scaled_fp8_quant( qinput, x_scale = vllm_ops.scaled_fp8_quant(
input_2d, input_2d,
input_scale, input_scale,
use_per_token_if_dynamic=use_per_token_if_dynamic, use_per_token_if_dynamic=use_per_token_if_dynamic,
@@ -275,32 +273,29 @@ def apply_fp8_linear(
) )
if cutlass_fp8_supported: if cutlass_fp8_supported:
try: if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel:
if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel: # Fall back to vllm cutlass w8a8 fp8 kernel
# Fall back to vllm cutlass w8a8 fp8 kernel output = vllm_ops.cutlass_scaled_mm(
output = ops.cutlass_scaled_mm( qinput,
qinput, weight,
weight, out_dtype=input.dtype,
out_dtype=input.dtype, scale_a=x_scale,
scale_a=x_scale, scale_b=weight_scale,
scale_b=weight_scale, bias=bias,
bias=bias, )
) else:
else: assert (
assert ( weight_scale.numel() == weight.shape[1]
weight_scale.numel() == weight.shape[1] ), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale"
), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale" output = fp8_scaled_mm(
output = fp8_scaled_mm( qinput,
qinput, weight,
weight, x_scale,
x_scale, weight_scale,
weight_scale, out_dtype=input.dtype,
out_dtype=input.dtype, bias=bias,
bias=bias, )
) return output.view(*output_shape)
return output.view(*output_shape)
except (ImportError, NameError, AttributeError):
pass
# torch.scaled_mm supports per tensor weights + activations only # torch.scaled_mm supports per tensor weights + activations only
# so fallback to naive if per channel or per token # so fallback to naive if per channel or per token
@@ -343,8 +338,10 @@ def apply_fp8_linear(
# Making sure the dummy tensor is on the same device as the weight # Making sure the dummy tensor is on the same device as the weight
global TORCH_DEVICE_IDENTITY global TORCH_DEVICE_IDENTITY
if TORCH_DEVICE_IDENTITY.device != weight.device: if TORCH_DEVICE_IDENTITY is None:
TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device) TORCH_DEVICE_IDENTITY = torch.ones(
1, dtype=torch.float32, device=weight.device
)
# GEMM # GEMM
# This computes C = (X * W). # This computes C = (X * W).
@@ -372,13 +369,6 @@ def apply_fp8_linear(
return output.to(dtype=input.dtype).view(*output_shape) return output.to(dtype=input.dtype).view(*output_shape)
def maybe_create_device_identity():
# Allocate dummy ones tensor for torch._scaled_mm
global TORCH_DEVICE_IDENTITY
if TORCH_DEVICE_IDENTITY is None:
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/w8a8_utils.py # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/w8a8_utils.py
# TODO(luka): follow similar pattern for marlin and block-fp8-linear # TODO(luka): follow similar pattern for marlin and block-fp8-linear
# https://github.com/vllm-project/vllm/issues/14397 # https://github.com/vllm-project/vllm/issues/14397
@@ -405,9 +395,7 @@ class Fp8LinearOp:
# We also don't pad when using torch.compile, # We also don't pad when using torch.compile,
# as it breaks with dynamic shapes. # as it breaks with dynamic shapes.
if pad_output is None: if pad_output is None:
enable_torch_compile = os.environ.get( enable_torch_compile = get_bool_env_var("SGLANG_ENABLE_TORCH_COMPILE")
"SGLANG_ENABLE_TORCH_COMPILE", "0"
).lower() in ("1", "true", "yes")
pad_output = not enable_torch_compile pad_output = not enable_torch_compile
self.output_padding = 17 if pad_output else None self.output_padding = 17 if pad_output else None
@@ -439,13 +427,13 @@ class Fp8LinearOp:
# for sgl-kernel fp8_scaled_mm, it support per channel W now # for sgl-kernel fp8_scaled_mm, it support per channel W now
if self.cutlass_fp8_supported and weight_scale.numel() == weight.shape[1]: if self.cutlass_fp8_supported and weight_scale.numel() == weight.shape[1]:
if _is_cuda: if _is_cuda:
qinput, x_scale = sgl_scaled_fp8_quant( qinput, x_scale = scaled_fp8_quant(
input_2d, input_2d,
input_scale, input_scale,
use_per_token_if_dynamic=use_per_token_if_dynamic, use_per_token_if_dynamic=use_per_token_if_dynamic,
) )
else: else:
qinput, x_scale = ops.scaled_fp8_quant( qinput, x_scale = vllm_ops.scaled_fp8_quant(
input_2d, input_2d,
input_scale, input_scale,
scale_ub=input_scale_ub, scale_ub=input_scale_ub,
@@ -455,7 +443,7 @@ class Fp8LinearOp:
# Fused GEMM_DQ # Fused GEMM_DQ
if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel: if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel:
# Fall back to vllm cutlass w8a8 fp8 kernel # Fall back to vllm cutlass w8a8 fp8 kernel
output = ops.cutlass_scaled_mm( output = vllm_ops.cutlass_scaled_mm(
qinput, qinput,
weight, weight,
out_dtype=input.dtype, out_dtype=input.dtype,
@@ -482,14 +470,14 @@ class Fp8LinearOp:
else: else:
# Maybe apply padding to output, see comment in __init__ # Maybe apply padding to output, see comment in __init__
if _is_cuda: if _is_cuda:
qinput, x_scale = sgl_scaled_fp8_quant( qinput, x_scale = scaled_fp8_quant(
input_2d, input_2d,
input_scale, input_scale,
num_token_padding=self.output_padding, num_token_padding=self.output_padding,
use_per_token_if_dynamic=use_per_token_if_dynamic, use_per_token_if_dynamic=use_per_token_if_dynamic,
) )
else: else:
qinput, x_scale = ops.scaled_fp8_quant( qinput, x_scale = vllm_ops.scaled_fp8_quant(
input_2d, input_2d,
input_scale, input_scale,
num_token_padding=self.output_padding, num_token_padding=self.output_padding,
@@ -562,9 +550,12 @@ class Fp8LinearOp:
# This computes C = (X * W). # This computes C = (X * W).
# Output in fp32 to allow subsequent ops to happen in-place # Output in fp32 to allow subsequent ops to happen in-place
# Making sure the dummy tensor is on the same device as the weight
global TORCH_DEVICE_IDENTITY global TORCH_DEVICE_IDENTITY
if TORCH_DEVICE_IDENTITY.device != weight.device: if TORCH_DEVICE_IDENTITY is None:
TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device) TORCH_DEVICE_IDENTITY = torch.ones(
1, dtype=torch.float32, device=weight.device
)
output = torch._scaled_mm( output = torch._scaled_mm(
qinput, qinput,

View File

@@ -1,18 +1,17 @@
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
from types import MappingProxyType from types import MappingProxyType
from typing import List, Mapping, Optional, Tuple, Union from typing import List, Mapping, Tuple, Union
import torch import torch
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
from sglang.srt.utils import is_cuda from sglang.srt.utils import is_cuda
_is_cuda = is_cuda() _is_cuda = is_cuda()
if _is_cuda: if not _is_cuda:
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant from vllm._custom_ops import scaled_fp8_quant
else:
from vllm import _custom_ops as vllm_ops
def is_fp8_fnuz() -> bool: def is_fp8_fnuz() -> bool:
@@ -116,12 +115,7 @@ def requantize_with_max_scale(
for idx, logical_width in enumerate(logical_widths): for idx, logical_width in enumerate(logical_widths):
end = start + logical_width end = start + logical_width
weight_dq = per_tensor_dequantize(weight[start:end, :], weight_scale[idx]) weight_dq = per_tensor_dequantize(weight[start:end, :], weight_scale[idx])
if _is_cuda: weight[start:end, :], _ = scaled_fp8_quant(weight_dq, max_w_scale)
weight[start:end, :], _ = sgl_scaled_fp8_quant(weight_dq, max_w_scale)
else:
weight[start:end, :], _ = vllm_ops.scaled_fp8_quant(
weight_dq, max_w_scale
)
start = end start = end
return max_w_scale, weight return max_w_scale, weight

View File

@@ -1,13 +1,6 @@
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional
import torch import torch
from sglang.srt.utils import is_cuda_available, set_weight_attrs
is_cuda = is_cuda_available()
if is_cuda:
from sgl_kernel import int8_scaled_mm
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.distributed import get_tensor_model_parallel_world_size
@@ -18,6 +11,11 @@ from sglang.srt.layers.quantization.base_config import (
QuantizeMethodBase, QuantizeMethodBase,
) )
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8 from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
from sglang.srt.utils import is_cuda_available, set_weight_attrs
is_cuda = is_cuda_available()
if is_cuda:
from sgl_kernel import int8_scaled_mm
class W8A8Int8Config(QuantizationConfig): class W8A8Int8Config(QuantizationConfig):

View File

@@ -11,10 +11,11 @@ from sglang.srt.custom_op import CustomOp
from sglang.srt.utils import is_cuda_available from sglang.srt.utils import is_cuda_available
_is_cuda_available = is_cuda_available() _is_cuda_available = is_cuda_available()
if _is_cuda_available: if _is_cuda_available:
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
else: else:
from vllm import _custom_ops as ops from vllm._custom_ops import rotary_embedding as vllm_rotary_embedding
def _rotate_neox(x: torch.Tensor) -> torch.Tensor: def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
@@ -159,7 +160,7 @@ class RotaryEmbedding(CustomOp):
) )
else: else:
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype) self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
ops.rotary_embedding( vllm_rotary_embedding(
positions, positions,
query, query,
key, key,

View File

@@ -1,25 +0,0 @@
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
def get_backend_from_name(name: str) -> BaseLoRABackend:
"""
Get corresponding backend class from backend's name
"""
if name == "triton":
from sglang.srt.lora.backend.triton_backend import TritonLoRABackend
return TritonLoRABackend
elif name == "flashinfer":
from sglang.srt.lora.backend.flashinfer_backend import FlashInferLoRABackend
return FlashInferLoRABackend
else:
raise ValueError(f"Invalid backend: {name}")
__all__ = [
"BaseLoRABackend",
"FlashInferLoRABackend",
"TritonLoRABackend",
"get_backend_from_name",
]

View File

@@ -75,7 +75,7 @@ class BaseLoRABackend:
qkv_lora_a: torch.Tensor, qkv_lora_a: torch.Tensor,
qkv_lora_b: Union[torch.Tensor, Tuple[torch.Tensor]], qkv_lora_b: Union[torch.Tensor, Tuple[torch.Tensor]],
*args, *args,
**kwargs **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
"""Run the lora pass for QKV Layer. """Run the lora pass for QKV Layer.
@@ -98,7 +98,7 @@ class BaseLoRABackend:
gate_up_lora_a: torch.Tensor, gate_up_lora_a: torch.Tensor,
gate_up_lora_b: Union[torch.Tensor, Tuple[torch.Tensor]], gate_up_lora_b: Union[torch.Tensor, Tuple[torch.Tensor]],
*args, *args,
**kwargs **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
"""Run the lora pass for gate_up_proj, usually attached to MergedColumnParallelLayer. """Run the lora pass for gate_up_proj, usually attached to MergedColumnParallelLayer.
@@ -115,3 +115,19 @@ class BaseLoRABackend:
def set_batch_info(self, batch_info: LoRABatchInfo): def set_batch_info(self, batch_info: LoRABatchInfo):
self.batch_info = batch_info self.batch_info = batch_info
def get_backend_from_name(name: str) -> BaseLoRABackend:
"""
Get corresponding backend class from backend's name
"""
if name == "triton":
from sglang.srt.lora.backend.triton_backend import TritonLoRABackend
return TritonLoRABackend
elif name == "flashinfer":
from sglang.srt.lora.backend.flashinfer_backend import FlashInferLoRABackend
return FlashInferLoRABackend
else:
raise ValueError(f"Invalid backend: {name}")

View File

@@ -2,7 +2,7 @@ from typing import Tuple
import torch import torch
from sglang.srt.lora.backend import BaseLoRABackend from sglang.srt.lora.backend.base_backend import BaseLoRABackend
from sglang.srt.lora.utils import LoRABatchInfo from sglang.srt.lora.utils import LoRABatchInfo
from sglang.srt.utils import is_flashinfer_available from sglang.srt.utils import is_flashinfer_available

View File

@@ -1,6 +1,6 @@
import torch import torch
from sglang.srt.lora.backend import BaseLoRABackend from sglang.srt.lora.backend.base_backend import BaseLoRABackend
from sglang.srt.lora.triton_ops import ( from sglang.srt.lora.triton_ops import (
gate_up_lora_b_fwd, gate_up_lora_b_fwd,
qkv_lora_b_fwd, qkv_lora_b_fwd,

View File

@@ -16,7 +16,7 @@ from sglang.srt.layers.linear import (
RowParallelLinear, RowParallelLinear,
) )
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.lora.backend import BaseLoRABackend from sglang.srt.lora.backend.base_backend import BaseLoRABackend
class BaseLayerWithLoRA(nn.Module): class BaseLayerWithLoRA(nn.Module):

View File

@@ -27,7 +27,7 @@ from torch import nn
from sglang.srt.configs.load_config import LoadConfig from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.hf_transformers_utils import AutoConfig from sglang.srt.hf_transformers_utils import AutoConfig
from sglang.srt.lora.backend import BaseLoRABackend from sglang.srt.lora.backend.base_backend import BaseLoRABackend
from sglang.srt.lora.lora_config import LoRAConfig from sglang.srt.lora.lora_config import LoRAConfig
from sglang.srt.model_loader.loader import DefaultModelLoader from sglang.srt.model_loader.loader import DefaultModelLoader

View File

@@ -22,7 +22,7 @@ import torch
from sglang.srt.configs.load_config import LoadConfig from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.hf_transformers_utils import AutoConfig from sglang.srt.hf_transformers_utils import AutoConfig
from sglang.srt.lora.backend import BaseLoRABackend, get_backend_from_name from sglang.srt.lora.backend.base_backend import BaseLoRABackend, get_backend_from_name
from sglang.srt.lora.layers import BaseLayerWithLoRA, get_lora_layer from sglang.srt.lora.layers import BaseLayerWithLoRA, get_lora_layer
from sglang.srt.lora.lora import LoRAAdapter from sglang.srt.lora.lora import LoRAAdapter
from sglang.srt.lora.lora_config import LoRAConfig from sglang.srt.lora.lora_config import LoRAConfig

View File

@@ -14,7 +14,6 @@
"""DetokenizerManager is a process that detokenizes the token ids.""" """DetokenizerManager is a process that detokenizes the token ids."""
import dataclasses import dataclasses
import json
import logging import logging
import os import os
import signal import signal

View File

@@ -1,7 +1,8 @@
""" """
Multi-modality utils Multi-modality utils
""" """
import logging
from abc import abstractmethod from abc import abstractmethod
from typing import Callable, List, Optional, Tuple from typing import Callable, List, Optional, Tuple
@@ -12,11 +13,11 @@ from sglang.srt.managers.schedule_batch import (
MultimodalDataItem, MultimodalDataItem,
MultimodalInputs, MultimodalInputs,
global_server_args_dict, global_server_args_dict,
logger,
) )
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import print_warning_once from sglang.srt.utils import print_warning_once
from sglang.utils import logger
logger = logging.getLogger(__name__)
class MultiModalityDataPaddingPattern: class MultiModalityDataPaddingPattern:

View File

@@ -5,8 +5,6 @@ import logging
import pkgutil import pkgutil
from functools import lru_cache from functools import lru_cache
from transformers import PROCESSOR_MAPPING
from sglang.srt.managers.multimodal_processors.base_processor import ( from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor, BaseMultimodalProcessor,
) )

View File

@@ -8,8 +8,6 @@ from typing import List, Optional
import numpy as np import numpy as np
import PIL import PIL
from decord import VideoReader, cpu
from PIL import Image
from transformers import BaseImageProcessorFast from transformers import BaseImageProcessorFast
from sglang.srt.managers.schedule_batch import Modality from sglang.srt.managers.schedule_batch import Modality
@@ -102,6 +100,9 @@ class BaseMultimodalProcessor(ABC):
""" """
estimate the total frame count from all visual input estimate the total frame count from all visual input
""" """
# Lazy import because decord is not available on some arm platforms.
from decord import VideoReader, cpu
# Before processing inputs # Before processing inputs
estimated_frames_list = [] estimated_frames_list = []
for image in image_data: for image in image_data:

View File

@@ -37,11 +37,11 @@ from sglang.srt.model_executor.forward_batch_info import (
from sglang.srt.patch_torch import monkey_patch_torch_compile from sglang.srt.patch_torch import monkey_patch_torch_compile
from sglang.srt.utils import get_available_gpu_memory, is_hip from sglang.srt.utils import get_available_gpu_memory, is_hip
_is_hip = is_hip()
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
_is_hip = is_hip()
def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int): def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
for sub in model._modules.values(): for sub in model._modules.values():

View File

@@ -320,7 +320,6 @@ class ModelRunner:
logger.info(f"DeepEP is turned on. DeepEP mode: {server_args.deepep_mode}") logger.info(f"DeepEP is turned on. DeepEP mode: {server_args.deepep_mode}")
if not self.use_mla_backend: if not self.use_mla_backend:
logger.info("Disable chunked prefix cache for non-MLA backend.")
server_args.disable_chunked_prefix_cache = True server_args.disable_chunked_prefix_cache = True
elif self.page_size > 1: elif self.page_size > 1:
logger.info("Disable chunked prefix cache when page size > 1.") logger.info("Disable chunked prefix cache when page size > 1.")

View File

@@ -48,7 +48,7 @@ _is_cuda = is_cuda()
if _is_cuda: if _is_cuda:
from sgl_kernel import awq_dequantize from sgl_kernel import awq_dequantize
else: else:
from vllm import _custom_ops as ops from vllm._custom_ops import awq_dequantize
class DeepseekModelNextN(nn.Module): class DeepseekModelNextN(nn.Module):
@@ -273,7 +273,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
self_attn.kv_b_proj.qzeros, self_attn.kv_b_proj.qzeros,
).T ).T
else: else:
w = ops.awq_dequantize( w = awq_dequantize(
self_attn.kv_b_proj.qweight, self_attn.kv_b_proj.qweight,
self_attn.kv_b_proj.scales, self_attn.kv_b_proj.scales,
self_attn.kv_b_proj.qzeros, self_attn.kv_b_proj.qzeros,

View File

@@ -51,6 +51,7 @@ from sglang.srt.layers.linear import (
) )
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
@@ -80,10 +81,8 @@ _is_cuda = is_cuda()
if _is_cuda: if _is_cuda:
from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2 from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
else: else:
from vllm import _custom_ops as ops from vllm._custom_ops import awq_dequantize
if _is_hip: if _is_hip:
from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import ( from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
@@ -861,7 +860,7 @@ class DeepseekV2AttentionMLA(nn.Module):
) )
elif self.w_kc.dtype == torch.float8_e4m3fn: elif self.w_kc.dtype == torch.float8_e4m3fn:
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8( q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
q_nope.transpose(0, 1), dtype=torch.float8_e4m3fn q_nope.transpose(0, 1),
) )
q_nope_out = bmm_fp8( q_nope_out = bmm_fp8(
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16 q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
@@ -892,7 +891,7 @@ class DeepseekV2AttentionMLA(nn.Module):
) )
elif self.w_vc.dtype == torch.float8_e4m3fn: elif self.w_vc.dtype == torch.float8_e4m3fn:
attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8( attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
attn_output.transpose(0, 1), dtype=torch.float8_e4m3fn attn_output.transpose(0, 1),
) )
attn_bmm_output = bmm_fp8( attn_bmm_output = bmm_fp8(
attn_output_val, attn_output_val,
@@ -1565,7 +1564,7 @@ class DeepseekV2ForCausalLM(nn.Module):
self_attn.kv_b_proj.qzeros, self_attn.kv_b_proj.qzeros,
).T ).T
else: else:
w = ops.awq_dequantize( w = awq_dequantize(
self_attn.kv_b_proj.qweight, self_attn.kv_b_proj.qweight,
self_attn.kv_b_proj.scales, self_attn.kv_b_proj.scales,
self_attn.kv_b_proj.qzeros, self_attn.kv_b_proj.qzeros,

View File

@@ -1,4 +1,3 @@
import re
from typing import Dict, Tuple from typing import Dict, Tuple

View File

@@ -10,12 +10,11 @@ import torch
import sglang.srt.sampling.penaltylib as penaltylib import sglang.srt.sampling.penaltylib as penaltylib
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
logger = logging.getLogger(__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.managers.schedule_batch import ScheduleBatch
logger = logging.getLogger(__name__)
@dataclasses.dataclass @dataclasses.dataclass
class SamplingBatchInfo: class SamplingBatchInfo:

View File

@@ -1,18 +0,0 @@
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Some shortcuts for backward compatibility.
# They will be removed in new versions.
from sglang.srt.entrypoints.engine import Engine
from sglang.srt.entrypoints.http_server import kill_process_tree, launch_server

View File

@@ -187,6 +187,7 @@ class ServerArgs:
n_share_experts_fusion: int = 0 n_share_experts_fusion: int = 0
disable_shared_experts_fusion: bool = False disable_shared_experts_fusion: bool = False
disable_chunked_prefix_cache: bool = False disable_chunked_prefix_cache: bool = False
disable_fast_image_processor: bool = False
# Debug tensor dumps # Debug tensor dumps
debug_tensor_dump_output_folder: Optional[str] = None debug_tensor_dump_output_folder: Optional[str] = None
@@ -198,9 +199,6 @@ class ServerArgs:
disaggregation_bootstrap_port: int = 8998 disaggregation_bootstrap_port: int = 8998
disaggregation_transfer_backend: str = "mooncake" disaggregation_transfer_backend: str = "mooncake"
# multimodal
disable_fast_image_processor: bool = False
def __post_init__(self): def __post_init__(self):
# Expert parallelism # Expert parallelism
if self.enable_ep_moe: if self.enable_ep_moe:
@@ -1136,6 +1134,11 @@ class ServerArgs:
action="store_true", action="store_true",
help="Disable chunked prefix cache feature for deepseek, which should save overhead for short sequences.", help="Disable chunked prefix cache feature for deepseek, which should save overhead for short sequences.",
) )
parser.add_argument(
"--disable-fast-image-processor",
action="store_true",
help="Adopt base image processor instead of fast image processor.",
)
# Server warmups # Server warmups
parser.add_argument( parser.add_argument(
@@ -1187,13 +1190,6 @@ class ServerArgs:
help="The backend for disaggregation transfer. Default is mooncake.", help="The backend for disaggregation transfer. Default is mooncake.",
) )
# Multimodal
parser.add_argument(
"--disable-fast-image-processor",
action="store_true",
help="Adopt base image processor instead of fast image processor.",
)
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace): def from_cli_args(cls, args: argparse.Namespace):
args.tp_size = args.tensor_parallel_size args.tp_size = args.tensor_parallel_size

View File

@@ -55,7 +55,6 @@ import torch.distributed
import torch.distributed as dist import torch.distributed as dist
import triton import triton
import zmq import zmq
from decord import VideoReader, cpu
from fastapi.responses import ORJSONResponse from fastapi.responses import ORJSONResponse
from packaging import version as pkg_version from packaging import version as pkg_version
from PIL import Image from PIL import Image
@@ -545,6 +544,9 @@ def load_audio(audio_file: str, sr: int = 16000, mono: bool = True) -> np.ndarra
def encode_video(video_path, frame_count_limit=None): def encode_video(video_path, frame_count_limit=None):
# Lazy import because decord is not available on some arm platforms.
from decord import VideoReader, cpu
if not os.path.exists(video_path): if not os.path.exists(video_path):
logger.error(f"Video {video_path} does not exist") logger.error(f"Video {video_path} does not exist")
return [] return []

View File

@@ -26,8 +26,8 @@ from transformers import (
AutoProcessor, AutoProcessor,
) )
from sglang.srt.entrypoints.engine import Engine
from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.server import Engine
from sglang.srt.utils import load_image from sglang.srt.utils import load_image
from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER, calculate_rouge_l from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER, calculate_rouge_l

View File

@@ -3,7 +3,7 @@
import pytest import pytest
import torch import torch
from sglang.srt.custom_op import scaled_fp8_quant from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
from sglang.srt.utils import is_cuda from sglang.srt.utils import is_cuda

View File

@@ -93,9 +93,7 @@ class TestPerTokenGroupQuantFP8(TestFP8Base):
A, A_quant_gt, scale_gt = self._make_A( A, A_quant_gt, scale_gt = self._make_A(
M=self.M, K=self.K, group_size=self.group_size, out_dtype=self.quant_type M=self.M, K=self.K, group_size=self.group_size, out_dtype=self.quant_type
) )
A_quant, scale = per_token_group_quant_fp8( A_quant, scale = per_token_group_quant_fp8(x=A, group_size=self.group_size)
x=A, group_size=self.group_size, dtype=self.quant_type
)
torch.testing.assert_close(scale, scale_gt) torch.testing.assert_close(scale, scale_gt)
diff = (A_quant.to(torch.float16) - A_quant_gt.to(torch.float16)).abs() diff = (A_quant.to(torch.float16) - A_quant_gt.to(torch.float16)).abs()
diff_count = (diff > 1e-5).count_nonzero() diff_count = (diff > 1e-5).count_nonzero()

View File

@@ -3,9 +3,9 @@ import unittest
import torch import torch
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
from sglang.test.test_utils import CustomTestCase from sglang.test.test_utils import CustomTestCase
@@ -41,7 +41,7 @@ def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk):
B, D = a.shape B, D = a.shape
# Perform per-token quantization # Perform per-token quantization
a_q, a_s = sgl_scaled_fp8_quant(a, use_per_token_if_dynamic=True) a_q, a_s = scaled_fp8_quant(a, use_per_token_if_dynamic=True)
# Repeat tokens to match topk # Repeat tokens to match topk
a_q = a_q.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) a_q = a_q.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
# Also repeat the scale # Also repeat the scale
@@ -69,7 +69,7 @@ def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk):
# Activation function # Activation function
act_out = SiluAndMul().forward_native(inter_out) act_out = SiluAndMul().forward_native(inter_out)
# Quantize activation output with per-token # Quantize activation output with per-token
act_out_q, act_out_s = sgl_scaled_fp8_quant( act_out_q, act_out_s = scaled_fp8_quant(
act_out, use_per_token_if_dynamic=True act_out, use_per_token_if_dynamic=True
) )