diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index a365f8481..2c3f722ce 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -513,6 +513,10 @@ class FusedMoE(torch.nn.Module): # Case input scale: input_scale loading is only supported for fp8 if "input_scale" in weight_name: + # INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust input_scale for e4m3fnuz (AMD) + if is_hip_ and get_bool_env_var("USE_INT4_WEIGHT"): + loaded_weight = loaded_weight * 2.0 + # this is needed for compressed-tensors only loaded_weight = loaded_weight.to(param.data.device) @@ -551,6 +555,10 @@ class FusedMoE(torch.nn.Module): # specific to each case quant_method = getattr(param, "quant_method", None) if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value: + # INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust INT4 column-wise scaling number to e4m3fnuz (AMD) + if is_hip_ and get_bool_env_var("USE_INT4_WEIGHT"): + loaded_weight = loaded_weight * 0.5 + self._load_per_channel_weight_scale( shard_id=shard_id, shard_dim=shard_dim, @@ -570,6 +578,10 @@ class FusedMoE(torch.nn.Module): tp_rank=tp_rank, ) elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value: + # INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust FP8 per-tensor scaling number for e4m3fnuz (AMD) + if is_hip_ and get_bool_env_var("USE_INT4_WEIGHT"): + loaded_weight = loaded_weight * 2.0 + self._load_per_tensor_weight_scale( shard_id=shard_id, param=param, diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index c61adbdae..e296756b5 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -460,7 +460,11 @@ class Fp8MoEMethod: from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported if self.quant_config.is_checkpoint_fp8_serialized: - params_dtype = torch.float8_e4m3fn + params_dtype = ( + torch.int32 + if get_bool_env_var("USE_INT4_WEIGHT") + else torch.float8_e4m3fn + ) tp_size = get_tensor_model_parallel_world_size() if self.block_quant: block_n, block_k = ( @@ -485,21 +489,40 @@ class Fp8MoEMethod: ) # WEIGHTS - w13_weight = torch.nn.Parameter( - torch.empty( - num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype - ), - requires_grad=False, - ) + if get_bool_env_var("USE_INT4_WEIGHT"): + # INT4 MoE weight - INT32 packed + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size, + hidden_size // 8, + dtype=params_dtype, + ), + requires_grad=False, + ) + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, hidden_size, intermediate_size // 8, dtype=params_dtype + ), + requires_grad=False, + ) + else: + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype + ), + requires_grad=False, + ) + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, hidden_size, intermediate_size, dtype=params_dtype + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) - w2_weight = torch.nn.Parameter( - torch.empty( - num_experts, hidden_size, intermediate_size, dtype=params_dtype - ), - requires_grad=False, - ) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) @@ -538,7 +561,9 @@ class Fp8MoEMethod: layer.register_parameter("w13_weight_scale", w13_weight_scale) layer.register_parameter("w2_weight_scale", w2_weight_scale) - if is_hip_ and get_bool_env_var("CK_MOE"): + if ( + is_hip_ + ): # and get_bool_env_var("CK_MOE"): TODO: add check back after triton kernel # ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling w13_weight_scale1 = torch.nn.Parameter( torch.ones(num_experts, 2 * intermediate_size, dtype=torch.float32), @@ -565,6 +590,13 @@ class Fp8MoEMethod: set_weight_attrs(w13_weight_scale, extra_weight_attrs) set_weight_attrs(w2_weight_scale, extra_weight_attrs) + if get_bool_env_var("USE_INT4_WEIGHT"): + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value} + ) + set_weight_attrs(w13_weight_scale1, extra_weight_attrs) + set_weight_attrs(w2_weight_scale1, extra_weight_attrs) + # INPUT_SCALES if self.quant_config.activation_scheme == "static": if not self.quant_config.is_checkpoint_fp8_serialized: @@ -590,6 +622,53 @@ class Fp8MoEMethod: layer.w2_input_scale = None def process_weights_after_loading(self, layer: Module) -> None: + if get_bool_env_var("USE_INT4_WEIGHT"): + # TODO: and get_bool_env_var("CK_MOE"): add after triton kernel added + # INT4-FP8 (INT4 MoE Weight, FP8 Compute) + # Weight Permutation + layer.w13_weight = torch.nn.Parameter( + permute_weight(layer.w13_weight.data), + requires_grad=False, + ) + torch.cuda.empty_cache() + layer.w2_weight = torch.nn.Parameter( + permute_weight(layer.w2_weight.data), + requires_grad=False, + ) + torch.cuda.empty_cache() + + # INT4-FP8 : offset INT4 w13_weight_scale1 to single w13_weight_scale + # Fp8 moe kernel needs single fp8 w13_weight_scale for w13 per expert. + # We won't do requant each expert's fp8 weight (not direct available), + # instead we adjust half of INT4 w13_weight_scale1 numbers + assert layer.w13_weight_scale is not None + shard_size = layer.intermediate_size_per_partition + max_w13_scales = layer.w13_weight_scale.max(dim=1).values + for expert_id in range(layer.num_experts): + start = 0 + max_w13_scale_fp8 = max_w13_scales[expert_id] + for shard_id in range(2): + if layer.w13_weight_scale[expert_id][shard_id] != max_w13_scale_fp8: + int4_rescale = ( + layer.w13_weight_scale[expert_id][shard_id] + / max_w13_scale_fp8 + ) + layer.w13_weight_scale1[expert_id][ + start : start + shard_size + ] *= int4_rescale + start += shard_size + + layer.w13_weight_scale = torch.nn.Parameter( + max_w13_scales, requires_grad=False + ) + + # special hack to asm_moe, which takes (weight_scale1 * weight_scale) as post GEMM scaling + # optimal design - shall apply per-column weight_scale1 before GEMM, and weight_scale post + for expert_id in range(layer.num_experts): + layer.w13_weight_scale1[expert_id] *= max_w13_scales[expert_id] + layer.w2_weight_scale1[expert_id] *= layer.w2_weight_scale[expert_id] + return + from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( padding_size, # Avoid circular import ) @@ -823,8 +902,24 @@ class Fp8MoEMethod: correction_bias=correction_bias, ) - if is_hip_ and get_bool_env_var("CK_MOE") and activation == "silu": + if is_hip_ and get_bool_env_var("USE_INT4_WEIGHT"): + # TODO: add triton kernel and add check get_bool_env_var("CK_MOE") + assert not no_combine, f"{no_combine=} is not supported." + return asm_moe( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + layer.w13_weight_scale1, + layer.w2_weight_scale1, + activation=activation, + ) + if is_hip_ and get_bool_env_var("CK_MOE"): # TODO(CK_MOE): FP8 or FP8 block_quant only supports 'silu' for the time-being. + assert ( + activation == "silu" + ), f"CK_MOE: FP8 and/or FP8 bloack_quant {activation=} will be supported later, unset CK_MOE" assert not no_combine, f"{no_combine=} is not supported." if self.block_quant: return asm_moe( @@ -835,10 +930,6 @@ class Fp8MoEMethod: topk_ids, layer.w13_weight_scale_inv, layer.w2_weight_scale_inv, - None, - None, - False, - None, block_shape=tuple(self.quant_config.weight_block_size), expert_mask=None, ) @@ -851,9 +942,6 @@ class Fp8MoEMethod: topk_ids, layer.w13_weight_scale1, layer.w2_weight_scale1, - None, - None, - False, ) else: # Expert fusion with FP8 quantization diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 76c05749d..1ce2862f9 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1269,7 +1269,8 @@ def permute_weight(x: torch.Tensor) -> torch.Tensor: elif x.dtype == torch.float8_e4m3fnuz or x.dtype == torch.int8: x_ = x_.view(int(b_), int(n_ / 16), 16, int(k_ / 64), 4, 16) else: - return x_ + # return x_ + x_ = x_.view(int(b_), int(n_ / 16), 16, int(k_ / 8), 2, 4) x_ = x_.permute(0, 1, 3, 4, 2, 5) x_ = x_.contiguous()