ROCm: Fix MoE padding for none FP8 cases (#2111)
This commit is contained in:
@@ -250,9 +250,12 @@ def invoke_fused_moe_kernel(
|
|||||||
assert topk_weights.stride(1) == 1
|
assert topk_weights.stride(1) == 1
|
||||||
assert sorted_token_ids.stride(0) == 1
|
assert sorted_token_ids.stride(0) == 1
|
||||||
|
|
||||||
|
padded_size = padding_size
|
||||||
if not use_fp8:
|
if not use_fp8:
|
||||||
assert A_scale is None
|
assert A_scale is None
|
||||||
assert B_scale is None
|
assert B_scale is None
|
||||||
|
# MOE_PADDING FP8 only
|
||||||
|
padded_size = 0
|
||||||
else:
|
else:
|
||||||
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
|
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
|
||||||
assert B_scale is not None
|
assert B_scale is not None
|
||||||
@@ -262,7 +265,7 @@ def invoke_fused_moe_kernel(
|
|||||||
* triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
|
* triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
K = B.shape[2] - padding_size
|
K = B.shape[2] - padded_size
|
||||||
if K % config["BLOCK_SIZE_K"] == 0:
|
if K % config["BLOCK_SIZE_K"] == 0:
|
||||||
even_ks = True
|
even_ks = True
|
||||||
else:
|
else:
|
||||||
@@ -279,7 +282,7 @@ def invoke_fused_moe_kernel(
|
|||||||
expert_ids,
|
expert_ids,
|
||||||
num_tokens_post_padded,
|
num_tokens_post_padded,
|
||||||
B.shape[1],
|
B.shape[1],
|
||||||
B.shape[2] - padding_size,
|
B.shape[2] - padded_size,
|
||||||
sorted_token_ids.shape[0],
|
sorted_token_ids.shape[0],
|
||||||
topk_ids.numel(),
|
topk_ids.numel(),
|
||||||
A.stride(0),
|
A.stride(0),
|
||||||
@@ -480,8 +483,12 @@ def fused_experts(
|
|||||||
a1_scale: Optional[torch.Tensor] = None,
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
a2_scale: Optional[torch.Tensor] = None,
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
|
padded_size = padding_size
|
||||||
|
if not use_fp8:
|
||||||
|
# MOE_PADDING FP8 only
|
||||||
|
padded_size = 0
|
||||||
# Check constraints.
|
# Check constraints.
|
||||||
assert hidden_states.shape[1] == w1.shape[2] - padding_size, "Hidden size mismatch"
|
assert hidden_states.shape[1] == w1.shape[2] - padded_size, "Hidden size mismatch"
|
||||||
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
||||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||||
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
||||||
@@ -498,7 +505,7 @@ def fused_experts(
|
|||||||
get_config_func = functools.partial(
|
get_config_func = functools.partial(
|
||||||
try_get_optimal_moe_config,
|
try_get_optimal_moe_config,
|
||||||
w1.shape,
|
w1.shape,
|
||||||
(w2.shape[0], w2.shape[1], w2.shape[2] - padding_size),
|
(w2.shape[0], w2.shape[1], w2.shape[2] - padded_size),
|
||||||
topk_ids.shape[1],
|
topk_ids.shape[1],
|
||||||
"float8" if use_fp8 else None,
|
"float8" if use_fp8 else None,
|
||||||
override_config=override_config,
|
override_config=override_config,
|
||||||
|
|||||||
Reference in New Issue
Block a user