ROCm: enable trillion-parameter MoE models with INT4-FP8 single node (#4152)
This commit is contained in:
@@ -513,6 +513,10 @@ class FusedMoE(torch.nn.Module):
|
|||||||
|
|
||||||
# Case input scale: input_scale loading is only supported for fp8
|
# Case input scale: input_scale loading is only supported for fp8
|
||||||
if "input_scale" in weight_name:
|
if "input_scale" in weight_name:
|
||||||
|
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust input_scale for e4m3fnuz (AMD)
|
||||||
|
if is_hip_ and get_bool_env_var("USE_INT4_WEIGHT"):
|
||||||
|
loaded_weight = loaded_weight * 2.0
|
||||||
|
|
||||||
# this is needed for compressed-tensors only
|
# this is needed for compressed-tensors only
|
||||||
loaded_weight = loaded_weight.to(param.data.device)
|
loaded_weight = loaded_weight.to(param.data.device)
|
||||||
|
|
||||||
@@ -551,6 +555,10 @@ class FusedMoE(torch.nn.Module):
|
|||||||
# specific to each case
|
# specific to each case
|
||||||
quant_method = getattr(param, "quant_method", None)
|
quant_method = getattr(param, "quant_method", None)
|
||||||
if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value:
|
if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value:
|
||||||
|
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust INT4 column-wise scaling number to e4m3fnuz (AMD)
|
||||||
|
if is_hip_ and get_bool_env_var("USE_INT4_WEIGHT"):
|
||||||
|
loaded_weight = loaded_weight * 0.5
|
||||||
|
|
||||||
self._load_per_channel_weight_scale(
|
self._load_per_channel_weight_scale(
|
||||||
shard_id=shard_id,
|
shard_id=shard_id,
|
||||||
shard_dim=shard_dim,
|
shard_dim=shard_dim,
|
||||||
@@ -570,6 +578,10 @@ class FusedMoE(torch.nn.Module):
|
|||||||
tp_rank=tp_rank,
|
tp_rank=tp_rank,
|
||||||
)
|
)
|
||||||
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
|
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
|
||||||
|
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust FP8 per-tensor scaling number for e4m3fnuz (AMD)
|
||||||
|
if is_hip_ and get_bool_env_var("USE_INT4_WEIGHT"):
|
||||||
|
loaded_weight = loaded_weight * 2.0
|
||||||
|
|
||||||
self._load_per_tensor_weight_scale(
|
self._load_per_tensor_weight_scale(
|
||||||
shard_id=shard_id,
|
shard_id=shard_id,
|
||||||
param=param,
|
param=param,
|
||||||
|
|||||||
@@ -460,7 +460,11 @@ class Fp8MoEMethod:
|
|||||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
||||||
|
|
||||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
params_dtype = torch.float8_e4m3fn
|
params_dtype = (
|
||||||
|
torch.int32
|
||||||
|
if get_bool_env_var("USE_INT4_WEIGHT")
|
||||||
|
else torch.float8_e4m3fn
|
||||||
|
)
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
if self.block_quant:
|
if self.block_quant:
|
||||||
block_n, block_k = (
|
block_n, block_k = (
|
||||||
@@ -485,21 +489,40 @@ class Fp8MoEMethod:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# WEIGHTS
|
# WEIGHTS
|
||||||
w13_weight = torch.nn.Parameter(
|
if get_bool_env_var("USE_INT4_WEIGHT"):
|
||||||
torch.empty(
|
# INT4 MoE weight - INT32 packed
|
||||||
num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype
|
w13_weight = torch.nn.Parameter(
|
||||||
),
|
torch.empty(
|
||||||
requires_grad=False,
|
num_experts,
|
||||||
)
|
2 * intermediate_size,
|
||||||
|
hidden_size // 8,
|
||||||
|
dtype=params_dtype,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
w2_weight = torch.nn.Parameter(
|
||||||
|
torch.empty(
|
||||||
|
num_experts, hidden_size, intermediate_size // 8, dtype=params_dtype
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
w13_weight = torch.nn.Parameter(
|
||||||
|
torch.empty(
|
||||||
|
num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
w2_weight = torch.nn.Parameter(
|
||||||
|
torch.empty(
|
||||||
|
num_experts, hidden_size, intermediate_size, dtype=params_dtype
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
|
||||||
layer.register_parameter("w13_weight", w13_weight)
|
layer.register_parameter("w13_weight", w13_weight)
|
||||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||||
|
|
||||||
w2_weight = torch.nn.Parameter(
|
|
||||||
torch.empty(
|
|
||||||
num_experts, hidden_size, intermediate_size, dtype=params_dtype
|
|
||||||
),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
layer.register_parameter("w2_weight", w2_weight)
|
layer.register_parameter("w2_weight", w2_weight)
|
||||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||||
|
|
||||||
@@ -538,7 +561,9 @@ class Fp8MoEMethod:
|
|||||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||||
|
|
||||||
if is_hip_ and get_bool_env_var("CK_MOE"):
|
if (
|
||||||
|
is_hip_
|
||||||
|
): # and get_bool_env_var("CK_MOE"): TODO: add check back after triton kernel
|
||||||
# ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
|
# ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
|
||||||
w13_weight_scale1 = torch.nn.Parameter(
|
w13_weight_scale1 = torch.nn.Parameter(
|
||||||
torch.ones(num_experts, 2 * intermediate_size, dtype=torch.float32),
|
torch.ones(num_experts, 2 * intermediate_size, dtype=torch.float32),
|
||||||
@@ -565,6 +590,13 @@ class Fp8MoEMethod:
|
|||||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||||
|
|
||||||
|
if get_bool_env_var("USE_INT4_WEIGHT"):
|
||||||
|
extra_weight_attrs.update(
|
||||||
|
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
|
||||||
|
)
|
||||||
|
set_weight_attrs(w13_weight_scale1, extra_weight_attrs)
|
||||||
|
set_weight_attrs(w2_weight_scale1, extra_weight_attrs)
|
||||||
|
|
||||||
# INPUT_SCALES
|
# INPUT_SCALES
|
||||||
if self.quant_config.activation_scheme == "static":
|
if self.quant_config.activation_scheme == "static":
|
||||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
@@ -590,6 +622,53 @@ class Fp8MoEMethod:
|
|||||||
layer.w2_input_scale = None
|
layer.w2_input_scale = None
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: Module) -> None:
|
def process_weights_after_loading(self, layer: Module) -> None:
|
||||||
|
if get_bool_env_var("USE_INT4_WEIGHT"):
|
||||||
|
# TODO: and get_bool_env_var("CK_MOE"): add after triton kernel added
|
||||||
|
# INT4-FP8 (INT4 MoE Weight, FP8 Compute)
|
||||||
|
# Weight Permutation
|
||||||
|
layer.w13_weight = torch.nn.Parameter(
|
||||||
|
permute_weight(layer.w13_weight.data),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
layer.w2_weight = torch.nn.Parameter(
|
||||||
|
permute_weight(layer.w2_weight.data),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# INT4-FP8 : offset INT4 w13_weight_scale1 to single w13_weight_scale
|
||||||
|
# Fp8 moe kernel needs single fp8 w13_weight_scale for w13 per expert.
|
||||||
|
# We won't do requant each expert's fp8 weight (not direct available),
|
||||||
|
# instead we adjust half of INT4 w13_weight_scale1 numbers
|
||||||
|
assert layer.w13_weight_scale is not None
|
||||||
|
shard_size = layer.intermediate_size_per_partition
|
||||||
|
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
||||||
|
for expert_id in range(layer.num_experts):
|
||||||
|
start = 0
|
||||||
|
max_w13_scale_fp8 = max_w13_scales[expert_id]
|
||||||
|
for shard_id in range(2):
|
||||||
|
if layer.w13_weight_scale[expert_id][shard_id] != max_w13_scale_fp8:
|
||||||
|
int4_rescale = (
|
||||||
|
layer.w13_weight_scale[expert_id][shard_id]
|
||||||
|
/ max_w13_scale_fp8
|
||||||
|
)
|
||||||
|
layer.w13_weight_scale1[expert_id][
|
||||||
|
start : start + shard_size
|
||||||
|
] *= int4_rescale
|
||||||
|
start += shard_size
|
||||||
|
|
||||||
|
layer.w13_weight_scale = torch.nn.Parameter(
|
||||||
|
max_w13_scales, requires_grad=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# special hack to asm_moe, which takes (weight_scale1 * weight_scale) as post GEMM scaling
|
||||||
|
# optimal design - shall apply per-column weight_scale1 before GEMM, and weight_scale post
|
||||||
|
for expert_id in range(layer.num_experts):
|
||||||
|
layer.w13_weight_scale1[expert_id] *= max_w13_scales[expert_id]
|
||||||
|
layer.w2_weight_scale1[expert_id] *= layer.w2_weight_scale[expert_id]
|
||||||
|
return
|
||||||
|
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
||||||
padding_size, # Avoid circular import
|
padding_size, # Avoid circular import
|
||||||
)
|
)
|
||||||
@@ -823,8 +902,24 @@ class Fp8MoEMethod:
|
|||||||
correction_bias=correction_bias,
|
correction_bias=correction_bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_hip_ and get_bool_env_var("CK_MOE") and activation == "silu":
|
if is_hip_ and get_bool_env_var("USE_INT4_WEIGHT"):
|
||||||
|
# TODO: add triton kernel and add check get_bool_env_var("CK_MOE")
|
||||||
|
assert not no_combine, f"{no_combine=} is not supported."
|
||||||
|
return asm_moe(
|
||||||
|
x,
|
||||||
|
layer.w13_weight,
|
||||||
|
layer.w2_weight,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
layer.w13_weight_scale1,
|
||||||
|
layer.w2_weight_scale1,
|
||||||
|
activation=activation,
|
||||||
|
)
|
||||||
|
if is_hip_ and get_bool_env_var("CK_MOE"):
|
||||||
# TODO(CK_MOE): FP8 or FP8 block_quant only supports 'silu' for the time-being.
|
# TODO(CK_MOE): FP8 or FP8 block_quant only supports 'silu' for the time-being.
|
||||||
|
assert (
|
||||||
|
activation == "silu"
|
||||||
|
), f"CK_MOE: FP8 and/or FP8 bloack_quant {activation=} will be supported later, unset CK_MOE"
|
||||||
assert not no_combine, f"{no_combine=} is not supported."
|
assert not no_combine, f"{no_combine=} is not supported."
|
||||||
if self.block_quant:
|
if self.block_quant:
|
||||||
return asm_moe(
|
return asm_moe(
|
||||||
@@ -835,10 +930,6 @@ class Fp8MoEMethod:
|
|||||||
topk_ids,
|
topk_ids,
|
||||||
layer.w13_weight_scale_inv,
|
layer.w13_weight_scale_inv,
|
||||||
layer.w2_weight_scale_inv,
|
layer.w2_weight_scale_inv,
|
||||||
None,
|
|
||||||
None,
|
|
||||||
False,
|
|
||||||
None,
|
|
||||||
block_shape=tuple(self.quant_config.weight_block_size),
|
block_shape=tuple(self.quant_config.weight_block_size),
|
||||||
expert_mask=None,
|
expert_mask=None,
|
||||||
)
|
)
|
||||||
@@ -851,9 +942,6 @@ class Fp8MoEMethod:
|
|||||||
topk_ids,
|
topk_ids,
|
||||||
layer.w13_weight_scale1,
|
layer.w13_weight_scale1,
|
||||||
layer.w2_weight_scale1,
|
layer.w2_weight_scale1,
|
||||||
None,
|
|
||||||
None,
|
|
||||||
False,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Expert fusion with FP8 quantization
|
# Expert fusion with FP8 quantization
|
||||||
|
|||||||
@@ -1269,7 +1269,8 @@ def permute_weight(x: torch.Tensor) -> torch.Tensor:
|
|||||||
elif x.dtype == torch.float8_e4m3fnuz or x.dtype == torch.int8:
|
elif x.dtype == torch.float8_e4m3fnuz or x.dtype == torch.int8:
|
||||||
x_ = x_.view(int(b_), int(n_ / 16), 16, int(k_ / 64), 4, 16)
|
x_ = x_.view(int(b_), int(n_ / 16), 16, int(k_ / 64), 4, 16)
|
||||||
else:
|
else:
|
||||||
return x_
|
# return x_
|
||||||
|
x_ = x_.view(int(b_), int(n_ / 16), 16, int(k_ / 8), 2, 4)
|
||||||
|
|
||||||
x_ = x_.permute(0, 1, 3, 4, 2, 5)
|
x_ = x_.permute(0, 1, 3, 4, 2, 5)
|
||||||
x_ = x_.contiguous()
|
x_ = x_.contiguous()
|
||||||
|
|||||||
Reference in New Issue
Block a user