Clean up imports (#5467)

This commit is contained in:
Lianmin Zheng
2025-04-16 15:26:49 -07:00
committed by GitHub
parent d7bc19a46a
commit 177320a582
51 changed files with 376 additions and 573 deletions

View File

@@ -24,6 +24,7 @@ from sglang.api import (
user_end,
video,
)
from sglang.global_config import global_config
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.lang.choices import (
greedy_token_selection,
@@ -31,6 +32,7 @@ from sglang.lang.choices import (
unconditional_likelihood_normalized,
)
from sglang.utils import LazyImport
from sglang.version import __version__
ServerArgs = LazyImport("sglang.srt.server_args", "ServerArgs")
Anthropic = LazyImport("sglang.lang.backend.anthropic", "Anthropic")
@@ -38,10 +40,6 @@ LiteLLM = LazyImport("sglang.lang.backend.litellm", "LiteLLM")
OpenAI = LazyImport("sglang.lang.backend.openai", "OpenAI")
VertexAI = LazyImport("sglang.lang.backend.vertexai", "VertexAI")
# Other configs
from sglang.global_config import global_config
from sglang.version import __version__
__all__ = [
"Engine",
"Runtime",

View File

@@ -707,10 +707,6 @@ def sample_random_requests(
# Download sharegpt if necessary
if not os.path.isfile(dataset_path):
print(
"If you do not want to randomly sample from a dataset,"
" please use --dataset-name random-ids."
)
dataset_path = download_and_cache_file(SHAREGPT_URL)
# Load the dataset.

View File

@@ -1,7 +1,3 @@
from typing import List, Optional, Union
import numpy as np
from sglang.lang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import get_chat_template
from sglang.lang.interpreter import StreamExecutor

View File

@@ -1,4 +1,4 @@
from typing import Callable, List, Optional, Union
from typing import List, Optional, Union
from sglang.lang.chat_template import get_chat_template
from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod

View File

@@ -2,7 +2,7 @@ import dataclasses
import logging
import time
import warnings
from typing import Callable, List, Optional, Union
from typing import List, Optional, Union
import numpy as np

View File

@@ -1,6 +1,5 @@
import os
import warnings
from typing import Optional
from sglang.lang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import get_chat_template

View File

@@ -5,13 +5,7 @@ from typing import List, Union
from sglang.global_config import global_config
from sglang.lang.interpreter import ProgramState, StreamExecutor, cache_program
from sglang.lang.ir import (
SglArgument,
SglConstantText,
SglExpr,
SglSamplingParams,
SglVariable,
)
from sglang.lang.ir import SglArgument, SglExpr, SglSamplingParams, SglVariable
def compile_func(function, backend):

View File

@@ -1,20 +1,16 @@
"""Tracing a program."""
import uuid
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional
from sglang.global_config import global_config
from sglang.lang.backend.base_backend import BaseBackend
from sglang.lang.interpreter import ProgramState, ProgramStateGroup
from sglang.lang.ir import (
SglArgument,
SglCommitLazy,
SglConcateAndAppend,
SglConstantText,
SglExpr,
SglExprList,
SglFork,
SglFunction,
SglGen,
SglGetForkItem,
SglRoleBegin,
@@ -230,8 +226,8 @@ class TracerProgramState(ProgramState):
self.cur_role = None
def _execute_var_scope_end(self, expr: SglVarScopeEnd):
new_node = SglVariable(name, source=self.last_node)
self.variables[name] = new_node
new_node = SglVariable(expr.name, source=self.last_node)
self.variables[expr.name] = new_node
def get_var(self, name):
ret = self.arguments.get(name, None)

View File

@@ -1,10 +1,8 @@
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/_custom_ops.py
import logging
import os
from typing import List, Tuple
import torch
import torch.library
from sglang.srt.utils import get_bool_env_var, is_hip, is_hpu

View File

@@ -42,65 +42,3 @@ class CustomOp(nn.Module):
return self.forward_hip
else:
return self.forward_native
if _is_cuda:
from sgl_kernel import sgl_per_tensor_quant_fp8, sgl_per_token_quant_fp8
def scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
num_token_padding: Optional[int] = None,
use_per_token_if_dynamic: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to FP8 (8-bit floating point) format.
Args:
input (torch.Tensor): Input tensor to be quantized
scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization.
If None, scales will be computed dynamically.
num_token_padding (Optional[int]): If specified, pad the first dimension
of the output to at least this value.
use_per_token_if_dynamic (bool): When using dynamic scaling (scale=None),
determines the quantization granularity:
- True: compute scale per token
- False: compute single scale per tensor
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- quantized_tensor: The FP8 quantized version of input
- scale_tensor: The scaling factors used for quantization
Raises:
AssertionError: If input is not 2D or if static scale's numel != 1
"""
assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D"
shape = input.shape
out_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
if num_token_padding:
shape = (max(num_token_padding, input.shape[0]), shape[1])
output = torch.empty(shape, device=input.device, dtype=out_dtype)
if scale is None:
# Dynamic scaling
if use_per_token_if_dynamic:
scale = torch.empty(
(shape[0], 1), device=input.device, dtype=torch.float32
)
sgl_per_token_quant_fp8(input, output, scale)
else:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
sgl_per_tensor_quant_fp8(
input, output, scale, is_static=False
) # False for dynamic
else:
# Static scaling
assert (
scale.numel() == 1
), f"Expected scalar scale, got numel={scale.numel()}"
sgl_per_tensor_quant_fp8(
input, output, scale, is_static=True
) # True for static
return output, scale

View File

@@ -19,11 +19,10 @@ import torch.distributed as dist
from PIL.Image import Image
from torch.distributed.tensor import DeviceMesh, DTensor
from sglang.srt.entrypoints.engine import Engine
from sglang.srt.entrypoints.http_server_engine import HttpServerEngineAdapter
from sglang.srt.model_executor.model_runner import LocalSerializedTensor
from sglang.srt.patch_torch import monkey_patch_torch_reductions
from sglang.srt.server import Engine
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj

View File

@@ -21,13 +21,6 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from sglang.srt.utils import is_cuda_available
_is_cuda = is_cuda_available()
if _is_cuda:
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
from sglang.srt.custom_op import CustomOp
from sglang.srt.distributed import (
divide,
@@ -35,7 +28,12 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.utils import set_weight_attrs
from sglang.srt.utils import is_cuda_available, set_weight_attrs
_is_cuda = is_cuda_available()
if _is_cuda:
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
logger = logging.getLogger(__name__)

View File

@@ -19,6 +19,7 @@ from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from sglang.srt.custom_op import CustomOp
from sglang.srt.utils import is_cuda_available
_is_cuda = is_cuda_available()
@@ -31,7 +32,6 @@ if _is_cuda:
rmsnorm,
)
from sglang.srt.custom_op import CustomOp
logger = logging.getLogger(__name__)

View File

@@ -2,6 +2,7 @@ import logging
from typing import Callable, List, Optional, Tuple
import torch
from torch.nn import Module
try:
from deep_gemm import (
@@ -13,8 +14,6 @@ try:
except ImportError:
use_deep_gemm = False
from torch.nn import Module
from sglang.srt.custom_op import CustomOp
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
@@ -37,21 +36,16 @@ from sglang.srt.layers.quantization.base_config import (
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.utils import DeepEPMode, is_cuda, is_hip, set_weight_attrs
_is_cuda = is_cuda()
if _is_cuda:
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
else:
from vllm import _custom_ops as vllm_ops
logger = logging.getLogger(__name__)
from sglang.srt.utils import DeepEPMode, is_hip, set_weight_attrs
_is_hip = is_hip()
_buffer = None
if _is_hip:
from vllm._custom_ops import scaled_fp8_quant
logger = logging.getLogger(__name__)
class GroupedGemmRunner(torch.nn.Module):
@@ -740,20 +734,12 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
)
for expert in range(layer.num_experts_per_partition):
if _is_cuda:
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
sgl_scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
)
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
sgl_scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
)
else:
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
vllm_ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
)
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
vllm_ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
)
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
)
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
)
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
return

View File

@@ -13,6 +13,7 @@ import triton
import triton.language as tl
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
from sglang.srt.utils import (
direct_register_custom_op,
get_bool_env_var,
@@ -22,28 +23,25 @@ from sglang.srt.utils import (
)
_is_hip = is_hip()
logger = logging.getLogger(__name__)
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
enable_moe_align_block_size_triton = bool(
int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
)
_is_cuda = is_cuda()
if _is_cuda:
from sgl_kernel import gelu_and_mul, silu_and_mul
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
else:
from vllm import _custom_ops as vllm_ops
from vllm._custom_ops import scaled_fp8_quant
if _is_cuda or _is_hip:
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
logger = logging.getLogger(__name__)
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
enable_moe_align_block_size_triton = bool(
int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
)
@triton.jit
def write_zeros_to_output(
c_ptr,
@@ -770,14 +768,9 @@ def invoke_fused_moe_kernel(
# activation tensor-wise fp8 quantization, dynamic or static
padded_size = padding_size
# activations apply per-token quantization when weights apply per-channel quantization by default
if _is_cuda:
A, A_scale = sgl_scaled_fp8_quant(
A, A_scale, use_per_token_if_dynamic=per_channel_quant
)
else:
A, A_scale = vllm_ops.scaled_fp8_quant(
A, A_scale, use_per_token_if_dynamic=per_channel_quant
)
A, A_scale = scaled_fp8_quant(
A, A_scale, use_per_token_if_dynamic=per_channel_quant
)
else:
# activation block-wise fp8 quantization
assert len(block_shape) == 2

View File

@@ -13,7 +13,6 @@
# ==============================================================================
import math
import os
from typing import Callable, Optional
import torch
@@ -29,6 +28,10 @@ _is_hip = is_hip()
if _is_cuda:
from sgl_kernel import moe_fused_gate
if _is_cuda or _is_hip:
from sgl_kernel import topk_softmax
expert_distribution_recorder = ExpertDistributionRecorder()
@@ -59,11 +62,6 @@ def fused_topk(
topk: int,
renormalize: bool,
):
if _is_cuda or _is_hip:
from sgl_kernel import topk_softmax
else:
from vllm import _custom_ops as vllm_ops
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
M, _ = hidden_states.shape
@@ -76,20 +74,12 @@ def fused_topk(
M, topk, dtype=torch.int32, device=hidden_states.device
)
if _is_cuda or _is_hip:
topk_softmax(
topk_weights,
topk_ids,
token_expert_indicies,
gating_output.float(),
)
else:
vllm_ops.topk_softmax(
topk_weights,
topk_ids,
token_expert_indicies,
gating_output.float(),
)
topk_softmax(
topk_weights,
topk_ids,
token_expert_indicies,
gating_output.float(),
)
del token_expert_indicies
if renormalize:

View File

@@ -7,8 +7,6 @@ from typing import Callable, Optional, Union
import torch
from torch.nn import Parameter
from sglang.srt.distributed import get_tensor_model_parallel_rank
__all__ = [
"BasevLLMParameter",
"PackedvLLMParameter",

View File

@@ -1,4 +1,4 @@
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
# Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors
# SPDX-License-Identifier: Apache-2.0
import logging
@@ -39,7 +39,6 @@ from sglang.srt.layers.quantization.compressed_tensors.utils import (
is_activation_quantization_format,
should_ignore_layer,
)
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
logger = logging.getLogger(__name__)

View File

@@ -1,22 +1,16 @@
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
# Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors
# SPDX-License-Identifier: Apache-2.0
import enum
import logging
from enum import Enum
from typing import TYPE_CHECKING, Callable, List, Optional
from typing import Callable, List, Optional
import torch
from compressed_tensors import CompressionFormat
from compressed_tensors.quantization import QuantizationStrategy
if TYPE_CHECKING:
from sglang.srt.layers.moe.fused_moe_triton import (
FusedMoE,
FusedMoEMethodBase,
FusedMoeWeightScaleSupported,
)
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
from sglang.srt.layers.quantization.utils import (
all_close_1d,
@@ -29,10 +23,9 @@ from sglang.srt.utils import set_weight_attrs
_is_cuda = is_cuda()
if _is_cuda:
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
else:
if not _is_cuda:
from vllm import _custom_ops as vllm_ops
from vllm._custom_ops import scaled_fp8_quant
try:
import vllm
@@ -58,8 +51,6 @@ __all__ = [
class CompressedTensorsMoEMethod:
def __new__(cls, *args, **kwargs):
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
if cls is CompressedTensorsMoEMethod:
return super().__new__(cls)
return super().__new__(cls)
@@ -76,7 +67,7 @@ class CompressedTensorsMoEMethod:
if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
if not VLLM_AVAILABLE:
raise ImportError(
"vllm is not installed, to use CompressedTensorsWNA16MoEMethod, please install vllm"
"vllm is not installed, to use CompressedTensorsWNA16MoEMethod, please install vllm."
)
return CompressedTensorsWNA16MoEMethod(quant_config)
elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
@@ -92,11 +83,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
def __init__(
self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
):
from sglang.srt.layers.moe.fused_moe_triton import (
FusedMoEMethodBase,
FusedMoeWeightScaleSupported,
)
self.quant_config = quant_config
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights")
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
@@ -267,19 +253,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer.w13_weight[expert_id][start : start + shard_size, :],
layer.w13_weight_scale[expert_id][shard_id],
)
(
layer.w13_weight[expert_id][start : start + shard_size, :],
_,
) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
if _is_cuda:
(
layer.w13_weight[expert_id][start : start + shard_size, :],
_,
) = sgl_scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
else:
(
layer.w13_weight[expert_id][start : start + shard_size, :],
_,
) = vllm_ops.scaled_fp8_quant(
dq_weight, max_w13_scales[expert_id]
)
start += shard_size
layer.w13_weight_scale = torch.nn.Parameter(
@@ -345,11 +323,6 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
def __init__(
self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
):
from sglang.srt.layers.moe.fused_moe_triton import (
FusedMoEMethodBase,
FusedMoeWeightScaleSupported,
)
self.quant_config = quant_config
# TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored.
@@ -609,7 +582,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
requires_grad=False,
)
marlin_w13_qweight = ops.gptq_marlin_moe_repack(
marlin_w13_qweight = vllm_ops.gptq_marlin_moe_repack(
layer.w13_weight_packed,
layer.w13_g_idx_sort_indices,
layer.w13_weight_packed.shape[1] * self.packed_factor,
@@ -617,7 +590,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
self.num_bits,
)
replace_tensor("w13_weight_packed", marlin_w13_qweight)
marlin_w2_qweight = ops.gptq_marlin_moe_repack(
marlin_w2_qweight = vllm_ops.gptq_marlin_moe_repack(
layer.w2_weight_packed,
layer.w2_g_idx_sort_indices,
layer.w2_weight_packed.shape[1] * self.packed_factor,
@@ -661,14 +634,9 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.topk import select_experts
assert activation == "silu", "Only SiLU activation is supported."
if not VLLM_AVAILABLE:
raise ImportError(
"vllm is not installed, to use fused_marlin_moe, please install vllm"
)
if expert_map is not None:
raise NotImplementedError(
"Expert Parallelism is not supported for " "fused Marlin MoE method."

View File

@@ -17,7 +17,6 @@ from sglang.srt.layers.quantization.compressed_tensors.schemes import (
)
from sglang.srt.layers.quantization.fp8_utils import (
Fp8LinearOp,
maybe_create_device_identity,
normalize_e4m3fn_to_e4m3fnuz,
)
from sglang.srt.layers.quantization.utils import is_fp8_fnuz, requantize_with_max_scale
@@ -99,8 +98,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
weight_loader: Callable,
**kwargs,
):
maybe_create_device_identity()
output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes

View File

@@ -8,15 +8,6 @@ import torch.nn.functional as F
from torch.nn import Module
from torch.nn.parameter import Parameter
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
from sglang.srt.layers.quantization.utils import (
all_close_1d,
convert_to_channelwise,
is_layer_skipped,
per_tensor_dequantize,
requantize_with_max_scale,
)
try:
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear,
@@ -27,11 +18,12 @@ try:
except ImportError:
MARLIN_FP8_AVAILABLE = False
def apply_fp8_marlin_linear(*args, **kwargs):
raise ImportError("vllm is not installed")
def dummy_func(*args, **kwargs):
raise ImportError(
"marlin FP8 requires some operators from vllm. Please install vllm."
)
def prepare_fp8_layer_for_marlin(*args, **kwargs):
raise ImportError("vllm is not installed")
apply_fp8_marlin_linear = prepare_fp8_layer_for_marlin = dummy_func
from sglang.srt.distributed import get_tensor_model_parallel_world_size
@@ -49,7 +41,10 @@ from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
from sglang.srt.layers.quantization.fp8_kernel import (
per_token_group_quant_fp8,
scaled_fp8_quant,
)
from sglang.srt.layers.quantization.fp8_utils import (
apply_fp8_linear,
apply_w8a8_block_fp8_linear,
@@ -57,30 +52,35 @@ from sglang.srt.layers.quantization.fp8_utils import (
input_to_float8,
normalize_e4m3fn_to_e4m3fnuz,
)
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
from sglang.srt.layers.quantization.utils import (
all_close_1d,
convert_to_channelwise,
is_layer_skipped,
per_tensor_dequantize,
requantize_with_max_scale,
)
from sglang.srt.utils import (
get_bool_env_var,
is_cuda,
is_hip,
permute_weight,
print_warning_once,
set_weight_attrs,
)
ACTIVATION_SCHEMES = ["static", "dynamic"]
_is_hip = is_hip()
_is_cuda = is_cuda()
if _is_hip:
from aiter import ActivationType
from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages, ck_moe_2stages_win4
from aiter.ops.shuffle import shuffle_weight
_is_cuda = is_cuda()
if not _is_cuda:
from vllm._custom_ops import scaled_fp8_quant
if _is_cuda:
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
else:
from vllm import _custom_ops as vllm_ops
ACTIVATION_SCHEMES = ["static", "dynamic"]
logger = logging.getLogger(__name__)
@@ -243,7 +243,6 @@ class Fp8LinearMethod(LinearMethodBase):
)
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.orig_dtype = params_dtype
@@ -327,7 +326,9 @@ class Fp8LinearMethod(LinearMethodBase):
layer.weight_scale_inv.data, requires_grad=False
)
return
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
# If checkpoint not serialized fp8, quantize the weights.
if not self.quant_config.is_checkpoint_fp8_serialized:
if self.cutlass_fp8_supported or self.use_marlin:
@@ -391,12 +392,9 @@ class Fp8LinearMethod(LinearMethodBase):
)
if self.use_marlin:
try:
prepare_fp8_layer_for_marlin(layer)
# Activations not quantized for marlin.
del layer.input_scale
except ImportError:
self.use_marlin = False
prepare_fp8_layer_for_marlin(layer)
# Activations not quantized for marlin.
del layer.input_scale
def apply(
self,
@@ -406,18 +404,15 @@ class Fp8LinearMethod(LinearMethodBase):
) -> torch.Tensor:
if self.use_marlin:
try:
return apply_fp8_marlin_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
workspace=layer.workspace,
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
bias=bias,
)
except ImportError:
self.use_marlin = False
return apply_fp8_marlin_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
workspace=layer.workspace,
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
bias=bias,
)
if self.block_quant:
return apply_w8a8_block_fp8_linear(
@@ -516,7 +511,7 @@ class Fp8MoEMethod:
)
# WEIGHTS
if get_bool_env_var("USE_INT4_WEIGHT"):
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
# INT4 MoE weight - INT32 packed
w13_weight = torch.nn.Parameter(
torch.empty(
@@ -617,7 +612,7 @@ class Fp8MoEMethod:
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
if get_bool_env_var("USE_INT4_WEIGHT"):
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
)
@@ -649,7 +644,7 @@ class Fp8MoEMethod:
layer.w2_input_scale = None
def process_weights_after_loading(self, layer: Module) -> None:
if get_bool_env_var("USE_INT4_WEIGHT"):
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
self.process_weights_hip_int4(layer)
return
@@ -706,20 +701,12 @@ class Fp8MoEMethod:
requires_grad=False,
)
for expert in range(layer.num_experts):
if _is_cuda:
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
sgl_scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
)
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
sgl_scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
)
else:
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
vllm_ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
)
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
vllm_ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
)
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
)
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
)
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
@@ -796,18 +783,10 @@ class Fp8MoEMethod:
layer.w13_weight[expert_id][start : start + shard_size, :],
layer.w13_weight_scale[expert_id][shard_id],
)
if _is_cuda:
(
layer.w13_weight[expert_id][start : start + shard_size, :],
_,
) = sgl_scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
else:
(
layer.w13_weight[expert_id][start : start + shard_size, :],
_,
) = vllm_ops.scaled_fp8_quant(
dq_weight, max_w13_scales[expert_id]
)
(
layer.w13_weight[expert_id][start : start + shard_size, :],
_,
) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
start += shard_size
layer.w13_weight_scale = torch.nn.Parameter(
@@ -930,41 +909,11 @@ class Fp8MoEMethod:
correction_bias=correction_bias,
)
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
# TODO: add triton kernel and add check get_bool_env_var("CK_MOE")
assert not no_combine, f"{no_combine=} is not supported."
return ck_moe_2stages_win4(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
layer.w13_weight_scale1,
layer.w2_weight_scale1,
activation=(
ActivationType.Silu if activation == "silu" else ActivationType.Gelu
),
)
if _is_hip and get_bool_env_var("CK_MOE"):
assert not no_combine, f"{no_combine=} is not supported."
if self.block_quant:
# TODO(CK_MOE): FP8 block_quant only supports 'silu' for the time-being.
assert (
activation == "silu"
), f"CK_MOE: FP8 bloack_quant {activation=} will be supported later, unset CK_MOE"
return asm_moe(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
layer.w13_weight_scale_inv,
layer.w2_weight_scale_inv,
block_shape=tuple(self.quant_config.weight_block_size),
expert_mask=None,
)
else:
return ck_moe_2stages(
if _is_hip:
if get_bool_env_var("USE_INT4_WEIGHT"):
# TODO: add triton kernel and add check get_bool_env_var("CK_MOE")
assert not no_combine, f"{no_combine=} is not supported."
return ck_moe_2stages_win4(
x,
layer.w13_weight,
layer.w2_weight,
@@ -978,33 +927,65 @@ class Fp8MoEMethod:
else ActivationType.Gelu
),
)
else:
# Expert fusion with FP8 quantization
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=inplace and not no_combine,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=True,
w1_scale=(
layer.w13_weight_scale_inv
if self.block_quant
else layer.w13_weight_scale
),
w2_scale=(
layer.w2_weight_scale_inv
if self.block_quant
else layer.w2_weight_scale
),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size,
no_combine=no_combine,
)
if get_bool_env_var("CK_MOE"):
assert not no_combine, f"{no_combine=} is not supported."
if self.block_quant:
# TODO(CK_MOE): FP8 block_quant only supports 'silu' for the time-being.
assert (
activation == "silu"
), f"CK_MOE: FP8 bloack_quant {activation=} will be supported later, unset CK_MOE"
return asm_moe(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
layer.w13_weight_scale_inv,
layer.w2_weight_scale_inv,
block_shape=tuple(self.quant_config.weight_block_size),
expert_mask=None,
)
else:
return ck_moe_2stages(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
layer.w13_weight_scale1,
layer.w2_weight_scale1,
activation=(
ActivationType.Silu
if activation == "silu"
else ActivationType.Gelu
),
)
# Expert fusion with FP8 quantization
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=inplace and not no_combine,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=True,
w1_scale=(
layer.w13_weight_scale_inv
if self.block_quant
else layer.w13_weight_scale
),
w2_scale=(
layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size,
no_combine=no_combine,
)
class Fp8KVCacheMethod(BaseKVCacheMethod):

View File

@@ -34,15 +34,23 @@ from sglang.srt.utils import (
supports_custom_op,
)
_enable_jit_deepgemm = False
_is_hip = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
_is_cuda = is_cuda()
_fp8_type = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
if _is_hip:
fp8_max = 224.0
else:
fp8_max = torch.finfo(_fp8_type).max
fp8_min = -fp8_max
_enable_jit_deepgemm = False
if _is_cuda:
import deep_gemm
from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8
from sgl_kernel import (
sgl_per_tensor_quant_fp8,
sgl_per_token_group_quant_fp8,
sgl_per_token_quant_fp8,
)
sm_version = get_device_sm()
if sm_version == 90 and get_bool_env_var(
@@ -53,6 +61,7 @@ if _is_cuda:
logger = logging.getLogger(__name__)
if supports_custom_op():
def deep_gemm_fp8_fp8_bf16_nt(
@@ -179,7 +188,6 @@ def per_token_group_quant_fp8(
x: torch.Tensor,
group_size: int,
eps: float = 1e-10,
dtype: torch.dtype = fp8_type_,
column_major_scales: bool = False,
scale_tma_aligned: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -192,7 +200,6 @@ def per_token_group_quant_fp8(
x: The input tenosr with ndim >= 2.
group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
@@ -202,15 +209,7 @@ def per_token_group_quant_fp8(
), "the last dimension of `x` cannot be divisible by `group_size`"
assert x.is_contiguous(), "`x` is not contiguous"
finfo = torch.finfo(dtype)
fp8_max = finfo.max
if _is_hip:
fp8_max = 224.0
fp8_min = -fp8_max
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type)
M = x.numel() // group_size
N = group_size
if column_major_scales:
@@ -276,27 +275,18 @@ def sglang_per_token_group_quant_fp8(
x: torch.Tensor,
group_size: int,
eps: float = 1e-10,
dtype: torch.dtype = fp8_type_,
):
assert (
x.shape[-1] % group_size == 0
), "the last dimension of `x` cannot be divisible by `group_size`"
assert x.is_contiguous(), "`x` is not contiguous"
finfo = torch.finfo(dtype)
fp8_max = finfo.max
fp8_min = -fp8_max
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
M = x.numel() // group_size
N = group_size
x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type)
x_s = torch.empty(
x.shape[:-1] + (x.shape[-1] // group_size,),
device=x.device,
dtype=torch.float32,
)
sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max)
return x_q, x_s
@@ -304,7 +294,7 @@ def sglang_per_token_group_quant_fp8(
def sglang_per_token_quant_fp8(
x: torch.Tensor,
dtype: torch.dtype = fp8_type_,
dtype: torch.dtype = _fp8_type,
):
assert x.is_contiguous(), "`x` is not contiguous"
@@ -368,7 +358,6 @@ def static_quant_fp8(
x: torch.Tensor,
x_s: torch.Tensor,
repeat_scale: bool = False,
dtype: torch.dtype = fp8_type_,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Function to perform static quantization using the given scale on an input tensor `x`.
@@ -386,15 +375,8 @@ def static_quant_fp8(
"""
assert x.is_contiguous(), "`x` is not contiguous"
assert x_s.numel() == 1, "only supports per-tensor scale"
finfo = torch.finfo(dtype)
fp8_max = finfo.max
if _is_hip:
fp8_max = 224.0
fp8_min = -fp8_max
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type)
M = x.numel() // x.shape[-1]
N = x.shape[-1]
if repeat_scale:
@@ -896,7 +878,7 @@ def _per_tensor_quant_mla_fp8_stage2(
def per_tensor_quant_mla_fp8(
x: torch.Tensor, eps: float = 1e-12, dtype: torch.dtype = torch.float8_e4m3fn
x: torch.Tensor, eps: float = 1e-12
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
This function quantizes input values to float8 values with tensor-wise quantization
@@ -904,13 +886,7 @@ def per_tensor_quant_mla_fp8(
"""
assert x.dim() == 3, "`x` is not a 3d-tensor"
finfo = torch.finfo(dtype)
fp8_max = finfo.max
if _is_hip:
dtype = torch.float8_e4m3fnuz
fp8_max = 224.0
x_q = x.new_empty(x.size(), dtype=dtype)
x_q = x.new_empty(x.size(), dtype=_fp8_type)
x_s = torch.zeros((1,), dtype=torch.float32, device=x.device)
num_head, num_seq, head_size = x.shape
@@ -935,9 +911,64 @@ def per_tensor_quant_mla_fp8(
head_size,
x.stride(0),
x.stride(1),
-fp8_max,
fp8_min,
fp8_max,
BLOCK_SIZE,
)
return x_q, x_s
def scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
num_token_padding: Optional[int] = None,
use_per_token_if_dynamic: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to FP8 (8-bit floating point) format.
Args:
input (torch.Tensor): Input tensor to be quantized
scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization.
If None, scales will be computed dynamically.
num_token_padding (Optional[int]): If specified, pad the first dimension
of the output to at least this value.
use_per_token_if_dynamic (bool): When using dynamic scaling (scale=None),
determines the quantization granularity:
- True: compute scale per token
- False: compute single scale per tensor
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- quantized_tensor: The FP8 quantized version of input
- scale_tensor: The scaling factors used for quantization
Raises:
AssertionError: If input is not 2D or if static scale's numel != 1
"""
assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D"
shape = input.shape
out_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
if num_token_padding:
shape = (max(num_token_padding, input.shape[0]), shape[1])
output = torch.empty(shape, device=input.device, dtype=out_dtype)
if scale is None:
# Dynamic scaling
if use_per_token_if_dynamic:
scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32)
sgl_per_token_quant_fp8(input, output, scale)
else:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
sgl_per_tensor_quant_fp8(
input, output, scale, is_static=False
) # False for dynamic
else:
# Static scaling
assert scale.numel() == 1, f"Expected scalar scale, got numel={scale.numel()}"
sgl_per_tensor_quant_fp8(
input, output, scale, is_static=True
) # True for static
return output, scale

View File

@@ -1,11 +1,19 @@
import os
from typing import List, Optional, Tuple
import torch
try:
from vllm import _custom_ops as vllm_ops
VLLM_AVAILABLE = True
except ImportError:
VLLM_AVAILABLE = False
from sglang.srt.layers.quantization.fp8_kernel import (
_enable_jit_deepgemm,
per_token_group_quant_fp8,
scaled_fp8_quant,
sglang_per_token_quant_fp8,
static_quant_fp8,
w8a8_block_fp8_matmul,
)
@@ -17,30 +25,20 @@ from sglang.srt.utils import (
is_hip,
)
try:
import vllm
from vllm import _custom_ops as ops
VLLM_AVAILABLE = True
except ImportError:
VLLM_AVAILABLE = False
use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_KERNEL")
_is_hip = is_hip()
_is_cuda = is_cuda()
if _is_hip and get_bool_env_var("CK_MOE"):
from aiter import gemm_a8w8_blockscale
_is_cuda = is_cuda()
if _is_cuda:
from sgl_kernel import fp8_blockwise_scaled_mm, fp8_scaled_mm
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_quant_fp8
use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_KERNEL")
# Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
TORCH_DEVICE_IDENTITY = None
_TORCH_VERSION = torch.__version__.split("+")[0]
try:
@@ -214,7 +212,7 @@ def block_quant_to_tensor_quant(
x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i]
x_q_tensor, scale = (
sgl_scaled_fp8_quant(x_dq_block)
scaled_fp8_quant(x_dq_block)
if _is_cuda
else input_to_float8(x_dq_block, dtype=x_q_block.dtype)
)
@@ -227,7 +225,7 @@ def channel_quant_to_tensor_quant(
) -> Tuple[torch.Tensor, torch.Tensor]:
x_dq_channel = x_q_channel.to(torch.float32) * x_s
x_q_tensor, scale = (
sgl_scaled_fp8_quant(x_dq_channel)
scaled_fp8_quant(x_dq_channel)
if _is_cuda
else input_to_float8(x_dq_channel, dtype=x_q_channel.dtype)
)
@@ -264,7 +262,7 @@ def apply_fp8_linear(
# final solution should be: 1. add support to per-tensor activation scaling.
# 2. solve the torch.compile error from weight_scale.numel() == 1 and x_scale.numel() > 1 (below line#308)
if _is_hip and weight_scale.numel() == 1:
qinput, x_scale = ops.scaled_fp8_quant(
qinput, x_scale = vllm_ops.scaled_fp8_quant(
input_2d,
input_scale,
use_per_token_if_dynamic=use_per_token_if_dynamic,
@@ -275,32 +273,29 @@ def apply_fp8_linear(
)
if cutlass_fp8_supported:
try:
if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel:
# Fall back to vllm cutlass w8a8 fp8 kernel
output = ops.cutlass_scaled_mm(
qinput,
weight,
out_dtype=input.dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias,
)
else:
assert (
weight_scale.numel() == weight.shape[1]
), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale"
output = fp8_scaled_mm(
qinput,
weight,
x_scale,
weight_scale,
out_dtype=input.dtype,
bias=bias,
)
return output.view(*output_shape)
except (ImportError, NameError, AttributeError):
pass
if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel:
# Fall back to vllm cutlass w8a8 fp8 kernel
output = vllm_ops.cutlass_scaled_mm(
qinput,
weight,
out_dtype=input.dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias,
)
else:
assert (
weight_scale.numel() == weight.shape[1]
), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale"
output = fp8_scaled_mm(
qinput,
weight,
x_scale,
weight_scale,
out_dtype=input.dtype,
bias=bias,
)
return output.view(*output_shape)
# torch.scaled_mm supports per tensor weights + activations only
# so fallback to naive if per channel or per token
@@ -343,8 +338,10 @@ def apply_fp8_linear(
# Making sure the dummy tensor is on the same device as the weight
global TORCH_DEVICE_IDENTITY
if TORCH_DEVICE_IDENTITY.device != weight.device:
TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device)
if TORCH_DEVICE_IDENTITY is None:
TORCH_DEVICE_IDENTITY = torch.ones(
1, dtype=torch.float32, device=weight.device
)
# GEMM
# This computes C = (X * W).
@@ -372,13 +369,6 @@ def apply_fp8_linear(
return output.to(dtype=input.dtype).view(*output_shape)
def maybe_create_device_identity():
# Allocate dummy ones tensor for torch._scaled_mm
global TORCH_DEVICE_IDENTITY
if TORCH_DEVICE_IDENTITY is None:
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/w8a8_utils.py
# TODO(luka): follow similar pattern for marlin and block-fp8-linear
# https://github.com/vllm-project/vllm/issues/14397
@@ -405,9 +395,7 @@ class Fp8LinearOp:
# We also don't pad when using torch.compile,
# as it breaks with dynamic shapes.
if pad_output is None:
enable_torch_compile = os.environ.get(
"SGLANG_ENABLE_TORCH_COMPILE", "0"
).lower() in ("1", "true", "yes")
enable_torch_compile = get_bool_env_var("SGLANG_ENABLE_TORCH_COMPILE")
pad_output = not enable_torch_compile
self.output_padding = 17 if pad_output else None
@@ -439,13 +427,13 @@ class Fp8LinearOp:
# for sgl-kernel fp8_scaled_mm, it support per channel W now
if self.cutlass_fp8_supported and weight_scale.numel() == weight.shape[1]:
if _is_cuda:
qinput, x_scale = sgl_scaled_fp8_quant(
qinput, x_scale = scaled_fp8_quant(
input_2d,
input_scale,
use_per_token_if_dynamic=use_per_token_if_dynamic,
)
else:
qinput, x_scale = ops.scaled_fp8_quant(
qinput, x_scale = vllm_ops.scaled_fp8_quant(
input_2d,
input_scale,
scale_ub=input_scale_ub,
@@ -455,7 +443,7 @@ class Fp8LinearOp:
# Fused GEMM_DQ
if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel:
# Fall back to vllm cutlass w8a8 fp8 kernel
output = ops.cutlass_scaled_mm(
output = vllm_ops.cutlass_scaled_mm(
qinput,
weight,
out_dtype=input.dtype,
@@ -482,14 +470,14 @@ class Fp8LinearOp:
else:
# Maybe apply padding to output, see comment in __init__
if _is_cuda:
qinput, x_scale = sgl_scaled_fp8_quant(
qinput, x_scale = scaled_fp8_quant(
input_2d,
input_scale,
num_token_padding=self.output_padding,
use_per_token_if_dynamic=use_per_token_if_dynamic,
)
else:
qinput, x_scale = ops.scaled_fp8_quant(
qinput, x_scale = vllm_ops.scaled_fp8_quant(
input_2d,
input_scale,
num_token_padding=self.output_padding,
@@ -562,9 +550,12 @@ class Fp8LinearOp:
# This computes C = (X * W).
# Output in fp32 to allow subsequent ops to happen in-place
# Making sure the dummy tensor is on the same device as the weight
global TORCH_DEVICE_IDENTITY
if TORCH_DEVICE_IDENTITY.device != weight.device:
TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device)
if TORCH_DEVICE_IDENTITY is None:
TORCH_DEVICE_IDENTITY = torch.ones(
1, dtype=torch.float32, device=weight.device
)
output = torch._scaled_mm(
qinput,

View File

@@ -1,18 +1,17 @@
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
from types import MappingProxyType
from typing import List, Mapping, Optional, Tuple, Union
from typing import List, Mapping, Tuple, Union
import torch
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
from sglang.srt.utils import is_cuda
_is_cuda = is_cuda()
if _is_cuda:
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
else:
from vllm import _custom_ops as vllm_ops
if not _is_cuda:
from vllm._custom_ops import scaled_fp8_quant
def is_fp8_fnuz() -> bool:
@@ -116,12 +115,7 @@ def requantize_with_max_scale(
for idx, logical_width in enumerate(logical_widths):
end = start + logical_width
weight_dq = per_tensor_dequantize(weight[start:end, :], weight_scale[idx])
if _is_cuda:
weight[start:end, :], _ = sgl_scaled_fp8_quant(weight_dq, max_w_scale)
else:
weight[start:end, :], _ = vllm_ops.scaled_fp8_quant(
weight_dq, max_w_scale
)
weight[start:end, :], _ = scaled_fp8_quant(weight_dq, max_w_scale)
start = end
return max_w_scale, weight

View File

@@ -1,13 +1,6 @@
from typing import Any, Callable, Dict, List, Optional
import torch
from sglang.srt.utils import is_cuda_available, set_weight_attrs
is_cuda = is_cuda_available()
if is_cuda:
from sgl_kernel import int8_scaled_mm
from torch.nn.parameter import Parameter
from sglang.srt.distributed import get_tensor_model_parallel_world_size
@@ -18,6 +11,11 @@ from sglang.srt.layers.quantization.base_config import (
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
from sglang.srt.utils import is_cuda_available, set_weight_attrs
is_cuda = is_cuda_available()
if is_cuda:
from sgl_kernel import int8_scaled_mm
class W8A8Int8Config(QuantizationConfig):

View File

@@ -11,10 +11,11 @@ from sglang.srt.custom_op import CustomOp
from sglang.srt.utils import is_cuda_available
_is_cuda_available = is_cuda_available()
if _is_cuda_available:
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
else:
from vllm import _custom_ops as ops
from vllm._custom_ops import rotary_embedding as vllm_rotary_embedding
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
@@ -159,7 +160,7 @@ class RotaryEmbedding(CustomOp):
)
else:
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
ops.rotary_embedding(
vllm_rotary_embedding(
positions,
query,
key,

View File

@@ -1,25 +0,0 @@
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
def get_backend_from_name(name: str) -> BaseLoRABackend:
"""
Get corresponding backend class from backend's name
"""
if name == "triton":
from sglang.srt.lora.backend.triton_backend import TritonLoRABackend
return TritonLoRABackend
elif name == "flashinfer":
from sglang.srt.lora.backend.flashinfer_backend import FlashInferLoRABackend
return FlashInferLoRABackend
else:
raise ValueError(f"Invalid backend: {name}")
__all__ = [
"BaseLoRABackend",
"FlashInferLoRABackend",
"TritonLoRABackend",
"get_backend_from_name",
]

View File

@@ -75,7 +75,7 @@ class BaseLoRABackend:
qkv_lora_a: torch.Tensor,
qkv_lora_b: Union[torch.Tensor, Tuple[torch.Tensor]],
*args,
**kwargs
**kwargs,
) -> torch.Tensor:
"""Run the lora pass for QKV Layer.
@@ -98,7 +98,7 @@ class BaseLoRABackend:
gate_up_lora_a: torch.Tensor,
gate_up_lora_b: Union[torch.Tensor, Tuple[torch.Tensor]],
*args,
**kwargs
**kwargs,
) -> torch.Tensor:
"""Run the lora pass for gate_up_proj, usually attached to MergedColumnParallelLayer.
@@ -115,3 +115,19 @@ class BaseLoRABackend:
def set_batch_info(self, batch_info: LoRABatchInfo):
self.batch_info = batch_info
def get_backend_from_name(name: str) -> BaseLoRABackend:
"""
Get corresponding backend class from backend's name
"""
if name == "triton":
from sglang.srt.lora.backend.triton_backend import TritonLoRABackend
return TritonLoRABackend
elif name == "flashinfer":
from sglang.srt.lora.backend.flashinfer_backend import FlashInferLoRABackend
return FlashInferLoRABackend
else:
raise ValueError(f"Invalid backend: {name}")

View File

@@ -2,7 +2,7 @@ from typing import Tuple
import torch
from sglang.srt.lora.backend import BaseLoRABackend
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
from sglang.srt.lora.utils import LoRABatchInfo
from sglang.srt.utils import is_flashinfer_available

View File

@@ -1,6 +1,6 @@
import torch
from sglang.srt.lora.backend import BaseLoRABackend
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
from sglang.srt.lora.triton_ops import (
gate_up_lora_b_fwd,
qkv_lora_b_fwd,

View File

@@ -16,7 +16,7 @@ from sglang.srt.layers.linear import (
RowParallelLinear,
)
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.lora.backend import BaseLoRABackend
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
class BaseLayerWithLoRA(nn.Module):

View File

@@ -27,7 +27,7 @@ from torch import nn
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.hf_transformers_utils import AutoConfig
from sglang.srt.lora.backend import BaseLoRABackend
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
from sglang.srt.lora.lora_config import LoRAConfig
from sglang.srt.model_loader.loader import DefaultModelLoader

View File

@@ -22,7 +22,7 @@ import torch
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.hf_transformers_utils import AutoConfig
from sglang.srt.lora.backend import BaseLoRABackend, get_backend_from_name
from sglang.srt.lora.backend.base_backend import BaseLoRABackend, get_backend_from_name
from sglang.srt.lora.layers import BaseLayerWithLoRA, get_lora_layer
from sglang.srt.lora.lora import LoRAAdapter
from sglang.srt.lora.lora_config import LoRAConfig

View File

@@ -14,7 +14,6 @@
"""DetokenizerManager is a process that detokenizes the token ids."""
import dataclasses
import json
import logging
import os
import signal

View File

@@ -1,7 +1,8 @@
"""
Multi-modality utils
Multi-modality utils
"""
import logging
from abc import abstractmethod
from typing import Callable, List, Optional, Tuple
@@ -12,11 +13,11 @@ from sglang.srt.managers.schedule_batch import (
MultimodalDataItem,
MultimodalInputs,
global_server_args_dict,
logger,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import print_warning_once
from sglang.utils import logger
logger = logging.getLogger(__name__)
class MultiModalityDataPaddingPattern:

View File

@@ -5,8 +5,6 @@ import logging
import pkgutil
from functools import lru_cache
from transformers import PROCESSOR_MAPPING
from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor,
)

View File

@@ -8,8 +8,6 @@ from typing import List, Optional
import numpy as np
import PIL
from decord import VideoReader, cpu
from PIL import Image
from transformers import BaseImageProcessorFast
from sglang.srt.managers.schedule_batch import Modality
@@ -102,6 +100,9 @@ class BaseMultimodalProcessor(ABC):
"""
estimate the total frame count from all visual input
"""
# Lazy import because decord is not available on some arm platforms.
from decord import VideoReader, cpu
# Before processing inputs
estimated_frames_list = []
for image in image_data:

View File

@@ -37,11 +37,11 @@ from sglang.srt.model_executor.forward_batch_info import (
from sglang.srt.patch_torch import monkey_patch_torch_compile
from sglang.srt.utils import get_available_gpu_memory, is_hip
_is_hip = is_hip()
if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner
_is_hip = is_hip()
def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
for sub in model._modules.values():

View File

@@ -320,7 +320,6 @@ class ModelRunner:
logger.info(f"DeepEP is turned on. DeepEP mode: {server_args.deepep_mode}")
if not self.use_mla_backend:
logger.info("Disable chunked prefix cache for non-MLA backend.")
server_args.disable_chunked_prefix_cache = True
elif self.page_size > 1:
logger.info("Disable chunked prefix cache when page size > 1.")

View File

@@ -48,7 +48,7 @@ _is_cuda = is_cuda()
if _is_cuda:
from sgl_kernel import awq_dequantize
else:
from vllm import _custom_ops as ops
from vllm._custom_ops import awq_dequantize
class DeepseekModelNextN(nn.Module):
@@ -273,7 +273,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
self_attn.kv_b_proj.qzeros,
).T
else:
w = ops.awq_dequantize(
w = awq_dequantize(
self_attn.kv_b_proj.qweight,
self_attn.kv_b_proj.scales,
self_attn.kv_b_proj.qzeros,

View File

@@ -51,6 +51,7 @@ from sglang.srt.layers.linear import (
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.base_config import QuantizationConfig
@@ -80,10 +81,8 @@ _is_cuda = is_cuda()
if _is_cuda:
from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
else:
from vllm import _custom_ops as ops
from vllm._custom_ops import awq_dequantize
if _is_hip:
from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
@@ -861,7 +860,7 @@ class DeepseekV2AttentionMLA(nn.Module):
)
elif self.w_kc.dtype == torch.float8_e4m3fn:
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
q_nope.transpose(0, 1), dtype=torch.float8_e4m3fn
q_nope.transpose(0, 1),
)
q_nope_out = bmm_fp8(
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
@@ -892,7 +891,7 @@ class DeepseekV2AttentionMLA(nn.Module):
)
elif self.w_vc.dtype == torch.float8_e4m3fn:
attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
attn_output.transpose(0, 1), dtype=torch.float8_e4m3fn
attn_output.transpose(0, 1),
)
attn_bmm_output = bmm_fp8(
attn_output_val,
@@ -1565,7 +1564,7 @@ class DeepseekV2ForCausalLM(nn.Module):
self_attn.kv_b_proj.qzeros,
).T
else:
w = ops.awq_dequantize(
w = awq_dequantize(
self_attn.kv_b_proj.qweight,
self_attn.kv_b_proj.scales,
self_attn.kv_b_proj.qzeros,

View File

@@ -1,4 +1,3 @@
import re
from typing import Dict, Tuple

View File

@@ -10,12 +10,11 @@ import torch
import sglang.srt.sampling.penaltylib as penaltylib
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import ScheduleBatch
logger = logging.getLogger(__name__)
@dataclasses.dataclass
class SamplingBatchInfo:

View File

@@ -1,18 +0,0 @@
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Some shortcuts for backward compatibility.
# They will be removed in new versions.
from sglang.srt.entrypoints.engine import Engine
from sglang.srt.entrypoints.http_server import kill_process_tree, launch_server

View File

@@ -187,6 +187,7 @@ class ServerArgs:
n_share_experts_fusion: int = 0
disable_shared_experts_fusion: bool = False
disable_chunked_prefix_cache: bool = False
disable_fast_image_processor: bool = False
# Debug tensor dumps
debug_tensor_dump_output_folder: Optional[str] = None
@@ -198,9 +199,6 @@ class ServerArgs:
disaggregation_bootstrap_port: int = 8998
disaggregation_transfer_backend: str = "mooncake"
# multimodal
disable_fast_image_processor: bool = False
def __post_init__(self):
# Expert parallelism
if self.enable_ep_moe:
@@ -1136,6 +1134,11 @@ class ServerArgs:
action="store_true",
help="Disable chunked prefix cache feature for deepseek, which should save overhead for short sequences.",
)
parser.add_argument(
"--disable-fast-image-processor",
action="store_true",
help="Adopt base image processor instead of fast image processor.",
)
# Server warmups
parser.add_argument(
@@ -1187,13 +1190,6 @@ class ServerArgs:
help="The backend for disaggregation transfer. Default is mooncake.",
)
# Multimodal
parser.add_argument(
"--disable-fast-image-processor",
action="store_true",
help="Adopt base image processor instead of fast image processor.",
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
args.tp_size = args.tensor_parallel_size

View File

@@ -55,7 +55,6 @@ import torch.distributed
import torch.distributed as dist
import triton
import zmq
from decord import VideoReader, cpu
from fastapi.responses import ORJSONResponse
from packaging import version as pkg_version
from PIL import Image
@@ -545,6 +544,9 @@ def load_audio(audio_file: str, sr: int = 16000, mono: bool = True) -> np.ndarra
def encode_video(video_path, frame_count_limit=None):
# Lazy import because decord is not available on some arm platforms.
from decord import VideoReader, cpu
if not os.path.exists(video_path):
logger.error(f"Video {video_path} does not exist")
return []

View File

@@ -26,8 +26,8 @@ from transformers import (
AutoProcessor,
)
from sglang.srt.entrypoints.engine import Engine
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.server import Engine
from sglang.srt.utils import load_image
from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER, calculate_rouge_l

View File

@@ -3,7 +3,7 @@
import pytest
import torch
from sglang.srt.custom_op import scaled_fp8_quant
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
from sglang.srt.utils import is_cuda