Combine fp4.py and mxfp4.py into one file and support dynamic mxfp4 quantization in mxfp4.py (#9049)
Co-authored-by: wunhuang <wunhuang@amd.com>
This commit is contained in:
@@ -38,6 +38,7 @@ from sglang.srt.utils import (
|
||||
is_hip,
|
||||
is_triton_kernels_available,
|
||||
log_info_on_rank0,
|
||||
mxfp_supported,
|
||||
next_power_of_2,
|
||||
round_up,
|
||||
set_weight_attrs,
|
||||
@@ -61,7 +62,14 @@ if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
|
||||
OCP_MX_BLOCK_SIZE = 32
|
||||
_is_hip = is_hip()
|
||||
|
||||
if _is_hip:
|
||||
# import aiter
|
||||
from aiter import ActivationType, QuantType, dtypes
|
||||
from aiter.fused_moe import fused_moe
|
||||
from aiter.ops.triton.quant import dynamic_mxfp4_quant
|
||||
from aiter.utility.fp4_utils import e8m0_shuffle
|
||||
|
||||
|
||||
def _swizzle_mxfp4(quant_tensor, scale, num_warps):
|
||||
@@ -162,13 +170,34 @@ except AttributeError as error:
|
||||
|
||||
class Mxfp4Config(QuantizationConfig):
|
||||
|
||||
def __init__(self, ignored_layers: Optional[list[str]] = None):
|
||||
def __init__(
|
||||
self,
|
||||
ignored_layers: Optional[list[str]] = None,
|
||||
is_checkpoint_mxfp4_serialized: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.is_checkpoint_mxfp4_serialized = is_checkpoint_mxfp4_serialized
|
||||
self.ignored_layers = ignored_layers
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
return cls()
|
||||
|
||||
quant_method = cls.get_from_keys(config, ["quant_method"])
|
||||
is_checkpoint_mxfp4_serialized = "mxfp4" in quant_method
|
||||
|
||||
if _is_hip:
|
||||
if mxfp_supported():
|
||||
return cls(
|
||||
is_checkpoint_mxfp4_serialized=is_checkpoint_mxfp4_serialized
|
||||
)
|
||||
else:
|
||||
|
||||
platform = torch.cuda.get_device_properties(0).gcnArchName
|
||||
raise ValueError(
|
||||
f"Current platform {platform} not support mxfp4 computation"
|
||||
)
|
||||
|
||||
return cls(is_checkpoint_mxfp4_serialized=is_checkpoint_mxfp4_serialized)
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
@@ -186,6 +215,9 @@ class Mxfp4Config(QuantizationConfig):
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return []
|
||||
|
||||
def is_static_cfg(self):
|
||||
return self.is_checkpoint_mxfp4_serialized
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
@@ -201,10 +233,16 @@ class Mxfp4Config(QuantizationConfig):
|
||||
fused_mapping=self.packed_modules_mapping,
|
||||
):
|
||||
return UnquantizedLinearMethod()
|
||||
elif _is_hip:
|
||||
return UnquantizedLinearMethod()
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return Mxfp4MoEMethod(prefix)
|
||||
if self.is_checkpoint_mxfp4_serialized:
|
||||
return Mxfp4MoEMethod(prefix=prefix)
|
||||
else:
|
||||
return Mxfp4DynamicQuantMoEMethod()
|
||||
else:
|
||||
raise NotImplementedError("Mxfp4 attention layer is not implemented")
|
||||
if self.is_checkpoint_mxfp4_serialized:
|
||||
raise NotImplementedError("Mxfp4 attention layer is not implemented")
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
@@ -655,3 +693,116 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
b1=layer.w13_weight_bias,
|
||||
b2=layer.w2_weight_bias,
|
||||
)
|
||||
|
||||
|
||||
class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
||||
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
# Allocate 2 scales for w1 and w3 respectively.
|
||||
# They will be combined to a single scale after weight loading.
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
|
||||
)
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
||||
)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
|
||||
# Add the quantization method used (per tensor/grouped/channel)
|
||||
# to ensure the weight scales are loaded in properly
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
|
||||
)
|
||||
|
||||
layer.w13_input_scale = None
|
||||
layer.w2_input_scale = None
|
||||
|
||||
def mxfp4_quantize(self, w):
|
||||
w_shape = w.shape
|
||||
w_need_reshape = True if w.dim() != 2 else False
|
||||
|
||||
if w_need_reshape:
|
||||
w_last_dim_size = w_shape[-1]
|
||||
w = w.view(-1, w_last_dim_size)
|
||||
|
||||
w, mx_scales = dynamic_mxfp4_quant(w)
|
||||
|
||||
if w_need_reshape:
|
||||
w_new_shape = w_shape[:-1] + (w.shape[-1],)
|
||||
w = w.view(w_new_shape)
|
||||
|
||||
mx_scales = e8m0_shuffle(mx_scales)
|
||||
|
||||
return w, mx_scales
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
w13, w13_mx_scales = self.mxfp4_quantize(layer.w13_weight.data)
|
||||
w2, w2_mx_scales = self.mxfp4_quantize(layer.w2_weight.data)
|
||||
|
||||
layer.w13_weight = torch.nn.Parameter(w13, requires_grad=False)
|
||||
layer.w13_weight_scale = torch.nn.Parameter(w13_mx_scales, requires_grad=False)
|
||||
|
||||
layer.w2_weight = torch.nn.Parameter(w2, requires_grad=False)
|
||||
layer.w2_weight_scale = torch.nn.Parameter(w2_mx_scales, requires_grad=False)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
) -> torch.Tensor:
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
|
||||
return fused_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_type=QuantType.per_1x32,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
activation=(
|
||||
ActivationType.Silu
|
||||
if moe_runner_config.activation == "silu"
|
||||
else ActivationType.Gelu
|
||||
),
|
||||
doweight_stage1=False,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user