Combine fp4.py and mxfp4.py into one file and support dynamic mxfp4 quantization in mxfp4.py (#9049)

Co-authored-by: wunhuang <wunhuang@amd.com>
This commit is contained in:
kk
2025-08-17 10:01:54 +08:00
committed by GitHub
parent 384f8ab5ce
commit 1c1f8a118e
7 changed files with 760 additions and 557 deletions

View File

@@ -38,6 +38,7 @@ from sglang.srt.utils import (
is_hip,
is_triton_kernels_available,
log_info_on_rank0,
mxfp_supported,
next_power_of_2,
round_up,
set_weight_attrs,
@@ -61,7 +62,14 @@ if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput
OCP_MX_BLOCK_SIZE = 32
_is_hip = is_hip()
if _is_hip:
# import aiter
from aiter import ActivationType, QuantType, dtypes
from aiter.fused_moe import fused_moe
from aiter.ops.triton.quant import dynamic_mxfp4_quant
from aiter.utility.fp4_utils import e8m0_shuffle
def _swizzle_mxfp4(quant_tensor, scale, num_warps):
@@ -162,13 +170,34 @@ except AttributeError as error:
class Mxfp4Config(QuantizationConfig):
def __init__(self, ignored_layers: Optional[list[str]] = None):
def __init__(
self,
ignored_layers: Optional[list[str]] = None,
is_checkpoint_mxfp4_serialized: bool = False,
):
super().__init__()
self.is_checkpoint_mxfp4_serialized = is_checkpoint_mxfp4_serialized
self.ignored_layers = ignored_layers
@classmethod
def from_config(cls, config):
return cls()
quant_method = cls.get_from_keys(config, ["quant_method"])
is_checkpoint_mxfp4_serialized = "mxfp4" in quant_method
if _is_hip:
if mxfp_supported():
return cls(
is_checkpoint_mxfp4_serialized=is_checkpoint_mxfp4_serialized
)
else:
platform = torch.cuda.get_device_properties(0).gcnArchName
raise ValueError(
f"Current platform {platform} not support mxfp4 computation"
)
return cls(is_checkpoint_mxfp4_serialized=is_checkpoint_mxfp4_serialized)
@classmethod
def get_min_capability(cls) -> int:
@@ -186,6 +215,9 @@ class Mxfp4Config(QuantizationConfig):
def get_config_filenames(cls) -> list[str]:
return []
def is_static_cfg(self):
return self.is_checkpoint_mxfp4_serialized
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
@@ -201,10 +233,16 @@ class Mxfp4Config(QuantizationConfig):
fused_mapping=self.packed_modules_mapping,
):
return UnquantizedLinearMethod()
elif _is_hip:
return UnquantizedLinearMethod()
elif isinstance(layer, FusedMoE):
return Mxfp4MoEMethod(prefix)
if self.is_checkpoint_mxfp4_serialized:
return Mxfp4MoEMethod(prefix=prefix)
else:
return Mxfp4DynamicQuantMoEMethod()
else:
raise NotImplementedError("Mxfp4 attention layer is not implemented")
if self.is_checkpoint_mxfp4_serialized:
raise NotImplementedError("Mxfp4 attention layer is not implemented")
return None
def get_scaled_act_names(self) -> List[str]:
@@ -655,3 +693,116 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
b1=layer.w13_weight_bias,
b2=layer.w2_weight_bias,
)
class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=params_dtype,
),
requires_grad=False,
)
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Add the quantization method used (per tensor/grouped/channel)
# to ensure the weight scales are loaded in properly
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
)
layer.w13_input_scale = None
layer.w2_input_scale = None
def mxfp4_quantize(self, w):
w_shape = w.shape
w_need_reshape = True if w.dim() != 2 else False
if w_need_reshape:
w_last_dim_size = w_shape[-1]
w = w.view(-1, w_last_dim_size)
w, mx_scales = dynamic_mxfp4_quant(w)
if w_need_reshape:
w_new_shape = w_shape[:-1] + (w.shape[-1],)
w = w.view(w_new_shape)
mx_scales = e8m0_shuffle(mx_scales)
return w, mx_scales
def process_weights_after_loading(self, layer: Module) -> None:
w13, w13_mx_scales = self.mxfp4_quantize(layer.w13_weight.data)
w2, w2_mx_scales = self.mxfp4_quantize(layer.w2_weight.data)
layer.w13_weight = torch.nn.Parameter(w13, requires_grad=False)
layer.w13_weight_scale = torch.nn.Parameter(w13_mx_scales, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2, requires_grad=False)
layer.w2_weight_scale = torch.nn.Parameter(w2_mx_scales, requires_grad=False)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
topk_weights, topk_ids, _ = topk_output
return fused_moe(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
quant_type=QuantType.per_1x32,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
activation=(
ActivationType.Silu
if moe_runner_config.activation == "silu"
else ActivationType.Gelu
),
doweight_stage1=False,
)