diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index 0cc44be55..1c770193f 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -1,12 +1,12 @@ """Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/linear.py""" +from __future__ import annotations + import itertools import logging -from abc import abstractmethod -from typing import Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple import torch -import torch.nn.functional as F from torch.nn.parameter import Parameter, UninitializedParameter from sglang.srt.distributed import ( @@ -17,7 +17,6 @@ from sglang.srt.distributed import ( tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce, ) -from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading from sglang.srt.layers.parameter import ( BasevLLMParameter, BlockQuantScaleParameter, @@ -27,17 +26,14 @@ from sglang.srt.layers.parameter import ( RowvLLMParameter, _ColumnvLLMParameter, ) -from sglang.srt.layers.quantization.base_config import ( - QuantizationConfig, - QuantizeMethodBase, -) -from sglang.srt.utils import ( - cpu_has_amx_support, - is_cpu, - is_npu, - set_weight_attrs, - use_intel_amx_backend, -) +from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod +from sglang.srt.utils import is_cpu, is_npu, set_weight_attrs + +if TYPE_CHECKING: + from sglang.srt.layers.quantization.base_config import ( + QuantizationConfig, + QuantizeMethodBase, + ) logger = logging.getLogger(__name__) @@ -59,7 +55,6 @@ WEIGHT_LOADER_V2_SUPPORTED = [ "IPEXAWQLinearMethod", ] -_is_cpu_amx_available = cpu_has_amx_support() _is_cpu = is_cpu() _is_npu = is_npu() @@ -110,91 +105,6 @@ def adjust_scalar_to_fused_array(param, loaded_weight, shard_id): return param[shard_id], loaded_weight -class LinearMethodBase(QuantizeMethodBase): - """Base class for different (maybe quantized) linear methods.""" - - @abstractmethod - def create_weights( - self, - layer: torch.nn.Module, - input_size_per_partition: int, - output_partition_sizes: List[int], - input_size: int, - output_size: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - """Create weights for a linear layer. - The weights will be set as attributes of the layer. - - Args: - layer: The layer that is using the LinearMethodBase factory. - input_size_per_partition: Size of the weight input dim on rank X. - output_partition_sizes: Sizes of the output dim of each logical - weight on rank X. E.g., output_partition_sizes for QKVLinear - is a list contains the width of Wq, Wk, Wv on rank X. - input_size: Size of the input dim of the weight across all ranks. - output_size: Size of the output dim of the weight across all ranks. - params_dtype: Datatype of the parameters. - """ - raise NotImplementedError - - @abstractmethod - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """Apply the weights in layer to the input tensor. - Expects create_weights to have been called before on the layer.""" - raise NotImplementedError - - -class UnquantizedLinearMethod(LinearMethodBase): - """Linear method without quantization.""" - - def create_weights( - self, - layer: torch.nn.Module, - input_size_per_partition: int, - output_partition_sizes: List[int], - input_size: int, - output_size: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - weight = Parameter( - torch.empty( - sum(output_partition_sizes), - input_size_per_partition, - dtype=params_dtype, - ), - requires_grad=False, - ) - set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) - layer.register_parameter("weight", weight) - set_weight_attrs(weight, extra_weight_attrs) - - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - if _is_cpu and _is_cpu_amx_available: - _amx_process_weight_after_loading(layer, ["weight"]) - - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - - if use_intel_amx_backend(layer): - return torch.ops.sgl_kernel.weight_packed_linear( - x, layer.weight, bias, True # is_vnni - ) - - return F.linear(x, layer.weight, bias) - - class LinearBase(torch.nn.Module): """Base linear layer. @@ -310,7 +220,7 @@ class ReplicatedLinear(LinearBase): assert param.size() == loaded_weight.size() param.data.copy_(loaded_weight) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: bias = self.bias if not self.skip_bias_add else None assert self.quant_method is not None output = self.quant_method.apply(self, x, bias) @@ -845,7 +755,7 @@ class QKVParallelLinear(ColumnParallelLinear): bias: bool = True, skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, + quant_config: Optional["QuantizationConfig"] = None, prefix: str = "", tp_rank: Optional[int] = None, tp_size: Optional[int] = None, diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index e8bfadfb6..a839b47fe 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -27,22 +27,20 @@ from sglang.srt.layers.moe.ep_moe.kernels import ( silu_and_mul_triton_kernel, tma_align_input_scale, ) -from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported -from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE, FusedMoEMethodBase +from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) -from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod +from sglang.srt.layers.quantization.fp8 import Fp8EPMoEMethod from sglang.srt.layers.quantization.fp8_kernel import ( is_fp8_fnuz, - scaled_fp8_quant, sglang_per_token_group_quant_fp8, sglang_per_token_quant_fp8, ) -from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz +from sglang.srt.layers.quantization.unquant import UnquantizedEPMoEMethod from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch @@ -53,7 +51,6 @@ from sglang.srt.utils import ( get_bool_env_var, is_hip, is_npu, - set_weight_attrs, ) _is_hip = is_hip() @@ -904,324 +901,6 @@ class EPMoE(torch.nn.Module): param_data[expert_id] = loaded_weight -class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp): - - def create_weights( - self, - layer: torch.nn.Module, - num_experts_per_partition: int, - hidden_size: int, - intermediate_size: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - # Fused gate_up_proj (column parallel) - w13_weight = torch.nn.Parameter( - torch.empty( - num_experts_per_partition, - 2 * intermediate_size, - hidden_size, - dtype=params_dtype, - ), - requires_grad=False, - ) - layer.register_parameter("w13_weight", w13_weight) - set_weight_attrs(w13_weight, extra_weight_attrs) - - # down_proj (row parallel) - w2_weight = torch.nn.Parameter( - torch.empty( - num_experts_per_partition, - hidden_size, - intermediate_size, - dtype=params_dtype, - ), - requires_grad=False, - ) - layer.register_parameter("w2_weight", w2_weight) - set_weight_attrs(w2_weight, extra_weight_attrs) - - # scale - layer.register_parameter("w13_input_scale", None) - layer.register_parameter("w13_weight_scale", None) - - ones_tensor = torch.ones(num_experts_per_partition, dtype=torch.float32) - - w2_input_scale = torch.nn.Parameter( - ones_tensor, - requires_grad=False, - ) - layer.register_parameter("w2_input_scale", w2_input_scale) - set_weight_attrs(w2_input_scale, extra_weight_attrs) - - w2_weight_scale = torch.nn.Parameter( - ones_tensor, - requires_grad=False, - ) - layer.register_parameter("w2_weight_scale", w2_weight_scale) - set_weight_attrs(w2_weight_scale, extra_weight_attrs) - - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, - ) -> torch.Tensor: - raise NotImplementedError - - -class Fp8EPMoEMethod(Fp8MoEMethod): - """MoE method for FP8. - Supports loading FP8 checkpoints with static weight scale and - dynamic/static activation scale. - - Args: - quant_config: The quantization config. - """ - - def __init__(self, quant_config: Fp8Config): - self.quant_config = quant_config - self.block_quant = self.quant_config.weight_block_size is not None - - def create_weights( - self, - layer: Module, - num_experts_per_partition: int, - hidden_size: int, - intermediate_size: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - if self.quant_config.is_checkpoint_fp8_serialized: - params_dtype = torch.float8_e4m3fn - - tp_size = get_tensor_model_parallel_world_size() - if self.block_quant: - block_n, block_k = ( - self.quant_config.weight_block_size[0], - self.quant_config.weight_block_size[1], - ) - # NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n. - # Required by column parallel or enabling merged weights - if intermediate_size % block_n != 0: - raise ValueError( - f"The output_size of gate's and up's weight = " - f"{intermediate_size} is not divisible by " - f"weight quantization block_n = {block_n}." - ) - if tp_size > 1: - # Required by row parallel - if intermediate_size % block_k != 0: - raise ValueError( - f"The input_size of down's weight = " - f"{intermediate_size} is not divisible by " - f"weight quantization block_k = {block_k}." - ) - - # WEIGHTS - w13_weight = torch.nn.Parameter( - torch.empty( - num_experts_per_partition, - 2 * intermediate_size, - hidden_size, - dtype=params_dtype, - ), - requires_grad=False, - ) - layer.register_parameter("w13_weight", w13_weight) - set_weight_attrs(w13_weight, extra_weight_attrs) - - w2_weight = torch.nn.Parameter( - torch.empty( - num_experts_per_partition, - hidden_size, - intermediate_size, - dtype=params_dtype, - ), - requires_grad=False, - ) - layer.register_parameter("w2_weight", w2_weight) - set_weight_attrs(w2_weight, extra_weight_attrs) - - # WEIGHT_SCALES - if self.block_quant: - w13_weight_scale = torch.nn.Parameter( - torch.ones( - num_experts_per_partition, - 2 * ((intermediate_size + block_n - 1) // block_n), - (hidden_size + block_k - 1) // block_k, - dtype=torch.float32, - ), - requires_grad=False, - ) - w2_weight_scale = torch.nn.Parameter( - torch.ones( - num_experts_per_partition, - (hidden_size + block_n - 1) // block_n, - (intermediate_size + block_k - 1) // block_k, - dtype=torch.float32, - ), - requires_grad=False, - ) - layer.register_parameter("w13_weight_scale_inv", w13_weight_scale) - layer.register_parameter("w2_weight_scale_inv", w2_weight_scale) - assert self.quant_config.activation_scheme == "dynamic" - else: - # WEIGHT_SCALES - # Allocate 2 scales for w1 and w3 respectively. - w13_weight_scale = torch.nn.Parameter( - torch.ones(num_experts_per_partition, 2, dtype=torch.float32), - requires_grad=False, - ) - layer.register_parameter("w13_weight_scale", w13_weight_scale) - - w2_weight_scale = torch.nn.Parameter( - torch.ones(num_experts_per_partition, dtype=torch.float32), - requires_grad=False, - ) - layer.register_parameter("w2_weight_scale", w2_weight_scale) - # Add the quantization method used (per tensor/grouped/channel) - # to ensure the weight scales are loaded in properly - extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value} - if self.block_quant - else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} - ) - # If loading fp8 checkpoint, pass the weight loaders. - # If loading an fp16 checkpoint, do not (we will quantize in - # process_weights_after_loading() - if self.quant_config.is_checkpoint_fp8_serialized: - set_weight_attrs(w13_weight_scale, extra_weight_attrs) - set_weight_attrs(w2_weight_scale, extra_weight_attrs) - - # INPUT_SCALES - if self.quant_config.activation_scheme == "static": - if not self.quant_config.is_checkpoint_fp8_serialized: - raise ValueError( - "Found static activation scheme for checkpoint that " - "was not serialized fp8." - ) - - w13_input_scale = torch.nn.Parameter( - torch.ones(num_experts_per_partition, dtype=torch.float32), - requires_grad=False, - ) - layer.register_parameter("w13_input_scale", w13_input_scale) - set_weight_attrs(w13_input_scale, extra_weight_attrs) - - w2_input_scale = torch.nn.Parameter( - torch.ones(num_experts_per_partition, dtype=torch.float32), - requires_grad=False, - ) - layer.register_parameter("w2_input_scale", w2_input_scale) - set_weight_attrs(w2_input_scale, extra_weight_attrs) - - else: - layer.w13_input_scale = None - layer.w2_input_scale = None - - def process_weights_after_loading(self, layer: Module) -> None: - - # If checkpoint is fp16, quantize in place. - if not self.quant_config.is_checkpoint_fp8_serialized: - # If rocm, use float8_e4m3fnuz as dtype - fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn - w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype) - w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype) - - layer.w13_weight_scale = torch.nn.Parameter( - torch.ones( - layer.num_experts_per_partition, - dtype=torch.float32, - device=w13_weight.device, - ), - requires_grad=False, - ) - - for expert in range(layer.num_experts_per_partition): - w13_weight[expert, :, :], layer.w13_weight_scale[expert] = ( - scaled_fp8_quant(layer.w13_weight.data[expert, :, :]) - ) - w2_weight[expert, :, :], layer.w2_weight_scale[expert] = ( - scaled_fp8_quant(layer.w2_weight.data[expert, :, :]) - ) - layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) - layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) - return - - # If checkpoint is fp8, we need to handle that the - # MoE kernels require single activation scale and single weight - # scale for w13 per expert. - else: - if self.quant_config.activation_scheme == "static": - if layer.w13_input_scale is None or layer.w2_input_scale is None: - raise ValueError( - "QuantConfig has static quantization, but found " - "activation scales are None." - ) - layer.w13_weight_scale = torch.nn.Parameter( - torch.max(layer.w13_weight_scale, dim=1).values, - requires_grad=False, - ) - if self.block_quant: - # If ROCm, normalize the weights and scales to e4m3fnuz - if _is_fp8_fnuz: - # activation_scheme: dynamic - w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( - weight=layer.w13_weight, - weight_scale=layer.w13_weight_scale_inv, - input_scale=None, - ) - w2_weight, w2_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( - weight=layer.w2_weight, - weight_scale=layer.w2_weight_scale_inv, - input_scale=None, - ) - # Reset the parameter - layer.w13_weight = torch.nn.Parameter( - w13_weight, requires_grad=False - ) - layer.w13_weight_scale_inv = torch.nn.Parameter( - w13_weight_scale, requires_grad=False - ) - layer.w13_input_scale = None - layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) - layer.w2_weight_scale_inv = torch.nn.Parameter( - w2_weight_scale, requires_grad=False - ) - layer.w2_input_scale = None - if _use_aiter: - layer.w13_weight = torch.nn.Parameter( - shuffle_weight(layer.w13_weight.data, (16, 16)), - requires_grad=False, - ) - layer.w2_weight = torch.nn.Parameter( - shuffle_weight(layer.w2_weight.data, (16, 16)), - requires_grad=False, - ) - return - - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, - ) -> torch.Tensor: - raise NotImplementedError - - class DeepEPMoE(EPMoE): """ MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/__init__.py b/python/sglang/srt/layers/moe/fused_moe_triton/__init__.py index 839b659fe..6d8aee852 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/__init__.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/__init__.py @@ -9,7 +9,6 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( ) from sglang.srt.layers.moe.fused_moe_triton.layer import ( FusedMoE, - FusedMoEMethodBase, FusedMoeWeightScaleSupported, ) @@ -31,11 +30,9 @@ def get_config() -> Optional[Dict[str, Any]]: __all__ = [ "FusedMoE", - "FusedMoEMethodBase", "FusedMoeWeightScaleSupported", "override_config", "get_config", - "fused_moe", "fused_experts", "get_config_file_name", "moe_align_block_size", diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index ad495d595..41ae6274b 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -1,60 +1,28 @@ # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py -import importlib -from abc import abstractmethod +import logging from enum import Enum from typing import Callable, List, Optional, Tuple import torch -from sglang.srt.custom_op import CustomOp from sglang.srt.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) -from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading -from sglang.srt.layers.moe.fused_moe_native import moe_forward_native -from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) +from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight -from sglang.srt.utils import ( - cpu_has_amx_support, - get_bool_env_var, - is_cpu, - is_hip, - set_weight_attrs, - use_intel_amx_backend, -) - -has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None - -if torch.cuda.is_available(): - from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts - - if has_triton_kernels: - from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import ( - triton_kernel_moe_forward, - ) -else: - fused_experts = None # type: ignore - -import logging +from sglang.srt.utils import cpu_has_amx_support, get_bool_env_var, is_cpu, is_hip _is_hip = is_hip() _is_cpu_amx_available = cpu_has_amx_support() _is_cpu = is_cpu() -_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip - -if _use_aiter: - from aiter import ActivationType - from aiter.fused_moe import fused_moe - from aiter.fused_moe_bf16_asm import ck_moe_2stages - from aiter.ops.shuffle import shuffle_weight logger = logging.getLogger(__name__) @@ -66,333 +34,6 @@ class FusedMoeWeightScaleSupported(Enum): BLOCK = "block" -class FusedMoEMethodBase(QuantizeMethodBase): - - @abstractmethod - def create_weights( - self, - layer: torch.nn.Module, - num_experts: int, - hidden_size: int, - intermediate_size: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - raise NotImplementedError - - @abstractmethod - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool, - ) -> torch.Tensor: - raise NotImplementedError - - -class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): - """MoE method without quantization.""" - - def __init__(self, use_triton_kernels: bool = False): - super().__init__() - self.use_triton_kernels = use_triton_kernels - - def create_weights( - self, - layer: torch.nn.Module, - num_experts: int, - hidden_size: int, - intermediate_size: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - # Fused gate_up_proj (column parallel) - w13_weight_n, w13_weight_k = 2 * intermediate_size, hidden_size - if self.use_triton_kernels: - w13_weight_n, w13_weight_k = w13_weight_k, w13_weight_n - w13_weight = torch.nn.Parameter( - torch.empty(num_experts, w13_weight_n, w13_weight_k, dtype=params_dtype), - requires_grad=False, - ) - layer.register_parameter("w13_weight", w13_weight) - set_weight_attrs(w13_weight, extra_weight_attrs) - - # down_proj (row parallel) - w2_weight_n, w2_weight_k = ( - hidden_size, - intermediate_size, - ) - if self.use_triton_kernels: - w2_weight_n, w2_weight_k = w2_weight_k, w2_weight_n - w2_weight = torch.nn.Parameter( - torch.empty(num_experts, w2_weight_n, w2_weight_k, dtype=params_dtype), - requires_grad=False, - ) - layer.register_parameter("w2_weight", w2_weight) - set_weight_attrs(w2_weight, extra_weight_attrs) - - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - if _use_aiter: - layer.w13_weight = torch.nn.Parameter( - shuffle_weight(layer.w13_weight.data, (16, 16)), - requires_grad=False, - ) - torch.cuda.empty_cache() - layer.w2_weight = torch.nn.Parameter( - shuffle_weight(layer.w2_weight.data, (16, 16)), - requires_grad=False, - ) - torch.cuda.empty_cache() - - # Pack weight for get better performance on CPU - if _is_cpu and _is_cpu_amx_available: - _amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"]) - - return - - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - num_fused_shared_experts: int = 0, - custom_routing_function: Optional[Callable] = None, - correction_bias: Optional[torch.Tensor] = None, - activation: str = "silu", - apply_router_weight_on_input: bool = False, - inplace: bool = True, - no_combine: bool = False, - routed_scaling_factor: Optional[float] = None, - ) -> torch.Tensor: - return self.forward( - x=x, - layer=layer, - router_logits=router_logits, - top_k=top_k, - renormalize=renormalize, - use_grouped_topk=use_grouped_topk, - topk_group=topk_group, - num_expert_group=num_expert_group, - num_fused_shared_experts=num_fused_shared_experts, - custom_routing_function=custom_routing_function, - correction_bias=correction_bias, - activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input, - inplace=inplace, - no_combine=no_combine, - routed_scaling_factor=routed_scaling_factor, - ) - - def forward_cuda( - self, - layer: torch.nn.Module, - x: torch.Tensor, - use_grouped_topk: bool, - top_k: int, - router_logits: torch.Tensor, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - num_fused_shared_experts: int = 0, - custom_routing_function: Optional[Callable] = None, - correction_bias: Optional[torch.Tensor] = None, - activation: str = "silu", - apply_router_weight_on_input: bool = False, - inplace: bool = True, - no_combine: bool = False, - routed_scaling_factor: Optional[float] = None, - ) -> torch.Tensor: - - if self.use_triton_kernels: - return triton_kernel_moe_forward( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - gating_output=router_logits, - topk=top_k, - renormalize=renormalize, - ) - else: - topk_weights, topk_ids = select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - num_fused_shared_experts=num_fused_shared_experts, - custom_routing_function=custom_routing_function, - correction_bias=correction_bias, - routed_scaling_factor=routed_scaling_factor, - ) - - if _use_aiter: - assert not no_combine, "unsupported" - if apply_router_weight_on_input: - assert ( - topk_weights.dim() == 2 - ), "`topk_weights` should be in shape (num_tokens, topk)" - _, topk = topk_weights.shape - assert ( - topk == 1 - ), "Only support topk=1 when `apply_router_weight_on_input` is True" - x = x * topk_weights.to(x.dtype) - topk_weights = torch.ones_like( - topk_weights, dtype=torch.float32 - ) # topk_weights must be FP32 (float32) - - return fused_moe( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights, - topk_ids, - activation=( - ActivationType.Silu - if activation == "silu" - else ActivationType.Gelu - ), - ) - else: - return fused_experts( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=inplace and not no_combine, - activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input, - no_combine=no_combine, - routed_scaling_factor=routed_scaling_factor, - ) - - def forward_cpu( - self, - layer: torch.nn.Module, - x: torch.Tensor, - use_grouped_topk: bool, - top_k: int, - router_logits: torch.Tensor, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - num_fused_shared_experts: int = 0, - custom_routing_function: Optional[Callable] = None, - correction_bias: Optional[torch.Tensor] = None, - activation: str = "silu", - apply_router_weight_on_input: bool = False, - inplace: bool = True, - no_combine: bool = False, - routed_scaling_factor: Optional[float] = None, - ) -> torch.Tensor: - assert activation == "silu", f"activation = {activation} is not supported." - - if use_intel_amx_backend(layer) and not apply_router_weight_on_input: - topk_weights, topk_ids = select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - num_fused_shared_experts=num_fused_shared_experts, - custom_routing_function=custom_routing_function, - correction_bias=correction_bias, - routed_scaling_factor=routed_scaling_factor, - ) - - # TODO: support apply_router_weight_on_input in the fused_experts_cpu kernel - return torch.ops.sgl_kernel.fused_experts_cpu( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights, - topk_ids, - False, # inplace # See [Note] inplace should be False in fused_experts. - False, # use_int8_w8a8 - False, # use_fp8_w8a16 - None, # w1_scale - None, # w2_scale - None, # block_size - None, # a1_scale - None, # a2_scale - True, # is_vnni - ) - else: - return moe_forward_native( - layer, - x, - use_grouped_topk, - top_k, - router_logits, - renormalize, - topk_group, - num_expert_group, - num_fused_shared_experts, - custom_routing_function, - correction_bias, - activation, - apply_router_weight_on_input, - inplace, - no_combine, - routed_scaling_factor, - ) - - def forward_npu( - self, - layer: torch.nn.Module, - x: torch.Tensor, - use_grouped_topk: bool, - top_k: int, - router_logits: torch.Tensor, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - num_fused_shared_experts: int = 0, - custom_routing_function: Optional[Callable] = None, - correction_bias: Optional[torch.Tensor] = None, - activation: str = "silu", - apply_router_weight_on_input: bool = False, - inplace: bool = True, - no_combine: bool = False, - routed_scaling_factor: Optional[float] = None, - ) -> torch.Tensor: - return moe_forward_native( - layer, - x, - use_grouped_topk, - top_k, - router_logits, - renormalize, - topk_group, - num_expert_group, - num_fused_shared_experts, - custom_routing_function, - correction_bias, - activation, - apply_router_weight_on_input, - inplace, - no_combine, - routed_scaling_factor, - ) - - def forward_tpu(self, *args, **kwargs) -> torch.Tensor: - raise NotImplementedError("The TPU backend currently does not support MoE.") - - forward_native = forward_cpu - - class FusedMoE(torch.nn.Module): """FusedMoE layer for MoE models. @@ -553,7 +194,7 @@ class FusedMoE(torch.nn.Module): shard_dim: int, expert_data: torch.Tensor, shard_id: str, - loaded_weight: torch.tensor, + loaded_weight: torch.Tensor, tp_rank: int, ): # Load grouped weight scales for group quantization @@ -580,7 +221,7 @@ class FusedMoE(torch.nn.Module): expert_data: torch.Tensor, shard_dim: int, shard_id: str, - loaded_weight: torch.tensor, + loaded_weight: torch.Tensor, tp_rank: int, ): # for per channel weight quantization @@ -600,7 +241,7 @@ class FusedMoE(torch.nn.Module): expert_data: torch.Tensor, shard_dim: int, shard_id: str, - loaded_weight: torch.tensor, + loaded_weight: torch.Tensor, tp_rank: int, ): @@ -645,7 +286,7 @@ class FusedMoE(torch.nn.Module): expert_data: torch.Tensor, shard_dim: int, shard_id: str, - loaded_weight: torch.tensor, + loaded_weight: torch.Tensor, tp_rank: int, ): """Load w2 weights for down projection. @@ -717,7 +358,7 @@ class FusedMoE(torch.nn.Module): shard_id: str, expert_data: torch.Tensor, shard_dim: int, - loaded_weight: torch.tensor, + loaded_weight: torch.Tensor, tp_rank: int, ): diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 18f3dea8d..1c8d219e4 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -19,15 +19,11 @@ import torch import torch.nn.functional as F from sglang.srt.eplb import expert_location_dispatch -from sglang.srt.eplb.expert_distribution import ( - ExpertDistributionRecorder, - get_global_expert_distribution_recorder, -) +from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.eplb.expert_location_dispatch import ( ExpertLocationDispatchInfo, topk_ids_logical_to_physical, ) -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.utils import ( cpu_has_amx_support, get_bool_env_var, diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py index 7507a5b62..e0f436343 100644 --- a/python/sglang/srt/layers/quantization/__init__.py +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -1,8 +1,6 @@ # Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py import builtins import inspect -import re -from copy import deepcopy from typing import Callable, Dict, Optional, Type, Union import torch @@ -45,7 +43,6 @@ except ImportError: ) = QQQConfig = Int8TpuConfig = DummyConfig -from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod from sglang.srt.layers.quantization.awq import AWQConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config @@ -66,6 +63,10 @@ from sglang.srt.layers.quantization.modelopt_quant import ( ) from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config from sglang.srt.layers.quantization.qoq import QoQConfig +from sglang.srt.layers.quantization.utils import ( + get_dynamic_override, + get_linear_quant_method, +) from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config @@ -120,99 +121,6 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: return QUANTIZATION_METHODS[quantization] -# Match dynamic rules with module name (prefix) and override quantize -# config if module (prefix) matches a rule -def override_config(config: QuantizationConfig, prefix: str): - weight_bits = get_dynamic_override(config, prefix, "bits", config.weight_bits) - if isinstance(weight_bits, int): - config.weight_bits = weight_bits - group_size = get_dynamic_override(config, prefix, "group_size", config.group_size) - if isinstance(group_size, int): - config.group_size = group_size - desc_act = get_dynamic_override(config, prefix, "desc_act", config.desc_act) - if isinstance(desc_act, bool): - config.desc_act = desc_act - - config.pack_factor = 32 // config.weight_bits # packed into int32 - if config.get_name() == "gptq_marlin": - is_sym = get_dynamic_override(config, prefix, "sym", config.is_sym) - if isinstance(is_sym, bool): - config.is_sym = is_sym - - if (config.weight_bits, config.is_sym) not in config.TYPE_MAP: - raise ValueError( - "Unsupported quantization config: " - f"bits={config.weight_bits}, sym={config.is_sym}" - ) - - config.quant_type = config.TYPE_MAP[(config.weight_bits, config.is_sym)] - elif config.get_name() == "gptq": - if config.weight_bits not in [2, 3, 4, 8]: - raise ValueError( - "Currently, only 2/3/4/8-bit weight quantization is " - f"supported for GPTQ, but got {config.weight_bits} bits." - ) - - -def get_dynamic_override( - config: QuantizationConfig, - layer_name: str, - key: Optional[str] = None, - default_value: Union[int, bool, None] = None, -) -> Union[Dict, int, bool, None]: - for pattern, pattern_dict in config.dynamic.items(): - # Negative match: matched modules are excluded from quantized init - if pattern.startswith("-:"): - if re.match(pattern.removeprefix("-:"), layer_name): - return False - # Positive match: matched modules have quant properties overrides - # base quant config - elif re.match(pattern.removeprefix("+:"), layer_name): - if key is None: - return pattern_dict - else: - return pattern_dict.get(key, default_value) - return default_value - - -def get_linear_quant_method( - config: QuantizationConfig, - layer: torch.nn.Module, - prefix: str, - linear_method_cls: type, -): - # Move import here to avoid circular import. This is only used in monkey patching - # of vllm's QuantizationConfig. - from sglang.srt.layers.vocab_parallel_embedding import ( - ParallelLMHead, - UnquantizedEmbeddingMethod, - ) - - cloned_config = deepcopy(config) - parallel_lm_head_quantized = ( - isinstance(layer, ParallelLMHead) and cloned_config.lm_head_quantized - ) - - if isinstance(layer, LinearBase) or parallel_lm_head_quantized: - # False = skip module, None = no override, else = Positive match - if ( - get_dynamic_override( # noqa: E712 - cloned_config, layer_name=prefix # noqa: E712 - ) - == False - ): # noqa: E712 - if parallel_lm_head_quantized: - return UnquantizedEmbeddingMethod() - return UnquantizedLinearMethod() - - if prefix: - # Dynamic per module/layer rules may override base config - override_config(cloned_config, prefix=prefix) - - return linear_method_cls(cloned_config) - return None - - def gptq_get_quant_method(self, layer, prefix): from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE diff --git a/python/sglang/srt/layers/quantization/awq.py b/python/sglang/srt/layers/quantization/awq.py index 9f14ac4c1..6265f2217 100644 --- a/python/sglang/srt/layers/quantization/awq.py +++ b/python/sglang/srt/layers/quantization/awq.py @@ -1,16 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + import logging from typing import Any, Dict, List, Optional import torch -from sglang.srt.layers.linear import ( - LinearBase, - LinearMethodBase, - UnquantizedLinearMethod, -) from sglang.srt.layers.parameter import GroupQuantScaleParameter, PackedvLLMParameter -from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.quantization.base_config import ( + LinearMethodBase, + QuantizationConfig, +) +from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod from sglang.srt.utils import is_cuda _is_cuda = is_cuda() @@ -81,7 +82,7 @@ class AWQConfig(QuantizationConfig): ] @classmethod - def from_config(cls, config: Dict[str, Any]) -> "AWQConfig": + def from_config(cls, config: Dict[str, Any]) -> AWQConfig: weight_bits = cls.get_from_keys(config, ["w_bit", "bits"]) group_size = cls.get_from_keys(config, ["q_group_size", "group_size"]) zero_point = cls.get_from_keys(config, ["zero_point"]) @@ -92,7 +93,8 @@ class AWQConfig(QuantizationConfig): def get_quant_method( self, layer: torch.nn.Module, prefix: str - ) -> Optional["LinearMethodBase"]: + ) -> Optional[LinearMethodBase]: + from sglang.srt.layers.linear import LinearBase if isinstance(layer, LinearBase): if is_layer_skipped_awq(prefix, self.modules_to_not_convert): diff --git a/python/sglang/srt/layers/quantization/base_config.py b/python/sglang/srt/layers/quantization/base_config.py index 6058702c9..607151671 100644 --- a/python/sglang/srt/layers/quantization/base_config.py +++ b/python/sglang/srt/layers/quantization/base_config.py @@ -18,14 +18,14 @@ class QuantizeMethodBase(ABC): """Create weights for a layer. The weights will be set as attributes of the layer.""" - raise NotImplementedError + raise NotImplementedError() @abstractmethod def apply(self, layer: torch.nn.Module, *args, **kwargs) -> torch.Tensor: """Apply the weights in layer to the input tensor. Expects create_weights to have been called before on the layer.""" - raise NotImplementedError + raise NotImplementedError() def process_weights_after_loading(self, layer: nn.Module) -> None: """Process the weight after loading. @@ -35,6 +35,74 @@ class QuantizeMethodBase(ABC): return +class LinearMethodBase(QuantizeMethodBase): + """Base class for different (maybe quantized) linear methods.""" + + @abstractmethod + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + """Create weights for a linear layer. + The weights will be set as attributes of the layer. + + Args: + layer: The layer that is using the LinearMethodBase factory. + input_size_per_partition: Size of the weight input dim on rank X. + output_partition_sizes: Sizes of the output dim of each logical + weight on rank X. E.g., output_partition_sizes for QKVLinear + is a list contains the width of Wq, Wk, Wv on rank X. + input_size: Size of the input dim of the weight across all ranks. + output_size: Size of the output dim of the weight across all ranks. + params_dtype: Datatype of the parameters. + """ + raise NotImplementedError() + + @abstractmethod + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Apply the weights in layer to the input tensor. + Expects create_weights to have been called before on the layer.""" + raise NotImplementedError() + + +class FusedMoEMethodBase(QuantizeMethodBase): + + @abstractmethod + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + raise NotImplementedError() + + @abstractmethod + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + ) -> torch.Tensor: + raise NotImplementedError() + + class QuantizationConfig(ABC): """Base class for quantization configs.""" @@ -46,12 +114,12 @@ class QuantizationConfig(ABC): @abstractmethod def get_name(self) -> str: """Name of the quantization method.""" - raise NotImplementedError + raise NotImplementedError() @abstractmethod def get_supported_act_dtypes(self) -> List[torch.dtype]: """List of supported activation dtypes.""" - raise NotImplementedError + raise NotImplementedError() @classmethod @abstractmethod @@ -62,19 +130,19 @@ class QuantizationConfig(ABC): This requirement is due to the custom CUDA kernels used by the quantization method. """ - raise NotImplementedError + raise NotImplementedError() @staticmethod @abstractmethod def get_config_filenames() -> List[str]: """List of filenames to search for in the model directory.""" - raise NotImplementedError + raise NotImplementedError() @classmethod @abstractmethod def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig": """Create a config class from the model's quantization config.""" - raise NotImplementedError + raise NotImplementedError() @classmethod def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]: @@ -117,7 +185,7 @@ class QuantizationConfig(ABC): The quantize method. None if the given layer doesn't support quant method. """ - raise NotImplementedError + raise NotImplementedError() @abstractmethod def get_scaled_act_names(self) -> List[str]: @@ -125,7 +193,7 @@ class QuantizationConfig(ABC): For now, this is only used by AWQ. """ - raise NotImplementedError + raise NotImplementedError() def method_has_implemented_embedding(method_class: Type[QuantizeMethodBase]) -> bool: diff --git a/python/sglang/srt/layers/quantization/blockwise_int8.py b/python/sglang/srt/layers/quantization/blockwise_int8.py index f38857595..a1da999b3 100644 --- a/python/sglang/srt/layers/quantization/blockwise_int8.py +++ b/python/sglang/srt/layers/quantization/blockwise_int8.py @@ -1,5 +1,7 @@ # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py +from __future__ import annotations + import logging from typing import Any, Callable, Dict, List, Optional @@ -7,17 +9,15 @@ import torch from torch.nn import Module from sglang.srt.distributed import get_tensor_model_parallel_world_size -from sglang.srt.layers.linear import ( - LinearBase, - LinearMethodBase, - UnquantizedLinearMethod, -) from sglang.srt.layers.parameter import BlockQuantScaleParameter, ModelWeightParameter from sglang.srt.layers.quantization.base_config import ( + FusedMoEMethodBase, + LinearMethodBase, QuantizationConfig, QuantizeMethodBase, ) from sglang.srt.layers.quantization.int8_utils import apply_w8a8_block_int8_linear +from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod from sglang.srt.layers.quantization.utils import is_layer_skipped from sglang.srt.utils import set_weight_attrs @@ -78,7 +78,7 @@ class BlockInt8Config(QuantizationConfig): return [] @classmethod - def from_config(cls, config: Dict[str, Any]) -> "BlockInt8Config": + def from_config(cls, config: Dict[str, Any]) -> BlockInt8Config: quant_method = cls.get_from_keys(config, ["quant_method"]) is_checkpoint_int8_serialized = "int8" in quant_method activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) @@ -93,7 +93,8 @@ class BlockInt8Config(QuantizationConfig): def get_quant_method( self, layer: torch.nn.Module, prefix: str - ) -> Optional["QuantizeMethodBase"]: + ) -> Optional[QuantizeMethodBase]: + from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.moe.fused_moe_triton import FusedMoE if isinstance(layer, LinearBase): @@ -230,7 +231,7 @@ class BlockInt8LinearMethod(LinearMethodBase): ) -class BlockInt8MoEMethod: +class BlockInt8MoEMethod(FusedMoEMethodBase): """MoE method for INT8. Supports loading INT8 checkpoints with static weight scale and dynamic activation scale. @@ -242,25 +243,7 @@ class BlockInt8MoEMethod: quant_config: The quantization config. """ - def __new__(cls, *args, **kwargs): - from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase - - if not hasattr(cls, "_initialized"): - original_init = cls.__init__ - new_cls = type( - cls.__name__, - (FusedMoEMethodBase,), - { - "__init__": original_init, - **{k: v for k, v in cls.__dict__.items() if k != "__dict__"}, - }, - ) - obj = super(new_cls, new_cls).__new__(new_cls) - obj.__init__(*args, **kwargs) - return obj - return super().__new__(cls) - - def __init__(self, quant_config): + def __init__(self, quant_config: BlockInt8Config): self.quant_config = quant_config assert self.quant_config.weight_block_size is not None assert self.quant_config.is_checkpoint_int8_serialized diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py index 7ce89345f..50d90406d 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py @@ -1,5 +1,6 @@ # Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations import logging from contextlib import suppress @@ -18,12 +19,8 @@ from compressed_tensors.quantization import ( ) from pydantic import BaseModel -from sglang.srt.layers.linear import ( - LinearBase, - LinearMethodBase, - UnquantizedLinearMethod, -) from sglang.srt.layers.quantization.base_config import ( + LinearMethodBase, QuantizationConfig, QuantizeMethodBase, ) @@ -40,6 +37,7 @@ from sglang.srt.layers.quantization.compressed_tensors.utils import ( is_activation_quantization_format, should_ignore_layer, ) +from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod try: import vllm @@ -97,7 +95,7 @@ class CompressedTensorsConfig(QuantizationConfig): self.config = config self.packed_modules_mapping = packed_modules_mapping - def get_linear_method(self) -> "CompressedTensorsLinearMethod": + def get_linear_method(self) -> CompressedTensorsLinearMethod: return CompressedTensorsLinearMethod(self) def get_supported_act_dtypes(cls) -> List[torch.dtype]: @@ -117,7 +115,8 @@ class CompressedTensorsConfig(QuantizationConfig): self, layer: torch.nn.Module, prefix: str, - ) -> Optional["QuantizeMethodBase"]: + ) -> Optional[QuantizeMethodBase]: + from sglang.srt.layers.linear import LinearBase # Check if the layer is skipped for quantization. # TODO (@robertgshaw2): support module names @@ -138,7 +137,7 @@ class CompressedTensorsConfig(QuantizationConfig): return None @classmethod - def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig": + def from_config(cls, config: Dict[str, Any]) -> CompressedTensorsConfig: ignore: List[str] = cast(List[str], config.get("ignore", [])) quant_format = cast(str, config.get("format")) target_scheme_map = cls._quantization_scheme_map_from_config(config=config) @@ -357,7 +356,7 @@ class CompressedTensorsConfig(QuantizationConfig): def _get_scheme_from_parts( self, weight_quant: BaseModel, input_quant: BaseModel - ) -> "CompressedTensorsScheme": + ) -> CompressedTensorsScheme: # Detect If Mixed Precision if self._is_wNa16_group_channel(weight_quant, input_quant): @@ -435,7 +434,7 @@ class CompressedTensorsConfig(QuantizationConfig): def get_scheme( self, layer: torch.nn.Module, layer_name: Optional[str] = None - ) -> Optional["CompressedTensorsScheme"]: + ) -> Optional[CompressedTensorsScheme]: """ compressed-tensors supports non uniform in the following way: diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 4d886de91..38588c809 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -1,7 +1,9 @@ # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py +from __future__ import annotations + import logging -from typing import Any, Callable, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union import torch import torch.nn.functional as F @@ -28,17 +30,14 @@ except ImportError: from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading -from sglang.srt.layers.linear import ( - LinearBase, - LinearMethodBase, - UnquantizedLinearMethod, -) from sglang.srt.layers.parameter import ( BlockQuantScaleParameter, ModelWeightParameter, PerTensorScaleParameter, ) from sglang.srt.layers.quantization.base_config import ( + FusedMoEMethodBase, + LinearMethodBase, QuantizationConfig, QuantizeMethodBase, ) @@ -56,6 +55,7 @@ from sglang.srt.layers.quantization.fp8_utils import ( normalize_e4m3fn_to_e4m3fnuz, ) from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod +from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod from sglang.srt.layers.quantization.utils import ( all_close_1d, convert_to_channelwise, @@ -77,6 +77,9 @@ from sglang.srt.utils import ( use_intel_amx_backend, ) +if TYPE_CHECKING: + from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config + _is_hip = is_hip() _is_cuda = is_cuda() _is_npu = is_npu() @@ -152,7 +155,7 @@ class Fp8Config(QuantizationConfig): return [] @classmethod - def from_config(cls, config: Dict[str, Any]) -> "Fp8Config": + def from_config(cls, config: Dict[str, Any]) -> Fp8Config: quant_method = cls.get_from_keys(config, ["quant_method"]) is_checkpoint_fp8_serialized = "fp8" in quant_method activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) @@ -167,7 +170,8 @@ class Fp8Config(QuantizationConfig): def get_quant_method( self, layer: torch.nn.Module, prefix: str - ) -> Optional["QuantizeMethodBase"]: + ) -> Optional[QuantizeMethodBase]: + from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.moe.fused_moe_triton import FusedMoE if isinstance(layer, LinearBase): @@ -200,7 +204,7 @@ class Fp8LinearMethod(LinearMethodBase): quant_config: The quantization config. """ - def __init__(self, quant_config: Union["Fp8Config", "W4AFp8Config"]): + def __init__(self, quant_config: Union[Fp8Config, W4AFp8Config]): self.quant_config = quant_config self.cutlass_fp8_supported = cutlass_fp8_supported() @@ -486,7 +490,7 @@ class Fp8LinearMethod(LinearMethodBase): ) -class Fp8MoEMethod: +class Fp8MoEMethod(FusedMoEMethodBase): """MoE method for FP8. Supports loading FP8 checkpoints with static weight scale and dynamic/static activation scale. @@ -499,25 +503,7 @@ class Fp8MoEMethod: quant_config: The quantization config. """ - def __new__(cls, *args, **kwargs): - from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase - - if not hasattr(cls, "_initialized"): - original_init = cls.__init__ - new_cls = type( - cls.__name__, - (FusedMoEMethodBase,), - { - "__init__": original_init, - **{k: v for k, v in cls.__dict__.items() if k != "__dict__"}, - }, - ) - obj = super(new_cls, new_cls).__new__(new_cls) - obj.__init__(*args, **kwargs) - return obj - return super().__new__(cls) - - def __init__(self, quant_config): + def __init__(self, quant_config: Fp8Config): self.quant_config = quant_config self.block_quant = self.quant_config.weight_block_size is not None self.cutlass_fp8_supported = cutlass_fp8_supported() @@ -1169,6 +1155,254 @@ class Fp8MoEMethod: return None +class Fp8EPMoEMethod(Fp8MoEMethod): + """MoE method for FP8. + Supports loading FP8 checkpoints with static weight scale and + dynamic/static activation scale. + + Args: + quant_config: The quantization config. + """ + + def __init__(self, quant_config: Fp8Config): + self.quant_config = quant_config + self.block_quant = self.quant_config.weight_block_size is not None + + def create_weights( + self, + layer: Module, + num_experts_per_partition: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported + + if self.quant_config.is_checkpoint_fp8_serialized: + params_dtype = torch.float8_e4m3fn + + tp_size = get_tensor_model_parallel_world_size() + if self.block_quant: + block_n, block_k = ( + self.quant_config.weight_block_size[0], + self.quant_config.weight_block_size[1], + ) + # NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n. + # Required by column parallel or enabling merged weights + if intermediate_size % block_n != 0: + raise ValueError( + f"The output_size of gate's and up's weight = " + f"{intermediate_size} is not divisible by " + f"weight quantization block_n = {block_n}." + ) + if tp_size > 1: + # Required by row parallel + if intermediate_size % block_k != 0: + raise ValueError( + f"The input_size of down's weight = " + f"{intermediate_size} is not divisible by " + f"weight quantization block_k = {block_k}." + ) + + # WEIGHTS + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts_per_partition, + 2 * intermediate_size, + hidden_size, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts_per_partition, + hidden_size, + intermediate_size, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # WEIGHT_SCALES + if self.block_quant: + w13_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts_per_partition, + 2 * ((intermediate_size + block_n - 1) // block_n), + (hidden_size + block_k - 1) // block_k, + dtype=torch.float32, + ), + requires_grad=False, + ) + w2_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts_per_partition, + (hidden_size + block_n - 1) // block_n, + (intermediate_size + block_k - 1) // block_k, + dtype=torch.float32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale_inv", w13_weight_scale) + layer.register_parameter("w2_weight_scale_inv", w2_weight_scale) + assert self.quant_config.activation_scheme == "dynamic" + else: + # WEIGHT_SCALES + # Allocate 2 scales for w1 and w3 respectively. + w13_weight_scale = torch.nn.Parameter( + torch.ones(num_experts_per_partition, 2, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts_per_partition, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + # Add the quantization method used (per tensor/grouped/channel) + # to ensure the weight scales are loaded in properly + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value} + if self.block_quant + else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} + ) + # If loading fp8 checkpoint, pass the weight loaders. + # If loading an fp16 checkpoint, do not (we will quantize in + # process_weights_after_loading() + if self.quant_config.is_checkpoint_fp8_serialized: + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # INPUT_SCALES + if self.quant_config.activation_scheme == "static": + if not self.quant_config.is_checkpoint_fp8_serialized: + raise ValueError( + "Found static activation scheme for checkpoint that " + "was not serialized fp8." + ) + + w13_input_scale = torch.nn.Parameter( + torch.ones(num_experts_per_partition, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w13_input_scale", w13_input_scale) + set_weight_attrs(w13_input_scale, extra_weight_attrs) + + w2_input_scale = torch.nn.Parameter( + torch.ones(num_experts_per_partition, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w2_input_scale", w2_input_scale) + set_weight_attrs(w2_input_scale, extra_weight_attrs) + + else: + layer.w13_input_scale = None + layer.w2_input_scale = None + + def process_weights_after_loading(self, layer: Module) -> None: + + # If checkpoint is fp16, quantize in place. + if not self.quant_config.is_checkpoint_fp8_serialized: + # If rocm, use float8_e4m3fnuz as dtype + fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn + w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype) + w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype) + + layer.w13_weight_scale = torch.nn.Parameter( + torch.ones( + layer.num_experts_per_partition, + dtype=torch.float32, + device=w13_weight.device, + ), + requires_grad=False, + ) + + for expert in range(layer.num_experts_per_partition): + w13_weight[expert, :, :], layer.w13_weight_scale[expert] = ( + scaled_fp8_quant(layer.w13_weight.data[expert, :, :]) + ) + w2_weight[expert, :, :], layer.w2_weight_scale[expert] = ( + scaled_fp8_quant(layer.w2_weight.data[expert, :, :]) + ) + layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) + layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) + return + + # If checkpoint is fp8, we need to handle that the + # MoE kernels require single activation scale and single weight + # scale for w13 per expert. + else: + if self.quant_config.activation_scheme == "static": + if layer.w13_input_scale is None or layer.w2_input_scale is None: + raise ValueError( + "QuantConfig has static quantization, but found " + "activation scales are None." + ) + layer.w13_weight_scale = torch.nn.Parameter( + torch.max(layer.w13_weight_scale, dim=1).values, + requires_grad=False, + ) + if self.block_quant: + # If ROCm, normalize the weights and scales to e4m3fnuz + if _is_fp8_fnuz: + # activation_scheme: dynamic + w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=layer.w13_weight, + weight_scale=layer.w13_weight_scale_inv, + input_scale=None, + ) + w2_weight, w2_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=layer.w2_weight, + weight_scale=layer.w2_weight_scale_inv, + input_scale=None, + ) + # Reset the parameter + layer.w13_weight = torch.nn.Parameter( + w13_weight, requires_grad=False + ) + layer.w13_weight_scale_inv = torch.nn.Parameter( + w13_weight_scale, requires_grad=False + ) + layer.w13_input_scale = None + layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) + layer.w2_weight_scale_inv = torch.nn.Parameter( + w2_weight_scale, requires_grad=False + ) + layer.w2_input_scale = None + if _use_aiter: + layer.w13_weight = torch.nn.Parameter( + shuffle_weight(layer.w13_weight.data, (16, 16)), + requires_grad=False, + ) + layer.w2_weight = torch.nn.Parameter( + shuffle_weight(layer.w2_weight.data, (16, 16)), + requires_grad=False, + ) + return + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + ) -> torch.Tensor: + raise NotImplementedError + + class Fp8KVCacheMethod(BaseKVCacheMethod): """ Supports loading kv-cache scaling factors from FP8 checkpoints. diff --git a/python/sglang/srt/layers/quantization/gptq.py b/python/sglang/srt/layers/quantization/gptq.py index 3658d0b85..af56c3be7 100644 --- a/python/sglang/srt/layers/quantization/gptq.py +++ b/python/sglang/srt/layers/quantization/gptq.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging from dataclasses import dataclass from fractions import Fraction @@ -5,7 +7,6 @@ from typing import Any, Callable, Dict, List, Optional, Union import torch -from sglang.srt.layers.linear import LinearBase, LinearMethodBase, set_weight_attrs from sglang.srt.layers.parameter import ( BasevLLMParameter, ChannelQuantScaleParameter, @@ -16,6 +17,8 @@ from sglang.srt.layers.parameter import ( permute_param_layout_, ) from sglang.srt.layers.quantization.base_config import ( + FusedMoEMethodBase, + LinearMethodBase, QuantizationConfig, QuantizeMethodBase, ) @@ -34,7 +37,11 @@ from sglang.srt.layers.quantization.marlin_utils import ( verify_marlin_supported, ) from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types -from sglang.srt.layers.quantization.utils import replace_parameter, unpack_cols +from sglang.srt.layers.quantization.utils import ( + get_linear_quant_method, + replace_parameter, + unpack_cols, +) try: from vllm import _custom_ops as ops @@ -49,8 +56,6 @@ if _is_cuda: from sgl_kernel import fused_marlin_moe -FusedMoEMethodBase = QuantizeMethodBase - logger = logging.getLogger(__name__) @@ -179,7 +184,7 @@ class GPTQConfig(QuantizationConfig): return ["quantize_config.json"] @classmethod - def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig": + def from_config(cls, config: Dict[str, Any]) -> GPTQConfig: dynamic = cls.get_from_keys_or(config, ["dynamic"], default={}) dynamic = {} if dynamic is None else dynamic @@ -191,10 +196,10 @@ class GPTQConfig(QuantizationConfig): def get_quant_method( self, layer: torch.nn.Module, prefix: str - ) -> Optional["LinearMethodBase"]: + ) -> Optional[LinearMethodBase]: # Delay the import to avoid circular dependency + from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.moe.fused_moe_triton import FusedMoE - from sglang.srt.layers.quantization import get_linear_quant_method if isinstance(layer, LinearBase): return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod) @@ -303,7 +308,7 @@ class GPTQMarlinConfig(QuantizationConfig): return ["quantize_config.json"] @classmethod - def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig": + def from_config(cls, config: Dict[str, Any]) -> GPTQMarlinConfig: dynamic = cls.get_from_keys_or(config, ["dynamic"], default={}) dynamic = {} if dynamic is None else dynamic @@ -354,7 +359,6 @@ class GPTQMarlinConfig(QuantizationConfig): ) -> Optional[QuantizeMethodBase]: # Delay the import to avoid circular dependency from sglang.srt.layers.moe.fused_moe_triton import FusedMoE - from sglang.srt.layers.quantization import get_linear_quant_method if isinstance(layer, FusedMoE): return GPTQMarlinMoEMethod(self) @@ -832,6 +836,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): **extra_weight_attrs, ): # Delay the import to avoid circular dependency + from sglang.srt.layers.linear import set_weight_attrs from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported intermediate_size = extra_weight_attrs.pop("intermediate_size") diff --git a/python/sglang/srt/layers/quantization/marlin_utils.py b/python/sglang/srt/layers/quantization/marlin_utils.py index 503c3d003..1edc672ab 100644 --- a/python/sglang/srt/layers/quantization/marlin_utils.py +++ b/python/sglang/srt/layers/quantization/marlin_utils.py @@ -1,25 +1,31 @@ # SPDX-License-Identifier: Apache-2.0 # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/marlin_utils.py +from __future__ import annotations + import logging -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional import numpy import torch -from sglang.srt.layers.linear import LinearBase, LinearMethodBase from sglang.srt.layers.parameter import ( BasevLLMParameter, ChannelQuantScaleParameter, GroupQuantScaleParameter, PackedvLLMParameter, ) -from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.quantization.base_config import ( + LinearMethodBase, + QuantizationConfig, +) from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types from sglang.srt.layers.quantization.utils import pack_cols, unpack_cols -from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.utils import get_device_capability +if TYPE_CHECKING: + from sglang.srt.layers.linear import LinearBase + try: from vllm import _custom_ops as ops except ImportError: @@ -617,7 +623,10 @@ class MarlinConfig(QuantizationConfig): def get_quant_method( self, layer: torch.nn.Module, prefix: str - ) -> Optional["MarlinLinearMethod"]: + ) -> Optional[MarlinLinearMethod]: + from sglang.srt.layers.linear import LinearBase + from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead + if isinstance(layer, LinearBase) or ( isinstance(layer, ParallelLMHead) and self.lm_head_quantized ): diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index 85be4f8f4..5263f3b92 100644 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -1,4 +1,5 @@ # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py +from __future__ import annotations import logging from typing import Any, Callable, Dict, List, Optional @@ -6,14 +7,11 @@ from typing import Any, Callable, Dict, List, Optional import torch from torch.nn.parameter import Parameter -from sglang.srt.layers.linear import ( - LinearBase, - LinearMethodBase, - UnquantizedLinearMethod, -) from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter from sglang.srt.layers.quantization.base_config import ( + FusedMoEMethodBase, + LinearMethodBase, QuantizationConfig, QuantizeMethodBase, ) @@ -23,6 +21,7 @@ from sglang.srt.layers.quantization.fp8_utils import ( is_sm100_supported, ) from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod +from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod from sglang.srt.layers.quantization.utils import ( convert_to_channelwise, is_layer_skipped, @@ -86,7 +85,7 @@ class ModelOptFp8Config(QuantizationConfig): return ["hf_quant_config.json"] @classmethod - def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp8Config": + def from_config(cls, config: Dict[str, Any]) -> ModelOptFp8Config: quant_method = cls.get_from_keys(config, ["quantization"]).get("quant_algo") kv_cache_quant_method = cls.get_from_keys(config, ["quantization"]).get( "kv_cache_quant_algo" @@ -109,7 +108,11 @@ class ModelOptFp8Config(QuantizationConfig): def get_quant_method( self, layer: torch.nn.Module, prefix: str - ) -> Optional["QuantizeMethodBase"]: + ) -> Optional[QuantizeMethodBase]: + + from sglang.srt.layers.linear import LinearBase + from sglang.srt.layers.moe.fused_moe_triton import FusedMoE + if self.exclude_modules and any( module in prefix or ( @@ -125,9 +128,6 @@ class ModelOptFp8Config(QuantizationConfig): if self.kv_cache_quant_method and isinstance(layer, RadixAttention): return ModelOptFp8KVCacheMethod(self) - # Add MoE support - from sglang.srt.layers.moe.fused_moe_triton import FusedMoE - if isinstance(layer, FusedMoE): return ModelOptFp8MoEMethod(self) @@ -246,7 +246,7 @@ class ModelOptFp8KVCacheMethod(BaseKVCacheMethod): super().__init__(quant_config) -class ModelOptFp8MoEMethod: +class ModelOptFp8MoEMethod(FusedMoEMethodBase): """MoE method for ModelOpt FP8. Supports loading FP8 checkpoints with static weight scale and activation scale. @@ -254,30 +254,6 @@ class ModelOptFp8MoEMethod: quant_config: The ModelOpt quantization config. """ - def __new__(cls, *args, **kwargs): - """ - Dynamic class composition pattern. - - This allows us to effectively "inject" FusedMoEMethodBase as a parent class - at runtime while avoiding circular import issues. - """ - from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase - - if not hasattr(cls, "_initialized"): - original_init = cls.__init__ - new_cls = type( - cls.__name__, - (FusedMoEMethodBase,), - { - "__init__": original_init, - **{k: v for k, v in cls.__dict__.items() if k != "__dict__"}, - }, - ) - obj = super(new_cls, new_cls).__new__(new_cls) - obj.__init__(*args, **kwargs) - return obj - return super().__new__(cls) - def __init__(self, quant_config: ModelOptFp8Config): self.quant_config = quant_config self.cutlass_fp8_supported = cutlass_fp8_supported() @@ -514,7 +490,7 @@ class ModelOptFp4Config(QuantizationConfig): return ["hf_quant_config.json"] @classmethod - def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp4Config": + def from_config(cls, config: Dict[str, Any]) -> ModelOptFp4Config: quant_config = cls.get_from_keys(config, ["quantization"]) quant_method = quant_config["quant_algo"] if not quant_method in ["FP8", "NVFP4"]: @@ -559,7 +535,8 @@ class ModelOptFp4Config(QuantizationConfig): def get_quant_method( self, layer: torch.nn.Module, prefix: str - ) -> Optional["QuantizeMethodBase"]: + ) -> Optional[QuantizeMethodBase]: + from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.moe.fused_moe_triton import FusedMoE if isinstance(layer, LinearBase): @@ -740,31 +717,13 @@ class ModelOptFp4LinearMethod(LinearMethodBase): return out.view(*output_shape) -class ModelOptNvFp4FusedMoEMethod: +class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): """ MoE Method for FP4 Quantization with Blockscales and PerTensorScales Args: quant_config: NVFP4 Quant Config """ - def __new__(cls, *args, **kwargs): - from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase - - if not hasattr(cls, "_initialized"): - original_init = cls.__init__ - new_cls = type( - cls.__name__, - (FusedMoEMethodBase,), - { - "__init__": original_init, - **{k: v for k, v in cls.__dict__.items() if k != "__dict__"}, - }, - ) - obj = super(new_cls, new_cls).__new__(new_cls) - obj.__init__(*args, **kwargs) - return obj - return super().__new__(cls) - def __init__(self, quant_config: ModelOptFp4Config): self.quant_config = quant_config if not is_sm100_supported(): diff --git a/python/sglang/srt/layers/quantization/moe_wna16.py b/python/sglang/srt/layers/quantization/moe_wna16.py index fe812595a..f83b9bb1f 100644 --- a/python/sglang/srt/layers/quantization/moe_wna16.py +++ b/python/sglang/srt/layers/quantization/moe_wna16.py @@ -1,4 +1,5 @@ # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/moe_wna16.py +from __future__ import annotations import logging from typing import Any, Callable, Dict, List, Optional @@ -7,13 +8,14 @@ import torch from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.distributed.parallel_state import get_tp_group -from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod from sglang.srt.layers.quantization.awq import AWQConfig from sglang.srt.layers.quantization.base_config import ( + FusedMoEMethodBase, QuantizationConfig, QuantizeMethodBase, ) from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig +from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod from sglang.srt.utils import get_device_capability, set_weight_attrs logger = logging.getLogger(__name__) @@ -118,7 +120,7 @@ class MoeWNA16Config(QuantizationConfig): raise NotImplementedError @classmethod - def from_config(cls, config: Dict[str, Any]) -> "MoeWNA16Config": + def from_config(cls, config: Dict[str, Any]) -> MoeWNA16Config: quant_method = cls.get_from_keys(config, ["quant_method"]) weight_bits = cls.get_from_keys(config, ["bits"]) group_size = cls.get_from_keys(config, ["group_size"]) @@ -177,8 +179,9 @@ class MoeWNA16Config(QuantizationConfig): def get_quant_method( self, layer: torch.nn.Module, prefix: str - ) -> Optional["QuantizeMethodBase"]: + ) -> Optional[QuantizeMethodBase]: # avoid circular import + from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE if is_layer_skipped_quant(prefix, self.modules_to_not_convert): @@ -209,32 +212,13 @@ def is_layer_skipped_quant(prefix: str, modules_to_not_convert: List[str]): return any(module_name in prefix for module_name in modules_to_not_convert) -class MoeWNA16Method: +class MoeWNA16Method(FusedMoEMethodBase): """Linear method for MOE WNA16 (W8A16/W4A16) quantization. Args: quant_config: The MOE WNA16 (W8A16/W4A16) quantization config. """ - def __new__(cls, *args, **kwargs): - # avoid circular import - from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase - - if not hasattr(cls, "_initialized"): - original_init = cls.__init__ - new_cls = type( - cls.__name__, - (FusedMoEMethodBase,), - { - "__init__": original_init, - **{k: v for k, v in cls.__dict__.items() if k != "__dict__"}, - }, - ) - obj = super(new_cls, new_cls).__new__(new_cls) - obj.__init__(*args, **kwargs) - return obj - return super().__new__(cls) - def __init__(self, quant_config: MoeWNA16Config): self.quant_config = quant_config diff --git a/python/sglang/srt/layers/quantization/qoq.py b/python/sglang/srt/layers/quantization/qoq.py index 3e3a3dfb6..ec0fda482 100644 --- a/python/sglang/srt/layers/quantization/qoq.py +++ b/python/sglang/srt/layers/quantization/qoq.py @@ -1,16 +1,17 @@ -from typing import Any, Callable, Dict, List, Optional +from __future__ import annotations + +from typing import Any, Dict, List, Optional import torch from torch.nn.parameter import Parameter -from sglang.srt.distributed import get_tensor_model_parallel_world_size -from sglang.srt.layers.linear import LinearMethodBase from sglang.srt.layers.parameter import ( ChannelQuantScaleParameter, GroupQuantScaleParameter, ModelWeightParameter, ) from sglang.srt.layers.quantization.base_config import ( + LinearMethodBase, QuantizationConfig, QuantizeMethodBase, ) @@ -71,7 +72,7 @@ class QoQConfig(QuantizationConfig): return 80 @classmethod - def get_name(self) -> str: + def get_name(cls) -> str: return "qoq" @classmethod @@ -83,7 +84,7 @@ class QoQConfig(QuantizationConfig): ] @classmethod - def from_config(cls, config: Dict[str, Any]) -> "QoQConfig": + def from_config(cls, config: Dict[str, Any]) -> QoQConfig: weight_bits = cls.get_from_keys(config, ["wbits"]) group_size = cls.get_from_keys(config, ["group_size"]) return cls(weight_bits, group_size) @@ -92,7 +93,7 @@ class QoQConfig(QuantizationConfig): self, layer: torch.nn.Module, prefix: str, - ) -> Optional["QuantizeMethodBase"]: + ) -> Optional[QuantizeMethodBase]: from sglang.srt.layers.linear import LinearBase if isinstance(layer, LinearBase): diff --git a/python/sglang/srt/layers/quantization/unquant.py b/python/sglang/srt/layers/quantization/unquant.py new file mode 100644 index 000000000..28d006255 --- /dev/null +++ b/python/sglang/srt/layers/quantization/unquant.py @@ -0,0 +1,515 @@ +import importlib +from typing import Callable, List, Optional + +import torch +import torch.nn.functional as F +from torch.nn.parameter import Parameter + +from sglang.srt.custom_op import CustomOp +from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading +from sglang.srt.layers.quantization.base_config import ( + FusedMoEMethodBase, + LinearMethodBase, + QuantizeMethodBase, +) +from sglang.srt.utils import ( + cpu_has_amx_support, + get_bool_env_var, + is_cpu, + is_hip, + set_weight_attrs, + use_intel_amx_backend, +) + +has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None + + +_is_cpu_amx_available = cpu_has_amx_support() +_is_hip = is_hip() +_is_cpu = is_cpu() +_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip + +if _use_aiter: + from aiter import ActivationType + from aiter.fused_moe import fused_moe + from aiter.fused_moe_bf16_asm import ck_moe_2stages + from aiter.ops.shuffle import shuffle_weight + + +class UnquantizedEmbeddingMethod(QuantizeMethodBase): + """Unquantized method for embeddings.""" + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + """Create weights for embedding layer.""" + weight = Parameter( + torch.empty( + sum(output_partition_sizes), + input_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + layer.register_parameter("weight", weight) + set_weight_attrs(weight, extra_weight_attrs) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return F.linear(x, layer.weight, bias) + + def embedding(self, layer: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor: + return F.embedding(input_, layer.weight) + + +class UnquantizedLinearMethod(LinearMethodBase): + """Linear method without quantization.""" + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + weight = Parameter( + torch.empty( + sum(output_partition_sizes), + input_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + layer.register_parameter("weight", weight) + set_weight_attrs(weight, extra_weight_attrs) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if _is_cpu and _is_cpu_amx_available: + _amx_process_weight_after_loading(layer, ["weight"]) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + if use_intel_amx_backend(layer): + return torch.ops.sgl_kernel.weight_packed_linear( + x, layer.weight, bias, True # is_vnni + ) + + return F.linear(x, layer.weight, bias) + + +class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): + """MoE method without quantization.""" + + def __init__(self, use_triton_kernels: bool = False): + super().__init__() + self.use_triton_kernels = use_triton_kernels + + from sglang.srt.layers.moe.fused_moe_native import moe_forward_native + + if torch.cuda.is_available(): + from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts + + if has_triton_kernels: + from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import ( + triton_kernel_moe_forward, + ) + else: + triton_kernel_moe_forward = None + else: + fused_experts = None # type: ignore + triton_kernel_moe_forward = None + + self.moe_forward_native = moe_forward_native + self.fused_experts = fused_experts + self.triton_kernel_moe_forward = triton_kernel_moe_forward + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + # Fused gate_up_proj (column parallel) + w13_weight_n, w13_weight_k = 2 * intermediate_size, hidden_size + if self.use_triton_kernels: + w13_weight_n, w13_weight_k = w13_weight_k, w13_weight_n + w13_weight = torch.nn.Parameter( + torch.empty(num_experts, w13_weight_n, w13_weight_k, dtype=params_dtype), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + # down_proj (row parallel) + w2_weight_n, w2_weight_k = ( + hidden_size, + intermediate_size, + ) + if self.use_triton_kernels: + w2_weight_n, w2_weight_k = w2_weight_k, w2_weight_n + w2_weight = torch.nn.Parameter( + torch.empty(num_experts, w2_weight_n, w2_weight_k, dtype=params_dtype), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if _use_aiter: + layer.w13_weight = torch.nn.Parameter( + shuffle_weight(layer.w13_weight.data, (16, 16)), + requires_grad=False, + ) + torch.cuda.empty_cache() + layer.w2_weight = torch.nn.Parameter( + shuffle_weight(layer.w2_weight.data, (16, 16)), + requires_grad=False, + ) + torch.cuda.empty_cache() + + # Pack weight for get better performance on CPU + if _is_cpu and _is_cpu_amx_available: + _amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"]) + + return + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + num_fused_shared_experts: int = 0, + custom_routing_function: Optional[Callable] = None, + correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + inplace: bool = True, + no_combine: bool = False, + routed_scaling_factor: Optional[float] = None, + ) -> torch.Tensor: + + return self.forward( + x=x, + layer=layer, + router_logits=router_logits, + top_k=top_k, + renormalize=renormalize, + use_grouped_topk=use_grouped_topk, + topk_group=topk_group, + num_expert_group=num_expert_group, + num_fused_shared_experts=num_fused_shared_experts, + custom_routing_function=custom_routing_function, + correction_bias=correction_bias, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + inplace=inplace, + no_combine=no_combine, + routed_scaling_factor=routed_scaling_factor, + ) + + def forward_cuda( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + num_fused_shared_experts: int = 0, + custom_routing_function: Optional[Callable] = None, + correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + inplace: bool = True, + no_combine: bool = False, + routed_scaling_factor: Optional[float] = None, + ) -> torch.Tensor: + + if self.use_triton_kernels: + return self.triton_kernel_moe_forward( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + ) + else: + from sglang.srt.layers.moe.topk import select_experts + + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + num_fused_shared_experts=num_fused_shared_experts, + custom_routing_function=custom_routing_function, + correction_bias=correction_bias, + routed_scaling_factor=routed_scaling_factor, + ) + + if _use_aiter: + assert not no_combine, "unsupported" + if apply_router_weight_on_input: + assert ( + topk_weights.dim() == 2 + ), "`topk_weights` should be in shape (num_tokens, topk)" + _, topk = topk_weights.shape + assert ( + topk == 1 + ), "Only support topk=1 when `apply_router_weight_on_input` is True" + x = x * topk_weights.to(x.dtype) + topk_weights = torch.ones_like( + topk_weights, dtype=torch.float32 + ) # topk_weights must be FP32 (float32) + + return fused_moe( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + activation=( + ActivationType.Silu + if activation == "silu" + else ActivationType.Gelu + ), + ) + else: + return self.fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=inplace and not no_combine, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + no_combine=no_combine, + routed_scaling_factor=routed_scaling_factor, + ) + + def forward_cpu( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + num_fused_shared_experts: int = 0, + custom_routing_function: Optional[Callable] = None, + correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + inplace: bool = True, + no_combine: bool = False, + routed_scaling_factor: Optional[float] = None, + ) -> torch.Tensor: + assert activation == "silu", f"activation = {activation} is not supported." + + if use_intel_amx_backend(layer) and not apply_router_weight_on_input: + + from sglang.srt.layers.moe.topk import select_experts + + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + num_fused_shared_experts=num_fused_shared_experts, + custom_routing_function=custom_routing_function, + correction_bias=correction_bias, + routed_scaling_factor=routed_scaling_factor, + ) + + # TODO: support apply_router_weight_on_input in the fused_experts_cpu kernel + return torch.ops.sgl_kernel.fused_experts_cpu( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + False, # inplace # See [Note] inplace should be False in fused_experts. + False, # use_int8_w8a8 + False, # use_fp8_w8a16 + None, # w1_scale + None, # w2_scale + None, # block_size + None, # a1_scale + None, # a2_scale + True, # is_vnni + ) + else: + return self.moe_forward_native( + layer, + x, + use_grouped_topk, + top_k, + router_logits, + renormalize, + topk_group, + num_expert_group, + num_fused_shared_experts, + custom_routing_function, + correction_bias, + activation, + apply_router_weight_on_input, + inplace, + no_combine, + routed_scaling_factor, + ) + + def forward_npu( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + num_fused_shared_experts: int = 0, + custom_routing_function: Optional[Callable] = None, + correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + inplace: bool = True, + no_combine: bool = False, + routed_scaling_factor: Optional[float] = None, + ) -> torch.Tensor: + return self.moe_forward_native( + layer, + x, + use_grouped_topk, + top_k, + router_logits, + renormalize, + topk_group, + num_expert_group, + num_fused_shared_experts, + custom_routing_function, + correction_bias, + activation, + apply_router_weight_on_input, + inplace, + no_combine, + routed_scaling_factor, + ) + + def forward_tpu(self, *args, **kwargs) -> torch.Tensor: + raise NotImplementedError("The TPU backend currently does not support MoE.") + + forward_native = forward_cpu + + +class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp): + + def create_weights( + self, + layer: torch.nn.Module, + num_experts_per_partition: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + # Fused gate_up_proj (column parallel) + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts_per_partition, + 2 * intermediate_size, + hidden_size, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + # down_proj (row parallel) + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts_per_partition, + hidden_size, + intermediate_size, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # scale + layer.register_parameter("w13_input_scale", None) + layer.register_parameter("w13_weight_scale", None) + + ones_tensor = torch.ones(num_experts_per_partition, dtype=torch.float32) + + w2_input_scale = torch.nn.Parameter( + ones_tensor, + requires_grad=False, + ) + layer.register_parameter("w2_input_scale", w2_input_scale) + set_weight_attrs(w2_input_scale, extra_weight_attrs) + + w2_weight_scale = torch.nn.Parameter( + ones_tensor, + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + ) -> torch.Tensor: + raise NotImplementedError diff --git a/python/sglang/srt/layers/quantization/utils.py b/python/sglang/srt/layers/quantization/utils.py index 2371208f7..51d70255d 100644 --- a/python/sglang/srt/layers/quantization/utils.py +++ b/python/sglang/srt/layers/quantization/utils.py @@ -1,7 +1,11 @@ # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py +from __future__ import annotations + +import re +from copy import deepcopy from types import MappingProxyType -from typing import List, Mapping, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union import numpy import torch @@ -10,6 +14,9 @@ from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant from sglang.srt.layers.quantization.scalar_type import ScalarType from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_npu +if TYPE_CHECKING: + from sglang.srt.layers.quantization.base_config import QuantizationConfig + _is_cuda = is_cuda() _is_npu = is_npu() _is_cpu_amx_available = cpu_has_amx_support() @@ -147,6 +154,94 @@ def replace_parameter( mod.register_parameter(name, torch.nn.Parameter(new, requires_grad=False)) +# Match dynamic rules with module name (prefix) and override quantize +# config if module (prefix) matches a rule +def override_config(config: QuantizationConfig, prefix: str): + weight_bits = get_dynamic_override(config, prefix, "bits", config.weight_bits) + if isinstance(weight_bits, int): + config.weight_bits = weight_bits + group_size = get_dynamic_override(config, prefix, "group_size", config.group_size) + if isinstance(group_size, int): + config.group_size = group_size + desc_act = get_dynamic_override(config, prefix, "desc_act", config.desc_act) + if isinstance(desc_act, bool): + config.desc_act = desc_act + + config.pack_factor = 32 // config.weight_bits # packed into int32 + if config.get_name() == "gptq_marlin": + is_sym = get_dynamic_override(config, prefix, "sym", config.is_sym) + if isinstance(is_sym, bool): + config.is_sym = is_sym + + if (config.weight_bits, config.is_sym) not in config.TYPE_MAP: + raise ValueError( + "Unsupported quantization config: " + f"bits={config.weight_bits}, sym={config.is_sym}" + ) + + config.quant_type = config.TYPE_MAP[(config.weight_bits, config.is_sym)] + elif config.get_name() == "gptq": + if config.weight_bits not in [2, 3, 4, 8]: + raise ValueError( + "Currently, only 2/3/4/8-bit weight quantization is " + f"supported for GPTQ, but got {config.weight_bits} bits." + ) + + +def get_dynamic_override( + config: QuantizationConfig, + layer_name: str, + key: Optional[str] = None, + default_value: Union[int, bool, None] = None, +) -> Union[Dict, int, bool, None]: + for pattern, pattern_dict in config.dynamic.items(): + # Negative match: matched modules are excluded from quantized init + if pattern.startswith("-:"): + if re.match(pattern.removeprefix("-:"), layer_name): + return False + # Positive match: matched modules have quant properties overrides + # base quant config + elif re.match(pattern.removeprefix("+:"), layer_name): + if key is None: + return pattern_dict + else: + return pattern_dict.get(key, default_value) + return default_value + + +def get_linear_quant_method( + config: QuantizationConfig, + layer: torch.nn.Module, + prefix: str, + linear_method_cls: type, +): + from sglang.srt.layers.linear import LinearBase + from sglang.srt.layers.quantization.unquant import ( + UnquantizedEmbeddingMethod, + UnquantizedLinearMethod, + ) + from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead + + cloned_config = deepcopy(config) + parallel_lm_head_quantized = ( + isinstance(layer, ParallelLMHead) and cloned_config.lm_head_quantized + ) + + if isinstance(layer, LinearBase) or parallel_lm_head_quantized: + # False = skip module, None = no override, else = Positive match + if get_dynamic_override(cloned_config, layer_name=prefix) is False: + if parallel_lm_head_quantized: + return UnquantizedEmbeddingMethod() + return UnquantizedLinearMethod() + + if prefix: + # Dynamic per module/layer rules may override base config + override_config(cloned_config, prefix=prefix) + + return linear_method_cls(cloned_config) + return None + + def get_pack_factor(num_bits): assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}" return 32 // num_bits diff --git a/python/sglang/srt/layers/quantization/w4afp8.py b/python/sglang/srt/layers/quantization/w4afp8.py index c2820bdfc..1c9dc5d33 100644 --- a/python/sglang/srt/layers/quantization/w4afp8.py +++ b/python/sglang/srt/layers/quantization/w4afp8.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging from typing import Any, Dict, List, Optional @@ -5,12 +7,13 @@ import torch from torch.nn import Module from torch.nn.parameter import Parameter -from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod from sglang.srt.layers.quantization.base_config import ( + FusedMoEMethodBase, QuantizationConfig, QuantizeMethodBase, ) from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod +from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod from sglang.srt.layers.quantization.utils import is_layer_skipped from sglang.srt.utils import set_weight_attrs @@ -62,7 +65,7 @@ class W4AFp8Config(QuantizationConfig): return [] @classmethod - def from_config(cls, config: Dict[str, Any]) -> "W4AFp8Config": + def from_config(cls, config: Dict[str, Any]) -> W4AFp8Config: quant_method = cls.get_from_keys(config, ["quant_method"]) is_checkpoint_fp8_serialized = "fp8" in quant_method is_checkpoint_w4afp8_serialized = "w4afp8" in quant_method @@ -79,7 +82,8 @@ class W4AFp8Config(QuantizationConfig): def get_quant_method( self, layer: torch.nn.Module, prefix: str - ) -> Optional["QuantizeMethodBase"]: + ) -> Optional[QuantizeMethodBase]: + from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.moe.fused_moe_triton import FusedMoE if isinstance(layer, LinearBase): @@ -94,7 +98,7 @@ class W4AFp8Config(QuantizationConfig): return [] -class W4AFp8MoEMethod: +class W4AFp8MoEMethod(FusedMoEMethodBase): def __init__(self, quant_config: W4AFp8Config): self.quant_config = quant_config diff --git a/python/sglang/srt/layers/quantization/w8a8_fp8.py b/python/sglang/srt/layers/quantization/w8a8_fp8.py index b2e606f4d..871a4534c 100644 --- a/python/sglang/srt/layers/quantization/w8a8_fp8.py +++ b/python/sglang/srt/layers/quantization/w8a8_fp8.py @@ -1,11 +1,14 @@ +from __future__ import annotations + from typing import Any, Callable, Dict, List, Optional import torch from torch.nn.parameter import Parameter -from sglang.srt.layers.linear import LinearMethodBase from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter from sglang.srt.layers.quantization.base_config import ( + FusedMoEMethodBase, + LinearMethodBase, QuantizationConfig, QuantizeMethodBase, ) @@ -64,7 +67,7 @@ class W8A8Fp8Config(QuantizationConfig): return [] @classmethod - def from_config(cls, config: Dict[str, Any]) -> "W8A8Fp8Config": + def from_config(cls, config: Dict[str, Any]) -> W8A8Fp8Config: quant_method = cls.get_from_keys(config, ["quant_method"]) is_checkpoint_fp8_serialized = ( "compressed-tensors" in quant_method or "w8a8_fp8" in quant_method @@ -75,7 +78,7 @@ class W8A8Fp8Config(QuantizationConfig): self, layer: torch.nn.Module, prefix: str, - ) -> Optional["QuantizeMethodBase"]: + ) -> Optional[QuantizeMethodBase]: from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.moe.fused_moe_triton import FusedMoE @@ -183,7 +186,7 @@ class W8A8Fp8LinearMethod(LinearMethodBase): ) -class W8A8FP8MoEMethod: +class W8A8FP8MoEMethod(FusedMoEMethodBase): """MoE method for FP8. Supports loading FP8 checkpoints with static weight scale and dynamic/static activation scale. @@ -194,25 +197,7 @@ class W8A8FP8MoEMethod: quant_config: The quantization config. """ - def __new__(cls, *args, **kwargs): - from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase - - if not hasattr(cls, "_initialized"): - original_init = cls.__init__ - new_cls = type( - cls.__name__, - (FusedMoEMethodBase,), - { - "__init__": original_init, - **{k: v for k, v in cls.__dict__.items() if k != "__dict__"}, - }, - ) - obj = super(new_cls, new_cls).__new__(new_cls) - obj.__init__(*args, **kwargs) - return obj - return super().__new__(cls) - - def __init__(self, quant_config): + def __init__(self, quant_config: W8A8Fp8Config): self.quant_config = quant_config def create_weights( diff --git a/python/sglang/srt/layers/quantization/w8a8_int8.py b/python/sglang/srt/layers/quantization/w8a8_int8.py index 49e6f0e8c..c8a024bf3 100644 --- a/python/sglang/srt/layers/quantization/w8a8_int8.py +++ b/python/sglang/srt/layers/quantization/w8a8_int8.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import importlib import sys from types import MappingProxyType @@ -11,21 +13,19 @@ from sglang.srt.distributed import ( get_tensor_model_parallel_world_size, ) from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading -from sglang.srt.layers.linear import ( - LinearMethodBase, - RowParallelLinear, - UnquantizedLinearMethod, -) from sglang.srt.layers.parameter import ( ChannelQuantScaleParameter, ModelWeightParameter, PerTensorScaleParameter, ) from sglang.srt.layers.quantization.base_config import ( + FusedMoEMethodBase, + LinearMethodBase, QuantizationConfig, QuantizeMethodBase, ) from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8 +from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod from sglang.srt.utils import ( apply_module_patch, cpu_has_amx_support, @@ -229,14 +229,14 @@ class W8A8Int8Config(QuantizationConfig): return [] @classmethod - def from_config(cls, config: Dict[str, Any]) -> "W8A8Int8Config": + def from_config(cls, config: Dict[str, Any]) -> W8A8Int8Config: return cls(config) def get_quant_method( self, layer: torch.nn.Module, prefix: str, - ) -> Optional["QuantizeMethodBase"]: + ) -> Optional[QuantizeMethodBase]: from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.moe.fused_moe_triton import FusedMoE @@ -374,7 +374,7 @@ class W8A8Int8LinearMethod(LinearMethodBase): ) -class W8A8Int8MoEMethod: +class W8A8Int8MoEMethod(FusedMoEMethodBase): """MoE method for INT8. Supports loading INT8 checkpoints with static weight scale and dynamic/static activation scale. @@ -385,25 +385,7 @@ class W8A8Int8MoEMethod: quant_config: The quantization config. """ - def __new__(cls, *args, **kwargs): - from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase - - if not hasattr(cls, "_initialized"): - original_init = cls.__init__ - new_cls = type( - cls.__name__, - (FusedMoEMethodBase,), - { - "__init__": original_init, - **{k: v for k, v in cls.__dict__.items() if k != "__dict__"}, - }, - ) - obj = super(new_cls, new_cls).__new__(new_cls) - obj.__init__(*args, **kwargs) - return obj - return super().__new__(cls) - - def __init__(self, quant_config): + def __init__(self, quant_config: W8A8Int8Config): self.quant_config = quant_config def create_weights( @@ -885,13 +867,15 @@ class NPU_W8A8DynamicLinearMethod(LinearMethodBase): x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: + from sglang.srt.layers.linear import RowParallelLinear + if isinstance(layer, RowParallelLinear): tp_rank = get_tensor_model_parallel_rank() return self.quant_method.apply(layer, x, bias, tp_rank) return self.quant_method.apply(layer, x, bias) -class NPU_W8A8MoEMethod: +class NPU_W8A8MoEMethod(FusedMoEMethodBase): """MoE method for NPU quantization. This class search for specific quantization diff --git a/python/sglang/srt/layers/vocab_parallel_embedding.py b/python/sglang/srt/layers/vocab_parallel_embedding.py index 0e075a251..d925506f5 100644 --- a/python/sglang/srt/layers/vocab_parallel_embedding.py +++ b/python/sglang/srt/layers/vocab_parallel_embedding.py @@ -5,7 +5,6 @@ from dataclasses import dataclass from typing import List, Optional, Sequence, Tuple import torch -import torch.nn.functional as F from torch.nn.parameter import Parameter, UninitializedParameter from sglang.srt.distributed import ( @@ -22,6 +21,7 @@ from sglang.srt.layers.quantization.base_config import ( QuantizeMethodBase, method_has_implemented_embedding, ) +from sglang.srt.layers.quantization.unquant import UnquantizedEmbeddingMethod from sglang.srt.utils import cpu_has_amx_support, is_cpu, set_weight_attrs DEFAULT_VOCAB_PADDING_SIZE = 64 @@ -32,44 +32,6 @@ _is_cpu = is_cpu() logger = logging.getLogger(__name__) -class UnquantizedEmbeddingMethod(QuantizeMethodBase): - """Unquantized method for embeddings.""" - - def create_weights( - self, - layer: torch.nn.Module, - input_size_per_partition: int, - output_partition_sizes: List[int], - input_size: int, - output_size: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - """Create weights for embedding layer.""" - weight = Parameter( - torch.empty( - sum(output_partition_sizes), - input_size_per_partition, - dtype=params_dtype, - ), - requires_grad=False, - ) - set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) - layer.register_parameter("weight", weight) - set_weight_attrs(weight, extra_weight_attrs) - - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - return F.linear(x, layer.weight, bias) - - def embedding(self, layer: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor: - return F.embedding(input_, layer.weight) - - def pad_vocab_size(vocab_size: int, pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int: """Pad the vocab size to the given value.""" return ((vocab_size + pad_to - 1) // pad_to) * pad_to