Sync from v0.13

This commit is contained in:
2026-01-19 10:38:50 +08:00
parent b2ef04d792
commit 5aef6c175a
3714 changed files with 854317 additions and 89342 deletions

View File

@@ -0,0 +1,45 @@
# Machete (Mixed Precision Cutlass-Based GEMM)
Machete is a spiritual successor to the Marlin kernel but optimized for Hopper architectures and based on Cutlass. Being based on Cutlass, new type pairs and epilogues are easier to add compared to Marlin.
## Overview
Machete effectively performs
```python
scale_type = w_s.dtype
compute_type = a.dtype
out = (w_q.to(scale_type) * w_s - w_z.to(scale_type)) @ a
```
Where `w_q` is a quantized weight matrix, `w_s` is the quantization scales, and
`w_z` is the quantization zeropoints.
> **_NOTE:_** `w_z` is added after the scales so we can
use FMA operations, but this means they must have the scales pre-applied if the
supplied zeropoints assume that they will be subtracted before the scales are
applied.
## API
The main optimization within Machete is prepacking the weight matrix to more closely match the tensor core layouts, allowing for wider shared memory loads when loading the weight matrix. This means that the weight matrix must be prepacked before calling `machete_gemm`. The flow looks something like:
```python
from vllm import _custom_ops as ops
...
W_q_packed = ops.machete_prepack_B(w_q, wtype)
output = ops.machete_gemm(
a,
b_q=W_q_packed,
b_type=wtype,
b_scales=w_s,
b_group_size=group_size
)
```
## Code Generation
Since Machete is based on Cutlass, we can generate multiple type pairs and different tile shapes using the same kernel template. We generate multiple instantiations of this template using `generate.py`.
New type pairs (`TypeConfig`s) can be appended to `impl_configs` (in `generate()`), and these will get automatically generated (assuming they can be supported without issues). For each `TypeConfig`, you must also provide an `ImplConfig`, which bundles a `TypeConfig` with a list of `ScheduleConfig`s, `Specialization`s, and a default heuristic. The `ScheduleConfig`s (which contain info on tile shapes, tile scheduler, etc.) can perform differently for different problem shapes, and there is almost never one `ScheduleConfig` that works well for all problem shapes, so it is generally beneficial to generate different `ScheduleConfig`s for different potential problem shapes. This is where the heuristic comes in. For each `TypeConfig`, a default heuristic should be provided. This maps different problem shapes to different `ScheduleConfig`s and is used when the user does not provide the `schedule` parameter to `machete_gemm`. The `Specialization`s define what feature combinations to generate, i.e., `with_zeropoints`, `with_scales`, etc. We can reduce compile times and the final binary size by limiting the set of feature combinations we generate.

View File

