diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 328d82215..3749abc34 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -71,7 +71,8 @@ ACTIVATION_SCHEMES = ["static", "dynamic"] _is_hip = is_hip() if _is_hip: - from aiter.fused_moe_bf16_asm import asm_moe + from aiter import ActivationType + from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages, ck_moe_2stages_win4 from aiter.ops.shuffle import shuffle_weight _is_cuda = is_cuda() @@ -487,7 +488,7 @@ class Fp8MoEMethod: if self.quant_config.is_checkpoint_fp8_serialized: params_dtype = ( - torch.int32 + torch.uint32 if get_bool_env_var("USE_INT4_WEIGHT") else torch.float8_e4m3fn ) @@ -822,12 +823,14 @@ class Fp8MoEMethod: # INT4-FP8 (INT4 MoE Weight, FP8 Compute) # Weight Permutation layer.w13_weight = torch.nn.Parameter( - permute_weight(layer.w13_weight.data), + # permute_weight(layer.w13_weight.data), + shuffle_weight(layer.w13_weight.data, (16, 16)), requires_grad=False, ) torch.cuda.empty_cache() layer.w2_weight = torch.nn.Parameter( - permute_weight(layer.w2_weight.data), + # permute_weight(layer.w2_weight.data), + shuffle_weight(layer.w2_weight.data, (16, 16)), requires_grad=False, ) torch.cuda.empty_cache() @@ -867,12 +870,14 @@ class Fp8MoEMethod: if get_bool_env_var("CK_MOE"): layer.w13_weight = torch.nn.Parameter( - permute_weight(layer.w13_weight.data), + # permute_weight(layer.w13_weight.data), + shuffle_weight(layer.w13_weight.data, (16, 16)), requires_grad=False, ) torch.cuda.empty_cache() layer.w2_weight = torch.nn.Parameter( - permute_weight(layer.w2_weight.data), + # permute_weight(layer.w2_weight.data), + shuffle_weight(layer.w2_weight.data, (16, 16)), requires_grad=False, ) torch.cuda.empty_cache() @@ -928,7 +933,7 @@ class Fp8MoEMethod: if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"): # TODO: add triton kernel and add check get_bool_env_var("CK_MOE") assert not no_combine, f"{no_combine=} is not supported." - return asm_moe( + return ck_moe_2stages_win4( x, layer.w13_weight, layer.w2_weight, @@ -936,15 +941,17 @@ class Fp8MoEMethod: topk_ids, layer.w13_weight_scale1, layer.w2_weight_scale1, - activation=activation, + activation=( + ActivationType.Silu if activation == "silu" else ActivationType.Gelu + ), ) if _is_hip and get_bool_env_var("CK_MOE"): - # TODO(CK_MOE): FP8 or FP8 block_quant only supports 'silu' for the time-being. - assert ( - activation == "silu" - ), f"CK_MOE: FP8 and/or FP8 bloack_quant {activation=} will be supported later, unset CK_MOE" assert not no_combine, f"{no_combine=} is not supported." if self.block_quant: + # TODO(CK_MOE): FP8 block_quant only supports 'silu' for the time-being. + assert ( + activation == "silu" + ), f"CK_MOE: FP8 bloack_quant {activation=} will be supported later, unset CK_MOE" return asm_moe( x, layer.w13_weight, @@ -957,7 +964,7 @@ class Fp8MoEMethod: expert_mask=None, ) else: - return asm_moe( + return ck_moe_2stages( x, layer.w13_weight, layer.w2_weight, @@ -965,6 +972,11 @@ class Fp8MoEMethod: topk_ids, layer.w13_weight_scale1, layer.w2_weight_scale1, + activation=( + ActivationType.Silu + if activation == "silu" + else ActivationType.Gelu + ), ) else: # Expert fusion with FP8 quantization