revert the MoE dependence (#3230)
This commit is contained in:
@@ -1,34 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
bool initCheckDebug()
|
||||
{
|
||||
auto constexpr kDebugEnabled = "TLLM_DEBUG_MODE";
|
||||
auto const debugEnabled = std::getenv(kDebugEnabled);
|
||||
return debugEnabled && debugEnabled[0] == '1';
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool DebugConfig::isCheckDebugEnabled()
|
||||
{
|
||||
static bool const debugEnabled = initCheckDebug();
|
||||
return debugEnabled;
|
||||
}
|
||||
92
sgl-kernel/3rdparty/tensorrt_llm/common/assert.h
vendored
92
sgl-kernel/3rdparty/tensorrt_llm/common/assert.h
vendored
@@ -1,92 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/stringUtils.h"
|
||||
#include "tensorrt_llm/common/tllmException.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace tensorrt_llm::common
|
||||
{
|
||||
[[noreturn]] inline void throwRuntimeError(char const* const file, int const line, std::string const& info = "")
|
||||
{
|
||||
throw TllmException(file, line, fmtstr("[TensorRT-LLM][ERROR] Assertion failed: %s", info.c_str()));
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::common
|
||||
|
||||
class DebugConfig
|
||||
{
|
||||
public:
|
||||
static bool isCheckDebugEnabled();
|
||||
};
|
||||
|
||||
#if defined(_WIN32)
|
||||
#define TLLM_LIKELY(x) (__assume((x) == 1), (x))
|
||||
#define TLLM_UNLIKELY(x) (__assume((x) == 0), (x))
|
||||
#else
|
||||
#define TLLM_LIKELY(x) __builtin_expect((x), 1)
|
||||
#define TLLM_UNLIKELY(x) __builtin_expect((x), 0)
|
||||
#endif
|
||||
|
||||
#define TLLM_CHECK(val) \
|
||||
do \
|
||||
{ \
|
||||
TLLM_LIKELY(static_cast<bool>(val)) ? ((void) 0) \
|
||||
: tensorrt_llm::common::throwRuntimeError(__FILE__, __LINE__, #val); \
|
||||
} while (0)
|
||||
|
||||
#define TLLM_CHECK_WITH_INFO(val, info, ...) \
|
||||
do \
|
||||
{ \
|
||||
TLLM_LIKELY(static_cast<bool>(val)) \
|
||||
? ((void) 0) \
|
||||
: tensorrt_llm::common::throwRuntimeError( \
|
||||
__FILE__, __LINE__, tensorrt_llm::common::fmtstr(info, ##__VA_ARGS__)); \
|
||||
} while (0)
|
||||
|
||||
#define TLLM_CHECK_DEBUG(val) \
|
||||
do \
|
||||
{ \
|
||||
if (TLLM_UNLIKELY(DebugConfig::isCheckDebugEnabled())) \
|
||||
{ \
|
||||
TLLM_LIKELY(static_cast<bool>(val)) ? ((void) 0) \
|
||||
: tensorrt_llm::common::throwRuntimeError(__FILE__, __LINE__, #val); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define TLLM_CHECK_DEBUG_WITH_INFO(val, info, ...) \
|
||||
do \
|
||||
{ \
|
||||
if (TLLM_UNLIKELY(DebugConfig::isCheckDebugEnabled())) \
|
||||
{ \
|
||||
TLLM_LIKELY(static_cast<bool>(val)) \
|
||||
? ((void) 0) \
|
||||
: tensorrt_llm::common::throwRuntimeError( \
|
||||
__FILE__, __LINE__, tensorrt_llm::common::fmtstr(info, ##__VA_ARGS__)); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define TLLM_THROW(...) \
|
||||
do \
|
||||
{ \
|
||||
throw NEW_TLLM_EXCEPTION(__VA_ARGS__); \
|
||||
} while (0)
|
||||
|
||||
#define TLLM_WRAP(ex) \
|
||||
NEW_TLLM_EXCEPTION("%s: %s", tensorrt_llm::common::TllmException::demangle(typeid(ex).name()).c_str(), ex.what())
|
||||
@@ -1,360 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/common/cublasMMWrapper.h"
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/cublasVersionCheck.h"
|
||||
#include <algorithm>
|
||||
|
||||
#ifndef CUDART_VERSION
|
||||
#error CUDART_VERSION Undefined!
|
||||
#endif
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace common
|
||||
{
|
||||
|
||||
CublasMMWrapper::CublasMMWrapper(std::shared_ptr<cublasHandle_t> cublasHandle,
|
||||
std::shared_ptr<cublasLtHandle_t> cublasltHandle, cudaStream_t stream, void* workspace)
|
||||
: mCublasHandle(cublasHandle)
|
||||
, mCublasLtHandle(cublasltHandle)
|
||||
, mStream(stream)
|
||||
, mCublasWorkspace(workspace)
|
||||
{
|
||||
}
|
||||
|
||||
CublasMMWrapper::~CublasMMWrapper() {}
|
||||
|
||||
CublasMMWrapper::CublasMMWrapper(CublasMMWrapper const& wrapper)
|
||||
: mCublasHandle(wrapper.mCublasHandle)
|
||||
, mCublasLtHandle(wrapper.mCublasLtHandle)
|
||||
, mStream(wrapper.mStream)
|
||||
{
|
||||
}
|
||||
|
||||
void CublasMMWrapper::createDescriptors(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n,
|
||||
int const k, int const lda, int const ldb, int const ldc, int8_t fastAcc)
|
||||
{
|
||||
// --------------------------------------
|
||||
// Create descriptors for the original matrices
|
||||
check_cuda_error(
|
||||
cublasLtMatrixLayoutCreate(&mADesc, mAType, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda));
|
||||
check_cuda_error(
|
||||
cublasLtMatrixLayoutCreate(&mBDesc, mBType, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb));
|
||||
check_cuda_error(cublasLtMatrixLayoutCreate(&mCDesc, mCType, m, n, ldc));
|
||||
check_cuda_error(cublasLtMatmulDescCreate(&mOperationDesc, mComputeType, mScaleType));
|
||||
check_cuda_error(cublasLtMatmulDescSetAttribute(
|
||||
mOperationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(cublasOperation_t)));
|
||||
check_cuda_error(cublasLtMatmulDescSetAttribute(
|
||||
mOperationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(cublasOperation_t)));
|
||||
check_cuda_error(
|
||||
cublasLtMatmulDescSetAttribute(mOperationDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fastAcc, sizeof(int8_t)));
|
||||
}
|
||||
|
||||
void CublasMMWrapper::setScaleDescriptors(void* scale_a, void* scale_b)
|
||||
{
|
||||
check_cuda_error(
|
||||
cublasLtMatmulDescSetAttribute(mOperationDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &scale_a, sizeof(void*)));
|
||||
check_cuda_error(
|
||||
cublasLtMatmulDescSetAttribute(mOperationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &scale_b, sizeof(void*)));
|
||||
}
|
||||
|
||||
void CublasMMWrapper::destroyDescriptors()
|
||||
{
|
||||
check_cuda_error(cublasLtMatmulDescDestroy(mOperationDesc));
|
||||
check_cuda_error(cublasLtMatrixLayoutDestroy(mADesc));
|
||||
check_cuda_error(cublasLtMatrixLayoutDestroy(mBDesc));
|
||||
check_cuda_error(cublasLtMatrixLayoutDestroy(mCDesc));
|
||||
mOperationDesc = NULL;
|
||||
mADesc = NULL;
|
||||
mBDesc = NULL;
|
||||
mCDesc = NULL;
|
||||
}
|
||||
|
||||
void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
|
||||
void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc)
|
||||
{
|
||||
Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, 1.0f, 0.0f);
|
||||
}
|
||||
|
||||
void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
|
||||
void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc,
|
||||
std::optional<cublasLtMatmulHeuristicResult_t> const& heuristic)
|
||||
{
|
||||
if (heuristic)
|
||||
{
|
||||
Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, 1.0f, 0.0f, /* hasAlgo */ (*heuristic).algo,
|
||||
(*heuristic).state == CUBLAS_STATUS_SUCCESS && (*heuristic).workspaceSize < CUBLAS_WORKSPACE_SIZE,
|
||||
/* usingCublasLt */ true);
|
||||
}
|
||||
else
|
||||
{
|
||||
Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, 1.0f, 0.0f, {}, /* hasAlgo */ false,
|
||||
/* usingCublasLt */ true);
|
||||
}
|
||||
}
|
||||
|
||||
void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
|
||||
void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta,
|
||||
std::optional<cublasLtMatmulHeuristicResult_t> const& heuristic)
|
||||
{
|
||||
if (heuristic)
|
||||
{
|
||||
Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, f_alpha, f_beta, /* hasAlgo */ (*heuristic).algo,
|
||||
(*heuristic).state == CUBLAS_STATUS_SUCCESS && (*heuristic).workspaceSize < CUBLAS_WORKSPACE_SIZE,
|
||||
/* usingCublasLt */ true);
|
||||
}
|
||||
else
|
||||
{
|
||||
Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, f_alpha, f_beta, {}, /* hasAlgo */ false,
|
||||
/* usingCublasLt */ true);
|
||||
}
|
||||
}
|
||||
|
||||
void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
|
||||
void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta)
|
||||
{
|
||||
bool usingCublasLt = mAType == CUDA_R_16F || mAType == CUDA_R_8F_E4M3;
|
||||
|
||||
Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, f_alpha, f_beta, {}, /* hasAlgo */ false,
|
||||
/* usingCublasLt */ usingCublasLt);
|
||||
}
|
||||
|
||||
void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
|
||||
void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta,
|
||||
cublasLtMatmulAlgo_t const& algo, bool hasAlgo, bool usingCublasLt)
|
||||
{
|
||||
half h_alpha = (half) (f_alpha);
|
||||
half h_beta = (half) (f_beta);
|
||||
|
||||
// TODO: default cublas libs
|
||||
usingCublasLt = usingCublasLt && (mAType == CUDA_R_16F || mAType == CUDA_R_8F_E4M3);
|
||||
bool isFp16ComputeType = mComputeType == CUBLAS_COMPUTE_16F;
|
||||
int batch_count = 1;
|
||||
// fp32 use cublas as default
|
||||
// fp16 use cublasLt as default
|
||||
void const* alpha = isFp16ComputeType ? reinterpret_cast<void*>(&h_alpha) : reinterpret_cast<void*>(&f_alpha);
|
||||
void const* beta = isFp16ComputeType ? reinterpret_cast<void*>(&h_beta) : reinterpret_cast<void*>(&f_beta);
|
||||
int workspaceSize = mCublasWorkspace == NULL ? 0 : CUBLAS_WORKSPACE_SIZE;
|
||||
|
||||
if (usingCublasLt)
|
||||
{
|
||||
if (hasAlgo)
|
||||
{
|
||||
hasAlgo = checkTactic(transa, transb, m, n, k, lda, ldb, ldc, algo);
|
||||
}
|
||||
|
||||
check_cuda_error(cublasLtMatmul(getCublasLtHandle(), mOperationDesc, alpha, A, mADesc, B, mBDesc, beta, C,
|
||||
mCDesc, C, mCDesc, (hasAlgo ? (&algo) : NULL), mCublasWorkspace, workspaceSize, mStream));
|
||||
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
else
|
||||
{
|
||||
check_cuda_error(cublasSetStream(getCublasHandle(), mStream));
|
||||
check_cuda_error(cublasSetWorkspace(getCublasHandle(), mCublasWorkspace, workspaceSize));
|
||||
// Go with default heuristic to choose tactic as cuBLAS does not allow to choose tactics in Ampere+
|
||||
cublasGemmAlgo_t cublasAlgo = CUBLAS_GEMM_DEFAULT;
|
||||
check_cuda_error(cublasGemmEx(getCublasHandle(), transa, transb, m, n, k, alpha, A, mAType, lda, B, mBType, ldb,
|
||||
beta, C, mCType, ldc, mComputeType, static_cast<cublasGemmAlgo_t>(cublasAlgo)));
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
}
|
||||
|
||||
void CublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n,
|
||||
int const k, void const* A, int const lda, const int64_t strideA, void const* B, int const ldb,
|
||||
const int64_t strideB, void* C, int const ldc, const int64_t strideC, int const batchCount, float const f_alpha,
|
||||
float const f_beta)
|
||||
{
|
||||
half h_alpha = (half) f_alpha;
|
||||
half h_beta = (half) f_beta;
|
||||
|
||||
int isFp16ComputeType = mComputeType == CUBLAS_COMPUTE_16F ? 1 : 0;
|
||||
void const* alpha = isFp16ComputeType ? reinterpret_cast<void*>(&h_alpha) : reinterpret_cast<void const*>(&f_alpha);
|
||||
void const* beta = isFp16ComputeType ? reinterpret_cast<void*>(&h_beta) : reinterpret_cast<void const*>(&f_beta);
|
||||
|
||||
check_cuda_error(cublasGemmStridedBatchedEx(getCublasHandle(), transa, transb, m, n, k, alpha, A, mAType, lda,
|
||||
strideA, B, mBType, ldb, strideB, beta, C, mCType, ldc, strideC, batchCount, mComputeType,
|
||||
mAType == CUDA_R_32F ? CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
}
|
||||
|
||||
void CublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n,
|
||||
int const k, float const f_alpha, void const* A, cudaDataType_t AType, int const lda, const int64_t strideA,
|
||||
void const* B, cudaDataType_t BType, int const ldb, const int64_t strideB, float const f_beta, void* C,
|
||||
cudaDataType_t CType, int const ldc, const int64_t strideC, int const batchCount, cudaDataType_t computeType)
|
||||
{
|
||||
half h_alpha = (half) f_alpha;
|
||||
half h_beta = (half) f_beta;
|
||||
|
||||
bool isFp16ComputeType = mComputeType == CUBLAS_COMPUTE_16F ? 1 : 0;
|
||||
void const* alpha = isFp16ComputeType ? reinterpret_cast<void*>(&h_alpha) : reinterpret_cast<void const*>(&f_alpha);
|
||||
void const* beta = isFp16ComputeType ? reinterpret_cast<void*>(&h_beta) : reinterpret_cast<void const*>(&f_beta);
|
||||
|
||||
check_cuda_error(cublasGemmStridedBatchedEx(getCublasHandle(), transa, transb, m, n, k, alpha, A, AType, lda,
|
||||
strideA, B, BType, ldb, strideB, beta, C, CType, ldc, strideC, batchCount, computeType,
|
||||
mAType == CUDA_R_32F ? CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
}
|
||||
|
||||
void CublasMMWrapper::setWorkspace(void* workspace)
|
||||
{
|
||||
mCublasWorkspace = workspace;
|
||||
}
|
||||
|
||||
void CublasMMWrapper::setFP32GemmConfig()
|
||||
{
|
||||
setGemmConfig(CUDA_R_32F, CUDA_R_32F, CUDA_R_32F, CUDA_R_32F);
|
||||
}
|
||||
|
||||
void CublasMMWrapper::setFP16GemmConfig(cudaDataType_t outputType)
|
||||
{
|
||||
setGemmConfig(CUDA_R_16F, CUDA_R_16F, outputType, CUDA_R_32F);
|
||||
}
|
||||
|
||||
#ifdef ENABLE_BF16
|
||||
void CublasMMWrapper::setBF16GemmConfig(cudaDataType_t outputType)
|
||||
{
|
||||
setGemmConfig(CUDA_R_16BF, CUDA_R_16BF, outputType, CUDA_R_32F);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_FP8
|
||||
void CublasMMWrapper::setFP8GemmConfig(cudaDataType_t outputType)
|
||||
{
|
||||
setGemmConfig(CUDA_R_8F_E4M3, CUDA_R_8F_E4M3, outputType, CUDA_R_32F);
|
||||
}
|
||||
#endif
|
||||
|
||||
void CublasMMWrapper::setGemmConfig(
|
||||
cudaDataType_t aType, cudaDataType_t bType, cudaDataType_t cType, cudaDataType_t computeType)
|
||||
{
|
||||
mAType = aType;
|
||||
mBType = bType;
|
||||
mCType = cType;
|
||||
bool isFp16ComputeType = computeType == CUDA_R_16F;
|
||||
if (isFp16ComputeType)
|
||||
{
|
||||
mComputeType = CUBLAS_COMPUTE_16F;
|
||||
mScaleType = CUDA_R_16F;
|
||||
}
|
||||
else
|
||||
{
|
||||
mComputeType = CUBLAS_COMPUTE_32F;
|
||||
mScaleType = CUDA_R_32F;
|
||||
}
|
||||
}
|
||||
|
||||
CublasDataType CublasMMWrapper::getCublasDataType(cudaDataType_t data_type)
|
||||
{
|
||||
if (data_type == CUDA_R_16F)
|
||||
{
|
||||
return HALF_DATATYPE;
|
||||
}
|
||||
else if (data_type == CUDA_R_32F)
|
||||
{
|
||||
return FLOAT_DATATYPE;
|
||||
}
|
||||
else if (data_type == CUDA_R_8I)
|
||||
{
|
||||
return INT8_DATATYPE;
|
||||
}
|
||||
#ifdef ENABLE_BF16
|
||||
else if (data_type == CUDA_R_16BF)
|
||||
{
|
||||
return BFLOAT16_DATATYPE;
|
||||
}
|
||||
#endif
|
||||
return FLOAT_DATATYPE;
|
||||
}
|
||||
|
||||
void CublasMMWrapper::setStream(cudaStream_t stream)
|
||||
{
|
||||
mStream = stream;
|
||||
}
|
||||
|
||||
bool CublasMMWrapper::checkTactic(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n,
|
||||
int const k, int const lda, int const ldb, int const ldc, cublasLtMatmulAlgo_t const& algo)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
descriptorsCreated(), "Descriptors are not created! Call createDescriptors before calling this function");
|
||||
|
||||
int workspaceSize = mCublasWorkspace == NULL ? 0 : CUBLAS_WORKSPACE_SIZE;
|
||||
|
||||
cublasLtMatmulHeuristicResult_t heurResult;
|
||||
cublasStatus_t algoStatus = cublasLtMatmulAlgoCheck(
|
||||
getCublasLtHandle(), mOperationDesc, mADesc, mBDesc, mCDesc, mCDesc, &algo, &heurResult);
|
||||
|
||||
if (algoStatus != CUBLAS_STATUS_SUCCESS || heurResult.state != CUBLAS_STATUS_SUCCESS
|
||||
|| heurResult.workspaceSize > CUBLAS_WORKSPACE_SIZE)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
sync_check_cuda_error();
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<cublasLtMatmulHeuristicResult_t> CublasMMWrapper::getTactics(cublasOperation_t transa,
|
||||
cublasOperation_t transb, int const m, int const n, int const k, int const lda, int const ldb, int const ldc)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
descriptorsCreated(), "Descriptors are not created! Call createDescriptors before calling this function");
|
||||
|
||||
auto const heuristics = getTactics(getCublasLtHandle(), mOperationDesc, mADesc, mBDesc, mCDesc, mCDesc);
|
||||
|
||||
sync_check_cuda_error();
|
||||
|
||||
return heuristics;
|
||||
}
|
||||
|
||||
std::vector<cublasLtMatmulHeuristicResult_t> CublasMMWrapper::getTactics(cublasLtHandle_t lightHandle,
|
||||
cublasLtMatmulDesc_t computeDesc, cublasLtMatrixLayout_t Adesc, cublasLtMatrixLayout_t Bdesc,
|
||||
cublasLtMatrixLayout_t Cdesc, cublasLtMatrixLayout_t Ddesc)
|
||||
{
|
||||
#if TLLM_CUBLAS_VER_LE(11, 4, 2)
|
||||
TLLM_CHECK_WITH_INFO(false, "CUBLAS version too low, must be > 11.4.2.");
|
||||
return {};
|
||||
#else
|
||||
std::vector<cublasLtMatmulHeuristicResult_t> heuristics(200);
|
||||
cublasLtMatmulPreference_t preference;
|
||||
check_cuda_error(cublasLtMatmulPreferenceCreate(&preference));
|
||||
check_cuda_error(cublasLtMatmulPreferenceInit(preference));
|
||||
uint64_t workspace_size = CUBLAS_WORKSPACE_SIZE;
|
||||
check_cuda_error(cublasLtMatmulPreferenceSetAttribute(
|
||||
preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspace_size, sizeof(workspace_size)));
|
||||
// Restrict reduction algorithms for numerical stability and better determinism
|
||||
uint32_t reduction_mask = CUBLASLT_REDUCTION_SCHEME_MASK;
|
||||
check_cuda_error(cublasLtMatmulPreferenceSetAttribute(
|
||||
preference, CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK, &reduction_mask, sizeof(reduction_mask)));
|
||||
#if TLLM_CUBLAS_VER_LT(12, 0, 0)
|
||||
uint32_t pointer_mode_mask = 0;
|
||||
check_cuda_error(cublasLtMatmulPreferenceSetAttribute(
|
||||
preference, CUBLASLT_MATMUL_PREF_EPILOGUE_MASK, &pointer_mode_mask, sizeof(pointer_mode_mask)));
|
||||
#endif
|
||||
|
||||
int return_count = 0;
|
||||
check_cuda_error(cublasLtMatmulAlgoGetHeuristic(lightHandle, computeDesc, Adesc, Bdesc, Cdesc, Ddesc, preference,
|
||||
heuristics.size(), heuristics.data(), &return_count));
|
||||
heuristics.resize(return_count);
|
||||
|
||||
return heuristics;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace common
|
||||
|
||||
} // namespace tensorrt_llm
|
||||
@@ -1,148 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include <cublasLt.h>
|
||||
#include <cublas_v2.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <map>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace common
|
||||
{
|
||||
|
||||
class CublasMMWrapper
|
||||
{
|
||||
protected:
|
||||
std::shared_ptr<cublasHandle_t> mCublasHandle;
|
||||
std::shared_ptr<cublasLtHandle_t> mCublasLtHandle;
|
||||
|
||||
cudaDataType_t mAType{};
|
||||
cudaDataType_t mBType{};
|
||||
cudaDataType_t mCType{};
|
||||
cublasComputeType_t mComputeType{};
|
||||
cudaDataType_t mScaleType{};
|
||||
|
||||
cublasLtMatmulDesc_t mOperationDesc{NULL};
|
||||
cublasLtMatrixLayout_t mADesc{NULL};
|
||||
cublasLtMatrixLayout_t mBDesc{NULL};
|
||||
cublasLtMatrixLayout_t mCDesc{NULL};
|
||||
|
||||
cudaStream_t mStream;
|
||||
|
||||
void* mCublasWorkspace = nullptr;
|
||||
|
||||
private:
|
||||
bool descriptorsCreated() const
|
||||
{
|
||||
return mOperationDesc != NULL && mADesc != NULL && mBDesc != NULL && mCDesc != NULL;
|
||||
}
|
||||
|
||||
public:
|
||||
CublasMMWrapper(std::shared_ptr<cublasHandle_t> cublasHandle, std::shared_ptr<cublasLtHandle_t> cublasLtHandle,
|
||||
cudaStream_t stream, void* workspace);
|
||||
|
||||
~CublasMMWrapper();
|
||||
|
||||
CublasMMWrapper(CublasMMWrapper const& wrapper);
|
||||
|
||||
/********************** GEMMs **********************/
|
||||
void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A,
|
||||
int const lda, void const* B, int const ldb, void* C, int const ldc);
|
||||
|
||||
void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A,
|
||||
int const lda, void const* B, int const ldb, void* C, int const ldc,
|
||||
std::optional<cublasLtMatmulHeuristicResult_t> const& algo);
|
||||
|
||||
void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A,
|
||||
int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta,
|
||||
std::optional<cublasLtMatmulHeuristicResult_t> const& algo);
|
||||
|
||||
void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A,
|
||||
int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta);
|
||||
|
||||
void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A,
|
||||
int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta,
|
||||
cublasLtMatmulAlgo_t const& algo, bool hasAlgo, bool usingCublasLt);
|
||||
|
||||
void stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
|
||||
void const* A, int const lda, const int64_t strideA, void const* B, int const ldb, const int64_t strideB,
|
||||
void* C, int const ldc, const int64_t strideC, int const batchCount, float const f_alpha = 1.0f,
|
||||
float const f_beta = 0.0f);
|
||||
|
||||
void stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
|
||||
float const f_alpha, void const* A, cudaDataType_t AType, int const lda, const int64_t strideA, void const* B,
|
||||
cudaDataType_t BType, int const ldb, const int64_t strideB, float const f_beta, void* C, cudaDataType_t CType,
|
||||
int const ldc, const int64_t strideC, int const batchCount, cudaDataType_t computeType);
|
||||
|
||||
/********************** Tactic selection helpers **********************/
|
||||
bool checkTactic(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
|
||||
int const lda, int const ldb, int const ldc, cublasLtMatmulAlgo_t const& algo);
|
||||
|
||||
std::vector<cublasLtMatmulHeuristicResult_t> getTactics(cublasOperation_t transa, cublasOperation_t transb,
|
||||
int const m, int const n, int const k, int const lda, int const ldb, int const ldc);
|
||||
|
||||
std::vector<cublasLtMatmulHeuristicResult_t> getTactics(cublasLtHandle_t lightHandle,
|
||||
cublasLtMatmulDesc_t computeDesc, cublasLtMatrixLayout_t Adesc, cublasLtMatrixLayout_t Bdesc,
|
||||
cublasLtMatrixLayout_t Cdesc, cublasLtMatrixLayout_t Ddesc);
|
||||
|
||||
using MatrixLayout = std::tuple<cudaDataType_t, cublasLtOrder_t, uint64_t, uint64_t>;
|
||||
using cache_idx_t = std::tuple<cublasLtMatmulDesc_t, std::array<MatrixLayout, 4>>;
|
||||
|
||||
MatrixLayout createMatrixLayout(cublasLtMatrixLayout_t Mdesc);
|
||||
|
||||
/********************** Utils **********************/
|
||||
void setWorkspace(void* workspace);
|
||||
|
||||
void setFP32GemmConfig();
|
||||
void setFP16GemmConfig(cudaDataType_t outputType = CUDA_R_16F);
|
||||
#ifdef ENABLE_BF16
|
||||
void setBF16GemmConfig(cudaDataType_t outputType = CUDA_R_16BF);
|
||||
#endif
|
||||
#ifdef ENABLE_FP8
|
||||
void setFP8GemmConfig(cudaDataType_t outputType = CUDA_R_16F);
|
||||
#endif
|
||||
|
||||
void setStream(cudaStream_t stream);
|
||||
|
||||
void setGemmConfig(cudaDataType_t aType, cudaDataType_t bType, cudaDataType_t cType, cudaDataType_t computeType);
|
||||
|
||||
CublasDataType getCublasDataType(cudaDataType_t data_type);
|
||||
|
||||
void createDescriptors(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
|
||||
int const lda, int const ldb, int const ldc, int8_t fastAcc = 0);
|
||||
void setScaleDescriptors(void* scale_a, void* scale_b);
|
||||
void destroyDescriptors();
|
||||
|
||||
cublasHandle_t getCublasHandle()
|
||||
{
|
||||
return *(this->mCublasHandle);
|
||||
}
|
||||
|
||||
cublasLtHandle_t getCublasLtHandle() const
|
||||
{
|
||||
return *(this->mCublasLtHandle);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace common
|
||||
|
||||
} // namespace tensorrt_llm
|
||||
@@ -1,35 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
// We don't want to include cublas_api.h. It contains the CUBLAS_VER_* macro
|
||||
// definition which is not sufficient to determine if we include cublas.h,
|
||||
// cublas_v2.h or cublasLt.h.
|
||||
|
||||
#define TLLM_CUBLAS_VERSION_CALC(MAJOR, MINOR, PATCH) (MAJOR * 10000 + MINOR * 100 + PATCH)
|
||||
#define TLLM_CUBLAS_VER_LE(MAJOR, MINOR, PATCH) \
|
||||
TLLM_CUBLAS_VERSION_CALC(CUBLAS_VER_MAJOR, CUBLAS_VER_MINOR, CUBLAS_VER_PATCH) \
|
||||
<= TLLM_CUBLAS_VERSION_CALC(MAJOR, MINOR, PATCH)
|
||||
#define TLLM_CUBLAS_VER_LT(MAJOR, MINOR, PATCH) \
|
||||
TLLM_CUBLAS_VERSION_CALC(CUBLAS_VER_MAJOR, CUBLAS_VER_MINOR, CUBLAS_VER_PATCH) \
|
||||
< TLLM_CUBLAS_VERSION_CALC(MAJOR, MINOR, PATCH)
|
||||
#define TLLM_CUBLAS_VER_GE(MAJOR, MINOR, PATCH) \
|
||||
TLLM_CUBLAS_VERSION_CALC(CUBLAS_VER_MAJOR, CUBLAS_VER_MINOR, CUBLAS_VER_PATCH) \
|
||||
>= TLLM_CUBLAS_VERSION_CALC(MAJOR, MINOR, PATCH)
|
||||
#define TLLM_CUBLAS_VER_GT(MAJOR, MINOR, PATCH) \
|
||||
TLLM_CUBLAS_VERSION_CALC(CUBLAS_VER_MAJOR, CUBLAS_VER_MINOR, CUBLAS_VER_PATCH) \
|
||||
> TLLM_CUBLAS_VERSION_CALC(MAJOR, MINOR, PATCH)
|
||||
@@ -1,313 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/cudaBf16Wrapper.h"
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace common
|
||||
{
|
||||
|
||||
#ifdef ENABLE_BF16
|
||||
inline __device__ float2 bf1622float2(const __nv_bfloat162 val)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
float2 f_val;
|
||||
f_val.x = __low2float(val);
|
||||
f_val.y = __high2float(val);
|
||||
return f_val;
|
||||
#else
|
||||
return __bfloat1622float2(val);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ int16_t bf1622int16(__nv_bfloat162 val)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
float2 f_val;
|
||||
f_val.x = max(min(__low2float(val), 127.f), -128.f);
|
||||
f_val.y = max(min(__high2float(val), 127.f), -128.f);
|
||||
|
||||
union
|
||||
{
|
||||
int8_t int8[2];
|
||||
int16_t int16;
|
||||
};
|
||||
|
||||
int8[0] = static_cast<int8_t>(static_cast<short>(f_val.x));
|
||||
int8[1] = static_cast<int8_t>(static_cast<short>(f_val.y));
|
||||
return int16;
|
||||
#else
|
||||
val = __hmin2(val, make_bfloat162(127., 127.));
|
||||
val = __hmax2(val, make_bfloat162(-128., -128.));
|
||||
|
||||
union
|
||||
{
|
||||
int8_t int8[2];
|
||||
int16_t int16;
|
||||
};
|
||||
|
||||
int8[0] = static_cast<int8_t>(static_cast<short>(val.x));
|
||||
int8[1] = static_cast<int8_t>(static_cast<short>(val.y));
|
||||
return int16;
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ __nv_bfloat162 float22bf162(const float2 val)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
return __floats2bfloat162_rn(val.x, val.y);
|
||||
#else
|
||||
return __float22bfloat162_rn(val);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
__nv_bfloat162 val2;
|
||||
val2.x = val;
|
||||
val2.y = val;
|
||||
return val2;
|
||||
#else
|
||||
return __bfloat162bfloat162(val);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ __nv_bfloat162 bf16hadd2(const __nv_bfloat162 x, const __nv_bfloat162 y)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
float fxl, fxh, fyl, fyh;
|
||||
fxl = __low2float(x);
|
||||
fxh = __high2float(x);
|
||||
fyl = __low2float(y);
|
||||
fyh = __high2float(y);
|
||||
return __floats2bfloat162_rn(fxl + fyl, fxh + fyh);
|
||||
#else
|
||||
return __hadd2(x, y);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ __nv_bfloat16 bf16hadd(const __nv_bfloat16 x, const __nv_bfloat16 y)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
return __float2bfloat16(__bfloat162float(x) + __bfloat162float(y));
|
||||
#else
|
||||
return __hadd(x, y);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ __nv_bfloat162 bf16hsub2(const __nv_bfloat162 x, const __nv_bfloat162 y)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
float fxl, fxh, fyl, fyh;
|
||||
fxl = __low2float(x);
|
||||
fxh = __high2float(x);
|
||||
fyl = __low2float(y);
|
||||
fyh = __high2float(y);
|
||||
return __floats2bfloat162_rn(fxl - fyl, fxh - fyh);
|
||||
#else
|
||||
return __hsub2(x, y);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ __nv_bfloat16 bf16hsub(const __nv_bfloat16 x, const __nv_bfloat16 y)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
return __float2bfloat16(__bfloat162float(x) - __bfloat162float(y));
|
||||
#else
|
||||
return __hsub(x, y);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ __nv_bfloat162 bf16hmul2(const __nv_bfloat162 x, const __nv_bfloat162 y)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
float fxl, fxh, fyl, fyh;
|
||||
fxl = __low2float(x);
|
||||
fxh = __high2float(x);
|
||||
fyl = __low2float(y);
|
||||
fyh = __high2float(y);
|
||||
return __floats2bfloat162_rn(fxl * fyl, fxh * fyh);
|
||||
#else
|
||||
return __hmul2(x, y);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ __nv_bfloat16 bf16hmul(const __nv_bfloat16 x, const __nv_bfloat16 y)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
return __float2bfloat16(__bfloat162float(x) * __bfloat162float(y));
|
||||
#else
|
||||
return __hmul(x, y);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ __nv_bfloat162 bf16hfma2(const __nv_bfloat162 x, const __nv_bfloat162 y, const __nv_bfloat162 z)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
float fxl, fxh, fyl, fyh, fzl, fzh;
|
||||
fxl = __low2float(x);
|
||||
fxh = __high2float(x);
|
||||
fyl = __low2float(y);
|
||||
fyh = __high2float(y);
|
||||
fzl = __low2float(z);
|
||||
fzh = __high2float(z);
|
||||
return __floats2bfloat162_rn(fxl * fyl + fzl, fxh * fyh + fzh);
|
||||
#else
|
||||
return __hfma2(x, y, z);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ __nv_bfloat16 bf16hfma(const __nv_bfloat16 x, const __nv_bfloat16 y, const __nv_bfloat16 z)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
return __float2bfloat16(__bfloat162float(x) * __bfloat162float(y) + __bfloat162float(z));
|
||||
#else
|
||||
return __hfma(x, y, z);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ __nv_bfloat162 bf16exp2(const __nv_bfloat162 x)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
float fxl, fxh;
|
||||
fxl = __low2float(x);
|
||||
fxh = __high2float(x);
|
||||
;
|
||||
return __floats2bfloat162_rn(expf(fxl), expf(fxh));
|
||||
#else
|
||||
return h2exp(x);
|
||||
#endif
|
||||
}
|
||||
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
|
||||
#if defined(CUDART_VERSION) && (CUDART_VERSION < 12020)
|
||||
|
||||
inline __device__ __nv_bfloat162 make_bfloat162(const __nv_bfloat16 x, const __nv_bfloat16 y)
|
||||
{
|
||||
__nv_bfloat162 t;
|
||||
t.x = x;
|
||||
t.y = y;
|
||||
return t;
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
||||
inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c));
|
||||
#else
|
||||
return a + b + c;
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c, __nv_bfloat16 d)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c) + __bfloat162float(d));
|
||||
#else
|
||||
return (__nv_bfloat16) ((float) a + (float) b + (float) c + (float) d);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ __nv_bfloat162 bf16hadd2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
float fal, fah, fbl, fbh, fcl, fch;
|
||||
fal = __low2float(a);
|
||||
fah = __high2float(a);
|
||||
fbl = __low2float(b);
|
||||
fbh = __high2float(b);
|
||||
fcl = __low2float(c);
|
||||
fch = __high2float(c);
|
||||
return __floats2bfloat162_rn(fal + fbl + fcl, fah + fbh + fch);
|
||||
#else
|
||||
return a + b + c;
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ __nv_bfloat16 bf16hmul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
return __float2bfloat16(__bfloat162float(a) * __bfloat162float(b) * __bfloat162float(c));
|
||||
#else
|
||||
return a * b * c;
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ __nv_bfloat162 bf16hmul2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
float fal, fah, fbl, fbh, fcl, fch;
|
||||
fal = __low2float(a);
|
||||
fah = __high2float(a);
|
||||
fbl = __low2float(b);
|
||||
fbh = __high2float(b);
|
||||
fcl = __low2float(c);
|
||||
fch = __high2float(c);
|
||||
return __floats2bfloat162_rn(fal * fbl * fcl, fah * fbh * fch);
|
||||
#else
|
||||
return a * b * c;
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ __nv_bfloat162 bf16hfma2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c, __nv_bfloat162 d)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
float fal, fah, fbl, fbh, fcl, fch, fdl, fdh;
|
||||
fal = __low2float(a);
|
||||
fah = __high2float(a);
|
||||
fbl = __low2float(b);
|
||||
fbh = __high2float(b);
|
||||
fcl = __low2float(c);
|
||||
fch = __high2float(c);
|
||||
fdl = __low2float(d);
|
||||
fdh = __high2float(d);
|
||||
return __floats2bfloat162_rn(fal * fbl * fcl + fdl, fah * fbh * fch + fdh);
|
||||
#else
|
||||
return a * b * c + d;
|
||||
#endif
|
||||
}
|
||||
|
||||
#endif // ENABLE_BF16
|
||||
|
||||
} // namespace common
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
// Operator definitions intentionally in global namespace
|
||||
namespace
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
|
||||
#if defined(CUDART_VERSION) && (CUDART_VERSION < 12020)
|
||||
|
||||
inline __device__ __nv_bfloat162 operator*(const __nv_bfloat162 x, const __nv_bfloat162 y)
|
||||
{
|
||||
return tensorrt_llm::common::bf16hmul2(x, y);
|
||||
};
|
||||
|
||||
inline __device__ __nv_bfloat162 operator+(const __nv_bfloat162 x, const __nv_bfloat162 y)
|
||||
{
|
||||
return tensorrt_llm::common::bf16hadd2(x, y);
|
||||
};
|
||||
#endif
|
||||
#endif
|
||||
} // namespace
|
||||
@@ -1,21 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifdef ENABLE_BF16
|
||||
#include <cuda_bf16.h>
|
||||
#endif
|
||||
@@ -1,187 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#define CUDA_LIB_NAME "cuda"
|
||||
|
||||
#if defined(_WIN32)
|
||||
#include <windows.h>
|
||||
#define dllOpen(name) LoadLibrary("nv" name ".dll")
|
||||
#define dllClose(handle) FreeLibrary(static_cast<HMODULE>(handle))
|
||||
#define dllGetSym(handle, name) static_cast<void*>(GetProcAddress(static_cast<HMODULE>(handle), name))
|
||||
#else // For non-Windows platforms
|
||||
#include <dlfcn.h>
|
||||
#define dllOpen(name) dlopen("lib" name ".so.1", RTLD_LAZY)
|
||||
#define dllClose(handle) dlclose(handle)
|
||||
#define dllGetSym(handle, name) dlsym(handle, name)
|
||||
#endif // defined(_WIN32)
|
||||
|
||||
#include "cudaDriverWrapper.h"
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include <cstdio>
|
||||
#include <cuda.h>
|
||||
|
||||
namespace tensorrt_llm::common
|
||||
{
|
||||
|
||||
std::shared_ptr<CUDADriverWrapper> CUDADriverWrapper::getInstance()
|
||||
{
|
||||
static std::mutex mutex;
|
||||
static std::weak_ptr<CUDADriverWrapper> instance;
|
||||
std::shared_ptr<CUDADriverWrapper> result = instance.lock();
|
||||
if (result)
|
||||
{
|
||||
return result;
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
result = instance.lock();
|
||||
if (!result)
|
||||
{
|
||||
result = std::shared_ptr<CUDADriverWrapper>(new CUDADriverWrapper());
|
||||
instance = result;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
CUDADriverWrapper::CUDADriverWrapper()
|
||||
: handle(dllOpen(CUDA_LIB_NAME))
|
||||
{
|
||||
|
||||
TLLM_CHECK_WITH_INFO(handle != nullptr, "CUDA driver library is not open correctly.");
|
||||
|
||||
auto load_sym = [](void* handle, char const* name)
|
||||
{
|
||||
void* ret = dllGetSym(handle, name);
|
||||
return ret;
|
||||
};
|
||||
|
||||
*reinterpret_cast<void**>(&_cuGetErrorName) = load_sym(handle, "cuGetErrorName");
|
||||
*reinterpret_cast<void**>(&_cuGetErrorMessage) = load_sym(handle, "cuGetErrorMessage");
|
||||
*reinterpret_cast<void**>(&_cuFuncSetAttribute) = load_sym(handle, "cuFuncSetAttribute");
|
||||
*reinterpret_cast<void**>(&_cuLinkComplete) = load_sym(handle, "cuLinkComplete");
|
||||
*reinterpret_cast<void**>(&_cuModuleUnload) = load_sym(handle, "cuModuleUnload");
|
||||
*reinterpret_cast<void**>(&_cuLinkDestroy) = load_sym(handle, "cuLinkDestroy");
|
||||
*reinterpret_cast<void**>(&_cuModuleLoadData) = load_sym(handle, "cuModuleLoadData");
|
||||
*reinterpret_cast<void**>(&_cuLinkCreate) = load_sym(handle, "cuLinkCreate_v2");
|
||||
*reinterpret_cast<void**>(&_cuModuleGetFunction) = load_sym(handle, "cuModuleGetFunction");
|
||||
*reinterpret_cast<void**>(&_cuModuleGetGlobal) = load_sym(handle, "cuModuleGetGlobal_v2");
|
||||
*reinterpret_cast<void**>(&_cuLinkAddFile) = load_sym(handle, "cuLinkAddFile_v2");
|
||||
*reinterpret_cast<void**>(&_cuLinkAddData) = load_sym(handle, "cuLinkAddData_v2");
|
||||
*reinterpret_cast<void**>(&_cuLaunchCooperativeKernel) = load_sym(handle, "cuLaunchCooperativeKernel");
|
||||
*reinterpret_cast<void**>(&_cuLaunchKernel) = load_sym(handle, "cuLaunchKernel");
|
||||
*reinterpret_cast<void**>(&_cuTensorMapEncodeTiled) = load_sym(handle, "cuTensorMapEncodeTiled");
|
||||
*reinterpret_cast<void**>(&_cuMemcpyDtoH) = load_sym(handle, "cuMemcpyDtoH_v2");
|
||||
}
|
||||
|
||||
CUDADriverWrapper::~CUDADriverWrapper()
|
||||
{
|
||||
dllClose(handle);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuGetErrorName(CUresult error, char const** pStr) const
|
||||
{
|
||||
return (*_cuGetErrorName)(error, pStr);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuGetErrorMessage(CUresult error, char const** pStr) const
|
||||
{
|
||||
return (*_cuGetErrorMessage)(error, pStr);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value) const
|
||||
{
|
||||
return (*_cuFuncSetAttribute)(hfunc, attrib, value);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut) const
|
||||
{
|
||||
return (*_cuLinkComplete)(state, cubinOut, sizeOut);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuModuleUnload(CUmodule hmod) const
|
||||
{
|
||||
return (*_cuModuleUnload)(hmod);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuLinkDestroy(CUlinkState state) const
|
||||
{
|
||||
return (*_cuLinkDestroy)(state);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuModuleLoadData(CUmodule* module, void const* image) const
|
||||
{
|
||||
return (*_cuModuleLoadData)(module, image);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuLinkCreate(
|
||||
unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut) const
|
||||
{
|
||||
return (*_cuLinkCreate)(numOptions, options, optionValues, stateOut);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuModuleGetFunction(CUfunction* hfunc, CUmodule hmod, char const* name) const
|
||||
{
|
||||
return (*_cuModuleGetFunction)(hfunc, hmod, name);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuModuleGetGlobal(CUdeviceptr* dptr, size_t* bytes, CUmodule hmod, char const* name) const
|
||||
{
|
||||
return (*_cuModuleGetGlobal)(dptr, bytes, hmod, name);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuLinkAddFile(CUlinkState state, CUjitInputType type, char const* path,
|
||||
unsigned int numOptions, CUjit_option* options, void** optionValues) const
|
||||
{
|
||||
return (*_cuLinkAddFile)(state, type, path, numOptions, options, optionValues);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuLinkAddData(CUlinkState state, CUjitInputType type, void* data, size_t size,
|
||||
char const* name, unsigned int numOptions, CUjit_option* options, void** optionValues) const
|
||||
{
|
||||
return (*_cuLinkAddData)(state, type, data, size, name, numOptions, options, optionValues);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuLaunchCooperativeKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY,
|
||||
unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ,
|
||||
unsigned int sharedMemBytes, CUstream hStream, void** kernelParams) const
|
||||
{
|
||||
return (*_cuLaunchCooperativeKernel)(
|
||||
f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, sharedMemBytes, hStream, kernelParams);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY,
|
||||
unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ,
|
||||
unsigned int sharedMemBytes, CUstream hStream, void** kernelParams, void** extra) const
|
||||
{
|
||||
return (*_cuLaunchKernel)(
|
||||
f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, sharedMemBytes, hStream, kernelParams, extra);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuTensorMapEncodeTiled(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType,
|
||||
cuuint32_t tensorRank, void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides,
|
||||
cuuint32_t const* boxDim, cuuint32_t const* elementStrides, CUtensorMapInterleave interleave,
|
||||
CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill) const
|
||||
{
|
||||
return (*_cuTensorMapEncodeTiled)(tensorMap, tensorDataType, tensorRank, globalAddress, globalDim, globalStrides,
|
||||
boxDim, elementStrides, interleave, swizzle, l2Promotion, oobFill);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuMemcpyDtoH(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount) const
|
||||
{
|
||||
return (*_cuMemcpyDtoH)(dstHost, srcDevice, ByteCount);
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::common
|
||||
@@ -1,138 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef CUDA_DRIVER_WRAPPER_H
|
||||
#define CUDA_DRIVER_WRAPPER_H
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include <cstdio>
|
||||
#include <cuda.h>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
|
||||
namespace tensorrt_llm::common
|
||||
{
|
||||
|
||||
class CUDADriverWrapper
|
||||
{
|
||||
public:
|
||||
static std::shared_ptr<CUDADriverWrapper> getInstance();
|
||||
|
||||
~CUDADriverWrapper();
|
||||
CUDADriverWrapper(CUDADriverWrapper const&) = delete;
|
||||
CUDADriverWrapper operator=(CUDADriverWrapper const&) = delete;
|
||||
CUDADriverWrapper(CUDADriverWrapper&&) = delete;
|
||||
CUDADriverWrapper operator=(CUDADriverWrapper&&) = delete;
|
||||
|
||||
CUresult cuGetErrorName(CUresult error, char const** pStr) const;
|
||||
|
||||
CUresult cuGetErrorMessage(CUresult error, char const** pStr) const;
|
||||
|
||||
CUresult cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value) const;
|
||||
|
||||
CUresult cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut) const;
|
||||
|
||||
CUresult cuModuleUnload(CUmodule hmod) const;
|
||||
|
||||
CUresult cuLinkDestroy(CUlinkState state) const;
|
||||
|
||||
CUresult cuModuleLoadData(CUmodule* module, void const* image) const;
|
||||
|
||||
CUresult cuLinkCreate(
|
||||
unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut) const;
|
||||
|
||||
CUresult cuModuleGetFunction(CUfunction* hfunc, CUmodule hmod, char const* name) const;
|
||||
|
||||
CUresult cuModuleGetGlobal(CUdeviceptr* dptr, size_t* bytes, CUmodule hmod, char const* name) const;
|
||||
|
||||
CUresult cuLinkAddFile(CUlinkState state, CUjitInputType type, char const* path, unsigned int numOptions,
|
||||
CUjit_option* options, void** optionValues) const;
|
||||
|
||||
CUresult cuLinkAddData(CUlinkState state, CUjitInputType type, void* data, size_t size, char const* name,
|
||||
unsigned int numOptions, CUjit_option* options, void** optionValues) const;
|
||||
|
||||
CUresult cuLaunchCooperativeKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY,
|
||||
unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ,
|
||||
unsigned int sharedMemBytes, CUstream hStream, void** kernelParams) const;
|
||||
|
||||
CUresult cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ,
|
||||
unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes,
|
||||
CUstream hStream, void** kernelParams, void** extra) const;
|
||||
|
||||
CUresult cuTensorMapEncodeTiled(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType, cuuint32_t tensorRank,
|
||||
void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides, cuuint32_t const* boxDim,
|
||||
cuuint32_t const* elementStrides, CUtensorMapInterleave interleave, CUtensorMapSwizzle swizzle,
|
||||
CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill) const;
|
||||
|
||||
CUresult cuMemcpyDtoH(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount) const;
|
||||
|
||||
private:
|
||||
void* handle;
|
||||
CUDADriverWrapper();
|
||||
|
||||
CUresult (*_cuGetErrorName)(CUresult, char const**);
|
||||
CUresult (*_cuGetErrorMessage)(CUresult, char const**);
|
||||
CUresult (*_cuFuncSetAttribute)(CUfunction, CUfunction_attribute, int);
|
||||
CUresult (*_cuLinkComplete)(CUlinkState, void**, size_t*);
|
||||
CUresult (*_cuModuleUnload)(CUmodule);
|
||||
CUresult (*_cuLinkDestroy)(CUlinkState);
|
||||
CUresult (*_cuLinkCreate)(unsigned int, CUjit_option*, void**, CUlinkState*);
|
||||
CUresult (*_cuModuleLoadData)(CUmodule*, void const*);
|
||||
CUresult (*_cuModuleGetFunction)(CUfunction*, CUmodule, char const*);
|
||||
CUresult (*_cuModuleGetGlobal)(CUdeviceptr*, size_t*, CUmodule, char const*);
|
||||
CUresult (*_cuLinkAddFile)(CUlinkState, CUjitInputType, char const*, unsigned int, CUjit_option*, void**);
|
||||
CUresult (*_cuLinkAddData)(
|
||||
CUlinkState, CUjitInputType, void*, size_t, char const*, unsigned int, CUjit_option*, void**);
|
||||
CUresult (*_cuLaunchCooperativeKernel)(CUfunction, unsigned int, unsigned int, unsigned int, unsigned int,
|
||||
unsigned int, unsigned int, unsigned int, CUstream, void**);
|
||||
CUresult (*_cuLaunchKernel)(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ,
|
||||
unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes,
|
||||
CUstream hStream, void** kernelParams, void** extra);
|
||||
CUresult (*_cuTensorMapEncodeTiled)(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType,
|
||||
cuuint32_t tensorRank, void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides,
|
||||
cuuint32_t const* boxDim, cuuint32_t const* elementStrides, CUtensorMapInterleave interleave,
|
||||
CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill);
|
||||
CUresult (*_cuMemcpyDtoH)(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void checkDriver(
|
||||
T result, CUDADriverWrapper const& wrap, char const* const func, char const* const file, int const line)
|
||||
{
|
||||
if (result)
|
||||
{
|
||||
char const* errorName = nullptr;
|
||||
char const* errorMsg = nullptr;
|
||||
wrap.cuGetErrorName(result, &errorName);
|
||||
wrap.cuGetErrorMessage(result, &errorMsg);
|
||||
throw TllmException(
|
||||
file, line, fmtstr("[TensorRT-LLM][ERROR] CUDA driver error in %s: %s: %s", func, errorName, errorMsg));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::common
|
||||
|
||||
/*
|
||||
* Macros compliant with TensorRT coding conventions
|
||||
*/
|
||||
#define TLLM_CU_CHECK(stat) \
|
||||
do \
|
||||
{ \
|
||||
tensorrt_llm::common::checkDriver( \
|
||||
(stat), *tensorrt_llm::common::CUDADriverWrapper::getInstance(), #stat, __FILE__, __LINE__); \
|
||||
} while (0)
|
||||
|
||||
#endif // CUDA_DRIVER_WRAPPER_H
|
||||
@@ -1,436 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/common/cudaFp8Utils.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/common/reduceKernelUtils.cuh"
|
||||
#include <algorithm>
|
||||
#include <cstdio>
|
||||
#include <cuda_fp16.h>
|
||||
#include <limits>
|
||||
#include <type_traits>
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace common
|
||||
{
|
||||
#ifdef ENABLE_FP8
|
||||
|
||||
constexpr int CTA_SIZE = 256;
|
||||
|
||||
template <bool QUANTIZE>
|
||||
__inline__ __device__ float scale(float a, float b)
|
||||
{
|
||||
return QUANTIZE ? a / b : a * b;
|
||||
}
|
||||
|
||||
template <QuantizeMode QUANTIZE_MODE, bool QUANTIZE, typename T_OUT, typename T_S, typename T_IN>
|
||||
__global__ void scaleMatrix(T_OUT* output, T_S const* input_scale, T_IN const* input, int64_t numel, int64_t lda)
|
||||
{
|
||||
for (int64_t i = threadIdx.x + blockIdx.x * blockDim.x; i < numel; i += blockDim.x * gridDim.x)
|
||||
{
|
||||
|
||||
if (QUANTIZE_MODE == QuantizeMode::PER_CHANNEL)
|
||||
{
|
||||
output[i] = T_OUT(scale<QUANTIZE>(static_cast<float>(input[i]), static_cast<float>(input_scale[i % lda])));
|
||||
}
|
||||
else if (QUANTIZE_MODE == QuantizeMode::PER_TOKEN)
|
||||
{
|
||||
output[i] = T_OUT(scale<QUANTIZE>(static_cast<float>(input[i]), static_cast<float>(input_scale[i / lda])));
|
||||
}
|
||||
else if (QUANTIZE_MODE == QuantizeMode::PER_TENSOR)
|
||||
{
|
||||
output[i] = T_OUT(scale<QUANTIZE>(static_cast<float>(input[i]), static_cast<float>(input_scale[0])));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T_OUT, typename T_S, typename T_IN>
|
||||
void invokeQuantizeMatrix(T_OUT* output, T_S const* input_scale, T_IN const* input, int64_t numel, int64_t lda,
|
||||
QuantizeMode quantize_mode, cudaStream_t stream)
|
||||
{
|
||||
dim3 grid(1024);
|
||||
dim3 block(CTA_SIZE);
|
||||
if (quantize_mode == QuantizeMode::PER_CHANNEL)
|
||||
{
|
||||
scaleMatrix<QuantizeMode::PER_CHANNEL, true>
|
||||
<<<grid, block, 0, stream>>>(output, input_scale, input, numel, lda);
|
||||
}
|
||||
else if (quantize_mode == QuantizeMode::PER_TOKEN)
|
||||
{
|
||||
scaleMatrix<QuantizeMode::PER_TOKEN, true><<<grid, block, 0, stream>>>(output, input_scale, input, numel, lda);
|
||||
}
|
||||
else if (quantize_mode == QuantizeMode::PER_TENSOR)
|
||||
{
|
||||
scaleMatrix<QuantizeMode::PER_TENSOR, true><<<grid, block, 0, stream>>>(output, input_scale, input, numel, lda);
|
||||
}
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
|
||||
template <typename T_OUT, typename T_S, typename T_IN>
|
||||
void invokeDequantizeMatrix(T_OUT* output, T_S const* input_scale, T_IN const* input, int64_t numel, int64_t lda,
|
||||
QuantizeMode quantize_mode, cudaStream_t stream)
|
||||
{
|
||||
dim3 grid(1024);
|
||||
dim3 block(CTA_SIZE);
|
||||
if (quantize_mode == QuantizeMode::PER_CHANNEL)
|
||||
{
|
||||
scaleMatrix<QuantizeMode::PER_CHANNEL, false>
|
||||
<<<grid, block, 0, stream>>>(output, input_scale, input, numel, lda);
|
||||
}
|
||||
else if (quantize_mode == QuantizeMode::PER_TOKEN)
|
||||
{
|
||||
scaleMatrix<QuantizeMode::PER_TOKEN, false><<<grid, block, 0, stream>>>(output, input_scale, input, numel, lda);
|
||||
}
|
||||
else if (quantize_mode == QuantizeMode::PER_TENSOR)
|
||||
{
|
||||
scaleMatrix<QuantizeMode::PER_TENSOR, false>
|
||||
<<<grid, block, 0, stream>>>(output, input_scale, input, numel, lda);
|
||||
}
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
|
||||
template <typename T_FAKE, typename T_OUT, typename T_IN>
|
||||
__global__ void fakeQuantize(T_OUT* dst, const T_IN* src, const int64_t numel)
|
||||
{
|
||||
for (int64_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < numel; tid += blockDim.x * gridDim.x)
|
||||
{
|
||||
T_FAKE tmp = (T_FAKE) (static_cast<float>(src[tid]));
|
||||
dst[tid] = (T_OUT) (static_cast<float>(tmp));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T_FAKE, typename T_OUT, typename T_IN>
|
||||
void invokeFakeQuantize(T_OUT* dst, const T_IN* src, const int64_t numel, cudaStream_t stream)
|
||||
{
|
||||
fakeQuantize<T_FAKE><<<1024, CTA_SIZE, 0, stream>>>(dst, src, numel);
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
|
||||
template void invokeFakeQuantize<__nv_fp8_e4m3, float, float>(
|
||||
float* dst, float const* src, const int64_t numel, cudaStream_t stream);
|
||||
template void invokeFakeQuantize<float, float, __nv_fp8_e4m3>(
|
||||
float* dst, __nv_fp8_e4m3 const* src, const int64_t numel, cudaStream_t stream);
|
||||
template void invokeFakeQuantize<__nv_fp8_e4m3, half, half>(
|
||||
half* dst, half const* src, const int64_t numel, cudaStream_t stream);
|
||||
template void invokeFakeQuantize<__nv_fp8_e4m3, __nv_bfloat16, __nv_bfloat16>(
|
||||
__nv_bfloat16* dst, __nv_bfloat16 const* src, const int64_t numel, cudaStream_t stream);
|
||||
|
||||
template void invokeFakeQuantize<float, half, float>(
|
||||
half* dst, float const* src, const int64_t numel, cudaStream_t stream);
|
||||
|
||||
__device__ float atomicMaxExtd(float* address, float val)
|
||||
{
|
||||
assert(val >= 0);
|
||||
unsigned int* address_as_u = reinterpret_cast<unsigned int*>(address);
|
||||
unsigned int old = atomicMax(address_as_u, __float_as_uint(val));
|
||||
return __uint_as_float(old);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T atomicMaxExtdV2(T* address, T val)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
|
||||
static_assert(std::is_same_v<T, half> | std::is_same_v<T, __nv_bfloat16>, "T needs to be either half or bfloat16");
|
||||
// The address in 64 bits.
|
||||
uint64_t address_u64 = reinterpret_cast<uint64_t const&>(address);
|
||||
|
||||
// Pack the input value into 32 bits.
|
||||
union
|
||||
{
|
||||
T v[2];
|
||||
uint16_t u[2];
|
||||
} old, tmp = {};
|
||||
|
||||
int const loc = (address_u64 & 0x2) >> 1;
|
||||
tmp.v[loc] = val;
|
||||
|
||||
// 4B aligned pointer.
|
||||
auto aligned_address = reinterpret_cast<T*>(address_u64 & ~0x3ull);
|
||||
|
||||
if constexpr (std::is_same_v<T, half>)
|
||||
{
|
||||
asm volatile("atom.global.v2.f16.max.noftz {%0, %1}, [%2], {%3, %4};"
|
||||
: "=h"(old.u[0]), "=h"(old.u[1])
|
||||
: "l"(aligned_address), "h"(tmp.u[0]), "h"(tmp.u[1]));
|
||||
}
|
||||
if constexpr (std::is_same_v<T, __nv_bfloat16>)
|
||||
{
|
||||
asm volatile("atom.global.v2.bf16.max.noftz {%0, %1}, [%2], {%3, %4};"
|
||||
: "=h"(old.u[0]), "=h"(old.u[1])
|
||||
: "l"(aligned_address), "h"(tmp.u[0]), "h"(tmp.u[1]));
|
||||
}
|
||||
|
||||
// Return the correct half.
|
||||
return old.v[loc];
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ half atomicMaxExtd(half* address, half val)
|
||||
{
|
||||
unsigned short int* address_as_u = reinterpret_cast<unsigned short int*>(address);
|
||||
unsigned short int old = *address_as_u, assumed;
|
||||
|
||||
while (val > __ushort_as_half(old))
|
||||
{
|
||||
assumed = old;
|
||||
old = atomicCAS(address_as_u, assumed, __half_as_ushort(val));
|
||||
}
|
||||
|
||||
return __ushort_as_half(old);
|
||||
}
|
||||
|
||||
__device__ __nv_bfloat16 atomicMaxExtd(__nv_bfloat16* address, __nv_bfloat16 val)
|
||||
{
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
|
||||
unsigned short int* address_as_u = reinterpret_cast<unsigned short int*>(address);
|
||||
unsigned short int old = *address_as_u, assumed;
|
||||
|
||||
while (val > __ushort_as_bfloat16(old))
|
||||
{
|
||||
assumed = old;
|
||||
old = atomicCAS(address_as_u, assumed, __bfloat16_as_ushort(val));
|
||||
}
|
||||
|
||||
return __ushort_as_bfloat16(old);
|
||||
#else
|
||||
assert(0);
|
||||
asm volatile("brkpt;\n" ::);
|
||||
return __nv_bfloat16(0);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <QuantizeMode QUANTIZE_MODE, typename T_S, typename T_W>
|
||||
__global__ void computeFP8QuantizeScale(T_S* quant_ptr, const T_W* weights, const int64_t size, const int64_t n)
|
||||
{
|
||||
constexpr float min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f);
|
||||
if (QUANTIZE_MODE == QuantizeMode::PER_CHANNEL)
|
||||
{
|
||||
for (int64_t col = threadIdx.x; col < n; col += blockDim.x)
|
||||
{
|
||||
float max = 0.f;
|
||||
for (int64_t i = col + n * blockIdx.x; i < size; i += gridDim.x * n)
|
||||
{
|
||||
auto val = fabs(static_cast<float>(weights[i]));
|
||||
max = max > val ? max : val;
|
||||
}
|
||||
auto const scale = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor);
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
|
||||
if constexpr (std::is_same_v<T_S, float>)
|
||||
{
|
||||
atomicMaxExtd(quant_ptr + col, scale);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto const address_u64 = reinterpret_cast<uint64_t>(quant_ptr + col);
|
||||
if ((col == 0 && address_u64 % 4 != 0) || (col == n - 1 && address_u64 % 4 == 0))
|
||||
atomicMaxExtd(quant_ptr + col, scale);
|
||||
else
|
||||
atomicMaxExtdV2(quant_ptr + col, scale);
|
||||
}
|
||||
#else // Vector atomics require __CUDA_ARCH__ >= 900
|
||||
atomicMaxExtd(quant_ptr + col, scale);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else if (QUANTIZE_MODE == QuantizeMode::PER_TOKEN)
|
||||
{
|
||||
auto const nrows = size / n;
|
||||
for (int64_t row = blockIdx.x; row < nrows; row += gridDim.x)
|
||||
{
|
||||
float max = 0.f;
|
||||
for (int64_t i = threadIdx.x; i < n; i += blockDim.x)
|
||||
{
|
||||
auto val = fabs(static_cast<float>(weights[row * n + i]));
|
||||
max = max > val ? max : val;
|
||||
}
|
||||
max = blockReduceMax<float>(max);
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
auto const scale = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor);
|
||||
quant_ptr[row] = scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (QUANTIZE_MODE == QuantizeMode::PER_TENSOR)
|
||||
{
|
||||
float max = 0.f;
|
||||
for (int64_t i = threadIdx.x + blockIdx.x * blockDim.x; i < size; i += gridDim.x * blockDim.x)
|
||||
{
|
||||
auto val = fabs(static_cast<float>(weights[i]));
|
||||
max = max > val ? max : val;
|
||||
}
|
||||
max = blockReduceMax<float>(max);
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
auto const scale = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor);
|
||||
atomicMaxExtd(quant_ptr, scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T_S, typename T_W>
|
||||
void invokeComputeFP8QuantizeScale(T_S* quant_ptr, const T_W* weights, const int64_t numel, const int64_t lda,
|
||||
QuantizeMode quantize_mode, cudaStream_t stream)
|
||||
{
|
||||
if (quantize_mode == QuantizeMode::PER_TOKEN)
|
||||
{
|
||||
dim3 block(CTA_SIZE);
|
||||
dim3 grid(numel / lda);
|
||||
computeFP8QuantizeScale<QuantizeMode::PER_TOKEN><<<grid, block, 0, stream>>>(quant_ptr, weights, numel, lda);
|
||||
}
|
||||
else if (quantize_mode == QuantizeMode::PER_CHANNEL)
|
||||
{
|
||||
dim3 block(CTA_SIZE);
|
||||
dim3 grid((lda + CTA_SIZE - 1) / CTA_SIZE);
|
||||
cudaMemsetAsync(quant_ptr, 0, lda * sizeof(T_S), stream);
|
||||
sync_check_cuda_error();
|
||||
computeFP8QuantizeScale<QuantizeMode::PER_CHANNEL><<<grid, block, 0, stream>>>(quant_ptr, weights, numel, lda);
|
||||
}
|
||||
else if (quantize_mode == QuantizeMode::PER_TENSOR)
|
||||
{
|
||||
dim3 block(1024);
|
||||
dim3 grid(1024);
|
||||
cudaMemsetAsync(quant_ptr, 0, sizeof(T_S), stream);
|
||||
sync_check_cuda_error();
|
||||
computeFP8QuantizeScale<QuantizeMode::PER_TENSOR><<<grid, block, 0, stream>>>(quant_ptr, weights, numel, lda);
|
||||
}
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
|
||||
#define DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(type_scale, type_in) \
|
||||
template void invokeComputeFP8QuantizeScale<type_scale, type_in>(type_scale * input_scale, type_in const* weights, \
|
||||
int64_t numel, int64_t lda, QuantizeMode quantize_mode, cudaStream_t stream);
|
||||
|
||||
DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(half, half);
|
||||
DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(float, half);
|
||||
DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(float, float);
|
||||
#ifdef ENABLE_BF16
|
||||
DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(__nv_bfloat16, __nv_bfloat16);
|
||||
DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(float, __nv_bfloat16);
|
||||
#endif
|
||||
|
||||
template <typename T_OUT, typename T_S, typename T_IN>
|
||||
__global__ void dynamicQuantizeMatrixPerToken(
|
||||
T_OUT* output, T_S* quant_ptr, T_IN const* input, int64_t numel, int64_t lda)
|
||||
{
|
||||
extern __shared__ __align__(sizeof(float)) char _shmem[];
|
||||
T_IN* shmem = reinterpret_cast<T_IN*>(_shmem);
|
||||
constexpr float min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f);
|
||||
auto const nrows = numel / lda;
|
||||
for (int64_t row = blockIdx.x; row < nrows; row += gridDim.x)
|
||||
{
|
||||
float max = 0.f;
|
||||
for (int64_t i = threadIdx.x; i < lda; i += blockDim.x)
|
||||
{
|
||||
auto const in = input[row * lda + i];
|
||||
shmem[i] = in;
|
||||
auto val = fabs(static_cast<float>(in));
|
||||
max = max > val ? max : val;
|
||||
}
|
||||
max = blockAllReduceMax<float>(max); // __syncthreads() called so we can read shmem
|
||||
auto const s = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor);
|
||||
for (int64_t i = threadIdx.x; i < lda; i += blockDim.x)
|
||||
{
|
||||
// true means we are quantizing
|
||||
output[row * lda + i] = (T_OUT) scale<true>(static_cast<float>(shmem[i]), static_cast<float>(s));
|
||||
}
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
quant_ptr[row] = s;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T_OUT, typename T_S, typename T_IN>
|
||||
void invokeComputeScalesAndQuantizeMatrix(T_OUT* output, T_S* quant_ptr, const T_IN* input, const int64_t numel,
|
||||
const int64_t lda, QuantizeMode quantize_mode, cudaStream_t stream)
|
||||
{
|
||||
if (quantize_mode == QuantizeMode::PER_TOKEN)
|
||||
{
|
||||
dim3 grid(numel / lda);
|
||||
bool use_shmem = true;
|
||||
auto const shmem_size = lda * sizeof(T_IN);
|
||||
if (shmem_size >= (48 << 10))
|
||||
{
|
||||
cudaError_t ret = cudaFuncSetAttribute(dynamicQuantizeMatrixPerToken<T_OUT, T_S, T_IN>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size);
|
||||
use_shmem = ret == cudaSuccess;
|
||||
}
|
||||
if (use_shmem)
|
||||
{
|
||||
// ensure the threadblock is as large as possible to increase occupancy
|
||||
dim3 block(std::min((lda + 31) / 32 * 32, static_cast<int64_t>(1024)));
|
||||
dynamicQuantizeMatrixPerToken<<<grid, block, shmem_size, stream>>>(output, quant_ptr, input, numel, lda);
|
||||
}
|
||||
else
|
||||
{
|
||||
dim3 block(CTA_SIZE);
|
||||
computeFP8QuantizeScale<QuantizeMode::PER_TOKEN><<<grid, block, 0, stream>>>(quant_ptr, input, numel, lda);
|
||||
sync_check_cuda_error();
|
||||
invokeQuantizeMatrix(output, quant_ptr, input, numel, lda, quantize_mode, stream);
|
||||
}
|
||||
}
|
||||
else if (quantize_mode == QuantizeMode::PER_CHANNEL)
|
||||
{
|
||||
dim3 block(CTA_SIZE);
|
||||
dim3 grid((lda + CTA_SIZE - 1) / CTA_SIZE);
|
||||
cudaMemsetAsync(quant_ptr, 0, lda * sizeof(T_S), stream);
|
||||
sync_check_cuda_error();
|
||||
computeFP8QuantizeScale<QuantizeMode::PER_CHANNEL><<<grid, block, 0, stream>>>(quant_ptr, input, numel, lda);
|
||||
sync_check_cuda_error();
|
||||
invokeQuantizeMatrix(output, quant_ptr, input, numel, lda, quantize_mode, stream);
|
||||
}
|
||||
else if (quantize_mode == QuantizeMode::PER_TENSOR)
|
||||
{
|
||||
dim3 block(1024);
|
||||
dim3 grid(1024);
|
||||
cudaMemsetAsync(quant_ptr, 0, sizeof(T_S), stream);
|
||||
sync_check_cuda_error();
|
||||
computeFP8QuantizeScale<QuantizeMode::PER_TENSOR><<<grid, block, 0, stream>>>(quant_ptr, input, numel, lda);
|
||||
sync_check_cuda_error();
|
||||
invokeQuantizeMatrix(output, quant_ptr, input, numel, lda, quantize_mode, stream);
|
||||
}
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
|
||||
#define DEFINE_INVOKE_QUANTIZE_MATRIX(type_out, type_scale, type_in) \
|
||||
template void invokeQuantizeMatrix<type_out, type_scale, type_in>(type_out * output, \
|
||||
type_scale const* input_scale, type_in const* input, int64_t numel, int64_t lda, QuantizeMode quantize_mode, \
|
||||
cudaStream_t stream); \
|
||||
template void invokeDequantizeMatrix<type_out, type_scale, type_in>(type_out * output, \
|
||||
type_scale const* input_scale, type_in const* input, int64_t numel, int64_t lda, QuantizeMode quantize_mode, \
|
||||
cudaStream_t stream); \
|
||||
template void invokeComputeScalesAndQuantizeMatrix<type_out, type_scale, type_in>(type_out * output, \
|
||||
type_scale * input_scale, type_in const* input, int64_t numel, int64_t lda, QuantizeMode quantize_mode, \
|
||||
cudaStream_t stream);
|
||||
|
||||
#ifdef ENABLE_FP8
|
||||
DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_fp8_e4m3, float, float);
|
||||
DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_fp8_e4m3, float, half);
|
||||
DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_fp8_e4m3, half, half);
|
||||
DEFINE_INVOKE_QUANTIZE_MATRIX(half, half, __nv_fp8_e4m3);
|
||||
DEFINE_INVOKE_QUANTIZE_MATRIX(float, float, __nv_fp8_e4m3);
|
||||
DEFINE_INVOKE_QUANTIZE_MATRIX(half, float, __nv_fp8_e4m3);
|
||||
#ifdef ENABLE_BF16
|
||||
DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_fp8_e4m3, __nv_bfloat16, __nv_bfloat16);
|
||||
DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_bfloat16, __nv_bfloat16, __nv_fp8_e4m3);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#endif // ENABLE_FP8
|
||||
} // namespace common
|
||||
} // namespace tensorrt_llm
|
||||
@@ -1,239 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifdef ENABLE_FP8
|
||||
#include <cuda_fp8.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#define FP8_MHA
|
||||
#define FUSE_GEMM_ACT
|
||||
#define FP8_GEMM_OUTPUT_QUANT_DISABLE
|
||||
|
||||
#ifdef FUSE_GEMM_ACT
|
||||
#define USE_QGMMA
|
||||
#endif
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace common
|
||||
{
|
||||
|
||||
constexpr float FP8_E4M3_MAX = 448.0f;
|
||||
|
||||
enum QuantizeMode
|
||||
{
|
||||
PER_CHANNEL,
|
||||
PER_TENSOR,
|
||||
PER_CHANNEL_WEIGHT_PER_TENSOR_ACT,
|
||||
PER_TOKEN,
|
||||
};
|
||||
|
||||
// Packed Data Type
|
||||
typedef struct __CUDA_ALIGN__(32)
|
||||
{
|
||||
float array[8];
|
||||
} float8;
|
||||
|
||||
typedef struct __CUDA_ALIGN__(16)
|
||||
{
|
||||
half array[8];
|
||||
} half8;
|
||||
|
||||
typedef struct __CUDA_ALIGN__(8)
|
||||
{
|
||||
half2 array[2];
|
||||
} half2_2;
|
||||
|
||||
typedef struct __CUDA_ALIGN__(8)
|
||||
{
|
||||
half array[4];
|
||||
} half_4;
|
||||
|
||||
#ifdef ENABLE_BF16
|
||||
typedef struct __CUDA_ALIGN__(4)
|
||||
{
|
||||
__nv_bfloat16 array[2];
|
||||
} __nv_bfloat16_2;
|
||||
|
||||
typedef struct __CUDA_ALIGN__(8)
|
||||
{
|
||||
__nv_bfloat162 x, y;
|
||||
} __nv_bfloat162_2_xy;
|
||||
|
||||
typedef struct __CUDA_ALIGN__(8)
|
||||
{
|
||||
__nv_bfloat16 array[4];
|
||||
} __nv_bfloat164;
|
||||
|
||||
typedef struct __CUDA_ALIGN__(8)
|
||||
{
|
||||
__nv_bfloat162 array[2];
|
||||
} __nv_bfloat162_2;
|
||||
|
||||
typedef struct __CUDA_ALIGN__(16)
|
||||
{
|
||||
__nv_bfloat16 array[8];
|
||||
} __nv_bfloat168;
|
||||
|
||||
typedef struct __CUDA_ALIGN__(16)
|
||||
{
|
||||
__nv_bfloat162 array[4];
|
||||
} __nv_bfloat162_4;
|
||||
|
||||
typedef struct __CUDA_ALIGN__(32)
|
||||
{
|
||||
__nv_bfloat16 array[16];
|
||||
} __nv_bfloat1616;
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_FP8
|
||||
typedef struct __CUDA_ALIGN__(2)
|
||||
{
|
||||
__nv_fp8_e4m3 array[2];
|
||||
} __nv_fp8_2_e4m3;
|
||||
|
||||
typedef struct __CUDA_ALIGN__(4)
|
||||
{
|
||||
__nv_fp8_e4m3 array[4];
|
||||
} __nv_fp8_4_e4m3;
|
||||
|
||||
typedef struct __CUDA_ALIGN__(4)
|
||||
{
|
||||
__nv_fp8x2_e4m3 array[2];
|
||||
} __nv_fp8x2_x2_e4m3;
|
||||
|
||||
typedef struct __CUDA_ALIGN__(8)
|
||||
{
|
||||
__nv_fp8_e4m3 array[8];
|
||||
} __nv_fp8_8_e4m3;
|
||||
|
||||
typedef struct __CUDA_ALIGN__(8)
|
||||
{
|
||||
__nv_fp8x2_e4m3 array[4];
|
||||
} __nv_fp8x2_x4_e4m3;
|
||||
|
||||
typedef struct __CUDA_ALIGN__(16)
|
||||
{
|
||||
__nv_fp8_e4m3 array[16];
|
||||
} __nv_fp8x16_e4m3;
|
||||
#endif
|
||||
|
||||
// only BF16 and FP8
|
||||
template <typename T, int PACK_SIZE>
|
||||
struct PackType
|
||||
{
|
||||
using type = float;
|
||||
};
|
||||
|
||||
#ifdef ENABLE_BF16
|
||||
template <>
|
||||
struct PackType<__nv_bfloat16, 2>
|
||||
{
|
||||
using type = __nv_bfloat16_2;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PackType<__nv_bfloat16, 4>
|
||||
{
|
||||
using type = __nv_bfloat164;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PackType<__nv_bfloat16, 8>
|
||||
{
|
||||
using type = __nv_bfloat168;
|
||||
};
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_FP8
|
||||
template <>
|
||||
struct PackType<__nv_fp8_e4m3, 2>
|
||||
{
|
||||
using type = __nv_fp8_2_e4m3;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PackType<__nv_fp8_e4m3, 4>
|
||||
{
|
||||
using type = __nv_fp8_4_e4m3;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PackType<__nv_fp8_e4m3, 8>
|
||||
{
|
||||
using type = __nv_fp8_8_e4m3;
|
||||
};
|
||||
#endif
|
||||
|
||||
__inline__ __device__ void fp8x4_e4m3_to_bfloat2(__nv_bfloat162* out1, __nv_bfloat162* out2, __nv_fp8x4_e4m3 const* in)
|
||||
{
|
||||
const char4 tmp_val = reinterpret_cast<char4 const*>(in)[0];
|
||||
*out1 = __nv_bfloat162((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.x)[0],
|
||||
(float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.y)[0]);
|
||||
*out2 = __nv_bfloat162((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.z)[0],
|
||||
(float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.w)[0]);
|
||||
}
|
||||
|
||||
__inline__ __device__ __nv_bfloat162 fp8x2_e4m3_to_bfloat2(__nv_fp8x2_e4m3 const* in)
|
||||
{
|
||||
const char2 tmp_val = reinterpret_cast<char2 const*>(in)[0];
|
||||
__nv_bfloat162 out = __nv_bfloat162((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.x)[0],
|
||||
(float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.y)[0]);
|
||||
return out;
|
||||
}
|
||||
|
||||
__inline__ __device__ void fp8x4_e4m3_to_half2(half2* out1, half2* out2, __nv_fp8x4_e4m3 const* in)
|
||||
{
|
||||
const char4 tmp_val = reinterpret_cast<char4 const*>(in)[0];
|
||||
*out1 = half2((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.x)[0],
|
||||
(float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.y)[0]);
|
||||
*out2 = half2((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.z)[0],
|
||||
(float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.w)[0]);
|
||||
}
|
||||
|
||||
__inline__ __device__ half2 fp8x2_e4m3_to_half2(__nv_fp8x2_e4m3 const* in)
|
||||
{
|
||||
const char2 tmp_val = reinterpret_cast<char2 const*>(in)[0];
|
||||
half2 out = half2((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.x)[0],
|
||||
(float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.y)[0]);
|
||||
return out;
|
||||
}
|
||||
|
||||
template <typename T_OUT, typename T_S, typename T_IN>
|
||||
void invokeQuantizeMatrix(T_OUT* output, T_S const* input_qua_amax_ptr, T_IN const* input, int64_t numel, int64_t lda,
|
||||
QuantizeMode quantize_mode, cudaStream_t stream);
|
||||
|
||||
template <typename T_OUT, typename T_S, typename T_IN>
|
||||
void invokeDequantizeMatrix(T_OUT* output, T_S const* input_qua_amax_ptr, T_IN const* input, int64_t numel, int64_t lda,
|
||||
QuantizeMode quantize_mode, cudaStream_t stream);
|
||||
|
||||
template <typename T_FAKE, typename T_OUT, typename T_IN>
|
||||
void invokeFakeQuantize(T_OUT* dst, const T_IN* src, const int64_t numel, cudaStream_t stream);
|
||||
|
||||
template <typename T_S, typename T_W>
|
||||
void invokeComputeFP8QuantizeScale(T_S* quant_ptr, const T_W* weights, const int64_t k, const int64_t lda,
|
||||
QuantizeMode quantize_mode, cudaStream_t stream);
|
||||
|
||||
template <typename T_OUT, typename T_S, typename T_IN>
|
||||
void invokeComputeScalesAndQuantizeMatrix(T_OUT* output, T_S* quant_ptr, const T_IN* weights, const int64_t numel,
|
||||
const int64_t lda, QuantizeMode quantize_mode, cudaStream_t stream);
|
||||
|
||||
} // namespace common
|
||||
} // namespace tensorrt_llm
|
||||
#endif // ENABLE_FP8
|
||||
@@ -1,752 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/cudaBf16Fallbacks.cuh"
|
||||
#include "tensorrt_llm/common/cudaBf16Wrapper.h"
|
||||
#include "tensorrt_llm/common/cudaFp8Utils.h"
|
||||
#include <assert.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#if ENABLE_BF16
|
||||
#include <cuda_bf16.h>
|
||||
#endif
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace common
|
||||
{
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T ldg(T const* val)
|
||||
{
|
||||
return __ldg(val);
|
||||
}
|
||||
|
||||
#if ENABLE_BF16
|
||||
template <>
|
||||
inline __device__ __nv_bfloat162 ldg(__nv_bfloat162 const* val)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
return val[0];
|
||||
#else
|
||||
return __ldg(val);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __device__ __nv_bfloat16 ldg(__nv_bfloat16 const* val)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
return val[0];
|
||||
#else
|
||||
return __ldg(val);
|
||||
#endif
|
||||
}
|
||||
#endif // ENABLE_BF16
|
||||
|
||||
// Get type2 from type or vice versa (applied to half and bfloat16)
|
||||
template <typename T>
|
||||
struct TypeConverter
|
||||
{
|
||||
using Type = half2;
|
||||
}; // keep for generality
|
||||
|
||||
template <>
|
||||
struct TypeConverter<half2>
|
||||
{
|
||||
using Type = half;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeConverter<half>
|
||||
{
|
||||
using Type = half2;
|
||||
};
|
||||
|
||||
#if ENABLE_BF16
|
||||
template <>
|
||||
struct TypeConverter<__nv_bfloat162>
|
||||
{
|
||||
using Type = __nv_bfloat16;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeConverter<__nv_bfloat16>
|
||||
{
|
||||
using Type = __nv_bfloat162;
|
||||
};
|
||||
#endif // ENABLE_BF16
|
||||
|
||||
// Defined math operations (bfloat16 fallback to fp32 when it is not supported)
|
||||
template <typename T>
|
||||
inline __device__ T hadd2(T a, T b)
|
||||
{
|
||||
return __hadd2(a, b);
|
||||
}
|
||||
|
||||
#if ENABLE_BF16
|
||||
template <>
|
||||
inline __device__ __nv_bfloat162 hadd2(__nv_bfloat162 a, __nv_bfloat162 b)
|
||||
{
|
||||
return bf16hadd2(a, b);
|
||||
}
|
||||
#endif // ENABLE_BF16
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T add(T a, T b)
|
||||
{
|
||||
return a + b;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __device__ half2 add(half2 a, half2 b)
|
||||
{
|
||||
return __hadd2(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __device__ half add(half a, half b)
|
||||
{
|
||||
return __hadd(a, b);
|
||||
}
|
||||
|
||||
#if ENABLE_BF16
|
||||
template <>
|
||||
inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b)
|
||||
{
|
||||
return bf16hadd2(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b)
|
||||
{
|
||||
return bf16hadd(a, b);
|
||||
}
|
||||
|
||||
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, float b)
|
||||
{
|
||||
return bf16hadd(a, __float2bfloat16(b));
|
||||
}
|
||||
#endif // ENABLE_BF16
|
||||
|
||||
// applies to all 4 values addition
|
||||
template <typename T>
|
||||
inline __device__ T add(T a, T b, T c)
|
||||
{
|
||||
return a + b + c;
|
||||
}
|
||||
|
||||
#if ENABLE_BF16
|
||||
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c)
|
||||
{
|
||||
return bf16hadd(a, b, c);
|
||||
}
|
||||
|
||||
inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)
|
||||
{
|
||||
return bf16hadd2(a, b, c);
|
||||
}
|
||||
#endif // ENABLE_BF16
|
||||
|
||||
// applies to all 4 values addition
|
||||
template <typename T>
|
||||
inline __device__ T add(T a, T b, T c, T d)
|
||||
{
|
||||
return (T) ((float) a + (float) b + (float) c + (float) d);
|
||||
}
|
||||
|
||||
#if ENABLE_BF16
|
||||
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c, __nv_bfloat16 d)
|
||||
{
|
||||
return bf16hadd(a, b, c, d);
|
||||
}
|
||||
#endif // ENABLE_BF16
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T hsub2(T a, T b)
|
||||
{
|
||||
return __hsub2(a, b);
|
||||
}
|
||||
|
||||
#if ENABLE_BF16
|
||||
template <>
|
||||
inline __device__ __nv_bfloat162 hsub2(__nv_bfloat162 a, __nv_bfloat162 b)
|
||||
{
|
||||
return bf16hsub2(a, b);
|
||||
}
|
||||
#endif // ENABLE_BF16
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T hmul2(T a, T b)
|
||||
{
|
||||
return __hmul2(a, b);
|
||||
}
|
||||
|
||||
#if ENABLE_BF16
|
||||
template <>
|
||||
inline __device__ __nv_bfloat162 hmul2(__nv_bfloat162 a, __nv_bfloat162 b)
|
||||
{
|
||||
return bf16hmul2(a, b);
|
||||
}
|
||||
#endif // ENABLE_BF16
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T hmul2(T a, T b, T c)
|
||||
{
|
||||
return a * b * c;
|
||||
}
|
||||
|
||||
#if ENABLE_BF16
|
||||
template <>
|
||||
inline __device__ __nv_bfloat162 hmul2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)
|
||||
{
|
||||
return bf16hmul2(a, b, c);
|
||||
}
|
||||
#endif // ENABLE_BF16
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T mul(T a, T b, T c)
|
||||
{
|
||||
return a * b * c;
|
||||
}
|
||||
|
||||
#if ENABLE_BF16
|
||||
template <>
|
||||
inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c)
|
||||
{
|
||||
return bf16hmul(a, b, c);
|
||||
}
|
||||
|
||||
inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)
|
||||
{
|
||||
return bf16hmul2(a, b, c);
|
||||
}
|
||||
#endif // ENABLE_BF16
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T fma(T a, T b, T c, T d)
|
||||
{
|
||||
return a * b * c + d;
|
||||
}
|
||||
|
||||
#if ENABLE_BF16
|
||||
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c, __nv_bfloat162 d)
|
||||
{
|
||||
return bf16hfma2(a, b, c, d);
|
||||
}
|
||||
#endif // ENABLE_BF16
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T fma(T a, T b, T c)
|
||||
{
|
||||
return a * b + c;
|
||||
}
|
||||
|
||||
#if ENABLE_BF16
|
||||
template <>
|
||||
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)
|
||||
{
|
||||
return bf16hfma2(a, b, c);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __device__ __nv_bfloat16 fma(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c)
|
||||
{
|
||||
return bf16hfma(a, b, c);
|
||||
}
|
||||
#endif // ENABLE_BF16
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T hexp2(T a)
|
||||
{
|
||||
return h2exp(a);
|
||||
}
|
||||
|
||||
#if ENABLE_BF16
|
||||
template <>
|
||||
inline __device__ __nv_bfloat162 hexp2(__nv_bfloat162 a)
|
||||
{
|
||||
return bf16exp2(a);
|
||||
}
|
||||
#endif // ENABLE_BF16
|
||||
|
||||
template <typename T_OUT, typename T_IN>
|
||||
__device__ inline T_OUT cuda_cast(T_IN val)
|
||||
{
|
||||
return val;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline float2 cuda_cast<float2, int2>(int2 val)
|
||||
{
|
||||
return make_float2(val.x, val.y);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline float2 cuda_cast<float2, float>(float val)
|
||||
{
|
||||
return make_float2(val, val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline float2 cuda_cast<float2, half2>(half2 val)
|
||||
{
|
||||
return __half22float2(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline half2 cuda_cast<half2, float2>(float2 val)
|
||||
{
|
||||
return __float22half2_rn(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline half2 cuda_cast<half2, float>(float val)
|
||||
{
|
||||
return __float2half2_rn(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline half2 cuda_cast<half2, half>(half val)
|
||||
{
|
||||
return __half2half2(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline int8_t cuda_cast<int8_t, half>(half val)
|
||||
{
|
||||
union
|
||||
{
|
||||
int8_t int8[2];
|
||||
int16_t int16;
|
||||
};
|
||||
|
||||
union
|
||||
{
|
||||
half fp16;
|
||||
int16_t int16_in;
|
||||
};
|
||||
|
||||
fp16 = val;
|
||||
asm volatile("cvt.rni.sat.s8.f16 %0, %1;" : "=h"(int16) : "h"(int16_in));
|
||||
return int8[0];
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline int16_t cuda_cast<int16_t, half2>(half2 val)
|
||||
{
|
||||
union
|
||||
{
|
||||
int8_t int8[2];
|
||||
int16_t int16;
|
||||
};
|
||||
|
||||
int8[0] = cuda_cast<int8_t>(val.x);
|
||||
int8[1] = cuda_cast<int8_t>(val.y);
|
||||
return int16;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline int8_t cuda_cast<int8_t, float>(float val)
|
||||
{
|
||||
union
|
||||
{
|
||||
int8_t int8[2];
|
||||
int16_t int16;
|
||||
};
|
||||
|
||||
asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=h"(int16) : "f"(val));
|
||||
return int8[0];
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline int16_t cuda_cast<int16_t, float2>(float2 val)
|
||||
{
|
||||
union
|
||||
{
|
||||
int8_t int8[2];
|
||||
int16_t int16;
|
||||
};
|
||||
|
||||
int8[0] = cuda_cast<int8_t>(val.x);
|
||||
int8[1] = cuda_cast<int8_t>(val.y);
|
||||
return int16;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline half2 cuda_cast<half2, int16_t>(int16_t val)
|
||||
{
|
||||
union
|
||||
{
|
||||
int8_t int8[2];
|
||||
int16_t int16;
|
||||
};
|
||||
|
||||
int16 = val;
|
||||
return make_half2(int8[0], int8[1]);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline float2 cuda_cast<float2, int16_t>(int16_t val)
|
||||
{
|
||||
union
|
||||
{
|
||||
int8_t int8[2];
|
||||
int16_t int16;
|
||||
};
|
||||
|
||||
int16 = val;
|
||||
return make_float2(int8[0], int8[1]);
|
||||
}
|
||||
|
||||
#ifdef ENABLE_BF16
|
||||
template <>
|
||||
__device__ inline __nv_bfloat16 cuda_cast(int32_t val)
|
||||
{
|
||||
return static_cast<float>(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline __nv_bfloat16 cuda_cast(int8_t val)
|
||||
{
|
||||
return static_cast<float>(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline int8_t cuda_cast(__nv_bfloat16 val)
|
||||
{
|
||||
return static_cast<float>(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline float cuda_cast<float, __nv_bfloat16>(__nv_bfloat16 val)
|
||||
{
|
||||
return __bfloat162float(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline float2 cuda_cast<float2, __nv_bfloat162>(__nv_bfloat162 val)
|
||||
{
|
||||
return bf1622float2(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline half cuda_cast<half, __nv_bfloat16>(__nv_bfloat16 val)
|
||||
{
|
||||
return __float2half(__bfloat162float(val));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline int16_t cuda_cast<int16_t, __nv_bfloat162>(__nv_bfloat162 val)
|
||||
{
|
||||
return bf1622int16(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline __nv_bfloat16 cuda_cast<__nv_bfloat16, float>(float val)
|
||||
{
|
||||
return __float2bfloat16(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline __nv_bfloat16 cuda_cast<__nv_bfloat16, half>(half val)
|
||||
{
|
||||
return __float2bfloat16(__half2float(val));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, __nv_bfloat16>(__nv_bfloat16 val)
|
||||
{
|
||||
return bf162bf162(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, float>(float val)
|
||||
{
|
||||
return __float2bfloat162_rn(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, float2>(float2 val)
|
||||
{
|
||||
return float22bf162(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, int16_t>(int16_t val)
|
||||
{
|
||||
union
|
||||
{
|
||||
int8_t int8[2];
|
||||
int16_t int16;
|
||||
};
|
||||
|
||||
int16 = val;
|
||||
__nv_bfloat162 res;
|
||||
res.x = cuda_cast<__nv_bfloat16>(int8[0]);
|
||||
res.y = cuda_cast<__nv_bfloat16>(int8[1]);
|
||||
return res;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, half2>(half2 val)
|
||||
{
|
||||
return float22bf162(__half22float2(val));
|
||||
}
|
||||
|
||||
#endif // ENABLE BF16
|
||||
|
||||
template <typename T>
|
||||
__device__ inline T cuda_abs(T val)
|
||||
{
|
||||
assert(false);
|
||||
return {};
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline float cuda_abs(float val)
|
||||
{
|
||||
return fabs(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline float2 cuda_abs(float2 val)
|
||||
{
|
||||
return make_float2(fabs(val.x), fabs(val.y));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline half cuda_abs(half val)
|
||||
{
|
||||
return __habs(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline half2 cuda_abs(half2 val)
|
||||
{
|
||||
return __habs2(val);
|
||||
}
|
||||
|
||||
#ifdef ENABLE_BF16
|
||||
|
||||
#if __CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)
|
||||
template <>
|
||||
__device__ inline __nv_bfloat16 cuda_abs(__nv_bfloat16 val)
|
||||
{
|
||||
return __habs(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline __nv_bfloat162 cuda_abs(__nv_bfloat162 val)
|
||||
{
|
||||
return __habs2(val);
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // ENABLE_FP16
|
||||
|
||||
template <typename To, typename Ti>
|
||||
__device__ inline To cuda_sum(Ti val)
|
||||
{
|
||||
return cuda_cast<To>(val);
|
||||
};
|
||||
|
||||
template <typename To>
|
||||
__device__ inline To cuda_sum(float2 val)
|
||||
{
|
||||
return cuda_cast<To>(val.x + val.y);
|
||||
};
|
||||
|
||||
// Unary maximum: compute the max of a vector type
|
||||
template <typename To, typename Ti>
|
||||
__device__ inline To cuda_max(Ti val)
|
||||
{
|
||||
return cuda_cast<To>(val);
|
||||
};
|
||||
|
||||
template <>
|
||||
__device__ inline float cuda_max(float2 val)
|
||||
{
|
||||
return fmaxf(val.x, val.y);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline half cuda_max(half2 val)
|
||||
{
|
||||
return __hmax(val.x, val.y);
|
||||
}
|
||||
|
||||
#ifdef ENABLE_BF16
|
||||
template <>
|
||||
__device__ inline __nv_bfloat16 cuda_max(__nv_bfloat162 val)
|
||||
{
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
|
||||
return __hmax(val.x, val.y);
|
||||
#else
|
||||
assert(0);
|
||||
asm volatile("brkpt;\n" ::);
|
||||
return __nv_bfloat16(0);
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
||||
// Binary maximum: compute the max of two values.
|
||||
template <typename T>
|
||||
__device__ inline T cuda_max(T val1, T val2)
|
||||
{
|
||||
return (val1 > val2) ? val1 : val2;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline float2 cuda_max(float2 val1, float2 val2)
|
||||
{
|
||||
float2 out;
|
||||
out.x = fmaxf(val1.x, val2.x);
|
||||
out.y = fmaxf(val1.y, val2.y);
|
||||
return out;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline half2 cuda_max(half2 val1, half2 val2)
|
||||
{
|
||||
return __hmax2(val1, val2);
|
||||
}
|
||||
|
||||
#ifdef ENABLE_BF16
|
||||
template <>
|
||||
__device__ inline __nv_bfloat162 cuda_max(__nv_bfloat162 val1, __nv_bfloat162 val2)
|
||||
{
|
||||
return __hmax2(val1, val2);
|
||||
}
|
||||
#endif // ENABLE_BF16
|
||||
|
||||
// Binary maximum: compute the min of two values.
|
||||
template <typename T>
|
||||
__device__ inline T cuda_min(T val1, T val2)
|
||||
{
|
||||
return (val1 < val2) ? val1 : val2;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline float2 cuda_min(float2 val1, float2 val2)
|
||||
{
|
||||
float2 out;
|
||||
out.x = fminf(val1.x, val2.x);
|
||||
out.y = fminf(val1.y, val2.y);
|
||||
return out;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline half2 cuda_min(half2 val1, half2 val2)
|
||||
{
|
||||
return __hmin2(val1, val2);
|
||||
}
|
||||
|
||||
#ifdef ENABLE_BF16
|
||||
template <>
|
||||
__device__ inline __nv_bfloat162 cuda_min(__nv_bfloat162 val1, __nv_bfloat162 val2)
|
||||
{
|
||||
return __hmin2(val1, val2);
|
||||
}
|
||||
#endif // ENABLE_BF16
|
||||
|
||||
// Helper function of clamping the val into the given range.
|
||||
template <typename T>
|
||||
inline __device__ T cuda_clamp(T val, T minVal, T maxVal)
|
||||
{
|
||||
return cuda_min(cuda_max(val, minVal), maxVal);
|
||||
}
|
||||
|
||||
#ifdef ENABLE_FP8
|
||||
template <>
|
||||
__device__ inline float2 cuda_cast<float2, __nv_fp8x2_e4m3>(__nv_fp8x2_e4m3 val)
|
||||
{
|
||||
return bf1622float2(fp8x2_e4m3_to_bfloat2(&val));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline half2 cuda_cast<half2, __nv_fp8x2_e4m3>(__nv_fp8x2_e4m3 val)
|
||||
{
|
||||
return fp8x2_e4m3_to_half2(&val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline __nv_fp8x2_e4m3 cuda_cast<__nv_fp8x2_e4m3, float2>(float2 val)
|
||||
{
|
||||
return __nv_fp8x2_e4m3(bf1622float2(float22bf162(val)));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline __nv_fp8x2_e4m3 cuda_cast<__nv_fp8x2_e4m3, half2>(half2 val)
|
||||
{
|
||||
return __nv_fp8x2_e4m3(cuda_cast<float2>(val));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline __nv_fp8x2_e4m3 cuda_cast<__nv_fp8x2_e4m3, __nv_bfloat162>(__nv_bfloat162 val)
|
||||
{
|
||||
return __nv_fp8x2_e4m3(cuda_cast<float2>(val));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, half>(half val)
|
||||
{
|
||||
return __nv_fp8_e4m3(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, __nv_bfloat16>(__nv_bfloat16 val)
|
||||
{
|
||||
return __nv_fp8_e4m3(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, float>(float val)
|
||||
{
|
||||
return __nv_fp8_e4m3(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline float cuda_cast<float, __nv_fp8_e4m3>(__nv_fp8_e4m3 val)
|
||||
{
|
||||
return (float) val;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, __nv_fp8x2_e4m3>(__nv_fp8x2_e4m3 val)
|
||||
{
|
||||
return fp8x2_e4m3_to_bfloat2(&val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline int8_t cuda_cast<int8_t, __nv_fp8_e4m3>(__nv_fp8_e4m3 val)
|
||||
{
|
||||
// no impl
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, int8_t>(int8_t val)
|
||||
{
|
||||
return cuda_cast<__nv_fp8_e4m3>(cuda_cast<__nv_bfloat16>(cuda_cast<float>(val)));
|
||||
}
|
||||
|
||||
#endif // ENABLE_FP8
|
||||
|
||||
} // namespace common
|
||||
} // namespace tensorrt_llm
|
||||
641
sgl-kernel/3rdparty/tensorrt_llm/common/cudaUtils.h
vendored
641
sgl-kernel/3rdparty/tensorrt_llm/common/cudaUtils.h
vendored
@@ -1,641 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/cudaBf16Wrapper.h"
|
||||
#include "tensorrt_llm/common/cudaDriverWrapper.h"
|
||||
#include "tensorrt_llm/common/cudaFp8Utils.h"
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
#include "tensorrt_llm/common/tllmException.h"
|
||||
#include <algorithm>
|
||||
#include <cinttypes>
|
||||
#include <cublasLt.h>
|
||||
#include <cublas_v2.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <driver_types.h>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#ifndef _WIN32 // Linux
|
||||
#include <sys/sysinfo.h>
|
||||
#endif // not WIN32
|
||||
#include <vector>
|
||||
#ifdef _WIN32 // Windows
|
||||
#include <windows.h>
|
||||
#undef ERROR // A Windows header file defines ERROR as 0, but it's used in our logger.h enum. Logging breaks without
|
||||
// this undef.
|
||||
#endif // WIN32
|
||||
|
||||
namespace tensorrt_llm::common
|
||||
{
|
||||
|
||||
// workspace for cublas gemm : 32MB
|
||||
#define CUBLAS_WORKSPACE_SIZE 33554432
|
||||
|
||||
typedef struct __align__(4)
|
||||
{
|
||||
half x, y, z, w;
|
||||
}
|
||||
|
||||
half4;
|
||||
|
||||
/* **************************** type definition ***************************** */
|
||||
|
||||
enum CublasDataType
|
||||
{
|
||||
FLOAT_DATATYPE = 0,
|
||||
HALF_DATATYPE = 1,
|
||||
BFLOAT16_DATATYPE = 2,
|
||||
INT8_DATATYPE = 3,
|
||||
FP8_DATATYPE = 4
|
||||
};
|
||||
|
||||
enum TRTLLMCudaDataType
|
||||
{
|
||||
FP32 = 0,
|
||||
FP16 = 1,
|
||||
BF16 = 2,
|
||||
INT8 = 3,
|
||||
FP8 = 4
|
||||
};
|
||||
|
||||
enum class OperationType
|
||||
{
|
||||
FP32,
|
||||
FP16,
|
||||
BF16,
|
||||
INT8,
|
||||
FP8
|
||||
};
|
||||
|
||||
/* **************************** debug tools ********************************* */
|
||||
static char const* _cudaGetErrorEnum(cudaError_t error)
|
||||
{
|
||||
return cudaGetErrorString(error);
|
||||
}
|
||||
|
||||
static char const* _cudaGetErrorEnum(cublasStatus_t error)
|
||||
{
|
||||
switch (error)
|
||||
{
|
||||
case CUBLAS_STATUS_SUCCESS: return "CUBLAS_STATUS_SUCCESS";
|
||||
|
||||
case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED";
|
||||
|
||||
case CUBLAS_STATUS_ALLOC_FAILED: return "CUBLAS_STATUS_ALLOC_FAILED";
|
||||
|
||||
case CUBLAS_STATUS_INVALID_VALUE: return "CUBLAS_STATUS_INVALID_VALUE";
|
||||
|
||||
case CUBLAS_STATUS_ARCH_MISMATCH: return "CUBLAS_STATUS_ARCH_MISMATCH";
|
||||
|
||||
case CUBLAS_STATUS_MAPPING_ERROR: return "CUBLAS_STATUS_MAPPING_ERROR";
|
||||
|
||||
case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED";
|
||||
|
||||
case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR";
|
||||
|
||||
case CUBLAS_STATUS_NOT_SUPPORTED: return "CUBLAS_STATUS_NOT_SUPPORTED";
|
||||
|
||||
case CUBLAS_STATUS_LICENSE_ERROR: return "CUBLAS_STATUS_LICENSE_ERROR";
|
||||
}
|
||||
return "<unknown>";
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void check(T result, char const* const func, char const* const file, int const line)
|
||||
{
|
||||
if (result)
|
||||
{
|
||||
throw TllmException(
|
||||
file, line, fmtstr("[TensorRT-LLM][ERROR] CUDA runtime error in %s: %s", func, _cudaGetErrorEnum(result)));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void checkEx(T result, std::initializer_list<T> const& validReturns, char const* const func, char const* const file,
|
||||
int const line)
|
||||
{
|
||||
if (std::all_of(std::begin(validReturns), std::end(validReturns), [&result](T const& t) { return t != result; }))
|
||||
{
|
||||
throw TllmException(
|
||||
file, line, fmtstr("[TensorRT-LLM][ERROR] CUDA runtime error in %s: %s", func, _cudaGetErrorEnum(result)));
|
||||
}
|
||||
}
|
||||
|
||||
#define check_cuda_error(val) check((val), #val, __FILE__, __LINE__)
|
||||
#define check_cuda_error_2(val, file, line) check((val), #val, file, line)
|
||||
|
||||
inline std::optional<bool> isCudaLaunchBlocking()
|
||||
{
|
||||
static bool firstCall = true;
|
||||
static std::optional<bool> result = std::nullopt;
|
||||
|
||||
if (firstCall)
|
||||
{
|
||||
char const* env = std::getenv("CUDA_LAUNCH_BLOCKING");
|
||||
if (env != nullptr && std::string(env) == "1")
|
||||
{
|
||||
result = true;
|
||||
}
|
||||
else if (env != nullptr && std::string(env) == "0")
|
||||
{
|
||||
result = false;
|
||||
}
|
||||
firstCall = false;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
inline bool doCheckError()
|
||||
{
|
||||
auto const cudaLaunchBlocking = isCudaLaunchBlocking();
|
||||
#ifndef NDEBUG
|
||||
bool const checkError = cudaLaunchBlocking.value_or(true);
|
||||
#else
|
||||
bool const checkError = cudaLaunchBlocking.value_or(false);
|
||||
#endif
|
||||
|
||||
return checkError;
|
||||
}
|
||||
|
||||
inline void syncAndCheck(char const* const file, int const line)
|
||||
{
|
||||
if (doCheckError())
|
||||
{
|
||||
cudaDeviceSynchronize();
|
||||
check(cudaGetLastError(), "cudaGetLastError", file, line);
|
||||
}
|
||||
}
|
||||
|
||||
#define sync_check_cuda_error() tensorrt_llm::common::syncAndCheck(__FILE__, __LINE__)
|
||||
|
||||
#define PRINT_FUNC_NAME_() \
|
||||
do \
|
||||
{ \
|
||||
std::cout << "[TensorRT-LLM][CALL] " << __FUNCTION__ << " " << std::endl; \
|
||||
} while (0)
|
||||
|
||||
// clang-format off
|
||||
template<typename T> struct packed_type;
|
||||
template <> struct packed_type<float> { using type = float; }; // we don't need to pack float by default
|
||||
template <> struct packed_type<half> { using type = half2; };
|
||||
|
||||
#ifdef ENABLE_BF16
|
||||
template<>
|
||||
struct packed_type<__nv_bfloat16> {
|
||||
using type = __nv_bfloat162;
|
||||
};
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_FP8
|
||||
template<>
|
||||
struct packed_type<__nv_fp8_e4m3> {
|
||||
using type = __nv_fp8x2_e4m3;
|
||||
};
|
||||
#endif
|
||||
|
||||
template<typename T> struct num_elems;
|
||||
template <> struct num_elems<float> { static constexpr int value = 1; };
|
||||
template <> struct num_elems<float2> { static constexpr int value = 2; };
|
||||
template <> struct num_elems<float4> { static constexpr int value = 4; };
|
||||
template <> struct num_elems<half> { static constexpr int value = 1; };
|
||||
template <> struct num_elems<half2> { static constexpr int value = 2; };
|
||||
#ifdef ENABLE_BF16
|
||||
template <> struct num_elems<__nv_bfloat16> { static constexpr int value = 1; };
|
||||
template <> struct num_elems<__nv_bfloat162> { static constexpr int value = 2; };
|
||||
#endif
|
||||
#ifdef ENABLE_FP8
|
||||
template <> struct num_elems<__nv_fp8_e4m3> { static constexpr int value = 1; };
|
||||
template <> struct num_elems<__nv_fp8x2_e4m3> { static constexpr int value = 2; };
|
||||
#endif
|
||||
|
||||
template<typename T, int num> struct packed_as;
|
||||
template<typename T> struct packed_as<T, 1> { using type = T; };
|
||||
template<> struct packed_as<half, 2> { using type = half2; };
|
||||
template<> struct packed_as<float, 2> { using type = float2; };
|
||||
template<> struct packed_as<int8_t, 2> { using type = int16_t; };
|
||||
template<> struct packed_as<int32_t, 2> { using type = int2; };
|
||||
template<> struct packed_as<half2, 1> { using type = half; };
|
||||
template<> struct packed_as<float2, 1> { using type = float; };
|
||||
#ifdef ENABLE_BF16
|
||||
template<> struct packed_as<__nv_bfloat16, 2> { using type = __nv_bfloat162; };
|
||||
template<> struct packed_as<__nv_bfloat162, 1> { using type = __nv_bfloat16; };
|
||||
#endif
|
||||
#ifdef ENABLE_FP8
|
||||
template<> struct packed_as<__nv_fp8_e4m3, 2> { using type = __nv_fp8x2_e4m3; };
|
||||
template<> struct packed_as<__nv_fp8x2_e4m3, 1> { using type = __nv_fp8_e4m3; };
|
||||
template<> struct packed_as<__nv_fp8_e5m2, 2> { using type = __nv_fp8x2_e5m2; };
|
||||
template<> struct packed_as<__nv_fp8x2_e5m2, 1> { using type = __nv_fp8_e5m2; };
|
||||
#endif
|
||||
|
||||
inline __device__ float2 operator*(float2 a, float2 b) { return make_float2(a.x * b.x, a.y * b.y); }
|
||||
inline __device__ float2 operator+(float2 a, float2 b) { return make_float2(a.x + b.x, a.y + b.y); }
|
||||
inline __device__ float2 operator-(float2 a, float2 b) { return make_float2(a.x - b.x, a.y - b.y); }
|
||||
|
||||
inline __device__ float2 operator*(float2 a, float b) { return make_float2(a.x * b, a.y * b); }
|
||||
inline __device__ float2 operator+(float2 a, float b) { return make_float2(a.x + b, a.y + b); }
|
||||
inline __device__ float2 operator-(float2 a, float b) { return make_float2(a.x - b, a.y - b); }
|
||||
|
||||
// clang-format on
|
||||
|
||||
template <typename T>
|
||||
struct CudaDataType
|
||||
{
|
||||
};
|
||||
|
||||
template <>
|
||||
struct CudaDataType<float>
|
||||
{
|
||||
static constexpr cudaDataType_t value = cudaDataType::CUDA_R_32F;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct CudaDataType<half>
|
||||
{
|
||||
static constexpr cudaDataType_t value = cudaDataType::CUDA_R_16F;
|
||||
};
|
||||
|
||||
#ifdef ENABLE_BF16
|
||||
template <>
|
||||
struct CudaDataType<__nv_bfloat16>
|
||||
{
|
||||
static constexpr cudaDataType_t value = cudaDataType::CUDA_R_16BF;
|
||||
};
|
||||
#endif
|
||||
|
||||
inline int getSMVersion()
|
||||
{
|
||||
int device{-1};
|
||||
check_cuda_error(cudaGetDevice(&device));
|
||||
int sm_major = 0;
|
||||
int sm_minor = 0;
|
||||
check_cuda_error(cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device));
|
||||
check_cuda_error(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device));
|
||||
return sm_major * 10 + sm_minor;
|
||||
}
|
||||
|
||||
inline int getDevice()
|
||||
{
|
||||
int current_dev_id = 0;
|
||||
check_cuda_error(cudaGetDevice(¤t_dev_id));
|
||||
return current_dev_id;
|
||||
}
|
||||
|
||||
inline int getDeviceCount()
|
||||
{
|
||||
int count = 0;
|
||||
check_cuda_error(cudaGetDeviceCount(&count));
|
||||
return count;
|
||||
}
|
||||
|
||||
/// @brief Identifies the memory type of the given pointer.
|
||||
template <typename T>
|
||||
cudaMemoryType getPtrCudaMemoryType(T* ptr)
|
||||
{
|
||||
cudaPointerAttributes attributes{};
|
||||
check_cuda_error(cudaPointerGetAttributes(&attributes, ptr));
|
||||
return attributes.type;
|
||||
}
|
||||
|
||||
/// Get the memory info
|
||||
/// \return The free and total amount of memory in bytes
|
||||
inline std::tuple<size_t, size_t> getDeviceMemoryInfo(bool const useUvm)
|
||||
{
|
||||
if (useUvm)
|
||||
{
|
||||
size_t freeSysMem = 0;
|
||||
size_t totalSysMem = 0;
|
||||
#ifndef _WIN32 // Linux
|
||||
struct sysinfo info
|
||||
{
|
||||
};
|
||||
|
||||
sysinfo(&info);
|
||||
totalSysMem = info.totalram * info.mem_unit;
|
||||
freeSysMem = info.freeram * info.mem_unit;
|
||||
#else // Windows
|
||||
MEMORYSTATUSEX memInfo;
|
||||
memInfo.dwLength = sizeof(memInfo);
|
||||
GlobalMemoryStatusEx(&memInfo);
|
||||
totalSysMem = memInfo.ullTotalPhys;
|
||||
freeSysMem = memInfo.ullAvailPhys;
|
||||
#endif // WIN32
|
||||
|
||||
TLLM_LOG_INFO("Using UVM based system memory for KV cache, total memory %0.2f GB, available memory %0.2f GB",
|
||||
((double) totalSysMem / 1e9), ((double) freeSysMem / 1e9));
|
||||
return {freeSysMem, totalSysMem};
|
||||
}
|
||||
|
||||
size_t free = 0;
|
||||
size_t total = 0;
|
||||
check_cuda_error(cudaMemGetInfo(&free, &total));
|
||||
TLLM_LOG_DEBUG("Using GPU memory for KV cache, total memory %0.2f GB, available memory %0.2f GB",
|
||||
((double) total / 1e9), ((double) free / 1e9));
|
||||
return {free, total};
|
||||
}
|
||||
|
||||
/// @brief Gets the memory allocation granularity for the current device.
|
||||
///
|
||||
/// @return size_t The size of the smallest difference in memory size supported by the current device.
|
||||
inline size_t getAllocationGranularity()
|
||||
{
|
||||
auto const currentDevice = getDevice();
|
||||
::CUmemAllocationProp prop = {};
|
||||
|
||||
prop.type = ::CU_MEM_ALLOCATION_TYPE_PINNED;
|
||||
prop.location.type = ::CU_MEM_LOCATION_TYPE_DEVICE;
|
||||
prop.location.id = currentDevice;
|
||||
prop.requestedHandleTypes = ::CU_MEM_HANDLE_TYPE_NONE;
|
||||
|
||||
// Get the minimum granularity supported for allocation with cuMemCreate()
|
||||
size_t granularity = 0;
|
||||
TLLM_CU_CHECK(cuMemGetAllocationGranularity(&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM));
|
||||
return granularity;
|
||||
}
|
||||
|
||||
inline int getMultiProcessorCount()
|
||||
{
|
||||
int device_id = 0;
|
||||
int multi_processor_count = 0;
|
||||
check_cuda_error(cudaGetDevice(&device_id));
|
||||
check_cuda_error(cudaDeviceGetAttribute(&multi_processor_count, cudaDevAttrMultiProcessorCount, device_id));
|
||||
return multi_processor_count;
|
||||
}
|
||||
|
||||
inline int getMaxSharedMemoryPerBlockOptin()
|
||||
{
|
||||
int device_id = 0;
|
||||
int max_shared_memory_per_block = 0;
|
||||
check_cuda_error(cudaGetDevice(&device_id));
|
||||
check_cuda_error(
|
||||
cudaDeviceGetAttribute(&max_shared_memory_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device_id));
|
||||
return max_shared_memory_per_block;
|
||||
}
|
||||
|
||||
template <typename T1, typename T2>
|
||||
inline size_t divUp(const T1& a, const T2& n)
|
||||
{
|
||||
auto const tmp_a = static_cast<size_t>(a);
|
||||
auto const tmp_n = static_cast<size_t>(n);
|
||||
return (tmp_a + tmp_n - 1) / tmp_n;
|
||||
}
|
||||
|
||||
inline int roundUp(int a, int n)
|
||||
{
|
||||
return divUp(a, n) * n;
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename = std::enable_if_t<std::is_integral<T>::value>,
|
||||
typename = std::enable_if_t<std::is_integral<U>::value>>
|
||||
auto constexpr ceilDiv(T numerator, U denominator)
|
||||
{
|
||||
return (numerator + denominator - 1) / denominator;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void printAbsMean(T const* buf, uint64_t size, cudaStream_t stream, std::string name = "")
|
||||
{
|
||||
if (buf == nullptr)
|
||||
{
|
||||
TLLM_LOG_WARNING("%s is an nullptr, skip!", name.c_str());
|
||||
return;
|
||||
}
|
||||
cudaDeviceSynchronize();
|
||||
check_cuda_error(cudaGetLastError());
|
||||
T* h_tmp = new T[size];
|
||||
cudaMemcpyAsync(h_tmp, buf, sizeof(T) * size, cudaMemcpyDeviceToHost, stream);
|
||||
cudaDeviceSynchronize();
|
||||
check_cuda_error(cudaGetLastError());
|
||||
double sum = 0.0f;
|
||||
uint64_t zero_count = 0;
|
||||
float max_val = -1e10;
|
||||
bool find_inf = false;
|
||||
for (uint64_t i = 0; i < size; i++)
|
||||
{
|
||||
if (std::isinf((float) (h_tmp[i])))
|
||||
{
|
||||
find_inf = true;
|
||||
continue;
|
||||
}
|
||||
sum += abs((double) h_tmp[i]);
|
||||
if ((float) h_tmp[i] == 0.0f)
|
||||
{
|
||||
zero_count++;
|
||||
}
|
||||
max_val = max_val > abs(float(h_tmp[i])) ? max_val : abs(float(h_tmp[i]));
|
||||
}
|
||||
TLLM_LOG_INFO("%20s size: %u, abs mean: %f, abs sum: %f, abs max: %f, find inf: %s", name.c_str(), size, sum / size,
|
||||
sum, max_val, find_inf ? "true" : "false");
|
||||
delete[] h_tmp;
|
||||
cudaDeviceSynchronize();
|
||||
check_cuda_error(cudaGetLastError());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void printToStream(T const* result, int const size, FILE* strm)
|
||||
{
|
||||
bool const split_rows = (strm == stdout);
|
||||
if (result == nullptr)
|
||||
{
|
||||
TLLM_LOG_WARNING("It is an nullptr, skip! \n");
|
||||
return;
|
||||
}
|
||||
T* tmp = reinterpret_cast<T*>(malloc(sizeof(T) * size));
|
||||
check_cuda_error(cudaMemcpy(tmp, result, sizeof(T) * size, cudaMemcpyDeviceToHost));
|
||||
for (int i = 0; i < size; ++i)
|
||||
{
|
||||
fprintf(strm, "%f, ", static_cast<float>(tmp[i]));
|
||||
if (split_rows && ((i + 1) % 10) == 0)
|
||||
fprintf(strm, "\n");
|
||||
}
|
||||
if (!split_rows || (size % 10) != 0)
|
||||
{
|
||||
fprintf(strm, "\n");
|
||||
}
|
||||
free(tmp);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void printToScreen(T const* result, int const size)
|
||||
{
|
||||
printToStream(result, size, stdout);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void print2dToStream(T const* result, int const r, int const c, int const stride, FILE* strm)
|
||||
{
|
||||
if (result == nullptr)
|
||||
{
|
||||
TLLM_LOG_WARNING("It is an nullptr, skip! \n");
|
||||
return;
|
||||
}
|
||||
for (int ri = 0; ri < r; ++ri)
|
||||
{
|
||||
T const* ptr = result + ri * stride;
|
||||
printToStream(ptr, c, strm);
|
||||
}
|
||||
fprintf(strm, "\n");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void print2dToScreen(T const* result, int const r, int const c, int const stride)
|
||||
{
|
||||
print2dToStream(result, r, c, stride, stdout);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void print2dToFile(std::string fname, T const* result, int const r, int const c, int const stride)
|
||||
{
|
||||
FILE* fp = fopen(fname.c_str(), "wt");
|
||||
if (fp != nullptr)
|
||||
{
|
||||
print2dToStream(result, r, c, stride, fp);
|
||||
fclose(fp);
|
||||
}
|
||||
}
|
||||
|
||||
inline void print_float_(float x)
|
||||
{
|
||||
printf("%7.3f ", x);
|
||||
}
|
||||
|
||||
inline void print_element_(float x)
|
||||
{
|
||||
print_float_(x);
|
||||
}
|
||||
|
||||
inline void print_element_(half x)
|
||||
{
|
||||
print_float_((float) x);
|
||||
}
|
||||
|
||||
#ifdef ENABLE_BF16
|
||||
inline void print_element_(__nv_bfloat16 x)
|
||||
{
|
||||
print_float_((float) x);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_FP8
|
||||
inline void print_element_(__nv_fp8_e4m3 x)
|
||||
{
|
||||
print_float_((float) x);
|
||||
}
|
||||
#endif
|
||||
|
||||
inline void print_element_(uint32_t ul)
|
||||
{
|
||||
printf("%7" PRIu32, ul);
|
||||
}
|
||||
|
||||
inline void print_element_(uint64_t ull)
|
||||
{
|
||||
printf("%7" PRIu64, ull);
|
||||
}
|
||||
|
||||
inline void print_element_(int32_t il)
|
||||
{
|
||||
printf("%7" PRId32, il);
|
||||
}
|
||||
|
||||
inline void print_element_(int64_t ill)
|
||||
{
|
||||
printf("%7" PRId64, ill);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void printMatrix(T const* ptr, int m, int k, int stride, bool is_device_ptr)
|
||||
{
|
||||
T* tmp;
|
||||
if (is_device_ptr)
|
||||
{
|
||||
// k < stride ; stride = col-dimension.
|
||||
tmp = reinterpret_cast<T*>(malloc(m * stride * sizeof(T)));
|
||||
check_cuda_error(cudaMemcpy(tmp, ptr, sizeof(T) * m * stride, cudaMemcpyDeviceToHost));
|
||||
cudaDeviceSynchronize();
|
||||
}
|
||||
else
|
||||
{
|
||||
tmp = const_cast<T*>(ptr);
|
||||
}
|
||||
|
||||
for (int ii = -1; ii < m; ++ii)
|
||||
{
|
||||
if (ii >= 0)
|
||||
{
|
||||
printf("%07d ", ii);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf(" ");
|
||||
}
|
||||
|
||||
for (int jj = 0; jj < k; jj += 1)
|
||||
{
|
||||
if (ii >= 0)
|
||||
{
|
||||
print_element_(tmp[ii * stride + jj]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("%7d ", jj);
|
||||
}
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
if (is_device_ptr)
|
||||
{
|
||||
free(tmp);
|
||||
}
|
||||
}
|
||||
|
||||
template void printMatrix(float const* ptr, int m, int k, int stride, bool is_device_ptr);
|
||||
template void printMatrix(half const* ptr, int m, int k, int stride, bool is_device_ptr);
|
||||
#ifdef ENABLE_BF16
|
||||
template void printMatrix(__nv_bfloat16 const* ptr, int m, int k, int stride, bool is_device_ptr);
|
||||
#endif
|
||||
#ifdef ENABLE_FP8
|
||||
template void printMatrix(__nv_fp8_e4m3 const* ptr, int m, int k, int stride, bool is_device_ptr);
|
||||
#endif
|
||||
template void printMatrix(uint32_t const* ptr, int m, int k, int stride, bool is_device_ptr);
|
||||
template void printMatrix(uint64_t const* ptr, int m, int k, int stride, bool is_device_ptr);
|
||||
template void printMatrix(int const* ptr, int m, int k, int stride, bool is_device_ptr);
|
||||
|
||||
} // namespace tensorrt_llm::common
|
||||
|
||||
/*
|
||||
* Macros compliant with TensorRT coding conventions
|
||||
*/
|
||||
#define TLLM_CUDA_CHECK(stat) \
|
||||
do \
|
||||
{ \
|
||||
tensorrt_llm::common::check((stat), #stat, __FILE__, __LINE__); \
|
||||
} while (0)
|
||||
|
||||
// We use singleton memory pool and the order of destructors depends on the compiler implementation. We find that the
|
||||
// cudaFree/cudaFreeHost is called after cudaruntime destruction on Windows. There will be an cudaErrorCudartUnloading
|
||||
// error. However, it is safe to ignore this error because the cuda runtime is already exited, we are no more worried
|
||||
// about the memory leaks.
|
||||
#define TLLM_CUDA_CHECK_FREE_RESOURCE(stat) \
|
||||
do \
|
||||
{ \
|
||||
tensorrt_llm::common::checkEx((stat), {cudaSuccess, cudaErrorCudartUnloading}, #stat, __FILE__, __LINE__); \
|
||||
} while (0)
|
||||
@@ -1,70 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/common/tllmException.h"
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
namespace tensorrt_llm::common
|
||||
{
|
||||
|
||||
Logger::Logger()
|
||||
{
|
||||
char* isFirstRankOnlyChar = std::getenv("TLLM_LOG_FIRST_RANK_ONLY");
|
||||
bool isFirstRankOnly = (isFirstRankOnlyChar != nullptr && std::string(isFirstRankOnlyChar) == "ON");
|
||||
|
||||
auto const* levelName = std::getenv("TLLM_LOG_LEVEL");
|
||||
if (levelName != nullptr)
|
||||
{
|
||||
auto level = [levelName = std::string(levelName)]()
|
||||
{
|
||||
if (levelName == "TRACE")
|
||||
return TRACE;
|
||||
if (levelName == "DEBUG")
|
||||
return DEBUG;
|
||||
if (levelName == "INFO")
|
||||
return INFO;
|
||||
if (levelName == "WARNING")
|
||||
return WARNING;
|
||||
if (levelName == "ERROR")
|
||||
return ERROR;
|
||||
TLLM_THROW("Invalid log level: %s", levelName.c_str());
|
||||
}();
|
||||
// If TLLM_LOG_FIRST_RANK_ONLY=ON, set LOG LEVEL of other device to ERROR
|
||||
if (isFirstRankOnly)
|
||||
{
|
||||
auto const deviceId = getDevice();
|
||||
if (deviceId != 1)
|
||||
{
|
||||
level = ERROR;
|
||||
}
|
||||
}
|
||||
setLevel(level);
|
||||
}
|
||||
}
|
||||
|
||||
void Logger::log(std::exception const& ex, Logger::Level level)
|
||||
{
|
||||
log(level, "%s: %s", TllmException::demangle(typeid(ex).name()).c_str(), ex.what());
|
||||
}
|
||||
|
||||
Logger* Logger::getLogger()
|
||||
{
|
||||
thread_local Logger instance;
|
||||
return &instance;
|
||||
}
|
||||
} // namespace tensorrt_llm::common
|
||||
190
sgl-kernel/3rdparty/tensorrt_llm/common/logger.h
vendored
190
sgl-kernel/3rdparty/tensorrt_llm/common/logger.h
vendored
@@ -1,190 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/stringUtils.h"
|
||||
|
||||
namespace tensorrt_llm::common
|
||||
{
|
||||
|
||||
class Logger
|
||||
{
|
||||
|
||||
// On Windows, the file wingdi.h is included which has
|
||||
// #define ERROR 0
|
||||
// This breaks everywhere ERROR is used in the Level enum
|
||||
#ifdef _WIN32
|
||||
#undef ERROR
|
||||
#endif // _WIN32
|
||||
|
||||
public:
|
||||
enum Level
|
||||
{
|
||||
TRACE = 0,
|
||||
DEBUG = 10,
|
||||
INFO = 20,
|
||||
WARNING = 30,
|
||||
ERROR = 40
|
||||
};
|
||||
|
||||
static Logger* getLogger();
|
||||
|
||||
Logger(Logger const&) = delete;
|
||||
void operator=(Logger const&) = delete;
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
template <typename... Args>
|
||||
void log(Level level, char const* format, Args const&... args);
|
||||
|
||||
template <typename... Args>
|
||||
void log(Level level, int rank, char const* format, Args const&... args);
|
||||
#else
|
||||
template <typename... Args>
|
||||
void log(Level level, char const* format, Args const&... args) __attribute__((format(printf, 3, 0)));
|
||||
|
||||
template <typename... Args>
|
||||
void log(Level level, int rank, char const* format, Args const&... args) __attribute__((format(printf, 4, 0)));
|
||||
#endif
|
||||
|
||||
template <typename... Args>
|
||||
void log(Level level, std::string const& format, Args const&... args)
|
||||
{
|
||||
return log(level, format.c_str(), args...);
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
void log(Level const level, int const rank, std::string const& format, Args const&... args)
|
||||
{
|
||||
return log(level, rank, format.c_str(), args...);
|
||||
}
|
||||
|
||||
void log(std::exception const& ex, Level level = Level::ERROR);
|
||||
|
||||
Level getLevel() const
|
||||
{
|
||||
return level_;
|
||||
}
|
||||
|
||||
void setLevel(Level const level)
|
||||
{
|
||||
level_ = level;
|
||||
log(INFO, "Set logger level to %s", getLevelName(level));
|
||||
}
|
||||
|
||||
bool isEnabled(Level const level) const
|
||||
{
|
||||
return level_ <= level;
|
||||
}
|
||||
|
||||
private:
|
||||
static auto constexpr kPREFIX = "[TensorRT-LLM]";
|
||||
|
||||
#ifndef NDEBUG
|
||||
Level const DEFAULT_LOG_LEVEL = DEBUG;
|
||||
#else
|
||||
Level const DEFAULT_LOG_LEVEL = INFO;
|
||||
#endif
|
||||
Level level_ = DEFAULT_LOG_LEVEL;
|
||||
|
||||
Logger(); // NOLINT(modernize-use-equals-delete)
|
||||
|
||||
static inline char const* getLevelName(Level const level)
|
||||
{
|
||||
switch (level)
|
||||
{
|
||||
case TRACE: return "TRACE";
|
||||
case DEBUG: return "DEBUG";
|
||||
case INFO: return "INFO";
|
||||
case WARNING: return "WARNING";
|
||||
case ERROR: return "ERROR";
|
||||
}
|
||||
|
||||
TLLM_THROW("Unknown log level: %d", level);
|
||||
}
|
||||
|
||||
static inline std::string getPrefix(Level const level)
|
||||
{
|
||||
return fmtstr("%s[%s] ", kPREFIX, getLevelName(level));
|
||||
}
|
||||
|
||||
static inline std::string getPrefix(Level const level, int const rank)
|
||||
{
|
||||
return fmtstr("%s[%s][%d] ", kPREFIX, getLevelName(level), rank);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename... Args>
|
||||
void Logger::log(Logger::Level level, char const* format, Args const&... args)
|
||||
{
|
||||
if (isEnabled(level))
|
||||
{
|
||||
auto const fmt = getPrefix(level) + format;
|
||||
auto& out = level_ < WARNING ? std::cout : std::cerr;
|
||||
if constexpr (sizeof...(args) > 0)
|
||||
{
|
||||
out << fmtstr(fmt.c_str(), args...);
|
||||
}
|
||||
else
|
||||
{
|
||||
out << fmt;
|
||||
}
|
||||
out << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
void Logger::log(Logger::Level const level, int const rank, char const* format, Args const&... args)
|
||||
{
|
||||
if (isEnabled(level))
|
||||
{
|
||||
auto const fmt = getPrefix(level, rank) + format;
|
||||
auto& out = level_ < WARNING ? std::cout : std::cerr;
|
||||
if constexpr (sizeof...(args) > 0)
|
||||
{
|
||||
out << fmtstr(fmt.c_str(), args...);
|
||||
}
|
||||
else
|
||||
{
|
||||
out << fmt;
|
||||
}
|
||||
out << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
#define TLLM_LOG(level, ...) \
|
||||
do \
|
||||
{ \
|
||||
auto* const logger = tensorrt_llm::common::Logger::getLogger(); \
|
||||
if (logger->isEnabled(level)) \
|
||||
{ \
|
||||
logger->log(level, __VA_ARGS__); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define TLLM_LOG_TRACE(...) TLLM_LOG(tensorrt_llm::common::Logger::TRACE, __VA_ARGS__)
|
||||
#define TLLM_LOG_DEBUG(...) TLLM_LOG(tensorrt_llm::common::Logger::DEBUG, __VA_ARGS__)
|
||||
#define TLLM_LOG_INFO(...) TLLM_LOG(tensorrt_llm::common::Logger::INFO, __VA_ARGS__)
|
||||
#define TLLM_LOG_WARNING(...) TLLM_LOG(tensorrt_llm::common::Logger::WARNING, __VA_ARGS__)
|
||||
#define TLLM_LOG_ERROR(...) TLLM_LOG(tensorrt_llm::common::Logger::ERROR, __VA_ARGS__)
|
||||
#define TLLM_LOG_EXCEPTION(ex, ...) tensorrt_llm::common::Logger::getLogger()->log(ex, ##__VA_ARGS__)
|
||||
} // namespace tensorrt_llm::common
|
||||
@@ -1,55 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/cudaBf16Fallbacks.cuh"
|
||||
#include "tensorrt_llm/common/cudaFp8Utils.h"
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <float.h>
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace common
|
||||
{
|
||||
|
||||
template <typename T>
|
||||
struct QuantTypeStaticVals;
|
||||
|
||||
template <>
|
||||
struct QuantTypeStaticVals<int8_t>
|
||||
{
|
||||
static constexpr float MAX_VAL = 127.f;
|
||||
static constexpr float MIN_SCALING_FACTOR = 0.f;
|
||||
static constexpr float MIN_SCALING_FACTOR_RCP = FLT_MAX;
|
||||
};
|
||||
|
||||
#ifdef ENABLE_FP8
|
||||
|
||||
template <>
|
||||
struct QuantTypeStaticVals<__nv_fp8_e4m3>
|
||||
{
|
||||
static constexpr float MAX_VAL = 448.f;
|
||||
// Ref: https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu#L720
|
||||
static constexpr float MIN_SCALING_FACTOR = 1.0f / (448.f * 512.f);
|
||||
static constexpr float MIN_SCALING_FACTOR_RCP = (448.f * 512.f);
|
||||
};
|
||||
|
||||
#endif // ENABLE_FP8
|
||||
|
||||
} // namespace common
|
||||
} // namespace tensorrt_llm
|
||||
@@ -1,358 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace common
|
||||
{
|
||||
|
||||
class QuantMode
|
||||
{
|
||||
// [WARNING] KEEP BELOW DEFINITION IN SYNC WITH tensorrt_llm/quantization/mode.py
|
||||
public:
|
||||
using BaseType = std::uint32_t;
|
||||
|
||||
explicit constexpr QuantMode(BaseType value) noexcept
|
||||
: mValue{value}
|
||||
{
|
||||
}
|
||||
|
||||
QuantMode() noexcept = default;
|
||||
|
||||
constexpr QuantMode(QuantMode const&) noexcept = default;
|
||||
|
||||
constexpr QuantMode& operator=(QuantMode const& other) noexcept = default;
|
||||
|
||||
static constexpr QuantMode none() noexcept
|
||||
{
|
||||
return QuantMode(BaseType(0));
|
||||
}
|
||||
|
||||
static constexpr QuantMode int4Weights() noexcept
|
||||
{
|
||||
return QuantMode(BaseType(1u) << 0);
|
||||
}
|
||||
|
||||
static constexpr QuantMode int8Weights() noexcept
|
||||
{
|
||||
return QuantMode(BaseType(1u) << 1);
|
||||
}
|
||||
|
||||
static constexpr QuantMode activations() noexcept
|
||||
{
|
||||
return QuantMode(BaseType(1u) << 2);
|
||||
}
|
||||
|
||||
static constexpr QuantMode perChannelScaling() noexcept
|
||||
{
|
||||
return QuantMode(BaseType(1u) << 3);
|
||||
}
|
||||
|
||||
static constexpr QuantMode perTokenScaling() noexcept
|
||||
{
|
||||
return QuantMode(BaseType(1u) << 4);
|
||||
}
|
||||
|
||||
static constexpr QuantMode perGroupScaling() noexcept
|
||||
{
|
||||
return QuantMode(BaseType(1u) << 5);
|
||||
}
|
||||
|
||||
static constexpr QuantMode int8KvCache() noexcept
|
||||
{
|
||||
return QuantMode(BaseType(1u) << 6);
|
||||
}
|
||||
|
||||
static constexpr QuantMode fp8KvCache() noexcept
|
||||
{
|
||||
return QuantMode(BaseType(1u) << 7);
|
||||
}
|
||||
|
||||
static constexpr QuantMode fp8Qdq() noexcept
|
||||
{
|
||||
return QuantMode(BaseType(1u) << 8);
|
||||
}
|
||||
|
||||
static constexpr QuantMode fp8RowWise() noexcept
|
||||
{
|
||||
return QuantMode(BaseType(1u) << 3 | BaseType(1u) << 4 | BaseType(1u) << 9);
|
||||
}
|
||||
|
||||
static constexpr QuantMode w4a8QServe() noexcept
|
||||
{
|
||||
return QuantMode(BaseType(1u) << 10);
|
||||
}
|
||||
|
||||
constexpr BaseType value() const noexcept
|
||||
{
|
||||
return mValue;
|
||||
}
|
||||
|
||||
constexpr bool isSet(QuantMode const& mode) const noexcept
|
||||
{
|
||||
return (mValue & mode.value()) == mode.value();
|
||||
}
|
||||
|
||||
constexpr bool hasInt4Weights() const noexcept
|
||||
{
|
||||
return isSet(int4Weights());
|
||||
}
|
||||
|
||||
constexpr bool hasInt8Weights() const noexcept
|
||||
{
|
||||
return isSet(int8Weights());
|
||||
}
|
||||
|
||||
constexpr bool hasActivations() const noexcept
|
||||
{
|
||||
return isSet(activations());
|
||||
}
|
||||
|
||||
constexpr bool hasPerChannelScaling() const noexcept
|
||||
{
|
||||
return isSet(perChannelScaling());
|
||||
}
|
||||
|
||||
constexpr bool hasPerTokenScaling() const noexcept
|
||||
{
|
||||
return isSet(perTokenScaling());
|
||||
}
|
||||
|
||||
constexpr bool hasPerGroupScaling() const noexcept
|
||||
{
|
||||
return isSet(perGroupScaling());
|
||||
}
|
||||
|
||||
constexpr bool hasStaticActivationScaling() const noexcept
|
||||
{
|
||||
return !hasPerTokenScaling();
|
||||
}
|
||||
|
||||
constexpr bool hasInt8KvCache() const noexcept
|
||||
{
|
||||
return isSet(int8KvCache());
|
||||
}
|
||||
|
||||
constexpr bool hasFp8KvCache() const noexcept
|
||||
{
|
||||
return isSet(fp8KvCache());
|
||||
}
|
||||
|
||||
constexpr bool hasFp8Qdq() const noexcept
|
||||
{
|
||||
return isSet(fp8Qdq());
|
||||
}
|
||||
|
||||
constexpr bool hasFp8RowWise() const noexcept
|
||||
{
|
||||
return isSet(fp8RowWise());
|
||||
}
|
||||
|
||||
constexpr bool hasKvCacheQuant() const noexcept
|
||||
{
|
||||
return hasInt8KvCache() || hasFp8KvCache();
|
||||
}
|
||||
|
||||
static constexpr QuantMode fromDescription(bool quantizeWeights = false, bool quantizeActivations = false,
|
||||
bool perToken = false, bool perChannel = false, bool perGroup = false, bool useInt4Weights = false,
|
||||
bool useInt8KvCache = false, bool useFp8KvCache = false, bool useFp8Qdq = false, bool useFp8RowWise = false,
|
||||
bool useW4a8QServe = false)
|
||||
{
|
||||
QuantMode quantMode{};
|
||||
if (quantizeWeights)
|
||||
{
|
||||
if (useInt4Weights)
|
||||
quantMode += int4Weights();
|
||||
else
|
||||
quantMode += int8Weights();
|
||||
}
|
||||
|
||||
if (quantizeActivations)
|
||||
{
|
||||
quantMode += activations();
|
||||
}
|
||||
|
||||
if (perChannel)
|
||||
{
|
||||
quantMode += QuantMode::perChannelScaling();
|
||||
}
|
||||
if (perToken)
|
||||
{
|
||||
quantMode += QuantMode::perTokenScaling();
|
||||
}
|
||||
if (perGroup)
|
||||
{
|
||||
quantMode += QuantMode::perGroupScaling();
|
||||
}
|
||||
|
||||
if (useInt8KvCache)
|
||||
{
|
||||
quantMode += int8KvCache();
|
||||
}
|
||||
|
||||
if (useFp8KvCache)
|
||||
{
|
||||
quantMode += fp8KvCache();
|
||||
}
|
||||
|
||||
if (useFp8Qdq)
|
||||
{
|
||||
quantMode += fp8Qdq();
|
||||
}
|
||||
|
||||
if (useFp8RowWise)
|
||||
{
|
||||
quantMode += fp8RowWise();
|
||||
}
|
||||
|
||||
if (useW4a8QServe)
|
||||
{
|
||||
quantMode += w4a8QServe();
|
||||
}
|
||||
|
||||
return quantMode;
|
||||
}
|
||||
|
||||
static constexpr QuantMode useSmoothQuant(bool perToken = false, bool perChannel = false)
|
||||
{
|
||||
return fromDescription(true, true, perToken, perChannel);
|
||||
}
|
||||
|
||||
static constexpr QuantMode useQServe(bool perGroup)
|
||||
{
|
||||
return fromDescription(true, true, false, false, perGroup, true, false, false, false, false, true);
|
||||
}
|
||||
|
||||
static constexpr QuantMode useWeightOnly(bool useInt4Weights = false, bool perGroup = false)
|
||||
{
|
||||
return fromDescription(true, false, false, false, perGroup, useInt4Weights);
|
||||
}
|
||||
|
||||
static QuantMode const fromQuantAlgo(
|
||||
std::optional<std::string> quantAlgo = std::nullopt, std::optional<std::string> kvCacheQuantAlgo = std::nullopt)
|
||||
{
|
||||
QuantMode quantMode{};
|
||||
if (quantAlgo == "W8A16")
|
||||
{
|
||||
quantMode = useWeightOnly(false, false);
|
||||
}
|
||||
else if (quantAlgo == "W4A16")
|
||||
{
|
||||
quantMode = useWeightOnly(true, false);
|
||||
}
|
||||
else if (quantAlgo == "W4A16_AWQ")
|
||||
{
|
||||
quantMode = useWeightOnly(true, true);
|
||||
}
|
||||
else if (quantAlgo == "W4A8_AWQ")
|
||||
{
|
||||
quantMode = useWeightOnly(true, true);
|
||||
}
|
||||
else if (quantAlgo == "W4A8_QSERVE_PER_GROUP")
|
||||
{
|
||||
quantMode = useQServe(false);
|
||||
}
|
||||
else if (quantAlgo == "W4A8_QSERVE_PER_CHANNEL")
|
||||
{
|
||||
quantMode = useQServe(true);
|
||||
}
|
||||
else if (quantAlgo == "W4A16_GPTQ")
|
||||
{
|
||||
quantMode = useWeightOnly(true, true);
|
||||
}
|
||||
else if (quantAlgo == "W8A8_SQ_PER_CHANNEL")
|
||||
{
|
||||
quantMode = useSmoothQuant(false, true);
|
||||
}
|
||||
else if (quantAlgo == "W8A8_SQ_PER_TENSOR_PLUGIN")
|
||||
{
|
||||
quantMode = useSmoothQuant(false, false);
|
||||
}
|
||||
else if (quantAlgo == "W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN")
|
||||
{
|
||||
quantMode = useSmoothQuant(true, true);
|
||||
}
|
||||
else if (quantAlgo == "W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN")
|
||||
{
|
||||
quantMode = useSmoothQuant(false, true);
|
||||
}
|
||||
else if (quantAlgo == "W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN")
|
||||
{
|
||||
quantMode = useSmoothQuant(true, false);
|
||||
}
|
||||
else if (quantAlgo == "FP8")
|
||||
{
|
||||
quantMode = fromDescription(false, false, false, false, false, false, false, false, true);
|
||||
}
|
||||
else if (quantAlgo == "FP8_ROWWISE")
|
||||
{
|
||||
quantMode = fromDescription(false, false, true, true, false, false, false, false, false, true);
|
||||
}
|
||||
|
||||
if (kvCacheQuantAlgo == "INT8")
|
||||
{
|
||||
quantMode += int8KvCache();
|
||||
}
|
||||
else if (kvCacheQuantAlgo == "FP8")
|
||||
{
|
||||
quantMode += fp8KvCache();
|
||||
}
|
||||
|
||||
return quantMode;
|
||||
}
|
||||
|
||||
constexpr QuantMode operator+(QuantMode const& other) const noexcept
|
||||
{
|
||||
return QuantMode(mValue | other.mValue);
|
||||
}
|
||||
|
||||
constexpr QuantMode& operator+=(QuantMode const& other) noexcept
|
||||
{
|
||||
return *this = *this + other;
|
||||
}
|
||||
|
||||
constexpr QuantMode operator-(QuantMode const& other) const noexcept
|
||||
{
|
||||
return QuantMode(mValue & ~other.mValue);
|
||||
}
|
||||
|
||||
constexpr QuantMode& operator-=(QuantMode const& other) noexcept
|
||||
{
|
||||
return *this = *this - other;
|
||||
}
|
||||
|
||||
constexpr bool operator==(QuantMode const& other) const noexcept
|
||||
{
|
||||
return mValue == other.mValue;
|
||||
}
|
||||
|
||||
constexpr bool operator!=(QuantMode const& other) const noexcept
|
||||
{
|
||||
return !(*this == other);
|
||||
}
|
||||
|
||||
private:
|
||||
BaseType mValue{0};
|
||||
};
|
||||
|
||||
} // namespace common
|
||||
} // namespace tensorrt_llm
|
||||
@@ -1,399 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include <array>
|
||||
#include <assert.h>
|
||||
#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))
|
||||
#include <cooperative_groups/reduce.h>
|
||||
#else
|
||||
#include <cooperative_groups.h>
|
||||
#endif
|
||||
#include "tensorrt_llm/common/cudaTypeUtils.cuh"
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <curand_kernel.h>
|
||||
#include <float.h>
|
||||
#include <type_traits>
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace common
|
||||
{
|
||||
|
||||
template <int VPT>
|
||||
struct BytesToType;
|
||||
|
||||
template <>
|
||||
struct BytesToType<1>
|
||||
{
|
||||
using type = uint8_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct BytesToType<2>
|
||||
{
|
||||
using type = uint16_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct BytesToType<4>
|
||||
{
|
||||
using type = uint32_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct BytesToType<8>
|
||||
{
|
||||
using type = uint64_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct BytesToType<16>
|
||||
{
|
||||
using type = float4;
|
||||
};
|
||||
|
||||
template <int Bytes>
|
||||
__device__ inline void copy(void const* local, void* data)
|
||||
{
|
||||
using T = typename BytesToType<Bytes>::type;
|
||||
|
||||
T const* in = static_cast<T const*>(local);
|
||||
T* out = static_cast<T*>(data);
|
||||
*out = *in;
|
||||
}
|
||||
|
||||
static float constexpr HALF_FLT_MAX = 65504.F;
|
||||
#define FINAL_MASK 0xffffffff
|
||||
|
||||
template <typename T>
|
||||
__inline__ __device__ T warpReduceSum(T val)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1)
|
||||
val = add<T>(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32)); //__shfl_sync bf16 return float when sm < 80
|
||||
return val;
|
||||
}
|
||||
|
||||
/* Calculate the sum of all elements in a block */
|
||||
template <typename T>
|
||||
__inline__ __device__ T blockReduceSum(T val)
|
||||
{
|
||||
static __shared__ T shared[32];
|
||||
int lane = threadIdx.x & 0x1f;
|
||||
int wid = threadIdx.x >> 5;
|
||||
|
||||
val = warpReduceSum<T>(val);
|
||||
|
||||
if (lane == 0)
|
||||
shared[wid] = val;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
|
||||
// blockDim.x is not divided by 32
|
||||
val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T) (0.0f);
|
||||
val = warpReduceSum<T>(val);
|
||||
|
||||
return val;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__inline__ __device__ T warpReduceMax(T val)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1)
|
||||
val = max(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32));
|
||||
return val;
|
||||
}
|
||||
|
||||
/* Calculate the maximum of all elements in a block */
|
||||
template <typename T>
|
||||
__inline__ __device__ T blockReduceMax(T val)
|
||||
{
|
||||
static __shared__ T shared[32];
|
||||
int lane = threadIdx.x & 0x1f; // in-warp idx
|
||||
int wid = threadIdx.x >> 5; // warp idx
|
||||
|
||||
val = warpReduceMax(val); // get maxx in each warp
|
||||
|
||||
if (lane == 0) // record in-warp maxx by warp Idx
|
||||
shared[wid] = val;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
|
||||
// blockDim.x is not divided by 32
|
||||
val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : -1e20f;
|
||||
val = warpReduceMax(val);
|
||||
|
||||
return val;
|
||||
}
|
||||
|
||||
/* Calculate the maximum of all elements in a block */
|
||||
template <typename T>
|
||||
__inline__ __device__ T blockAllReduceMax(T val)
|
||||
{
|
||||
static __shared__ T shared[32];
|
||||
int lane = threadIdx.x & 0x1f; // in-warp idx
|
||||
int wid = threadIdx.x >> 5; // warp idx
|
||||
|
||||
val = warpReduceMax(val); // get maxx in each warp
|
||||
|
||||
if (lane == 0) // record in-warp maxx by warp Idx
|
||||
shared[wid] = val;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
|
||||
// blockDim.x is not divided by 32
|
||||
val = (lane < (blockDim.x / 32.f)) ? shared[lane] : -1e20f;
|
||||
val = warpReduceMax(val);
|
||||
|
||||
return val;
|
||||
}
|
||||
|
||||
template <typename T, int NUM>
|
||||
__inline__ __device__ T warpReduceSumV2(T* val)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NUM; i++)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1)
|
||||
val[i] += __shfl_xor_sync(FINAL_MASK, val[i], mask, 32);
|
||||
}
|
||||
return (T) (0.0f);
|
||||
}
|
||||
|
||||
template <typename T, int NUM>
|
||||
__inline__ __device__ T blockReduceSumV2(T* val)
|
||||
{
|
||||
static __shared__ T shared[NUM][33];
|
||||
int lane = threadIdx.x & 0x1f;
|
||||
int wid = threadIdx.x >> 5;
|
||||
|
||||
warpReduceSumV2<T, NUM>(val);
|
||||
|
||||
if (lane == 0)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NUM; i++)
|
||||
{
|
||||
shared[i][wid] = val[i];
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
bool is_mask = threadIdx.x < (blockDim.x / 32.f);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NUM; i++)
|
||||
{
|
||||
val[i] = is_mask ? shared[i][lane] : (T) (0.0f);
|
||||
}
|
||||
warpReduceSumV2<T, NUM>(val);
|
||||
return (T) 0.0f;
|
||||
}
|
||||
|
||||
template <typename T, int NUM>
|
||||
__inline__ __device__ T warpReduceMaxV2(T* val)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NUM; i++)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1)
|
||||
val[i] = max(val[i], __shfl_xor_sync(FINAL_MASK, val[i], mask, 32));
|
||||
}
|
||||
return (T) (0.0f);
|
||||
}
|
||||
|
||||
template <typename T, int NUM>
|
||||
__inline__ __device__ T blockReduceMaxV2(T* val)
|
||||
{
|
||||
static __shared__ T shared[32][NUM];
|
||||
int lane = threadIdx.x & 0x1f; // in-warp idx
|
||||
int wid = threadIdx.x >> 5; // warp idx
|
||||
|
||||
warpReduceMaxV2<T, NUM>(val); // get maxx in each warp
|
||||
|
||||
if (lane == 0) // record in-warp maxx by warp Idx
|
||||
{
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NUM; i++)
|
||||
{
|
||||
shared[wid][i] = val[i];
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
|
||||
// blockDim.x is not divided by 32
|
||||
bool is_mask = threadIdx.x < (blockDim.x / 32.f);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NUM; i++)
|
||||
{
|
||||
val[i] = is_mask ? shared[lane][i] : (T) -1e20f;
|
||||
}
|
||||
warpReduceMaxV2<T, NUM>(val);
|
||||
|
||||
return (T) 0.0f;
|
||||
}
|
||||
|
||||
template <int NUM>
|
||||
__inline__ __device__ void cgBlockReduceSumElements(float* element_list, float* cgBlockReduceSumElements_shm)
|
||||
{
|
||||
cg::thread_block cta = cg::this_thread_block();
|
||||
cg::thread_block_tile<32> tile = cg::tiled_partition<32>(cta);
|
||||
|
||||
int const tid = cta.thread_rank();
|
||||
int const blockz = blockDim.x;
|
||||
for (int i = 0; i < NUM; i++)
|
||||
{
|
||||
#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))
|
||||
cgBlockReduceSumElements_shm[i * blockz + tid] = cg::reduce(tile, element_list[i], cg::plus<float>());
|
||||
#else
|
||||
// TODO Add implementation here
|
||||
if (threadIdx.x == 0 && blockIdx.x == 0)
|
||||
{
|
||||
printf("[ERROR] Not support cgBlockReduceSumElements when CUDA < 11 \n");
|
||||
assert(false);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
cg::sync(cta);
|
||||
if (tid == 0)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NUM; i++)
|
||||
{
|
||||
float beta = 0.0f;
|
||||
for (int j = 0; j < blockz; j += 32)
|
||||
{
|
||||
beta += cgBlockReduceSumElements_shm[i * blockz + j];
|
||||
}
|
||||
element_list[i] = beta;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int MAX_K>
|
||||
struct TopK
|
||||
{
|
||||
int p[MAX_K]; // index, being -1 at the tail if the array is not full
|
||||
T u[MAX_K]; // value in descend order, being -MAX_T_VAL if the element is invalid
|
||||
|
||||
__device__ __forceinline__ void insert(T const elem, int const elem_id)
|
||||
{
|
||||
if (elem_id < 0)
|
||||
{
|
||||
return;
|
||||
}
|
||||
// Condition of updating the array
|
||||
// 1. array is not full
|
||||
// 2. elem is greater than the smallest (last) element in the array
|
||||
// 3. elem is equal to the smallest (last) element in the array but its elem_id is smaller
|
||||
bool const need_update
|
||||
= (p[MAX_K - 1] == -1 || elem > u[MAX_K - 1] || elem == u[MAX_K - 1] && elem_id < p[MAX_K - 1]);
|
||||
if (!need_update)
|
||||
{
|
||||
return;
|
||||
}
|
||||
// Find suitable index for the new element
|
||||
int i;
|
||||
for (i = MAX_K - 2; i >= 0; --i)
|
||||
{
|
||||
bool const need_decrease = (p[i] == -1 || elem > u[i] || elem == u[i] && elem_id < p[i]);
|
||||
if (!need_decrease)
|
||||
break;
|
||||
}
|
||||
// Move elements to correct positions
|
||||
for (int k = MAX_K - 2; k >= i; --k)
|
||||
{
|
||||
p[k + 1] = p[k];
|
||||
u[k + 1] = u[k];
|
||||
}
|
||||
p[i] = elem_id;
|
||||
u[i] = elem;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void init()
|
||||
{
|
||||
T const MAX_T_VAL = (std::is_same<T, half>::value) ? HALF_FLT_MAX : FLT_MAX;
|
||||
for (int i = 0; i < MAX_K; i++)
|
||||
{
|
||||
p[i] = -1;
|
||||
u[i] = -MAX_T_VAL;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int MAX_K>
|
||||
__device__ __forceinline__ TopK<T, MAX_K> reduce_topk_op(TopK<T, MAX_K> const& a, TopK<T, MAX_K> const& b)
|
||||
{
|
||||
TopK<T, MAX_K> res = a;
|
||||
for (int i = 0; i < MAX_K; ++i)
|
||||
res.insert(b.u[i], b.p[i]);
|
||||
return res;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct TopK_2
|
||||
{
|
||||
int p = -1;
|
||||
T u = -((std::is_same<T, half>::value) ? HALF_FLT_MAX : FLT_MAX);
|
||||
|
||||
__device__ __forceinline__ void insert(T elem, int elem_id)
|
||||
{
|
||||
if (elem > u)
|
||||
{
|
||||
u = elem;
|
||||
p = elem_id;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void init()
|
||||
{
|
||||
u = -((std::is_same<T, half>::value) ? HALF_FLT_MAX : FLT_MAX);
|
||||
p = -1;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ TopK_2<T> reduce_topk_op_2(TopK_2<T> const& a, TopK_2<T> const& b)
|
||||
{
|
||||
return a.u > b.u ? a : b;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T clamp_inf_for_half(float const input)
|
||||
{
|
||||
return input;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __forceinline__ half clamp_inf_for_half(float const input)
|
||||
{
|
||||
// clamp inf values to enable fp16 training
|
||||
return input > 0.0f ? (half) min(input, HALF_FLT_MAX - 1000) : (half) max(input, -HALF_FLT_MAX + 1000);
|
||||
}
|
||||
|
||||
} // namespace common
|
||||
} // namespace tensorrt_llm
|
||||
@@ -1,76 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/common/stringUtils.h"
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
|
||||
#include <cerrno>
|
||||
#include <cstdarg>
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
namespace tensorrt_llm::common
|
||||
{
|
||||
|
||||
namespace
|
||||
{
|
||||
std::string vformat(char const* fmt, va_list args)
|
||||
{
|
||||
va_list args0;
|
||||
va_copy(args0, args);
|
||||
auto const size = vsnprintf(nullptr, 0, fmt, args0);
|
||||
if (size <= 0)
|
||||
return "";
|
||||
|
||||
std::string stringBuf(size, char{});
|
||||
auto const size2 = std::vsnprintf(&stringBuf[0], size + 1, fmt, args);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(size2 == size, std::string(std::strerror(errno)));
|
||||
|
||||
return stringBuf;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::string fmtstr(char const* format, ...)
|
||||
{
|
||||
va_list args;
|
||||
va_start(args, format);
|
||||
std::string result = vformat(format, args);
|
||||
va_end(args);
|
||||
return result;
|
||||
};
|
||||
|
||||
std::unordered_set<std::string> str2set(std::string const& input, char delimiter)
|
||||
{
|
||||
std::unordered_set<std::string> values;
|
||||
if (!input.empty())
|
||||
{
|
||||
std::stringstream valStream(input);
|
||||
std::string val;
|
||||
while (std::getline(valStream, val, delimiter))
|
||||
{
|
||||
if (!val.empty())
|
||||
{
|
||||
values.insert(val);
|
||||
}
|
||||
}
|
||||
}
|
||||
return values;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::common
|
||||
@@ -1,113 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#if ENABLE_BF16
|
||||
#include <cuda_bf16.h>
|
||||
#endif // ENABLE_BF16
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#include <memory> // std::make_unique
|
||||
#include <sstream> // std::stringstream
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
namespace tensorrt_llm::common
|
||||
{
|
||||
#if ENABLE_BF16
|
||||
static inline std::basic_ostream<char>& operator<<(std::basic_ostream<char>& stream, __nv_bfloat16 const& val)
|
||||
{
|
||||
stream << __bfloat162float(val);
|
||||
return stream;
|
||||
}
|
||||
#endif // ENABLE_BF16
|
||||
|
||||
static inline std::basic_ostream<char>& operator<<(std::basic_ostream<char>& stream, __half const& val)
|
||||
{
|
||||
stream << __half2float(val);
|
||||
return stream;
|
||||
}
|
||||
|
||||
inline std::string fmtstr(std::string const& s)
|
||||
{
|
||||
return s;
|
||||
}
|
||||
|
||||
inline std::string fmtstr(std::string&& s)
|
||||
{
|
||||
return s;
|
||||
}
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
std::string fmtstr(char const* format, ...);
|
||||
#else
|
||||
std::string fmtstr(char const* format, ...) __attribute__((format(printf, 1, 2)));
|
||||
#endif
|
||||
|
||||
// __PRETTY_FUNCTION__ is used for neat debugging printing but is not supported on Windows
|
||||
// The alternative is __FUNCSIG__, which is similar but not identical
|
||||
#if defined(_WIN32)
|
||||
#define __PRETTY_FUNCTION__ __FUNCSIG__
|
||||
#endif
|
||||
|
||||
auto constexpr kDefaultDelimiter = ", ";
|
||||
|
||||
template <typename U, typename TStream, typename T>
|
||||
inline TStream& arr2outCasted(TStream& out, T* arr, size_t size, char const* delim = kDefaultDelimiter)
|
||||
{
|
||||
out << "(";
|
||||
if (size > 0)
|
||||
{
|
||||
for (size_t i = 0; i < size - 1; ++i)
|
||||
{
|
||||
out << static_cast<U>(arr[i]) << delim;
|
||||
}
|
||||
out << static_cast<U>(arr[size - 1]);
|
||||
}
|
||||
out << ")";
|
||||
return out;
|
||||
}
|
||||
|
||||
template <typename TStream, typename T>
|
||||
inline TStream& arr2out(TStream& out, T* arr, size_t size, char const* delim = kDefaultDelimiter)
|
||||
{
|
||||
return arr2outCasted<T>(out, arr, size, delim);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline std::string arr2str(T* arr, size_t size, char const* delim = kDefaultDelimiter)
|
||||
{
|
||||
std::stringstream ss;
|
||||
return arr2out(ss, arr, size, delim).str();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline std::string vec2str(std::vector<T> const& vec, char const* delim = kDefaultDelimiter)
|
||||
{
|
||||
return arr2str(vec.data(), vec.size(), delim);
|
||||
}
|
||||
|
||||
inline bool strStartsWith(std::string const& str, std::string const& prefix)
|
||||
{
|
||||
return str.rfind(prefix, 0) == 0;
|
||||
}
|
||||
|
||||
/// @brief Split a string into a set of strings using a delimiter
|
||||
std::unordered_set<std::string> str2set(std::string const& input, char delimiter);
|
||||
|
||||
} // namespace tensorrt_llm::common
|
||||
@@ -1,105 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/common/tllmException.h"
|
||||
#include "tensorrt_llm/common/stringUtils.h"
|
||||
|
||||
#include <cstdlib>
|
||||
#if !defined(_MSC_VER)
|
||||
#include <cxxabi.h>
|
||||
#include <dlfcn.h>
|
||||
#include <execinfo.h>
|
||||
#endif
|
||||
#include <sstream>
|
||||
|
||||
namespace tensorrt_llm::common
|
||||
{
|
||||
|
||||
namespace
|
||||
{
|
||||
int constexpr VOID_PTR_SZ = 2 + sizeof(void*) * 2;
|
||||
}
|
||||
|
||||
#if !defined(_MSC_VER)
|
||||
|
||||
TllmException::TllmException(char const* file, std::size_t line, std::string const& msg)
|
||||
: std::runtime_error{""}
|
||||
{
|
||||
mNbFrames = backtrace(mCallstack.data(), MAX_FRAMES);
|
||||
auto const trace = getTrace();
|
||||
std::runtime_error::operator=(
|
||||
std::runtime_error{fmtstr("%s (%s:%zu)\n%s", msg.c_str(), file, line, trace.c_str())});
|
||||
}
|
||||
#else
|
||||
TllmException::TllmException(char const* file, std::size_t line, std::string const& msg)
|
||||
: mNbFrames{}
|
||||
, std::runtime_error{fmtstr("%s (%s:%zu)", msg.c_str(), file, line)}
|
||||
{
|
||||
}
|
||||
#endif
|
||||
|
||||
TllmException::~TllmException() noexcept = default;
|
||||
|
||||
std::string TllmException::getTrace() const
|
||||
{
|
||||
#if defined(_MSC_VER)
|
||||
return "";
|
||||
#else
|
||||
auto const trace = backtrace_symbols(mCallstack.data(), mNbFrames);
|
||||
std::ostringstream buf;
|
||||
for (auto i = 1; i < mNbFrames; ++i)
|
||||
{
|
||||
Dl_info info;
|
||||
if (dladdr(mCallstack[i], &info) && info.dli_sname)
|
||||
{
|
||||
auto const clearName = demangle(info.dli_sname);
|
||||
buf << fmtstr("%-3d %*p %s + %zd", i, VOID_PTR_SZ, mCallstack[i], clearName.c_str(),
|
||||
static_cast<char*>(mCallstack[i]) - static_cast<char*>(info.dli_saddr));
|
||||
}
|
||||
else
|
||||
{
|
||||
buf << fmtstr("%-3d %*p %s", i, VOID_PTR_SZ, mCallstack[i], trace[i]);
|
||||
}
|
||||
if (i < mNbFrames - 1)
|
||||
buf << std::endl;
|
||||
}
|
||||
|
||||
if (mNbFrames == MAX_FRAMES)
|
||||
buf << std::endl << "[truncated]";
|
||||
|
||||
std::free(trace);
|
||||
return buf.str();
|
||||
#endif
|
||||
}
|
||||
|
||||
std::string TllmException::demangle(char const* name)
|
||||
{
|
||||
#if defined(_MSC_VER)
|
||||
return name;
|
||||
#else
|
||||
std::string clearName{name};
|
||||
auto status = -1;
|
||||
auto const demangled = abi::__cxa_demangle(name, nullptr, nullptr, &status);
|
||||
if (status == 0)
|
||||
{
|
||||
clearName = demangled;
|
||||
std::free(demangled);
|
||||
}
|
||||
return clearName;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::common
|
||||
@@ -1,48 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
#include <cstddef>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
|
||||
#define NEW_TLLM_EXCEPTION(...) \
|
||||
tensorrt_llm::common::TllmException(__FILE__, __LINE__, tensorrt_llm::common::fmtstr(__VA_ARGS__))
|
||||
|
||||
namespace tensorrt_llm::common
|
||||
{
|
||||
|
||||
class TllmException : public std::runtime_error
|
||||
{
|
||||
public:
|
||||
static auto constexpr MAX_FRAMES = 128;
|
||||
|
||||
explicit TllmException(char const* file, std::size_t line, std::string const& msg);
|
||||
|
||||
~TllmException() noexcept override;
|
||||
|
||||
[[nodiscard]] std::string getTrace() const;
|
||||
|
||||
static std::string demangle(char const* name);
|
||||
|
||||
private:
|
||||
std::array<void*, MAX_FRAMES> mCallstack{};
|
||||
int mNbFrames;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::common
|
||||
@@ -1,87 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 1993-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
|
||||
namespace tensorrt_llm::common
|
||||
{
|
||||
|
||||
std::uintptr_t constexpr kCudaMemAlign = 128;
|
||||
|
||||
inline int8_t* alignPtr(int8_t* ptr, uintptr_t to)
|
||||
{
|
||||
uintptr_t addr = (uintptr_t) ptr;
|
||||
if (addr % to)
|
||||
{
|
||||
addr += to - addr % to;
|
||||
}
|
||||
return (int8_t*) addr;
|
||||
}
|
||||
|
||||
constexpr size_t alignSize(size_t size, size_t to)
|
||||
{
|
||||
if ((size % to) != 0U)
|
||||
{
|
||||
size += to - size % to;
|
||||
}
|
||||
return size;
|
||||
}
|
||||
|
||||
inline int8_t* nextWorkspacePtrCommon(int8_t* ptr, uintptr_t previousWorkspaceSize, uintptr_t const alignment)
|
||||
{
|
||||
uintptr_t addr = (uintptr_t) ptr;
|
||||
addr += previousWorkspaceSize;
|
||||
return alignPtr((int8_t*) addr, alignment);
|
||||
}
|
||||
|
||||
inline int8_t* nextWorkspacePtr(int8_t* ptr, uintptr_t previousWorkspaceSize)
|
||||
{
|
||||
return nextWorkspacePtrCommon(ptr, previousWorkspaceSize, kCudaMemAlign);
|
||||
}
|
||||
|
||||
inline int8_t* nextWorkspacePtr(
|
||||
int8_t* const base, uintptr_t& offset, uintptr_t const size, uintptr_t const alignment = kCudaMemAlign)
|
||||
{
|
||||
uintptr_t curr_offset = offset;
|
||||
uintptr_t next_offset = curr_offset + ((size + alignment - 1) / alignment) * alignment;
|
||||
int8_t* newptr = size == 0 ? nullptr : base + curr_offset;
|
||||
offset = next_offset;
|
||||
return newptr;
|
||||
}
|
||||
|
||||
inline int8_t* nextWorkspacePtrWithAlignment(
|
||||
int8_t* ptr, uintptr_t previousWorkspaceSize, uintptr_t const alignment = kCudaMemAlign)
|
||||
{
|
||||
return nextWorkspacePtrCommon(ptr, previousWorkspaceSize, alignment);
|
||||
}
|
||||
|
||||
inline size_t calculateTotalWorkspaceSize(
|
||||
size_t const* workspaces, int count, uintptr_t const alignment = kCudaMemAlign)
|
||||
{
|
||||
size_t total = 0;
|
||||
for (int i = 0; i < count; i++)
|
||||
{
|
||||
total += workspaces[i];
|
||||
if (workspaces[i] % alignment)
|
||||
{
|
||||
total += alignment - (workspaces[i] % alignment);
|
||||
}
|
||||
}
|
||||
return total;
|
||||
}
|
||||
|
||||
}; // namespace tensorrt_llm::common
|
||||
@@ -1,352 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/arch/util.hpp>
|
||||
#include <cute/atom/copy_traits.hpp>
|
||||
#include <cute/numeric/numeric_types.hpp>
|
||||
|
||||
// Config
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) && (__CUDACC_VER_MAJOR__ >= 10))
|
||||
#define CUTE_ARCH_RED_F16_SM70_ENABLED
|
||||
#endif
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12))
|
||||
#define CUTE_ARCH_RED_VEC_SM90_ENABLED
|
||||
#define CUTE_ARCH_RED_BF16_SM90_ENABLED
|
||||
#endif
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
//////////////////////////////////
|
||||
// Wrapper around CUDA's atomicAdd
|
||||
//////////////////////////////////
|
||||
|
||||
template <class T>
|
||||
struct TypedAtomicAdd
|
||||
{
|
||||
using SRegisters = T[1];
|
||||
using DRegisters = T[1];
|
||||
|
||||
CUTE_HOST_DEVICE static constexpr void copy(T const& src, T& dst)
|
||||
{
|
||||
atomicAdd(&dst, src);
|
||||
}
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct Copy_Traits<TypedAtomicAdd<T>>
|
||||
{
|
||||
// Logical thread id to thread idx (one-thread)
|
||||
using ThrID = Layout<_1>;
|
||||
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape<_1, Int<sizeof_bits<T>::value>>>;
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_1, Int<sizeof_bits<T>::value>>>;
|
||||
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
};
|
||||
|
||||
//////////////////////////////////
|
||||
// F16 ADD PTX
|
||||
//////////////////////////////////
|
||||
|
||||
struct SM70_RED_ADD_NOFTZ_F16
|
||||
{
|
||||
using SRegisters = uint16_t[1];
|
||||
using DRegisters = uint16_t[1];
|
||||
|
||||
CUTE_HOST_DEVICE static void copy(uint16_t const& src0, uint16_t& gmem_dst)
|
||||
{
|
||||
#if defined(CUTE_ARCH_RED_F16_SM70_ENABLED)
|
||||
asm volatile("red.global.add.noftz.f16 [%0], %1;\n" ::"l"(&gmem_dst), "h"(src0));
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.f16 without CUTE_ARCH_RED_F16_SM70_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Copy_Traits<SM70_RED_ADD_NOFTZ_F16>
|
||||
{
|
||||
// Logical thread id to thread idx (one-thread)
|
||||
using ThrID = Layout<_1>;
|
||||
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape<_1, _16>>;
|
||||
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_1, _16>>;
|
||||
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
};
|
||||
|
||||
struct SM70_RED_ADD_NOFTZ_F16x2
|
||||
{
|
||||
using SRegisters = uint32_t[1];
|
||||
using DRegisters = uint32_t[1];
|
||||
|
||||
CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t& gmem_dst)
|
||||
{
|
||||
#if defined(CUTE_ARCH_RED_F16_SM70_ENABLED)
|
||||
asm volatile("red.global.add.noftz.f16x2 [%0], %1;\n" ::"l"(&gmem_dst), "r"(src0));
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.f16 without CUTE_ARCH_RED_F16_SM70_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Copy_Traits<SM70_RED_ADD_NOFTZ_F16x2>
|
||||
{
|
||||
// Logical thread id to thread idx (one-thread)
|
||||
using ThrID = Layout<_1>;
|
||||
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape<_1, _32>>;
|
||||
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_1, _32>>;
|
||||
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
};
|
||||
|
||||
struct SM90_RED_ADD_NOFTZ_F16x2_V2
|
||||
{
|
||||
using SRegisters = uint32_t[2];
|
||||
using DRegisters = uint64_t[1];
|
||||
|
||||
CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t const& src1, uint64_t& gmem_dst)
|
||||
{
|
||||
#if defined(CUTE_ARCH_RED_VEC_SM90_ENABLED)
|
||||
asm volatile("red.global.add.noftz.v2.f16x2 [%0], {%1, %2};\n" ::"l"(&gmem_dst), "r"(src0), "r"(src1));
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.vX without CUTE_ARCH_RED_VEC_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Copy_Traits<SM90_RED_ADD_NOFTZ_F16x2_V2>
|
||||
{
|
||||
// Logical thread id to thread idx (one-thread)
|
||||
using ThrID = Layout<_1>;
|
||||
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape<_1, _64>>;
|
||||
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_1, _64>>;
|
||||
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
};
|
||||
|
||||
struct SM90_RED_ADD_NOFTZ_F16x2_V4
|
||||
{
|
||||
using SRegisters = uint32_t[4];
|
||||
using DRegisters = uint128_t[1];
|
||||
|
||||
CUTE_HOST_DEVICE static void copy(
|
||||
uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, uint128_t& gmem_dst)
|
||||
{
|
||||
#if defined(CUTE_ARCH_RED_VEC_SM90_ENABLED)
|
||||
asm volatile("red.global.add.noftz.v4.f16x2 [%0], {%1, %2, %3, %4};\n" ::"l"(&gmem_dst), "r"(src0), "r"(src1),
|
||||
"r"(src2), "r"(src3));
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.vX without CUTE_ARCH_RED_VEC_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Copy_Traits<SM90_RED_ADD_NOFTZ_F16x2_V4>
|
||||
{
|
||||
// Logical thread id to thread idx (one-thread)
|
||||
using ThrID = Layout<_1>;
|
||||
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape<_1, _128>>;
|
||||
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_1, _128>>;
|
||||
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
};
|
||||
|
||||
//////////////////////////////////
|
||||
// BF16 ADD PTX
|
||||
//////////////////////////////////
|
||||
|
||||
struct SM90_RED_ADD_NOFTZ_BF16
|
||||
{
|
||||
using SRegisters = uint16_t[1];
|
||||
using DRegisters = uint16_t[1];
|
||||
|
||||
CUTE_HOST_DEVICE static void copy(uint16_t const& src0, uint16_t& gmem_dst)
|
||||
{
|
||||
#if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED)
|
||||
asm volatile("red.global.add.noftz.bf16 [%0], %1;\n" ::"l"(&gmem_dst), "h"(src0));
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Copy_Traits<SM90_RED_ADD_NOFTZ_BF16>
|
||||
{
|
||||
// Logical thread id to thread idx (one-thread)
|
||||
using ThrID = Layout<_1>;
|
||||
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape<_1, _16>>;
|
||||
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_1, _16>>;
|
||||
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
};
|
||||
|
||||
//////////////////////////////////
|
||||
|
||||
struct SM90_RED_ADD_NOFTZ_BF16x2
|
||||
{
|
||||
using SRegisters = uint32_t[1];
|
||||
using DRegisters = uint32_t[1];
|
||||
|
||||
CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t& gmem_dst)
|
||||
{
|
||||
#if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED)
|
||||
asm volatile("red.global.add.noftz.bf16x2 [%0], %1;\n" ::"l"(&gmem_dst), "r"(src0));
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Copy_Traits<SM90_RED_ADD_NOFTZ_BF16x2>
|
||||
{
|
||||
// Logical thread id to thread idx (one-thread)
|
||||
using ThrID = Layout<_1>;
|
||||
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape<_1, _32>>;
|
||||
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_1, _32>>;
|
||||
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
};
|
||||
|
||||
//////////////////////////////////
|
||||
|
||||
struct SM90_RED_ADD_NOFTZ_BF16x2_V2
|
||||
{
|
||||
using SRegisters = uint32_t[2];
|
||||
using DRegisters = uint64_t[1];
|
||||
|
||||
CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t const& src1, uint64_t& gmem_dst)
|
||||
{
|
||||
#if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED)
|
||||
asm volatile("red.global.add.noftz.v2.bf16x2 [%0], {%1, %2};\n" ::"l"(&gmem_dst), "r"(src0), "r"(src1));
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Copy_Traits<SM90_RED_ADD_NOFTZ_BF16x2_V2>
|
||||
{
|
||||
// Logical thread id to thread idx (one-thread)
|
||||
using ThrID = Layout<_1>;
|
||||
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape<_1, _64>>;
|
||||
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_1, _64>>;
|
||||
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
};
|
||||
|
||||
//////////////////////////////////
|
||||
|
||||
struct SM90_RED_ADD_NOFTZ_BF16x2_V4
|
||||
{
|
||||
using SRegisters = uint32_t[4];
|
||||
using DRegisters = uint128_t[1];
|
||||
|
||||
CUTE_HOST_DEVICE static void copy(
|
||||
uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, uint128_t& gmem_dst)
|
||||
{
|
||||
#if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED)
|
||||
asm volatile("red.global.add.noftz.v4.bf16x2 [%0], {%1, %2, %3, %4};\n" ::"l"(&gmem_dst), "r"(src0), "r"(src1),
|
||||
"r"(src2), "r"(src3));
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Copy_Traits<SM90_RED_ADD_NOFTZ_BF16x2_V4>
|
||||
{
|
||||
// Logical thread id to thread idx (one-thread)
|
||||
using ThrID = Layout<_1>;
|
||||
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape<_1, _128>>;
|
||||
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_1, _128>>;
|
||||
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
};
|
||||
|
||||
//////////////////////////////////
|
||||
|
||||
} // end namespace cute
|
||||
@@ -1,120 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Templates exposing architecture support for multiply-add operations
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "cutlass_extensions/weight_only_quant_op.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace arch
|
||||
{
|
||||
|
||||
// Tag which triggers MMA which will trigger
|
||||
struct OpMultiplyAddDequantizeInterleavedBToA;
|
||||
|
||||
/*
|
||||
Below we have extra tags to signal what kind of dequantization we want to do
|
||||
(per col, scale only fine grained, finegrained with zero). This still lets us
|
||||
the existing template infrastructure (incl. that in CUTLASS). However, we
|
||||
split out the template below into OpMultiplyAddDequantizeInterleavedBToA along
|
||||
with the quantization op before instantiating the GEMM pieces.
|
||||
|
||||
Note that this is somewhat of a hack, but it SIGNIFICANTLY reduces the amount of
|
||||
code we need to duplicate.
|
||||
*/
|
||||
struct OpMultiplyAddDequantizeInterleavedBToA_percol_scale;
|
||||
struct OpMultiplyAddDequantizeInterleavedBToA_fine_scale;
|
||||
struct OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias;
|
||||
|
||||
// The default just forwards the original operator
|
||||
template <typename MmaOp, WeightOnlyQuantOp QuantOp_>
|
||||
struct TagOperator
|
||||
{
|
||||
using TaggedOperator = MmaOp;
|
||||
};
|
||||
|
||||
// Specializations below attach more information to the operator
|
||||
template <>
|
||||
struct TagOperator<OpMultiplyAddDequantizeInterleavedBToA, WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY>
|
||||
{
|
||||
using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_percol_scale;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TagOperator<OpMultiplyAddDequantizeInterleavedBToA, WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY>
|
||||
{
|
||||
using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_fine_scale;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TagOperator<OpMultiplyAddDequantizeInterleavedBToA, WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS>
|
||||
{
|
||||
using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias;
|
||||
};
|
||||
|
||||
// Here we instantiate some structs to "detag" the tagged operator. It splits it back to the original
|
||||
// operator + the extra information. If no extra info was tagged, the dequant op per column scaling
|
||||
// as a default.
|
||||
template <typename TaggedMmaOp>
|
||||
struct DetagOperator
|
||||
{
|
||||
using Operator = TaggedMmaOp;
|
||||
static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DetagOperator<OpMultiplyAddDequantizeInterleavedBToA_percol_scale>
|
||||
{
|
||||
using Operator = OpMultiplyAddDequantizeInterleavedBToA;
|
||||
static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DetagOperator<OpMultiplyAddDequantizeInterleavedBToA_fine_scale>
|
||||
{
|
||||
using Operator = OpMultiplyAddDequantizeInterleavedBToA;
|
||||
static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DetagOperator<OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias>
|
||||
{
|
||||
using Operator = OpMultiplyAddDequantizeInterleavedBToA;
|
||||
static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS;
|
||||
};
|
||||
|
||||
} // namespace arch
|
||||
} // namespace cutlass
|
||||
@@ -1,88 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
#include "cutlass/device_kernel.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace cutlass_extensions
|
||||
{
|
||||
|
||||
template <typename GemmKernel, bool enable_cutlass_3x = false>
|
||||
inline int compute_occupancy_for_kernel()
|
||||
{
|
||||
|
||||
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
|
||||
|
||||
if (smem_size > (48 << 10))
|
||||
{
|
||||
cudaFuncAttributes attr;
|
||||
int device = 0;
|
||||
int max_smem_per_block = 0;
|
||||
tensorrt_llm::common::check_cuda_error(cudaGetDevice(&device));
|
||||
tensorrt_llm::common::check_cuda_error(
|
||||
cudaDeviceGetAttribute(&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device));
|
||||
if constexpr (enable_cutlass_3x)
|
||||
{
|
||||
tensorrt_llm::common::check_cuda_error(cudaFuncGetAttributes(&attr, cutlass::device_kernel<GemmKernel>));
|
||||
}
|
||||
else
|
||||
{
|
||||
tensorrt_llm::common::check_cuda_error(cudaFuncGetAttributes(&attr, cutlass::Kernel<GemmKernel>));
|
||||
}
|
||||
if (smem_size + attr.sharedSizeBytes >= static_cast<size_t>(max_smem_per_block))
|
||||
{
|
||||
// This should mean that
|
||||
// cudaFuncSetAttribute(cutlass::Kernel<GemmKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)
|
||||
// wouldn't work. In that case, we return an occupancy of 0. This will cause the heuristic to ignore this
|
||||
// configuration.
|
||||
return 0;
|
||||
}
|
||||
|
||||
if constexpr (enable_cutlass_3x)
|
||||
{
|
||||
tensorrt_llm::common::check_cuda_error(cudaFuncSetAttribute(
|
||||
cutlass::device_kernel<GemmKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
}
|
||||
else
|
||||
{
|
||||
tensorrt_llm::common::check_cuda_error(cudaFuncSetAttribute(
|
||||
cutlass::Kernel<GemmKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
}
|
||||
}
|
||||
|
||||
int max_active_blocks = -1;
|
||||
if constexpr (enable_cutlass_3x)
|
||||
{
|
||||
tensorrt_llm::common::check_cuda_error(
|
||||
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, cutlass::device_kernel<GemmKernel>,
|
||||
128 * (GemmKernel::NumLoadWarpGroups + GemmKernel::NumMmaWarpGroups), smem_size));
|
||||
}
|
||||
else
|
||||
{
|
||||
tensorrt_llm::common::check_cuda_error(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_active_blocks, cutlass::Kernel<GemmKernel>, GemmKernel::kThreadCount, smem_size));
|
||||
}
|
||||
|
||||
return max_active_blocks;
|
||||
}
|
||||
|
||||
} // namespace cutlass_extensions
|
||||
} // namespace tensorrt_llm
|
||||
@@ -1,550 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Functor performing elementwise operations used by epilogues.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/epilogue/collective/detail.hpp"
|
||||
#include "cutlass/fast_math.h"
|
||||
|
||||
#include "cute/numeric/numeric_types.hpp"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
#include "cutlass_extensions/arch/copy_red_global.hpp"
|
||||
#include "cutlass_extensions/util/gather_tensor.hpp"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace epilogue
|
||||
{
|
||||
namespace collective
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <class StrideC_, class ElementD_, class StrideD_, class ThreadEpilogueOp_, class ElementBias, class StrideBias,
|
||||
class ElementScale, class StrideScale, class EpilogueTile, class SmemLayoutAtomD, class CopyOpR2S, class CopyOpS2R,
|
||||
class CopyOpR2G>
|
||||
class EpilogueMoeFusedFinalize
|
||||
{
|
||||
public:
|
||||
using EpilogueSchedule = PtrArrayNoSmemWarpSpecialized;
|
||||
using DispatchPolicy = PtrArrayNoSmemWarpSpecialized;
|
||||
|
||||
using ThreadEpilogueOp = ThreadEpilogueOp_;
|
||||
using ElementOutput = typename ThreadEpilogueOp::ElementOutput;
|
||||
using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator;
|
||||
using ElementCompute = typename ThreadEpilogueOp::ElementCompute;
|
||||
using ElementIntermediate = typename ThreadEpilogueOp::ElementD;
|
||||
|
||||
using ElementC = typename ThreadEpilogueOp::ElementC;
|
||||
using StrideC = StrideC_;
|
||||
using InternalStrideC = cute::remove_pointer_t<StrideC>;
|
||||
using ElementD = ElementD_;
|
||||
using StrideD = StrideD_;
|
||||
using InternalStrideD = cute::remove_pointer_t<StrideD>;
|
||||
|
||||
static_assert(!is_same_v<InternalStrideC, StrideC>, "Stride C must be a pointer");
|
||||
static_assert(is_same_v<InternalStrideD, StrideD>, "Stride D must not be a pointer");
|
||||
|
||||
using CopyAtomR2S = Copy_Atom<CopyOpR2S, ElementAccumulator>;
|
||||
using CopyAtomS2R = Copy_Atom<CopyOpS2R, ElementAccumulator>;
|
||||
using CopyAtomR2G = Copy_Atom<CopyOpR2G, ElementD>;
|
||||
static constexpr int AlignmentD = CopyAtomR2G::NumValSrc;
|
||||
|
||||
using SmemLayoutD = decltype(tile_to_shape(SmemLayoutAtomD{}, EpilogueTile{}));
|
||||
|
||||
constexpr static size_t SmemAlignmentD = cutlass::detail::alignment_for_swizzle(SmemLayoutD{});
|
||||
|
||||
struct SharedStorage
|
||||
{
|
||||
alignas(SmemAlignmentD) cute::ArrayEngine<ElementAccumulator, cosize_v<SmemLayoutD>> smem_D;
|
||||
};
|
||||
|
||||
struct TensorMapStorage
|
||||
{
|
||||
};
|
||||
|
||||
struct Arguments
|
||||
{
|
||||
typename ThreadEpilogueOp::Params thread{};
|
||||
ElementC const** ptr_C{};
|
||||
StrideC dC{};
|
||||
ElementD* ptr_D{};
|
||||
StrideD dD{};
|
||||
ElementBias const* ptr_bias;
|
||||
StrideBias dBias{};
|
||||
ElementScale const* ptr_scale;
|
||||
StrideScale dScale{};
|
||||
int64_t const* group_offset{};
|
||||
int32_t const* scatter_index{};
|
||||
cutlass::FastDivmod num_rows_in_final_output;
|
||||
};
|
||||
|
||||
using Params = Arguments;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
template <class ProblemShape>
|
||||
static constexpr Params to_underlying_arguments(
|
||||
ProblemShape const&, Arguments const& args, [[maybe_unused]] void* workspace)
|
||||
{
|
||||
return args;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static size_t get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count = 0)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static cutlass::Status initialize_workspace(ProblemShape const& problem_shape, Arguments const& args,
|
||||
void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr)
|
||||
{
|
||||
return cutlass::Status::kSuccess;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
CUTLASS_HOST_DEVICE static bool can_implement(
|
||||
[[maybe_unused]] ProblemShape problem_shape, [[maybe_unused]] Arguments const& args)
|
||||
{
|
||||
bool implementable = true;
|
||||
if (problem_shape.is_host_problem_shape_available())
|
||||
{
|
||||
// Check alignment for all problem sizes
|
||||
for (int i = 0; i < problem_shape.groups(); i++)
|
||||
{
|
||||
auto problem_shape_MNKL = append<4>(problem_shape.get_host_problem_shape(i), 1);
|
||||
auto [M, N, K, L] = problem_shape_MNKL;
|
||||
implementable = implementable
|
||||
&& cutlass::detail::check_alignment<AlignmentD>(cute::make_shape(M, N, L), InternalStrideD{});
|
||||
}
|
||||
}
|
||||
|
||||
if (!implementable)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(
|
||||
" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for selected global "
|
||||
"reduction instruction.\n");
|
||||
}
|
||||
return implementable;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
EpilogueMoeFusedFinalize(Params const& params_)
|
||||
: params(params_)
|
||||
{
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
bool is_source_needed()
|
||||
{
|
||||
// For Ptr-Array or Grouped Gemm we cannot determine if source is needed based on first beta.
|
||||
return params.ptr_C != nullptr
|
||||
&& (params.thread.beta_ptr_array || params.thread.beta_ptr || params.thread.beta != 0);
|
||||
}
|
||||
|
||||
template <class ProblemShapeMNKL, class BlockShapeMNK, class BlockCoordMNKL, class FrgEngine, class FrgLayout,
|
||||
class TiledMma, class ResidueMNK>
|
||||
CUTLASS_HOST_DEVICE void operator()(ProblemShapeMNKL problem_shape_mnkl, BlockShapeMNK blk_shape_MNK,
|
||||
BlockCoordMNKL blk_coord_mnkl, cute::Tensor<FrgEngine, FrgLayout> const& accumulators, TiledMma tiled_mma,
|
||||
ResidueMNK residue_mnk, int thread_idx, [[maybe_unused]] char* smem_buf)
|
||||
{
|
||||
using namespace cute;
|
||||
using X = Underscore;
|
||||
|
||||
static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4");
|
||||
static_assert(is_static<BlockShapeMNK>::value, "ThreadBlock tile shape must be static");
|
||||
static_assert(rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3");
|
||||
static_assert(rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3");
|
||||
|
||||
auto synchronize = [&]()
|
||||
{ cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
|
||||
|
||||
// Separate out problem shape for convenience
|
||||
auto M = get<0>(problem_shape_mnkl);
|
||||
auto N = get<1>(problem_shape_mnkl);
|
||||
auto L = get<3>(problem_shape_mnkl);
|
||||
|
||||
auto mma_tile_m = tile_size<0>(tiled_mma);
|
||||
auto mma_tile_n = tile_size<1>(tiled_mma);
|
||||
auto epi_tile_m = size<0>(EpilogueTile{});
|
||||
auto epi_tile_n = size<1>(EpilogueTile{});
|
||||
|
||||
CUTE_STATIC_ASSERT(epi_tile_m % mma_tile_m == 0, "MMA_TILE_M must divide EPI_TILE_M");
|
||||
CUTE_STATIC_ASSERT(mma_tile_n % epi_tile_n == 0, "EPI_TILE_N must divide MMA_TILE_N");
|
||||
|
||||
// Batches are managed by using appropriate pointers to C and D matrices
|
||||
int32_t const mock_L = 1;
|
||||
int32_t const mock_l_coord = 0;
|
||||
|
||||
// Slice to get the tile this CTA is responsible for
|
||||
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl;
|
||||
|
||||
// If scalar alpha/beta are provided, i.e., same alpha/beta applies to all batches/groups.
|
||||
// If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups,
|
||||
// we get the correct alpha/beta values for the current batch/group using group index.
|
||||
ThreadEpilogueOp epilogue_op(params.thread, l_coord);
|
||||
|
||||
SharedStorage& storage = *reinterpret_cast<SharedStorage*>(smem_buf);
|
||||
|
||||
Tensor sD_ = make_tensor(make_smem_ptr(storage.smem_D.begin()), SmemLayoutD{});
|
||||
Tensor sD = as_position_independent_swizzle_tensor(sD_);
|
||||
|
||||
// Function to scatter output rows
|
||||
auto& num_rows = params.num_rows_in_final_output;
|
||||
auto read_scatter_map = IndexedGather(make_gmem_ptr(params.scatter_index + params.group_offset[l_coord]));
|
||||
auto get_scatter_idx = [&](auto i)
|
||||
{
|
||||
auto scatter = read_scatter_map(i);
|
||||
int quot, rem;
|
||||
num_rows(quot, rem, scatter);
|
||||
return rem;
|
||||
};
|
||||
|
||||
// Represent the full output tensor
|
||||
ElementC const* ptr_C = epilogue_op.is_source_needed() ? params.ptr_C[l_coord] : nullptr;
|
||||
auto dC = epilogue_op.is_source_needed() ? params.dC[l_coord] : InternalStrideC{};
|
||||
Tensor mC_mnl = make_tensor(make_gmem_ptr(ptr_C), make_shape(M, N, mock_L), dC); // (m,n,l)
|
||||
Tensor mD_mnl = make_gather_tensor(
|
||||
make_gmem_ptr(params.ptr_D), make_shape(M, N, mock_L), params.dD, get_scatter_idx); // (m,n,l)
|
||||
|
||||
// Use fake shape for bias, it doesn't matter
|
||||
bool const is_bias_needed = params.ptr_bias != nullptr;
|
||||
Tensor mBias_mnl = make_tensor(make_gmem_ptr(params.ptr_bias), make_shape(M, N, 1), params.dBias);
|
||||
Tensor mScale_mnl = make_tensor(
|
||||
make_gmem_ptr(params.ptr_scale + params.group_offset[l_coord]), make_shape(M, N), params.dScale);
|
||||
|
||||
Tensor gC_mnl
|
||||
= local_tile(mC_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l)
|
||||
Tensor gD_mnl
|
||||
= local_tile(mD_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l)
|
||||
|
||||
Tensor gC = gC_mnl(_, _, m_coord, n_coord, mock_l_coord); // (BLK_M,BLK_N)
|
||||
Tensor gD = gD_mnl(_, _, m_coord, n_coord, mock_l_coord); // (BLK_M,BLK_N)
|
||||
|
||||
Tensor gC_epi = flat_divide(gC, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
|
||||
Tensor gD_epi = flat_divide(gD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
|
||||
|
||||
Tensor gBias_mnl
|
||||
= local_tile(mBias_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l)
|
||||
Tensor gScale_mnl
|
||||
= local_tile(mScale_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l)
|
||||
|
||||
Tensor gBias = gBias_mnl(_, _, m_coord, n_coord, l_coord); // (BLK_M,BLK_N)
|
||||
Tensor gScale = gScale_mnl(_, _, m_coord, n_coord); // (BLK_M,BLK_N)
|
||||
|
||||
Tensor gBias_epi = flat_divide(gBias, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
|
||||
Tensor gScale_epi = flat_divide(gScale, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
|
||||
|
||||
// Get the smallest tiled copy we can use to retile the accumulators
|
||||
TiledCopy tiled_copy_C_atom
|
||||
= make_tiled_copy_C_atom(Copy_Atom<SM90_U32x4_STSM_N, cutlass::half_t>{}, tiled_mma);
|
||||
TiledCopy tiled_r2s = make_tiled_copy_S(CopyAtomR2S{}, tiled_copy_C_atom);
|
||||
|
||||
auto thread_r2s = tiled_r2s.get_thread_slice(thread_idx);
|
||||
Tensor tRS_rAcc = thread_r2s.retile_S(accumulators); // ((R2S,R2S_V),MMA_M,MMA_N)
|
||||
Tensor tRS_sD = thread_r2s.partition_D(sD); // ((R2S,R2S_V),R2S_M,R2S_N)
|
||||
Tensor tRS_rD = make_tensor<ElementAccumulator>(shape(tRS_sD)); // ((R2S,R2S_V),R2S_M,R2S_N)
|
||||
|
||||
// Make a tiled copy vectorized along major direction of D
|
||||
auto tiled_s2r = [&]()
|
||||
{
|
||||
if constexpr (cutlass::gemm::detail::is_k_major<StrideD>())
|
||||
{
|
||||
constexpr int NumThreadsMajor = epi_tile_n / AlignmentD;
|
||||
constexpr int NumThreadsMinor = cute::size(tiled_mma) / NumThreadsMajor;
|
||||
return make_tiled_copy(CopyAtomS2R{},
|
||||
Layout<Shape<Int<NumThreadsMinor>, Int<NumThreadsMajor>>, Stride<Int<NumThreadsMajor>, _1>>{},
|
||||
Layout<Shape<_1, Int<AlignmentD>>>{});
|
||||
}
|
||||
else if constexpr (cutlass::gemm::detail::is_mn_major<StrideD>())
|
||||
{
|
||||
constexpr int NumThreadsMajor = epi_tile_m / AlignmentD;
|
||||
constexpr int NumThreadsMinor = cute::size(tiled_mma) / NumThreadsMajor;
|
||||
return make_tiled_copy(CopyAtomS2R{},
|
||||
Layout<Shape<Int<NumThreadsMajor>, Int<NumThreadsMinor>>, Stride<_1, Int<NumThreadsMajor>>>{},
|
||||
Layout<Shape<Int<AlignmentD>, _1>>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(cute::is_void_v<StrideD>, "Unsupported D gmem layout.");
|
||||
}
|
||||
}();
|
||||
|
||||
auto thread_s2r = tiled_s2r.get_thread_slice(thread_idx);
|
||||
Tensor tSR_sD = thread_s2r.partition_S(sD); // ((S2R,S2R_V),S2R_M,S2R_N)
|
||||
Tensor tSR_gD = thread_s2r.partition_D(gD_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N)
|
||||
Tensor tSR_gC = thread_s2r.partition_D(gC_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N)
|
||||
Tensor tSR_gBias = thread_s2r.partition_D(gBias_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N)
|
||||
Tensor tSR_gScale = thread_s2r.partition_D(gScale_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N)
|
||||
|
||||
// Allocate intermediate registers for a single subtile
|
||||
Tensor tSR_rD = make_tensor<ElementAccumulator>(take<0, 3>(shape(tSR_gD))); // ((S2R,S2R_V),S2R_M,S2R_N)
|
||||
Tensor tSR_rD_final = make_tensor<ElementD>(shape(tSR_rD)); // ((S2R,S2R_V),S2R_M,S2R_N)
|
||||
Tensor tSR_rC = make_tensor<ElementC>(shape(tSR_rD)); // ((S2R,S2R_V),S2R_M,S2R_N)
|
||||
Tensor tSR_rBias = make_tensor<ElementBias>(tSR_gBias(_, _, _, 0, 0).layout()); // ((S2R,S2R_V),S2R_M,S2R_N)
|
||||
Tensor tSR_rScale = make_tensor<ElementScale>(tSR_gScale(_, _, _, 0, 0).layout()); // ((S2R,S2R_V),S2R_M,S2R_N)
|
||||
|
||||
// Make an identity coordinate tensor for predicating our output MN tile
|
||||
Tensor cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD))));
|
||||
Tensor cD_epi = flat_divide(cD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
|
||||
Tensor tSR_cD = thread_s2r.partition_D(cD_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N)
|
||||
|
||||
// epilogue subtile loop
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int epi_m = 0; epi_m < size<2>(gD_epi); ++epi_m)
|
||||
{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int epi_n = 0; epi_n < size<3>(gD_epi); ++epi_n)
|
||||
{
|
||||
int mma_m = (epi_m * epi_tile_m) / mma_tile_m;
|
||||
int mma_n = (epi_n * epi_tile_n) / mma_tile_n;
|
||||
Tensor tRS_rAcc_mn = tRS_rAcc(_, mma_m, mma_n);
|
||||
|
||||
int epi_n_in_mma = epi_n % (mma_tile_n / epi_tile_n);
|
||||
int r2s_v = epi_n_in_mma * size(tRS_rD);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int epi_v = 0; epi_v < size(tRS_rD); ++epi_v)
|
||||
{
|
||||
tRS_rD(epi_v) = tRS_rAcc_mn(r2s_v + epi_v);
|
||||
}
|
||||
|
||||
copy(tiled_r2s, tRS_rD, tRS_sD);
|
||||
synchronize();
|
||||
|
||||
copy(tiled_s2r, tSR_sD, tSR_rD);
|
||||
synchronize();
|
||||
|
||||
Tensor tSR_gC_mn = tSR_gC(_, _, _, epi_m, epi_n);
|
||||
Tensor tSR_gBias_mn = tSR_gBias(_, _, _, epi_m, epi_n);
|
||||
Tensor tSR_gScale_mn = tSR_gScale(_, _, _, epi_m, epi_n);
|
||||
Tensor tSR_cD_mn = tSR_cD(_, _, _, epi_m, epi_n);
|
||||
Tensor tSR_gD_mn = tSR_gD(_, _, _, epi_m, epi_n);
|
||||
|
||||
if (epilogue_op.is_source_needed())
|
||||
{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int m = 0; m < size<1>(tSR_rD); ++m)
|
||||
{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int n = 0; n < size<2>(tSR_rD); ++n)
|
||||
{
|
||||
if (elem_less(tSR_cD_mn(0, m, n), make_coord(get<0>(residue_mnk), get<1>(residue_mnk))))
|
||||
{
|
||||
copy(tSR_gC_mn(_, m, n), tSR_rC(_, m, n));
|
||||
if (is_bias_needed)
|
||||
{
|
||||
copy(tSR_gBias_mn(_, m, n), tSR_rBias(_, m, n));
|
||||
}
|
||||
copy(tSR_gScale_mn(_, m, n), tSR_rScale(_, m, n));
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size<0>(tSR_rD); ++i)
|
||||
{
|
||||
auto epi_value = epilogue_op(tSR_rD(i, m, n), tSR_rC(i, m, n));
|
||||
if (is_bias_needed)
|
||||
{
|
||||
epi_value += static_cast<ElementCompute>(tSR_rBias(i, m, n));
|
||||
}
|
||||
tSR_rD_final(i, m, n) = static_cast<ElementD>(tSR_rScale(i, m, n) * epi_value);
|
||||
}
|
||||
copy(CopyAtomR2G{}, tSR_rD_final(_, m, n), tSR_gD_mn(_, m, n));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int m = 0; m < size<1>(tSR_rD); ++m)
|
||||
{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int n = 0; n < size<2>(tSR_rD); ++n)
|
||||
{
|
||||
if (elem_less(tSR_cD_mn(0, m, n), make_coord(get<0>(residue_mnk), get<1>(residue_mnk))))
|
||||
{
|
||||
if (is_bias_needed)
|
||||
{
|
||||
copy(tSR_gBias_mn(_, m, n), tSR_rBias(_, m, n));
|
||||
}
|
||||
copy(tSR_gScale_mn(_, m, n), tSR_rScale(_, m, n));
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size<0>(tSR_rD); ++i)
|
||||
{
|
||||
auto epi_value = epilogue_op(tSR_rD(i, m, n));
|
||||
if (is_bias_needed)
|
||||
{
|
||||
epi_value += static_cast<ElementCompute>(tSR_rBias(i, m, n));
|
||||
}
|
||||
tSR_rD_final(i, m, n) = static_cast<ElementD>(tSR_rScale(i, m, n) * epi_value);
|
||||
}
|
||||
copy(CopyAtomR2G{}, tSR_rD_final(_, m, n), tSR_gD_mn(_, m, n));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
Params params;
|
||||
};
|
||||
|
||||
namespace detail
|
||||
{
|
||||
|
||||
template <class Element, class MaxVec>
|
||||
constexpr auto get_vectorized_atomic_add_op()
|
||||
{
|
||||
using namespace cute;
|
||||
|
||||
auto constexpr MaxVecSize = size(MaxVec{});
|
||||
|
||||
if constexpr (is_same_v<Element, cutlass::half_t>)
|
||||
{
|
||||
if constexpr (MaxVecSize >= 8)
|
||||
{
|
||||
return SM90_RED_ADD_NOFTZ_F16x2_V4{};
|
||||
}
|
||||
else if constexpr (MaxVecSize >= 4)
|
||||
{
|
||||
return SM90_RED_ADD_NOFTZ_F16x2_V2{};
|
||||
}
|
||||
else if constexpr (MaxVecSize >= 2)
|
||||
{
|
||||
return SM70_RED_ADD_NOFTZ_F16x2{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return SM70_RED_ADD_NOFTZ_F16{};
|
||||
}
|
||||
}
|
||||
else if constexpr (is_same_v<Element, cutlass::bfloat16_t>)
|
||||
{
|
||||
if constexpr (MaxVecSize >= 8)
|
||||
{
|
||||
return SM90_RED_ADD_NOFTZ_BF16x2_V4{};
|
||||
}
|
||||
else if constexpr (MaxVecSize >= 4)
|
||||
{
|
||||
return SM90_RED_ADD_NOFTZ_BF16x2_V2{};
|
||||
}
|
||||
else if constexpr (MaxVecSize >= 2)
|
||||
{
|
||||
return SM90_RED_ADD_NOFTZ_BF16x2{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return SM90_RED_ADD_NOFTZ_BF16{};
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// non-vectorized atomic add for all other types until supported
|
||||
return TypedAtomicAdd<Element>{};
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <class TileShape, class ElementC, class StrideC, class ElementD, class StrideD, class ElementAccumulator,
|
||||
class ElementCompute, class ElementBias, class StrideBias, class ElementScale, class StrideScale>
|
||||
struct EpilogueMoeFusedFinalizeBuilder
|
||||
{
|
||||
|
||||
// assuming cooperative kernel schedule
|
||||
using EpiTileN = decltype(cute::min(size<1>(TileShape{}), _32{}));
|
||||
using EpilogueTile = Shape<_128, EpiTileN>;
|
||||
|
||||
// Output of linear combination is ElementCompute instead of ElementD
|
||||
// since we will be doing more computate on it, no need to cast yet.
|
||||
using ThreadEpilogueOp
|
||||
= cutlass::epilogue::thread::LinearCombination<ElementCompute, 1, ElementAccumulator, ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::Default, cutlass::FloatRoundStyle::round_to_nearest, ElementC>;
|
||||
|
||||
using SmemLayoutAtomD
|
||||
= decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom<StrideD, ElementAccumulator, EpilogueTile>());
|
||||
using CopyAtomR2S = decltype(detail::sm90_get_smem_store_op_for_accumulator<StrideD, ElementAccumulator>());
|
||||
using CopyAtomS2R = DefaultCopy;
|
||||
using CopyAtomR2G = decltype(detail::get_vectorized_atomic_add_op<ElementD, EpiTileN>());
|
||||
|
||||
template <class EpilogueOp>
|
||||
struct Sm90TmaWarpSpecializedAdapterWithSmemStorage : detail::Sm90TmaWarpSpecializedAdapter<EpilogueOp>
|
||||
{
|
||||
// We need to override this one using declaration because otherwise we double up on the smem
|
||||
using TensorMapStorage = typename EpilogueOp::TensorMapStorage;
|
||||
|
||||
using Base = detail::Sm90TmaWarpSpecializedAdapter<EpilogueOp>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90TmaWarpSpecializedAdapterWithSmemStorage(
|
||||
typename EpilogueOp::Params const& params, [[maybe_unused]] typename Base::TensorStorage& shared_tensors)
|
||||
: Base(params)
|
||||
{
|
||||
}
|
||||
|
||||
// These functions depend on the type of TensorMapStorage
|
||||
template <bool IsLoad>
|
||||
CUTLASS_DEVICE void tensormaps_perform_update([[maybe_unused]] TensorMapStorage& shared_tensormap,
|
||||
[[maybe_unused]] typename EpilogueOp::Params const& params,
|
||||
[[maybe_unused]] cute::TmaDescriptor const* tensormap, [[maybe_unused]] int32_t next_batch)
|
||||
{
|
||||
}
|
||||
|
||||
template <bool IsLoad>
|
||||
CUTLASS_DEVICE void tensormaps_cp_fence_release([[maybe_unused]] TensorMapStorage& shared_tensormap,
|
||||
[[maybe_unused]] cute::TmaDescriptor const* tensormap, [[maybe_unused]] uint32_t lane_predicate)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
using CollectiveOp = Sm90TmaWarpSpecializedAdapterWithSmemStorage<
|
||||
EpilogueMoeFusedFinalize<StrideC, ElementD, StrideD, ThreadEpilogueOp, ElementBias, StrideBias, ElementScale,
|
||||
StrideScale, EpilogueTile, SmemLayoutAtomD, CopyAtomR2S, CopyAtomS2R, CopyAtomR2G>>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace collective
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -1,105 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Functor performing linear combination with a maximum operation used by epilogues.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_generic.h"
|
||||
#include "cutlass/epilogue/thread/scale_type.h"
|
||||
#include "cutlass/functional.h"
|
||||
#include "cutlass/half.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace epilogue
|
||||
{
|
||||
namespace thread
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
__forceinline__ __device__ float copysignf_pos(float a, float b)
|
||||
{
|
||||
float r;
|
||||
r = __int_as_float(__float_as_int(a) | (__float_as_int(b) & 0x80000000));
|
||||
return r;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ float tanh_opt(float x)
|
||||
{
|
||||
#if (__CUDACC_VER_MAJOR__ < 11) || (__CUDA_ARCH__ < 750)
|
||||
float const exp_val = -1.f * fabs(2 * x);
|
||||
return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x);
|
||||
#else
|
||||
return fast_tanh(x);
|
||||
#endif
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
template <>
|
||||
struct GELU_taylor<float>
|
||||
{
|
||||
static bool const kIsHeavy = true;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
float operator()(float const& z) const
|
||||
{
|
||||
|
||||
float k0 = float(0.7978845608028654);
|
||||
float k1 = float(0.044715);
|
||||
|
||||
return float(cutlass::constants::half<float>() * z
|
||||
* (cutlass::constants::one<float>() + tanh_opt(k0 * z * (cutlass::constants::one<float>() + k1 * z * z))));
|
||||
}
|
||||
|
||||
using Params = LinearCombinationGenericParams<float>;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
float operator()(float const& scalar, Params const& params_) const
|
||||
{
|
||||
return this->operator()(scalar);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace thread
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -1,352 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Epilogue visitor for threadblock scoped INT8 GEMMs that uses one scaling factor per row, and one per column.
|
||||
|
||||
original file: 3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h
|
||||
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/arch/memory_sm75.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "tensorrt_llm/common/quantization.h"
|
||||
|
||||
namespace tk = tensorrt_llm::common;
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace epilogue
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
|
||||
template <typename ThreadblockShape_, int ThreadCount, typename ScaleTileIterator_, typename OutputTileIterator_,
|
||||
typename ElementAccumulator_, typename ElementCompute_, typename ElementwiseFunctor_, bool UseMasking_ = false>
|
||||
class EpilogueVisitorPerRowPerCol
|
||||
{
|
||||
public:
|
||||
using ThreadblockShape = ThreadblockShape_;
|
||||
static int const kThreadCount = ThreadCount;
|
||||
|
||||
using ScaleTileIterator = ScaleTileIterator_;
|
||||
using OutputTileIterator = OutputTileIterator_;
|
||||
using ElementwiseFunctor = ElementwiseFunctor_;
|
||||
|
||||
static int const kIterations = OutputTileIterator::kIterations;
|
||||
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
using ElementOutput = typename OutputTileIterator::Element;
|
||||
using LayoutOutput = cutlass::layout::RowMajor;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
|
||||
using AlphaScaleElementType = typename ScaleTileIterator::Element;
|
||||
|
||||
using ElementCompute = ElementCompute_;
|
||||
using AccumulatorFragment = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
using ComputeFragment = Array<ElementCompute_, kElementsPerAccess>;
|
||||
using OutputVector = Array<ElementOutput, kElementsPerAccess>;
|
||||
|
||||
static int const kThreadsPerRow = OutputTileIterator::ThreadMap::Detail::kAccessWidth;
|
||||
static bool const kHasMultiStepsInRow = (OutputTileIterator::ThreadMap::Iterations::kColumn > 1);
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments
|
||||
{
|
||||
|
||||
typename ElementwiseFunctor::Params elementwise;
|
||||
int64_t batch_stride_alpha;
|
||||
int64_t batch_stride_C;
|
||||
int64_t batch_stride_D;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
Arguments()
|
||||
: batch_stride_alpha(0)
|
||||
, batch_stride_C(0)
|
||||
, batch_stride_D(0)
|
||||
{
|
||||
}
|
||||
|
||||
Arguments(typename ElementwiseFunctor::Params elementwise_)
|
||||
: elementwise(elementwise_)
|
||||
, batch_stride_alpha(0)
|
||||
, batch_stride_C(0)
|
||||
, batch_stride_D(0)
|
||||
{
|
||||
}
|
||||
|
||||
Arguments(typename ElementwiseFunctor::Params elementwise_, int64_t batch_stride_alpha_,
|
||||
int64_t batch_stride_C_, int64_t batch_stride_D_)
|
||||
: elementwise(elementwise_)
|
||||
, batch_stride_alpha(batch_stride_alpha_)
|
||||
, batch_stride_C(batch_stride_C_)
|
||||
, batch_stride_D(batch_stride_D_)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
struct Params
|
||||
{
|
||||
|
||||
typename ElementwiseFunctor::Params elementwise;
|
||||
int64_t batch_stride_alpha;
|
||||
int64_t batch_stride_C;
|
||||
int64_t batch_stride_D;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const& args)
|
||||
: elementwise(args.elementwise)
|
||||
, batch_stride_alpha(args.batch_stride_alpha)
|
||||
, batch_stride_C(args.batch_stride_C)
|
||||
, batch_stride_D(args.batch_stride_D)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared storage
|
||||
struct SharedStorage
|
||||
{
|
||||
};
|
||||
|
||||
private:
|
||||
Params const& params_;
|
||||
SharedStorage& shared_storage_;
|
||||
MatrixCoord extent_;
|
||||
MatrixCoord extent_real_;
|
||||
ElementwiseFunctor elementwise_;
|
||||
|
||||
bool const per_token_quant_;
|
||||
bool const per_channel_quant_;
|
||||
|
||||
AlphaScaleElementType* ptr_alpha_row_;
|
||||
AlphaScaleElementType* ptr_alpha_col_;
|
||||
ScaleTileIterator iterator_alpha_col_;
|
||||
OutputTileIterator iterator_C_;
|
||||
OutputTileIterator iterator_D_;
|
||||
|
||||
AlphaScaleElementType element_alpha_row_ = 1.0f;
|
||||
AlphaScaleElementType element_alpha_col_ = 1.0f;
|
||||
typename ScaleTileIterator::Fragment fragment_alpha_col_;
|
||||
typename OutputTileIterator::Fragment fragment_C_;
|
||||
typename OutputTileIterator::Fragment fragment_D_;
|
||||
|
||||
ElementAccumulator beta_;
|
||||
|
||||
int column_offset_;
|
||||
|
||||
MatrixCoord thread_offset_;
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
EpilogueVisitorPerRowPerCol(Params const& params, SharedStorage& shared_storage,
|
||||
cutlass::MatrixCoord const& problem_size, int thread_idx, int warp_idx, int lane_idx,
|
||||
typename ScaleTileIterator::Params params_alpha_col, typename OutputTileIterator::Params params_C,
|
||||
typename OutputTileIterator::Params params_D, tk::QuantMode quant_option, AlphaScaleElementType* ptr_alpha_row,
|
||||
AlphaScaleElementType* ptr_alpha_col, typename OutputTileIterator::Element* ptr_C,
|
||||
typename OutputTileIterator::Element* ptr_D,
|
||||
cutlass::MatrixCoord const& threadblock_offset = cutlass::MatrixCoord(0, 0), int column_offset = 0,
|
||||
cutlass::MatrixCoord const& problem_size_real = cutlass::MatrixCoord(0, 0))
|
||||
: params_(params)
|
||||
, shared_storage_(shared_storage)
|
||||
, extent_(problem_size)
|
||||
, elementwise_(params.elementwise)
|
||||
, per_token_quant_(quant_option.hasPerTokenScaling())
|
||||
, per_channel_quant_(quant_option.hasPerChannelScaling())
|
||||
, ptr_alpha_row_(ptr_alpha_row)
|
||||
, ptr_alpha_col_(ptr_alpha_col)
|
||||
, iterator_alpha_col_(params_alpha_col, ptr_alpha_col, problem_size, thread_idx, threadblock_offset)
|
||||
, iterator_C_(params_C, ptr_C, problem_size, thread_idx, threadblock_offset)
|
||||
, iterator_D_(params_D, ptr_D, problem_size, thread_idx, threadblock_offset)
|
||||
, extent_real_(problem_size_real)
|
||||
{
|
||||
beta_ = (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr : params.elementwise.beta);
|
||||
|
||||
if (beta_ == ElementAccumulator())
|
||||
{
|
||||
iterator_C_.clear_mask();
|
||||
}
|
||||
|
||||
if (!per_channel_quant_ && (ptr_alpha_col_ != nullptr))
|
||||
{
|
||||
element_alpha_col_ = *ptr_alpha_col_;
|
||||
}
|
||||
|
||||
if (!per_token_quant_ && (ptr_alpha_row_ != nullptr))
|
||||
{
|
||||
element_alpha_row_ = *ptr_alpha_row_;
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to indicate split-K behavior
|
||||
CUTLASS_DEVICE
|
||||
void set_k_partition(int split_k_index, ///< Index of this threadblock within split-K partitioned scheme
|
||||
int split_k_slices)
|
||||
{ ///< Total number of split-K slices
|
||||
}
|
||||
|
||||
/// Called to set the batch index
|
||||
CUTLASS_DEVICE
|
||||
void set_batch_index(int batch_idx)
|
||||
{
|
||||
iterator_alpha_col_.add_pointer_offset(batch_idx * params_.batch_stride_alpha);
|
||||
iterator_C_.add_pointer_offset(batch_idx * params_.batch_stride_C);
|
||||
iterator_D_.add_pointer_offset(batch_idx * params_.batch_stride_D);
|
||||
}
|
||||
|
||||
/// Called at the start of the epilogue just before iterating over accumulator slices
|
||||
CUTLASS_DEVICE
|
||||
void begin_epilogue()
|
||||
{
|
||||
if (per_channel_quant_)
|
||||
{
|
||||
iterator_alpha_col_.load(fragment_alpha_col_);
|
||||
}
|
||||
}
|
||||
|
||||
/// Called at the start of one step before starting accumulator exchange
|
||||
CUTLASS_DEVICE
|
||||
void begin_step(int step_idx)
|
||||
{
|
||||
fragment_D_.clear();
|
||||
fragment_C_.clear();
|
||||
|
||||
if (elementwise_.kScale != cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling)
|
||||
{
|
||||
iterator_C_.load(fragment_C_);
|
||||
++iterator_C_;
|
||||
}
|
||||
}
|
||||
|
||||
/// Called at the start of a row
|
||||
CUTLASS_DEVICE
|
||||
void begin_row(int row_idx)
|
||||
{
|
||||
// load alpha_row in begin_step only when per token(row) scaling is used
|
||||
if (per_token_quant_)
|
||||
{
|
||||
int thread_offset_row
|
||||
= iterator_D_.thread_start_row() + OutputTileIterator::ThreadMap::iteration_offset(row_idx).row();
|
||||
|
||||
arch::global_load<AlphaScaleElementType, sizeof(AlphaScaleElementType)>(
|
||||
element_alpha_row_, ptr_alpha_row_ + thread_offset_row, thread_offset_row < extent_.row());
|
||||
}
|
||||
}
|
||||
|
||||
/// Called after accumulators have been exchanged for each accumulator vector
|
||||
CUTLASS_DEVICE
|
||||
void visit(int iter_idx, int row_idx, int column_idx, int frag_idx, AccumulatorFragment const& accum)
|
||||
{
|
||||
|
||||
NumericArrayConverter<ElementCompute, ElementAccumulator, kElementsPerAccess> source_converter;
|
||||
|
||||
ComputeFragment result = source_converter(accum);
|
||||
if (per_channel_quant_)
|
||||
{
|
||||
ComputeFragment alpha_col = reinterpret_cast<ComputeFragment*>(&fragment_alpha_col_)[column_idx];
|
||||
result = per_token_channel_scale_accumulator_(result, alpha_col, element_alpha_row_);
|
||||
}
|
||||
else
|
||||
{
|
||||
result = per_token_scale_accumulator_(result, element_alpha_col_, element_alpha_row_);
|
||||
}
|
||||
|
||||
// Convert to the output
|
||||
NumericArrayConverter<ElementOutput, ElementCompute, kElementsPerAccess> output_converter;
|
||||
OutputVector& output = reinterpret_cast<OutputVector*>(&fragment_D_)[frag_idx];
|
||||
output = output_converter(result);
|
||||
}
|
||||
|
||||
/// Called at the end of a row
|
||||
CUTLASS_DEVICE
|
||||
void end_row(int row_idx) {}
|
||||
|
||||
/// Called after all accumulator elements have been visited
|
||||
CUTLASS_DEVICE
|
||||
void end_step(int step_idx)
|
||||
{
|
||||
|
||||
iterator_D_.store(fragment_D_);
|
||||
++iterator_D_;
|
||||
}
|
||||
|
||||
/// Called after all steps have been completed
|
||||
CUTLASS_DEVICE
|
||||
void end_epilogue() {}
|
||||
|
||||
private:
|
||||
CUTLASS_DEVICE
|
||||
ComputeFragment per_token_channel_scale_accumulator_(
|
||||
ComputeFragment const& accum, ComputeFragment const& scale_col, AlphaScaleElementType const& scale_row)
|
||||
{
|
||||
|
||||
ComputeFragment result;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < ComputeFragment::kElements; ++i)
|
||||
{
|
||||
result[i] = accum[i] * (scale_col[i] * scale_row);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ComputeFragment per_token_scale_accumulator_(
|
||||
ComputeFragment const& accum, AlphaScaleElementType const& scale_col, AlphaScaleElementType const& scale_row)
|
||||
{
|
||||
|
||||
ComputeFragment result;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < ComputeFragment::kElements; ++i)
|
||||
{
|
||||
result[i] = accum[i] * (scale_col * scale_row);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
@@ -1,282 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
|
||||
|
||||
The epilogue rearranges the result of a matrix product through shared memory to match canonical
|
||||
tensor layouts in global memory. Epilogues support conversion and reduction operations.
|
||||
|
||||
original file: 3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h
|
||||
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/platform/platform.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_clamp.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_gelu.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_hardswish.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_planar_complex.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_relu.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_relu0.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_sigmoid.h"
|
||||
|
||||
#include "cutlass/epilogue/thread/conversion_op.h"
|
||||
#include "cutlass/epilogue/thread/reduction_op.h"
|
||||
|
||||
#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h"
|
||||
|
||||
#include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h"
|
||||
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
|
||||
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h"
|
||||
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h"
|
||||
#include "cutlass/epilogue/threadblock/shared_load_iterator.h"
|
||||
#include "cutlass/epilogue/threadblock/shared_load_iterator_mixed.h"
|
||||
#include "cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h"
|
||||
#include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h"
|
||||
#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h"
|
||||
#include "cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h"
|
||||
|
||||
#include "cutlass/epilogue/threadblock/epilogue.h"
|
||||
#include "cutlass/epilogue/threadblock/interleaved_epilogue.h"
|
||||
|
||||
#include "cutlass/layout/permute.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace epilogue
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace detail
|
||||
{
|
||||
|
||||
/// Partial specialization for bfloat16_t <= int32_t x 8 epilogues avoids shared memory bank conflicts.
|
||||
template <typename ThreadblockShape, typename WarpShape, typename InstructionShape, typename ThreadMap>
|
||||
struct DefaultIteratorsTensorOp<cutlass::bfloat16_t, int32_t, 8, ThreadblockShape, WarpShape, InstructionShape,
|
||||
ThreadMap>
|
||||
{
|
||||
using WarpTileIterator
|
||||
= cutlass::epilogue::warp::TileIteratorTensorOpMixed<WarpShape, InstructionShape, int32_t, 32, 16, 8, 8>;
|
||||
|
||||
using SharedLoadIterator
|
||||
= cutlass::epilogue::threadblock::SharedLoadIteratorMixed<ThreadMap, int32_t, 32, 16, 8, 8>;
|
||||
|
||||
static int const kFragmentsPerIteration = 2;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace detail
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Tile iterator used to load output tile from shared memory in epilogue.
|
||||
///
|
||||
/// Satisfies: ReadableTileIterator
|
||||
///
|
||||
template <typename ThreadMap_ ///< Thread map (concept: OutputTileThreadMap)
|
||||
>
|
||||
class SharedLoadIteratorMixed<ThreadMap_, int32_t, 32, 16, 8, 8>
|
||||
{
|
||||
public:
|
||||
using ThreadMap = ThreadMap_;
|
||||
using Shape = typename ThreadMap::Shape;
|
||||
|
||||
using Element = int32_t;
|
||||
|
||||
using Layout = layout::RowMajor;
|
||||
using TensorRef = TensorRef<Element, Layout>;
|
||||
using ConstTensorRef = typename TensorRef::ConstTensorRef;
|
||||
|
||||
using Index = typename Layout::Index;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
using TensorCoord = MatrixCoord;
|
||||
|
||||
static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
|
||||
|
||||
static int const kAlignment = ThreadMap::kElementsPerAccess * sizeof_bits<Element>::value / 8;
|
||||
|
||||
static int const kThreads = ThreadMap::kThreads;
|
||||
|
||||
/// Fragment object
|
||||
using Fragment = Array<Element,
|
||||
ThreadMap::Iterations::kColumn * ThreadMap::Iterations::kRow * ThreadMap::Iterations::kGroup
|
||||
* ThreadMap::Iterations::kCluster * ThreadMap::kElementsPerAccess>;
|
||||
|
||||
/// Memory access size
|
||||
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess, kAlignment>;
|
||||
|
||||
/// Vector type used for SMEM loads
|
||||
using LoadType = AlignedArray<Element, const_min(128 / sizeof_bits<Element>::value, ThreadMap::kElementsPerAccess),
|
||||
const_min(16, kAlignment)>;
|
||||
|
||||
static int const kLoadsPerAccess = AccessType::kElements / LoadType::kElements;
|
||||
|
||||
private:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Byte-level pointer
|
||||
LoadType const* pointers_[kLoadsPerAccess];
|
||||
|
||||
/// Stride along adjacent rows in units of LoadType
|
||||
int stride_;
|
||||
|
||||
public:
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_DEVICE
|
||||
SharedLoadIteratorMixed(TensorRef ref, int thread_idx)
|
||||
: stride_((ref.stride(0) / LoadType::kElements))
|
||||
{
|
||||
|
||||
TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx);
|
||||
|
||||
// Initialize pointers
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kLoadsPerAccess; ++i)
|
||||
{
|
||||
pointers_[i] = reinterpret_cast<LoadType const*>(ref.data());
|
||||
|
||||
int col_idx = (thread_offset.column() / kElementsPerAccess) * kLoadsPerAccess;
|
||||
int bank_offset = (col_idx * int(sizeof(LoadType)) / 128) % kLoadsPerAccess;
|
||||
|
||||
col_idx += (bank_offset + i) % kLoadsPerAccess;
|
||||
|
||||
pointers_[i] += thread_offset.row() * stride_ + col_idx;
|
||||
}
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
CUTLASS_HOST_DEVICE
|
||||
void add_pointer_offset(LongIndex pointer_offset)
|
||||
{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kLoadsPerAccess; ++i)
|
||||
{
|
||||
pointers_[i] += pointer_offset / LoadType::kElements;
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void add_tile_offset(TensorCoord const& offset)
|
||||
{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kLoadsPerAccess; ++i)
|
||||
{
|
||||
pointers_[i]
|
||||
+= offset.row() * Shape::kRow * stride_ + offset.column() * Shape::kColumn / LoadType::kElements;
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory
|
||||
CUTLASS_DEVICE
|
||||
void load_with_pointer_offset(Fragment& frag, Index pointer_offset) const
|
||||
{
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster)
|
||||
{
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group)
|
||||
{
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int row = 0; row < ThreadMap::Iterations::kRow; ++row)
|
||||
{
|
||||
|
||||
int row_ptr_offset = row * ThreadMap::Delta::kRow * stride_
|
||||
+ group * ThreadMap::Delta::kGroup * stride_ + cluster * ThreadMap::Delta::kCluster * stride_
|
||||
+ pointer_offset / LoadType::kElements;
|
||||
|
||||
int frag_row_idx
|
||||
= (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster));
|
||||
|
||||
LoadType* frag_ptr = reinterpret_cast<LoadType*>(&frag);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column)
|
||||
{
|
||||
|
||||
int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn + column;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < kLoadsPerAccess; ++v)
|
||||
{
|
||||
|
||||
int vector_idx
|
||||
= (column * ThreadMap::Delta::kColumn / kElementsPerAccess * kLoadsPerAccess);
|
||||
|
||||
LoadType const* memory_pointer = pointers_[v] + row_ptr_offset;
|
||||
|
||||
frag_ptr[frag_idx * kLoadsPerAccess + v] = memory_pointer[vector_idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads a fragment
|
||||
CUTLASS_DEVICE
|
||||
void load(Fragment& frag) const
|
||||
{
|
||||
|
||||
load_with_pointer_offset(frag, 0);
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -1,141 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
/**
|
||||
* @file epilogue_helpers.h
|
||||
*
|
||||
* This file includes types for the epilogues. The empty structs exist so we can signal to template
|
||||
* code the type of epilogue we want to run, and let the underlying code specify the details such as
|
||||
* element types, accumulator type and elements per vector access.
|
||||
*
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_generic.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_relu.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_silu.h"
|
||||
#include "cutlass_extensions/epilogue/thread/fused_activations.h"
|
||||
#include <cutlass/epilogue/fusion/operations.hpp>
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace cutlass_extensions
|
||||
{
|
||||
|
||||
struct EpilogueOpBiasSilu
|
||||
{
|
||||
};
|
||||
|
||||
struct EpilogueOpBiasReLU
|
||||
{
|
||||
};
|
||||
|
||||
struct EpilogueOpBiasFtGelu
|
||||
{
|
||||
};
|
||||
|
||||
struct EpilogueOpBias
|
||||
{
|
||||
};
|
||||
|
||||
struct EpilogueOpDefaultSilu
|
||||
{
|
||||
};
|
||||
|
||||
struct EpilogueOpDefaultReLU
|
||||
{
|
||||
};
|
||||
|
||||
struct EpilogueOpDefaultFtGelu
|
||||
{
|
||||
};
|
||||
|
||||
struct EpilogueOpDefault
|
||||
{
|
||||
};
|
||||
|
||||
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator, typename Op>
|
||||
struct Epilogue
|
||||
{
|
||||
static_assert(sizeof(ElementType) == 0, "Unrecognized Epilogue Tag");
|
||||
};
|
||||
|
||||
constexpr auto BiasScaleMode = cutlass::epilogue::thread::ScaleType::NoBetaScaling;
|
||||
|
||||
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
|
||||
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpBiasSilu>
|
||||
{
|
||||
using Op = cutlass::epilogue::thread::LinearCombinationSilu<ElementType, ElementsPerVectorAccess,
|
||||
ElementAccumulator, ElementAccumulator, BiasScaleMode>;
|
||||
};
|
||||
|
||||
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
|
||||
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpBiasReLU>
|
||||
{
|
||||
using Op = cutlass::epilogue::thread::LinearCombinationRelu<ElementType, ElementsPerVectorAccess,
|
||||
ElementAccumulator, ElementAccumulator, BiasScaleMode>;
|
||||
};
|
||||
|
||||
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
|
||||
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpBiasFtGelu>
|
||||
{
|
||||
using Op = cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::GELU_taylor, ElementType,
|
||||
ElementsPerVectorAccess, ElementAccumulator, ElementAccumulator, BiasScaleMode,
|
||||
cutlass::FloatRoundStyle::round_to_nearest, true>;
|
||||
};
|
||||
|
||||
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
|
||||
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpBias>
|
||||
{
|
||||
using Op = cutlass::epilogue::thread::LinearCombination<ElementType, ElementsPerVectorAccess, ElementAccumulator,
|
||||
ElementAccumulator, BiasScaleMode>;
|
||||
};
|
||||
|
||||
constexpr auto DefaultScaleMode = cutlass::epilogue::thread::ScaleType::Default;
|
||||
|
||||
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
|
||||
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpDefaultSilu>
|
||||
{
|
||||
using Op = cutlass::epilogue::thread::LinearCombinationSilu<ElementType, ElementsPerVectorAccess,
|
||||
ElementAccumulator, ElementAccumulator, DefaultScaleMode>;
|
||||
};
|
||||
|
||||
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
|
||||
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpDefaultReLU>
|
||||
{
|
||||
using Op = cutlass::epilogue::thread::LinearCombinationRelu<ElementType, ElementsPerVectorAccess,
|
||||
ElementAccumulator, ElementAccumulator, DefaultScaleMode>;
|
||||
};
|
||||
|
||||
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
|
||||
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpDefaultFtGelu>
|
||||
{
|
||||
using Op = cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::GELU_taylor, ElementType,
|
||||
ElementsPerVectorAccess, ElementAccumulator, ElementAccumulator, DefaultScaleMode,
|
||||
cutlass::FloatRoundStyle::round_to_nearest, true>;
|
||||
};
|
||||
|
||||
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
|
||||
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpDefault>
|
||||
{
|
||||
using Op = cutlass::epilogue::thread::LinearCombination<ElementType, ElementsPerVectorAccess, ElementAccumulator,
|
||||
ElementAccumulator, DefaultScaleMode>;
|
||||
};
|
||||
|
||||
} // namespace cutlass_extensions
|
||||
} // namespace tensorrt_llm
|
||||
@@ -1,221 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/arch/mma.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
#include "cutlass/gemm/collective/builders/sm90_common.inl"
|
||||
|
||||
// SM90 Collective Builders should be used only starting CUDA 12.0
|
||||
#if (__CUDACC_VER_MAJOR__ >= 12)
|
||||
#define CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
|
||||
#endif
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::collective
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace detail
|
||||
{
|
||||
|
||||
// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count.
|
||||
template <int CapacityBytes, class ElementA, class ElementB, class TileShapeMNK, bool SwapAB, int carveout_bytes>
|
||||
constexpr int compute_stage_count_or_override_gated(StageCountAutoCarveout<carveout_bytes> stage_count)
|
||||
{
|
||||
// 32 bytes to account for barriers etc.
|
||||
constexpr int stage_barrier_bytes = 32;
|
||||
constexpr int a_bits = static_cast<int>(sizeof_bits<ElementA>::value);
|
||||
constexpr int b_bits = static_cast<int>(sizeof_bits<ElementB>::value);
|
||||
constexpr int stage_bytes = [&]() -> int
|
||||
{
|
||||
if constexpr (SwapAB)
|
||||
{
|
||||
return (a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{}) * 2) / 8
|
||||
+ (b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) / 8 + stage_barrier_bytes;
|
||||
}
|
||||
else
|
||||
{
|
||||
return (a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) / 8
|
||||
+ (b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{}) * 2) / 8 + stage_barrier_bytes;
|
||||
}
|
||||
}();
|
||||
|
||||
return (CapacityBytes - carveout_bytes) / stage_bytes;
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// GMMA_TMA_WS_SS
|
||||
template <class ElementA, class GmemLayoutA, int AlignmentA, class ElementB, class GmemLayoutB, int AlignmentB,
|
||||
class ElementAccumulator, class TileShape_MNK, class ClusterShape_MNK, class StageCountType,
|
||||
class KernelScheduleType, template <class /* ElementCompute */> class Activation, bool SwapAB>
|
||||
struct CollectiveBuilderGated<arch::Sm90, arch::OpClassTensorOp, ElementA, GmemLayoutA, AlignmentA, ElementB,
|
||||
GmemLayoutB, AlignmentB, ElementAccumulator, TileShape_MNK, ClusterShape_MNK, StageCountType, KernelScheduleType,
|
||||
Activation, SwapAB,
|
||||
cute::enable_if_t<(cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecialized>
|
||||
|| cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpong>
|
||||
|| cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperative>
|
||||
|| cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedCooperative>) &¬ detail::
|
||||
is_use_rmem_A<ElementA, GmemLayoutA, ElementB, GmemLayoutB>()>>
|
||||
{
|
||||
static_assert(is_static<TileShape_MNK>::value);
|
||||
static_assert(is_static<ClusterShape_MNK>::value);
|
||||
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
|
||||
static_assert(cutlass::detail::dependent_false<ElementA>, "Unsupported Toolkit for SM90 Collective Builder\n");
|
||||
#endif
|
||||
static_assert(detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
|
||||
"Should meet TMA alignment requirement\n");
|
||||
|
||||
static constexpr bool IsArrayOfPointersGemm
|
||||
= (cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedCooperative>);
|
||||
static constexpr bool IsFP8Input = detail::is_input_fp8<ElementA, ElementB>();
|
||||
static_assert(!IsFP8Input || (IsFP8Input && !IsArrayOfPointersGemm),
|
||||
"Kernel[Array/Group]TmaWarpSpecializedCooperative is only compatible with FP8 FastAccum version right now\n");
|
||||
|
||||
// For fp32 types, map to tf32 MMA value type
|
||||
using MmaElementA = cute::conditional_t<cute::is_same_v<ElementA, float>, tfloat32_t, ElementA>;
|
||||
using MmaElementB = cute::conditional_t<cute::is_same_v<ElementB, float>, tfloat32_t, ElementB>;
|
||||
|
||||
static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A<MmaElementA, GmemLayoutA>();
|
||||
static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B<MmaElementB, GmemLayoutB>();
|
||||
|
||||
using AtomLayoutMNK = cute::conditional_t<cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperative>
|
||||
|| IsArrayOfPointersGemm,
|
||||
Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;
|
||||
|
||||
using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector<MmaElementA, MmaElementB,
|
||||
ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(),
|
||||
AtomLayoutMNK{}));
|
||||
|
||||
using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
|
||||
using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));
|
||||
|
||||
using SmemLayoutAtomA = decltype(detail::ss_smem_selector<GmmaMajorA, MmaElementA,
|
||||
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
using SmemLayoutAtomB = decltype(detail::ss_smem_selector<GmmaMajorB, MmaElementB,
|
||||
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
|
||||
static constexpr int PipelineStages
|
||||
= detail::compute_stage_count_or_override_gated<detail::sm90_smem_capacity_bytes, MmaElementA, MmaElementB,
|
||||
TileShape_MNK, SwapAB>(StageCountType{});
|
||||
using DispatchPolicy = cute::conditional_t<IsArrayOfPointersGemm,
|
||||
MainloopSm90ArrayTmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>,
|
||||
/* For FP8 use a separate mainloop compared to other datatypes */
|
||||
cute::conditional_t<IsFP8Input,
|
||||
MainloopSm90TmaGmmaWarpSpecializedFP8<PipelineStages, ClusterShape_MNK, KernelScheduleType>,
|
||||
MainloopSm90TmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>>>;
|
||||
|
||||
using SmemCopyAtomA = void;
|
||||
using SmemCopyAtomB = void;
|
||||
|
||||
using CollectiveOp = CollectiveMmaGated<DispatchPolicy, TileShape_MNK, ElementA, TagToStrideA_t<GmemLayoutA>,
|
||||
ElementB, TagToStrideB_t<GmemLayoutB>, TiledMma, GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity,
|
||||
GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity, Activation, SwapAB>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// GMMA_TMA_WS_FP8_FAST_ACCUM_SS
|
||||
template <class ElementA, class GmemLayoutA, int AlignmentA, class ElementB, class GmemLayoutB, int AlignmentB,
|
||||
class ElementAccumulator, class TileShape_MNK, class ClusterShape_MNK, class StageCountType,
|
||||
class KernelScheduleType, template <class /* ElementCompute */> class Activation, bool SwapAB>
|
||||
struct CollectiveBuilderGated<arch::Sm90, arch::OpClassTensorOp, ElementA, GmemLayoutA, AlignmentA, ElementB,
|
||||
GmemLayoutB, AlignmentB, ElementAccumulator, TileShape_MNK, ClusterShape_MNK, StageCountType, KernelScheduleType,
|
||||
Activation, SwapAB,
|
||||
cute::enable_if_t<cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedFP8FastAccum>
|
||||
|| cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpongFP8FastAccum>
|
||||
|| cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperativeFP8FastAccum>
|
||||
|| cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum>>>
|
||||
{
|
||||
static_assert(is_static<TileShape_MNK>::value);
|
||||
static_assert(is_static<ClusterShape_MNK>::value);
|
||||
static_assert(detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
|
||||
"Not meet TMA alignment requirement yet\n");
|
||||
static_assert(
|
||||
detail::is_input_fp8<ElementA, ElementB>(), "Only FP8 datatypes are compatible with these kernel schedules\n");
|
||||
// Dispatch TN fp8 kernels only to TMA warp specialized FP8 builder
|
||||
static_assert(!detail::is_use_rmem_A<ElementA, GmemLayoutA, ElementB, GmemLayoutB>(),
|
||||
"Not supported for fp8 non-TN warp specialized kernels yet\n");
|
||||
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
|
||||
static_assert(cutlass::detail::dependent_false<ElementA>, "Unsupported Toolkit for SM90 Collective Builder\n");
|
||||
#endif
|
||||
|
||||
static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A<ElementA, GmemLayoutA>();
|
||||
static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B<ElementB, GmemLayoutB>();
|
||||
|
||||
static constexpr bool IsArrayOfPointersGemm
|
||||
= (cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum>);
|
||||
using AtomLayoutMNK
|
||||
= cute::conditional_t<cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperativeFP8FastAccum>
|
||||
|| IsArrayOfPointersGemm,
|
||||
Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;
|
||||
|
||||
using TiledMma = decltype(cute::make_tiled_mma(
|
||||
cute::GMMA::ss_op_selector<ElementA, ElementB, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(),
|
||||
AtomLayoutMNK{}));
|
||||
|
||||
using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
|
||||
using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));
|
||||
|
||||
using SmemLayoutAtomA = decltype(detail::ss_smem_selector<GmmaMajorA, ElementA,
|
||||
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
using SmemLayoutAtomB = decltype(detail::ss_smem_selector<GmmaMajorB, ElementB,
|
||||
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
|
||||
static constexpr int PipelineStages
|
||||
= detail::compute_stage_count_or_override_gated<detail::sm90_smem_capacity_bytes, ElementA, ElementB,
|
||||
TileShape_MNK, SwapAB>(StageCountType{});
|
||||
using DispatchPolicy = cute::conditional_t<IsArrayOfPointersGemm,
|
||||
MainloopSm90ArrayTmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>,
|
||||
MainloopSm90TmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>>;
|
||||
|
||||
using SmemCopyAtomA = void;
|
||||
using SmemCopyAtomB = void;
|
||||
|
||||
using CollectiveOp = CollectiveMmaGated<DispatchPolicy, TileShape_MNK, ElementA, TagToStrideA_t<GmemLayoutA>,
|
||||
ElementB, TagToStrideB_t<GmemLayoutB>, TiledMma, GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity,
|
||||
GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity, Activation, SwapAB>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -1,58 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass_extensions/gemm/collective/collective_mma_gated.hpp"
|
||||
|
||||
namespace cutlass::gemm::collective
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <class ArchTag, class OpClass, class ElementA, class GmemLayoutA, int AlignmentA, class ElementB,
|
||||
class GmemLayoutB, int AlignmentB, class ElementAccumulator, class TileShape_MNK, class ClusterShape_MNK,
|
||||
class StageCountType, class KernelScheduleType, template <class /* ElementCompute */> class Activation,
|
||||
bool SwapAB = false, class Enable = void>
|
||||
struct CollectiveBuilderGated
|
||||
{
|
||||
static_assert(sizeof(ElementA) == 0, "Could not build a collective for given parameters.");
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_gated.inl"
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -1,59 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/detail/dependent_false.hpp"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::collective
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <class DispatchPolicy, class TileShape, class ElementA, class StrideA, class ElementB, class StrideB,
|
||||
class TiledMma, class GmemTiledCopyA, class SmemLayoutAtomA, class SmemCopyAtomA, class TransformA,
|
||||
class GmemTiledCopyB, class SmemLayoutAtomB, class SmemCopyAtomB, class TransformB,
|
||||
template <class /* ElementCompute */> class Activation, bool SwapAB = false>
|
||||
struct CollectiveMmaGated
|
||||
{
|
||||
static_assert(cutlass::detail::dependent_false<ElementA>, "Could not find a mainloop specialization.");
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized.hpp"
|
||||
#include "cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized_fp8.hpp"
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -1,642 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "cute/arch/cluster_sm90.hpp"
|
||||
#include "cute/arch/copy_sm90.hpp"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
|
||||
#include "cute/algorithm/functional.hpp"
|
||||
#include "cute/algorithm/gemm.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cute/numeric/arithmetic_tuple.hpp"
|
||||
#include "cute/tensor_predicate.hpp"
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::collective
|
||||
{
|
||||
using namespace cute;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// WarpSpecialized Mainloop
|
||||
template <int Stages, class ClusterShape, class KernelSchedule, class TileShape_, class ElementA_, class StrideA_,
|
||||
class ElementB_, class StrideB_, class TiledMma_, class GmemTiledCopyA_, class SmemLayoutAtomA_,
|
||||
class SmemCopyAtomA_, class TransformA_, class GmemTiledCopyB_, class SmemLayoutAtomB_, class SmemCopyAtomB_,
|
||||
class TransformB_, template <class /* ElementCompute */> class Activation_, bool SwapAB_>
|
||||
struct CollectiveMmaGated<MainloopSm90TmaGmmaWarpSpecialized<Stages, ClusterShape, KernelSchedule>, TileShape_,
|
||||
ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_, GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_,
|
||||
GmemTiledCopyB_, SmemLayoutAtomB_, SmemCopyAtomB_, TransformB_, Activation_, SwapAB_>
|
||||
{
|
||||
static constexpr bool isGated = true;
|
||||
static constexpr bool SwapAB = SwapAB_;
|
||||
|
||||
//
|
||||
// Type Aliases
|
||||
//
|
||||
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecialized<Stages, ClusterShape, KernelSchedule>;
|
||||
using TileShape = TileShape_;
|
||||
using ElementA = ElementA_;
|
||||
using StrideA = StrideA_;
|
||||
using ElementB = ElementB_;
|
||||
using StrideB = StrideB_;
|
||||
using TiledMma = TiledMma_;
|
||||
using ElementAccumulator = typename TiledMma::ValTypeC;
|
||||
using GmemTiledCopyA = GmemTiledCopyA_;
|
||||
using GmemTiledCopyB = GmemTiledCopyB_;
|
||||
using SmemLayoutAtomA = SmemLayoutAtomA_;
|
||||
using SmemLayoutAtomB = SmemLayoutAtomB_;
|
||||
using SmemCopyAtomA = SmemCopyAtomA_;
|
||||
using SmemCopyAtomB = SmemCopyAtomB_;
|
||||
using TransformA = TransformA_;
|
||||
using TransformB = TransformB_;
|
||||
using ArchTag = typename DispatchPolicy::ArchTag;
|
||||
using Activation = Activation_<ElementAccumulator>;
|
||||
|
||||
using ElementAux = cute::conditional_t<SwapAB, ElementA_, ElementB_>;
|
||||
using ValTypeAux = cute::conditional_t<SwapAB, typename TiledMma::ValTypeA, typename TiledMma::ValTypeB>;
|
||||
|
||||
using MainloopPipeline = cutlass::PipelineTmaAsync<DispatchPolicy::Stages>;
|
||||
using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
|
||||
|
||||
using PipelineParams = typename MainloopPipeline::Params;
|
||||
|
||||
static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert(
|
||||
(size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert(
|
||||
(size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
|
||||
static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert(
|
||||
(size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert(
|
||||
(size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
|
||||
// Tile along modes in a way that maximizes the TMA box size.
|
||||
using SmemLayoutA = decltype(tile_to_shape(SmemLayoutAtomA{},
|
||||
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
|
||||
conditional_t<::cutlass::gemm::detail::is_major<0, StrideA>(), Step<_2, _1, _3>, Step<_1, _2, _3>>{}));
|
||||
using SmemLayoutB = decltype(tile_to_shape(SmemLayoutAtomB{},
|
||||
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
|
||||
conditional_t<::cutlass::gemm::detail::is_major<0, StrideB>(), Step<_2, _1, _3>, Step<_1, _2, _3>>{}));
|
||||
using SmemLayoutAux = cute::conditional_t<SwapAB, SmemLayoutA, SmemLayoutB>;
|
||||
|
||||
static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more.");
|
||||
static_assert(cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value
|
||||
&& cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeB>::value,
|
||||
"MMA atom must source both A and B operand from smem_desc for this mainloop.");
|
||||
static_assert(
|
||||
cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>,
|
||||
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
|
||||
static_assert(
|
||||
cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>,
|
||||
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
|
||||
|
||||
// TMA converts f32 input to tf32 when copying from GMEM to SMEM
|
||||
// For all other types, cast to size equivalent uint type to avoid any rounding by TMA.
|
||||
static constexpr bool ConvertF32toTF32A = cute::is_same_v<float, ElementA>;
|
||||
static constexpr bool ConvertF32toTF32B = cute::is_same_v<float, ElementB>;
|
||||
using InternalElementA = cute::conditional_t<ConvertF32toTF32A, tfloat32_t, uint_bit_t<sizeof_bits_v<ElementA>>>;
|
||||
using InternalElementB = cute::conditional_t<ConvertF32toTF32B, tfloat32_t, uint_bit_t<sizeof_bits_v<ElementB>>>;
|
||||
using InternalElementAux = cute::conditional_t<SwapAB, InternalElementA, InternalElementB>;
|
||||
|
||||
struct SharedStorage
|
||||
{
|
||||
struct TensorStorage : cute::aligned_struct<128>
|
||||
{
|
||||
cute::array_aligned<typename TiledMma::ValTypeA, cute::cosize_v<SmemLayoutA>> smem_A;
|
||||
cute::array_aligned<typename TiledMma::ValTypeB, cute::cosize_v<SmemLayoutB>> smem_B;
|
||||
cute::array_aligned<ValTypeAux, cute::cosize_v<SmemLayoutAux>> smem_Aux;
|
||||
} tensors;
|
||||
|
||||
using PipelineStorage = typename MainloopPipeline::SharedStorage;
|
||||
PipelineStorage pipeline;
|
||||
};
|
||||
|
||||
using TensorStorage = typename SharedStorage::TensorStorage;
|
||||
using PipelineStorage = typename SharedStorage::PipelineStorage;
|
||||
|
||||
// Host side kernel arguments
|
||||
struct Arguments
|
||||
{
|
||||
ElementA const* ptr_A;
|
||||
StrideA dA;
|
||||
ElementB const* ptr_B;
|
||||
StrideB dB;
|
||||
float scale_d0 = 1.0f;
|
||||
float scale_d1 = 1.0f;
|
||||
uint32_t mma_promotion_interval = 4;
|
||||
};
|
||||
|
||||
// Device side kernel params
|
||||
struct Params
|
||||
{
|
||||
// Assumption: StrideA is congruent with Problem_MK
|
||||
using TMA_A = decltype(make_tma_copy(GmemTiledCopyA{},
|
||||
make_tensor(static_cast<InternalElementA const*>(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}),
|
||||
SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
|
||||
size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any
|
||||
// Assumption: StrideB is congruent with Problem_NK
|
||||
using TMA_B = decltype(make_tma_copy(GmemTiledCopyB{},
|
||||
make_tensor(static_cast<InternalElementB const*>(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}),
|
||||
SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
|
||||
size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any
|
||||
using TMA_Aux = cute::conditional_t<SwapAB, TMA_A, TMA_B>;
|
||||
TMA_A tma_load_a;
|
||||
TMA_B tma_load_b;
|
||||
TMA_Aux tma_load_aux;
|
||||
float scale_d0 = 1.0f;
|
||||
float scale_d1 = 1.0f;
|
||||
};
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
template <class ProblemShape>
|
||||
static constexpr Params to_underlying_arguments(
|
||||
ProblemShape const& problem_shape, Arguments const& args, void* workspace)
|
||||
{
|
||||
(void) workspace;
|
||||
|
||||
// Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK)
|
||||
auto problem_shape_MNKL = append<4>(problem_shape, 1);
|
||||
auto [M, N, K, L] = problem_shape_MNKL;
|
||||
|
||||
auto ptr_A = reinterpret_cast<InternalElementA const*>(args.ptr_A);
|
||||
auto ptr_B = reinterpret_cast<InternalElementB const*>(args.ptr_B);
|
||||
|
||||
Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M, K, L), args.dA));
|
||||
Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N, K, L), args.dB));
|
||||
typename Params::TMA_A tma_load_a = make_tma_copy(GmemTiledCopyA{}, tensor_a,
|
||||
SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
|
||||
size<1>(ClusterShape{})); // mcast along N mode for this M load, if any
|
||||
typename Params::TMA_B tma_load_b = make_tma_copy(GmemTiledCopyB{}, tensor_b,
|
||||
SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
|
||||
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
|
||||
|
||||
if constexpr (SwapAB)
|
||||
{
|
||||
auto ptr_Aux = reinterpret_cast<InternalElementA const*>(args.ptr_A + size(make_shape(M, K, L)));
|
||||
Tensor tensor_aux = make_tensor(ptr_Aux, make_layout(make_shape(M, K, L), args.dA));
|
||||
typename Params::TMA_Aux tma_load_aux = make_tma_copy(GmemTiledCopyA{}, tensor_aux,
|
||||
SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
|
||||
size<1>(ClusterShape{})); // mcast along N mode for this M load, if any
|
||||
return {tma_load_a, tma_load_b, tma_load_aux, args.scale_d0, args.scale_d1};
|
||||
}
|
||||
else
|
||||
{
|
||||
auto ptr_Aux = reinterpret_cast<InternalElementB const*>(args.ptr_B + size(make_shape(N, K, L)));
|
||||
Tensor tensor_aux = make_tensor(ptr_Aux, make_layout(make_shape(N, K, L), args.dB));
|
||||
typename Params::TMA_Aux tma_load_aux = make_tma_copy(GmemTiledCopyB{}, tensor_aux,
|
||||
SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
|
||||
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
|
||||
return {tma_load_a, tma_load_b, tma_load_aux, args.scale_d0, args.scale_d1};
|
||||
}
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static bool can_implement(ProblemShape const& problem_shape, [[maybe_unused]] Arguments const& args)
|
||||
{
|
||||
constexpr int tma_alignment_bits = 128;
|
||||
auto problem_shape_MNKL = append<4>(problem_shape, 1);
|
||||
auto [M, N, K, L] = problem_shape_MNKL;
|
||||
|
||||
bool implementable = true;
|
||||
constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
|
||||
implementable = implementable
|
||||
&& cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M, K, L), StrideA{});
|
||||
constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
|
||||
implementable = implementable
|
||||
&& cutlass::detail::check_alignment<min_tma_aligned_elements_B>(cute::make_shape(N, K, L), StrideB{});
|
||||
|
||||
if (!implementable)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(
|
||||
" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
|
||||
}
|
||||
return implementable;
|
||||
}
|
||||
|
||||
static constexpr int K_PIPE_MAX = DispatchPolicy::Stages;
|
||||
static constexpr int K_PIPE_MMAS = 1;
|
||||
static constexpr uint32_t TmaTransactionBytes
|
||||
= (size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast<uint32_t>(sizeof_bits<ElementA>::value)) / 8
|
||||
+ (size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast<uint32_t>(sizeof_bits<ElementB>::value)) / 8
|
||||
+ (size<0>(SmemLayoutAux{}) * size<1>(SmemLayoutAux{}) * static_cast<uint32_t>(sizeof_bits<ElementAux>::value))
|
||||
/ 8;
|
||||
|
||||
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
|
||||
CUTLASS_DEVICE
|
||||
static void prefetch_tma_descriptors(Params const& mainloop_params)
|
||||
{
|
||||
cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(mainloop_params.tma_load_aux.get_tma_descriptor());
|
||||
}
|
||||
|
||||
/// Set up the data needed by this collective for load and mma.
|
||||
/// Returns a tuple of tensors. The collective and the kernel layer have the contract
|
||||
/// Returned tuple must contain at least two elements, with the first two elements being:
|
||||
/// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l)
|
||||
/// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l)
|
||||
/// gAux_xkl - The tma tensor, A/B after a local tile so it has shape (BLK_N,BLK_K,m/n,k,l)
|
||||
/// The rest of the tensors can be specified as needed by this collective.
|
||||
template <class ProblemShape_MNKL>
|
||||
CUTLASS_DEVICE auto load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const
|
||||
{
|
||||
using X = Underscore;
|
||||
// Separate out problem shape for convenience
|
||||
auto [M, N, K, L] = problem_shape_MNKL;
|
||||
|
||||
// TMA requires special handling of strides to deal with coord codomain mapping
|
||||
// Represent the full tensors -- get these from TMA
|
||||
Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M, K, L)); // (m,k,l)
|
||||
Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N, K, L)); // (n,k,l)
|
||||
|
||||
// Make tiled views, defer the slice
|
||||
Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_, _, _), Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l)
|
||||
Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_, _, _), Step<X, _1, _1>{}); // (BLK_N,BLK_K,n,k,l)
|
||||
|
||||
if constexpr (SwapAB)
|
||||
{
|
||||
Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor(make_shape(M, K, L)); // (m,k,l)
|
||||
Tensor gAux_xkl
|
||||
= local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _), Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l)
|
||||
return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl);
|
||||
}
|
||||
else
|
||||
{
|
||||
Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor(make_shape(N, K, L)); // (n,k,l)
|
||||
Tensor gAux_xkl
|
||||
= local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _), Step<X, _1, _1>{}); // (BLK_N,BLK_K,n,k,l)
|
||||
return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl);
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a collective-scoped matrix multiply-accumulate
|
||||
/// Producer Perspective
|
||||
template <class TensorA, class TensorB, class TensorAux, class KTileIterator, class BlockCoord>
|
||||
CUTLASS_DEVICE void load(Params const& mainloop_params, MainloopPipeline pipeline, PipelineState smem_pipe_write,
|
||||
cute::tuple<TensorA, TensorB, TensorAux> const& load_inputs, BlockCoord const& blk_coord,
|
||||
KTileIterator k_tile_iter, int k_tile_count, int thread_idx, uint32_t block_rank_in_cluster,
|
||||
TensorStorage& shared_tensors)
|
||||
{
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
if (lane_predicate)
|
||||
{
|
||||
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
||||
Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()), SmemLayoutAux{});
|
||||
|
||||
//
|
||||
// Prepare the TMA loads for A and B
|
||||
//
|
||||
|
||||
constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape());
|
||||
uint2 cluster_local_block_id
|
||||
= {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
|
||||
|
||||
Tensor gA_mkl = get<0>(load_inputs);
|
||||
Tensor gB_nkl = get<1>(load_inputs);
|
||||
Tensor gAux_xkl = get<2>(load_inputs);
|
||||
|
||||
auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
|
||||
auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);
|
||||
auto block_tma_aux = SwapAB ? mainloop_params.tma_load_aux.get_slice(cluster_local_block_id.y)
|
||||
: mainloop_params.tma_load_aux.get_slice(cluster_local_block_id.x);
|
||||
// Partition the inputs based on the current block coordinates.
|
||||
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
|
||||
Tensor gA = gA_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k)
|
||||
Tensor gB = gB_nkl(_, _, n_coord, _, l_coord); // (BLK_N,BLK_K,k)
|
||||
Tensor gAux = SwapAB ? gAux_xkl(_, _, m_coord, _, l_coord) : gAux_xkl(_, _, n_coord, _, l_coord);
|
||||
|
||||
// Applies the mapping from block_tma_a
|
||||
Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
|
||||
Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
|
||||
|
||||
Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
|
||||
Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
|
||||
|
||||
Tensor tAuxgAux = block_tma_aux.partition_S(gAux);
|
||||
Tensor tAuxsAux = block_tma_aux.partition_D(sAux);
|
||||
|
||||
uint16_t mcast_mask_a = 0;
|
||||
uint16_t mcast_mask_b = 0;
|
||||
uint16_t mcast_mask_aux = 0;
|
||||
|
||||
// Issue TmaLoads
|
||||
// Maps the tile -> block, value
|
||||
if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>)
|
||||
{
|
||||
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
|
||||
for (int n = 0; n < size<1>(block_layout); ++n)
|
||||
{
|
||||
mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x, n, Int<0>{}));
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>)
|
||||
{
|
||||
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
|
||||
for (int m = 0; m < size<0>(block_layout); ++m)
|
||||
{
|
||||
mcast_mask_b |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, Int<0>{}));
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (SwapAB)
|
||||
{
|
||||
mcast_mask_aux = mcast_mask_a;
|
||||
}
|
||||
else
|
||||
{
|
||||
mcast_mask_aux = mcast_mask_b;
|
||||
}
|
||||
|
||||
// Mainloop
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; k_tile_count > 0; --k_tile_count)
|
||||
{
|
||||
// LOCK smem_pipe_write for _writing_
|
||||
pipeline.producer_acquire(smem_pipe_write);
|
||||
|
||||
//
|
||||
// Copy gmem to smem for *k_tile_iter
|
||||
//
|
||||
|
||||
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
|
||||
BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
|
||||
|
||||
int write_stage = smem_pipe_write.index();
|
||||
copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_, _, _, *k_tile_iter),
|
||||
tAsA(_, _, _, write_stage));
|
||||
copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_, _, _, *k_tile_iter),
|
||||
tBsB(_, _, _, write_stage));
|
||||
copy(mainloop_params.tma_load_aux.with(*tma_barrier, mcast_mask_aux), tAuxgAux(_, _, _, *k_tile_iter),
|
||||
tAuxsAux(_, _, _, write_stage));
|
||||
++k_tile_iter;
|
||||
|
||||
// Advance smem_pipe_write
|
||||
++smem_pipe_write;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
|
||||
CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write)
|
||||
{
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
// Issue the epilogue waits
|
||||
if (lane_predicate)
|
||||
{
|
||||
/* This helps avoid early exit of blocks in Cluster
|
||||
* Waits for all stages to either be released (all
|
||||
* Consumer UNLOCKs), or if the stage was never used
|
||||
* then would just be acquired since the phase was
|
||||
* still inverted from make_producer_start_state
|
||||
*/
|
||||
pipeline.producer_tail(smem_pipe_write);
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a collective-scoped matrix multiply-accumulate
|
||||
/// Consumer Perspective
|
||||
template <class FrgTensorC>
|
||||
CUTLASS_DEVICE void mma(MainloopPipeline pipeline, PipelineState smem_pipe_read, FrgTensorC& accum0,
|
||||
FrgTensorC& accum1, int k_tile_count, int thread_idx, TensorStorage& shared_tensors,
|
||||
Params const& mainloop_params)
|
||||
{
|
||||
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
|
||||
static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
|
||||
static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
|
||||
static_assert(cute::rank(SmemLayoutAux{}) == 3, "Smem layout must be rank 3.");
|
||||
static_assert(cute::is_void_v<SmemCopyAtomA>,
|
||||
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
|
||||
static_assert(cute::is_void_v<SmemCopyAtomB>,
|
||||
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
|
||||
|
||||
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
||||
Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()), SmemLayoutAux{});
|
||||
|
||||
//
|
||||
// Define C accumulators and A/B partitioning
|
||||
//
|
||||
|
||||
TiledMma tiled_mma;
|
||||
auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
|
||||
|
||||
Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
|
||||
Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
|
||||
|
||||
// Allocate "fragments/descriptors"
|
||||
Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE)
|
||||
Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE)
|
||||
|
||||
auto tCsAux = [&]() -> auto
|
||||
{
|
||||
if constexpr (SwapAB)
|
||||
{
|
||||
return thread_mma.partition_A(sAux);
|
||||
}
|
||||
else
|
||||
{
|
||||
return thread_mma.partition_B(sAux);
|
||||
}
|
||||
}();
|
||||
auto tCrAux = [&]() -> auto
|
||||
{
|
||||
if constexpr (SwapAB)
|
||||
{
|
||||
return thread_mma.make_fragment_A(tCsAux);
|
||||
}
|
||||
else
|
||||
{
|
||||
return thread_mma.make_fragment_B(tCsAux);
|
||||
}
|
||||
}();
|
||||
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum0)); // M
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum0)); // N
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K
|
||||
CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE
|
||||
if constexpr (SwapAB)
|
||||
{
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<1>(accum1)); // M
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum1)); // N
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCsAux)); // K
|
||||
CUTE_STATIC_ASSERT_V(size<3>(tCsB) == size<3>(tCsAux)); // PIPE
|
||||
}
|
||||
else
|
||||
{
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum1)); // M
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<2>(accum1)); // N
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsAux)); // K
|
||||
CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsAux)); // PIPE
|
||||
}
|
||||
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sA)); // PIPE
|
||||
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sB)); // PIPE
|
||||
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sAux)); // PIPE
|
||||
|
||||
//
|
||||
// PIPELINED MAIN LOOP
|
||||
//
|
||||
static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), "ERROR : Incorrect number of MMAs in flight");
|
||||
|
||||
// We release buffers to producer warps(dma load) with some mmas in flight
|
||||
PipelineState smem_pipe_release = smem_pipe_read;
|
||||
|
||||
// Prologue GMMAs
|
||||
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
|
||||
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
|
||||
warpgroup_fence_operand(accum0);
|
||||
warpgroup_fence_operand(accum1);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue)
|
||||
{
|
||||
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
|
||||
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
||||
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
||||
|
||||
int read_stage = smem_pipe_read.index();
|
||||
warpgroup_arrive();
|
||||
// Unroll the K mode manually to set scale D to 1
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block)
|
||||
{
|
||||
// (V,M,K) x (V,N,K) => (V,M,N)
|
||||
cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accum0);
|
||||
if constexpr (SwapAB)
|
||||
{
|
||||
cute::gemm(tiled_mma, tCrAux(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accum1);
|
||||
}
|
||||
else
|
||||
{
|
||||
cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), tCrAux(_, _, k_block, read_stage), accum1);
|
||||
}
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
}
|
||||
|
||||
warpgroup_commit_batch();
|
||||
|
||||
++smem_pipe_read;
|
||||
}
|
||||
|
||||
warpgroup_fence_operand(accum0);
|
||||
warpgroup_fence_operand(accum1);
|
||||
// Mainloop GMMAs
|
||||
k_tile_count -= prologue_mma_count;
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; k_tile_count > 0; --k_tile_count)
|
||||
{
|
||||
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
|
||||
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
||||
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
||||
|
||||
//
|
||||
// Compute on k_tile
|
||||
//
|
||||
|
||||
int read_stage = smem_pipe_read.index();
|
||||
warpgroup_fence_operand(accum0);
|
||||
warpgroup_fence_operand(accum1);
|
||||
warpgroup_arrive();
|
||||
// Unroll the K mode manually to set scale D to 1
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block)
|
||||
{
|
||||
// (V,M,K) x (V,N,K) => (V,M,N)
|
||||
cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accum0);
|
||||
if constexpr (SwapAB)
|
||||
{
|
||||
cute::gemm(tiled_mma, tCrAux(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accum1);
|
||||
}
|
||||
else
|
||||
{
|
||||
cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), tCrAux(_, _, k_block, read_stage), accum1);
|
||||
}
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
}
|
||||
warpgroup_commit_batch();
|
||||
|
||||
/// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed
|
||||
warpgroup_wait<K_PIPE_MMAS>();
|
||||
warpgroup_fence_operand(accum0);
|
||||
warpgroup_fence_operand(accum1);
|
||||
|
||||
// UNLOCK smem_pipe_release, done _computing_ on it
|
||||
pipeline.consumer_release(smem_pipe_release);
|
||||
|
||||
// Advance smem_pipe_read and smem_pipe_release
|
||||
++smem_pipe_read;
|
||||
++smem_pipe_release;
|
||||
}
|
||||
|
||||
warpgroup_fence_operand(accum0);
|
||||
warpgroup_fence_operand(accum1);
|
||||
}
|
||||
|
||||
/// Perform a Consumer Epilogue to release all buffers
|
||||
CUTLASS_DEVICE void mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count)
|
||||
{
|
||||
// Prologue GMMAs
|
||||
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
|
||||
k_tile_count -= prologue_mma_count;
|
||||
|
||||
smem_pipe_release.advance(k_tile_count);
|
||||
|
||||
// Wait on all GMMAs to complete
|
||||
warpgroup_wait<0>();
|
||||
|
||||
for (int count = 0; count < prologue_mma_count; ++count)
|
||||
{
|
||||
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
|
||||
++smem_pipe_release;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -1,665 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cute/arch/cluster_sm90.hpp"
|
||||
#include "cute/arch/copy_sm90.hpp"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
|
||||
#include "cute/algorithm/functional.hpp"
|
||||
#include "cute/algorithm/gemm.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cute/numeric/arithmetic_tuple.hpp"
|
||||
#include "cute/tensor_predicate.hpp"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
#include "cutlass/gemm/collective/fp8_accumulation.hpp"
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::collective
|
||||
{
|
||||
using namespace cute;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// WarpSpecialized Mainloop
|
||||
template <int Stages, class ClusterShape, class KernelSchedule, class TileShape_, class ElementA_, class StrideA_,
|
||||
class ElementB_, class StrideB_, class TiledMma_, class GmemTiledCopyA_, class SmemLayoutAtomA_,
|
||||
class SmemCopyAtomA_, class TransformA_, class GmemTiledCopyB_, class SmemLayoutAtomB_, class SmemCopyAtomB_,
|
||||
class TransformB_, template <class /* ElementCompute */> class Activation_, bool SwapAB_>
|
||||
struct CollectiveMmaGated<MainloopSm90TmaGmmaWarpSpecializedFP8<Stages, ClusterShape, KernelSchedule>, TileShape_,
|
||||
ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_, GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_,
|
||||
GmemTiledCopyB_, SmemLayoutAtomB_, SmemCopyAtomB_, TransformB_, Activation_, SwapAB_>
|
||||
{
|
||||
static constexpr bool isGated = true;
|
||||
static constexpr bool SwapAB = SwapAB_;
|
||||
|
||||
//
|
||||
// Type Aliases
|
||||
//
|
||||
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedFP8<Stages, ClusterShape, KernelSchedule>;
|
||||
using TileShape = TileShape_;
|
||||
using ElementA = ElementA_;
|
||||
using StrideA = StrideA_;
|
||||
using ElementB = ElementB_;
|
||||
using StrideB = StrideB_;
|
||||
using TiledMma = TiledMma_;
|
||||
using ElementAccumulator = typename TiledMma::ValTypeC;
|
||||
using GmemTiledCopyA = GmemTiledCopyA_;
|
||||
using GmemTiledCopyB = GmemTiledCopyB_;
|
||||
using SmemLayoutAtomA = SmemLayoutAtomA_;
|
||||
using SmemLayoutAtomB = SmemLayoutAtomB_;
|
||||
using SmemCopyAtomA = SmemCopyAtomA_;
|
||||
using SmemCopyAtomB = SmemCopyAtomB_;
|
||||
using TransformA = TransformA_;
|
||||
using TransformB = TransformB_;
|
||||
using ArchTag = typename DispatchPolicy::ArchTag;
|
||||
using Activation = Activation_<ElementAccumulator>;
|
||||
|
||||
using ElementAux = cute::conditional_t<SwapAB, ElementA_, ElementB_>;
|
||||
using ValTypeAux = cute::conditional_t<SwapAB, typename TiledMma::ValTypeA, typename TiledMma::ValTypeB>;
|
||||
|
||||
using MainloopPipeline = cutlass::PipelineTmaAsync<DispatchPolicy::Stages>;
|
||||
using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
|
||||
|
||||
using PipelineParams = typename MainloopPipeline::Params;
|
||||
|
||||
static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert(
|
||||
(size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert(
|
||||
(size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
|
||||
static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert(
|
||||
(size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert(
|
||||
(size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
|
||||
// Tile along modes in a way that maximizes the TMA box size.
|
||||
using SmemLayoutA = decltype(tile_to_shape(SmemLayoutAtomA{},
|
||||
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
|
||||
conditional_t<::cutlass::gemm::detail::is_major<0, StrideA>(), Step<_2, _1, _3>, Step<_1, _2, _3>>{}));
|
||||
using SmemLayoutB = decltype(tile_to_shape(SmemLayoutAtomB{},
|
||||
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
|
||||
conditional_t<::cutlass::gemm::detail::is_major<0, StrideB>(), Step<_2, _1, _3>, Step<_1, _2, _3>>{}));
|
||||
using SmemLayoutAux = cute::conditional_t<SwapAB, SmemLayoutA, SmemLayoutB>;
|
||||
|
||||
static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more.");
|
||||
static_assert(cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value
|
||||
&& cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeB>::value,
|
||||
"MMA atom must source both A and B operand from smem_desc for this mainloop.");
|
||||
static_assert(
|
||||
cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>,
|
||||
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
|
||||
static_assert(
|
||||
cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>,
|
||||
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
|
||||
|
||||
struct SharedStorage
|
||||
{
|
||||
struct TensorStorage : cute::aligned_struct<128>
|
||||
{
|
||||
cute::array_aligned<typename TiledMma::ValTypeA, cute::cosize_v<SmemLayoutA>> smem_A;
|
||||
cute::array_aligned<typename TiledMma::ValTypeB, cute::cosize_v<SmemLayoutB>> smem_B;
|
||||
cute::array_aligned<ValTypeAux, cute::cosize_v<SmemLayoutAux>> smem_Aux;
|
||||
} tensors;
|
||||
|
||||
using PipelineStorage = typename MainloopPipeline::SharedStorage;
|
||||
PipelineStorage pipeline;
|
||||
};
|
||||
|
||||
using TensorStorage = typename SharedStorage::TensorStorage;
|
||||
using PipelineStorage = typename SharedStorage::PipelineStorage;
|
||||
|
||||
// Host side kernel arguments
|
||||
struct Arguments
|
||||
{
|
||||
ElementA const* ptr_A;
|
||||
StrideA dA;
|
||||
ElementB const* ptr_B;
|
||||
StrideB dB;
|
||||
float scale_d0 = 1.0f;
|
||||
float scale_d1 = 1.0f;
|
||||
uint32_t mma_promotion_interval = 4;
|
||||
};
|
||||
|
||||
// Device side kernel params
|
||||
struct Params
|
||||
{
|
||||
// Assumption: StrideA is congruent with Problem_MK
|
||||
using TMA_A = decltype(make_tma_copy(GmemTiledCopyA{},
|
||||
make_tensor(static_cast<ElementA const*>(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}),
|
||||
SmemLayoutA{}(_, _, 0), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
|
||||
size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any
|
||||
// Assumption: StrideB is congruent with Problem_NK
|
||||
using TMA_B = decltype(make_tma_copy(GmemTiledCopyB{},
|
||||
make_tensor(static_cast<ElementB const*>(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}),
|
||||
SmemLayoutB{}(_, _, 0), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
|
||||
size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any
|
||||
using TMA_Aux = cute::conditional_t<SwapAB, TMA_A, TMA_B>;
|
||||
TMA_A tma_load_a;
|
||||
TMA_B tma_load_b;
|
||||
TMA_Aux tma_load_aux;
|
||||
float scale_d0 = 1.0f;
|
||||
float scale_d1 = 1.0f;
|
||||
uint32_t mma_promotion_interval = 4;
|
||||
};
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
template <class ProblemShape>
|
||||
static constexpr Params to_underlying_arguments(
|
||||
ProblemShape const& problem_shape, Arguments const& args, void* workspace)
|
||||
{
|
||||
(void) workspace;
|
||||
|
||||
// Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK)
|
||||
auto problem_shape_MNKL = append<4>(problem_shape, 1);
|
||||
auto [M, N, K, L] = problem_shape_MNKL;
|
||||
|
||||
auto ptr_A = reinterpret_cast<ElementA const*>(args.ptr_A);
|
||||
auto ptr_B = reinterpret_cast<ElementB const*>(args.ptr_B);
|
||||
|
||||
Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M, K, L), args.dA));
|
||||
Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N, K, L), args.dB));
|
||||
typename Params::TMA_A tma_load_a = make_tma_copy(GmemTiledCopyA{}, tensor_a,
|
||||
SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
|
||||
size<1>(ClusterShape{})); // mcast along N mode for this M load, if any
|
||||
typename Params::TMA_B tma_load_b = make_tma_copy(GmemTiledCopyB{}, tensor_b,
|
||||
SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
|
||||
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
|
||||
if constexpr (SwapAB)
|
||||
{
|
||||
auto ptr_Aux = reinterpret_cast<ElementA const*>(args.ptr_A + size(make_shape(M, K, L)));
|
||||
Tensor tensor_aux = make_tensor(ptr_Aux, make_layout(make_shape(M, K, L), args.dA));
|
||||
typename Params::TMA_Aux tma_load_aux = make_tma_copy(GmemTiledCopyA{}, tensor_aux,
|
||||
SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
|
||||
size<1>(ClusterShape{})); // mcast along N mode for this M load, if any
|
||||
return {tma_load_a, tma_load_b, tma_load_aux, args.scale_d0, args.scale_d1, args.mma_promotion_interval};
|
||||
}
|
||||
else
|
||||
{
|
||||
auto ptr_Aux = reinterpret_cast<ElementB const*>(args.ptr_B + size(make_shape(N, K, L)));
|
||||
Tensor tensor_aux = make_tensor(ptr_Aux, make_layout(make_shape(N, K, L), args.dB));
|
||||
typename Params::TMA_Aux tma_load_aux = make_tma_copy(GmemTiledCopyB{}, tensor_aux,
|
||||
SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
|
||||
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
|
||||
return {tma_load_a, tma_load_b, tma_load_aux, args.scale_d0, args.scale_d1, args.mma_promotion_interval};
|
||||
}
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static bool can_implement(ProblemShape const& problem_shape, [[maybe_unused]] Arguments const& args)
|
||||
{
|
||||
constexpr int tma_alignment_bits = 128;
|
||||
auto problem_shape_MNKL = append<4>(problem_shape, 1);
|
||||
auto [M, N, K, L] = problem_shape_MNKL;
|
||||
|
||||
bool implementable = true;
|
||||
constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
|
||||
implementable = implementable
|
||||
&& cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M, K, L), StrideA{});
|
||||
constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
|
||||
implementable = implementable
|
||||
&& cutlass::detail::check_alignment<min_tma_aligned_elements_B>(cute::make_shape(N, K, L), StrideB{});
|
||||
/* MMA promotion interval should be a multiple of 4, since each mainloop iteration would issue 4 MMA
|
||||
* instructions. */
|
||||
implementable = implementable && (args.mma_promotion_interval % 4 == 0);
|
||||
|
||||
if (!implementable)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(
|
||||
" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
|
||||
}
|
||||
return implementable;
|
||||
}
|
||||
|
||||
static constexpr int K_PIPE_MAX = DispatchPolicy::Stages;
|
||||
static constexpr int K_PIPE_MMAS = 1;
|
||||
static constexpr uint32_t TmaTransactionBytes
|
||||
= (size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast<uint32_t>(sizeof_bits<ElementA>::value)) / 8
|
||||
+ (size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast<uint32_t>(sizeof_bits<ElementB>::value)) / 8
|
||||
+ (size<0>(SmemLayoutAux{}) * size<1>(SmemLayoutAux{}) * static_cast<uint32_t>(sizeof_bits<ElementAux>::value))
|
||||
/ 8;
|
||||
|
||||
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
|
||||
CUTLASS_DEVICE
|
||||
static void prefetch_tma_descriptors(Params const& mainloop_params)
|
||||
{
|
||||
cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(mainloop_params.tma_load_aux.get_tma_descriptor());
|
||||
}
|
||||
|
||||
/// Set up the data needed by this collective for load and mma.
|
||||
/// Returns a tuple of tensors. The collective and the kernel layer have the contract
|
||||
/// Returned tuple must contain at least two elements, with the first two elements being:
|
||||
/// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l)
|
||||
/// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l)
|
||||
/// gAux_xkl - The tma tensor, A/B after a local tile so it has shape (BLK_N,BLK_K,m/n,k,l)
|
||||
template <class ProblemShape_MNKL>
|
||||
CUTLASS_DEVICE auto load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const
|
||||
{
|
||||
using X = Underscore;
|
||||
// Separate out problem shape for convenience
|
||||
auto [M, N, K, L] = problem_shape_MNKL;
|
||||
|
||||
// TMA requires special handling of strides to deal with coord codomain mapping
|
||||
// Represent the full tensors -- get these from TMA
|
||||
Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M, K, L)); // (m,k,l)
|
||||
Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N, K, L)); // (n,k,l)
|
||||
|
||||
// Make tiled views, defer the slice
|
||||
Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_, _, _), Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l)
|
||||
Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_, _, _), Step<X, _1, _1>{}); // (BLK_N,BLK_K,n,k,l)
|
||||
|
||||
if constexpr (SwapAB)
|
||||
{
|
||||
Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor(make_shape(M, K, L)); // (m,k,l)
|
||||
Tensor gAux_xkl
|
||||
= local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _), Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l)
|
||||
return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl);
|
||||
}
|
||||
else
|
||||
{
|
||||
Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor(make_shape(N, K, L)); // (n,k,l)
|
||||
Tensor gAux_xkl
|
||||
= local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _), Step<X, _1, _1>{}); // (BLK_N,BLK_K,n,k,l)
|
||||
return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl);
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a collective-scoped matrix multiply-accumulate
|
||||
/// Producer Perspective
|
||||
template <class TensorA, class TensorB, class TensorAux, class KTileIterator, class BlockCoord>
|
||||
CUTLASS_DEVICE void load(Params const& mainloop_params, MainloopPipeline pipeline, PipelineState smem_pipe_write,
|
||||
cute::tuple<TensorA, TensorB, TensorAux> const& load_inputs, BlockCoord const& blk_coord,
|
||||
KTileIterator k_tile_iter, int k_tile_count, int thread_idx, uint32_t block_rank_in_cluster,
|
||||
TensorStorage& shared_tensors)
|
||||
{
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
if (lane_predicate)
|
||||
{
|
||||
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
||||
Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()), SmemLayoutAux{});
|
||||
|
||||
//
|
||||
// Prepare the TMA loads for A and B
|
||||
//
|
||||
|
||||
constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
|
||||
uint2 cluster_local_block_id
|
||||
= {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
|
||||
|
||||
Tensor gA_mkl = get<0>(load_inputs);
|
||||
Tensor gB_nkl = get<1>(load_inputs);
|
||||
Tensor gAux_xkl = get<2>(load_inputs);
|
||||
|
||||
auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
|
||||
auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);
|
||||
auto block_tma_aux = SwapAB ? mainloop_params.tma_load_aux.get_slice(cluster_local_block_id.y)
|
||||
: mainloop_params.tma_load_aux.get_slice(cluster_local_block_id.x);
|
||||
|
||||
// Partition the inputs based on the current block coordinates.
|
||||
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
|
||||
Tensor gA = gA_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k)
|
||||
Tensor gB = gB_nkl(_, _, n_coord, _, l_coord); // (BLK_N,BLK_K,k)
|
||||
Tensor gAux = SwapAB ? gAux_xkl(_, _, m_coord, _, l_coord) : gAux_xkl(_, _, n_coord, _, l_coord);
|
||||
|
||||
// Applies the mapping from block_tma_a
|
||||
Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
|
||||
Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
|
||||
|
||||
Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
|
||||
Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
|
||||
|
||||
Tensor tAuxgAux = block_tma_aux.partition_S(gAux);
|
||||
Tensor tAuxsAux = block_tma_aux.partition_D(sAux);
|
||||
|
||||
uint16_t mcast_mask_a = 0;
|
||||
uint16_t mcast_mask_b = 0;
|
||||
uint16_t mcast_mask_aux = 0;
|
||||
|
||||
// Issue TmaLoads
|
||||
// Maps the tile -> block, value
|
||||
if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>)
|
||||
{
|
||||
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
|
||||
for (int n = 0; n < size<1>(block_layout); ++n)
|
||||
{
|
||||
mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x, n, Int<0>{}));
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>)
|
||||
{
|
||||
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
|
||||
for (int m = 0; m < size<0>(block_layout); ++m)
|
||||
{
|
||||
mcast_mask_b |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, Int<0>{}));
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (SwapAB)
|
||||
{
|
||||
mcast_mask_aux = mcast_mask_a;
|
||||
}
|
||||
else
|
||||
{
|
||||
mcast_mask_aux = mcast_mask_b;
|
||||
}
|
||||
|
||||
// Mainloop
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; k_tile_count > 0; --k_tile_count)
|
||||
{
|
||||
// LOCK smem_pipe_write for _writing_
|
||||
pipeline.producer_acquire(smem_pipe_write);
|
||||
|
||||
//
|
||||
// Copy gmem to smem for *k_tile_iter
|
||||
//
|
||||
|
||||
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
|
||||
BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
|
||||
|
||||
int write_stage = smem_pipe_write.index();
|
||||
copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_, _, _, *k_tile_iter),
|
||||
tAsA(_, _, _, write_stage));
|
||||
copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_, _, _, *k_tile_iter),
|
||||
tBsB(_, _, _, write_stage));
|
||||
copy(mainloop_params.tma_load_aux.with(*tma_barrier, mcast_mask_aux), tAuxgAux(_, _, _, *k_tile_iter),
|
||||
tAuxsAux(_, _, _, write_stage));
|
||||
++k_tile_iter;
|
||||
|
||||
// Advance smem_pipe_write
|
||||
++smem_pipe_write;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
|
||||
CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write)
|
||||
{
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
// Issue the epilogue waits
|
||||
if (lane_predicate)
|
||||
{
|
||||
/* This helps avoid early exit of blocks in Cluster
|
||||
* Waits for all stages to either be released (all
|
||||
* Consumer UNLOCKs), or if the stage was never used
|
||||
* then would just be acquired since the phase was
|
||||
* still inverted from make_producer_start_state
|
||||
*/
|
||||
pipeline.producer_tail(smem_pipe_write);
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a collective-scoped matrix multiply-accumulate
|
||||
/// Consumer Perspective
|
||||
template <class FrgTensorC>
|
||||
CUTLASS_DEVICE void mma(MainloopPipeline pipeline, PipelineState smem_pipe_read, FrgTensorC& accum0,
|
||||
FrgTensorC& accum1, int k_tile_count, int thread_idx, TensorStorage& shared_tensors,
|
||||
Params const& mainloop_params)
|
||||
{
|
||||
|
||||
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
|
||||
static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
|
||||
static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
|
||||
static_assert(cute::is_void_v<SmemCopyAtomA>,
|
||||
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
|
||||
static_assert(cute::is_void_v<SmemCopyAtomB>,
|
||||
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
|
||||
|
||||
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
||||
Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()), SmemLayoutAux{});
|
||||
|
||||
//
|
||||
// Define C accumulators and A/B partitioning
|
||||
//
|
||||
|
||||
TiledMma tiled_mma;
|
||||
auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
|
||||
|
||||
Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
|
||||
Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
|
||||
|
||||
// Allocate "fragments/descriptors"
|
||||
Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE)
|
||||
Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE)
|
||||
|
||||
auto tCsAux = [&]() -> auto
|
||||
{
|
||||
if constexpr (SwapAB)
|
||||
{
|
||||
return thread_mma.partition_A(sAux);
|
||||
}
|
||||
else
|
||||
{
|
||||
return thread_mma.partition_B(sAux);
|
||||
}
|
||||
}();
|
||||
auto tCrAux = [&]() -> auto
|
||||
{
|
||||
if constexpr (SwapAB)
|
||||
{
|
||||
return thread_mma.make_fragment_A(tCsAux);
|
||||
}
|
||||
else
|
||||
{
|
||||
return thread_mma.make_fragment_B(tCsAux);
|
||||
}
|
||||
}();
|
||||
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum0)); // M
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum0)); // N
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K
|
||||
CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE
|
||||
if constexpr (SwapAB)
|
||||
{
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<1>(accum1)); // M
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum1)); // N
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCsAux)); // K
|
||||
CUTE_STATIC_ASSERT_V(size<3>(tCsB) == size<3>(tCsAux)); // PIPE
|
||||
}
|
||||
else
|
||||
{
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum1)); // M
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<2>(accum1)); // N
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsAux)); // K
|
||||
CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsAux)); // PIPE
|
||||
}
|
||||
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sA)); // PIPE
|
||||
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sB)); // PIPE
|
||||
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sAux)); // PIPE
|
||||
|
||||
//
|
||||
// PIPELINED MAIN LOOP
|
||||
//
|
||||
static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), "ERROR : Incorrect number of MMAs in flight");
|
||||
|
||||
// We release buffers to producer warps(dma load) with some mmas in flight
|
||||
PipelineState smem_pipe_release = smem_pipe_read;
|
||||
|
||||
// Prologue GMMAs
|
||||
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
|
||||
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
|
||||
GmmaFP8Accumulation accumulation0(accum0, mainloop_params.mma_promotion_interval, size<2>(tCrA));
|
||||
GmmaFP8Accumulation accumulation1(accum1, mainloop_params.mma_promotion_interval, size<2>(tCrA));
|
||||
warpgroup_fence_operand(accumulation0());
|
||||
warpgroup_fence_operand(accumulation1());
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue)
|
||||
{
|
||||
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
|
||||
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
||||
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
||||
|
||||
if (accumulation0.prepare_if_needed())
|
||||
{
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
}
|
||||
|
||||
int read_stage = smem_pipe_read.index();
|
||||
warpgroup_arrive();
|
||||
// Unroll the K mode manually to set scale D to 1
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block)
|
||||
{
|
||||
// (V,M,K) x (V,N,K) => (V,M,N)
|
||||
cute::gemm(
|
||||
tiled_mma, tCrA(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accumulation0());
|
||||
if constexpr (SwapAB)
|
||||
{
|
||||
cute::gemm(
|
||||
tiled_mma, tCrAux(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accumulation1());
|
||||
}
|
||||
else
|
||||
{
|
||||
cute::gemm(
|
||||
tiled_mma, tCrA(_, _, k_block, read_stage), tCrAux(_, _, k_block, read_stage), accumulation1());
|
||||
}
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
}
|
||||
warpgroup_commit_batch();
|
||||
|
||||
accumulation0.promote_if_needed();
|
||||
accumulation1.promote_if_needed();
|
||||
|
||||
++smem_pipe_read;
|
||||
}
|
||||
|
||||
warpgroup_fence_operand(accumulation0());
|
||||
warpgroup_fence_operand(accumulation1());
|
||||
// Mainloop GMMAs
|
||||
k_tile_count -= prologue_mma_count;
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; k_tile_count > 0; --k_tile_count)
|
||||
{
|
||||
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
|
||||
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
||||
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
||||
|
||||
//
|
||||
// Compute on k_tile
|
||||
//
|
||||
|
||||
int read_stage = smem_pipe_read.index();
|
||||
|
||||
if (accumulation0.prepare_if_needed())
|
||||
{
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
}
|
||||
|
||||
warpgroup_fence_operand(accumulation0());
|
||||
warpgroup_fence_operand(accumulation1());
|
||||
warpgroup_arrive();
|
||||
// Unroll the K mode manually to set scale D to 1
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block)
|
||||
{
|
||||
// (V,M,K) x (V,N,K) => (V,M,N)
|
||||
cute::gemm(
|
||||
tiled_mma, tCrA(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accumulation0());
|
||||
if constexpr (SwapAB)
|
||||
{
|
||||
cute::gemm(
|
||||
tiled_mma, tCrAux(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accumulation1());
|
||||
}
|
||||
else
|
||||
{
|
||||
cute::gemm(
|
||||
tiled_mma, tCrA(_, _, k_block, read_stage), tCrAux(_, _, k_block, read_stage), accumulation1());
|
||||
}
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
}
|
||||
warpgroup_commit_batch();
|
||||
|
||||
/// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed
|
||||
warpgroup_wait<K_PIPE_MMAS>();
|
||||
warpgroup_fence_operand(accumulation0());
|
||||
warpgroup_fence_operand(accumulation1());
|
||||
|
||||
accumulation0.promote_if_needed();
|
||||
accumulation1.promote_if_needed();
|
||||
|
||||
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
|
||||
|
||||
// Advance smem_pipe_read and smem_pipe_release
|
||||
++smem_pipe_read;
|
||||
++smem_pipe_release;
|
||||
}
|
||||
|
||||
accumulation0.promote_residue_if_needed();
|
||||
accumulation1.promote_residue_if_needed();
|
||||
|
||||
warpgroup_fence_operand(accumulation0());
|
||||
warpgroup_fence_operand(accumulation1());
|
||||
}
|
||||
|
||||
/// Perform a Consumer Epilogue to release all buffers
|
||||
CUTLASS_DEVICE void mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count)
|
||||
{
|
||||
// Prologue GMMAs
|
||||
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
|
||||
k_tile_count -= prologue_mma_count;
|
||||
|
||||
smem_pipe_release.advance(k_tile_count);
|
||||
|
||||
// Wait on all GMMAs to complete
|
||||
warpgroup_wait<0>();
|
||||
|
||||
for (int count = 0; count < prologue_mma_count; ++count)
|
||||
{
|
||||
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
|
||||
++smem_pipe_release;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -1,438 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*!
|
||||
\file
|
||||
\brief The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and
|
||||
batched array variants.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
// #include <limits>
|
||||
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/device_kernel.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.h"
|
||||
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
|
||||
|
||||
#include "cutlass/gemm/device/default_gemm_configuration.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm_universal.h"
|
||||
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace device
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/*
|
||||
This is the device layer from CUTLASS 2.10 (SHA - cc85b64cf676c45f98a17e3a47c0aafcf817f088)
|
||||
It is replicated here since we needed to duplicate kernel level APIs for mixed dtype GEMMs
|
||||
and SmoothQuant. The newer device layer is not compatible with these older kernel level APIs.
|
||||
|
||||
Note: While CUTLASS 3.x supports stream-k, none of the kernels in the extensions folder support
|
||||
that feature at the moment.
|
||||
*/
|
||||
|
||||
template <typename GemmKernel_>
|
||||
class GemmUniversalBaseCompat
|
||||
{
|
||||
public:
|
||||
using GemmKernel = GemmKernel_;
|
||||
using ThreadblockShape = typename GemmKernel::Mma::Shape;
|
||||
|
||||
using ElementA = typename GemmKernel::ElementA;
|
||||
using LayoutA = typename GemmKernel::LayoutA;
|
||||
using TensorRefA = TensorRef<ElementA const, LayoutA>;
|
||||
static ComplexTransform const kTransformA = GemmKernel::kTransformA;
|
||||
|
||||
using ElementB = typename GemmKernel::ElementB;
|
||||
using LayoutB = typename GemmKernel::LayoutB;
|
||||
using TensorRefB = TensorRef<ElementB const, LayoutB>;
|
||||
static ComplexTransform const kTransformB = GemmKernel::kTransformB;
|
||||
|
||||
using ElementC = typename GemmKernel::ElementC;
|
||||
using LayoutC = typename GemmKernel::LayoutC;
|
||||
using TensorRefC = TensorRef<ElementC const, LayoutC>;
|
||||
using TensorRefD = TensorRef<ElementC, LayoutC>;
|
||||
|
||||
using ElementAccumulator = typename GemmKernel::Mma::Policy::Operator::ElementC;
|
||||
|
||||
using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp;
|
||||
using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle;
|
||||
using Operator = typename GemmKernel::Operator;
|
||||
|
||||
/// Argument structure
|
||||
using Arguments = typename GemmKernel::Arguments;
|
||||
|
||||
protected:
|
||||
/// Kernel parameters object
|
||||
typename GemmKernel::Params params_;
|
||||
|
||||
protected:
|
||||
/// Private helper to obtain the grid dimensions with fix-up for split-K
|
||||
static void get_grid_shape_(gemm::GemmCoord& grid_tiled_shape, int& gemm_k_size, Arguments const& args)
|
||||
{
|
||||
|
||||
// Determine grid shape
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
grid_tiled_shape = threadblock_swizzle.get_tiled_shape(
|
||||
args.problem_size, {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count);
|
||||
|
||||
gemm_k_size = args.problem_size.k();
|
||||
|
||||
if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel)
|
||||
{
|
||||
|
||||
int const kAlignK
|
||||
= const_max(const_max(128 / sizeof_bits<ElementA>::value, 128 / sizeof_bits<ElementB>::value), 1);
|
||||
|
||||
gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK);
|
||||
|
||||
if (gemm_k_size)
|
||||
{
|
||||
grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
/// Constructs the GEMM.
|
||||
GemmUniversalBaseCompat() {}
|
||||
|
||||
/// Determines whether the GEMM can execute the given problem.
|
||||
static Status can_implement(Arguments const& args)
|
||||
{
|
||||
|
||||
// Determine grid shape
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int gemm_k_size = 0;
|
||||
|
||||
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
|
||||
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
dim3 grid = threadblock_swizzle.get_grid_shape(grid_tiled_shape);
|
||||
|
||||
uint32_t const kGridYZMax = ((1 << (sizeof(uint16_t) * 8)) - 1);
|
||||
|
||||
if (!(grid.y <= kGridYZMax && grid.z <= kGridYZMax))
|
||||
{
|
||||
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
return GemmKernel::can_implement(args);
|
||||
}
|
||||
|
||||
/// Gets the workspace size
|
||||
static size_t get_workspace_size(Arguments const& args)
|
||||
{
|
||||
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_workspace_size()");
|
||||
|
||||
size_t workspace_bytes = 0;
|
||||
|
||||
// Determine grid shape
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int gemm_k_size = 0;
|
||||
|
||||
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
|
||||
|
||||
if (args.mode == GemmUniversalMode::kGemmSplitKParallel)
|
||||
{
|
||||
|
||||
// Split-K parallel always requires a temporary workspace
|
||||
workspace_bytes = sizeof(ElementC) * size_t(args.batch_stride_D) * size_t(grid_tiled_shape.k());
|
||||
}
|
||||
else if (args.mode == GemmUniversalMode::kGemm && grid_tiled_shape.k() > 1)
|
||||
{
|
||||
|
||||
// Serial split-K only requires a temporary workspace if the number of partitions along the
|
||||
// GEMM K dimension is greater than one.
|
||||
workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n());
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes);
|
||||
|
||||
workspace_bytes += GemmKernel::get_extra_workspace_size(args, grid_tiled_shape);
|
||||
|
||||
return workspace_bytes;
|
||||
}
|
||||
|
||||
/// Computes the grid shape
|
||||
static dim3 get_grid_shape(Arguments const& args)
|
||||
{
|
||||
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_grid_shape()");
|
||||
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int gemm_k_size = 0;
|
||||
|
||||
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
|
||||
dim3 result = threadblock_swizzle.get_grid_shape(grid_tiled_shape);
|
||||
|
||||
CUTLASS_TRACE_HOST(" grid_tiled_shape: " << grid_tiled_shape << "\n"
|
||||
<< " result = {" << result << "}");
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Computes the maximum number of active blocks per multiprocessor
|
||||
static int maximum_active_blocks(int smem_capacity = -1)
|
||||
{
|
||||
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::maximum_active_blocks()");
|
||||
|
||||
int max_active_blocks = -1;
|
||||
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
|
||||
|
||||
CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes");
|
||||
|
||||
if (smem_size <= (48 << 10))
|
||||
{
|
||||
|
||||
cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_active_blocks, Kernel<GemmKernel>, GemmKernel::kThreadCount, smem_size);
|
||||
|
||||
if (result == cudaSuccess)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
|
||||
return max_active_blocks;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
// Query assuming zero shared memory then compute occupancy limit based on SMEM
|
||||
cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_active_blocks, Kernel<GemmKernel>, GemmKernel::kThreadCount, 0);
|
||||
|
||||
if (result != cudaSuccess)
|
||||
{
|
||||
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result));
|
||||
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (smem_capacity < 0)
|
||||
{
|
||||
int device_idx = 0;
|
||||
result = cudaGetDevice(&device_idx);
|
||||
|
||||
if (result != cudaSuccess)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
|
||||
cudaDeviceProp properties;
|
||||
result = cudaGetDeviceProperties(&properties, device_idx);
|
||||
|
||||
if (result != cudaSuccess)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
|
||||
smem_capacity = static_cast<int>(properties.sharedMemPerMultiprocessor);
|
||||
}
|
||||
|
||||
int occupancy = std::min(max_active_blocks, smem_capacity / smem_size);
|
||||
|
||||
CUTLASS_TRACE_HOST(" occupancy: " << occupancy);
|
||||
|
||||
return occupancy;
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST(" returning internal error");
|
||||
|
||||
return -1;
|
||||
}
|
||||
|
||||
/// Initializes GEMM state from arguments.
|
||||
Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr)
|
||||
{
|
||||
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::initialize() - workspace "
|
||||
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
|
||||
|
||||
size_t workspace_bytes = get_workspace_size(args);
|
||||
|
||||
CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes);
|
||||
|
||||
if (workspace_bytes)
|
||||
{
|
||||
|
||||
if (!workspace)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(" error: device workspace must not be null");
|
||||
|
||||
return Status::kErrorWorkspaceNull;
|
||||
}
|
||||
|
||||
if (args.mode == GemmUniversalMode::kGemm)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(" clearing device workspace");
|
||||
cudaError_t result = cudaMemsetAsync(workspace, 0, workspace_bytes, stream);
|
||||
|
||||
if (result != cudaSuccess)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(" cudaMemsetAsync() returned error " << cudaGetErrorString(result));
|
||||
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get CUDA grid shape
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int gemm_k_size = 0;
|
||||
|
||||
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
|
||||
|
||||
// Initialize the Params structure
|
||||
params_ = typename GemmKernel::Params(args, grid_tiled_shape, gemm_k_size, static_cast<int*>(workspace));
|
||||
|
||||
// Specify shared memory capacity for kernel.
|
||||
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
|
||||
|
||||
if (smem_size >= (48 << 10))
|
||||
{
|
||||
cudaError_t result
|
||||
= cudaFuncSetAttribute(Kernel<GemmKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
|
||||
if (result != cudaSuccess)
|
||||
{
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Lightweight update given a subset of arguments
|
||||
Status update(Arguments const& args, void* workspace = nullptr)
|
||||
{
|
||||
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat()::update() - workspace: " << workspace);
|
||||
|
||||
size_t workspace_bytes = get_workspace_size(args);
|
||||
|
||||
if (workspace_bytes && !workspace)
|
||||
{
|
||||
return Status::kErrorWorkspaceNull;
|
||||
}
|
||||
|
||||
params_.update(args, workspace);
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status run(cudaStream_t stream = nullptr)
|
||||
{
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::run()");
|
||||
|
||||
//
|
||||
// Configure grid and block dimensions
|
||||
//
|
||||
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
|
||||
dim3 block(GemmKernel::kThreadCount, 1, 1);
|
||||
|
||||
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
|
||||
|
||||
//
|
||||
// Launch kernel
|
||||
//
|
||||
|
||||
CUTLASS_TRACE_HOST(" grid: (" << grid << "), block: (" << block << "), SMEM: " << smem_size << " bytes");
|
||||
|
||||
// Launch
|
||||
cutlass::Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params_);
|
||||
|
||||
//
|
||||
// Query for errors
|
||||
//
|
||||
cudaError_t result = cudaGetLastError();
|
||||
|
||||
if (result != cudaSuccess)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result));
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(cudaStream_t stream = nullptr)
|
||||
{
|
||||
return run(stream);
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr)
|
||||
{
|
||||
|
||||
Status status = initialize(args, workspace, stream);
|
||||
|
||||
if (status == Status::kSuccess)
|
||||
{
|
||||
status = run(stream);
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace device
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -1,542 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*!
|
||||
\file
|
||||
\brief Based on cutlass/include/cutlass/gemm/kernel/gemm_grouped.h
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <limits>
|
||||
#include <numeric>
|
||||
#include <vector>
|
||||
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/device_kernel.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.h"
|
||||
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
|
||||
|
||||
#include "cutlass/gemm/device/default_gemm_configuration.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm_universal.h"
|
||||
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace device
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T_IN, typename T_OUT>
|
||||
__global__ void splitkReduction(T_OUT** out_tensor, const T_IN* in_tensor, GemmCoord const* problem_sizes, int splitk,
|
||||
int64_t* splitk_buffer_offsets)
|
||||
{
|
||||
// in_tensor: [problem_idx, k_partition, hidden_size]
|
||||
// Note that different requests of in_tensor might have different hidden_size (=m*n)
|
||||
// so, we need to use splitk_buffer_offsets.
|
||||
// out_tensor: problem_idx * [hidden_size]
|
||||
|
||||
int const problem_idx = blockIdx.y;
|
||||
GemmCoord problem = problem_sizes[problem_idx];
|
||||
int const hidden_size = problem.m() * problem.n();
|
||||
const T_IN* in_tensor_ = in_tensor + splitk_buffer_offsets[problem_idx] * splitk;
|
||||
T_OUT* out_tensor_ = out_tensor[problem_idx];
|
||||
|
||||
for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < hidden_size; i += blockDim.x * gridDim.x)
|
||||
{
|
||||
float sum = 0.0f;
|
||||
for (int k_idx = 0; k_idx < splitk; k_idx++)
|
||||
{
|
||||
sum += (float) in_tensor_[k_idx * hidden_size + i];
|
||||
}
|
||||
out_tensor_[i] = (T_OUT) (sum);
|
||||
}
|
||||
}
|
||||
|
||||
/// GEMM Grouped
|
||||
template <typename BaseKernel_>
|
||||
class BaseSplitkGrouped
|
||||
{
|
||||
public:
|
||||
using BaseKernel = BaseKernel_;
|
||||
|
||||
using ElementA = typename BaseKernel::ElementA;
|
||||
using LayoutA = typename BaseKernel::LayoutA;
|
||||
using TensorRefA = TensorRef<ElementA const, LayoutA>;
|
||||
static ComplexTransform const kTransformA = BaseKernel::kTransformA;
|
||||
static int const kAlignmentA = BaseKernel::kAlignmentA;
|
||||
|
||||
using ElementB = typename BaseKernel::ElementB;
|
||||
using LayoutB = typename BaseKernel::LayoutB;
|
||||
using TensorRefB = TensorRef<ElementB const, LayoutB>;
|
||||
static ComplexTransform const kTransformB = BaseKernel::kTransformB;
|
||||
static int const kAlignmentB = BaseKernel::kAlignmentB;
|
||||
|
||||
using ElementC = typename BaseKernel::ElementC;
|
||||
using LayoutC = typename BaseKernel::LayoutC;
|
||||
using TensorRefC = TensorRef<ElementC const, LayoutC>;
|
||||
using TensorRefD = TensorRef<ElementC, LayoutC>;
|
||||
static int const kAlignmentC = BaseKernel::kAlignmentC;
|
||||
|
||||
using ElementAccumulator = typename BaseKernel::Mma::Policy::Operator::ElementC;
|
||||
|
||||
using EpilogueOutputOp = typename BaseKernel::EpilogueOutputOp;
|
||||
using ThreadblockSwizzle = typename threadblock::GemmSplitKHorizontalThreadblockSwizzle;
|
||||
|
||||
using Operator = typename BaseKernel::Operator;
|
||||
using WarpMmaOperator = typename BaseKernel::Mma::Policy::Operator;
|
||||
|
||||
using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator;
|
||||
using MathOperator = typename WarpMmaOperator::MathOperator;
|
||||
using OperatorClass = typename WarpMmaOperator::OperatorClass;
|
||||
using ArchTag = typename WarpMmaOperator::ArchTag;
|
||||
using ThreadblockShape = typename BaseKernel::Mma::Shape;
|
||||
using WarpShape = typename BaseKernel::WarpShape;
|
||||
using InstructionShape = typename BaseKernel::InstructionShape;
|
||||
static int const kStages = BaseKernel::Mma::kStages;
|
||||
|
||||
/// Argument structure
|
||||
using Arguments = typename BaseKernel::Arguments;
|
||||
|
||||
using ProblemInfo = typename BaseKernel::ProblemVisitor::ProblemInfo;
|
||||
|
||||
protected:
|
||||
/// Kernel parameters object
|
||||
typename BaseKernel::Params gemm_params_;
|
||||
|
||||
private:
|
||||
/// Get the number of tiles across all problems in a group
|
||||
static int32_t group_tile_count(cutlass::gemm::GemmCoord const* problem_sizes_ptr, int problem_count)
|
||||
{
|
||||
int32_t tiles = 0;
|
||||
for (int32_t i = 0; i < problem_count; ++i)
|
||||
{
|
||||
cutlass::gemm::GemmCoord problem = problem_sizes_ptr[i];
|
||||
BaseKernel::ProblemVisitor::possibly_transpose_problem(problem);
|
||||
tiles += problem_tile_count(problem);
|
||||
}
|
||||
return tiles;
|
||||
}
|
||||
|
||||
/// Copy from `data` to `workspace`
|
||||
Status copy_to_workspace(void* workspace, void* data, size_t bytes)
|
||||
{
|
||||
cudaError_t cuda_error = cudaMemcpy(workspace, data, bytes, cudaMemcpyHostToDevice);
|
||||
if (cuda_error != cudaSuccess)
|
||||
{
|
||||
// Call cudaGetLastError() to clear the error bit
|
||||
cuda_error = cudaGetLastError();
|
||||
CUTLASS_TRACE_HOST(" cudaMemcpy() returned error " << cudaGetErrorString(cuda_error));
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Precomputes scheduling information for the grouped GEMM
|
||||
Status precompute(Arguments const& args, int32_t tile_count, void* workspace)
|
||||
{
|
||||
size_t workspace_bytes = get_workspace_size(args);
|
||||
std::vector<uint8_t> host_workspace(workspace_bytes);
|
||||
BaseKernel::ProblemVisitor::host_precompute(
|
||||
args.host_problem_sizes, args.problem_count, args.threadblock_count, (void*) host_workspace.data());
|
||||
return copy_to_workspace(workspace, host_workspace.data(), workspace_bytes);
|
||||
}
|
||||
|
||||
/// Reorder `data` according to `indices`
|
||||
template <typename T>
|
||||
static void reorder_array(T* data, std::vector<size_t> const& indices)
|
||||
{
|
||||
// For now, simply create a copy of the data and then copy over to the original.
|
||||
std::vector<T> copy(indices.size());
|
||||
for (size_t i = 0; i < indices.size(); ++i)
|
||||
{
|
||||
copy.at(i) = data[indices[i]];
|
||||
}
|
||||
|
||||
memcpy(data, copy.data(), indices.size() * sizeof(T));
|
||||
}
|
||||
|
||||
public:
|
||||
/// Constructs the GEMM.
|
||||
BaseSplitkGrouped() {}
|
||||
|
||||
/// Determines whether the GEMM can execute the given problem.
|
||||
static Status can_implement(Arguments const& args)
|
||||
{
|
||||
|
||||
return BaseKernel::can_implement(args);
|
||||
}
|
||||
|
||||
/// Get the number of tiles in a problem
|
||||
static int32_t problem_tile_count(cutlass::gemm::GemmCoord const& problem)
|
||||
{
|
||||
auto grid = BaseKernel::ProblemVisitor::grid_shape(problem);
|
||||
return BaseKernel::ProblemVisitor::tile_count(grid);
|
||||
}
|
||||
|
||||
/// Get the number of tiles across all problems in a group
|
||||
static int32_t group_tile_count(Arguments const& args)
|
||||
{
|
||||
if (args.host_problem_sizes == nullptr)
|
||||
{
|
||||
CUTLASS_TRACE_HOST("Received nullptr for `args.host_problem_sizes");
|
||||
return -1;
|
||||
}
|
||||
|
||||
return group_tile_count(args.host_problem_sizes, args.problem_count);
|
||||
}
|
||||
|
||||
/// Gets the workspace size
|
||||
static size_t get_workspace_size(Arguments const& args)
|
||||
{
|
||||
size_t total_mn = 0;
|
||||
for (int i = 0; i < args.problem_count; i++)
|
||||
{
|
||||
total_mn += args.host_problem_sizes[i].m() * args.host_problem_sizes[i].n();
|
||||
}
|
||||
size_t workSpaceSize = total_mn * sizeof(ElementAccumulator) * args.split_k_slices;
|
||||
|
||||
if (BaseKernel::ProblemVisitor::kRequiresPrecomputation)
|
||||
{
|
||||
workSpaceSize += BaseKernel::ProblemVisitor::get_workspace_size(
|
||||
args.host_problem_sizes, args.problem_count, args.threadblock_count);
|
||||
}
|
||||
return workSpaceSize;
|
||||
}
|
||||
|
||||
/// Computes the grid shape
|
||||
static dim3 get_grid_shape(Arguments const& args)
|
||||
{
|
||||
|
||||
return dim3(args.threadblock_count, 1, 1);
|
||||
}
|
||||
|
||||
/// Computes the maximum number of active blocks per multiprocessor
|
||||
static int maximum_active_blocks(int smem_capacity = -1)
|
||||
{
|
||||
|
||||
CUTLASS_TRACE_HOST("BaseSplitkGrouped::maximum_active_blocks()");
|
||||
|
||||
int smem_size = int(sizeof(typename BaseKernel::SharedStorage));
|
||||
|
||||
CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes");
|
||||
|
||||
cudaError_t result;
|
||||
if (smem_size > (48 << 10))
|
||||
{
|
||||
result = cudaFuncSetAttribute(Kernel<BaseKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
|
||||
if (result != cudaSuccess)
|
||||
{
|
||||
// Call cudaGetLastError() to clear the error bit
|
||||
result = cudaGetLastError();
|
||||
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(result));
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
int max_active_blocks = -1;
|
||||
result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_active_blocks, Kernel<BaseKernel>, BaseKernel::kThreadCount, smem_size);
|
||||
|
||||
if (result != cudaSuccess)
|
||||
{
|
||||
// Call cudaGetLastError() to clear the error bit
|
||||
result = cudaGetLastError();
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result));
|
||||
return -1;
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
|
||||
return max_active_blocks;
|
||||
}
|
||||
|
||||
/// Sorts each pointer passed in according to the indices that sort
|
||||
/// `problem_sizes_ptr` in descending order of problem-K dimension.
|
||||
static void sort_problems(int problem_count, cutlass::gemm::GemmCoord* problem_sizes_ptr, int64_t* lda_host_ptr,
|
||||
int64_t* ldb_host_ptr, int64_t* ldc_host_ptr, int64_t* ldd_host_ptr, int64_t* offset_A_ptr,
|
||||
int64_t* offset_B_ptr, int64_t* offset_C_ptr, int64_t* offset_D_ptr)
|
||||
{
|
||||
std::vector<size_t> indices(problem_count);
|
||||
std::iota(indices.begin(), indices.end(), 0);
|
||||
std::stable_sort(indices.begin(), indices.end(),
|
||||
[&problem_sizes_ptr](size_t i, size_t j) { return problem_sizes_ptr[i].k() > problem_sizes_ptr[j].k(); });
|
||||
|
||||
reorder_array(problem_sizes_ptr, indices);
|
||||
reorder_array(lda_host_ptr, indices);
|
||||
reorder_array(ldb_host_ptr, indices);
|
||||
reorder_array(ldc_host_ptr, indices);
|
||||
reorder_array(ldd_host_ptr, indices);
|
||||
reorder_array(offset_A_ptr, indices);
|
||||
reorder_array(offset_B_ptr, indices);
|
||||
reorder_array(offset_C_ptr, indices);
|
||||
reorder_array(offset_D_ptr, indices);
|
||||
}
|
||||
|
||||
/// Computes the number of threadblocks to launch for the grouped kernel
|
||||
static int sufficient(
|
||||
cutlass::gemm::GemmCoord const* problem_sizes_ptr = nullptr, int problem_count = 0, int available_sm_count = -1)
|
||||
{
|
||||
// Determine the number of blocks that would be launched to fill up a single
|
||||
// wave on the GPU with each SM having maximum occupancy.
|
||||
int device_idx;
|
||||
cudaError_t result = cudaGetDevice(&device_idx);
|
||||
if (result != cudaSuccess)
|
||||
{
|
||||
// Call cudaGetLastError() to clear the error bit
|
||||
result = cudaGetLastError();
|
||||
CUTLASS_TRACE_HOST(" cudaGetDevice() returned error " << cudaGetErrorString(result));
|
||||
return 0;
|
||||
}
|
||||
|
||||
int multiprocessor_count;
|
||||
result = cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device_idx);
|
||||
if (result != cudaSuccess)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(" cudaDeviceGetAttribute() returned error " << cudaGetErrorString(result));
|
||||
return 0;
|
||||
}
|
||||
|
||||
bool override_sm_count = (available_sm_count < 0 || available_sm_count > multiprocessor_count);
|
||||
if (override_sm_count)
|
||||
{
|
||||
available_sm_count = multiprocessor_count;
|
||||
}
|
||||
|
||||
int max_active_blocks = maximum_active_blocks();
|
||||
if (max_active_blocks <= 0)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
int occupancy_based_block_count = available_sm_count * max_active_blocks;
|
||||
|
||||
if (problem_sizes_ptr == nullptr || problem_count == 0)
|
||||
{
|
||||
return occupancy_based_block_count;
|
||||
}
|
||||
|
||||
int total_tiles = group_tile_count(problem_sizes_ptr, problem_count);
|
||||
|
||||
// If the group contains a single problem, launching the exact number of
|
||||
// threadblocks needed to cover the problem minimizes the work performed
|
||||
// per threadblock in finding the next tile to compute. We return total_tiles
|
||||
// unless the user has provided the SM count.
|
||||
if (problem_count == 1 && override_sm_count)
|
||||
{
|
||||
return total_tiles;
|
||||
}
|
||||
|
||||
// Choose between the full wave of threadblocks and the tile count. If there
|
||||
// are fewer tiles in the group than threadblocks in the full wave, only
|
||||
// some threadblocks will be assigned tiles. Those threadblocks
|
||||
// which are not assigned tiles still need to perform the work of iterating through
|
||||
// problem sizes to determine that they have no work to do. This competes for cycles
|
||||
// with those threadblocks that are assigned tiles to compute.
|
||||
return std::min(total_tiles, occupancy_based_block_count);
|
||||
}
|
||||
|
||||
/// Initializes GEMM state from arguments.
|
||||
Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr)
|
||||
{
|
||||
|
||||
CUTLASS_TRACE_HOST("BaseSplitkGrouped::initialize() - workspace "
|
||||
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
|
||||
|
||||
// Workspace
|
||||
size_t workspace_bytes = get_workspace_size(args);
|
||||
|
||||
if (workspace_bytes && !workspace)
|
||||
{
|
||||
return Status::kErrorWorkspaceNull;
|
||||
}
|
||||
|
||||
if (BaseKernel::ProblemVisitor::kRequiresPrecomputation)
|
||||
{
|
||||
int32_t tile_count = group_tile_count(args);
|
||||
Status status = precompute(args, tile_count, workspace);
|
||||
if (status != Status::kSuccess)
|
||||
{
|
||||
return status;
|
||||
}
|
||||
|
||||
gemm_params_ = typename BaseKernel::Params(args, workspace, tile_count);
|
||||
}
|
||||
else
|
||||
{
|
||||
gemm_params_ = typename BaseKernel::Params(args, workspace);
|
||||
}
|
||||
|
||||
// Specify shared memory capacity for kernel.
|
||||
int smem_size = int(sizeof(typename BaseKernel::SharedStorage));
|
||||
|
||||
if (smem_size >= (48 << 10))
|
||||
{
|
||||
cudaError_t result
|
||||
= cudaFuncSetAttribute(Kernel<BaseKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
|
||||
if (result != cudaSuccess)
|
||||
{
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Lightweight update given a subset of arguments
|
||||
Status update(Arguments const& args, void* workspace = nullptr)
|
||||
{
|
||||
|
||||
size_t workspace_bytes = get_workspace_size(args);
|
||||
|
||||
if (workspace_bytes && !workspace)
|
||||
{
|
||||
return Status::kErrorWorkspaceNull;
|
||||
}
|
||||
|
||||
if (BaseKernel::ProblemVisitor::kRequiresPrecomputation)
|
||||
{
|
||||
int32_t tile_count = group_tile_count(args);
|
||||
Status status = precompute(args, tile_count, workspace);
|
||||
if (status != Status::kSuccess)
|
||||
{
|
||||
return status;
|
||||
}
|
||||
|
||||
gemm_params_.update(args, workspace, tile_count);
|
||||
}
|
||||
else
|
||||
{
|
||||
gemm_params_.update(args, workspace);
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status run(cudaStream_t stream = nullptr)
|
||||
{
|
||||
if (!gemm_params_.problem_visitor.problem_count)
|
||||
{
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
//
|
||||
// Launch kernel
|
||||
//
|
||||
|
||||
// Launch splitk grouped gemm
|
||||
{
|
||||
dim3 grid(gemm_params_.threadblock_count, 1, gemm_params_.split_k_slices);
|
||||
dim3 block(BaseKernel::kThreadCount, 1, 1);
|
||||
|
||||
int smem_size = int(sizeof(typename BaseKernel::SharedStorage));
|
||||
cutlass::Kernel<BaseKernel><<<grid, block, smem_size, stream>>>(gemm_params_);
|
||||
|
||||
cudaError_t result = cudaGetLastError();
|
||||
|
||||
if (result != cudaSuccess)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result));
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
// Launch splitkReduction
|
||||
{
|
||||
dim3 grid(32, gemm_params_.problem_visitor.problem_count);
|
||||
dim3 block(256);
|
||||
splitkReduction<<<grid, block, 0, stream>>>(gemm_params_.ptr_D, gemm_params_.ptr_D_split,
|
||||
gemm_params_.problem_visitor.problem_sizes, gemm_params_.split_k_slices,
|
||||
gemm_params_.splitk_buffer_offsets);
|
||||
|
||||
cudaError_t result = cudaGetLastError();
|
||||
|
||||
if (result != cudaSuccess)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result));
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(cudaStream_t stream = nullptr)
|
||||
{
|
||||
return run(stream);
|
||||
}
|
||||
|
||||
/// Initializes and runs the kernel.
|
||||
Status operator()(Arguments const& args, void* workspace, cudaStream_t stream = nullptr)
|
||||
{
|
||||
|
||||
Status status = initialize(args, workspace, stream);
|
||||
|
||||
if (status == Status::kSuccess)
|
||||
{
|
||||
status = run(stream);
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// GEMM Grouped
|
||||
template <typename GemmKernel_>
|
||||
class SplitkGemmGrouped : public BaseSplitkGrouped<GemmKernel_>
|
||||
{
|
||||
public:
|
||||
using GemmKernel = GemmKernel_;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace device
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -1,162 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/arch/mma.h"
|
||||
#include "cutlass/bfloat16.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/half.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
|
||||
#include "cutlass_extensions/arch/mma.h"
|
||||
#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h"
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace kernel
|
||||
{
|
||||
|
||||
template <typename TypeA, typename TypeB, typename arch, typename Enable = void>
|
||||
struct MixedGemmArchTraits
|
||||
{
|
||||
static_assert(dependent_false<arch>, "Unrecognised parameterization");
|
||||
};
|
||||
|
||||
template <typename Arch>
|
||||
struct MixedGemmArchTraits<float, float, Arch>
|
||||
{
|
||||
static constexpr int Stages = 2;
|
||||
using OperatorClass = cutlass::arch::OpClassSimt;
|
||||
using AccType = float;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
|
||||
static constexpr int ElementsPerAccessA = 1;
|
||||
static constexpr int ElementsPerAccessB = 1;
|
||||
static constexpr int ElementsPerAccessC = 1;
|
||||
static constexpr int ThreadblockK = 8;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
|
||||
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
};
|
||||
|
||||
// ======================= Turing Traits ==============================
|
||||
// Note that turing does not have native bfloat support so weights and activations will be casted to fp16
|
||||
// and compute will happen in fp16 then will be converted for bf16 output.
|
||||
template <typename TypeA, typename TypeB>
|
||||
struct MixedGemmArchTraits<TypeA, TypeB, cutlass::arch::Sm75,
|
||||
typename cutlass::platform::enable_if<cutlass::platform::is_same<TypeA, cutlass::half_t>::value
|
||||
|| cutlass::platform::is_same<TypeA, cutlass::bfloat16_t>::value>::type>
|
||||
{
|
||||
private:
|
||||
using LayoutDetails = LayoutDetailsB<TypeA, TypeB, cutlass::arch::Sm75>;
|
||||
|
||||
public:
|
||||
static constexpr int ThreadblockK = LayoutDetails::ThreadblockK;
|
||||
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
using AccType = float;
|
||||
using LayoutB = typename LayoutDetails::Layout;
|
||||
|
||||
static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits<TypeA>::value;
|
||||
static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess;
|
||||
static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits<TypeA>::value;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
|
||||
|
||||
using Operator = typename LayoutDetails::Operator;
|
||||
};
|
||||
|
||||
// ======================= Ampere Traits ==============================
|
||||
template <typename TypeA, typename TypeB>
|
||||
struct MixedGemmArchTraits<TypeA, TypeB, cutlass::arch::Sm80,
|
||||
typename cutlass::platform::enable_if<cutlass::platform::is_same<TypeA, cutlass::half_t>::value
|
||||
|| cutlass::platform::is_same<TypeA, cutlass::bfloat16_t>::value>::type>
|
||||
{
|
||||
private:
|
||||
using LayoutDetails = LayoutDetailsB<TypeA, TypeB, cutlass::arch::Sm80>;
|
||||
|
||||
public:
|
||||
static constexpr int ThreadblockK = LayoutDetails::ThreadblockK;
|
||||
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
using AccType = float;
|
||||
using LayoutB = typename LayoutDetails::Layout;
|
||||
|
||||
static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits<TypeA>::value;
|
||||
static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess;
|
||||
static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits<TypeA>::value;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
|
||||
using Operator = typename LayoutDetails::Operator;
|
||||
};
|
||||
|
||||
// ======================= Ada Traits ==============================
|
||||
template <typename TypeA, typename TypeB>
|
||||
struct MixedGemmArchTraits<TypeA, TypeB, cutlass::arch::Sm89,
|
||||
typename cutlass::platform::enable_if<cutlass::platform::is_same<TypeA, cutlass::half_t>::value
|
||||
|| cutlass::platform::is_same<TypeA, cutlass::bfloat16_t>::value>::type>
|
||||
{
|
||||
private:
|
||||
using LayoutDetails = LayoutDetailsB<TypeA, TypeB, cutlass::arch::Sm89>;
|
||||
|
||||
public:
|
||||
static constexpr int ThreadblockK = LayoutDetails::ThreadblockK;
|
||||
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
using AccType = float;
|
||||
using LayoutB = typename LayoutDetails::Layout;
|
||||
|
||||
static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits<TypeA>::value;
|
||||
static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess;
|
||||
static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits<TypeA>::value;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256 / cutlass::sizeof_bits<TypeA>::value>;
|
||||
|
||||
using Operator = typename LayoutDetails::Operator;
|
||||
};
|
||||
|
||||
// FP8 A/B = fp8, C/D = fp32
|
||||
template <typename TypeA, typename TypeB>
|
||||
struct MixedGemmArchTraits<TypeA, TypeB, cutlass::arch::Sm89,
|
||||
typename cutlass::platform::enable_if<cutlass::platform::is_same<TypeA, cutlass::float_e4m3_t>::value
|
||||
|| cutlass::platform::is_same<TypeA, cutlass::float_e5m2_t>::value>::type>
|
||||
{
|
||||
private:
|
||||
using LayoutDetails = LayoutDetailsB<TypeA, TypeB, cutlass::arch::Sm89>;
|
||||
|
||||
public:
|
||||
static constexpr int ThreadblockK = LayoutDetails::ThreadblockK;
|
||||
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
using AccType = float;
|
||||
// be careful, TypeC should align with HopperGroupedGemmInput::OutputTypeAdaptor_t<TypeA>
|
||||
using TypeC = __nv_bfloat16;
|
||||
using LayoutB = typename LayoutDetails::Layout;
|
||||
|
||||
static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits<TypeA>::value;
|
||||
static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess;
|
||||
static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits<TypeC>::value;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256 / cutlass::sizeof_bits<TypeA>::value>;
|
||||
|
||||
using Operator = typename LayoutDetails::Operator;
|
||||
};
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@@ -1,57 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/arch/mma.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace kernel
|
||||
{
|
||||
|
||||
template <typename arch>
|
||||
struct Int8GemmArchTraits
|
||||
{
|
||||
using OperatorClass = cutlass::arch::OpClassSimt;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
|
||||
};
|
||||
|
||||
// ======================= Turing Traits ==============================
|
||||
template <>
|
||||
struct Int8GemmArchTraits<cutlass::arch::Sm75>
|
||||
{
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>;
|
||||
};
|
||||
|
||||
// ======================= Ampere Traits ==============================
|
||||
template <>
|
||||
struct Int8GemmArchTraits<cutlass::arch::Sm80>
|
||||
{
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
};
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@@ -1,207 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with
|
||||
the appropriate threadblock-scoped epilogue.
|
||||
|
||||
Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are
|
||||
accommodated by exchanging A and B operands and assuming transposed layouts. Partial
|
||||
specializations here choose 'device::GemmTransposed' to implement this functionality.
|
||||
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/device/default_gemm_configuration.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm_complex.h"
|
||||
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
|
||||
|
||||
#include "cutlass/layout/permute.h"
|
||||
|
||||
#include "splitk_gemm_grouped.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace kernel
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA_,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA_,
|
||||
/// Complex elementwise transformation on A operand
|
||||
ComplexTransform TransformA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB_,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB_,
|
||||
/// Complex elementwise transformation on B operand
|
||||
ComplexTransform TransformB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC_,
|
||||
/// Layout type for C and D matrix operands
|
||||
typename LayoutC_,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Operator class tag
|
||||
typename OperatorClass,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Whether the schedule of problems to visit has been precomputed
|
||||
GroupScheduleMode GroupScheduleMode_ = GroupScheduleMode::kDeviceOnly,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator = typename device::DefaultGemmConfiguration<OperatorClass, ArchTag, ElementA_, ElementB_,
|
||||
ElementC_, ElementAccumulator>::Operator,
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
|
||||
/// Permute result D
|
||||
typename PermuteDLayout = layout::NoPermute,
|
||||
///
|
||||
typename Enable = void>
|
||||
struct DefaultSplitkGemmGrouped;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Real-valued GEMM kernels
|
||||
//
|
||||
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC,
|
||||
/// Layout type for C and D matrix operands
|
||||
typename LayoutC,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Operator class tag
|
||||
typename OperatorClass,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Whether the schedule of problems to visit has been precomputed
|
||||
GroupScheduleMode GroupScheduleMode_,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
SharedMemoryClearOption SharedMemoryClear,
|
||||
/// Permute result D
|
||||
typename PermuteDLayout>
|
||||
struct DefaultSplitkGemmGrouped<ElementA, LayoutA,
|
||||
ComplexTransform::kNone, // transform A
|
||||
kAlignmentA, ElementB, LayoutB,
|
||||
ComplexTransform::kNone, // transform B
|
||||
kAlignmentB, ElementC, LayoutC, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape,
|
||||
InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, GroupScheduleMode_, Operator, SharedMemoryClear,
|
||||
PermuteDLayout, typename platform::enable_if<!cutlass::is_complex<ElementAccumulator>::value>::type>
|
||||
{
|
||||
|
||||
// If true, we must construct a 'transposed-and-exchanged' Mma operator.
|
||||
static bool const kInternalTranspose = platform::is_same<LayoutC, layout::ColumnMajor>::value;
|
||||
|
||||
using MapArguments = kernel::detail::MapArguments<ElementA, LayoutA, ComplexTransform::kNone, kAlignmentA, ElementB,
|
||||
LayoutB, ComplexTransform::kNone, kAlignmentB, LayoutC, kInternalTranspose>;
|
||||
|
||||
// Define the default GEMM kernel
|
||||
using DefaultGemmKernel = typename kernel::DefaultGemm<typename MapArguments::ElementA,
|
||||
typename MapArguments::LayoutA, MapArguments::kAlignmentA, typename MapArguments::ElementB,
|
||||
typename MapArguments::LayoutB, MapArguments::kAlignmentB, ElementC, typename MapArguments::LayoutC,
|
||||
ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp,
|
||||
ThreadblockSwizzle, Stages, true, Operator, SharedMemoryClear, false, /*GatherA*/
|
||||
false, /*GatherB*/
|
||||
false, /*ScatterD*/
|
||||
PermuteDLayout>::GemmKernel;
|
||||
|
||||
/// Define the kernel in terms of the default kernel
|
||||
using GemmKernel = kernel::SplitkGemmGrouped<typename DefaultGemmKernel::Mma, typename DefaultGemmKernel::Epilogue,
|
||||
ThreadblockSwizzle, GroupScheduleMode_, kInternalTranspose>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -1,566 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/semaphore.h"
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace kernel
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace detail
|
||||
{
|
||||
template <typename>
|
||||
inline constexpr bool dependent_false_v = false;
|
||||
}
|
||||
|
||||
template <typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||
typename Epilogue_, ///! Epilogue
|
||||
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
|
||||
typename KernelArch, ///! The Architecture this kernel is compiled for. Used since SIMT kernels lose top-level
|
||||
/// arch.
|
||||
bool SplitKSerial ///! If true, code supporting split-K via serial reduction is enabled.
|
||||
>
|
||||
struct GemmFpAIntB
|
||||
{
|
||||
|
||||
using Mma = Mma_;
|
||||
using Epilogue = Epilogue_;
|
||||
using EpilogueOutputOp = typename Epilogue::OutputOp;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
static bool const kSplitKSerial = SplitKSerial;
|
||||
|
||||
using ElementA = typename Mma::IteratorA::Element;
|
||||
using LayoutA = typename Mma::IteratorA::Layout;
|
||||
using ElementB = typename Mma::IteratorB::Element;
|
||||
using LayoutB = typename Mma::IteratorB::Element;
|
||||
using ElementC = typename Epilogue::OutputTileIterator::Element;
|
||||
using LayoutC = typename Mma::LayoutC;
|
||||
using ElementScale = ElementC;
|
||||
|
||||
static ComplexTransform const kTransformA = Mma::kTransformA;
|
||||
static ComplexTransform const kTransformB = Mma::kTransformA;
|
||||
|
||||
// Type definitions about the mainloop.
|
||||
using Operator = typename Mma::Operator;
|
||||
using OperatorClass = typename Mma::Operator::OperatorClass;
|
||||
using ThreadblockShape = typename Mma::Shape;
|
||||
using WarpShape = typename Mma::Operator::Shape;
|
||||
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
|
||||
using ArchTag = typename Mma::ArchTag;
|
||||
|
||||
static int const kStages = Mma::kStages;
|
||||
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
|
||||
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
|
||||
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
/// Warp count (concept: GemmShape)
|
||||
using WarpCount = typename Mma::WarpCount;
|
||||
static int const kThreadCount = 32 * WarpCount::kCount;
|
||||
|
||||
static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK;
|
||||
|
||||
/// Parameters structure
|
||||
struct Arguments
|
||||
{
|
||||
GemmUniversalMode mode = GemmUniversalMode::kGemm;
|
||||
|
||||
cutlass::gemm::GemmCoord problem_size;
|
||||
int group_size;
|
||||
typename Mma::IteratorA::TensorRef ref_A;
|
||||
typename Mma::IteratorB::TensorRef ref_B;
|
||||
typename Mma::IteratorScale::TensorRef ref_scale;
|
||||
typename Mma::IteratorScale::TensorRef ref_zero;
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C;
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_D;
|
||||
|
||||
// Control serial split-k
|
||||
int batch_count;
|
||||
|
||||
typename EpilogueOutputOp::Params output_op;
|
||||
|
||||
// For gather+scatter operations
|
||||
int const* gather_A_indices;
|
||||
int const* gather_B_indices;
|
||||
int const* scatter_D_indices;
|
||||
|
||||
// Included so we can use Gemm Universal
|
||||
int batch_stride_D = 0;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments() {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(cutlass::gemm::GemmCoord const& problem_size, int const group_size,
|
||||
typename Mma::IteratorA::TensorRef ref_A, typename Mma::IteratorB::TensorRef ref_B,
|
||||
typename Mma::IteratorScale::TensorRef ref_scale, typename Mma::IteratorScale::TensorRef ref_zero,
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C,
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_D, int serial_split_k_factor,
|
||||
typename EpilogueOutputOp::Params output_op = typename EpilogueOutputOp::Params(),
|
||||
int const* gather_A_indices = nullptr, int const* gather_B_indices = nullptr,
|
||||
int const* scatter_D_indices = nullptr)
|
||||
: problem_size(problem_size)
|
||||
, group_size(group_size)
|
||||
, ref_A(ref_A)
|
||||
, ref_B(ref_B)
|
||||
, ref_scale(ref_scale)
|
||||
, ref_zero(ref_zero)
|
||||
, ref_C(ref_C)
|
||||
, ref_D(ref_D)
|
||||
, batch_count(serial_split_k_factor)
|
||||
, output_op(output_op)
|
||||
, gather_A_indices(gather_A_indices)
|
||||
, gather_B_indices(gather_B_indices)
|
||||
, scatter_D_indices(scatter_D_indices)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
/// Parameters structure
|
||||
struct Params
|
||||
{
|
||||
cutlass::gemm::GemmCoord problem_size;
|
||||
int group_size;
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int swizzle_log_tile;
|
||||
typename Mma::IteratorA::Params params_A;
|
||||
typename Mma::IteratorA::TensorRef ref_A;
|
||||
typename Mma::IteratorB::Params params_B;
|
||||
typename Mma::IteratorB::TensorRef ref_B;
|
||||
typename Mma::IteratorScale::Params params_scale;
|
||||
typename Mma::IteratorScale::TensorRef ref_scale;
|
||||
typename Mma::IteratorScale::TensorRef ref_zero;
|
||||
typename Epilogue::OutputTileIterator::Params params_C;
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C;
|
||||
typename Epilogue::OutputTileIterator::Params params_D;
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_D;
|
||||
typename EpilogueOutputOp::Params output_op;
|
||||
int* semaphore;
|
||||
int gemm_k_size;
|
||||
// For gather+scatter operations
|
||||
int const* gather_A_indices;
|
||||
int const* gather_B_indices;
|
||||
int const* scatter_D_indices;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params()
|
||||
: swizzle_log_tile(0)
|
||||
, semaphore(0)
|
||||
, gemm_k_size(0)
|
||||
{
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape, int const gemm_k_size,
|
||||
void* workspace = nullptr)
|
||||
: problem_size(args.problem_size)
|
||||
, group_size(args.group_size)
|
||||
, grid_tiled_shape(grid_tiled_shape)
|
||||
, swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape))
|
||||
, params_A(args.ref_A.layout())
|
||||
, ref_A(args.ref_A)
|
||||
, params_B(args.ref_B.layout())
|
||||
, ref_B(args.ref_B)
|
||||
, params_scale(args.ref_scale.layout())
|
||||
, ref_scale(args.ref_scale)
|
||||
, ref_zero(args.ref_zero)
|
||||
, params_C(args.ref_C.layout())
|
||||
, ref_C(args.ref_C)
|
||||
, params_D(args.ref_D.layout())
|
||||
, ref_D(args.ref_D)
|
||||
, output_op(args.output_op)
|
||||
, semaphore(static_cast<int*>(workspace))
|
||||
, gemm_k_size(gemm_k_size)
|
||||
, gather_A_indices(args.gather_A_indices)
|
||||
, gather_B_indices(args.gather_B_indices)
|
||||
, scatter_D_indices(args.scatter_D_indices)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared memory storage structure
|
||||
union SharedStorage
|
||||
{
|
||||
typename Mma::SharedStorage main_loop;
|
||||
typename Epilogue::SharedStorage epilogue;
|
||||
};
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
GemmFpAIntB() {}
|
||||
|
||||
/// Determines whether kernel satisfies alignment
|
||||
static Status can_implement(Arguments const& args)
|
||||
{
|
||||
static int const kAlignmentA
|
||||
= (platform::is_same<typename Mma::IteratorA::Layout, layout::ColumnMajorInterleaved<32>>::value) ? 32
|
||||
: (platform::is_same<typename Mma::IteratorA::Layout, layout::ColumnMajorInterleaved<64>>::value)
|
||||
? 64
|
||||
: Mma::IteratorA::AccessType::kElements;
|
||||
static int const kAlignmentB
|
||||
= (platform::is_same<typename Mma::IteratorB::Layout, layout::RowMajorInterleaved<32>>::value) ? 32
|
||||
: (platform::is_same<typename Mma::IteratorB::Layout, layout::RowMajorInterleaved<64>>::value)
|
||||
? 64
|
||||
: Mma::IteratorB::AccessType::kElements;
|
||||
|
||||
static int const kAlignmentScale = Mma::IteratorScale::AccessType::kElements;
|
||||
|
||||
static int const kAlignmentC = (platform::is_same<typename Epilogue::OutputTileIterator::Layout,
|
||||
layout::ColumnMajorInterleaved<32>>::value)
|
||||
? 32
|
||||
: (platform::is_same<typename Epilogue::OutputTileIterator::Layout,
|
||||
layout::ColumnMajorInterleaved<64>>::value)
|
||||
? 64
|
||||
: Epilogue::OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
if (!TensorRef_aligned(args.ref_A, kAlignmentA))
|
||||
{
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (!TensorRef_aligned(args.ref_B, kAlignmentB))
|
||||
{
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (!TensorRef_aligned(args.ref_scale, kAlignmentScale))
|
||||
{
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (!TensorRef_aligned(args.ref_zero, kAlignmentScale))
|
||||
{
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (!TensorRef_aligned(args.ref_C, kAlignmentC))
|
||||
{
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (!TensorRef_aligned(args.ref_D, kAlignmentC))
|
||||
{
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (!args.ref_scale.good())
|
||||
{
|
||||
return Status::kErrorNotSupported;
|
||||
}
|
||||
|
||||
if constexpr (hasZero(Mma::QuantOp))
|
||||
{
|
||||
if (!args.ref_zero.good())
|
||||
{
|
||||
return Status::kErrorNotSupported;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if (args.ref_zero.good())
|
||||
{
|
||||
return Status::kErrorNotSupported;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (isFinegrained(Mma::QuantOp))
|
||||
{
|
||||
if (args.group_size != 64 && args.group_size != 128)
|
||||
{
|
||||
return Status::kErrorNotSupported;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape)
|
||||
{
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Initializes the fine grained scale+bias iterator. Needed since the fine grained iterator
|
||||
// has a different constructor signature than a regular cutlass iterator
|
||||
template <typename IteratorScale, WeightOnlyQuantOp op, std::enable_if_t<isFinegrained(op), bool> = true>
|
||||
CUTLASS_DEVICE static IteratorScale initialize_scale(typename IteratorScale::Params const& params,
|
||||
typename IteratorScale::Pointer pointer_scale, typename IteratorScale::Pointer pointer_zero,
|
||||
typename IteratorScale::TensorCoord extent, int thread_id,
|
||||
typename IteratorScale::TensorCoord const& threadblock_offset, int group_size)
|
||||
{
|
||||
|
||||
return IteratorScale(params, pointer_scale, pointer_zero, extent, thread_id, threadblock_offset, group_size);
|
||||
}
|
||||
|
||||
template <typename IteratorScale, WeightOnlyQuantOp op, std::enable_if_t<!isFinegrained(op), bool> = true>
|
||||
CUTLASS_DEVICE static IteratorScale initialize_scale(typename IteratorScale::Params const& params,
|
||||
typename IteratorScale::Pointer pointer_scale, typename IteratorScale::Pointer pointer_zero,
|
||||
typename IteratorScale::TensorCoord extent, int thread_id,
|
||||
typename IteratorScale::TensorCoord const& threadblock_offset, int group_size)
|
||||
{
|
||||
|
||||
return IteratorScale(params, pointer_scale, extent, thread_id, threadblock_offset);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void run_kernel_(Params const& params, SharedStorage& shared_storage)
|
||||
{
|
||||
using LayoutB = typename Mma::IteratorB::Layout;
|
||||
static_assert(platform::is_same<LayoutB, layout::RowMajor>::value && kInterleave == 1
|
||||
|| platform::is_same<LayoutB, layout::ColumnMajor>::value && kInterleave >= 1,
|
||||
"B must be row major/col major OR col major interleaved.");
|
||||
|
||||
// Compute threadblock location
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
// Early exit if CTA is out of range
|
||||
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m()
|
||||
|| params.grid_tiled_shape.n() <= threadblock_tile_offset.n())
|
||||
{
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
// Compute initial location in logical coordinates
|
||||
cutlass::MatrixCoord tb_offset_A{
|
||||
threadblock_tile_offset.m() * Mma::Shape::kM,
|
||||
threadblock_tile_offset.k() * params.gemm_k_size,
|
||||
};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_B{threadblock_tile_offset.k() * params.gemm_k_size * kInterleave,
|
||||
threadblock_tile_offset.n() * Mma::Shape::kN / kInterleave};
|
||||
|
||||
typename MatrixCoord::Index fg_row_offset = threadblock_tile_offset.k() * params.gemm_k_size / 64;
|
||||
typename MatrixCoord::Index scale_row_offset = isFinegrained(Mma::QuantOp) ? fg_row_offset : 0;
|
||||
cutlass::MatrixCoord tb_offset_scale{scale_row_offset, threadblock_tile_offset.n() * Mma::Shape::kN};
|
||||
|
||||
// Problem size is a function of threadblock index in the K dimension
|
||||
int problem_size_k = min(params.problem_size.k(), (threadblock_tile_offset.k() + 1) * params.gemm_k_size);
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK;
|
||||
|
||||
// Compute position within threadblock
|
||||
int thread_idx = threadIdx.x;
|
||||
|
||||
// Construct iterators to A and B operands
|
||||
typename Mma::IteratorA iterator_A(params.params_A, params.ref_A.data(),
|
||||
{params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A, params.gather_A_indices);
|
||||
|
||||
typename Mma::IteratorB iterator_B(params.params_B, params.ref_B.data(),
|
||||
{problem_size_k * kInterleave, params.problem_size.n() / kInterleave}, thread_idx, tb_offset_B,
|
||||
params.gather_B_indices);
|
||||
|
||||
typename MatrixCoord::Index scale_row_extent = isFinegrained(Mma::QuantOp) ? problem_size_k / 64 : 1;
|
||||
typename Mma::IteratorScale iterator_scale = initialize_scale<typename Mma::IteratorScale, Mma::QuantOp>(
|
||||
params.params_scale, params.ref_scale.data(), params.ref_zero.data(),
|
||||
{scale_row_extent, params.problem_size.n()}, thread_idx, tb_offset_scale, params.group_size);
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
//
|
||||
// Main loop
|
||||
//
|
||||
// Construct thread-scoped matrix multiply
|
||||
Mma mma(shared_storage.main_loop, params.group_size, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
typename Mma::FragmentC accumulators;
|
||||
|
||||
accumulators.clear();
|
||||
|
||||
if (!kSplitKSerial || gemm_k_iterations > 0)
|
||||
{
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_scale, accumulators);
|
||||
}
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
//
|
||||
|
||||
EpilogueOutputOp output_op(params.output_op);
|
||||
|
||||
//
|
||||
// Masked tile iterators constructed from members
|
||||
//
|
||||
|
||||
threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
// assume identity swizzle
|
||||
MatrixCoord threadblock_offset(
|
||||
threadblock_tile_offset.m() * Mma::Shape::kM, threadblock_tile_offset.n() * Mma::Shape::kN);
|
||||
|
||||
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
|
||||
|
||||
// Construct the semaphore.
|
||||
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
|
||||
|
||||
// If performing a reduction via split-K, fetch the initial synchronization
|
||||
if (kSplitKSerial && params.grid_tiled_shape.k() > 1)
|
||||
{
|
||||
|
||||
// Fetch the synchronization lock initially but do not block.
|
||||
semaphore.fetch();
|
||||
|
||||
// Indicate which position in a serial reduction the output operator is currently updating
|
||||
output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
|
||||
}
|
||||
|
||||
// Tile iterator loading from source tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_C(params.params_C, params.ref_C.data(), params.problem_size.mn(),
|
||||
thread_idx, threadblock_offset, params.scatter_D_indices);
|
||||
|
||||
// Tile iterator writing to destination tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_D(params.params_D, params.ref_D.data(), params.problem_size.mn(),
|
||||
thread_idx, threadblock_offset, params.scatter_D_indices);
|
||||
|
||||
Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
// Wait on the semaphore - this latency may have been covered by iterator construction
|
||||
if (kSplitKSerial && params.grid_tiled_shape.k() > 1)
|
||||
{
|
||||
|
||||
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
|
||||
if (threadblock_tile_offset.k())
|
||||
{
|
||||
iterator_C = iterator_D;
|
||||
}
|
||||
|
||||
semaphore.wait(threadblock_tile_offset.k());
|
||||
}
|
||||
|
||||
// Execute the epilogue operator to update the destination tensor.
|
||||
epilogue(output_op, iterator_D, accumulators, iterator_C);
|
||||
|
||||
//
|
||||
// Release the semaphore
|
||||
//
|
||||
|
||||
if (kSplitKSerial && params.grid_tiled_shape.k() > 1)
|
||||
{
|
||||
|
||||
int lock = 0;
|
||||
if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1)
|
||||
{
|
||||
|
||||
// The final threadblock resets the semaphore for subsequent grids.
|
||||
lock = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
// Otherwise, the semaphore is incremented
|
||||
lock = threadblock_tile_offset.k() + 1;
|
||||
}
|
||||
|
||||
semaphore.release(lock);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename CompilationArch>
|
||||
CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage)
|
||||
{
|
||||
if constexpr (platform::is_same<KernelArch, CompilationArch>::value)
|
||||
{
|
||||
run_kernel_(params, shared_storage);
|
||||
}
|
||||
else
|
||||
{
|
||||
CUTLASS_NOT_IMPLEMENTED();
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond
|
||||
to the ArchTag of the cutlass kernel operator.
|
||||
*/
|
||||
/// Executes one GEMM
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const& params, SharedStorage& shared_storage)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__)
|
||||
#if (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800)
|
||||
run_kernel<arch::Sm75>(params, shared_storage);
|
||||
#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 890)
|
||||
run_kernel<arch::Sm80>(params, shared_storage);
|
||||
#elif (__CUDA_ARCH__ == 890)
|
||||
run_kernel<arch::Sm89>(params, shared_storage);
|
||||
#elif (__CUDA_ARCH__ >= 900)
|
||||
CUTLASS_NOT_IMPLEMENTED(); // Don't compile these for Hopper or later. Use CUTLASS 3.x kernels.
|
||||
#else
|
||||
static_assert(
|
||||
false, "Invalid architecture being compiled. Only Volta+ supported in weight-only quantization kernels.");
|
||||
#endif
|
||||
#else
|
||||
CUTLASS_NOT_IMPLEMENTED();
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@@ -1,218 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include <cutlass/gemm/kernel/gemm_grouped_problem_visitor.h>
|
||||
#include <cutlass/trace.h>
|
||||
#include <cutlass_extensions/gemm/kernel/fused_moe_kernel_routine.cuh>
|
||||
#include <cutlass_extensions/gemm/kernel/fused_moe_kernel_traits.cuh>
|
||||
#include <cutlass_extensions/gemm/kernel/moe_problem_visitor.h>
|
||||
|
||||
namespace fused_moe
|
||||
{
|
||||
template <typename ElementInput_, typename ElementWeight_, typename ElementOutput_, int MaxTileM_, int TileN_,
|
||||
int TileK_, int Stages_, Activation_Type activation_type_>
|
||||
struct Fused_Moe_Kernel_sm80
|
||||
{
|
||||
static constexpr int kMaxTileM = MaxTileM_;
|
||||
static constexpr int kTileN = isGateActivation(activation_type_) ? TileN_ / 2 : TileN_;
|
||||
static constexpr int kTileK = TileK_;
|
||||
static constexpr int kStages = Stages_;
|
||||
static constexpr Activation_Type activation_type = activation_type_;
|
||||
|
||||
using ElementInput = ElementInput_;
|
||||
using ElementWeight = ElementWeight_;
|
||||
using ElementOutput = ElementOutput_;
|
||||
using BaseKernelTraits = Fused_Moe_Kernel_traits_sm80<ElementInput, ElementWeight, ElementOutput, kMaxTileM, kTileN,
|
||||
kTileK, kStages, activation_type>;
|
||||
using Routine_Arguments = Routine_Arguments<ElementInput, ElementWeight, ElementOutput>;
|
||||
using Routine_Params = Routine_Params<ElementInput, ElementWeight, ElementOutput>;
|
||||
using ProblemVisitor
|
||||
= cutlass::gemm::kernel::MoeProblemVisitor<cutlass::gemm::kernel::detail::GemmGroupedProblemSizeHelper<
|
||||
cutlass::gemm::GemmShape<kMaxTileM, kTileN, kTileK>, false>,
|
||||
cutlass::gemm::GemmShape<kMaxTileM, kTileN, kTileK>, cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly,
|
||||
BaseKernelTraits::kThreadCount, BaseKernelTraits::kThreadCount>;
|
||||
|
||||
struct Arguments
|
||||
{
|
||||
Routine_Arguments routine_args;
|
||||
int problem_count{};
|
||||
int threadblock_count{};
|
||||
};
|
||||
|
||||
struct Params
|
||||
{
|
||||
Routine_Params routine_params;
|
||||
int threadblock_count{};
|
||||
typename ProblemVisitor::Params problem_visitor_param;
|
||||
};
|
||||
|
||||
using BaseKernelTraits_m16 = Fused_Moe_Kernel_traits_sm80<ElementInput, ElementWeight, ElementOutput, 16, kTileN,
|
||||
kTileK, kStages, activation_type>;
|
||||
static constexpr bool use_m16 = TileK_ >= 64; // use tileshape m = 16 when original tileshape k >= 64
|
||||
|
||||
static constexpr int kSmemSize = use_m16
|
||||
? (BaseKernelTraits::kSmemSize > BaseKernelTraits_m16::kSmemSize ? BaseKernelTraits::kSmemSize
|
||||
: BaseKernelTraits_m16::kSmemSize)
|
||||
: BaseKernelTraits::kSmemSize;
|
||||
static constexpr int kThreadCount = BaseKernelTraits::kThreadCount;
|
||||
|
||||
static constexpr bool can_implement(int const avaliable_smem_size)
|
||||
{
|
||||
return BaseKernelTraits::can_implement(avaliable_smem_size);
|
||||
}
|
||||
|
||||
static Params to_underlying_arguments(Arguments const& args)
|
||||
{
|
||||
return {
|
||||
{args.routine_args.ptr_input, args.routine_args.ptr_fc1, args.routine_args.ptr_bias,
|
||||
args.routine_args.ptr_output, args.routine_args.total_tokens_including_expert, args.routine_args.gemm_n,
|
||||
args.routine_args.gemm_k, args.routine_args.num_expert, args.routine_args.bias_is_broadcast},
|
||||
args.threadblock_count,
|
||||
{args.routine_args.total_tokens_including_expert, args.routine_args.gemm_n, args.routine_args.gemm_k,
|
||||
args.problem_count, nullptr, 0}};
|
||||
}
|
||||
|
||||
CUTE_DEVICE
|
||||
void run_device(Params const& params)
|
||||
{
|
||||
#define ROUTINE_PATH(kTileM_size) \
|
||||
{ \
|
||||
constexpr int kTileM = use_m16 ? (kTileM_size) : ((kTileM_size) == 16 ? 32 : (kTileM_size)); \
|
||||
using RoutineTraits = Fused_Moe_Kernel_routine_sm80<ElementInput, ElementWeight, ElementOutput, kTileM, \
|
||||
kTileN, kTileK, kStages, activation_type>; \
|
||||
RoutineTraits routine{}; \
|
||||
int const block_m_idx = (block_m_idx_temp) *kMaxTileM / kTileM; \
|
||||
routine.run_routine(params.routine_params, problem_index, block_m_idx, block_n_idx, gemm_m); \
|
||||
}
|
||||
typename ProblemVisitor::SharedStorage dummy_storage{};
|
||||
ProblemVisitor problem_visitor(params.problem_visitor_param, dummy_storage, blockIdx.x);
|
||||
while (problem_visitor.next_tile())
|
||||
{
|
||||
auto problem_size = problem_visitor.problem_size();
|
||||
auto grid_size = problem_visitor.grid_shape(problem_size);
|
||||
auto problem_index = problem_visitor.problem_index();
|
||||
int32_t cta_idx = int32_t(problem_visitor.threadblock_idx());
|
||||
int const gemm_m = problem_size.m();
|
||||
const int32_t block_m_idx_temp = cta_idx / grid_size.n();
|
||||
const int32_t block_n_idx = cta_idx % grid_size.n();
|
||||
|
||||
int const residue_m = gemm_m - kMaxTileM * block_m_idx_temp;
|
||||
if (residue_m > kMaxTileM / 2)
|
||||
{
|
||||
using RoutineTraits = Fused_Moe_Kernel_routine_sm80<ElementInput, ElementWeight, ElementOutput,
|
||||
kMaxTileM, kTileN, kTileK, kStages, activation_type>;
|
||||
RoutineTraits routine{};
|
||||
routine.run_routine(params.routine_params, problem_index, block_m_idx_temp, block_n_idx, gemm_m);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
if constexpr (kMaxTileM >= 128)
|
||||
{
|
||||
if (residue_m > 32)
|
||||
{
|
||||
ROUTINE_PATH(64);
|
||||
}
|
||||
else if (residue_m > 16)
|
||||
{
|
||||
ROUTINE_PATH(32);
|
||||
}
|
||||
else
|
||||
{
|
||||
// TODO: use cuda core gemm here
|
||||
ROUTINE_PATH(16);
|
||||
}
|
||||
}
|
||||
else if (kMaxTileM == 64)
|
||||
{
|
||||
if (residue_m > 16)
|
||||
{
|
||||
ROUTINE_PATH(32);
|
||||
}
|
||||
else
|
||||
{
|
||||
// TODO: use cuda core gemm here
|
||||
ROUTINE_PATH(16);
|
||||
}
|
||||
}
|
||||
else if (kMaxTileM == 32)
|
||||
{
|
||||
// TODO: use cuda core gemm here
|
||||
ROUTINE_PATH(16);
|
||||
}
|
||||
else
|
||||
{
|
||||
// TODO: use cuda core gemm here
|
||||
ROUTINE_PATH(16);
|
||||
}
|
||||
}
|
||||
problem_visitor.advance(gridDim.x);
|
||||
}
|
||||
#undef ROUTINE_PATH
|
||||
}
|
||||
};
|
||||
|
||||
template <typename GemmType>
|
||||
__global__ void run_global(__grid_constant__ typename GemmType::Params const params)
|
||||
{
|
||||
GemmType gemm;
|
||||
gemm.run_device(params);
|
||||
}
|
||||
|
||||
/// Computes the maximum number of active blocks per multiprocessor
|
||||
template <typename GemmType>
|
||||
static int fused_gemm_maximum_active_blocks(int smem_capacity = -1)
|
||||
{
|
||||
|
||||
CUTLASS_TRACE_HOST("BaseGrouped::maximum_active_blocks()");
|
||||
|
||||
constexpr int smem_size = GemmType::kSmemSize;
|
||||
|
||||
CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes");
|
||||
|
||||
cudaError_t result;
|
||||
if (smem_size > (48 << 10))
|
||||
{
|
||||
result = cudaFuncSetAttribute(run_global<GemmType>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
|
||||
if (result != cudaSuccess)
|
||||
{
|
||||
// Call cudaGetLastError() to clear the error bit
|
||||
result = cudaGetLastError();
|
||||
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(result));
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
int max_active_blocks = -1;
|
||||
result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_active_blocks, run_global<GemmType>, GemmType::kThreadCount, smem_size);
|
||||
|
||||
if (result != cudaSuccess)
|
||||
{
|
||||
// Call cudaGetLastError() to clear the error bit
|
||||
result = cudaGetLastError();
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result));
|
||||
return -1;
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
|
||||
return max_active_blocks;
|
||||
}
|
||||
} // namespace fused_moe
|
||||
@@ -1,799 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include <cutlass_extensions/gemm/kernel/fused_moe_kernel_traits.cuh>
|
||||
|
||||
namespace fused_moe
|
||||
{
|
||||
|
||||
template <typename ElementInput_, typename ElementWeight_, typename ElementOutput_, int TileM_, int TileN_, int TileK_,
|
||||
int Stages_, Activation_Type activation_type_, typename Enable = void>
|
||||
struct Fused_Moe_Kernel_routine_sm80;
|
||||
|
||||
template <typename ElementInput_, typename ElementWeight_, typename ElementOutput_, int TileM_, int TileN_, int TileK_,
|
||||
int Stages_, Activation_Type activation_type_>
|
||||
struct Fused_Moe_Kernel_routine_sm80<ElementInput_, ElementWeight_, ElementOutput_, TileM_, TileN_, TileK_, Stages_,
|
||||
activation_type_, std::enable_if_t<isGateActivation(activation_type_)>>
|
||||
{
|
||||
using KT = Fused_Moe_Kernel_traits_sm80<ElementInput_, ElementWeight_, ElementOutput_, TileM_, TileN_, TileK_,
|
||||
Stages_, activation_type_>;
|
||||
using Params = Routine_Params<ElementInput_, ElementWeight_, ElementOutput_>;
|
||||
|
||||
CUTE_DEVICE auto gmem_tensor_init(int const problem_index, int const gemm_m, Params const& params)
|
||||
{
|
||||
using X = cute::Underscore;
|
||||
|
||||
int const M = gemm_m;
|
||||
int const N1 = params.gemm_n;
|
||||
int const K1 = params.gemm_k;
|
||||
bool const bias_is_broadcast = params.bias_is_broadcast;
|
||||
|
||||
int const row_jump = ((problem_index == 0) ? 0 : params.total_tokens_including_expert[problem_index - 1]);
|
||||
typename KT::ElementInput const* ptr_input_ = params.ptr_input + row_jump * K1;
|
||||
typename KT::ElementWeight const* ptr_fc1_gate_
|
||||
= params.ptr_fc1 + (2 * problem_index + 1) * N1 * K1; // TODO: we only focus on gated activation..
|
||||
typename KT::ElementWeight const* ptr_fc1_
|
||||
= params.ptr_fc1 + 2 * problem_index * N1 * K1; // TODO: we only focus on gated activation..
|
||||
typename KT::ElementInput const* ptr_bias_ = (params.ptr_bias == nullptr)
|
||||
? nullptr
|
||||
: (bias_is_broadcast ? params.ptr_bias + 2 * problem_index * N1 : params.ptr_bias + 2 * row_jump * N1);
|
||||
typename KT::ElementInput const* ptr_bias_gate_ = (params.ptr_bias == nullptr)
|
||||
? nullptr
|
||||
: (bias_is_broadcast ? params.ptr_bias + (2 * problem_index + 1) * N1
|
||||
: params.ptr_bias + (2 * row_jump + 1) * N1);
|
||||
typename KT::ElementOutput* ptr_output_ = params.ptr_output + row_jump * N1;
|
||||
|
||||
cute::Tensor mInput_mk
|
||||
= cute::make_tensor(cute::make_gmem_ptr(static_cast<typename KT::ElementInput const*>(ptr_input_)),
|
||||
cute::make_shape(M, K1), cute::make_stride(K1, cute::_1{}));
|
||||
|
||||
cute::Tensor mfc1_gate_nk
|
||||
= cute::make_tensor(cute::make_gmem_ptr(static_cast<typename KT::ElementWeight const*>(ptr_fc1_gate_)),
|
||||
cute::make_shape(N1, K1), cute::make_stride(K1, cute::_1{}));
|
||||
|
||||
cute::Tensor mfc1_nk
|
||||
= cute::make_tensor(cute::make_gmem_ptr(static_cast<typename KT::ElementWeight const*>(ptr_fc1_)),
|
||||
cute::make_shape(N1, K1), cute::make_stride(K1, cute::_1{}));
|
||||
|
||||
cute::Tensor mBias_mn = cute::make_tensor(
|
||||
cute::make_gmem_ptr(static_cast<typename KT::ElementInput const*>(ptr_bias_)), cute::make_shape(M, N1),
|
||||
cute::make_stride(bias_is_broadcast ? cute::Int<0>{} : N1 * 2,
|
||||
cute::_1{})); // trick: bias shape is [1, N], but we use [M, N].
|
||||
|
||||
cute::Tensor mBias_gate_mn = cute::make_tensor(
|
||||
cute::make_gmem_ptr(static_cast<typename KT::ElementInput const*>(ptr_bias_gate_)), cute::make_shape(M, N1),
|
||||
cute::make_stride(bias_is_broadcast ? cute::Int<0>{} : N1 * 2,
|
||||
cute::_1{})); // trick: bias shape is [1, N], but we use [M, N].
|
||||
|
||||
cute::Tensor mOutput_mn
|
||||
= cute::make_tensor(cute::make_gmem_ptr(static_cast<typename KT::ElementInput*>(ptr_output_)),
|
||||
cute::make_shape(M, N1), cute::make_stride(N1, cute::_1{}));
|
||||
|
||||
cute::Tensor gInput_mk = cute::local_tile(mInput_mk, typename KT::TileShape{},
|
||||
cute::make_coord(cute::_, cute::_, cute::_), cute::Step<cute::_1, X, cute::_1>{}); // (BLK_M, BLK_K, m, k)
|
||||
cute::Tensor gfc1_gate_nk = cute::local_tile(mfc1_gate_nk, typename KT::TileShape{},
|
||||
cute::make_coord(cute::_, cute::_, cute::_), cute::Step<X, cute::_1, cute::_1>{}); // (BLK_N, BLK_K, n, k)
|
||||
cute::Tensor gfc1_nk = cute::local_tile(mfc1_nk, typename KT::TileShape{},
|
||||
cute::make_coord(cute::_, cute::_, cute::_), cute::Step<X, cute::_1, cute::_1>{}); // (BLK_N, BLK_K, n, k)
|
||||
|
||||
cute::Tensor gBias_mn = cute::local_tile(mBias_mn, typename KT::TileShape{},
|
||||
cute::make_coord(cute::_, cute::_, cute::_), cute::Step<cute::_1, cute::_1, X>{}); // (BLK_M, BLK_N, m, n)
|
||||
|
||||
cute::Tensor gBias_gate_mn = cute::local_tile(mBias_gate_mn, typename KT::TileShape{},
|
||||
cute::make_coord(cute::_, cute::_, cute::_), cute::Step<cute::_1, cute::_1, X>{}); // (BLK_M, BLK_N, m, n)
|
||||
|
||||
cute::Tensor gOutput_mn = cute::local_tile(mOutput_mn, typename KT::TileShape{},
|
||||
cute::make_coord(cute::_, cute::_, cute::_), cute::Step<cute::_1, cute::_1, X>{}); // (BLK_M, BLK_N, m, n)
|
||||
|
||||
return cute::make_tuple(gInput_mk, gfc1_gate_nk, gfc1_nk, gBias_mn, gBias_gate_mn, gOutput_mn);
|
||||
}
|
||||
|
||||
// be careful, m_idx will change when use another tile shape..
|
||||
CUTE_DEVICE void run_routine(
|
||||
Params const& params, int const problem_index, int const block_m_idx, int const block_n_idx, int const gemm_m)
|
||||
{
|
||||
extern __shared__ char smem_[];
|
||||
typename KT::SharedStorage& shared_storage = *reinterpret_cast<typename KT::SharedStorage*>(smem_);
|
||||
int const thread_idx = threadIdx.x;
|
||||
bool const bias_is_broadcast = params.bias_is_broadcast;
|
||||
// gmem tensor partition ..
|
||||
auto [gInput_mk, gfc1_gate_nk, gfc1_nk, gBias_mn, gBias_gate_mn, gOutput_mn]
|
||||
= gmem_tensor_init(problem_index, gemm_m, params);
|
||||
int const residue_m = gemm_m - block_m_idx * cute::size<0>(gInput_mk);
|
||||
auto const n_tile_count = cute::size<2>(gfc1_gate_nk);
|
||||
|
||||
// smem tensor ..
|
||||
cute::Tensor sInput = cute::make_tensor(
|
||||
cute::make_smem_ptr(shared_storage.smem_input.data()), typename KT::SmemLayoutA{}); // (BLK_M, BLK_K, Stage)
|
||||
cute::Tensor sfc1_weight = cute::make_tensor(cute::make_smem_ptr(shared_storage.smem_fc1_weight.data()),
|
||||
typename KT::SmemLayoutB{}); // (BLK_N, BLK_K, Stage)
|
||||
cute::Tensor sfc1_gate_weight
|
||||
= cute::make_tensor(cute::make_smem_ptr(shared_storage.smem_fc1_gate_weight.data()),
|
||||
typename KT::SmemLayoutB{}); // (BLK_N, BLK_K, Stage)
|
||||
cute::Tensor sO = cute::make_tensor(
|
||||
cute::make_smem_ptr(shared_storage.smem_o.data()), typename KT::SmemLayoutO{}); // (BLK_M, BLK_N)
|
||||
|
||||
// (1) first step, get the fc1_res and fc1_gate
|
||||
|
||||
// (1.1) get partition for gmem -> smem
|
||||
cute::Tensor gInput = gInput_mk(cute::_, cute::_, block_m_idx, cute::_); // (BLK_M, BLK_K, k)
|
||||
cute::Tensor gfc1 = gfc1_nk(cute::_, cute::_, block_n_idx, cute::_); // (BLK_N, BLK_K, k)
|
||||
cute::Tensor gfc1g = gfc1_gate_nk(cute::_, cute::_, block_n_idx, cute::_); // (BLK_N, BLK_K, k)
|
||||
|
||||
typename KT::GmemTiledCopyA gmem_tiled_copy_A;
|
||||
typename KT::GmemTiledCopyB gmem_tiled_copy_B;
|
||||
auto gmem_thr_copy_A = gmem_tiled_copy_A.get_slice(thread_idx);
|
||||
auto gmem_thr_copy_B = gmem_tiled_copy_B.get_slice(thread_idx);
|
||||
|
||||
cute::Tensor tInputgInput = gmem_thr_copy_A.partition_S(gInput); // (ACPY,ACPY_M,ACPY_K,k)
|
||||
cute::Tensor tInputsInput = gmem_thr_copy_A.partition_D(sInput); // (ACPY,ACPY_M,ACPY_K,Stage)
|
||||
cute::Tensor tfc1gfc1 = gmem_thr_copy_B.partition_S(gfc1); // (BCPY,BCPY_N,BCPY_K,k)
|
||||
cute::Tensor tfc1sfc1 = gmem_thr_copy_B.partition_D(sfc1_weight); // (BCPY,BCPY_N,BCPY_K,Stage)
|
||||
cute::Tensor tfc1ggfc1g = gmem_thr_copy_B.partition_S(gfc1g); // (BCPY,BCPY_N,BCPY_K,k)
|
||||
cute::Tensor tfc1gsfc1g = gmem_thr_copy_B.partition_D(sfc1_gate_weight); // (BCPY,BCPY_N,BCPY_K,Stage)
|
||||
|
||||
// Allocate predicate tensors for input and fc weight (actually we only need input predicate tensor)
|
||||
cute::Tensor tInputpInput
|
||||
= cute::make_tensor<bool>(cute::make_shape(cute::size<1>(tInputsInput), cute::size<2>(tInputsInput)),
|
||||
cute::Stride<cute::_1, cute::_0>{});
|
||||
// Construct identity layout for sInput
|
||||
cute::Tensor cInput = make_identity_tensor(
|
||||
make_shape(cute::size<0>(sInput), cute::size<1>(sInput))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
||||
|
||||
// Repeat the partitioning with identity layouts
|
||||
cute::Tensor tInputcInput = gmem_thr_copy_A.partition_S(cInput); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
|
||||
|
||||
// Set predicates for m bounds
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int m = 0; m < cute::size<0>(tInputpInput); ++m)
|
||||
{
|
||||
tInputpInput(m, 0) = cute::get<0>(tInputcInput(0, m, 0)) < residue_m; // blk_m coord < residue_m
|
||||
}
|
||||
|
||||
// (1.2) prefetch gmem -> smem
|
||||
cute::clear(tInputsInput); // we don't need to clear tfc1sfc1..
|
||||
auto k_tile_iter = cute::make_coord_iterator(cute::size<2>(gInput)); // emm, iter start from 0
|
||||
int k_tile_count = cute::size<2>(gInput);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_pipe = 0; k_pipe < KT::Stages - 1; ++k_pipe)
|
||||
{
|
||||
if (k_tile_count <= 0)
|
||||
{
|
||||
cute::clear(tInputpInput);
|
||||
}
|
||||
// cute::copy(gmem_tiled_copy_A, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter),
|
||||
// tInputsInput(cute::_, cute::_, cute::_, k_pipe));
|
||||
// use copy_if
|
||||
cute::copy_if(gmem_tiled_copy_A, tInputpInput, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter),
|
||||
tInputsInput(cute::_, cute::_, cute::_, k_pipe));
|
||||
cute::copy(gmem_tiled_copy_B, tfc1gfc1(cute::_, cute::_, cute::_, *k_tile_iter),
|
||||
tfc1sfc1(cute::_, cute::_, cute::_, k_pipe));
|
||||
cute::copy(gmem_tiled_copy_B, tfc1ggfc1g(cute::_, cute::_, cute::_, *k_tile_iter),
|
||||
tfc1gsfc1g(cute::_, cute::_, cute::_, k_pipe));
|
||||
cute::cp_async_fence();
|
||||
k_tile_count--;
|
||||
if (k_tile_count > 0)
|
||||
{
|
||||
++k_tile_iter;
|
||||
}
|
||||
}
|
||||
|
||||
// (1.3) get partition for rf
|
||||
typename KT::TiledMma tiled_mma;
|
||||
auto thr_mma = tiled_mma.get_thread_slice(thread_idx);
|
||||
cute::Tensor tOrInput = thr_mma.partition_fragment_A(sInput(cute::_, cute::_, 0)); // (MMA,MMA_M,MMA_K)
|
||||
cute::Tensor tOrfc1 = thr_mma.partition_fragment_B(sfc1_weight(cute::_, cute::_, 0)); // (MMA,MMA_N,MMA_K)
|
||||
cute::Tensor tOrfc1g = thr_mma.partition_fragment_B(sfc1_gate_weight(cute::_, cute::_, 0)); // (MMA,MMA_N,MMA_K)
|
||||
|
||||
cute::Tensor accum
|
||||
= cute::partition_fragment_C(tiled_mma, cute::take<0, 2>(typename KT::TileShape{})); // (MMA,MMA_M,MMA_N)
|
||||
cute::Tensor accum_gate
|
||||
= cute::partition_fragment_C(tiled_mma, cute::take<0, 2>(typename KT::TileShape{})); // (MMA,MMA_M,MMA_N)
|
||||
cute::clear(accum);
|
||||
cute::clear(accum_gate);
|
||||
// checkout the shape
|
||||
CUTE_STATIC_ASSERT_V(cute::size<1>(tOrInput) == cute::size<1>(accum)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(cute::size<1>(tOrInput) == cute::size<1>(accum_gate)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(cute::size<1>(tOrfc1) == cute::size<2>(accum)); // MMA_N
|
||||
CUTE_STATIC_ASSERT_V(cute::size<1>(tOrfc1) == cute::size<2>(accum_gate)); // MMA_N
|
||||
CUTE_STATIC_ASSERT_V(cute::size<1>(tOrfc1g) == cute::size<2>(accum)); // MMA_N
|
||||
CUTE_STATIC_ASSERT_V(cute::size<1>(tOrfc1g) == cute::size<2>(accum_gate)); // MMA_N
|
||||
CUTE_STATIC_ASSERT_V(cute::size<2>(tOrInput) == cute::size<2>(tOrfc1)); // MMA_K
|
||||
CUTE_STATIC_ASSERT_V(cute::size<2>(tOrInput) == cute::size<2>(tOrfc1g)); // MMA_K
|
||||
CUTE_STATIC_ASSERT_V(cute::size(gmem_tiled_copy_A) == cute::size(tiled_mma));
|
||||
CUTE_STATIC_ASSERT_V(cute::size(gmem_tiled_copy_B) == cute::size(tiled_mma));
|
||||
|
||||
// (1.4)retiling the smem and rf for copy..
|
||||
auto smem_tiled_copy_A = cute::make_tiled_copy_A(typename KT::SmemCopyAtomA{}, tiled_mma);
|
||||
auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx);
|
||||
cute::Tensor tOsInput = smem_thr_copy_A.partition_S(sInput); // (CPY,CPY_M,CPY_K,Stage)
|
||||
cute::Tensor tOrInput_copy_view = smem_thr_copy_A.retile_D(tOrInput); // (CPY,CPY_M,CPY_K)
|
||||
CUTE_STATIC_ASSERT_V(cute::size<1>(tOsInput) == cute::size<1>(tOrInput_copy_view)); // CPY_M
|
||||
CUTE_STATIC_ASSERT_V(cute::size<2>(tOsInput) == cute::size<2>(tOrInput_copy_view)); // CPY_K
|
||||
|
||||
auto smem_tiled_copy_B = cute::make_tiled_copy_B(typename KT::SmemCopyAtomB{}, tiled_mma);
|
||||
auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx);
|
||||
cute::Tensor tOsfc1 = smem_thr_copy_B.partition_S(sfc1_weight); // (CPY,CPY_N,CPY_K,Stage)
|
||||
cute::Tensor tOrfc1_copy_view = smem_thr_copy_B.retile_D(tOrfc1); // (CPY,CPY_N,CPY_K)
|
||||
cute::Tensor tOsfc1g = smem_thr_copy_B.partition_S(sfc1_gate_weight); // (CPY,CPY_N,CPY_K,Stage)
|
||||
cute::Tensor tOrfc1g_copy_view = smem_thr_copy_B.retile_D(tOrfc1g); // (CPY,CPY_N,CPY_K)
|
||||
CUTE_STATIC_ASSERT_V(cute::size<1>(tOsfc1) == cute::size<1>(tOrfc1_copy_view)); // CPY_N
|
||||
CUTE_STATIC_ASSERT_V(cute::size<2>(tOsfc1) == cute::size<2>(tOrfc1_copy_view)); // CPY_K
|
||||
CUTE_STATIC_ASSERT_V(cute::size<1>(tOsfc1g) == cute::size<1>(tOrfc1g_copy_view)); // CPY_N
|
||||
CUTE_STATIC_ASSERT_V(cute::size<2>(tOsfc1g) == cute::size<2>(tOrfc1g_copy_view)); // CPY_K
|
||||
|
||||
// (1.5) mainloop
|
||||
// Current pipe index in smem to read from
|
||||
int smem_pipe_read = 0;
|
||||
// Current pipe index in smem to write to
|
||||
int smem_pipe_write = KT::Stages - 1;
|
||||
|
||||
cute::Tensor tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read);
|
||||
cute::Tensor tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read);
|
||||
cute::Tensor tOsfc1g_p = tOsfc1g(cute::_, cute::_, cute::_, smem_pipe_read);
|
||||
|
||||
constexpr int K_BLOCK_MAX = cute::size<2>(tOrInput);
|
||||
// prefetch register pipeline
|
||||
if constexpr (K_BLOCK_MAX > 1)
|
||||
{
|
||||
cute::cp_async_wait<KT::Stages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Prefetch the first rmem from the first k-tile
|
||||
cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, cute::Int<0>{}),
|
||||
tOrInput_copy_view(cute::_, cute::_, cute::Int<0>{}));
|
||||
cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, cute::Int<0>{}),
|
||||
tOrfc1_copy_view(cute::_, cute::_, cute::Int<0>{}));
|
||||
cute::copy(smem_tiled_copy_B, tOsfc1g_p(cute::_, cute::_, cute::Int<0>{}),
|
||||
tOrfc1g_copy_view(cute::_, cute::_, cute::Int<0>{}));
|
||||
}
|
||||
// k loop for mainloop
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; k_tile_count > 0; --k_tile_count)
|
||||
{
|
||||
cute::for_each(cute::make_int_sequence<K_BLOCK_MAX>{},
|
||||
[&](auto k_block)
|
||||
{
|
||||
if (k_block == K_BLOCK_MAX - 1)
|
||||
{
|
||||
tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read);
|
||||
tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read);
|
||||
tOsfc1g_p = tOsfc1g(cute::_, cute::_, cute::_, smem_pipe_read);
|
||||
cute::cp_async_wait<KT::Stages - 2>();
|
||||
__syncthreads();
|
||||
}
|
||||
// Load A, B shmem->regs for k_block+1
|
||||
auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX;
|
||||
cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next),
|
||||
tOrInput_copy_view(cute::_, cute::_, k_block_next));
|
||||
cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next),
|
||||
tOrfc1_copy_view(cute::_, cute::_, k_block_next));
|
||||
cute::copy(smem_tiled_copy_B, tOsfc1g_p(cute::_, cute::_, k_block_next),
|
||||
tOrfc1g_copy_view(cute::_, cute::_, k_block_next));
|
||||
// Copy gmem to smem before computing gemm on each k-pipe
|
||||
if (k_block == 0)
|
||||
{
|
||||
// cute::copy(gmem_tiled_copy_A, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter),
|
||||
// tInputsInput(cute::_, cute::_, cute::_, smem_pipe_write));
|
||||
cute::copy_if(gmem_tiled_copy_A, tInputpInput,
|
||||
tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter),
|
||||
tInputsInput(cute::_, cute::_, cute::_, smem_pipe_write));
|
||||
cute::copy(gmem_tiled_copy_B, tfc1gfc1(cute::_, cute::_, cute::_, *k_tile_iter),
|
||||
tfc1sfc1(cute::_, cute::_, cute::_, smem_pipe_write));
|
||||
cute::copy(gmem_tiled_copy_B, tfc1ggfc1g(cute::_, cute::_, cute::_, *k_tile_iter),
|
||||
tfc1gsfc1g(cute::_, cute::_, cute::_, smem_pipe_write));
|
||||
cute::cp_async_fence();
|
||||
if (k_tile_count - 1 > 0)
|
||||
{
|
||||
++k_tile_iter;
|
||||
}
|
||||
|
||||
// Advance the pipe -- Doing it here accounts for K_BLOCK_MAX = 1 (no rmem pipe)
|
||||
smem_pipe_write = smem_pipe_read;
|
||||
++smem_pipe_read;
|
||||
smem_pipe_read = (smem_pipe_read == KT::Stages) ? 0 : smem_pipe_read;
|
||||
}
|
||||
// Thread-level register gemm for k_block
|
||||
cute::gemm(tiled_mma, accum, tOrInput(cute::_, cute::_, k_block), tOrfc1(cute::_, cute::_, k_block),
|
||||
accum);
|
||||
cute::gemm(tiled_mma, accum_gate, tOrInput(cute::_, cute::_, k_block),
|
||||
tOrfc1g(cute::_, cute::_, k_block), accum_gate);
|
||||
});
|
||||
}
|
||||
|
||||
// load tail
|
||||
cute::for_each(cute::make_int_sequence<KT::Stages - 2>{},
|
||||
[&](auto WaitIndex)
|
||||
{
|
||||
k_tile_count--;
|
||||
using WaitIndex_t = decltype(WaitIndex);
|
||||
cute::for_each(cute::make_int_sequence<K_BLOCK_MAX>{},
|
||||
[&](auto k_block)
|
||||
{
|
||||
if (k_block == K_BLOCK_MAX - 1)
|
||||
{
|
||||
tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read);
|
||||
tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read);
|
||||
tOsfc1g_p = tOsfc1g(cute::_, cute::_, cute::_, smem_pipe_read);
|
||||
cute::cp_async_wait<KT::Stages - 3 - WaitIndex_t::value>();
|
||||
__syncthreads();
|
||||
}
|
||||
// Load A, B shmem->regs for k_block+1
|
||||
auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX;
|
||||
cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next),
|
||||
tOrInput_copy_view(cute::_, cute::_, k_block_next));
|
||||
cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next),
|
||||
tOrfc1_copy_view(cute::_, cute::_, k_block_next));
|
||||
cute::copy(smem_tiled_copy_B, tOsfc1g_p(cute::_, cute::_, k_block_next),
|
||||
tOrfc1g_copy_view(cute::_, cute::_, k_block_next));
|
||||
if (k_block == 0)
|
||||
{
|
||||
// only update smem_pipe_read
|
||||
++smem_pipe_read;
|
||||
smem_pipe_read = (smem_pipe_read == KT::Stages) ? 0 : smem_pipe_read;
|
||||
}
|
||||
// Thread-level register gemm for k_block
|
||||
cute::gemm(tiled_mma, accum, tOrInput(cute::_, cute::_, k_block),
|
||||
tOrfc1(cute::_, cute::_, k_block), accum);
|
||||
cute::gemm(tiled_mma, accum_gate, tOrInput(cute::_, cute::_, k_block),
|
||||
tOrfc1g(cute::_, cute::_, k_block), accum_gate);
|
||||
});
|
||||
});
|
||||
// mma tail
|
||||
cute::for_each(cute::make_int_sequence<K_BLOCK_MAX>{},
|
||||
[&](auto k_block)
|
||||
{
|
||||
// Load A, B shmem->regs for k_block+1
|
||||
auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX;
|
||||
cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next),
|
||||
tOrInput_copy_view(cute::_, cute::_, k_block_next));
|
||||
cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next),
|
||||
tOrfc1_copy_view(cute::_, cute::_, k_block_next));
|
||||
cute::copy(smem_tiled_copy_B, tOsfc1g_p(cute::_, cute::_, k_block_next),
|
||||
tOrfc1g_copy_view(cute::_, cute::_, k_block_next));
|
||||
// Thread-level register gemm for k_block
|
||||
cute::gemm(
|
||||
tiled_mma, accum, tOrInput(cute::_, cute::_, k_block), tOrfc1(cute::_, cute::_, k_block), accum);
|
||||
cute::gemm(tiled_mma, accum_gate, tOrInput(cute::_, cute::_, k_block),
|
||||
tOrfc1g(cute::_, cute::_, k_block), accum_gate);
|
||||
});
|
||||
// if (cute::thread0()) {
|
||||
// cute::print(accum_gate(0, 0, 0));
|
||||
// printf("\n");
|
||||
// }
|
||||
// (2) add bias if it has..
|
||||
if (params.ptr_bias != nullptr)
|
||||
{
|
||||
cute::Tensor gBias = gBias_mn(cute::_, cute::_, bias_is_broadcast ? 0 : block_m_idx, block_n_idx);
|
||||
cute::Tensor gBias_gate = gBias_gate_mn(cute::_, cute::_, bias_is_broadcast ? 0 : block_m_idx, block_n_idx);
|
||||
cute::Tensor tOgBias = thr_mma.partition_C(gBias);
|
||||
cute::Tensor tOgBiasg = thr_mma.partition_C(gBias_gate);
|
||||
for (int i = 0; i < cute::size(accum); i++)
|
||||
{
|
||||
accum(i) += tOgBias(i);
|
||||
accum_gate(i) += tOgBiasg(i);
|
||||
}
|
||||
}
|
||||
|
||||
// (3) calculate swiglu
|
||||
using ActivationFn = typename KT::ActivationFn;
|
||||
ActivationFn fn{};
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int temp_iter = 0; temp_iter < cute::size(accum); temp_iter++)
|
||||
{
|
||||
accum(temp_iter) = fn(accum_gate(temp_iter)) * accum(temp_iter);
|
||||
}
|
||||
|
||||
// (4) push all the result to smem
|
||||
// (4.1) convert result from ElementAccum to ElementInput
|
||||
cute::Tensor temp_accum = util_convert_type<KT::ElementOutput>(accum);
|
||||
// if (cute::thread0()) {
|
||||
// cute::print(temp_accum(0, 0, 0));
|
||||
// printf("\n");
|
||||
// }
|
||||
// (4.2) retile rf and smem for copy back..
|
||||
auto smem_tiled_copy_O = cute::make_tiled_copy_C(typename KT::SmemCopyAtomO{}, tiled_mma);
|
||||
auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx);
|
||||
// cute::clear(sO);
|
||||
cute::Tensor taccumrO = smem_thr_copy_O.retile_S(temp_accum);
|
||||
cute::Tensor taccumsO = smem_thr_copy_O.partition_D(sO);
|
||||
|
||||
// (4.3) copy rf result to smem (TODO: maybe use forloop for better performance..)
|
||||
cute::copy(smem_tiled_copy_O, taccumrO, taccumsO);
|
||||
__syncthreads();
|
||||
|
||||
// (4.4) sO -> rO -> gO
|
||||
|
||||
typename KT::GmemTiledCopyO gmem_tiled_copy_O;
|
||||
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
|
||||
// auto gmem_thr_copy_Bias = gmem_tiled_copy_O.get_thread_slice(thread_idx % KT::kGmemTrheadsPerRow); //
|
||||
// remember, for all the threads in the same col, they have the same idx for bias..
|
||||
cute::Tensor gO = gOutput_mn(cute::_, cute::_, block_m_idx, block_n_idx);
|
||||
// cute::Tensor gBias = gBias_mn(cute::_, cute::_, 0, block_n_idx); // bias only have one row..
|
||||
auto tOsO = gmem_thr_copy_O.partition_S(sO);
|
||||
auto tOgO = gmem_thr_copy_O.partition_D(gO);
|
||||
// auto tOgBias = gmem_thr_copy_O.partition_D(gBias);
|
||||
cute::Tensor cOutput = cute::make_identity_tensor(
|
||||
cute::make_shape(cute::size<0>(typename KT::TileShape{}), cute::size<1>(typename KT::TileShape{})));
|
||||
cute::Tensor tOcO = gmem_thr_copy_O.partition_D(cOutput);
|
||||
cute::Tensor tOrO = cute::make_tensor<KT::ElementOutput>(cute::shape(tOgO));
|
||||
cute::copy(gmem_tiled_copy_O, tOsO, tOrO);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int m = 0; m < cute::size<1>(tOgO); ++m)
|
||||
{
|
||||
if (cute::get<0>(tOcO(0, m, 0)) < residue_m)
|
||||
{
|
||||
cute::copy(gmem_tiled_copy_O, tOrO(cute::_, m, cute::_), tOgO(cute::_, m, cute::_));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ElementInput_, typename ElementWeight_, typename ElementOutput_, int TileM_, int TileN_, int TileK_,
|
||||
int Stages_, Activation_Type activation_type_>
|
||||
struct Fused_Moe_Kernel_routine_sm80<ElementInput_, ElementWeight_, ElementOutput_, TileM_, TileN_, TileK_, Stages_,
|
||||
activation_type_, std::enable_if_t<!isGateActivation(activation_type_)>>
|
||||
{
|
||||
|
||||
using KT = Fused_Moe_Kernel_traits_sm80<ElementInput_, ElementWeight_, ElementOutput_, TileM_, TileN_, TileK_,
|
||||
Stages_, activation_type_>;
|
||||
using Params = Routine_Params<ElementInput_, ElementWeight_, ElementOutput_>;
|
||||
|
||||
CUTE_DEVICE auto gmem_tensor_init(int const problem_index, int const gemm_m, Params const& params)
|
||||
{
|
||||
using X = cute::Underscore;
|
||||
|
||||
int const M = gemm_m;
|
||||
int const N1 = params.gemm_n;
|
||||
int const K1 = params.gemm_k;
|
||||
bool const bias_is_broadcast = params.bias_is_broadcast;
|
||||
|
||||
int const row_jump = ((problem_index == 0) ? 0 : params.total_tokens_including_expert[problem_index - 1]);
|
||||
typename KT::ElementInput const* ptr_input_ = params.ptr_input + row_jump * K1;
|
||||
typename KT::ElementWeight const* ptr_fc1_ = params.ptr_fc1 + problem_index * N1 * K1;
|
||||
typename KT::ElementInput const* ptr_bias_ = (params.ptr_bias == nullptr)
|
||||
? nullptr
|
||||
: (bias_is_broadcast ? params.ptr_bias + problem_index * N1 : params.ptr_bias + row_jump * N1);
|
||||
typename KT::ElementOutput* ptr_output_ = params.ptr_output + row_jump * N1;
|
||||
|
||||
cute::Tensor mInput_mk
|
||||
= cute::make_tensor(cute::make_gmem_ptr(static_cast<typename KT::ElementInput const*>(ptr_input_)),
|
||||
cute::make_shape(M, K1), cute::make_stride(K1, cute::_1{}));
|
||||
|
||||
cute::Tensor mfc1_nk
|
||||
= cute::make_tensor(cute::make_gmem_ptr(static_cast<typename KT::ElementWeight const*>(ptr_fc1_)),
|
||||
cute::make_shape(N1, K1), cute::make_stride(K1, cute::_1{}));
|
||||
|
||||
cute::Tensor mBias_mn = cute::make_tensor(
|
||||
cute::make_gmem_ptr(static_cast<typename KT::ElementInput const*>(ptr_bias_)), cute::make_shape(M, N1),
|
||||
cute::make_stride(bias_is_broadcast ? cute::Int<0>{} : N1,
|
||||
cute::_1{})); // trick: bias shape is [1, N], but we use [M, N].
|
||||
|
||||
cute::Tensor mOutput_mn
|
||||
= cute::make_tensor(cute::make_gmem_ptr(static_cast<typename KT::ElementInput*>(ptr_output_)),
|
||||
cute::make_shape(M, N1), cute::make_stride(N1, cute::_1{}));
|
||||
|
||||
cute::Tensor gInput_mk = cute::local_tile(mInput_mk, typename KT::TileShape{},
|
||||
cute::make_coord(cute::_, cute::_, cute::_), cute::Step<cute::_1, X, cute::_1>{}); // (BLK_M, BLK_K, m, k)
|
||||
cute::Tensor gfc1_nk = cute::local_tile(mfc1_nk, typename KT::TileShape{},
|
||||
cute::make_coord(cute::_, cute::_, cute::_), cute::Step<X, cute::_1, cute::_1>{}); // (BLK_N, BLK_K, n, k)
|
||||
|
||||
cute::Tensor gBias_mn = cute::local_tile(mBias_mn, typename KT::TileShape{},
|
||||
cute::make_coord(cute::_, cute::_, cute::_), cute::Step<cute::_1, cute::_1, X>{}); // (BLK_M, BLK_N, m, n)
|
||||
|
||||
cute::Tensor gOutput_mn = cute::local_tile(mOutput_mn, typename KT::TileShape{},
|
||||
cute::make_coord(cute::_, cute::_, cute::_), cute::Step<cute::_1, cute::_1, X>{}); // (BLK_M, BLK_N, m, n)
|
||||
|
||||
return cute::make_tuple(gInput_mk, gfc1_nk, gBias_mn, gOutput_mn);
|
||||
}
|
||||
|
||||
// be careful, m_idx will change when use another tile shape..
|
||||
CUTE_DEVICE void run_routine(
|
||||
Params const& params, int const problem_index, int const block_m_idx, int const block_n_idx, int const gemm_m)
|
||||
{
|
||||
extern __shared__ char smem_[];
|
||||
typename KT::SharedStorage& shared_storage = *reinterpret_cast<typename KT::SharedStorage*>(smem_);
|
||||
int const thread_idx = threadIdx.x;
|
||||
bool const bias_is_broadcast = params.bias_is_broadcast;
|
||||
// gmem tensor partition ..
|
||||
auto [gInput_mk, gfc1_nk, gBias_mn, gOutput_mn] = gmem_tensor_init(problem_index, gemm_m, params);
|
||||
int const residue_m = gemm_m - block_m_idx * cute::size<0>(gInput_mk);
|
||||
auto const n_tile_count = cute::size<2>(gfc1_nk);
|
||||
|
||||
// smem tensor ..
|
||||
cute::Tensor sInput = cute::make_tensor(
|
||||
cute::make_smem_ptr(shared_storage.smem_input.data()), typename KT::SmemLayoutA{}); // (BLK_M, BLK_K, Stage)
|
||||
cute::Tensor sfc1_weight = cute::make_tensor(cute::make_smem_ptr(shared_storage.smem_fc1_weight.data()),
|
||||
typename KT::SmemLayoutB{}); // (BLK_N, BLK_K, Stage)
|
||||
cute::Tensor sO = cute::make_tensor(
|
||||
cute::make_smem_ptr(shared_storage.smem_o.data()), typename KT::SmemLayoutO{}); // (BLK_M, BLK_N)
|
||||
|
||||
// (1) first step, get the fc1_res and fc1_gate
|
||||
|
||||
// (1.1) get partition for gmem -> smem
|
||||
cute::Tensor gInput = gInput_mk(cute::_, cute::_, block_m_idx, cute::_); // (BLK_M, BLK_K, k)
|
||||
cute::Tensor gfc1 = gfc1_nk(cute::_, cute::_, block_n_idx, cute::_); // (BLK_N, BLK_K, k)
|
||||
|
||||
typename KT::GmemTiledCopyA gmem_tiled_copy_A;
|
||||
typename KT::GmemTiledCopyB gmem_tiled_copy_B;
|
||||
auto gmem_thr_copy_A = gmem_tiled_copy_A.get_slice(thread_idx);
|
||||
auto gmem_thr_copy_B = gmem_tiled_copy_B.get_slice(thread_idx);
|
||||
|
||||
cute::Tensor tInputgInput = gmem_thr_copy_A.partition_S(gInput); // (ACPY,ACPY_M,ACPY_K,k)
|
||||
cute::Tensor tInputsInput = gmem_thr_copy_A.partition_S(sInput); // (ACPY,ACPY_M,ACPY_K,Stage)
|
||||
cute::Tensor tfc1gfc1 = gmem_thr_copy_B.partition_S(gfc1); // (BCPY,BCPY_N,BCPY_K,k)
|
||||
cute::Tensor tfc1sfc1 = gmem_thr_copy_B.partition_D(sfc1_weight); // (BCPY,BCPY_N,BCPY_K,Stage)
|
||||
|
||||
// Allocate predicate tensors for input and fc weight (actually we only need input predicate tensor)
|
||||
cute::Tensor tInputpInput
|
||||
= cute::make_tensor<bool>(cute::make_shape(cute::size<1>(tInputsInput), cute::size<2>(tInputsInput)),
|
||||
cute::Stride<cute::_1, cute::_0>{});
|
||||
// Construct identity layout for sInput
|
||||
cute::Tensor cInput = make_identity_tensor(
|
||||
make_shape(cute::size<0>(sInput), cute::size<1>(sInput))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
||||
|
||||
// Repeat the partitioning with identity layouts
|
||||
cute::Tensor tInputcInput = gmem_thr_copy_A.partition_S(cInput); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
|
||||
|
||||
// Set predicates for m bounds
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int m = 0; m < cute::size<0>(tInputpInput); ++m)
|
||||
{
|
||||
tInputpInput(m, 0) = cute::get<0>(tInputcInput(0, m, 0)) < residue_m; // blk_m coord < residue_m
|
||||
}
|
||||
|
||||
// (1.2) prefetch gmem -> smem
|
||||
cute::clear(tInputsInput); // we don't need to clear tfc1sfc1..
|
||||
auto k_tile_iter = cute::make_coord_iterator(cute::size<2>(gInput)); // emm, iter start from 0
|
||||
int k_tile_count = cute::size<2>(gInput);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_pipe = 0; k_pipe < KT::Stages - 1; ++k_pipe)
|
||||
{
|
||||
if (k_tile_count <= 0)
|
||||
{
|
||||
cute::clear(tInputpInput);
|
||||
}
|
||||
// cute::copy(gmem_tiled_copy_A, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter),
|
||||
// tInputsInput(cute::_, cute::_, cute::_, k_pipe));
|
||||
// use copy_if
|
||||
cute::copy_if(gmem_tiled_copy_A, tInputpInput, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter),
|
||||
tInputsInput(cute::_, cute::_, cute::_, k_pipe));
|
||||
cute::copy(gmem_tiled_copy_B, tfc1gfc1(cute::_, cute::_, cute::_, *k_tile_iter),
|
||||
tfc1sfc1(cute::_, cute::_, cute::_, k_pipe));
|
||||
cute::cp_async_fence();
|
||||
k_tile_count--;
|
||||
if (k_tile_count > 0)
|
||||
{
|
||||
++k_tile_iter;
|
||||
}
|
||||
}
|
||||
|
||||
// (1.3) get partition for rf
|
||||
typename KT::TiledMma tiled_mma;
|
||||
auto thr_mma = tiled_mma.get_thread_slice(thread_idx);
|
||||
cute::Tensor tOrInput = thr_mma.partition_fragment_A(sInput(cute::_, cute::_, 0)); // (MMA,MMA_M,MMA_K)
|
||||
cute::Tensor tOrfc1 = thr_mma.partition_fragment_B(sfc1_weight(cute::_, cute::_, 0)); // (MMA,MMA_N,MMA_K)
|
||||
|
||||
cute::Tensor accum
|
||||
= cute::partition_fragment_C(tiled_mma, cute::take<0, 2>(typename KT::TileShape{})); // (MMA,MMA_M,MMA_N)
|
||||
cute::clear(accum);
|
||||
// checkout the shape
|
||||
CUTE_STATIC_ASSERT_V(cute::size<1>(tOrInput) == cute::size<1>(accum)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(cute::size<1>(tOrfc1) == cute::size<2>(accum)); // MMA_N
|
||||
CUTE_STATIC_ASSERT_V(cute::size<2>(tOrInput) == cute::size<2>(tOrfc1)); // MMA_K
|
||||
CUTE_STATIC_ASSERT_V(cute::size(gmem_tiled_copy_A) == cute::size(tiled_mma));
|
||||
CUTE_STATIC_ASSERT_V(cute::size(gmem_tiled_copy_B) == cute::size(tiled_mma));
|
||||
|
||||
// (1.4)retiling the smem and rf for copy..
|
||||
auto smem_tiled_copy_A = cute::make_tiled_copy_A(typename KT::SmemCopyAtomA{}, tiled_mma);
|
||||
auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx);
|
||||
cute::Tensor tOsInput = smem_thr_copy_A.partition_S(sInput); // (CPY,CPY_M,CPY_K,Stage)
|
||||
cute::Tensor tOrInput_copy_view = smem_thr_copy_A.retile_D(tOrInput); // (CPY,CPY_M,CPY_K)
|
||||
CUTE_STATIC_ASSERT_V(cute::size<1>(tOsInput) == cute::size<1>(tOrInput_copy_view)); // CPY_M
|
||||
CUTE_STATIC_ASSERT_V(cute::size<2>(tOsInput) == cute::size<2>(tOrInput_copy_view)); // CPY_K
|
||||
|
||||
auto smem_tiled_copy_B = cute::make_tiled_copy_B(typename KT::SmemCopyAtomB{}, tiled_mma);
|
||||
auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx);
|
||||
cute::Tensor tOsfc1 = smem_thr_copy_B.partition_S(sfc1_weight); // (CPY,CPY_N,CPY_K,Stage)
|
||||
cute::Tensor tOrfc1_copy_view = smem_thr_copy_B.retile_D(tOrfc1); // (CPY,CPY_N,CPY_K)
|
||||
CUTE_STATIC_ASSERT_V(cute::size<1>(tOsfc1) == cute::size<1>(tOrfc1_copy_view)); // CPY_N
|
||||
CUTE_STATIC_ASSERT_V(cute::size<2>(tOsfc1) == cute::size<2>(tOrfc1_copy_view)); // CPY_K
|
||||
|
||||
// (1.5) mainloop
|
||||
// Current pipe index in smem to read from
|
||||
int smem_pipe_read = 0;
|
||||
// Current pipe index in smem to write to
|
||||
int smem_pipe_write = KT::Stages - 1;
|
||||
|
||||
cute::Tensor tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read);
|
||||
cute::Tensor tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read);
|
||||
|
||||
constexpr int K_BLOCK_MAX = cute::size<2>(tOrInput);
|
||||
// prefetch register pipeline
|
||||
if constexpr (K_BLOCK_MAX > 1)
|
||||
{
|
||||
cute::cp_async_wait<KT::Stages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Prefetch the first rmem from the first k-tile
|
||||
cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, cute::Int<0>{}),
|
||||
tOrInput_copy_view(cute::_, cute::_, cute::Int<0>{}));
|
||||
cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, cute::Int<0>{}),
|
||||
tOrfc1_copy_view(cute::_, cute::_, cute::Int<0>{}));
|
||||
}
|
||||
// k loop for mainloop
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; k_tile_count > 0; --k_tile_count)
|
||||
{
|
||||
cute::for_each(cute::make_int_sequence<K_BLOCK_MAX>{},
|
||||
[&](auto k_block)
|
||||
{
|
||||
if (k_block == K_BLOCK_MAX - 1)
|
||||
{
|
||||
tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read);
|
||||
tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read);
|
||||
cute::cp_async_wait<KT::Stages - 2>();
|
||||
__syncthreads();
|
||||
}
|
||||
// Load A, B shmem->regs for k_block+1
|
||||
auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX;
|
||||
cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next),
|
||||
tOrInput_copy_view(cute::_, cute::_, k_block_next));
|
||||
cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next),
|
||||
tOrfc1_copy_view(cute::_, cute::_, k_block_next));
|
||||
// Copy gmem to smem before computing gemm on each k-pipe
|
||||
if (k_block == 0)
|
||||
{
|
||||
// cute::copy(gmem_tiled_copy_A, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter),
|
||||
// tInputsInput(cute::_, cute::_, cute::_, smem_pipe_write));
|
||||
cute::copy_if(gmem_tiled_copy_A, tInputpInput,
|
||||
tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter),
|
||||
tInputsInput(cute::_, cute::_, cute::_, smem_pipe_write));
|
||||
cute::copy(gmem_tiled_copy_B, tfc1gfc1(cute::_, cute::_, cute::_, *k_tile_iter),
|
||||
tfc1sfc1(cute::_, cute::_, cute::_, smem_pipe_write));
|
||||
cute::cp_async_fence();
|
||||
if (k_tile_count - 1 > 0)
|
||||
{
|
||||
++k_tile_iter;
|
||||
}
|
||||
|
||||
// Advance the pipe -- Doing it here accounts for K_BLOCK_MAX = 1 (no rmem pipe)
|
||||
smem_pipe_write = smem_pipe_read;
|
||||
++smem_pipe_read;
|
||||
smem_pipe_read = (smem_pipe_read == KT::Stages) ? 0 : smem_pipe_read;
|
||||
}
|
||||
// Thread-level register gemm for k_block
|
||||
cute::gemm(tiled_mma, accum, tOrInput(cute::_, cute::_, k_block), tOrfc1(cute::_, cute::_, k_block),
|
||||
accum);
|
||||
});
|
||||
}
|
||||
// load tail
|
||||
cute::for_each(cute::make_int_sequence<KT::Stages - 2>{},
|
||||
[&](auto WaitIndex)
|
||||
{
|
||||
k_tile_count--;
|
||||
using WaitIndex_t = decltype(WaitIndex);
|
||||
cute::for_each(cute::make_int_sequence<K_BLOCK_MAX>{},
|
||||
[&](auto k_block)
|
||||
{
|
||||
if (k_block == K_BLOCK_MAX - 1)
|
||||
{
|
||||
tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read);
|
||||
tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read);
|
||||
cute::cp_async_wait<KT::Stages - 3 - WaitIndex_t::value>();
|
||||
__syncthreads();
|
||||
}
|
||||
// Load A, B shmem->regs for k_block+1
|
||||
auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX;
|
||||
cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next),
|
||||
tOrInput_copy_view(cute::_, cute::_, k_block_next));
|
||||
cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next),
|
||||
tOrfc1_copy_view(cute::_, cute::_, k_block_next));
|
||||
if (k_block == 0)
|
||||
{
|
||||
// only update smem_pipe_read
|
||||
++smem_pipe_read;
|
||||
smem_pipe_read = (smem_pipe_read == KT::Stages) ? 0 : smem_pipe_read;
|
||||
}
|
||||
// Thread-level register gemm for k_block
|
||||
cute::gemm(tiled_mma, accum, tOrInput(cute::_, cute::_, k_block),
|
||||
tOrfc1(cute::_, cute::_, k_block), accum);
|
||||
});
|
||||
});
|
||||
// mma tail
|
||||
cute::for_each(cute::make_int_sequence<K_BLOCK_MAX>{},
|
||||
[&](auto k_block)
|
||||
{
|
||||
// Load A, B shmem->regs for k_block+1
|
||||
auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX;
|
||||
cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next),
|
||||
tOrInput_copy_view(cute::_, cute::_, k_block_next));
|
||||
cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next),
|
||||
tOrfc1_copy_view(cute::_, cute::_, k_block_next));
|
||||
// Thread-level register gemm for k_block
|
||||
cute::gemm(
|
||||
tiled_mma, accum, tOrInput(cute::_, cute::_, k_block), tOrfc1(cute::_, cute::_, k_block), accum);
|
||||
});
|
||||
// if (cute::thread0()) {
|
||||
// cute::print(accum_gate(0, 0, 0));
|
||||
// printf("\n");
|
||||
// }
|
||||
// (2) add bias if it has..
|
||||
if (params.ptr_bias != nullptr)
|
||||
{
|
||||
cute::Tensor gBias = gBias_mn(cute::_, cute::_, bias_is_broadcast ? 0 : block_m_idx, block_n_idx);
|
||||
cute::Tensor tOgBias = thr_mma.partition_C(gBias);
|
||||
for (int i = 0; i < cute::size(accum); i++)
|
||||
{
|
||||
accum(i) += tOgBias(i);
|
||||
}
|
||||
}
|
||||
// (3) calculate swiglu
|
||||
using ActivationFn = typename KT::ActivationFn;
|
||||
ActivationFn fn{};
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int temp_iter = 0; temp_iter < cute::size(accum); temp_iter++)
|
||||
{
|
||||
accum(temp_iter) = fn(accum(temp_iter));
|
||||
}
|
||||
|
||||
// (4) push all the result to smem
|
||||
// (4.1) convert result from ElementAccum to ElementInput
|
||||
cute::Tensor temp_accum = util_convert_type<KT::ElementOutput>(accum);
|
||||
// if (cute::thread0()) {
|
||||
// cute::print(temp_accum(0, 0, 0));
|
||||
// printf("\n");
|
||||
// }
|
||||
// (4.2) retile rf and smem for copy back..
|
||||
auto smem_tiled_copy_O = cute::make_tiled_copy_C(typename KT::SmemCopyAtomO{}, tiled_mma);
|
||||
auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx);
|
||||
// cute::clear(sO);
|
||||
cute::Tensor taccumrO = smem_thr_copy_O.retile_S(temp_accum);
|
||||
cute::Tensor taccumsO = smem_thr_copy_O.partition_D(sO);
|
||||
|
||||
// (4.3) copy rf result to smem (TODO: maybe use forloop for better performance..)
|
||||
cute::copy(smem_tiled_copy_O, taccumrO, taccumsO);
|
||||
__syncthreads();
|
||||
|
||||
// (4.4) sO -> rO -> gO
|
||||
|
||||
typename KT::GmemTiledCopyO gmem_tiled_copy_O;
|
||||
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
|
||||
// auto gmem_thr_copy_Bias = gmem_tiled_copy_O.get_thread_slice(thread_idx % KT::kGmemTrheadsPerRow); //
|
||||
cute::Tensor gO = gOutput_mn(cute::_, cute::_, block_m_idx, block_n_idx);
|
||||
auto tOsO = gmem_thr_copy_O.partition_S(sO);
|
||||
auto tOgO = gmem_thr_copy_O.partition_D(gO);
|
||||
cute::Tensor cOutput = cute::make_identity_tensor(
|
||||
cute::make_shape(cute::size<0>(typename KT::TileShape{}), cute::size<1>(typename KT::TileShape{})));
|
||||
cute::Tensor tOcO = gmem_thr_copy_O.partition_D(cOutput);
|
||||
cute::Tensor tOrO = cute::make_tensor<KT::ElementOutput>(cute::shape(tOgO));
|
||||
cute::copy(gmem_tiled_copy_O, tOsO, tOrO);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int m = 0; m < cute::size<1>(tOgO); ++m)
|
||||
{
|
||||
if (cute::get<0>(tOcO(0, m, 0)) < residue_m)
|
||||
{
|
||||
cute::copy(gmem_tiled_copy_O, tOrO(cute::_, m, cute::_), tOgO(cute::_, m, cute::_));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace fused_moe
|
||||
@@ -1,215 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/epilogue/thread/activation.h>
|
||||
#include <cutlass_extensions/epilogue_helpers.h>
|
||||
#include <cutlass_extensions/gemm/kernel/moe_cute_util.cuh>
|
||||
#include <cutlass_extensions/gemm/kernel/moe_problem_visitor.h>
|
||||
|
||||
namespace fused_moe
|
||||
{
|
||||
template <typename ElementInput, typename ElementWeight, typename ElementOutput>
|
||||
struct Routine_Arguments
|
||||
{
|
||||
ElementInput* ptr_input{};
|
||||
ElementWeight* ptr_fc1{};
|
||||
ElementInput* ptr_bias{};
|
||||
ElementOutput* ptr_output{};
|
||||
int64_t const* total_tokens_including_expert{};
|
||||
int gemm_n{};
|
||||
int gemm_k{};
|
||||
int num_expert{};
|
||||
bool bias_is_broadcast{};
|
||||
};
|
||||
|
||||
template <typename ElementInput, typename ElementWeight, typename ElementOutput>
|
||||
struct Routine_Params
|
||||
{
|
||||
ElementInput* ptr_input{};
|
||||
ElementWeight* ptr_fc1{};
|
||||
ElementInput* ptr_bias{};
|
||||
ElementOutput* ptr_output{};
|
||||
int64_t const* total_tokens_including_expert{};
|
||||
int gemm_n{};
|
||||
int gemm_k{};
|
||||
int num_expert{};
|
||||
bool bias_is_broadcast{};
|
||||
};
|
||||
|
||||
enum class Activation_Type
|
||||
{
|
||||
Gelu = 0,
|
||||
Relu,
|
||||
Silu,
|
||||
Swiglu,
|
||||
Geglu,
|
||||
Identity,
|
||||
InvalidType
|
||||
};
|
||||
|
||||
constexpr bool isGateActivation(Activation_Type const& activation_type)
|
||||
{
|
||||
return activation_type == Activation_Type::Swiglu || activation_type == Activation_Type::Geglu;
|
||||
}
|
||||
|
||||
template <typename CutlassExtensionEpilogueTag>
|
||||
constexpr Activation_Type EpilogueRouting(bool /*is_gate*/)
|
||||
{
|
||||
return Activation_Type::InvalidType;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr Activation_Type EpilogueRouting<tensorrt_llm::cutlass_extensions::EpilogueOpDefault>(bool /*is_gate*/)
|
||||
{
|
||||
return Activation_Type::Identity;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr Activation_Type EpilogueRouting<tensorrt_llm::cutlass_extensions::EpilogueOpDefaultReLU>(bool /*is_gate*/)
|
||||
{
|
||||
return Activation_Type::Relu;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr Activation_Type EpilogueRouting<tensorrt_llm::cutlass_extensions::EpilogueOpDefaultSilu>(bool is_gate)
|
||||
{
|
||||
return is_gate ? Activation_Type::Swiglu : Activation_Type::Silu;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr Activation_Type EpilogueRouting<tensorrt_llm::cutlass_extensions::EpilogueOpDefaultFtGelu>(bool is_gate)
|
||||
{
|
||||
return is_gate ? Activation_Type::Geglu : Activation_Type::Gelu;
|
||||
}
|
||||
|
||||
/* fusing all three kernels has many limitations. This is the simpler version. Just fuse first two kernels..*/
|
||||
template <typename ElementInput_, typename ElementWeight_, typename ElementOutput_, int TileM_, int TileN_, int TileK_,
|
||||
int Stages_, Activation_Type activation_type>
|
||||
struct Fused_Moe_Kernel_traits_sm80
|
||||
{
|
||||
using ElementInput = ElementInput_;
|
||||
using ElementWeight = ElementWeight_;
|
||||
using ElementAccum = float;
|
||||
using ElementOutput = ElementOutput_;
|
||||
|
||||
using index_t = uint32_t;
|
||||
static_assert(TileM_ % 16 == 0);
|
||||
static_assert(TileN_ % 32 == 0);
|
||||
static_assert(TileK_ % 32 == 0);
|
||||
static constexpr int Stages = Stages_;
|
||||
static constexpr int kTileM = TileM_;
|
||||
static constexpr int kTileN = TileN_;
|
||||
static constexpr int kTileK = (kTileM > 16) ? (TileK_) : (TileK_ >= 64 ? TileK_ : 64);
|
||||
|
||||
// tile shape
|
||||
using TileShape = cute::Shape<cute::Int<kTileM>, cute::Int<kTileN>, cute::Int<kTileK>>;
|
||||
static constexpr int kWarpsCount = 4;
|
||||
static constexpr int kThreadCount = kWarpsCount * 32;
|
||||
|
||||
// MMA atom arch and layout
|
||||
using MMA_Atom_Arch = std::conditional_t<std::is_same_v<ElementInput, cutlass::half_t>,
|
||||
cute::MMA_Atom<cute::SM80_16x8x16_F32F16F16F32_TN>, cute::MMA_Atom<cute::SM80_16x8x16_F32BF16BF16F32_TN>>;
|
||||
// using ValLayoutMNK = cute::Layout<cute::Shape<cute::_1, cute::_2, cute::_1>>;
|
||||
using ThreadLayoutMNK
|
||||
= std::conditional_t<kTileM == 16, cute::Layout<cute::Shape<cute::_1, cute::Int<kWarpsCount / 1>, cute::_1>>,
|
||||
cute::Layout<cute::Shape<cute::_2, cute::Int<kWarpsCount / 2>, cute::_1>>>;
|
||||
using ValLayoutMNK = std::conditional_t<kTileM == 16, cute::Tile<cute::_16, cute::_64, cute::_16>,
|
||||
cute::Tile<cute::_32, cute::_32, cute::_16>>;
|
||||
using TiledMma = cute::TiledMMA<MMA_Atom_Arch, ThreadLayoutMNK,
|
||||
ValLayoutMNK>; // 32x32x16 or 16x64x16 MMA for LDSM if kWarp = 4
|
||||
static constexpr int kAlignment = 8;
|
||||
static constexpr int kBlcokKSmem = (kTileM == 16) ? 64 : 32;
|
||||
// A memory copy operand
|
||||
using DefaultOperandA
|
||||
= DefaultGemm_TensorOpSm80_OperandA<ElementInput, cutlass::layout::RowMajor, kAlignment, kBlcokKSmem>;
|
||||
using SmemLayoutAtomA = typename DefaultOperandA::SmemLayoutAtom;
|
||||
using SmemCopyAtomA = typename DefaultOperandA::SmemCopyAtom;
|
||||
using GmemTiledCopyA = typename DefaultOperandA::GmemTiledCopy;
|
||||
|
||||
// B memory copy operand
|
||||
using DefaultOperandB
|
||||
= DefaultGemm_TensorOpSm80_OperandB<ElementWeight, cutlass::layout::ColumnMajor, kAlignment, kBlcokKSmem>;
|
||||
using SmemLayoutAtomB = typename DefaultOperandB::SmemLayoutAtom;
|
||||
using SmemCopyAtomB = typename DefaultOperandB::SmemCopyAtom;
|
||||
using GmemTiledCopyB = typename DefaultOperandB::GmemTiledCopy;
|
||||
|
||||
// Output memory copy operand
|
||||
using SmemLayoutAtomO = SmemLayoutAtomA;
|
||||
using SmemCopyAtomO = cute::Copy_Atom<cute::DefaultCopy, ElementOutput>;
|
||||
static constexpr int kGmemElementPerLoad = sizeof(cute::uint128_t) / sizeof(ElementOutput);
|
||||
static constexpr int kGmemTrheadsPerRow = kBlcokKSmem / kGmemElementPerLoad;
|
||||
using GmemLayoutAtomO
|
||||
= cute::Layout<cute::Shape<cute::Int<kThreadCount / kGmemTrheadsPerRow>, cute::Int<kGmemTrheadsPerRow>>,
|
||||
cute::Stride<cute::Int<kGmemTrheadsPerRow>, cute::_1>>;
|
||||
using GmemTiledCopyO = decltype(cute::make_tiled_copy(cute::Copy_Atom<cute::DefaultCopy, ElementOutput>{},
|
||||
GmemLayoutAtomO{}, cute::Layout<cute::Shape<cute::_1, cute::_8>>{}));
|
||||
|
||||
static_assert(cute::rank(SmemLayoutAtomA{}) == 2);
|
||||
static_assert(cute::size<0>(TileShape{}) % cute::size<0>(SmemLayoutAtomA{}) == 0); // M
|
||||
static_assert(cute::size<2>(TileShape{}) % cute::size<1>(SmemLayoutAtomA{}) == 0); // K
|
||||
static_assert(cute::rank(SmemLayoutAtomB{}) == 2);
|
||||
static_assert(cute::size<1>(TileShape{}) % cute::size<0>(SmemLayoutAtomB{}) == 0); // N
|
||||
static_assert(cute::size<2>(TileShape{}) % cute::size<1>(SmemLayoutAtomB{}) == 0); // K
|
||||
|
||||
using SmemLayoutA = decltype(cute::tile_to_shape(SmemLayoutAtomA{},
|
||||
cute::make_shape(
|
||||
cute::shape<0>(TileShape{}), cute::shape<2>(TileShape{}), cute::Int<Stages>{}))); // BLK_M, BLK_K, Stages
|
||||
using SmemLayoutB = decltype(cute::tile_to_shape(SmemLayoutAtomB{},
|
||||
cute::make_shape(
|
||||
cute::shape<1>(TileShape{}), cute::shape<2>(TileShape{}), cute::Int<Stages>{}))); // BLK_N, BLK_K, Stages
|
||||
using SmemLayoutO = decltype(cute::tile_to_shape(
|
||||
SmemLayoutAtomO{}, cute::make_shape(cute::shape<0>(TileShape{}), cute::shape<1>(TileShape{})))); // BLK_M, BLK_N
|
||||
|
||||
// we need at least 2 stages..
|
||||
static_assert(Stages >= 2);
|
||||
|
||||
struct SharedStorageNormal : cute::aligned_struct<128>
|
||||
{
|
||||
cute::array_aligned<ElementInput, cute::cosize_v<SmemLayoutA>> smem_input;
|
||||
cute::array_aligned<ElementInput, cute::cosize_v<SmemLayoutB>> smem_fc1_weight;
|
||||
cute::array_aligned<ElementInput, cute::cosize_v<SmemLayoutO>> smem_o;
|
||||
};
|
||||
|
||||
struct SharedStorageGate : cute::aligned_struct<128>
|
||||
{
|
||||
cute::array_aligned<ElementInput, cute::cosize_v<SmemLayoutA>> smem_input;
|
||||
cute::array_aligned<ElementInput, cute::cosize_v<SmemLayoutB>> smem_fc1_gate_weight;
|
||||
cute::array_aligned<ElementInput, cute::cosize_v<SmemLayoutB>> smem_fc1_weight;
|
||||
cute::array_aligned<ElementInput, cute::cosize_v<SmemLayoutO>> smem_o;
|
||||
};
|
||||
|
||||
using SharedStorage = std::conditional_t<isGateActivation(activation_type), SharedStorageGate, SharedStorageNormal>;
|
||||
|
||||
using ActivationFn = std::conditional_t<activation_type == Activation_Type::Gelu
|
||||
|| activation_type == Activation_Type::Geglu,
|
||||
cutlass::epilogue::thread::GELU<float>,
|
||||
std::conditional_t<activation_type == Activation_Type::Relu, cutlass::epilogue::thread::ReLU<float>,
|
||||
std::conditional_t<activation_type == Activation_Type::Silu || activation_type == Activation_Type::Swiglu,
|
||||
cutlass::epilogue::thread::SiLu<float>, cutlass::epilogue::thread::Identity<float>>>>;
|
||||
|
||||
static constexpr int kSmemSize = static_cast<int>(sizeof(SharedStorage));
|
||||
|
||||
static constexpr bool can_implement(int const avaliable_smem_size)
|
||||
{
|
||||
return avaliable_smem_size > kSmemSize;
|
||||
}
|
||||
|
||||
// #endif
|
||||
};
|
||||
} // namespace fused_moe
|
||||
@@ -1,73 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/*! \file
|
||||
\brief Scheduler for grouped GEMM
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h"
|
||||
#include "cutlass_extensions/gemm/kernel/moe_problem_visitor.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace kernel
|
||||
{
|
||||
|
||||
/// Visitor class to abstract away the algorithm for iterating over tiles
|
||||
template <typename ThreadblockShape, GroupScheduleMode GroupScheduleMode_, int PrefetchTileCount, int ThreadCount,
|
||||
bool Transposed = false>
|
||||
struct GemmMoeProblemVisitor
|
||||
: public MoeProblemVisitor<detail::GemmGroupedProblemSizeHelper<ThreadblockShape, Transposed>, ThreadblockShape,
|
||||
GroupScheduleMode_, PrefetchTileCount, ThreadCount>
|
||||
{
|
||||
|
||||
static bool const kTransposed = Transposed;
|
||||
|
||||
using ProblemSizeHelper = detail::GemmGroupedProblemSizeHelper<ThreadblockShape, Transposed>;
|
||||
using Base
|
||||
= MoeProblemVisitor<ProblemSizeHelper, ThreadblockShape, GroupScheduleMode_, PrefetchTileCount, ThreadCount>;
|
||||
using Params = typename Base::Params;
|
||||
using SharedStorage = typename Base::SharedStorage;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_DEVICE
|
||||
GemmMoeProblemVisitor(Params const& params_, SharedStorage& shared_storage_, int32_t block_idx)
|
||||
: Base(params_, shared_storage_, block_idx)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -1,70 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::kernel
|
||||
{
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/*
|
||||
* Stateless universal device GEMM kernel type that treats GEMM as
|
||||
* a composition of a collective mainloop and a collective epilogue.
|
||||
*
|
||||
* Supports both the 2.x and 3.x APIs based on whether the first type is
|
||||
* a cute::tuple<> or not.
|
||||
* 2.x API implementation: cutlass/gemm/kernel/gemm_universal.h
|
||||
* 3.x API implementation: cutlass/gemm/kernel/gemm_*.hpp
|
||||
*
|
||||
* In the following declaration, the name preceding the 'Or' refers to
|
||||
* 3.x API type argument order, and the name succeeding the 'Or' refers to
|
||||
* 2.x API type argument order. Template arguments without two names
|
||||
* belong to the 3.x API only.
|
||||
**/
|
||||
template <class ProblemShapeOrThreadblockMma_, // (m, n, k) or (m, n, k, l)
|
||||
class CollectiveMainloopOrEpilogue_, class CollectiveEpilogueOrThreadblockSwizzle_, class TileScheduler_ = void,
|
||||
class Enable = void>
|
||||
class GemmUniversalGated;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::kernel
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_cooperative.hpp"
|
||||
#include "cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_pingpong.hpp"
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -1,585 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief GEMM kernel to support the epilogue visitor model
|
||||
for customized softmax partial reduction epilogue fusion.
|
||||
|
||||
This source file will likely be moved to `include/cutlass/gemm/kernel/` in the future once
|
||||
its usage has been stabilized. For now, it is included in this example to demonstrate
|
||||
some basic output fusion options.
|
||||
|
||||
original file: 3rdparty/cutlass/examples/35_gemm_softmax/gemm_with_epilogue_visitor.h
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/semaphore.h"
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
#include "cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h"
|
||||
|
||||
namespace tk = tensorrt_llm::common;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace kernel
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||
typename Epilogue_, ///! Epilogue
|
||||
typename ThreadblockSwizzle_ ///! Threadblock swizzling function
|
||||
>
|
||||
struct GemmWithEpilogueVisitor
|
||||
{
|
||||
public:
|
||||
using Mma = Mma_;
|
||||
using Epilogue = Epilogue_;
|
||||
using EpilogueVisitor = typename Epilogue::Visitor;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
|
||||
using ElementA = typename Mma::IteratorA::Element;
|
||||
using LayoutA = typename Mma::IteratorA::Layout;
|
||||
using TensorRefA = TensorRef<ElementA, LayoutA>;
|
||||
|
||||
using ElementB = typename Mma::IteratorB::Element;
|
||||
using LayoutB = typename Mma::IteratorB::Layout;
|
||||
using TensorRefB = TensorRef<ElementB, LayoutB>;
|
||||
|
||||
using ElementCompute = typename EpilogueVisitor::ElementCompute;
|
||||
using LayoutAlphaCol = cutlass::layout::RowMajor;
|
||||
using LayoutAlphaRow = cutlass::layout::ColumnMajor;
|
||||
using TensorRefAlphaCol = TensorRef<ElementCompute, LayoutAlphaCol>;
|
||||
using TensorRefAlphaRow = TensorRef<ElementCompute, LayoutAlphaRow>;
|
||||
|
||||
using ElementC = typename EpilogueVisitor::ElementOutput;
|
||||
using LayoutC = typename Epilogue::Layout;
|
||||
using TensorRefC = TensorRef<ElementC, LayoutC>;
|
||||
|
||||
static ComplexTransform const kTransformA = Mma::kTransformA;
|
||||
static ComplexTransform const kTransformB = Mma::kTransformB;
|
||||
using Operator = typename Mma::Operator;
|
||||
|
||||
using OperatorClass = typename Mma::Operator::OperatorClass;
|
||||
using ThreadblockShape = typename Mma::Shape;
|
||||
using WarpShape = typename Mma::Operator::Shape;
|
||||
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
|
||||
using ArchTag = typename Mma::ArchTag;
|
||||
using EpilogueOutputOp =
|
||||
typename Epilogue::Visitor::ElementwiseFunctor; // Define type so GemmUniversalBase doesn't complain
|
||||
|
||||
static int const kStages = Mma::kStages;
|
||||
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
|
||||
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
|
||||
static int const kAlignmentC = EpilogueVisitor::kElementsPerAccess;
|
||||
|
||||
/// Warp count (concept: GemmShape)
|
||||
using WarpCount = typename Mma::WarpCount;
|
||||
static int const kThreadCount = 32 * WarpCount::kCount;
|
||||
|
||||
/// Split-K preserves splits that are 128b aligned
|
||||
static int const kSplitKAlignment
|
||||
= const_max(128 / sizeof_bits<ElementA>::value, 128 / sizeof_bits<ElementB>::value);
|
||||
|
||||
//
|
||||
// Structures
|
||||
//
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments
|
||||
{
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
GemmUniversalMode mode;
|
||||
GemmCoord problem_size;
|
||||
int batch_count;
|
||||
|
||||
TensorRefA ref_A;
|
||||
TensorRefB ref_B;
|
||||
tk::QuantMode quant_option;
|
||||
TensorRefAlphaCol ref_alpha_col;
|
||||
TensorRefAlphaRow ref_alpha_row;
|
||||
TensorRefC ref_C;
|
||||
TensorRefC ref_D;
|
||||
|
||||
int64_t batch_stride_A;
|
||||
int64_t batch_stride_B;
|
||||
int64_t batch_stride_D;
|
||||
|
||||
typename EpilogueVisitor::Arguments epilogue_visitor;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
Arguments()
|
||||
: mode(GemmUniversalMode::kGemm)
|
||||
, batch_count(1)
|
||||
{
|
||||
}
|
||||
|
||||
/// constructs an arguments structure
|
||||
Arguments(GemmUniversalMode mode_, GemmCoord problem_size_, int batch_count_, TensorRefA ref_A_,
|
||||
TensorRefB ref_B_, tk::QuantMode quant_option_, TensorRefAlphaCol ref_alpha_col_,
|
||||
TensorRefAlphaRow ref_alpha_row_, TensorRefC ref_C_, TensorRefC ref_D_, int64_t batch_stride_A_,
|
||||
int64_t batch_stride_B_, typename EpilogueVisitor::Arguments epilogue_visitor_)
|
||||
: mode(mode_)
|
||||
, problem_size(problem_size_)
|
||||
, batch_count(batch_count_)
|
||||
, ref_A(ref_A_)
|
||||
, ref_B(ref_B_)
|
||||
, quant_option(quant_option_)
|
||||
, ref_alpha_col(ref_alpha_col_)
|
||||
, ref_alpha_row(ref_alpha_row_)
|
||||
, ref_C(ref_C_)
|
||||
, ref_D(ref_D_)
|
||||
, batch_stride_A(batch_stride_A_)
|
||||
, batch_stride_B(batch_stride_B_)
|
||||
, batch_stride_D(0)
|
||||
, epilogue_visitor(epilogue_visitor_)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Structure for precomputing values in host memory and passing to kernels
|
||||
//
|
||||
|
||||
/// Parameters structure
|
||||
struct Params
|
||||
{
|
||||
|
||||
cutlass::gemm::GemmCoord problem_size;
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int swizzle_log_tile;
|
||||
|
||||
typename Mma::IteratorA::Params params_A;
|
||||
typename Mma::IteratorB::Params params_B;
|
||||
typename EpilogueVisitor::ScaleTileIterator::Params params_alpha_col;
|
||||
typename EpilogueVisitor::ScaleTileIterator::Params params_alpha_row;
|
||||
typename EpilogueVisitor::OutputTileIterator::Params params_C;
|
||||
typename EpilogueVisitor::OutputTileIterator::Params params_D;
|
||||
|
||||
GemmUniversalMode mode;
|
||||
int batch_count;
|
||||
int gemm_k_size;
|
||||
|
||||
void* ptr_A;
|
||||
void* ptr_B;
|
||||
tk::QuantMode quant_option;
|
||||
typename EpilogueVisitor::ScaleTileIterator::Element* ptr_alpha_col;
|
||||
typename EpilogueVisitor::ScaleTileIterator::Element* ptr_alpha_row;
|
||||
ElementC* ptr_C;
|
||||
ElementC* ptr_D;
|
||||
|
||||
int64_t batch_stride_A;
|
||||
int64_t batch_stride_B;
|
||||
|
||||
typename EpilogueVisitor::Params epilogue_visitor;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params()
|
||||
: swizzle_log_tile(0)
|
||||
, params_A(0)
|
||||
, params_B(0)
|
||||
, params_alpha_col(0)
|
||||
, params_C(0)
|
||||
, params_D(0)
|
||||
, batch_count(0)
|
||||
, gemm_k_size(0)
|
||||
, mode(cutlass::gemm::GemmUniversalMode::kGemm)
|
||||
, ptr_A(nullptr)
|
||||
, ptr_B(nullptr)
|
||||
, ptr_alpha_col(nullptr)
|
||||
, ptr_alpha_row(nullptr)
|
||||
, ptr_C(nullptr)
|
||||
, ptr_D(nullptr)
|
||||
, batch_stride_A(0)
|
||||
, batch_stride_B(0)
|
||||
{
|
||||
}
|
||||
|
||||
Params(
|
||||
Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape_, int gemm_k_size_, int* workspace_)
|
||||
: problem_size(args.problem_size)
|
||||
, swizzle_log_tile(0)
|
||||
, params_A(args.ref_A.layout())
|
||||
, params_B(args.ref_B.layout())
|
||||
, params_alpha_col(args.ref_alpha_col.layout())
|
||||
, params_alpha_row(args.ref_alpha_col.layout())
|
||||
, params_C(args.ref_C.layout())
|
||||
, params_D(args.ref_D.layout())
|
||||
, mode(args.mode)
|
||||
, batch_count(args.batch_count)
|
||||
, gemm_k_size(args.problem_size.k())
|
||||
, ptr_A(args.ref_A.data())
|
||||
, ptr_B(args.ref_B.data())
|
||||
, quant_option(args.quant_option)
|
||||
, ptr_alpha_col(args.ref_alpha_col.data())
|
||||
, ptr_alpha_row(args.ref_alpha_row.data())
|
||||
, ptr_C(args.ref_C.data())
|
||||
, ptr_D(args.ref_D.data())
|
||||
, batch_stride_A(args.batch_stride_A)
|
||||
, batch_stride_B(args.batch_stride_B)
|
||||
, epilogue_visitor(args.epilogue_visitor)
|
||||
{
|
||||
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
grid_tiled_shape = threadblock_swizzle.get_tiled_shape(args.problem_size,
|
||||
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count);
|
||||
|
||||
if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel)
|
||||
{
|
||||
|
||||
int const kAlignK
|
||||
= const_max(const_max(128 / sizeof_bits<ElementA>::value, 128 / sizeof_bits<ElementB>::value), 1);
|
||||
|
||||
gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK);
|
||||
|
||||
if (gemm_k_size)
|
||||
{
|
||||
grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size);
|
||||
}
|
||||
}
|
||||
|
||||
swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape);
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared memory storage structure
|
||||
union SharedStorage
|
||||
{
|
||||
|
||||
typename Mma::SharedStorage main_loop;
|
||||
|
||||
struct
|
||||
{
|
||||
typename Epilogue::SharedStorage epilogue;
|
||||
typename EpilogueVisitor::SharedStorage visitor;
|
||||
} epilogue;
|
||||
};
|
||||
|
||||
public:
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_DEVICE
|
||||
GemmWithEpilogueVisitor() {}
|
||||
|
||||
/// Determines whether kernel satisfies alignment
|
||||
static Status can_implement(cutlass::gemm::GemmCoord const& problem_size)
|
||||
{
|
||||
|
||||
CUTLASS_TRACE_HOST("GemmWithEpilogueVisitor::can_implement()");
|
||||
|
||||
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
|
||||
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
|
||||
static int const kAlignmentC = EpilogueVisitor::OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
bool isAMisaligned = false;
|
||||
bool isBMisaligned = false;
|
||||
bool isCMisaligned = false;
|
||||
|
||||
if (platform::is_same<LayoutA, layout::RowMajor>::value)
|
||||
{
|
||||
isAMisaligned = problem_size.k() % kAlignmentA;
|
||||
}
|
||||
else if (platform::is_same<LayoutA, layout::ColumnMajor>::value)
|
||||
{
|
||||
isAMisaligned = problem_size.m() % kAlignmentA;
|
||||
}
|
||||
else if (platform::is_same<LayoutA, layout::ColumnMajorInterleaved<32>>::value
|
||||
|| platform::is_same<LayoutA, layout::ColumnMajorInterleaved<64>>::value)
|
||||
{
|
||||
isAMisaligned = problem_size.k() % kAlignmentA;
|
||||
}
|
||||
|
||||
if (platform::is_same<LayoutB, layout::RowMajor>::value)
|
||||
{
|
||||
isBMisaligned = problem_size.n() % kAlignmentB;
|
||||
}
|
||||
else if (platform::is_same<LayoutB, layout::ColumnMajor>::value)
|
||||
{
|
||||
isBMisaligned = problem_size.k() % kAlignmentB;
|
||||
}
|
||||
else if (platform::is_same<LayoutB, layout::RowMajorInterleaved<32>>::value
|
||||
|| platform::is_same<LayoutB, layout::RowMajorInterleaved<64>>::value)
|
||||
{
|
||||
isBMisaligned = problem_size.k() % kAlignmentB;
|
||||
}
|
||||
|
||||
if (platform::is_same<LayoutC, layout::RowMajor>::value)
|
||||
{
|
||||
isCMisaligned = problem_size.n() % kAlignmentC;
|
||||
}
|
||||
else if (platform::is_same<LayoutC, layout::ColumnMajor>::value)
|
||||
{
|
||||
isCMisaligned = problem_size.m() % kAlignmentC;
|
||||
}
|
||||
else if (platform::is_same<LayoutC, layout::ColumnMajorInterleaved<32>>::value
|
||||
|| platform::is_same<LayoutC, layout::ColumnMajorInterleaved<64>>::value)
|
||||
{
|
||||
isCMisaligned = problem_size.n() % kAlignmentC;
|
||||
}
|
||||
|
||||
if (isAMisaligned)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand");
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (isBMisaligned)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand");
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (isCMisaligned)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand");
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST(" returning kSuccess");
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
static Status can_implement(Arguments const& args)
|
||||
{
|
||||
return can_implement(args.problem_size);
|
||||
}
|
||||
|
||||
static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape)
|
||||
{
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#define SPLIT_K_ENABLED 1
|
||||
|
||||
/// Executes one GEMM
|
||||
CUTLASS_DEVICE
|
||||
void run_kernel_(Params const& params, SharedStorage& shared_storage)
|
||||
{
|
||||
|
||||
// Compute threadblock location
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
// Early exit if CTA is out of range
|
||||
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m()
|
||||
|| params.grid_tiled_shape.n() <= threadblock_tile_offset.n())
|
||||
{
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
int offset_k = 0;
|
||||
int problem_size_k = params.problem_size.k();
|
||||
|
||||
ElementA* ptr_A = static_cast<ElementA*>(params.ptr_A);
|
||||
ElementB* ptr_B = static_cast<ElementB*>(params.ptr_B);
|
||||
|
||||
#if SPLIT_K_ENABLED
|
||||
//
|
||||
// Fetch pointers based on mode.
|
||||
//
|
||||
if (params.mode == GemmUniversalMode::kGemm || params.mode == GemmUniversalMode::kGemmSplitKParallel)
|
||||
{
|
||||
|
||||
if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k())
|
||||
{
|
||||
|
||||
problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size;
|
||||
}
|
||||
|
||||
offset_k = threadblock_tile_offset.k() * params.gemm_k_size;
|
||||
}
|
||||
else if (params.mode == GemmUniversalMode::kBatched)
|
||||
{
|
||||
ptr_A += threadblock_tile_offset.k() * params.batch_stride_A;
|
||||
ptr_B += threadblock_tile_offset.k() * params.batch_stride_B;
|
||||
}
|
||||
else if (params.mode == GemmUniversalMode::kArray)
|
||||
{
|
||||
ptr_A = static_cast<ElementA* const*>(params.ptr_A)[threadblock_tile_offset.k()];
|
||||
ptr_B = static_cast<ElementB* const*>(params.ptr_B)[threadblock_tile_offset.k()];
|
||||
}
|
||||
#endif
|
||||
|
||||
// Compute initial location in logical coordinates
|
||||
cutlass::MatrixCoord tb_offset_A{
|
||||
threadblock_tile_offset.m() * Mma::Shape::kM,
|
||||
offset_k,
|
||||
};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_B{offset_k, threadblock_tile_offset.n() * Mma::Shape::kN};
|
||||
|
||||
// Compute position within threadblock
|
||||
int thread_idx = threadIdx.x;
|
||||
|
||||
// Construct iterators to A and B operands
|
||||
typename Mma::IteratorA iterator_A(
|
||||
params.params_A, ptr_A, {params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A);
|
||||
|
||||
typename Mma::IteratorB iterator_B(
|
||||
params.params_B, ptr_B, {problem_size_k, params.problem_size.n()}, thread_idx, tb_offset_B);
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
//
|
||||
// Main loop
|
||||
//
|
||||
|
||||
// Construct thread-scoped matrix multiply
|
||||
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
typename Mma::FragmentC accumulators;
|
||||
|
||||
accumulators.clear();
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK;
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators);
|
||||
|
||||
//
|
||||
// Masked tile iterators constructed from members
|
||||
//
|
||||
|
||||
threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
// assume identity swizzle
|
||||
MatrixCoord threadblock_offset(
|
||||
threadblock_tile_offset.m() * Mma::Shape::kM, threadblock_tile_offset.n() * Mma::Shape::kN);
|
||||
|
||||
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
|
||||
|
||||
//
|
||||
// Construct the epilogue visitor
|
||||
//
|
||||
|
||||
EpilogueVisitor epilogue_visitor(params.epilogue_visitor, shared_storage.epilogue.visitor,
|
||||
params.problem_size.mn(), thread_idx, warp_idx, lane_idx, params.params_alpha_col, params.params_C,
|
||||
params.params_D, params.quant_option, params.ptr_alpha_row, params.ptr_alpha_col, params.ptr_C,
|
||||
params.ptr_D, threadblock_offset, blockIdx.y * params.problem_size.m());
|
||||
|
||||
if (params.mode == GemmUniversalMode::kGemm)
|
||||
{
|
||||
// Indicate which position in a serial reduction the output operator is currently updating
|
||||
epilogue_visitor.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
|
||||
}
|
||||
else if (params.mode == GemmUniversalMode::kBatched || params.mode == GemmUniversalMode::kArray)
|
||||
{
|
||||
epilogue_visitor.set_batch_index(threadblock_tile_offset.k());
|
||||
}
|
||||
|
||||
// Construct the epilogue
|
||||
Epilogue epilogue(shared_storage.epilogue.epilogue, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
// Execute the epilogue operator to update the destination tensor.
|
||||
epilogue(epilogue_visitor, accumulators);
|
||||
}
|
||||
|
||||
template <typename CompilationArch>
|
||||
CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage)
|
||||
{
|
||||
if constexpr (platform::is_same<ArchTag, CompilationArch>::value)
|
||||
{
|
||||
run_kernel_(params, shared_storage);
|
||||
}
|
||||
else
|
||||
{
|
||||
CUTLASS_NOT_IMPLEMENTED();
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond
|
||||
to the ArchTag of the cutlass kernel operator.
|
||||
*/
|
||||
/// Executes one GEMM
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const& params, SharedStorage& shared_storage)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__)
|
||||
#if (__CUDA_ARCH__ >= 720) && (__CUDA_ARCH__ < 750)
|
||||
run_kernel<arch::Sm72>(params, shared_storage);
|
||||
#elif (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800)
|
||||
run_kernel<arch::Sm75>(params, shared_storage);
|
||||
#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900)
|
||||
run_kernel<arch::Sm80>(params, shared_storage);
|
||||
#elif (__CUDA_ARCH__ >= 900)
|
||||
// TODO - replace with CUTLASS_NOT_IMPLEMENTED() and upgrade to 3.x kernels.
|
||||
run_kernel<arch::Sm80>(params, shared_storage);
|
||||
#else
|
||||
static_assert(
|
||||
false, "Invalid architecture being compiled. Only Volta+ supported in weight-only quantization kernels.");
|
||||
#endif
|
||||
#else
|
||||
CUTLASS_NOT_IMPLEMENTED();
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -1,143 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
/*
|
||||
This file exists so that we use the same weight layout for MoE grouped gemm and regular gemm when the weight is
|
||||
quantized. The preprocessing code reads this template to know how to organize the quantized weight matrices
|
||||
to be consumed by CUTLASS.
|
||||
|
||||
Note that for int4, ThreadBlockK MUST be 64.
|
||||
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/arch/mma.h"
|
||||
#include "cutlass/platform/platform.h"
|
||||
|
||||
#include "cutlass_extensions/arch/mma.h"
|
||||
#include "cutlass_extensions/tile_interleaved_layout.h"
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace kernel
|
||||
{
|
||||
|
||||
template <typename TypeA, typename TypeB, typename Arch, typename Enable = void>
|
||||
struct LayoutDetailsB
|
||||
{
|
||||
};
|
||||
|
||||
// Specializations for Turing+ when B is FP16. These are currently only used for MoE networks.
|
||||
// TODO - Switch this to column major for weights since gemms should be more performant.
|
||||
template <typename TypeA, typename Arch>
|
||||
struct LayoutDetailsB<TypeA, half_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type>
|
||||
{
|
||||
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
|
||||
using Layout = layout::ColumnMajor;
|
||||
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<half_t>::value;
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
};
|
||||
|
||||
template <typename TypeA, typename Arch>
|
||||
struct LayoutDetailsB<TypeA, bfloat16_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type>
|
||||
{
|
||||
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
|
||||
using Layout = layout::ColumnMajor;
|
||||
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<bfloat16_t>::value;
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
};
|
||||
|
||||
template <typename TypeA>
|
||||
struct LayoutDetailsB<TypeA, cutlass::float_e4m3_t, arch::Sm89>
|
||||
{
|
||||
static constexpr int ThreadblockK = 64;
|
||||
|
||||
private:
|
||||
static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits<uint8_t>::value;
|
||||
static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK;
|
||||
|
||||
public:
|
||||
using Layout = layout::ColumnMajor;
|
||||
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<cutlass::float_e4m3_t>::value;
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
// for fast accumulation
|
||||
// using Operator = cutlass::arch::OpMultiplyAddFastAccum;
|
||||
};
|
||||
|
||||
// Specializations for Turing+ when B is quantized. These can use the operator OpMultiplyAddDequantizeInterleavedBToA,
|
||||
// which signals that we want to dequantize after loading from smem.
|
||||
template <typename TypeA, typename Arch>
|
||||
struct LayoutDetailsB < TypeA,
|
||||
uint8_t, Arch,
|
||||
typename platform::enable_if<Arch::kMinComputeCapability >= 75 && Arch::kMinComputeCapability<90>::type>
|
||||
{
|
||||
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
|
||||
|
||||
private:
|
||||
static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits<uint8_t>::value;
|
||||
static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK;
|
||||
|
||||
public:
|
||||
using Layout = layout::ColumnMajorTileInterleave<ThreadblockK, ColumnsInterleaved>;
|
||||
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<uint8_t>::value;
|
||||
using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA;
|
||||
};
|
||||
|
||||
template <typename TypeA, typename Arch>
|
||||
struct LayoutDetailsB < TypeA,
|
||||
uint4b_t, Arch,
|
||||
typename platform::enable_if<Arch::kMinComputeCapability >= 75 && Arch::kMinComputeCapability<90>::type>
|
||||
{
|
||||
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
|
||||
|
||||
private:
|
||||
static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits<uint4b_t>::value;
|
||||
static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK;
|
||||
|
||||
public:
|
||||
using Layout = layout::ColumnMajorTileInterleave<ThreadblockK, ColumnsInterleaved>;
|
||||
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<uint4b_t>::value;
|
||||
using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA;
|
||||
};
|
||||
|
||||
template <typename TypeA, typename Arch>
|
||||
struct LayoutDetailsB<TypeA, uint8_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 90>::type>
|
||||
{
|
||||
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
|
||||
using Layout = layout::ColumnMajor;
|
||||
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<half_t>::value;
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
};
|
||||
|
||||
template <typename TypeA, typename Arch>
|
||||
struct LayoutDetailsB<TypeA, uint4b_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 90>::type>
|
||||
{
|
||||
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
|
||||
using Layout = layout::ColumnMajor;
|
||||
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<half_t>::value;
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
};
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@@ -1,185 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include <cute/algorithm/copy.hpp>
|
||||
#include <cute/atom/copy_atom.hpp>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/layout/layout.h>
|
||||
#include <cutlass/numeric_conversion.h>
|
||||
|
||||
template <typename Element, typename Layout, int Alignment, int SizeK>
|
||||
struct DefaultGemm_TensorOpSm80_OperandA;
|
||||
|
||||
template <typename Element, typename Layout, int Alignment, int SizeK>
|
||||
struct DefaultGemm_TensorOpSm80_OperandB;
|
||||
|
||||
template <>
|
||||
struct DefaultGemm_TensorOpSm80_OperandA<cute::half_t, cutlass::layout::RowMajor, 8, 64>
|
||||
{
|
||||
// Smem
|
||||
using SmemLayoutAtom = decltype(cute::composition(
|
||||
cute::Swizzle<3, 3, 3>{}, cute::Layout<cute::Shape<cute::_8, cute::_64>, cute::Stride<cute::_64, cute::_1>>{}));
|
||||
using SmemCopyAtom = cute::Copy_Atom<cute::SM75_U32x4_LDSM_N, cute::half_t>;
|
||||
|
||||
// Gmem
|
||||
using GmemTiledCopy = decltype(cute::make_tiled_copy(
|
||||
cute::Copy_Atom<cute::SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, cute::half_t>{},
|
||||
cute::Layout<cute::Shape<cute::_16, cute::_8>, cute::Stride<cute::_8, cute::_1>>{},
|
||||
cute::Layout<cute::Shape<cute::_1, cute::_8>>{}));
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DefaultGemm_TensorOpSm80_OperandA<cute::bfloat16_t, cutlass::layout::RowMajor, 8, 64>
|
||||
{
|
||||
// Smem
|
||||
using SmemLayoutAtom = decltype(cute::composition(
|
||||
cute::Swizzle<3, 3, 3>{}, cute::Layout<cute::Shape<cute::_8, cute::_64>, cute::Stride<cute::_64, cute::_1>>{}));
|
||||
using SmemCopyAtom = cute::Copy_Atom<cute::SM75_U32x4_LDSM_N, cute::bfloat16_t>;
|
||||
|
||||
// Gmem
|
||||
using GmemTiledCopy = decltype(cute::make_tiled_copy(
|
||||
cute::Copy_Atom<cute::SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, cute::bfloat16_t>{},
|
||||
cute::Layout<cute::Shape<cute::_16, cute::_8>, cute::Stride<cute::_8, cute::_1>>{},
|
||||
cute::Layout<cute::Shape<cute::_1, cute::_8>>{}));
|
||||
};
|
||||
|
||||
/// Operand A - Column-major (M-major)
|
||||
template <int SizeK>
|
||||
struct DefaultGemm_TensorOpSm80_OperandA<cute::half_t, cutlass::layout::ColumnMajor, 8, SizeK>
|
||||
{
|
||||
// Smem
|
||||
using SmemLayoutAtom = decltype(cute::composition(
|
||||
cute::Swizzle<3, 3, 3>{}, cute::Layout<cute::Shape<cute::_64, cute::_8>, cute::Stride<cute::_1, cute::_64>>{}));
|
||||
using SmemCopyAtom = cute::Copy_Atom<cute::SM75_U16x8_LDSM_T, cute::half_t>;
|
||||
|
||||
// Gmem
|
||||
using GmemTiledCopy = decltype(cute::make_tiled_copy(
|
||||
cute::Copy_Atom<cute::SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, cute::half_t>{},
|
||||
cute::Layout<cute::Shape<cute::_16, cute::_8>, cute::Stride<cute::_1, cute::_16>>{},
|
||||
cute::Layout<cute::Shape<cute::_8, cute::_1>>{}));
|
||||
};
|
||||
|
||||
template <int SizeK>
|
||||
struct DefaultGemm_TensorOpSm80_OperandA<cute::bfloat16_t, cutlass::layout::ColumnMajor, 8, SizeK>
|
||||
{
|
||||
// Smem
|
||||
using SmemLayoutAtom = decltype(cute::composition(
|
||||
cute::Swizzle<3, 3, 3>{}, cute::Layout<cute::Shape<cute::_64, cute::_8>, cute::Stride<cute::_1, cute::_64>>{}));
|
||||
using SmemCopyAtom = cute::Copy_Atom<cute::SM75_U16x8_LDSM_T, cute::bfloat16_t>;
|
||||
|
||||
// Gmem
|
||||
using GmemTiledCopy = decltype(cute::make_tiled_copy(
|
||||
cute::Copy_Atom<cute::SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, cute::bfloat16_t>{},
|
||||
cute::Layout<cute::Shape<cute::_16, cute::_8>, cute::Stride<cute::_1, cute::_16>>{},
|
||||
cute::Layout<cute::Shape<cute::_8, cute::_1>>{}));
|
||||
};
|
||||
|
||||
// Because the F32F16 TiledMMA is A-B symmetric, we can reuse the DefaultOperands
|
||||
|
||||
// Operand B - Column-Major (K-major)
|
||||
template <int Alignment, int SizeK>
|
||||
struct DefaultGemm_TensorOpSm80_OperandB<cute::half_t, cutlass::layout::ColumnMajor, Alignment, SizeK>
|
||||
: DefaultGemm_TensorOpSm80_OperandA<cute::half_t, cutlass::layout::RowMajor, Alignment, SizeK>
|
||||
{
|
||||
};
|
||||
|
||||
template <int Alignment, int SizeK>
|
||||
struct DefaultGemm_TensorOpSm80_OperandB<cute::bfloat16_t, cutlass::layout::ColumnMajor, Alignment, SizeK>
|
||||
: DefaultGemm_TensorOpSm80_OperandA<cute::bfloat16_t, cutlass::layout::RowMajor, Alignment, SizeK>
|
||||
{
|
||||
};
|
||||
|
||||
// Operand B - Row-Major (N-major)
|
||||
template <int Alignment, int SizeK>
|
||||
struct DefaultGemm_TensorOpSm80_OperandB<cute::half_t, cutlass::layout::RowMajor, Alignment, SizeK>
|
||||
: DefaultGemm_TensorOpSm80_OperandA<cute::half_t, cutlass::layout::ColumnMajor, Alignment, SizeK>
|
||||
{
|
||||
};
|
||||
|
||||
template <int Alignment, int SizeK>
|
||||
struct DefaultGemm_TensorOpSm80_OperandB<cute::bfloat16_t, cutlass::layout::RowMajor, Alignment, SizeK>
|
||||
: DefaultGemm_TensorOpSm80_OperandA<cute::bfloat16_t, cutlass::layout::ColumnMajor, Alignment, SizeK>
|
||||
{
|
||||
};
|
||||
|
||||
//
|
||||
// F16: 128-by-128-by-32 (small k-block)
|
||||
//
|
||||
|
||||
/// Operand A - Row-major (K-Major)
|
||||
template <>
|
||||
struct DefaultGemm_TensorOpSm80_OperandA<cute::half_t, cutlass::layout::RowMajor, 8, 32>
|
||||
{
|
||||
// Smem
|
||||
using SmemLayoutAtom = decltype(cute::composition(
|
||||
cute::Swizzle<2, 3, 3>{}, cute::Layout<cute::Shape<cute::_8, cute::_32>, cute::Stride<cute::_32, cute::_1>>{}));
|
||||
using SmemCopyAtom = cute::Copy_Atom<cute::SM75_U32x4_LDSM_N, cute::half_t>;
|
||||
|
||||
// Gmem
|
||||
using GmemTiledCopy = decltype(cute::make_tiled_copy(
|
||||
cute::Copy_Atom<cute::SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, cute::half_t>{},
|
||||
cute::Layout<cute::Shape<cute::_32, cute::_4>, cute::Stride<cute::_4, cute::_1>>{},
|
||||
cute::Layout<cute::Shape<cute::_1, cute::_8>>{}));
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DefaultGemm_TensorOpSm80_OperandA<cute::bfloat16_t, cutlass::layout::RowMajor, 8, 32>
|
||||
{
|
||||
// Smem
|
||||
using SmemLayoutAtom = decltype(cute::composition(
|
||||
cute::Swizzle<2, 3, 3>{}, cute::Layout<cute::Shape<cute::_8, cute::_32>, cute::Stride<cute::_32, cute::_1>>{}));
|
||||
using SmemCopyAtom = cute::Copy_Atom<cute::SM75_U32x4_LDSM_N, cute::bfloat16_t>;
|
||||
|
||||
// Gmem
|
||||
using GmemTiledCopy = decltype(cute::make_tiled_copy(
|
||||
cute::Copy_Atom<cute::SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, cute::bfloat16_t>{},
|
||||
cute::Layout<cute::Shape<cute::_32, cute::_4>, cute::Stride<cute::_4, cute::_1>>{},
|
||||
cute::Layout<cute::Shape<cute::_1, cute::_8>>{}));
|
||||
};
|
||||
|
||||
template <typename To_type, typename Engine, typename Layout>
|
||||
CUTE_DEVICE auto util_convert_type(cute::Tensor<Engine, Layout> const& tensor)
|
||||
{
|
||||
using From_type = typename Engine::value_type;
|
||||
constexpr int numel = decltype(cute::size(tensor))::value;
|
||||
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
|
||||
// HACK: this requires tensor to be "contiguous"
|
||||
auto frag = convert_op(*reinterpret_cast<cutlass::Array<From_type, numel> const*>(tensor.data()));
|
||||
return cute::make_tensor(cute::make_rmem_ptr<To_type>(&frag), tensor.layout());
|
||||
}
|
||||
|
||||
template <typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
CUTE_DEVICE void util_copy(
|
||||
TiledCopy const& tiled_copy, cute::Tensor<Engine0, Layout0> const& S, cute::Tensor<Engine1, Layout1>& D)
|
||||
{
|
||||
CUTE_STATIC_ASSERT_V(cute::rank(S) == cute::Int<3>{});
|
||||
CUTE_STATIC_ASSERT_V(cute::rank(D) == cute::Int<3>{});
|
||||
CUTE_STATIC_ASSERT_V(cute::size<0>(S) == cute::size<0>(D));
|
||||
CUTE_STATIC_ASSERT_V(cute::size<1>(S) == cute::size<1>(D));
|
||||
CUTE_STATIC_ASSERT_V(cute::size<2>(S) == cute::size<2>(D));
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int m = 0; m < cute::size<1>(S); ++m)
|
||||
{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k = 0; k < cute::size<2>(S); ++k)
|
||||
{
|
||||
cute::copy(tiled_copy, S(cute::_, m, k), D(cute::_, m, k));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,553 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/semaphore.h"
|
||||
|
||||
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h"
|
||||
#include "cutlass_extensions/tile_interleaved_layout.h"
|
||||
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace kernel
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// This section exists to that we can use the same kernel code for regular gemm and dequantizing gemms.
|
||||
// It will dispatch to the dequantizing gemm if the Mma type has an Iterator for scales in global.
|
||||
template <typename...>
|
||||
using void_t = void;
|
||||
|
||||
template <typename Mma, typename = void>
|
||||
struct use_dq_gemm : platform::false_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Mma>
|
||||
struct use_dq_gemm<Mma, void_t<typename Mma::IteratorScale>> : platform::true_type
|
||||
{
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||
typename Epilogue_, ///! Epilogue
|
||||
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
|
||||
typename KernelArch, ///! The Architecture this kernel is compiled for. Used since SIMT kernels lose top-level
|
||||
/// arch.
|
||||
GroupScheduleMode GroupScheduleMode_ ///! Type of scheduling to perform
|
||||
>
|
||||
struct MoeFCGemm
|
||||
{
|
||||
public:
|
||||
using Mma = Mma_;
|
||||
using Epilogue = Epilogue_;
|
||||
using EpilogueOutputOp = typename Epilogue::OutputOp;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_;
|
||||
static bool const kTransposed = false;
|
||||
|
||||
// Optional transpose
|
||||
using MapArguments = kernel::detail::MapArguments<typename Mma::IteratorA::Element, typename Mma::IteratorA::Layout,
|
||||
Mma::kTransformA, Mma::IteratorA::AccessType::kElements, typename Mma::IteratorB::Element,
|
||||
typename Mma::IteratorB::Layout, Mma::kTransformB, Mma::IteratorB::AccessType::kElements, typename Mma::LayoutC,
|
||||
kTransposed>;
|
||||
|
||||
// Public-facing type definitions related to operand element type, layout, and complex conjugate
|
||||
// operation. Must interact with the 'kTransposed' notion.
|
||||
static_assert(!kTransposed, "Transpose problem not supported");
|
||||
using ElementA = typename MapArguments::ElementA;
|
||||
using LayoutA = typename MapArguments::LayoutA;
|
||||
using ElementB = typename MapArguments::ElementB;
|
||||
using LayoutB = typename MapArguments::LayoutB;
|
||||
using ElementC = typename Epilogue::OutputTileIterator::Element;
|
||||
using LayoutC = typename MapArguments::LayoutC;
|
||||
using ElementScale = ElementC;
|
||||
|
||||
static ComplexTransform const kTransformA = MapArguments::kTransformA;
|
||||
static ComplexTransform const kTransformB = MapArguments::kTransformB;
|
||||
|
||||
// Type definitions about the mainloop.
|
||||
using Operator = typename Mma::Operator;
|
||||
using OperatorClass = typename Mma::Operator::OperatorClass;
|
||||
using ThreadblockShape = typename Mma::Shape;
|
||||
using WarpShape = typename Mma::Operator::Shape;
|
||||
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
|
||||
using ArchTag = typename Mma::ArchTag;
|
||||
|
||||
static int const kStages = Mma::kStages;
|
||||
static int const kAlignmentA = MapArguments::kAlignmentA;
|
||||
static int const kAlignmentB = MapArguments::kAlignmentB;
|
||||
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
/// Warp count (concept: GemmShape)
|
||||
using WarpCount = typename Mma::WarpCount;
|
||||
static int const kThreadCount = 32 * WarpCount::kCount;
|
||||
|
||||
using ProblemVisitor
|
||||
= GemmMoeProblemVisitor<ThreadblockShape, kGroupScheduleMode, kThreadCount, kThreadCount, kTransposed>;
|
||||
|
||||
//
|
||||
// Structures
|
||||
//
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments
|
||||
{
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
int problem_count;
|
||||
int threadblock_count;
|
||||
int group_size;
|
||||
|
||||
typename EpilogueOutputOp::Params output_op;
|
||||
|
||||
ElementA* ptr_A;
|
||||
ElementB* ptr_B;
|
||||
ElementScale* weight_scales;
|
||||
ElementC* ptr_C;
|
||||
ElementC* ptr_D;
|
||||
bool C_is_broadcast;
|
||||
|
||||
int64_t const* total_tokens_including_expert;
|
||||
int64_t gemm_n;
|
||||
int64_t gemm_k;
|
||||
|
||||
// Only used by device-level operator
|
||||
GemmCoord* host_problem_sizes;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments()
|
||||
: problem_count(0)
|
||||
, threadblock_count(0)
|
||||
, ptr_A(nullptr)
|
||||
, ptr_B(nullptr)
|
||||
, weight_scales(nullptr)
|
||||
, ptr_C(nullptr)
|
||||
, ptr_D(nullptr)
|
||||
, total_tokens_including_expert(nullptr)
|
||||
, gemm_n(0)
|
||||
, gemm_k(0)
|
||||
, host_problem_sizes(nullptr)
|
||||
, C_is_broadcast{true}
|
||||
{
|
||||
}
|
||||
|
||||
/// Ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(int problem_count, int threadblock_count, int group_size, typename EpilogueOutputOp::Params output_op,
|
||||
ElementA const* ptr_A, ElementB const* ptr_B, ElementScale const* weight_scales, ElementC const* ptr_C,
|
||||
bool C_is_broadcast, ElementC* ptr_D, int64_t const* total_tokens_including_expert, int64_t gemm_n,
|
||||
int64_t gemm_k, GemmCoord* host_problem_sizes = nullptr)
|
||||
: problem_count(problem_count)
|
||||
, threadblock_count(threadblock_count)
|
||||
, group_size(group_size)
|
||||
, output_op(output_op)
|
||||
, ptr_A(const_cast<ElementA*>(ptr_A))
|
||||
, ptr_B(const_cast<ElementB*>(ptr_B))
|
||||
, weight_scales(const_cast<ElementScale*>(weight_scales))
|
||||
, ptr_C(const_cast<ElementC*>(ptr_C))
|
||||
, C_is_broadcast{C_is_broadcast}
|
||||
, ptr_D(ptr_D)
|
||||
, total_tokens_including_expert(total_tokens_including_expert)
|
||||
, gemm_n(gemm_n)
|
||||
, gemm_k(gemm_k)
|
||||
, host_problem_sizes(nullptr)
|
||||
{
|
||||
if (platform::is_same<uint8_t, ElementB>::value || platform::is_same<uint4b_t, ElementB>::value)
|
||||
{
|
||||
assert(weight_scales);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Structure for precomputing values in host memory and passing to kernels
|
||||
//
|
||||
|
||||
/// Parameters structure
|
||||
struct Params
|
||||
{
|
||||
|
||||
typename ProblemVisitor::Params problem_visitor;
|
||||
int threadblock_count;
|
||||
int group_size;
|
||||
bool C_is_broadcast;
|
||||
|
||||
typename EpilogueOutputOp::Params output_op;
|
||||
|
||||
ElementA* ptr_A;
|
||||
ElementB* ptr_B;
|
||||
ElementScale* weight_scales;
|
||||
ElementC* ptr_C;
|
||||
ElementC* ptr_D;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params()
|
||||
: ptr_A(nullptr)
|
||||
, ptr_B(nullptr)
|
||||
, weight_scales(nullptr)
|
||||
, ptr_C(nullptr)
|
||||
, ptr_D(nullptr)
|
||||
, C_is_broadcast(true)
|
||||
{
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const& args, void* workspace = nullptr, int tile_count = 0)
|
||||
: problem_visitor(
|
||||
args.total_tokens_including_expert, args.gemm_n, args.gemm_k, args.problem_count, workspace, tile_count)
|
||||
, threadblock_count(args.threadblock_count)
|
||||
, group_size(args.group_size)
|
||||
, output_op(args.output_op)
|
||||
, ptr_A(args.ptr_A)
|
||||
, ptr_B(args.ptr_B)
|
||||
, weight_scales(args.weight_scales)
|
||||
, ptr_C(args.ptr_C)
|
||||
, ptr_D(args.ptr_D)
|
||||
, C_is_broadcast(args.C_is_broadcast)
|
||||
{
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void update(Arguments const& args, void* workspace = nullptr, int tile_count = 0)
|
||||
{
|
||||
|
||||
problem_visitor = typename ProblemVisitor::Params(args.total_tokens_including_expert, args.gemm_n,
|
||||
args.gemm_k, args.problem_count, workspace, tile_count);
|
||||
threadblock_count = args.threadblock_count;
|
||||
output_op = args.output_op;
|
||||
ptr_A = args.ptr_A;
|
||||
ptr_B = args.ptr_B;
|
||||
weight_scales = args.weight_scales;
|
||||
ptr_C = args.ptr_C;
|
||||
ptr_D = args.ptr_D;
|
||||
C_is_broadcast = args.C_is_broadcast;
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared memory storage structure
|
||||
union SharedStorage
|
||||
{
|
||||
typename ProblemVisitor::SharedStorage problem_visitor;
|
||||
typename Mma::SharedStorage main_loop;
|
||||
typename Epilogue::SharedStorage epilogue;
|
||||
};
|
||||
|
||||
public:
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_DEVICE
|
||||
MoeFCGemm() {}
|
||||
|
||||
/// Determines whether kernel satisfies alignment
|
||||
static Status can_implement(cutlass::gemm::GemmCoord const& problem_size)
|
||||
{
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
static Status can_implement(Arguments const& args)
|
||||
{
|
||||
if (platform::is_same<uint8_t, ElementB>::value || platform::is_same<uint4b_t, ElementB>::value)
|
||||
{
|
||||
if (args.weight_scales == nullptr)
|
||||
{
|
||||
CUTLASS_TRACE_HOST("MoeFCGemm::can_implement() - weight scales are required for uint8_t and uint4b_t");
|
||||
return Status::kInvalid;
|
||||
}
|
||||
}
|
||||
else if (args.weight_scales != nullptr)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(
|
||||
"MoeFCGemm::can_implement() - weight scales are ignored for all types except uint8_t and uint4b_t");
|
||||
return Status::kInvalid;
|
||||
}
|
||||
else if (args.group_size != args.gemm_k)
|
||||
{
|
||||
CUTLASS_TRACE_HOST("MoeFCGemm::can_implement() - scale shape should be (1, gemm_n)");
|
||||
return Status::kInvalid;
|
||||
}
|
||||
// Handle the case the input is too short
|
||||
else if (args.gemm_n < Mma::IteratorB::AccessType::kElements)
|
||||
{
|
||||
CUTLASS_TRACE_HOST("MoeFCGemm::can_implement() - gemm_n is smaller than the input alignment");
|
||||
return Status::kInvalid;
|
||||
}
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape)
|
||||
{
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void run_kernel_(Params const& params, SharedStorage& shared_storage)
|
||||
{
|
||||
//
|
||||
// These types shadow the type-level definitions and support the ability to implement
|
||||
// a 'transposed' GEMM that computes the transposed problems.
|
||||
//
|
||||
using ElementA = typename Mma::IteratorA::Element;
|
||||
using LayoutA = typename Mma::IteratorA::Layout;
|
||||
using ElementB = typename Mma::IteratorB::Element;
|
||||
using LayoutB = typename Mma::IteratorB::Layout;
|
||||
using ElementC = typename Epilogue::OutputTileIterator::Element;
|
||||
using LayoutC = typename Epilogue::OutputTileIterator::Layout;
|
||||
static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK;
|
||||
static_assert(platform::is_same<LayoutB, layout::RowMajor>::value && kInterleave == 1
|
||||
|| platform::is_same<LayoutB, layout::ColumnMajor>::value && kInterleave >= 1,
|
||||
"B must be row major/col major OR col major interleaved.");
|
||||
|
||||
//
|
||||
// Problem visitor.
|
||||
//
|
||||
ProblemVisitor problem_visitor(params.problem_visitor, shared_storage.problem_visitor, blockIdx.x);
|
||||
|
||||
const int64_t gemm_k = params.problem_visitor.gemm_k;
|
||||
const int64_t gemm_n = params.problem_visitor.gemm_n;
|
||||
int64_t bytes_per_expert_matrix = (gemm_k * gemm_n / 8) * cutlass::sizeof_bits<ElementB>::value;
|
||||
|
||||
// Outer 'persistent' loop to iterate over tiles
|
||||
int loop = 0;
|
||||
while (problem_visitor.next_tile())
|
||||
{
|
||||
loop++;
|
||||
|
||||
GemmCoord problem_size = problem_visitor.problem_size();
|
||||
int32_t problem_idx = problem_visitor.problem_index();
|
||||
int32_t cta_idx = int32_t(problem_visitor.threadblock_idx());
|
||||
|
||||
GemmCoord grid_shape = problem_visitor.grid_shape(problem_size);
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_offset(
|
||||
int(cta_idx / grid_shape.n()) * Mma::Shape::kM, int(cta_idx % grid_shape.n()) * Mma::Shape::kN, 0);
|
||||
|
||||
// Load element pointers. Exchange pointers and strides if working on the transpose
|
||||
const int64_t rows_to_jump
|
||||
= problem_idx == 0 ? 0 : params.problem_visitor.last_row_for_problem[problem_idx - 1];
|
||||
ElementA* ptr_A = reinterpret_cast<ElementA*>(params.ptr_A) + rows_to_jump * gemm_k;
|
||||
typename LayoutA::LongIndex ldm_A = gemm_k;
|
||||
|
||||
char* byte_ptr_B = ((char*) params.ptr_B) + problem_idx * bytes_per_expert_matrix;
|
||||
ElementB* ptr_B = reinterpret_cast<ElementB*>(byte_ptr_B);
|
||||
typename LayoutB::LongIndex ldm_B
|
||||
= platform::is_same<layout::RowMajor, LayoutB>::value ? gemm_n : gemm_k * kInterleave;
|
||||
|
||||
// Compute initial location in logical coordinates
|
||||
cutlass::MatrixCoord tb_offset_A{
|
||||
threadblock_offset.m(),
|
||||
0,
|
||||
};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_B{0, threadblock_offset.n() / kInterleave};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()};
|
||||
|
||||
// Compute position within threadblock
|
||||
int thread_idx = threadIdx.x;
|
||||
|
||||
// Construct iterators to A and B operands
|
||||
typename Mma::IteratorA iterator_A(
|
||||
LayoutA(ldm_A), ptr_A, {problem_size.m(), problem_size.k()}, thread_idx, tb_offset_A);
|
||||
|
||||
typename Mma::IteratorB iterator_B(LayoutB(ldm_B), ptr_B,
|
||||
{problem_size.k() * kInterleave, problem_size.n() / kInterleave}, thread_idx, tb_offset_B);
|
||||
|
||||
typename Mma::FragmentC accumulators;
|
||||
|
||||
accumulators.clear();
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
//
|
||||
// Matrix multiply phase
|
||||
//
|
||||
|
||||
// Construct thread-scoped matrix multiply
|
||||
auto CreateMMA = [&]()
|
||||
{
|
||||
if constexpr (use_dq_gemm<Mma>::value)
|
||||
return Mma(shared_storage.main_loop, params.group_size, thread_idx, warp_idx, lane_idx);
|
||||
else
|
||||
return Mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
|
||||
};
|
||||
Mma mma = CreateMMA();
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
|
||||
|
||||
// Wait for all threads to finish their epilogue phases from the previous tile.
|
||||
__syncthreads();
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
ElementScale* weight_scale_ptr = params.weight_scales + problem_idx * problem_size.n();
|
||||
|
||||
if constexpr (use_dq_gemm<Mma>::value)
|
||||
{
|
||||
const MatrixCoord scale_extent = {1, problem_size.n()};
|
||||
typename Mma::IteratorScale iterator_scale(Mma::IteratorScale::Layout(scale_extent.column()),
|
||||
weight_scale_ptr, scale_extent, thread_idx, tb_offset_scale);
|
||||
|
||||
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_scale, accumulators);
|
||||
}
|
||||
else
|
||||
{
|
||||
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators);
|
||||
}
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
//
|
||||
|
||||
ElementC* ptr_C = reinterpret_cast<ElementC*>(params.ptr_C)
|
||||
+ (params.C_is_broadcast ? problem_idx : rows_to_jump) * gemm_n;
|
||||
ElementC* ptr_D = reinterpret_cast<ElementC*>(params.ptr_D) + rows_to_jump * gemm_n;
|
||||
|
||||
// lora need to set as layout_C(gemm_n)
|
||||
LayoutC layout_C = params.C_is_broadcast ? LayoutC(0) : LayoutC(gemm_n);
|
||||
LayoutC layout_D(gemm_n);
|
||||
|
||||
typename Epilogue::OutputTileIterator::Params params_C(layout_C);
|
||||
typename Epilogue::OutputTileIterator::Params params_D(layout_D);
|
||||
|
||||
// Tile iterator loading from source tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_C(
|
||||
params_C, ptr_C, problem_size.mn(), thread_idx, threadblock_offset.mn());
|
||||
|
||||
// Tile iterator writing to destination tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_D(
|
||||
params_D, ptr_D, problem_size.mn(), thread_idx, threadblock_offset.mn());
|
||||
|
||||
Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
// Execute the epilogue operator to update the destination tensor.
|
||||
if constexpr (platform::is_same<EpilogueOutputOp,
|
||||
cutlass::epilogue::thread::LinearCombination<typename EpilogueOutputOp::ElementOutput,
|
||||
EpilogueOutputOp::kCount, typename EpilogueOutputOp::ElementAccumulator,
|
||||
typename EpilogueOutputOp::ElementCompute, EpilogueOutputOp::kScale,
|
||||
EpilogueOutputOp::kRound>>::value)
|
||||
{
|
||||
EpilogueOutputOp output_op(params.output_op, problem_idx);
|
||||
epilogue(output_op, iterator_D, accumulators, iterator_C);
|
||||
}
|
||||
else
|
||||
{
|
||||
EpilogueOutputOp output_op(params.output_op);
|
||||
epilogue(output_op, iterator_D, accumulators, iterator_C);
|
||||
}
|
||||
|
||||
// Next tile
|
||||
problem_visitor.advance(gridDim.x);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename CompilationArch>
|
||||
CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage)
|
||||
{
|
||||
if constexpr (platform::is_same<KernelArch, CompilationArch>::value)
|
||||
{
|
||||
run_kernel_(params, shared_storage);
|
||||
}
|
||||
else
|
||||
{
|
||||
CUTLASS_NOT_IMPLEMENTED();
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond
|
||||
to the ArchTag of the cutlass kernel operator.
|
||||
*/
|
||||
/// Executes one GEMM
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const& params, SharedStorage& shared_storage)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__)
|
||||
#if (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800)
|
||||
run_kernel<arch::Sm75>(params, shared_storage);
|
||||
#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 890)
|
||||
run_kernel<arch::Sm80>(params, shared_storage);
|
||||
#elif (__CUDA_ARCH__ >= 890) && (__CUDA_ARCH__ < 900)
|
||||
constexpr bool isFp8 = platform::is_same<ElementA, cutlass::float_e4m3_t>::value
|
||||
|| platform::is_same<ElementA, cutlass::float_e5m2_t>::value;
|
||||
if constexpr (isFp8)
|
||||
{
|
||||
run_kernel<arch::Sm89>(params, shared_storage);
|
||||
}
|
||||
else
|
||||
{ // reuse sm80 kernel for other types, align with dispatchToArch
|
||||
run_kernel<arch::Sm80>(params, shared_storage);
|
||||
}
|
||||
#elif (__CUDA_ARCH__ >= 900)
|
||||
run_kernel<arch::Sm80>(params, shared_storage);
|
||||
#else
|
||||
static_assert(
|
||||
false, "Invalid architecture being compiled. Only Volta+ supported in weight-only quantization kernels.");
|
||||
#endif
|
||||
#else
|
||||
CUTLASS_NOT_IMPLEMENTED();
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -1,344 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/*! \file
|
||||
\brief Base scheduler for grouped problems, using MoE
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/gemm/kernel/grouped_problem_visitor.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace kernel
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Visitor class to abstract away the algorithm for iterating over tiles
|
||||
template <typename ProblemSizeHelper, typename ThreadblockShape_>
|
||||
struct BaseMoeProblemVisitor
|
||||
{
|
||||
using ThreadblockShape = ThreadblockShape_;
|
||||
|
||||
struct ProblemInfo
|
||||
{
|
||||
static int32_t const kNoPrefetchEntry = -1;
|
||||
int32_t problem_idx;
|
||||
int32_t problem_start;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ProblemInfo()
|
||||
: problem_idx(kNoPrefetchEntry)
|
||||
, problem_start(kNoPrefetchEntry)
|
||||
{
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ProblemInfo(int32_t problem_idx_, int32_t problem_start_)
|
||||
: problem_idx(problem_idx_)
|
||||
, problem_start(problem_start_)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
struct Params
|
||||
{
|
||||
int64_t const* last_row_for_problem;
|
||||
int64_t gemm_n;
|
||||
int64_t gemm_k;
|
||||
int32_t problem_count;
|
||||
void const* workspace;
|
||||
int32_t tile_count;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params()
|
||||
: last_row_for_problem(nullptr)
|
||||
, gemm_n(0)
|
||||
, gemm_k(0)
|
||||
, problem_count(0)
|
||||
, workspace(nullptr)
|
||||
, tile_count(0)
|
||||
{
|
||||
}
|
||||
|
||||
/// Ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(int64_t const* last_row_for_problem, int64_t gemm_n, int64_t gemm_k, int32_t problem_count,
|
||||
void const* workspace = nullptr, int32_t tile_count = 0)
|
||||
: last_row_for_problem(last_row_for_problem)
|
||||
, gemm_n(gemm_n)
|
||||
, gemm_k(gemm_k)
|
||||
, problem_count(problem_count)
|
||||
, workspace(workspace)
|
||||
, tile_count(tile_count)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
Params const& params;
|
||||
int32_t tile_idx;
|
||||
int32_t problem_tile_start;
|
||||
int32_t problem_idx;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_DEVICE
|
||||
BaseMoeProblemVisitor(Params const& params_, int32_t block_idx)
|
||||
: params(params_)
|
||||
, tile_idx(block_idx)
|
||||
, problem_tile_start(0)
|
||||
, problem_idx(0)
|
||||
{
|
||||
}
|
||||
|
||||
/// Get the grid shape
|
||||
CUTLASS_HOST_DEVICE
|
||||
static cutlass::gemm::GemmCoord grid_shape(cutlass::gemm::GemmCoord const& problem)
|
||||
{
|
||||
|
||||
return cutlass::gemm::GemmCoord(((problem.m() - 1 + ThreadblockShape::kM) / ThreadblockShape::kM),
|
||||
((problem.n() - 1 + ThreadblockShape::kN) / ThreadblockShape::kN), 1);
|
||||
}
|
||||
|
||||
/// Gets the global tile index
|
||||
CUTLASS_HOST_DEVICE
|
||||
int32_t tile_index() const
|
||||
{
|
||||
return tile_idx;
|
||||
}
|
||||
|
||||
/// Gets the index of the problem
|
||||
CUTLASS_HOST_DEVICE
|
||||
int32_t problem_index() const
|
||||
{
|
||||
return problem_idx;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
int32_t threadblock_idx() const
|
||||
{
|
||||
return tile_idx - problem_tile_start;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void advance(int32_t grid_size)
|
||||
{
|
||||
tile_idx += grid_size;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static void possibly_transpose_problem(cutlass::gemm::GemmCoord& problem)
|
||||
{
|
||||
ProblemSizeHelper::possibly_transpose_problem(problem);
|
||||
}
|
||||
|
||||
/// Returns the problem size for the current problem
|
||||
CUTLASS_HOST_DEVICE
|
||||
cutlass::gemm::GemmCoord problem_size() const
|
||||
{
|
||||
return problem_size(problem_idx);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
cutlass::gemm::GemmCoord problem_size(int idx) const
|
||||
{
|
||||
const int64_t prev_problem_row = idx == 0 ? 0 : params.last_row_for_problem[idx - 1];
|
||||
const int64_t current_problem_row = params.last_row_for_problem[idx];
|
||||
const int64_t gemm_m = current_problem_row - prev_problem_row;
|
||||
GemmCoord problem(GemmCoord::Index(gemm_m), GemmCoord::Index(params.gemm_n), GemmCoord::Index(params.gemm_k));
|
||||
ProblemSizeHelper::possibly_transpose_problem(problem);
|
||||
return problem;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static int32_t tile_count(cutlass::gemm::GemmCoord const& grid)
|
||||
{
|
||||
return ProblemSizeHelper::tile_count(grid);
|
||||
}
|
||||
|
||||
static int32_t group_tile_count(cutlass::gemm::GemmCoord const* host_problem_sizes_ptr, int32_t problem_count)
|
||||
{
|
||||
int32_t total_tiles = 0;
|
||||
for (int32_t i = 0; i < problem_count; ++i)
|
||||
{
|
||||
auto problem = host_problem_sizes_ptr[i];
|
||||
possibly_transpose_problem(problem);
|
||||
auto grid = grid_shape(problem);
|
||||
total_tiles += tile_count(grid);
|
||||
}
|
||||
|
||||
return total_tiles;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename ProblemSizeHelper, typename ThreadblockShape, GroupScheduleMode GroupScheduleMode_,
|
||||
int PrefetchTileCount, int ThreadCount>
|
||||
struct MoeProblemVisitor;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// ProblemVisitor that performs all scheduling on device
|
||||
//
|
||||
template <typename ProblemSizeHelper, typename ThreadblockShape, int PrefetchTileCount, int ThreadCount>
|
||||
struct MoeProblemVisitor<ProblemSizeHelper, ThreadblockShape, GroupScheduleMode::kDeviceOnly, PrefetchTileCount,
|
||||
ThreadCount> : public BaseMoeProblemVisitor<ProblemSizeHelper, ThreadblockShape>
|
||||
{
|
||||
using Base = BaseMoeProblemVisitor<ProblemSizeHelper, ThreadblockShape>;
|
||||
using Params = typename Base::Params;
|
||||
static int const kThreadCount = ThreadCount;
|
||||
static bool const kRequiresPrecomputation = false;
|
||||
static int const kThreadsPerWarp = 32;
|
||||
|
||||
struct SharedStorage
|
||||
{
|
||||
};
|
||||
|
||||
// Final tile of the problem loaded by this thread. Each thread will hold
|
||||
// a separate value.
|
||||
int32_t problem_ending_tile;
|
||||
|
||||
SharedStorage& shared_storage;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_DEVICE
|
||||
MoeProblemVisitor(Params const& params_, SharedStorage& shared_storage_, int32_t block_idx)
|
||||
: Base(params_, block_idx)
|
||||
, problem_ending_tile(0)
|
||||
, shared_storage(shared_storage_)
|
||||
{
|
||||
this->problem_idx = -1 * kThreadsPerWarp;
|
||||
this->problem_tile_start = 0;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
bool next_tile()
|
||||
{
|
||||
// Check whether the tile to compute is within the range of the current problem.
|
||||
int32_t problem_tile_end = __shfl_sync(0xffffffff, problem_ending_tile, this->problem_idx % kThreadsPerWarp);
|
||||
if (this->tile_idx < problem_tile_end)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
// Check whether the tile to compute is within the current group of problems fetched by the warp.
|
||||
// The last tile for this group is the final tile of the problem held by the final thread in the warp.
|
||||
int32_t group_tile_end = __shfl_sync(0xffffffff, problem_ending_tile, kThreadsPerWarp - 1);
|
||||
|
||||
// Keep the starting problem for this group in `problem_idx`. This is done to reduce
|
||||
// register pressure. The starting problem for this group is simply the first problem
|
||||
// in the group most recently fetched by the warp.
|
||||
int32_t& group_problem_start = this->problem_idx;
|
||||
group_problem_start = (this->problem_idx / kThreadsPerWarp) * kThreadsPerWarp;
|
||||
|
||||
// Keep the starting tile for this group in `problem_tile_start`. This is done to reduce
|
||||
// register pressure.
|
||||
int32_t& group_tile_start = this->problem_tile_start;
|
||||
|
||||
// Each thread in the warp processes a separate problem to advance until
|
||||
// reaching a problem whose starting tile is less less than tile_idx.
|
||||
while (group_tile_end <= this->tile_idx)
|
||||
{
|
||||
group_problem_start += kThreadsPerWarp;
|
||||
if (group_problem_start > this->params.problem_count)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Since `group_tile_start` is a reference to `this->problem_tile_start`, this
|
||||
// also sets `this->problem_tile_start`. The fact that `this->problem_tile_start`
|
||||
// is also set here is used later in `next_tile`.
|
||||
group_tile_start = group_tile_end;
|
||||
|
||||
int lane_idx = threadIdx.x % kThreadsPerWarp;
|
||||
int32_t lane_problem = group_problem_start + lane_idx;
|
||||
|
||||
// Compute the number of tiles in the problem assigned to each thread.
|
||||
problem_ending_tile = 0;
|
||||
if (lane_problem < this->params.problem_count)
|
||||
{
|
||||
cutlass::gemm::GemmCoord problem = this->problem_size(lane_problem);
|
||||
cutlass::gemm::GemmCoord grid = this->grid_shape(problem);
|
||||
problem_ending_tile = this->tile_count(grid);
|
||||
}
|
||||
|
||||
// Compute a warp-wide inclusive prefix sum to compute the ending tile index of
|
||||
// each thread's problem.
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 1; i < kThreadsPerWarp; i <<= 1)
|
||||
{
|
||||
int32_t val = __shfl_up_sync(0xffffffff, problem_ending_tile, i);
|
||||
if (lane_idx >= i)
|
||||
{
|
||||
problem_ending_tile += val;
|
||||
}
|
||||
}
|
||||
|
||||
// The total tile count for this group is now in the final position of the prefix sum
|
||||
int32_t tiles_in_group = __shfl_sync(0xffffffff, problem_ending_tile, kThreadsPerWarp - 1);
|
||||
|
||||
problem_ending_tile += group_tile_start;
|
||||
group_tile_end += tiles_in_group;
|
||||
}
|
||||
|
||||
// The next problem to process is the first one that does not have ending tile position
|
||||
// that is greater than or equal to tile index.
|
||||
int32_t problem_idx_in_group = __popc(__ballot_sync(0xffffffff, problem_ending_tile <= this->tile_idx));
|
||||
|
||||
this->problem_idx = group_problem_start + problem_idx_in_group;
|
||||
|
||||
// The starting tile for this problem is the ending tile of the previous problem. In cases
|
||||
// where `problem_idx_in_group` is the first problem in the group, we do not need to reset
|
||||
// `problem_tile_start`, because it is set to the previous group's ending tile in the while
|
||||
// loop above.
|
||||
if (problem_idx_in_group > 0)
|
||||
{
|
||||
this->problem_tile_start = __shfl_sync(0xffffffff, problem_ending_tile, problem_idx_in_group - 1);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static size_t get_workspace_size(
|
||||
cutlass::gemm::GemmCoord const* host_problem_sizes_ptr, int32_t problem_count, int32_t block_count)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
static void host_precompute(cutlass::gemm::GemmCoord const* host_problem_sizes_ptr, int32_t problem_count,
|
||||
int32_t block_count, void* host_workspace_ptr)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@@ -1,646 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "cute/arch/cluster_sm90.hpp"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/arch/mma_sm90.h"
|
||||
#include "cutlass/arch/reg_reconfig.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/epilogue/collective/detail.hpp"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
|
||||
#include "cutlass/kernel_hardware_info.hpp"
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
#include "cutlass/trace.h"
|
||||
#include "cutlass/workspace.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::kernel
|
||||
{
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <class ProblemShape_, class CollectiveMainloop_, class CollectiveEpilogue_, class TileScheduler_>
|
||||
class GemmUniversalGated<ProblemShape_, CollectiveMainloop_, CollectiveEpilogue_, TileScheduler_,
|
||||
cute::enable_if_t<
|
||||
cute::is_base_of_v<KernelTmaWarpSpecializedCooperative, typename CollectiveMainloop_::DispatchPolicy::Schedule>
|
||||
&& CollectiveMainloop_::isGated>>
|
||||
{
|
||||
public:
|
||||
//
|
||||
// Type Aliases
|
||||
//
|
||||
using ProblemShape = ProblemShape_;
|
||||
static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4,
|
||||
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
|
||||
// Mainloop derived types
|
||||
using CollectiveMainloop = CollectiveMainloop_;
|
||||
using TileShape = typename CollectiveMainloop::TileShape;
|
||||
using TiledMma = typename CollectiveMainloop::TiledMma;
|
||||
using ArchTag = typename CollectiveMainloop::ArchTag;
|
||||
using ElementA = typename CollectiveMainloop::ElementA;
|
||||
using StrideA = typename CollectiveMainloop::StrideA;
|
||||
using ElementB = typename CollectiveMainloop::ElementB;
|
||||
using StrideB = typename CollectiveMainloop::StrideB;
|
||||
using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy;
|
||||
using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator;
|
||||
using ClusterShape = typename DispatchPolicy::ClusterShape;
|
||||
using MainloopArguments = typename CollectiveMainloop::Arguments;
|
||||
using MainloopParams = typename CollectiveMainloop::Params;
|
||||
using Activation = typename CollectiveMainloop::Activation;
|
||||
|
||||
// Epilogue derived types
|
||||
using CollectiveEpilogue = CollectiveEpilogue_;
|
||||
using ElementC = typename CollectiveEpilogue::ElementC;
|
||||
using StrideC = typename CollectiveEpilogue::StrideC;
|
||||
using ElementD = typename CollectiveEpilogue::ElementD;
|
||||
using StrideD = typename CollectiveEpilogue::StrideD;
|
||||
using EpilogueArguments = typename CollectiveEpilogue::Arguments;
|
||||
using EpilogueParams = typename CollectiveEpilogue::Params;
|
||||
|
||||
static_assert(ArchTag::kMinComputeCapability >= 90);
|
||||
|
||||
using TileSchedulerTag = TileScheduler_;
|
||||
using TileScheduler =
|
||||
typename detail::TileSchedulerSelector<TileScheduler_, ArchTag, TileShape, ClusterShape>::Scheduler;
|
||||
using TileSchedulerArguments = typename TileScheduler::Arguments;
|
||||
using TileSchedulerParams = typename TileScheduler::Params;
|
||||
|
||||
static constexpr uint32_t NumLoadWarpGroups = 1;
|
||||
static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMma{})) / NumThreadsPerWarpGroup;
|
||||
static constexpr uint32_t MaxThreadsPerBlock
|
||||
= CUTE_STATIC_V(size(TiledMma{})) + (NumLoadWarpGroups * NumThreadsPerWarpGroup);
|
||||
static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
|
||||
|
||||
/// Register requirement for Load and Math WGs
|
||||
static constexpr uint32_t LoadRegisterRequirement = 40;
|
||||
static constexpr uint32_t MmaRegisterRequirement = 232;
|
||||
|
||||
// 1 stage ordered sequence between mainloop and epilogue producer load threads
|
||||
using LoadWarpOrderBarrier = cutlass::OrderedSequenceBarrier<1, 2>;
|
||||
|
||||
// Kernel level shared memory storage
|
||||
struct SharedStorage
|
||||
{
|
||||
struct TensorStorage : cute::aligned_struct<128>
|
||||
{
|
||||
using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage;
|
||||
using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage;
|
||||
|
||||
MainloopTensorStorage mainloop;
|
||||
EpilogueTensorStorage epilogue;
|
||||
} tensors;
|
||||
|
||||
struct PipelineStorage : cute::aligned_struct<16>
|
||||
{
|
||||
using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage;
|
||||
using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage;
|
||||
|
||||
alignas(16) MainloopPipelineStorage mainloop;
|
||||
alignas(16) EpiLoadPipelineStorage epi_load;
|
||||
alignas(16) typename LoadWarpOrderBarrier::SharedStorage load_order;
|
||||
} pipelines;
|
||||
};
|
||||
|
||||
static constexpr int SharedStorageSize = sizeof(SharedStorage);
|
||||
|
||||
// Device side arguments
|
||||
struct Arguments
|
||||
{
|
||||
GemmUniversalMode mode{};
|
||||
ProblemShape problem_shape{};
|
||||
MainloopArguments mainloop{};
|
||||
EpilogueArguments epilogue{};
|
||||
KernelHardwareInfo hw_info{};
|
||||
TileSchedulerArguments scheduler{};
|
||||
};
|
||||
|
||||
// Kernel entry point API
|
||||
struct Params
|
||||
{
|
||||
GemmUniversalMode mode{};
|
||||
ProblemShape problem_shape{};
|
||||
MainloopParams mainloop{};
|
||||
EpilogueParams epilogue{};
|
||||
KernelHardwareInfo hw_info{};
|
||||
TileSchedulerParams scheduler{};
|
||||
void* workspace{nullptr};
|
||||
};
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
// Convert to underlying arguments. In this case, a simple copy for the aliased type.
|
||||
static Params to_underlying_arguments(Arguments const& args, void* workspace)
|
||||
{
|
||||
CUTLASS_TRACE_HOST("to_underlying_arguments():");
|
||||
|
||||
auto problem_shape = args.problem_shape;
|
||||
// if constexpr (detail::IF_SWAP_AB<CollectiveMainloop>::value) {
|
||||
// // swap M/N
|
||||
// get<0>(problem_shape) = get<1>(args.problem_shape);
|
||||
// get<1>(problem_shape) = get<0>(args.problem_shape);
|
||||
// }
|
||||
auto problem_shape_MNKL = append<4>(problem_shape, 1);
|
||||
|
||||
// Get SM count if needed, otherwise use user supplied SM count
|
||||
int sm_count = args.hw_info.sm_count;
|
||||
if (sm_count <= 0)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(
|
||||
" WARNING: Arguments do not include a valid SM count.\n"
|
||||
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
|
||||
sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id);
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
|
||||
|
||||
KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};
|
||||
|
||||
// Calculate workspace pointers
|
||||
uint8_t* workspace_ptr = reinterpret_cast<uint8_t*>(workspace);
|
||||
size_t workspace_offset = 0;
|
||||
|
||||
void* scheduler_workspace = workspace_ptr;
|
||||
workspace_offset += TileScheduler::template get_workspace_size<ProblemShape, ElementAccumulator>(
|
||||
args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups);
|
||||
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
|
||||
|
||||
void* epilogue_workspace = workspace_ptr + workspace_offset;
|
||||
workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue);
|
||||
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
|
||||
|
||||
void* mainloop_workspace = nullptr;
|
||||
// Precompute the sub tiles numbers in epilogue, pass into tile scheduler. Therefore it will be used
|
||||
// in separate reduction scheme for streamk case, NumEpilogueSubTiles default value is 1, which means
|
||||
// subtile will not be used, therefore separate reduction will not be enabled.
|
||||
constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{});
|
||||
TileSchedulerParams scheduler = TileScheduler::to_underlying_arguments(problem_shape_MNKL, TileShape{},
|
||||
ClusterShape{}, hw_info, args.scheduler, scheduler_workspace, NumEpilogueSubTiles);
|
||||
|
||||
return {args.mode, problem_shape,
|
||||
CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, mainloop_workspace),
|
||||
CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, epilogue_workspace), hw_info,
|
||||
scheduler, workspace};
|
||||
}
|
||||
|
||||
static bool can_implement(Arguments const& args)
|
||||
{
|
||||
bool implementable = (args.mode == GemmUniversalMode::kGemm)
|
||||
or (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4);
|
||||
if (!implementable)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n");
|
||||
return implementable;
|
||||
}
|
||||
implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop);
|
||||
implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue);
|
||||
implementable &= TileScheduler::can_implement(args.scheduler);
|
||||
return implementable;
|
||||
}
|
||||
|
||||
static size_t get_workspace_size(Arguments const& args)
|
||||
{
|
||||
size_t workspace_size = 0;
|
||||
constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{});
|
||||
|
||||
workspace_size += TileScheduler::template get_workspace_size<ProblemShape, ElementAccumulator>(
|
||||
args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles);
|
||||
workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment);
|
||||
|
||||
workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue);
|
||||
workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment);
|
||||
|
||||
return workspace_size;
|
||||
}
|
||||
|
||||
static cutlass::Status initialize_workspace(Arguments const& args, void* workspace = nullptr,
|
||||
cudaStream_t stream = nullptr, CudaHostAdapter* cuda_adapter = nullptr)
|
||||
{
|
||||
Status status = Status::kSuccess;
|
||||
uint8_t* workspace_ptr = reinterpret_cast<uint8_t*>(workspace);
|
||||
size_t workspace_offset = 0;
|
||||
constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{});
|
||||
|
||||
status = TileScheduler::template initialize_workspace<ProblemShape, ElementAccumulator>(args.scheduler,
|
||||
workspace_ptr + workspace_offset, stream, args.problem_shape, args.hw_info, NumMmaWarpGroups,
|
||||
NumEpilogueSubTiles);
|
||||
workspace_offset += TileScheduler::template get_workspace_size<ProblemShape, ElementAccumulator>(
|
||||
args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles);
|
||||
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
|
||||
if (status != Status::kSuccess)
|
||||
{
|
||||
return status;
|
||||
}
|
||||
|
||||
status = CollectiveEpilogue::initialize_workspace(
|
||||
args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter);
|
||||
workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue);
|
||||
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
|
||||
if (status != Status::kSuccess)
|
||||
{
|
||||
return status;
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
// Computes the kernel launch grid shape based on runtime parameters
|
||||
static dim3 get_grid_shape(Params const& params)
|
||||
{
|
||||
// Given device SM count, set grid size s.t. we do not launch more thread blocks than we can run concurrently
|
||||
TileSchedulerArguments args{};
|
||||
if constexpr (!std::is_const_v<decltype(args.max_swizzle_size)>)
|
||||
{
|
||||
args.max_swizzle_size = 1 << params.scheduler.log_swizzle_size_;
|
||||
}
|
||||
args.raster_order = params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN
|
||||
? TileScheduler::RasterOrderOptions::AlongN
|
||||
: TileScheduler::RasterOrderOptions::AlongM;
|
||||
return TileScheduler::get_grid_shape(params.problem_shape, TileShape{}, ClusterShape{}, params.hw_info, args);
|
||||
}
|
||||
|
||||
static dim3 get_block_shape()
|
||||
{
|
||||
return dim3(MaxThreadsPerBlock, 1, 1);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const& params, char* smem_buf)
|
||||
{
|
||||
using namespace cute;
|
||||
using X = Underscore;
|
||||
|
||||
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
|
||||
#if !defined(__CUDA_ARCH_FEAT_SM90_ALL)
|
||||
printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n");
|
||||
#else
|
||||
|
||||
// Preconditions
|
||||
static_assert(size(TiledMma{}) == 256, "Cooperative kernel must have TiledMMA operating using 256 threads.");
|
||||
static_assert(size<0>(TileShape{}) >= 128,
|
||||
"Cooperative kernel requires Tile Size to be greater than or equal to 128 along the M-dimension.");
|
||||
|
||||
static_assert(cute::rank(StrideA{}) == 3,
|
||||
"StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideB{}) == 3,
|
||||
"StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideC{}) == 3,
|
||||
"StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideD{}) == 3,
|
||||
"StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
|
||||
/* In the Cooperative kernel, Consumer0 and Consumer1 collaborate on the same tile */
|
||||
enum class WarpGroupRole
|
||||
{
|
||||
Producer = 0,
|
||||
Consumer0 = 1,
|
||||
Consumer1 = 2
|
||||
};
|
||||
enum class ProducerWarpRole
|
||||
{
|
||||
Mainloop = 0,
|
||||
Warp1 = 1,
|
||||
Epilogue = 2,
|
||||
Warp3 = 3
|
||||
};
|
||||
|
||||
// Kernel level shared memory storage
|
||||
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
|
||||
|
||||
int thread_idx = int(threadIdx.x);
|
||||
int lane_idx = canonical_lane_idx();
|
||||
int warp_idx = canonical_warp_idx_sync();
|
||||
int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup;
|
||||
int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup;
|
||||
int mma_thread_idx = thread_idx % size(TiledMma{});
|
||||
auto warp_group_role = WarpGroupRole(canonical_warp_group_idx());
|
||||
auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group);
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
|
||||
|
||||
// Issue Tma Descriptor Prefetch from a single thread
|
||||
if ((warp_idx == 0) && lane_predicate)
|
||||
{
|
||||
CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
|
||||
CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue);
|
||||
}
|
||||
|
||||
// Mainloop Load pipeline
|
||||
using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline;
|
||||
typename MainloopPipeline::Params mainloop_pipeline_params;
|
||||
if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Mainloop)
|
||||
{
|
||||
mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer;
|
||||
}
|
||||
if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1)
|
||||
{
|
||||
mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer;
|
||||
}
|
||||
mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0;
|
||||
mainloop_pipeline_params.num_consumers = size(TiledMma{});
|
||||
mainloop_pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytes;
|
||||
MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{});
|
||||
|
||||
// Epilogue Load pipeline
|
||||
using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline;
|
||||
typename EpiLoadPipeline::Params epi_load_pipeline_params;
|
||||
if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Epilogue)
|
||||
{
|
||||
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer;
|
||||
}
|
||||
if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1)
|
||||
{
|
||||
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer;
|
||||
}
|
||||
epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster();
|
||||
epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp;
|
||||
epi_load_pipeline_params.consumer_arv_count = size(TiledMma{});
|
||||
epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes;
|
||||
EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params);
|
||||
|
||||
// Epilogue Store pipeline
|
||||
using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline;
|
||||
typename EpiStorePipeline::Params epi_store_pipeline_params;
|
||||
epi_store_pipeline_params.always_wait = true;
|
||||
EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params);
|
||||
|
||||
typename LoadWarpOrderBarrier::Params params_load_order_barrier;
|
||||
params_load_order_barrier.group_id = producer_warp_role == ProducerWarpRole::Mainloop ? 0 : 1;
|
||||
params_load_order_barrier.group_size = NumThreadsPerWarp;
|
||||
LoadWarpOrderBarrier load_order_barrier(shared_storage.pipelines.load_order, params_load_order_barrier);
|
||||
|
||||
// Initialize starting pipeline states for the collectives
|
||||
// Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding)
|
||||
typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state;
|
||||
typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state;
|
||||
|
||||
// For the DMA Load (producer) we start with an opposite phase
|
||||
// i.e., we skip all waits since we know that the buffer is indeed empty
|
||||
PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state<MainloopPipeline>();
|
||||
PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state<EpiLoadPipeline>();
|
||||
PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state<EpiStorePipeline>();
|
||||
|
||||
auto cluster_wait_fn = []()
|
||||
{
|
||||
// We need this to guarantee that the Pipeline init is visible
|
||||
// To all producers and consumer thread blocks in the Cluster
|
||||
if constexpr (size(ClusterShape{}) > 1)
|
||||
{
|
||||
cute::cluster_arrive_relaxed();
|
||||
return []() { cute::cluster_wait(); };
|
||||
}
|
||||
else
|
||||
{
|
||||
__syncthreads();
|
||||
return []() {}; // do nothing
|
||||
}
|
||||
}();
|
||||
|
||||
// Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK)
|
||||
auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{});
|
||||
|
||||
// Get the appropriate blocks for this thread block -- potential for thread block locality
|
||||
TiledMma tiled_mma;
|
||||
auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K)
|
||||
|
||||
TileScheduler scheduler{params.scheduler};
|
||||
auto work_tile_info = scheduler.get_current_work();
|
||||
|
||||
// In a warp specialized kernel, collectives expose data movement and compute operations separately
|
||||
CollectiveMainloop collective_mainloop;
|
||||
CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue);
|
||||
|
||||
// Prepare and partition the input tensors. Expects a tuple of tensors where:
|
||||
// get<0>(load_inputs) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l)
|
||||
// get<1>(load_inputs) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l)
|
||||
auto load_inputs = collective_mainloop.load_init(problem_shape_MNKL, params.mainloop);
|
||||
static_assert(cute::tuple_size_v<decltype(load_inputs)> >= 3,
|
||||
"Output of load_init must have at least three elements (A, B, Aux)");
|
||||
|
||||
// Extract out partitioned A and B.
|
||||
Tensor gA_mkl = get<0>(load_inputs);
|
||||
Tensor gB_nkl = get<1>(load_inputs);
|
||||
Tensor gAux_xkl = get<2>(load_inputs);
|
||||
|
||||
// Get pipeline stage increments from tensor shapes
|
||||
auto k_tile_count = size<3>(gA_mkl);
|
||||
|
||||
// Wait for all thread blocks in the Cluster
|
||||
cluster_wait_fn();
|
||||
|
||||
if (warp_group_role == WarpGroupRole::Producer)
|
||||
{
|
||||
cutlass::arch::warpgroup_reg_dealloc<LoadRegisterRequirement>();
|
||||
|
||||
// Mainloop Producer Warp
|
||||
if (producer_warp_role == ProducerWarpRole::Mainloop)
|
||||
{
|
||||
bool do_load_order_arrive = true;
|
||||
while (work_tile_info.is_valid())
|
||||
{
|
||||
if (!TileScheduler::valid_warpgroup_in_work_tile(work_tile_info))
|
||||
{
|
||||
work_tile_info = fetch_next_work(work_tile_info, scheduler);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape
|
||||
auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
|
||||
auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
|
||||
auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
|
||||
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
|
||||
|
||||
// Get the number of K tiles to compute for this work as well as the starting K tile offset of the
|
||||
// work.
|
||||
auto work_k_tile_count
|
||||
= TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape);
|
||||
auto work_k_tile_start = TileScheduler::get_work_k_tile_start(work_tile_info);
|
||||
auto k_tile_iter
|
||||
= cute::make_coord_iterator(idx2crd(work_k_tile_start, shape<3>(gA_mkl)), shape<3>(gA_mkl));
|
||||
|
||||
collective_mainloop.load(params.mainloop, mainloop_pipeline, mainloop_pipe_producer_state,
|
||||
load_inputs, blk_coord, k_tile_iter, work_k_tile_count, lane_idx, block_rank_in_cluster,
|
||||
shared_storage.tensors.mainloop);
|
||||
// Update starting pipeline state for the next tile
|
||||
mainloop_pipe_producer_state.advance(work_k_tile_count);
|
||||
|
||||
// Signal for the epilogue load warp to begin
|
||||
if (do_load_order_arrive)
|
||||
{
|
||||
load_order_barrier.arrive();
|
||||
do_load_order_arrive = false;
|
||||
}
|
||||
|
||||
// Get next work tile
|
||||
work_tile_info = fetch_next_work(work_tile_info, scheduler);
|
||||
} // Scheduler work fetch loop
|
||||
|
||||
// Make sure all Consumer Warp Groups have been waited upon
|
||||
collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state);
|
||||
} // Mainloop Producer Warp End
|
||||
|
||||
// Epilogue Producer Warp
|
||||
else if (producer_warp_role == ProducerWarpRole::Epilogue && collective_epilogue.is_producer_load_needed())
|
||||
{
|
||||
while (work_tile_info.is_valid())
|
||||
{
|
||||
if (!TileScheduler::requires_separate_reduction(params.scheduler))
|
||||
{
|
||||
load_order_barrier.wait();
|
||||
}
|
||||
if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler))
|
||||
{
|
||||
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape
|
||||
auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
|
||||
auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
|
||||
auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
|
||||
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
|
||||
|
||||
epi_load_pipe_producer_state = collective_epilogue.load(epi_load_pipeline,
|
||||
epi_load_pipe_producer_state, problem_shape_MNKL, blk_shape, blk_coord, tiled_mma, lane_idx,
|
||||
shared_storage.tensors.epilogue, work_tile_info.reduction_subtile_idx());
|
||||
}
|
||||
|
||||
// Get next work tile
|
||||
work_tile_info = fetch_next_work(work_tile_info, scheduler);
|
||||
} // Scheduler work fetch loop
|
||||
|
||||
// Make sure all Consumer Warp Groups have been waited upon
|
||||
collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state);
|
||||
} // Epilogue Producer Warp End
|
||||
} // Producer Warp Group End
|
||||
|
||||
else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1)
|
||||
{
|
||||
cutlass::arch::warpgroup_reg_alloc<MmaRegisterRequirement>();
|
||||
|
||||
// Do we potentially issue tail arrives for TMA stores, if epilogue load is waiting for it
|
||||
bool do_store_tail = false;
|
||||
float scale_d0 = params.mainloop.scale_d0;
|
||||
float scale_d1 = params.mainloop.scale_d1;
|
||||
while (work_tile_info.is_valid())
|
||||
{
|
||||
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape
|
||||
auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
|
||||
auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
|
||||
auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
|
||||
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
|
||||
auto work_k_tile_count
|
||||
= TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape);
|
||||
|
||||
// Allocate the accumulators for the (M,N) blk_shape
|
||||
//
|
||||
// MSVC CTAD breaks if we say "Tensor" here, so we use "auto" instead.
|
||||
auto accumulators0 = partition_fragment_C(tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N)
|
||||
auto accumulators1 = partition_fragment_C(tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N)
|
||||
if (TileScheduler::valid_warpgroup_in_work_tile(work_tile_info))
|
||||
{
|
||||
collective_mainloop.mma(mainloop_pipeline, mainloop_pipe_consumer_state, accumulators0,
|
||||
accumulators1, work_k_tile_count, mma_thread_idx, shared_storage.tensors.mainloop,
|
||||
params.mainloop);
|
||||
|
||||
// Make sure the math instructions are done and free buffers before entering the epilogue
|
||||
collective_mainloop.mma_tail(mainloop_pipeline, mainloop_pipe_consumer_state, work_k_tile_count);
|
||||
|
||||
// Update starting mainloop pipeline state for the next tile
|
||||
mainloop_pipe_consumer_state.advance(work_k_tile_count);
|
||||
}
|
||||
// Index of warp group within consumer warp groups
|
||||
int consumer_warp_group_idx = canonical_warp_group_idx() - NumLoadWarpGroups;
|
||||
|
||||
// Perform reduction across splits, if needed
|
||||
TileScheduler::fixup(
|
||||
params.scheduler, work_tile_info, accumulators0, NumMmaWarpGroups, consumer_warp_group_idx);
|
||||
TileScheduler::fixup(
|
||||
params.scheduler, work_tile_info, accumulators1, NumMmaWarpGroups, consumer_warp_group_idx);
|
||||
|
||||
Activation elt_op;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(accumulators0); i++)
|
||||
{
|
||||
accumulators0[i] = (accumulators0[i] * scale_d0) * elt_op(scale_d1 * accumulators1[i]);
|
||||
}
|
||||
|
||||
if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler))
|
||||
{
|
||||
// Epilogue and write to gD
|
||||
auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next]
|
||||
= collective_epilogue.store(epi_load_pipeline, epi_load_pipe_consumer_state, epi_store_pipeline,
|
||||
epi_store_pipe_producer_state, problem_shape_MNKL, blk_shape, blk_coord, accumulators0,
|
||||
tiled_mma, mma_thread_idx, shared_storage.tensors.epilogue,
|
||||
work_tile_info.reduction_subtile_idx());
|
||||
epi_load_pipe_consumer_state = epi_load_pipe_consumer_state_next;
|
||||
epi_store_pipe_producer_state = epi_store_pipe_producer_state_next;
|
||||
do_store_tail = true;
|
||||
}
|
||||
|
||||
// Get next work tile
|
||||
work_tile_info = fetch_next_work(work_tile_info, scheduler);
|
||||
} // Scheduler work fetch loop
|
||||
|
||||
if (do_store_tail)
|
||||
{
|
||||
collective_epilogue.store_tail(
|
||||
epi_load_pipeline, epi_load_pipe_consumer_state, epi_store_pipeline, epi_store_pipe_producer_state);
|
||||
}
|
||||
} // Consumer Warp Groups End
|
||||
#endif
|
||||
}
|
||||
|
||||
private:
|
||||
// Kernel helper function to get next work unit
|
||||
CUTLASS_DEVICE
|
||||
typename TileScheduler::WorkTileInfo fetch_next_work(
|
||||
typename TileScheduler::WorkTileInfo& work_tile_info, TileScheduler& scheduler) const
|
||||
{
|
||||
// Check whether we should continue on with the current work unit. If this is the case,
|
||||
// the work unit will have been updated in continue_current_work to reflect the new
|
||||
// tile to be computed.
|
||||
if (scheduler.continue_current_work(work_tile_info))
|
||||
{
|
||||
return work_tile_info;
|
||||
}
|
||||
|
||||
// Get next work tile
|
||||
scheduler.advance_to_next_work();
|
||||
return scheduler.get_current_work();
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::kernel
|
||||
@@ -1,621 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "cute/arch/cluster_sm90.hpp"
|
||||
#include "cutlass/arch/mma_sm90.h"
|
||||
#include "cutlass/arch/reg_reconfig.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/epilogue/collective/detail.hpp"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp"
|
||||
#include "cutlass/kernel_hardware_info.hpp"
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
#include "cutlass/trace.h"
|
||||
#include "cutlass/workspace.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
#include "cute/util/debug.hpp"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::kernel
|
||||
{
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <class ProblemShape_, class CollectiveMainloop_, class CollectiveEpilogue_, class TileScheduler_>
|
||||
class GemmUniversalGated<ProblemShape_, CollectiveMainloop_, CollectiveEpilogue_, TileScheduler_,
|
||||
cute::enable_if_t<
|
||||
cute::is_base_of_v<KernelTmaWarpSpecializedPingpong, typename CollectiveMainloop_::DispatchPolicy::Schedule>
|
||||
&& CollectiveMainloop_::isGated>>
|
||||
{
|
||||
public:
|
||||
//
|
||||
// Type Aliases
|
||||
//
|
||||
using ProblemShape = ProblemShape_;
|
||||
static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4,
|
||||
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
|
||||
// Mainloop derived types
|
||||
using CollectiveMainloop = CollectiveMainloop_;
|
||||
using TileShape = typename CollectiveMainloop::TileShape;
|
||||
using TiledMma = typename CollectiveMainloop::TiledMma;
|
||||
using ArchTag = typename CollectiveMainloop::ArchTag;
|
||||
using ElementA = typename CollectiveMainloop::ElementA;
|
||||
using StrideA = typename CollectiveMainloop::StrideA;
|
||||
using ElementB = typename CollectiveMainloop::ElementB;
|
||||
using StrideB = typename CollectiveMainloop::StrideB;
|
||||
using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy;
|
||||
using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator;
|
||||
using ClusterShape = typename DispatchPolicy::ClusterShape;
|
||||
using MainloopArguments = typename CollectiveMainloop::Arguments;
|
||||
using MainloopParams = typename CollectiveMainloop::Params;
|
||||
using Activation = typename CollectiveMainloop::Activation;
|
||||
static_assert(ArchTag::kMinComputeCapability >= 90);
|
||||
|
||||
// Epilogue derived types
|
||||
using CollectiveEpilogue = CollectiveEpilogue_;
|
||||
using ElementC = typename CollectiveEpilogue::ElementC;
|
||||
using StrideC = typename CollectiveEpilogue::StrideC;
|
||||
using ElementD = typename CollectiveEpilogue::ElementD;
|
||||
using StrideD = typename CollectiveEpilogue::StrideD;
|
||||
using EpilogueArguments = typename CollectiveEpilogue::Arguments;
|
||||
using EpilogueParams = typename CollectiveEpilogue::Params;
|
||||
|
||||
static_assert(!cute::is_same_v<TileScheduler_, StreamKScheduler>,
|
||||
"Ping-pong kernel does not currently support stream-K scheduler.");
|
||||
using TileSchedulerTag = TileScheduler_;
|
||||
using TileScheduler =
|
||||
typename detail::TileSchedulerSelector<TileScheduler_, ArchTag, TileShape, ClusterShape>::Scheduler;
|
||||
using TileSchedulerArguments = typename TileScheduler::Arguments;
|
||||
using TileSchedulerParams = typename TileScheduler::Params;
|
||||
|
||||
static constexpr uint32_t NumLoadWarpGroups = 1;
|
||||
static constexpr uint32_t NumMmaWarpGroups = 2;
|
||||
static constexpr uint32_t MaxThreadsPerBlock
|
||||
= CUTE_STATIC_V(size(TiledMma{})) + (NumMmaWarpGroups * NumThreadsPerWarpGroup);
|
||||
static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
|
||||
|
||||
/// Register requirement for Load and Math WGs
|
||||
static constexpr uint32_t LoadRegisterRequirement = 40;
|
||||
static constexpr uint32_t MmaRegisterRequirement = 232;
|
||||
|
||||
// 1 stage ordered sequence between mainloop and epilogue producer load threads
|
||||
using LoadWarpOrderBarrier = cutlass::OrderedSequenceBarrier<1, 2>;
|
||||
|
||||
// Order Sequence barrier with two stages: one for Mainloop and one for Epilogue
|
||||
static constexpr uint32_t StagesPerMathWarpGroup = 2;
|
||||
using MathWarpGroupOrderBarrier = cutlass::OrderedSequenceBarrier<StagesPerMathWarpGroup, NumMmaWarpGroups>;
|
||||
|
||||
// Kernel level shared memory storage
|
||||
struct SharedStorage
|
||||
{
|
||||
struct TensorStorage : cute::aligned_struct<128>
|
||||
{
|
||||
using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage;
|
||||
using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage;
|
||||
|
||||
MainloopTensorStorage mainloop;
|
||||
EpilogueTensorStorage epilogue;
|
||||
} tensors;
|
||||
|
||||
struct PipelineStorage : cute::aligned_struct<16>
|
||||
{
|
||||
using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage;
|
||||
using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage;
|
||||
using MathWarpGroupOrderBarrierStorage = typename MathWarpGroupOrderBarrier::SharedStorage;
|
||||
|
||||
alignas(16) MainloopPipelineStorage mainloop;
|
||||
alignas(16) EpiLoadPipelineStorage epi_load;
|
||||
alignas(16) MathWarpGroupOrderBarrierStorage math_wg_order;
|
||||
alignas(16) typename LoadWarpOrderBarrier::SharedStorage load_order;
|
||||
} pipelines;
|
||||
};
|
||||
|
||||
static constexpr int SharedStorageSize = sizeof(SharedStorage);
|
||||
|
||||
// Device side arguments
|
||||
struct Arguments
|
||||
{
|
||||
GemmUniversalMode mode{};
|
||||
ProblemShape problem_shape{};
|
||||
MainloopArguments mainloop{};
|
||||
EpilogueArguments epilogue{};
|
||||
KernelHardwareInfo hw_info{};
|
||||
TileSchedulerArguments scheduler{};
|
||||
};
|
||||
|
||||
// Kernel entry point API
|
||||
struct Params
|
||||
{
|
||||
GemmUniversalMode mode{};
|
||||
ProblemShape problem_shape{};
|
||||
MainloopParams mainloop{};
|
||||
EpilogueParams epilogue{};
|
||||
KernelHardwareInfo hw_info{};
|
||||
TileSchedulerParams scheduler{};
|
||||
};
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
// Convert to underlying arguments. In this case, a simple copy for the aliased type.
|
||||
static Params to_underlying_arguments(Arguments const& args, void* workspace)
|
||||
{
|
||||
CUTLASS_TRACE_HOST("to_underlying_arguments():");
|
||||
|
||||
(void) workspace;
|
||||
auto problem_shape = args.problem_shape;
|
||||
// if constexpr (detail::IF_SWAP_AB<CollectiveMainloop>::value) {
|
||||
// // swap M/N
|
||||
// get<0>(problem_shape) = get<1>(args.problem_shape);
|
||||
// get<1>(problem_shape) = get<0>(args.problem_shape);
|
||||
// }
|
||||
auto problem_shape_MNKL = append<4>(problem_shape, 1);
|
||||
|
||||
// Get SM count if needed, otherwise use user supplied SM count
|
||||
int sm_count = args.hw_info.sm_count;
|
||||
if (sm_count <= 0)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(
|
||||
" WARNING: Arguments do not include a valid SM count.\n"
|
||||
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
|
||||
sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id);
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
|
||||
KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};
|
||||
|
||||
// Calculate workspace pointers
|
||||
uint8_t* workspace_ptr = reinterpret_cast<uint8_t*>(workspace);
|
||||
size_t workspace_offset = 0;
|
||||
|
||||
void* scheduler_workspace = workspace_ptr;
|
||||
workspace_offset += TileScheduler::template get_workspace_size<ProblemShape, ElementAccumulator>(
|
||||
args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups);
|
||||
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
|
||||
|
||||
void* epilogue_workspace = workspace_ptr + workspace_offset;
|
||||
workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue);
|
||||
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
|
||||
|
||||
void* mainloop_workspace = nullptr;
|
||||
|
||||
return {args.mode, problem_shape,
|
||||
CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, mainloop_workspace),
|
||||
CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, epilogue_workspace), hw_info,
|
||||
TileScheduler::to_underlying_arguments(
|
||||
problem_shape_MNKL, TileShape{}, ClusterShape{}, hw_info, args.scheduler, scheduler_workspace)};
|
||||
}
|
||||
|
||||
static bool can_implement(Arguments const& args)
|
||||
{
|
||||
bool implementable = (args.mode == GemmUniversalMode::kGemm)
|
||||
or (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4);
|
||||
if (!implementable)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n");
|
||||
return implementable;
|
||||
}
|
||||
implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop);
|
||||
implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue);
|
||||
implementable &= TileScheduler::can_implement(args.scheduler);
|
||||
return implementable;
|
||||
}
|
||||
|
||||
static size_t get_workspace_size(Arguments const& args)
|
||||
{
|
||||
size_t workspace_size = 0;
|
||||
workspace_size += TileScheduler::template get_workspace_size<ProblemShape, ElementAccumulator>(
|
||||
args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups);
|
||||
workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment);
|
||||
|
||||
workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue);
|
||||
workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment);
|
||||
|
||||
return workspace_size;
|
||||
}
|
||||
|
||||
static cutlass::Status initialize_workspace(Arguments const& args, void* workspace = nullptr,
|
||||
cudaStream_t stream = nullptr, CudaHostAdapter* cuda_adapter = nullptr)
|
||||
{
|
||||
Status status = Status::kSuccess;
|
||||
uint8_t* workspace_ptr = reinterpret_cast<uint8_t*>(workspace);
|
||||
size_t workspace_offset = 0;
|
||||
|
||||
status = TileScheduler::template initialize_workspace<ProblemShape, ElementAccumulator>(args.scheduler,
|
||||
workspace_ptr + workspace_offset, stream, args.problem_shape, args.hw_info, NumMmaWarpGroups);
|
||||
workspace_offset += TileScheduler::template get_workspace_size<ProblemShape, ElementAccumulator>(
|
||||
args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups);
|
||||
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
|
||||
if (status != Status::kSuccess)
|
||||
{
|
||||
return status;
|
||||
}
|
||||
|
||||
status = CollectiveEpilogue::initialize_workspace(
|
||||
args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter);
|
||||
workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue);
|
||||
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
|
||||
if (status != Status::kSuccess)
|
||||
{
|
||||
return status;
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
// Computes the kernel launch grid shape based on runtime parameters
|
||||
static dim3 get_grid_shape(Params const& params)
|
||||
{
|
||||
// Given device SM count, set grid size s.t. we do not launch more thread blocks than we can run concurrently
|
||||
TileSchedulerArguments args{};
|
||||
if constexpr (!std::is_const_v<decltype(args.max_swizzle_size)>)
|
||||
{
|
||||
args.max_swizzle_size = 1 << params.scheduler.log_swizzle_size_;
|
||||
}
|
||||
args.raster_order = params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN
|
||||
? TileScheduler::RasterOrderOptions::AlongN
|
||||
: TileScheduler::RasterOrderOptions::AlongM;
|
||||
return TileScheduler::get_grid_shape(params.problem_shape, TileShape{}, ClusterShape{}, params.hw_info, args);
|
||||
}
|
||||
|
||||
static dim3 get_block_shape()
|
||||
{
|
||||
return dim3(MaxThreadsPerBlock, 1, 1);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const& params, char* smem_buf)
|
||||
{
|
||||
using namespace cute;
|
||||
using X = Underscore;
|
||||
|
||||
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
|
||||
#if !defined(__CUDA_ARCH_FEAT_SM90_ALL)
|
||||
printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n");
|
||||
#else
|
||||
|
||||
// Preconditions
|
||||
static_assert(cute::rank(StrideA{}) == 3,
|
||||
"StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideB{}) == 3,
|
||||
"StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideC{}) == 3,
|
||||
"StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideD{}) == 3,
|
||||
"StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
|
||||
enum class WarpGroupRole
|
||||
{
|
||||
Producer = 0,
|
||||
Consumer0 = 1,
|
||||
Consumer1 = 2
|
||||
};
|
||||
enum class ProducerWarpRole
|
||||
{
|
||||
Mainloop = 0,
|
||||
Warp1 = 1,
|
||||
Epilogue = 2,
|
||||
Warp3 = 3
|
||||
};
|
||||
|
||||
// Kernel level shared memory storage
|
||||
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
|
||||
|
||||
int thread_idx = int(threadIdx.x);
|
||||
int lane_idx = canonical_lane_idx();
|
||||
int warp_idx = canonical_warp_idx_sync();
|
||||
int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup;
|
||||
int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup;
|
||||
auto warp_group_role = WarpGroupRole(canonical_warp_group_idx());
|
||||
auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group);
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
|
||||
|
||||
// Issue Tma Descriptor Prefetch from a single thread
|
||||
if ((warp_idx == 0) && lane_predicate)
|
||||
{
|
||||
CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
|
||||
CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue);
|
||||
}
|
||||
|
||||
// Mainloop Load pipeline
|
||||
using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline;
|
||||
typename MainloopPipeline::Params mainloop_pipeline_params;
|
||||
if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Mainloop)
|
||||
{
|
||||
mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer;
|
||||
}
|
||||
if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1)
|
||||
{
|
||||
mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer;
|
||||
}
|
||||
mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0;
|
||||
mainloop_pipeline_params.num_consumers = NumThreadsPerWarpGroup;
|
||||
mainloop_pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytes;
|
||||
MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{});
|
||||
|
||||
// Epilogue Load pipeline
|
||||
using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline;
|
||||
typename EpiLoadPipeline::Params epi_load_pipeline_params;
|
||||
if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Epilogue)
|
||||
{
|
||||
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer;
|
||||
}
|
||||
if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1)
|
||||
{
|
||||
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer;
|
||||
}
|
||||
epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster();
|
||||
epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp;
|
||||
epi_load_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup;
|
||||
epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes;
|
||||
EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params);
|
||||
|
||||
// Epilogue Store pipeline
|
||||
using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline;
|
||||
typename EpiStorePipeline::Params epi_store_pipeline_params;
|
||||
epi_store_pipeline_params.always_wait = true;
|
||||
EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params);
|
||||
|
||||
typename LoadWarpOrderBarrier::Params params_load_order_barrier;
|
||||
params_load_order_barrier.group_id = producer_warp_role == ProducerWarpRole::Mainloop ? 0 : 1;
|
||||
params_load_order_barrier.group_size = NumThreadsPerWarp;
|
||||
LoadWarpOrderBarrier load_order_barrier(shared_storage.pipelines.load_order, params_load_order_barrier);
|
||||
|
||||
typename MathWarpGroupOrderBarrier::Params params_math_wg_order_barrier;
|
||||
// DMA Load WG will not participate in these Ordered Barrier syncs
|
||||
params_math_wg_order_barrier.group_id = canonical_warp_group_idx() - static_cast<int>(WarpGroupRole::Consumer0);
|
||||
params_math_wg_order_barrier.group_size = NumThreadsPerWarpGroup; // Number of threads / participants in a group
|
||||
MathWarpGroupOrderBarrier math_wg_order_barrier(
|
||||
shared_storage.pipelines.math_wg_order, params_math_wg_order_barrier);
|
||||
|
||||
// Initialize starting pipeline states for the collectives
|
||||
// Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding)
|
||||
typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state;
|
||||
typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state;
|
||||
|
||||
// For the DMA Load (producer) we start with an opposite phase
|
||||
// i.e., we skip all waits since we know that the buffer is indeed empty
|
||||
PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state<MainloopPipeline>();
|
||||
PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state<EpiLoadPipeline>();
|
||||
PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state<EpiStorePipeline>();
|
||||
|
||||
auto cluster_wait_fn = [&]()
|
||||
{
|
||||
// We need this to guarantee that the Pipeline init is visible
|
||||
// To all producers and consumer thread blocks in the Cluster
|
||||
if constexpr (size(ClusterShape{}) > 1)
|
||||
{
|
||||
cute::cluster_arrive_relaxed();
|
||||
return []() { cute::cluster_wait(); };
|
||||
}
|
||||
else
|
||||
{
|
||||
__syncthreads();
|
||||
return []() {}; // do nothing
|
||||
}
|
||||
}();
|
||||
|
||||
// Separate out problem shape for convenience
|
||||
// Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK)
|
||||
auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{});
|
||||
|
||||
// Get the appropriate blocks for this thread block -- potential for thread block locality
|
||||
TiledMma tiled_mma;
|
||||
auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K)
|
||||
|
||||
// In a warp specialized kernel, collectives expose data movement and compute operations separately
|
||||
CollectiveMainloop collective_mainloop;
|
||||
CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue);
|
||||
|
||||
// Prepare and partition the input tensors. Expects a tuple of tensors where:
|
||||
// get<0>(load_inputs) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l)
|
||||
// get<1>(load_inputs) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l)
|
||||
auto load_inputs = collective_mainloop.load_init(problem_shape_MNKL, params.mainloop);
|
||||
static_assert(cute::tuple_size_v<decltype(load_inputs)> >= 3,
|
||||
"Output of load_init must have at least three elements (A, B, Aux)");
|
||||
|
||||
// Extract out partitioned A and B.
|
||||
Tensor gA_mkl = get<0>(load_inputs);
|
||||
Tensor gB_nkl = get<1>(load_inputs);
|
||||
Tensor gAux_xkl = get<2>(load_inputs);
|
||||
|
||||
// Get pipeline stage increments from tensor shapes
|
||||
auto k_tile_count = size<3>(gA_mkl);
|
||||
auto c_tile_count = CollectiveEpilogue::get_load_pipe_increment(blk_shape);
|
||||
auto d_tile_count = CollectiveEpilogue::get_store_pipe_increment(blk_shape);
|
||||
|
||||
TileScheduler scheduler{params.scheduler};
|
||||
|
||||
if (warp_group_role == WarpGroupRole::Consumer1)
|
||||
{
|
||||
// Advance 2nd Math WG to the next work tile for the startup
|
||||
scheduler.advance_to_next_work();
|
||||
// Advance 2nd Math WG pipeline states to the end of 1st Math WG
|
||||
mainloop_pipe_consumer_state.advance(k_tile_count);
|
||||
epi_load_pipe_consumer_state.advance(c_tile_count);
|
||||
epi_store_pipe_producer_state.advance(d_tile_count);
|
||||
}
|
||||
auto work_tile_info = scheduler.get_current_work();
|
||||
|
||||
// Wait for all thread blocks in the Cluster
|
||||
cluster_wait_fn();
|
||||
|
||||
if (warp_group_role == WarpGroupRole::Producer)
|
||||
{
|
||||
cutlass::arch::warpgroup_reg_dealloc<LoadRegisterRequirement>();
|
||||
|
||||
// Mainloop Producer Warp
|
||||
if (producer_warp_role == ProducerWarpRole::Mainloop)
|
||||
{
|
||||
bool do_load_order_arrive = true;
|
||||
while (work_tile_info.is_valid())
|
||||
{
|
||||
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape
|
||||
auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
|
||||
auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
|
||||
auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
|
||||
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
|
||||
|
||||
auto k_tile_iter = cute::make_coord_iterator(shape<3>(gA_mkl));
|
||||
|
||||
collective_mainloop.load(params.mainloop, mainloop_pipeline, mainloop_pipe_producer_state,
|
||||
load_inputs, blk_coord, k_tile_iter, k_tile_count, lane_idx, block_rank_in_cluster,
|
||||
shared_storage.tensors.mainloop);
|
||||
// Update starting pipeline state for the next tile
|
||||
mainloop_pipe_producer_state.advance(k_tile_count);
|
||||
|
||||
// Signal for the epilogue load warp to begin
|
||||
if (do_load_order_arrive)
|
||||
{
|
||||
load_order_barrier.arrive();
|
||||
do_load_order_arrive = false;
|
||||
}
|
||||
|
||||
// Get next work tile
|
||||
scheduler.advance_to_next_work();
|
||||
work_tile_info = scheduler.get_current_work();
|
||||
} // Scheduler work fetch loop
|
||||
|
||||
// Make sure all Consumer Warp Groups have been waited upon
|
||||
collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state);
|
||||
} // Mainloop Producer Warp End
|
||||
|
||||
// Epilogue Producer Warp
|
||||
else if (producer_warp_role == ProducerWarpRole::Epilogue && collective_epilogue.is_producer_load_needed())
|
||||
{
|
||||
load_order_barrier.wait();
|
||||
while (work_tile_info.is_valid())
|
||||
{
|
||||
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape
|
||||
auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
|
||||
auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
|
||||
auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
|
||||
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
|
||||
|
||||
epi_load_pipe_producer_state
|
||||
= collective_epilogue.load(epi_load_pipeline, epi_load_pipe_producer_state, problem_shape_MNKL,
|
||||
blk_shape, blk_coord, tiled_mma, lane_idx, shared_storage.tensors.epilogue);
|
||||
|
||||
// Get next work tile
|
||||
scheduler.advance_to_next_work();
|
||||
work_tile_info = scheduler.get_current_work();
|
||||
} // Scheduler work fetch loop
|
||||
|
||||
// Make sure all Consumer Warp Groups have been waited upon
|
||||
collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state);
|
||||
} // Epilogue Producer Warp End
|
||||
} // Producer Warp Group End
|
||||
|
||||
else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1)
|
||||
{
|
||||
cutlass::arch::warpgroup_reg_alloc<MmaRegisterRequirement>();
|
||||
|
||||
float scale_d0 = params.mainloop.scale_d0;
|
||||
float scale_d1 = params.mainloop.scale_d1;
|
||||
while (work_tile_info.is_valid())
|
||||
{
|
||||
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape
|
||||
auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
|
||||
auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
|
||||
auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
|
||||
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
|
||||
|
||||
// Allocate the accumulators for the (M,N) blk_shape
|
||||
Tensor accumulators0 = partition_fragment_C(tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N)
|
||||
Tensor accumulators1 = partition_fragment_C(tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N)
|
||||
|
||||
// Order two Math WG's MMA one after the other, helps hide Epilogue
|
||||
math_wg_order_barrier.wait();
|
||||
|
||||
collective_mainloop.mma(mainloop_pipeline, mainloop_pipe_consumer_state, accumulators0, accumulators1,
|
||||
k_tile_count, warp_group_thread_idx, shared_storage.tensors.mainloop, params.mainloop);
|
||||
|
||||
// Cue for next Math WG's MMA to start
|
||||
math_wg_order_barrier.arrive();
|
||||
|
||||
// Make sure the math instructions are done and free buffers before entering the epilogue
|
||||
collective_mainloop.mma_tail(mainloop_pipeline, mainloop_pipe_consumer_state, k_tile_count);
|
||||
// Update starting mainloop pipeline state for the next tile
|
||||
mainloop_pipe_consumer_state.advance(k_tile_count * NumMmaWarpGroups);
|
||||
|
||||
Activation elt_op;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(accumulators0); i++)
|
||||
{
|
||||
accumulators0[i] = (accumulators0[i] * scale_d0) * elt_op(scale_d1 * accumulators1[i]);
|
||||
}
|
||||
|
||||
// Order two Math WG's Epilogue one after the other
|
||||
math_wg_order_barrier.wait();
|
||||
|
||||
// Epilogue and write to gD
|
||||
auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next]
|
||||
= collective_epilogue.store(epi_load_pipeline, epi_load_pipe_consumer_state, epi_store_pipeline,
|
||||
epi_store_pipe_producer_state, problem_shape_MNKL, blk_shape, blk_coord, accumulators0,
|
||||
tiled_mma, warp_group_thread_idx, shared_storage.tensors.epilogue);
|
||||
|
||||
// TMA store pipeline wait is only visible to TMA-issuing warp, so for multiple-consumer kernels
|
||||
// we need to wait for all TMA stores to complete before issuing consumer order barrier arrives
|
||||
// to ensure next math consumer doesn't overwrite smem of in-flight TMA stores of current consumer.
|
||||
auto [epi_load_pipe_consumer_state_next_, epi_store_pipe_producer_state_next_]
|
||||
= collective_epilogue.store_tail(epi_load_pipeline, epi_load_pipe_consumer_state_next,
|
||||
epi_store_pipeline, epi_store_pipe_producer_state_next);
|
||||
|
||||
// Update starting load/store pipeline states for the next tile
|
||||
// state has already been incremented by 1 tile in collective calls, advance once again for ping pong
|
||||
epi_load_pipe_consumer_state = epi_load_pipe_consumer_state_next_;
|
||||
epi_store_pipe_producer_state = epi_store_pipe_producer_state_next_;
|
||||
epi_load_pipe_consumer_state.advance(c_tile_count);
|
||||
epi_store_pipe_producer_state.advance(d_tile_count);
|
||||
|
||||
// Cue for next Math WG's Epilogue to start
|
||||
math_wg_order_barrier.arrive();
|
||||
|
||||
// Get next work tile
|
||||
scheduler.advance_to_next_work(NumMmaWarpGroups);
|
||||
work_tile_info = scheduler.get_current_work();
|
||||
} // Scheduler work fetch loop
|
||||
} // Consumer Warp Groups End
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::kernel
|
||||
@@ -1,494 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief based on cutlass/include/cutlass/gemm/kernel/gemm_grouped.h
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/semaphore.h"
|
||||
|
||||
#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h"
|
||||
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace kernel
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||
typename Epilogue_, ///! Epilogue
|
||||
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
|
||||
GroupScheduleMode GroupScheduleMode_, ///! Type of scheduling to perform
|
||||
bool Transposed = false>
|
||||
struct SplitkGemmGrouped
|
||||
{
|
||||
public:
|
||||
using Mma = Mma_;
|
||||
using Epilogue = Epilogue_;
|
||||
using EpilogueOutputOp = typename Epilogue::OutputOp;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_;
|
||||
static bool const kTransposed = Transposed;
|
||||
|
||||
// Optional transpose
|
||||
using MapArguments = kernel::detail::MapArguments<typename Mma::IteratorA::Element, typename Mma::IteratorA::Layout,
|
||||
Mma::kTransformA, Mma::IteratorA::AccessType::kElements, typename Mma::IteratorB::Element,
|
||||
typename Mma::IteratorB::Layout, Mma::kTransformB, Mma::IteratorB::AccessType::kElements, typename Mma::LayoutC,
|
||||
kTransposed>;
|
||||
|
||||
// Public-facing type definitions related to operand element type, layout, and complex conjugate
|
||||
// operation. Must interact with the 'kTransposed' notion.
|
||||
using ElementA = typename MapArguments::ElementA;
|
||||
using LayoutA = typename MapArguments::LayoutA;
|
||||
using ElementB = typename MapArguments::ElementB;
|
||||
using LayoutB = typename MapArguments::LayoutB;
|
||||
using ElementC = typename Epilogue::OutputTileIterator::Element;
|
||||
using LayoutC = typename MapArguments::LayoutC;
|
||||
|
||||
using ElementFinalOutput = typename MapArguments::ElementA;
|
||||
|
||||
static ComplexTransform const kTransformA = MapArguments::kTransformA;
|
||||
static ComplexTransform const kTransformB = MapArguments::kTransformB;
|
||||
|
||||
// Type definitions about the mainloop.
|
||||
using Operator = typename Mma::Operator;
|
||||
using OperatorClass = typename Mma::Operator::OperatorClass;
|
||||
using ThreadblockShape = typename Mma::Shape;
|
||||
using WarpShape = typename Mma::Operator::Shape;
|
||||
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
|
||||
using ArchTag = typename Mma::ArchTag;
|
||||
|
||||
static int const kStages = Mma::kStages;
|
||||
static int const kAlignmentA = MapArguments::kAlignmentA;
|
||||
static int const kAlignmentB = MapArguments::kAlignmentB;
|
||||
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
/// Warp count (concept: GemmShape)
|
||||
using WarpCount = typename Mma::WarpCount;
|
||||
static int const kThreadCount = 32 * WarpCount::kCount;
|
||||
|
||||
using ProblemVisitor
|
||||
= GemmGroupedProblemVisitor<ThreadblockShape, kGroupScheduleMode, kThreadCount, kThreadCount, kTransposed>;
|
||||
|
||||
//
|
||||
// Structures
|
||||
//
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments
|
||||
{
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
GemmCoord* problem_sizes;
|
||||
int problem_count;
|
||||
int threadblock_count;
|
||||
|
||||
typename EpilogueOutputOp::Params output_op;
|
||||
|
||||
ElementA** ptr_A;
|
||||
ElementB** ptr_B;
|
||||
ElementFinalOutput** ptr_C;
|
||||
ElementFinalOutput** ptr_D;
|
||||
|
||||
typename LayoutA::Stride::LongIndex* lda;
|
||||
typename LayoutB::Stride::LongIndex* ldb;
|
||||
typename LayoutC::Stride::LongIndex* ldc;
|
||||
typename LayoutC::Stride::LongIndex* ldd;
|
||||
|
||||
// Only used by device-level operator
|
||||
GemmCoord* host_problem_sizes;
|
||||
|
||||
// splitK
|
||||
int split_k_slices;
|
||||
int64_t* splitk_buffer_offsets;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments()
|
||||
: problem_count(0)
|
||||
, threadblock_count(0)
|
||||
, ptr_A(nullptr)
|
||||
, ptr_B(nullptr)
|
||||
, ptr_C(nullptr)
|
||||
, ptr_D(nullptr)
|
||||
, lda(nullptr)
|
||||
, ldb(nullptr)
|
||||
, ldc(nullptr)
|
||||
, ldd(nullptr)
|
||||
, host_problem_sizes(nullptr)
|
||||
, split_k_slices(1)
|
||||
, splitk_buffer_offsets(nullptr)
|
||||
{
|
||||
}
|
||||
|
||||
/// Ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(GemmCoord* problem_sizes, int problem_count, int threadblock_count,
|
||||
typename EpilogueOutputOp::Params output_op, ElementA** ptr_A, ElementB** ptr_B, ElementFinalOutput** ptr_C,
|
||||
ElementFinalOutput** ptr_D, typename LayoutA::Stride::LongIndex* lda,
|
||||
typename LayoutB::Stride::LongIndex* ldb, typename LayoutC::Stride::LongIndex* ldc,
|
||||
typename LayoutC::Stride::LongIndex* ldd, GemmCoord* host_problem_sizes, int split_k_slices,
|
||||
int64_t* splitk_buffer_offsets)
|
||||
: problem_sizes(problem_sizes)
|
||||
, problem_count(problem_count)
|
||||
, threadblock_count(threadblock_count)
|
||||
, output_op(output_op)
|
||||
, ptr_A(ptr_A)
|
||||
, ptr_B(ptr_B)
|
||||
, ptr_C(ptr_C)
|
||||
, ptr_D(ptr_D)
|
||||
, lda(lda)
|
||||
, ldb(ldb)
|
||||
, ldc(ldc)
|
||||
, ldd(ldd)
|
||||
, host_problem_sizes(host_problem_sizes)
|
||||
, split_k_slices(split_k_slices)
|
||||
, splitk_buffer_offsets(splitk_buffer_offsets)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Structure for precomputing values in host memory and passing to kernels
|
||||
//
|
||||
|
||||
/// Parameters structure
|
||||
struct Params
|
||||
{
|
||||
|
||||
typename ProblemVisitor::Params problem_visitor;
|
||||
int threadblock_count;
|
||||
|
||||
typename EpilogueOutputOp::Params output_op;
|
||||
|
||||
ElementA** ptr_A;
|
||||
ElementB** ptr_B;
|
||||
ElementFinalOutput** ptr_C;
|
||||
ElementFinalOutput** ptr_D;
|
||||
ElementC* ptr_C_split;
|
||||
ElementC* ptr_D_split;
|
||||
|
||||
typename LayoutA::Stride::LongIndex* lda;
|
||||
typename LayoutB::Stride::LongIndex* ldb;
|
||||
typename LayoutC::Stride::LongIndex* ldc;
|
||||
typename LayoutC::Stride::LongIndex* ldd;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
// splitk
|
||||
GemmCoord grid_tiled_shape;
|
||||
int swizzle_log_tile;
|
||||
int gemm_k_size;
|
||||
GemmCoord* host_problem_sizes;
|
||||
int split_k_slices;
|
||||
int64_t* splitk_buffer_offsets;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params()
|
||||
: ptr_A(nullptr)
|
||||
, ptr_B(nullptr)
|
||||
, ptr_C(nullptr)
|
||||
, ptr_D(nullptr)
|
||||
, ptr_C_split(nullptr)
|
||||
, ptr_D_split(nullptr)
|
||||
, lda(nullptr)
|
||||
, ldb(nullptr)
|
||||
, ldc(nullptr)
|
||||
, ldd(nullptr)
|
||||
, swizzle_log_tile(0)
|
||||
, gemm_k_size(0)
|
||||
, host_problem_sizes(nullptr)
|
||||
, split_k_slices(1)
|
||||
, splitk_buffer_offsets(nullptr)
|
||||
{
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const& args, void* workspace = nullptr, int tile_count = 0)
|
||||
: problem_visitor(args.problem_sizes, args.problem_count, workspace, tile_count)
|
||||
, host_problem_sizes(args.host_problem_sizes)
|
||||
, threadblock_count(args.threadblock_count)
|
||||
, output_op(args.output_op)
|
||||
, ptr_A(args.ptr_A)
|
||||
, ptr_B(args.ptr_B)
|
||||
, ptr_C(args.ptr_C)
|
||||
, ptr_D(args.ptr_D)
|
||||
, ptr_C_split((ElementC*) workspace)
|
||||
, ptr_D_split((ElementC*) workspace)
|
||||
, lda(args.lda)
|
||||
, ldb(args.ldb)
|
||||
, ldc(args.ldc)
|
||||
, ldd(args.ldd)
|
||||
, split_k_slices(args.split_k_slices)
|
||||
, splitk_buffer_offsets(args.splitk_buffer_offsets)
|
||||
{
|
||||
// Determine grid shape
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
grid_tiled_shape = threadblock_swizzle.get_tiled_shape(args.host_problem_sizes[0],
|
||||
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.split_k_slices);
|
||||
swizzle_log_tile = ThreadblockSwizzle().get_log_tile(grid_tiled_shape);
|
||||
|
||||
// only support same k
|
||||
int full_gemm_k_iterations = args.host_problem_sizes[0].k() / Mma::Shape::kK;
|
||||
int gemm_k_iterations = full_gemm_k_iterations / grid_tiled_shape.k();
|
||||
|
||||
gemm_k_size = gemm_k_iterations * Mma::Shape::kK;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void update(Arguments const& args, void* workspace = nullptr, int tile_count = 0)
|
||||
{
|
||||
|
||||
problem_visitor =
|
||||
typename ProblemVisitor::Params(args.problem_sizes, args.problem_count, workspace, tile_count);
|
||||
threadblock_count = args.threadblock_count;
|
||||
output_op = args.output_op;
|
||||
ptr_A = args.ptr_A;
|
||||
ptr_B = args.ptr_B;
|
||||
ptr_C = args.ptr_C;
|
||||
ptr_D = args.ptr_D;
|
||||
ptr_C_split = workspace;
|
||||
ptr_D_split = workspace;
|
||||
|
||||
lda = args.lda;
|
||||
ldb = args.ldb;
|
||||
ldc = args.ldc;
|
||||
ldd = args.ldd;
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared memory storage structure
|
||||
struct SharedStorage
|
||||
{
|
||||
union
|
||||
{
|
||||
typename Mma::SharedStorage main_loop;
|
||||
typename Epilogue::SharedStorage epilogue;
|
||||
} kernel;
|
||||
|
||||
// ProblemVisitor shared storage can't be overlapped with others
|
||||
typename ProblemVisitor::SharedStorage problem_visitor;
|
||||
};
|
||||
|
||||
public:
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_DEVICE
|
||||
SplitkGemmGrouped() {}
|
||||
|
||||
/// Determines whether kernel satisfies alignment
|
||||
static Status can_implement(cutlass::gemm::GemmCoord const& problem_size)
|
||||
{
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
static Status can_implement(Arguments const& args)
|
||||
{
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Executes one GEMM
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const& params, SharedStorage& shared_storage)
|
||||
{
|
||||
|
||||
//
|
||||
// These types shadow the type-level definitions and support the ability to implement
|
||||
// a 'transposed' GEMM that computes the transposed problems.
|
||||
//
|
||||
using ElementA = typename Mma::IteratorA::Element;
|
||||
using LayoutA = typename Mma::IteratorA::Layout;
|
||||
using ElementB = typename Mma::IteratorB::Element;
|
||||
using LayoutB = typename Mma::IteratorB::Layout;
|
||||
using ElementC = typename Epilogue::OutputTileIterator::Element;
|
||||
using LayoutC = typename Epilogue::OutputTileIterator::Layout;
|
||||
|
||||
//
|
||||
// Problem visitor.
|
||||
//
|
||||
ProblemVisitor problem_visitor(params.problem_visitor, shared_storage.problem_visitor, blockIdx.x);
|
||||
|
||||
// Outer 'persistent' loop to iterate over tiles
|
||||
while (problem_visitor.next_tile())
|
||||
{
|
||||
|
||||
GemmCoord problem_size = problem_visitor.problem_size();
|
||||
int32_t problem_idx = problem_visitor.problem_index();
|
||||
int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx());
|
||||
|
||||
GemmCoord grid_shape = problem_visitor.grid_shape(problem_size);
|
||||
|
||||
// Load element pointers. Exchange pointers and strides if working on the transpose
|
||||
ElementA* ptr_A
|
||||
= reinterpret_cast<ElementA*>((kTransposed ? params.ptr_B[problem_idx] : params.ptr_A[problem_idx]));
|
||||
typename LayoutA::LongIndex ldm_A = (kTransposed ? params.ldb[problem_idx] : params.lda[problem_idx]);
|
||||
|
||||
ElementB* ptr_B
|
||||
= reinterpret_cast<ElementB*>((kTransposed ? params.ptr_A[problem_idx] : params.ptr_B[problem_idx]));
|
||||
typename LayoutB::LongIndex ldm_B = (kTransposed ? params.lda[problem_idx] : params.ldb[problem_idx]);
|
||||
|
||||
// Compute threadblock location
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_offset(int(threadblock_idx / grid_shape.n()) * Mma::Shape::kM,
|
||||
int(threadblock_idx % grid_shape.n()) * Mma::Shape::kN, 0);
|
||||
|
||||
// Compute initial location in logical coordinates
|
||||
cutlass::MatrixCoord tb_offset_A{
|
||||
threadblock_offset.m(),
|
||||
threadblock_tile_offset.k() * params.gemm_k_size,
|
||||
};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_B{threadblock_tile_offset.k() * params.gemm_k_size, threadblock_offset.n()};
|
||||
|
||||
// Problem size is a function of threadblock index in the K dimension
|
||||
int problem_size_k;
|
||||
if (threadblock_tile_offset.k() + 1 == params.grid_tiled_shape.k())
|
||||
{
|
||||
problem_size_k = problem_size.k();
|
||||
}
|
||||
else
|
||||
{
|
||||
problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size;
|
||||
}
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK;
|
||||
|
||||
// Compute position within threadblock
|
||||
int thread_idx = threadIdx.x;
|
||||
|
||||
// Construct iterators to A and B operands
|
||||
typename Mma::IteratorA iterator_A(
|
||||
LayoutA(ldm_A), ptr_A, {problem_size.m(), problem_size_k}, thread_idx, tb_offset_A);
|
||||
|
||||
typename Mma::IteratorB iterator_B(
|
||||
LayoutB(ldm_B), ptr_B, {problem_size_k, problem_size.n()}, thread_idx, tb_offset_B);
|
||||
|
||||
typename Mma::FragmentC accumulators;
|
||||
|
||||
accumulators.clear();
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = canonical_warp_idx_sync();
|
||||
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
//
|
||||
// Matrix multiply phase
|
||||
//
|
||||
|
||||
// Construct thread-scoped matrix multiply
|
||||
Mma mma(shared_storage.kernel.main_loop, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
// Wait for all threads to finish their epilogue phases from the previous tile.
|
||||
__syncthreads();
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators);
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
//
|
||||
|
||||
EpilogueOutputOp output_op(params.output_op);
|
||||
|
||||
ElementC* ptr_C = params.ptr_C_split;
|
||||
ElementC* ptr_D = params.ptr_D_split;
|
||||
|
||||
LayoutC layout_C(params.ldc[problem_idx]);
|
||||
LayoutC layout_D(params.ldd[problem_idx]);
|
||||
|
||||
typename Epilogue::OutputTileIterator::Params params_C(layout_C);
|
||||
typename Epilogue::OutputTileIterator::Params params_D(layout_D);
|
||||
|
||||
// assume identity swizzle
|
||||
MatrixCoord threadblock_offset_C(threadblock_offset.m(), threadblock_offset.n());
|
||||
|
||||
// Tile iterator loading from source tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_C(
|
||||
params_C, ptr_C, problem_size.mn(), thread_idx, threadblock_offset_C);
|
||||
|
||||
iterator_C.add_pointer_offset(problem_size.m() * problem_size.n() * threadblock_tile_offset.k()
|
||||
+ gridDim.z * params.splitk_buffer_offsets[problem_idx]);
|
||||
|
||||
// Tile iterator writing to destination tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_D(
|
||||
params_D, ptr_D, problem_size.mn(), thread_idx, threadblock_offset_C);
|
||||
iterator_D.add_pointer_offset(problem_size.m() * problem_size.n() * threadblock_tile_offset.k()
|
||||
+ gridDim.z * params.splitk_buffer_offsets[problem_idx]);
|
||||
|
||||
Epilogue epilogue(shared_storage.kernel.epilogue, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
// Execute the epilogue operator to update the destination tensor.
|
||||
epilogue(output_op, iterator_D, accumulators, iterator_C);
|
||||
|
||||
// Next tile
|
||||
problem_visitor.advance(gridDim.x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -1,125 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass_extensions/arch/mma.h"
|
||||
#include "cutlass_extensions/interleaved_numeric_conversion.h"
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// We need to distinguish here, since we want volta support. It is too much effort
|
||||
// to write shared memory iterators that are probably needed for volta to function
|
||||
// properly. As a result, we allow converters both after the LDG (for volta) and after
|
||||
// the LDS for Turing+.
|
||||
template <
|
||||
/// Iterator for B matrix in global memory
|
||||
typename IteratorB,
|
||||
/// Warp level Mma
|
||||
typename MmaOperator,
|
||||
/// Math operation perform by warp level operator
|
||||
typename MathOperator>
|
||||
struct SetConverters
|
||||
{
|
||||
};
|
||||
|
||||
// Dequantize after LDG, so set transforms accordingly
|
||||
template <
|
||||
/// Iterator for B matrix in global memory
|
||||
typename IteratorB,
|
||||
/// Mma Policy
|
||||
typename MmaOperator>
|
||||
struct SetConverters<IteratorB, MmaOperator, arch::OpMultiplyAdd>
|
||||
{
|
||||
using TransformAfterLDG
|
||||
= FastInterleavedAndBiasedNumericArrayConverter<typename MmaOperator::ArchMmaOperator::ElementB,
|
||||
typename IteratorB::Element, IteratorB::Fragment::kElements>;
|
||||
|
||||
using TransformAfterLDS = NumericArrayConverter<typename MmaOperator::ArchMmaOperator::ElementB,
|
||||
typename MmaOperator::ArchMmaOperator::ElementB, MmaOperator::FragmentB::kElements>;
|
||||
};
|
||||
|
||||
// Dequantize after LDS, so set transforms accordingly
|
||||
|
||||
template <
|
||||
/// Iterator for B matrix in global memory
|
||||
typename IteratorB,
|
||||
/// Mma Policy
|
||||
typename MmaOperator>
|
||||
struct SetConverters<IteratorB, MmaOperator, arch::OpMultiplyAddDequantizeInterleavedBToA>
|
||||
{
|
||||
using TransformAfterLDG = NumericArrayConverter<typename IteratorB::Element, typename IteratorB::Element,
|
||||
IteratorB::Fragment::kElements>;
|
||||
|
||||
using TransformAfterLDS
|
||||
= FastInterleavedAndBiasedNumericArrayConverter<typename MmaOperator::ArchMmaOperator::ElementB,
|
||||
typename TransformAfterLDG::result_type::Element, MmaOperator::FragmentB::kElements>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA_,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA_,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB_,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB_,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for the input scale
|
||||
typename ElementScale_,
|
||||
/// Layout for the scale operand
|
||||
typename LayoutScale_,
|
||||
/// Access granularity of Scales in unit of elements
|
||||
int kAlignmentScale,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator_,
|
||||
/// Layout type for C and D matrix operands
|
||||
typename LayoutC_,
|
||||
/// Operator class tag
|
||||
typename OperatorClass_,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag_,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape_,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape_,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape_,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator_,
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
|
||||
///
|
||||
typename Enable = void>
|
||||
struct DqMma;
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@@ -1,302 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/gemm/threadblock/default_mma.h"
|
||||
#include "cutlass_extensions/arch/mma.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/threadblock/dq_mma_multistage.h"
|
||||
#include "cutlass_extensions/gemm/warp/default_mma_tensor_op.h"
|
||||
#include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h"
|
||||
#include "cutlass_extensions/tile_interleaved_layout.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/threadblock/default_dq_mma.h"
|
||||
#include "cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h"
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename MmaShape, typename Element, typename Layout, WeightOnlyQuantOp QuantOp, int Alignment,
|
||||
typename Enable = void>
|
||||
struct DefaultScaleIteratorsMultistage;
|
||||
|
||||
// Fine grained iterators
|
||||
template <typename MmaShape, typename Element, typename Layout, WeightOnlyQuantOp QuantOp, int Alignment>
|
||||
struct DefaultScaleIteratorsMultistage<MmaShape, Element, Layout, QuantOp, Alignment,
|
||||
std::enable_if_t<isFinegrained(QuantOp)>>
|
||||
{
|
||||
using IteratorScale
|
||||
= cutlass::transform::threadblock::FineGrainedScaleZeroIterator<cutlass::MatrixShape<1, MmaShape::kN>, Element,
|
||||
Layout, 0, Alignment>;
|
||||
|
||||
using SmemIteratorScale = IteratorScale;
|
||||
};
|
||||
|
||||
// Per column iterators
|
||||
template <typename MmaShape, typename Element, typename Layout, WeightOnlyQuantOp QuantOp, int Alignment>
|
||||
struct DefaultScaleIteratorsMultistage<MmaShape, Element, Layout, QuantOp, Alignment,
|
||||
std::enable_if_t<!isFinegrained(QuantOp)>>
|
||||
{
|
||||
// ThreadMap for scale iterator
|
||||
static_assert((MmaShape::kN % Alignment) == 0, "");
|
||||
|
||||
private:
|
||||
using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap<layout::PitchLinearShape<MmaShape::kN, 1>,
|
||||
MmaShape::kN / Alignment, Alignment>;
|
||||
|
||||
public:
|
||||
// Define iterators over tiles from the scale operand
|
||||
using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaShape::kN>,
|
||||
Element, Layout, 0, IteratorScaleThreadMap, Alignment>;
|
||||
|
||||
using SmemIteratorScale = IteratorScale;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// Type for element A
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Type for element B
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for the input scale
|
||||
typename ElementScale,
|
||||
/// Layout for the scale operand
|
||||
typename LayoutScale,
|
||||
/// Access granularity of Scales in unit of elements
|
||||
int kAlignmentScale,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Operator class tag
|
||||
typename OperatorClass,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Stages in GEMM
|
||||
int kStages,
|
||||
/// Operator performed by GEMM
|
||||
typename Operator_,
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
SharedMemoryClearOption SharedMemoryClear>
|
||||
struct DqMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementScale, LayoutScale, kAlignmentScale,
|
||||
ElementAccumulator, layout::RowMajor, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape,
|
||||
kStages, Operator_, SharedMemoryClear,
|
||||
typename platform::enable_if<(
|
||||
ArchTag::kMinComputeCapability >= 80 && !layout::IsColumnMajorTileInterleave<LayoutB>::value)>::type>
|
||||
{
|
||||
|
||||
static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value
|
||||
|| platform::is_same<ElementA, float_e4m3_t>::value,
|
||||
"Element A must be fp16, fp8 or bf16");
|
||||
|
||||
using OperatorInfo = arch::DetagOperator<Operator_>;
|
||||
using Operator = typename OperatorInfo::Operator;
|
||||
static_assert(platform::is_same<Operator, arch::OpMultiplyAddDequantizeInterleavedBToA>::value,
|
||||
"Mma multistage must dequantize after ldsm");
|
||||
|
||||
static_assert(platform::is_same<ElementB, uint8_t>::value || platform::is_same<ElementB, uint4b_t>::value,
|
||||
"Element B must be uint8 or uint4");
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits<ElementA>::value * kAlignmentA) == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits<ElementB>::value * kAlignmentB) == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
// Define the MmaCore components
|
||||
// Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape,
|
||||
ElementA, LayoutA, ElementB, LayoutB, ElementAccumulator, layout::RowMajor, OperatorClass, std::max(kStages, 3),
|
||||
Operator, false, CacheOpA, CacheOpB>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::Array<ElementA, kAlignmentA>;
|
||||
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, ElementA, LayoutA, 1, ThreadMapA,
|
||||
AccessTypeA>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::Array<ElementB, kAlignmentB>;
|
||||
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, ElementB, LayoutB, 0, ThreadMapB,
|
||||
AccessTypeB>;
|
||||
|
||||
using ScaleIterators = DefaultScaleIteratorsMultistage<typename MmaCore::Shape, ElementScale, LayoutScale,
|
||||
OperatorInfo::QuantOp, kAlignmentScale>;
|
||||
|
||||
// Define iterators over tiles from the scale operand
|
||||
using IteratorScale = typename ScaleIterators::IteratorScale;
|
||||
|
||||
using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale;
|
||||
|
||||
using Converter = FastInterleavedAndBiasedNumericArrayConverter<ElementScale, ElementB,
|
||||
MmaCore::MmaPolicy::Operator::FragmentB::kElements>;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage<typename MmaCore::Shape, IteratorA,
|
||||
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
|
||||
MmaCore::kCacheOpB, IteratorScale, SmemIteratorScale, ElementAccumulator, layout::RowMajor,
|
||||
typename MmaCore::MmaPolicy, kStages, Converter, OperatorInfo::QuantOp, SharedMemoryClear>;
|
||||
};
|
||||
|
||||
// Specialization to handle column major interleave B
|
||||
template <
|
||||
/// Type for element A
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Type for element B
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for the input scale
|
||||
typename ElementScale,
|
||||
/// Layout for the scale operand
|
||||
typename LayoutScale,
|
||||
/// Access granularity of Scales in unit of elements
|
||||
int kAlignmentScale,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Operator class tag
|
||||
typename OperatorClass,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Stages in GEMM
|
||||
int kStages,
|
||||
/// Operator performed by GEMM
|
||||
typename Operator_,
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
SharedMemoryClearOption SharedMemoryClear>
|
||||
struct DqMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementScale, LayoutScale, kAlignmentScale,
|
||||
ElementAccumulator, layout::RowMajor, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape,
|
||||
kStages, Operator_, SharedMemoryClear,
|
||||
typename platform::enable_if<(
|
||||
ArchTag::kMinComputeCapability >= 80 && layout::IsColumnMajorTileInterleave<LayoutB>::value)>::type>
|
||||
{
|
||||
|
||||
static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value
|
||||
|| platform::is_same<ElementA, float_e4m3_t>::value,
|
||||
"Element A must be fp16, fp8 or bf16");
|
||||
|
||||
using OperatorInfo = arch::DetagOperator<Operator_>;
|
||||
using Operator = typename OperatorInfo::Operator;
|
||||
static_assert(platform::is_same<Operator, arch::OpMultiplyAddDequantizeInterleavedBToA>::value,
|
||||
"Mma multistage must dequantize after ldsm");
|
||||
|
||||
static_assert(platform::is_same<ElementB, uint8_t>::value || platform::is_same<ElementB, uint4b_t>::value,
|
||||
"Element B must be uint8 or uint4");
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits<ElementA>::value * kAlignmentA) == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits<ElementB>::value * kAlignmentB) == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
// Define the MmaCore components
|
||||
// Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape,
|
||||
ElementA, LayoutA, ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, OperatorClass,
|
||||
std::max(kStages, 3), Operator, false, CacheOpA, CacheOpB>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::Array<ElementA, kAlignmentA>;
|
||||
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, ElementA, LayoutA, 1, ThreadMapA,
|
||||
AccessTypeA>;
|
||||
|
||||
private:
|
||||
static constexpr int ColumnsInterleaved = LayoutB::kColumnsInterleaved;
|
||||
static constexpr int RowsPerTile = LayoutB::kRowsPerTile;
|
||||
static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), "");
|
||||
static_assert(RowsPerTile == MmaCore::Shape::kK, "");
|
||||
|
||||
using OriginalThreadMap = typename MmaCore::IteratorThreadMapB;
|
||||
using OriginalWarpArrangement = typename OriginalThreadMap::Detail::WarpThreadArrangement;
|
||||
static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), "");
|
||||
|
||||
using GmemIteratorShape
|
||||
= MatrixShape<MmaCore::Shape::kK * ColumnsInterleaved, MmaCore::Shape::kN / ColumnsInterleaved>;
|
||||
using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap<
|
||||
layout::PitchLinearShape<GmemIteratorShape::kRow, GmemIteratorShape::kColumn>, OriginalThreadMap::kThreads,
|
||||
layout::PitchLinearShape<OriginalWarpArrangement::kContiguous * ColumnsInterleaved,
|
||||
OriginalWarpArrangement::kStrided / ColumnsInterleaved>,
|
||||
MmaCore::kAccessSizeInBits / sizeof_bits<ElementB>::value>;
|
||||
|
||||
public:
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::Array<ElementB, kAlignmentB>;
|
||||
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<GmemIteratorShape, ElementB,
|
||||
layout::ColumnMajor, 0, GmemThreadMapB, AccessTypeB>;
|
||||
|
||||
using ScaleIterators = DefaultScaleIteratorsMultistage<typename MmaCore::Shape, ElementScale, LayoutScale,
|
||||
OperatorInfo::QuantOp, kAlignmentScale>;
|
||||
|
||||
// Define iterators over tiles from the scale operand
|
||||
using IteratorScale = typename ScaleIterators::IteratorScale;
|
||||
|
||||
using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale;
|
||||
|
||||
using Converter = FastInterleavedAndBiasedNumericArrayConverter<ElementScale, ElementB,
|
||||
MmaCore::MmaPolicy::Operator::FragmentB::kElements>;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage<typename MmaCore::Shape, IteratorA,
|
||||
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
|
||||
MmaCore::kCacheOpB, IteratorScale, SmemIteratorScale, ElementAccumulator, layout::RowMajor,
|
||||
typename MmaCore::MmaPolicy, kStages, Converter, OperatorInfo::QuantOp, SharedMemoryClear>;
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@@ -1,284 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/gemm/threadblock/default_mma.h"
|
||||
#include "cutlass_extensions/arch/mma.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h"
|
||||
#include "cutlass_extensions/gemm/warp/default_mma_tensor_op.h"
|
||||
#include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h"
|
||||
#include "cutlass_extensions/tile_interleaved_layout.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/threadblock/default_dq_mma.h"
|
||||
#include "cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h"
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename MmaShape, typename Element, typename Layout, WeightOnlyQuantOp QuantOp, int Alignment,
|
||||
typename Enable = void>
|
||||
struct DefaultScaleIteratorsPipelined;
|
||||
|
||||
// Fine grained iterators
|
||||
template <typename MmaShape, typename Element, typename Layout, WeightOnlyQuantOp QuantOp, int Alignment>
|
||||
struct DefaultScaleIteratorsPipelined<MmaShape, Element, Layout, QuantOp, Alignment,
|
||||
std::enable_if_t<isFinegrained(QuantOp)>>
|
||||
{
|
||||
private:
|
||||
using SmemScaleType = half_t;
|
||||
|
||||
public:
|
||||
using IteratorScale
|
||||
= cutlass::transform::threadblock::FineGrainedScaleZeroIterator<cutlass::MatrixShape<1, MmaShape::kN>, Element,
|
||||
Layout, 0, Alignment>;
|
||||
|
||||
using SmemIteratorScale
|
||||
= cutlass::transform::threadblock::FineGrainedScaleZeroIterator<cutlass::MatrixShape<1, MmaShape::kN>,
|
||||
SmemScaleType, Layout, 0, Alignment>;
|
||||
};
|
||||
|
||||
// Per column iterators
|
||||
template <typename MmaShape, typename Element, typename Layout, WeightOnlyQuantOp QuantOp, int Alignment>
|
||||
struct DefaultScaleIteratorsPipelined<MmaShape, Element, Layout, QuantOp, Alignment,
|
||||
std::enable_if_t<!isFinegrained(QuantOp)>>
|
||||
{
|
||||
static_assert((MmaShape::kN % Alignment) == 0, "");
|
||||
|
||||
private:
|
||||
// ThreadMap for scale iterator
|
||||
using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap<layout::PitchLinearShape<MmaShape::kN, 1>,
|
||||
MmaShape::kN / Alignment, Alignment>;
|
||||
using SmemScaleType = half_t;
|
||||
|
||||
public:
|
||||
// Define iterators over tiles from the scale operand
|
||||
using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaShape::kN>,
|
||||
Element, Layout, 0, IteratorScaleThreadMap, Alignment>;
|
||||
|
||||
using SmemIteratorScale
|
||||
= cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaShape::kN>, SmemScaleType,
|
||||
Layout, 0, IteratorScaleThreadMap, Alignment>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// Type for element A
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Type for element B
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for the input scale
|
||||
typename ElementScale,
|
||||
/// Layout for the scale operand
|
||||
typename LayoutScale,
|
||||
/// Access granularity of Scales in unit of elements
|
||||
int kAlignmentScale,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Operator class tag
|
||||
typename OperatorClass,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator_>
|
||||
struct DqMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementScale, LayoutScale, kAlignmentScale,
|
||||
ElementAccumulator, layout::RowMajor, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2,
|
||||
Operator_, SharedMemoryClearOption::kNone,
|
||||
typename platform::enable_if<(
|
||||
ArchTag::kMinComputeCapability < 80 && !layout::IsColumnMajorTileInterleave<LayoutB>::value)>::type>
|
||||
{
|
||||
|
||||
static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value,
|
||||
"Element A must be fp16 or bf16");
|
||||
|
||||
static_assert(platform::is_same<ElementB, uint8_t>::value || platform::is_same<ElementB, uint4b_t>::value,
|
||||
"Element B must be uint8 or uint4");
|
||||
|
||||
using OperatorInfo = arch::DetagOperator<Operator_>;
|
||||
using Operator = typename OperatorInfo::Operator;
|
||||
static_assert(OperatorInfo::QuantOp == WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, "");
|
||||
|
||||
static constexpr bool DqAfterLDG = platform::is_same<arch::OpMultiplyAdd, Operator>::value;
|
||||
using MmaCoreElementA = half_t;
|
||||
using MmaCoreElementB = typename platform::conditional<DqAfterLDG, MmaCoreElementA, ElementB>::type;
|
||||
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape,
|
||||
MmaCoreElementA, LayoutA, MmaCoreElementB, LayoutB, ElementAccumulator, layout::RowMajor, OperatorClass, 2,
|
||||
Operator>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>, ElementA, LayoutA, 1,
|
||||
typename MmaCore::IteratorThreadMapA, kAlignmentA>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore::Shape::kK, MmaCore::Shape::kN>, ElementB, LayoutB, 0,
|
||||
typename MmaCore::IteratorThreadMapB, kAlignmentB>;
|
||||
|
||||
using ScaleIterators = DefaultScaleIteratorsPipelined<typename MmaCore::Shape, ElementScale, LayoutScale,
|
||||
OperatorInfo::QuantOp, kAlignmentScale>;
|
||||
|
||||
// Define iterators over tiles from the scale operand
|
||||
using IteratorScale = typename ScaleIterators::IteratorScale;
|
||||
|
||||
using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale;
|
||||
|
||||
using Converters = SetConverters<IteratorB, typename MmaCore::MmaPolicy::Operator, Operator>;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined<typename MmaCore::Shape, IteratorA,
|
||||
typename MmaCore::SmemIteratorA, IteratorB, typename MmaCore::SmemIteratorB, IteratorScale, SmemIteratorScale,
|
||||
ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, typename Converters::TransformAfterLDG,
|
||||
typename Converters::TransformAfterLDS, OperatorInfo::QuantOp>;
|
||||
};
|
||||
|
||||
// Specialization to handle column major interleave B
|
||||
template <
|
||||
/// Type for element A
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Type for element B
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for the input scale
|
||||
typename ElementScale,
|
||||
/// Layout for the scale operand
|
||||
typename LayoutScale,
|
||||
/// Access granularity of Scales in unit of elements
|
||||
int kAlignmentScale,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Operator class tag
|
||||
typename OperatorClass,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator_>
|
||||
struct DqMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementScale, LayoutScale, kAlignmentScale,
|
||||
ElementAccumulator, layout::RowMajor, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2,
|
||||
Operator_, SharedMemoryClearOption::kNone,
|
||||
typename platform::enable_if<(
|
||||
ArchTag::kMinComputeCapability < 80 && layout::IsColumnMajorTileInterleave<LayoutB>::value)>::type>
|
||||
{
|
||||
|
||||
static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value,
|
||||
"Element A must be fp16 or bf16");
|
||||
|
||||
static_assert(platform::is_same<ElementB, uint8_t>::value || platform::is_same<ElementB, uint4b_t>::value,
|
||||
"Element B must be uint8 or uint4");
|
||||
|
||||
using OperatorInfo = arch::DetagOperator<Operator_>;
|
||||
using Operator = typename OperatorInfo::Operator;
|
||||
|
||||
static constexpr bool DqAfterLDG = platform::is_same<arch::OpMultiplyAdd, Operator>::value;
|
||||
using MmaCoreElementA = half_t;
|
||||
using MmaCoreElementB = typename platform::conditional<DqAfterLDG, MmaCoreElementA, ElementB>::type;
|
||||
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape,
|
||||
MmaCoreElementA, LayoutA, MmaCoreElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor,
|
||||
OperatorClass, 2, Operator>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>, ElementA, LayoutA, 1,
|
||||
typename MmaCore::IteratorThreadMapA, kAlignmentA>;
|
||||
|
||||
private:
|
||||
static constexpr int ColumnsInterleaved = LayoutB::kColumnsInterleaved;
|
||||
static constexpr int RowsPerTile = LayoutB::kRowsPerTile;
|
||||
static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), "");
|
||||
static_assert(RowsPerTile == MmaCore::Shape::kK, "");
|
||||
|
||||
using OriginalThreadMap = typename MmaCore::IteratorThreadMapB;
|
||||
using OriginalWarpArrangement = typename OriginalThreadMap::Detail::WarpThreadArrangement;
|
||||
static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), "");
|
||||
|
||||
using GmemIteratorShape
|
||||
= MatrixShape<MmaCore::Shape::kK * ColumnsInterleaved, MmaCore::Shape::kN / ColumnsInterleaved>;
|
||||
using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap<
|
||||
layout::PitchLinearShape<GmemIteratorShape::kRow, GmemIteratorShape::kColumn>, OriginalThreadMap::kThreads,
|
||||
layout::PitchLinearShape<OriginalWarpArrangement::kContiguous * ColumnsInterleaved,
|
||||
OriginalWarpArrangement::kStrided / ColumnsInterleaved>,
|
||||
MmaCore::kAccessSizeInBits / sizeof_bits<ElementB>::value>;
|
||||
|
||||
public:
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator<GmemIteratorShape, ElementB,
|
||||
layout::ColumnMajor, 0, GmemThreadMapB, kAlignmentB>;
|
||||
|
||||
// ThreadMap for scale iterator
|
||||
static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, "");
|
||||
using IteratorScaleThreadMap
|
||||
= transform::PitchLinearStripminedThreadMap<layout::PitchLinearShape<MmaCore::Shape::kN, 1>,
|
||||
MmaCore::Shape::kN / kAlignmentScale, kAlignmentScale>;
|
||||
|
||||
using ScaleIterators = DefaultScaleIteratorsPipelined<typename MmaCore::Shape, ElementScale, LayoutScale,
|
||||
OperatorInfo::QuantOp, kAlignmentScale>;
|
||||
|
||||
// Define iterators over tiles from the scale operand
|
||||
using IteratorScale = typename ScaleIterators::IteratorScale;
|
||||
|
||||
using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale;
|
||||
|
||||
using Converters = SetConverters<IteratorB, typename MmaCore::MmaPolicy::Operator, Operator>;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined<typename MmaCore::Shape, IteratorA,
|
||||
typename MmaCore::SmemIteratorA, IteratorB, typename MmaCore::SmemIteratorB, IteratorScale, SmemIteratorScale,
|
||||
ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, typename Converters::TransformAfterLDG,
|
||||
typename Converters::TransformAfterLDS, OperatorInfo::QuantOp>;
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@@ -1,351 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/default_mma_bf16.h"
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight, mma pipelined (stage=2)
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultMma<cutlass::half_t, LayoutA, kAlignmentA, uint8_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator>
|
||||
{
|
||||
|
||||
private:
|
||||
static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
|
||||
|
||||
using Mma = DqMma<half_t, LayoutA, kAlignmentA, uint8_t, LayoutB, kAlignmentB, half_t, layout::RowMajor,
|
||||
kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
|
||||
WarpShape, InstructionShape, 2, Operator>;
|
||||
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight, mma pipelined (stage=2)
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultMma<cutlass::half_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator>
|
||||
{
|
||||
|
||||
private:
|
||||
static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
|
||||
|
||||
using Mma = DqMma<half_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, half_t, layout::RowMajor,
|
||||
kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
|
||||
WarpShape, InstructionShape, 2, Operator>;
|
||||
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight, mma multistage
|
||||
/// (stage>=3)
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
///
|
||||
int kStages,
|
||||
/// Shared memory clear option
|
||||
SharedMemoryClearOption SharedMemoryClear>
|
||||
struct DefaultMma<cutlass::half_t, LayoutA, kAlignmentA, uint8_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
|
||||
false, SharedMemoryClear>
|
||||
{
|
||||
|
||||
private:
|
||||
static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
|
||||
|
||||
using Mma = DqMma<half_t, LayoutA, kAlignmentA, uint8_t, LayoutB, kAlignmentB, half_t, layout::RowMajor,
|
||||
kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
|
||||
WarpShape, InstructionShape, kStages, Operator, SharedMemoryClear>;
|
||||
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight, mma multistage
|
||||
/// (stage>=3)
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
///
|
||||
int kStages,
|
||||
/// Shared memory clear option
|
||||
SharedMemoryClearOption SharedMemoryClear>
|
||||
struct DefaultMma<cutlass::half_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
|
||||
false, SharedMemoryClear>
|
||||
{
|
||||
|
||||
private:
|
||||
static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
|
||||
|
||||
using Mma = DqMma<half_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, half_t, layout::RowMajor,
|
||||
kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
|
||||
WarpShape, InstructionShape, kStages, Operator, SharedMemoryClear>;
|
||||
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
};
|
||||
|
||||
#ifdef ENABLE_FP8
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fp8 activation & int4 weight, mma multistage
|
||||
/// (stage>=3)
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
///
|
||||
int kStages,
|
||||
/// Shared memory clear option
|
||||
SharedMemoryClearOption SharedMemoryClear>
|
||||
struct DefaultMma<cutlass::float_e4m3_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
|
||||
false, SharedMemoryClear>
|
||||
{
|
||||
|
||||
private:
|
||||
static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
|
||||
|
||||
using Mma = DqMma<cutlass::float_e4m3_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, half_t,
|
||||
layout::RowMajor, kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag,
|
||||
ThreadblockShape, WarpShape, InstructionShape, kStages, Operator, SharedMemoryClear>;
|
||||
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
// fp16 x fp16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on
|
||||
// large tile when not enough shared mem is present to do 3+ stage
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
SharedMemoryClearOption SharedMemoryClear,
|
||||
/// Gather operand A by using an index array
|
||||
bool GatherA,
|
||||
/// Gather operand B by using an index array
|
||||
bool GatherB>
|
||||
struct DefaultMma<half_t, LayoutA, kAlignmentA, half_t, LayoutB, kAlignmentB, ElementAccumulator, layout::RowMajor,
|
||||
arch::OpClassTensorOp, arch::Sm80, ThreadblockShape, WarpShape, InstructionShape, 2, Operator, false,
|
||||
SharedMemoryClear, GatherA, GatherB>
|
||||
{
|
||||
|
||||
// Define the MmaCore components
|
||||
// 3 is used on purpose here to trigger components for mma multistage
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape,
|
||||
half_t, LayoutA, half_t, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, 3, Operator>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::Array<half_t, kAlignmentA>;
|
||||
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, half_t, LayoutA, 1, ThreadMapA, AccessTypeA,
|
||||
GatherA>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::Array<half_t, kAlignmentB>;
|
||||
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, half_t, LayoutB, 0, ThreadMapB, AccessTypeB,
|
||||
GatherB>;
|
||||
|
||||
// Define the threadblock-scoped multistage matrix multiply
|
||||
using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage<typename MmaCore::Shape, IteratorA,
|
||||
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
|
||||
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, 2>;
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@@ -1,353 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/gemm/threadblock/default_mma.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h"
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & bf16 weight
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
SharedMemoryClearOption SharedMemoryClear,
|
||||
/// Gather operand A by using an index array
|
||||
bool GatherA,
|
||||
/// Gather operand B by using an index array
|
||||
bool GatherB>
|
||||
struct DefaultMma<bfloat16_t, LayoutA, kAlignmentA, bfloat16_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator, false,
|
||||
SharedMemoryClear, GatherA, GatherB>
|
||||
{
|
||||
|
||||
private:
|
||||
// Conversions only needed pre-ampere. This will trigger mma pipeline, so we convert before STS.
|
||||
static constexpr bool arch_has_bf16_mma = ArchTag::kMinComputeCapability >= 80;
|
||||
using MmaElementA = typename platform::conditional<arch_has_bf16_mma, bfloat16_t, half_t>::type;
|
||||
using MmaElementB = typename platform::conditional<arch_has_bf16_mma, bfloat16_t, half_t>::type;
|
||||
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore =
|
||||
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape, MmaElementA,
|
||||
LayoutA, MmaElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, 2, Operator>;
|
||||
|
||||
using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>, bfloat16_t, LayoutA, 1,
|
||||
typename MmaCore::IteratorThreadMapA, kAlignmentA, GatherA>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore::Shape::kK, MmaCore::Shape::kN>, bfloat16_t, LayoutB, 0,
|
||||
typename MmaCore::IteratorThreadMapB, kAlignmentB, GatherB>;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = cutlass::gemm::threadblock::MmaPipelined<typename MmaCore::Shape, IteratorA,
|
||||
typename MmaCore::SmemIteratorA, IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator,
|
||||
layout::RowMajor, typename MmaCore::MmaPolicy>;
|
||||
};
|
||||
|
||||
// bf16 x bf16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on
|
||||
// large tile when not enough shared mem is present to do 3+ stage
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
SharedMemoryClearOption SharedMemoryClear,
|
||||
/// Gather operand A by using an index array
|
||||
bool GatherA,
|
||||
/// Gather operand B by using an index array
|
||||
bool GatherB>
|
||||
struct DefaultMma<bfloat16_t, LayoutA, kAlignmentA, bfloat16_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, arch::Sm80, ThreadblockShape, WarpShape, InstructionShape, 2, Operator,
|
||||
false, SharedMemoryClear, GatherA, GatherB>
|
||||
{
|
||||
|
||||
// Define the MmaCore components
|
||||
// 3 is used on purpose here to trigger components for mma multistage
|
||||
using MmaCore =
|
||||
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape, bfloat16_t,
|
||||
LayoutA, bfloat16_t, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, 3, Operator>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::Array<bfloat16_t, kAlignmentA>;
|
||||
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, bfloat16_t, LayoutA, 1, ThreadMapA,
|
||||
AccessTypeA, GatherA>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::Array<bfloat16_t, kAlignmentB>;
|
||||
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, bfloat16_t, LayoutB, 0, ThreadMapB,
|
||||
AccessTypeB, GatherB>;
|
||||
|
||||
// Define the threadblock-scoped multistage matrix multiply
|
||||
using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage<typename MmaCore::Shape, IteratorA,
|
||||
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
|
||||
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, 2>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int8 weight
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultMma<cutlass::bfloat16_t, LayoutA, kAlignmentA, uint8_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator>
|
||||
{
|
||||
|
||||
private:
|
||||
static constexpr int kAlignmentScale = 128 / sizeof_bits<bfloat16_t>::value;
|
||||
|
||||
using Mma = DqMma<bfloat16_t, LayoutA, kAlignmentA, uint8_t, LayoutB, kAlignmentB, bfloat16_t, layout::RowMajor,
|
||||
kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
|
||||
WarpShape, InstructionShape, 2, Operator>;
|
||||
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int4 weight
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultMma<cutlass::bfloat16_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator>
|
||||
{
|
||||
|
||||
private:
|
||||
static constexpr int kAlignmentScale = 128 / sizeof_bits<bfloat16_t>::value;
|
||||
|
||||
using Mma = DqMma<bfloat16_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, bfloat16_t, layout::RowMajor,
|
||||
kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
|
||||
WarpShape, InstructionShape, 2, Operator>;
|
||||
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
};
|
||||
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
///
|
||||
int kStages,
|
||||
/// Shared memory clear option
|
||||
SharedMemoryClearOption SharedMemoryClear>
|
||||
struct DefaultMma<cutlass::bfloat16_t, LayoutA, kAlignmentA, uint8_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
|
||||
false, SharedMemoryClear>
|
||||
{
|
||||
|
||||
private:
|
||||
static constexpr int kAlignmentScale = 128 / sizeof_bits<bfloat16_t>::value;
|
||||
|
||||
using Mma = DqMma<bfloat16_t, LayoutA, kAlignmentA, uint8_t, LayoutB, kAlignmentB, bfloat16_t, layout::RowMajor,
|
||||
kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
|
||||
WarpShape, InstructionShape, kStages, Operator, SharedMemoryClear>;
|
||||
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
///
|
||||
int kStages,
|
||||
/// Shared memory clear option
|
||||
SharedMemoryClearOption SharedMemoryClear>
|
||||
struct DefaultMma<cutlass::bfloat16_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
|
||||
false, SharedMemoryClear>
|
||||
{
|
||||
|
||||
private:
|
||||
static constexpr int kAlignmentScale = 128 / sizeof_bits<bfloat16_t>::value;
|
||||
|
||||
using Mma = DqMma<bfloat16_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, bfloat16_t, layout::RowMajor,
|
||||
kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
|
||||
WarpShape, InstructionShape, kStages, Operator, SharedMemoryClear>;
|
||||
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@@ -1,257 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/threadblock/mma_base.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass_extensions/weight_only_quant_op.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// SFINAE trick so I can keep the same loop code for Volta and dispatch to the
|
||||
// correct warp level mma. On volta, all data is stored to shared memory as FP16.
|
||||
template <typename WarpMma, int kExpansionFactor = 1>
|
||||
CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, typename WarpMma::FragmentC& D,
|
||||
typename WarpMma::FragmentA const& A, typename WarpMma::FragmentB const& B, typename WarpMma::FragmentC const& C,
|
||||
int const warp_tileB_k_offset)
|
||||
{
|
||||
warp_mma(D, A, B, C);
|
||||
}
|
||||
|
||||
template <typename WarpMma, int kExpansionFactor = WarpMma::kExpansionFactor>
|
||||
CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, typename WarpMma::FragmentC& D,
|
||||
typename WarpMma::TransformedFragmentA const& A, typename WarpMma::TransformedFragmentB const& B,
|
||||
typename WarpMma::FragmentC const& C, int const warp_tileB_k_offset)
|
||||
{
|
||||
warp_mma(D, A, B, C, warp_tileB_k_offset);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
|
||||
/// instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy_,
|
||||
/// The type of the scales
|
||||
typename ElementScale_,
|
||||
/// Number of stages,
|
||||
int Stages,
|
||||
/// The dequantizing op to be performed.
|
||||
WeightOnlyQuantOp DequantOp,
|
||||
/// Used for partial specialization,
|
||||
typename Enable = bool>
|
||||
class DqMmaBase
|
||||
{
|
||||
public:
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape = Shape_;
|
||||
|
||||
///< Policy describing tuning details
|
||||
using Policy = Policy_;
|
||||
|
||||
///< Type of the scale to be loaded
|
||||
using ElementScale = ElementScale_;
|
||||
|
||||
static_assert(DequantOp != WeightOnlyQuantOp::UNDEFINED, "");
|
||||
|
||||
// Finegrained scales get streamed in via cp.async
|
||||
static constexpr int ScalebiasStages = isFinegrained(DequantOp) ? Stages : 1;
|
||||
// We always have scales.
|
||||
static constexpr int ScaleElementsPerStage = Shape::kN;
|
||||
// We sometimes have a bias
|
||||
static constexpr int BiasElementsPerStage = hasZero(DequantOp) ? Shape::kN : 0;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator = typename Policy::Operator;
|
||||
|
||||
/// Shape describing the overall GEMM computed from shared memory
|
||||
/// by each warp.
|
||||
using WarpGemm = typename Policy::Operator::Shape;
|
||||
|
||||
/// Shape describing the number of warps filling the CTA
|
||||
using WarpCount = GemmShape<Shape::kM / WarpGemm::kM, Shape::kN / WarpGemm::kN, Shape::kK / WarpGemm::kK>;
|
||||
|
||||
/// Number of warp-level GEMM operations
|
||||
static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK);
|
||||
|
||||
static constexpr int kNumKIterationsPerWarpBLoad
|
||||
= Operator::IteratorB::InstructionShape::kRow / Operator::InstructionShape::kK;
|
||||
|
||||
static_assert(!(kWarpGemmIterations % kNumKIterationsPerWarpBLoad), "");
|
||||
static constexpr int kWarpGemmIterationsForB = kWarpGemmIterations / kNumKIterationsPerWarpBLoad;
|
||||
|
||||
/// Number of stages
|
||||
static int const kStages = Stages;
|
||||
|
||||
/// Tensor reference to the A operand
|
||||
using TensorRefA = TensorRef<typename Operator::ElementA, typename Operator::LayoutA>;
|
||||
|
||||
/// Tensor reference to the B operand
|
||||
using TensorRefB = TensorRef<typename Operator::ElementB, typename Operator::LayoutB>;
|
||||
|
||||
//
|
||||
// Nested structs
|
||||
//
|
||||
|
||||
/// Shared storage object needed by threadblock-scoped GEMM
|
||||
class SharedStorage
|
||||
{
|
||||
public:
|
||||
//
|
||||
// Type definitions
|
||||
//
|
||||
|
||||
/// Shape of the A matrix operand in shared memory
|
||||
using ShapeA
|
||||
= MatrixShape<Shape::kM + Policy::SmemPaddingA::kRow, Shape::kK * kStages + Policy::SmemPaddingA::kColumn>;
|
||||
|
||||
/// Shape of the B matrix operand in shared memory
|
||||
using ShapeB
|
||||
= MatrixShape<Shape::kK * kStages + Policy::SmemPaddingB::kRow, Shape::kN + Policy::SmemPaddingB::kColumn>;
|
||||
|
||||
/// Shape of the shared memory buffer for the scales for the B matrix.
|
||||
using ShapeScale = MatrixShape<ScalebiasStages, ScaleElementsPerStage>;
|
||||
/// Shape of the shared memory buffer for the biases of the B matrix.
|
||||
using ShapeZero = MatrixShape<ScalebiasStages, BiasElementsPerStage>;
|
||||
|
||||
public:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Buffer for A operand
|
||||
AlignedBuffer<typename Operator::ElementA, ShapeA::kCount> operand_A;
|
||||
|
||||
/// Buffer for B operand
|
||||
AlignedBuffer<typename Operator::ElementB, ShapeB::kCount> operand_B;
|
||||
|
||||
/// Buffer to hold scales for threadblock
|
||||
AlignedBuffer<ElementScale, ShapeScale::kCount> operand_scale;
|
||||
|
||||
/// Buffer to hold scales for threadblock
|
||||
AlignedBuffer<ElementScale, ShapeZero::kCount> operand_zero;
|
||||
|
||||
public:
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Returns a layout object for the A matrix
|
||||
CUTLASS_DEVICE
|
||||
static typename Operator::LayoutA LayoutA()
|
||||
{
|
||||
return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn});
|
||||
}
|
||||
|
||||
/// Returns a layout object for the B matrix
|
||||
CUTLASS_HOST_DEVICE
|
||||
static typename Operator::LayoutB LayoutB()
|
||||
{
|
||||
return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn});
|
||||
}
|
||||
|
||||
/// Returns a TensorRef to the A operand
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRefA operand_A_ref()
|
||||
{
|
||||
return TensorRefA{operand_A.data(), LayoutA()};
|
||||
}
|
||||
|
||||
/// Returns a TensorRef to the B operand
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRefB operand_B_ref()
|
||||
{
|
||||
return TensorRefB{operand_B.data(), LayoutB()};
|
||||
}
|
||||
};
|
||||
|
||||
protected:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Iterator to load a warp-scoped tile of A operand from shared memory
|
||||
typename Operator::IteratorA warp_tile_iterator_A_;
|
||||
|
||||
/// Iterator to load a warp-scoped tile of B operand from shared memory
|
||||
typename Operator::IteratorB warp_tile_iterator_B_;
|
||||
|
||||
public:
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
DqMmaBase(
|
||||
///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
SharedStorage& shared_storage,
|
||||
///< ID within the threadblock
|
||||
int thread_idx,
|
||||
///< ID of warp
|
||||
int warp_idx,
|
||||
///< ID of each thread within a warp
|
||||
int lane_idx)
|
||||
: warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx)
|
||||
, warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -1,110 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h"
|
||||
#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h"
|
||||
#include "cutlass_extensions/interleaved_numeric_conversion.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
|
||||
/// instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Iterates over tiles of A operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorA_,
|
||||
/// Iterates over tiles of A operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorA_,
|
||||
/// Cache operation for operand A
|
||||
cutlass::arch::CacheOperation::Kind CacheOpA,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorB_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB_,
|
||||
/// Cache operation for operand B
|
||||
cutlass::arch::CacheOperation::Kind CacheOpB,
|
||||
/// Data type for the scales
|
||||
typename IteratorScale_,
|
||||
/// Iterators over scales in shared memory
|
||||
typename SmemIteratorScale_,
|
||||
/// Data type of accumulator matrix
|
||||
typename ElementC_,
|
||||
/// Data type of accumulator matrix
|
||||
typename LayoutC_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy_,
|
||||
/// Number of stages,
|
||||
int Stages,
|
||||
/// Converter for B matrix applited immediately after the LDS
|
||||
typename TransformBAfterLDS_,
|
||||
/// The quantization operator being used
|
||||
WeightOnlyQuantOp QuantOp_,
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
|
||||
/// Used for partial specialization
|
||||
typename Enable = void>
|
||||
class DqMmaMultistage;
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
#include "cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h"
|
||||
@@ -1,708 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h"
|
||||
#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h"
|
||||
#include "cutlass_extensions/interleaved_numeric_conversion.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
|
||||
/// instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Iterates over tiles of A operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorA_,
|
||||
/// Iterates over tiles of A operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorA_,
|
||||
/// Cache operation for operand A
|
||||
cutlass::arch::CacheOperation::Kind CacheOpA,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorB_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB_,
|
||||
/// Cache operation for operand B
|
||||
cutlass::arch::CacheOperation::Kind CacheOpB,
|
||||
/// Iterators over scales in global memory
|
||||
typename IteratorScale_,
|
||||
/// Iterators over scales in shared memory
|
||||
typename SmemIteratorScale_,
|
||||
/// Data type of accumulator matrix
|
||||
typename ElementC_,
|
||||
/// Layout of accumulator matrix
|
||||
typename LayoutC_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy_,
|
||||
/// Number of stages,
|
||||
int Stages,
|
||||
/// Converter for B matrix applied immediately after the LDS
|
||||
typename TransformBAfterLDS_,
|
||||
/// The quantization operator being used
|
||||
WeightOnlyQuantOp QuantOp_,
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
SharedMemoryClearOption SharedMemoryClear>
|
||||
class DqMmaMultistage<Shape_, IteratorA_, SmemIteratorA_, CacheOpA, IteratorB_, SmemIteratorB_, CacheOpB,
|
||||
IteratorScale_, SmemIteratorScale_, ElementC_, LayoutC_, Policy_, Stages, TransformBAfterLDS_, QuantOp_,
|
||||
SharedMemoryClear, std::enable_if_t<isFinegrained(QuantOp_)>>
|
||||
: public DqMmaBase<Shape_, Policy_, typename IteratorScale_::Element, Stages, QuantOp_>
|
||||
{
|
||||
public:
|
||||
///< Base class
|
||||
using Base = DqMmaBase<Shape_, Policy_, typename IteratorScale_::Element, Stages, QuantOp_>;
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape = Shape_;
|
||||
///< Iterates over tiles of A operand in global memory
|
||||
using IteratorA = IteratorA_;
|
||||
///< Iterates over tiles of B operand in global memory
|
||||
using IteratorB = IteratorB_;
|
||||
///< Data type of accumulator matrix
|
||||
using ElementC = ElementC_;
|
||||
///< Layout of accumulator matrix
|
||||
using LayoutC = LayoutC_;
|
||||
///< Policy describing tuning details
|
||||
using Policy = Policy_;
|
||||
|
||||
using IteratorScale = IteratorScale_;
|
||||
using ElementScale = typename IteratorScale::Element;
|
||||
using LayoutScale = typename IteratorScale::Layout;
|
||||
|
||||
using SmemIteratorA = SmemIteratorA_;
|
||||
using SmemIteratorB = SmemIteratorB_;
|
||||
using SmemIteratorScale = SmemIteratorScale_;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
|
||||
|
||||
using TransformBAfterLDS = TransformBAfterLDS_;
|
||||
|
||||
static constexpr WeightOnlyQuantOp QuantOp = QuantOp_;
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC = typename Policy::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator = typename Policy::Operator;
|
||||
|
||||
/// Minimum architecture is Sm80 to support cp.async
|
||||
using ArchTag = arch::Sm80;
|
||||
|
||||
using Dequantizer = warp::MmaTensorOpDequantizer<Operator, typename Base::WarpGemm, Operand::kB, ElementScale,
|
||||
LayoutScale, 32, QuantOp>;
|
||||
|
||||
/// Complex transform on A operand
|
||||
static ComplexTransform const kTransformA = Operator::kTransformA;
|
||||
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB = Operator::kTransformB;
|
||||
|
||||
static_assert(Base::SharedStorage::ShapeScale::kRow == Stages, "");
|
||||
static_assert(Base::SharedStorage::ShapeScale::kColumn == Shape::kN, "");
|
||||
|
||||
/// Internal structure exposed for introspection.
|
||||
struct Detail
|
||||
{
|
||||
|
||||
static_assert(Base::kWarpGemmIterations > 1,
|
||||
"The pipelined structure requires at least two warp-level "
|
||||
"GEMM operations.");
|
||||
|
||||
/// Number of cp.async instructions to load one stage of operand A
|
||||
static int const AsyncCopyIterationsPerStageA = IteratorA::ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Number of cp.async instructions to load one stage of operand B
|
||||
static int const AsyncCopyIterationsPerStageB = IteratorB::ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Number of stages
|
||||
static int const kStages = Stages;
|
||||
|
||||
/// Number of cp.async instructions to load on group of operand A
|
||||
static int const kAccessesPerGroupA
|
||||
= (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
|
||||
|
||||
/// Number of cp.async instructions to load on group of operand B
|
||||
static int const kAccessesPerGroupB
|
||||
= (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
|
||||
};
|
||||
|
||||
private:
|
||||
using WarpFragmentA = typename Operator::FragmentA;
|
||||
using WarpFragmentB = typename Operator::FragmentB;
|
||||
Dequantizer warp_dequantizer_;
|
||||
|
||||
using ElementA = typename IteratorA::Element;
|
||||
using ElementB = typename IteratorB::Element;
|
||||
using LayoutDetailsForB = kernel::LayoutDetailsB<ElementA, ElementB, ArchTag>;
|
||||
|
||||
static constexpr bool RequiresTileInterleave
|
||||
= layout::IsColumnMajorTileInterleave<typename LayoutDetailsForB::Layout>::value;
|
||||
static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)),
|
||||
"Layout K must match threadblockK");
|
||||
|
||||
private:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
||||
SmemIteratorA smem_iterator_A_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB smem_iterator_B_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of scale and zero operand to shared memory
|
||||
SmemIteratorScale smem_iterator_scale_;
|
||||
|
||||
public:
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
DqMmaMultistage(
|
||||
///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
typename Base::SharedStorage& shared_storage,
|
||||
/// The group size for quantization
|
||||
int const group_size,
|
||||
///< ID within the threadblock
|
||||
int thread_idx,
|
||||
///< ID of warp
|
||||
int warp_idx,
|
||||
///< ID of each thread within a warp
|
||||
int lane_idx)
|
||||
: Base(shared_storage, thread_idx, warp_idx, lane_idx)
|
||||
, warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)},
|
||||
{shared_storage.operand_zero.data(), LayoutScale(Shape::kN)},
|
||||
(warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx)
|
||||
, smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx)
|
||||
, smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx)
|
||||
, smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(),
|
||||
shared_storage.operand_zero.data(), {Base::kStages, Shape::kN}, thread_idx, group_size)
|
||||
{
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
// _m: the warp's position within the threadblock along the M dimension
|
||||
// _n: the warp's position within the threadblock along the N dimension
|
||||
// _k: the warp's position within the threadblock along the K dimension
|
||||
|
||||
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
|
||||
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
|
||||
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
|
||||
|
||||
// Add per-warp offsets in units of warp-level tiles
|
||||
this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
|
||||
this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n});
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void copy_scales_and_advance(IteratorScale& iterator_scale, int stage = -1, int k_iter = -1)
|
||||
{
|
||||
static_assert(IteratorScale::Shape::kRow == 1, "Scale stride must be 1.");
|
||||
|
||||
typename IteratorScale::AccessType* gmem_scale_ptr = iterator_scale.get_scale();
|
||||
typename IteratorScale::AccessType* gmem_zero_ptr = iterator_scale.get_zero();
|
||||
|
||||
typename IteratorScale::AccessType* smem_scale_ptr
|
||||
= reinterpret_cast<typename IteratorScale::AccessType*>(this->smem_iterator_scale_.get_scale());
|
||||
typename IteratorScale::AccessType* smem_zero_ptr
|
||||
= reinterpret_cast<typename IteratorScale::AccessType*>(this->smem_iterator_scale_.get_zero());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorScale::Element>::value * IteratorScale::kAlignment / 8;
|
||||
|
||||
cutlass::arch::cp_async<kSrcBytes, kCacheOpB>(smem_scale_ptr, gmem_scale_ptr, iterator_scale.valid());
|
||||
|
||||
if (gmem_zero_ptr != nullptr)
|
||||
{
|
||||
cutlass::arch::cp_async<kSrcBytes, kCacheOpB>(smem_zero_ptr, gmem_zero_ptr, iterator_scale.valid());
|
||||
}
|
||||
|
||||
if (iterator_scale.group_size_ == 64)
|
||||
{
|
||||
iterator_scale.add_tile_offset({1, 0});
|
||||
}
|
||||
else if (iterator_scale.group_size_ == 128)
|
||||
{
|
||||
if constexpr (Shape::kK == 128)
|
||||
{
|
||||
iterator_scale.add_tile_offset({1, 0});
|
||||
}
|
||||
else if constexpr (Shape::kK == 64)
|
||||
{
|
||||
if (iterator_scale.row_groupsize64_ & 0x1)
|
||||
{
|
||||
iterator_scale.add_tile_offset({1, 0});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(Shape::kK == 0, "Unsupported k tile shape, can only be 64 or 128");
|
||||
}
|
||||
}
|
||||
|
||||
iterator_scale.row_groupsize64_++;
|
||||
|
||||
this->smem_iterator_scale_.add_tile_offset({1, 0});
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void copy_tiles_and_advance(
|
||||
IteratorA& iterator_A, IteratorB& iterator_B, int group_start_A = 0, int group_start_B = 0)
|
||||
{
|
||||
iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector);
|
||||
this->smem_iterator_A_.set_iteration_index(group_start_A);
|
||||
|
||||
// Async Copy for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::kAccessesPerGroupA; ++j)
|
||||
{
|
||||
if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA)
|
||||
{
|
||||
typename IteratorA::AccessType* dst_ptr
|
||||
= reinterpret_cast<typename IteratorA::AccessType*>(this->smem_iterator_A_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value
|
||||
* IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v)
|
||||
{
|
||||
auto gmem_ptr = iterator_A.get();
|
||||
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kZfill)
|
||||
{
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(dst_ptr + v, gmem_ptr, iterator_A.valid());
|
||||
}
|
||||
else
|
||||
{
|
||||
cutlass::arch::cp_async<kSrcBytes, kCacheOpA>(dst_ptr + v, gmem_ptr, iterator_A.valid());
|
||||
}
|
||||
|
||||
++iterator_A;
|
||||
}
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
}
|
||||
}
|
||||
|
||||
iterator_B.set_iteration_index(group_start_B * IteratorB::kAccessesPerVector);
|
||||
this->smem_iterator_B_.set_iteration_index(group_start_B);
|
||||
|
||||
// Async Copy for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::kAccessesPerGroupB; ++j)
|
||||
{
|
||||
if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB)
|
||||
{
|
||||
typename IteratorB::AccessType* dst_ptr
|
||||
= reinterpret_cast<typename IteratorB::AccessType*>(this->smem_iterator_B_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value
|
||||
* IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v)
|
||||
{
|
||||
auto gmem_ptr = iterator_B.get();
|
||||
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kZfill)
|
||||
{
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(dst_ptr + v, gmem_ptr, iterator_B.valid());
|
||||
}
|
||||
else
|
||||
{
|
||||
cutlass::arch::cp_async<kSrcBytes, kCacheOpB>(dst_ptr + v, gmem_ptr, iterator_B.valid());
|
||||
}
|
||||
|
||||
++iterator_B;
|
||||
}
|
||||
++this->smem_iterator_B_;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a threadblock-scoped matrix multiply-accumulate
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
///< problem size of GEMM
|
||||
int gemm_k_iterations,
|
||||
///< destination accumulator tile
|
||||
FragmentC& accum,
|
||||
///< iterator over A operand in global memory
|
||||
IteratorA iterator_A,
|
||||
///< iterator over B operand in global memory
|
||||
IteratorB iterator_B,
|
||||
///< iterator over scale operand in global memory
|
||||
IteratorScale iterator_scale,
|
||||
///< initial value of accumulator
|
||||
FragmentC const& src_accum)
|
||||
{
|
||||
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
|
||||
TransformBAfterLDS lds_converter;
|
||||
|
||||
// Issue several complete stages
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations)
|
||||
{
|
||||
|
||||
iterator_A.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_scale.clear_mask(gemm_k_iterations == 0);
|
||||
|
||||
iterator_A.set_iteration_index(0);
|
||||
this->smem_iterator_A_.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j)
|
||||
{
|
||||
typename IteratorA::AccessType* dst_ptr
|
||||
= reinterpret_cast<typename IteratorA::AccessType*>(this->smem_iterator_A_.get());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v)
|
||||
{
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value
|
||||
* IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
|
||||
dst_ptr + v, iterator_A.get(), iterator_A.valid());
|
||||
|
||||
++iterator_A;
|
||||
}
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
}
|
||||
|
||||
iterator_B.set_iteration_index(0);
|
||||
this->smem_iterator_B_.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j)
|
||||
{
|
||||
typename IteratorB::AccessType* dst_ptr
|
||||
= reinterpret_cast<typename IteratorB::AccessType*>(this->smem_iterator_B_.get());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v)
|
||||
{
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value
|
||||
* IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
|
||||
dst_ptr + v, iterator_B.get(), iterator_B.valid());
|
||||
|
||||
++iterator_B;
|
||||
}
|
||||
|
||||
++this->smem_iterator_B_;
|
||||
}
|
||||
|
||||
copy_scales_and_advance(iterator_scale, stage, gemm_k_iterations);
|
||||
|
||||
// Move to the next stage
|
||||
iterator_A.add_tile_offset({0, 1});
|
||||
iterator_B.add_tile_offset({1, 0});
|
||||
|
||||
this->smem_iterator_A_.add_tile_offset({0, 1});
|
||||
this->smem_iterator_B_.add_tile_offset({1, 0});
|
||||
|
||||
// Defines the boundary of a stage of cp.async.
|
||||
cutlass::arch::cp_async_fence();
|
||||
}
|
||||
|
||||
// Perform accumulation in the 'd' output operand
|
||||
accum = src_accum;
|
||||
|
||||
//
|
||||
// Clear the remaining tiles of SMEM. This is a functional requirement for some kernels
|
||||
// so that all accumulator elements outside the GEMM footprint are zero.
|
||||
//
|
||||
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage)
|
||||
{
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
||||
SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_);
|
||||
|
||||
typename IteratorA::AccessType zero_A;
|
||||
zero_A.clear();
|
||||
|
||||
last_smem_iterator_A.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j)
|
||||
{
|
||||
|
||||
typename IteratorA::AccessType* dst_ptr
|
||||
= reinterpret_cast<typename IteratorA::AccessType*>(last_smem_iterator_A.get());
|
||||
|
||||
*dst_ptr = zero_A;
|
||||
|
||||
++last_smem_iterator_A;
|
||||
}
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_);
|
||||
typename IteratorB::AccessType zero_B;
|
||||
|
||||
zero_B.clear();
|
||||
last_smem_iterator_B.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j)
|
||||
{
|
||||
|
||||
typename IteratorB::AccessType* dst_ptr
|
||||
= reinterpret_cast<typename IteratorB::AccessType*>(last_smem_iterator_B.get());
|
||||
|
||||
*dst_ptr = zero_B;
|
||||
|
||||
++last_smem_iterator_B;
|
||||
}
|
||||
}
|
||||
|
||||
// Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - #committed)
|
||||
cutlass::arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math
|
||||
// instructions
|
||||
WarpFragmentA warp_frag_A[2];
|
||||
WarpFragmentB warp_frag_B[2];
|
||||
typename Dequantizer::FragmentScale warp_frag_scales;
|
||||
typename Dequantizer::FragmentZero warp_frag_zeros;
|
||||
|
||||
Operator warp_mma;
|
||||
|
||||
this->warp_tile_iterator_A_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(0);
|
||||
|
||||
this->warp_tile_iterator_A_.load(warp_frag_A[0]);
|
||||
this->warp_tile_iterator_B_.load(warp_frag_B[0]);
|
||||
|
||||
warp_dequantizer_.load(warp_frag_scales, warp_frag_zeros);
|
||||
|
||||
++this->warp_tile_iterator_A_;
|
||||
++this->warp_tile_iterator_B_;
|
||||
warp_dequantizer_.add_pointer_offset(Shape::kN);
|
||||
|
||||
iterator_A.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_scale.clear_mask(gemm_k_iterations == 0);
|
||||
|
||||
int smem_write_stage_idx = Base::kStages - 1;
|
||||
int smem_read_stage_idx = 0;
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; gemm_k_iterations > (-Base::kStages + 1);)
|
||||
{
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
// Computes a warp-level GEMM on data held in shared memory
|
||||
// Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k)
|
||||
{
|
||||
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if
|
||||
// this is the last group as the case may be.
|
||||
|
||||
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
|
||||
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
|
||||
++this->warp_tile_iterator_A_;
|
||||
|
||||
int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
|
||||
int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
|
||||
if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1)
|
||||
{
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(
|
||||
(warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB);
|
||||
this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]);
|
||||
++this->warp_tile_iterator_B_;
|
||||
}
|
||||
|
||||
typename TransformBAfterLDS::result_type converted_frag_B
|
||||
= lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);
|
||||
warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales, warp_frag_zeros);
|
||||
|
||||
using FragmentOperandB = cutlass::Array<ElementA, Operator::FragmentB::kElements>;
|
||||
constexpr cutlass::FloatRoundStyle RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
|
||||
constexpr int ConversionVectorWidth = TransformBAfterLDS::result_type::kElements;
|
||||
static_assert(ConversionVectorWidth == FragmentOperandB::kElements);
|
||||
|
||||
using Converter
|
||||
= cutlass::NumericArrayConverter<ElementA, ElementScale, ConversionVectorWidth, RoundStyle>;
|
||||
|
||||
FragmentOperandB converted_frag_B_operand = Converter::convert(converted_frag_B);
|
||||
run_warp_mma(warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B_operand, accum,
|
||||
warp_tileB_k_compute_offset);
|
||||
|
||||
// Issue global->shared copies for the this stage
|
||||
if (warp_mma_k < Base::kWarpGemmIterations - 1)
|
||||
{
|
||||
int group_start_iteration_A, group_start_iteration_B;
|
||||
|
||||
group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA;
|
||||
group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB;
|
||||
|
||||
copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B);
|
||||
|
||||
// This is the first group of a given stage, so we issue the loads for the B scales immediately.
|
||||
if (group_start_iteration_B == 0)
|
||||
{
|
||||
copy_scales_and_advance(iterator_scale);
|
||||
}
|
||||
}
|
||||
|
||||
if (warp_mma_k + 2 == Base::kWarpGemmIterations)
|
||||
{
|
||||
int group_start_iteration_A, group_start_iteration_B;
|
||||
group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA;
|
||||
group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB;
|
||||
|
||||
copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B);
|
||||
|
||||
// Inserts a memory fence between stages of cp.async instructions.
|
||||
cutlass::arch::cp_async_fence();
|
||||
|
||||
// Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 -
|
||||
// #committed)
|
||||
arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Move to the next stage
|
||||
iterator_A.add_tile_offset({0, 1});
|
||||
iterator_B.add_tile_offset({1, 0});
|
||||
|
||||
this->smem_iterator_A_.add_tile_offset({0, 1});
|
||||
this->smem_iterator_B_.add_tile_offset({1, 0});
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the
|
||||
// circular buffer in shared memory
|
||||
if (smem_write_stage_idx == (Base::kStages - 1))
|
||||
{
|
||||
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
|
||||
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
|
||||
this->smem_iterator_scale_.add_tile_offset({-Base::kStages, 0});
|
||||
smem_write_stage_idx = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
++smem_write_stage_idx;
|
||||
}
|
||||
|
||||
if (smem_read_stage_idx == (Base::kStages - 1))
|
||||
{
|
||||
this->warp_tile_iterator_A_.add_tile_offset(
|
||||
{0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
|
||||
this->warp_tile_iterator_B_.add_tile_offset(
|
||||
{-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0});
|
||||
warp_dequantizer_.add_pointer_offset(-Base::kStages * Shape::kN);
|
||||
smem_read_stage_idx = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
++smem_read_stage_idx;
|
||||
}
|
||||
|
||||
--gemm_k_iterations;
|
||||
iterator_A.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_scale.clear_mask(gemm_k_iterations == 0);
|
||||
}
|
||||
}
|
||||
|
||||
// Load the scale needed for the next tile iteration.
|
||||
warp_dequantizer_.load(warp_frag_scales, warp_frag_zeros);
|
||||
// Update internal pointer to set of scales in shared memory.
|
||||
warp_dequantizer_.add_pointer_offset(Shape::kN);
|
||||
}
|
||||
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kZfill)
|
||||
{
|
||||
// commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop
|
||||
cutlass::arch::cp_async_fence();
|
||||
cutlass::arch::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -1,647 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h"
|
||||
#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h"
|
||||
#include "cutlass_extensions/interleaved_numeric_conversion.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
|
||||
/// instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Iterates over tiles of A operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorA_,
|
||||
/// Iterates over tiles of A operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorA_,
|
||||
/// Cache operation for operand A
|
||||
cutlass::arch::CacheOperation::Kind CacheOpA,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorB_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB_,
|
||||
/// Cache operation for operand B
|
||||
cutlass::arch::CacheOperation::Kind CacheOpB,
|
||||
/// Iterators over scales in global memory
|
||||
typename IteratorScale_,
|
||||
/// Iterators over scales in shared memory
|
||||
typename SmemIteratorScale_,
|
||||
/// Data type of accumulator matrix
|
||||
typename ElementC_,
|
||||
/// Layout of accumulator matrix
|
||||
typename LayoutC_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy_,
|
||||
/// Number of stages,
|
||||
int Stages,
|
||||
/// Converter for B matrix applited immediately after the LDS
|
||||
typename TransformBAfterLDS_,
|
||||
/// The quantization operator being used
|
||||
WeightOnlyQuantOp QuantOp_,
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
SharedMemoryClearOption SharedMemoryClear>
|
||||
class DqMmaMultistage<Shape_, IteratorA_, SmemIteratorA_, CacheOpA, IteratorB_, SmemIteratorB_, CacheOpB,
|
||||
IteratorScale_, SmemIteratorScale_, ElementC_, LayoutC_, Policy_, Stages, TransformBAfterLDS_, QuantOp_,
|
||||
SharedMemoryClear, std::enable_if_t<!isFinegrained(QuantOp_)>>
|
||||
: public DqMmaBase<Shape_, Policy_, typename IteratorScale_::Element, Stages, QuantOp_>
|
||||
{
|
||||
public:
|
||||
///< Base class
|
||||
using Base = DqMmaBase<Shape_, Policy_, typename IteratorScale_::Element, Stages, QuantOp_>;
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape = Shape_;
|
||||
///< Iterates over tiles of A operand in global memory
|
||||
using IteratorA = IteratorA_;
|
||||
///< Iterates over tiles of B operand in global memory
|
||||
using IteratorB = IteratorB_;
|
||||
///< Data type of accumulator matrix
|
||||
using ElementC = ElementC_;
|
||||
///< Layout of accumulator matrix
|
||||
using LayoutC = LayoutC_;
|
||||
///< Policy describing tuning details
|
||||
using Policy = Policy_;
|
||||
|
||||
using IteratorScale = IteratorScale_;
|
||||
using ElementScale = typename IteratorScale::Element;
|
||||
using LayoutScale = typename IteratorScale::Layout;
|
||||
|
||||
using SmemIteratorA = SmemIteratorA_;
|
||||
using SmemIteratorB = SmemIteratorB_;
|
||||
using SmemIteratorScale = SmemIteratorScale_;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
|
||||
|
||||
using TransformBAfterLDS = TransformBAfterLDS_;
|
||||
|
||||
static constexpr WeightOnlyQuantOp QuantOp = QuantOp_;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Fragment of operand Scale loaded from global memory;
|
||||
using FragmentScale = typename IteratorScale::Fragment;
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC = typename Policy::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator = typename Policy::Operator;
|
||||
|
||||
/// Minimum architecture is Sm80 to support cp.async
|
||||
using ArchTag = arch::Sm80;
|
||||
|
||||
using Dequantizer = warp::MmaTensorOpDequantizer<Operator, typename Base::WarpGemm, Operand::kB, ElementScale,
|
||||
LayoutScale, 32, QuantOp>;
|
||||
|
||||
/// Complex transform on A operand
|
||||
static ComplexTransform const kTransformA = Operator::kTransformA;
|
||||
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB = Operator::kTransformB;
|
||||
|
||||
/// Internal structure exposed for introspection.
|
||||
struct Detail
|
||||
{
|
||||
|
||||
static_assert(Base::kWarpGemmIterations > 1,
|
||||
"The pipelined structure requires at least two warp-level "
|
||||
"GEMM operations.");
|
||||
|
||||
/// Number of cp.async instructions to load one stage of operand A
|
||||
static int const AsyncCopyIterationsPerStageA = IteratorA::ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Number of cp.async instructions to load one stage of operand B
|
||||
static int const AsyncCopyIterationsPerStageB = IteratorB::ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Number of stages
|
||||
static int const kStages = Stages;
|
||||
|
||||
/// Number of cp.async instructions to load on group of operand A
|
||||
static int const kAccessesPerGroupA
|
||||
= (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
|
||||
|
||||
/// Number of cp.async instructions to load on group of operand B
|
||||
static int const kAccessesPerGroupB
|
||||
= (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
|
||||
};
|
||||
|
||||
private:
|
||||
using WarpFragmentA = typename Operator::FragmentA;
|
||||
using WarpFragmentB = typename Operator::FragmentB;
|
||||
Dequantizer warp_dequantizer_;
|
||||
|
||||
using ElementA = typename IteratorA::Element;
|
||||
using ElementB = typename IteratorB::Element;
|
||||
using LayoutDetailsForB = kernel::LayoutDetailsB<ElementA, ElementB, ArchTag>;
|
||||
|
||||
static constexpr bool RequiresTileInterleave
|
||||
= layout::IsColumnMajorTileInterleave<typename LayoutDetailsForB::Layout>::value;
|
||||
static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)),
|
||||
"Layout K must match threadblockK");
|
||||
|
||||
private:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
||||
SmemIteratorA smem_iterator_A_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB smem_iterator_B_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of scale operand to shared memory
|
||||
SmemIteratorScale smem_iterator_scale_;
|
||||
|
||||
public:
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
DqMmaMultistage(
|
||||
///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
typename Base::SharedStorage& shared_storage,
|
||||
///< Group size for quantization. Not used by this main loop since it assumes per-column
|
||||
int const group_size,
|
||||
///< ID within the threadblock
|
||||
int thread_idx,
|
||||
///< ID of warp
|
||||
int warp_idx,
|
||||
///< ID of each thread within a warp
|
||||
int lane_idx)
|
||||
: Base(shared_storage, thread_idx, warp_idx, lane_idx)
|
||||
, warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)},
|
||||
(warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx)
|
||||
, smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx)
|
||||
, smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx)
|
||||
, smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), {1, Shape::kN}, thread_idx)
|
||||
{
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
// _m: the warp's position within the threadblock along the M dimension
|
||||
// _n: the warp's position within the threadblock along the N dimension
|
||||
// _k: the warp's position within the threadblock along the K dimension
|
||||
|
||||
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
|
||||
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
|
||||
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
|
||||
|
||||
// Add per-warp offsets in units of warp-level tiles
|
||||
this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
|
||||
this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n});
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void copy_tiles_and_advance(
|
||||
IteratorA& iterator_A, IteratorB& iterator_B, int group_start_A = 0, int group_start_B = 0)
|
||||
{
|
||||
iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector);
|
||||
this->smem_iterator_A_.set_iteration_index(group_start_A);
|
||||
|
||||
// Async Copy for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::kAccessesPerGroupA; ++j)
|
||||
{
|
||||
if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA)
|
||||
{
|
||||
typename IteratorA::AccessType* dst_ptr
|
||||
= reinterpret_cast<typename IteratorA::AccessType*>(this->smem_iterator_A_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value
|
||||
* IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v)
|
||||
{
|
||||
auto gmem_ptr = iterator_A.get();
|
||||
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kZfill)
|
||||
{
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(dst_ptr + v, gmem_ptr, iterator_A.valid());
|
||||
}
|
||||
else
|
||||
{
|
||||
cutlass::arch::cp_async<kSrcBytes, kCacheOpA>(dst_ptr + v, gmem_ptr, iterator_A.valid());
|
||||
}
|
||||
|
||||
++iterator_A;
|
||||
}
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
}
|
||||
}
|
||||
|
||||
iterator_B.set_iteration_index(group_start_B * IteratorB::kAccessesPerVector);
|
||||
this->smem_iterator_B_.set_iteration_index(group_start_B);
|
||||
|
||||
// Async Copy for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::kAccessesPerGroupB; ++j)
|
||||
{
|
||||
if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB)
|
||||
{
|
||||
typename IteratorB::AccessType* dst_ptr
|
||||
= reinterpret_cast<typename IteratorB::AccessType*>(this->smem_iterator_B_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value
|
||||
* IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v)
|
||||
{
|
||||
auto gmem_ptr = iterator_B.get();
|
||||
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kZfill)
|
||||
{
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(dst_ptr + v, gmem_ptr, iterator_B.valid());
|
||||
}
|
||||
else
|
||||
{
|
||||
cutlass::arch::cp_async<kSrcBytes, kCacheOpB>(dst_ptr + v, gmem_ptr, iterator_B.valid());
|
||||
}
|
||||
|
||||
++iterator_B;
|
||||
}
|
||||
++this->smem_iterator_B_;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a threadblock-scoped matrix multiply-accumulate
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
///< problem size of GEMM
|
||||
int gemm_k_iterations,
|
||||
///< destination accumulator tile
|
||||
FragmentC& accum,
|
||||
///< iterator over A operand in global memory
|
||||
IteratorA iterator_A,
|
||||
///< iterator over B operand in global memory
|
||||
IteratorB iterator_B,
|
||||
///< iterator over scale operand in global memory
|
||||
IteratorScale iterator_scale,
|
||||
///< initial value of accumulator
|
||||
FragmentC const& src_accum)
|
||||
{
|
||||
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
|
||||
TransformBAfterLDS lds_converter;
|
||||
|
||||
// NOTE - switch to ldg.sts
|
||||
// Issue this first, so cp.async.commit_group will commit this load as well.
|
||||
// Note: we do not commit here and this load will commit in the same group as
|
||||
// the first load of A.
|
||||
FragmentScale tb_frag_scales;
|
||||
tb_frag_scales.clear();
|
||||
iterator_scale.load(tb_frag_scales);
|
||||
this->smem_iterator_scale_.store(tb_frag_scales);
|
||||
|
||||
// Issue several complete stages
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations)
|
||||
{
|
||||
|
||||
iterator_A.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B.clear_mask(gemm_k_iterations == 0);
|
||||
|
||||
iterator_A.set_iteration_index(0);
|
||||
this->smem_iterator_A_.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j)
|
||||
{
|
||||
typename IteratorA::AccessType* dst_ptr
|
||||
= reinterpret_cast<typename IteratorA::AccessType*>(this->smem_iterator_A_.get());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v)
|
||||
{
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value
|
||||
* IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8;
|
||||
|
||||
int src_bytes = (iterator_A.valid() ? kSrcBytes : 0);
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
|
||||
dst_ptr + v, iterator_A.get(), iterator_A.valid());
|
||||
|
||||
++iterator_A;
|
||||
}
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
}
|
||||
|
||||
iterator_B.set_iteration_index(0);
|
||||
this->smem_iterator_B_.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j)
|
||||
{
|
||||
typename IteratorB::AccessType* dst_ptr
|
||||
= reinterpret_cast<typename IteratorB::AccessType*>(this->smem_iterator_B_.get());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v)
|
||||
{
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value
|
||||
* IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
|
||||
dst_ptr + v, iterator_B.get(), iterator_B.valid());
|
||||
|
||||
++iterator_B;
|
||||
}
|
||||
|
||||
++this->smem_iterator_B_;
|
||||
}
|
||||
|
||||
// Move to the next stage
|
||||
iterator_A.add_tile_offset({0, 1});
|
||||
iterator_B.add_tile_offset({1, 0});
|
||||
|
||||
this->smem_iterator_A_.add_tile_offset({0, 1});
|
||||
this->smem_iterator_B_.add_tile_offset({1, 0});
|
||||
|
||||
// Defines the boundary of a stage of cp.async.
|
||||
cutlass::arch::cp_async_fence();
|
||||
}
|
||||
|
||||
// Perform accumulation in the 'd' output operand
|
||||
accum = src_accum;
|
||||
|
||||
//
|
||||
// Clear the remaining tiles of SMEM. This is a functional requirement for some kernels
|
||||
// so that all accumulator elements outside the GEMM footprint are zero.
|
||||
//
|
||||
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage)
|
||||
{
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
||||
SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_);
|
||||
|
||||
typename IteratorA::AccessType zero_A;
|
||||
zero_A.clear();
|
||||
|
||||
last_smem_iterator_A.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j)
|
||||
{
|
||||
|
||||
typename IteratorA::AccessType* dst_ptr
|
||||
= reinterpret_cast<typename IteratorA::AccessType*>(last_smem_iterator_A.get());
|
||||
|
||||
*dst_ptr = zero_A;
|
||||
|
||||
++last_smem_iterator_A;
|
||||
}
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_);
|
||||
typename IteratorB::AccessType zero_B;
|
||||
|
||||
zero_B.clear();
|
||||
last_smem_iterator_B.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j)
|
||||
{
|
||||
|
||||
typename IteratorB::AccessType* dst_ptr
|
||||
= reinterpret_cast<typename IteratorB::AccessType*>(last_smem_iterator_B.get());
|
||||
|
||||
*dst_ptr = zero_B;
|
||||
|
||||
++last_smem_iterator_B;
|
||||
}
|
||||
}
|
||||
|
||||
// Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - #committed)
|
||||
cutlass::arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math
|
||||
// instructions
|
||||
WarpFragmentA warp_frag_A[2];
|
||||
WarpFragmentB warp_frag_B[2];
|
||||
typename Dequantizer::FragmentScale warp_frag_scales;
|
||||
|
||||
Operator warp_mma;
|
||||
|
||||
this->warp_tile_iterator_A_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(0);
|
||||
|
||||
this->warp_tile_iterator_A_.load(warp_frag_A[0]);
|
||||
this->warp_tile_iterator_B_.load(warp_frag_B[0]);
|
||||
warp_dequantizer_.load(warp_frag_scales);
|
||||
|
||||
++this->warp_tile_iterator_A_;
|
||||
++this->warp_tile_iterator_B_;
|
||||
|
||||
iterator_A.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B.clear_mask(gemm_k_iterations == 0);
|
||||
|
||||
int smem_write_stage_idx = Base::kStages - 1;
|
||||
int smem_read_stage_idx = 0;
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; gemm_k_iterations > (-Base::kStages + 1);)
|
||||
{
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
// Computes a warp-level GEMM on data held in shared memory
|
||||
// Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k)
|
||||
{
|
||||
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if
|
||||
// this is the last group as the case may be.
|
||||
|
||||
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
|
||||
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
|
||||
++this->warp_tile_iterator_A_;
|
||||
|
||||
int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
|
||||
int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
|
||||
if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1)
|
||||
{
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(
|
||||
(warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB);
|
||||
this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]);
|
||||
++this->warp_tile_iterator_B_;
|
||||
}
|
||||
|
||||
typename TransformBAfterLDS::result_type converted_frag_B
|
||||
= lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);
|
||||
warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales);
|
||||
|
||||
using FragmentOperandB = cutlass::Array<ElementA, Operator::FragmentB::kElements>;
|
||||
constexpr cutlass::FloatRoundStyle RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
|
||||
constexpr int ConversionVectorWidth = TransformBAfterLDS::result_type::kElements;
|
||||
static_assert(ConversionVectorWidth == FragmentOperandB::kElements);
|
||||
|
||||
using Converter
|
||||
= cutlass::NumericArrayConverter<ElementA, ElementScale, ConversionVectorWidth, RoundStyle>;
|
||||
|
||||
FragmentOperandB converted_frag_B_operand = Converter::convert(converted_frag_B);
|
||||
run_warp_mma(warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B_operand, accum,
|
||||
warp_tileB_k_compute_offset);
|
||||
|
||||
// Issue global->shared copies for the this stage
|
||||
if (warp_mma_k < Base::kWarpGemmIterations - 1)
|
||||
{
|
||||
int group_start_iteration_A, group_start_iteration_B;
|
||||
|
||||
group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA;
|
||||
group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB;
|
||||
|
||||
copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B);
|
||||
}
|
||||
|
||||
if (warp_mma_k + 2 == Base::kWarpGemmIterations)
|
||||
{
|
||||
int group_start_iteration_A, group_start_iteration_B;
|
||||
group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA;
|
||||
group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB;
|
||||
|
||||
copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B);
|
||||
|
||||
// Inserts a memory fence between stages of cp.async instructions.
|
||||
cutlass::arch::cp_async_fence();
|
||||
|
||||
// Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 -
|
||||
// #committed)
|
||||
arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Move to the next stage
|
||||
iterator_A.add_tile_offset({0, 1});
|
||||
iterator_B.add_tile_offset({1, 0});
|
||||
|
||||
this->smem_iterator_A_.add_tile_offset({0, 1});
|
||||
this->smem_iterator_B_.add_tile_offset({1, 0});
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the
|
||||
// circular buffer in shared memory
|
||||
if (smem_write_stage_idx == (Base::kStages - 1))
|
||||
{
|
||||
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
|
||||
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
|
||||
smem_write_stage_idx = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
++smem_write_stage_idx;
|
||||
}
|
||||
|
||||
if (smem_read_stage_idx == (Base::kStages - 1))
|
||||
{
|
||||
this->warp_tile_iterator_A_.add_tile_offset(
|
||||
{0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
|
||||
this->warp_tile_iterator_B_.add_tile_offset(
|
||||
{-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0});
|
||||
smem_read_stage_idx = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
++smem_read_stage_idx;
|
||||
}
|
||||
|
||||
--gemm_k_iterations;
|
||||
iterator_A.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B.clear_mask(gemm_k_iterations == 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kZfill)
|
||||
{
|
||||
// commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop
|
||||
cutlass::arch::cp_async_fence();
|
||||
cutlass::arch::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -1,106 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h"
|
||||
#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h"
|
||||
#include "cutlass_extensions/interleaved_numeric_conversion.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h"
|
||||
#include "cutlass_extensions/gemm_configs.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Iterates over tiles of A operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
||||
typename IteratorA_,
|
||||
/// Iterates over tiles of A operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorA_,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
||||
typename IteratorB_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB_,
|
||||
/// Data type for the scales
|
||||
typename IteratorScale_,
|
||||
/// Iterators over scales in shared memory
|
||||
typename SmemIteratorScale_,
|
||||
/// Data type of accumulator matrix
|
||||
typename ElementC_,
|
||||
/// Data type of accumulator matrix
|
||||
typename LayoutC_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy_,
|
||||
/// Converter for B matrix applied immediately after the LDG (before STS)
|
||||
typename TransformBAfterLDG_,
|
||||
/// Converter for B matrix applited immediately after the LDS
|
||||
typename TransformBAfterLDS_,
|
||||
/// The quantization operator being used
|
||||
WeightOnlyQuantOp QuantOp_,
|
||||
/// Used for partial specialization
|
||||
typename Enable = void>
|
||||
class DqMmaPipelined;
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
#include "cutlass_extensions/gemm/threadblock/dq_mma_pipelined_finegrained.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/dq_mma_pipelined_percol.h"
|
||||
@@ -1,486 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h"
|
||||
#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h"
|
||||
#include "cutlass_extensions/interleaved_numeric_conversion.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h"
|
||||
#include "cutlass_extensions/gemm_configs.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Iterates over tiles of A operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
||||
typename IteratorA_,
|
||||
/// Iterates over tiles of A operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorA_,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
||||
typename IteratorB_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB_,
|
||||
/// Iterators over scales in global memory
|
||||
typename IteratorScale_,
|
||||
/// Iterators over scales in shared memory
|
||||
typename SmemIteratorScale_,
|
||||
/// Data type of accumulator matrix
|
||||
typename ElementC_,
|
||||
/// Layout of accumulator matrix
|
||||
typename LayoutC_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy_,
|
||||
/// Converter for B matrix applied immediately after the LDG (before STS)
|
||||
typename TransformBAfterLDG_,
|
||||
/// Converter for B matrix applited immediately after the LDS
|
||||
typename TransformBAfterLDS_,
|
||||
/// The quantization operator being used
|
||||
WeightOnlyQuantOp QuantOp_>
|
||||
class DqMmaPipelined<Shape_, IteratorA_, SmemIteratorA_, IteratorB_, SmemIteratorB_, IteratorScale_, SmemIteratorScale_,
|
||||
ElementC_, LayoutC_, Policy_, TransformBAfterLDG_, TransformBAfterLDS_, QuantOp_,
|
||||
std::enable_if_t<isFinegrained(QuantOp_)>>
|
||||
: public DqMmaBase<Shape_, Policy_, typename SmemIteratorScale_::Element, 2, QuantOp_>
|
||||
{
|
||||
public:
|
||||
///< Base class
|
||||
using Base = DqMmaBase<Shape_, Policy_, typename SmemIteratorScale_::Element, 2, QuantOp_>;
|
||||
|
||||
using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory
|
||||
using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory
|
||||
using ElementC = ElementC_; ///< Data type of accumulator matrix
|
||||
using LayoutC = LayoutC_; ///< Layout of accumulator matrix
|
||||
using Policy = Policy_; ///< Policy describing tuning details
|
||||
|
||||
using IteratorScale = IteratorScale_;
|
||||
using ElementScale = typename IteratorScale::Element;
|
||||
using LayoutScale = typename IteratorScale::Layout;
|
||||
|
||||
using SmemIteratorA = SmemIteratorA_;
|
||||
using SmemIteratorB = SmemIteratorB_;
|
||||
using SmemIteratorScale = SmemIteratorScale_;
|
||||
|
||||
using TransformBAfterLDG = TransformBAfterLDG_;
|
||||
using TransformBAfterLDS = TransformBAfterLDS_;
|
||||
|
||||
static constexpr WeightOnlyQuantOp QuantOp = QuantOp_;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Fragment of operand A loaded from global memory
|
||||
using FragmentA = typename IteratorA::Fragment;
|
||||
|
||||
/// Fragment of operand B loaded from global memory
|
||||
using FragmentB = typename IteratorB::Fragment;
|
||||
|
||||
/// Fragment of operand Scale loaded from global memory;
|
||||
using FragmentScale = typename IteratorScale::Fragment;
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC = typename Policy::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator = typename Policy::Operator;
|
||||
|
||||
/// Obtain the arch tag from the warp-level operator
|
||||
using ArchTag = typename Policy::Operator::ArchTag;
|
||||
|
||||
using Dequantizer = warp::MmaTensorOpDequantizer<Operator, typename Base::WarpGemm, Operand::kB,
|
||||
typename SmemIteratorScale::Element, LayoutScale, 32, QuantOp>;
|
||||
|
||||
/// Complex transform on A operand
|
||||
static ComplexTransform const kTransformA = Operator::kTransformA;
|
||||
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB = Operator::kTransformB;
|
||||
|
||||
// staticaly assert kStages for DqMmaPipelined is two (Double-buffered pipeline)
|
||||
static_assert((Base::kStages == 2), "DqMmaPipelined requires kStages set to value 2");
|
||||
|
||||
static_assert(Base::SharedStorage::ShapeScale::kRow == Base::kStages, "");
|
||||
static_assert(Base::SharedStorage::ShapeScale::kColumn == Shape::kN, "");
|
||||
|
||||
private:
|
||||
using WarpFragmentA = typename Operator::FragmentA;
|
||||
using WarpFragmentB = typename Operator::FragmentB;
|
||||
Dequantizer warp_dequantizer_;
|
||||
|
||||
using WarpFragmentScale = typename Dequantizer::FragmentScale;
|
||||
using WarpFragmentZero = typename Dequantizer::FragmentZero;
|
||||
|
||||
using ElementA = typename IteratorA::Element;
|
||||
using ElementB = typename IteratorB::Element;
|
||||
using LayoutDetailsForB = kernel::LayoutDetailsB<ElementA, ElementB, ArchTag>;
|
||||
|
||||
static constexpr bool RequiresTileInterleave
|
||||
= layout::IsColumnMajorTileInterleave<typename LayoutDetailsForB::Layout>::value;
|
||||
static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)),
|
||||
"Layout K must match threadblockK");
|
||||
|
||||
protected:
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
||||
SmemIteratorA smem_iterator_A_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB smem_iterator_B_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of scale and zero operand to shared memory
|
||||
SmemIteratorScale smem_iterator_scale_;
|
||||
|
||||
public:
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
DqMmaPipelined(typename Base::SharedStorage&
|
||||
shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
int const group_size, ///< The group size for quantization
|
||||
int thread_idx, ///< ID within the threadblock
|
||||
int warp_idx, ///< ID of warp
|
||||
int lane_idx ///< ID of each thread within a warp
|
||||
)
|
||||
: Base(shared_storage, thread_idx, warp_idx, lane_idx)
|
||||
, warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)},
|
||||
{shared_storage.operand_zero.data(), LayoutScale(Shape::kN)},
|
||||
(warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx)
|
||||
, smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx)
|
||||
, smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx)
|
||||
, smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(),
|
||||
shared_storage.operand_zero.data(), {Base::kStages, Shape::kN}, thread_idx, group_size)
|
||||
{
|
||||
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
// _m: the warp's position within the threadblock along the M dimension
|
||||
// _n: the warp's position within the threadblock along the N dimension
|
||||
// _k: the warp's position within the threadblock along the K dimension
|
||||
|
||||
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
|
||||
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
|
||||
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
|
||||
|
||||
// Add per-warp offsets in units of warp-level tiles
|
||||
this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
|
||||
this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n});
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void copy_scales_and_advance(IteratorScale& iterator_scale)
|
||||
{
|
||||
using TransformScale = NumericArrayConverter<typename SmemIteratorScale::Element,
|
||||
typename FragmentScale::Element, FragmentScale::kElements>;
|
||||
|
||||
FragmentScale tb_frag_scales;
|
||||
FragmentScale tb_frag_zeros;
|
||||
tb_frag_scales.clear();
|
||||
tb_frag_zeros.clear();
|
||||
|
||||
TransformScale transformScale;
|
||||
|
||||
using FragmentElement = typename FragmentScale::Element;
|
||||
|
||||
auto gmem_scale_ptr = iterator_scale.get_scale();
|
||||
auto gmem_zero_ptr = iterator_scale.get_zero();
|
||||
|
||||
arch::global_load<FragmentScale, sizeof(FragmentScale)>(tb_frag_scales, gmem_scale_ptr, iterator_scale.valid());
|
||||
|
||||
if (gmem_zero_ptr != nullptr)
|
||||
{
|
||||
arch::global_load<FragmentScale, sizeof(FragmentScale)>(
|
||||
tb_frag_zeros, gmem_zero_ptr, iterator_scale.valid());
|
||||
}
|
||||
|
||||
typename TransformScale::result_type tb_frag_scales_fp16 = transformScale(tb_frag_scales);
|
||||
typename TransformScale::result_type tb_frag_zeros_fp16;
|
||||
if (gmem_zero_ptr != nullptr)
|
||||
tb_frag_zeros_fp16 = transformScale(tb_frag_zeros);
|
||||
|
||||
auto frag_scale_ptr_fp16 = reinterpret_cast<typename SmemIteratorScale::Element*>(&tb_frag_scales_fp16);
|
||||
auto frag_zero_ptr_fp16 = reinterpret_cast<typename SmemIteratorScale::Element*>(&tb_frag_zeros_fp16);
|
||||
auto smem_scale_ptr = this->smem_iterator_scale_.get_scale();
|
||||
auto smem_zero_ptr = this->smem_iterator_scale_.get_zero();
|
||||
|
||||
if (iterator_scale.valid())
|
||||
{
|
||||
auto smem_offset = cast_smem_ptr_to_uint(smem_scale_ptr);
|
||||
arch::shared_store<sizeof(FragmentScale)>(smem_offset, frag_scale_ptr_fp16);
|
||||
|
||||
if (gmem_zero_ptr != nullptr)
|
||||
{
|
||||
smem_offset = cast_smem_ptr_to_uint(smem_zero_ptr);
|
||||
arch::shared_store<sizeof(FragmentScale)>(smem_offset, frag_zero_ptr_fp16);
|
||||
}
|
||||
}
|
||||
|
||||
if (iterator_scale.group_size_ == 64)
|
||||
{
|
||||
iterator_scale.add_tile_offset({1, 0});
|
||||
}
|
||||
else if (iterator_scale.group_size_ == 128)
|
||||
{
|
||||
if constexpr (Shape::kK == 128)
|
||||
{
|
||||
iterator_scale.add_tile_offset({1, 0});
|
||||
}
|
||||
else if constexpr (Shape::kK == 64)
|
||||
{
|
||||
if (iterator_scale.row_groupsize64_ & 0x1)
|
||||
{
|
||||
iterator_scale.add_tile_offset({1, 0});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(Shape::kK == 0, "Unsupported k tile shape, can only be 64 or 128");
|
||||
}
|
||||
}
|
||||
|
||||
iterator_scale.row_groupsize64_++;
|
||||
|
||||
this->smem_iterator_scale_.add_tile_offset({1, 0});
|
||||
}
|
||||
|
||||
/// Perform a threadblock-scoped matrix multiply-accumulate
|
||||
CUTLASS_DEVICE
|
||||
void operator()(int gemm_k_iterations, ///< number of iterations of the mainloop
|
||||
FragmentC& accum, ///< destination accumulator tile
|
||||
IteratorA iterator_A, ///< iterator over A operand in global memory
|
||||
IteratorB iterator_B, ///< iterator over B operand in global memory
|
||||
IteratorScale iterator_scale, ///< iterator over scale operand in global memory
|
||||
FragmentC const& src_accum)
|
||||
{ ///< source accumulator tile
|
||||
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
TransformBAfterLDG ldg_converter;
|
||||
TransformBAfterLDS lds_converter;
|
||||
|
||||
using TransformA
|
||||
= NumericArrayConverter<typename WarpFragmentA::Element, typename FragmentA::Element, FragmentA::kElements>;
|
||||
|
||||
// These transforms are mainly to handle when we have bfloat activations and weights in GMEM and want
|
||||
// to issue HMMA on architectures older than Ampere. We will convert to FP16 before STS.
|
||||
TransformA transformA;
|
||||
|
||||
// Perform accumulation in the 'd' output operand
|
||||
accum = src_accum;
|
||||
|
||||
FragmentA tb_frag_A;
|
||||
FragmentB tb_frag_B;
|
||||
|
||||
tb_frag_A.clear();
|
||||
tb_frag_B.clear();
|
||||
|
||||
// The last kblock is loaded in the prolog
|
||||
iterator_A.load(tb_frag_A);
|
||||
iterator_B.load(tb_frag_B);
|
||||
|
||||
++iterator_A;
|
||||
++iterator_B;
|
||||
|
||||
this->smem_iterator_A_.store(transformA(tb_frag_A));
|
||||
this->smem_iterator_B_.store(ldg_converter(tb_frag_B));
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
++this->smem_iterator_B_;
|
||||
|
||||
copy_scales_and_advance(iterator_scale);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math instructions
|
||||
WarpFragmentA warp_frag_A[2];
|
||||
WarpFragmentB warp_frag_B[2];
|
||||
WarpFragmentScale warp_frag_scales;
|
||||
WarpFragmentZero warp_frag_zero;
|
||||
|
||||
this->warp_tile_iterator_A_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(0);
|
||||
|
||||
this->warp_tile_iterator_A_.load(warp_frag_A[0]);
|
||||
this->warp_tile_iterator_B_.load(warp_frag_B[0]);
|
||||
|
||||
warp_dequantizer_.load(warp_frag_scales, warp_frag_zero);
|
||||
|
||||
++this->warp_tile_iterator_A_;
|
||||
++this->warp_tile_iterator_B_;
|
||||
warp_dequantizer_.add_pointer_offset(Shape::kN);
|
||||
|
||||
Operator warp_mma;
|
||||
|
||||
int smem_write_stage_idx = 1;
|
||||
|
||||
// Avoid reading out of bounds
|
||||
iterator_A.clear_mask(gemm_k_iterations <= 1);
|
||||
iterator_B.clear_mask(gemm_k_iterations <= 1);
|
||||
iterator_scale.clear_mask(gemm_k_iterations <= 1);
|
||||
|
||||
// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
|
||||
// shared memory loads (which have the tighest latency requirement).
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
// Note: The main loop does not support Base::kWarpGemmIterations == 2.
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; gemm_k_iterations > 0; --gemm_k_iterations)
|
||||
{
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k)
|
||||
{
|
||||
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if this is the last group
|
||||
// as the case may be.
|
||||
|
||||
if (warp_mma_k == Base::kWarpGemmIterations - 1)
|
||||
{
|
||||
|
||||
// Write fragments to shared memory
|
||||
this->smem_iterator_A_.store(transformA(tb_frag_A));
|
||||
|
||||
this->smem_iterator_B_.store(ldg_converter(tb_frag_B));
|
||||
|
||||
__syncthreads();
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
++this->smem_iterator_B_;
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
|
||||
if (smem_write_stage_idx == 1)
|
||||
{
|
||||
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
|
||||
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
|
||||
this->smem_iterator_scale_.add_tile_offset({-Base::kStages, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
this->warp_tile_iterator_A_.add_tile_offset(
|
||||
{0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
|
||||
this->warp_tile_iterator_B_.add_tile_offset(
|
||||
{-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0});
|
||||
warp_dequantizer_.add_pointer_offset(-Base::kStages * Shape::kN);
|
||||
}
|
||||
|
||||
smem_write_stage_idx ^= 1;
|
||||
}
|
||||
|
||||
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
|
||||
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
|
||||
++this->warp_tile_iterator_A_;
|
||||
|
||||
int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
|
||||
int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
|
||||
// We are just about to finish computing on a fragment of B, so initiate the load for the next fragment.
|
||||
if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1)
|
||||
{
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(
|
||||
(warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB);
|
||||
this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]);
|
||||
++this->warp_tile_iterator_B_;
|
||||
}
|
||||
|
||||
if (warp_mma_k == 0)
|
||||
{
|
||||
|
||||
iterator_A.load(tb_frag_A);
|
||||
iterator_B.load(tb_frag_B);
|
||||
|
||||
++iterator_A;
|
||||
++iterator_B;
|
||||
|
||||
copy_scales_and_advance(iterator_scale);
|
||||
|
||||
// Avoid reading out of bounds if this was the last loop iteration
|
||||
iterator_A.clear_mask(gemm_k_iterations <= 2);
|
||||
iterator_B.clear_mask(gemm_k_iterations <= 2);
|
||||
iterator_scale.clear_mask(gemm_k_iterations <= 2);
|
||||
}
|
||||
|
||||
typename TransformBAfterLDS::result_type converted_frag_B
|
||||
= lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);
|
||||
warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales, warp_frag_zero);
|
||||
run_warp_mma(
|
||||
warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset);
|
||||
}
|
||||
|
||||
// Load the scales needed for the next tile iteration
|
||||
warp_dequantizer_.load(warp_frag_scales, warp_frag_zero);
|
||||
// Update internal pointer to the set of scales in shared memory
|
||||
warp_dequantizer_.add_pointer_offset(Shape::kN);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -1,399 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h"
|
||||
#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h"
|
||||
#include "cutlass_extensions/interleaved_numeric_conversion.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h"
|
||||
#include "cutlass_extensions/gemm_configs.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Iterates over tiles of A operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
||||
typename IteratorA_,
|
||||
/// Iterates over tiles of A operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorA_,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
||||
typename IteratorB_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB_,
|
||||
/// Iterators over scales in global memory
|
||||
typename IteratorScale_,
|
||||
/// Iterators over scales in shared memory
|
||||
typename SmemIteratorScale_,
|
||||
/// Data type of accumulator matrix
|
||||
typename ElementC_,
|
||||
/// Layout of accumulator matrix
|
||||
typename LayoutC_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy_,
|
||||
/// Converter for B matrix applied immediately after the LDG (before STS)
|
||||
typename TransformBAfterLDG_,
|
||||
/// Converter for B matrix applited immediately after the LDS
|
||||
typename TransformBAfterLDS_,
|
||||
/// The quantization operator being used
|
||||
WeightOnlyQuantOp QuantOp_>
|
||||
class DqMmaPipelined<Shape_, IteratorA_, SmemIteratorA_, IteratorB_, SmemIteratorB_, IteratorScale_, SmemIteratorScale_,
|
||||
ElementC_, LayoutC_, Policy_, TransformBAfterLDG_, TransformBAfterLDS_, QuantOp_,
|
||||
std::enable_if_t<!isFinegrained(QuantOp_)>>
|
||||
: public DqMmaBase<Shape_, Policy_, typename SmemIteratorScale_::Element, 2, QuantOp_>
|
||||
{
|
||||
public:
|
||||
///< Base class
|
||||
using Base = DqMmaBase<Shape_, Policy_, typename SmemIteratorScale_::Element, 2, QuantOp_>;
|
||||
|
||||
using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory
|
||||
using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory
|
||||
using ElementC = ElementC_; ///< Data type of accumulator matrix
|
||||
using LayoutC = LayoutC_; ///< Layout of accumulator matrix
|
||||
using Policy = Policy_; ///< Policy describing tuning details
|
||||
|
||||
using IteratorScale = IteratorScale_;
|
||||
using ElementScale = typename IteratorScale::Element;
|
||||
using LayoutScale = typename IteratorScale::Layout;
|
||||
|
||||
using SmemIteratorA = SmemIteratorA_;
|
||||
using SmemIteratorB = SmemIteratorB_;
|
||||
using SmemIteratorScale = SmemIteratorScale_;
|
||||
|
||||
using TransformBAfterLDG = TransformBAfterLDG_;
|
||||
using TransformBAfterLDS = TransformBAfterLDS_;
|
||||
|
||||
static constexpr WeightOnlyQuantOp QuantOp = QuantOp_;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Fragment of operand A loaded from global memory
|
||||
using FragmentA = typename IteratorA::Fragment;
|
||||
|
||||
/// Fragment of operand B loaded from global memory
|
||||
using FragmentB = typename IteratorB::Fragment;
|
||||
|
||||
/// Fragment of operand Scale loaded from global memory;
|
||||
using FragmentScale = typename IteratorScale::Fragment;
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC = typename Policy::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator = typename Policy::Operator;
|
||||
|
||||
/// Obtain the arch tag from the warp-level operator
|
||||
using ArchTag = typename Policy::Operator::ArchTag;
|
||||
|
||||
using Dequantizer = warp::MmaTensorOpDequantizer<Operator, typename Base::WarpGemm, Operand::kB,
|
||||
typename SmemIteratorScale::Fragment::Element, LayoutScale, 32, QuantOp>;
|
||||
|
||||
/// Complex transform on A operand
|
||||
static ComplexTransform const kTransformA = Operator::kTransformA;
|
||||
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB = Operator::kTransformB;
|
||||
|
||||
// staticaly assert kStages for DqMmaPipelined is two (Double-buffered pipeline)
|
||||
static_assert((Base::kStages == 2), "DqMmaPipelined requires kStages set to value 2");
|
||||
|
||||
private:
|
||||
using WarpFragmentA = typename Operator::FragmentA;
|
||||
using WarpFragmentB = typename Operator::FragmentB;
|
||||
Dequantizer warp_dequantizer_;
|
||||
|
||||
using ElementA = typename IteratorA::Element;
|
||||
using ElementB = typename IteratorB::Element;
|
||||
using LayoutDetailsForB = kernel::LayoutDetailsB<ElementA, ElementB, ArchTag>;
|
||||
|
||||
static constexpr bool RequiresTileInterleave
|
||||
= layout::IsColumnMajorTileInterleave<typename LayoutDetailsForB::Layout>::value;
|
||||
static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)),
|
||||
"Layout K must match threadblockK");
|
||||
|
||||
protected:
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
||||
SmemIteratorA smem_iterator_A_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB smem_iterator_B_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of scale operand to shared memory
|
||||
SmemIteratorScale smem_iterator_scale_;
|
||||
|
||||
public:
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
DqMmaPipelined(typename Base::SharedStorage&
|
||||
shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
int const group_size, ///< Will not be used, just to adapt to finegrained modifications and make the compilation
|
||||
///< successful. Because DqMmaPipelined is only enabled for sm<80, so even if this
|
||||
///< argument is not added, it does not affect compilation for sm>=80.
|
||||
int thread_idx, ///< ID within the threadblock
|
||||
int warp_idx, ///< ID of warp
|
||||
int lane_idx ///< ID of each thread within a warp
|
||||
)
|
||||
: Base(shared_storage, thread_idx, warp_idx, lane_idx)
|
||||
, warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)},
|
||||
(warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx)
|
||||
, smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx)
|
||||
, smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx)
|
||||
, smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), {1, Shape::kN}, thread_idx)
|
||||
{
|
||||
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
// _m: the warp's position within the threadblock along the M dimension
|
||||
// _n: the warp's position within the threadblock along the N dimension
|
||||
// _k: the warp's position within the threadblock along the K dimension
|
||||
|
||||
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
|
||||
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
|
||||
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
|
||||
|
||||
// Add per-warp offsets in units of warp-level tiles
|
||||
this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
|
||||
this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n});
|
||||
}
|
||||
|
||||
/// Perform a threadblock-scoped matrix multiply-accumulate
|
||||
CUTLASS_DEVICE
|
||||
void operator()(int gemm_k_iterations, ///< number of iterations of the mainloop
|
||||
FragmentC& accum, ///< destination accumulator tile
|
||||
IteratorA iterator_A, ///< iterator over A operand in global memory
|
||||
IteratorB iterator_B, ///< iterator over B operand in global memory
|
||||
IteratorScale iterator_scale, ///< iterator over scale operand in global memory
|
||||
FragmentC const& src_accum)
|
||||
{ ///< source accumulator tile
|
||||
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
TransformBAfterLDG ldg_converter;
|
||||
TransformBAfterLDS lds_converter;
|
||||
|
||||
using TransformA
|
||||
= NumericArrayConverter<typename WarpFragmentA::Element, typename FragmentA::Element, FragmentA::kElements>;
|
||||
|
||||
using TransformScale = NumericArrayConverter<typename SmemIteratorScale::Fragment::Element,
|
||||
typename FragmentScale::Element, FragmentScale::kElements>;
|
||||
|
||||
// These transforms are mainly to handle when we have bfloat activations and weights in GMEM and want
|
||||
// to issue HMMA on architectures older than Ampere. We will convert to FP16 before STS.
|
||||
TransformA transformA;
|
||||
TransformScale transformScale;
|
||||
|
||||
// Perform accumulation in the 'd' output operand
|
||||
accum = src_accum;
|
||||
|
||||
FragmentA tb_frag_A;
|
||||
FragmentB tb_frag_B;
|
||||
FragmentScale tb_frag_scales;
|
||||
|
||||
using WarpFragmentScale = typename Dequantizer::FragmentScale;
|
||||
WarpFragmentScale warp_frag_scales;
|
||||
|
||||
tb_frag_A.clear();
|
||||
tb_frag_B.clear();
|
||||
tb_frag_scales.clear();
|
||||
|
||||
// The last kblock is loaded in the prolog
|
||||
iterator_A.load(tb_frag_A);
|
||||
iterator_B.load(tb_frag_B);
|
||||
iterator_scale.load(tb_frag_scales);
|
||||
|
||||
++iterator_A;
|
||||
++iterator_B;
|
||||
|
||||
this->smem_iterator_A_.store(transformA(tb_frag_A));
|
||||
this->smem_iterator_B_.store(ldg_converter(tb_frag_B));
|
||||
this->smem_iterator_scale_.store(transformScale(tb_frag_scales));
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
++this->smem_iterator_B_;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
warp_dequantizer_.load(warp_frag_scales);
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math instructions
|
||||
WarpFragmentA warp_frag_A[2];
|
||||
WarpFragmentB warp_frag_B[2];
|
||||
|
||||
this->warp_tile_iterator_A_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(0);
|
||||
|
||||
this->warp_tile_iterator_A_.load(warp_frag_A[0]);
|
||||
this->warp_tile_iterator_B_.load(warp_frag_B[0]);
|
||||
|
||||
++this->warp_tile_iterator_A_;
|
||||
++this->warp_tile_iterator_B_;
|
||||
|
||||
Operator warp_mma;
|
||||
|
||||
int smem_write_stage_idx = 1;
|
||||
|
||||
// Avoid reading out of bounds
|
||||
iterator_A.clear_mask(gemm_k_iterations <= 1);
|
||||
iterator_B.clear_mask(gemm_k_iterations <= 1);
|
||||
|
||||
// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
|
||||
// shared memory loads (which have the tighest latency requirement).
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
// Note: The main loop does not support Base::kWarpGemmIterations == 2.
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; gemm_k_iterations > 0; --gemm_k_iterations)
|
||||
{
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k)
|
||||
{
|
||||
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if this is the last group
|
||||
// as the case may be.
|
||||
|
||||
if (warp_mma_k == Base::kWarpGemmIterations - 1)
|
||||
{
|
||||
|
||||
// Write fragments to shared memory
|
||||
this->smem_iterator_A_.store(transformA(tb_frag_A));
|
||||
|
||||
this->smem_iterator_B_.store(ldg_converter(tb_frag_B));
|
||||
|
||||
__syncthreads();
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
++this->smem_iterator_B_;
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
|
||||
if (smem_write_stage_idx == 1)
|
||||
{
|
||||
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
|
||||
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
this->warp_tile_iterator_A_.add_tile_offset(
|
||||
{0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
|
||||
this->warp_tile_iterator_B_.add_tile_offset(
|
||||
{-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0});
|
||||
}
|
||||
|
||||
smem_write_stage_idx ^= 1;
|
||||
}
|
||||
|
||||
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
|
||||
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
|
||||
++this->warp_tile_iterator_A_;
|
||||
|
||||
int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
|
||||
int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
|
||||
// We are just about to finish computing on a fragment of B, so initiate the load for the next fragment.
|
||||
if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1)
|
||||
{
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(
|
||||
(warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB);
|
||||
this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]);
|
||||
++this->warp_tile_iterator_B_;
|
||||
}
|
||||
|
||||
if (warp_mma_k == 0)
|
||||
{
|
||||
|
||||
iterator_A.load(tb_frag_A);
|
||||
iterator_B.load(tb_frag_B);
|
||||
|
||||
++iterator_A;
|
||||
++iterator_B;
|
||||
|
||||
// Avoid reading out of bounds if this was the last loop iteration
|
||||
iterator_A.clear_mask(gemm_k_iterations <= 2);
|
||||
iterator_B.clear_mask(gemm_k_iterations <= 2);
|
||||
}
|
||||
|
||||
typename TransformBAfterLDS::result_type converted_frag_B
|
||||
= lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);
|
||||
warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales);
|
||||
run_warp_mma(
|
||||
warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -1,107 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Default warp-level GEMM operators selected by data type, size, and layouts of operands.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/warp/default_mma_tensor_op.h"
|
||||
#include "cutlass/gemm/warp/mma_tensor_op.h"
|
||||
|
||||
#include "cutlass_extensions/arch/mma.h"
|
||||
#include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h"
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace warp
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for m-by-n-by-kgroup
|
||||
template <
|
||||
/// Shape of one matrix production operation (concept: GemmShape)
|
||||
typename WarpShape_,
|
||||
/// Shape of one matrix production operation (concept: GemmShape)
|
||||
typename InstructionShape_,
|
||||
/// Data type of A elements,
|
||||
typename ElementA,
|
||||
/// Layout of A matrix (concept: MatrixLayout)
|
||||
typename LayoutA,
|
||||
/// Data type of B elements
|
||||
typename ElementB,
|
||||
/// Layout of B matrix (concept: MatrixLayout)
|
||||
typename LayoutB,
|
||||
/// Element type of C matrix
|
||||
typename ElementC,
|
||||
/// Layout of C matrix (concept: MatrixLayout)
|
||||
typename LayoutC,
|
||||
/// Number of partitions along K dimension
|
||||
int PartitionsK,
|
||||
/// Store the accumulators in row major or column major. Row major is used
|
||||
/// when output layout is interleaved.
|
||||
bool AccumulatorsInRowMajor>
|
||||
struct DefaultMmaTensorOp<WarpShape_, InstructionShape_, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
|
||||
arch::OpMultiplyAddDequantizeInterleavedBToA, PartitionsK, AccumulatorsInRowMajor>
|
||||
{
|
||||
|
||||
private:
|
||||
// Shape for computing the FP16s
|
||||
using ComputeInstructionShape = InstructionShape_;
|
||||
|
||||
// Chosen so we get K=16 for int8 and K=32 for int4.
|
||||
static constexpr int LoadInstructionK = 128 / sizeof_bits<ElementB>::value;
|
||||
|
||||
// Shape for loading the narrow data type from shared memory
|
||||
using LoadInstructionShape = GemmShape<InstructionShape_::kM, InstructionShape_::kN, LoadInstructionK>;
|
||||
|
||||
public:
|
||||
using Policy = cutlass::gemm::warp::MmaTensorOpPolicy<
|
||||
cutlass::arch::Mma<InstructionShape_, 32, ElementA, cutlass::layout::RowMajor, ElementA,
|
||||
cutlass::layout::ColumnMajor, ElementC, cutlass::layout::RowMajor, arch::OpMultiplyAdd>,
|
||||
cutlass::MatrixShape<1, 1>>;
|
||||
|
||||
// Define the warp-level tensor op
|
||||
using Type = cutlass::gemm::warp::MmaTensorOpComputeBWithF16<WarpShape_, ElementA, LayoutA, ElementB, LayoutB,
|
||||
ElementC, LayoutC, Policy, LoadInstructionShape, PartitionsK, AccumulatorsInRowMajor>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace warp
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -1,306 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Templates implementing warp-level matrix multiply-accumulate operations targeting
|
||||
Tensor Cores.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/platform/platform.h"
|
||||
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/arch/memory_sm75.h"
|
||||
#include "cutlass/arch/mma_sm75.h"
|
||||
#include "cutlass/arch/mma_sm80.h"
|
||||
#include "cutlass/arch/mma_sm89.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/warp/mma.h"
|
||||
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_policy.h"
|
||||
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h"
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace warp
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Data type of A elements
|
||||
typename ElementA_,
|
||||
/// Layout of A matrix (concept: MatrixLayout)
|
||||
typename LayoutA_,
|
||||
/// Data type of B elements
|
||||
typename ElementB_,
|
||||
/// Layout of B matrix (concept: MatrixLayout)
|
||||
typename LayoutB_,
|
||||
/// Element type of C matrix
|
||||
typename ElementC_,
|
||||
/// Layout of C matrix (concept: MatrixLayout)
|
||||
typename LayoutC_,
|
||||
/// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy)
|
||||
typename Policy_,
|
||||
/// Instruction shape to override shared memory iterators with
|
||||
typename SharedMemoryInstructionShape_,
|
||||
/// Number of partitions along K dimension
|
||||
int PartitionsK_ = 1,
|
||||
/// Store the accumulators in row major or column major. Row major is used
|
||||
/// when output layout is interleaved.
|
||||
bool AccumulatorsInRowMajor = false,
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool>
|
||||
class MmaTensorOpComputeBWithF16
|
||||
{
|
||||
public:
|
||||
/// Shape of warp-level matrix operation (concept: GemmShape)
|
||||
using Shape = Shape_;
|
||||
|
||||
/// Data type of multiplicand A
|
||||
using ElementA = ElementA_;
|
||||
|
||||
/// Layout of multiplicand A
|
||||
using LayoutA = LayoutA_;
|
||||
|
||||
/// Data type of multiplicand B
|
||||
using ElementB = ElementB_;
|
||||
|
||||
/// Layout of multiplicand B
|
||||
using LayoutB = LayoutB_;
|
||||
|
||||
/// Data type of accumulator matrix C
|
||||
using ElementC = ElementC_;
|
||||
|
||||
/// Layout of accumulator matrix C
|
||||
using LayoutC = LayoutC_;
|
||||
|
||||
/// Shape of the warp in units of thread (concept: MmaLanePolicySimt)
|
||||
using Policy = Policy_;
|
||||
|
||||
/// Underlying matrix multiply operator (concept: arch::Mma)
|
||||
using ArchMmaOperator = typename Policy::Operator;
|
||||
|
||||
/// Indicates math operator
|
||||
using MathOperator = typename ArchMmaOperator::Operator;
|
||||
|
||||
/// Architecture tag from underlying instruction
|
||||
using ArchTag = typename ArchMmaOperator::ArchTag;
|
||||
static_assert((platform::is_same<typename ArchMmaOperator::ElementA, half_t>::value
|
||||
&& platform::is_same<typename ArchMmaOperator::ElementB, half_t>::value)
|
||||
|| (platform::is_same<typename ArchMmaOperator::ElementA, bfloat16_t>::value
|
||||
&& platform::is_same<typename ArchMmaOperator::ElementB, bfloat16_t>::value
|
||||
&& ArchTag::kMinComputeCapability >= 80)
|
||||
|| (platform::is_same<typename ArchMmaOperator::ElementA, float_e4m3_t>::value
|
||||
&& platform::is_same<typename ArchMmaOperator::ElementB, float_e4m3_t>::value
|
||||
&& ArchTag::kMinComputeCapability >= 89),
|
||||
"MmaTensorOpCvtBToA only supports underlying HMMA/QMMA");
|
||||
|
||||
static_assert(platform::is_same<ElementA, half_t>::value
|
||||
|| (platform::is_same<ElementA, bfloat16_t>::value && ArchTag::kMinComputeCapability >= 80)
|
||||
|| (platform::is_same<ElementA, float_e4m3_t>::value && ArchTag::kMinComputeCapability >= 89),
|
||||
"MmaTensorOpCvtBToA only supports Fp16 A or Bf16 A on Ampere+, or FP8 on Ada");
|
||||
|
||||
/// Indicates class of matrix operator
|
||||
using OperatorClass = arch::OpClassTensorOp;
|
||||
|
||||
/// Shape of underlying instruction
|
||||
using InstructionShape = typename ArchMmaOperator::Shape;
|
||||
|
||||
/// Instruction shape to override shared memory iterators with
|
||||
using SharedMemoryInstructionShape = SharedMemoryInstructionShape_;
|
||||
|
||||
static_assert(
|
||||
SharedMemoryInstructionShape::kM == InstructionShape::kM, "M dimension of compute instruction must match load");
|
||||
static_assert(
|
||||
SharedMemoryInstructionShape::kN == InstructionShape::kN, "N dimension of compute instruction must match load");
|
||||
|
||||
static constexpr int kExpansionFactor = SharedMemoryInstructionShape::kK / InstructionShape::kK;
|
||||
|
||||
static_assert(!(Shape::kK % SharedMemoryInstructionShape::kK), "");
|
||||
|
||||
/// Complex transform on A operand
|
||||
static ComplexTransform const kTransformA = ComplexTransform::kNone;
|
||||
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB = ComplexTransform::kNone;
|
||||
|
||||
/// Number of threads participating in warp-level matrix product
|
||||
static int const kThreadCount = 32;
|
||||
|
||||
/// Number of partitions along K dimension
|
||||
static int const kPartitionsK = PartitionsK_;
|
||||
|
||||
public:
|
||||
/// Iterates over the A operand in memory
|
||||
using IteratorA
|
||||
= MmaTensorOpMultiplicandTileIterator<MatrixShape<Shape::kM, Shape::kK>, Operand::kA, ElementA, LayoutA,
|
||||
MatrixShape<InstructionShape::kM, InstructionShape::kK>, Policy::OpDelta::kRow, kThreadCount, kPartitionsK>;
|
||||
|
||||
/// Storage for A tile
|
||||
using FragmentA = typename IteratorA::Fragment;
|
||||
|
||||
/// Storage for transformed A tile
|
||||
using TransformedFragmentA = Array<typename ArchMmaOperator::ElementA, FragmentA::kElements>;
|
||||
|
||||
/// Iterates over the B operand in memory
|
||||
using IteratorB = MmaTensorOpMultiplicandTileIterator<MatrixShape<Shape::kK, Shape::kN>, Operand::kB, ElementB,
|
||||
LayoutB, MatrixShape<SharedMemoryInstructionShape::kK, InstructionShape::kN>, Policy::OpDelta::kRow,
|
||||
kThreadCount, kPartitionsK>;
|
||||
|
||||
/// Storage for B tile
|
||||
using FragmentB = typename IteratorB::Fragment;
|
||||
|
||||
/// Storage for transformed B tile
|
||||
using TransformedFragmentB = Array<typename ArchMmaOperator::ElementB, FragmentB::kElements>;
|
||||
|
||||
/// Iterates over the C operand in memory
|
||||
using IteratorC = MmaTensorOpAccumulatorTileIterator<MatrixShape<Shape::kM, Shape::kN>, ElementC, LayoutC,
|
||||
typename ArchMmaOperator::Shape, typename Policy::OpDelta>;
|
||||
|
||||
/// Storage for C tile
|
||||
using FragmentC = typename IteratorC::Fragment;
|
||||
|
||||
/// Number of mma operations performed
|
||||
using MmaIterations = MatrixShape<(Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM,
|
||||
(Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN>;
|
||||
|
||||
public:
|
||||
/// Underlying matrix multiply operator (concept: arch::Mma)
|
||||
ArchMmaOperator mma;
|
||||
|
||||
public:
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Ctor
|
||||
CUTLASS_DEVICE
|
||||
MmaTensorOpComputeBWithF16() {}
|
||||
|
||||
/// Performs a warp-level matrix multiply-accumulate operation
|
||||
CUTLASS_DEVICE
|
||||
void operator()(FragmentC& D, TransformedFragmentA const& A, TransformedFragmentB const& B, FragmentC const& C,
|
||||
int const warp_tileB_k_offset) const
|
||||
{
|
||||
|
||||
using MmaOperandA = typename ArchMmaOperator::FragmentA;
|
||||
using MmaOperandB = typename ArchMmaOperator::FragmentB;
|
||||
using MmaOperandC = typename ArchMmaOperator::FragmentC;
|
||||
|
||||
static_assert(
|
||||
TransformedFragmentB::kElements == MmaOperandB::kElements * kExpansionFactor * MmaIterations::kColumn,
|
||||
"Each thread should have a pack of mma registers for each column iteration AND for the expanded K dim of "
|
||||
"B");
|
||||
|
||||
D = C;
|
||||
|
||||
MmaOperandA const* ptr_A = reinterpret_cast<MmaOperandA const*>(&A);
|
||||
MmaOperandB const* ptr_B = reinterpret_cast<MmaOperandB const*>(&B);
|
||||
MmaOperandC* ptr_D = reinterpret_cast<MmaOperandC*>(&D);
|
||||
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
|
||||
// Serpentine visitation order maximizing reuse of Rb
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int n = 0; n < MmaIterations::kColumn; ++n)
|
||||
{
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int m = 0; m < MmaIterations::kRow; ++m)
|
||||
{
|
||||
|
||||
int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m);
|
||||
|
||||
int n_offsetB = warp_tileB_k_offset + kExpansionFactor * n;
|
||||
if (AccumulatorsInRowMajor)
|
||||
{ // matrix B is reordered
|
||||
mma(ptr_D[n + m_serpentine * MmaIterations::kColumn], ptr_A[m_serpentine], ptr_B[n_offsetB],
|
||||
ptr_D[n + m_serpentine * MmaIterations::kColumn]);
|
||||
}
|
||||
else
|
||||
{
|
||||
mma(ptr_D[m_serpentine + n * MmaIterations::kRow], ptr_A[m_serpentine], ptr_B[n_offsetB],
|
||||
ptr_D[m_serpentine + n * MmaIterations::kRow]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
||||
// Serpentine visitation order maximizing reuse of Ra
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int m = 0; m < MmaIterations::kRow; ++m)
|
||||
{
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int n = 0; n < MmaIterations::kColumn; ++n)
|
||||
{
|
||||
|
||||
int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n);
|
||||
|
||||
int n_serpentine_offsetB = warp_tileB_k_offset + kExpansionFactor * n_serpentine;
|
||||
if (AccumulatorsInRowMajor)
|
||||
{ // matrix B is reordered
|
||||
mma(ptr_D[n_serpentine + m * MmaIterations::kColumn], ptr_A[m], ptr_B[n_serpentine_offsetB],
|
||||
ptr_D[n_serpentine + m * MmaIterations::kColumn]);
|
||||
}
|
||||
else
|
||||
{
|
||||
mma(ptr_D[m + n_serpentine * MmaIterations::kRow], ptr_A[m], ptr_B[n_serpentine_offsetB],
|
||||
ptr_D[m + n_serpentine * MmaIterations::kRow]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
assert(0);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace warp
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -1,463 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/arch/memory_sm75.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/layout/pitch_linear.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
|
||||
#include "cutlass/functional.h"
|
||||
#include "cutlass/platform/platform.h"
|
||||
|
||||
#include "cutlass_extensions/weight_only_quant_op.h"
|
||||
#include "tensorrt_llm/common/cudaBf16Wrapper.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace warp
|
||||
{
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// Matrix multiply operator
|
||||
typename MmaOperator_,
|
||||
/// Size of the matrix to load (concept: MatrixShape)
|
||||
typename Shape_,
|
||||
/// Operand identity
|
||||
Operand Operand,
|
||||
/// Data type of Scale elements
|
||||
typename Element_,
|
||||
/// Layout of operand
|
||||
typename Layout_,
|
||||
/// Number of threads participating in one matrix operation
|
||||
int Threads,
|
||||
///
|
||||
WeightOnlyQuantOp QuantOp_,
|
||||
///
|
||||
typename Enable = void>
|
||||
class MmaTensorOpDequantizer;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Bfloat specialization for Ampere
|
||||
template <
|
||||
/// Underlying matrix multiply operator (concept: MmaTensorOp)
|
||||
typename MmaOperator_,
|
||||
/// Shape of the warp level matrix multiply (concept: GemmShape)
|
||||
typename Shape_,
|
||||
///
|
||||
WeightOnlyQuantOp QuantOp_>
|
||||
class MmaTensorOpDequantizer<MmaOperator_, Shape_, Operand::kB, bfloat16_t, layout::RowMajor, 32, QuantOp_,
|
||||
typename platform::enable_if<MmaOperator_::ArchTag::kMinComputeCapability >= 80
|
||||
&& platform::is_same<typename MmaOperator_::ArchMmaOperator::LayoutB, layout::ColumnMajor>::value>::type>
|
||||
{
|
||||
|
||||
public:
|
||||
/// Mma Operator
|
||||
using MmaOperator = MmaOperator_;
|
||||
|
||||
// The architecture specific mma ooperator being used
|
||||
using ArchMmaOperator = typename MmaOperator::ArchMmaOperator;
|
||||
|
||||
// Mma Instruction Shape
|
||||
using InstructionShape = typename ArchMmaOperator::Shape;
|
||||
|
||||
// This is the ratio of the load instruction vs the compute instruction.
|
||||
static constexpr int kExpansionFactor = MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK;
|
||||
|
||||
/// Type of the scales
|
||||
using ElementScale = bfloat16_t;
|
||||
|
||||
/// Fragment to hold B data before Mma
|
||||
using FragmentDequantizedOperand = Array<ElementScale, MmaOperator::FragmentB::kElements>;
|
||||
|
||||
// Fragment to hold scale data to apply to B before mma
|
||||
// We need 1 fp16 per matrix iteration in the N dimension
|
||||
static constexpr int kColsPerMmaPerThread = 1;
|
||||
using FragmentScale = Array<ElementScale, kColsPerMmaPerThread * MmaOperator::MmaIterations::kColumn>;
|
||||
using FragmentZero = Array<ElementScale, kColsPerMmaPerThread * MmaOperator::MmaIterations::kColumn>;
|
||||
|
||||
/// Warp mma shape
|
||||
using Shape = Shape_;
|
||||
|
||||
/// Layout of the scales in shared memory
|
||||
using Layout = layout::RowMajor;
|
||||
|
||||
/// TensorRef type for loading element from a tensor
|
||||
using TensorRef = TensorRef<ElementScale, Layout>;
|
||||
|
||||
static constexpr WeightOnlyQuantOp QuantOp = QuantOp_;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
MmaTensorOpDequantizer(TensorRef smem_scales, TensorRef smem_zeros, int const warp_idx_n, int const lane_idx)
|
||||
{
|
||||
int const warp_offset = warp_idx_n * Shape::kN;
|
||||
int const quad = lane_idx / 4;
|
||||
int const thread_offset = warp_offset + quad;
|
||||
pointer_scale_ = smem_scales.data() + thread_offset;
|
||||
if constexpr (hasZero(QuantOp))
|
||||
{
|
||||
pointer_zero_ = smem_zeros.data() + thread_offset;
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
MmaTensorOpDequantizer(TensorRef smem_scales, int const warp_idx_n, int const lane_idx)
|
||||
: MmaTensorOpDequantizer(smem_scales, TensorRef(), warp_idx_n, lane_idx)
|
||||
{
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void load(FragmentScale& scale_frag)
|
||||
{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter)
|
||||
{
|
||||
scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN];
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag)
|
||||
{
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16))
|
||||
using _MmaOperandB = typename ArchMmaOperator::FragmentB;
|
||||
using ExpandedMmaOperandB = Array<typename _MmaOperandB::Element, kExpansionFactor * _MmaOperandB::kElements>;
|
||||
static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn
|
||||
== FragmentDequantizedOperand::kElements,
|
||||
"");
|
||||
|
||||
__nv_bfloat16 const* scale_ptr = reinterpret_cast<__nv_bfloat16 const*>(&scale_frag);
|
||||
ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast<ExpandedMmaOperandB*>(&operand_frag);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter)
|
||||
{
|
||||
static_assert(ExpandedMmaOperandB::kElements % 2 == 0, "");
|
||||
|
||||
__nv_bfloat162 scalex2 = __bfloat162bfloat162(scale_ptr[mma_n_iter]);
|
||||
__nv_bfloat162* operand_bf16x2_ptr = reinterpret_cast<__nv_bfloat162*>(&operand_frag_ptr[mma_n_iter]);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii)
|
||||
{
|
||||
operand_bf16x2_ptr[ii] = __hmul2(operand_bf16x2_ptr[ii], scalex2);
|
||||
}
|
||||
}
|
||||
#else
|
||||
// Slow path not implemented here on purpose. If we need to do HMMA on older arch, scale conversion should
|
||||
// happen before scales are stored to shared memory and we should use the fp16 dequantizer. This will avoid
|
||||
// numerous conversion instructions in GEMM main loop.
|
||||
arch::device_breakpoint();
|
||||
#endif
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void load(FragmentScale& scale_frag, FragmentScale& zero_frag)
|
||||
{
|
||||
if constexpr (hasZero(QuantOp))
|
||||
{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter)
|
||||
{
|
||||
scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN];
|
||||
zero_frag[mma_n_iter] = pointer_zero_[mma_n_iter * InstructionShape::kN];
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter)
|
||||
{
|
||||
scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void dequantize(
|
||||
FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag, FragmentScale const& zero_frag)
|
||||
{
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16))
|
||||
using _MmaOperandB = typename ArchMmaOperator::FragmentB;
|
||||
using ExpandedMmaOperandB = Array<typename _MmaOperandB::Element, kExpansionFactor * _MmaOperandB::kElements>;
|
||||
static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn
|
||||
== FragmentDequantizedOperand::kElements,
|
||||
"");
|
||||
|
||||
__nv_bfloat16 const* scale_ptr = reinterpret_cast<__nv_bfloat16 const*>(&scale_frag);
|
||||
__nv_bfloat16 const* zero_ptr = reinterpret_cast<__nv_bfloat16 const*>(&zero_frag);
|
||||
|
||||
ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast<ExpandedMmaOperandB*>(&operand_frag);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter)
|
||||
{
|
||||
static_assert(ExpandedMmaOperandB::kElements % 2 == 0, "");
|
||||
|
||||
__nv_bfloat162 scalex2 = __bfloat162bfloat162(scale_ptr[mma_n_iter]);
|
||||
__nv_bfloat162 zerox2 = __bfloat162bfloat162(zero_ptr[mma_n_iter]);
|
||||
__nv_bfloat162* operand_bf16x2_ptr = reinterpret_cast<__nv_bfloat162*>(&operand_frag_ptr[mma_n_iter]);
|
||||
|
||||
if constexpr (hasZero(QuantOp))
|
||||
{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii)
|
||||
{
|
||||
operand_bf16x2_ptr[ii] = __hfma2(operand_bf16x2_ptr[ii], scalex2, zerox2);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii)
|
||||
{
|
||||
operand_bf16x2_ptr[ii] = __hmul2(operand_bf16x2_ptr[ii], scalex2);
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
// Slow path not implemented here on purpose. If we need to do HMMA on older arch, scale conversion should
|
||||
// happen before scales are stored to shared memory and we should use the fp16 dequantizer. This will avoid
|
||||
// numerous conversion instructions in GEMM main loop.
|
||||
arch::device_breakpoint();
|
||||
#endif
|
||||
}
|
||||
|
||||
// Adds a pointer offset in units of elements.
|
||||
CUTLASS_DEVICE
|
||||
void add_pointer_offset(int64_t const& offset)
|
||||
{
|
||||
static_assert(sizeof(ElementScale) > 1, "");
|
||||
pointer_scale_ += offset;
|
||||
pointer_zero_ += offset;
|
||||
}
|
||||
|
||||
private:
|
||||
ElementScale const* pointer_scale_;
|
||||
ElementScale const* pointer_zero_;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Specialization for Turing & Ampere
|
||||
template <
|
||||
/// Underlying matrix multiply operator (concept: MmaTensorOp)
|
||||
typename MmaOperator_,
|
||||
/// Shape of the warp level matrix multiply (concept: GemmShape)
|
||||
typename Shape_,
|
||||
///
|
||||
WeightOnlyQuantOp QuantOp_>
|
||||
class MmaTensorOpDequantizer<MmaOperator_, Shape_, Operand::kB, half_t, layout::RowMajor, 32, QuantOp_,
|
||||
typename platform::enable_if<MmaOperator_::ArchTag::kMinComputeCapability >= 75
|
||||
&& platform::is_same<typename MmaOperator_::ArchMmaOperator::LayoutB, layout::ColumnMajor>::value>::type>
|
||||
{
|
||||
|
||||
public:
|
||||
/// Mma Operator
|
||||
using MmaOperator = MmaOperator_;
|
||||
|
||||
// The architecture specific mma ooperator being used
|
||||
using ArchMmaOperator = typename MmaOperator::ArchMmaOperator;
|
||||
|
||||
// Mma Instruction Shape
|
||||
using InstructionShape = typename ArchMmaOperator::Shape;
|
||||
|
||||
// This is the ratio of the load instruction vs the compute instruction.
|
||||
static constexpr int kExpansionFactor = MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK;
|
||||
|
||||
/// Type of the scales
|
||||
using ElementScale = half_t;
|
||||
|
||||
/// Fragment to hold B data before Mma
|
||||
using FragmentDequantizedOperand = Array<ElementScale, MmaOperator::FragmentB::kElements>;
|
||||
|
||||
// Fragment to hold scale data to apply to B before mma
|
||||
// We need 1 fp16 per matrix iteration in the N dimension
|
||||
static constexpr int kColsPerMmaPerThread = 1;
|
||||
using FragmentScale = Array<ElementScale, kColsPerMmaPerThread * MmaOperator::MmaIterations::kColumn>;
|
||||
using FragmentZero = Array<ElementScale, kColsPerMmaPerThread * MmaOperator::MmaIterations::kColumn>;
|
||||
|
||||
/// Warp mma shape
|
||||
using Shape = Shape_;
|
||||
|
||||
/// Layout of the scales in shared memory
|
||||
using Layout = layout::RowMajor;
|
||||
|
||||
/// TensorRef type for loading element from a tensor
|
||||
using TensorRef = TensorRef<ElementScale, Layout>;
|
||||
|
||||
static constexpr WeightOnlyQuantOp QuantOp = QuantOp_;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
MmaTensorOpDequantizer(TensorRef smem_scales, TensorRef smem_zeros, int const warp_idx_n, int const lane_idx)
|
||||
{
|
||||
int const warp_offset = warp_idx_n * Shape::kN;
|
||||
int const quad = lane_idx / 4;
|
||||
int const thread_offset = warp_offset + quad;
|
||||
pointer_scale_ = smem_scales.data() + thread_offset;
|
||||
if constexpr (hasZero(QuantOp))
|
||||
{
|
||||
pointer_zero_ = smem_zeros.data() + thread_offset;
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
MmaTensorOpDequantizer(TensorRef smem_scales, int const warp_idx_n, int const lane_idx)
|
||||
: MmaTensorOpDequantizer(smem_scales, TensorRef(), warp_idx_n, lane_idx)
|
||||
{
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void load(FragmentScale& scale_frag)
|
||||
{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter)
|
||||
{
|
||||
scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN];
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag)
|
||||
{
|
||||
using _MmaOperandB = typename ArchMmaOperator::FragmentB;
|
||||
using ExpandedMmaOperandB
|
||||
= Array<typename FragmentDequantizedOperand::Element, kExpansionFactor * _MmaOperandB::kElements>;
|
||||
static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn
|
||||
== FragmentDequantizedOperand::kElements,
|
||||
"");
|
||||
|
||||
multiplies<ExpandedMmaOperandB> mul_op;
|
||||
|
||||
ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast<ExpandedMmaOperandB*>(&operand_frag);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter)
|
||||
{
|
||||
operand_frag_ptr[mma_n_iter] = mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]);
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void load(FragmentScale& scale_frag, FragmentScale& zero_frag)
|
||||
{
|
||||
if constexpr (hasZero(QuantOp))
|
||||
{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter)
|
||||
{
|
||||
scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN];
|
||||
zero_frag[mma_n_iter] = pointer_zero_[mma_n_iter * InstructionShape::kN];
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter)
|
||||
{
|
||||
scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void dequantize(
|
||||
FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag, FragmentScale const& zero_frag)
|
||||
{
|
||||
using _MmaOperandB = typename ArchMmaOperator::FragmentB;
|
||||
using ExpandedMmaOperandB
|
||||
= Array<typename FragmentDequantizedOperand::Element, kExpansionFactor * _MmaOperandB::kElements>;
|
||||
static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn
|
||||
== FragmentDequantizedOperand::kElements,
|
||||
"");
|
||||
|
||||
multiplies<ExpandedMmaOperandB> mul_op;
|
||||
ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast<ExpandedMmaOperandB*>(&operand_frag);
|
||||
|
||||
if constexpr (hasZero(QuantOp))
|
||||
{
|
||||
plus<ExpandedMmaOperandB> plus_op;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter)
|
||||
{
|
||||
operand_frag_ptr[mma_n_iter]
|
||||
= plus_op(mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]), zero_frag[mma_n_iter]);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter)
|
||||
{
|
||||
operand_frag_ptr[mma_n_iter] = mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Adds a pointer offset in units of elements.
|
||||
CUTLASS_DEVICE
|
||||
void add_pointer_offset(int64_t const& offset)
|
||||
{
|
||||
static_assert(sizeof(ElementScale) > 1, "");
|
||||
pointer_scale_ += offset;
|
||||
pointer_zero_ += offset;
|
||||
}
|
||||
|
||||
private:
|
||||
ElementScale const* pointer_scale_;
|
||||
ElementScale const* pointer_zero_;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace warp
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -1,224 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace cutlass_extensions
|
||||
{
|
||||
// Note: The shapes are in the format MxNxK. The K shape of the runtime config MUST match the K shape
|
||||
// in the kernel layout details when doing weight only quantization.
|
||||
enum class CutlassTileConfig
|
||||
{
|
||||
// Signals that we should run heuristics do choose a config
|
||||
Undefined,
|
||||
|
||||
// Signals that we should run heuristics do choose a config
|
||||
ChooseWithHeuristic,
|
||||
|
||||
// SiMT config
|
||||
CtaShape128x128x8_WarpShape64x64x8,
|
||||
|
||||
// TensorCore configs CTA_N = 128, CTA_K = 64
|
||||
// Warp configs for M=16
|
||||
CtaShape16x128x64_WarpShape16x32x64,
|
||||
// Warp configs for M=32
|
||||
CtaShape32x128x64_WarpShape32x32x64,
|
||||
|
||||
// Warp configs for M=64
|
||||
CtaShape64x128x64_WarpShape32x64x64,
|
||||
CtaShape64x64x128_WarpShape32x64x64,
|
||||
CtaShape64x128x64_WarpShape64x32x64,
|
||||
|
||||
// Warp configs for M=128
|
||||
CtaShape128x64x64_WarpShape64x32x64,
|
||||
CtaShape128x128x64_WarpShape64x32x64,
|
||||
CtaShape128x128x64_WarpShape64x64x64,
|
||||
CtaShape128x128x64_WarpShape128x32x64,
|
||||
CtaShape128x256x64_WarpShape64x64x64,
|
||||
|
||||
// Warp configs for M=256
|
||||
CtaShape256x128x64_WarpShape64x64x64,
|
||||
|
||||
// TensorCore config CTA_N = 64, CTA_K = 128
|
||||
CtaShape128x64x128_WarpShape64x32x128,
|
||||
|
||||
// TensorCore config CTA_N = 256, CTA_K = 64
|
||||
CtaShape16x256x64_WarpShape16x64x64,
|
||||
|
||||
// TensorCore config CTA_N = 256, CTA_K = 128
|
||||
CtaShape16x256x128_WarpShape16x64x128
|
||||
|
||||
};
|
||||
|
||||
enum class SplitKStyle
|
||||
{
|
||||
NO_SPLIT_K,
|
||||
SPLIT_K_SERIAL,
|
||||
STREAM_K, // Sm80+
|
||||
// SPLIT_K_PARALLEL // Not supported yet
|
||||
};
|
||||
|
||||
enum class CutlassTileConfigSM90
|
||||
{
|
||||
// Signals that we should run heuristics do choose a config
|
||||
Undefined,
|
||||
|
||||
// Signals that we should run heuristics do choose a config
|
||||
ChooseWithHeuristic,
|
||||
|
||||
// CTA configs for M=64
|
||||
CtaShape64x16x128B,
|
||||
CtaShape64x32x128B,
|
||||
CtaShape64x64x128B,
|
||||
CtaShape64x128x128B,
|
||||
CtaShape64x256x128B,
|
||||
|
||||
// CTA configs for M=128
|
||||
CtaShape128x16x128B,
|
||||
CtaShape128x32x128B,
|
||||
CtaShape128x64x128B,
|
||||
CtaShape128x128x128B,
|
||||
CtaShape128x256x128B,
|
||||
|
||||
// CTA configs for M=128
|
||||
CtaShape256x128x128B,
|
||||
};
|
||||
|
||||
enum class MainloopScheduleType
|
||||
{
|
||||
AUTO // Automatically selects between pingpong and cooperative schedules on Hopper. On older architectures, this
|
||||
// defaults to the "legacy" main loop schedule.
|
||||
};
|
||||
|
||||
enum class EpilogueScheduleType
|
||||
{
|
||||
AUTO // Automatically chooses an epilogue schedule compatible with the selected main loop schedule for Hopper. For
|
||||
// architectures older than hopper, the epilogue is always performed by the same thread block as the main loop.
|
||||
};
|
||||
|
||||
enum class ClusterShape
|
||||
{
|
||||
ClusterShape_1x1x1,
|
||||
ClusterShape_2x1x1,
|
||||
ClusterShape_1x2x1,
|
||||
ClusterShape_2x2x1,
|
||||
ClusterShape_1x8x1,
|
||||
ClusterShape_8x1x1
|
||||
};
|
||||
|
||||
struct CutlassGemmConfig
|
||||
{
|
||||
enum CandidateConfigTypeParam : int
|
||||
{
|
||||
NONE = 0,
|
||||
WEIGHT_ONLY = 1u << 0,
|
||||
SIMT_ONLY = 1u << 1,
|
||||
INT8_ONLY = 1u << 2,
|
||||
HOPPER = 1u << 3,
|
||||
GROUPED_GEMM = 1u << 4,
|
||||
FP8_ONLY = 1u << 5,
|
||||
};
|
||||
|
||||
CutlassTileConfig tile_config = CutlassTileConfig::ChooseWithHeuristic;
|
||||
SplitKStyle split_k_style = SplitKStyle::NO_SPLIT_K;
|
||||
int split_k_factor = -1;
|
||||
int stages = -1;
|
||||
|
||||
// config options for sm90
|
||||
CutlassTileConfigSM90 tile_config_sm90 = CutlassTileConfigSM90::ChooseWithHeuristic;
|
||||
MainloopScheduleType mainloop_schedule = MainloopScheduleType::AUTO;
|
||||
EpilogueScheduleType epilogue_schedule = EpilogueScheduleType::AUTO;
|
||||
ClusterShape cluster_shape = ClusterShape::ClusterShape_1x1x1;
|
||||
bool is_sm90 = false;
|
||||
|
||||
CutlassGemmConfig() {}
|
||||
|
||||
CutlassGemmConfig(CutlassTileConfig tile_config, SplitKStyle split_k_style, int split_k_factor, int stages)
|
||||
: tile_config(tile_config)
|
||||
, split_k_style(split_k_style)
|
||||
, split_k_factor(split_k_factor)
|
||||
, stages(stages)
|
||||
, is_sm90(false)
|
||||
{
|
||||
}
|
||||
|
||||
CutlassGemmConfig(CutlassTileConfigSM90 tile_config_sm90, MainloopScheduleType mainloop_schedule,
|
||||
EpilogueScheduleType epilogue_schedule, ClusterShape cluster_shape)
|
||||
: tile_config_sm90(tile_config_sm90)
|
||||
, mainloop_schedule(mainloop_schedule)
|
||||
, epilogue_schedule(epilogue_schedule)
|
||||
, cluster_shape(cluster_shape)
|
||||
, is_sm90(true)
|
||||
{
|
||||
}
|
||||
|
||||
std::string toString() const
|
||||
{
|
||||
std::stringstream tactic;
|
||||
tactic << "Cutlass GEMM Tactic";
|
||||
if (tile_config_sm90 != tensorrt_llm::cutlass_extensions::CutlassTileConfigSM90::ChooseWithHeuristic)
|
||||
{
|
||||
assert(is_sm90 && "Invalid cutlass GEMM config");
|
||||
tactic << "\n\tstyle=TMA"
|
||||
<< "\n\ttile shape ID: " << (int) tile_config_sm90 << "\n\tcluster shape ID: " << (int) cluster_shape
|
||||
<< "\n\tmainloop sched: " << (int) mainloop_schedule << "\n\tepi sched: " << (int) epilogue_schedule;
|
||||
}
|
||||
else if (tile_config != tensorrt_llm::cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic)
|
||||
{
|
||||
assert(!is_sm90 && "Invalid cutlass GEMM config");
|
||||
tactic << "\n\tstyle=compatible"
|
||||
<< "\n\ttile shape ID: " << (int) tile_config << "\n\tstages: " << (int) stages
|
||||
<< "\n\tsplit k: " << (int) split_k_factor;
|
||||
}
|
||||
else
|
||||
{
|
||||
tactic << "\n\tundefined";
|
||||
}
|
||||
tactic << "\n";
|
||||
return tactic.str();
|
||||
}
|
||||
};
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& out, CutlassGemmConfig const& config)
|
||||
{
|
||||
// clang-format off
|
||||
if (config.is_sm90)
|
||||
{
|
||||
out << "tile_config_sm90_enum: " << int(config.tile_config_sm90)
|
||||
<< ", mainloop_schedule_enum: " << int(config.mainloop_schedule)
|
||||
<< ", epilogue_schedule_enum: " << int(config.epilogue_schedule)
|
||||
<< ", cluster_shape_enum: " << int(config.cluster_shape);
|
||||
}
|
||||
else
|
||||
{
|
||||
out << "tile_config_enum: " << int(config.tile_config)
|
||||
<< ", split_k_style_enum: " << int(config.split_k_style)
|
||||
<< ", split_k_factor: " << config.split_k_factor
|
||||
<< ", stages: " << config.stages;
|
||||
}
|
||||
// clang-format on
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace cutlass_extensions
|
||||
} // namespace tensorrt_llm
|
||||
@@ -1,447 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*!
|
||||
\file
|
||||
\brief Boost-like numeric conversion operator for int8 and CUTLASS int4b_t interleaved in a register
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/half.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
|
||||
// This converter is meant to be used with data interleaved in a 32-bit register where the even elements are in the low
|
||||
// bits and the odd elemeents are in the high bits of the register. In addition, it assumes elements were originally
|
||||
// signed and had a bias of 2**(b-1) added (where b is the number of bits in the type) to make all numbers unsigned.
|
||||
// This converter will uninterleave the data and subtract the bias while converting to the result type.
|
||||
template <typename T, typename S, int N>
|
||||
struct FastInterleavedAndBiasedNumericArrayConverter
|
||||
{
|
||||
};
|
||||
|
||||
template <>
|
||||
struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint8_t, 4>
|
||||
{
|
||||
using result_type = Array<half_t, 4>;
|
||||
using source_type = Array<uint8_t, 4>;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source)
|
||||
{
|
||||
result_type result;
|
||||
|
||||
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
|
||||
uint32_t const i8s = reinterpret_cast<uint32_t const&>(source);
|
||||
|
||||
static constexpr uint32_t mask_for_elt_01 = 0x5250;
|
||||
static constexpr uint32_t mask_for_elt_23 = 0x5351;
|
||||
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
|
||||
asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[0]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_01));
|
||||
asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[1]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_23));
|
||||
|
||||
// Lastly, we subtract 1152 from our constructed number using fp16 math to get our signed integer as fp16.
|
||||
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(I8s_TO_F16s_MAGIC_NUM));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(I8s_TO_F16s_MAGIC_NUM));
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s)
|
||||
{
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
|
||||
template <int N>
|
||||
struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint8_t, N>
|
||||
{
|
||||
static constexpr int VEC_WIDTH = 4;
|
||||
static_assert(!(N % VEC_WIDTH), "N must be multiple of 4.");
|
||||
|
||||
using result_type = Array<half_t, N>;
|
||||
using source_type = Array<uint8_t, N>;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source)
|
||||
{
|
||||
using scalar_result_type = typename result_type::Element;
|
||||
using scalar_source_type = typename source_type::Element;
|
||||
FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type, scalar_source_type, VEC_WIDTH>
|
||||
convert_vector_;
|
||||
|
||||
result_type result;
|
||||
using vec_result = Array<scalar_result_type, VEC_WIDTH>;
|
||||
using vec_source = Array<scalar_source_type, VEC_WIDTH>;
|
||||
|
||||
vec_result* result_ptr = reinterpret_cast<vec_result*>(&result);
|
||||
vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N / VEC_WIDTH; ++i)
|
||||
{
|
||||
result_ptr[i] = convert_vector_(source_ptr[i]);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s)
|
||||
{
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint8_t, 4>
|
||||
{
|
||||
using result_type = Array<bfloat16_t, 4>;
|
||||
using source_type = Array<uint8_t, 4>;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source)
|
||||
{
|
||||
result_type result;
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
|
||||
|
||||
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(&result);
|
||||
uint32_t const i8s = reinterpret_cast<uint32_t const&>(source);
|
||||
|
||||
static constexpr uint32_t fp32_base = 0x4B000000;
|
||||
float fp32_intermediates[4];
|
||||
|
||||
// Construct FP32s, bfloat does not have enough mantissa for IADD trick
|
||||
uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
|
||||
fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650);
|
||||
fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7652);
|
||||
fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7651);
|
||||
fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653);
|
||||
|
||||
// Subtract out fp32_base + 128 to make the unsigned integer signed.
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < 4; ++ii)
|
||||
{
|
||||
fp32_intermediates[ii] -= 8388736.f;
|
||||
}
|
||||
|
||||
// Truncate the fp32 representation and pack up as bfloat16s.
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < 2; ++ii)
|
||||
{
|
||||
bf16_result_ptr[ii]
|
||||
= __byte_perm(fp32_intermediates_casted[2 * ii + 0], fp32_intermediates_casted[2 * ii + 1], 0x7632);
|
||||
}
|
||||
#else
|
||||
// Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use
|
||||
// HMMA on older hardware, they should Convert directly to FP16 using FP16 converters.
|
||||
result.clear(); // Suppress compiler warning
|
||||
arch::device_breakpoint();
|
||||
#endif
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s)
|
||||
{
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
|
||||
template <int N>
|
||||
struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint8_t, N>
|
||||
{
|
||||
static constexpr int VEC_WIDTH = 4;
|
||||
static_assert(!(N % VEC_WIDTH), "N must be multiple of 4.");
|
||||
|
||||
using result_type = Array<bfloat16_t, N>;
|
||||
using source_type = Array<uint8_t, N>;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source)
|
||||
{
|
||||
using scalar_result_type = typename result_type::Element;
|
||||
using scalar_source_type = typename source_type::Element;
|
||||
FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type, scalar_source_type, VEC_WIDTH>
|
||||
convert_vector_;
|
||||
|
||||
result_type result;
|
||||
using vec_result = Array<scalar_result_type, VEC_WIDTH>;
|
||||
using vec_source = Array<scalar_source_type, VEC_WIDTH>;
|
||||
|
||||
vec_result* result_ptr = reinterpret_cast<vec_result*>(&result);
|
||||
vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N / VEC_WIDTH; ++i)
|
||||
{
|
||||
result_ptr[i] = convert_vector_(source_ptr[i]);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s)
|
||||
{
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint4b_t, 8>
|
||||
{
|
||||
using result_type = Array<half_t, 8>;
|
||||
using source_type = Array<uint4b_t, 8>;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source)
|
||||
{
|
||||
result_type result;
|
||||
|
||||
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
|
||||
uint32_t const i4s = reinterpret_cast<uint32_t const&>(source);
|
||||
|
||||
// First, we extract the i4s and construct an intermediate fp16 number.
|
||||
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
|
||||
static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
|
||||
static constexpr uint32_t TOP_MASK = 0x00f000f0;
|
||||
static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
|
||||
|
||||
// Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing
|
||||
// format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions.
|
||||
// In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and
|
||||
// elt_67 to fp16 without having to shift them to the bottom bits before hand.
|
||||
|
||||
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue
|
||||
// immediately before required.
|
||||
const uint32_t top_i4s = i4s >> 8;
|
||||
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(h[0])
|
||||
: "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(h[1])
|
||||
: "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(h[2])
|
||||
: "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(h[3])
|
||||
: "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||
|
||||
// I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the
|
||||
// half2 ctor. In this case, I chose performance reliability over code readability.
|
||||
|
||||
// This is the half2 {1032, 1032} represented as an integer.
|
||||
static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
|
||||
// This is the half2 {1 / 16, 1 / 16} represented as an integer.
|
||||
static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
|
||||
// This is the half2 {-72, -72} represented as an integer.
|
||||
static constexpr uint32_t NEG_72 = 0xd480d480;
|
||||
|
||||
// Finally, we construct the output numbers.
|
||||
// Convert elt_01
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
|
||||
// Convert elt_23
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_72));
|
||||
// Convert elt_45
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
|
||||
// Convert elt_67
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_72));
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s)
|
||||
{
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
|
||||
template <int N>
|
||||
struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint4b_t, N>
|
||||
{
|
||||
static constexpr int VEC_WIDTH = 8;
|
||||
static_assert(!(N % VEC_WIDTH), "N must be multiple of 8.");
|
||||
|
||||
using result_type = Array<half_t, N>;
|
||||
using source_type = Array<uint4b_t, N>;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source)
|
||||
{
|
||||
using scalar_result_type = typename result_type::Element;
|
||||
using scalar_source_type = typename source_type::Element;
|
||||
FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type, scalar_source_type, VEC_WIDTH>
|
||||
convert_vector_;
|
||||
|
||||
result_type result;
|
||||
using vec_result = Array<scalar_result_type, VEC_WIDTH>;
|
||||
using vec_source = Array<scalar_source_type, VEC_WIDTH>;
|
||||
|
||||
vec_result* result_ptr = reinterpret_cast<vec_result*>(&result);
|
||||
vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N / VEC_WIDTH; ++i)
|
||||
{
|
||||
result_ptr[i] = convert_vector_(source_ptr[i]);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s)
|
||||
{
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint4b_t, 8>
|
||||
{
|
||||
using result_type = Array<bfloat16_t, 8>;
|
||||
using source_type = Array<uint4b_t, 8>;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source)
|
||||
{
|
||||
result_type result;
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
|
||||
|
||||
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
|
||||
uint32_t const source_i4s = reinterpret_cast<uint32_t const&>(source);
|
||||
|
||||
// First, we extract the i4s and construct an intermediate fp16 number.
|
||||
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
|
||||
static constexpr uint32_t MASK = 0x000f000f;
|
||||
static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300;
|
||||
|
||||
// We don't have enough mantissa to remove as much shift overhead as FP16, so we must loop.
|
||||
// No shift needed for first item.
|
||||
uint32_t i4s = source_i4s;
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(h[0])
|
||||
: "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 1; ii < result_type::kElements / 2; ++ii)
|
||||
{
|
||||
i4s >>= sizeof_bits<typename source_type::Element>::value;
|
||||
// (i4s & 0x000f000f) | 0x43004300
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(h[ii])
|
||||
: "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
|
||||
}
|
||||
|
||||
// This is the BF16 {-136, -136} represented as an integer.
|
||||
static constexpr uint32_t BF16_BIAS = 0xC308C308;
|
||||
static constexpr uint32_t BF16_ONE = 0x3F803F80;
|
||||
|
||||
// Finally, we construct the output numbers.
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < result_type::kElements / 2; ++ii)
|
||||
{
|
||||
// Since this section is for Ampere+, we use bf16 fma to do the bias subtraction
|
||||
asm("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[ii]) : "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS));
|
||||
}
|
||||
#else
|
||||
// Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use
|
||||
// HMMA on older hardware, they should Convert directly to FP16 using FP16 converters.
|
||||
arch::device_breakpoint();
|
||||
result.clear(); // Suppress compiler warning.
|
||||
#endif
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s)
|
||||
{
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
|
||||
template <int N>
|
||||
struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint4b_t, N>
|
||||
{
|
||||
static constexpr int VEC_WIDTH = 8;
|
||||
static_assert(!(N % VEC_WIDTH), "N must be multiple of 8.");
|
||||
|
||||
using result_type = Array<bfloat16_t, N>;
|
||||
using source_type = Array<uint4b_t, N>;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source)
|
||||
{
|
||||
using scalar_result_type = typename result_type::Element;
|
||||
using scalar_source_type = typename source_type::Element;
|
||||
FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type, scalar_source_type, VEC_WIDTH>
|
||||
convert_vector_;
|
||||
|
||||
result_type result;
|
||||
using vec_result = Array<scalar_result_type, VEC_WIDTH>;
|
||||
using vec_source = Array<scalar_source_type, VEC_WIDTH>;
|
||||
|
||||
vec_result* result_ptr = reinterpret_cast<vec_result*>(&result);
|
||||
vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N / VEC_WIDTH; ++i)
|
||||
{
|
||||
result_ptr[i] = convert_vector_(source_ptr[i]);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s)
|
||||
{
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -1,66 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines new layouts needed for MoE
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/pitch_linear_coord.h"
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace layout
|
||||
{
|
||||
|
||||
template <int RowsPerTile, int ColumnsInterleaved>
|
||||
struct ColumnMajorTileInterleave
|
||||
{
|
||||
static constexpr int kRowsPerTile = RowsPerTile;
|
||||
static constexpr int kColumnsInterleaved = ColumnsInterleaved;
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct IsColumnMajorTileInterleave
|
||||
{
|
||||
static constexpr bool value = false;
|
||||
};
|
||||
|
||||
template <int U, int V>
|
||||
struct IsColumnMajorTileInterleave<ColumnMajorTileInterleave<U, V>>
|
||||
{
|
||||
static constexpr bool value = true;
|
||||
};
|
||||
|
||||
} // namespace layout
|
||||
} // namespace cutlass
|
||||
@@ -1,250 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Templates for visiting scales to be used when dequantizing the weights for weight-only GEMM
|
||||
quantization.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/layout/pitch_linear.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/predicate_vector.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
#include "cutlass/transform/threadblock/predicated_tile_access_iterator_params.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace transform
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Shape, typename Element, typename Layout, int AdvanceRank, int Alignment>
|
||||
class FineGrainedScaleZeroIterator;
|
||||
|
||||
template <typename Shape_, typename Element_, int Alignment_>
|
||||
class FineGrainedScaleZeroIterator<Shape_, Element_, layout::RowMajor, 0, Alignment_>
|
||||
{
|
||||
public:
|
||||
using Shape = Shape_;
|
||||
using Element = Element_;
|
||||
using Layout = layout::RowMajor;
|
||||
static int const kAdvanceRank = 0;
|
||||
static int const kAlignment = Alignment_;
|
||||
|
||||
static int const kAccessesPerVector = 1;
|
||||
|
||||
/// Row index of scales corresponding to the groupsize of 64
|
||||
int row_groupsize64_;
|
||||
int group_size_;
|
||||
|
||||
using Index = typename Layout::Index;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
|
||||
using TensorRef = TensorRef<Element, Layout>;
|
||||
using TensorView = TensorView<Element, Layout>;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using Pointer = Element*;
|
||||
using NonConstPointer = typename platform::remove_const<Element>::type*;
|
||||
|
||||
using AccessType = AlignedArray<Element, kAlignment>;
|
||||
|
||||
using Fragment = cutlass::Array<Element, kAlignment>;
|
||||
|
||||
// For compatibility with existing iterator interface
|
||||
struct Params
|
||||
{
|
||||
LongIndex stride_ = 0;
|
||||
|
||||
/// amount (in byte) to increment pointer from first access of current tile
|
||||
/// to first access of next tile
|
||||
LongIndex inc_advance_ = 0;
|
||||
|
||||
// Default ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() {}
|
||||
|
||||
/// Construct the Params object given a pitch-linear tensor's layout
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Layout const& layout)
|
||||
: stride_(layout.stride(0))
|
||||
{
|
||||
inc_advance_ = Shape::kRow * stride_ * sizeof_bits<Element>::value / 8;
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
/// Internal pointer type permits fast address arithmetic
|
||||
using BytePointer = char*;
|
||||
|
||||
private:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Parameters object with precomputed internal state
|
||||
Params const params_;
|
||||
|
||||
/// Internal pointer to first access of tile
|
||||
BytePointer pointer_scale_;
|
||||
BytePointer pointer_zero_;
|
||||
|
||||
bool is_valid_ = false;
|
||||
|
||||
public:
|
||||
/// Constructs a TileIterator from its precomputed state, threadblock offset,
|
||||
/// and thread ID
|
||||
CUTLASS_DEVICE
|
||||
FineGrainedScaleZeroIterator(
|
||||
///< Precomputed parameters object
|
||||
Params const& params,
|
||||
///< Pointer to start of scale tensor
|
||||
Pointer pointer_scale,
|
||||
///< Pointer to start of zero tensor
|
||||
Pointer pointer_zero,
|
||||
///< Extent of the scale and bias
|
||||
TensorCoord extent,
|
||||
///< ID of each participating thread
|
||||
int thread_id,
|
||||
///< Initial offset of threadblock
|
||||
TensorCoord const& threadblock_offset,
|
||||
///< Group size
|
||||
int group_size)
|
||||
: params_(params)
|
||||
, pointer_scale_(reinterpret_cast<BytePointer>(const_cast<NonConstPointer>(pointer_scale)))
|
||||
, pointer_zero_(reinterpret_cast<BytePointer>(const_cast<NonConstPointer>(pointer_zero)))
|
||||
{
|
||||
row_groupsize64_ = threadblock_offset.row();
|
||||
group_size_ = group_size;
|
||||
|
||||
const LongIndex tb_row_byte_offset
|
||||
= threadblock_offset.row() / (group_size / 64) * params_.stride_ * sizeof_bits<Element>::value / 8;
|
||||
const LongIndex tb_col_byte_offset = threadblock_offset.column() * sizeof_bits<Element>::value / 8;
|
||||
pointer_scale_ += (tb_row_byte_offset + tb_col_byte_offset);
|
||||
|
||||
if (pointer_zero_ != nullptr)
|
||||
{
|
||||
pointer_zero_ += (tb_row_byte_offset + tb_col_byte_offset);
|
||||
}
|
||||
|
||||
static constexpr int THREADS_PER_ROW = Shape::kColumn / kAlignment;
|
||||
|
||||
int const thread_row = thread_id / THREADS_PER_ROW;
|
||||
int const thread_col = thread_id % THREADS_PER_ROW;
|
||||
|
||||
const LongIndex thread_row_byte_offset = thread_row * params_.stride_ * sizeof_bits<Element>::value / 8;
|
||||
const LongIndex thread_col_byte_offset = thread_col * kAlignment * sizeof_bits<Element>::value / 8;
|
||||
pointer_scale_ += (thread_row_byte_offset + thread_col_byte_offset);
|
||||
if (pointer_zero_ != nullptr)
|
||||
{
|
||||
pointer_zero_ += (thread_row_byte_offset + thread_col_byte_offset);
|
||||
}
|
||||
|
||||
// For the rows, we must check that we are within the extent AND the tile to avoid extra reads on
|
||||
// a given iteration. The same threads will be responsible for issues reads since the number of scales
|
||||
// read in a given iteration is a constant. Therefore, we should never have to update is_valid_
|
||||
// outside of the constructor.
|
||||
int const global_row = threadblock_offset.row() + thread_row;
|
||||
int const global_col = threadblock_offset.column() + thread_col * kAlignment;
|
||||
|
||||
bool const row_in_bounds = global_row < extent.row() && thread_row < Shape::kRow;
|
||||
bool const col_in_bounds = global_col < extent.column();
|
||||
|
||||
is_valid_ = row_in_bounds && col_in_bounds;
|
||||
}
|
||||
|
||||
/// Construct a PredicatedTileAccessIterator with zero threadblock offset
|
||||
CUTLASS_HOST_DEVICE FineGrainedScaleZeroIterator(Params const& params, ///< Precomputed parameters object
|
||||
Pointer pointer_scale, ///< Pointer to start of scale tensor
|
||||
Pointer pointer_zero, ///< Pointer to start of zero tensor
|
||||
TensorCoord extent, ///< Extent of tensor
|
||||
int thread_id, ///< ID of each participating thread
|
||||
int group_size)
|
||||
: FineGrainedScaleZeroIterator(
|
||||
params, pointer_scale, pointer_zero, extent, thread_id, make_Coord(0, 0), group_size)
|
||||
{
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void add_tile_offset(TensorCoord const& tile_offset)
|
||||
{
|
||||
const LongIndex row_byte_offset = tile_offset.row() * params_.inc_advance_;
|
||||
const LongIndex col_byte_offset = tile_offset.column() * Shape::kColumn * sizeof_bits<Element>::value / 8;
|
||||
pointer_scale_ += row_byte_offset + col_byte_offset;
|
||||
if (pointer_zero_ != nullptr)
|
||||
{
|
||||
pointer_zero_ += row_byte_offset + col_byte_offset;
|
||||
}
|
||||
}
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE void clear_mask(bool enable = true)
|
||||
{
|
||||
is_valid_ &= (!enable);
|
||||
}
|
||||
|
||||
/// Returns whether access is valid or not
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool valid() const
|
||||
{
|
||||
return is_valid_;
|
||||
}
|
||||
|
||||
/// Returns a scale pointer
|
||||
CUTLASS_HOST_DEVICE
|
||||
AccessType* get_scale() const
|
||||
{
|
||||
return reinterpret_cast<AccessType*>(pointer_scale_);
|
||||
}
|
||||
|
||||
/// Returns a zero pointer
|
||||
CUTLASS_HOST_DEVICE
|
||||
AccessType* get_zero() const
|
||||
{
|
||||
return reinterpret_cast<AccessType*>(pointer_zero_);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace transform
|
||||
} // namespace cutlass
|
||||
@@ -1,181 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "cute/layout.hpp"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/util/print.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
/// Function object that applies an index to its argument
|
||||
template <class Iter>
|
||||
struct IndexedGather
|
||||
{
|
||||
CUTE_HOST_DEVICE constexpr IndexedGather(Iter indices = {})
|
||||
: indices_(indices)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename I>
|
||||
CUTE_HOST_DEVICE constexpr auto operator()(I i) const
|
||||
{
|
||||
return indices_[i];
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE friend void print(IndexedGather const& s)
|
||||
{
|
||||
cute::print("Indexed{");
|
||||
print(s.indices_);
|
||||
print("}");
|
||||
}
|
||||
|
||||
Iter indices_;
|
||||
};
|
||||
|
||||
/// Custom stride object that applies a function followed by a stride
|
||||
template <class Func, class Stride>
|
||||
struct CustomStride
|
||||
{
|
||||
CUTE_HOST_DEVICE constexpr CustomStride(Func const& func, Stride const& stride)
|
||||
: func_(func)
|
||||
, stride_(stride)
|
||||
{
|
||||
}
|
||||
|
||||
template <class I>
|
||||
CUTE_HOST_DEVICE constexpr friend auto operator*(I i, CustomStride const& s)
|
||||
{
|
||||
return s.func_(i) * s.stride_;
|
||||
}
|
||||
|
||||
template <class I>
|
||||
CUTE_HOST_DEVICE constexpr friend auto operator*(CustomStride const& s, I i)
|
||||
{
|
||||
return s.func_(i) * s.stride_;
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE friend void print(CustomStride const& s)
|
||||
{
|
||||
cute::print("Custom{");
|
||||
print(s.func_);
|
||||
cute::print(",");
|
||||
print(s.stride_);
|
||||
cute::print("}");
|
||||
}
|
||||
|
||||
template <class Div>
|
||||
CUTE_HOST_DEVICE constexpr friend auto safe_div(CustomStride const& s, Div const& div)
|
||||
{
|
||||
return CustomStride<Func, decltype(safe_div(s.stride_, div))>(s.func_, safe_div(s.stride_, div));
|
||||
}
|
||||
|
||||
// Circumvent the requirement on make_layout that shape and stride are integral
|
||||
template <class Shape>
|
||||
CUTE_HOST_DEVICE constexpr friend auto make_layout(Shape const& shape, CustomStride const& stride)
|
||||
{
|
||||
return Layout<Shape, CustomStride>(shape, stride);
|
||||
}
|
||||
|
||||
Func func_;
|
||||
Stride stride_;
|
||||
};
|
||||
|
||||
template <class Stride, class Func>
|
||||
CUTLASS_HOST_DEVICE auto make_custom_stride_layout(Stride const& stride, Func&& func)
|
||||
{
|
||||
// Use a dummy shape and replace the first non-unit and non-zero stride with a custom gather stride
|
||||
auto idx = find_if(stride, [](auto x) { return !is_constant<1, decltype(x)>{} && !is_constant<0, decltype(x)>{}; });
|
||||
constexpr int I = decltype(idx)::value;
|
||||
return make_layout(
|
||||
repeat_like(stride, _1{}), replace<I>(stride, CustomStride{static_cast<Func&&>(func), get<I>(stride)}));
|
||||
}
|
||||
|
||||
/// Helper function to optionally create a gather tensor
|
||||
template <class Iterator, class Shape, class Stride, class Func>
|
||||
CUTLASS_HOST_DEVICE auto make_gather_tensor(Iterator iter, Shape const& shape, Stride const& stride, Func&& func)
|
||||
{
|
||||
Layout matrix_layout = make_identity_layout(shape);
|
||||
auto offset = as_arithmetic_tuple(repeat_like(shape, _0{}));
|
||||
Layout gather_layout = make_custom_stride_layout(stride, static_cast<Func&&>(func));
|
||||
return make_tensor(iter, ComposedLayout{gather_layout, offset, matrix_layout});
|
||||
}
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
template <int N, int I, class Shape, class Stride>
|
||||
CUTE_HOST_DEVICE constexpr auto upcast(Shape const& shape, Stride const& stride)
|
||||
{
|
||||
if constexpr (is_tuple<Shape>::value)
|
||||
{
|
||||
return transform_layout(shape, stride, [](auto const& s, auto const& d) { return upcast<N, I>(s, d); });
|
||||
}
|
||||
else if constexpr (is_scaled_basis<Stride>::value)
|
||||
{
|
||||
if constexpr (Stride::mode() == I)
|
||||
{
|
||||
return make_layout(shape_div(shape, Int<N>{}), shape_div(stride, Int<N>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_layout(shape, stride);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return upcast<N>(shape, stride);
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
template <int N, class OuterShape, class OuterStride, class Offset, class Shape, class Stride>
|
||||
CUTE_HOST_DEVICE constexpr auto upcast(
|
||||
ComposedLayout<Layout<OuterShape, OuterStride>, Offset, Layout<Shape, Stride>> const& layout)
|
||||
{
|
||||
// Find index of the stride-1 mode - that is the only one that requires updating inner shape and offset
|
||||
auto idx = find_if(layout.layout_a().stride(), [](auto x) { return is_constant<1, decltype(x)>{}; });
|
||||
constexpr int I = decltype(idx)::value;
|
||||
|
||||
// Upcast the outer layout (works as expected)
|
||||
auto outer = upcast<N>(layout.layout_a());
|
||||
|
||||
// Upcast the accumulated offset along stride-1 mode
|
||||
auto offset = as_arithmetic_tuple(replace<I>(layout.offset(), upcast<N>(get<I>(layout.offset()))));
|
||||
|
||||
// Upcast the inner layout's shape along stride-1 mode
|
||||
auto inner = upcast<N, I>(layout.layout_b().shape(), layout.layout_b().stride());
|
||||
|
||||
return composition(outer, offset, inner);
|
||||
}
|
||||
|
||||
} // namespace cute
|
||||
@@ -1,58 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
|
||||
enum class WeightOnlyQuantOp
|
||||
{
|
||||
UNDEFINED,
|
||||
PER_COLUMN_SCALE_ONLY,
|
||||
FINEGRAINED_SCALE_ONLY,
|
||||
FINEGRAINED_SCALE_AND_ZEROS
|
||||
};
|
||||
|
||||
constexpr bool isFinegrained(WeightOnlyQuantOp op)
|
||||
{
|
||||
return op == WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS || op == WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY;
|
||||
}
|
||||
|
||||
constexpr bool hasZero(WeightOnlyQuantOp op)
|
||||
{
|
||||
return op == WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS;
|
||||
}
|
||||
|
||||
} // namespace cutlass
|
||||
@@ -1,25 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
namespace tensorrt_llm::kernels::cutlass_kernels
|
||||
{
|
||||
template <typename ElementType_, typename CutlassWeightType_, int MaxTileM_, int TileN_, int TileK_, int Stages_,
|
||||
typename EpilogueTag>
|
||||
void sm80_generic_fused_moe_gemm_kernelLauncher(ElementType_ const* A, CutlassWeightType_ const* B,
|
||||
ElementType_ const* biases, bool bias_is_broadcast, ElementType_* C, int64_t const* total_tokens_including_expert,
|
||||
int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, int multi_processor_count, cudaStream_t stream,
|
||||
int* kernel_occupancy);
|
||||
}
|
||||
@@ -1,96 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
|
||||
#include "cutlass/gemm/device/gemm_grouped.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include <cutlass_extensions/epilogue_helpers.h>
|
||||
#include <cutlass_extensions/gemm/kernel/fused_moe_kernel.cuh>
|
||||
#include <tensorrt_llm/common/cudaUtils.h>
|
||||
|
||||
namespace tensorrt_llm::kernels::cutlass_kernels
|
||||
{
|
||||
template <typename ElementType_, typename CutlassWeightType_, int MaxTileM_, int TileN_, int TileK_, int Stages_,
|
||||
typename EpilogueTag>
|
||||
void sm80_generic_fused_moe_gemm_kernelLauncher(ElementType_ const* A, CutlassWeightType_ const* B,
|
||||
ElementType_ const* biases, bool bias_is_broadcast, ElementType_* C, int64_t const* total_tokens_including_expert,
|
||||
int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, int multi_processor_count, cudaStream_t stream,
|
||||
int* kernel_occupancy)
|
||||
{
|
||||
constexpr auto activation_type = fused_moe::EpilogueRouting<EpilogueTag>(true);
|
||||
using GemmType = fused_moe::Fused_Moe_Kernel_sm80<ElementType_, CutlassWeightType_, ElementType_, MaxTileM_, TileN_,
|
||||
TileK_, Stages_, activation_type>;
|
||||
|
||||
// make sure GPU has enough resources..
|
||||
if (kernel_occupancy != nullptr)
|
||||
{
|
||||
constexpr int smem_size = GemmType::kSmemSize;
|
||||
|
||||
if (smem_size > (48 << 10))
|
||||
{
|
||||
cudaFuncAttributes attr{};
|
||||
int device = 0;
|
||||
int max_smem_per_block = 0;
|
||||
tensorrt_llm::common::check_cuda_error(cudaGetDevice(&device));
|
||||
tensorrt_llm::common::check_cuda_error(
|
||||
cudaDeviceGetAttribute(&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device));
|
||||
tensorrt_llm::common::check_cuda_error(cudaFuncGetAttributes(&attr, fused_moe::run_global<GemmType>));
|
||||
if (smem_size + attr.sharedSizeBytes >= static_cast<size_t>(max_smem_per_block))
|
||||
{
|
||||
// This should mean that
|
||||
// cudaFuncSetAttribute(cutlass::Kernel<GemmKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
// smem_size) wouldn't work. In that case, we return an occupancy of 0. This will cause the
|
||||
// heuristic to ignore this configuration.
|
||||
*kernel_occupancy = 0;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
int max_active_blocks = -1;
|
||||
tensorrt_llm::common::check_cuda_error(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_active_blocks, fused_moe::run_global<GemmType>, GemmType::kThreadCount, smem_size));
|
||||
*kernel_occupancy = max_active_blocks;
|
||||
return;
|
||||
}
|
||||
int occupancy = std::min(2, fused_moe::fused_gemm_maximum_active_blocks<GemmType>());
|
||||
int const threadblock_count = multi_processor_count * occupancy;
|
||||
TLLM_CHECK_WITH_INFO(occupancy > 0, "GPU lacks the shared memory resources to run fused_moe kernel");
|
||||
using Arguments = typename GemmType::Arguments;
|
||||
Arguments args{{const_cast<ElementType_*>(A), const_cast<CutlassWeightType_*>(B), const_cast<ElementType_*>(biases),
|
||||
reinterpret_cast<ElementType_*>(C), total_tokens_including_expert, static_cast<int>(gemm_n),
|
||||
static_cast<int>(gemm_k), num_experts, bias_is_broadcast},
|
||||
num_experts, threadblock_count};
|
||||
auto params = GemmType::to_underlying_arguments(args);
|
||||
if (GemmType::kSmemSize >= (48 << 10))
|
||||
{
|
||||
cudaError_t result = cudaFuncSetAttribute(
|
||||
fused_moe::run_global<GemmType>, cudaFuncAttributeMaxDynamicSharedMemorySize, GemmType::kSmemSize);
|
||||
TLLM_CHECK_WITH_INFO(result == cudaSuccess,
|
||||
"Fail to set the max smem size to " + std::to_string(GemmType::kSmemSize) + " for fused moe kernel");
|
||||
}
|
||||
dim3 grid(params.threadblock_count, 1, 1);
|
||||
dim3 block(GemmType::kThreadCount);
|
||||
fused_moe::run_global<GemmType><<<grid, block, GemmType::kSmemSize, stream>>>(params);
|
||||
auto result = cudaGetLastError();
|
||||
TLLM_CHECK_WITH_INFO(result == cudaSuccess, "Fail to execute fused moe kernel, cuda error %d\n", (int) (result));
|
||||
}
|
||||
} // namespace tensorrt_llm::kernels::cutlass_kernels
|
||||
@@ -1,37 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h"
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
namespace cutlass_kernels
|
||||
{
|
||||
|
||||
// Keep in sync with the signature generated by generate_kernels.py
|
||||
template <typename T, typename WeightType, typename OutputType, typename EpilogueTag,
|
||||
HopperGroupedGemmInput::EpilogueFusion FUSION, typename TileShape, typename ClusterShape, bool BIAS>
|
||||
void sm90_generic_moe_gemm_kernelLauncher(HopperGroupedGemmInput hopper_input, int num_experts,
|
||||
int multi_processor_count, cudaStream_t stream, int* kernel_occupancy, size_t* workspace_size);
|
||||
|
||||
} // namespace cutlass_kernels
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
@@ -1,348 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
|
||||
#include "cutlass/gemm/device/gemm_grouped.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/group_array_problem_shape.hpp"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
|
||||
#include "cutlass_extensions/compute_occupancy.h"
|
||||
#include "cutlass_extensions/epilogue/collective/epilogue_moe_finalize.hpp"
|
||||
#include "cutlass_extensions/epilogue_helpers.h"
|
||||
#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h"
|
||||
#include "cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/default_mma.h"
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h"
|
||||
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h"
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <math.h>
|
||||
#include <sstream>
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
namespace cutlass_kernels
|
||||
{
|
||||
using EpilogueFusion = HopperGroupedGemmInput::EpilogueFusion;
|
||||
|
||||
// Hopper helper class for defining all the cutlass helper types
|
||||
template <typename T, typename WeightType, typename OutputType, typename EpilogueTag, typename TileShape,
|
||||
typename ClusterShape, bool BIAS, EpilogueFusion FUSION>
|
||||
struct HopperGroupedGemmInfo
|
||||
{
|
||||
using Arch = cutlass::arch::Sm90;
|
||||
|
||||
// TODO Update once mixed input support is added
|
||||
static_assert(cutlass::platform::is_same<T, WeightType>::value,
|
||||
"CUTLASS does not currently have specialised SM90 support for quantized operations");
|
||||
|
||||
#ifdef ENABLE_FP8
|
||||
constexpr static bool IsFP8
|
||||
= cutlass::platform::is_same<T, __nv_fp8_e4m3>::value || cutlass::platform::is_same<T, __nv_fp8_e5m2>::value;
|
||||
#else
|
||||
constexpr static bool IsFP8 = false;
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_BF16
|
||||
static_assert(cutlass::platform::is_same<T, __nv_bfloat16>::value || cutlass::platform::is_same<T, half>::value
|
||||
|| cutlass::platform::is_same<T, float>::value || IsFP8,
|
||||
"Specialized for bfloat16, half, float, fp8");
|
||||
#else
|
||||
static_assert(cutlass::platform::is_same<T, half>::value || cutlass::platform::is_same<T, float>::value || IsFP8,
|
||||
"Specialized for half, float, fp8");
|
||||
#endif
|
||||
|
||||
static_assert(cutlass::platform::is_same<T, WeightType>::value
|
||||
|| cutlass::platform::is_same<WeightType, uint8_t>::value
|
||||
|| cutlass::platform::is_same<WeightType, cutlass::uint4b_t>::value
|
||||
|| cutlass::platform::is_same<WeightType, cutlass::float_e4m3_t>::value
|
||||
|| cutlass::platform::is_same<WeightType, cutlass::float_e5m2_t>::value,
|
||||
"Unexpected quantization type");
|
||||
|
||||
// The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary.
|
||||
using ElementType = typename TllmToCutlassTypeAdapter<T>::type;
|
||||
|
||||
using CutlassWeightTypeMaybeUint4 = typename TllmToCutlassTypeAdapter<WeightType>::type;
|
||||
// For legacy reasons we convert unsigned 8-bit to signed
|
||||
using CutlassWeightTypeMaybeUint8
|
||||
= std::conditional_t<std::is_same_v<CutlassWeightTypeMaybeUint4, cutlass::uint4b_t>, cutlass::int4b_t,
|
||||
CutlassWeightTypeMaybeUint4>;
|
||||
using CutlassWeightType
|
||||
= std::conditional_t<std::is_same_v<CutlassWeightTypeMaybeUint8, uint8_t>, int8_t, CutlassWeightTypeMaybeUint8>;
|
||||
|
||||
using ElementA = ElementType;
|
||||
using ElementB = CutlassWeightType;
|
||||
|
||||
using ElementD = typename TllmToCutlassTypeAdapter<HopperGroupedGemmInput::OutputTypeAdaptor_t<OutputType>>::type;
|
||||
using ElementFinalOutput = typename TllmToCutlassTypeAdapter<OutputType>::type;
|
||||
|
||||
// using ElementC = std::conditional_t<BIAS, ElementType, void>;
|
||||
// using ElementCNoVoid = std::conditional_t<BIAS, ElementType, ElementD>;
|
||||
using ElementC = void;
|
||||
using ElementCNoVoid = ElementD;
|
||||
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using ElementBias = ElementFinalOutput;
|
||||
using ElementRouterScales = float;
|
||||
|
||||
// A matrix configuration - this is transposed and swapped with B
|
||||
using LayoutA = HopperGroupedGemmInput::LayoutA;
|
||||
constexpr static int AlignmentA
|
||||
= 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units
|
||||
// of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration - this is transposed and swapped with A
|
||||
using LayoutB = HopperGroupedGemmInput::LayoutB; // Layout type for B matrix operand
|
||||
constexpr static int AlignmentB
|
||||
= 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units
|
||||
// of elements (up to 16 bytes)
|
||||
|
||||
// C matrix configuration
|
||||
using LayoutC = HopperGroupedGemmInput::LayoutC; // Layout type for C matrix operand
|
||||
using StrideC = HopperGroupedGemmInput::StrideC;
|
||||
// Note we use ElementType here deliberately, so we don't break when BIAS is disabled
|
||||
constexpr static int AlignmentC
|
||||
= 128 / cutlass::sizeof_bits<ElementType>::value; // Memory access granularity/alignment of C matrix in units
|
||||
// of elements (up to 16 bytes)
|
||||
|
||||
// D matrix configuration
|
||||
using LayoutD = HopperGroupedGemmInput::DefaultEpilogue::LayoutD;
|
||||
using StrideD = HopperGroupedGemmInput::DefaultEpilogue::StrideD;
|
||||
constexpr static int AlignmentD
|
||||
= 128 / cutlass::sizeof_bits<ElementD>::value; // Memory access granularity/alignment of D matrix
|
||||
// in units of elements (up to 16 bytes)
|
||||
|
||||
static_assert(cutlass::platform::is_same<EpilogueTag, tensorrt_llm::cutlass_extensions::EpilogueOpDefault>::value,
|
||||
"Hopper Grouped GEMM specialisation doesn't support fused activation");
|
||||
|
||||
using EpilogueOp
|
||||
= cutlass::epilogue::fusion::LinearCombination<ElementD, ElementAccumulator, ElementC, ElementAccumulator>;
|
||||
|
||||
// TODO Add mode for fused activation once CUTLASS adds support
|
||||
// using EpilogueSchedule = cutlass::platform::conditional_t<
|
||||
// cutlass::platform::is_same<EpilogueOp, EpilogueOpDefault>::value,
|
||||
// cutlass::epilogue::PtrArrayNoSmemWarpSpecialized,
|
||||
// cutlass::epilogue::?????????????????? /// <<<<<< what supports activations
|
||||
// >;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized;
|
||||
|
||||
// Epilogue For Default Finalize
|
||||
using CollectiveEpilogueDefault = typename cutlass::epilogue::collective::CollectiveBuilder< //
|
||||
Arch, cutlass::arch::OpClassTensorOp, //
|
||||
TileShape, ClusterShape, //
|
||||
cutlass::epilogue::collective::EpilogueTileAuto, //
|
||||
ElementAccumulator, ElementAccumulator, //
|
||||
ElementC, LayoutC*, AlignmentC, //
|
||||
ElementD, LayoutD*, AlignmentD, //
|
||||
EpilogueSchedule>::CollectiveOp;
|
||||
|
||||
// Epilogue For Fused Finalize
|
||||
using CollectiveEpilogueFinalize = typename cutlass::epilogue::collective::EpilogueMoeFusedFinalizeBuilder< //
|
||||
TileShape, //
|
||||
ElementCNoVoid, StrideC*, //
|
||||
ElementFinalOutput, HopperGroupedGemmInput::FusedFinalizeEpilogue::StrideFinalOutput, //
|
||||
ElementAccumulator, //
|
||||
ElementAccumulator, //
|
||||
ElementBias, HopperGroupedGemmInput::FusedFinalizeEpilogue::StrideBias, //
|
||||
ElementRouterScales, HopperGroupedGemmInput::FusedFinalizeEpilogue::StrideRouterScales //
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue
|
||||
= std::conditional_t<FUSION == EpilogueFusion::FINALIZE, CollectiveEpilogueFinalize, CollectiveEpilogueDefault>;
|
||||
|
||||
using StageCountAutoCarveout = cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>;
|
||||
|
||||
using KernelSchedule
|
||||
= std::conditional_t<IsFP8, cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum,
|
||||
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative>;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< //
|
||||
Arch, cutlass::arch::OpClassTensorOp, //
|
||||
CutlassWeightType, LayoutB*, AlignmentB, // A & B swapped here
|
||||
ElementType, LayoutA*, AlignmentA, //
|
||||
ElementAccumulator, //
|
||||
TileShape, ClusterShape, //
|
||||
StageCountAutoCarveout, KernelSchedule>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<HopperGroupedGemmInput::ProblemShape, CollectiveMainloop,
|
||||
CollectiveEpilogue>;
|
||||
|
||||
using GemmGrouped = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
};
|
||||
|
||||
// Hopper specialised version
|
||||
template <typename T, typename WeightType, typename OutputType, typename EpilogueTag, EpilogueFusion FUSION,
|
||||
typename TileShape, typename ClusterShape, bool BIAS>
|
||||
void sm90_generic_moe_gemm_kernelLauncher(HopperGroupedGemmInput hopper_input, int num_experts,
|
||||
int const multi_processor_count, cudaStream_t stream, int* kernel_occupancy, size_t* workspace_size)
|
||||
{
|
||||
#ifdef COMPILE_HOPPER_TMA_GEMMS
|
||||
using namespace cute;
|
||||
if constexpr (!should_filter_sm90_gemm_problem_shape_v<TileShape, ClusterShape, T>)
|
||||
{
|
||||
using GemmInfo
|
||||
= HopperGroupedGemmInfo<T, WeightType, OutputType, EpilogueTag, TileShape, ClusterShape, BIAS, FUSION>;
|
||||
|
||||
using ElementAccumulator = typename GemmInfo::ElementAccumulator;
|
||||
using ElementA = typename GemmInfo::ElementA;
|
||||
using ElementB = typename GemmInfo::ElementB;
|
||||
using ElementC = typename GemmInfo::ElementC;
|
||||
using ElementCNoVoid = typename GemmInfo::ElementCNoVoid;
|
||||
using ElementD = typename GemmInfo::ElementD;
|
||||
|
||||
using CollectiveMainloop = typename GemmInfo::CollectiveMainloop;
|
||||
using CollectiveEpilogue = typename GemmInfo::CollectiveEpilogue;
|
||||
using GemmKernel = typename GemmInfo::GemmKernel;
|
||||
using GemmGrouped = typename GemmInfo::GemmGrouped;
|
||||
|
||||
if (kernel_occupancy != nullptr)
|
||||
{
|
||||
*kernel_occupancy = tensorrt_llm::cutlass_extensions::compute_occupancy_for_kernel<GemmKernel, true>();
|
||||
return;
|
||||
}
|
||||
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
hw_info.device_id = 0;
|
||||
hw_info.sm_count = multi_processor_count;
|
||||
|
||||
GemmGrouped gemm;
|
||||
|
||||
if (workspace_size != nullptr)
|
||||
{
|
||||
// Make a mock problem shape with just the minimal information actually required to get the workspace size
|
||||
// This makes some assumptions about CUTLASS's implementation which is suboptimal. We have a check later to
|
||||
// catch future cutlass updates causing silent breakages, but that is not fool proof.
|
||||
// The alternative is to wait until we have data and then dynamically allocate the workspace
|
||||
typename HopperGroupedGemmInput::ProblemShape shape_info{num_experts, nullptr, nullptr};
|
||||
|
||||
typename GemmGrouped::Arguments args{
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped, shape_info, {}, {}, hw_info};
|
||||
*workspace_size = gemm.get_workspace_size(args);
|
||||
return;
|
||||
}
|
||||
|
||||
using MainloopArguments = typename CollectiveMainloop::Arguments;
|
||||
TLLM_CHECK(hopper_input.stride_a);
|
||||
TLLM_CHECK(hopper_input.stride_b);
|
||||
TLLM_CHECK(hopper_input.ptr_a);
|
||||
TLLM_CHECK(hopper_input.ptr_b);
|
||||
|
||||
MainloopArguments const mainloop_params = {reinterpret_cast<ElementB const**>(hopper_input.ptr_b),
|
||||
hopper_input.stride_b, reinterpret_cast<ElementA const**>(hopper_input.ptr_a), hopper_input.stride_a};
|
||||
|
||||
typename GemmGrouped::EpilogueOutputOp::Params epilogue_scalars{
|
||||
ElementAccumulator(1.f), hopper_input.ptr_c ? ElementAccumulator(1.f) : ElementAccumulator(0.f)};
|
||||
epilogue_scalars.alpha_ptr_array = hopper_input.alpha_scale_ptr_array;
|
||||
using EpilogueArguments = typename CollectiveEpilogue::Arguments;
|
||||
// TODO(dastokes) ptr_c casts to ElementCNoVoid** because there is a workaround in CUTLASS
|
||||
auto make_epi_args = [&]()
|
||||
{
|
||||
if constexpr (FUSION == EpilogueFusion::NONE)
|
||||
{
|
||||
auto epi_params = hopper_input.default_epilogue;
|
||||
return EpilogueArguments{epilogue_scalars, reinterpret_cast<ElementCNoVoid const**>(hopper_input.ptr_c),
|
||||
hopper_input.stride_c, reinterpret_cast<ElementD**>(epi_params.ptr_d), epi_params.stride_d};
|
||||
}
|
||||
else if constexpr (FUSION == EpilogueFusion::FINALIZE)
|
||||
{
|
||||
// Parameters for fused finalize
|
||||
auto epi_params = hopper_input.fused_finalize_epilogue;
|
||||
return EpilogueArguments{
|
||||
epilogue_scalars, // Parameters to underlying epilogue
|
||||
reinterpret_cast<ElementCNoVoid const**>(hopper_input.ptr_c), hopper_input.stride_c, // C params
|
||||
reinterpret_cast<typename GemmInfo::ElementFinalOutput*>(epi_params.ptr_final_output),
|
||||
epi_params.stride_final_output, // D (output) params
|
||||
reinterpret_cast<typename GemmInfo::ElementBias const*>(epi_params.ptr_bias),
|
||||
epi_params.stride_bias, // Bias params
|
||||
epi_params.ptr_router_scales, epi_params.stride_router_scales, // Router scales
|
||||
epi_params.ptr_expert_first_token_offset, // Offset of this expert's token in the router scales
|
||||
epi_params.ptr_source_token_index, // Index of the source token to sum into
|
||||
epi_params.num_rows_in_final_output // Number of tokens in the output buffer
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(
|
||||
sizeof(EpilogueArguments) == 0, "Unimplemented fusion provided to SM90+ MoE gemm launcher");
|
||||
}
|
||||
};
|
||||
EpilogueArguments const epilogue_params = make_epi_args();
|
||||
|
||||
typename GemmKernel::TileScheduler::Arguments scheduler_args{
|
||||
1, GemmKernel::TileScheduler::RasterOrderOptions::AlongN};
|
||||
|
||||
typename GemmGrouped::Arguments args{cutlass::gemm::GemmUniversalMode::kGrouped, hopper_input.shape_info,
|
||||
mainloop_params, epilogue_params, hw_info, scheduler_args};
|
||||
|
||||
size_t calculated_ws_size = gemm.get_workspace_size(args);
|
||||
TLLM_CHECK_WITH_INFO(calculated_ws_size <= hopper_input.gemm_workspace_size,
|
||||
"Workspace is size %zu but only %zu were allocated", calculated_ws_size, hopper_input.gemm_workspace_size);
|
||||
|
||||
auto can_implement = gemm.can_implement(args);
|
||||
TLLM_CHECK_WITH_INFO(can_implement == cutlass::Status::kSuccess,
|
||||
"Grouped GEMM kernel will fail for params. Error: " + std::string(cutlassGetStatusString(can_implement)));
|
||||
|
||||
auto init_status = gemm.initialize(args, hopper_input.gemm_workspace);
|
||||
TLLM_CHECK_WITH_INFO(init_status == cutlass::Status::kSuccess,
|
||||
"Failed to initialize cutlass SM90 grouped gemm. Error: "
|
||||
+ std::string(cutlassGetStatusString(init_status)));
|
||||
|
||||
auto run_status = gemm.run(stream);
|
||||
TLLM_CHECK_WITH_INFO(run_status == cutlass::Status::kSuccess,
|
||||
"Failed to run cutlass SM90 grouped gemm. Error: " + std::string(cutlassGetStatusString(run_status)));
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_THROW("Configuration was disabled by FAST_BUILD");
|
||||
}
|
||||
|
||||
#else // COMPILE_HOPPER_TMA_GEMMS
|
||||
TLLM_THROW("Please recompile with support for hopper by passing 90-real as an arch to build_wheel.py.");
|
||||
#endif // COMPILE_HOPPER_TMA_GEMMS
|
||||
}
|
||||
|
||||
} // namespace cutlass_kernels
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
@@ -1,131 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
// Order matters here, packed_stride.hpp is missing cute and convolution includes
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
std::array<size_t, 10> HopperGroupedGemmInput::workspaceBuffers(int num_experts)
|
||||
{
|
||||
size_t problem_shape_size = sizeof(ProblemShape::UnderlyingProblemShape) * num_experts;
|
||||
size_t stride_a_size = sizeof(StrideA) * num_experts;
|
||||
size_t stride_b_size = sizeof(StrideB) * num_experts;
|
||||
size_t stride_c_size = sizeof(StrideC) * num_experts;
|
||||
size_t stride_d_size = sizeof(DefaultEpilogue::StrideD) * num_experts;
|
||||
|
||||
size_t ptr_buf_size = sizeof(void*) * num_experts;
|
||||
size_t scale_buf_size = sizeof(float*) * num_experts;
|
||||
|
||||
return std::array{problem_shape_size, stride_a_size, stride_b_size, stride_c_size, stride_d_size, ptr_buf_size,
|
||||
ptr_buf_size, ptr_buf_size, ptr_buf_size, scale_buf_size};
|
||||
}
|
||||
|
||||
size_t HopperGroupedGemmInput::workspaceSize(int num_experts)
|
||||
{
|
||||
auto buffers = workspaceBuffers(num_experts);
|
||||
return tensorrt_llm::common::calculateTotalWorkspaceSize(buffers.data(), buffers.size());
|
||||
}
|
||||
|
||||
void HopperGroupedGemmInput::configureWorkspace(
|
||||
int8_t* start_ptr, int num_experts, void* gemm_workspace, size_t gemm_workspace_size)
|
||||
{
|
||||
auto buffers = workspaceBuffers(num_experts);
|
||||
std::array<int8_t*, 10> pointers{};
|
||||
TLLM_CHECK_WITH_INFO(pointers.size() == buffers.size(), "Mismatching workspace size and number of buffers");
|
||||
for (int i = 0; i < buffers.size(); i++)
|
||||
{
|
||||
pointers[i] = start_ptr;
|
||||
start_ptr = tensorrt_llm::common::nextWorkspacePtr(start_ptr, buffers[i]);
|
||||
}
|
||||
|
||||
shape_info.num_groups = num_experts;
|
||||
shape_info.problem_shapes = reinterpret_cast<ProblemShape::UnderlyingProblemShape*>(pointers[0]);
|
||||
shape_info.host_problem_shapes = nullptr;
|
||||
stride_a = reinterpret_cast<StrideA*>(pointers[1]);
|
||||
stride_b = reinterpret_cast<StrideB*>(pointers[2]);
|
||||
stride_c = reinterpret_cast<StrideC*>(pointers[3]);
|
||||
default_epilogue.stride_d = reinterpret_cast<DefaultEpilogue::StrideD*>(pointers[4]);
|
||||
|
||||
ptr_a = reinterpret_cast<void const**>(pointers[5]);
|
||||
ptr_b = reinterpret_cast<void const**>(pointers[6]);
|
||||
ptr_c = reinterpret_cast<void const**>(pointers[7]);
|
||||
default_epilogue.ptr_d = reinterpret_cast<void**>(pointers[8]);
|
||||
|
||||
alpha_scale_ptr_array = reinterpret_cast<float const**>(pointers[9]);
|
||||
|
||||
this->gemm_workspace = reinterpret_cast<uint8_t*>(gemm_workspace);
|
||||
this->gemm_workspace_size = gemm_workspace_size;
|
||||
}
|
||||
|
||||
void HopperGroupedGemmInput::setFinalizeFusionParams(void* final_output, float const* router_scales,
|
||||
int64_t const* expert_first_token_offset, int const* source_token_index, void const* bias, int hidden_size,
|
||||
int num_output_tokens)
|
||||
{
|
||||
fused_finalize_epilogue.ptr_final_output = final_output;
|
||||
fused_finalize_epilogue.ptr_router_scales = router_scales;
|
||||
fused_finalize_epilogue.ptr_bias = bias;
|
||||
fused_finalize_epilogue.ptr_expert_first_token_offset = expert_first_token_offset;
|
||||
fused_finalize_epilogue.ptr_source_token_index = source_token_index;
|
||||
|
||||
fused_finalize_epilogue.stride_final_output
|
||||
= cutlass::make_cute_packed_stride(FusedFinalizeEpilogue::StrideFinalOutput{},
|
||||
transpose_stride(cute::make_shape(num_output_tokens, hidden_size, 1)));
|
||||
fused_finalize_epilogue.stride_bias
|
||||
= transpose_stride(cute::make_stride(cute::Int<0>{}, cute::Int<1>{}, hidden_size));
|
||||
fused_finalize_epilogue.stride_router_scales = {};
|
||||
|
||||
fused_finalize_epilogue.num_rows_in_final_output = num_output_tokens;
|
||||
}
|
||||
|
||||
std::string HopperGroupedGemmInput::toString() const
|
||||
{
|
||||
std::stringstream ss;
|
||||
ss << "Hopper Input Information: " << (isValid() ? "valid" : "null") << "\n";
|
||||
if (isValid())
|
||||
{
|
||||
ss << "Ptr A: " << ptr_a << ", Ptr B: " << ptr_b << ", Ptr C: " << ptr_c << "\n";
|
||||
ss << "Epilogue Fusion: " << (int) fusion;
|
||||
if (fusion == HopperGroupedGemmInput::EpilogueFusion::FINALIZE)
|
||||
{
|
||||
ss << ",\nFinal Output: " << fused_finalize_epilogue.ptr_final_output;
|
||||
ss << " with Stride: " << fused_finalize_epilogue.stride_router_scales;
|
||||
ss << ",\nBias: " << fused_finalize_epilogue.ptr_bias;
|
||||
ss << " with Stride: " << fused_finalize_epilogue.stride_bias;
|
||||
ss << ",\nRouter Scales: " << fused_finalize_epilogue.ptr_router_scales;
|
||||
ss << " with Stride: " << fused_finalize_epilogue.stride_router_scales;
|
||||
ss << ",\nExpert Offset: " << fused_finalize_epilogue.ptr_expert_first_token_offset;
|
||||
ss << ", Source Map: " << fused_finalize_epilogue.ptr_source_token_index;
|
||||
}
|
||||
else
|
||||
{
|
||||
ss << ", Ptr D: " << default_epilogue.ptr_d;
|
||||
}
|
||||
ss << '\n';
|
||||
ss << "Alpha scale ptr: " << alpha_scale_ptr_array << "\n";
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
} // namespace tensorrt_llm
|
||||
@@ -1,230 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "tensorrt_llm/common/cudaFp8Utils.h"
|
||||
#include "tensorrt_llm/common/workspace.h"
|
||||
#include "tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h"
|
||||
#include <array>
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <optional>
|
||||
#include <vector>
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/group_array_problem_shape.hpp"
|
||||
#include "cutlass/layout/layout.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
template <class T>
|
||||
constexpr auto transpose_stride(T const& t)
|
||||
{
|
||||
return cute::prepend(cute::prepend(cute::take<2, cute::rank_v<T>>(t), cute::get<0>(t)), cute::get<1>(t));
|
||||
}
|
||||
|
||||
struct HopperGroupedGemmInput
|
||||
{
|
||||
template <class T>
|
||||
using TransposeStride = decltype(transpose_stride<T>(T{}));
|
||||
template <class Tag>
|
||||
using TransposeLayoutTag = std::conditional_t<std::is_same_v<Tag, cutlass::layout::RowMajor>,
|
||||
cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>;
|
||||
|
||||
static_assert(std::is_same_v<cutlass::layout::RowMajor, TransposeLayoutTag<cutlass::layout::ColumnMajor>>);
|
||||
static_assert(std::is_same_v<cutlass::layout::ColumnMajor, TransposeLayoutTag<cutlass::layout::RowMajor>>);
|
||||
|
||||
// Layout for A and B is transposed and then swapped in the implementation
|
||||
// This uses B^T * A^T = (A * B)^T to get a better layout for the GEMM
|
||||
using LayoutA = TransposeLayoutTag<cutlass::layout::RowMajor>; // Layout type for A matrix operand
|
||||
using LayoutB = TransposeLayoutTag<cutlass::layout::ColumnMajor>; // Layout type for B matrix operand
|
||||
using LayoutC = TransposeLayoutTag<cutlass::layout::RowMajor>; // Layout type for C matrix operand
|
||||
|
||||
using StrideA
|
||||
= std::remove_pointer_t<cutlass::detail::TagToStrideB_t<LayoutA*>>; // Use B because they will be swapped
|
||||
using StrideB
|
||||
= std::remove_pointer_t<cutlass::detail::TagToStrideA_t<LayoutB*>>; // Use A because they will be swapped
|
||||
using StrideC = std::remove_pointer_t<cutlass::detail::TagToStrideC_t<LayoutC*>>;
|
||||
|
||||
template <class T>
|
||||
constexpr static bool IsFP8_v = std::is_same_v<T, __nv_fp8_e4m3> || std::is_same_v<T, __nv_fp8_e5m2>;
|
||||
|
||||
// Currently this should always just be T
|
||||
template <class T>
|
||||
using OutputTypeAdaptor_t = std::conditional_t<IsFP8_v<T>, nv_bfloat16, T>;
|
||||
|
||||
using ProblemShape = cutlass::gemm::GroupProblemShape<cute::Shape<int64_t, int64_t, int64_t>>;
|
||||
|
||||
ProblemShape shape_info{};
|
||||
StrideA* stride_a = nullptr;
|
||||
StrideB* stride_b = nullptr;
|
||||
|
||||
void const** ptr_a = nullptr;
|
||||
void const** ptr_b = nullptr;
|
||||
|
||||
// C is currently the same in both epilogues
|
||||
StrideC* stride_c = nullptr;
|
||||
void const** ptr_c = nullptr;
|
||||
|
||||
struct DefaultEpilogue
|
||||
{
|
||||
using LayoutD = TransposeLayoutTag<cutlass::layout::RowMajor>; // Layout type for D matrix operand
|
||||
using StrideD = std::remove_pointer_t<cutlass::detail::TagToStrideC_t<LayoutD*>>;
|
||||
|
||||
StrideD* stride_d = nullptr;
|
||||
void** ptr_d = nullptr;
|
||||
};
|
||||
|
||||
struct FusedFinalizeEpilogue
|
||||
{
|
||||
using StrideFinalOutput = DefaultEpilogue::StrideD;
|
||||
using StrideBias = TransposeStride<cute::Stride<cute::_0, cute::_1, int>>;
|
||||
using StrideRouterScales = TransposeStride<cute::Stride<cute::_1, cute::_0>>;
|
||||
|
||||
void* ptr_final_output = nullptr;
|
||||
StrideFinalOutput stride_final_output{};
|
||||
|
||||
void const* ptr_bias = nullptr;
|
||||
StrideBias stride_bias{};
|
||||
|
||||
float const* ptr_router_scales = nullptr;
|
||||
StrideRouterScales stride_router_scales{};
|
||||
|
||||
int64_t const* ptr_expert_first_token_offset = nullptr;
|
||||
int const* ptr_source_token_index = nullptr;
|
||||
|
||||
size_t num_rows_in_final_output = 0;
|
||||
};
|
||||
|
||||
DefaultEpilogue default_epilogue;
|
||||
FusedFinalizeEpilogue fused_finalize_epilogue;
|
||||
|
||||
enum class EpilogueFusion
|
||||
{
|
||||
NONE,
|
||||
ACTIVATION,
|
||||
GATED_ACTIVATION,
|
||||
FINALIZE
|
||||
};
|
||||
EpilogueFusion fusion = EpilogueFusion::NONE;
|
||||
|
||||
float const** alpha_scale_ptr_array = nullptr;
|
||||
|
||||
uint8_t* gemm_workspace = nullptr;
|
||||
size_t gemm_workspace_size = 0;
|
||||
|
||||
static std::array<size_t, 10> workspaceBuffers(int num_experts);
|
||||
|
||||
static size_t workspaceSize(int num_experts);
|
||||
|
||||
void configureWorkspace(int8_t* start_ptr, int num_experts, void* gemm_workspace, size_t gemm_workspace_size);
|
||||
|
||||
bool isValid() const
|
||||
{
|
||||
return stride_a != nullptr && ptr_a != nullptr;
|
||||
}
|
||||
|
||||
void setFinalizeFusionParams(void* final_output, float const* router_scales,
|
||||
int64_t const* expert_first_token_offset, int const* source_token_index, void const* bias, int hidden_size,
|
||||
int num_output_tokens);
|
||||
|
||||
std::string toString() const;
|
||||
};
|
||||
|
||||
// Note update moe.py to match
|
||||
enum class ActivationType
|
||||
{
|
||||
Gelu = 0,
|
||||
Relu,
|
||||
Silu,
|
||||
Swiglu,
|
||||
Geglu,
|
||||
Identity,
|
||||
InvalidType
|
||||
};
|
||||
|
||||
constexpr bool isGatedActivation(ActivationType activation_type)
|
||||
{
|
||||
return activation_type == ActivationType::Swiglu || activation_type == ActivationType::Geglu;
|
||||
}
|
||||
|
||||
template <typename T, /*The type used for activations/scales/compute*/
|
||||
typename WeightType, /* The type for the MoE weights */
|
||||
typename OutputType, /* The output type for the GEMM */
|
||||
typename ScaleBiasType = OutputType /* The type for the scales/bias */
|
||||
>
|
||||
class MoeGemmRunner
|
||||
{
|
||||
public:
|
||||
MoeGemmRunner();
|
||||
|
||||
#if defined(ENABLE_FP8)
|
||||
static constexpr bool use_fp8 = std::is_same_v<T, __nv_fp8_e4m3> || std::is_same_v<T, __nv_fp8_e5m2>;
|
||||
#else
|
||||
static constexpr bool use_fp8 = false;
|
||||
#endif
|
||||
|
||||
void moeGemmBiasAct(T const* A, WeightType const* B, ScaleBiasType const* weight_scales,
|
||||
ScaleBiasType const* biases, bool bias_is_broadcast, void* C, int64_t const* total_tokens_including_expert,
|
||||
HopperGroupedGemmInput layout_info, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
|
||||
ActivationType activation_type, bool use_fused_moe, float const** alpha_scale_ptr_array, cudaStream_t stream,
|
||||
cutlass_extensions::CutlassGemmConfig chosen_conf);
|
||||
|
||||
void moeGemm(T const* A, WeightType const* B, ScaleBiasType const* weight_scales, void* C,
|
||||
int64_t const* total_tokens_including_expert, HopperGroupedGemmInput layout_info, int64_t total_rows,
|
||||
int64_t gemm_n, int64_t gemm_k, int num_experts, bool use_fused_moe, float const** alpha_scale_ptr_array,
|
||||
cudaStream_t stream, cutlass_extensions::CutlassGemmConfig chosen_conf);
|
||||
|
||||
std::vector<cutlass_extensions::CutlassGemmConfig> getConfigs() const;
|
||||
static std::vector<cutlass_extensions::CutlassGemmConfig> getConfigs(int sm);
|
||||
static std::vector<cutlass_extensions::CutlassGemmConfig> getHopperConfigs(int sm);
|
||||
static std::vector<cutlass_extensions::CutlassGemmConfig> getAmpereConfigs(int sm);
|
||||
|
||||
[[nodiscard]] bool isHopperSpecialised(cutlass_extensions::CutlassGemmConfig gemm_config) const;
|
||||
[[nodiscard]] bool supportsHopperSpecialisation() const;
|
||||
[[nodiscard]] bool isFusedGatedActivation(
|
||||
cutlass_extensions::CutlassGemmConfig gemm_config, bool is_gated_activation, int gemm_n, int gemm_k) const;
|
||||
[[nodiscard]] bool supportsFusedGatedActivation(bool is_gated_activation, int gemm_n, int gemm_k) const;
|
||||
|
||||
size_t getMaxWorkspaceSize(int num_experts) const;
|
||||
|
||||
[[nodiscard]] int getSM() const;
|
||||
|
||||
private:
|
||||
template <typename EpilogueTag>
|
||||
void dispatchToArch(T const* A, WeightType const* B, ScaleBiasType const* weight_scales,
|
||||
ScaleBiasType const* biases, bool bias_is_broadcast, void* C, int64_t const* total_tokens_including_expert,
|
||||
HopperGroupedGemmInput layout_info, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
|
||||
cutlass_extensions::CutlassGemmConfig gemm_config, bool use_fused_moe, float const** alpha_scale_ptr_array,
|
||||
cudaStream_t stream, int* occupancy = nullptr);
|
||||
|
||||
template <typename EpilogueTag>
|
||||
void runGemm(T const* A, WeightType const* B, ScaleBiasType const* weight_scales, ScaleBiasType const* biases,
|
||||
bool bias_is_broadcast, void* C, int64_t const* total_tokens_including_expert,
|
||||
HopperGroupedGemmInput layout_info, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
|
||||
bool use_fused_moe, float const** alpha_scale_ptr_array, cudaStream_t stream,
|
||||
cutlass_extensions::CutlassGemmConfig chosen_conf);
|
||||
|
||||
private:
|
||||
int sm_{};
|
||||
int multi_processor_count_{};
|
||||
mutable int num_experts_ = 0;
|
||||
mutable size_t gemm_workspace_size_ = 0;
|
||||
size_t calcMaxWorkspaceSize(int num_experts) const;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm
|
||||
@@ -1,24 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
#ifdef ENABLE_BF16
|
||||
template class MoeGemmRunner<__nv_bfloat16, __nv_bfloat16, __nv_bfloat16>;
|
||||
#endif
|
||||
} // namespace tensorrt_llm
|
||||
@@ -1,24 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
#ifdef ENABLE_BF16
|
||||
template class MoeGemmRunner<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16>;
|
||||
#endif
|
||||
} // namespace tensorrt_llm
|
||||
@@ -1,24 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
#ifdef ENABLE_BF16
|
||||
template class MoeGemmRunner<__nv_bfloat16, uint8_t, __nv_bfloat16>;
|
||||
#endif
|
||||
} // namespace tensorrt_llm
|
||||
@@ -1,22 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
template class MoeGemmRunner<half, half, half>;
|
||||
}
|
||||
@@ -1,22 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
template class MoeGemmRunner<half, cutlass::uint4b_t, half>;
|
||||
}
|
||||
@@ -1,22 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
template class MoeGemmRunner<half, uint8_t, half>;
|
||||
}
|
||||
@@ -1,22 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
template class MoeGemmRunner<float, float, float>;
|
||||
}
|
||||
@@ -1,28 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
#ifdef ENABLE_FP8
|
||||
template class MoeGemmRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, half>;
|
||||
#ifdef ENABLE_BF16
|
||||
template class MoeGemmRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, __nv_bfloat16>;
|
||||
#endif
|
||||
// template class MoeGemmRunner<__nv_fp8_e5m2, __nv_fp8_e5m2>;
|
||||
#endif
|
||||
} // namespace tensorrt_llm
|
||||
@@ -1,823 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
// Ignore CUTLASS warnings about type punning
|
||||
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
|
||||
#endif
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
|
||||
#include "cutlass/gemm/device/gemm_grouped.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/group_array_problem_shape.hpp"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
|
||||
#include "cutlass_extensions/compute_occupancy.h"
|
||||
#include "cutlass_extensions/epilogue_helpers.h"
|
||||
#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h"
|
||||
#include "cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/default_mma.h"
|
||||
|
||||
#ifdef __GNUC__ // Restore GCC-specific diagnostics
|
||||
#pragma GCC diagnostic pop
|
||||
#endif
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h"
|
||||
|
||||
#include "moe_gemm_kernels_template_sm90.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h"
|
||||
#include <tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.h>
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <math.h>
|
||||
#include <sstream>
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels::cutlass_kernels
|
||||
{
|
||||
|
||||
// ============================= Variable batched Gemm things ===========================
|
||||
template <typename T, typename WeightType, typename GemmOutputType, typename arch, typename EpilogueTag,
|
||||
typename ThreadblockShape, typename WarpShape, int Stages>
|
||||
void genericMoeGemmKernelLauncher(T const* A, WeightType const* B, GemmOutputType const* weight_scales,
|
||||
GemmOutputType const* biases, bool bias_is_broadcast, GemmOutputType* C,
|
||||
int64_t const* total_tokens_including_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
|
||||
cutlass_extensions::CutlassGemmConfig gemm_config, int const multi_processor_count, bool use_fused_moe,
|
||||
float const** alpha_scale_ptr_array, cudaStream_t stream, int* kernel_occupancy = nullptr)
|
||||
{
|
||||
#if defined(ENABLE_FP8)
|
||||
static_assert(cutlass::platform::is_same<T, __nv_bfloat16>::value || cutlass::platform::is_same<T, half>::value
|
||||
|| cutlass::platform::is_same<T, __nv_fp8_e4m3>::value
|
||||
|| cutlass::platform::is_same<T, __nv_fp8_e5m2>::value || cutlass::platform::is_same<T, float>::value,
|
||||
"Specialized for fp8, bfloat16, half, float");
|
||||
#elif defined(ENABLE_BF16)
|
||||
static_assert(cutlass::platform::is_same<T, __nv_bfloat16>::value || cutlass::platform::is_same<T, half>::value
|
||||
|| cutlass::platform::is_same<T, float>::value,
|
||||
"Specialized for bfloat16, half, float");
|
||||
#else
|
||||
static_assert(cutlass::platform::is_same<T, half>::value || cutlass::platform::is_same<T, float>::value,
|
||||
"Specialized for half, float");
|
||||
#endif
|
||||
|
||||
static_assert(cutlass::platform::is_same<T, WeightType>::value
|
||||
|| cutlass::platform::is_same<WeightType, uint8_t>::value
|
||||
|| cutlass::platform::is_same<WeightType, cutlass::uint4b_t>::value,
|
||||
"");
|
||||
|
||||
static_assert(!cutlass::platform::is_same<arch, cutlass::arch::Sm90>::value,
|
||||
"Sm90 architecture should use specialised kernels");
|
||||
|
||||
// The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary.
|
||||
using ElementType = typename TllmToCutlassTypeAdapter<T>::type;
|
||||
using CutlassGemmOutputType = typename TllmToCutlassTypeAdapter<GemmOutputType>::type;
|
||||
using CutlassWeightType = typename TllmToCutlassTypeAdapter<WeightType>::type;
|
||||
if (!use_fused_moe)
|
||||
{
|
||||
// We need separate config for each architecture since we will target different tensorcore instructions. For
|
||||
// float, we do not target TCs.
|
||||
using MixedGemmArchTraits = cutlass::gemm::kernel::MixedGemmArchTraits<ElementType, CutlassWeightType, arch>;
|
||||
using ElementAccumulator = typename MixedGemmArchTraits::AccType;
|
||||
|
||||
using EpilogueOp = typename tensorrt_llm::cutlass_extensions::Epilogue<CutlassGemmOutputType,
|
||||
MixedGemmArchTraits::ElementsPerAccessC, ElementAccumulator, EpilogueTag>::Op;
|
||||
|
||||
typename EpilogueOp::Params epilogue_op(
|
||||
ElementAccumulator(1.f), biases ? ElementAccumulator(1.f) : ElementAccumulator(0.f));
|
||||
|
||||
#if defined(ENABLE_FP8)
|
||||
if constexpr ((std::is_same_v<T, __nv_fp8_e4m3>
|
||||
|| std::is_same_v<T, __nv_fp8_e5m2>) &&std::is_same_v<EpilogueTag,
|
||||
cutlass_extensions::EpilogueOpDefault>)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(weight_scales == nullptr && biases == nullptr && alpha_scale_ptr_array,
|
||||
"weight_scales and biases should be nullptr and alpha_scale_ptr_array shouldn't be nullptr for FP8 "
|
||||
"Ada");
|
||||
epilogue_op.alpha_ptr_array = alpha_scale_ptr_array;
|
||||
}
|
||||
#endif
|
||||
|
||||
// Finally, set up the kernel.
|
||||
using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemmGrouped<ElementType, cutlass::layout::RowMajor,
|
||||
cutlass::ComplexTransform::kNone, MixedGemmArchTraits::ElementsPerAccessA, CutlassWeightType,
|
||||
typename MixedGemmArchTraits::LayoutB, cutlass::ComplexTransform::kNone,
|
||||
MixedGemmArchTraits::ElementsPerAccessB, CutlassGemmOutputType, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, typename MixedGemmArchTraits::OperatorClass, arch, ThreadblockShape, WarpShape,
|
||||
typename MixedGemmArchTraits::InstructionShape, EpilogueOp,
|
||||
cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, Stages,
|
||||
cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, typename MixedGemmArchTraits::Operator>::GemmKernel;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::MoeFCGemm<typename GemmKernel_::Mma, typename GemmKernel_::Epilogue,
|
||||
typename GemmKernel_::ThreadblockSwizzle,
|
||||
arch, // Ensure top level arch is used for dispatch
|
||||
GemmKernel_::kGroupScheduleMode>;
|
||||
|
||||
using GemmGrouped = cutlass::gemm::device::GemmGrouped<GemmKernel>;
|
||||
|
||||
if (kernel_occupancy != nullptr)
|
||||
{
|
||||
*kernel_occupancy = tensorrt_llm::cutlass_extensions::compute_occupancy_for_kernel<GemmKernel>();
|
||||
return;
|
||||
}
|
||||
int occupancy = std::min(2, GemmGrouped::maximum_active_blocks());
|
||||
TLLM_CHECK_WITH_INFO(occupancy > 0, "GPU lacks the shared memory resources to run GroupedGEMM kernel");
|
||||
int const threadblock_count = multi_processor_count * occupancy;
|
||||
|
||||
int const group_size = gemm_k;
|
||||
typename GemmGrouped::Arguments args(num_experts, threadblock_count, group_size, epilogue_op,
|
||||
reinterpret_cast<ElementType const*>(A), reinterpret_cast<CutlassWeightType const*>(B),
|
||||
reinterpret_cast<CutlassGemmOutputType const*>(weight_scales),
|
||||
reinterpret_cast<CutlassGemmOutputType const*>(biases), bias_is_broadcast,
|
||||
reinterpret_cast<CutlassGemmOutputType*>(C), total_tokens_including_expert, gemm_n, gemm_k);
|
||||
|
||||
GemmGrouped gemm;
|
||||
|
||||
auto can_implement = gemm.can_implement(args);
|
||||
TLLM_CHECK_WITH_INFO(can_implement == cutlass::Status::kSuccess,
|
||||
"MoE FC kernel will fail for params. Error: " + std::string(cutlassGetStatusString(can_implement)));
|
||||
|
||||
auto init_status = gemm.initialize(args);
|
||||
TLLM_CHECK_WITH_INFO(init_status == cutlass::Status::kSuccess,
|
||||
"Failed to initialize cutlass grouped gemm. Error: " + std::string(cutlassGetStatusString(init_status)));
|
||||
|
||||
auto run_status = gemm.run(stream);
|
||||
TLLM_CHECK_WITH_INFO(run_status == cutlass::Status::kSuccess,
|
||||
"Failed to run cutlass grouped gemm. Error: " + std::string(cutlassGetStatusString(run_status)));
|
||||
}
|
||||
else if constexpr (sizeof(ElementType) == 2 && sizeof(CutlassWeightType) == 2
|
||||
&& (std::is_same_v<EpilogueTag, cutlass_extensions::EpilogueOpDefaultSilu>
|
||||
|| std::is_same_v<EpilogueTag, cutlass_extensions::EpilogueOpDefaultFtGelu>) ) // use fused moe gemm
|
||||
// kernel.. (only support
|
||||
// fp16 or bf16)
|
||||
{
|
||||
sm80_generic_fused_moe_gemm_kernelLauncher<ElementType, CutlassWeightType, ThreadblockShape::kM,
|
||||
ThreadblockShape::kN, ThreadblockShape::kK, Stages, EpilogueTag>(reinterpret_cast<ElementType const*>(A),
|
||||
reinterpret_cast<CutlassWeightType const*>(B), reinterpret_cast<ElementType const*>(biases),
|
||||
bias_is_broadcast, reinterpret_cast<ElementType*>(C), total_tokens_including_expert, num_rows, gemm_n,
|
||||
gemm_k, num_experts, multi_processor_count, stream, kernel_occupancy);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace kernels::cutlass_kernels
|
||||
|
||||
template <typename T, typename WeightType, typename GemmOutputType, typename Arch, typename EpilogueTag,
|
||||
typename ThreadblockShape, typename WarpShape, int Stages>
|
||||
static void dispatch(T const* A, WeightType const* B, GemmOutputType const* weight_scales, GemmOutputType const* biases,
|
||||
bool bias_is_broadcast, GemmOutputType* C, int64_t const* total_tokens_including_expert, int64_t num_rows,
|
||||
int64_t gemm_n, int64_t gemm_k, int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config,
|
||||
int multi_processor_count, bool use_fused_moe, float const** alpha_scale_ptr_array, cudaStream_t stream,
|
||||
int* occupancy = nullptr)
|
||||
{
|
||||
|
||||
static_assert(!std::is_same_v<Arch, cutlass::arch::Sm90>, "Use TMA specialised functions for arch SM90");
|
||||
#if defined(ENABLE_FP8)
|
||||
constexpr bool isFp8 = std::is_same_v<T, __nv_fp8_e4m3> || std::is_same_v<T, __nv_fp8_e5m2>;
|
||||
#else
|
||||
constexpr bool isFp8 = false;
|
||||
#endif
|
||||
|
||||
if constexpr ((Stages == 2 || Arch::kMinComputeCapability >= 80)
|
||||
&& (!isFp8 || std::is_same_v<Arch, cutlass::arch::Sm89>) )
|
||||
{
|
||||
kernels::cutlass_kernels::genericMoeGemmKernelLauncher<T, WeightType, GemmOutputType, Arch, EpilogueTag,
|
||||
ThreadblockShape, WarpShape, Stages>(A, B, weight_scales, biases, bias_is_broadcast, C,
|
||||
total_tokens_including_expert, num_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count,
|
||||
use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_THROW(
|
||||
"Cutlass gemm. Not instantiated for arch %d with stages set to %d", Arch::kMinComputeCapability, Stages);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType, typename GemmOutputType, typename arch, typename EpilogueTag,
|
||||
typename ThreadblockShape, typename WarpShape>
|
||||
void dispatchGemmConfig(T const* A, WeightType const* B, GemmOutputType const* weight_scales,
|
||||
GemmOutputType const* biases, bool bias_is_broadcast, GemmOutputType* C,
|
||||
int64_t const* total_tokens_including_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
|
||||
cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, bool use_fused_moe,
|
||||
float const** alpha_scale_ptr_array, cudaStream_t stream, int* occupancy = nullptr)
|
||||
{
|
||||
switch (gemm_config.stages)
|
||||
{
|
||||
case 2:
|
||||
dispatch<T, WeightType, GemmOutputType, arch, EpilogueTag, ThreadblockShape, WarpShape, 2>(A, B, weight_scales,
|
||||
biases, bias_is_broadcast, C, total_tokens_including_expert, num_rows, gemm_n, gemm_k, num_experts,
|
||||
gemm_config, multi_processor_count, use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
|
||||
break;
|
||||
case 3:
|
||||
dispatch<T, WeightType, GemmOutputType, arch, EpilogueTag, ThreadblockShape, WarpShape, 3>(A, B, weight_scales,
|
||||
biases, bias_is_broadcast, C, total_tokens_including_expert, num_rows, gemm_n, gemm_k, num_experts,
|
||||
gemm_config, multi_processor_count, use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
|
||||
break;
|
||||
case 4:
|
||||
dispatch<T, WeightType, GemmOutputType, arch, EpilogueTag, ThreadblockShape, WarpShape, 4>(A, B, weight_scales,
|
||||
biases, bias_is_broadcast, C, total_tokens_including_expert, num_rows, gemm_n, gemm_k, num_experts,
|
||||
gemm_config, multi_processor_count, use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
|
||||
break;
|
||||
default: TLLM_THROW("dispatchGemmConfig does not support stages %d", gemm_config.stages); break;
|
||||
}
|
||||
}
|
||||
|
||||
// This overload will handle tensorop gemms. It is disabled via SFINAE for fp32.
|
||||
// This overload is only enabled when T == WeightType.
|
||||
template <typename T, typename WeightType, typename GemmOutputType, typename arch, typename EpilogueTag,
|
||||
typename std::enable_if<!std::is_same<T, float>::value
|
||||
#if defined(ENABLE_FP8)
|
||||
&& !std::is_same<T, __nv_fp8_e4m3>::value && !std::is_same<T, __nv_fp8_e5m2>::value
|
||||
#endif
|
||||
&& std::is_same<T, WeightType>::value>::type* = nullptr>
|
||||
void dispatchMoeGemmToCutlass(T const* A, WeightType const* B, GemmOutputType const* weight_scales,
|
||||
GemmOutputType const* biases, bool bias_is_broadcast, GemmOutputType* C,
|
||||
int64_t const* total_tokens_including_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
|
||||
cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, bool use_fused_moe,
|
||||
float const** alpha_scale_ptr_array, cudaStream_t stream, int* occupancy = nullptr)
|
||||
{
|
||||
switch (gemm_config.tile_config)
|
||||
{
|
||||
case cutlass_extensions::CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64:
|
||||
TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta");
|
||||
if constexpr (arch::kMinComputeCapability >= 75)
|
||||
{
|
||||
dispatchGemmConfig<T, WeightType, GemmOutputType, arch, EpilogueTag, cutlass::gemm::GemmShape<16, 128, 64>,
|
||||
cutlass::gemm::GemmShape<16, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C,
|
||||
total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config,
|
||||
multi_processor_count, use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
|
||||
}
|
||||
break;
|
||||
case cutlass_extensions::CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64:
|
||||
TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta");
|
||||
if constexpr (arch::kMinComputeCapability >= 75)
|
||||
{
|
||||
dispatchGemmConfig<T, WeightType, GemmOutputType, arch, EpilogueTag, cutlass::gemm::GemmShape<16, 256, 64>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C,
|
||||
total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config,
|
||||
multi_processor_count, use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
|
||||
}
|
||||
break;
|
||||
case cutlass_extensions::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64:
|
||||
dispatchGemmConfig<T, WeightType, GemmOutputType, arch, EpilogueTag, cutlass::gemm::GemmShape<32, 128, 64>,
|
||||
cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C,
|
||||
total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count,
|
||||
use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
|
||||
break;
|
||||
case cutlass_extensions::CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64:
|
||||
dispatchGemmConfig<T, WeightType, GemmOutputType, arch, EpilogueTag, cutlass::gemm::GemmShape<64, 128, 64>,
|
||||
cutlass::gemm::GemmShape<32, 64, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C,
|
||||
total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count,
|
||||
use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
|
||||
break;
|
||||
case cutlass_extensions::CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64:
|
||||
dispatchGemmConfig<T, WeightType, GemmOutputType, arch, EpilogueTag, cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C,
|
||||
total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count,
|
||||
use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
|
||||
break;
|
||||
case cutlass_extensions::CutlassTileConfig::Undefined: TLLM_THROW("GEMM config undefined."); break;
|
||||
case cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic:
|
||||
TLLM_THROW("GEMM config should have already been set by heuristic.");
|
||||
break;
|
||||
default: TLLM_THROW("Config is invalid for same type tensorop GEMM."); break;
|
||||
}
|
||||
}
|
||||
|
||||
// Tensorop GEMM overload
|
||||
// Overload for quantize MoE GEMMs. We disable some warp configs here since they will not be used and we can improve
|
||||
// compile time
|
||||
template <typename T, typename WeightType, typename GemmOutputType, typename arch, typename EpilogueTag,
|
||||
typename std::enable_if<!std::is_same<T, float>::value && !std::is_same<T, WeightType>::value>::type* = nullptr>
|
||||
void dispatchMoeGemmToCutlass(T const* A, WeightType const* B, GemmOutputType const* weight_scales,
|
||||
GemmOutputType const* biases, bool bias_is_broadcast, GemmOutputType* C,
|
||||
int64_t const* total_tokens_including_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
|
||||
cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, bool use_fused_moe,
|
||||
float const** alpha_scale_ptr_array, cudaStream_t stream, int* occupancy = nullptr)
|
||||
{
|
||||
switch (gemm_config.tile_config)
|
||||
{
|
||||
case cutlass_extensions::CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64:
|
||||
TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta");
|
||||
if constexpr (arch::kMinComputeCapability >= 75)
|
||||
{
|
||||
dispatchGemmConfig<T, WeightType, GemmOutputType, arch, EpilogueTag, cutlass::gemm::GemmShape<16, 128, 64>,
|
||||
cutlass::gemm::GemmShape<16, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C,
|
||||
total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config,
|
||||
multi_processor_count, use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
|
||||
}
|
||||
break;
|
||||
case cutlass_extensions::CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64:
|
||||
TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta");
|
||||
if constexpr (arch::kMinComputeCapability >= 75)
|
||||
{
|
||||
dispatchGemmConfig<T, WeightType, GemmOutputType, arch, EpilogueTag, cutlass::gemm::GemmShape<16, 256, 64>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C,
|
||||
total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config,
|
||||
multi_processor_count, use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
|
||||
}
|
||||
break;
|
||||
case cutlass_extensions::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64:
|
||||
dispatchGemmConfig<T, WeightType, GemmOutputType, arch, EpilogueTag, cutlass::gemm::GemmShape<32, 128, 64>,
|
||||
cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C,
|
||||
total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count,
|
||||
use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
|
||||
break;
|
||||
case cutlass_extensions::CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64:
|
||||
dispatchGemmConfig<T, WeightType, GemmOutputType, arch, EpilogueTag, cutlass::gemm::GemmShape<64, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C,
|
||||
total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count,
|
||||
use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
|
||||
break;
|
||||
case cutlass_extensions::CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64:
|
||||
dispatchGemmConfig<T, WeightType, GemmOutputType, arch, EpilogueTag, cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<128, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C,
|
||||
total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count,
|
||||
use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
|
||||
break;
|
||||
case cutlass_extensions::CutlassTileConfig::Undefined: TLLM_THROW("GEMM config undefined."); break;
|
||||
case cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic:
|
||||
TLLM_THROW("GEMM config should have already been set by heuristic.");
|
||||
break;
|
||||
default: TLLM_THROW("Config is invalid for mixed type tensorop GEMM."); break;
|
||||
}
|
||||
}
|
||||
|
||||
// This overload will handle tensorop gemms.
|
||||
// This overload is only enabled when T == WeightType and T == __nv_fp8_e4m3 or __nv_fp8_e5m2
|
||||
#if defined(ENABLE_FP8)
|
||||
template <typename T, typename WeightType, typename GemmOutputType, typename arch, typename EpilogueTag,
|
||||
typename std::enable_if<(std::is_same<T, __nv_fp8_e4m3>::value || std::is_same<T, __nv_fp8_e5m2>::value)
|
||||
&& std::is_same<T, WeightType>::value>::type* = nullptr>
|
||||
void dispatchMoeGemmToCutlass(T const* A, WeightType const* B, GemmOutputType const* weight_scales,
|
||||
GemmOutputType const* biases, bool bias_is_broadcast, GemmOutputType* C,
|
||||
int64_t const* total_tokens_including_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
|
||||
cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, bool use_fused_moe,
|
||||
float const** alpha_scale_ptr_array, cudaStream_t stream, int* occupancy = nullptr)
|
||||
{
|
||||
switch (gemm_config.tile_config)
|
||||
{
|
||||
case cutlass_extensions::CutlassTileConfig::CtaShape16x256x128_WarpShape16x64x128:
|
||||
dispatchGemmConfig<T, WeightType, GemmOutputType, arch, EpilogueTag, cutlass::gemm::GemmShape<16, 256, 128>,
|
||||
cutlass::gemm::GemmShape<16, 64, 128>>(A, B, weight_scales, biases, bias_is_broadcast, C,
|
||||
total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count,
|
||||
use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
|
||||
break;
|
||||
case cutlass_extensions::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64:
|
||||
dispatchGemmConfig<T, WeightType, GemmOutputType, arch, EpilogueTag, cutlass::gemm::GemmShape<32, 128, 64>,
|
||||
cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C,
|
||||
total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count,
|
||||
use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
|
||||
break;
|
||||
case cutlass_extensions::CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64:
|
||||
dispatchGemmConfig<T, WeightType, GemmOutputType, arch, EpilogueTag, cutlass::gemm::GemmShape<64, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C,
|
||||
total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count,
|
||||
use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
|
||||
break;
|
||||
case cutlass_extensions::CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64:
|
||||
dispatchGemmConfig<T, WeightType, GemmOutputType, arch, EpilogueTag, cutlass::gemm::GemmShape<64, 64, 128>,
|
||||
cutlass::gemm::GemmShape<32, 64, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C,
|
||||
total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count,
|
||||
use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
|
||||
break;
|
||||
case cutlass_extensions::CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64:
|
||||
dispatchGemmConfig<T, WeightType, GemmOutputType, arch, EpilogueTag, cutlass::gemm::GemmShape<128, 64, 64>,
|
||||
cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C,
|
||||
total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count,
|
||||
use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
|
||||
break;
|
||||
case cutlass_extensions::CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64:
|
||||
dispatchGemmConfig<T, WeightType, GemmOutputType, arch, EpilogueTag, cutlass::gemm::GemmShape<128, 256, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C,
|
||||
total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count,
|
||||
use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
|
||||
break;
|
||||
case cutlass_extensions::CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64:
|
||||
dispatchGemmConfig<T, WeightType, GemmOutputType, arch, EpilogueTag, cutlass::gemm::GemmShape<256, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C,
|
||||
total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count,
|
||||
use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
|
||||
break;
|
||||
case cutlass_extensions::CutlassTileConfig::Undefined: TLLM_THROW("GEMM config undefined."); break;
|
||||
case cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic:
|
||||
TLLM_THROW("GEMM config should have already been set by heuristic.");
|
||||
break;
|
||||
default: TLLM_THROW("Config is invalid for same type tensorop GEMM."); break;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// This overload will handle simt gemms. It is disabled via SFINAE for tensorop.
|
||||
template <typename T, typename WeightType, typename GemmOutputType, typename arch, typename EpilogueTag,
|
||||
typename std::enable_if<std::is_same<T, float>::value>::type* = nullptr>
|
||||
void dispatchMoeGemmToCutlass(T const* A, WeightType const* B, GemmOutputType const* weight_scales,
|
||||
GemmOutputType const* biases, bool bias_is_broadcast, GemmOutputType* C,
|
||||
int64_t const* total_tokens_including_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
|
||||
cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, bool use_fused_moe,
|
||||
float const** alpha_scale_ptr_array, cudaStream_t stream, int* occupancy = nullptr)
|
||||
{
|
||||
switch (gemm_config.tile_config)
|
||||
{
|
||||
case cutlass_extensions::CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8:
|
||||
dispatchGemmConfig<T, WeightType, GemmOutputType, arch, EpilogueTag, cutlass::gemm::GemmShape<128, 128, 8>,
|
||||
cutlass::gemm::GemmShape<64, 64, 8>>(A, B, weight_scales, biases, bias_is_broadcast, C,
|
||||
total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count,
|
||||
use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
|
||||
break;
|
||||
case cutlass_extensions::CutlassTileConfig::Undefined: TLLM_THROW("GEMM config undefined."); break;
|
||||
case cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic:
|
||||
TLLM_THROW("GEMM config should have already been set by heuristic.");
|
||||
break;
|
||||
default: TLLM_THROW("Unsupported config for float MoE gemm."); break;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
|
||||
std::vector<cutlass_extensions::CutlassGemmConfig>
|
||||
MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::getConfigs() const
|
||||
{
|
||||
return getConfigs(sm_);
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
|
||||
std::vector<cutlass_extensions::CutlassGemmConfig> MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::getConfigs(
|
||||
int sm)
|
||||
{
|
||||
std::vector<cutlass_extensions::CutlassGemmConfig> candidate_configs = getHopperConfigs(sm);
|
||||
std::vector<cutlass_extensions::CutlassGemmConfig> ampere_configs = getAmpereConfigs(sm);
|
||||
std::copy(ampere_configs.begin(), ampere_configs.end(), std::back_inserter(candidate_configs));
|
||||
|
||||
return candidate_configs;
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
|
||||
std::vector<cutlass_extensions::CutlassGemmConfig>
|
||||
MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::getAmpereConfigs(int sm)
|
||||
{
|
||||
using tensorrt_llm::cutlass_extensions::CutlassGemmConfig;
|
||||
static constexpr auto weight_only_flag
|
||||
= std::is_same<T, WeightType>::value ? CutlassGemmConfig::NONE : CutlassGemmConfig::WEIGHT_ONLY;
|
||||
static constexpr auto simt_only_flag
|
||||
= std::is_same<T, float>::value ? CutlassGemmConfig::SIMT_ONLY : CutlassGemmConfig::NONE;
|
||||
static constexpr auto fp8_only_flag = use_fp8 ? CutlassGemmConfig::FP8_ONLY : CutlassGemmConfig::NONE;
|
||||
int const max_split_k = 1;
|
||||
int const grouped_gemm_flag = CutlassGemmConfig::GROUPED_GEMM;
|
||||
int const enable_hopper = CutlassGemmConfig::NONE;
|
||||
|
||||
auto config_type_param = static_cast<CutlassGemmConfig::CandidateConfigTypeParam>(
|
||||
weight_only_flag | simt_only_flag | grouped_gemm_flag | enable_hopper | fp8_only_flag);
|
||||
|
||||
if (!kernels::cutlass_kernels::isValidAmpereMOESpecialisation<T, WeightType>())
|
||||
{
|
||||
return {};
|
||||
}
|
||||
|
||||
std::vector<cutlass_extensions::CutlassGemmConfig> ampere_configs
|
||||
= kernels::cutlass_kernels::get_candidate_configs(sm, max_split_k, config_type_param);
|
||||
return ampere_configs;
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
|
||||
std::vector<cutlass_extensions::CutlassGemmConfig>
|
||||
MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::getHopperConfigs(int sm)
|
||||
{
|
||||
using tensorrt_llm::cutlass_extensions::CutlassGemmConfig;
|
||||
static constexpr auto weight_only_flag
|
||||
= std::is_same<T, WeightType>::value ? CutlassGemmConfig::NONE : CutlassGemmConfig::WEIGHT_ONLY;
|
||||
static constexpr auto simt_only_flag
|
||||
= std::is_same<T, float>::value ? CutlassGemmConfig::SIMT_ONLY : CutlassGemmConfig::NONE;
|
||||
int const max_split_k = 1;
|
||||
int const grouped_gemm_flag = CutlassGemmConfig::GROUPED_GEMM;
|
||||
int const enable_hopper = CutlassGemmConfig::HOPPER;
|
||||
static constexpr auto fp8_only_flag = use_fp8 ? CutlassGemmConfig::FP8_ONLY : CutlassGemmConfig::NONE;
|
||||
auto config_type_param = static_cast<CutlassGemmConfig::CandidateConfigTypeParam>(
|
||||
weight_only_flag | simt_only_flag | grouped_gemm_flag | enable_hopper | fp8_only_flag);
|
||||
|
||||
if (!kernels::cutlass_kernels::isValidHopperMOESpecialisation<T, WeightType>())
|
||||
{
|
||||
return {};
|
||||
}
|
||||
|
||||
std::vector<cutlass_extensions::CutlassGemmConfig> hopper_configs
|
||||
= kernels::cutlass_kernels::get_candidate_configs(sm, max_split_k, config_type_param);
|
||||
return hopper_configs;
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
|
||||
bool MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::isHopperSpecialised(
|
||||
cutlass_extensions::CutlassGemmConfig gemm_config) const
|
||||
{
|
||||
bool config_is_sm90 = gemm_config.is_sm90;
|
||||
return supportsHopperSpecialisation() && config_is_sm90;
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
|
||||
bool MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::supportsHopperSpecialisation() const
|
||||
{
|
||||
return sm_ == 90 && kernels::cutlass_kernels::isValidHopperMOESpecialisation<T, WeightType>();
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
|
||||
int MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::getSM() const
|
||||
{
|
||||
return this->sm_;
|
||||
}
|
||||
|
||||
// currently support sm80 bf16/fp16 gate activation, only set predication tensor for m direction
|
||||
template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
|
||||
bool MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::supportsFusedGatedActivation(
|
||||
bool is_gated_activation, int gemm_n, int gemm_k) const
|
||||
{
|
||||
constexpr bool ENABLE_FUSED_GATED_ACTIVATION = true;
|
||||
return is_gated_activation && std::is_same_v<T, WeightType> && !std::is_same_v<T, float> && !use_fp8
|
||||
&& (this->getSM() >= 80) && (gemm_k % 64 == 0) && (gemm_n % 64 == 0) && ENABLE_FUSED_GATED_ACTIVATION;
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
|
||||
bool MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::isFusedGatedActivation(
|
||||
cutlass_extensions::CutlassGemmConfig gemm_config, bool is_gated_activation, int gemm_n, int gemm_k) const
|
||||
{
|
||||
return supportsFusedGatedActivation(is_gated_activation, gemm_n, gemm_k) && !gemm_config.is_sm90;
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
|
||||
MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::MoeGemmRunner()
|
||||
{
|
||||
int device{-1};
|
||||
tensorrt_llm::common::check_cuda_error(cudaGetDevice(&device));
|
||||
sm_ = tensorrt_llm::common::getSMVersion();
|
||||
tensorrt_llm::common::check_cuda_error(
|
||||
cudaDeviceGetAttribute(&multi_processor_count_, cudaDevAttrMultiProcessorCount, device));
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
|
||||
template <typename EpilogueTag>
|
||||
void MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::dispatchToArch<EpilogueTag>(T const* A,
|
||||
WeightType const* B, ScaleBiasType const* weight_scales, ScaleBiasType const* biases, bool bias_is_broadcast,
|
||||
void* C_void, int64_t const* total_tokens_including_expert, HopperGroupedGemmInput hopper_input, int64_t total_rows,
|
||||
int64_t gemm_n, int64_t gemm_k, int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config,
|
||||
bool use_fused_moe, float const** alpha_scale_ptr_array, cudaStream_t stream, int* occupancy)
|
||||
{
|
||||
static_assert(std::is_same_v<ScaleBiasType, OutputType>,
|
||||
"Separate Scale/Bias type is not supported. This is assumed to be the gemm output type");
|
||||
|
||||
// For now we always cast this to output type.
|
||||
// In the future this will vary based on what fusions are applied for FP8
|
||||
auto* C = reinterpret_cast<OutputType*>(C_void);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
sm_ >= 89 || !hopper_input.isValid(), "Hopper input information is set for non specialised implementation");
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
sm_ == 90 || !gemm_config.is_sm90, "Hopper configuration provided for non-Hopper architecture");
|
||||
|
||||
if (sm_ >= 75 && sm_ < 80)
|
||||
{
|
||||
dispatchMoeGemmToCutlass<T, WeightType, ScaleBiasType, cutlass::arch::Sm75, EpilogueTag>(A, B, weight_scales,
|
||||
biases, bias_is_broadcast, C, total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts,
|
||||
gemm_config, multi_processor_count_, use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
|
||||
}
|
||||
else if (sm_ >= 80 && sm_ < 90)
|
||||
{
|
||||
if constexpr (use_fp8)
|
||||
{
|
||||
#if defined(ENABLE_FP8)
|
||||
static_assert(!std::is_same_v<OutputType, __nv_fp8_e4m3> && !std::is_same_v<OutputType, __nv_fp8_e5m2>,
|
||||
"FP8 GEMM Output not supported");
|
||||
#endif
|
||||
|
||||
TLLM_CHECK_WITH_INFO(sm_ == 89, "For sm >= 80 and < 90, fp8 is only supported with sm == 89");
|
||||
dispatchMoeGemmToCutlass<T, WeightType, ScaleBiasType, cutlass::arch::Sm89, EpilogueTag>(A, B,
|
||||
weight_scales, biases, bias_is_broadcast, C, total_tokens_including_expert, total_rows, gemm_n, gemm_k,
|
||||
num_experts, gemm_config, multi_processor_count_, use_fused_moe, alpha_scale_ptr_array, stream,
|
||||
occupancy);
|
||||
}
|
||||
else
|
||||
{
|
||||
dispatchMoeGemmToCutlass<T, WeightType, ScaleBiasType, cutlass::arch::Sm80, EpilogueTag>(A, B,
|
||||
weight_scales, biases, bias_is_broadcast, C, total_tokens_including_expert, total_rows, gemm_n, gemm_k,
|
||||
num_experts, gemm_config, multi_processor_count_, use_fused_moe, alpha_scale_ptr_array, stream,
|
||||
occupancy);
|
||||
}
|
||||
}
|
||||
else if (sm_ >= 90)
|
||||
{
|
||||
if constexpr (kernels::cutlass_kernels::isValidHopperMOESpecialisation<T, WeightType, EpilogueTag>())
|
||||
{
|
||||
|
||||
// We allow both SM90 and SM80 configurations to coexist because for some cases with small numbers of tokens
|
||||
// SM80 is faster. We check here to see which is selected
|
||||
if (gemm_config.is_sm90)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(biases != nullptr || hopper_input.ptr_c == nullptr,
|
||||
"Input biases and hopper input disagree if bias is enabled");
|
||||
TLLM_CHECK_WITH_INFO(hopper_input.isValid(), "Calling SM90 configuration with invalid hopper config");
|
||||
|
||||
// Select the appropriate fusion function
|
||||
auto select_function = [&]()
|
||||
{
|
||||
switch (hopper_input.fusion)
|
||||
{
|
||||
case HopperGroupedGemmInput::EpilogueFusion::FINALIZE:
|
||||
return &dispatchMoeGemmSelectTileShapeSM90<T, WeightType, OutputType, EpilogueTag,
|
||||
HopperGroupedGemmInput::EpilogueFusion::FINALIZE>;
|
||||
case HopperGroupedGemmInput::EpilogueFusion::NONE:
|
||||
return &dispatchMoeGemmSelectTileShapeSM90<T, WeightType, OutputType, EpilogueTag,
|
||||
HopperGroupedGemmInput::EpilogueFusion::NONE>;
|
||||
case HopperGroupedGemmInput::EpilogueFusion::ACTIVATION:
|
||||
case HopperGroupedGemmInput::EpilogueFusion::GATED_ACTIVATION:
|
||||
default: TLLM_THROW("Unimplemented fusion %d requested", (int) hopper_input.fusion);
|
||||
};
|
||||
};
|
||||
auto selected_func = select_function();
|
||||
selected_func(
|
||||
hopper_input, num_experts, gemm_config, multi_processor_count_, stream, occupancy, nullptr);
|
||||
return;
|
||||
}
|
||||
|
||||
// Fallthrough to SM80 impl below
|
||||
}
|
||||
|
||||
// Do Ampere case instead
|
||||
if constexpr (kernels::cutlass_kernels::isValidAmpereMOESpecialisation<T, WeightType, EpilogueTag>())
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(!hopper_input.isValid(),
|
||||
"Non-specialised Hopper implementation is being rerouted to fallback implementation so input "
|
||||
"information is not required");
|
||||
TLLM_CHECK_WITH_INFO(!gemm_config.is_sm90,
|
||||
"GEMM config is for SM90 configuration, but this configuration is not valid for Hppper");
|
||||
dispatchMoeGemmToCutlass<T, WeightType, ScaleBiasType, cutlass::arch::Sm80, EpilogueTag>(A, B,
|
||||
weight_scales, biases, bias_is_broadcast, C, total_tokens_including_expert, total_rows, gemm_n, gemm_k,
|
||||
num_experts, gemm_config, multi_processor_count_, use_fused_moe, alpha_scale_ptr_array, stream,
|
||||
occupancy);
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_THROW("Configuration expects SM80 but configuration is not supported by SM80 kernels");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_THROW("Arch unsupported for MoE GEMM");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
|
||||
size_t MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::getMaxWorkspaceSize(int num_experts) const
|
||||
{
|
||||
if (num_experts != num_experts_)
|
||||
{
|
||||
TLLM_LOG_TRACE("Calling getMaxWorkspaceSize() with a new expert count %d vs %d", num_experts, num_experts_);
|
||||
num_experts_ = num_experts;
|
||||
gemm_workspace_size_ = calcMaxWorkspaceSize(num_experts);
|
||||
}
|
||||
return gemm_workspace_size_;
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
|
||||
size_t MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::calcMaxWorkspaceSize(int num_experts) const
|
||||
{
|
||||
if (!supportsHopperSpecialisation())
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
if constexpr (kernels::cutlass_kernels::isValidHopperMOESpecialisation<T, WeightType>())
|
||||
{
|
||||
auto configs = getHopperConfigs(sm_);
|
||||
size_t max_size = 0;
|
||||
bool has_config = false;
|
||||
for (auto conf : configs)
|
||||
{
|
||||
#define CALC_SIZE_FUSION(FUSION) \
|
||||
do \
|
||||
{ \
|
||||
try \
|
||||
{ \
|
||||
size_t size = calcMaxWorkspaceSizeSM90<T, WeightType, OutputType, FUSION>( \
|
||||
num_experts, conf, multi_processor_count_); \
|
||||
max_size = std::max(max_size, size); \
|
||||
has_config = true; \
|
||||
} \
|
||||
catch (tensorrt_llm::common::TllmException const& e) \
|
||||
{ \
|
||||
TLLM_LOG_TRACE("Unsupported config skipped when calculating MOE workspace size"); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
CALC_SIZE_FUSION(HopperGroupedGemmInput::EpilogueFusion::NONE);
|
||||
CALC_SIZE_FUSION(HopperGroupedGemmInput::EpilogueFusion::FINALIZE);
|
||||
}
|
||||
TLLM_CHECK_WITH_INFO(has_config, "Could not find valid config when calculating workspace size");
|
||||
return max_size;
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_THROW("Attempting to calculate Hopper GEMM workspace size with unsupported weight combination");
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
|
||||
template <typename EpilogueTag>
|
||||
void MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::runGemm(T const* A, WeightType const* B,
|
||||
ScaleBiasType const* weight_scales, ScaleBiasType const* biases, bool bias_is_broadcast, void* C,
|
||||
int64_t const* total_tokens_including_expert, HopperGroupedGemmInput hopper_input, int64_t total_rows,
|
||||
int64_t gemm_n, int64_t gemm_k, int num_experts, bool use_fused_moe, float const** alpha_scale_ptr_array,
|
||||
cudaStream_t stream, cutlass_extensions::CutlassGemmConfig chosen_conf)
|
||||
{
|
||||
dispatchToArch<EpilogueTag>(A, B, weight_scales, biases, bias_is_broadcast, C, total_tokens_including_expert,
|
||||
hopper_input, total_rows, gemm_n, gemm_k, num_experts, chosen_conf, use_fused_moe, alpha_scale_ptr_array,
|
||||
stream, nullptr);
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
|
||||
void MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::moeGemmBiasAct(T const* A, WeightType const* B,
|
||||
ScaleBiasType const* weight_scales, ScaleBiasType const* biases, bool bias_is_broadcast, void* C,
|
||||
int64_t const* total_tokens_including_expert, HopperGroupedGemmInput hopper_input, int64_t total_rows,
|
||||
int64_t gemm_n, int64_t gemm_k, int num_experts, ActivationType activation_type, bool use_fused_moe,
|
||||
float const** alpha_scale_ptr_array, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig chosen_conf)
|
||||
{
|
||||
switch (activation_type)
|
||||
{
|
||||
case ActivationType::Relu:
|
||||
runGemm<cutlass_extensions::EpilogueOpDefaultReLU>(A, B, weight_scales, biases, bias_is_broadcast, C,
|
||||
total_tokens_including_expert, hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe,
|
||||
alpha_scale_ptr_array, stream, chosen_conf);
|
||||
break;
|
||||
case ActivationType::Gelu:
|
||||
runGemm<cutlass_extensions::EpilogueOpDefaultFtGelu>(A, B, weight_scales, biases, bias_is_broadcast, C,
|
||||
total_tokens_including_expert, hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe,
|
||||
alpha_scale_ptr_array, stream, chosen_conf);
|
||||
break;
|
||||
case ActivationType::Silu:
|
||||
runGemm<cutlass_extensions::EpilogueOpDefaultSilu>(A, B, weight_scales, biases, bias_is_broadcast, C,
|
||||
total_tokens_including_expert, hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe,
|
||||
alpha_scale_ptr_array, stream, chosen_conf);
|
||||
break;
|
||||
case ActivationType::Identity:
|
||||
runGemm<cutlass_extensions::EpilogueOpDefault>(A, B, weight_scales, biases, bias_is_broadcast, C,
|
||||
total_tokens_including_expert, hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe,
|
||||
alpha_scale_ptr_array, stream, chosen_conf);
|
||||
break;
|
||||
case ActivationType::Swiglu:
|
||||
runGemm<cutlass_extensions::EpilogueOpDefaultSilu>(A, B, weight_scales, biases, bias_is_broadcast, C,
|
||||
total_tokens_including_expert, hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe,
|
||||
alpha_scale_ptr_array, stream, chosen_conf);
|
||||
break;
|
||||
case ActivationType::Geglu:
|
||||
runGemm<cutlass_extensions::EpilogueOpDefaultFtGelu>(A, B, weight_scales, biases, bias_is_broadcast, C,
|
||||
total_tokens_including_expert, hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe,
|
||||
alpha_scale_ptr_array, stream, chosen_conf);
|
||||
break;
|
||||
case ActivationType::InvalidType: TLLM_THROW("Activation type for fpA_intB must be valid."); break;
|
||||
default: TLLM_THROW("Invalid activation type."); break;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
|
||||
void MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::moeGemm(T const* A, WeightType const* B,
|
||||
ScaleBiasType const* weight_scales, void* C, int64_t const* total_tokens_including_expert,
|
||||
HopperGroupedGemmInput hopper_input, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
|
||||
bool use_fused_moe, float const** alpha_scale_ptr_array, cudaStream_t stream,
|
||||
cutlass_extensions::CutlassGemmConfig chosen_conf)
|
||||
{
|
||||
runGemm<cutlass_extensions::EpilogueOpDefault>(A, B, weight_scales, nullptr, true, C, total_tokens_including_expert,
|
||||
hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe, alpha_scale_ptr_array, stream,
|
||||
chosen_conf);
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm
|
||||
@@ -1,222 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
// Ignore CUTLASS warnings about type punning
|
||||
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
|
||||
#endif // __GNUC__
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
|
||||
#include "cutlass/gemm/device/gemm_grouped.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/group_array_problem_shape.hpp"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
|
||||
#include "cutlass_extensions/compute_occupancy.h"
|
||||
#include "cutlass_extensions/epilogue_helpers.h"
|
||||
#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h"
|
||||
#include "cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/default_mma.h"
|
||||
|
||||
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
|
||||
#pragma GCC diagnostic pop
|
||||
#endif // __GNUC__
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h"
|
||||
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h"
|
||||
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h"
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <math.h>
|
||||
#include <sstream>
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
using EpilogueFusion = HopperGroupedGemmInput::EpilogueFusion;
|
||||
|
||||
template <typename T, typename WeightType, typename OutputType, typename EpilogueTag, EpilogueFusion FUSION,
|
||||
typename TileShape, typename ClusterShape>
|
||||
void dispatchMoeGemmSelectBiasSM90(HopperGroupedGemmInput hopper_input, int num_experts, int multi_processor_count,
|
||||
cudaStream_t stream, int* occupancy, size_t* workspace_size)
|
||||
{
|
||||
static_assert(kernels::cutlass_kernels::isValidHopperMOESpecialisation<T, WeightType, EpilogueTag>(),
|
||||
"Invalid hopper configuration invoked, fallback to Sm80");
|
||||
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
workspace_size || hopper_input.isValid(), "Hopper specialisation is missing additional input information");
|
||||
|
||||
// auto func = hopper_input.ptr_c ?
|
||||
// kernels::cutlass_kernels::genericMoeGemmKernelLauncherHopper<T, WeightType,
|
||||
// cutlass::arch::Sm90, EpilogueTag, true>
|
||||
// :
|
||||
// kernels::cutlass_kernels::genericMoeGemmKernelLauncherHopper<T,
|
||||
// WeightType,
|
||||
// cutlass::arch::Sm90, EpilogueTag, false>;
|
||||
// TODO(dastokes) Re-enable bias when CUTLASS supports it
|
||||
auto func = kernels::cutlass_kernels::sm90_generic_moe_gemm_kernelLauncher<T, WeightType, OutputType, EpilogueTag,
|
||||
FUSION, TileShape, ClusterShape, false>;
|
||||
func(hopper_input, num_experts, multi_processor_count, stream, occupancy, workspace_size);
|
||||
}
|
||||
|
||||
/*
|
||||
1x1x1 cluster shape is are supported for any tile shape.
|
||||
|
||||
2x1x1 cluster shape is only supported for when the M tile is at least 128.
|
||||
|
||||
1x2x1 cluster shape is only supported when the N tile is at least 128.
|
||||
|
||||
2x2x1 cluster shape is only supported when both the M and N tiles are at least 128.
|
||||
|
||||
We make the above restrictions are to improve compilation speed in TRT-LLM by pruning kernels
|
||||
that may not be very useful in practice.
|
||||
*/
|
||||
template <typename CTAShape, typename ClusterShape>
|
||||
constexpr bool are_tile_shapes_supported()
|
||||
{
|
||||
using namespace cute;
|
||||
[[maybe_unused]] constexpr int cta_m = get<0>(CTAShape{});
|
||||
[[maybe_unused]] constexpr int cta_n = get<1>(CTAShape{});
|
||||
constexpr int cga_m = get<0>(ClusterShape{});
|
||||
constexpr int cga_n = get<1>(ClusterShape{});
|
||||
|
||||
if constexpr (cga_m == _1{} && cga_n == _1{})
|
||||
{
|
||||
return true;
|
||||
}
|
||||
else if constexpr (cga_m == _2{} && cga_n == _1{} && cta_m >= _128{})
|
||||
{
|
||||
return true;
|
||||
}
|
||||
else if constexpr (cga_m == _1{} && cga_n == _2{} && cta_n >= _128{})
|
||||
{
|
||||
return true;
|
||||
}
|
||||
else if constexpr (cga_m == _2{} && cga_n == _2{} && cta_m >= _128{} && cta_n >= _128{})
|
||||
{
|
||||
return true;
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType, typename OutputType, typename EpilogueTag, EpilogueFusion FUSION,
|
||||
typename TileShape>
|
||||
void dispatchMoeGemmSelectClusterShapeSM90(HopperGroupedGemmInput hopper_input, int num_experts,
|
||||
cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, int* occupancy,
|
||||
size_t* workspace_size)
|
||||
{
|
||||
using namespace cute;
|
||||
switch (gemm_config.cluster_shape)
|
||||
{
|
||||
#define SHAPE_CASE(M, N, K) \
|
||||
case cutlass_extensions::ClusterShape::ClusterShape_##M##x##N##x##K: \
|
||||
{ \
|
||||
using ClusterShape = Shape<_##M, _##N, _##K>; \
|
||||
if constexpr (are_tile_shapes_supported<TileShape, ClusterShape>()) \
|
||||
{ \
|
||||
dispatchMoeGemmSelectBiasSM90<T, WeightType, OutputType, EpilogueTag, FUSION, TileShape, ClusterShape>( \
|
||||
hopper_input, num_experts, multi_processor_count, stream, occupancy, workspace_size); \
|
||||
break; \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
TLLM_THROW("Unsupported tile and cluster shape combination"); \
|
||||
} \
|
||||
}
|
||||
|
||||
SHAPE_CASE(1, 1, 1)
|
||||
SHAPE_CASE(1, 2, 1)
|
||||
|
||||
SHAPE_CASE(2, 1, 1)
|
||||
SHAPE_CASE(2, 2, 1)
|
||||
|
||||
#undef SHAPE_CASE
|
||||
default: TLLM_THROW("Unsupported config for MoE gemm.");
|
||||
}
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
template <typename T, typename WeightType, typename OutputType, typename EpilogueTag, EpilogueFusion FUSION>
|
||||
void dispatchMoeGemmSelectTileShapeSM90(HopperGroupedGemmInput hopper_input, int num_experts,
|
||||
cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, int* occupancy,
|
||||
size_t* workspace_size)
|
||||
{
|
||||
using namespace cute;
|
||||
|
||||
switch (gemm_config.tile_config_sm90)
|
||||
{
|
||||
#define SHAPE_CASE(M, N, K) \
|
||||
case cutlass_extensions::CutlassTileConfigSM90::CtaShape##M##x##N##x##K##B: \
|
||||
{ \
|
||||
constexpr int KtileBytes = K / sizeof(T); \
|
||||
using KTileDim = Int<KtileBytes>; \
|
||||
using TileShape = Shape<_##M, _##N, KTileDim>; \
|
||||
dispatchMoeGemmSelectClusterShapeSM90<T, WeightType, OutputType, EpilogueTag, FUSION, TileShape>( \
|
||||
hopper_input, num_experts, gemm_config, multi_processor_count, stream, occupancy, workspace_size); \
|
||||
break; \
|
||||
}
|
||||
|
||||
SHAPE_CASE(128, 16, 128)
|
||||
SHAPE_CASE(128, 32, 128)
|
||||
SHAPE_CASE(128, 64, 128)
|
||||
SHAPE_CASE(128, 128, 128)
|
||||
SHAPE_CASE(128, 256, 128)
|
||||
SHAPE_CASE(256, 128, 128)
|
||||
|
||||
#undef SHAPE_CASE
|
||||
case cutlass_extensions::CutlassTileConfigSM90::Undefined: TLLM_THROW("GEMM config undefined."); break;
|
||||
case cutlass_extensions::CutlassTileConfigSM90::ChooseWithHeuristic:
|
||||
TLLM_THROW("GEMM config should have already been set by heuristic.");
|
||||
break;
|
||||
default: TLLM_THROW("Unsupported config for MoE gemm."); break;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType, typename OutputType, EpilogueFusion FUSION>
|
||||
size_t calcMaxWorkspaceSizeSM90(
|
||||
int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count)
|
||||
{
|
||||
size_t count;
|
||||
// Most of the values are ignored for WS size calculation. We reuse the function to reduce the template bloat
|
||||
dispatchMoeGemmSelectTileShapeSM90<T, WeightType, OutputType, cutlass_extensions::EpilogueOpDefault, FUSION>(
|
||||
HopperGroupedGemmInput{}, num_experts, gemm_config, multi_processor_count, cudaStream_t{0}, nullptr, &count);
|
||||
return count;
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm
|
||||
@@ -1,44 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/arch/mma_sm90.h"
|
||||
#include "cutlass_extensions/epilogue_helpers.h"
|
||||
|
||||
namespace tensorrt_llm::kernels::cutlass_kernels
|
||||
{
|
||||
|
||||
// Hopper arch
|
||||
template <typename T, typename WeightType, typename EpilogueTag = cutlass_extensions::EpilogueOpDefault>
|
||||
constexpr bool isValidHopperMOESpecialisation()
|
||||
{
|
||||
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
return cutlass::platform::is_same<T, WeightType>::value
|
||||
&& cutlass::platform::is_same<EpilogueTag, cutlass_extensions::EpilogueOpDefault>::value;
|
||||
#else
|
||||
return false; // CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED is set when Hopper kernels are enabled
|
||||
#endif
|
||||
}
|
||||
|
||||
// Hopper arch
|
||||
template <typename T, typename WeightType, typename EpilogueTag = cutlass_extensions::EpilogueOpDefault>
|
||||
constexpr bool isValidAmpereMOESpecialisation()
|
||||
{
|
||||
return true; // Default to true
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::kernels::cutlass_kernels
|
||||
Reference in New Issue
Block a user