Sync from v0.13

This commit is contained in:
2026-01-19 10:38:50 +08:00
parent b2ef04d792
commit 5aef6c175a
3714 changed files with 854317 additions and 89342 deletions

View File

@@ -0,0 +1,457 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
//
// This file is a modified excerpt of
// include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp
// from https://github.com/NVIDIA/cutlass v3.5.0
// It has been modified to support either row/column or scalar broadcasting
// where the tensor being loaded from is always passed in via a device pointer.
// This lets one compiled kernel handle all cases of per-tensor or
// per-channel/per-token quantization.
//
// This interface also allows the scales to be passed in as tensors that
// consistently reside on the device, which avoids an issue with a previous
// implementation where scalars needed to be on the CPU since they
// were passed in via float values. This created a potential performance hazard
// if scales were initially on the device, and caused torch.compile graphs
// breaks when moving scales to the CPU.
//
#pragma once
// Turn off clang-format for the entire file to keep it close to upstream
// clang-format off
#include "cutlass/cutlass.h"
#include "cutlass/arch/barrier.h"
#include "cute/tensor.hpp"
#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp"
namespace cutlass::epilogue::fusion {
using namespace cute;
using namespace detail;
// Row vector broadcast
template<
int Stages,
class CtaTileShapeMNK,
class Element,
class StrideMNL = Stride<_0,_1,_0>,
int Alignment = 128 / sizeof_bits_v<Element>
>
struct Sm90RowOrScalarBroadcastArray {
static_assert(Stages == 0, "Row broadcast doesn't support smem usage");
static_assert(is_static_v<decltype(take<0,2>(StrideMNL{}))>); // batch stride can be dynamic or static
static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{});
struct SharedStorage {
array_aligned<Element, size<1>(CtaTileShapeMNK{})> smem;
};
// This struct has been modified to have a bool indicating that ptr_row is a
// scalar that must be broadcast, instead of containing a scalar that is
// valid if ptr_row is null.
struct Arguments {
const Element* const* ptr_row_array = nullptr;
bool row_broadcast = true;
StrideMNL dRow = {};
};
using Params = Arguments;
template <class ProblemShape>
static constexpr Params
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
return args;
}
template <class ProblemShape>
static bool
can_implement(ProblemShape const& problem_shape, Arguments const& args) {
return true;
}
template <class ProblemShape>
static size_t
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
return 0;
}
template <class ProblemShape>
static cutlass::Status
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
CudaHostAdapter* cuda_adapter = nullptr) {
return cutlass::Status::kSuccess;
}
CUTLASS_HOST_DEVICE
Sm90RowOrScalarBroadcastArray() { }
CUTLASS_HOST_DEVICE
Sm90RowOrScalarBroadcastArray(Params const& params, SharedStorage const& shared_storage)
: params(params)
, smem(const_cast<Element*>(shared_storage.smem.data())) { }
Params params;
Element *smem = nullptr;
CUTLASS_DEVICE bool
is_producer_load_needed() const {
return false;
}
CUTLASS_DEVICE bool
is_C_load_needed() const {
return false;
}
CUTLASS_DEVICE bool
is_zero() const {
return (!params.row_broadcast && *(params.ptr_row_array[group]) == Element(0));
}
template <class... Args>
CUTLASS_DEVICE auto
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
return EmptyProducerLoadCallbacks{};
}
template <class GS_GTensor, class GS_STensor, class GS_CTensor, class Tiled_G2S, class SR_STensor, class SR_RTensor, class CTensor, class ThrResidue, class ThrNum>
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
CUTLASS_DEVICE
ConsumerStoreCallbacks(
GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_,
GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_,
SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_,
CTensor tCcRow_, ThrResidue residue_tCcRow_, ThrNum thr_num_,
int group, Params const& params_)
: tGS_gRow(tGS_gRow_)
, tGS_sRow(tGS_sRow_)
, tGS_cRow(tGS_cRow_)
, tiled_G2S(tiled_g2s_)
, tSR_sRow(tSR_sRow_)
, tSR_rRow(tSR_rRow_)
, tCcRow(tCcRow_)
, residue_tCcRow(residue_tCcRow_)
, group(group)
, params(params_) {}
GS_GTensor tGS_gRow; // (CPY,CPY_M,CPY_N)
GS_STensor tGS_sRow; // (CPY,CPY_M,CPY_N)
GS_CTensor tGS_cRow; // (CPY,CPY_M,CPY_N)
Tiled_G2S tiled_G2S;
SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
CTensor tCcRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
ThrResidue residue_tCcRow; // (m, n)
ThrNum thr_num;
int group;
Params const& params;
CUTLASS_DEVICE void
begin() {
if (!params.row_broadcast) {
fill(tSR_rRow, *(params.ptr_row_array[group]));
return;
}
auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride()));
for (int i = 0; i < size(tGS_gRow_flt); ++i) {
if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) {
continue; // OOB of SMEM,
}
if (elem_less(tGS_cRow_flt(i), make_coord(get<0>(residue_tCcRow), get<1>(residue_tCcRow)))) {
tGS_sRow_flt(i) = tGS_gRow_flt(i);
}
else {
tGS_sRow_flt(i) = Element(0); // Set to Zero when OOB so LDS could be issue without any preds.
}
}
synchronize();
}
CUTLASS_DEVICE void
begin_loop(int epi_m, int epi_n) {
if (epi_m == 0) { // Assumes M-major subtile loop
if (!params.row_broadcast) return; // Do not issue LDS when row is scalar
Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
Tensor tSR_rRow_flt = filter_zeros(tSR_rRow);
copy(tSR_sRow_flt, tSR_rRow_flt);
}
}
template <typename ElementAccumulator, int FragmentSize>
CUTLASS_DEVICE Array<Element, FragmentSize>
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
Array<Element, FragmentSize> frg_row;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < FragmentSize; ++i) {
frg_row[i] = tSR_rRow(epi_v * FragmentSize + i);
}
return frg_row;
}
};
template <
bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
class... Args
>
CUTLASS_DEVICE auto
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
auto [M, N, K, L] = args.problem_shape_mnkl;
auto [m, n, k, l] = args.tile_coord_mnkl;
using ThreadCount = decltype(size(args.tiled_copy));
Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row_array[l]), make_shape(M,N,1), params.dRow);
Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
Tensor sRow = make_tensor(make_smem_ptr(smem),
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N)
//// G2S: Gmem to Smem
auto tiled_g2s = make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
Layout< Shape<_1, ThreadCount>,
Stride<_0, _1>>{},
Layout<_1>{});
auto thr_g2s = tiled_g2s.get_slice(args.thread_idx);
Tensor tGS_gRow = thr_g2s.partition_S(gRow);
Tensor tGS_sRow = thr_g2s.partition_D(sRow);
//// G2S: Coord
auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})));
Tensor tGS_cRow = thr_g2s.partition_S(cRow);
//// S2R: Smem to Reg
Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
return ConsumerStoreCallbacks<decltype(tGS_gRow), decltype(tGS_sRow), decltype(tGS_cRow), decltype(tiled_g2s), decltype(tSR_sRow), decltype(tSR_rRow), decltype(args.tCcD), decltype(args.residue_cD), ThreadCount>(
tGS_gRow,
tGS_sRow,
tGS_cRow, tiled_g2s,
tSR_sRow,
tSR_rRow,
args.tCcD,
args.residue_cD,
ThreadCount{},
l,
params);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// Column vector broadcast
template<
int Stages,
class CtaTileShapeMNK,
class Element,
class StrideMNL = Stride<_1,_0,_0>,
int Alignment = 128 / sizeof_bits_v<Element>
>
struct Sm90ColOrScalarBroadcastArray {
static_assert(Stages == 0, "Column broadcast doesn't support smem usage yet");
static_assert(Alignment * sizeof_bits_v<Element> % 128 == 0, "sub-16B alignment not supported yet");
static_assert(
(cute::is_same_v<StrideMNL, Stride<_1,_0, _0>>) || // col vector broadcast, e.g. per-row alpha/bias
(cute::is_same_v<StrideMNL, Stride<_1,_0,int>>)); // batched col vector broadcast, e.g. batched per-row bias
// Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem
struct SharedStorage { };
// This struct has been modified to have a bool indicating that ptr_col is a
// scalar that must be broadcast, instead of containing a scalar that is
// valid if ptr_col is null.
struct Arguments {
const Element* const* ptr_col_array = nullptr;
bool col_broadcast = true;
StrideMNL dCol = {};
};
using Params = Arguments;
template <class ProblemShape>
static constexpr Params
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
return args;
}
template <class ProblemShape>
static bool
can_implement(ProblemShape const& problem_shape, Arguments const& args) {
return true;
}
template <class ProblemShape>
static size_t
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
return 0;
}
template <class ProblemShape>
static cutlass::Status
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
CudaHostAdapter* cuda_adapter = nullptr) {
return cutlass::Status::kSuccess;
}
CUTLASS_DEVICE bool
is_producer_load_needed() const {
return false;
}
CUTLASS_DEVICE bool
is_C_load_needed() const {
return false;
}
CUTLASS_DEVICE bool
is_zero() const {
return (!params.col_broadcast && *(params.ptr_col_array[group]) == Element(0));
}
CUTLASS_HOST_DEVICE
Sm90ColOrScalarBroadcastArray() { }
CUTLASS_HOST_DEVICE
Sm90ColOrScalarBroadcastArray(Params const& params, SharedStorage const& shared_storage)
: params(params) { }
Params params;
template <class... Args>
CUTLASS_DEVICE auto
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
return EmptyProducerLoadCallbacks{};
}
template<class GTensor, class RTensor, class CTensor, class ProblemShape>
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
CUTLASS_DEVICE
ConsumerStoreCallbacks(
GTensor&& tCgCol,
RTensor&& tCrCol,
CTensor&& tCcCol,
ProblemShape problem_shape,
int group,
Params const& params
):
tCgCol(cute::forward<GTensor>(tCgCol)),
tCrCol(cute::forward<RTensor>(tCrCol)),
tCcCol(cute::forward<CTensor>(tCcCol)),
m(get<0>(problem_shape)),
group(group),
params(params) {}
GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
RTensor tCrCol;
CTensor tCcCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
Params const& params;
int m;
int group;
CUTLASS_DEVICE void
begin() {
Tensor pred = make_tensor<bool>(shape(tCgCol));
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(pred); ++i) {
pred(i) = get<0>(tCcCol(i)) < m;
}
if (!params.col_broadcast) {
fill(tCrCol, *(params.ptr_col_array[group]));
return;
}
// Filter so we don't issue redundant copies over stride-0 modes
// (only works if 0-strides are in same location, which is by construction)
copy_if(pred, filter(tCgCol), filter(tCrCol));
}
template <typename ElementAccumulator, int FragmentSize>
CUTLASS_DEVICE Array<Element, FragmentSize>
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
Array<Element, FragmentSize> frg_col;
Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < FragmentSize; ++i) {
frg_col[i] = tCrCol_mn(epi_v * FragmentSize + i);
}
return frg_col;
}
};
template <
bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
class... Args
>
CUTLASS_DEVICE auto
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
auto [M, N, K, L] = args.problem_shape_mnkl;
auto [m, n, k, l] = args.tile_coord_mnkl;
Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col_array[l]), make_shape(M,N,1), params.dCol);
Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
// Generate an identity tensor matching the shape of the global tensor and
// partition the same way, this will be used to generate the predicate
// tensor for loading
Tensor cCol = make_identity_tensor(mCol.shape());
Tensor tCcCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
return ConsumerStoreCallbacks(
cute::move(tCgCol),
cute::move(tCrCol),
cute::move(tCcCol),
args.problem_shape_mnkl,
l,
params
);
}
};
}

