ROCm: enable trillion-parameter MoE models with INT4-FP8 single node (#4152)

This commit is contained in:
HAI
2025-03-06 15:33:02 -08:00
committed by GitHub
parent 9854a18a51
commit 13bc39c5d6
3 changed files with 124 additions and 23 deletions

View File

@@ -460,7 +460,11 @@ class Fp8MoEMethod:
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
if self.quant_config.is_checkpoint_fp8_serialized:
params_dtype = torch.float8_e4m3fn
params_dtype = (
torch.int32
if get_bool_env_var("USE_INT4_WEIGHT")
else torch.float8_e4m3fn
)
tp_size = get_tensor_model_parallel_world_size()
if self.block_quant:
block_n, block_k = (
@@ -485,21 +489,40 @@ class Fp8MoEMethod:
)
# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype
),
requires_grad=False,
)
if get_bool_env_var("USE_INT4_WEIGHT"):
# INT4 MoE weight - INT32 packed
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
2 * intermediate_size,
hidden_size // 8,
dtype=params_dtype,
),
requires_grad=False,
)
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts, hidden_size, intermediate_size // 8, dtype=params_dtype
),
requires_grad=False,
)
else:
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype
),
requires_grad=False,
)
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts, hidden_size, intermediate_size, dtype=params_dtype
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts, hidden_size, intermediate_size, dtype=params_dtype
),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
@@ -538,7 +561,9 @@ class Fp8MoEMethod:
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
if is_hip_ and get_bool_env_var("CK_MOE"):
if (
is_hip_
): # and get_bool_env_var("CK_MOE"): TODO: add check back after triton kernel
# ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
w13_weight_scale1 = torch.nn.Parameter(
torch.ones(num_experts, 2 * intermediate_size, dtype=torch.float32),
@@ -565,6 +590,13 @@ class Fp8MoEMethod:
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
if get_bool_env_var("USE_INT4_WEIGHT"):
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
)
set_weight_attrs(w13_weight_scale1, extra_weight_attrs)
set_weight_attrs(w2_weight_scale1, extra_weight_attrs)
# INPUT_SCALES
if self.quant_config.activation_scheme == "static":
if not self.quant_config.is_checkpoint_fp8_serialized:
@@ -590,6 +622,53 @@ class Fp8MoEMethod:
layer.w2_input_scale = None
def process_weights_after_loading(self, layer: Module) -> None:
if get_bool_env_var("USE_INT4_WEIGHT"):
# TODO: and get_bool_env_var("CK_MOE"): add after triton kernel added
# INT4-FP8 (INT4 MoE Weight, FP8 Compute)
# Weight Permutation
layer.w13_weight = torch.nn.Parameter(
permute_weight(layer.w13_weight.data),
requires_grad=False,
)
torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter(
permute_weight(layer.w2_weight.data),
requires_grad=False,
)
torch.cuda.empty_cache()
# INT4-FP8 : offset INT4 w13_weight_scale1 to single w13_weight_scale
# Fp8 moe kernel needs single fp8 w13_weight_scale for w13 per expert.
# We won't do requant each expert's fp8 weight (not direct available),
# instead we adjust half of INT4 w13_weight_scale1 numbers
assert layer.w13_weight_scale is not None
shard_size = layer.intermediate_size_per_partition
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
for expert_id in range(layer.num_experts):
start = 0
max_w13_scale_fp8 = max_w13_scales[expert_id]
for shard_id in range(2):
if layer.w13_weight_scale[expert_id][shard_id] != max_w13_scale_fp8:
int4_rescale = (
layer.w13_weight_scale[expert_id][shard_id]
/ max_w13_scale_fp8
)
layer.w13_weight_scale1[expert_id][
start : start + shard_size
] *= int4_rescale
start += shard_size
layer.w13_weight_scale = torch.nn.Parameter(
max_w13_scales, requires_grad=False
)
# special hack to asm_moe, which takes (weight_scale1 * weight_scale) as post GEMM scaling
# optimal design - shall apply per-column weight_scale1 before GEMM, and weight_scale post
for expert_id in range(layer.num_experts):
layer.w13_weight_scale1[expert_id] *= max_w13_scales[expert_id]
layer.w2_weight_scale1[expert_id] *= layer.w2_weight_scale[expert_id]
return
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
padding_size, # Avoid circular import
)
@@ -823,8 +902,24 @@ class Fp8MoEMethod:
correction_bias=correction_bias,
)
if is_hip_ and get_bool_env_var("CK_MOE") and activation == "silu":
if is_hip_ and get_bool_env_var("USE_INT4_WEIGHT"):
# TODO: add triton kernel and add check get_bool_env_var("CK_MOE")
assert not no_combine, f"{no_combine=} is not supported."
return asm_moe(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
layer.w13_weight_scale1,
layer.w2_weight_scale1,
activation=activation,
)
if is_hip_ and get_bool_env_var("CK_MOE"):
# TODO(CK_MOE): FP8 or FP8 block_quant only supports 'silu' for the time-being.
assert (
activation == "silu"
), f"CK_MOE: FP8 and/or FP8 bloack_quant {activation=} will be supported later, unset CK_MOE"
assert not no_combine, f"{no_combine=} is not supported."
if self.block_quant:
return asm_moe(
@@ -835,10 +930,6 @@ class Fp8MoEMethod:
topk_ids,
layer.w13_weight_scale_inv,
layer.w2_weight_scale_inv,
None,
None,
False,
None,
block_shape=tuple(self.quant_config.weight_block_size),
expert_mask=None,
)
@@ -851,9 +942,6 @@ class Fp8MoEMethod:
topk_ids,
layer.w13_weight_scale1,
layer.w2_weight_scale1,
None,
None,
False,
)
else:
# Expert fusion with FP8 quantization