add tensorrt_llm common and cutlass_extensions as 3rdparty (#3216)

Co-authored-by: BBuf <35585791+BBuf@users.noreply.github.com>
This commit is contained in:
Yineng Zhang
2025-01-30 23:04:41 +08:00
committed by GitHub
parent 468d23cff9
commit 222ce6f1da
86 changed files with 23201 additions and 0 deletions

1
.clang-format-ignore Normal file
View File

@@ -0,0 +1 @@
sgl-kernel/3rdparty/tensorrt_llm/*

View File

@@ -0,0 +1,22 @@
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION &
# AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
#
file(GLOB SRCS *.cpp)
file(GLOB CU_SRCS *.cu)
add_library(common_src OBJECT ${SRCS} ${CU_SRCS})
set_property(TARGET common_src PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET common_src PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)

View File

@@ -0,0 +1,34 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/assert.h"
namespace
{
bool initCheckDebug()
{
auto constexpr kDebugEnabled = "TLLM_DEBUG_MODE";
auto const debugEnabled = std::getenv(kDebugEnabled);
return debugEnabled && debugEnabled[0] == '1';
}
} // namespace
bool DebugConfig::isCheckDebugEnabled()
{
static bool const debugEnabled = initCheckDebug();
return debugEnabled;
}

View File

@@ -0,0 +1,360 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/cublasMMWrapper.h"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cublasVersionCheck.h"
#include <algorithm>
#ifndef CUDART_VERSION
#error CUDART_VERSION Undefined!
#endif
namespace tensorrt_llm
{
namespace common
{
CublasMMWrapper::CublasMMWrapper(std::shared_ptr<cublasHandle_t> cublasHandle,
std::shared_ptr<cublasLtHandle_t> cublasltHandle, cudaStream_t stream, void* workspace)
: mCublasHandle(cublasHandle)
, mCublasLtHandle(cublasltHandle)
, mStream(stream)
, mCublasWorkspace(workspace)
{
}
CublasMMWrapper::~CublasMMWrapper() {}
CublasMMWrapper::CublasMMWrapper(CublasMMWrapper const& wrapper)
: mCublasHandle(wrapper.mCublasHandle)
, mCublasLtHandle(wrapper.mCublasLtHandle)
, mStream(wrapper.mStream)
{
}
void CublasMMWrapper::createDescriptors(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n,
int const k, int const lda, int const ldb, int const ldc, int8_t fastAcc)
{
// --------------------------------------
// Create descriptors for the original matrices
check_cuda_error(
cublasLtMatrixLayoutCreate(&mADesc, mAType, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda));
check_cuda_error(
cublasLtMatrixLayoutCreate(&mBDesc, mBType, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb));
check_cuda_error(cublasLtMatrixLayoutCreate(&mCDesc, mCType, m, n, ldc));
check_cuda_error(cublasLtMatmulDescCreate(&mOperationDesc, mComputeType, mScaleType));
check_cuda_error(cublasLtMatmulDescSetAttribute(
mOperationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(cublasOperation_t)));
check_cuda_error(cublasLtMatmulDescSetAttribute(
mOperationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(cublasOperation_t)));
check_cuda_error(
cublasLtMatmulDescSetAttribute(mOperationDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fastAcc, sizeof(int8_t)));
}
void CublasMMWrapper::setScaleDescriptors(void* scale_a, void* scale_b)
{
check_cuda_error(
cublasLtMatmulDescSetAttribute(mOperationDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &scale_a, sizeof(void*)));
check_cuda_error(
cublasLtMatmulDescSetAttribute(mOperationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &scale_b, sizeof(void*)));
}
void CublasMMWrapper::destroyDescriptors()
{
check_cuda_error(cublasLtMatmulDescDestroy(mOperationDesc));
check_cuda_error(cublasLtMatrixLayoutDestroy(mADesc));
check_cuda_error(cublasLtMatrixLayoutDestroy(mBDesc));
check_cuda_error(cublasLtMatrixLayoutDestroy(mCDesc));
mOperationDesc = NULL;
mADesc = NULL;
mBDesc = NULL;
mCDesc = NULL;
}
void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc)
{
Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, 1.0f, 0.0f);
}
void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc,
std::optional<cublasLtMatmulHeuristicResult_t> const& heuristic)
{
if (heuristic)
{
Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, 1.0f, 0.0f, /* hasAlgo */ (*heuristic).algo,
(*heuristic).state == CUBLAS_STATUS_SUCCESS && (*heuristic).workspaceSize < CUBLAS_WORKSPACE_SIZE,
/* usingCublasLt */ true);
}
else
{
Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, 1.0f, 0.0f, {}, /* hasAlgo */ false,
/* usingCublasLt */ true);
}
}
void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta,
std::optional<cublasLtMatmulHeuristicResult_t> const& heuristic)
{
if (heuristic)
{
Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, f_alpha, f_beta, /* hasAlgo */ (*heuristic).algo,
(*heuristic).state == CUBLAS_STATUS_SUCCESS && (*heuristic).workspaceSize < CUBLAS_WORKSPACE_SIZE,
/* usingCublasLt */ true);
}
else
{
Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, f_alpha, f_beta, {}, /* hasAlgo */ false,
/* usingCublasLt */ true);
}
}
void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta)
{
bool usingCublasLt = mAType == CUDA_R_16F || mAType == CUDA_R_8F_E4M3;
Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, f_alpha, f_beta, {}, /* hasAlgo */ false,
/* usingCublasLt */ usingCublasLt);
}
void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta,
cublasLtMatmulAlgo_t const& algo, bool hasAlgo, bool usingCublasLt)
{
half h_alpha = (half) (f_alpha);
half h_beta = (half) (f_beta);
// TODO: default cublas libs
usingCublasLt = usingCublasLt && (mAType == CUDA_R_16F || mAType == CUDA_R_8F_E4M3);
bool isFp16ComputeType = mComputeType == CUBLAS_COMPUTE_16F;
int batch_count = 1;
// fp32 use cublas as default
// fp16 use cublasLt as default
void const* alpha = isFp16ComputeType ? reinterpret_cast<void*>(&h_alpha) : reinterpret_cast<void*>(&f_alpha);
void const* beta = isFp16ComputeType ? reinterpret_cast<void*>(&h_beta) : reinterpret_cast<void*>(&f_beta);
int workspaceSize = mCublasWorkspace == NULL ? 0 : CUBLAS_WORKSPACE_SIZE;
if (usingCublasLt)
{
if (hasAlgo)
{
hasAlgo = checkTactic(transa, transb, m, n, k, lda, ldb, ldc, algo);
}
check_cuda_error(cublasLtMatmul(getCublasLtHandle(), mOperationDesc, alpha, A, mADesc, B, mBDesc, beta, C,
mCDesc, C, mCDesc, (hasAlgo ? (&algo) : NULL), mCublasWorkspace, workspaceSize, mStream));
sync_check_cuda_error();
}
else
{
check_cuda_error(cublasSetStream(getCublasHandle(), mStream));
check_cuda_error(cublasSetWorkspace(getCublasHandle(), mCublasWorkspace, workspaceSize));
// Go with default heuristic to choose tactic as cuBLAS does not allow to choose tactics in Ampere+
cublasGemmAlgo_t cublasAlgo = CUBLAS_GEMM_DEFAULT;
check_cuda_error(cublasGemmEx(getCublasHandle(), transa, transb, m, n, k, alpha, A, mAType, lda, B, mBType, ldb,
beta, C, mCType, ldc, mComputeType, static_cast<cublasGemmAlgo_t>(cublasAlgo)));
sync_check_cuda_error();
}
}
void CublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n,
int const k, void const* A, int const lda, const int64_t strideA, void const* B, int const ldb,
const int64_t strideB, void* C, int const ldc, const int64_t strideC, int const batchCount, float const f_alpha,
float const f_beta)
{
half h_alpha = (half) f_alpha;
half h_beta = (half) f_beta;
int isFp16ComputeType = mComputeType == CUBLAS_COMPUTE_16F ? 1 : 0;
void const* alpha = isFp16ComputeType ? reinterpret_cast<void*>(&h_alpha) : reinterpret_cast<void const*>(&f_alpha);
void const* beta = isFp16ComputeType ? reinterpret_cast<void*>(&h_beta) : reinterpret_cast<void const*>(&f_beta);
check_cuda_error(cublasGemmStridedBatchedEx(getCublasHandle(), transa, transb, m, n, k, alpha, A, mAType, lda,
strideA, B, mBType, ldb, strideB, beta, C, mCType, ldc, strideC, batchCount, mComputeType,
mAType == CUDA_R_32F ? CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
void CublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n,
int const k, float const f_alpha, void const* A, cudaDataType_t AType, int const lda, const int64_t strideA,
void const* B, cudaDataType_t BType, int const ldb, const int64_t strideB, float const f_beta, void* C,
cudaDataType_t CType, int const ldc, const int64_t strideC, int const batchCount, cudaDataType_t computeType)
{
half h_alpha = (half) f_alpha;
half h_beta = (half) f_beta;
bool isFp16ComputeType = mComputeType == CUBLAS_COMPUTE_16F ? 1 : 0;
void const* alpha = isFp16ComputeType ? reinterpret_cast<void*>(&h_alpha) : reinterpret_cast<void const*>(&f_alpha);
void const* beta = isFp16ComputeType ? reinterpret_cast<void*>(&h_beta) : reinterpret_cast<void const*>(&f_beta);
check_cuda_error(cublasGemmStridedBatchedEx(getCublasHandle(), transa, transb, m, n, k, alpha, A, AType, lda,
strideA, B, BType, ldb, strideB, beta, C, CType, ldc, strideC, batchCount, computeType,
mAType == CUDA_R_32F ? CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
void CublasMMWrapper::setWorkspace(void* workspace)
{
mCublasWorkspace = workspace;
}
void CublasMMWrapper::setFP32GemmConfig()
{
setGemmConfig(CUDA_R_32F, CUDA_R_32F, CUDA_R_32F, CUDA_R_32F);
}
void CublasMMWrapper::setFP16GemmConfig(cudaDataType_t outputType)
{
setGemmConfig(CUDA_R_16F, CUDA_R_16F, outputType, CUDA_R_32F);
}
#ifdef ENABLE_BF16
void CublasMMWrapper::setBF16GemmConfig(cudaDataType_t outputType)
{
setGemmConfig(CUDA_R_16BF, CUDA_R_16BF, outputType, CUDA_R_32F);
}
#endif
#ifdef ENABLE_FP8
void CublasMMWrapper::setFP8GemmConfig(cudaDataType_t outputType)
{
setGemmConfig(CUDA_R_8F_E4M3, CUDA_R_8F_E4M3, outputType, CUDA_R_32F);
}
#endif
void CublasMMWrapper::setGemmConfig(
cudaDataType_t aType, cudaDataType_t bType, cudaDataType_t cType, cudaDataType_t computeType)
{
mAType = aType;
mBType = bType;
mCType = cType;
bool isFp16ComputeType = computeType == CUDA_R_16F;
if (isFp16ComputeType)
{
mComputeType = CUBLAS_COMPUTE_16F;
mScaleType = CUDA_R_16F;
}
else
{
mComputeType = CUBLAS_COMPUTE_32F;
mScaleType = CUDA_R_32F;
}
}
CublasDataType CublasMMWrapper::getCublasDataType(cudaDataType_t data_type)
{
if (data_type == CUDA_R_16F)
{
return HALF_DATATYPE;
}
else if (data_type == CUDA_R_32F)
{
return FLOAT_DATATYPE;
}
else if (data_type == CUDA_R_8I)
{
return INT8_DATATYPE;
}
#ifdef ENABLE_BF16
else if (data_type == CUDA_R_16BF)
{
return BFLOAT16_DATATYPE;
}
#endif
return FLOAT_DATATYPE;
}
void CublasMMWrapper::setStream(cudaStream_t stream)
{
mStream = stream;
}
bool CublasMMWrapper::checkTactic(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n,
int const k, int const lda, int const ldb, int const ldc, cublasLtMatmulAlgo_t const& algo)
{
TLLM_CHECK_WITH_INFO(
descriptorsCreated(), "Descriptors are not created! Call createDescriptors before calling this function");
int workspaceSize = mCublasWorkspace == NULL ? 0 : CUBLAS_WORKSPACE_SIZE;
cublasLtMatmulHeuristicResult_t heurResult;
cublasStatus_t algoStatus = cublasLtMatmulAlgoCheck(
getCublasLtHandle(), mOperationDesc, mADesc, mBDesc, mCDesc, mCDesc, &algo, &heurResult);
if (algoStatus != CUBLAS_STATUS_SUCCESS || heurResult.state != CUBLAS_STATUS_SUCCESS
|| heurResult.workspaceSize > CUBLAS_WORKSPACE_SIZE)
{
return false;
}
sync_check_cuda_error();
return true;
}
std::vector<cublasLtMatmulHeuristicResult_t> CublasMMWrapper::getTactics(cublasOperation_t transa,
cublasOperation_t transb, int const m, int const n, int const k, int const lda, int const ldb, int const ldc)
{
TLLM_CHECK_WITH_INFO(
descriptorsCreated(), "Descriptors are not created! Call createDescriptors before calling this function");
auto const heuristics = getTactics(getCublasLtHandle(), mOperationDesc, mADesc, mBDesc, mCDesc, mCDesc);
sync_check_cuda_error();
return heuristics;
}
std::vector<cublasLtMatmulHeuristicResult_t> CublasMMWrapper::getTactics(cublasLtHandle_t lightHandle,
cublasLtMatmulDesc_t computeDesc, cublasLtMatrixLayout_t Adesc, cublasLtMatrixLayout_t Bdesc,
cublasLtMatrixLayout_t Cdesc, cublasLtMatrixLayout_t Ddesc)
{
#if TLLM_CUBLAS_VER_LE(11, 4, 2)
TLLM_CHECK_WITH_INFO(false, "CUBLAS version too low, must be > 11.4.2.");
return {};
#else
std::vector<cublasLtMatmulHeuristicResult_t> heuristics(200);
cublasLtMatmulPreference_t preference;
check_cuda_error(cublasLtMatmulPreferenceCreate(&preference));
check_cuda_error(cublasLtMatmulPreferenceInit(preference));
uint64_t workspace_size = CUBLAS_WORKSPACE_SIZE;
check_cuda_error(cublasLtMatmulPreferenceSetAttribute(
preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspace_size, sizeof(workspace_size)));
// Restrict reduction algorithms for numerical stability and better determinism
uint32_t reduction_mask = CUBLASLT_REDUCTION_SCHEME_MASK;
check_cuda_error(cublasLtMatmulPreferenceSetAttribute(
preference, CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK, &reduction_mask, sizeof(reduction_mask)));
#if TLLM_CUBLAS_VER_LT(12, 0, 0)
uint32_t pointer_mode_mask = 0;
check_cuda_error(cublasLtMatmulPreferenceSetAttribute(
preference, CUBLASLT_MATMUL_PREF_EPILOGUE_MASK, &pointer_mode_mask, sizeof(pointer_mode_mask)));
#endif
int return_count = 0;
check_cuda_error(cublasLtMatmulAlgoGetHeuristic(lightHandle, computeDesc, Adesc, Bdesc, Cdesc, Ddesc, preference,
heuristics.size(), heuristics.data(), &return_count));
heuristics.resize(return_count);
return heuristics;
#endif
}
} // namespace common
} // namespace tensorrt_llm

View File

@@ -0,0 +1,148 @@
/*
* Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/common/cudaUtils.h"
#include <cublasLt.h>
#include <cublas_v2.h>
#include <cuda_runtime.h>
#include <map>
#include <optional>
#include <string>
namespace tensorrt_llm
{
namespace common
{
class CublasMMWrapper
{
protected:
std::shared_ptr<cublasHandle_t> mCublasHandle;
std::shared_ptr<cublasLtHandle_t> mCublasLtHandle;
cudaDataType_t mAType{};
cudaDataType_t mBType{};
cudaDataType_t mCType{};
cublasComputeType_t mComputeType{};
cudaDataType_t mScaleType{};
cublasLtMatmulDesc_t mOperationDesc{NULL};
cublasLtMatrixLayout_t mADesc{NULL};
cublasLtMatrixLayout_t mBDesc{NULL};
cublasLtMatrixLayout_t mCDesc{NULL};
cudaStream_t mStream;
void* mCublasWorkspace = nullptr;
private:
bool descriptorsCreated() const
{
return mOperationDesc != NULL && mADesc != NULL && mBDesc != NULL && mCDesc != NULL;
}
public:
CublasMMWrapper(std::shared_ptr<cublasHandle_t> cublasHandle, std::shared_ptr<cublasLtHandle_t> cublasLtHandle,
cudaStream_t stream, void* workspace);
~CublasMMWrapper();
CublasMMWrapper(CublasMMWrapper const& wrapper);
/********************** GEMMs **********************/
void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A,
int const lda, void const* B, int const ldb, void* C, int const ldc);
void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A,
int const lda, void const* B, int const ldb, void* C, int const ldc,
std::optional<cublasLtMatmulHeuristicResult_t> const& algo);
void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A,
int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta,
std::optional<cublasLtMatmulHeuristicResult_t> const& algo);
void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A,
int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta);
void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A,
int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta,
cublasLtMatmulAlgo_t const& algo, bool hasAlgo, bool usingCublasLt);
void stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
void const* A, int const lda, const int64_t strideA, void const* B, int const ldb, const int64_t strideB,
void* C, int const ldc, const int64_t strideC, int const batchCount, float const f_alpha = 1.0f,
float const f_beta = 0.0f);
void stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
float const f_alpha, void const* A, cudaDataType_t AType, int const lda, const int64_t strideA, void const* B,
cudaDataType_t BType, int const ldb, const int64_t strideB, float const f_beta, void* C, cudaDataType_t CType,
int const ldc, const int64_t strideC, int const batchCount, cudaDataType_t computeType);
/********************** Tactic selection helpers **********************/
bool checkTactic(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
int const lda, int const ldb, int const ldc, cublasLtMatmulAlgo_t const& algo);
std::vector<cublasLtMatmulHeuristicResult_t> getTactics(cublasOperation_t transa, cublasOperation_t transb,
int const m, int const n, int const k, int const lda, int const ldb, int const ldc);
std::vector<cublasLtMatmulHeuristicResult_t> getTactics(cublasLtHandle_t lightHandle,
cublasLtMatmulDesc_t computeDesc, cublasLtMatrixLayout_t Adesc, cublasLtMatrixLayout_t Bdesc,
cublasLtMatrixLayout_t Cdesc, cublasLtMatrixLayout_t Ddesc);
using MatrixLayout = std::tuple<cudaDataType_t, cublasLtOrder_t, uint64_t, uint64_t>;
using cache_idx_t = std::tuple<cublasLtMatmulDesc_t, std::array<MatrixLayout, 4>>;
MatrixLayout createMatrixLayout(cublasLtMatrixLayout_t Mdesc);
/********************** Utils **********************/
void setWorkspace(void* workspace);
void setFP32GemmConfig();
void setFP16GemmConfig(cudaDataType_t outputType = CUDA_R_16F);
#ifdef ENABLE_BF16
void setBF16GemmConfig(cudaDataType_t outputType = CUDA_R_16BF);
#endif
#ifdef ENABLE_FP8
void setFP8GemmConfig(cudaDataType_t outputType = CUDA_R_16F);
#endif
void setStream(cudaStream_t stream);
void setGemmConfig(cudaDataType_t aType, cudaDataType_t bType, cudaDataType_t cType, cudaDataType_t computeType);
CublasDataType getCublasDataType(cudaDataType_t data_type);
void createDescriptors(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
int const lda, int const ldb, int const ldc, int8_t fastAcc = 0);
void setScaleDescriptors(void* scale_a, void* scale_b);
void destroyDescriptors();
cublasHandle_t getCublasHandle()
{
return *(this->mCublasHandle);
}
cublasLtHandle_t getCublasLtHandle() const
{
return *(this->mCublasLtHandle);
}
};
} // namespace common
} // namespace tensorrt_llm

View File

@@ -0,0 +1,35 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
// We don't want to include cublas_api.h. It contains the CUBLAS_VER_* macro
// definition which is not sufficient to determine if we include cublas.h,
// cublas_v2.h or cublasLt.h.
#define TLLM_CUBLAS_VERSION_CALC(MAJOR, MINOR, PATCH) (MAJOR * 10000 + MINOR * 100 + PATCH)
#define TLLM_CUBLAS_VER_LE(MAJOR, MINOR, PATCH) \
TLLM_CUBLAS_VERSION_CALC(CUBLAS_VER_MAJOR, CUBLAS_VER_MINOR, CUBLAS_VER_PATCH) \
<= TLLM_CUBLAS_VERSION_CALC(MAJOR, MINOR, PATCH)
#define TLLM_CUBLAS_VER_LT(MAJOR, MINOR, PATCH) \
TLLM_CUBLAS_VERSION_CALC(CUBLAS_VER_MAJOR, CUBLAS_VER_MINOR, CUBLAS_VER_PATCH) \
< TLLM_CUBLAS_VERSION_CALC(MAJOR, MINOR, PATCH)
#define TLLM_CUBLAS_VER_GE(MAJOR, MINOR, PATCH) \
TLLM_CUBLAS_VERSION_CALC(CUBLAS_VER_MAJOR, CUBLAS_VER_MINOR, CUBLAS_VER_PATCH) \
>= TLLM_CUBLAS_VERSION_CALC(MAJOR, MINOR, PATCH)
#define TLLM_CUBLAS_VER_GT(MAJOR, MINOR, PATCH) \
TLLM_CUBLAS_VERSION_CALC(CUBLAS_VER_MAJOR, CUBLAS_VER_MINOR, CUBLAS_VER_PATCH) \
> TLLM_CUBLAS_VERSION_CALC(MAJOR, MINOR, PATCH)

View File

@@ -0,0 +1,313 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/common/cudaBf16Wrapper.h"
#include <cuda_fp16.h>
#include <cuda_runtime_api.h>
namespace tensorrt_llm
{
namespace common
{
#ifdef ENABLE_BF16
inline __device__ float2 bf1622float2(const __nv_bfloat162 val)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float2 f_val;
f_val.x = __low2float(val);
f_val.y = __high2float(val);
return f_val;
#else
return __bfloat1622float2(val);
#endif
}
inline __device__ int16_t bf1622int16(__nv_bfloat162 val)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float2 f_val;
f_val.x = max(min(__low2float(val), 127.f), -128.f);
f_val.y = max(min(__high2float(val), 127.f), -128.f);
union
{
int8_t int8[2];
int16_t int16;
};
int8[0] = static_cast<int8_t>(static_cast<short>(f_val.x));
int8[1] = static_cast<int8_t>(static_cast<short>(f_val.y));
return int16;
#else
val = __hmin2(val, make_bfloat162(127., 127.));
val = __hmax2(val, make_bfloat162(-128., -128.));
union
{
int8_t int8[2];
int16_t int16;
};
int8[0] = static_cast<int8_t>(static_cast<short>(val.x));
int8[1] = static_cast<int8_t>(static_cast<short>(val.y));
return int16;
#endif
}
inline __device__ __nv_bfloat162 float22bf162(const float2 val)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __floats2bfloat162_rn(val.x, val.y);
#else
return __float22bfloat162_rn(val);
#endif
}
inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
__nv_bfloat162 val2;
val2.x = val;
val2.y = val;
return val2;
#else
return __bfloat162bfloat162(val);
#endif
}
inline __device__ __nv_bfloat162 bf16hadd2(const __nv_bfloat162 x, const __nv_bfloat162 y)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fxl, fxh, fyl, fyh;
fxl = __low2float(x);
fxh = __high2float(x);
fyl = __low2float(y);
fyh = __high2float(y);
return __floats2bfloat162_rn(fxl + fyl, fxh + fyh);
#else
return __hadd2(x, y);
#endif
}
inline __device__ __nv_bfloat16 bf16hadd(const __nv_bfloat16 x, const __nv_bfloat16 y)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16(__bfloat162float(x) + __bfloat162float(y));
#else
return __hadd(x, y);
#endif
}
inline __device__ __nv_bfloat162 bf16hsub2(const __nv_bfloat162 x, const __nv_bfloat162 y)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fxl, fxh, fyl, fyh;
fxl = __low2float(x);
fxh = __high2float(x);
fyl = __low2float(y);
fyh = __high2float(y);
return __floats2bfloat162_rn(fxl - fyl, fxh - fyh);
#else
return __hsub2(x, y);
#endif
}
inline __device__ __nv_bfloat16 bf16hsub(const __nv_bfloat16 x, const __nv_bfloat16 y)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16(__bfloat162float(x) - __bfloat162float(y));
#else
return __hsub(x, y);
#endif
}
inline __device__ __nv_bfloat162 bf16hmul2(const __nv_bfloat162 x, const __nv_bfloat162 y)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fxl, fxh, fyl, fyh;
fxl = __low2float(x);
fxh = __high2float(x);
fyl = __low2float(y);
fyh = __high2float(y);
return __floats2bfloat162_rn(fxl * fyl, fxh * fyh);
#else
return __hmul2(x, y);
#endif
}
inline __device__ __nv_bfloat16 bf16hmul(const __nv_bfloat16 x, const __nv_bfloat16 y)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16(__bfloat162float(x) * __bfloat162float(y));
#else
return __hmul(x, y);
#endif
}
inline __device__ __nv_bfloat162 bf16hfma2(const __nv_bfloat162 x, const __nv_bfloat162 y, const __nv_bfloat162 z)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fxl, fxh, fyl, fyh, fzl, fzh;
fxl = __low2float(x);
fxh = __high2float(x);
fyl = __low2float(y);
fyh = __high2float(y);
fzl = __low2float(z);
fzh = __high2float(z);
return __floats2bfloat162_rn(fxl * fyl + fzl, fxh * fyh + fzh);
#else
return __hfma2(x, y, z);
#endif
}
inline __device__ __nv_bfloat16 bf16hfma(const __nv_bfloat16 x, const __nv_bfloat16 y, const __nv_bfloat16 z)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16(__bfloat162float(x) * __bfloat162float(y) + __bfloat162float(z));
#else
return __hfma(x, y, z);
#endif
}
inline __device__ __nv_bfloat162 bf16exp2(const __nv_bfloat162 x)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fxl, fxh;
fxl = __low2float(x);
fxh = __high2float(x);
;
return __floats2bfloat162_rn(expf(fxl), expf(fxh));
#else
return h2exp(x);
#endif
}
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
#if defined(CUDART_VERSION) && (CUDART_VERSION < 12020)
inline __device__ __nv_bfloat162 make_bfloat162(const __nv_bfloat16 x, const __nv_bfloat16 y)
{
__nv_bfloat162 t;
t.x = x;
t.y = y;
return t;
}
#endif
#endif
inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c));
#else
return a + b + c;
#endif
}
inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c, __nv_bfloat16 d)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c) + __bfloat162float(d));
#else
return (__nv_bfloat16) ((float) a + (float) b + (float) c + (float) d);
#endif
}
inline __device__ __nv_bfloat162 bf16hadd2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fal, fah, fbl, fbh, fcl, fch;
fal = __low2float(a);
fah = __high2float(a);
fbl = __low2float(b);
fbh = __high2float(b);
fcl = __low2float(c);
fch = __high2float(c);
return __floats2bfloat162_rn(fal + fbl + fcl, fah + fbh + fch);
#else
return a + b + c;
#endif
}
inline __device__ __nv_bfloat16 bf16hmul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16(__bfloat162float(a) * __bfloat162float(b) * __bfloat162float(c));
#else
return a * b * c;
#endif
}
inline __device__ __nv_bfloat162 bf16hmul2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fal, fah, fbl, fbh, fcl, fch;
fal = __low2float(a);
fah = __high2float(a);
fbl = __low2float(b);
fbh = __high2float(b);
fcl = __low2float(c);
fch = __high2float(c);
return __floats2bfloat162_rn(fal * fbl * fcl, fah * fbh * fch);
#else
return a * b * c;
#endif
}
inline __device__ __nv_bfloat162 bf16hfma2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c, __nv_bfloat162 d)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fal, fah, fbl, fbh, fcl, fch, fdl, fdh;
fal = __low2float(a);
fah = __high2float(a);
fbl = __low2float(b);
fbh = __high2float(b);
fcl = __low2float(c);
fch = __high2float(c);
fdl = __low2float(d);
fdh = __high2float(d);
return __floats2bfloat162_rn(fal * fbl * fcl + fdl, fah * fbh * fch + fdh);
#else
return a * b * c + d;
#endif
}
#endif // ENABLE_BF16
} // namespace common
} // namespace tensorrt_llm
// Operator definitions intentionally in global namespace
namespace
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
#if defined(CUDART_VERSION) && (CUDART_VERSION < 12020)
inline __device__ __nv_bfloat162 operator*(const __nv_bfloat162 x, const __nv_bfloat162 y)
{
return tensorrt_llm::common::bf16hmul2(x, y);
};
inline __device__ __nv_bfloat162 operator+(const __nv_bfloat162 x, const __nv_bfloat162 y)
{
return tensorrt_llm::common::bf16hadd2(x, y);
};
#endif
#endif
} // namespace

View File

@@ -0,0 +1,187 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#define CUDA_LIB_NAME "cuda"
#if defined(_WIN32)
#include <windows.h>
#define dllOpen(name) LoadLibrary("nv" name ".dll")
#define dllClose(handle) FreeLibrary(static_cast<HMODULE>(handle))
#define dllGetSym(handle, name) static_cast<void*>(GetProcAddress(static_cast<HMODULE>(handle), name))
#else // For non-Windows platforms
#include <dlfcn.h>
#define dllOpen(name) dlopen("lib" name ".so.1", RTLD_LAZY)
#define dllClose(handle) dlclose(handle)
#define dllGetSym(handle, name) dlsym(handle, name)
#endif // defined(_WIN32)
#include "cudaDriverWrapper.h"
#include "tensorrt_llm/common/assert.h"
#include <cstdio>
#include <cuda.h>
namespace tensorrt_llm::common
{
std::shared_ptr<CUDADriverWrapper> CUDADriverWrapper::getInstance()
{
static std::mutex mutex;
static std::weak_ptr<CUDADriverWrapper> instance;
std::shared_ptr<CUDADriverWrapper> result = instance.lock();
if (result)
{
return result;
}
std::lock_guard<std::mutex> lock(mutex);
result = instance.lock();
if (!result)
{
result = std::shared_ptr<CUDADriverWrapper>(new CUDADriverWrapper());
instance = result;
}
return result;
}
CUDADriverWrapper::CUDADriverWrapper()
: handle(dllOpen(CUDA_LIB_NAME))
{
TLLM_CHECK_WITH_INFO(handle != nullptr, "CUDA driver library is not open correctly.");
auto load_sym = [](void* handle, char const* name)
{
void* ret = dllGetSym(handle, name);
return ret;
};
*reinterpret_cast<void**>(&_cuGetErrorName) = load_sym(handle, "cuGetErrorName");
*reinterpret_cast<void**>(&_cuGetErrorMessage) = load_sym(handle, "cuGetErrorMessage");
*reinterpret_cast<void**>(&_cuFuncSetAttribute) = load_sym(handle, "cuFuncSetAttribute");
*reinterpret_cast<void**>(&_cuLinkComplete) = load_sym(handle, "cuLinkComplete");
*reinterpret_cast<void**>(&_cuModuleUnload) = load_sym(handle, "cuModuleUnload");
*reinterpret_cast<void**>(&_cuLinkDestroy) = load_sym(handle, "cuLinkDestroy");
*reinterpret_cast<void**>(&_cuModuleLoadData) = load_sym(handle, "cuModuleLoadData");
*reinterpret_cast<void**>(&_cuLinkCreate) = load_sym(handle, "cuLinkCreate_v2");
*reinterpret_cast<void**>(&_cuModuleGetFunction) = load_sym(handle, "cuModuleGetFunction");
*reinterpret_cast<void**>(&_cuModuleGetGlobal) = load_sym(handle, "cuModuleGetGlobal_v2");
*reinterpret_cast<void**>(&_cuLinkAddFile) = load_sym(handle, "cuLinkAddFile_v2");
*reinterpret_cast<void**>(&_cuLinkAddData) = load_sym(handle, "cuLinkAddData_v2");
*reinterpret_cast<void**>(&_cuLaunchCooperativeKernel) = load_sym(handle, "cuLaunchCooperativeKernel");
*reinterpret_cast<void**>(&_cuLaunchKernel) = load_sym(handle, "cuLaunchKernel");
*reinterpret_cast<void**>(&_cuTensorMapEncodeTiled) = load_sym(handle, "cuTensorMapEncodeTiled");
*reinterpret_cast<void**>(&_cuMemcpyDtoH) = load_sym(handle, "cuMemcpyDtoH_v2");
}
CUDADriverWrapper::~CUDADriverWrapper()
{
dllClose(handle);
}
CUresult CUDADriverWrapper::cuGetErrorName(CUresult error, char const** pStr) const
{
return (*_cuGetErrorName)(error, pStr);
}
CUresult CUDADriverWrapper::cuGetErrorMessage(CUresult error, char const** pStr) const
{
return (*_cuGetErrorMessage)(error, pStr);
}
CUresult CUDADriverWrapper::cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value) const
{
return (*_cuFuncSetAttribute)(hfunc, attrib, value);
}
CUresult CUDADriverWrapper::cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut) const
{
return (*_cuLinkComplete)(state, cubinOut, sizeOut);
}
CUresult CUDADriverWrapper::cuModuleUnload(CUmodule hmod) const
{
return (*_cuModuleUnload)(hmod);
}
CUresult CUDADriverWrapper::cuLinkDestroy(CUlinkState state) const
{
return (*_cuLinkDestroy)(state);
}
CUresult CUDADriverWrapper::cuModuleLoadData(CUmodule* module, void const* image) const
{
return (*_cuModuleLoadData)(module, image);
}
CUresult CUDADriverWrapper::cuLinkCreate(
unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut) const
{
return (*_cuLinkCreate)(numOptions, options, optionValues, stateOut);
}
CUresult CUDADriverWrapper::cuModuleGetFunction(CUfunction* hfunc, CUmodule hmod, char const* name) const
{
return (*_cuModuleGetFunction)(hfunc, hmod, name);
}
CUresult CUDADriverWrapper::cuModuleGetGlobal(CUdeviceptr* dptr, size_t* bytes, CUmodule hmod, char const* name) const
{
return (*_cuModuleGetGlobal)(dptr, bytes, hmod, name);
}
CUresult CUDADriverWrapper::cuLinkAddFile(CUlinkState state, CUjitInputType type, char const* path,
unsigned int numOptions, CUjit_option* options, void** optionValues) const
{
return (*_cuLinkAddFile)(state, type, path, numOptions, options, optionValues);
}
CUresult CUDADriverWrapper::cuLinkAddData(CUlinkState state, CUjitInputType type, void* data, size_t size,
char const* name, unsigned int numOptions, CUjit_option* options, void** optionValues) const
{
return (*_cuLinkAddData)(state, type, data, size, name, numOptions, options, optionValues);
}
CUresult CUDADriverWrapper::cuLaunchCooperativeKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY,
unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ,
unsigned int sharedMemBytes, CUstream hStream, void** kernelParams) const
{
return (*_cuLaunchCooperativeKernel)(
f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, sharedMemBytes, hStream, kernelParams);
}
CUresult CUDADriverWrapper::cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY,
unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ,
unsigned int sharedMemBytes, CUstream hStream, void** kernelParams, void** extra) const
{
return (*_cuLaunchKernel)(
f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, sharedMemBytes, hStream, kernelParams, extra);
}
CUresult CUDADriverWrapper::cuTensorMapEncodeTiled(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType,
cuuint32_t tensorRank, void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides,
cuuint32_t const* boxDim, cuuint32_t const* elementStrides, CUtensorMapInterleave interleave,
CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill) const
{
return (*_cuTensorMapEncodeTiled)(tensorMap, tensorDataType, tensorRank, globalAddress, globalDim, globalStrides,
boxDim, elementStrides, interleave, swizzle, l2Promotion, oobFill);
}
CUresult CUDADriverWrapper::cuMemcpyDtoH(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount) const
{
return (*_cuMemcpyDtoH)(dstHost, srcDevice, ByteCount);
}
} // namespace tensorrt_llm::common

View File

@@ -0,0 +1,138 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef CUDA_DRIVER_WRAPPER_H
#define CUDA_DRIVER_WRAPPER_H
#include "tensorrt_llm/common/assert.h"
#include <cstdio>
#include <cuda.h>
#include <memory>
#include <mutex>
namespace tensorrt_llm::common
{
class CUDADriverWrapper
{
public:
static std::shared_ptr<CUDADriverWrapper> getInstance();
~CUDADriverWrapper();
CUDADriverWrapper(CUDADriverWrapper const&) = delete;
CUDADriverWrapper operator=(CUDADriverWrapper const&) = delete;
CUDADriverWrapper(CUDADriverWrapper&&) = delete;
CUDADriverWrapper operator=(CUDADriverWrapper&&) = delete;
CUresult cuGetErrorName(CUresult error, char const** pStr) const;
CUresult cuGetErrorMessage(CUresult error, char const** pStr) const;
CUresult cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value) const;
CUresult cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut) const;
CUresult cuModuleUnload(CUmodule hmod) const;
CUresult cuLinkDestroy(CUlinkState state) const;
CUresult cuModuleLoadData(CUmodule* module, void const* image) const;
CUresult cuLinkCreate(
unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut) const;
CUresult cuModuleGetFunction(CUfunction* hfunc, CUmodule hmod, char const* name) const;
CUresult cuModuleGetGlobal(CUdeviceptr* dptr, size_t* bytes, CUmodule hmod, char const* name) const;
CUresult cuLinkAddFile(CUlinkState state, CUjitInputType type, char const* path, unsigned int numOptions,
CUjit_option* options, void** optionValues) const;
CUresult cuLinkAddData(CUlinkState state, CUjitInputType type, void* data, size_t size, char const* name,
unsigned int numOptions, CUjit_option* options, void** optionValues) const;
CUresult cuLaunchCooperativeKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY,
unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ,
unsigned int sharedMemBytes, CUstream hStream, void** kernelParams) const;
CUresult cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ,
unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes,
CUstream hStream, void** kernelParams, void** extra) const;
CUresult cuTensorMapEncodeTiled(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType, cuuint32_t tensorRank,
void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides, cuuint32_t const* boxDim,
cuuint32_t const* elementStrides, CUtensorMapInterleave interleave, CUtensorMapSwizzle swizzle,
CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill) const;
CUresult cuMemcpyDtoH(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount) const;
private:
void* handle;
CUDADriverWrapper();
CUresult (*_cuGetErrorName)(CUresult, char const**);
CUresult (*_cuGetErrorMessage)(CUresult, char const**);
CUresult (*_cuFuncSetAttribute)(CUfunction, CUfunction_attribute, int);
CUresult (*_cuLinkComplete)(CUlinkState, void**, size_t*);
CUresult (*_cuModuleUnload)(CUmodule);
CUresult (*_cuLinkDestroy)(CUlinkState);
CUresult (*_cuLinkCreate)(unsigned int, CUjit_option*, void**, CUlinkState*);
CUresult (*_cuModuleLoadData)(CUmodule*, void const*);
CUresult (*_cuModuleGetFunction)(CUfunction*, CUmodule, char const*);
CUresult (*_cuModuleGetGlobal)(CUdeviceptr*, size_t*, CUmodule, char const*);
CUresult (*_cuLinkAddFile)(CUlinkState, CUjitInputType, char const*, unsigned int, CUjit_option*, void**);
CUresult (*_cuLinkAddData)(
CUlinkState, CUjitInputType, void*, size_t, char const*, unsigned int, CUjit_option*, void**);
CUresult (*_cuLaunchCooperativeKernel)(CUfunction, unsigned int, unsigned int, unsigned int, unsigned int,
unsigned int, unsigned int, unsigned int, CUstream, void**);
CUresult (*_cuLaunchKernel)(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ,
unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes,
CUstream hStream, void** kernelParams, void** extra);
CUresult (*_cuTensorMapEncodeTiled)(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType,
cuuint32_t tensorRank, void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides,
cuuint32_t const* boxDim, cuuint32_t const* elementStrides, CUtensorMapInterleave interleave,
CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill);
CUresult (*_cuMemcpyDtoH)(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount);
};
template <typename T>
void checkDriver(
T result, CUDADriverWrapper const& wrap, char const* const func, char const* const file, int const line)
{
if (result)
{
char const* errorName = nullptr;
char const* errorMsg = nullptr;
wrap.cuGetErrorName(result, &errorName);
wrap.cuGetErrorMessage(result, &errorMsg);
throw TllmException(
file, line, fmtstr("[TensorRT-LLM][ERROR] CUDA driver error in %s: %s: %s", func, errorName, errorMsg));
}
}
} // namespace tensorrt_llm::common
/*
* Macros compliant with TensorRT coding conventions
*/
#define TLLM_CU_CHECK(stat) \
do \
{ \
tensorrt_llm::common::checkDriver( \
(stat), *tensorrt_llm::common::CUDADriverWrapper::getInstance(), #stat, __FILE__, __LINE__); \
} while (0)
#endif // CUDA_DRIVER_WRAPPER_H

View File

@@ -0,0 +1,436 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/cudaFp8Utils.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/reduceKernelUtils.cuh"
#include <algorithm>
#include <cstdio>
#include <cuda_fp16.h>
#include <limits>
#include <type_traits>
namespace tensorrt_llm
{
namespace common
{
#ifdef ENABLE_FP8
constexpr int CTA_SIZE = 256;
template <bool QUANTIZE>
__inline__ __device__ float scale(float a, float b)
{
return QUANTIZE ? a / b : a * b;
}
template <QuantizeMode QUANTIZE_MODE, bool QUANTIZE, typename T_OUT, typename T_S, typename T_IN>
__global__ void scaleMatrix(T_OUT* output, T_S const* input_scale, T_IN const* input, int64_t numel, int64_t lda)
{
for (int64_t i = threadIdx.x + blockIdx.x * blockDim.x; i < numel; i += blockDim.x * gridDim.x)
{
if (QUANTIZE_MODE == QuantizeMode::PER_CHANNEL)
{
output[i] = T_OUT(scale<QUANTIZE>(static_cast<float>(input[i]), static_cast<float>(input_scale[i % lda])));
}
else if (QUANTIZE_MODE == QuantizeMode::PER_TOKEN)
{
output[i] = T_OUT(scale<QUANTIZE>(static_cast<float>(input[i]), static_cast<float>(input_scale[i / lda])));
}
else if (QUANTIZE_MODE == QuantizeMode::PER_TENSOR)
{
output[i] = T_OUT(scale<QUANTIZE>(static_cast<float>(input[i]), static_cast<float>(input_scale[0])));
}
}
}
template <typename T_OUT, typename T_S, typename T_IN>
void invokeQuantizeMatrix(T_OUT* output, T_S const* input_scale, T_IN const* input, int64_t numel, int64_t lda,
QuantizeMode quantize_mode, cudaStream_t stream)
{
dim3 grid(1024);
dim3 block(CTA_SIZE);
if (quantize_mode == QuantizeMode::PER_CHANNEL)
{
scaleMatrix<QuantizeMode::PER_CHANNEL, true>
<<<grid, block, 0, stream>>>(output, input_scale, input, numel, lda);
}
else if (quantize_mode == QuantizeMode::PER_TOKEN)
{
scaleMatrix<QuantizeMode::PER_TOKEN, true><<<grid, block, 0, stream>>>(output, input_scale, input, numel, lda);
}
else if (quantize_mode == QuantizeMode::PER_TENSOR)
{
scaleMatrix<QuantizeMode::PER_TENSOR, true><<<grid, block, 0, stream>>>(output, input_scale, input, numel, lda);
}
sync_check_cuda_error();
}
template <typename T_OUT, typename T_S, typename T_IN>
void invokeDequantizeMatrix(T_OUT* output, T_S const* input_scale, T_IN const* input, int64_t numel, int64_t lda,
QuantizeMode quantize_mode, cudaStream_t stream)
{
dim3 grid(1024);
dim3 block(CTA_SIZE);
if (quantize_mode == QuantizeMode::PER_CHANNEL)
{
scaleMatrix<QuantizeMode::PER_CHANNEL, false>
<<<grid, block, 0, stream>>>(output, input_scale, input, numel, lda);
}
else if (quantize_mode == QuantizeMode::PER_TOKEN)
{
scaleMatrix<QuantizeMode::PER_TOKEN, false><<<grid, block, 0, stream>>>(output, input_scale, input, numel, lda);
}
else if (quantize_mode == QuantizeMode::PER_TENSOR)
{
scaleMatrix<QuantizeMode::PER_TENSOR, false>
<<<grid, block, 0, stream>>>(output, input_scale, input, numel, lda);
}
sync_check_cuda_error();
}
template <typename T_FAKE, typename T_OUT, typename T_IN>
__global__ void fakeQuantize(T_OUT* dst, const T_IN* src, const int64_t numel)
{
for (int64_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < numel; tid += blockDim.x * gridDim.x)
{
T_FAKE tmp = (T_FAKE) (static_cast<float>(src[tid]));
dst[tid] = (T_OUT) (static_cast<float>(tmp));
}
}
template <typename T_FAKE, typename T_OUT, typename T_IN>
void invokeFakeQuantize(T_OUT* dst, const T_IN* src, const int64_t numel, cudaStream_t stream)
{
fakeQuantize<T_FAKE><<<1024, CTA_SIZE, 0, stream>>>(dst, src, numel);
sync_check_cuda_error();
}
template void invokeFakeQuantize<__nv_fp8_e4m3, float, float>(
float* dst, float const* src, const int64_t numel, cudaStream_t stream);
template void invokeFakeQuantize<float, float, __nv_fp8_e4m3>(
float* dst, __nv_fp8_e4m3 const* src, const int64_t numel, cudaStream_t stream);
template void invokeFakeQuantize<__nv_fp8_e4m3, half, half>(
half* dst, half const* src, const int64_t numel, cudaStream_t stream);
template void invokeFakeQuantize<__nv_fp8_e4m3, __nv_bfloat16, __nv_bfloat16>(
__nv_bfloat16* dst, __nv_bfloat16 const* src, const int64_t numel, cudaStream_t stream);
template void invokeFakeQuantize<float, half, float>(
half* dst, float const* src, const int64_t numel, cudaStream_t stream);
__device__ float atomicMaxExtd(float* address, float val)
{
assert(val >= 0);
unsigned int* address_as_u = reinterpret_cast<unsigned int*>(address);
unsigned int old = atomicMax(address_as_u, __float_as_uint(val));
return __uint_as_float(old);
}
template <typename T>
inline __device__ T atomicMaxExtdV2(T* address, T val)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
static_assert(std::is_same_v<T, half> | std::is_same_v<T, __nv_bfloat16>, "T needs to be either half or bfloat16");
// The address in 64 bits.
uint64_t address_u64 = reinterpret_cast<uint64_t const&>(address);
// Pack the input value into 32 bits.
union
{
T v[2];
uint16_t u[2];
} old, tmp = {};
int const loc = (address_u64 & 0x2) >> 1;
tmp.v[loc] = val;
// 4B aligned pointer.
auto aligned_address = reinterpret_cast<T*>(address_u64 & ~0x3ull);
if constexpr (std::is_same_v<T, half>)
{
asm volatile("atom.global.v2.f16.max.noftz {%0, %1}, [%2], {%3, %4};"
: "=h"(old.u[0]), "=h"(old.u[1])
: "l"(aligned_address), "h"(tmp.u[0]), "h"(tmp.u[1]));
}
if constexpr (std::is_same_v<T, __nv_bfloat16>)
{
asm volatile("atom.global.v2.bf16.max.noftz {%0, %1}, [%2], {%3, %4};"
: "=h"(old.u[0]), "=h"(old.u[1])
: "l"(aligned_address), "h"(tmp.u[0]), "h"(tmp.u[1]));
}
// Return the correct half.
return old.v[loc];
#endif
}
__device__ half atomicMaxExtd(half* address, half val)
{
unsigned short int* address_as_u = reinterpret_cast<unsigned short int*>(address);
unsigned short int old = *address_as_u, assumed;
while (val > __ushort_as_half(old))
{
assumed = old;
old = atomicCAS(address_as_u, assumed, __half_as_ushort(val));
}
return __ushort_as_half(old);
}
__device__ __nv_bfloat16 atomicMaxExtd(__nv_bfloat16* address, __nv_bfloat16 val)
{
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
unsigned short int* address_as_u = reinterpret_cast<unsigned short int*>(address);
unsigned short int old = *address_as_u, assumed;
while (val > __ushort_as_bfloat16(old))
{
assumed = old;
old = atomicCAS(address_as_u, assumed, __bfloat16_as_ushort(val));
}
return __ushort_as_bfloat16(old);
#else
assert(0);
asm volatile("brkpt;\n" ::);
return __nv_bfloat16(0);
#endif
}
template <QuantizeMode QUANTIZE_MODE, typename T_S, typename T_W>
__global__ void computeFP8QuantizeScale(T_S* quant_ptr, const T_W* weights, const int64_t size, const int64_t n)
{
constexpr float min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f);
if (QUANTIZE_MODE == QuantizeMode::PER_CHANNEL)
{
for (int64_t col = threadIdx.x; col < n; col += blockDim.x)
{
float max = 0.f;
for (int64_t i = col + n * blockIdx.x; i < size; i += gridDim.x * n)
{
auto val = fabs(static_cast<float>(weights[i]));
max = max > val ? max : val;
}
auto const scale = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor);
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
if constexpr (std::is_same_v<T_S, float>)
{
atomicMaxExtd(quant_ptr + col, scale);
}
else
{
auto const address_u64 = reinterpret_cast<uint64_t>(quant_ptr + col);
if ((col == 0 && address_u64 % 4 != 0) || (col == n - 1 && address_u64 % 4 == 0))
atomicMaxExtd(quant_ptr + col, scale);
else
atomicMaxExtdV2(quant_ptr + col, scale);
}
#else // Vector atomics require __CUDA_ARCH__ >= 900
atomicMaxExtd(quant_ptr + col, scale);
#endif
}
}
else if (QUANTIZE_MODE == QuantizeMode::PER_TOKEN)
{
auto const nrows = size / n;
for (int64_t row = blockIdx.x; row < nrows; row += gridDim.x)
{
float max = 0.f;
for (int64_t i = threadIdx.x; i < n; i += blockDim.x)
{
auto val = fabs(static_cast<float>(weights[row * n + i]));
max = max > val ? max : val;
}
max = blockReduceMax<float>(max);
if (threadIdx.x == 0)
{
auto const scale = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor);
quant_ptr[row] = scale;
}
}
}
else if (QUANTIZE_MODE == QuantizeMode::PER_TENSOR)
{
float max = 0.f;
for (int64_t i = threadIdx.x + blockIdx.x * blockDim.x; i < size; i += gridDim.x * blockDim.x)
{
auto val = fabs(static_cast<float>(weights[i]));
max = max > val ? max : val;
}
max = blockReduceMax<float>(max);
if (threadIdx.x == 0)
{
auto const scale = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor);
atomicMaxExtd(quant_ptr, scale);
}
}
}
template <typename T_S, typename T_W>
void invokeComputeFP8QuantizeScale(T_S* quant_ptr, const T_W* weights, const int64_t numel, const int64_t lda,
QuantizeMode quantize_mode, cudaStream_t stream)
{
if (quantize_mode == QuantizeMode::PER_TOKEN)
{
dim3 block(CTA_SIZE);
dim3 grid(numel / lda);
computeFP8QuantizeScale<QuantizeMode::PER_TOKEN><<<grid, block, 0, stream>>>(quant_ptr, weights, numel, lda);
}
else if (quantize_mode == QuantizeMode::PER_CHANNEL)
{
dim3 block(CTA_SIZE);
dim3 grid((lda + CTA_SIZE - 1) / CTA_SIZE);
cudaMemsetAsync(quant_ptr, 0, lda * sizeof(T_S), stream);
sync_check_cuda_error();
computeFP8QuantizeScale<QuantizeMode::PER_CHANNEL><<<grid, block, 0, stream>>>(quant_ptr, weights, numel, lda);
}
else if (quantize_mode == QuantizeMode::PER_TENSOR)
{
dim3 block(1024);
dim3 grid(1024);
cudaMemsetAsync(quant_ptr, 0, sizeof(T_S), stream);
sync_check_cuda_error();
computeFP8QuantizeScale<QuantizeMode::PER_TENSOR><<<grid, block, 0, stream>>>(quant_ptr, weights, numel, lda);
}
sync_check_cuda_error();
}
#define DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(type_scale, type_in) \
template void invokeComputeFP8QuantizeScale<type_scale, type_in>(type_scale * input_scale, type_in const* weights, \
int64_t numel, int64_t lda, QuantizeMode quantize_mode, cudaStream_t stream);
DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(half, half);
DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(float, half);
DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(float, float);
#ifdef ENABLE_BF16
DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(__nv_bfloat16, __nv_bfloat16);
DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(float, __nv_bfloat16);
#endif
template <typename T_OUT, typename T_S, typename T_IN>
__global__ void dynamicQuantizeMatrixPerToken(
T_OUT* output, T_S* quant_ptr, T_IN const* input, int64_t numel, int64_t lda)
{
extern __shared__ __align__(sizeof(float)) char _shmem[];
T_IN* shmem = reinterpret_cast<T_IN*>(_shmem);
constexpr float min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f);
auto const nrows = numel / lda;
for (int64_t row = blockIdx.x; row < nrows; row += gridDim.x)
{
float max = 0.f;
for (int64_t i = threadIdx.x; i < lda; i += blockDim.x)
{
auto const in = input[row * lda + i];
shmem[i] = in;
auto val = fabs(static_cast<float>(in));
max = max > val ? max : val;
}
max = blockAllReduceMax<float>(max); // __syncthreads() called so we can read shmem
auto const s = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor);
for (int64_t i = threadIdx.x; i < lda; i += blockDim.x)
{
// true means we are quantizing
output[row * lda + i] = (T_OUT) scale<true>(static_cast<float>(shmem[i]), static_cast<float>(s));
}
if (threadIdx.x == 0)
{
quant_ptr[row] = s;
}
}
}
template <typename T_OUT, typename T_S, typename T_IN>
void invokeComputeScalesAndQuantizeMatrix(T_OUT* output, T_S* quant_ptr, const T_IN* input, const int64_t numel,
const int64_t lda, QuantizeMode quantize_mode, cudaStream_t stream)
{
if (quantize_mode == QuantizeMode::PER_TOKEN)
{
dim3 grid(numel / lda);
bool use_shmem = true;
auto const shmem_size = lda * sizeof(T_IN);
if (shmem_size >= (48 << 10))
{
cudaError_t ret = cudaFuncSetAttribute(dynamicQuantizeMatrixPerToken<T_OUT, T_S, T_IN>,
cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size);
use_shmem = ret == cudaSuccess;
}
if (use_shmem)
{
// ensure the threadblock is as large as possible to increase occupancy
dim3 block(std::min((lda + 31) / 32 * 32, static_cast<int64_t>(1024)));
dynamicQuantizeMatrixPerToken<<<grid, block, shmem_size, stream>>>(output, quant_ptr, input, numel, lda);
}
else
{
dim3 block(CTA_SIZE);
computeFP8QuantizeScale<QuantizeMode::PER_TOKEN><<<grid, block, 0, stream>>>(quant_ptr, input, numel, lda);
sync_check_cuda_error();
invokeQuantizeMatrix(output, quant_ptr, input, numel, lda, quantize_mode, stream);
}
}
else if (quantize_mode == QuantizeMode::PER_CHANNEL)
{
dim3 block(CTA_SIZE);
dim3 grid((lda + CTA_SIZE - 1) / CTA_SIZE);
cudaMemsetAsync(quant_ptr, 0, lda * sizeof(T_S), stream);
sync_check_cuda_error();
computeFP8QuantizeScale<QuantizeMode::PER_CHANNEL><<<grid, block, 0, stream>>>(quant_ptr, input, numel, lda);
sync_check_cuda_error();
invokeQuantizeMatrix(output, quant_ptr, input, numel, lda, quantize_mode, stream);
}
else if (quantize_mode == QuantizeMode::PER_TENSOR)
{
dim3 block(1024);
dim3 grid(1024);
cudaMemsetAsync(quant_ptr, 0, sizeof(T_S), stream);
sync_check_cuda_error();
computeFP8QuantizeScale<QuantizeMode::PER_TENSOR><<<grid, block, 0, stream>>>(quant_ptr, input, numel, lda);
sync_check_cuda_error();
invokeQuantizeMatrix(output, quant_ptr, input, numel, lda, quantize_mode, stream);
}
sync_check_cuda_error();
}
#define DEFINE_INVOKE_QUANTIZE_MATRIX(type_out, type_scale, type_in) \
template void invokeQuantizeMatrix<type_out, type_scale, type_in>(type_out * output, \
type_scale const* input_scale, type_in const* input, int64_t numel, int64_t lda, QuantizeMode quantize_mode, \
cudaStream_t stream); \
template void invokeDequantizeMatrix<type_out, type_scale, type_in>(type_out * output, \
type_scale const* input_scale, type_in const* input, int64_t numel, int64_t lda, QuantizeMode quantize_mode, \
cudaStream_t stream); \
template void invokeComputeScalesAndQuantizeMatrix<type_out, type_scale, type_in>(type_out * output, \
type_scale * input_scale, type_in const* input, int64_t numel, int64_t lda, QuantizeMode quantize_mode, \
cudaStream_t stream);
#ifdef ENABLE_FP8
DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_fp8_e4m3, float, float);
DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_fp8_e4m3, float, half);
DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_fp8_e4m3, half, half);
DEFINE_INVOKE_QUANTIZE_MATRIX(half, half, __nv_fp8_e4m3);
DEFINE_INVOKE_QUANTIZE_MATRIX(float, float, __nv_fp8_e4m3);
DEFINE_INVOKE_QUANTIZE_MATRIX(half, float, __nv_fp8_e4m3);
#ifdef ENABLE_BF16
DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_fp8_e4m3, __nv_bfloat16, __nv_bfloat16);
DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_bfloat16, __nv_bfloat16, __nv_fp8_e4m3);
#endif
#endif
#endif // ENABLE_FP8
} // namespace common
} // namespace tensorrt_llm

View File

@@ -0,0 +1,84 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/cudaProfilerUtils.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/common/stringUtils.h"
#include <cstdint>
#include <optional>
namespace
{
std::tuple<std::unordered_set<int32_t>, std::unordered_set<int32_t>> populateIterationIndexesImpl(
std::string const& envVarName)
{
auto envVarVal = std::getenv(envVarName.c_str());
auto envVarValStr = std::string{envVarVal != nullptr ? envVarVal : ""};
auto values = tensorrt_llm::common::str2set(envVarValStr, ',');
std::unordered_set<int32_t> startSet;
std::unordered_set<int32_t> endSet;
for (std::string const& value : values)
{
size_t dashIdx = value.find("-");
if (dashIdx != std::string::npos)
{
int32_t start = std::stoi(value.substr(0, dashIdx));
startSet.insert(start);
int32_t end = std::stoi(value.substr(dashIdx + 1));
endSet.insert(end);
}
else
{
int32_t start_end = std::stoi(value);
startSet.insert(start_end);
endSet.insert(start_end);
}
}
return std::make_pair(startSet, endSet);
}
} // namespace
namespace tensorrt_llm::common
{
std::pair<std::unordered_set<int32_t>, std::unordered_set<int32_t>> populateIterationIndexes(
std::string const& envVarName, std::optional<std::string> const& legacyEnvVarName)
{
auto [profileIterIdxs, stopIterIdxs] = populateIterationIndexesImpl(envVarName);
// If empty, try to use legacy env var name
if (legacyEnvVarName && profileIterIdxs.empty() && stopIterIdxs.empty())
{
std::tie(profileIterIdxs, stopIterIdxs) = populateIterationIndexesImpl(legacyEnvVarName.value());
if (!profileIterIdxs.empty() || !stopIterIdxs.empty())
{
TLLM_LOG_WARNING(
"Using deprecated environment variable %s to specify cudaProfiler start and stop iterations. "
"Please "
"use %s "
"instead.",
legacyEnvVarName.value().c_str(), envVarName.c_str());
}
}
return std::make_pair(profileIterIdxs, stopIterIdxs);
}
} // namespace tensorrt_llm::common

View File

@@ -0,0 +1,752 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/common/cudaBf16Fallbacks.cuh"
#include "tensorrt_llm/common/cudaBf16Wrapper.h"
#include "tensorrt_llm/common/cudaFp8Utils.h"
#include <assert.h>
#include <cuda.h>
#include <cuda_fp16.h>
#if ENABLE_BF16
#include <cuda_bf16.h>
#endif
namespace tensorrt_llm
{
namespace common
{
template <typename T>
inline __device__ T ldg(T const* val)
{
return __ldg(val);
}
#if ENABLE_BF16
template <>
inline __device__ __nv_bfloat162 ldg(__nv_bfloat162 const* val)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return val[0];
#else
return __ldg(val);
#endif
}
template <>
inline __device__ __nv_bfloat16 ldg(__nv_bfloat16 const* val)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return val[0];
#else
return __ldg(val);
#endif
}
#endif // ENABLE_BF16
// Get type2 from type or vice versa (applied to half and bfloat16)
template <typename T>
struct TypeConverter
{
using Type = half2;
}; // keep for generality
template <>
struct TypeConverter<half2>
{
using Type = half;
};
template <>
struct TypeConverter<half>
{
using Type = half2;
};
#if ENABLE_BF16
template <>
struct TypeConverter<__nv_bfloat162>
{
using Type = __nv_bfloat16;
};
template <>
struct TypeConverter<__nv_bfloat16>
{
using Type = __nv_bfloat162;
};
#endif // ENABLE_BF16
// Defined math operations (bfloat16 fallback to fp32 when it is not supported)
template <typename T>
inline __device__ T hadd2(T a, T b)
{
return __hadd2(a, b);
}
#if ENABLE_BF16
template <>
inline __device__ __nv_bfloat162 hadd2(__nv_bfloat162 a, __nv_bfloat162 b)
{
return bf16hadd2(a, b);
}
#endif // ENABLE_BF16
template <typename T>
inline __device__ T add(T a, T b)
{
return a + b;
}
template <>
inline __device__ half2 add(half2 a, half2 b)
{
return __hadd2(a, b);
}
template <>
inline __device__ half add(half a, half b)
{
return __hadd(a, b);
}
#if ENABLE_BF16
template <>
inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b)
{
return bf16hadd2(a, b);
}
template <>
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b)
{
return bf16hadd(a, b);
}
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, float b)
{
return bf16hadd(a, __float2bfloat16(b));
}
#endif // ENABLE_BF16
// applies to all 4 values addition
template <typename T>
inline __device__ T add(T a, T b, T c)
{
return a + b + c;
}
#if ENABLE_BF16
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c)
{
return bf16hadd(a, b, c);
}
inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)
{
return bf16hadd2(a, b, c);
}
#endif // ENABLE_BF16
// applies to all 4 values addition
template <typename T>
inline __device__ T add(T a, T b, T c, T d)
{
return (T) ((float) a + (float) b + (float) c + (float) d);
}
#if ENABLE_BF16
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c, __nv_bfloat16 d)
{
return bf16hadd(a, b, c, d);
}
#endif // ENABLE_BF16
template <typename T>
inline __device__ T hsub2(T a, T b)
{
return __hsub2(a, b);
}
#if ENABLE_BF16
template <>
inline __device__ __nv_bfloat162 hsub2(__nv_bfloat162 a, __nv_bfloat162 b)
{
return bf16hsub2(a, b);
}
#endif // ENABLE_BF16
template <typename T>
inline __device__ T hmul2(T a, T b)
{
return __hmul2(a, b);
}
#if ENABLE_BF16
template <>
inline __device__ __nv_bfloat162 hmul2(__nv_bfloat162 a, __nv_bfloat162 b)
{
return bf16hmul2(a, b);
}
#endif // ENABLE_BF16
template <typename T>
inline __device__ T hmul2(T a, T b, T c)
{
return a * b * c;
}
#if ENABLE_BF16
template <>
inline __device__ __nv_bfloat162 hmul2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)
{
return bf16hmul2(a, b, c);
}
#endif // ENABLE_BF16
template <typename T>
inline __device__ T mul(T a, T b, T c)
{
return a * b * c;
}
#if ENABLE_BF16
template <>
inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c)
{
return bf16hmul(a, b, c);
}
inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)
{
return bf16hmul2(a, b, c);
}
#endif // ENABLE_BF16
template <typename T>
inline __device__ T fma(T a, T b, T c, T d)
{
return a * b * c + d;
}
#if ENABLE_BF16
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c, __nv_bfloat162 d)
{
return bf16hfma2(a, b, c, d);
}
#endif // ENABLE_BF16
template <typename T>
inline __device__ T fma(T a, T b, T c)
{
return a * b + c;
}
#if ENABLE_BF16
template <>
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)
{
return bf16hfma2(a, b, c);
}
template <>
inline __device__ __nv_bfloat16 fma(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c)
{
return bf16hfma(a, b, c);
}
#endif // ENABLE_BF16
template <typename T>
inline __device__ T hexp2(T a)
{
return h2exp(a);
}
#if ENABLE_BF16
template <>
inline __device__ __nv_bfloat162 hexp2(__nv_bfloat162 a)
{
return bf16exp2(a);
}
#endif // ENABLE_BF16
template <typename T_OUT, typename T_IN>
__device__ inline T_OUT cuda_cast(T_IN val)
{
return val;
}
template <>
__device__ inline float2 cuda_cast<float2, int2>(int2 val)
{
return make_float2(val.x, val.y);
}
template <>
__device__ inline float2 cuda_cast<float2, float>(float val)
{
return make_float2(val, val);
}
template <>
__device__ inline float2 cuda_cast<float2, half2>(half2 val)
{
return __half22float2(val);
}
template <>
__device__ inline half2 cuda_cast<half2, float2>(float2 val)
{
return __float22half2_rn(val);
}
template <>
__device__ inline half2 cuda_cast<half2, float>(float val)
{
return __float2half2_rn(val);
}
template <>
__device__ inline half2 cuda_cast<half2, half>(half val)
{
return __half2half2(val);
}
template <>
__device__ inline int8_t cuda_cast<int8_t, half>(half val)
{
union
{
int8_t int8[2];
int16_t int16;
};
union
{
half fp16;
int16_t int16_in;
};
fp16 = val;
asm volatile("cvt.rni.sat.s8.f16 %0, %1;" : "=h"(int16) : "h"(int16_in));
return int8[0];
}
template <>
__device__ inline int16_t cuda_cast<int16_t, half2>(half2 val)
{
union
{
int8_t int8[2];
int16_t int16;
};
int8[0] = cuda_cast<int8_t>(val.x);
int8[1] = cuda_cast<int8_t>(val.y);
return int16;
}
template <>
__device__ inline int8_t cuda_cast<int8_t, float>(float val)
{
union
{
int8_t int8[2];
int16_t int16;
};
asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=h"(int16) : "f"(val));
return int8[0];
}
template <>
__device__ inline int16_t cuda_cast<int16_t, float2>(float2 val)
{
union
{
int8_t int8[2];
int16_t int16;
};
int8[0] = cuda_cast<int8_t>(val.x);
int8[1] = cuda_cast<int8_t>(val.y);
return int16;
}
template <>
__device__ inline half2 cuda_cast<half2, int16_t>(int16_t val)
{
union
{
int8_t int8[2];
int16_t int16;
};
int16 = val;
return make_half2(int8[0], int8[1]);
}
template <>
__device__ inline float2 cuda_cast<float2, int16_t>(int16_t val)
{
union
{
int8_t int8[2];
int16_t int16;
};
int16 = val;
return make_float2(int8[0], int8[1]);
}
#ifdef ENABLE_BF16
template <>
__device__ inline __nv_bfloat16 cuda_cast(int32_t val)
{
return static_cast<float>(val);
}
template <>
__device__ inline __nv_bfloat16 cuda_cast(int8_t val)
{
return static_cast<float>(val);
}
template <>
__device__ inline int8_t cuda_cast(__nv_bfloat16 val)
{
return static_cast<float>(val);
}
template <>
__device__ inline float cuda_cast<float, __nv_bfloat16>(__nv_bfloat16 val)
{
return __bfloat162float(val);
}
template <>
__device__ inline float2 cuda_cast<float2, __nv_bfloat162>(__nv_bfloat162 val)
{
return bf1622float2(val);
}
template <>
__device__ inline half cuda_cast<half, __nv_bfloat16>(__nv_bfloat16 val)
{
return __float2half(__bfloat162float(val));
}
template <>
__device__ inline int16_t cuda_cast<int16_t, __nv_bfloat162>(__nv_bfloat162 val)
{
return bf1622int16(val);
}
template <>
__device__ inline __nv_bfloat16 cuda_cast<__nv_bfloat16, float>(float val)
{
return __float2bfloat16(val);
}
template <>
__device__ inline __nv_bfloat16 cuda_cast<__nv_bfloat16, half>(half val)
{
return __float2bfloat16(__half2float(val));
}
template <>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, __nv_bfloat16>(__nv_bfloat16 val)
{
return bf162bf162(val);
}
template <>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, float>(float val)
{
return __float2bfloat162_rn(val);
}
template <>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, float2>(float2 val)
{
return float22bf162(val);
}
template <>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, int16_t>(int16_t val)
{
union
{
int8_t int8[2];
int16_t int16;
};
int16 = val;
__nv_bfloat162 res;
res.x = cuda_cast<__nv_bfloat16>(int8[0]);
res.y = cuda_cast<__nv_bfloat16>(int8[1]);
return res;
}
template <>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, half2>(half2 val)
{
return float22bf162(__half22float2(val));
}
#endif // ENABLE BF16
template <typename T>
__device__ inline T cuda_abs(T val)
{
assert(false);
return {};
}
template <>
__device__ inline float cuda_abs(float val)
{
return fabs(val);
}
template <>
__device__ inline float2 cuda_abs(float2 val)
{
return make_float2(fabs(val.x), fabs(val.y));
}
template <>
__device__ inline half cuda_abs(half val)
{
return __habs(val);
}
template <>
__device__ inline half2 cuda_abs(half2 val)
{
return __habs2(val);
}
#ifdef ENABLE_BF16
#if __CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)
template <>
__device__ inline __nv_bfloat16 cuda_abs(__nv_bfloat16 val)
{
return __habs(val);
}
template <>
__device__ inline __nv_bfloat162 cuda_abs(__nv_bfloat162 val)
{
return __habs2(val);
}
#endif
#endif // ENABLE_FP16
template <typename To, typename Ti>
__device__ inline To cuda_sum(Ti val)
{
return cuda_cast<To>(val);
};
template <typename To>
__device__ inline To cuda_sum(float2 val)
{
return cuda_cast<To>(val.x + val.y);
};
// Unary maximum: compute the max of a vector type
template <typename To, typename Ti>
__device__ inline To cuda_max(Ti val)
{
return cuda_cast<To>(val);
};
template <>
__device__ inline float cuda_max(float2 val)
{
return fmaxf(val.x, val.y);
}
template <>
__device__ inline half cuda_max(half2 val)
{
return __hmax(val.x, val.y);
}
#ifdef ENABLE_BF16
template <>
__device__ inline __nv_bfloat16 cuda_max(__nv_bfloat162 val)
{
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
return __hmax(val.x, val.y);
#else
assert(0);
asm volatile("brkpt;\n" ::);
return __nv_bfloat16(0);
#endif
}
#endif
// Binary maximum: compute the max of two values.
template <typename T>
__device__ inline T cuda_max(T val1, T val2)
{
return (val1 > val2) ? val1 : val2;
}
template <>
__device__ inline float2 cuda_max(float2 val1, float2 val2)
{
float2 out;
out.x = fmaxf(val1.x, val2.x);
out.y = fmaxf(val1.y, val2.y);
return out;
}
template <>
__device__ inline half2 cuda_max(half2 val1, half2 val2)
{
return __hmax2(val1, val2);
}
#ifdef ENABLE_BF16
template <>
__device__ inline __nv_bfloat162 cuda_max(__nv_bfloat162 val1, __nv_bfloat162 val2)
{
return __hmax2(val1, val2);
}
#endif // ENABLE_BF16
// Binary maximum: compute the min of two values.
template <typename T>
__device__ inline T cuda_min(T val1, T val2)
{
return (val1 < val2) ? val1 : val2;
}
template <>
__device__ inline float2 cuda_min(float2 val1, float2 val2)
{
float2 out;
out.x = fminf(val1.x, val2.x);
out.y = fminf(val1.y, val2.y);
return out;
}
template <>
__device__ inline half2 cuda_min(half2 val1, half2 val2)
{
return __hmin2(val1, val2);
}
#ifdef ENABLE_BF16
template <>
__device__ inline __nv_bfloat162 cuda_min(__nv_bfloat162 val1, __nv_bfloat162 val2)
{
return __hmin2(val1, val2);
}
#endif // ENABLE_BF16
// Helper function of clamping the val into the given range.
template <typename T>
inline __device__ T cuda_clamp(T val, T minVal, T maxVal)
{
return cuda_min(cuda_max(val, minVal), maxVal);
}
#ifdef ENABLE_FP8
template <>
__device__ inline float2 cuda_cast<float2, __nv_fp8x2_e4m3>(__nv_fp8x2_e4m3 val)
{
return bf1622float2(fp8x2_e4m3_to_bfloat2(&val));
}
template <>
__device__ inline half2 cuda_cast<half2, __nv_fp8x2_e4m3>(__nv_fp8x2_e4m3 val)
{
return fp8x2_e4m3_to_half2(&val);
}
template <>
__device__ inline __nv_fp8x2_e4m3 cuda_cast<__nv_fp8x2_e4m3, float2>(float2 val)
{
return __nv_fp8x2_e4m3(bf1622float2(float22bf162(val)));
}
template <>
__device__ inline __nv_fp8x2_e4m3 cuda_cast<__nv_fp8x2_e4m3, half2>(half2 val)
{
return __nv_fp8x2_e4m3(cuda_cast<float2>(val));
}
template <>
__device__ inline __nv_fp8x2_e4m3 cuda_cast<__nv_fp8x2_e4m3, __nv_bfloat162>(__nv_bfloat162 val)
{
return __nv_fp8x2_e4m3(cuda_cast<float2>(val));
}
template <>
__device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, half>(half val)
{
return __nv_fp8_e4m3(val);
}
template <>
__device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, __nv_bfloat16>(__nv_bfloat16 val)
{
return __nv_fp8_e4m3(val);
}
template <>
__device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, float>(float val)
{
return __nv_fp8_e4m3(val);
}
template <>
__device__ inline float cuda_cast<float, __nv_fp8_e4m3>(__nv_fp8_e4m3 val)
{
return (float) val;
}
template <>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, __nv_fp8x2_e4m3>(__nv_fp8x2_e4m3 val)
{
return fp8x2_e4m3_to_bfloat2(&val);
}
template <>
__device__ inline int8_t cuda_cast<int8_t, __nv_fp8_e4m3>(__nv_fp8_e4m3 val)
{
// no impl
return 0;
}
template <>
__device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, int8_t>(int8_t val)
{
return cuda_cast<__nv_fp8_e4m3>(cuda_cast<__nv_bfloat16>(cuda_cast<float>(val)));
}
#endif // ENABLE_FP8
} // namespace common
} // namespace tensorrt_llm

View File

@@ -0,0 +1,36 @@
/*
* Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cstddef>
namespace tensorrt_llm::utils::customAllReduceUtils
{
constexpr size_t NUM_POINTERS_PER_RANK = 7;
// WARNING: MUST BE KEPT IN SYNC with tensorrt_llm/plugin/plugin.py
inline size_t getMaxRequiredWorkspaceSize(int worldSize) noexcept
{
if (worldSize <= 2)
{
return 16 * 1000 * 1000;
}
return 8 * 1000 * 1000;
}
} // namespace tensorrt_llm::utils::customAllReduceUtils

View File

@@ -0,0 +1,214 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "envUtils.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/logger.h"
#include <cstdlib>
namespace tensorrt_llm::common
{
std::optional<int32_t> getIntEnv(char const* name)
{
char const* const env = std::getenv(name);
if (env == nullptr)
{
return std::nullopt;
}
int32_t const val = std::stoi(env);
if (val <= 0)
{
return std::nullopt;
}
return {val};
};
// Returns true if the env variable exists and is set to "1"
static bool getBoolEnv(char const* name)
{
char const* env = std::getenv(name);
return env && env[0] == '1' && env[1] == '\0';
}
// XQA kernels (optimized kernels for generation phase).
bool forceXQAKernels()
{
static bool const forceXQA = (getIntEnv("TRTLLM_FORCE_XQA").value_or(0) != 0);
return forceXQA;
}
std::optional<bool> getEnvEnableXQAJIT()
{
static bool init = false;
static bool exists = false;
static bool enableXQAJIT = false;
if (!init)
{
init = true;
char const* enable_xqa_jit_var = std::getenv("TRTLLM_ENABLE_XQA_JIT");
if (enable_xqa_jit_var)
{
exists = true;
if (enable_xqa_jit_var[0] == '1' && enable_xqa_jit_var[1] == '\0')
{
enableXQAJIT = true;
}
}
}
if (exists)
{
return enableXQAJIT;
}
else
{
return std::nullopt;
}
}
// Tune the number of blocks per sequence for accuracy/performance purpose.
bool getEnvMmhaMultiblockDebug()
{
static bool init = false;
static bool forceMmhaMaxSeqLenTile = false;
if (!init)
{
init = true;
char const* enable_mmha_debug_var = std::getenv("TRTLLM_ENABLE_MMHA_MULTI_BLOCK_DEBUG");
if (enable_mmha_debug_var)
{
if (enable_mmha_debug_var[0] == '1' && enable_mmha_debug_var[1] == '\0')
{
forceMmhaMaxSeqLenTile = true;
}
}
}
return forceMmhaMaxSeqLenTile;
}
int getEnvMmhaBlocksPerSequence()
{
static bool init = false;
static int mmhaBlocksPerSequence = 0;
if (!init)
{
init = true;
char const* mmhaBlocksPerSequenceEnv = std::getenv("TRTLLM_MMHA_BLOCKS_PER_SEQUENCE");
if (mmhaBlocksPerSequenceEnv)
{
mmhaBlocksPerSequence = std::atoi(mmhaBlocksPerSequenceEnv);
if (mmhaBlocksPerSequence <= 0)
{
TLLM_LOG_WARNING("Invalid value for TRTLLM_MMHA_BLOCKS_PER_SEQUENCE. Will use default values instead!");
}
}
}
return mmhaBlocksPerSequence;
}
int getEnvMmhaKernelBlockSize()
{
static bool init = false;
static int mmhaKernelBlockSize = 0;
if (!init)
{
init = true;
char const* mmhaKernelBlockSizeEnv = std::getenv("TRTLLM_MMHA_KERNEL_BLOCK_SIZE");
if (mmhaKernelBlockSizeEnv)
{
mmhaKernelBlockSize = std::atoi(mmhaKernelBlockSizeEnv);
if (mmhaKernelBlockSize <= 0)
{
TLLM_LOG_WARNING("Invalid value for TRTLLM_MMHA_KERNEL_BLOCK_SIZE. Will use default values instead!");
}
}
}
return mmhaKernelBlockSize;
}
bool getEnvEnablePDL()
{
static bool init = false;
static bool enablePDL = false;
if (!init)
{
init = true;
// PDL only available when arch >= 90
if (getSMVersion() >= 90)
{
// PDL will be enabled by setting the env variables `TRTLLM_ENABLE_PDL` to `1`
enablePDL = getBoolEnv("TRTLLM_ENABLE_PDL");
}
}
return enablePDL;
}
bool getEnvUseUCXKvCache()
{
static bool const useUCXKVCache = getBoolEnv("TRTLLM_USE_UCX_KVCACHE");
return useUCXKVCache;
}
std::string getEnvUCXInterface()
{
static bool init = false;
static std::string ucxInterface;
if (!init)
{
init = true;
{
char const* ucx_interface = std::getenv("TRTLLM_UCX_INTERFACE");
if (ucx_interface)
{
ucxInterface = ucx_interface;
}
}
}
return ucxInterface;
}
bool getEnvDisaggLayerwise()
{
static bool const disaggLayerwise = getBoolEnv("TRTLLM_DISAGG_LAYERWISE");
return disaggLayerwise;
}
bool getEnvParallelCacheSend()
{
static bool const parallelCacheSend = getBoolEnv("TRTLLM_PARALLEL_CACHE_SEND");
return parallelCacheSend;
}
bool getEnvRequestKVCacheSerial()
{
static bool const requestKVCacheSerial = getBoolEnv("TRTLLM_REQUEST_KV_CACHE_SERIAL");
return requestKVCacheSerial;
}
bool getEnvDisableKVCacheTransferOverlap()
{
static bool const disableKVCacheTransferOverlap = getBoolEnv("TRTLLM_DISABLE_KV_CACHE_TRANSFER_OVERLAP");
return disableKVCacheTransferOverlap;
}
bool getEnvDisableReceiveKVCacheParallel()
{
static bool const disableReceiveParallel = getBoolEnv("TRTLLM_DISABLE_KVCACHE_RECEIVE_PARALLEL");
return disableReceiveParallel;
}
} // namespace tensorrt_llm::common

View File

@@ -0,0 +1,60 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cstdint>
#include <optional>
#include <string>
namespace tensorrt_llm::common
{
// Useful when you want to inject some debug code controllable with env var.
std::optional<int32_t> getIntEnv(char const* name);
// XQA kernels (optimized kernels for generation phase).
bool forceXQAKernels();
// Whether XQA JIT is enabled.
//
// Returns the value of TRTLLM_ENABLE_XQA_JIT env var. If such env var doesn't exist, std::nullopt is returned.
std::optional<bool> getEnvEnableXQAJIT();
// Tune the number of blocks per sequence for accuracy/performance purpose.
bool getEnvMmhaMultiblockDebug();
int getEnvMmhaBlocksPerSequence();
int getEnvMmhaKernelBlockSize();
// Whether PDL is enabled.
bool getEnvEnablePDL();
bool getEnvUseUCXKvCache();
std::string getEnvUCXInterface();
bool getEnvDisaggLayerwise();
bool getEnvParallelCacheSend();
bool getEnvRequestKVCacheSerial();
bool getEnvDisableKVCacheTransferOverlap();
bool getEnvDisableReceiveKVCacheParallel();
} // namespace tensorrt_llm::common

View File

@@ -0,0 +1,70 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/tllmException.h"
#include <cuda_runtime.h>
namespace tensorrt_llm::common
{
Logger::Logger()
{
char* isFirstRankOnlyChar = std::getenv("TLLM_LOG_FIRST_RANK_ONLY");
bool isFirstRankOnly = (isFirstRankOnlyChar != nullptr && std::string(isFirstRankOnlyChar) == "ON");
auto const* levelName = std::getenv("TLLM_LOG_LEVEL");
if (levelName != nullptr)
{
auto level = [levelName = std::string(levelName)]()
{
if (levelName == "TRACE")
return TRACE;
if (levelName == "DEBUG")
return DEBUG;
if (levelName == "INFO")
return INFO;
if (levelName == "WARNING")
return WARNING;
if (levelName == "ERROR")
return ERROR;
TLLM_THROW("Invalid log level: %s", levelName.c_str());
}();
// If TLLM_LOG_FIRST_RANK_ONLY=ON, set LOG LEVEL of other device to ERROR
if (isFirstRankOnly)
{
auto const deviceId = getDevice();
if (deviceId != 1)
{
level = ERROR;
}
}
setLevel(level);
}
}
void Logger::log(std::exception const& ex, Logger::Level level)
{
log(level, "%s: %s", TllmException::demangle(typeid(ex).name()).c_str(), ex.what());
}
Logger* Logger::getLogger()
{
thread_local Logger instance;
return &instance;
}
} // namespace tensorrt_llm::common

View File

@@ -0,0 +1,37 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cuda_runtime.h>
namespace tensorrt_llm
{
namespace common
{
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
inline __device__ __host__ T divUp(T m, T n)
{
return (m + n - 1) / n;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace common
} // namespace tensorrt_llm

View File

@@ -0,0 +1,906 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaTypeUtils.cuh"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/common/memoryUtils.h"
#include <curand_kernel.h>
#include <sys/stat.h>
#include <unordered_map>
namespace tensorrt_llm
{
namespace common
{
template <typename T>
void deviceMalloc(T** ptr, size_t size, bool is_random_initialize)
{
check_cuda_error(cudaMalloc((void**) (ptr), sizeof(T) * size));
if (is_random_initialize)
{
cudaRandomUniform(*ptr, size);
}
}
template void deviceMalloc(float** ptr, size_t size, bool is_random_initialize);
template void deviceMalloc(half** ptr, size_t size, bool is_random_initialize);
#ifdef ENABLE_BF16
template void deviceMalloc(__nv_bfloat16** ptr, size_t size, bool is_random_initialize);
#endif
template void deviceMalloc(uint16_t** ptr, size_t size, bool is_random_initialize);
template void deviceMalloc(int** ptr, size_t size, bool is_random_initialize);
template void deviceMalloc(bool** ptr, size_t size, bool is_random_initialize);
template void deviceMalloc(char** ptr, size_t size, bool is_random_initialize);
template void deviceMalloc(int8_t** ptr, size_t size, bool is_random_initialize);
#ifdef ENABLE_FP8
template void deviceMalloc(__nv_fp8_e4m3** ptr, size_t size, bool is_random_initialize);
#endif
template <typename T>
void deviceMemSetZero(T* ptr, size_t size)
{
check_cuda_error(cudaMemset(static_cast<void*>(ptr), 0, sizeof(T) * size));
}
template void deviceMemSetZero(float* ptr, size_t size);
template void deviceMemSetZero(half* ptr, size_t size);
template void deviceMemSetZero(int* ptr, size_t size);
template void deviceMemSetZero(uint32_t* ptr, size_t size);
template void deviceMemSetZero(bool* ptr, size_t size);
#ifdef ENABLE_FP8
template void deviceMemSetZero(__nv_fp8_e4m3* ptr, size_t size);
#endif
#ifdef ENABLE_BF16
template void deviceMemSetZero(__nv_bfloat16* ptr, size_t size);
#endif
template <typename T>
void deviceFree(T*& ptr)
{
if (ptr != NULL)
{
check_cuda_error(cudaFree(ptr));
ptr = NULL;
}
}
template void deviceFree(float*& ptr);
template void deviceFree(half*& ptr);
#ifdef ENABLE_BF16
template void deviceFree(__nv_bfloat16*& ptr);
#endif
template void deviceFree(unsigned short*& ptr);
template void deviceFree(int*& ptr);
template void deviceFree(bool*& ptr);
template void deviceFree(char*& ptr);
template void deviceFree(int8_t*& ptr);
#ifdef ENABLE_FP8
template void deviceFree(__nv_fp8_e4m3*& ptr);
#endif
template <typename T>
void deviceFill(T* devptr, size_t size, T value, cudaStream_t stream)
{
T* arr = new T[size];
std::fill(arr, arr + size, value);
check_cuda_error(cudaMemcpyAsync(devptr, arr, sizeof(T) * size, cudaMemcpyHostToDevice, stream));
delete[] arr;
}
template void deviceFill(float* devptr, size_t size, float value, cudaStream_t stream);
template void deviceFill(half* devptr, size_t size, half value, cudaStream_t stream);
#ifdef ENABLE_BF16
template void deviceFill(__nv_bfloat16* devptr, size_t size, __nv_bfloat16 value, cudaStream_t stream);
#endif
template void deviceFill(int* devptr, size_t size, int value, cudaStream_t stream);
template void deviceFill(bool* devptr, size_t size, bool value, cudaStream_t stream);
template <typename T>
void cudaD2Hcpy(T* tgt, T const* src, const size_t size)
{
check_cuda_error(cudaMemcpy(tgt, src, sizeof(T) * size, cudaMemcpyDeviceToHost));
}
template void cudaD2Hcpy(float* tgt, float const* src, size_t size);
template void cudaD2Hcpy(half* tgt, half const* src, size_t size);
#ifdef ENABLE_BF16
template void cudaD2Hcpy(__nv_bfloat16* tgt, __nv_bfloat16 const* src, size_t size);
#endif
template void cudaD2Hcpy(int* tgt, int const* src, size_t size);
template void cudaD2Hcpy(bool* tgt, bool const* src, size_t size);
#ifdef ENABLE_FP8
template void cudaD2Hcpy(__nv_fp8_e4m3* tgt, __nv_fp8_e4m3 const* src, size_t size);
#endif
template void cudaD2Hcpy(unsigned long long* tgt, unsigned long long const* src, size_t size);
template void cudaD2Hcpy(unsigned int* tgt, unsigned int const* src, size_t size);
template void cudaD2Hcpy(int8_t* tgt, int8_t const* src, size_t size);
template <typename T>
void cudaH2Dcpy(T* tgt, T const* src, const size_t size)
{
check_cuda_error(cudaMemcpy(tgt, src, sizeof(T) * size, cudaMemcpyHostToDevice));
}
template void cudaH2Dcpy(float* tgt, float const* src, size_t size);
template void cudaH2Dcpy(half* tgt, half const* src, size_t size);
#ifdef ENABLE_BF16
template void cudaH2Dcpy(__nv_bfloat16* tgt, __nv_bfloat16 const* src, size_t size);
#endif
template void cudaH2Dcpy(int* tgt, int const* src, size_t size);
template void cudaH2Dcpy(bool* tgt, bool const* src, size_t size);
#ifdef ENABLE_FP8
template void cudaH2Dcpy(__nv_fp8_e4m3* tgt, __nv_fp8_e4m3 const* src, size_t size);
#endif
template void cudaH2Dcpy(unsigned long long* tgt, unsigned long long const* src, size_t size);
template void cudaH2Dcpy(unsigned int* tgt, unsigned int const* src, size_t size);
template void cudaH2Dcpy(int8_t* tgt, int8_t const* src, size_t size);
template <typename T>
void cudaD2Dcpy(T* tgt, T const* src, const size_t size, cudaStream_t stream)
{
check_cuda_error(cudaMemcpyAsync(tgt, src, sizeof(T) * size, cudaMemcpyDeviceToDevice, stream));
}
template void cudaD2Dcpy(float* tgt, float const* src, size_t size, cudaStream_t stream);
template void cudaD2Dcpy(half* tgt, half const* src, size_t size, cudaStream_t stream);
#ifdef ENABLE_BF16
template void cudaD2Dcpy(__nv_bfloat16* tgt, __nv_bfloat16 const* src, size_t size, cudaStream_t stream);
#endif
template void cudaD2Dcpy(int* tgt, int const* src, size_t size, cudaStream_t stream);
template void cudaD2Dcpy(bool* tgt, bool const* src, size_t size, cudaStream_t stream);
template void cudaD2Dcpy(int8_t* tgt, int8_t const* src, size_t size, cudaStream_t stream);
#ifdef ENABLE_FP8
template void cudaD2Dcpy(__nv_fp8_e4m3* tgt, __nv_fp8_e4m3 const* src, size_t size, cudaStream_t stream);
#endif
template void cudaD2Dcpy(unsigned long long* tgt, unsigned long long const* src, size_t size, cudaStream_t stream);
template <typename T_OUT, typename T_IN>
__global__ void cudaCast(T_OUT* dst, T_IN* src, const size_t size)
{
for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x)
{
dst[tid] = (T_OUT) ((float) (src[tid]));
}
}
template <typename T_OUT, typename T_IN>
void invokeCudaCast(T_OUT* dst, T_IN const* const src, const size_t size, cudaStream_t stream)
{
cudaCast<<<256, 256, 0, stream>>>(dst, src, size);
}
template void invokeCudaCast(float* dst, half const* const src, const size_t size, cudaStream_t stream);
#ifdef ENABLE_BF16
template void invokeCudaCast(float* dst, __nv_bfloat16 const* const src, const size_t size, cudaStream_t stream);
template void invokeCudaCast(__nv_bfloat16* dst, float const* const src, const size_t size, cudaStream_t stream);
template void invokeCudaCast(__nv_bfloat16* dst, half const* const src, const size_t size, cudaStream_t stream);
template void invokeCudaCast(half* dst, __nv_bfloat16 const* const src, const size_t size, cudaStream_t stream);
#endif
#ifdef ENABLE_FP8
template void invokeCudaCast(float* dst, __nv_fp8_e4m3 const* const src, const size_t size, cudaStream_t stream);
template void invokeCudaCast(
__nv_bfloat16* dst, __nv_fp8_e4m3 const* const src, const size_t size, cudaStream_t stream);
template void invokeCudaCast(half* dst, __nv_fp8_e4m3 const* const src, const size_t size, cudaStream_t stream);
template void invokeCudaCast(__nv_fp8_e4m3* dst, float const* const src, const size_t size, cudaStream_t stream);
template void invokeCudaCast(
__nv_fp8_e4m3* dst, __nv_bfloat16 const* const src, const size_t size, cudaStream_t stream);
template void invokeCudaCast(__nv_fp8_e4m3* dst, half const* const src, const size_t size, cudaStream_t stream);
#endif
template <typename T>
void cudaAutoCpy(T* tgt, T const* src, const size_t size, cudaStream_t stream)
{
if (stream != NULL)
{
check_cuda_error(cudaMemcpyAsync(tgt, src, sizeof(T) * size, cudaMemcpyDefault, stream));
}
else
{
check_cuda_error(cudaMemcpy(tgt, src, sizeof(T) * size, cudaMemcpyDefault));
}
}
template void cudaAutoCpy(float* tgt, float const* src, size_t size, cudaStream_t stream);
template void cudaAutoCpy(half* tgt, half const* src, size_t size, cudaStream_t stream);
#ifdef ENABLE_BF16
template void cudaAutoCpy(__nv_bfloat16* tgt, __nv_bfloat16 const* src, size_t size, cudaStream_t stream);
#endif
template void cudaAutoCpy(int* tgt, int const* src, size_t size, cudaStream_t stream);
template void cudaAutoCpy(bool* tgt, bool const* src, size_t size, cudaStream_t stream);
template void cudaAutoCpy(int8_t* tgt, int8_t const* src, size_t size, cudaStream_t stream);
template void cudaAutoCpy(uint8_t* tgt, uint8_t const* src, size_t size, cudaStream_t stream);
template void cudaAutoCpy(uint32_t* tgt, uint32_t const* src, size_t size, cudaStream_t stream);
template void cudaAutoCpy(unsigned long long* tgt, unsigned long long const* src, size_t size, cudaStream_t stream);
template void cudaAutoCpy(unsigned long* tgt, unsigned long const* src, size_t size, cudaStream_t stream);
template void cudaAutoCpy(char* tgt, char const* src, size_t size, cudaStream_t stream);
template void cudaAutoCpy(float const** tgt, float const* const* src, size_t size, cudaStream_t stream);
template void cudaAutoCpy(half const** tgt, half const* const* src, size_t size, cudaStream_t stream);
#ifdef ENABLE_BF16
template void cudaAutoCpy(__nv_bfloat16 const** tgt, __nv_bfloat16 const* const* src, size_t size, cudaStream_t stream);
#endif
template void cudaAutoCpy(int const** tgt, int const* const* src, size_t size, cudaStream_t stream);
template void cudaAutoCpy(bool const** tgt, bool const* const* src, size_t size, cudaStream_t stream);
template void cudaAutoCpy(int8_t const** tgt, int8_t const* const* src, size_t size, cudaStream_t stream);
template void cudaAutoCpy(
unsigned long long const** tgt, unsigned long long const* const* src, size_t size, cudaStream_t stream);
template <typename T>
__global__ void cuda_random_uniform_kernel(T* buffer, const size_t size, int const seq_offset)
{
const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
curandState_t local_state;
curand_init((unsigned long long int) 1337, idx + seq_offset, 0, &local_state);
for (size_t index = idx; index < size; index += blockDim.x * gridDim.x)
{
buffer[index] = (T) (curand_uniform(&local_state) * 0.2f - 0.1f);
}
}
template <>
__global__ void cuda_random_uniform_kernel<int>(int* buffer, const size_t size, int const seq_offset)
{
const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
curandState_t local_state;
curand_init((float) 1337.f, idx + seq_offset, 0, &local_state);
for (size_t index = idx; index < size; index += blockDim.x * gridDim.x)
{
buffer[index] = curand(&local_state);
}
}
template <>
__global__ void cuda_random_uniform_kernel<bool>(bool* buffer, const size_t size, int const seq_offset)
{
const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
curandState_t local_state;
curand_init((float) 1337.f, idx + seq_offset, 0, &local_state);
for (size_t index = idx; index < size; index += blockDim.x * gridDim.x)
{
buffer[index] = (curand(&local_state) % 2 == 0);
}
}
template <>
__global__ void cuda_random_uniform_kernel<char>(char* buffer, const size_t size, int const seq_offset)
{
const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
curandState_t local_state;
curand_init((float) 1337.f, idx + seq_offset, 0, &local_state);
for (size_t index = idx; index < size; index += blockDim.x * gridDim.x)
{
buffer[index] = curand(&local_state) % 0xFF;
}
}
template <typename T>
void cudaRandomUniform(T* buffer, const size_t size)
{
static int seq_offset = 0;
cuda_random_uniform_kernel<T><<<256, 256>>>(buffer, size, seq_offset);
seq_offset += 256 * 256;
}
template void cudaRandomUniform(float* buffer, const size_t size);
template void cudaRandomUniform(half* buffer, const size_t size);
#ifdef ENABLE_BF16
template void cudaRandomUniform(__nv_bfloat16* buffer, const size_t size);
#endif
template void cudaRandomUniform(int* buffer, const size_t size);
template void cudaRandomUniform(bool* buffer, const size_t size);
template void cudaRandomUniform(char* buffer, const size_t size);
#ifdef ENABLE_FP8
template void cudaRandomUniform(__nv_fp8_e4m3* buffer, const size_t size);
#endif
// loads data from binary file. If it succeeds, returns a non-empty vector. If loading fails or
// the product of the elements in shape is 0, this function will return an empty vector.
template <typename T>
std::vector<T> loadWeightFromBinHelper(std::vector<size_t> shape, std::string filename)
{
if (shape.size() > 2)
{
printf("[ERROR] shape should have less than two dims \n");
return std::vector<T>();
}
size_t dim0 = shape[0], dim1 = 1;
if (shape.size() == 2)
{
dim1 = shape[1];
}
size_t size = dim0 * dim1;
if (size == 0)
{
TLLM_LOG_WARNING("shape is zero, skip loading weight from file %s \n", filename.c_str());
return std::vector<T>();
}
std::vector<T> host_array(size);
std::ifstream in(filename, std::ios::in | std::ios::binary);
if (!in.is_open())
{
TLLM_LOG_WARNING("file %s cannot be opened, loading model fails! \n", filename.c_str());
return std::vector<T>();
}
size_t loaded_data_size = sizeof(T) * size;
in.seekg(0, in.end);
in.seekg(0, in.beg);
TLLM_LOG_DEBUG("Read " + std::to_string(loaded_data_size) + " bytes from " + filename);
in.read((char*) host_array.data(), loaded_data_size);
size_t in_get_size = in.gcount();
if (in_get_size != loaded_data_size)
{
TLLM_LOG_WARNING("file %s only has %ld, but request %ld, loading model fails! \n", filename.c_str(),
in_get_size, loaded_data_size);
return std::vector<T>();
}
in.close();
// If we succeed, return an array with values.
return host_array;
}
template <typename T, typename T_IN>
int loadWeightFromBinFunc(T* ptr, std::vector<size_t> shape, std::string filename)
{
std::vector<T_IN> host_array = loadWeightFromBinHelper<T_IN>(shape, filename);
if (host_array.empty())
{
return 0;
}
if (std::is_same<T, T_IN>::value == true)
{
cudaH2Dcpy(ptr, (T*) host_array.data(), host_array.size());
}
else
{
T_IN* ptr_2 = nullptr;
deviceMalloc(&ptr_2, host_array.size(), false);
cudaH2Dcpy(ptr_2, host_array.data(), host_array.size());
invokeCudaD2DcpyConvert(ptr, ptr_2, host_array.size());
deviceFree(ptr_2);
}
return 0;
}
template int loadWeightFromBinFunc<float, float>(float* ptr, std::vector<size_t> shape, std::string filename);
template int loadWeightFromBinFunc<half, float>(half* ptr, std::vector<size_t> shape, std::string filename);
template int loadWeightFromBinFunc<float, half>(float* ptr, std::vector<size_t> shape, std::string filename);
template int loadWeightFromBinFunc<half, half>(half* ptr, std::vector<size_t> shape, std::string filename);
template int loadWeightFromBinFunc<int8_t, int8_t>(int8_t* ptr, std::vector<size_t> shape, std::string filename);
#ifdef ENABLE_BF16
template int loadWeightFromBinFunc<__nv_bfloat16, float>(
__nv_bfloat16* ptr, std::vector<size_t> shape, std::string filename);
template int loadWeightFromBinFunc<__nv_bfloat16, half>(
__nv_bfloat16* ptr, std::vector<size_t> shape, std::string filename);
template int loadWeightFromBinFunc<float, __nv_bfloat16>(float* ptr, std::vector<size_t> shape, std::string filename);
template int loadWeightFromBinFunc<half, __nv_bfloat16>(half* ptr, std::vector<size_t> shape, std::string filename);
template int loadWeightFromBinFunc<__nv_bfloat16, __nv_bfloat16>(
__nv_bfloat16* ptr, std::vector<size_t> shape, std::string filename);
#endif // ENABLE_BF16
template int loadWeightFromBinFunc<int, int>(int* ptr, std::vector<size_t> shape, std::string filename);
#ifdef ENABLE_FP8
template int loadWeightFromBinFunc<__nv_fp8_e4m3, float>(
__nv_fp8_e4m3* ptr, std::vector<size_t> shape, std::string filename);
#endif // ENABLE_FP8
template <typename T>
int loadWeightFromBin(T* ptr, std::vector<size_t> shape, std::string filename, TRTLLMCudaDataType model_file_type)
{
switch (model_file_type)
{
case TRTLLMCudaDataType::FP32: loadWeightFromBinFunc<T, float>(ptr, shape, filename); break;
case TRTLLMCudaDataType::FP16: loadWeightFromBinFunc<T, half>(ptr, shape, filename); break;
case TRTLLMCudaDataType::INT8: loadWeightFromBinFunc<T, int8_t>(ptr, shape, filename); break;
#ifdef ENABLE_BF16
case TRTLLMCudaDataType::BF16: loadWeightFromBinFunc<T, __nv_bfloat16>(ptr, shape, filename); break;
#endif
#ifdef ENABLE_FP8
case TRTLLMCudaDataType::FP8: loadWeightFromBinFunc<T, float>(ptr, shape, filename); break;
#endif
default: TLLM_LOG_ERROR("Does not support TRTLLMCudaDataType=%d", model_file_type); TLLM_CHECK(false);
}
return 0;
}
template <>
int loadWeightFromBin(int* ptr, std::vector<size_t> shape, std::string filename, TRTLLMCudaDataType model_file_type)
{
loadWeightFromBinFunc<int, int>(ptr, shape, filename);
return 0;
}
template int loadWeightFromBin(
float* ptr, std::vector<size_t> shape, std::string filename, TRTLLMCudaDataType model_file_type);
template int loadWeightFromBin(
half* ptr, std::vector<size_t> shape, std::string filename, TRTLLMCudaDataType model_file_type);
template int loadWeightFromBin(
int8_t* ptr, std::vector<size_t> shape, std::string filename, TRTLLMCudaDataType model_file_type);
#ifdef ENABLE_BF16
template int loadWeightFromBin(
__nv_bfloat16* ptr, std::vector<size_t> shape, std::string filename, TRTLLMCudaDataType model_file_type);
#endif
#ifdef ENABLE_FP8
template int loadWeightFromBin(
__nv_fp8_e4m3* ptr, std::vector<size_t> shape, std::string filename, TRTLLMCudaDataType model_file_type);
#endif
template int loadWeightFromBin(
int* ptr, std::vector<size_t> shape, std::string filename, TRTLLMCudaDataType model_file_type);
template <typename T_IN, typename T_OUT>
__global__ void cudaD2DcpyConvert(T_OUT* dst, const T_IN* src, const size_t size)
{
for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x)
{
dst[tid] = cuda_cast<T_OUT>(src[tid]);
}
}
template <typename T_IN, typename T_OUT>
void invokeCudaD2DcpyConvert(T_OUT* tgt, const T_IN* src, const size_t size, cudaStream_t stream)
{
cudaD2DcpyConvert<<<256, 256, 0, stream>>>(tgt, src, size);
}
template void invokeCudaD2DcpyConvert(int8_t* tgt, float const* src, const size_t size, cudaStream_t stream);
template void invokeCudaD2DcpyConvert(float* tgt, int8_t const* src, const size_t size, cudaStream_t stream);
template void invokeCudaD2DcpyConvert(float* tgt, int const* src, const size_t size, cudaStream_t stream);
template void invokeCudaD2DcpyConvert(half* tgt, int const* src, const size_t size, cudaStream_t stream);
template void invokeCudaD2DcpyConvert(float* tgt, float const* src, const size_t size, cudaStream_t stream);
template void invokeCudaD2DcpyConvert(half* tgt, float const* src, const size_t size, cudaStream_t stream);
template void invokeCudaD2DcpyConvert(float* tgt, half const* src, const size_t size, cudaStream_t stream);
template void invokeCudaD2DcpyConvert(uint32_t* tgt, int const* src, const size_t size, cudaStream_t stream);
template void invokeCudaD2DcpyConvert(int* tgt, uint32_t const* src, const size_t size, cudaStream_t stream);
template void invokeCudaD2DcpyConvert(int* tgt, float const* src, const size_t size, cudaStream_t stream);
template void invokeCudaD2DcpyConvert(int* tgt, half const* src, const size_t size, cudaStream_t stream);
#ifdef ENABLE_BF16
template void invokeCudaD2DcpyConvert(__nv_bfloat16* tgt, float const* src, const size_t size, cudaStream_t stream);
template void invokeCudaD2DcpyConvert(__nv_bfloat16* tgt, int const* src, const size_t size, cudaStream_t stream);
template void invokeCudaD2DcpyConvert(float* tgt, __nv_bfloat16 const* src, const size_t size, cudaStream_t stream);
template void invokeCudaD2DcpyConvert(int* tgt, __nv_bfloat16 const* src, const size_t size, cudaStream_t stream);
#endif // ENABLE_BF16
template <typename T_IN, typename T_OUT>
__global__ void cudaD2DScaleCpyConvert(
T_OUT* dst, const T_IN* src, float const* scale, bool invert_scale, const size_t size)
{
float const scale_value = invert_scale ? 1.0f / scale[0] : scale[0];
for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x)
{
dst[tid] = cuda_cast<T_OUT>(cuda_cast<float>(src[tid]) * scale_value);
}
}
template <typename T_IN, typename T_OUT>
void invokeCudaD2DScaleCpyConvert(
T_OUT* tgt, const T_IN* src, float const* scale, bool invert_scale, const size_t size, cudaStream_t stream)
{
cudaD2DScaleCpyConvert<<<256, 256, 0, stream>>>(tgt, src, scale, invert_scale, size);
}
// clang-format off
template void invokeCudaD2DScaleCpyConvert(float* tgt, const int32_t* src, const float* scale, bool invert_scale, const size_t size, cudaStream_t stream);
template void invokeCudaD2DScaleCpyConvert(int32_t* tgt, const float* src, const float* scale, bool invert_scale, const size_t size, cudaStream_t stream);
template void invokeCudaD2DScaleCpyConvert(half* tgt, const int32_t* src, const float* scale, bool invert_scale, const size_t size, cudaStream_t stream);
template void invokeCudaD2DScaleCpyConvert(int32_t* tgt, const half* src, const float* scale, bool invert_scale, const size_t size, cudaStream_t stream);
#ifdef ENABLE_BF16
template void invokeCudaD2DScaleCpyConvert(__nv_bfloat16* tgt, const int32_t* src, const float* scale, bool invert_scale, const size_t size, cudaStream_t stream);
template void invokeCudaD2DScaleCpyConvert(int32_t* tgt, const __nv_bfloat16* src, const float* scale, bool invert_scale, const size_t size, cudaStream_t stream);
#endif // ENABLE_BF16
#ifdef ENABLE_FP8
template void invokeCudaD2DScaleCpyConvert(float* tgt, const __nv_fp8_e4m3* src, const float* scale, bool invert_scale, const size_t size, cudaStream_t stream);
#endif // ENABLE_FP8
// clang-format on
void invokeCudaD2DcpyHalf2Float(float* dst, half* src, const size_t size, cudaStream_t stream)
{
invokeCudaD2DcpyConvert(dst, src, size, stream);
}
void invokeCudaD2DcpyFloat2Half(half* dst, float* src, const size_t size, cudaStream_t stream)
{
invokeCudaD2DcpyConvert(dst, src, size, stream);
}
template <typename T>
void saveToBinary(T const* ptr, const size_t size, std::string filename)
{
std::vector<T> h_ptr(size);
cudaD2Hcpy(h_ptr.data(), ptr, size);
std::vector<float> float_ptr(size);
for (size_t i = 0; i < size; i++)
{
float_ptr[i] = (float) h_ptr[i];
}
std::ofstream out(filename, std::ios::out | std::ios::binary);
TLLM_CHECK_WITH_INFO(out.is_open(), "Fail to open file " + filename);
out.write((char*) float_ptr.data(), size * sizeof(float));
}
template void saveToBinary(float const* ptr, const size_t size, std::string filename);
template void saveToBinary(half const* ptr, const size_t size, std::string filename);
#ifdef ENABLE_BF16
template void saveToBinary(__nv_bfloat16 const* ptr, const size_t size, std::string filename);
#endif // ENABLE_BF16
template <>
void saveToBinary(int const* ptr, const size_t size, std::string filename)
{
std::vector<int> h_ptr(size);
cudaD2Hcpy(h_ptr.data(), ptr, size);
std::ofstream out(filename, std::ios::out | std::ios::binary);
TLLM_CHECK_WITH_INFO(out.is_open(), "Fail to open file " + filename);
out.write((char*) h_ptr.data(), size * sizeof(int));
}
template <typename T_IN, typename T_fake_type>
__global__ void fakeCast(T_IN* input_ptr, const size_t size)
{
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x)
{
T_fake_type tmp_val = (T_fake_type) ((float) input_ptr[i]);
input_ptr[i] = (T_IN) ((float) tmp_val);
}
}
template <typename T_IN, typename T_fake_type>
void invokeFakeCast(T_IN* input_ptr, const size_t size, cudaStream_t stream)
{
dim3 block(256);
dim3 grid((size + 255) / 256);
fakeCast<T_IN, T_fake_type><<<grid, block, 0, stream>>>(input_ptr, size);
}
#ifdef ENABLE_FP8
__global__ void cudaD2Dcpyfp82Float(float* dst, __nv_fp8_e4m3* src, const size_t size)
{
for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x)
{
dst[tid] = (float) (src[tid]);
}
}
void invokeCudaD2Dcpyfp82Float(float* dst, __nv_fp8_e4m3* src, const size_t size, cudaStream_t stream)
{
cudaD2Dcpyfp82Float<<<256, 256, 0, stream>>>(dst, src, size);
}
__global__ void cudaD2Dcpyfp82Half(half* dst, __nv_fp8_e4m3* src, const size_t size)
{
for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x)
{
dst[tid] = (half) ((float) (src[tid]));
}
}
void invokeCudaD2Dcpyfp82Half(half* dst, __nv_fp8_e4m3* src, const size_t size, cudaStream_t stream)
{
cudaD2Dcpyfp82Half<<<256, 256, 0, stream>>>(dst, src, size);
}
__global__ void cudaD2DcpyFloat2fp8(__nv_fp8_e4m3* dst, float* src, const size_t size)
{
for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x)
{
dst[tid] = (__nv_fp8_e4m3) src[tid];
}
}
void invokeCudaD2DcpyFloat2fp8(__nv_fp8_e4m3* dst, float* src, const size_t size, cudaStream_t stream)
{
cudaD2DcpyFloat2fp8<<<256, 256, 0, stream>>>(dst, src, size);
}
__global__ void cudaD2DcpyHalf2fp8(__nv_fp8_e4m3* dst, half* src, const size_t size)
{
for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x)
{
dst[tid] = (__nv_fp8_e4m3) src[tid];
}
}
void invokeCudaD2DcpyHalf2fp8(__nv_fp8_e4m3* dst, half* src, const size_t size, cudaStream_t stream)
{
cudaD2DcpyHalf2fp8<<<256, 256, 0, stream>>>(dst, src, size);
}
__global__ void cudaD2DcpyBfloat2fp8(__nv_fp8_e4m3* dst, __nv_bfloat16* src, const size_t size)
{
for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x)
{
dst[tid] = (__nv_fp8_e4m3) src[tid];
}
}
void invokeCudaD2DcpyBfloat2fp8(__nv_fp8_e4m3* dst, __nv_bfloat16* src, const size_t size, cudaStream_t stream)
{
cudaD2DcpyBfloat2fp8<<<256, 256, 0, stream>>>(dst, src, size);
}
#endif // ENABLE_FP8
template <typename T_OUT, typename T_IN>
__global__ void transpose(T_OUT* dst, T_IN* src, const size_t dim0, const size_t dim1)
{
for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < dim0 * dim1; tid += blockDim.x * gridDim.x)
{
const size_t src_col_id = tid % dim1;
const size_t src_row_id = tid / dim1;
dst[src_col_id * dim0 + src_row_id] = (T_OUT) (src[tid]);
}
}
template <typename T>
void invokeInPlaceTranspose(T* data, T* workspace, const size_t dim0, const size_t dim1)
{
// copy data to workspace, and then transpose from workspace to data
cudaD2Dcpy(workspace, data, dim0 * dim1);
transpose<<<256, 256>>>(data, workspace, dim0, dim1);
}
#ifdef ENABLE_FP8
template void invokeInPlaceTranspose(
__nv_fp8_e4m3* data, __nv_fp8_e4m3* workspace, const size_t dim0, const size_t dim1);
#endif // ENABLE_FP8
#ifdef ENABLE_BF16
template void invokeInPlaceTranspose(
__nv_bfloat16* data, __nv_bfloat16* workspace, const size_t dim0, const size_t dim1);
#endif // ENABLE_BF16
template void invokeInPlaceTranspose(float* data, float* workspace, const size_t dim0, const size_t dim1);
template <typename T_OUT, typename T_IN>
__global__ void transpose0213(
T_OUT* dst, T_IN* src, const size_t dim0, const size_t dim1, const size_t dim2, const size_t dim3)
{
// src permutation: [0, 1, 2, 3]
// dst permutation: [0, 2, 1, 3]
for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < dim0 * dim1 * dim2 * dim3;
tid += blockDim.x * gridDim.x)
{
size_t tmp_idx = tid;
const size_t dim_3_idx = tmp_idx % dim3;
tmp_idx = (tmp_idx - dim_3_idx) / dim3;
const size_t dim_2_idx = tmp_idx % dim2;
tmp_idx = (tmp_idx - dim_2_idx) / dim2;
const size_t dim_1_idx = tmp_idx % dim1;
tmp_idx = (tmp_idx - dim_1_idx) / dim1;
const size_t dim_0_idx = tmp_idx % dim0;
dst[dim_0_idx * dim1 * dim2 * dim3 + dim_2_idx * dim1 * dim3 + dim_1_idx * dim3 + dim_3_idx] = src[tid];
}
}
template <typename T>
void invokeInPlaceTranspose0213(
T* data, T* workspace, const size_t dim0, const size_t dim1, const size_t dim2, const size_t dim3)
{
// copy data to workspace, and then transpose from workspace to data
// Note that this kernel is used for pre-processing and not very efficient.
cudaD2Dcpy(workspace, data, dim0 * dim1 * dim2 * dim3);
transpose0213<<<256, 256>>>(data, workspace, dim0, dim1, dim2, dim3);
}
#ifdef ENABLE_FP8
template void invokeInPlaceTranspose0213(__nv_fp8_e4m3* data, __nv_fp8_e4m3* workspace, const size_t dim0,
const size_t dim1, const size_t dim2, const size_t dim3);
#endif // ENABLE_FP8
#ifdef ENABLE_BF16
template void invokeInPlaceTranspose0213(__nv_bfloat16* data, __nv_bfloat16* workspace, const size_t dim0,
const size_t dim1, const size_t dim2, const size_t dim3);
#endif // ENABLE_BF16
template void invokeInPlaceTranspose0213(
float* data, float* workspace, const size_t dim0, const size_t dim1, const size_t dim2, const size_t dim3);
template <typename T_OUT, typename T_IN>
__global__ void transpose102(T_OUT* dst, T_IN* src, const size_t dim0, const size_t dim1, const size_t dim2)
{
// src permutation: [0, 1, 2]
// dst permutation: [1, 0, 2]
for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < dim0 * dim1 * dim2; tid += blockDim.x * gridDim.x)
{
size_t tmp_idx = tid;
const size_t dim_2_idx = tmp_idx % dim2;
tmp_idx = (tmp_idx - dim_2_idx) / dim2;
const size_t dim_1_idx = tmp_idx % dim1;
tmp_idx = (tmp_idx - dim_1_idx) / dim1;
const size_t dim_0_idx = tmp_idx % dim0;
dst[dim_1_idx * dim0 * dim2 + dim_0_idx * dim2 + dim_2_idx] = src[tid];
}
}
template <typename T>
void invokeInPlaceTranspose102(T* data, T* workspace, const size_t dim0, const size_t dim1, const size_t dim2)
{
// copy data to workspace, and then transpose from workspace to data
// Note that this kernel is used for pre-processing and not very efficient.
cudaD2Dcpy(workspace, data, dim0 * dim1 * dim2);
transpose102<<<256, 256>>>(data, workspace, dim0, dim1, dim2);
}
#ifdef ENABLE_FP8
template void invokeInPlaceTranspose102(
__nv_fp8_e4m3* data, __nv_fp8_e4m3* workspace, const size_t dim0, const size_t dim1, const size_t dim2);
#endif // ENABLE_FP8
#ifdef ENABLE_BF16
template void invokeInPlaceTranspose102(
__nv_bfloat16* data, __nv_bfloat16* workspace, const size_t dim0, const size_t dim1, const size_t dim2);
#endif // ENABLE_BF16
template void invokeInPlaceTranspose102(
float* data, float* workspace, const size_t dim0, const size_t dim1, const size_t dim2);
template <typename T>
void __global__ multiplyScale(T* tensor, float scale, const size_t size)
{
for (size_t index = threadIdx.x + blockIdx.x * blockDim.x; index < size; index += blockDim.x * gridDim.x)
{
tensor[index] = (T) (((float) tensor[index]) * scale);
}
}
template <typename T>
void invokeMultiplyScale(T* tensor, float scale, const size_t size, cudaStream_t stream)
{
int block = 256;
int grid = (size + 255) / 256;
multiplyScale<<<grid, block, 0, stream>>>(tensor, scale, size);
}
template void invokeMultiplyScale(float* tensor, float scale, const size_t size, cudaStream_t stream);
template void invokeMultiplyScale(half* tensor, float scale, const size_t size, cudaStream_t stream);
#ifdef ENABLE_BF16
template void invokeMultiplyScale(__nv_bfloat16* tensor, float scale, const size_t size, cudaStream_t stream);
#endif
#ifdef ENABLE_FP8
template void invokeMultiplyScale(__nv_fp8_e4m3* tensor, float scale, const size_t size, cudaStream_t stream);
#endif
template <typename T>
void __global__ divideScale(T* tensor, float scale, const size_t size)
{
for (size_t index = threadIdx.x + blockIdx.x * blockDim.x; index < size; index += blockDim.x * gridDim.x)
{
tensor[index] = (T) (((float) tensor[index]) / scale);
}
}
template <typename T>
void invokeDivideScale(T* tensor, float scale, const size_t size, cudaStream_t stream)
{
int block = 256;
int grid = (size + 255) / 256;
divideScale<<<grid, block, 0, stream>>>(tensor, scale, size);
}
template void invokeDivideScale(float* tensor, float scale, const size_t size, cudaStream_t stream);
template void invokeDivideScale(half* tensor, float scale, const size_t size, cudaStream_t stream);
#ifdef ENABLE_BF16
template void invokeDivideScale(__nv_bfloat16* tensor, float scale, const size_t size, cudaStream_t stream);
#endif
#ifdef ENABLE_FP8
template void invokeDivideScale(__nv_fp8_e4m3* tensor, float scale, const size_t size, cudaStream_t stream);
#endif
#ifdef ENABLE_BF16
template void invokeFakeCast<float, __nv_bfloat16>(float* input_ptr, const size_t size, cudaStream_t stream);
template void invokeFakeCast<__nv_bfloat16, __nv_bfloat16>(
__nv_bfloat16* input_ptr, const size_t size, cudaStream_t stream);
template void invokeFakeCast<half, __nv_bfloat16>(half* input_ptr, const size_t size, cudaStream_t stream);
#endif
template void invokeFakeCast<float, half>(float* input_ptr, const size_t size, cudaStream_t stream);
template void invokeFakeCast<float, float>(float* input_ptr, const size_t size, cudaStream_t stream);
#ifdef ENABLE_FP8
template void invokeFakeCast<float, __nv_fp8_e4m3>(float* input_ptr, const size_t size, cudaStream_t stream);
template void invokeFakeCast<half, __nv_fp8_e4m3>(half* input_ptr, const size_t size, cudaStream_t stream);
template void invokeFakeCast<__nv_bfloat16, __nv_fp8_e4m3>(
__nv_bfloat16* input_ptr, const size_t size, cudaStream_t stream);
#endif
size_t cuda_datatype_size(TRTLLMCudaDataType dt)
{
static const std::unordered_map<TRTLLMCudaDataType, size_t> sizes{
{TRTLLMCudaDataType::FP32, sizeof(float)}, {TRTLLMCudaDataType::FP16, sizeof(half)}
#ifdef ENABLE_BF16
,
{TRTLLMCudaDataType::BF16, sizeof(__nv_bfloat16)}
#endif
};
return sizes.at(dt);
}
template <typename T>
__global__ void check_range(T const* buffer, size_t size, T min, T max, bool* d_within_range)
{
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x)
{
const T val = buffer[i];
if (val < min || val > max)
{
*d_within_range = false;
}
}
}
template <typename T>
bool invokeCheckRange(T const* buffer, const size_t size, T min, T max, bool* d_within_range, cudaStream_t stream)
{
cudaMemsetAsync(d_within_range, true, sizeof(bool), stream);
dim3 block(256);
dim3 grid((size + 255) / 256);
check_range<T><<<grid, block, 0, stream>>>(buffer, size, min, max, d_within_range);
bool result;
cudaD2Hcpy(&result, d_within_range, 1);
return result;
}
template bool invokeCheckRange<int>(
int const* buffer, const size_t size, int min, int max, bool* d_within_range, cudaStream_t stream);
/*
* Determine the total workspace size based on a vector containing multiple variable sizes.
*/
size_t calcAlignedSize(std::vector<size_t> const& sizes, const size_t ALIGN_BYTES)
{
const size_t ALIGN_MASK = ~(ALIGN_BYTES - 1);
// Check ALIGN_BYTES is a power of 2
assert((ALIGN_BYTES & (ALIGN_BYTES - 1)) == 0);
size_t total = 0;
for (auto sz : sizes)
{
total += (sz + ALIGN_BYTES - 1) & ALIGN_MASK;
}
// We add extra "ALIGN_BYTES - 1" bytes in case the start address passed to the function calcAlignedPointers() is
// not aligned.
return total + ALIGN_BYTES - 1;
}
/*
* Given the address of the workspace and the vector containing multiple variable sizes, calculate the start addresses
* of each variable.
*/
void calcAlignedPointers(
std::vector<void*>& outPtrs, void const* p, std::vector<size_t> const& sizes, size_t ALIGN_BYTES)
{
const size_t ALIGN_MASK = ~(ALIGN_BYTES - 1);
// Check ALIGN_BYTES is a power of 2
assert((ALIGN_BYTES & (ALIGN_BYTES - 1)) == 0);
// In case the start address is not aligned
char* ptr = reinterpret_cast<char*>((reinterpret_cast<size_t>(p) + ALIGN_BYTES - 1) & ALIGN_MASK);
outPtrs.reserve(sizes.size());
for (auto sz : sizes)
{
outPtrs.push_back(ptr);
ptr += (sz + ALIGN_BYTES - 1) & ALIGN_MASK;
}
}
} // namespace common
} // namespace tensorrt_llm

View File

@@ -0,0 +1,292 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/common/cudaFp8Utils.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include <cassert>
namespace tensorrt_llm
{
namespace common
{
template <typename T>
void deviceMalloc(T** ptr, size_t size, bool is_random_initialize = true);
template <typename T>
void deviceMemSetZero(T* ptr, size_t size);
template <typename T>
void deviceFree(T*& ptr);
template <typename T>
void deviceFill(T* devptr, size_t size, T value, cudaStream_t stream = 0);
template <typename T>
void cudaD2Hcpy(T* tgt, T const* src, size_t const size);
template <typename T>
void cudaH2Dcpy(T* tgt, T const* src, size_t const size);
template <typename T>
void cudaD2Dcpy(T* tgt, T const* src, size_t const size, cudaStream_t stream = NULL);
template <typename T>
void cudaAutoCpy(T* tgt, T const* src, size_t const size, cudaStream_t stream = NULL);
template <typename T>
void cudaRandomUniform(T* buffer, size_t const size);
template <typename T>
int loadWeightFromBin(T* ptr, std::vector<size_t> shape, std::string filename,
TRTLLMCudaDataType model_file_type = TRTLLMCudaDataType::FP32);
// template<typename T>
// int loadWeightFromBinAndQuantizeForWeightOnly(int8_t* quantized_weight_ptr,
// T* scale_ptr,
// std::vector<size_t> shape,
// std::string filename,
// TRTLLMCudaDataType model_file_type = TRTLLMCudaDataType::FP32);
void invokeCudaD2DcpyHalf2Float(float* dst, half* src, size_t const size, cudaStream_t stream);
void invokeCudaD2DcpyFloat2Half(half* dst, float* src, size_t const size, cudaStream_t stream);
#ifdef ENABLE_FP8
void invokeCudaD2Dcpyfp82Float(float* dst, __nv_fp8_e4m3* src, size_t const size, cudaStream_t stream);
void invokeCudaD2Dcpyfp82Half(half* dst, __nv_fp8_e4m3* src, size_t const size, cudaStream_t stream);
void invokeCudaD2DcpyFloat2fp8(__nv_fp8_e4m3* dst, float* src, size_t const size, cudaStream_t stream);
void invokeCudaD2DcpyHalf2fp8(__nv_fp8_e4m3* dst, half* src, size_t const size, cudaStream_t stream);
void invokeCudaD2DcpyBfloat2fp8(__nv_fp8_e4m3* dst, __nv_bfloat16* src, size_t const size, cudaStream_t stream);
#endif // ENABLE_FP8
#ifdef ENABLE_BF16
void invokeCudaD2DcpyBfloat2Float(float* dst, __nv_bfloat16* src, size_t const size, cudaStream_t stream);
#endif // ENABLE_BF16
template <typename T_OUT, typename T_IN>
void invokeCudaCast(T_OUT* dst, T_IN const* const src, size_t const size, cudaStream_t stream);
////////////////////////////////////////////////////////////////////////////////////////////////////
// The following functions implement conversion of multi-dimensional indices to an index in a flat array.
// The shape of the Tensor dimensions is passed as one array (`dims`), the indices are given as individual arguments.
// For examples on how to use these functions, see their tests `test_memory_utils.cu`.
// All of these functions can be evaluated at compile time by recursive template expansion.
template <typename TDim, typename T, typename TIndex>
__inline__ __host__ __device__ std::enable_if_t<std::is_pointer<TDim>::value, T> constexpr flat_index(
T const& acc, TDim dims, TIndex const& index)
{
assert(index < dims[0]);
return acc * dims[0] + index;
}
template <typename TDim, typename T, typename TIndex, typename... TIndices>
__inline__ __host__ __device__ std::enable_if_t<std::is_pointer<TDim>::value, T> constexpr flat_index(
T const& acc, TDim dims, TIndex const& index, TIndices... indices)
{
assert(index < dims[0]);
return flat_index(acc * dims[0] + index, dims + 1, indices...);
}
template <typename TDim, typename T>
__inline__ __host__ __device__ std::enable_if_t<std::is_pointer<TDim>::value, T> constexpr flat_index(
[[maybe_unused]] TDim dims, T const& index)
{
assert(index < dims[0]);
return index;
}
template <typename TDim, typename TIndex, typename... TIndices>
__inline__ __host__ __device__
std::enable_if_t<std::is_pointer<TDim>::value, typename std::remove_pointer<TDim>::type> constexpr flat_index(
TDim dims, TIndex const& index, TIndices... indices)
{
assert(index < dims[0]);
return flat_index(static_cast<typename std::remove_pointer<TDim>::type>(index), dims + 1, indices...);
}
template <unsigned skip = 0, typename T, std::size_t N, typename TIndex, typename... TIndices>
__inline__ __host__ __device__ T constexpr flat_index(
std::array<T, N> const& dims, TIndex const& index, TIndices... indices)
{
static_assert(skip < N);
static_assert(sizeof...(TIndices) < N - skip, "Number of indices exceeds number of dimensions");
return flat_index(&dims[skip], index, indices...);
}
template <unsigned skip = 0, typename T, typename TIndex, std::size_t N, typename... TIndices>
__inline__ __host__ __device__ T constexpr flat_index(
T const& acc, std::array<T, N> const& dims, TIndex const& index, TIndices... indices)
{
static_assert(skip < N);
static_assert(sizeof...(TIndices) < N - skip, "Number of indices exceeds number of dimensions");
return flat_index(acc, &dims[skip], index, indices...);
}
template <unsigned skip = 0, typename T, typename TIndex, std::size_t N, typename... TIndices>
__inline__ __host__ __device__ T constexpr flat_index(T const (&dims)[N], TIndex const& index, TIndices... indices)
{
static_assert(skip < N);
static_assert(sizeof...(TIndices) < N - skip, "Number of indices exceeds number of dimensions");
return flat_index(static_cast<T const*>(dims) + skip, index, indices...);
}
template <unsigned skip = 0, typename T, typename TIndex, std::size_t N, typename... TIndices>
__inline__ __host__ __device__ T constexpr flat_index(
T const& acc, T const (&dims)[N], TIndex const& index, TIndices... indices)
{
static_assert(skip < N);
static_assert(sizeof...(TIndices) < N - skip, "Number of indices exceeds number of dimensions");
return flat_index(acc, static_cast<T const*>(dims) + skip, index, indices...);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// These are simpler functions for multi-dimensional index conversion. Indices and dimensions are passed as individual
// arguments. These functions are more suitable for usage inside kernels than the corresponding flat_index functions
// which require arrays as arguments. Usage examples can be found in `test_memory_utils.cu`. The functions can be
// evaluated at compile time.
template <typename T, typename TIndex>
__inline__ __host__ __device__ T constexpr flat_index2(TIndex const& index_0, TIndex const& index_1, T const& dim_1)
{
assert(index_1 < dim_1);
return index_0 * dim_1 + index_1;
}
template <typename T, typename TIndex>
__inline__ __host__ __device__ T constexpr flat_index3(
TIndex const& index_0, TIndex const& index_1, TIndex const& index_2, T const& dim_1, T const& dim_2)
{
assert(index_2 < dim_2);
return flat_index2(index_0, index_1, dim_1) * dim_2 + index_2;
}
template <typename T, typename TIndex>
__inline__ __host__ __device__ T constexpr flat_index4(TIndex const& index_0, TIndex const& index_1,
TIndex const& index_2, TIndex const& index_3, T const& dim_1, T const& dim_2, T const& dim_3)
{
assert(index_3 < dim_3);
return flat_index3(index_0, index_1, index_2, dim_1, dim_2) * dim_3 + index_3;
}
template <typename T, typename TIndex>
__inline__ __host__ __device__ T constexpr flat_index5(TIndex const& index_0, TIndex const& index_1,
TIndex const& index_2, TIndex const& index_3, TIndex const& index_4, T const& dim_1, T const& dim_2, T const& dim_3,
T const& dim_4)
{
assert(index_4 < dim_4);
return flat_index4(index_0, index_1, index_2, index_3, dim_1, dim_2, dim_3) * dim_4 + index_4;
}
template <typename T, typename TIndex>
__inline__ __host__ __device__ T constexpr flat_index_strided3(
TIndex const& index_0, TIndex const& index_1, TIndex const& index_2, T const& stride_1, T const& stride_2)
{
assert(index_1 < stride_1 / stride_2);
assert(index_2 < stride_2);
return index_0 * stride_1 + index_1 * stride_2 + index_2;
}
template <typename T, typename TIndex>
__inline__ __host__ __device__ T constexpr flat_index_strided4(TIndex const& index_0, TIndex const& index_1,
TIndex const& index_2, TIndex const& index_3, T const& stride_1, T const& stride_2, T const& stride_3)
{
assert(index_1 < stride_1 / stride_2);
assert(index_2 < stride_2 / stride_3);
assert(index_3 < stride_3);
return index_0 * stride_1 + index_1 * stride_2 + index_2 * stride_3 + index_3;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
void invokeInPlaceTranspose(T* data, T* workspace, size_t const dim0, size_t const dim1);
template <typename T>
void invokeInPlaceTranspose0213(
T* data, T* workspace, size_t const dim0, size_t const dim1, size_t const dim2, size_t const dim3);
template <typename T>
void invokeInPlaceTranspose102(T* data, T* workspace, size_t const dim0, size_t const dim1, size_t const dim2);
template <typename T>
void invokeMultiplyScale(T* tensor, float scale, size_t const size, cudaStream_t stream);
template <typename T>
void invokeDivideScale(T* tensor, float scale, size_t const size, cudaStream_t stream);
template <typename T_IN, typename T_OUT>
void invokeCudaD2DcpyConvert(T_OUT* tgt, const T_IN* src, size_t const size, cudaStream_t stream = 0);
template <typename T_IN, typename T_OUT>
void invokeCudaD2DScaleCpyConvert(
T_OUT* tgt, const T_IN* src, float const* scale, bool invert_scale, size_t const size, cudaStream_t stream = 0);
inline bool checkIfFileExist(std::string const& file_path)
{
std::ifstream in(file_path, std::ios::in | std::ios::binary);
if (in.is_open())
{
in.close();
return true;
}
return false;
}
template <typename T>
void saveToBinary(T const* ptr, size_t const size, std::string filename);
template <typename T_IN, typename T_fake_type>
void invokeFakeCast(T_IN* input_ptr, size_t const size, cudaStream_t stream);
size_t cuda_datatype_size(TRTLLMCudaDataType dt);
template <typename T>
bool invokeCheckRange(T const* buffer, size_t const size, T min, T max, bool* d_within_range, cudaStream_t stream);
constexpr size_t DEFAULT_ALIGN_BYTES = 256;
size_t calcAlignedSize(std::vector<size_t> const& sizes, size_t ALIGN_BYTES = DEFAULT_ALIGN_BYTES);
void calcAlignedPointers(std::vector<void*>& outPtrs, void const* p, std::vector<size_t> const& sizes,
size_t ALIGN_BYTES = DEFAULT_ALIGN_BYTES);
struct AlignedPointersUnpacker
{
template <typename... T>
void operator()(T*&... outPtrs)
{
assert(sizeof...(T) == alignedPointers.size());
auto it = alignedPointers.begin();
((outPtrs = static_cast<T*>(*it++)), ...);
}
std::vector<void*> alignedPointers;
};
AlignedPointersUnpacker inline calcAlignedPointers(
void const* p, std::vector<size_t> const& sizes, size_t ALIGN_BYTES = DEFAULT_ALIGN_BYTES)
{
AlignedPointersUnpacker unpacker{};
calcAlignedPointers(unpacker.alignedPointers, p, sizes, ALIGN_BYTES);
return unpacker;
}
} // namespace common
} // namespace tensorrt_llm

View File

@@ -0,0 +1,588 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <numeric>
#include <unordered_set>
#include "tensorrt_llm/common/mpiUtils.h"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/iBuffer.h"
#include <csignal>
#include <cstdlib>
#include <mutex>
#include <thread>
#include <type_traits>
#ifndef _WIN32
#include <unistd.h>
#endif
// We rely on SizeType32 being int32_t in some places with weak type checking,
// i.e. we're passing void ptr to some function. To prevent mysterious errors
// in the future, we trigger a compilation error here if SizeType32 isn't int32_t.
static_assert(std::is_same<tensorrt_llm::runtime::SizeType32, std::int32_t>::value);
namespace tensorrt_llm::mpi
{
MPI_Datatype getMpiDtype(MpiType dtype)
{
#if ENABLE_MULTI_DEVICE
static std::unordered_map<MpiType, MPI_Datatype> const dtype_map{
{MpiType::kBYTE, MPI_BYTE},
{MpiType::kHALF, MPI_UINT16_T},
{MpiType::kFLOAT, MPI_FLOAT},
{MpiType::kDOUBLE, MPI_DOUBLE},
{MpiType::kBOOL, MPI_C_BOOL},
{MpiType::kINT8, MPI_INT8_T},
{MpiType::kUINT8, MPI_UINT8_T},
{MpiType::kINT32, MPI_INT32_T},
{MpiType::kUINT32, MPI_UINT32_T},
{MpiType::kINT64, MPI_INT64_T},
{MpiType::kUINT64, MPI_UINT64_T},
{MpiType::kFP8, MPI_UINT8_T},
{MpiType::kBF16, MPI_UINT16_T},
{MpiType::kCHAR, MPI_CHAR},
};
return dtype_map.at(dtype);
#else
TLLM_THROW("Multi device support is disabled.");
#endif
}
MPI_Op getMpiOp(MpiOp op)
{
#if ENABLE_MULTI_DEVICE
static std::unordered_map<MpiOp, MPI_Op> const op_map{
{MpiOp::NULLOP, MPI_OP_NULL},
{MpiOp::MAX, MPI_MAX},
{MpiOp::MIN, MPI_MIN},
{MpiOp::SUM, MPI_SUM},
{MpiOp::PROD, MPI_PROD},
{MpiOp::LAND, MPI_LAND},
{MpiOp::BAND, MPI_BAND},
{MpiOp::LOR, MPI_LOR},
{MpiOp::BOR, MPI_BOR},
{MpiOp::LXOR, MPI_LXOR},
{MpiOp::BXOR, MPI_BXOR},
{MpiOp::MINLOC, MPI_MINLOC},
{MpiOp::MAXLOC, MPI_MAXLOC},
{MpiOp::REPLACE, MPI_REPLACE},
};
return op_map.at(op);
#else
TLLM_THROW("Multi device support is disabled.");
#endif // ENABLE_MULTI_DEVICE
}
namespace
{
bool mpiInitialized = false;
std::recursive_mutex mpiMutex;
MpiComm initLocalSession()
{
#if ENABLE_MULTI_DEVICE
MPI_Comm localComm = nullptr;
MPI_Comm_split_type(COMM_SESSION, OMPI_COMM_TYPE_HOST, COMM_SESSION.getRank(), MPI_INFO_NULL, &localComm);
MpiComm localSession{localComm, false};
#else
MpiComm localSession{COMM_SESSION, false};
#endif // ENABLE_MULTI_DEVICE
return localSession;
}
} // namespace
std::vector<int> getWorldRanks(MpiComm const& comm)
{
#if ENABLE_MULTI_DEVICE
MPI_Group group = nullptr;
MPI_Group worldGroup = nullptr;
MPICHECK(MPI_Comm_group(MPI_COMM_WORLD, &worldGroup));
MPICHECK(MPI_Comm_group(comm, &group));
int groupSize = 0;
MPICHECK(MPI_Group_size(group, &groupSize));
std::vector<int> ranks(groupSize);
std::vector<int> worldRanks(groupSize);
std::iota(ranks.begin(), ranks.end(), 0);
MPICHECK(MPI_Group_translate_ranks(group, groupSize, ranks.data(), worldGroup, worldRanks.data()));
MPICHECK(MPI_Group_free(&group));
MPICHECK(MPI_Group_free(&worldGroup));
#else
std::vector<int> worldRanks{0};
#endif
return worldRanks;
}
void initialize(MpiThreadSupport threadMode, bool forwardAbortToParent)
{
// double-checked locking
if (mpiInitialized)
{
return;
}
std::lock_guard<std::recursive_mutex> lk(mpiMutex);
if (mpiInitialized)
{
return;
}
#if ENABLE_MULTI_DEVICE
int initialized = 0;
TLLM_MPI_CHECK(MPI_Initialized(&initialized));
if (!initialized)
{
TLLM_LOG_INFO("Initializing MPI with thread mode %d", threadMode);
int providedMode = 0;
auto requiredMode = static_cast<int>(threadMode);
MPICHECK(MPI_Init_thread(nullptr, nullptr, requiredMode, &providedMode));
TLLM_CHECK_WITH_INFO(providedMode >= requiredMode, "MPI_Init_thread failed");
std::atexit([]() { MPI_Finalize(); });
/*
* We only catch SIGABRT and SIGSEGV because most, of not all errors in the worker will cause one of these 2
* signals. Signals like SIGINT and SIGTERM should be issued to the parent and should terminate MPI workers
* correctly.
*/
for (int sig : {SIGABRT, SIGSEGV})
{
__sighandler_t previousHandler = nullptr;
if (forwardAbortToParent)
{
previousHandler = std::signal(sig,
[](int signal)
{
#ifndef _WIN32
pid_t parentProcessId = getppid();
kill(parentProcessId, SIGKILL);
#endif
MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE);
});
}
else
{
previousHandler = std::signal(sig, [](int signal) { MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE); });
}
TLLM_CHECK_WITH_INFO(previousHandler != SIG_ERR, "Signal handler setup failed");
}
// ensure local MPI communicator is initialized
MpiComm::localSession();
TLLM_LOG_INFO("Initialized MPI");
}
#endif // ENABLE_MULTI_DEVICE
mpiInitialized = true;
}
void MpiComm::barrier() const
{
#if ENABLE_MULTI_DEVICE
MPICHECK(MPI_Barrier(mComm));
#else
TLLM_THROW("Multi device support is disabled.");
#endif // ENABLE_MULTI_DEVICE
}
#if ENABLE_MULTI_DEVICE
template <typename TMpiFunc, typename TBase, typename... TArgs,
typename = std::enable_if_t<std::is_same_v<void, std::remove_const_t<TBase>>>>
size_t invokeChunked(TMpiFunc func, TBase* buffer, size_t size, MPI_Datatype dtype, TArgs... args)
{
constexpr auto maxP1 = static_cast<size_t>(std::numeric_limits<int>::max()) + 1;
if (TLLM_LIKELY(size < maxP1))
{
MPICHECK(func(buffer, size, dtype, args...));
return 1;
}
constexpr size_t alignment = 256;
int elementSize = 1;
MPICHECK(MPI_Type_size(dtype, &elementSize));
elementSize = std::min<int>(elementSize, alignment);
// We cap at max alignment-bytes chunks that can be sent at once.
auto const step = maxP1 - (alignment / elementSize);
using TCast = std::conditional_t<std::is_const_v<TBase>, uint8_t const, uint8_t>;
size_t count = 0;
while (size != 0)
{
auto currentStep = static_cast<int>(std::min(size, step));
MPICHECK(func(buffer, currentStep, dtype, args...));
size -= currentStep;
size_t diff = static_cast<size_t>(currentStep) * elementSize;
buffer = static_cast<TCast*>(buffer) + diff;
++count;
}
return count;
}
#endif // ENABLE_MULTI_DEVICE
std::shared_ptr<MpiRequest> MpiComm::bcastAsync(void* buffer, size_t size, MpiType dtype, int root) const
{
std::shared_ptr<MpiRequest> r = std::make_shared<MpiRequest>();
#if ENABLE_MULTI_DEVICE
invokeChunked(MPI_Ibcast, buffer, size, getMpiDtype(dtype), root, mComm, &r->mRequest);
#else
TLLM_THROW("Multi device support is disabled.");
#endif // ENABLE_MULTI_DEVICE
return r;
}
std::shared_ptr<MpiRequest> MpiComm::bcastAsync(runtime::IBuffer& buf, int root) const
{
TLLM_CHECK(buf.getMemoryType() != runtime::MemoryType::kGPU);
return bcastAsync(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, root);
}
void MpiComm::bcast(void* buffer, size_t size, MpiType dtype, int root) const
{
#if ENABLE_MULTI_DEVICE
invokeChunked(MPI_Bcast, buffer, size, getMpiDtype(dtype), root, mComm);
#else
TLLM_THROW("Multi device support is disabled.");
#endif // ENABLE_MULTI_DEVICE
}
void MpiComm::bcast(runtime::IBuffer& buf, int root) const
{
bcast(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, root);
}
std::shared_ptr<MpiRequest> MpiComm::sendAsync(void const* buffer, size_t size, MpiType dtype, int dest, int tag) const
{
TLLM_LOG_DEBUG("start MPI_Isend with size %d", size);
std::shared_ptr<MpiRequest> r = std::make_shared<MpiRequest>();
#if ENABLE_MULTI_DEVICE
invokeChunked(MPI_Isend, buffer, size, getMpiDtype(dtype), dest, tag, mComm, &r->mRequest);
#else
TLLM_THROW("Multi device support is disabled.");
#endif
TLLM_LOG_DEBUG("end MPI_Isend with size %d", size);
return r;
}
std::shared_ptr<MpiRequest> MpiComm::sendAsync(runtime::IBuffer const& buf, int dest, int tag) const
{
return sendAsync(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, dest, tag);
}
void MpiComm::send(void const* buffer, size_t size, MpiType dtype, int dest, int tag) const
{
TLLM_LOG_DEBUG("start MPI_Send with size %d", size);
#if ENABLE_MULTI_DEVICE
invokeChunked(MPI_Send, buffer, size, getMpiDtype(dtype), dest, tag, mComm);
#else
TLLM_THROW("Multi device support is disabled.");
#endif // ENABLE_MULTI_DEVICE
TLLM_LOG_DEBUG("end MPI_Send with size %d", size);
}
void MpiComm::send(runtime::IBuffer const& buf, int dest, int tag) const
{
send(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, dest, tag);
}
MPI_Status MpiComm::recv(void* buffer, size_t size, MpiType dtype, int source, int tag) const
{
TLLM_LOG_DEBUG("start MPI_Recv with size %d", size);
MPI_Status status{};
#if ENABLE_MULTI_DEVICE
invokeChunked(MPI_Recv, buffer, size, getMpiDtype(dtype), source, tag, mComm, &status);
#else
TLLM_THROW("Multi device support is disabled.");
#endif // ENABLE_MULTI_DEVICE
TLLM_LOG_DEBUG("end MPI_Recv with size %d", size);
return status;
}
MPI_Status MpiComm::recv(runtime::IBuffer& buf, int source, int tag) const
{
return recv(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, source, tag);
}
MpiComm MpiComm::split(int color, int key) const
{
MPI_Comm splitComm = nullptr;
#if ENABLE_MULTI_DEVICE
MPICHECK(MPI_Comm_split(mComm, color, key, &splitComm));
#else
TLLM_THROW("Multi device support is disabled.");
#endif // ENABLE_MULTI_DEVICE
return MpiComm{splitComm, true};
}
void MpiComm::allreduce(void const* sendbuf, void* recvbuf, int count, MpiType dtype, MpiOp op) const
{
#if ENABLE_MULTI_DEVICE
MPICHECK(MPI_Allreduce(sendbuf, recvbuf, count, getMpiDtype(dtype), getMpiOp(op), mComm));
#else
TLLM_THROW("Multi device support is disabled.");
#endif // ENABLE_MULTI_DEVICE
}
void MpiComm::allgather(void const* sendbuf, void* recvbuf, int count, MpiType dtype) const
{
#if ENABLE_MULTI_DEVICE
MPICHECK(MPI_Allgather(sendbuf, count, getMpiDtype(dtype), recvbuf, count, getMpiDtype(dtype), mComm));
#else
TLLM_THROW("Multi device support is disabled.");
#endif // ENABLE_MULTI_DEVICE
}
void MpiComm::allgatherv(void const* sendbuf, int sendcount, MpiType sendtype, void* recvbuf,
std::vector<int> const& recvcounts, std::vector<int> const& displs, MpiType recvtype) const
{
#if ENABLE_MULTI_DEVICE
MPICHECK(MPI_Allgatherv(sendbuf, sendcount, getMpiDtype(sendtype), recvbuf, recvcounts.data(), displs.data(),
getMpiDtype(recvtype), mComm));
#else
TLLM_THROW("Multi device support is disabled.");
#endif // ENABLE_MULTI_DEVICE
}
void MpiComm::mprobe(int source, int tag, MPI_Message* msg, MPI_Status* status) const
{
#if ENABLE_MULTI_DEVICE
MPICHECK(MPI_Mprobe(source, tag, mComm, msg, status));
#else
TLLM_THROW("Multi device support is disabled.");
#endif // ENABLE_MULTI_DEVICE
}
bool MpiComm::improbe(int source, int tag, MPI_Message* msg, MPI_Status* status) const
{
#if ENABLE_MULTI_DEVICE
int flag{0};
MPICHECK(MPI_Improbe(source, tag, mComm, &flag, msg, status));
return flag != 0;
#else
TLLM_THROW("Multi device support is disabled.");
return false;
#endif
}
bool MpiComm::iprobe(int source, int tag, MPI_Status* status) const
{
#if ENABLE_MULTI_DEVICE
int flag{0};
MPICHECK(MPI_Iprobe(source, tag, mComm, &flag, status));
return flag != 0;
#else
TLLM_THROW("Multi device support is disabled.");
return false;
#endif
}
void MpiComm::recvPoll(int source, int tag, int periodMs) const
{
MPI_Status status;
while (!iprobe(source, tag, &status))
{
std::this_thread::sleep_for(std::chrono::milliseconds(periodMs));
}
}
int MpiComm::getRank() const
{
int rank = 0;
#if ENABLE_MULTI_DEVICE
MPICHECK(MPI_Comm_rank(mComm, &rank));
#endif
return rank;
}
int MpiComm::getSize() const
{
int world_size = 1;
#if ENABLE_MULTI_DEVICE
MPICHECK(MPI_Comm_size(mComm, &world_size));
#endif
return world_size;
}
MpiComm const& MpiComm::world()
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
static MpiComm commWorld{MPI_COMM_WORLD, false};
initialize();
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
return commWorld;
}
MpiComm& MpiComm::mutableSession()
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
static MpiComm commSession{MPI_COMM_WORLD, false};
initialize();
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
return commSession;
}
MpiComm& MpiComm::mutableLocalSession()
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
static MpiComm localSession = initLocalSession();
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
return localSession;
}
void MpiComm::refreshLocalSession()
{
#if ENABLE_MULTI_DEVICE
static std::mutex mutex;
std::unique_lock lock(mutex);
auto initSessionRanks = getWorldRanks(MpiComm::session());
auto localSessionRanks = getWorldRanks(MpiComm::localSession());
// Add to intersectionRanks in order of initSessionRanks
std::vector<int> intersectionRanks;
std::unordered_set<int> localSessionRanksSet(localSessionRanks.begin(), localSessionRanks.end());
for (auto rank : initSessionRanks)
{
if (localSessionRanksSet.find(rank) != localSessionRanksSet.end())
{
intersectionRanks.push_back(rank);
}
}
MPI_Group worldGroup = nullptr;
MPICHECK(MPI_Comm_group(MPI_COMM_WORLD, &worldGroup));
MPI_Group localGroup = nullptr;
MPICHECK(MPI_Group_incl(worldGroup, intersectionRanks.size(), intersectionRanks.data(), &localGroup));
MPI_Comm localComm = nullptr;
MPICHECK(MPI_Comm_create_group(MPI_COMM_WORLD, localGroup, intersectionRanks.front(), &localComm));
MpiComm::mutableLocalSession().mFreeComm = true;
MpiComm::mutableLocalSession() = MpiComm{localComm, false};
TLLM_LOG_INFO("Refreshed the MPI local session");
#endif // ENABLE_MULTI_DEVICE
}
MpiComm::MpiComm(MPI_Comm g, bool freeComm)
: mComm{g}
, mFreeComm{freeComm}
{
TLLM_CHECK(mComm != MPI_COMM_NULL);
}
MpiComm::~MpiComm() noexcept
{
#if ENABLE_MULTI_DEVICE
if (mFreeComm && mComm)
{
if (MPI_Comm_free(&mComm) != MPI_SUCCESS)
{
TLLM_LOG_ERROR("MPI_Comm_free failed");
}
}
#endif // ENABLE_MULTI_DEVICE
}
MpiComm::MpiComm(MpiComm&& comm) noexcept
: mComm{comm.mComm}
, mFreeComm{comm.mFreeComm}
{
comm.mFreeComm = false;
}
MpiComm& MpiComm::operator=(MpiComm&& comm) noexcept
{
this->~MpiComm();
mComm = comm.mComm;
mFreeComm = comm.mFreeComm;
comm.mFreeComm = false;
return *this;
}
MpiWaitThread::MpiWaitThread(std::string name, std::function<void()> funcWait, std::function<void()> funcSetup)
: mName{name.c_str()}
, mFuncWait{funcWait}
, mFuncSetup{funcSetup}
{
TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__);
mThread = std::make_unique<std::thread>(&MpiWaitThread::sideThread, this);
TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__);
}
MpiWaitThread::~MpiWaitThread()
{
TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__);
waitStop();
mShouldExit.store(true);
notifyStart();
mThread->join();
mThread.reset(nullptr);
TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__);
}
void MpiWaitThread::sideThread()
{
if (mFuncSetup)
{
mFuncSetup();
}
while (!mShouldExit.load())
{
notifyStop();
waitStart();
mFuncWait();
}
}
void MpiWaitThread::waitStart()
{
TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__);
std::unique_lock<std::mutex> lock(mMutex);
mCondVar.wait(lock, [this] { return mRunning; });
TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__);
}
void MpiWaitThread::waitStop()
{
TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__);
std::unique_lock<std::mutex> lock(mMutex);
mCondVar.wait(lock, [this] { return !mRunning; });
TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__);
}
void MpiWaitThread::notifyStart()
{
TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__);
std::lock_guard<std::mutex> lock(mMutex);
mRunning = true;
mCondVar.notify_one();
TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__);
}
void MpiWaitThread::notifyStop()
{
TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__);
std::lock_guard<std::mutex> lock(mMutex);
mRunning = false;
mCondVar.notify_one();
TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__);
}
} // namespace tensorrt_llm::mpi

View File

@@ -0,0 +1,46 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <nvtx3/nvtx3.hpp>
#include <array>
namespace tensorrt_llm::common::nvtx
{
inline nvtx3::color nextColor()
{
#ifndef NVTX_DISABLE
constexpr std::array kColors{nvtx3::color{0xff00ff00}, nvtx3::color{0xff0000ff}, nvtx3::color{0xffffff00},
nvtx3::color{0xffff00ff}, nvtx3::color{0xff00ffff}, nvtx3::color{0xffff0000}, nvtx3::color{0xffffffff}};
constexpr auto numColors = kColors.size();
static thread_local std::size_t colorId = 0;
auto const color = kColors[colorId];
colorId = colorId + 1 >= numColors ? 0 : colorId + 1;
return color;
#else
return nvtx3::color{0};
#endif
}
} // namespace tensorrt_llm::common::nvtx
#define NVTX3_SCOPED_RANGE_WITH_NAME(range, name) \
::nvtx3::scoped_range range(::tensorrt_llm::common::nvtx::nextColor(), name)
#define NVTX3_SCOPED_RANGE(range) NVTX3_SCOPED_RANGE_WITH_NAME(range##_range, #range)

View File

@@ -0,0 +1,323 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/opUtils.h"
#include "tensorrt_llm/common/mpiUtils.h"
#include "cuda.h"
#include <cstdint>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
#include <functional>
#include <mutex>
#include <thread>
#ifdef _MSC_VER
#define FN_NAME __FUNCTION__
#else
#define FN_NAME __func__
#endif
#if ENABLE_MULTI_DEVICE
std::unordered_map<nvinfer1::DataType, ncclDataType_t>* getDtypeMap()
{
static std::unordered_map<nvinfer1::DataType, ncclDataType_t> dtypeMap = {{nvinfer1::DataType::kFLOAT, ncclFloat32},
{nvinfer1::DataType::kHALF, ncclFloat16}, {nvinfer1::DataType::kBF16, ncclBfloat16}};
return &dtypeMap;
}
namespace
{
// Get NCCL unique ID for a group of ranks.
ncclUniqueId getUniqueId(std::set<int> const& group) noexcept
{
auto const rank = COMM_SESSION.getRank();
TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, rank);
ncclUniqueId id;
if (rank == *group.begin())
{
NCCLCHECK(ncclGetUniqueId(&id));
for (auto it = std::next(std::begin(group), 1); it != group.end(); ++it)
{
COMM_SESSION.sendValue(id, *it, 0);
}
}
else
{
COMM_SESSION.recvValue(id, *group.begin(), 0);
}
TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, rank);
return id;
}
} // namespace
std::shared_ptr<ncclComm_t> getComm(std::set<int> const& group)
{
auto const rank = COMM_SESSION.getRank();
TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, rank);
static std::map<std::set<int>, std::shared_ptr<ncclComm_t>> commMap;
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);
std::ostringstream oss;
int index = 0;
for (auto const& rank : group)
{
if (index != 0)
{
oss << ",";
}
oss << rank;
index++;
}
auto groupStr = oss.str();
auto it = commMap.find(group);
if (it != commMap.end())
{
auto ncclComm = it->second;
TLLM_LOG_TRACE("NCCL comm for group(%s) is cached for rank %d", groupStr.c_str(), rank);
return ncclComm;
}
TLLM_LOG_TRACE("Init NCCL comm for group(%s) for rank %d", groupStr.c_str(), rank);
ncclUniqueId id = getUniqueId(group);
int groupRank = 0;
for (auto const& currentRank : group)
{
if (rank == currentRank)
break;
++groupRank;
}
TLLM_CHECK(groupRank < group.size());
std::shared_ptr<ncclComm_t> ncclComm(new ncclComm_t,
[](ncclComm_t* comm)
{
ncclCommDestroy(*comm);
delete comm;
});
NCCLCHECK(ncclCommInitRank(ncclComm.get(), group.size(), id, groupRank));
commMap[group] = ncclComm;
TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, rank);
return ncclComm;
}
#endif // ENABLE_MULTI_DEVICE
void const* tensorrt_llm::common::getCommSessionHandle()
{
#if ENABLE_MULTI_DEVICE
return &COMM_SESSION;
#else
return nullptr;
#endif // ENABLE_MULTI_DEVICE
}
namespace
{
// Get current cuda context, a default context will be created if there is no context.
inline CUcontext getCurrentCudaCtx()
{
CUcontext ctx{};
CUresult err = cuCtxGetCurrent(&ctx);
if (err == CUDA_ERROR_NOT_INITIALIZED || ctx == nullptr)
{
TLLM_CUDA_CHECK(cudaFree(nullptr));
err = cuCtxGetCurrent(&ctx);
}
TLLM_CHECK(err == CUDA_SUCCESS);
return ctx;
}
// Helper to create per-cuda-context singleton managed by std::shared_ptr.
// Unlike conventional singletons, singleton created with this will be released
// when not needed, instead of on process exit.
// Objects of this class shall always be declared static / global, and shall never own CUDA
// resources.
template <typename T>
class PerCudaCtxSingletonCreator
{
public:
using CreatorFunc = std::function<std::unique_ptr<T>()>;
using DeleterFunc = std::function<void(T*)>;
// creator returning std::unique_ptr is by design.
// It forces separation of memory for T and memory for control blocks.
// So when T is released, but we still have observer weak_ptr in mObservers, the T mem block can be released.
// creator itself must not own CUDA resources. Only the object it creates can.
PerCudaCtxSingletonCreator(CreatorFunc creator, DeleterFunc deleter)
: mCreator{std::move(creator)}
, mDeleter{std::move(deleter)}
{
}
std::shared_ptr<T> operator()()
{
std::lock_guard<std::mutex> lk{mMutex};
CUcontext ctx{getCurrentCudaCtx()};
std::shared_ptr<T> result = mObservers[ctx].lock();
if (result == nullptr)
{
// Create the resource and register with an observer.
result = std::shared_ptr<T>{mCreator().release(),
[this, ctx](T* obj)
{
if (obj == nullptr)
{
return;
}
mDeleter(obj);
// Clears observer to avoid growth of mObservers, in case users creates/destroys cuda contexts
// frequently.
std::shared_ptr<T> observedObjHolder; // Delay destroy to avoid dead lock.
std::lock_guard<std::mutex> lk{mMutex};
// Must check observer again because another thread may created new instance for this ctx just
// before we lock mMutex. We can't infer that the observer is stale from the fact that obj is
// destroyed, because shared_ptr ref-count checking and observer removing are not in one atomic
// operation, and the observer may be changed to observe another instance.
observedObjHolder = mObservers.at(ctx).lock();
if (observedObjHolder == nullptr)
{
mObservers.erase(ctx);
}
}};
mObservers.at(ctx) = result;
}
return result;
}
private:
CreatorFunc mCreator;
DeleterFunc mDeleter;
mutable std::mutex mMutex;
// CUDA resources are per-context.
std::unordered_map<CUcontext, std::weak_ptr<T>> mObservers;
};
template <typename T>
class PerThreadSingletonCreator
{
public:
using CreatorFunc = std::function<std::unique_ptr<T>()>;
using DeleterFunc = std::function<void(T*)>;
// creator returning std::unique_ptr is by design.
// It forces separation of memory for T and memory for control blocks.
// So when T is released, but we still have observer weak_ptr in mObservers, the T mem block can be released.
// creator itself must not own CUDA resources. Only the object it creates can.
PerThreadSingletonCreator(CreatorFunc creator, DeleterFunc deleter)
: mCreator{std::move(creator)}
, mDeleter{std::move(deleter)}
{
}
std::shared_ptr<T> operator()()
{
std::lock_guard<std::mutex> lk{mMutex};
std::thread::id thread = std::this_thread::get_id();
std::shared_ptr<T> result = mObservers[thread].lock();
if (result == nullptr)
{
// Create the resource and register with an observer.
result = std::shared_ptr<T>{mCreator().release(),
[this, thread](T* obj)
{
if (obj == nullptr)
{
return;
}
mDeleter(obj);
// Clears observer to avoid growth of mObservers, in case users creates/destroys cuda contexts
// frequently.
std::shared_ptr<T> observedObjHolder; // Delay destroy to avoid dead lock.
std::lock_guard<std::mutex> lk{mMutex};
// Must check observer again because another thread may created new instance for this ctx just
// before we lock mMutex. We can't infer that the observer is stale from the fact that obj is
// destroyed, because shared_ptr ref-count checking and observer removing are not in one atomic
// operation, and the observer may be changed to observe another instance.
observedObjHolder = mObservers.at(thread).lock();
if (observedObjHolder == nullptr)
{
mObservers.erase(thread);
}
}};
mObservers.at(thread) = result;
}
return result;
}
private:
CreatorFunc mCreator;
DeleterFunc mDeleter;
mutable std::mutex mMutex;
// CUDA resources are per-thread.
std::unordered_map<std::thread::id, std::weak_ptr<T>> mObservers;
};
} // namespace
std::shared_ptr<cublasHandle_t> getCublasHandle()
{
static PerThreadSingletonCreator<cublasHandle_t> creator(
[]() -> auto
{
auto handle = std::unique_ptr<cublasHandle_t>(new cublasHandle_t);
TLLM_CUDA_CHECK(cublasCreate(handle.get()));
return handle;
},
[](cublasHandle_t* handle)
{
TLLM_CUDA_CHECK(cublasDestroy(*handle));
delete handle;
});
return creator();
}
std::shared_ptr<cublasLtHandle_t> getCublasLtHandle()
{
static PerThreadSingletonCreator<cublasLtHandle_t> creator(
[]() -> auto
{
auto handle = std::unique_ptr<cublasLtHandle_t>(new cublasLtHandle_t);
TLLM_CUDA_CHECK(cublasLtCreate(handle.get()));
return handle;
},
[](cublasLtHandle_t* handle)
{
TLLM_CUDA_CHECK(cublasLtDestroy(*handle));
delete handle;
});
return creator();
}
std::shared_ptr<tensorrt_llm::common::CublasMMWrapper> getCublasMMWrapper(std::shared_ptr<cublasHandle_t> cublasHandle,
std::shared_ptr<cublasLtHandle_t> cublasltHandle, cudaStream_t stream, void* workspace)
{
static PerThreadSingletonCreator<tensorrt_llm::common::CublasMMWrapper> creator(
[cublasHandle, cublasltHandle, stream, workspace]() -> auto
{
auto wrapper = std::unique_ptr<tensorrt_llm::common::CublasMMWrapper>(
new tensorrt_llm::common::CublasMMWrapper(cublasHandle, cublasltHandle, stream, workspace));
return wrapper;
},
[](tensorrt_llm::common::CublasMMWrapper* wrapper) { delete wrapper; });
return creator();
}

View File

@@ -0,0 +1,215 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/common/cublasMMWrapper.h"
#include "tensorrt_llm/common/workspace.h"
#include <NvInferRuntime.h>
#include <cublasLt.h>
#include <cublas_v2.h>
#include <cuda_runtime.h>
#if ENABLE_MULTI_DEVICE
#include <nccl.h>
#endif // ENABLE_MULTI_DEVICE
#include <cstring>
#include <map>
#include <memory>
#include <nvml.h>
#include <optional>
#include <set>
#include <string>
#include <unordered_map>
namespace tensorrt_llm::common
{
// Write values into buffer
template <typename T>
void write(char*& buffer, T const& val)
{
std::memcpy(buffer, &val, sizeof(T));
buffer += sizeof(T);
}
// Read values from buffer
template <typename T>
void read(char const*& buffer, T& val)
{
std::memcpy(&val, buffer, sizeof(T));
buffer += sizeof(T);
}
// Like std::unique_ptr, but does not prevent generation of default copy constructor when used as class members.
// The copy constructor produces nullptr. So the plugin default copy constructor will not really copy this, and
// your clone() implementation is responsible for initializing such data members.
// With this we can simplify clone() implementation when there are many data members including at least one unique_ptr.
template <typename T, typename Del = std::default_delete<T>>
class UniqPtrWNullCopy : public std::unique_ptr<T, Del>
{
public:
using std::unique_ptr<T, Del>::unique_ptr;
// for compatibility with std::make_unique
explicit UniqPtrWNullCopy(std::unique_ptr<T, Del>&& src)
: std::unique_ptr<T, Del>::unique_ptr{std::move(src)}
{
}
// copy constructor produces nullptr
UniqPtrWNullCopy(UniqPtrWNullCopy const&)
: std::unique_ptr<T, Del>::unique_ptr{}
{
}
};
// for testing only
void const* getCommSessionHandle();
} // namespace tensorrt_llm::common
inline bool isBuilding()
{
auto constexpr key = "IS_BUILDING";
auto const val = getenv(key);
return val != nullptr && std::string(val) == "1";
}
#if ENABLE_MULTI_DEVICE
#define NCCLCHECK(cmd) \
do \
{ \
ncclResult_t r = cmd; \
if (r != ncclSuccess) \
{ \
printf("Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, ncclGetErrorString(r)); \
exit(EXIT_FAILURE); \
} \
} while (0)
std::unordered_map<nvinfer1::DataType, ncclDataType_t>* getDtypeMap();
std::shared_ptr<ncclComm_t> getComm(std::set<int> const& group);
#endif // ENABLE_MULTI_DEVICE
//! To save GPU memory, all the plugins share the same cublas and cublasLt handle globally.
//! Get cublas and cublasLt handle for current cuda context
std::shared_ptr<cublasHandle_t> getCublasHandle();
std::shared_ptr<cublasLtHandle_t> getCublasLtHandle();
std::shared_ptr<tensorrt_llm::common::CublasMMWrapper> getCublasMMWrapper(std::shared_ptr<cublasHandle_t> cublasHandle,
std::shared_ptr<cublasLtHandle_t> cublasltHandle, cudaStream_t stream, void* workspace);
#ifndef DEBUG
#define PLUGIN_CHECK(status) \
do \
{ \
if (status != 0) \
abort(); \
} while (0)
#define ASSERT_PARAM(exp) \
do \
{ \
if (!(exp)) \
return STATUS_BAD_PARAM; \
} while (0)
#define ASSERT_FAILURE(exp) \
do \
{ \
if (!(exp)) \
return STATUS_FAILURE; \
} while (0)
#define CSC(call, err) \
do \
{ \
cudaError_t cudaStatus = call; \
if (cudaStatus != cudaSuccess) \
{ \
return err; \
} \
} while (0)
#define DEBUG_PRINTF(...) \
do \
{ \
} while (0)
#else
#define ASSERT_PARAM(exp) \
do \
{ \
if (!(exp)) \
{ \
fprintf(stderr, "Bad param - " #exp ", %s:%d\n", __FILE__, __LINE__); \
return STATUS_BAD_PARAM; \
} \
} while (0)
#define ASSERT_FAILURE(exp) \
do \
{ \
if (!(exp)) \
{ \
fprintf(stderr, "Failure - " #exp ", %s:%d\n", __FILE__, __LINE__); \
return STATUS_FAILURE; \
} \
} while (0)
#define CSC(call, err) \
do \
{ \
cudaError_t cudaStatus = call; \
if (cudaStatus != cudaSuccess) \
{ \
printf("%s %d CUDA FAIL %s\n", __FILE__, __LINE__, cudaGetErrorString(cudaStatus)); \
return err; \
} \
} while (0)
#define PLUGIN_CHECK(status) \
{ \
if (status != 0) \
{ \
DEBUG_PRINTF("%s %d CUDA FAIL %s\n", __FILE__, __LINE__, cudaGetErrorString(status)); \
abort(); \
} \
}
#define DEBUG_PRINTF(...) \
do \
{ \
printf(__VA_ARGS__); \
} while (0)
#endif // DEBUG
#define NVML_CHECK(cmd) \
do \
{ \
nvmlReturn_t r = cmd; \
if (r != NVML_SUCCESS) \
{ \
printf("Failed, NVML error %s:%d '%s'\n", __FILE__, __LINE__, nvmlErrorString(r)); \
exit(EXIT_FAILURE); \
} \
} while (0)

View File

@@ -0,0 +1,55 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/common/cudaBf16Fallbacks.cuh"
#include "tensorrt_llm/common/cudaFp8Utils.h"
#include <cuda.h>
#include <cuda_fp16.h>
#include <float.h>
namespace tensorrt_llm
{
namespace common
{
template <typename T>
struct QuantTypeStaticVals;
template <>
struct QuantTypeStaticVals<int8_t>
{
static constexpr float MAX_VAL = 127.f;
static constexpr float MIN_SCALING_FACTOR = 0.f;
static constexpr float MIN_SCALING_FACTOR_RCP = FLT_MAX;
};
#ifdef ENABLE_FP8
template <>
struct QuantTypeStaticVals<__nv_fp8_e4m3>
{
static constexpr float MAX_VAL = 448.f;
// Ref: https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu#L720
static constexpr float MIN_SCALING_FACTOR = 1.0f / (448.f * 512.f);
static constexpr float MIN_SCALING_FACTOR_RCP = (448.f * 512.f);
};
#endif // ENABLE_FP8
} // namespace common
} // namespace tensorrt_llm

View File

@@ -0,0 +1,399 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <array>
#include <assert.h>
#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))
#include <cooperative_groups/reduce.h>
#else
#include <cooperative_groups.h>
#endif
#include "tensorrt_llm/common/cudaTypeUtils.cuh"
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <curand_kernel.h>
#include <float.h>
#include <type_traits>
namespace cg = cooperative_groups;
namespace tensorrt_llm
{
namespace common
{
template <int VPT>
struct BytesToType;
template <>
struct BytesToType<1>
{
using type = uint8_t;
};
template <>
struct BytesToType<2>
{
using type = uint16_t;
};
template <>
struct BytesToType<4>
{
using type = uint32_t;
};
template <>
struct BytesToType<8>
{
using type = uint64_t;
};
template <>
struct BytesToType<16>
{
using type = float4;
};
template <int Bytes>
__device__ inline void copy(void const* local, void* data)
{
using T = typename BytesToType<Bytes>::type;
T const* in = static_cast<T const*>(local);
T* out = static_cast<T*>(data);
*out = *in;
}
static float constexpr HALF_FLT_MAX = 65504.F;
#define FINAL_MASK 0xffffffff
template <typename T>
__inline__ __device__ T warpReduceSum(T val)
{
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val = add<T>(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32)); //__shfl_sync bf16 return float when sm < 80
return val;
}
/* Calculate the sum of all elements in a block */
template <typename T>
__inline__ __device__ T blockReduceSum(T val)
{
static __shared__ T shared[32];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
val = warpReduceSum<T>(val);
if (lane == 0)
shared[wid] = val;
__syncthreads();
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
// blockDim.x is not divided by 32
val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T) (0.0f);
val = warpReduceSum<T>(val);
return val;
}
template <typename T>
__inline__ __device__ T warpReduceMax(T val)
{
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val = max(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32));
return val;
}
/* Calculate the maximum of all elements in a block */
template <typename T>
__inline__ __device__ T blockReduceMax(T val)
{
static __shared__ T shared[32];
int lane = threadIdx.x & 0x1f; // in-warp idx
int wid = threadIdx.x >> 5; // warp idx
val = warpReduceMax(val); // get maxx in each warp
if (lane == 0) // record in-warp maxx by warp Idx
shared[wid] = val;
__syncthreads();
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
// blockDim.x is not divided by 32
val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : -1e20f;
val = warpReduceMax(val);
return val;
}
/* Calculate the maximum of all elements in a block */
template <typename T>
__inline__ __device__ T blockAllReduceMax(T val)
{
static __shared__ T shared[32];
int lane = threadIdx.x & 0x1f; // in-warp idx
int wid = threadIdx.x >> 5; // warp idx
val = warpReduceMax(val); // get maxx in each warp
if (lane == 0) // record in-warp maxx by warp Idx
shared[wid] = val;
__syncthreads();
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
// blockDim.x is not divided by 32
val = (lane < (blockDim.x / 32.f)) ? shared[lane] : -1e20f;
val = warpReduceMax(val);
return val;
}
template <typename T, int NUM>
__inline__ __device__ T warpReduceSumV2(T* val)
{
#pragma unroll
for (int i = 0; i < NUM; i++)
{
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val[i] += __shfl_xor_sync(FINAL_MASK, val[i], mask, 32);
}
return (T) (0.0f);
}
template <typename T, int NUM>
__inline__ __device__ T blockReduceSumV2(T* val)
{
static __shared__ T shared[NUM][33];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
warpReduceSumV2<T, NUM>(val);
if (lane == 0)
{
#pragma unroll
for (int i = 0; i < NUM; i++)
{
shared[i][wid] = val[i];
}
}
__syncthreads();
bool is_mask = threadIdx.x < (blockDim.x / 32.f);
#pragma unroll
for (int i = 0; i < NUM; i++)
{
val[i] = is_mask ? shared[i][lane] : (T) (0.0f);
}
warpReduceSumV2<T, NUM>(val);
return (T) 0.0f;
}
template <typename T, int NUM>
__inline__ __device__ T warpReduceMaxV2(T* val)
{
#pragma unroll
for (int i = 0; i < NUM; i++)
{
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val[i] = max(val[i], __shfl_xor_sync(FINAL_MASK, val[i], mask, 32));
}
return (T) (0.0f);
}
template <typename T, int NUM>
__inline__ __device__ T blockReduceMaxV2(T* val)
{
static __shared__ T shared[32][NUM];
int lane = threadIdx.x & 0x1f; // in-warp idx
int wid = threadIdx.x >> 5; // warp idx
warpReduceMaxV2<T, NUM>(val); // get maxx in each warp
if (lane == 0) // record in-warp maxx by warp Idx
{
#pragma unroll
for (int i = 0; i < NUM; i++)
{
shared[wid][i] = val[i];
}
}
__syncthreads();
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
// blockDim.x is not divided by 32
bool is_mask = threadIdx.x < (blockDim.x / 32.f);
#pragma unroll
for (int i = 0; i < NUM; i++)
{
val[i] = is_mask ? shared[lane][i] : (T) -1e20f;
}
warpReduceMaxV2<T, NUM>(val);
return (T) 0.0f;
}
template <int NUM>
__inline__ __device__ void cgBlockReduceSumElements(float* element_list, float* cgBlockReduceSumElements_shm)
{
cg::thread_block cta = cg::this_thread_block();
cg::thread_block_tile<32> tile = cg::tiled_partition<32>(cta);
int const tid = cta.thread_rank();
int const blockz = blockDim.x;
for (int i = 0; i < NUM; i++)
{
#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))
cgBlockReduceSumElements_shm[i * blockz + tid] = cg::reduce(tile, element_list[i], cg::plus<float>());
#else
// TODO Add implementation here
if (threadIdx.x == 0 && blockIdx.x == 0)
{
printf("[ERROR] Not support cgBlockReduceSumElements when CUDA < 11 \n");
assert(false);
}
#endif
}
cg::sync(cta);
if (tid == 0)
{
#pragma unroll
for (int i = 0; i < NUM; i++)
{
float beta = 0.0f;
for (int j = 0; j < blockz; j += 32)
{
beta += cgBlockReduceSumElements_shm[i * blockz + j];
}
element_list[i] = beta;
}
}
}
template <typename T, int MAX_K>
struct TopK
{
int p[MAX_K]; // index, being -1 at the tail if the array is not full
T u[MAX_K]; // value in descend order, being -MAX_T_VAL if the element is invalid
__device__ __forceinline__ void insert(T const elem, int const elem_id)
{
if (elem_id < 0)
{
return;
}
// Condition of updating the array
// 1. array is not full
// 2. elem is greater than the smallest (last) element in the array
// 3. elem is equal to the smallest (last) element in the array but its elem_id is smaller
bool const need_update
= (p[MAX_K - 1] == -1 || elem > u[MAX_K - 1] || elem == u[MAX_K - 1] && elem_id < p[MAX_K - 1]);
if (!need_update)
{
return;
}
// Find suitable index for the new element
int i;
for (i = MAX_K - 2; i >= 0; --i)
{
bool const need_decrease = (p[i] == -1 || elem > u[i] || elem == u[i] && elem_id < p[i]);
if (!need_decrease)
break;
}
// Move elements to correct positions
for (int k = MAX_K - 2; k >= i; --k)
{
p[k + 1] = p[k];
u[k + 1] = u[k];
}
p[i] = elem_id;
u[i] = elem;
}
__device__ __forceinline__ void init()
{
T const MAX_T_VAL = (std::is_same<T, half>::value) ? HALF_FLT_MAX : FLT_MAX;
for (int i = 0; i < MAX_K; i++)
{
p[i] = -1;
u[i] = -MAX_T_VAL;
}
}
};
template <typename T, int MAX_K>
__device__ __forceinline__ TopK<T, MAX_K> reduce_topk_op(TopK<T, MAX_K> const& a, TopK<T, MAX_K> const& b)
{
TopK<T, MAX_K> res = a;
for (int i = 0; i < MAX_K; ++i)
res.insert(b.u[i], b.p[i]);
return res;
}
template <typename T>
struct TopK_2
{
int p = -1;
T u = -((std::is_same<T, half>::value) ? HALF_FLT_MAX : FLT_MAX);
__device__ __forceinline__ void insert(T elem, int elem_id)
{
if (elem > u)
{
u = elem;
p = elem_id;
}
}
__device__ __forceinline__ void init()
{
u = -((std::is_same<T, half>::value) ? HALF_FLT_MAX : FLT_MAX);
p = -1;
}
};
template <typename T>
__device__ __forceinline__ TopK_2<T> reduce_topk_op_2(TopK_2<T> const& a, TopK_2<T> const& b)
{
return a.u > b.u ? a : b;
}
template <typename T>
__device__ __forceinline__ T clamp_inf_for_half(float const input)
{
return input;
}
template <>
__device__ __forceinline__ half clamp_inf_for_half(float const input)
{
// clamp inf values to enable fp16 training
return input > 0.0f ? (half) min(input, HALF_FLT_MAX - 1000) : (half) max(input, -HALF_FLT_MAX + 1000);
}
} // namespace common
} // namespace tensorrt_llm

View File

@@ -0,0 +1,123 @@
/*
* Copyright (c) 2021-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <functional>
#include <numeric>
#include <optional>
#include <sstream>
namespace tensorrt_llm::common::stl_utils
{
template <typename TInputIt, typename TOutputIt, typename TBinOp>
constexpr TOutputIt basicInclusiveScan(TInputIt first, TInputIt last, TOutputIt dFirst, TBinOp op)
{
if (first != last)
{
auto val = *first;
while (true)
{
*dFirst = val;
++dFirst;
++first;
if (first == last)
{
break;
}
val = op(std::move(val), *first);
}
}
return dFirst;
}
template <typename TInputIt, typename TOutputIt>
constexpr TOutputIt inclusiveScan(TInputIt first, TInputIt last, TOutputIt dFirst)
{
#if defined(__GNUC__) && __GNUC__ <= 8
return basicInclusiveScan(first, last, dFirst, std::plus<>{});
#else
return std::inclusive_scan(first, last, dFirst);
#endif
}
template <typename TInputIt, typename TOutputIt, typename T, typename TBinOp>
constexpr TOutputIt basicExclusiveScan(TInputIt first, TInputIt last, TOutputIt dFirst, T init, TBinOp op)
{
if (first != last)
{
while (true)
{
T tmp{op(init, *first)};
*dFirst = init;
++dFirst;
++first;
if (first == last)
{
break;
}
init = std::move(tmp);
}
}
return dFirst;
}
template <typename TInputIt, typename TOutputIt, typename T>
constexpr TOutputIt exclusiveScan(TInputIt first, TInputIt last, TOutputIt dFirst, T init)
{
#if defined(__GNUC__) && __GNUC__ <= 8
return basicExclusiveScan(first, last, dFirst, std::move(init), std::plus<>{});
#else
return std::exclusive_scan(first, last, dFirst, std::move(init));
#endif
}
template <typename T, typename = void>
struct HasOperatorOutput : std::false_type
{
};
template <typename T>
struct HasOperatorOutput<T, std::void_t<decltype((std::declval<std::ostream&>() << std::declval<T>()))>>
: std::true_type
{
};
template <typename T>
std::string toString(T const& t, typename std::enable_if_t<HasOperatorOutput<T>::value, int> = 0)
{
std::ostringstream oss;
oss << t;
return oss.str();
}
template <typename T>
std::string toString(std::optional<T> const& t, typename std::enable_if_t<HasOperatorOutput<T>::value, int> = 0)
{
std::ostringstream oss;
if (t)
{
oss << t.value();
}
else
{
oss << "None";
}
return oss.str();
}
} // namespace tensorrt_llm::common::stl_utils

View File

@@ -0,0 +1,76 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/stringUtils.h"
#include "tensorrt_llm/common/assert.h"
#include <cerrno>
#include <cstdarg>
#include <cstring>
#include <iostream>
#include <string>
namespace tensorrt_llm::common
{
namespace
{
std::string vformat(char const* fmt, va_list args)
{
va_list args0;
va_copy(args0, args);
auto const size = vsnprintf(nullptr, 0, fmt, args0);
if (size <= 0)
return "";
std::string stringBuf(size, char{});
auto const size2 = std::vsnprintf(&stringBuf[0], size + 1, fmt, args);
TLLM_CHECK_WITH_INFO(size2 == size, std::string(std::strerror(errno)));
return stringBuf;
}
} // namespace
std::string fmtstr(char const* format, ...)
{
va_list args;
va_start(args, format);
std::string result = vformat(format, args);
va_end(args);
return result;
};
std::unordered_set<std::string> str2set(std::string const& input, char delimiter)
{
std::unordered_set<std::string> values;
if (!input.empty())
{
std::stringstream valStream(input);
std::string val;
while (std::getline(valStream, val, delimiter))
{
if (!val.empty())
{
values.insert(val);
}
}
}
return values;
};
} // namespace tensorrt_llm::common

View File

@@ -0,0 +1,42 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <chrono>
#include <iomanip>
#include <sstream>
#include "tensorrt_llm/common/timestampUtils.h"
namespace tensorrt_llm::common
{
std::string getCurrentTimestamp()
{
auto now = std::chrono::system_clock::now();
auto now_t = std::chrono::system_clock::to_time_t(now);
auto tm = *std::localtime(&now_t);
auto epoch_to_now = now.time_since_epoch();
auto seconds = std::chrono::duration_cast<std::chrono::seconds>(epoch_to_now);
auto us = std::chrono::duration_cast<std::chrono::microseconds>(epoch_to_now - seconds);
std::ostringstream stream;
stream << std::put_time(&tm, "%m-%d-%Y %H:%M:%S");
stream << "." << std::setfill('0') << std::setw(6) << us.count();
return stream.str();
}
} // namespace tensorrt_llm::common

View File

@@ -0,0 +1,25 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <string>
namespace tensorrt_llm::common
{
/// @brief Get the current timestamp in the format "MM-DD-YYYY HH:MM:SS:uuuuuu"
std::string getCurrentTimestamp();
} // namespace tensorrt_llm::common

View File

@@ -0,0 +1,105 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/tllmException.h"
#include "tensorrt_llm/common/stringUtils.h"
#include <cstdlib>
#if !defined(_MSC_VER)
#include <cxxabi.h>
#include <dlfcn.h>
#include <execinfo.h>
#endif
#include <sstream>
namespace tensorrt_llm::common
{
namespace
{
int constexpr VOID_PTR_SZ = 2 + sizeof(void*) * 2;
}
#if !defined(_MSC_VER)
TllmException::TllmException(char const* file, std::size_t line, std::string const& msg)
: std::runtime_error{""}
{
mNbFrames = backtrace(mCallstack.data(), MAX_FRAMES);
auto const trace = getTrace();
std::runtime_error::operator=(
std::runtime_error{fmtstr("%s (%s:%zu)\n%s", msg.c_str(), file, line, trace.c_str())});
}
#else
TllmException::TllmException(char const* file, std::size_t line, std::string const& msg)
: mNbFrames{}
, std::runtime_error{fmtstr("%s (%s:%zu)", msg.c_str(), file, line)}
{
}
#endif
TllmException::~TllmException() noexcept = default;
std::string TllmException::getTrace() const
{
#if defined(_MSC_VER)
return "";
#else
auto const trace = backtrace_symbols(mCallstack.data(), mNbFrames);
std::ostringstream buf;
for (auto i = 1; i < mNbFrames; ++i)
{
Dl_info info;
if (dladdr(mCallstack[i], &info) && info.dli_sname)
{
auto const clearName = demangle(info.dli_sname);
buf << fmtstr("%-3d %*p %s + %zd", i, VOID_PTR_SZ, mCallstack[i], clearName.c_str(),
static_cast<char*>(mCallstack[i]) - static_cast<char*>(info.dli_saddr));
}
else
{
buf << fmtstr("%-3d %*p %s", i, VOID_PTR_SZ, mCallstack[i], trace[i]);
}
if (i < mNbFrames - 1)
buf << std::endl;
}
if (mNbFrames == MAX_FRAMES)
buf << std::endl << "[truncated]";
std::free(trace);
return buf.str();
#endif
}
std::string TllmException::demangle(char const* name)
{
#if defined(_MSC_VER)
return name;
#else
std::string clearName{name};
auto status = -1;
auto const demangled = abi::__cxa_demangle(name, nullptr, nullptr, &status);
if (status == 0)
{
clearName = demangled;
std::free(demangled);
}
return clearName;
#endif
}
} // namespace tensorrt_llm::common

View File

@@ -0,0 +1,87 @@
/*
* Copyright (c) 1993-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cstddef>
#include <cstdint>
namespace tensorrt_llm::common
{
std::uintptr_t constexpr kCudaMemAlign = 128;
inline int8_t* alignPtr(int8_t* ptr, uintptr_t to)
{
uintptr_t addr = (uintptr_t) ptr;
if (addr % to)
{
addr += to - addr % to;
}
return (int8_t*) addr;
}
constexpr size_t alignSize(size_t size, size_t to)
{
if ((size % to) != 0U)
{
size += to - size % to;
}
return size;
}
inline int8_t* nextWorkspacePtrCommon(int8_t* ptr, uintptr_t previousWorkspaceSize, uintptr_t const alignment)
{
uintptr_t addr = (uintptr_t) ptr;
addr += previousWorkspaceSize;
return alignPtr((int8_t*) addr, alignment);
}
inline int8_t* nextWorkspacePtr(int8_t* ptr, uintptr_t previousWorkspaceSize)
{
return nextWorkspacePtrCommon(ptr, previousWorkspaceSize, kCudaMemAlign);
}
inline int8_t* nextWorkspacePtr(
int8_t* const base, uintptr_t& offset, uintptr_t const size, uintptr_t const alignment = kCudaMemAlign)
{
uintptr_t curr_offset = offset;
uintptr_t next_offset = curr_offset + ((size + alignment - 1) / alignment) * alignment;
int8_t* newptr = size == 0 ? nullptr : base + curr_offset;
offset = next_offset;
return newptr;
}
inline int8_t* nextWorkspacePtrWithAlignment(
int8_t* ptr, uintptr_t previousWorkspaceSize, uintptr_t const alignment = kCudaMemAlign)
{
return nextWorkspacePtrCommon(ptr, previousWorkspaceSize, alignment);
}
inline size_t calculateTotalWorkspaceSize(
size_t const* workspaces, int count, uintptr_t const alignment = kCudaMemAlign)
{
size_t total = 0;
for (int i = 0; i < count; i++)
{
total += workspaces[i];
if (workspaces[i] % alignment)
{
total += alignment - (workspaces[i] % alignment);
}
}
return total;
}
}; // namespace tensorrt_llm::common

View File

@@ -0,0 +1,352 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp>
#include <cute/arch/util.hpp>
#include <cute/atom/copy_traits.hpp>
#include <cute/numeric/numeric_types.hpp>
// Config
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) && (__CUDACC_VER_MAJOR__ >= 10))
#define CUTE_ARCH_RED_F16_SM70_ENABLED
#endif
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12))
#define CUTE_ARCH_RED_VEC_SM90_ENABLED
#define CUTE_ARCH_RED_BF16_SM90_ENABLED
#endif
namespace cute
{
//////////////////////////////////
// Wrapper around CUDA's atomicAdd
//////////////////////////////////
template <class T>
struct TypedAtomicAdd
{
using SRegisters = T[1];
using DRegisters = T[1];
CUTE_HOST_DEVICE static constexpr void copy(T const& src, T& dst)
{
atomicAdd(&dst, src);
}
};
template <class T>
struct Copy_Traits<TypedAtomicAdd<T>>
{
// Logical thread id to thread idx (one-thread)
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1, Int<sizeof_bits<T>::value>>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1, Int<sizeof_bits<T>::value>>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
};
//////////////////////////////////
// F16 ADD PTX
//////////////////////////////////
struct SM70_RED_ADD_NOFTZ_F16
{
using SRegisters = uint16_t[1];
using DRegisters = uint16_t[1];
CUTE_HOST_DEVICE static void copy(uint16_t const& src0, uint16_t& gmem_dst)
{
#if defined(CUTE_ARCH_RED_F16_SM70_ENABLED)
asm volatile("red.global.add.noftz.f16 [%0], %1;\n" ::"l"(&gmem_dst), "h"(src0));
#else
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.f16 without CUTE_ARCH_RED_F16_SM70_ENABLED.");
#endif
}
};
template <>
struct Copy_Traits<SM70_RED_ADD_NOFTZ_F16>
{
// Logical thread id to thread idx (one-thread)
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1, _16>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1, _16>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
};
struct SM70_RED_ADD_NOFTZ_F16x2
{
using SRegisters = uint32_t[1];
using DRegisters = uint32_t[1];
CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t& gmem_dst)
{
#if defined(CUTE_ARCH_RED_F16_SM70_ENABLED)
asm volatile("red.global.add.noftz.f16x2 [%0], %1;\n" ::"l"(&gmem_dst), "r"(src0));
#else
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.f16 without CUTE_ARCH_RED_F16_SM70_ENABLED.");
#endif
}
};
template <>
struct Copy_Traits<SM70_RED_ADD_NOFTZ_F16x2>
{
// Logical thread id to thread idx (one-thread)
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1, _32>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1, _32>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
};
struct SM90_RED_ADD_NOFTZ_F16x2_V2
{
using SRegisters = uint32_t[2];
using DRegisters = uint64_t[1];
CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t const& src1, uint64_t& gmem_dst)
{
#if defined(CUTE_ARCH_RED_VEC_SM90_ENABLED)
asm volatile("red.global.add.noftz.v2.f16x2 [%0], {%1, %2};\n" ::"l"(&gmem_dst), "r"(src0), "r"(src1));
#else
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.vX without CUTE_ARCH_RED_VEC_SM90_ENABLED.");
#endif
}
};
template <>
struct Copy_Traits<SM90_RED_ADD_NOFTZ_F16x2_V2>
{
// Logical thread id to thread idx (one-thread)
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1, _64>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1, _64>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
};
struct SM90_RED_ADD_NOFTZ_F16x2_V4
{
using SRegisters = uint32_t[4];
using DRegisters = uint128_t[1];
CUTE_HOST_DEVICE static void copy(
uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, uint128_t& gmem_dst)
{
#if defined(CUTE_ARCH_RED_VEC_SM90_ENABLED)
asm volatile("red.global.add.noftz.v4.f16x2 [%0], {%1, %2, %3, %4};\n" ::"l"(&gmem_dst), "r"(src0), "r"(src1),
"r"(src2), "r"(src3));
#else
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.vX without CUTE_ARCH_RED_VEC_SM90_ENABLED.");
#endif
}
};
template <>
struct Copy_Traits<SM90_RED_ADD_NOFTZ_F16x2_V4>
{
// Logical thread id to thread idx (one-thread)
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1, _128>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1, _128>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
};
//////////////////////////////////
// BF16 ADD PTX
//////////////////////////////////
struct SM90_RED_ADD_NOFTZ_BF16
{
using SRegisters = uint16_t[1];
using DRegisters = uint16_t[1];
CUTE_HOST_DEVICE static void copy(uint16_t const& src0, uint16_t& gmem_dst)
{
#if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED)
asm volatile("red.global.add.noftz.bf16 [%0], %1;\n" ::"l"(&gmem_dst), "h"(src0));
#else
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED.");
#endif
}
};
template <>
struct Copy_Traits<SM90_RED_ADD_NOFTZ_BF16>
{
// Logical thread id to thread idx (one-thread)
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1, _16>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1, _16>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
};
//////////////////////////////////
struct SM90_RED_ADD_NOFTZ_BF16x2
{
using SRegisters = uint32_t[1];
using DRegisters = uint32_t[1];
CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t& gmem_dst)
{
#if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED)
asm volatile("red.global.add.noftz.bf16x2 [%0], %1;\n" ::"l"(&gmem_dst), "r"(src0));
#else
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED.");
#endif
}
};
template <>
struct Copy_Traits<SM90_RED_ADD_NOFTZ_BF16x2>
{
// Logical thread id to thread idx (one-thread)
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1, _32>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1, _32>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
};
//////////////////////////////////
struct SM90_RED_ADD_NOFTZ_BF16x2_V2
{
using SRegisters = uint32_t[2];
using DRegisters = uint64_t[1];
CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t const& src1, uint64_t& gmem_dst)
{
#if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED)
asm volatile("red.global.add.noftz.v2.bf16x2 [%0], {%1, %2};\n" ::"l"(&gmem_dst), "r"(src0), "r"(src1));
#else
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED.");
#endif
}
};
template <>
struct Copy_Traits<SM90_RED_ADD_NOFTZ_BF16x2_V2>
{
// Logical thread id to thread idx (one-thread)
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1, _64>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1, _64>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
};
//////////////////////////////////
struct SM90_RED_ADD_NOFTZ_BF16x2_V4
{
using SRegisters = uint32_t[4];
using DRegisters = uint128_t[1];
CUTE_HOST_DEVICE static void copy(
uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, uint128_t& gmem_dst)
{
#if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED)
asm volatile("red.global.add.noftz.v4.bf16x2 [%0], {%1, %2, %3, %4};\n" ::"l"(&gmem_dst), "r"(src0), "r"(src1),
"r"(src2), "r"(src3));
#else
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED.");
#endif
}
};
template <>
struct Copy_Traits<SM90_RED_ADD_NOFTZ_BF16x2_V4>
{
// Logical thread id to thread idx (one-thread)
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1, _128>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1, _128>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
};
//////////////////////////////////
} // end namespace cute

View File

@@ -0,0 +1,120 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Templates exposing architecture support for multiply-add operations
*/
#pragma once
#include "cutlass_extensions/weight_only_quant_op.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace arch
{
// Tag which triggers MMA which will trigger
struct OpMultiplyAddDequantizeInterleavedBToA;
/*
Below we have extra tags to signal what kind of dequantization we want to do
(per col, scale only fine grained, finegrained with zero). This still lets us
the existing template infrastructure (incl. that in CUTLASS). However, we
split out the template below into OpMultiplyAddDequantizeInterleavedBToA along
with the quantization op before instantiating the GEMM pieces.
Note that this is somewhat of a hack, but it SIGNIFICANTLY reduces the amount of
code we need to duplicate.
*/
struct OpMultiplyAddDequantizeInterleavedBToA_percol_scale;
struct OpMultiplyAddDequantizeInterleavedBToA_fine_scale;
struct OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias;
// The default just forwards the original operator
template <typename MmaOp, WeightOnlyQuantOp QuantOp_>
struct TagOperator
{
using TaggedOperator = MmaOp;
};
// Specializations below attach more information to the operator
template <>
struct TagOperator<OpMultiplyAddDequantizeInterleavedBToA, WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY>
{
using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_percol_scale;
};
template <>
struct TagOperator<OpMultiplyAddDequantizeInterleavedBToA, WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY>
{
using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_fine_scale;
};
template <>
struct TagOperator<OpMultiplyAddDequantizeInterleavedBToA, WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS>
{
using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias;
};
// Here we instantiate some structs to "detag" the tagged operator. It splits it back to the original
// operator + the extra information. If no extra info was tagged, the dequant op per column scaling
// as a default.
template <typename TaggedMmaOp>
struct DetagOperator
{
using Operator = TaggedMmaOp;
static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY;
};
template <>
struct DetagOperator<OpMultiplyAddDequantizeInterleavedBToA_percol_scale>
{
using Operator = OpMultiplyAddDequantizeInterleavedBToA;
static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY;
};
template <>
struct DetagOperator<OpMultiplyAddDequantizeInterleavedBToA_fine_scale>
{
using Operator = OpMultiplyAddDequantizeInterleavedBToA;
static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY;
};
template <>
struct DetagOperator<OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias>
{
using Operator = OpMultiplyAddDequantizeInterleavedBToA;
static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS;
};
} // namespace arch
} // namespace cutlass

View File

@@ -0,0 +1,88 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cuda_runtime_api.h>
#include "cutlass/device_kernel.h"
#include "tensorrt_llm/common/cudaUtils.h"
namespace tensorrt_llm
{
namespace cutlass_extensions
{
template <typename GemmKernel, bool enable_cutlass_3x = false>
inline int compute_occupancy_for_kernel()
{
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
if (smem_size > (48 << 10))
{
cudaFuncAttributes attr;
int device = 0;
int max_smem_per_block = 0;
tensorrt_llm::common::check_cuda_error(cudaGetDevice(&device));
tensorrt_llm::common::check_cuda_error(
cudaDeviceGetAttribute(&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device));
if constexpr (enable_cutlass_3x)
{
tensorrt_llm::common::check_cuda_error(cudaFuncGetAttributes(&attr, cutlass::device_kernel<GemmKernel>));
}
else
{
tensorrt_llm::common::check_cuda_error(cudaFuncGetAttributes(&attr, cutlass::Kernel<GemmKernel>));
}
if (smem_size + attr.sharedSizeBytes >= static_cast<size_t>(max_smem_per_block))
{
// This should mean that
// cudaFuncSetAttribute(cutlass::Kernel<GemmKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)
// wouldn't work. In that case, we return an occupancy of 0. This will cause the heuristic to ignore this
// configuration.
return 0;
}
if constexpr (enable_cutlass_3x)
{
tensorrt_llm::common::check_cuda_error(cudaFuncSetAttribute(
cutlass::device_kernel<GemmKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
else
{
tensorrt_llm::common::check_cuda_error(cudaFuncSetAttribute(
cutlass::Kernel<GemmKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
}
int max_active_blocks = -1;
if constexpr (enable_cutlass_3x)
{
tensorrt_llm::common::check_cuda_error(
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, cutlass::device_kernel<GemmKernel>,
128 * (GemmKernel::NumLoadWarpGroups + GemmKernel::NumMmaWarpGroups), smem_size));
}
else
{
tensorrt_llm::common::check_cuda_error(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks, cutlass::Kernel<GemmKernel>, GemmKernel::kThreadCount, smem_size));
}
return max_active_blocks;
}
} // namespace cutlass_extensions
} // namespace tensorrt_llm

View File

@@ -0,0 +1,550 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Functor performing elementwise operations used by epilogues.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/collective/detail.hpp"
#include "cutlass/fast_math.h"
#include "cute/numeric/numeric_types.hpp"
#include "cute/tensor.hpp"
#include "cutlass/trace.h"
#include "cutlass_extensions/arch/copy_red_global.hpp"
#include "cutlass_extensions/util/gather_tensor.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace epilogue
{
namespace collective
{
/////////////////////////////////////////////////////////////////////////////////////////////////
template <class StrideC_, class ElementD_, class StrideD_, class ThreadEpilogueOp_, class ElementBias, class StrideBias,
class ElementScale, class StrideScale, class EpilogueTile, class SmemLayoutAtomD, class CopyOpR2S, class CopyOpS2R,
class CopyOpR2G>
class EpilogueMoeFusedFinalize
{
public:
using EpilogueSchedule = PtrArrayNoSmemWarpSpecialized;
using DispatchPolicy = PtrArrayNoSmemWarpSpecialized;
using ThreadEpilogueOp = ThreadEpilogueOp_;
using ElementOutput = typename ThreadEpilogueOp::ElementOutput;
using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator;
using ElementCompute = typename ThreadEpilogueOp::ElementCompute;
using ElementIntermediate = typename ThreadEpilogueOp::ElementD;
using ElementC = typename ThreadEpilogueOp::ElementC;
using StrideC = StrideC_;
using InternalStrideC = cute::remove_pointer_t<StrideC>;
using ElementD = ElementD_;
using StrideD = StrideD_;
using InternalStrideD = cute::remove_pointer_t<StrideD>;
static_assert(!is_same_v<InternalStrideC, StrideC>, "Stride C must be a pointer");
static_assert(is_same_v<InternalStrideD, StrideD>, "Stride D must not be a pointer");
using CopyAtomR2S = Copy_Atom<CopyOpR2S, ElementAccumulator>;
using CopyAtomS2R = Copy_Atom<CopyOpS2R, ElementAccumulator>;
using CopyAtomR2G = Copy_Atom<CopyOpR2G, ElementD>;
static constexpr int AlignmentD = CopyAtomR2G::NumValSrc;
using SmemLayoutD = decltype(tile_to_shape(SmemLayoutAtomD{}, EpilogueTile{}));
constexpr static size_t SmemAlignmentD = cutlass::detail::alignment_for_swizzle(SmemLayoutD{});
struct SharedStorage
{
alignas(SmemAlignmentD) cute::ArrayEngine<ElementAccumulator, cosize_v<SmemLayoutD>> smem_D;
};
struct TensorMapStorage
{
};
struct Arguments
{
typename ThreadEpilogueOp::Params thread{};
ElementC const** ptr_C{};
StrideC dC{};
ElementD* ptr_D{};
StrideD dD{};
ElementBias const* ptr_bias;
StrideBias dBias{};
ElementScale const* ptr_scale;
StrideScale dScale{};
int64_t const* group_offset{};
int32_t const* scatter_index{};
cutlass::FastDivmod num_rows_in_final_output;
};
using Params = Arguments;
//
// Methods
//
template <class ProblemShape>
static constexpr Params to_underlying_arguments(
ProblemShape const&, Arguments const& args, [[maybe_unused]] void* workspace)
{
return args;
}
template <class ProblemShape>
static size_t get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count = 0)
{
return 0;
}
template <class ProblemShape>
static cutlass::Status initialize_workspace(ProblemShape const& problem_shape, Arguments const& args,
void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr)
{
return cutlass::Status::kSuccess;
}
template <class ProblemShape>
CUTLASS_HOST_DEVICE static bool can_implement(
[[maybe_unused]] ProblemShape problem_shape, [[maybe_unused]] Arguments const& args)
{
bool implementable = true;
if (problem_shape.is_host_problem_shape_available())
{
// Check alignment for all problem sizes
for (int i = 0; i < problem_shape.groups(); i++)
{
auto problem_shape_MNKL = append<4>(problem_shape.get_host_problem_shape(i), 1);
auto [M, N, K, L] = problem_shape_MNKL;
implementable = implementable
&& cutlass::detail::check_alignment<AlignmentD>(cute::make_shape(M, N, L), InternalStrideD{});
}
}
if (!implementable)
{
CUTLASS_TRACE_HOST(
" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for selected global "
"reduction instruction.\n");
}
return implementable;
}
CUTLASS_HOST_DEVICE
EpilogueMoeFusedFinalize(Params const& params_)
: params(params_)
{
}
CUTLASS_DEVICE
bool is_source_needed()
{
// For Ptr-Array or Grouped Gemm we cannot determine if source is needed based on first beta.
return params.ptr_C != nullptr
&& (params.thread.beta_ptr_array || params.thread.beta_ptr || params.thread.beta != 0);
}
template <class ProblemShapeMNKL, class BlockShapeMNK, class BlockCoordMNKL, class FrgEngine, class FrgLayout,
class TiledMma, class ResidueMNK>
CUTLASS_HOST_DEVICE void operator()(ProblemShapeMNKL problem_shape_mnkl, BlockShapeMNK blk_shape_MNK,
BlockCoordMNKL blk_coord_mnkl, cute::Tensor<FrgEngine, FrgLayout> const& accumulators, TiledMma tiled_mma,
ResidueMNK residue_mnk, int thread_idx, [[maybe_unused]] char* smem_buf)
{
using namespace cute;
using X = Underscore;
static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4");
static_assert(is_static<BlockShapeMNK>::value, "ThreadBlock tile shape must be static");
static_assert(rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3");
static_assert(rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3");
auto synchronize = [&]()
{ cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
// Separate out problem shape for convenience
auto M = get<0>(problem_shape_mnkl);
auto N = get<1>(problem_shape_mnkl);
auto L = get<3>(problem_shape_mnkl);
auto mma_tile_m = tile_size<0>(tiled_mma);
auto mma_tile_n = tile_size<1>(tiled_mma);
auto epi_tile_m = size<0>(EpilogueTile{});
auto epi_tile_n = size<1>(EpilogueTile{});
CUTE_STATIC_ASSERT(epi_tile_m % mma_tile_m == 0, "MMA_TILE_M must divide EPI_TILE_M");
CUTE_STATIC_ASSERT(mma_tile_n % epi_tile_n == 0, "EPI_TILE_N must divide MMA_TILE_N");
// Batches are managed by using appropriate pointers to C and D matrices
int32_t const mock_L = 1;
int32_t const mock_l_coord = 0;
// Slice to get the tile this CTA is responsible for
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl;
// If scalar alpha/beta are provided, i.e., same alpha/beta applies to all batches/groups.
// If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups,
// we get the correct alpha/beta values for the current batch/group using group index.
ThreadEpilogueOp epilogue_op(params.thread, l_coord);
SharedStorage& storage = *reinterpret_cast<SharedStorage*>(smem_buf);
Tensor sD_ = make_tensor(make_smem_ptr(storage.smem_D.begin()), SmemLayoutD{});
Tensor sD = as_position_independent_swizzle_tensor(sD_);
// Function to scatter output rows
auto& num_rows = params.num_rows_in_final_output;
auto read_scatter_map = IndexedGather(make_gmem_ptr(params.scatter_index + params.group_offset[l_coord]));
auto get_scatter_idx = [&](auto i)
{
auto scatter = read_scatter_map(i);
int quot, rem;
num_rows(quot, rem, scatter);
return rem;
};
// Represent the full output tensor
ElementC const* ptr_C = epilogue_op.is_source_needed() ? params.ptr_C[l_coord] : nullptr;
auto dC = epilogue_op.is_source_needed() ? params.dC[l_coord] : InternalStrideC{};
Tensor mC_mnl = make_tensor(make_gmem_ptr(ptr_C), make_shape(M, N, mock_L), dC); // (m,n,l)
Tensor mD_mnl = make_gather_tensor(
make_gmem_ptr(params.ptr_D), make_shape(M, N, mock_L), params.dD, get_scatter_idx); // (m,n,l)
// Use fake shape for bias, it doesn't matter
bool const is_bias_needed = params.ptr_bias != nullptr;
Tensor mBias_mnl = make_tensor(make_gmem_ptr(params.ptr_bias), make_shape(M, N, 1), params.dBias);
Tensor mScale_mnl = make_tensor(
make_gmem_ptr(params.ptr_scale + params.group_offset[l_coord]), make_shape(M, N), params.dScale);
Tensor gC_mnl
= local_tile(mC_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l)
Tensor gD_mnl
= local_tile(mD_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l)
Tensor gC = gC_mnl(_, _, m_coord, n_coord, mock_l_coord); // (BLK_M,BLK_N)
Tensor gD = gD_mnl(_, _, m_coord, n_coord, mock_l_coord); // (BLK_M,BLK_N)
Tensor gC_epi = flat_divide(gC, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
Tensor gD_epi = flat_divide(gD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
Tensor gBias_mnl
= local_tile(mBias_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l)
Tensor gScale_mnl
= local_tile(mScale_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l)
Tensor gBias = gBias_mnl(_, _, m_coord, n_coord, l_coord); // (BLK_M,BLK_N)
Tensor gScale = gScale_mnl(_, _, m_coord, n_coord); // (BLK_M,BLK_N)
Tensor gBias_epi = flat_divide(gBias, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
Tensor gScale_epi = flat_divide(gScale, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
// Get the smallest tiled copy we can use to retile the accumulators
TiledCopy tiled_copy_C_atom
= make_tiled_copy_C_atom(Copy_Atom<SM90_U32x4_STSM_N, cutlass::half_t>{}, tiled_mma);
TiledCopy tiled_r2s = make_tiled_copy_S(CopyAtomR2S{}, tiled_copy_C_atom);
auto thread_r2s = tiled_r2s.get_thread_slice(thread_idx);
Tensor tRS_rAcc = thread_r2s.retile_S(accumulators); // ((R2S,R2S_V),MMA_M,MMA_N)
Tensor tRS_sD = thread_r2s.partition_D(sD); // ((R2S,R2S_V),R2S_M,R2S_N)
Tensor tRS_rD = make_tensor<ElementAccumulator>(shape(tRS_sD)); // ((R2S,R2S_V),R2S_M,R2S_N)
// Make a tiled copy vectorized along major direction of D
auto tiled_s2r = [&]()
{
if constexpr (cutlass::gemm::detail::is_k_major<StrideD>())
{
constexpr int NumThreadsMajor = epi_tile_n / AlignmentD;
constexpr int NumThreadsMinor = cute::size(tiled_mma) / NumThreadsMajor;
return make_tiled_copy(CopyAtomS2R{},
Layout<Shape<Int<NumThreadsMinor>, Int<NumThreadsMajor>>, Stride<Int<NumThreadsMajor>, _1>>{},
Layout<Shape<_1, Int<AlignmentD>>>{});
}
else if constexpr (cutlass::gemm::detail::is_mn_major<StrideD>())
{
constexpr int NumThreadsMajor = epi_tile_m / AlignmentD;
constexpr int NumThreadsMinor = cute::size(tiled_mma) / NumThreadsMajor;
return make_tiled_copy(CopyAtomS2R{},
Layout<Shape<Int<NumThreadsMajor>, Int<NumThreadsMinor>>, Stride<_1, Int<NumThreadsMajor>>>{},
Layout<Shape<Int<AlignmentD>, _1>>{});
}
else
{
static_assert(cute::is_void_v<StrideD>, "Unsupported D gmem layout.");
}
}();
auto thread_s2r = tiled_s2r.get_thread_slice(thread_idx);
Tensor tSR_sD = thread_s2r.partition_S(sD); // ((S2R,S2R_V),S2R_M,S2R_N)
Tensor tSR_gD = thread_s2r.partition_D(gD_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N)
Tensor tSR_gC = thread_s2r.partition_D(gC_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N)
Tensor tSR_gBias = thread_s2r.partition_D(gBias_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N)
Tensor tSR_gScale = thread_s2r.partition_D(gScale_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N)
// Allocate intermediate registers for a single subtile
Tensor tSR_rD = make_tensor<ElementAccumulator>(take<0, 3>(shape(tSR_gD))); // ((S2R,S2R_V),S2R_M,S2R_N)
Tensor tSR_rD_final = make_tensor<ElementD>(shape(tSR_rD)); // ((S2R,S2R_V),S2R_M,S2R_N)
Tensor tSR_rC = make_tensor<ElementC>(shape(tSR_rD)); // ((S2R,S2R_V),S2R_M,S2R_N)
Tensor tSR_rBias = make_tensor<ElementBias>(tSR_gBias(_, _, _, 0, 0).layout()); // ((S2R,S2R_V),S2R_M,S2R_N)
Tensor tSR_rScale = make_tensor<ElementScale>(tSR_gScale(_, _, _, 0, 0).layout()); // ((S2R,S2R_V),S2R_M,S2R_N)
// Make an identity coordinate tensor for predicating our output MN tile
Tensor cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD))));
Tensor cD_epi = flat_divide(cD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
Tensor tSR_cD = thread_s2r.partition_D(cD_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N)
// epilogue subtile loop
CUTLASS_PRAGMA_UNROLL
for (int epi_m = 0; epi_m < size<2>(gD_epi); ++epi_m)
{
CUTLASS_PRAGMA_UNROLL
for (int epi_n = 0; epi_n < size<3>(gD_epi); ++epi_n)
{
int mma_m = (epi_m * epi_tile_m) / mma_tile_m;
int mma_n = (epi_n * epi_tile_n) / mma_tile_n;
Tensor tRS_rAcc_mn = tRS_rAcc(_, mma_m, mma_n);
int epi_n_in_mma = epi_n % (mma_tile_n / epi_tile_n);
int r2s_v = epi_n_in_mma * size(tRS_rD);
CUTLASS_PRAGMA_UNROLL
for (int epi_v = 0; epi_v < size(tRS_rD); ++epi_v)
{
tRS_rD(epi_v) = tRS_rAcc_mn(r2s_v + epi_v);
}
copy(tiled_r2s, tRS_rD, tRS_sD);
synchronize();
copy(tiled_s2r, tSR_sD, tSR_rD);
synchronize();
Tensor tSR_gC_mn = tSR_gC(_, _, _, epi_m, epi_n);
Tensor tSR_gBias_mn = tSR_gBias(_, _, _, epi_m, epi_n);
Tensor tSR_gScale_mn = tSR_gScale(_, _, _, epi_m, epi_n);
Tensor tSR_cD_mn = tSR_cD(_, _, _, epi_m, epi_n);
Tensor tSR_gD_mn = tSR_gD(_, _, _, epi_m, epi_n);
if (epilogue_op.is_source_needed())
{
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < size<1>(tSR_rD); ++m)
{
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < size<2>(tSR_rD); ++n)
{
if (elem_less(tSR_cD_mn(0, m, n), make_coord(get<0>(residue_mnk), get<1>(residue_mnk))))
{
copy(tSR_gC_mn(_, m, n), tSR_rC(_, m, n));
if (is_bias_needed)
{
copy(tSR_gBias_mn(_, m, n), tSR_rBias(_, m, n));
}
copy(tSR_gScale_mn(_, m, n), tSR_rScale(_, m, n));
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size<0>(tSR_rD); ++i)
{
auto epi_value = epilogue_op(tSR_rD(i, m, n), tSR_rC(i, m, n));
if (is_bias_needed)
{
epi_value += static_cast<ElementCompute>(tSR_rBias(i, m, n));
}
tSR_rD_final(i, m, n) = static_cast<ElementD>(tSR_rScale(i, m, n) * epi_value);
}
copy(CopyAtomR2G{}, tSR_rD_final(_, m, n), tSR_gD_mn(_, m, n));
}
}
}
}
else
{
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < size<1>(tSR_rD); ++m)
{
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < size<2>(tSR_rD); ++n)
{
if (elem_less(tSR_cD_mn(0, m, n), make_coord(get<0>(residue_mnk), get<1>(residue_mnk))))
{
if (is_bias_needed)
{
copy(tSR_gBias_mn(_, m, n), tSR_rBias(_, m, n));
}
copy(tSR_gScale_mn(_, m, n), tSR_rScale(_, m, n));
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size<0>(tSR_rD); ++i)
{
auto epi_value = epilogue_op(tSR_rD(i, m, n));
if (is_bias_needed)
{
epi_value += static_cast<ElementCompute>(tSR_rBias(i, m, n));
}
tSR_rD_final(i, m, n) = static_cast<ElementD>(tSR_rScale(i, m, n) * epi_value);
}
copy(CopyAtomR2G{}, tSR_rD_final(_, m, n), tSR_gD_mn(_, m, n));
}
}
}
}
}
}
}
private:
Params params;
};
namespace detail
{
template <class Element, class MaxVec>
constexpr auto get_vectorized_atomic_add_op()
{
using namespace cute;
auto constexpr MaxVecSize = size(MaxVec{});
if constexpr (is_same_v<Element, cutlass::half_t>)
{
if constexpr (MaxVecSize >= 8)
{
return SM90_RED_ADD_NOFTZ_F16x2_V4{};
}
else if constexpr (MaxVecSize >= 4)
{
return SM90_RED_ADD_NOFTZ_F16x2_V2{};
}
else if constexpr (MaxVecSize >= 2)
{
return SM70_RED_ADD_NOFTZ_F16x2{};
}
else
{
return SM70_RED_ADD_NOFTZ_F16{};
}
}
else if constexpr (is_same_v<Element, cutlass::bfloat16_t>)
{
if constexpr (MaxVecSize >= 8)
{
return SM90_RED_ADD_NOFTZ_BF16x2_V4{};
}
else if constexpr (MaxVecSize >= 4)
{
return SM90_RED_ADD_NOFTZ_BF16x2_V2{};
}
else if constexpr (MaxVecSize >= 2)
{
return SM90_RED_ADD_NOFTZ_BF16x2{};
}
else
{
return SM90_RED_ADD_NOFTZ_BF16{};
}
}
else
{
// non-vectorized atomic add for all other types until supported
return TypedAtomicAdd<Element>{};
}
}
} // namespace detail
template <class TileShape, class ElementC, class StrideC, class ElementD, class StrideD, class ElementAccumulator,
class ElementCompute, class ElementBias, class StrideBias, class ElementScale, class StrideScale>
struct EpilogueMoeFusedFinalizeBuilder
{
// assuming cooperative kernel schedule
using EpiTileN = decltype(cute::min(size<1>(TileShape{}), _32{}));
using EpilogueTile = Shape<_128, EpiTileN>;
// Output of linear combination is ElementCompute instead of ElementD
// since we will be doing more computate on it, no need to cast yet.
using ThreadEpilogueOp
= cutlass::epilogue::thread::LinearCombination<ElementCompute, 1, ElementAccumulator, ElementCompute,
cutlass::epilogue::thread::ScaleType::Default, cutlass::FloatRoundStyle::round_to_nearest, ElementC>;
using SmemLayoutAtomD
= decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom<StrideD, ElementAccumulator, EpilogueTile>());
using CopyAtomR2S = decltype(detail::sm90_get_smem_store_op_for_accumulator<StrideD, ElementAccumulator>());
using CopyAtomS2R = DefaultCopy;
using CopyAtomR2G = decltype(detail::get_vectorized_atomic_add_op<ElementD, EpiTileN>());
template <class EpilogueOp>
struct Sm90TmaWarpSpecializedAdapterWithSmemStorage : detail::Sm90TmaWarpSpecializedAdapter<EpilogueOp>
{
// We need to override this one using declaration because otherwise we double up on the smem
using TensorMapStorage = typename EpilogueOp::TensorMapStorage;
using Base = detail::Sm90TmaWarpSpecializedAdapter<EpilogueOp>;
CUTLASS_HOST_DEVICE
Sm90TmaWarpSpecializedAdapterWithSmemStorage(
typename EpilogueOp::Params const& params, [[maybe_unused]] typename Base::TensorStorage& shared_tensors)
: Base(params)
{
}
// These functions depend on the type of TensorMapStorage
template <bool IsLoad>
CUTLASS_DEVICE void tensormaps_perform_update([[maybe_unused]] TensorMapStorage& shared_tensormap,
[[maybe_unused]] typename EpilogueOp::Params const& params,
[[maybe_unused]] cute::TmaDescriptor const* tensormap, [[maybe_unused]] int32_t next_batch)
{
}
template <bool IsLoad>
CUTLASS_DEVICE void tensormaps_cp_fence_release([[maybe_unused]] TensorMapStorage& shared_tensormap,
[[maybe_unused]] cute::TmaDescriptor const* tensormap, [[maybe_unused]] uint32_t lane_predicate)
{
}
};
using CollectiveOp = Sm90TmaWarpSpecializedAdapterWithSmemStorage<
EpilogueMoeFusedFinalize<StrideC, ElementD, StrideD, ThreadEpilogueOp, ElementBias, StrideBias, ElementScale,
StrideScale, EpilogueTile, SmemLayoutAtomD, CopyAtomR2S, CopyAtomS2R, CopyAtomR2G>>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace collective
} // namespace epilogue
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,105 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Functor performing linear combination with a maximum operation used by epilogues.
*/
#pragma once
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/thread/activation.h"
#include "cutlass/epilogue/thread/linear_combination_generic.h"
#include "cutlass/epilogue/thread/scale_type.h"
#include "cutlass/functional.h"
#include "cutlass/half.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/numeric_types.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace epilogue
{
namespace thread
{
/////////////////////////////////////////////////////////////////////////////////////////////////
__forceinline__ __device__ float copysignf_pos(float a, float b)
{
float r;
r = __int_as_float(__float_as_int(a) | (__float_as_int(b) & 0x80000000));
return r;
}
__forceinline__ __device__ float tanh_opt(float x)
{
#if (__CUDACC_VER_MAJOR__ < 11) || (__CUDA_ARCH__ < 750)
float const exp_val = -1.f * fabs(2 * x);
return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x);
#else
return fast_tanh(x);
#endif
}
/////////////////////////////////////////////////////////////////////////////////////////////////
template <>
struct GELU_taylor<float>
{
static bool const kIsHeavy = true;
CUTLASS_DEVICE
float operator()(float const& z) const
{
float k0 = float(0.7978845608028654);
float k1 = float(0.044715);
return float(cutlass::constants::half<float>() * z
* (cutlass::constants::one<float>() + tanh_opt(k0 * z * (cutlass::constants::one<float>() + k1 * z * z))));
}
using Params = LinearCombinationGenericParams<float>;
CUTLASS_DEVICE
float operator()(float const& scalar, Params const& params_) const
{
return this->operator()(scalar);
}
};
} // namespace thread
} // namespace epilogue
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,352 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Epilogue visitor for threadblock scoped INT8 GEMMs that uses one scaling factor per row, and one per column.
original file: 3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h
*/
#pragma once
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "cutlass/arch/memory.h"
#include "cutlass/arch/memory_sm75.h"
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/numeric_conversion.h"
#include "tensorrt_llm/common/quantization.h"
namespace tk = tensorrt_llm::common;
namespace cutlass
{
namespace epilogue
{
namespace threadblock
{
template <typename ThreadblockShape_, int ThreadCount, typename ScaleTileIterator_, typename OutputTileIterator_,
typename ElementAccumulator_, typename ElementCompute_, typename ElementwiseFunctor_, bool UseMasking_ = false>
class EpilogueVisitorPerRowPerCol
{
public:
using ThreadblockShape = ThreadblockShape_;
static int const kThreadCount = ThreadCount;
using ScaleTileIterator = ScaleTileIterator_;
using OutputTileIterator = OutputTileIterator_;
using ElementwiseFunctor = ElementwiseFunctor_;
static int const kIterations = OutputTileIterator::kIterations;
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
using ElementOutput = typename OutputTileIterator::Element;
using LayoutOutput = cutlass::layout::RowMajor;
using ElementAccumulator = ElementAccumulator_;
using AlphaScaleElementType = typename ScaleTileIterator::Element;
using ElementCompute = ElementCompute_;
using AccumulatorFragment = Array<ElementAccumulator, kElementsPerAccess>;
using ComputeFragment = Array<ElementCompute_, kElementsPerAccess>;
using OutputVector = Array<ElementOutput, kElementsPerAccess>;
static int const kThreadsPerRow = OutputTileIterator::ThreadMap::Detail::kAccessWidth;
static bool const kHasMultiStepsInRow = (OutputTileIterator::ThreadMap::Iterations::kColumn > 1);
/// Argument structure
struct Arguments
{
typename ElementwiseFunctor::Params elementwise;
int64_t batch_stride_alpha;
int64_t batch_stride_C;
int64_t batch_stride_D;
//
// Methods
//
Arguments()
: batch_stride_alpha(0)
, batch_stride_C(0)
, batch_stride_D(0)
{
}
Arguments(typename ElementwiseFunctor::Params elementwise_)
: elementwise(elementwise_)
, batch_stride_alpha(0)
, batch_stride_C(0)
, batch_stride_D(0)
{
}
Arguments(typename ElementwiseFunctor::Params elementwise_, int64_t batch_stride_alpha_,
int64_t batch_stride_C_, int64_t batch_stride_D_)
: elementwise(elementwise_)
, batch_stride_alpha(batch_stride_alpha_)
, batch_stride_C(batch_stride_C_)
, batch_stride_D(batch_stride_D_)
{
}
};
struct Params
{
typename ElementwiseFunctor::Params elementwise;
int64_t batch_stride_alpha;
int64_t batch_stride_C;
int64_t batch_stride_D;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params() {}
CUTLASS_HOST_DEVICE
Params(Arguments const& args)
: elementwise(args.elementwise)
, batch_stride_alpha(args.batch_stride_alpha)
, batch_stride_C(args.batch_stride_C)
, batch_stride_D(args.batch_stride_D)
{
}
};
/// Shared storage
struct SharedStorage
{
};
private:
Params const& params_;
SharedStorage& shared_storage_;
MatrixCoord extent_;
MatrixCoord extent_real_;
ElementwiseFunctor elementwise_;
bool const per_token_quant_;
bool const per_channel_quant_;
AlphaScaleElementType* ptr_alpha_row_;
AlphaScaleElementType* ptr_alpha_col_;
ScaleTileIterator iterator_alpha_col_;
OutputTileIterator iterator_C_;
OutputTileIterator iterator_D_;
AlphaScaleElementType element_alpha_row_ = 1.0f;
AlphaScaleElementType element_alpha_col_ = 1.0f;
typename ScaleTileIterator::Fragment fragment_alpha_col_;
typename OutputTileIterator::Fragment fragment_C_;
typename OutputTileIterator::Fragment fragment_D_;
ElementAccumulator beta_;
int column_offset_;
MatrixCoord thread_offset_;
public:
CUTLASS_DEVICE
EpilogueVisitorPerRowPerCol(Params const& params, SharedStorage& shared_storage,
cutlass::MatrixCoord const& problem_size, int thread_idx, int warp_idx, int lane_idx,
typename ScaleTileIterator::Params params_alpha_col, typename OutputTileIterator::Params params_C,
typename OutputTileIterator::Params params_D, tk::QuantMode quant_option, AlphaScaleElementType* ptr_alpha_row,
AlphaScaleElementType* ptr_alpha_col, typename OutputTileIterator::Element* ptr_C,
typename OutputTileIterator::Element* ptr_D,
cutlass::MatrixCoord const& threadblock_offset = cutlass::MatrixCoord(0, 0), int column_offset = 0,
cutlass::MatrixCoord const& problem_size_real = cutlass::MatrixCoord(0, 0))
: params_(params)
, shared_storage_(shared_storage)
, extent_(problem_size)
, elementwise_(params.elementwise)
, per_token_quant_(quant_option.hasPerTokenScaling())
, per_channel_quant_(quant_option.hasPerChannelScaling())
, ptr_alpha_row_(ptr_alpha_row)
, ptr_alpha_col_(ptr_alpha_col)
, iterator_alpha_col_(params_alpha_col, ptr_alpha_col, problem_size, thread_idx, threadblock_offset)
, iterator_C_(params_C, ptr_C, problem_size, thread_idx, threadblock_offset)
, iterator_D_(params_D, ptr_D, problem_size, thread_idx, threadblock_offset)
, extent_real_(problem_size_real)
{
beta_ = (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr : params.elementwise.beta);
if (beta_ == ElementAccumulator())
{
iterator_C_.clear_mask();
}
if (!per_channel_quant_ && (ptr_alpha_col_ != nullptr))
{
element_alpha_col_ = *ptr_alpha_col_;
}
if (!per_token_quant_ && (ptr_alpha_row_ != nullptr))
{
element_alpha_row_ = *ptr_alpha_row_;
}
}
/// Helper to indicate split-K behavior
CUTLASS_DEVICE
void set_k_partition(int split_k_index, ///< Index of this threadblock within split-K partitioned scheme
int split_k_slices)
{ ///< Total number of split-K slices
}
/// Called to set the batch index
CUTLASS_DEVICE
void set_batch_index(int batch_idx)
{
iterator_alpha_col_.add_pointer_offset(batch_idx * params_.batch_stride_alpha);
iterator_C_.add_pointer_offset(batch_idx * params_.batch_stride_C);
iterator_D_.add_pointer_offset(batch_idx * params_.batch_stride_D);
}
/// Called at the start of the epilogue just before iterating over accumulator slices
CUTLASS_DEVICE
void begin_epilogue()
{
if (per_channel_quant_)
{
iterator_alpha_col_.load(fragment_alpha_col_);
}
}
/// Called at the start of one step before starting accumulator exchange
CUTLASS_DEVICE
void begin_step(int step_idx)
{
fragment_D_.clear();
fragment_C_.clear();
if (elementwise_.kScale != cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling)
{
iterator_C_.load(fragment_C_);
++iterator_C_;
}
}
/// Called at the start of a row
CUTLASS_DEVICE
void begin_row(int row_idx)
{
// load alpha_row in begin_step only when per token(row) scaling is used
if (per_token_quant_)
{
int thread_offset_row
= iterator_D_.thread_start_row() + OutputTileIterator::ThreadMap::iteration_offset(row_idx).row();
arch::global_load<AlphaScaleElementType, sizeof(AlphaScaleElementType)>(
element_alpha_row_, ptr_alpha_row_ + thread_offset_row, thread_offset_row < extent_.row());
}
}
/// Called after accumulators have been exchanged for each accumulator vector
CUTLASS_DEVICE
void visit(int iter_idx, int row_idx, int column_idx, int frag_idx, AccumulatorFragment const& accum)
{
NumericArrayConverter<ElementCompute, ElementAccumulator, kElementsPerAccess> source_converter;
ComputeFragment result = source_converter(accum);
if (per_channel_quant_)
{
ComputeFragment alpha_col = reinterpret_cast<ComputeFragment*>(&fragment_alpha_col_)[column_idx];
result = per_token_channel_scale_accumulator_(result, alpha_col, element_alpha_row_);
}
else
{
result = per_token_scale_accumulator_(result, element_alpha_col_, element_alpha_row_);
}
// Convert to the output
NumericArrayConverter<ElementOutput, ElementCompute, kElementsPerAccess> output_converter;
OutputVector& output = reinterpret_cast<OutputVector*>(&fragment_D_)[frag_idx];
output = output_converter(result);
}
/// Called at the end of a row
CUTLASS_DEVICE
void end_row(int row_idx) {}
/// Called after all accumulator elements have been visited
CUTLASS_DEVICE
void end_step(int step_idx)
{
iterator_D_.store(fragment_D_);
++iterator_D_;
}
/// Called after all steps have been completed
CUTLASS_DEVICE
void end_epilogue() {}
private:
CUTLASS_DEVICE
ComputeFragment per_token_channel_scale_accumulator_(
ComputeFragment const& accum, ComputeFragment const& scale_col, AlphaScaleElementType const& scale_row)
{
ComputeFragment result;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < ComputeFragment::kElements; ++i)
{
result[i] = accum[i] * (scale_col[i] * scale_row);
}
return result;
}
CUTLASS_DEVICE
ComputeFragment per_token_scale_accumulator_(
ComputeFragment const& accum, AlphaScaleElementType const& scale_col, AlphaScaleElementType const& scale_row)
{
ComputeFragment result;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < ComputeFragment::kElements; ++i)
{
result[i] = accum[i] * (scale_col * scale_row);
}
return result;
}
};
} // namespace threadblock
} // namespace epilogue
} // namespace cutlass

View File

@@ -0,0 +1,282 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
The epilogue rearranges the result of a matrix product through shared memory to match canonical
tensor layouts in global memory. Epilogues support conversion and reduction operations.
original file: 3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h
*/
#pragma once
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/platform/platform.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/epilogue/thread/linear_combination_clamp.h"
#include "cutlass/epilogue/thread/linear_combination_gelu.h"
#include "cutlass/epilogue/thread/linear_combination_hardswish.h"
#include "cutlass/epilogue/thread/linear_combination_planar_complex.h"
#include "cutlass/epilogue/thread/linear_combination_relu.h"
#include "cutlass/epilogue/thread/linear_combination_relu0.h"
#include "cutlass/epilogue/thread/linear_combination_sigmoid.h"
#include "cutlass/epilogue/thread/conversion_op.h"
#include "cutlass/epilogue/thread/reduction_op.h"
#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h"
#include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h"
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h"
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h"
#include "cutlass/epilogue/threadblock/shared_load_iterator.h"
#include "cutlass/epilogue/threadblock/shared_load_iterator_mixed.h"
#include "cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h"
#include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h"
#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h"
#include "cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h"
#include "cutlass/epilogue/threadblock/epilogue.h"
#include "cutlass/epilogue/threadblock/interleaved_epilogue.h"
#include "cutlass/layout/permute.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace epilogue
{
namespace threadblock
{
////////////////////////////////////////////////////////////////////////////////
namespace detail
{
/// Partial specialization for bfloat16_t <= int32_t x 8 epilogues avoids shared memory bank conflicts.
template <typename ThreadblockShape, typename WarpShape, typename InstructionShape, typename ThreadMap>
struct DefaultIteratorsTensorOp<cutlass::bfloat16_t, int32_t, 8, ThreadblockShape, WarpShape, InstructionShape,
ThreadMap>
{
using WarpTileIterator
= cutlass::epilogue::warp::TileIteratorTensorOpMixed<WarpShape, InstructionShape, int32_t, 32, 16, 8, 8>;
using SharedLoadIterator
= cutlass::epilogue::threadblock::SharedLoadIteratorMixed<ThreadMap, int32_t, 32, 16, 8, 8>;
static int const kFragmentsPerIteration = 2;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace detail
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Tile iterator used to load output tile from shared memory in epilogue.
///
/// Satisfies: ReadableTileIterator
///
template <typename ThreadMap_ ///< Thread map (concept: OutputTileThreadMap)
>
class SharedLoadIteratorMixed<ThreadMap_, int32_t, 32, 16, 8, 8>
{
public:
using ThreadMap = ThreadMap_;
using Shape = typename ThreadMap::Shape;
using Element = int32_t;
using Layout = layout::RowMajor;
using TensorRef = TensorRef<Element, Layout>;
using ConstTensorRef = typename TensorRef::ConstTensorRef;
using Index = typename Layout::Index;
using LongIndex = typename Layout::LongIndex;
using TensorCoord = MatrixCoord;
static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
static int const kAlignment = ThreadMap::kElementsPerAccess * sizeof_bits<Element>::value / 8;
static int const kThreads = ThreadMap::kThreads;
/// Fragment object
using Fragment = Array<Element,
ThreadMap::Iterations::kColumn * ThreadMap::Iterations::kRow * ThreadMap::Iterations::kGroup
* ThreadMap::Iterations::kCluster * ThreadMap::kElementsPerAccess>;
/// Memory access size
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess, kAlignment>;
/// Vector type used for SMEM loads
using LoadType = AlignedArray<Element, const_min(128 / sizeof_bits<Element>::value, ThreadMap::kElementsPerAccess),
const_min(16, kAlignment)>;
static int const kLoadsPerAccess = AccessType::kElements / LoadType::kElements;
private:
//
// Data members
//
/// Byte-level pointer
LoadType const* pointers_[kLoadsPerAccess];
/// Stride along adjacent rows in units of LoadType
int stride_;
public:
//
// Methods
//
/// Constructor
CUTLASS_DEVICE
SharedLoadIteratorMixed(TensorRef ref, int thread_idx)
: stride_((ref.stride(0) / LoadType::kElements))
{
TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx);
// Initialize pointers
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kLoadsPerAccess; ++i)
{
pointers_[i] = reinterpret_cast<LoadType const*>(ref.data());
int col_idx = (thread_offset.column() / kElementsPerAccess) * kLoadsPerAccess;
int bank_offset = (col_idx * int(sizeof(LoadType)) / 128) % kLoadsPerAccess;
col_idx += (bank_offset + i) % kLoadsPerAccess;
pointers_[i] += thread_offset.row() * stride_ + col_idx;
}
}
/// Adds a pointer offset in units of Element
CUTLASS_HOST_DEVICE
void add_pointer_offset(LongIndex pointer_offset)
{
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kLoadsPerAccess; ++i)
{
pointers_[i] += pointer_offset / LoadType::kElements;
}
}
CUTLASS_DEVICE
void add_tile_offset(TensorCoord const& offset)
{
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kLoadsPerAccess; ++i)
{
pointers_[i]
+= offset.row() * Shape::kRow * stride_ + offset.column() * Shape::kColumn / LoadType::kElements;
}
}
/// Loads a fragment from memory
CUTLASS_DEVICE
void load_with_pointer_offset(Fragment& frag, Index pointer_offset) const
{
CUTLASS_PRAGMA_UNROLL
for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster)
{
CUTLASS_PRAGMA_UNROLL
for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group)
{
CUTLASS_PRAGMA_UNROLL
for (int row = 0; row < ThreadMap::Iterations::kRow; ++row)
{
int row_ptr_offset = row * ThreadMap::Delta::kRow * stride_
+ group * ThreadMap::Delta::kGroup * stride_ + cluster * ThreadMap::Delta::kCluster * stride_
+ pointer_offset / LoadType::kElements;
int frag_row_idx
= (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster));
LoadType* frag_ptr = reinterpret_cast<LoadType*>(&frag);
CUTLASS_PRAGMA_UNROLL
for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column)
{
int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn + column;
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < kLoadsPerAccess; ++v)
{
int vector_idx
= (column * ThreadMap::Delta::kColumn / kElementsPerAccess * kLoadsPerAccess);
LoadType const* memory_pointer = pointers_[v] + row_ptr_offset;
frag_ptr[frag_idx * kLoadsPerAccess + v] = memory_pointer[vector_idx];
}
}
}
}
}
}
/// Loads a fragment
CUTLASS_DEVICE
void load(Fragment& frag) const
{
load_with_pointer_offset(frag, 0);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace epilogue
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,141 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* @file epilogue_helpers.h
*
* This file includes types for the epilogues. The empty structs exist so we can signal to template
* code the type of epilogue we want to run, and let the underlying code specify the details such as
* element types, accumulator type and elements per vector access.
*
*/
#pragma once
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/epilogue/thread/linear_combination_generic.h"
#include "cutlass/epilogue/thread/linear_combination_relu.h"
#include "cutlass/epilogue/thread/linear_combination_silu.h"
#include "cutlass_extensions/epilogue/thread/fused_activations.h"
#include <cutlass/epilogue/fusion/operations.hpp>
namespace tensorrt_llm
{
namespace cutlass_extensions
{
struct EpilogueOpBiasSilu
{
};
struct EpilogueOpBiasReLU
{
};
struct EpilogueOpBiasFtGelu
{
};
struct EpilogueOpBias
{
};
struct EpilogueOpDefaultSilu
{
};
struct EpilogueOpDefaultReLU
{
};
struct EpilogueOpDefaultFtGelu
{
};
struct EpilogueOpDefault
{
};
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator, typename Op>
struct Epilogue
{
static_assert(sizeof(ElementType) == 0, "Unrecognized Epilogue Tag");
};
constexpr auto BiasScaleMode = cutlass::epilogue::thread::ScaleType::NoBetaScaling;
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpBiasSilu>
{
using Op = cutlass::epilogue::thread::LinearCombinationSilu<ElementType, ElementsPerVectorAccess,
ElementAccumulator, ElementAccumulator, BiasScaleMode>;
};
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpBiasReLU>
{
using Op = cutlass::epilogue::thread::LinearCombinationRelu<ElementType, ElementsPerVectorAccess,
ElementAccumulator, ElementAccumulator, BiasScaleMode>;
};
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpBiasFtGelu>
{
using Op = cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::GELU_taylor, ElementType,
ElementsPerVectorAccess, ElementAccumulator, ElementAccumulator, BiasScaleMode,
cutlass::FloatRoundStyle::round_to_nearest, true>;
};
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpBias>
{
using Op = cutlass::epilogue::thread::LinearCombination<ElementType, ElementsPerVectorAccess, ElementAccumulator,
ElementAccumulator, BiasScaleMode>;
};
constexpr auto DefaultScaleMode = cutlass::epilogue::thread::ScaleType::Default;
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpDefaultSilu>
{
using Op = cutlass::epilogue::thread::LinearCombinationSilu<ElementType, ElementsPerVectorAccess,
ElementAccumulator, ElementAccumulator, DefaultScaleMode>;
};
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpDefaultReLU>
{
using Op = cutlass::epilogue::thread::LinearCombinationRelu<ElementType, ElementsPerVectorAccess,
ElementAccumulator, ElementAccumulator, DefaultScaleMode>;
};
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpDefaultFtGelu>
{
using Op = cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::GELU_taylor, ElementType,
ElementsPerVectorAccess, ElementAccumulator, ElementAccumulator, DefaultScaleMode,
cutlass::FloatRoundStyle::round_to_nearest, true>;
};
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpDefault>
{
using Op = cutlass::epilogue::thread::LinearCombination<ElementType, ElementsPerVectorAccess, ElementAccumulator,
ElementAccumulator, DefaultScaleMode>;
};
} // namespace cutlass_extensions
} // namespace tensorrt_llm

View File

@@ -0,0 +1,221 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/arch/mma.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/collective/builders/sm90_common.inl"
// SM90 Collective Builders should be used only starting CUDA 12.0
#if (__CUDACC_VER_MAJOR__ >= 12)
#define CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
#endif
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::collective
{
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace detail
{
// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count.
template <int CapacityBytes, class ElementA, class ElementB, class TileShapeMNK, bool SwapAB, int carveout_bytes>
constexpr int compute_stage_count_or_override_gated(StageCountAutoCarveout<carveout_bytes> stage_count)
{
// 32 bytes to account for barriers etc.
constexpr int stage_barrier_bytes = 32;
constexpr int a_bits = static_cast<int>(sizeof_bits<ElementA>::value);
constexpr int b_bits = static_cast<int>(sizeof_bits<ElementB>::value);
constexpr int stage_bytes = [&]() -> int
{
if constexpr (SwapAB)
{
return (a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{}) * 2) / 8
+ (b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) / 8 + stage_barrier_bytes;
}
else
{
return (a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) / 8
+ (b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{}) * 2) / 8 + stage_barrier_bytes;
}
}();
return (CapacityBytes - carveout_bytes) / stage_bytes;
}
} // namespace detail
/////////////////////////////////////////////////////////////////////////////////////////////////
// GMMA_TMA_WS_SS
template <class ElementA, class GmemLayoutA, int AlignmentA, class ElementB, class GmemLayoutB, int AlignmentB,
class ElementAccumulator, class TileShape_MNK, class ClusterShape_MNK, class StageCountType,
class KernelScheduleType, template <class /* ElementCompute */> class Activation, bool SwapAB>
struct CollectiveBuilderGated<arch::Sm90, arch::OpClassTensorOp, ElementA, GmemLayoutA, AlignmentA, ElementB,
GmemLayoutB, AlignmentB, ElementAccumulator, TileShape_MNK, ClusterShape_MNK, StageCountType, KernelScheduleType,
Activation, SwapAB,
cute::enable_if_t<(cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecialized>
|| cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpong>
|| cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperative>
|| cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedCooperative>) &&not detail::
is_use_rmem_A<ElementA, GmemLayoutA, ElementB, GmemLayoutB>()>>
{
static_assert(is_static<TileShape_MNK>::value);
static_assert(is_static<ClusterShape_MNK>::value);
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
static_assert(cutlass::detail::dependent_false<ElementA>, "Unsupported Toolkit for SM90 Collective Builder\n");
#endif
static_assert(detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
"Should meet TMA alignment requirement\n");
static constexpr bool IsArrayOfPointersGemm
= (cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedCooperative>);
static constexpr bool IsFP8Input = detail::is_input_fp8<ElementA, ElementB>();
static_assert(!IsFP8Input || (IsFP8Input && !IsArrayOfPointersGemm),
"Kernel[Array/Group]TmaWarpSpecializedCooperative is only compatible with FP8 FastAccum version right now\n");
// For fp32 types, map to tf32 MMA value type
using MmaElementA = cute::conditional_t<cute::is_same_v<ElementA, float>, tfloat32_t, ElementA>;
using MmaElementB = cute::conditional_t<cute::is_same_v<ElementB, float>, tfloat32_t, ElementB>;
static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A<MmaElementA, GmemLayoutA>();
static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B<MmaElementB, GmemLayoutB>();
using AtomLayoutMNK = cute::conditional_t<cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperative>
|| IsArrayOfPointersGemm,
Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;
using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector<MmaElementA, MmaElementB,
ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(),
AtomLayoutMNK{}));
using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));
using SmemLayoutAtomA = decltype(detail::ss_smem_selector<GmmaMajorA, MmaElementA,
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutAtomB = decltype(detail::ss_smem_selector<GmmaMajorB, MmaElementB,
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
static constexpr int PipelineStages
= detail::compute_stage_count_or_override_gated<detail::sm90_smem_capacity_bytes, MmaElementA, MmaElementB,
TileShape_MNK, SwapAB>(StageCountType{});
using DispatchPolicy = cute::conditional_t<IsArrayOfPointersGemm,
MainloopSm90ArrayTmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>,
/* For FP8 use a separate mainloop compared to other datatypes */
cute::conditional_t<IsFP8Input,
MainloopSm90TmaGmmaWarpSpecializedFP8<PipelineStages, ClusterShape_MNK, KernelScheduleType>,
MainloopSm90TmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>>>;
using SmemCopyAtomA = void;
using SmemCopyAtomB = void;
using CollectiveOp = CollectiveMmaGated<DispatchPolicy, TileShape_MNK, ElementA, TagToStrideA_t<GmemLayoutA>,
ElementB, TagToStrideB_t<GmemLayoutB>, TiledMma, GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity,
GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity, Activation, SwapAB>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// GMMA_TMA_WS_FP8_FAST_ACCUM_SS
template <class ElementA, class GmemLayoutA, int AlignmentA, class ElementB, class GmemLayoutB, int AlignmentB,
class ElementAccumulator, class TileShape_MNK, class ClusterShape_MNK, class StageCountType,
class KernelScheduleType, template <class /* ElementCompute */> class Activation, bool SwapAB>
struct CollectiveBuilderGated<arch::Sm90, arch::OpClassTensorOp, ElementA, GmemLayoutA, AlignmentA, ElementB,
GmemLayoutB, AlignmentB, ElementAccumulator, TileShape_MNK, ClusterShape_MNK, StageCountType, KernelScheduleType,
Activation, SwapAB,
cute::enable_if_t<cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedFP8FastAccum>
|| cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpongFP8FastAccum>
|| cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperativeFP8FastAccum>
|| cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum>>>
{
static_assert(is_static<TileShape_MNK>::value);
static_assert(is_static<ClusterShape_MNK>::value);
static_assert(detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
"Not meet TMA alignment requirement yet\n");
static_assert(
detail::is_input_fp8<ElementA, ElementB>(), "Only FP8 datatypes are compatible with these kernel schedules\n");
// Dispatch TN fp8 kernels only to TMA warp specialized FP8 builder
static_assert(!detail::is_use_rmem_A<ElementA, GmemLayoutA, ElementB, GmemLayoutB>(),
"Not supported for fp8 non-TN warp specialized kernels yet\n");
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
static_assert(cutlass::detail::dependent_false<ElementA>, "Unsupported Toolkit for SM90 Collective Builder\n");
#endif
static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A<ElementA, GmemLayoutA>();
static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B<ElementB, GmemLayoutB>();
static constexpr bool IsArrayOfPointersGemm
= (cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum>);
using AtomLayoutMNK
= cute::conditional_t<cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperativeFP8FastAccum>
|| IsArrayOfPointersGemm,
Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;
using TiledMma = decltype(cute::make_tiled_mma(
cute::GMMA::ss_op_selector<ElementA, ElementB, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(),
AtomLayoutMNK{}));
using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));
using SmemLayoutAtomA = decltype(detail::ss_smem_selector<GmmaMajorA, ElementA,
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutAtomB = decltype(detail::ss_smem_selector<GmmaMajorB, ElementB,
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
static constexpr int PipelineStages
= detail::compute_stage_count_or_override_gated<detail::sm90_smem_capacity_bytes, ElementA, ElementB,
TileShape_MNK, SwapAB>(StageCountType{});
using DispatchPolicy = cute::conditional_t<IsArrayOfPointersGemm,
MainloopSm90ArrayTmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>,
MainloopSm90TmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>>;
using SmemCopyAtomA = void;
using SmemCopyAtomB = void;
using CollectiveOp = CollectiveMmaGated<DispatchPolicy, TileShape_MNK, ElementA, TagToStrideA_t<GmemLayoutA>,
ElementB, TagToStrideB_t<GmemLayoutB>, TiledMma, GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity,
GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity, Activation, SwapAB>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,58 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass_extensions/gemm/collective/collective_mma_gated.hpp"
namespace cutlass::gemm::collective
{
/////////////////////////////////////////////////////////////////////////////////////////////////
template <class ArchTag, class OpClass, class ElementA, class GmemLayoutA, int AlignmentA, class ElementB,
class GmemLayoutB, int AlignmentB, class ElementAccumulator, class TileShape_MNK, class ClusterShape_MNK,
class StageCountType, class KernelScheduleType, template <class /* ElementCompute */> class Activation,
bool SwapAB = false, class Enable = void>
struct CollectiveBuilderGated
{
static_assert(sizeof(ElementA) == 0, "Could not build a collective for given parameters.");
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_gated.inl"
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,59 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/detail/dependent_false.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::collective
{
/////////////////////////////////////////////////////////////////////////////////////////////////
template <class DispatchPolicy, class TileShape, class ElementA, class StrideA, class ElementB, class StrideB,
class TiledMma, class GmemTiledCopyA, class SmemLayoutAtomA, class SmemCopyAtomA, class TransformA,
class GmemTiledCopyB, class SmemLayoutAtomB, class SmemCopyAtomB, class TransformB,
template <class /* ElementCompute */> class Activation, bool SwapAB = false>
struct CollectiveMmaGated
{
static_assert(cutlass::detail::dependent_false<ElementA>, "Could not find a mainloop specialization.");
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized.hpp"
#include "cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized_fp8.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,642 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cute/arch/cluster_sm90.hpp"
#include "cute/arch/copy_sm90.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cute/algorithm/functional.hpp"
#include "cute/algorithm/gemm.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cute/numeric/arithmetic_tuple.hpp"
#include "cute/tensor_predicate.hpp"
#include "cutlass/pipeline/pipeline.hpp"
#include "cutlass/trace.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::collective
{
using namespace cute;
/////////////////////////////////////////////////////////////////////////////////////////////////
// WarpSpecialized Mainloop
template <int Stages, class ClusterShape, class KernelSchedule, class TileShape_, class ElementA_, class StrideA_,
class ElementB_, class StrideB_, class TiledMma_, class GmemTiledCopyA_, class SmemLayoutAtomA_,
class SmemCopyAtomA_, class TransformA_, class GmemTiledCopyB_, class SmemLayoutAtomB_, class SmemCopyAtomB_,
class TransformB_, template <class /* ElementCompute */> class Activation_, bool SwapAB_>
struct CollectiveMmaGated<MainloopSm90TmaGmmaWarpSpecialized<Stages, ClusterShape, KernelSchedule>, TileShape_,
ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_, GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_,
GmemTiledCopyB_, SmemLayoutAtomB_, SmemCopyAtomB_, TransformB_, Activation_, SwapAB_>
{
static constexpr bool isGated = true;
static constexpr bool SwapAB = SwapAB_;
//
// Type Aliases
//
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecialized<Stages, ClusterShape, KernelSchedule>;
using TileShape = TileShape_;
using ElementA = ElementA_;
using StrideA = StrideA_;
using ElementB = ElementB_;
using StrideB = StrideB_;
using TiledMma = TiledMma_;
using ElementAccumulator = typename TiledMma::ValTypeC;
using GmemTiledCopyA = GmemTiledCopyA_;
using GmemTiledCopyB = GmemTiledCopyB_;
using SmemLayoutAtomA = SmemLayoutAtomA_;
using SmemLayoutAtomB = SmemLayoutAtomB_;
using SmemCopyAtomA = SmemCopyAtomA_;
using SmemCopyAtomB = SmemCopyAtomB_;
using TransformA = TransformA_;
using TransformB = TransformB_;
using ArchTag = typename DispatchPolicy::ArchTag;
using Activation = Activation_<ElementAccumulator>;
using ElementAux = cute::conditional_t<SwapAB, ElementA_, ElementB_>;
using ValTypeAux = cute::conditional_t<SwapAB, typename TiledMma::ValTypeA, typename TiledMma::ValTypeB>;
using MainloopPipeline = cutlass::PipelineTmaAsync<DispatchPolicy::Stages>;
using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
using PipelineParams = typename MainloopPipeline::Params;
static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert(
(size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert(
(size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert(
(size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert(
(size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
// Tile along modes in a way that maximizes the TMA box size.
using SmemLayoutA = decltype(tile_to_shape(SmemLayoutAtomA{},
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
conditional_t<::cutlass::gemm::detail::is_major<0, StrideA>(), Step<_2, _1, _3>, Step<_1, _2, _3>>{}));
using SmemLayoutB = decltype(tile_to_shape(SmemLayoutAtomB{},
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
conditional_t<::cutlass::gemm::detail::is_major<0, StrideB>(), Step<_2, _1, _3>, Step<_1, _2, _3>>{}));
using SmemLayoutAux = cute::conditional_t<SwapAB, SmemLayoutA, SmemLayoutB>;
static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more.");
static_assert(cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value
&& cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeB>::value,
"MMA atom must source both A and B operand from smem_desc for this mainloop.");
static_assert(
cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>,
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
static_assert(
cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>,
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
// TMA converts f32 input to tf32 when copying from GMEM to SMEM
// For all other types, cast to size equivalent uint type to avoid any rounding by TMA.
static constexpr bool ConvertF32toTF32A = cute::is_same_v<float, ElementA>;
static constexpr bool ConvertF32toTF32B = cute::is_same_v<float, ElementB>;
using InternalElementA = cute::conditional_t<ConvertF32toTF32A, tfloat32_t, uint_bit_t<sizeof_bits_v<ElementA>>>;
using InternalElementB = cute::conditional_t<ConvertF32toTF32B, tfloat32_t, uint_bit_t<sizeof_bits_v<ElementB>>>;
using InternalElementAux = cute::conditional_t<SwapAB, InternalElementA, InternalElementB>;
struct SharedStorage
{
struct TensorStorage : cute::aligned_struct<128>
{
cute::array_aligned<typename TiledMma::ValTypeA, cute::cosize_v<SmemLayoutA>> smem_A;
cute::array_aligned<typename TiledMma::ValTypeB, cute::cosize_v<SmemLayoutB>> smem_B;
cute::array_aligned<ValTypeAux, cute::cosize_v<SmemLayoutAux>> smem_Aux;
} tensors;
using PipelineStorage = typename MainloopPipeline::SharedStorage;
PipelineStorage pipeline;
};
using TensorStorage = typename SharedStorage::TensorStorage;
using PipelineStorage = typename SharedStorage::PipelineStorage;
// Host side kernel arguments
struct Arguments
{
ElementA const* ptr_A;
StrideA dA;
ElementB const* ptr_B;
StrideB dB;
float scale_d0 = 1.0f;
float scale_d1 = 1.0f;
uint32_t mma_promotion_interval = 4;
};
// Device side kernel params
struct Params
{
// Assumption: StrideA is congruent with Problem_MK
using TMA_A = decltype(make_tma_copy(GmemTiledCopyA{},
make_tensor(static_cast<InternalElementA const*>(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}),
SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any
// Assumption: StrideB is congruent with Problem_NK
using TMA_B = decltype(make_tma_copy(GmemTiledCopyB{},
make_tensor(static_cast<InternalElementB const*>(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}),
SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any
using TMA_Aux = cute::conditional_t<SwapAB, TMA_A, TMA_B>;
TMA_A tma_load_a;
TMA_B tma_load_b;
TMA_Aux tma_load_aux;
float scale_d0 = 1.0f;
float scale_d1 = 1.0f;
};
//
// Methods
//
template <class ProblemShape>
static constexpr Params to_underlying_arguments(
ProblemShape const& problem_shape, Arguments const& args, void* workspace)
{
(void) workspace;
// Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK)
auto problem_shape_MNKL = append<4>(problem_shape, 1);
auto [M, N, K, L] = problem_shape_MNKL;
auto ptr_A = reinterpret_cast<InternalElementA const*>(args.ptr_A);
auto ptr_B = reinterpret_cast<InternalElementB const*>(args.ptr_B);
Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M, K, L), args.dA));
Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N, K, L), args.dB));
typename Params::TMA_A tma_load_a = make_tma_copy(GmemTiledCopyA{}, tensor_a,
SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
size<1>(ClusterShape{})); // mcast along N mode for this M load, if any
typename Params::TMA_B tma_load_b = make_tma_copy(GmemTiledCopyB{}, tensor_b,
SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
if constexpr (SwapAB)
{
auto ptr_Aux = reinterpret_cast<InternalElementA const*>(args.ptr_A + size(make_shape(M, K, L)));
Tensor tensor_aux = make_tensor(ptr_Aux, make_layout(make_shape(M, K, L), args.dA));
typename Params::TMA_Aux tma_load_aux = make_tma_copy(GmemTiledCopyA{}, tensor_aux,
SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
size<1>(ClusterShape{})); // mcast along N mode for this M load, if any
return {tma_load_a, tma_load_b, tma_load_aux, args.scale_d0, args.scale_d1};
}
else
{
auto ptr_Aux = reinterpret_cast<InternalElementB const*>(args.ptr_B + size(make_shape(N, K, L)));
Tensor tensor_aux = make_tensor(ptr_Aux, make_layout(make_shape(N, K, L), args.dB));
typename Params::TMA_Aux tma_load_aux = make_tma_copy(GmemTiledCopyB{}, tensor_aux,
SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
return {tma_load_a, tma_load_b, tma_load_aux, args.scale_d0, args.scale_d1};
}
}
template <class ProblemShape>
static bool can_implement(ProblemShape const& problem_shape, [[maybe_unused]] Arguments const& args)
{
constexpr int tma_alignment_bits = 128;
auto problem_shape_MNKL = append<4>(problem_shape, 1);
auto [M, N, K, L] = problem_shape_MNKL;
bool implementable = true;
constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
implementable = implementable
&& cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M, K, L), StrideA{});
constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
implementable = implementable
&& cutlass::detail::check_alignment<min_tma_aligned_elements_B>(cute::make_shape(N, K, L), StrideB{});
if (!implementable)
{
CUTLASS_TRACE_HOST(
" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
}
return implementable;
}
static constexpr int K_PIPE_MAX = DispatchPolicy::Stages;
static constexpr int K_PIPE_MMAS = 1;
static constexpr uint32_t TmaTransactionBytes
= (size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast<uint32_t>(sizeof_bits<ElementA>::value)) / 8
+ (size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast<uint32_t>(sizeof_bits<ElementB>::value)) / 8
+ (size<0>(SmemLayoutAux{}) * size<1>(SmemLayoutAux{}) * static_cast<uint32_t>(sizeof_bits<ElementAux>::value))
/ 8;
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const& mainloop_params)
{
cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor());
cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor());
cute::prefetch_tma_descriptor(mainloop_params.tma_load_aux.get_tma_descriptor());
}
/// Set up the data needed by this collective for load and mma.
/// Returns a tuple of tensors. The collective and the kernel layer have the contract
/// Returned tuple must contain at least two elements, with the first two elements being:
/// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l)
/// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l)
/// gAux_xkl - The tma tensor, A/B after a local tile so it has shape (BLK_N,BLK_K,m/n,k,l)
/// The rest of the tensors can be specified as needed by this collective.
template <class ProblemShape_MNKL>
CUTLASS_DEVICE auto load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const
{
using X = Underscore;
// Separate out problem shape for convenience
auto [M, N, K, L] = problem_shape_MNKL;
// TMA requires special handling of strides to deal with coord codomain mapping
// Represent the full tensors -- get these from TMA
Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M, K, L)); // (m,k,l)
Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N, K, L)); // (n,k,l)
// Make tiled views, defer the slice
Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_, _, _), Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l)
Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_, _, _), Step<X, _1, _1>{}); // (BLK_N,BLK_K,n,k,l)
if constexpr (SwapAB)
{
Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor(make_shape(M, K, L)); // (m,k,l)
Tensor gAux_xkl
= local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _), Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l)
return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl);
}
else
{
Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor(make_shape(N, K, L)); // (n,k,l)
Tensor gAux_xkl
= local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _), Step<X, _1, _1>{}); // (BLK_N,BLK_K,n,k,l)
return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl);
}
}
/// Perform a collective-scoped matrix multiply-accumulate
/// Producer Perspective
template <class TensorA, class TensorB, class TensorAux, class KTileIterator, class BlockCoord>
CUTLASS_DEVICE void load(Params const& mainloop_params, MainloopPipeline pipeline, PipelineState smem_pipe_write,
cute::tuple<TensorA, TensorB, TensorAux> const& load_inputs, BlockCoord const& blk_coord,
KTileIterator k_tile_iter, int k_tile_count, int thread_idx, uint32_t block_rank_in_cluster,
TensorStorage& shared_tensors)
{
int lane_predicate = cute::elect_one_sync();
if (lane_predicate)
{
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()), SmemLayoutAux{});
//
// Prepare the TMA loads for A and B
//
constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape());
uint2 cluster_local_block_id
= {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
Tensor gA_mkl = get<0>(load_inputs);
Tensor gB_nkl = get<1>(load_inputs);
Tensor gAux_xkl = get<2>(load_inputs);
auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);
auto block_tma_aux = SwapAB ? mainloop_params.tma_load_aux.get_slice(cluster_local_block_id.y)
: mainloop_params.tma_load_aux.get_slice(cluster_local_block_id.x);
// Partition the inputs based on the current block coordinates.
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
Tensor gA = gA_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k)
Tensor gB = gB_nkl(_, _, n_coord, _, l_coord); // (BLK_N,BLK_K,k)
Tensor gAux = SwapAB ? gAux_xkl(_, _, m_coord, _, l_coord) : gAux_xkl(_, _, n_coord, _, l_coord);
// Applies the mapping from block_tma_a
Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
Tensor tAuxgAux = block_tma_aux.partition_S(gAux);
Tensor tAuxsAux = block_tma_aux.partition_D(sAux);
uint16_t mcast_mask_a = 0;
uint16_t mcast_mask_b = 0;
uint16_t mcast_mask_aux = 0;
// Issue TmaLoads
// Maps the tile -> block, value
if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>)
{
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
for (int n = 0; n < size<1>(block_layout); ++n)
{
mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x, n, Int<0>{}));
}
}
if constexpr (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>)
{
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
for (int m = 0; m < size<0>(block_layout); ++m)
{
mcast_mask_b |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, Int<0>{}));
}
}
if constexpr (SwapAB)
{
mcast_mask_aux = mcast_mask_a;
}
else
{
mcast_mask_aux = mcast_mask_b;
}
// Mainloop
CUTLASS_PRAGMA_NO_UNROLL
for (; k_tile_count > 0; --k_tile_count)
{
// LOCK smem_pipe_write for _writing_
pipeline.producer_acquire(smem_pipe_write);
//
// Copy gmem to smem for *k_tile_iter
//
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
int write_stage = smem_pipe_write.index();
copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_, _, _, *k_tile_iter),
tAsA(_, _, _, write_stage));
copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_, _, _, *k_tile_iter),
tBsB(_, _, _, write_stage));
copy(mainloop_params.tma_load_aux.with(*tma_barrier, mcast_mask_aux), tAuxgAux(_, _, _, *k_tile_iter),
tAuxsAux(_, _, _, write_stage));
++k_tile_iter;
// Advance smem_pipe_write
++smem_pipe_write;
}
}
}
/// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write)
{
int lane_predicate = cute::elect_one_sync();
// Issue the epilogue waits
if (lane_predicate)
{
/* This helps avoid early exit of blocks in Cluster
* Waits for all stages to either be released (all
* Consumer UNLOCKs), or if the stage was never used
* then would just be acquired since the phase was
* still inverted from make_producer_start_state
*/
pipeline.producer_tail(smem_pipe_write);
}
}
/// Perform a collective-scoped matrix multiply-accumulate
/// Consumer Perspective
template <class FrgTensorC>
CUTLASS_DEVICE void mma(MainloopPipeline pipeline, PipelineState smem_pipe_read, FrgTensorC& accum0,
FrgTensorC& accum1, int k_tile_count, int thread_idx, TensorStorage& shared_tensors,
Params const& mainloop_params)
{
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
static_assert(cute::rank(SmemLayoutAux{}) == 3, "Smem layout must be rank 3.");
static_assert(cute::is_void_v<SmemCopyAtomA>,
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
static_assert(cute::is_void_v<SmemCopyAtomB>,
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()), SmemLayoutAux{});
//
// Define C accumulators and A/B partitioning
//
TiledMma tiled_mma;
auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
// Allocate "fragments/descriptors"
Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE)
Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE)
auto tCsAux = [&]() -> auto
{
if constexpr (SwapAB)
{
return thread_mma.partition_A(sAux);
}
else
{
return thread_mma.partition_B(sAux);
}
}();
auto tCrAux = [&]() -> auto
{
if constexpr (SwapAB)
{
return thread_mma.make_fragment_A(tCsAux);
}
else
{
return thread_mma.make_fragment_B(tCsAux);
}
}();
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum0)); // M
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum0)); // N
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K
CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE
if constexpr (SwapAB)
{
CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<1>(accum1)); // M
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum1)); // N
CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCsAux)); // K
CUTE_STATIC_ASSERT_V(size<3>(tCsB) == size<3>(tCsAux)); // PIPE
}
else
{
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum1)); // M
CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<2>(accum1)); // N
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsAux)); // K
CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsAux)); // PIPE
}
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sA)); // PIPE
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sB)); // PIPE
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sAux)); // PIPE
//
// PIPELINED MAIN LOOP
//
static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), "ERROR : Incorrect number of MMAs in flight");
// We release buffers to producer warps(dma load) with some mmas in flight
PipelineState smem_pipe_release = smem_pipe_read;
// Prologue GMMAs
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
warpgroup_fence_operand(accum0);
warpgroup_fence_operand(accum1);
CUTLASS_PRAGMA_UNROLL
for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue)
{
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
pipeline.consumer_wait(smem_pipe_read, barrier_token);
int read_stage = smem_pipe_read.index();
warpgroup_arrive();
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block)
{
// (V,M,K) x (V,N,K) => (V,M,N)
cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accum0);
if constexpr (SwapAB)
{
cute::gemm(tiled_mma, tCrAux(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accum1);
}
else
{
cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), tCrAux(_, _, k_block, read_stage), accum1);
}
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
warpgroup_commit_batch();
++smem_pipe_read;
}
warpgroup_fence_operand(accum0);
warpgroup_fence_operand(accum1);
// Mainloop GMMAs
k_tile_count -= prologue_mma_count;
CUTLASS_PRAGMA_NO_UNROLL
for (; k_tile_count > 0; --k_tile_count)
{
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
pipeline.consumer_wait(smem_pipe_read, barrier_token);
//
// Compute on k_tile
//
int read_stage = smem_pipe_read.index();
warpgroup_fence_operand(accum0);
warpgroup_fence_operand(accum1);
warpgroup_arrive();
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block)
{
// (V,M,K) x (V,N,K) => (V,M,N)
cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accum0);
if constexpr (SwapAB)
{
cute::gemm(tiled_mma, tCrAux(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accum1);
}
else
{
cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), tCrAux(_, _, k_block, read_stage), accum1);
}
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
warpgroup_commit_batch();
/// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed
warpgroup_wait<K_PIPE_MMAS>();
warpgroup_fence_operand(accum0);
warpgroup_fence_operand(accum1);
// UNLOCK smem_pipe_release, done _computing_ on it
pipeline.consumer_release(smem_pipe_release);
// Advance smem_pipe_read and smem_pipe_release
++smem_pipe_read;
++smem_pipe_release;
}
warpgroup_fence_operand(accum0);
warpgroup_fence_operand(accum1);
}
/// Perform a Consumer Epilogue to release all buffers
CUTLASS_DEVICE void mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count)
{
// Prologue GMMAs
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
k_tile_count -= prologue_mma_count;
smem_pipe_release.advance(k_tile_count);
// Wait on all GMMAs to complete
warpgroup_wait<0>();
for (int count = 0; count < prologue_mma_count; ++count)
{
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
++smem_pipe_release;
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,665 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cute/arch/cluster_sm90.hpp"
#include "cute/arch/copy_sm90.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cute/algorithm/functional.hpp"
#include "cute/algorithm/gemm.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cute/numeric/arithmetic_tuple.hpp"
#include "cute/tensor_predicate.hpp"
#include "cutlass/epilogue/thread/activation.h"
#include "cutlass/gemm/collective/fp8_accumulation.hpp"
#include "cutlass/trace.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::collective
{
using namespace cute;
/////////////////////////////////////////////////////////////////////////////////////////////////
// WarpSpecialized Mainloop
template <int Stages, class ClusterShape, class KernelSchedule, class TileShape_, class ElementA_, class StrideA_,
class ElementB_, class StrideB_, class TiledMma_, class GmemTiledCopyA_, class SmemLayoutAtomA_,
class SmemCopyAtomA_, class TransformA_, class GmemTiledCopyB_, class SmemLayoutAtomB_, class SmemCopyAtomB_,
class TransformB_, template <class /* ElementCompute */> class Activation_, bool SwapAB_>
struct CollectiveMmaGated<MainloopSm90TmaGmmaWarpSpecializedFP8<Stages, ClusterShape, KernelSchedule>, TileShape_,
ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_, GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_,
GmemTiledCopyB_, SmemLayoutAtomB_, SmemCopyAtomB_, TransformB_, Activation_, SwapAB_>
{
static constexpr bool isGated = true;
static constexpr bool SwapAB = SwapAB_;
//
// Type Aliases
//
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedFP8<Stages, ClusterShape, KernelSchedule>;
using TileShape = TileShape_;
using ElementA = ElementA_;
using StrideA = StrideA_;
using ElementB = ElementB_;
using StrideB = StrideB_;
using TiledMma = TiledMma_;
using ElementAccumulator = typename TiledMma::ValTypeC;
using GmemTiledCopyA = GmemTiledCopyA_;
using GmemTiledCopyB = GmemTiledCopyB_;
using SmemLayoutAtomA = SmemLayoutAtomA_;
using SmemLayoutAtomB = SmemLayoutAtomB_;
using SmemCopyAtomA = SmemCopyAtomA_;
using SmemCopyAtomB = SmemCopyAtomB_;
using TransformA = TransformA_;
using TransformB = TransformB_;
using ArchTag = typename DispatchPolicy::ArchTag;
using Activation = Activation_<ElementAccumulator>;
using ElementAux = cute::conditional_t<SwapAB, ElementA_, ElementB_>;
using ValTypeAux = cute::conditional_t<SwapAB, typename TiledMma::ValTypeA, typename TiledMma::ValTypeB>;
using MainloopPipeline = cutlass::PipelineTmaAsync<DispatchPolicy::Stages>;
using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
using PipelineParams = typename MainloopPipeline::Params;
static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert(
(size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert(
(size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert(
(size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert(
(size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
// Tile along modes in a way that maximizes the TMA box size.
using SmemLayoutA = decltype(tile_to_shape(SmemLayoutAtomA{},
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
conditional_t<::cutlass::gemm::detail::is_major<0, StrideA>(), Step<_2, _1, _3>, Step<_1, _2, _3>>{}));
using SmemLayoutB = decltype(tile_to_shape(SmemLayoutAtomB{},
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
conditional_t<::cutlass::gemm::detail::is_major<0, StrideB>(), Step<_2, _1, _3>, Step<_1, _2, _3>>{}));
using SmemLayoutAux = cute::conditional_t<SwapAB, SmemLayoutA, SmemLayoutB>;
static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more.");
static_assert(cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value
&& cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeB>::value,
"MMA atom must source both A and B operand from smem_desc for this mainloop.");
static_assert(
cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>,
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
static_assert(
cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>,
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
struct SharedStorage
{
struct TensorStorage : cute::aligned_struct<128>
{
cute::array_aligned<typename TiledMma::ValTypeA, cute::cosize_v<SmemLayoutA>> smem_A;
cute::array_aligned<typename TiledMma::ValTypeB, cute::cosize_v<SmemLayoutB>> smem_B;
cute::array_aligned<ValTypeAux, cute::cosize_v<SmemLayoutAux>> smem_Aux;
} tensors;
using PipelineStorage = typename MainloopPipeline::SharedStorage;
PipelineStorage pipeline;
};
using TensorStorage = typename SharedStorage::TensorStorage;
using PipelineStorage = typename SharedStorage::PipelineStorage;
// Host side kernel arguments
struct Arguments
{
ElementA const* ptr_A;
StrideA dA;
ElementB const* ptr_B;
StrideB dB;
float scale_d0 = 1.0f;
float scale_d1 = 1.0f;
uint32_t mma_promotion_interval = 4;
};
// Device side kernel params
struct Params
{
// Assumption: StrideA is congruent with Problem_MK
using TMA_A = decltype(make_tma_copy(GmemTiledCopyA{},
make_tensor(static_cast<ElementA const*>(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}),
SmemLayoutA{}(_, _, 0), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any
// Assumption: StrideB is congruent with Problem_NK
using TMA_B = decltype(make_tma_copy(GmemTiledCopyB{},
make_tensor(static_cast<ElementB const*>(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}),
SmemLayoutB{}(_, _, 0), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any
using TMA_Aux = cute::conditional_t<SwapAB, TMA_A, TMA_B>;
TMA_A tma_load_a;
TMA_B tma_load_b;
TMA_Aux tma_load_aux;
float scale_d0 = 1.0f;
float scale_d1 = 1.0f;
uint32_t mma_promotion_interval = 4;
};
//
// Methods
//
template <class ProblemShape>
static constexpr Params to_underlying_arguments(
ProblemShape const& problem_shape, Arguments const& args, void* workspace)
{
(void) workspace;
// Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK)
auto problem_shape_MNKL = append<4>(problem_shape, 1);
auto [M, N, K, L] = problem_shape_MNKL;
auto ptr_A = reinterpret_cast<ElementA const*>(args.ptr_A);
auto ptr_B = reinterpret_cast<ElementB const*>(args.ptr_B);
Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M, K, L), args.dA));
Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N, K, L), args.dB));
typename Params::TMA_A tma_load_a = make_tma_copy(GmemTiledCopyA{}, tensor_a,
SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
size<1>(ClusterShape{})); // mcast along N mode for this M load, if any
typename Params::TMA_B tma_load_b = make_tma_copy(GmemTiledCopyB{}, tensor_b,
SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
if constexpr (SwapAB)
{
auto ptr_Aux = reinterpret_cast<ElementA const*>(args.ptr_A + size(make_shape(M, K, L)));
Tensor tensor_aux = make_tensor(ptr_Aux, make_layout(make_shape(M, K, L), args.dA));
typename Params::TMA_Aux tma_load_aux = make_tma_copy(GmemTiledCopyA{}, tensor_aux,
SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
size<1>(ClusterShape{})); // mcast along N mode for this M load, if any
return {tma_load_a, tma_load_b, tma_load_aux, args.scale_d0, args.scale_d1, args.mma_promotion_interval};
}
else
{
auto ptr_Aux = reinterpret_cast<ElementB const*>(args.ptr_B + size(make_shape(N, K, L)));
Tensor tensor_aux = make_tensor(ptr_Aux, make_layout(make_shape(N, K, L), args.dB));
typename Params::TMA_Aux tma_load_aux = make_tma_copy(GmemTiledCopyB{}, tensor_aux,
SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
return {tma_load_a, tma_load_b, tma_load_aux, args.scale_d0, args.scale_d1, args.mma_promotion_interval};
}
}
template <class ProblemShape>
static bool can_implement(ProblemShape const& problem_shape, [[maybe_unused]] Arguments const& args)
{
constexpr int tma_alignment_bits = 128;
auto problem_shape_MNKL = append<4>(problem_shape, 1);
auto [M, N, K, L] = problem_shape_MNKL;
bool implementable = true;
constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
implementable = implementable
&& cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M, K, L), StrideA{});
constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
implementable = implementable
&& cutlass::detail::check_alignment<min_tma_aligned_elements_B>(cute::make_shape(N, K, L), StrideB{});
/* MMA promotion interval should be a multiple of 4, since each mainloop iteration would issue 4 MMA
* instructions. */
implementable = implementable && (args.mma_promotion_interval % 4 == 0);
if (!implementable)
{
CUTLASS_TRACE_HOST(
" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
}
return implementable;
}
static constexpr int K_PIPE_MAX = DispatchPolicy::Stages;
static constexpr int K_PIPE_MMAS = 1;
static constexpr uint32_t TmaTransactionBytes
= (size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast<uint32_t>(sizeof_bits<ElementA>::value)) / 8
+ (size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast<uint32_t>(sizeof_bits<ElementB>::value)) / 8
+ (size<0>(SmemLayoutAux{}) * size<1>(SmemLayoutAux{}) * static_cast<uint32_t>(sizeof_bits<ElementAux>::value))
/ 8;
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const& mainloop_params)
{
cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor());
cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor());
cute::prefetch_tma_descriptor(mainloop_params.tma_load_aux.get_tma_descriptor());
}
/// Set up the data needed by this collective for load and mma.
/// Returns a tuple of tensors. The collective and the kernel layer have the contract
/// Returned tuple must contain at least two elements, with the first two elements being:
/// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l)
/// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l)
/// gAux_xkl - The tma tensor, A/B after a local tile so it has shape (BLK_N,BLK_K,m/n,k,l)
template <class ProblemShape_MNKL>
CUTLASS_DEVICE auto load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const
{
using X = Underscore;
// Separate out problem shape for convenience
auto [M, N, K, L] = problem_shape_MNKL;
// TMA requires special handling of strides to deal with coord codomain mapping
// Represent the full tensors -- get these from TMA
Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M, K, L)); // (m,k,l)
Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N, K, L)); // (n,k,l)
// Make tiled views, defer the slice
Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_, _, _), Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l)
Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_, _, _), Step<X, _1, _1>{}); // (BLK_N,BLK_K,n,k,l)
if constexpr (SwapAB)
{
Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor(make_shape(M, K, L)); // (m,k,l)
Tensor gAux_xkl
= local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _), Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l)
return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl);
}
else
{
Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor(make_shape(N, K, L)); // (n,k,l)
Tensor gAux_xkl
= local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _), Step<X, _1, _1>{}); // (BLK_N,BLK_K,n,k,l)
return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl);
}
}
/// Perform a collective-scoped matrix multiply-accumulate
/// Producer Perspective
template <class TensorA, class TensorB, class TensorAux, class KTileIterator, class BlockCoord>
CUTLASS_DEVICE void load(Params const& mainloop_params, MainloopPipeline pipeline, PipelineState smem_pipe_write,
cute::tuple<TensorA, TensorB, TensorAux> const& load_inputs, BlockCoord const& blk_coord,
KTileIterator k_tile_iter, int k_tile_count, int thread_idx, uint32_t block_rank_in_cluster,
TensorStorage& shared_tensors)
{
int lane_predicate = cute::elect_one_sync();
if (lane_predicate)
{
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()), SmemLayoutAux{});
//
// Prepare the TMA loads for A and B
//
constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
uint2 cluster_local_block_id
= {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
Tensor gA_mkl = get<0>(load_inputs);
Tensor gB_nkl = get<1>(load_inputs);
Tensor gAux_xkl = get<2>(load_inputs);
auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);
auto block_tma_aux = SwapAB ? mainloop_params.tma_load_aux.get_slice(cluster_local_block_id.y)
: mainloop_params.tma_load_aux.get_slice(cluster_local_block_id.x);
// Partition the inputs based on the current block coordinates.
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
Tensor gA = gA_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k)
Tensor gB = gB_nkl(_, _, n_coord, _, l_coord); // (BLK_N,BLK_K,k)
Tensor gAux = SwapAB ? gAux_xkl(_, _, m_coord, _, l_coord) : gAux_xkl(_, _, n_coord, _, l_coord);
// Applies the mapping from block_tma_a
Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
Tensor tAuxgAux = block_tma_aux.partition_S(gAux);
Tensor tAuxsAux = block_tma_aux.partition_D(sAux);
uint16_t mcast_mask_a = 0;
uint16_t mcast_mask_b = 0;
uint16_t mcast_mask_aux = 0;
// Issue TmaLoads
// Maps the tile -> block, value
if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>)
{
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
for (int n = 0; n < size<1>(block_layout); ++n)
{
mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x, n, Int<0>{}));
}
}
if constexpr (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>)
{
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
for (int m = 0; m < size<0>(block_layout); ++m)
{
mcast_mask_b |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, Int<0>{}));
}
}
if constexpr (SwapAB)
{
mcast_mask_aux = mcast_mask_a;
}
else
{
mcast_mask_aux = mcast_mask_b;
}
// Mainloop
CUTLASS_PRAGMA_NO_UNROLL
for (; k_tile_count > 0; --k_tile_count)
{
// LOCK smem_pipe_write for _writing_
pipeline.producer_acquire(smem_pipe_write);
//
// Copy gmem to smem for *k_tile_iter
//
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
int write_stage = smem_pipe_write.index();
copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_, _, _, *k_tile_iter),
tAsA(_, _, _, write_stage));
copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_, _, _, *k_tile_iter),
tBsB(_, _, _, write_stage));
copy(mainloop_params.tma_load_aux.with(*tma_barrier, mcast_mask_aux), tAuxgAux(_, _, _, *k_tile_iter),
tAuxsAux(_, _, _, write_stage));
++k_tile_iter;
// Advance smem_pipe_write
++smem_pipe_write;
}
}
}
/// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write)
{
int lane_predicate = cute::elect_one_sync();
// Issue the epilogue waits
if (lane_predicate)
{
/* This helps avoid early exit of blocks in Cluster
* Waits for all stages to either be released (all
* Consumer UNLOCKs), or if the stage was never used
* then would just be acquired since the phase was
* still inverted from make_producer_start_state
*/
pipeline.producer_tail(smem_pipe_write);
}
}
/// Perform a collective-scoped matrix multiply-accumulate
/// Consumer Perspective
template <class FrgTensorC>
CUTLASS_DEVICE void mma(MainloopPipeline pipeline, PipelineState smem_pipe_read, FrgTensorC& accum0,
FrgTensorC& accum1, int k_tile_count, int thread_idx, TensorStorage& shared_tensors,
Params const& mainloop_params)
{
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
static_assert(cute::is_void_v<SmemCopyAtomA>,
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
static_assert(cute::is_void_v<SmemCopyAtomB>,
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()), SmemLayoutAux{});
//
// Define C accumulators and A/B partitioning
//
TiledMma tiled_mma;
auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
// Allocate "fragments/descriptors"
Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE)
Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE)
auto tCsAux = [&]() -> auto
{
if constexpr (SwapAB)
{
return thread_mma.partition_A(sAux);
}
else
{
return thread_mma.partition_B(sAux);
}
}();
auto tCrAux = [&]() -> auto
{
if constexpr (SwapAB)
{
return thread_mma.make_fragment_A(tCsAux);
}
else
{
return thread_mma.make_fragment_B(tCsAux);
}
}();
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum0)); // M
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum0)); // N
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K
CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE
if constexpr (SwapAB)
{
CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<1>(accum1)); // M
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum1)); // N
CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCsAux)); // K
CUTE_STATIC_ASSERT_V(size<3>(tCsB) == size<3>(tCsAux)); // PIPE
}
else
{
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum1)); // M
CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<2>(accum1)); // N
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsAux)); // K
CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsAux)); // PIPE
}
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sA)); // PIPE
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sB)); // PIPE
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sAux)); // PIPE
//
// PIPELINED MAIN LOOP
//
static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), "ERROR : Incorrect number of MMAs in flight");
// We release buffers to producer warps(dma load) with some mmas in flight
PipelineState smem_pipe_release = smem_pipe_read;
// Prologue GMMAs
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
GmmaFP8Accumulation accumulation0(accum0, mainloop_params.mma_promotion_interval, size<2>(tCrA));
GmmaFP8Accumulation accumulation1(accum1, mainloop_params.mma_promotion_interval, size<2>(tCrA));
warpgroup_fence_operand(accumulation0());
warpgroup_fence_operand(accumulation1());
CUTLASS_PRAGMA_UNROLL
for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue)
{
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
pipeline.consumer_wait(smem_pipe_read, barrier_token);
if (accumulation0.prepare_if_needed())
{
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
}
int read_stage = smem_pipe_read.index();
warpgroup_arrive();
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block)
{
// (V,M,K) x (V,N,K) => (V,M,N)
cute::gemm(
tiled_mma, tCrA(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accumulation0());
if constexpr (SwapAB)
{
cute::gemm(
tiled_mma, tCrAux(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accumulation1());
}
else
{
cute::gemm(
tiled_mma, tCrA(_, _, k_block, read_stage), tCrAux(_, _, k_block, read_stage), accumulation1());
}
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
warpgroup_commit_batch();
accumulation0.promote_if_needed();
accumulation1.promote_if_needed();
++smem_pipe_read;
}
warpgroup_fence_operand(accumulation0());
warpgroup_fence_operand(accumulation1());
// Mainloop GMMAs
k_tile_count -= prologue_mma_count;
CUTLASS_PRAGMA_NO_UNROLL
for (; k_tile_count > 0; --k_tile_count)
{
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
pipeline.consumer_wait(smem_pipe_read, barrier_token);
//
// Compute on k_tile
//
int read_stage = smem_pipe_read.index();
if (accumulation0.prepare_if_needed())
{
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
}
warpgroup_fence_operand(accumulation0());
warpgroup_fence_operand(accumulation1());
warpgroup_arrive();
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block)
{
// (V,M,K) x (V,N,K) => (V,M,N)
cute::gemm(
tiled_mma, tCrA(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accumulation0());
if constexpr (SwapAB)
{
cute::gemm(
tiled_mma, tCrAux(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accumulation1());
}
else
{
cute::gemm(
tiled_mma, tCrA(_, _, k_block, read_stage), tCrAux(_, _, k_block, read_stage), accumulation1());
}
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
warpgroup_commit_batch();
/// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed
warpgroup_wait<K_PIPE_MMAS>();
warpgroup_fence_operand(accumulation0());
warpgroup_fence_operand(accumulation1());
accumulation0.promote_if_needed();
accumulation1.promote_if_needed();
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
// Advance smem_pipe_read and smem_pipe_release
++smem_pipe_read;
++smem_pipe_release;
}
accumulation0.promote_residue_if_needed();
accumulation1.promote_residue_if_needed();
warpgroup_fence_operand(accumulation0());
warpgroup_fence_operand(accumulation1());
}
/// Perform a Consumer Epilogue to release all buffers
CUTLASS_DEVICE void mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count)
{
// Prologue GMMAs
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
k_tile_count -= prologue_mma_count;
smem_pipe_release.advance(k_tile_count);
// Wait on all GMMAs to complete
warpgroup_wait<0>();
for (int count = 0; count < prologue_mma_count; ++count)
{
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
++smem_pipe_release;
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,438 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*!
\file
\brief The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and
batched array variants.
*/
#pragma once
// #include <limits>
#include "cutlass/arch/arch.h"
#include "cutlass/cutlass.h"
#include "cutlass/device_kernel.h"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/kernel/gemm_universal.h"
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
#include "cutlass/gemm/device/default_gemm_configuration.h"
#include "cutlass/gemm/kernel/default_gemm_universal.h"
#include "cutlass/trace.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace gemm
{
namespace device
{
/////////////////////////////////////////////////////////////////////////////////////////////////
/*
This is the device layer from CUTLASS 2.10 (SHA - cc85b64cf676c45f98a17e3a47c0aafcf817f088)
It is replicated here since we needed to duplicate kernel level APIs for mixed dtype GEMMs
and SmoothQuant. The newer device layer is not compatible with these older kernel level APIs.
Note: While CUTLASS 3.x supports stream-k, none of the kernels in the extensions folder support
that feature at the moment.
*/
template <typename GemmKernel_>
class GemmUniversalBaseCompat
{
public:
using GemmKernel = GemmKernel_;
using ThreadblockShape = typename GemmKernel::Mma::Shape;
using ElementA = typename GemmKernel::ElementA;
using LayoutA = typename GemmKernel::LayoutA;
using TensorRefA = TensorRef<ElementA const, LayoutA>;
static ComplexTransform const kTransformA = GemmKernel::kTransformA;
using ElementB = typename GemmKernel::ElementB;
using LayoutB = typename GemmKernel::LayoutB;
using TensorRefB = TensorRef<ElementB const, LayoutB>;
static ComplexTransform const kTransformB = GemmKernel::kTransformB;
using ElementC = typename GemmKernel::ElementC;
using LayoutC = typename GemmKernel::LayoutC;
using TensorRefC = TensorRef<ElementC const, LayoutC>;
using TensorRefD = TensorRef<ElementC, LayoutC>;
using ElementAccumulator = typename GemmKernel::Mma::Policy::Operator::ElementC;
using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp;
using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle;
using Operator = typename GemmKernel::Operator;
/// Argument structure
using Arguments = typename GemmKernel::Arguments;
protected:
/// Kernel parameters object
typename GemmKernel::Params params_;
protected:
/// Private helper to obtain the grid dimensions with fix-up for split-K
static void get_grid_shape_(gemm::GemmCoord& grid_tiled_shape, int& gemm_k_size, Arguments const& args)
{
// Determine grid shape
ThreadblockSwizzle threadblock_swizzle;
grid_tiled_shape = threadblock_swizzle.get_tiled_shape(
args.problem_size, {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count);
gemm_k_size = args.problem_size.k();
if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel)
{
int const kAlignK
= const_max(const_max(128 / sizeof_bits<ElementA>::value, 128 / sizeof_bits<ElementB>::value), 1);
gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK);
if (gemm_k_size)
{
grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size);
}
}
}
public:
/// Constructs the GEMM.
GemmUniversalBaseCompat() {}
/// Determines whether the GEMM can execute the given problem.
static Status can_implement(Arguments const& args)
{
// Determine grid shape
cutlass::gemm::GemmCoord grid_tiled_shape;
int gemm_k_size = 0;
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
ThreadblockSwizzle threadblock_swizzle;
dim3 grid = threadblock_swizzle.get_grid_shape(grid_tiled_shape);
uint32_t const kGridYZMax = ((1 << (sizeof(uint16_t) * 8)) - 1);
if (!(grid.y <= kGridYZMax && grid.z <= kGridYZMax))
{
return Status::kErrorInvalidProblem;
}
return GemmKernel::can_implement(args);
}
/// Gets the workspace size
static size_t get_workspace_size(Arguments const& args)
{
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_workspace_size()");
size_t workspace_bytes = 0;
// Determine grid shape
cutlass::gemm::GemmCoord grid_tiled_shape;
int gemm_k_size = 0;
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
if (args.mode == GemmUniversalMode::kGemmSplitKParallel)
{
// Split-K parallel always requires a temporary workspace
workspace_bytes = sizeof(ElementC) * size_t(args.batch_stride_D) * size_t(grid_tiled_shape.k());
}
else if (args.mode == GemmUniversalMode::kGemm && grid_tiled_shape.k() > 1)
{
// Serial split-K only requires a temporary workspace if the number of partitions along the
// GEMM K dimension is greater than one.
workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n());
}
CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes);
workspace_bytes += GemmKernel::get_extra_workspace_size(args, grid_tiled_shape);
return workspace_bytes;
}
/// Computes the grid shape
static dim3 get_grid_shape(Arguments const& args)
{
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_grid_shape()");
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord grid_tiled_shape;
int gemm_k_size = 0;
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
dim3 result = threadblock_swizzle.get_grid_shape(grid_tiled_shape);
CUTLASS_TRACE_HOST(" grid_tiled_shape: " << grid_tiled_shape << "\n"
<< " result = {" << result << "}");
return result;
}
/// Computes the maximum number of active blocks per multiprocessor
static int maximum_active_blocks(int smem_capacity = -1)
{
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::maximum_active_blocks()");
int max_active_blocks = -1;
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes");
if (smem_size <= (48 << 10))
{
cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks, Kernel<GemmKernel>, GemmKernel::kThreadCount, smem_size);
if (result == cudaSuccess)
{
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
return max_active_blocks;
}
}
else
{
// Query assuming zero shared memory then compute occupancy limit based on SMEM
cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks, Kernel<GemmKernel>, GemmKernel::kThreadCount, 0);
if (result != cudaSuccess)
{
CUTLASS_TRACE_HOST(
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result));
return -1;
}
if (smem_capacity < 0)
{
int device_idx = 0;
result = cudaGetDevice(&device_idx);
if (result != cudaSuccess)
{
return -1;
}
cudaDeviceProp properties;
result = cudaGetDeviceProperties(&properties, device_idx);
if (result != cudaSuccess)
{
return -1;
}
smem_capacity = static_cast<int>(properties.sharedMemPerMultiprocessor);
}
int occupancy = std::min(max_active_blocks, smem_capacity / smem_size);
CUTLASS_TRACE_HOST(" occupancy: " << occupancy);
return occupancy;
}
CUTLASS_TRACE_HOST(" returning internal error");
return -1;
}
/// Initializes GEMM state from arguments.
Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr)
{
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::initialize() - workspace "
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
size_t workspace_bytes = get_workspace_size(args);
CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes);
if (workspace_bytes)
{
if (!workspace)
{
CUTLASS_TRACE_HOST(" error: device workspace must not be null");
return Status::kErrorWorkspaceNull;
}
if (args.mode == GemmUniversalMode::kGemm)
{
CUTLASS_TRACE_HOST(" clearing device workspace");
cudaError_t result = cudaMemsetAsync(workspace, 0, workspace_bytes, stream);
if (result != cudaSuccess)
{
CUTLASS_TRACE_HOST(" cudaMemsetAsync() returned error " << cudaGetErrorString(result));
return Status::kErrorInternal;
}
}
}
// Get CUDA grid shape
cutlass::gemm::GemmCoord grid_tiled_shape;
int gemm_k_size = 0;
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
// Initialize the Params structure
params_ = typename GemmKernel::Params(args, grid_tiled_shape, gemm_k_size, static_cast<int*>(workspace));
// Specify shared memory capacity for kernel.
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
if (smem_size >= (48 << 10))
{
cudaError_t result
= cudaFuncSetAttribute(Kernel<GemmKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
if (result != cudaSuccess)
{
return Status::kErrorInternal;
}
}
return Status::kSuccess;
}
/// Lightweight update given a subset of arguments
Status update(Arguments const& args, void* workspace = nullptr)
{
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat()::update() - workspace: " << workspace);
size_t workspace_bytes = get_workspace_size(args);
if (workspace_bytes && !workspace)
{
return Status::kErrorWorkspaceNull;
}
params_.update(args, workspace);
return Status::kSuccess;
}
/// Runs the kernel using initialized state.
Status run(cudaStream_t stream = nullptr)
{
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::run()");
//
// Configure grid and block dimensions
//
ThreadblockSwizzle threadblock_swizzle;
dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
dim3 block(GemmKernel::kThreadCount, 1, 1);
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
//
// Launch kernel
//
CUTLASS_TRACE_HOST(" grid: (" << grid << "), block: (" << block << "), SMEM: " << smem_size << " bytes");
// Launch
cutlass::Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params_);
//
// Query for errors
//
cudaError_t result = cudaGetLastError();
if (result != cudaSuccess)
{
CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result));
return Status::kErrorInternal;
}
return Status::kSuccess;
}
/// Runs the kernel using initialized state.
Status operator()(cudaStream_t stream = nullptr)
{
return run(stream);
}
/// Runs the kernel using initialized state.
Status operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr)
{
Status status = initialize(args, workspace, stream);
if (status == Status::kSuccess)
{
status = run(stream);
}
return status;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace device
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,542 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*!
\file
\brief Based on cutlass/include/cutlass/gemm/kernel/gemm_grouped.h
*/
#pragma once
#include <limits>
#include <numeric>
#include <vector>
#include "cutlass/arch/arch.h"
#include "cutlass/cutlass.h"
#include "cutlass/device_kernel.h"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/kernel/gemm_universal.h"
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
#include "cutlass/gemm/device/default_gemm_configuration.h"
#include "cutlass/gemm/kernel/default_gemm_universal.h"
#include "cutlass/trace.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace gemm
{
namespace device
{
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T_IN, typename T_OUT>
__global__ void splitkReduction(T_OUT** out_tensor, const T_IN* in_tensor, GemmCoord const* problem_sizes, int splitk,
int64_t* splitk_buffer_offsets)
{
// in_tensor: [problem_idx, k_partition, hidden_size]
// Note that different requests of in_tensor might have different hidden_size (=m*n)
// so, we need to use splitk_buffer_offsets.
// out_tensor: problem_idx * [hidden_size]
int const problem_idx = blockIdx.y;
GemmCoord problem = problem_sizes[problem_idx];
int const hidden_size = problem.m() * problem.n();
const T_IN* in_tensor_ = in_tensor + splitk_buffer_offsets[problem_idx] * splitk;
T_OUT* out_tensor_ = out_tensor[problem_idx];
for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < hidden_size; i += blockDim.x * gridDim.x)
{
float sum = 0.0f;
for (int k_idx = 0; k_idx < splitk; k_idx++)
{
sum += (float) in_tensor_[k_idx * hidden_size + i];
}
out_tensor_[i] = (T_OUT) (sum);
}
}
/// GEMM Grouped
template <typename BaseKernel_>
class BaseSplitkGrouped
{
public:
using BaseKernel = BaseKernel_;
using ElementA = typename BaseKernel::ElementA;
using LayoutA = typename BaseKernel::LayoutA;
using TensorRefA = TensorRef<ElementA const, LayoutA>;
static ComplexTransform const kTransformA = BaseKernel::kTransformA;
static int const kAlignmentA = BaseKernel::kAlignmentA;
using ElementB = typename BaseKernel::ElementB;
using LayoutB = typename BaseKernel::LayoutB;
using TensorRefB = TensorRef<ElementB const, LayoutB>;
static ComplexTransform const kTransformB = BaseKernel::kTransformB;
static int const kAlignmentB = BaseKernel::kAlignmentB;
using ElementC = typename BaseKernel::ElementC;
using LayoutC = typename BaseKernel::LayoutC;
using TensorRefC = TensorRef<ElementC const, LayoutC>;
using TensorRefD = TensorRef<ElementC, LayoutC>;
static int const kAlignmentC = BaseKernel::kAlignmentC;
using ElementAccumulator = typename BaseKernel::Mma::Policy::Operator::ElementC;
using EpilogueOutputOp = typename BaseKernel::EpilogueOutputOp;
using ThreadblockSwizzle = typename threadblock::GemmSplitKHorizontalThreadblockSwizzle;
using Operator = typename BaseKernel::Operator;
using WarpMmaOperator = typename BaseKernel::Mma::Policy::Operator;
using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator;
using MathOperator = typename WarpMmaOperator::MathOperator;
using OperatorClass = typename WarpMmaOperator::OperatorClass;
using ArchTag = typename WarpMmaOperator::ArchTag;
using ThreadblockShape = typename BaseKernel::Mma::Shape;
using WarpShape = typename BaseKernel::WarpShape;
using InstructionShape = typename BaseKernel::InstructionShape;
static int const kStages = BaseKernel::Mma::kStages;
/// Argument structure
using Arguments = typename BaseKernel::Arguments;
using ProblemInfo = typename BaseKernel::ProblemVisitor::ProblemInfo;
protected:
/// Kernel parameters object
typename BaseKernel::Params gemm_params_;
private:
/// Get the number of tiles across all problems in a group
static int32_t group_tile_count(cutlass::gemm::GemmCoord const* problem_sizes_ptr, int problem_count)
{
int32_t tiles = 0;
for (int32_t i = 0; i < problem_count; ++i)
{
cutlass::gemm::GemmCoord problem = problem_sizes_ptr[i];
BaseKernel::ProblemVisitor::possibly_transpose_problem(problem);
tiles += problem_tile_count(problem);
}
return tiles;
}
/// Copy from `data` to `workspace`
Status copy_to_workspace(void* workspace, void* data, size_t bytes)
{
cudaError_t cuda_error = cudaMemcpy(workspace, data, bytes, cudaMemcpyHostToDevice);
if (cuda_error != cudaSuccess)
{
// Call cudaGetLastError() to clear the error bit
cuda_error = cudaGetLastError();
CUTLASS_TRACE_HOST(" cudaMemcpy() returned error " << cudaGetErrorString(cuda_error));
return Status::kErrorInternal;
}
return Status::kSuccess;
}
/// Precomputes scheduling information for the grouped GEMM
Status precompute(Arguments const& args, int32_t tile_count, void* workspace)
{
size_t workspace_bytes = get_workspace_size(args);
std::vector<uint8_t> host_workspace(workspace_bytes);
BaseKernel::ProblemVisitor::host_precompute(
args.host_problem_sizes, args.problem_count, args.threadblock_count, (void*) host_workspace.data());
return copy_to_workspace(workspace, host_workspace.data(), workspace_bytes);
}
/// Reorder `data` according to `indices`
template <typename T>
static void reorder_array(T* data, std::vector<size_t> const& indices)
{
// For now, simply create a copy of the data and then copy over to the original.
std::vector<T> copy(indices.size());
for (size_t i = 0; i < indices.size(); ++i)
{
copy.at(i) = data[indices[i]];
}
memcpy(data, copy.data(), indices.size() * sizeof(T));
}
public:
/// Constructs the GEMM.
BaseSplitkGrouped() {}
/// Determines whether the GEMM can execute the given problem.
static Status can_implement(Arguments const& args)
{
return BaseKernel::can_implement(args);
}
/// Get the number of tiles in a problem
static int32_t problem_tile_count(cutlass::gemm::GemmCoord const& problem)
{
auto grid = BaseKernel::ProblemVisitor::grid_shape(problem);
return BaseKernel::ProblemVisitor::tile_count(grid);
}
/// Get the number of tiles across all problems in a group
static int32_t group_tile_count(Arguments const& args)
{
if (args.host_problem_sizes == nullptr)
{
CUTLASS_TRACE_HOST("Received nullptr for `args.host_problem_sizes");
return -1;
}
return group_tile_count(args.host_problem_sizes, args.problem_count);
}
/// Gets the workspace size
static size_t get_workspace_size(Arguments const& args)
{
size_t total_mn = 0;
for (int i = 0; i < args.problem_count; i++)
{
total_mn += args.host_problem_sizes[i].m() * args.host_problem_sizes[i].n();
}
size_t workSpaceSize = total_mn * sizeof(ElementAccumulator) * args.split_k_slices;
if (BaseKernel::ProblemVisitor::kRequiresPrecomputation)
{
workSpaceSize += BaseKernel::ProblemVisitor::get_workspace_size(
args.host_problem_sizes, args.problem_count, args.threadblock_count);
}
return workSpaceSize;
}
/// Computes the grid shape
static dim3 get_grid_shape(Arguments const& args)
{
return dim3(args.threadblock_count, 1, 1);
}
/// Computes the maximum number of active blocks per multiprocessor
static int maximum_active_blocks(int smem_capacity = -1)
{
CUTLASS_TRACE_HOST("BaseSplitkGrouped::maximum_active_blocks()");
int smem_size = int(sizeof(typename BaseKernel::SharedStorage));
CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes");
cudaError_t result;
if (smem_size > (48 << 10))
{
result = cudaFuncSetAttribute(Kernel<BaseKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
if (result != cudaSuccess)
{
// Call cudaGetLastError() to clear the error bit
result = cudaGetLastError();
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(result));
return -1;
}
}
int max_active_blocks = -1;
result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks, Kernel<BaseKernel>, BaseKernel::kThreadCount, smem_size);
if (result != cudaSuccess)
{
// Call cudaGetLastError() to clear the error bit
result = cudaGetLastError();
CUTLASS_TRACE_HOST(
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result));
return -1;
}
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
return max_active_blocks;
}
/// Sorts each pointer passed in according to the indices that sort
/// `problem_sizes_ptr` in descending order of problem-K dimension.
static void sort_problems(int problem_count, cutlass::gemm::GemmCoord* problem_sizes_ptr, int64_t* lda_host_ptr,
int64_t* ldb_host_ptr, int64_t* ldc_host_ptr, int64_t* ldd_host_ptr, int64_t* offset_A_ptr,
int64_t* offset_B_ptr, int64_t* offset_C_ptr, int64_t* offset_D_ptr)
{
std::vector<size_t> indices(problem_count);
std::iota(indices.begin(), indices.end(), 0);
std::stable_sort(indices.begin(), indices.end(),
[&problem_sizes_ptr](size_t i, size_t j) { return problem_sizes_ptr[i].k() > problem_sizes_ptr[j].k(); });
reorder_array(problem_sizes_ptr, indices);
reorder_array(lda_host_ptr, indices);
reorder_array(ldb_host_ptr, indices);
reorder_array(ldc_host_ptr, indices);
reorder_array(ldd_host_ptr, indices);
reorder_array(offset_A_ptr, indices);
reorder_array(offset_B_ptr, indices);
reorder_array(offset_C_ptr, indices);
reorder_array(offset_D_ptr, indices);
}
/// Computes the number of threadblocks to launch for the grouped kernel
static int sufficient(
cutlass::gemm::GemmCoord const* problem_sizes_ptr = nullptr, int problem_count = 0, int available_sm_count = -1)
{
// Determine the number of blocks that would be launched to fill up a single
// wave on the GPU with each SM having maximum occupancy.
int device_idx;
cudaError_t result = cudaGetDevice(&device_idx);
if (result != cudaSuccess)
{
// Call cudaGetLastError() to clear the error bit
result = cudaGetLastError();
CUTLASS_TRACE_HOST(" cudaGetDevice() returned error " << cudaGetErrorString(result));
return 0;
}
int multiprocessor_count;
result = cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device_idx);
if (result != cudaSuccess)
{
CUTLASS_TRACE_HOST(" cudaDeviceGetAttribute() returned error " << cudaGetErrorString(result));
return 0;
}
bool override_sm_count = (available_sm_count < 0 || available_sm_count > multiprocessor_count);
if (override_sm_count)
{
available_sm_count = multiprocessor_count;
}
int max_active_blocks = maximum_active_blocks();
if (max_active_blocks <= 0)
{
return 0;
}
int occupancy_based_block_count = available_sm_count * max_active_blocks;
if (problem_sizes_ptr == nullptr || problem_count == 0)
{
return occupancy_based_block_count;
}
int total_tiles = group_tile_count(problem_sizes_ptr, problem_count);
// If the group contains a single problem, launching the exact number of
// threadblocks needed to cover the problem minimizes the work performed
// per threadblock in finding the next tile to compute. We return total_tiles
// unless the user has provided the SM count.
if (problem_count == 1 && override_sm_count)
{
return total_tiles;
}
// Choose between the full wave of threadblocks and the tile count. If there
// are fewer tiles in the group than threadblocks in the full wave, only
// some threadblocks will be assigned tiles. Those threadblocks
// which are not assigned tiles still need to perform the work of iterating through
// problem sizes to determine that they have no work to do. This competes for cycles
// with those threadblocks that are assigned tiles to compute.
return std::min(total_tiles, occupancy_based_block_count);
}
/// Initializes GEMM state from arguments.
Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr)
{
CUTLASS_TRACE_HOST("BaseSplitkGrouped::initialize() - workspace "
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
// Workspace
size_t workspace_bytes = get_workspace_size(args);
if (workspace_bytes && !workspace)
{
return Status::kErrorWorkspaceNull;
}
if (BaseKernel::ProblemVisitor::kRequiresPrecomputation)
{
int32_t tile_count = group_tile_count(args);
Status status = precompute(args, tile_count, workspace);
if (status != Status::kSuccess)
{
return status;
}
gemm_params_ = typename BaseKernel::Params(args, workspace, tile_count);
}
else
{
gemm_params_ = typename BaseKernel::Params(args, workspace);
}
// Specify shared memory capacity for kernel.
int smem_size = int(sizeof(typename BaseKernel::SharedStorage));
if (smem_size >= (48 << 10))
{
cudaError_t result
= cudaFuncSetAttribute(Kernel<BaseKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
if (result != cudaSuccess)
{
return Status::kErrorInternal;
}
}
return Status::kSuccess;
}
/// Lightweight update given a subset of arguments
Status update(Arguments const& args, void* workspace = nullptr)
{
size_t workspace_bytes = get_workspace_size(args);
if (workspace_bytes && !workspace)
{
return Status::kErrorWorkspaceNull;
}
if (BaseKernel::ProblemVisitor::kRequiresPrecomputation)
{
int32_t tile_count = group_tile_count(args);
Status status = precompute(args, tile_count, workspace);
if (status != Status::kSuccess)
{
return status;
}
gemm_params_.update(args, workspace, tile_count);
}
else
{
gemm_params_.update(args, workspace);
}
return Status::kSuccess;
}
/// Runs the kernel using initialized state.
Status run(cudaStream_t stream = nullptr)
{
if (!gemm_params_.problem_visitor.problem_count)
{
return Status::kSuccess;
}
//
// Launch kernel
//
// Launch splitk grouped gemm
{
dim3 grid(gemm_params_.threadblock_count, 1, gemm_params_.split_k_slices);
dim3 block(BaseKernel::kThreadCount, 1, 1);
int smem_size = int(sizeof(typename BaseKernel::SharedStorage));
cutlass::Kernel<BaseKernel><<<grid, block, smem_size, stream>>>(gemm_params_);
cudaError_t result = cudaGetLastError();
if (result != cudaSuccess)
{
CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result));
return Status::kErrorInternal;
}
}
// Launch splitkReduction
{
dim3 grid(32, gemm_params_.problem_visitor.problem_count);
dim3 block(256);
splitkReduction<<<grid, block, 0, stream>>>(gemm_params_.ptr_D, gemm_params_.ptr_D_split,
gemm_params_.problem_visitor.problem_sizes, gemm_params_.split_k_slices,
gemm_params_.splitk_buffer_offsets);
cudaError_t result = cudaGetLastError();
if (result != cudaSuccess)
{
CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result));
return Status::kErrorInternal;
}
}
return Status::kSuccess;
}
/// Runs the kernel using initialized state.
Status operator()(cudaStream_t stream = nullptr)
{
return run(stream);
}
/// Initializes and runs the kernel.
Status operator()(Arguments const& args, void* workspace, cudaStream_t stream = nullptr)
{
Status status = initialize(args, workspace, stream);
if (status == Status::kSuccess)
{
status = run(stream);
}
return status;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM Grouped
template <typename GemmKernel_>
class SplitkGemmGrouped : public BaseSplitkGrouped<GemmKernel_>
{
public:
using GemmKernel = GemmKernel_;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace device
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,162 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cutlass/arch/arch.h"
#include "cutlass/arch/mma.h"
#include "cutlass/bfloat16.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/half.h"
#include "cutlass/layout/matrix.h"
#include "cutlass_extensions/arch/mma.h"
#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h"
namespace cutlass
{
namespace gemm
{
namespace kernel
{
template <typename TypeA, typename TypeB, typename arch, typename Enable = void>
struct MixedGemmArchTraits
{
static_assert(dependent_false<arch>, "Unrecognised parameterization");
};
template <typename Arch>
struct MixedGemmArchTraits<float, float, Arch>
{
static constexpr int Stages = 2;
using OperatorClass = cutlass::arch::OpClassSimt;
using AccType = float;
using LayoutB = cutlass::layout::ColumnMajor;
static constexpr int ElementsPerAccessA = 1;
static constexpr int ElementsPerAccessB = 1;
static constexpr int ElementsPerAccessC = 1;
static constexpr int ThreadblockK = 8;
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
using Operator = cutlass::arch::OpMultiplyAdd;
};
// ======================= Turing Traits ==============================
// Note that turing does not have native bfloat support so weights and activations will be casted to fp16
// and compute will happen in fp16 then will be converted for bf16 output.
template <typename TypeA, typename TypeB>
struct MixedGemmArchTraits<TypeA, TypeB, cutlass::arch::Sm75,
typename cutlass::platform::enable_if<cutlass::platform::is_same<TypeA, cutlass::half_t>::value
|| cutlass::platform::is_same<TypeA, cutlass::bfloat16_t>::value>::type>
{
private:
using LayoutDetails = LayoutDetailsB<TypeA, TypeB, cutlass::arch::Sm75>;
public:
static constexpr int ThreadblockK = LayoutDetails::ThreadblockK;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using AccType = float;
using LayoutB = typename LayoutDetails::Layout;
static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits<TypeA>::value;
static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess;
static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits<TypeA>::value;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
using Operator = typename LayoutDetails::Operator;
};
// ======================= Ampere Traits ==============================
template <typename TypeA, typename TypeB>
struct MixedGemmArchTraits<TypeA, TypeB, cutlass::arch::Sm80,
typename cutlass::platform::enable_if<cutlass::platform::is_same<TypeA, cutlass::half_t>::value
|| cutlass::platform::is_same<TypeA, cutlass::bfloat16_t>::value>::type>
{
private:
using LayoutDetails = LayoutDetailsB<TypeA, TypeB, cutlass::arch::Sm80>;
public:
static constexpr int ThreadblockK = LayoutDetails::ThreadblockK;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using AccType = float;
using LayoutB = typename LayoutDetails::Layout;
static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits<TypeA>::value;
static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess;
static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits<TypeA>::value;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
using Operator = typename LayoutDetails::Operator;
};
// ======================= Ada Traits ==============================
template <typename TypeA, typename TypeB>
struct MixedGemmArchTraits<TypeA, TypeB, cutlass::arch::Sm89,
typename cutlass::platform::enable_if<cutlass::platform::is_same<TypeA, cutlass::half_t>::value
|| cutlass::platform::is_same<TypeA, cutlass::bfloat16_t>::value>::type>
{
private:
using LayoutDetails = LayoutDetailsB<TypeA, TypeB, cutlass::arch::Sm89>;
public:
static constexpr int ThreadblockK = LayoutDetails::ThreadblockK;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using AccType = float;
using LayoutB = typename LayoutDetails::Layout;
static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits<TypeA>::value;
static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess;
static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits<TypeA>::value;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256 / cutlass::sizeof_bits<TypeA>::value>;
using Operator = typename LayoutDetails::Operator;
};
// FP8 A/B = fp8, C/D = fp32
template <typename TypeA, typename TypeB>
struct MixedGemmArchTraits<TypeA, TypeB, cutlass::arch::Sm89,
typename cutlass::platform::enable_if<cutlass::platform::is_same<TypeA, cutlass::float_e4m3_t>::value
|| cutlass::platform::is_same<TypeA, cutlass::float_e5m2_t>::value>::type>
{
private:
using LayoutDetails = LayoutDetailsB<TypeA, TypeB, cutlass::arch::Sm89>;
public:
static constexpr int ThreadblockK = LayoutDetails::ThreadblockK;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using AccType = float;
// be careful, TypeC should align with HopperGroupedGemmInput::OutputTypeAdaptor_t<TypeA>
using TypeC = __nv_bfloat16;
using LayoutB = typename LayoutDetails::Layout;
static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits<TypeA>::value;
static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess;
static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits<TypeC>::value;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256 / cutlass::sizeof_bits<TypeA>::value>;
using Operator = typename LayoutDetails::Operator;
};
} // namespace kernel
} // namespace gemm
} // namespace cutlass

View File

@@ -0,0 +1,57 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cutlass/arch/arch.h"
#include "cutlass/arch/mma.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/layout/matrix.h"
namespace cutlass
{
namespace gemm
{
namespace kernel
{
template <typename arch>
struct Int8GemmArchTraits
{
using OperatorClass = cutlass::arch::OpClassSimt;
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
};
// ======================= Turing Traits ==============================
template <>
struct Int8GemmArchTraits<cutlass::arch::Sm75>
{
using OperatorClass = cutlass::arch::OpClassTensorOp;
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>;
};
// ======================= Ampere Traits ==============================
template <>
struct Int8GemmArchTraits<cutlass::arch::Sm80>
{
using OperatorClass = cutlass::arch::OpClassTensorOp;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
};
} // namespace kernel
} // namespace gemm
} // namespace cutlass

View File

@@ -0,0 +1,207 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief
Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with
the appropriate threadblock-scoped epilogue.
Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are
accommodated by exchanging A and B operands and assuming transposed layouts. Partial
specializations here choose 'device::GemmTransposed' to implement this functionality.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/complex.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/device/default_gemm_configuration.h"
#include "cutlass/gemm/kernel/default_gemm.h"
#include "cutlass/gemm/kernel/default_gemm_complex.h"
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
#include "cutlass/layout/permute.h"
#include "splitk_gemm_grouped.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace gemm
{
namespace kernel
{
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
/// Element type for A matrix operand
typename ElementA_,
/// Layout type for A matrix operand
typename LayoutA_,
/// Complex elementwise transformation on A operand
ComplexTransform TransformA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB_,
/// Layout type for B matrix operand
typename LayoutB_,
/// Complex elementwise transformation on B operand
ComplexTransform TransformB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for C and D matrix operands
typename ElementC_,
/// Layout type for C and D matrix operands
typename LayoutC_,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Operator class tag
typename OperatorClass,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Warp-level tile size (concept: GemmShape)
typename InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Whether the schedule of problems to visit has been precomputed
GroupScheduleMode GroupScheduleMode_ = GroupScheduleMode::kDeviceOnly,
/// Operation performed by GEMM
typename Operator = typename device::DefaultGemmConfiguration<OperatorClass, ArchTag, ElementA_, ElementB_,
ElementC_, ElementAccumulator>::Operator,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
/// Permute result D
typename PermuteDLayout = layout::NoPermute,
///
typename Enable = void>
struct DefaultSplitkGemmGrouped;
/////////////////////////////////////////////////////////////////////////////////////////////////
//
// Real-valued GEMM kernels
//
template <
/// Element type for A matrix operand
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for C and D matrix operands
typename ElementC,
/// Layout type for C and D matrix operands
typename LayoutC,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Operator class tag
typename OperatorClass,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Warp-level tile size (concept: GemmShape)
typename InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Whether the schedule of problems to visit has been precomputed
GroupScheduleMode GroupScheduleMode_,
/// Operation performed by GEMM
typename Operator,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear,
/// Permute result D
typename PermuteDLayout>
struct DefaultSplitkGemmGrouped<ElementA, LayoutA,
ComplexTransform::kNone, // transform A
kAlignmentA, ElementB, LayoutB,
ComplexTransform::kNone, // transform B
kAlignmentB, ElementC, LayoutC, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape,
InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, GroupScheduleMode_, Operator, SharedMemoryClear,
PermuteDLayout, typename platform::enable_if<!cutlass::is_complex<ElementAccumulator>::value>::type>
{
// If true, we must construct a 'transposed-and-exchanged' Mma operator.
static bool const kInternalTranspose = platform::is_same<LayoutC, layout::ColumnMajor>::value;
using MapArguments = kernel::detail::MapArguments<ElementA, LayoutA, ComplexTransform::kNone, kAlignmentA, ElementB,
LayoutB, ComplexTransform::kNone, kAlignmentB, LayoutC, kInternalTranspose>;
// Define the default GEMM kernel
using DefaultGemmKernel = typename kernel::DefaultGemm<typename MapArguments::ElementA,
typename MapArguments::LayoutA, MapArguments::kAlignmentA, typename MapArguments::ElementB,
typename MapArguments::LayoutB, MapArguments::kAlignmentB, ElementC, typename MapArguments::LayoutC,
ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp,
ThreadblockSwizzle, Stages, true, Operator, SharedMemoryClear, false, /*GatherA*/
false, /*GatherB*/
false, /*ScatterD*/
PermuteDLayout>::GemmKernel;
/// Define the kernel in terms of the default kernel
using GemmKernel = kernel::SplitkGemmGrouped<typename DefaultGemmKernel::Mma, typename DefaultGemmKernel::Epilogue,
ThreadblockSwizzle, GroupScheduleMode_, kInternalTranspose>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,566 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/arch/arch.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/semaphore.h"
#include <type_traits>
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace gemm
{
namespace kernel
{
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace detail
{
template <typename>
inline constexpr bool dependent_false_v = false;
}
template <typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
typename Epilogue_, ///! Epilogue
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
typename KernelArch, ///! The Architecture this kernel is compiled for. Used since SIMT kernels lose top-level
/// arch.
bool SplitKSerial ///! If true, code supporting split-K via serial reduction is enabled.
>
struct GemmFpAIntB
{
using Mma = Mma_;
using Epilogue = Epilogue_;
using EpilogueOutputOp = typename Epilogue::OutputOp;
using ThreadblockSwizzle = ThreadblockSwizzle_;
static bool const kSplitKSerial = SplitKSerial;
using ElementA = typename Mma::IteratorA::Element;
using LayoutA = typename Mma::IteratorA::Layout;
using ElementB = typename Mma::IteratorB::Element;
using LayoutB = typename Mma::IteratorB::Element;
using ElementC = typename Epilogue::OutputTileIterator::Element;
using LayoutC = typename Mma::LayoutC;
using ElementScale = ElementC;
static ComplexTransform const kTransformA = Mma::kTransformA;
static ComplexTransform const kTransformB = Mma::kTransformA;
// Type definitions about the mainloop.
using Operator = typename Mma::Operator;
using OperatorClass = typename Mma::Operator::OperatorClass;
using ThreadblockShape = typename Mma::Shape;
using WarpShape = typename Mma::Operator::Shape;
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
using ArchTag = typename Mma::ArchTag;
static int const kStages = Mma::kStages;
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
/// Warp count (concept: GemmShape)
using WarpCount = typename Mma::WarpCount;
static int const kThreadCount = 32 * WarpCount::kCount;
static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK;
/// Parameters structure
struct Arguments
{
GemmUniversalMode mode = GemmUniversalMode::kGemm;
cutlass::gemm::GemmCoord problem_size;
int group_size;
typename Mma::IteratorA::TensorRef ref_A;
typename Mma::IteratorB::TensorRef ref_B;
typename Mma::IteratorScale::TensorRef ref_scale;
typename Mma::IteratorScale::TensorRef ref_zero;
typename Epilogue::OutputTileIterator::TensorRef ref_C;
typename Epilogue::OutputTileIterator::TensorRef ref_D;
// Control serial split-k
int batch_count;
typename EpilogueOutputOp::Params output_op;
// For gather+scatter operations
int const* gather_A_indices;
int const* gather_B_indices;
int const* scatter_D_indices;
// Included so we can use Gemm Universal
int batch_stride_D = 0;
//
// Methods
//
CUTLASS_HOST_DEVICE
Arguments() {}
CUTLASS_HOST_DEVICE
Arguments(cutlass::gemm::GemmCoord const& problem_size, int const group_size,
typename Mma::IteratorA::TensorRef ref_A, typename Mma::IteratorB::TensorRef ref_B,
typename Mma::IteratorScale::TensorRef ref_scale, typename Mma::IteratorScale::TensorRef ref_zero,
typename Epilogue::OutputTileIterator::TensorRef ref_C,
typename Epilogue::OutputTileIterator::TensorRef ref_D, int serial_split_k_factor,
typename EpilogueOutputOp::Params output_op = typename EpilogueOutputOp::Params(),
int const* gather_A_indices = nullptr, int const* gather_B_indices = nullptr,
int const* scatter_D_indices = nullptr)
: problem_size(problem_size)
, group_size(group_size)
, ref_A(ref_A)
, ref_B(ref_B)
, ref_scale(ref_scale)
, ref_zero(ref_zero)
, ref_C(ref_C)
, ref_D(ref_D)
, batch_count(serial_split_k_factor)
, output_op(output_op)
, gather_A_indices(gather_A_indices)
, gather_B_indices(gather_B_indices)
, scatter_D_indices(scatter_D_indices)
{
}
};
/// Parameters structure
struct Params
{
cutlass::gemm::GemmCoord problem_size;
int group_size;
cutlass::gemm::GemmCoord grid_tiled_shape;
int swizzle_log_tile;
typename Mma::IteratorA::Params params_A;
typename Mma::IteratorA::TensorRef ref_A;
typename Mma::IteratorB::Params params_B;
typename Mma::IteratorB::TensorRef ref_B;
typename Mma::IteratorScale::Params params_scale;
typename Mma::IteratorScale::TensorRef ref_scale;
typename Mma::IteratorScale::TensorRef ref_zero;
typename Epilogue::OutputTileIterator::Params params_C;
typename Epilogue::OutputTileIterator::TensorRef ref_C;
typename Epilogue::OutputTileIterator::Params params_D;
typename Epilogue::OutputTileIterator::TensorRef ref_D;
typename EpilogueOutputOp::Params output_op;
int* semaphore;
int gemm_k_size;
// For gather+scatter operations
int const* gather_A_indices;
int const* gather_B_indices;
int const* scatter_D_indices;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params()
: swizzle_log_tile(0)
, semaphore(0)
, gemm_k_size(0)
{
}
CUTLASS_HOST_DEVICE
Params(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape, int const gemm_k_size,
void* workspace = nullptr)
: problem_size(args.problem_size)
, group_size(args.group_size)
, grid_tiled_shape(grid_tiled_shape)
, swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape))
, params_A(args.ref_A.layout())
, ref_A(args.ref_A)
, params_B(args.ref_B.layout())
, ref_B(args.ref_B)
, params_scale(args.ref_scale.layout())
, ref_scale(args.ref_scale)
, ref_zero(args.ref_zero)
, params_C(args.ref_C.layout())
, ref_C(args.ref_C)
, params_D(args.ref_D.layout())
, ref_D(args.ref_D)
, output_op(args.output_op)
, semaphore(static_cast<int*>(workspace))
, gemm_k_size(gemm_k_size)
, gather_A_indices(args.gather_A_indices)
, gather_B_indices(args.gather_B_indices)
, scatter_D_indices(args.scatter_D_indices)
{
}
};
/// Shared memory storage structure
union SharedStorage
{
typename Mma::SharedStorage main_loop;
typename Epilogue::SharedStorage epilogue;
};
//
// Methods
//
CUTLASS_HOST_DEVICE
GemmFpAIntB() {}
/// Determines whether kernel satisfies alignment
static Status can_implement(Arguments const& args)
{
static int const kAlignmentA
= (platform::is_same<typename Mma::IteratorA::Layout, layout::ColumnMajorInterleaved<32>>::value) ? 32
: (platform::is_same<typename Mma::IteratorA::Layout, layout::ColumnMajorInterleaved<64>>::value)
? 64
: Mma::IteratorA::AccessType::kElements;
static int const kAlignmentB
= (platform::is_same<typename Mma::IteratorB::Layout, layout::RowMajorInterleaved<32>>::value) ? 32
: (platform::is_same<typename Mma::IteratorB::Layout, layout::RowMajorInterleaved<64>>::value)
? 64
: Mma::IteratorB::AccessType::kElements;
static int const kAlignmentScale = Mma::IteratorScale::AccessType::kElements;
static int const kAlignmentC = (platform::is_same<typename Epilogue::OutputTileIterator::Layout,
layout::ColumnMajorInterleaved<32>>::value)
? 32
: (platform::is_same<typename Epilogue::OutputTileIterator::Layout,
layout::ColumnMajorInterleaved<64>>::value)
? 64
: Epilogue::OutputTileIterator::kElementsPerAccess;
if (!TensorRef_aligned(args.ref_A, kAlignmentA))
{
return Status::kErrorMisalignedOperand;
}
if (!TensorRef_aligned(args.ref_B, kAlignmentB))
{
return Status::kErrorMisalignedOperand;
}
if (!TensorRef_aligned(args.ref_scale, kAlignmentScale))
{
return Status::kErrorMisalignedOperand;
}
if (!TensorRef_aligned(args.ref_zero, kAlignmentScale))
{
return Status::kErrorMisalignedOperand;
}
if (!TensorRef_aligned(args.ref_C, kAlignmentC))
{
return Status::kErrorMisalignedOperand;
}
if (!TensorRef_aligned(args.ref_D, kAlignmentC))
{
return Status::kErrorMisalignedOperand;
}
if (!args.ref_scale.good())
{
return Status::kErrorNotSupported;
}
if constexpr (hasZero(Mma::QuantOp))
{
if (!args.ref_zero.good())
{
return Status::kErrorNotSupported;
}
}
else
{
if (args.ref_zero.good())
{
return Status::kErrorNotSupported;
}
}
if constexpr (isFinegrained(Mma::QuantOp))
{
if (args.group_size != 64 && args.group_size != 128)
{
return Status::kErrorNotSupported;
}
}
return Status::kSuccess;
}
static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape)
{
return 0;
}
// Initializes the fine grained scale+bias iterator. Needed since the fine grained iterator
// has a different constructor signature than a regular cutlass iterator
template <typename IteratorScale, WeightOnlyQuantOp op, std::enable_if_t<isFinegrained(op), bool> = true>
CUTLASS_DEVICE static IteratorScale initialize_scale(typename IteratorScale::Params const& params,
typename IteratorScale::Pointer pointer_scale, typename IteratorScale::Pointer pointer_zero,
typename IteratorScale::TensorCoord extent, int thread_id,
typename IteratorScale::TensorCoord const& threadblock_offset, int group_size)
{
return IteratorScale(params, pointer_scale, pointer_zero, extent, thread_id, threadblock_offset, group_size);
}
template <typename IteratorScale, WeightOnlyQuantOp op, std::enable_if_t<!isFinegrained(op), bool> = true>
CUTLASS_DEVICE static IteratorScale initialize_scale(typename IteratorScale::Params const& params,
typename IteratorScale::Pointer pointer_scale, typename IteratorScale::Pointer pointer_zero,
typename IteratorScale::TensorCoord extent, int thread_id,
typename IteratorScale::TensorCoord const& threadblock_offset, int group_size)
{
return IteratorScale(params, pointer_scale, extent, thread_id, threadblock_offset);
}
CUTLASS_DEVICE
void run_kernel_(Params const& params, SharedStorage& shared_storage)
{
using LayoutB = typename Mma::IteratorB::Layout;
static_assert(platform::is_same<LayoutB, layout::RowMajor>::value && kInterleave == 1
|| platform::is_same<LayoutB, layout::ColumnMajor>::value && kInterleave >= 1,
"B must be row major/col major OR col major interleaved.");
// Compute threadblock location
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
// Early exit if CTA is out of range
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m()
|| params.grid_tiled_shape.n() <= threadblock_tile_offset.n())
{
return;
}
// Compute initial location in logical coordinates
cutlass::MatrixCoord tb_offset_A{
threadblock_tile_offset.m() * Mma::Shape::kM,
threadblock_tile_offset.k() * params.gemm_k_size,
};
cutlass::MatrixCoord tb_offset_B{threadblock_tile_offset.k() * params.gemm_k_size * kInterleave,
threadblock_tile_offset.n() * Mma::Shape::kN / kInterleave};
typename MatrixCoord::Index fg_row_offset = threadblock_tile_offset.k() * params.gemm_k_size / 64;
typename MatrixCoord::Index scale_row_offset = isFinegrained(Mma::QuantOp) ? fg_row_offset : 0;
cutlass::MatrixCoord tb_offset_scale{scale_row_offset, threadblock_tile_offset.n() * Mma::Shape::kN};
// Problem size is a function of threadblock index in the K dimension
int problem_size_k = min(params.problem_size.k(), (threadblock_tile_offset.k() + 1) * params.gemm_k_size);
// Compute threadblock-scoped matrix multiply-add
int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK;
// Compute position within threadblock
int thread_idx = threadIdx.x;
// Construct iterators to A and B operands
typename Mma::IteratorA iterator_A(params.params_A, params.ref_A.data(),
{params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A, params.gather_A_indices);
typename Mma::IteratorB iterator_B(params.params_B, params.ref_B.data(),
{problem_size_k * kInterleave, params.problem_size.n() / kInterleave}, thread_idx, tb_offset_B,
params.gather_B_indices);
typename MatrixCoord::Index scale_row_extent = isFinegrained(Mma::QuantOp) ? problem_size_k / 64 : 1;
typename Mma::IteratorScale iterator_scale = initialize_scale<typename Mma::IteratorScale, Mma::QuantOp>(
params.params_scale, params.ref_scale.data(), params.ref_zero.data(),
{scale_row_extent, params.problem_size.n()}, thread_idx, tb_offset_scale, params.group_size);
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
int lane_idx = threadIdx.x % 32;
//
// Main loop
//
// Construct thread-scoped matrix multiply
Mma mma(shared_storage.main_loop, params.group_size, thread_idx, warp_idx, lane_idx);
typename Mma::FragmentC accumulators;
accumulators.clear();
if (!kSplitKSerial || gemm_k_iterations > 0)
{
// Compute threadblock-scoped matrix multiply-add
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_scale, accumulators);
}
//
// Epilogue
//
EpilogueOutputOp output_op(params.output_op);
//
// Masked tile iterators constructed from members
//
threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
// assume identity swizzle
MatrixCoord threadblock_offset(
threadblock_tile_offset.m() * Mma::Shape::kM, threadblock_tile_offset.n() * Mma::Shape::kN);
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
// Construct the semaphore.
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
// If performing a reduction via split-K, fetch the initial synchronization
if (kSplitKSerial && params.grid_tiled_shape.k() > 1)
{
// Fetch the synchronization lock initially but do not block.
semaphore.fetch();
// Indicate which position in a serial reduction the output operator is currently updating
output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
}
// Tile iterator loading from source tensor.
typename Epilogue::OutputTileIterator iterator_C(params.params_C, params.ref_C.data(), params.problem_size.mn(),
thread_idx, threadblock_offset, params.scatter_D_indices);
// Tile iterator writing to destination tensor.
typename Epilogue::OutputTileIterator iterator_D(params.params_D, params.ref_D.data(), params.problem_size.mn(),
thread_idx, threadblock_offset, params.scatter_D_indices);
Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx);
// Wait on the semaphore - this latency may have been covered by iterator construction
if (kSplitKSerial && params.grid_tiled_shape.k() > 1)
{
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
if (threadblock_tile_offset.k())
{
iterator_C = iterator_D;
}
semaphore.wait(threadblock_tile_offset.k());
}
// Execute the epilogue operator to update the destination tensor.
epilogue(output_op, iterator_D, accumulators, iterator_C);
//
// Release the semaphore
//
if (kSplitKSerial && params.grid_tiled_shape.k() > 1)
{
int lock = 0;
if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1)
{
// The final threadblock resets the semaphore for subsequent grids.
lock = 0;
}
else
{
// Otherwise, the semaphore is incremented
lock = threadblock_tile_offset.k() + 1;
}
semaphore.release(lock);
}
}
template <typename CompilationArch>
CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage)
{
if constexpr (platform::is_same<KernelArch, CompilationArch>::value)
{
run_kernel_(params, shared_storage);
}
else
{
CUTLASS_NOT_IMPLEMENTED();
}
}
/*
To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond
to the ArchTag of the cutlass kernel operator.
*/
/// Executes one GEMM
CUTLASS_DEVICE
void operator()(Params const& params, SharedStorage& shared_storage)
{
#if defined(__CUDA_ARCH__)
#if (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800)
run_kernel<arch::Sm75>(params, shared_storage);
#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 890)
run_kernel<arch::Sm80>(params, shared_storage);
#elif (__CUDA_ARCH__ == 890)
run_kernel<arch::Sm89>(params, shared_storage);
#elif (__CUDA_ARCH__ >= 900)
CUTLASS_NOT_IMPLEMENTED(); // Don't compile these for Hopper or later. Use CUTLASS 3.x kernels.
#else
static_assert(
false, "Invalid architecture being compiled. Only Volta+ supported in weight-only quantization kernels.");
#endif
#else
CUTLASS_NOT_IMPLEMENTED();
#endif
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass

View File

@@ -0,0 +1,218 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cutlass/gemm/kernel/gemm_grouped_problem_visitor.h>
#include <cutlass/trace.h>
#include <cutlass_extensions/gemm/kernel/fused_moe_kernel_routine.cuh>
#include <cutlass_extensions/gemm/kernel/fused_moe_kernel_traits.cuh>
#include <cutlass_extensions/gemm/kernel/moe_problem_visitor.h>
namespace fused_moe
{
template <typename ElementInput_, typename ElementWeight_, typename ElementOutput_, int MaxTileM_, int TileN_,
int TileK_, int Stages_, Activation_Type activation_type_>
struct Fused_Moe_Kernel_sm80
{
static constexpr int kMaxTileM = MaxTileM_;
static constexpr int kTileN = isGateActivation(activation_type_) ? TileN_ / 2 : TileN_;
static constexpr int kTileK = TileK_;
static constexpr int kStages = Stages_;
static constexpr Activation_Type activation_type = activation_type_;
using ElementInput = ElementInput_;
using ElementWeight = ElementWeight_;
using ElementOutput = ElementOutput_;
using BaseKernelTraits = Fused_Moe_Kernel_traits_sm80<ElementInput, ElementWeight, ElementOutput, kMaxTileM, kTileN,
kTileK, kStages, activation_type>;
using Routine_Arguments = Routine_Arguments<ElementInput, ElementWeight, ElementOutput>;
using Routine_Params = Routine_Params<ElementInput, ElementWeight, ElementOutput>;
using ProblemVisitor
= cutlass::gemm::kernel::MoeProblemVisitor<cutlass::gemm::kernel::detail::GemmGroupedProblemSizeHelper<
cutlass::gemm::GemmShape<kMaxTileM, kTileN, kTileK>, false>,
cutlass::gemm::GemmShape<kMaxTileM, kTileN, kTileK>, cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly,
BaseKernelTraits::kThreadCount, BaseKernelTraits::kThreadCount>;
struct Arguments
{
Routine_Arguments routine_args;
int problem_count{};
int threadblock_count{};
};
struct Params
{
Routine_Params routine_params;
int threadblock_count{};
typename ProblemVisitor::Params problem_visitor_param;
};
using BaseKernelTraits_m16 = Fused_Moe_Kernel_traits_sm80<ElementInput, ElementWeight, ElementOutput, 16, kTileN,
kTileK, kStages, activation_type>;
static constexpr bool use_m16 = TileK_ >= 64; // use tileshape m = 16 when original tileshape k >= 64
static constexpr int kSmemSize = use_m16
? (BaseKernelTraits::kSmemSize > BaseKernelTraits_m16::kSmemSize ? BaseKernelTraits::kSmemSize
: BaseKernelTraits_m16::kSmemSize)
: BaseKernelTraits::kSmemSize;
static constexpr int kThreadCount = BaseKernelTraits::kThreadCount;
static constexpr bool can_implement(int const avaliable_smem_size)
{
return BaseKernelTraits::can_implement(avaliable_smem_size);
}
static Params to_underlying_arguments(Arguments const& args)
{
return {
{args.routine_args.ptr_input, args.routine_args.ptr_fc1, args.routine_args.ptr_bias,
args.routine_args.ptr_output, args.routine_args.total_tokens_including_expert, args.routine_args.gemm_n,
args.routine_args.gemm_k, args.routine_args.num_expert, args.routine_args.bias_is_broadcast},
args.threadblock_count,
{args.routine_args.total_tokens_including_expert, args.routine_args.gemm_n, args.routine_args.gemm_k,
args.problem_count, nullptr, 0}};
}
CUTE_DEVICE
void run_device(Params const& params)
{
#define ROUTINE_PATH(kTileM_size) \
{ \
constexpr int kTileM = use_m16 ? (kTileM_size) : ((kTileM_size) == 16 ? 32 : (kTileM_size)); \
using RoutineTraits = Fused_Moe_Kernel_routine_sm80<ElementInput, ElementWeight, ElementOutput, kTileM, \
kTileN, kTileK, kStages, activation_type>; \
RoutineTraits routine{}; \
int const block_m_idx = (block_m_idx_temp) *kMaxTileM / kTileM; \
routine.run_routine(params.routine_params, problem_index, block_m_idx, block_n_idx, gemm_m); \
}
typename ProblemVisitor::SharedStorage dummy_storage{};
ProblemVisitor problem_visitor(params.problem_visitor_param, dummy_storage, blockIdx.x);
while (problem_visitor.next_tile())
{
auto problem_size = problem_visitor.problem_size();
auto grid_size = problem_visitor.grid_shape(problem_size);
auto problem_index = problem_visitor.problem_index();
int32_t cta_idx = int32_t(problem_visitor.threadblock_idx());
int const gemm_m = problem_size.m();
const int32_t block_m_idx_temp = cta_idx / grid_size.n();
const int32_t block_n_idx = cta_idx % grid_size.n();
int const residue_m = gemm_m - kMaxTileM * block_m_idx_temp;
if (residue_m > kMaxTileM / 2)
{
using RoutineTraits = Fused_Moe_Kernel_routine_sm80<ElementInput, ElementWeight, ElementOutput,
kMaxTileM, kTileN, kTileK, kStages, activation_type>;
RoutineTraits routine{};
routine.run_routine(params.routine_params, problem_index, block_m_idx_temp, block_n_idx, gemm_m);
}
else
{
if constexpr (kMaxTileM >= 128)
{
if (residue_m > 32)
{
ROUTINE_PATH(64);
}
else if (residue_m > 16)
{
ROUTINE_PATH(32);
}
else
{
// TODO: use cuda core gemm here
ROUTINE_PATH(16);
}
}
else if (kMaxTileM == 64)
{
if (residue_m > 16)
{
ROUTINE_PATH(32);
}
else
{
// TODO: use cuda core gemm here
ROUTINE_PATH(16);
}
}
else if (kMaxTileM == 32)
{
// TODO: use cuda core gemm here
ROUTINE_PATH(16);
}
else
{
// TODO: use cuda core gemm here
ROUTINE_PATH(16);
}
}
problem_visitor.advance(gridDim.x);
}
#undef ROUTINE_PATH
}
};
template <typename GemmType>
__global__ void run_global(__grid_constant__ typename GemmType::Params const params)
{
GemmType gemm;
gemm.run_device(params);
}
/// Computes the maximum number of active blocks per multiprocessor
template <typename GemmType>
static int fused_gemm_maximum_active_blocks(int smem_capacity = -1)
{
CUTLASS_TRACE_HOST("BaseGrouped::maximum_active_blocks()");
constexpr int smem_size = GemmType::kSmemSize;
CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes");
cudaError_t result;
if (smem_size > (48 << 10))
{
result = cudaFuncSetAttribute(run_global<GemmType>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
if (result != cudaSuccess)
{
// Call cudaGetLastError() to clear the error bit
result = cudaGetLastError();
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(result));
return -1;
}
}
int max_active_blocks = -1;
result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks, run_global<GemmType>, GemmType::kThreadCount, smem_size);
if (result != cudaSuccess)
{
// Call cudaGetLastError() to clear the error bit
result = cudaGetLastError();
CUTLASS_TRACE_HOST(
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result));
return -1;
}
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
return max_active_blocks;
}
} // namespace fused_moe

View File

@@ -0,0 +1,799 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cutlass_extensions/gemm/kernel/fused_moe_kernel_traits.cuh>
namespace fused_moe
{
template <typename ElementInput_, typename ElementWeight_, typename ElementOutput_, int TileM_, int TileN_, int TileK_,
int Stages_, Activation_Type activation_type_, typename Enable = void>
struct Fused_Moe_Kernel_routine_sm80;
template <typename ElementInput_, typename ElementWeight_, typename ElementOutput_, int TileM_, int TileN_, int TileK_,
int Stages_, Activation_Type activation_type_>
struct Fused_Moe_Kernel_routine_sm80<ElementInput_, ElementWeight_, ElementOutput_, TileM_, TileN_, TileK_, Stages_,
activation_type_, std::enable_if_t<isGateActivation(activation_type_)>>
{
using KT = Fused_Moe_Kernel_traits_sm80<ElementInput_, ElementWeight_, ElementOutput_, TileM_, TileN_, TileK_,
Stages_, activation_type_>;
using Params = Routine_Params<ElementInput_, ElementWeight_, ElementOutput_>;
CUTE_DEVICE auto gmem_tensor_init(int const problem_index, int const gemm_m, Params const& params)
{
using X = cute::Underscore;
int const M = gemm_m;
int const N1 = params.gemm_n;
int const K1 = params.gemm_k;
bool const bias_is_broadcast = params.bias_is_broadcast;
int const row_jump = ((problem_index == 0) ? 0 : params.total_tokens_including_expert[problem_index - 1]);
typename KT::ElementInput const* ptr_input_ = params.ptr_input + row_jump * K1;
typename KT::ElementWeight const* ptr_fc1_gate_
= params.ptr_fc1 + (2 * problem_index + 1) * N1 * K1; // TODO: we only focus on gated activation..
typename KT::ElementWeight const* ptr_fc1_
= params.ptr_fc1 + 2 * problem_index * N1 * K1; // TODO: we only focus on gated activation..
typename KT::ElementInput const* ptr_bias_ = (params.ptr_bias == nullptr)
? nullptr
: (bias_is_broadcast ? params.ptr_bias + 2 * problem_index * N1 : params.ptr_bias + 2 * row_jump * N1);
typename KT::ElementInput const* ptr_bias_gate_ = (params.ptr_bias == nullptr)
? nullptr
: (bias_is_broadcast ? params.ptr_bias + (2 * problem_index + 1) * N1
: params.ptr_bias + (2 * row_jump + 1) * N1);
typename KT::ElementOutput* ptr_output_ = params.ptr_output + row_jump * N1;
cute::Tensor mInput_mk
= cute::make_tensor(cute::make_gmem_ptr(static_cast<typename KT::ElementInput const*>(ptr_input_)),
cute::make_shape(M, K1), cute::make_stride(K1, cute::_1{}));
cute::Tensor mfc1_gate_nk
= cute::make_tensor(cute::make_gmem_ptr(static_cast<typename KT::ElementWeight const*>(ptr_fc1_gate_)),
cute::make_shape(N1, K1), cute::make_stride(K1, cute::_1{}));
cute::Tensor mfc1_nk
= cute::make_tensor(cute::make_gmem_ptr(static_cast<typename KT::ElementWeight const*>(ptr_fc1_)),
cute::make_shape(N1, K1), cute::make_stride(K1, cute::_1{}));
cute::Tensor mBias_mn = cute::make_tensor(
cute::make_gmem_ptr(static_cast<typename KT::ElementInput const*>(ptr_bias_)), cute::make_shape(M, N1),
cute::make_stride(bias_is_broadcast ? cute::Int<0>{} : N1 * 2,
cute::_1{})); // trick: bias shape is [1, N], but we use [M, N].
cute::Tensor mBias_gate_mn = cute::make_tensor(
cute::make_gmem_ptr(static_cast<typename KT::ElementInput const*>(ptr_bias_gate_)), cute::make_shape(M, N1),
cute::make_stride(bias_is_broadcast ? cute::Int<0>{} : N1 * 2,
cute::_1{})); // trick: bias shape is [1, N], but we use [M, N].
cute::Tensor mOutput_mn
= cute::make_tensor(cute::make_gmem_ptr(static_cast<typename KT::ElementInput*>(ptr_output_)),
cute::make_shape(M, N1), cute::make_stride(N1, cute::_1{}));
cute::Tensor gInput_mk = cute::local_tile(mInput_mk, typename KT::TileShape{},
cute::make_coord(cute::_, cute::_, cute::_), cute::Step<cute::_1, X, cute::_1>{}); // (BLK_M, BLK_K, m, k)
cute::Tensor gfc1_gate_nk = cute::local_tile(mfc1_gate_nk, typename KT::TileShape{},
cute::make_coord(cute::_, cute::_, cute::_), cute::Step<X, cute::_1, cute::_1>{}); // (BLK_N, BLK_K, n, k)
cute::Tensor gfc1_nk = cute::local_tile(mfc1_nk, typename KT::TileShape{},
cute::make_coord(cute::_, cute::_, cute::_), cute::Step<X, cute::_1, cute::_1>{}); // (BLK_N, BLK_K, n, k)
cute::Tensor gBias_mn = cute::local_tile(mBias_mn, typename KT::TileShape{},
cute::make_coord(cute::_, cute::_, cute::_), cute::Step<cute::_1, cute::_1, X>{}); // (BLK_M, BLK_N, m, n)
cute::Tensor gBias_gate_mn = cute::local_tile(mBias_gate_mn, typename KT::TileShape{},
cute::make_coord(cute::_, cute::_, cute::_), cute::Step<cute::_1, cute::_1, X>{}); // (BLK_M, BLK_N, m, n)
cute::Tensor gOutput_mn = cute::local_tile(mOutput_mn, typename KT::TileShape{},
cute::make_coord(cute::_, cute::_, cute::_), cute::Step<cute::_1, cute::_1, X>{}); // (BLK_M, BLK_N, m, n)
return cute::make_tuple(gInput_mk, gfc1_gate_nk, gfc1_nk, gBias_mn, gBias_gate_mn, gOutput_mn);
}
// be careful, m_idx will change when use another tile shape..
CUTE_DEVICE void run_routine(
Params const& params, int const problem_index, int const block_m_idx, int const block_n_idx, int const gemm_m)
{
extern __shared__ char smem_[];
typename KT::SharedStorage& shared_storage = *reinterpret_cast<typename KT::SharedStorage*>(smem_);
int const thread_idx = threadIdx.x;
bool const bias_is_broadcast = params.bias_is_broadcast;
// gmem tensor partition ..
auto [gInput_mk, gfc1_gate_nk, gfc1_nk, gBias_mn, gBias_gate_mn, gOutput_mn]
= gmem_tensor_init(problem_index, gemm_m, params);
int const residue_m = gemm_m - block_m_idx * cute::size<0>(gInput_mk);
auto const n_tile_count = cute::size<2>(gfc1_gate_nk);
// smem tensor ..
cute::Tensor sInput = cute::make_tensor(
cute::make_smem_ptr(shared_storage.smem_input.data()), typename KT::SmemLayoutA{}); // (BLK_M, BLK_K, Stage)
cute::Tensor sfc1_weight = cute::make_tensor(cute::make_smem_ptr(shared_storage.smem_fc1_weight.data()),
typename KT::SmemLayoutB{}); // (BLK_N, BLK_K, Stage)
cute::Tensor sfc1_gate_weight
= cute::make_tensor(cute::make_smem_ptr(shared_storage.smem_fc1_gate_weight.data()),
typename KT::SmemLayoutB{}); // (BLK_N, BLK_K, Stage)
cute::Tensor sO = cute::make_tensor(
cute::make_smem_ptr(shared_storage.smem_o.data()), typename KT::SmemLayoutO{}); // (BLK_M, BLK_N)
// (1) first step, get the fc1_res and fc1_gate
// (1.1) get partition for gmem -> smem
cute::Tensor gInput = gInput_mk(cute::_, cute::_, block_m_idx, cute::_); // (BLK_M, BLK_K, k)
cute::Tensor gfc1 = gfc1_nk(cute::_, cute::_, block_n_idx, cute::_); // (BLK_N, BLK_K, k)
cute::Tensor gfc1g = gfc1_gate_nk(cute::_, cute::_, block_n_idx, cute::_); // (BLK_N, BLK_K, k)
typename KT::GmemTiledCopyA gmem_tiled_copy_A;
typename KT::GmemTiledCopyB gmem_tiled_copy_B;
auto gmem_thr_copy_A = gmem_tiled_copy_A.get_slice(thread_idx);
auto gmem_thr_copy_B = gmem_tiled_copy_B.get_slice(thread_idx);
cute::Tensor tInputgInput = gmem_thr_copy_A.partition_S(gInput); // (ACPY,ACPY_M,ACPY_K,k)
cute::Tensor tInputsInput = gmem_thr_copy_A.partition_D(sInput); // (ACPY,ACPY_M,ACPY_K,Stage)
cute::Tensor tfc1gfc1 = gmem_thr_copy_B.partition_S(gfc1); // (BCPY,BCPY_N,BCPY_K,k)
cute::Tensor tfc1sfc1 = gmem_thr_copy_B.partition_D(sfc1_weight); // (BCPY,BCPY_N,BCPY_K,Stage)
cute::Tensor tfc1ggfc1g = gmem_thr_copy_B.partition_S(gfc1g); // (BCPY,BCPY_N,BCPY_K,k)
cute::Tensor tfc1gsfc1g = gmem_thr_copy_B.partition_D(sfc1_gate_weight); // (BCPY,BCPY_N,BCPY_K,Stage)
// Allocate predicate tensors for input and fc weight (actually we only need input predicate tensor)
cute::Tensor tInputpInput
= cute::make_tensor<bool>(cute::make_shape(cute::size<1>(tInputsInput), cute::size<2>(tInputsInput)),
cute::Stride<cute::_1, cute::_0>{});
// Construct identity layout for sInput
cute::Tensor cInput = make_identity_tensor(
make_shape(cute::size<0>(sInput), cute::size<1>(sInput))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
// Repeat the partitioning with identity layouts
cute::Tensor tInputcInput = gmem_thr_copy_A.partition_S(cInput); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
// Set predicates for m bounds
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < cute::size<0>(tInputpInput); ++m)
{
tInputpInput(m, 0) = cute::get<0>(tInputcInput(0, m, 0)) < residue_m; // blk_m coord < residue_m
}
// (1.2) prefetch gmem -> smem
cute::clear(tInputsInput); // we don't need to clear tfc1sfc1..
auto k_tile_iter = cute::make_coord_iterator(cute::size<2>(gInput)); // emm, iter start from 0
int k_tile_count = cute::size<2>(gInput);
CUTLASS_PRAGMA_UNROLL
for (int k_pipe = 0; k_pipe < KT::Stages - 1; ++k_pipe)
{
if (k_tile_count <= 0)
{
cute::clear(tInputpInput);
}
// cute::copy(gmem_tiled_copy_A, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter),
// tInputsInput(cute::_, cute::_, cute::_, k_pipe));
// use copy_if
cute::copy_if(gmem_tiled_copy_A, tInputpInput, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter),
tInputsInput(cute::_, cute::_, cute::_, k_pipe));
cute::copy(gmem_tiled_copy_B, tfc1gfc1(cute::_, cute::_, cute::_, *k_tile_iter),
tfc1sfc1(cute::_, cute::_, cute::_, k_pipe));
cute::copy(gmem_tiled_copy_B, tfc1ggfc1g(cute::_, cute::_, cute::_, *k_tile_iter),
tfc1gsfc1g(cute::_, cute::_, cute::_, k_pipe));
cute::cp_async_fence();
k_tile_count--;
if (k_tile_count > 0)
{
++k_tile_iter;
}
}
// (1.3) get partition for rf
typename KT::TiledMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(thread_idx);
cute::Tensor tOrInput = thr_mma.partition_fragment_A(sInput(cute::_, cute::_, 0)); // (MMA,MMA_M,MMA_K)
cute::Tensor tOrfc1 = thr_mma.partition_fragment_B(sfc1_weight(cute::_, cute::_, 0)); // (MMA,MMA_N,MMA_K)
cute::Tensor tOrfc1g = thr_mma.partition_fragment_B(sfc1_gate_weight(cute::_, cute::_, 0)); // (MMA,MMA_N,MMA_K)
cute::Tensor accum
= cute::partition_fragment_C(tiled_mma, cute::take<0, 2>(typename KT::TileShape{})); // (MMA,MMA_M,MMA_N)
cute::Tensor accum_gate
= cute::partition_fragment_C(tiled_mma, cute::take<0, 2>(typename KT::TileShape{})); // (MMA,MMA_M,MMA_N)
cute::clear(accum);
cute::clear(accum_gate);
// checkout the shape
CUTE_STATIC_ASSERT_V(cute::size<1>(tOrInput) == cute::size<1>(accum)); // MMA_M
CUTE_STATIC_ASSERT_V(cute::size<1>(tOrInput) == cute::size<1>(accum_gate)); // MMA_M
CUTE_STATIC_ASSERT_V(cute::size<1>(tOrfc1) == cute::size<2>(accum)); // MMA_N
CUTE_STATIC_ASSERT_V(cute::size<1>(tOrfc1) == cute::size<2>(accum_gate)); // MMA_N
CUTE_STATIC_ASSERT_V(cute::size<1>(tOrfc1g) == cute::size<2>(accum)); // MMA_N
CUTE_STATIC_ASSERT_V(cute::size<1>(tOrfc1g) == cute::size<2>(accum_gate)); // MMA_N
CUTE_STATIC_ASSERT_V(cute::size<2>(tOrInput) == cute::size<2>(tOrfc1)); // MMA_K
CUTE_STATIC_ASSERT_V(cute::size<2>(tOrInput) == cute::size<2>(tOrfc1g)); // MMA_K
CUTE_STATIC_ASSERT_V(cute::size(gmem_tiled_copy_A) == cute::size(tiled_mma));
CUTE_STATIC_ASSERT_V(cute::size(gmem_tiled_copy_B) == cute::size(tiled_mma));
// (1.4)retiling the smem and rf for copy..
auto smem_tiled_copy_A = cute::make_tiled_copy_A(typename KT::SmemCopyAtomA{}, tiled_mma);
auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx);
cute::Tensor tOsInput = smem_thr_copy_A.partition_S(sInput); // (CPY,CPY_M,CPY_K,Stage)
cute::Tensor tOrInput_copy_view = smem_thr_copy_A.retile_D(tOrInput); // (CPY,CPY_M,CPY_K)
CUTE_STATIC_ASSERT_V(cute::size<1>(tOsInput) == cute::size<1>(tOrInput_copy_view)); // CPY_M
CUTE_STATIC_ASSERT_V(cute::size<2>(tOsInput) == cute::size<2>(tOrInput_copy_view)); // CPY_K
auto smem_tiled_copy_B = cute::make_tiled_copy_B(typename KT::SmemCopyAtomB{}, tiled_mma);
auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx);
cute::Tensor tOsfc1 = smem_thr_copy_B.partition_S(sfc1_weight); // (CPY,CPY_N,CPY_K,Stage)
cute::Tensor tOrfc1_copy_view = smem_thr_copy_B.retile_D(tOrfc1); // (CPY,CPY_N,CPY_K)
cute::Tensor tOsfc1g = smem_thr_copy_B.partition_S(sfc1_gate_weight); // (CPY,CPY_N,CPY_K,Stage)
cute::Tensor tOrfc1g_copy_view = smem_thr_copy_B.retile_D(tOrfc1g); // (CPY,CPY_N,CPY_K)
CUTE_STATIC_ASSERT_V(cute::size<1>(tOsfc1) == cute::size<1>(tOrfc1_copy_view)); // CPY_N
CUTE_STATIC_ASSERT_V(cute::size<2>(tOsfc1) == cute::size<2>(tOrfc1_copy_view)); // CPY_K
CUTE_STATIC_ASSERT_V(cute::size<1>(tOsfc1g) == cute::size<1>(tOrfc1g_copy_view)); // CPY_N
CUTE_STATIC_ASSERT_V(cute::size<2>(tOsfc1g) == cute::size<2>(tOrfc1g_copy_view)); // CPY_K
// (1.5) mainloop
// Current pipe index in smem to read from
int smem_pipe_read = 0;
// Current pipe index in smem to write to
int smem_pipe_write = KT::Stages - 1;
cute::Tensor tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read);
cute::Tensor tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read);
cute::Tensor tOsfc1g_p = tOsfc1g(cute::_, cute::_, cute::_, smem_pipe_read);
constexpr int K_BLOCK_MAX = cute::size<2>(tOrInput);
// prefetch register pipeline
if constexpr (K_BLOCK_MAX > 1)
{
cute::cp_async_wait<KT::Stages - 2>();
__syncthreads();
// Prefetch the first rmem from the first k-tile
cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, cute::Int<0>{}),
tOrInput_copy_view(cute::_, cute::_, cute::Int<0>{}));
cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, cute::Int<0>{}),
tOrfc1_copy_view(cute::_, cute::_, cute::Int<0>{}));
cute::copy(smem_tiled_copy_B, tOsfc1g_p(cute::_, cute::_, cute::Int<0>{}),
tOrfc1g_copy_view(cute::_, cute::_, cute::Int<0>{}));
}
// k loop for mainloop
CUTLASS_PRAGMA_NO_UNROLL
for (; k_tile_count > 0; --k_tile_count)
{
cute::for_each(cute::make_int_sequence<K_BLOCK_MAX>{},
[&](auto k_block)
{
if (k_block == K_BLOCK_MAX - 1)
{
tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read);
tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read);
tOsfc1g_p = tOsfc1g(cute::_, cute::_, cute::_, smem_pipe_read);
cute::cp_async_wait<KT::Stages - 2>();
__syncthreads();
}
// Load A, B shmem->regs for k_block+1
auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX;
cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next),
tOrInput_copy_view(cute::_, cute::_, k_block_next));
cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next),
tOrfc1_copy_view(cute::_, cute::_, k_block_next));
cute::copy(smem_tiled_copy_B, tOsfc1g_p(cute::_, cute::_, k_block_next),
tOrfc1g_copy_view(cute::_, cute::_, k_block_next));
// Copy gmem to smem before computing gemm on each k-pipe
if (k_block == 0)
{
// cute::copy(gmem_tiled_copy_A, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter),
// tInputsInput(cute::_, cute::_, cute::_, smem_pipe_write));
cute::copy_if(gmem_tiled_copy_A, tInputpInput,
tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter),
tInputsInput(cute::_, cute::_, cute::_, smem_pipe_write));
cute::copy(gmem_tiled_copy_B, tfc1gfc1(cute::_, cute::_, cute::_, *k_tile_iter),
tfc1sfc1(cute::_, cute::_, cute::_, smem_pipe_write));
cute::copy(gmem_tiled_copy_B, tfc1ggfc1g(cute::_, cute::_, cute::_, *k_tile_iter),
tfc1gsfc1g(cute::_, cute::_, cute::_, smem_pipe_write));
cute::cp_async_fence();
if (k_tile_count - 1 > 0)
{
++k_tile_iter;
}
// Advance the pipe -- Doing it here accounts for K_BLOCK_MAX = 1 (no rmem pipe)
smem_pipe_write = smem_pipe_read;
++smem_pipe_read;
smem_pipe_read = (smem_pipe_read == KT::Stages) ? 0 : smem_pipe_read;
}
// Thread-level register gemm for k_block
cute::gemm(tiled_mma, accum, tOrInput(cute::_, cute::_, k_block), tOrfc1(cute::_, cute::_, k_block),
accum);
cute::gemm(tiled_mma, accum_gate, tOrInput(cute::_, cute::_, k_block),
tOrfc1g(cute::_, cute::_, k_block), accum_gate);
});
}
// load tail
cute::for_each(cute::make_int_sequence<KT::Stages - 2>{},
[&](auto WaitIndex)
{
k_tile_count--;
using WaitIndex_t = decltype(WaitIndex);
cute::for_each(cute::make_int_sequence<K_BLOCK_MAX>{},
[&](auto k_block)
{
if (k_block == K_BLOCK_MAX - 1)
{
tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read);
tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read);
tOsfc1g_p = tOsfc1g(cute::_, cute::_, cute::_, smem_pipe_read);
cute::cp_async_wait<KT::Stages - 3 - WaitIndex_t::value>();
__syncthreads();
}
// Load A, B shmem->regs for k_block+1
auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX;
cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next),
tOrInput_copy_view(cute::_, cute::_, k_block_next));
cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next),
tOrfc1_copy_view(cute::_, cute::_, k_block_next));
cute::copy(smem_tiled_copy_B, tOsfc1g_p(cute::_, cute::_, k_block_next),
tOrfc1g_copy_view(cute::_, cute::_, k_block_next));
if (k_block == 0)
{
// only update smem_pipe_read
++smem_pipe_read;
smem_pipe_read = (smem_pipe_read == KT::Stages) ? 0 : smem_pipe_read;
}
// Thread-level register gemm for k_block
cute::gemm(tiled_mma, accum, tOrInput(cute::_, cute::_, k_block),
tOrfc1(cute::_, cute::_, k_block), accum);
cute::gemm(tiled_mma, accum_gate, tOrInput(cute::_, cute::_, k_block),
tOrfc1g(cute::_, cute::_, k_block), accum_gate);
});
});
// mma tail
cute::for_each(cute::make_int_sequence<K_BLOCK_MAX>{},
[&](auto k_block)
{
// Load A, B shmem->regs for k_block+1
auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX;
cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next),
tOrInput_copy_view(cute::_, cute::_, k_block_next));
cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next),
tOrfc1_copy_view(cute::_, cute::_, k_block_next));
cute::copy(smem_tiled_copy_B, tOsfc1g_p(cute::_, cute::_, k_block_next),
tOrfc1g_copy_view(cute::_, cute::_, k_block_next));
// Thread-level register gemm for k_block
cute::gemm(
tiled_mma, accum, tOrInput(cute::_, cute::_, k_block), tOrfc1(cute::_, cute::_, k_block), accum);
cute::gemm(tiled_mma, accum_gate, tOrInput(cute::_, cute::_, k_block),
tOrfc1g(cute::_, cute::_, k_block), accum_gate);
});
// if (cute::thread0()) {
// cute::print(accum_gate(0, 0, 0));
// printf("\n");
// }
// (2) add bias if it has..
if (params.ptr_bias != nullptr)
{
cute::Tensor gBias = gBias_mn(cute::_, cute::_, bias_is_broadcast ? 0 : block_m_idx, block_n_idx);
cute::Tensor gBias_gate = gBias_gate_mn(cute::_, cute::_, bias_is_broadcast ? 0 : block_m_idx, block_n_idx);
cute::Tensor tOgBias = thr_mma.partition_C(gBias);
cute::Tensor tOgBiasg = thr_mma.partition_C(gBias_gate);
for (int i = 0; i < cute::size(accum); i++)
{
accum(i) += tOgBias(i);
accum_gate(i) += tOgBiasg(i);
}
}
// (3) calculate swiglu
using ActivationFn = typename KT::ActivationFn;
ActivationFn fn{};
CUTLASS_PRAGMA_UNROLL
for (int temp_iter = 0; temp_iter < cute::size(accum); temp_iter++)
{
accum(temp_iter) = fn(accum_gate(temp_iter)) * accum(temp_iter);
}
// (4) push all the result to smem
// (4.1) convert result from ElementAccum to ElementInput
cute::Tensor temp_accum = util_convert_type<KT::ElementOutput>(accum);
// if (cute::thread0()) {
// cute::print(temp_accum(0, 0, 0));
// printf("\n");
// }
// (4.2) retile rf and smem for copy back..
auto smem_tiled_copy_O = cute::make_tiled_copy_C(typename KT::SmemCopyAtomO{}, tiled_mma);
auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx);
// cute::clear(sO);
cute::Tensor taccumrO = smem_thr_copy_O.retile_S(temp_accum);
cute::Tensor taccumsO = smem_thr_copy_O.partition_D(sO);
// (4.3) copy rf result to smem (TODO: maybe use forloop for better performance..)
cute::copy(smem_tiled_copy_O, taccumrO, taccumsO);
__syncthreads();
// (4.4) sO -> rO -> gO
typename KT::GmemTiledCopyO gmem_tiled_copy_O;
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
// auto gmem_thr_copy_Bias = gmem_tiled_copy_O.get_thread_slice(thread_idx % KT::kGmemTrheadsPerRow); //
// remember, for all the threads in the same col, they have the same idx for bias..
cute::Tensor gO = gOutput_mn(cute::_, cute::_, block_m_idx, block_n_idx);
// cute::Tensor gBias = gBias_mn(cute::_, cute::_, 0, block_n_idx); // bias only have one row..
auto tOsO = gmem_thr_copy_O.partition_S(sO);
auto tOgO = gmem_thr_copy_O.partition_D(gO);
// auto tOgBias = gmem_thr_copy_O.partition_D(gBias);
cute::Tensor cOutput = cute::make_identity_tensor(
cute::make_shape(cute::size<0>(typename KT::TileShape{}), cute::size<1>(typename KT::TileShape{})));
cute::Tensor tOcO = gmem_thr_copy_O.partition_D(cOutput);
cute::Tensor tOrO = cute::make_tensor<KT::ElementOutput>(cute::shape(tOgO));
cute::copy(gmem_tiled_copy_O, tOsO, tOrO);
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < cute::size<1>(tOgO); ++m)
{
if (cute::get<0>(tOcO(0, m, 0)) < residue_m)
{
cute::copy(gmem_tiled_copy_O, tOrO(cute::_, m, cute::_), tOgO(cute::_, m, cute::_));
}
}
}
};
template <typename ElementInput_, typename ElementWeight_, typename ElementOutput_, int TileM_, int TileN_, int TileK_,
int Stages_, Activation_Type activation_type_>
struct Fused_Moe_Kernel_routine_sm80<ElementInput_, ElementWeight_, ElementOutput_, TileM_, TileN_, TileK_, Stages_,
activation_type_, std::enable_if_t<!isGateActivation(activation_type_)>>
{
using KT = Fused_Moe_Kernel_traits_sm80<ElementInput_, ElementWeight_, ElementOutput_, TileM_, TileN_, TileK_,
Stages_, activation_type_>;
using Params = Routine_Params<ElementInput_, ElementWeight_, ElementOutput_>;
CUTE_DEVICE auto gmem_tensor_init(int const problem_index, int const gemm_m, Params const& params)
{
using X = cute::Underscore;
int const M = gemm_m;
int const N1 = params.gemm_n;
int const K1 = params.gemm_k;
bool const bias_is_broadcast = params.bias_is_broadcast;
int const row_jump = ((problem_index == 0) ? 0 : params.total_tokens_including_expert[problem_index - 1]);
typename KT::ElementInput const* ptr_input_ = params.ptr_input + row_jump * K1;
typename KT::ElementWeight const* ptr_fc1_ = params.ptr_fc1 + problem_index * N1 * K1;
typename KT::ElementInput const* ptr_bias_ = (params.ptr_bias == nullptr)
? nullptr
: (bias_is_broadcast ? params.ptr_bias + problem_index * N1 : params.ptr_bias + row_jump * N1);
typename KT::ElementOutput* ptr_output_ = params.ptr_output + row_jump * N1;
cute::Tensor mInput_mk
= cute::make_tensor(cute::make_gmem_ptr(static_cast<typename KT::ElementInput const*>(ptr_input_)),
cute::make_shape(M, K1), cute::make_stride(K1, cute::_1{}));
cute::Tensor mfc1_nk
= cute::make_tensor(cute::make_gmem_ptr(static_cast<typename KT::ElementWeight const*>(ptr_fc1_)),
cute::make_shape(N1, K1), cute::make_stride(K1, cute::_1{}));
cute::Tensor mBias_mn = cute::make_tensor(
cute::make_gmem_ptr(static_cast<typename KT::ElementInput const*>(ptr_bias_)), cute::make_shape(M, N1),
cute::make_stride(bias_is_broadcast ? cute::Int<0>{} : N1,
cute::_1{})); // trick: bias shape is [1, N], but we use [M, N].
cute::Tensor mOutput_mn
= cute::make_tensor(cute::make_gmem_ptr(static_cast<typename KT::ElementInput*>(ptr_output_)),
cute::make_shape(M, N1), cute::make_stride(N1, cute::_1{}));
cute::Tensor gInput_mk = cute::local_tile(mInput_mk, typename KT::TileShape{},
cute::make_coord(cute::_, cute::_, cute::_), cute::Step<cute::_1, X, cute::_1>{}); // (BLK_M, BLK_K, m, k)
cute::Tensor gfc1_nk = cute::local_tile(mfc1_nk, typename KT::TileShape{},
cute::make_coord(cute::_, cute::_, cute::_), cute::Step<X, cute::_1, cute::_1>{}); // (BLK_N, BLK_K, n, k)
cute::Tensor gBias_mn = cute::local_tile(mBias_mn, typename KT::TileShape{},
cute::make_coord(cute::_, cute::_, cute::_), cute::Step<cute::_1, cute::_1, X>{}); // (BLK_M, BLK_N, m, n)
cute::Tensor gOutput_mn = cute::local_tile(mOutput_mn, typename KT::TileShape{},
cute::make_coord(cute::_, cute::_, cute::_), cute::Step<cute::_1, cute::_1, X>{}); // (BLK_M, BLK_N, m, n)
return cute::make_tuple(gInput_mk, gfc1_nk, gBias_mn, gOutput_mn);
}
// be careful, m_idx will change when use another tile shape..
CUTE_DEVICE void run_routine(
Params const& params, int const problem_index, int const block_m_idx, int const block_n_idx, int const gemm_m)
{
extern __shared__ char smem_[];
typename KT::SharedStorage& shared_storage = *reinterpret_cast<typename KT::SharedStorage*>(smem_);
int const thread_idx = threadIdx.x;
bool const bias_is_broadcast = params.bias_is_broadcast;
// gmem tensor partition ..
auto [gInput_mk, gfc1_nk, gBias_mn, gOutput_mn] = gmem_tensor_init(problem_index, gemm_m, params);
int const residue_m = gemm_m - block_m_idx * cute::size<0>(gInput_mk);
auto const n_tile_count = cute::size<2>(gfc1_nk);
// smem tensor ..
cute::Tensor sInput = cute::make_tensor(
cute::make_smem_ptr(shared_storage.smem_input.data()), typename KT::SmemLayoutA{}); // (BLK_M, BLK_K, Stage)
cute::Tensor sfc1_weight = cute::make_tensor(cute::make_smem_ptr(shared_storage.smem_fc1_weight.data()),
typename KT::SmemLayoutB{}); // (BLK_N, BLK_K, Stage)
cute::Tensor sO = cute::make_tensor(
cute::make_smem_ptr(shared_storage.smem_o.data()), typename KT::SmemLayoutO{}); // (BLK_M, BLK_N)
// (1) first step, get the fc1_res and fc1_gate
// (1.1) get partition for gmem -> smem
cute::Tensor gInput = gInput_mk(cute::_, cute::_, block_m_idx, cute::_); // (BLK_M, BLK_K, k)
cute::Tensor gfc1 = gfc1_nk(cute::_, cute::_, block_n_idx, cute::_); // (BLK_N, BLK_K, k)
typename KT::GmemTiledCopyA gmem_tiled_copy_A;
typename KT::GmemTiledCopyB gmem_tiled_copy_B;
auto gmem_thr_copy_A = gmem_tiled_copy_A.get_slice(thread_idx);
auto gmem_thr_copy_B = gmem_tiled_copy_B.get_slice(thread_idx);
cute::Tensor tInputgInput = gmem_thr_copy_A.partition_S(gInput); // (ACPY,ACPY_M,ACPY_K,k)
cute::Tensor tInputsInput = gmem_thr_copy_A.partition_S(sInput); // (ACPY,ACPY_M,ACPY_K,Stage)
cute::Tensor tfc1gfc1 = gmem_thr_copy_B.partition_S(gfc1); // (BCPY,BCPY_N,BCPY_K,k)
cute::Tensor tfc1sfc1 = gmem_thr_copy_B.partition_D(sfc1_weight); // (BCPY,BCPY_N,BCPY_K,Stage)
// Allocate predicate tensors for input and fc weight (actually we only need input predicate tensor)
cute::Tensor tInputpInput
= cute::make_tensor<bool>(cute::make_shape(cute::size<1>(tInputsInput), cute::size<2>(tInputsInput)),
cute::Stride<cute::_1, cute::_0>{});
// Construct identity layout for sInput
cute::Tensor cInput = make_identity_tensor(
make_shape(cute::size<0>(sInput), cute::size<1>(sInput))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
// Repeat the partitioning with identity layouts
cute::Tensor tInputcInput = gmem_thr_copy_A.partition_S(cInput); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
// Set predicates for m bounds
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < cute::size<0>(tInputpInput); ++m)
{
tInputpInput(m, 0) = cute::get<0>(tInputcInput(0, m, 0)) < residue_m; // blk_m coord < residue_m
}
// (1.2) prefetch gmem -> smem
cute::clear(tInputsInput); // we don't need to clear tfc1sfc1..
auto k_tile_iter = cute::make_coord_iterator(cute::size<2>(gInput)); // emm, iter start from 0
int k_tile_count = cute::size<2>(gInput);
CUTLASS_PRAGMA_UNROLL
for (int k_pipe = 0; k_pipe < KT::Stages - 1; ++k_pipe)
{
if (k_tile_count <= 0)
{
cute::clear(tInputpInput);
}
// cute::copy(gmem_tiled_copy_A, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter),
// tInputsInput(cute::_, cute::_, cute::_, k_pipe));
// use copy_if
cute::copy_if(gmem_tiled_copy_A, tInputpInput, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter),
tInputsInput(cute::_, cute::_, cute::_, k_pipe));
cute::copy(gmem_tiled_copy_B, tfc1gfc1(cute::_, cute::_, cute::_, *k_tile_iter),
tfc1sfc1(cute::_, cute::_, cute::_, k_pipe));
cute::cp_async_fence();
k_tile_count--;
if (k_tile_count > 0)
{
++k_tile_iter;
}
}
// (1.3) get partition for rf
typename KT::TiledMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(thread_idx);
cute::Tensor tOrInput = thr_mma.partition_fragment_A(sInput(cute::_, cute::_, 0)); // (MMA,MMA_M,MMA_K)
cute::Tensor tOrfc1 = thr_mma.partition_fragment_B(sfc1_weight(cute::_, cute::_, 0)); // (MMA,MMA_N,MMA_K)
cute::Tensor accum
= cute::partition_fragment_C(tiled_mma, cute::take<0, 2>(typename KT::TileShape{})); // (MMA,MMA_M,MMA_N)
cute::clear(accum);
// checkout the shape
CUTE_STATIC_ASSERT_V(cute::size<1>(tOrInput) == cute::size<1>(accum)); // MMA_M
CUTE_STATIC_ASSERT_V(cute::size<1>(tOrfc1) == cute::size<2>(accum)); // MMA_N
CUTE_STATIC_ASSERT_V(cute::size<2>(tOrInput) == cute::size<2>(tOrfc1)); // MMA_K
CUTE_STATIC_ASSERT_V(cute::size(gmem_tiled_copy_A) == cute::size(tiled_mma));
CUTE_STATIC_ASSERT_V(cute::size(gmem_tiled_copy_B) == cute::size(tiled_mma));
// (1.4)retiling the smem and rf for copy..
auto smem_tiled_copy_A = cute::make_tiled_copy_A(typename KT::SmemCopyAtomA{}, tiled_mma);
auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx);
cute::Tensor tOsInput = smem_thr_copy_A.partition_S(sInput); // (CPY,CPY_M,CPY_K,Stage)
cute::Tensor tOrInput_copy_view = smem_thr_copy_A.retile_D(tOrInput); // (CPY,CPY_M,CPY_K)
CUTE_STATIC_ASSERT_V(cute::size<1>(tOsInput) == cute::size<1>(tOrInput_copy_view)); // CPY_M
CUTE_STATIC_ASSERT_V(cute::size<2>(tOsInput) == cute::size<2>(tOrInput_copy_view)); // CPY_K
auto smem_tiled_copy_B = cute::make_tiled_copy_B(typename KT::SmemCopyAtomB{}, tiled_mma);
auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx);
cute::Tensor tOsfc1 = smem_thr_copy_B.partition_S(sfc1_weight); // (CPY,CPY_N,CPY_K,Stage)
cute::Tensor tOrfc1_copy_view = smem_thr_copy_B.retile_D(tOrfc1); // (CPY,CPY_N,CPY_K)
CUTE_STATIC_ASSERT_V(cute::size<1>(tOsfc1) == cute::size<1>(tOrfc1_copy_view)); // CPY_N
CUTE_STATIC_ASSERT_V(cute::size<2>(tOsfc1) == cute::size<2>(tOrfc1_copy_view)); // CPY_K
// (1.5) mainloop
// Current pipe index in smem to read from
int smem_pipe_read = 0;
// Current pipe index in smem to write to
int smem_pipe_write = KT::Stages - 1;
cute::Tensor tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read);
cute::Tensor tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read);
constexpr int K_BLOCK_MAX = cute::size<2>(tOrInput);
// prefetch register pipeline
if constexpr (K_BLOCK_MAX > 1)
{
cute::cp_async_wait<KT::Stages - 2>();
__syncthreads();
// Prefetch the first rmem from the first k-tile
cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, cute::Int<0>{}),
tOrInput_copy_view(cute::_, cute::_, cute::Int<0>{}));
cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, cute::Int<0>{}),
tOrfc1_copy_view(cute::_, cute::_, cute::Int<0>{}));
}
// k loop for mainloop
CUTLASS_PRAGMA_NO_UNROLL
for (; k_tile_count > 0; --k_tile_count)
{
cute::for_each(cute::make_int_sequence<K_BLOCK_MAX>{},
[&](auto k_block)
{
if (k_block == K_BLOCK_MAX - 1)
{
tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read);
tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read);
cute::cp_async_wait<KT::Stages - 2>();
__syncthreads();
}
// Load A, B shmem->regs for k_block+1
auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX;
cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next),
tOrInput_copy_view(cute::_, cute::_, k_block_next));
cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next),
tOrfc1_copy_view(cute::_, cute::_, k_block_next));
// Copy gmem to smem before computing gemm on each k-pipe
if (k_block == 0)
{
// cute::copy(gmem_tiled_copy_A, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter),
// tInputsInput(cute::_, cute::_, cute::_, smem_pipe_write));
cute::copy_if(gmem_tiled_copy_A, tInputpInput,
tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter),
tInputsInput(cute::_, cute::_, cute::_, smem_pipe_write));
cute::copy(gmem_tiled_copy_B, tfc1gfc1(cute::_, cute::_, cute::_, *k_tile_iter),
tfc1sfc1(cute::_, cute::_, cute::_, smem_pipe_write));
cute::cp_async_fence();
if (k_tile_count - 1 > 0)
{
++k_tile_iter;
}
// Advance the pipe -- Doing it here accounts for K_BLOCK_MAX = 1 (no rmem pipe)
smem_pipe_write = smem_pipe_read;
++smem_pipe_read;
smem_pipe_read = (smem_pipe_read == KT::Stages) ? 0 : smem_pipe_read;
}
// Thread-level register gemm for k_block
cute::gemm(tiled_mma, accum, tOrInput(cute::_, cute::_, k_block), tOrfc1(cute::_, cute::_, k_block),
accum);
});
}
// load tail
cute::for_each(cute::make_int_sequence<KT::Stages - 2>{},
[&](auto WaitIndex)
{
k_tile_count--;
using WaitIndex_t = decltype(WaitIndex);
cute::for_each(cute::make_int_sequence<K_BLOCK_MAX>{},
[&](auto k_block)
{
if (k_block == K_BLOCK_MAX - 1)
{
tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read);
tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read);
cute::cp_async_wait<KT::Stages - 3 - WaitIndex_t::value>();
__syncthreads();
}
// Load A, B shmem->regs for k_block+1
auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX;
cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next),
tOrInput_copy_view(cute::_, cute::_, k_block_next));
cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next),
tOrfc1_copy_view(cute::_, cute::_, k_block_next));
if (k_block == 0)
{
// only update smem_pipe_read
++smem_pipe_read;
smem_pipe_read = (smem_pipe_read == KT::Stages) ? 0 : smem_pipe_read;
}
// Thread-level register gemm for k_block
cute::gemm(tiled_mma, accum, tOrInput(cute::_, cute::_, k_block),
tOrfc1(cute::_, cute::_, k_block), accum);
});
});
// mma tail
cute::for_each(cute::make_int_sequence<K_BLOCK_MAX>{},
[&](auto k_block)
{
// Load A, B shmem->regs for k_block+1
auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX;
cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next),
tOrInput_copy_view(cute::_, cute::_, k_block_next));
cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next),
tOrfc1_copy_view(cute::_, cute::_, k_block_next));
// Thread-level register gemm for k_block
cute::gemm(
tiled_mma, accum, tOrInput(cute::_, cute::_, k_block), tOrfc1(cute::_, cute::_, k_block), accum);
});
// if (cute::thread0()) {
// cute::print(accum_gate(0, 0, 0));
// printf("\n");
// }
// (2) add bias if it has..
if (params.ptr_bias != nullptr)
{
cute::Tensor gBias = gBias_mn(cute::_, cute::_, bias_is_broadcast ? 0 : block_m_idx, block_n_idx);
cute::Tensor tOgBias = thr_mma.partition_C(gBias);
for (int i = 0; i < cute::size(accum); i++)
{
accum(i) += tOgBias(i);
}
}
// (3) calculate swiglu
using ActivationFn = typename KT::ActivationFn;
ActivationFn fn{};
CUTLASS_PRAGMA_UNROLL
for (int temp_iter = 0; temp_iter < cute::size(accum); temp_iter++)
{
accum(temp_iter) = fn(accum(temp_iter));
}
// (4) push all the result to smem
// (4.1) convert result from ElementAccum to ElementInput
cute::Tensor temp_accum = util_convert_type<KT::ElementOutput>(accum);
// if (cute::thread0()) {
// cute::print(temp_accum(0, 0, 0));
// printf("\n");
// }
// (4.2) retile rf and smem for copy back..
auto smem_tiled_copy_O = cute::make_tiled_copy_C(typename KT::SmemCopyAtomO{}, tiled_mma);
auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx);
// cute::clear(sO);
cute::Tensor taccumrO = smem_thr_copy_O.retile_S(temp_accum);
cute::Tensor taccumsO = smem_thr_copy_O.partition_D(sO);
// (4.3) copy rf result to smem (TODO: maybe use forloop for better performance..)
cute::copy(smem_tiled_copy_O, taccumrO, taccumsO);
__syncthreads();
// (4.4) sO -> rO -> gO
typename KT::GmemTiledCopyO gmem_tiled_copy_O;
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
// auto gmem_thr_copy_Bias = gmem_tiled_copy_O.get_thread_slice(thread_idx % KT::kGmemTrheadsPerRow); //
cute::Tensor gO = gOutput_mn(cute::_, cute::_, block_m_idx, block_n_idx);
auto tOsO = gmem_thr_copy_O.partition_S(sO);
auto tOgO = gmem_thr_copy_O.partition_D(gO);
cute::Tensor cOutput = cute::make_identity_tensor(
cute::make_shape(cute::size<0>(typename KT::TileShape{}), cute::size<1>(typename KT::TileShape{})));
cute::Tensor tOcO = gmem_thr_copy_O.partition_D(cOutput);
cute::Tensor tOrO = cute::make_tensor<KT::ElementOutput>(cute::shape(tOgO));
cute::copy(gmem_tiled_copy_O, tOsO, tOrO);
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < cute::size<1>(tOgO); ++m)
{
if (cute::get<0>(tOcO(0, m, 0)) < residue_m)
{
cute::copy(gmem_tiled_copy_O, tOrO(cute::_, m, cute::_), tOgO(cute::_, m, cute::_));
}
}
}
};
} // namespace fused_moe

View File

@@ -0,0 +1,215 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cutlass/epilogue/thread/activation.h>
#include <cutlass_extensions/epilogue_helpers.h>
#include <cutlass_extensions/gemm/kernel/moe_cute_util.cuh>
#include <cutlass_extensions/gemm/kernel/moe_problem_visitor.h>
namespace fused_moe
{
template <typename ElementInput, typename ElementWeight, typename ElementOutput>
struct Routine_Arguments
{
ElementInput* ptr_input{};
ElementWeight* ptr_fc1{};
ElementInput* ptr_bias{};
ElementOutput* ptr_output{};
int64_t const* total_tokens_including_expert{};
int gemm_n{};
int gemm_k{};
int num_expert{};
bool bias_is_broadcast{};
};
template <typename ElementInput, typename ElementWeight, typename ElementOutput>
struct Routine_Params
{
ElementInput* ptr_input{};
ElementWeight* ptr_fc1{};
ElementInput* ptr_bias{};
ElementOutput* ptr_output{};
int64_t const* total_tokens_including_expert{};
int gemm_n{};
int gemm_k{};
int num_expert{};
bool bias_is_broadcast{};
};
enum class Activation_Type
{
Gelu = 0,
Relu,
Silu,
Swiglu,
Geglu,
Identity,
InvalidType
};
constexpr bool isGateActivation(Activation_Type const& activation_type)
{
return activation_type == Activation_Type::Swiglu || activation_type == Activation_Type::Geglu;
}
template <typename CutlassExtensionEpilogueTag>
constexpr Activation_Type EpilogueRouting(bool /*is_gate*/)
{
return Activation_Type::InvalidType;
}
template <>
constexpr Activation_Type EpilogueRouting<tensorrt_llm::cutlass_extensions::EpilogueOpDefault>(bool /*is_gate*/)
{
return Activation_Type::Identity;
}
template <>
constexpr Activation_Type EpilogueRouting<tensorrt_llm::cutlass_extensions::EpilogueOpDefaultReLU>(bool /*is_gate*/)
{
return Activation_Type::Relu;
}
template <>
constexpr Activation_Type EpilogueRouting<tensorrt_llm::cutlass_extensions::EpilogueOpDefaultSilu>(bool is_gate)
{
return is_gate ? Activation_Type::Swiglu : Activation_Type::Silu;
}
template <>
constexpr Activation_Type EpilogueRouting<tensorrt_llm::cutlass_extensions::EpilogueOpDefaultFtGelu>(bool is_gate)
{
return is_gate ? Activation_Type::Geglu : Activation_Type::Gelu;
}
/* fusing all three kernels has many limitations. This is the simpler version. Just fuse first two kernels..*/
template <typename ElementInput_, typename ElementWeight_, typename ElementOutput_, int TileM_, int TileN_, int TileK_,
int Stages_, Activation_Type activation_type>
struct Fused_Moe_Kernel_traits_sm80
{
using ElementInput = ElementInput_;
using ElementWeight = ElementWeight_;
using ElementAccum = float;
using ElementOutput = ElementOutput_;
using index_t = uint32_t;
static_assert(TileM_ % 16 == 0);
static_assert(TileN_ % 32 == 0);
static_assert(TileK_ % 32 == 0);
static constexpr int Stages = Stages_;
static constexpr int kTileM = TileM_;
static constexpr int kTileN = TileN_;
static constexpr int kTileK = (kTileM > 16) ? (TileK_) : (TileK_ >= 64 ? TileK_ : 64);
// tile shape
using TileShape = cute::Shape<cute::Int<kTileM>, cute::Int<kTileN>, cute::Int<kTileK>>;
static constexpr int kWarpsCount = 4;
static constexpr int kThreadCount = kWarpsCount * 32;
// MMA atom arch and layout
using MMA_Atom_Arch = std::conditional_t<std::is_same_v<ElementInput, cutlass::half_t>,
cute::MMA_Atom<cute::SM80_16x8x16_F32F16F16F32_TN>, cute::MMA_Atom<cute::SM80_16x8x16_F32BF16BF16F32_TN>>;
// using ValLayoutMNK = cute::Layout<cute::Shape<cute::_1, cute::_2, cute::_1>>;
using ThreadLayoutMNK
= std::conditional_t<kTileM == 16, cute::Layout<cute::Shape<cute::_1, cute::Int<kWarpsCount / 1>, cute::_1>>,
cute::Layout<cute::Shape<cute::_2, cute::Int<kWarpsCount / 2>, cute::_1>>>;
using ValLayoutMNK = std::conditional_t<kTileM == 16, cute::Tile<cute::_16, cute::_64, cute::_16>,
cute::Tile<cute::_32, cute::_32, cute::_16>>;
using TiledMma = cute::TiledMMA<MMA_Atom_Arch, ThreadLayoutMNK,
ValLayoutMNK>; // 32x32x16 or 16x64x16 MMA for LDSM if kWarp = 4
static constexpr int kAlignment = 8;
static constexpr int kBlcokKSmem = (kTileM == 16) ? 64 : 32;
// A memory copy operand
using DefaultOperandA
= DefaultGemm_TensorOpSm80_OperandA<ElementInput, cutlass::layout::RowMajor, kAlignment, kBlcokKSmem>;
using SmemLayoutAtomA = typename DefaultOperandA::SmemLayoutAtom;
using SmemCopyAtomA = typename DefaultOperandA::SmemCopyAtom;
using GmemTiledCopyA = typename DefaultOperandA::GmemTiledCopy;
// B memory copy operand
using DefaultOperandB
= DefaultGemm_TensorOpSm80_OperandB<ElementWeight, cutlass::layout::ColumnMajor, kAlignment, kBlcokKSmem>;
using SmemLayoutAtomB = typename DefaultOperandB::SmemLayoutAtom;
using SmemCopyAtomB = typename DefaultOperandB::SmemCopyAtom;
using GmemTiledCopyB = typename DefaultOperandB::GmemTiledCopy;
// Output memory copy operand
using SmemLayoutAtomO = SmemLayoutAtomA;
using SmemCopyAtomO = cute::Copy_Atom<cute::DefaultCopy, ElementOutput>;
static constexpr int kGmemElementPerLoad = sizeof(cute::uint128_t) / sizeof(ElementOutput);
static constexpr int kGmemTrheadsPerRow = kBlcokKSmem / kGmemElementPerLoad;
using GmemLayoutAtomO
= cute::Layout<cute::Shape<cute::Int<kThreadCount / kGmemTrheadsPerRow>, cute::Int<kGmemTrheadsPerRow>>,
cute::Stride<cute::Int<kGmemTrheadsPerRow>, cute::_1>>;
using GmemTiledCopyO = decltype(cute::make_tiled_copy(cute::Copy_Atom<cute::DefaultCopy, ElementOutput>{},
GmemLayoutAtomO{}, cute::Layout<cute::Shape<cute::_1, cute::_8>>{}));
static_assert(cute::rank(SmemLayoutAtomA{}) == 2);
static_assert(cute::size<0>(TileShape{}) % cute::size<0>(SmemLayoutAtomA{}) == 0); // M
static_assert(cute::size<2>(TileShape{}) % cute::size<1>(SmemLayoutAtomA{}) == 0); // K
static_assert(cute::rank(SmemLayoutAtomB{}) == 2);
static_assert(cute::size<1>(TileShape{}) % cute::size<0>(SmemLayoutAtomB{}) == 0); // N
static_assert(cute::size<2>(TileShape{}) % cute::size<1>(SmemLayoutAtomB{}) == 0); // K
using SmemLayoutA = decltype(cute::tile_to_shape(SmemLayoutAtomA{},
cute::make_shape(
cute::shape<0>(TileShape{}), cute::shape<2>(TileShape{}), cute::Int<Stages>{}))); // BLK_M, BLK_K, Stages
using SmemLayoutB = decltype(cute::tile_to_shape(SmemLayoutAtomB{},
cute::make_shape(
cute::shape<1>(TileShape{}), cute::shape<2>(TileShape{}), cute::Int<Stages>{}))); // BLK_N, BLK_K, Stages
using SmemLayoutO = decltype(cute::tile_to_shape(
SmemLayoutAtomO{}, cute::make_shape(cute::shape<0>(TileShape{}), cute::shape<1>(TileShape{})))); // BLK_M, BLK_N
// we need at least 2 stages..
static_assert(Stages >= 2);
struct SharedStorageNormal : cute::aligned_struct<128>
{
cute::array_aligned<ElementInput, cute::cosize_v<SmemLayoutA>> smem_input;
cute::array_aligned<ElementInput, cute::cosize_v<SmemLayoutB>> smem_fc1_weight;
cute::array_aligned<ElementInput, cute::cosize_v<SmemLayoutO>> smem_o;
};
struct SharedStorageGate : cute::aligned_struct<128>
{
cute::array_aligned<ElementInput, cute::cosize_v<SmemLayoutA>> smem_input;
cute::array_aligned<ElementInput, cute::cosize_v<SmemLayoutB>> smem_fc1_gate_weight;
cute::array_aligned<ElementInput, cute::cosize_v<SmemLayoutB>> smem_fc1_weight;
cute::array_aligned<ElementInput, cute::cosize_v<SmemLayoutO>> smem_o;
};
using SharedStorage = std::conditional_t<isGateActivation(activation_type), SharedStorageGate, SharedStorageNormal>;
using ActivationFn = std::conditional_t<activation_type == Activation_Type::Gelu
|| activation_type == Activation_Type::Geglu,
cutlass::epilogue::thread::GELU<float>,
std::conditional_t<activation_type == Activation_Type::Relu, cutlass::epilogue::thread::ReLU<float>,
std::conditional_t<activation_type == Activation_Type::Silu || activation_type == Activation_Type::Swiglu,
cutlass::epilogue::thread::SiLu<float>, cutlass::epilogue::thread::Identity<float>>>>;
static constexpr int kSmemSize = static_cast<int>(sizeof(SharedStorage));
static constexpr bool can_implement(int const avaliable_smem_size)
{
return avaliable_smem_size > kSmemSize;
}
// #endif
};
} // namespace fused_moe

View File

@@ -0,0 +1,73 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*! \file
\brief Scheduler for grouped GEMM
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h"
#include "cutlass/matrix_coord.h"
#include "cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h"
#include "cutlass_extensions/gemm/kernel/moe_problem_visitor.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace gemm
{
namespace kernel
{
/// Visitor class to abstract away the algorithm for iterating over tiles
template <typename ThreadblockShape, GroupScheduleMode GroupScheduleMode_, int PrefetchTileCount, int ThreadCount,
bool Transposed = false>
struct GemmMoeProblemVisitor
: public MoeProblemVisitor<detail::GemmGroupedProblemSizeHelper<ThreadblockShape, Transposed>, ThreadblockShape,
GroupScheduleMode_, PrefetchTileCount, ThreadCount>
{
static bool const kTransposed = Transposed;
using ProblemSizeHelper = detail::GemmGroupedProblemSizeHelper<ThreadblockShape, Transposed>;
using Base
= MoeProblemVisitor<ProblemSizeHelper, ThreadblockShape, GroupScheduleMode_, PrefetchTileCount, ThreadCount>;
using Params = typename Base::Params;
using SharedStorage = typename Base::SharedStorage;
//
// Methods
//
CUTLASS_DEVICE
GemmMoeProblemVisitor(Params const& params_, SharedStorage& shared_storage_, int32_t block_idx)
: Base(params_, shared_storage_, block_idx)
{
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,70 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::kernel
{
////////////////////////////////////////////////////////////////////////////////
/*
* Stateless universal device GEMM kernel type that treats GEMM as
* a composition of a collective mainloop and a collective epilogue.
*
* Supports both the 2.x and 3.x APIs based on whether the first type is
* a cute::tuple<> or not.
* 2.x API implementation: cutlass/gemm/kernel/gemm_universal.h
* 3.x API implementation: cutlass/gemm/kernel/gemm_*.hpp
*
* In the following declaration, the name preceding the 'Or' refers to
* 3.x API type argument order, and the name succeeding the 'Or' refers to
* 2.x API type argument order. Template arguments without two names
* belong to the 3.x API only.
**/
template <class ProblemShapeOrThreadblockMma_, // (m, n, k) or (m, n, k, l)
class CollectiveMainloopOrEpilogue_, class CollectiveEpilogueOrThreadblockSwizzle_, class TileScheduler_ = void,
class Enable = void>
class GemmUniversalGated;
////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::kernel
////////////////////////////////////////////////////////////////////////////////
#include "cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_cooperative.hpp"
#include "cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_pingpong.hpp"
////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,585 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief GEMM kernel to support the epilogue visitor model
for customized softmax partial reduction epilogue fusion.
This source file will likely be moved to `include/cutlass/gemm/kernel/` in the future once
its usage has been stabilized. For now, it is included in this example to demonstrate
some basic output fusion options.
original file: 3rdparty/cutlass/examples/35_gemm_softmax/gemm_with_epilogue_visitor.h
*/
#pragma once
#include "cutlass/complex.h"
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/semaphore.h"
#include "cutlass/trace.h"
#include "cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h"
namespace tk = tensorrt_llm::common;
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace gemm
{
namespace kernel
{
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
typename Epilogue_, ///! Epilogue
typename ThreadblockSwizzle_ ///! Threadblock swizzling function
>
struct GemmWithEpilogueVisitor
{
public:
using Mma = Mma_;
using Epilogue = Epilogue_;
using EpilogueVisitor = typename Epilogue::Visitor;
using ThreadblockSwizzle = ThreadblockSwizzle_;
using ElementA = typename Mma::IteratorA::Element;
using LayoutA = typename Mma::IteratorA::Layout;
using TensorRefA = TensorRef<ElementA, LayoutA>;
using ElementB = typename Mma::IteratorB::Element;
using LayoutB = typename Mma::IteratorB::Layout;
using TensorRefB = TensorRef<ElementB, LayoutB>;
using ElementCompute = typename EpilogueVisitor::ElementCompute;
using LayoutAlphaCol = cutlass::layout::RowMajor;
using LayoutAlphaRow = cutlass::layout::ColumnMajor;
using TensorRefAlphaCol = TensorRef<ElementCompute, LayoutAlphaCol>;
using TensorRefAlphaRow = TensorRef<ElementCompute, LayoutAlphaRow>;
using ElementC = typename EpilogueVisitor::ElementOutput;
using LayoutC = typename Epilogue::Layout;
using TensorRefC = TensorRef<ElementC, LayoutC>;
static ComplexTransform const kTransformA = Mma::kTransformA;
static ComplexTransform const kTransformB = Mma::kTransformB;
using Operator = typename Mma::Operator;
using OperatorClass = typename Mma::Operator::OperatorClass;
using ThreadblockShape = typename Mma::Shape;
using WarpShape = typename Mma::Operator::Shape;
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
using ArchTag = typename Mma::ArchTag;
using EpilogueOutputOp =
typename Epilogue::Visitor::ElementwiseFunctor; // Define type so GemmUniversalBase doesn't complain
static int const kStages = Mma::kStages;
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
static int const kAlignmentC = EpilogueVisitor::kElementsPerAccess;
/// Warp count (concept: GemmShape)
using WarpCount = typename Mma::WarpCount;
static int const kThreadCount = 32 * WarpCount::kCount;
/// Split-K preserves splits that are 128b aligned
static int const kSplitKAlignment
= const_max(128 / sizeof_bits<ElementA>::value, 128 / sizeof_bits<ElementB>::value);
//
// Structures
//
/// Argument structure
struct Arguments
{
//
// Data members
//
GemmUniversalMode mode;
GemmCoord problem_size;
int batch_count;
TensorRefA ref_A;
TensorRefB ref_B;
tk::QuantMode quant_option;
TensorRefAlphaCol ref_alpha_col;
TensorRefAlphaRow ref_alpha_row;
TensorRefC ref_C;
TensorRefC ref_D;
int64_t batch_stride_A;
int64_t batch_stride_B;
int64_t batch_stride_D;
typename EpilogueVisitor::Arguments epilogue_visitor;
//
// Methods
//
Arguments()
: mode(GemmUniversalMode::kGemm)
, batch_count(1)
{
}
/// constructs an arguments structure
Arguments(GemmUniversalMode mode_, GemmCoord problem_size_, int batch_count_, TensorRefA ref_A_,
TensorRefB ref_B_, tk::QuantMode quant_option_, TensorRefAlphaCol ref_alpha_col_,
TensorRefAlphaRow ref_alpha_row_, TensorRefC ref_C_, TensorRefC ref_D_, int64_t batch_stride_A_,
int64_t batch_stride_B_, typename EpilogueVisitor::Arguments epilogue_visitor_)
: mode(mode_)
, problem_size(problem_size_)
, batch_count(batch_count_)
, ref_A(ref_A_)
, ref_B(ref_B_)
, quant_option(quant_option_)
, ref_alpha_col(ref_alpha_col_)
, ref_alpha_row(ref_alpha_row_)
, ref_C(ref_C_)
, ref_D(ref_D_)
, batch_stride_A(batch_stride_A_)
, batch_stride_B(batch_stride_B_)
, batch_stride_D(0)
, epilogue_visitor(epilogue_visitor_)
{
}
};
//
// Structure for precomputing values in host memory and passing to kernels
//
/// Parameters structure
struct Params
{
cutlass::gemm::GemmCoord problem_size;
cutlass::gemm::GemmCoord grid_tiled_shape;
int swizzle_log_tile;
typename Mma::IteratorA::Params params_A;
typename Mma::IteratorB::Params params_B;
typename EpilogueVisitor::ScaleTileIterator::Params params_alpha_col;
typename EpilogueVisitor::ScaleTileIterator::Params params_alpha_row;
typename EpilogueVisitor::OutputTileIterator::Params params_C;
typename EpilogueVisitor::OutputTileIterator::Params params_D;
GemmUniversalMode mode;
int batch_count;
int gemm_k_size;
void* ptr_A;
void* ptr_B;
tk::QuantMode quant_option;
typename EpilogueVisitor::ScaleTileIterator::Element* ptr_alpha_col;
typename EpilogueVisitor::ScaleTileIterator::Element* ptr_alpha_row;
ElementC* ptr_C;
ElementC* ptr_D;
int64_t batch_stride_A;
int64_t batch_stride_B;
typename EpilogueVisitor::Params epilogue_visitor;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params()
: swizzle_log_tile(0)
, params_A(0)
, params_B(0)
, params_alpha_col(0)
, params_C(0)
, params_D(0)
, batch_count(0)
, gemm_k_size(0)
, mode(cutlass::gemm::GemmUniversalMode::kGemm)
, ptr_A(nullptr)
, ptr_B(nullptr)
, ptr_alpha_col(nullptr)
, ptr_alpha_row(nullptr)
, ptr_C(nullptr)
, ptr_D(nullptr)
, batch_stride_A(0)
, batch_stride_B(0)
{
}
Params(
Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape_, int gemm_k_size_, int* workspace_)
: problem_size(args.problem_size)
, swizzle_log_tile(0)
, params_A(args.ref_A.layout())
, params_B(args.ref_B.layout())
, params_alpha_col(args.ref_alpha_col.layout())
, params_alpha_row(args.ref_alpha_col.layout())
, params_C(args.ref_C.layout())
, params_D(args.ref_D.layout())
, mode(args.mode)
, batch_count(args.batch_count)
, gemm_k_size(args.problem_size.k())
, ptr_A(args.ref_A.data())
, ptr_B(args.ref_B.data())
, quant_option(args.quant_option)
, ptr_alpha_col(args.ref_alpha_col.data())
, ptr_alpha_row(args.ref_alpha_row.data())
, ptr_C(args.ref_C.data())
, ptr_D(args.ref_D.data())
, batch_stride_A(args.batch_stride_A)
, batch_stride_B(args.batch_stride_B)
, epilogue_visitor(args.epilogue_visitor)
{
ThreadblockSwizzle threadblock_swizzle;
grid_tiled_shape = threadblock_swizzle.get_tiled_shape(args.problem_size,
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count);
if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel)
{
int const kAlignK
= const_max(const_max(128 / sizeof_bits<ElementA>::value, 128 / sizeof_bits<ElementB>::value), 1);
gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK);
if (gemm_k_size)
{
grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size);
}
}
swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape);
}
};
/// Shared memory storage structure
union SharedStorage
{
typename Mma::SharedStorage main_loop;
struct
{
typename Epilogue::SharedStorage epilogue;
typename EpilogueVisitor::SharedStorage visitor;
} epilogue;
};
public:
//
// Methods
//
CUTLASS_DEVICE
GemmWithEpilogueVisitor() {}
/// Determines whether kernel satisfies alignment
static Status can_implement(cutlass::gemm::GemmCoord const& problem_size)
{
CUTLASS_TRACE_HOST("GemmWithEpilogueVisitor::can_implement()");
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
static int const kAlignmentC = EpilogueVisitor::OutputTileIterator::kElementsPerAccess;
bool isAMisaligned = false;
bool isBMisaligned = false;
bool isCMisaligned = false;
if (platform::is_same<LayoutA, layout::RowMajor>::value)
{
isAMisaligned = problem_size.k() % kAlignmentA;
}
else if (platform::is_same<LayoutA, layout::ColumnMajor>::value)
{
isAMisaligned = problem_size.m() % kAlignmentA;
}
else if (platform::is_same<LayoutA, layout::ColumnMajorInterleaved<32>>::value
|| platform::is_same<LayoutA, layout::ColumnMajorInterleaved<64>>::value)
{
isAMisaligned = problem_size.k() % kAlignmentA;
}
if (platform::is_same<LayoutB, layout::RowMajor>::value)
{
isBMisaligned = problem_size.n() % kAlignmentB;
}
else if (platform::is_same<LayoutB, layout::ColumnMajor>::value)
{
isBMisaligned = problem_size.k() % kAlignmentB;
}
else if (platform::is_same<LayoutB, layout::RowMajorInterleaved<32>>::value
|| platform::is_same<LayoutB, layout::RowMajorInterleaved<64>>::value)
{
isBMisaligned = problem_size.k() % kAlignmentB;
}
if (platform::is_same<LayoutC, layout::RowMajor>::value)
{
isCMisaligned = problem_size.n() % kAlignmentC;
}
else if (platform::is_same<LayoutC, layout::ColumnMajor>::value)
{
isCMisaligned = problem_size.m() % kAlignmentC;
}
else if (platform::is_same<LayoutC, layout::ColumnMajorInterleaved<32>>::value
|| platform::is_same<LayoutC, layout::ColumnMajorInterleaved<64>>::value)
{
isCMisaligned = problem_size.n() % kAlignmentC;
}
if (isAMisaligned)
{
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand");
return Status::kErrorMisalignedOperand;
}
if (isBMisaligned)
{
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand");
return Status::kErrorMisalignedOperand;
}
if (isCMisaligned)
{
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand");
return Status::kErrorMisalignedOperand;
}
CUTLASS_TRACE_HOST(" returning kSuccess");
return Status::kSuccess;
}
static Status can_implement(Arguments const& args)
{
return can_implement(args.problem_size);
}
static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape)
{
return 0;
}
#define SPLIT_K_ENABLED 1
/// Executes one GEMM
CUTLASS_DEVICE
void run_kernel_(Params const& params, SharedStorage& shared_storage)
{
// Compute threadblock location
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
// Early exit if CTA is out of range
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m()
|| params.grid_tiled_shape.n() <= threadblock_tile_offset.n())
{
return;
}
int offset_k = 0;
int problem_size_k = params.problem_size.k();
ElementA* ptr_A = static_cast<ElementA*>(params.ptr_A);
ElementB* ptr_B = static_cast<ElementB*>(params.ptr_B);
#if SPLIT_K_ENABLED
//
// Fetch pointers based on mode.
//
if (params.mode == GemmUniversalMode::kGemm || params.mode == GemmUniversalMode::kGemmSplitKParallel)
{
if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k())
{
problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size;
}
offset_k = threadblock_tile_offset.k() * params.gemm_k_size;
}
else if (params.mode == GemmUniversalMode::kBatched)
{
ptr_A += threadblock_tile_offset.k() * params.batch_stride_A;
ptr_B += threadblock_tile_offset.k() * params.batch_stride_B;
}
else if (params.mode == GemmUniversalMode::kArray)
{
ptr_A = static_cast<ElementA* const*>(params.ptr_A)[threadblock_tile_offset.k()];
ptr_B = static_cast<ElementB* const*>(params.ptr_B)[threadblock_tile_offset.k()];
}
#endif
// Compute initial location in logical coordinates
cutlass::MatrixCoord tb_offset_A{
threadblock_tile_offset.m() * Mma::Shape::kM,
offset_k,
};
cutlass::MatrixCoord tb_offset_B{offset_k, threadblock_tile_offset.n() * Mma::Shape::kN};
// Compute position within threadblock
int thread_idx = threadIdx.x;
// Construct iterators to A and B operands
typename Mma::IteratorA iterator_A(
params.params_A, ptr_A, {params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A);
typename Mma::IteratorB iterator_B(
params.params_B, ptr_B, {problem_size_k, params.problem_size.n()}, thread_idx, tb_offset_B);
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
int lane_idx = threadIdx.x % 32;
//
// Main loop
//
// Construct thread-scoped matrix multiply
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
typename Mma::FragmentC accumulators;
accumulators.clear();
// Compute threadblock-scoped matrix multiply-add
int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK;
// Compute threadblock-scoped matrix multiply-add
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators);
//
// Masked tile iterators constructed from members
//
threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
// assume identity swizzle
MatrixCoord threadblock_offset(
threadblock_tile_offset.m() * Mma::Shape::kM, threadblock_tile_offset.n() * Mma::Shape::kN);
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
//
// Construct the epilogue visitor
//
EpilogueVisitor epilogue_visitor(params.epilogue_visitor, shared_storage.epilogue.visitor,
params.problem_size.mn(), thread_idx, warp_idx, lane_idx, params.params_alpha_col, params.params_C,
params.params_D, params.quant_option, params.ptr_alpha_row, params.ptr_alpha_col, params.ptr_C,
params.ptr_D, threadblock_offset, blockIdx.y * params.problem_size.m());
if (params.mode == GemmUniversalMode::kGemm)
{
// Indicate which position in a serial reduction the output operator is currently updating
epilogue_visitor.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
}
else if (params.mode == GemmUniversalMode::kBatched || params.mode == GemmUniversalMode::kArray)
{
epilogue_visitor.set_batch_index(threadblock_tile_offset.k());
}
// Construct the epilogue
Epilogue epilogue(shared_storage.epilogue.epilogue, thread_idx, warp_idx, lane_idx);
// Execute the epilogue operator to update the destination tensor.
epilogue(epilogue_visitor, accumulators);
}
template <typename CompilationArch>
CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage)
{
if constexpr (platform::is_same<ArchTag, CompilationArch>::value)
{
run_kernel_(params, shared_storage);
}
else
{
CUTLASS_NOT_IMPLEMENTED();
}
}
/*
To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond
to the ArchTag of the cutlass kernel operator.
*/
/// Executes one GEMM
CUTLASS_DEVICE
void operator()(Params const& params, SharedStorage& shared_storage)
{
#if defined(__CUDA_ARCH__)
#if (__CUDA_ARCH__ >= 720) && (__CUDA_ARCH__ < 750)
run_kernel<arch::Sm72>(params, shared_storage);
#elif (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800)
run_kernel<arch::Sm75>(params, shared_storage);
#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900)
run_kernel<arch::Sm80>(params, shared_storage);
#elif (__CUDA_ARCH__ >= 900)
// TODO - replace with CUTLASS_NOT_IMPLEMENTED() and upgrade to 3.x kernels.
run_kernel<arch::Sm80>(params, shared_storage);
#else
static_assert(
false, "Invalid architecture being compiled. Only Volta+ supported in weight-only quantization kernels.");
#endif
#else
CUTLASS_NOT_IMPLEMENTED();
#endif
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,143 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
This file exists so that we use the same weight layout for MoE grouped gemm and regular gemm when the weight is
quantized. The preprocessing code reads this template to know how to organize the quantized weight matrices
to be consumed by CUTLASS.
Note that for int4, ThreadBlockK MUST be 64.
*/
#pragma once
#include "cutlass/layout/matrix.h"
#include "cutlass/numeric_types.h"
#include "cutlass/arch/arch.h"
#include "cutlass/arch/mma.h"
#include "cutlass/platform/platform.h"
#include "cutlass_extensions/arch/mma.h"
#include "cutlass_extensions/tile_interleaved_layout.h"
namespace cutlass
{
namespace gemm
{
namespace kernel
{
template <typename TypeA, typename TypeB, typename Arch, typename Enable = void>
struct LayoutDetailsB
{
};
// Specializations for Turing+ when B is FP16. These are currently only used for MoE networks.
// TODO - Switch this to column major for weights since gemms should be more performant.
template <typename TypeA, typename Arch>
struct LayoutDetailsB<TypeA, half_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type>
{
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
using Layout = layout::ColumnMajor;
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<half_t>::value;
using Operator = cutlass::arch::OpMultiplyAdd;
};
template <typename TypeA, typename Arch>
struct LayoutDetailsB<TypeA, bfloat16_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type>
{
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
using Layout = layout::ColumnMajor;
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<bfloat16_t>::value;
using Operator = cutlass::arch::OpMultiplyAdd;
};
template <typename TypeA>
struct LayoutDetailsB<TypeA, cutlass::float_e4m3_t, arch::Sm89>
{
static constexpr int ThreadblockK = 64;
private:
static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits<uint8_t>::value;
static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK;
public:
using Layout = layout::ColumnMajor;
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<cutlass::float_e4m3_t>::value;
using Operator = cutlass::arch::OpMultiplyAdd;
// for fast accumulation
// using Operator = cutlass::arch::OpMultiplyAddFastAccum;
};
// Specializations for Turing+ when B is quantized. These can use the operator OpMultiplyAddDequantizeInterleavedBToA,
// which signals that we want to dequantize after loading from smem.
template <typename TypeA, typename Arch>
struct LayoutDetailsB < TypeA,
uint8_t, Arch,
typename platform::enable_if<Arch::kMinComputeCapability >= 75 && Arch::kMinComputeCapability<90>::type>
{
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
private:
static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits<uint8_t>::value;
static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK;
public:
using Layout = layout::ColumnMajorTileInterleave<ThreadblockK, ColumnsInterleaved>;
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<uint8_t>::value;
using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA;
};
template <typename TypeA, typename Arch>
struct LayoutDetailsB < TypeA,
uint4b_t, Arch,
typename platform::enable_if<Arch::kMinComputeCapability >= 75 && Arch::kMinComputeCapability<90>::type>
{
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
private:
static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits<uint4b_t>::value;
static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK;
public:
using Layout = layout::ColumnMajorTileInterleave<ThreadblockK, ColumnsInterleaved>;
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<uint4b_t>::value;
using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA;
};
template <typename TypeA, typename Arch>
struct LayoutDetailsB<TypeA, uint8_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 90>::type>
{
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
using Layout = layout::ColumnMajor;
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<half_t>::value;
using Operator = cutlass::arch::OpMultiplyAdd;
};
template <typename TypeA, typename Arch>
struct LayoutDetailsB<TypeA, uint4b_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 90>::type>
{
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
using Layout = layout::ColumnMajor;
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<half_t>::value;
using Operator = cutlass::arch::OpMultiplyAdd;
};
} // namespace kernel
} // namespace gemm
} // namespace cutlass

View File

@@ -0,0 +1,185 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cute/algorithm/copy.hpp>
#include <cute/atom/copy_atom.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/layout/layout.h>
#include <cutlass/numeric_conversion.h>
template <typename Element, typename Layout, int Alignment, int SizeK>
struct DefaultGemm_TensorOpSm80_OperandA;
template <typename Element, typename Layout, int Alignment, int SizeK>
struct DefaultGemm_TensorOpSm80_OperandB;
template <>
struct DefaultGemm_TensorOpSm80_OperandA<cute::half_t, cutlass::layout::RowMajor, 8, 64>
{
// Smem
using SmemLayoutAtom = decltype(cute::composition(
cute::Swizzle<3, 3, 3>{}, cute::Layout<cute::Shape<cute::_8, cute::_64>, cute::Stride<cute::_64, cute::_1>>{}));
using SmemCopyAtom = cute::Copy_Atom<cute::SM75_U32x4_LDSM_N, cute::half_t>;
// Gmem
using GmemTiledCopy = decltype(cute::make_tiled_copy(
cute::Copy_Atom<cute::SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, cute::half_t>{},
cute::Layout<cute::Shape<cute::_16, cute::_8>, cute::Stride<cute::_8, cute::_1>>{},
cute::Layout<cute::Shape<cute::_1, cute::_8>>{}));
};
template <>
struct DefaultGemm_TensorOpSm80_OperandA<cute::bfloat16_t, cutlass::layout::RowMajor, 8, 64>
{
// Smem
using SmemLayoutAtom = decltype(cute::composition(
cute::Swizzle<3, 3, 3>{}, cute::Layout<cute::Shape<cute::_8, cute::_64>, cute::Stride<cute::_64, cute::_1>>{}));
using SmemCopyAtom = cute::Copy_Atom<cute::SM75_U32x4_LDSM_N, cute::bfloat16_t>;
// Gmem
using GmemTiledCopy = decltype(cute::make_tiled_copy(
cute::Copy_Atom<cute::SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, cute::bfloat16_t>{},
cute::Layout<cute::Shape<cute::_16, cute::_8>, cute::Stride<cute::_8, cute::_1>>{},
cute::Layout<cute::Shape<cute::_1, cute::_8>>{}));
};
/// Operand A - Column-major (M-major)
template <int SizeK>
struct DefaultGemm_TensorOpSm80_OperandA<cute::half_t, cutlass::layout::ColumnMajor, 8, SizeK>
{
// Smem
using SmemLayoutAtom = decltype(cute::composition(
cute::Swizzle<3, 3, 3>{}, cute::Layout<cute::Shape<cute::_64, cute::_8>, cute::Stride<cute::_1, cute::_64>>{}));
using SmemCopyAtom = cute::Copy_Atom<cute::SM75_U16x8_LDSM_T, cute::half_t>;
// Gmem
using GmemTiledCopy = decltype(cute::make_tiled_copy(
cute::Copy_Atom<cute::SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, cute::half_t>{},
cute::Layout<cute::Shape<cute::_16, cute::_8>, cute::Stride<cute::_1, cute::_16>>{},
cute::Layout<cute::Shape<cute::_8, cute::_1>>{}));
};
template <int SizeK>
struct DefaultGemm_TensorOpSm80_OperandA<cute::bfloat16_t, cutlass::layout::ColumnMajor, 8, SizeK>
{
// Smem
using SmemLayoutAtom = decltype(cute::composition(
cute::Swizzle<3, 3, 3>{}, cute::Layout<cute::Shape<cute::_64, cute::_8>, cute::Stride<cute::_1, cute::_64>>{}));
using SmemCopyAtom = cute::Copy_Atom<cute::SM75_U16x8_LDSM_T, cute::bfloat16_t>;
// Gmem
using GmemTiledCopy = decltype(cute::make_tiled_copy(
cute::Copy_Atom<cute::SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, cute::bfloat16_t>{},
cute::Layout<cute::Shape<cute::_16, cute::_8>, cute::Stride<cute::_1, cute::_16>>{},
cute::Layout<cute::Shape<cute::_8, cute::_1>>{}));
};
// Because the F32F16 TiledMMA is A-B symmetric, we can reuse the DefaultOperands
// Operand B - Column-Major (K-major)
template <int Alignment, int SizeK>
struct DefaultGemm_TensorOpSm80_OperandB<cute::half_t, cutlass::layout::ColumnMajor, Alignment, SizeK>
: DefaultGemm_TensorOpSm80_OperandA<cute::half_t, cutlass::layout::RowMajor, Alignment, SizeK>
{
};
template <int Alignment, int SizeK>
struct DefaultGemm_TensorOpSm80_OperandB<cute::bfloat16_t, cutlass::layout::ColumnMajor, Alignment, SizeK>
: DefaultGemm_TensorOpSm80_OperandA<cute::bfloat16_t, cutlass::layout::RowMajor, Alignment, SizeK>
{
};
// Operand B - Row-Major (N-major)
template <int Alignment, int SizeK>
struct DefaultGemm_TensorOpSm80_OperandB<cute::half_t, cutlass::layout::RowMajor, Alignment, SizeK>
: DefaultGemm_TensorOpSm80_OperandA<cute::half_t, cutlass::layout::ColumnMajor, Alignment, SizeK>
{
};
template <int Alignment, int SizeK>
struct DefaultGemm_TensorOpSm80_OperandB<cute::bfloat16_t, cutlass::layout::RowMajor, Alignment, SizeK>
: DefaultGemm_TensorOpSm80_OperandA<cute::bfloat16_t, cutlass::layout::ColumnMajor, Alignment, SizeK>
{
};
//
// F16: 128-by-128-by-32 (small k-block)
//
/// Operand A - Row-major (K-Major)
template <>
struct DefaultGemm_TensorOpSm80_OperandA<cute::half_t, cutlass::layout::RowMajor, 8, 32>
{
// Smem
using SmemLayoutAtom = decltype(cute::composition(
cute::Swizzle<2, 3, 3>{}, cute::Layout<cute::Shape<cute::_8, cute::_32>, cute::Stride<cute::_32, cute::_1>>{}));
using SmemCopyAtom = cute::Copy_Atom<cute::SM75_U32x4_LDSM_N, cute::half_t>;
// Gmem
using GmemTiledCopy = decltype(cute::make_tiled_copy(
cute::Copy_Atom<cute::SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, cute::half_t>{},
cute::Layout<cute::Shape<cute::_32, cute::_4>, cute::Stride<cute::_4, cute::_1>>{},
cute::Layout<cute::Shape<cute::_1, cute::_8>>{}));
};
template <>
struct DefaultGemm_TensorOpSm80_OperandA<cute::bfloat16_t, cutlass::layout::RowMajor, 8, 32>
{
// Smem
using SmemLayoutAtom = decltype(cute::composition(
cute::Swizzle<2, 3, 3>{}, cute::Layout<cute::Shape<cute::_8, cute::_32>, cute::Stride<cute::_32, cute::_1>>{}));
using SmemCopyAtom = cute::Copy_Atom<cute::SM75_U32x4_LDSM_N, cute::bfloat16_t>;
// Gmem
using GmemTiledCopy = decltype(cute::make_tiled_copy(
cute::Copy_Atom<cute::SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, cute::bfloat16_t>{},
cute::Layout<cute::Shape<cute::_32, cute::_4>, cute::Stride<cute::_4, cute::_1>>{},
cute::Layout<cute::Shape<cute::_1, cute::_8>>{}));
};
template <typename To_type, typename Engine, typename Layout>
CUTE_DEVICE auto util_convert_type(cute::Tensor<Engine, Layout> const& tensor)
{
using From_type = typename Engine::value_type;
constexpr int numel = decltype(cute::size(tensor))::value;
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
// HACK: this requires tensor to be "contiguous"
auto frag = convert_op(*reinterpret_cast<cutlass::Array<From_type, numel> const*>(tensor.data()));
return cute::make_tensor(cute::make_rmem_ptr<To_type>(&frag), tensor.layout());
}
template <typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
CUTE_DEVICE void util_copy(
TiledCopy const& tiled_copy, cute::Tensor<Engine0, Layout0> const& S, cute::Tensor<Engine1, Layout1>& D)
{
CUTE_STATIC_ASSERT_V(cute::rank(S) == cute::Int<3>{});
CUTE_STATIC_ASSERT_V(cute::rank(D) == cute::Int<3>{});
CUTE_STATIC_ASSERT_V(cute::size<0>(S) == cute::size<0>(D));
CUTE_STATIC_ASSERT_V(cute::size<1>(S) == cute::size<1>(D));
CUTE_STATIC_ASSERT_V(cute::size<2>(S) == cute::size<2>(D));
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < cute::size<1>(S); ++m)
{
CUTLASS_PRAGMA_UNROLL
for (int k = 0; k < cute::size<2>(S); ++k)
{
cute::copy(tiled_copy, S(cute::_, m, k), D(cute::_, m, k));
}
}
}

View File

@@ -0,0 +1,553 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*! \file
\brief
*/
#pragma once
#include "cutlass/complex.h"
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/semaphore.h"
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/trace.h"
#include "cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h"
#include "cutlass_extensions/tile_interleaved_layout.h"
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace gemm
{
namespace kernel
{
/////////////////////////////////////////////////////////////////////////////////////////////////
// This section exists to that we can use the same kernel code for regular gemm and dequantizing gemms.
// It will dispatch to the dequantizing gemm if the Mma type has an Iterator for scales in global.
template <typename...>
using void_t = void;
template <typename Mma, typename = void>
struct use_dq_gemm : platform::false_type
{
};
template <typename Mma>
struct use_dq_gemm<Mma, void_t<typename Mma::IteratorScale>> : platform::true_type
{
};
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
typename Epilogue_, ///! Epilogue
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
typename KernelArch, ///! The Architecture this kernel is compiled for. Used since SIMT kernels lose top-level
/// arch.
GroupScheduleMode GroupScheduleMode_ ///! Type of scheduling to perform
>
struct MoeFCGemm
{
public:
using Mma = Mma_;
using Epilogue = Epilogue_;
using EpilogueOutputOp = typename Epilogue::OutputOp;
using ThreadblockSwizzle = ThreadblockSwizzle_;
static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_;
static bool const kTransposed = false;
// Optional transpose
using MapArguments = kernel::detail::MapArguments<typename Mma::IteratorA::Element, typename Mma::IteratorA::Layout,
Mma::kTransformA, Mma::IteratorA::AccessType::kElements, typename Mma::IteratorB::Element,
typename Mma::IteratorB::Layout, Mma::kTransformB, Mma::IteratorB::AccessType::kElements, typename Mma::LayoutC,
kTransposed>;
// Public-facing type definitions related to operand element type, layout, and complex conjugate
// operation. Must interact with the 'kTransposed' notion.
static_assert(!kTransposed, "Transpose problem not supported");
using ElementA = typename MapArguments::ElementA;
using LayoutA = typename MapArguments::LayoutA;
using ElementB = typename MapArguments::ElementB;
using LayoutB = typename MapArguments::LayoutB;
using ElementC = typename Epilogue::OutputTileIterator::Element;
using LayoutC = typename MapArguments::LayoutC;
using ElementScale = ElementC;
static ComplexTransform const kTransformA = MapArguments::kTransformA;
static ComplexTransform const kTransformB = MapArguments::kTransformB;
// Type definitions about the mainloop.
using Operator = typename Mma::Operator;
using OperatorClass = typename Mma::Operator::OperatorClass;
using ThreadblockShape = typename Mma::Shape;
using WarpShape = typename Mma::Operator::Shape;
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
using ArchTag = typename Mma::ArchTag;
static int const kStages = Mma::kStages;
static int const kAlignmentA = MapArguments::kAlignmentA;
static int const kAlignmentB = MapArguments::kAlignmentB;
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
/// Warp count (concept: GemmShape)
using WarpCount = typename Mma::WarpCount;
static int const kThreadCount = 32 * WarpCount::kCount;
using ProblemVisitor
= GemmMoeProblemVisitor<ThreadblockShape, kGroupScheduleMode, kThreadCount, kThreadCount, kTransposed>;
//
// Structures
//
/// Argument structure
struct Arguments
{
//
// Data members
//
int problem_count;
int threadblock_count;
int group_size;
typename EpilogueOutputOp::Params output_op;
ElementA* ptr_A;
ElementB* ptr_B;
ElementScale* weight_scales;
ElementC* ptr_C;
ElementC* ptr_D;
bool C_is_broadcast;
int64_t const* total_tokens_including_expert;
int64_t gemm_n;
int64_t gemm_k;
// Only used by device-level operator
GemmCoord* host_problem_sizes;
//
// Methods
//
/// Default ctor
CUTLASS_HOST_DEVICE
Arguments()
: problem_count(0)
, threadblock_count(0)
, ptr_A(nullptr)
, ptr_B(nullptr)
, weight_scales(nullptr)
, ptr_C(nullptr)
, ptr_D(nullptr)
, total_tokens_including_expert(nullptr)
, gemm_n(0)
, gemm_k(0)
, host_problem_sizes(nullptr)
, C_is_broadcast{true}
{
}
/// Ctor
CUTLASS_HOST_DEVICE
Arguments(int problem_count, int threadblock_count, int group_size, typename EpilogueOutputOp::Params output_op,
ElementA const* ptr_A, ElementB const* ptr_B, ElementScale const* weight_scales, ElementC const* ptr_C,
bool C_is_broadcast, ElementC* ptr_D, int64_t const* total_tokens_including_expert, int64_t gemm_n,
int64_t gemm_k, GemmCoord* host_problem_sizes = nullptr)
: problem_count(problem_count)
, threadblock_count(threadblock_count)
, group_size(group_size)
, output_op(output_op)
, ptr_A(const_cast<ElementA*>(ptr_A))
, ptr_B(const_cast<ElementB*>(ptr_B))
, weight_scales(const_cast<ElementScale*>(weight_scales))
, ptr_C(const_cast<ElementC*>(ptr_C))
, C_is_broadcast{C_is_broadcast}
, ptr_D(ptr_D)
, total_tokens_including_expert(total_tokens_including_expert)
, gemm_n(gemm_n)
, gemm_k(gemm_k)
, host_problem_sizes(nullptr)
{
if (platform::is_same<uint8_t, ElementB>::value || platform::is_same<uint4b_t, ElementB>::value)
{
assert(weight_scales);
}
}
};
//
// Structure for precomputing values in host memory and passing to kernels
//
/// Parameters structure
struct Params
{
typename ProblemVisitor::Params problem_visitor;
int threadblock_count;
int group_size;
bool C_is_broadcast;
typename EpilogueOutputOp::Params output_op;
ElementA* ptr_A;
ElementB* ptr_B;
ElementScale* weight_scales;
ElementC* ptr_C;
ElementC* ptr_D;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params()
: ptr_A(nullptr)
, ptr_B(nullptr)
, weight_scales(nullptr)
, ptr_C(nullptr)
, ptr_D(nullptr)
, C_is_broadcast(true)
{
}
CUTLASS_HOST_DEVICE
Params(Arguments const& args, void* workspace = nullptr, int tile_count = 0)
: problem_visitor(
args.total_tokens_including_expert, args.gemm_n, args.gemm_k, args.problem_count, workspace, tile_count)
, threadblock_count(args.threadblock_count)
, group_size(args.group_size)
, output_op(args.output_op)
, ptr_A(args.ptr_A)
, ptr_B(args.ptr_B)
, weight_scales(args.weight_scales)
, ptr_C(args.ptr_C)
, ptr_D(args.ptr_D)
, C_is_broadcast(args.C_is_broadcast)
{
}
CUTLASS_HOST_DEVICE
void update(Arguments const& args, void* workspace = nullptr, int tile_count = 0)
{
problem_visitor = typename ProblemVisitor::Params(args.total_tokens_including_expert, args.gemm_n,
args.gemm_k, args.problem_count, workspace, tile_count);
threadblock_count = args.threadblock_count;
output_op = args.output_op;
ptr_A = args.ptr_A;
ptr_B = args.ptr_B;
weight_scales = args.weight_scales;
ptr_C = args.ptr_C;
ptr_D = args.ptr_D;
C_is_broadcast = args.C_is_broadcast;
}
};
/// Shared memory storage structure
union SharedStorage
{
typename ProblemVisitor::SharedStorage problem_visitor;
typename Mma::SharedStorage main_loop;
typename Epilogue::SharedStorage epilogue;
};
public:
//
// Methods
//
CUTLASS_DEVICE
MoeFCGemm() {}
/// Determines whether kernel satisfies alignment
static Status can_implement(cutlass::gemm::GemmCoord const& problem_size)
{
return Status::kSuccess;
}
static Status can_implement(Arguments const& args)
{
if (platform::is_same<uint8_t, ElementB>::value || platform::is_same<uint4b_t, ElementB>::value)
{
if (args.weight_scales == nullptr)
{
CUTLASS_TRACE_HOST("MoeFCGemm::can_implement() - weight scales are required for uint8_t and uint4b_t");
return Status::kInvalid;
}
}
else if (args.weight_scales != nullptr)
{
CUTLASS_TRACE_HOST(
"MoeFCGemm::can_implement() - weight scales are ignored for all types except uint8_t and uint4b_t");
return Status::kInvalid;
}
else if (args.group_size != args.gemm_k)
{
CUTLASS_TRACE_HOST("MoeFCGemm::can_implement() - scale shape should be (1, gemm_n)");
return Status::kInvalid;
}
// Handle the case the input is too short
else if (args.gemm_n < Mma::IteratorB::AccessType::kElements)
{
CUTLASS_TRACE_HOST("MoeFCGemm::can_implement() - gemm_n is smaller than the input alignment");
return Status::kInvalid;
}
return Status::kSuccess;
}
static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape)
{
return 0;
}
CUTLASS_DEVICE
void run_kernel_(Params const& params, SharedStorage& shared_storage)
{
//
// These types shadow the type-level definitions and support the ability to implement
// a 'transposed' GEMM that computes the transposed problems.
//
using ElementA = typename Mma::IteratorA::Element;
using LayoutA = typename Mma::IteratorA::Layout;
using ElementB = typename Mma::IteratorB::Element;
using LayoutB = typename Mma::IteratorB::Layout;
using ElementC = typename Epilogue::OutputTileIterator::Element;
using LayoutC = typename Epilogue::OutputTileIterator::Layout;
static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK;
static_assert(platform::is_same<LayoutB, layout::RowMajor>::value && kInterleave == 1
|| platform::is_same<LayoutB, layout::ColumnMajor>::value && kInterleave >= 1,
"B must be row major/col major OR col major interleaved.");
//
// Problem visitor.
//
ProblemVisitor problem_visitor(params.problem_visitor, shared_storage.problem_visitor, blockIdx.x);
const int64_t gemm_k = params.problem_visitor.gemm_k;
const int64_t gemm_n = params.problem_visitor.gemm_n;
int64_t bytes_per_expert_matrix = (gemm_k * gemm_n / 8) * cutlass::sizeof_bits<ElementB>::value;
// Outer 'persistent' loop to iterate over tiles
int loop = 0;
while (problem_visitor.next_tile())
{
loop++;
GemmCoord problem_size = problem_visitor.problem_size();
int32_t problem_idx = problem_visitor.problem_index();
int32_t cta_idx = int32_t(problem_visitor.threadblock_idx());
GemmCoord grid_shape = problem_visitor.grid_shape(problem_size);
cutlass::gemm::GemmCoord threadblock_offset(
int(cta_idx / grid_shape.n()) * Mma::Shape::kM, int(cta_idx % grid_shape.n()) * Mma::Shape::kN, 0);
// Load element pointers. Exchange pointers and strides if working on the transpose
const int64_t rows_to_jump
= problem_idx == 0 ? 0 : params.problem_visitor.last_row_for_problem[problem_idx - 1];
ElementA* ptr_A = reinterpret_cast<ElementA*>(params.ptr_A) + rows_to_jump * gemm_k;
typename LayoutA::LongIndex ldm_A = gemm_k;
char* byte_ptr_B = ((char*) params.ptr_B) + problem_idx * bytes_per_expert_matrix;
ElementB* ptr_B = reinterpret_cast<ElementB*>(byte_ptr_B);
typename LayoutB::LongIndex ldm_B
= platform::is_same<layout::RowMajor, LayoutB>::value ? gemm_n : gemm_k * kInterleave;
// Compute initial location in logical coordinates
cutlass::MatrixCoord tb_offset_A{
threadblock_offset.m(),
0,
};
cutlass::MatrixCoord tb_offset_B{0, threadblock_offset.n() / kInterleave};
cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()};
// Compute position within threadblock
int thread_idx = threadIdx.x;
// Construct iterators to A and B operands
typename Mma::IteratorA iterator_A(
LayoutA(ldm_A), ptr_A, {problem_size.m(), problem_size.k()}, thread_idx, tb_offset_A);
typename Mma::IteratorB iterator_B(LayoutB(ldm_B), ptr_B,
{problem_size.k() * kInterleave, problem_size.n() / kInterleave}, thread_idx, tb_offset_B);
typename Mma::FragmentC accumulators;
accumulators.clear();
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
int lane_idx = threadIdx.x % 32;
//
// Matrix multiply phase
//
// Construct thread-scoped matrix multiply
auto CreateMMA = [&]()
{
if constexpr (use_dq_gemm<Mma>::value)
return Mma(shared_storage.main_loop, params.group_size, thread_idx, warp_idx, lane_idx);
else
return Mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
};
Mma mma = CreateMMA();
// Compute threadblock-scoped matrix multiply-add
int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
// Wait for all threads to finish their epilogue phases from the previous tile.
__syncthreads();
// Compute threadblock-scoped matrix multiply-add
ElementScale* weight_scale_ptr = params.weight_scales + problem_idx * problem_size.n();
if constexpr (use_dq_gemm<Mma>::value)
{
const MatrixCoord scale_extent = {1, problem_size.n()};
typename Mma::IteratorScale iterator_scale(Mma::IteratorScale::Layout(scale_extent.column()),
weight_scale_ptr, scale_extent, thread_idx, tb_offset_scale);
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_scale, accumulators);
}
else
{
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators);
}
//
// Epilogue
//
ElementC* ptr_C = reinterpret_cast<ElementC*>(params.ptr_C)
+ (params.C_is_broadcast ? problem_idx : rows_to_jump) * gemm_n;
ElementC* ptr_D = reinterpret_cast<ElementC*>(params.ptr_D) + rows_to_jump * gemm_n;
// lora need to set as layout_C(gemm_n)
LayoutC layout_C = params.C_is_broadcast ? LayoutC(0) : LayoutC(gemm_n);
LayoutC layout_D(gemm_n);
typename Epilogue::OutputTileIterator::Params params_C(layout_C);
typename Epilogue::OutputTileIterator::Params params_D(layout_D);
// Tile iterator loading from source tensor.
typename Epilogue::OutputTileIterator iterator_C(
params_C, ptr_C, problem_size.mn(), thread_idx, threadblock_offset.mn());
// Tile iterator writing to destination tensor.
typename Epilogue::OutputTileIterator iterator_D(
params_D, ptr_D, problem_size.mn(), thread_idx, threadblock_offset.mn());
Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx);
// Execute the epilogue operator to update the destination tensor.
if constexpr (platform::is_same<EpilogueOutputOp,
cutlass::epilogue::thread::LinearCombination<typename EpilogueOutputOp::ElementOutput,
EpilogueOutputOp::kCount, typename EpilogueOutputOp::ElementAccumulator,
typename EpilogueOutputOp::ElementCompute, EpilogueOutputOp::kScale,
EpilogueOutputOp::kRound>>::value)
{
EpilogueOutputOp output_op(params.output_op, problem_idx);
epilogue(output_op, iterator_D, accumulators, iterator_C);
}
else
{
EpilogueOutputOp output_op(params.output_op);
epilogue(output_op, iterator_D, accumulators, iterator_C);
}
// Next tile
problem_visitor.advance(gridDim.x);
}
}
template <typename CompilationArch>
CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage)
{
if constexpr (platform::is_same<KernelArch, CompilationArch>::value)
{
run_kernel_(params, shared_storage);
}
else
{
CUTLASS_NOT_IMPLEMENTED();
}
}
/*
To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond
to the ArchTag of the cutlass kernel operator.
*/
/// Executes one GEMM
CUTLASS_DEVICE
void operator()(Params const& params, SharedStorage& shared_storage)
{
#if defined(__CUDA_ARCH__)
#if (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800)
run_kernel<arch::Sm75>(params, shared_storage);
#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 890)
run_kernel<arch::Sm80>(params, shared_storage);
#elif (__CUDA_ARCH__ >= 890) && (__CUDA_ARCH__ < 900)
constexpr bool isFp8 = platform::is_same<ElementA, cutlass::float_e4m3_t>::value
|| platform::is_same<ElementA, cutlass::float_e5m2_t>::value;
if constexpr (isFp8)
{
run_kernel<arch::Sm89>(params, shared_storage);
}
else
{ // reuse sm80 kernel for other types, align with dispatchToArch
run_kernel<arch::Sm80>(params, shared_storage);
}
#elif (__CUDA_ARCH__ >= 900)
run_kernel<arch::Sm80>(params, shared_storage);
#else
static_assert(
false, "Invalid architecture being compiled. Only Volta+ supported in weight-only quantization kernels.");
#endif
#else
CUTLASS_NOT_IMPLEMENTED();
#endif
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,344 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*! \file
\brief Base scheduler for grouped problems, using MoE
*/
#pragma once
#include "cutlass/gemm/kernel/grouped_problem_visitor.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace gemm
{
namespace kernel
{
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Visitor class to abstract away the algorithm for iterating over tiles
template <typename ProblemSizeHelper, typename ThreadblockShape_>
struct BaseMoeProblemVisitor
{
using ThreadblockShape = ThreadblockShape_;
struct ProblemInfo
{
static int32_t const kNoPrefetchEntry = -1;
int32_t problem_idx;
int32_t problem_start;
CUTLASS_DEVICE
ProblemInfo()
: problem_idx(kNoPrefetchEntry)
, problem_start(kNoPrefetchEntry)
{
}
CUTLASS_DEVICE
ProblemInfo(int32_t problem_idx_, int32_t problem_start_)
: problem_idx(problem_idx_)
, problem_start(problem_start_)
{
}
};
struct Params
{
int64_t const* last_row_for_problem;
int64_t gemm_n;
int64_t gemm_k;
int32_t problem_count;
void const* workspace;
int32_t tile_count;
//
// Methods
//
/// Ctor
CUTLASS_HOST_DEVICE
Params()
: last_row_for_problem(nullptr)
, gemm_n(0)
, gemm_k(0)
, problem_count(0)
, workspace(nullptr)
, tile_count(0)
{
}
/// Ctor
CUTLASS_HOST_DEVICE
Params(int64_t const* last_row_for_problem, int64_t gemm_n, int64_t gemm_k, int32_t problem_count,
void const* workspace = nullptr, int32_t tile_count = 0)
: last_row_for_problem(last_row_for_problem)
, gemm_n(gemm_n)
, gemm_k(gemm_k)
, problem_count(problem_count)
, workspace(workspace)
, tile_count(tile_count)
{
}
};
Params const& params;
int32_t tile_idx;
int32_t problem_tile_start;
int32_t problem_idx;
//
// Methods
//
CUTLASS_DEVICE
BaseMoeProblemVisitor(Params const& params_, int32_t block_idx)
: params(params_)
, tile_idx(block_idx)
, problem_tile_start(0)
, problem_idx(0)
{
}
/// Get the grid shape
CUTLASS_HOST_DEVICE
static cutlass::gemm::GemmCoord grid_shape(cutlass::gemm::GemmCoord const& problem)
{
return cutlass::gemm::GemmCoord(((problem.m() - 1 + ThreadblockShape::kM) / ThreadblockShape::kM),
((problem.n() - 1 + ThreadblockShape::kN) / ThreadblockShape::kN), 1);
}
/// Gets the global tile index
CUTLASS_HOST_DEVICE
int32_t tile_index() const
{
return tile_idx;
}
/// Gets the index of the problem
CUTLASS_HOST_DEVICE
int32_t problem_index() const
{
return problem_idx;
}
CUTLASS_HOST_DEVICE
int32_t threadblock_idx() const
{
return tile_idx - problem_tile_start;
}
CUTLASS_DEVICE
void advance(int32_t grid_size)
{
tile_idx += grid_size;
}
CUTLASS_HOST_DEVICE
static void possibly_transpose_problem(cutlass::gemm::GemmCoord& problem)
{
ProblemSizeHelper::possibly_transpose_problem(problem);
}
/// Returns the problem size for the current problem
CUTLASS_HOST_DEVICE
cutlass::gemm::GemmCoord problem_size() const
{
return problem_size(problem_idx);
}
CUTLASS_HOST_DEVICE
cutlass::gemm::GemmCoord problem_size(int idx) const
{
const int64_t prev_problem_row = idx == 0 ? 0 : params.last_row_for_problem[idx - 1];
const int64_t current_problem_row = params.last_row_for_problem[idx];
const int64_t gemm_m = current_problem_row - prev_problem_row;
GemmCoord problem(GemmCoord::Index(gemm_m), GemmCoord::Index(params.gemm_n), GemmCoord::Index(params.gemm_k));
ProblemSizeHelper::possibly_transpose_problem(problem);
return problem;
}
CUTLASS_HOST_DEVICE
static int32_t tile_count(cutlass::gemm::GemmCoord const& grid)
{
return ProblemSizeHelper::tile_count(grid);
}
static int32_t group_tile_count(cutlass::gemm::GemmCoord const* host_problem_sizes_ptr, int32_t problem_count)
{
int32_t total_tiles = 0;
for (int32_t i = 0; i < problem_count; ++i)
{
auto problem = host_problem_sizes_ptr[i];
possibly_transpose_problem(problem);
auto grid = grid_shape(problem);
total_tiles += tile_count(grid);
}
return total_tiles;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename ProblemSizeHelper, typename ThreadblockShape, GroupScheduleMode GroupScheduleMode_,
int PrefetchTileCount, int ThreadCount>
struct MoeProblemVisitor;
/////////////////////////////////////////////////////////////////////////////////////////////////
// ProblemVisitor that performs all scheduling on device
//
template <typename ProblemSizeHelper, typename ThreadblockShape, int PrefetchTileCount, int ThreadCount>
struct MoeProblemVisitor<ProblemSizeHelper, ThreadblockShape, GroupScheduleMode::kDeviceOnly, PrefetchTileCount,
ThreadCount> : public BaseMoeProblemVisitor<ProblemSizeHelper, ThreadblockShape>
{
using Base = BaseMoeProblemVisitor<ProblemSizeHelper, ThreadblockShape>;
using Params = typename Base::Params;
static int const kThreadCount = ThreadCount;
static bool const kRequiresPrecomputation = false;
static int const kThreadsPerWarp = 32;
struct SharedStorage
{
};
// Final tile of the problem loaded by this thread. Each thread will hold
// a separate value.
int32_t problem_ending_tile;
SharedStorage& shared_storage;
//
// Methods
//
CUTLASS_DEVICE
MoeProblemVisitor(Params const& params_, SharedStorage& shared_storage_, int32_t block_idx)
: Base(params_, block_idx)
, problem_ending_tile(0)
, shared_storage(shared_storage_)
{
this->problem_idx = -1 * kThreadsPerWarp;
this->problem_tile_start = 0;
}
CUTLASS_DEVICE
bool next_tile()
{
// Check whether the tile to compute is within the range of the current problem.
int32_t problem_tile_end = __shfl_sync(0xffffffff, problem_ending_tile, this->problem_idx % kThreadsPerWarp);
if (this->tile_idx < problem_tile_end)
{
return true;
}
// Check whether the tile to compute is within the current group of problems fetched by the warp.
// The last tile for this group is the final tile of the problem held by the final thread in the warp.
int32_t group_tile_end = __shfl_sync(0xffffffff, problem_ending_tile, kThreadsPerWarp - 1);
// Keep the starting problem for this group in `problem_idx`. This is done to reduce
// register pressure. The starting problem for this group is simply the first problem
// in the group most recently fetched by the warp.
int32_t& group_problem_start = this->problem_idx;
group_problem_start = (this->problem_idx / kThreadsPerWarp) * kThreadsPerWarp;
// Keep the starting tile for this group in `problem_tile_start`. This is done to reduce
// register pressure.
int32_t& group_tile_start = this->problem_tile_start;
// Each thread in the warp processes a separate problem to advance until
// reaching a problem whose starting tile is less less than tile_idx.
while (group_tile_end <= this->tile_idx)
{
group_problem_start += kThreadsPerWarp;
if (group_problem_start > this->params.problem_count)
{
return false;
}
// Since `group_tile_start` is a reference to `this->problem_tile_start`, this
// also sets `this->problem_tile_start`. The fact that `this->problem_tile_start`
// is also set here is used later in `next_tile`.
group_tile_start = group_tile_end;
int lane_idx = threadIdx.x % kThreadsPerWarp;
int32_t lane_problem = group_problem_start + lane_idx;
// Compute the number of tiles in the problem assigned to each thread.
problem_ending_tile = 0;
if (lane_problem < this->params.problem_count)
{
cutlass::gemm::GemmCoord problem = this->problem_size(lane_problem);
cutlass::gemm::GemmCoord grid = this->grid_shape(problem);
problem_ending_tile = this->tile_count(grid);
}
// Compute a warp-wide inclusive prefix sum to compute the ending tile index of
// each thread's problem.
CUTLASS_PRAGMA_UNROLL
for (int i = 1; i < kThreadsPerWarp; i <<= 1)
{
int32_t val = __shfl_up_sync(0xffffffff, problem_ending_tile, i);
if (lane_idx >= i)
{
problem_ending_tile += val;
}
}
// The total tile count for this group is now in the final position of the prefix sum
int32_t tiles_in_group = __shfl_sync(0xffffffff, problem_ending_tile, kThreadsPerWarp - 1);
problem_ending_tile += group_tile_start;
group_tile_end += tiles_in_group;
}
// The next problem to process is the first one that does not have ending tile position
// that is greater than or equal to tile index.
int32_t problem_idx_in_group = __popc(__ballot_sync(0xffffffff, problem_ending_tile <= this->tile_idx));
this->problem_idx = group_problem_start + problem_idx_in_group;
// The starting tile for this problem is the ending tile of the previous problem. In cases
// where `problem_idx_in_group` is the first problem in the group, we do not need to reset
// `problem_tile_start`, because it is set to the previous group's ending tile in the while
// loop above.
if (problem_idx_in_group > 0)
{
this->problem_tile_start = __shfl_sync(0xffffffff, problem_ending_tile, problem_idx_in_group - 1);
}
return true;
}
static size_t get_workspace_size(
cutlass::gemm::GemmCoord const* host_problem_sizes_ptr, int32_t problem_count, int32_t block_count)
{
return 0;
}
static void host_precompute(cutlass::gemm::GemmCoord const* host_problem_sizes_ptr, int32_t problem_count,
int32_t block_count, void* host_workspace_ptr)
{
}
};
} // namespace kernel
} // namespace gemm
} // namespace cutlass

View File

@@ -0,0 +1,646 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cute/arch/cluster_sm90.hpp"
#include "cute/tensor.hpp"
#include "cutlass/arch/mma_sm90.h"
#include "cutlass/arch/reg_reconfig.h"
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/collective/detail.hpp"
#include "cutlass/fast_math.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
#include "cutlass/kernel_hardware_info.hpp"
#include "cutlass/pipeline/pipeline.hpp"
#include "cutlass/trace.h"
#include "cutlass/workspace.h"
///////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::kernel
{
///////////////////////////////////////////////////////////////////////////////
template <class ProblemShape_, class CollectiveMainloop_, class CollectiveEpilogue_, class TileScheduler_>
class GemmUniversalGated<ProblemShape_, CollectiveMainloop_, CollectiveEpilogue_, TileScheduler_,
cute::enable_if_t<
cute::is_base_of_v<KernelTmaWarpSpecializedCooperative, typename CollectiveMainloop_::DispatchPolicy::Schedule>
&& CollectiveMainloop_::isGated>>
{
public:
//
// Type Aliases
//
using ProblemShape = ProblemShape_;
static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4,
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
// Mainloop derived types
using CollectiveMainloop = CollectiveMainloop_;
using TileShape = typename CollectiveMainloop::TileShape;
using TiledMma = typename CollectiveMainloop::TiledMma;
using ArchTag = typename CollectiveMainloop::ArchTag;
using ElementA = typename CollectiveMainloop::ElementA;
using StrideA = typename CollectiveMainloop::StrideA;
using ElementB = typename CollectiveMainloop::ElementB;
using StrideB = typename CollectiveMainloop::StrideB;
using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy;
using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator;
using ClusterShape = typename DispatchPolicy::ClusterShape;
using MainloopArguments = typename CollectiveMainloop::Arguments;
using MainloopParams = typename CollectiveMainloop::Params;
using Activation = typename CollectiveMainloop::Activation;
// Epilogue derived types
using CollectiveEpilogue = CollectiveEpilogue_;
using ElementC = typename CollectiveEpilogue::ElementC;
using StrideC = typename CollectiveEpilogue::StrideC;
using ElementD = typename CollectiveEpilogue::ElementD;
using StrideD = typename CollectiveEpilogue::StrideD;
using EpilogueArguments = typename CollectiveEpilogue::Arguments;
using EpilogueParams = typename CollectiveEpilogue::Params;
static_assert(ArchTag::kMinComputeCapability >= 90);
using TileSchedulerTag = TileScheduler_;
using TileScheduler =
typename detail::TileSchedulerSelector<TileScheduler_, ArchTag, TileShape, ClusterShape>::Scheduler;
using TileSchedulerArguments = typename TileScheduler::Arguments;
using TileSchedulerParams = typename TileScheduler::Params;
static constexpr uint32_t NumLoadWarpGroups = 1;
static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMma{})) / NumThreadsPerWarpGroup;
static constexpr uint32_t MaxThreadsPerBlock
= CUTE_STATIC_V(size(TiledMma{})) + (NumLoadWarpGroups * NumThreadsPerWarpGroup);
static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
/// Register requirement for Load and Math WGs
static constexpr uint32_t LoadRegisterRequirement = 40;
static constexpr uint32_t MmaRegisterRequirement = 232;
// 1 stage ordered sequence between mainloop and epilogue producer load threads
using LoadWarpOrderBarrier = cutlass::OrderedSequenceBarrier<1, 2>;
// Kernel level shared memory storage
struct SharedStorage
{
struct TensorStorage : cute::aligned_struct<128>
{
using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage;
using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage;
MainloopTensorStorage mainloop;
EpilogueTensorStorage epilogue;
} tensors;
struct PipelineStorage : cute::aligned_struct<16>
{
using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage;
using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage;
alignas(16) MainloopPipelineStorage mainloop;
alignas(16) EpiLoadPipelineStorage epi_load;
alignas(16) typename LoadWarpOrderBarrier::SharedStorage load_order;
} pipelines;
};
static constexpr int SharedStorageSize = sizeof(SharedStorage);
// Device side arguments
struct Arguments
{
GemmUniversalMode mode{};
ProblemShape problem_shape{};
MainloopArguments mainloop{};
EpilogueArguments epilogue{};
KernelHardwareInfo hw_info{};
TileSchedulerArguments scheduler{};
};
// Kernel entry point API
struct Params
{
GemmUniversalMode mode{};
ProblemShape problem_shape{};
MainloopParams mainloop{};
EpilogueParams epilogue{};
KernelHardwareInfo hw_info{};
TileSchedulerParams scheduler{};
void* workspace{nullptr};
};
//
// Methods
//
// Convert to underlying arguments. In this case, a simple copy for the aliased type.
static Params to_underlying_arguments(Arguments const& args, void* workspace)
{
CUTLASS_TRACE_HOST("to_underlying_arguments():");
auto problem_shape = args.problem_shape;
// if constexpr (detail::IF_SWAP_AB<CollectiveMainloop>::value) {
// // swap M/N
// get<0>(problem_shape) = get<1>(args.problem_shape);
// get<1>(problem_shape) = get<0>(args.problem_shape);
// }
auto problem_shape_MNKL = append<4>(problem_shape, 1);
// Get SM count if needed, otherwise use user supplied SM count
int sm_count = args.hw_info.sm_count;
if (sm_count <= 0)
{
CUTLASS_TRACE_HOST(
" WARNING: Arguments do not include a valid SM count.\n"
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id);
}
CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};
// Calculate workspace pointers
uint8_t* workspace_ptr = reinterpret_cast<uint8_t*>(workspace);
size_t workspace_offset = 0;
void* scheduler_workspace = workspace_ptr;
workspace_offset += TileScheduler::template get_workspace_size<ProblemShape, ElementAccumulator>(
args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups);
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
void* epilogue_workspace = workspace_ptr + workspace_offset;
workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue);
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
void* mainloop_workspace = nullptr;
// Precompute the sub tiles numbers in epilogue, pass into tile scheduler. Therefore it will be used
// in separate reduction scheme for streamk case, NumEpilogueSubTiles default value is 1, which means
// subtile will not be used, therefore separate reduction will not be enabled.
constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{});
TileSchedulerParams scheduler = TileScheduler::to_underlying_arguments(problem_shape_MNKL, TileShape{},
ClusterShape{}, hw_info, args.scheduler, scheduler_workspace, NumEpilogueSubTiles);
return {args.mode, problem_shape,
CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, mainloop_workspace),
CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, epilogue_workspace), hw_info,
scheduler, workspace};
}
static bool can_implement(Arguments const& args)
{
bool implementable = (args.mode == GemmUniversalMode::kGemm)
or (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4);
if (!implementable)
{
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n");
return implementable;
}
implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop);
implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue);
implementable &= TileScheduler::can_implement(args.scheduler);
return implementable;
}
static size_t get_workspace_size(Arguments const& args)
{
size_t workspace_size = 0;
constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{});
workspace_size += TileScheduler::template get_workspace_size<ProblemShape, ElementAccumulator>(
args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles);
workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment);
workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue);
workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment);
return workspace_size;
}
static cutlass::Status initialize_workspace(Arguments const& args, void* workspace = nullptr,
cudaStream_t stream = nullptr, CudaHostAdapter* cuda_adapter = nullptr)
{
Status status = Status::kSuccess;
uint8_t* workspace_ptr = reinterpret_cast<uint8_t*>(workspace);
size_t workspace_offset = 0;
constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{});
status = TileScheduler::template initialize_workspace<ProblemShape, ElementAccumulator>(args.scheduler,
workspace_ptr + workspace_offset, stream, args.problem_shape, args.hw_info, NumMmaWarpGroups,
NumEpilogueSubTiles);
workspace_offset += TileScheduler::template get_workspace_size<ProblemShape, ElementAccumulator>(
args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles);
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
if (status != Status::kSuccess)
{
return status;
}
status = CollectiveEpilogue::initialize_workspace(
args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter);
workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue);
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
if (status != Status::kSuccess)
{
return status;
}
return status;
}
// Computes the kernel launch grid shape based on runtime parameters
static dim3 get_grid_shape(Params const& params)
{
// Given device SM count, set grid size s.t. we do not launch more thread blocks than we can run concurrently
TileSchedulerArguments args{};
if constexpr (!std::is_const_v<decltype(args.max_swizzle_size)>)
{
args.max_swizzle_size = 1 << params.scheduler.log_swizzle_size_;
}
args.raster_order = params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN
? TileScheduler::RasterOrderOptions::AlongN
: TileScheduler::RasterOrderOptions::AlongM;
return TileScheduler::get_grid_shape(params.problem_shape, TileShape{}, ClusterShape{}, params.hw_info, args);
}
static dim3 get_block_shape()
{
return dim3(MaxThreadsPerBlock, 1, 1);
}
CUTLASS_DEVICE
void operator()(Params const& params, char* smem_buf)
{
using namespace cute;
using X = Underscore;
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
#if !defined(__CUDA_ARCH_FEAT_SM90_ALL)
printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n");
#else
// Preconditions
static_assert(size(TiledMma{}) == 256, "Cooperative kernel must have TiledMMA operating using 256 threads.");
static_assert(size<0>(TileShape{}) >= 128,
"Cooperative kernel requires Tile Size to be greater than or equal to 128 along the M-dimension.");
static_assert(cute::rank(StrideA{}) == 3,
"StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideB{}) == 3,
"StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideC{}) == 3,
"StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideD{}) == 3,
"StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
/* In the Cooperative kernel, Consumer0 and Consumer1 collaborate on the same tile */
enum class WarpGroupRole
{
Producer = 0,
Consumer0 = 1,
Consumer1 = 2
};
enum class ProducerWarpRole
{
Mainloop = 0,
Warp1 = 1,
Epilogue = 2,
Warp3 = 3
};
// Kernel level shared memory storage
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
int thread_idx = int(threadIdx.x);
int lane_idx = canonical_lane_idx();
int warp_idx = canonical_warp_idx_sync();
int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup;
int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup;
int mma_thread_idx = thread_idx % size(TiledMma{});
auto warp_group_role = WarpGroupRole(canonical_warp_group_idx());
auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group);
int lane_predicate = cute::elect_one_sync();
uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
// Issue Tma Descriptor Prefetch from a single thread
if ((warp_idx == 0) && lane_predicate)
{
CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue);
}
// Mainloop Load pipeline
using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline;
typename MainloopPipeline::Params mainloop_pipeline_params;
if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Mainloop)
{
mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer;
}
if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1)
{
mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer;
}
mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0;
mainloop_pipeline_params.num_consumers = size(TiledMma{});
mainloop_pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytes;
MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{});
// Epilogue Load pipeline
using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline;
typename EpiLoadPipeline::Params epi_load_pipeline_params;
if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Epilogue)
{
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer;
}
if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1)
{
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer;
}
epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster();
epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp;
epi_load_pipeline_params.consumer_arv_count = size(TiledMma{});
epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes;
EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params);
// Epilogue Store pipeline
using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline;
typename EpiStorePipeline::Params epi_store_pipeline_params;
epi_store_pipeline_params.always_wait = true;
EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params);
typename LoadWarpOrderBarrier::Params params_load_order_barrier;
params_load_order_barrier.group_id = producer_warp_role == ProducerWarpRole::Mainloop ? 0 : 1;
params_load_order_barrier.group_size = NumThreadsPerWarp;
LoadWarpOrderBarrier load_order_barrier(shared_storage.pipelines.load_order, params_load_order_barrier);
// Initialize starting pipeline states for the collectives
// Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding)
typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state;
typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state;
// For the DMA Load (producer) we start with an opposite phase
// i.e., we skip all waits since we know that the buffer is indeed empty
PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state<MainloopPipeline>();
PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state<EpiLoadPipeline>();
PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state<EpiStorePipeline>();
auto cluster_wait_fn = []()
{
// We need this to guarantee that the Pipeline init is visible
// To all producers and consumer thread blocks in the Cluster
if constexpr (size(ClusterShape{}) > 1)
{
cute::cluster_arrive_relaxed();
return []() { cute::cluster_wait(); };
}
else
{
__syncthreads();
return []() {}; // do nothing
}
}();
// Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK)
auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{});
// Get the appropriate blocks for this thread block -- potential for thread block locality
TiledMma tiled_mma;
auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K)
TileScheduler scheduler{params.scheduler};
auto work_tile_info = scheduler.get_current_work();
// In a warp specialized kernel, collectives expose data movement and compute operations separately
CollectiveMainloop collective_mainloop;
CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue);
// Prepare and partition the input tensors. Expects a tuple of tensors where:
// get<0>(load_inputs) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l)
// get<1>(load_inputs) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l)
auto load_inputs = collective_mainloop.load_init(problem_shape_MNKL, params.mainloop);
static_assert(cute::tuple_size_v<decltype(load_inputs)> >= 3,
"Output of load_init must have at least three elements (A, B, Aux)");
// Extract out partitioned A and B.
Tensor gA_mkl = get<0>(load_inputs);
Tensor gB_nkl = get<1>(load_inputs);
Tensor gAux_xkl = get<2>(load_inputs);
// Get pipeline stage increments from tensor shapes
auto k_tile_count = size<3>(gA_mkl);
// Wait for all thread blocks in the Cluster
cluster_wait_fn();
if (warp_group_role == WarpGroupRole::Producer)
{
cutlass::arch::warpgroup_reg_dealloc<LoadRegisterRequirement>();
// Mainloop Producer Warp
if (producer_warp_role == ProducerWarpRole::Mainloop)
{
bool do_load_order_arrive = true;
while (work_tile_info.is_valid())
{
if (!TileScheduler::valid_warpgroup_in_work_tile(work_tile_info))
{
work_tile_info = fetch_next_work(work_tile_info, scheduler);
continue;
}
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape
auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
// Get the number of K tiles to compute for this work as well as the starting K tile offset of the
// work.
auto work_k_tile_count
= TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape);
auto work_k_tile_start = TileScheduler::get_work_k_tile_start(work_tile_info);
auto k_tile_iter
= cute::make_coord_iterator(idx2crd(work_k_tile_start, shape<3>(gA_mkl)), shape<3>(gA_mkl));
collective_mainloop.load(params.mainloop, mainloop_pipeline, mainloop_pipe_producer_state,
load_inputs, blk_coord, k_tile_iter, work_k_tile_count, lane_idx, block_rank_in_cluster,
shared_storage.tensors.mainloop);
// Update starting pipeline state for the next tile
mainloop_pipe_producer_state.advance(work_k_tile_count);
// Signal for the epilogue load warp to begin
if (do_load_order_arrive)
{
load_order_barrier.arrive();
do_load_order_arrive = false;
}
// Get next work tile
work_tile_info = fetch_next_work(work_tile_info, scheduler);
} // Scheduler work fetch loop
// Make sure all Consumer Warp Groups have been waited upon
collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state);
} // Mainloop Producer Warp End
// Epilogue Producer Warp
else if (producer_warp_role == ProducerWarpRole::Epilogue && collective_epilogue.is_producer_load_needed())
{
while (work_tile_info.is_valid())
{
if (!TileScheduler::requires_separate_reduction(params.scheduler))
{
load_order_barrier.wait();
}
if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler))
{
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape
auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
epi_load_pipe_producer_state = collective_epilogue.load(epi_load_pipeline,
epi_load_pipe_producer_state, problem_shape_MNKL, blk_shape, blk_coord, tiled_mma, lane_idx,
shared_storage.tensors.epilogue, work_tile_info.reduction_subtile_idx());
}
// Get next work tile
work_tile_info = fetch_next_work(work_tile_info, scheduler);
} // Scheduler work fetch loop
// Make sure all Consumer Warp Groups have been waited upon
collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state);
} // Epilogue Producer Warp End
} // Producer Warp Group End
else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1)
{
cutlass::arch::warpgroup_reg_alloc<MmaRegisterRequirement>();
// Do we potentially issue tail arrives for TMA stores, if epilogue load is waiting for it
bool do_store_tail = false;
float scale_d0 = params.mainloop.scale_d0;
float scale_d1 = params.mainloop.scale_d1;
while (work_tile_info.is_valid())
{
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape
auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
auto work_k_tile_count
= TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape);
// Allocate the accumulators for the (M,N) blk_shape
//
// MSVC CTAD breaks if we say "Tensor" here, so we use "auto" instead.
auto accumulators0 = partition_fragment_C(tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N)
auto accumulators1 = partition_fragment_C(tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N)
if (TileScheduler::valid_warpgroup_in_work_tile(work_tile_info))
{
collective_mainloop.mma(mainloop_pipeline, mainloop_pipe_consumer_state, accumulators0,
accumulators1, work_k_tile_count, mma_thread_idx, shared_storage.tensors.mainloop,
params.mainloop);
// Make sure the math instructions are done and free buffers before entering the epilogue
collective_mainloop.mma_tail(mainloop_pipeline, mainloop_pipe_consumer_state, work_k_tile_count);
// Update starting mainloop pipeline state for the next tile
mainloop_pipe_consumer_state.advance(work_k_tile_count);
}
// Index of warp group within consumer warp groups
int consumer_warp_group_idx = canonical_warp_group_idx() - NumLoadWarpGroups;
// Perform reduction across splits, if needed
TileScheduler::fixup(
params.scheduler, work_tile_info, accumulators0, NumMmaWarpGroups, consumer_warp_group_idx);
TileScheduler::fixup(
params.scheduler, work_tile_info, accumulators1, NumMmaWarpGroups, consumer_warp_group_idx);
Activation elt_op;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(accumulators0); i++)
{
accumulators0[i] = (accumulators0[i] * scale_d0) * elt_op(scale_d1 * accumulators1[i]);
}
if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler))
{
// Epilogue and write to gD
auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next]
= collective_epilogue.store(epi_load_pipeline, epi_load_pipe_consumer_state, epi_store_pipeline,
epi_store_pipe_producer_state, problem_shape_MNKL, blk_shape, blk_coord, accumulators0,
tiled_mma, mma_thread_idx, shared_storage.tensors.epilogue,
work_tile_info.reduction_subtile_idx());
epi_load_pipe_consumer_state = epi_load_pipe_consumer_state_next;
epi_store_pipe_producer_state = epi_store_pipe_producer_state_next;
do_store_tail = true;
}
// Get next work tile
work_tile_info = fetch_next_work(work_tile_info, scheduler);
} // Scheduler work fetch loop
if (do_store_tail)
{
collective_epilogue.store_tail(
epi_load_pipeline, epi_load_pipe_consumer_state, epi_store_pipeline, epi_store_pipe_producer_state);
}
} // Consumer Warp Groups End
#endif
}
private:
// Kernel helper function to get next work unit
CUTLASS_DEVICE
typename TileScheduler::WorkTileInfo fetch_next_work(
typename TileScheduler::WorkTileInfo& work_tile_info, TileScheduler& scheduler) const
{
// Check whether we should continue on with the current work unit. If this is the case,
// the work unit will have been updated in continue_current_work to reflect the new
// tile to be computed.
if (scheduler.continue_current_work(work_tile_info))
{
return work_tile_info;
}
// Get next work tile
scheduler.advance_to_next_work();
return scheduler.get_current_work();
}
};
///////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::kernel

View File

@@ -0,0 +1,621 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cute/arch/cluster_sm90.hpp"
#include "cutlass/arch/mma_sm90.h"
#include "cutlass/arch/reg_reconfig.h"
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/collective/detail.hpp"
#include "cutlass/fast_math.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp"
#include "cutlass/kernel_hardware_info.hpp"
#include "cutlass/pipeline/pipeline.hpp"
#include "cutlass/trace.h"
#include "cutlass/workspace.h"
#include "cute/tensor.hpp"
#include "cute/util/debug.hpp"
///////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::kernel
{
///////////////////////////////////////////////////////////////////////////////
template <class ProblemShape_, class CollectiveMainloop_, class CollectiveEpilogue_, class TileScheduler_>
class GemmUniversalGated<ProblemShape_, CollectiveMainloop_, CollectiveEpilogue_, TileScheduler_,
cute::enable_if_t<
cute::is_base_of_v<KernelTmaWarpSpecializedPingpong, typename CollectiveMainloop_::DispatchPolicy::Schedule>
&& CollectiveMainloop_::isGated>>
{
public:
//
// Type Aliases
//
using ProblemShape = ProblemShape_;
static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4,
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
// Mainloop derived types
using CollectiveMainloop = CollectiveMainloop_;
using TileShape = typename CollectiveMainloop::TileShape;
using TiledMma = typename CollectiveMainloop::TiledMma;
using ArchTag = typename CollectiveMainloop::ArchTag;
using ElementA = typename CollectiveMainloop::ElementA;
using StrideA = typename CollectiveMainloop::StrideA;
using ElementB = typename CollectiveMainloop::ElementB;
using StrideB = typename CollectiveMainloop::StrideB;
using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy;
using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator;
using ClusterShape = typename DispatchPolicy::ClusterShape;
using MainloopArguments = typename CollectiveMainloop::Arguments;
using MainloopParams = typename CollectiveMainloop::Params;
using Activation = typename CollectiveMainloop::Activation;
static_assert(ArchTag::kMinComputeCapability >= 90);
// Epilogue derived types
using CollectiveEpilogue = CollectiveEpilogue_;
using ElementC = typename CollectiveEpilogue::ElementC;
using StrideC = typename CollectiveEpilogue::StrideC;
using ElementD = typename CollectiveEpilogue::ElementD;
using StrideD = typename CollectiveEpilogue::StrideD;
using EpilogueArguments = typename CollectiveEpilogue::Arguments;
using EpilogueParams = typename CollectiveEpilogue::Params;
static_assert(!cute::is_same_v<TileScheduler_, StreamKScheduler>,
"Ping-pong kernel does not currently support stream-K scheduler.");
using TileSchedulerTag = TileScheduler_;
using TileScheduler =
typename detail::TileSchedulerSelector<TileScheduler_, ArchTag, TileShape, ClusterShape>::Scheduler;
using TileSchedulerArguments = typename TileScheduler::Arguments;
using TileSchedulerParams = typename TileScheduler::Params;
static constexpr uint32_t NumLoadWarpGroups = 1;
static constexpr uint32_t NumMmaWarpGroups = 2;
static constexpr uint32_t MaxThreadsPerBlock
= CUTE_STATIC_V(size(TiledMma{})) + (NumMmaWarpGroups * NumThreadsPerWarpGroup);
static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
/// Register requirement for Load and Math WGs
static constexpr uint32_t LoadRegisterRequirement = 40;
static constexpr uint32_t MmaRegisterRequirement = 232;
// 1 stage ordered sequence between mainloop and epilogue producer load threads
using LoadWarpOrderBarrier = cutlass::OrderedSequenceBarrier<1, 2>;
// Order Sequence barrier with two stages: one for Mainloop and one for Epilogue
static constexpr uint32_t StagesPerMathWarpGroup = 2;
using MathWarpGroupOrderBarrier = cutlass::OrderedSequenceBarrier<StagesPerMathWarpGroup, NumMmaWarpGroups>;
// Kernel level shared memory storage
struct SharedStorage
{
struct TensorStorage : cute::aligned_struct<128>
{
using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage;
using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage;
MainloopTensorStorage mainloop;
EpilogueTensorStorage epilogue;
} tensors;
struct PipelineStorage : cute::aligned_struct<16>
{
using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage;
using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage;
using MathWarpGroupOrderBarrierStorage = typename MathWarpGroupOrderBarrier::SharedStorage;
alignas(16) MainloopPipelineStorage mainloop;
alignas(16) EpiLoadPipelineStorage epi_load;
alignas(16) MathWarpGroupOrderBarrierStorage math_wg_order;
alignas(16) typename LoadWarpOrderBarrier::SharedStorage load_order;
} pipelines;
};
static constexpr int SharedStorageSize = sizeof(SharedStorage);
// Device side arguments
struct Arguments
{
GemmUniversalMode mode{};
ProblemShape problem_shape{};
MainloopArguments mainloop{};
EpilogueArguments epilogue{};
KernelHardwareInfo hw_info{};
TileSchedulerArguments scheduler{};
};
// Kernel entry point API
struct Params
{
GemmUniversalMode mode{};
ProblemShape problem_shape{};
MainloopParams mainloop{};
EpilogueParams epilogue{};
KernelHardwareInfo hw_info{};
TileSchedulerParams scheduler{};
};
//
// Methods
//
// Convert to underlying arguments. In this case, a simple copy for the aliased type.
static Params to_underlying_arguments(Arguments const& args, void* workspace)
{
CUTLASS_TRACE_HOST("to_underlying_arguments():");
(void) workspace;
auto problem_shape = args.problem_shape;
// if constexpr (detail::IF_SWAP_AB<CollectiveMainloop>::value) {
// // swap M/N
// get<0>(problem_shape) = get<1>(args.problem_shape);
// get<1>(problem_shape) = get<0>(args.problem_shape);
// }
auto problem_shape_MNKL = append<4>(problem_shape, 1);
// Get SM count if needed, otherwise use user supplied SM count
int sm_count = args.hw_info.sm_count;
if (sm_count <= 0)
{
CUTLASS_TRACE_HOST(
" WARNING: Arguments do not include a valid SM count.\n"
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id);
}
CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};
// Calculate workspace pointers
uint8_t* workspace_ptr = reinterpret_cast<uint8_t*>(workspace);
size_t workspace_offset = 0;
void* scheduler_workspace = workspace_ptr;
workspace_offset += TileScheduler::template get_workspace_size<ProblemShape, ElementAccumulator>(
args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups);
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
void* epilogue_workspace = workspace_ptr + workspace_offset;
workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue);
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
void* mainloop_workspace = nullptr;
return {args.mode, problem_shape,
CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, mainloop_workspace),
CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, epilogue_workspace), hw_info,
TileScheduler::to_underlying_arguments(
problem_shape_MNKL, TileShape{}, ClusterShape{}, hw_info, args.scheduler, scheduler_workspace)};
}
static bool can_implement(Arguments const& args)
{
bool implementable = (args.mode == GemmUniversalMode::kGemm)
or (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4);
if (!implementable)
{
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n");
return implementable;
}
implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop);
implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue);
implementable &= TileScheduler::can_implement(args.scheduler);
return implementable;
}
static size_t get_workspace_size(Arguments const& args)
{
size_t workspace_size = 0;
workspace_size += TileScheduler::template get_workspace_size<ProblemShape, ElementAccumulator>(
args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups);
workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment);
workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue);
workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment);
return workspace_size;
}
static cutlass::Status initialize_workspace(Arguments const& args, void* workspace = nullptr,
cudaStream_t stream = nullptr, CudaHostAdapter* cuda_adapter = nullptr)
{
Status status = Status::kSuccess;
uint8_t* workspace_ptr = reinterpret_cast<uint8_t*>(workspace);
size_t workspace_offset = 0;
status = TileScheduler::template initialize_workspace<ProblemShape, ElementAccumulator>(args.scheduler,
workspace_ptr + workspace_offset, stream, args.problem_shape, args.hw_info, NumMmaWarpGroups);
workspace_offset += TileScheduler::template get_workspace_size<ProblemShape, ElementAccumulator>(
args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups);
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
if (status != Status::kSuccess)
{
return status;
}
status = CollectiveEpilogue::initialize_workspace(
args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter);
workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue);
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
if (status != Status::kSuccess)
{
return status;
}
return status;
}
// Computes the kernel launch grid shape based on runtime parameters
static dim3 get_grid_shape(Params const& params)
{
// Given device SM count, set grid size s.t. we do not launch more thread blocks than we can run concurrently
TileSchedulerArguments args{};
if constexpr (!std::is_const_v<decltype(args.max_swizzle_size)>)
{
args.max_swizzle_size = 1 << params.scheduler.log_swizzle_size_;
}
args.raster_order = params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN
? TileScheduler::RasterOrderOptions::AlongN
: TileScheduler::RasterOrderOptions::AlongM;
return TileScheduler::get_grid_shape(params.problem_shape, TileShape{}, ClusterShape{}, params.hw_info, args);
}
static dim3 get_block_shape()
{
return dim3(MaxThreadsPerBlock, 1, 1);
}
CUTLASS_DEVICE
void operator()(Params const& params, char* smem_buf)
{
using namespace cute;
using X = Underscore;
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
#if !defined(__CUDA_ARCH_FEAT_SM90_ALL)
printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n");
#else
// Preconditions
static_assert(cute::rank(StrideA{}) == 3,
"StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideB{}) == 3,
"StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideC{}) == 3,
"StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideD{}) == 3,
"StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
enum class WarpGroupRole
{
Producer = 0,
Consumer0 = 1,
Consumer1 = 2
};
enum class ProducerWarpRole
{
Mainloop = 0,
Warp1 = 1,
Epilogue = 2,
Warp3 = 3
};
// Kernel level shared memory storage
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
int thread_idx = int(threadIdx.x);
int lane_idx = canonical_lane_idx();
int warp_idx = canonical_warp_idx_sync();
int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup;
int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup;
auto warp_group_role = WarpGroupRole(canonical_warp_group_idx());
auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group);
int lane_predicate = cute::elect_one_sync();
uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
// Issue Tma Descriptor Prefetch from a single thread
if ((warp_idx == 0) && lane_predicate)
{
CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue);
}
// Mainloop Load pipeline
using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline;
typename MainloopPipeline::Params mainloop_pipeline_params;
if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Mainloop)
{
mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer;
}
if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1)
{
mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer;
}
mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0;
mainloop_pipeline_params.num_consumers = NumThreadsPerWarpGroup;
mainloop_pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytes;
MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{});
// Epilogue Load pipeline
using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline;
typename EpiLoadPipeline::Params epi_load_pipeline_params;
if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Epilogue)
{
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer;
}
if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1)
{
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer;
}
epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster();
epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp;
epi_load_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup;
epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes;
EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params);
// Epilogue Store pipeline
using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline;
typename EpiStorePipeline::Params epi_store_pipeline_params;
epi_store_pipeline_params.always_wait = true;
EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params);
typename LoadWarpOrderBarrier::Params params_load_order_barrier;
params_load_order_barrier.group_id = producer_warp_role == ProducerWarpRole::Mainloop ? 0 : 1;
params_load_order_barrier.group_size = NumThreadsPerWarp;
LoadWarpOrderBarrier load_order_barrier(shared_storage.pipelines.load_order, params_load_order_barrier);
typename MathWarpGroupOrderBarrier::Params params_math_wg_order_barrier;
// DMA Load WG will not participate in these Ordered Barrier syncs
params_math_wg_order_barrier.group_id = canonical_warp_group_idx() - static_cast<int>(WarpGroupRole::Consumer0);
params_math_wg_order_barrier.group_size = NumThreadsPerWarpGroup; // Number of threads / participants in a group
MathWarpGroupOrderBarrier math_wg_order_barrier(
shared_storage.pipelines.math_wg_order, params_math_wg_order_barrier);
// Initialize starting pipeline states for the collectives
// Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding)
typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state;
typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state;
// For the DMA Load (producer) we start with an opposite phase
// i.e., we skip all waits since we know that the buffer is indeed empty
PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state<MainloopPipeline>();
PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state<EpiLoadPipeline>();
PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state<EpiStorePipeline>();
auto cluster_wait_fn = [&]()
{
// We need this to guarantee that the Pipeline init is visible
// To all producers and consumer thread blocks in the Cluster
if constexpr (size(ClusterShape{}) > 1)
{
cute::cluster_arrive_relaxed();
return []() { cute::cluster_wait(); };
}
else
{
__syncthreads();
return []() {}; // do nothing
}
}();
// Separate out problem shape for convenience
// Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK)
auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{});
// Get the appropriate blocks for this thread block -- potential for thread block locality
TiledMma tiled_mma;
auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K)
// In a warp specialized kernel, collectives expose data movement and compute operations separately
CollectiveMainloop collective_mainloop;
CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue);
// Prepare and partition the input tensors. Expects a tuple of tensors where:
// get<0>(load_inputs) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l)
// get<1>(load_inputs) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l)
auto load_inputs = collective_mainloop.load_init(problem_shape_MNKL, params.mainloop);
static_assert(cute::tuple_size_v<decltype(load_inputs)> >= 3,
"Output of load_init must have at least three elements (A, B, Aux)");
// Extract out partitioned A and B.
Tensor gA_mkl = get<0>(load_inputs);
Tensor gB_nkl = get<1>(load_inputs);
Tensor gAux_xkl = get<2>(load_inputs);
// Get pipeline stage increments from tensor shapes
auto k_tile_count = size<3>(gA_mkl);
auto c_tile_count = CollectiveEpilogue::get_load_pipe_increment(blk_shape);
auto d_tile_count = CollectiveEpilogue::get_store_pipe_increment(blk_shape);
TileScheduler scheduler{params.scheduler};
if (warp_group_role == WarpGroupRole::Consumer1)
{
// Advance 2nd Math WG to the next work tile for the startup
scheduler.advance_to_next_work();
// Advance 2nd Math WG pipeline states to the end of 1st Math WG
mainloop_pipe_consumer_state.advance(k_tile_count);
epi_load_pipe_consumer_state.advance(c_tile_count);
epi_store_pipe_producer_state.advance(d_tile_count);
}
auto work_tile_info = scheduler.get_current_work();
// Wait for all thread blocks in the Cluster
cluster_wait_fn();
if (warp_group_role == WarpGroupRole::Producer)
{
cutlass::arch::warpgroup_reg_dealloc<LoadRegisterRequirement>();
// Mainloop Producer Warp
if (producer_warp_role == ProducerWarpRole::Mainloop)
{
bool do_load_order_arrive = true;
while (work_tile_info.is_valid())
{
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape
auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
auto k_tile_iter = cute::make_coord_iterator(shape<3>(gA_mkl));
collective_mainloop.load(params.mainloop, mainloop_pipeline, mainloop_pipe_producer_state,
load_inputs, blk_coord, k_tile_iter, k_tile_count, lane_idx, block_rank_in_cluster,
shared_storage.tensors.mainloop);
// Update starting pipeline state for the next tile
mainloop_pipe_producer_state.advance(k_tile_count);
// Signal for the epilogue load warp to begin
if (do_load_order_arrive)
{
load_order_barrier.arrive();
do_load_order_arrive = false;
}
// Get next work tile
scheduler.advance_to_next_work();
work_tile_info = scheduler.get_current_work();
} // Scheduler work fetch loop
// Make sure all Consumer Warp Groups have been waited upon
collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state);
} // Mainloop Producer Warp End
// Epilogue Producer Warp
else if (producer_warp_role == ProducerWarpRole::Epilogue && collective_epilogue.is_producer_load_needed())
{
load_order_barrier.wait();
while (work_tile_info.is_valid())
{
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape
auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
epi_load_pipe_producer_state
= collective_epilogue.load(epi_load_pipeline, epi_load_pipe_producer_state, problem_shape_MNKL,
blk_shape, blk_coord, tiled_mma, lane_idx, shared_storage.tensors.epilogue);
// Get next work tile
scheduler.advance_to_next_work();
work_tile_info = scheduler.get_current_work();
} // Scheduler work fetch loop
// Make sure all Consumer Warp Groups have been waited upon
collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state);
} // Epilogue Producer Warp End
} // Producer Warp Group End
else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1)
{
cutlass::arch::warpgroup_reg_alloc<MmaRegisterRequirement>();
float scale_d0 = params.mainloop.scale_d0;
float scale_d1 = params.mainloop.scale_d1;
while (work_tile_info.is_valid())
{
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape
auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
// Allocate the accumulators for the (M,N) blk_shape
Tensor accumulators0 = partition_fragment_C(tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N)
Tensor accumulators1 = partition_fragment_C(tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N)
// Order two Math WG's MMA one after the other, helps hide Epilogue
math_wg_order_barrier.wait();
collective_mainloop.mma(mainloop_pipeline, mainloop_pipe_consumer_state, accumulators0, accumulators1,
k_tile_count, warp_group_thread_idx, shared_storage.tensors.mainloop, params.mainloop);
// Cue for next Math WG's MMA to start
math_wg_order_barrier.arrive();
// Make sure the math instructions are done and free buffers before entering the epilogue
collective_mainloop.mma_tail(mainloop_pipeline, mainloop_pipe_consumer_state, k_tile_count);
// Update starting mainloop pipeline state for the next tile
mainloop_pipe_consumer_state.advance(k_tile_count * NumMmaWarpGroups);
Activation elt_op;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(accumulators0); i++)
{
accumulators0[i] = (accumulators0[i] * scale_d0) * elt_op(scale_d1 * accumulators1[i]);
}
// Order two Math WG's Epilogue one after the other
math_wg_order_barrier.wait();
// Epilogue and write to gD
auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next]
= collective_epilogue.store(epi_load_pipeline, epi_load_pipe_consumer_state, epi_store_pipeline,
epi_store_pipe_producer_state, problem_shape_MNKL, blk_shape, blk_coord, accumulators0,
tiled_mma, warp_group_thread_idx, shared_storage.tensors.epilogue);
// TMA store pipeline wait is only visible to TMA-issuing warp, so for multiple-consumer kernels
// we need to wait for all TMA stores to complete before issuing consumer order barrier arrives
// to ensure next math consumer doesn't overwrite smem of in-flight TMA stores of current consumer.
auto [epi_load_pipe_consumer_state_next_, epi_store_pipe_producer_state_next_]
= collective_epilogue.store_tail(epi_load_pipeline, epi_load_pipe_consumer_state_next,
epi_store_pipeline, epi_store_pipe_producer_state_next);
// Update starting load/store pipeline states for the next tile
// state has already been incremented by 1 tile in collective calls, advance once again for ping pong
epi_load_pipe_consumer_state = epi_load_pipe_consumer_state_next_;
epi_store_pipe_producer_state = epi_store_pipe_producer_state_next_;
epi_load_pipe_consumer_state.advance(c_tile_count);
epi_store_pipe_producer_state.advance(d_tile_count);
// Cue for next Math WG's Epilogue to start
math_wg_order_barrier.arrive();
// Get next work tile
scheduler.advance_to_next_work(NumMmaWarpGroups);
work_tile_info = scheduler.get_current_work();
} // Scheduler work fetch loop
} // Consumer Warp Groups End
#endif
}
};
///////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::kernel

View File

@@ -0,0 +1,494 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief based on cutlass/include/cutlass/gemm/kernel/gemm_grouped.h
*/
#pragma once
#include "cutlass/complex.h"
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/semaphore.h"
#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h"
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/trace.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace gemm
{
namespace kernel
{
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
typename Epilogue_, ///! Epilogue
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
GroupScheduleMode GroupScheduleMode_, ///! Type of scheduling to perform
bool Transposed = false>
struct SplitkGemmGrouped
{
public:
using Mma = Mma_;
using Epilogue = Epilogue_;
using EpilogueOutputOp = typename Epilogue::OutputOp;
using ThreadblockSwizzle = ThreadblockSwizzle_;
static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_;
static bool const kTransposed = Transposed;
// Optional transpose
using MapArguments = kernel::detail::MapArguments<typename Mma::IteratorA::Element, typename Mma::IteratorA::Layout,
Mma::kTransformA, Mma::IteratorA::AccessType::kElements, typename Mma::IteratorB::Element,
typename Mma::IteratorB::Layout, Mma::kTransformB, Mma::IteratorB::AccessType::kElements, typename Mma::LayoutC,
kTransposed>;
// Public-facing type definitions related to operand element type, layout, and complex conjugate
// operation. Must interact with the 'kTransposed' notion.
using ElementA = typename MapArguments::ElementA;
using LayoutA = typename MapArguments::LayoutA;
using ElementB = typename MapArguments::ElementB;
using LayoutB = typename MapArguments::LayoutB;
using ElementC = typename Epilogue::OutputTileIterator::Element;
using LayoutC = typename MapArguments::LayoutC;
using ElementFinalOutput = typename MapArguments::ElementA;
static ComplexTransform const kTransformA = MapArguments::kTransformA;
static ComplexTransform const kTransformB = MapArguments::kTransformB;
// Type definitions about the mainloop.
using Operator = typename Mma::Operator;
using OperatorClass = typename Mma::Operator::OperatorClass;
using ThreadblockShape = typename Mma::Shape;
using WarpShape = typename Mma::Operator::Shape;
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
using ArchTag = typename Mma::ArchTag;
static int const kStages = Mma::kStages;
static int const kAlignmentA = MapArguments::kAlignmentA;
static int const kAlignmentB = MapArguments::kAlignmentB;
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
/// Warp count (concept: GemmShape)
using WarpCount = typename Mma::WarpCount;
static int const kThreadCount = 32 * WarpCount::kCount;
using ProblemVisitor
= GemmGroupedProblemVisitor<ThreadblockShape, kGroupScheduleMode, kThreadCount, kThreadCount, kTransposed>;
//
// Structures
//
/// Argument structure
struct Arguments
{
//
// Data members
//
GemmCoord* problem_sizes;
int problem_count;
int threadblock_count;
typename EpilogueOutputOp::Params output_op;
ElementA** ptr_A;
ElementB** ptr_B;
ElementFinalOutput** ptr_C;
ElementFinalOutput** ptr_D;
typename LayoutA::Stride::LongIndex* lda;
typename LayoutB::Stride::LongIndex* ldb;
typename LayoutC::Stride::LongIndex* ldc;
typename LayoutC::Stride::LongIndex* ldd;
// Only used by device-level operator
GemmCoord* host_problem_sizes;
// splitK
int split_k_slices;
int64_t* splitk_buffer_offsets;
//
// Methods
//
/// Default ctor
CUTLASS_HOST_DEVICE
Arguments()
: problem_count(0)
, threadblock_count(0)
, ptr_A(nullptr)
, ptr_B(nullptr)
, ptr_C(nullptr)
, ptr_D(nullptr)
, lda(nullptr)
, ldb(nullptr)
, ldc(nullptr)
, ldd(nullptr)
, host_problem_sizes(nullptr)
, split_k_slices(1)
, splitk_buffer_offsets(nullptr)
{
}
/// Ctor
CUTLASS_HOST_DEVICE
Arguments(GemmCoord* problem_sizes, int problem_count, int threadblock_count,
typename EpilogueOutputOp::Params output_op, ElementA** ptr_A, ElementB** ptr_B, ElementFinalOutput** ptr_C,
ElementFinalOutput** ptr_D, typename LayoutA::Stride::LongIndex* lda,
typename LayoutB::Stride::LongIndex* ldb, typename LayoutC::Stride::LongIndex* ldc,
typename LayoutC::Stride::LongIndex* ldd, GemmCoord* host_problem_sizes, int split_k_slices,
int64_t* splitk_buffer_offsets)
: problem_sizes(problem_sizes)
, problem_count(problem_count)
, threadblock_count(threadblock_count)
, output_op(output_op)
, ptr_A(ptr_A)
, ptr_B(ptr_B)
, ptr_C(ptr_C)
, ptr_D(ptr_D)
, lda(lda)
, ldb(ldb)
, ldc(ldc)
, ldd(ldd)
, host_problem_sizes(host_problem_sizes)
, split_k_slices(split_k_slices)
, splitk_buffer_offsets(splitk_buffer_offsets)
{
}
};
//
// Structure for precomputing values in host memory and passing to kernels
//
/// Parameters structure
struct Params
{
typename ProblemVisitor::Params problem_visitor;
int threadblock_count;
typename EpilogueOutputOp::Params output_op;
ElementA** ptr_A;
ElementB** ptr_B;
ElementFinalOutput** ptr_C;
ElementFinalOutput** ptr_D;
ElementC* ptr_C_split;
ElementC* ptr_D_split;
typename LayoutA::Stride::LongIndex* lda;
typename LayoutB::Stride::LongIndex* ldb;
typename LayoutC::Stride::LongIndex* ldc;
typename LayoutC::Stride::LongIndex* ldd;
//
// Methods
//
// splitk
GemmCoord grid_tiled_shape;
int swizzle_log_tile;
int gemm_k_size;
GemmCoord* host_problem_sizes;
int split_k_slices;
int64_t* splitk_buffer_offsets;
CUTLASS_HOST_DEVICE
Params()
: ptr_A(nullptr)
, ptr_B(nullptr)
, ptr_C(nullptr)
, ptr_D(nullptr)
, ptr_C_split(nullptr)
, ptr_D_split(nullptr)
, lda(nullptr)
, ldb(nullptr)
, ldc(nullptr)
, ldd(nullptr)
, swizzle_log_tile(0)
, gemm_k_size(0)
, host_problem_sizes(nullptr)
, split_k_slices(1)
, splitk_buffer_offsets(nullptr)
{
}
CUTLASS_HOST_DEVICE
Params(Arguments const& args, void* workspace = nullptr, int tile_count = 0)
: problem_visitor(args.problem_sizes, args.problem_count, workspace, tile_count)
, host_problem_sizes(args.host_problem_sizes)
, threadblock_count(args.threadblock_count)
, output_op(args.output_op)
, ptr_A(args.ptr_A)
, ptr_B(args.ptr_B)
, ptr_C(args.ptr_C)
, ptr_D(args.ptr_D)
, ptr_C_split((ElementC*) workspace)
, ptr_D_split((ElementC*) workspace)
, lda(args.lda)
, ldb(args.ldb)
, ldc(args.ldc)
, ldd(args.ldd)
, split_k_slices(args.split_k_slices)
, splitk_buffer_offsets(args.splitk_buffer_offsets)
{
// Determine grid shape
ThreadblockSwizzle threadblock_swizzle;
grid_tiled_shape = threadblock_swizzle.get_tiled_shape(args.host_problem_sizes[0],
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.split_k_slices);
swizzle_log_tile = ThreadblockSwizzle().get_log_tile(grid_tiled_shape);
// only support same k
int full_gemm_k_iterations = args.host_problem_sizes[0].k() / Mma::Shape::kK;
int gemm_k_iterations = full_gemm_k_iterations / grid_tiled_shape.k();
gemm_k_size = gemm_k_iterations * Mma::Shape::kK;
}
CUTLASS_HOST_DEVICE
void update(Arguments const& args, void* workspace = nullptr, int tile_count = 0)
{
problem_visitor =
typename ProblemVisitor::Params(args.problem_sizes, args.problem_count, workspace, tile_count);
threadblock_count = args.threadblock_count;
output_op = args.output_op;
ptr_A = args.ptr_A;
ptr_B = args.ptr_B;
ptr_C = args.ptr_C;
ptr_D = args.ptr_D;
ptr_C_split = workspace;
ptr_D_split = workspace;
lda = args.lda;
ldb = args.ldb;
ldc = args.ldc;
ldd = args.ldd;
}
};
/// Shared memory storage structure
struct SharedStorage
{
union
{
typename Mma::SharedStorage main_loop;
typename Epilogue::SharedStorage epilogue;
} kernel;
// ProblemVisitor shared storage can't be overlapped with others
typename ProblemVisitor::SharedStorage problem_visitor;
};
public:
//
// Methods
//
CUTLASS_DEVICE
SplitkGemmGrouped() {}
/// Determines whether kernel satisfies alignment
static Status can_implement(cutlass::gemm::GemmCoord const& problem_size)
{
return Status::kSuccess;
}
static Status can_implement(Arguments const& args)
{
return Status::kSuccess;
}
/// Executes one GEMM
CUTLASS_DEVICE
void operator()(Params const& params, SharedStorage& shared_storage)
{
//
// These types shadow the type-level definitions and support the ability to implement
// a 'transposed' GEMM that computes the transposed problems.
//
using ElementA = typename Mma::IteratorA::Element;
using LayoutA = typename Mma::IteratorA::Layout;
using ElementB = typename Mma::IteratorB::Element;
using LayoutB = typename Mma::IteratorB::Layout;
using ElementC = typename Epilogue::OutputTileIterator::Element;
using LayoutC = typename Epilogue::OutputTileIterator::Layout;
//
// Problem visitor.
//
ProblemVisitor problem_visitor(params.problem_visitor, shared_storage.problem_visitor, blockIdx.x);
// Outer 'persistent' loop to iterate over tiles
while (problem_visitor.next_tile())
{
GemmCoord problem_size = problem_visitor.problem_size();
int32_t problem_idx = problem_visitor.problem_index();
int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx());
GemmCoord grid_shape = problem_visitor.grid_shape(problem_size);
// Load element pointers. Exchange pointers and strides if working on the transpose
ElementA* ptr_A
= reinterpret_cast<ElementA*>((kTransposed ? params.ptr_B[problem_idx] : params.ptr_A[problem_idx]));
typename LayoutA::LongIndex ldm_A = (kTransposed ? params.ldb[problem_idx] : params.lda[problem_idx]);
ElementB* ptr_B
= reinterpret_cast<ElementB*>((kTransposed ? params.ptr_A[problem_idx] : params.ptr_B[problem_idx]));
typename LayoutB::LongIndex ldm_B = (kTransposed ? params.lda[problem_idx] : params.ldb[problem_idx]);
// Compute threadblock location
ThreadblockSwizzle threadblock_swizzle;
GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
cutlass::gemm::GemmCoord threadblock_offset(int(threadblock_idx / grid_shape.n()) * Mma::Shape::kM,
int(threadblock_idx % grid_shape.n()) * Mma::Shape::kN, 0);
// Compute initial location in logical coordinates
cutlass::MatrixCoord tb_offset_A{
threadblock_offset.m(),
threadblock_tile_offset.k() * params.gemm_k_size,
};
cutlass::MatrixCoord tb_offset_B{threadblock_tile_offset.k() * params.gemm_k_size, threadblock_offset.n()};
// Problem size is a function of threadblock index in the K dimension
int problem_size_k;
if (threadblock_tile_offset.k() + 1 == params.grid_tiled_shape.k())
{
problem_size_k = problem_size.k();
}
else
{
problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size;
}
// Compute threadblock-scoped matrix multiply-add
int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK;
// Compute position within threadblock
int thread_idx = threadIdx.x;
// Construct iterators to A and B operands
typename Mma::IteratorA iterator_A(
LayoutA(ldm_A), ptr_A, {problem_size.m(), problem_size_k}, thread_idx, tb_offset_A);
typename Mma::IteratorB iterator_B(
LayoutB(ldm_B), ptr_B, {problem_size_k, problem_size.n()}, thread_idx, tb_offset_B);
typename Mma::FragmentC accumulators;
accumulators.clear();
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = canonical_warp_idx_sync();
int lane_idx = threadIdx.x % 32;
//
// Matrix multiply phase
//
// Construct thread-scoped matrix multiply
Mma mma(shared_storage.kernel.main_loop, thread_idx, warp_idx, lane_idx);
// Wait for all threads to finish their epilogue phases from the previous tile.
__syncthreads();
// Compute threadblock-scoped matrix multiply-add
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators);
//
// Epilogue
//
EpilogueOutputOp output_op(params.output_op);
ElementC* ptr_C = params.ptr_C_split;
ElementC* ptr_D = params.ptr_D_split;
LayoutC layout_C(params.ldc[problem_idx]);
LayoutC layout_D(params.ldd[problem_idx]);
typename Epilogue::OutputTileIterator::Params params_C(layout_C);
typename Epilogue::OutputTileIterator::Params params_D(layout_D);
// assume identity swizzle
MatrixCoord threadblock_offset_C(threadblock_offset.m(), threadblock_offset.n());
// Tile iterator loading from source tensor.
typename Epilogue::OutputTileIterator iterator_C(
params_C, ptr_C, problem_size.mn(), thread_idx, threadblock_offset_C);
iterator_C.add_pointer_offset(problem_size.m() * problem_size.n() * threadblock_tile_offset.k()
+ gridDim.z * params.splitk_buffer_offsets[problem_idx]);
// Tile iterator writing to destination tensor.
typename Epilogue::OutputTileIterator iterator_D(
params_D, ptr_D, problem_size.mn(), thread_idx, threadblock_offset_C);
iterator_D.add_pointer_offset(problem_size.m() * problem_size.n() * threadblock_tile_offset.k()
+ gridDim.z * params.splitk_buffer_offsets[problem_idx]);
Epilogue epilogue(shared_storage.kernel.epilogue, thread_idx, warp_idx, lane_idx);
// Execute the epilogue operator to update the destination tensor.
epilogue(output_op, iterator_D, accumulators, iterator_C);
// Next tile
problem_visitor.advance(gridDim.x);
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,125 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cutlass_extensions/arch/mma.h"
#include "cutlass_extensions/interleaved_numeric_conversion.h"
namespace cutlass
{
namespace gemm
{
namespace threadblock
{
////////////////////////////////////////////////////////////////////////////////
// We need to distinguish here, since we want volta support. It is too much effort
// to write shared memory iterators that are probably needed for volta to function
// properly. As a result, we allow converters both after the LDG (for volta) and after
// the LDS for Turing+.
template <
/// Iterator for B matrix in global memory
typename IteratorB,
/// Warp level Mma
typename MmaOperator,
/// Math operation perform by warp level operator
typename MathOperator>
struct SetConverters
{
};
// Dequantize after LDG, so set transforms accordingly
template <
/// Iterator for B matrix in global memory
typename IteratorB,
/// Mma Policy
typename MmaOperator>
struct SetConverters<IteratorB, MmaOperator, arch::OpMultiplyAdd>
{
using TransformAfterLDG
= FastInterleavedAndBiasedNumericArrayConverter<typename MmaOperator::ArchMmaOperator::ElementB,
typename IteratorB::Element, IteratorB::Fragment::kElements>;
using TransformAfterLDS = NumericArrayConverter<typename MmaOperator::ArchMmaOperator::ElementB,
typename MmaOperator::ArchMmaOperator::ElementB, MmaOperator::FragmentB::kElements>;
};
// Dequantize after LDS, so set transforms accordingly
template <
/// Iterator for B matrix in global memory
typename IteratorB,
/// Mma Policy
typename MmaOperator>
struct SetConverters<IteratorB, MmaOperator, arch::OpMultiplyAddDequantizeInterleavedBToA>
{
using TransformAfterLDG = NumericArrayConverter<typename IteratorB::Element, typename IteratorB::Element,
IteratorB::Fragment::kElements>;
using TransformAfterLDS
= FastInterleavedAndBiasedNumericArrayConverter<typename MmaOperator::ArchMmaOperator::ElementB,
typename TransformAfterLDG::result_type::Element, MmaOperator::FragmentB::kElements>;
};
////////////////////////////////////////////////////////////////////////////////
template <
/// Element type for A matrix operand
typename ElementA_,
/// Layout type for A matrix operand
typename LayoutA_,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB_,
/// Layout type for B matrix operand
typename LayoutB_,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for the input scale
typename ElementScale_,
/// Layout for the scale operand
typename LayoutScale_,
/// Access granularity of Scales in unit of elements
int kAlignmentScale,
/// Element type for internal accumulation
typename ElementAccumulator_,
/// Layout type for C and D matrix operands
typename LayoutC_,
/// Operator class tag
typename OperatorClass_,
/// Tag indicating architecture to tune for
typename ArchTag_,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape_,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape_,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape_,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Operation performed by GEMM
typename Operator_,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
///
typename Enable = void>
struct DqMma;
} // namespace threadblock
} // namespace gemm
} // namespace cutlass

View File

@@ -0,0 +1,302 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cutlass/gemm/threadblock/default_mma.h"
#include "cutlass_extensions/arch/mma.h"
#include "cutlass_extensions/gemm/threadblock/dq_mma_multistage.h"
#include "cutlass_extensions/gemm/warp/default_mma_tensor_op.h"
#include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h"
#include "cutlass_extensions/tile_interleaved_layout.h"
#include "cutlass_extensions/gemm/threadblock/default_dq_mma.h"
#include "cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h"
namespace cutlass
{
namespace gemm
{
namespace threadblock
{
////////////////////////////////////////////////////////////////////////////////
template <typename MmaShape, typename Element, typename Layout, WeightOnlyQuantOp QuantOp, int Alignment,
typename Enable = void>
struct DefaultScaleIteratorsMultistage;
// Fine grained iterators
template <typename MmaShape, typename Element, typename Layout, WeightOnlyQuantOp QuantOp, int Alignment>
struct DefaultScaleIteratorsMultistage<MmaShape, Element, Layout, QuantOp, Alignment,
std::enable_if_t<isFinegrained(QuantOp)>>
{
using IteratorScale
= cutlass::transform::threadblock::FineGrainedScaleZeroIterator<cutlass::MatrixShape<1, MmaShape::kN>, Element,
Layout, 0, Alignment>;
using SmemIteratorScale = IteratorScale;
};
// Per column iterators
template <typename MmaShape, typename Element, typename Layout, WeightOnlyQuantOp QuantOp, int Alignment>
struct DefaultScaleIteratorsMultistage<MmaShape, Element, Layout, QuantOp, Alignment,
std::enable_if_t<!isFinegrained(QuantOp)>>
{
// ThreadMap for scale iterator
static_assert((MmaShape::kN % Alignment) == 0, "");
private:
using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap<layout::PitchLinearShape<MmaShape::kN, 1>,
MmaShape::kN / Alignment, Alignment>;
public:
// Define iterators over tiles from the scale operand
using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaShape::kN>,
Element, Layout, 0, IteratorScaleThreadMap, Alignment>;
using SmemIteratorScale = IteratorScale;
};
////////////////////////////////////////////////////////////////////////////////
template <
/// Type for element A
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Type for element B
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for the input scale
typename ElementScale,
/// Layout for the scale operand
typename LayoutScale,
/// Access granularity of Scales in unit of elements
int kAlignmentScale,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Operator class tag
typename OperatorClass,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Stages in GEMM
int kStages,
/// Operator performed by GEMM
typename Operator_,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear>
struct DqMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementScale, LayoutScale, kAlignmentScale,
ElementAccumulator, layout::RowMajor, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape,
kStages, Operator_, SharedMemoryClear,
typename platform::enable_if<(
ArchTag::kMinComputeCapability >= 80 && !layout::IsColumnMajorTileInterleave<LayoutB>::value)>::type>
{
static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value
|| platform::is_same<ElementA, float_e4m3_t>::value,
"Element A must be fp16, fp8 or bf16");
using OperatorInfo = arch::DetagOperator<Operator_>;
using Operator = typename OperatorInfo::Operator;
static_assert(platform::is_same<Operator, arch::OpMultiplyAddDequantizeInterleavedBToA>::value,
"Mma multistage must dequantize after ldsm");
static_assert(platform::is_same<ElementB, uint8_t>::value || platform::is_same<ElementB, uint4b_t>::value,
"Element B must be uint8 or uint4");
static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits<ElementA>::value * kAlignmentA) == 128)
? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits<ElementB>::value * kAlignmentB) == 128)
? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
// Define the MmaCore components
// Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape,
ElementA, LayoutA, ElementB, LayoutB, ElementAccumulator, layout::RowMajor, OperatorClass, std::max(kStages, 3),
Operator, false, CacheOpA, CacheOpB>;
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::Array<ElementA, kAlignmentA>;
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, ElementA, LayoutA, 1, ThreadMapA,
AccessTypeA>;
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::Array<ElementB, kAlignmentB>;
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, ElementB, LayoutB, 0, ThreadMapB,
AccessTypeB>;
using ScaleIterators = DefaultScaleIteratorsMultistage<typename MmaCore::Shape, ElementScale, LayoutScale,
OperatorInfo::QuantOp, kAlignmentScale>;
// Define iterators over tiles from the scale operand
using IteratorScale = typename ScaleIterators::IteratorScale;
using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale;
using Converter = FastInterleavedAndBiasedNumericArrayConverter<ElementScale, ElementB,
MmaCore::MmaPolicy::Operator::FragmentB::kElements>;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage<typename MmaCore::Shape, IteratorA,
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
MmaCore::kCacheOpB, IteratorScale, SmemIteratorScale, ElementAccumulator, layout::RowMajor,
typename MmaCore::MmaPolicy, kStages, Converter, OperatorInfo::QuantOp, SharedMemoryClear>;
};
// Specialization to handle column major interleave B
template <
/// Type for element A
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Type for element B
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for the input scale
typename ElementScale,
/// Layout for the scale operand
typename LayoutScale,
/// Access granularity of Scales in unit of elements
int kAlignmentScale,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Operator class tag
typename OperatorClass,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Stages in GEMM
int kStages,
/// Operator performed by GEMM
typename Operator_,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear>
struct DqMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementScale, LayoutScale, kAlignmentScale,
ElementAccumulator, layout::RowMajor, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape,
kStages, Operator_, SharedMemoryClear,
typename platform::enable_if<(
ArchTag::kMinComputeCapability >= 80 && layout::IsColumnMajorTileInterleave<LayoutB>::value)>::type>
{
static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value
|| platform::is_same<ElementA, float_e4m3_t>::value,
"Element A must be fp16, fp8 or bf16");
using OperatorInfo = arch::DetagOperator<Operator_>;
using Operator = typename OperatorInfo::Operator;
static_assert(platform::is_same<Operator, arch::OpMultiplyAddDequantizeInterleavedBToA>::value,
"Mma multistage must dequantize after ldsm");
static_assert(platform::is_same<ElementB, uint8_t>::value || platform::is_same<ElementB, uint4b_t>::value,
"Element B must be uint8 or uint4");
static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits<ElementA>::value * kAlignmentA) == 128)
? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits<ElementB>::value * kAlignmentB) == 128)
? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
// Define the MmaCore components
// Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape,
ElementA, LayoutA, ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, OperatorClass,
std::max(kStages, 3), Operator, false, CacheOpA, CacheOpB>;
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::Array<ElementA, kAlignmentA>;
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, ElementA, LayoutA, 1, ThreadMapA,
AccessTypeA>;
private:
static constexpr int ColumnsInterleaved = LayoutB::kColumnsInterleaved;
static constexpr int RowsPerTile = LayoutB::kRowsPerTile;
static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), "");
static_assert(RowsPerTile == MmaCore::Shape::kK, "");
using OriginalThreadMap = typename MmaCore::IteratorThreadMapB;
using OriginalWarpArrangement = typename OriginalThreadMap::Detail::WarpThreadArrangement;
static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), "");
using GmemIteratorShape
= MatrixShape<MmaCore::Shape::kK * ColumnsInterleaved, MmaCore::Shape::kN / ColumnsInterleaved>;
using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap<
layout::PitchLinearShape<GmemIteratorShape::kRow, GmemIteratorShape::kColumn>, OriginalThreadMap::kThreads,
layout::PitchLinearShape<OriginalWarpArrangement::kContiguous * ColumnsInterleaved,
OriginalWarpArrangement::kStrided / ColumnsInterleaved>,
MmaCore::kAccessSizeInBits / sizeof_bits<ElementB>::value>;
public:
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::Array<ElementB, kAlignmentB>;
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<GmemIteratorShape, ElementB,
layout::ColumnMajor, 0, GmemThreadMapB, AccessTypeB>;
using ScaleIterators = DefaultScaleIteratorsMultistage<typename MmaCore::Shape, ElementScale, LayoutScale,
OperatorInfo::QuantOp, kAlignmentScale>;
// Define iterators over tiles from the scale operand
using IteratorScale = typename ScaleIterators::IteratorScale;
using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale;
using Converter = FastInterleavedAndBiasedNumericArrayConverter<ElementScale, ElementB,
MmaCore::MmaPolicy::Operator::FragmentB::kElements>;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage<typename MmaCore::Shape, IteratorA,
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
MmaCore::kCacheOpB, IteratorScale, SmemIteratorScale, ElementAccumulator, layout::RowMajor,
typename MmaCore::MmaPolicy, kStages, Converter, OperatorInfo::QuantOp, SharedMemoryClear>;
};
} // namespace threadblock
} // namespace gemm
} // namespace cutlass

View File

@@ -0,0 +1,284 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cutlass/gemm/threadblock/default_mma.h"
#include "cutlass_extensions/arch/mma.h"
#include "cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h"
#include "cutlass_extensions/gemm/warp/default_mma_tensor_op.h"
#include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h"
#include "cutlass_extensions/tile_interleaved_layout.h"
#include "cutlass_extensions/gemm/threadblock/default_dq_mma.h"
#include "cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h"
namespace cutlass
{
namespace gemm
{
namespace threadblock
{
////////////////////////////////////////////////////////////////////////////////
template <typename MmaShape, typename Element, typename Layout, WeightOnlyQuantOp QuantOp, int Alignment,
typename Enable = void>
struct DefaultScaleIteratorsPipelined;
// Fine grained iterators
template <typename MmaShape, typename Element, typename Layout, WeightOnlyQuantOp QuantOp, int Alignment>
struct DefaultScaleIteratorsPipelined<MmaShape, Element, Layout, QuantOp, Alignment,
std::enable_if_t<isFinegrained(QuantOp)>>
{
private:
using SmemScaleType = half_t;
public:
using IteratorScale
= cutlass::transform::threadblock::FineGrainedScaleZeroIterator<cutlass::MatrixShape<1, MmaShape::kN>, Element,
Layout, 0, Alignment>;
using SmemIteratorScale
= cutlass::transform::threadblock::FineGrainedScaleZeroIterator<cutlass::MatrixShape<1, MmaShape::kN>,
SmemScaleType, Layout, 0, Alignment>;
};
// Per column iterators
template <typename MmaShape, typename Element, typename Layout, WeightOnlyQuantOp QuantOp, int Alignment>
struct DefaultScaleIteratorsPipelined<MmaShape, Element, Layout, QuantOp, Alignment,
std::enable_if_t<!isFinegrained(QuantOp)>>
{
static_assert((MmaShape::kN % Alignment) == 0, "");
private:
// ThreadMap for scale iterator
using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap<layout::PitchLinearShape<MmaShape::kN, 1>,
MmaShape::kN / Alignment, Alignment>;
using SmemScaleType = half_t;
public:
// Define iterators over tiles from the scale operand
using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaShape::kN>,
Element, Layout, 0, IteratorScaleThreadMap, Alignment>;
using SmemIteratorScale
= cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaShape::kN>, SmemScaleType,
Layout, 0, IteratorScaleThreadMap, Alignment>;
};
////////////////////////////////////////////////////////////////////////////////
template <
/// Type for element A
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Type for element B
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for the input scale
typename ElementScale,
/// Layout for the scale operand
typename LayoutScale,
/// Access granularity of Scales in unit of elements
int kAlignmentScale,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Operator class tag
typename OperatorClass,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator_>
struct DqMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementScale, LayoutScale, kAlignmentScale,
ElementAccumulator, layout::RowMajor, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2,
Operator_, SharedMemoryClearOption::kNone,
typename platform::enable_if<(
ArchTag::kMinComputeCapability < 80 && !layout::IsColumnMajorTileInterleave<LayoutB>::value)>::type>
{
static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value,
"Element A must be fp16 or bf16");
static_assert(platform::is_same<ElementB, uint8_t>::value || platform::is_same<ElementB, uint4b_t>::value,
"Element B must be uint8 or uint4");
using OperatorInfo = arch::DetagOperator<Operator_>;
using Operator = typename OperatorInfo::Operator;
static_assert(OperatorInfo::QuantOp == WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, "");
static constexpr bool DqAfterLDG = platform::is_same<arch::OpMultiplyAdd, Operator>::value;
using MmaCoreElementA = half_t;
using MmaCoreElementB = typename platform::conditional<DqAfterLDG, MmaCoreElementA, ElementB>::type;
// Define the MmaCore components
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape,
MmaCoreElementA, LayoutA, MmaCoreElementB, LayoutB, ElementAccumulator, layout::RowMajor, OperatorClass, 2,
Operator>;
// Define iterators over tiles from the A operand
using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator<
cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>, ElementA, LayoutA, 1,
typename MmaCore::IteratorThreadMapA, kAlignmentA>;
// Define iterators over tiles from the B operand
using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator<
cutlass::MatrixShape<MmaCore::Shape::kK, MmaCore::Shape::kN>, ElementB, LayoutB, 0,
typename MmaCore::IteratorThreadMapB, kAlignmentB>;
using ScaleIterators = DefaultScaleIteratorsPipelined<typename MmaCore::Shape, ElementScale, LayoutScale,
OperatorInfo::QuantOp, kAlignmentScale>;
// Define iterators over tiles from the scale operand
using IteratorScale = typename ScaleIterators::IteratorScale;
using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale;
using Converters = SetConverters<IteratorB, typename MmaCore::MmaPolicy::Operator, Operator>;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined<typename MmaCore::Shape, IteratorA,
typename MmaCore::SmemIteratorA, IteratorB, typename MmaCore::SmemIteratorB, IteratorScale, SmemIteratorScale,
ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, typename Converters::TransformAfterLDG,
typename Converters::TransformAfterLDS, OperatorInfo::QuantOp>;
};
// Specialization to handle column major interleave B
template <
/// Type for element A
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Type for element B
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for the input scale
typename ElementScale,
/// Layout for the scale operand
typename LayoutScale,
/// Access granularity of Scales in unit of elements
int kAlignmentScale,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Operator class tag
typename OperatorClass,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator_>
struct DqMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementScale, LayoutScale, kAlignmentScale,
ElementAccumulator, layout::RowMajor, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2,
Operator_, SharedMemoryClearOption::kNone,
typename platform::enable_if<(
ArchTag::kMinComputeCapability < 80 && layout::IsColumnMajorTileInterleave<LayoutB>::value)>::type>
{
static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value,
"Element A must be fp16 or bf16");
static_assert(platform::is_same<ElementB, uint8_t>::value || platform::is_same<ElementB, uint4b_t>::value,
"Element B must be uint8 or uint4");
using OperatorInfo = arch::DetagOperator<Operator_>;
using Operator = typename OperatorInfo::Operator;
static constexpr bool DqAfterLDG = platform::is_same<arch::OpMultiplyAdd, Operator>::value;
using MmaCoreElementA = half_t;
using MmaCoreElementB = typename platform::conditional<DqAfterLDG, MmaCoreElementA, ElementB>::type;
// Define the MmaCore components
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape,
MmaCoreElementA, LayoutA, MmaCoreElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor,
OperatorClass, 2, Operator>;
// Define iterators over tiles from the A operand
using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator<
cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>, ElementA, LayoutA, 1,
typename MmaCore::IteratorThreadMapA, kAlignmentA>;
private:
static constexpr int ColumnsInterleaved = LayoutB::kColumnsInterleaved;
static constexpr int RowsPerTile = LayoutB::kRowsPerTile;
static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), "");
static_assert(RowsPerTile == MmaCore::Shape::kK, "");
using OriginalThreadMap = typename MmaCore::IteratorThreadMapB;
using OriginalWarpArrangement = typename OriginalThreadMap::Detail::WarpThreadArrangement;
static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), "");
using GmemIteratorShape
= MatrixShape<MmaCore::Shape::kK * ColumnsInterleaved, MmaCore::Shape::kN / ColumnsInterleaved>;
using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap<
layout::PitchLinearShape<GmemIteratorShape::kRow, GmemIteratorShape::kColumn>, OriginalThreadMap::kThreads,
layout::PitchLinearShape<OriginalWarpArrangement::kContiguous * ColumnsInterleaved,
OriginalWarpArrangement::kStrided / ColumnsInterleaved>,
MmaCore::kAccessSizeInBits / sizeof_bits<ElementB>::value>;
public:
// Define iterators over tiles from the B operand
using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator<GmemIteratorShape, ElementB,
layout::ColumnMajor, 0, GmemThreadMapB, kAlignmentB>;
// ThreadMap for scale iterator
static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, "");
using IteratorScaleThreadMap
= transform::PitchLinearStripminedThreadMap<layout::PitchLinearShape<MmaCore::Shape::kN, 1>,
MmaCore::Shape::kN / kAlignmentScale, kAlignmentScale>;
using ScaleIterators = DefaultScaleIteratorsPipelined<typename MmaCore::Shape, ElementScale, LayoutScale,
OperatorInfo::QuantOp, kAlignmentScale>;
// Define iterators over tiles from the scale operand
using IteratorScale = typename ScaleIterators::IteratorScale;
using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale;
using Converters = SetConverters<IteratorB, typename MmaCore::MmaPolicy::Operator, Operator>;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined<typename MmaCore::Shape, IteratorA,
typename MmaCore::SmemIteratorA, IteratorB, typename MmaCore::SmemIteratorB, IteratorScale, SmemIteratorScale,
ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, typename Converters::TransformAfterLDG,
typename Converters::TransformAfterLDS, OperatorInfo::QuantOp>;
};
} // namespace threadblock
} // namespace gemm
} // namespace cutlass

View File

@@ -0,0 +1,351 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h"
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h"
#include "cutlass_extensions/gemm/threadblock/default_mma_bf16.h"
namespace cutlass
{
namespace gemm
{
namespace threadblock
{
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight, mma pipelined (stage=2)
template <
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator>
struct DefaultMma<cutlass::half_t, LayoutA, kAlignmentA, uint8_t, LayoutB, kAlignmentB, ElementAccumulator,
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator>
{
private:
static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
using Mma = DqMma<half_t, LayoutA, kAlignmentA, uint8_t, LayoutB, kAlignmentB, half_t, layout::RowMajor,
kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
WarpShape, InstructionShape, 2, Operator>;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight, mma pipelined (stage=2)
template <
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator>
struct DefaultMma<cutlass::half_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, ElementAccumulator,
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator>
{
private:
static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
using Mma = DqMma<half_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, half_t, layout::RowMajor,
kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
WarpShape, InstructionShape, 2, Operator>;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight, mma multistage
/// (stage>=3)
template <
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator,
///
int kStages,
/// Shared memory clear option
SharedMemoryClearOption SharedMemoryClear>
struct DefaultMma<cutlass::half_t, LayoutA, kAlignmentA, uint8_t, LayoutB, kAlignmentB, ElementAccumulator,
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
false, SharedMemoryClear>
{
private:
static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
using Mma = DqMma<half_t, LayoutA, kAlignmentA, uint8_t, LayoutB, kAlignmentB, half_t, layout::RowMajor,
kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
WarpShape, InstructionShape, kStages, Operator, SharedMemoryClear>;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight, mma multistage
/// (stage>=3)
template <
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator,
///
int kStages,
/// Shared memory clear option
SharedMemoryClearOption SharedMemoryClear>
struct DefaultMma<cutlass::half_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, ElementAccumulator,
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
false, SharedMemoryClear>
{
private:
static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
using Mma = DqMma<half_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, half_t, layout::RowMajor,
kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
WarpShape, InstructionShape, kStages, Operator, SharedMemoryClear>;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
};
#ifdef ENABLE_FP8
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), fp8 activation & int4 weight, mma multistage
/// (stage>=3)
template <
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator,
///
int kStages,
/// Shared memory clear option
SharedMemoryClearOption SharedMemoryClear>
struct DefaultMma<cutlass::float_e4m3_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, ElementAccumulator,
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
false, SharedMemoryClear>
{
private:
static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
using Mma = DqMma<cutlass::float_e4m3_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, half_t,
layout::RowMajor, kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag,
ThreadblockShape, WarpShape, InstructionShape, kStages, Operator, SharedMemoryClear>;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
};
#endif
// fp16 x fp16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on
// large tile when not enough shared mem is present to do 3+ stage
template <
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear,
/// Gather operand A by using an index array
bool GatherA,
/// Gather operand B by using an index array
bool GatherB>
struct DefaultMma<half_t, LayoutA, kAlignmentA, half_t, LayoutB, kAlignmentB, ElementAccumulator, layout::RowMajor,
arch::OpClassTensorOp, arch::Sm80, ThreadblockShape, WarpShape, InstructionShape, 2, Operator, false,
SharedMemoryClear, GatherA, GatherB>
{
// Define the MmaCore components
// 3 is used on purpose here to trigger components for mma multistage
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape,
half_t, LayoutA, half_t, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, 3, Operator>;
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::Array<half_t, kAlignmentA>;
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, half_t, LayoutA, 1, ThreadMapA, AccessTypeA,
GatherA>;
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::Array<half_t, kAlignmentB>;
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, half_t, LayoutB, 0, ThreadMapB, AccessTypeB,
GatherB>;
// Define the threadblock-scoped multistage matrix multiply
using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage<typename MmaCore::Shape, IteratorA,
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, 2>;
};
} // namespace threadblock
} // namespace gemm
} // namespace cutlass

View File

@@ -0,0 +1,353 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cutlass/gemm/threadblock/default_mma.h"
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h"
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h"
namespace cutlass
{
namespace gemm
{
namespace threadblock
{
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & bf16 weight
template <
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear,
/// Gather operand A by using an index array
bool GatherA,
/// Gather operand B by using an index array
bool GatherB>
struct DefaultMma<bfloat16_t, LayoutA, kAlignmentA, bfloat16_t, LayoutB, kAlignmentB, ElementAccumulator,
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator, false,
SharedMemoryClear, GatherA, GatherB>
{
private:
// Conversions only needed pre-ampere. This will trigger mma pipeline, so we convert before STS.
static constexpr bool arch_has_bf16_mma = ArchTag::kMinComputeCapability >= 80;
using MmaElementA = typename platform::conditional<arch_has_bf16_mma, bfloat16_t, half_t>::type;
using MmaElementB = typename platform::conditional<arch_has_bf16_mma, bfloat16_t, half_t>::type;
public:
// Define the MmaCore components
using MmaCore =
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape, MmaElementA,
LayoutA, MmaElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, 2, Operator>;
using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator<
cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>, bfloat16_t, LayoutA, 1,
typename MmaCore::IteratorThreadMapA, kAlignmentA, GatherA>;
// Define iterators over tiles from the B operand
using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator<
cutlass::MatrixShape<MmaCore::Shape::kK, MmaCore::Shape::kN>, bfloat16_t, LayoutB, 0,
typename MmaCore::IteratorThreadMapB, kAlignmentB, GatherB>;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = cutlass::gemm::threadblock::MmaPipelined<typename MmaCore::Shape, IteratorA,
typename MmaCore::SmemIteratorA, IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator,
layout::RowMajor, typename MmaCore::MmaPolicy>;
};
// bf16 x bf16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on
// large tile when not enough shared mem is present to do 3+ stage
template <
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear,
/// Gather operand A by using an index array
bool GatherA,
/// Gather operand B by using an index array
bool GatherB>
struct DefaultMma<bfloat16_t, LayoutA, kAlignmentA, bfloat16_t, LayoutB, kAlignmentB, ElementAccumulator,
layout::RowMajor, arch::OpClassTensorOp, arch::Sm80, ThreadblockShape, WarpShape, InstructionShape, 2, Operator,
false, SharedMemoryClear, GatherA, GatherB>
{
// Define the MmaCore components
// 3 is used on purpose here to trigger components for mma multistage
using MmaCore =
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape, bfloat16_t,
LayoutA, bfloat16_t, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, 3, Operator>;
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::Array<bfloat16_t, kAlignmentA>;
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, bfloat16_t, LayoutA, 1, ThreadMapA,
AccessTypeA, GatherA>;
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::Array<bfloat16_t, kAlignmentB>;
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, bfloat16_t, LayoutB, 0, ThreadMapB,
AccessTypeB, GatherB>;
// Define the threadblock-scoped multistage matrix multiply
using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage<typename MmaCore::Shape, IteratorA,
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, 2>;
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int8 weight
template <
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator>
struct DefaultMma<cutlass::bfloat16_t, LayoutA, kAlignmentA, uint8_t, LayoutB, kAlignmentB, ElementAccumulator,
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator>
{
private:
static constexpr int kAlignmentScale = 128 / sizeof_bits<bfloat16_t>::value;
using Mma = DqMma<bfloat16_t, LayoutA, kAlignmentA, uint8_t, LayoutB, kAlignmentB, bfloat16_t, layout::RowMajor,
kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
WarpShape, InstructionShape, 2, Operator>;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int4 weight
template <
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator>
struct DefaultMma<cutlass::bfloat16_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, ElementAccumulator,
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator>
{
private:
static constexpr int kAlignmentScale = 128 / sizeof_bits<bfloat16_t>::value;
using Mma = DqMma<bfloat16_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, bfloat16_t, layout::RowMajor,
kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
WarpShape, InstructionShape, 2, Operator>;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
};
template <
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator,
///
int kStages,
/// Shared memory clear option
SharedMemoryClearOption SharedMemoryClear>
struct DefaultMma<cutlass::bfloat16_t, LayoutA, kAlignmentA, uint8_t, LayoutB, kAlignmentB, ElementAccumulator,
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
false, SharedMemoryClear>
{
private:
static constexpr int kAlignmentScale = 128 / sizeof_bits<bfloat16_t>::value;
using Mma = DqMma<bfloat16_t, LayoutA, kAlignmentA, uint8_t, LayoutB, kAlignmentB, bfloat16_t, layout::RowMajor,
kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
WarpShape, InstructionShape, kStages, Operator, SharedMemoryClear>;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight
template <
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator,
///
int kStages,
/// Shared memory clear option
SharedMemoryClearOption SharedMemoryClear>
struct DefaultMma<cutlass::bfloat16_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, ElementAccumulator,
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
false, SharedMemoryClear>
{
private:
static constexpr int kAlignmentScale = 128 / sizeof_bits<bfloat16_t>::value;
using Mma = DqMma<bfloat16_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, bfloat16_t, layout::RowMajor,
kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
WarpShape, InstructionShape, kStages, Operator, SharedMemoryClear>;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
};
} // namespace threadblock
} // namespace gemm
} // namespace cutlass

View File

@@ -0,0 +1,257 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
*/
#pragma once
#include "cutlass/aligned_buffer.h"
#include "cutlass/arch/memory.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/threadblock/mma_base.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
#include "cutlass_extensions/weight_only_quant_op.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace gemm
{
namespace threadblock
{
////////////////////////////////////////////////////////////////////////////////
// SFINAE trick so I can keep the same loop code for Volta and dispatch to the
// correct warp level mma. On volta, all data is stored to shared memory as FP16.
template <typename WarpMma, int kExpansionFactor = 1>
CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, typename WarpMma::FragmentC& D,
typename WarpMma::FragmentA const& A, typename WarpMma::FragmentB const& B, typename WarpMma::FragmentC const& C,
int const warp_tileB_k_offset)
{
warp_mma(D, A, B, C);
}
template <typename WarpMma, int kExpansionFactor = WarpMma::kExpansionFactor>
CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, typename WarpMma::FragmentC& D,
typename WarpMma::TransformedFragmentA const& A, typename WarpMma::TransformedFragmentB const& B,
typename WarpMma::FragmentC const& C, int const warp_tileB_k_offset)
{
warp_mma(D, A, B, C, warp_tileB_k_offset);
}
////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
/// instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy_,
/// The type of the scales
typename ElementScale_,
/// Number of stages,
int Stages,
/// The dequantizing op to be performed.
WeightOnlyQuantOp DequantOp,
/// Used for partial specialization,
typename Enable = bool>
class DqMmaBase
{
public:
///< Size of the Gemm problem - concept: gemm::GemmShape<>
using Shape = Shape_;
///< Policy describing tuning details
using Policy = Policy_;
///< Type of the scale to be loaded
using ElementScale = ElementScale_;
static_assert(DequantOp != WeightOnlyQuantOp::UNDEFINED, "");
// Finegrained scales get streamed in via cp.async
static constexpr int ScalebiasStages = isFinegrained(DequantOp) ? Stages : 1;
// We always have scales.
static constexpr int ScaleElementsPerStage = Shape::kN;
// We sometimes have a bias
static constexpr int BiasElementsPerStage = hasZero(DequantOp) ? Shape::kN : 0;
//
// Dependent types
//
/// Warp-level Mma
using Operator = typename Policy::Operator;
/// Shape describing the overall GEMM computed from shared memory
/// by each warp.
using WarpGemm = typename Policy::Operator::Shape;
/// Shape describing the number of warps filling the CTA
using WarpCount = GemmShape<Shape::kM / WarpGemm::kM, Shape::kN / WarpGemm::kN, Shape::kK / WarpGemm::kK>;
/// Number of warp-level GEMM operations
static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK);
static constexpr int kNumKIterationsPerWarpBLoad
= Operator::IteratorB::InstructionShape::kRow / Operator::InstructionShape::kK;
static_assert(!(kWarpGemmIterations % kNumKIterationsPerWarpBLoad), "");
static constexpr int kWarpGemmIterationsForB = kWarpGemmIterations / kNumKIterationsPerWarpBLoad;
/// Number of stages
static int const kStages = Stages;
/// Tensor reference to the A operand
using TensorRefA = TensorRef<typename Operator::ElementA, typename Operator::LayoutA>;
/// Tensor reference to the B operand
using TensorRefB = TensorRef<typename Operator::ElementB, typename Operator::LayoutB>;
//
// Nested structs
//
/// Shared storage object needed by threadblock-scoped GEMM
class SharedStorage
{
public:
//
// Type definitions
//
/// Shape of the A matrix operand in shared memory
using ShapeA
= MatrixShape<Shape::kM + Policy::SmemPaddingA::kRow, Shape::kK * kStages + Policy::SmemPaddingA::kColumn>;
/// Shape of the B matrix operand in shared memory
using ShapeB
= MatrixShape<Shape::kK * kStages + Policy::SmemPaddingB::kRow, Shape::kN + Policy::SmemPaddingB::kColumn>;
/// Shape of the shared memory buffer for the scales for the B matrix.
using ShapeScale = MatrixShape<ScalebiasStages, ScaleElementsPerStage>;
/// Shape of the shared memory buffer for the biases of the B matrix.
using ShapeZero = MatrixShape<ScalebiasStages, BiasElementsPerStage>;
public:
//
// Data members
//
/// Buffer for A operand
AlignedBuffer<typename Operator::ElementA, ShapeA::kCount> operand_A;
/// Buffer for B operand
AlignedBuffer<typename Operator::ElementB, ShapeB::kCount> operand_B;
/// Buffer to hold scales for threadblock
AlignedBuffer<ElementScale, ShapeScale::kCount> operand_scale;
/// Buffer to hold scales for threadblock
AlignedBuffer<ElementScale, ShapeZero::kCount> operand_zero;
public:
//
// Methods
//
/// Returns a layout object for the A matrix
CUTLASS_DEVICE
static typename Operator::LayoutA LayoutA()
{
return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn});
}
/// Returns a layout object for the B matrix
CUTLASS_HOST_DEVICE
static typename Operator::LayoutB LayoutB()
{
return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn});
}
/// Returns a TensorRef to the A operand
CUTLASS_HOST_DEVICE
TensorRefA operand_A_ref()
{
return TensorRefA{operand_A.data(), LayoutA()};
}
/// Returns a TensorRef to the B operand
CUTLASS_HOST_DEVICE
TensorRefB operand_B_ref()
{
return TensorRefB{operand_B.data(), LayoutB()};
}
};
protected:
//
// Data members
//
/// Iterator to load a warp-scoped tile of A operand from shared memory
typename Operator::IteratorA warp_tile_iterator_A_;
/// Iterator to load a warp-scoped tile of B operand from shared memory
typename Operator::IteratorB warp_tile_iterator_B_;
public:
/// Construct from tensor references
CUTLASS_DEVICE
DqMmaBase(
///< Shared storage needed for internal use by threadblock-scoped GEMM
SharedStorage& shared_storage,
///< ID within the threadblock
int thread_idx,
///< ID of warp
int warp_idx,
///< ID of each thread within a warp
int lane_idx)
: warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx)
, warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx)
{
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,110 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
*/
#pragma once
#include "cutlass/aligned_buffer.h"
#include "cutlass/arch/memory.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h"
#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h"
#include "cutlass_extensions/interleaved_numeric_conversion.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace gemm
{
namespace threadblock
{
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
/// instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Iterates over tiles of A operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorA_,
/// Iterates over tiles of A operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorA_,
/// Cache operation for operand A
cutlass::arch::CacheOperation::Kind CacheOpA,
/// Iterates over tiles of B operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorB_,
/// Iterates over tiles of B operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorB_,
/// Cache operation for operand B
cutlass::arch::CacheOperation::Kind CacheOpB,
/// Data type for the scales
typename IteratorScale_,
/// Iterators over scales in shared memory
typename SmemIteratorScale_,
/// Data type of accumulator matrix
typename ElementC_,
/// Data type of accumulator matrix
typename LayoutC_,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy_,
/// Number of stages,
int Stages,
/// Converter for B matrix applited immediately after the LDS
typename TransformBAfterLDS_,
/// The quantization operator being used
WeightOnlyQuantOp QuantOp_,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
/// Used for partial specialization
typename Enable = void>
class DqMmaMultistage;
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
#include "cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h"
#include "cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h"

View File

@@ -0,0 +1,708 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
*/
#pragma once
#include "cutlass/aligned_buffer.h"
#include "cutlass/arch/memory.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h"
#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h"
#include "cutlass_extensions/interleaved_numeric_conversion.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace gemm
{
namespace threadblock
{
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
/// instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Iterates over tiles of A operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorA_,
/// Iterates over tiles of A operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorA_,
/// Cache operation for operand A
cutlass::arch::CacheOperation::Kind CacheOpA,
/// Iterates over tiles of B operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorB_,
/// Iterates over tiles of B operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorB_,
/// Cache operation for operand B
cutlass::arch::CacheOperation::Kind CacheOpB,
/// Iterators over scales in global memory
typename IteratorScale_,
/// Iterators over scales in shared memory
typename SmemIteratorScale_,
/// Data type of accumulator matrix
typename ElementC_,
/// Layout of accumulator matrix
typename LayoutC_,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy_,
/// Number of stages,
int Stages,
/// Converter for B matrix applied immediately after the LDS
typename TransformBAfterLDS_,
/// The quantization operator being used
WeightOnlyQuantOp QuantOp_,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear>
class DqMmaMultistage<Shape_, IteratorA_, SmemIteratorA_, CacheOpA, IteratorB_, SmemIteratorB_, CacheOpB,
IteratorScale_, SmemIteratorScale_, ElementC_, LayoutC_, Policy_, Stages, TransformBAfterLDS_, QuantOp_,
SharedMemoryClear, std::enable_if_t<isFinegrained(QuantOp_)>>
: public DqMmaBase<Shape_, Policy_, typename IteratorScale_::Element, Stages, QuantOp_>
{
public:
///< Base class
using Base = DqMmaBase<Shape_, Policy_, typename IteratorScale_::Element, Stages, QuantOp_>;
///< Size of the Gemm problem - concept: gemm::GemmShape<>
using Shape = Shape_;
///< Iterates over tiles of A operand in global memory
using IteratorA = IteratorA_;
///< Iterates over tiles of B operand in global memory
using IteratorB = IteratorB_;
///< Data type of accumulator matrix
using ElementC = ElementC_;
///< Layout of accumulator matrix
using LayoutC = LayoutC_;
///< Policy describing tuning details
using Policy = Policy_;
using IteratorScale = IteratorScale_;
using ElementScale = typename IteratorScale::Element;
using LayoutScale = typename IteratorScale::Layout;
using SmemIteratorA = SmemIteratorA_;
using SmemIteratorB = SmemIteratorB_;
using SmemIteratorScale = SmemIteratorScale_;
static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
using TransformBAfterLDS = TransformBAfterLDS_;
static constexpr WeightOnlyQuantOp QuantOp = QuantOp_;
//
// Dependent types
//
/// Fragment of accumulator tile
using FragmentC = typename Policy::Operator::FragmentC;
/// Warp-level Mma
using Operator = typename Policy::Operator;
/// Minimum architecture is Sm80 to support cp.async
using ArchTag = arch::Sm80;
using Dequantizer = warp::MmaTensorOpDequantizer<Operator, typename Base::WarpGemm, Operand::kB, ElementScale,
LayoutScale, 32, QuantOp>;
/// Complex transform on A operand
static ComplexTransform const kTransformA = Operator::kTransformA;
/// Complex transform on B operand
static ComplexTransform const kTransformB = Operator::kTransformB;
static_assert(Base::SharedStorage::ShapeScale::kRow == Stages, "");
static_assert(Base::SharedStorage::ShapeScale::kColumn == Shape::kN, "");
/// Internal structure exposed for introspection.
struct Detail
{
static_assert(Base::kWarpGemmIterations > 1,
"The pipelined structure requires at least two warp-level "
"GEMM operations.");
/// Number of cp.async instructions to load one stage of operand A
static int const AsyncCopyIterationsPerStageA = IteratorA::ThreadMap::Iterations::kCount;
/// Number of cp.async instructions to load one stage of operand B
static int const AsyncCopyIterationsPerStageB = IteratorB::ThreadMap::Iterations::kCount;
/// Number of stages
static int const kStages = Stages;
/// Number of cp.async instructions to load on group of operand A
static int const kAccessesPerGroupA
= (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
/// Number of cp.async instructions to load on group of operand B
static int const kAccessesPerGroupB
= (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
};
private:
using WarpFragmentA = typename Operator::FragmentA;
using WarpFragmentB = typename Operator::FragmentB;
Dequantizer warp_dequantizer_;
using ElementA = typename IteratorA::Element;
using ElementB = typename IteratorB::Element;
using LayoutDetailsForB = kernel::LayoutDetailsB<ElementA, ElementB, ArchTag>;
static constexpr bool RequiresTileInterleave
= layout::IsColumnMajorTileInterleave<typename LayoutDetailsForB::Layout>::value;
static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)),
"Layout K must match threadblockK");
private:
//
// Data members
//
/// Iterator to write threadblock-scoped tile of A operand to shared memory
SmemIteratorA smem_iterator_A_;
/// Iterator to write threadblock-scoped tile of B operand to shared memory
SmemIteratorB smem_iterator_B_;
/// Iterator to write threadblock-scoped tile of scale and zero operand to shared memory
SmemIteratorScale smem_iterator_scale_;
public:
/// Construct from tensor references
CUTLASS_DEVICE
DqMmaMultistage(
///< Shared storage needed for internal use by threadblock-scoped GEMM
typename Base::SharedStorage& shared_storage,
/// The group size for quantization
int const group_size,
///< ID within the threadblock
int thread_idx,
///< ID of warp
int warp_idx,
///< ID of each thread within a warp
int lane_idx)
: Base(shared_storage, thread_idx, warp_idx, lane_idx)
, warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)},
{shared_storage.operand_zero.data(), LayoutScale(Shape::kN)},
(warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx)
, smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx)
, smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx)
, smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(),
shared_storage.operand_zero.data(), {Base::kStages, Shape::kN}, thread_idx, group_size)
{
// Compute warp location within threadblock tile by mapping the warp_id to
// three coordinates:
// _m: the warp's position within the threadblock along the M dimension
// _n: the warp's position within the threadblock along the N dimension
// _k: the warp's position within the threadblock along the K dimension
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
// Add per-warp offsets in units of warp-level tiles
this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n});
}
CUTLASS_DEVICE
void copy_scales_and_advance(IteratorScale& iterator_scale, int stage = -1, int k_iter = -1)
{
static_assert(IteratorScale::Shape::kRow == 1, "Scale stride must be 1.");
typename IteratorScale::AccessType* gmem_scale_ptr = iterator_scale.get_scale();
typename IteratorScale::AccessType* gmem_zero_ptr = iterator_scale.get_zero();
typename IteratorScale::AccessType* smem_scale_ptr
= reinterpret_cast<typename IteratorScale::AccessType*>(this->smem_iterator_scale_.get_scale());
typename IteratorScale::AccessType* smem_zero_ptr
= reinterpret_cast<typename IteratorScale::AccessType*>(this->smem_iterator_scale_.get_zero());
int const kSrcBytes = sizeof_bits<typename IteratorScale::Element>::value * IteratorScale::kAlignment / 8;
cutlass::arch::cp_async<kSrcBytes, kCacheOpB>(smem_scale_ptr, gmem_scale_ptr, iterator_scale.valid());
if (gmem_zero_ptr != nullptr)
{
cutlass::arch::cp_async<kSrcBytes, kCacheOpB>(smem_zero_ptr, gmem_zero_ptr, iterator_scale.valid());
}
if (iterator_scale.group_size_ == 64)
{
iterator_scale.add_tile_offset({1, 0});
}
else if (iterator_scale.group_size_ == 128)
{
if constexpr (Shape::kK == 128)
{
iterator_scale.add_tile_offset({1, 0});
}
else if constexpr (Shape::kK == 64)
{
if (iterator_scale.row_groupsize64_ & 0x1)
{
iterator_scale.add_tile_offset({1, 0});
}
}
else
{
static_assert(Shape::kK == 0, "Unsupported k tile shape, can only be 64 or 128");
}
}
iterator_scale.row_groupsize64_++;
this->smem_iterator_scale_.add_tile_offset({1, 0});
}
CUTLASS_DEVICE
void copy_tiles_and_advance(
IteratorA& iterator_A, IteratorB& iterator_B, int group_start_A = 0, int group_start_B = 0)
{
iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector);
this->smem_iterator_A_.set_iteration_index(group_start_A);
// Async Copy for operand A
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::kAccessesPerGroupA; ++j)
{
if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA)
{
typename IteratorA::AccessType* dst_ptr
= reinterpret_cast<typename IteratorA::AccessType*>(this->smem_iterator_A_.get());
int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value
* IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8;
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v)
{
auto gmem_ptr = iterator_A.get();
if (SharedMemoryClear == SharedMemoryClearOption::kZfill)
{
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(dst_ptr + v, gmem_ptr, iterator_A.valid());
}
else
{
cutlass::arch::cp_async<kSrcBytes, kCacheOpA>(dst_ptr + v, gmem_ptr, iterator_A.valid());
}
++iterator_A;
}
++this->smem_iterator_A_;
}
}
iterator_B.set_iteration_index(group_start_B * IteratorB::kAccessesPerVector);
this->smem_iterator_B_.set_iteration_index(group_start_B);
// Async Copy for operand B
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::kAccessesPerGroupB; ++j)
{
if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB)
{
typename IteratorB::AccessType* dst_ptr
= reinterpret_cast<typename IteratorB::AccessType*>(this->smem_iterator_B_.get());
int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value
* IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8;
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v)
{
auto gmem_ptr = iterator_B.get();
if (SharedMemoryClear == SharedMemoryClearOption::kZfill)
{
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(dst_ptr + v, gmem_ptr, iterator_B.valid());
}
else
{
cutlass::arch::cp_async<kSrcBytes, kCacheOpB>(dst_ptr + v, gmem_ptr, iterator_B.valid());
}
++iterator_B;
}
++this->smem_iterator_B_;
}
}
}
/// Perform a threadblock-scoped matrix multiply-accumulate
CUTLASS_DEVICE
void operator()(
///< problem size of GEMM
int gemm_k_iterations,
///< destination accumulator tile
FragmentC& accum,
///< iterator over A operand in global memory
IteratorA iterator_A,
///< iterator over B operand in global memory
IteratorB iterator_B,
///< iterator over scale operand in global memory
IteratorScale iterator_scale,
///< initial value of accumulator
FragmentC const& src_accum)
{
//
// Prologue
//
TransformBAfterLDS lds_converter;
// Issue several complete stages
CUTLASS_PRAGMA_UNROLL
for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations)
{
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_scale.clear_mask(gemm_k_iterations == 0);
iterator_A.set_iteration_index(0);
this->smem_iterator_A_.set_iteration_index(0);
// Async Copy for operand A
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j)
{
typename IteratorA::AccessType* dst_ptr
= reinterpret_cast<typename IteratorA::AccessType*>(this->smem_iterator_A_.get());
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v)
{
int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value
* IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8;
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
dst_ptr + v, iterator_A.get(), iterator_A.valid());
++iterator_A;
}
++this->smem_iterator_A_;
}
iterator_B.set_iteration_index(0);
this->smem_iterator_B_.set_iteration_index(0);
// Async Copy for operand B
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j)
{
typename IteratorB::AccessType* dst_ptr
= reinterpret_cast<typename IteratorB::AccessType*>(this->smem_iterator_B_.get());
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v)
{
int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value
* IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8;
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
dst_ptr + v, iterator_B.get(), iterator_B.valid());
++iterator_B;
}
++this->smem_iterator_B_;
}
copy_scales_and_advance(iterator_scale, stage, gemm_k_iterations);
// Move to the next stage
iterator_A.add_tile_offset({0, 1});
iterator_B.add_tile_offset({1, 0});
this->smem_iterator_A_.add_tile_offset({0, 1});
this->smem_iterator_B_.add_tile_offset({1, 0});
// Defines the boundary of a stage of cp.async.
cutlass::arch::cp_async_fence();
}
// Perform accumulation in the 'd' output operand
accum = src_accum;
//
// Clear the remaining tiles of SMEM. This is a functional requirement for some kernels
// so that all accumulator elements outside the GEMM footprint are zero.
//
if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage)
{
/// Iterator to write threadblock-scoped tile of A operand to shared memory
SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_);
typename IteratorA::AccessType zero_A;
zero_A.clear();
last_smem_iterator_A.set_iteration_index(0);
// Async Copy for operand A
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j)
{
typename IteratorA::AccessType* dst_ptr
= reinterpret_cast<typename IteratorA::AccessType*>(last_smem_iterator_A.get());
*dst_ptr = zero_A;
++last_smem_iterator_A;
}
/// Iterator to write threadblock-scoped tile of B operand to shared memory
SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_);
typename IteratorB::AccessType zero_B;
zero_B.clear();
last_smem_iterator_B.set_iteration_index(0);
// Async Copy for operand B
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j)
{
typename IteratorB::AccessType* dst_ptr
= reinterpret_cast<typename IteratorB::AccessType*>(last_smem_iterator_B.get());
*dst_ptr = zero_B;
++last_smem_iterator_B;
}
}
// Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - #committed)
cutlass::arch::cp_async_wait<Base::kStages - 2>();
__syncthreads();
// Pair of fragments used to overlap shared memory loads and math
// instructions
WarpFragmentA warp_frag_A[2];
WarpFragmentB warp_frag_B[2];
typename Dequantizer::FragmentScale warp_frag_scales;
typename Dequantizer::FragmentZero warp_frag_zeros;
Operator warp_mma;
this->warp_tile_iterator_A_.set_kgroup_index(0);
this->warp_tile_iterator_B_.set_kgroup_index(0);
this->warp_tile_iterator_A_.load(warp_frag_A[0]);
this->warp_tile_iterator_B_.load(warp_frag_B[0]);
warp_dequantizer_.load(warp_frag_scales, warp_frag_zeros);
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B_;
warp_dequantizer_.add_pointer_offset(Shape::kN);
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_scale.clear_mask(gemm_k_iterations == 0);
int smem_write_stage_idx = Base::kStages - 1;
int smem_read_stage_idx = 0;
//
// Mainloop
//
CUTLASS_GEMM_LOOP
for (; gemm_k_iterations > (-Base::kStages + 1);)
{
//
// Loop over GEMM K dimension
//
// Computes a warp-level GEMM on data held in shared memory
// Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate
CUTLASS_PRAGMA_UNROLL
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k)
{
// Load warp-level tiles from shared memory, wrapping to k offset if
// this is the last group as the case may be.
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
++this->warp_tile_iterator_A_;
int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1)
{
this->warp_tile_iterator_B_.set_kgroup_index(
(warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB);
this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]);
++this->warp_tile_iterator_B_;
}
typename TransformBAfterLDS::result_type converted_frag_B
= lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);
warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales, warp_frag_zeros);
using FragmentOperandB = cutlass::Array<ElementA, Operator::FragmentB::kElements>;
constexpr cutlass::FloatRoundStyle RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
constexpr int ConversionVectorWidth = TransformBAfterLDS::result_type::kElements;
static_assert(ConversionVectorWidth == FragmentOperandB::kElements);
using Converter
= cutlass::NumericArrayConverter<ElementA, ElementScale, ConversionVectorWidth, RoundStyle>;
FragmentOperandB converted_frag_B_operand = Converter::convert(converted_frag_B);
run_warp_mma(warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B_operand, accum,
warp_tileB_k_compute_offset);
// Issue global->shared copies for the this stage
if (warp_mma_k < Base::kWarpGemmIterations - 1)
{
int group_start_iteration_A, group_start_iteration_B;
group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA;
group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB;
copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B);
// This is the first group of a given stage, so we issue the loads for the B scales immediately.
if (group_start_iteration_B == 0)
{
copy_scales_and_advance(iterator_scale);
}
}
if (warp_mma_k + 2 == Base::kWarpGemmIterations)
{
int group_start_iteration_A, group_start_iteration_B;
group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA;
group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB;
copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B);
// Inserts a memory fence between stages of cp.async instructions.
cutlass::arch::cp_async_fence();
// Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 -
// #committed)
arch::cp_async_wait<Base::kStages - 2>();
__syncthreads();
// Move to the next stage
iterator_A.add_tile_offset({0, 1});
iterator_B.add_tile_offset({1, 0});
this->smem_iterator_A_.add_tile_offset({0, 1});
this->smem_iterator_B_.add_tile_offset({1, 0});
// Add negative offsets to return iterators to the 'start' of the
// circular buffer in shared memory
if (smem_write_stage_idx == (Base::kStages - 1))
{
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
this->smem_iterator_scale_.add_tile_offset({-Base::kStages, 0});
smem_write_stage_idx = 0;
}
else
{
++smem_write_stage_idx;
}
if (smem_read_stage_idx == (Base::kStages - 1))
{
this->warp_tile_iterator_A_.add_tile_offset(
{0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
this->warp_tile_iterator_B_.add_tile_offset(
{-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0});
warp_dequantizer_.add_pointer_offset(-Base::kStages * Shape::kN);
smem_read_stage_idx = 0;
}
else
{
++smem_read_stage_idx;
}
--gemm_k_iterations;
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_scale.clear_mask(gemm_k_iterations == 0);
}
}
// Load the scale needed for the next tile iteration.
warp_dequantizer_.load(warp_frag_scales, warp_frag_zeros);
// Update internal pointer to set of scales in shared memory.
warp_dequantizer_.add_pointer_offset(Shape::kN);
}
if (SharedMemoryClear == SharedMemoryClearOption::kZfill)
{
// commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop
cutlass::arch::cp_async_fence();
cutlass::arch::cp_async_wait<0>();
__syncthreads();
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,647 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
*/
#pragma once
#include "cutlass/aligned_buffer.h"
#include "cutlass/arch/memory.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h"
#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h"
#include "cutlass_extensions/interleaved_numeric_conversion.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace gemm
{
namespace threadblock
{
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
/// instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Iterates over tiles of A operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorA_,
/// Iterates over tiles of A operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorA_,
/// Cache operation for operand A
cutlass::arch::CacheOperation::Kind CacheOpA,
/// Iterates over tiles of B operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorB_,
/// Iterates over tiles of B operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorB_,
/// Cache operation for operand B
cutlass::arch::CacheOperation::Kind CacheOpB,
/// Iterators over scales in global memory
typename IteratorScale_,
/// Iterators over scales in shared memory
typename SmemIteratorScale_,
/// Data type of accumulator matrix
typename ElementC_,
/// Layout of accumulator matrix
typename LayoutC_,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy_,
/// Number of stages,
int Stages,
/// Converter for B matrix applited immediately after the LDS
typename TransformBAfterLDS_,
/// The quantization operator being used
WeightOnlyQuantOp QuantOp_,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear>
class DqMmaMultistage<Shape_, IteratorA_, SmemIteratorA_, CacheOpA, IteratorB_, SmemIteratorB_, CacheOpB,
IteratorScale_, SmemIteratorScale_, ElementC_, LayoutC_, Policy_, Stages, TransformBAfterLDS_, QuantOp_,
SharedMemoryClear, std::enable_if_t<!isFinegrained(QuantOp_)>>
: public DqMmaBase<Shape_, Policy_, typename IteratorScale_::Element, Stages, QuantOp_>
{
public:
///< Base class
using Base = DqMmaBase<Shape_, Policy_, typename IteratorScale_::Element, Stages, QuantOp_>;
///< Size of the Gemm problem - concept: gemm::GemmShape<>
using Shape = Shape_;
///< Iterates over tiles of A operand in global memory
using IteratorA = IteratorA_;
///< Iterates over tiles of B operand in global memory
using IteratorB = IteratorB_;
///< Data type of accumulator matrix
using ElementC = ElementC_;
///< Layout of accumulator matrix
using LayoutC = LayoutC_;
///< Policy describing tuning details
using Policy = Policy_;
using IteratorScale = IteratorScale_;
using ElementScale = typename IteratorScale::Element;
using LayoutScale = typename IteratorScale::Layout;
using SmemIteratorA = SmemIteratorA_;
using SmemIteratorB = SmemIteratorB_;
using SmemIteratorScale = SmemIteratorScale_;
static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
using TransformBAfterLDS = TransformBAfterLDS_;
static constexpr WeightOnlyQuantOp QuantOp = QuantOp_;
//
// Dependent types
//
/// Fragment of operand Scale loaded from global memory;
using FragmentScale = typename IteratorScale::Fragment;
/// Fragment of accumulator tile
using FragmentC = typename Policy::Operator::FragmentC;
/// Warp-level Mma
using Operator = typename Policy::Operator;
/// Minimum architecture is Sm80 to support cp.async
using ArchTag = arch::Sm80;
using Dequantizer = warp::MmaTensorOpDequantizer<Operator, typename Base::WarpGemm, Operand::kB, ElementScale,
LayoutScale, 32, QuantOp>;
/// Complex transform on A operand
static ComplexTransform const kTransformA = Operator::kTransformA;
/// Complex transform on B operand
static ComplexTransform const kTransformB = Operator::kTransformB;
/// Internal structure exposed for introspection.
struct Detail
{
static_assert(Base::kWarpGemmIterations > 1,
"The pipelined structure requires at least two warp-level "
"GEMM operations.");
/// Number of cp.async instructions to load one stage of operand A
static int const AsyncCopyIterationsPerStageA = IteratorA::ThreadMap::Iterations::kCount;
/// Number of cp.async instructions to load one stage of operand B
static int const AsyncCopyIterationsPerStageB = IteratorB::ThreadMap::Iterations::kCount;
/// Number of stages
static int const kStages = Stages;
/// Number of cp.async instructions to load on group of operand A
static int const kAccessesPerGroupA
= (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
/// Number of cp.async instructions to load on group of operand B
static int const kAccessesPerGroupB
= (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
};
private:
using WarpFragmentA = typename Operator::FragmentA;
using WarpFragmentB = typename Operator::FragmentB;
Dequantizer warp_dequantizer_;
using ElementA = typename IteratorA::Element;
using ElementB = typename IteratorB::Element;
using LayoutDetailsForB = kernel::LayoutDetailsB<ElementA, ElementB, ArchTag>;
static constexpr bool RequiresTileInterleave
= layout::IsColumnMajorTileInterleave<typename LayoutDetailsForB::Layout>::value;
static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)),
"Layout K must match threadblockK");
private:
//
// Data members
//
/// Iterator to write threadblock-scoped tile of A operand to shared memory
SmemIteratorA smem_iterator_A_;
/// Iterator to write threadblock-scoped tile of B operand to shared memory
SmemIteratorB smem_iterator_B_;
/// Iterator to write threadblock-scoped tile of scale operand to shared memory
SmemIteratorScale smem_iterator_scale_;
public:
/// Construct from tensor references
CUTLASS_DEVICE
DqMmaMultistage(
///< Shared storage needed for internal use by threadblock-scoped GEMM
typename Base::SharedStorage& shared_storage,
///< Group size for quantization. Not used by this main loop since it assumes per-column
int const group_size,
///< ID within the threadblock
int thread_idx,
///< ID of warp
int warp_idx,
///< ID of each thread within a warp
int lane_idx)
: Base(shared_storage, thread_idx, warp_idx, lane_idx)
, warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)},
(warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx)
, smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx)
, smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx)
, smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), {1, Shape::kN}, thread_idx)
{
// Compute warp location within threadblock tile by mapping the warp_id to
// three coordinates:
// _m: the warp's position within the threadblock along the M dimension
// _n: the warp's position within the threadblock along the N dimension
// _k: the warp's position within the threadblock along the K dimension
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
// Add per-warp offsets in units of warp-level tiles
this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n});
}
CUTLASS_DEVICE
void copy_tiles_and_advance(
IteratorA& iterator_A, IteratorB& iterator_B, int group_start_A = 0, int group_start_B = 0)
{
iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector);
this->smem_iterator_A_.set_iteration_index(group_start_A);
// Async Copy for operand A
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::kAccessesPerGroupA; ++j)
{
if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA)
{
typename IteratorA::AccessType* dst_ptr
= reinterpret_cast<typename IteratorA::AccessType*>(this->smem_iterator_A_.get());
int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value
* IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8;
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v)
{
auto gmem_ptr = iterator_A.get();
if (SharedMemoryClear == SharedMemoryClearOption::kZfill)
{
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(dst_ptr + v, gmem_ptr, iterator_A.valid());
}
else
{
cutlass::arch::cp_async<kSrcBytes, kCacheOpA>(dst_ptr + v, gmem_ptr, iterator_A.valid());
}
++iterator_A;
}
++this->smem_iterator_A_;
}
}
iterator_B.set_iteration_index(group_start_B * IteratorB::kAccessesPerVector);
this->smem_iterator_B_.set_iteration_index(group_start_B);
// Async Copy for operand B
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::kAccessesPerGroupB; ++j)
{
if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB)
{
typename IteratorB::AccessType* dst_ptr
= reinterpret_cast<typename IteratorB::AccessType*>(this->smem_iterator_B_.get());
int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value
* IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8;
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v)
{
auto gmem_ptr = iterator_B.get();
if (SharedMemoryClear == SharedMemoryClearOption::kZfill)
{
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(dst_ptr + v, gmem_ptr, iterator_B.valid());
}
else
{
cutlass::arch::cp_async<kSrcBytes, kCacheOpB>(dst_ptr + v, gmem_ptr, iterator_B.valid());
}
++iterator_B;
}
++this->smem_iterator_B_;
}
}
}
/// Perform a threadblock-scoped matrix multiply-accumulate
CUTLASS_DEVICE
void operator()(
///< problem size of GEMM
int gemm_k_iterations,
///< destination accumulator tile
FragmentC& accum,
///< iterator over A operand in global memory
IteratorA iterator_A,
///< iterator over B operand in global memory
IteratorB iterator_B,
///< iterator over scale operand in global memory
IteratorScale iterator_scale,
///< initial value of accumulator
FragmentC const& src_accum)
{
//
// Prologue
//
TransformBAfterLDS lds_converter;
// NOTE - switch to ldg.sts
// Issue this first, so cp.async.commit_group will commit this load as well.
// Note: we do not commit here and this load will commit in the same group as
// the first load of A.
FragmentScale tb_frag_scales;
tb_frag_scales.clear();
iterator_scale.load(tb_frag_scales);
this->smem_iterator_scale_.store(tb_frag_scales);
// Issue several complete stages
CUTLASS_PRAGMA_UNROLL
for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations)
{
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_A.set_iteration_index(0);
this->smem_iterator_A_.set_iteration_index(0);
// Async Copy for operand A
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j)
{
typename IteratorA::AccessType* dst_ptr
= reinterpret_cast<typename IteratorA::AccessType*>(this->smem_iterator_A_.get());
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v)
{
int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value
* IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8;
int src_bytes = (iterator_A.valid() ? kSrcBytes : 0);
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
dst_ptr + v, iterator_A.get(), iterator_A.valid());
++iterator_A;
}
++this->smem_iterator_A_;
}
iterator_B.set_iteration_index(0);
this->smem_iterator_B_.set_iteration_index(0);
// Async Copy for operand B
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j)
{
typename IteratorB::AccessType* dst_ptr
= reinterpret_cast<typename IteratorB::AccessType*>(this->smem_iterator_B_.get());
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v)
{
int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value
* IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8;
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
dst_ptr + v, iterator_B.get(), iterator_B.valid());
++iterator_B;
}
++this->smem_iterator_B_;
}
// Move to the next stage
iterator_A.add_tile_offset({0, 1});
iterator_B.add_tile_offset({1, 0});
this->smem_iterator_A_.add_tile_offset({0, 1});
this->smem_iterator_B_.add_tile_offset({1, 0});
// Defines the boundary of a stage of cp.async.
cutlass::arch::cp_async_fence();
}
// Perform accumulation in the 'd' output operand
accum = src_accum;
//
// Clear the remaining tiles of SMEM. This is a functional requirement for some kernels
// so that all accumulator elements outside the GEMM footprint are zero.
//
if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage)
{
/// Iterator to write threadblock-scoped tile of A operand to shared memory
SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_);
typename IteratorA::AccessType zero_A;
zero_A.clear();
last_smem_iterator_A.set_iteration_index(0);
// Async Copy for operand A
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j)
{
typename IteratorA::AccessType* dst_ptr
= reinterpret_cast<typename IteratorA::AccessType*>(last_smem_iterator_A.get());
*dst_ptr = zero_A;
++last_smem_iterator_A;
}
/// Iterator to write threadblock-scoped tile of B operand to shared memory
SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_);
typename IteratorB::AccessType zero_B;
zero_B.clear();
last_smem_iterator_B.set_iteration_index(0);
// Async Copy for operand B
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j)
{
typename IteratorB::AccessType* dst_ptr
= reinterpret_cast<typename IteratorB::AccessType*>(last_smem_iterator_B.get());
*dst_ptr = zero_B;
++last_smem_iterator_B;
}
}
// Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - #committed)
cutlass::arch::cp_async_wait<Base::kStages - 2>();
__syncthreads();
// Pair of fragments used to overlap shared memory loads and math
// instructions
WarpFragmentA warp_frag_A[2];
WarpFragmentB warp_frag_B[2];
typename Dequantizer::FragmentScale warp_frag_scales;
Operator warp_mma;
this->warp_tile_iterator_A_.set_kgroup_index(0);
this->warp_tile_iterator_B_.set_kgroup_index(0);
this->warp_tile_iterator_A_.load(warp_frag_A[0]);
this->warp_tile_iterator_B_.load(warp_frag_B[0]);
warp_dequantizer_.load(warp_frag_scales);
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B_;
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
int smem_write_stage_idx = Base::kStages - 1;
int smem_read_stage_idx = 0;
//
// Mainloop
//
CUTLASS_GEMM_LOOP
for (; gemm_k_iterations > (-Base::kStages + 1);)
{
//
// Loop over GEMM K dimension
//
// Computes a warp-level GEMM on data held in shared memory
// Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate
CUTLASS_PRAGMA_UNROLL
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k)
{
// Load warp-level tiles from shared memory, wrapping to k offset if
// this is the last group as the case may be.
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
++this->warp_tile_iterator_A_;
int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1)
{
this->warp_tile_iterator_B_.set_kgroup_index(
(warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB);
this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]);
++this->warp_tile_iterator_B_;
}
typename TransformBAfterLDS::result_type converted_frag_B
= lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);
warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales);
using FragmentOperandB = cutlass::Array<ElementA, Operator::FragmentB::kElements>;
constexpr cutlass::FloatRoundStyle RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
constexpr int ConversionVectorWidth = TransformBAfterLDS::result_type::kElements;
static_assert(ConversionVectorWidth == FragmentOperandB::kElements);
using Converter
= cutlass::NumericArrayConverter<ElementA, ElementScale, ConversionVectorWidth, RoundStyle>;
FragmentOperandB converted_frag_B_operand = Converter::convert(converted_frag_B);
run_warp_mma(warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B_operand, accum,
warp_tileB_k_compute_offset);
// Issue global->shared copies for the this stage
if (warp_mma_k < Base::kWarpGemmIterations - 1)
{
int group_start_iteration_A, group_start_iteration_B;
group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA;
group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB;
copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B);
}
if (warp_mma_k + 2 == Base::kWarpGemmIterations)
{
int group_start_iteration_A, group_start_iteration_B;
group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA;
group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB;
copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B);
// Inserts a memory fence between stages of cp.async instructions.
cutlass::arch::cp_async_fence();
// Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 -
// #committed)
arch::cp_async_wait<Base::kStages - 2>();
__syncthreads();
// Move to the next stage
iterator_A.add_tile_offset({0, 1});
iterator_B.add_tile_offset({1, 0});
this->smem_iterator_A_.add_tile_offset({0, 1});
this->smem_iterator_B_.add_tile_offset({1, 0});
// Add negative offsets to return iterators to the 'start' of the
// circular buffer in shared memory
if (smem_write_stage_idx == (Base::kStages - 1))
{
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
smem_write_stage_idx = 0;
}
else
{
++smem_write_stage_idx;
}
if (smem_read_stage_idx == (Base::kStages - 1))
{
this->warp_tile_iterator_A_.add_tile_offset(
{0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
this->warp_tile_iterator_B_.add_tile_offset(
{-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0});
smem_read_stage_idx = 0;
}
else
{
++smem_read_stage_idx;
}
--gemm_k_iterations;
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
}
}
}
if (SharedMemoryClear == SharedMemoryClearOption::kZfill)
{
// commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop
cutlass::arch::cp_async_fence();
cutlass::arch::cp_async_wait<0>();
__syncthreads();
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,106 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
*/
#pragma once
#include "cutlass/aligned_buffer.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h"
#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h"
#include "cutlass_extensions/interleaved_numeric_conversion.h"
#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h"
#include "cutlass_extensions/gemm_configs.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace gemm
{
namespace threadblock
{
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Iterates over tiles of A operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
typename IteratorA_,
/// Iterates over tiles of A operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorA_,
/// Iterates over tiles of B operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
typename IteratorB_,
/// Iterates over tiles of B operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorB_,
/// Data type for the scales
typename IteratorScale_,
/// Iterators over scales in shared memory
typename SmemIteratorScale_,
/// Data type of accumulator matrix
typename ElementC_,
/// Data type of accumulator matrix
typename LayoutC_,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy_,
/// Converter for B matrix applied immediately after the LDG (before STS)
typename TransformBAfterLDG_,
/// Converter for B matrix applited immediately after the LDS
typename TransformBAfterLDS_,
/// The quantization operator being used
WeightOnlyQuantOp QuantOp_,
/// Used for partial specialization
typename Enable = void>
class DqMmaPipelined;
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
#include "cutlass_extensions/gemm/threadblock/dq_mma_pipelined_finegrained.h"
#include "cutlass_extensions/gemm/threadblock/dq_mma_pipelined_percol.h"

View File

@@ -0,0 +1,486 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
*/
#pragma once
#include "cutlass/aligned_buffer.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h"
#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h"
#include "cutlass_extensions/interleaved_numeric_conversion.h"
#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h"
#include "cutlass_extensions/gemm_configs.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace gemm
{
namespace threadblock
{
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Iterates over tiles of A operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
typename IteratorA_,
/// Iterates over tiles of A operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorA_,
/// Iterates over tiles of B operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
typename IteratorB_,
/// Iterates over tiles of B operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorB_,
/// Iterators over scales in global memory
typename IteratorScale_,
/// Iterators over scales in shared memory
typename SmemIteratorScale_,
/// Data type of accumulator matrix
typename ElementC_,
/// Layout of accumulator matrix
typename LayoutC_,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy_,
/// Converter for B matrix applied immediately after the LDG (before STS)
typename TransformBAfterLDG_,
/// Converter for B matrix applited immediately after the LDS
typename TransformBAfterLDS_,
/// The quantization operator being used
WeightOnlyQuantOp QuantOp_>
class DqMmaPipelined<Shape_, IteratorA_, SmemIteratorA_, IteratorB_, SmemIteratorB_, IteratorScale_, SmemIteratorScale_,
ElementC_, LayoutC_, Policy_, TransformBAfterLDG_, TransformBAfterLDS_, QuantOp_,
std::enable_if_t<isFinegrained(QuantOp_)>>
: public DqMmaBase<Shape_, Policy_, typename SmemIteratorScale_::Element, 2, QuantOp_>
{
public:
///< Base class
using Base = DqMmaBase<Shape_, Policy_, typename SmemIteratorScale_::Element, 2, QuantOp_>;
using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory
using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory
using ElementC = ElementC_; ///< Data type of accumulator matrix
using LayoutC = LayoutC_; ///< Layout of accumulator matrix
using Policy = Policy_; ///< Policy describing tuning details
using IteratorScale = IteratorScale_;
using ElementScale = typename IteratorScale::Element;
using LayoutScale = typename IteratorScale::Layout;
using SmemIteratorA = SmemIteratorA_;
using SmemIteratorB = SmemIteratorB_;
using SmemIteratorScale = SmemIteratorScale_;
using TransformBAfterLDG = TransformBAfterLDG_;
using TransformBAfterLDS = TransformBAfterLDS_;
static constexpr WeightOnlyQuantOp QuantOp = QuantOp_;
//
// Dependent types
//
/// Fragment of operand A loaded from global memory
using FragmentA = typename IteratorA::Fragment;
/// Fragment of operand B loaded from global memory
using FragmentB = typename IteratorB::Fragment;
/// Fragment of operand Scale loaded from global memory;
using FragmentScale = typename IteratorScale::Fragment;
/// Fragment of accumulator tile
using FragmentC = typename Policy::Operator::FragmentC;
/// Warp-level Mma
using Operator = typename Policy::Operator;
/// Obtain the arch tag from the warp-level operator
using ArchTag = typename Policy::Operator::ArchTag;
using Dequantizer = warp::MmaTensorOpDequantizer<Operator, typename Base::WarpGemm, Operand::kB,
typename SmemIteratorScale::Element, LayoutScale, 32, QuantOp>;
/// Complex transform on A operand
static ComplexTransform const kTransformA = Operator::kTransformA;
/// Complex transform on B operand
static ComplexTransform const kTransformB = Operator::kTransformB;
// staticaly assert kStages for DqMmaPipelined is two (Double-buffered pipeline)
static_assert((Base::kStages == 2), "DqMmaPipelined requires kStages set to value 2");
static_assert(Base::SharedStorage::ShapeScale::kRow == Base::kStages, "");
static_assert(Base::SharedStorage::ShapeScale::kColumn == Shape::kN, "");
private:
using WarpFragmentA = typename Operator::FragmentA;
using WarpFragmentB = typename Operator::FragmentB;
Dequantizer warp_dequantizer_;
using WarpFragmentScale = typename Dequantizer::FragmentScale;
using WarpFragmentZero = typename Dequantizer::FragmentZero;
using ElementA = typename IteratorA::Element;
using ElementB = typename IteratorB::Element;
using LayoutDetailsForB = kernel::LayoutDetailsB<ElementA, ElementB, ArchTag>;
static constexpr bool RequiresTileInterleave
= layout::IsColumnMajorTileInterleave<typename LayoutDetailsForB::Layout>::value;
static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)),
"Layout K must match threadblockK");
protected:
/// Iterator to write threadblock-scoped tile of A operand to shared memory
SmemIteratorA smem_iterator_A_;
/// Iterator to write threadblock-scoped tile of B operand to shared memory
SmemIteratorB smem_iterator_B_;
/// Iterator to write threadblock-scoped tile of scale and zero operand to shared memory
SmemIteratorScale smem_iterator_scale_;
public:
/// Construct from tensor references
CUTLASS_DEVICE
DqMmaPipelined(typename Base::SharedStorage&
shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM
int const group_size, ///< The group size for quantization
int thread_idx, ///< ID within the threadblock
int warp_idx, ///< ID of warp
int lane_idx ///< ID of each thread within a warp
)
: Base(shared_storage, thread_idx, warp_idx, lane_idx)
, warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)},
{shared_storage.operand_zero.data(), LayoutScale(Shape::kN)},
(warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx)
, smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx)
, smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx)
, smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(),
shared_storage.operand_zero.data(), {Base::kStages, Shape::kN}, thread_idx, group_size)
{
// Compute warp location within threadblock tile by mapping the warp_id to
// three coordinates:
// _m: the warp's position within the threadblock along the M dimension
// _n: the warp's position within the threadblock along the N dimension
// _k: the warp's position within the threadblock along the K dimension
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
// Add per-warp offsets in units of warp-level tiles
this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n});
}
CUTLASS_DEVICE
void copy_scales_and_advance(IteratorScale& iterator_scale)
{
using TransformScale = NumericArrayConverter<typename SmemIteratorScale::Element,
typename FragmentScale::Element, FragmentScale::kElements>;
FragmentScale tb_frag_scales;
FragmentScale tb_frag_zeros;
tb_frag_scales.clear();
tb_frag_zeros.clear();
TransformScale transformScale;
using FragmentElement = typename FragmentScale::Element;
auto gmem_scale_ptr = iterator_scale.get_scale();
auto gmem_zero_ptr = iterator_scale.get_zero();
arch::global_load<FragmentScale, sizeof(FragmentScale)>(tb_frag_scales, gmem_scale_ptr, iterator_scale.valid());
if (gmem_zero_ptr != nullptr)
{
arch::global_load<FragmentScale, sizeof(FragmentScale)>(
tb_frag_zeros, gmem_zero_ptr, iterator_scale.valid());
}
typename TransformScale::result_type tb_frag_scales_fp16 = transformScale(tb_frag_scales);
typename TransformScale::result_type tb_frag_zeros_fp16;
if (gmem_zero_ptr != nullptr)
tb_frag_zeros_fp16 = transformScale(tb_frag_zeros);
auto frag_scale_ptr_fp16 = reinterpret_cast<typename SmemIteratorScale::Element*>(&tb_frag_scales_fp16);
auto frag_zero_ptr_fp16 = reinterpret_cast<typename SmemIteratorScale::Element*>(&tb_frag_zeros_fp16);
auto smem_scale_ptr = this->smem_iterator_scale_.get_scale();
auto smem_zero_ptr = this->smem_iterator_scale_.get_zero();
if (iterator_scale.valid())
{
auto smem_offset = cast_smem_ptr_to_uint(smem_scale_ptr);
arch::shared_store<sizeof(FragmentScale)>(smem_offset, frag_scale_ptr_fp16);
if (gmem_zero_ptr != nullptr)
{
smem_offset = cast_smem_ptr_to_uint(smem_zero_ptr);
arch::shared_store<sizeof(FragmentScale)>(smem_offset, frag_zero_ptr_fp16);
}
}
if (iterator_scale.group_size_ == 64)
{
iterator_scale.add_tile_offset({1, 0});
}
else if (iterator_scale.group_size_ == 128)
{
if constexpr (Shape::kK == 128)
{
iterator_scale.add_tile_offset({1, 0});
}
else if constexpr (Shape::kK == 64)
{
if (iterator_scale.row_groupsize64_ & 0x1)
{
iterator_scale.add_tile_offset({1, 0});
}
}
else
{
static_assert(Shape::kK == 0, "Unsupported k tile shape, can only be 64 or 128");
}
}
iterator_scale.row_groupsize64_++;
this->smem_iterator_scale_.add_tile_offset({1, 0});
}
/// Perform a threadblock-scoped matrix multiply-accumulate
CUTLASS_DEVICE
void operator()(int gemm_k_iterations, ///< number of iterations of the mainloop
FragmentC& accum, ///< destination accumulator tile
IteratorA iterator_A, ///< iterator over A operand in global memory
IteratorB iterator_B, ///< iterator over B operand in global memory
IteratorScale iterator_scale, ///< iterator over scale operand in global memory
FragmentC const& src_accum)
{ ///< source accumulator tile
//
// Prologue
//
TransformBAfterLDG ldg_converter;
TransformBAfterLDS lds_converter;
using TransformA
= NumericArrayConverter<typename WarpFragmentA::Element, typename FragmentA::Element, FragmentA::kElements>;
// These transforms are mainly to handle when we have bfloat activations and weights in GMEM and want
// to issue HMMA on architectures older than Ampere. We will convert to FP16 before STS.
TransformA transformA;
// Perform accumulation in the 'd' output operand
accum = src_accum;
FragmentA tb_frag_A;
FragmentB tb_frag_B;
tb_frag_A.clear();
tb_frag_B.clear();
// The last kblock is loaded in the prolog
iterator_A.load(tb_frag_A);
iterator_B.load(tb_frag_B);
++iterator_A;
++iterator_B;
this->smem_iterator_A_.store(transformA(tb_frag_A));
this->smem_iterator_B_.store(ldg_converter(tb_frag_B));
++this->smem_iterator_A_;
++this->smem_iterator_B_;
copy_scales_and_advance(iterator_scale);
__syncthreads();
// Pair of fragments used to overlap shared memory loads and math instructions
WarpFragmentA warp_frag_A[2];
WarpFragmentB warp_frag_B[2];
WarpFragmentScale warp_frag_scales;
WarpFragmentZero warp_frag_zero;
this->warp_tile_iterator_A_.set_kgroup_index(0);
this->warp_tile_iterator_B_.set_kgroup_index(0);
this->warp_tile_iterator_A_.load(warp_frag_A[0]);
this->warp_tile_iterator_B_.load(warp_frag_B[0]);
warp_dequantizer_.load(warp_frag_scales, warp_frag_zero);
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B_;
warp_dequantizer_.add_pointer_offset(Shape::kN);
Operator warp_mma;
int smem_write_stage_idx = 1;
// Avoid reading out of bounds
iterator_A.clear_mask(gemm_k_iterations <= 1);
iterator_B.clear_mask(gemm_k_iterations <= 1);
iterator_scale.clear_mask(gemm_k_iterations <= 1);
// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
// shared memory loads (which have the tighest latency requirement).
//
// Mainloop
//
// Note: The main loop does not support Base::kWarpGemmIterations == 2.
CUTLASS_GEMM_LOOP
for (; gemm_k_iterations > 0; --gemm_k_iterations)
{
//
// Loop over GEMM K dimension
//
CUTLASS_PRAGMA_UNROLL
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k)
{
// Load warp-level tiles from shared memory, wrapping to k offset if this is the last group
// as the case may be.
if (warp_mma_k == Base::kWarpGemmIterations - 1)
{
// Write fragments to shared memory
this->smem_iterator_A_.store(transformA(tb_frag_A));
this->smem_iterator_B_.store(ldg_converter(tb_frag_B));
__syncthreads();
++this->smem_iterator_A_;
++this->smem_iterator_B_;
// Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
if (smem_write_stage_idx == 1)
{
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
this->smem_iterator_scale_.add_tile_offset({-Base::kStages, 0});
}
else
{
this->warp_tile_iterator_A_.add_tile_offset(
{0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
this->warp_tile_iterator_B_.add_tile_offset(
{-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0});
warp_dequantizer_.add_pointer_offset(-Base::kStages * Shape::kN);
}
smem_write_stage_idx ^= 1;
}
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
++this->warp_tile_iterator_A_;
int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
// We are just about to finish computing on a fragment of B, so initiate the load for the next fragment.
if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1)
{
this->warp_tile_iterator_B_.set_kgroup_index(
(warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB);
this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]);
++this->warp_tile_iterator_B_;
}
if (warp_mma_k == 0)
{
iterator_A.load(tb_frag_A);
iterator_B.load(tb_frag_B);
++iterator_A;
++iterator_B;
copy_scales_and_advance(iterator_scale);
// Avoid reading out of bounds if this was the last loop iteration
iterator_A.clear_mask(gemm_k_iterations <= 2);
iterator_B.clear_mask(gemm_k_iterations <= 2);
iterator_scale.clear_mask(gemm_k_iterations <= 2);
}
typename TransformBAfterLDS::result_type converted_frag_B
= lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);
warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales, warp_frag_zero);
run_warp_mma(
warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset);
}
// Load the scales needed for the next tile iteration
warp_dequantizer_.load(warp_frag_scales, warp_frag_zero);
// Update internal pointer to the set of scales in shared memory
warp_dequantizer_.add_pointer_offset(Shape::kN);
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,399 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
*/
#pragma once
#include "cutlass/aligned_buffer.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h"
#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h"
#include "cutlass_extensions/interleaved_numeric_conversion.h"
#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h"
#include "cutlass_extensions/gemm_configs.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace gemm
{
namespace threadblock
{
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Iterates over tiles of A operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
typename IteratorA_,
/// Iterates over tiles of A operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorA_,
/// Iterates over tiles of B operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
typename IteratorB_,
/// Iterates over tiles of B operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorB_,
/// Iterators over scales in global memory
typename IteratorScale_,
/// Iterators over scales in shared memory
typename SmemIteratorScale_,
/// Data type of accumulator matrix
typename ElementC_,
/// Layout of accumulator matrix
typename LayoutC_,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy_,
/// Converter for B matrix applied immediately after the LDG (before STS)
typename TransformBAfterLDG_,
/// Converter for B matrix applited immediately after the LDS
typename TransformBAfterLDS_,
/// The quantization operator being used
WeightOnlyQuantOp QuantOp_>
class DqMmaPipelined<Shape_, IteratorA_, SmemIteratorA_, IteratorB_, SmemIteratorB_, IteratorScale_, SmemIteratorScale_,
ElementC_, LayoutC_, Policy_, TransformBAfterLDG_, TransformBAfterLDS_, QuantOp_,
std::enable_if_t<!isFinegrained(QuantOp_)>>
: public DqMmaBase<Shape_, Policy_, typename SmemIteratorScale_::Element, 2, QuantOp_>
{
public:
///< Base class
using Base = DqMmaBase<Shape_, Policy_, typename SmemIteratorScale_::Element, 2, QuantOp_>;
using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory
using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory
using ElementC = ElementC_; ///< Data type of accumulator matrix
using LayoutC = LayoutC_; ///< Layout of accumulator matrix
using Policy = Policy_; ///< Policy describing tuning details
using IteratorScale = IteratorScale_;
using ElementScale = typename IteratorScale::Element;
using LayoutScale = typename IteratorScale::Layout;
using SmemIteratorA = SmemIteratorA_;
using SmemIteratorB = SmemIteratorB_;
using SmemIteratorScale = SmemIteratorScale_;
using TransformBAfterLDG = TransformBAfterLDG_;
using TransformBAfterLDS = TransformBAfterLDS_;
static constexpr WeightOnlyQuantOp QuantOp = QuantOp_;
//
// Dependent types
//
/// Fragment of operand A loaded from global memory
using FragmentA = typename IteratorA::Fragment;
/// Fragment of operand B loaded from global memory
using FragmentB = typename IteratorB::Fragment;
/// Fragment of operand Scale loaded from global memory;
using FragmentScale = typename IteratorScale::Fragment;
/// Fragment of accumulator tile
using FragmentC = typename Policy::Operator::FragmentC;
/// Warp-level Mma
using Operator = typename Policy::Operator;
/// Obtain the arch tag from the warp-level operator
using ArchTag = typename Policy::Operator::ArchTag;
using Dequantizer = warp::MmaTensorOpDequantizer<Operator, typename Base::WarpGemm, Operand::kB,
typename SmemIteratorScale::Fragment::Element, LayoutScale, 32, QuantOp>;
/// Complex transform on A operand
static ComplexTransform const kTransformA = Operator::kTransformA;
/// Complex transform on B operand
static ComplexTransform const kTransformB = Operator::kTransformB;
// staticaly assert kStages for DqMmaPipelined is two (Double-buffered pipeline)
static_assert((Base::kStages == 2), "DqMmaPipelined requires kStages set to value 2");
private:
using WarpFragmentA = typename Operator::FragmentA;
using WarpFragmentB = typename Operator::FragmentB;
Dequantizer warp_dequantizer_;
using ElementA = typename IteratorA::Element;
using ElementB = typename IteratorB::Element;
using LayoutDetailsForB = kernel::LayoutDetailsB<ElementA, ElementB, ArchTag>;
static constexpr bool RequiresTileInterleave
= layout::IsColumnMajorTileInterleave<typename LayoutDetailsForB::Layout>::value;
static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)),
"Layout K must match threadblockK");
protected:
/// Iterator to write threadblock-scoped tile of A operand to shared memory
SmemIteratorA smem_iterator_A_;
/// Iterator to write threadblock-scoped tile of B operand to shared memory
SmemIteratorB smem_iterator_B_;
/// Iterator to write threadblock-scoped tile of scale operand to shared memory
SmemIteratorScale smem_iterator_scale_;
public:
/// Construct from tensor references
CUTLASS_DEVICE
DqMmaPipelined(typename Base::SharedStorage&
shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM
int const group_size, ///< Will not be used, just to adapt to finegrained modifications and make the compilation
///< successful. Because DqMmaPipelined is only enabled for sm<80, so even if this
///< argument is not added, it does not affect compilation for sm>=80.
int thread_idx, ///< ID within the threadblock
int warp_idx, ///< ID of warp
int lane_idx ///< ID of each thread within a warp
)
: Base(shared_storage, thread_idx, warp_idx, lane_idx)
, warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)},
(warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx)
, smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx)
, smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx)
, smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), {1, Shape::kN}, thread_idx)
{
// Compute warp location within threadblock tile by mapping the warp_id to
// three coordinates:
// _m: the warp's position within the threadblock along the M dimension
// _n: the warp's position within the threadblock along the N dimension
// _k: the warp's position within the threadblock along the K dimension
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
// Add per-warp offsets in units of warp-level tiles
this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n});
}
/// Perform a threadblock-scoped matrix multiply-accumulate
CUTLASS_DEVICE
void operator()(int gemm_k_iterations, ///< number of iterations of the mainloop
FragmentC& accum, ///< destination accumulator tile
IteratorA iterator_A, ///< iterator over A operand in global memory
IteratorB iterator_B, ///< iterator over B operand in global memory
IteratorScale iterator_scale, ///< iterator over scale operand in global memory
FragmentC const& src_accum)
{ ///< source accumulator tile
//
// Prologue
//
TransformBAfterLDG ldg_converter;
TransformBAfterLDS lds_converter;
using TransformA
= NumericArrayConverter<typename WarpFragmentA::Element, typename FragmentA::Element, FragmentA::kElements>;
using TransformScale = NumericArrayConverter<typename SmemIteratorScale::Fragment::Element,
typename FragmentScale::Element, FragmentScale::kElements>;
// These transforms are mainly to handle when we have bfloat activations and weights in GMEM and want
// to issue HMMA on architectures older than Ampere. We will convert to FP16 before STS.
TransformA transformA;
TransformScale transformScale;
// Perform accumulation in the 'd' output operand
accum = src_accum;
FragmentA tb_frag_A;
FragmentB tb_frag_B;
FragmentScale tb_frag_scales;
using WarpFragmentScale = typename Dequantizer::FragmentScale;
WarpFragmentScale warp_frag_scales;
tb_frag_A.clear();
tb_frag_B.clear();
tb_frag_scales.clear();
// The last kblock is loaded in the prolog
iterator_A.load(tb_frag_A);
iterator_B.load(tb_frag_B);
iterator_scale.load(tb_frag_scales);
++iterator_A;
++iterator_B;
this->smem_iterator_A_.store(transformA(tb_frag_A));
this->smem_iterator_B_.store(ldg_converter(tb_frag_B));
this->smem_iterator_scale_.store(transformScale(tb_frag_scales));
++this->smem_iterator_A_;
++this->smem_iterator_B_;
__syncthreads();
warp_dequantizer_.load(warp_frag_scales);
// Pair of fragments used to overlap shared memory loads and math instructions
WarpFragmentA warp_frag_A[2];
WarpFragmentB warp_frag_B[2];
this->warp_tile_iterator_A_.set_kgroup_index(0);
this->warp_tile_iterator_B_.set_kgroup_index(0);
this->warp_tile_iterator_A_.load(warp_frag_A[0]);
this->warp_tile_iterator_B_.load(warp_frag_B[0]);
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B_;
Operator warp_mma;
int smem_write_stage_idx = 1;
// Avoid reading out of bounds
iterator_A.clear_mask(gemm_k_iterations <= 1);
iterator_B.clear_mask(gemm_k_iterations <= 1);
// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
// shared memory loads (which have the tighest latency requirement).
//
// Mainloop
//
// Note: The main loop does not support Base::kWarpGemmIterations == 2.
CUTLASS_GEMM_LOOP
for (; gemm_k_iterations > 0; --gemm_k_iterations)
{
//
// Loop over GEMM K dimension
//
CUTLASS_PRAGMA_UNROLL
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k)
{
// Load warp-level tiles from shared memory, wrapping to k offset if this is the last group
// as the case may be.
if (warp_mma_k == Base::kWarpGemmIterations - 1)
{
// Write fragments to shared memory
this->smem_iterator_A_.store(transformA(tb_frag_A));
this->smem_iterator_B_.store(ldg_converter(tb_frag_B));
__syncthreads();
++this->smem_iterator_A_;
++this->smem_iterator_B_;
// Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
if (smem_write_stage_idx == 1)
{
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
}
else
{
this->warp_tile_iterator_A_.add_tile_offset(
{0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
this->warp_tile_iterator_B_.add_tile_offset(
{-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0});
}
smem_write_stage_idx ^= 1;
}
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
++this->warp_tile_iterator_A_;
int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
// We are just about to finish computing on a fragment of B, so initiate the load for the next fragment.
if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1)
{
this->warp_tile_iterator_B_.set_kgroup_index(
(warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB);
this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]);
++this->warp_tile_iterator_B_;
}
if (warp_mma_k == 0)
{
iterator_A.load(tb_frag_A);
iterator_B.load(tb_frag_B);
++iterator_A;
++iterator_B;
// Avoid reading out of bounds if this was the last loop iteration
iterator_A.clear_mask(gemm_k_iterations <= 2);
iterator_B.clear_mask(gemm_k_iterations <= 2);
}
typename TransformBAfterLDS::result_type converted_frag_B
= lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);
warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales);
run_warp_mma(
warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset);
}
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,107 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Default warp-level GEMM operators selected by data type, size, and layouts of operands.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/gemm/warp/default_mma_tensor_op.h"
#include "cutlass/gemm/warp/mma_tensor_op.h"
#include "cutlass_extensions/arch/mma.h"
#include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h"
namespace cutlass
{
namespace gemm
{
namespace warp
{
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Partial specialization for m-by-n-by-kgroup
template <
/// Shape of one matrix production operation (concept: GemmShape)
typename WarpShape_,
/// Shape of one matrix production operation (concept: GemmShape)
typename InstructionShape_,
/// Data type of A elements,
typename ElementA,
/// Layout of A matrix (concept: MatrixLayout)
typename LayoutA,
/// Data type of B elements
typename ElementB,
/// Layout of B matrix (concept: MatrixLayout)
typename LayoutB,
/// Element type of C matrix
typename ElementC,
/// Layout of C matrix (concept: MatrixLayout)
typename LayoutC,
/// Number of partitions along K dimension
int PartitionsK,
/// Store the accumulators in row major or column major. Row major is used
/// when output layout is interleaved.
bool AccumulatorsInRowMajor>
struct DefaultMmaTensorOp<WarpShape_, InstructionShape_, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
arch::OpMultiplyAddDequantizeInterleavedBToA, PartitionsK, AccumulatorsInRowMajor>
{
private:
// Shape for computing the FP16s
using ComputeInstructionShape = InstructionShape_;
// Chosen so we get K=16 for int8 and K=32 for int4.
static constexpr int LoadInstructionK = 128 / sizeof_bits<ElementB>::value;
// Shape for loading the narrow data type from shared memory
using LoadInstructionShape = GemmShape<InstructionShape_::kM, InstructionShape_::kN, LoadInstructionK>;
public:
using Policy = cutlass::gemm::warp::MmaTensorOpPolicy<
cutlass::arch::Mma<InstructionShape_, 32, ElementA, cutlass::layout::RowMajor, ElementA,
cutlass::layout::ColumnMajor, ElementC, cutlass::layout::RowMajor, arch::OpMultiplyAdd>,
cutlass::MatrixShape<1, 1>>;
// Define the warp-level tensor op
using Type = cutlass::gemm::warp::MmaTensorOpComputeBWithF16<WarpShape_, ElementA, LayoutA, ElementB, LayoutB,
ElementC, LayoutC, Policy, LoadInstructionShape, PartitionsK, AccumulatorsInRowMajor>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace warp
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,306 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Templates implementing warp-level matrix multiply-accumulate operations targeting
Tensor Cores.
*/
#pragma once
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/platform/platform.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/numeric_types.h"
#include "cutlass/arch/memory_sm75.h"
#include "cutlass/arch/mma_sm75.h"
#include "cutlass/arch/mma_sm80.h"
#include "cutlass/arch/mma_sm89.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/warp/mma.h"
#include "cutlass/gemm/warp/mma_tensor_op_policy.h"
#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h"
#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace gemm
{
namespace warp
{
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Data type of A elements
typename ElementA_,
/// Layout of A matrix (concept: MatrixLayout)
typename LayoutA_,
/// Data type of B elements
typename ElementB_,
/// Layout of B matrix (concept: MatrixLayout)
typename LayoutB_,
/// Element type of C matrix
typename ElementC_,
/// Layout of C matrix (concept: MatrixLayout)
typename LayoutC_,
/// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy)
typename Policy_,
/// Instruction shape to override shared memory iterators with
typename SharedMemoryInstructionShape_,
/// Number of partitions along K dimension
int PartitionsK_ = 1,
/// Store the accumulators in row major or column major. Row major is used
/// when output layout is interleaved.
bool AccumulatorsInRowMajor = false,
/// Used for partial specialization
typename Enable = bool>
class MmaTensorOpComputeBWithF16
{
public:
/// Shape of warp-level matrix operation (concept: GemmShape)
using Shape = Shape_;
/// Data type of multiplicand A
using ElementA = ElementA_;
/// Layout of multiplicand A
using LayoutA = LayoutA_;
/// Data type of multiplicand B
using ElementB = ElementB_;
/// Layout of multiplicand B
using LayoutB = LayoutB_;
/// Data type of accumulator matrix C
using ElementC = ElementC_;
/// Layout of accumulator matrix C
using LayoutC = LayoutC_;
/// Shape of the warp in units of thread (concept: MmaLanePolicySimt)
using Policy = Policy_;
/// Underlying matrix multiply operator (concept: arch::Mma)
using ArchMmaOperator = typename Policy::Operator;
/// Indicates math operator
using MathOperator = typename ArchMmaOperator::Operator;
/// Architecture tag from underlying instruction
using ArchTag = typename ArchMmaOperator::ArchTag;
static_assert((platform::is_same<typename ArchMmaOperator::ElementA, half_t>::value
&& platform::is_same<typename ArchMmaOperator::ElementB, half_t>::value)
|| (platform::is_same<typename ArchMmaOperator::ElementA, bfloat16_t>::value
&& platform::is_same<typename ArchMmaOperator::ElementB, bfloat16_t>::value
&& ArchTag::kMinComputeCapability >= 80)
|| (platform::is_same<typename ArchMmaOperator::ElementA, float_e4m3_t>::value
&& platform::is_same<typename ArchMmaOperator::ElementB, float_e4m3_t>::value
&& ArchTag::kMinComputeCapability >= 89),
"MmaTensorOpCvtBToA only supports underlying HMMA/QMMA");
static_assert(platform::is_same<ElementA, half_t>::value
|| (platform::is_same<ElementA, bfloat16_t>::value && ArchTag::kMinComputeCapability >= 80)
|| (platform::is_same<ElementA, float_e4m3_t>::value && ArchTag::kMinComputeCapability >= 89),
"MmaTensorOpCvtBToA only supports Fp16 A or Bf16 A on Ampere+, or FP8 on Ada");
/// Indicates class of matrix operator
using OperatorClass = arch::OpClassTensorOp;
/// Shape of underlying instruction
using InstructionShape = typename ArchMmaOperator::Shape;
/// Instruction shape to override shared memory iterators with
using SharedMemoryInstructionShape = SharedMemoryInstructionShape_;
static_assert(
SharedMemoryInstructionShape::kM == InstructionShape::kM, "M dimension of compute instruction must match load");
static_assert(
SharedMemoryInstructionShape::kN == InstructionShape::kN, "N dimension of compute instruction must match load");
static constexpr int kExpansionFactor = SharedMemoryInstructionShape::kK / InstructionShape::kK;
static_assert(!(Shape::kK % SharedMemoryInstructionShape::kK), "");
/// Complex transform on A operand
static ComplexTransform const kTransformA = ComplexTransform::kNone;
/// Complex transform on B operand
static ComplexTransform const kTransformB = ComplexTransform::kNone;
/// Number of threads participating in warp-level matrix product
static int const kThreadCount = 32;
/// Number of partitions along K dimension
static int const kPartitionsK = PartitionsK_;
public:
/// Iterates over the A operand in memory
using IteratorA
= MmaTensorOpMultiplicandTileIterator<MatrixShape<Shape::kM, Shape::kK>, Operand::kA, ElementA, LayoutA,
MatrixShape<InstructionShape::kM, InstructionShape::kK>, Policy::OpDelta::kRow, kThreadCount, kPartitionsK>;
/// Storage for A tile
using FragmentA = typename IteratorA::Fragment;
/// Storage for transformed A tile
using TransformedFragmentA = Array<typename ArchMmaOperator::ElementA, FragmentA::kElements>;
/// Iterates over the B operand in memory
using IteratorB = MmaTensorOpMultiplicandTileIterator<MatrixShape<Shape::kK, Shape::kN>, Operand::kB, ElementB,
LayoutB, MatrixShape<SharedMemoryInstructionShape::kK, InstructionShape::kN>, Policy::OpDelta::kRow,
kThreadCount, kPartitionsK>;
/// Storage for B tile
using FragmentB = typename IteratorB::Fragment;
/// Storage for transformed B tile
using TransformedFragmentB = Array<typename ArchMmaOperator::ElementB, FragmentB::kElements>;
/// Iterates over the C operand in memory
using IteratorC = MmaTensorOpAccumulatorTileIterator<MatrixShape<Shape::kM, Shape::kN>, ElementC, LayoutC,
typename ArchMmaOperator::Shape, typename Policy::OpDelta>;
/// Storage for C tile
using FragmentC = typename IteratorC::Fragment;
/// Number of mma operations performed
using MmaIterations = MatrixShape<(Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM,
(Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN>;
public:
/// Underlying matrix multiply operator (concept: arch::Mma)
ArchMmaOperator mma;
public:
//
// Methods
//
/// Ctor
CUTLASS_DEVICE
MmaTensorOpComputeBWithF16() {}
/// Performs a warp-level matrix multiply-accumulate operation
CUTLASS_DEVICE
void operator()(FragmentC& D, TransformedFragmentA const& A, TransformedFragmentB const& B, FragmentC const& C,
int const warp_tileB_k_offset) const
{
using MmaOperandA = typename ArchMmaOperator::FragmentA;
using MmaOperandB = typename ArchMmaOperator::FragmentB;
using MmaOperandC = typename ArchMmaOperator::FragmentC;
static_assert(
TransformedFragmentB::kElements == MmaOperandB::kElements * kExpansionFactor * MmaIterations::kColumn,
"Each thread should have a pack of mma registers for each column iteration AND for the expanded K dim of "
"B");
D = C;
MmaOperandA const* ptr_A = reinterpret_cast<MmaOperandA const*>(&A);
MmaOperandB const* ptr_B = reinterpret_cast<MmaOperandB const*>(&B);
MmaOperandC* ptr_D = reinterpret_cast<MmaOperandC*>(&D);
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
// Serpentine visitation order maximizing reuse of Rb
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < MmaIterations::kColumn; ++n)
{
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < MmaIterations::kRow; ++m)
{
int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m);
int n_offsetB = warp_tileB_k_offset + kExpansionFactor * n;
if (AccumulatorsInRowMajor)
{ // matrix B is reordered
mma(ptr_D[n + m_serpentine * MmaIterations::kColumn], ptr_A[m_serpentine], ptr_B[n_offsetB],
ptr_D[n + m_serpentine * MmaIterations::kColumn]);
}
else
{
mma(ptr_D[m_serpentine + n * MmaIterations::kRow], ptr_A[m_serpentine], ptr_B[n_offsetB],
ptr_D[m_serpentine + n * MmaIterations::kRow]);
}
}
}
#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
// Serpentine visitation order maximizing reuse of Ra
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < MmaIterations::kRow; ++m)
{
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < MmaIterations::kColumn; ++n)
{
int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n);
int n_serpentine_offsetB = warp_tileB_k_offset + kExpansionFactor * n_serpentine;
if (AccumulatorsInRowMajor)
{ // matrix B is reordered
mma(ptr_D[n_serpentine + m * MmaIterations::kColumn], ptr_A[m], ptr_B[n_serpentine_offsetB],
ptr_D[n_serpentine + m * MmaIterations::kColumn]);
}
else
{
mma(ptr_D[m + n_serpentine * MmaIterations::kRow], ptr_A[m], ptr_B[n_serpentine_offsetB],
ptr_D[m + n_serpentine * MmaIterations::kRow]);
}
}
}
#else
assert(0);
#endif
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace warp
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,463 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/array.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/arch/arch.h"
#include "cutlass/arch/memory_sm75.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/layout/pitch_linear.h"
#include "cutlass/layout/tensor.h"
#include "cutlass/functional.h"
#include "cutlass/platform/platform.h"
#include "cutlass_extensions/weight_only_quant_op.h"
#include "tensorrt_llm/common/cudaBf16Wrapper.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace gemm
{
namespace warp
{
////////////////////////////////////////////////////////////////////////////////
template <
/// Matrix multiply operator
typename MmaOperator_,
/// Size of the matrix to load (concept: MatrixShape)
typename Shape_,
/// Operand identity
Operand Operand,
/// Data type of Scale elements
typename Element_,
/// Layout of operand
typename Layout_,
/// Number of threads participating in one matrix operation
int Threads,
///
WeightOnlyQuantOp QuantOp_,
///
typename Enable = void>
class MmaTensorOpDequantizer;
////////////////////////////////////////////////////////////////////////////////
// Bfloat specialization for Ampere
template <
/// Underlying matrix multiply operator (concept: MmaTensorOp)
typename MmaOperator_,
/// Shape of the warp level matrix multiply (concept: GemmShape)
typename Shape_,
///
WeightOnlyQuantOp QuantOp_>
class MmaTensorOpDequantizer<MmaOperator_, Shape_, Operand::kB, bfloat16_t, layout::RowMajor, 32, QuantOp_,
typename platform::enable_if<MmaOperator_::ArchTag::kMinComputeCapability >= 80
&& platform::is_same<typename MmaOperator_::ArchMmaOperator::LayoutB, layout::ColumnMajor>::value>::type>
{
public:
/// Mma Operator
using MmaOperator = MmaOperator_;
// The architecture specific mma ooperator being used
using ArchMmaOperator = typename MmaOperator::ArchMmaOperator;
// Mma Instruction Shape
using InstructionShape = typename ArchMmaOperator::Shape;
// This is the ratio of the load instruction vs the compute instruction.
static constexpr int kExpansionFactor = MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK;
/// Type of the scales
using ElementScale = bfloat16_t;
/// Fragment to hold B data before Mma
using FragmentDequantizedOperand = Array<ElementScale, MmaOperator::FragmentB::kElements>;
// Fragment to hold scale data to apply to B before mma
// We need 1 fp16 per matrix iteration in the N dimension
static constexpr int kColsPerMmaPerThread = 1;
using FragmentScale = Array<ElementScale, kColsPerMmaPerThread * MmaOperator::MmaIterations::kColumn>;
using FragmentZero = Array<ElementScale, kColsPerMmaPerThread * MmaOperator::MmaIterations::kColumn>;
/// Warp mma shape
using Shape = Shape_;
/// Layout of the scales in shared memory
using Layout = layout::RowMajor;
/// TensorRef type for loading element from a tensor
using TensorRef = TensorRef<ElementScale, Layout>;
static constexpr WeightOnlyQuantOp QuantOp = QuantOp_;
CUTLASS_DEVICE
MmaTensorOpDequantizer(TensorRef smem_scales, TensorRef smem_zeros, int const warp_idx_n, int const lane_idx)
{
int const warp_offset = warp_idx_n * Shape::kN;
int const quad = lane_idx / 4;
int const thread_offset = warp_offset + quad;
pointer_scale_ = smem_scales.data() + thread_offset;
if constexpr (hasZero(QuantOp))
{
pointer_zero_ = smem_zeros.data() + thread_offset;
}
}
CUTLASS_DEVICE
MmaTensorOpDequantizer(TensorRef smem_scales, int const warp_idx_n, int const lane_idx)
: MmaTensorOpDequantizer(smem_scales, TensorRef(), warp_idx_n, lane_idx)
{
}
CUTLASS_DEVICE
void load(FragmentScale& scale_frag)
{
CUTLASS_PRAGMA_UNROLL
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter)
{
scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN];
}
}
CUTLASS_DEVICE
void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag)
{
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16))
using _MmaOperandB = typename ArchMmaOperator::FragmentB;
using ExpandedMmaOperandB = Array<typename _MmaOperandB::Element, kExpansionFactor * _MmaOperandB::kElements>;
static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn
== FragmentDequantizedOperand::kElements,
"");
__nv_bfloat16 const* scale_ptr = reinterpret_cast<__nv_bfloat16 const*>(&scale_frag);
ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast<ExpandedMmaOperandB*>(&operand_frag);
CUTLASS_PRAGMA_UNROLL
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter)
{
static_assert(ExpandedMmaOperandB::kElements % 2 == 0, "");
__nv_bfloat162 scalex2 = __bfloat162bfloat162(scale_ptr[mma_n_iter]);
__nv_bfloat162* operand_bf16x2_ptr = reinterpret_cast<__nv_bfloat162*>(&operand_frag_ptr[mma_n_iter]);
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii)
{
operand_bf16x2_ptr[ii] = __hmul2(operand_bf16x2_ptr[ii], scalex2);
}
}
#else
// Slow path not implemented here on purpose. If we need to do HMMA on older arch, scale conversion should
// happen before scales are stored to shared memory and we should use the fp16 dequantizer. This will avoid
// numerous conversion instructions in GEMM main loop.
arch::device_breakpoint();
#endif
}
CUTLASS_DEVICE
void load(FragmentScale& scale_frag, FragmentScale& zero_frag)
{
if constexpr (hasZero(QuantOp))
{
CUTLASS_PRAGMA_UNROLL
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter)
{
scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN];
zero_frag[mma_n_iter] = pointer_zero_[mma_n_iter * InstructionShape::kN];
}
}
else
{
CUTLASS_PRAGMA_UNROLL
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter)
{
scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN];
}
}
}
CUTLASS_DEVICE
void dequantize(
FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag, FragmentScale const& zero_frag)
{
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16))
using _MmaOperandB = typename ArchMmaOperator::FragmentB;
using ExpandedMmaOperandB = Array<typename _MmaOperandB::Element, kExpansionFactor * _MmaOperandB::kElements>;
static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn
== FragmentDequantizedOperand::kElements,
"");
__nv_bfloat16 const* scale_ptr = reinterpret_cast<__nv_bfloat16 const*>(&scale_frag);
__nv_bfloat16 const* zero_ptr = reinterpret_cast<__nv_bfloat16 const*>(&zero_frag);
ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast<ExpandedMmaOperandB*>(&operand_frag);
CUTLASS_PRAGMA_UNROLL
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter)
{
static_assert(ExpandedMmaOperandB::kElements % 2 == 0, "");
__nv_bfloat162 scalex2 = __bfloat162bfloat162(scale_ptr[mma_n_iter]);
__nv_bfloat162 zerox2 = __bfloat162bfloat162(zero_ptr[mma_n_iter]);
__nv_bfloat162* operand_bf16x2_ptr = reinterpret_cast<__nv_bfloat162*>(&operand_frag_ptr[mma_n_iter]);
if constexpr (hasZero(QuantOp))
{
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii)
{
operand_bf16x2_ptr[ii] = __hfma2(operand_bf16x2_ptr[ii], scalex2, zerox2);
}
}
else
{
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii)
{
operand_bf16x2_ptr[ii] = __hmul2(operand_bf16x2_ptr[ii], scalex2);
}
}
}
#else
// Slow path not implemented here on purpose. If we need to do HMMA on older arch, scale conversion should
// happen before scales are stored to shared memory and we should use the fp16 dequantizer. This will avoid
// numerous conversion instructions in GEMM main loop.
arch::device_breakpoint();
#endif
}
// Adds a pointer offset in units of elements.
CUTLASS_DEVICE
void add_pointer_offset(int64_t const& offset)
{
static_assert(sizeof(ElementScale) > 1, "");
pointer_scale_ += offset;
pointer_zero_ += offset;
}
private:
ElementScale const* pointer_scale_;
ElementScale const* pointer_zero_;
};
////////////////////////////////////////////////////////////////////////////////
// Specialization for Turing & Ampere
template <
/// Underlying matrix multiply operator (concept: MmaTensorOp)
typename MmaOperator_,
/// Shape of the warp level matrix multiply (concept: GemmShape)
typename Shape_,
///
WeightOnlyQuantOp QuantOp_>
class MmaTensorOpDequantizer<MmaOperator_, Shape_, Operand::kB, half_t, layout::RowMajor, 32, QuantOp_,
typename platform::enable_if<MmaOperator_::ArchTag::kMinComputeCapability >= 75
&& platform::is_same<typename MmaOperator_::ArchMmaOperator::LayoutB, layout::ColumnMajor>::value>::type>
{
public:
/// Mma Operator
using MmaOperator = MmaOperator_;
// The architecture specific mma ooperator being used
using ArchMmaOperator = typename MmaOperator::ArchMmaOperator;
// Mma Instruction Shape
using InstructionShape = typename ArchMmaOperator::Shape;
// This is the ratio of the load instruction vs the compute instruction.
static constexpr int kExpansionFactor = MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK;
/// Type of the scales
using ElementScale = half_t;
/// Fragment to hold B data before Mma
using FragmentDequantizedOperand = Array<ElementScale, MmaOperator::FragmentB::kElements>;
// Fragment to hold scale data to apply to B before mma
// We need 1 fp16 per matrix iteration in the N dimension
static constexpr int kColsPerMmaPerThread = 1;
using FragmentScale = Array<ElementScale, kColsPerMmaPerThread * MmaOperator::MmaIterations::kColumn>;
using FragmentZero = Array<ElementScale, kColsPerMmaPerThread * MmaOperator::MmaIterations::kColumn>;
/// Warp mma shape
using Shape = Shape_;
/// Layout of the scales in shared memory
using Layout = layout::RowMajor;
/// TensorRef type for loading element from a tensor
using TensorRef = TensorRef<ElementScale, Layout>;
static constexpr WeightOnlyQuantOp QuantOp = QuantOp_;
CUTLASS_DEVICE
MmaTensorOpDequantizer(TensorRef smem_scales, TensorRef smem_zeros, int const warp_idx_n, int const lane_idx)
{
int const warp_offset = warp_idx_n * Shape::kN;
int const quad = lane_idx / 4;
int const thread_offset = warp_offset + quad;
pointer_scale_ = smem_scales.data() + thread_offset;
if constexpr (hasZero(QuantOp))
{
pointer_zero_ = smem_zeros.data() + thread_offset;
}
}
CUTLASS_DEVICE
MmaTensorOpDequantizer(TensorRef smem_scales, int const warp_idx_n, int const lane_idx)
: MmaTensorOpDequantizer(smem_scales, TensorRef(), warp_idx_n, lane_idx)
{
}
CUTLASS_DEVICE
void load(FragmentScale& scale_frag)
{
CUTLASS_PRAGMA_UNROLL
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter)
{
scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN];
}
}
CUTLASS_DEVICE
void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag)
{
using _MmaOperandB = typename ArchMmaOperator::FragmentB;
using ExpandedMmaOperandB
= Array<typename FragmentDequantizedOperand::Element, kExpansionFactor * _MmaOperandB::kElements>;
static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn
== FragmentDequantizedOperand::kElements,
"");
multiplies<ExpandedMmaOperandB> mul_op;
ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast<ExpandedMmaOperandB*>(&operand_frag);
CUTLASS_PRAGMA_UNROLL
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter)
{
operand_frag_ptr[mma_n_iter] = mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]);
}
}
CUTLASS_DEVICE
void load(FragmentScale& scale_frag, FragmentScale& zero_frag)
{
if constexpr (hasZero(QuantOp))
{
CUTLASS_PRAGMA_UNROLL
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter)
{
scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN];
zero_frag[mma_n_iter] = pointer_zero_[mma_n_iter * InstructionShape::kN];
}
}
else
{
CUTLASS_PRAGMA_UNROLL
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter)
{
scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN];
}
}
}
CUTLASS_DEVICE
void dequantize(
FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag, FragmentScale const& zero_frag)
{
using _MmaOperandB = typename ArchMmaOperator::FragmentB;
using ExpandedMmaOperandB
= Array<typename FragmentDequantizedOperand::Element, kExpansionFactor * _MmaOperandB::kElements>;
static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn
== FragmentDequantizedOperand::kElements,
"");
multiplies<ExpandedMmaOperandB> mul_op;
ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast<ExpandedMmaOperandB*>(&operand_frag);
if constexpr (hasZero(QuantOp))
{
plus<ExpandedMmaOperandB> plus_op;
CUTLASS_PRAGMA_UNROLL
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter)
{
operand_frag_ptr[mma_n_iter]
= plus_op(mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]), zero_frag[mma_n_iter]);
}
}
else
{
CUTLASS_PRAGMA_UNROLL
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter)
{
operand_frag_ptr[mma_n_iter] = mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]);
}
}
}
// Adds a pointer offset in units of elements.
CUTLASS_DEVICE
void add_pointer_offset(int64_t const& offset)
{
static_assert(sizeof(ElementScale) > 1, "");
pointer_scale_ += offset;
pointer_zero_ += offset;
}
private:
ElementScale const* pointer_scale_;
ElementScale const* pointer_zero_;
};
////////////////////////////////////////////////////////////////////////////////
} // namespace warp
} // namespace gemm
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,224 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cassert>
#include <iostream>
#include <sstream>
#include <string>
namespace tensorrt_llm
{
namespace cutlass_extensions
{
// Note: The shapes are in the format MxNxK. The K shape of the runtime config MUST match the K shape
// in the kernel layout details when doing weight only quantization.
enum class CutlassTileConfig
{
// Signals that we should run heuristics do choose a config
Undefined,
// Signals that we should run heuristics do choose a config
ChooseWithHeuristic,
// SiMT config
CtaShape128x128x8_WarpShape64x64x8,
// TensorCore configs CTA_N = 128, CTA_K = 64
// Warp configs for M=16
CtaShape16x128x64_WarpShape16x32x64,
// Warp configs for M=32
CtaShape32x128x64_WarpShape32x32x64,
// Warp configs for M=64
CtaShape64x128x64_WarpShape32x64x64,
CtaShape64x64x128_WarpShape32x64x64,
CtaShape64x128x64_WarpShape64x32x64,
// Warp configs for M=128
CtaShape128x64x64_WarpShape64x32x64,
CtaShape128x128x64_WarpShape64x32x64,
CtaShape128x128x64_WarpShape64x64x64,
CtaShape128x128x64_WarpShape128x32x64,
CtaShape128x256x64_WarpShape64x64x64,
// Warp configs for M=256
CtaShape256x128x64_WarpShape64x64x64,
// TensorCore config CTA_N = 64, CTA_K = 128
CtaShape128x64x128_WarpShape64x32x128,
// TensorCore config CTA_N = 256, CTA_K = 64
CtaShape16x256x64_WarpShape16x64x64,
// TensorCore config CTA_N = 256, CTA_K = 128
CtaShape16x256x128_WarpShape16x64x128
};
enum class SplitKStyle
{
NO_SPLIT_K,
SPLIT_K_SERIAL,
STREAM_K, // Sm80+
// SPLIT_K_PARALLEL // Not supported yet
};
enum class CutlassTileConfigSM90
{
// Signals that we should run heuristics do choose a config
Undefined,
// Signals that we should run heuristics do choose a config
ChooseWithHeuristic,
// CTA configs for M=64
CtaShape64x16x128B,
CtaShape64x32x128B,
CtaShape64x64x128B,
CtaShape64x128x128B,
CtaShape64x256x128B,
// CTA configs for M=128
CtaShape128x16x128B,
CtaShape128x32x128B,
CtaShape128x64x128B,
CtaShape128x128x128B,
CtaShape128x256x128B,
// CTA configs for M=128
CtaShape256x128x128B,
};
enum class MainloopScheduleType
{
AUTO // Automatically selects between pingpong and cooperative schedules on Hopper. On older architectures, this
// defaults to the "legacy" main loop schedule.
};
enum class EpilogueScheduleType
{
AUTO // Automatically chooses an epilogue schedule compatible with the selected main loop schedule for Hopper. For
// architectures older than hopper, the epilogue is always performed by the same thread block as the main loop.
};
enum class ClusterShape
{
ClusterShape_1x1x1,
ClusterShape_2x1x1,
ClusterShape_1x2x1,
ClusterShape_2x2x1,
ClusterShape_1x8x1,
ClusterShape_8x1x1
};
struct CutlassGemmConfig
{
enum CandidateConfigTypeParam : int
{
NONE = 0,
WEIGHT_ONLY = 1u << 0,
SIMT_ONLY = 1u << 1,
INT8_ONLY = 1u << 2,
HOPPER = 1u << 3,
GROUPED_GEMM = 1u << 4,
FP8_ONLY = 1u << 5,
};
CutlassTileConfig tile_config = CutlassTileConfig::ChooseWithHeuristic;
SplitKStyle split_k_style = SplitKStyle::NO_SPLIT_K;
int split_k_factor = -1;
int stages = -1;
// config options for sm90
CutlassTileConfigSM90 tile_config_sm90 = CutlassTileConfigSM90::ChooseWithHeuristic;
MainloopScheduleType mainloop_schedule = MainloopScheduleType::AUTO;
EpilogueScheduleType epilogue_schedule = EpilogueScheduleType::AUTO;
ClusterShape cluster_shape = ClusterShape::ClusterShape_1x1x1;
bool is_sm90 = false;
CutlassGemmConfig() {}
CutlassGemmConfig(CutlassTileConfig tile_config, SplitKStyle split_k_style, int split_k_factor, int stages)
: tile_config(tile_config)
, split_k_style(split_k_style)
, split_k_factor(split_k_factor)
, stages(stages)
, is_sm90(false)
{
}
CutlassGemmConfig(CutlassTileConfigSM90 tile_config_sm90, MainloopScheduleType mainloop_schedule,
EpilogueScheduleType epilogue_schedule, ClusterShape cluster_shape)
: tile_config_sm90(tile_config_sm90)
, mainloop_schedule(mainloop_schedule)
, epilogue_schedule(epilogue_schedule)
, cluster_shape(cluster_shape)
, is_sm90(true)
{
}
std::string toString() const
{
std::stringstream tactic;
tactic << "Cutlass GEMM Tactic";
if (tile_config_sm90 != tensorrt_llm::cutlass_extensions::CutlassTileConfigSM90::ChooseWithHeuristic)
{
assert(is_sm90 && "Invalid cutlass GEMM config");
tactic << "\n\tstyle=TMA"
<< "\n\ttile shape ID: " << (int) tile_config_sm90 << "\n\tcluster shape ID: " << (int) cluster_shape
<< "\n\tmainloop sched: " << (int) mainloop_schedule << "\n\tepi sched: " << (int) epilogue_schedule;
}
else if (tile_config != tensorrt_llm::cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic)
{
assert(!is_sm90 && "Invalid cutlass GEMM config");
tactic << "\n\tstyle=compatible"
<< "\n\ttile shape ID: " << (int) tile_config << "\n\tstages: " << (int) stages
<< "\n\tsplit k: " << (int) split_k_factor;
}
else
{
tactic << "\n\tundefined";
}
tactic << "\n";
return tactic.str();
}
};
inline std::ostream& operator<<(std::ostream& out, CutlassGemmConfig const& config)
{
// clang-format off
if (config.is_sm90)
{
out << "tile_config_sm90_enum: " << int(config.tile_config_sm90)
<< ", mainloop_schedule_enum: " << int(config.mainloop_schedule)
<< ", epilogue_schedule_enum: " << int(config.epilogue_schedule)
<< ", cluster_shape_enum: " << int(config.cluster_shape);
}
else
{
out << "tile_config_enum: " << int(config.tile_config)
<< ", split_k_style_enum: " << int(config.split_k_style)
<< ", split_k_factor: " << config.split_k_factor
<< ", stages: " << config.stages;
}
// clang-format on
return out;
}
} // namespace cutlass_extensions
} // namespace tensorrt_llm

View File

@@ -0,0 +1,447 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*!
\file
\brief Boost-like numeric conversion operator for int8 and CUTLASS int4b_t interleaved in a register
*/
#pragma once
#include "cutlass/arch/arch.h"
#include "cutlass/array.h"
#include "cutlass/half.h"
#include "cutlass/numeric_types.h"
namespace cutlass
{
// This converter is meant to be used with data interleaved in a 32-bit register where the even elements are in the low
// bits and the odd elemeents are in the high bits of the register. In addition, it assumes elements were originally
// signed and had a bias of 2**(b-1) added (where b is the number of bits in the type) to make all numbers unsigned.
// This converter will uninterleave the data and subtract the bias while converting to the result type.
template <typename T, typename S, int N>
struct FastInterleavedAndBiasedNumericArrayConverter
{
};
template <>
struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint8_t, 4>
{
using result_type = Array<half_t, 4>;
using source_type = Array<uint8_t, 4>;
CUTLASS_DEVICE
static result_type convert(source_type const& source)
{
result_type result;
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
uint32_t const i8s = reinterpret_cast<uint32_t const&>(source);
static constexpr uint32_t mask_for_elt_01 = 0x5250;
static constexpr uint32_t mask_for_elt_23 = 0x5351;
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[0]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_01));
asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[1]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_23));
// Lastly, we subtract 1152 from our constructed number using fp16 math to get our signed integer as fp16.
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(I8s_TO_F16s_MAGIC_NUM));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(I8s_TO_F16s_MAGIC_NUM));
return result;
}
CUTLASS_DEVICE
result_type operator()(source_type const& s)
{
return convert(s);
}
};
template <int N>
struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint8_t, N>
{
static constexpr int VEC_WIDTH = 4;
static_assert(!(N % VEC_WIDTH), "N must be multiple of 4.");
using result_type = Array<half_t, N>;
using source_type = Array<uint8_t, N>;
CUTLASS_DEVICE
static result_type convert(source_type const& source)
{
using scalar_result_type = typename result_type::Element;
using scalar_source_type = typename source_type::Element;
FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type, scalar_source_type, VEC_WIDTH>
convert_vector_;
result_type result;
using vec_result = Array<scalar_result_type, VEC_WIDTH>;
using vec_source = Array<scalar_source_type, VEC_WIDTH>;
vec_result* result_ptr = reinterpret_cast<vec_result*>(&result);
vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N / VEC_WIDTH; ++i)
{
result_ptr[i] = convert_vector_(source_ptr[i]);
}
return result;
}
CUTLASS_DEVICE
result_type operator()(source_type const& s)
{
return convert(s);
}
};
template <>
struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint8_t, 4>
{
using result_type = Array<bfloat16_t, 4>;
using source_type = Array<uint8_t, 4>;
CUTLASS_DEVICE
static result_type convert(source_type const& source)
{
result_type result;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(&result);
uint32_t const i8s = reinterpret_cast<uint32_t const&>(source);
static constexpr uint32_t fp32_base = 0x4B000000;
float fp32_intermediates[4];
// Construct FP32s, bfloat does not have enough mantissa for IADD trick
uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650);
fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7652);
fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7651);
fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653);
// Subtract out fp32_base + 128 to make the unsigned integer signed.
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < 4; ++ii)
{
fp32_intermediates[ii] -= 8388736.f;
}
// Truncate the fp32 representation and pack up as bfloat16s.
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < 2; ++ii)
{
bf16_result_ptr[ii]
= __byte_perm(fp32_intermediates_casted[2 * ii + 0], fp32_intermediates_casted[2 * ii + 1], 0x7632);
}
#else
// Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use
// HMMA on older hardware, they should Convert directly to FP16 using FP16 converters.
result.clear(); // Suppress compiler warning
arch::device_breakpoint();
#endif
return result;
}
CUTLASS_DEVICE
result_type operator()(source_type const& s)
{
return convert(s);
}
};
template <int N>
struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint8_t, N>
{
static constexpr int VEC_WIDTH = 4;
static_assert(!(N % VEC_WIDTH), "N must be multiple of 4.");
using result_type = Array<bfloat16_t, N>;
using source_type = Array<uint8_t, N>;
CUTLASS_DEVICE
static result_type convert(source_type const& source)
{
using scalar_result_type = typename result_type::Element;
using scalar_source_type = typename source_type::Element;
FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type, scalar_source_type, VEC_WIDTH>
convert_vector_;
result_type result;
using vec_result = Array<scalar_result_type, VEC_WIDTH>;
using vec_source = Array<scalar_source_type, VEC_WIDTH>;
vec_result* result_ptr = reinterpret_cast<vec_result*>(&result);
vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N / VEC_WIDTH; ++i)
{
result_ptr[i] = convert_vector_(source_ptr[i]);
}
return result;
}
CUTLASS_DEVICE
result_type operator()(source_type const& s)
{
return convert(s);
}
};
template <>
struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint4b_t, 8>
{
using result_type = Array<half_t, 8>;
using source_type = Array<uint4b_t, 8>;
CUTLASS_DEVICE
static result_type convert(source_type const& source)
{
result_type result;
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
uint32_t const i4s = reinterpret_cast<uint32_t const&>(source);
// First, we extract the i4s and construct an intermediate fp16 number.
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
static constexpr uint32_t TOP_MASK = 0x00f000f0;
static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
// Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing
// format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions.
// In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and
// elt_67 to fp16 without having to shift them to the bottom bits before hand.
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue
// immediately before required.
const uint32_t top_i4s = i4s >> 8;
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[0])
: "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[1])
: "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[2])
: "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[3])
: "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the
// half2 ctor. In this case, I chose performance reliability over code readability.
// This is the half2 {1032, 1032} represented as an integer.
static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
// This is the half2 {1 / 16, 1 / 16} represented as an integer.
static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
// This is the half2 {-72, -72} represented as an integer.
static constexpr uint32_t NEG_72 = 0xd480d480;
// Finally, we construct the output numbers.
// Convert elt_01
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_23
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_72));
// Convert elt_45
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_67
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_72));
return result;
}
CUTLASS_DEVICE
result_type operator()(source_type const& s)
{
return convert(s);
}
};
template <int N>
struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint4b_t, N>
{
static constexpr int VEC_WIDTH = 8;
static_assert(!(N % VEC_WIDTH), "N must be multiple of 8.");
using result_type = Array<half_t, N>;
using source_type = Array<uint4b_t, N>;
CUTLASS_DEVICE
static result_type convert(source_type const& source)
{
using scalar_result_type = typename result_type::Element;
using scalar_source_type = typename source_type::Element;
FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type, scalar_source_type, VEC_WIDTH>
convert_vector_;
result_type result;
using vec_result = Array<scalar_result_type, VEC_WIDTH>;
using vec_source = Array<scalar_source_type, VEC_WIDTH>;
vec_result* result_ptr = reinterpret_cast<vec_result*>(&result);
vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N / VEC_WIDTH; ++i)
{
result_ptr[i] = convert_vector_(source_ptr[i]);
}
return result;
}
CUTLASS_DEVICE
result_type operator()(source_type const& s)
{
return convert(s);
}
};
template <>
struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint4b_t, 8>
{
using result_type = Array<bfloat16_t, 8>;
using source_type = Array<uint4b_t, 8>;
CUTLASS_DEVICE
static result_type convert(source_type const& source)
{
result_type result;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
uint32_t const source_i4s = reinterpret_cast<uint32_t const&>(source);
// First, we extract the i4s and construct an intermediate fp16 number.
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint32_t MASK = 0x000f000f;
static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300;
// We don't have enough mantissa to remove as much shift overhead as FP16, so we must loop.
// No shift needed for first item.
uint32_t i4s = source_i4s;
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[0])
: "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
CUTLASS_PRAGMA_UNROLL
for (int ii = 1; ii < result_type::kElements / 2; ++ii)
{
i4s >>= sizeof_bits<typename source_type::Element>::value;
// (i4s & 0x000f000f) | 0x43004300
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[ii])
: "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
}
// This is the BF16 {-136, -136} represented as an integer.
static constexpr uint32_t BF16_BIAS = 0xC308C308;
static constexpr uint32_t BF16_ONE = 0x3F803F80;
// Finally, we construct the output numbers.
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < result_type::kElements / 2; ++ii)
{
// Since this section is for Ampere+, we use bf16 fma to do the bias subtraction
asm("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[ii]) : "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS));
}
#else
// Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use
// HMMA on older hardware, they should Convert directly to FP16 using FP16 converters.
arch::device_breakpoint();
result.clear(); // Suppress compiler warning.
#endif
return result;
}
CUTLASS_DEVICE
result_type operator()(source_type const& s)
{
return convert(s);
}
};
template <int N>
struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint4b_t, N>
{
static constexpr int VEC_WIDTH = 8;
static_assert(!(N % VEC_WIDTH), "N must be multiple of 8.");
using result_type = Array<bfloat16_t, N>;
using source_type = Array<uint4b_t, N>;
CUTLASS_DEVICE
static result_type convert(source_type const& source)
{
using scalar_result_type = typename result_type::Element;
using scalar_source_type = typename source_type::Element;
FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type, scalar_source_type, VEC_WIDTH>
convert_vector_;
result_type result;
using vec_result = Array<scalar_result_type, VEC_WIDTH>;
using vec_source = Array<scalar_source_type, VEC_WIDTH>;
vec_result* result_ptr = reinterpret_cast<vec_result*>(&result);
vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N / VEC_WIDTH; ++i)
{
result_ptr[i] = convert_vector_(source_ptr[i]);
}
return result;
}
CUTLASS_DEVICE
result_type operator()(source_type const& s)
{
return convert(s);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,66 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines new layouts needed for MoE
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/pitch_linear_coord.h"
namespace cutlass
{
namespace layout
{
template <int RowsPerTile, int ColumnsInterleaved>
struct ColumnMajorTileInterleave
{
static constexpr int kRowsPerTile = RowsPerTile;
static constexpr int kColumnsInterleaved = ColumnsInterleaved;
};
template <class T>
struct IsColumnMajorTileInterleave
{
static constexpr bool value = false;
};
template <int U, int V>
struct IsColumnMajorTileInterleave<ColumnMajorTileInterleave<U, V>>
{
static constexpr bool value = true;
};
} // namespace layout
} // namespace cutlass

View File

@@ -0,0 +1,250 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Templates for visiting scales to be used when dequantizing the weights for weight-only GEMM
quantization.
*/
#pragma once
#include "cutlass/array.h"
#include "cutlass/coord.h"
#include "cutlass/cutlass.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/layout/pitch_linear.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/predicate_vector.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/tensor_view.h"
#include "cutlass/transform/threadblock/predicated_tile_access_iterator_params.h"
////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace transform
{
namespace threadblock
{
////////////////////////////////////////////////////////////////////////////////
template <typename Shape, typename Element, typename Layout, int AdvanceRank, int Alignment>
class FineGrainedScaleZeroIterator;
template <typename Shape_, typename Element_, int Alignment_>
class FineGrainedScaleZeroIterator<Shape_, Element_, layout::RowMajor, 0, Alignment_>
{
public:
using Shape = Shape_;
using Element = Element_;
using Layout = layout::RowMajor;
static int const kAdvanceRank = 0;
static int const kAlignment = Alignment_;
static int const kAccessesPerVector = 1;
/// Row index of scales corresponding to the groupsize of 64
int row_groupsize64_;
int group_size_;
using Index = typename Layout::Index;
using LongIndex = typename Layout::LongIndex;
using TensorRef = TensorRef<Element, Layout>;
using TensorView = TensorView<Element, Layout>;
using TensorCoord = typename Layout::TensorCoord;
using Pointer = Element*;
using NonConstPointer = typename platform::remove_const<Element>::type*;
using AccessType = AlignedArray<Element, kAlignment>;
using Fragment = cutlass::Array<Element, kAlignment>;
// For compatibility with existing iterator interface
struct Params
{
LongIndex stride_ = 0;
/// amount (in byte) to increment pointer from first access of current tile
/// to first access of next tile
LongIndex inc_advance_ = 0;
// Default ctor
CUTLASS_HOST_DEVICE
Params() {}
/// Construct the Params object given a pitch-linear tensor's layout
CUTLASS_HOST_DEVICE
Params(Layout const& layout)
: stride_(layout.stride(0))
{
inc_advance_ = Shape::kRow * stride_ * sizeof_bits<Element>::value / 8;
}
};
private:
/// Internal pointer type permits fast address arithmetic
using BytePointer = char*;
private:
//
// Data members
//
/// Parameters object with precomputed internal state
Params const params_;
/// Internal pointer to first access of tile
BytePointer pointer_scale_;
BytePointer pointer_zero_;
bool is_valid_ = false;
public:
/// Constructs a TileIterator from its precomputed state, threadblock offset,
/// and thread ID
CUTLASS_DEVICE
FineGrainedScaleZeroIterator(
///< Precomputed parameters object
Params const& params,
///< Pointer to start of scale tensor
Pointer pointer_scale,
///< Pointer to start of zero tensor
Pointer pointer_zero,
///< Extent of the scale and bias
TensorCoord extent,
///< ID of each participating thread
int thread_id,
///< Initial offset of threadblock
TensorCoord const& threadblock_offset,
///< Group size
int group_size)
: params_(params)
, pointer_scale_(reinterpret_cast<BytePointer>(const_cast<NonConstPointer>(pointer_scale)))
, pointer_zero_(reinterpret_cast<BytePointer>(const_cast<NonConstPointer>(pointer_zero)))
{
row_groupsize64_ = threadblock_offset.row();
group_size_ = group_size;
const LongIndex tb_row_byte_offset
= threadblock_offset.row() / (group_size / 64) * params_.stride_ * sizeof_bits<Element>::value / 8;
const LongIndex tb_col_byte_offset = threadblock_offset.column() * sizeof_bits<Element>::value / 8;
pointer_scale_ += (tb_row_byte_offset + tb_col_byte_offset);
if (pointer_zero_ != nullptr)
{
pointer_zero_ += (tb_row_byte_offset + tb_col_byte_offset);
}
static constexpr int THREADS_PER_ROW = Shape::kColumn / kAlignment;
int const thread_row = thread_id / THREADS_PER_ROW;
int const thread_col = thread_id % THREADS_PER_ROW;
const LongIndex thread_row_byte_offset = thread_row * params_.stride_ * sizeof_bits<Element>::value / 8;
const LongIndex thread_col_byte_offset = thread_col * kAlignment * sizeof_bits<Element>::value / 8;
pointer_scale_ += (thread_row_byte_offset + thread_col_byte_offset);
if (pointer_zero_ != nullptr)
{
pointer_zero_ += (thread_row_byte_offset + thread_col_byte_offset);
}
// For the rows, we must check that we are within the extent AND the tile to avoid extra reads on
// a given iteration. The same threads will be responsible for issues reads since the number of scales
// read in a given iteration is a constant. Therefore, we should never have to update is_valid_
// outside of the constructor.
int const global_row = threadblock_offset.row() + thread_row;
int const global_col = threadblock_offset.column() + thread_col * kAlignment;
bool const row_in_bounds = global_row < extent.row() && thread_row < Shape::kRow;
bool const col_in_bounds = global_col < extent.column();
is_valid_ = row_in_bounds && col_in_bounds;
}
/// Construct a PredicatedTileAccessIterator with zero threadblock offset
CUTLASS_HOST_DEVICE FineGrainedScaleZeroIterator(Params const& params, ///< Precomputed parameters object
Pointer pointer_scale, ///< Pointer to start of scale tensor
Pointer pointer_zero, ///< Pointer to start of zero tensor
TensorCoord extent, ///< Extent of tensor
int thread_id, ///< ID of each participating thread
int group_size)
: FineGrainedScaleZeroIterator(
params, pointer_scale, pointer_zero, extent, thread_id, make_Coord(0, 0), group_size)
{
}
CUTLASS_DEVICE
void add_tile_offset(TensorCoord const& tile_offset)
{
const LongIndex row_byte_offset = tile_offset.row() * params_.inc_advance_;
const LongIndex col_byte_offset = tile_offset.column() * Shape::kColumn * sizeof_bits<Element>::value / 8;
pointer_scale_ += row_byte_offset + col_byte_offset;
if (pointer_zero_ != nullptr)
{
pointer_zero_ += row_byte_offset + col_byte_offset;
}
}
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE void clear_mask(bool enable = true)
{
is_valid_ &= (!enable);
}
/// Returns whether access is valid or not
CUTLASS_HOST_DEVICE
bool valid() const
{
return is_valid_;
}
/// Returns a scale pointer
CUTLASS_HOST_DEVICE
AccessType* get_scale() const
{
return reinterpret_cast<AccessType*>(pointer_scale_);
}
/// Returns a zero pointer
CUTLASS_HOST_DEVICE
AccessType* get_zero() const
{
return reinterpret_cast<AccessType*>(pointer_zero_);
}
};
} // namespace threadblock
} // namespace transform
} // namespace cutlass

View File

@@ -0,0 +1,181 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cute/layout.hpp"
#include "cute/tensor.hpp"
#include "cute/util/print.hpp"
using namespace cute;
/// Function object that applies an index to its argument
template <class Iter>
struct IndexedGather
{
CUTE_HOST_DEVICE constexpr IndexedGather(Iter indices = {})
: indices_(indices)
{
}
template <typename I>
CUTE_HOST_DEVICE constexpr auto operator()(I i) const
{
return indices_[i];
}
CUTE_HOST_DEVICE friend void print(IndexedGather const& s)
{
cute::print("Indexed{");
print(s.indices_);
print("}");
}
Iter indices_;
};
/// Custom stride object that applies a function followed by a stride
template <class Func, class Stride>
struct CustomStride
{
CUTE_HOST_DEVICE constexpr CustomStride(Func const& func, Stride const& stride)
: func_(func)
, stride_(stride)
{
}
template <class I>
CUTE_HOST_DEVICE constexpr friend auto operator*(I i, CustomStride const& s)
{
return s.func_(i) * s.stride_;
}
template <class I>
CUTE_HOST_DEVICE constexpr friend auto operator*(CustomStride const& s, I i)
{
return s.func_(i) * s.stride_;
}
CUTE_HOST_DEVICE friend void print(CustomStride const& s)
{
cute::print("Custom{");
print(s.func_);
cute::print(",");
print(s.stride_);
cute::print("}");
}
template <class Div>
CUTE_HOST_DEVICE constexpr friend auto safe_div(CustomStride const& s, Div const& div)
{
return CustomStride<Func, decltype(safe_div(s.stride_, div))>(s.func_, safe_div(s.stride_, div));
}
// Circumvent the requirement on make_layout that shape and stride are integral
template <class Shape>
CUTE_HOST_DEVICE constexpr friend auto make_layout(Shape const& shape, CustomStride const& stride)
{
return Layout<Shape, CustomStride>(shape, stride);
}
Func func_;
Stride stride_;
};
template <class Stride, class Func>
CUTLASS_HOST_DEVICE auto make_custom_stride_layout(Stride const& stride, Func&& func)
{
// Use a dummy shape and replace the first non-unit and non-zero stride with a custom gather stride
auto idx = find_if(stride, [](auto x) { return !is_constant<1, decltype(x)>{} && !is_constant<0, decltype(x)>{}; });
constexpr int I = decltype(idx)::value;
return make_layout(
repeat_like(stride, _1{}), replace<I>(stride, CustomStride{static_cast<Func&&>(func), get<I>(stride)}));
}
/// Helper function to optionally create a gather tensor
template <class Iterator, class Shape, class Stride, class Func>
CUTLASS_HOST_DEVICE auto make_gather_tensor(Iterator iter, Shape const& shape, Stride const& stride, Func&& func)
{
Layout matrix_layout = make_identity_layout(shape);
auto offset = as_arithmetic_tuple(repeat_like(shape, _0{}));
Layout gather_layout = make_custom_stride_layout(stride, static_cast<Func&&>(func));
return make_tensor(iter, ComposedLayout{gather_layout, offset, matrix_layout});
}
namespace cute
{
template <int N, int I, class Shape, class Stride>
CUTE_HOST_DEVICE constexpr auto upcast(Shape const& shape, Stride const& stride)
{
if constexpr (is_tuple<Shape>::value)
{
return transform_layout(shape, stride, [](auto const& s, auto const& d) { return upcast<N, I>(s, d); });
}
else if constexpr (is_scaled_basis<Stride>::value)
{
if constexpr (Stride::mode() == I)
{
return make_layout(shape_div(shape, Int<N>{}), shape_div(stride, Int<N>{}));
}
else
{
return make_layout(shape, stride);
}
}
else
{
return upcast<N>(shape, stride);
}
CUTE_GCC_UNREACHABLE;
}
template <int N, class OuterShape, class OuterStride, class Offset, class Shape, class Stride>
CUTE_HOST_DEVICE constexpr auto upcast(
ComposedLayout<Layout<OuterShape, OuterStride>, Offset, Layout<Shape, Stride>> const& layout)
{
// Find index of the stride-1 mode - that is the only one that requires updating inner shape and offset
auto idx = find_if(layout.layout_a().stride(), [](auto x) { return is_constant<1, decltype(x)>{}; });
constexpr int I = decltype(idx)::value;
// Upcast the outer layout (works as expected)
auto outer = upcast<N>(layout.layout_a());
// Upcast the accumulated offset along stride-1 mode
auto offset = as_arithmetic_tuple(replace<I>(layout.offset(), upcast<N>(get<I>(layout.offset()))));
// Upcast the inner layout's shape along stride-1 mode
auto inner = upcast<N, I>(layout.layout_b().shape(), layout.layout_b().stride());
return composition(outer, offset, inner);
}
} // namespace cute

View File

@@ -0,0 +1,58 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores.
*/
#pragma once
namespace cutlass
{
enum class WeightOnlyQuantOp
{
UNDEFINED,
PER_COLUMN_SCALE_ONLY,
FINEGRAINED_SCALE_ONLY,
FINEGRAINED_SCALE_AND_ZEROS
};
constexpr bool isFinegrained(WeightOnlyQuantOp op)
{
return op == WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS || op == WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY;
}
constexpr bool hasZero(WeightOnlyQuantOp op)
{
return op == WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS;
}
} // namespace cutlass

View File

@@ -223,3 +223,208 @@ BSD 3-Clause "New" License
3rdparty/cutlass
include/flashinfer/attention/hopper/block_sparse_gather.cuh
Notice for NVIDIA/TensorRT-LLM
-------------------------------
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@@ -39,6 +39,8 @@ cutlass_default = root / "3rdparty" / "cutlass"
cutlass = Path(os.environ.get("CUSTOM_CUTLASS_SRC_DIR", default=cutlass_default))
flashinfer = root / "3rdparty" / "flashinfer"
turbomind = root / "3rdparty" / "turbomind"
tensorrt_llm_parent = root / "3rdparty"
tensorrt_llm = root / "3rdparty" / "tensorrt_llm"
include_dirs = [
cutlass.resolve() / "include",
cutlass.resolve() / "tools" / "util" / "include",
@@ -51,6 +53,8 @@ include_dirs = [
"cublasLt",
turbomind.resolve(),
turbomind.resolve() / "src",
tensorrt_llm_parent.resolve(),
tensorrt_llm.resolve() / "cutlass_extensions" / "include",
]
nvcc_flags = [