View File

@@ -0,0 +1,497 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
//
// This file is a modified excerpt of
// include/cutlass/epilogue/fusion/visitor_load.hpp from
// https://github.com/NVIDIA/cutlass v3.5.0
// It has been modified to support either
// row/column or scalar broadcasting where the tensor being loaded from is
// always passed in via a device pointer. This lets one compiled kernel handle
// all cases of per-tensor or per-channel/per-token quantization.
//
// This interface also allows the scales to be passed in as tensors that
// consistently reside on the device, which avoids an issue with a previous
// implementation where scalars needed to be on the CPU since they
// were passed in via float values. This created a potential performance hazard
// if scales were initially on the device, and caused torch.compile graph
// breaks when moving scales to the CPU.
//
#pragma once
// Turn off clang-format for the entire file to keep it close to upstream
// clang-format off
#include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp"
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
#include "cute/tensor.hpp"
namespace cutlass::epilogue::threadblock {
using namespace cute;
using namespace detail;
template<
class ThreadMap,
class Element,
class StrideMNL
>
struct VisitorRowOrScalarBroadcast {
// This struct has been modified to have a bool indicating that ptr_row is a
// scalar that must be broadcast.
struct Arguments {
Element const* ptr_row = nullptr;
bool row_broadcast = true;
StrideMNL dRow = {};
};
using Params = Arguments;
template <class ProblemShape>
static constexpr Params
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
return args;
}
template <class ProblemShape>
static size_t
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
return 0;
}
struct SharedStorage {};
// Global load type
static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits<Element>::value;
using VecType = uint_bit_t<cute::min(128, vec_bits)>;
static int constexpr VecLength = sizeof(VecType) / sizeof(Element);
CUTLASS_HOST_DEVICE
VisitorRowOrScalarBroadcast() { }
CUTLASS_HOST_DEVICE
VisitorRowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
: params_ptr(&params) { }
Params const* params_ptr;
template <class GTensor, class RTensor, class CTensor, class ProblemShape>
struct Callbacks : EmptyCallbacks {
CUTLASS_DEVICE
Callbacks(
GTensor&& tC_gRow,
RTensor&& tC_rRow,
CTensor&& tC_cRow,
ProblemShape problem_shape,
Params const* params_ptr
):
tC_gRow(cute::forward<GTensor>(tC_gRow)),
tC_rRow(cute::forward<RTensor>(tC_rRow)),
tC_cRow(cute::forward<CTensor>(tC_cRow)),
n(get<1>(problem_shape)),
params_ptr(params_ptr) { }
GTensor tC_gRow;
RTensor tC_rRow;
CTensor tC_cRow;
Params const* params_ptr;
int n;
// This function is modified from VisitorRowBroadcast
CUTLASS_DEVICE void
begin_epilogue() {
clear(tC_rRow);
auto src_v = filter(tC_gRow);
auto coord_v = filter(tC_cRow);
auto dst_v = filter(tC_rRow);
if (params_ptr->row_broadcast) {
// In this case we are loading from a row vector and broadcasting
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(src_v); ++i) {
bool guard = get<1>(coord_v(i)) < n;
cutlass::arch::global_load<VecType, sizeof(VecType)>(
dst_v(i), (void const*)&src_v(i), guard);
}
} else {
// In this case we are loading from a scalar and broadcasting
VecType filled_vec;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < VecLength; i++) {
reinterpret_cast<Element*>(&filled_vec)[i] = *(params_ptr->ptr_row);
}
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(src_v); ++i) {
if (get<1>(coord_v(i)) < n) {
dst_v(i) = filled_vec;
}
}
}
}
template <class ElementAccumulator, int FragmentSize>
CUTLASS_DEVICE auto // returns an Array
visit(int iter_idx, int row_idx, int column_idx, int frg_idx,
Array<ElementAccumulator, FragmentSize> const& frg_acc) {
Tensor rRow_frg = recast<Array<Element, FragmentSize>>(coalesce(tC_rRow));
return rRow_frg(column_idx);
}
};
template <class ProblemShape>
CUTLASS_DEVICE auto
get_callbacks(
gemm::GemmCoord threadblock_tile_offset,
int thread_idx,
ProblemShape problem_shape
) {
Tensor mRow = make_tensor(
make_gmem_ptr(params_ptr->ptr_row),
problem_shape,
params_ptr->dRow);
// VECTOR, FRAGMENT_COLUMN
Tensor tC_gRow = recast<VecType>(
ThreadMap::partition(mRow, thread_idx, threadblock_tile_offset)
)(_,_,_0{},_0{},_0{},_0{});
Tensor tC_rRow = make_tensor_like(tC_gRow);
// Generate the pred tensor
Tensor cRow = make_identity_tensor(mRow.shape());
Tensor tC_cRow = outer_partition(
ThreadMap::partition(cRow, thread_idx, threadblock_tile_offset)(_,_,_0{},_0{},_0{},_0{}),
Shape<Int<VecLength>>{},
(_0{})
);
return Callbacks<
decltype(tC_gRow), decltype(tC_rRow),
decltype(tC_cRow), ProblemShape>(
cute::move(tC_gRow),
cute::move(tC_rRow),
cute::move(tC_cRow),
problem_shape,
params_ptr
);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// This is a modified RowBroadcast that will broadcast 0 if ptr_row is null
template<
class ThreadMap,
class Element,
class StrideMNL
>
struct VisitorRowOrZeroBroadcast {
// This struct has been modified to remove null_default (because it's always 0)
struct Arguments {
Element const* ptr_row = nullptr;
StrideMNL dRow = {};
};
using Params = Arguments;
template <class ProblemShape>
static constexpr Params
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
return args;
}
template <class ProblemShape>
static size_t
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
return 0;
}
struct SharedStorage {};
// Global load type
static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits<Element>::value;
using VecType = uint_bit_t<cute::min(128, vec_bits)>;
static int constexpr VecLength = sizeof(VecType) / sizeof(Element);
CUTLASS_HOST_DEVICE
VisitorRowOrZeroBroadcast() { }
CUTLASS_HOST_DEVICE
VisitorRowOrZeroBroadcast(Params const& params, SharedStorage const& shared_storage)
: params_ptr(&params) { }
Params const* params_ptr;
template <class GTensor, class RTensor, class CTensor, class ProblemShape>
struct Callbacks : EmptyCallbacks {
CUTLASS_DEVICE
Callbacks(
GTensor&& tC_gRow,
RTensor&& tC_rRow,
CTensor&& tC_cRow,
ProblemShape problem_shape,
Params const* params_ptr
):
tC_gRow(cute::forward<GTensor>(tC_gRow)),
tC_rRow(cute::forward<RTensor>(tC_rRow)),
tC_cRow(cute::forward<CTensor>(tC_cRow)),
n(get<1>(problem_shape)),
params_ptr(params_ptr) { }
GTensor tC_gRow;
RTensor tC_rRow;
CTensor tC_cRow;
Params const* params_ptr;
int n;
// This function is modified from VisitorRowBroadcast
CUTLASS_DEVICE void
begin_epilogue() {
clear(tC_rRow);
auto src_v = filter(tC_gRow);
auto coord_v = filter(tC_cRow);
auto dst_v = filter(tC_rRow);
if (params_ptr->ptr_row != nullptr) {
// In this case we are loading from a row vector and broadcasting
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(src_v); ++i) {
bool guard = get<1>(coord_v(i)) < n;
cutlass::arch::global_load<VecType, sizeof(VecType)>(
dst_v(i), (void const*)&src_v(i), guard);
}
} else {
// In this case we are broadcasting 0
VecType filled_vec;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < VecLength; i++) {
reinterpret_cast<Element*>(&filled_vec)[i] = Element{0};
}
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(src_v); ++i) {
if (get<1>(coord_v(i)) < n) {
dst_v(i) = filled_vec;
}
}
}
}
template <class ElementAccumulator, int FragmentSize>
CUTLASS_DEVICE auto // returns an Array
visit(int iter_idx, int row_idx, int column_idx, int frg_idx,
Array<ElementAccumulator, FragmentSize> const& frg_acc) {
Tensor rRow_frg = recast<Array<Element, FragmentSize>>(coalesce(tC_rRow));
return rRow_frg(column_idx);
}
};
template <class ProblemShape>
CUTLASS_DEVICE auto
get_callbacks(
gemm::GemmCoord threadblock_tile_offset,
int thread_idx,
ProblemShape problem_shape
) {
Tensor mRow = make_tensor(
make_gmem_ptr(params_ptr->ptr_row),
problem_shape,
params_ptr->dRow);
// VECTOR, FRAGMENT_COLUMN
Tensor tC_gRow = recast<VecType>(
ThreadMap::partition(mRow, thread_idx, threadblock_tile_offset)
)(_,_,_0{},_0{},_0{},_0{});
Tensor tC_rRow = make_tensor_like(tC_gRow);
// Generate the pred tensor
Tensor cRow = make_identity_tensor(mRow.shape());
Tensor tC_cRow = outer_partition(
ThreadMap::partition(cRow, thread_idx, threadblock_tile_offset)(_,_,_0{},_0{},_0{},_0{}),
Shape<Int<VecLength>>{},
(_0{})
);
return Callbacks<
decltype(tC_gRow), decltype(tC_rRow),
decltype(tC_cRow), ProblemShape>(
cute::move(tC_gRow),
cute::move(tC_rRow),
cute::move(tC_cRow),
problem_shape,
params_ptr
);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// Column vector broadcast
template<
class ThreadMap,
class Element,
class StrideMNL = Stride<_1,_0,_0>
>
struct VisitorColOrScalarBroadcast {
// This struct has been modified to have a bool indicating that ptr_col is a
// scalar that must be broadcast.
struct Arguments {
Element const* ptr_col = nullptr;
bool col_broadcast = true;
StrideMNL dCol = {};
};
using Params = Arguments;
template <class ProblemShape>
static constexpr Params
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
return args;
}
template <class ProblemShape>
static size_t
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
return 0;
}
struct SharedStorage { };
CUTLASS_HOST_DEVICE
VisitorColOrScalarBroadcast() { }
CUTLASS_HOST_DEVICE
VisitorColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
: params_ptr(&params) { }
Params const* params_ptr;
template <class GTensor, class RTensor, class CTensor, class ProblemShape>
struct Callbacks : EmptyCallbacks {
CUTLASS_DEVICE
Callbacks(
GTensor&& tC_gCol,
RTensor&& tC_rCol,
CTensor&& tC_cCol,
ProblemShape problem_shape,
Params const* params_ptr
):
tC_gCol(cute::forward<GTensor>(tC_gCol)),
tC_rCol(cute::forward<RTensor>(tC_rCol)),
tC_cCol(cute::forward<CTensor>(tC_cCol)),
m(get<0>(problem_shape)),
params_ptr(params_ptr) { }
GTensor tC_gCol;
RTensor tC_rCol;
CTensor tC_cCol;
Params const* params_ptr;
int m;
// This function is modified from VisitorColBroadcast
CUTLASS_DEVICE void
begin_epilogue() {
clear(tC_rCol);
Tensor pred = make_tensor<bool>(shape(tC_gCol));
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(pred); ++i) {
pred(i) = get<0>(tC_cCol(i)) < m;
}
if (params_ptr->col_broadcast) {
// In this case we are loading from a column vector and broadcasting
copy_if(pred, tC_gCol, tC_rCol);
} else {
// In this case we are loading from a scalar and broadcasting
auto dst_v = filter(tC_rCol);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(dst_v); ++i) {
if (pred(i)) {
dst_v(i) = *(params_ptr->ptr_col);
}
}
}
}
template <class ElementAccumulator, int FragmentSize>
CUTLASS_DEVICE auto // returns an Array
visit(int iter_idx, int row_idx, int column_idx, int frg_idx,
Array<ElementAccumulator, FragmentSize> const& frg_acc) {
Array<Element, FragmentSize> frg_col;
frg_col.fill(tC_rCol(row_idx,iter_idx));
return frg_col;
}
};
template <class ProblemShape>
CUTLASS_DEVICE auto
get_callbacks(
gemm::GemmCoord threadblock_tile_offset,
int thread_idx,
ProblemShape problem_shape
) {
Tensor mCol = make_tensor(
make_gmem_ptr(params_ptr->ptr_col),
problem_shape,
params_ptr->dCol);
// VECTOR, FRAGMENT_COLUMN, FRAGMENT_ROW, ITERATION_ROW, ITERATION_GROUP, ITERATION_CLUSTER
Tensor tC_gCol = group_modes<1,4>(
ThreadMap::partition(mCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_));
Tensor tC_rCol = make_tensor_like(tC_gCol);
// Generate the pred tensor
Tensor cCol = make_identity_tensor(mCol.shape());
Tensor tC_cCol = group_modes<1,4>(
ThreadMap::partition(cCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_));
return Callbacks<
decltype(tC_gCol), decltype(tC_rCol),
decltype(tC_cCol), ProblemShape>(
cute::move(tC_gCol),
cute::move(tC_rCol),
cute::move(tC_cCol),
problem_shape,
params_ptr
);
}
};
}

