ROCm: Fix MoE padding for none FP8 cases (#2111)

This commit is contained in:
HAI
2024-11-21 12:23:21 -08:00
committed by GitHub
parent 7f8fcd39cd
commit f35cb46cc3

View File

@@ -250,9 +250,12 @@ def invoke_fused_moe_kernel(
assert topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1
padded_size = padding_size
if not use_fp8:
assert A_scale is None
assert B_scale is None
# MOE_PADDING FP8 only
padded_size = 0
else:
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
assert B_scale is not None
@@ -262,7 +265,7 @@ def invoke_fused_moe_kernel(
* triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
)
K = B.shape[2] - padding_size
K = B.shape[2] - padded_size
if K % config["BLOCK_SIZE_K"] == 0:
even_ks = True
else:
@@ -279,7 +282,7 @@ def invoke_fused_moe_kernel(
expert_ids,
num_tokens_post_padded,
B.shape[1],
B.shape[2] - padding_size,
B.shape[2] - padded_size,
sorted_token_ids.shape[0],
topk_ids.numel(),
A.stride(0),
@@ -480,8 +483,12 @@ def fused_experts(
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
):
padded_size = padding_size
if not use_fp8:
# MOE_PADDING FP8 only
padded_size = 0
# Check constraints.
assert hidden_states.shape[1] == w1.shape[2] - padding_size, "Hidden size mismatch"
assert hidden_states.shape[1] == w1.shape[2] - padded_size, "Hidden size mismatch"
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
@@ -498,7 +505,7 @@ def fused_experts(
get_config_func = functools.partial(
try_get_optimal_moe_config,
w1.shape,
(w2.shape[0], w2.shape[1], w2.shape[2] - padding_size),
(w2.shape[0], w2.shape[1], w2.shape[2] - padded_size),
topk_ids.shape[1],
"float8" if use_fp8 else None,
override_config=override_config,