diff --git a/benchmark/kernels/deepseek/benchmark_deepgemm_fp8_group_gemm.py.py b/benchmark/kernels/deepseek/benchmark_deepgemm_fp8_group_gemm.py.py index 539b3e249..2c3e8dfcc 100644 --- a/benchmark/kernels/deepseek/benchmark_deepgemm_fp8_group_gemm.py.py +++ b/benchmark/kernels/deepseek/benchmark_deepgemm_fp8_group_gemm.py.py @@ -115,17 +115,17 @@ def fp8_gemm_group_triton_kernel( ): """Kernel for computing the matmul C = A x B with FP8 inputs and scaling factors. A has shape (M, K), B has shape (K, N) and C has shape (M, N) + + Note: Block sizes must be multiples of 32 for optimal TMA performance. """ # Map program ids to the block of C it should compute - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m + pid_group = tl.program_id(axis=0) # Group ID + pid_n = tl.program_id(axis=1) # N dimension ID + + # Compute the M block ID within this group + group_size_m = min(M - pid_group * GROUP_SIZE_M, GROUP_SIZE_M) + pid_m_within_group = tl.program_id(axis=2) % group_size_m + pid_m = pid_group * GROUP_SIZE_M + pid_m_within_group # Create pointers for the first blocks of A and B offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M @@ -153,20 +153,15 @@ def fp8_gemm_group_triton_kernel( pid_n * stride_b_scale_n + k_block * stride_b_scale_k ) + # Perform matrix multiplication in FP8 + res = tl.dot(a, b) + # Load scaling factors for the current block a_scale = tl.load(a_scale_ptrs)[:, None] # [BLOCK_SIZE_M, 1] b_scale = tl.load(b_scale_ptrs) - # Convert FP8 to FP32 for computation - a = a.to(tl.float32) - b = b.to(tl.float32) - - # Apply scaling factors to the current block - a = a * a_scale - b = b * b_scale - - # Accumulate matmul for the current block - accumulator += tl.dot(a, b) + # Apply scaling factors to the accumulated result + accumulator += res * a_scale * b_scale # Advance pointers a_ptrs += BLOCK_SIZE_K * stride_ak @@ -183,13 +178,14 @@ def fp8_gemm_group_triton_kernel( tl.store(c_ptrs, c, mask=c_mask) -def fp8_gemm_group_triton(a_tuple, b_tuple, num_groups): +def fp8_gemm_group_triton(a_tuple, b_tuple, c, num_groups): """ Perform matrix multiplication with FP8 inputs and proper scaling. Args: a_tuple: Tuple of (quantized_tensor, scale_factors) for input A b_tuple: Tuple of (quantized_tensor, scale_factors) for input B + c: Output tensor in BF16 format num_groups: Number of groups for grouped GEMM Returns: @@ -199,32 +195,21 @@ def fp8_gemm_group_triton(a_tuple, b_tuple, num_groups): a, a_scale = a_tuple b, b_scale = b_tuple - # Check constraints - assert a.shape[1] == b.shape[1], "Incompatible dimensions" - assert a.is_contiguous(), "Matrix A must be contiguous" - M, K = a.shape - N, K_b = b.shape - assert K == K_b, f"Incompatible K dimensions: {K} vs {K_b}" + _, N = b.shape - # Transpose b to match kernel expectations (K,N format) - b = b.T.contiguous() + # Configure block sizes - must be multiples of 32 for TMA alignment + BLOCK_SIZE_M = 128 + BLOCK_SIZE_N = 128 + BLOCK_SIZE_K = 128 - # Allocate output in bfloat16 (not float16) - c = torch.empty((M, N), device=a.device, dtype=torch.bfloat16) + # Calculate grid dimensions + num_pid_m = triton.cdiv(M, BLOCK_SIZE_M) + num_pid_n = triton.cdiv(N, BLOCK_SIZE_N) + num_groups_grid = triton.cdiv(num_pid_m, num_groups) - # Prepare scale factors - # Ensure scales are in the right format and contiguous - a_scale = a_scale.contiguous() - b_scale = b_scale.contiguous() - - # 1D launch kernel - grid = lambda META: ( - triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), - ) - - # Calculate K blocks (128 elements per block) - K_blocks = triton.cdiv(K, 128) + # 3D grid launch - (group, n_blocks, m_blocks_per_group) + grid = (num_groups_grid, num_pid_n, min(num_groups, num_pid_m)) fp8_gemm_group_triton_kernel[grid]( a, @@ -245,9 +230,9 @@ def fp8_gemm_group_triton(a_tuple, b_tuple, num_groups): 1, # Stride in the K dimension may be 1 b_scale.stride(0), 1 if b_scale.dim() > 1 else 0, - BLOCK_SIZE_M=128, - BLOCK_SIZE_N=128, - BLOCK_SIZE_K=128, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, GROUP_SIZE_M=num_groups, ) @@ -264,6 +249,73 @@ def fp8_gemm_group_deepgemm(x_fp8_grouped, y_fp8_grouped, out, m_indices): return out +def calculate_diff(m: int, n: int, k: int, num_groups: int): + print(f"Shape (m={m}, n={n}, k={k}") + x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) + y = torch.randn((n, k), device="cuda", dtype=torch.bfloat16) + x_fp8_grouped, y_fp8_grouped, x_fp8_flat, y_fp8_flat, out, out_torch = ( + construct_grouped_and_flat_fp8(x, y, num_groups, is_masked=False) + ) + m_per_group = m // num_groups + out_deepgemm = out.clone() + m_indices = torch.arange(0, num_groups, device="cuda", dtype=torch.int) + m_indices = ( + m_indices.unsqueeze(-1).expand(num_groups, m_per_group).contiguous().view(-1) + ) + + fp8_gemm_group_deepgemm( + x_fp8_grouped, + y_fp8_grouped, + out_deepgemm, + m_indices, + ) + torch.cuda.synchronize() + + # Prepare inputs for Triton + a, a_scale = x_fp8_flat + b, b_scale = y_fp8_flat + b = b.T.contiguous() + # Ensure scales are in the right format and contiguous + a_scale, b_scale = a_scale.contiguous(), b_scale.contiguous() + M, _ = a.shape + _, N = b.shape + c = torch.empty((M, N), device=a.device, dtype=torch.bfloat16) + out_triton = fp8_gemm_group_triton((a, a_scale), (b, b_scale), c, num_groups) + torch.cuda.synchronize() + + diff_torch_deepgemm = torch.abs(out_torch - out_deepgemm).mean().item() + diff_torch_triton = torch.abs(out_torch - out_triton).mean().item() + diff_deepgemm_triton = torch.abs(out_deepgemm - out_triton).mean().item() + + print(f"Shape m={m}, n={n}, k={k}:") + print(f"Torch output: {out_torch[0, 0:5]}") + print(f"DeepGEMM output: {out_deepgemm[0, 0:5]}") + print(f"Triton output: {out_triton[0, 0:5]}") + print(f"Mean absolute difference (Torch-DeepGEMM): {diff_torch_deepgemm}") + print(f"Mean absolute difference (Torch-Triton): {diff_torch_triton}") + print(f"Mean absolute difference (DeepGEMM-Triton): {diff_deepgemm_triton}") + + deepgemm_torch_diff = calc_diff(out_deepgemm, out_torch) + triton_torch_diff = calc_diff(out_triton, out_torch) + deepgemm_triton_diff = calc_diff(out_deepgemm, out_triton) + + DIFF_THRESHOLD = 0.001 + all_match = ( + deepgemm_torch_diff < DIFF_THRESHOLD + and triton_torch_diff < DIFF_THRESHOLD + and deepgemm_triton_diff < DIFF_THRESHOLD + ) + if all_match: + print("✅ All implementations match\n") + else: + print("❌ Some implementations differ:") + print( + f" - Torch vs DeepGEMM: {'✅' if deepgemm_torch_diff < DIFF_THRESHOLD else '❌'}" + f" - Torch vs Triton: {'✅' if triton_torch_diff < DIFF_THRESHOLD else '❌'}" + f" - DeepGEMM vs Triton: {'✅' if deepgemm_triton_diff < DIFF_THRESHOLD else '❌'}" + ) + + def get_weight_shapes(tp_size): # cannot TP total = [ @@ -310,65 +362,6 @@ def create_benchmark_configs(tp_size): return configs -def calculate_diff(m: int, n: int, k: int, num_groups: int): - print(f"Shape (m={m}, n={n}, k={k}") - x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) - y = torch.randn((n, k), device="cuda", dtype=torch.bfloat16) - x_fp8_grouped, y_fp8_grouped, x_fp8_flat, y_fp8_flat, out, out_torch = ( - construct_grouped_and_flat_fp8(x, y, num_groups, is_masked=False) - ) - m_per_group = m // num_groups - out_deepgemm = out.clone() - m_indices = torch.arange(0, num_groups, device="cuda", dtype=torch.int) - m_indices = ( - m_indices.unsqueeze(-1).expand(num_groups, m_per_group).contiguous().view(-1) - ) - - fp8_gemm_group_deepgemm( - x_fp8_grouped, - y_fp8_grouped, - out_deepgemm, - m_indices, - ) - torch.cuda.synchronize() - - # Quantized x and y - out_triton = fp8_gemm_group_triton(x_fp8_flat, y_fp8_flat, num_groups) - torch.cuda.synchronize() - - diff_torch_deepgemm = torch.abs(out_torch - out_deepgemm).mean().item() - diff_torch_triton = torch.abs(out_torch - out_triton).mean().item() - diff_deepgemm_triton = torch.abs(out_deepgemm - out_triton).mean().item() - - print(f"Shape m={m}, n={n}, k={k}:") - print(f"Torch output: {out_torch[0, 0:5]}") - print(f"DeepGEMM output: {out_deepgemm[0, 0:5]}") - print(f"Triton output: {out_triton[0, 0:5]}") - print(f"Mean absolute difference (Torch-DeepGEMM): {diff_torch_deepgemm}") - print(f"Mean absolute difference (Torch-Triton): {diff_torch_triton}") - print(f"Mean absolute difference (DeepGEMM-Triton): {diff_deepgemm_triton}") - - deepgemm_torch_diff = calc_diff(out_deepgemm, out_torch) - triton_torch_diff = calc_diff(out_triton, out_torch) - deepgemm_triton_diff = calc_diff(out_deepgemm, out_triton) - - DIFF_THRESHOLD = 0.001 - all_match = ( - deepgemm_torch_diff < DIFF_THRESHOLD - and triton_torch_diff < DIFF_THRESHOLD - and deepgemm_triton_diff < DIFF_THRESHOLD - ) - if all_match: - print("✅ All implementations match\n") - else: - print("❌ Some implementations differ:") - print( - f" - Torch vs DeepGEMM: {'✅' if deepgemm_torch_diff < DIFF_THRESHOLD else '❌'}" - f" - Torch vs Triton: {'✅' if triton_torch_diff < DIFF_THRESHOLD else '❌'}" - f" - DeepGEMM vs Triton: {'✅' if deepgemm_triton_diff < DIFF_THRESHOLD else '❌'}" - ) - - def get_benchmark(tp_size): all_configs = create_benchmark_configs(tp_size) @@ -416,10 +409,21 @@ def get_benchmark(tp_size): quantiles=quantiles, ) elif provider == "triton": + # Prepare inputs for Triton + # We did it outside of the lambda function to make it fair comparison like deepgemm + a, a_scale = x_fp8_flat + b, b_scale = y_fp8_flat + b = b.T.contiguous() + # Ensure scales are in the right format and contiguous + a_scale, b_scale = a_scale.contiguous(), b_scale.contiguous() + M, _ = a.shape + _, N = b.shape + c = torch.empty((M, N), device=a.device, dtype=torch.bfloat16) ms, min_ms, max_ms = triton.testing.do_bench( lambda: fp8_gemm_group_triton( - x_fp8_flat, - y_fp8_flat, + (a, a_scale), + (b, b_scale), + c, num_groups, ), quantiles=quantiles, @@ -429,13 +433,8 @@ def get_benchmark(tp_size): flops = 2 * m * n * k # multiply-adds tflops = flops / (ms * 1e-3) / 1e12 - # Print shape-specific results with TFLOPS - print(f"Time: {ms:.2f} ms, TFLOPS: {tflops:.2f}") - return ( - ms, - max_ms, - min_ms, - ) # return in seconds for consistency with triton benchmark + print(f"Time: {ms*1000:.2f} ms, TFLOPS: {tflops:.2f}") + return ms * 1000, max_ms * 1000, min_ms * 1000 # convert to ms return benchmark @@ -478,6 +477,7 @@ if __name__ == "__main__": calculate_diff(8192, 2048, 7168, 4) calculate_diff(4096, 7168, 4096, 8) calculate_diff(4096, 2048, 7168, 8) + calculate_diff(4096, 576, 7168, 8) # Get the benchmark function with the specified tp_size benchmark = get_benchmark(args.tp_size)