Support new DeepGEMM input format in silu_and_mul_masked_post_quant_fwd (#7153)
This commit is contained in:
@@ -278,6 +278,7 @@ def _silu_and_mul_post_quant_kernel(
|
||||
fp8_min,
|
||||
BLOCK_N: tl.constexpr,
|
||||
NUM_STAGE: tl.constexpr,
|
||||
SCALE_UE8M0: tl.constexpr,
|
||||
):
|
||||
expert_id = tl.program_id(2)
|
||||
token_id = tl.program_id(1)
|
||||
@@ -319,6 +320,8 @@ def _silu_and_mul_post_quant_kernel(
|
||||
gate_up = up * gate
|
||||
_absmax = tl.maximum(tl.max(tl.abs(gate_up)), 1e-10)
|
||||
output_s = _absmax / fp8_max
|
||||
if SCALE_UE8M0:
|
||||
output_s = tl.exp2(tl.ceil(tl.log2(tl.abs(output_s))))
|
||||
output_q = tl.clamp(gate_up / output_s, fp8_min, fp8_max).to(
|
||||
output_ptr.dtype.element_ty
|
||||
)
|
||||
@@ -339,6 +342,7 @@ def silu_and_mul_masked_post_quant_fwd(
|
||||
output_scale: torch.Tensor,
|
||||
quant_group_size: int,
|
||||
masked_m: torch.Tensor,
|
||||
scale_ue8m0: bool = False,
|
||||
):
|
||||
"""
|
||||
input shape [expert_num, token_num_padded, hidden_dim]
|
||||
@@ -395,6 +399,7 @@ def silu_and_mul_masked_post_quant_fwd(
|
||||
BLOCK_N=BLOCK_N,
|
||||
NUM_STAGE=NUM_STAGES,
|
||||
num_warps=num_warps,
|
||||
SCALE_UE8M0=scale_ue8m0,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
Reference in New Issue
Block a user