1295 lines
46 KiB
Python
1295 lines
46 KiB
Python
# Copy from https://github.com/pytorch/FBGEMM/tree/main/fbgemm_gpu/experimental/gen_ai
|
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
# pyre-unsafe
|
|
|
|
import functools
|
|
import inspect
|
|
import sys
|
|
import warnings
|
|
from typing import Optional
|
|
|
|
import torch
|
|
import triton # @manual
|
|
import triton.language as tl # @manual
|
|
from triton.runtime import driver # @manual
|
|
|
|
|
|
def map_dtype_to_triton(dtype: torch.dtype) -> tl.dtype:
|
|
"""
|
|
Maps torch dtype to triton dtype.
|
|
|
|
Args:
|
|
dtype (torch.dtype): input dtype.
|
|
|
|
Returns:
|
|
tl.dtype: triton dtype.
|
|
"""
|
|
if dtype == torch.float16:
|
|
return tl.float16
|
|
elif dtype == torch.bfloat16:
|
|
return tl.bfloat16
|
|
elif dtype == torch.float32:
|
|
return tl.float32
|
|
elif dtype == torch.int32:
|
|
return tl.int32
|
|
elif dtype == torch.float8_e4m3fn and torch.version.hip is None:
|
|
return tl.float8e4nv
|
|
else:
|
|
raise ValueError(f"Unsupported dtype {dtype}")
|
|
|
|
|
|
# check if we have the TMA version in Triton PR #4498 (https://github.com/triton-lang/triton/pull/4498).
|
|
HAS_TMA_DESC = "nv_tma_desc_type" in dir(tl)
|
|
|
|
if HAS_TMA_DESC:
|
|
print(
|
|
"TMA benchmarks will be running with experimental grid constant TMA descriptor.",
|
|
file=sys.stderr,
|
|
)
|
|
else:
|
|
print(
|
|
"TMA benchmarks will be running without grid constant TMA descriptor.",
|
|
file=sys.stderr,
|
|
)
|
|
|
|
|
|
class TmaAutoTuneHelper:
|
|
|
|
# duck typing wrapper to implement the same interface as TmaDescKernelParam in Triton PR #4498
|
|
class KernelParamWrapper:
|
|
def __init__(self, desc):
|
|
self.desc = desc
|
|
|
|
def tma_desc_cpu_ptr(self):
|
|
return self.desc.data_ptr()
|
|
|
|
TMA_SIZE = 128
|
|
|
|
def __init__(self):
|
|
self.fill_1d_tma_descriptor_inner = (
|
|
triton.runtime.driver.active.utils.fill_1d_tma_descriptor
|
|
)
|
|
self.fill_2d_tma_descriptor_inner = (
|
|
triton.runtime.driver.active.utils.fill_2d_tma_descriptor
|
|
)
|
|
if HAS_TMA_DESC:
|
|
self.descriptors = {}
|
|
else:
|
|
self.cuda_descriptors = {}
|
|
|
|
# Call this method outside of the lambda function for grid size
|
|
def init_tma_descriptor(self, name):
|
|
if HAS_TMA_DESC:
|
|
self.descriptors[name] = torch.empty(
|
|
TmaAutoTuneHelper.TMA_SIZE, device="cpu", dtype=torch.int8
|
|
)
|
|
else:
|
|
self.cuda_descriptors[name] = torch.empty(
|
|
TmaAutoTuneHelper.TMA_SIZE, device="cuda", dtype=torch.int8
|
|
)
|
|
|
|
# Call this method inside the lambda function for grid size
|
|
def fill_1d_tma_descriptor(self, name, ptr, dim, block_dim, element_size):
|
|
if HAS_TMA_DESC:
|
|
desc_x = self.descriptors[name]
|
|
assert desc_x.data_ptr() % 64 == 0
|
|
self.fill_1d_tma_descriptor_inner(
|
|
ptr, dim, block_dim, element_size, desc_x.data_ptr()
|
|
)
|
|
else:
|
|
desc_x = self.cuda_descriptors[name]
|
|
buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True)
|
|
self.fill_1d_tma_descriptor_inner(
|
|
ptr, dim, block_dim, element_size, buf_x.data_ptr()
|
|
)
|
|
desc_x.copy_(buf_x, non_blocking=True)
|
|
|
|
# Call this method inside the lambda function for grid size
|
|
def fill_2d_tma_descriptor(
|
|
self, name, ptr, dim1, dim0, block_dim1, block_dim0, element_size
|
|
):
|
|
if HAS_TMA_DESC:
|
|
desc_x = self.descriptors[name]
|
|
assert desc_x.data_ptr() % 64 == 0
|
|
self.fill_2d_tma_descriptor_inner(
|
|
ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr()
|
|
)
|
|
else:
|
|
desc_x = self.cuda_descriptors[name]
|
|
buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True)
|
|
self.fill_2d_tma_descriptor_inner(
|
|
ptr, dim1, dim0, block_dim1, block_dim0, element_size, buf_x.data_ptr()
|
|
)
|
|
desc_x.copy_(buf_x, non_blocking=True)
|
|
|
|
def get_tma_descriptor_kernel_param(self, name):
|
|
if HAS_TMA_DESC:
|
|
assert self.descriptors[name] is not None
|
|
return self.KernelParamWrapper(self.descriptors[name])
|
|
else:
|
|
assert self.cuda_descriptors[name] is not None
|
|
return self.cuda_descriptors[name]
|
|
|
|
|
|
_NV_CONFIGS = [
|
|
triton.Config(
|
|
{
|
|
"BLOCK_SIZE_M": block_size_m,
|
|
"BLOCK_SIZE_N": block_size_n,
|
|
"BLOCK_SIZE_K": block_size_k,
|
|
"NUM_CONSUMER_GROUPS": 1,
|
|
},
|
|
num_stages=num_stages,
|
|
num_warps=num_warps,
|
|
num_ctas=num_ctas,
|
|
)
|
|
for block_size_m in [64, 128]
|
|
for block_size_n in [64, 128, 256]
|
|
for block_size_k in [64, 128, 256]
|
|
for num_stages in [3, 4]
|
|
for num_warps in [4, 8]
|
|
for num_ctas in [1]
|
|
]
|
|
|
|
_HAS_WS_SUPPORT = None
|
|
|
|
|
|
def _check_ws_support():
|
|
if not hasattr(tl, "async_task"):
|
|
return False
|
|
config_signature = inspect.signature(triton.Config).parameters
|
|
if (
|
|
"num_consumer_groups" not in config_signature
|
|
or "num_buffers_warp_spec" not in config_signature
|
|
):
|
|
return False
|
|
if not HAS_TMA_DESC:
|
|
return False
|
|
return True
|
|
|
|
|
|
def _set_ws_support():
|
|
global _HAS_WS_SUPPORT
|
|
if _HAS_WS_SUPPORT is None:
|
|
_HAS_WS_SUPPORT = _check_ws_support()
|
|
|
|
|
|
_set_ws_support()
|
|
|
|
if _HAS_WS_SUPPORT:
|
|
_NV_WS_CONFIGS = [
|
|
triton.Config(
|
|
{
|
|
"BLOCK_SIZE_M": block_size_m,
|
|
"BLOCK_SIZE_N": block_size_n,
|
|
"BLOCK_SIZE_K": block_size_k,
|
|
"NUM_CONSUMER_GROUPS": max(1, num_consumer_groups),
|
|
"USE_TMA_LOAD_ON_SCALES": use_tma_load_on_scales,
|
|
"USE_TMA_STORE": use_tma_store,
|
|
},
|
|
num_stages=num_stages,
|
|
num_warps=num_warps,
|
|
num_ctas=num_ctas,
|
|
num_consumer_groups=num_consumer_groups,
|
|
num_buffers_warp_spec=num_stages,
|
|
)
|
|
for block_size_m in [64, 128, 256]
|
|
for block_size_n in [64, 128, 256]
|
|
for block_size_k in [64, 128, 256]
|
|
for num_stages in [2, 3, 4]
|
|
for num_warps in [4, 8, 16]
|
|
# TODO(shikaili): Resolve LLVM error.
|
|
for num_ctas in [1]
|
|
for num_consumer_groups in [0, 2]
|
|
for use_tma_load_on_scales in [True, False]
|
|
# TODO(shikaili): Resolve compatibility with ws.
|
|
for use_tma_store in [False]
|
|
]
|
|
else:
|
|
_NV_WS_CONFIGS = _NV_CONFIGS
|
|
|
|
|
|
_AMD_CONFIGS = [
|
|
triton.Config(
|
|
{
|
|
"BLOCK_SIZE_M": block_size_m,
|
|
"BLOCK_SIZE_N": block_size_n,
|
|
"BLOCK_SIZE_K": block_size_k,
|
|
"waves_per_eu": waves_per_cu,
|
|
"matrix_instr_nonkdim": matrix_instr_nonkdim,
|
|
"NUM_CONSUMER_GROUPS": 1,
|
|
},
|
|
num_stages=num_stages,
|
|
num_warps=num_warps,
|
|
)
|
|
for block_size_m in [32, 64, 128]
|
|
for block_size_n in [32, 64, 128, 256]
|
|
for block_size_k in [128, 256]
|
|
for num_stages in [1, 2]
|
|
for num_warps, waves_per_cu in [(4, 1), (8, 2), (16, 4)]
|
|
for matrix_instr_nonkdim in [16]
|
|
]
|
|
|
|
|
|
def early_config_prune(configs, named_args, dtsize=None, dtype=None, **kwargs):
|
|
device = torch.cuda.current_device()
|
|
# BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages
|
|
if dtsize is None:
|
|
dtsize = named_args["c_ptr"].element_size()
|
|
if dtype is None:
|
|
dtype = named_args["c_ptr"].dtype
|
|
|
|
pruned_configs = []
|
|
for config in configs:
|
|
kw = config.kwargs
|
|
(
|
|
BLOCK_M,
|
|
BLOCK_N,
|
|
BLOCK_K,
|
|
num_stages,
|
|
num_warps,
|
|
num_consumer_groups,
|
|
use_tma_load_on_scales,
|
|
) = (
|
|
kw["BLOCK_SIZE_M"],
|
|
kw["BLOCK_SIZE_N"],
|
|
kw["BLOCK_SIZE_K"],
|
|
config.num_stages,
|
|
config.num_warps,
|
|
config.num_consumer_groups,
|
|
kw.get("USE_TMA_LOAD_ON_SCALES", False),
|
|
)
|
|
G, M, N, K = (
|
|
named_args["G"],
|
|
named_args["M_BUCKET"],
|
|
named_args["N"],
|
|
named_args["K"],
|
|
)
|
|
|
|
# 1. make sure we have enough smem
|
|
max_shared_memory = driver.active.utils.get_device_properties(device)[
|
|
"max_shared_mem"
|
|
]
|
|
if torch.version.hip:
|
|
required_shared_memory = BLOCK_N * BLOCK_K * num_stages * dtsize
|
|
else:
|
|
required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize
|
|
if required_shared_memory > max_shared_memory:
|
|
continue
|
|
|
|
use_warp_specialization = num_consumer_groups >= 1
|
|
|
|
M_PER_GROUP = M // G
|
|
MIN_M_TILES = 32 if torch.version.hip else 64
|
|
# 2. make sure we don't load M tiles that are too big
|
|
if (
|
|
not use_warp_specialization
|
|
and BLOCK_M > MIN_M_TILES
|
|
and BLOCK_M > (M_PER_GROUP * 2)
|
|
):
|
|
continue
|
|
# 3. make sure we don't load N tiles that are too small
|
|
if BLOCK_M < 128 and BLOCK_M < (M_PER_GROUP // 2):
|
|
continue
|
|
|
|
num_sm = driver.active.utils.get_device_properties(device)[
|
|
"multiprocessor_count"
|
|
]
|
|
N_TILES = N // BLOCK_N
|
|
MIN_N_TILES = 32 if torch.version.hip else 64
|
|
# 4. make sure we don't load N tiles that are too big
|
|
if (
|
|
not use_warp_specialization
|
|
and BLOCK_N > MIN_N_TILES
|
|
and M * N_TILES < num_sm
|
|
):
|
|
continue
|
|
# 5. make sure we don't load N tiles that are too small
|
|
if BLOCK_N < 128 and M * N_TILES > 2 * num_sm:
|
|
continue
|
|
|
|
# 6. make sure K can be evenly divided
|
|
if K % BLOCK_K != 0:
|
|
continue
|
|
|
|
# 7. make sure we can partition for ws
|
|
if use_warp_specialization:
|
|
if num_warps != 4:
|
|
continue
|
|
|
|
# "tritongpu-warp-spec-data-partition"
|
|
m_slice = BLOCK_M // num_consumer_groups
|
|
n_slice = BLOCK_N // num_consumer_groups
|
|
if m_slice < 64 and n_slice < 256:
|
|
continue
|
|
|
|
if dtsize >= 2:
|
|
if use_tma_load_on_scales:
|
|
continue
|
|
pruned_configs.append(config)
|
|
|
|
return pruned_configs
|
|
|
|
|
|
@triton.autotune(
|
|
configs=_AMD_CONFIGS if torch.version.hip else _NV_CONFIGS,
|
|
key=["G", "M_BUCKET", "N", "K"],
|
|
prune_configs_by={"early_config_prune": early_config_prune},
|
|
restore_value=["c_ptr"], # restore for scatter_add fusion
|
|
)
|
|
@triton.jit
|
|
def _fbgemm_grouped_gemm(
|
|
a_desc_ptr,
|
|
b_desc_ptr,
|
|
c_ptr,
|
|
workspace,
|
|
scatter_add_indices,
|
|
m_sizes,
|
|
# problem sizes
|
|
G: tl.constexpr,
|
|
M_BUCKET,
|
|
N: tl.constexpr,
|
|
K: tl.constexpr,
|
|
NUM_SMS: tl.constexpr,
|
|
FUSE_SCATTER_ADD: tl.constexpr,
|
|
USE_TMA_LOAD: tl.constexpr,
|
|
USE_TMA_STORE: tl.constexpr,
|
|
USE_FAST_ACCUM: tl.constexpr,
|
|
# tile sizes
|
|
BLOCK_SIZE_M: tl.constexpr,
|
|
BLOCK_SIZE_N: tl.constexpr,
|
|
BLOCK_SIZE_K: tl.constexpr,
|
|
NUM_CONSUMER_GROUPS: tl.constexpr,
|
|
) -> None:
|
|
tl.static_assert(
|
|
not (FUSE_SCATTER_ADD and USE_TMA_STORE),
|
|
"Cannot fuse scatter add with TMA store!",
|
|
)
|
|
|
|
tidx = tl.program_id(0)
|
|
|
|
dtype: tl.dtype = c_ptr.dtype.element_ty
|
|
TMA_SIZE: tl.constexpr = tl.constexpr(128)
|
|
if USE_TMA_STORE:
|
|
c_desc_ptr = workspace + tidx * TMA_SIZE
|
|
else:
|
|
c_desc_ptr = None
|
|
|
|
M_end_offset = 0
|
|
M_end_offset = M_end_offset.to(tl.int64)
|
|
iterated_tiles = 0
|
|
iterated_tiles = iterated_tiles.to(tl.int64)
|
|
for g in tl.range(G):
|
|
# Move across groups
|
|
m_size = tl.load(m_sizes + g)
|
|
|
|
if m_size > 0:
|
|
M_start_offset = M_end_offset
|
|
M_end_offset = M_start_offset + m_size
|
|
N_start_offset = g.to(tl.int64) * N
|
|
n_size = N
|
|
|
|
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
|
|
num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
|
|
num_tiles = num_m_tiles * num_n_tiles
|
|
|
|
if USE_TMA_STORE:
|
|
# pyre-ignore
|
|
tl.extra.cuda.experimental_device_tensormap_create2d(
|
|
desc_ptr=c_desc_ptr,
|
|
global_address=c_ptr + M_start_offset * N,
|
|
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
|
|
global_size=[m_size, n_size],
|
|
element_ty=c_ptr.dtype.element_ty,
|
|
)
|
|
# pyre-ignore
|
|
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
|
|
|
|
# Move across tiles
|
|
while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles:
|
|
gidx = tidx - iterated_tiles
|
|
# Split M first and N second.
|
|
tile_m_idx = gidx % num_m_tiles
|
|
tile_n_idx = gidx // num_m_tiles
|
|
|
|
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
|
tl.static_assert(K % BLOCK_SIZE_K == 0)
|
|
if USE_TMA_LOAD:
|
|
m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
|
|
n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
|
|
for k_offset in range(0, K, BLOCK_SIZE_K):
|
|
a = tl._experimental_descriptor_load(
|
|
a_desc_ptr,
|
|
[m_offset, k_offset],
|
|
[BLOCK_SIZE_M, BLOCK_SIZE_K],
|
|
dtype,
|
|
)
|
|
b = tl._experimental_descriptor_load(
|
|
b_desc_ptr,
|
|
[n_offset, k_offset],
|
|
[BLOCK_SIZE_N, BLOCK_SIZE_K],
|
|
dtype,
|
|
)
|
|
if USE_FAST_ACCUM:
|
|
accumulator = tl.dot(a, b.T, accumulator)
|
|
else:
|
|
accumulator += tl.dot(a, b.T)
|
|
else:
|
|
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
|
a_ptrs = (
|
|
a_desc_ptr
|
|
+ (M_start_offset + offs_am[:, None]) * K
|
|
+ offs_k[None, :]
|
|
)
|
|
b_ptrs = (
|
|
b_desc_ptr
|
|
+ (N_start_offset + offs_bn[:, None]) * K
|
|
+ offs_k[None, :]
|
|
)
|
|
for k_offset in range(0, K, BLOCK_SIZE_K):
|
|
a = tl.load(a_ptrs, mask=offs_am[:, None] < m_size)
|
|
b = tl.load(b_ptrs, mask=offs_bn[:, None] < n_size)
|
|
accumulator += tl.dot(a, b.T)
|
|
a_ptrs += BLOCK_SIZE_K
|
|
b_ptrs += BLOCK_SIZE_K
|
|
|
|
if USE_TMA_STORE:
|
|
m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
|
|
n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
|
|
tl._experimental_descriptor_store(
|
|
c_desc_ptr,
|
|
accumulator.to(c_ptr.dtype.element_ty),
|
|
[m_offset, n_offset],
|
|
)
|
|
elif FUSE_SCATTER_ADD:
|
|
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
mask = offs_am < m_size
|
|
m_offsets = tl.load(
|
|
scatter_add_indices + M_start_offset + offs_am,
|
|
mask=mask,
|
|
cache_modifier=".ca",
|
|
)
|
|
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
c = accumulator.to(c_ptr.dtype.element_ty)
|
|
tl.atomic_add(
|
|
c_ptr + m_offsets[:, None] * N + offs_bn[None, :],
|
|
c,
|
|
mask=mask[:, None] and offs_bn[None, :] < n_size,
|
|
sem="relaxed",
|
|
)
|
|
else:
|
|
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
c = accumulator.to(c_ptr.dtype.element_ty)
|
|
tl.store(
|
|
c_ptr
|
|
+ (M_start_offset + offs_am[:, None]) * N
|
|
+ offs_bn[None, :],
|
|
c,
|
|
mask=offs_am[:, None] < m_size and offs_bn[None, :] < n_size,
|
|
)
|
|
tidx += NUM_SMS
|
|
|
|
iterated_tiles += num_tiles
|
|
|
|
|
|
# TODO(shikaili): Too much code duplication. Need to refactor.
|
|
@triton.autotune(
|
|
configs=_NV_WS_CONFIGS,
|
|
key=["G", "M_BUCKET", "N", "K"],
|
|
prune_configs_by={"early_config_prune": early_config_prune},
|
|
restore_value=["c_ptr"], # restore for scatter_add fusion
|
|
)
|
|
@triton.jit
|
|
def _fbgemm_grouped_gemm_ws(
|
|
a_desc_ptr,
|
|
b_desc_ptr,
|
|
c_ptr,
|
|
workspace,
|
|
scatter_add_indices,
|
|
m_sizes,
|
|
# problem sizes
|
|
G: tl.constexpr,
|
|
M_BUCKET: tl.constexpr,
|
|
N: tl.constexpr,
|
|
K: tl.constexpr,
|
|
NUM_SMS: tl.constexpr,
|
|
FUSE_SCATTER_ADD: tl.constexpr,
|
|
USE_TMA_LOAD: tl.constexpr,
|
|
USE_FAST_ACCUM: tl.constexpr,
|
|
# tile sizes
|
|
BLOCK_SIZE_M: tl.constexpr,
|
|
BLOCK_SIZE_N: tl.constexpr,
|
|
BLOCK_SIZE_K: tl.constexpr,
|
|
NUM_CONSUMER_GROUPS: tl.constexpr,
|
|
USE_TMA_LOAD_ON_SCALES: tl.constexpr,
|
|
USE_TMA_STORE: tl.constexpr,
|
|
) -> None:
|
|
tl.static_assert(USE_TMA_LOAD, "Always use TMA load with warp specialziation!")
|
|
tl.static_assert(not USE_TMA_LOAD_ON_SCALES, "Not supported!")
|
|
tl.static_assert(
|
|
not (FUSE_SCATTER_ADD and USE_TMA_STORE),
|
|
"Cannot fuse scatter add with TMA store!",
|
|
)
|
|
|
|
tidx = tl.program_id(0)
|
|
|
|
dtype: tl.dtype = c_ptr.dtype.element_ty
|
|
TMA_SIZE: tl.constexpr = tl.constexpr(128)
|
|
if USE_TMA_STORE:
|
|
c_desc_ptr = workspace + tidx * TMA_SIZE
|
|
else:
|
|
c_desc_ptr = None
|
|
|
|
M_end_offset = 0
|
|
M_end_offset = M_end_offset.to(tl.int64)
|
|
iterated_tiles = 0
|
|
iterated_tiles = iterated_tiles.to(tl.int64)
|
|
for g in tl.range(G):
|
|
# Move across groups
|
|
m_size = tl.load(m_sizes + g, cache_modifier=".ca")
|
|
|
|
if m_size > 0:
|
|
M_start_offset = M_end_offset
|
|
M_end_offset = M_start_offset + m_size
|
|
N_start_offset = g.to(tl.int64) * N
|
|
|
|
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
|
|
tl.static_assert(N % BLOCK_SIZE_N == 0)
|
|
NUM_N_TILES: tl.constexpr = N // BLOCK_SIZE_N
|
|
num_tiles = num_m_tiles * NUM_N_TILES
|
|
|
|
if USE_TMA_STORE:
|
|
with tl.async_task([0]):
|
|
# pyre-ignore
|
|
tl.extra.cuda.experimental_device_tensormap_create2d(
|
|
desc_ptr=c_desc_ptr,
|
|
global_address=c_ptr + M_start_offset * N,
|
|
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
|
|
global_size=[m_size, N],
|
|
element_ty=c_ptr.dtype.element_ty,
|
|
)
|
|
# pyre-ignore
|
|
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
|
|
|
|
# Move across tiles
|
|
next_iterated_tiles = iterated_tiles + num_tiles
|
|
if (tidx >= iterated_tiles) and (tidx < next_iterated_tiles):
|
|
for i in range(tidx, next_iterated_tiles, NUM_SMS):
|
|
gidx = i - iterated_tiles
|
|
# Split M first and N second.
|
|
tile_m_idx = gidx % num_m_tiles
|
|
tile_n_idx = gidx // num_m_tiles
|
|
|
|
accumulator = tl.zeros(
|
|
(BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32
|
|
)
|
|
tl.static_assert(K % BLOCK_SIZE_K == 0)
|
|
|
|
m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
|
|
n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
|
|
for k_offset in range(0, K, BLOCK_SIZE_K):
|
|
with tl.async_task([0]):
|
|
a = tl._experimental_descriptor_load(
|
|
a_desc_ptr,
|
|
[m_offset, k_offset],
|
|
[BLOCK_SIZE_M, BLOCK_SIZE_K],
|
|
dtype,
|
|
)
|
|
b = tl._experimental_descriptor_load(
|
|
b_desc_ptr,
|
|
[n_offset, k_offset],
|
|
[BLOCK_SIZE_N, BLOCK_SIZE_K],
|
|
dtype,
|
|
)
|
|
with tl.async_task([1, NUM_CONSUMER_GROUPS]):
|
|
if USE_FAST_ACCUM:
|
|
accumulator = tl.dot(a, b.T, accumulator)
|
|
else:
|
|
accumulator += tl.dot(a, b.T)
|
|
|
|
if USE_TMA_STORE:
|
|
with tl.async_task([1, NUM_CONSUMER_GROUPS]):
|
|
m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
|
|
n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
|
|
tl._experimental_descriptor_store(
|
|
c_desc_ptr,
|
|
accumulator.to(c_ptr.dtype.element_ty),
|
|
[m_offset, n_offset],
|
|
)
|
|
elif FUSE_SCATTER_ADD:
|
|
with tl.async_task([1, NUM_CONSUMER_GROUPS]):
|
|
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(
|
|
0, BLOCK_SIZE_M
|
|
)
|
|
mask = offs_am < m_size
|
|
m_offsets = tl.load(
|
|
scatter_add_indices + M_start_offset + offs_am,
|
|
mask=mask,
|
|
cache_modifier=".ca",
|
|
)
|
|
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(
|
|
0, BLOCK_SIZE_N
|
|
)
|
|
c = accumulator.to(c_ptr.dtype.element_ty)
|
|
tl.atomic_add(
|
|
c_ptr + m_offsets[:, None] * N + offs_bn[None, :],
|
|
c,
|
|
mask=mask[:, None],
|
|
sem="relaxed",
|
|
)
|
|
else:
|
|
with tl.async_task([1, NUM_CONSUMER_GROUPS]):
|
|
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(
|
|
0, BLOCK_SIZE_M
|
|
)
|
|
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(
|
|
0, BLOCK_SIZE_N
|
|
)
|
|
c = accumulator.to(c_ptr.dtype.element_ty)
|
|
tl.store(
|
|
c_ptr
|
|
+ (M_start_offset + offs_am[:, None]) * N
|
|
+ offs_bn[None, :],
|
|
c,
|
|
mask=offs_am[:, None] < m_size,
|
|
cache_modifier=".cs",
|
|
)
|
|
tidx += NUM_SMS
|
|
|
|
iterated_tiles += num_tiles
|
|
|
|
|
|
TT_FP8_DTYPE = tl.float8e4b8 if torch.version.hip else tl.float8e4nv
|
|
|
|
|
|
# TODO(shikaili): clean up redundant 'b_scale_desc_ptr' argument.
|
|
@triton.autotune(
|
|
configs=_AMD_CONFIGS if torch.version.hip else _NV_CONFIGS,
|
|
key=["G", "M_BUCKET", "N", "K"],
|
|
prune_configs_by={
|
|
"early_config_prune": functools.partial(
|
|
early_config_prune, dtype=TT_FP8_DTYPE, dtsize=1
|
|
)
|
|
},
|
|
restore_value=["c_ptr"], # restore for scatter_add fusion
|
|
)
|
|
@triton.jit
|
|
def _fbgemm_grouped_gemm_fp8_rowwise(
|
|
a_desc_ptr,
|
|
a_scale_ptr,
|
|
b_desc_ptr,
|
|
b_scale_ptr,
|
|
b_scale_desc_ptr,
|
|
c_ptr,
|
|
workspace,
|
|
scatter_add_indices,
|
|
m_sizes,
|
|
# problem sizes
|
|
G: tl.constexpr,
|
|
M_BUCKET,
|
|
N: tl.constexpr,
|
|
K: tl.constexpr,
|
|
NUM_SMS: tl.constexpr,
|
|
FUSE_SCATTER_ADD: tl.constexpr,
|
|
USE_TMA_LOAD: tl.constexpr,
|
|
USE_TMA_STORE: tl.constexpr,
|
|
USE_FAST_ACCUM: tl.constexpr,
|
|
# tile sizes
|
|
BLOCK_SIZE_M: tl.constexpr,
|
|
BLOCK_SIZE_N: tl.constexpr,
|
|
BLOCK_SIZE_K: tl.constexpr,
|
|
NUM_CONSUMER_GROUPS: tl.constexpr,
|
|
) -> None:
|
|
tl.static_assert(
|
|
not (FUSE_SCATTER_ADD and USE_TMA_STORE),
|
|
"Cannot fuse scatter add with TMA store!",
|
|
)
|
|
|
|
tidx = tl.program_id(0)
|
|
|
|
dtype = TT_FP8_DTYPE
|
|
TMA_SIZE: tl.constexpr = tl.constexpr(128)
|
|
if USE_TMA_STORE:
|
|
c_desc_ptr = workspace + tidx * TMA_SIZE
|
|
else:
|
|
c_desc_ptr = None
|
|
|
|
M_end_offset = 0
|
|
M_end_offset = M_end_offset.to(tl.int64)
|
|
iterated_tiles = 0
|
|
iterated_tiles = iterated_tiles.to(tl.int64)
|
|
for g in tl.range(G):
|
|
# Move across groups
|
|
m_size = tl.load(m_sizes + g)
|
|
|
|
if m_size > 0:
|
|
M_start_offset = M_end_offset
|
|
M_end_offset = M_start_offset + m_size
|
|
N_start_offset = g.to(tl.int64) * N
|
|
n_size = N
|
|
|
|
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
|
|
num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
|
|
num_tiles = num_m_tiles * num_n_tiles
|
|
|
|
if USE_TMA_STORE:
|
|
# pyre-ignore
|
|
tl.extra.cuda.experimental_device_tensormap_create2d(
|
|
desc_ptr=c_desc_ptr,
|
|
global_address=c_ptr + M_start_offset * N,
|
|
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
|
|
global_size=[m_size, n_size],
|
|
element_ty=c_ptr.dtype.element_ty,
|
|
)
|
|
# pyre-ignore
|
|
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
|
|
|
|
# Move across tiles
|
|
while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles:
|
|
gidx = tidx - iterated_tiles
|
|
# Split M first and N second.
|
|
tile_m_idx = gidx % num_m_tiles
|
|
tile_n_idx = gidx // num_m_tiles
|
|
|
|
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
|
tl.static_assert(K % BLOCK_SIZE_K == 0)
|
|
if USE_TMA_LOAD:
|
|
m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
|
|
n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
|
|
for k_offset in range(0, K, BLOCK_SIZE_K):
|
|
a = tl._experimental_descriptor_load(
|
|
a_desc_ptr,
|
|
[m_offset, k_offset],
|
|
[BLOCK_SIZE_M, BLOCK_SIZE_K],
|
|
dtype,
|
|
)
|
|
b = tl._experimental_descriptor_load(
|
|
b_desc_ptr,
|
|
[n_offset, k_offset],
|
|
[BLOCK_SIZE_N, BLOCK_SIZE_K],
|
|
dtype,
|
|
)
|
|
if USE_FAST_ACCUM:
|
|
accumulator = tl.dot(a, b.T, accumulator)
|
|
else:
|
|
accumulator += tl.dot(a, b.T)
|
|
else:
|
|
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
|
a_ptrs = (
|
|
a_desc_ptr
|
|
+ (M_start_offset + offs_am[:, None]) * K
|
|
+ offs_k[None, :]
|
|
)
|
|
b_ptrs = (
|
|
b_desc_ptr
|
|
+ (N_start_offset + offs_bn[:, None]) * K
|
|
+ offs_k[None, :]
|
|
)
|
|
for k_offset in range(0, K, BLOCK_SIZE_K):
|
|
a = tl.load(a_ptrs, mask=offs_am[:, None] < m_size)
|
|
b = tl.load(b_ptrs, mask=offs_bn[:, None] < n_size)
|
|
accumulator += tl.dot(a, b.T)
|
|
a_ptrs += BLOCK_SIZE_K
|
|
b_ptrs += BLOCK_SIZE_K
|
|
|
|
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
a_scale = tl.load(
|
|
a_scale_ptr + M_start_offset + offs_am[:, None],
|
|
mask=offs_am[:, None] < m_size,
|
|
)
|
|
b_scale = tl.load(
|
|
b_scale_ptr + N_start_offset + offs_bn[None, :],
|
|
mask=offs_bn[None, :] < n_size,
|
|
)
|
|
c = accumulator.to(tl.float32) * a_scale * b_scale
|
|
|
|
if USE_TMA_STORE:
|
|
m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
|
|
n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
|
|
tl._experimental_descriptor_store(
|
|
c_desc_ptr,
|
|
c.to(c_ptr.dtype.element_ty),
|
|
[m_offset, n_offset],
|
|
)
|
|
elif FUSE_SCATTER_ADD:
|
|
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
mask = offs_am < m_size
|
|
m_offsets = tl.load(
|
|
scatter_add_indices + M_start_offset + offs_am,
|
|
mask=mask,
|
|
cache_modifier=".ca",
|
|
)
|
|
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
tl.atomic_add(
|
|
c_ptr + m_offsets[:, None] * N + offs_bn[None, :],
|
|
c.to(c_ptr.dtype.element_ty),
|
|
mask=mask[:, None] and offs_bn[None, :] < n_size,
|
|
sem="relaxed",
|
|
)
|
|
else:
|
|
tl.store(
|
|
c_ptr
|
|
+ (M_start_offset + offs_am[:, None]) * N
|
|
+ offs_bn[None, :],
|
|
c,
|
|
mask=offs_am[:, None] < m_size and offs_bn[None, :] < n_size,
|
|
)
|
|
tidx += NUM_SMS
|
|
|
|
iterated_tiles += num_tiles
|
|
|
|
|
|
# TODO(shikaili): Too much code duplication. Need to refactor.
|
|
@triton.autotune(
|
|
configs=_NV_WS_CONFIGS,
|
|
key=["G", "M_BUCKET", "N", "K"],
|
|
prune_configs_by={
|
|
"early_config_prune": functools.partial(
|
|
early_config_prune, dtype=TT_FP8_DTYPE, dtsize=1
|
|
)
|
|
},
|
|
restore_value=["c_ptr"], # restore for scatter_add fusion
|
|
)
|
|
@triton.jit
|
|
def _fbgemm_grouped_gemm_fp8_rowwise_ws(
|
|
a_desc_ptr,
|
|
a_scale_ptr,
|
|
b_desc_ptr,
|
|
b_scale_ptr,
|
|
b_scale_desc_ptr,
|
|
c_ptr,
|
|
workspace,
|
|
scatter_add_indices,
|
|
m_sizes,
|
|
# problem sizes
|
|
G: tl.constexpr,
|
|
M_BUCKET: tl.constexpr,
|
|
N: tl.constexpr,
|
|
K: tl.constexpr,
|
|
NUM_SMS: tl.constexpr,
|
|
FUSE_SCATTER_ADD: tl.constexpr,
|
|
USE_TMA_LOAD: tl.constexpr,
|
|
USE_FAST_ACCUM: tl.constexpr,
|
|
# tile sizes
|
|
BLOCK_SIZE_M: tl.constexpr,
|
|
BLOCK_SIZE_N: tl.constexpr,
|
|
BLOCK_SIZE_K: tl.constexpr,
|
|
NUM_CONSUMER_GROUPS: tl.constexpr,
|
|
USE_TMA_LOAD_ON_SCALES: tl.constexpr,
|
|
USE_TMA_STORE: tl.constexpr,
|
|
) -> None:
|
|
tl.static_assert(USE_TMA_LOAD, "Always use TMA load with warp specialziation!")
|
|
tl.static_assert(
|
|
not (FUSE_SCATTER_ADD and USE_TMA_STORE),
|
|
"Cannot fuse scatter add with TMA store!",
|
|
)
|
|
|
|
tidx = tl.program_id(0)
|
|
|
|
dtype = TT_FP8_DTYPE
|
|
TMA_SIZE: tl.constexpr = tl.constexpr(128)
|
|
if USE_TMA_STORE:
|
|
c_desc_ptr = workspace + tidx * TMA_SIZE
|
|
else:
|
|
c_desc_ptr = None
|
|
|
|
M_end_offset = 0
|
|
M_end_offset = M_end_offset.to(tl.int64)
|
|
iterated_tiles = 0
|
|
iterated_tiles = iterated_tiles.to(tl.int64)
|
|
for g in tl.range(G):
|
|
# Move across groups
|
|
m_size = tl.load(m_sizes + g, cache_modifier=".ca")
|
|
|
|
if m_size > 0:
|
|
M_start_offset = M_end_offset
|
|
M_end_offset = M_start_offset + m_size
|
|
N_start_offset = g.to(tl.int64) * N
|
|
|
|
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
|
|
tl.static_assert(N % BLOCK_SIZE_N == 0)
|
|
NUM_N_TILES: tl.constexpr = N // BLOCK_SIZE_N
|
|
num_tiles = num_m_tiles * NUM_N_TILES
|
|
|
|
if USE_TMA_STORE:
|
|
with tl.async_task([0]):
|
|
# pyre-ignore
|
|
tl.extra.cuda.experimental_device_tensormap_create2d(
|
|
desc_ptr=c_desc_ptr,
|
|
global_address=c_ptr + M_start_offset * N,
|
|
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
|
|
global_size=[m_size, N],
|
|
element_ty=c_ptr.dtype.element_ty,
|
|
)
|
|
# pyre-ignore
|
|
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
|
|
|
|
# Move across tiles
|
|
next_iterated_tiles = iterated_tiles + num_tiles
|
|
if (tidx >= iterated_tiles) and (tidx < next_iterated_tiles):
|
|
for i in range(tidx, next_iterated_tiles, NUM_SMS):
|
|
gidx = i - iterated_tiles
|
|
# Split M first and N second.
|
|
tile_m_idx = gidx % num_m_tiles
|
|
tile_n_idx = gidx // num_m_tiles
|
|
|
|
accumulator = tl.zeros(
|
|
(BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32
|
|
)
|
|
tl.static_assert(K % BLOCK_SIZE_K == 0)
|
|
|
|
m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
|
|
n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
|
|
for k_offset in range(0, K, BLOCK_SIZE_K):
|
|
with tl.async_task([0]):
|
|
a = tl._experimental_descriptor_load(
|
|
a_desc_ptr,
|
|
[m_offset, k_offset],
|
|
[BLOCK_SIZE_M, BLOCK_SIZE_K],
|
|
dtype,
|
|
)
|
|
b = tl._experimental_descriptor_load(
|
|
b_desc_ptr,
|
|
[n_offset, k_offset],
|
|
[BLOCK_SIZE_N, BLOCK_SIZE_K],
|
|
dtype,
|
|
)
|
|
with tl.async_task([1, NUM_CONSUMER_GROUPS]):
|
|
if USE_FAST_ACCUM:
|
|
accumulator = tl.dot(a, b.T, accumulator)
|
|
else:
|
|
accumulator += tl.dot(a, b.T)
|
|
|
|
if USE_TMA_LOAD_ON_SCALES:
|
|
with tl.async_task([0]):
|
|
b_scale = tl._experimental_descriptor_load(
|
|
b_scale_desc_ptr,
|
|
[n_offset],
|
|
[BLOCK_SIZE_N],
|
|
tl.float32,
|
|
)
|
|
|
|
with tl.async_task([1, NUM_CONSUMER_GROUPS]):
|
|
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(
|
|
0, BLOCK_SIZE_M
|
|
)
|
|
a_scale = tl.load(
|
|
a_scale_ptr + M_start_offset + offs_am[:, None],
|
|
mask=offs_am[:, None] < m_size,
|
|
cache_modifier=".ca",
|
|
)
|
|
c = accumulator.to(tl.float32) * a_scale * b_scale[None, :]
|
|
else:
|
|
with tl.async_task([1, NUM_CONSUMER_GROUPS]):
|
|
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(
|
|
0, BLOCK_SIZE_M
|
|
)
|
|
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(
|
|
0, BLOCK_SIZE_N
|
|
)
|
|
a_scale = tl.load(
|
|
a_scale_ptr + M_start_offset + offs_am[:, None],
|
|
mask=offs_am[:, None] < m_size,
|
|
cache_modifier=".ca",
|
|
)
|
|
b_scale = tl.load(
|
|
b_scale_ptr + N_start_offset + offs_bn[None, :],
|
|
cache_modifier=".ca",
|
|
)
|
|
c = accumulator.to(tl.float32) * a_scale * b_scale
|
|
|
|
if USE_TMA_STORE:
|
|
with tl.async_task([1, NUM_CONSUMER_GROUPS]):
|
|
m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
|
|
n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
|
|
tl._experimental_descriptor_store(
|
|
c_desc_ptr,
|
|
c.to(c_ptr.dtype.element_ty),
|
|
[m_offset, n_offset],
|
|
)
|
|
elif FUSE_SCATTER_ADD:
|
|
with tl.async_task([1, NUM_CONSUMER_GROUPS]):
|
|
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(
|
|
0, BLOCK_SIZE_M
|
|
)
|
|
mask = offs_am < m_size
|
|
m_offsets = tl.load(
|
|
scatter_add_indices + M_start_offset + offs_am,
|
|
mask=mask,
|
|
cache_modifier=".ca",
|
|
)
|
|
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(
|
|
0, BLOCK_SIZE_N
|
|
)
|
|
tl.atomic_add(
|
|
c_ptr + m_offsets[:, None] * N + offs_bn[None, :],
|
|
c,
|
|
mask=mask[:, None],
|
|
sem="relaxed",
|
|
)
|
|
else:
|
|
with tl.async_task([1, NUM_CONSUMER_GROUPS]):
|
|
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(
|
|
0, BLOCK_SIZE_M
|
|
)
|
|
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(
|
|
0, BLOCK_SIZE_N
|
|
)
|
|
tl.store(
|
|
c_ptr
|
|
+ (M_start_offset + offs_am[:, None]) * N
|
|
+ offs_bn[None, :],
|
|
c,
|
|
mask=offs_am[:, None] < m_size,
|
|
cache_modifier=".cs",
|
|
)
|
|
tidx += NUM_SMS
|
|
|
|
iterated_tiles += num_tiles
|
|
|
|
|
|
warnings.simplefilter("once")
|
|
|
|
|
|
def _grouped_gemm(
|
|
*,
|
|
x: torch.Tensor,
|
|
w: torch.Tensor,
|
|
m_sizes: torch.Tensor,
|
|
x_scale: Optional[torch.Tensor],
|
|
w_scale: Optional[torch.Tensor],
|
|
use_fast_accum: bool,
|
|
use_warp_specialization: bool,
|
|
output_tensor: Optional[torch.Tensor],
|
|
scatter_add_indices: Optional[torch.Tensor],
|
|
) -> torch.Tensor:
|
|
|
|
USE_TMA_LOAD = not torch.version.hip
|
|
USE_TMA_STORE = False
|
|
|
|
if USE_TMA_LOAD and not HAS_TMA_DESC:
|
|
USE_TMA_LOAD = False
|
|
warnings.warn("TMA load is disabled as there is no TMA descriptor support!")
|
|
|
|
if USE_TMA_STORE and not HAS_TMA_DESC:
|
|
USE_TMA_STORE = False
|
|
warnings.warn("TMA store is disabled as there is no TMA descriptor support!")
|
|
|
|
# TODO(shikaili): Check the readniess of WS on ROCm side in Meta's Triton.
|
|
if use_warp_specialization and torch.version.hip:
|
|
warnings.warn("Warp specialization is disabled as it is not supported on ROCm.")
|
|
use_warp_specialization = False
|
|
|
|
if use_warp_specialization and not _HAS_WS_SUPPORT:
|
|
warnings.warn(
|
|
"Warp specialization is disabled as the Triton build in current environment doesn't have such support. Please build from https://github.com/facebookexperimental/triton/tree/ws-3.2.x to enable it for best performance on Nvidia's SM90 GPUs."
|
|
)
|
|
use_warp_specialization = False
|
|
|
|
if use_warp_specialization:
|
|
assert HAS_TMA_DESC
|
|
USE_TMA_STORE = True # Tuning decision
|
|
|
|
G = m_sizes.shape[0]
|
|
|
|
assert x.is_contiguous()
|
|
assert w.is_contiguous()
|
|
assert m_sizes.is_contiguous()
|
|
|
|
M, K = x.shape
|
|
N = w.shape[0] // G
|
|
assert K == w.shape[1]
|
|
|
|
if output_tensor is None:
|
|
FUSE_SCATTER_ADD = False
|
|
assert scatter_add_indices is None
|
|
y = torch.empty((M, N), device=x.device, dtype=torch.bfloat16)
|
|
else:
|
|
FUSE_SCATTER_ADD = True
|
|
assert scatter_add_indices is not None
|
|
assert scatter_add_indices.is_contiguous()
|
|
assert scatter_add_indices.shape == (M,)
|
|
y = output_tensor
|
|
if M == 0 or N == 0:
|
|
return y
|
|
|
|
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
|
|
|
|
desc_helper = None
|
|
desc_x = x
|
|
desc_w = w
|
|
desc_ws = w_scale
|
|
workspace = None
|
|
|
|
if USE_TMA_LOAD:
|
|
desc_helper = TmaAutoTuneHelper()
|
|
desc_helper.init_tma_descriptor("x")
|
|
desc_helper.init_tma_descriptor("w")
|
|
desc_x = desc_helper.get_tma_descriptor_kernel_param("x")
|
|
desc_w = desc_helper.get_tma_descriptor_kernel_param("w")
|
|
if use_warp_specialization and w_scale is not None:
|
|
desc_helper.init_tma_descriptor("ws")
|
|
desc_ws = desc_helper.get_tma_descriptor_kernel_param("ws")
|
|
|
|
if USE_TMA_STORE:
|
|
workspace = torch.empty(
|
|
NUM_SMS * TmaAutoTuneHelper.TMA_SIZE,
|
|
device=x.device,
|
|
dtype=torch.uint8,
|
|
)
|
|
|
|
def grid(META):
|
|
if USE_TMA_LOAD:
|
|
nonlocal desc_helper # noqa: F824
|
|
desc_helper.fill_2d_tma_descriptor(
|
|
"x",
|
|
x.data_ptr(),
|
|
M,
|
|
K,
|
|
META["BLOCK_SIZE_M"] // META["NUM_CONSUMER_GROUPS"],
|
|
META["BLOCK_SIZE_K"],
|
|
x.element_size(),
|
|
)
|
|
|
|
desc_helper.fill_2d_tma_descriptor(
|
|
"w",
|
|
w.data_ptr(),
|
|
N * G,
|
|
K,
|
|
META["BLOCK_SIZE_N"],
|
|
META["BLOCK_SIZE_K"],
|
|
w.element_size(),
|
|
)
|
|
|
|
if META.get("USE_TMA_LOAD_ON_SCALES", False):
|
|
desc_helper.fill_1d_tma_descriptor(
|
|
"ws",
|
|
w_scale.data_ptr(),
|
|
N * G,
|
|
META["BLOCK_SIZE_N"],
|
|
w_scale.element_size(),
|
|
)
|
|
|
|
return (NUM_SMS,)
|
|
|
|
M_BUCKET_CAP = 16384
|
|
M_BUCKET = min(triton.next_power_of_2(M), M_BUCKET_CAP)
|
|
if x_scale is not None and w_scale is not None:
|
|
assert x_scale.is_contiguous()
|
|
assert w_scale.is_contiguous()
|
|
fn = (
|
|
_fbgemm_grouped_gemm_fp8_rowwise_ws
|
|
if use_warp_specialization
|
|
else _fbgemm_grouped_gemm_fp8_rowwise
|
|
)
|
|
args = (
|
|
desc_x,
|
|
x_scale,
|
|
desc_w,
|
|
w_scale,
|
|
desc_ws,
|
|
y,
|
|
workspace,
|
|
scatter_add_indices,
|
|
m_sizes,
|
|
G,
|
|
M_BUCKET,
|
|
N,
|
|
K,
|
|
NUM_SMS,
|
|
FUSE_SCATTER_ADD,
|
|
USE_TMA_LOAD,
|
|
)
|
|
if use_warp_specialization:
|
|
args += (use_fast_accum,)
|
|
else:
|
|
args += (USE_TMA_STORE, use_fast_accum)
|
|
fn[grid](*args)
|
|
else:
|
|
assert x_scale is None
|
|
assert w_scale is None
|
|
fn = (
|
|
_fbgemm_grouped_gemm_ws if use_warp_specialization else _fbgemm_grouped_gemm
|
|
)
|
|
args = (
|
|
desc_x,
|
|
desc_w,
|
|
y,
|
|
workspace,
|
|
scatter_add_indices,
|
|
m_sizes,
|
|
G,
|
|
M_BUCKET,
|
|
N,
|
|
K,
|
|
NUM_SMS,
|
|
FUSE_SCATTER_ADD,
|
|
USE_TMA_LOAD,
|
|
)
|
|
if use_warp_specialization:
|
|
args += (use_fast_accum,)
|
|
else:
|
|
args += (USE_TMA_STORE, use_fast_accum)
|
|
fn[grid](*args)
|
|
|
|
return y
|
|
|
|
|
|
def grouped_gemm(
|
|
x: torch.Tensor,
|
|
w: torch.Tensor,
|
|
m_sizes: torch.Tensor,
|
|
use_fast_accum: bool = True,
|
|
*,
|
|
_use_warp_specialization: bool = True,
|
|
_output_tensor: Optional[torch.Tensor] = None,
|
|
_scatter_add_indices: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
return _grouped_gemm(
|
|
x=x,
|
|
w=w,
|
|
m_sizes=m_sizes,
|
|
x_scale=None,
|
|
w_scale=None,
|
|
use_fast_accum=use_fast_accum,
|
|
use_warp_specialization=_use_warp_specialization,
|
|
output_tensor=_output_tensor,
|
|
scatter_add_indices=_scatter_add_indices,
|
|
)
|
|
|
|
|
|
def grouped_gemm_fp8_rowwise(
|
|
x: torch.Tensor,
|
|
w: torch.Tensor,
|
|
m_sizes: torch.Tensor,
|
|
x_scale: torch.Tensor,
|
|
w_scale: torch.Tensor,
|
|
use_fast_accum: bool = True,
|
|
*,
|
|
_use_warp_specialization: bool = True,
|
|
_output_tensor: Optional[torch.Tensor] = None,
|
|
_scatter_add_indices: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
return _grouped_gemm(
|
|
x=x,
|
|
w=w,
|
|
m_sizes=m_sizes,
|
|
x_scale=x_scale,
|
|
w_scale=w_scale,
|
|
use_fast_accum=use_fast_accum,
|
|
use_warp_specialization=_use_warp_specialization,
|
|
output_tensor=_output_tensor,
|
|
scatter_add_indices=_scatter_add_indices,
|
|
)
|