Sync from v0.13
This commit is contained in:
45
csrc/quantization/machete/Readme.md
Normal file
45
csrc/quantization/machete/Readme.md
Normal file
@@ -0,0 +1,45 @@
|
||||
# Machete (Mixed Precision Cutlass-Based GEMM)
|
||||
|
||||
Machete is a spiritual successor to the Marlin kernel but optimized for Hopper architectures and based on Cutlass. Being based on Cutlass, new type pairs and epilogues are easier to add compared to Marlin.
|
||||
|
||||
## Overview
|
||||
|
||||
Machete effectively performs
|
||||
|
||||
```python
|
||||
scale_type = w_s.dtype
|
||||
compute_type = a.dtype
|
||||
out = (w_q.to(scale_type) * w_s - w_z.to(scale_type)) @ a
|
||||
```
|
||||
|
||||
Where `w_q` is a quantized weight matrix, `w_s` is the quantization scales, and
|
||||
`w_z` is the quantization zeropoints.
|
||||
|
||||
> **_NOTE:_** `w_z` is added after the scales so we can
|
||||
use FMA operations, but this means they must have the scales pre-applied if the
|
||||
supplied zeropoints assume that they will be subtracted before the scales are
|
||||
applied.
|
||||
|
||||
## API
|
||||
|
||||
The main optimization within Machete is prepacking the weight matrix to more closely match the tensor core layouts, allowing for wider shared memory loads when loading the weight matrix. This means that the weight matrix must be prepacked before calling `machete_gemm`. The flow looks something like:
|
||||
|
||||
```python
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
...
|
||||
W_q_packed = ops.machete_prepack_B(w_q, wtype)
|
||||
output = ops.machete_gemm(
|
||||
a,
|
||||
b_q=W_q_packed,
|
||||
b_type=wtype,
|
||||
b_scales=w_s,
|
||||
b_group_size=group_size
|
||||
)
|
||||
```
|
||||
|
||||
## Code Generation
|
||||
|
||||
Since Machete is based on Cutlass, we can generate multiple type pairs and different tile shapes using the same kernel template. We generate multiple instantiations of this template using `generate.py`.
|
||||
|
||||
New type pairs (`TypeConfig`s) can be appended to `impl_configs` (in `generate()`), and these will get automatically generated (assuming they can be supported without issues). For each `TypeConfig`, you must also provide an `ImplConfig`, which bundles a `TypeConfig` with a list of `ScheduleConfig`s, `Specialization`s, and a default heuristic. The `ScheduleConfig`s (which contain info on tile shapes, tile scheduler, etc.) can perform differently for different problem shapes, and there is almost never one `ScheduleConfig` that works well for all problem shapes, so it is generally beneficial to generate different `ScheduleConfig`s for different potential problem shapes. This is where the heuristic comes in. For each `TypeConfig`, a default heuristic should be provided. This maps different problem shapes to different `ScheduleConfig`s and is used when the user does not provide the `schedule` parameter to `machete_gemm`. The `Specialization`s define what feature combinations to generate, i.e., `with_zeropoints`, `with_scales`, etc. We can reduce compile times and the final binary size by limiting the set of feature combinations we generate.
|
||||
694
csrc/quantization/machete/generate.py
Normal file
694
csrc/quantization/machete/generate.py
Normal file
@@ -0,0 +1,694 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
import shutil
|
||||
from collections.abc import Iterable
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, fields
|
||||
from functools import reduce
|
||||
|
||||
import jinja2
|
||||
from vllm_cutlass_library_extension import (
|
||||
DataType,
|
||||
EpilogueScheduleTag,
|
||||
EpilogueScheduleType,
|
||||
MixedInputKernelScheduleType,
|
||||
TileSchedulerTag,
|
||||
TileSchedulerType,
|
||||
VLLMDataType,
|
||||
VLLMDataTypeNames,
|
||||
VLLMDataTypeSize,
|
||||
VLLMDataTypeTag,
|
||||
VLLMDataTypeTorchDataTypeTag,
|
||||
VLLMDataTypeVLLMScalarTypeTag,
|
||||
VLLMKernelScheduleTag,
|
||||
)
|
||||
|
||||
#
|
||||
# Generator templating
|
||||
#
|
||||
|
||||
DISPATCH_TEMPLATE = """
|
||||
#include "../machete_mm_launcher.cuh"
|
||||
|
||||
namespace machete {
|
||||
|
||||
{% for impl_config in impl_configs %}
|
||||
{% set type_sig = gen_type_sig(impl_config.types) -%}
|
||||
{% for s in impl_config.schedules %}
|
||||
extern torch::Tensor impl_{{type_sig}}_sch_{{gen_sch_sig(s)}}(MMArgs);
|
||||
{%- endfor %}
|
||||
|
||||
torch::Tensor mm_dispatch_{{type_sig}}(MMArgs args) {
|
||||
[[maybe_unused]] auto M = args.A.size(0);
|
||||
[[maybe_unused]] auto N = args.B.size(1);
|
||||
[[maybe_unused]] auto K = args.A.size(1);
|
||||
|
||||
if (!args.maybe_schedule) {
|
||||
{%- for cond, s in impl_config.heuristic %}
|
||||
{%if cond is not none%}if ({{cond}})
|
||||
{%- else %}else
|
||||
{%- endif %}
|
||||
return impl_{{type_sig}}_sch_{{ gen_sch_sig(s) }}(args);{% endfor %}
|
||||
}
|
||||
|
||||
{%- for s in impl_config.schedules %}
|
||||
if (*args.maybe_schedule == "{{ gen_sch_sig(s) }}")
|
||||
return impl_{{type_sig}}_sch_{{ gen_sch_sig(s) }}(args);
|
||||
{%- endfor %}
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "machete_gemm(..) is not implemented for "
|
||||
"schedule = ", *args.maybe_schedule);
|
||||
}
|
||||
{%- endfor %}
|
||||
|
||||
|
||||
static inline std::optional<at::ScalarType> maybe_scalartype(
|
||||
std::optional<at::Tensor> const& t) {
|
||||
if (!t) {
|
||||
return std::nullopt;
|
||||
} else {
|
||||
return t->scalar_type();
|
||||
};
|
||||
}
|
||||
|
||||
torch::Tensor mm_dispatch(MMArgs args) {
|
||||
auto out_type = args.maybe_out_type.value_or(args.A.scalar_type());
|
||||
auto a_type = args.A.scalar_type();
|
||||
auto maybe_g_scales_type = maybe_scalartype(args.maybe_group_scales);
|
||||
auto maybe_g_zeros_type = maybe_scalartype(args.maybe_group_zeros);
|
||||
auto maybe_ch_scales_type = maybe_scalartype(args.maybe_channel_scales);
|
||||
auto maybe_tok_scales_type = maybe_scalartype(args.maybe_token_scales);
|
||||
|
||||
{% for impl_config in impl_configs %}
|
||||
{% set t = impl_config.types -%}
|
||||
{% set type_sig = gen_type_sig(t) -%}
|
||||
if (args.b_type == {{VLLMScalarTypeTag[t.b]}}
|
||||
&& a_type == {{TorchTypeTag[t.a]}}
|
||||
&& out_type == {{TorchTypeTag[t.out]}}
|
||||
&& {%if t.b_group_scale != void -%}
|
||||
maybe_g_scales_type == {{TorchTypeTag[t.b_group_scale]}}
|
||||
{%- else %}!maybe_g_scales_type{%endif%}
|
||||
&& {%if t.b_group_zeropoint != void -%}
|
||||
maybe_g_zeros_type == {{TorchTypeTag[t.b_group_zeropoint]}}
|
||||
{%- else %}!maybe_g_zeros_type{%endif%}
|
||||
&& {%if t.b_channel_scale != void -%}
|
||||
maybe_ch_scales_type == {{TorchTypeTag[t.b_channel_scale]}}
|
||||
{%- else %}!maybe_ch_scales_type{%endif%}
|
||||
&& {%if t.a_token_scale != void -%}
|
||||
maybe_tok_scales_type == {{TorchTypeTag[t.a_token_scale]}}
|
||||
{%- else %}!maybe_tok_scales_type{%endif%}
|
||||
) {
|
||||
return mm_dispatch_{{type_sig}}(args);
|
||||
}
|
||||
{%- endfor %}
|
||||
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false, "machete_mm(..) is not implemented for "
|
||||
"a_type=", args.A.scalar_type(),
|
||||
", b_type=", args.b_type.str(),
|
||||
", out_type=", out_type,
|
||||
", with_group_scale_type=", maybe_g_scales_type
|
||||
? toString(*maybe_g_scales_type) : "None",
|
||||
", with_group_zeropoint_type=", maybe_g_zeros_type
|
||||
? toString(*maybe_g_zeros_type) : "None",
|
||||
", with_channel_scale_type=", maybe_ch_scales_type
|
||||
? toString(*maybe_ch_scales_type) : "None",
|
||||
", with_token_scale_type=", maybe_tok_scales_type
|
||||
? toString(*maybe_tok_scales_type) : "None",
|
||||
"; implemented types are: \\n",
|
||||
{%- for impl_config in impl_configs %}
|
||||
{% set t = impl_config.types -%}
|
||||
"\\t{{gen_type_option_name(t)}}\\n",
|
||||
{%- endfor %}
|
||||
"");
|
||||
}
|
||||
|
||||
std::vector<std::string> supported_schedules_dispatch(
|
||||
SupportedSchedulesArgs args) {
|
||||
auto out_type = args.maybe_out_type.value_or(args.a_type);
|
||||
|
||||
{% for impl_config in impl_configs %}
|
||||
{% set t = impl_config.types -%}
|
||||
{% set schs = impl_config.schedules -%}
|
||||
if (args.b_type == {{VLLMScalarTypeTag[t.b]}}
|
||||
&& args.a_type == {{TorchTypeTag[t.a]}}
|
||||
&& out_type == {{TorchTypeTag[t.out]}}
|
||||
&& {%if t.b_group_scale != void -%}
|
||||
args.maybe_group_scales_type == {{TorchTypeTag[t.b_group_scale]}}
|
||||
{%- else %}!args.maybe_group_scales_type{%endif%}
|
||||
&& {%if t.b_group_zeropoint != void-%}
|
||||
args.maybe_group_zeros_type == {{TorchTypeTag[t.b_group_zeropoint]}}
|
||||
{%- else %}!args.maybe_group_zeros_type{%endif%}
|
||||
) {
|
||||
return {
|
||||
{%- for s in impl_config.schedules %}
|
||||
"{{gen_sch_sig(s)}}"{% if not loop.last %},{% endif %}
|
||||
{%- endfor %}
|
||||
};
|
||||
}
|
||||
{%- endfor %}
|
||||
|
||||
return {};
|
||||
};
|
||||
|
||||
}; // namespace machete
|
||||
"""
|
||||
|
||||
IMPL_TEMPLATE = """
|
||||
#include "../machete_mm_launcher.cuh"
|
||||
|
||||
namespace machete {
|
||||
|
||||
{% for sch in unique_schedules(impl_configs) %}
|
||||
{% set sch_sig = gen_sch_sig(sch) -%}
|
||||
struct sch_{{sch_sig}} {
|
||||
using TileShapeNM = Shape<{{
|
||||
to_cute_constant(sch.tile_shape_mn)|join(', ')}}>;
|
||||
using ClusterShape = Shape<{{
|
||||
to_cute_constant(sch.cluster_shape_mnk)|join(', ')}}>;
|
||||
// TODO: Reimplement
|
||||
// using KernelSchedule = {{KernelScheduleTag[sch.kernel_schedule]}};
|
||||
using EpilogueSchedule = {{EpilogueScheduleTag[sch.epilogue_schedule]}};
|
||||
using TileScheduler = {{TileSchedulerTag[sch.tile_scheduler]}};
|
||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
};
|
||||
{% endfor %}
|
||||
|
||||
{% for impl_config in impl_configs %}
|
||||
{% set t = impl_config.types -%}
|
||||
{% set schs = impl_config.schedules -%}
|
||||
{% set type_sig = gen_type_sig(t) -%}
|
||||
|
||||
template<typename Sch>
|
||||
using Kernel_{{type_sig}} = MacheteKernelTemplate<
|
||||
{{DataTypeTag[t.a]}}, // ElementA
|
||||
{{DataTypeTag[t.b]}}, // ElementB
|
||||
{{DataTypeTag[t.out]}}, // ElementD
|
||||
{{DataTypeTag[t.accumulator]}}, // Accumulator
|
||||
{{DataTypeTag[t.b_group_scale]}}, // GroupScaleT
|
||||
{{DataTypeTag[t.b_group_zeropoint]}}, // GroupZeroT
|
||||
{{DataTypeTag[t.b_channel_scale]}}, // ChannelScaleT
|
||||
{{DataTypeTag[t.a_token_scale]}}, // TokenScaleT
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperative,
|
||||
Sch>;
|
||||
|
||||
{% for sch in schs %}
|
||||
{% set sch_sig = gen_sch_sig(sch) -%}
|
||||
torch::Tensor
|
||||
impl_{{type_sig}}_sch_{{sch_sig}}(MMArgs args) {
|
||||
return run_impl<Kernel_{{type_sig}}<sch_{{sch_sig}}>>(args);
|
||||
}
|
||||
{%- endfor %}
|
||||
{%- endfor %}
|
||||
|
||||
}; // namespace machete
|
||||
"""
|
||||
|
||||
PREPACK_TEMPLATE = """
|
||||
#include "../machete_prepack_launcher.cuh"
|
||||
|
||||
namespace machete {
|
||||
|
||||
torch::Tensor prepack_B_dispatch(PrepackBArgs args) {
|
||||
auto convert_type = args.maybe_group_scales_type.value_or(args.a_type);
|
||||
{%- for t in types %}
|
||||
{% set b_type = unsigned_type_with_bitwidth(t.b_num_bits) %}
|
||||
if (args.a_type == {{TorchTypeTag[t.a]}}
|
||||
&& args.b_type.size_bits() == {{t.b_num_bits}}
|
||||
&& convert_type == {{TorchTypeTag[t.convert]}}) {
|
||||
return prepack_impl<
|
||||
PrepackedLayoutBTemplate<
|
||||
{{DataTypeTag[t.a]}}, // ElementA
|
||||
{{DataTypeTag[b_type]}}, // ElementB
|
||||
{{DataTypeTag[t.convert]}}, // ElementConvert
|
||||
{{DataTypeTag[t.accumulator]}}, // Accumulator
|
||||
cutlass::layout::ColumnMajor,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperative>
|
||||
>(args.B);
|
||||
}
|
||||
{%- endfor %}
|
||||
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false,
|
||||
"prepack_B_dispatch(..) is not implemented for "
|
||||
"atype = ", args.a_type,
|
||||
", b_type = ", args.b_type.str(),
|
||||
", with_group_scales_type= ", args.maybe_group_scales_type ?
|
||||
toString(*args.maybe_group_scales_type) : "None");
|
||||
}
|
||||
|
||||
}; // namespace machete
|
||||
"""
|
||||
|
||||
TmaMI = MixedInputKernelScheduleType.TmaWarpSpecializedCooperative
|
||||
TmaCoop = EpilogueScheduleType.TmaWarpSpecializedCooperative
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ScheduleConfig:
|
||||
tile_shape_mn: tuple[int, int]
|
||||
cluster_shape_mnk: tuple[int, int, int]
|
||||
kernel_schedule: MixedInputKernelScheduleType
|
||||
epilogue_schedule: EpilogueScheduleType
|
||||
tile_scheduler: TileSchedulerType
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TypeConfig:
|
||||
a: DataType
|
||||
b: DataType | VLLMDataType
|
||||
b_group_scale: DataType
|
||||
b_group_zeropoint: DataType
|
||||
b_channel_scale: DataType
|
||||
a_token_scale: DataType
|
||||
out: DataType
|
||||
accumulator: DataType
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PrepackTypeConfig:
|
||||
a: DataType
|
||||
b_num_bits: int
|
||||
convert: DataType
|
||||
accumulator: DataType
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImplConfig:
|
||||
types: TypeConfig
|
||||
schedules: list[ScheduleConfig]
|
||||
heuristic: list[tuple[str | None, ScheduleConfig]]
|
||||
|
||||
|
||||
def generate_sch_sig(schedule_config: ScheduleConfig) -> str:
|
||||
tile_shape = (
|
||||
f"{schedule_config.tile_shape_mn[0]}x{schedule_config.tile_shape_mn[1]}"
|
||||
)
|
||||
cluster_shape = (
|
||||
f"{schedule_config.cluster_shape_mnk[0]}"
|
||||
+ f"x{schedule_config.cluster_shape_mnk[1]}"
|
||||
+ f"x{schedule_config.cluster_shape_mnk[2]}"
|
||||
)
|
||||
kernel_schedule = VLLMKernelScheduleTag[schedule_config.kernel_schedule].split(
|
||||
"::"
|
||||
)[-1]
|
||||
epilogue_schedule = EpilogueScheduleTag[schedule_config.epilogue_schedule].split(
|
||||
"::"
|
||||
)[-1]
|
||||
tile_scheduler = TileSchedulerTag[schedule_config.tile_scheduler].split("::")[-1]
|
||||
|
||||
return (
|
||||
f"{tile_shape}_{cluster_shape}_{kernel_schedule}"
|
||||
+ f"_{epilogue_schedule}_{tile_scheduler}"
|
||||
)
|
||||
|
||||
|
||||
# mostly unique shorter sch_sig
|
||||
def generate_terse_sch_sig(schedule_config: ScheduleConfig) -> str:
|
||||
kernel_terse_names_replace = {
|
||||
"KernelTmaWarpSpecializedCooperative": "TmaMI_",
|
||||
"TmaWarpSpecializedCooperative_": "TmaCoop_",
|
||||
"StreamKScheduler": "streamK",
|
||||
}
|
||||
|
||||
sch_sig = generate_sch_sig(schedule_config)
|
||||
for orig, terse in kernel_terse_names_replace.items():
|
||||
sch_sig = sch_sig.replace(orig, terse)
|
||||
return sch_sig
|
||||
|
||||
|
||||
# unique type_name
|
||||
def generate_type_signature(kernel_types: TypeConfig):
|
||||
return str(
|
||||
"".join(
|
||||
[
|
||||
VLLMDataTypeNames[getattr(kernel_types, field.name)]
|
||||
for field in fields(TypeConfig)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def generate_type_option_name(kernel_types: TypeConfig):
|
||||
return ", ".join(
|
||||
[
|
||||
f"{field.name.replace('b_', 'with_') + '_type'}="
|
||||
+ VLLMDataTypeNames[getattr(kernel_types, field.name)]
|
||||
for field in fields(TypeConfig)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def is_power_of_two(n):
|
||||
return (n != 0) and (n & (n - 1) == 0)
|
||||
|
||||
|
||||
def to_cute_constant(value: list[int]):
|
||||
def _to_cute_constant(value: int):
|
||||
if is_power_of_two(value):
|
||||
return f"_{value}"
|
||||
else:
|
||||
return f"Int<{value}>"
|
||||
|
||||
if isinstance(value, Iterable):
|
||||
return [_to_cute_constant(value) for value in value]
|
||||
else:
|
||||
return _to_cute_constant(value)
|
||||
|
||||
|
||||
def unique_schedules(impl_configs: list[ImplConfig]):
|
||||
# Use dict over set for deterministic ordering
|
||||
return list(
|
||||
{
|
||||
sch: None for impl_config in impl_configs for sch in impl_config.schedules
|
||||
}.keys()
|
||||
)
|
||||
|
||||
|
||||
def unsigned_type_with_bitwidth(num_bits):
|
||||
return {
|
||||
4: DataType.u4,
|
||||
8: DataType.u8,
|
||||
16: DataType.u16,
|
||||
32: DataType.u32,
|
||||
64: DataType.u64,
|
||||
}[num_bits]
|
||||
|
||||
|
||||
template_globals = {
|
||||
"void": DataType.void,
|
||||
"DataTypeTag": VLLMDataTypeTag,
|
||||
"VLLMScalarTypeTag": VLLMDataTypeVLLMScalarTypeTag,
|
||||
"TorchTypeTag": VLLMDataTypeTorchDataTypeTag,
|
||||
"KernelScheduleTag": VLLMKernelScheduleTag,
|
||||
"EpilogueScheduleTag": EpilogueScheduleTag,
|
||||
"TileSchedulerTag": TileSchedulerTag,
|
||||
"to_cute_constant": to_cute_constant,
|
||||
"gen_sch_sig": generate_terse_sch_sig,
|
||||
"gen_type_sig": generate_type_signature,
|
||||
"unique_schedules": unique_schedules,
|
||||
"unsigned_type_with_bitwidth": unsigned_type_with_bitwidth,
|
||||
"gen_type_option_name": generate_type_option_name,
|
||||
}
|
||||
|
||||
|
||||
def create_template(template_str):
|
||||
template = jinja2.Template(template_str)
|
||||
template.globals.update(template_globals)
|
||||
return template
|
||||
|
||||
|
||||
mm_dispatch_template = create_template(DISPATCH_TEMPLATE)
|
||||
mm_impl_template = create_template(IMPL_TEMPLATE)
|
||||
prepack_dispatch_template = create_template(PREPACK_TEMPLATE)
|
||||
|
||||
|
||||
def create_sources(impl_configs: list[ImplConfig], num_impl_files=8):
|
||||
sources = []
|
||||
|
||||
sources.append(
|
||||
(
|
||||
"machete_mm_dispatch",
|
||||
mm_dispatch_template.render(impl_configs=impl_configs),
|
||||
)
|
||||
)
|
||||
|
||||
prepack_types = []
|
||||
for impl_config in impl_configs:
|
||||
convert_type = (
|
||||
impl_config.types.a
|
||||
if impl_config.types.b_group_scale == DataType.void
|
||||
else impl_config.types.b_group_scale
|
||||
)
|
||||
prepack_types.append(
|
||||
PrepackTypeConfig(
|
||||
a=impl_config.types.a,
|
||||
b_num_bits=VLLMDataTypeSize[impl_config.types.b],
|
||||
convert=convert_type,
|
||||
accumulator=impl_config.types.accumulator,
|
||||
)
|
||||
)
|
||||
|
||||
def prepacked_type_key(prepack_type: PrepackTypeConfig):
|
||||
# For now, we can just use the first accumulator type seen since
|
||||
# the tensor core shapes/layouts don't vary based on accumulator
|
||||
# type so we can generate less code this way
|
||||
return (prepack_type.a, prepack_type.b_num_bits, prepack_type.convert)
|
||||
|
||||
unique_prepack_types = []
|
||||
prepack_types_seen = set()
|
||||
for prepack_type in prepack_types:
|
||||
key = prepacked_type_key(prepack_type)
|
||||
if key not in prepack_types_seen:
|
||||
unique_prepack_types.append(prepack_type)
|
||||
prepack_types_seen.add(key)
|
||||
|
||||
sources.append(
|
||||
(
|
||||
"machete_prepack",
|
||||
prepack_dispatch_template.render(
|
||||
types=unique_prepack_types,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Split up impls across files
|
||||
num_impls = reduce(lambda x, y: x + len(y.schedules), impl_configs, 0)
|
||||
num_impls_per_file = math.ceil(num_impls / num_impl_files)
|
||||
|
||||
files_impls: list[list[ImplConfig]] = [[]]
|
||||
|
||||
curr_num_impls_assigned = 0
|
||||
curr_impl_in_file = 0
|
||||
curr_impl_configs = deepcopy(list(reversed(impl_configs)))
|
||||
|
||||
while curr_num_impls_assigned < num_impls:
|
||||
room_left_in_file = num_impls_per_file - curr_impl_in_file
|
||||
if room_left_in_file == 0:
|
||||
files_impls.append([])
|
||||
room_left_in_file = num_impls_per_file
|
||||
curr_impl_in_file = 0
|
||||
|
||||
curr_ic = curr_impl_configs[-1]
|
||||
if len(curr_ic.schedules) >= room_left_in_file:
|
||||
# Break apart the current impl config
|
||||
tmp_ic = deepcopy(curr_ic)
|
||||
tmp_ic.schedules = curr_ic.schedules[:room_left_in_file]
|
||||
curr_ic.schedules = curr_ic.schedules[room_left_in_file:]
|
||||
files_impls[-1].append(tmp_ic)
|
||||
else:
|
||||
files_impls[-1].append(curr_ic)
|
||||
curr_impl_configs.pop()
|
||||
curr_num_impls_assigned += len(files_impls[-1][-1].schedules)
|
||||
curr_impl_in_file += len(files_impls[-1][-1].schedules)
|
||||
|
||||
for part, file_impls in enumerate(files_impls):
|
||||
sources.append(
|
||||
(
|
||||
f"machete_mm_impl_part{part + 1}",
|
||||
mm_impl_template.render(impl_configs=file_impls),
|
||||
)
|
||||
)
|
||||
|
||||
return sources
|
||||
|
||||
|
||||
def generate():
|
||||
# See csrc/quantization/machete/Readme.md, the Codegeneration for more info
|
||||
# about how this works
|
||||
SCRIPT_DIR = os.path.dirname(__file__)
|
||||
|
||||
sch_common_params = dict(
|
||||
kernel_schedule=TmaMI,
|
||||
epilogue_schedule=TmaCoop,
|
||||
tile_scheduler=TileSchedulerType.StreamK,
|
||||
)
|
||||
|
||||
# Stored as "condition": ((tile_shape_mn), (cluster_shape_mnk))
|
||||
default_tile_heuristic_config = {
|
||||
#### M = 257+
|
||||
"M > 256 && K <= 16384 && N <= 4096": ((128, 128), (2, 1, 1)),
|
||||
"M > 256": ((128, 256), (2, 1, 1)),
|
||||
#### M = 129-256
|
||||
"M > 128 && K <= 4096 && N <= 4096": ((128, 64), (2, 1, 1)),
|
||||
"M > 128 && K <= 8192 && N <= 8192": ((128, 128), (2, 1, 1)),
|
||||
"M > 128": ((128, 256), (2, 1, 1)),
|
||||
#### M = 65-128
|
||||
"M > 64 && K <= 4069 && N <= 4069": ((128, 32), (2, 1, 1)),
|
||||
"M > 64 && K <= 4069 && N <= 8192": ((128, 64), (2, 1, 1)),
|
||||
"M > 64 && K >= 8192 && N >= 12288": ((256, 128), (2, 1, 1)),
|
||||
"M > 64": ((128, 128), (2, 1, 1)),
|
||||
#### M = 33-64
|
||||
"M > 32 && K <= 6144 && N <= 6144": ((128, 16), (1, 1, 1)),
|
||||
"M > 32 && K >= 16384 && N >= 12288": ((256, 64), (2, 1, 1)),
|
||||
"M > 32": ((128, 64), (2, 1, 1)),
|
||||
#### M = 17-32
|
||||
"M > 16 && K <= 12288 && N <= 8192": ((128, 32), (2, 1, 1)),
|
||||
"M > 16": ((256, 32), (2, 1, 1)),
|
||||
#### M = 1-16
|
||||
"N >= 26624": ((256, 16), (1, 1, 1)),
|
||||
None: ((128, 16), (1, 1, 1)),
|
||||
}
|
||||
|
||||
# For now we use the same heuristic for all types
|
||||
# Heuristic is currently tuned for H100s
|
||||
default_heuristic = [
|
||||
(cond, ScheduleConfig(*tile_config, **sch_common_params)) # type: ignore
|
||||
for cond, tile_config in default_tile_heuristic_config.items()
|
||||
]
|
||||
|
||||
def get_unique_schedules(heuristic: dict[str, ScheduleConfig]):
|
||||
# Do not use schedules = list(set(...)) because we need to make sure
|
||||
# the output list is deterministic; otherwise the generated kernel file
|
||||
# will be non-deterministic and causes ccache miss.
|
||||
schedules = []
|
||||
for _, schedule_config in heuristic:
|
||||
if schedule_config not in schedules:
|
||||
schedules.append(schedule_config)
|
||||
return schedules
|
||||
|
||||
impl_configs = []
|
||||
|
||||
GPTQ_kernel_type_configs = list(
|
||||
TypeConfig(
|
||||
a=a,
|
||||
b=b,
|
||||
b_group_scale=a,
|
||||
b_group_zeropoint=DataType.void,
|
||||
b_channel_scale=DataType.void,
|
||||
a_token_scale=DataType.void,
|
||||
out=a,
|
||||
accumulator=DataType.f32,
|
||||
)
|
||||
for b in (VLLMDataType.u4b8, VLLMDataType.u8b128)
|
||||
for a in (DataType.f16, DataType.bf16)
|
||||
)
|
||||
|
||||
impl_configs += [
|
||||
ImplConfig(x[0], x[1], x[2])
|
||||
for x in zip(
|
||||
GPTQ_kernel_type_configs,
|
||||
itertools.repeat(get_unique_schedules(default_heuristic)),
|
||||
itertools.repeat(default_heuristic),
|
||||
)
|
||||
]
|
||||
|
||||
AWQ_kernel_type_configs = list(
|
||||
TypeConfig(
|
||||
a=a,
|
||||
b=b,
|
||||
b_group_scale=a,
|
||||
b_group_zeropoint=a,
|
||||
b_channel_scale=DataType.void,
|
||||
a_token_scale=DataType.void,
|
||||
out=a,
|
||||
accumulator=DataType.f32,
|
||||
)
|
||||
for b in (DataType.u4, DataType.u8)
|
||||
for a in (DataType.f16, DataType.bf16)
|
||||
)
|
||||
|
||||
impl_configs += [
|
||||
ImplConfig(x[0], x[1], x[2])
|
||||
for x in zip(
|
||||
AWQ_kernel_type_configs,
|
||||
itertools.repeat(get_unique_schedules(default_heuristic)),
|
||||
itertools.repeat(default_heuristic),
|
||||
)
|
||||
]
|
||||
|
||||
# TODO: Support W4A8 when ready
|
||||
# # Stored as "condition": ((tile_shape_mn), (cluster_shape_mnk))
|
||||
# # TODO (LucasWilkinson): Further tuning required
|
||||
# qqq_tile_heuristic_config = {
|
||||
# #### M = 257+
|
||||
# # ((128, 256), (2, 1, 1)) Broken for QQQ types
|
||||
# # TODO (LucasWilkinson): Investigate further
|
||||
# # "M > 256 && K <= 16384 && N <= 4096": ((128, 128), (2, 1, 1)),
|
||||
# # "M > 256": ((128, 256), (2, 1, 1)),
|
||||
# "M > 256": ((128, 128), (2, 1, 1)),
|
||||
# #### M = 129-256
|
||||
# "M > 128 && K <= 4096 && N <= 4096": ((128, 64), (2, 1, 1)),
|
||||
# "M > 128 && K <= 8192 && N <= 8192": ((128, 128), (2, 1, 1)),
|
||||
# # ((128, 256), (2, 1, 1)) Broken for QQQ types
|
||||
# # TODO (LucasWilkinson): Investigate further
|
||||
# # "M > 128": ((128, 256), (2, 1, 1)),
|
||||
# "M > 128": ((128, 128), (2, 1, 1)),
|
||||
# #### M = 65-128
|
||||
# "M > 64 && K <= 4069 && N <= 4069": ((128, 32), (2, 1, 1)),
|
||||
# "M > 64 && K <= 4069 && N <= 8192": ((128, 64), (2, 1, 1)),
|
||||
# "M > 64 && K >= 8192 && N >= 12288": ((256, 128), (2, 1, 1)),
|
||||
# "M > 64": ((128, 128), (2, 1, 1)),
|
||||
# #### M = 33-64
|
||||
# "M > 32 && K <= 6144 && N <= 6144": ((128, 16), (1, 1, 1)),
|
||||
# # Broken for QQQ types
|
||||
# # TODO (LucasWilkinson): Investigate further
|
||||
# #"M > 32 && K >= 16384 && N >= 12288": ((256, 64), (2, 1, 1)),
|
||||
# "M > 32": ((128, 64), (2, 1, 1)),
|
||||
# #### M = 17-32
|
||||
# "M > 16 && K <= 12288 && N <= 8192": ((128, 32), (2, 1, 1)),
|
||||
# "M > 16": ((256, 32), (2, 1, 1)),
|
||||
# #### M = 1-16
|
||||
# "N >= 26624": ((256, 16), (1, 1, 1)),
|
||||
# None: ((128, 16), (1, 1, 1)),
|
||||
# }
|
||||
|
||||
# # For now we use the same heuristic for all types
|
||||
# # Heuristic is currently tuned for H100s
|
||||
# qqq_heuristic = [
|
||||
# (cond, ScheduleConfig(*tile_config,
|
||||
# **sch_common_params)) # type: ignore
|
||||
# for cond, tile_config in qqq_tile_heuristic_config.items()
|
||||
# ]
|
||||
|
||||
# QQQ_kernel_types = [
|
||||
# *(TypeConfig(
|
||||
# a=DataType.s8,
|
||||
# b=VLLMDataType.u4b8,
|
||||
# b_group_scale=b_group_scale,
|
||||
# b_group_zeropoint=DataType.void,
|
||||
# b_channel_scale=DataType.f32,
|
||||
# a_token_scale=DataType.f32,
|
||||
# out=DataType.f16,
|
||||
# accumulator=DataType.s32,
|
||||
# ) for b_group_scale in (DataType.f16, DataType.void)),
|
||||
# *(TypeConfig(
|
||||
# a=DataType.e4m3,
|
||||
# b=VLLMDataType.u4b8,
|
||||
# b_group_scale=b_group_scale,
|
||||
# b_group_zeropoint=DataType.void,
|
||||
# b_channel_scale=DataType.f32,
|
||||
# a_token_scale=DataType.f32,
|
||||
# out=DataType.f16,
|
||||
# accumulator=DataType.f32,
|
||||
# ) for b_group_scale in (DataType.f16, DataType.void)),
|
||||
# ]
|
||||
|
||||
# impl_configs += [
|
||||
# ImplConfig(x[0], x[1], x[2])
|
||||
# for x in zip(QQQ_kernel_types,
|
||||
# itertools.repeat(get_unique_schedules(qqq_heuristic)),
|
||||
# itertools.repeat(qqq_heuristic))
|
||||
# ]
|
||||
|
||||
output_dir = os.path.join(SCRIPT_DIR, "generated")
|
||||
|
||||
# Delete the "generated" directory if it exists
|
||||
if os.path.exists(output_dir):
|
||||
shutil.rmtree(output_dir)
|
||||
|
||||
# Create the "generated" directory
|
||||
os.makedirs(output_dir)
|
||||
|
||||
# Render each group of configurations into separate files
|
||||
for filename, code in create_sources(impl_configs):
|
||||
filepath = os.path.join(output_dir, f"{filename}.cu")
|
||||
with open(filepath, "w") as output_file:
|
||||
output_file.write(code)
|
||||
print(f"Rendered template to {filepath}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
generate()
|
||||
31
csrc/quantization/machete/machete_collective_builder.cuh
Normal file
31
csrc/quantization/machete/machete_collective_builder.cuh
Normal file
@@ -0,0 +1,31 @@
|
||||
#pragma once
|
||||
|
||||
#include "cutlass_extensions/vllm_collective_builder.cuh"
|
||||
#include "machete_mainloop.cuh"
|
||||
|
||||
namespace cutlass::gemm::collective {
|
||||
using namespace cute;
|
||||
|
||||
struct MacheteKernelTag {};
|
||||
|
||||
template <class ElementPairA_, class GmemLayoutA_, int AlignmentA,
|
||||
class ElementPairB_, class GmemLayoutB_, int AlignmentB,
|
||||
class ElementAccumulator, class TileShape_MNK, class ClusterShape_MNK,
|
||||
class StageCountType, class KernelScheduleType>
|
||||
struct VLLMCollectiveBuilder<
|
||||
MacheteKernelTag, arch::Sm90, arch::OpClassTensorOp, ElementPairA_,
|
||||
GmemLayoutA_, AlignmentA, ElementPairB_, GmemLayoutB_, AlignmentB,
|
||||
ElementAccumulator, TileShape_MNK, ClusterShape_MNK, StageCountType,
|
||||
KernelScheduleType,
|
||||
cute::enable_if_t<(
|
||||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecialized> ||
|
||||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpong> ||
|
||||
cute::is_same_v<KernelScheduleType,
|
||||
KernelTmaWarpSpecializedCooperative>)>> {
|
||||
using CollectiveOp = machete::MacheteCollectiveMma<
|
||||
ElementPairA_, GmemLayoutA_, AlignmentA, ElementPairB_, GmemLayoutB_,
|
||||
AlignmentB, ElementAccumulator, TileShape_MNK, ClusterShape_MNK,
|
||||
StageCountType, KernelScheduleType>;
|
||||
};
|
||||
|
||||
}; // namespace cutlass::gemm::collective
|
||||
35
csrc/quantization/machete/machete_interleaving_utils.cuh
Normal file
35
csrc/quantization/machete/machete_interleaving_utils.cuh
Normal file
@@ -0,0 +1,35 @@
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/layout.hpp"
|
||||
|
||||
namespace machete {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
// get an interleaved block layout where each element consecutive element has a
|
||||
// stride of bit_stride and the block width is blk_bit_width,
|
||||
// examples:
|
||||
// size_bits<T> = 8, bit_stride = 8, blk_bit_width = 32 -> 4:1
|
||||
// size_bits<T> = 8, bit_stride = 16, blk_bit_width = 32 -> (2, 2):(2, 1)
|
||||
// size_bits<T> = 4, bit_stride = 8, blk_bit_width = 32 -> (4, 2):(2, 1)
|
||||
// size_bits<T> = 4, bit_stride = 16, blk_bit_width = 32 -> (2, 4):(4, 1)
|
||||
template <typename T, int bit_stride, int blk_bit_width>
|
||||
CUTE_HOST_DEVICE static constexpr auto get_interleaved_blk_layout() {
|
||||
static_assert(blk_bit_width % bit_stride == 0);
|
||||
static_assert(bit_stride % cute::sizeof_bits_v<T> == 0);
|
||||
|
||||
constexpr auto elems_per_blk = blk_bit_width / cute::sizeof_bits_v<T>;
|
||||
|
||||
if constexpr (cute::sizeof_bits_v<T> == bit_stride) {
|
||||
// identity layout
|
||||
return Layout<Shape<Int<elems_per_blk>>>{};
|
||||
} else {
|
||||
constexpr auto elems_per_stride = bit_stride / cute::sizeof_bits_v<T>;
|
||||
constexpr auto num_strides = elems_per_blk / elems_per_stride;
|
||||
return Layout<Shape<Int<num_strides>, Int<elems_per_stride>>,
|
||||
Stride<Int<elems_per_stride>, Int<1>>>{};
|
||||
}
|
||||
}
|
||||
|
||||
}; // namespace machete
|
||||
1473
csrc/quantization/machete/machete_mainloop.cuh
Normal file
1473
csrc/quantization/machete/machete_mainloop.cuh
Normal file
File diff suppressed because it is too large
Load Diff
309
csrc/quantization/machete/machete_mm_kernel.cuh
Normal file
309
csrc/quantization/machete/machete_mm_kernel.cuh
Normal file
@@ -0,0 +1,309 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
// clang-format off
|
||||
// The cutlass include order matters (annoyingly)
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
// clang-format on
|
||||
|
||||
#include "cutlass_extensions/cute_utils.cuh"
|
||||
#include "cutlass_extensions/vllm_numeric_conversion.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
#include "cutlass_extensions/torch_utils.hpp"
|
||||
#include "machete_collective_builder.cuh"
|
||||
#include "machete_prepacked_layout.cuh"
|
||||
#include "machete_interleaving_utils.cuh"
|
||||
|
||||
namespace machete {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
// NOTE This kernel computes D = alpha * A * B + beta * C by computing
|
||||
// D^t = alpha * B^t * A^t + beta * C^t, this is because the wgmma
|
||||
// instructions only support sourcing from registers for the left-hand
|
||||
// operand, we want to upconvert/decompress the quantized operand in
|
||||
// register. Since the primary use case we want to support is Y = XW^t where
|
||||
// W is quantized, in this situation or right-hand operand is quantized so
|
||||
// we compute the transpose to move it to the left-hand side.
|
||||
template <typename ElementA_, typename ElementB_, typename ElementD_,
|
||||
typename AccumulatorT, typename GroupScaleT, typename GroupZeroT,
|
||||
typename ChannelScaleT, typename TokenScaleT, class KernelSchedule,
|
||||
typename ScheduleConfig>
|
||||
struct MacheteKernelTemplate {
|
||||
static constexpr bool with_C = false; // not ever used
|
||||
static constexpr bool with_group_scales = !std::is_same_v<GroupScaleT, void>;
|
||||
static constexpr bool with_group_zeropoints =
|
||||
!std::is_same_v<GroupZeroT, void>;
|
||||
static constexpr bool with_channel_scales =
|
||||
!std::is_same_v<ChannelScaleT, void>;
|
||||
static constexpr bool with_token_scales = !std::is_same_v<TokenScaleT, void>;
|
||||
|
||||
using MmaType = ElementA_;
|
||||
using ElementA = ElementA_;
|
||||
using ElementB = ElementB_;
|
||||
using ElementD = ElementD_;
|
||||
using ElementC = cute::conditional_t<with_C, ElementD, void>;
|
||||
using ElementAccumulator = AccumulatorT;
|
||||
using ElementCompute = AccumulatorT; // For Epilogue
|
||||
// Use dummy values when we don't have scales or zeropoints
|
||||
using ElementZGroup =
|
||||
cute::conditional_t<with_group_zeropoints, GroupZeroT, MmaType>;
|
||||
using ElementSGroup =
|
||||
cute::conditional_t<with_group_scales, GroupScaleT, MmaType>;
|
||||
using ElementConvertGroup =
|
||||
cute::conditional_t<with_group_scales, GroupScaleT, MmaType>;
|
||||
using ElementSChannel =
|
||||
cute::conditional_t<with_channel_scales, ChannelScaleT, AccumulatorT>;
|
||||
using ElementSToken =
|
||||
cute::conditional_t<with_token_scales, TokenScaleT, AccumulatorT>;
|
||||
|
||||
using BTypeTuple = cute::conditional_t<
|
||||
with_group_scales,
|
||||
cute::conditional_t<with_group_zeropoints,
|
||||
cute::tuple<ElementB, ElementSGroup, ElementZGroup>,
|
||||
cute::tuple<ElementB, ElementSGroup>>,
|
||||
ElementB>;
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = LayoutC;
|
||||
using LayoutScale = cutlass::layout::RowMajor;
|
||||
// not actually used since B has the prepacked layout, but required by cutlass
|
||||
using _LayoutB = cutlass::layout::ColumnMajor;
|
||||
|
||||
// Interface strides expected by create_arguments (will get transposed)
|
||||
using StrideA = cutlass::detail::TagToStrideA_t<LayoutA>;
|
||||
using StrideC = cutlass::detail::TagToStrideA_t<LayoutC>;
|
||||
using StrideD = cutlass::detail::TagToStrideA_t<LayoutD>;
|
||||
using StrideSGroup = cutlass::detail::TagToStrideA_t<LayoutScale>;
|
||||
using StrideZGroup = StrideSGroup;
|
||||
|
||||
using LayoutA_Transpose =
|
||||
typename cutlass::layout::LayoutTranspose<LayoutA>::type;
|
||||
using LayoutC_Transpose =
|
||||
typename cutlass::layout::LayoutTranspose<LayoutC>::type;
|
||||
using LayoutD_Transpose =
|
||||
typename cutlass::layout::LayoutTranspose<LayoutD>::type;
|
||||
|
||||
using ArchTag = cutlass::arch::Sm90;
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
|
||||
using PrepackedLayoutB =
|
||||
PrepackedLayoutBTemplate<ElementA_, ElementB_, ElementConvertGroup,
|
||||
AccumulatorT, LayoutA_Transpose, KernelSchedule>;
|
||||
|
||||
static int constexpr TileShapeK =
|
||||
128 * 8 / cutlass::sizeof_bits<MmaType>::value;
|
||||
static int constexpr AlignmentA = 128 / cutlass::sizeof_bits_v<ElementA>;
|
||||
static int constexpr AlignmentB = 128 / cutlass::sizeof_bits_v<ElementB>;
|
||||
static int constexpr AlignmentC =
|
||||
(with_C) ? 128 / cutlass::sizeof_bits_v<ElementC> : 0;
|
||||
static int constexpr AlignmentD = 128 / cutlass::sizeof_bits_v<ElementD>;
|
||||
|
||||
using TileShape = decltype(append(typename ScheduleConfig::TileShapeNM{},
|
||||
cute::Int<TileShapeK>{}));
|
||||
using ClusterShape = typename ScheduleConfig::ClusterShape;
|
||||
using EpilogueSchedule = typename ScheduleConfig::EpilogueSchedule;
|
||||
using EpilogueTileType = typename ScheduleConfig::EpilogueTileType;
|
||||
using TileScheduler = typename ScheduleConfig::TileScheduler;
|
||||
|
||||
static_assert(
|
||||
(!with_channel_scales && !with_token_scales) ||
|
||||
((with_channel_scales && with_token_scales) &&
|
||||
std::is_same_v<ElementSChannel, ElementSToken>),
|
||||
"Currently token and channel scales (if present) must be the same type");
|
||||
|
||||
// Currently only supports float scales
|
||||
using ChTokScalesEpilogue =
|
||||
typename vllm::c3x::ScaledEpilogue<ElementAccumulator, ElementD,
|
||||
TileShape>;
|
||||
static_assert((with_channel_scales || with_token_scales) ||
|
||||
(std::is_same_v<ElementSChannel, float> &&
|
||||
std::is_same_v<ElementSToken, float>),
|
||||
"Currently token and channel scales (if present) must be float "
|
||||
"(and if one is present the other must be too)");
|
||||
|
||||
using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT<
|
||||
cutlass::epilogue::fusion::Sm90AccFetch>;
|
||||
|
||||
using EVTCompute =
|
||||
std::conditional_t<with_channel_scales || with_token_scales,
|
||||
typename ChTokScalesEpilogue::EVTCompute,
|
||||
StoreEpilogueCompute>;
|
||||
|
||||
// EVTCompute
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType,
|
||||
ElementAccumulator, ElementSChannel, ElementC, LayoutC_Transpose,
|
||||
AlignmentC, ElementD, LayoutD_Transpose, AlignmentD, EpilogueSchedule,
|
||||
EVTCompute>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::VLLMCollectiveBuilder<
|
||||
cutlass::gemm::collective::MacheteKernelTag, ArchTag, OperatorClass,
|
||||
BTypeTuple, PrepackedLayoutB, AlignmentB, ElementA, LayoutA_Transpose,
|
||||
AlignmentA, ElementAccumulator, TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int, int, int, int>, // Indicates ProblemShape
|
||||
CollectiveMainloop, CollectiveEpilogue, TileScheduler>;
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
// stride_B is unused (since B is prepacked), but still required by cutlass
|
||||
using _StrideB = cutlass::detail::TagToStrideB_t<_LayoutB>;
|
||||
|
||||
using Arguments = typename Gemm::Arguments;
|
||||
using MainloopArguments = typename GemmKernel::MainloopArguments;
|
||||
using EpilogueArguments = typename GemmKernel::EpilogueArguments;
|
||||
|
||||
static Arguments create_arguments(
|
||||
cudaStream_t stream,
|
||||
torch::Tensor const& A, // MxK matrix
|
||||
torch::Tensor const& B, // KxN prepacked matrix
|
||||
torch::Tensor& D, // MxN matrix
|
||||
std::optional<torch::Tensor> const& maybe_g_scales, // scale_KxN matrix
|
||||
std::optional<torch::Tensor> const& maybe_g_zeros, // scale_KxN matrix
|
||||
std::optional<int64_t> maybe_group_size,
|
||||
std::optional<torch::Tensor> const& maybe_ch_scales, // len N vector
|
||||
std::optional<torch::Tensor> const& maybe_tok_scales) // len M vector
|
||||
{
|
||||
static_assert(!with_group_zeropoints || with_group_scales);
|
||||
|
||||
int M = A.size(0), N = B.size(1), K = A.size(1);
|
||||
TORCH_CHECK(D.size(0) == M && D.size(1) == N);
|
||||
|
||||
auto layout_A = make_cute_layout<StrideA>(A, "A");
|
||||
auto layout_D = make_cute_layout<StrideD>(D, "D");
|
||||
auto layout_S_group =
|
||||
maybe_make_cute_layout<StrideSGroup>(maybe_g_scales, "group_scales");
|
||||
auto layout_Z_group =
|
||||
maybe_make_cute_layout<StrideZGroup>(maybe_g_zeros, "group_zeros");
|
||||
int64_t numel_S_channel = maybe_ch_scales ? maybe_ch_scales->numel() : 0;
|
||||
int64_t numel_S_token = maybe_tok_scales ? maybe_tok_scales->numel() : 0;
|
||||
|
||||
auto unwrap = [](auto const& t) {
|
||||
return t ? t->const_data_ptr() : nullptr;
|
||||
};
|
||||
auto A_ptr = static_cast<ElementA const*>(A.const_data_ptr());
|
||||
auto B_ptr = static_cast<ElementB const*>(B.const_data_ptr());
|
||||
auto D_ptr = static_cast<ElementD*>(D.mutable_data_ptr());
|
||||
auto S_group_ptr =
|
||||
static_cast<ElementSGroup const*>(unwrap(maybe_g_scales));
|
||||
auto Z_group_ptr = static_cast<ElementZGroup const*>(unwrap(maybe_g_zeros));
|
||||
auto S_channel_ptr =
|
||||
static_cast<ElementSChannel const*>(unwrap(maybe_ch_scales));
|
||||
auto S_token_ptr =
|
||||
static_cast<ElementSToken const*>(unwrap(maybe_tok_scales));
|
||||
|
||||
int const group_size =
|
||||
maybe_group_size == -1 ? K : maybe_group_size.value_or(K);
|
||||
int const scale_k = (K + group_size - 1) / group_size;
|
||||
|
||||
TORCH_CHECK(size<0>(layout_A) == M && size<1>(layout_A) == K);
|
||||
TORCH_CHECK(size<0>(layout_D) == M && size<1>(layout_D) == N);
|
||||
|
||||
if constexpr (with_group_scales) {
|
||||
TORCH_CHECK(S_group_ptr && layout_S_group);
|
||||
TORCH_CHECK((size<0>(*layout_S_group) == scale_k &&
|
||||
size<1>(*layout_S_group) == N));
|
||||
} else {
|
||||
TORCH_CHECK(!S_group_ptr, "Scales not supported");
|
||||
}
|
||||
|
||||
if constexpr (with_group_zeropoints) {
|
||||
TORCH_CHECK(Z_group_ptr && layout_Z_group);
|
||||
TORCH_CHECK((size<0>(*layout_Z_group) == scale_k &&
|
||||
size<1>(*layout_Z_group) == N));
|
||||
TORCH_CHECK(layout_S_group && *layout_Z_group == *layout_S_group,
|
||||
"Scales and zeros must have the same layout");
|
||||
} else {
|
||||
TORCH_CHECK(!Z_group_ptr, "Zeropoints not supported");
|
||||
}
|
||||
|
||||
if constexpr (with_channel_scales || with_token_scales) {
|
||||
TORCH_CHECK(
|
||||
(maybe_ch_scales->numel() == N || maybe_ch_scales->numel() == 1) &&
|
||||
(maybe_tok_scales->numel() == M || maybe_tok_scales->numel() == 1));
|
||||
}
|
||||
|
||||
// Transpose A and D
|
||||
// A doesn't need to be transposed since cutlass expects a NxK matrix
|
||||
// for B (which is At)
|
||||
auto stride_At = layout_A.stride();
|
||||
auto stride_Dt = permute_layout<1, 0, 2>(layout_D).stride();
|
||||
|
||||
MainloopArguments mainloop_arguments{};
|
||||
// {Accum, C, C_layout, D, D}
|
||||
EpilogueArguments epilogue_arguments{};
|
||||
|
||||
if constexpr (with_channel_scales || with_token_scales) {
|
||||
epilogue_arguments =
|
||||
EpilogueArguments{ChTokScalesEpilogue::prepare_args(
|
||||
*maybe_ch_scales, *maybe_tok_scales),
|
||||
nullptr,
|
||||
{},
|
||||
D_ptr,
|
||||
stride_Dt};
|
||||
} else {
|
||||
epilogue_arguments = EpilogueArguments{{}, nullptr, {}, D_ptr, stride_Dt};
|
||||
}
|
||||
|
||||
if constexpr (with_group_scales && with_group_zeropoints) {
|
||||
auto stride_S_group = permute_layout<1, 0, 2>(*layout_S_group).stride();
|
||||
mainloop_arguments = MainloopArguments{
|
||||
B_ptr, _StrideB{}, A_ptr, stride_At,
|
||||
S_group_ptr, stride_S_group, group_size, Z_group_ptr};
|
||||
} else if constexpr (with_group_scales) {
|
||||
auto stride_S_group = permute_layout<1, 0, 2>(*layout_S_group).stride();
|
||||
mainloop_arguments =
|
||||
MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At,
|
||||
S_group_ptr, stride_S_group, group_size};
|
||||
} else {
|
||||
mainloop_arguments =
|
||||
MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At};
|
||||
}
|
||||
|
||||
return Arguments{cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{N, M, K, 1},
|
||||
mainloop_arguments,
|
||||
epilogue_arguments};
|
||||
};
|
||||
|
||||
static size_t get_workspace_size(Arguments const& args) {
|
||||
return Gemm::get_workspace_size(args);
|
||||
}
|
||||
|
||||
static bool can_implement(Arguments const& args) {
|
||||
return Gemm::can_implement(args) == cutlass::Status::kSuccess;
|
||||
}
|
||||
|
||||
static void run(Arguments const& args, void* workspace, cudaStream_t stream) {
|
||||
Gemm gemm_op;
|
||||
|
||||
cutlass::Status status = gemm_op.initialize(args, workspace, stream);
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess,
|
||||
"Machete kernel failed to initialize workspace");
|
||||
|
||||
status = gemm_op.run(stream);
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, "Machete kernel failed");
|
||||
}
|
||||
};
|
||||
|
||||
}; // namespace machete
|
||||
75
csrc/quantization/machete/machete_mm_launcher.cuh
Normal file
75
csrc/quantization/machete/machete_mm_launcher.cuh
Normal file
@@ -0,0 +1,75 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/all.h>
|
||||
#include <Python.h>
|
||||
|
||||
#include "machete_mm_kernel.cuh"
|
||||
#include "cutlass_extensions/torch_utils.hpp"
|
||||
#include "core/scalar_type.hpp"
|
||||
|
||||
namespace machete {
|
||||
|
||||
struct MMArgs {
|
||||
torch::Tensor const& A;
|
||||
torch::Tensor const& B;
|
||||
vllm::ScalarType const& b_type;
|
||||
std::optional<at::ScalarType> const& maybe_out_type;
|
||||
std::optional<torch::Tensor> const& maybe_group_scales;
|
||||
std::optional<torch::Tensor> const& maybe_group_zeros;
|
||||
std::optional<int64_t> maybe_group_size;
|
||||
std::optional<torch::Tensor> const& maybe_channel_scales;
|
||||
std::optional<torch::Tensor> const& maybe_token_scales;
|
||||
std::optional<std::string> maybe_schedule;
|
||||
};
|
||||
|
||||
struct SupportedSchedulesArgs {
|
||||
at::ScalarType a_type;
|
||||
vllm::ScalarType b_type;
|
||||
std::optional<at::ScalarType> maybe_group_scales_type;
|
||||
std::optional<at::ScalarType> maybe_group_zeros_type;
|
||||
std::optional<at::ScalarType> maybe_channel_scales_type;
|
||||
std::optional<at::ScalarType> maybe_token_scales_type;
|
||||
std::optional<at::ScalarType> maybe_out_type;
|
||||
};
|
||||
|
||||
torch::Tensor mm_dispatch(MMArgs args);
|
||||
|
||||
std::vector<std::string> supported_schedules_dispatch(
|
||||
SupportedSchedulesArgs args);
|
||||
|
||||
template <typename MacheteKernel>
|
||||
torch::Tensor run_impl(MMArgs args) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(args.A));
|
||||
|
||||
auto device = args.A.device();
|
||||
auto stream = at::cuda::getCurrentCUDAStream(device.index());
|
||||
|
||||
int M = args.A.size(0);
|
||||
int N = args.B.size(1);
|
||||
int K = args.A.size(1);
|
||||
|
||||
// Allocate output
|
||||
torch::Tensor D = torch::empty(
|
||||
{M, N},
|
||||
torch::TensorOptions()
|
||||
.dtype(equivalent_scalar_type_v<typename MacheteKernel::ElementD>)
|
||||
.device(device));
|
||||
|
||||
auto arguments = MacheteKernel::create_arguments(
|
||||
stream, //
|
||||
args.A, args.B, D, args.maybe_group_scales, args.maybe_group_zeros,
|
||||
args.maybe_group_size, args.maybe_channel_scales,
|
||||
args.maybe_token_scales);
|
||||
TORCH_CHECK(MacheteKernel::can_implement(arguments),
|
||||
"Machete kernel cannot be run with these arguments");
|
||||
|
||||
size_t workspace_size = MacheteKernel::get_workspace_size(arguments);
|
||||
torch::Tensor workspace = torch::empty(
|
||||
workspace_size, torch::TensorOptions().dtype(torch::kU8).device(device));
|
||||
|
||||
MacheteKernel::run(arguments, workspace.mutable_data_ptr(), stream);
|
||||
|
||||
return D;
|
||||
};
|
||||
|
||||
}; // namespace machete
|
||||
76
csrc/quantization/machete/machete_prepack_kernel.cuh
Normal file
76
csrc/quantization/machete/machete_prepack_kernel.cuh
Normal file
@@ -0,0 +1,76 @@
|
||||
#pragma once
|
||||
|
||||
#include "machete_mm_kernel.cuh"
|
||||
#include "cutlass_extensions/cute_utils.cuh"
|
||||
#include "cutlass_extensions/torch_utils.hpp"
|
||||
|
||||
namespace machete {
|
||||
|
||||
template <int threads, typename PrepackedLayoutB, typename BInTensor,
|
||||
typename ElementB>
|
||||
static __global__ void prepack_B_kernel(BInTensor B_in, ElementB* B_out_ptr) {
|
||||
auto constexpr block_size =
|
||||
Int<size(typename PrepackedLayoutB::PPBlockShape_NK{})>{};
|
||||
auto constexpr eles_per_thread = Int<block_size / threads>{};
|
||||
static_assert(block_size % threads == 0,
|
||||
"block_size must be divisible by the number of threads");
|
||||
|
||||
// Which pre-packed are we responsible for
|
||||
auto blk_coord = make_coord(blockIdx.x, blockIdx.y, blockIdx.z);
|
||||
auto tB_in = local_tile(
|
||||
B_in, append(typename PrepackedLayoutB::PPBlockShape_NK{}, _1{}),
|
||||
blk_coord);
|
||||
|
||||
// Find the start offset in the output for this pre-packed block
|
||||
auto bNbKL_to_offset = PrepackedLayoutB::bNbKL_to_offset(shape(B_in));
|
||||
|
||||
// Tensor representing a 1:1 mapping to the output space in 1D
|
||||
auto tB_out_linear =
|
||||
make_tensor(get_logical_ptr(B_out_ptr) + bNbKL_to_offset(blk_coord),
|
||||
make_layout(make_shape(block_size)));
|
||||
// Mapping from output space (1D) to input space
|
||||
auto tB_in_linear = make_tensor(
|
||||
tB_in.data(),
|
||||
tB_in.layout()
|
||||
.compose(right_inverse(PrepackedLayoutB::ppblock_ilvd_NK_to_offset()))
|
||||
.with_shape(make_shape(block_size)));
|
||||
|
||||
// Tile for this specific thread (could have used a TiledCopy but these work
|
||||
// best with 2d layouts, this is a simple 1d layout so local_tile is enough,
|
||||
// we are also not that concerned with performance for this kernel)
|
||||
auto thr_tB_in_linear =
|
||||
local_tile(tB_in_linear, make_shape(eles_per_thread), threadIdx.x);
|
||||
auto thr_tB_out_linear =
|
||||
local_tile(tB_out_linear, make_shape(eles_per_thread), threadIdx.x);
|
||||
|
||||
// Construct a register-backed Tensor with the same shape as each thread's
|
||||
// partition
|
||||
auto fragment = make_tensor<ElementB>(shape(thr_tB_in_linear));
|
||||
|
||||
copy(thr_tB_in_linear, fragment);
|
||||
copy(Copy_Atom<DefaultCopy, uint8_t>{}, fragment, thr_tB_out_linear);
|
||||
}
|
||||
|
||||
template <typename PrepackedLayoutB, typename InLayout>
|
||||
static void prepack_B_template(
|
||||
cudaStream_t stream, typename PrepackedLayoutB::ElementB const* B_in_ptr,
|
||||
InLayout B_layout, typename PrepackedLayoutB::ElementB* B_out_ptr) {
|
||||
using TileShapeNKL =
|
||||
decltype(append(typename PrepackedLayoutB::PPBlockShape_NK{}, _1{}));
|
||||
auto ilvd_NKbNbKL_to_offset =
|
||||
PrepackedLayoutB::ilvd_NKbNbKL_to_offset(shape(B_layout));
|
||||
|
||||
TORCH_CHECK(size<0>(B_layout) % size<0>(TileShapeNKL{}) == 0);
|
||||
TORCH_CHECK(size<1>(B_layout) % size<1>(TileShapeNKL{}) == 0);
|
||||
|
||||
auto N_tiles = size<0>(B_layout) / size<0>(TileShapeNKL{});
|
||||
auto K_tiles = size<1>(B_layout) / size<1>(TileShapeNKL{});
|
||||
auto L_tiles = size<2>(B_layout);
|
||||
|
||||
auto B_in = make_tensor(get_logical_ptr(B_in_ptr), B_layout);
|
||||
|
||||
prepack_B_kernel<128, PrepackedLayoutB>
|
||||
<<<dim3(N_tiles, K_tiles, L_tiles), 128, 0, stream>>>(B_in, B_out_ptr);
|
||||
}
|
||||
|
||||
}; // namespace machete
|
||||
74
csrc/quantization/machete/machete_prepack_launcher.cuh
Normal file
74
csrc/quantization/machete/machete_prepack_launcher.cuh
Normal file
@@ -0,0 +1,74 @@
|
||||
#pragma once
|
||||
|
||||
#include "machete_prepack_kernel.cuh"
|
||||
#include "cutlass_extensions/torch_utils.hpp"
|
||||
#include "core/scalar_type.hpp"
|
||||
|
||||
namespace machete {
|
||||
|
||||
struct PrepackBArgs {
|
||||
torch::Tensor const& B;
|
||||
at::ScalarType a_type;
|
||||
vllm::ScalarType b_type;
|
||||
std::optional<at::ScalarType> maybe_group_scales_type;
|
||||
};
|
||||
|
||||
template <typename PrepackedLayoutB>
|
||||
torch::Tensor prepack_impl(torch::Tensor const B) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(B));
|
||||
using ElementB = typename PrepackedLayoutB::ElementB;
|
||||
using PPBlockShape_NK = typename PrepackedLayoutB::PPBlockShape_NK;
|
||||
|
||||
auto device = B.device();
|
||||
auto stream = at::cuda::getCurrentCUDAStream(device.index());
|
||||
auto B_ptr = static_cast<ElementB const*>(B.const_data_ptr());
|
||||
// elements per storage item for B
|
||||
auto eles_per_storage =
|
||||
(B.dtype().itemsize() * 8) / cute::sizeof_bits_v<ElementB>;
|
||||
|
||||
// torch B passed in is/should be (packed_K,N), the kernel expects (N,K,L) (to
|
||||
// match cutlass using (N,K,L) for B), so we transpose B to (N,packed_K,L)
|
||||
auto Bt_packed = B.t();
|
||||
|
||||
TORCH_CHECK(
|
||||
(B.size(0) * eles_per_storage) % size<1>(PPBlockShape_NK{}) == 0,
|
||||
"B.shape[0] (in terms of unpacked elements) must be a multiple of ",
|
||||
size<1>(PPBlockShape_NK{}));
|
||||
TORCH_CHECK(B.size(1) % size<0>(PPBlockShape_NK{}) == 0,
|
||||
"B.shape[1] must be a multiple of ", size<0>(PPBlockShape_NK{}));
|
||||
|
||||
using StrideB = cutlass::detail::TagToStrideB_t<cutlass::layout::ColumnMajor>;
|
||||
auto const l_Bt_packed = make_cute_layout<StrideB>(Bt_packed, "B");
|
||||
|
||||
// convert (N,packed_K,L) layout to (N,K,L) layout
|
||||
// in effect we want to do: blocked_product(layout_Bt_packed,
|
||||
// make_ordered_layout(make_shape(_1{}, eles_per_storage, _1{}),
|
||||
// Step<_1, _0, _2>{}));
|
||||
// but blocked_product does not support dynamic strides so we implement the
|
||||
// equivalent manually,
|
||||
// new_shape = (N, packed_K, L) * (1, eles_per_storage, 1) -> (N, K, L)
|
||||
// new_stride = (s0, s1, s2) * (eles_per_storage, 1, eles_per_storage)
|
||||
// when s1 == 1
|
||||
TORCH_CHECK(stride<1>(l_Bt_packed) == 1);
|
||||
// clang-format off
|
||||
auto const layout_Bt = make_layout(
|
||||
transform_with_idx(l_Bt_packed.shape(), [&](auto ele, auto idx) {
|
||||
return idx == 1 ? ele * eles_per_storage : ele;
|
||||
}),
|
||||
transform_with_idx(l_Bt_packed.stride(), [&](auto ele, auto idx) {
|
||||
return idx != 1 ? ele * eles_per_storage : ele;
|
||||
}));
|
||||
// clang-format on
|
||||
|
||||
// Allocate output
|
||||
torch::Tensor D = torch::empty_like(B, {}, at::MemoryFormat::Contiguous);
|
||||
|
||||
prepack_B_template<PrepackedLayoutB>(
|
||||
stream, B_ptr, layout_Bt, static_cast<ElementB*>(D.mutable_data_ptr()));
|
||||
|
||||
return D;
|
||||
};
|
||||
|
||||
torch::Tensor prepack_B_dispatch(PrepackBArgs args);
|
||||
|
||||
}; // namespace machete
|
||||
253
csrc/quantization/machete/machete_prepacked_layout.cuh
Normal file
253
csrc/quantization/machete/machete_prepacked_layout.cuh
Normal file
@@ -0,0 +1,253 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
// clang-format off
|
||||
// The cutlass include order matters (annoyingly)
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
// clang-format on
|
||||
|
||||
#include "cutlass_extensions/cute_utils.cuh"
|
||||
#include "machete_collective_builder.cuh"
|
||||
#include "machete_interleaving_utils.cuh"
|
||||
|
||||
namespace machete {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
struct IlvBlkLayoutAuto {};
|
||||
|
||||
// This defines a prepacked layout for the B matrix, where the matrix is broken
|
||||
// up into PPBlockShape_NK blocks. The data within each block is then compactly
|
||||
// stored in memory such that when performing a TiledMMA operation with the same
|
||||
// shape as prepacked block, all the data for a given thread is contiguous in
|
||||
// memory. This allows us to use wider shared memory loads when loading B from
|
||||
// shared memory. The values within a thread are also potentially interlaeved
|
||||
// inorder to allow for more efficient upconverting.
|
||||
//
|
||||
// The contract here is that the `TiledMma` determined below matches the one
|
||||
// ultimately used in the kernel. (this is also why the other element types are
|
||||
// required along with the kernel schedule)
|
||||
template <typename ElementA_, typename ElementB_, typename ElementConvert_,
|
||||
typename AccumulatorT, class LayoutB, class KernelSchedule,
|
||||
typename IlvBlkLayout_ = IlvBlkLayoutAuto>
|
||||
// clang-format on
|
||||
struct PrepackedLayoutBTemplate {
|
||||
using MmaType = ElementA_;
|
||||
using ElementA = ElementA_;
|
||||
using ElementB = ElementB_;
|
||||
using ElementAccumulator = AccumulatorT;
|
||||
using ElementMma = MmaType;
|
||||
|
||||
// Interleave for 4bit bit types when we are not upconverting to fp8 or int8,
|
||||
// in those cases case we use a LUT using prmt instructions to upconvert and
|
||||
// is more efficient if the data is not interleaved For 8bit+ prmt
|
||||
// instructions makes non-interleaved layouts efficient enough we don't need
|
||||
// iterleaved layouts (and can reuse more of the existing cutlass converts)
|
||||
static constexpr bool should_interleave =
|
||||
sizeof_bits_v<ElementB> <= 4 &&
|
||||
!std::is_same_v<ElementConvert_, cutlass::float_e4m3_t> &&
|
||||
!std::is_same_v<ElementConvert_, int8_t>;
|
||||
|
||||
// Only use interleaved layouts for subbyte weights,
|
||||
using IlvdBlkLayout = std::conditional_t<
|
||||
std::is_same_v<IlvBlkLayout_, IlvBlkLayoutAuto>,
|
||||
std::conditional_t<
|
||||
should_interleave,
|
||||
decltype(get_interleaved_blk_layout<
|
||||
ElementB, sizeof_bits_v<ElementConvert_>, 32>()),
|
||||
void>,
|
||||
IlvBlkLayout_>;
|
||||
|
||||
// TODO (LucasWilkinson): compare the performance for other sizes
|
||||
// Prepacked block shape, smallest layout atom for loading into registers
|
||||
// (can contain multiple wgmma instructions worth of data in one block)
|
||||
// We ideally want this to be configured such that a thread can perform 128bit
|
||||
// loads, i.e. we amount of data associated with each thread within a
|
||||
// prepacked block is a multiple of 128bits, when using a cooperative sechdule
|
||||
// we have 256 threads working a single block at a time, this means each
|
||||
// thread works on `sizeof_bits_v<ElementB> * (128*64) / 256` bits of data,
|
||||
// for a 4bit type this would be 128bits
|
||||
using PPBlockShape_NK = Shape<_128, _64>;
|
||||
|
||||
// Create the shape of the tile anticipated to be used by the GEMM kernel,
|
||||
// when the kernel executes we will compute `Ct = Bt * At` since the
|
||||
// quantized weights (B), must be the lhs operand so the flow through
|
||||
// registers.
|
||||
// The _128 here doesn't actually impact the shape of the stored tile directly
|
||||
// but may impact the op selected by rs_op_selector
|
||||
using GemmTileShape = decltype(make_shape(size<0>(PPBlockShape_NK{}), _128{},
|
||||
size<1>(PPBlockShape_NK{})));
|
||||
|
||||
static constexpr cute::GMMA::Major GmmaMajorB =
|
||||
gmma_rs_tag_to_major_B<LayoutB>();
|
||||
|
||||
// For coop schedules we have two warp groups cooperatively issuing wgmma
|
||||
// instructions so we use 2 atoms along the M dim (one for each warpgroup)
|
||||
using AtomLayoutMNK = cute::conditional_t<
|
||||
cute::is_same_v<KernelSchedule, KernelTmaWarpSpecializedCooperative>,
|
||||
Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;
|
||||
|
||||
using TiledMma = decltype(cute::make_tiled_mma(
|
||||
cute::GMMA::rs_op_selector<ElementMma, ElementMma, ElementAccumulator,
|
||||
GemmTileShape, GMMA::Major::K, GmmaMajorB>(),
|
||||
AtomLayoutMNK{}));
|
||||
|
||||
// Prepacked block, (athrid, val) -> (N,K)
|
||||
// i.e. ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...))) -> (N,K)
|
||||
CUTE_HOST_DEVICE static constexpr auto ppblock_TV_to_NK() {
|
||||
return TiledMma{}.thrfrg_A(make_layout(PPBlockShape_NK{}));
|
||||
}
|
||||
|
||||
// Prepacked block, (N,K) -> (athrid, val)
|
||||
// i.e. (N,K) -> ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...)))
|
||||
CUTE_HOST_DEVICE static constexpr auto ppblock_NK_to_TV() {
|
||||
return right_inverse(ppblock_TV_to_NK()).with_shape(PPBlockShape_NK{});
|
||||
}
|
||||
|
||||
// Prepacked block, (athrid, val) -> (storage_offset)
|
||||
// i.e. ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...))) -> (storage_idx)
|
||||
CUTE_HOST_DEVICE static constexpr auto ppblock_TV_to_offset() {
|
||||
// Return iterleaved layout
|
||||
return make_ordered_layout(shape(ppblock_TV_to_NK()), Step<_1, _0>{});
|
||||
}
|
||||
|
||||
// Prepacked block, (athrid, val) -> (storage_offset)
|
||||
// i.e. ((ThrV,(ThrM,ThrK)),(IlvdFrgV,(RestM,RestK,...))) -> (storage_idx)
|
||||
CUTE_HOST_DEVICE static constexpr auto ppblock_ilvd_TV_to_offset() {
|
||||
auto layout_no_interleave =
|
||||
make_ordered_layout(shape(ppblock_TV_to_NK()), Step<_1, _0>{});
|
||||
|
||||
if constexpr (std::is_same_v<IlvdBlkLayout, void>) {
|
||||
return layout_no_interleave;
|
||||
} else {
|
||||
// interleave by transforming FrgV into interleaved blocks where each
|
||||
// block has the layout IlvdBlkLayout, for example if IlvdBlkLayout is
|
||||
// (2, 2) : (2, 1) then we get: ((2, 2), size(FrgV) / 4) : ((2, 1), 4)
|
||||
// if FrgV is {A, B, C, D, E, F, G, H}
|
||||
// then ((IlvBlk), FrgB) is {A, C, B, D, C, G, D, H}
|
||||
auto frgV = get<1, 0>(layout_no_interleave);
|
||||
auto ilvdBlk = IlvdBlkLayout{};
|
||||
static_assert(size(frgV) % size(ilvdBlk) == 0,
|
||||
"FrgV must be divisible by size(ilvdBlk)");
|
||||
auto ilvd_FrgV = make_layout(
|
||||
make_shape(shape(ilvdBlk), Int<size(frgV) / size(ilvdBlk)>{}),
|
||||
make_stride(stride(ilvdBlk), size(ilvdBlk)));
|
||||
|
||||
// Return iterleaved layout
|
||||
return make_layout(
|
||||
get<0>(layout_no_interleave),
|
||||
make_layout(ilvd_FrgV, get<1, 1>(layout_no_interleave)));
|
||||
}
|
||||
}
|
||||
|
||||
// Prepacked block, (M,K) -> (storage_offset)
|
||||
CUTE_HOST_DEVICE static constexpr auto ppblock_ilvd_NK_to_offset() {
|
||||
// do (M,K) -> (athrid, val) -> (storage_idx)
|
||||
return ppblock_ilvd_TV_to_offset().compose(ppblock_NK_to_TV());
|
||||
}
|
||||
|
||||
// ((athrid, val), (BlocksN, BlocksK), L) -> (storage_idx)
|
||||
template <typename Shape_NKL>
|
||||
CUTE_HOST_DEVICE static constexpr auto TVbNbKL_to_offset(
|
||||
Shape_NKL shape_mkl) {
|
||||
constexpr auto block_layout = ppblock_TV_to_offset();
|
||||
|
||||
// (BlocksN, BlocksK, L)
|
||||
auto blocks_shape =
|
||||
cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}),
|
||||
[](auto x, auto y) { return x / y; });
|
||||
|
||||
// ((athrid, val), (BlocksN, BlocksK, L)) -> (storage_idx)
|
||||
auto result = make_layout(
|
||||
block_layout,
|
||||
make_layout(blocks_shape,
|
||||
compact_col_major(blocks_shape, size(block_layout))));
|
||||
|
||||
// ((athrid, val), (BlocksN, BlocksK, L))
|
||||
// => ((athrid, val), (BlocksN, BlocksK), L)
|
||||
return group<1, 3>(result(_, repeat<rank<1>(result)>(_)));
|
||||
}
|
||||
|
||||
// ((athrid_val), (BlocksN, BlocksK, L)) -> (N, K, L)
|
||||
template <typename Shape_NKL>
|
||||
CUTE_HOST_DEVICE static constexpr auto TVbNbKL_to_offset_copy(
|
||||
Shape_NKL shape_mkl) {
|
||||
auto layout = TVbNbKL_to_offset(shape_mkl);
|
||||
// for 4-bit elements, having >= 64 values per column
|
||||
// allows TMA to load full 32-byte sectors
|
||||
auto inner_layout =
|
||||
make_layout(make_shape(_256{}, size<0>(layout) / _256{}));
|
||||
|
||||
return make_layout(inner_layout, get<1>(layout), get<2>(layout));
|
||||
}
|
||||
|
||||
// ((BlockN, BlockK), (BlocksN, BlocksK), L) -> (storage_idx)
|
||||
template <typename Shape_NKL>
|
||||
CUTE_HOST_DEVICE static constexpr auto ilvd_NKbNbKL_to_offset(
|
||||
Shape_NKL shape_mkl) {
|
||||
constexpr auto block_layout = ppblock_ilvd_NK_to_offset();
|
||||
|
||||
// (BlocksN, BlocksK, L)
|
||||
auto blocks_shape =
|
||||
cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}),
|
||||
[](auto x, auto y) { return x / y; });
|
||||
|
||||
// ((athrid, val), (BlocksN, BlocksK, L)) -> (storage_idx)
|
||||
auto result = make_layout(
|
||||
block_layout,
|
||||
make_layout(blocks_shape,
|
||||
compact_col_major(blocks_shape, size(block_layout))));
|
||||
|
||||
// ((athrid, val), (BlocksN, BlocksK, L)) => ((athrid, val), (BlocksN,
|
||||
// BlocksK), L)
|
||||
return group<1, 3>(result(_, repeat<rank<1>(result)>(_)));
|
||||
}
|
||||
|
||||
// (BlocksN, BlocksK, L) -> (storage_idx)
|
||||
template <typename Shape_NKL>
|
||||
CUTE_HOST_DEVICE static constexpr auto bNbKL_to_offset(Shape_NKL shape_mkl) {
|
||||
// (BlocksN, BlocksK, L)
|
||||
auto blocks_shape =
|
||||
cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}),
|
||||
[](auto x, auto y) { return x / y; });
|
||||
auto stride = size(PPBlockShape_NK{});
|
||||
|
||||
// (BlocksN, BlocksK, L) -> (storage_idx)
|
||||
return make_layout(blocks_shape, compact_col_major(blocks_shape, stride));
|
||||
}
|
||||
|
||||
// ((athrid, val), (BlocksN, BlocksK, L)) -> (N, K, L)
|
||||
template <class Shape_NKL>
|
||||
CUTE_HOST_DEVICE static auto TVbNbK_to_NKL(Shape_NKL shape_mkl) {
|
||||
auto tile = make_tile(make_layout(size<0>(PPBlockShape_NK{})),
|
||||
make_layout(size<1>(PPBlockShape_NK{})));
|
||||
|
||||
// ((BlockN, BlockK), (BlocksN, BlocksK, L)) -> (N, K, L)
|
||||
auto tiled_A = zipped_divide(make_layout(shape_mkl), tile);
|
||||
return tiled_A.compose(ppblock_TV_to_NK(), _);
|
||||
}
|
||||
|
||||
// (N, K, L) -> ((athrid, val), (BlocksN, BlocksK), L)
|
||||
template <class Shape_NKL>
|
||||
CUTE_HOST_DEVICE static auto NKL_to_TVbNbK(Shape_NKL shape_mkl) {
|
||||
auto TVbNbK_to_NKL_layout = TVbNbK_to_NKL(shape_mkl);
|
||||
return blocked_product(ppblock_NK_to_TV(),
|
||||
make_layout(shape<1>(TVbNbK_to_NKL_layout)));
|
||||
}
|
||||
};
|
||||
|
||||
}; // namespace machete
|
||||
73
csrc/quantization/machete/machete_pytorch.cu
Normal file
73
csrc/quantization/machete/machete_pytorch.cu
Normal file
@@ -0,0 +1,73 @@
|
||||
#include "machete_mm_launcher.cuh"
|
||||
#include "machete_prepack_launcher.cuh"
|
||||
#include "core/scalar_type.hpp"
|
||||
|
||||
#include "core/registration.h"
|
||||
|
||||
namespace machete {
|
||||
|
||||
using namespace vllm;
|
||||
|
||||
std::vector<std::string> supported_schedules(
|
||||
at::ScalarType a_type, int64_t b_type_id,
|
||||
std::optional<at::ScalarType> maybe_group_scales_type,
|
||||
std::optional<at::ScalarType> maybe_group_zeros_type,
|
||||
std::optional<at::ScalarType> maybe_channel_scales_type,
|
||||
std::optional<at::ScalarType> maybe_token_scales_type,
|
||||
std::optional<at::ScalarType> maybe_out_type) {
|
||||
ScalarType const b_type = ScalarType::from_id(b_type_id);
|
||||
return supported_schedules_dispatch({
|
||||
.a_type = a_type,
|
||||
.b_type = b_type,
|
||||
.maybe_group_scales_type = maybe_group_scales_type,
|
||||
.maybe_group_zeros_type = maybe_group_zeros_type,
|
||||
.maybe_channel_scales_type = maybe_channel_scales_type,
|
||||
.maybe_token_scales_type = maybe_token_scales_type,
|
||||
.maybe_out_type = maybe_out_type,
|
||||
});
|
||||
}
|
||||
|
||||
torch::Tensor mm(torch::Tensor const& A, torch::Tensor const& B,
|
||||
int64_t b_type_id,
|
||||
std::optional<at::ScalarType> const& maybe_out_type,
|
||||
std::optional<torch::Tensor> const& maybe_group_scales,
|
||||
std::optional<torch::Tensor> const& maybe_group_zeros,
|
||||
std::optional<int64_t> maybe_group_size,
|
||||
std::optional<torch::Tensor> const& maybe_channel_scales,
|
||||
std::optional<torch::Tensor> const& maybe_token_scales,
|
||||
std::optional<std::string> maybe_schedule) {
|
||||
ScalarType const b_type = ScalarType::from_id(b_type_id);
|
||||
return mm_dispatch({.A = A,
|
||||
.B = B,
|
||||
.b_type = b_type,
|
||||
.maybe_out_type = maybe_out_type,
|
||||
.maybe_group_scales = maybe_group_scales,
|
||||
.maybe_group_zeros = maybe_group_zeros,
|
||||
.maybe_group_size = maybe_group_size,
|
||||
.maybe_channel_scales = maybe_channel_scales,
|
||||
.maybe_token_scales = maybe_token_scales,
|
||||
.maybe_schedule = maybe_schedule});
|
||||
}
|
||||
|
||||
torch::Tensor prepack_B(
|
||||
torch::Tensor const& B, at::ScalarType const& a_type, int64_t b_type_id,
|
||||
std::optional<at::ScalarType> const& maybe_group_scales_type) {
|
||||
ScalarType const b_type = ScalarType::from_id(b_type_id);
|
||||
return prepack_B_dispatch(
|
||||
{.B = B,
|
||||
.a_type = a_type,
|
||||
.b_type = b_type,
|
||||
.maybe_group_scales_type = maybe_group_scales_type});
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("machete_prepack_B", &prepack_B);
|
||||
m.impl("machete_mm", &mm);
|
||||
}
|
||||
|
||||
// use CatchAll since supported_schedules has no tensor arguments
|
||||
TORCH_LIBRARY_IMPL(TORCH_EXTENSION_NAME, CatchAll, m) {
|
||||
m.impl("machete_supported_schedules", &supported_schedules);
|
||||
}
|
||||
|
||||
}; // namespace machete
|
||||
Reference in New Issue
Block a user