ROCm/AITER CK_MoE: update 2-stage kernels & support both Activations (#5228)
This commit is contained in:
@@ -71,7 +71,8 @@ ACTIVATION_SCHEMES = ["static", "dynamic"]
|
|||||||
_is_hip = is_hip()
|
_is_hip = is_hip()
|
||||||
|
|
||||||
if _is_hip:
|
if _is_hip:
|
||||||
from aiter.fused_moe_bf16_asm import asm_moe
|
from aiter import ActivationType
|
||||||
|
from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages, ck_moe_2stages_win4
|
||||||
from aiter.ops.shuffle import shuffle_weight
|
from aiter.ops.shuffle import shuffle_weight
|
||||||
|
|
||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
@@ -487,7 +488,7 @@ class Fp8MoEMethod:
|
|||||||
|
|
||||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
params_dtype = (
|
params_dtype = (
|
||||||
torch.int32
|
torch.uint32
|
||||||
if get_bool_env_var("USE_INT4_WEIGHT")
|
if get_bool_env_var("USE_INT4_WEIGHT")
|
||||||
else torch.float8_e4m3fn
|
else torch.float8_e4m3fn
|
||||||
)
|
)
|
||||||
@@ -822,12 +823,14 @@ class Fp8MoEMethod:
|
|||||||
# INT4-FP8 (INT4 MoE Weight, FP8 Compute)
|
# INT4-FP8 (INT4 MoE Weight, FP8 Compute)
|
||||||
# Weight Permutation
|
# Weight Permutation
|
||||||
layer.w13_weight = torch.nn.Parameter(
|
layer.w13_weight = torch.nn.Parameter(
|
||||||
permute_weight(layer.w13_weight.data),
|
# permute_weight(layer.w13_weight.data),
|
||||||
|
shuffle_weight(layer.w13_weight.data, (16, 16)),
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
)
|
)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
layer.w2_weight = torch.nn.Parameter(
|
layer.w2_weight = torch.nn.Parameter(
|
||||||
permute_weight(layer.w2_weight.data),
|
# permute_weight(layer.w2_weight.data),
|
||||||
|
shuffle_weight(layer.w2_weight.data, (16, 16)),
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
)
|
)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
@@ -867,12 +870,14 @@ class Fp8MoEMethod:
|
|||||||
|
|
||||||
if get_bool_env_var("CK_MOE"):
|
if get_bool_env_var("CK_MOE"):
|
||||||
layer.w13_weight = torch.nn.Parameter(
|
layer.w13_weight = torch.nn.Parameter(
|
||||||
permute_weight(layer.w13_weight.data),
|
# permute_weight(layer.w13_weight.data),
|
||||||
|
shuffle_weight(layer.w13_weight.data, (16, 16)),
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
)
|
)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
layer.w2_weight = torch.nn.Parameter(
|
layer.w2_weight = torch.nn.Parameter(
|
||||||
permute_weight(layer.w2_weight.data),
|
# permute_weight(layer.w2_weight.data),
|
||||||
|
shuffle_weight(layer.w2_weight.data, (16, 16)),
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
)
|
)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
@@ -928,7 +933,7 @@ class Fp8MoEMethod:
|
|||||||
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
|
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
|
||||||
# TODO: add triton kernel and add check get_bool_env_var("CK_MOE")
|
# TODO: add triton kernel and add check get_bool_env_var("CK_MOE")
|
||||||
assert not no_combine, f"{no_combine=} is not supported."
|
assert not no_combine, f"{no_combine=} is not supported."
|
||||||
return asm_moe(
|
return ck_moe_2stages_win4(
|
||||||
x,
|
x,
|
||||||
layer.w13_weight,
|
layer.w13_weight,
|
||||||
layer.w2_weight,
|
layer.w2_weight,
|
||||||
@@ -936,15 +941,17 @@ class Fp8MoEMethod:
|
|||||||
topk_ids,
|
topk_ids,
|
||||||
layer.w13_weight_scale1,
|
layer.w13_weight_scale1,
|
||||||
layer.w2_weight_scale1,
|
layer.w2_weight_scale1,
|
||||||
activation=activation,
|
activation=(
|
||||||
|
ActivationType.Silu if activation == "silu" else ActivationType.Gelu
|
||||||
|
),
|
||||||
)
|
)
|
||||||
if _is_hip and get_bool_env_var("CK_MOE"):
|
if _is_hip and get_bool_env_var("CK_MOE"):
|
||||||
# TODO(CK_MOE): FP8 or FP8 block_quant only supports 'silu' for the time-being.
|
|
||||||
assert (
|
|
||||||
activation == "silu"
|
|
||||||
), f"CK_MOE: FP8 and/or FP8 bloack_quant {activation=} will be supported later, unset CK_MOE"
|
|
||||||
assert not no_combine, f"{no_combine=} is not supported."
|
assert not no_combine, f"{no_combine=} is not supported."
|
||||||
if self.block_quant:
|
if self.block_quant:
|
||||||
|
# TODO(CK_MOE): FP8 block_quant only supports 'silu' for the time-being.
|
||||||
|
assert (
|
||||||
|
activation == "silu"
|
||||||
|
), f"CK_MOE: FP8 bloack_quant {activation=} will be supported later, unset CK_MOE"
|
||||||
return asm_moe(
|
return asm_moe(
|
||||||
x,
|
x,
|
||||||
layer.w13_weight,
|
layer.w13_weight,
|
||||||
@@ -957,7 +964,7 @@ class Fp8MoEMethod:
|
|||||||
expert_mask=None,
|
expert_mask=None,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return asm_moe(
|
return ck_moe_2stages(
|
||||||
x,
|
x,
|
||||||
layer.w13_weight,
|
layer.w13_weight,
|
||||||
layer.w2_weight,
|
layer.w2_weight,
|
||||||
@@ -965,6 +972,11 @@ class Fp8MoEMethod:
|
|||||||
topk_ids,
|
topk_ids,
|
||||||
layer.w13_weight_scale1,
|
layer.w13_weight_scale1,
|
||||||
layer.w2_weight_scale1,
|
layer.w2_weight_scale1,
|
||||||
|
activation=(
|
||||||
|
ActivationType.Silu
|
||||||
|
if activation == "silu"
|
||||||
|
else ActivationType.Gelu
|
||||||
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Expert fusion with FP8 quantization
|
# Expert fusion with FP8 quantization
|
||||||
|
|||||||
Reference in New Issue
Block a user