ROCm/AITER CK_MoE: update 2-stage kernels & support both Activations (#5228)
This commit is contained in:
@@ -71,7 +71,8 @@ ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||
_is_hip = is_hip()
|
||||
|
||||
if _is_hip:
|
||||
from aiter.fused_moe_bf16_asm import asm_moe
|
||||
from aiter import ActivationType
|
||||
from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages, ck_moe_2stages_win4
|
||||
from aiter.ops.shuffle import shuffle_weight
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
@@ -487,7 +488,7 @@ class Fp8MoEMethod:
|
||||
|
||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||
params_dtype = (
|
||||
torch.int32
|
||||
torch.uint32
|
||||
if get_bool_env_var("USE_INT4_WEIGHT")
|
||||
else torch.float8_e4m3fn
|
||||
)
|
||||
@@ -822,12 +823,14 @@ class Fp8MoEMethod:
|
||||
# INT4-FP8 (INT4 MoE Weight, FP8 Compute)
|
||||
# Weight Permutation
|
||||
layer.w13_weight = torch.nn.Parameter(
|
||||
permute_weight(layer.w13_weight.data),
|
||||
# permute_weight(layer.w13_weight.data),
|
||||
shuffle_weight(layer.w13_weight.data, (16, 16)),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
layer.w2_weight = torch.nn.Parameter(
|
||||
permute_weight(layer.w2_weight.data),
|
||||
# permute_weight(layer.w2_weight.data),
|
||||
shuffle_weight(layer.w2_weight.data, (16, 16)),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
@@ -867,12 +870,14 @@ class Fp8MoEMethod:
|
||||
|
||||
if get_bool_env_var("CK_MOE"):
|
||||
layer.w13_weight = torch.nn.Parameter(
|
||||
permute_weight(layer.w13_weight.data),
|
||||
# permute_weight(layer.w13_weight.data),
|
||||
shuffle_weight(layer.w13_weight.data, (16, 16)),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
layer.w2_weight = torch.nn.Parameter(
|
||||
permute_weight(layer.w2_weight.data),
|
||||
# permute_weight(layer.w2_weight.data),
|
||||
shuffle_weight(layer.w2_weight.data, (16, 16)),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
@@ -928,7 +933,7 @@ class Fp8MoEMethod:
|
||||
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
|
||||
# TODO: add triton kernel and add check get_bool_env_var("CK_MOE")
|
||||
assert not no_combine, f"{no_combine=} is not supported."
|
||||
return asm_moe(
|
||||
return ck_moe_2stages_win4(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
@@ -936,15 +941,17 @@ class Fp8MoEMethod:
|
||||
topk_ids,
|
||||
layer.w13_weight_scale1,
|
||||
layer.w2_weight_scale1,
|
||||
activation=activation,
|
||||
activation=(
|
||||
ActivationType.Silu if activation == "silu" else ActivationType.Gelu
|
||||
),
|
||||
)
|
||||
if _is_hip and get_bool_env_var("CK_MOE"):
|
||||
# TODO(CK_MOE): FP8 or FP8 block_quant only supports 'silu' for the time-being.
|
||||
assert (
|
||||
activation == "silu"
|
||||
), f"CK_MOE: FP8 and/or FP8 bloack_quant {activation=} will be supported later, unset CK_MOE"
|
||||
assert not no_combine, f"{no_combine=} is not supported."
|
||||
if self.block_quant:
|
||||
# TODO(CK_MOE): FP8 block_quant only supports 'silu' for the time-being.
|
||||
assert (
|
||||
activation == "silu"
|
||||
), f"CK_MOE: FP8 bloack_quant {activation=} will be supported later, unset CK_MOE"
|
||||
return asm_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
@@ -957,7 +964,7 @@ class Fp8MoEMethod:
|
||||
expert_mask=None,
|
||||
)
|
||||
else:
|
||||
return asm_moe(
|
||||
return ck_moe_2stages(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
@@ -965,6 +972,11 @@ class Fp8MoEMethod:
|
||||
topk_ids,
|
||||
layer.w13_weight_scale1,
|
||||
layer.w2_weight_scale1,
|
||||
activation=(
|
||||
ActivationType.Silu
|
||||
if activation == "silu"
|
||||
else ActivationType.Gelu
|
||||
),
|
||||
)
|
||||
else:
|
||||
# Expert fusion with FP8 quantization
|
||||
|
||||
Reference in New Issue
Block a user