ROCm: enable trillion-parameter MoE models with INT4-FP8 single node (#4152)

This commit is contained in:
HAI
2025-03-06 15:33:02 -08:00
committed by GitHub
parent 9854a18a51
commit 13bc39c5d6
3 changed files with 124 additions and 23 deletions

View File

@@ -513,6 +513,10 @@ class FusedMoE(torch.nn.Module):
# Case input scale: input_scale loading is only supported for fp8
if "input_scale" in weight_name:
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust input_scale for e4m3fnuz (AMD)
if is_hip_ and get_bool_env_var("USE_INT4_WEIGHT"):
loaded_weight = loaded_weight * 2.0
# this is needed for compressed-tensors only
loaded_weight = loaded_weight.to(param.data.device)
@@ -551,6 +555,10 @@ class FusedMoE(torch.nn.Module):
# specific to each case
quant_method = getattr(param, "quant_method", None)
if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value:
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust INT4 column-wise scaling number to e4m3fnuz (AMD)
if is_hip_ and get_bool_env_var("USE_INT4_WEIGHT"):
loaded_weight = loaded_weight * 0.5
self._load_per_channel_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
@@ -570,6 +578,10 @@ class FusedMoE(torch.nn.Module):
tp_rank=tp_rank,
)
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust FP8 per-tensor scaling number for e4m3fnuz (AMD)
if is_hip_ and get_bool_env_var("USE_INT4_WEIGHT"):
loaded_weight = loaded_weight * 2.0
self._load_per_tensor_weight_scale(
shard_id=shard_id,
param=param,