@@ -0,0 +1,694 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
import math
import os
import shutil
from collections.abc import Iterable
from copy import deepcopy
from dataclasses import dataclass, fields
from functools import reduce
import jinja2
from vllm_cutlass_library_extension import (
DataType,
EpilogueScheduleTag,
EpilogueScheduleType,
MixedInputKernelScheduleType,
TileSchedulerTag,
TileSchedulerType,
VLLMDataType,
VLLMDataTypeNames,
VLLMDataTypeSize,
VLLMDataTypeTag,
VLLMDataTypeTorchDataTypeTag,
VLLMDataTypeVLLMScalarTypeTag,
VLLMKernelScheduleTag,
)
#
# Generator templating
#
DISPATCH_TEMPLATE = """
#include "../machete_mm_launcher.cuh"
namespace machete {
{% for impl_config in impl_configs %}
{% set type_sig = gen_type_sig(impl_config.types) -%}
{% for s in impl_config.schedules %}
extern torch::Tensor impl_{{type_sig}}_sch_{{gen_sch_sig(s)}}(MMArgs);
{%- endfor %}
torch::Tensor mm_dispatch_{{type_sig}}(MMArgs args) {
[[maybe_unused]] auto M = args.A.size(0);
[[maybe_unused]] auto N = args.B.size(1);
[[maybe_unused]] auto K = args.A.size(1);
if (!args.maybe_schedule) {
{%- for cond, s in impl_config.heuristic %}
{%if cond is not none%}if ({{cond}})
{%- else %}else
{%- endif %}
return impl_{{type_sig}}_sch_{{ gen_sch_sig(s) }}(args);{% endfor %}
}
{%- for s in impl_config.schedules %}
if (*args.maybe_schedule == "{{ gen_sch_sig(s) }}")
return impl_{{type_sig}}_sch_{{ gen_sch_sig(s) }}(args);
{%- endfor %}
TORCH_CHECK_NOT_IMPLEMENTED(false, "machete_gemm(..) is not implemented for "
"schedule = ", *args.maybe_schedule);
}
{%- endfor %}
static inline std::optional<at::ScalarType> maybe_scalartype(
std::optional<at::Tensor> const& t) {
if (!t) {
return std::nullopt;
} else {
return t->scalar_type();
};
}
torch::Tensor mm_dispatch(MMArgs args) {
auto out_type = args.maybe_out_type.value_or(args.A.scalar_type());
auto a_type = args.A.scalar_type();
auto maybe_g_scales_type = maybe_scalartype(args.maybe_group_scales);
auto maybe_g_zeros_type = maybe_scalartype(args.maybe_group_zeros);
auto maybe_ch_scales_type = maybe_scalartype(args.maybe_channel_scales);
auto maybe_tok_scales_type = maybe_scalartype(args.maybe_token_scales);
{% for impl_config in impl_configs %}
{% set t = impl_config.types -%}
{% set type_sig = gen_type_sig(t) -%}
if (args.b_type == {{VLLMScalarTypeTag[t.b]}}
&& a_type == {{TorchTypeTag[t.a]}}
&& out_type == {{TorchTypeTag[t.out]}}
&& {%if t.b_group_scale != void -%}
maybe_g_scales_type == {{TorchTypeTag[t.b_group_scale]}}
{%- else %}!maybe_g_scales_type{%endif%}
&& {%if t.b_group_zeropoint != void -%}
maybe_g_zeros_type == {{TorchTypeTag[t.b_group_zeropoint]}}
{%- else %}!maybe_g_zeros_type{%endif%}
&& {%if t.b_channel_scale != void -%}
maybe_ch_scales_type == {{TorchTypeTag[t.b_channel_scale]}}
{%- else %}!maybe_ch_scales_type{%endif%}
&& {%if t.a_token_scale != void -%}
maybe_tok_scales_type == {{TorchTypeTag[t.a_token_scale]}}
{%- else %}!maybe_tok_scales_type{%endif%}
) {
return mm_dispatch_{{type_sig}}(args);
}
{%- endfor %}
TORCH_CHECK_NOT_IMPLEMENTED(
false, "machete_mm(..) is not implemented for "
"a_type=", args.A.scalar_type(),
", b_type=", args.b_type.str(),
", out_type=", out_type,
", with_group_scale_type=", maybe_g_scales_type
? toString(*maybe_g_scales_type) : "None",
", with_group_zeropoint_type=", maybe_g_zeros_type
? toString(*maybe_g_zeros_type) : "None",
", with_channel_scale_type=", maybe_ch_scales_type
? toString(*maybe_ch_scales_type) : "None",
", with_token_scale_type=", maybe_tok_scales_type
? toString(*maybe_tok_scales_type) : "None",
"; implemented types are: \\n",
{%- for impl_config in impl_configs %}
{% set t = impl_config.types -%}
"\\t{{gen_type_option_name(t)}}\\n",
{%- endfor %}
"");
}
std::vector<std::string> supported_schedules_dispatch(
SupportedSchedulesArgs args) {
auto out_type = args.maybe_out_type.value_or(args.a_type);
{% for impl_config in impl_configs %}
{% set t = impl_config.types -%}
{% set schs = impl_config.schedules -%}
if (args.b_type == {{VLLMScalarTypeTag[t.b]}}
&& args.a_type == {{TorchTypeTag[t.a]}}
&& out_type == {{TorchTypeTag[t.out]}}
&& {%if t.b_group_scale != void -%}
args.maybe_group_scales_type == {{TorchTypeTag[t.b_group_scale]}}
{%- else %}!args.maybe_group_scales_type{%endif%}
&& {%if t.b_group_zeropoint != void-%}
args.maybe_group_zeros_type == {{TorchTypeTag[t.b_group_zeropoint]}}
{%- else %}!args.maybe_group_zeros_type{%endif%}
) {
return {
{%- for s in impl_config.schedules %}
"{{gen_sch_sig(s)}}"{% if not loop.last %},{% endif %}
{%- endfor %}
};
}
{%- endfor %}
return {};
};
}; // namespace machete
"""
IMPL_TEMPLATE = """
#include "../machete_mm_launcher.cuh"
namespace machete {
{% for sch in unique_schedules(impl_configs) %}
{% set sch_sig = gen_sch_sig(sch) -%}
struct sch_{{sch_sig}} {
using TileShapeNM = Shape<{{
to_cute_constant(sch.tile_shape_mn)|join(', ')}}>;
using ClusterShape = Shape<{{
to_cute_constant(sch.cluster_shape_mnk)|join(', ')}}>;
// TODO: Reimplement
// using KernelSchedule = {{KernelScheduleTag[sch.kernel_schedule]}};
using EpilogueSchedule = {{EpilogueScheduleTag[sch.epilogue_schedule]}};
using TileScheduler = {{TileSchedulerTag[sch.tile_scheduler]}};
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
};
{% endfor %}
{% for impl_config in impl_configs %}
{% set t = impl_config.types -%}
{% set schs = impl_config.schedules -%}
{% set type_sig = gen_type_sig(t) -%}
template<typename Sch>
using Kernel_{{type_sig}} = MacheteKernelTemplate<
{{DataTypeTag[t.a]}}, // ElementA
{{DataTypeTag[t.b]}}, // ElementB
{{DataTypeTag[t.out]}}, // ElementD
{{DataTypeTag[t.accumulator]}}, // Accumulator
{{DataTypeTag[t.b_group_scale]}}, // GroupScaleT
{{DataTypeTag[t.b_group_zeropoint]}}, // GroupZeroT
{{DataTypeTag[t.b_channel_scale]}}, // ChannelScaleT
{{DataTypeTag[t.a_token_scale]}}, // TokenScaleT
cutlass::gemm::KernelTmaWarpSpecializedCooperative,
Sch>;
{% for sch in schs %}
{% set sch_sig = gen_sch_sig(sch) -%}
torch::Tensor
impl_{{type_sig}}_sch_{{sch_sig}}(MMArgs args) {
return run_impl<Kernel_{{type_sig}}<sch_{{sch_sig}}>>(args);
}
{%- endfor %}
{%- endfor %}
}; // namespace machete
"""
PREPACK_TEMPLATE = """
#include "../machete_prepack_launcher.cuh"
namespace machete {
torch::Tensor prepack_B_dispatch(PrepackBArgs args) {
auto convert_type = args.maybe_group_scales_type.value_or(args.a_type);
{%- for t in types %}
{% set b_type = unsigned_type_with_bitwidth(t.b_num_bits) %}
if (args.a_type == {{TorchTypeTag[t.a]}}
&& args.b_type.size_bits() == {{t.b_num_bits}}
&& convert_type == {{TorchTypeTag[t.convert]}}) {
return prepack_impl<
PrepackedLayoutBTemplate<
{{DataTypeTag[t.a]}}, // ElementA
{{DataTypeTag[b_type]}}, // ElementB
{{DataTypeTag[t.convert]}}, // ElementConvert
{{DataTypeTag[t.accumulator]}}, // Accumulator
cutlass::layout::ColumnMajor,
cutlass::gemm::KernelTmaWarpSpecializedCooperative>
>(args.B);
}
{%- endfor %}
TORCH_CHECK_NOT_IMPLEMENTED(false,
"prepack_B_dispatch(..) is not implemented for "
"atype = ", args.a_type,
", b_type = ", args.b_type.str(),
", with_group_scales_type= ", args.maybe_group_scales_type ?
toString(*args.maybe_group_scales_type) : "None");
}
}; // namespace machete
"""
TmaMI = MixedInputKernelScheduleType.TmaWarpSpecializedCooperative
TmaCoop = EpilogueScheduleType.TmaWarpSpecializedCooperative
@dataclass(frozen=True)
class ScheduleConfig:
tile_shape_mn: tuple[int, int]
cluster_shape_mnk: tuple[int, int, int]
kernel_schedule: MixedInputKernelScheduleType
epilogue_schedule: EpilogueScheduleType
tile_scheduler: TileSchedulerType
@dataclass(frozen=True)
class TypeConfig:
a: DataType
b: DataType | VLLMDataType
b_group_scale: DataType
b_group_zeropoint: DataType
b_channel_scale: DataType
a_token_scale: DataType
out: DataType
accumulator: DataType
@dataclass(frozen=True)
class PrepackTypeConfig:
a: DataType
b_num_bits: int
convert: DataType
accumulator: DataType
@dataclass
class ImplConfig:
types: TypeConfig
schedules: list[ScheduleConfig]
heuristic: list[tuple[str | None, ScheduleConfig]]
def generate_sch_sig(schedule_config: ScheduleConfig) -> str:
tile_shape = (
f"{schedule_config.tile_shape_mn[0]}x{schedule_config.tile_shape_mn[1]}"
)
cluster_shape = (
f"{schedule_config.cluster_shape_mnk[0]}"
+ f"x{schedule_config.cluster_shape_mnk[1]}"
+ f"x{schedule_config.cluster_shape_mnk[2]}"
)
kernel_schedule = VLLMKernelScheduleTag[schedule_config.kernel_schedule].split(
"::"
)[-1]
epilogue_schedule = EpilogueScheduleTag[schedule_config.epilogue_schedule].split(
"::"
)[-1]
tile_scheduler = TileSchedulerTag[schedule_config.tile_scheduler].split("::")[-1]
return (
f"{tile_shape}_{cluster_shape}_{kernel_schedule}"
+ f"_{epilogue_schedule}_{tile_scheduler}"
)
# mostly unique shorter sch_sig
def generate_terse_sch_sig(schedule_config: ScheduleConfig) -> str:
kernel_terse_names_replace = {
"KernelTmaWarpSpecializedCooperative": "TmaMI_",
"TmaWarpSpecializedCooperative_": "TmaCoop_",
"StreamKScheduler": "streamK",
}
sch_sig = generate_sch_sig(schedule_config)
for orig, terse in kernel_terse_names_replace.items():
sch_sig = sch_sig.replace(orig, terse)
return sch_sig
# unique type_name
def generate_type_signature(kernel_types: TypeConfig):
return str(
"".join(
[
VLLMDataTypeNames[getattr(kernel_types, field.name)]
for field in fields(TypeConfig)
]
)
)
def generate_type_option_name(kernel_types: TypeConfig):
return ", ".join(
[
f"{field.name.replace('b_', 'with_') + '_type'}="
+ VLLMDataTypeNames[getattr(kernel_types, field.name)]
for field in fields(TypeConfig)
]
)
def is_power_of_two(n):
return (n != 0) and (n & (n - 1) == 0)
def to_cute_constant(value: list[int]):
def _to_cute_constant(value: int):
if is_power_of_two(value):
return f"_{value}"
else:
return f"Int<{value}>"
if isinstance(value, Iterable):
return [_to_cute_constant(value) for value in value]
else:
return _to_cute_constant(value)
def unique_schedules(impl_configs: list[ImplConfig]):
# Use dict over set for deterministic ordering
return list(
{
sch: None for impl_config in impl_configs for sch in impl_config.schedules
}.keys()
)
def unsigned_type_with_bitwidth(num_bits):
return {
4: DataType.u4,
8: DataType.u8,
16: DataType.u16,
32: DataType.u32,
64: DataType.u64,
}[num_bits]
template_globals = {
"void": DataType.void,
"DataTypeTag": VLLMDataTypeTag,
"VLLMScalarTypeTag": VLLMDataTypeVLLMScalarTypeTag,
"TorchTypeTag": VLLMDataTypeTorchDataTypeTag,
"KernelScheduleTag": VLLMKernelScheduleTag,
"EpilogueScheduleTag": EpilogueScheduleTag,
"TileSchedulerTag": TileSchedulerTag,
"to_cute_constant": to_cute_constant,
"gen_sch_sig": generate_terse_sch_sig,
"gen_type_sig": generate_type_signature,
"unique_schedules": unique_schedules,
"unsigned_type_with_bitwidth": unsigned_type_with_bitwidth,
"gen_type_option_name": generate_type_option_name,
}
def create_template(template_str):
template = jinja2.Template(template_str)
template.globals.update(template_globals)
return template
mm_dispatch_template = create_template(DISPATCH_TEMPLATE)
mm_impl_template = create_template(IMPL_TEMPLATE)
prepack_dispatch_template = create_template(PREPACK_TEMPLATE)
def create_sources(impl_configs: list[ImplConfig], num_impl_files=8):
sources = []
sources.append(
(
"machete_mm_dispatch",
mm_dispatch_template.render(impl_configs=impl_configs),
)
)
prepack_types = []
for impl_config in impl_configs:
convert_type = (
impl_config.types.a
if impl_config.types.b_group_scale == DataType.void
else impl_config.types.b_group_scale
)
prepack_types.append(
PrepackTypeConfig(
a=impl_config.types.a,
b_num_bits=VLLMDataTypeSize[impl_config.types.b],
convert=convert_type,
accumulator=impl_config.types.accumulator,
)
)
def prepacked_type_key(prepack_type: PrepackTypeConfig):
# For now, we can just use the first accumulator type seen since
# the tensor core shapes/layouts don't vary based on accumulator
# type so we can generate less code this way
return (prepack_type.a, prepack_type.b_num_bits, prepack_type.convert)
unique_prepack_types = []
prepack_types_seen = set()
for prepack_type in prepack_types:
key = prepacked_type_key(prepack_type)
if key not in prepack_types_seen:
unique_prepack_types.append(prepack_type)
prepack_types_seen.add(key)
sources.append(
(
"machete_prepack",
prepack_dispatch_template.render(
types=unique_prepack_types,
),
)
)
# Split up impls across files
num_impls = reduce(lambda x, y: x + len(y.schedules), impl_configs, 0)
num_impls_per_file = math.ceil(num_impls / num_impl_files)
files_impls: list[list[ImplConfig]] = [[]]
curr_num_impls_assigned = 0
curr_impl_in_file = 0
curr_impl_configs = deepcopy(list(reversed(impl_configs)))
while curr_num_impls_assigned < num_impls:
room_left_in_file = num_impls_per_file - curr_impl_in_file
if room_left_in_file == 0:
files_impls.append([])
room_left_in_file = num_impls_per_file
curr_impl_in_file = 0
curr_ic = curr_impl_configs[-1]
if len(curr_ic.schedules) >= room_left_in_file:
# Break apart the current impl config
tmp_ic = deepcopy(curr_ic)
tmp_ic.schedules = curr_ic.schedules[:room_left_in_file]
curr_ic.schedules = curr_ic.schedules[room_left_in_file:]
files_impls[-1].append(tmp_ic)
else:
files_impls[-1].append(curr_ic)
curr_impl_configs.pop()
curr_num_impls_assigned += len(files_impls[-1][-1].schedules)
curr_impl_in_file += len(files_impls[-1][-1].schedules)
for part, file_impls in enumerate(files_impls):
sources.append(
(
f"machete_mm_impl_part{part + 1}",
mm_impl_template.render(impl_configs=file_impls),
)
)
return sources
def generate():
# See csrc/quantization/machete/Readme.md, the Codegeneration for more info
# about how this works
SCRIPT_DIR = os.path.dirname(__file__)
sch_common_params = dict(
kernel_schedule=TmaMI,
epilogue_schedule=TmaCoop,
tile_scheduler=TileSchedulerType.StreamK,
)
# Stored as "condition": ((tile_shape_mn), (cluster_shape_mnk))
default_tile_heuristic_config = {
#### M = 257+
"M > 256 && K <= 16384 && N <= 4096": ((128, 128), (2, 1, 1)),
"M > 256": ((128, 256), (2, 1, 1)),
#### M = 129-256
"M > 128 && K <= 4096 && N <= 4096": ((128, 64), (2, 1, 1)),
"M > 128 && K <= 8192 && N <= 8192": ((128, 128), (2, 1, 1)),
"M > 128": ((128, 256), (2, 1, 1)),
#### M = 65-128
"M > 64 && K <= 4069 && N <= 4069": ((128, 32), (2, 1, 1)),
"M > 64 && K <= 4069 && N <= 8192": ((128, 64), (2, 1, 1)),
"M > 64 && K >= 8192 && N >= 12288": ((256, 128), (2, 1, 1)),
"M > 64": ((128, 128), (2, 1, 1)),
#### M = 33-64
"M > 32 && K <= 6144 && N <= 6144": ((128, 16), (1, 1, 1)),
"M > 32 && K >= 16384 && N >= 12288": ((256, 64), (2, 1, 1)),
"M > 32": ((128, 64), (2, 1, 1)),
#### M = 17-32
"M > 16 && K <= 12288 && N <= 8192": ((128, 32), (2, 1, 1)),
"M > 16": ((256, 32), (2, 1, 1)),
#### M = 1-16
"N >= 26624": ((256, 16), (1, 1, 1)),
None: ((128, 16), (1, 1, 1)),
}
# For now we use the same heuristic for all types
# Heuristic is currently tuned for H100s
default_heuristic = [
(cond, ScheduleConfig(*tile_config, **sch_common_params)) # type: ignore
for cond, tile_config in default_tile_heuristic_config.items()
]
def get_unique_schedules(heuristic: dict[str, ScheduleConfig]):
# Do not use schedules = list(set(...)) because we need to make sure
# the output list is deterministic; otherwise the generated kernel file
# will be non-deterministic and causes ccache miss.
schedules = []
for _, schedule_config in heuristic:
if schedule_config not in schedules:
schedules.append(schedule_config)
return schedules
impl_configs = []
GPTQ_kernel_type_configs = list(
TypeConfig(
a=a,
b=b,
b_group_scale=a,
b_group_zeropoint=DataType.void,
b_channel_scale=DataType.void,
a_token_scale=DataType.void,
out=a,
accumulator=DataType.f32,
)
for b in (VLLMDataType.u4b8, VLLMDataType.u8b128)
for a in (DataType.f16, DataType.bf16)
)
impl_configs += [
ImplConfig(x[0], x[1], x[2])
for x in zip(
GPTQ_kernel_type_configs,
itertools.repeat(get_unique_schedules(default_heuristic)),
itertools.repeat(default_heuristic),
)
]
AWQ_kernel_type_configs = list(
TypeConfig(
a=a,
b=b,
b_group_scale=a,
b_group_zeropoint=a,
b_channel_scale=DataType.void,
a_token_scale=DataType.void,
out=a,
accumulator=DataType.f32,
)
for b in (DataType.u4, DataType.u8)
for a in (DataType.f16, DataType.bf16)
)
impl_configs += [
ImplConfig(x[0], x[1], x[2])
for x in zip(
AWQ_kernel_type_configs,
itertools.repeat(get_unique_schedules(default_heuristic)),
itertools.repeat(default_heuristic),
)
]
# TODO: Support W4A8 when ready
# # Stored as "condition": ((tile_shape_mn), (cluster_shape_mnk))
# # TODO (LucasWilkinson): Further tuning required
# qqq_tile_heuristic_config = {
# #### M = 257+
# # ((128, 256), (2, 1, 1)) Broken for QQQ types
# # TODO (LucasWilkinson): Investigate further
# # "M > 256 && K <= 16384 && N <= 4096": ((128, 128), (2, 1, 1)),
# # "M > 256": ((128, 256), (2, 1, 1)),
# "M > 256": ((128, 128), (2, 1, 1)),
# #### M = 129-256
# "M > 128 && K <= 4096 && N <= 4096": ((128, 64), (2, 1, 1)),
# "M > 128 && K <= 8192 && N <= 8192": ((128, 128), (2, 1, 1)),
# # ((128, 256), (2, 1, 1)) Broken for QQQ types
# # TODO (LucasWilkinson): Investigate further
# # "M > 128": ((128, 256), (2, 1, 1)),
# "M > 128": ((128, 128), (2, 1, 1)),
# #### M = 65-128
# "M > 64 && K <= 4069 && N <= 4069": ((128, 32), (2, 1, 1)),
# "M > 64 && K <= 4069 && N <= 8192": ((128, 64), (2, 1, 1)),
# "M > 64 && K >= 8192 && N >= 12288": ((256, 128), (2, 1, 1)),
# "M > 64": ((128, 128), (2, 1, 1)),
# #### M = 33-64
# "M > 32 && K <= 6144 && N <= 6144": ((128, 16), (1, 1, 1)),
# # Broken for QQQ types
# # TODO (LucasWilkinson): Investigate further
# #"M > 32 && K >= 16384 && N >= 12288": ((256, 64), (2, 1, 1)),
# "M > 32": ((128, 64), (2, 1, 1)),
# #### M = 17-32
# "M > 16 && K <= 12288 && N <= 8192": ((128, 32), (2, 1, 1)),
# "M > 16": ((256, 32), (2, 1, 1)),
# #### M = 1-16
# "N >= 26624": ((256, 16), (1, 1, 1)),
# None: ((128, 16), (1, 1, 1)),
# }
# # For now we use the same heuristic for all types
# # Heuristic is currently tuned for H100s
# qqq_heuristic = [
# (cond, ScheduleConfig(*tile_config,
# **sch_common_params)) # type: ignore
# for cond, tile_config in qqq_tile_heuristic_config.items()
# ]
# QQQ_kernel_types = [
# *(TypeConfig(
# a=DataType.s8,
# b=VLLMDataType.u4b8,
# b_group_scale=b_group_scale,
# b_group_zeropoint=DataType.void,
# b_channel_scale=DataType.f32,
# a_token_scale=DataType.f32,
# out=DataType.f16,
# accumulator=DataType.s32,
# ) for b_group_scale in (DataType.f16, DataType.void)),
# *(TypeConfig(
# a=DataType.e4m3,
# b=VLLMDataType.u4b8,
# b_group_scale=b_group_scale,
# b_group_zeropoint=DataType.void,
# b_channel_scale=DataType.f32,
# a_token_scale=DataType.f32,
# out=DataType.f16,
# accumulator=DataType.f32,
# ) for b_group_scale in (DataType.f16, DataType.void)),
# ]
# impl_configs += [
# ImplConfig(x[0], x[1], x[2])
# for x in zip(QQQ_kernel_types,
# itertools.repeat(get_unique_schedules(qqq_heuristic)),
# itertools.repeat(qqq_heuristic))
# ]
output_dir = os.path.join(SCRIPT_DIR, "generated")
# Delete the "generated" directory if it exists
if os.path.exists(output_dir):
shutil.rmtree(output_dir)
# Create the "generated" directory
os.makedirs(output_dir)
# Render each group of configurations into separate files
for filename, code in create_sources(impl_configs):
filepath = os.path.join(output_dir, f"{filename}.cu")
with open(filepath, "w") as output_file:
output_file.write(code)
print(f"Rendered template to {filepath}")
if __name__ == "__main__":
generate()

