ROCm: enable trillion-parameter MoE models with INT4-FP8 single node (#4152)
This commit is contained in:
@@ -513,6 +513,10 @@ class FusedMoE(torch.nn.Module):
|
||||
|
||||
# Case input scale: input_scale loading is only supported for fp8
|
||||
if "input_scale" in weight_name:
|
||||
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust input_scale for e4m3fnuz (AMD)
|
||||
if is_hip_ and get_bool_env_var("USE_INT4_WEIGHT"):
|
||||
loaded_weight = loaded_weight * 2.0
|
||||
|
||||
# this is needed for compressed-tensors only
|
||||
loaded_weight = loaded_weight.to(param.data.device)
|
||||
|
||||
@@ -551,6 +555,10 @@ class FusedMoE(torch.nn.Module):
|
||||
# specific to each case
|
||||
quant_method = getattr(param, "quant_method", None)
|
||||
if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value:
|
||||
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust INT4 column-wise scaling number to e4m3fnuz (AMD)
|
||||
if is_hip_ and get_bool_env_var("USE_INT4_WEIGHT"):
|
||||
loaded_weight = loaded_weight * 0.5
|
||||
|
||||
self._load_per_channel_weight_scale(
|
||||
shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
@@ -570,6 +578,10 @@ class FusedMoE(torch.nn.Module):
|
||||
tp_rank=tp_rank,
|
||||
)
|
||||
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
|
||||
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust FP8 per-tensor scaling number for e4m3fnuz (AMD)
|
||||
if is_hip_ and get_bool_env_var("USE_INT4_WEIGHT"):
|
||||
loaded_weight = loaded_weight * 2.0
|
||||
|
||||
self._load_per_tensor_weight_scale(
|
||||
shard_id=shard_id,
|
||||
param=param,
|
||||
|
||||
Reference in New Issue
Block a user