View File

@@ -0,0 +1,447 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
//
// This file is a modified excerpt of
// include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp
// from https://github.com/NVIDIA/cutlass v3.5.0
// It has been modified to support either row/column or scalar broadcasting
// where the tensor being loaded from is always passed in via a device pointer.
// This lets one compiled kernel handle all cases of per-tensor or
// per-channel/per-token quantization.
//
// This interface also allows the scales to be passed in as tensors that
// consistently reside on the device, which avoids an issue with a previous
// implementation where scalars needed to be on the CPU since they
// were passed in via float values. This created a potential performance hazard
// if scales were initially on the device, and caused torch.compile graphs
// breaks when moving scales to the CPU.
//
#pragma once
// Turn off clang-format for the entire file to keep it close to upstream
// clang-format off
#include "cutlass/cutlass.h"
#include "cutlass/arch/barrier.h"
#include "cute/tensor.hpp"
#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp"
namespace cutlass::epilogue::fusion {
using namespace cute;
using namespace detail;
// Row vector broadcast
template<
int Stages,
class CtaTileShapeMNK,
class Element,
class StrideMNL = Stride<_0,_1,_0>,
int Alignment = 128 / sizeof_bits_v<Element>
>
struct Sm90RowOrScalarBroadcast {
static_assert(Stages == 0, "Row broadcast doesn't support smem usage");
static_assert(is_static_v<decltype(take<0,2>(StrideMNL{}))>); // batch stride can be dynamic or static
static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{});
struct SharedStorage {
array_aligned<Element, size<1>(CtaTileShapeMNK{})> smem;
};
// This struct has been modified to have a bool indicating that ptr_row is a
// scalar that must be broadcast, instead of containing a scalar that is
// valid if ptr_row is null.
struct Arguments {
Element const* ptr_row = nullptr;
bool row_broadcast = true;
StrideMNL dRow = {};
};
using Params = Arguments;
template <class ProblemShape>
static constexpr Params
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
return args;
}
template <class ProblemShape>
static bool
can_implement(ProblemShape const& problem_shape, Arguments const& args) {
return true;
}
template <class ProblemShape>
static size_t
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
return 0;
}
template <class ProblemShape>
static cutlass::Status
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
CudaHostAdapter* cuda_adapter = nullptr) {
return cutlass::Status::kSuccess;
}
CUTLASS_HOST_DEVICE
Sm90RowOrScalarBroadcast() { }
CUTLASS_HOST_DEVICE
Sm90RowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
: params(params)
, smem(const_cast<Element*>(shared_storage.smem.data())) { }
Params params;
Element *smem = nullptr;
CUTLASS_DEVICE bool
is_producer_load_needed() const {
return false;
}
CUTLASS_DEVICE bool
is_C_load_needed() const {
return false;
}
CUTLASS_DEVICE bool
is_zero() const {
return (!params.row_broadcast && *(params.ptr_row) == Element(0));
}
template <class... Args>
CUTLASS_DEVICE auto
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
return EmptyProducerLoadCallbacks{};
}
template <class GS_GTensor, class GS_STensor, class GS_CTensor, class Tiled_G2S, class SR_STensor, class SR_RTensor, class CTensor, class ThrResidue, class ThrNum>
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
CUTLASS_DEVICE
ConsumerStoreCallbacks(
GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_,
GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_,
SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_,
CTensor tCcRow_, ThrResidue residue_tCcRow_, ThrNum thr_num_, Params const& params_)
: tGS_gRow(tGS_gRow_)
, tGS_sRow(tGS_sRow_)
, tGS_cRow(tGS_cRow_)
, tiled_G2S(tiled_g2s_)
, tSR_sRow(tSR_sRow_)
, tSR_rRow(tSR_rRow_)
, tCcRow(tCcRow_)
, residue_tCcRow(residue_tCcRow_)
, params(params_) {}
GS_GTensor tGS_gRow; // (CPY,CPY_M,CPY_N)
GS_STensor tGS_sRow; // (CPY,CPY_M,CPY_N)
GS_CTensor tGS_cRow; // (CPY,CPY_M,CPY_N)
Tiled_G2S tiled_G2S;
SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
CTensor tCcRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
ThrResidue residue_tCcRow; // (m, n)
ThrNum thr_num;
Params const& params;
CUTLASS_DEVICE void
begin() {
if (!params.row_broadcast) {
fill(tSR_rRow, *(params.ptr_row));
return;
}
auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride()));
for (int i = 0; i < size(tGS_gRow_flt); ++i) {
if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) {
continue; // OOB of SMEM,
}
if (elem_less(tGS_cRow_flt(i), make_coord(get<0>(residue_tCcRow), get<1>(residue_tCcRow)))) {
tGS_sRow_flt(i) = tGS_gRow_flt(i);
}
else {
tGS_sRow_flt(i) = Element(0); // Set to Zero when OOB so LDS could be issue without any preds.
}
}
synchronize();
}
CUTLASS_DEVICE void
begin_loop(int epi_m, int epi_n) {
if (epi_m == 0) { // Assumes M-major subtile loop
if (!params.row_broadcast) return; // Do not issue LDS when row is scalar
Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
Tensor tSR_rRow_flt = filter_zeros(tSR_rRow);
copy(tSR_sRow_flt, tSR_rRow_flt);
}
}
template <typename ElementAccumulator, int FragmentSize>
CUTLASS_DEVICE Array<Element, FragmentSize>
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
Array<Element, FragmentSize> frg_row;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < FragmentSize; ++i) {
frg_row[i] = tSR_rRow(epi_v * FragmentSize + i);
}
return frg_row;
}
};
template <
bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
class... Args
>
CUTLASS_DEVICE auto
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
auto [M, N, K, L] = args.problem_shape_mnkl;
auto [m, n, k, l] = args.tile_coord_mnkl;
using ThreadCount = decltype(size(args.tiled_copy));
Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow);
Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
Tensor sRow = make_tensor(make_smem_ptr(smem),
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N)
//// G2S: Gmem to Smem
auto tiled_g2s = make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
Layout< Shape<_1, ThreadCount>,
Stride<_0, _1>>{},
Layout<_1>{});
auto thr_g2s = tiled_g2s.get_slice(args.thread_idx);
Tensor tGS_gRow = thr_g2s.partition_S(gRow);
Tensor tGS_sRow = thr_g2s.partition_D(sRow);
//// G2S: Coord
auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})));
Tensor tGS_cRow = thr_g2s.partition_S(cRow);
//// S2R: Smem to Reg
Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
return ConsumerStoreCallbacks<decltype(tGS_gRow), decltype(tGS_sRow), decltype(tGS_cRow), decltype(tiled_g2s), decltype(tSR_sRow), decltype(tSR_rRow), decltype(args.tCcD), decltype(args.residue_cD), ThreadCount>(
tGS_gRow,
tGS_sRow,
tGS_cRow, tiled_g2s,
tSR_sRow,
tSR_rRow,
args.tCcD,
args.residue_cD,
ThreadCount{},
params);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// Column vector broadcast
template<
int Stages,
class CtaTileShapeMNK,
class Element,
class StrideMNL = Stride<_1,_0,_0>,
int Alignment = 128 / sizeof_bits_v<Element>
>
struct Sm90ColOrScalarBroadcast {
static_assert(Stages == 0, "Column broadcast doesn't support smem usage yet");
static_assert(Alignment * sizeof_bits_v<Element> % 128 == 0, "sub-16B alignment not supported yet");
static_assert(
(cute::is_same_v<StrideMNL, Stride<_1,_0, _0>>) || // col vector broadcast, e.g. per-row alpha/bias
(cute::is_same_v<StrideMNL, Stride<_1,_0,int>>)); // batched col vector broadcast, e.g. batched per-row bias
// Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem
struct SharedStorage { };
// This struct has been modified to have a bool indicating that ptr_col is a
// scalar that must be broadcast, instead of containing a scalar that is
// valid if ptr_col is null.
struct Arguments {
Element const* ptr_col = nullptr;
bool col_broadcast = true;
StrideMNL dCol = {};
};
using Params = Arguments;
template <class ProblemShape>
static constexpr Params
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
return args;
}
template <class ProblemShape>
static bool
can_implement(ProblemShape const& problem_shape, Arguments const& args) {
return true;
}
template <class ProblemShape>
static size_t
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
return 0;
}
template <class ProblemShape>
static cutlass::Status
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
CudaHostAdapter* cuda_adapter = nullptr) {
return cutlass::Status::kSuccess;
}
CUTLASS_DEVICE bool
is_producer_load_needed() const {
return false;
}
CUTLASS_DEVICE bool
is_C_load_needed() const {
return false;
}
CUTLASS_DEVICE bool
is_zero() const {
return (!params.col_broadcast && *(params.ptr_col) == Element(0));
}
CUTLASS_HOST_DEVICE
Sm90ColOrScalarBroadcast() { }
CUTLASS_HOST_DEVICE
Sm90ColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
: params(params) { }
Params params;
template <class... Args>
CUTLASS_DEVICE auto
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
return EmptyProducerLoadCallbacks{};
}
template<class GTensor, class RTensor, class CTensor, class ProblemShape>
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
CUTLASS_DEVICE
ConsumerStoreCallbacks(
GTensor&& tCgCol,
RTensor&& tCrCol,
CTensor&& tCcCol,
ProblemShape problem_shape,
Params const& params
):
tCgCol(cute::forward<GTensor>(tCgCol)),
tCrCol(cute::forward<RTensor>(tCrCol)),
tCcCol(cute::forward<CTensor>(tCcCol)),
m(get<0>(problem_shape)),
params(params) {}
GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
RTensor tCrCol;
CTensor tCcCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
Params const& params;
int m;
CUTLASS_DEVICE void
begin() {
Tensor pred = make_tensor<bool>(shape(tCgCol));
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(pred); ++i) {
pred(i) = get<0>(tCcCol(i)) < m;
}
if (!params.col_broadcast) {
fill(tCrCol, *(params.ptr_col));
return;
}
// Filter so we don't issue redundant copies over stride-0 modes
// (only works if 0-strides are in same location, which is by construction)
copy_if(pred, filter(tCgCol), filter(tCrCol));
}
template <typename ElementAccumulator, int FragmentSize>
CUTLASS_DEVICE Array<Element, FragmentSize>
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
Array<Element, FragmentSize> frg_col;
Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < FragmentSize; ++i) {
frg_col[i] = tCrCol_mn(epi_v * FragmentSize + i);
}
return frg_col;
}
};
template <
bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
class... Args
>
CUTLASS_DEVICE auto
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
auto [M, N, K, L] = args.problem_shape_mnkl;
Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol);
Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
// Generate an identity tensor matching the shape of the global tensor and
// partition the same way, this will be used to generate the predicate
// tensor for loading
Tensor cCol = make_identity_tensor(mCol.shape());
Tensor tCcCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
return ConsumerStoreCallbacks(
cute::move(tCgCol),
cute::move(tCrCol),
cute::move(tCcCol),
args.problem_shape_mnkl,
params
);
}
};
}

