216 lines
7.3 KiB
Python
216 lines
7.3 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from typing import TYPE_CHECKING, Any, Callable, Optional
|
|
|
|
import torch
|
|
from aiter import ActivationType, QuantType, biased_grouped_topk
|
|
from aiter.fused_moe import fused_moe
|
|
from aiter.utility.fp4_utils import e8m0_shuffle
|
|
|
|
from sglang.srt.layers.moe import MoeRunnerConfig
|
|
from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
|
|
from sglang.srt.utils import get_bool_env_var, is_hip, mxfp_supported, set_weight_attrs
|
|
|
|
if TYPE_CHECKING:
|
|
from sglang.srt.layers.moe.token_dispatcher import (
|
|
CombineInput,
|
|
StandardDispatchOutput,
|
|
)
|
|
from sglang.srt.layers.quantization.quark.quark import QuarkConfig
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_is_hip = is_hip()
|
|
|
|
__all__ = ["QuarkMoEMethod", "QuarkW4A4MXFp4MoEMethod"]
|
|
|
|
OCP_MX_BLOCK_SIZE = 32
|
|
|
|
if TYPE_CHECKING:
|
|
from sglang.srt.layers.quantization import QuarkConfig
|
|
|
|
|
|
class QuarkMoEMethod(FusedMoEMethodBase):
|
|
|
|
def __init__(self, quant_config: QuarkConfig):
|
|
self.quant_config = quant_config
|
|
|
|
@staticmethod
|
|
def get_moe_method(
|
|
quant_config: QuarkConfig, # type: ignore # noqa E501 # noqa F821
|
|
module: torch.nn.Module,
|
|
layer_name: str,
|
|
) -> "QuarkMoEMethod":
|
|
layer_quant_config = quant_config._find_matched_config(layer_name, module)
|
|
|
|
if layer_quant_config.get("output_tensors") or layer_quant_config.get("bias"):
|
|
raise NotImplementedError(
|
|
"Currently, Quark models with "
|
|
"output_tensors and bias "
|
|
"quantized are not supported"
|
|
)
|
|
weight_config = layer_quant_config.get("weight")
|
|
input_config = layer_quant_config.get("input_tensors")
|
|
|
|
if quant_config._is_mx_fp4(weight_config, input_config):
|
|
return QuarkW4A4MXFp4MoEMethod(weight_config, input_config)
|
|
else:
|
|
raise RuntimeError("Unsupported FusedMoe scheme")
|
|
|
|
|
|
class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
|
|
|
|
def __init__(self, weight_config: dict[str, Any], input_config: dict[str, Any]):
|
|
self.weight_quant = weight_config
|
|
self.input_quant = input_config
|
|
|
|
weight_qscheme = self.weight_quant.get("qscheme")
|
|
input_qscheme = self.input_quant.get("qscheme")
|
|
if not (weight_qscheme == "per_group" and input_qscheme == "per_group"):
|
|
raise ValueError(
|
|
"For MX(FP4) Fused MoE layers, only per-group scales "
|
|
"for weights and activations are supported. Found "
|
|
f"{weight_qscheme}, {input_qscheme}"
|
|
) # noqa E501
|
|
|
|
self.static_input_scales = not self.input_quant.get("is_dynamic")
|
|
self.with_bias = False
|
|
|
|
def create_weights(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
num_experts: int,
|
|
hidden_size: int,
|
|
intermediate_size_per_partition: int,
|
|
params_dtype: torch.dtype,
|
|
**extra_weight_attrs,
|
|
):
|
|
|
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
|
|
|
# Add the quantization method used (per tensor/grouped/channel)
|
|
# to ensure the weight scales are loaded in properly
|
|
extra_weight_attrs.update(
|
|
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
|
|
)
|
|
|
|
params_dtype = torch.uint8
|
|
|
|
# WEIGHTS
|
|
w13_weight = torch.nn.Parameter(
|
|
torch.empty(
|
|
num_experts,
|
|
2 * intermediate_size_per_partition,
|
|
hidden_size // 2,
|
|
dtype=params_dtype,
|
|
),
|
|
requires_grad=False,
|
|
)
|
|
layer.register_parameter("w13_weight", w13_weight)
|
|
|
|
set_weight_attrs(w13_weight, extra_weight_attrs)
|
|
|
|
w2_weight = torch.nn.Parameter(
|
|
torch.empty(
|
|
num_experts,
|
|
hidden_size,
|
|
intermediate_size_per_partition // 2,
|
|
dtype=params_dtype,
|
|
),
|
|
requires_grad=False,
|
|
)
|
|
layer.register_parameter("w2_weight", w2_weight)
|
|
|
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
|
|
|
# WEIGHT_SCALES
|
|
w13_weight_scale = torch.nn.Parameter(
|
|
torch.ones(
|
|
num_experts,
|
|
2 * intermediate_size_per_partition,
|
|
hidden_size // OCP_MX_BLOCK_SIZE,
|
|
dtype=params_dtype,
|
|
),
|
|
requires_grad=False,
|
|
)
|
|
w2_weight_scale = torch.nn.Parameter(
|
|
torch.ones(
|
|
num_experts,
|
|
hidden_size,
|
|
intermediate_size_per_partition // OCP_MX_BLOCK_SIZE,
|
|
dtype=params_dtype,
|
|
),
|
|
requires_grad=False,
|
|
)
|
|
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
|
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
|
|
|
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
|
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
|
|
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
float_dtype = torch.get_default_dtype()
|
|
|
|
# Pre-shuffle weight scales
|
|
s0, s1, _ = layer.w13_weight_scale.shape
|
|
w13_weight_scale = layer.w13_weight_scale.view(s0 * s1, -1)
|
|
w13_weight_scale = e8m0_shuffle(w13_weight_scale)
|
|
# layer.w13_weight_scale = torch.nn.Parameter(w13_weight_scale, requires_grad=False)
|
|
layer.w13_weight_scale.data = w13_weight_scale.view(s0, s1, -1)
|
|
|
|
s0, s1, _ = layer.w2_weight_scale.shape
|
|
w2_weight_scale = layer.w2_weight_scale.view(s0 * s1, -1)
|
|
w2_weight_scale = e8m0_shuffle(w2_weight_scale)
|
|
# layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale, requires_grad=False)
|
|
layer.w2_weight_scale.data = w2_weight_scale.view(s0, s1, -1)
|
|
|
|
def create_moe_runner(
|
|
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
|
):
|
|
self.moe_runner_config = moe_runner_config
|
|
|
|
def apply(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
dispatch_output: StandardDispatchOutput,
|
|
) -> CombineInput:
|
|
|
|
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
|
|
|
x = dispatch_output.hidden_states
|
|
topk_output = dispatch_output.topk_output
|
|
moe_runner_config = self.moe_runner_config
|
|
topk_weights, topk_ids, _ = topk_output
|
|
if _is_hip:
|
|
topk_weights = topk_weights.to(
|
|
torch.float32
|
|
) # aiter's moe_sorting requires topk_weights to be FP32
|
|
|
|
if hasattr(torch, "float4_e2m1fn_x2"):
|
|
w13_weight = layer.w13_weight.view(torch.float4_e2m1fn_x2)
|
|
w2_weight = layer.w2_weight.view(torch.float4_e2m1fn_x2)
|
|
else:
|
|
w13_weight = layer.w13_weight
|
|
w2_weight = layer.w2_weight
|
|
|
|
output = fused_moe(
|
|
x,
|
|
w13_weight,
|
|
w2_weight,
|
|
topk_weights,
|
|
topk_ids,
|
|
quant_type=QuantType.per_1x32,
|
|
w1_scale=layer.w13_weight_scale,
|
|
w2_scale=layer.w2_weight_scale,
|
|
activation=(
|
|
ActivationType.Silu
|
|
if moe_runner_config.activation == "silu"
|
|
else ActivationType.Gelu
|
|
),
|
|
doweight_stage1=False,
|
|
)
|
|
return StandardCombineInput(hidden_states=output)
|