View File

@@ -0,0 +1,31 @@
#pragma once
#include "cutlass_extensions/vllm_collective_builder.cuh"
#include "machete_mainloop.cuh"
namespace cutlass::gemm::collective {
using namespace cute;
struct MacheteKernelTag {};
template <class ElementPairA_, class GmemLayoutA_, int AlignmentA,
class ElementPairB_, class GmemLayoutB_, int AlignmentB,
class ElementAccumulator, class TileShape_MNK, class ClusterShape_MNK,
class StageCountType, class KernelScheduleType>
struct VLLMCollectiveBuilder<
MacheteKernelTag, arch::Sm90, arch::OpClassTensorOp, ElementPairA_,
GmemLayoutA_, AlignmentA, ElementPairB_, GmemLayoutB_, AlignmentB,
ElementAccumulator, TileShape_MNK, ClusterShape_MNK, StageCountType,
KernelScheduleType,
cute::enable_if_t<(
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecialized> ||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpong> ||
cute::is_same_v<KernelScheduleType,
KernelTmaWarpSpecializedCooperative>)>> {
using CollectiveOp = machete::MacheteCollectiveMma<
ElementPairA_, GmemLayoutA_, AlignmentA, ElementPairB_, GmemLayoutB_,
AlignmentB, ElementAccumulator, TileShape_MNK, ClusterShape_MNK,
StageCountType, KernelScheduleType>;
};
}; // namespace cutlass::gemm::collective