View File

@@ -0,0 +1,321 @@
#pragma once
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp"
/*
This file defines custom epilogues for fusing channel scales, token scales,
bias, and activation zero-points onto a GEMM operation using the
CUTLASS 2.x API, for sm80 (Ampere) NVIDIA GPUs.
Epilogues must contain a public type named EVTCompute of type Sm80EVT,
as well as a static prepare_args function that constructs an
EVTCompute::Arguments struct.
*/
namespace vllm::c2x {
using namespace cute;
/*
* This class provides the common load descriptors for the
* ScaledEpilogue[...] classes
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogueBase {
protected:
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
template <typename T>
using ColOrScalarLoad =
cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
OutputTileThreadMap, T, Stride<Int<1>, Int<0>, Int<0>>>;
template <typename T>
using RowOrScalarLoad =
cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast<
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
template <typename T>
using ColLoad = cutlass::epilogue::threadblock::VisitorColBroadcast<
OutputTileThreadMap, T, Stride<Int<1>, Int<0>, Int<0>>>;
template <typename T>
using RowLoad = cutlass::epilogue::threadblock::VisitorRowBroadcast<
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
template <typename T>
using RowOrZeroLoad =
cutlass::epilogue::threadblock::VisitorRowOrZeroBroadcast<
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
// This utility function constructs the arguments for the load descriptors
// from a tensor. It can handle both row and column, as well as row/column or
// scalar cases.
template <typename Descriptor, typename T>
static auto args_from_tensor(torch::Tensor const& tensor) {
using Arguments = typename Descriptor::Arguments;
auto* data_ptr = static_cast<T*>(tensor.data_ptr());
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
std::is_same_v<Descriptor, RowOrScalarLoad<T>>) {
return Arguments{data_ptr, tensor.numel() != 1};
} else {
// it would technically work but no use case as data_ptr is never nullptr
static_assert(!std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
return Arguments{data_ptr};
}
}
// This overload handles the case where there might not be a tensor, in which
// case a nullptr is passed and a constant (0) is used.
template <typename Descriptor, typename T>
static auto args_from_tensor(std::optional<torch::Tensor> const& tensor) {
static_assert(std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
using Arguments = typename Descriptor::Arguments;
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
return Arguments{data_ptr};
}
};
/*
This epilogue function defines a quantized GEMM operation similar to
torch._scaled_mm.
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
per-row. B can be quantized per-tensor or per-column.
Any combination of per-tensor and per-row or column is supported.
A and B must have symmetric quantization (zero point == 0).
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
scales are applied elementwise with numpy-style broadcasting.
ScaleA and ScaleB define the epilogue functions that apply the scales for
the A and B operands respectively. These scales may be either per-tensor or
per row or column.
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogue
: private ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
private:
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 =
cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};
return ArgumentType{a_args, evt0_args, {}};
}
};
/*
* This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
* This bias can also be used in the per-tensor azp case, where the activation
* zero point (azp) is used to compute an azp correction term,
* which is folded into the bias.
*
* The bias tensor must be per-output channel.
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogueBias
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
protected:
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template RowLoad<ElementD>;
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 =
cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::homogeneous_multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA,
EVTCompute0, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};
return ArgumentType{a_args, evt0_args, bias_args, {}};
}
};
/*
* This epilogue directly supports per-tensor azp in int32 form.
* As opposed to the per-token epilogue below, this epilogue only has an azp_adj
* term, which should already be multiplied with the scalar azp.
* The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
*
* This epilogue also supports bias, which remains per-channel.
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogueBiasAzp
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
private:
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template RowOrZeroLoad<ElementD>;
// This is the full AZP term, azp * J @ B, shape (1,n)
using AzpWithAdj = typename SUPER::template RowLoad<int32_t>;
// Compute float(accum - azp_adj), both operands are int32_t
using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::minus, float, int32_t,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeAzp =
cutlass::epilogue::threadblock::Sm80EVT<ComputeAzp, Accum, AzpWithAdj>;
using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeScaleB =
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleB, ScaleB,
EVTComputeAzp>;
using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::homogeneous_multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleBiasA, ScaleA,
EVTComputeScaleB, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
std::optional<torch::Tensor> const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
auto azp_adj_args =
SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args, {}};
typename EVTComputeScaleB::Arguments evt_scale_b_args{
b_args, evt_azp_args, {}};
return ArgumentType{a_args, evt_scale_b_args, bias_args, {}};
}
};
/*
* This epilogue supports per-token azp by computing and applying
* the correction term using a rank-1 update. If the term were materialized,
* it would require O(m*n) space, and this way it only requires O(m+n) space.
* The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
* point for each row of A.
* The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
*
* This epilogue also supports bias, which remains per-channel.
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogueBiasAzpToken
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
private:
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template RowOrZeroLoad<ElementD>;
// Per-token azp term, shape (m,1)
using Azp = typename SUPER::template ColLoad<int32_t>;
// This is the AZP adjustment term, J @ B, shape (1,n)
using AzpAdj = typename SUPER::template RowLoad<int32_t>;
// Compute azp * azp_adj
using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, int32_t, int32_t,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeAzp =
cutlass::epilogue::threadblock::Sm80EVT<ComputeAzp, Azp, AzpAdj>;
// Compute float(accum - azp*azp_adj), all operands are int32_t
using ComputeAcc = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::minus, float, int32_t,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeAcc =
cutlass::epilogue::threadblock::Sm80EVT<ComputeAcc, Accum, EVTComputeAzp>;
using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeScaleB =
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleB, ScaleB,
EVTComputeAcc>;
using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::homogeneous_multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleBiasA, ScaleA,
EVTComputeScaleB, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
torch::Tensor const& azp,
std::optional<torch::Tensor> const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
auto azp_args = SUPER::template args_from_tensor<Azp, int32_t>(azp);
auto azp_adj_args =
SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args, {}};
typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args, {}};
typename EVTComputeScaleB::Arguments evt_scale_b_args{
b_args, evt_acc_args, {}};
return ArgumentType{a_args, evt_scale_b_args, bias_args, {}};
}
};
}; // namespace vllm::c2x

View File

@@ -0,0 +1,450 @@
#pragma once
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp"
/*
This file defines custom epilogues for fusing channel scales, token scales,
bias, and activation zero-points onto a GEMM operation using the
CUTLASS 3.x API, for NVIDIA GPUs with sm90a (Hopper) or later.
Epilogues must contain a public type named EVTCompute of type Sm90EVT,
as well as a static prepare_args function that constructs an
EVTCompute::Arguments struct.
*/
namespace vllm::c3x {
using namespace cute;
template <typename T>
struct identity {
CUTLASS_HOST_DEVICE
T operator()(T lhs) const { return lhs; }
};
template <typename ElementAcc, typename ElementD, typename TileShape>
struct TrivialEpilogue {
private:
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
using Compute = cutlass::epilogue::fusion::Sm90Compute<
cutlass::epilogue::thread::Identity, ElementD, ElementAcc,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute = cutlass::epilogue::fusion::Sm90EVT<Compute, Accum>;
using ArgumentType = typename EVTCompute::Arguments;
template <typename... Args>
static ArgumentType prepare_args(Args... args) {
return {};
}
};
/*
* This class provides the common load descriptors for the
* ScaledEpilogue[...] classes
*/
template <typename ElementAcc, typename ElementD, typename TileShape>
struct ScaledEpilogueBase {
protected:
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
template <typename T>
using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<
0 /*Stages*/, TileShape, T, Stride<Int<1>, Int<0>, Int<0>>>;
template <typename T>
using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
0 /*Stages*/, TileShape, T, Stride<Int<0>, Int<1>, Int<0>>>;
// Don't want to support nullptr by default
template <typename T, bool EnableNullPtr = false>
using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast<
0 /*Stages*/, TileShape, T, T, Stride<Int<1>, Int<0>, Int<0>>,
128 / sizeof_bits_v<T>, EnableNullPtr>;
// Don't want to support nullptr by default
template <typename T, bool EnableNullPtr = false>
using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast<
0 /*Stages*/, TileShape, T, T, Stride<Int<0>, Int<1>, Int<0>>,
128 / sizeof_bits_v<T>, EnableNullPtr>;
template <typename T>
using ColOrScalarLoadArray =
cutlass::epilogue::fusion::Sm90ColOrScalarBroadcastArray<
0 /*Stages*/, TileShape, T, Stride<Int<1>, Int<0>, Int<0>>>;
template <typename T>
using RowOrScalarLoadArray =
cutlass::epilogue::fusion::Sm90RowOrScalarBroadcastArray<
0 /*Stages*/, TileShape, T, Stride<Int<0>, Int<1>, Int<0>>>;
// This utility function constructs the arguments for the load descriptors
// from a tensor. It can handle both row and column, as well as row/column or
// scalar cases.
template <typename Descriptor, typename T>
static auto args_from_tensor(torch::Tensor const& tensor) {
using Arguments = typename Descriptor::Arguments;
auto* data_ptr = static_cast<T*>(tensor.data_ptr());
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
std::is_same_v<Descriptor, RowOrScalarLoad<T>>) {
return Arguments{data_ptr, tensor.numel() != 1};
} else {
static_assert(!std::is_same_v<Descriptor, ColLoad<T, true>> &&
!std::is_same_v<Descriptor, RowLoad<T, true>>);
return Arguments{data_ptr};
}
}
// This overload handles the case where there might not be a tensor, in which
// case a nullptr is passed and a constant (0) is used.
template <typename Descriptor, typename T>
static auto args_from_tensor(std::optional<torch::Tensor> const& tensor) {
using Arguments = typename Descriptor::Arguments;
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
static_assert(std::is_same_v<Descriptor, ColLoad<T, true>> ||
std::is_same_v<Descriptor, RowLoad<T, true>>);
return Arguments{data_ptr};
}
template <typename Descriptor, typename T>
static auto args_from_tensor(const T* const* data_ptr, bool do_broadcast) {
using Arguments = typename Descriptor::Arguments;
static_assert(std::is_same_v<Descriptor, ColOrScalarLoadArray<T>> ||
std::is_same_v<Descriptor, RowOrScalarLoadArray<T>>);
return Arguments{data_ptr, do_broadcast};
}
};
/*
This epilogue function defines a quantized GEMM operation similar to
torch.scaled_mm_.
A and B may be both either int8 or fp8_e4m3. A can be
quantized per-tensor or per-row. B can be quantized per-tensor or per-column.
Any combination of per-tensor and per-row or column is supported.
A and B must have symmetric quantization (zero point == 0).
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
scales are applied elementwise with numpy-style broadcasting.
ScaleA and ScaleB define the epilogue functions that apply the scales for
the A and B operands respectively. These scales may be either per-tensor or
per row or column.
*/
template <typename ElementAcc, typename ElementD, typename TileShape>
struct ScaledEpilogue
: private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 =
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};
return ArgumentType{a_args, evt0_args, {}};
}
};
/*
* This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
* This bias can also be used in the per-tensor azp case, where the activation
* zero point (azp) is used to compute an azp correction term,
* which is folded into the bias.
*
* The bias tensor must be per-output channel.
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
*/
template <typename ElementAcc, typename ElementD, typename TileShape>
struct ScaledEpilogueBias
: private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template RowLoad<ElementD>;
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 =
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::homogeneous_multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};
return ArgumentType{a_args, evt0_args, bias_args, {}};
}
};
/*
* This epilogue performs the same operation as ScaledEpilogueBias, but the
* bias is a column vector instead of a row vector. Useful e.g. if we are
* computing a GEMM via C^T += B^T A^T. This happens in the 2:4 sparse kernels.
*/
template <typename ElementAcc, typename ElementD, typename TileShape>
struct ScaledEpilogueColumnBias
: private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template ColLoad<ElementD>;
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 =
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::homogeneous_multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};
return ArgumentType{a_args, evt0_args, bias_args, {}};
}
};
/*
* This epilogue directly supports per-tensor azp in int32 form.
* As opposed to the per-token epilogue below, this epilogue only has an azp_adj
* term, which should already be multiplied with the scalar azp.
* The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
*
* This epilogue also supports bias, which remains per-channel.
*/
template <typename ElementAcc, typename ElementD, typename TileShape>
struct ScaledEpilogueBiasAzp
: private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template RowLoad<ElementD, true>;
// This is the full AZP term, azp * J @ B, shape (1,n)
using AzpWithAdj = typename SUPER::template RowLoad<int32_t>;
// Compute float(accum - azp_adj), both operands are int32_t
using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute<
cutlass::minus, float, int32_t,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeAzp =
cutlass::epilogue::fusion::Sm90EVT<ComputeAzp, Accum, AzpWithAdj>;
using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeScaleB =
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAzp>;
using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
cutlass::homogeneous_multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleBiasA, ScaleA,
EVTComputeScaleB, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
std::optional<torch::Tensor> const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
auto azp_adj_args =
SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args, {}};
typename EVTComputeScaleB::Arguments evt_scale_b_args{
b_args, evt_azp_args, {}};
return ArgumentType{a_args, evt_scale_b_args, bias_args, {}};
}
};
/*
* This epilogue supports per-token azp by computing and applying
* the correction term using a rank-1 update. If the term were materialized,
* it would require O(m*n) space, and this way it only requires O(m+n) space.
* The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
* point for each row of A.
* The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
*
* This epilogue also supports bias, which remains per-channel.
*/
template <typename ElementAcc, typename ElementD, typename TileShape>
struct ScaledEpilogueBiasAzpToken
: private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template RowLoad<ElementD, true>;
// Per-token azp term, shape (m,1)
using Azp = typename SUPER::template ColLoad<int32_t>;
// This is the AZP adjustment term, J @ B, shape (1,n)
using AzpAdj = typename SUPER::template RowLoad<int32_t>;
// Compute azp * azp_adj
using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, int32_t, int32_t,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeAzp =
cutlass::epilogue::fusion::Sm90EVT<ComputeAzp, Azp, AzpAdj>;
// Compute float(accum - azp*azp_adj), all operands are int32_t
using ComputeAcc = cutlass::epilogue::fusion::Sm90Compute<
cutlass::minus, float, int32_t,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeAcc =
cutlass::epilogue::fusion::Sm90EVT<ComputeAcc, Accum, EVTComputeAzp>;
using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeScaleB =
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAcc>;
using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
cutlass::homogeneous_multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleBiasA, ScaleA,
EVTComputeScaleB, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
torch::Tensor const& azp,
std::optional<torch::Tensor> const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
auto azp_args = SUPER::template args_from_tensor<Azp, int32_t>(azp);
auto azp_adj_args =
SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args, {}};
typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args, {}};
typename EVTComputeScaleB::Arguments evt_scale_b_args{
b_args, evt_acc_args, {}};
return ArgumentType{a_args, evt_scale_b_args, bias_args, {}};
}
};
/*
This epilogue works like ScaledEpilogue, but ScaleA and ScaleB are pointers
to arrays containing different scales used in group gemm. The number of
pointers in ScaleA and the number of pointers in ScaleB are equal to the
group size.
*/
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
struct ScaledEpilogueArray
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoadArray<float>;
using ScaleB = typename SUPER::template RowOrScalarLoadArray<float>;
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 =
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>;
using ArgumentType = typename EVTCompute::Arguments;
using ScaleAArray = typename SUPER::template ColOrScalarLoadArray<float>;
using ScaleBArray = typename SUPER::template RowOrScalarLoadArray<float>;
static ArgumentType prepare_args(float const* const* a_scales_ptr,
float const* const* b_scales_ptr,
bool a_col_broadcast, bool b_row_broadcast) {
auto a_args = SUPER::template args_from_tensor<ScaleAArray, float>(
a_scales_ptr, a_col_broadcast);
auto b_args = SUPER::template args_from_tensor<ScaleBArray, float>(
b_scales_ptr, b_row_broadcast);
typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};
return ArgumentType{a_args, evt0_args, {}};
}
};
}; // namespace vllm::c3x