[feat]Support fusion kernel for constructing quant input and scale factor for fp8_blockwise_scaled_grouped_mm (#8023)
This commit is contained in:
@@ -1166,3 +1166,88 @@ def scaled_fp8_quant(
|
||||
) # True for static
|
||||
|
||||
return output, scale
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({"BLOCK_M": block_m}, num_warps=num_warps)
|
||||
for block_m in [16, 32, 64, 128]
|
||||
for num_warps in [2, 4, 8]
|
||||
],
|
||||
key=["K", "BLOCK_K", "M_ALIGNMENT"],
|
||||
)
|
||||
@triton.jit
|
||||
def _per_token_group_quant_fp8_hopper_moe_mn_major(
|
||||
a, # (M, K):(K, 1)
|
||||
expert_offsets, # (num_experts,)
|
||||
problem_sizes, # (num_experts, 3)
|
||||
a_fp8, # (M, K):(K, 1)
|
||||
sfa, # (M, k)
|
||||
K: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
M_ALIGNMENT: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr, # tune
|
||||
):
|
||||
k_offset = tl.program_id(0)
|
||||
expert_id = tl.program_id(1)
|
||||
|
||||
m = tl.load(problem_sizes + expert_id * 3)
|
||||
current_expert_offset = tl.load(expert_offsets + expert_id).to(tl.int64)
|
||||
tl.multiple_of(m, M_ALIGNMENT)
|
||||
tl.multiple_of(current_expert_offset, M_ALIGNMENT)
|
||||
|
||||
coord_k = k_offset * BLOCK_K + tl.arange(0, BLOCK_K)
|
||||
for i in tl.range(tl.cdiv(m, BLOCK_M)):
|
||||
coord_m = i * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
a_ptrs = a + current_expert_offset * K + coord_m[:, None] * K + coord_k[None, :]
|
||||
a_mask = (coord_m < m)[:, None] & (coord_k < K)[None, :]
|
||||
|
||||
inp = tl.load(a_ptrs, mask=a_mask).to(tl.float32) # [BLOCK_M, BLOCK_K]
|
||||
inp_amax = tl.max(tl.abs(inp), axis=1) # [BLOCK_M,]
|
||||
inp_amax = tl.clamp(inp_amax, min=1e-4, max=float("inf"))
|
||||
inp_fp8 = (inp * (448.0 / inp_amax[:, None])).to(tl.float8e4nv)
|
||||
|
||||
# Store fp8
|
||||
a_fp8_ptrs = (
|
||||
a_fp8 + current_expert_offset * K + coord_m[:, None] * K + coord_k[None, :]
|
||||
)
|
||||
tl.store(a_fp8_ptrs, inp_fp8, mask=a_mask)
|
||||
|
||||
# Store sfa
|
||||
k = tl.cdiv(K, BLOCK_K)
|
||||
sfa_ptrs = (
|
||||
sfa + current_expert_offset * k + k_offset * m + coord_m
|
||||
) # MN-Major with sfa
|
||||
tl.store(sfa_ptrs, inp_amax / 448.0, mask=coord_m < m)
|
||||
|
||||
|
||||
def per_token_group_quant_fp8_hopper_moe_mn_major(
|
||||
A: torch.Tensor,
|
||||
expert_offsets: torch.Tensor,
|
||||
problem_sizes: torch.Tensor,
|
||||
group_size: int,
|
||||
expert_tokens_alignment: int = 1,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert A.dim() == 2
|
||||
assert A.is_contiguous(), "`A` is not contiguous"
|
||||
assert (
|
||||
A.shape[-1] % group_size == 0
|
||||
), "the last dimension of `A` cannot be divisible by `group_size`"
|
||||
|
||||
a_q = torch.empty_like(A, device=A.device, dtype=fp8_dtype)
|
||||
M, K = A.shape[0], A.shape[1]
|
||||
k = K // group_size
|
||||
sfa = torch.empty((M, k), device=A.device, dtype=torch.float32)
|
||||
num_experts = problem_sizes.shape[0]
|
||||
grid = (k, num_experts)
|
||||
_per_token_group_quant_fp8_hopper_moe_mn_major[grid](
|
||||
A,
|
||||
expert_offsets,
|
||||
problem_sizes,
|
||||
a_q,
|
||||
sfa,
|
||||
K,
|
||||
group_size,
|
||||
expert_tokens_alignment,
|
||||
)
|
||||
return a_q, sfa
|
||||
|
||||
Reference in New Issue
Block a user