ROCm: Fix MoE padding for none FP8 cases (#2111)
This commit is contained in:
@@ -250,9 +250,12 @@ def invoke_fused_moe_kernel(
|
||||
assert topk_weights.stride(1) == 1
|
||||
assert sorted_token_ids.stride(0) == 1
|
||||
|
||||
padded_size = padding_size
|
||||
if not use_fp8:
|
||||
assert A_scale is None
|
||||
assert B_scale is None
|
||||
# MOE_PADDING FP8 only
|
||||
padded_size = 0
|
||||
else:
|
||||
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
|
||||
assert B_scale is not None
|
||||
@@ -262,7 +265,7 @@ def invoke_fused_moe_kernel(
|
||||
* triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
|
||||
)
|
||||
|
||||
K = B.shape[2] - padding_size
|
||||
K = B.shape[2] - padded_size
|
||||
if K % config["BLOCK_SIZE_K"] == 0:
|
||||
even_ks = True
|
||||
else:
|
||||
@@ -279,7 +282,7 @@ def invoke_fused_moe_kernel(
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
B.shape[1],
|
||||
B.shape[2] - padding_size,
|
||||
B.shape[2] - padded_size,
|
||||
sorted_token_ids.shape[0],
|
||||
topk_ids.numel(),
|
||||
A.stride(0),
|
||||
@@ -480,8 +483,12 @@ def fused_experts(
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
):
|
||||
padded_size = padding_size
|
||||
if not use_fp8:
|
||||
# MOE_PADDING FP8 only
|
||||
padded_size = 0
|
||||
# Check constraints.
|
||||
assert hidden_states.shape[1] == w1.shape[2] - padding_size, "Hidden size mismatch"
|
||||
assert hidden_states.shape[1] == w1.shape[2] - padded_size, "Hidden size mismatch"
|
||||
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
||||
@@ -498,7 +505,7 @@ def fused_experts(
|
||||
get_config_func = functools.partial(
|
||||
try_get_optimal_moe_config,
|
||||
w1.shape,
|
||||
(w2.shape[0], w2.shape[1], w2.shape[2] - padding_size),
|
||||
(w2.shape[0], w2.shape[1], w2.shape[2] - padded_size),
|
||||
topk_ids.shape[1],
|
||||
"float8" if use_fp8 else None,
|
||||
override_config=override_config,
|
||||
|
||||
Reference in New Issue
Block a user