Files
sglang/python/sglang/srt/layers/quantization/quark/quark_moe.py
2025-10-05 19:51:34 -07:00

216 lines
7.3 KiB
Python

# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any, Callable, Optional
import torch
from aiter import ActivationType, QuantType, biased_grouped_topk
from aiter.fused_moe import fused_moe
from aiter.utility.fp4_utils import e8m0_shuffle
from sglang.srt.layers.moe import MoeRunnerConfig
from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
from sglang.srt.utils import get_bool_env_var, is_hip, mxfp_supported, set_weight_attrs
if TYPE_CHECKING:
from sglang.srt.layers.moe.token_dispatcher import (
CombineInput,
StandardDispatchOutput,
)
from sglang.srt.layers.quantization.quark.quark import QuarkConfig
logger = logging.getLogger(__name__)
_is_hip = is_hip()
__all__ = ["QuarkMoEMethod", "QuarkW4A4MXFp4MoEMethod"]
OCP_MX_BLOCK_SIZE = 32
if TYPE_CHECKING:
from sglang.srt.layers.quantization import QuarkConfig
class QuarkMoEMethod(FusedMoEMethodBase):
def __init__(self, quant_config: QuarkConfig):
self.quant_config = quant_config
@staticmethod
def get_moe_method(
quant_config: QuarkConfig, # type: ignore # noqa E501 # noqa F821
module: torch.nn.Module,
layer_name: str,
) -> "QuarkMoEMethod":
layer_quant_config = quant_config._find_matched_config(layer_name, module)
if layer_quant_config.get("output_tensors") or layer_quant_config.get("bias"):
raise NotImplementedError(
"Currently, Quark models with "
"output_tensors and bias "
"quantized are not supported"
)
weight_config = layer_quant_config.get("weight")
input_config = layer_quant_config.get("input_tensors")
if quant_config._is_mx_fp4(weight_config, input_config):
return QuarkW4A4MXFp4MoEMethod(weight_config, input_config)
else:
raise RuntimeError("Unsupported FusedMoe scheme")
class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
def __init__(self, weight_config: dict[str, Any], input_config: dict[str, Any]):
self.weight_quant = weight_config
self.input_quant = input_config
weight_qscheme = self.weight_quant.get("qscheme")
input_qscheme = self.input_quant.get("qscheme")
if not (weight_qscheme == "per_group" and input_qscheme == "per_group"):
raise ValueError(
"For MX(FP4) Fused MoE layers, only per-group scales "
"for weights and activations are supported. Found "
f"{weight_qscheme}, {input_qscheme}"
) # noqa E501
self.static_input_scales = not self.input_quant.get("is_dynamic")
self.with_bias = False
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
# Add the quantization method used (per tensor/grouped/channel)
# to ensure the weight scales are loaded in properly
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
)
params_dtype = torch.uint8
# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size // 2,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition // 2,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# WEIGHT_SCALES
w13_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts,
2 * intermediate_size_per_partition,
hidden_size // OCP_MX_BLOCK_SIZE,
dtype=params_dtype,
),
requires_grad=False,
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts,
hidden_size,
intermediate_size_per_partition // OCP_MX_BLOCK_SIZE,
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
float_dtype = torch.get_default_dtype()
# Pre-shuffle weight scales
s0, s1, _ = layer.w13_weight_scale.shape
w13_weight_scale = layer.w13_weight_scale.view(s0 * s1, -1)
w13_weight_scale = e8m0_shuffle(w13_weight_scale)
# layer.w13_weight_scale = torch.nn.Parameter(w13_weight_scale, requires_grad=False)
layer.w13_weight_scale.data = w13_weight_scale.view(s0, s1, -1)
s0, s1, _ = layer.w2_weight_scale.shape
w2_weight_scale = layer.w2_weight_scale.view(s0 * s1, -1)
w2_weight_scale = e8m0_shuffle(w2_weight_scale)
# layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale, requires_grad=False)
layer.w2_weight_scale.data = w2_weight_scale.view(s0, s1, -1)
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
def apply(
self,
layer: torch.nn.Module,
dispatch_output: StandardDispatchOutput,
) -> CombineInput:
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
moe_runner_config = self.moe_runner_config
topk_weights, topk_ids, _ = topk_output
if _is_hip:
topk_weights = topk_weights.to(
torch.float32
) # aiter's moe_sorting requires topk_weights to be FP32
if hasattr(torch, "float4_e2m1fn_x2"):
w13_weight = layer.w13_weight.view(torch.float4_e2m1fn_x2)
w2_weight = layer.w2_weight.view(torch.float4_e2m1fn_x2)
else:
w13_weight = layer.w13_weight
w2_weight = layer.w2_weight
output = fused_moe(
x,
w13_weight,
w2_weight,
topk_weights,
topk_ids,
quant_type=QuantType.per_1x32,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
activation=(
ActivationType.Silu
if moe_runner_config.activation == "silu"
else ActivationType.Gelu
),
doweight_stage1=False,
)
return StandardCombineInput(hidden_states=output)