diff --git a/benchmark/fbgemm/README.md b/benchmark/fbgemm/README.md new file mode 100644 index 000000000..e51356d8a --- /dev/null +++ b/benchmark/fbgemm/README.md @@ -0,0 +1,29 @@ +## Benchmark FBGEMM Grouped GEMM + +Benchmark FBGEMM Grouped GEMM in both Triton and CUDA version and SGLang Triton Grouped GEMM, it will be used to compare the bandwidth of different implementations. + +### Requirements + +```shell +pip install fbgemm-gpu-genai +``` + +### Usage + +```bash +python3 benchmark/fbgemm/benchmark_fbgemm_grouped_gemm.py --model Qwen/Qwen2-57B-A14B-Instruct --tp-size 4 --use-fp8-w8a8 +``` + +For example, in H200, the Qwen2-57B-A14B-Instruct TP4 fp8w8a8 grouped gemm bandwidth result is as follows: + +```shell +grouped-gemm-performance: + batch_size FBGEMM Triton Grouped GEMM FP8 FBGEMM CUTLASS F8F8BF16 Rowwise SGLang Grouped GEMM FP8 +0 256.0 3704.841339 3042.626402 2254.725030 +1 512.0 3691.426346 3029.065684 2269.504543 +2 1024.0 3653.938629 2258.471467 2358.319020 +3 2048.0 3596.644313 2271.611904 2476.895397 +4 4096.0 3468.496435 2231.283986 2179.473910 +``` + +The theoretical peak bandwidth of H200 is 4.8 TB/s. Taking batch_size 256 as an example, the bandwidth of FBGEMM Triton Grouped GEMM FP8 is 3704.841339 GB/s, the bandwidth of FBGEMM CUTLASS F8F8BF16 Rowwise is 3042.626402 GB/s, and the bandwidth of SGLang Grouped GEMM FP8 is 2254.725030 GB/s. Therefore, FBGEMM Triton Grouped GEMM FP8 achieves 77.9% of H200's theoretical peak bandwidth, FBGEMM CUTLASS F8F8BF16 Rowwise achieves 63.4% of H200's theoretical peak bandwidth, and SGLang Grouped GEMM FP8 achieves 46.9% of H200's theoretical peak bandwidth. diff --git a/benchmark/fbgemm/benchmark_fbgemm_grouped_gemm.py b/benchmark/fbgemm/benchmark_fbgemm_grouped_gemm.py index 4926bb214..6e8c8dcf2 100644 --- a/benchmark/fbgemm/benchmark_fbgemm_grouped_gemm.py +++ b/benchmark/fbgemm/benchmark_fbgemm_grouped_gemm.py @@ -1,10 +1,16 @@ -# python3 benchmark/kernels/fbgemm/benchmark_fbgemm_grouped_gemm.py --model Qwen/Qwen2-57B-A14B-Instruct --tp-size 4 --use-fp8-w8a8 +# python3 benchmark/fbgemm/benchmark_fbgemm_grouped_gemm.py --model Qwen/Qwen2-57B-A14B-Instruct --tp-size 4 --use-fp8-w8a8 import argparse import torch import triton -from fbgemm_grouped_gemm import grouped_gemm as fbgemm_grouped_gemm -from fbgemm_grouped_gemm import ( +from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import ( + quantize_fp8_row, + triton_quantize_fp8_row, +) +from fbgemm_gpu.experimental.gemm.triton_gemm.grouped_gemm import ( + grouped_gemm as fbgemm_grouped_gemm, +) +from fbgemm_gpu.experimental.gemm.triton_gemm.grouped_gemm import ( grouped_gemm_fp8_rowwise as fbgemm_grouped_gemm_fp8_rowwise, ) from transformers import AutoConfig @@ -29,12 +35,11 @@ def get_model_config(model_name: str, tp_size: int): elif config.architectures[0] == "Qwen3MoeForCausalLM": num_groups = config.num_experts intermediate_size = config.moe_intermediate_size - elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]: - num_groups = ( - config.n_routed_experts + 1 - if config.architectures[0] in ["DeepseekV3ForCausalLM"] - else config.n_routed_experts - ) + elif config.architectures[0] in [ + "DeepseekV2ForCausalLM", + "DeepseekV3ForCausalLM", + ]: + num_groups = config.n_routed_experts intermediate_size = config.moe_intermediate_size elif config.architectures[0] == "Llama4ForConditionalGeneration": num_groups = config.text_config.num_local_experts @@ -65,7 +70,7 @@ def create_test_data(batch_size, num_groups, hidden_size, intermediate_size): tokens_per_group = batch_size // num_groups m_sizes = torch.full( - (num_groups,), tokens_per_group, dtype=torch.int64, device="cuda" + (num_groups,), tokens_per_group, dtype=torch.int32, device="cuda" ) x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device="cuda") @@ -84,11 +89,11 @@ def create_test_data(batch_size, num_groups, hidden_size, intermediate_size): batch_size, intermediate_size, dtype=torch.bfloat16, device="cuda" ) - seg_indptr = torch.zeros(num_groups + 1, dtype=torch.int64, device="cuda") + seg_indptr = torch.zeros(num_groups + 1, dtype=torch.int32, device="cuda") for i in range(1, num_groups + 1): seg_indptr[i] = seg_indptr[i - 1] + tokens_per_group - weight_indices = torch.arange(num_groups, dtype=torch.int64, device="cuda") + weight_indices = torch.arange(num_groups, dtype=torch.int32, device="cuda") return ( x, @@ -102,39 +107,144 @@ def create_test_data(batch_size, num_groups, hidden_size, intermediate_size): ) -def create_fp8_test_data(batch_size, num_groups, hidden_size, intermediate_size): +def create_fp8_test_data( + batch_size, num_groups, hidden_size, intermediate_size, backend="triton" +): + """ + Create test data for FP8 grouped GEMM operations. + + Args: + batch_size: Total batch size + num_groups: Number of groups + hidden_size: Hidden dimension size + intermediate_size: Intermediate dimension size + backend: "triton" for Triton GEMM, "cutlass" for CUTLASS GEMM + + Returns: + For triton: (x_fp8, w_fp8, m_sizes, x_scale, w_scale) + For cutlass: (x, wq, w_scale, m_sizes) + """ torch.manual_seed(42) tokens_per_group = batch_size // num_groups - m_sizes = torch.full( - (num_groups,), tokens_per_group, dtype=torch.int64, device="cuda" - ) - x_fp16 = torch.randn(batch_size, hidden_size, dtype=torch.float16, device="cuda") - w_fp16 = torch.randn( - num_groups * intermediate_size, hidden_size, dtype=torch.float16, device="cuda" - ) + # Create weight matrices for each group + w_list = [] + for _ in range(num_groups): + w = torch.randn( + intermediate_size, hidden_size, dtype=torch.float16, device="cuda" + ) + w_list.append(w) - x_fp8 = x_fp16.to(torch.float8_e4m3fn) - w_fp8 = w_fp16.to(torch.float8_e4m3fn) + # Quantize weights using quantize_fp8_row for each group + wq_list, w_scale_list = zip(*[quantize_fp8_row(w) for w in w_list]) - x_scale = torch.randn(batch_size, dtype=torch.float32, device="cuda").abs() + 1e-4 - w_scale = torch.randn(num_groups, dtype=torch.float32, device="cuda").abs() + 1e-4 + if backend == "triton": + # Triton format: concatenated weights + w_fp8 = torch.concat(wq_list, dim=0).contiguous() + w_scale = torch.concat(w_scale_list, dim=0).contiguous() - return x_fp8, w_fp8, m_sizes, x_scale, w_scale + # Create m_sizes as int32 for triton + m_sizes = torch.full( + (num_groups,), tokens_per_group, dtype=torch.int32, device="cuda" + ) + + # Create and quantize input + x_fp16 = torch.randn( + batch_size, hidden_size, dtype=torch.float16, device="cuda" + ) + x_fp8, x_scale = triton_quantize_fp8_row(x_fp16) + x_scale = x_scale.view(batch_size, -1) + + return x_fp8, w_fp8, m_sizes, x_scale, w_scale + + elif backend == "cutlass": + # CUTLASS format: stacked weights + wq = torch.stack(wq_list, dim=0).contiguous() + w_scale = torch.stack(w_scale_list, dim=0).contiguous() + + # Create m_sizes as int64 for cutlass + m_values = [tokens_per_group] * num_groups + m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device="cuda") + + # Create input data - separate for each group then concat + x_list = [] + for _ in range(num_groups): + x = torch.randn( + tokens_per_group, hidden_size, dtype=torch.float16, device="cuda" + ) + x_list.append(x) + + # Concatenate inputs into single tensor + x = torch.concat(x_list, dim=0).contiguous() + + return x, wq, w_scale, m_sizes + + else: + raise ValueError(f"Unsupported backend: {backend}") + + +def calculate_memory_bandwidth(m_sizes, hidden_size, intermediate_size, dtype): + """ + Calculate memory bandwidth based on accessed expert weights. + + Args: + m_sizes: Tensor containing batch sizes for each group + hidden_size: Hidden dimension size + intermediate_size: Intermediate dimension size + dtype: Data type of weights + + Returns: + Memory size in bytes for accessed expert weights + """ + # Count non-zero groups (active experts) + if hasattr(m_sizes, "cpu"): + active_experts = torch.count_nonzero(m_sizes).item() + else: + active_experts = sum(1 for m in m_sizes if m > 0) + + # Calculate bytes per element based on dtype + if dtype in [torch.float16, torch.bfloat16]: + bytes_per_element = 2 + elif dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: + bytes_per_element = 1 + elif dtype == torch.float32: + bytes_per_element = 4 + else: + # Default to 2 bytes for unknown dtypes + bytes_per_element = 2 + + # Memory per expert weight matrix + memory_per_expert = hidden_size * intermediate_size * bytes_per_element + + # Total memory for active experts + total_memory_bytes = active_experts * memory_per_expert + + return total_memory_bytes def get_benchmark_config(use_fp8_w8a8=False): if use_fp8_w8a8: return { - "line_vals": ["fbgemm_grouped_gemm_fp8", "sglang_grouped_gemm"], - "line_names": ["FBGEMM Grouped GEMM FP8", "SGLang Grouped GEMM FP8"], - "styles": [("blue", "-"), ("red", "-")], + "line_vals": [ + "fbgemm_triton_grouped_gemm_fp8", + "fbgemm_cutlass_f8f8bf16_rowwise", + "sglang_grouped_gemm", + ], + "line_names": [ + "FBGEMM Triton Grouped GEMM FP8", + "FBGEMM CUTLASS F8F8BF16 Rowwise", + "SGLang Grouped GEMM FP8", + ], + "styles": [("blue", "-"), ("orange", "-"), ("red", "-")], } else: return { - "line_vals": ["fbgemm_grouped_gemm", "sglang_grouped_gemm"], - "line_names": ["FBGEMM Grouped GEMM BF16", "SGLang Grouped GEMM BF16"], + "line_vals": ["fbgemm_triton_grouped_gemm", "sglang_grouped_gemm"], + "line_names": [ + "FBGEMM Triton Grouped GEMM BF16", + "SGLang Grouped GEMM BF16", + ], "styles": [("blue", "-"), ("green", "-")], } @@ -146,12 +256,12 @@ def run_benchmark( benchmark_config = triton.testing.Benchmark( x_names=["batch_size"], - x_vals=[1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096], + x_vals=[256, 512, 1024, 2048, 4096], line_arg="provider", line_vals=config["line_vals"], line_names=config["line_names"], styles=config["styles"], - ylabel="Time (ms)", + ylabel="Bandwidth (GB/s)", plot_name="grouped-gemm-performance", args={}, ) @@ -165,13 +275,22 @@ def run_benchmark( hidden_size = model_config["hidden_size"] intermediate_size = model_config["intermediate_size"] - if provider == "fbgemm_grouped_gemm_fp8": + if provider == "fbgemm_triton_grouped_gemm_fp8": try: test_data = create_fp8_test_data( - batch_size, num_groups, hidden_size, intermediate_size + batch_size, + num_groups, + hidden_size, + intermediate_size, + backend="triton", ) x_fp8, w_fp8, m_sizes, x_scale, w_scale = test_data + # Calculate memory bandwidth + memory_bytes = calculate_memory_bandwidth( + m_sizes, hidden_size, intermediate_size, torch.float8_e4m3fn + ) + def run_func(): return fbgemm_grouped_gemm_fp8_rowwise( x_fp8, w_fp8, m_sizes, x_scale, w_scale, use_fast_accum=True @@ -180,6 +299,38 @@ def run_benchmark( except Exception as e: print(f"FP8 not supported, skipping: {e}") return float("inf"), float("inf"), float("inf") + + elif provider == "fbgemm_cutlass_f8f8bf16_rowwise": + try: + test_data = create_fp8_test_data( + batch_size, + num_groups, + hidden_size, + intermediate_size, + backend="cutlass", + ) + x, wq, w_scale, m_sizes = test_data + + # Calculate memory bandwidth + memory_bytes = calculate_memory_bandwidth( + m_sizes, hidden_size, intermediate_size, torch.float8_e4m3fn + ) + + # Quantize input using triton_quantize_fp8_row + xq, x_scale = triton_quantize_fp8_row(x) + x_scale = x_scale.view(batch_size, -1) + + def run_func(): + return torch.ops.fbgemm.f8f8bf16_rowwise_grouped_stacked( + xq, wq, x_scale, w_scale, m_sizes + ) + + except Exception as e: + print( + f"CUTLASS f8f8bf16_rowwise_grouped_stacked not supported, " + f"skipping: {e}" + ) + return float("inf"), float("inf"), float("inf") else: test_data = create_test_data( batch_size, num_groups, hidden_size, intermediate_size @@ -195,7 +346,12 @@ def run_benchmark( weight_indices, ) = test_data - if provider == "fbgemm_grouped_gemm": + # Calculate memory bandwidth for BF16 operations + memory_bytes = calculate_memory_bandwidth( + m_sizes, hidden_size, intermediate_size, torch.bfloat16 + ) + + if provider == "fbgemm_triton_grouped_gemm": def run_func(): return fbgemm_grouped_gemm( @@ -228,10 +384,19 @@ def run_benchmark( try: quantiles = [0.5, 0.2, 0.8] ms, min_ms, max_ms = triton.testing.do_bench(run_func, quantiles=quantiles) - return ms, min_ms, max_ms + + # Convert time (ms) to bandwidth (GB/s) + # Bandwidth = Memory (bytes) / Time (seconds) + # Convert ms to seconds and bytes to GB (1e9) + gb_per_s = (memory_bytes / 1e9) / (ms / 1000) + # min bandwidth = max time, max bandwidth = min time + min_gb_per_s = (memory_bytes / 1e9) / (max_ms / 1000) + max_gb_per_s = (memory_bytes / 1e9) / (min_ms / 1000) + + return gb_per_s, min_gb_per_s, max_gb_per_s except Exception as e: print(f"Error during benchmarking for {provider}: {e}") - return float("inf"), float("inf"), float("inf") + return 0.0, 0.0, 0.0 dynamic_benchmark.run( show_plots=True, @@ -242,7 +407,7 @@ def run_benchmark( ) -def verify_correctness(model_config, use_fp8_w8a8): +def verify_correctness(model_config): print("Verifying correctness...") batch_size = 128 num_groups = model_config["num_groups"] @@ -250,54 +415,39 @@ def verify_correctness(model_config, use_fp8_w8a8): intermediate_size = model_config["intermediate_size"] test_data = create_test_data(batch_size, num_groups, hidden_size, intermediate_size) - (x, w_fbgemm, w_sglang, c_fbgemm, c_sglang, m_sizes, seg_indptr, weight_indices) = ( - test_data + ( + x, + w_fbgemm, + w_sglang, + c_fbgemm, + c_sglang, + m_sizes, + seg_indptr, + weight_indices, + ) = test_data + + result_fbgemm = fbgemm_grouped_gemm(x, w_fbgemm, m_sizes, use_fast_accum=True) + + result_sglang = sglang_grouped_gemm( + x, + w_sglang, + c_sglang, + num_groups, + weight_column_major=True, + seg_indptr=seg_indptr, + weight_indices=weight_indices, + c_dtype=c_sglang.dtype, ) - try: - result_fbgemm = fbgemm_grouped_gemm(x, w_fbgemm, m_sizes, use_fast_accum=True) - - result_sglang = sglang_grouped_gemm( - x, - w_sglang, - c_sglang, - num_groups, - weight_column_major=True, - seg_indptr=seg_indptr, - weight_indices=weight_indices, - c_dtype=c_sglang.dtype, - ) - - if torch.allclose(result_fbgemm, result_sglang, rtol=1e-3, atol=1e-3): - print("✓ BF16 Correctness verification passed!") - else: - max_diff = torch.max(torch.abs(result_fbgemm - result_sglang)) - print(f"✗ BF16 Correctness verification failed! Max diff: {max_diff}") - return False - - if use_fp8_w8a8: - try: - fp8_data = create_fp8_test_data( - batch_size, num_groups, hidden_size, intermediate_size - ) - x_fp8, w_fp8, m_sizes_fp8, x_scale, w_scale = fp8_data - - result_fp8 = fbgemm_grouped_gemm_fp8_rowwise( - x_fp8, w_fp8, m_sizes_fp8, x_scale, w_scale, use_fast_accum=True - ) - - assert result_fp8.shape == (batch_size, intermediate_size) - print("✓ FP8 functionality test passed!") - except Exception as e: - print(f"FP8 test failed (possibly unsupported): {e}") - return False - - return True - - except Exception as e: - print(f"✗ Error during correctness verification: {e}") + if torch.allclose(result_fbgemm, result_sglang, rtol=1e-3, atol=1e-3): + print("✓ BF16 Correctness verification passed!") + else: + max_diff = torch.max(torch.abs(result_fbgemm - result_sglang)) + print(f"✗ BF16 Correctness verification failed! Max diff: {max_diff}") return False + return True + def main(): parser = argparse.ArgumentParser( @@ -348,7 +498,7 @@ def main(): print(f" use_fp8_w8a8: {args.use_fp8_w8a8}") if args.verify_correctness: - if not verify_correctness(model_config, args.use_fp8_w8a8): + if not verify_correctness(model_config): print("Correctness verification failed. Exiting...") return diff --git a/benchmark/fbgemm/fbgemm_grouped_gemm.py b/benchmark/fbgemm/fbgemm_grouped_gemm.py deleted file mode 100644 index 7d8568947..000000000 --- a/benchmark/fbgemm/fbgemm_grouped_gemm.py +++ /dev/null @@ -1,1294 +0,0 @@ -# Copy from https://github.com/pytorch/FBGEMM/tree/main/fbgemm_gpu/experimental/gen_ai -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - -import functools -import inspect -import sys -import warnings -from typing import Optional - -import torch -import triton # @manual -import triton.language as tl # @manual -from triton.runtime import driver # @manual - - -def map_dtype_to_triton(dtype: torch.dtype) -> tl.dtype: - """ - Maps torch dtype to triton dtype. - - Args: - dtype (torch.dtype): input dtype. - - Returns: - tl.dtype: triton dtype. - """ - if dtype == torch.float16: - return tl.float16 - elif dtype == torch.bfloat16: - return tl.bfloat16 - elif dtype == torch.float32: - return tl.float32 - elif dtype == torch.int32: - return tl.int32 - elif dtype == torch.float8_e4m3fn and torch.version.hip is None: - return tl.float8e4nv - else: - raise ValueError(f"Unsupported dtype {dtype}") - - -# check if we have the TMA version in Triton PR #4498 (https://github.com/triton-lang/triton/pull/4498). -HAS_TMA_DESC = "nv_tma_desc_type" in dir(tl) - -if HAS_TMA_DESC: - print( - "TMA benchmarks will be running with experimental grid constant TMA descriptor.", - file=sys.stderr, - ) -else: - print( - "TMA benchmarks will be running without grid constant TMA descriptor.", - file=sys.stderr, - ) - - -class TmaAutoTuneHelper: - - # duck typing wrapper to implement the same interface as TmaDescKernelParam in Triton PR #4498 - class KernelParamWrapper: - def __init__(self, desc): - self.desc = desc - - def tma_desc_cpu_ptr(self): - return self.desc.data_ptr() - - TMA_SIZE = 128 - - def __init__(self): - self.fill_1d_tma_descriptor_inner = ( - triton.runtime.driver.active.utils.fill_1d_tma_descriptor - ) - self.fill_2d_tma_descriptor_inner = ( - triton.runtime.driver.active.utils.fill_2d_tma_descriptor - ) - if HAS_TMA_DESC: - self.descriptors = {} - else: - self.cuda_descriptors = {} - - # Call this method outside of the lambda function for grid size - def init_tma_descriptor(self, name): - if HAS_TMA_DESC: - self.descriptors[name] = torch.empty( - TmaAutoTuneHelper.TMA_SIZE, device="cpu", dtype=torch.int8 - ) - else: - self.cuda_descriptors[name] = torch.empty( - TmaAutoTuneHelper.TMA_SIZE, device="cuda", dtype=torch.int8 - ) - - # Call this method inside the lambda function for grid size - def fill_1d_tma_descriptor(self, name, ptr, dim, block_dim, element_size): - if HAS_TMA_DESC: - desc_x = self.descriptors[name] - assert desc_x.data_ptr() % 64 == 0 - self.fill_1d_tma_descriptor_inner( - ptr, dim, block_dim, element_size, desc_x.data_ptr() - ) - else: - desc_x = self.cuda_descriptors[name] - buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True) - self.fill_1d_tma_descriptor_inner( - ptr, dim, block_dim, element_size, buf_x.data_ptr() - ) - desc_x.copy_(buf_x, non_blocking=True) - - # Call this method inside the lambda function for grid size - def fill_2d_tma_descriptor( - self, name, ptr, dim1, dim0, block_dim1, block_dim0, element_size - ): - if HAS_TMA_DESC: - desc_x = self.descriptors[name] - assert desc_x.data_ptr() % 64 == 0 - self.fill_2d_tma_descriptor_inner( - ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr() - ) - else: - desc_x = self.cuda_descriptors[name] - buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True) - self.fill_2d_tma_descriptor_inner( - ptr, dim1, dim0, block_dim1, block_dim0, element_size, buf_x.data_ptr() - ) - desc_x.copy_(buf_x, non_blocking=True) - - def get_tma_descriptor_kernel_param(self, name): - if HAS_TMA_DESC: - assert self.descriptors[name] is not None - return self.KernelParamWrapper(self.descriptors[name]) - else: - assert self.cuda_descriptors[name] is not None - return self.cuda_descriptors[name] - - -_NV_CONFIGS = [ - triton.Config( - { - "BLOCK_SIZE_M": block_size_m, - "BLOCK_SIZE_N": block_size_n, - "BLOCK_SIZE_K": block_size_k, - "NUM_CONSUMER_GROUPS": 1, - }, - num_stages=num_stages, - num_warps=num_warps, - num_ctas=num_ctas, - ) - for block_size_m in [64, 128] - for block_size_n in [64, 128, 256] - for block_size_k in [64, 128, 256] - for num_stages in [3, 4] - for num_warps in [4, 8] - for num_ctas in [1] -] - -_HAS_WS_SUPPORT = None - - -def _check_ws_support(): - if not hasattr(tl, "async_task"): - return False - config_signature = inspect.signature(triton.Config).parameters - if ( - "num_consumer_groups" not in config_signature - or "num_buffers_warp_spec" not in config_signature - ): - return False - if not HAS_TMA_DESC: - return False - return True - - -def _set_ws_support(): - global _HAS_WS_SUPPORT - if _HAS_WS_SUPPORT is None: - _HAS_WS_SUPPORT = _check_ws_support() - - -_set_ws_support() - -if _HAS_WS_SUPPORT: - _NV_WS_CONFIGS = [ - triton.Config( - { - "BLOCK_SIZE_M": block_size_m, - "BLOCK_SIZE_N": block_size_n, - "BLOCK_SIZE_K": block_size_k, - "NUM_CONSUMER_GROUPS": max(1, num_consumer_groups), - "USE_TMA_LOAD_ON_SCALES": use_tma_load_on_scales, - "USE_TMA_STORE": use_tma_store, - }, - num_stages=num_stages, - num_warps=num_warps, - num_ctas=num_ctas, - num_consumer_groups=num_consumer_groups, - num_buffers_warp_spec=num_stages, - ) - for block_size_m in [64, 128, 256] - for block_size_n in [64, 128, 256] - for block_size_k in [64, 128, 256] - for num_stages in [2, 3, 4] - for num_warps in [4, 8, 16] - # TODO(shikaili): Resolve LLVM error. - for num_ctas in [1] - for num_consumer_groups in [0, 2] - for use_tma_load_on_scales in [True, False] - # TODO(shikaili): Resolve compatibility with ws. - for use_tma_store in [False] - ] -else: - _NV_WS_CONFIGS = _NV_CONFIGS - - -_AMD_CONFIGS = [ - triton.Config( - { - "BLOCK_SIZE_M": block_size_m, - "BLOCK_SIZE_N": block_size_n, - "BLOCK_SIZE_K": block_size_k, - "waves_per_eu": waves_per_cu, - "matrix_instr_nonkdim": matrix_instr_nonkdim, - "NUM_CONSUMER_GROUPS": 1, - }, - num_stages=num_stages, - num_warps=num_warps, - ) - for block_size_m in [32, 64, 128] - for block_size_n in [32, 64, 128, 256] - for block_size_k in [128, 256] - for num_stages in [1, 2] - for num_warps, waves_per_cu in [(4, 1), (8, 2), (16, 4)] - for matrix_instr_nonkdim in [16] -] - - -def early_config_prune(configs, named_args, dtsize=None, dtype=None, **kwargs): - device = torch.cuda.current_device() - # BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages - if dtsize is None: - dtsize = named_args["c_ptr"].element_size() - if dtype is None: - dtype = named_args["c_ptr"].dtype - - pruned_configs = [] - for config in configs: - kw = config.kwargs - ( - BLOCK_M, - BLOCK_N, - BLOCK_K, - num_stages, - num_warps, - num_consumer_groups, - use_tma_load_on_scales, - ) = ( - kw["BLOCK_SIZE_M"], - kw["BLOCK_SIZE_N"], - kw["BLOCK_SIZE_K"], - config.num_stages, - config.num_warps, - config.num_consumer_groups, - kw.get("USE_TMA_LOAD_ON_SCALES", False), - ) - G, M, N, K = ( - named_args["G"], - named_args["M_BUCKET"], - named_args["N"], - named_args["K"], - ) - - # 1. make sure we have enough smem - max_shared_memory = driver.active.utils.get_device_properties(device)[ - "max_shared_mem" - ] - if torch.version.hip: - required_shared_memory = BLOCK_N * BLOCK_K * num_stages * dtsize - else: - required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize - if required_shared_memory > max_shared_memory: - continue - - use_warp_specialization = num_consumer_groups >= 1 - - M_PER_GROUP = M // G - MIN_M_TILES = 32 if torch.version.hip else 64 - # 2. make sure we don't load M tiles that are too big - if ( - not use_warp_specialization - and BLOCK_M > MIN_M_TILES - and BLOCK_M > (M_PER_GROUP * 2) - ): - continue - # 3. make sure we don't load N tiles that are too small - if BLOCK_M < 128 and BLOCK_M < (M_PER_GROUP // 2): - continue - - num_sm = driver.active.utils.get_device_properties(device)[ - "multiprocessor_count" - ] - N_TILES = N // BLOCK_N - MIN_N_TILES = 32 if torch.version.hip else 64 - # 4. make sure we don't load N tiles that are too big - if ( - not use_warp_specialization - and BLOCK_N > MIN_N_TILES - and M * N_TILES < num_sm - ): - continue - # 5. make sure we don't load N tiles that are too small - if BLOCK_N < 128 and M * N_TILES > 2 * num_sm: - continue - - # 6. make sure K can be evenly divided - if K % BLOCK_K != 0: - continue - - # 7. make sure we can partition for ws - if use_warp_specialization: - if num_warps != 4: - continue - - # "tritongpu-warp-spec-data-partition" - m_slice = BLOCK_M // num_consumer_groups - n_slice = BLOCK_N // num_consumer_groups - if m_slice < 64 and n_slice < 256: - continue - - if dtsize >= 2: - if use_tma_load_on_scales: - continue - pruned_configs.append(config) - - return pruned_configs - - -@triton.autotune( - configs=_AMD_CONFIGS if torch.version.hip else _NV_CONFIGS, - key=["G", "M_BUCKET", "N", "K"], - prune_configs_by={"early_config_prune": early_config_prune}, - restore_value=["c_ptr"], # restore for scatter_add fusion -) -@triton.jit -def _fbgemm_grouped_gemm( - a_desc_ptr, - b_desc_ptr, - c_ptr, - workspace, - scatter_add_indices, - m_sizes, - # problem sizes - G: tl.constexpr, - M_BUCKET, - N: tl.constexpr, - K: tl.constexpr, - NUM_SMS: tl.constexpr, - FUSE_SCATTER_ADD: tl.constexpr, - USE_TMA_LOAD: tl.constexpr, - USE_TMA_STORE: tl.constexpr, - USE_FAST_ACCUM: tl.constexpr, - # tile sizes - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - NUM_CONSUMER_GROUPS: tl.constexpr, -) -> None: - tl.static_assert( - not (FUSE_SCATTER_ADD and USE_TMA_STORE), - "Cannot fuse scatter add with TMA store!", - ) - - tidx = tl.program_id(0) - - dtype: tl.dtype = c_ptr.dtype.element_ty - TMA_SIZE: tl.constexpr = tl.constexpr(128) - if USE_TMA_STORE: - c_desc_ptr = workspace + tidx * TMA_SIZE - else: - c_desc_ptr = None - - M_end_offset = 0 - M_end_offset = M_end_offset.to(tl.int64) - iterated_tiles = 0 - iterated_tiles = iterated_tiles.to(tl.int64) - for g in tl.range(G): - # Move across groups - m_size = tl.load(m_sizes + g) - - if m_size > 0: - M_start_offset = M_end_offset - M_end_offset = M_start_offset + m_size - N_start_offset = g.to(tl.int64) * N - n_size = N - - num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M) - num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N) - num_tiles = num_m_tiles * num_n_tiles - - if USE_TMA_STORE: - # pyre-ignore - tl.extra.cuda.experimental_device_tensormap_create2d( - desc_ptr=c_desc_ptr, - global_address=c_ptr + M_start_offset * N, - load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N], - global_size=[m_size, n_size], - element_ty=c_ptr.dtype.element_ty, - ) - # pyre-ignore - tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr) - - # Move across tiles - while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles: - gidx = tidx - iterated_tiles - # Split M first and N second. - tile_m_idx = gidx % num_m_tiles - tile_n_idx = gidx // num_m_tiles - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - tl.static_assert(K % BLOCK_SIZE_K == 0) - if USE_TMA_LOAD: - m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32) - n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32) - for k_offset in range(0, K, BLOCK_SIZE_K): - a = tl._experimental_descriptor_load( - a_desc_ptr, - [m_offset, k_offset], - [BLOCK_SIZE_M, BLOCK_SIZE_K], - dtype, - ) - b = tl._experimental_descriptor_load( - b_desc_ptr, - [n_offset, k_offset], - [BLOCK_SIZE_N, BLOCK_SIZE_K], - dtype, - ) - if USE_FAST_ACCUM: - accumulator = tl.dot(a, b.T, accumulator) - else: - accumulator += tl.dot(a, b.T) - else: - offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = ( - a_desc_ptr - + (M_start_offset + offs_am[:, None]) * K - + offs_k[None, :] - ) - b_ptrs = ( - b_desc_ptr - + (N_start_offset + offs_bn[:, None]) * K - + offs_k[None, :] - ) - for k_offset in range(0, K, BLOCK_SIZE_K): - a = tl.load(a_ptrs, mask=offs_am[:, None] < m_size) - b = tl.load(b_ptrs, mask=offs_bn[:, None] < n_size) - accumulator += tl.dot(a, b.T) - a_ptrs += BLOCK_SIZE_K - b_ptrs += BLOCK_SIZE_K - - if USE_TMA_STORE: - m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32) - n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32) - tl._experimental_descriptor_store( - c_desc_ptr, - accumulator.to(c_ptr.dtype.element_ty), - [m_offset, n_offset], - ) - elif FUSE_SCATTER_ADD: - offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - mask = offs_am < m_size - m_offsets = tl.load( - scatter_add_indices + M_start_offset + offs_am, - mask=mask, - cache_modifier=".ca", - ) - offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c = accumulator.to(c_ptr.dtype.element_ty) - tl.atomic_add( - c_ptr + m_offsets[:, None] * N + offs_bn[None, :], - c, - mask=mask[:, None] and offs_bn[None, :] < n_size, - sem="relaxed", - ) - else: - offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c = accumulator.to(c_ptr.dtype.element_ty) - tl.store( - c_ptr - + (M_start_offset + offs_am[:, None]) * N - + offs_bn[None, :], - c, - mask=offs_am[:, None] < m_size and offs_bn[None, :] < n_size, - ) - tidx += NUM_SMS - - iterated_tiles += num_tiles - - -# TODO(shikaili): Too much code duplication. Need to refactor. -@triton.autotune( - configs=_NV_WS_CONFIGS, - key=["G", "M_BUCKET", "N", "K"], - prune_configs_by={"early_config_prune": early_config_prune}, - restore_value=["c_ptr"], # restore for scatter_add fusion -) -@triton.jit -def _fbgemm_grouped_gemm_ws( - a_desc_ptr, - b_desc_ptr, - c_ptr, - workspace, - scatter_add_indices, - m_sizes, - # problem sizes - G: tl.constexpr, - M_BUCKET: tl.constexpr, - N: tl.constexpr, - K: tl.constexpr, - NUM_SMS: tl.constexpr, - FUSE_SCATTER_ADD: tl.constexpr, - USE_TMA_LOAD: tl.constexpr, - USE_FAST_ACCUM: tl.constexpr, - # tile sizes - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - NUM_CONSUMER_GROUPS: tl.constexpr, - USE_TMA_LOAD_ON_SCALES: tl.constexpr, - USE_TMA_STORE: tl.constexpr, -) -> None: - tl.static_assert(USE_TMA_LOAD, "Always use TMA load with warp specialziation!") - tl.static_assert(not USE_TMA_LOAD_ON_SCALES, "Not supported!") - tl.static_assert( - not (FUSE_SCATTER_ADD and USE_TMA_STORE), - "Cannot fuse scatter add with TMA store!", - ) - - tidx = tl.program_id(0) - - dtype: tl.dtype = c_ptr.dtype.element_ty - TMA_SIZE: tl.constexpr = tl.constexpr(128) - if USE_TMA_STORE: - c_desc_ptr = workspace + tidx * TMA_SIZE - else: - c_desc_ptr = None - - M_end_offset = 0 - M_end_offset = M_end_offset.to(tl.int64) - iterated_tiles = 0 - iterated_tiles = iterated_tiles.to(tl.int64) - for g in tl.range(G): - # Move across groups - m_size = tl.load(m_sizes + g, cache_modifier=".ca") - - if m_size > 0: - M_start_offset = M_end_offset - M_end_offset = M_start_offset + m_size - N_start_offset = g.to(tl.int64) * N - - num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M) - tl.static_assert(N % BLOCK_SIZE_N == 0) - NUM_N_TILES: tl.constexpr = N // BLOCK_SIZE_N - num_tiles = num_m_tiles * NUM_N_TILES - - if USE_TMA_STORE: - with tl.async_task([0]): - # pyre-ignore - tl.extra.cuda.experimental_device_tensormap_create2d( - desc_ptr=c_desc_ptr, - global_address=c_ptr + M_start_offset * N, - load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N], - global_size=[m_size, N], - element_ty=c_ptr.dtype.element_ty, - ) - # pyre-ignore - tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr) - - # Move across tiles - next_iterated_tiles = iterated_tiles + num_tiles - if (tidx >= iterated_tiles) and (tidx < next_iterated_tiles): - for i in range(tidx, next_iterated_tiles, NUM_SMS): - gidx = i - iterated_tiles - # Split M first and N second. - tile_m_idx = gidx % num_m_tiles - tile_n_idx = gidx // num_m_tiles - - accumulator = tl.zeros( - (BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32 - ) - tl.static_assert(K % BLOCK_SIZE_K == 0) - - m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32) - n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32) - for k_offset in range(0, K, BLOCK_SIZE_K): - with tl.async_task([0]): - a = tl._experimental_descriptor_load( - a_desc_ptr, - [m_offset, k_offset], - [BLOCK_SIZE_M, BLOCK_SIZE_K], - dtype, - ) - b = tl._experimental_descriptor_load( - b_desc_ptr, - [n_offset, k_offset], - [BLOCK_SIZE_N, BLOCK_SIZE_K], - dtype, - ) - with tl.async_task([1, NUM_CONSUMER_GROUPS]): - if USE_FAST_ACCUM: - accumulator = tl.dot(a, b.T, accumulator) - else: - accumulator += tl.dot(a, b.T) - - if USE_TMA_STORE: - with tl.async_task([1, NUM_CONSUMER_GROUPS]): - m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32) - n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32) - tl._experimental_descriptor_store( - c_desc_ptr, - accumulator.to(c_ptr.dtype.element_ty), - [m_offset, n_offset], - ) - elif FUSE_SCATTER_ADD: - with tl.async_task([1, NUM_CONSUMER_GROUPS]): - offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange( - 0, BLOCK_SIZE_M - ) - mask = offs_am < m_size - m_offsets = tl.load( - scatter_add_indices + M_start_offset + offs_am, - mask=mask, - cache_modifier=".ca", - ) - offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange( - 0, BLOCK_SIZE_N - ) - c = accumulator.to(c_ptr.dtype.element_ty) - tl.atomic_add( - c_ptr + m_offsets[:, None] * N + offs_bn[None, :], - c, - mask=mask[:, None], - sem="relaxed", - ) - else: - with tl.async_task([1, NUM_CONSUMER_GROUPS]): - offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange( - 0, BLOCK_SIZE_M - ) - offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange( - 0, BLOCK_SIZE_N - ) - c = accumulator.to(c_ptr.dtype.element_ty) - tl.store( - c_ptr - + (M_start_offset + offs_am[:, None]) * N - + offs_bn[None, :], - c, - mask=offs_am[:, None] < m_size, - cache_modifier=".cs", - ) - tidx += NUM_SMS - - iterated_tiles += num_tiles - - -TT_FP8_DTYPE = tl.float8e4b8 if torch.version.hip else tl.float8e4nv - - -# TODO(shikaili): clean up redundant 'b_scale_desc_ptr' argument. -@triton.autotune( - configs=_AMD_CONFIGS if torch.version.hip else _NV_CONFIGS, - key=["G", "M_BUCKET", "N", "K"], - prune_configs_by={ - "early_config_prune": functools.partial( - early_config_prune, dtype=TT_FP8_DTYPE, dtsize=1 - ) - }, - restore_value=["c_ptr"], # restore for scatter_add fusion -) -@triton.jit -def _fbgemm_grouped_gemm_fp8_rowwise( - a_desc_ptr, - a_scale_ptr, - b_desc_ptr, - b_scale_ptr, - b_scale_desc_ptr, - c_ptr, - workspace, - scatter_add_indices, - m_sizes, - # problem sizes - G: tl.constexpr, - M_BUCKET, - N: tl.constexpr, - K: tl.constexpr, - NUM_SMS: tl.constexpr, - FUSE_SCATTER_ADD: tl.constexpr, - USE_TMA_LOAD: tl.constexpr, - USE_TMA_STORE: tl.constexpr, - USE_FAST_ACCUM: tl.constexpr, - # tile sizes - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - NUM_CONSUMER_GROUPS: tl.constexpr, -) -> None: - tl.static_assert( - not (FUSE_SCATTER_ADD and USE_TMA_STORE), - "Cannot fuse scatter add with TMA store!", - ) - - tidx = tl.program_id(0) - - dtype = TT_FP8_DTYPE - TMA_SIZE: tl.constexpr = tl.constexpr(128) - if USE_TMA_STORE: - c_desc_ptr = workspace + tidx * TMA_SIZE - else: - c_desc_ptr = None - - M_end_offset = 0 - M_end_offset = M_end_offset.to(tl.int64) - iterated_tiles = 0 - iterated_tiles = iterated_tiles.to(tl.int64) - for g in tl.range(G): - # Move across groups - m_size = tl.load(m_sizes + g) - - if m_size > 0: - M_start_offset = M_end_offset - M_end_offset = M_start_offset + m_size - N_start_offset = g.to(tl.int64) * N - n_size = N - - num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M) - num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N) - num_tiles = num_m_tiles * num_n_tiles - - if USE_TMA_STORE: - # pyre-ignore - tl.extra.cuda.experimental_device_tensormap_create2d( - desc_ptr=c_desc_ptr, - global_address=c_ptr + M_start_offset * N, - load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N], - global_size=[m_size, n_size], - element_ty=c_ptr.dtype.element_ty, - ) - # pyre-ignore - tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr) - - # Move across tiles - while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles: - gidx = tidx - iterated_tiles - # Split M first and N second. - tile_m_idx = gidx % num_m_tiles - tile_n_idx = gidx // num_m_tiles - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - tl.static_assert(K % BLOCK_SIZE_K == 0) - if USE_TMA_LOAD: - m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32) - n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32) - for k_offset in range(0, K, BLOCK_SIZE_K): - a = tl._experimental_descriptor_load( - a_desc_ptr, - [m_offset, k_offset], - [BLOCK_SIZE_M, BLOCK_SIZE_K], - dtype, - ) - b = tl._experimental_descriptor_load( - b_desc_ptr, - [n_offset, k_offset], - [BLOCK_SIZE_N, BLOCK_SIZE_K], - dtype, - ) - if USE_FAST_ACCUM: - accumulator = tl.dot(a, b.T, accumulator) - else: - accumulator += tl.dot(a, b.T) - else: - offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = ( - a_desc_ptr - + (M_start_offset + offs_am[:, None]) * K - + offs_k[None, :] - ) - b_ptrs = ( - b_desc_ptr - + (N_start_offset + offs_bn[:, None]) * K - + offs_k[None, :] - ) - for k_offset in range(0, K, BLOCK_SIZE_K): - a = tl.load(a_ptrs, mask=offs_am[:, None] < m_size) - b = tl.load(b_ptrs, mask=offs_bn[:, None] < n_size) - accumulator += tl.dot(a, b.T) - a_ptrs += BLOCK_SIZE_K - b_ptrs += BLOCK_SIZE_K - - offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - a_scale = tl.load( - a_scale_ptr + M_start_offset + offs_am[:, None], - mask=offs_am[:, None] < m_size, - ) - b_scale = tl.load( - b_scale_ptr + N_start_offset + offs_bn[None, :], - mask=offs_bn[None, :] < n_size, - ) - c = accumulator.to(tl.float32) * a_scale * b_scale - - if USE_TMA_STORE: - m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32) - n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32) - tl._experimental_descriptor_store( - c_desc_ptr, - c.to(c_ptr.dtype.element_ty), - [m_offset, n_offset], - ) - elif FUSE_SCATTER_ADD: - offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - mask = offs_am < m_size - m_offsets = tl.load( - scatter_add_indices + M_start_offset + offs_am, - mask=mask, - cache_modifier=".ca", - ) - offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - tl.atomic_add( - c_ptr + m_offsets[:, None] * N + offs_bn[None, :], - c.to(c_ptr.dtype.element_ty), - mask=mask[:, None] and offs_bn[None, :] < n_size, - sem="relaxed", - ) - else: - tl.store( - c_ptr - + (M_start_offset + offs_am[:, None]) * N - + offs_bn[None, :], - c, - mask=offs_am[:, None] < m_size and offs_bn[None, :] < n_size, - ) - tidx += NUM_SMS - - iterated_tiles += num_tiles - - -# TODO(shikaili): Too much code duplication. Need to refactor. -@triton.autotune( - configs=_NV_WS_CONFIGS, - key=["G", "M_BUCKET", "N", "K"], - prune_configs_by={ - "early_config_prune": functools.partial( - early_config_prune, dtype=TT_FP8_DTYPE, dtsize=1 - ) - }, - restore_value=["c_ptr"], # restore for scatter_add fusion -) -@triton.jit -def _fbgemm_grouped_gemm_fp8_rowwise_ws( - a_desc_ptr, - a_scale_ptr, - b_desc_ptr, - b_scale_ptr, - b_scale_desc_ptr, - c_ptr, - workspace, - scatter_add_indices, - m_sizes, - # problem sizes - G: tl.constexpr, - M_BUCKET: tl.constexpr, - N: tl.constexpr, - K: tl.constexpr, - NUM_SMS: tl.constexpr, - FUSE_SCATTER_ADD: tl.constexpr, - USE_TMA_LOAD: tl.constexpr, - USE_FAST_ACCUM: tl.constexpr, - # tile sizes - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - NUM_CONSUMER_GROUPS: tl.constexpr, - USE_TMA_LOAD_ON_SCALES: tl.constexpr, - USE_TMA_STORE: tl.constexpr, -) -> None: - tl.static_assert(USE_TMA_LOAD, "Always use TMA load with warp specialziation!") - tl.static_assert( - not (FUSE_SCATTER_ADD and USE_TMA_STORE), - "Cannot fuse scatter add with TMA store!", - ) - - tidx = tl.program_id(0) - - dtype = TT_FP8_DTYPE - TMA_SIZE: tl.constexpr = tl.constexpr(128) - if USE_TMA_STORE: - c_desc_ptr = workspace + tidx * TMA_SIZE - else: - c_desc_ptr = None - - M_end_offset = 0 - M_end_offset = M_end_offset.to(tl.int64) - iterated_tiles = 0 - iterated_tiles = iterated_tiles.to(tl.int64) - for g in tl.range(G): - # Move across groups - m_size = tl.load(m_sizes + g, cache_modifier=".ca") - - if m_size > 0: - M_start_offset = M_end_offset - M_end_offset = M_start_offset + m_size - N_start_offset = g.to(tl.int64) * N - - num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M) - tl.static_assert(N % BLOCK_SIZE_N == 0) - NUM_N_TILES: tl.constexpr = N // BLOCK_SIZE_N - num_tiles = num_m_tiles * NUM_N_TILES - - if USE_TMA_STORE: - with tl.async_task([0]): - # pyre-ignore - tl.extra.cuda.experimental_device_tensormap_create2d( - desc_ptr=c_desc_ptr, - global_address=c_ptr + M_start_offset * N, - load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N], - global_size=[m_size, N], - element_ty=c_ptr.dtype.element_ty, - ) - # pyre-ignore - tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr) - - # Move across tiles - next_iterated_tiles = iterated_tiles + num_tiles - if (tidx >= iterated_tiles) and (tidx < next_iterated_tiles): - for i in range(tidx, next_iterated_tiles, NUM_SMS): - gidx = i - iterated_tiles - # Split M first and N second. - tile_m_idx = gidx % num_m_tiles - tile_n_idx = gidx // num_m_tiles - - accumulator = tl.zeros( - (BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32 - ) - tl.static_assert(K % BLOCK_SIZE_K == 0) - - m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32) - n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32) - for k_offset in range(0, K, BLOCK_SIZE_K): - with tl.async_task([0]): - a = tl._experimental_descriptor_load( - a_desc_ptr, - [m_offset, k_offset], - [BLOCK_SIZE_M, BLOCK_SIZE_K], - dtype, - ) - b = tl._experimental_descriptor_load( - b_desc_ptr, - [n_offset, k_offset], - [BLOCK_SIZE_N, BLOCK_SIZE_K], - dtype, - ) - with tl.async_task([1, NUM_CONSUMER_GROUPS]): - if USE_FAST_ACCUM: - accumulator = tl.dot(a, b.T, accumulator) - else: - accumulator += tl.dot(a, b.T) - - if USE_TMA_LOAD_ON_SCALES: - with tl.async_task([0]): - b_scale = tl._experimental_descriptor_load( - b_scale_desc_ptr, - [n_offset], - [BLOCK_SIZE_N], - tl.float32, - ) - - with tl.async_task([1, NUM_CONSUMER_GROUPS]): - offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange( - 0, BLOCK_SIZE_M - ) - a_scale = tl.load( - a_scale_ptr + M_start_offset + offs_am[:, None], - mask=offs_am[:, None] < m_size, - cache_modifier=".ca", - ) - c = accumulator.to(tl.float32) * a_scale * b_scale[None, :] - else: - with tl.async_task([1, NUM_CONSUMER_GROUPS]): - offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange( - 0, BLOCK_SIZE_M - ) - offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange( - 0, BLOCK_SIZE_N - ) - a_scale = tl.load( - a_scale_ptr + M_start_offset + offs_am[:, None], - mask=offs_am[:, None] < m_size, - cache_modifier=".ca", - ) - b_scale = tl.load( - b_scale_ptr + N_start_offset + offs_bn[None, :], - cache_modifier=".ca", - ) - c = accumulator.to(tl.float32) * a_scale * b_scale - - if USE_TMA_STORE: - with tl.async_task([1, NUM_CONSUMER_GROUPS]): - m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32) - n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32) - tl._experimental_descriptor_store( - c_desc_ptr, - c.to(c_ptr.dtype.element_ty), - [m_offset, n_offset], - ) - elif FUSE_SCATTER_ADD: - with tl.async_task([1, NUM_CONSUMER_GROUPS]): - offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange( - 0, BLOCK_SIZE_M - ) - mask = offs_am < m_size - m_offsets = tl.load( - scatter_add_indices + M_start_offset + offs_am, - mask=mask, - cache_modifier=".ca", - ) - offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange( - 0, BLOCK_SIZE_N - ) - tl.atomic_add( - c_ptr + m_offsets[:, None] * N + offs_bn[None, :], - c, - mask=mask[:, None], - sem="relaxed", - ) - else: - with tl.async_task([1, NUM_CONSUMER_GROUPS]): - offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange( - 0, BLOCK_SIZE_M - ) - offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange( - 0, BLOCK_SIZE_N - ) - tl.store( - c_ptr - + (M_start_offset + offs_am[:, None]) * N - + offs_bn[None, :], - c, - mask=offs_am[:, None] < m_size, - cache_modifier=".cs", - ) - tidx += NUM_SMS - - iterated_tiles += num_tiles - - -warnings.simplefilter("once") - - -def _grouped_gemm( - *, - x: torch.Tensor, - w: torch.Tensor, - m_sizes: torch.Tensor, - x_scale: Optional[torch.Tensor], - w_scale: Optional[torch.Tensor], - use_fast_accum: bool, - use_warp_specialization: bool, - output_tensor: Optional[torch.Tensor], - scatter_add_indices: Optional[torch.Tensor], -) -> torch.Tensor: - - USE_TMA_LOAD = not torch.version.hip - USE_TMA_STORE = False - - if USE_TMA_LOAD and not HAS_TMA_DESC: - USE_TMA_LOAD = False - warnings.warn("TMA load is disabled as there is no TMA descriptor support!") - - if USE_TMA_STORE and not HAS_TMA_DESC: - USE_TMA_STORE = False - warnings.warn("TMA store is disabled as there is no TMA descriptor support!") - - # TODO(shikaili): Check the readniess of WS on ROCm side in Meta's Triton. - if use_warp_specialization and torch.version.hip: - warnings.warn("Warp specialization is disabled as it is not supported on ROCm.") - use_warp_specialization = False - - if use_warp_specialization and not _HAS_WS_SUPPORT: - warnings.warn( - "Warp specialization is disabled as the Triton build in current environment doesn't have such support. Please build from https://github.com/facebookexperimental/triton/tree/ws-3.2.x to enable it for best performance on Nvidia's SM90 GPUs." - ) - use_warp_specialization = False - - if use_warp_specialization: - assert HAS_TMA_DESC - USE_TMA_STORE = True # Tuning decision - - G = m_sizes.shape[0] - - assert x.is_contiguous() - assert w.is_contiguous() - assert m_sizes.is_contiguous() - - M, K = x.shape - N = w.shape[0] // G - assert K == w.shape[1] - - if output_tensor is None: - FUSE_SCATTER_ADD = False - assert scatter_add_indices is None - y = torch.empty((M, N), device=x.device, dtype=torch.bfloat16) - else: - FUSE_SCATTER_ADD = True - assert scatter_add_indices is not None - assert scatter_add_indices.is_contiguous() - assert scatter_add_indices.shape == (M,) - y = output_tensor - if M == 0 or N == 0: - return y - - NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count - - desc_helper = None - desc_x = x - desc_w = w - desc_ws = w_scale - workspace = None - - if USE_TMA_LOAD: - desc_helper = TmaAutoTuneHelper() - desc_helper.init_tma_descriptor("x") - desc_helper.init_tma_descriptor("w") - desc_x = desc_helper.get_tma_descriptor_kernel_param("x") - desc_w = desc_helper.get_tma_descriptor_kernel_param("w") - if use_warp_specialization and w_scale is not None: - desc_helper.init_tma_descriptor("ws") - desc_ws = desc_helper.get_tma_descriptor_kernel_param("ws") - - if USE_TMA_STORE: - workspace = torch.empty( - NUM_SMS * TmaAutoTuneHelper.TMA_SIZE, - device=x.device, - dtype=torch.uint8, - ) - - def grid(META): - if USE_TMA_LOAD: - nonlocal desc_helper # noqa: F824 - desc_helper.fill_2d_tma_descriptor( - "x", - x.data_ptr(), - M, - K, - META["BLOCK_SIZE_M"] // META["NUM_CONSUMER_GROUPS"], - META["BLOCK_SIZE_K"], - x.element_size(), - ) - - desc_helper.fill_2d_tma_descriptor( - "w", - w.data_ptr(), - N * G, - K, - META["BLOCK_SIZE_N"], - META["BLOCK_SIZE_K"], - w.element_size(), - ) - - if META.get("USE_TMA_LOAD_ON_SCALES", False): - desc_helper.fill_1d_tma_descriptor( - "ws", - w_scale.data_ptr(), - N * G, - META["BLOCK_SIZE_N"], - w_scale.element_size(), - ) - - return (NUM_SMS,) - - M_BUCKET_CAP = 16384 - M_BUCKET = min(triton.next_power_of_2(M), M_BUCKET_CAP) - if x_scale is not None and w_scale is not None: - assert x_scale.is_contiguous() - assert w_scale.is_contiguous() - fn = ( - _fbgemm_grouped_gemm_fp8_rowwise_ws - if use_warp_specialization - else _fbgemm_grouped_gemm_fp8_rowwise - ) - args = ( - desc_x, - x_scale, - desc_w, - w_scale, - desc_ws, - y, - workspace, - scatter_add_indices, - m_sizes, - G, - M_BUCKET, - N, - K, - NUM_SMS, - FUSE_SCATTER_ADD, - USE_TMA_LOAD, - ) - if use_warp_specialization: - args += (use_fast_accum,) - else: - args += (USE_TMA_STORE, use_fast_accum) - fn[grid](*args) - else: - assert x_scale is None - assert w_scale is None - fn = ( - _fbgemm_grouped_gemm_ws if use_warp_specialization else _fbgemm_grouped_gemm - ) - args = ( - desc_x, - desc_w, - y, - workspace, - scatter_add_indices, - m_sizes, - G, - M_BUCKET, - N, - K, - NUM_SMS, - FUSE_SCATTER_ADD, - USE_TMA_LOAD, - ) - if use_warp_specialization: - args += (use_fast_accum,) - else: - args += (USE_TMA_STORE, use_fast_accum) - fn[grid](*args) - - return y - - -def grouped_gemm( - x: torch.Tensor, - w: torch.Tensor, - m_sizes: torch.Tensor, - use_fast_accum: bool = True, - *, - _use_warp_specialization: bool = True, - _output_tensor: Optional[torch.Tensor] = None, - _scatter_add_indices: Optional[torch.Tensor] = None, -) -> torch.Tensor: - return _grouped_gemm( - x=x, - w=w, - m_sizes=m_sizes, - x_scale=None, - w_scale=None, - use_fast_accum=use_fast_accum, - use_warp_specialization=_use_warp_specialization, - output_tensor=_output_tensor, - scatter_add_indices=_scatter_add_indices, - ) - - -def grouped_gemm_fp8_rowwise( - x: torch.Tensor, - w: torch.Tensor, - m_sizes: torch.Tensor, - x_scale: torch.Tensor, - w_scale: torch.Tensor, - use_fast_accum: bool = True, - *, - _use_warp_specialization: bool = True, - _output_tensor: Optional[torch.Tensor] = None, - _scatter_add_indices: Optional[torch.Tensor] = None, -) -> torch.Tensor: - return _grouped_gemm( - x=x, - w=w, - m_sizes=m_sizes, - x_scale=x_scale, - w_scale=w_scale, - use_fast_accum=use_fast_accum, - use_warp_specialization=_use_warp_specialization, - output_tensor=_output_tensor, - scatter_add_indices=_scatter_add_indices, - ) diff --git a/benchmark/fbgemm/test_grouped_gemm.py b/benchmark/fbgemm/test_grouped_gemm.py deleted file mode 100644 index 92d1b213e..000000000 --- a/benchmark/fbgemm/test_grouped_gemm.py +++ /dev/null @@ -1,323 +0,0 @@ -import os -import sys - -import pytest -import torch - -sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) - -try: - from fbgemm_grouped_gemm import grouped_gemm as fbgemm_grouped_gemm - from fbgemm_grouped_gemm import ( - grouped_gemm_fp8_rowwise as fbgemm_grouped_gemm_fp8_rowwise, - ) - - FBGEMM_AVAILABLE = True - print("✓ Successfully imported FBGEMM grouped GEMM") -except ImportError as e: - print(f"✗ Failed to import FBGEMM grouped GEMM: {e}") - FBGEMM_AVAILABLE = False - -try: - from sglang.srt.layers.moe.ep_moe.kernels import ( - grouped_gemm_triton as sglang_grouped_gemm, - ) - - SGLANG_AVAILABLE = True - print("✓ Successfully imported SGLang grouped GEMM") -except ImportError as e: - print(f"✗ Failed to import SGLang grouped GEMM: {e}") - SGLANG_AVAILABLE = False - - -def create_uniform_groups(batch_size, num_groups, device): - tokens_per_group = batch_size // num_groups - return torch.full((num_groups,), tokens_per_group, dtype=torch.int64, device=device) - - -def create_non_uniform_groups(batch_size, num_groups, device): - remaining = batch_size - m_sizes = [] - - for i in range(num_groups - 1): - if remaining <= 1: - size = 1 - else: - max_size = remaining - (num_groups - i - 1) + 1 - size = torch.randint(1, max_size, (1,)).item() - m_sizes.append(size) - remaining -= size - - m_sizes.append(remaining) - return torch.tensor(m_sizes, dtype=torch.int64, device=device) - - -def create_sglang_inputs(x, w, m_sizes, num_groups, intermediate_size, device): - batch_size = x.shape[0] - - c_sglang = torch.empty( - batch_size, intermediate_size, dtype=torch.bfloat16, device=device - ) - - seg_indptr = torch.zeros(num_groups + 1, dtype=torch.int64, device=device) - current_pos = 0 - for i, size in enumerate(m_sizes): - current_pos += size - seg_indptr[i + 1] = current_pos - - weight_indices = torch.arange(num_groups, dtype=torch.int64, device=device) - w_sglang = w.view(num_groups, intermediate_size, -1) - - return c_sglang, seg_indptr, weight_indices, w_sglang - - -def create_fp8_data(batch_size, num_groups, hidden_size, intermediate_size, device): - torch.manual_seed(42) - - x_fp16 = torch.randn(batch_size, hidden_size, dtype=torch.float16, device=device) - w_fp16 = torch.randn( - num_groups * intermediate_size, hidden_size, dtype=torch.float16, device=device - ) - - x_fp8 = x_fp16.to(torch.float8_e4m3fn) - w_fp8 = w_fp16.to(torch.float8_e4m3fn) - - x_scale = torch.randn(batch_size, dtype=torch.float32, device=device).abs() + 1e-4 - w_scale = torch.randn(num_groups, dtype=torch.float32, device=device).abs() + 1e-4 - - return x_fp8, w_fp8, x_scale, w_scale - - -@pytest.fixture -def device(): - if not torch.cuda.is_available(): - pytest.skip("CUDA not available") - return torch.device("cuda") - - -@pytest.mark.skipif(not FBGEMM_AVAILABLE, reason="FBGEMM not available") -@pytest.mark.skipif(not SGLANG_AVAILABLE, reason="SGLang not available") -@pytest.mark.parametrize("batch_size", [32]) -@pytest.mark.parametrize("num_groups", [2, 4, 8]) -@pytest.mark.parametrize("hidden_size", [512, 1024]) -@pytest.mark.parametrize("intermediate_size", [1024, 2048]) -def test_uniform_groups(batch_size, num_groups, hidden_size, intermediate_size, device): - if batch_size % num_groups != 0: - pytest.skip(f"batch_size {batch_size} not divisible by num_groups {num_groups}") - - torch.manual_seed(42) - - m_sizes = create_uniform_groups(batch_size, num_groups, device) - - x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device) - w = torch.randn( - num_groups * intermediate_size, hidden_size, dtype=torch.bfloat16, device=device - ) - - result_fbgemm = fbgemm_grouped_gemm(x, w, m_sizes, use_fast_accum=True) - - c_sglang, seg_indptr, weight_indices, w_sglang = create_sglang_inputs( - x, w, m_sizes, num_groups, intermediate_size, device - ) - - result_sglang = sglang_grouped_gemm( - x, - w_sglang, - c_sglang, - num_groups, - weight_column_major=True, - seg_indptr=seg_indptr, - weight_indices=weight_indices, - c_dtype=c_sglang.dtype, - ) - - assert torch.allclose(result_fbgemm, result_sglang, rtol=1e-3, atol=1e-3) - - -@pytest.mark.skipif(not FBGEMM_AVAILABLE, reason="FBGEMM not available") -@pytest.mark.skipif(not SGLANG_AVAILABLE, reason="SGLang not available") -@pytest.mark.parametrize("batch_size", [63, 100, 127]) -@pytest.mark.parametrize("num_groups", [3, 5, 7]) -@pytest.mark.parametrize("hidden_size", [512, 1024]) -@pytest.mark.parametrize("intermediate_size", [1024, 2048]) -def test_non_uniform_groups( - batch_size, num_groups, hidden_size, intermediate_size, device -): - torch.manual_seed(42) - - m_sizes = create_non_uniform_groups(batch_size, num_groups, device) - - x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device) - w = torch.randn( - num_groups * intermediate_size, hidden_size, dtype=torch.bfloat16, device=device - ) - - result_fbgemm = fbgemm_grouped_gemm(x, w, m_sizes, use_fast_accum=True) - - c_sglang, seg_indptr, weight_indices, w_sglang = create_sglang_inputs( - x, w, m_sizes, num_groups, intermediate_size, device - ) - - result_sglang = sglang_grouped_gemm( - x, - w_sglang, - c_sglang, - num_groups, - weight_column_major=True, - seg_indptr=seg_indptr, - weight_indices=weight_indices, - c_dtype=c_sglang.dtype, - ) - - assert torch.allclose(result_fbgemm, result_sglang, rtol=1e-3, atol=1e-3) - - -@pytest.mark.skipif(not FBGEMM_AVAILABLE, reason="FBGEMM not available") -@pytest.mark.skipif(not SGLANG_AVAILABLE, reason="SGLang not available") -@pytest.mark.parametrize("batch_size,num_groups", [(64, 4), (128, 8), (256, 16)]) -@pytest.mark.parametrize("hidden_size", [768, 2048, 4096]) -@pytest.mark.parametrize("intermediate_size", [2048, 4096, 8192]) -def test_large_dimensions( - batch_size, num_groups, hidden_size, intermediate_size, device -): - torch.manual_seed(42) - - m_sizes = create_uniform_groups(batch_size, num_groups, device) - - x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device) - w = torch.randn( - num_groups * intermediate_size, hidden_size, dtype=torch.bfloat16, device=device - ) - - result_fbgemm = fbgemm_grouped_gemm(x, w, m_sizes, use_fast_accum=True) - - c_sglang, seg_indptr, weight_indices, w_sglang = create_sglang_inputs( - x, w, m_sizes, num_groups, intermediate_size, device - ) - - result_sglang = sglang_grouped_gemm( - x, - w_sglang, - c_sglang, - num_groups, - weight_column_major=True, - seg_indptr=seg_indptr, - weight_indices=weight_indices, - c_dtype=c_sglang.dtype, - ) - - assert torch.allclose(result_fbgemm, result_sglang, rtol=1e-3, atol=1e-3) - - -@pytest.mark.skipif(not FBGEMM_AVAILABLE, reason="FBGEMM not available") -@pytest.mark.parametrize("batch_size", [32, 64]) -@pytest.mark.parametrize("num_groups", [2, 4]) -@pytest.mark.parametrize("hidden_size", [512, 1024]) -@pytest.mark.parametrize("intermediate_size", [1024, 2048]) -def test_fp8_uniform_groups( - batch_size, num_groups, hidden_size, intermediate_size, device -): - if batch_size % num_groups != 0: - pytest.skip(f"batch_size {batch_size} not divisible by num_groups {num_groups}") - - torch.manual_seed(42) - - m_sizes = create_uniform_groups(batch_size, num_groups, device) - x_fp8, w_fp8, x_scale, w_scale = create_fp8_data( - batch_size, num_groups, hidden_size, intermediate_size, device - ) - - try: - result_fp8 = fbgemm_grouped_gemm_fp8_rowwise( - x_fp8, w_fp8, m_sizes, x_scale, w_scale, use_fast_accum=True - ) - assert result_fp8.shape == (batch_size, intermediate_size) - assert result_fp8.dtype == torch.bfloat16 - except Exception as e: - pytest.skip(f"FP8 test failed (possibly unsupported): {e}") - - -@pytest.mark.skipif(not FBGEMM_AVAILABLE, reason="FBGEMM not available") -@pytest.mark.parametrize("batch_size", [63, 100]) -@pytest.mark.parametrize("num_groups", [3, 5]) -@pytest.mark.parametrize("hidden_size", [512, 1024]) -@pytest.mark.parametrize("intermediate_size", [1024, 2048]) -def test_fp8_non_uniform_groups( - batch_size, num_groups, hidden_size, intermediate_size, device -): - torch.manual_seed(42) - - m_sizes = create_non_uniform_groups(batch_size, num_groups, device) - x_fp8, w_fp8, x_scale, w_scale = create_fp8_data( - batch_size, num_groups, hidden_size, intermediate_size, device - ) - - try: - result_fp8 = fbgemm_grouped_gemm_fp8_rowwise( - x_fp8, w_fp8, m_sizes, x_scale, w_scale, use_fast_accum=True - ) - assert result_fp8.shape == (batch_size, intermediate_size) - assert result_fp8.dtype == torch.bfloat16 - except Exception as e: - pytest.skip(f"FP8 test failed (possibly unsupported): {e}") - - -@pytest.mark.skipif(not FBGEMM_AVAILABLE, reason="FBGEMM not available") -def test_fbgemm_only_uniform(device): - torch.manual_seed(42) - - batch_size, num_groups = 64, 4 - hidden_size, intermediate_size = 512, 1024 - - m_sizes = create_uniform_groups(batch_size, num_groups, device) - x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device) - w = torch.randn( - num_groups * intermediate_size, hidden_size, dtype=torch.bfloat16, device=device - ) - - result = fbgemm_grouped_gemm(x, w, m_sizes, use_fast_accum=True) - - assert result.shape == (batch_size, intermediate_size) - assert result.dtype == torch.bfloat16 - - -@pytest.mark.skipif(not SGLANG_AVAILABLE, reason="SGLang not available") -def test_sglang_only_uniform(device): - torch.manual_seed(42) - - batch_size, num_groups = 64, 4 - hidden_size, intermediate_size = 512, 1024 - - m_sizes = create_uniform_groups(batch_size, num_groups, device) - x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device) - w = torch.randn( - num_groups * intermediate_size, hidden_size, dtype=torch.bfloat16, device=device - ) - - c_sglang, seg_indptr, weight_indices, w_sglang = create_sglang_inputs( - x, w, m_sizes, num_groups, intermediate_size, device - ) - - result = sglang_grouped_gemm( - x, - w_sglang, - c_sglang, - num_groups, - weight_column_major=True, - seg_indptr=seg_indptr, - weight_indices=weight_indices, - c_dtype=c_sglang.dtype, - ) - - assert result.shape == (batch_size, intermediate_size) - assert result.dtype == torch.bfloat16 - - -def test_imports(): - assert ( - FBGEMM_AVAILABLE or SGLANG_AVAILABLE - ), "Neither FBGEMM nor SGLang is available" - - -if __name__ == "__main__": - pytest.main([__file__, "-v"])