View File

@@ -0,0 +1,35 @@
#pragma once
#include "cutlass/cutlass.h"
#include "cute/layout.hpp"
namespace machete {
using namespace cute;
// get an interleaved block layout where each element consecutive element has a
// stride of bit_stride and the block width is blk_bit_width,
// examples:
// size_bits<T> = 8, bit_stride = 8, blk_bit_width = 32 -> 4:1
// size_bits<T> = 8, bit_stride = 16, blk_bit_width = 32 -> (2, 2):(2, 1)
// size_bits<T> = 4, bit_stride = 8, blk_bit_width = 32 -> (4, 2):(2, 1)
// size_bits<T> = 4, bit_stride = 16, blk_bit_width = 32 -> (2, 4):(4, 1)
template <typename T, int bit_stride, int blk_bit_width>
CUTE_HOST_DEVICE static constexpr auto get_interleaved_blk_layout() {
static_assert(blk_bit_width % bit_stride == 0);
static_assert(bit_stride % cute::sizeof_bits_v<T> == 0);
constexpr auto elems_per_blk = blk_bit_width / cute::sizeof_bits_v<T>;
if constexpr (cute::sizeof_bits_v<T> == bit_stride) {
// identity layout
return Layout<Shape<Int<elems_per_blk>>>{};
} else {
constexpr auto elems_per_stride = bit_stride / cute::sizeof_bits_v<T>;
constexpr auto num_strides = elems_per_blk / elems_per_stride;
return Layout<Shape<Int<num_strides>, Int<elems_per_stride>>,
Stride<Int<elems_per_stride>, Int<1>>>{};
}
}
}; // namespace machete

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,309 @@
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
// clang-format off
// The cutlass include order matters (annoyingly)
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
// clang-format on
#include "cutlass_extensions/cute_utils.cuh"
#include "cutlass_extensions/vllm_numeric_conversion.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
#include "cutlass_extensions/torch_utils.hpp"
#include "machete_collective_builder.cuh"
#include "machete_prepacked_layout.cuh"
#include "machete_interleaving_utils.cuh"
namespace machete {
using namespace cute;
// NOTE This kernel computes D = alpha * A * B + beta * C by computing
// D^t = alpha * B^t * A^t + beta * C^t, this is because the wgmma
// instructions only support sourcing from registers for the left-hand
// operand, we want to upconvert/decompress the quantized operand in
// register. Since the primary use case we want to support is Y = XW^t where
// W is quantized, in this situation or right-hand operand is quantized so
// we compute the transpose to move it to the left-hand side.
template <typename ElementA_, typename ElementB_, typename ElementD_,
typename AccumulatorT, typename GroupScaleT, typename GroupZeroT,
typename ChannelScaleT, typename TokenScaleT, class KernelSchedule,
typename ScheduleConfig>
struct MacheteKernelTemplate {
static constexpr bool with_C = false; // not ever used
static constexpr bool with_group_scales = !std::is_same_v<GroupScaleT, void>;
static constexpr bool with_group_zeropoints =
!std::is_same_v<GroupZeroT, void>;
static constexpr bool with_channel_scales =
!std::is_same_v<ChannelScaleT, void>;
static constexpr bool with_token_scales = !std::is_same_v<TokenScaleT, void>;
using MmaType = ElementA_;
using ElementA = ElementA_;
using ElementB = ElementB_;
using ElementD = ElementD_;
using ElementC = cute::conditional_t<with_C, ElementD, void>;
using ElementAccumulator = AccumulatorT;
using ElementCompute = AccumulatorT; // For Epilogue
// Use dummy values when we don't have scales or zeropoints
using ElementZGroup =
cute::conditional_t<with_group_zeropoints, GroupZeroT, MmaType>;
using ElementSGroup =
cute::conditional_t<with_group_scales, GroupScaleT, MmaType>;
using ElementConvertGroup =
cute::conditional_t<with_group_scales, GroupScaleT, MmaType>;
using ElementSChannel =
cute::conditional_t<with_channel_scales, ChannelScaleT, AccumulatorT>;
using ElementSToken =
cute::conditional_t<with_token_scales, TokenScaleT, AccumulatorT>;
using BTypeTuple = cute::conditional_t<
with_group_scales,
cute::conditional_t<with_group_zeropoints,
cute::tuple<ElementB, ElementSGroup, ElementZGroup>,
cute::tuple<ElementB, ElementSGroup>>,
ElementB>;
using LayoutA = cutlass::layout::RowMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = LayoutC;
using LayoutScale = cutlass::layout::RowMajor;
// not actually used since B has the prepacked layout, but required by cutlass
using _LayoutB = cutlass::layout::ColumnMajor;
// Interface strides expected by create_arguments (will get transposed)
using StrideA = cutlass::detail::TagToStrideA_t<LayoutA>;
using StrideC = cutlass::detail::TagToStrideA_t<LayoutC>;
using StrideD = cutlass::detail::TagToStrideA_t<LayoutD>;
using StrideSGroup = cutlass::detail::TagToStrideA_t<LayoutScale>;
using StrideZGroup = StrideSGroup;
using LayoutA_Transpose =
typename cutlass::layout::LayoutTranspose<LayoutA>::type;
using LayoutC_Transpose =
typename cutlass::layout::LayoutTranspose<LayoutC>::type;
using LayoutD_Transpose =
typename cutlass::layout::LayoutTranspose<LayoutD>::type;
using ArchTag = cutlass::arch::Sm90;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using PrepackedLayoutB =
PrepackedLayoutBTemplate<ElementA_, ElementB_, ElementConvertGroup,
AccumulatorT, LayoutA_Transpose, KernelSchedule>;
static int constexpr TileShapeK =
128 * 8 / cutlass::sizeof_bits<MmaType>::value;
static int constexpr AlignmentA = 128 / cutlass::sizeof_bits_v<ElementA>;
static int constexpr AlignmentB = 128 / cutlass::sizeof_bits_v<ElementB>;
static int constexpr AlignmentC =
(with_C) ? 128 / cutlass::sizeof_bits_v<ElementC> : 0;
static int constexpr AlignmentD = 128 / cutlass::sizeof_bits_v<ElementD>;
using TileShape = decltype(append(typename ScheduleConfig::TileShapeNM{},
cute::Int<TileShapeK>{}));
using ClusterShape = typename ScheduleConfig::ClusterShape;
using EpilogueSchedule = typename ScheduleConfig::EpilogueSchedule;
using EpilogueTileType = typename ScheduleConfig::EpilogueTileType;
using TileScheduler = typename ScheduleConfig::TileScheduler;
static_assert(
(!with_channel_scales && !with_token_scales) ||
((with_channel_scales && with_token_scales) &&
std::is_same_v<ElementSChannel, ElementSToken>),
"Currently token and channel scales (if present) must be the same type");
// Currently only supports float scales
using ChTokScalesEpilogue =
typename vllm::c3x::ScaledEpilogue<ElementAccumulator, ElementD,
TileShape>;
static_assert((with_channel_scales || with_token_scales) ||
(std::is_same_v<ElementSChannel, float> &&
std::is_same_v<ElementSToken, float>),
"Currently token and channel scales (if present) must be float "
"(and if one is present the other must be too)");
using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT<
cutlass::epilogue::fusion::Sm90AccFetch>;
using EVTCompute =
std::conditional_t<with_channel_scales || with_token_scales,
typename ChTokScalesEpilogue::EVTCompute,
StoreEpilogueCompute>;
// EVTCompute
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType,
ElementAccumulator, ElementSChannel, ElementC, LayoutC_Transpose,
AlignmentC, ElementD, LayoutD_Transpose, AlignmentD, EpilogueSchedule,
EVTCompute>::CollectiveOp;
using CollectiveMainloop =
typename cutlass::gemm::collective::VLLMCollectiveBuilder<
cutlass::gemm::collective::MacheteKernelTag, ArchTag, OperatorClass,
BTypeTuple, PrepackedLayoutB, AlignmentB, ElementA, LayoutA_Transpose,
AlignmentA, ElementAccumulator, TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, // Indicates ProblemShape
CollectiveMainloop, CollectiveEpilogue, TileScheduler>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
// stride_B is unused (since B is prepacked), but still required by cutlass
using _StrideB = cutlass::detail::TagToStrideB_t<_LayoutB>;
using Arguments = typename Gemm::Arguments;
using MainloopArguments = typename GemmKernel::MainloopArguments;
using EpilogueArguments = typename GemmKernel::EpilogueArguments;
static Arguments create_arguments(
cudaStream_t stream,
torch::Tensor const& A, // MxK matrix
torch::Tensor const& B, // KxN prepacked matrix
torch::Tensor& D, // MxN matrix
std::optional<torch::Tensor> const& maybe_g_scales, // scale_KxN matrix
std::optional<torch::Tensor> const& maybe_g_zeros, // scale_KxN matrix
std::optional<int64_t> maybe_group_size,
std::optional<torch::Tensor> const& maybe_ch_scales, // len N vector
std::optional<torch::Tensor> const& maybe_tok_scales) // len M vector
{
static_assert(!with_group_zeropoints || with_group_scales);
int M = A.size(0), N = B.size(1), K = A.size(1);
TORCH_CHECK(D.size(0) == M && D.size(1) == N);
auto layout_A = make_cute_layout<StrideA>(A, "A");
auto layout_D = make_cute_layout<StrideD>(D, "D");
auto layout_S_group =
maybe_make_cute_layout<StrideSGroup>(maybe_g_scales, "group_scales");
auto layout_Z_group =
maybe_make_cute_layout<StrideZGroup>(maybe_g_zeros, "group_zeros");
int64_t numel_S_channel = maybe_ch_scales ? maybe_ch_scales->numel() : 0;
int64_t numel_S_token = maybe_tok_scales ? maybe_tok_scales->numel() : 0;
auto unwrap = [](auto const& t) {
return t ? t->const_data_ptr() : nullptr;
};
auto A_ptr = static_cast<ElementA const*>(A.const_data_ptr());
auto B_ptr = static_cast<ElementB const*>(B.const_data_ptr());
auto D_ptr = static_cast<ElementD*>(D.mutable_data_ptr());
auto S_group_ptr =
static_cast<ElementSGroup const*>(unwrap(maybe_g_scales));
auto Z_group_ptr = static_cast<ElementZGroup const*>(unwrap(maybe_g_zeros));
auto S_channel_ptr =
static_cast<ElementSChannel const*>(unwrap(maybe_ch_scales));
auto S_token_ptr =
static_cast<ElementSToken const*>(unwrap(maybe_tok_scales));
int const group_size =
maybe_group_size == -1 ? K : maybe_group_size.value_or(K);
int const scale_k = (K + group_size - 1) / group_size;
TORCH_CHECK(size<0>(layout_A) == M && size<1>(layout_A) == K);
TORCH_CHECK(size<0>(layout_D) == M && size<1>(layout_D) == N);
if constexpr (with_group_scales) {
TORCH_CHECK(S_group_ptr && layout_S_group);
TORCH_CHECK((size<0>(*layout_S_group) == scale_k &&
size<1>(*layout_S_group) == N));
} else {
TORCH_CHECK(!S_group_ptr, "Scales not supported");
}
if constexpr (with_group_zeropoints) {
TORCH_CHECK(Z_group_ptr && layout_Z_group);
TORCH_CHECK((size<0>(*layout_Z_group) == scale_k &&
size<1>(*layout_Z_group) == N));
TORCH_CHECK(layout_S_group && *layout_Z_group == *layout_S_group,
"Scales and zeros must have the same layout");
} else {
TORCH_CHECK(!Z_group_ptr, "Zeropoints not supported");
}
if constexpr (with_channel_scales || with_token_scales) {
TORCH_CHECK(
(maybe_ch_scales->numel() == N || maybe_ch_scales->numel() == 1) &&
(maybe_tok_scales->numel() == M || maybe_tok_scales->numel() == 1));
}
// Transpose A and D
// A doesn't need to be transposed since cutlass expects a NxK matrix
// for B (which is At)
auto stride_At = layout_A.stride();
auto stride_Dt = permute_layout<1, 0, 2>(layout_D).stride();
MainloopArguments mainloop_arguments{};
// {Accum, C, C_layout, D, D}
EpilogueArguments epilogue_arguments{};
if constexpr (with_channel_scales || with_token_scales) {
epilogue_arguments =
EpilogueArguments{ChTokScalesEpilogue::prepare_args(
*maybe_ch_scales, *maybe_tok_scales),
nullptr,
{},
D_ptr,
stride_Dt};
} else {
epilogue_arguments = EpilogueArguments{{}, nullptr, {}, D_ptr, stride_Dt};
}
if constexpr (with_group_scales && with_group_zeropoints) {
auto stride_S_group = permute_layout<1, 0, 2>(*layout_S_group).stride();
mainloop_arguments = MainloopArguments{
B_ptr, _StrideB{}, A_ptr, stride_At,
S_group_ptr, stride_S_group, group_size, Z_group_ptr};
} else if constexpr (with_group_scales) {
auto stride_S_group = permute_layout<1, 0, 2>(*layout_S_group).stride();
mainloop_arguments =
MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At,
S_group_ptr, stride_S_group, group_size};
} else {
mainloop_arguments =
MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At};
}
return Arguments{cutlass::gemm::GemmUniversalMode::kGemm,
{N, M, K, 1},
mainloop_arguments,
epilogue_arguments};
};
static size_t get_workspace_size(Arguments const& args) {
return Gemm::get_workspace_size(args);
}
static bool can_implement(Arguments const& args) {
return Gemm::can_implement(args) == cutlass::Status::kSuccess;
}
static void run(Arguments const& args, void* workspace, cudaStream_t stream) {
Gemm gemm_op;
cutlass::Status status = gemm_op.initialize(args, workspace, stream);
TORCH_CHECK(status == cutlass::Status::kSuccess,
"Machete kernel failed to initialize workspace");
status = gemm_op.run(stream);
TORCH_CHECK(status == cutlass::Status::kSuccess, "Machete kernel failed");
}
};
}; // namespace machete

