Optimize Triton Kernel of Group GEMM in DeepGEMM Benchmark (#4014)
This commit is contained in:
@@ -115,17 +115,17 @@ def fp8_gemm_group_triton_kernel(
|
|||||||
):
|
):
|
||||||
"""Kernel for computing the matmul C = A x B with FP8 inputs and scaling factors.
|
"""Kernel for computing the matmul C = A x B with FP8 inputs and scaling factors.
|
||||||
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
|
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
|
||||||
|
|
||||||
|
Note: Block sizes must be multiples of 32 for optimal TMA performance.
|
||||||
"""
|
"""
|
||||||
# Map program ids to the block of C it should compute
|
# Map program ids to the block of C it should compute
|
||||||
pid = tl.program_id(axis=0)
|
pid_group = tl.program_id(axis=0) # Group ID
|
||||||
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
pid_n = tl.program_id(axis=1) # N dimension ID
|
||||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
|
||||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
# Compute the M block ID within this group
|
||||||
group_id = pid // num_pid_in_group
|
group_size_m = min(M - pid_group * GROUP_SIZE_M, GROUP_SIZE_M)
|
||||||
first_pid_m = group_id * GROUP_SIZE_M
|
pid_m_within_group = tl.program_id(axis=2) % group_size_m
|
||||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
pid_m = pid_group * GROUP_SIZE_M + pid_m_within_group
|
||||||
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
|
|
||||||
pid_n = (pid % num_pid_in_group) // group_size_m
|
|
||||||
|
|
||||||
# Create pointers for the first blocks of A and B
|
# Create pointers for the first blocks of A and B
|
||||||
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
|
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
|
||||||
@@ -153,20 +153,15 @@ def fp8_gemm_group_triton_kernel(
|
|||||||
pid_n * stride_b_scale_n + k_block * stride_b_scale_k
|
pid_n * stride_b_scale_n + k_block * stride_b_scale_k
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Perform matrix multiplication in FP8
|
||||||
|
res = tl.dot(a, b)
|
||||||
|
|
||||||
# Load scaling factors for the current block
|
# Load scaling factors for the current block
|
||||||
a_scale = tl.load(a_scale_ptrs)[:, None] # [BLOCK_SIZE_M, 1]
|
a_scale = tl.load(a_scale_ptrs)[:, None] # [BLOCK_SIZE_M, 1]
|
||||||
b_scale = tl.load(b_scale_ptrs)
|
b_scale = tl.load(b_scale_ptrs)
|
||||||
|
|
||||||
# Convert FP8 to FP32 for computation
|
# Apply scaling factors to the accumulated result
|
||||||
a = a.to(tl.float32)
|
accumulator += res * a_scale * b_scale
|
||||||
b = b.to(tl.float32)
|
|
||||||
|
|
||||||
# Apply scaling factors to the current block
|
|
||||||
a = a * a_scale
|
|
||||||
b = b * b_scale
|
|
||||||
|
|
||||||
# Accumulate matmul for the current block
|
|
||||||
accumulator += tl.dot(a, b)
|
|
||||||
|
|
||||||
# Advance pointers
|
# Advance pointers
|
||||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||||
@@ -183,13 +178,14 @@ def fp8_gemm_group_triton_kernel(
|
|||||||
tl.store(c_ptrs, c, mask=c_mask)
|
tl.store(c_ptrs, c, mask=c_mask)
|
||||||
|
|
||||||
|
|
||||||
def fp8_gemm_group_triton(a_tuple, b_tuple, num_groups):
|
def fp8_gemm_group_triton(a_tuple, b_tuple, c, num_groups):
|
||||||
"""
|
"""
|
||||||
Perform matrix multiplication with FP8 inputs and proper scaling.
|
Perform matrix multiplication with FP8 inputs and proper scaling.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
a_tuple: Tuple of (quantized_tensor, scale_factors) for input A
|
a_tuple: Tuple of (quantized_tensor, scale_factors) for input A
|
||||||
b_tuple: Tuple of (quantized_tensor, scale_factors) for input B
|
b_tuple: Tuple of (quantized_tensor, scale_factors) for input B
|
||||||
|
c: Output tensor in BF16 format
|
||||||
num_groups: Number of groups for grouped GEMM
|
num_groups: Number of groups for grouped GEMM
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -199,32 +195,21 @@ def fp8_gemm_group_triton(a_tuple, b_tuple, num_groups):
|
|||||||
a, a_scale = a_tuple
|
a, a_scale = a_tuple
|
||||||
b, b_scale = b_tuple
|
b, b_scale = b_tuple
|
||||||
|
|
||||||
# Check constraints
|
|
||||||
assert a.shape[1] == b.shape[1], "Incompatible dimensions"
|
|
||||||
assert a.is_contiguous(), "Matrix A must be contiguous"
|
|
||||||
|
|
||||||
M, K = a.shape
|
M, K = a.shape
|
||||||
N, K_b = b.shape
|
_, N = b.shape
|
||||||
assert K == K_b, f"Incompatible K dimensions: {K} vs {K_b}"
|
|
||||||
|
|
||||||
# Transpose b to match kernel expectations (K,N format)
|
# Configure block sizes - must be multiples of 32 for TMA alignment
|
||||||
b = b.T.contiguous()
|
BLOCK_SIZE_M = 128
|
||||||
|
BLOCK_SIZE_N = 128
|
||||||
|
BLOCK_SIZE_K = 128
|
||||||
|
|
||||||
# Allocate output in bfloat16 (not float16)
|
# Calculate grid dimensions
|
||||||
c = torch.empty((M, N), device=a.device, dtype=torch.bfloat16)
|
num_pid_m = triton.cdiv(M, BLOCK_SIZE_M)
|
||||||
|
num_pid_n = triton.cdiv(N, BLOCK_SIZE_N)
|
||||||
|
num_groups_grid = triton.cdiv(num_pid_m, num_groups)
|
||||||
|
|
||||||
# Prepare scale factors
|
# 3D grid launch - (group, n_blocks, m_blocks_per_group)
|
||||||
# Ensure scales are in the right format and contiguous
|
grid = (num_groups_grid, num_pid_n, min(num_groups, num_pid_m))
|
||||||
a_scale = a_scale.contiguous()
|
|
||||||
b_scale = b_scale.contiguous()
|
|
||||||
|
|
||||||
# 1D launch kernel
|
|
||||||
grid = lambda META: (
|
|
||||||
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Calculate K blocks (128 elements per block)
|
|
||||||
K_blocks = triton.cdiv(K, 128)
|
|
||||||
|
|
||||||
fp8_gemm_group_triton_kernel[grid](
|
fp8_gemm_group_triton_kernel[grid](
|
||||||
a,
|
a,
|
||||||
@@ -245,9 +230,9 @@ def fp8_gemm_group_triton(a_tuple, b_tuple, num_groups):
|
|||||||
1, # Stride in the K dimension may be 1
|
1, # Stride in the K dimension may be 1
|
||||||
b_scale.stride(0),
|
b_scale.stride(0),
|
||||||
1 if b_scale.dim() > 1 else 0,
|
1 if b_scale.dim() > 1 else 0,
|
||||||
BLOCK_SIZE_M=128,
|
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
||||||
BLOCK_SIZE_N=128,
|
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
||||||
BLOCK_SIZE_K=128,
|
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
||||||
GROUP_SIZE_M=num_groups,
|
GROUP_SIZE_M=num_groups,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -264,6 +249,73 @@ def fp8_gemm_group_deepgemm(x_fp8_grouped, y_fp8_grouped, out, m_indices):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_diff(m: int, n: int, k: int, num_groups: int):
|
||||||
|
print(f"Shape (m={m}, n={n}, k={k}")
|
||||||
|
x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
|
||||||
|
y = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)
|
||||||
|
x_fp8_grouped, y_fp8_grouped, x_fp8_flat, y_fp8_flat, out, out_torch = (
|
||||||
|
construct_grouped_and_flat_fp8(x, y, num_groups, is_masked=False)
|
||||||
|
)
|
||||||
|
m_per_group = m // num_groups
|
||||||
|
out_deepgemm = out.clone()
|
||||||
|
m_indices = torch.arange(0, num_groups, device="cuda", dtype=torch.int)
|
||||||
|
m_indices = (
|
||||||
|
m_indices.unsqueeze(-1).expand(num_groups, m_per_group).contiguous().view(-1)
|
||||||
|
)
|
||||||
|
|
||||||
|
fp8_gemm_group_deepgemm(
|
||||||
|
x_fp8_grouped,
|
||||||
|
y_fp8_grouped,
|
||||||
|
out_deepgemm,
|
||||||
|
m_indices,
|
||||||
|
)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
# Prepare inputs for Triton
|
||||||
|
a, a_scale = x_fp8_flat
|
||||||
|
b, b_scale = y_fp8_flat
|
||||||
|
b = b.T.contiguous()
|
||||||
|
# Ensure scales are in the right format and contiguous
|
||||||
|
a_scale, b_scale = a_scale.contiguous(), b_scale.contiguous()
|
||||||
|
M, _ = a.shape
|
||||||
|
_, N = b.shape
|
||||||
|
c = torch.empty((M, N), device=a.device, dtype=torch.bfloat16)
|
||||||
|
out_triton = fp8_gemm_group_triton((a, a_scale), (b, b_scale), c, num_groups)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
diff_torch_deepgemm = torch.abs(out_torch - out_deepgemm).mean().item()
|
||||||
|
diff_torch_triton = torch.abs(out_torch - out_triton).mean().item()
|
||||||
|
diff_deepgemm_triton = torch.abs(out_deepgemm - out_triton).mean().item()
|
||||||
|
|
||||||
|
print(f"Shape m={m}, n={n}, k={k}:")
|
||||||
|
print(f"Torch output: {out_torch[0, 0:5]}")
|
||||||
|
print(f"DeepGEMM output: {out_deepgemm[0, 0:5]}")
|
||||||
|
print(f"Triton output: {out_triton[0, 0:5]}")
|
||||||
|
print(f"Mean absolute difference (Torch-DeepGEMM): {diff_torch_deepgemm}")
|
||||||
|
print(f"Mean absolute difference (Torch-Triton): {diff_torch_triton}")
|
||||||
|
print(f"Mean absolute difference (DeepGEMM-Triton): {diff_deepgemm_triton}")
|
||||||
|
|
||||||
|
deepgemm_torch_diff = calc_diff(out_deepgemm, out_torch)
|
||||||
|
triton_torch_diff = calc_diff(out_triton, out_torch)
|
||||||
|
deepgemm_triton_diff = calc_diff(out_deepgemm, out_triton)
|
||||||
|
|
||||||
|
DIFF_THRESHOLD = 0.001
|
||||||
|
all_match = (
|
||||||
|
deepgemm_torch_diff < DIFF_THRESHOLD
|
||||||
|
and triton_torch_diff < DIFF_THRESHOLD
|
||||||
|
and deepgemm_triton_diff < DIFF_THRESHOLD
|
||||||
|
)
|
||||||
|
if all_match:
|
||||||
|
print("✅ All implementations match\n")
|
||||||
|
else:
|
||||||
|
print("❌ Some implementations differ:")
|
||||||
|
print(
|
||||||
|
f" - Torch vs DeepGEMM: {'✅' if deepgemm_torch_diff < DIFF_THRESHOLD else '❌'}"
|
||||||
|
f" - Torch vs Triton: {'✅' if triton_torch_diff < DIFF_THRESHOLD else '❌'}"
|
||||||
|
f" - DeepGEMM vs Triton: {'✅' if deepgemm_triton_diff < DIFF_THRESHOLD else '❌'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_weight_shapes(tp_size):
|
def get_weight_shapes(tp_size):
|
||||||
# cannot TP
|
# cannot TP
|
||||||
total = [
|
total = [
|
||||||
@@ -310,65 +362,6 @@ def create_benchmark_configs(tp_size):
|
|||||||
return configs
|
return configs
|
||||||
|
|
||||||
|
|
||||||
def calculate_diff(m: int, n: int, k: int, num_groups: int):
|
|
||||||
print(f"Shape (m={m}, n={n}, k={k}")
|
|
||||||
x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
|
|
||||||
y = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)
|
|
||||||
x_fp8_grouped, y_fp8_grouped, x_fp8_flat, y_fp8_flat, out, out_torch = (
|
|
||||||
construct_grouped_and_flat_fp8(x, y, num_groups, is_masked=False)
|
|
||||||
)
|
|
||||||
m_per_group = m // num_groups
|
|
||||||
out_deepgemm = out.clone()
|
|
||||||
m_indices = torch.arange(0, num_groups, device="cuda", dtype=torch.int)
|
|
||||||
m_indices = (
|
|
||||||
m_indices.unsqueeze(-1).expand(num_groups, m_per_group).contiguous().view(-1)
|
|
||||||
)
|
|
||||||
|
|
||||||
fp8_gemm_group_deepgemm(
|
|
||||||
x_fp8_grouped,
|
|
||||||
y_fp8_grouped,
|
|
||||||
out_deepgemm,
|
|
||||||
m_indices,
|
|
||||||
)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
# Quantized x and y
|
|
||||||
out_triton = fp8_gemm_group_triton(x_fp8_flat, y_fp8_flat, num_groups)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
diff_torch_deepgemm = torch.abs(out_torch - out_deepgemm).mean().item()
|
|
||||||
diff_torch_triton = torch.abs(out_torch - out_triton).mean().item()
|
|
||||||
diff_deepgemm_triton = torch.abs(out_deepgemm - out_triton).mean().item()
|
|
||||||
|
|
||||||
print(f"Shape m={m}, n={n}, k={k}:")
|
|
||||||
print(f"Torch output: {out_torch[0, 0:5]}")
|
|
||||||
print(f"DeepGEMM output: {out_deepgemm[0, 0:5]}")
|
|
||||||
print(f"Triton output: {out_triton[0, 0:5]}")
|
|
||||||
print(f"Mean absolute difference (Torch-DeepGEMM): {diff_torch_deepgemm}")
|
|
||||||
print(f"Mean absolute difference (Torch-Triton): {diff_torch_triton}")
|
|
||||||
print(f"Mean absolute difference (DeepGEMM-Triton): {diff_deepgemm_triton}")
|
|
||||||
|
|
||||||
deepgemm_torch_diff = calc_diff(out_deepgemm, out_torch)
|
|
||||||
triton_torch_diff = calc_diff(out_triton, out_torch)
|
|
||||||
deepgemm_triton_diff = calc_diff(out_deepgemm, out_triton)
|
|
||||||
|
|
||||||
DIFF_THRESHOLD = 0.001
|
|
||||||
all_match = (
|
|
||||||
deepgemm_torch_diff < DIFF_THRESHOLD
|
|
||||||
and triton_torch_diff < DIFF_THRESHOLD
|
|
||||||
and deepgemm_triton_diff < DIFF_THRESHOLD
|
|
||||||
)
|
|
||||||
if all_match:
|
|
||||||
print("✅ All implementations match\n")
|
|
||||||
else:
|
|
||||||
print("❌ Some implementations differ:")
|
|
||||||
print(
|
|
||||||
f" - Torch vs DeepGEMM: {'✅' if deepgemm_torch_diff < DIFF_THRESHOLD else '❌'}"
|
|
||||||
f" - Torch vs Triton: {'✅' if triton_torch_diff < DIFF_THRESHOLD else '❌'}"
|
|
||||||
f" - DeepGEMM vs Triton: {'✅' if deepgemm_triton_diff < DIFF_THRESHOLD else '❌'}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_benchmark(tp_size):
|
def get_benchmark(tp_size):
|
||||||
all_configs = create_benchmark_configs(tp_size)
|
all_configs = create_benchmark_configs(tp_size)
|
||||||
|
|
||||||
@@ -416,10 +409,21 @@ def get_benchmark(tp_size):
|
|||||||
quantiles=quantiles,
|
quantiles=quantiles,
|
||||||
)
|
)
|
||||||
elif provider == "triton":
|
elif provider == "triton":
|
||||||
|
# Prepare inputs for Triton
|
||||||
|
# We did it outside of the lambda function to make it fair comparison like deepgemm
|
||||||
|
a, a_scale = x_fp8_flat
|
||||||
|
b, b_scale = y_fp8_flat
|
||||||
|
b = b.T.contiguous()
|
||||||
|
# Ensure scales are in the right format and contiguous
|
||||||
|
a_scale, b_scale = a_scale.contiguous(), b_scale.contiguous()
|
||||||
|
M, _ = a.shape
|
||||||
|
_, N = b.shape
|
||||||
|
c = torch.empty((M, N), device=a.device, dtype=torch.bfloat16)
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||||
lambda: fp8_gemm_group_triton(
|
lambda: fp8_gemm_group_triton(
|
||||||
x_fp8_flat,
|
(a, a_scale),
|
||||||
y_fp8_flat,
|
(b, b_scale),
|
||||||
|
c,
|
||||||
num_groups,
|
num_groups,
|
||||||
),
|
),
|
||||||
quantiles=quantiles,
|
quantiles=quantiles,
|
||||||
@@ -429,13 +433,8 @@ def get_benchmark(tp_size):
|
|||||||
flops = 2 * m * n * k # multiply-adds
|
flops = 2 * m * n * k # multiply-adds
|
||||||
tflops = flops / (ms * 1e-3) / 1e12
|
tflops = flops / (ms * 1e-3) / 1e12
|
||||||
|
|
||||||
# Print shape-specific results with TFLOPS
|
print(f"Time: {ms*1000:.2f} ms, TFLOPS: {tflops:.2f}")
|
||||||
print(f"Time: {ms:.2f} ms, TFLOPS: {tflops:.2f}")
|
return ms * 1000, max_ms * 1000, min_ms * 1000 # convert to ms
|
||||||
return (
|
|
||||||
ms,
|
|
||||||
max_ms,
|
|
||||||
min_ms,
|
|
||||||
) # return in seconds for consistency with triton benchmark
|
|
||||||
|
|
||||||
return benchmark
|
return benchmark
|
||||||
|
|
||||||
@@ -478,6 +477,7 @@ if __name__ == "__main__":
|
|||||||
calculate_diff(8192, 2048, 7168, 4)
|
calculate_diff(8192, 2048, 7168, 4)
|
||||||
calculate_diff(4096, 7168, 4096, 8)
|
calculate_diff(4096, 7168, 4096, 8)
|
||||||
calculate_diff(4096, 2048, 7168, 8)
|
calculate_diff(4096, 2048, 7168, 8)
|
||||||
|
calculate_diff(4096, 576, 7168, 8)
|
||||||
|
|
||||||
# Get the benchmark function with the specified tp_size
|
# Get the benchmark function with the specified tp_size
|
||||||
benchmark = get_benchmark(args.tp_size)
|
benchmark = get_benchmark(args.tp_size)
|
||||||
|
|||||||
Reference in New Issue
Block a user