Support new DeepGEMM input format in silu_and_mul_masked_post_quant_fwd (#7153)
This commit is contained in:
@@ -278,6 +278,7 @@ def _silu_and_mul_post_quant_kernel(
|
|||||||
fp8_min,
|
fp8_min,
|
||||||
BLOCK_N: tl.constexpr,
|
BLOCK_N: tl.constexpr,
|
||||||
NUM_STAGE: tl.constexpr,
|
NUM_STAGE: tl.constexpr,
|
||||||
|
SCALE_UE8M0: tl.constexpr,
|
||||||
):
|
):
|
||||||
expert_id = tl.program_id(2)
|
expert_id = tl.program_id(2)
|
||||||
token_id = tl.program_id(1)
|
token_id = tl.program_id(1)
|
||||||
@@ -319,6 +320,8 @@ def _silu_and_mul_post_quant_kernel(
|
|||||||
gate_up = up * gate
|
gate_up = up * gate
|
||||||
_absmax = tl.maximum(tl.max(tl.abs(gate_up)), 1e-10)
|
_absmax = tl.maximum(tl.max(tl.abs(gate_up)), 1e-10)
|
||||||
output_s = _absmax / fp8_max
|
output_s = _absmax / fp8_max
|
||||||
|
if SCALE_UE8M0:
|
||||||
|
output_s = tl.exp2(tl.ceil(tl.log2(tl.abs(output_s))))
|
||||||
output_q = tl.clamp(gate_up / output_s, fp8_min, fp8_max).to(
|
output_q = tl.clamp(gate_up / output_s, fp8_min, fp8_max).to(
|
||||||
output_ptr.dtype.element_ty
|
output_ptr.dtype.element_ty
|
||||||
)
|
)
|
||||||
@@ -339,6 +342,7 @@ def silu_and_mul_masked_post_quant_fwd(
|
|||||||
output_scale: torch.Tensor,
|
output_scale: torch.Tensor,
|
||||||
quant_group_size: int,
|
quant_group_size: int,
|
||||||
masked_m: torch.Tensor,
|
masked_m: torch.Tensor,
|
||||||
|
scale_ue8m0: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
input shape [expert_num, token_num_padded, hidden_dim]
|
input shape [expert_num, token_num_padded, hidden_dim]
|
||||||
@@ -395,6 +399,7 @@ def silu_and_mul_masked_post_quant_fwd(
|
|||||||
BLOCK_N=BLOCK_N,
|
BLOCK_N=BLOCK_N,
|
||||||
NUM_STAGE=NUM_STAGES,
|
NUM_STAGE=NUM_STAGES,
|
||||||
num_warps=num_warps,
|
num_warps=num_warps,
|
||||||
|
SCALE_UE8M0=scale_ue8m0,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user