View File

@@ -0,0 +1,75 @@
#pragma once
#include <torch/all.h>
#include <Python.h>
#include "machete_mm_kernel.cuh"
#include "cutlass_extensions/torch_utils.hpp"
#include "core/scalar_type.hpp"
namespace machete {
struct MMArgs {
torch::Tensor const& A;
torch::Tensor const& B;
vllm::ScalarType const& b_type;
std::optional<at::ScalarType> const& maybe_out_type;
std::optional<torch::Tensor> const& maybe_group_scales;
std::optional<torch::Tensor> const& maybe_group_zeros;
std::optional<int64_t> maybe_group_size;
std::optional<torch::Tensor> const& maybe_channel_scales;
std::optional<torch::Tensor> const& maybe_token_scales;
std::optional<std::string> maybe_schedule;
};
struct SupportedSchedulesArgs {
at::ScalarType a_type;
vllm::ScalarType b_type;
std::optional<at::ScalarType> maybe_group_scales_type;
std::optional<at::ScalarType> maybe_group_zeros_type;
std::optional<at::ScalarType> maybe_channel_scales_type;
std::optional<at::ScalarType> maybe_token_scales_type;
std::optional<at::ScalarType> maybe_out_type;
};
torch::Tensor mm_dispatch(MMArgs args);
std::vector<std::string> supported_schedules_dispatch(
SupportedSchedulesArgs args);
template <typename MacheteKernel>
torch::Tensor run_impl(MMArgs args) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(args.A));
auto device = args.A.device();
auto stream = at::cuda::getCurrentCUDAStream(device.index());
int M = args.A.size(0);
int N = args.B.size(1);
int K = args.A.size(1);
// Allocate output
torch::Tensor D = torch::empty(
{M, N},
torch::TensorOptions()
.dtype(equivalent_scalar_type_v<typename MacheteKernel::ElementD>)
.device(device));
auto arguments = MacheteKernel::create_arguments(
stream, //
args.A, args.B, D, args.maybe_group_scales, args.maybe_group_zeros,
args.maybe_group_size, args.maybe_channel_scales,
args.maybe_token_scales);
TORCH_CHECK(MacheteKernel::can_implement(arguments),
"Machete kernel cannot be run with these arguments");
size_t workspace_size = MacheteKernel::get_workspace_size(arguments);
torch::Tensor workspace = torch::empty(
workspace_size, torch::TensorOptions().dtype(torch::kU8).device(device));
MacheteKernel::run(arguments, workspace.mutable_data_ptr(), stream);
return D;
};
}; // namespace machete

View File

@@ -0,0 +1,76 @@
#pragma once
#include "machete_mm_kernel.cuh"
#include "cutlass_extensions/cute_utils.cuh"
#include "cutlass_extensions/torch_utils.hpp"
namespace machete {
template <int threads, typename PrepackedLayoutB, typename BInTensor,
typename ElementB>
static __global__ void prepack_B_kernel(BInTensor B_in, ElementB* B_out_ptr) {
auto constexpr block_size =
Int<size(typename PrepackedLayoutB::PPBlockShape_NK{})>{};
auto constexpr eles_per_thread = Int<block_size / threads>{};
static_assert(block_size % threads == 0,
"block_size must be divisible by the number of threads");
// Which pre-packed are we responsible for
auto blk_coord = make_coord(blockIdx.x, blockIdx.y, blockIdx.z);
auto tB_in = local_tile(
B_in, append(typename PrepackedLayoutB::PPBlockShape_NK{}, _1{}),
blk_coord);
// Find the start offset in the output for this pre-packed block
auto bNbKL_to_offset = PrepackedLayoutB::bNbKL_to_offset(shape(B_in));
// Tensor representing a 1:1 mapping to the output space in 1D
auto tB_out_linear =
make_tensor(get_logical_ptr(B_out_ptr) + bNbKL_to_offset(blk_coord),
make_layout(make_shape(block_size)));
// Mapping from output space (1D) to input space
auto tB_in_linear = make_tensor(
tB_in.data(),
tB_in.layout()
.compose(right_inverse(PrepackedLayoutB::ppblock_ilvd_NK_to_offset()))
.with_shape(make_shape(block_size)));
// Tile for this specific thread (could have used a TiledCopy but these work
// best with 2d layouts, this is a simple 1d layout so local_tile is enough,
// we are also not that concerned with performance for this kernel)
auto thr_tB_in_linear =
local_tile(tB_in_linear, make_shape(eles_per_thread), threadIdx.x);
auto thr_tB_out_linear =
local_tile(tB_out_linear, make_shape(eles_per_thread), threadIdx.x);
// Construct a register-backed Tensor with the same shape as each thread's
// partition
auto fragment = make_tensor<ElementB>(shape(thr_tB_in_linear));
copy(thr_tB_in_linear, fragment);
copy(Copy_Atom<DefaultCopy, uint8_t>{}, fragment, thr_tB_out_linear);
}
template <typename PrepackedLayoutB, typename InLayout>
static void prepack_B_template(
cudaStream_t stream, typename PrepackedLayoutB::ElementB const* B_in_ptr,
InLayout B_layout, typename PrepackedLayoutB::ElementB* B_out_ptr) {
using TileShapeNKL =
decltype(append(typename PrepackedLayoutB::PPBlockShape_NK{}, _1{}));
auto ilvd_NKbNbKL_to_offset =
PrepackedLayoutB::ilvd_NKbNbKL_to_offset(shape(B_layout));
TORCH_CHECK(size<0>(B_layout) % size<0>(TileShapeNKL{}) == 0);
TORCH_CHECK(size<1>(B_layout) % size<1>(TileShapeNKL{}) == 0);
auto N_tiles = size<0>(B_layout) / size<0>(TileShapeNKL{});
auto K_tiles = size<1>(B_layout) / size<1>(TileShapeNKL{});
auto L_tiles = size<2>(B_layout);
auto B_in = make_tensor(get_logical_ptr(B_in_ptr), B_layout);
prepack_B_kernel<128, PrepackedLayoutB>
<<<dim3(N_tiles, K_tiles, L_tiles), 128, 0, stream>>>(B_in, B_out_ptr);
}
}; // namespace machete

View File

@@ -0,0 +1,74 @@
#pragma once
#include "machete_prepack_kernel.cuh"
#include "cutlass_extensions/torch_utils.hpp"
#include "core/scalar_type.hpp"
namespace machete {
struct PrepackBArgs {
torch::Tensor const& B;
at::ScalarType a_type;
vllm::ScalarType b_type;
std::optional<at::ScalarType> maybe_group_scales_type;
};
template <typename PrepackedLayoutB>
torch::Tensor prepack_impl(torch::Tensor const B) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(B));
using ElementB = typename PrepackedLayoutB::ElementB;
using PPBlockShape_NK = typename PrepackedLayoutB::PPBlockShape_NK;
auto device = B.device();
auto stream = at::cuda::getCurrentCUDAStream(device.index());
auto B_ptr = static_cast<ElementB const*>(B.const_data_ptr());
// elements per storage item for B
auto eles_per_storage =
(B.dtype().itemsize() * 8) / cute::sizeof_bits_v<ElementB>;
// torch B passed in is/should be (packed_K,N), the kernel expects (N,K,L) (to
// match cutlass using (N,K,L) for B), so we transpose B to (N,packed_K,L)
auto Bt_packed = B.t();
TORCH_CHECK(
(B.size(0) * eles_per_storage) % size<1>(PPBlockShape_NK{}) == 0,
"B.shape[0] (in terms of unpacked elements) must be a multiple of ",
size<1>(PPBlockShape_NK{}));
TORCH_CHECK(B.size(1) % size<0>(PPBlockShape_NK{}) == 0,
"B.shape[1] must be a multiple of ", size<0>(PPBlockShape_NK{}));
using StrideB = cutlass::detail::TagToStrideB_t<cutlass::layout::ColumnMajor>;
auto const l_Bt_packed = make_cute_layout<StrideB>(Bt_packed, "B");
// convert (N,packed_K,L) layout to (N,K,L) layout
// in effect we want to do: blocked_product(layout_Bt_packed,
// make_ordered_layout(make_shape(_1{}, eles_per_storage, _1{}),
// Step<_1, _0, _2>{}));
// but blocked_product does not support dynamic strides so we implement the
// equivalent manually,
// new_shape = (N, packed_K, L) * (1, eles_per_storage, 1) -> (N, K, L)
// new_stride = (s0, s1, s2) * (eles_per_storage, 1, eles_per_storage)
// when s1 == 1
TORCH_CHECK(stride<1>(l_Bt_packed) == 1);
// clang-format off
auto const layout_Bt = make_layout(
transform_with_idx(l_Bt_packed.shape(), [&](auto ele, auto idx) {
return idx == 1 ? ele * eles_per_storage : ele;
}),
transform_with_idx(l_Bt_packed.stride(), [&](auto ele, auto idx) {
return idx != 1 ? ele * eles_per_storage : ele;
}));
// clang-format on
// Allocate output
torch::Tensor D = torch::empty_like(B, {}, at::MemoryFormat::Contiguous);
prepack_B_template<PrepackedLayoutB>(
stream, B_ptr, layout_Bt, static_cast<ElementB*>(D.mutable_data_ptr()));
return D;
};
torch::Tensor prepack_B_dispatch(PrepackBArgs args);
}; // namespace machete

View File

@@ -0,0 +1,253 @@
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
// clang-format off
// The cutlass include order matters (annoyingly)
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
// clang-format on
#include "cutlass_extensions/cute_utils.cuh"
#include "machete_collective_builder.cuh"
#include "machete_interleaving_utils.cuh"
namespace machete {
using namespace cute;
struct IlvBlkLayoutAuto {};
// This defines a prepacked layout for the B matrix, where the matrix is broken
// up into PPBlockShape_NK blocks. The data within each block is then compactly
// stored in memory such that when performing a TiledMMA operation with the same
// shape as prepacked block, all the data for a given thread is contiguous in
// memory. This allows us to use wider shared memory loads when loading B from
// shared memory. The values within a thread are also potentially interlaeved
// inorder to allow for more efficient upconverting.
//
// The contract here is that the `TiledMma` determined below matches the one
// ultimately used in the kernel. (this is also why the other element types are
// required along with the kernel schedule)
template <typename ElementA_, typename ElementB_, typename ElementConvert_,
typename AccumulatorT, class LayoutB, class KernelSchedule,
typename IlvBlkLayout_ = IlvBlkLayoutAuto>
// clang-format on
struct PrepackedLayoutBTemplate {
using MmaType = ElementA_;
using ElementA = ElementA_;
using ElementB = ElementB_;
using ElementAccumulator = AccumulatorT;
using ElementMma = MmaType;
// Interleave for 4bit bit types when we are not upconverting to fp8 or int8,
// in those cases case we use a LUT using prmt instructions to upconvert and
// is more efficient if the data is not interleaved For 8bit+ prmt
// instructions makes non-interleaved layouts efficient enough we don't need
// iterleaved layouts (and can reuse more of the existing cutlass converts)
static constexpr bool should_interleave =
sizeof_bits_v<ElementB> <= 4 &&
!std::is_same_v<ElementConvert_, cutlass::float_e4m3_t> &&
!std::is_same_v<ElementConvert_, int8_t>;
// Only use interleaved layouts for subbyte weights,
using IlvdBlkLayout = std::conditional_t<
std::is_same_v<IlvBlkLayout_, IlvBlkLayoutAuto>,
std::conditional_t<
should_interleave,
decltype(get_interleaved_blk_layout<
ElementB, sizeof_bits_v<ElementConvert_>, 32>()),
void>,
IlvBlkLayout_>;
// TODO (LucasWilkinson): compare the performance for other sizes
// Prepacked block shape, smallest layout atom for loading into registers
// (can contain multiple wgmma instructions worth of data in one block)
// We ideally want this to be configured such that a thread can perform 128bit
// loads, i.e. we amount of data associated with each thread within a
// prepacked block is a multiple of 128bits, when using a cooperative sechdule
// we have 256 threads working a single block at a time, this means each
// thread works on `sizeof_bits_v<ElementB> * (128*64) / 256` bits of data,
// for a 4bit type this would be 128bits
using PPBlockShape_NK = Shape<_128, _64>;
// Create the shape of the tile anticipated to be used by the GEMM kernel,
// when the kernel executes we will compute `Ct = Bt * At` since the
// quantized weights (B), must be the lhs operand so the flow through
// registers.
// The _128 here doesn't actually impact the shape of the stored tile directly
// but may impact the op selected by rs_op_selector
using GemmTileShape = decltype(make_shape(size<0>(PPBlockShape_NK{}), _128{},
size<1>(PPBlockShape_NK{})));
static constexpr cute::GMMA::Major GmmaMajorB =
gmma_rs_tag_to_major_B<LayoutB>();
// For coop schedules we have two warp groups cooperatively issuing wgmma
// instructions so we use 2 atoms along the M dim (one for each warpgroup)
using AtomLayoutMNK = cute::conditional_t<
cute::is_same_v<KernelSchedule, KernelTmaWarpSpecializedCooperative>,
Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;
using TiledMma = decltype(cute::make_tiled_mma(
cute::GMMA::rs_op_selector<ElementMma, ElementMma, ElementAccumulator,
GemmTileShape, GMMA::Major::K, GmmaMajorB>(),
AtomLayoutMNK{}));
// Prepacked block, (athrid, val) -> (N,K)
// i.e. ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...))) -> (N,K)
CUTE_HOST_DEVICE static constexpr auto ppblock_TV_to_NK() {
return TiledMma{}.thrfrg_A(make_layout(PPBlockShape_NK{}));
}
// Prepacked block, (N,K) -> (athrid, val)
// i.e. (N,K) -> ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...)))
CUTE_HOST_DEVICE static constexpr auto ppblock_NK_to_TV() {
return right_inverse(ppblock_TV_to_NK()).with_shape(PPBlockShape_NK{});
}
// Prepacked block, (athrid, val) -> (storage_offset)
// i.e. ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...))) -> (storage_idx)
CUTE_HOST_DEVICE static constexpr auto ppblock_TV_to_offset() {
// Return iterleaved layout
return make_ordered_layout(shape(ppblock_TV_to_NK()), Step<_1, _0>{});
}
// Prepacked block, (athrid, val) -> (storage_offset)
// i.e. ((ThrV,(ThrM,ThrK)),(IlvdFrgV,(RestM,RestK,...))) -> (storage_idx)
CUTE_HOST_DEVICE static constexpr auto ppblock_ilvd_TV_to_offset() {
auto layout_no_interleave =
make_ordered_layout(shape(ppblock_TV_to_NK()), Step<_1, _0>{});
if constexpr (std::is_same_v<IlvdBlkLayout, void>) {
return layout_no_interleave;
} else {
// interleave by transforming FrgV into interleaved blocks where each
// block has the layout IlvdBlkLayout, for example if IlvdBlkLayout is
// (2, 2) : (2, 1) then we get: ((2, 2), size(FrgV) / 4) : ((2, 1), 4)
// if FrgV is {A, B, C, D, E, F, G, H}
// then ((IlvBlk), FrgB) is {A, C, B, D, C, G, D, H}
auto frgV = get<1, 0>(layout_no_interleave);
auto ilvdBlk = IlvdBlkLayout{};
static_assert(size(frgV) % size(ilvdBlk) == 0,
"FrgV must be divisible by size(ilvdBlk)");
auto ilvd_FrgV = make_layout(
make_shape(shape(ilvdBlk), Int<size(frgV) / size(ilvdBlk)>{}),
make_stride(stride(ilvdBlk), size(ilvdBlk)));
// Return iterleaved layout
return make_layout(
get<0>(layout_no_interleave),
make_layout(ilvd_FrgV, get<1, 1>(layout_no_interleave)));
}
}
// Prepacked block, (M,K) -> (storage_offset)
CUTE_HOST_DEVICE static constexpr auto ppblock_ilvd_NK_to_offset() {
// do (M,K) -> (athrid, val) -> (storage_idx)
return ppblock_ilvd_TV_to_offset().compose(ppblock_NK_to_TV());
}
// ((athrid, val), (BlocksN, BlocksK), L) -> (storage_idx)
template <typename Shape_NKL>
CUTE_HOST_DEVICE static constexpr auto TVbNbKL_to_offset(
Shape_NKL shape_mkl) {
constexpr auto block_layout = ppblock_TV_to_offset();
// (BlocksN, BlocksK, L)
auto blocks_shape =
cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}),
[](auto x, auto y) { return x / y; });
// ((athrid, val), (BlocksN, BlocksK, L)) -> (storage_idx)
auto result = make_layout(
block_layout,
make_layout(blocks_shape,
compact_col_major(blocks_shape, size(block_layout))));
// ((athrid, val), (BlocksN, BlocksK, L))
// => ((athrid, val), (BlocksN, BlocksK), L)
return group<1, 3>(result(_, repeat<rank<1>(result)>(_)));
}
// ((athrid_val), (BlocksN, BlocksK, L)) -> (N, K, L)
template <typename Shape_NKL>
CUTE_HOST_DEVICE static constexpr auto TVbNbKL_to_offset_copy(
Shape_NKL shape_mkl) {
auto layout = TVbNbKL_to_offset(shape_mkl);
// for 4-bit elements, having >= 64 values per column
// allows TMA to load full 32-byte sectors
auto inner_layout =
make_layout(make_shape(_256{}, size<0>(layout) / _256{}));
return make_layout(inner_layout, get<1>(layout), get<2>(layout));
}
// ((BlockN, BlockK), (BlocksN, BlocksK), L) -> (storage_idx)
template <typename Shape_NKL>
CUTE_HOST_DEVICE static constexpr auto ilvd_NKbNbKL_to_offset(
Shape_NKL shape_mkl) {
constexpr auto block_layout = ppblock_ilvd_NK_to_offset();
// (BlocksN, BlocksK, L)
auto blocks_shape =
cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}),
[](auto x, auto y) { return x / y; });
// ((athrid, val), (BlocksN, BlocksK, L)) -> (storage_idx)
auto result = make_layout(
block_layout,
make_layout(blocks_shape,
compact_col_major(blocks_shape, size(block_layout))));
// ((athrid, val), (BlocksN, BlocksK, L)) => ((athrid, val), (BlocksN,
// BlocksK), L)
return group<1, 3>(result(_, repeat<rank<1>(result)>(_)));
}
// (BlocksN, BlocksK, L) -> (storage_idx)
template <typename Shape_NKL>
CUTE_HOST_DEVICE static constexpr auto bNbKL_to_offset(Shape_NKL shape_mkl) {
// (BlocksN, BlocksK, L)
auto blocks_shape =
cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}),
[](auto x, auto y) { return x / y; });
auto stride = size(PPBlockShape_NK{});
// (BlocksN, BlocksK, L) -> (storage_idx)
return make_layout(blocks_shape, compact_col_major(blocks_shape, stride));
}
// ((athrid, val), (BlocksN, BlocksK, L)) -> (N, K, L)
template <class Shape_NKL>
CUTE_HOST_DEVICE static auto TVbNbK_to_NKL(Shape_NKL shape_mkl) {
auto tile = make_tile(make_layout(size<0>(PPBlockShape_NK{})),
make_layout(size<1>(PPBlockShape_NK{})));
// ((BlockN, BlockK), (BlocksN, BlocksK, L)) -> (N, K, L)
auto tiled_A = zipped_divide(make_layout(shape_mkl), tile);
return tiled_A.compose(ppblock_TV_to_NK(), _);
}
// (N, K, L) -> ((athrid, val), (BlocksN, BlocksK), L)
template <class Shape_NKL>
CUTE_HOST_DEVICE static auto NKL_to_TVbNbK(Shape_NKL shape_mkl) {
auto TVbNbK_to_NKL_layout = TVbNbK_to_NKL(shape_mkl);
return blocked_product(ppblock_NK_to_TV(),
make_layout(shape<1>(TVbNbK_to_NKL_layout)));
}
};
}; // namespace machete

View File

@@ -0,0 +1,73 @@
#include "machete_mm_launcher.cuh"
#include "machete_prepack_launcher.cuh"
#include "core/scalar_type.hpp"
#include "core/registration.h"
namespace machete {
using namespace vllm;
std::vector<std::string> supported_schedules(
at::ScalarType a_type, int64_t b_type_id,
std::optional<at::ScalarType> maybe_group_scales_type,
std::optional<at::ScalarType> maybe_group_zeros_type,
std::optional<at::ScalarType> maybe_channel_scales_type,
std::optional<at::ScalarType> maybe_token_scales_type,
std::optional<at::ScalarType> maybe_out_type) {
ScalarType const b_type = ScalarType::from_id(b_type_id);
return supported_schedules_dispatch({
.a_type = a_type,
.b_type = b_type,
.maybe_group_scales_type = maybe_group_scales_type,
.maybe_group_zeros_type = maybe_group_zeros_type,
.maybe_channel_scales_type = maybe_channel_scales_type,
.maybe_token_scales_type = maybe_token_scales_type,
.maybe_out_type = maybe_out_type,
});
}
torch::Tensor mm(torch::Tensor const& A, torch::Tensor const& B,
int64_t b_type_id,
std::optional<at::ScalarType> const& maybe_out_type,
std::optional<torch::Tensor> const& maybe_group_scales,
std::optional<torch::Tensor> const& maybe_group_zeros,
std::optional<int64_t> maybe_group_size,
std::optional<torch::Tensor> const& maybe_channel_scales,
std::optional<torch::Tensor> const& maybe_token_scales,
std::optional<std::string> maybe_schedule) {
ScalarType const b_type = ScalarType::from_id(b_type_id);
return mm_dispatch({.A = A,
.B = B,
.b_type = b_type,
.maybe_out_type = maybe_out_type,
.maybe_group_scales = maybe_group_scales,
.maybe_group_zeros = maybe_group_zeros,
.maybe_group_size = maybe_group_size,
.maybe_channel_scales = maybe_channel_scales,
.maybe_token_scales = maybe_token_scales,
.maybe_schedule = maybe_schedule});
}
torch::Tensor prepack_B(
torch::Tensor const& B, at::ScalarType const& a_type, int64_t b_type_id,
std::optional<at::ScalarType> const& maybe_group_scales_type) {
ScalarType const b_type = ScalarType::from_id(b_type_id);
return prepack_B_dispatch(
{.B = B,
.a_type = a_type,
.b_type = b_type,
.maybe_group_scales_type = maybe_group_scales_type});
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("machete_prepack_B", &prepack_B);
m.impl("machete_mm", &mm);
}
// use CatchAll since supported_schedules has no tensor arguments
TORCH_LIBRARY_IMPL(TORCH_EXTENSION_NAME, CatchAll, m) {
m.impl("machete_supported_schedules", &supported_schedules);
}
}; // namespace machete