add tensorrt_llm common and cutlass_extensions as 3rdparty (#3216)
Co-authored-by: BBuf <35585791+BBuf@users.noreply.github.com>
This commit is contained in:
1
.clang-format-ignore
Normal file
1
.clang-format-ignore
Normal file
@@ -0,0 +1 @@
|
||||
sgl-kernel/3rdparty/tensorrt_llm/*
|
||||
22
sgl-kernel/3rdparty/tensorrt_llm/common/CMakeLists.txt
vendored
Normal file
22
sgl-kernel/3rdparty/tensorrt_llm/common/CMakeLists.txt
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION &
|
||||
# AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
|
||||
# use this file except in compliance with the License. You may obtain a copy of
|
||||
# the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations under
|
||||
# the License.
|
||||
#
|
||||
file(GLOB SRCS *.cpp)
|
||||
file(GLOB CU_SRCS *.cu)
|
||||
|
||||
add_library(common_src OBJECT ${SRCS} ${CU_SRCS})
|
||||
set_property(TARGET common_src PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||
set_property(TARGET common_src PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
|
||||
34
sgl-kernel/3rdparty/tensorrt_llm/common/assert.cpp
vendored
Executable file
34
sgl-kernel/3rdparty/tensorrt_llm/common/assert.cpp
vendored
Executable file
@@ -0,0 +1,34 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
bool initCheckDebug()
|
||||
{
|
||||
auto constexpr kDebugEnabled = "TLLM_DEBUG_MODE";
|
||||
auto const debugEnabled = std::getenv(kDebugEnabled);
|
||||
return debugEnabled && debugEnabled[0] == '1';
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool DebugConfig::isCheckDebugEnabled()
|
||||
{
|
||||
static bool const debugEnabled = initCheckDebug();
|
||||
return debugEnabled;
|
||||
}
|
||||
360
sgl-kernel/3rdparty/tensorrt_llm/common/cublasMMWrapper.cpp
vendored
Normal file
360
sgl-kernel/3rdparty/tensorrt_llm/common/cublasMMWrapper.cpp
vendored
Normal file
@@ -0,0 +1,360 @@
|
||||
/*
|
||||
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/common/cublasMMWrapper.h"
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/cublasVersionCheck.h"
|
||||
#include <algorithm>
|
||||
|
||||
#ifndef CUDART_VERSION
|
||||
#error CUDART_VERSION Undefined!
|
||||
#endif
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace common
|
||||
{
|
||||
|
||||
CublasMMWrapper::CublasMMWrapper(std::shared_ptr<cublasHandle_t> cublasHandle,
|
||||
std::shared_ptr<cublasLtHandle_t> cublasltHandle, cudaStream_t stream, void* workspace)
|
||||
: mCublasHandle(cublasHandle)
|
||||
, mCublasLtHandle(cublasltHandle)
|
||||
, mStream(stream)
|
||||
, mCublasWorkspace(workspace)
|
||||
{
|
||||
}
|
||||
|
||||
CublasMMWrapper::~CublasMMWrapper() {}
|
||||
|
||||
CublasMMWrapper::CublasMMWrapper(CublasMMWrapper const& wrapper)
|
||||
: mCublasHandle(wrapper.mCublasHandle)
|
||||
, mCublasLtHandle(wrapper.mCublasLtHandle)
|
||||
, mStream(wrapper.mStream)
|
||||
{
|
||||
}
|
||||
|
||||
void CublasMMWrapper::createDescriptors(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n,
|
||||
int const k, int const lda, int const ldb, int const ldc, int8_t fastAcc)
|
||||
{
|
||||
// --------------------------------------
|
||||
// Create descriptors for the original matrices
|
||||
check_cuda_error(
|
||||
cublasLtMatrixLayoutCreate(&mADesc, mAType, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda));
|
||||
check_cuda_error(
|
||||
cublasLtMatrixLayoutCreate(&mBDesc, mBType, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb));
|
||||
check_cuda_error(cublasLtMatrixLayoutCreate(&mCDesc, mCType, m, n, ldc));
|
||||
check_cuda_error(cublasLtMatmulDescCreate(&mOperationDesc, mComputeType, mScaleType));
|
||||
check_cuda_error(cublasLtMatmulDescSetAttribute(
|
||||
mOperationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(cublasOperation_t)));
|
||||
check_cuda_error(cublasLtMatmulDescSetAttribute(
|
||||
mOperationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(cublasOperation_t)));
|
||||
check_cuda_error(
|
||||
cublasLtMatmulDescSetAttribute(mOperationDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fastAcc, sizeof(int8_t)));
|
||||
}
|
||||
|
||||
void CublasMMWrapper::setScaleDescriptors(void* scale_a, void* scale_b)
|
||||
{
|
||||
check_cuda_error(
|
||||
cublasLtMatmulDescSetAttribute(mOperationDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &scale_a, sizeof(void*)));
|
||||
check_cuda_error(
|
||||
cublasLtMatmulDescSetAttribute(mOperationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &scale_b, sizeof(void*)));
|
||||
}
|
||||
|
||||
void CublasMMWrapper::destroyDescriptors()
|
||||
{
|
||||
check_cuda_error(cublasLtMatmulDescDestroy(mOperationDesc));
|
||||
check_cuda_error(cublasLtMatrixLayoutDestroy(mADesc));
|
||||
check_cuda_error(cublasLtMatrixLayoutDestroy(mBDesc));
|
||||
check_cuda_error(cublasLtMatrixLayoutDestroy(mCDesc));
|
||||
mOperationDesc = NULL;
|
||||
mADesc = NULL;
|
||||
mBDesc = NULL;
|
||||
mCDesc = NULL;
|
||||
}
|
||||
|
||||
void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
|
||||
void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc)
|
||||
{
|
||||
Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, 1.0f, 0.0f);
|
||||
}
|
||||
|
||||
void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
|
||||
void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc,
|
||||
std::optional<cublasLtMatmulHeuristicResult_t> const& heuristic)
|
||||
{
|
||||
if (heuristic)
|
||||
{
|
||||
Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, 1.0f, 0.0f, /* hasAlgo */ (*heuristic).algo,
|
||||
(*heuristic).state == CUBLAS_STATUS_SUCCESS && (*heuristic).workspaceSize < CUBLAS_WORKSPACE_SIZE,
|
||||
/* usingCublasLt */ true);
|
||||
}
|
||||
else
|
||||
{
|
||||
Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, 1.0f, 0.0f, {}, /* hasAlgo */ false,
|
||||
/* usingCublasLt */ true);
|
||||
}
|
||||
}
|
||||
|
||||
void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
|
||||
void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta,
|
||||
std::optional<cublasLtMatmulHeuristicResult_t> const& heuristic)
|
||||
{
|
||||
if (heuristic)
|
||||
{
|
||||
Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, f_alpha, f_beta, /* hasAlgo */ (*heuristic).algo,
|
||||
(*heuristic).state == CUBLAS_STATUS_SUCCESS && (*heuristic).workspaceSize < CUBLAS_WORKSPACE_SIZE,
|
||||
/* usingCublasLt */ true);
|
||||
}
|
||||
else
|
||||
{
|
||||
Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, f_alpha, f_beta, {}, /* hasAlgo */ false,
|
||||
/* usingCublasLt */ true);
|
||||
}
|
||||
}
|
||||
|
||||
void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
|
||||
void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta)
|
||||
{
|
||||
bool usingCublasLt = mAType == CUDA_R_16F || mAType == CUDA_R_8F_E4M3;
|
||||
|
||||
Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, f_alpha, f_beta, {}, /* hasAlgo */ false,
|
||||
/* usingCublasLt */ usingCublasLt);
|
||||
}
|
||||
|
||||
void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
|
||||
void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta,
|
||||
cublasLtMatmulAlgo_t const& algo, bool hasAlgo, bool usingCublasLt)
|
||||
{
|
||||
half h_alpha = (half) (f_alpha);
|
||||
half h_beta = (half) (f_beta);
|
||||
|
||||
// TODO: default cublas libs
|
||||
usingCublasLt = usingCublasLt && (mAType == CUDA_R_16F || mAType == CUDA_R_8F_E4M3);
|
||||
bool isFp16ComputeType = mComputeType == CUBLAS_COMPUTE_16F;
|
||||
int batch_count = 1;
|
||||
// fp32 use cublas as default
|
||||
// fp16 use cublasLt as default
|
||||
void const* alpha = isFp16ComputeType ? reinterpret_cast<void*>(&h_alpha) : reinterpret_cast<void*>(&f_alpha);
|
||||
void const* beta = isFp16ComputeType ? reinterpret_cast<void*>(&h_beta) : reinterpret_cast<void*>(&f_beta);
|
||||
int workspaceSize = mCublasWorkspace == NULL ? 0 : CUBLAS_WORKSPACE_SIZE;
|
||||
|
||||
if (usingCublasLt)
|
||||
{
|
||||
if (hasAlgo)
|
||||
{
|
||||
hasAlgo = checkTactic(transa, transb, m, n, k, lda, ldb, ldc, algo);
|
||||
}
|
||||
|
||||
check_cuda_error(cublasLtMatmul(getCublasLtHandle(), mOperationDesc, alpha, A, mADesc, B, mBDesc, beta, C,
|
||||
mCDesc, C, mCDesc, (hasAlgo ? (&algo) : NULL), mCublasWorkspace, workspaceSize, mStream));
|
||||
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
else
|
||||
{
|
||||
check_cuda_error(cublasSetStream(getCublasHandle(), mStream));
|
||||
check_cuda_error(cublasSetWorkspace(getCublasHandle(), mCublasWorkspace, workspaceSize));
|
||||
// Go with default heuristic to choose tactic as cuBLAS does not allow to choose tactics in Ampere+
|
||||
cublasGemmAlgo_t cublasAlgo = CUBLAS_GEMM_DEFAULT;
|
||||
check_cuda_error(cublasGemmEx(getCublasHandle(), transa, transb, m, n, k, alpha, A, mAType, lda, B, mBType, ldb,
|
||||
beta, C, mCType, ldc, mComputeType, static_cast<cublasGemmAlgo_t>(cublasAlgo)));
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
}
|
||||
|
||||
void CublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n,
|
||||
int const k, void const* A, int const lda, const int64_t strideA, void const* B, int const ldb,
|
||||
const int64_t strideB, void* C, int const ldc, const int64_t strideC, int const batchCount, float const f_alpha,
|
||||
float const f_beta)
|
||||
{
|
||||
half h_alpha = (half) f_alpha;
|
||||
half h_beta = (half) f_beta;
|
||||
|
||||
int isFp16ComputeType = mComputeType == CUBLAS_COMPUTE_16F ? 1 : 0;
|
||||
void const* alpha = isFp16ComputeType ? reinterpret_cast<void*>(&h_alpha) : reinterpret_cast<void const*>(&f_alpha);
|
||||
void const* beta = isFp16ComputeType ? reinterpret_cast<void*>(&h_beta) : reinterpret_cast<void const*>(&f_beta);
|
||||
|
||||
check_cuda_error(cublasGemmStridedBatchedEx(getCublasHandle(), transa, transb, m, n, k, alpha, A, mAType, lda,
|
||||
strideA, B, mBType, ldb, strideB, beta, C, mCType, ldc, strideC, batchCount, mComputeType,
|
||||
mAType == CUDA_R_32F ? CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
}
|
||||
|
||||
void CublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n,
|
||||
int const k, float const f_alpha, void const* A, cudaDataType_t AType, int const lda, const int64_t strideA,
|
||||
void const* B, cudaDataType_t BType, int const ldb, const int64_t strideB, float const f_beta, void* C,
|
||||
cudaDataType_t CType, int const ldc, const int64_t strideC, int const batchCount, cudaDataType_t computeType)
|
||||
{
|
||||
half h_alpha = (half) f_alpha;
|
||||
half h_beta = (half) f_beta;
|
||||
|
||||
bool isFp16ComputeType = mComputeType == CUBLAS_COMPUTE_16F ? 1 : 0;
|
||||
void const* alpha = isFp16ComputeType ? reinterpret_cast<void*>(&h_alpha) : reinterpret_cast<void const*>(&f_alpha);
|
||||
void const* beta = isFp16ComputeType ? reinterpret_cast<void*>(&h_beta) : reinterpret_cast<void const*>(&f_beta);
|
||||
|
||||
check_cuda_error(cublasGemmStridedBatchedEx(getCublasHandle(), transa, transb, m, n, k, alpha, A, AType, lda,
|
||||
strideA, B, BType, ldb, strideB, beta, C, CType, ldc, strideC, batchCount, computeType,
|
||||
mAType == CUDA_R_32F ? CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
}
|
||||
|
||||
void CublasMMWrapper::setWorkspace(void* workspace)
|
||||
{
|
||||
mCublasWorkspace = workspace;
|
||||
}
|
||||
|
||||
void CublasMMWrapper::setFP32GemmConfig()
|
||||
{
|
||||
setGemmConfig(CUDA_R_32F, CUDA_R_32F, CUDA_R_32F, CUDA_R_32F);
|
||||
}
|
||||
|
||||
void CublasMMWrapper::setFP16GemmConfig(cudaDataType_t outputType)
|
||||
{
|
||||
setGemmConfig(CUDA_R_16F, CUDA_R_16F, outputType, CUDA_R_32F);
|
||||
}
|
||||
|
||||
#ifdef ENABLE_BF16
|
||||
void CublasMMWrapper::setBF16GemmConfig(cudaDataType_t outputType)
|
||||
{
|
||||
setGemmConfig(CUDA_R_16BF, CUDA_R_16BF, outputType, CUDA_R_32F);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_FP8
|
||||
void CublasMMWrapper::setFP8GemmConfig(cudaDataType_t outputType)
|
||||
{
|
||||
setGemmConfig(CUDA_R_8F_E4M3, CUDA_R_8F_E4M3, outputType, CUDA_R_32F);
|
||||
}
|
||||
#endif
|
||||
|
||||
void CublasMMWrapper::setGemmConfig(
|
||||
cudaDataType_t aType, cudaDataType_t bType, cudaDataType_t cType, cudaDataType_t computeType)
|
||||
{
|
||||
mAType = aType;
|
||||
mBType = bType;
|
||||
mCType = cType;
|
||||
bool isFp16ComputeType = computeType == CUDA_R_16F;
|
||||
if (isFp16ComputeType)
|
||||
{
|
||||
mComputeType = CUBLAS_COMPUTE_16F;
|
||||
mScaleType = CUDA_R_16F;
|
||||
}
|
||||
else
|
||||
{
|
||||
mComputeType = CUBLAS_COMPUTE_32F;
|
||||
mScaleType = CUDA_R_32F;
|
||||
}
|
||||
}
|
||||
|
||||
CublasDataType CublasMMWrapper::getCublasDataType(cudaDataType_t data_type)
|
||||
{
|
||||
if (data_type == CUDA_R_16F)
|
||||
{
|
||||
return HALF_DATATYPE;
|
||||
}
|
||||
else if (data_type == CUDA_R_32F)
|
||||
{
|
||||
return FLOAT_DATATYPE;
|
||||
}
|
||||
else if (data_type == CUDA_R_8I)
|
||||
{
|
||||
return INT8_DATATYPE;
|
||||
}
|
||||
#ifdef ENABLE_BF16
|
||||
else if (data_type == CUDA_R_16BF)
|
||||
{
|
||||
return BFLOAT16_DATATYPE;
|
||||
}
|
||||
#endif
|
||||
return FLOAT_DATATYPE;
|
||||
}
|
||||
|
||||
void CublasMMWrapper::setStream(cudaStream_t stream)
|
||||
{
|
||||
mStream = stream;
|
||||
}
|
||||
|
||||
bool CublasMMWrapper::checkTactic(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n,
|
||||
int const k, int const lda, int const ldb, int const ldc, cublasLtMatmulAlgo_t const& algo)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
descriptorsCreated(), "Descriptors are not created! Call createDescriptors before calling this function");
|
||||
|
||||
int workspaceSize = mCublasWorkspace == NULL ? 0 : CUBLAS_WORKSPACE_SIZE;
|
||||
|
||||
cublasLtMatmulHeuristicResult_t heurResult;
|
||||
cublasStatus_t algoStatus = cublasLtMatmulAlgoCheck(
|
||||
getCublasLtHandle(), mOperationDesc, mADesc, mBDesc, mCDesc, mCDesc, &algo, &heurResult);
|
||||
|
||||
if (algoStatus != CUBLAS_STATUS_SUCCESS || heurResult.state != CUBLAS_STATUS_SUCCESS
|
||||
|| heurResult.workspaceSize > CUBLAS_WORKSPACE_SIZE)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
sync_check_cuda_error();
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<cublasLtMatmulHeuristicResult_t> CublasMMWrapper::getTactics(cublasOperation_t transa,
|
||||
cublasOperation_t transb, int const m, int const n, int const k, int const lda, int const ldb, int const ldc)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
descriptorsCreated(), "Descriptors are not created! Call createDescriptors before calling this function");
|
||||
|
||||
auto const heuristics = getTactics(getCublasLtHandle(), mOperationDesc, mADesc, mBDesc, mCDesc, mCDesc);
|
||||
|
||||
sync_check_cuda_error();
|
||||
|
||||
return heuristics;
|
||||
}
|
||||
|
||||
std::vector<cublasLtMatmulHeuristicResult_t> CublasMMWrapper::getTactics(cublasLtHandle_t lightHandle,
|
||||
cublasLtMatmulDesc_t computeDesc, cublasLtMatrixLayout_t Adesc, cublasLtMatrixLayout_t Bdesc,
|
||||
cublasLtMatrixLayout_t Cdesc, cublasLtMatrixLayout_t Ddesc)
|
||||
{
|
||||
#if TLLM_CUBLAS_VER_LE(11, 4, 2)
|
||||
TLLM_CHECK_WITH_INFO(false, "CUBLAS version too low, must be > 11.4.2.");
|
||||
return {};
|
||||
#else
|
||||
std::vector<cublasLtMatmulHeuristicResult_t> heuristics(200);
|
||||
cublasLtMatmulPreference_t preference;
|
||||
check_cuda_error(cublasLtMatmulPreferenceCreate(&preference));
|
||||
check_cuda_error(cublasLtMatmulPreferenceInit(preference));
|
||||
uint64_t workspace_size = CUBLAS_WORKSPACE_SIZE;
|
||||
check_cuda_error(cublasLtMatmulPreferenceSetAttribute(
|
||||
preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspace_size, sizeof(workspace_size)));
|
||||
// Restrict reduction algorithms for numerical stability and better determinism
|
||||
uint32_t reduction_mask = CUBLASLT_REDUCTION_SCHEME_MASK;
|
||||
check_cuda_error(cublasLtMatmulPreferenceSetAttribute(
|
||||
preference, CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK, &reduction_mask, sizeof(reduction_mask)));
|
||||
#if TLLM_CUBLAS_VER_LT(12, 0, 0)
|
||||
uint32_t pointer_mode_mask = 0;
|
||||
check_cuda_error(cublasLtMatmulPreferenceSetAttribute(
|
||||
preference, CUBLASLT_MATMUL_PREF_EPILOGUE_MASK, &pointer_mode_mask, sizeof(pointer_mode_mask)));
|
||||
#endif
|
||||
|
||||
int return_count = 0;
|
||||
check_cuda_error(cublasLtMatmulAlgoGetHeuristic(lightHandle, computeDesc, Adesc, Bdesc, Cdesc, Ddesc, preference,
|
||||
heuristics.size(), heuristics.data(), &return_count));
|
||||
heuristics.resize(return_count);
|
||||
|
||||
return heuristics;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace common
|
||||
|
||||
} // namespace tensorrt_llm
|
||||
148
sgl-kernel/3rdparty/tensorrt_llm/common/cublasMMWrapper.h
vendored
Normal file
148
sgl-kernel/3rdparty/tensorrt_llm/common/cublasMMWrapper.h
vendored
Normal file
@@ -0,0 +1,148 @@
|
||||
/*
|
||||
* Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include <cublasLt.h>
|
||||
#include <cublas_v2.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <map>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace common
|
||||
{
|
||||
|
||||
class CublasMMWrapper
|
||||
{
|
||||
protected:
|
||||
std::shared_ptr<cublasHandle_t> mCublasHandle;
|
||||
std::shared_ptr<cublasLtHandle_t> mCublasLtHandle;
|
||||
|
||||
cudaDataType_t mAType{};
|
||||
cudaDataType_t mBType{};
|
||||
cudaDataType_t mCType{};
|
||||
cublasComputeType_t mComputeType{};
|
||||
cudaDataType_t mScaleType{};
|
||||
|
||||
cublasLtMatmulDesc_t mOperationDesc{NULL};
|
||||
cublasLtMatrixLayout_t mADesc{NULL};
|
||||
cublasLtMatrixLayout_t mBDesc{NULL};
|
||||
cublasLtMatrixLayout_t mCDesc{NULL};
|
||||
|
||||
cudaStream_t mStream;
|
||||
|
||||
void* mCublasWorkspace = nullptr;
|
||||
|
||||
private:
|
||||
bool descriptorsCreated() const
|
||||
{
|
||||
return mOperationDesc != NULL && mADesc != NULL && mBDesc != NULL && mCDesc != NULL;
|
||||
}
|
||||
|
||||
public:
|
||||
CublasMMWrapper(std::shared_ptr<cublasHandle_t> cublasHandle, std::shared_ptr<cublasLtHandle_t> cublasLtHandle,
|
||||
cudaStream_t stream, void* workspace);
|
||||
|
||||
~CublasMMWrapper();
|
||||
|
||||
CublasMMWrapper(CublasMMWrapper const& wrapper);
|
||||
|
||||
/********************** GEMMs **********************/
|
||||
void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A,
|
||||
int const lda, void const* B, int const ldb, void* C, int const ldc);
|
||||
|
||||
void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A,
|
||||
int const lda, void const* B, int const ldb, void* C, int const ldc,
|
||||
std::optional<cublasLtMatmulHeuristicResult_t> const& algo);
|
||||
|
||||
void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A,
|
||||
int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta,
|
||||
std::optional<cublasLtMatmulHeuristicResult_t> const& algo);
|
||||
|
||||
void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A,
|
||||
int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta);
|
||||
|
||||
void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A,
|
||||
int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta,
|
||||
cublasLtMatmulAlgo_t const& algo, bool hasAlgo, bool usingCublasLt);
|
||||
|
||||
void stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
|
||||
void const* A, int const lda, const int64_t strideA, void const* B, int const ldb, const int64_t strideB,
|
||||
void* C, int const ldc, const int64_t strideC, int const batchCount, float const f_alpha = 1.0f,
|
||||
float const f_beta = 0.0f);
|
||||
|
||||
void stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
|
||||
float const f_alpha, void const* A, cudaDataType_t AType, int const lda, const int64_t strideA, void const* B,
|
||||
cudaDataType_t BType, int const ldb, const int64_t strideB, float const f_beta, void* C, cudaDataType_t CType,
|
||||
int const ldc, const int64_t strideC, int const batchCount, cudaDataType_t computeType);
|
||||
|
||||
/********************** Tactic selection helpers **********************/
|
||||
bool checkTactic(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
|
||||
int const lda, int const ldb, int const ldc, cublasLtMatmulAlgo_t const& algo);
|
||||
|
||||
std::vector<cublasLtMatmulHeuristicResult_t> getTactics(cublasOperation_t transa, cublasOperation_t transb,
|
||||
int const m, int const n, int const k, int const lda, int const ldb, int const ldc);
|
||||
|
||||
std::vector<cublasLtMatmulHeuristicResult_t> getTactics(cublasLtHandle_t lightHandle,
|
||||
cublasLtMatmulDesc_t computeDesc, cublasLtMatrixLayout_t Adesc, cublasLtMatrixLayout_t Bdesc,
|
||||
cublasLtMatrixLayout_t Cdesc, cublasLtMatrixLayout_t Ddesc);
|
||||
|
||||
using MatrixLayout = std::tuple<cudaDataType_t, cublasLtOrder_t, uint64_t, uint64_t>;
|
||||
using cache_idx_t = std::tuple<cublasLtMatmulDesc_t, std::array<MatrixLayout, 4>>;
|
||||
|
||||
MatrixLayout createMatrixLayout(cublasLtMatrixLayout_t Mdesc);
|
||||
|
||||
/********************** Utils **********************/
|
||||
void setWorkspace(void* workspace);
|
||||
|
||||
void setFP32GemmConfig();
|
||||
void setFP16GemmConfig(cudaDataType_t outputType = CUDA_R_16F);
|
||||
#ifdef ENABLE_BF16
|
||||
void setBF16GemmConfig(cudaDataType_t outputType = CUDA_R_16BF);
|
||||
#endif
|
||||
#ifdef ENABLE_FP8
|
||||
void setFP8GemmConfig(cudaDataType_t outputType = CUDA_R_16F);
|
||||
#endif
|
||||
|
||||
void setStream(cudaStream_t stream);
|
||||
|
||||
void setGemmConfig(cudaDataType_t aType, cudaDataType_t bType, cudaDataType_t cType, cudaDataType_t computeType);
|
||||
|
||||
CublasDataType getCublasDataType(cudaDataType_t data_type);
|
||||
|
||||
void createDescriptors(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
|
||||
int const lda, int const ldb, int const ldc, int8_t fastAcc = 0);
|
||||
void setScaleDescriptors(void* scale_a, void* scale_b);
|
||||
void destroyDescriptors();
|
||||
|
||||
cublasHandle_t getCublasHandle()
|
||||
{
|
||||
return *(this->mCublasHandle);
|
||||
}
|
||||
|
||||
cublasLtHandle_t getCublasLtHandle() const
|
||||
{
|
||||
return *(this->mCublasLtHandle);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace common
|
||||
|
||||
} // namespace tensorrt_llm
|
||||
35
sgl-kernel/3rdparty/tensorrt_llm/common/cublasVersionCheck.h
vendored
Normal file
35
sgl-kernel/3rdparty/tensorrt_llm/common/cublasVersionCheck.h
vendored
Normal file
@@ -0,0 +1,35 @@
|
||||
/*
|
||||
* Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
// We don't want to include cublas_api.h. It contains the CUBLAS_VER_* macro
|
||||
// definition which is not sufficient to determine if we include cublas.h,
|
||||
// cublas_v2.h or cublasLt.h.
|
||||
|
||||
#define TLLM_CUBLAS_VERSION_CALC(MAJOR, MINOR, PATCH) (MAJOR * 10000 + MINOR * 100 + PATCH)
|
||||
#define TLLM_CUBLAS_VER_LE(MAJOR, MINOR, PATCH) \
|
||||
TLLM_CUBLAS_VERSION_CALC(CUBLAS_VER_MAJOR, CUBLAS_VER_MINOR, CUBLAS_VER_PATCH) \
|
||||
<= TLLM_CUBLAS_VERSION_CALC(MAJOR, MINOR, PATCH)
|
||||
#define TLLM_CUBLAS_VER_LT(MAJOR, MINOR, PATCH) \
|
||||
TLLM_CUBLAS_VERSION_CALC(CUBLAS_VER_MAJOR, CUBLAS_VER_MINOR, CUBLAS_VER_PATCH) \
|
||||
< TLLM_CUBLAS_VERSION_CALC(MAJOR, MINOR, PATCH)
|
||||
#define TLLM_CUBLAS_VER_GE(MAJOR, MINOR, PATCH) \
|
||||
TLLM_CUBLAS_VERSION_CALC(CUBLAS_VER_MAJOR, CUBLAS_VER_MINOR, CUBLAS_VER_PATCH) \
|
||||
>= TLLM_CUBLAS_VERSION_CALC(MAJOR, MINOR, PATCH)
|
||||
#define TLLM_CUBLAS_VER_GT(MAJOR, MINOR, PATCH) \
|
||||
TLLM_CUBLAS_VERSION_CALC(CUBLAS_VER_MAJOR, CUBLAS_VER_MINOR, CUBLAS_VER_PATCH) \
|
||||
> TLLM_CUBLAS_VERSION_CALC(MAJOR, MINOR, PATCH)
|
||||
313
sgl-kernel/3rdparty/tensorrt_llm/common/cudaBf16Fallbacks.cuh
vendored
Normal file
313
sgl-kernel/3rdparty/tensorrt_llm/common/cudaBf16Fallbacks.cuh
vendored
Normal file
@@ -0,0 +1,313 @@
|
||||
/*
|
||||
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/cudaBf16Wrapper.h"
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace common
|
||||
{
|
||||
|
||||
#ifdef ENABLE_BF16
|
||||
inline __device__ float2 bf1622float2(const __nv_bfloat162 val)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
float2 f_val;
|
||||
f_val.x = __low2float(val);
|
||||
f_val.y = __high2float(val);
|
||||
return f_val;
|
||||
#else
|
||||
return __bfloat1622float2(val);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ int16_t bf1622int16(__nv_bfloat162 val)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
float2 f_val;
|
||||
f_val.x = max(min(__low2float(val), 127.f), -128.f);
|
||||
f_val.y = max(min(__high2float(val), 127.f), -128.f);
|
||||
|
||||
union
|
||||
{
|
||||
int8_t int8[2];
|
||||
int16_t int16;
|
||||
};
|
||||
|
||||
int8[0] = static_cast<int8_t>(static_cast<short>(f_val.x));
|
||||
int8[1] = static_cast<int8_t>(static_cast<short>(f_val.y));
|
||||
return int16;
|
||||
#else
|
||||
val = __hmin2(val, make_bfloat162(127., 127.));
|
||||
val = __hmax2(val, make_bfloat162(-128., -128.));
|
||||
|
||||
union
|
||||
{
|
||||
int8_t int8[2];
|
||||
int16_t int16;
|
||||
};
|
||||
|
||||
int8[0] = static_cast<int8_t>(static_cast<short>(val.x));
|
||||
int8[1] = static_cast<int8_t>(static_cast<short>(val.y));
|
||||
return int16;
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ __nv_bfloat162 float22bf162(const float2 val)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
return __floats2bfloat162_rn(val.x, val.y);
|
||||
#else
|
||||
return __float22bfloat162_rn(val);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
__nv_bfloat162 val2;
|
||||
val2.x = val;
|
||||
val2.y = val;
|
||||
return val2;
|
||||
#else
|
||||
return __bfloat162bfloat162(val);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ __nv_bfloat162 bf16hadd2(const __nv_bfloat162 x, const __nv_bfloat162 y)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
float fxl, fxh, fyl, fyh;
|
||||
fxl = __low2float(x);
|
||||
fxh = __high2float(x);
|
||||
fyl = __low2float(y);
|
||||
fyh = __high2float(y);
|
||||
return __floats2bfloat162_rn(fxl + fyl, fxh + fyh);
|
||||
#else
|
||||
return __hadd2(x, y);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ __nv_bfloat16 bf16hadd(const __nv_bfloat16 x, const __nv_bfloat16 y)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
return __float2bfloat16(__bfloat162float(x) + __bfloat162float(y));
|
||||
#else
|
||||
return __hadd(x, y);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ __nv_bfloat162 bf16hsub2(const __nv_bfloat162 x, const __nv_bfloat162 y)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
float fxl, fxh, fyl, fyh;
|
||||
fxl = __low2float(x);
|
||||
fxh = __high2float(x);
|
||||
fyl = __low2float(y);
|
||||
fyh = __high2float(y);
|
||||
return __floats2bfloat162_rn(fxl - fyl, fxh - fyh);
|
||||
#else
|
||||
return __hsub2(x, y);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ __nv_bfloat16 bf16hsub(const __nv_bfloat16 x, const __nv_bfloat16 y)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
return __float2bfloat16(__bfloat162float(x) - __bfloat162float(y));
|
||||
#else
|
||||
return __hsub(x, y);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ __nv_bfloat162 bf16hmul2(const __nv_bfloat162 x, const __nv_bfloat162 y)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
float fxl, fxh, fyl, fyh;
|
||||
fxl = __low2float(x);
|
||||
fxh = __high2float(x);
|
||||
fyl = __low2float(y);
|
||||
fyh = __high2float(y);
|
||||
return __floats2bfloat162_rn(fxl * fyl, fxh * fyh);
|
||||
#else
|
||||
return __hmul2(x, y);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ __nv_bfloat16 bf16hmul(const __nv_bfloat16 x, const __nv_bfloat16 y)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
return __float2bfloat16(__bfloat162float(x) * __bfloat162float(y));
|
||||
#else
|
||||
return __hmul(x, y);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ __nv_bfloat162 bf16hfma2(const __nv_bfloat162 x, const __nv_bfloat162 y, const __nv_bfloat162 z)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
float fxl, fxh, fyl, fyh, fzl, fzh;
|
||||
fxl = __low2float(x);
|
||||
fxh = __high2float(x);
|
||||
fyl = __low2float(y);
|
||||
fyh = __high2float(y);
|
||||
fzl = __low2float(z);
|
||||
fzh = __high2float(z);
|
||||
return __floats2bfloat162_rn(fxl * fyl + fzl, fxh * fyh + fzh);
|
||||
#else
|
||||
return __hfma2(x, y, z);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ __nv_bfloat16 bf16hfma(const __nv_bfloat16 x, const __nv_bfloat16 y, const __nv_bfloat16 z)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
return __float2bfloat16(__bfloat162float(x) * __bfloat162float(y) + __bfloat162float(z));
|
||||
#else
|
||||
return __hfma(x, y, z);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ __nv_bfloat162 bf16exp2(const __nv_bfloat162 x)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
float fxl, fxh;
|
||||
fxl = __low2float(x);
|
||||
fxh = __high2float(x);
|
||||
;
|
||||
return __floats2bfloat162_rn(expf(fxl), expf(fxh));
|
||||
#else
|
||||
return h2exp(x);
|
||||
#endif
|
||||
}
|
||||
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
|
||||
#if defined(CUDART_VERSION) && (CUDART_VERSION < 12020)
|
||||
|
||||
inline __device__ __nv_bfloat162 make_bfloat162(const __nv_bfloat16 x, const __nv_bfloat16 y)
|
||||
{
|
||||
__nv_bfloat162 t;
|
||||
t.x = x;
|
||||
t.y = y;
|
||||
return t;
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
||||
inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c));
|
||||
#else
|
||||
return a + b + c;
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c, __nv_bfloat16 d)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c) + __bfloat162float(d));
|
||||
#else
|
||||
return (__nv_bfloat16) ((float) a + (float) b + (float) c + (float) d);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ __nv_bfloat162 bf16hadd2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
float fal, fah, fbl, fbh, fcl, fch;
|
||||
fal = __low2float(a);
|
||||
fah = __high2float(a);
|
||||
fbl = __low2float(b);
|
||||
fbh = __high2float(b);
|
||||
fcl = __low2float(c);
|
||||
fch = __high2float(c);
|
||||
return __floats2bfloat162_rn(fal + fbl + fcl, fah + fbh + fch);
|
||||
#else
|
||||
return a + b + c;
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ __nv_bfloat16 bf16hmul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
return __float2bfloat16(__bfloat162float(a) * __bfloat162float(b) * __bfloat162float(c));
|
||||
#else
|
||||
return a * b * c;
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ __nv_bfloat162 bf16hmul2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
float fal, fah, fbl, fbh, fcl, fch;
|
||||
fal = __low2float(a);
|
||||
fah = __high2float(a);
|
||||
fbl = __low2float(b);
|
||||
fbh = __high2float(b);
|
||||
fcl = __low2float(c);
|
||||
fch = __high2float(c);
|
||||
return __floats2bfloat162_rn(fal * fbl * fcl, fah * fbh * fch);
|
||||
#else
|
||||
return a * b * c;
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ __nv_bfloat162 bf16hfma2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c, __nv_bfloat162 d)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
float fal, fah, fbl, fbh, fcl, fch, fdl, fdh;
|
||||
fal = __low2float(a);
|
||||
fah = __high2float(a);
|
||||
fbl = __low2float(b);
|
||||
fbh = __high2float(b);
|
||||
fcl = __low2float(c);
|
||||
fch = __high2float(c);
|
||||
fdl = __low2float(d);
|
||||
fdh = __high2float(d);
|
||||
return __floats2bfloat162_rn(fal * fbl * fcl + fdl, fah * fbh * fch + fdh);
|
||||
#else
|
||||
return a * b * c + d;
|
||||
#endif
|
||||
}
|
||||
|
||||
#endif // ENABLE_BF16
|
||||
|
||||
} // namespace common
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
// Operator definitions intentionally in global namespace
|
||||
namespace
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
|
||||
#if defined(CUDART_VERSION) && (CUDART_VERSION < 12020)
|
||||
|
||||
inline __device__ __nv_bfloat162 operator*(const __nv_bfloat162 x, const __nv_bfloat162 y)
|
||||
{
|
||||
return tensorrt_llm::common::bf16hmul2(x, y);
|
||||
};
|
||||
|
||||
inline __device__ __nv_bfloat162 operator+(const __nv_bfloat162 x, const __nv_bfloat162 y)
|
||||
{
|
||||
return tensorrt_llm::common::bf16hadd2(x, y);
|
||||
};
|
||||
#endif
|
||||
#endif
|
||||
} // namespace
|
||||
187
sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.cpp
vendored
Normal file
187
sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.cpp
vendored
Normal file
@@ -0,0 +1,187 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#define CUDA_LIB_NAME "cuda"
|
||||
|
||||
#if defined(_WIN32)
|
||||
#include <windows.h>
|
||||
#define dllOpen(name) LoadLibrary("nv" name ".dll")
|
||||
#define dllClose(handle) FreeLibrary(static_cast<HMODULE>(handle))
|
||||
#define dllGetSym(handle, name) static_cast<void*>(GetProcAddress(static_cast<HMODULE>(handle), name))
|
||||
#else // For non-Windows platforms
|
||||
#include <dlfcn.h>
|
||||
#define dllOpen(name) dlopen("lib" name ".so.1", RTLD_LAZY)
|
||||
#define dllClose(handle) dlclose(handle)
|
||||
#define dllGetSym(handle, name) dlsym(handle, name)
|
||||
#endif // defined(_WIN32)
|
||||
|
||||
#include "cudaDriverWrapper.h"
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include <cstdio>
|
||||
#include <cuda.h>
|
||||
|
||||
namespace tensorrt_llm::common
|
||||
{
|
||||
|
||||
std::shared_ptr<CUDADriverWrapper> CUDADriverWrapper::getInstance()
|
||||
{
|
||||
static std::mutex mutex;
|
||||
static std::weak_ptr<CUDADriverWrapper> instance;
|
||||
std::shared_ptr<CUDADriverWrapper> result = instance.lock();
|
||||
if (result)
|
||||
{
|
||||
return result;
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
result = instance.lock();
|
||||
if (!result)
|
||||
{
|
||||
result = std::shared_ptr<CUDADriverWrapper>(new CUDADriverWrapper());
|
||||
instance = result;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
CUDADriverWrapper::CUDADriverWrapper()
|
||||
: handle(dllOpen(CUDA_LIB_NAME))
|
||||
{
|
||||
|
||||
TLLM_CHECK_WITH_INFO(handle != nullptr, "CUDA driver library is not open correctly.");
|
||||
|
||||
auto load_sym = [](void* handle, char const* name)
|
||||
{
|
||||
void* ret = dllGetSym(handle, name);
|
||||
return ret;
|
||||
};
|
||||
|
||||
*reinterpret_cast<void**>(&_cuGetErrorName) = load_sym(handle, "cuGetErrorName");
|
||||
*reinterpret_cast<void**>(&_cuGetErrorMessage) = load_sym(handle, "cuGetErrorMessage");
|
||||
*reinterpret_cast<void**>(&_cuFuncSetAttribute) = load_sym(handle, "cuFuncSetAttribute");
|
||||
*reinterpret_cast<void**>(&_cuLinkComplete) = load_sym(handle, "cuLinkComplete");
|
||||
*reinterpret_cast<void**>(&_cuModuleUnload) = load_sym(handle, "cuModuleUnload");
|
||||
*reinterpret_cast<void**>(&_cuLinkDestroy) = load_sym(handle, "cuLinkDestroy");
|
||||
*reinterpret_cast<void**>(&_cuModuleLoadData) = load_sym(handle, "cuModuleLoadData");
|
||||
*reinterpret_cast<void**>(&_cuLinkCreate) = load_sym(handle, "cuLinkCreate_v2");
|
||||
*reinterpret_cast<void**>(&_cuModuleGetFunction) = load_sym(handle, "cuModuleGetFunction");
|
||||
*reinterpret_cast<void**>(&_cuModuleGetGlobal) = load_sym(handle, "cuModuleGetGlobal_v2");
|
||||
*reinterpret_cast<void**>(&_cuLinkAddFile) = load_sym(handle, "cuLinkAddFile_v2");
|
||||
*reinterpret_cast<void**>(&_cuLinkAddData) = load_sym(handle, "cuLinkAddData_v2");
|
||||
*reinterpret_cast<void**>(&_cuLaunchCooperativeKernel) = load_sym(handle, "cuLaunchCooperativeKernel");
|
||||
*reinterpret_cast<void**>(&_cuLaunchKernel) = load_sym(handle, "cuLaunchKernel");
|
||||
*reinterpret_cast<void**>(&_cuTensorMapEncodeTiled) = load_sym(handle, "cuTensorMapEncodeTiled");
|
||||
*reinterpret_cast<void**>(&_cuMemcpyDtoH) = load_sym(handle, "cuMemcpyDtoH_v2");
|
||||
}
|
||||
|
||||
CUDADriverWrapper::~CUDADriverWrapper()
|
||||
{
|
||||
dllClose(handle);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuGetErrorName(CUresult error, char const** pStr) const
|
||||
{
|
||||
return (*_cuGetErrorName)(error, pStr);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuGetErrorMessage(CUresult error, char const** pStr) const
|
||||
{
|
||||
return (*_cuGetErrorMessage)(error, pStr);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value) const
|
||||
{
|
||||
return (*_cuFuncSetAttribute)(hfunc, attrib, value);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut) const
|
||||
{
|
||||
return (*_cuLinkComplete)(state, cubinOut, sizeOut);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuModuleUnload(CUmodule hmod) const
|
||||
{
|
||||
return (*_cuModuleUnload)(hmod);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuLinkDestroy(CUlinkState state) const
|
||||
{
|
||||
return (*_cuLinkDestroy)(state);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuModuleLoadData(CUmodule* module, void const* image) const
|
||||
{
|
||||
return (*_cuModuleLoadData)(module, image);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuLinkCreate(
|
||||
unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut) const
|
||||
{
|
||||
return (*_cuLinkCreate)(numOptions, options, optionValues, stateOut);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuModuleGetFunction(CUfunction* hfunc, CUmodule hmod, char const* name) const
|
||||
{
|
||||
return (*_cuModuleGetFunction)(hfunc, hmod, name);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuModuleGetGlobal(CUdeviceptr* dptr, size_t* bytes, CUmodule hmod, char const* name) const
|
||||
{
|
||||
return (*_cuModuleGetGlobal)(dptr, bytes, hmod, name);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuLinkAddFile(CUlinkState state, CUjitInputType type, char const* path,
|
||||
unsigned int numOptions, CUjit_option* options, void** optionValues) const
|
||||
{
|
||||
return (*_cuLinkAddFile)(state, type, path, numOptions, options, optionValues);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuLinkAddData(CUlinkState state, CUjitInputType type, void* data, size_t size,
|
||||
char const* name, unsigned int numOptions, CUjit_option* options, void** optionValues) const
|
||||
{
|
||||
return (*_cuLinkAddData)(state, type, data, size, name, numOptions, options, optionValues);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuLaunchCooperativeKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY,
|
||||
unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ,
|
||||
unsigned int sharedMemBytes, CUstream hStream, void** kernelParams) const
|
||||
{
|
||||
return (*_cuLaunchCooperativeKernel)(
|
||||
f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, sharedMemBytes, hStream, kernelParams);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY,
|
||||
unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ,
|
||||
unsigned int sharedMemBytes, CUstream hStream, void** kernelParams, void** extra) const
|
||||
{
|
||||
return (*_cuLaunchKernel)(
|
||||
f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, sharedMemBytes, hStream, kernelParams, extra);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuTensorMapEncodeTiled(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType,
|
||||
cuuint32_t tensorRank, void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides,
|
||||
cuuint32_t const* boxDim, cuuint32_t const* elementStrides, CUtensorMapInterleave interleave,
|
||||
CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill) const
|
||||
{
|
||||
return (*_cuTensorMapEncodeTiled)(tensorMap, tensorDataType, tensorRank, globalAddress, globalDim, globalStrides,
|
||||
boxDim, elementStrides, interleave, swizzle, l2Promotion, oobFill);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuMemcpyDtoH(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount) const
|
||||
{
|
||||
return (*_cuMemcpyDtoH)(dstHost, srcDevice, ByteCount);
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::common
|
||||
138
sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.h
vendored
Normal file
138
sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.h
vendored
Normal file
@@ -0,0 +1,138 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef CUDA_DRIVER_WRAPPER_H
|
||||
#define CUDA_DRIVER_WRAPPER_H
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include <cstdio>
|
||||
#include <cuda.h>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
|
||||
namespace tensorrt_llm::common
|
||||
{
|
||||
|
||||
class CUDADriverWrapper
|
||||
{
|
||||
public:
|
||||
static std::shared_ptr<CUDADriverWrapper> getInstance();
|
||||
|
||||
~CUDADriverWrapper();
|
||||
CUDADriverWrapper(CUDADriverWrapper const&) = delete;
|
||||
CUDADriverWrapper operator=(CUDADriverWrapper const&) = delete;
|
||||
CUDADriverWrapper(CUDADriverWrapper&&) = delete;
|
||||
CUDADriverWrapper operator=(CUDADriverWrapper&&) = delete;
|
||||
|
||||
CUresult cuGetErrorName(CUresult error, char const** pStr) const;
|
||||
|
||||
CUresult cuGetErrorMessage(CUresult error, char const** pStr) const;
|
||||
|
||||
CUresult cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value) const;
|
||||
|
||||
CUresult cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut) const;
|
||||
|
||||
CUresult cuModuleUnload(CUmodule hmod) const;
|
||||
|
||||
CUresult cuLinkDestroy(CUlinkState state) const;
|
||||
|
||||
CUresult cuModuleLoadData(CUmodule* module, void const* image) const;
|
||||
|
||||
CUresult cuLinkCreate(
|
||||
unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut) const;
|
||||
|
||||
CUresult cuModuleGetFunction(CUfunction* hfunc, CUmodule hmod, char const* name) const;
|
||||
|
||||
CUresult cuModuleGetGlobal(CUdeviceptr* dptr, size_t* bytes, CUmodule hmod, char const* name) const;
|
||||
|
||||
CUresult cuLinkAddFile(CUlinkState state, CUjitInputType type, char const* path, unsigned int numOptions,
|
||||
CUjit_option* options, void** optionValues) const;
|
||||
|
||||
CUresult cuLinkAddData(CUlinkState state, CUjitInputType type, void* data, size_t size, char const* name,
|
||||
unsigned int numOptions, CUjit_option* options, void** optionValues) const;
|
||||
|
||||
CUresult cuLaunchCooperativeKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY,
|
||||
unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ,
|
||||
unsigned int sharedMemBytes, CUstream hStream, void** kernelParams) const;
|
||||
|
||||
CUresult cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ,
|
||||
unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes,
|
||||
CUstream hStream, void** kernelParams, void** extra) const;
|
||||
|
||||
CUresult cuTensorMapEncodeTiled(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType, cuuint32_t tensorRank,
|
||||
void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides, cuuint32_t const* boxDim,
|
||||
cuuint32_t const* elementStrides, CUtensorMapInterleave interleave, CUtensorMapSwizzle swizzle,
|
||||
CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill) const;
|
||||
|
||||
CUresult cuMemcpyDtoH(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount) const;
|
||||
|
||||
private:
|
||||
void* handle;
|
||||
CUDADriverWrapper();
|
||||
|
||||
CUresult (*_cuGetErrorName)(CUresult, char const**);
|
||||
CUresult (*_cuGetErrorMessage)(CUresult, char const**);
|
||||
CUresult (*_cuFuncSetAttribute)(CUfunction, CUfunction_attribute, int);
|
||||
CUresult (*_cuLinkComplete)(CUlinkState, void**, size_t*);
|
||||
CUresult (*_cuModuleUnload)(CUmodule);
|
||||
CUresult (*_cuLinkDestroy)(CUlinkState);
|
||||
CUresult (*_cuLinkCreate)(unsigned int, CUjit_option*, void**, CUlinkState*);
|
||||
CUresult (*_cuModuleLoadData)(CUmodule*, void const*);
|
||||
CUresult (*_cuModuleGetFunction)(CUfunction*, CUmodule, char const*);
|
||||
CUresult (*_cuModuleGetGlobal)(CUdeviceptr*, size_t*, CUmodule, char const*);
|
||||
CUresult (*_cuLinkAddFile)(CUlinkState, CUjitInputType, char const*, unsigned int, CUjit_option*, void**);
|
||||
CUresult (*_cuLinkAddData)(
|
||||
CUlinkState, CUjitInputType, void*, size_t, char const*, unsigned int, CUjit_option*, void**);
|
||||
CUresult (*_cuLaunchCooperativeKernel)(CUfunction, unsigned int, unsigned int, unsigned int, unsigned int,
|
||||
unsigned int, unsigned int, unsigned int, CUstream, void**);
|
||||
CUresult (*_cuLaunchKernel)(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ,
|
||||
unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes,
|
||||
CUstream hStream, void** kernelParams, void** extra);
|
||||
CUresult (*_cuTensorMapEncodeTiled)(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType,
|
||||
cuuint32_t tensorRank, void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides,
|
||||
cuuint32_t const* boxDim, cuuint32_t const* elementStrides, CUtensorMapInterleave interleave,
|
||||
CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill);
|
||||
CUresult (*_cuMemcpyDtoH)(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void checkDriver(
|
||||
T result, CUDADriverWrapper const& wrap, char const* const func, char const* const file, int const line)
|
||||
{
|
||||
if (result)
|
||||
{
|
||||
char const* errorName = nullptr;
|
||||
char const* errorMsg = nullptr;
|
||||
wrap.cuGetErrorName(result, &errorName);
|
||||
wrap.cuGetErrorMessage(result, &errorMsg);
|
||||
throw TllmException(
|
||||
file, line, fmtstr("[TensorRT-LLM][ERROR] CUDA driver error in %s: %s: %s", func, errorName, errorMsg));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::common
|
||||
|
||||
/*
|
||||
* Macros compliant with TensorRT coding conventions
|
||||
*/
|
||||
#define TLLM_CU_CHECK(stat) \
|
||||
do \
|
||||
{ \
|
||||
tensorrt_llm::common::checkDriver( \
|
||||
(stat), *tensorrt_llm::common::CUDADriverWrapper::getInstance(), #stat, __FILE__, __LINE__); \
|
||||
} while (0)
|
||||
|
||||
#endif // CUDA_DRIVER_WRAPPER_H
|
||||
436
sgl-kernel/3rdparty/tensorrt_llm/common/cudaFp8Utils.cu
vendored
Normal file
436
sgl-kernel/3rdparty/tensorrt_llm/common/cudaFp8Utils.cu
vendored
Normal file
@@ -0,0 +1,436 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/common/cudaFp8Utils.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/common/reduceKernelUtils.cuh"
|
||||
#include <algorithm>
|
||||
#include <cstdio>
|
||||
#include <cuda_fp16.h>
|
||||
#include <limits>
|
||||
#include <type_traits>
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace common
|
||||
{
|
||||
#ifdef ENABLE_FP8
|
||||
|
||||
constexpr int CTA_SIZE = 256;
|
||||
|
||||
template <bool QUANTIZE>
|
||||
__inline__ __device__ float scale(float a, float b)
|
||||
{
|
||||
return QUANTIZE ? a / b : a * b;
|
||||
}
|
||||
|
||||
template <QuantizeMode QUANTIZE_MODE, bool QUANTIZE, typename T_OUT, typename T_S, typename T_IN>
|
||||
__global__ void scaleMatrix(T_OUT* output, T_S const* input_scale, T_IN const* input, int64_t numel, int64_t lda)
|
||||
{
|
||||
for (int64_t i = threadIdx.x + blockIdx.x * blockDim.x; i < numel; i += blockDim.x * gridDim.x)
|
||||
{
|
||||
|
||||
if (QUANTIZE_MODE == QuantizeMode::PER_CHANNEL)
|
||||
{
|
||||
output[i] = T_OUT(scale<QUANTIZE>(static_cast<float>(input[i]), static_cast<float>(input_scale[i % lda])));
|
||||
}
|
||||
else if (QUANTIZE_MODE == QuantizeMode::PER_TOKEN)
|
||||
{
|
||||
output[i] = T_OUT(scale<QUANTIZE>(static_cast<float>(input[i]), static_cast<float>(input_scale[i / lda])));
|
||||
}
|
||||
else if (QUANTIZE_MODE == QuantizeMode::PER_TENSOR)
|
||||
{
|
||||
output[i] = T_OUT(scale<QUANTIZE>(static_cast<float>(input[i]), static_cast<float>(input_scale[0])));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T_OUT, typename T_S, typename T_IN>
|
||||
void invokeQuantizeMatrix(T_OUT* output, T_S const* input_scale, T_IN const* input, int64_t numel, int64_t lda,
|
||||
QuantizeMode quantize_mode, cudaStream_t stream)
|
||||
{
|
||||
dim3 grid(1024);
|
||||
dim3 block(CTA_SIZE);
|
||||
if (quantize_mode == QuantizeMode::PER_CHANNEL)
|
||||
{
|
||||
scaleMatrix<QuantizeMode::PER_CHANNEL, true>
|
||||
<<<grid, block, 0, stream>>>(output, input_scale, input, numel, lda);
|
||||
}
|
||||
else if (quantize_mode == QuantizeMode::PER_TOKEN)
|
||||
{
|
||||
scaleMatrix<QuantizeMode::PER_TOKEN, true><<<grid, block, 0, stream>>>(output, input_scale, input, numel, lda);
|
||||
}
|
||||
else if (quantize_mode == QuantizeMode::PER_TENSOR)
|
||||
{
|
||||
scaleMatrix<QuantizeMode::PER_TENSOR, true><<<grid, block, 0, stream>>>(output, input_scale, input, numel, lda);
|
||||
}
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
|
||||
template <typename T_OUT, typename T_S, typename T_IN>
|
||||
void invokeDequantizeMatrix(T_OUT* output, T_S const* input_scale, T_IN const* input, int64_t numel, int64_t lda,
|
||||
QuantizeMode quantize_mode, cudaStream_t stream)
|
||||
{
|
||||
dim3 grid(1024);
|
||||
dim3 block(CTA_SIZE);
|
||||
if (quantize_mode == QuantizeMode::PER_CHANNEL)
|
||||
{
|
||||
scaleMatrix<QuantizeMode::PER_CHANNEL, false>
|
||||
<<<grid, block, 0, stream>>>(output, input_scale, input, numel, lda);
|
||||
}
|
||||
else if (quantize_mode == QuantizeMode::PER_TOKEN)
|
||||
{
|
||||
scaleMatrix<QuantizeMode::PER_TOKEN, false><<<grid, block, 0, stream>>>(output, input_scale, input, numel, lda);
|
||||
}
|
||||
else if (quantize_mode == QuantizeMode::PER_TENSOR)
|
||||
{
|
||||
scaleMatrix<QuantizeMode::PER_TENSOR, false>
|
||||
<<<grid, block, 0, stream>>>(output, input_scale, input, numel, lda);
|
||||
}
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
|
||||
template <typename T_FAKE, typename T_OUT, typename T_IN>
|
||||
__global__ void fakeQuantize(T_OUT* dst, const T_IN* src, const int64_t numel)
|
||||
{
|
||||
for (int64_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < numel; tid += blockDim.x * gridDim.x)
|
||||
{
|
||||
T_FAKE tmp = (T_FAKE) (static_cast<float>(src[tid]));
|
||||
dst[tid] = (T_OUT) (static_cast<float>(tmp));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T_FAKE, typename T_OUT, typename T_IN>
|
||||
void invokeFakeQuantize(T_OUT* dst, const T_IN* src, const int64_t numel, cudaStream_t stream)
|
||||
{
|
||||
fakeQuantize<T_FAKE><<<1024, CTA_SIZE, 0, stream>>>(dst, src, numel);
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
|
||||
template void invokeFakeQuantize<__nv_fp8_e4m3, float, float>(
|
||||
float* dst, float const* src, const int64_t numel, cudaStream_t stream);
|
||||
template void invokeFakeQuantize<float, float, __nv_fp8_e4m3>(
|
||||
float* dst, __nv_fp8_e4m3 const* src, const int64_t numel, cudaStream_t stream);
|
||||
template void invokeFakeQuantize<__nv_fp8_e4m3, half, half>(
|
||||
half* dst, half const* src, const int64_t numel, cudaStream_t stream);
|
||||
template void invokeFakeQuantize<__nv_fp8_e4m3, __nv_bfloat16, __nv_bfloat16>(
|
||||
__nv_bfloat16* dst, __nv_bfloat16 const* src, const int64_t numel, cudaStream_t stream);
|
||||
|
||||
template void invokeFakeQuantize<float, half, float>(
|
||||
half* dst, float const* src, const int64_t numel, cudaStream_t stream);
|
||||
|
||||
__device__ float atomicMaxExtd(float* address, float val)
|
||||
{
|
||||
assert(val >= 0);
|
||||
unsigned int* address_as_u = reinterpret_cast<unsigned int*>(address);
|
||||
unsigned int old = atomicMax(address_as_u, __float_as_uint(val));
|
||||
return __uint_as_float(old);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T atomicMaxExtdV2(T* address, T val)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
|
||||
static_assert(std::is_same_v<T, half> | std::is_same_v<T, __nv_bfloat16>, "T needs to be either half or bfloat16");
|
||||
// The address in 64 bits.
|
||||
uint64_t address_u64 = reinterpret_cast<uint64_t const&>(address);
|
||||
|
||||
// Pack the input value into 32 bits.
|
||||
union
|
||||
{
|
||||
T v[2];
|
||||
uint16_t u[2];
|
||||
} old, tmp = {};
|
||||
|
||||
int const loc = (address_u64 & 0x2) >> 1;
|
||||
tmp.v[loc] = val;
|
||||
|
||||
// 4B aligned pointer.
|
||||
auto aligned_address = reinterpret_cast<T*>(address_u64 & ~0x3ull);
|
||||
|
||||
if constexpr (std::is_same_v<T, half>)
|
||||
{
|
||||
asm volatile("atom.global.v2.f16.max.noftz {%0, %1}, [%2], {%3, %4};"
|
||||
: "=h"(old.u[0]), "=h"(old.u[1])
|
||||
: "l"(aligned_address), "h"(tmp.u[0]), "h"(tmp.u[1]));
|
||||
}
|
||||
if constexpr (std::is_same_v<T, __nv_bfloat16>)
|
||||
{
|
||||
asm volatile("atom.global.v2.bf16.max.noftz {%0, %1}, [%2], {%3, %4};"
|
||||
: "=h"(old.u[0]), "=h"(old.u[1])
|
||||
: "l"(aligned_address), "h"(tmp.u[0]), "h"(tmp.u[1]));
|
||||
}
|
||||
|
||||
// Return the correct half.
|
||||
return old.v[loc];
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ half atomicMaxExtd(half* address, half val)
|
||||
{
|
||||
unsigned short int* address_as_u = reinterpret_cast<unsigned short int*>(address);
|
||||
unsigned short int old = *address_as_u, assumed;
|
||||
|
||||
while (val > __ushort_as_half(old))
|
||||
{
|
||||
assumed = old;
|
||||
old = atomicCAS(address_as_u, assumed, __half_as_ushort(val));
|
||||
}
|
||||
|
||||
return __ushort_as_half(old);
|
||||
}
|
||||
|
||||
__device__ __nv_bfloat16 atomicMaxExtd(__nv_bfloat16* address, __nv_bfloat16 val)
|
||||
{
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
|
||||
unsigned short int* address_as_u = reinterpret_cast<unsigned short int*>(address);
|
||||
unsigned short int old = *address_as_u, assumed;
|
||||
|
||||
while (val > __ushort_as_bfloat16(old))
|
||||
{
|
||||
assumed = old;
|
||||
old = atomicCAS(address_as_u, assumed, __bfloat16_as_ushort(val));
|
||||
}
|
||||
|
||||
return __ushort_as_bfloat16(old);
|
||||
#else
|
||||
assert(0);
|
||||
asm volatile("brkpt;\n" ::);
|
||||
return __nv_bfloat16(0);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <QuantizeMode QUANTIZE_MODE, typename T_S, typename T_W>
|
||||
__global__ void computeFP8QuantizeScale(T_S* quant_ptr, const T_W* weights, const int64_t size, const int64_t n)
|
||||
{
|
||||
constexpr float min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f);
|
||||
if (QUANTIZE_MODE == QuantizeMode::PER_CHANNEL)
|
||||
{
|
||||
for (int64_t col = threadIdx.x; col < n; col += blockDim.x)
|
||||
{
|
||||
float max = 0.f;
|
||||
for (int64_t i = col + n * blockIdx.x; i < size; i += gridDim.x * n)
|
||||
{
|
||||
auto val = fabs(static_cast<float>(weights[i]));
|
||||
max = max > val ? max : val;
|
||||
}
|
||||
auto const scale = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor);
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
|
||||
if constexpr (std::is_same_v<T_S, float>)
|
||||
{
|
||||
atomicMaxExtd(quant_ptr + col, scale);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto const address_u64 = reinterpret_cast<uint64_t>(quant_ptr + col);
|
||||
if ((col == 0 && address_u64 % 4 != 0) || (col == n - 1 && address_u64 % 4 == 0))
|
||||
atomicMaxExtd(quant_ptr + col, scale);
|
||||
else
|
||||
atomicMaxExtdV2(quant_ptr + col, scale);
|
||||
}
|
||||
#else // Vector atomics require __CUDA_ARCH__ >= 900
|
||||
atomicMaxExtd(quant_ptr + col, scale);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else if (QUANTIZE_MODE == QuantizeMode::PER_TOKEN)
|
||||
{
|
||||
auto const nrows = size / n;
|
||||
for (int64_t row = blockIdx.x; row < nrows; row += gridDim.x)
|
||||
{
|
||||
float max = 0.f;
|
||||
for (int64_t i = threadIdx.x; i < n; i += blockDim.x)
|
||||
{
|
||||
auto val = fabs(static_cast<float>(weights[row * n + i]));
|
||||
max = max > val ? max : val;
|
||||
}
|
||||
max = blockReduceMax<float>(max);
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
auto const scale = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor);
|
||||
quant_ptr[row] = scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (QUANTIZE_MODE == QuantizeMode::PER_TENSOR)
|
||||
{
|
||||
float max = 0.f;
|
||||
for (int64_t i = threadIdx.x + blockIdx.x * blockDim.x; i < size; i += gridDim.x * blockDim.x)
|
||||
{
|
||||
auto val = fabs(static_cast<float>(weights[i]));
|
||||
max = max > val ? max : val;
|
||||
}
|
||||
max = blockReduceMax<float>(max);
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
auto const scale = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor);
|
||||
atomicMaxExtd(quant_ptr, scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T_S, typename T_W>
|
||||
void invokeComputeFP8QuantizeScale(T_S* quant_ptr, const T_W* weights, const int64_t numel, const int64_t lda,
|
||||
QuantizeMode quantize_mode, cudaStream_t stream)
|
||||
{
|
||||
if (quantize_mode == QuantizeMode::PER_TOKEN)
|
||||
{
|
||||
dim3 block(CTA_SIZE);
|
||||
dim3 grid(numel / lda);
|
||||
computeFP8QuantizeScale<QuantizeMode::PER_TOKEN><<<grid, block, 0, stream>>>(quant_ptr, weights, numel, lda);
|
||||
}
|
||||
else if (quantize_mode == QuantizeMode::PER_CHANNEL)
|
||||
{
|
||||
dim3 block(CTA_SIZE);
|
||||
dim3 grid((lda + CTA_SIZE - 1) / CTA_SIZE);
|
||||
cudaMemsetAsync(quant_ptr, 0, lda * sizeof(T_S), stream);
|
||||
sync_check_cuda_error();
|
||||
computeFP8QuantizeScale<QuantizeMode::PER_CHANNEL><<<grid, block, 0, stream>>>(quant_ptr, weights, numel, lda);
|
||||
}
|
||||
else if (quantize_mode == QuantizeMode::PER_TENSOR)
|
||||
{
|
||||
dim3 block(1024);
|
||||
dim3 grid(1024);
|
||||
cudaMemsetAsync(quant_ptr, 0, sizeof(T_S), stream);
|
||||
sync_check_cuda_error();
|
||||
computeFP8QuantizeScale<QuantizeMode::PER_TENSOR><<<grid, block, 0, stream>>>(quant_ptr, weights, numel, lda);
|
||||
}
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
|
||||
#define DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(type_scale, type_in) \
|
||||
template void invokeComputeFP8QuantizeScale<type_scale, type_in>(type_scale * input_scale, type_in const* weights, \
|
||||
int64_t numel, int64_t lda, QuantizeMode quantize_mode, cudaStream_t stream);
|
||||
|
||||
DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(half, half);
|
||||
DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(float, half);
|
||||
DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(float, float);
|
||||
#ifdef ENABLE_BF16
|
||||
DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(__nv_bfloat16, __nv_bfloat16);
|
||||
DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(float, __nv_bfloat16);
|
||||
#endif
|
||||
|
||||
template <typename T_OUT, typename T_S, typename T_IN>
|
||||
__global__ void dynamicQuantizeMatrixPerToken(
|
||||
T_OUT* output, T_S* quant_ptr, T_IN const* input, int64_t numel, int64_t lda)
|
||||
{
|
||||
extern __shared__ __align__(sizeof(float)) char _shmem[];
|
||||
T_IN* shmem = reinterpret_cast<T_IN*>(_shmem);
|
||||
constexpr float min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f);
|
||||
auto const nrows = numel / lda;
|
||||
for (int64_t row = blockIdx.x; row < nrows; row += gridDim.x)
|
||||
{
|
||||
float max = 0.f;
|
||||
for (int64_t i = threadIdx.x; i < lda; i += blockDim.x)
|
||||
{
|
||||
auto const in = input[row * lda + i];
|
||||
shmem[i] = in;
|
||||
auto val = fabs(static_cast<float>(in));
|
||||
max = max > val ? max : val;
|
||||
}
|
||||
max = blockAllReduceMax<float>(max); // __syncthreads() called so we can read shmem
|
||||
auto const s = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor);
|
||||
for (int64_t i = threadIdx.x; i < lda; i += blockDim.x)
|
||||
{
|
||||
// true means we are quantizing
|
||||
output[row * lda + i] = (T_OUT) scale<true>(static_cast<float>(shmem[i]), static_cast<float>(s));
|
||||
}
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
quant_ptr[row] = s;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T_OUT, typename T_S, typename T_IN>
|
||||
void invokeComputeScalesAndQuantizeMatrix(T_OUT* output, T_S* quant_ptr, const T_IN* input, const int64_t numel,
|
||||
const int64_t lda, QuantizeMode quantize_mode, cudaStream_t stream)
|
||||
{
|
||||
if (quantize_mode == QuantizeMode::PER_TOKEN)
|
||||
{
|
||||
dim3 grid(numel / lda);
|
||||
bool use_shmem = true;
|
||||
auto const shmem_size = lda * sizeof(T_IN);
|
||||
if (shmem_size >= (48 << 10))
|
||||
{
|
||||
cudaError_t ret = cudaFuncSetAttribute(dynamicQuantizeMatrixPerToken<T_OUT, T_S, T_IN>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size);
|
||||
use_shmem = ret == cudaSuccess;
|
||||
}
|
||||
if (use_shmem)
|
||||
{
|
||||
// ensure the threadblock is as large as possible to increase occupancy
|
||||
dim3 block(std::min((lda + 31) / 32 * 32, static_cast<int64_t>(1024)));
|
||||
dynamicQuantizeMatrixPerToken<<<grid, block, shmem_size, stream>>>(output, quant_ptr, input, numel, lda);
|
||||
}
|
||||
else
|
||||
{
|
||||
dim3 block(CTA_SIZE);
|
||||
computeFP8QuantizeScale<QuantizeMode::PER_TOKEN><<<grid, block, 0, stream>>>(quant_ptr, input, numel, lda);
|
||||
sync_check_cuda_error();
|
||||
invokeQuantizeMatrix(output, quant_ptr, input, numel, lda, quantize_mode, stream);
|
||||
}
|
||||
}
|
||||
else if (quantize_mode == QuantizeMode::PER_CHANNEL)
|
||||
{
|
||||
dim3 block(CTA_SIZE);
|
||||
dim3 grid((lda + CTA_SIZE - 1) / CTA_SIZE);
|
||||
cudaMemsetAsync(quant_ptr, 0, lda * sizeof(T_S), stream);
|
||||
sync_check_cuda_error();
|
||||
computeFP8QuantizeScale<QuantizeMode::PER_CHANNEL><<<grid, block, 0, stream>>>(quant_ptr, input, numel, lda);
|
||||
sync_check_cuda_error();
|
||||
invokeQuantizeMatrix(output, quant_ptr, input, numel, lda, quantize_mode, stream);
|
||||
}
|
||||
else if (quantize_mode == QuantizeMode::PER_TENSOR)
|
||||
{
|
||||
dim3 block(1024);
|
||||
dim3 grid(1024);
|
||||
cudaMemsetAsync(quant_ptr, 0, sizeof(T_S), stream);
|
||||
sync_check_cuda_error();
|
||||
computeFP8QuantizeScale<QuantizeMode::PER_TENSOR><<<grid, block, 0, stream>>>(quant_ptr, input, numel, lda);
|
||||
sync_check_cuda_error();
|
||||
invokeQuantizeMatrix(output, quant_ptr, input, numel, lda, quantize_mode, stream);
|
||||
}
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
|
||||
#define DEFINE_INVOKE_QUANTIZE_MATRIX(type_out, type_scale, type_in) \
|
||||
template void invokeQuantizeMatrix<type_out, type_scale, type_in>(type_out * output, \
|
||||
type_scale const* input_scale, type_in const* input, int64_t numel, int64_t lda, QuantizeMode quantize_mode, \
|
||||
cudaStream_t stream); \
|
||||
template void invokeDequantizeMatrix<type_out, type_scale, type_in>(type_out * output, \
|
||||
type_scale const* input_scale, type_in const* input, int64_t numel, int64_t lda, QuantizeMode quantize_mode, \
|
||||
cudaStream_t stream); \
|
||||
template void invokeComputeScalesAndQuantizeMatrix<type_out, type_scale, type_in>(type_out * output, \
|
||||
type_scale * input_scale, type_in const* input, int64_t numel, int64_t lda, QuantizeMode quantize_mode, \
|
||||
cudaStream_t stream);
|
||||
|
||||
#ifdef ENABLE_FP8
|
||||
DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_fp8_e4m3, float, float);
|
||||
DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_fp8_e4m3, float, half);
|
||||
DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_fp8_e4m3, half, half);
|
||||
DEFINE_INVOKE_QUANTIZE_MATRIX(half, half, __nv_fp8_e4m3);
|
||||
DEFINE_INVOKE_QUANTIZE_MATRIX(float, float, __nv_fp8_e4m3);
|
||||
DEFINE_INVOKE_QUANTIZE_MATRIX(half, float, __nv_fp8_e4m3);
|
||||
#ifdef ENABLE_BF16
|
||||
DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_fp8_e4m3, __nv_bfloat16, __nv_bfloat16);
|
||||
DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_bfloat16, __nv_bfloat16, __nv_fp8_e4m3);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#endif // ENABLE_FP8
|
||||
} // namespace common
|
||||
} // namespace tensorrt_llm
|
||||
84
sgl-kernel/3rdparty/tensorrt_llm/common/cudaProfilerUtils.cpp
vendored
Normal file
84
sgl-kernel/3rdparty/tensorrt_llm/common/cudaProfilerUtils.cpp
vendored
Normal file
@@ -0,0 +1,84 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/common/cudaProfilerUtils.h"
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
#include "tensorrt_llm/common/stringUtils.h"
|
||||
#include <cstdint>
|
||||
#include <optional>
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
std::tuple<std::unordered_set<int32_t>, std::unordered_set<int32_t>> populateIterationIndexesImpl(
|
||||
std::string const& envVarName)
|
||||
{
|
||||
auto envVarVal = std::getenv(envVarName.c_str());
|
||||
auto envVarValStr = std::string{envVarVal != nullptr ? envVarVal : ""};
|
||||
auto values = tensorrt_llm::common::str2set(envVarValStr, ',');
|
||||
std::unordered_set<int32_t> startSet;
|
||||
std::unordered_set<int32_t> endSet;
|
||||
for (std::string const& value : values)
|
||||
{
|
||||
size_t dashIdx = value.find("-");
|
||||
if (dashIdx != std::string::npos)
|
||||
{
|
||||
int32_t start = std::stoi(value.substr(0, dashIdx));
|
||||
startSet.insert(start);
|
||||
int32_t end = std::stoi(value.substr(dashIdx + 1));
|
||||
endSet.insert(end);
|
||||
}
|
||||
else
|
||||
{
|
||||
int32_t start_end = std::stoi(value);
|
||||
startSet.insert(start_end);
|
||||
endSet.insert(start_end);
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_pair(startSet, endSet);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace tensorrt_llm::common
|
||||
{
|
||||
|
||||
std::pair<std::unordered_set<int32_t>, std::unordered_set<int32_t>> populateIterationIndexes(
|
||||
std::string const& envVarName, std::optional<std::string> const& legacyEnvVarName)
|
||||
{
|
||||
auto [profileIterIdxs, stopIterIdxs] = populateIterationIndexesImpl(envVarName);
|
||||
|
||||
// If empty, try to use legacy env var name
|
||||
if (legacyEnvVarName && profileIterIdxs.empty() && stopIterIdxs.empty())
|
||||
{
|
||||
std::tie(profileIterIdxs, stopIterIdxs) = populateIterationIndexesImpl(legacyEnvVarName.value());
|
||||
|
||||
if (!profileIterIdxs.empty() || !stopIterIdxs.empty())
|
||||
{
|
||||
TLLM_LOG_WARNING(
|
||||
"Using deprecated environment variable %s to specify cudaProfiler start and stop iterations. "
|
||||
"Please "
|
||||
"use %s "
|
||||
"instead.",
|
||||
legacyEnvVarName.value().c_str(), envVarName.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_pair(profileIterIdxs, stopIterIdxs);
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::common
|
||||
752
sgl-kernel/3rdparty/tensorrt_llm/common/cudaTypeUtils.cuh
vendored
Normal file
752
sgl-kernel/3rdparty/tensorrt_llm/common/cudaTypeUtils.cuh
vendored
Normal file
@@ -0,0 +1,752 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/cudaBf16Fallbacks.cuh"
|
||||
#include "tensorrt_llm/common/cudaBf16Wrapper.h"
|
||||
#include "tensorrt_llm/common/cudaFp8Utils.h"
|
||||
#include <assert.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#if ENABLE_BF16
|
||||
#include <cuda_bf16.h>
|
||||
#endif
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace common
|
||||
{
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T ldg(T const* val)
|
||||
{
|
||||
return __ldg(val);
|
||||
}
|
||||
|
||||
#if ENABLE_BF16
|
||||
template <>
|
||||
inline __device__ __nv_bfloat162 ldg(__nv_bfloat162 const* val)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
return val[0];
|
||||
#else
|
||||
return __ldg(val);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __device__ __nv_bfloat16 ldg(__nv_bfloat16 const* val)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
return val[0];
|
||||
#else
|
||||
return __ldg(val);
|
||||
#endif
|
||||
}
|
||||
#endif // ENABLE_BF16
|
||||
|
||||
// Get type2 from type or vice versa (applied to half and bfloat16)
|
||||
template <typename T>
|
||||
struct TypeConverter
|
||||
{
|
||||
using Type = half2;
|
||||
}; // keep for generality
|
||||
|
||||
template <>
|
||||
struct TypeConverter<half2>
|
||||
{
|
||||
using Type = half;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeConverter<half>
|
||||
{
|
||||
using Type = half2;
|
||||
};
|
||||
|
||||
#if ENABLE_BF16
|
||||
template <>
|
||||
struct TypeConverter<__nv_bfloat162>
|
||||
{
|
||||
using Type = __nv_bfloat16;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeConverter<__nv_bfloat16>
|
||||
{
|
||||
using Type = __nv_bfloat162;
|
||||
};
|
||||
#endif // ENABLE_BF16
|
||||
|
||||
// Defined math operations (bfloat16 fallback to fp32 when it is not supported)
|
||||
template <typename T>
|
||||
inline __device__ T hadd2(T a, T b)
|
||||
{
|
||||
return __hadd2(a, b);
|
||||
}
|
||||
|
||||
#if ENABLE_BF16
|
||||
template <>
|
||||
inline __device__ __nv_bfloat162 hadd2(__nv_bfloat162 a, __nv_bfloat162 b)
|
||||
{
|
||||
return bf16hadd2(a, b);
|
||||
}
|
||||
#endif // ENABLE_BF16
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T add(T a, T b)
|
||||
{
|
||||
return a + b;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __device__ half2 add(half2 a, half2 b)
|
||||
{
|
||||
return __hadd2(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __device__ half add(half a, half b)
|
||||
{
|
||||
return __hadd(a, b);
|
||||
}
|
||||
|
||||
#if ENABLE_BF16
|
||||
template <>
|
||||
inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b)
|
||||
{
|
||||
return bf16hadd2(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b)
|
||||
{
|
||||
return bf16hadd(a, b);
|
||||
}
|
||||
|
||||
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, float b)
|
||||
{
|
||||
return bf16hadd(a, __float2bfloat16(b));
|
||||
}
|
||||
#endif // ENABLE_BF16
|
||||
|
||||
// applies to all 4 values addition
|
||||
template <typename T>
|
||||
inline __device__ T add(T a, T b, T c)
|
||||
{
|
||||
return a + b + c;
|
||||
}
|
||||
|
||||
#if ENABLE_BF16
|
||||
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c)
|
||||
{
|
||||
return bf16hadd(a, b, c);
|
||||
}
|
||||
|
||||
inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)
|
||||
{
|
||||
return bf16hadd2(a, b, c);
|
||||
}
|
||||
#endif // ENABLE_BF16
|
||||
|
||||
// applies to all 4 values addition
|
||||
template <typename T>
|
||||
inline __device__ T add(T a, T b, T c, T d)
|
||||
{
|
||||
return (T) ((float) a + (float) b + (float) c + (float) d);
|
||||
}
|
||||
|
||||
#if ENABLE_BF16
|
||||
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c, __nv_bfloat16 d)
|
||||
{
|
||||
return bf16hadd(a, b, c, d);
|
||||
}
|
||||
#endif // ENABLE_BF16
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T hsub2(T a, T b)
|
||||
{
|
||||
return __hsub2(a, b);
|
||||
}
|
||||
|
||||
#if ENABLE_BF16
|
||||
template <>
|
||||
inline __device__ __nv_bfloat162 hsub2(__nv_bfloat162 a, __nv_bfloat162 b)
|
||||
{
|
||||
return bf16hsub2(a, b);
|
||||
}
|
||||
#endif // ENABLE_BF16
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T hmul2(T a, T b)
|
||||
{
|
||||
return __hmul2(a, b);
|
||||
}
|
||||
|
||||
#if ENABLE_BF16
|
||||
template <>
|
||||
inline __device__ __nv_bfloat162 hmul2(__nv_bfloat162 a, __nv_bfloat162 b)
|
||||
{
|
||||
return bf16hmul2(a, b);
|
||||
}
|
||||
#endif // ENABLE_BF16
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T hmul2(T a, T b, T c)
|
||||
{
|
||||
return a * b * c;
|
||||
}
|
||||
|
||||
#if ENABLE_BF16
|
||||
template <>
|
||||
inline __device__ __nv_bfloat162 hmul2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)
|
||||
{
|
||||
return bf16hmul2(a, b, c);
|
||||
}
|
||||
#endif // ENABLE_BF16
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T mul(T a, T b, T c)
|
||||
{
|
||||
return a * b * c;
|
||||
}
|
||||
|
||||
#if ENABLE_BF16
|
||||
template <>
|
||||
inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c)
|
||||
{
|
||||
return bf16hmul(a, b, c);
|
||||
}
|
||||
|
||||
inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)
|
||||
{
|
||||
return bf16hmul2(a, b, c);
|
||||
}
|
||||
#endif // ENABLE_BF16
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T fma(T a, T b, T c, T d)
|
||||
{
|
||||
return a * b * c + d;
|
||||
}
|
||||
|
||||
#if ENABLE_BF16
|
||||
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c, __nv_bfloat162 d)
|
||||
{
|
||||
return bf16hfma2(a, b, c, d);
|
||||
}
|
||||
#endif // ENABLE_BF16
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T fma(T a, T b, T c)
|
||||
{
|
||||
return a * b + c;
|
||||
}
|
||||
|
||||
#if ENABLE_BF16
|
||||
template <>
|
||||
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)
|
||||
{
|
||||
return bf16hfma2(a, b, c);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __device__ __nv_bfloat16 fma(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c)
|
||||
{
|
||||
return bf16hfma(a, b, c);
|
||||
}
|
||||
#endif // ENABLE_BF16
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T hexp2(T a)
|
||||
{
|
||||
return h2exp(a);
|
||||
}
|
||||
|
||||
#if ENABLE_BF16
|
||||
template <>
|
||||
inline __device__ __nv_bfloat162 hexp2(__nv_bfloat162 a)
|
||||
{
|
||||
return bf16exp2(a);
|
||||
}
|
||||
#endif // ENABLE_BF16
|
||||
|
||||
template <typename T_OUT, typename T_IN>
|
||||
__device__ inline T_OUT cuda_cast(T_IN val)
|
||||
{
|
||||
return val;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline float2 cuda_cast<float2, int2>(int2 val)
|
||||
{
|
||||
return make_float2(val.x, val.y);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline float2 cuda_cast<float2, float>(float val)
|
||||
{
|
||||
return make_float2(val, val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline float2 cuda_cast<float2, half2>(half2 val)
|
||||
{
|
||||
return __half22float2(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline half2 cuda_cast<half2, float2>(float2 val)
|
||||
{
|
||||
return __float22half2_rn(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline half2 cuda_cast<half2, float>(float val)
|
||||
{
|
||||
return __float2half2_rn(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline half2 cuda_cast<half2, half>(half val)
|
||||
{
|
||||
return __half2half2(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline int8_t cuda_cast<int8_t, half>(half val)
|
||||
{
|
||||
union
|
||||
{
|
||||
int8_t int8[2];
|
||||
int16_t int16;
|
||||
};
|
||||
|
||||
union
|
||||
{
|
||||
half fp16;
|
||||
int16_t int16_in;
|
||||
};
|
||||
|
||||
fp16 = val;
|
||||
asm volatile("cvt.rni.sat.s8.f16 %0, %1;" : "=h"(int16) : "h"(int16_in));
|
||||
return int8[0];
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline int16_t cuda_cast<int16_t, half2>(half2 val)
|
||||
{
|
||||
union
|
||||
{
|
||||
int8_t int8[2];
|
||||
int16_t int16;
|
||||
};
|
||||
|
||||
int8[0] = cuda_cast<int8_t>(val.x);
|
||||
int8[1] = cuda_cast<int8_t>(val.y);
|
||||
return int16;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline int8_t cuda_cast<int8_t, float>(float val)
|
||||
{
|
||||
union
|
||||
{
|
||||
int8_t int8[2];
|
||||
int16_t int16;
|
||||
};
|
||||
|
||||
asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=h"(int16) : "f"(val));
|
||||
return int8[0];
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline int16_t cuda_cast<int16_t, float2>(float2 val)
|
||||
{
|
||||
union
|
||||
{
|
||||
int8_t int8[2];
|
||||
int16_t int16;
|
||||
};
|
||||
|
||||
int8[0] = cuda_cast<int8_t>(val.x);
|
||||
int8[1] = cuda_cast<int8_t>(val.y);
|
||||
return int16;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline half2 cuda_cast<half2, int16_t>(int16_t val)
|
||||
{
|
||||
union
|
||||
{
|
||||
int8_t int8[2];
|
||||
int16_t int16;
|
||||
};
|
||||
|
||||
int16 = val;
|
||||
return make_half2(int8[0], int8[1]);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline float2 cuda_cast<float2, int16_t>(int16_t val)
|
||||
{
|
||||
union
|
||||
{
|
||||
int8_t int8[2];
|
||||
int16_t int16;
|
||||
};
|
||||
|
||||
int16 = val;
|
||||
return make_float2(int8[0], int8[1]);
|
||||
}
|
||||
|
||||
#ifdef ENABLE_BF16
|
||||
template <>
|
||||
__device__ inline __nv_bfloat16 cuda_cast(int32_t val)
|
||||
{
|
||||
return static_cast<float>(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline __nv_bfloat16 cuda_cast(int8_t val)
|
||||
{
|
||||
return static_cast<float>(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline int8_t cuda_cast(__nv_bfloat16 val)
|
||||
{
|
||||
return static_cast<float>(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline float cuda_cast<float, __nv_bfloat16>(__nv_bfloat16 val)
|
||||
{
|
||||
return __bfloat162float(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline float2 cuda_cast<float2, __nv_bfloat162>(__nv_bfloat162 val)
|
||||
{
|
||||
return bf1622float2(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline half cuda_cast<half, __nv_bfloat16>(__nv_bfloat16 val)
|
||||
{
|
||||
return __float2half(__bfloat162float(val));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline int16_t cuda_cast<int16_t, __nv_bfloat162>(__nv_bfloat162 val)
|
||||
{
|
||||
return bf1622int16(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline __nv_bfloat16 cuda_cast<__nv_bfloat16, float>(float val)
|
||||
{
|
||||
return __float2bfloat16(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline __nv_bfloat16 cuda_cast<__nv_bfloat16, half>(half val)
|
||||
{
|
||||
return __float2bfloat16(__half2float(val));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, __nv_bfloat16>(__nv_bfloat16 val)
|
||||
{
|
||||
return bf162bf162(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, float>(float val)
|
||||
{
|
||||
return __float2bfloat162_rn(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, float2>(float2 val)
|
||||
{
|
||||
return float22bf162(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, int16_t>(int16_t val)
|
||||
{
|
||||
union
|
||||
{
|
||||
int8_t int8[2];
|
||||
int16_t int16;
|
||||
};
|
||||
|
||||
int16 = val;
|
||||
__nv_bfloat162 res;
|
||||
res.x = cuda_cast<__nv_bfloat16>(int8[0]);
|
||||
res.y = cuda_cast<__nv_bfloat16>(int8[1]);
|
||||
return res;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, half2>(half2 val)
|
||||
{
|
||||
return float22bf162(__half22float2(val));
|
||||
}
|
||||
|
||||
#endif // ENABLE BF16
|
||||
|
||||
template <typename T>
|
||||
__device__ inline T cuda_abs(T val)
|
||||
{
|
||||
assert(false);
|
||||
return {};
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline float cuda_abs(float val)
|
||||
{
|
||||
return fabs(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline float2 cuda_abs(float2 val)
|
||||
{
|
||||
return make_float2(fabs(val.x), fabs(val.y));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline half cuda_abs(half val)
|
||||
{
|
||||
return __habs(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline half2 cuda_abs(half2 val)
|
||||
{
|
||||
return __habs2(val);
|
||||
}
|
||||
|
||||
#ifdef ENABLE_BF16
|
||||
|
||||
#if __CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)
|
||||
template <>
|
||||
__device__ inline __nv_bfloat16 cuda_abs(__nv_bfloat16 val)
|
||||
{
|
||||
return __habs(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline __nv_bfloat162 cuda_abs(__nv_bfloat162 val)
|
||||
{
|
||||
return __habs2(val);
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // ENABLE_FP16
|
||||
|
||||
template <typename To, typename Ti>
|
||||
__device__ inline To cuda_sum(Ti val)
|
||||
{
|
||||
return cuda_cast<To>(val);
|
||||
};
|
||||
|
||||
template <typename To>
|
||||
__device__ inline To cuda_sum(float2 val)
|
||||
{
|
||||
return cuda_cast<To>(val.x + val.y);
|
||||
};
|
||||
|
||||
// Unary maximum: compute the max of a vector type
|
||||
template <typename To, typename Ti>
|
||||
__device__ inline To cuda_max(Ti val)
|
||||
{
|
||||
return cuda_cast<To>(val);
|
||||
};
|
||||
|
||||
template <>
|
||||
__device__ inline float cuda_max(float2 val)
|
||||
{
|
||||
return fmaxf(val.x, val.y);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline half cuda_max(half2 val)
|
||||
{
|
||||
return __hmax(val.x, val.y);
|
||||
}
|
||||
|
||||
#ifdef ENABLE_BF16
|
||||
template <>
|
||||
__device__ inline __nv_bfloat16 cuda_max(__nv_bfloat162 val)
|
||||
{
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
|
||||
return __hmax(val.x, val.y);
|
||||
#else
|
||||
assert(0);
|
||||
asm volatile("brkpt;\n" ::);
|
||||
return __nv_bfloat16(0);
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
||||
// Binary maximum: compute the max of two values.
|
||||
template <typename T>
|
||||
__device__ inline T cuda_max(T val1, T val2)
|
||||
{
|
||||
return (val1 > val2) ? val1 : val2;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline float2 cuda_max(float2 val1, float2 val2)
|
||||
{
|
||||
float2 out;
|
||||
out.x = fmaxf(val1.x, val2.x);
|
||||
out.y = fmaxf(val1.y, val2.y);
|
||||
return out;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline half2 cuda_max(half2 val1, half2 val2)
|
||||
{
|
||||
return __hmax2(val1, val2);
|
||||
}
|
||||
|
||||
#ifdef ENABLE_BF16
|
||||
template <>
|
||||
__device__ inline __nv_bfloat162 cuda_max(__nv_bfloat162 val1, __nv_bfloat162 val2)
|
||||
{
|
||||
return __hmax2(val1, val2);
|
||||
}
|
||||
#endif // ENABLE_BF16
|
||||
|
||||
// Binary maximum: compute the min of two values.
|
||||
template <typename T>
|
||||
__device__ inline T cuda_min(T val1, T val2)
|
||||
{
|
||||
return (val1 < val2) ? val1 : val2;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline float2 cuda_min(float2 val1, float2 val2)
|
||||
{
|
||||
float2 out;
|
||||
out.x = fminf(val1.x, val2.x);
|
||||
out.y = fminf(val1.y, val2.y);
|
||||
return out;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline half2 cuda_min(half2 val1, half2 val2)
|
||||
{
|
||||
return __hmin2(val1, val2);
|
||||
}
|
||||
|
||||
#ifdef ENABLE_BF16
|
||||
template <>
|
||||
__device__ inline __nv_bfloat162 cuda_min(__nv_bfloat162 val1, __nv_bfloat162 val2)
|
||||
{
|
||||
return __hmin2(val1, val2);
|
||||
}
|
||||
#endif // ENABLE_BF16
|
||||
|
||||
// Helper function of clamping the val into the given range.
|
||||
template <typename T>
|
||||
inline __device__ T cuda_clamp(T val, T minVal, T maxVal)
|
||||
{
|
||||
return cuda_min(cuda_max(val, minVal), maxVal);
|
||||
}
|
||||
|
||||
#ifdef ENABLE_FP8
|
||||
template <>
|
||||
__device__ inline float2 cuda_cast<float2, __nv_fp8x2_e4m3>(__nv_fp8x2_e4m3 val)
|
||||
{
|
||||
return bf1622float2(fp8x2_e4m3_to_bfloat2(&val));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline half2 cuda_cast<half2, __nv_fp8x2_e4m3>(__nv_fp8x2_e4m3 val)
|
||||
{
|
||||
return fp8x2_e4m3_to_half2(&val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline __nv_fp8x2_e4m3 cuda_cast<__nv_fp8x2_e4m3, float2>(float2 val)
|
||||
{
|
||||
return __nv_fp8x2_e4m3(bf1622float2(float22bf162(val)));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline __nv_fp8x2_e4m3 cuda_cast<__nv_fp8x2_e4m3, half2>(half2 val)
|
||||
{
|
||||
return __nv_fp8x2_e4m3(cuda_cast<float2>(val));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline __nv_fp8x2_e4m3 cuda_cast<__nv_fp8x2_e4m3, __nv_bfloat162>(__nv_bfloat162 val)
|
||||
{
|
||||
return __nv_fp8x2_e4m3(cuda_cast<float2>(val));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, half>(half val)
|
||||
{
|
||||
return __nv_fp8_e4m3(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, __nv_bfloat16>(__nv_bfloat16 val)
|
||||
{
|
||||
return __nv_fp8_e4m3(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, float>(float val)
|
||||
{
|
||||
return __nv_fp8_e4m3(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline float cuda_cast<float, __nv_fp8_e4m3>(__nv_fp8_e4m3 val)
|
||||
{
|
||||
return (float) val;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, __nv_fp8x2_e4m3>(__nv_fp8x2_e4m3 val)
|
||||
{
|
||||
return fp8x2_e4m3_to_bfloat2(&val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline int8_t cuda_cast<int8_t, __nv_fp8_e4m3>(__nv_fp8_e4m3 val)
|
||||
{
|
||||
// no impl
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, int8_t>(int8_t val)
|
||||
{
|
||||
return cuda_cast<__nv_fp8_e4m3>(cuda_cast<__nv_bfloat16>(cuda_cast<float>(val)));
|
||||
}
|
||||
|
||||
#endif // ENABLE_FP8
|
||||
|
||||
} // namespace common
|
||||
} // namespace tensorrt_llm
|
||||
36
sgl-kernel/3rdparty/tensorrt_llm/common/customAllReduceUtils.h
vendored
Normal file
36
sgl-kernel/3rdparty/tensorrt_llm/common/customAllReduceUtils.h
vendored
Normal file
@@ -0,0 +1,36 @@
|
||||
/*
|
||||
* Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstddef>
|
||||
|
||||
namespace tensorrt_llm::utils::customAllReduceUtils
|
||||
{
|
||||
|
||||
constexpr size_t NUM_POINTERS_PER_RANK = 7;
|
||||
|
||||
// WARNING: MUST BE KEPT IN SYNC with tensorrt_llm/plugin/plugin.py
|
||||
inline size_t getMaxRequiredWorkspaceSize(int worldSize) noexcept
|
||||
{
|
||||
if (worldSize <= 2)
|
||||
{
|
||||
return 16 * 1000 * 1000;
|
||||
}
|
||||
return 8 * 1000 * 1000;
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::utils::customAllReduceUtils
|
||||
214
sgl-kernel/3rdparty/tensorrt_llm/common/envUtils.cpp
vendored
Normal file
214
sgl-kernel/3rdparty/tensorrt_llm/common/envUtils.cpp
vendored
Normal file
@@ -0,0 +1,214 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "envUtils.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
#include <cstdlib>
|
||||
|
||||
namespace tensorrt_llm::common
|
||||
{
|
||||
|
||||
std::optional<int32_t> getIntEnv(char const* name)
|
||||
{
|
||||
char const* const env = std::getenv(name);
|
||||
if (env == nullptr)
|
||||
{
|
||||
return std::nullopt;
|
||||
}
|
||||
int32_t const val = std::stoi(env);
|
||||
if (val <= 0)
|
||||
{
|
||||
return std::nullopt;
|
||||
}
|
||||
return {val};
|
||||
};
|
||||
|
||||
// Returns true if the env variable exists and is set to "1"
|
||||
static bool getBoolEnv(char const* name)
|
||||
{
|
||||
char const* env = std::getenv(name);
|
||||
return env && env[0] == '1' && env[1] == '\0';
|
||||
}
|
||||
|
||||
// XQA kernels (optimized kernels for generation phase).
|
||||
bool forceXQAKernels()
|
||||
{
|
||||
static bool const forceXQA = (getIntEnv("TRTLLM_FORCE_XQA").value_or(0) != 0);
|
||||
return forceXQA;
|
||||
}
|
||||
|
||||
std::optional<bool> getEnvEnableXQAJIT()
|
||||
{
|
||||
static bool init = false;
|
||||
static bool exists = false;
|
||||
static bool enableXQAJIT = false;
|
||||
if (!init)
|
||||
{
|
||||
init = true;
|
||||
char const* enable_xqa_jit_var = std::getenv("TRTLLM_ENABLE_XQA_JIT");
|
||||
if (enable_xqa_jit_var)
|
||||
{
|
||||
exists = true;
|
||||
if (enable_xqa_jit_var[0] == '1' && enable_xqa_jit_var[1] == '\0')
|
||||
{
|
||||
enableXQAJIT = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (exists)
|
||||
{
|
||||
return enableXQAJIT;
|
||||
}
|
||||
else
|
||||
{
|
||||
return std::nullopt;
|
||||
}
|
||||
}
|
||||
|
||||
// Tune the number of blocks per sequence for accuracy/performance purpose.
|
||||
bool getEnvMmhaMultiblockDebug()
|
||||
{
|
||||
static bool init = false;
|
||||
static bool forceMmhaMaxSeqLenTile = false;
|
||||
if (!init)
|
||||
{
|
||||
init = true;
|
||||
char const* enable_mmha_debug_var = std::getenv("TRTLLM_ENABLE_MMHA_MULTI_BLOCK_DEBUG");
|
||||
if (enable_mmha_debug_var)
|
||||
{
|
||||
if (enable_mmha_debug_var[0] == '1' && enable_mmha_debug_var[1] == '\0')
|
||||
{
|
||||
forceMmhaMaxSeqLenTile = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return forceMmhaMaxSeqLenTile;
|
||||
}
|
||||
|
||||
int getEnvMmhaBlocksPerSequence()
|
||||
{
|
||||
static bool init = false;
|
||||
static int mmhaBlocksPerSequence = 0;
|
||||
if (!init)
|
||||
{
|
||||
init = true;
|
||||
char const* mmhaBlocksPerSequenceEnv = std::getenv("TRTLLM_MMHA_BLOCKS_PER_SEQUENCE");
|
||||
if (mmhaBlocksPerSequenceEnv)
|
||||
{
|
||||
mmhaBlocksPerSequence = std::atoi(mmhaBlocksPerSequenceEnv);
|
||||
if (mmhaBlocksPerSequence <= 0)
|
||||
{
|
||||
TLLM_LOG_WARNING("Invalid value for TRTLLM_MMHA_BLOCKS_PER_SEQUENCE. Will use default values instead!");
|
||||
}
|
||||
}
|
||||
}
|
||||
return mmhaBlocksPerSequence;
|
||||
}
|
||||
|
||||
int getEnvMmhaKernelBlockSize()
|
||||
{
|
||||
static bool init = false;
|
||||
static int mmhaKernelBlockSize = 0;
|
||||
if (!init)
|
||||
{
|
||||
init = true;
|
||||
char const* mmhaKernelBlockSizeEnv = std::getenv("TRTLLM_MMHA_KERNEL_BLOCK_SIZE");
|
||||
if (mmhaKernelBlockSizeEnv)
|
||||
{
|
||||
mmhaKernelBlockSize = std::atoi(mmhaKernelBlockSizeEnv);
|
||||
if (mmhaKernelBlockSize <= 0)
|
||||
{
|
||||
TLLM_LOG_WARNING("Invalid value for TRTLLM_MMHA_KERNEL_BLOCK_SIZE. Will use default values instead!");
|
||||
}
|
||||
}
|
||||
}
|
||||
return mmhaKernelBlockSize;
|
||||
}
|
||||
|
||||
bool getEnvEnablePDL()
|
||||
{
|
||||
static bool init = false;
|
||||
static bool enablePDL = false;
|
||||
if (!init)
|
||||
{
|
||||
init = true;
|
||||
// PDL only available when arch >= 90
|
||||
if (getSMVersion() >= 90)
|
||||
{
|
||||
// PDL will be enabled by setting the env variables `TRTLLM_ENABLE_PDL` to `1`
|
||||
enablePDL = getBoolEnv("TRTLLM_ENABLE_PDL");
|
||||
}
|
||||
}
|
||||
return enablePDL;
|
||||
}
|
||||
|
||||
bool getEnvUseUCXKvCache()
|
||||
{
|
||||
static bool const useUCXKVCache = getBoolEnv("TRTLLM_USE_UCX_KVCACHE");
|
||||
return useUCXKVCache;
|
||||
}
|
||||
|
||||
std::string getEnvUCXInterface()
|
||||
{
|
||||
static bool init = false;
|
||||
static std::string ucxInterface;
|
||||
if (!init)
|
||||
{
|
||||
init = true;
|
||||
{
|
||||
char const* ucx_interface = std::getenv("TRTLLM_UCX_INTERFACE");
|
||||
if (ucx_interface)
|
||||
{
|
||||
ucxInterface = ucx_interface;
|
||||
}
|
||||
}
|
||||
}
|
||||
return ucxInterface;
|
||||
}
|
||||
|
||||
bool getEnvDisaggLayerwise()
|
||||
{
|
||||
static bool const disaggLayerwise = getBoolEnv("TRTLLM_DISAGG_LAYERWISE");
|
||||
return disaggLayerwise;
|
||||
}
|
||||
|
||||
bool getEnvParallelCacheSend()
|
||||
{
|
||||
static bool const parallelCacheSend = getBoolEnv("TRTLLM_PARALLEL_CACHE_SEND");
|
||||
return parallelCacheSend;
|
||||
}
|
||||
|
||||
bool getEnvRequestKVCacheSerial()
|
||||
{
|
||||
static bool const requestKVCacheSerial = getBoolEnv("TRTLLM_REQUEST_KV_CACHE_SERIAL");
|
||||
return requestKVCacheSerial;
|
||||
}
|
||||
|
||||
bool getEnvDisableKVCacheTransferOverlap()
|
||||
{
|
||||
static bool const disableKVCacheTransferOverlap = getBoolEnv("TRTLLM_DISABLE_KV_CACHE_TRANSFER_OVERLAP");
|
||||
return disableKVCacheTransferOverlap;
|
||||
}
|
||||
|
||||
bool getEnvDisableReceiveKVCacheParallel()
|
||||
{
|
||||
static bool const disableReceiveParallel = getBoolEnv("TRTLLM_DISABLE_KVCACHE_RECEIVE_PARALLEL");
|
||||
return disableReceiveParallel;
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::common
|
||||
60
sgl-kernel/3rdparty/tensorrt_llm/common/envUtils.h
vendored
Normal file
60
sgl-kernel/3rdparty/tensorrt_llm/common/envUtils.h
vendored
Normal file
@@ -0,0 +1,60 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include <cstdint>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
|
||||
namespace tensorrt_llm::common
|
||||
{
|
||||
// Useful when you want to inject some debug code controllable with env var.
|
||||
std::optional<int32_t> getIntEnv(char const* name);
|
||||
|
||||
// XQA kernels (optimized kernels for generation phase).
|
||||
bool forceXQAKernels();
|
||||
|
||||
// Whether XQA JIT is enabled.
|
||||
//
|
||||
// Returns the value of TRTLLM_ENABLE_XQA_JIT env var. If such env var doesn't exist, std::nullopt is returned.
|
||||
std::optional<bool> getEnvEnableXQAJIT();
|
||||
|
||||
// Tune the number of blocks per sequence for accuracy/performance purpose.
|
||||
bool getEnvMmhaMultiblockDebug();
|
||||
|
||||
int getEnvMmhaBlocksPerSequence();
|
||||
|
||||
int getEnvMmhaKernelBlockSize();
|
||||
|
||||
// Whether PDL is enabled.
|
||||
bool getEnvEnablePDL();
|
||||
|
||||
bool getEnvUseUCXKvCache();
|
||||
|
||||
std::string getEnvUCXInterface();
|
||||
|
||||
bool getEnvDisaggLayerwise();
|
||||
|
||||
bool getEnvParallelCacheSend();
|
||||
|
||||
bool getEnvRequestKVCacheSerial();
|
||||
|
||||
bool getEnvDisableKVCacheTransferOverlap();
|
||||
|
||||
bool getEnvDisableReceiveKVCacheParallel();
|
||||
|
||||
} // namespace tensorrt_llm::common
|
||||
70
sgl-kernel/3rdparty/tensorrt_llm/common/logger.cpp
vendored
Normal file
70
sgl-kernel/3rdparty/tensorrt_llm/common/logger.cpp
vendored
Normal file
@@ -0,0 +1,70 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/common/tllmException.h"
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
namespace tensorrt_llm::common
|
||||
{
|
||||
|
||||
Logger::Logger()
|
||||
{
|
||||
char* isFirstRankOnlyChar = std::getenv("TLLM_LOG_FIRST_RANK_ONLY");
|
||||
bool isFirstRankOnly = (isFirstRankOnlyChar != nullptr && std::string(isFirstRankOnlyChar) == "ON");
|
||||
|
||||
auto const* levelName = std::getenv("TLLM_LOG_LEVEL");
|
||||
if (levelName != nullptr)
|
||||
{
|
||||
auto level = [levelName = std::string(levelName)]()
|
||||
{
|
||||
if (levelName == "TRACE")
|
||||
return TRACE;
|
||||
if (levelName == "DEBUG")
|
||||
return DEBUG;
|
||||
if (levelName == "INFO")
|
||||
return INFO;
|
||||
if (levelName == "WARNING")
|
||||
return WARNING;
|
||||
if (levelName == "ERROR")
|
||||
return ERROR;
|
||||
TLLM_THROW("Invalid log level: %s", levelName.c_str());
|
||||
}();
|
||||
// If TLLM_LOG_FIRST_RANK_ONLY=ON, set LOG LEVEL of other device to ERROR
|
||||
if (isFirstRankOnly)
|
||||
{
|
||||
auto const deviceId = getDevice();
|
||||
if (deviceId != 1)
|
||||
{
|
||||
level = ERROR;
|
||||
}
|
||||
}
|
||||
setLevel(level);
|
||||
}
|
||||
}
|
||||
|
||||
void Logger::log(std::exception const& ex, Logger::Level level)
|
||||
{
|
||||
log(level, "%s: %s", TllmException::demangle(typeid(ex).name()).c_str(), ex.what());
|
||||
}
|
||||
|
||||
Logger* Logger::getLogger()
|
||||
{
|
||||
thread_local Logger instance;
|
||||
return &instance;
|
||||
}
|
||||
} // namespace tensorrt_llm::common
|
||||
37
sgl-kernel/3rdparty/tensorrt_llm/common/mathUtils.h
vendored
Normal file
37
sgl-kernel/3rdparty/tensorrt_llm/common/mathUtils.h
vendored
Normal file
@@ -0,0 +1,37 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace common
|
||||
{
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T>
|
||||
inline __device__ __host__ T divUp(T m, T n)
|
||||
{
|
||||
return (m + n - 1) / n;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace common
|
||||
} // namespace tensorrt_llm
|
||||
906
sgl-kernel/3rdparty/tensorrt_llm/common/memoryUtils.cu
vendored
Normal file
906
sgl-kernel/3rdparty/tensorrt_llm/common/memoryUtils.cu
vendored
Normal file
@@ -0,0 +1,906 @@
|
||||
/*
|
||||
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/cudaTypeUtils.cuh"
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
#include "tensorrt_llm/common/memoryUtils.h"
|
||||
|
||||
#include <curand_kernel.h>
|
||||
#include <sys/stat.h>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace common
|
||||
{
|
||||
|
||||
template <typename T>
|
||||
void deviceMalloc(T** ptr, size_t size, bool is_random_initialize)
|
||||
{
|
||||
check_cuda_error(cudaMalloc((void**) (ptr), sizeof(T) * size));
|
||||
if (is_random_initialize)
|
||||
{
|
||||
cudaRandomUniform(*ptr, size);
|
||||
}
|
||||
}
|
||||
|
||||
template void deviceMalloc(float** ptr, size_t size, bool is_random_initialize);
|
||||
template void deviceMalloc(half** ptr, size_t size, bool is_random_initialize);
|
||||
#ifdef ENABLE_BF16
|
||||
template void deviceMalloc(__nv_bfloat16** ptr, size_t size, bool is_random_initialize);
|
||||
#endif
|
||||
template void deviceMalloc(uint16_t** ptr, size_t size, bool is_random_initialize);
|
||||
template void deviceMalloc(int** ptr, size_t size, bool is_random_initialize);
|
||||
template void deviceMalloc(bool** ptr, size_t size, bool is_random_initialize);
|
||||
template void deviceMalloc(char** ptr, size_t size, bool is_random_initialize);
|
||||
template void deviceMalloc(int8_t** ptr, size_t size, bool is_random_initialize);
|
||||
#ifdef ENABLE_FP8
|
||||
template void deviceMalloc(__nv_fp8_e4m3** ptr, size_t size, bool is_random_initialize);
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
void deviceMemSetZero(T* ptr, size_t size)
|
||||
{
|
||||
check_cuda_error(cudaMemset(static_cast<void*>(ptr), 0, sizeof(T) * size));
|
||||
}
|
||||
|
||||
template void deviceMemSetZero(float* ptr, size_t size);
|
||||
template void deviceMemSetZero(half* ptr, size_t size);
|
||||
template void deviceMemSetZero(int* ptr, size_t size);
|
||||
template void deviceMemSetZero(uint32_t* ptr, size_t size);
|
||||
template void deviceMemSetZero(bool* ptr, size_t size);
|
||||
#ifdef ENABLE_FP8
|
||||
template void deviceMemSetZero(__nv_fp8_e4m3* ptr, size_t size);
|
||||
#endif
|
||||
#ifdef ENABLE_BF16
|
||||
template void deviceMemSetZero(__nv_bfloat16* ptr, size_t size);
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
void deviceFree(T*& ptr)
|
||||
{
|
||||
if (ptr != NULL)
|
||||
{
|
||||
check_cuda_error(cudaFree(ptr));
|
||||
ptr = NULL;
|
||||
}
|
||||
}
|
||||
|
||||
template void deviceFree(float*& ptr);
|
||||
template void deviceFree(half*& ptr);
|
||||
#ifdef ENABLE_BF16
|
||||
template void deviceFree(__nv_bfloat16*& ptr);
|
||||
#endif
|
||||
template void deviceFree(unsigned short*& ptr);
|
||||
template void deviceFree(int*& ptr);
|
||||
template void deviceFree(bool*& ptr);
|
||||
template void deviceFree(char*& ptr);
|
||||
template void deviceFree(int8_t*& ptr);
|
||||
#ifdef ENABLE_FP8
|
||||
template void deviceFree(__nv_fp8_e4m3*& ptr);
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
void deviceFill(T* devptr, size_t size, T value, cudaStream_t stream)
|
||||
{
|
||||
T* arr = new T[size];
|
||||
std::fill(arr, arr + size, value);
|
||||
check_cuda_error(cudaMemcpyAsync(devptr, arr, sizeof(T) * size, cudaMemcpyHostToDevice, stream));
|
||||
delete[] arr;
|
||||
}
|
||||
|
||||
template void deviceFill(float* devptr, size_t size, float value, cudaStream_t stream);
|
||||
template void deviceFill(half* devptr, size_t size, half value, cudaStream_t stream);
|
||||
#ifdef ENABLE_BF16
|
||||
template void deviceFill(__nv_bfloat16* devptr, size_t size, __nv_bfloat16 value, cudaStream_t stream);
|
||||
#endif
|
||||
template void deviceFill(int* devptr, size_t size, int value, cudaStream_t stream);
|
||||
template void deviceFill(bool* devptr, size_t size, bool value, cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
void cudaD2Hcpy(T* tgt, T const* src, const size_t size)
|
||||
{
|
||||
check_cuda_error(cudaMemcpy(tgt, src, sizeof(T) * size, cudaMemcpyDeviceToHost));
|
||||
}
|
||||
|
||||
template void cudaD2Hcpy(float* tgt, float const* src, size_t size);
|
||||
template void cudaD2Hcpy(half* tgt, half const* src, size_t size);
|
||||
#ifdef ENABLE_BF16
|
||||
template void cudaD2Hcpy(__nv_bfloat16* tgt, __nv_bfloat16 const* src, size_t size);
|
||||
#endif
|
||||
template void cudaD2Hcpy(int* tgt, int const* src, size_t size);
|
||||
template void cudaD2Hcpy(bool* tgt, bool const* src, size_t size);
|
||||
#ifdef ENABLE_FP8
|
||||
template void cudaD2Hcpy(__nv_fp8_e4m3* tgt, __nv_fp8_e4m3 const* src, size_t size);
|
||||
#endif
|
||||
template void cudaD2Hcpy(unsigned long long* tgt, unsigned long long const* src, size_t size);
|
||||
template void cudaD2Hcpy(unsigned int* tgt, unsigned int const* src, size_t size);
|
||||
template void cudaD2Hcpy(int8_t* tgt, int8_t const* src, size_t size);
|
||||
|
||||
template <typename T>
|
||||
void cudaH2Dcpy(T* tgt, T const* src, const size_t size)
|
||||
{
|
||||
check_cuda_error(cudaMemcpy(tgt, src, sizeof(T) * size, cudaMemcpyHostToDevice));
|
||||
}
|
||||
|
||||
template void cudaH2Dcpy(float* tgt, float const* src, size_t size);
|
||||
template void cudaH2Dcpy(half* tgt, half const* src, size_t size);
|
||||
#ifdef ENABLE_BF16
|
||||
template void cudaH2Dcpy(__nv_bfloat16* tgt, __nv_bfloat16 const* src, size_t size);
|
||||
#endif
|
||||
template void cudaH2Dcpy(int* tgt, int const* src, size_t size);
|
||||
template void cudaH2Dcpy(bool* tgt, bool const* src, size_t size);
|
||||
#ifdef ENABLE_FP8
|
||||
template void cudaH2Dcpy(__nv_fp8_e4m3* tgt, __nv_fp8_e4m3 const* src, size_t size);
|
||||
#endif
|
||||
template void cudaH2Dcpy(unsigned long long* tgt, unsigned long long const* src, size_t size);
|
||||
template void cudaH2Dcpy(unsigned int* tgt, unsigned int const* src, size_t size);
|
||||
template void cudaH2Dcpy(int8_t* tgt, int8_t const* src, size_t size);
|
||||
|
||||
template <typename T>
|
||||
void cudaD2Dcpy(T* tgt, T const* src, const size_t size, cudaStream_t stream)
|
||||
{
|
||||
check_cuda_error(cudaMemcpyAsync(tgt, src, sizeof(T) * size, cudaMemcpyDeviceToDevice, stream));
|
||||
}
|
||||
|
||||
template void cudaD2Dcpy(float* tgt, float const* src, size_t size, cudaStream_t stream);
|
||||
template void cudaD2Dcpy(half* tgt, half const* src, size_t size, cudaStream_t stream);
|
||||
#ifdef ENABLE_BF16
|
||||
template void cudaD2Dcpy(__nv_bfloat16* tgt, __nv_bfloat16 const* src, size_t size, cudaStream_t stream);
|
||||
#endif
|
||||
template void cudaD2Dcpy(int* tgt, int const* src, size_t size, cudaStream_t stream);
|
||||
template void cudaD2Dcpy(bool* tgt, bool const* src, size_t size, cudaStream_t stream);
|
||||
template void cudaD2Dcpy(int8_t* tgt, int8_t const* src, size_t size, cudaStream_t stream);
|
||||
#ifdef ENABLE_FP8
|
||||
template void cudaD2Dcpy(__nv_fp8_e4m3* tgt, __nv_fp8_e4m3 const* src, size_t size, cudaStream_t stream);
|
||||
#endif
|
||||
template void cudaD2Dcpy(unsigned long long* tgt, unsigned long long const* src, size_t size, cudaStream_t stream);
|
||||
|
||||
template <typename T_OUT, typename T_IN>
|
||||
__global__ void cudaCast(T_OUT* dst, T_IN* src, const size_t size)
|
||||
{
|
||||
for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x)
|
||||
{
|
||||
dst[tid] = (T_OUT) ((float) (src[tid]));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T_OUT, typename T_IN>
|
||||
void invokeCudaCast(T_OUT* dst, T_IN const* const src, const size_t size, cudaStream_t stream)
|
||||
{
|
||||
cudaCast<<<256, 256, 0, stream>>>(dst, src, size);
|
||||
}
|
||||
|
||||
template void invokeCudaCast(float* dst, half const* const src, const size_t size, cudaStream_t stream);
|
||||
#ifdef ENABLE_BF16
|
||||
template void invokeCudaCast(float* dst, __nv_bfloat16 const* const src, const size_t size, cudaStream_t stream);
|
||||
template void invokeCudaCast(__nv_bfloat16* dst, float const* const src, const size_t size, cudaStream_t stream);
|
||||
template void invokeCudaCast(__nv_bfloat16* dst, half const* const src, const size_t size, cudaStream_t stream);
|
||||
template void invokeCudaCast(half* dst, __nv_bfloat16 const* const src, const size_t size, cudaStream_t stream);
|
||||
#endif
|
||||
#ifdef ENABLE_FP8
|
||||
template void invokeCudaCast(float* dst, __nv_fp8_e4m3 const* const src, const size_t size, cudaStream_t stream);
|
||||
template void invokeCudaCast(
|
||||
__nv_bfloat16* dst, __nv_fp8_e4m3 const* const src, const size_t size, cudaStream_t stream);
|
||||
template void invokeCudaCast(half* dst, __nv_fp8_e4m3 const* const src, const size_t size, cudaStream_t stream);
|
||||
template void invokeCudaCast(__nv_fp8_e4m3* dst, float const* const src, const size_t size, cudaStream_t stream);
|
||||
template void invokeCudaCast(
|
||||
__nv_fp8_e4m3* dst, __nv_bfloat16 const* const src, const size_t size, cudaStream_t stream);
|
||||
template void invokeCudaCast(__nv_fp8_e4m3* dst, half const* const src, const size_t size, cudaStream_t stream);
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
void cudaAutoCpy(T* tgt, T const* src, const size_t size, cudaStream_t stream)
|
||||
{
|
||||
if (stream != NULL)
|
||||
{
|
||||
check_cuda_error(cudaMemcpyAsync(tgt, src, sizeof(T) * size, cudaMemcpyDefault, stream));
|
||||
}
|
||||
else
|
||||
{
|
||||
check_cuda_error(cudaMemcpy(tgt, src, sizeof(T) * size, cudaMemcpyDefault));
|
||||
}
|
||||
}
|
||||
|
||||
template void cudaAutoCpy(float* tgt, float const* src, size_t size, cudaStream_t stream);
|
||||
template void cudaAutoCpy(half* tgt, half const* src, size_t size, cudaStream_t stream);
|
||||
#ifdef ENABLE_BF16
|
||||
template void cudaAutoCpy(__nv_bfloat16* tgt, __nv_bfloat16 const* src, size_t size, cudaStream_t stream);
|
||||
#endif
|
||||
template void cudaAutoCpy(int* tgt, int const* src, size_t size, cudaStream_t stream);
|
||||
template void cudaAutoCpy(bool* tgt, bool const* src, size_t size, cudaStream_t stream);
|
||||
template void cudaAutoCpy(int8_t* tgt, int8_t const* src, size_t size, cudaStream_t stream);
|
||||
template void cudaAutoCpy(uint8_t* tgt, uint8_t const* src, size_t size, cudaStream_t stream);
|
||||
template void cudaAutoCpy(uint32_t* tgt, uint32_t const* src, size_t size, cudaStream_t stream);
|
||||
template void cudaAutoCpy(unsigned long long* tgt, unsigned long long const* src, size_t size, cudaStream_t stream);
|
||||
template void cudaAutoCpy(unsigned long* tgt, unsigned long const* src, size_t size, cudaStream_t stream);
|
||||
template void cudaAutoCpy(char* tgt, char const* src, size_t size, cudaStream_t stream);
|
||||
|
||||
template void cudaAutoCpy(float const** tgt, float const* const* src, size_t size, cudaStream_t stream);
|
||||
template void cudaAutoCpy(half const** tgt, half const* const* src, size_t size, cudaStream_t stream);
|
||||
#ifdef ENABLE_BF16
|
||||
template void cudaAutoCpy(__nv_bfloat16 const** tgt, __nv_bfloat16 const* const* src, size_t size, cudaStream_t stream);
|
||||
#endif
|
||||
template void cudaAutoCpy(int const** tgt, int const* const* src, size_t size, cudaStream_t stream);
|
||||
template void cudaAutoCpy(bool const** tgt, bool const* const* src, size_t size, cudaStream_t stream);
|
||||
template void cudaAutoCpy(int8_t const** tgt, int8_t const* const* src, size_t size, cudaStream_t stream);
|
||||
template void cudaAutoCpy(
|
||||
unsigned long long const** tgt, unsigned long long const* const* src, size_t size, cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
__global__ void cuda_random_uniform_kernel(T* buffer, const size_t size, int const seq_offset)
|
||||
{
|
||||
const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
curandState_t local_state;
|
||||
curand_init((unsigned long long int) 1337, idx + seq_offset, 0, &local_state);
|
||||
for (size_t index = idx; index < size; index += blockDim.x * gridDim.x)
|
||||
{
|
||||
buffer[index] = (T) (curand_uniform(&local_state) * 0.2f - 0.1f);
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__global__ void cuda_random_uniform_kernel<int>(int* buffer, const size_t size, int const seq_offset)
|
||||
{
|
||||
const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
curandState_t local_state;
|
||||
curand_init((float) 1337.f, idx + seq_offset, 0, &local_state);
|
||||
for (size_t index = idx; index < size; index += blockDim.x * gridDim.x)
|
||||
{
|
||||
buffer[index] = curand(&local_state);
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__global__ void cuda_random_uniform_kernel<bool>(bool* buffer, const size_t size, int const seq_offset)
|
||||
{
|
||||
const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
curandState_t local_state;
|
||||
curand_init((float) 1337.f, idx + seq_offset, 0, &local_state);
|
||||
for (size_t index = idx; index < size; index += blockDim.x * gridDim.x)
|
||||
{
|
||||
buffer[index] = (curand(&local_state) % 2 == 0);
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__global__ void cuda_random_uniform_kernel<char>(char* buffer, const size_t size, int const seq_offset)
|
||||
{
|
||||
const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
curandState_t local_state;
|
||||
curand_init((float) 1337.f, idx + seq_offset, 0, &local_state);
|
||||
for (size_t index = idx; index < size; index += blockDim.x * gridDim.x)
|
||||
{
|
||||
buffer[index] = curand(&local_state) % 0xFF;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void cudaRandomUniform(T* buffer, const size_t size)
|
||||
{
|
||||
static int seq_offset = 0;
|
||||
cuda_random_uniform_kernel<T><<<256, 256>>>(buffer, size, seq_offset);
|
||||
seq_offset += 256 * 256;
|
||||
}
|
||||
|
||||
template void cudaRandomUniform(float* buffer, const size_t size);
|
||||
template void cudaRandomUniform(half* buffer, const size_t size);
|
||||
#ifdef ENABLE_BF16
|
||||
template void cudaRandomUniform(__nv_bfloat16* buffer, const size_t size);
|
||||
#endif
|
||||
template void cudaRandomUniform(int* buffer, const size_t size);
|
||||
template void cudaRandomUniform(bool* buffer, const size_t size);
|
||||
template void cudaRandomUniform(char* buffer, const size_t size);
|
||||
#ifdef ENABLE_FP8
|
||||
template void cudaRandomUniform(__nv_fp8_e4m3* buffer, const size_t size);
|
||||
#endif
|
||||
|
||||
// loads data from binary file. If it succeeds, returns a non-empty vector. If loading fails or
|
||||
// the product of the elements in shape is 0, this function will return an empty vector.
|
||||
template <typename T>
|
||||
std::vector<T> loadWeightFromBinHelper(std::vector<size_t> shape, std::string filename)
|
||||
{
|
||||
if (shape.size() > 2)
|
||||
{
|
||||
printf("[ERROR] shape should have less than two dims \n");
|
||||
return std::vector<T>();
|
||||
}
|
||||
size_t dim0 = shape[0], dim1 = 1;
|
||||
if (shape.size() == 2)
|
||||
{
|
||||
dim1 = shape[1];
|
||||
}
|
||||
size_t size = dim0 * dim1;
|
||||
if (size == 0)
|
||||
{
|
||||
TLLM_LOG_WARNING("shape is zero, skip loading weight from file %s \n", filename.c_str());
|
||||
return std::vector<T>();
|
||||
}
|
||||
|
||||
std::vector<T> host_array(size);
|
||||
std::ifstream in(filename, std::ios::in | std::ios::binary);
|
||||
if (!in.is_open())
|
||||
{
|
||||
TLLM_LOG_WARNING("file %s cannot be opened, loading model fails! \n", filename.c_str());
|
||||
return std::vector<T>();
|
||||
}
|
||||
|
||||
size_t loaded_data_size = sizeof(T) * size;
|
||||
in.seekg(0, in.end);
|
||||
in.seekg(0, in.beg);
|
||||
|
||||
TLLM_LOG_DEBUG("Read " + std::to_string(loaded_data_size) + " bytes from " + filename);
|
||||
in.read((char*) host_array.data(), loaded_data_size);
|
||||
|
||||
size_t in_get_size = in.gcount();
|
||||
if (in_get_size != loaded_data_size)
|
||||
{
|
||||
TLLM_LOG_WARNING("file %s only has %ld, but request %ld, loading model fails! \n", filename.c_str(),
|
||||
in_get_size, loaded_data_size);
|
||||
return std::vector<T>();
|
||||
}
|
||||
in.close();
|
||||
// If we succeed, return an array with values.
|
||||
return host_array;
|
||||
}
|
||||
|
||||
template <typename T, typename T_IN>
|
||||
int loadWeightFromBinFunc(T* ptr, std::vector<size_t> shape, std::string filename)
|
||||
{
|
||||
std::vector<T_IN> host_array = loadWeightFromBinHelper<T_IN>(shape, filename);
|
||||
|
||||
if (host_array.empty())
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (std::is_same<T, T_IN>::value == true)
|
||||
{
|
||||
cudaH2Dcpy(ptr, (T*) host_array.data(), host_array.size());
|
||||
}
|
||||
else
|
||||
{
|
||||
T_IN* ptr_2 = nullptr;
|
||||
deviceMalloc(&ptr_2, host_array.size(), false);
|
||||
cudaH2Dcpy(ptr_2, host_array.data(), host_array.size());
|
||||
invokeCudaD2DcpyConvert(ptr, ptr_2, host_array.size());
|
||||
deviceFree(ptr_2);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
template int loadWeightFromBinFunc<float, float>(float* ptr, std::vector<size_t> shape, std::string filename);
|
||||
template int loadWeightFromBinFunc<half, float>(half* ptr, std::vector<size_t> shape, std::string filename);
|
||||
template int loadWeightFromBinFunc<float, half>(float* ptr, std::vector<size_t> shape, std::string filename);
|
||||
template int loadWeightFromBinFunc<half, half>(half* ptr, std::vector<size_t> shape, std::string filename);
|
||||
template int loadWeightFromBinFunc<int8_t, int8_t>(int8_t* ptr, std::vector<size_t> shape, std::string filename);
|
||||
#ifdef ENABLE_BF16
|
||||
template int loadWeightFromBinFunc<__nv_bfloat16, float>(
|
||||
__nv_bfloat16* ptr, std::vector<size_t> shape, std::string filename);
|
||||
template int loadWeightFromBinFunc<__nv_bfloat16, half>(
|
||||
__nv_bfloat16* ptr, std::vector<size_t> shape, std::string filename);
|
||||
template int loadWeightFromBinFunc<float, __nv_bfloat16>(float* ptr, std::vector<size_t> shape, std::string filename);
|
||||
template int loadWeightFromBinFunc<half, __nv_bfloat16>(half* ptr, std::vector<size_t> shape, std::string filename);
|
||||
template int loadWeightFromBinFunc<__nv_bfloat16, __nv_bfloat16>(
|
||||
__nv_bfloat16* ptr, std::vector<size_t> shape, std::string filename);
|
||||
#endif // ENABLE_BF16
|
||||
template int loadWeightFromBinFunc<int, int>(int* ptr, std::vector<size_t> shape, std::string filename);
|
||||
#ifdef ENABLE_FP8
|
||||
template int loadWeightFromBinFunc<__nv_fp8_e4m3, float>(
|
||||
__nv_fp8_e4m3* ptr, std::vector<size_t> shape, std::string filename);
|
||||
#endif // ENABLE_FP8
|
||||
|
||||
template <typename T>
|
||||
int loadWeightFromBin(T* ptr, std::vector<size_t> shape, std::string filename, TRTLLMCudaDataType model_file_type)
|
||||
{
|
||||
switch (model_file_type)
|
||||
{
|
||||
case TRTLLMCudaDataType::FP32: loadWeightFromBinFunc<T, float>(ptr, shape, filename); break;
|
||||
case TRTLLMCudaDataType::FP16: loadWeightFromBinFunc<T, half>(ptr, shape, filename); break;
|
||||
case TRTLLMCudaDataType::INT8: loadWeightFromBinFunc<T, int8_t>(ptr, shape, filename); break;
|
||||
#ifdef ENABLE_BF16
|
||||
case TRTLLMCudaDataType::BF16: loadWeightFromBinFunc<T, __nv_bfloat16>(ptr, shape, filename); break;
|
||||
#endif
|
||||
#ifdef ENABLE_FP8
|
||||
case TRTLLMCudaDataType::FP8: loadWeightFromBinFunc<T, float>(ptr, shape, filename); break;
|
||||
#endif
|
||||
default: TLLM_LOG_ERROR("Does not support TRTLLMCudaDataType=%d", model_file_type); TLLM_CHECK(false);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <>
|
||||
int loadWeightFromBin(int* ptr, std::vector<size_t> shape, std::string filename, TRTLLMCudaDataType model_file_type)
|
||||
{
|
||||
loadWeightFromBinFunc<int, int>(ptr, shape, filename);
|
||||
return 0;
|
||||
}
|
||||
|
||||
template int loadWeightFromBin(
|
||||
float* ptr, std::vector<size_t> shape, std::string filename, TRTLLMCudaDataType model_file_type);
|
||||
template int loadWeightFromBin(
|
||||
half* ptr, std::vector<size_t> shape, std::string filename, TRTLLMCudaDataType model_file_type);
|
||||
template int loadWeightFromBin(
|
||||
int8_t* ptr, std::vector<size_t> shape, std::string filename, TRTLLMCudaDataType model_file_type);
|
||||
#ifdef ENABLE_BF16
|
||||
template int loadWeightFromBin(
|
||||
__nv_bfloat16* ptr, std::vector<size_t> shape, std::string filename, TRTLLMCudaDataType model_file_type);
|
||||
#endif
|
||||
#ifdef ENABLE_FP8
|
||||
template int loadWeightFromBin(
|
||||
__nv_fp8_e4m3* ptr, std::vector<size_t> shape, std::string filename, TRTLLMCudaDataType model_file_type);
|
||||
#endif
|
||||
template int loadWeightFromBin(
|
||||
int* ptr, std::vector<size_t> shape, std::string filename, TRTLLMCudaDataType model_file_type);
|
||||
|
||||
template <typename T_IN, typename T_OUT>
|
||||
__global__ void cudaD2DcpyConvert(T_OUT* dst, const T_IN* src, const size_t size)
|
||||
{
|
||||
for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x)
|
||||
{
|
||||
dst[tid] = cuda_cast<T_OUT>(src[tid]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T_IN, typename T_OUT>
|
||||
void invokeCudaD2DcpyConvert(T_OUT* tgt, const T_IN* src, const size_t size, cudaStream_t stream)
|
||||
{
|
||||
cudaD2DcpyConvert<<<256, 256, 0, stream>>>(tgt, src, size);
|
||||
}
|
||||
|
||||
template void invokeCudaD2DcpyConvert(int8_t* tgt, float const* src, const size_t size, cudaStream_t stream);
|
||||
template void invokeCudaD2DcpyConvert(float* tgt, int8_t const* src, const size_t size, cudaStream_t stream);
|
||||
template void invokeCudaD2DcpyConvert(float* tgt, int const* src, const size_t size, cudaStream_t stream);
|
||||
template void invokeCudaD2DcpyConvert(half* tgt, int const* src, const size_t size, cudaStream_t stream);
|
||||
template void invokeCudaD2DcpyConvert(float* tgt, float const* src, const size_t size, cudaStream_t stream);
|
||||
template void invokeCudaD2DcpyConvert(half* tgt, float const* src, const size_t size, cudaStream_t stream);
|
||||
template void invokeCudaD2DcpyConvert(float* tgt, half const* src, const size_t size, cudaStream_t stream);
|
||||
template void invokeCudaD2DcpyConvert(uint32_t* tgt, int const* src, const size_t size, cudaStream_t stream);
|
||||
template void invokeCudaD2DcpyConvert(int* tgt, uint32_t const* src, const size_t size, cudaStream_t stream);
|
||||
template void invokeCudaD2DcpyConvert(int* tgt, float const* src, const size_t size, cudaStream_t stream);
|
||||
template void invokeCudaD2DcpyConvert(int* tgt, half const* src, const size_t size, cudaStream_t stream);
|
||||
|
||||
#ifdef ENABLE_BF16
|
||||
template void invokeCudaD2DcpyConvert(__nv_bfloat16* tgt, float const* src, const size_t size, cudaStream_t stream);
|
||||
template void invokeCudaD2DcpyConvert(__nv_bfloat16* tgt, int const* src, const size_t size, cudaStream_t stream);
|
||||
template void invokeCudaD2DcpyConvert(float* tgt, __nv_bfloat16 const* src, const size_t size, cudaStream_t stream);
|
||||
template void invokeCudaD2DcpyConvert(int* tgt, __nv_bfloat16 const* src, const size_t size, cudaStream_t stream);
|
||||
#endif // ENABLE_BF16
|
||||
|
||||
template <typename T_IN, typename T_OUT>
|
||||
__global__ void cudaD2DScaleCpyConvert(
|
||||
T_OUT* dst, const T_IN* src, float const* scale, bool invert_scale, const size_t size)
|
||||
{
|
||||
float const scale_value = invert_scale ? 1.0f / scale[0] : scale[0];
|
||||
for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x)
|
||||
{
|
||||
dst[tid] = cuda_cast<T_OUT>(cuda_cast<float>(src[tid]) * scale_value);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T_IN, typename T_OUT>
|
||||
void invokeCudaD2DScaleCpyConvert(
|
||||
T_OUT* tgt, const T_IN* src, float const* scale, bool invert_scale, const size_t size, cudaStream_t stream)
|
||||
{
|
||||
cudaD2DScaleCpyConvert<<<256, 256, 0, stream>>>(tgt, src, scale, invert_scale, size);
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
template void invokeCudaD2DScaleCpyConvert(float* tgt, const int32_t* src, const float* scale, bool invert_scale, const size_t size, cudaStream_t stream);
|
||||
template void invokeCudaD2DScaleCpyConvert(int32_t* tgt, const float* src, const float* scale, bool invert_scale, const size_t size, cudaStream_t stream);
|
||||
template void invokeCudaD2DScaleCpyConvert(half* tgt, const int32_t* src, const float* scale, bool invert_scale, const size_t size, cudaStream_t stream);
|
||||
template void invokeCudaD2DScaleCpyConvert(int32_t* tgt, const half* src, const float* scale, bool invert_scale, const size_t size, cudaStream_t stream);
|
||||
#ifdef ENABLE_BF16
|
||||
template void invokeCudaD2DScaleCpyConvert(__nv_bfloat16* tgt, const int32_t* src, const float* scale, bool invert_scale, const size_t size, cudaStream_t stream);
|
||||
template void invokeCudaD2DScaleCpyConvert(int32_t* tgt, const __nv_bfloat16* src, const float* scale, bool invert_scale, const size_t size, cudaStream_t stream);
|
||||
#endif // ENABLE_BF16
|
||||
#ifdef ENABLE_FP8
|
||||
template void invokeCudaD2DScaleCpyConvert(float* tgt, const __nv_fp8_e4m3* src, const float* scale, bool invert_scale, const size_t size, cudaStream_t stream);
|
||||
#endif // ENABLE_FP8
|
||||
// clang-format on
|
||||
|
||||
void invokeCudaD2DcpyHalf2Float(float* dst, half* src, const size_t size, cudaStream_t stream)
|
||||
{
|
||||
invokeCudaD2DcpyConvert(dst, src, size, stream);
|
||||
}
|
||||
|
||||
void invokeCudaD2DcpyFloat2Half(half* dst, float* src, const size_t size, cudaStream_t stream)
|
||||
{
|
||||
invokeCudaD2DcpyConvert(dst, src, size, stream);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void saveToBinary(T const* ptr, const size_t size, std::string filename)
|
||||
{
|
||||
|
||||
std::vector<T> h_ptr(size);
|
||||
cudaD2Hcpy(h_ptr.data(), ptr, size);
|
||||
std::vector<float> float_ptr(size);
|
||||
for (size_t i = 0; i < size; i++)
|
||||
{
|
||||
float_ptr[i] = (float) h_ptr[i];
|
||||
}
|
||||
|
||||
std::ofstream out(filename, std::ios::out | std::ios::binary);
|
||||
TLLM_CHECK_WITH_INFO(out.is_open(), "Fail to open file " + filename);
|
||||
|
||||
out.write((char*) float_ptr.data(), size * sizeof(float));
|
||||
}
|
||||
|
||||
template void saveToBinary(float const* ptr, const size_t size, std::string filename);
|
||||
template void saveToBinary(half const* ptr, const size_t size, std::string filename);
|
||||
#ifdef ENABLE_BF16
|
||||
template void saveToBinary(__nv_bfloat16 const* ptr, const size_t size, std::string filename);
|
||||
#endif // ENABLE_BF16
|
||||
|
||||
template <>
|
||||
void saveToBinary(int const* ptr, const size_t size, std::string filename)
|
||||
{
|
||||
std::vector<int> h_ptr(size);
|
||||
cudaD2Hcpy(h_ptr.data(), ptr, size);
|
||||
std::ofstream out(filename, std::ios::out | std::ios::binary);
|
||||
TLLM_CHECK_WITH_INFO(out.is_open(), "Fail to open file " + filename);
|
||||
out.write((char*) h_ptr.data(), size * sizeof(int));
|
||||
}
|
||||
|
||||
template <typename T_IN, typename T_fake_type>
|
||||
__global__ void fakeCast(T_IN* input_ptr, const size_t size)
|
||||
{
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x)
|
||||
{
|
||||
T_fake_type tmp_val = (T_fake_type) ((float) input_ptr[i]);
|
||||
input_ptr[i] = (T_IN) ((float) tmp_val);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T_IN, typename T_fake_type>
|
||||
void invokeFakeCast(T_IN* input_ptr, const size_t size, cudaStream_t stream)
|
||||
{
|
||||
dim3 block(256);
|
||||
dim3 grid((size + 255) / 256);
|
||||
fakeCast<T_IN, T_fake_type><<<grid, block, 0, stream>>>(input_ptr, size);
|
||||
}
|
||||
|
||||
#ifdef ENABLE_FP8
|
||||
__global__ void cudaD2Dcpyfp82Float(float* dst, __nv_fp8_e4m3* src, const size_t size)
|
||||
{
|
||||
for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x)
|
||||
{
|
||||
dst[tid] = (float) (src[tid]);
|
||||
}
|
||||
}
|
||||
|
||||
void invokeCudaD2Dcpyfp82Float(float* dst, __nv_fp8_e4m3* src, const size_t size, cudaStream_t stream)
|
||||
{
|
||||
cudaD2Dcpyfp82Float<<<256, 256, 0, stream>>>(dst, src, size);
|
||||
}
|
||||
|
||||
__global__ void cudaD2Dcpyfp82Half(half* dst, __nv_fp8_e4m3* src, const size_t size)
|
||||
{
|
||||
for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x)
|
||||
{
|
||||
dst[tid] = (half) ((float) (src[tid]));
|
||||
}
|
||||
}
|
||||
|
||||
void invokeCudaD2Dcpyfp82Half(half* dst, __nv_fp8_e4m3* src, const size_t size, cudaStream_t stream)
|
||||
{
|
||||
cudaD2Dcpyfp82Half<<<256, 256, 0, stream>>>(dst, src, size);
|
||||
}
|
||||
|
||||
__global__ void cudaD2DcpyFloat2fp8(__nv_fp8_e4m3* dst, float* src, const size_t size)
|
||||
{
|
||||
for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x)
|
||||
{
|
||||
dst[tid] = (__nv_fp8_e4m3) src[tid];
|
||||
}
|
||||
}
|
||||
|
||||
void invokeCudaD2DcpyFloat2fp8(__nv_fp8_e4m3* dst, float* src, const size_t size, cudaStream_t stream)
|
||||
{
|
||||
cudaD2DcpyFloat2fp8<<<256, 256, 0, stream>>>(dst, src, size);
|
||||
}
|
||||
|
||||
__global__ void cudaD2DcpyHalf2fp8(__nv_fp8_e4m3* dst, half* src, const size_t size)
|
||||
{
|
||||
for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x)
|
||||
{
|
||||
dst[tid] = (__nv_fp8_e4m3) src[tid];
|
||||
}
|
||||
}
|
||||
|
||||
void invokeCudaD2DcpyHalf2fp8(__nv_fp8_e4m3* dst, half* src, const size_t size, cudaStream_t stream)
|
||||
{
|
||||
cudaD2DcpyHalf2fp8<<<256, 256, 0, stream>>>(dst, src, size);
|
||||
}
|
||||
|
||||
__global__ void cudaD2DcpyBfloat2fp8(__nv_fp8_e4m3* dst, __nv_bfloat16* src, const size_t size)
|
||||
{
|
||||
for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x)
|
||||
{
|
||||
dst[tid] = (__nv_fp8_e4m3) src[tid];
|
||||
}
|
||||
}
|
||||
|
||||
void invokeCudaD2DcpyBfloat2fp8(__nv_fp8_e4m3* dst, __nv_bfloat16* src, const size_t size, cudaStream_t stream)
|
||||
{
|
||||
cudaD2DcpyBfloat2fp8<<<256, 256, 0, stream>>>(dst, src, size);
|
||||
}
|
||||
|
||||
#endif // ENABLE_FP8
|
||||
|
||||
template <typename T_OUT, typename T_IN>
|
||||
__global__ void transpose(T_OUT* dst, T_IN* src, const size_t dim0, const size_t dim1)
|
||||
{
|
||||
for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < dim0 * dim1; tid += blockDim.x * gridDim.x)
|
||||
{
|
||||
const size_t src_col_id = tid % dim1;
|
||||
const size_t src_row_id = tid / dim1;
|
||||
dst[src_col_id * dim0 + src_row_id] = (T_OUT) (src[tid]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void invokeInPlaceTranspose(T* data, T* workspace, const size_t dim0, const size_t dim1)
|
||||
{
|
||||
// copy data to workspace, and then transpose from workspace to data
|
||||
cudaD2Dcpy(workspace, data, dim0 * dim1);
|
||||
transpose<<<256, 256>>>(data, workspace, dim0, dim1);
|
||||
}
|
||||
|
||||
#ifdef ENABLE_FP8
|
||||
template void invokeInPlaceTranspose(
|
||||
__nv_fp8_e4m3* data, __nv_fp8_e4m3* workspace, const size_t dim0, const size_t dim1);
|
||||
#endif // ENABLE_FP8
|
||||
#ifdef ENABLE_BF16
|
||||
template void invokeInPlaceTranspose(
|
||||
__nv_bfloat16* data, __nv_bfloat16* workspace, const size_t dim0, const size_t dim1);
|
||||
#endif // ENABLE_BF16
|
||||
template void invokeInPlaceTranspose(float* data, float* workspace, const size_t dim0, const size_t dim1);
|
||||
|
||||
template <typename T_OUT, typename T_IN>
|
||||
__global__ void transpose0213(
|
||||
T_OUT* dst, T_IN* src, const size_t dim0, const size_t dim1, const size_t dim2, const size_t dim3)
|
||||
{
|
||||
// src permutation: [0, 1, 2, 3]
|
||||
// dst permutation: [0, 2, 1, 3]
|
||||
for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < dim0 * dim1 * dim2 * dim3;
|
||||
tid += blockDim.x * gridDim.x)
|
||||
{
|
||||
size_t tmp_idx = tid;
|
||||
const size_t dim_3_idx = tmp_idx % dim3;
|
||||
tmp_idx = (tmp_idx - dim_3_idx) / dim3;
|
||||
const size_t dim_2_idx = tmp_idx % dim2;
|
||||
tmp_idx = (tmp_idx - dim_2_idx) / dim2;
|
||||
const size_t dim_1_idx = tmp_idx % dim1;
|
||||
tmp_idx = (tmp_idx - dim_1_idx) / dim1;
|
||||
const size_t dim_0_idx = tmp_idx % dim0;
|
||||
dst[dim_0_idx * dim1 * dim2 * dim3 + dim_2_idx * dim1 * dim3 + dim_1_idx * dim3 + dim_3_idx] = src[tid];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void invokeInPlaceTranspose0213(
|
||||
T* data, T* workspace, const size_t dim0, const size_t dim1, const size_t dim2, const size_t dim3)
|
||||
{
|
||||
// copy data to workspace, and then transpose from workspace to data
|
||||
// Note that this kernel is used for pre-processing and not very efficient.
|
||||
cudaD2Dcpy(workspace, data, dim0 * dim1 * dim2 * dim3);
|
||||
transpose0213<<<256, 256>>>(data, workspace, dim0, dim1, dim2, dim3);
|
||||
}
|
||||
|
||||
#ifdef ENABLE_FP8
|
||||
template void invokeInPlaceTranspose0213(__nv_fp8_e4m3* data, __nv_fp8_e4m3* workspace, const size_t dim0,
|
||||
const size_t dim1, const size_t dim2, const size_t dim3);
|
||||
#endif // ENABLE_FP8
|
||||
#ifdef ENABLE_BF16
|
||||
template void invokeInPlaceTranspose0213(__nv_bfloat16* data, __nv_bfloat16* workspace, const size_t dim0,
|
||||
const size_t dim1, const size_t dim2, const size_t dim3);
|
||||
#endif // ENABLE_BF16
|
||||
template void invokeInPlaceTranspose0213(
|
||||
float* data, float* workspace, const size_t dim0, const size_t dim1, const size_t dim2, const size_t dim3);
|
||||
|
||||
template <typename T_OUT, typename T_IN>
|
||||
__global__ void transpose102(T_OUT* dst, T_IN* src, const size_t dim0, const size_t dim1, const size_t dim2)
|
||||
{
|
||||
// src permutation: [0, 1, 2]
|
||||
// dst permutation: [1, 0, 2]
|
||||
for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < dim0 * dim1 * dim2; tid += blockDim.x * gridDim.x)
|
||||
{
|
||||
size_t tmp_idx = tid;
|
||||
const size_t dim_2_idx = tmp_idx % dim2;
|
||||
tmp_idx = (tmp_idx - dim_2_idx) / dim2;
|
||||
const size_t dim_1_idx = tmp_idx % dim1;
|
||||
tmp_idx = (tmp_idx - dim_1_idx) / dim1;
|
||||
const size_t dim_0_idx = tmp_idx % dim0;
|
||||
dst[dim_1_idx * dim0 * dim2 + dim_0_idx * dim2 + dim_2_idx] = src[tid];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void invokeInPlaceTranspose102(T* data, T* workspace, const size_t dim0, const size_t dim1, const size_t dim2)
|
||||
{
|
||||
// copy data to workspace, and then transpose from workspace to data
|
||||
// Note that this kernel is used for pre-processing and not very efficient.
|
||||
cudaD2Dcpy(workspace, data, dim0 * dim1 * dim2);
|
||||
transpose102<<<256, 256>>>(data, workspace, dim0, dim1, dim2);
|
||||
}
|
||||
|
||||
#ifdef ENABLE_FP8
|
||||
template void invokeInPlaceTranspose102(
|
||||
__nv_fp8_e4m3* data, __nv_fp8_e4m3* workspace, const size_t dim0, const size_t dim1, const size_t dim2);
|
||||
#endif // ENABLE_FP8
|
||||
#ifdef ENABLE_BF16
|
||||
template void invokeInPlaceTranspose102(
|
||||
__nv_bfloat16* data, __nv_bfloat16* workspace, const size_t dim0, const size_t dim1, const size_t dim2);
|
||||
#endif // ENABLE_BF16
|
||||
template void invokeInPlaceTranspose102(
|
||||
float* data, float* workspace, const size_t dim0, const size_t dim1, const size_t dim2);
|
||||
|
||||
template <typename T>
|
||||
void __global__ multiplyScale(T* tensor, float scale, const size_t size)
|
||||
{
|
||||
for (size_t index = threadIdx.x + blockIdx.x * blockDim.x; index < size; index += blockDim.x * gridDim.x)
|
||||
{
|
||||
tensor[index] = (T) (((float) tensor[index]) * scale);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void invokeMultiplyScale(T* tensor, float scale, const size_t size, cudaStream_t stream)
|
||||
{
|
||||
int block = 256;
|
||||
int grid = (size + 255) / 256;
|
||||
multiplyScale<<<grid, block, 0, stream>>>(tensor, scale, size);
|
||||
}
|
||||
|
||||
template void invokeMultiplyScale(float* tensor, float scale, const size_t size, cudaStream_t stream);
|
||||
template void invokeMultiplyScale(half* tensor, float scale, const size_t size, cudaStream_t stream);
|
||||
#ifdef ENABLE_BF16
|
||||
template void invokeMultiplyScale(__nv_bfloat16* tensor, float scale, const size_t size, cudaStream_t stream);
|
||||
#endif
|
||||
#ifdef ENABLE_FP8
|
||||
template void invokeMultiplyScale(__nv_fp8_e4m3* tensor, float scale, const size_t size, cudaStream_t stream);
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
void __global__ divideScale(T* tensor, float scale, const size_t size)
|
||||
{
|
||||
for (size_t index = threadIdx.x + blockIdx.x * blockDim.x; index < size; index += blockDim.x * gridDim.x)
|
||||
{
|
||||
tensor[index] = (T) (((float) tensor[index]) / scale);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void invokeDivideScale(T* tensor, float scale, const size_t size, cudaStream_t stream)
|
||||
{
|
||||
int block = 256;
|
||||
int grid = (size + 255) / 256;
|
||||
divideScale<<<grid, block, 0, stream>>>(tensor, scale, size);
|
||||
}
|
||||
|
||||
template void invokeDivideScale(float* tensor, float scale, const size_t size, cudaStream_t stream);
|
||||
template void invokeDivideScale(half* tensor, float scale, const size_t size, cudaStream_t stream);
|
||||
#ifdef ENABLE_BF16
|
||||
template void invokeDivideScale(__nv_bfloat16* tensor, float scale, const size_t size, cudaStream_t stream);
|
||||
#endif
|
||||
#ifdef ENABLE_FP8
|
||||
template void invokeDivideScale(__nv_fp8_e4m3* tensor, float scale, const size_t size, cudaStream_t stream);
|
||||
#endif
|
||||
#ifdef ENABLE_BF16
|
||||
template void invokeFakeCast<float, __nv_bfloat16>(float* input_ptr, const size_t size, cudaStream_t stream);
|
||||
template void invokeFakeCast<__nv_bfloat16, __nv_bfloat16>(
|
||||
__nv_bfloat16* input_ptr, const size_t size, cudaStream_t stream);
|
||||
template void invokeFakeCast<half, __nv_bfloat16>(half* input_ptr, const size_t size, cudaStream_t stream);
|
||||
#endif
|
||||
template void invokeFakeCast<float, half>(float* input_ptr, const size_t size, cudaStream_t stream);
|
||||
template void invokeFakeCast<float, float>(float* input_ptr, const size_t size, cudaStream_t stream);
|
||||
#ifdef ENABLE_FP8
|
||||
template void invokeFakeCast<float, __nv_fp8_e4m3>(float* input_ptr, const size_t size, cudaStream_t stream);
|
||||
template void invokeFakeCast<half, __nv_fp8_e4m3>(half* input_ptr, const size_t size, cudaStream_t stream);
|
||||
template void invokeFakeCast<__nv_bfloat16, __nv_fp8_e4m3>(
|
||||
__nv_bfloat16* input_ptr, const size_t size, cudaStream_t stream);
|
||||
#endif
|
||||
|
||||
size_t cuda_datatype_size(TRTLLMCudaDataType dt)
|
||||
{
|
||||
static const std::unordered_map<TRTLLMCudaDataType, size_t> sizes{
|
||||
{TRTLLMCudaDataType::FP32, sizeof(float)}, {TRTLLMCudaDataType::FP16, sizeof(half)}
|
||||
#ifdef ENABLE_BF16
|
||||
,
|
||||
{TRTLLMCudaDataType::BF16, sizeof(__nv_bfloat16)}
|
||||
#endif
|
||||
};
|
||||
|
||||
return sizes.at(dt);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void check_range(T const* buffer, size_t size, T min, T max, bool* d_within_range)
|
||||
{
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x)
|
||||
{
|
||||
const T val = buffer[i];
|
||||
if (val < min || val > max)
|
||||
{
|
||||
*d_within_range = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool invokeCheckRange(T const* buffer, const size_t size, T min, T max, bool* d_within_range, cudaStream_t stream)
|
||||
{
|
||||
cudaMemsetAsync(d_within_range, true, sizeof(bool), stream);
|
||||
|
||||
dim3 block(256);
|
||||
dim3 grid((size + 255) / 256);
|
||||
check_range<T><<<grid, block, 0, stream>>>(buffer, size, min, max, d_within_range);
|
||||
|
||||
bool result;
|
||||
cudaD2Hcpy(&result, d_within_range, 1);
|
||||
return result;
|
||||
}
|
||||
|
||||
template bool invokeCheckRange<int>(
|
||||
int const* buffer, const size_t size, int min, int max, bool* d_within_range, cudaStream_t stream);
|
||||
|
||||
/*
|
||||
* Determine the total workspace size based on a vector containing multiple variable sizes.
|
||||
*/
|
||||
size_t calcAlignedSize(std::vector<size_t> const& sizes, const size_t ALIGN_BYTES)
|
||||
{
|
||||
const size_t ALIGN_MASK = ~(ALIGN_BYTES - 1);
|
||||
// Check ALIGN_BYTES is a power of 2
|
||||
assert((ALIGN_BYTES & (ALIGN_BYTES - 1)) == 0);
|
||||
|
||||
size_t total = 0;
|
||||
for (auto sz : sizes)
|
||||
{
|
||||
total += (sz + ALIGN_BYTES - 1) & ALIGN_MASK;
|
||||
}
|
||||
|
||||
// We add extra "ALIGN_BYTES - 1" bytes in case the start address passed to the function calcAlignedPointers() is
|
||||
// not aligned.
|
||||
return total + ALIGN_BYTES - 1;
|
||||
}
|
||||
|
||||
/*
|
||||
* Given the address of the workspace and the vector containing multiple variable sizes, calculate the start addresses
|
||||
* of each variable.
|
||||
*/
|
||||
void calcAlignedPointers(
|
||||
std::vector<void*>& outPtrs, void const* p, std::vector<size_t> const& sizes, size_t ALIGN_BYTES)
|
||||
{
|
||||
const size_t ALIGN_MASK = ~(ALIGN_BYTES - 1);
|
||||
// Check ALIGN_BYTES is a power of 2
|
||||
assert((ALIGN_BYTES & (ALIGN_BYTES - 1)) == 0);
|
||||
|
||||
// In case the start address is not aligned
|
||||
char* ptr = reinterpret_cast<char*>((reinterpret_cast<size_t>(p) + ALIGN_BYTES - 1) & ALIGN_MASK);
|
||||
|
||||
outPtrs.reserve(sizes.size());
|
||||
for (auto sz : sizes)
|
||||
{
|
||||
outPtrs.push_back(ptr);
|
||||
ptr += (sz + ALIGN_BYTES - 1) & ALIGN_MASK;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace common
|
||||
} // namespace tensorrt_llm
|
||||
292
sgl-kernel/3rdparty/tensorrt_llm/common/memoryUtils.h
vendored
Normal file
292
sgl-kernel/3rdparty/tensorrt_llm/common/memoryUtils.h
vendored
Normal file
@@ -0,0 +1,292 @@
|
||||
/*
|
||||
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/cudaFp8Utils.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
|
||||
#include <cassert>
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace common
|
||||
{
|
||||
|
||||
template <typename T>
|
||||
void deviceMalloc(T** ptr, size_t size, bool is_random_initialize = true);
|
||||
|
||||
template <typename T>
|
||||
void deviceMemSetZero(T* ptr, size_t size);
|
||||
|
||||
template <typename T>
|
||||
|
||||
void deviceFree(T*& ptr);
|
||||
|
||||
template <typename T>
|
||||
void deviceFill(T* devptr, size_t size, T value, cudaStream_t stream = 0);
|
||||
|
||||
template <typename T>
|
||||
void cudaD2Hcpy(T* tgt, T const* src, size_t const size);
|
||||
|
||||
template <typename T>
|
||||
void cudaH2Dcpy(T* tgt, T const* src, size_t const size);
|
||||
|
||||
template <typename T>
|
||||
void cudaD2Dcpy(T* tgt, T const* src, size_t const size, cudaStream_t stream = NULL);
|
||||
|
||||
template <typename T>
|
||||
void cudaAutoCpy(T* tgt, T const* src, size_t const size, cudaStream_t stream = NULL);
|
||||
|
||||
template <typename T>
|
||||
void cudaRandomUniform(T* buffer, size_t const size);
|
||||
|
||||
template <typename T>
|
||||
int loadWeightFromBin(T* ptr, std::vector<size_t> shape, std::string filename,
|
||||
TRTLLMCudaDataType model_file_type = TRTLLMCudaDataType::FP32);
|
||||
|
||||
// template<typename T>
|
||||
// int loadWeightFromBinAndQuantizeForWeightOnly(int8_t* quantized_weight_ptr,
|
||||
// T* scale_ptr,
|
||||
// std::vector<size_t> shape,
|
||||
// std::string filename,
|
||||
// TRTLLMCudaDataType model_file_type = TRTLLMCudaDataType::FP32);
|
||||
|
||||
void invokeCudaD2DcpyHalf2Float(float* dst, half* src, size_t const size, cudaStream_t stream);
|
||||
void invokeCudaD2DcpyFloat2Half(half* dst, float* src, size_t const size, cudaStream_t stream);
|
||||
#ifdef ENABLE_FP8
|
||||
void invokeCudaD2Dcpyfp82Float(float* dst, __nv_fp8_e4m3* src, size_t const size, cudaStream_t stream);
|
||||
void invokeCudaD2Dcpyfp82Half(half* dst, __nv_fp8_e4m3* src, size_t const size, cudaStream_t stream);
|
||||
void invokeCudaD2DcpyFloat2fp8(__nv_fp8_e4m3* dst, float* src, size_t const size, cudaStream_t stream);
|
||||
void invokeCudaD2DcpyHalf2fp8(__nv_fp8_e4m3* dst, half* src, size_t const size, cudaStream_t stream);
|
||||
void invokeCudaD2DcpyBfloat2fp8(__nv_fp8_e4m3* dst, __nv_bfloat16* src, size_t const size, cudaStream_t stream);
|
||||
#endif // ENABLE_FP8
|
||||
#ifdef ENABLE_BF16
|
||||
void invokeCudaD2DcpyBfloat2Float(float* dst, __nv_bfloat16* src, size_t const size, cudaStream_t stream);
|
||||
#endif // ENABLE_BF16
|
||||
|
||||
template <typename T_OUT, typename T_IN>
|
||||
void invokeCudaCast(T_OUT* dst, T_IN const* const src, size_t const size, cudaStream_t stream);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// The following functions implement conversion of multi-dimensional indices to an index in a flat array.
|
||||
// The shape of the Tensor dimensions is passed as one array (`dims`), the indices are given as individual arguments.
|
||||
// For examples on how to use these functions, see their tests `test_memory_utils.cu`.
|
||||
// All of these functions can be evaluated at compile time by recursive template expansion.
|
||||
|
||||
template <typename TDim, typename T, typename TIndex>
|
||||
__inline__ __host__ __device__ std::enable_if_t<std::is_pointer<TDim>::value, T> constexpr flat_index(
|
||||
T const& acc, TDim dims, TIndex const& index)
|
||||
{
|
||||
assert(index < dims[0]);
|
||||
return acc * dims[0] + index;
|
||||
}
|
||||
|
||||
template <typename TDim, typename T, typename TIndex, typename... TIndices>
|
||||
__inline__ __host__ __device__ std::enable_if_t<std::is_pointer<TDim>::value, T> constexpr flat_index(
|
||||
T const& acc, TDim dims, TIndex const& index, TIndices... indices)
|
||||
{
|
||||
assert(index < dims[0]);
|
||||
return flat_index(acc * dims[0] + index, dims + 1, indices...);
|
||||
}
|
||||
|
||||
template <typename TDim, typename T>
|
||||
__inline__ __host__ __device__ std::enable_if_t<std::is_pointer<TDim>::value, T> constexpr flat_index(
|
||||
[[maybe_unused]] TDim dims, T const& index)
|
||||
{
|
||||
assert(index < dims[0]);
|
||||
return index;
|
||||
}
|
||||
|
||||
template <typename TDim, typename TIndex, typename... TIndices>
|
||||
__inline__ __host__ __device__
|
||||
std::enable_if_t<std::is_pointer<TDim>::value, typename std::remove_pointer<TDim>::type> constexpr flat_index(
|
||||
TDim dims, TIndex const& index, TIndices... indices)
|
||||
{
|
||||
assert(index < dims[0]);
|
||||
return flat_index(static_cast<typename std::remove_pointer<TDim>::type>(index), dims + 1, indices...);
|
||||
}
|
||||
|
||||
template <unsigned skip = 0, typename T, std::size_t N, typename TIndex, typename... TIndices>
|
||||
__inline__ __host__ __device__ T constexpr flat_index(
|
||||
std::array<T, N> const& dims, TIndex const& index, TIndices... indices)
|
||||
{
|
||||
static_assert(skip < N);
|
||||
static_assert(sizeof...(TIndices) < N - skip, "Number of indices exceeds number of dimensions");
|
||||
return flat_index(&dims[skip], index, indices...);
|
||||
}
|
||||
|
||||
template <unsigned skip = 0, typename T, typename TIndex, std::size_t N, typename... TIndices>
|
||||
__inline__ __host__ __device__ T constexpr flat_index(
|
||||
T const& acc, std::array<T, N> const& dims, TIndex const& index, TIndices... indices)
|
||||
{
|
||||
static_assert(skip < N);
|
||||
static_assert(sizeof...(TIndices) < N - skip, "Number of indices exceeds number of dimensions");
|
||||
return flat_index(acc, &dims[skip], index, indices...);
|
||||
}
|
||||
|
||||
template <unsigned skip = 0, typename T, typename TIndex, std::size_t N, typename... TIndices>
|
||||
__inline__ __host__ __device__ T constexpr flat_index(T const (&dims)[N], TIndex const& index, TIndices... indices)
|
||||
{
|
||||
static_assert(skip < N);
|
||||
static_assert(sizeof...(TIndices) < N - skip, "Number of indices exceeds number of dimensions");
|
||||
return flat_index(static_cast<T const*>(dims) + skip, index, indices...);
|
||||
}
|
||||
|
||||
template <unsigned skip = 0, typename T, typename TIndex, std::size_t N, typename... TIndices>
|
||||
__inline__ __host__ __device__ T constexpr flat_index(
|
||||
T const& acc, T const (&dims)[N], TIndex const& index, TIndices... indices)
|
||||
{
|
||||
static_assert(skip < N);
|
||||
static_assert(sizeof...(TIndices) < N - skip, "Number of indices exceeds number of dimensions");
|
||||
return flat_index(acc, static_cast<T const*>(dims) + skip, index, indices...);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// These are simpler functions for multi-dimensional index conversion. Indices and dimensions are passed as individual
|
||||
// arguments. These functions are more suitable for usage inside kernels than the corresponding flat_index functions
|
||||
// which require arrays as arguments. Usage examples can be found in `test_memory_utils.cu`. The functions can be
|
||||
// evaluated at compile time.
|
||||
|
||||
template <typename T, typename TIndex>
|
||||
__inline__ __host__ __device__ T constexpr flat_index2(TIndex const& index_0, TIndex const& index_1, T const& dim_1)
|
||||
{
|
||||
assert(index_1 < dim_1);
|
||||
return index_0 * dim_1 + index_1;
|
||||
}
|
||||
|
||||
template <typename T, typename TIndex>
|
||||
__inline__ __host__ __device__ T constexpr flat_index3(
|
||||
TIndex const& index_0, TIndex const& index_1, TIndex const& index_2, T const& dim_1, T const& dim_2)
|
||||
{
|
||||
assert(index_2 < dim_2);
|
||||
return flat_index2(index_0, index_1, dim_1) * dim_2 + index_2;
|
||||
}
|
||||
|
||||
template <typename T, typename TIndex>
|
||||
__inline__ __host__ __device__ T constexpr flat_index4(TIndex const& index_0, TIndex const& index_1,
|
||||
TIndex const& index_2, TIndex const& index_3, T const& dim_1, T const& dim_2, T const& dim_3)
|
||||
{
|
||||
assert(index_3 < dim_3);
|
||||
return flat_index3(index_0, index_1, index_2, dim_1, dim_2) * dim_3 + index_3;
|
||||
}
|
||||
|
||||
template <typename T, typename TIndex>
|
||||
__inline__ __host__ __device__ T constexpr flat_index5(TIndex const& index_0, TIndex const& index_1,
|
||||
TIndex const& index_2, TIndex const& index_3, TIndex const& index_4, T const& dim_1, T const& dim_2, T const& dim_3,
|
||||
T const& dim_4)
|
||||
{
|
||||
assert(index_4 < dim_4);
|
||||
return flat_index4(index_0, index_1, index_2, index_3, dim_1, dim_2, dim_3) * dim_4 + index_4;
|
||||
}
|
||||
|
||||
template <typename T, typename TIndex>
|
||||
__inline__ __host__ __device__ T constexpr flat_index_strided3(
|
||||
TIndex const& index_0, TIndex const& index_1, TIndex const& index_2, T const& stride_1, T const& stride_2)
|
||||
{
|
||||
assert(index_1 < stride_1 / stride_2);
|
||||
assert(index_2 < stride_2);
|
||||
return index_0 * stride_1 + index_1 * stride_2 + index_2;
|
||||
}
|
||||
|
||||
template <typename T, typename TIndex>
|
||||
__inline__ __host__ __device__ T constexpr flat_index_strided4(TIndex const& index_0, TIndex const& index_1,
|
||||
TIndex const& index_2, TIndex const& index_3, T const& stride_1, T const& stride_2, T const& stride_3)
|
||||
{
|
||||
assert(index_1 < stride_1 / stride_2);
|
||||
assert(index_2 < stride_2 / stride_3);
|
||||
assert(index_3 < stride_3);
|
||||
return index_0 * stride_1 + index_1 * stride_2 + index_2 * stride_3 + index_3;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T>
|
||||
void invokeInPlaceTranspose(T* data, T* workspace, size_t const dim0, size_t const dim1);
|
||||
|
||||
template <typename T>
|
||||
void invokeInPlaceTranspose0213(
|
||||
T* data, T* workspace, size_t const dim0, size_t const dim1, size_t const dim2, size_t const dim3);
|
||||
|
||||
template <typename T>
|
||||
void invokeInPlaceTranspose102(T* data, T* workspace, size_t const dim0, size_t const dim1, size_t const dim2);
|
||||
|
||||
template <typename T>
|
||||
void invokeMultiplyScale(T* tensor, float scale, size_t const size, cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
void invokeDivideScale(T* tensor, float scale, size_t const size, cudaStream_t stream);
|
||||
|
||||
template <typename T_IN, typename T_OUT>
|
||||
void invokeCudaD2DcpyConvert(T_OUT* tgt, const T_IN* src, size_t const size, cudaStream_t stream = 0);
|
||||
|
||||
template <typename T_IN, typename T_OUT>
|
||||
void invokeCudaD2DScaleCpyConvert(
|
||||
T_OUT* tgt, const T_IN* src, float const* scale, bool invert_scale, size_t const size, cudaStream_t stream = 0);
|
||||
|
||||
inline bool checkIfFileExist(std::string const& file_path)
|
||||
{
|
||||
std::ifstream in(file_path, std::ios::in | std::ios::binary);
|
||||
if (in.is_open())
|
||||
{
|
||||
in.close();
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void saveToBinary(T const* ptr, size_t const size, std::string filename);
|
||||
|
||||
template <typename T_IN, typename T_fake_type>
|
||||
void invokeFakeCast(T_IN* input_ptr, size_t const size, cudaStream_t stream);
|
||||
|
||||
size_t cuda_datatype_size(TRTLLMCudaDataType dt);
|
||||
|
||||
template <typename T>
|
||||
bool invokeCheckRange(T const* buffer, size_t const size, T min, T max, bool* d_within_range, cudaStream_t stream);
|
||||
|
||||
constexpr size_t DEFAULT_ALIGN_BYTES = 256;
|
||||
|
||||
size_t calcAlignedSize(std::vector<size_t> const& sizes, size_t ALIGN_BYTES = DEFAULT_ALIGN_BYTES);
|
||||
void calcAlignedPointers(std::vector<void*>& outPtrs, void const* p, std::vector<size_t> const& sizes,
|
||||
size_t ALIGN_BYTES = DEFAULT_ALIGN_BYTES);
|
||||
|
||||
struct AlignedPointersUnpacker
|
||||
{
|
||||
template <typename... T>
|
||||
void operator()(T*&... outPtrs)
|
||||
{
|
||||
assert(sizeof...(T) == alignedPointers.size());
|
||||
auto it = alignedPointers.begin();
|
||||
((outPtrs = static_cast<T*>(*it++)), ...);
|
||||
}
|
||||
|
||||
std::vector<void*> alignedPointers;
|
||||
};
|
||||
|
||||
AlignedPointersUnpacker inline calcAlignedPointers(
|
||||
void const* p, std::vector<size_t> const& sizes, size_t ALIGN_BYTES = DEFAULT_ALIGN_BYTES)
|
||||
{
|
||||
AlignedPointersUnpacker unpacker{};
|
||||
calcAlignedPointers(unpacker.alignedPointers, p, sizes, ALIGN_BYTES);
|
||||
return unpacker;
|
||||
}
|
||||
|
||||
} // namespace common
|
||||
} // namespace tensorrt_llm
|
||||
588
sgl-kernel/3rdparty/tensorrt_llm/common/mpiUtils.cpp
vendored
Normal file
588
sgl-kernel/3rdparty/tensorrt_llm/common/mpiUtils.cpp
vendored
Normal file
@@ -0,0 +1,588 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <numeric>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "tensorrt_llm/common/mpiUtils.h"
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include "tensorrt_llm/runtime/iBuffer.h"
|
||||
|
||||
#include <csignal>
|
||||
#include <cstdlib>
|
||||
#include <mutex>
|
||||
#include <thread>
|
||||
#include <type_traits>
|
||||
#ifndef _WIN32
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
|
||||
// We rely on SizeType32 being int32_t in some places with weak type checking,
|
||||
// i.e. we're passing void ptr to some function. To prevent mysterious errors
|
||||
// in the future, we trigger a compilation error here if SizeType32 isn't int32_t.
|
||||
static_assert(std::is_same<tensorrt_llm::runtime::SizeType32, std::int32_t>::value);
|
||||
|
||||
namespace tensorrt_llm::mpi
|
||||
{
|
||||
|
||||
MPI_Datatype getMpiDtype(MpiType dtype)
|
||||
{
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
static std::unordered_map<MpiType, MPI_Datatype> const dtype_map{
|
||||
{MpiType::kBYTE, MPI_BYTE},
|
||||
{MpiType::kHALF, MPI_UINT16_T},
|
||||
{MpiType::kFLOAT, MPI_FLOAT},
|
||||
{MpiType::kDOUBLE, MPI_DOUBLE},
|
||||
{MpiType::kBOOL, MPI_C_BOOL},
|
||||
{MpiType::kINT8, MPI_INT8_T},
|
||||
{MpiType::kUINT8, MPI_UINT8_T},
|
||||
{MpiType::kINT32, MPI_INT32_T},
|
||||
{MpiType::kUINT32, MPI_UINT32_T},
|
||||
{MpiType::kINT64, MPI_INT64_T},
|
||||
{MpiType::kUINT64, MPI_UINT64_T},
|
||||
{MpiType::kFP8, MPI_UINT8_T},
|
||||
{MpiType::kBF16, MPI_UINT16_T},
|
||||
{MpiType::kCHAR, MPI_CHAR},
|
||||
};
|
||||
return dtype_map.at(dtype);
|
||||
#else
|
||||
TLLM_THROW("Multi device support is disabled.");
|
||||
#endif
|
||||
}
|
||||
|
||||
MPI_Op getMpiOp(MpiOp op)
|
||||
{
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
static std::unordered_map<MpiOp, MPI_Op> const op_map{
|
||||
{MpiOp::NULLOP, MPI_OP_NULL},
|
||||
{MpiOp::MAX, MPI_MAX},
|
||||
{MpiOp::MIN, MPI_MIN},
|
||||
{MpiOp::SUM, MPI_SUM},
|
||||
{MpiOp::PROD, MPI_PROD},
|
||||
{MpiOp::LAND, MPI_LAND},
|
||||
{MpiOp::BAND, MPI_BAND},
|
||||
{MpiOp::LOR, MPI_LOR},
|
||||
{MpiOp::BOR, MPI_BOR},
|
||||
{MpiOp::LXOR, MPI_LXOR},
|
||||
{MpiOp::BXOR, MPI_BXOR},
|
||||
{MpiOp::MINLOC, MPI_MINLOC},
|
||||
{MpiOp::MAXLOC, MPI_MAXLOC},
|
||||
{MpiOp::REPLACE, MPI_REPLACE},
|
||||
};
|
||||
return op_map.at(op);
|
||||
#else
|
||||
TLLM_THROW("Multi device support is disabled.");
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
bool mpiInitialized = false;
|
||||
std::recursive_mutex mpiMutex;
|
||||
|
||||
MpiComm initLocalSession()
|
||||
{
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
MPI_Comm localComm = nullptr;
|
||||
MPI_Comm_split_type(COMM_SESSION, OMPI_COMM_TYPE_HOST, COMM_SESSION.getRank(), MPI_INFO_NULL, &localComm);
|
||||
MpiComm localSession{localComm, false};
|
||||
#else
|
||||
MpiComm localSession{COMM_SESSION, false};
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
return localSession;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::vector<int> getWorldRanks(MpiComm const& comm)
|
||||
{
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
MPI_Group group = nullptr;
|
||||
MPI_Group worldGroup = nullptr;
|
||||
|
||||
MPICHECK(MPI_Comm_group(MPI_COMM_WORLD, &worldGroup));
|
||||
MPICHECK(MPI_Comm_group(comm, &group));
|
||||
|
||||
int groupSize = 0;
|
||||
MPICHECK(MPI_Group_size(group, &groupSize));
|
||||
std::vector<int> ranks(groupSize);
|
||||
std::vector<int> worldRanks(groupSize);
|
||||
std::iota(ranks.begin(), ranks.end(), 0);
|
||||
|
||||
MPICHECK(MPI_Group_translate_ranks(group, groupSize, ranks.data(), worldGroup, worldRanks.data()));
|
||||
MPICHECK(MPI_Group_free(&group));
|
||||
MPICHECK(MPI_Group_free(&worldGroup));
|
||||
#else
|
||||
std::vector<int> worldRanks{0};
|
||||
#endif
|
||||
return worldRanks;
|
||||
}
|
||||
|
||||
void initialize(MpiThreadSupport threadMode, bool forwardAbortToParent)
|
||||
{
|
||||
// double-checked locking
|
||||
if (mpiInitialized)
|
||||
{
|
||||
return;
|
||||
}
|
||||
std::lock_guard<std::recursive_mutex> lk(mpiMutex);
|
||||
if (mpiInitialized)
|
||||
{
|
||||
return;
|
||||
}
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
int initialized = 0;
|
||||
TLLM_MPI_CHECK(MPI_Initialized(&initialized));
|
||||
if (!initialized)
|
||||
{
|
||||
TLLM_LOG_INFO("Initializing MPI with thread mode %d", threadMode);
|
||||
int providedMode = 0;
|
||||
auto requiredMode = static_cast<int>(threadMode);
|
||||
MPICHECK(MPI_Init_thread(nullptr, nullptr, requiredMode, &providedMode));
|
||||
TLLM_CHECK_WITH_INFO(providedMode >= requiredMode, "MPI_Init_thread failed");
|
||||
std::atexit([]() { MPI_Finalize(); });
|
||||
|
||||
/*
|
||||
* We only catch SIGABRT and SIGSEGV because most, of not all errors in the worker will cause one of these 2
|
||||
* signals. Signals like SIGINT and SIGTERM should be issued to the parent and should terminate MPI workers
|
||||
* correctly.
|
||||
*/
|
||||
for (int sig : {SIGABRT, SIGSEGV})
|
||||
{
|
||||
__sighandler_t previousHandler = nullptr;
|
||||
if (forwardAbortToParent)
|
||||
{
|
||||
previousHandler = std::signal(sig,
|
||||
[](int signal)
|
||||
{
|
||||
#ifndef _WIN32
|
||||
pid_t parentProcessId = getppid();
|
||||
kill(parentProcessId, SIGKILL);
|
||||
#endif
|
||||
MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
previousHandler = std::signal(sig, [](int signal) { MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE); });
|
||||
}
|
||||
TLLM_CHECK_WITH_INFO(previousHandler != SIG_ERR, "Signal handler setup failed");
|
||||
}
|
||||
|
||||
// ensure local MPI communicator is initialized
|
||||
MpiComm::localSession();
|
||||
TLLM_LOG_INFO("Initialized MPI");
|
||||
}
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
mpiInitialized = true;
|
||||
}
|
||||
|
||||
void MpiComm::barrier() const
|
||||
{
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
MPICHECK(MPI_Barrier(mComm));
|
||||
#else
|
||||
TLLM_THROW("Multi device support is disabled.");
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
}
|
||||
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
template <typename TMpiFunc, typename TBase, typename... TArgs,
|
||||
typename = std::enable_if_t<std::is_same_v<void, std::remove_const_t<TBase>>>>
|
||||
size_t invokeChunked(TMpiFunc func, TBase* buffer, size_t size, MPI_Datatype dtype, TArgs... args)
|
||||
{
|
||||
constexpr auto maxP1 = static_cast<size_t>(std::numeric_limits<int>::max()) + 1;
|
||||
if (TLLM_LIKELY(size < maxP1))
|
||||
{
|
||||
MPICHECK(func(buffer, size, dtype, args...));
|
||||
return 1;
|
||||
}
|
||||
|
||||
constexpr size_t alignment = 256;
|
||||
int elementSize = 1;
|
||||
MPICHECK(MPI_Type_size(dtype, &elementSize));
|
||||
elementSize = std::min<int>(elementSize, alignment);
|
||||
|
||||
// We cap at max alignment-bytes chunks that can be sent at once.
|
||||
auto const step = maxP1 - (alignment / elementSize);
|
||||
|
||||
using TCast = std::conditional_t<std::is_const_v<TBase>, uint8_t const, uint8_t>;
|
||||
size_t count = 0;
|
||||
while (size != 0)
|
||||
{
|
||||
auto currentStep = static_cast<int>(std::min(size, step));
|
||||
MPICHECK(func(buffer, currentStep, dtype, args...));
|
||||
size -= currentStep;
|
||||
size_t diff = static_cast<size_t>(currentStep) * elementSize;
|
||||
buffer = static_cast<TCast*>(buffer) + diff;
|
||||
++count;
|
||||
}
|
||||
|
||||
return count;
|
||||
}
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
|
||||
std::shared_ptr<MpiRequest> MpiComm::bcastAsync(void* buffer, size_t size, MpiType dtype, int root) const
|
||||
{
|
||||
std::shared_ptr<MpiRequest> r = std::make_shared<MpiRequest>();
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
invokeChunked(MPI_Ibcast, buffer, size, getMpiDtype(dtype), root, mComm, &r->mRequest);
|
||||
#else
|
||||
TLLM_THROW("Multi device support is disabled.");
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
return r;
|
||||
}
|
||||
|
||||
std::shared_ptr<MpiRequest> MpiComm::bcastAsync(runtime::IBuffer& buf, int root) const
|
||||
{
|
||||
TLLM_CHECK(buf.getMemoryType() != runtime::MemoryType::kGPU);
|
||||
return bcastAsync(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, root);
|
||||
}
|
||||
|
||||
void MpiComm::bcast(void* buffer, size_t size, MpiType dtype, int root) const
|
||||
{
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
invokeChunked(MPI_Bcast, buffer, size, getMpiDtype(dtype), root, mComm);
|
||||
#else
|
||||
TLLM_THROW("Multi device support is disabled.");
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
}
|
||||
|
||||
void MpiComm::bcast(runtime::IBuffer& buf, int root) const
|
||||
{
|
||||
bcast(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, root);
|
||||
}
|
||||
|
||||
std::shared_ptr<MpiRequest> MpiComm::sendAsync(void const* buffer, size_t size, MpiType dtype, int dest, int tag) const
|
||||
{
|
||||
TLLM_LOG_DEBUG("start MPI_Isend with size %d", size);
|
||||
std::shared_ptr<MpiRequest> r = std::make_shared<MpiRequest>();
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
invokeChunked(MPI_Isend, buffer, size, getMpiDtype(dtype), dest, tag, mComm, &r->mRequest);
|
||||
#else
|
||||
TLLM_THROW("Multi device support is disabled.");
|
||||
#endif
|
||||
TLLM_LOG_DEBUG("end MPI_Isend with size %d", size);
|
||||
return r;
|
||||
}
|
||||
|
||||
std::shared_ptr<MpiRequest> MpiComm::sendAsync(runtime::IBuffer const& buf, int dest, int tag) const
|
||||
{
|
||||
return sendAsync(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, dest, tag);
|
||||
}
|
||||
|
||||
void MpiComm::send(void const* buffer, size_t size, MpiType dtype, int dest, int tag) const
|
||||
{
|
||||
TLLM_LOG_DEBUG("start MPI_Send with size %d", size);
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
invokeChunked(MPI_Send, buffer, size, getMpiDtype(dtype), dest, tag, mComm);
|
||||
#else
|
||||
TLLM_THROW("Multi device support is disabled.");
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
TLLM_LOG_DEBUG("end MPI_Send with size %d", size);
|
||||
}
|
||||
|
||||
void MpiComm::send(runtime::IBuffer const& buf, int dest, int tag) const
|
||||
{
|
||||
send(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, dest, tag);
|
||||
}
|
||||
|
||||
MPI_Status MpiComm::recv(void* buffer, size_t size, MpiType dtype, int source, int tag) const
|
||||
{
|
||||
TLLM_LOG_DEBUG("start MPI_Recv with size %d", size);
|
||||
MPI_Status status{};
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
invokeChunked(MPI_Recv, buffer, size, getMpiDtype(dtype), source, tag, mComm, &status);
|
||||
#else
|
||||
TLLM_THROW("Multi device support is disabled.");
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
TLLM_LOG_DEBUG("end MPI_Recv with size %d", size);
|
||||
return status;
|
||||
}
|
||||
|
||||
MPI_Status MpiComm::recv(runtime::IBuffer& buf, int source, int tag) const
|
||||
{
|
||||
return recv(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, source, tag);
|
||||
}
|
||||
|
||||
MpiComm MpiComm::split(int color, int key) const
|
||||
{
|
||||
MPI_Comm splitComm = nullptr;
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
MPICHECK(MPI_Comm_split(mComm, color, key, &splitComm));
|
||||
#else
|
||||
TLLM_THROW("Multi device support is disabled.");
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
return MpiComm{splitComm, true};
|
||||
}
|
||||
|
||||
void MpiComm::allreduce(void const* sendbuf, void* recvbuf, int count, MpiType dtype, MpiOp op) const
|
||||
{
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
MPICHECK(MPI_Allreduce(sendbuf, recvbuf, count, getMpiDtype(dtype), getMpiOp(op), mComm));
|
||||
#else
|
||||
TLLM_THROW("Multi device support is disabled.");
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
}
|
||||
|
||||
void MpiComm::allgather(void const* sendbuf, void* recvbuf, int count, MpiType dtype) const
|
||||
{
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
MPICHECK(MPI_Allgather(sendbuf, count, getMpiDtype(dtype), recvbuf, count, getMpiDtype(dtype), mComm));
|
||||
#else
|
||||
TLLM_THROW("Multi device support is disabled.");
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
}
|
||||
|
||||
void MpiComm::allgatherv(void const* sendbuf, int sendcount, MpiType sendtype, void* recvbuf,
|
||||
std::vector<int> const& recvcounts, std::vector<int> const& displs, MpiType recvtype) const
|
||||
{
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
MPICHECK(MPI_Allgatherv(sendbuf, sendcount, getMpiDtype(sendtype), recvbuf, recvcounts.data(), displs.data(),
|
||||
getMpiDtype(recvtype), mComm));
|
||||
|
||||
#else
|
||||
TLLM_THROW("Multi device support is disabled.");
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
}
|
||||
|
||||
void MpiComm::mprobe(int source, int tag, MPI_Message* msg, MPI_Status* status) const
|
||||
{
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
MPICHECK(MPI_Mprobe(source, tag, mComm, msg, status));
|
||||
#else
|
||||
TLLM_THROW("Multi device support is disabled.");
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
}
|
||||
|
||||
bool MpiComm::improbe(int source, int tag, MPI_Message* msg, MPI_Status* status) const
|
||||
{
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
int flag{0};
|
||||
MPICHECK(MPI_Improbe(source, tag, mComm, &flag, msg, status));
|
||||
return flag != 0;
|
||||
#else
|
||||
TLLM_THROW("Multi device support is disabled.");
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
bool MpiComm::iprobe(int source, int tag, MPI_Status* status) const
|
||||
{
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
int flag{0};
|
||||
MPICHECK(MPI_Iprobe(source, tag, mComm, &flag, status));
|
||||
return flag != 0;
|
||||
#else
|
||||
TLLM_THROW("Multi device support is disabled.");
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
void MpiComm::recvPoll(int source, int tag, int periodMs) const
|
||||
{
|
||||
MPI_Status status;
|
||||
while (!iprobe(source, tag, &status))
|
||||
{
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(periodMs));
|
||||
}
|
||||
}
|
||||
|
||||
int MpiComm::getRank() const
|
||||
{
|
||||
int rank = 0;
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
MPICHECK(MPI_Comm_rank(mComm, &rank));
|
||||
#endif
|
||||
return rank;
|
||||
}
|
||||
|
||||
int MpiComm::getSize() const
|
||||
{
|
||||
int world_size = 1;
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
MPICHECK(MPI_Comm_size(mComm, &world_size));
|
||||
#endif
|
||||
return world_size;
|
||||
}
|
||||
|
||||
MpiComm const& MpiComm::world()
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
static MpiComm commWorld{MPI_COMM_WORLD, false};
|
||||
initialize();
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
return commWorld;
|
||||
}
|
||||
|
||||
MpiComm& MpiComm::mutableSession()
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
static MpiComm commSession{MPI_COMM_WORLD, false};
|
||||
initialize();
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
return commSession;
|
||||
}
|
||||
|
||||
MpiComm& MpiComm::mutableLocalSession()
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
static MpiComm localSession = initLocalSession();
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
return localSession;
|
||||
}
|
||||
|
||||
void MpiComm::refreshLocalSession()
|
||||
{
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
static std::mutex mutex;
|
||||
std::unique_lock lock(mutex);
|
||||
auto initSessionRanks = getWorldRanks(MpiComm::session());
|
||||
auto localSessionRanks = getWorldRanks(MpiComm::localSession());
|
||||
|
||||
// Add to intersectionRanks in order of initSessionRanks
|
||||
std::vector<int> intersectionRanks;
|
||||
std::unordered_set<int> localSessionRanksSet(localSessionRanks.begin(), localSessionRanks.end());
|
||||
for (auto rank : initSessionRanks)
|
||||
{
|
||||
if (localSessionRanksSet.find(rank) != localSessionRanksSet.end())
|
||||
{
|
||||
intersectionRanks.push_back(rank);
|
||||
}
|
||||
}
|
||||
|
||||
MPI_Group worldGroup = nullptr;
|
||||
MPICHECK(MPI_Comm_group(MPI_COMM_WORLD, &worldGroup));
|
||||
MPI_Group localGroup = nullptr;
|
||||
MPICHECK(MPI_Group_incl(worldGroup, intersectionRanks.size(), intersectionRanks.data(), &localGroup));
|
||||
MPI_Comm localComm = nullptr;
|
||||
MPICHECK(MPI_Comm_create_group(MPI_COMM_WORLD, localGroup, intersectionRanks.front(), &localComm));
|
||||
MpiComm::mutableLocalSession().mFreeComm = true;
|
||||
MpiComm::mutableLocalSession() = MpiComm{localComm, false};
|
||||
TLLM_LOG_INFO("Refreshed the MPI local session");
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
}
|
||||
|
||||
MpiComm::MpiComm(MPI_Comm g, bool freeComm)
|
||||
: mComm{g}
|
||||
, mFreeComm{freeComm}
|
||||
{
|
||||
TLLM_CHECK(mComm != MPI_COMM_NULL);
|
||||
}
|
||||
|
||||
MpiComm::~MpiComm() noexcept
|
||||
{
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
if (mFreeComm && mComm)
|
||||
{
|
||||
if (MPI_Comm_free(&mComm) != MPI_SUCCESS)
|
||||
{
|
||||
TLLM_LOG_ERROR("MPI_Comm_free failed");
|
||||
}
|
||||
}
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
}
|
||||
|
||||
MpiComm::MpiComm(MpiComm&& comm) noexcept
|
||||
: mComm{comm.mComm}
|
||||
, mFreeComm{comm.mFreeComm}
|
||||
{
|
||||
comm.mFreeComm = false;
|
||||
}
|
||||
|
||||
MpiComm& MpiComm::operator=(MpiComm&& comm) noexcept
|
||||
{
|
||||
this->~MpiComm();
|
||||
mComm = comm.mComm;
|
||||
mFreeComm = comm.mFreeComm;
|
||||
comm.mFreeComm = false;
|
||||
return *this;
|
||||
}
|
||||
|
||||
MpiWaitThread::MpiWaitThread(std::string name, std::function<void()> funcWait, std::function<void()> funcSetup)
|
||||
: mName{name.c_str()}
|
||||
, mFuncWait{funcWait}
|
||||
, mFuncSetup{funcSetup}
|
||||
{
|
||||
TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__);
|
||||
mThread = std::make_unique<std::thread>(&MpiWaitThread::sideThread, this);
|
||||
TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
MpiWaitThread::~MpiWaitThread()
|
||||
{
|
||||
TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__);
|
||||
waitStop();
|
||||
mShouldExit.store(true);
|
||||
notifyStart();
|
||||
mThread->join();
|
||||
mThread.reset(nullptr);
|
||||
TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void MpiWaitThread::sideThread()
|
||||
{
|
||||
if (mFuncSetup)
|
||||
{
|
||||
mFuncSetup();
|
||||
}
|
||||
while (!mShouldExit.load())
|
||||
{
|
||||
notifyStop();
|
||||
waitStart();
|
||||
mFuncWait();
|
||||
}
|
||||
}
|
||||
|
||||
void MpiWaitThread::waitStart()
|
||||
{
|
||||
TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__);
|
||||
std::unique_lock<std::mutex> lock(mMutex);
|
||||
mCondVar.wait(lock, [this] { return mRunning; });
|
||||
TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void MpiWaitThread::waitStop()
|
||||
{
|
||||
TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__);
|
||||
std::unique_lock<std::mutex> lock(mMutex);
|
||||
mCondVar.wait(lock, [this] { return !mRunning; });
|
||||
TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void MpiWaitThread::notifyStart()
|
||||
{
|
||||
TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__);
|
||||
std::lock_guard<std::mutex> lock(mMutex);
|
||||
mRunning = true;
|
||||
mCondVar.notify_one();
|
||||
TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void MpiWaitThread::notifyStop()
|
||||
{
|
||||
TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__);
|
||||
std::lock_guard<std::mutex> lock(mMutex);
|
||||
mRunning = false;
|
||||
mCondVar.notify_one();
|
||||
TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::mpi
|
||||
46
sgl-kernel/3rdparty/tensorrt_llm/common/nvtxUtils.h
vendored
Normal file
46
sgl-kernel/3rdparty/tensorrt_llm/common/nvtxUtils.h
vendored
Normal file
@@ -0,0 +1,46 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
#include <array>
|
||||
|
||||
namespace tensorrt_llm::common::nvtx
|
||||
{
|
||||
inline nvtx3::color nextColor()
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
constexpr std::array kColors{nvtx3::color{0xff00ff00}, nvtx3::color{0xff0000ff}, nvtx3::color{0xffffff00},
|
||||
nvtx3::color{0xffff00ff}, nvtx3::color{0xff00ffff}, nvtx3::color{0xffff0000}, nvtx3::color{0xffffffff}};
|
||||
constexpr auto numColors = kColors.size();
|
||||
|
||||
static thread_local std::size_t colorId = 0;
|
||||
auto const color = kColors[colorId];
|
||||
colorId = colorId + 1 >= numColors ? 0 : colorId + 1;
|
||||
return color;
|
||||
#else
|
||||
return nvtx3::color{0};
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::common::nvtx
|
||||
|
||||
#define NVTX3_SCOPED_RANGE_WITH_NAME(range, name) \
|
||||
::nvtx3::scoped_range range(::tensorrt_llm::common::nvtx::nextColor(), name)
|
||||
#define NVTX3_SCOPED_RANGE(range) NVTX3_SCOPED_RANGE_WITH_NAME(range##_range, #range)
|
||||
323
sgl-kernel/3rdparty/tensorrt_llm/common/opUtils.cpp
vendored
Normal file
323
sgl-kernel/3rdparty/tensorrt_llm/common/opUtils.cpp
vendored
Normal file
@@ -0,0 +1,323 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tensorrt_llm/common/opUtils.h"
|
||||
#include "tensorrt_llm/common/mpiUtils.h"
|
||||
|
||||
#include "cuda.h"
|
||||
#include <cstdint>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_fp8.h>
|
||||
#include <functional>
|
||||
#include <mutex>
|
||||
#include <thread>
|
||||
|
||||
#ifdef _MSC_VER
|
||||
#define FN_NAME __FUNCTION__
|
||||
#else
|
||||
#define FN_NAME __func__
|
||||
#endif
|
||||
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
|
||||
std::unordered_map<nvinfer1::DataType, ncclDataType_t>* getDtypeMap()
|
||||
{
|
||||
static std::unordered_map<nvinfer1::DataType, ncclDataType_t> dtypeMap = {{nvinfer1::DataType::kFLOAT, ncclFloat32},
|
||||
{nvinfer1::DataType::kHALF, ncclFloat16}, {nvinfer1::DataType::kBF16, ncclBfloat16}};
|
||||
return &dtypeMap;
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
// Get NCCL unique ID for a group of ranks.
|
||||
ncclUniqueId getUniqueId(std::set<int> const& group) noexcept
|
||||
{
|
||||
auto const rank = COMM_SESSION.getRank();
|
||||
TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, rank);
|
||||
ncclUniqueId id;
|
||||
if (rank == *group.begin())
|
||||
{
|
||||
NCCLCHECK(ncclGetUniqueId(&id));
|
||||
for (auto it = std::next(std::begin(group), 1); it != group.end(); ++it)
|
||||
{
|
||||
COMM_SESSION.sendValue(id, *it, 0);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
COMM_SESSION.recvValue(id, *group.begin(), 0);
|
||||
}
|
||||
TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, rank);
|
||||
return id;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
std::shared_ptr<ncclComm_t> getComm(std::set<int> const& group)
|
||||
{
|
||||
auto const rank = COMM_SESSION.getRank();
|
||||
TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, rank);
|
||||
static std::map<std::set<int>, std::shared_ptr<ncclComm_t>> commMap;
|
||||
static std::mutex mutex;
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
std::ostringstream oss;
|
||||
int index = 0;
|
||||
for (auto const& rank : group)
|
||||
{
|
||||
if (index != 0)
|
||||
{
|
||||
oss << ",";
|
||||
}
|
||||
oss << rank;
|
||||
index++;
|
||||
}
|
||||
auto groupStr = oss.str();
|
||||
auto it = commMap.find(group);
|
||||
if (it != commMap.end())
|
||||
{
|
||||
auto ncclComm = it->second;
|
||||
TLLM_LOG_TRACE("NCCL comm for group(%s) is cached for rank %d", groupStr.c_str(), rank);
|
||||
return ncclComm;
|
||||
}
|
||||
|
||||
TLLM_LOG_TRACE("Init NCCL comm for group(%s) for rank %d", groupStr.c_str(), rank);
|
||||
ncclUniqueId id = getUniqueId(group);
|
||||
int groupRank = 0;
|
||||
for (auto const& currentRank : group)
|
||||
{
|
||||
if (rank == currentRank)
|
||||
break;
|
||||
++groupRank;
|
||||
}
|
||||
TLLM_CHECK(groupRank < group.size());
|
||||
std::shared_ptr<ncclComm_t> ncclComm(new ncclComm_t,
|
||||
[](ncclComm_t* comm)
|
||||
{
|
||||
ncclCommDestroy(*comm);
|
||||
delete comm;
|
||||
});
|
||||
NCCLCHECK(ncclCommInitRank(ncclComm.get(), group.size(), id, groupRank));
|
||||
commMap[group] = ncclComm;
|
||||
TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, rank);
|
||||
return ncclComm;
|
||||
}
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
|
||||
void const* tensorrt_llm::common::getCommSessionHandle()
|
||||
{
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
return &COMM_SESSION;
|
||||
#else
|
||||
return nullptr;
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
// Get current cuda context, a default context will be created if there is no context.
|
||||
inline CUcontext getCurrentCudaCtx()
|
||||
{
|
||||
CUcontext ctx{};
|
||||
CUresult err = cuCtxGetCurrent(&ctx);
|
||||
if (err == CUDA_ERROR_NOT_INITIALIZED || ctx == nullptr)
|
||||
{
|
||||
TLLM_CUDA_CHECK(cudaFree(nullptr));
|
||||
err = cuCtxGetCurrent(&ctx);
|
||||
}
|
||||
TLLM_CHECK(err == CUDA_SUCCESS);
|
||||
return ctx;
|
||||
}
|
||||
|
||||
// Helper to create per-cuda-context singleton managed by std::shared_ptr.
|
||||
// Unlike conventional singletons, singleton created with this will be released
|
||||
// when not needed, instead of on process exit.
|
||||
// Objects of this class shall always be declared static / global, and shall never own CUDA
|
||||
// resources.
|
||||
template <typename T>
|
||||
class PerCudaCtxSingletonCreator
|
||||
{
|
||||
public:
|
||||
using CreatorFunc = std::function<std::unique_ptr<T>()>;
|
||||
using DeleterFunc = std::function<void(T*)>;
|
||||
|
||||
// creator returning std::unique_ptr is by design.
|
||||
// It forces separation of memory for T and memory for control blocks.
|
||||
// So when T is released, but we still have observer weak_ptr in mObservers, the T mem block can be released.
|
||||
// creator itself must not own CUDA resources. Only the object it creates can.
|
||||
PerCudaCtxSingletonCreator(CreatorFunc creator, DeleterFunc deleter)
|
||||
: mCreator{std::move(creator)}
|
||||
, mDeleter{std::move(deleter)}
|
||||
{
|
||||
}
|
||||
|
||||
std::shared_ptr<T> operator()()
|
||||
{
|
||||
std::lock_guard<std::mutex> lk{mMutex};
|
||||
CUcontext ctx{getCurrentCudaCtx()};
|
||||
std::shared_ptr<T> result = mObservers[ctx].lock();
|
||||
if (result == nullptr)
|
||||
{
|
||||
// Create the resource and register with an observer.
|
||||
result = std::shared_ptr<T>{mCreator().release(),
|
||||
[this, ctx](T* obj)
|
||||
{
|
||||
if (obj == nullptr)
|
||||
{
|
||||
return;
|
||||
}
|
||||
mDeleter(obj);
|
||||
|
||||
// Clears observer to avoid growth of mObservers, in case users creates/destroys cuda contexts
|
||||
// frequently.
|
||||
std::shared_ptr<T> observedObjHolder; // Delay destroy to avoid dead lock.
|
||||
std::lock_guard<std::mutex> lk{mMutex};
|
||||
// Must check observer again because another thread may created new instance for this ctx just
|
||||
// before we lock mMutex. We can't infer that the observer is stale from the fact that obj is
|
||||
// destroyed, because shared_ptr ref-count checking and observer removing are not in one atomic
|
||||
// operation, and the observer may be changed to observe another instance.
|
||||
observedObjHolder = mObservers.at(ctx).lock();
|
||||
if (observedObjHolder == nullptr)
|
||||
{
|
||||
mObservers.erase(ctx);
|
||||
}
|
||||
}};
|
||||
mObservers.at(ctx) = result;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
private:
|
||||
CreatorFunc mCreator;
|
||||
DeleterFunc mDeleter;
|
||||
mutable std::mutex mMutex;
|
||||
// CUDA resources are per-context.
|
||||
std::unordered_map<CUcontext, std::weak_ptr<T>> mObservers;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class PerThreadSingletonCreator
|
||||
{
|
||||
public:
|
||||
using CreatorFunc = std::function<std::unique_ptr<T>()>;
|
||||
using DeleterFunc = std::function<void(T*)>;
|
||||
|
||||
// creator returning std::unique_ptr is by design.
|
||||
// It forces separation of memory for T and memory for control blocks.
|
||||
// So when T is released, but we still have observer weak_ptr in mObservers, the T mem block can be released.
|
||||
// creator itself must not own CUDA resources. Only the object it creates can.
|
||||
PerThreadSingletonCreator(CreatorFunc creator, DeleterFunc deleter)
|
||||
: mCreator{std::move(creator)}
|
||||
, mDeleter{std::move(deleter)}
|
||||
{
|
||||
}
|
||||
|
||||
std::shared_ptr<T> operator()()
|
||||
{
|
||||
std::lock_guard<std::mutex> lk{mMutex};
|
||||
|
||||
std::thread::id thread = std::this_thread::get_id();
|
||||
std::shared_ptr<T> result = mObservers[thread].lock();
|
||||
|
||||
if (result == nullptr)
|
||||
{
|
||||
// Create the resource and register with an observer.
|
||||
result = std::shared_ptr<T>{mCreator().release(),
|
||||
[this, thread](T* obj)
|
||||
{
|
||||
if (obj == nullptr)
|
||||
{
|
||||
return;
|
||||
}
|
||||
mDeleter(obj);
|
||||
|
||||
// Clears observer to avoid growth of mObservers, in case users creates/destroys cuda contexts
|
||||
// frequently.
|
||||
std::shared_ptr<T> observedObjHolder; // Delay destroy to avoid dead lock.
|
||||
std::lock_guard<std::mutex> lk{mMutex};
|
||||
// Must check observer again because another thread may created new instance for this ctx just
|
||||
// before we lock mMutex. We can't infer that the observer is stale from the fact that obj is
|
||||
// destroyed, because shared_ptr ref-count checking and observer removing are not in one atomic
|
||||
// operation, and the observer may be changed to observe another instance.
|
||||
observedObjHolder = mObservers.at(thread).lock();
|
||||
if (observedObjHolder == nullptr)
|
||||
{
|
||||
mObservers.erase(thread);
|
||||
}
|
||||
}};
|
||||
mObservers.at(thread) = result;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
private:
|
||||
CreatorFunc mCreator;
|
||||
DeleterFunc mDeleter;
|
||||
mutable std::mutex mMutex;
|
||||
// CUDA resources are per-thread.
|
||||
std::unordered_map<std::thread::id, std::weak_ptr<T>> mObservers;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::shared_ptr<cublasHandle_t> getCublasHandle()
|
||||
{
|
||||
static PerThreadSingletonCreator<cublasHandle_t> creator(
|
||||
[]() -> auto
|
||||
{
|
||||
auto handle = std::unique_ptr<cublasHandle_t>(new cublasHandle_t);
|
||||
TLLM_CUDA_CHECK(cublasCreate(handle.get()));
|
||||
return handle;
|
||||
},
|
||||
[](cublasHandle_t* handle)
|
||||
{
|
||||
TLLM_CUDA_CHECK(cublasDestroy(*handle));
|
||||
delete handle;
|
||||
});
|
||||
return creator();
|
||||
}
|
||||
|
||||
std::shared_ptr<cublasLtHandle_t> getCublasLtHandle()
|
||||
{
|
||||
static PerThreadSingletonCreator<cublasLtHandle_t> creator(
|
||||
[]() -> auto
|
||||
{
|
||||
auto handle = std::unique_ptr<cublasLtHandle_t>(new cublasLtHandle_t);
|
||||
TLLM_CUDA_CHECK(cublasLtCreate(handle.get()));
|
||||
return handle;
|
||||
},
|
||||
[](cublasLtHandle_t* handle)
|
||||
{
|
||||
TLLM_CUDA_CHECK(cublasLtDestroy(*handle));
|
||||
delete handle;
|
||||
});
|
||||
return creator();
|
||||
}
|
||||
|
||||
std::shared_ptr<tensorrt_llm::common::CublasMMWrapper> getCublasMMWrapper(std::shared_ptr<cublasHandle_t> cublasHandle,
|
||||
std::shared_ptr<cublasLtHandle_t> cublasltHandle, cudaStream_t stream, void* workspace)
|
||||
{
|
||||
static PerThreadSingletonCreator<tensorrt_llm::common::CublasMMWrapper> creator(
|
||||
[cublasHandle, cublasltHandle, stream, workspace]() -> auto
|
||||
{
|
||||
auto wrapper = std::unique_ptr<tensorrt_llm::common::CublasMMWrapper>(
|
||||
new tensorrt_llm::common::CublasMMWrapper(cublasHandle, cublasltHandle, stream, workspace));
|
||||
return wrapper;
|
||||
},
|
||||
[](tensorrt_llm::common::CublasMMWrapper* wrapper) { delete wrapper; });
|
||||
return creator();
|
||||
}
|
||||
215
sgl-kernel/3rdparty/tensorrt_llm/common/opUtils.h
vendored
Normal file
215
sgl-kernel/3rdparty/tensorrt_llm/common/opUtils.h
vendored
Normal file
@@ -0,0 +1,215 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/cublasMMWrapper.h"
|
||||
#include "tensorrt_llm/common/workspace.h"
|
||||
|
||||
#include <NvInferRuntime.h>
|
||||
#include <cublasLt.h>
|
||||
#include <cublas_v2.h>
|
||||
#include <cuda_runtime.h>
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
#include <nccl.h>
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
|
||||
#include <cstring>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <nvml.h>
|
||||
#include <optional>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace tensorrt_llm::common
|
||||
{
|
||||
|
||||
// Write values into buffer
|
||||
template <typename T>
|
||||
void write(char*& buffer, T const& val)
|
||||
{
|
||||
std::memcpy(buffer, &val, sizeof(T));
|
||||
buffer += sizeof(T);
|
||||
}
|
||||
|
||||
// Read values from buffer
|
||||
template <typename T>
|
||||
void read(char const*& buffer, T& val)
|
||||
{
|
||||
std::memcpy(&val, buffer, sizeof(T));
|
||||
buffer += sizeof(T);
|
||||
}
|
||||
|
||||
// Like std::unique_ptr, but does not prevent generation of default copy constructor when used as class members.
|
||||
// The copy constructor produces nullptr. So the plugin default copy constructor will not really copy this, and
|
||||
// your clone() implementation is responsible for initializing such data members.
|
||||
// With this we can simplify clone() implementation when there are many data members including at least one unique_ptr.
|
||||
template <typename T, typename Del = std::default_delete<T>>
|
||||
class UniqPtrWNullCopy : public std::unique_ptr<T, Del>
|
||||
{
|
||||
public:
|
||||
using std::unique_ptr<T, Del>::unique_ptr;
|
||||
|
||||
// for compatibility with std::make_unique
|
||||
explicit UniqPtrWNullCopy(std::unique_ptr<T, Del>&& src)
|
||||
: std::unique_ptr<T, Del>::unique_ptr{std::move(src)}
|
||||
{
|
||||
}
|
||||
|
||||
// copy constructor produces nullptr
|
||||
UniqPtrWNullCopy(UniqPtrWNullCopy const&)
|
||||
: std::unique_ptr<T, Del>::unique_ptr{}
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
// for testing only
|
||||
void const* getCommSessionHandle();
|
||||
} // namespace tensorrt_llm::common
|
||||
|
||||
inline bool isBuilding()
|
||||
{
|
||||
auto constexpr key = "IS_BUILDING";
|
||||
auto const val = getenv(key);
|
||||
return val != nullptr && std::string(val) == "1";
|
||||
}
|
||||
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
#define NCCLCHECK(cmd) \
|
||||
do \
|
||||
{ \
|
||||
ncclResult_t r = cmd; \
|
||||
if (r != ncclSuccess) \
|
||||
{ \
|
||||
printf("Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, ncclGetErrorString(r)); \
|
||||
exit(EXIT_FAILURE); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
std::unordered_map<nvinfer1::DataType, ncclDataType_t>* getDtypeMap();
|
||||
|
||||
std::shared_ptr<ncclComm_t> getComm(std::set<int> const& group);
|
||||
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
|
||||
//! To save GPU memory, all the plugins share the same cublas and cublasLt handle globally.
|
||||
//! Get cublas and cublasLt handle for current cuda context
|
||||
std::shared_ptr<cublasHandle_t> getCublasHandle();
|
||||
std::shared_ptr<cublasLtHandle_t> getCublasLtHandle();
|
||||
std::shared_ptr<tensorrt_llm::common::CublasMMWrapper> getCublasMMWrapper(std::shared_ptr<cublasHandle_t> cublasHandle,
|
||||
std::shared_ptr<cublasLtHandle_t> cublasltHandle, cudaStream_t stream, void* workspace);
|
||||
|
||||
#ifndef DEBUG
|
||||
|
||||
#define PLUGIN_CHECK(status) \
|
||||
do \
|
||||
{ \
|
||||
if (status != 0) \
|
||||
abort(); \
|
||||
} while (0)
|
||||
|
||||
#define ASSERT_PARAM(exp) \
|
||||
do \
|
||||
{ \
|
||||
if (!(exp)) \
|
||||
return STATUS_BAD_PARAM; \
|
||||
} while (0)
|
||||
|
||||
#define ASSERT_FAILURE(exp) \
|
||||
do \
|
||||
{ \
|
||||
if (!(exp)) \
|
||||
return STATUS_FAILURE; \
|
||||
} while (0)
|
||||
|
||||
#define CSC(call, err) \
|
||||
do \
|
||||
{ \
|
||||
cudaError_t cudaStatus = call; \
|
||||
if (cudaStatus != cudaSuccess) \
|
||||
{ \
|
||||
return err; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define DEBUG_PRINTF(...) \
|
||||
do \
|
||||
{ \
|
||||
} while (0)
|
||||
|
||||
#else
|
||||
|
||||
#define ASSERT_PARAM(exp) \
|
||||
do \
|
||||
{ \
|
||||
if (!(exp)) \
|
||||
{ \
|
||||
fprintf(stderr, "Bad param - " #exp ", %s:%d\n", __FILE__, __LINE__); \
|
||||
return STATUS_BAD_PARAM; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define ASSERT_FAILURE(exp) \
|
||||
do \
|
||||
{ \
|
||||
if (!(exp)) \
|
||||
{ \
|
||||
fprintf(stderr, "Failure - " #exp ", %s:%d\n", __FILE__, __LINE__); \
|
||||
return STATUS_FAILURE; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define CSC(call, err) \
|
||||
do \
|
||||
{ \
|
||||
cudaError_t cudaStatus = call; \
|
||||
if (cudaStatus != cudaSuccess) \
|
||||
{ \
|
||||
printf("%s %d CUDA FAIL %s\n", __FILE__, __LINE__, cudaGetErrorString(cudaStatus)); \
|
||||
return err; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define PLUGIN_CHECK(status) \
|
||||
{ \
|
||||
if (status != 0) \
|
||||
{ \
|
||||
DEBUG_PRINTF("%s %d CUDA FAIL %s\n", __FILE__, __LINE__, cudaGetErrorString(status)); \
|
||||
abort(); \
|
||||
} \
|
||||
}
|
||||
|
||||
#define DEBUG_PRINTF(...) \
|
||||
do \
|
||||
{ \
|
||||
printf(__VA_ARGS__); \
|
||||
} while (0)
|
||||
|
||||
#endif // DEBUG
|
||||
|
||||
#define NVML_CHECK(cmd) \
|
||||
do \
|
||||
{ \
|
||||
nvmlReturn_t r = cmd; \
|
||||
if (r != NVML_SUCCESS) \
|
||||
{ \
|
||||
printf("Failed, NVML error %s:%d '%s'\n", __FILE__, __LINE__, nvmlErrorString(r)); \
|
||||
exit(EXIT_FAILURE); \
|
||||
} \
|
||||
} while (0)
|
||||
55
sgl-kernel/3rdparty/tensorrt_llm/common/quantTypeUtils.cuh
vendored
Normal file
55
sgl-kernel/3rdparty/tensorrt_llm/common/quantTypeUtils.cuh
vendored
Normal file
@@ -0,0 +1,55 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/cudaBf16Fallbacks.cuh"
|
||||
#include "tensorrt_llm/common/cudaFp8Utils.h"
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <float.h>
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace common
|
||||
{
|
||||
|
||||
template <typename T>
|
||||
struct QuantTypeStaticVals;
|
||||
|
||||
template <>
|
||||
struct QuantTypeStaticVals<int8_t>
|
||||
{
|
||||
static constexpr float MAX_VAL = 127.f;
|
||||
static constexpr float MIN_SCALING_FACTOR = 0.f;
|
||||
static constexpr float MIN_SCALING_FACTOR_RCP = FLT_MAX;
|
||||
};
|
||||
|
||||
#ifdef ENABLE_FP8
|
||||
|
||||
template <>
|
||||
struct QuantTypeStaticVals<__nv_fp8_e4m3>
|
||||
{
|
||||
static constexpr float MAX_VAL = 448.f;
|
||||
// Ref: https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu#L720
|
||||
static constexpr float MIN_SCALING_FACTOR = 1.0f / (448.f * 512.f);
|
||||
static constexpr float MIN_SCALING_FACTOR_RCP = (448.f * 512.f);
|
||||
};
|
||||
|
||||
#endif // ENABLE_FP8
|
||||
|
||||
} // namespace common
|
||||
} // namespace tensorrt_llm
|
||||
399
sgl-kernel/3rdparty/tensorrt_llm/common/reduceKernelUtils.cuh
vendored
Normal file
399
sgl-kernel/3rdparty/tensorrt_llm/common/reduceKernelUtils.cuh
vendored
Normal file
@@ -0,0 +1,399 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include <array>
|
||||
#include <assert.h>
|
||||
#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))
|
||||
#include <cooperative_groups/reduce.h>
|
||||
#else
|
||||
#include <cooperative_groups.h>
|
||||
#endif
|
||||
#include "tensorrt_llm/common/cudaTypeUtils.cuh"
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <curand_kernel.h>
|
||||
#include <float.h>
|
||||
#include <type_traits>
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace common
|
||||
{
|
||||
|
||||
template <int VPT>
|
||||
struct BytesToType;
|
||||
|
||||
template <>
|
||||
struct BytesToType<1>
|
||||
{
|
||||
using type = uint8_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct BytesToType<2>
|
||||
{
|
||||
using type = uint16_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct BytesToType<4>
|
||||
{
|
||||
using type = uint32_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct BytesToType<8>
|
||||
{
|
||||
using type = uint64_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct BytesToType<16>
|
||||
{
|
||||
using type = float4;
|
||||
};
|
||||
|
||||
template <int Bytes>
|
||||
__device__ inline void copy(void const* local, void* data)
|
||||
{
|
||||
using T = typename BytesToType<Bytes>::type;
|
||||
|
||||
T const* in = static_cast<T const*>(local);
|
||||
T* out = static_cast<T*>(data);
|
||||
*out = *in;
|
||||
}
|
||||
|
||||
static float constexpr HALF_FLT_MAX = 65504.F;
|
||||
#define FINAL_MASK 0xffffffff
|
||||
|
||||
template <typename T>
|
||||
__inline__ __device__ T warpReduceSum(T val)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1)
|
||||
val = add<T>(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32)); //__shfl_sync bf16 return float when sm < 80
|
||||
return val;
|
||||
}
|
||||
|
||||
/* Calculate the sum of all elements in a block */
|
||||
template <typename T>
|
||||
__inline__ __device__ T blockReduceSum(T val)
|
||||
{
|
||||
static __shared__ T shared[32];
|
||||
int lane = threadIdx.x & 0x1f;
|
||||
int wid = threadIdx.x >> 5;
|
||||
|
||||
val = warpReduceSum<T>(val);
|
||||
|
||||
if (lane == 0)
|
||||
shared[wid] = val;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
|
||||
// blockDim.x is not divided by 32
|
||||
val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T) (0.0f);
|
||||
val = warpReduceSum<T>(val);
|
||||
|
||||
return val;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__inline__ __device__ T warpReduceMax(T val)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1)
|
||||
val = max(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32));
|
||||
return val;
|
||||
}
|
||||
|
||||
/* Calculate the maximum of all elements in a block */
|
||||
template <typename T>
|
||||
__inline__ __device__ T blockReduceMax(T val)
|
||||
{
|
||||
static __shared__ T shared[32];
|
||||
int lane = threadIdx.x & 0x1f; // in-warp idx
|
||||
int wid = threadIdx.x >> 5; // warp idx
|
||||
|
||||
val = warpReduceMax(val); // get maxx in each warp
|
||||
|
||||
if (lane == 0) // record in-warp maxx by warp Idx
|
||||
shared[wid] = val;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
|
||||
// blockDim.x is not divided by 32
|
||||
val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : -1e20f;
|
||||
val = warpReduceMax(val);
|
||||
|
||||
return val;
|
||||
}
|
||||
|
||||
/* Calculate the maximum of all elements in a block */
|
||||
template <typename T>
|
||||
__inline__ __device__ T blockAllReduceMax(T val)
|
||||
{
|
||||
static __shared__ T shared[32];
|
||||
int lane = threadIdx.x & 0x1f; // in-warp idx
|
||||
int wid = threadIdx.x >> 5; // warp idx
|
||||
|
||||
val = warpReduceMax(val); // get maxx in each warp
|
||||
|
||||
if (lane == 0) // record in-warp maxx by warp Idx
|
||||
shared[wid] = val;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
|
||||
// blockDim.x is not divided by 32
|
||||
val = (lane < (blockDim.x / 32.f)) ? shared[lane] : -1e20f;
|
||||
val = warpReduceMax(val);
|
||||
|
||||
return val;
|
||||
}
|
||||
|
||||
template <typename T, int NUM>
|
||||
__inline__ __device__ T warpReduceSumV2(T* val)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NUM; i++)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1)
|
||||
val[i] += __shfl_xor_sync(FINAL_MASK, val[i], mask, 32);
|
||||
}
|
||||
return (T) (0.0f);
|
||||
}
|
||||
|
||||
template <typename T, int NUM>
|
||||
__inline__ __device__ T blockReduceSumV2(T* val)
|
||||
{
|
||||
static __shared__ T shared[NUM][33];
|
||||
int lane = threadIdx.x & 0x1f;
|
||||
int wid = threadIdx.x >> 5;
|
||||
|
||||
warpReduceSumV2<T, NUM>(val);
|
||||
|
||||
if (lane == 0)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NUM; i++)
|
||||
{
|
||||
shared[i][wid] = val[i];
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
bool is_mask = threadIdx.x < (blockDim.x / 32.f);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NUM; i++)
|
||||
{
|
||||
val[i] = is_mask ? shared[i][lane] : (T) (0.0f);
|
||||
}
|
||||
warpReduceSumV2<T, NUM>(val);
|
||||
return (T) 0.0f;
|
||||
}
|
||||
|
||||
template <typename T, int NUM>
|
||||
__inline__ __device__ T warpReduceMaxV2(T* val)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NUM; i++)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1)
|
||||
val[i] = max(val[i], __shfl_xor_sync(FINAL_MASK, val[i], mask, 32));
|
||||
}
|
||||
return (T) (0.0f);
|
||||
}
|
||||
|
||||
template <typename T, int NUM>
|
||||
__inline__ __device__ T blockReduceMaxV2(T* val)
|
||||
{
|
||||
static __shared__ T shared[32][NUM];
|
||||
int lane = threadIdx.x & 0x1f; // in-warp idx
|
||||
int wid = threadIdx.x >> 5; // warp idx
|
||||
|
||||
warpReduceMaxV2<T, NUM>(val); // get maxx in each warp
|
||||
|
||||
if (lane == 0) // record in-warp maxx by warp Idx
|
||||
{
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NUM; i++)
|
||||
{
|
||||
shared[wid][i] = val[i];
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
|
||||
// blockDim.x is not divided by 32
|
||||
bool is_mask = threadIdx.x < (blockDim.x / 32.f);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NUM; i++)
|
||||
{
|
||||
val[i] = is_mask ? shared[lane][i] : (T) -1e20f;
|
||||
}
|
||||
warpReduceMaxV2<T, NUM>(val);
|
||||
|
||||
return (T) 0.0f;
|
||||
}
|
||||
|
||||
template <int NUM>
|
||||
__inline__ __device__ void cgBlockReduceSumElements(float* element_list, float* cgBlockReduceSumElements_shm)
|
||||
{
|
||||
cg::thread_block cta = cg::this_thread_block();
|
||||
cg::thread_block_tile<32> tile = cg::tiled_partition<32>(cta);
|
||||
|
||||
int const tid = cta.thread_rank();
|
||||
int const blockz = blockDim.x;
|
||||
for (int i = 0; i < NUM; i++)
|
||||
{
|
||||
#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))
|
||||
cgBlockReduceSumElements_shm[i * blockz + tid] = cg::reduce(tile, element_list[i], cg::plus<float>());
|
||||
#else
|
||||
// TODO Add implementation here
|
||||
if (threadIdx.x == 0 && blockIdx.x == 0)
|
||||
{
|
||||
printf("[ERROR] Not support cgBlockReduceSumElements when CUDA < 11 \n");
|
||||
assert(false);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
cg::sync(cta);
|
||||
if (tid == 0)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NUM; i++)
|
||||
{
|
||||
float beta = 0.0f;
|
||||
for (int j = 0; j < blockz; j += 32)
|
||||
{
|
||||
beta += cgBlockReduceSumElements_shm[i * blockz + j];
|
||||
}
|
||||
element_list[i] = beta;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int MAX_K>
|
||||
struct TopK
|
||||
{
|
||||
int p[MAX_K]; // index, being -1 at the tail if the array is not full
|
||||
T u[MAX_K]; // value in descend order, being -MAX_T_VAL if the element is invalid
|
||||
|
||||
__device__ __forceinline__ void insert(T const elem, int const elem_id)
|
||||
{
|
||||
if (elem_id < 0)
|
||||
{
|
||||
return;
|
||||
}
|
||||
// Condition of updating the array
|
||||
// 1. array is not full
|
||||
// 2. elem is greater than the smallest (last) element in the array
|
||||
// 3. elem is equal to the smallest (last) element in the array but its elem_id is smaller
|
||||
bool const need_update
|
||||
= (p[MAX_K - 1] == -1 || elem > u[MAX_K - 1] || elem == u[MAX_K - 1] && elem_id < p[MAX_K - 1]);
|
||||
if (!need_update)
|
||||
{
|
||||
return;
|
||||
}
|
||||
// Find suitable index for the new element
|
||||
int i;
|
||||
for (i = MAX_K - 2; i >= 0; --i)
|
||||
{
|
||||
bool const need_decrease = (p[i] == -1 || elem > u[i] || elem == u[i] && elem_id < p[i]);
|
||||
if (!need_decrease)
|
||||
break;
|
||||
}
|
||||
// Move elements to correct positions
|
||||
for (int k = MAX_K - 2; k >= i; --k)
|
||||
{
|
||||
p[k + 1] = p[k];
|
||||
u[k + 1] = u[k];
|
||||
}
|
||||
p[i] = elem_id;
|
||||
u[i] = elem;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void init()
|
||||
{
|
||||
T const MAX_T_VAL = (std::is_same<T, half>::value) ? HALF_FLT_MAX : FLT_MAX;
|
||||
for (int i = 0; i < MAX_K; i++)
|
||||
{
|
||||
p[i] = -1;
|
||||
u[i] = -MAX_T_VAL;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int MAX_K>
|
||||
__device__ __forceinline__ TopK<T, MAX_K> reduce_topk_op(TopK<T, MAX_K> const& a, TopK<T, MAX_K> const& b)
|
||||
{
|
||||
TopK<T, MAX_K> res = a;
|
||||
for (int i = 0; i < MAX_K; ++i)
|
||||
res.insert(b.u[i], b.p[i]);
|
||||
return res;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct TopK_2
|
||||
{
|
||||
int p = -1;
|
||||
T u = -((std::is_same<T, half>::value) ? HALF_FLT_MAX : FLT_MAX);
|
||||
|
||||
__device__ __forceinline__ void insert(T elem, int elem_id)
|
||||
{
|
||||
if (elem > u)
|
||||
{
|
||||
u = elem;
|
||||
p = elem_id;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void init()
|
||||
{
|
||||
u = -((std::is_same<T, half>::value) ? HALF_FLT_MAX : FLT_MAX);
|
||||
p = -1;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ TopK_2<T> reduce_topk_op_2(TopK_2<T> const& a, TopK_2<T> const& b)
|
||||
{
|
||||
return a.u > b.u ? a : b;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T clamp_inf_for_half(float const input)
|
||||
{
|
||||
return input;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __forceinline__ half clamp_inf_for_half(float const input)
|
||||
{
|
||||
// clamp inf values to enable fp16 training
|
||||
return input > 0.0f ? (half) min(input, HALF_FLT_MAX - 1000) : (half) max(input, -HALF_FLT_MAX + 1000);
|
||||
}
|
||||
|
||||
} // namespace common
|
||||
} // namespace tensorrt_llm
|
||||
123
sgl-kernel/3rdparty/tensorrt_llm/common/stlUtils.h
vendored
Normal file
123
sgl-kernel/3rdparty/tensorrt_llm/common/stlUtils.h
vendored
Normal file
@@ -0,0 +1,123 @@
|
||||
/*
|
||||
* Copyright (c) 2021-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <functional>
|
||||
#include <numeric>
|
||||
#include <optional>
|
||||
#include <sstream>
|
||||
|
||||
namespace tensorrt_llm::common::stl_utils
|
||||
{
|
||||
|
||||
template <typename TInputIt, typename TOutputIt, typename TBinOp>
|
||||
constexpr TOutputIt basicInclusiveScan(TInputIt first, TInputIt last, TOutputIt dFirst, TBinOp op)
|
||||
{
|
||||
if (first != last)
|
||||
{
|
||||
auto val = *first;
|
||||
while (true)
|
||||
{
|
||||
*dFirst = val;
|
||||
++dFirst;
|
||||
++first;
|
||||
if (first == last)
|
||||
{
|
||||
break;
|
||||
}
|
||||
val = op(std::move(val), *first);
|
||||
}
|
||||
}
|
||||
return dFirst;
|
||||
}
|
||||
|
||||
template <typename TInputIt, typename TOutputIt>
|
||||
constexpr TOutputIt inclusiveScan(TInputIt first, TInputIt last, TOutputIt dFirst)
|
||||
{
|
||||
#if defined(__GNUC__) && __GNUC__ <= 8
|
||||
return basicInclusiveScan(first, last, dFirst, std::plus<>{});
|
||||
#else
|
||||
return std::inclusive_scan(first, last, dFirst);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename TInputIt, typename TOutputIt, typename T, typename TBinOp>
|
||||
constexpr TOutputIt basicExclusiveScan(TInputIt first, TInputIt last, TOutputIt dFirst, T init, TBinOp op)
|
||||
{
|
||||
if (first != last)
|
||||
{
|
||||
while (true)
|
||||
{
|
||||
T tmp{op(init, *first)};
|
||||
*dFirst = init;
|
||||
++dFirst;
|
||||
++first;
|
||||
if (first == last)
|
||||
{
|
||||
break;
|
||||
}
|
||||
init = std::move(tmp);
|
||||
}
|
||||
}
|
||||
return dFirst;
|
||||
}
|
||||
|
||||
template <typename TInputIt, typename TOutputIt, typename T>
|
||||
constexpr TOutputIt exclusiveScan(TInputIt first, TInputIt last, TOutputIt dFirst, T init)
|
||||
{
|
||||
#if defined(__GNUC__) && __GNUC__ <= 8
|
||||
return basicExclusiveScan(first, last, dFirst, std::move(init), std::plus<>{});
|
||||
#else
|
||||
return std::exclusive_scan(first, last, dFirst, std::move(init));
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, typename = void>
|
||||
struct HasOperatorOutput : std::false_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct HasOperatorOutput<T, std::void_t<decltype((std::declval<std::ostream&>() << std::declval<T>()))>>
|
||||
: std::true_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
std::string toString(T const& t, typename std::enable_if_t<HasOperatorOutput<T>::value, int> = 0)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << t;
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::string toString(std::optional<T> const& t, typename std::enable_if_t<HasOperatorOutput<T>::value, int> = 0)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
if (t)
|
||||
{
|
||||
oss << t.value();
|
||||
}
|
||||
else
|
||||
{
|
||||
oss << "None";
|
||||
}
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::common::stl_utils
|
||||
76
sgl-kernel/3rdparty/tensorrt_llm/common/stringUtils.cpp
vendored
Normal file
76
sgl-kernel/3rdparty/tensorrt_llm/common/stringUtils.cpp
vendored
Normal file
@@ -0,0 +1,76 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/common/stringUtils.h"
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
|
||||
#include <cerrno>
|
||||
#include <cstdarg>
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
namespace tensorrt_llm::common
|
||||
{
|
||||
|
||||
namespace
|
||||
{
|
||||
std::string vformat(char const* fmt, va_list args)
|
||||
{
|
||||
va_list args0;
|
||||
va_copy(args0, args);
|
||||
auto const size = vsnprintf(nullptr, 0, fmt, args0);
|
||||
if (size <= 0)
|
||||
return "";
|
||||
|
||||
std::string stringBuf(size, char{});
|
||||
auto const size2 = std::vsnprintf(&stringBuf[0], size + 1, fmt, args);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(size2 == size, std::string(std::strerror(errno)));
|
||||
|
||||
return stringBuf;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::string fmtstr(char const* format, ...)
|
||||
{
|
||||
va_list args;
|
||||
va_start(args, format);
|
||||
std::string result = vformat(format, args);
|
||||
va_end(args);
|
||||
return result;
|
||||
};
|
||||
|
||||
std::unordered_set<std::string> str2set(std::string const& input, char delimiter)
|
||||
{
|
||||
std::unordered_set<std::string> values;
|
||||
if (!input.empty())
|
||||
{
|
||||
std::stringstream valStream(input);
|
||||
std::string val;
|
||||
while (std::getline(valStream, val, delimiter))
|
||||
{
|
||||
if (!val.empty())
|
||||
{
|
||||
values.insert(val);
|
||||
}
|
||||
}
|
||||
}
|
||||
return values;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::common
|
||||
42
sgl-kernel/3rdparty/tensorrt_llm/common/timestampUtils.cpp
vendored
Normal file
42
sgl-kernel/3rdparty/tensorrt_llm/common/timestampUtils.cpp
vendored
Normal file
@@ -0,0 +1,42 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <chrono>
|
||||
#include <iomanip>
|
||||
#include <sstream>
|
||||
|
||||
#include "tensorrt_llm/common/timestampUtils.h"
|
||||
|
||||
namespace tensorrt_llm::common
|
||||
{
|
||||
|
||||
std::string getCurrentTimestamp()
|
||||
{
|
||||
auto now = std::chrono::system_clock::now();
|
||||
auto now_t = std::chrono::system_clock::to_time_t(now);
|
||||
auto tm = *std::localtime(&now_t);
|
||||
|
||||
auto epoch_to_now = now.time_since_epoch();
|
||||
auto seconds = std::chrono::duration_cast<std::chrono::seconds>(epoch_to_now);
|
||||
auto us = std::chrono::duration_cast<std::chrono::microseconds>(epoch_to_now - seconds);
|
||||
|
||||
std::ostringstream stream;
|
||||
stream << std::put_time(&tm, "%m-%d-%Y %H:%M:%S");
|
||||
stream << "." << std::setfill('0') << std::setw(6) << us.count();
|
||||
return stream.str();
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::common
|
||||
25
sgl-kernel/3rdparty/tensorrt_llm/common/timestampUtils.h
vendored
Normal file
25
sgl-kernel/3rdparty/tensorrt_llm/common/timestampUtils.h
vendored
Normal file
@@ -0,0 +1,25 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace tensorrt_llm::common
|
||||
{
|
||||
|
||||
/// @brief Get the current timestamp in the format "MM-DD-YYYY HH:MM:SS:uuuuuu"
|
||||
std::string getCurrentTimestamp();
|
||||
|
||||
} // namespace tensorrt_llm::common
|
||||
105
sgl-kernel/3rdparty/tensorrt_llm/common/tllmException.cpp
vendored
Normal file
105
sgl-kernel/3rdparty/tensorrt_llm/common/tllmException.cpp
vendored
Normal file
@@ -0,0 +1,105 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/common/tllmException.h"
|
||||
#include "tensorrt_llm/common/stringUtils.h"
|
||||
|
||||
#include <cstdlib>
|
||||
#if !defined(_MSC_VER)
|
||||
#include <cxxabi.h>
|
||||
#include <dlfcn.h>
|
||||
#include <execinfo.h>
|
||||
#endif
|
||||
#include <sstream>
|
||||
|
||||
namespace tensorrt_llm::common
|
||||
{
|
||||
|
||||
namespace
|
||||
{
|
||||
int constexpr VOID_PTR_SZ = 2 + sizeof(void*) * 2;
|
||||
}
|
||||
|
||||
#if !defined(_MSC_VER)
|
||||
|
||||
TllmException::TllmException(char const* file, std::size_t line, std::string const& msg)
|
||||
: std::runtime_error{""}
|
||||
{
|
||||
mNbFrames = backtrace(mCallstack.data(), MAX_FRAMES);
|
||||
auto const trace = getTrace();
|
||||
std::runtime_error::operator=(
|
||||
std::runtime_error{fmtstr("%s (%s:%zu)\n%s", msg.c_str(), file, line, trace.c_str())});
|
||||
}
|
||||
#else
|
||||
TllmException::TllmException(char const* file, std::size_t line, std::string const& msg)
|
||||
: mNbFrames{}
|
||||
, std::runtime_error{fmtstr("%s (%s:%zu)", msg.c_str(), file, line)}
|
||||
{
|
||||
}
|
||||
#endif
|
||||
|
||||
TllmException::~TllmException() noexcept = default;
|
||||
|
||||
std::string TllmException::getTrace() const
|
||||
{
|
||||
#if defined(_MSC_VER)
|
||||
return "";
|
||||
#else
|
||||
auto const trace = backtrace_symbols(mCallstack.data(), mNbFrames);
|
||||
std::ostringstream buf;
|
||||
for (auto i = 1; i < mNbFrames; ++i)
|
||||
{
|
||||
Dl_info info;
|
||||
if (dladdr(mCallstack[i], &info) && info.dli_sname)
|
||||
{
|
||||
auto const clearName = demangle(info.dli_sname);
|
||||
buf << fmtstr("%-3d %*p %s + %zd", i, VOID_PTR_SZ, mCallstack[i], clearName.c_str(),
|
||||
static_cast<char*>(mCallstack[i]) - static_cast<char*>(info.dli_saddr));
|
||||
}
|
||||
else
|
||||
{
|
||||
buf << fmtstr("%-3d %*p %s", i, VOID_PTR_SZ, mCallstack[i], trace[i]);
|
||||
}
|
||||
if (i < mNbFrames - 1)
|
||||
buf << std::endl;
|
||||
}
|
||||
|
||||
if (mNbFrames == MAX_FRAMES)
|
||||
buf << std::endl << "[truncated]";
|
||||
|
||||
std::free(trace);
|
||||
return buf.str();
|
||||
#endif
|
||||
}
|
||||
|
||||
std::string TllmException::demangle(char const* name)
|
||||
{
|
||||
#if defined(_MSC_VER)
|
||||
return name;
|
||||
#else
|
||||
std::string clearName{name};
|
||||
auto status = -1;
|
||||
auto const demangled = abi::__cxa_demangle(name, nullptr, nullptr, &status);
|
||||
if (status == 0)
|
||||
{
|
||||
clearName = demangled;
|
||||
std::free(demangled);
|
||||
}
|
||||
return clearName;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::common
|
||||
87
sgl-kernel/3rdparty/tensorrt_llm/common/workspace.h
vendored
Normal file
87
sgl-kernel/3rdparty/tensorrt_llm/common/workspace.h
vendored
Normal file
@@ -0,0 +1,87 @@
|
||||
/*
|
||||
* Copyright (c) 1993-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
|
||||
namespace tensorrt_llm::common
|
||||
{
|
||||
|
||||
std::uintptr_t constexpr kCudaMemAlign = 128;
|
||||
|
||||
inline int8_t* alignPtr(int8_t* ptr, uintptr_t to)
|
||||
{
|
||||
uintptr_t addr = (uintptr_t) ptr;
|
||||
if (addr % to)
|
||||
{
|
||||
addr += to - addr % to;
|
||||
}
|
||||
return (int8_t*) addr;
|
||||
}
|
||||
|
||||
constexpr size_t alignSize(size_t size, size_t to)
|
||||
{
|
||||
if ((size % to) != 0U)
|
||||
{
|
||||
size += to - size % to;
|
||||
}
|
||||
return size;
|
||||
}
|
||||
|
||||
inline int8_t* nextWorkspacePtrCommon(int8_t* ptr, uintptr_t previousWorkspaceSize, uintptr_t const alignment)
|
||||
{
|
||||
uintptr_t addr = (uintptr_t) ptr;
|
||||
addr += previousWorkspaceSize;
|
||||
return alignPtr((int8_t*) addr, alignment);
|
||||
}
|
||||
|
||||
inline int8_t* nextWorkspacePtr(int8_t* ptr, uintptr_t previousWorkspaceSize)
|
||||
{
|
||||
return nextWorkspacePtrCommon(ptr, previousWorkspaceSize, kCudaMemAlign);
|
||||
}
|
||||
|
||||
inline int8_t* nextWorkspacePtr(
|
||||
int8_t* const base, uintptr_t& offset, uintptr_t const size, uintptr_t const alignment = kCudaMemAlign)
|
||||
{
|
||||
uintptr_t curr_offset = offset;
|
||||
uintptr_t next_offset = curr_offset + ((size + alignment - 1) / alignment) * alignment;
|
||||
int8_t* newptr = size == 0 ? nullptr : base + curr_offset;
|
||||
offset = next_offset;
|
||||
return newptr;
|
||||
}
|
||||
|
||||
inline int8_t* nextWorkspacePtrWithAlignment(
|
||||
int8_t* ptr, uintptr_t previousWorkspaceSize, uintptr_t const alignment = kCudaMemAlign)
|
||||
{
|
||||
return nextWorkspacePtrCommon(ptr, previousWorkspaceSize, alignment);
|
||||
}
|
||||
|
||||
inline size_t calculateTotalWorkspaceSize(
|
||||
size_t const* workspaces, int count, uintptr_t const alignment = kCudaMemAlign)
|
||||
{
|
||||
size_t total = 0;
|
||||
for (int i = 0; i < count; i++)
|
||||
{
|
||||
total += workspaces[i];
|
||||
if (workspaces[i] % alignment)
|
||||
{
|
||||
total += alignment - (workspaces[i] % alignment);
|
||||
}
|
||||
}
|
||||
return total;
|
||||
}
|
||||
|
||||
}; // namespace tensorrt_llm::common
|
||||
@@ -0,0 +1,352 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/arch/util.hpp>
|
||||
#include <cute/atom/copy_traits.hpp>
|
||||
#include <cute/numeric/numeric_types.hpp>
|
||||
|
||||
// Config
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) && (__CUDACC_VER_MAJOR__ >= 10))
|
||||
#define CUTE_ARCH_RED_F16_SM70_ENABLED
|
||||
#endif
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12))
|
||||
#define CUTE_ARCH_RED_VEC_SM90_ENABLED
|
||||
#define CUTE_ARCH_RED_BF16_SM90_ENABLED
|
||||
#endif
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
//////////////////////////////////
|
||||
// Wrapper around CUDA's atomicAdd
|
||||
//////////////////////////////////
|
||||
|
||||
template <class T>
|
||||
struct TypedAtomicAdd
|
||||
{
|
||||
using SRegisters = T[1];
|
||||
using DRegisters = T[1];
|
||||
|
||||
CUTE_HOST_DEVICE static constexpr void copy(T const& src, T& dst)
|
||||
{
|
||||
atomicAdd(&dst, src);
|
||||
}
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct Copy_Traits<TypedAtomicAdd<T>>
|
||||
{
|
||||
// Logical thread id to thread idx (one-thread)
|
||||
using ThrID = Layout<_1>;
|
||||
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape<_1, Int<sizeof_bits<T>::value>>>;
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_1, Int<sizeof_bits<T>::value>>>;
|
||||
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
};
|
||||
|
||||
//////////////////////////////////
|
||||
// F16 ADD PTX
|
||||
//////////////////////////////////
|
||||
|
||||
struct SM70_RED_ADD_NOFTZ_F16
|
||||
{
|
||||
using SRegisters = uint16_t[1];
|
||||
using DRegisters = uint16_t[1];
|
||||
|
||||
CUTE_HOST_DEVICE static void copy(uint16_t const& src0, uint16_t& gmem_dst)
|
||||
{
|
||||
#if defined(CUTE_ARCH_RED_F16_SM70_ENABLED)
|
||||
asm volatile("red.global.add.noftz.f16 [%0], %1;\n" ::"l"(&gmem_dst), "h"(src0));
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.f16 without CUTE_ARCH_RED_F16_SM70_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Copy_Traits<SM70_RED_ADD_NOFTZ_F16>
|
||||
{
|
||||
// Logical thread id to thread idx (one-thread)
|
||||
using ThrID = Layout<_1>;
|
||||
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape<_1, _16>>;
|
||||
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_1, _16>>;
|
||||
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
};
|
||||
|
||||
struct SM70_RED_ADD_NOFTZ_F16x2
|
||||
{
|
||||
using SRegisters = uint32_t[1];
|
||||
using DRegisters = uint32_t[1];
|
||||
|
||||
CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t& gmem_dst)
|
||||
{
|
||||
#if defined(CUTE_ARCH_RED_F16_SM70_ENABLED)
|
||||
asm volatile("red.global.add.noftz.f16x2 [%0], %1;\n" ::"l"(&gmem_dst), "r"(src0));
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.f16 without CUTE_ARCH_RED_F16_SM70_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Copy_Traits<SM70_RED_ADD_NOFTZ_F16x2>
|
||||
{
|
||||
// Logical thread id to thread idx (one-thread)
|
||||
using ThrID = Layout<_1>;
|
||||
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape<_1, _32>>;
|
||||
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_1, _32>>;
|
||||
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
};
|
||||
|
||||
struct SM90_RED_ADD_NOFTZ_F16x2_V2
|
||||
{
|
||||
using SRegisters = uint32_t[2];
|
||||
using DRegisters = uint64_t[1];
|
||||
|
||||
CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t const& src1, uint64_t& gmem_dst)
|
||||
{
|
||||
#if defined(CUTE_ARCH_RED_VEC_SM90_ENABLED)
|
||||
asm volatile("red.global.add.noftz.v2.f16x2 [%0], {%1, %2};\n" ::"l"(&gmem_dst), "r"(src0), "r"(src1));
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.vX without CUTE_ARCH_RED_VEC_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Copy_Traits<SM90_RED_ADD_NOFTZ_F16x2_V2>
|
||||
{
|
||||
// Logical thread id to thread idx (one-thread)
|
||||
using ThrID = Layout<_1>;
|
||||
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape<_1, _64>>;
|
||||
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_1, _64>>;
|
||||
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
};
|
||||
|
||||
struct SM90_RED_ADD_NOFTZ_F16x2_V4
|
||||
{
|
||||
using SRegisters = uint32_t[4];
|
||||
using DRegisters = uint128_t[1];
|
||||
|
||||
CUTE_HOST_DEVICE static void copy(
|
||||
uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, uint128_t& gmem_dst)
|
||||
{
|
||||
#if defined(CUTE_ARCH_RED_VEC_SM90_ENABLED)
|
||||
asm volatile("red.global.add.noftz.v4.f16x2 [%0], {%1, %2, %3, %4};\n" ::"l"(&gmem_dst), "r"(src0), "r"(src1),
|
||||
"r"(src2), "r"(src3));
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.vX without CUTE_ARCH_RED_VEC_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Copy_Traits<SM90_RED_ADD_NOFTZ_F16x2_V4>
|
||||
{
|
||||
// Logical thread id to thread idx (one-thread)
|
||||
using ThrID = Layout<_1>;
|
||||
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape<_1, _128>>;
|
||||
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_1, _128>>;
|
||||
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
};
|
||||
|
||||
//////////////////////////////////
|
||||
// BF16 ADD PTX
|
||||
//////////////////////////////////
|
||||
|
||||
struct SM90_RED_ADD_NOFTZ_BF16
|
||||
{
|
||||
using SRegisters = uint16_t[1];
|
||||
using DRegisters = uint16_t[1];
|
||||
|
||||
CUTE_HOST_DEVICE static void copy(uint16_t const& src0, uint16_t& gmem_dst)
|
||||
{
|
||||
#if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED)
|
||||
asm volatile("red.global.add.noftz.bf16 [%0], %1;\n" ::"l"(&gmem_dst), "h"(src0));
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Copy_Traits<SM90_RED_ADD_NOFTZ_BF16>
|
||||
{
|
||||
// Logical thread id to thread idx (one-thread)
|
||||
using ThrID = Layout<_1>;
|
||||
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape<_1, _16>>;
|
||||
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_1, _16>>;
|
||||
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
};
|
||||
|
||||
//////////////////////////////////
|
||||
|
||||
struct SM90_RED_ADD_NOFTZ_BF16x2
|
||||
{
|
||||
using SRegisters = uint32_t[1];
|
||||
using DRegisters = uint32_t[1];
|
||||
|
||||
CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t& gmem_dst)
|
||||
{
|
||||
#if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED)
|
||||
asm volatile("red.global.add.noftz.bf16x2 [%0], %1;\n" ::"l"(&gmem_dst), "r"(src0));
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Copy_Traits<SM90_RED_ADD_NOFTZ_BF16x2>
|
||||
{
|
||||
// Logical thread id to thread idx (one-thread)
|
||||
using ThrID = Layout<_1>;
|
||||
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape<_1, _32>>;
|
||||
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_1, _32>>;
|
||||
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
};
|
||||
|
||||
//////////////////////////////////
|
||||
|
||||
struct SM90_RED_ADD_NOFTZ_BF16x2_V2
|
||||
{
|
||||
using SRegisters = uint32_t[2];
|
||||
using DRegisters = uint64_t[1];
|
||||
|
||||
CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t const& src1, uint64_t& gmem_dst)
|
||||
{
|
||||
#if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED)
|
||||
asm volatile("red.global.add.noftz.v2.bf16x2 [%0], {%1, %2};\n" ::"l"(&gmem_dst), "r"(src0), "r"(src1));
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Copy_Traits<SM90_RED_ADD_NOFTZ_BF16x2_V2>
|
||||
{
|
||||
// Logical thread id to thread idx (one-thread)
|
||||
using ThrID = Layout<_1>;
|
||||
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape<_1, _64>>;
|
||||
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_1, _64>>;
|
||||
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
};
|
||||
|
||||
//////////////////////////////////
|
||||
|
||||
struct SM90_RED_ADD_NOFTZ_BF16x2_V4
|
||||
{
|
||||
using SRegisters = uint32_t[4];
|
||||
using DRegisters = uint128_t[1];
|
||||
|
||||
CUTE_HOST_DEVICE static void copy(
|
||||
uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, uint128_t& gmem_dst)
|
||||
{
|
||||
#if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED)
|
||||
asm volatile("red.global.add.noftz.v4.bf16x2 [%0], {%1, %2, %3, %4};\n" ::"l"(&gmem_dst), "r"(src0), "r"(src1),
|
||||
"r"(src2), "r"(src3));
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Copy_Traits<SM90_RED_ADD_NOFTZ_BF16x2_V4>
|
||||
{
|
||||
// Logical thread id to thread idx (one-thread)
|
||||
using ThrID = Layout<_1>;
|
||||
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape<_1, _128>>;
|
||||
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_1, _128>>;
|
||||
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
};
|
||||
|
||||
//////////////////////////////////
|
||||
|
||||
} // end namespace cute
|
||||
120
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/arch/mma.h
vendored
Normal file
120
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/arch/mma.h
vendored
Normal file
@@ -0,0 +1,120 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Templates exposing architecture support for multiply-add operations
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "cutlass_extensions/weight_only_quant_op.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace arch
|
||||
{
|
||||
|
||||
// Tag which triggers MMA which will trigger
|
||||
struct OpMultiplyAddDequantizeInterleavedBToA;
|
||||
|
||||
/*
|
||||
Below we have extra tags to signal what kind of dequantization we want to do
|
||||
(per col, scale only fine grained, finegrained with zero). This still lets us
|
||||
the existing template infrastructure (incl. that in CUTLASS). However, we
|
||||
split out the template below into OpMultiplyAddDequantizeInterleavedBToA along
|
||||
with the quantization op before instantiating the GEMM pieces.
|
||||
|
||||
Note that this is somewhat of a hack, but it SIGNIFICANTLY reduces the amount of
|
||||
code we need to duplicate.
|
||||
*/
|
||||
struct OpMultiplyAddDequantizeInterleavedBToA_percol_scale;
|
||||
struct OpMultiplyAddDequantizeInterleavedBToA_fine_scale;
|
||||
struct OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias;
|
||||
|
||||
// The default just forwards the original operator
|
||||
template <typename MmaOp, WeightOnlyQuantOp QuantOp_>
|
||||
struct TagOperator
|
||||
{
|
||||
using TaggedOperator = MmaOp;
|
||||
};
|
||||
|
||||
// Specializations below attach more information to the operator
|
||||
template <>
|
||||
struct TagOperator<OpMultiplyAddDequantizeInterleavedBToA, WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY>
|
||||
{
|
||||
using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_percol_scale;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TagOperator<OpMultiplyAddDequantizeInterleavedBToA, WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY>
|
||||
{
|
||||
using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_fine_scale;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TagOperator<OpMultiplyAddDequantizeInterleavedBToA, WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS>
|
||||
{
|
||||
using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias;
|
||||
};
|
||||
|
||||
// Here we instantiate some structs to "detag" the tagged operator. It splits it back to the original
|
||||
// operator + the extra information. If no extra info was tagged, the dequant op per column scaling
|
||||
// as a default.
|
||||
template <typename TaggedMmaOp>
|
||||
struct DetagOperator
|
||||
{
|
||||
using Operator = TaggedMmaOp;
|
||||
static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DetagOperator<OpMultiplyAddDequantizeInterleavedBToA_percol_scale>
|
||||
{
|
||||
using Operator = OpMultiplyAddDequantizeInterleavedBToA;
|
||||
static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DetagOperator<OpMultiplyAddDequantizeInterleavedBToA_fine_scale>
|
||||
{
|
||||
using Operator = OpMultiplyAddDequantizeInterleavedBToA;
|
||||
static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DetagOperator<OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias>
|
||||
{
|
||||
using Operator = OpMultiplyAddDequantizeInterleavedBToA;
|
||||
static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS;
|
||||
};
|
||||
|
||||
} // namespace arch
|
||||
} // namespace cutlass
|
||||
@@ -0,0 +1,88 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
#include "cutlass/device_kernel.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace cutlass_extensions
|
||||
{
|
||||
|
||||
template <typename GemmKernel, bool enable_cutlass_3x = false>
|
||||
inline int compute_occupancy_for_kernel()
|
||||
{
|
||||
|
||||
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
|
||||
|
||||
if (smem_size > (48 << 10))
|
||||
{
|
||||
cudaFuncAttributes attr;
|
||||
int device = 0;
|
||||
int max_smem_per_block = 0;
|
||||
tensorrt_llm::common::check_cuda_error(cudaGetDevice(&device));
|
||||
tensorrt_llm::common::check_cuda_error(
|
||||
cudaDeviceGetAttribute(&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device));
|
||||
if constexpr (enable_cutlass_3x)
|
||||
{
|
||||
tensorrt_llm::common::check_cuda_error(cudaFuncGetAttributes(&attr, cutlass::device_kernel<GemmKernel>));
|
||||
}
|
||||
else
|
||||
{
|
||||
tensorrt_llm::common::check_cuda_error(cudaFuncGetAttributes(&attr, cutlass::Kernel<GemmKernel>));
|
||||
}
|
||||
if (smem_size + attr.sharedSizeBytes >= static_cast<size_t>(max_smem_per_block))
|
||||
{
|
||||
// This should mean that
|
||||
// cudaFuncSetAttribute(cutlass::Kernel<GemmKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)
|
||||
// wouldn't work. In that case, we return an occupancy of 0. This will cause the heuristic to ignore this
|
||||
// configuration.
|
||||
return 0;
|
||||
}
|
||||
|
||||
if constexpr (enable_cutlass_3x)
|
||||
{
|
||||
tensorrt_llm::common::check_cuda_error(cudaFuncSetAttribute(
|
||||
cutlass::device_kernel<GemmKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
}
|
||||
else
|
||||
{
|
||||
tensorrt_llm::common::check_cuda_error(cudaFuncSetAttribute(
|
||||
cutlass::Kernel<GemmKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
}
|
||||
}
|
||||
|
||||
int max_active_blocks = -1;
|
||||
if constexpr (enable_cutlass_3x)
|
||||
{
|
||||
tensorrt_llm::common::check_cuda_error(
|
||||
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, cutlass::device_kernel<GemmKernel>,
|
||||
128 * (GemmKernel::NumLoadWarpGroups + GemmKernel::NumMmaWarpGroups), smem_size));
|
||||
}
|
||||
else
|
||||
{
|
||||
tensorrt_llm::common::check_cuda_error(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_active_blocks, cutlass::Kernel<GemmKernel>, GemmKernel::kThreadCount, smem_size));
|
||||
}
|
||||
|
||||
return max_active_blocks;
|
||||
}
|
||||
|
||||
} // namespace cutlass_extensions
|
||||
} // namespace tensorrt_llm
|
||||
@@ -0,0 +1,550 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Functor performing elementwise operations used by epilogues.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/epilogue/collective/detail.hpp"
|
||||
#include "cutlass/fast_math.h"
|
||||
|
||||
#include "cute/numeric/numeric_types.hpp"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
#include "cutlass_extensions/arch/copy_red_global.hpp"
|
||||
#include "cutlass_extensions/util/gather_tensor.hpp"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace epilogue
|
||||
{
|
||||
namespace collective
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <class StrideC_, class ElementD_, class StrideD_, class ThreadEpilogueOp_, class ElementBias, class StrideBias,
|
||||
class ElementScale, class StrideScale, class EpilogueTile, class SmemLayoutAtomD, class CopyOpR2S, class CopyOpS2R,
|
||||
class CopyOpR2G>
|
||||
class EpilogueMoeFusedFinalize
|
||||
{
|
||||
public:
|
||||
using EpilogueSchedule = PtrArrayNoSmemWarpSpecialized;
|
||||
using DispatchPolicy = PtrArrayNoSmemWarpSpecialized;
|
||||
|
||||
using ThreadEpilogueOp = ThreadEpilogueOp_;
|
||||
using ElementOutput = typename ThreadEpilogueOp::ElementOutput;
|
||||
using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator;
|
||||
using ElementCompute = typename ThreadEpilogueOp::ElementCompute;
|
||||
using ElementIntermediate = typename ThreadEpilogueOp::ElementD;
|
||||
|
||||
using ElementC = typename ThreadEpilogueOp::ElementC;
|
||||
using StrideC = StrideC_;
|
||||
using InternalStrideC = cute::remove_pointer_t<StrideC>;
|
||||
using ElementD = ElementD_;
|
||||
using StrideD = StrideD_;
|
||||
using InternalStrideD = cute::remove_pointer_t<StrideD>;
|
||||
|
||||
static_assert(!is_same_v<InternalStrideC, StrideC>, "Stride C must be a pointer");
|
||||
static_assert(is_same_v<InternalStrideD, StrideD>, "Stride D must not be a pointer");
|
||||
|
||||
using CopyAtomR2S = Copy_Atom<CopyOpR2S, ElementAccumulator>;
|
||||
using CopyAtomS2R = Copy_Atom<CopyOpS2R, ElementAccumulator>;
|
||||
using CopyAtomR2G = Copy_Atom<CopyOpR2G, ElementD>;
|
||||
static constexpr int AlignmentD = CopyAtomR2G::NumValSrc;
|
||||
|
||||
using SmemLayoutD = decltype(tile_to_shape(SmemLayoutAtomD{}, EpilogueTile{}));
|
||||
|
||||
constexpr static size_t SmemAlignmentD = cutlass::detail::alignment_for_swizzle(SmemLayoutD{});
|
||||
|
||||
struct SharedStorage
|
||||
{
|
||||
alignas(SmemAlignmentD) cute::ArrayEngine<ElementAccumulator, cosize_v<SmemLayoutD>> smem_D;
|
||||
};
|
||||
|
||||
struct TensorMapStorage
|
||||
{
|
||||
};
|
||||
|
||||
struct Arguments
|
||||
{
|
||||
typename ThreadEpilogueOp::Params thread{};
|
||||
ElementC const** ptr_C{};
|
||||
StrideC dC{};
|
||||
ElementD* ptr_D{};
|
||||
StrideD dD{};
|
||||
ElementBias const* ptr_bias;
|
||||
StrideBias dBias{};
|
||||
ElementScale const* ptr_scale;
|
||||
StrideScale dScale{};
|
||||
int64_t const* group_offset{};
|
||||
int32_t const* scatter_index{};
|
||||
cutlass::FastDivmod num_rows_in_final_output;
|
||||
};
|
||||
|
||||
using Params = Arguments;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
template <class ProblemShape>
|
||||
static constexpr Params to_underlying_arguments(
|
||||
ProblemShape const&, Arguments const& args, [[maybe_unused]] void* workspace)
|
||||
{
|
||||
return args;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static size_t get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count = 0)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static cutlass::Status initialize_workspace(ProblemShape const& problem_shape, Arguments const& args,
|
||||
void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr)
|
||||
{
|
||||
return cutlass::Status::kSuccess;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
CUTLASS_HOST_DEVICE static bool can_implement(
|
||||
[[maybe_unused]] ProblemShape problem_shape, [[maybe_unused]] Arguments const& args)
|
||||
{
|
||||
bool implementable = true;
|
||||
if (problem_shape.is_host_problem_shape_available())
|
||||
{
|
||||
// Check alignment for all problem sizes
|
||||
for (int i = 0; i < problem_shape.groups(); i++)
|
||||
{
|
||||
auto problem_shape_MNKL = append<4>(problem_shape.get_host_problem_shape(i), 1);
|
||||
auto [M, N, K, L] = problem_shape_MNKL;
|
||||
implementable = implementable
|
||||
&& cutlass::detail::check_alignment<AlignmentD>(cute::make_shape(M, N, L), InternalStrideD{});
|
||||
}
|
||||
}
|
||||
|
||||
if (!implementable)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(
|
||||
" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for selected global "
|
||||
"reduction instruction.\n");
|
||||
}
|
||||
return implementable;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
EpilogueMoeFusedFinalize(Params const& params_)
|
||||
: params(params_)
|
||||
{
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
bool is_source_needed()
|
||||
{
|
||||
// For Ptr-Array or Grouped Gemm we cannot determine if source is needed based on first beta.
|
||||
return params.ptr_C != nullptr
|
||||
&& (params.thread.beta_ptr_array || params.thread.beta_ptr || params.thread.beta != 0);
|
||||
}
|
||||
|
||||
template <class ProblemShapeMNKL, class BlockShapeMNK, class BlockCoordMNKL, class FrgEngine, class FrgLayout,
|
||||
class TiledMma, class ResidueMNK>
|
||||
CUTLASS_HOST_DEVICE void operator()(ProblemShapeMNKL problem_shape_mnkl, BlockShapeMNK blk_shape_MNK,
|
||||
BlockCoordMNKL blk_coord_mnkl, cute::Tensor<FrgEngine, FrgLayout> const& accumulators, TiledMma tiled_mma,
|
||||
ResidueMNK residue_mnk, int thread_idx, [[maybe_unused]] char* smem_buf)
|
||||
{
|
||||
using namespace cute;
|
||||
using X = Underscore;
|
||||
|
||||
static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4");
|
||||
static_assert(is_static<BlockShapeMNK>::value, "ThreadBlock tile shape must be static");
|
||||
static_assert(rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3");
|
||||
static_assert(rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3");
|
||||
|
||||
auto synchronize = [&]()
|
||||
{ cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
|
||||
|
||||
// Separate out problem shape for convenience
|
||||
auto M = get<0>(problem_shape_mnkl);
|
||||
auto N = get<1>(problem_shape_mnkl);
|
||||
auto L = get<3>(problem_shape_mnkl);
|
||||
|
||||
auto mma_tile_m = tile_size<0>(tiled_mma);
|
||||
auto mma_tile_n = tile_size<1>(tiled_mma);
|
||||
auto epi_tile_m = size<0>(EpilogueTile{});
|
||||
auto epi_tile_n = size<1>(EpilogueTile{});
|
||||
|
||||
CUTE_STATIC_ASSERT(epi_tile_m % mma_tile_m == 0, "MMA_TILE_M must divide EPI_TILE_M");
|
||||
CUTE_STATIC_ASSERT(mma_tile_n % epi_tile_n == 0, "EPI_TILE_N must divide MMA_TILE_N");
|
||||
|
||||
// Batches are managed by using appropriate pointers to C and D matrices
|
||||
int32_t const mock_L = 1;
|
||||
int32_t const mock_l_coord = 0;
|
||||
|
||||
// Slice to get the tile this CTA is responsible for
|
||||
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl;
|
||||
|
||||
// If scalar alpha/beta are provided, i.e., same alpha/beta applies to all batches/groups.
|
||||
// If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups,
|
||||
// we get the correct alpha/beta values for the current batch/group using group index.
|
||||
ThreadEpilogueOp epilogue_op(params.thread, l_coord);
|
||||
|
||||
SharedStorage& storage = *reinterpret_cast<SharedStorage*>(smem_buf);
|
||||
|
||||
Tensor sD_ = make_tensor(make_smem_ptr(storage.smem_D.begin()), SmemLayoutD{});
|
||||
Tensor sD = as_position_independent_swizzle_tensor(sD_);
|
||||
|
||||
// Function to scatter output rows
|
||||
auto& num_rows = params.num_rows_in_final_output;
|
||||
auto read_scatter_map = IndexedGather(make_gmem_ptr(params.scatter_index + params.group_offset[l_coord]));
|
||||
auto get_scatter_idx = [&](auto i)
|
||||
{
|
||||
auto scatter = read_scatter_map(i);
|
||||
int quot, rem;
|
||||
num_rows(quot, rem, scatter);
|
||||
return rem;
|
||||
};
|
||||
|
||||
// Represent the full output tensor
|
||||
ElementC const* ptr_C = epilogue_op.is_source_needed() ? params.ptr_C[l_coord] : nullptr;
|
||||
auto dC = epilogue_op.is_source_needed() ? params.dC[l_coord] : InternalStrideC{};
|
||||
Tensor mC_mnl = make_tensor(make_gmem_ptr(ptr_C), make_shape(M, N, mock_L), dC); // (m,n,l)
|
||||
Tensor mD_mnl = make_gather_tensor(
|
||||
make_gmem_ptr(params.ptr_D), make_shape(M, N, mock_L), params.dD, get_scatter_idx); // (m,n,l)
|
||||
|
||||
// Use fake shape for bias, it doesn't matter
|
||||
bool const is_bias_needed = params.ptr_bias != nullptr;
|
||||
Tensor mBias_mnl = make_tensor(make_gmem_ptr(params.ptr_bias), make_shape(M, N, 1), params.dBias);
|
||||
Tensor mScale_mnl = make_tensor(
|
||||
make_gmem_ptr(params.ptr_scale + params.group_offset[l_coord]), make_shape(M, N), params.dScale);
|
||||
|
||||
Tensor gC_mnl
|
||||
= local_tile(mC_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l)
|
||||
Tensor gD_mnl
|
||||
= local_tile(mD_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l)
|
||||
|
||||
Tensor gC = gC_mnl(_, _, m_coord, n_coord, mock_l_coord); // (BLK_M,BLK_N)
|
||||
Tensor gD = gD_mnl(_, _, m_coord, n_coord, mock_l_coord); // (BLK_M,BLK_N)
|
||||
|
||||
Tensor gC_epi = flat_divide(gC, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
|
||||
Tensor gD_epi = flat_divide(gD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
|
||||
|
||||
Tensor gBias_mnl
|
||||
= local_tile(mBias_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l)
|
||||
Tensor gScale_mnl
|
||||
= local_tile(mScale_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l)
|
||||
|
||||
Tensor gBias = gBias_mnl(_, _, m_coord, n_coord, l_coord); // (BLK_M,BLK_N)
|
||||
Tensor gScale = gScale_mnl(_, _, m_coord, n_coord); // (BLK_M,BLK_N)
|
||||
|
||||
Tensor gBias_epi = flat_divide(gBias, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
|
||||
Tensor gScale_epi = flat_divide(gScale, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
|
||||
|
||||
// Get the smallest tiled copy we can use to retile the accumulators
|
||||
TiledCopy tiled_copy_C_atom
|
||||
= make_tiled_copy_C_atom(Copy_Atom<SM90_U32x4_STSM_N, cutlass::half_t>{}, tiled_mma);
|
||||
TiledCopy tiled_r2s = make_tiled_copy_S(CopyAtomR2S{}, tiled_copy_C_atom);
|
||||
|
||||
auto thread_r2s = tiled_r2s.get_thread_slice(thread_idx);
|
||||
Tensor tRS_rAcc = thread_r2s.retile_S(accumulators); // ((R2S,R2S_V),MMA_M,MMA_N)
|
||||
Tensor tRS_sD = thread_r2s.partition_D(sD); // ((R2S,R2S_V),R2S_M,R2S_N)
|
||||
Tensor tRS_rD = make_tensor<ElementAccumulator>(shape(tRS_sD)); // ((R2S,R2S_V),R2S_M,R2S_N)
|
||||
|
||||
// Make a tiled copy vectorized along major direction of D
|
||||
auto tiled_s2r = [&]()
|
||||
{
|
||||
if constexpr (cutlass::gemm::detail::is_k_major<StrideD>())
|
||||
{
|
||||
constexpr int NumThreadsMajor = epi_tile_n / AlignmentD;
|
||||
constexpr int NumThreadsMinor = cute::size(tiled_mma) / NumThreadsMajor;
|
||||
return make_tiled_copy(CopyAtomS2R{},
|
||||
Layout<Shape<Int<NumThreadsMinor>, Int<NumThreadsMajor>>, Stride<Int<NumThreadsMajor>, _1>>{},
|
||||
Layout<Shape<_1, Int<AlignmentD>>>{});
|
||||
}
|
||||
else if constexpr (cutlass::gemm::detail::is_mn_major<StrideD>())
|
||||
{
|
||||
constexpr int NumThreadsMajor = epi_tile_m / AlignmentD;
|
||||
constexpr int NumThreadsMinor = cute::size(tiled_mma) / NumThreadsMajor;
|
||||
return make_tiled_copy(CopyAtomS2R{},
|
||||
Layout<Shape<Int<NumThreadsMajor>, Int<NumThreadsMinor>>, Stride<_1, Int<NumThreadsMajor>>>{},
|
||||
Layout<Shape<Int<AlignmentD>, _1>>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(cute::is_void_v<StrideD>, "Unsupported D gmem layout.");
|
||||
}
|
||||
}();
|
||||
|
||||
auto thread_s2r = tiled_s2r.get_thread_slice(thread_idx);
|
||||
Tensor tSR_sD = thread_s2r.partition_S(sD); // ((S2R,S2R_V),S2R_M,S2R_N)
|
||||
Tensor tSR_gD = thread_s2r.partition_D(gD_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N)
|
||||
Tensor tSR_gC = thread_s2r.partition_D(gC_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N)
|
||||
Tensor tSR_gBias = thread_s2r.partition_D(gBias_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N)
|
||||
Tensor tSR_gScale = thread_s2r.partition_D(gScale_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N)
|
||||
|
||||
// Allocate intermediate registers for a single subtile
|
||||
Tensor tSR_rD = make_tensor<ElementAccumulator>(take<0, 3>(shape(tSR_gD))); // ((S2R,S2R_V),S2R_M,S2R_N)
|
||||
Tensor tSR_rD_final = make_tensor<ElementD>(shape(tSR_rD)); // ((S2R,S2R_V),S2R_M,S2R_N)
|
||||
Tensor tSR_rC = make_tensor<ElementC>(shape(tSR_rD)); // ((S2R,S2R_V),S2R_M,S2R_N)
|
||||
Tensor tSR_rBias = make_tensor<ElementBias>(tSR_gBias(_, _, _, 0, 0).layout()); // ((S2R,S2R_V),S2R_M,S2R_N)
|
||||
Tensor tSR_rScale = make_tensor<ElementScale>(tSR_gScale(_, _, _, 0, 0).layout()); // ((S2R,S2R_V),S2R_M,S2R_N)
|
||||
|
||||
// Make an identity coordinate tensor for predicating our output MN tile
|
||||
Tensor cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD))));
|
||||
Tensor cD_epi = flat_divide(cD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
|
||||
Tensor tSR_cD = thread_s2r.partition_D(cD_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N)
|
||||
|
||||
// epilogue subtile loop
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int epi_m = 0; epi_m < size<2>(gD_epi); ++epi_m)
|
||||
{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int epi_n = 0; epi_n < size<3>(gD_epi); ++epi_n)
|
||||
{
|
||||
int mma_m = (epi_m * epi_tile_m) / mma_tile_m;
|
||||
int mma_n = (epi_n * epi_tile_n) / mma_tile_n;
|
||||
Tensor tRS_rAcc_mn = tRS_rAcc(_, mma_m, mma_n);
|
||||
|
||||
int epi_n_in_mma = epi_n % (mma_tile_n / epi_tile_n);
|
||||
int r2s_v = epi_n_in_mma * size(tRS_rD);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int epi_v = 0; epi_v < size(tRS_rD); ++epi_v)
|
||||
{
|
||||
tRS_rD(epi_v) = tRS_rAcc_mn(r2s_v + epi_v);
|
||||
}
|
||||
|
||||
copy(tiled_r2s, tRS_rD, tRS_sD);
|
||||
synchronize();
|
||||
|
||||
copy(tiled_s2r, tSR_sD, tSR_rD);
|
||||
synchronize();
|
||||
|
||||
Tensor tSR_gC_mn = tSR_gC(_, _, _, epi_m, epi_n);
|
||||
Tensor tSR_gBias_mn = tSR_gBias(_, _, _, epi_m, epi_n);
|
||||
Tensor tSR_gScale_mn = tSR_gScale(_, _, _, epi_m, epi_n);
|
||||
Tensor tSR_cD_mn = tSR_cD(_, _, _, epi_m, epi_n);
|
||||
Tensor tSR_gD_mn = tSR_gD(_, _, _, epi_m, epi_n);
|
||||
|
||||
if (epilogue_op.is_source_needed())
|
||||
{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int m = 0; m < size<1>(tSR_rD); ++m)
|
||||
{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int n = 0; n < size<2>(tSR_rD); ++n)
|
||||
{
|
||||
if (elem_less(tSR_cD_mn(0, m, n), make_coord(get<0>(residue_mnk), get<1>(residue_mnk))))
|
||||
{
|
||||
copy(tSR_gC_mn(_, m, n), tSR_rC(_, m, n));
|
||||
if (is_bias_needed)
|
||||
{
|
||||
copy(tSR_gBias_mn(_, m, n), tSR_rBias(_, m, n));
|
||||
}
|
||||
copy(tSR_gScale_mn(_, m, n), tSR_rScale(_, m, n));
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size<0>(tSR_rD); ++i)
|
||||
{
|
||||
auto epi_value = epilogue_op(tSR_rD(i, m, n), tSR_rC(i, m, n));
|
||||
if (is_bias_needed)
|
||||
{
|
||||
epi_value += static_cast<ElementCompute>(tSR_rBias(i, m, n));
|
||||
}
|
||||
tSR_rD_final(i, m, n) = static_cast<ElementD>(tSR_rScale(i, m, n) * epi_value);
|
||||
}
|
||||
copy(CopyAtomR2G{}, tSR_rD_final(_, m, n), tSR_gD_mn(_, m, n));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int m = 0; m < size<1>(tSR_rD); ++m)
|
||||
{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int n = 0; n < size<2>(tSR_rD); ++n)
|
||||
{
|
||||
if (elem_less(tSR_cD_mn(0, m, n), make_coord(get<0>(residue_mnk), get<1>(residue_mnk))))
|
||||
{
|
||||
if (is_bias_needed)
|
||||
{
|
||||
copy(tSR_gBias_mn(_, m, n), tSR_rBias(_, m, n));
|
||||
}
|
||||
copy(tSR_gScale_mn(_, m, n), tSR_rScale(_, m, n));
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size<0>(tSR_rD); ++i)
|
||||
{
|
||||
auto epi_value = epilogue_op(tSR_rD(i, m, n));
|
||||
if (is_bias_needed)
|
||||
{
|
||||
epi_value += static_cast<ElementCompute>(tSR_rBias(i, m, n));
|
||||
}
|
||||
tSR_rD_final(i, m, n) = static_cast<ElementD>(tSR_rScale(i, m, n) * epi_value);
|
||||
}
|
||||
copy(CopyAtomR2G{}, tSR_rD_final(_, m, n), tSR_gD_mn(_, m, n));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
Params params;
|
||||
};
|
||||
|
||||
namespace detail
|
||||
{
|
||||
|
||||
template <class Element, class MaxVec>
|
||||
constexpr auto get_vectorized_atomic_add_op()
|
||||
{
|
||||
using namespace cute;
|
||||
|
||||
auto constexpr MaxVecSize = size(MaxVec{});
|
||||
|
||||
if constexpr (is_same_v<Element, cutlass::half_t>)
|
||||
{
|
||||
if constexpr (MaxVecSize >= 8)
|
||||
{
|
||||
return SM90_RED_ADD_NOFTZ_F16x2_V4{};
|
||||
}
|
||||
else if constexpr (MaxVecSize >= 4)
|
||||
{
|
||||
return SM90_RED_ADD_NOFTZ_F16x2_V2{};
|
||||
}
|
||||
else if constexpr (MaxVecSize >= 2)
|
||||
{
|
||||
return SM70_RED_ADD_NOFTZ_F16x2{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return SM70_RED_ADD_NOFTZ_F16{};
|
||||
}
|
||||
}
|
||||
else if constexpr (is_same_v<Element, cutlass::bfloat16_t>)
|
||||
{
|
||||
if constexpr (MaxVecSize >= 8)
|
||||
{
|
||||
return SM90_RED_ADD_NOFTZ_BF16x2_V4{};
|
||||
}
|
||||
else if constexpr (MaxVecSize >= 4)
|
||||
{
|
||||
return SM90_RED_ADD_NOFTZ_BF16x2_V2{};
|
||||
}
|
||||
else if constexpr (MaxVecSize >= 2)
|
||||
{
|
||||
return SM90_RED_ADD_NOFTZ_BF16x2{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return SM90_RED_ADD_NOFTZ_BF16{};
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// non-vectorized atomic add for all other types until supported
|
||||
return TypedAtomicAdd<Element>{};
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <class TileShape, class ElementC, class StrideC, class ElementD, class StrideD, class ElementAccumulator,
|
||||
class ElementCompute, class ElementBias, class StrideBias, class ElementScale, class StrideScale>
|
||||
struct EpilogueMoeFusedFinalizeBuilder
|
||||
{
|
||||
|
||||
// assuming cooperative kernel schedule
|
||||
using EpiTileN = decltype(cute::min(size<1>(TileShape{}), _32{}));
|
||||
using EpilogueTile = Shape<_128, EpiTileN>;
|
||||
|
||||
// Output of linear combination is ElementCompute instead of ElementD
|
||||
// since we will be doing more computate on it, no need to cast yet.
|
||||
using ThreadEpilogueOp
|
||||
= cutlass::epilogue::thread::LinearCombination<ElementCompute, 1, ElementAccumulator, ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::Default, cutlass::FloatRoundStyle::round_to_nearest, ElementC>;
|
||||
|
||||
using SmemLayoutAtomD
|
||||
= decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom<StrideD, ElementAccumulator, EpilogueTile>());
|
||||
using CopyAtomR2S = decltype(detail::sm90_get_smem_store_op_for_accumulator<StrideD, ElementAccumulator>());
|
||||
using CopyAtomS2R = DefaultCopy;
|
||||
using CopyAtomR2G = decltype(detail::get_vectorized_atomic_add_op<ElementD, EpiTileN>());
|
||||
|
||||
template <class EpilogueOp>
|
||||
struct Sm90TmaWarpSpecializedAdapterWithSmemStorage : detail::Sm90TmaWarpSpecializedAdapter<EpilogueOp>
|
||||
{
|
||||
// We need to override this one using declaration because otherwise we double up on the smem
|
||||
using TensorMapStorage = typename EpilogueOp::TensorMapStorage;
|
||||
|
||||
using Base = detail::Sm90TmaWarpSpecializedAdapter<EpilogueOp>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90TmaWarpSpecializedAdapterWithSmemStorage(
|
||||
typename EpilogueOp::Params const& params, [[maybe_unused]] typename Base::TensorStorage& shared_tensors)
|
||||
: Base(params)
|
||||
{
|
||||
}
|
||||
|
||||
// These functions depend on the type of TensorMapStorage
|
||||
template <bool IsLoad>
|
||||
CUTLASS_DEVICE void tensormaps_perform_update([[maybe_unused]] TensorMapStorage& shared_tensormap,
|
||||
[[maybe_unused]] typename EpilogueOp::Params const& params,
|
||||
[[maybe_unused]] cute::TmaDescriptor const* tensormap, [[maybe_unused]] int32_t next_batch)
|
||||
{
|
||||
}
|
||||
|
||||
template <bool IsLoad>
|
||||
CUTLASS_DEVICE void tensormaps_cp_fence_release([[maybe_unused]] TensorMapStorage& shared_tensormap,
|
||||
[[maybe_unused]] cute::TmaDescriptor const* tensormap, [[maybe_unused]] uint32_t lane_predicate)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
using CollectiveOp = Sm90TmaWarpSpecializedAdapterWithSmemStorage<
|
||||
EpilogueMoeFusedFinalize<StrideC, ElementD, StrideD, ThreadEpilogueOp, ElementBias, StrideBias, ElementScale,
|
||||
StrideScale, EpilogueTile, SmemLayoutAtomD, CopyAtomR2S, CopyAtomS2R, CopyAtomR2G>>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace collective
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,105 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Functor performing linear combination with a maximum operation used by epilogues.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_generic.h"
|
||||
#include "cutlass/epilogue/thread/scale_type.h"
|
||||
#include "cutlass/functional.h"
|
||||
#include "cutlass/half.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace epilogue
|
||||
{
|
||||
namespace thread
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
__forceinline__ __device__ float copysignf_pos(float a, float b)
|
||||
{
|
||||
float r;
|
||||
r = __int_as_float(__float_as_int(a) | (__float_as_int(b) & 0x80000000));
|
||||
return r;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ float tanh_opt(float x)
|
||||
{
|
||||
#if (__CUDACC_VER_MAJOR__ < 11) || (__CUDA_ARCH__ < 750)
|
||||
float const exp_val = -1.f * fabs(2 * x);
|
||||
return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x);
|
||||
#else
|
||||
return fast_tanh(x);
|
||||
#endif
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
template <>
|
||||
struct GELU_taylor<float>
|
||||
{
|
||||
static bool const kIsHeavy = true;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
float operator()(float const& z) const
|
||||
{
|
||||
|
||||
float k0 = float(0.7978845608028654);
|
||||
float k1 = float(0.044715);
|
||||
|
||||
return float(cutlass::constants::half<float>() * z
|
||||
* (cutlass::constants::one<float>() + tanh_opt(k0 * z * (cutlass::constants::one<float>() + k1 * z * z))));
|
||||
}
|
||||
|
||||
using Params = LinearCombinationGenericParams<float>;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
float operator()(float const& scalar, Params const& params_) const
|
||||
{
|
||||
return this->operator()(scalar);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace thread
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,352 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Epilogue visitor for threadblock scoped INT8 GEMMs that uses one scaling factor per row, and one per column.
|
||||
|
||||
original file: 3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h
|
||||
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/arch/memory_sm75.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "tensorrt_llm/common/quantization.h"
|
||||
|
||||
namespace tk = tensorrt_llm::common;
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace epilogue
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
|
||||
template <typename ThreadblockShape_, int ThreadCount, typename ScaleTileIterator_, typename OutputTileIterator_,
|
||||
typename ElementAccumulator_, typename ElementCompute_, typename ElementwiseFunctor_, bool UseMasking_ = false>
|
||||
class EpilogueVisitorPerRowPerCol
|
||||
{
|
||||
public:
|
||||
using ThreadblockShape = ThreadblockShape_;
|
||||
static int const kThreadCount = ThreadCount;
|
||||
|
||||
using ScaleTileIterator = ScaleTileIterator_;
|
||||
using OutputTileIterator = OutputTileIterator_;
|
||||
using ElementwiseFunctor = ElementwiseFunctor_;
|
||||
|
||||
static int const kIterations = OutputTileIterator::kIterations;
|
||||
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
using ElementOutput = typename OutputTileIterator::Element;
|
||||
using LayoutOutput = cutlass::layout::RowMajor;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
|
||||
using AlphaScaleElementType = typename ScaleTileIterator::Element;
|
||||
|
||||
using ElementCompute = ElementCompute_;
|
||||
using AccumulatorFragment = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
using ComputeFragment = Array<ElementCompute_, kElementsPerAccess>;
|
||||
using OutputVector = Array<ElementOutput, kElementsPerAccess>;
|
||||
|
||||
static int const kThreadsPerRow = OutputTileIterator::ThreadMap::Detail::kAccessWidth;
|
||||
static bool const kHasMultiStepsInRow = (OutputTileIterator::ThreadMap::Iterations::kColumn > 1);
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments
|
||||
{
|
||||
|
||||
typename ElementwiseFunctor::Params elementwise;
|
||||
int64_t batch_stride_alpha;
|
||||
int64_t batch_stride_C;
|
||||
int64_t batch_stride_D;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
Arguments()
|
||||
: batch_stride_alpha(0)
|
||||
, batch_stride_C(0)
|
||||
, batch_stride_D(0)
|
||||
{
|
||||
}
|
||||
|
||||
Arguments(typename ElementwiseFunctor::Params elementwise_)
|
||||
: elementwise(elementwise_)
|
||||
, batch_stride_alpha(0)
|
||||
, batch_stride_C(0)
|
||||
, batch_stride_D(0)
|
||||
{
|
||||
}
|
||||
|
||||
Arguments(typename ElementwiseFunctor::Params elementwise_, int64_t batch_stride_alpha_,
|
||||
int64_t batch_stride_C_, int64_t batch_stride_D_)
|
||||
: elementwise(elementwise_)
|
||||
, batch_stride_alpha(batch_stride_alpha_)
|
||||
, batch_stride_C(batch_stride_C_)
|
||||
, batch_stride_D(batch_stride_D_)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
struct Params
|
||||
{
|
||||
|
||||
typename ElementwiseFunctor::Params elementwise;
|
||||
int64_t batch_stride_alpha;
|
||||
int64_t batch_stride_C;
|
||||
int64_t batch_stride_D;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const& args)
|
||||
: elementwise(args.elementwise)
|
||||
, batch_stride_alpha(args.batch_stride_alpha)
|
||||
, batch_stride_C(args.batch_stride_C)
|
||||
, batch_stride_D(args.batch_stride_D)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared storage
|
||||
struct SharedStorage
|
||||
{
|
||||
};
|
||||
|
||||
private:
|
||||
Params const& params_;
|
||||
SharedStorage& shared_storage_;
|
||||
MatrixCoord extent_;
|
||||
MatrixCoord extent_real_;
|
||||
ElementwiseFunctor elementwise_;
|
||||
|
||||
bool const per_token_quant_;
|
||||
bool const per_channel_quant_;
|
||||
|
||||
AlphaScaleElementType* ptr_alpha_row_;
|
||||
AlphaScaleElementType* ptr_alpha_col_;
|
||||
ScaleTileIterator iterator_alpha_col_;
|
||||
OutputTileIterator iterator_C_;
|
||||
OutputTileIterator iterator_D_;
|
||||
|
||||
AlphaScaleElementType element_alpha_row_ = 1.0f;
|
||||
AlphaScaleElementType element_alpha_col_ = 1.0f;
|
||||
typename ScaleTileIterator::Fragment fragment_alpha_col_;
|
||||
typename OutputTileIterator::Fragment fragment_C_;
|
||||
typename OutputTileIterator::Fragment fragment_D_;
|
||||
|
||||
ElementAccumulator beta_;
|
||||
|
||||
int column_offset_;
|
||||
|
||||
MatrixCoord thread_offset_;
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
EpilogueVisitorPerRowPerCol(Params const& params, SharedStorage& shared_storage,
|
||||
cutlass::MatrixCoord const& problem_size, int thread_idx, int warp_idx, int lane_idx,
|
||||
typename ScaleTileIterator::Params params_alpha_col, typename OutputTileIterator::Params params_C,
|
||||
typename OutputTileIterator::Params params_D, tk::QuantMode quant_option, AlphaScaleElementType* ptr_alpha_row,
|
||||
AlphaScaleElementType* ptr_alpha_col, typename OutputTileIterator::Element* ptr_C,
|
||||
typename OutputTileIterator::Element* ptr_D,
|
||||
cutlass::MatrixCoord const& threadblock_offset = cutlass::MatrixCoord(0, 0), int column_offset = 0,
|
||||
cutlass::MatrixCoord const& problem_size_real = cutlass::MatrixCoord(0, 0))
|
||||
: params_(params)
|
||||
, shared_storage_(shared_storage)
|
||||
, extent_(problem_size)
|
||||
, elementwise_(params.elementwise)
|
||||
, per_token_quant_(quant_option.hasPerTokenScaling())
|
||||
, per_channel_quant_(quant_option.hasPerChannelScaling())
|
||||
, ptr_alpha_row_(ptr_alpha_row)
|
||||
, ptr_alpha_col_(ptr_alpha_col)
|
||||
, iterator_alpha_col_(params_alpha_col, ptr_alpha_col, problem_size, thread_idx, threadblock_offset)
|
||||
, iterator_C_(params_C, ptr_C, problem_size, thread_idx, threadblock_offset)
|
||||
, iterator_D_(params_D, ptr_D, problem_size, thread_idx, threadblock_offset)
|
||||
, extent_real_(problem_size_real)
|
||||
{
|
||||
beta_ = (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr : params.elementwise.beta);
|
||||
|
||||
if (beta_ == ElementAccumulator())
|
||||
{
|
||||
iterator_C_.clear_mask();
|
||||
}
|
||||
|
||||
if (!per_channel_quant_ && (ptr_alpha_col_ != nullptr))
|
||||
{
|
||||
element_alpha_col_ = *ptr_alpha_col_;
|
||||
}
|
||||
|
||||
if (!per_token_quant_ && (ptr_alpha_row_ != nullptr))
|
||||
{
|
||||
element_alpha_row_ = *ptr_alpha_row_;
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to indicate split-K behavior
|
||||
CUTLASS_DEVICE
|
||||
void set_k_partition(int split_k_index, ///< Index of this threadblock within split-K partitioned scheme
|
||||
int split_k_slices)
|
||||
{ ///< Total number of split-K slices
|
||||
}
|
||||
|
||||
/// Called to set the batch index
|
||||
CUTLASS_DEVICE
|
||||
void set_batch_index(int batch_idx)
|
||||
{
|
||||
iterator_alpha_col_.add_pointer_offset(batch_idx * params_.batch_stride_alpha);
|
||||
iterator_C_.add_pointer_offset(batch_idx * params_.batch_stride_C);
|
||||
iterator_D_.add_pointer_offset(batch_idx * params_.batch_stride_D);
|
||||
}
|
||||
|
||||
/// Called at the start of the epilogue just before iterating over accumulator slices
|
||||
CUTLASS_DEVICE
|
||||
void begin_epilogue()
|
||||
{
|
||||
if (per_channel_quant_)
|
||||
{
|
||||
iterator_alpha_col_.load(fragment_alpha_col_);
|
||||
}
|
||||
}
|
||||
|
||||
/// Called at the start of one step before starting accumulator exchange
|
||||
CUTLASS_DEVICE
|
||||
void begin_step(int step_idx)
|
||||
{
|
||||
fragment_D_.clear();
|
||||
fragment_C_.clear();
|
||||
|
||||
if (elementwise_.kScale != cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling)
|
||||
{
|
||||
iterator_C_.load(fragment_C_);
|
||||
++iterator_C_;
|
||||
}
|
||||
}
|
||||
|
||||
/// Called at the start of a row
|
||||
CUTLASS_DEVICE
|
||||
void begin_row(int row_idx)
|
||||
{
|
||||
// load alpha_row in begin_step only when per token(row) scaling is used
|
||||
if (per_token_quant_)
|
||||
{
|
||||
int thread_offset_row
|
||||
= iterator_D_.thread_start_row() + OutputTileIterator::ThreadMap::iteration_offset(row_idx).row();
|
||||
|
||||
arch::global_load<AlphaScaleElementType, sizeof(AlphaScaleElementType)>(
|
||||
element_alpha_row_, ptr_alpha_row_ + thread_offset_row, thread_offset_row < extent_.row());
|
||||
}
|
||||
}
|
||||
|
||||
/// Called after accumulators have been exchanged for each accumulator vector
|
||||
CUTLASS_DEVICE
|
||||
void visit(int iter_idx, int row_idx, int column_idx, int frag_idx, AccumulatorFragment const& accum)
|
||||
{
|
||||
|
||||
NumericArrayConverter<ElementCompute, ElementAccumulator, kElementsPerAccess> source_converter;
|
||||
|
||||
ComputeFragment result = source_converter(accum);
|
||||
if (per_channel_quant_)
|
||||
{
|
||||
ComputeFragment alpha_col = reinterpret_cast<ComputeFragment*>(&fragment_alpha_col_)[column_idx];
|
||||
result = per_token_channel_scale_accumulator_(result, alpha_col, element_alpha_row_);
|
||||
}
|
||||
else
|
||||
{
|
||||
result = per_token_scale_accumulator_(result, element_alpha_col_, element_alpha_row_);
|
||||
}
|
||||
|
||||
// Convert to the output
|
||||
NumericArrayConverter<ElementOutput, ElementCompute, kElementsPerAccess> output_converter;
|
||||
OutputVector& output = reinterpret_cast<OutputVector*>(&fragment_D_)[frag_idx];
|
||||
output = output_converter(result);
|
||||
}
|
||||
|
||||
/// Called at the end of a row
|
||||
CUTLASS_DEVICE
|
||||
void end_row(int row_idx) {}
|
||||
|
||||
/// Called after all accumulator elements have been visited
|
||||
CUTLASS_DEVICE
|
||||
void end_step(int step_idx)
|
||||
{
|
||||
|
||||
iterator_D_.store(fragment_D_);
|
||||
++iterator_D_;
|
||||
}
|
||||
|
||||
/// Called after all steps have been completed
|
||||
CUTLASS_DEVICE
|
||||
void end_epilogue() {}
|
||||
|
||||
private:
|
||||
CUTLASS_DEVICE
|
||||
ComputeFragment per_token_channel_scale_accumulator_(
|
||||
ComputeFragment const& accum, ComputeFragment const& scale_col, AlphaScaleElementType const& scale_row)
|
||||
{
|
||||
|
||||
ComputeFragment result;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < ComputeFragment::kElements; ++i)
|
||||
{
|
||||
result[i] = accum[i] * (scale_col[i] * scale_row);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ComputeFragment per_token_scale_accumulator_(
|
||||
ComputeFragment const& accum, AlphaScaleElementType const& scale_col, AlphaScaleElementType const& scale_row)
|
||||
{
|
||||
|
||||
ComputeFragment result;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < ComputeFragment::kElements; ++i)
|
||||
{
|
||||
result[i] = accum[i] * (scale_col * scale_row);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
@@ -0,0 +1,282 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
|
||||
|
||||
The epilogue rearranges the result of a matrix product through shared memory to match canonical
|
||||
tensor layouts in global memory. Epilogues support conversion and reduction operations.
|
||||
|
||||
original file: 3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h
|
||||
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/platform/platform.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_clamp.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_gelu.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_hardswish.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_planar_complex.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_relu.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_relu0.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_sigmoid.h"
|
||||
|
||||
#include "cutlass/epilogue/thread/conversion_op.h"
|
||||
#include "cutlass/epilogue/thread/reduction_op.h"
|
||||
|
||||
#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h"
|
||||
|
||||
#include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h"
|
||||
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
|
||||
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h"
|
||||
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h"
|
||||
#include "cutlass/epilogue/threadblock/shared_load_iterator.h"
|
||||
#include "cutlass/epilogue/threadblock/shared_load_iterator_mixed.h"
|
||||
#include "cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h"
|
||||
#include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h"
|
||||
#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h"
|
||||
#include "cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h"
|
||||
|
||||
#include "cutlass/epilogue/threadblock/epilogue.h"
|
||||
#include "cutlass/epilogue/threadblock/interleaved_epilogue.h"
|
||||
|
||||
#include "cutlass/layout/permute.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace epilogue
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace detail
|
||||
{
|
||||
|
||||
/// Partial specialization for bfloat16_t <= int32_t x 8 epilogues avoids shared memory bank conflicts.
|
||||
template <typename ThreadblockShape, typename WarpShape, typename InstructionShape, typename ThreadMap>
|
||||
struct DefaultIteratorsTensorOp<cutlass::bfloat16_t, int32_t, 8, ThreadblockShape, WarpShape, InstructionShape,
|
||||
ThreadMap>
|
||||
{
|
||||
using WarpTileIterator
|
||||
= cutlass::epilogue::warp::TileIteratorTensorOpMixed<WarpShape, InstructionShape, int32_t, 32, 16, 8, 8>;
|
||||
|
||||
using SharedLoadIterator
|
||||
= cutlass::epilogue::threadblock::SharedLoadIteratorMixed<ThreadMap, int32_t, 32, 16, 8, 8>;
|
||||
|
||||
static int const kFragmentsPerIteration = 2;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace detail
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Tile iterator used to load output tile from shared memory in epilogue.
|
||||
///
|
||||
/// Satisfies: ReadableTileIterator
|
||||
///
|
||||
template <typename ThreadMap_ ///< Thread map (concept: OutputTileThreadMap)
|
||||
>
|
||||
class SharedLoadIteratorMixed<ThreadMap_, int32_t, 32, 16, 8, 8>
|
||||
{
|
||||
public:
|
||||
using ThreadMap = ThreadMap_;
|
||||
using Shape = typename ThreadMap::Shape;
|
||||
|
||||
using Element = int32_t;
|
||||
|
||||
using Layout = layout::RowMajor;
|
||||
using TensorRef = TensorRef<Element, Layout>;
|
||||
using ConstTensorRef = typename TensorRef::ConstTensorRef;
|
||||
|
||||
using Index = typename Layout::Index;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
using TensorCoord = MatrixCoord;
|
||||
|
||||
static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
|
||||
|
||||
static int const kAlignment = ThreadMap::kElementsPerAccess * sizeof_bits<Element>::value / 8;
|
||||
|
||||
static int const kThreads = ThreadMap::kThreads;
|
||||
|
||||
/// Fragment object
|
||||
using Fragment = Array<Element,
|
||||
ThreadMap::Iterations::kColumn * ThreadMap::Iterations::kRow * ThreadMap::Iterations::kGroup
|
||||
* ThreadMap::Iterations::kCluster * ThreadMap::kElementsPerAccess>;
|
||||
|
||||
/// Memory access size
|
||||
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess, kAlignment>;
|
||||
|
||||
/// Vector type used for SMEM loads
|
||||
using LoadType = AlignedArray<Element, const_min(128 / sizeof_bits<Element>::value, ThreadMap::kElementsPerAccess),
|
||||
const_min(16, kAlignment)>;
|
||||
|
||||
static int const kLoadsPerAccess = AccessType::kElements / LoadType::kElements;
|
||||
|
||||
private:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Byte-level pointer
|
||||
LoadType const* pointers_[kLoadsPerAccess];
|
||||
|
||||
/// Stride along adjacent rows in units of LoadType
|
||||
int stride_;
|
||||
|
||||
public:
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_DEVICE
|
||||
SharedLoadIteratorMixed(TensorRef ref, int thread_idx)
|
||||
: stride_((ref.stride(0) / LoadType::kElements))
|
||||
{
|
||||
|
||||
TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx);
|
||||
|
||||
// Initialize pointers
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kLoadsPerAccess; ++i)
|
||||
{
|
||||
pointers_[i] = reinterpret_cast<LoadType const*>(ref.data());
|
||||
|
||||
int col_idx = (thread_offset.column() / kElementsPerAccess) * kLoadsPerAccess;
|
||||
int bank_offset = (col_idx * int(sizeof(LoadType)) / 128) % kLoadsPerAccess;
|
||||
|
||||
col_idx += (bank_offset + i) % kLoadsPerAccess;
|
||||
|
||||
pointers_[i] += thread_offset.row() * stride_ + col_idx;
|
||||
}
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
CUTLASS_HOST_DEVICE
|
||||
void add_pointer_offset(LongIndex pointer_offset)
|
||||
{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kLoadsPerAccess; ++i)
|
||||
{
|
||||
pointers_[i] += pointer_offset / LoadType::kElements;
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void add_tile_offset(TensorCoord const& offset)
|
||||
{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kLoadsPerAccess; ++i)
|
||||
{
|
||||
pointers_[i]
|
||||
+= offset.row() * Shape::kRow * stride_ + offset.column() * Shape::kColumn / LoadType::kElements;
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory
|
||||
CUTLASS_DEVICE
|
||||
void load_with_pointer_offset(Fragment& frag, Index pointer_offset) const
|
||||
{
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster)
|
||||
{
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group)
|
||||
{
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int row = 0; row < ThreadMap::Iterations::kRow; ++row)
|
||||
{
|
||||
|
||||
int row_ptr_offset = row * ThreadMap::Delta::kRow * stride_
|
||||
+ group * ThreadMap::Delta::kGroup * stride_ + cluster * ThreadMap::Delta::kCluster * stride_
|
||||
+ pointer_offset / LoadType::kElements;
|
||||
|
||||
int frag_row_idx
|
||||
= (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster));
|
||||
|
||||
LoadType* frag_ptr = reinterpret_cast<LoadType*>(&frag);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column)
|
||||
{
|
||||
|
||||
int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn + column;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < kLoadsPerAccess; ++v)
|
||||
{
|
||||
|
||||
int vector_idx
|
||||
= (column * ThreadMap::Delta::kColumn / kElementsPerAccess * kLoadsPerAccess);
|
||||
|
||||
LoadType const* memory_pointer = pointers_[v] + row_ptr_offset;
|
||||
|
||||
frag_ptr[frag_idx * kLoadsPerAccess + v] = memory_pointer[vector_idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads a fragment
|
||||
CUTLASS_DEVICE
|
||||
void load(Fragment& frag) const
|
||||
{
|
||||
|
||||
load_with_pointer_offset(frag, 0);
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
141
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue_helpers.h
vendored
Normal file
141
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue_helpers.h
vendored
Normal file
@@ -0,0 +1,141 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
/**
|
||||
* @file epilogue_helpers.h
|
||||
*
|
||||
* This file includes types for the epilogues. The empty structs exist so we can signal to template
|
||||
* code the type of epilogue we want to run, and let the underlying code specify the details such as
|
||||
* element types, accumulator type and elements per vector access.
|
||||
*
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_generic.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_relu.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_silu.h"
|
||||
#include "cutlass_extensions/epilogue/thread/fused_activations.h"
|
||||
#include <cutlass/epilogue/fusion/operations.hpp>
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace cutlass_extensions
|
||||
{
|
||||
|
||||
struct EpilogueOpBiasSilu
|
||||
{
|
||||
};
|
||||
|
||||
struct EpilogueOpBiasReLU
|
||||
{
|
||||
};
|
||||
|
||||
struct EpilogueOpBiasFtGelu
|
||||
{
|
||||
};
|
||||
|
||||
struct EpilogueOpBias
|
||||
{
|
||||
};
|
||||
|
||||
struct EpilogueOpDefaultSilu
|
||||
{
|
||||
};
|
||||
|
||||
struct EpilogueOpDefaultReLU
|
||||
{
|
||||
};
|
||||
|
||||
struct EpilogueOpDefaultFtGelu
|
||||
{
|
||||
};
|
||||
|
||||
struct EpilogueOpDefault
|
||||
{
|
||||
};
|
||||
|
||||
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator, typename Op>
|
||||
struct Epilogue
|
||||
{
|
||||
static_assert(sizeof(ElementType) == 0, "Unrecognized Epilogue Tag");
|
||||
};
|
||||
|
||||
constexpr auto BiasScaleMode = cutlass::epilogue::thread::ScaleType::NoBetaScaling;
|
||||
|
||||
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
|
||||
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpBiasSilu>
|
||||
{
|
||||
using Op = cutlass::epilogue::thread::LinearCombinationSilu<ElementType, ElementsPerVectorAccess,
|
||||
ElementAccumulator, ElementAccumulator, BiasScaleMode>;
|
||||
};
|
||||
|
||||
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
|
||||
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpBiasReLU>
|
||||
{
|
||||
using Op = cutlass::epilogue::thread::LinearCombinationRelu<ElementType, ElementsPerVectorAccess,
|
||||
ElementAccumulator, ElementAccumulator, BiasScaleMode>;
|
||||
};
|
||||
|
||||
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
|
||||
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpBiasFtGelu>
|
||||
{
|
||||
using Op = cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::GELU_taylor, ElementType,
|
||||
ElementsPerVectorAccess, ElementAccumulator, ElementAccumulator, BiasScaleMode,
|
||||
cutlass::FloatRoundStyle::round_to_nearest, true>;
|
||||
};
|
||||
|
||||
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
|
||||
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpBias>
|
||||
{
|
||||
using Op = cutlass::epilogue::thread::LinearCombination<ElementType, ElementsPerVectorAccess, ElementAccumulator,
|
||||
ElementAccumulator, BiasScaleMode>;
|
||||
};
|
||||
|
||||
constexpr auto DefaultScaleMode = cutlass::epilogue::thread::ScaleType::Default;
|
||||
|
||||
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
|
||||
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpDefaultSilu>
|
||||
{
|
||||
using Op = cutlass::epilogue::thread::LinearCombinationSilu<ElementType, ElementsPerVectorAccess,
|
||||
ElementAccumulator, ElementAccumulator, DefaultScaleMode>;
|
||||
};
|
||||
|
||||
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
|
||||
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpDefaultReLU>
|
||||
{
|
||||
using Op = cutlass::epilogue::thread::LinearCombinationRelu<ElementType, ElementsPerVectorAccess,
|
||||
ElementAccumulator, ElementAccumulator, DefaultScaleMode>;
|
||||
};
|
||||
|
||||
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
|
||||
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpDefaultFtGelu>
|
||||
{
|
||||
using Op = cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::GELU_taylor, ElementType,
|
||||
ElementsPerVectorAccess, ElementAccumulator, ElementAccumulator, DefaultScaleMode,
|
||||
cutlass::FloatRoundStyle::round_to_nearest, true>;
|
||||
};
|
||||
|
||||
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
|
||||
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpDefault>
|
||||
{
|
||||
using Op = cutlass::epilogue::thread::LinearCombination<ElementType, ElementsPerVectorAccess, ElementAccumulator,
|
||||
ElementAccumulator, DefaultScaleMode>;
|
||||
};
|
||||
|
||||
} // namespace cutlass_extensions
|
||||
} // namespace tensorrt_llm
|
||||
@@ -0,0 +1,221 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/arch/mma.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
#include "cutlass/gemm/collective/builders/sm90_common.inl"
|
||||
|
||||
// SM90 Collective Builders should be used only starting CUDA 12.0
|
||||
#if (__CUDACC_VER_MAJOR__ >= 12)
|
||||
#define CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
|
||||
#endif
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::collective
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace detail
|
||||
{
|
||||
|
||||
// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count.
|
||||
template <int CapacityBytes, class ElementA, class ElementB, class TileShapeMNK, bool SwapAB, int carveout_bytes>
|
||||
constexpr int compute_stage_count_or_override_gated(StageCountAutoCarveout<carveout_bytes> stage_count)
|
||||
{
|
||||
// 32 bytes to account for barriers etc.
|
||||
constexpr int stage_barrier_bytes = 32;
|
||||
constexpr int a_bits = static_cast<int>(sizeof_bits<ElementA>::value);
|
||||
constexpr int b_bits = static_cast<int>(sizeof_bits<ElementB>::value);
|
||||
constexpr int stage_bytes = [&]() -> int
|
||||
{
|
||||
if constexpr (SwapAB)
|
||||
{
|
||||
return (a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{}) * 2) / 8
|
||||
+ (b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) / 8 + stage_barrier_bytes;
|
||||
}
|
||||
else
|
||||
{
|
||||
return (a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) / 8
|
||||
+ (b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{}) * 2) / 8 + stage_barrier_bytes;
|
||||
}
|
||||
}();
|
||||
|
||||
return (CapacityBytes - carveout_bytes) / stage_bytes;
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// GMMA_TMA_WS_SS
|
||||
template <class ElementA, class GmemLayoutA, int AlignmentA, class ElementB, class GmemLayoutB, int AlignmentB,
|
||||
class ElementAccumulator, class TileShape_MNK, class ClusterShape_MNK, class StageCountType,
|
||||
class KernelScheduleType, template <class /* ElementCompute */> class Activation, bool SwapAB>
|
||||
struct CollectiveBuilderGated<arch::Sm90, arch::OpClassTensorOp, ElementA, GmemLayoutA, AlignmentA, ElementB,
|
||||
GmemLayoutB, AlignmentB, ElementAccumulator, TileShape_MNK, ClusterShape_MNK, StageCountType, KernelScheduleType,
|
||||
Activation, SwapAB,
|
||||
cute::enable_if_t<(cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecialized>
|
||||
|| cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpong>
|
||||
|| cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperative>
|
||||
|| cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedCooperative>) &¬ detail::
|
||||
is_use_rmem_A<ElementA, GmemLayoutA, ElementB, GmemLayoutB>()>>
|
||||
{
|
||||
static_assert(is_static<TileShape_MNK>::value);
|
||||
static_assert(is_static<ClusterShape_MNK>::value);
|
||||
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
|
||||
static_assert(cutlass::detail::dependent_false<ElementA>, "Unsupported Toolkit for SM90 Collective Builder\n");
|
||||
#endif
|
||||
static_assert(detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
|
||||
"Should meet TMA alignment requirement\n");
|
||||
|
||||
static constexpr bool IsArrayOfPointersGemm
|
||||
= (cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedCooperative>);
|
||||
static constexpr bool IsFP8Input = detail::is_input_fp8<ElementA, ElementB>();
|
||||
static_assert(!IsFP8Input || (IsFP8Input && !IsArrayOfPointersGemm),
|
||||
"Kernel[Array/Group]TmaWarpSpecializedCooperative is only compatible with FP8 FastAccum version right now\n");
|
||||
|
||||
// For fp32 types, map to tf32 MMA value type
|
||||
using MmaElementA = cute::conditional_t<cute::is_same_v<ElementA, float>, tfloat32_t, ElementA>;
|
||||
using MmaElementB = cute::conditional_t<cute::is_same_v<ElementB, float>, tfloat32_t, ElementB>;
|
||||
|
||||
static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A<MmaElementA, GmemLayoutA>();
|
||||
static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B<MmaElementB, GmemLayoutB>();
|
||||
|
||||
using AtomLayoutMNK = cute::conditional_t<cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperative>
|
||||
|| IsArrayOfPointersGemm,
|
||||
Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;
|
||||
|
||||
using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector<MmaElementA, MmaElementB,
|
||||
ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(),
|
||||
AtomLayoutMNK{}));
|
||||
|
||||
using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
|
||||
using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));
|
||||
|
||||
using SmemLayoutAtomA = decltype(detail::ss_smem_selector<GmmaMajorA, MmaElementA,
|
||||
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
using SmemLayoutAtomB = decltype(detail::ss_smem_selector<GmmaMajorB, MmaElementB,
|
||||
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
|
||||
static constexpr int PipelineStages
|
||||
= detail::compute_stage_count_or_override_gated<detail::sm90_smem_capacity_bytes, MmaElementA, MmaElementB,
|
||||
TileShape_MNK, SwapAB>(StageCountType{});
|
||||
using DispatchPolicy = cute::conditional_t<IsArrayOfPointersGemm,
|
||||
MainloopSm90ArrayTmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>,
|
||||
/* For FP8 use a separate mainloop compared to other datatypes */
|
||||
cute::conditional_t<IsFP8Input,
|
||||
MainloopSm90TmaGmmaWarpSpecializedFP8<PipelineStages, ClusterShape_MNK, KernelScheduleType>,
|
||||
MainloopSm90TmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>>>;
|
||||
|
||||
using SmemCopyAtomA = void;
|
||||
using SmemCopyAtomB = void;
|
||||
|
||||
using CollectiveOp = CollectiveMmaGated<DispatchPolicy, TileShape_MNK, ElementA, TagToStrideA_t<GmemLayoutA>,
|
||||
ElementB, TagToStrideB_t<GmemLayoutB>, TiledMma, GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity,
|
||||
GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity, Activation, SwapAB>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// GMMA_TMA_WS_FP8_FAST_ACCUM_SS
|
||||
template <class ElementA, class GmemLayoutA, int AlignmentA, class ElementB, class GmemLayoutB, int AlignmentB,
|
||||
class ElementAccumulator, class TileShape_MNK, class ClusterShape_MNK, class StageCountType,
|
||||
class KernelScheduleType, template <class /* ElementCompute */> class Activation, bool SwapAB>
|
||||
struct CollectiveBuilderGated<arch::Sm90, arch::OpClassTensorOp, ElementA, GmemLayoutA, AlignmentA, ElementB,
|
||||
GmemLayoutB, AlignmentB, ElementAccumulator, TileShape_MNK, ClusterShape_MNK, StageCountType, KernelScheduleType,
|
||||
Activation, SwapAB,
|
||||
cute::enable_if_t<cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedFP8FastAccum>
|
||||
|| cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpongFP8FastAccum>
|
||||
|| cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperativeFP8FastAccum>
|
||||
|| cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum>>>
|
||||
{
|
||||
static_assert(is_static<TileShape_MNK>::value);
|
||||
static_assert(is_static<ClusterShape_MNK>::value);
|
||||
static_assert(detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
|
||||
"Not meet TMA alignment requirement yet\n");
|
||||
static_assert(
|
||||
detail::is_input_fp8<ElementA, ElementB>(), "Only FP8 datatypes are compatible with these kernel schedules\n");
|
||||
// Dispatch TN fp8 kernels only to TMA warp specialized FP8 builder
|
||||
static_assert(!detail::is_use_rmem_A<ElementA, GmemLayoutA, ElementB, GmemLayoutB>(),
|
||||
"Not supported for fp8 non-TN warp specialized kernels yet\n");
|
||||
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
|
||||
static_assert(cutlass::detail::dependent_false<ElementA>, "Unsupported Toolkit for SM90 Collective Builder\n");
|
||||
#endif
|
||||
|
||||
static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A<ElementA, GmemLayoutA>();
|
||||
static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B<ElementB, GmemLayoutB>();
|
||||
|
||||
static constexpr bool IsArrayOfPointersGemm
|
||||
= (cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum>);
|
||||
using AtomLayoutMNK
|
||||
= cute::conditional_t<cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperativeFP8FastAccum>
|
||||
|| IsArrayOfPointersGemm,
|
||||
Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;
|
||||
|
||||
using TiledMma = decltype(cute::make_tiled_mma(
|
||||
cute::GMMA::ss_op_selector<ElementA, ElementB, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(),
|
||||
AtomLayoutMNK{}));
|
||||
|
||||
using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
|
||||
using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));
|
||||
|
||||
using SmemLayoutAtomA = decltype(detail::ss_smem_selector<GmmaMajorA, ElementA,
|
||||
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
using SmemLayoutAtomB = decltype(detail::ss_smem_selector<GmmaMajorB, ElementB,
|
||||
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
|
||||
static constexpr int PipelineStages
|
||||
= detail::compute_stage_count_or_override_gated<detail::sm90_smem_capacity_bytes, ElementA, ElementB,
|
||||
TileShape_MNK, SwapAB>(StageCountType{});
|
||||
using DispatchPolicy = cute::conditional_t<IsArrayOfPointersGemm,
|
||||
MainloopSm90ArrayTmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>,
|
||||
MainloopSm90TmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>>;
|
||||
|
||||
using SmemCopyAtomA = void;
|
||||
using SmemCopyAtomB = void;
|
||||
|
||||
using CollectiveOp = CollectiveMmaGated<DispatchPolicy, TileShape_MNK, ElementA, TagToStrideA_t<GmemLayoutA>,
|
||||
ElementB, TagToStrideB_t<GmemLayoutB>, TiledMma, GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity,
|
||||
GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity, Activation, SwapAB>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,58 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass_extensions/gemm/collective/collective_mma_gated.hpp"
|
||||
|
||||
namespace cutlass::gemm::collective
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <class ArchTag, class OpClass, class ElementA, class GmemLayoutA, int AlignmentA, class ElementB,
|
||||
class GmemLayoutB, int AlignmentB, class ElementAccumulator, class TileShape_MNK, class ClusterShape_MNK,
|
||||
class StageCountType, class KernelScheduleType, template <class /* ElementCompute */> class Activation,
|
||||
bool SwapAB = false, class Enable = void>
|
||||
struct CollectiveBuilderGated
|
||||
{
|
||||
static_assert(sizeof(ElementA) == 0, "Could not build a collective for given parameters.");
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_gated.inl"
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,59 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/detail/dependent_false.hpp"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::collective
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <class DispatchPolicy, class TileShape, class ElementA, class StrideA, class ElementB, class StrideB,
|
||||
class TiledMma, class GmemTiledCopyA, class SmemLayoutAtomA, class SmemCopyAtomA, class TransformA,
|
||||
class GmemTiledCopyB, class SmemLayoutAtomB, class SmemCopyAtomB, class TransformB,
|
||||
template <class /* ElementCompute */> class Activation, bool SwapAB = false>
|
||||
struct CollectiveMmaGated
|
||||
{
|
||||
static_assert(cutlass::detail::dependent_false<ElementA>, "Could not find a mainloop specialization.");
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized.hpp"
|
||||
#include "cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized_fp8.hpp"
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,642 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "cute/arch/cluster_sm90.hpp"
|
||||
#include "cute/arch/copy_sm90.hpp"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
|
||||
#include "cute/algorithm/functional.hpp"
|
||||
#include "cute/algorithm/gemm.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cute/numeric/arithmetic_tuple.hpp"
|
||||
#include "cute/tensor_predicate.hpp"
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::collective
|
||||
{
|
||||
using namespace cute;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// WarpSpecialized Mainloop
|
||||
template <int Stages, class ClusterShape, class KernelSchedule, class TileShape_, class ElementA_, class StrideA_,
|
||||
class ElementB_, class StrideB_, class TiledMma_, class GmemTiledCopyA_, class SmemLayoutAtomA_,
|
||||
class SmemCopyAtomA_, class TransformA_, class GmemTiledCopyB_, class SmemLayoutAtomB_, class SmemCopyAtomB_,
|
||||
class TransformB_, template <class /* ElementCompute */> class Activation_, bool SwapAB_>
|
||||
struct CollectiveMmaGated<MainloopSm90TmaGmmaWarpSpecialized<Stages, ClusterShape, KernelSchedule>, TileShape_,
|
||||
ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_, GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_,
|
||||
GmemTiledCopyB_, SmemLayoutAtomB_, SmemCopyAtomB_, TransformB_, Activation_, SwapAB_>
|
||||
{
|
||||
static constexpr bool isGated = true;
|
||||
static constexpr bool SwapAB = SwapAB_;
|
||||
|
||||
//
|
||||
// Type Aliases
|
||||
//
|
||||
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecialized<Stages, ClusterShape, KernelSchedule>;
|
||||
using TileShape = TileShape_;
|
||||
using ElementA = ElementA_;
|
||||
using StrideA = StrideA_;
|
||||
using ElementB = ElementB_;
|
||||
using StrideB = StrideB_;
|
||||
using TiledMma = TiledMma_;
|
||||
using ElementAccumulator = typename TiledMma::ValTypeC;
|
||||
using GmemTiledCopyA = GmemTiledCopyA_;
|
||||
using GmemTiledCopyB = GmemTiledCopyB_;
|
||||
using SmemLayoutAtomA = SmemLayoutAtomA_;
|
||||
using SmemLayoutAtomB = SmemLayoutAtomB_;
|
||||
using SmemCopyAtomA = SmemCopyAtomA_;
|
||||
using SmemCopyAtomB = SmemCopyAtomB_;
|
||||
using TransformA = TransformA_;
|
||||
using TransformB = TransformB_;
|
||||
using ArchTag = typename DispatchPolicy::ArchTag;
|
||||
using Activation = Activation_<ElementAccumulator>;
|
||||
|
||||
using ElementAux = cute::conditional_t<SwapAB, ElementA_, ElementB_>;
|
||||
using ValTypeAux = cute::conditional_t<SwapAB, typename TiledMma::ValTypeA, typename TiledMma::ValTypeB>;
|
||||
|
||||
using MainloopPipeline = cutlass::PipelineTmaAsync<DispatchPolicy::Stages>;
|
||||
using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
|
||||
|
||||
using PipelineParams = typename MainloopPipeline::Params;
|
||||
|
||||
static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert(
|
||||
(size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert(
|
||||
(size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
|
||||
static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert(
|
||||
(size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert(
|
||||
(size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
|
||||
// Tile along modes in a way that maximizes the TMA box size.
|
||||
using SmemLayoutA = decltype(tile_to_shape(SmemLayoutAtomA{},
|
||||
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
|
||||
conditional_t<::cutlass::gemm::detail::is_major<0, StrideA>(), Step<_2, _1, _3>, Step<_1, _2, _3>>{}));
|
||||
using SmemLayoutB = decltype(tile_to_shape(SmemLayoutAtomB{},
|
||||
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
|
||||
conditional_t<::cutlass::gemm::detail::is_major<0, StrideB>(), Step<_2, _1, _3>, Step<_1, _2, _3>>{}));
|
||||
using SmemLayoutAux = cute::conditional_t<SwapAB, SmemLayoutA, SmemLayoutB>;
|
||||
|
||||
static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more.");
|
||||
static_assert(cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value
|
||||
&& cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeB>::value,
|
||||
"MMA atom must source both A and B operand from smem_desc for this mainloop.");
|
||||
static_assert(
|
||||
cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>,
|
||||
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
|
||||
static_assert(
|
||||
cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>,
|
||||
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
|
||||
|
||||
// TMA converts f32 input to tf32 when copying from GMEM to SMEM
|
||||
// For all other types, cast to size equivalent uint type to avoid any rounding by TMA.
|
||||
static constexpr bool ConvertF32toTF32A = cute::is_same_v<float, ElementA>;
|
||||
static constexpr bool ConvertF32toTF32B = cute::is_same_v<float, ElementB>;
|
||||
using InternalElementA = cute::conditional_t<ConvertF32toTF32A, tfloat32_t, uint_bit_t<sizeof_bits_v<ElementA>>>;
|
||||
using InternalElementB = cute::conditional_t<ConvertF32toTF32B, tfloat32_t, uint_bit_t<sizeof_bits_v<ElementB>>>;
|
||||
using InternalElementAux = cute::conditional_t<SwapAB, InternalElementA, InternalElementB>;
|
||||
|
||||
struct SharedStorage
|
||||
{
|
||||
struct TensorStorage : cute::aligned_struct<128>
|
||||
{
|
||||
cute::array_aligned<typename TiledMma::ValTypeA, cute::cosize_v<SmemLayoutA>> smem_A;
|
||||
cute::array_aligned<typename TiledMma::ValTypeB, cute::cosize_v<SmemLayoutB>> smem_B;
|
||||
cute::array_aligned<ValTypeAux, cute::cosize_v<SmemLayoutAux>> smem_Aux;
|
||||
} tensors;
|
||||
|
||||
using PipelineStorage = typename MainloopPipeline::SharedStorage;
|
||||
PipelineStorage pipeline;
|
||||
};
|
||||
|
||||
using TensorStorage = typename SharedStorage::TensorStorage;
|
||||
using PipelineStorage = typename SharedStorage::PipelineStorage;
|
||||
|
||||
// Host side kernel arguments
|
||||
struct Arguments
|
||||
{
|
||||
ElementA const* ptr_A;
|
||||
StrideA dA;
|
||||
ElementB const* ptr_B;
|
||||
StrideB dB;
|
||||
float scale_d0 = 1.0f;
|
||||
float scale_d1 = 1.0f;
|
||||
uint32_t mma_promotion_interval = 4;
|
||||
};
|
||||
|
||||
// Device side kernel params
|
||||
struct Params
|
||||
{
|
||||
// Assumption: StrideA is congruent with Problem_MK
|
||||
using TMA_A = decltype(make_tma_copy(GmemTiledCopyA{},
|
||||
make_tensor(static_cast<InternalElementA const*>(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}),
|
||||
SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
|
||||
size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any
|
||||
// Assumption: StrideB is congruent with Problem_NK
|
||||
using TMA_B = decltype(make_tma_copy(GmemTiledCopyB{},
|
||||
make_tensor(static_cast<InternalElementB const*>(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}),
|
||||
SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
|
||||
size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any
|
||||
using TMA_Aux = cute::conditional_t<SwapAB, TMA_A, TMA_B>;
|
||||
TMA_A tma_load_a;
|
||||
TMA_B tma_load_b;
|
||||
TMA_Aux tma_load_aux;
|
||||
float scale_d0 = 1.0f;
|
||||
float scale_d1 = 1.0f;
|
||||
};
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
template <class ProblemShape>
|
||||
static constexpr Params to_underlying_arguments(
|
||||
ProblemShape const& problem_shape, Arguments const& args, void* workspace)
|
||||
{
|
||||
(void) workspace;
|
||||
|
||||
// Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK)
|
||||
auto problem_shape_MNKL = append<4>(problem_shape, 1);
|
||||
auto [M, N, K, L] = problem_shape_MNKL;
|
||||
|
||||
auto ptr_A = reinterpret_cast<InternalElementA const*>(args.ptr_A);
|
||||
auto ptr_B = reinterpret_cast<InternalElementB const*>(args.ptr_B);
|
||||
|
||||
Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M, K, L), args.dA));
|
||||
Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N, K, L), args.dB));
|
||||
typename Params::TMA_A tma_load_a = make_tma_copy(GmemTiledCopyA{}, tensor_a,
|
||||
SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
|
||||
size<1>(ClusterShape{})); // mcast along N mode for this M load, if any
|
||||
typename Params::TMA_B tma_load_b = make_tma_copy(GmemTiledCopyB{}, tensor_b,
|
||||
SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
|
||||
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
|
||||
|
||||
if constexpr (SwapAB)
|
||||
{
|
||||
auto ptr_Aux = reinterpret_cast<InternalElementA const*>(args.ptr_A + size(make_shape(M, K, L)));
|
||||
Tensor tensor_aux = make_tensor(ptr_Aux, make_layout(make_shape(M, K, L), args.dA));
|
||||
typename Params::TMA_Aux tma_load_aux = make_tma_copy(GmemTiledCopyA{}, tensor_aux,
|
||||
SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
|
||||
size<1>(ClusterShape{})); // mcast along N mode for this M load, if any
|
||||
return {tma_load_a, tma_load_b, tma_load_aux, args.scale_d0, args.scale_d1};
|
||||
}
|
||||
else
|
||||
{
|
||||
auto ptr_Aux = reinterpret_cast<InternalElementB const*>(args.ptr_B + size(make_shape(N, K, L)));
|
||||
Tensor tensor_aux = make_tensor(ptr_Aux, make_layout(make_shape(N, K, L), args.dB));
|
||||
typename Params::TMA_Aux tma_load_aux = make_tma_copy(GmemTiledCopyB{}, tensor_aux,
|
||||
SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
|
||||
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
|
||||
return {tma_load_a, tma_load_b, tma_load_aux, args.scale_d0, args.scale_d1};
|
||||
}
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static bool can_implement(ProblemShape const& problem_shape, [[maybe_unused]] Arguments const& args)
|
||||
{
|
||||
constexpr int tma_alignment_bits = 128;
|
||||
auto problem_shape_MNKL = append<4>(problem_shape, 1);
|
||||
auto [M, N, K, L] = problem_shape_MNKL;
|
||||
|
||||
bool implementable = true;
|
||||
constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
|
||||
implementable = implementable
|
||||
&& cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M, K, L), StrideA{});
|
||||
constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
|
||||
implementable = implementable
|
||||
&& cutlass::detail::check_alignment<min_tma_aligned_elements_B>(cute::make_shape(N, K, L), StrideB{});
|
||||
|
||||
if (!implementable)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(
|
||||
" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
|
||||
}
|
||||
return implementable;
|
||||
}
|
||||
|
||||
static constexpr int K_PIPE_MAX = DispatchPolicy::Stages;
|
||||
static constexpr int K_PIPE_MMAS = 1;
|
||||
static constexpr uint32_t TmaTransactionBytes
|
||||
= (size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast<uint32_t>(sizeof_bits<ElementA>::value)) / 8
|
||||
+ (size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast<uint32_t>(sizeof_bits<ElementB>::value)) / 8
|
||||
+ (size<0>(SmemLayoutAux{}) * size<1>(SmemLayoutAux{}) * static_cast<uint32_t>(sizeof_bits<ElementAux>::value))
|
||||
/ 8;
|
||||
|
||||
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
|
||||
CUTLASS_DEVICE
|
||||
static void prefetch_tma_descriptors(Params const& mainloop_params)
|
||||
{
|
||||
cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(mainloop_params.tma_load_aux.get_tma_descriptor());
|
||||
}
|
||||
|
||||
/// Set up the data needed by this collective for load and mma.
|
||||
/// Returns a tuple of tensors. The collective and the kernel layer have the contract
|
||||
/// Returned tuple must contain at least two elements, with the first two elements being:
|
||||
/// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l)
|
||||
/// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l)
|
||||
/// gAux_xkl - The tma tensor, A/B after a local tile so it has shape (BLK_N,BLK_K,m/n,k,l)
|
||||
/// The rest of the tensors can be specified as needed by this collective.
|
||||
template <class ProblemShape_MNKL>
|
||||
CUTLASS_DEVICE auto load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const
|
||||
{
|
||||
using X = Underscore;
|
||||
// Separate out problem shape for convenience
|
||||
auto [M, N, K, L] = problem_shape_MNKL;
|
||||
|
||||
// TMA requires special handling of strides to deal with coord codomain mapping
|
||||
// Represent the full tensors -- get these from TMA
|
||||
Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M, K, L)); // (m,k,l)
|
||||
Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N, K, L)); // (n,k,l)
|
||||
|
||||
// Make tiled views, defer the slice
|
||||
Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_, _, _), Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l)
|
||||
Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_, _, _), Step<X, _1, _1>{}); // (BLK_N,BLK_K,n,k,l)
|
||||
|
||||
if constexpr (SwapAB)
|
||||
{
|
||||
Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor(make_shape(M, K, L)); // (m,k,l)
|
||||
Tensor gAux_xkl
|
||||
= local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _), Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l)
|
||||
return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl);
|
||||
}
|
||||
else
|
||||
{
|
||||
Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor(make_shape(N, K, L)); // (n,k,l)
|
||||
Tensor gAux_xkl
|
||||
= local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _), Step<X, _1, _1>{}); // (BLK_N,BLK_K,n,k,l)
|
||||
return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl);
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a collective-scoped matrix multiply-accumulate
|
||||
/// Producer Perspective
|
||||
template <class TensorA, class TensorB, class TensorAux, class KTileIterator, class BlockCoord>
|
||||
CUTLASS_DEVICE void load(Params const& mainloop_params, MainloopPipeline pipeline, PipelineState smem_pipe_write,
|
||||
cute::tuple<TensorA, TensorB, TensorAux> const& load_inputs, BlockCoord const& blk_coord,
|
||||
KTileIterator k_tile_iter, int k_tile_count, int thread_idx, uint32_t block_rank_in_cluster,
|
||||
TensorStorage& shared_tensors)
|
||||
{
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
if (lane_predicate)
|
||||
{
|
||||
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
||||
Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()), SmemLayoutAux{});
|
||||
|
||||
//
|
||||
// Prepare the TMA loads for A and B
|
||||
//
|
||||
|
||||
constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape());
|
||||
uint2 cluster_local_block_id
|
||||
= {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
|
||||
|
||||
Tensor gA_mkl = get<0>(load_inputs);
|
||||
Tensor gB_nkl = get<1>(load_inputs);
|
||||
Tensor gAux_xkl = get<2>(load_inputs);
|
||||
|
||||
auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
|
||||
auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);
|
||||
auto block_tma_aux = SwapAB ? mainloop_params.tma_load_aux.get_slice(cluster_local_block_id.y)
|
||||
: mainloop_params.tma_load_aux.get_slice(cluster_local_block_id.x);
|
||||
// Partition the inputs based on the current block coordinates.
|
||||
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
|
||||
Tensor gA = gA_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k)
|
||||
Tensor gB = gB_nkl(_, _, n_coord, _, l_coord); // (BLK_N,BLK_K,k)
|
||||
Tensor gAux = SwapAB ? gAux_xkl(_, _, m_coord, _, l_coord) : gAux_xkl(_, _, n_coord, _, l_coord);
|
||||
|
||||
// Applies the mapping from block_tma_a
|
||||
Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
|
||||
Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
|
||||
|
||||
Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
|
||||
Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
|
||||
|
||||
Tensor tAuxgAux = block_tma_aux.partition_S(gAux);
|
||||
Tensor tAuxsAux = block_tma_aux.partition_D(sAux);
|
||||
|
||||
uint16_t mcast_mask_a = 0;
|
||||
uint16_t mcast_mask_b = 0;
|
||||
uint16_t mcast_mask_aux = 0;
|
||||
|
||||
// Issue TmaLoads
|
||||
// Maps the tile -> block, value
|
||||
if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>)
|
||||
{
|
||||
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
|
||||
for (int n = 0; n < size<1>(block_layout); ++n)
|
||||
{
|
||||
mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x, n, Int<0>{}));
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>)
|
||||
{
|
||||
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
|
||||
for (int m = 0; m < size<0>(block_layout); ++m)
|
||||
{
|
||||
mcast_mask_b |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, Int<0>{}));
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (SwapAB)
|
||||
{
|
||||
mcast_mask_aux = mcast_mask_a;
|
||||
}
|
||||
else
|
||||
{
|
||||
mcast_mask_aux = mcast_mask_b;
|
||||
}
|
||||
|
||||
// Mainloop
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; k_tile_count > 0; --k_tile_count)
|
||||
{
|
||||
// LOCK smem_pipe_write for _writing_
|
||||
pipeline.producer_acquire(smem_pipe_write);
|
||||
|
||||
//
|
||||
// Copy gmem to smem for *k_tile_iter
|
||||
//
|
||||
|
||||
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
|
||||
BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
|
||||
|
||||
int write_stage = smem_pipe_write.index();
|
||||
copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_, _, _, *k_tile_iter),
|
||||
tAsA(_, _, _, write_stage));
|
||||
copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_, _, _, *k_tile_iter),
|
||||
tBsB(_, _, _, write_stage));
|
||||
copy(mainloop_params.tma_load_aux.with(*tma_barrier, mcast_mask_aux), tAuxgAux(_, _, _, *k_tile_iter),
|
||||
tAuxsAux(_, _, _, write_stage));
|
||||
++k_tile_iter;
|
||||
|
||||
// Advance smem_pipe_write
|
||||
++smem_pipe_write;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
|
||||
CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write)
|
||||
{
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
// Issue the epilogue waits
|
||||
if (lane_predicate)
|
||||
{
|
||||
/* This helps avoid early exit of blocks in Cluster
|
||||
* Waits for all stages to either be released (all
|
||||
* Consumer UNLOCKs), or if the stage was never used
|
||||
* then would just be acquired since the phase was
|
||||
* still inverted from make_producer_start_state
|
||||
*/
|
||||
pipeline.producer_tail(smem_pipe_write);
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a collective-scoped matrix multiply-accumulate
|
||||
/// Consumer Perspective
|
||||
template <class FrgTensorC>
|
||||
CUTLASS_DEVICE void mma(MainloopPipeline pipeline, PipelineState smem_pipe_read, FrgTensorC& accum0,
|
||||
FrgTensorC& accum1, int k_tile_count, int thread_idx, TensorStorage& shared_tensors,
|
||||
Params const& mainloop_params)
|
||||
{
|
||||
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
|
||||
static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
|
||||
static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
|
||||
static_assert(cute::rank(SmemLayoutAux{}) == 3, "Smem layout must be rank 3.");
|
||||
static_assert(cute::is_void_v<SmemCopyAtomA>,
|
||||
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
|
||||
static_assert(cute::is_void_v<SmemCopyAtomB>,
|
||||
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
|
||||
|
||||
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
||||
Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()), SmemLayoutAux{});
|
||||
|
||||
//
|
||||
// Define C accumulators and A/B partitioning
|
||||
//
|
||||
|
||||
TiledMma tiled_mma;
|
||||
auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
|
||||
|
||||
Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
|
||||
Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
|
||||
|
||||
// Allocate "fragments/descriptors"
|
||||
Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE)
|
||||
Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE)
|
||||
|
||||
auto tCsAux = [&]() -> auto
|
||||
{
|
||||
if constexpr (SwapAB)
|
||||
{
|
||||
return thread_mma.partition_A(sAux);
|
||||
}
|
||||
else
|
||||
{
|
||||
return thread_mma.partition_B(sAux);
|
||||
}
|
||||
}();
|
||||
auto tCrAux = [&]() -> auto
|
||||
{
|
||||
if constexpr (SwapAB)
|
||||
{
|
||||
return thread_mma.make_fragment_A(tCsAux);
|
||||
}
|
||||
else
|
||||
{
|
||||
return thread_mma.make_fragment_B(tCsAux);
|
||||
}
|
||||
}();
|
||||
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum0)); // M
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum0)); // N
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K
|
||||
CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE
|
||||
if constexpr (SwapAB)
|
||||
{
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<1>(accum1)); // M
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum1)); // N
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCsAux)); // K
|
||||
CUTE_STATIC_ASSERT_V(size<3>(tCsB) == size<3>(tCsAux)); // PIPE
|
||||
}
|
||||
else
|
||||
{
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum1)); // M
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<2>(accum1)); // N
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsAux)); // K
|
||||
CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsAux)); // PIPE
|
||||
}
|
||||
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sA)); // PIPE
|
||||
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sB)); // PIPE
|
||||
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sAux)); // PIPE
|
||||
|
||||
//
|
||||
// PIPELINED MAIN LOOP
|
||||
//
|
||||
static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), "ERROR : Incorrect number of MMAs in flight");
|
||||
|
||||
// We release buffers to producer warps(dma load) with some mmas in flight
|
||||
PipelineState smem_pipe_release = smem_pipe_read;
|
||||
|
||||
// Prologue GMMAs
|
||||
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
|
||||
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
|
||||
warpgroup_fence_operand(accum0);
|
||||
warpgroup_fence_operand(accum1);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue)
|
||||
{
|
||||
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
|
||||
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
||||
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
||||
|
||||
int read_stage = smem_pipe_read.index();
|
||||
warpgroup_arrive();
|
||||
// Unroll the K mode manually to set scale D to 1
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block)
|
||||
{
|
||||
// (V,M,K) x (V,N,K) => (V,M,N)
|
||||
cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accum0);
|
||||
if constexpr (SwapAB)
|
||||
{
|
||||
cute::gemm(tiled_mma, tCrAux(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accum1);
|
||||
}
|
||||
else
|
||||
{
|
||||
cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), tCrAux(_, _, k_block, read_stage), accum1);
|
||||
}
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
}
|
||||
|
||||
warpgroup_commit_batch();
|
||||
|
||||
++smem_pipe_read;
|
||||
}
|
||||
|
||||
warpgroup_fence_operand(accum0);
|
||||
warpgroup_fence_operand(accum1);
|
||||
// Mainloop GMMAs
|
||||
k_tile_count -= prologue_mma_count;
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; k_tile_count > 0; --k_tile_count)
|
||||
{
|
||||
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
|
||||
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
||||
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
||||
|
||||
//
|
||||
// Compute on k_tile
|
||||
//
|
||||
|
||||
int read_stage = smem_pipe_read.index();
|
||||
warpgroup_fence_operand(accum0);
|
||||
warpgroup_fence_operand(accum1);
|
||||
warpgroup_arrive();
|
||||
// Unroll the K mode manually to set scale D to 1
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block)
|
||||
{
|
||||
// (V,M,K) x (V,N,K) => (V,M,N)
|
||||
cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accum0);
|
||||
if constexpr (SwapAB)
|
||||
{
|
||||
cute::gemm(tiled_mma, tCrAux(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accum1);
|
||||
}
|
||||
else
|
||||
{
|
||||
cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), tCrAux(_, _, k_block, read_stage), accum1);
|
||||
}
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
}
|
||||
warpgroup_commit_batch();
|
||||
|
||||
/// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed
|
||||
warpgroup_wait<K_PIPE_MMAS>();
|
||||
warpgroup_fence_operand(accum0);
|
||||
warpgroup_fence_operand(accum1);
|
||||
|
||||
// UNLOCK smem_pipe_release, done _computing_ on it
|
||||
pipeline.consumer_release(smem_pipe_release);
|
||||
|
||||
// Advance smem_pipe_read and smem_pipe_release
|
||||
++smem_pipe_read;
|
||||
++smem_pipe_release;
|
||||
}
|
||||
|
||||
warpgroup_fence_operand(accum0);
|
||||
warpgroup_fence_operand(accum1);
|
||||
}
|
||||
|
||||
/// Perform a Consumer Epilogue to release all buffers
|
||||
CUTLASS_DEVICE void mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count)
|
||||
{
|
||||
// Prologue GMMAs
|
||||
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
|
||||
k_tile_count -= prologue_mma_count;
|
||||
|
||||
smem_pipe_release.advance(k_tile_count);
|
||||
|
||||
// Wait on all GMMAs to complete
|
||||
warpgroup_wait<0>();
|
||||
|
||||
for (int count = 0; count < prologue_mma_count; ++count)
|
||||
{
|
||||
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
|
||||
++smem_pipe_release;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,665 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cute/arch/cluster_sm90.hpp"
|
||||
#include "cute/arch/copy_sm90.hpp"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
|
||||
#include "cute/algorithm/functional.hpp"
|
||||
#include "cute/algorithm/gemm.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cute/numeric/arithmetic_tuple.hpp"
|
||||
#include "cute/tensor_predicate.hpp"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
#include "cutlass/gemm/collective/fp8_accumulation.hpp"
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::collective
|
||||
{
|
||||
using namespace cute;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// WarpSpecialized Mainloop
|
||||
template <int Stages, class ClusterShape, class KernelSchedule, class TileShape_, class ElementA_, class StrideA_,
|
||||
class ElementB_, class StrideB_, class TiledMma_, class GmemTiledCopyA_, class SmemLayoutAtomA_,
|
||||
class SmemCopyAtomA_, class TransformA_, class GmemTiledCopyB_, class SmemLayoutAtomB_, class SmemCopyAtomB_,
|
||||
class TransformB_, template <class /* ElementCompute */> class Activation_, bool SwapAB_>
|
||||
struct CollectiveMmaGated<MainloopSm90TmaGmmaWarpSpecializedFP8<Stages, ClusterShape, KernelSchedule>, TileShape_,
|
||||
ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_, GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_,
|
||||
GmemTiledCopyB_, SmemLayoutAtomB_, SmemCopyAtomB_, TransformB_, Activation_, SwapAB_>
|
||||
{
|
||||
static constexpr bool isGated = true;
|
||||
static constexpr bool SwapAB = SwapAB_;
|
||||
|
||||
//
|
||||
// Type Aliases
|
||||
//
|
||||
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedFP8<Stages, ClusterShape, KernelSchedule>;
|
||||
using TileShape = TileShape_;
|
||||
using ElementA = ElementA_;
|
||||
using StrideA = StrideA_;
|
||||
using ElementB = ElementB_;
|
||||
using StrideB = StrideB_;
|
||||
using TiledMma = TiledMma_;
|
||||
using ElementAccumulator = typename TiledMma::ValTypeC;
|
||||
using GmemTiledCopyA = GmemTiledCopyA_;
|
||||
using GmemTiledCopyB = GmemTiledCopyB_;
|
||||
using SmemLayoutAtomA = SmemLayoutAtomA_;
|
||||
using SmemLayoutAtomB = SmemLayoutAtomB_;
|
||||
using SmemCopyAtomA = SmemCopyAtomA_;
|
||||
using SmemCopyAtomB = SmemCopyAtomB_;
|
||||
using TransformA = TransformA_;
|
||||
using TransformB = TransformB_;
|
||||
using ArchTag = typename DispatchPolicy::ArchTag;
|
||||
using Activation = Activation_<ElementAccumulator>;
|
||||
|
||||
using ElementAux = cute::conditional_t<SwapAB, ElementA_, ElementB_>;
|
||||
using ValTypeAux = cute::conditional_t<SwapAB, typename TiledMma::ValTypeA, typename TiledMma::ValTypeB>;
|
||||
|
||||
using MainloopPipeline = cutlass::PipelineTmaAsync<DispatchPolicy::Stages>;
|
||||
using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
|
||||
|
||||
using PipelineParams = typename MainloopPipeline::Params;
|
||||
|
||||
static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert(
|
||||
(size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert(
|
||||
(size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
|
||||
static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert(
|
||||
(size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert(
|
||||
(size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
|
||||
// Tile along modes in a way that maximizes the TMA box size.
|
||||
using SmemLayoutA = decltype(tile_to_shape(SmemLayoutAtomA{},
|
||||
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
|
||||
conditional_t<::cutlass::gemm::detail::is_major<0, StrideA>(), Step<_2, _1, _3>, Step<_1, _2, _3>>{}));
|
||||
using SmemLayoutB = decltype(tile_to_shape(SmemLayoutAtomB{},
|
||||
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
|
||||
conditional_t<::cutlass::gemm::detail::is_major<0, StrideB>(), Step<_2, _1, _3>, Step<_1, _2, _3>>{}));
|
||||
using SmemLayoutAux = cute::conditional_t<SwapAB, SmemLayoutA, SmemLayoutB>;
|
||||
|
||||
static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more.");
|
||||
static_assert(cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value
|
||||
&& cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeB>::value,
|
||||
"MMA atom must source both A and B operand from smem_desc for this mainloop.");
|
||||
static_assert(
|
||||
cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>,
|
||||
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
|
||||
static_assert(
|
||||
cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>,
|
||||
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
|
||||
|
||||
struct SharedStorage
|
||||
{
|
||||
struct TensorStorage : cute::aligned_struct<128>
|
||||
{
|
||||
cute::array_aligned<typename TiledMma::ValTypeA, cute::cosize_v<SmemLayoutA>> smem_A;
|
||||
cute::array_aligned<typename TiledMma::ValTypeB, cute::cosize_v<SmemLayoutB>> smem_B;
|
||||
cute::array_aligned<ValTypeAux, cute::cosize_v<SmemLayoutAux>> smem_Aux;
|
||||
} tensors;
|
||||
|
||||
using PipelineStorage = typename MainloopPipeline::SharedStorage;
|
||||
PipelineStorage pipeline;
|
||||
};
|
||||
|
||||
using TensorStorage = typename SharedStorage::TensorStorage;
|
||||
using PipelineStorage = typename SharedStorage::PipelineStorage;
|
||||
|
||||
// Host side kernel arguments
|
||||
struct Arguments
|
||||
{
|
||||
ElementA const* ptr_A;
|
||||
StrideA dA;
|
||||
ElementB const* ptr_B;
|
||||
StrideB dB;
|
||||
float scale_d0 = 1.0f;
|
||||
float scale_d1 = 1.0f;
|
||||
uint32_t mma_promotion_interval = 4;
|
||||
};
|
||||
|
||||
// Device side kernel params
|
||||
struct Params
|
||||
{
|
||||
// Assumption: StrideA is congruent with Problem_MK
|
||||
using TMA_A = decltype(make_tma_copy(GmemTiledCopyA{},
|
||||
make_tensor(static_cast<ElementA const*>(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}),
|
||||
SmemLayoutA{}(_, _, 0), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
|
||||
size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any
|
||||
// Assumption: StrideB is congruent with Problem_NK
|
||||
using TMA_B = decltype(make_tma_copy(GmemTiledCopyB{},
|
||||
make_tensor(static_cast<ElementB const*>(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}),
|
||||
SmemLayoutB{}(_, _, 0), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
|
||||
size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any
|
||||
using TMA_Aux = cute::conditional_t<SwapAB, TMA_A, TMA_B>;
|
||||
TMA_A tma_load_a;
|
||||
TMA_B tma_load_b;
|
||||
TMA_Aux tma_load_aux;
|
||||
float scale_d0 = 1.0f;
|
||||
float scale_d1 = 1.0f;
|
||||
uint32_t mma_promotion_interval = 4;
|
||||
};
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
template <class ProblemShape>
|
||||
static constexpr Params to_underlying_arguments(
|
||||
ProblemShape const& problem_shape, Arguments const& args, void* workspace)
|
||||
{
|
||||
(void) workspace;
|
||||
|
||||
// Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK)
|
||||
auto problem_shape_MNKL = append<4>(problem_shape, 1);
|
||||
auto [M, N, K, L] = problem_shape_MNKL;
|
||||
|
||||
auto ptr_A = reinterpret_cast<ElementA const*>(args.ptr_A);
|
||||
auto ptr_B = reinterpret_cast<ElementB const*>(args.ptr_B);
|
||||
|
||||
Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M, K, L), args.dA));
|
||||
Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N, K, L), args.dB));
|
||||
typename Params::TMA_A tma_load_a = make_tma_copy(GmemTiledCopyA{}, tensor_a,
|
||||
SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
|
||||
size<1>(ClusterShape{})); // mcast along N mode for this M load, if any
|
||||
typename Params::TMA_B tma_load_b = make_tma_copy(GmemTiledCopyB{}, tensor_b,
|
||||
SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
|
||||
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
|
||||
if constexpr (SwapAB)
|
||||
{
|
||||
auto ptr_Aux = reinterpret_cast<ElementA const*>(args.ptr_A + size(make_shape(M, K, L)));
|
||||
Tensor tensor_aux = make_tensor(ptr_Aux, make_layout(make_shape(M, K, L), args.dA));
|
||||
typename Params::TMA_Aux tma_load_aux = make_tma_copy(GmemTiledCopyA{}, tensor_aux,
|
||||
SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
|
||||
size<1>(ClusterShape{})); // mcast along N mode for this M load, if any
|
||||
return {tma_load_a, tma_load_b, tma_load_aux, args.scale_d0, args.scale_d1, args.mma_promotion_interval};
|
||||
}
|
||||
else
|
||||
{
|
||||
auto ptr_Aux = reinterpret_cast<ElementB const*>(args.ptr_B + size(make_shape(N, K, L)));
|
||||
Tensor tensor_aux = make_tensor(ptr_Aux, make_layout(make_shape(N, K, L), args.dB));
|
||||
typename Params::TMA_Aux tma_load_aux = make_tma_copy(GmemTiledCopyB{}, tensor_aux,
|
||||
SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
|
||||
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
|
||||
return {tma_load_a, tma_load_b, tma_load_aux, args.scale_d0, args.scale_d1, args.mma_promotion_interval};
|
||||
}
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static bool can_implement(ProblemShape const& problem_shape, [[maybe_unused]] Arguments const& args)
|
||||
{
|
||||
constexpr int tma_alignment_bits = 128;
|
||||
auto problem_shape_MNKL = append<4>(problem_shape, 1);
|
||||
auto [M, N, K, L] = problem_shape_MNKL;
|
||||
|
||||
bool implementable = true;
|
||||
constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
|
||||
implementable = implementable
|
||||
&& cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M, K, L), StrideA{});
|
||||
constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
|
||||
implementable = implementable
|
||||
&& cutlass::detail::check_alignment<min_tma_aligned_elements_B>(cute::make_shape(N, K, L), StrideB{});
|
||||
/* MMA promotion interval should be a multiple of 4, since each mainloop iteration would issue 4 MMA
|
||||
* instructions. */
|
||||
implementable = implementable && (args.mma_promotion_interval % 4 == 0);
|
||||
|
||||
if (!implementable)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(
|
||||
" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
|
||||
}
|
||||
return implementable;
|
||||
}
|
||||
|
||||
static constexpr int K_PIPE_MAX = DispatchPolicy::Stages;
|
||||
static constexpr int K_PIPE_MMAS = 1;
|
||||
static constexpr uint32_t TmaTransactionBytes
|
||||
= (size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast<uint32_t>(sizeof_bits<ElementA>::value)) / 8
|
||||
+ (size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast<uint32_t>(sizeof_bits<ElementB>::value)) / 8
|
||||
+ (size<0>(SmemLayoutAux{}) * size<1>(SmemLayoutAux{}) * static_cast<uint32_t>(sizeof_bits<ElementAux>::value))
|
||||
/ 8;
|
||||
|
||||
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
|
||||
CUTLASS_DEVICE
|
||||
static void prefetch_tma_descriptors(Params const& mainloop_params)
|
||||
{
|
||||
cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(mainloop_params.tma_load_aux.get_tma_descriptor());
|
||||
}
|
||||
|
||||
/// Set up the data needed by this collective for load and mma.
|
||||
/// Returns a tuple of tensors. The collective and the kernel layer have the contract
|
||||
/// Returned tuple must contain at least two elements, with the first two elements being:
|
||||
/// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l)
|
||||
/// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l)
|
||||
/// gAux_xkl - The tma tensor, A/B after a local tile so it has shape (BLK_N,BLK_K,m/n,k,l)
|
||||
template <class ProblemShape_MNKL>
|
||||
CUTLASS_DEVICE auto load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const
|
||||
{
|
||||
using X = Underscore;
|
||||
// Separate out problem shape for convenience
|
||||
auto [M, N, K, L] = problem_shape_MNKL;
|
||||
|
||||
// TMA requires special handling of strides to deal with coord codomain mapping
|
||||
// Represent the full tensors -- get these from TMA
|
||||
Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M, K, L)); // (m,k,l)
|
||||
Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N, K, L)); // (n,k,l)
|
||||
|
||||
// Make tiled views, defer the slice
|
||||
Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_, _, _), Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l)
|
||||
Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_, _, _), Step<X, _1, _1>{}); // (BLK_N,BLK_K,n,k,l)
|
||||
|
||||
if constexpr (SwapAB)
|
||||
{
|
||||
Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor(make_shape(M, K, L)); // (m,k,l)
|
||||
Tensor gAux_xkl
|
||||
= local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _), Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l)
|
||||
return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl);
|
||||
}
|
||||
else
|
||||
{
|
||||
Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor(make_shape(N, K, L)); // (n,k,l)
|
||||
Tensor gAux_xkl
|
||||
= local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _), Step<X, _1, _1>{}); // (BLK_N,BLK_K,n,k,l)
|
||||
return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl);
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a collective-scoped matrix multiply-accumulate
|
||||
/// Producer Perspective
|
||||
template <class TensorA, class TensorB, class TensorAux, class KTileIterator, class BlockCoord>
|
||||
CUTLASS_DEVICE void load(Params const& mainloop_params, MainloopPipeline pipeline, PipelineState smem_pipe_write,
|
||||
cute::tuple<TensorA, TensorB, TensorAux> const& load_inputs, BlockCoord const& blk_coord,
|
||||
KTileIterator k_tile_iter, int k_tile_count, int thread_idx, uint32_t block_rank_in_cluster,
|
||||
TensorStorage& shared_tensors)
|
||||
{
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
if (lane_predicate)
|
||||
{
|
||||
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
||||
Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()), SmemLayoutAux{});
|
||||
|
||||
//
|
||||
// Prepare the TMA loads for A and B
|
||||
//
|
||||
|
||||
constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
|
||||
uint2 cluster_local_block_id
|
||||
= {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
|
||||
|
||||
Tensor gA_mkl = get<0>(load_inputs);
|
||||
Tensor gB_nkl = get<1>(load_inputs);
|
||||
Tensor gAux_xkl = get<2>(load_inputs);
|
||||
|
||||
auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
|
||||
auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);
|
||||
auto block_tma_aux = SwapAB ? mainloop_params.tma_load_aux.get_slice(cluster_local_block_id.y)
|
||||
: mainloop_params.tma_load_aux.get_slice(cluster_local_block_id.x);
|
||||
|
||||
// Partition the inputs based on the current block coordinates.
|
||||
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
|
||||
Tensor gA = gA_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k)
|
||||
Tensor gB = gB_nkl(_, _, n_coord, _, l_coord); // (BLK_N,BLK_K,k)
|
||||
Tensor gAux = SwapAB ? gAux_xkl(_, _, m_coord, _, l_coord) : gAux_xkl(_, _, n_coord, _, l_coord);
|
||||
|
||||
// Applies the mapping from block_tma_a
|
||||
Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
|
||||
Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
|
||||
|
||||
Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
|
||||
Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
|
||||
|
||||
Tensor tAuxgAux = block_tma_aux.partition_S(gAux);
|
||||
Tensor tAuxsAux = block_tma_aux.partition_D(sAux);
|
||||
|
||||
uint16_t mcast_mask_a = 0;
|
||||
uint16_t mcast_mask_b = 0;
|
||||
uint16_t mcast_mask_aux = 0;
|
||||
|
||||
// Issue TmaLoads
|
||||
// Maps the tile -> block, value
|
||||
if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>)
|
||||
{
|
||||
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
|
||||
for (int n = 0; n < size<1>(block_layout); ++n)
|
||||
{
|
||||
mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x, n, Int<0>{}));
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>)
|
||||
{
|
||||
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
|
||||
for (int m = 0; m < size<0>(block_layout); ++m)
|
||||
{
|
||||
mcast_mask_b |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, Int<0>{}));
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (SwapAB)
|
||||
{
|
||||
mcast_mask_aux = mcast_mask_a;
|
||||
}
|
||||
else
|
||||
{
|
||||
mcast_mask_aux = mcast_mask_b;
|
||||
}
|
||||
|
||||
// Mainloop
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; k_tile_count > 0; --k_tile_count)
|
||||
{
|
||||
// LOCK smem_pipe_write for _writing_
|
||||
pipeline.producer_acquire(smem_pipe_write);
|
||||
|
||||
//
|
||||
// Copy gmem to smem for *k_tile_iter
|
||||
//
|
||||
|
||||
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
|
||||
BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
|
||||
|
||||
int write_stage = smem_pipe_write.index();
|
||||
copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_, _, _, *k_tile_iter),
|
||||
tAsA(_, _, _, write_stage));
|
||||
copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_, _, _, *k_tile_iter),
|
||||
tBsB(_, _, _, write_stage));
|
||||
copy(mainloop_params.tma_load_aux.with(*tma_barrier, mcast_mask_aux), tAuxgAux(_, _, _, *k_tile_iter),
|
||||
tAuxsAux(_, _, _, write_stage));
|
||||
++k_tile_iter;
|
||||
|
||||
// Advance smem_pipe_write
|
||||
++smem_pipe_write;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
|
||||
CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write)
|
||||
{
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
// Issue the epilogue waits
|
||||
if (lane_predicate)
|
||||
{
|
||||
/* This helps avoid early exit of blocks in Cluster
|
||||
* Waits for all stages to either be released (all
|
||||
* Consumer UNLOCKs), or if the stage was never used
|
||||
* then would just be acquired since the phase was
|
||||
* still inverted from make_producer_start_state
|
||||
*/
|
||||
pipeline.producer_tail(smem_pipe_write);
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a collective-scoped matrix multiply-accumulate
|
||||
/// Consumer Perspective
|
||||
template <class FrgTensorC>
|
||||
CUTLASS_DEVICE void mma(MainloopPipeline pipeline, PipelineState smem_pipe_read, FrgTensorC& accum0,
|
||||
FrgTensorC& accum1, int k_tile_count, int thread_idx, TensorStorage& shared_tensors,
|
||||
Params const& mainloop_params)
|
||||
{
|
||||
|
||||
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
|
||||
static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
|
||||
static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
|
||||
static_assert(cute::is_void_v<SmemCopyAtomA>,
|
||||
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
|
||||
static_assert(cute::is_void_v<SmemCopyAtomB>,
|
||||
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
|
||||
|
||||
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
||||
Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()), SmemLayoutAux{});
|
||||
|
||||
//
|
||||
// Define C accumulators and A/B partitioning
|
||||
//
|
||||
|
||||
TiledMma tiled_mma;
|
||||
auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
|
||||
|
||||
Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
|
||||
Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
|
||||
|
||||
// Allocate "fragments/descriptors"
|
||||
Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE)
|
||||
Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE)
|
||||
|
||||
auto tCsAux = [&]() -> auto
|
||||
{
|
||||
if constexpr (SwapAB)
|
||||
{
|
||||
return thread_mma.partition_A(sAux);
|
||||
}
|
||||
else
|
||||
{
|
||||
return thread_mma.partition_B(sAux);
|
||||
}
|
||||
}();
|
||||
auto tCrAux = [&]() -> auto
|
||||
{
|
||||
if constexpr (SwapAB)
|
||||
{
|
||||
return thread_mma.make_fragment_A(tCsAux);
|
||||
}
|
||||
else
|
||||
{
|
||||
return thread_mma.make_fragment_B(tCsAux);
|
||||
}
|
||||
}();
|
||||
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum0)); // M
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum0)); // N
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K
|
||||
CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE
|
||||
if constexpr (SwapAB)
|
||||
{
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<1>(accum1)); // M
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum1)); // N
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCsAux)); // K
|
||||
CUTE_STATIC_ASSERT_V(size<3>(tCsB) == size<3>(tCsAux)); // PIPE
|
||||
}
|
||||
else
|
||||
{
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum1)); // M
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<2>(accum1)); // N
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsAux)); // K
|
||||
CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsAux)); // PIPE
|
||||
}
|
||||
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sA)); // PIPE
|
||||
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sB)); // PIPE
|
||||
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sAux)); // PIPE
|
||||
|
||||
//
|
||||
// PIPELINED MAIN LOOP
|
||||
//
|
||||
static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), "ERROR : Incorrect number of MMAs in flight");
|
||||
|
||||
// We release buffers to producer warps(dma load) with some mmas in flight
|
||||
PipelineState smem_pipe_release = smem_pipe_read;
|
||||
|
||||
// Prologue GMMAs
|
||||
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
|
||||
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
|
||||
GmmaFP8Accumulation accumulation0(accum0, mainloop_params.mma_promotion_interval, size<2>(tCrA));
|
||||
GmmaFP8Accumulation accumulation1(accum1, mainloop_params.mma_promotion_interval, size<2>(tCrA));
|
||||
warpgroup_fence_operand(accumulation0());
|
||||
warpgroup_fence_operand(accumulation1());
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue)
|
||||
{
|
||||
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
|
||||
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
||||
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
||||
|
||||
if (accumulation0.prepare_if_needed())
|
||||
{
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
}
|
||||
|
||||
int read_stage = smem_pipe_read.index();
|
||||
warpgroup_arrive();
|
||||
// Unroll the K mode manually to set scale D to 1
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block)
|
||||
{
|
||||
// (V,M,K) x (V,N,K) => (V,M,N)
|
||||
cute::gemm(
|
||||
tiled_mma, tCrA(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accumulation0());
|
||||
if constexpr (SwapAB)
|
||||
{
|
||||
cute::gemm(
|
||||
tiled_mma, tCrAux(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accumulation1());
|
||||
}
|
||||
else
|
||||
{
|
||||
cute::gemm(
|
||||
tiled_mma, tCrA(_, _, k_block, read_stage), tCrAux(_, _, k_block, read_stage), accumulation1());
|
||||
}
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
}
|
||||
warpgroup_commit_batch();
|
||||
|
||||
accumulation0.promote_if_needed();
|
||||
accumulation1.promote_if_needed();
|
||||
|
||||
++smem_pipe_read;
|
||||
}
|
||||
|
||||
warpgroup_fence_operand(accumulation0());
|
||||
warpgroup_fence_operand(accumulation1());
|
||||
// Mainloop GMMAs
|
||||
k_tile_count -= prologue_mma_count;
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; k_tile_count > 0; --k_tile_count)
|
||||
{
|
||||
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
|
||||
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
||||
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
||||
|
||||
//
|
||||
// Compute on k_tile
|
||||
//
|
||||
|
||||
int read_stage = smem_pipe_read.index();
|
||||
|
||||
if (accumulation0.prepare_if_needed())
|
||||
{
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
}
|
||||
|
||||
warpgroup_fence_operand(accumulation0());
|
||||
warpgroup_fence_operand(accumulation1());
|
||||
warpgroup_arrive();
|
||||
// Unroll the K mode manually to set scale D to 1
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block)
|
||||
{
|
||||
// (V,M,K) x (V,N,K) => (V,M,N)
|
||||
cute::gemm(
|
||||
tiled_mma, tCrA(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accumulation0());
|
||||
if constexpr (SwapAB)
|
||||
{
|
||||
cute::gemm(
|
||||
tiled_mma, tCrAux(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accumulation1());
|
||||
}
|
||||
else
|
||||
{
|
||||
cute::gemm(
|
||||
tiled_mma, tCrA(_, _, k_block, read_stage), tCrAux(_, _, k_block, read_stage), accumulation1());
|
||||
}
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
}
|
||||
warpgroup_commit_batch();
|
||||
|
||||
/// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed
|
||||
warpgroup_wait<K_PIPE_MMAS>();
|
||||
warpgroup_fence_operand(accumulation0());
|
||||
warpgroup_fence_operand(accumulation1());
|
||||
|
||||
accumulation0.promote_if_needed();
|
||||
accumulation1.promote_if_needed();
|
||||
|
||||
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
|
||||
|
||||
// Advance smem_pipe_read and smem_pipe_release
|
||||
++smem_pipe_read;
|
||||
++smem_pipe_release;
|
||||
}
|
||||
|
||||
accumulation0.promote_residue_if_needed();
|
||||
accumulation1.promote_residue_if_needed();
|
||||
|
||||
warpgroup_fence_operand(accumulation0());
|
||||
warpgroup_fence_operand(accumulation1());
|
||||
}
|
||||
|
||||
/// Perform a Consumer Epilogue to release all buffers
|
||||
CUTLASS_DEVICE void mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count)
|
||||
{
|
||||
// Prologue GMMAs
|
||||
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
|
||||
k_tile_count -= prologue_mma_count;
|
||||
|
||||
smem_pipe_release.advance(k_tile_count);
|
||||
|
||||
// Wait on all GMMAs to complete
|
||||
warpgroup_wait<0>();
|
||||
|
||||
for (int count = 0; count < prologue_mma_count; ++count)
|
||||
{
|
||||
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
|
||||
++smem_pipe_release;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,438 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*!
|
||||
\file
|
||||
\brief The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and
|
||||
batched array variants.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
// #include <limits>
|
||||
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/device_kernel.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.h"
|
||||
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
|
||||
|
||||
#include "cutlass/gemm/device/default_gemm_configuration.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm_universal.h"
|
||||
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace device
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/*
|
||||
This is the device layer from CUTLASS 2.10 (SHA - cc85b64cf676c45f98a17e3a47c0aafcf817f088)
|
||||
It is replicated here since we needed to duplicate kernel level APIs for mixed dtype GEMMs
|
||||
and SmoothQuant. The newer device layer is not compatible with these older kernel level APIs.
|
||||
|
||||
Note: While CUTLASS 3.x supports stream-k, none of the kernels in the extensions folder support
|
||||
that feature at the moment.
|
||||
*/
|
||||
|
||||
template <typename GemmKernel_>
|
||||
class GemmUniversalBaseCompat
|
||||
{
|
||||
public:
|
||||
using GemmKernel = GemmKernel_;
|
||||
using ThreadblockShape = typename GemmKernel::Mma::Shape;
|
||||
|
||||
using ElementA = typename GemmKernel::ElementA;
|
||||
using LayoutA = typename GemmKernel::LayoutA;
|
||||
using TensorRefA = TensorRef<ElementA const, LayoutA>;
|
||||
static ComplexTransform const kTransformA = GemmKernel::kTransformA;
|
||||
|
||||
using ElementB = typename GemmKernel::ElementB;
|
||||
using LayoutB = typename GemmKernel::LayoutB;
|
||||
using TensorRefB = TensorRef<ElementB const, LayoutB>;
|
||||
static ComplexTransform const kTransformB = GemmKernel::kTransformB;
|
||||
|
||||
using ElementC = typename GemmKernel::ElementC;
|
||||
using LayoutC = typename GemmKernel::LayoutC;
|
||||
using TensorRefC = TensorRef<ElementC const, LayoutC>;
|
||||
using TensorRefD = TensorRef<ElementC, LayoutC>;
|
||||
|
||||
using ElementAccumulator = typename GemmKernel::Mma::Policy::Operator::ElementC;
|
||||
|
||||
using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp;
|
||||
using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle;
|
||||
using Operator = typename GemmKernel::Operator;
|
||||
|
||||
/// Argument structure
|
||||
using Arguments = typename GemmKernel::Arguments;
|
||||
|
||||
protected:
|
||||
/// Kernel parameters object
|
||||
typename GemmKernel::Params params_;
|
||||
|
||||
protected:
|
||||
/// Private helper to obtain the grid dimensions with fix-up for split-K
|
||||
static void get_grid_shape_(gemm::GemmCoord& grid_tiled_shape, int& gemm_k_size, Arguments const& args)
|
||||
{
|
||||
|
||||
// Determine grid shape
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
grid_tiled_shape = threadblock_swizzle.get_tiled_shape(
|
||||
args.problem_size, {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count);
|
||||
|
||||
gemm_k_size = args.problem_size.k();
|
||||
|
||||
if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel)
|
||||
{
|
||||
|
||||
int const kAlignK
|
||||
= const_max(const_max(128 / sizeof_bits<ElementA>::value, 128 / sizeof_bits<ElementB>::value), 1);
|
||||
|
||||
gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK);
|
||||
|
||||
if (gemm_k_size)
|
||||
{
|
||||
grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
/// Constructs the GEMM.
|
||||
GemmUniversalBaseCompat() {}
|
||||
|
||||
/// Determines whether the GEMM can execute the given problem.
|
||||
static Status can_implement(Arguments const& args)
|
||||
{
|
||||
|
||||
// Determine grid shape
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int gemm_k_size = 0;
|
||||
|
||||
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
|
||||
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
dim3 grid = threadblock_swizzle.get_grid_shape(grid_tiled_shape);
|
||||
|
||||
uint32_t const kGridYZMax = ((1 << (sizeof(uint16_t) * 8)) - 1);
|
||||
|
||||
if (!(grid.y <= kGridYZMax && grid.z <= kGridYZMax))
|
||||
{
|
||||
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
return GemmKernel::can_implement(args);
|
||||
}
|
||||
|
||||
/// Gets the workspace size
|
||||
static size_t get_workspace_size(Arguments const& args)
|
||||
{
|
||||
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_workspace_size()");
|
||||
|
||||
size_t workspace_bytes = 0;
|
||||
|
||||
// Determine grid shape
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int gemm_k_size = 0;
|
||||
|
||||
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
|
||||
|
||||
if (args.mode == GemmUniversalMode::kGemmSplitKParallel)
|
||||
{
|
||||
|
||||
// Split-K parallel always requires a temporary workspace
|
||||
workspace_bytes = sizeof(ElementC) * size_t(args.batch_stride_D) * size_t(grid_tiled_shape.k());
|
||||
}
|
||||
else if (args.mode == GemmUniversalMode::kGemm && grid_tiled_shape.k() > 1)
|
||||
{
|
||||
|
||||
// Serial split-K only requires a temporary workspace if the number of partitions along the
|
||||
// GEMM K dimension is greater than one.
|
||||
workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n());
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes);
|
||||
|
||||
workspace_bytes += GemmKernel::get_extra_workspace_size(args, grid_tiled_shape);
|
||||
|
||||
return workspace_bytes;
|
||||
}
|
||||
|
||||
/// Computes the grid shape
|
||||
static dim3 get_grid_shape(Arguments const& args)
|
||||
{
|
||||
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_grid_shape()");
|
||||
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int gemm_k_size = 0;
|
||||
|
||||
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
|
||||
dim3 result = threadblock_swizzle.get_grid_shape(grid_tiled_shape);
|
||||
|
||||
CUTLASS_TRACE_HOST(" grid_tiled_shape: " << grid_tiled_shape << "\n"
|
||||
<< " result = {" << result << "}");
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Computes the maximum number of active blocks per multiprocessor
|
||||
static int maximum_active_blocks(int smem_capacity = -1)
|
||||
{
|
||||
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::maximum_active_blocks()");
|
||||
|
||||
int max_active_blocks = -1;
|
||||
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
|
||||
|
||||
CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes");
|
||||
|
||||
if (smem_size <= (48 << 10))
|
||||
{
|
||||
|
||||
cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_active_blocks, Kernel<GemmKernel>, GemmKernel::kThreadCount, smem_size);
|
||||
|
||||
if (result == cudaSuccess)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
|
||||
return max_active_blocks;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
// Query assuming zero shared memory then compute occupancy limit based on SMEM
|
||||
cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_active_blocks, Kernel<GemmKernel>, GemmKernel::kThreadCount, 0);
|
||||
|
||||
if (result != cudaSuccess)
|
||||
{
|
||||
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result));
|
||||
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (smem_capacity < 0)
|
||||
{
|
||||
int device_idx = 0;
|
||||
result = cudaGetDevice(&device_idx);
|
||||
|
||||
if (result != cudaSuccess)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
|
||||
cudaDeviceProp properties;
|
||||
result = cudaGetDeviceProperties(&properties, device_idx);
|
||||
|
||||
if (result != cudaSuccess)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
|
||||
smem_capacity = static_cast<int>(properties.sharedMemPerMultiprocessor);
|
||||
}
|
||||
|
||||
int occupancy = std::min(max_active_blocks, smem_capacity / smem_size);
|
||||
|
||||
CUTLASS_TRACE_HOST(" occupancy: " << occupancy);
|
||||
|
||||
return occupancy;
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST(" returning internal error");
|
||||
|
||||
return -1;
|
||||
}
|
||||
|
||||
/// Initializes GEMM state from arguments.
|
||||
Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr)
|
||||
{
|
||||
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::initialize() - workspace "
|
||||
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
|
||||
|
||||
size_t workspace_bytes = get_workspace_size(args);
|
||||
|
||||
CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes);
|
||||
|
||||
if (workspace_bytes)
|
||||
{
|
||||
|
||||
if (!workspace)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(" error: device workspace must not be null");
|
||||
|
||||
return Status::kErrorWorkspaceNull;
|
||||
}
|
||||
|
||||
if (args.mode == GemmUniversalMode::kGemm)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(" clearing device workspace");
|
||||
cudaError_t result = cudaMemsetAsync(workspace, 0, workspace_bytes, stream);
|
||||
|
||||
if (result != cudaSuccess)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(" cudaMemsetAsync() returned error " << cudaGetErrorString(result));
|
||||
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get CUDA grid shape
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int gemm_k_size = 0;
|
||||
|
||||
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
|
||||
|
||||
// Initialize the Params structure
|
||||
params_ = typename GemmKernel::Params(args, grid_tiled_shape, gemm_k_size, static_cast<int*>(workspace));
|
||||
|
||||
// Specify shared memory capacity for kernel.
|
||||
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
|
||||
|
||||
if (smem_size >= (48 << 10))
|
||||
{
|
||||
cudaError_t result
|
||||
= cudaFuncSetAttribute(Kernel<GemmKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
|
||||
if (result != cudaSuccess)
|
||||
{
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Lightweight update given a subset of arguments
|
||||
Status update(Arguments const& args, void* workspace = nullptr)
|
||||
{
|
||||
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat()::update() - workspace: " << workspace);
|
||||
|
||||
size_t workspace_bytes = get_workspace_size(args);
|
||||
|
||||
if (workspace_bytes && !workspace)
|
||||
{
|
||||
return Status::kErrorWorkspaceNull;
|
||||
}
|
||||
|
||||
params_.update(args, workspace);
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status run(cudaStream_t stream = nullptr)
|
||||
{
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::run()");
|
||||
|
||||
//
|
||||
// Configure grid and block dimensions
|
||||
//
|
||||
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
|
||||
dim3 block(GemmKernel::kThreadCount, 1, 1);
|
||||
|
||||
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
|
||||
|
||||
//
|
||||
// Launch kernel
|
||||
//
|
||||
|
||||
CUTLASS_TRACE_HOST(" grid: (" << grid << "), block: (" << block << "), SMEM: " << smem_size << " bytes");
|
||||
|
||||
// Launch
|
||||
cutlass::Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params_);
|
||||
|
||||
//
|
||||
// Query for errors
|
||||
//
|
||||
cudaError_t result = cudaGetLastError();
|
||||
|
||||
if (result != cudaSuccess)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result));
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(cudaStream_t stream = nullptr)
|
||||
{
|
||||
return run(stream);
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr)
|
||||
{
|
||||
|
||||
Status status = initialize(args, workspace, stream);
|
||||
|
||||
if (status == Status::kSuccess)
|
||||
{
|
||||
status = run(stream);
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace device
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,542 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*!
|
||||
\file
|
||||
\brief Based on cutlass/include/cutlass/gemm/kernel/gemm_grouped.h
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <limits>
|
||||
#include <numeric>
|
||||
#include <vector>
|
||||
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/device_kernel.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.h"
|
||||
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
|
||||
|
||||
#include "cutlass/gemm/device/default_gemm_configuration.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm_universal.h"
|
||||
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace device
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T_IN, typename T_OUT>
|
||||
__global__ void splitkReduction(T_OUT** out_tensor, const T_IN* in_tensor, GemmCoord const* problem_sizes, int splitk,
|
||||
int64_t* splitk_buffer_offsets)
|
||||
{
|
||||
// in_tensor: [problem_idx, k_partition, hidden_size]
|
||||
// Note that different requests of in_tensor might have different hidden_size (=m*n)
|
||||
// so, we need to use splitk_buffer_offsets.
|
||||
// out_tensor: problem_idx * [hidden_size]
|
||||
|
||||
int const problem_idx = blockIdx.y;
|
||||
GemmCoord problem = problem_sizes[problem_idx];
|
||||
int const hidden_size = problem.m() * problem.n();
|
||||
const T_IN* in_tensor_ = in_tensor + splitk_buffer_offsets[problem_idx] * splitk;
|
||||
T_OUT* out_tensor_ = out_tensor[problem_idx];
|
||||
|
||||
for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < hidden_size; i += blockDim.x * gridDim.x)
|
||||
{
|
||||
float sum = 0.0f;
|
||||
for (int k_idx = 0; k_idx < splitk; k_idx++)
|
||||
{
|
||||
sum += (float) in_tensor_[k_idx * hidden_size + i];
|
||||
}
|
||||
out_tensor_[i] = (T_OUT) (sum);
|
||||
}
|
||||
}
|
||||
|
||||
/// GEMM Grouped
|
||||
template <typename BaseKernel_>
|
||||
class BaseSplitkGrouped
|
||||
{
|
||||
public:
|
||||
using BaseKernel = BaseKernel_;
|
||||
|
||||
using ElementA = typename BaseKernel::ElementA;
|
||||
using LayoutA = typename BaseKernel::LayoutA;
|
||||
using TensorRefA = TensorRef<ElementA const, LayoutA>;
|
||||
static ComplexTransform const kTransformA = BaseKernel::kTransformA;
|
||||
static int const kAlignmentA = BaseKernel::kAlignmentA;
|
||||
|
||||
using ElementB = typename BaseKernel::ElementB;
|
||||
using LayoutB = typename BaseKernel::LayoutB;
|
||||
using TensorRefB = TensorRef<ElementB const, LayoutB>;
|
||||
static ComplexTransform const kTransformB = BaseKernel::kTransformB;
|
||||
static int const kAlignmentB = BaseKernel::kAlignmentB;
|
||||
|
||||
using ElementC = typename BaseKernel::ElementC;
|
||||
using LayoutC = typename BaseKernel::LayoutC;
|
||||
using TensorRefC = TensorRef<ElementC const, LayoutC>;
|
||||
using TensorRefD = TensorRef<ElementC, LayoutC>;
|
||||
static int const kAlignmentC = BaseKernel::kAlignmentC;
|
||||
|
||||
using ElementAccumulator = typename BaseKernel::Mma::Policy::Operator::ElementC;
|
||||
|
||||
using EpilogueOutputOp = typename BaseKernel::EpilogueOutputOp;
|
||||
using ThreadblockSwizzle = typename threadblock::GemmSplitKHorizontalThreadblockSwizzle;
|
||||
|
||||
using Operator = typename BaseKernel::Operator;
|
||||
using WarpMmaOperator = typename BaseKernel::Mma::Policy::Operator;
|
||||
|
||||
using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator;
|
||||
using MathOperator = typename WarpMmaOperator::MathOperator;
|
||||
using OperatorClass = typename WarpMmaOperator::OperatorClass;
|
||||
using ArchTag = typename WarpMmaOperator::ArchTag;
|
||||
using ThreadblockShape = typename BaseKernel::Mma::Shape;
|
||||
using WarpShape = typename BaseKernel::WarpShape;
|
||||
using InstructionShape = typename BaseKernel::InstructionShape;
|
||||
static int const kStages = BaseKernel::Mma::kStages;
|
||||
|
||||
/// Argument structure
|
||||
using Arguments = typename BaseKernel::Arguments;
|
||||
|
||||
using ProblemInfo = typename BaseKernel::ProblemVisitor::ProblemInfo;
|
||||
|
||||
protected:
|
||||
/// Kernel parameters object
|
||||
typename BaseKernel::Params gemm_params_;
|
||||
|
||||
private:
|
||||
/// Get the number of tiles across all problems in a group
|
||||
static int32_t group_tile_count(cutlass::gemm::GemmCoord const* problem_sizes_ptr, int problem_count)
|
||||
{
|
||||
int32_t tiles = 0;
|
||||
for (int32_t i = 0; i < problem_count; ++i)
|
||||
{
|
||||
cutlass::gemm::GemmCoord problem = problem_sizes_ptr[i];
|
||||
BaseKernel::ProblemVisitor::possibly_transpose_problem(problem);
|
||||
tiles += problem_tile_count(problem);
|
||||
}
|
||||
return tiles;
|
||||
}
|
||||
|
||||
/// Copy from `data` to `workspace`
|
||||
Status copy_to_workspace(void* workspace, void* data, size_t bytes)
|
||||
{
|
||||
cudaError_t cuda_error = cudaMemcpy(workspace, data, bytes, cudaMemcpyHostToDevice);
|
||||
if (cuda_error != cudaSuccess)
|
||||
{
|
||||
// Call cudaGetLastError() to clear the error bit
|
||||
cuda_error = cudaGetLastError();
|
||||
CUTLASS_TRACE_HOST(" cudaMemcpy() returned error " << cudaGetErrorString(cuda_error));
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Precomputes scheduling information for the grouped GEMM
|
||||
Status precompute(Arguments const& args, int32_t tile_count, void* workspace)
|
||||
{
|
||||
size_t workspace_bytes = get_workspace_size(args);
|
||||
std::vector<uint8_t> host_workspace(workspace_bytes);
|
||||
BaseKernel::ProblemVisitor::host_precompute(
|
||||
args.host_problem_sizes, args.problem_count, args.threadblock_count, (void*) host_workspace.data());
|
||||
return copy_to_workspace(workspace, host_workspace.data(), workspace_bytes);
|
||||
}
|
||||
|
||||
/// Reorder `data` according to `indices`
|
||||
template <typename T>
|
||||
static void reorder_array(T* data, std::vector<size_t> const& indices)
|
||||
{
|
||||
// For now, simply create a copy of the data and then copy over to the original.
|
||||
std::vector<T> copy(indices.size());
|
||||
for (size_t i = 0; i < indices.size(); ++i)
|
||||
{
|
||||
copy.at(i) = data[indices[i]];
|
||||
}
|
||||
|
||||
memcpy(data, copy.data(), indices.size() * sizeof(T));
|
||||
}
|
||||
|
||||
public:
|
||||
/// Constructs the GEMM.
|
||||
BaseSplitkGrouped() {}
|
||||
|
||||
/// Determines whether the GEMM can execute the given problem.
|
||||
static Status can_implement(Arguments const& args)
|
||||
{
|
||||
|
||||
return BaseKernel::can_implement(args);
|
||||
}
|
||||
|
||||
/// Get the number of tiles in a problem
|
||||
static int32_t problem_tile_count(cutlass::gemm::GemmCoord const& problem)
|
||||
{
|
||||
auto grid = BaseKernel::ProblemVisitor::grid_shape(problem);
|
||||
return BaseKernel::ProblemVisitor::tile_count(grid);
|
||||
}
|
||||
|
||||
/// Get the number of tiles across all problems in a group
|
||||
static int32_t group_tile_count(Arguments const& args)
|
||||
{
|
||||
if (args.host_problem_sizes == nullptr)
|
||||
{
|
||||
CUTLASS_TRACE_HOST("Received nullptr for `args.host_problem_sizes");
|
||||
return -1;
|
||||
}
|
||||
|
||||
return group_tile_count(args.host_problem_sizes, args.problem_count);
|
||||
}
|
||||
|
||||
/// Gets the workspace size
|
||||
static size_t get_workspace_size(Arguments const& args)
|
||||
{
|
||||
size_t total_mn = 0;
|
||||
for (int i = 0; i < args.problem_count; i++)
|
||||
{
|
||||
total_mn += args.host_problem_sizes[i].m() * args.host_problem_sizes[i].n();
|
||||
}
|
||||
size_t workSpaceSize = total_mn * sizeof(ElementAccumulator) * args.split_k_slices;
|
||||
|
||||
if (BaseKernel::ProblemVisitor::kRequiresPrecomputation)
|
||||
{
|
||||
workSpaceSize += BaseKernel::ProblemVisitor::get_workspace_size(
|
||||
args.host_problem_sizes, args.problem_count, args.threadblock_count);
|
||||
}
|
||||
return workSpaceSize;
|
||||
}
|
||||
|
||||
/// Computes the grid shape
|
||||
static dim3 get_grid_shape(Arguments const& args)
|
||||
{
|
||||
|
||||
return dim3(args.threadblock_count, 1, 1);
|
||||
}
|
||||
|
||||
/// Computes the maximum number of active blocks per multiprocessor
|
||||
static int maximum_active_blocks(int smem_capacity = -1)
|
||||
{
|
||||
|
||||
CUTLASS_TRACE_HOST("BaseSplitkGrouped::maximum_active_blocks()");
|
||||
|
||||
int smem_size = int(sizeof(typename BaseKernel::SharedStorage));
|
||||
|
||||
CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes");
|
||||
|
||||
cudaError_t result;
|
||||
if (smem_size > (48 << 10))
|
||||
{
|
||||
result = cudaFuncSetAttribute(Kernel<BaseKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
|
||||
if (result != cudaSuccess)
|
||||
{
|
||||
// Call cudaGetLastError() to clear the error bit
|
||||
result = cudaGetLastError();
|
||||
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(result));
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
int max_active_blocks = -1;
|
||||
result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_active_blocks, Kernel<BaseKernel>, BaseKernel::kThreadCount, smem_size);
|
||||
|
||||
if (result != cudaSuccess)
|
||||
{
|
||||
// Call cudaGetLastError() to clear the error bit
|
||||
result = cudaGetLastError();
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result));
|
||||
return -1;
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
|
||||
return max_active_blocks;
|
||||
}
|
||||
|
||||
/// Sorts each pointer passed in according to the indices that sort
|
||||
/// `problem_sizes_ptr` in descending order of problem-K dimension.
|
||||
static void sort_problems(int problem_count, cutlass::gemm::GemmCoord* problem_sizes_ptr, int64_t* lda_host_ptr,
|
||||
int64_t* ldb_host_ptr, int64_t* ldc_host_ptr, int64_t* ldd_host_ptr, int64_t* offset_A_ptr,
|
||||
int64_t* offset_B_ptr, int64_t* offset_C_ptr, int64_t* offset_D_ptr)
|
||||
{
|
||||
std::vector<size_t> indices(problem_count);
|
||||
std::iota(indices.begin(), indices.end(), 0);
|
||||
std::stable_sort(indices.begin(), indices.end(),
|
||||
[&problem_sizes_ptr](size_t i, size_t j) { return problem_sizes_ptr[i].k() > problem_sizes_ptr[j].k(); });
|
||||
|
||||
reorder_array(problem_sizes_ptr, indices);
|
||||
reorder_array(lda_host_ptr, indices);
|
||||
reorder_array(ldb_host_ptr, indices);
|
||||
reorder_array(ldc_host_ptr, indices);
|
||||
reorder_array(ldd_host_ptr, indices);
|
||||
reorder_array(offset_A_ptr, indices);
|
||||
reorder_array(offset_B_ptr, indices);
|
||||
reorder_array(offset_C_ptr, indices);
|
||||
reorder_array(offset_D_ptr, indices);
|
||||
}
|
||||
|
||||
/// Computes the number of threadblocks to launch for the grouped kernel
|
||||
static int sufficient(
|
||||
cutlass::gemm::GemmCoord const* problem_sizes_ptr = nullptr, int problem_count = 0, int available_sm_count = -1)
|
||||
{
|
||||
// Determine the number of blocks that would be launched to fill up a single
|
||||
// wave on the GPU with each SM having maximum occupancy.
|
||||
int device_idx;
|
||||
cudaError_t result = cudaGetDevice(&device_idx);
|
||||
if (result != cudaSuccess)
|
||||
{
|
||||
// Call cudaGetLastError() to clear the error bit
|
||||
result = cudaGetLastError();
|
||||
CUTLASS_TRACE_HOST(" cudaGetDevice() returned error " << cudaGetErrorString(result));
|
||||
return 0;
|
||||
}
|
||||
|
||||
int multiprocessor_count;
|
||||
result = cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device_idx);
|
||||
if (result != cudaSuccess)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(" cudaDeviceGetAttribute() returned error " << cudaGetErrorString(result));
|
||||
return 0;
|
||||
}
|
||||
|
||||
bool override_sm_count = (available_sm_count < 0 || available_sm_count > multiprocessor_count);
|
||||
if (override_sm_count)
|
||||
{
|
||||
available_sm_count = multiprocessor_count;
|
||||
}
|
||||
|
||||
int max_active_blocks = maximum_active_blocks();
|
||||
if (max_active_blocks <= 0)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
int occupancy_based_block_count = available_sm_count * max_active_blocks;
|
||||
|
||||
if (problem_sizes_ptr == nullptr || problem_count == 0)
|
||||
{
|
||||
return occupancy_based_block_count;
|
||||
}
|
||||
|
||||
int total_tiles = group_tile_count(problem_sizes_ptr, problem_count);
|
||||
|
||||
// If the group contains a single problem, launching the exact number of
|
||||
// threadblocks needed to cover the problem minimizes the work performed
|
||||
// per threadblock in finding the next tile to compute. We return total_tiles
|
||||
// unless the user has provided the SM count.
|
||||
if (problem_count == 1 && override_sm_count)
|
||||
{
|
||||
return total_tiles;
|
||||
}
|
||||
|
||||
// Choose between the full wave of threadblocks and the tile count. If there
|
||||
// are fewer tiles in the group than threadblocks in the full wave, only
|
||||
// some threadblocks will be assigned tiles. Those threadblocks
|
||||
// which are not assigned tiles still need to perform the work of iterating through
|
||||
// problem sizes to determine that they have no work to do. This competes for cycles
|
||||
// with those threadblocks that are assigned tiles to compute.
|
||||
return std::min(total_tiles, occupancy_based_block_count);
|
||||
}
|
||||
|
||||
/// Initializes GEMM state from arguments.
|
||||
Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr)
|
||||
{
|
||||
|
||||
CUTLASS_TRACE_HOST("BaseSplitkGrouped::initialize() - workspace "
|
||||
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
|
||||
|
||||
// Workspace
|
||||
size_t workspace_bytes = get_workspace_size(args);
|
||||
|
||||
if (workspace_bytes && !workspace)
|
||||
{
|
||||
return Status::kErrorWorkspaceNull;
|
||||
}
|
||||
|
||||
if (BaseKernel::ProblemVisitor::kRequiresPrecomputation)
|
||||
{
|
||||
int32_t tile_count = group_tile_count(args);
|
||||
Status status = precompute(args, tile_count, workspace);
|
||||
if (status != Status::kSuccess)
|
||||
{
|
||||
return status;
|
||||
}
|
||||
|
||||
gemm_params_ = typename BaseKernel::Params(args, workspace, tile_count);
|
||||
}
|
||||
else
|
||||
{
|
||||
gemm_params_ = typename BaseKernel::Params(args, workspace);
|
||||
}
|
||||
|
||||
// Specify shared memory capacity for kernel.
|
||||
int smem_size = int(sizeof(typename BaseKernel::SharedStorage));
|
||||
|
||||
if (smem_size >= (48 << 10))
|
||||
{
|
||||
cudaError_t result
|
||||
= cudaFuncSetAttribute(Kernel<BaseKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
|
||||
if (result != cudaSuccess)
|
||||
{
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Lightweight update given a subset of arguments
|
||||
Status update(Arguments const& args, void* workspace = nullptr)
|
||||
{
|
||||
|
||||
size_t workspace_bytes = get_workspace_size(args);
|
||||
|
||||
if (workspace_bytes && !workspace)
|
||||
{
|
||||
return Status::kErrorWorkspaceNull;
|
||||
}
|
||||
|
||||
if (BaseKernel::ProblemVisitor::kRequiresPrecomputation)
|
||||
{
|
||||
int32_t tile_count = group_tile_count(args);
|
||||
Status status = precompute(args, tile_count, workspace);
|
||||
if (status != Status::kSuccess)
|
||||
{
|
||||
return status;
|
||||
}
|
||||
|
||||
gemm_params_.update(args, workspace, tile_count);
|
||||
}
|
||||
else
|
||||
{
|
||||
gemm_params_.update(args, workspace);
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status run(cudaStream_t stream = nullptr)
|
||||
{
|
||||
if (!gemm_params_.problem_visitor.problem_count)
|
||||
{
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
//
|
||||
// Launch kernel
|
||||
//
|
||||
|
||||
// Launch splitk grouped gemm
|
||||
{
|
||||
dim3 grid(gemm_params_.threadblock_count, 1, gemm_params_.split_k_slices);
|
||||
dim3 block(BaseKernel::kThreadCount, 1, 1);
|
||||
|
||||
int smem_size = int(sizeof(typename BaseKernel::SharedStorage));
|
||||
cutlass::Kernel<BaseKernel><<<grid, block, smem_size, stream>>>(gemm_params_);
|
||||
|
||||
cudaError_t result = cudaGetLastError();
|
||||
|
||||
if (result != cudaSuccess)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result));
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
// Launch splitkReduction
|
||||
{
|
||||
dim3 grid(32, gemm_params_.problem_visitor.problem_count);
|
||||
dim3 block(256);
|
||||
splitkReduction<<<grid, block, 0, stream>>>(gemm_params_.ptr_D, gemm_params_.ptr_D_split,
|
||||
gemm_params_.problem_visitor.problem_sizes, gemm_params_.split_k_slices,
|
||||
gemm_params_.splitk_buffer_offsets);
|
||||
|
||||
cudaError_t result = cudaGetLastError();
|
||||
|
||||
if (result != cudaSuccess)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result));
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(cudaStream_t stream = nullptr)
|
||||
{
|
||||
return run(stream);
|
||||
}
|
||||
|
||||
/// Initializes and runs the kernel.
|
||||
Status operator()(Arguments const& args, void* workspace, cudaStream_t stream = nullptr)
|
||||
{
|
||||
|
||||
Status status = initialize(args, workspace, stream);
|
||||
|
||||
if (status == Status::kSuccess)
|
||||
{
|
||||
status = run(stream);
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// GEMM Grouped
|
||||
template <typename GemmKernel_>
|
||||
class SplitkGemmGrouped : public BaseSplitkGrouped<GemmKernel_>
|
||||
{
|
||||
public:
|
||||
using GemmKernel = GemmKernel_;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace device
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,162 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/arch/mma.h"
|
||||
#include "cutlass/bfloat16.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/half.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
|
||||
#include "cutlass_extensions/arch/mma.h"
|
||||
#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h"
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace kernel
|
||||
{
|
||||
|
||||
template <typename TypeA, typename TypeB, typename arch, typename Enable = void>
|
||||
struct MixedGemmArchTraits
|
||||
{
|
||||
static_assert(dependent_false<arch>, "Unrecognised parameterization");
|
||||
};
|
||||
|
||||
template <typename Arch>
|
||||
struct MixedGemmArchTraits<float, float, Arch>
|
||||
{
|
||||
static constexpr int Stages = 2;
|
||||
using OperatorClass = cutlass::arch::OpClassSimt;
|
||||
using AccType = float;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
|
||||
static constexpr int ElementsPerAccessA = 1;
|
||||
static constexpr int ElementsPerAccessB = 1;
|
||||
static constexpr int ElementsPerAccessC = 1;
|
||||
static constexpr int ThreadblockK = 8;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
|
||||
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
};
|
||||
|
||||
// ======================= Turing Traits ==============================
|
||||
// Note that turing does not have native bfloat support so weights and activations will be casted to fp16
|
||||
// and compute will happen in fp16 then will be converted for bf16 output.
|
||||
template <typename TypeA, typename TypeB>
|
||||
struct MixedGemmArchTraits<TypeA, TypeB, cutlass::arch::Sm75,
|
||||
typename cutlass::platform::enable_if<cutlass::platform::is_same<TypeA, cutlass::half_t>::value
|
||||
|| cutlass::platform::is_same<TypeA, cutlass::bfloat16_t>::value>::type>
|
||||
{
|
||||
private:
|
||||
using LayoutDetails = LayoutDetailsB<TypeA, TypeB, cutlass::arch::Sm75>;
|
||||
|
||||
public:
|
||||
static constexpr int ThreadblockK = LayoutDetails::ThreadblockK;
|
||||
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
using AccType = float;
|
||||
using LayoutB = typename LayoutDetails::Layout;
|
||||
|
||||
static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits<TypeA>::value;
|
||||
static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess;
|
||||
static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits<TypeA>::value;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
|
||||
|
||||
using Operator = typename LayoutDetails::Operator;
|
||||
};
|
||||
|
||||
// ======================= Ampere Traits ==============================
|
||||
template <typename TypeA, typename TypeB>
|
||||
struct MixedGemmArchTraits<TypeA, TypeB, cutlass::arch::Sm80,
|
||||
typename cutlass::platform::enable_if<cutlass::platform::is_same<TypeA, cutlass::half_t>::value
|
||||
|| cutlass::platform::is_same<TypeA, cutlass::bfloat16_t>::value>::type>
|
||||
{
|
||||
private:
|
||||
using LayoutDetails = LayoutDetailsB<TypeA, TypeB, cutlass::arch::Sm80>;
|
||||
|
||||
public:
|
||||
static constexpr int ThreadblockK = LayoutDetails::ThreadblockK;
|
||||
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
using AccType = float;
|
||||
using LayoutB = typename LayoutDetails::Layout;
|
||||
|
||||
static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits<TypeA>::value;
|
||||
static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess;
|
||||
static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits<TypeA>::value;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
|
||||
using Operator = typename LayoutDetails::Operator;
|
||||
};
|
||||
|
||||
// ======================= Ada Traits ==============================
|
||||
template <typename TypeA, typename TypeB>
|
||||
struct MixedGemmArchTraits<TypeA, TypeB, cutlass::arch::Sm89,
|
||||
typename cutlass::platform::enable_if<cutlass::platform::is_same<TypeA, cutlass::half_t>::value
|
||||
|| cutlass::platform::is_same<TypeA, cutlass::bfloat16_t>::value>::type>
|
||||
{
|
||||
private:
|
||||
using LayoutDetails = LayoutDetailsB<TypeA, TypeB, cutlass::arch::Sm89>;
|
||||
|
||||
public:
|
||||
static constexpr int ThreadblockK = LayoutDetails::ThreadblockK;
|
||||
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
using AccType = float;
|
||||
using LayoutB = typename LayoutDetails::Layout;
|
||||
|
||||
static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits<TypeA>::value;
|
||||
static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess;
|
||||
static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits<TypeA>::value;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256 / cutlass::sizeof_bits<TypeA>::value>;
|
||||
|
||||
using Operator = typename LayoutDetails::Operator;
|
||||
};
|
||||
|
||||
// FP8 A/B = fp8, C/D = fp32
|
||||
template <typename TypeA, typename TypeB>
|
||||
struct MixedGemmArchTraits<TypeA, TypeB, cutlass::arch::Sm89,
|
||||
typename cutlass::platform::enable_if<cutlass::platform::is_same<TypeA, cutlass::float_e4m3_t>::value
|
||||
|| cutlass::platform::is_same<TypeA, cutlass::float_e5m2_t>::value>::type>
|
||||
{
|
||||
private:
|
||||
using LayoutDetails = LayoutDetailsB<TypeA, TypeB, cutlass::arch::Sm89>;
|
||||
|
||||
public:
|
||||
static constexpr int ThreadblockK = LayoutDetails::ThreadblockK;
|
||||
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
using AccType = float;
|
||||
// be careful, TypeC should align with HopperGroupedGemmInput::OutputTypeAdaptor_t<TypeA>
|
||||
using TypeC = __nv_bfloat16;
|
||||
using LayoutB = typename LayoutDetails::Layout;
|
||||
|
||||
static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits<TypeA>::value;
|
||||
static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess;
|
||||
static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits<TypeC>::value;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256 / cutlass::sizeof_bits<TypeA>::value>;
|
||||
|
||||
using Operator = typename LayoutDetails::Operator;
|
||||
};
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@@ -0,0 +1,57 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/arch/mma.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace kernel
|
||||
{
|
||||
|
||||
template <typename arch>
|
||||
struct Int8GemmArchTraits
|
||||
{
|
||||
using OperatorClass = cutlass::arch::OpClassSimt;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
|
||||
};
|
||||
|
||||
// ======================= Turing Traits ==============================
|
||||
template <>
|
||||
struct Int8GemmArchTraits<cutlass::arch::Sm75>
|
||||
{
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>;
|
||||
};
|
||||
|
||||
// ======================= Ampere Traits ==============================
|
||||
template <>
|
||||
struct Int8GemmArchTraits<cutlass::arch::Sm80>
|
||||
{
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
};
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@@ -0,0 +1,207 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with
|
||||
the appropriate threadblock-scoped epilogue.
|
||||
|
||||
Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are
|
||||
accommodated by exchanging A and B operands and assuming transposed layouts. Partial
|
||||
specializations here choose 'device::GemmTransposed' to implement this functionality.
|
||||
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/device/default_gemm_configuration.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm_complex.h"
|
||||
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
|
||||
|
||||
#include "cutlass/layout/permute.h"
|
||||
|
||||
#include "splitk_gemm_grouped.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace kernel
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA_,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA_,
|
||||
/// Complex elementwise transformation on A operand
|
||||
ComplexTransform TransformA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB_,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB_,
|
||||
/// Complex elementwise transformation on B operand
|
||||
ComplexTransform TransformB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC_,
|
||||
/// Layout type for C and D matrix operands
|
||||
typename LayoutC_,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Operator class tag
|
||||
typename OperatorClass,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Whether the schedule of problems to visit has been precomputed
|
||||
GroupScheduleMode GroupScheduleMode_ = GroupScheduleMode::kDeviceOnly,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator = typename device::DefaultGemmConfiguration<OperatorClass, ArchTag, ElementA_, ElementB_,
|
||||
ElementC_, ElementAccumulator>::Operator,
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
|
||||
/// Permute result D
|
||||
typename PermuteDLayout = layout::NoPermute,
|
||||
///
|
||||
typename Enable = void>
|
||||
struct DefaultSplitkGemmGrouped;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Real-valued GEMM kernels
|
||||
//
|
||||
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC,
|
||||
/// Layout type for C and D matrix operands
|
||||
typename LayoutC,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Operator class tag
|
||||
typename OperatorClass,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Whether the schedule of problems to visit has been precomputed
|
||||
GroupScheduleMode GroupScheduleMode_,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
SharedMemoryClearOption SharedMemoryClear,
|
||||
/// Permute result D
|
||||
typename PermuteDLayout>
|
||||
struct DefaultSplitkGemmGrouped<ElementA, LayoutA,
|
||||
ComplexTransform::kNone, // transform A
|
||||
kAlignmentA, ElementB, LayoutB,
|
||||
ComplexTransform::kNone, // transform B
|
||||
kAlignmentB, ElementC, LayoutC, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape,
|
||||
InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, GroupScheduleMode_, Operator, SharedMemoryClear,
|
||||
PermuteDLayout, typename platform::enable_if<!cutlass::is_complex<ElementAccumulator>::value>::type>
|
||||
{
|
||||
|
||||
// If true, we must construct a 'transposed-and-exchanged' Mma operator.
|
||||
static bool const kInternalTranspose = platform::is_same<LayoutC, layout::ColumnMajor>::value;
|
||||
|
||||
using MapArguments = kernel::detail::MapArguments<ElementA, LayoutA, ComplexTransform::kNone, kAlignmentA, ElementB,
|
||||
LayoutB, ComplexTransform::kNone, kAlignmentB, LayoutC, kInternalTranspose>;
|
||||
|
||||
// Define the default GEMM kernel
|
||||
using DefaultGemmKernel = typename kernel::DefaultGemm<typename MapArguments::ElementA,
|
||||
typename MapArguments::LayoutA, MapArguments::kAlignmentA, typename MapArguments::ElementB,
|
||||
typename MapArguments::LayoutB, MapArguments::kAlignmentB, ElementC, typename MapArguments::LayoutC,
|
||||
ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp,
|
||||
ThreadblockSwizzle, Stages, true, Operator, SharedMemoryClear, false, /*GatherA*/
|
||||
false, /*GatherB*/
|
||||
false, /*ScatterD*/
|
||||
PermuteDLayout>::GemmKernel;
|
||||
|
||||
/// Define the kernel in terms of the default kernel
|
||||
using GemmKernel = kernel::SplitkGemmGrouped<typename DefaultGemmKernel::Mma, typename DefaultGemmKernel::Epilogue,
|
||||
ThreadblockSwizzle, GroupScheduleMode_, kInternalTranspose>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,566 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/semaphore.h"
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace kernel
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace detail
|
||||
{
|
||||
template <typename>
|
||||
inline constexpr bool dependent_false_v = false;
|
||||
}
|
||||
|
||||
template <typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||
typename Epilogue_, ///! Epilogue
|
||||
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
|
||||
typename KernelArch, ///! The Architecture this kernel is compiled for. Used since SIMT kernels lose top-level
|
||||
/// arch.
|
||||
bool SplitKSerial ///! If true, code supporting split-K via serial reduction is enabled.
|
||||
>
|
||||
struct GemmFpAIntB
|
||||
{
|
||||
|
||||
using Mma = Mma_;
|
||||
using Epilogue = Epilogue_;
|
||||
using EpilogueOutputOp = typename Epilogue::OutputOp;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
static bool const kSplitKSerial = SplitKSerial;
|
||||
|
||||
using ElementA = typename Mma::IteratorA::Element;
|
||||
using LayoutA = typename Mma::IteratorA::Layout;
|
||||
using ElementB = typename Mma::IteratorB::Element;
|
||||
using LayoutB = typename Mma::IteratorB::Element;
|
||||
using ElementC = typename Epilogue::OutputTileIterator::Element;
|
||||
using LayoutC = typename Mma::LayoutC;
|
||||
using ElementScale = ElementC;
|
||||
|
||||
static ComplexTransform const kTransformA = Mma::kTransformA;
|
||||
static ComplexTransform const kTransformB = Mma::kTransformA;
|
||||
|
||||
// Type definitions about the mainloop.
|
||||
using Operator = typename Mma::Operator;
|
||||
using OperatorClass = typename Mma::Operator::OperatorClass;
|
||||
using ThreadblockShape = typename Mma::Shape;
|
||||
using WarpShape = typename Mma::Operator::Shape;
|
||||
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
|
||||
using ArchTag = typename Mma::ArchTag;
|
||||
|
||||
static int const kStages = Mma::kStages;
|
||||
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
|
||||
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
|
||||
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
/// Warp count (concept: GemmShape)
|
||||
using WarpCount = typename Mma::WarpCount;
|
||||
static int const kThreadCount = 32 * WarpCount::kCount;
|
||||
|
||||
static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK;
|
||||
|
||||
/// Parameters structure
|
||||
struct Arguments
|
||||
{
|
||||
GemmUniversalMode mode = GemmUniversalMode::kGemm;
|
||||
|
||||
cutlass::gemm::GemmCoord problem_size;
|
||||
int group_size;
|
||||
typename Mma::IteratorA::TensorRef ref_A;
|
||||
typename Mma::IteratorB::TensorRef ref_B;
|
||||
typename Mma::IteratorScale::TensorRef ref_scale;
|
||||
typename Mma::IteratorScale::TensorRef ref_zero;
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C;
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_D;
|
||||
|
||||
// Control serial split-k
|
||||
int batch_count;
|
||||
|
||||
typename EpilogueOutputOp::Params output_op;
|
||||
|
||||
// For gather+scatter operations
|
||||
int const* gather_A_indices;
|
||||
int const* gather_B_indices;
|
||||
int const* scatter_D_indices;
|
||||
|
||||
// Included so we can use Gemm Universal
|
||||
int batch_stride_D = 0;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments() {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(cutlass::gemm::GemmCoord const& problem_size, int const group_size,
|
||||
typename Mma::IteratorA::TensorRef ref_A, typename Mma::IteratorB::TensorRef ref_B,
|
||||
typename Mma::IteratorScale::TensorRef ref_scale, typename Mma::IteratorScale::TensorRef ref_zero,
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C,
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_D, int serial_split_k_factor,
|
||||
typename EpilogueOutputOp::Params output_op = typename EpilogueOutputOp::Params(),
|
||||
int const* gather_A_indices = nullptr, int const* gather_B_indices = nullptr,
|
||||
int const* scatter_D_indices = nullptr)
|
||||
: problem_size(problem_size)
|
||||
, group_size(group_size)
|
||||
, ref_A(ref_A)
|
||||
, ref_B(ref_B)
|
||||
, ref_scale(ref_scale)
|
||||
, ref_zero(ref_zero)
|
||||
, ref_C(ref_C)
|
||||
, ref_D(ref_D)
|
||||
, batch_count(serial_split_k_factor)
|
||||
, output_op(output_op)
|
||||
, gather_A_indices(gather_A_indices)
|
||||
, gather_B_indices(gather_B_indices)
|
||||
, scatter_D_indices(scatter_D_indices)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
/// Parameters structure
|
||||
struct Params
|
||||
{
|
||||
cutlass::gemm::GemmCoord problem_size;
|
||||
int group_size;
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int swizzle_log_tile;
|
||||
typename Mma::IteratorA::Params params_A;
|
||||
typename Mma::IteratorA::TensorRef ref_A;
|
||||
typename Mma::IteratorB::Params params_B;
|
||||
typename Mma::IteratorB::TensorRef ref_B;
|
||||
typename Mma::IteratorScale::Params params_scale;
|
||||
typename Mma::IteratorScale::TensorRef ref_scale;
|
||||
typename Mma::IteratorScale::TensorRef ref_zero;
|
||||
typename Epilogue::OutputTileIterator::Params params_C;
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C;
|
||||
typename Epilogue::OutputTileIterator::Params params_D;
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_D;
|
||||
typename EpilogueOutputOp::Params output_op;
|
||||
int* semaphore;
|
||||
int gemm_k_size;
|
||||
// For gather+scatter operations
|
||||
int const* gather_A_indices;
|
||||
int const* gather_B_indices;
|
||||
int const* scatter_D_indices;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params()
|
||||
: swizzle_log_tile(0)
|
||||
, semaphore(0)
|
||||
, gemm_k_size(0)
|
||||
{
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape, int const gemm_k_size,
|
||||
void* workspace = nullptr)
|
||||
: problem_size(args.problem_size)
|
||||
, group_size(args.group_size)
|
||||
, grid_tiled_shape(grid_tiled_shape)
|
||||
, swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape))
|
||||
, params_A(args.ref_A.layout())
|
||||
, ref_A(args.ref_A)
|
||||
, params_B(args.ref_B.layout())
|
||||
, ref_B(args.ref_B)
|
||||
, params_scale(args.ref_scale.layout())
|
||||
, ref_scale(args.ref_scale)
|
||||
, ref_zero(args.ref_zero)
|
||||
, params_C(args.ref_C.layout())
|
||||
, ref_C(args.ref_C)
|
||||
, params_D(args.ref_D.layout())
|
||||
, ref_D(args.ref_D)
|
||||
, output_op(args.output_op)
|
||||
, semaphore(static_cast<int*>(workspace))
|
||||
, gemm_k_size(gemm_k_size)
|
||||
, gather_A_indices(args.gather_A_indices)
|
||||
, gather_B_indices(args.gather_B_indices)
|
||||
, scatter_D_indices(args.scatter_D_indices)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared memory storage structure
|
||||
union SharedStorage
|
||||
{
|
||||
typename Mma::SharedStorage main_loop;
|
||||
typename Epilogue::SharedStorage epilogue;
|
||||
};
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
GemmFpAIntB() {}
|
||||
|
||||
/// Determines whether kernel satisfies alignment
|
||||
static Status can_implement(Arguments const& args)
|
||||
{
|
||||
static int const kAlignmentA
|
||||
= (platform::is_same<typename Mma::IteratorA::Layout, layout::ColumnMajorInterleaved<32>>::value) ? 32
|
||||
: (platform::is_same<typename Mma::IteratorA::Layout, layout::ColumnMajorInterleaved<64>>::value)
|
||||
? 64
|
||||
: Mma::IteratorA::AccessType::kElements;
|
||||
static int const kAlignmentB
|
||||
= (platform::is_same<typename Mma::IteratorB::Layout, layout::RowMajorInterleaved<32>>::value) ? 32
|
||||
: (platform::is_same<typename Mma::IteratorB::Layout, layout::RowMajorInterleaved<64>>::value)
|
||||
? 64
|
||||
: Mma::IteratorB::AccessType::kElements;
|
||||
|
||||
static int const kAlignmentScale = Mma::IteratorScale::AccessType::kElements;
|
||||
|
||||
static int const kAlignmentC = (platform::is_same<typename Epilogue::OutputTileIterator::Layout,
|
||||
layout::ColumnMajorInterleaved<32>>::value)
|
||||
? 32
|
||||
: (platform::is_same<typename Epilogue::OutputTileIterator::Layout,
|
||||
layout::ColumnMajorInterleaved<64>>::value)
|
||||
? 64
|
||||
: Epilogue::OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
if (!TensorRef_aligned(args.ref_A, kAlignmentA))
|
||||
{
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (!TensorRef_aligned(args.ref_B, kAlignmentB))
|
||||
{
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (!TensorRef_aligned(args.ref_scale, kAlignmentScale))
|
||||
{
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (!TensorRef_aligned(args.ref_zero, kAlignmentScale))
|
||||
{
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (!TensorRef_aligned(args.ref_C, kAlignmentC))
|
||||
{
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (!TensorRef_aligned(args.ref_D, kAlignmentC))
|
||||
{
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (!args.ref_scale.good())
|
||||
{
|
||||
return Status::kErrorNotSupported;
|
||||
}
|
||||
|
||||
if constexpr (hasZero(Mma::QuantOp))
|
||||
{
|
||||
if (!args.ref_zero.good())
|
||||
{
|
||||
return Status::kErrorNotSupported;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if (args.ref_zero.good())
|
||||
{
|
||||
return Status::kErrorNotSupported;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (isFinegrained(Mma::QuantOp))
|
||||
{
|
||||
if (args.group_size != 64 && args.group_size != 128)
|
||||
{
|
||||
return Status::kErrorNotSupported;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape)
|
||||
{
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Initializes the fine grained scale+bias iterator. Needed since the fine grained iterator
|
||||
// has a different constructor signature than a regular cutlass iterator
|
||||
template <typename IteratorScale, WeightOnlyQuantOp op, std::enable_if_t<isFinegrained(op), bool> = true>
|
||||
CUTLASS_DEVICE static IteratorScale initialize_scale(typename IteratorScale::Params const& params,
|
||||
typename IteratorScale::Pointer pointer_scale, typename IteratorScale::Pointer pointer_zero,
|
||||
typename IteratorScale::TensorCoord extent, int thread_id,
|
||||
typename IteratorScale::TensorCoord const& threadblock_offset, int group_size)
|
||||
{
|
||||
|
||||
return IteratorScale(params, pointer_scale, pointer_zero, extent, thread_id, threadblock_offset, group_size);
|
||||
}
|
||||
|
||||
template <typename IteratorScale, WeightOnlyQuantOp op, std::enable_if_t<!isFinegrained(op), bool> = true>
|
||||
CUTLASS_DEVICE static IteratorScale initialize_scale(typename IteratorScale::Params const& params,
|
||||
typename IteratorScale::Pointer pointer_scale, typename IteratorScale::Pointer pointer_zero,
|
||||
typename IteratorScale::TensorCoord extent, int thread_id,
|
||||
typename IteratorScale::TensorCoord const& threadblock_offset, int group_size)
|
||||
{
|
||||
|
||||
return IteratorScale(params, pointer_scale, extent, thread_id, threadblock_offset);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void run_kernel_(Params const& params, SharedStorage& shared_storage)
|
||||
{
|
||||
using LayoutB = typename Mma::IteratorB::Layout;
|
||||
static_assert(platform::is_same<LayoutB, layout::RowMajor>::value && kInterleave == 1
|
||||
|| platform::is_same<LayoutB, layout::ColumnMajor>::value && kInterleave >= 1,
|
||||
"B must be row major/col major OR col major interleaved.");
|
||||
|
||||
// Compute threadblock location
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
// Early exit if CTA is out of range
|
||||
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m()
|
||||
|| params.grid_tiled_shape.n() <= threadblock_tile_offset.n())
|
||||
{
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
// Compute initial location in logical coordinates
|
||||
cutlass::MatrixCoord tb_offset_A{
|
||||
threadblock_tile_offset.m() * Mma::Shape::kM,
|
||||
threadblock_tile_offset.k() * params.gemm_k_size,
|
||||
};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_B{threadblock_tile_offset.k() * params.gemm_k_size * kInterleave,
|
||||
threadblock_tile_offset.n() * Mma::Shape::kN / kInterleave};
|
||||
|
||||
typename MatrixCoord::Index fg_row_offset = threadblock_tile_offset.k() * params.gemm_k_size / 64;
|
||||
typename MatrixCoord::Index scale_row_offset = isFinegrained(Mma::QuantOp) ? fg_row_offset : 0;
|
||||
cutlass::MatrixCoord tb_offset_scale{scale_row_offset, threadblock_tile_offset.n() * Mma::Shape::kN};
|
||||
|
||||
// Problem size is a function of threadblock index in the K dimension
|
||||
int problem_size_k = min(params.problem_size.k(), (threadblock_tile_offset.k() + 1) * params.gemm_k_size);
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK;
|
||||
|
||||
// Compute position within threadblock
|
||||
int thread_idx = threadIdx.x;
|
||||
|
||||
// Construct iterators to A and B operands
|
||||
typename Mma::IteratorA iterator_A(params.params_A, params.ref_A.data(),
|
||||
{params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A, params.gather_A_indices);
|
||||
|
||||
typename Mma::IteratorB iterator_B(params.params_B, params.ref_B.data(),
|
||||
{problem_size_k * kInterleave, params.problem_size.n() / kInterleave}, thread_idx, tb_offset_B,
|
||||
params.gather_B_indices);
|
||||
|
||||
typename MatrixCoord::Index scale_row_extent = isFinegrained(Mma::QuantOp) ? problem_size_k / 64 : 1;
|
||||
typename Mma::IteratorScale iterator_scale = initialize_scale<typename Mma::IteratorScale, Mma::QuantOp>(
|
||||
params.params_scale, params.ref_scale.data(), params.ref_zero.data(),
|
||||
{scale_row_extent, params.problem_size.n()}, thread_idx, tb_offset_scale, params.group_size);
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
//
|
||||
// Main loop
|
||||
//
|
||||
// Construct thread-scoped matrix multiply
|
||||
Mma mma(shared_storage.main_loop, params.group_size, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
typename Mma::FragmentC accumulators;
|
||||
|
||||
accumulators.clear();
|
||||
|
||||
if (!kSplitKSerial || gemm_k_iterations > 0)
|
||||
{
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_scale, accumulators);
|
||||
}
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
//
|
||||
|
||||
EpilogueOutputOp output_op(params.output_op);
|
||||
|
||||
//
|
||||
// Masked tile iterators constructed from members
|
||||
//
|
||||
|
||||
threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
// assume identity swizzle
|
||||
MatrixCoord threadblock_offset(
|
||||
threadblock_tile_offset.m() * Mma::Shape::kM, threadblock_tile_offset.n() * Mma::Shape::kN);
|
||||
|
||||
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
|
||||
|
||||
// Construct the semaphore.
|
||||
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
|
||||
|
||||
// If performing a reduction via split-K, fetch the initial synchronization
|
||||
if (kSplitKSerial && params.grid_tiled_shape.k() > 1)
|
||||
{
|
||||
|
||||
// Fetch the synchronization lock initially but do not block.
|
||||
semaphore.fetch();
|
||||
|
||||
// Indicate which position in a serial reduction the output operator is currently updating
|
||||
output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
|
||||
}
|
||||
|
||||
// Tile iterator loading from source tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_C(params.params_C, params.ref_C.data(), params.problem_size.mn(),
|
||||
thread_idx, threadblock_offset, params.scatter_D_indices);
|
||||
|
||||
// Tile iterator writing to destination tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_D(params.params_D, params.ref_D.data(), params.problem_size.mn(),
|
||||
thread_idx, threadblock_offset, params.scatter_D_indices);
|
||||
|
||||
Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
// Wait on the semaphore - this latency may have been covered by iterator construction
|
||||
if (kSplitKSerial && params.grid_tiled_shape.k() > 1)
|
||||
{
|
||||
|
||||
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
|
||||
if (threadblock_tile_offset.k())
|
||||
{
|
||||
iterator_C = iterator_D;
|
||||
}
|
||||
|
||||
semaphore.wait(threadblock_tile_offset.k());
|
||||
}
|
||||
|
||||
// Execute the epilogue operator to update the destination tensor.
|
||||
epilogue(output_op, iterator_D, accumulators, iterator_C);
|
||||
|
||||
//
|
||||
// Release the semaphore
|
||||
//
|
||||
|
||||
if (kSplitKSerial && params.grid_tiled_shape.k() > 1)
|
||||
{
|
||||
|
||||
int lock = 0;
|
||||
if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1)
|
||||
{
|
||||
|
||||
// The final threadblock resets the semaphore for subsequent grids.
|
||||
lock = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
// Otherwise, the semaphore is incremented
|
||||
lock = threadblock_tile_offset.k() + 1;
|
||||
}
|
||||
|
||||
semaphore.release(lock);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename CompilationArch>
|
||||
CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage)
|
||||
{
|
||||
if constexpr (platform::is_same<KernelArch, CompilationArch>::value)
|
||||
{
|
||||
run_kernel_(params, shared_storage);
|
||||
}
|
||||
else
|
||||
{
|
||||
CUTLASS_NOT_IMPLEMENTED();
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond
|
||||
to the ArchTag of the cutlass kernel operator.
|
||||
*/
|
||||
/// Executes one GEMM
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const& params, SharedStorage& shared_storage)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__)
|
||||
#if (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800)
|
||||
run_kernel<arch::Sm75>(params, shared_storage);
|
||||
#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 890)
|
||||
run_kernel<arch::Sm80>(params, shared_storage);
|
||||
#elif (__CUDA_ARCH__ == 890)
|
||||
run_kernel<arch::Sm89>(params, shared_storage);
|
||||
#elif (__CUDA_ARCH__ >= 900)
|
||||
CUTLASS_NOT_IMPLEMENTED(); // Don't compile these for Hopper or later. Use CUTLASS 3.x kernels.
|
||||
#else
|
||||
static_assert(
|
||||
false, "Invalid architecture being compiled. Only Volta+ supported in weight-only quantization kernels.");
|
||||
#endif
|
||||
#else
|
||||
CUTLASS_NOT_IMPLEMENTED();
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@@ -0,0 +1,218 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include <cutlass/gemm/kernel/gemm_grouped_problem_visitor.h>
|
||||
#include <cutlass/trace.h>
|
||||
#include <cutlass_extensions/gemm/kernel/fused_moe_kernel_routine.cuh>
|
||||
#include <cutlass_extensions/gemm/kernel/fused_moe_kernel_traits.cuh>
|
||||
#include <cutlass_extensions/gemm/kernel/moe_problem_visitor.h>
|
||||
|
||||
namespace fused_moe
|
||||
{
|
||||
template <typename ElementInput_, typename ElementWeight_, typename ElementOutput_, int MaxTileM_, int TileN_,
|
||||
int TileK_, int Stages_, Activation_Type activation_type_>
|
||||
struct Fused_Moe_Kernel_sm80
|
||||
{
|
||||
static constexpr int kMaxTileM = MaxTileM_;
|
||||
static constexpr int kTileN = isGateActivation(activation_type_) ? TileN_ / 2 : TileN_;
|
||||
static constexpr int kTileK = TileK_;
|
||||
static constexpr int kStages = Stages_;
|
||||
static constexpr Activation_Type activation_type = activation_type_;
|
||||
|
||||
using ElementInput = ElementInput_;
|
||||
using ElementWeight = ElementWeight_;
|
||||
using ElementOutput = ElementOutput_;
|
||||
using BaseKernelTraits = Fused_Moe_Kernel_traits_sm80<ElementInput, ElementWeight, ElementOutput, kMaxTileM, kTileN,
|
||||
kTileK, kStages, activation_type>;
|
||||
using Routine_Arguments = Routine_Arguments<ElementInput, ElementWeight, ElementOutput>;
|
||||
using Routine_Params = Routine_Params<ElementInput, ElementWeight, ElementOutput>;
|
||||
using ProblemVisitor
|
||||
= cutlass::gemm::kernel::MoeProblemVisitor<cutlass::gemm::kernel::detail::GemmGroupedProblemSizeHelper<
|
||||
cutlass::gemm::GemmShape<kMaxTileM, kTileN, kTileK>, false>,
|
||||
cutlass::gemm::GemmShape<kMaxTileM, kTileN, kTileK>, cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly,
|
||||
BaseKernelTraits::kThreadCount, BaseKernelTraits::kThreadCount>;
|
||||
|
||||
struct Arguments
|
||||
{
|
||||
Routine_Arguments routine_args;
|
||||
int problem_count{};
|
||||
int threadblock_count{};
|
||||
};
|
||||
|
||||
struct Params
|
||||
{
|
||||
Routine_Params routine_params;
|
||||
int threadblock_count{};
|
||||
typename ProblemVisitor::Params problem_visitor_param;
|
||||
};
|
||||
|
||||
using BaseKernelTraits_m16 = Fused_Moe_Kernel_traits_sm80<ElementInput, ElementWeight, ElementOutput, 16, kTileN,
|
||||
kTileK, kStages, activation_type>;
|
||||
static constexpr bool use_m16 = TileK_ >= 64; // use tileshape m = 16 when original tileshape k >= 64
|
||||
|
||||
static constexpr int kSmemSize = use_m16
|
||||
? (BaseKernelTraits::kSmemSize > BaseKernelTraits_m16::kSmemSize ? BaseKernelTraits::kSmemSize
|
||||
: BaseKernelTraits_m16::kSmemSize)
|
||||
: BaseKernelTraits::kSmemSize;
|
||||
static constexpr int kThreadCount = BaseKernelTraits::kThreadCount;
|
||||
|
||||
static constexpr bool can_implement(int const avaliable_smem_size)
|
||||
{
|
||||
return BaseKernelTraits::can_implement(avaliable_smem_size);
|
||||
}
|
||||
|
||||
static Params to_underlying_arguments(Arguments const& args)
|
||||
{
|
||||
return {
|
||||
{args.routine_args.ptr_input, args.routine_args.ptr_fc1, args.routine_args.ptr_bias,
|
||||
args.routine_args.ptr_output, args.routine_args.total_tokens_including_expert, args.routine_args.gemm_n,
|
||||
args.routine_args.gemm_k, args.routine_args.num_expert, args.routine_args.bias_is_broadcast},
|
||||
args.threadblock_count,
|
||||
{args.routine_args.total_tokens_including_expert, args.routine_args.gemm_n, args.routine_args.gemm_k,
|
||||
args.problem_count, nullptr, 0}};
|
||||
}
|
||||
|
||||
CUTE_DEVICE
|
||||
void run_device(Params const& params)
|
||||
{
|
||||
#define ROUTINE_PATH(kTileM_size) \
|
||||
{ \
|
||||
constexpr int kTileM = use_m16 ? (kTileM_size) : ((kTileM_size) == 16 ? 32 : (kTileM_size)); \
|
||||
using RoutineTraits = Fused_Moe_Kernel_routine_sm80<ElementInput, ElementWeight, ElementOutput, kTileM, \
|
||||
kTileN, kTileK, kStages, activation_type>; \
|
||||
RoutineTraits routine{}; \
|
||||
int const block_m_idx = (block_m_idx_temp) *kMaxTileM / kTileM; \
|
||||
routine.run_routine(params.routine_params, problem_index, block_m_idx, block_n_idx, gemm_m); \
|
||||
}
|
||||
typename ProblemVisitor::SharedStorage dummy_storage{};
|
||||
ProblemVisitor problem_visitor(params.problem_visitor_param, dummy_storage, blockIdx.x);
|
||||
while (problem_visitor.next_tile())
|
||||
{
|
||||
auto problem_size = problem_visitor.problem_size();
|
||||
auto grid_size = problem_visitor.grid_shape(problem_size);
|
||||
auto problem_index = problem_visitor.problem_index();
|
||||
int32_t cta_idx = int32_t(problem_visitor.threadblock_idx());
|
||||
int const gemm_m = problem_size.m();
|
||||
const int32_t block_m_idx_temp = cta_idx / grid_size.n();
|
||||
const int32_t block_n_idx = cta_idx % grid_size.n();
|
||||
|
||||
int const residue_m = gemm_m - kMaxTileM * block_m_idx_temp;
|
||||
if (residue_m > kMaxTileM / 2)
|
||||
{
|
||||
using RoutineTraits = Fused_Moe_Kernel_routine_sm80<ElementInput, ElementWeight, ElementOutput,
|
||||
kMaxTileM, kTileN, kTileK, kStages, activation_type>;
|
||||
RoutineTraits routine{};
|
||||
routine.run_routine(params.routine_params, problem_index, block_m_idx_temp, block_n_idx, gemm_m);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
if constexpr (kMaxTileM >= 128)
|
||||
{
|
||||
if (residue_m > 32)
|
||||
{
|
||||
ROUTINE_PATH(64);
|
||||
}
|
||||
else if (residue_m > 16)
|
||||
{
|
||||
ROUTINE_PATH(32);
|
||||
}
|
||||
else
|
||||
{
|
||||
// TODO: use cuda core gemm here
|
||||
ROUTINE_PATH(16);
|
||||
}
|
||||
}
|
||||
else if (kMaxTileM == 64)
|
||||
{
|
||||
if (residue_m > 16)
|
||||
{
|
||||
ROUTINE_PATH(32);
|
||||
}
|
||||
else
|
||||
{
|
||||
// TODO: use cuda core gemm here
|
||||
ROUTINE_PATH(16);
|
||||
}
|
||||
}
|
||||
else if (kMaxTileM == 32)
|
||||
{
|
||||
// TODO: use cuda core gemm here
|
||||
ROUTINE_PATH(16);
|
||||
}
|
||||
else
|
||||
{
|
||||
// TODO: use cuda core gemm here
|
||||
ROUTINE_PATH(16);
|
||||
}
|
||||
}
|
||||
problem_visitor.advance(gridDim.x);
|
||||
}
|
||||
#undef ROUTINE_PATH
|
||||
}
|
||||
};
|
||||
|
||||
template <typename GemmType>
|
||||
__global__ void run_global(__grid_constant__ typename GemmType::Params const params)
|
||||
{
|
||||
GemmType gemm;
|
||||
gemm.run_device(params);
|
||||
}
|
||||
|
||||
/// Computes the maximum number of active blocks per multiprocessor
|
||||
template <typename GemmType>
|
||||
static int fused_gemm_maximum_active_blocks(int smem_capacity = -1)
|
||||
{
|
||||
|
||||
CUTLASS_TRACE_HOST("BaseGrouped::maximum_active_blocks()");
|
||||
|
||||
constexpr int smem_size = GemmType::kSmemSize;
|
||||
|
||||
CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes");
|
||||
|
||||
cudaError_t result;
|
||||
if (smem_size > (48 << 10))
|
||||
{
|
||||
result = cudaFuncSetAttribute(run_global<GemmType>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
|
||||
if (result != cudaSuccess)
|
||||
{
|
||||
// Call cudaGetLastError() to clear the error bit
|
||||
result = cudaGetLastError();
|
||||
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(result));
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
int max_active_blocks = -1;
|
||||
result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_active_blocks, run_global<GemmType>, GemmType::kThreadCount, smem_size);
|
||||
|
||||
if (result != cudaSuccess)
|
||||
{
|
||||
// Call cudaGetLastError() to clear the error bit
|
||||
result = cudaGetLastError();
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result));
|
||||
return -1;
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
|
||||
return max_active_blocks;
|
||||
}
|
||||
} // namespace fused_moe
|
||||
@@ -0,0 +1,799 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include <cutlass_extensions/gemm/kernel/fused_moe_kernel_traits.cuh>
|
||||
|
||||
namespace fused_moe
|
||||
{
|
||||
|
||||
template <typename ElementInput_, typename ElementWeight_, typename ElementOutput_, int TileM_, int TileN_, int TileK_,
|
||||
int Stages_, Activation_Type activation_type_, typename Enable = void>
|
||||
struct Fused_Moe_Kernel_routine_sm80;
|
||||
|
||||
template <typename ElementInput_, typename ElementWeight_, typename ElementOutput_, int TileM_, int TileN_, int TileK_,
|
||||
int Stages_, Activation_Type activation_type_>
|
||||
struct Fused_Moe_Kernel_routine_sm80<ElementInput_, ElementWeight_, ElementOutput_, TileM_, TileN_, TileK_, Stages_,
|
||||
activation_type_, std::enable_if_t<isGateActivation(activation_type_)>>
|
||||
{
|
||||
using KT = Fused_Moe_Kernel_traits_sm80<ElementInput_, ElementWeight_, ElementOutput_, TileM_, TileN_, TileK_,
|
||||
Stages_, activation_type_>;
|
||||
using Params = Routine_Params<ElementInput_, ElementWeight_, ElementOutput_>;
|
||||
|
||||
CUTE_DEVICE auto gmem_tensor_init(int const problem_index, int const gemm_m, Params const& params)
|
||||
{
|
||||
using X = cute::Underscore;
|
||||
|
||||
int const M = gemm_m;
|
||||
int const N1 = params.gemm_n;
|
||||
int const K1 = params.gemm_k;
|
||||
bool const bias_is_broadcast = params.bias_is_broadcast;
|
||||
|
||||
int const row_jump = ((problem_index == 0) ? 0 : params.total_tokens_including_expert[problem_index - 1]);
|
||||
typename KT::ElementInput const* ptr_input_ = params.ptr_input + row_jump * K1;
|
||||
typename KT::ElementWeight const* ptr_fc1_gate_
|
||||
= params.ptr_fc1 + (2 * problem_index + 1) * N1 * K1; // TODO: we only focus on gated activation..
|
||||
typename KT::ElementWeight const* ptr_fc1_
|
||||
= params.ptr_fc1 + 2 * problem_index * N1 * K1; // TODO: we only focus on gated activation..
|
||||
typename KT::ElementInput const* ptr_bias_ = (params.ptr_bias == nullptr)
|
||||
? nullptr
|
||||
: (bias_is_broadcast ? params.ptr_bias + 2 * problem_index * N1 : params.ptr_bias + 2 * row_jump * N1);
|
||||
typename KT::ElementInput const* ptr_bias_gate_ = (params.ptr_bias == nullptr)
|
||||
? nullptr
|
||||
: (bias_is_broadcast ? params.ptr_bias + (2 * problem_index + 1) * N1
|
||||
: params.ptr_bias + (2 * row_jump + 1) * N1);
|
||||
typename KT::ElementOutput* ptr_output_ = params.ptr_output + row_jump * N1;
|
||||
|
||||
cute::Tensor mInput_mk
|
||||
= cute::make_tensor(cute::make_gmem_ptr(static_cast<typename KT::ElementInput const*>(ptr_input_)),
|
||||
cute::make_shape(M, K1), cute::make_stride(K1, cute::_1{}));
|
||||
|
||||
cute::Tensor mfc1_gate_nk
|
||||
= cute::make_tensor(cute::make_gmem_ptr(static_cast<typename KT::ElementWeight const*>(ptr_fc1_gate_)),
|
||||
cute::make_shape(N1, K1), cute::make_stride(K1, cute::_1{}));
|
||||
|
||||
cute::Tensor mfc1_nk
|
||||
= cute::make_tensor(cute::make_gmem_ptr(static_cast<typename KT::ElementWeight const*>(ptr_fc1_)),
|
||||
cute::make_shape(N1, K1), cute::make_stride(K1, cute::_1{}));
|
||||
|
||||
cute::Tensor mBias_mn = cute::make_tensor(
|
||||
cute::make_gmem_ptr(static_cast<typename KT::ElementInput const*>(ptr_bias_)), cute::make_shape(M, N1),
|
||||
cute::make_stride(bias_is_broadcast ? cute::Int<0>{} : N1 * 2,
|
||||
cute::_1{})); // trick: bias shape is [1, N], but we use [M, N].
|
||||
|
||||
cute::Tensor mBias_gate_mn = cute::make_tensor(
|
||||
cute::make_gmem_ptr(static_cast<typename KT::ElementInput const*>(ptr_bias_gate_)), cute::make_shape(M, N1),
|
||||
cute::make_stride(bias_is_broadcast ? cute::Int<0>{} : N1 * 2,
|
||||
cute::_1{})); // trick: bias shape is [1, N], but we use [M, N].
|
||||
|
||||
cute::Tensor mOutput_mn
|
||||
= cute::make_tensor(cute::make_gmem_ptr(static_cast<typename KT::ElementInput*>(ptr_output_)),
|
||||
cute::make_shape(M, N1), cute::make_stride(N1, cute::_1{}));
|
||||
|
||||
cute::Tensor gInput_mk = cute::local_tile(mInput_mk, typename KT::TileShape{},
|
||||
cute::make_coord(cute::_, cute::_, cute::_), cute::Step<cute::_1, X, cute::_1>{}); // (BLK_M, BLK_K, m, k)
|
||||
cute::Tensor gfc1_gate_nk = cute::local_tile(mfc1_gate_nk, typename KT::TileShape{},
|
||||
cute::make_coord(cute::_, cute::_, cute::_), cute::Step<X, cute::_1, cute::_1>{}); // (BLK_N, BLK_K, n, k)
|
||||
cute::Tensor gfc1_nk = cute::local_tile(mfc1_nk, typename KT::TileShape{},
|
||||
cute::make_coord(cute::_, cute::_, cute::_), cute::Step<X, cute::_1, cute::_1>{}); // (BLK_N, BLK_K, n, k)
|
||||
|
||||
cute::Tensor gBias_mn = cute::local_tile(mBias_mn, typename KT::TileShape{},
|
||||
cute::make_coord(cute::_, cute::_, cute::_), cute::Step<cute::_1, cute::_1, X>{}); // (BLK_M, BLK_N, m, n)
|
||||
|
||||
cute::Tensor gBias_gate_mn = cute::local_tile(mBias_gate_mn, typename KT::TileShape{},
|
||||
cute::make_coord(cute::_, cute::_, cute::_), cute::Step<cute::_1, cute::_1, X>{}); // (BLK_M, BLK_N, m, n)
|
||||
|
||||
cute::Tensor gOutput_mn = cute::local_tile(mOutput_mn, typename KT::TileShape{},
|
||||
cute::make_coord(cute::_, cute::_, cute::_), cute::Step<cute::_1, cute::_1, X>{}); // (BLK_M, BLK_N, m, n)
|
||||
|
||||
return cute::make_tuple(gInput_mk, gfc1_gate_nk, gfc1_nk, gBias_mn, gBias_gate_mn, gOutput_mn);
|
||||
}
|
||||
|
||||
// be careful, m_idx will change when use another tile shape..
|
||||
CUTE_DEVICE void run_routine(
|
||||
Params const& params, int const problem_index, int const block_m_idx, int const block_n_idx, int const gemm_m)
|
||||
{
|
||||
extern __shared__ char smem_[];
|
||||
typename KT::SharedStorage& shared_storage = *reinterpret_cast<typename KT::SharedStorage*>(smem_);
|
||||
int const thread_idx = threadIdx.x;
|
||||
bool const bias_is_broadcast = params.bias_is_broadcast;
|
||||
// gmem tensor partition ..
|
||||
auto [gInput_mk, gfc1_gate_nk, gfc1_nk, gBias_mn, gBias_gate_mn, gOutput_mn]
|
||||
= gmem_tensor_init(problem_index, gemm_m, params);
|
||||
int const residue_m = gemm_m - block_m_idx * cute::size<0>(gInput_mk);
|
||||
auto const n_tile_count = cute::size<2>(gfc1_gate_nk);
|
||||
|
||||
// smem tensor ..
|
||||
cute::Tensor sInput = cute::make_tensor(
|
||||
cute::make_smem_ptr(shared_storage.smem_input.data()), typename KT::SmemLayoutA{}); // (BLK_M, BLK_K, Stage)
|
||||
cute::Tensor sfc1_weight = cute::make_tensor(cute::make_smem_ptr(shared_storage.smem_fc1_weight.data()),
|
||||
typename KT::SmemLayoutB{}); // (BLK_N, BLK_K, Stage)
|
||||
cute::Tensor sfc1_gate_weight
|
||||
= cute::make_tensor(cute::make_smem_ptr(shared_storage.smem_fc1_gate_weight.data()),
|
||||
typename KT::SmemLayoutB{}); // (BLK_N, BLK_K, Stage)
|
||||
cute::Tensor sO = cute::make_tensor(
|
||||
cute::make_smem_ptr(shared_storage.smem_o.data()), typename KT::SmemLayoutO{}); // (BLK_M, BLK_N)
|
||||
|
||||
// (1) first step, get the fc1_res and fc1_gate
|
||||
|
||||
// (1.1) get partition for gmem -> smem
|
||||
cute::Tensor gInput = gInput_mk(cute::_, cute::_, block_m_idx, cute::_); // (BLK_M, BLK_K, k)
|
||||
cute::Tensor gfc1 = gfc1_nk(cute::_, cute::_, block_n_idx, cute::_); // (BLK_N, BLK_K, k)
|
||||
cute::Tensor gfc1g = gfc1_gate_nk(cute::_, cute::_, block_n_idx, cute::_); // (BLK_N, BLK_K, k)
|
||||
|
||||
typename KT::GmemTiledCopyA gmem_tiled_copy_A;
|
||||
typename KT::GmemTiledCopyB gmem_tiled_copy_B;
|
||||
auto gmem_thr_copy_A = gmem_tiled_copy_A.get_slice(thread_idx);
|
||||
auto gmem_thr_copy_B = gmem_tiled_copy_B.get_slice(thread_idx);
|
||||
|
||||
cute::Tensor tInputgInput = gmem_thr_copy_A.partition_S(gInput); // (ACPY,ACPY_M,ACPY_K,k)
|
||||
cute::Tensor tInputsInput = gmem_thr_copy_A.partition_D(sInput); // (ACPY,ACPY_M,ACPY_K,Stage)
|
||||
cute::Tensor tfc1gfc1 = gmem_thr_copy_B.partition_S(gfc1); // (BCPY,BCPY_N,BCPY_K,k)
|
||||
cute::Tensor tfc1sfc1 = gmem_thr_copy_B.partition_D(sfc1_weight); // (BCPY,BCPY_N,BCPY_K,Stage)
|
||||
cute::Tensor tfc1ggfc1g = gmem_thr_copy_B.partition_S(gfc1g); // (BCPY,BCPY_N,BCPY_K,k)
|
||||
cute::Tensor tfc1gsfc1g = gmem_thr_copy_B.partition_D(sfc1_gate_weight); // (BCPY,BCPY_N,BCPY_K,Stage)
|
||||
|
||||
// Allocate predicate tensors for input and fc weight (actually we only need input predicate tensor)
|
||||
cute::Tensor tInputpInput
|
||||
= cute::make_tensor<bool>(cute::make_shape(cute::size<1>(tInputsInput), cute::size<2>(tInputsInput)),
|
||||
cute::Stride<cute::_1, cute::_0>{});
|
||||
// Construct identity layout for sInput
|
||||
cute::Tensor cInput = make_identity_tensor(
|
||||
make_shape(cute::size<0>(sInput), cute::size<1>(sInput))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
||||
|
||||
// Repeat the partitioning with identity layouts
|
||||
cute::Tensor tInputcInput = gmem_thr_copy_A.partition_S(cInput); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
|
||||
|
||||
// Set predicates for m bounds
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int m = 0; m < cute::size<0>(tInputpInput); ++m)
|
||||
{
|
||||
tInputpInput(m, 0) = cute::get<0>(tInputcInput(0, m, 0)) < residue_m; // blk_m coord < residue_m
|
||||
}
|
||||
|
||||
// (1.2) prefetch gmem -> smem
|
||||
cute::clear(tInputsInput); // we don't need to clear tfc1sfc1..
|
||||
auto k_tile_iter = cute::make_coord_iterator(cute::size<2>(gInput)); // emm, iter start from 0
|
||||
int k_tile_count = cute::size<2>(gInput);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_pipe = 0; k_pipe < KT::Stages - 1; ++k_pipe)
|
||||
{
|
||||
if (k_tile_count <= 0)
|
||||
{
|
||||
cute::clear(tInputpInput);
|
||||
}
|
||||
// cute::copy(gmem_tiled_copy_A, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter),
|
||||
// tInputsInput(cute::_, cute::_, cute::_, k_pipe));
|
||||
// use copy_if
|
||||
cute::copy_if(gmem_tiled_copy_A, tInputpInput, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter),
|
||||
tInputsInput(cute::_, cute::_, cute::_, k_pipe));
|
||||
cute::copy(gmem_tiled_copy_B, tfc1gfc1(cute::_, cute::_, cute::_, *k_tile_iter),
|
||||
tfc1sfc1(cute::_, cute::_, cute::_, k_pipe));
|
||||
cute::copy(gmem_tiled_copy_B, tfc1ggfc1g(cute::_, cute::_, cute::_, *k_tile_iter),
|
||||
tfc1gsfc1g(cute::_, cute::_, cute::_, k_pipe));
|
||||
cute::cp_async_fence();
|
||||
k_tile_count--;
|
||||
if (k_tile_count > 0)
|
||||
{
|
||||
++k_tile_iter;
|
||||
}
|
||||
}
|
||||
|
||||
// (1.3) get partition for rf
|
||||
typename KT::TiledMma tiled_mma;
|
||||
auto thr_mma = tiled_mma.get_thread_slice(thread_idx);
|
||||
cute::Tensor tOrInput = thr_mma.partition_fragment_A(sInput(cute::_, cute::_, 0)); // (MMA,MMA_M,MMA_K)
|
||||
cute::Tensor tOrfc1 = thr_mma.partition_fragment_B(sfc1_weight(cute::_, cute::_, 0)); // (MMA,MMA_N,MMA_K)
|
||||
cute::Tensor tOrfc1g = thr_mma.partition_fragment_B(sfc1_gate_weight(cute::_, cute::_, 0)); // (MMA,MMA_N,MMA_K)
|
||||
|
||||
cute::Tensor accum
|
||||
= cute::partition_fragment_C(tiled_mma, cute::take<0, 2>(typename KT::TileShape{})); // (MMA,MMA_M,MMA_N)
|
||||
cute::Tensor accum_gate
|
||||
= cute::partition_fragment_C(tiled_mma, cute::take<0, 2>(typename KT::TileShape{})); // (MMA,MMA_M,MMA_N)
|
||||
cute::clear(accum);
|
||||
cute::clear(accum_gate);
|
||||
// checkout the shape
|
||||
CUTE_STATIC_ASSERT_V(cute::size<1>(tOrInput) == cute::size<1>(accum)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(cute::size<1>(tOrInput) == cute::size<1>(accum_gate)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(cute::size<1>(tOrfc1) == cute::size<2>(accum)); // MMA_N
|
||||
CUTE_STATIC_ASSERT_V(cute::size<1>(tOrfc1) == cute::size<2>(accum_gate)); // MMA_N
|
||||
CUTE_STATIC_ASSERT_V(cute::size<1>(tOrfc1g) == cute::size<2>(accum)); // MMA_N
|
||||
CUTE_STATIC_ASSERT_V(cute::size<1>(tOrfc1g) == cute::size<2>(accum_gate)); // MMA_N
|
||||
CUTE_STATIC_ASSERT_V(cute::size<2>(tOrInput) == cute::size<2>(tOrfc1)); // MMA_K
|
||||
CUTE_STATIC_ASSERT_V(cute::size<2>(tOrInput) == cute::size<2>(tOrfc1g)); // MMA_K
|
||||
CUTE_STATIC_ASSERT_V(cute::size(gmem_tiled_copy_A) == cute::size(tiled_mma));
|
||||
CUTE_STATIC_ASSERT_V(cute::size(gmem_tiled_copy_B) == cute::size(tiled_mma));
|
||||
|
||||
// (1.4)retiling the smem and rf for copy..
|
||||
auto smem_tiled_copy_A = cute::make_tiled_copy_A(typename KT::SmemCopyAtomA{}, tiled_mma);
|
||||
auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx);
|
||||
cute::Tensor tOsInput = smem_thr_copy_A.partition_S(sInput); // (CPY,CPY_M,CPY_K,Stage)
|
||||
cute::Tensor tOrInput_copy_view = smem_thr_copy_A.retile_D(tOrInput); // (CPY,CPY_M,CPY_K)
|
||||
CUTE_STATIC_ASSERT_V(cute::size<1>(tOsInput) == cute::size<1>(tOrInput_copy_view)); // CPY_M
|
||||
CUTE_STATIC_ASSERT_V(cute::size<2>(tOsInput) == cute::size<2>(tOrInput_copy_view)); // CPY_K
|
||||
|
||||
auto smem_tiled_copy_B = cute::make_tiled_copy_B(typename KT::SmemCopyAtomB{}, tiled_mma);
|
||||
auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx);
|
||||
cute::Tensor tOsfc1 = smem_thr_copy_B.partition_S(sfc1_weight); // (CPY,CPY_N,CPY_K,Stage)
|
||||
cute::Tensor tOrfc1_copy_view = smem_thr_copy_B.retile_D(tOrfc1); // (CPY,CPY_N,CPY_K)
|
||||
cute::Tensor tOsfc1g = smem_thr_copy_B.partition_S(sfc1_gate_weight); // (CPY,CPY_N,CPY_K,Stage)
|
||||
cute::Tensor tOrfc1g_copy_view = smem_thr_copy_B.retile_D(tOrfc1g); // (CPY,CPY_N,CPY_K)
|
||||
CUTE_STATIC_ASSERT_V(cute::size<1>(tOsfc1) == cute::size<1>(tOrfc1_copy_view)); // CPY_N
|
||||
CUTE_STATIC_ASSERT_V(cute::size<2>(tOsfc1) == cute::size<2>(tOrfc1_copy_view)); // CPY_K
|
||||
CUTE_STATIC_ASSERT_V(cute::size<1>(tOsfc1g) == cute::size<1>(tOrfc1g_copy_view)); // CPY_N
|
||||
CUTE_STATIC_ASSERT_V(cute::size<2>(tOsfc1g) == cute::size<2>(tOrfc1g_copy_view)); // CPY_K
|
||||
|
||||
// (1.5) mainloop
|
||||
// Current pipe index in smem to read from
|
||||
int smem_pipe_read = 0;
|
||||
// Current pipe index in smem to write to
|
||||
int smem_pipe_write = KT::Stages - 1;
|
||||
|
||||
cute::Tensor tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read);
|
||||
cute::Tensor tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read);
|
||||
cute::Tensor tOsfc1g_p = tOsfc1g(cute::_, cute::_, cute::_, smem_pipe_read);
|
||||
|
||||
constexpr int K_BLOCK_MAX = cute::size<2>(tOrInput);
|
||||
// prefetch register pipeline
|
||||
if constexpr (K_BLOCK_MAX > 1)
|
||||
{
|
||||
cute::cp_async_wait<KT::Stages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Prefetch the first rmem from the first k-tile
|
||||
cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, cute::Int<0>{}),
|
||||
tOrInput_copy_view(cute::_, cute::_, cute::Int<0>{}));
|
||||
cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, cute::Int<0>{}),
|
||||
tOrfc1_copy_view(cute::_, cute::_, cute::Int<0>{}));
|
||||
cute::copy(smem_tiled_copy_B, tOsfc1g_p(cute::_, cute::_, cute::Int<0>{}),
|
||||
tOrfc1g_copy_view(cute::_, cute::_, cute::Int<0>{}));
|
||||
}
|
||||
// k loop for mainloop
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; k_tile_count > 0; --k_tile_count)
|
||||
{
|
||||
cute::for_each(cute::make_int_sequence<K_BLOCK_MAX>{},
|
||||
[&](auto k_block)
|
||||
{
|
||||
if (k_block == K_BLOCK_MAX - 1)
|
||||
{
|
||||
tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read);
|
||||
tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read);
|
||||
tOsfc1g_p = tOsfc1g(cute::_, cute::_, cute::_, smem_pipe_read);
|
||||
cute::cp_async_wait<KT::Stages - 2>();
|
||||
__syncthreads();
|
||||
}
|
||||
// Load A, B shmem->regs for k_block+1
|
||||
auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX;
|
||||
cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next),
|
||||
tOrInput_copy_view(cute::_, cute::_, k_block_next));
|
||||
cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next),
|
||||
tOrfc1_copy_view(cute::_, cute::_, k_block_next));
|
||||
cute::copy(smem_tiled_copy_B, tOsfc1g_p(cute::_, cute::_, k_block_next),
|
||||
tOrfc1g_copy_view(cute::_, cute::_, k_block_next));
|
||||
// Copy gmem to smem before computing gemm on each k-pipe
|
||||
if (k_block == 0)
|
||||
{
|
||||
// cute::copy(gmem_tiled_copy_A, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter),
|
||||
// tInputsInput(cute::_, cute::_, cute::_, smem_pipe_write));
|
||||
cute::copy_if(gmem_tiled_copy_A, tInputpInput,
|
||||
tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter),
|
||||
tInputsInput(cute::_, cute::_, cute::_, smem_pipe_write));
|
||||
cute::copy(gmem_tiled_copy_B, tfc1gfc1(cute::_, cute::_, cute::_, *k_tile_iter),
|
||||
tfc1sfc1(cute::_, cute::_, cute::_, smem_pipe_write));
|
||||
cute::copy(gmem_tiled_copy_B, tfc1ggfc1g(cute::_, cute::_, cute::_, *k_tile_iter),
|
||||
tfc1gsfc1g(cute::_, cute::_, cute::_, smem_pipe_write));
|
||||
cute::cp_async_fence();
|
||||
if (k_tile_count - 1 > 0)
|
||||
{
|
||||
++k_tile_iter;
|
||||
}
|
||||
|
||||
// Advance the pipe -- Doing it here accounts for K_BLOCK_MAX = 1 (no rmem pipe)
|
||||
smem_pipe_write = smem_pipe_read;
|
||||
++smem_pipe_read;
|
||||
smem_pipe_read = (smem_pipe_read == KT::Stages) ? 0 : smem_pipe_read;
|
||||
}
|
||||
// Thread-level register gemm for k_block
|
||||
cute::gemm(tiled_mma, accum, tOrInput(cute::_, cute::_, k_block), tOrfc1(cute::_, cute::_, k_block),
|
||||
accum);
|
||||
cute::gemm(tiled_mma, accum_gate, tOrInput(cute::_, cute::_, k_block),
|
||||
tOrfc1g(cute::_, cute::_, k_block), accum_gate);
|
||||
});
|
||||
}
|
||||
|
||||
// load tail
|
||||
cute::for_each(cute::make_int_sequence<KT::Stages - 2>{},
|
||||
[&](auto WaitIndex)
|
||||
{
|
||||
k_tile_count--;
|
||||
using WaitIndex_t = decltype(WaitIndex);
|
||||
cute::for_each(cute::make_int_sequence<K_BLOCK_MAX>{},
|
||||
[&](auto k_block)
|
||||
{
|
||||
if (k_block == K_BLOCK_MAX - 1)
|
||||
{
|
||||
tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read);
|
||||
tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read);
|
||||
tOsfc1g_p = tOsfc1g(cute::_, cute::_, cute::_, smem_pipe_read);
|
||||
cute::cp_async_wait<KT::Stages - 3 - WaitIndex_t::value>();
|
||||
__syncthreads();
|
||||
}
|
||||
// Load A, B shmem->regs for k_block+1
|
||||
auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX;
|
||||
cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next),
|
||||
tOrInput_copy_view(cute::_, cute::_, k_block_next));
|
||||
cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next),
|
||||
tOrfc1_copy_view(cute::_, cute::_, k_block_next));
|
||||
cute::copy(smem_tiled_copy_B, tOsfc1g_p(cute::_, cute::_, k_block_next),
|
||||
tOrfc1g_copy_view(cute::_, cute::_, k_block_next));
|
||||
if (k_block == 0)
|
||||
{
|
||||
// only update smem_pipe_read
|
||||
++smem_pipe_read;
|
||||
smem_pipe_read = (smem_pipe_read == KT::Stages) ? 0 : smem_pipe_read;
|
||||
}
|
||||
// Thread-level register gemm for k_block
|
||||
cute::gemm(tiled_mma, accum, tOrInput(cute::_, cute::_, k_block),
|
||||
tOrfc1(cute::_, cute::_, k_block), accum);
|
||||
cute::gemm(tiled_mma, accum_gate, tOrInput(cute::_, cute::_, k_block),
|
||||
tOrfc1g(cute::_, cute::_, k_block), accum_gate);
|
||||
});
|
||||
});
|
||||
// mma tail
|
||||
cute::for_each(cute::make_int_sequence<K_BLOCK_MAX>{},
|
||||
[&](auto k_block)
|
||||
{
|
||||
// Load A, B shmem->regs for k_block+1
|
||||
auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX;
|
||||
cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next),
|
||||
tOrInput_copy_view(cute::_, cute::_, k_block_next));
|
||||
cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next),
|
||||
tOrfc1_copy_view(cute::_, cute::_, k_block_next));
|
||||
cute::copy(smem_tiled_copy_B, tOsfc1g_p(cute::_, cute::_, k_block_next),
|
||||
tOrfc1g_copy_view(cute::_, cute::_, k_block_next));
|
||||
// Thread-level register gemm for k_block
|
||||
cute::gemm(
|
||||
tiled_mma, accum, tOrInput(cute::_, cute::_, k_block), tOrfc1(cute::_, cute::_, k_block), accum);
|
||||
cute::gemm(tiled_mma, accum_gate, tOrInput(cute::_, cute::_, k_block),
|
||||
tOrfc1g(cute::_, cute::_, k_block), accum_gate);
|
||||
});
|
||||
// if (cute::thread0()) {
|
||||
// cute::print(accum_gate(0, 0, 0));
|
||||
// printf("\n");
|
||||
// }
|
||||
// (2) add bias if it has..
|
||||
if (params.ptr_bias != nullptr)
|
||||
{
|
||||
cute::Tensor gBias = gBias_mn(cute::_, cute::_, bias_is_broadcast ? 0 : block_m_idx, block_n_idx);
|
||||
cute::Tensor gBias_gate = gBias_gate_mn(cute::_, cute::_, bias_is_broadcast ? 0 : block_m_idx, block_n_idx);
|
||||
cute::Tensor tOgBias = thr_mma.partition_C(gBias);
|
||||
cute::Tensor tOgBiasg = thr_mma.partition_C(gBias_gate);
|
||||
for (int i = 0; i < cute::size(accum); i++)
|
||||
{
|
||||
accum(i) += tOgBias(i);
|
||||
accum_gate(i) += tOgBiasg(i);
|
||||
}
|
||||
}
|
||||
|
||||
// (3) calculate swiglu
|
||||
using ActivationFn = typename KT::ActivationFn;
|
||||
ActivationFn fn{};
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int temp_iter = 0; temp_iter < cute::size(accum); temp_iter++)
|
||||
{
|
||||
accum(temp_iter) = fn(accum_gate(temp_iter)) * accum(temp_iter);
|
||||
}
|
||||
|
||||
// (4) push all the result to smem
|
||||
// (4.1) convert result from ElementAccum to ElementInput
|
||||
cute::Tensor temp_accum = util_convert_type<KT::ElementOutput>(accum);
|
||||
// if (cute::thread0()) {
|
||||
// cute::print(temp_accum(0, 0, 0));
|
||||
// printf("\n");
|
||||
// }
|
||||
// (4.2) retile rf and smem for copy back..
|
||||
auto smem_tiled_copy_O = cute::make_tiled_copy_C(typename KT::SmemCopyAtomO{}, tiled_mma);
|
||||
auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx);
|
||||
// cute::clear(sO);
|
||||
cute::Tensor taccumrO = smem_thr_copy_O.retile_S(temp_accum);
|
||||
cute::Tensor taccumsO = smem_thr_copy_O.partition_D(sO);
|
||||
|
||||
// (4.3) copy rf result to smem (TODO: maybe use forloop for better performance..)
|
||||
cute::copy(smem_tiled_copy_O, taccumrO, taccumsO);
|
||||
__syncthreads();
|
||||
|
||||
// (4.4) sO -> rO -> gO
|
||||
|
||||
typename KT::GmemTiledCopyO gmem_tiled_copy_O;
|
||||
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
|
||||
// auto gmem_thr_copy_Bias = gmem_tiled_copy_O.get_thread_slice(thread_idx % KT::kGmemTrheadsPerRow); //
|
||||
// remember, for all the threads in the same col, they have the same idx for bias..
|
||||
cute::Tensor gO = gOutput_mn(cute::_, cute::_, block_m_idx, block_n_idx);
|
||||
// cute::Tensor gBias = gBias_mn(cute::_, cute::_, 0, block_n_idx); // bias only have one row..
|
||||
auto tOsO = gmem_thr_copy_O.partition_S(sO);
|
||||
auto tOgO = gmem_thr_copy_O.partition_D(gO);
|
||||
// auto tOgBias = gmem_thr_copy_O.partition_D(gBias);
|
||||
cute::Tensor cOutput = cute::make_identity_tensor(
|
||||
cute::make_shape(cute::size<0>(typename KT::TileShape{}), cute::size<1>(typename KT::TileShape{})));
|
||||
cute::Tensor tOcO = gmem_thr_copy_O.partition_D(cOutput);
|
||||
cute::Tensor tOrO = cute::make_tensor<KT::ElementOutput>(cute::shape(tOgO));
|
||||
cute::copy(gmem_tiled_copy_O, tOsO, tOrO);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int m = 0; m < cute::size<1>(tOgO); ++m)
|
||||
{
|
||||
if (cute::get<0>(tOcO(0, m, 0)) < residue_m)
|
||||
{
|
||||
cute::copy(gmem_tiled_copy_O, tOrO(cute::_, m, cute::_), tOgO(cute::_, m, cute::_));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ElementInput_, typename ElementWeight_, typename ElementOutput_, int TileM_, int TileN_, int TileK_,
|
||||
int Stages_, Activation_Type activation_type_>
|
||||
struct Fused_Moe_Kernel_routine_sm80<ElementInput_, ElementWeight_, ElementOutput_, TileM_, TileN_, TileK_, Stages_,
|
||||
activation_type_, std::enable_if_t<!isGateActivation(activation_type_)>>
|
||||
{
|
||||
|
||||
using KT = Fused_Moe_Kernel_traits_sm80<ElementInput_, ElementWeight_, ElementOutput_, TileM_, TileN_, TileK_,
|
||||
Stages_, activation_type_>;
|
||||
using Params = Routine_Params<ElementInput_, ElementWeight_, ElementOutput_>;
|
||||
|
||||
CUTE_DEVICE auto gmem_tensor_init(int const problem_index, int const gemm_m, Params const& params)
|
||||
{
|
||||
using X = cute::Underscore;
|
||||
|
||||
int const M = gemm_m;
|
||||
int const N1 = params.gemm_n;
|
||||
int const K1 = params.gemm_k;
|
||||
bool const bias_is_broadcast = params.bias_is_broadcast;
|
||||
|
||||
int const row_jump = ((problem_index == 0) ? 0 : params.total_tokens_including_expert[problem_index - 1]);
|
||||
typename KT::ElementInput const* ptr_input_ = params.ptr_input + row_jump * K1;
|
||||
typename KT::ElementWeight const* ptr_fc1_ = params.ptr_fc1 + problem_index * N1 * K1;
|
||||
typename KT::ElementInput const* ptr_bias_ = (params.ptr_bias == nullptr)
|
||||
? nullptr
|
||||
: (bias_is_broadcast ? params.ptr_bias + problem_index * N1 : params.ptr_bias + row_jump * N1);
|
||||
typename KT::ElementOutput* ptr_output_ = params.ptr_output + row_jump * N1;
|
||||
|
||||
cute::Tensor mInput_mk
|
||||
= cute::make_tensor(cute::make_gmem_ptr(static_cast<typename KT::ElementInput const*>(ptr_input_)),
|
||||
cute::make_shape(M, K1), cute::make_stride(K1, cute::_1{}));
|
||||
|
||||
cute::Tensor mfc1_nk
|
||||
= cute::make_tensor(cute::make_gmem_ptr(static_cast<typename KT::ElementWeight const*>(ptr_fc1_)),
|
||||
cute::make_shape(N1, K1), cute::make_stride(K1, cute::_1{}));
|
||||
|
||||
cute::Tensor mBias_mn = cute::make_tensor(
|
||||
cute::make_gmem_ptr(static_cast<typename KT::ElementInput const*>(ptr_bias_)), cute::make_shape(M, N1),
|
||||
cute::make_stride(bias_is_broadcast ? cute::Int<0>{} : N1,
|
||||
cute::_1{})); // trick: bias shape is [1, N], but we use [M, N].
|
||||
|
||||
cute::Tensor mOutput_mn
|
||||
= cute::make_tensor(cute::make_gmem_ptr(static_cast<typename KT::ElementInput*>(ptr_output_)),
|
||||
cute::make_shape(M, N1), cute::make_stride(N1, cute::_1{}));
|
||||
|
||||
cute::Tensor gInput_mk = cute::local_tile(mInput_mk, typename KT::TileShape{},
|
||||
cute::make_coord(cute::_, cute::_, cute::_), cute::Step<cute::_1, X, cute::_1>{}); // (BLK_M, BLK_K, m, k)
|
||||
cute::Tensor gfc1_nk = cute::local_tile(mfc1_nk, typename KT::TileShape{},
|
||||
cute::make_coord(cute::_, cute::_, cute::_), cute::Step<X, cute::_1, cute::_1>{}); // (BLK_N, BLK_K, n, k)
|
||||
|
||||
cute::Tensor gBias_mn = cute::local_tile(mBias_mn, typename KT::TileShape{},
|
||||
cute::make_coord(cute::_, cute::_, cute::_), cute::Step<cute::_1, cute::_1, X>{}); // (BLK_M, BLK_N, m, n)
|
||||
|
||||
cute::Tensor gOutput_mn = cute::local_tile(mOutput_mn, typename KT::TileShape{},
|
||||
cute::make_coord(cute::_, cute::_, cute::_), cute::Step<cute::_1, cute::_1, X>{}); // (BLK_M, BLK_N, m, n)
|
||||
|
||||
return cute::make_tuple(gInput_mk, gfc1_nk, gBias_mn, gOutput_mn);
|
||||
}
|
||||
|
||||
// be careful, m_idx will change when use another tile shape..
|
||||
CUTE_DEVICE void run_routine(
|
||||
Params const& params, int const problem_index, int const block_m_idx, int const block_n_idx, int const gemm_m)
|
||||
{
|
||||
extern __shared__ char smem_[];
|
||||
typename KT::SharedStorage& shared_storage = *reinterpret_cast<typename KT::SharedStorage*>(smem_);
|
||||
int const thread_idx = threadIdx.x;
|
||||
bool const bias_is_broadcast = params.bias_is_broadcast;
|
||||
// gmem tensor partition ..
|
||||
auto [gInput_mk, gfc1_nk, gBias_mn, gOutput_mn] = gmem_tensor_init(problem_index, gemm_m, params);
|
||||
int const residue_m = gemm_m - block_m_idx * cute::size<0>(gInput_mk);
|
||||
auto const n_tile_count = cute::size<2>(gfc1_nk);
|
||||
|
||||
// smem tensor ..
|
||||
cute::Tensor sInput = cute::make_tensor(
|
||||
cute::make_smem_ptr(shared_storage.smem_input.data()), typename KT::SmemLayoutA{}); // (BLK_M, BLK_K, Stage)
|
||||
cute::Tensor sfc1_weight = cute::make_tensor(cute::make_smem_ptr(shared_storage.smem_fc1_weight.data()),
|
||||
typename KT::SmemLayoutB{}); // (BLK_N, BLK_K, Stage)
|
||||
cute::Tensor sO = cute::make_tensor(
|
||||
cute::make_smem_ptr(shared_storage.smem_o.data()), typename KT::SmemLayoutO{}); // (BLK_M, BLK_N)
|
||||
|
||||
// (1) first step, get the fc1_res and fc1_gate
|
||||
|
||||
// (1.1) get partition for gmem -> smem
|
||||
cute::Tensor gInput = gInput_mk(cute::_, cute::_, block_m_idx, cute::_); // (BLK_M, BLK_K, k)
|
||||
cute::Tensor gfc1 = gfc1_nk(cute::_, cute::_, block_n_idx, cute::_); // (BLK_N, BLK_K, k)
|
||||
|
||||
typename KT::GmemTiledCopyA gmem_tiled_copy_A;
|
||||
typename KT::GmemTiledCopyB gmem_tiled_copy_B;
|
||||
auto gmem_thr_copy_A = gmem_tiled_copy_A.get_slice(thread_idx);
|
||||
auto gmem_thr_copy_B = gmem_tiled_copy_B.get_slice(thread_idx);
|
||||
|
||||
cute::Tensor tInputgInput = gmem_thr_copy_A.partition_S(gInput); // (ACPY,ACPY_M,ACPY_K,k)
|
||||
cute::Tensor tInputsInput = gmem_thr_copy_A.partition_S(sInput); // (ACPY,ACPY_M,ACPY_K,Stage)
|
||||
cute::Tensor tfc1gfc1 = gmem_thr_copy_B.partition_S(gfc1); // (BCPY,BCPY_N,BCPY_K,k)
|
||||
cute::Tensor tfc1sfc1 = gmem_thr_copy_B.partition_D(sfc1_weight); // (BCPY,BCPY_N,BCPY_K,Stage)
|
||||
|
||||
// Allocate predicate tensors for input and fc weight (actually we only need input predicate tensor)
|
||||
cute::Tensor tInputpInput
|
||||
= cute::make_tensor<bool>(cute::make_shape(cute::size<1>(tInputsInput), cute::size<2>(tInputsInput)),
|
||||
cute::Stride<cute::_1, cute::_0>{});
|
||||
// Construct identity layout for sInput
|
||||
cute::Tensor cInput = make_identity_tensor(
|
||||
make_shape(cute::size<0>(sInput), cute::size<1>(sInput))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
||||
|
||||
// Repeat the partitioning with identity layouts
|
||||
cute::Tensor tInputcInput = gmem_thr_copy_A.partition_S(cInput); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
|
||||
|
||||
// Set predicates for m bounds
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int m = 0; m < cute::size<0>(tInputpInput); ++m)
|
||||
{
|
||||
tInputpInput(m, 0) = cute::get<0>(tInputcInput(0, m, 0)) < residue_m; // blk_m coord < residue_m
|
||||
}
|
||||
|
||||
// (1.2) prefetch gmem -> smem
|
||||
cute::clear(tInputsInput); // we don't need to clear tfc1sfc1..
|
||||
auto k_tile_iter = cute::make_coord_iterator(cute::size<2>(gInput)); // emm, iter start from 0
|
||||
int k_tile_count = cute::size<2>(gInput);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_pipe = 0; k_pipe < KT::Stages - 1; ++k_pipe)
|
||||
{
|
||||
if (k_tile_count <= 0)
|
||||
{
|
||||
cute::clear(tInputpInput);
|
||||
}
|
||||
// cute::copy(gmem_tiled_copy_A, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter),
|
||||
// tInputsInput(cute::_, cute::_, cute::_, k_pipe));
|
||||
// use copy_if
|
||||
cute::copy_if(gmem_tiled_copy_A, tInputpInput, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter),
|
||||
tInputsInput(cute::_, cute::_, cute::_, k_pipe));
|
||||
cute::copy(gmem_tiled_copy_B, tfc1gfc1(cute::_, cute::_, cute::_, *k_tile_iter),
|
||||
tfc1sfc1(cute::_, cute::_, cute::_, k_pipe));
|
||||
cute::cp_async_fence();
|
||||
k_tile_count--;
|
||||
if (k_tile_count > 0)
|
||||
{
|
||||
++k_tile_iter;
|
||||
}
|
||||
}
|
||||
|
||||
// (1.3) get partition for rf
|
||||
typename KT::TiledMma tiled_mma;
|
||||
auto thr_mma = tiled_mma.get_thread_slice(thread_idx);
|
||||
cute::Tensor tOrInput = thr_mma.partition_fragment_A(sInput(cute::_, cute::_, 0)); // (MMA,MMA_M,MMA_K)
|
||||
cute::Tensor tOrfc1 = thr_mma.partition_fragment_B(sfc1_weight(cute::_, cute::_, 0)); // (MMA,MMA_N,MMA_K)
|
||||
|
||||
cute::Tensor accum
|
||||
= cute::partition_fragment_C(tiled_mma, cute::take<0, 2>(typename KT::TileShape{})); // (MMA,MMA_M,MMA_N)
|
||||
cute::clear(accum);
|
||||
// checkout the shape
|
||||
CUTE_STATIC_ASSERT_V(cute::size<1>(tOrInput) == cute::size<1>(accum)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(cute::size<1>(tOrfc1) == cute::size<2>(accum)); // MMA_N
|
||||
CUTE_STATIC_ASSERT_V(cute::size<2>(tOrInput) == cute::size<2>(tOrfc1)); // MMA_K
|
||||
CUTE_STATIC_ASSERT_V(cute::size(gmem_tiled_copy_A) == cute::size(tiled_mma));
|
||||
CUTE_STATIC_ASSERT_V(cute::size(gmem_tiled_copy_B) == cute::size(tiled_mma));
|
||||
|
||||
// (1.4)retiling the smem and rf for copy..
|
||||
auto smem_tiled_copy_A = cute::make_tiled_copy_A(typename KT::SmemCopyAtomA{}, tiled_mma);
|
||||
auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx);
|
||||
cute::Tensor tOsInput = smem_thr_copy_A.partition_S(sInput); // (CPY,CPY_M,CPY_K,Stage)
|
||||
cute::Tensor tOrInput_copy_view = smem_thr_copy_A.retile_D(tOrInput); // (CPY,CPY_M,CPY_K)
|
||||
CUTE_STATIC_ASSERT_V(cute::size<1>(tOsInput) == cute::size<1>(tOrInput_copy_view)); // CPY_M
|
||||
CUTE_STATIC_ASSERT_V(cute::size<2>(tOsInput) == cute::size<2>(tOrInput_copy_view)); // CPY_K
|
||||
|
||||
auto smem_tiled_copy_B = cute::make_tiled_copy_B(typename KT::SmemCopyAtomB{}, tiled_mma);
|
||||
auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx);
|
||||
cute::Tensor tOsfc1 = smem_thr_copy_B.partition_S(sfc1_weight); // (CPY,CPY_N,CPY_K,Stage)
|
||||
cute::Tensor tOrfc1_copy_view = smem_thr_copy_B.retile_D(tOrfc1); // (CPY,CPY_N,CPY_K)
|
||||
CUTE_STATIC_ASSERT_V(cute::size<1>(tOsfc1) == cute::size<1>(tOrfc1_copy_view)); // CPY_N
|
||||
CUTE_STATIC_ASSERT_V(cute::size<2>(tOsfc1) == cute::size<2>(tOrfc1_copy_view)); // CPY_K
|
||||
|
||||
// (1.5) mainloop
|
||||
// Current pipe index in smem to read from
|
||||
int smem_pipe_read = 0;
|
||||
// Current pipe index in smem to write to
|
||||
int smem_pipe_write = KT::Stages - 1;
|
||||
|
||||
cute::Tensor tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read);
|
||||
cute::Tensor tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read);
|
||||
|
||||
constexpr int K_BLOCK_MAX = cute::size<2>(tOrInput);
|
||||
// prefetch register pipeline
|
||||
if constexpr (K_BLOCK_MAX > 1)
|
||||
{
|
||||
cute::cp_async_wait<KT::Stages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Prefetch the first rmem from the first k-tile
|
||||
cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, cute::Int<0>{}),
|
||||
tOrInput_copy_view(cute::_, cute::_, cute::Int<0>{}));
|
||||
cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, cute::Int<0>{}),
|
||||
tOrfc1_copy_view(cute::_, cute::_, cute::Int<0>{}));
|
||||
}
|
||||
// k loop for mainloop
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; k_tile_count > 0; --k_tile_count)
|
||||
{
|
||||
cute::for_each(cute::make_int_sequence<K_BLOCK_MAX>{},
|
||||
[&](auto k_block)
|
||||
{
|
||||
if (k_block == K_BLOCK_MAX - 1)
|
||||
{
|
||||
tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read);
|
||||
tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read);
|
||||
cute::cp_async_wait<KT::Stages - 2>();
|
||||
__syncthreads();
|
||||
}
|
||||
// Load A, B shmem->regs for k_block+1
|
||||
auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX;
|
||||
cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next),
|
||||
tOrInput_copy_view(cute::_, cute::_, k_block_next));
|
||||
cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next),
|
||||
tOrfc1_copy_view(cute::_, cute::_, k_block_next));
|
||||
// Copy gmem to smem before computing gemm on each k-pipe
|
||||
if (k_block == 0)
|
||||
{
|
||||
// cute::copy(gmem_tiled_copy_A, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter),
|
||||
// tInputsInput(cute::_, cute::_, cute::_, smem_pipe_write));
|
||||
cute::copy_if(gmem_tiled_copy_A, tInputpInput,
|
||||
tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter),
|
||||
tInputsInput(cute::_, cute::_, cute::_, smem_pipe_write));
|
||||
cute::copy(gmem_tiled_copy_B, tfc1gfc1(cute::_, cute::_, cute::_, *k_tile_iter),
|
||||
tfc1sfc1(cute::_, cute::_, cute::_, smem_pipe_write));
|
||||
cute::cp_async_fence();
|
||||
if (k_tile_count - 1 > 0)
|
||||
{
|
||||
++k_tile_iter;
|
||||
}
|
||||
|
||||
// Advance the pipe -- Doing it here accounts for K_BLOCK_MAX = 1 (no rmem pipe)
|
||||
smem_pipe_write = smem_pipe_read;
|
||||
++smem_pipe_read;
|
||||
smem_pipe_read = (smem_pipe_read == KT::Stages) ? 0 : smem_pipe_read;
|
||||
}
|
||||
// Thread-level register gemm for k_block
|
||||
cute::gemm(tiled_mma, accum, tOrInput(cute::_, cute::_, k_block), tOrfc1(cute::_, cute::_, k_block),
|
||||
accum);
|
||||
});
|
||||
}
|
||||
// load tail
|
||||
cute::for_each(cute::make_int_sequence<KT::Stages - 2>{},
|
||||
[&](auto WaitIndex)
|
||||
{
|
||||
k_tile_count--;
|
||||
using WaitIndex_t = decltype(WaitIndex);
|
||||
cute::for_each(cute::make_int_sequence<K_BLOCK_MAX>{},
|
||||
[&](auto k_block)
|
||||
{
|
||||
if (k_block == K_BLOCK_MAX - 1)
|
||||
{
|
||||
tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read);
|
||||
tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read);
|
||||
cute::cp_async_wait<KT::Stages - 3 - WaitIndex_t::value>();
|
||||
__syncthreads();
|
||||
}
|
||||
// Load A, B shmem->regs for k_block+1
|
||||
auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX;
|
||||
cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next),
|
||||
tOrInput_copy_view(cute::_, cute::_, k_block_next));
|
||||
cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next),
|
||||
tOrfc1_copy_view(cute::_, cute::_, k_block_next));
|
||||
if (k_block == 0)
|
||||
{
|
||||
// only update smem_pipe_read
|
||||
++smem_pipe_read;
|
||||
smem_pipe_read = (smem_pipe_read == KT::Stages) ? 0 : smem_pipe_read;
|
||||
}
|
||||
// Thread-level register gemm for k_block
|
||||
cute::gemm(tiled_mma, accum, tOrInput(cute::_, cute::_, k_block),
|
||||
tOrfc1(cute::_, cute::_, k_block), accum);
|
||||
});
|
||||
});
|
||||
// mma tail
|
||||
cute::for_each(cute::make_int_sequence<K_BLOCK_MAX>{},
|
||||
[&](auto k_block)
|
||||
{
|
||||
// Load A, B shmem->regs for k_block+1
|
||||
auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX;
|
||||
cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next),
|
||||
tOrInput_copy_view(cute::_, cute::_, k_block_next));
|
||||
cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next),
|
||||
tOrfc1_copy_view(cute::_, cute::_, k_block_next));
|
||||
// Thread-level register gemm for k_block
|
||||
cute::gemm(
|
||||
tiled_mma, accum, tOrInput(cute::_, cute::_, k_block), tOrfc1(cute::_, cute::_, k_block), accum);
|
||||
});
|
||||
// if (cute::thread0()) {
|
||||
// cute::print(accum_gate(0, 0, 0));
|
||||
// printf("\n");
|
||||
// }
|
||||
// (2) add bias if it has..
|
||||
if (params.ptr_bias != nullptr)
|
||||
{
|
||||
cute::Tensor gBias = gBias_mn(cute::_, cute::_, bias_is_broadcast ? 0 : block_m_idx, block_n_idx);
|
||||
cute::Tensor tOgBias = thr_mma.partition_C(gBias);
|
||||
for (int i = 0; i < cute::size(accum); i++)
|
||||
{
|
||||
accum(i) += tOgBias(i);
|
||||
}
|
||||
}
|
||||
// (3) calculate swiglu
|
||||
using ActivationFn = typename KT::ActivationFn;
|
||||
ActivationFn fn{};
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int temp_iter = 0; temp_iter < cute::size(accum); temp_iter++)
|
||||
{
|
||||
accum(temp_iter) = fn(accum(temp_iter));
|
||||
}
|
||||
|
||||
// (4) push all the result to smem
|
||||
// (4.1) convert result from ElementAccum to ElementInput
|
||||
cute::Tensor temp_accum = util_convert_type<KT::ElementOutput>(accum);
|
||||
// if (cute::thread0()) {
|
||||
// cute::print(temp_accum(0, 0, 0));
|
||||
// printf("\n");
|
||||
// }
|
||||
// (4.2) retile rf and smem for copy back..
|
||||
auto smem_tiled_copy_O = cute::make_tiled_copy_C(typename KT::SmemCopyAtomO{}, tiled_mma);
|
||||
auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx);
|
||||
// cute::clear(sO);
|
||||
cute::Tensor taccumrO = smem_thr_copy_O.retile_S(temp_accum);
|
||||
cute::Tensor taccumsO = smem_thr_copy_O.partition_D(sO);
|
||||
|
||||
// (4.3) copy rf result to smem (TODO: maybe use forloop for better performance..)
|
||||
cute::copy(smem_tiled_copy_O, taccumrO, taccumsO);
|
||||
__syncthreads();
|
||||
|
||||
// (4.4) sO -> rO -> gO
|
||||
|
||||
typename KT::GmemTiledCopyO gmem_tiled_copy_O;
|
||||
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
|
||||
// auto gmem_thr_copy_Bias = gmem_tiled_copy_O.get_thread_slice(thread_idx % KT::kGmemTrheadsPerRow); //
|
||||
cute::Tensor gO = gOutput_mn(cute::_, cute::_, block_m_idx, block_n_idx);
|
||||
auto tOsO = gmem_thr_copy_O.partition_S(sO);
|
||||
auto tOgO = gmem_thr_copy_O.partition_D(gO);
|
||||
cute::Tensor cOutput = cute::make_identity_tensor(
|
||||
cute::make_shape(cute::size<0>(typename KT::TileShape{}), cute::size<1>(typename KT::TileShape{})));
|
||||
cute::Tensor tOcO = gmem_thr_copy_O.partition_D(cOutput);
|
||||
cute::Tensor tOrO = cute::make_tensor<KT::ElementOutput>(cute::shape(tOgO));
|
||||
cute::copy(gmem_tiled_copy_O, tOsO, tOrO);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int m = 0; m < cute::size<1>(tOgO); ++m)
|
||||
{
|
||||
if (cute::get<0>(tOcO(0, m, 0)) < residue_m)
|
||||
{
|
||||
cute::copy(gmem_tiled_copy_O, tOrO(cute::_, m, cute::_), tOgO(cute::_, m, cute::_));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace fused_moe
|
||||
@@ -0,0 +1,215 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/epilogue/thread/activation.h>
|
||||
#include <cutlass_extensions/epilogue_helpers.h>
|
||||
#include <cutlass_extensions/gemm/kernel/moe_cute_util.cuh>
|
||||
#include <cutlass_extensions/gemm/kernel/moe_problem_visitor.h>
|
||||
|
||||
namespace fused_moe
|
||||
{
|
||||
template <typename ElementInput, typename ElementWeight, typename ElementOutput>
|
||||
struct Routine_Arguments
|
||||
{
|
||||
ElementInput* ptr_input{};
|
||||
ElementWeight* ptr_fc1{};
|
||||
ElementInput* ptr_bias{};
|
||||
ElementOutput* ptr_output{};
|
||||
int64_t const* total_tokens_including_expert{};
|
||||
int gemm_n{};
|
||||
int gemm_k{};
|
||||
int num_expert{};
|
||||
bool bias_is_broadcast{};
|
||||
};
|
||||
|
||||
template <typename ElementInput, typename ElementWeight, typename ElementOutput>
|
||||
struct Routine_Params
|
||||
{
|
||||
ElementInput* ptr_input{};
|
||||
ElementWeight* ptr_fc1{};
|
||||
ElementInput* ptr_bias{};
|
||||
ElementOutput* ptr_output{};
|
||||
int64_t const* total_tokens_including_expert{};
|
||||
int gemm_n{};
|
||||
int gemm_k{};
|
||||
int num_expert{};
|
||||
bool bias_is_broadcast{};
|
||||
};
|
||||
|
||||
enum class Activation_Type
|
||||
{
|
||||
Gelu = 0,
|
||||
Relu,
|
||||
Silu,
|
||||
Swiglu,
|
||||
Geglu,
|
||||
Identity,
|
||||
InvalidType
|
||||
};
|
||||
|
||||
constexpr bool isGateActivation(Activation_Type const& activation_type)
|
||||
{
|
||||
return activation_type == Activation_Type::Swiglu || activation_type == Activation_Type::Geglu;
|
||||
}
|
||||
|
||||
template <typename CutlassExtensionEpilogueTag>
|
||||
constexpr Activation_Type EpilogueRouting(bool /*is_gate*/)
|
||||
{
|
||||
return Activation_Type::InvalidType;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr Activation_Type EpilogueRouting<tensorrt_llm::cutlass_extensions::EpilogueOpDefault>(bool /*is_gate*/)
|
||||
{
|
||||
return Activation_Type::Identity;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr Activation_Type EpilogueRouting<tensorrt_llm::cutlass_extensions::EpilogueOpDefaultReLU>(bool /*is_gate*/)
|
||||
{
|
||||
return Activation_Type::Relu;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr Activation_Type EpilogueRouting<tensorrt_llm::cutlass_extensions::EpilogueOpDefaultSilu>(bool is_gate)
|
||||
{
|
||||
return is_gate ? Activation_Type::Swiglu : Activation_Type::Silu;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr Activation_Type EpilogueRouting<tensorrt_llm::cutlass_extensions::EpilogueOpDefaultFtGelu>(bool is_gate)
|
||||
{
|
||||
return is_gate ? Activation_Type::Geglu : Activation_Type::Gelu;
|
||||
}
|
||||
|
||||
/* fusing all three kernels has many limitations. This is the simpler version. Just fuse first two kernels..*/
|
||||
template <typename ElementInput_, typename ElementWeight_, typename ElementOutput_, int TileM_, int TileN_, int TileK_,
|
||||
int Stages_, Activation_Type activation_type>
|
||||
struct Fused_Moe_Kernel_traits_sm80
|
||||
{
|
||||
using ElementInput = ElementInput_;
|
||||
using ElementWeight = ElementWeight_;
|
||||
using ElementAccum = float;
|
||||
using ElementOutput = ElementOutput_;
|
||||
|
||||
using index_t = uint32_t;
|
||||
static_assert(TileM_ % 16 == 0);
|
||||
static_assert(TileN_ % 32 == 0);
|
||||
static_assert(TileK_ % 32 == 0);
|
||||
static constexpr int Stages = Stages_;
|
||||
static constexpr int kTileM = TileM_;
|
||||
static constexpr int kTileN = TileN_;
|
||||
static constexpr int kTileK = (kTileM > 16) ? (TileK_) : (TileK_ >= 64 ? TileK_ : 64);
|
||||
|
||||
// tile shape
|
||||
using TileShape = cute::Shape<cute::Int<kTileM>, cute::Int<kTileN>, cute::Int<kTileK>>;
|
||||
static constexpr int kWarpsCount = 4;
|
||||
static constexpr int kThreadCount = kWarpsCount * 32;
|
||||
|
||||
// MMA atom arch and layout
|
||||
using MMA_Atom_Arch = std::conditional_t<std::is_same_v<ElementInput, cutlass::half_t>,
|
||||
cute::MMA_Atom<cute::SM80_16x8x16_F32F16F16F32_TN>, cute::MMA_Atom<cute::SM80_16x8x16_F32BF16BF16F32_TN>>;
|
||||
// using ValLayoutMNK = cute::Layout<cute::Shape<cute::_1, cute::_2, cute::_1>>;
|
||||
using ThreadLayoutMNK
|
||||
= std::conditional_t<kTileM == 16, cute::Layout<cute::Shape<cute::_1, cute::Int<kWarpsCount / 1>, cute::_1>>,
|
||||
cute::Layout<cute::Shape<cute::_2, cute::Int<kWarpsCount / 2>, cute::_1>>>;
|
||||
using ValLayoutMNK = std::conditional_t<kTileM == 16, cute::Tile<cute::_16, cute::_64, cute::_16>,
|
||||
cute::Tile<cute::_32, cute::_32, cute::_16>>;
|
||||
using TiledMma = cute::TiledMMA<MMA_Atom_Arch, ThreadLayoutMNK,
|
||||
ValLayoutMNK>; // 32x32x16 or 16x64x16 MMA for LDSM if kWarp = 4
|
||||
static constexpr int kAlignment = 8;
|
||||
static constexpr int kBlcokKSmem = (kTileM == 16) ? 64 : 32;
|
||||
// A memory copy operand
|
||||
using DefaultOperandA
|
||||
= DefaultGemm_TensorOpSm80_OperandA<ElementInput, cutlass::layout::RowMajor, kAlignment, kBlcokKSmem>;
|
||||
using SmemLayoutAtomA = typename DefaultOperandA::SmemLayoutAtom;
|
||||
using SmemCopyAtomA = typename DefaultOperandA::SmemCopyAtom;
|
||||
using GmemTiledCopyA = typename DefaultOperandA::GmemTiledCopy;
|
||||
|
||||
// B memory copy operand
|
||||
using DefaultOperandB
|
||||
= DefaultGemm_TensorOpSm80_OperandB<ElementWeight, cutlass::layout::ColumnMajor, kAlignment, kBlcokKSmem>;
|
||||
using SmemLayoutAtomB = typename DefaultOperandB::SmemLayoutAtom;
|
||||
using SmemCopyAtomB = typename DefaultOperandB::SmemCopyAtom;
|
||||
using GmemTiledCopyB = typename DefaultOperandB::GmemTiledCopy;
|
||||
|
||||
// Output memory copy operand
|
||||
using SmemLayoutAtomO = SmemLayoutAtomA;
|
||||
using SmemCopyAtomO = cute::Copy_Atom<cute::DefaultCopy, ElementOutput>;
|
||||
static constexpr int kGmemElementPerLoad = sizeof(cute::uint128_t) / sizeof(ElementOutput);
|
||||
static constexpr int kGmemTrheadsPerRow = kBlcokKSmem / kGmemElementPerLoad;
|
||||
using GmemLayoutAtomO
|
||||
= cute::Layout<cute::Shape<cute::Int<kThreadCount / kGmemTrheadsPerRow>, cute::Int<kGmemTrheadsPerRow>>,
|
||||
cute::Stride<cute::Int<kGmemTrheadsPerRow>, cute::_1>>;
|
||||
using GmemTiledCopyO = decltype(cute::make_tiled_copy(cute::Copy_Atom<cute::DefaultCopy, ElementOutput>{},
|
||||
GmemLayoutAtomO{}, cute::Layout<cute::Shape<cute::_1, cute::_8>>{}));
|
||||
|
||||
static_assert(cute::rank(SmemLayoutAtomA{}) == 2);
|
||||
static_assert(cute::size<0>(TileShape{}) % cute::size<0>(SmemLayoutAtomA{}) == 0); // M
|
||||
static_assert(cute::size<2>(TileShape{}) % cute::size<1>(SmemLayoutAtomA{}) == 0); // K
|
||||
static_assert(cute::rank(SmemLayoutAtomB{}) == 2);
|
||||
static_assert(cute::size<1>(TileShape{}) % cute::size<0>(SmemLayoutAtomB{}) == 0); // N
|
||||
static_assert(cute::size<2>(TileShape{}) % cute::size<1>(SmemLayoutAtomB{}) == 0); // K
|
||||
|
||||
using SmemLayoutA = decltype(cute::tile_to_shape(SmemLayoutAtomA{},
|
||||
cute::make_shape(
|
||||
cute::shape<0>(TileShape{}), cute::shape<2>(TileShape{}), cute::Int<Stages>{}))); // BLK_M, BLK_K, Stages
|
||||
using SmemLayoutB = decltype(cute::tile_to_shape(SmemLayoutAtomB{},
|
||||
cute::make_shape(
|
||||
cute::shape<1>(TileShape{}), cute::shape<2>(TileShape{}), cute::Int<Stages>{}))); // BLK_N, BLK_K, Stages
|
||||
using SmemLayoutO = decltype(cute::tile_to_shape(
|
||||
SmemLayoutAtomO{}, cute::make_shape(cute::shape<0>(TileShape{}), cute::shape<1>(TileShape{})))); // BLK_M, BLK_N
|
||||
|
||||
// we need at least 2 stages..
|
||||
static_assert(Stages >= 2);
|
||||
|
||||
struct SharedStorageNormal : cute::aligned_struct<128>
|
||||
{
|
||||
cute::array_aligned<ElementInput, cute::cosize_v<SmemLayoutA>> smem_input;
|
||||
cute::array_aligned<ElementInput, cute::cosize_v<SmemLayoutB>> smem_fc1_weight;
|
||||
cute::array_aligned<ElementInput, cute::cosize_v<SmemLayoutO>> smem_o;
|
||||
};
|
||||
|
||||
struct SharedStorageGate : cute::aligned_struct<128>
|
||||
{
|
||||
cute::array_aligned<ElementInput, cute::cosize_v<SmemLayoutA>> smem_input;
|
||||
cute::array_aligned<ElementInput, cute::cosize_v<SmemLayoutB>> smem_fc1_gate_weight;
|
||||
cute::array_aligned<ElementInput, cute::cosize_v<SmemLayoutB>> smem_fc1_weight;
|
||||
cute::array_aligned<ElementInput, cute::cosize_v<SmemLayoutO>> smem_o;
|
||||
};
|
||||
|
||||
using SharedStorage = std::conditional_t<isGateActivation(activation_type), SharedStorageGate, SharedStorageNormal>;
|
||||
|
||||
using ActivationFn = std::conditional_t<activation_type == Activation_Type::Gelu
|
||||
|| activation_type == Activation_Type::Geglu,
|
||||
cutlass::epilogue::thread::GELU<float>,
|
||||
std::conditional_t<activation_type == Activation_Type::Relu, cutlass::epilogue::thread::ReLU<float>,
|
||||
std::conditional_t<activation_type == Activation_Type::Silu || activation_type == Activation_Type::Swiglu,
|
||||
cutlass::epilogue::thread::SiLu<float>, cutlass::epilogue::thread::Identity<float>>>>;
|
||||
|
||||
static constexpr int kSmemSize = static_cast<int>(sizeof(SharedStorage));
|
||||
|
||||
static constexpr bool can_implement(int const avaliable_smem_size)
|
||||
{
|
||||
return avaliable_smem_size > kSmemSize;
|
||||
}
|
||||
|
||||
// #endif
|
||||
};
|
||||
} // namespace fused_moe
|
||||
@@ -0,0 +1,73 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/*! \file
|
||||
\brief Scheduler for grouped GEMM
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h"
|
||||
#include "cutlass_extensions/gemm/kernel/moe_problem_visitor.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace kernel
|
||||
{
|
||||
|
||||
/// Visitor class to abstract away the algorithm for iterating over tiles
|
||||
template <typename ThreadblockShape, GroupScheduleMode GroupScheduleMode_, int PrefetchTileCount, int ThreadCount,
|
||||
bool Transposed = false>
|
||||
struct GemmMoeProblemVisitor
|
||||
: public MoeProblemVisitor<detail::GemmGroupedProblemSizeHelper<ThreadblockShape, Transposed>, ThreadblockShape,
|
||||
GroupScheduleMode_, PrefetchTileCount, ThreadCount>
|
||||
{
|
||||
|
||||
static bool const kTransposed = Transposed;
|
||||
|
||||
using ProblemSizeHelper = detail::GemmGroupedProblemSizeHelper<ThreadblockShape, Transposed>;
|
||||
using Base
|
||||
= MoeProblemVisitor<ProblemSizeHelper, ThreadblockShape, GroupScheduleMode_, PrefetchTileCount, ThreadCount>;
|
||||
using Params = typename Base::Params;
|
||||
using SharedStorage = typename Base::SharedStorage;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_DEVICE
|
||||
GemmMoeProblemVisitor(Params const& params_, SharedStorage& shared_storage_, int32_t block_idx)
|
||||
: Base(params_, shared_storage_, block_idx)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,70 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::kernel
|
||||
{
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/*
|
||||
* Stateless universal device GEMM kernel type that treats GEMM as
|
||||
* a composition of a collective mainloop and a collective epilogue.
|
||||
*
|
||||
* Supports both the 2.x and 3.x APIs based on whether the first type is
|
||||
* a cute::tuple<> or not.
|
||||
* 2.x API implementation: cutlass/gemm/kernel/gemm_universal.h
|
||||
* 3.x API implementation: cutlass/gemm/kernel/gemm_*.hpp
|
||||
*
|
||||
* In the following declaration, the name preceding the 'Or' refers to
|
||||
* 3.x API type argument order, and the name succeeding the 'Or' refers to
|
||||
* 2.x API type argument order. Template arguments without two names
|
||||
* belong to the 3.x API only.
|
||||
**/
|
||||
template <class ProblemShapeOrThreadblockMma_, // (m, n, k) or (m, n, k, l)
|
||||
class CollectiveMainloopOrEpilogue_, class CollectiveEpilogueOrThreadblockSwizzle_, class TileScheduler_ = void,
|
||||
class Enable = void>
|
||||
class GemmUniversalGated;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::kernel
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_cooperative.hpp"
|
||||
#include "cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_pingpong.hpp"
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,585 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief GEMM kernel to support the epilogue visitor model
|
||||
for customized softmax partial reduction epilogue fusion.
|
||||
|
||||
This source file will likely be moved to `include/cutlass/gemm/kernel/` in the future once
|
||||
its usage has been stabilized. For now, it is included in this example to demonstrate
|
||||
some basic output fusion options.
|
||||
|
||||
original file: 3rdparty/cutlass/examples/35_gemm_softmax/gemm_with_epilogue_visitor.h
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/semaphore.h"
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
#include "cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h"
|
||||
|
||||
namespace tk = tensorrt_llm::common;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace kernel
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||
typename Epilogue_, ///! Epilogue
|
||||
typename ThreadblockSwizzle_ ///! Threadblock swizzling function
|
||||
>
|
||||
struct GemmWithEpilogueVisitor
|
||||
{
|
||||
public:
|
||||
using Mma = Mma_;
|
||||
using Epilogue = Epilogue_;
|
||||
using EpilogueVisitor = typename Epilogue::Visitor;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
|
||||
using ElementA = typename Mma::IteratorA::Element;
|
||||
using LayoutA = typename Mma::IteratorA::Layout;
|
||||
using TensorRefA = TensorRef<ElementA, LayoutA>;
|
||||
|
||||
using ElementB = typename Mma::IteratorB::Element;
|
||||
using LayoutB = typename Mma::IteratorB::Layout;
|
||||
using TensorRefB = TensorRef<ElementB, LayoutB>;
|
||||
|
||||
using ElementCompute = typename EpilogueVisitor::ElementCompute;
|
||||
using LayoutAlphaCol = cutlass::layout::RowMajor;
|
||||
using LayoutAlphaRow = cutlass::layout::ColumnMajor;
|
||||
using TensorRefAlphaCol = TensorRef<ElementCompute, LayoutAlphaCol>;
|
||||
using TensorRefAlphaRow = TensorRef<ElementCompute, LayoutAlphaRow>;
|
||||
|
||||
using ElementC = typename EpilogueVisitor::ElementOutput;
|
||||
using LayoutC = typename Epilogue::Layout;
|
||||
using TensorRefC = TensorRef<ElementC, LayoutC>;
|
||||
|
||||
static ComplexTransform const kTransformA = Mma::kTransformA;
|
||||
static ComplexTransform const kTransformB = Mma::kTransformB;
|
||||
using Operator = typename Mma::Operator;
|
||||
|
||||
using OperatorClass = typename Mma::Operator::OperatorClass;
|
||||
using ThreadblockShape = typename Mma::Shape;
|
||||
using WarpShape = typename Mma::Operator::Shape;
|
||||
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
|
||||
using ArchTag = typename Mma::ArchTag;
|
||||
using EpilogueOutputOp =
|
||||
typename Epilogue::Visitor::ElementwiseFunctor; // Define type so GemmUniversalBase doesn't complain
|
||||
|
||||
static int const kStages = Mma::kStages;
|
||||
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
|
||||
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
|
||||
static int const kAlignmentC = EpilogueVisitor::kElementsPerAccess;
|
||||
|
||||
/// Warp count (concept: GemmShape)
|
||||
using WarpCount = typename Mma::WarpCount;
|
||||
static int const kThreadCount = 32 * WarpCount::kCount;
|
||||
|
||||
/// Split-K preserves splits that are 128b aligned
|
||||
static int const kSplitKAlignment
|
||||
= const_max(128 / sizeof_bits<ElementA>::value, 128 / sizeof_bits<ElementB>::value);
|
||||
|
||||
//
|
||||
// Structures
|
||||
//
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments
|
||||
{
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
GemmUniversalMode mode;
|
||||
GemmCoord problem_size;
|
||||
int batch_count;
|
||||
|
||||
TensorRefA ref_A;
|
||||
TensorRefB ref_B;
|
||||
tk::QuantMode quant_option;
|
||||
TensorRefAlphaCol ref_alpha_col;
|
||||
TensorRefAlphaRow ref_alpha_row;
|
||||
TensorRefC ref_C;
|
||||
TensorRefC ref_D;
|
||||
|
||||
int64_t batch_stride_A;
|
||||
int64_t batch_stride_B;
|
||||
int64_t batch_stride_D;
|
||||
|
||||
typename EpilogueVisitor::Arguments epilogue_visitor;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
Arguments()
|
||||
: mode(GemmUniversalMode::kGemm)
|
||||
, batch_count(1)
|
||||
{
|
||||
}
|
||||
|
||||
/// constructs an arguments structure
|
||||
Arguments(GemmUniversalMode mode_, GemmCoord problem_size_, int batch_count_, TensorRefA ref_A_,
|
||||
TensorRefB ref_B_, tk::QuantMode quant_option_, TensorRefAlphaCol ref_alpha_col_,
|
||||
TensorRefAlphaRow ref_alpha_row_, TensorRefC ref_C_, TensorRefC ref_D_, int64_t batch_stride_A_,
|
||||
int64_t batch_stride_B_, typename EpilogueVisitor::Arguments epilogue_visitor_)
|
||||
: mode(mode_)
|
||||
, problem_size(problem_size_)
|
||||
, batch_count(batch_count_)
|
||||
, ref_A(ref_A_)
|
||||
, ref_B(ref_B_)
|
||||
, quant_option(quant_option_)
|
||||
, ref_alpha_col(ref_alpha_col_)
|
||||
, ref_alpha_row(ref_alpha_row_)
|
||||
, ref_C(ref_C_)
|
||||
, ref_D(ref_D_)
|
||||
, batch_stride_A(batch_stride_A_)
|
||||
, batch_stride_B(batch_stride_B_)
|
||||
, batch_stride_D(0)
|
||||
, epilogue_visitor(epilogue_visitor_)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Structure for precomputing values in host memory and passing to kernels
|
||||
//
|
||||
|
||||
/// Parameters structure
|
||||
struct Params
|
||||
{
|
||||
|
||||
cutlass::gemm::GemmCoord problem_size;
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int swizzle_log_tile;
|
||||
|
||||
typename Mma::IteratorA::Params params_A;
|
||||
typename Mma::IteratorB::Params params_B;
|
||||
typename EpilogueVisitor::ScaleTileIterator::Params params_alpha_col;
|
||||
typename EpilogueVisitor::ScaleTileIterator::Params params_alpha_row;
|
||||
typename EpilogueVisitor::OutputTileIterator::Params params_C;
|
||||
typename EpilogueVisitor::OutputTileIterator::Params params_D;
|
||||
|
||||
GemmUniversalMode mode;
|
||||
int batch_count;
|
||||
int gemm_k_size;
|
||||
|
||||
void* ptr_A;
|
||||
void* ptr_B;
|
||||
tk::QuantMode quant_option;
|
||||
typename EpilogueVisitor::ScaleTileIterator::Element* ptr_alpha_col;
|
||||
typename EpilogueVisitor::ScaleTileIterator::Element* ptr_alpha_row;
|
||||
ElementC* ptr_C;
|
||||
ElementC* ptr_D;
|
||||
|
||||
int64_t batch_stride_A;
|
||||
int64_t batch_stride_B;
|
||||
|
||||
typename EpilogueVisitor::Params epilogue_visitor;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params()
|
||||
: swizzle_log_tile(0)
|
||||
, params_A(0)
|
||||
, params_B(0)
|
||||
, params_alpha_col(0)
|
||||
, params_C(0)
|
||||
, params_D(0)
|
||||
, batch_count(0)
|
||||
, gemm_k_size(0)
|
||||
, mode(cutlass::gemm::GemmUniversalMode::kGemm)
|
||||
, ptr_A(nullptr)
|
||||
, ptr_B(nullptr)
|
||||
, ptr_alpha_col(nullptr)
|
||||
, ptr_alpha_row(nullptr)
|
||||
, ptr_C(nullptr)
|
||||
, ptr_D(nullptr)
|
||||
, batch_stride_A(0)
|
||||
, batch_stride_B(0)
|
||||
{
|
||||
}
|
||||
|
||||
Params(
|
||||
Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape_, int gemm_k_size_, int* workspace_)
|
||||
: problem_size(args.problem_size)
|
||||
, swizzle_log_tile(0)
|
||||
, params_A(args.ref_A.layout())
|
||||
, params_B(args.ref_B.layout())
|
||||
, params_alpha_col(args.ref_alpha_col.layout())
|
||||
, params_alpha_row(args.ref_alpha_col.layout())
|
||||
, params_C(args.ref_C.layout())
|
||||
, params_D(args.ref_D.layout())
|
||||
, mode(args.mode)
|
||||
, batch_count(args.batch_count)
|
||||
, gemm_k_size(args.problem_size.k())
|
||||
, ptr_A(args.ref_A.data())
|
||||
, ptr_B(args.ref_B.data())
|
||||
, quant_option(args.quant_option)
|
||||
, ptr_alpha_col(args.ref_alpha_col.data())
|
||||
, ptr_alpha_row(args.ref_alpha_row.data())
|
||||
, ptr_C(args.ref_C.data())
|
||||
, ptr_D(args.ref_D.data())
|
||||
, batch_stride_A(args.batch_stride_A)
|
||||
, batch_stride_B(args.batch_stride_B)
|
||||
, epilogue_visitor(args.epilogue_visitor)
|
||||
{
|
||||
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
grid_tiled_shape = threadblock_swizzle.get_tiled_shape(args.problem_size,
|
||||
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count);
|
||||
|
||||
if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel)
|
||||
{
|
||||
|
||||
int const kAlignK
|
||||
= const_max(const_max(128 / sizeof_bits<ElementA>::value, 128 / sizeof_bits<ElementB>::value), 1);
|
||||
|
||||
gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK);
|
||||
|
||||
if (gemm_k_size)
|
||||
{
|
||||
grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size);
|
||||
}
|
||||
}
|
||||
|
||||
swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape);
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared memory storage structure
|
||||
union SharedStorage
|
||||
{
|
||||
|
||||
typename Mma::SharedStorage main_loop;
|
||||
|
||||
struct
|
||||
{
|
||||
typename Epilogue::SharedStorage epilogue;
|
||||
typename EpilogueVisitor::SharedStorage visitor;
|
||||
} epilogue;
|
||||
};
|
||||
|
||||
public:
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_DEVICE
|
||||
GemmWithEpilogueVisitor() {}
|
||||
|
||||
/// Determines whether kernel satisfies alignment
|
||||
static Status can_implement(cutlass::gemm::GemmCoord const& problem_size)
|
||||
{
|
||||
|
||||
CUTLASS_TRACE_HOST("GemmWithEpilogueVisitor::can_implement()");
|
||||
|
||||
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
|
||||
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
|
||||
static int const kAlignmentC = EpilogueVisitor::OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
bool isAMisaligned = false;
|
||||
bool isBMisaligned = false;
|
||||
bool isCMisaligned = false;
|
||||
|
||||
if (platform::is_same<LayoutA, layout::RowMajor>::value)
|
||||
{
|
||||
isAMisaligned = problem_size.k() % kAlignmentA;
|
||||
}
|
||||
else if (platform::is_same<LayoutA, layout::ColumnMajor>::value)
|
||||
{
|
||||
isAMisaligned = problem_size.m() % kAlignmentA;
|
||||
}
|
||||
else if (platform::is_same<LayoutA, layout::ColumnMajorInterleaved<32>>::value
|
||||
|| platform::is_same<LayoutA, layout::ColumnMajorInterleaved<64>>::value)
|
||||
{
|
||||
isAMisaligned = problem_size.k() % kAlignmentA;
|
||||
}
|
||||
|
||||
if (platform::is_same<LayoutB, layout::RowMajor>::value)
|
||||
{
|
||||
isBMisaligned = problem_size.n() % kAlignmentB;
|
||||
}
|
||||
else if (platform::is_same<LayoutB, layout::ColumnMajor>::value)
|
||||
{
|
||||
isBMisaligned = problem_size.k() % kAlignmentB;
|
||||
}
|
||||
else if (platform::is_same<LayoutB, layout::RowMajorInterleaved<32>>::value
|
||||
|| platform::is_same<LayoutB, layout::RowMajorInterleaved<64>>::value)
|
||||
{
|
||||
isBMisaligned = problem_size.k() % kAlignmentB;
|
||||
}
|
||||
|
||||
if (platform::is_same<LayoutC, layout::RowMajor>::value)
|
||||
{
|
||||
isCMisaligned = problem_size.n() % kAlignmentC;
|
||||
}
|
||||
else if (platform::is_same<LayoutC, layout::ColumnMajor>::value)
|
||||
{
|
||||
isCMisaligned = problem_size.m() % kAlignmentC;
|
||||
}
|
||||
else if (platform::is_same<LayoutC, layout::ColumnMajorInterleaved<32>>::value
|
||||
|| platform::is_same<LayoutC, layout::ColumnMajorInterleaved<64>>::value)
|
||||
{
|
||||
isCMisaligned = problem_size.n() % kAlignmentC;
|
||||
}
|
||||
|
||||
if (isAMisaligned)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand");
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (isBMisaligned)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand");
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (isCMisaligned)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand");
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST(" returning kSuccess");
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
static Status can_implement(Arguments const& args)
|
||||
{
|
||||
return can_implement(args.problem_size);
|
||||
}
|
||||
|
||||
static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape)
|
||||
{
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#define SPLIT_K_ENABLED 1
|
||||
|
||||
/// Executes one GEMM
|
||||
CUTLASS_DEVICE
|
||||
void run_kernel_(Params const& params, SharedStorage& shared_storage)
|
||||
{
|
||||
|
||||
// Compute threadblock location
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
// Early exit if CTA is out of range
|
||||
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m()
|
||||
|| params.grid_tiled_shape.n() <= threadblock_tile_offset.n())
|
||||
{
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
int offset_k = 0;
|
||||
int problem_size_k = params.problem_size.k();
|
||||
|
||||
ElementA* ptr_A = static_cast<ElementA*>(params.ptr_A);
|
||||
ElementB* ptr_B = static_cast<ElementB*>(params.ptr_B);
|
||||
|
||||
#if SPLIT_K_ENABLED
|
||||
//
|
||||
// Fetch pointers based on mode.
|
||||
//
|
||||
if (params.mode == GemmUniversalMode::kGemm || params.mode == GemmUniversalMode::kGemmSplitKParallel)
|
||||
{
|
||||
|
||||
if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k())
|
||||
{
|
||||
|
||||
problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size;
|
||||
}
|
||||
|
||||
offset_k = threadblock_tile_offset.k() * params.gemm_k_size;
|
||||
}
|
||||
else if (params.mode == GemmUniversalMode::kBatched)
|
||||
{
|
||||
ptr_A += threadblock_tile_offset.k() * params.batch_stride_A;
|
||||
ptr_B += threadblock_tile_offset.k() * params.batch_stride_B;
|
||||
}
|
||||
else if (params.mode == GemmUniversalMode::kArray)
|
||||
{
|
||||
ptr_A = static_cast<ElementA* const*>(params.ptr_A)[threadblock_tile_offset.k()];
|
||||
ptr_B = static_cast<ElementB* const*>(params.ptr_B)[threadblock_tile_offset.k()];
|
||||
}
|
||||
#endif
|
||||
|
||||
// Compute initial location in logical coordinates
|
||||
cutlass::MatrixCoord tb_offset_A{
|
||||
threadblock_tile_offset.m() * Mma::Shape::kM,
|
||||
offset_k,
|
||||
};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_B{offset_k, threadblock_tile_offset.n() * Mma::Shape::kN};
|
||||
|
||||
// Compute position within threadblock
|
||||
int thread_idx = threadIdx.x;
|
||||
|
||||
// Construct iterators to A and B operands
|
||||
typename Mma::IteratorA iterator_A(
|
||||
params.params_A, ptr_A, {params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A);
|
||||
|
||||
typename Mma::IteratorB iterator_B(
|
||||
params.params_B, ptr_B, {problem_size_k, params.problem_size.n()}, thread_idx, tb_offset_B);
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
//
|
||||
// Main loop
|
||||
//
|
||||
|
||||
// Construct thread-scoped matrix multiply
|
||||
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
typename Mma::FragmentC accumulators;
|
||||
|
||||
accumulators.clear();
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK;
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators);
|
||||
|
||||
//
|
||||
// Masked tile iterators constructed from members
|
||||
//
|
||||
|
||||
threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
// assume identity swizzle
|
||||
MatrixCoord threadblock_offset(
|
||||
threadblock_tile_offset.m() * Mma::Shape::kM, threadblock_tile_offset.n() * Mma::Shape::kN);
|
||||
|
||||
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
|
||||
|
||||
//
|
||||
// Construct the epilogue visitor
|
||||
//
|
||||
|
||||
EpilogueVisitor epilogue_visitor(params.epilogue_visitor, shared_storage.epilogue.visitor,
|
||||
params.problem_size.mn(), thread_idx, warp_idx, lane_idx, params.params_alpha_col, params.params_C,
|
||||
params.params_D, params.quant_option, params.ptr_alpha_row, params.ptr_alpha_col, params.ptr_C,
|
||||
params.ptr_D, threadblock_offset, blockIdx.y * params.problem_size.m());
|
||||
|
||||
if (params.mode == GemmUniversalMode::kGemm)
|
||||
{
|
||||
// Indicate which position in a serial reduction the output operator is currently updating
|
||||
epilogue_visitor.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
|
||||
}
|
||||
else if (params.mode == GemmUniversalMode::kBatched || params.mode == GemmUniversalMode::kArray)
|
||||
{
|
||||
epilogue_visitor.set_batch_index(threadblock_tile_offset.k());
|
||||
}
|
||||
|
||||
// Construct the epilogue
|
||||
Epilogue epilogue(shared_storage.epilogue.epilogue, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
// Execute the epilogue operator to update the destination tensor.
|
||||
epilogue(epilogue_visitor, accumulators);
|
||||
}
|
||||
|
||||
template <typename CompilationArch>
|
||||
CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage)
|
||||
{
|
||||
if constexpr (platform::is_same<ArchTag, CompilationArch>::value)
|
||||
{
|
||||
run_kernel_(params, shared_storage);
|
||||
}
|
||||
else
|
||||
{
|
||||
CUTLASS_NOT_IMPLEMENTED();
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond
|
||||
to the ArchTag of the cutlass kernel operator.
|
||||
*/
|
||||
/// Executes one GEMM
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const& params, SharedStorage& shared_storage)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__)
|
||||
#if (__CUDA_ARCH__ >= 720) && (__CUDA_ARCH__ < 750)
|
||||
run_kernel<arch::Sm72>(params, shared_storage);
|
||||
#elif (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800)
|
||||
run_kernel<arch::Sm75>(params, shared_storage);
|
||||
#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900)
|
||||
run_kernel<arch::Sm80>(params, shared_storage);
|
||||
#elif (__CUDA_ARCH__ >= 900)
|
||||
// TODO - replace with CUTLASS_NOT_IMPLEMENTED() and upgrade to 3.x kernels.
|
||||
run_kernel<arch::Sm80>(params, shared_storage);
|
||||
#else
|
||||
static_assert(
|
||||
false, "Invalid architecture being compiled. Only Volta+ supported in weight-only quantization kernels.");
|
||||
#endif
|
||||
#else
|
||||
CUTLASS_NOT_IMPLEMENTED();
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,143 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
/*
|
||||
This file exists so that we use the same weight layout for MoE grouped gemm and regular gemm when the weight is
|
||||
quantized. The preprocessing code reads this template to know how to organize the quantized weight matrices
|
||||
to be consumed by CUTLASS.
|
||||
|
||||
Note that for int4, ThreadBlockK MUST be 64.
|
||||
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/arch/mma.h"
|
||||
#include "cutlass/platform/platform.h"
|
||||
|
||||
#include "cutlass_extensions/arch/mma.h"
|
||||
#include "cutlass_extensions/tile_interleaved_layout.h"
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace kernel
|
||||
{
|
||||
|
||||
template <typename TypeA, typename TypeB, typename Arch, typename Enable = void>
|
||||
struct LayoutDetailsB
|
||||
{
|
||||
};
|
||||
|
||||
// Specializations for Turing+ when B is FP16. These are currently only used for MoE networks.
|
||||
// TODO - Switch this to column major for weights since gemms should be more performant.
|
||||
template <typename TypeA, typename Arch>
|
||||
struct LayoutDetailsB<TypeA, half_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type>
|
||||
{
|
||||
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
|
||||
using Layout = layout::ColumnMajor;
|
||||
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<half_t>::value;
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
};
|
||||
|
||||
template <typename TypeA, typename Arch>
|
||||
struct LayoutDetailsB<TypeA, bfloat16_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type>
|
||||
{
|
||||
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
|
||||
using Layout = layout::ColumnMajor;
|
||||
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<bfloat16_t>::value;
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
};
|
||||
|
||||
template <typename TypeA>
|
||||
struct LayoutDetailsB<TypeA, cutlass::float_e4m3_t, arch::Sm89>
|
||||
{
|
||||
static constexpr int ThreadblockK = 64;
|
||||
|
||||
private:
|
||||
static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits<uint8_t>::value;
|
||||
static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK;
|
||||
|
||||
public:
|
||||
using Layout = layout::ColumnMajor;
|
||||
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<cutlass::float_e4m3_t>::value;
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
// for fast accumulation
|
||||
// using Operator = cutlass::arch::OpMultiplyAddFastAccum;
|
||||
};
|
||||
|
||||
// Specializations for Turing+ when B is quantized. These can use the operator OpMultiplyAddDequantizeInterleavedBToA,
|
||||
// which signals that we want to dequantize after loading from smem.
|
||||
template <typename TypeA, typename Arch>
|
||||
struct LayoutDetailsB < TypeA,
|
||||
uint8_t, Arch,
|
||||
typename platform::enable_if<Arch::kMinComputeCapability >= 75 && Arch::kMinComputeCapability<90>::type>
|
||||
{
|
||||
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
|
||||
|
||||
private:
|
||||
static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits<uint8_t>::value;
|
||||
static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK;
|
||||
|
||||
public:
|
||||
using Layout = layout::ColumnMajorTileInterleave<ThreadblockK, ColumnsInterleaved>;
|
||||
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<uint8_t>::value;
|
||||
using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA;
|
||||
};
|
||||
|
||||
template <typename TypeA, typename Arch>
|
||||
struct LayoutDetailsB < TypeA,
|
||||
uint4b_t, Arch,
|
||||
typename platform::enable_if<Arch::kMinComputeCapability >= 75 && Arch::kMinComputeCapability<90>::type>
|
||||
{
|
||||
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
|
||||
|
||||
private:
|
||||
static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits<uint4b_t>::value;
|
||||
static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK;
|
||||
|
||||
public:
|
||||
using Layout = layout::ColumnMajorTileInterleave<ThreadblockK, ColumnsInterleaved>;
|
||||
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<uint4b_t>::value;
|
||||
using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA;
|
||||
};
|
||||
|
||||
template <typename TypeA, typename Arch>
|
||||
struct LayoutDetailsB<TypeA, uint8_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 90>::type>
|
||||
{
|
||||
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
|
||||
using Layout = layout::ColumnMajor;
|
||||
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<half_t>::value;
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
};
|
||||
|
||||
template <typename TypeA, typename Arch>
|
||||
struct LayoutDetailsB<TypeA, uint4b_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 90>::type>
|
||||
{
|
||||
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
|
||||
using Layout = layout::ColumnMajor;
|
||||
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<half_t>::value;
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
};
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@@ -0,0 +1,185 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include <cute/algorithm/copy.hpp>
|
||||
#include <cute/atom/copy_atom.hpp>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/layout/layout.h>
|
||||
#include <cutlass/numeric_conversion.h>
|
||||
|
||||
template <typename Element, typename Layout, int Alignment, int SizeK>
|
||||
struct DefaultGemm_TensorOpSm80_OperandA;
|
||||
|
||||
template <typename Element, typename Layout, int Alignment, int SizeK>
|
||||
struct DefaultGemm_TensorOpSm80_OperandB;
|
||||
|
||||
template <>
|
||||
struct DefaultGemm_TensorOpSm80_OperandA<cute::half_t, cutlass::layout::RowMajor, 8, 64>
|
||||
{
|
||||
// Smem
|
||||
using SmemLayoutAtom = decltype(cute::composition(
|
||||
cute::Swizzle<3, 3, 3>{}, cute::Layout<cute::Shape<cute::_8, cute::_64>, cute::Stride<cute::_64, cute::_1>>{}));
|
||||
using SmemCopyAtom = cute::Copy_Atom<cute::SM75_U32x4_LDSM_N, cute::half_t>;
|
||||
|
||||
// Gmem
|
||||
using GmemTiledCopy = decltype(cute::make_tiled_copy(
|
||||
cute::Copy_Atom<cute::SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, cute::half_t>{},
|
||||
cute::Layout<cute::Shape<cute::_16, cute::_8>, cute::Stride<cute::_8, cute::_1>>{},
|
||||
cute::Layout<cute::Shape<cute::_1, cute::_8>>{}));
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DefaultGemm_TensorOpSm80_OperandA<cute::bfloat16_t, cutlass::layout::RowMajor, 8, 64>
|
||||
{
|
||||
// Smem
|
||||
using SmemLayoutAtom = decltype(cute::composition(
|
||||
cute::Swizzle<3, 3, 3>{}, cute::Layout<cute::Shape<cute::_8, cute::_64>, cute::Stride<cute::_64, cute::_1>>{}));
|
||||
using SmemCopyAtom = cute::Copy_Atom<cute::SM75_U32x4_LDSM_N, cute::bfloat16_t>;
|
||||
|
||||
// Gmem
|
||||
using GmemTiledCopy = decltype(cute::make_tiled_copy(
|
||||
cute::Copy_Atom<cute::SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, cute::bfloat16_t>{},
|
||||
cute::Layout<cute::Shape<cute::_16, cute::_8>, cute::Stride<cute::_8, cute::_1>>{},
|
||||
cute::Layout<cute::Shape<cute::_1, cute::_8>>{}));
|
||||
};
|
||||
|
||||
/// Operand A - Column-major (M-major)
|
||||
template <int SizeK>
|
||||
struct DefaultGemm_TensorOpSm80_OperandA<cute::half_t, cutlass::layout::ColumnMajor, 8, SizeK>
|
||||
{
|
||||
// Smem
|
||||
using SmemLayoutAtom = decltype(cute::composition(
|
||||
cute::Swizzle<3, 3, 3>{}, cute::Layout<cute::Shape<cute::_64, cute::_8>, cute::Stride<cute::_1, cute::_64>>{}));
|
||||
using SmemCopyAtom = cute::Copy_Atom<cute::SM75_U16x8_LDSM_T, cute::half_t>;
|
||||
|
||||
// Gmem
|
||||
using GmemTiledCopy = decltype(cute::make_tiled_copy(
|
||||
cute::Copy_Atom<cute::SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, cute::half_t>{},
|
||||
cute::Layout<cute::Shape<cute::_16, cute::_8>, cute::Stride<cute::_1, cute::_16>>{},
|
||||
cute::Layout<cute::Shape<cute::_8, cute::_1>>{}));
|
||||
};
|
||||
|
||||
template <int SizeK>
|
||||
struct DefaultGemm_TensorOpSm80_OperandA<cute::bfloat16_t, cutlass::layout::ColumnMajor, 8, SizeK>
|
||||
{
|
||||
// Smem
|
||||
using SmemLayoutAtom = decltype(cute::composition(
|
||||
cute::Swizzle<3, 3, 3>{}, cute::Layout<cute::Shape<cute::_64, cute::_8>, cute::Stride<cute::_1, cute::_64>>{}));
|
||||
using SmemCopyAtom = cute::Copy_Atom<cute::SM75_U16x8_LDSM_T, cute::bfloat16_t>;
|
||||
|
||||
// Gmem
|
||||
using GmemTiledCopy = decltype(cute::make_tiled_copy(
|
||||
cute::Copy_Atom<cute::SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, cute::bfloat16_t>{},
|
||||
cute::Layout<cute::Shape<cute::_16, cute::_8>, cute::Stride<cute::_1, cute::_16>>{},
|
||||
cute::Layout<cute::Shape<cute::_8, cute::_1>>{}));
|
||||
};
|
||||
|
||||
// Because the F32F16 TiledMMA is A-B symmetric, we can reuse the DefaultOperands
|
||||
|
||||
// Operand B - Column-Major (K-major)
|
||||
template <int Alignment, int SizeK>
|
||||
struct DefaultGemm_TensorOpSm80_OperandB<cute::half_t, cutlass::layout::ColumnMajor, Alignment, SizeK>
|
||||
: DefaultGemm_TensorOpSm80_OperandA<cute::half_t, cutlass::layout::RowMajor, Alignment, SizeK>
|
||||
{
|
||||
};
|
||||
|
||||
template <int Alignment, int SizeK>
|
||||
struct DefaultGemm_TensorOpSm80_OperandB<cute::bfloat16_t, cutlass::layout::ColumnMajor, Alignment, SizeK>
|
||||
: DefaultGemm_TensorOpSm80_OperandA<cute::bfloat16_t, cutlass::layout::RowMajor, Alignment, SizeK>
|
||||
{
|
||||
};
|
||||
|
||||
// Operand B - Row-Major (N-major)
|
||||
template <int Alignment, int SizeK>
|
||||
struct DefaultGemm_TensorOpSm80_OperandB<cute::half_t, cutlass::layout::RowMajor, Alignment, SizeK>
|
||||
: DefaultGemm_TensorOpSm80_OperandA<cute::half_t, cutlass::layout::ColumnMajor, Alignment, SizeK>
|
||||
{
|
||||
};
|
||||
|
||||
template <int Alignment, int SizeK>
|
||||
struct DefaultGemm_TensorOpSm80_OperandB<cute::bfloat16_t, cutlass::layout::RowMajor, Alignment, SizeK>
|
||||
: DefaultGemm_TensorOpSm80_OperandA<cute::bfloat16_t, cutlass::layout::ColumnMajor, Alignment, SizeK>
|
||||
{
|
||||
};
|
||||
|
||||
//
|
||||
// F16: 128-by-128-by-32 (small k-block)
|
||||
//
|
||||
|
||||
/// Operand A - Row-major (K-Major)
|
||||
template <>
|
||||
struct DefaultGemm_TensorOpSm80_OperandA<cute::half_t, cutlass::layout::RowMajor, 8, 32>
|
||||
{
|
||||
// Smem
|
||||
using SmemLayoutAtom = decltype(cute::composition(
|
||||
cute::Swizzle<2, 3, 3>{}, cute::Layout<cute::Shape<cute::_8, cute::_32>, cute::Stride<cute::_32, cute::_1>>{}));
|
||||
using SmemCopyAtom = cute::Copy_Atom<cute::SM75_U32x4_LDSM_N, cute::half_t>;
|
||||
|
||||
// Gmem
|
||||
using GmemTiledCopy = decltype(cute::make_tiled_copy(
|
||||
cute::Copy_Atom<cute::SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, cute::half_t>{},
|
||||
cute::Layout<cute::Shape<cute::_32, cute::_4>, cute::Stride<cute::_4, cute::_1>>{},
|
||||
cute::Layout<cute::Shape<cute::_1, cute::_8>>{}));
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DefaultGemm_TensorOpSm80_OperandA<cute::bfloat16_t, cutlass::layout::RowMajor, 8, 32>
|
||||
{
|
||||
// Smem
|
||||
using SmemLayoutAtom = decltype(cute::composition(
|
||||
cute::Swizzle<2, 3, 3>{}, cute::Layout<cute::Shape<cute::_8, cute::_32>, cute::Stride<cute::_32, cute::_1>>{}));
|
||||
using SmemCopyAtom = cute::Copy_Atom<cute::SM75_U32x4_LDSM_N, cute::bfloat16_t>;
|
||||
|
||||
// Gmem
|
||||
using GmemTiledCopy = decltype(cute::make_tiled_copy(
|
||||
cute::Copy_Atom<cute::SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, cute::bfloat16_t>{},
|
||||
cute::Layout<cute::Shape<cute::_32, cute::_4>, cute::Stride<cute::_4, cute::_1>>{},
|
||||
cute::Layout<cute::Shape<cute::_1, cute::_8>>{}));
|
||||
};
|
||||
|
||||
template <typename To_type, typename Engine, typename Layout>
|
||||
CUTE_DEVICE auto util_convert_type(cute::Tensor<Engine, Layout> const& tensor)
|
||||
{
|
||||
using From_type = typename Engine::value_type;
|
||||
constexpr int numel = decltype(cute::size(tensor))::value;
|
||||
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
|
||||
// HACK: this requires tensor to be "contiguous"
|
||||
auto frag = convert_op(*reinterpret_cast<cutlass::Array<From_type, numel> const*>(tensor.data()));
|
||||
return cute::make_tensor(cute::make_rmem_ptr<To_type>(&frag), tensor.layout());
|
||||
}
|
||||
|
||||
template <typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
CUTE_DEVICE void util_copy(
|
||||
TiledCopy const& tiled_copy, cute::Tensor<Engine0, Layout0> const& S, cute::Tensor<Engine1, Layout1>& D)
|
||||
{
|
||||
CUTE_STATIC_ASSERT_V(cute::rank(S) == cute::Int<3>{});
|
||||
CUTE_STATIC_ASSERT_V(cute::rank(D) == cute::Int<3>{});
|
||||
CUTE_STATIC_ASSERT_V(cute::size<0>(S) == cute::size<0>(D));
|
||||
CUTE_STATIC_ASSERT_V(cute::size<1>(S) == cute::size<1>(D));
|
||||
CUTE_STATIC_ASSERT_V(cute::size<2>(S) == cute::size<2>(D));
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int m = 0; m < cute::size<1>(S); ++m)
|
||||
{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k = 0; k < cute::size<2>(S); ++k)
|
||||
{
|
||||
cute::copy(tiled_copy, S(cute::_, m, k), D(cute::_, m, k));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,553 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/semaphore.h"
|
||||
|
||||
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h"
|
||||
#include "cutlass_extensions/tile_interleaved_layout.h"
|
||||
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace kernel
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// This section exists to that we can use the same kernel code for regular gemm and dequantizing gemms.
|
||||
// It will dispatch to the dequantizing gemm if the Mma type has an Iterator for scales in global.
|
||||
template <typename...>
|
||||
using void_t = void;
|
||||
|
||||
template <typename Mma, typename = void>
|
||||
struct use_dq_gemm : platform::false_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Mma>
|
||||
struct use_dq_gemm<Mma, void_t<typename Mma::IteratorScale>> : platform::true_type
|
||||
{
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||
typename Epilogue_, ///! Epilogue
|
||||
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
|
||||
typename KernelArch, ///! The Architecture this kernel is compiled for. Used since SIMT kernels lose top-level
|
||||
/// arch.
|
||||
GroupScheduleMode GroupScheduleMode_ ///! Type of scheduling to perform
|
||||
>
|
||||
struct MoeFCGemm
|
||||
{
|
||||
public:
|
||||
using Mma = Mma_;
|
||||
using Epilogue = Epilogue_;
|
||||
using EpilogueOutputOp = typename Epilogue::OutputOp;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_;
|
||||
static bool const kTransposed = false;
|
||||
|
||||
// Optional transpose
|
||||
using MapArguments = kernel::detail::MapArguments<typename Mma::IteratorA::Element, typename Mma::IteratorA::Layout,
|
||||
Mma::kTransformA, Mma::IteratorA::AccessType::kElements, typename Mma::IteratorB::Element,
|
||||
typename Mma::IteratorB::Layout, Mma::kTransformB, Mma::IteratorB::AccessType::kElements, typename Mma::LayoutC,
|
||||
kTransposed>;
|
||||
|
||||
// Public-facing type definitions related to operand element type, layout, and complex conjugate
|
||||
// operation. Must interact with the 'kTransposed' notion.
|
||||
static_assert(!kTransposed, "Transpose problem not supported");
|
||||
using ElementA = typename MapArguments::ElementA;
|
||||
using LayoutA = typename MapArguments::LayoutA;
|
||||
using ElementB = typename MapArguments::ElementB;
|
||||
using LayoutB = typename MapArguments::LayoutB;
|
||||
using ElementC = typename Epilogue::OutputTileIterator::Element;
|
||||
using LayoutC = typename MapArguments::LayoutC;
|
||||
using ElementScale = ElementC;
|
||||
|
||||
static ComplexTransform const kTransformA = MapArguments::kTransformA;
|
||||
static ComplexTransform const kTransformB = MapArguments::kTransformB;
|
||||
|
||||
// Type definitions about the mainloop.
|
||||
using Operator = typename Mma::Operator;
|
||||
using OperatorClass = typename Mma::Operator::OperatorClass;
|
||||
using ThreadblockShape = typename Mma::Shape;
|
||||
using WarpShape = typename Mma::Operator::Shape;
|
||||
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
|
||||
using ArchTag = typename Mma::ArchTag;
|
||||
|
||||
static int const kStages = Mma::kStages;
|
||||
static int const kAlignmentA = MapArguments::kAlignmentA;
|
||||
static int const kAlignmentB = MapArguments::kAlignmentB;
|
||||
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
/// Warp count (concept: GemmShape)
|
||||
using WarpCount = typename Mma::WarpCount;
|
||||
static int const kThreadCount = 32 * WarpCount::kCount;
|
||||
|
||||
using ProblemVisitor
|
||||
= GemmMoeProblemVisitor<ThreadblockShape, kGroupScheduleMode, kThreadCount, kThreadCount, kTransposed>;
|
||||
|
||||
//
|
||||
// Structures
|
||||
//
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments
|
||||
{
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
int problem_count;
|
||||
int threadblock_count;
|
||||
int group_size;
|
||||
|
||||
typename EpilogueOutputOp::Params output_op;
|
||||
|
||||
ElementA* ptr_A;
|
||||
ElementB* ptr_B;
|
||||
ElementScale* weight_scales;
|
||||
ElementC* ptr_C;
|
||||
ElementC* ptr_D;
|
||||
bool C_is_broadcast;
|
||||
|
||||
int64_t const* total_tokens_including_expert;
|
||||
int64_t gemm_n;
|
||||
int64_t gemm_k;
|
||||
|
||||
// Only used by device-level operator
|
||||
GemmCoord* host_problem_sizes;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments()
|
||||
: problem_count(0)
|
||||
, threadblock_count(0)
|
||||
, ptr_A(nullptr)
|
||||
, ptr_B(nullptr)
|
||||
, weight_scales(nullptr)
|
||||
, ptr_C(nullptr)
|
||||
, ptr_D(nullptr)
|
||||
, total_tokens_including_expert(nullptr)
|
||||
, gemm_n(0)
|
||||
, gemm_k(0)
|
||||
, host_problem_sizes(nullptr)
|
||||
, C_is_broadcast{true}
|
||||
{
|
||||
}
|
||||
|
||||
/// Ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(int problem_count, int threadblock_count, int group_size, typename EpilogueOutputOp::Params output_op,
|
||||
ElementA const* ptr_A, ElementB const* ptr_B, ElementScale const* weight_scales, ElementC const* ptr_C,
|
||||
bool C_is_broadcast, ElementC* ptr_D, int64_t const* total_tokens_including_expert, int64_t gemm_n,
|
||||
int64_t gemm_k, GemmCoord* host_problem_sizes = nullptr)
|
||||
: problem_count(problem_count)
|
||||
, threadblock_count(threadblock_count)
|
||||
, group_size(group_size)
|
||||
, output_op(output_op)
|
||||
, ptr_A(const_cast<ElementA*>(ptr_A))
|
||||
, ptr_B(const_cast<ElementB*>(ptr_B))
|
||||
, weight_scales(const_cast<ElementScale*>(weight_scales))
|
||||
, ptr_C(const_cast<ElementC*>(ptr_C))
|
||||
, C_is_broadcast{C_is_broadcast}
|
||||
, ptr_D(ptr_D)
|
||||
, total_tokens_including_expert(total_tokens_including_expert)
|
||||
, gemm_n(gemm_n)
|
||||
, gemm_k(gemm_k)
|
||||
, host_problem_sizes(nullptr)
|
||||
{
|
||||
if (platform::is_same<uint8_t, ElementB>::value || platform::is_same<uint4b_t, ElementB>::value)
|
||||
{
|
||||
assert(weight_scales);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Structure for precomputing values in host memory and passing to kernels
|
||||
//
|
||||
|
||||
/// Parameters structure
|
||||
struct Params
|
||||
{
|
||||
|
||||
typename ProblemVisitor::Params problem_visitor;
|
||||
int threadblock_count;
|
||||
int group_size;
|
||||
bool C_is_broadcast;
|
||||
|
||||
typename EpilogueOutputOp::Params output_op;
|
||||
|
||||
ElementA* ptr_A;
|
||||
ElementB* ptr_B;
|
||||
ElementScale* weight_scales;
|
||||
ElementC* ptr_C;
|
||||
ElementC* ptr_D;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params()
|
||||
: ptr_A(nullptr)
|
||||
, ptr_B(nullptr)
|
||||
, weight_scales(nullptr)
|
||||
, ptr_C(nullptr)
|
||||
, ptr_D(nullptr)
|
||||
, C_is_broadcast(true)
|
||||
{
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const& args, void* workspace = nullptr, int tile_count = 0)
|
||||
: problem_visitor(
|
||||
args.total_tokens_including_expert, args.gemm_n, args.gemm_k, args.problem_count, workspace, tile_count)
|
||||
, threadblock_count(args.threadblock_count)
|
||||
, group_size(args.group_size)
|
||||
, output_op(args.output_op)
|
||||
, ptr_A(args.ptr_A)
|
||||
, ptr_B(args.ptr_B)
|
||||
, weight_scales(args.weight_scales)
|
||||
, ptr_C(args.ptr_C)
|
||||
, ptr_D(args.ptr_D)
|
||||
, C_is_broadcast(args.C_is_broadcast)
|
||||
{
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void update(Arguments const& args, void* workspace = nullptr, int tile_count = 0)
|
||||
{
|
||||
|
||||
problem_visitor = typename ProblemVisitor::Params(args.total_tokens_including_expert, args.gemm_n,
|
||||
args.gemm_k, args.problem_count, workspace, tile_count);
|
||||
threadblock_count = args.threadblock_count;
|
||||
output_op = args.output_op;
|
||||
ptr_A = args.ptr_A;
|
||||
ptr_B = args.ptr_B;
|
||||
weight_scales = args.weight_scales;
|
||||
ptr_C = args.ptr_C;
|
||||
ptr_D = args.ptr_D;
|
||||
C_is_broadcast = args.C_is_broadcast;
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared memory storage structure
|
||||
union SharedStorage
|
||||
{
|
||||
typename ProblemVisitor::SharedStorage problem_visitor;
|
||||
typename Mma::SharedStorage main_loop;
|
||||
typename Epilogue::SharedStorage epilogue;
|
||||
};
|
||||
|
||||
public:
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_DEVICE
|
||||
MoeFCGemm() {}
|
||||
|
||||
/// Determines whether kernel satisfies alignment
|
||||
static Status can_implement(cutlass::gemm::GemmCoord const& problem_size)
|
||||
{
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
static Status can_implement(Arguments const& args)
|
||||
{
|
||||
if (platform::is_same<uint8_t, ElementB>::value || platform::is_same<uint4b_t, ElementB>::value)
|
||||
{
|
||||
if (args.weight_scales == nullptr)
|
||||
{
|
||||
CUTLASS_TRACE_HOST("MoeFCGemm::can_implement() - weight scales are required for uint8_t and uint4b_t");
|
||||
return Status::kInvalid;
|
||||
}
|
||||
}
|
||||
else if (args.weight_scales != nullptr)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(
|
||||
"MoeFCGemm::can_implement() - weight scales are ignored for all types except uint8_t and uint4b_t");
|
||||
return Status::kInvalid;
|
||||
}
|
||||
else if (args.group_size != args.gemm_k)
|
||||
{
|
||||
CUTLASS_TRACE_HOST("MoeFCGemm::can_implement() - scale shape should be (1, gemm_n)");
|
||||
return Status::kInvalid;
|
||||
}
|
||||
// Handle the case the input is too short
|
||||
else if (args.gemm_n < Mma::IteratorB::AccessType::kElements)
|
||||
{
|
||||
CUTLASS_TRACE_HOST("MoeFCGemm::can_implement() - gemm_n is smaller than the input alignment");
|
||||
return Status::kInvalid;
|
||||
}
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape)
|
||||
{
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void run_kernel_(Params const& params, SharedStorage& shared_storage)
|
||||
{
|
||||
//
|
||||
// These types shadow the type-level definitions and support the ability to implement
|
||||
// a 'transposed' GEMM that computes the transposed problems.
|
||||
//
|
||||
using ElementA = typename Mma::IteratorA::Element;
|
||||
using LayoutA = typename Mma::IteratorA::Layout;
|
||||
using ElementB = typename Mma::IteratorB::Element;
|
||||
using LayoutB = typename Mma::IteratorB::Layout;
|
||||
using ElementC = typename Epilogue::OutputTileIterator::Element;
|
||||
using LayoutC = typename Epilogue::OutputTileIterator::Layout;
|
||||
static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK;
|
||||
static_assert(platform::is_same<LayoutB, layout::RowMajor>::value && kInterleave == 1
|
||||
|| platform::is_same<LayoutB, layout::ColumnMajor>::value && kInterleave >= 1,
|
||||
"B must be row major/col major OR col major interleaved.");
|
||||
|
||||
//
|
||||
// Problem visitor.
|
||||
//
|
||||
ProblemVisitor problem_visitor(params.problem_visitor, shared_storage.problem_visitor, blockIdx.x);
|
||||
|
||||
const int64_t gemm_k = params.problem_visitor.gemm_k;
|
||||
const int64_t gemm_n = params.problem_visitor.gemm_n;
|
||||
int64_t bytes_per_expert_matrix = (gemm_k * gemm_n / 8) * cutlass::sizeof_bits<ElementB>::value;
|
||||
|
||||
// Outer 'persistent' loop to iterate over tiles
|
||||
int loop = 0;
|
||||
while (problem_visitor.next_tile())
|
||||
{
|
||||
loop++;
|
||||
|
||||
GemmCoord problem_size = problem_visitor.problem_size();
|
||||
int32_t problem_idx = problem_visitor.problem_index();
|
||||
int32_t cta_idx = int32_t(problem_visitor.threadblock_idx());
|
||||
|
||||
GemmCoord grid_shape = problem_visitor.grid_shape(problem_size);
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_offset(
|
||||
int(cta_idx / grid_shape.n()) * Mma::Shape::kM, int(cta_idx % grid_shape.n()) * Mma::Shape::kN, 0);
|
||||
|
||||
// Load element pointers. Exchange pointers and strides if working on the transpose
|
||||
const int64_t rows_to_jump
|
||||
= problem_idx == 0 ? 0 : params.problem_visitor.last_row_for_problem[problem_idx - 1];
|
||||
ElementA* ptr_A = reinterpret_cast<ElementA*>(params.ptr_A) + rows_to_jump * gemm_k;
|
||||
typename LayoutA::LongIndex ldm_A = gemm_k;
|
||||
|
||||
char* byte_ptr_B = ((char*) params.ptr_B) + problem_idx * bytes_per_expert_matrix;
|
||||
ElementB* ptr_B = reinterpret_cast<ElementB*>(byte_ptr_B);
|
||||
typename LayoutB::LongIndex ldm_B
|
||||
= platform::is_same<layout::RowMajor, LayoutB>::value ? gemm_n : gemm_k * kInterleave;
|
||||
|
||||
// Compute initial location in logical coordinates
|
||||
cutlass::MatrixCoord tb_offset_A{
|
||||
threadblock_offset.m(),
|
||||
0,
|
||||
};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_B{0, threadblock_offset.n() / kInterleave};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()};
|
||||
|
||||
// Compute position within threadblock
|
||||
int thread_idx = threadIdx.x;
|
||||
|
||||
// Construct iterators to A and B operands
|
||||
typename Mma::IteratorA iterator_A(
|
||||
LayoutA(ldm_A), ptr_A, {problem_size.m(), problem_size.k()}, thread_idx, tb_offset_A);
|
||||
|
||||
typename Mma::IteratorB iterator_B(LayoutB(ldm_B), ptr_B,
|
||||
{problem_size.k() * kInterleave, problem_size.n() / kInterleave}, thread_idx, tb_offset_B);
|
||||
|
||||
typename Mma::FragmentC accumulators;
|
||||
|
||||
accumulators.clear();
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
//
|
||||
// Matrix multiply phase
|
||||
//
|
||||
|
||||
// Construct thread-scoped matrix multiply
|
||||
auto CreateMMA = [&]()
|
||||
{
|
||||
if constexpr (use_dq_gemm<Mma>::value)
|
||||
return Mma(shared_storage.main_loop, params.group_size, thread_idx, warp_idx, lane_idx);
|
||||
else
|
||||
return Mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
|
||||
};
|
||||
Mma mma = CreateMMA();
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
|
||||
|
||||
// Wait for all threads to finish their epilogue phases from the previous tile.
|
||||
__syncthreads();
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
ElementScale* weight_scale_ptr = params.weight_scales + problem_idx * problem_size.n();
|
||||
|
||||
if constexpr (use_dq_gemm<Mma>::value)
|
||||
{
|
||||
const MatrixCoord scale_extent = {1, problem_size.n()};
|
||||
typename Mma::IteratorScale iterator_scale(Mma::IteratorScale::Layout(scale_extent.column()),
|
||||
weight_scale_ptr, scale_extent, thread_idx, tb_offset_scale);
|
||||
|
||||
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_scale, accumulators);
|
||||
}
|
||||
else
|
||||
{
|
||||
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators);
|
||||
}
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
//
|
||||
|
||||
ElementC* ptr_C = reinterpret_cast<ElementC*>(params.ptr_C)
|
||||
+ (params.C_is_broadcast ? problem_idx : rows_to_jump) * gemm_n;
|
||||
ElementC* ptr_D = reinterpret_cast<ElementC*>(params.ptr_D) + rows_to_jump * gemm_n;
|
||||
|
||||
// lora need to set as layout_C(gemm_n)
|
||||
LayoutC layout_C = params.C_is_broadcast ? LayoutC(0) : LayoutC(gemm_n);
|
||||
LayoutC layout_D(gemm_n);
|
||||
|
||||
typename Epilogue::OutputTileIterator::Params params_C(layout_C);
|
||||
typename Epilogue::OutputTileIterator::Params params_D(layout_D);
|
||||
|
||||
// Tile iterator loading from source tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_C(
|
||||
params_C, ptr_C, problem_size.mn(), thread_idx, threadblock_offset.mn());
|
||||
|
||||
// Tile iterator writing to destination tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_D(
|
||||
params_D, ptr_D, problem_size.mn(), thread_idx, threadblock_offset.mn());
|
||||
|
||||
Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
// Execute the epilogue operator to update the destination tensor.
|
||||
if constexpr (platform::is_same<EpilogueOutputOp,
|
||||
cutlass::epilogue::thread::LinearCombination<typename EpilogueOutputOp::ElementOutput,
|
||||
EpilogueOutputOp::kCount, typename EpilogueOutputOp::ElementAccumulator,
|
||||
typename EpilogueOutputOp::ElementCompute, EpilogueOutputOp::kScale,
|
||||
EpilogueOutputOp::kRound>>::value)
|
||||
{
|
||||
EpilogueOutputOp output_op(params.output_op, problem_idx);
|
||||
epilogue(output_op, iterator_D, accumulators, iterator_C);
|
||||
}
|
||||
else
|
||||
{
|
||||
EpilogueOutputOp output_op(params.output_op);
|
||||
epilogue(output_op, iterator_D, accumulators, iterator_C);
|
||||
}
|
||||
|
||||
// Next tile
|
||||
problem_visitor.advance(gridDim.x);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename CompilationArch>
|
||||
CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage)
|
||||
{
|
||||
if constexpr (platform::is_same<KernelArch, CompilationArch>::value)
|
||||
{
|
||||
run_kernel_(params, shared_storage);
|
||||
}
|
||||
else
|
||||
{
|
||||
CUTLASS_NOT_IMPLEMENTED();
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond
|
||||
to the ArchTag of the cutlass kernel operator.
|
||||
*/
|
||||
/// Executes one GEMM
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const& params, SharedStorage& shared_storage)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__)
|
||||
#if (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800)
|
||||
run_kernel<arch::Sm75>(params, shared_storage);
|
||||
#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 890)
|
||||
run_kernel<arch::Sm80>(params, shared_storage);
|
||||
#elif (__CUDA_ARCH__ >= 890) && (__CUDA_ARCH__ < 900)
|
||||
constexpr bool isFp8 = platform::is_same<ElementA, cutlass::float_e4m3_t>::value
|
||||
|| platform::is_same<ElementA, cutlass::float_e5m2_t>::value;
|
||||
if constexpr (isFp8)
|
||||
{
|
||||
run_kernel<arch::Sm89>(params, shared_storage);
|
||||
}
|
||||
else
|
||||
{ // reuse sm80 kernel for other types, align with dispatchToArch
|
||||
run_kernel<arch::Sm80>(params, shared_storage);
|
||||
}
|
||||
#elif (__CUDA_ARCH__ >= 900)
|
||||
run_kernel<arch::Sm80>(params, shared_storage);
|
||||
#else
|
||||
static_assert(
|
||||
false, "Invalid architecture being compiled. Only Volta+ supported in weight-only quantization kernels.");
|
||||
#endif
|
||||
#else
|
||||
CUTLASS_NOT_IMPLEMENTED();
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,344 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/*! \file
|
||||
\brief Base scheduler for grouped problems, using MoE
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/gemm/kernel/grouped_problem_visitor.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace kernel
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Visitor class to abstract away the algorithm for iterating over tiles
|
||||
template <typename ProblemSizeHelper, typename ThreadblockShape_>
|
||||
struct BaseMoeProblemVisitor
|
||||
{
|
||||
using ThreadblockShape = ThreadblockShape_;
|
||||
|
||||
struct ProblemInfo
|
||||
{
|
||||
static int32_t const kNoPrefetchEntry = -1;
|
||||
int32_t problem_idx;
|
||||
int32_t problem_start;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ProblemInfo()
|
||||
: problem_idx(kNoPrefetchEntry)
|
||||
, problem_start(kNoPrefetchEntry)
|
||||
{
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ProblemInfo(int32_t problem_idx_, int32_t problem_start_)
|
||||
: problem_idx(problem_idx_)
|
||||
, problem_start(problem_start_)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
struct Params
|
||||
{
|
||||
int64_t const* last_row_for_problem;
|
||||
int64_t gemm_n;
|
||||
int64_t gemm_k;
|
||||
int32_t problem_count;
|
||||
void const* workspace;
|
||||
int32_t tile_count;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params()
|
||||
: last_row_for_problem(nullptr)
|
||||
, gemm_n(0)
|
||||
, gemm_k(0)
|
||||
, problem_count(0)
|
||||
, workspace(nullptr)
|
||||
, tile_count(0)
|
||||
{
|
||||
}
|
||||
|
||||
/// Ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(int64_t const* last_row_for_problem, int64_t gemm_n, int64_t gemm_k, int32_t problem_count,
|
||||
void const* workspace = nullptr, int32_t tile_count = 0)
|
||||
: last_row_for_problem(last_row_for_problem)
|
||||
, gemm_n(gemm_n)
|
||||
, gemm_k(gemm_k)
|
||||
, problem_count(problem_count)
|
||||
, workspace(workspace)
|
||||
, tile_count(tile_count)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
Params const& params;
|
||||
int32_t tile_idx;
|
||||
int32_t problem_tile_start;
|
||||
int32_t problem_idx;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_DEVICE
|
||||
BaseMoeProblemVisitor(Params const& params_, int32_t block_idx)
|
||||
: params(params_)
|
||||
, tile_idx(block_idx)
|
||||
, problem_tile_start(0)
|
||||
, problem_idx(0)
|
||||
{
|
||||
}
|
||||
|
||||
/// Get the grid shape
|
||||
CUTLASS_HOST_DEVICE
|
||||
static cutlass::gemm::GemmCoord grid_shape(cutlass::gemm::GemmCoord const& problem)
|
||||
{
|
||||
|
||||
return cutlass::gemm::GemmCoord(((problem.m() - 1 + ThreadblockShape::kM) / ThreadblockShape::kM),
|
||||
((problem.n() - 1 + ThreadblockShape::kN) / ThreadblockShape::kN), 1);
|
||||
}
|
||||
|
||||
/// Gets the global tile index
|
||||
CUTLASS_HOST_DEVICE
|
||||
int32_t tile_index() const
|
||||
{
|
||||
return tile_idx;
|
||||
}
|
||||
|
||||
/// Gets the index of the problem
|
||||
CUTLASS_HOST_DEVICE
|
||||
int32_t problem_index() const
|
||||
{
|
||||
return problem_idx;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
int32_t threadblock_idx() const
|
||||
{
|
||||
return tile_idx - problem_tile_start;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void advance(int32_t grid_size)
|
||||
{
|
||||
tile_idx += grid_size;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static void possibly_transpose_problem(cutlass::gemm::GemmCoord& problem)
|
||||
{
|
||||
ProblemSizeHelper::possibly_transpose_problem(problem);
|
||||
}
|
||||
|
||||
/// Returns the problem size for the current problem
|
||||
CUTLASS_HOST_DEVICE
|
||||
cutlass::gemm::GemmCoord problem_size() const
|
||||
{
|
||||
return problem_size(problem_idx);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
cutlass::gemm::GemmCoord problem_size(int idx) const
|
||||
{
|
||||
const int64_t prev_problem_row = idx == 0 ? 0 : params.last_row_for_problem[idx - 1];
|
||||
const int64_t current_problem_row = params.last_row_for_problem[idx];
|
||||
const int64_t gemm_m = current_problem_row - prev_problem_row;
|
||||
GemmCoord problem(GemmCoord::Index(gemm_m), GemmCoord::Index(params.gemm_n), GemmCoord::Index(params.gemm_k));
|
||||
ProblemSizeHelper::possibly_transpose_problem(problem);
|
||||
return problem;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static int32_t tile_count(cutlass::gemm::GemmCoord const& grid)
|
||||
{
|
||||
return ProblemSizeHelper::tile_count(grid);
|
||||
}
|
||||
|
||||
static int32_t group_tile_count(cutlass::gemm::GemmCoord const* host_problem_sizes_ptr, int32_t problem_count)
|
||||
{
|
||||
int32_t total_tiles = 0;
|
||||
for (int32_t i = 0; i < problem_count; ++i)
|
||||
{
|
||||
auto problem = host_problem_sizes_ptr[i];
|
||||
possibly_transpose_problem(problem);
|
||||
auto grid = grid_shape(problem);
|
||||
total_tiles += tile_count(grid);
|
||||
}
|
||||
|
||||
return total_tiles;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename ProblemSizeHelper, typename ThreadblockShape, GroupScheduleMode GroupScheduleMode_,
|
||||
int PrefetchTileCount, int ThreadCount>
|
||||
struct MoeProblemVisitor;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// ProblemVisitor that performs all scheduling on device
|
||||
//
|
||||
template <typename ProblemSizeHelper, typename ThreadblockShape, int PrefetchTileCount, int ThreadCount>
|
||||
struct MoeProblemVisitor<ProblemSizeHelper, ThreadblockShape, GroupScheduleMode::kDeviceOnly, PrefetchTileCount,
|
||||
ThreadCount> : public BaseMoeProblemVisitor<ProblemSizeHelper, ThreadblockShape>
|
||||
{
|
||||
using Base = BaseMoeProblemVisitor<ProblemSizeHelper, ThreadblockShape>;
|
||||
using Params = typename Base::Params;
|
||||
static int const kThreadCount = ThreadCount;
|
||||
static bool const kRequiresPrecomputation = false;
|
||||
static int const kThreadsPerWarp = 32;
|
||||
|
||||
struct SharedStorage
|
||||
{
|
||||
};
|
||||
|
||||
// Final tile of the problem loaded by this thread. Each thread will hold
|
||||
// a separate value.
|
||||
int32_t problem_ending_tile;
|
||||
|
||||
SharedStorage& shared_storage;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_DEVICE
|
||||
MoeProblemVisitor(Params const& params_, SharedStorage& shared_storage_, int32_t block_idx)
|
||||
: Base(params_, block_idx)
|
||||
, problem_ending_tile(0)
|
||||
, shared_storage(shared_storage_)
|
||||
{
|
||||
this->problem_idx = -1 * kThreadsPerWarp;
|
||||
this->problem_tile_start = 0;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
bool next_tile()
|
||||
{
|
||||
// Check whether the tile to compute is within the range of the current problem.
|
||||
int32_t problem_tile_end = __shfl_sync(0xffffffff, problem_ending_tile, this->problem_idx % kThreadsPerWarp);
|
||||
if (this->tile_idx < problem_tile_end)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
// Check whether the tile to compute is within the current group of problems fetched by the warp.
|
||||
// The last tile for this group is the final tile of the problem held by the final thread in the warp.
|
||||
int32_t group_tile_end = __shfl_sync(0xffffffff, problem_ending_tile, kThreadsPerWarp - 1);
|
||||
|
||||
// Keep the starting problem for this group in `problem_idx`. This is done to reduce
|
||||
// register pressure. The starting problem for this group is simply the first problem
|
||||
// in the group most recently fetched by the warp.
|
||||
int32_t& group_problem_start = this->problem_idx;
|
||||
group_problem_start = (this->problem_idx / kThreadsPerWarp) * kThreadsPerWarp;
|
||||
|
||||
// Keep the starting tile for this group in `problem_tile_start`. This is done to reduce
|
||||
// register pressure.
|
||||
int32_t& group_tile_start = this->problem_tile_start;
|
||||
|
||||
// Each thread in the warp processes a separate problem to advance until
|
||||
// reaching a problem whose starting tile is less less than tile_idx.
|
||||
while (group_tile_end <= this->tile_idx)
|
||||
{
|
||||
group_problem_start += kThreadsPerWarp;
|
||||
if (group_problem_start > this->params.problem_count)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Since `group_tile_start` is a reference to `this->problem_tile_start`, this
|
||||
// also sets `this->problem_tile_start`. The fact that `this->problem_tile_start`
|
||||
// is also set here is used later in `next_tile`.
|
||||
group_tile_start = group_tile_end;
|
||||
|
||||
int lane_idx = threadIdx.x % kThreadsPerWarp;
|
||||
int32_t lane_problem = group_problem_start + lane_idx;
|
||||
|
||||
// Compute the number of tiles in the problem assigned to each thread.
|
||||
problem_ending_tile = 0;
|
||||
if (lane_problem < this->params.problem_count)
|
||||
{
|
||||
cutlass::gemm::GemmCoord problem = this->problem_size(lane_problem);
|
||||
cutlass::gemm::GemmCoord grid = this->grid_shape(problem);
|
||||
problem_ending_tile = this->tile_count(grid);
|
||||
}
|
||||
|
||||
// Compute a warp-wide inclusive prefix sum to compute the ending tile index of
|
||||
// each thread's problem.
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 1; i < kThreadsPerWarp; i <<= 1)
|
||||
{
|
||||
int32_t val = __shfl_up_sync(0xffffffff, problem_ending_tile, i);
|
||||
if (lane_idx >= i)
|
||||
{
|
||||
problem_ending_tile += val;
|
||||
}
|
||||
}
|
||||
|
||||
// The total tile count for this group is now in the final position of the prefix sum
|
||||
int32_t tiles_in_group = __shfl_sync(0xffffffff, problem_ending_tile, kThreadsPerWarp - 1);
|
||||
|
||||
problem_ending_tile += group_tile_start;
|
||||
group_tile_end += tiles_in_group;
|
||||
}
|
||||
|
||||
// The next problem to process is the first one that does not have ending tile position
|
||||
// that is greater than or equal to tile index.
|
||||
int32_t problem_idx_in_group = __popc(__ballot_sync(0xffffffff, problem_ending_tile <= this->tile_idx));
|
||||
|
||||
this->problem_idx = group_problem_start + problem_idx_in_group;
|
||||
|
||||
// The starting tile for this problem is the ending tile of the previous problem. In cases
|
||||
// where `problem_idx_in_group` is the first problem in the group, we do not need to reset
|
||||
// `problem_tile_start`, because it is set to the previous group's ending tile in the while
|
||||
// loop above.
|
||||
if (problem_idx_in_group > 0)
|
||||
{
|
||||
this->problem_tile_start = __shfl_sync(0xffffffff, problem_ending_tile, problem_idx_in_group - 1);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static size_t get_workspace_size(
|
||||
cutlass::gemm::GemmCoord const* host_problem_sizes_ptr, int32_t problem_count, int32_t block_count)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
static void host_precompute(cutlass::gemm::GemmCoord const* host_problem_sizes_ptr, int32_t problem_count,
|
||||
int32_t block_count, void* host_workspace_ptr)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@@ -0,0 +1,646 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "cute/arch/cluster_sm90.hpp"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/arch/mma_sm90.h"
|
||||
#include "cutlass/arch/reg_reconfig.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/epilogue/collective/detail.hpp"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
|
||||
#include "cutlass/kernel_hardware_info.hpp"
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
#include "cutlass/trace.h"
|
||||
#include "cutlass/workspace.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::kernel
|
||||
{
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <class ProblemShape_, class CollectiveMainloop_, class CollectiveEpilogue_, class TileScheduler_>
|
||||
class GemmUniversalGated<ProblemShape_, CollectiveMainloop_, CollectiveEpilogue_, TileScheduler_,
|
||||
cute::enable_if_t<
|
||||
cute::is_base_of_v<KernelTmaWarpSpecializedCooperative, typename CollectiveMainloop_::DispatchPolicy::Schedule>
|
||||
&& CollectiveMainloop_::isGated>>
|
||||
{
|
||||
public:
|
||||
//
|
||||
// Type Aliases
|
||||
//
|
||||
using ProblemShape = ProblemShape_;
|
||||
static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4,
|
||||
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
|
||||
// Mainloop derived types
|
||||
using CollectiveMainloop = CollectiveMainloop_;
|
||||
using TileShape = typename CollectiveMainloop::TileShape;
|
||||
using TiledMma = typename CollectiveMainloop::TiledMma;
|
||||
using ArchTag = typename CollectiveMainloop::ArchTag;
|
||||
using ElementA = typename CollectiveMainloop::ElementA;
|
||||
using StrideA = typename CollectiveMainloop::StrideA;
|
||||
using ElementB = typename CollectiveMainloop::ElementB;
|
||||
using StrideB = typename CollectiveMainloop::StrideB;
|
||||
using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy;
|
||||
using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator;
|
||||
using ClusterShape = typename DispatchPolicy::ClusterShape;
|
||||
using MainloopArguments = typename CollectiveMainloop::Arguments;
|
||||
using MainloopParams = typename CollectiveMainloop::Params;
|
||||
using Activation = typename CollectiveMainloop::Activation;
|
||||
|
||||
// Epilogue derived types
|
||||
using CollectiveEpilogue = CollectiveEpilogue_;
|
||||
using ElementC = typename CollectiveEpilogue::ElementC;
|
||||
using StrideC = typename CollectiveEpilogue::StrideC;
|
||||
using ElementD = typename CollectiveEpilogue::ElementD;
|
||||
using StrideD = typename CollectiveEpilogue::StrideD;
|
||||
using EpilogueArguments = typename CollectiveEpilogue::Arguments;
|
||||
using EpilogueParams = typename CollectiveEpilogue::Params;
|
||||
|
||||
static_assert(ArchTag::kMinComputeCapability >= 90);
|
||||
|
||||
using TileSchedulerTag = TileScheduler_;
|
||||
using TileScheduler =
|
||||
typename detail::TileSchedulerSelector<TileScheduler_, ArchTag, TileShape, ClusterShape>::Scheduler;
|
||||
using TileSchedulerArguments = typename TileScheduler::Arguments;
|
||||
using TileSchedulerParams = typename TileScheduler::Params;
|
||||
|
||||
static constexpr uint32_t NumLoadWarpGroups = 1;
|
||||
static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMma{})) / NumThreadsPerWarpGroup;
|
||||
static constexpr uint32_t MaxThreadsPerBlock
|
||||
= CUTE_STATIC_V(size(TiledMma{})) + (NumLoadWarpGroups * NumThreadsPerWarpGroup);
|
||||
static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
|
||||
|
||||
/// Register requirement for Load and Math WGs
|
||||
static constexpr uint32_t LoadRegisterRequirement = 40;
|
||||
static constexpr uint32_t MmaRegisterRequirement = 232;
|
||||
|
||||
// 1 stage ordered sequence between mainloop and epilogue producer load threads
|
||||
using LoadWarpOrderBarrier = cutlass::OrderedSequenceBarrier<1, 2>;
|
||||
|
||||
// Kernel level shared memory storage
|
||||
struct SharedStorage
|
||||
{
|
||||
struct TensorStorage : cute::aligned_struct<128>
|
||||
{
|
||||
using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage;
|
||||
using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage;
|
||||
|
||||
MainloopTensorStorage mainloop;
|
||||
EpilogueTensorStorage epilogue;
|
||||
} tensors;
|
||||
|
||||
struct PipelineStorage : cute::aligned_struct<16>
|
||||
{
|
||||
using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage;
|
||||
using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage;
|
||||
|
||||
alignas(16) MainloopPipelineStorage mainloop;
|
||||
alignas(16) EpiLoadPipelineStorage epi_load;
|
||||
alignas(16) typename LoadWarpOrderBarrier::SharedStorage load_order;
|
||||
} pipelines;
|
||||
};
|
||||
|
||||
static constexpr int SharedStorageSize = sizeof(SharedStorage);
|
||||
|
||||
// Device side arguments
|
||||
struct Arguments
|
||||
{
|
||||
GemmUniversalMode mode{};
|
||||
ProblemShape problem_shape{};
|
||||
MainloopArguments mainloop{};
|
||||
EpilogueArguments epilogue{};
|
||||
KernelHardwareInfo hw_info{};
|
||||
TileSchedulerArguments scheduler{};
|
||||
};
|
||||
|
||||
// Kernel entry point API
|
||||
struct Params
|
||||
{
|
||||
GemmUniversalMode mode{};
|
||||
ProblemShape problem_shape{};
|
||||
MainloopParams mainloop{};
|
||||
EpilogueParams epilogue{};
|
||||
KernelHardwareInfo hw_info{};
|
||||
TileSchedulerParams scheduler{};
|
||||
void* workspace{nullptr};
|
||||
};
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
// Convert to underlying arguments. In this case, a simple copy for the aliased type.
|
||||
static Params to_underlying_arguments(Arguments const& args, void* workspace)
|
||||
{
|
||||
CUTLASS_TRACE_HOST("to_underlying_arguments():");
|
||||
|
||||
auto problem_shape = args.problem_shape;
|
||||
// if constexpr (detail::IF_SWAP_AB<CollectiveMainloop>::value) {
|
||||
// // swap M/N
|
||||
// get<0>(problem_shape) = get<1>(args.problem_shape);
|
||||
// get<1>(problem_shape) = get<0>(args.problem_shape);
|
||||
// }
|
||||
auto problem_shape_MNKL = append<4>(problem_shape, 1);
|
||||
|
||||
// Get SM count if needed, otherwise use user supplied SM count
|
||||
int sm_count = args.hw_info.sm_count;
|
||||
if (sm_count <= 0)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(
|
||||
" WARNING: Arguments do not include a valid SM count.\n"
|
||||
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
|
||||
sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id);
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
|
||||
|
||||
KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};
|
||||
|
||||
// Calculate workspace pointers
|
||||
uint8_t* workspace_ptr = reinterpret_cast<uint8_t*>(workspace);
|
||||
size_t workspace_offset = 0;
|
||||
|
||||
void* scheduler_workspace = workspace_ptr;
|
||||
workspace_offset += TileScheduler::template get_workspace_size<ProblemShape, ElementAccumulator>(
|
||||
args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups);
|
||||
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
|
||||
|
||||
void* epilogue_workspace = workspace_ptr + workspace_offset;
|
||||
workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue);
|
||||
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
|
||||
|
||||
void* mainloop_workspace = nullptr;
|
||||
// Precompute the sub tiles numbers in epilogue, pass into tile scheduler. Therefore it will be used
|
||||
// in separate reduction scheme for streamk case, NumEpilogueSubTiles default value is 1, which means
|
||||
// subtile will not be used, therefore separate reduction will not be enabled.
|
||||
constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{});
|
||||
TileSchedulerParams scheduler = TileScheduler::to_underlying_arguments(problem_shape_MNKL, TileShape{},
|
||||
ClusterShape{}, hw_info, args.scheduler, scheduler_workspace, NumEpilogueSubTiles);
|
||||
|
||||
return {args.mode, problem_shape,
|
||||
CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, mainloop_workspace),
|
||||
CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, epilogue_workspace), hw_info,
|
||||
scheduler, workspace};
|
||||
}
|
||||
|
||||
static bool can_implement(Arguments const& args)
|
||||
{
|
||||
bool implementable = (args.mode == GemmUniversalMode::kGemm)
|
||||
or (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4);
|
||||
if (!implementable)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n");
|
||||
return implementable;
|
||||
}
|
||||
implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop);
|
||||
implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue);
|
||||
implementable &= TileScheduler::can_implement(args.scheduler);
|
||||
return implementable;
|
||||
}
|
||||
|
||||
static size_t get_workspace_size(Arguments const& args)
|
||||
{
|
||||
size_t workspace_size = 0;
|
||||
constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{});
|
||||
|
||||
workspace_size += TileScheduler::template get_workspace_size<ProblemShape, ElementAccumulator>(
|
||||
args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles);
|
||||
workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment);
|
||||
|
||||
workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue);
|
||||
workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment);
|
||||
|
||||
return workspace_size;
|
||||
}
|
||||
|
||||
static cutlass::Status initialize_workspace(Arguments const& args, void* workspace = nullptr,
|
||||
cudaStream_t stream = nullptr, CudaHostAdapter* cuda_adapter = nullptr)
|
||||
{
|
||||
Status status = Status::kSuccess;
|
||||
uint8_t* workspace_ptr = reinterpret_cast<uint8_t*>(workspace);
|
||||
size_t workspace_offset = 0;
|
||||
constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{});
|
||||
|
||||
status = TileScheduler::template initialize_workspace<ProblemShape, ElementAccumulator>(args.scheduler,
|
||||
workspace_ptr + workspace_offset, stream, args.problem_shape, args.hw_info, NumMmaWarpGroups,
|
||||
NumEpilogueSubTiles);
|
||||
workspace_offset += TileScheduler::template get_workspace_size<ProblemShape, ElementAccumulator>(
|
||||
args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles);
|
||||
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
|
||||
if (status != Status::kSuccess)
|
||||
{
|
||||
return status;
|
||||
}
|
||||
|
||||
status = CollectiveEpilogue::initialize_workspace(
|
||||
args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter);
|
||||
workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue);
|
||||
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
|
||||
if (status != Status::kSuccess)
|
||||
{
|
||||
return status;
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
// Computes the kernel launch grid shape based on runtime parameters
|
||||
static dim3 get_grid_shape(Params const& params)
|
||||
{
|
||||
// Given device SM count, set grid size s.t. we do not launch more thread blocks than we can run concurrently
|
||||
TileSchedulerArguments args{};
|
||||
if constexpr (!std::is_const_v<decltype(args.max_swizzle_size)>)
|
||||
{
|
||||
args.max_swizzle_size = 1 << params.scheduler.log_swizzle_size_;
|
||||
}
|
||||
args.raster_order = params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN
|
||||
? TileScheduler::RasterOrderOptions::AlongN
|
||||
: TileScheduler::RasterOrderOptions::AlongM;
|
||||
return TileScheduler::get_grid_shape(params.problem_shape, TileShape{}, ClusterShape{}, params.hw_info, args);
|
||||
}
|
||||
|
||||
static dim3 get_block_shape()
|
||||
{
|
||||
return dim3(MaxThreadsPerBlock, 1, 1);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const& params, char* smem_buf)
|
||||
{
|
||||
using namespace cute;
|
||||
using X = Underscore;
|
||||
|
||||
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
|
||||
#if !defined(__CUDA_ARCH_FEAT_SM90_ALL)
|
||||
printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n");
|
||||
#else
|
||||
|
||||
// Preconditions
|
||||
static_assert(size(TiledMma{}) == 256, "Cooperative kernel must have TiledMMA operating using 256 threads.");
|
||||
static_assert(size<0>(TileShape{}) >= 128,
|
||||
"Cooperative kernel requires Tile Size to be greater than or equal to 128 along the M-dimension.");
|
||||
|
||||
static_assert(cute::rank(StrideA{}) == 3,
|
||||
"StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideB{}) == 3,
|
||||
"StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideC{}) == 3,
|
||||
"StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideD{}) == 3,
|
||||
"StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
|
||||
/* In the Cooperative kernel, Consumer0 and Consumer1 collaborate on the same tile */
|
||||
enum class WarpGroupRole
|
||||
{
|
||||
Producer = 0,
|
||||
Consumer0 = 1,
|
||||
Consumer1 = 2
|
||||
};
|
||||
enum class ProducerWarpRole
|
||||
{
|
||||
Mainloop = 0,
|
||||
Warp1 = 1,
|
||||
Epilogue = 2,
|
||||
Warp3 = 3
|
||||
};
|
||||
|
||||
// Kernel level shared memory storage
|
||||
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
|
||||
|
||||
int thread_idx = int(threadIdx.x);
|
||||
int lane_idx = canonical_lane_idx();
|
||||
int warp_idx = canonical_warp_idx_sync();
|
||||
int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup;
|
||||
int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup;
|
||||
int mma_thread_idx = thread_idx % size(TiledMma{});
|
||||
auto warp_group_role = WarpGroupRole(canonical_warp_group_idx());
|
||||
auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group);
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
|
||||
|
||||
// Issue Tma Descriptor Prefetch from a single thread
|
||||
if ((warp_idx == 0) && lane_predicate)
|
||||
{
|
||||
CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
|
||||
CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue);
|
||||
}
|
||||
|
||||
// Mainloop Load pipeline
|
||||
using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline;
|
||||
typename MainloopPipeline::Params mainloop_pipeline_params;
|
||||
if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Mainloop)
|
||||
{
|
||||
mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer;
|
||||
}
|
||||
if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1)
|
||||
{
|
||||
mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer;
|
||||
}
|
||||
mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0;
|
||||
mainloop_pipeline_params.num_consumers = size(TiledMma{});
|
||||
mainloop_pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytes;
|
||||
MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{});
|
||||
|
||||
// Epilogue Load pipeline
|
||||
using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline;
|
||||
typename EpiLoadPipeline::Params epi_load_pipeline_params;
|
||||
if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Epilogue)
|
||||
{
|
||||
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer;
|
||||
}
|
||||
if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1)
|
||||
{
|
||||
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer;
|
||||
}
|
||||
epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster();
|
||||
epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp;
|
||||
epi_load_pipeline_params.consumer_arv_count = size(TiledMma{});
|
||||
epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes;
|
||||
EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params);
|
||||
|
||||
// Epilogue Store pipeline
|
||||
using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline;
|
||||
typename EpiStorePipeline::Params epi_store_pipeline_params;
|
||||
epi_store_pipeline_params.always_wait = true;
|
||||
EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params);
|
||||
|
||||
typename LoadWarpOrderBarrier::Params params_load_order_barrier;
|
||||
params_load_order_barrier.group_id = producer_warp_role == ProducerWarpRole::Mainloop ? 0 : 1;
|
||||
params_load_order_barrier.group_size = NumThreadsPerWarp;
|
||||
LoadWarpOrderBarrier load_order_barrier(shared_storage.pipelines.load_order, params_load_order_barrier);
|
||||
|
||||
// Initialize starting pipeline states for the collectives
|
||||
// Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding)
|
||||
typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state;
|
||||
typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state;
|
||||
|
||||
// For the DMA Load (producer) we start with an opposite phase
|
||||
// i.e., we skip all waits since we know that the buffer is indeed empty
|
||||
PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state<MainloopPipeline>();
|
||||
PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state<EpiLoadPipeline>();
|
||||
PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state<EpiStorePipeline>();
|
||||
|
||||
auto cluster_wait_fn = []()
|
||||
{
|
||||
// We need this to guarantee that the Pipeline init is visible
|
||||
// To all producers and consumer thread blocks in the Cluster
|
||||
if constexpr (size(ClusterShape{}) > 1)
|
||||
{
|
||||
cute::cluster_arrive_relaxed();
|
||||
return []() { cute::cluster_wait(); };
|
||||
}
|
||||
else
|
||||
{
|
||||
__syncthreads();
|
||||
return []() {}; // do nothing
|
||||
}
|
||||
}();
|
||||
|
||||
// Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK)
|
||||
auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{});
|
||||
|
||||
// Get the appropriate blocks for this thread block -- potential for thread block locality
|
||||
TiledMma tiled_mma;
|
||||
auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K)
|
||||
|
||||
TileScheduler scheduler{params.scheduler};
|
||||
auto work_tile_info = scheduler.get_current_work();
|
||||
|
||||
// In a warp specialized kernel, collectives expose data movement and compute operations separately
|
||||
CollectiveMainloop collective_mainloop;
|
||||
CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue);
|
||||
|
||||
// Prepare and partition the input tensors. Expects a tuple of tensors where:
|
||||
// get<0>(load_inputs) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l)
|
||||
// get<1>(load_inputs) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l)
|
||||
auto load_inputs = collective_mainloop.load_init(problem_shape_MNKL, params.mainloop);
|
||||
static_assert(cute::tuple_size_v<decltype(load_inputs)> >= 3,
|
||||
"Output of load_init must have at least three elements (A, B, Aux)");
|
||||
|
||||
// Extract out partitioned A and B.
|
||||
Tensor gA_mkl = get<0>(load_inputs);
|
||||
Tensor gB_nkl = get<1>(load_inputs);
|
||||
Tensor gAux_xkl = get<2>(load_inputs);
|
||||
|
||||
// Get pipeline stage increments from tensor shapes
|
||||
auto k_tile_count = size<3>(gA_mkl);
|
||||
|
||||
// Wait for all thread blocks in the Cluster
|
||||
cluster_wait_fn();
|
||||
|
||||
if (warp_group_role == WarpGroupRole::Producer)
|
||||
{
|
||||
cutlass::arch::warpgroup_reg_dealloc<LoadRegisterRequirement>();
|
||||
|
||||
// Mainloop Producer Warp
|
||||
if (producer_warp_role == ProducerWarpRole::Mainloop)
|
||||
{
|
||||
bool do_load_order_arrive = true;
|
||||
while (work_tile_info.is_valid())
|
||||
{
|
||||
if (!TileScheduler::valid_warpgroup_in_work_tile(work_tile_info))
|
||||
{
|
||||
work_tile_info = fetch_next_work(work_tile_info, scheduler);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape
|
||||
auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
|
||||
auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
|
||||
auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
|
||||
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
|
||||
|
||||
// Get the number of K tiles to compute for this work as well as the starting K tile offset of the
|
||||
// work.
|
||||
auto work_k_tile_count
|
||||
= TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape);
|
||||
auto work_k_tile_start = TileScheduler::get_work_k_tile_start(work_tile_info);
|
||||
auto k_tile_iter
|
||||
= cute::make_coord_iterator(idx2crd(work_k_tile_start, shape<3>(gA_mkl)), shape<3>(gA_mkl));
|
||||
|
||||
collective_mainloop.load(params.mainloop, mainloop_pipeline, mainloop_pipe_producer_state,
|
||||
load_inputs, blk_coord, k_tile_iter, work_k_tile_count, lane_idx, block_rank_in_cluster,
|
||||
shared_storage.tensors.mainloop);
|
||||
// Update starting pipeline state for the next tile
|
||||
mainloop_pipe_producer_state.advance(work_k_tile_count);
|
||||
|
||||
// Signal for the epilogue load warp to begin
|
||||
if (do_load_order_arrive)
|
||||
{
|
||||
load_order_barrier.arrive();
|
||||
do_load_order_arrive = false;
|
||||
}
|
||||
|
||||
// Get next work tile
|
||||
work_tile_info = fetch_next_work(work_tile_info, scheduler);
|
||||
} // Scheduler work fetch loop
|
||||
|
||||
// Make sure all Consumer Warp Groups have been waited upon
|
||||
collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state);
|
||||
} // Mainloop Producer Warp End
|
||||
|
||||
// Epilogue Producer Warp
|
||||
else if (producer_warp_role == ProducerWarpRole::Epilogue && collective_epilogue.is_producer_load_needed())
|
||||
{
|
||||
while (work_tile_info.is_valid())
|
||||
{
|
||||
if (!TileScheduler::requires_separate_reduction(params.scheduler))
|
||||
{
|
||||
load_order_barrier.wait();
|
||||
}
|
||||
if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler))
|
||||
{
|
||||
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape
|
||||
auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
|
||||
auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
|
||||
auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
|
||||
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
|
||||
|
||||
epi_load_pipe_producer_state = collective_epilogue.load(epi_load_pipeline,
|
||||
epi_load_pipe_producer_state, problem_shape_MNKL, blk_shape, blk_coord, tiled_mma, lane_idx,
|
||||
shared_storage.tensors.epilogue, work_tile_info.reduction_subtile_idx());
|
||||
}
|
||||
|
||||
// Get next work tile
|
||||
work_tile_info = fetch_next_work(work_tile_info, scheduler);
|
||||
} // Scheduler work fetch loop
|
||||
|
||||
// Make sure all Consumer Warp Groups have been waited upon
|
||||
collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state);
|
||||
} // Epilogue Producer Warp End
|
||||
} // Producer Warp Group End
|
||||
|
||||
else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1)
|
||||
{
|
||||
cutlass::arch::warpgroup_reg_alloc<MmaRegisterRequirement>();
|
||||
|
||||
// Do we potentially issue tail arrives for TMA stores, if epilogue load is waiting for it
|
||||
bool do_store_tail = false;
|
||||
float scale_d0 = params.mainloop.scale_d0;
|
||||
float scale_d1 = params.mainloop.scale_d1;
|
||||
while (work_tile_info.is_valid())
|
||||
{
|
||||
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape
|
||||
auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
|
||||
auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
|
||||
auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
|
||||
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
|
||||
auto work_k_tile_count
|
||||
= TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape);
|
||||
|
||||
// Allocate the accumulators for the (M,N) blk_shape
|
||||
//
|
||||
// MSVC CTAD breaks if we say "Tensor" here, so we use "auto" instead.
|
||||
auto accumulators0 = partition_fragment_C(tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N)
|
||||
auto accumulators1 = partition_fragment_C(tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N)
|
||||
if (TileScheduler::valid_warpgroup_in_work_tile(work_tile_info))
|
||||
{
|
||||
collective_mainloop.mma(mainloop_pipeline, mainloop_pipe_consumer_state, accumulators0,
|
||||
accumulators1, work_k_tile_count, mma_thread_idx, shared_storage.tensors.mainloop,
|
||||
params.mainloop);
|
||||
|
||||
// Make sure the math instructions are done and free buffers before entering the epilogue
|
||||
collective_mainloop.mma_tail(mainloop_pipeline, mainloop_pipe_consumer_state, work_k_tile_count);
|
||||
|
||||
// Update starting mainloop pipeline state for the next tile
|
||||
mainloop_pipe_consumer_state.advance(work_k_tile_count);
|
||||
}
|
||||
// Index of warp group within consumer warp groups
|
||||
int consumer_warp_group_idx = canonical_warp_group_idx() - NumLoadWarpGroups;
|
||||
|
||||
// Perform reduction across splits, if needed
|
||||
TileScheduler::fixup(
|
||||
params.scheduler, work_tile_info, accumulators0, NumMmaWarpGroups, consumer_warp_group_idx);
|
||||
TileScheduler::fixup(
|
||||
params.scheduler, work_tile_info, accumulators1, NumMmaWarpGroups, consumer_warp_group_idx);
|
||||
|
||||
Activation elt_op;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(accumulators0); i++)
|
||||
{
|
||||
accumulators0[i] = (accumulators0[i] * scale_d0) * elt_op(scale_d1 * accumulators1[i]);
|
||||
}
|
||||
|
||||
if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler))
|
||||
{
|
||||
// Epilogue and write to gD
|
||||
auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next]
|
||||
= collective_epilogue.store(epi_load_pipeline, epi_load_pipe_consumer_state, epi_store_pipeline,
|
||||
epi_store_pipe_producer_state, problem_shape_MNKL, blk_shape, blk_coord, accumulators0,
|
||||
tiled_mma, mma_thread_idx, shared_storage.tensors.epilogue,
|
||||
work_tile_info.reduction_subtile_idx());
|
||||
epi_load_pipe_consumer_state = epi_load_pipe_consumer_state_next;
|
||||
epi_store_pipe_producer_state = epi_store_pipe_producer_state_next;
|
||||
do_store_tail = true;
|
||||
}
|
||||
|
||||
// Get next work tile
|
||||
work_tile_info = fetch_next_work(work_tile_info, scheduler);
|
||||
} // Scheduler work fetch loop
|
||||
|
||||
if (do_store_tail)
|
||||
{
|
||||
collective_epilogue.store_tail(
|
||||
epi_load_pipeline, epi_load_pipe_consumer_state, epi_store_pipeline, epi_store_pipe_producer_state);
|
||||
}
|
||||
} // Consumer Warp Groups End
|
||||
#endif
|
||||
}
|
||||
|
||||
private:
|
||||
// Kernel helper function to get next work unit
|
||||
CUTLASS_DEVICE
|
||||
typename TileScheduler::WorkTileInfo fetch_next_work(
|
||||
typename TileScheduler::WorkTileInfo& work_tile_info, TileScheduler& scheduler) const
|
||||
{
|
||||
// Check whether we should continue on with the current work unit. If this is the case,
|
||||
// the work unit will have been updated in continue_current_work to reflect the new
|
||||
// tile to be computed.
|
||||
if (scheduler.continue_current_work(work_tile_info))
|
||||
{
|
||||
return work_tile_info;
|
||||
}
|
||||
|
||||
// Get next work tile
|
||||
scheduler.advance_to_next_work();
|
||||
return scheduler.get_current_work();
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::kernel
|
||||
@@ -0,0 +1,621 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "cute/arch/cluster_sm90.hpp"
|
||||
#include "cutlass/arch/mma_sm90.h"
|
||||
#include "cutlass/arch/reg_reconfig.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/epilogue/collective/detail.hpp"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp"
|
||||
#include "cutlass/kernel_hardware_info.hpp"
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
#include "cutlass/trace.h"
|
||||
#include "cutlass/workspace.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
#include "cute/util/debug.hpp"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::kernel
|
||||
{
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <class ProblemShape_, class CollectiveMainloop_, class CollectiveEpilogue_, class TileScheduler_>
|
||||
class GemmUniversalGated<ProblemShape_, CollectiveMainloop_, CollectiveEpilogue_, TileScheduler_,
|
||||
cute::enable_if_t<
|
||||
cute::is_base_of_v<KernelTmaWarpSpecializedPingpong, typename CollectiveMainloop_::DispatchPolicy::Schedule>
|
||||
&& CollectiveMainloop_::isGated>>
|
||||
{
|
||||
public:
|
||||
//
|
||||
// Type Aliases
|
||||
//
|
||||
using ProblemShape = ProblemShape_;
|
||||
static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4,
|
||||
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
|
||||
// Mainloop derived types
|
||||
using CollectiveMainloop = CollectiveMainloop_;
|
||||
using TileShape = typename CollectiveMainloop::TileShape;
|
||||
using TiledMma = typename CollectiveMainloop::TiledMma;
|
||||
using ArchTag = typename CollectiveMainloop::ArchTag;
|
||||
using ElementA = typename CollectiveMainloop::ElementA;
|
||||
using StrideA = typename CollectiveMainloop::StrideA;
|
||||
using ElementB = typename CollectiveMainloop::ElementB;
|
||||
using StrideB = typename CollectiveMainloop::StrideB;
|
||||
using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy;
|
||||
using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator;
|
||||
using ClusterShape = typename DispatchPolicy::ClusterShape;
|
||||
using MainloopArguments = typename CollectiveMainloop::Arguments;
|
||||
using MainloopParams = typename CollectiveMainloop::Params;
|
||||
using Activation = typename CollectiveMainloop::Activation;
|
||||
static_assert(ArchTag::kMinComputeCapability >= 90);
|
||||
|
||||
// Epilogue derived types
|
||||
using CollectiveEpilogue = CollectiveEpilogue_;
|
||||
using ElementC = typename CollectiveEpilogue::ElementC;
|
||||
using StrideC = typename CollectiveEpilogue::StrideC;
|
||||
using ElementD = typename CollectiveEpilogue::ElementD;
|
||||
using StrideD = typename CollectiveEpilogue::StrideD;
|
||||
using EpilogueArguments = typename CollectiveEpilogue::Arguments;
|
||||
using EpilogueParams = typename CollectiveEpilogue::Params;
|
||||
|
||||
static_assert(!cute::is_same_v<TileScheduler_, StreamKScheduler>,
|
||||
"Ping-pong kernel does not currently support stream-K scheduler.");
|
||||
using TileSchedulerTag = TileScheduler_;
|
||||
using TileScheduler =
|
||||
typename detail::TileSchedulerSelector<TileScheduler_, ArchTag, TileShape, ClusterShape>::Scheduler;
|
||||
using TileSchedulerArguments = typename TileScheduler::Arguments;
|
||||
using TileSchedulerParams = typename TileScheduler::Params;
|
||||
|
||||
static constexpr uint32_t NumLoadWarpGroups = 1;
|
||||
static constexpr uint32_t NumMmaWarpGroups = 2;
|
||||
static constexpr uint32_t MaxThreadsPerBlock
|
||||
= CUTE_STATIC_V(size(TiledMma{})) + (NumMmaWarpGroups * NumThreadsPerWarpGroup);
|
||||
static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
|
||||
|
||||
/// Register requirement for Load and Math WGs
|
||||
static constexpr uint32_t LoadRegisterRequirement = 40;
|
||||
static constexpr uint32_t MmaRegisterRequirement = 232;
|
||||
|
||||
// 1 stage ordered sequence between mainloop and epilogue producer load threads
|
||||
using LoadWarpOrderBarrier = cutlass::OrderedSequenceBarrier<1, 2>;
|
||||
|
||||
// Order Sequence barrier with two stages: one for Mainloop and one for Epilogue
|
||||
static constexpr uint32_t StagesPerMathWarpGroup = 2;
|
||||
using MathWarpGroupOrderBarrier = cutlass::OrderedSequenceBarrier<StagesPerMathWarpGroup, NumMmaWarpGroups>;
|
||||
|
||||
// Kernel level shared memory storage
|
||||
struct SharedStorage
|
||||
{
|
||||
struct TensorStorage : cute::aligned_struct<128>
|
||||
{
|
||||
using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage;
|
||||
using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage;
|
||||
|
||||
MainloopTensorStorage mainloop;
|
||||
EpilogueTensorStorage epilogue;
|
||||
} tensors;
|
||||
|
||||
struct PipelineStorage : cute::aligned_struct<16>
|
||||
{
|
||||
using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage;
|
||||
using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage;
|
||||
using MathWarpGroupOrderBarrierStorage = typename MathWarpGroupOrderBarrier::SharedStorage;
|
||||
|
||||
alignas(16) MainloopPipelineStorage mainloop;
|
||||
alignas(16) EpiLoadPipelineStorage epi_load;
|
||||
alignas(16) MathWarpGroupOrderBarrierStorage math_wg_order;
|
||||
alignas(16) typename LoadWarpOrderBarrier::SharedStorage load_order;
|
||||
} pipelines;
|
||||
};
|
||||
|
||||
static constexpr int SharedStorageSize = sizeof(SharedStorage);
|
||||
|
||||
// Device side arguments
|
||||
struct Arguments
|
||||
{
|
||||
GemmUniversalMode mode{};
|
||||
ProblemShape problem_shape{};
|
||||
MainloopArguments mainloop{};
|
||||
EpilogueArguments epilogue{};
|
||||
KernelHardwareInfo hw_info{};
|
||||
TileSchedulerArguments scheduler{};
|
||||
};
|
||||
|
||||
// Kernel entry point API
|
||||
struct Params
|
||||
{
|
||||
GemmUniversalMode mode{};
|
||||
ProblemShape problem_shape{};
|
||||
MainloopParams mainloop{};
|
||||
EpilogueParams epilogue{};
|
||||
KernelHardwareInfo hw_info{};
|
||||
TileSchedulerParams scheduler{};
|
||||
};
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
// Convert to underlying arguments. In this case, a simple copy for the aliased type.
|
||||
static Params to_underlying_arguments(Arguments const& args, void* workspace)
|
||||
{
|
||||
CUTLASS_TRACE_HOST("to_underlying_arguments():");
|
||||
|
||||
(void) workspace;
|
||||
auto problem_shape = args.problem_shape;
|
||||
// if constexpr (detail::IF_SWAP_AB<CollectiveMainloop>::value) {
|
||||
// // swap M/N
|
||||
// get<0>(problem_shape) = get<1>(args.problem_shape);
|
||||
// get<1>(problem_shape) = get<0>(args.problem_shape);
|
||||
// }
|
||||
auto problem_shape_MNKL = append<4>(problem_shape, 1);
|
||||
|
||||
// Get SM count if needed, otherwise use user supplied SM count
|
||||
int sm_count = args.hw_info.sm_count;
|
||||
if (sm_count <= 0)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(
|
||||
" WARNING: Arguments do not include a valid SM count.\n"
|
||||
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
|
||||
sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id);
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
|
||||
KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};
|
||||
|
||||
// Calculate workspace pointers
|
||||
uint8_t* workspace_ptr = reinterpret_cast<uint8_t*>(workspace);
|
||||
size_t workspace_offset = 0;
|
||||
|
||||
void* scheduler_workspace = workspace_ptr;
|
||||
workspace_offset += TileScheduler::template get_workspace_size<ProblemShape, ElementAccumulator>(
|
||||
args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups);
|
||||
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
|
||||
|
||||
void* epilogue_workspace = workspace_ptr + workspace_offset;
|
||||
workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue);
|
||||
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
|
||||
|
||||
void* mainloop_workspace = nullptr;
|
||||
|
||||
return {args.mode, problem_shape,
|
||||
CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, mainloop_workspace),
|
||||
CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, epilogue_workspace), hw_info,
|
||||
TileScheduler::to_underlying_arguments(
|
||||
problem_shape_MNKL, TileShape{}, ClusterShape{}, hw_info, args.scheduler, scheduler_workspace)};
|
||||
}
|
||||
|
||||
static bool can_implement(Arguments const& args)
|
||||
{
|
||||
bool implementable = (args.mode == GemmUniversalMode::kGemm)
|
||||
or (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4);
|
||||
if (!implementable)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n");
|
||||
return implementable;
|
||||
}
|
||||
implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop);
|
||||
implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue);
|
||||
implementable &= TileScheduler::can_implement(args.scheduler);
|
||||
return implementable;
|
||||
}
|
||||
|
||||
static size_t get_workspace_size(Arguments const& args)
|
||||
{
|
||||
size_t workspace_size = 0;
|
||||
workspace_size += TileScheduler::template get_workspace_size<ProblemShape, ElementAccumulator>(
|
||||
args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups);
|
||||
workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment);
|
||||
|
||||
workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue);
|
||||
workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment);
|
||||
|
||||
return workspace_size;
|
||||
}
|
||||
|
||||
static cutlass::Status initialize_workspace(Arguments const& args, void* workspace = nullptr,
|
||||
cudaStream_t stream = nullptr, CudaHostAdapter* cuda_adapter = nullptr)
|
||||
{
|
||||
Status status = Status::kSuccess;
|
||||
uint8_t* workspace_ptr = reinterpret_cast<uint8_t*>(workspace);
|
||||
size_t workspace_offset = 0;
|
||||
|
||||
status = TileScheduler::template initialize_workspace<ProblemShape, ElementAccumulator>(args.scheduler,
|
||||
workspace_ptr + workspace_offset, stream, args.problem_shape, args.hw_info, NumMmaWarpGroups);
|
||||
workspace_offset += TileScheduler::template get_workspace_size<ProblemShape, ElementAccumulator>(
|
||||
args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups);
|
||||
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
|
||||
if (status != Status::kSuccess)
|
||||
{
|
||||
return status;
|
||||
}
|
||||
|
||||
status = CollectiveEpilogue::initialize_workspace(
|
||||
args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter);
|
||||
workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue);
|
||||
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
|
||||
if (status != Status::kSuccess)
|
||||
{
|
||||
return status;
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
// Computes the kernel launch grid shape based on runtime parameters
|
||||
static dim3 get_grid_shape(Params const& params)
|
||||
{
|
||||
// Given device SM count, set grid size s.t. we do not launch more thread blocks than we can run concurrently
|
||||
TileSchedulerArguments args{};
|
||||
if constexpr (!std::is_const_v<decltype(args.max_swizzle_size)>)
|
||||
{
|
||||
args.max_swizzle_size = 1 << params.scheduler.log_swizzle_size_;
|
||||
}
|
||||
args.raster_order = params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN
|
||||
? TileScheduler::RasterOrderOptions::AlongN
|
||||
: TileScheduler::RasterOrderOptions::AlongM;
|
||||
return TileScheduler::get_grid_shape(params.problem_shape, TileShape{}, ClusterShape{}, params.hw_info, args);
|
||||
}
|
||||
|
||||
static dim3 get_block_shape()
|
||||
{
|
||||
return dim3(MaxThreadsPerBlock, 1, 1);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const& params, char* smem_buf)
|
||||
{
|
||||
using namespace cute;
|
||||
using X = Underscore;
|
||||
|
||||
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
|
||||
#if !defined(__CUDA_ARCH_FEAT_SM90_ALL)
|
||||
printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n");
|
||||
#else
|
||||
|
||||
// Preconditions
|
||||
static_assert(cute::rank(StrideA{}) == 3,
|
||||
"StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideB{}) == 3,
|
||||
"StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideC{}) == 3,
|
||||
"StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideD{}) == 3,
|
||||
"StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
|
||||
enum class WarpGroupRole
|
||||
{
|
||||
Producer = 0,
|
||||
Consumer0 = 1,
|
||||
Consumer1 = 2
|
||||
};
|
||||
enum class ProducerWarpRole
|
||||
{
|
||||
Mainloop = 0,
|
||||
Warp1 = 1,
|
||||
Epilogue = 2,
|
||||
Warp3 = 3
|
||||
};
|
||||
|
||||
// Kernel level shared memory storage
|
||||
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
|
||||
|
||||
int thread_idx = int(threadIdx.x);
|
||||
int lane_idx = canonical_lane_idx();
|
||||
int warp_idx = canonical_warp_idx_sync();
|
||||
int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup;
|
||||
int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup;
|
||||
auto warp_group_role = WarpGroupRole(canonical_warp_group_idx());
|
||||
auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group);
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
|
||||
|
||||
// Issue Tma Descriptor Prefetch from a single thread
|
||||
if ((warp_idx == 0) && lane_predicate)
|
||||
{
|
||||
CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
|
||||
CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue);
|
||||
}
|
||||
|
||||
// Mainloop Load pipeline
|
||||
using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline;
|
||||
typename MainloopPipeline::Params mainloop_pipeline_params;
|
||||
if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Mainloop)
|
||||
{
|
||||
mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer;
|
||||
}
|
||||
if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1)
|
||||
{
|
||||
mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer;
|
||||
}
|
||||
mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0;
|
||||
mainloop_pipeline_params.num_consumers = NumThreadsPerWarpGroup;
|
||||
mainloop_pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytes;
|
||||
MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{});
|
||||
|
||||
// Epilogue Load pipeline
|
||||
using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline;
|
||||
typename EpiLoadPipeline::Params epi_load_pipeline_params;
|
||||
if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Epilogue)
|
||||
{
|
||||
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer;
|
||||
}
|
||||
if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1)
|
||||
{
|
||||
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer;
|
||||
}
|
||||
epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster();
|
||||
epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp;
|
||||
epi_load_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup;
|
||||
epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes;
|
||||
EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params);
|
||||
|
||||
// Epilogue Store pipeline
|
||||
using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline;
|
||||
typename EpiStorePipeline::Params epi_store_pipeline_params;
|
||||
epi_store_pipeline_params.always_wait = true;
|
||||
EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params);
|
||||
|
||||
typename LoadWarpOrderBarrier::Params params_load_order_barrier;
|
||||
params_load_order_barrier.group_id = producer_warp_role == ProducerWarpRole::Mainloop ? 0 : 1;
|
||||
params_load_order_barrier.group_size = NumThreadsPerWarp;
|
||||
LoadWarpOrderBarrier load_order_barrier(shared_storage.pipelines.load_order, params_load_order_barrier);
|
||||
|
||||
typename MathWarpGroupOrderBarrier::Params params_math_wg_order_barrier;
|
||||
// DMA Load WG will not participate in these Ordered Barrier syncs
|
||||
params_math_wg_order_barrier.group_id = canonical_warp_group_idx() - static_cast<int>(WarpGroupRole::Consumer0);
|
||||
params_math_wg_order_barrier.group_size = NumThreadsPerWarpGroup; // Number of threads / participants in a group
|
||||
MathWarpGroupOrderBarrier math_wg_order_barrier(
|
||||
shared_storage.pipelines.math_wg_order, params_math_wg_order_barrier);
|
||||
|
||||
// Initialize starting pipeline states for the collectives
|
||||
// Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding)
|
||||
typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state;
|
||||
typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state;
|
||||
|
||||
// For the DMA Load (producer) we start with an opposite phase
|
||||
// i.e., we skip all waits since we know that the buffer is indeed empty
|
||||
PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state<MainloopPipeline>();
|
||||
PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state<EpiLoadPipeline>();
|
||||
PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state<EpiStorePipeline>();
|
||||
|
||||
auto cluster_wait_fn = [&]()
|
||||
{
|
||||
// We need this to guarantee that the Pipeline init is visible
|
||||
// To all producers and consumer thread blocks in the Cluster
|
||||
if constexpr (size(ClusterShape{}) > 1)
|
||||
{
|
||||
cute::cluster_arrive_relaxed();
|
||||
return []() { cute::cluster_wait(); };
|
||||
}
|
||||
else
|
||||
{
|
||||
__syncthreads();
|
||||
return []() {}; // do nothing
|
||||
}
|
||||
}();
|
||||
|
||||
// Separate out problem shape for convenience
|
||||
// Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK)
|
||||
auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{});
|
||||
|
||||
// Get the appropriate blocks for this thread block -- potential for thread block locality
|
||||
TiledMma tiled_mma;
|
||||
auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K)
|
||||
|
||||
// In a warp specialized kernel, collectives expose data movement and compute operations separately
|
||||
CollectiveMainloop collective_mainloop;
|
||||
CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue);
|
||||
|
||||
// Prepare and partition the input tensors. Expects a tuple of tensors where:
|
||||
// get<0>(load_inputs) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l)
|
||||
// get<1>(load_inputs) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l)
|
||||
auto load_inputs = collective_mainloop.load_init(problem_shape_MNKL, params.mainloop);
|
||||
static_assert(cute::tuple_size_v<decltype(load_inputs)> >= 3,
|
||||
"Output of load_init must have at least three elements (A, B, Aux)");
|
||||
|
||||
// Extract out partitioned A and B.
|
||||
Tensor gA_mkl = get<0>(load_inputs);
|
||||
Tensor gB_nkl = get<1>(load_inputs);
|
||||
Tensor gAux_xkl = get<2>(load_inputs);
|
||||
|
||||
// Get pipeline stage increments from tensor shapes
|
||||
auto k_tile_count = size<3>(gA_mkl);
|
||||
auto c_tile_count = CollectiveEpilogue::get_load_pipe_increment(blk_shape);
|
||||
auto d_tile_count = CollectiveEpilogue::get_store_pipe_increment(blk_shape);
|
||||
|
||||
TileScheduler scheduler{params.scheduler};
|
||||
|
||||
if (warp_group_role == WarpGroupRole::Consumer1)
|
||||
{
|
||||
// Advance 2nd Math WG to the next work tile for the startup
|
||||
scheduler.advance_to_next_work();
|
||||
// Advance 2nd Math WG pipeline states to the end of 1st Math WG
|
||||
mainloop_pipe_consumer_state.advance(k_tile_count);
|
||||
epi_load_pipe_consumer_state.advance(c_tile_count);
|
||||
epi_store_pipe_producer_state.advance(d_tile_count);
|
||||
}
|
||||
auto work_tile_info = scheduler.get_current_work();
|
||||
|
||||
// Wait for all thread blocks in the Cluster
|
||||
cluster_wait_fn();
|
||||
|
||||
if (warp_group_role == WarpGroupRole::Producer)
|
||||
{
|
||||
cutlass::arch::warpgroup_reg_dealloc<LoadRegisterRequirement>();
|
||||
|
||||
// Mainloop Producer Warp
|
||||
if (producer_warp_role == ProducerWarpRole::Mainloop)
|
||||
{
|
||||
bool do_load_order_arrive = true;
|
||||
while (work_tile_info.is_valid())
|
||||
{
|
||||
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape
|
||||
auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
|
||||
auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
|
||||
auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
|
||||
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
|
||||
|
||||
auto k_tile_iter = cute::make_coord_iterator(shape<3>(gA_mkl));
|
||||
|
||||
collective_mainloop.load(params.mainloop, mainloop_pipeline, mainloop_pipe_producer_state,
|
||||
load_inputs, blk_coord, k_tile_iter, k_tile_count, lane_idx, block_rank_in_cluster,
|
||||
shared_storage.tensors.mainloop);
|
||||
// Update starting pipeline state for the next tile
|
||||
mainloop_pipe_producer_state.advance(k_tile_count);
|
||||
|
||||
// Signal for the epilogue load warp to begin
|
||||
if (do_load_order_arrive)
|
||||
{
|
||||
load_order_barrier.arrive();
|
||||
do_load_order_arrive = false;
|
||||
}
|
||||
|
||||
// Get next work tile
|
||||
scheduler.advance_to_next_work();
|
||||
work_tile_info = scheduler.get_current_work();
|
||||
} // Scheduler work fetch loop
|
||||
|
||||
// Make sure all Consumer Warp Groups have been waited upon
|
||||
collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state);
|
||||
} // Mainloop Producer Warp End
|
||||
|
||||
// Epilogue Producer Warp
|
||||
else if (producer_warp_role == ProducerWarpRole::Epilogue && collective_epilogue.is_producer_load_needed())
|
||||
{
|
||||
load_order_barrier.wait();
|
||||
while (work_tile_info.is_valid())
|
||||
{
|
||||
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape
|
||||
auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
|
||||
auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
|
||||
auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
|
||||
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
|
||||
|
||||
epi_load_pipe_producer_state
|
||||
= collective_epilogue.load(epi_load_pipeline, epi_load_pipe_producer_state, problem_shape_MNKL,
|
||||
blk_shape, blk_coord, tiled_mma, lane_idx, shared_storage.tensors.epilogue);
|
||||
|
||||
// Get next work tile
|
||||
scheduler.advance_to_next_work();
|
||||
work_tile_info = scheduler.get_current_work();
|
||||
} // Scheduler work fetch loop
|
||||
|
||||
// Make sure all Consumer Warp Groups have been waited upon
|
||||
collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state);
|
||||
} // Epilogue Producer Warp End
|
||||
} // Producer Warp Group End
|
||||
|
||||
else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1)
|
||||
{
|
||||
cutlass::arch::warpgroup_reg_alloc<MmaRegisterRequirement>();
|
||||
|
||||
float scale_d0 = params.mainloop.scale_d0;
|
||||
float scale_d1 = params.mainloop.scale_d1;
|
||||
while (work_tile_info.is_valid())
|
||||
{
|
||||
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape
|
||||
auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
|
||||
auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
|
||||
auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
|
||||
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
|
||||
|
||||
// Allocate the accumulators for the (M,N) blk_shape
|
||||
Tensor accumulators0 = partition_fragment_C(tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N)
|
||||
Tensor accumulators1 = partition_fragment_C(tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N)
|
||||
|
||||
// Order two Math WG's MMA one after the other, helps hide Epilogue
|
||||
math_wg_order_barrier.wait();
|
||||
|
||||
collective_mainloop.mma(mainloop_pipeline, mainloop_pipe_consumer_state, accumulators0, accumulators1,
|
||||
k_tile_count, warp_group_thread_idx, shared_storage.tensors.mainloop, params.mainloop);
|
||||
|
||||
// Cue for next Math WG's MMA to start
|
||||
math_wg_order_barrier.arrive();
|
||||
|
||||
// Make sure the math instructions are done and free buffers before entering the epilogue
|
||||
collective_mainloop.mma_tail(mainloop_pipeline, mainloop_pipe_consumer_state, k_tile_count);
|
||||
// Update starting mainloop pipeline state for the next tile
|
||||
mainloop_pipe_consumer_state.advance(k_tile_count * NumMmaWarpGroups);
|
||||
|
||||
Activation elt_op;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(accumulators0); i++)
|
||||
{
|
||||
accumulators0[i] = (accumulators0[i] * scale_d0) * elt_op(scale_d1 * accumulators1[i]);
|
||||
}
|
||||
|
||||
// Order two Math WG's Epilogue one after the other
|
||||
math_wg_order_barrier.wait();
|
||||
|
||||
// Epilogue and write to gD
|
||||
auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next]
|
||||
= collective_epilogue.store(epi_load_pipeline, epi_load_pipe_consumer_state, epi_store_pipeline,
|
||||
epi_store_pipe_producer_state, problem_shape_MNKL, blk_shape, blk_coord, accumulators0,
|
||||
tiled_mma, warp_group_thread_idx, shared_storage.tensors.epilogue);
|
||||
|
||||
// TMA store pipeline wait is only visible to TMA-issuing warp, so for multiple-consumer kernels
|
||||
// we need to wait for all TMA stores to complete before issuing consumer order barrier arrives
|
||||
// to ensure next math consumer doesn't overwrite smem of in-flight TMA stores of current consumer.
|
||||
auto [epi_load_pipe_consumer_state_next_, epi_store_pipe_producer_state_next_]
|
||||
= collective_epilogue.store_tail(epi_load_pipeline, epi_load_pipe_consumer_state_next,
|
||||
epi_store_pipeline, epi_store_pipe_producer_state_next);
|
||||
|
||||
// Update starting load/store pipeline states for the next tile
|
||||
// state has already been incremented by 1 tile in collective calls, advance once again for ping pong
|
||||
epi_load_pipe_consumer_state = epi_load_pipe_consumer_state_next_;
|
||||
epi_store_pipe_producer_state = epi_store_pipe_producer_state_next_;
|
||||
epi_load_pipe_consumer_state.advance(c_tile_count);
|
||||
epi_store_pipe_producer_state.advance(d_tile_count);
|
||||
|
||||
// Cue for next Math WG's Epilogue to start
|
||||
math_wg_order_barrier.arrive();
|
||||
|
||||
// Get next work tile
|
||||
scheduler.advance_to_next_work(NumMmaWarpGroups);
|
||||
work_tile_info = scheduler.get_current_work();
|
||||
} // Scheduler work fetch loop
|
||||
} // Consumer Warp Groups End
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::kernel
|
||||
@@ -0,0 +1,494 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief based on cutlass/include/cutlass/gemm/kernel/gemm_grouped.h
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/semaphore.h"
|
||||
|
||||
#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h"
|
||||
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace kernel
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||
typename Epilogue_, ///! Epilogue
|
||||
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
|
||||
GroupScheduleMode GroupScheduleMode_, ///! Type of scheduling to perform
|
||||
bool Transposed = false>
|
||||
struct SplitkGemmGrouped
|
||||
{
|
||||
public:
|
||||
using Mma = Mma_;
|
||||
using Epilogue = Epilogue_;
|
||||
using EpilogueOutputOp = typename Epilogue::OutputOp;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_;
|
||||
static bool const kTransposed = Transposed;
|
||||
|
||||
// Optional transpose
|
||||
using MapArguments = kernel::detail::MapArguments<typename Mma::IteratorA::Element, typename Mma::IteratorA::Layout,
|
||||
Mma::kTransformA, Mma::IteratorA::AccessType::kElements, typename Mma::IteratorB::Element,
|
||||
typename Mma::IteratorB::Layout, Mma::kTransformB, Mma::IteratorB::AccessType::kElements, typename Mma::LayoutC,
|
||||
kTransposed>;
|
||||
|
||||
// Public-facing type definitions related to operand element type, layout, and complex conjugate
|
||||
// operation. Must interact with the 'kTransposed' notion.
|
||||
using ElementA = typename MapArguments::ElementA;
|
||||
using LayoutA = typename MapArguments::LayoutA;
|
||||
using ElementB = typename MapArguments::ElementB;
|
||||
using LayoutB = typename MapArguments::LayoutB;
|
||||
using ElementC = typename Epilogue::OutputTileIterator::Element;
|
||||
using LayoutC = typename MapArguments::LayoutC;
|
||||
|
||||
using ElementFinalOutput = typename MapArguments::ElementA;
|
||||
|
||||
static ComplexTransform const kTransformA = MapArguments::kTransformA;
|
||||
static ComplexTransform const kTransformB = MapArguments::kTransformB;
|
||||
|
||||
// Type definitions about the mainloop.
|
||||
using Operator = typename Mma::Operator;
|
||||
using OperatorClass = typename Mma::Operator::OperatorClass;
|
||||
using ThreadblockShape = typename Mma::Shape;
|
||||
using WarpShape = typename Mma::Operator::Shape;
|
||||
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
|
||||
using ArchTag = typename Mma::ArchTag;
|
||||
|
||||
static int const kStages = Mma::kStages;
|
||||
static int const kAlignmentA = MapArguments::kAlignmentA;
|
||||
static int const kAlignmentB = MapArguments::kAlignmentB;
|
||||
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
/// Warp count (concept: GemmShape)
|
||||
using WarpCount = typename Mma::WarpCount;
|
||||
static int const kThreadCount = 32 * WarpCount::kCount;
|
||||
|
||||
using ProblemVisitor
|
||||
= GemmGroupedProblemVisitor<ThreadblockShape, kGroupScheduleMode, kThreadCount, kThreadCount, kTransposed>;
|
||||
|
||||
//
|
||||
// Structures
|
||||
//
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments
|
||||
{
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
GemmCoord* problem_sizes;
|
||||
int problem_count;
|
||||
int threadblock_count;
|
||||
|
||||
typename EpilogueOutputOp::Params output_op;
|
||||
|
||||
ElementA** ptr_A;
|
||||
ElementB** ptr_B;
|
||||
ElementFinalOutput** ptr_C;
|
||||
ElementFinalOutput** ptr_D;
|
||||
|
||||
typename LayoutA::Stride::LongIndex* lda;
|
||||
typename LayoutB::Stride::LongIndex* ldb;
|
||||
typename LayoutC::Stride::LongIndex* ldc;
|
||||
typename LayoutC::Stride::LongIndex* ldd;
|
||||
|
||||
// Only used by device-level operator
|
||||
GemmCoord* host_problem_sizes;
|
||||
|
||||
// splitK
|
||||
int split_k_slices;
|
||||
int64_t* splitk_buffer_offsets;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments()
|
||||
: problem_count(0)
|
||||
, threadblock_count(0)
|
||||
, ptr_A(nullptr)
|
||||
, ptr_B(nullptr)
|
||||
, ptr_C(nullptr)
|
||||
, ptr_D(nullptr)
|
||||
, lda(nullptr)
|
||||
, ldb(nullptr)
|
||||
, ldc(nullptr)
|
||||
, ldd(nullptr)
|
||||
, host_problem_sizes(nullptr)
|
||||
, split_k_slices(1)
|
||||
, splitk_buffer_offsets(nullptr)
|
||||
{
|
||||
}
|
||||
|
||||
/// Ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(GemmCoord* problem_sizes, int problem_count, int threadblock_count,
|
||||
typename EpilogueOutputOp::Params output_op, ElementA** ptr_A, ElementB** ptr_B, ElementFinalOutput** ptr_C,
|
||||
ElementFinalOutput** ptr_D, typename LayoutA::Stride::LongIndex* lda,
|
||||
typename LayoutB::Stride::LongIndex* ldb, typename LayoutC::Stride::LongIndex* ldc,
|
||||
typename LayoutC::Stride::LongIndex* ldd, GemmCoord* host_problem_sizes, int split_k_slices,
|
||||
int64_t* splitk_buffer_offsets)
|
||||
: problem_sizes(problem_sizes)
|
||||
, problem_count(problem_count)
|
||||
, threadblock_count(threadblock_count)
|
||||
, output_op(output_op)
|
||||
, ptr_A(ptr_A)
|
||||
, ptr_B(ptr_B)
|
||||
, ptr_C(ptr_C)
|
||||
, ptr_D(ptr_D)
|
||||
, lda(lda)
|
||||
, ldb(ldb)
|
||||
, ldc(ldc)
|
||||
, ldd(ldd)
|
||||
, host_problem_sizes(host_problem_sizes)
|
||||
, split_k_slices(split_k_slices)
|
||||
, splitk_buffer_offsets(splitk_buffer_offsets)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Structure for precomputing values in host memory and passing to kernels
|
||||
//
|
||||
|
||||
/// Parameters structure
|
||||
struct Params
|
||||
{
|
||||
|
||||
typename ProblemVisitor::Params problem_visitor;
|
||||
int threadblock_count;
|
||||
|
||||
typename EpilogueOutputOp::Params output_op;
|
||||
|
||||
ElementA** ptr_A;
|
||||
ElementB** ptr_B;
|
||||
ElementFinalOutput** ptr_C;
|
||||
ElementFinalOutput** ptr_D;
|
||||
ElementC* ptr_C_split;
|
||||
ElementC* ptr_D_split;
|
||||
|
||||
typename LayoutA::Stride::LongIndex* lda;
|
||||
typename LayoutB::Stride::LongIndex* ldb;
|
||||
typename LayoutC::Stride::LongIndex* ldc;
|
||||
typename LayoutC::Stride::LongIndex* ldd;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
// splitk
|
||||
GemmCoord grid_tiled_shape;
|
||||
int swizzle_log_tile;
|
||||
int gemm_k_size;
|
||||
GemmCoord* host_problem_sizes;
|
||||
int split_k_slices;
|
||||
int64_t* splitk_buffer_offsets;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params()
|
||||
: ptr_A(nullptr)
|
||||
, ptr_B(nullptr)
|
||||
, ptr_C(nullptr)
|
||||
, ptr_D(nullptr)
|
||||
, ptr_C_split(nullptr)
|
||||
, ptr_D_split(nullptr)
|
||||
, lda(nullptr)
|
||||
, ldb(nullptr)
|
||||
, ldc(nullptr)
|
||||
, ldd(nullptr)
|
||||
, swizzle_log_tile(0)
|
||||
, gemm_k_size(0)
|
||||
, host_problem_sizes(nullptr)
|
||||
, split_k_slices(1)
|
||||
, splitk_buffer_offsets(nullptr)
|
||||
{
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const& args, void* workspace = nullptr, int tile_count = 0)
|
||||
: problem_visitor(args.problem_sizes, args.problem_count, workspace, tile_count)
|
||||
, host_problem_sizes(args.host_problem_sizes)
|
||||
, threadblock_count(args.threadblock_count)
|
||||
, output_op(args.output_op)
|
||||
, ptr_A(args.ptr_A)
|
||||
, ptr_B(args.ptr_B)
|
||||
, ptr_C(args.ptr_C)
|
||||
, ptr_D(args.ptr_D)
|
||||
, ptr_C_split((ElementC*) workspace)
|
||||
, ptr_D_split((ElementC*) workspace)
|
||||
, lda(args.lda)
|
||||
, ldb(args.ldb)
|
||||
, ldc(args.ldc)
|
||||
, ldd(args.ldd)
|
||||
, split_k_slices(args.split_k_slices)
|
||||
, splitk_buffer_offsets(args.splitk_buffer_offsets)
|
||||
{
|
||||
// Determine grid shape
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
grid_tiled_shape = threadblock_swizzle.get_tiled_shape(args.host_problem_sizes[0],
|
||||
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.split_k_slices);
|
||||
swizzle_log_tile = ThreadblockSwizzle().get_log_tile(grid_tiled_shape);
|
||||
|
||||
// only support same k
|
||||
int full_gemm_k_iterations = args.host_problem_sizes[0].k() / Mma::Shape::kK;
|
||||
int gemm_k_iterations = full_gemm_k_iterations / grid_tiled_shape.k();
|
||||
|
||||
gemm_k_size = gemm_k_iterations * Mma::Shape::kK;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void update(Arguments const& args, void* workspace = nullptr, int tile_count = 0)
|
||||
{
|
||||
|
||||
problem_visitor =
|
||||
typename ProblemVisitor::Params(args.problem_sizes, args.problem_count, workspace, tile_count);
|
||||
threadblock_count = args.threadblock_count;
|
||||
output_op = args.output_op;
|
||||
ptr_A = args.ptr_A;
|
||||
ptr_B = args.ptr_B;
|
||||
ptr_C = args.ptr_C;
|
||||
ptr_D = args.ptr_D;
|
||||
ptr_C_split = workspace;
|
||||
ptr_D_split = workspace;
|
||||
|
||||
lda = args.lda;
|
||||
ldb = args.ldb;
|
||||
ldc = args.ldc;
|
||||
ldd = args.ldd;
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared memory storage structure
|
||||
struct SharedStorage
|
||||
{
|
||||
union
|
||||
{
|
||||
typename Mma::SharedStorage main_loop;
|
||||
typename Epilogue::SharedStorage epilogue;
|
||||
} kernel;
|
||||
|
||||
// ProblemVisitor shared storage can't be overlapped with others
|
||||
typename ProblemVisitor::SharedStorage problem_visitor;
|
||||
};
|
||||
|
||||
public:
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_DEVICE
|
||||
SplitkGemmGrouped() {}
|
||||
|
||||
/// Determines whether kernel satisfies alignment
|
||||
static Status can_implement(cutlass::gemm::GemmCoord const& problem_size)
|
||||
{
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
static Status can_implement(Arguments const& args)
|
||||
{
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Executes one GEMM
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const& params, SharedStorage& shared_storage)
|
||||
{
|
||||
|
||||
//
|
||||
// These types shadow the type-level definitions and support the ability to implement
|
||||
// a 'transposed' GEMM that computes the transposed problems.
|
||||
//
|
||||
using ElementA = typename Mma::IteratorA::Element;
|
||||
using LayoutA = typename Mma::IteratorA::Layout;
|
||||
using ElementB = typename Mma::IteratorB::Element;
|
||||
using LayoutB = typename Mma::IteratorB::Layout;
|
||||
using ElementC = typename Epilogue::OutputTileIterator::Element;
|
||||
using LayoutC = typename Epilogue::OutputTileIterator::Layout;
|
||||
|
||||
//
|
||||
// Problem visitor.
|
||||
//
|
||||
ProblemVisitor problem_visitor(params.problem_visitor, shared_storage.problem_visitor, blockIdx.x);
|
||||
|
||||
// Outer 'persistent' loop to iterate over tiles
|
||||
while (problem_visitor.next_tile())
|
||||
{
|
||||
|
||||
GemmCoord problem_size = problem_visitor.problem_size();
|
||||
int32_t problem_idx = problem_visitor.problem_index();
|
||||
int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx());
|
||||
|
||||
GemmCoord grid_shape = problem_visitor.grid_shape(problem_size);
|
||||
|
||||
// Load element pointers. Exchange pointers and strides if working on the transpose
|
||||
ElementA* ptr_A
|
||||
= reinterpret_cast<ElementA*>((kTransposed ? params.ptr_B[problem_idx] : params.ptr_A[problem_idx]));
|
||||
typename LayoutA::LongIndex ldm_A = (kTransposed ? params.ldb[problem_idx] : params.lda[problem_idx]);
|
||||
|
||||
ElementB* ptr_B
|
||||
= reinterpret_cast<ElementB*>((kTransposed ? params.ptr_A[problem_idx] : params.ptr_B[problem_idx]));
|
||||
typename LayoutB::LongIndex ldm_B = (kTransposed ? params.lda[problem_idx] : params.ldb[problem_idx]);
|
||||
|
||||
// Compute threadblock location
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_offset(int(threadblock_idx / grid_shape.n()) * Mma::Shape::kM,
|
||||
int(threadblock_idx % grid_shape.n()) * Mma::Shape::kN, 0);
|
||||
|
||||
// Compute initial location in logical coordinates
|
||||
cutlass::MatrixCoord tb_offset_A{
|
||||
threadblock_offset.m(),
|
||||
threadblock_tile_offset.k() * params.gemm_k_size,
|
||||
};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_B{threadblock_tile_offset.k() * params.gemm_k_size, threadblock_offset.n()};
|
||||
|
||||
// Problem size is a function of threadblock index in the K dimension
|
||||
int problem_size_k;
|
||||
if (threadblock_tile_offset.k() + 1 == params.grid_tiled_shape.k())
|
||||
{
|
||||
problem_size_k = problem_size.k();
|
||||
}
|
||||
else
|
||||
{
|
||||
problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size;
|
||||
}
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK;
|
||||
|
||||
// Compute position within threadblock
|
||||
int thread_idx = threadIdx.x;
|
||||
|
||||
// Construct iterators to A and B operands
|
||||
typename Mma::IteratorA iterator_A(
|
||||
LayoutA(ldm_A), ptr_A, {problem_size.m(), problem_size_k}, thread_idx, tb_offset_A);
|
||||
|
||||
typename Mma::IteratorB iterator_B(
|
||||
LayoutB(ldm_B), ptr_B, {problem_size_k, problem_size.n()}, thread_idx, tb_offset_B);
|
||||
|
||||
typename Mma::FragmentC accumulators;
|
||||
|
||||
accumulators.clear();
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = canonical_warp_idx_sync();
|
||||
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
//
|
||||
// Matrix multiply phase
|
||||
//
|
||||
|
||||
// Construct thread-scoped matrix multiply
|
||||
Mma mma(shared_storage.kernel.main_loop, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
// Wait for all threads to finish their epilogue phases from the previous tile.
|
||||
__syncthreads();
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators);
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
//
|
||||
|
||||
EpilogueOutputOp output_op(params.output_op);
|
||||
|
||||
ElementC* ptr_C = params.ptr_C_split;
|
||||
ElementC* ptr_D = params.ptr_D_split;
|
||||
|
||||
LayoutC layout_C(params.ldc[problem_idx]);
|
||||
LayoutC layout_D(params.ldd[problem_idx]);
|
||||
|
||||
typename Epilogue::OutputTileIterator::Params params_C(layout_C);
|
||||
typename Epilogue::OutputTileIterator::Params params_D(layout_D);
|
||||
|
||||
// assume identity swizzle
|
||||
MatrixCoord threadblock_offset_C(threadblock_offset.m(), threadblock_offset.n());
|
||||
|
||||
// Tile iterator loading from source tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_C(
|
||||
params_C, ptr_C, problem_size.mn(), thread_idx, threadblock_offset_C);
|
||||
|
||||
iterator_C.add_pointer_offset(problem_size.m() * problem_size.n() * threadblock_tile_offset.k()
|
||||
+ gridDim.z * params.splitk_buffer_offsets[problem_idx]);
|
||||
|
||||
// Tile iterator writing to destination tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_D(
|
||||
params_D, ptr_D, problem_size.mn(), thread_idx, threadblock_offset_C);
|
||||
iterator_D.add_pointer_offset(problem_size.m() * problem_size.n() * threadblock_tile_offset.k()
|
||||
+ gridDim.z * params.splitk_buffer_offsets[problem_idx]);
|
||||
|
||||
Epilogue epilogue(shared_storage.kernel.epilogue, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
// Execute the epilogue operator to update the destination tensor.
|
||||
epilogue(output_op, iterator_D, accumulators, iterator_C);
|
||||
|
||||
// Next tile
|
||||
problem_visitor.advance(gridDim.x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,125 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass_extensions/arch/mma.h"
|
||||
#include "cutlass_extensions/interleaved_numeric_conversion.h"
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// We need to distinguish here, since we want volta support. It is too much effort
|
||||
// to write shared memory iterators that are probably needed for volta to function
|
||||
// properly. As a result, we allow converters both after the LDG (for volta) and after
|
||||
// the LDS for Turing+.
|
||||
template <
|
||||
/// Iterator for B matrix in global memory
|
||||
typename IteratorB,
|
||||
/// Warp level Mma
|
||||
typename MmaOperator,
|
||||
/// Math operation perform by warp level operator
|
||||
typename MathOperator>
|
||||
struct SetConverters
|
||||
{
|
||||
};
|
||||
|
||||
// Dequantize after LDG, so set transforms accordingly
|
||||
template <
|
||||
/// Iterator for B matrix in global memory
|
||||
typename IteratorB,
|
||||
/// Mma Policy
|
||||
typename MmaOperator>
|
||||
struct SetConverters<IteratorB, MmaOperator, arch::OpMultiplyAdd>
|
||||
{
|
||||
using TransformAfterLDG
|
||||
= FastInterleavedAndBiasedNumericArrayConverter<typename MmaOperator::ArchMmaOperator::ElementB,
|
||||
typename IteratorB::Element, IteratorB::Fragment::kElements>;
|
||||
|
||||
using TransformAfterLDS = NumericArrayConverter<typename MmaOperator::ArchMmaOperator::ElementB,
|
||||
typename MmaOperator::ArchMmaOperator::ElementB, MmaOperator::FragmentB::kElements>;
|
||||
};
|
||||
|
||||
// Dequantize after LDS, so set transforms accordingly
|
||||
|
||||
template <
|
||||
/// Iterator for B matrix in global memory
|
||||
typename IteratorB,
|
||||
/// Mma Policy
|
||||
typename MmaOperator>
|
||||
struct SetConverters<IteratorB, MmaOperator, arch::OpMultiplyAddDequantizeInterleavedBToA>
|
||||
{
|
||||
using TransformAfterLDG = NumericArrayConverter<typename IteratorB::Element, typename IteratorB::Element,
|
||||
IteratorB::Fragment::kElements>;
|
||||
|
||||
using TransformAfterLDS
|
||||
= FastInterleavedAndBiasedNumericArrayConverter<typename MmaOperator::ArchMmaOperator::ElementB,
|
||||
typename TransformAfterLDG::result_type::Element, MmaOperator::FragmentB::kElements>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA_,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA_,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB_,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB_,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for the input scale
|
||||
typename ElementScale_,
|
||||
/// Layout for the scale operand
|
||||
typename LayoutScale_,
|
||||
/// Access granularity of Scales in unit of elements
|
||||
int kAlignmentScale,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator_,
|
||||
/// Layout type for C and D matrix operands
|
||||
typename LayoutC_,
|
||||
/// Operator class tag
|
||||
typename OperatorClass_,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag_,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape_,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape_,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape_,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator_,
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
|
||||
///
|
||||
typename Enable = void>
|
||||
struct DqMma;
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@@ -0,0 +1,302 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/gemm/threadblock/default_mma.h"
|
||||
#include "cutlass_extensions/arch/mma.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/threadblock/dq_mma_multistage.h"
|
||||
#include "cutlass_extensions/gemm/warp/default_mma_tensor_op.h"
|
||||
#include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h"
|
||||
#include "cutlass_extensions/tile_interleaved_layout.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/threadblock/default_dq_mma.h"
|
||||
#include "cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h"
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename MmaShape, typename Element, typename Layout, WeightOnlyQuantOp QuantOp, int Alignment,
|
||||
typename Enable = void>
|
||||
struct DefaultScaleIteratorsMultistage;
|
||||
|
||||
// Fine grained iterators
|
||||
template <typename MmaShape, typename Element, typename Layout, WeightOnlyQuantOp QuantOp, int Alignment>
|
||||
struct DefaultScaleIteratorsMultistage<MmaShape, Element, Layout, QuantOp, Alignment,
|
||||
std::enable_if_t<isFinegrained(QuantOp)>>
|
||||
{
|
||||
using IteratorScale
|
||||
= cutlass::transform::threadblock::FineGrainedScaleZeroIterator<cutlass::MatrixShape<1, MmaShape::kN>, Element,
|
||||
Layout, 0, Alignment>;
|
||||
|
||||
using SmemIteratorScale = IteratorScale;
|
||||
};
|
||||
|
||||
// Per column iterators
|
||||
template <typename MmaShape, typename Element, typename Layout, WeightOnlyQuantOp QuantOp, int Alignment>
|
||||
struct DefaultScaleIteratorsMultistage<MmaShape, Element, Layout, QuantOp, Alignment,
|
||||
std::enable_if_t<!isFinegrained(QuantOp)>>
|
||||
{
|
||||
// ThreadMap for scale iterator
|
||||
static_assert((MmaShape::kN % Alignment) == 0, "");
|
||||
|
||||
private:
|
||||
using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap<layout::PitchLinearShape<MmaShape::kN, 1>,
|
||||
MmaShape::kN / Alignment, Alignment>;
|
||||
|
||||
public:
|
||||
// Define iterators over tiles from the scale operand
|
||||
using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaShape::kN>,
|
||||
Element, Layout, 0, IteratorScaleThreadMap, Alignment>;
|
||||
|
||||
using SmemIteratorScale = IteratorScale;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// Type for element A
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Type for element B
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for the input scale
|
||||
typename ElementScale,
|
||||
/// Layout for the scale operand
|
||||
typename LayoutScale,
|
||||
/// Access granularity of Scales in unit of elements
|
||||
int kAlignmentScale,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Operator class tag
|
||||
typename OperatorClass,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Stages in GEMM
|
||||
int kStages,
|
||||
/// Operator performed by GEMM
|
||||
typename Operator_,
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
SharedMemoryClearOption SharedMemoryClear>
|
||||
struct DqMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementScale, LayoutScale, kAlignmentScale,
|
||||
ElementAccumulator, layout::RowMajor, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape,
|
||||
kStages, Operator_, SharedMemoryClear,
|
||||
typename platform::enable_if<(
|
||||
ArchTag::kMinComputeCapability >= 80 && !layout::IsColumnMajorTileInterleave<LayoutB>::value)>::type>
|
||||
{
|
||||
|
||||
static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value
|
||||
|| platform::is_same<ElementA, float_e4m3_t>::value,
|
||||
"Element A must be fp16, fp8 or bf16");
|
||||
|
||||
using OperatorInfo = arch::DetagOperator<Operator_>;
|
||||
using Operator = typename OperatorInfo::Operator;
|
||||
static_assert(platform::is_same<Operator, arch::OpMultiplyAddDequantizeInterleavedBToA>::value,
|
||||
"Mma multistage must dequantize after ldsm");
|
||||
|
||||
static_assert(platform::is_same<ElementB, uint8_t>::value || platform::is_same<ElementB, uint4b_t>::value,
|
||||
"Element B must be uint8 or uint4");
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits<ElementA>::value * kAlignmentA) == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits<ElementB>::value * kAlignmentB) == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
// Define the MmaCore components
|
||||
// Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape,
|
||||
ElementA, LayoutA, ElementB, LayoutB, ElementAccumulator, layout::RowMajor, OperatorClass, std::max(kStages, 3),
|
||||
Operator, false, CacheOpA, CacheOpB>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::Array<ElementA, kAlignmentA>;
|
||||
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, ElementA, LayoutA, 1, ThreadMapA,
|
||||
AccessTypeA>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::Array<ElementB, kAlignmentB>;
|
||||
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, ElementB, LayoutB, 0, ThreadMapB,
|
||||
AccessTypeB>;
|
||||
|
||||
using ScaleIterators = DefaultScaleIteratorsMultistage<typename MmaCore::Shape, ElementScale, LayoutScale,
|
||||
OperatorInfo::QuantOp, kAlignmentScale>;
|
||||
|
||||
// Define iterators over tiles from the scale operand
|
||||
using IteratorScale = typename ScaleIterators::IteratorScale;
|
||||
|
||||
using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale;
|
||||
|
||||
using Converter = FastInterleavedAndBiasedNumericArrayConverter<ElementScale, ElementB,
|
||||
MmaCore::MmaPolicy::Operator::FragmentB::kElements>;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage<typename MmaCore::Shape, IteratorA,
|
||||
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
|
||||
MmaCore::kCacheOpB, IteratorScale, SmemIteratorScale, ElementAccumulator, layout::RowMajor,
|
||||
typename MmaCore::MmaPolicy, kStages, Converter, OperatorInfo::QuantOp, SharedMemoryClear>;
|
||||
};
|
||||
|
||||
// Specialization to handle column major interleave B
|
||||
template <
|
||||
/// Type for element A
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Type for element B
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for the input scale
|
||||
typename ElementScale,
|
||||
/// Layout for the scale operand
|
||||
typename LayoutScale,
|
||||
/// Access granularity of Scales in unit of elements
|
||||
int kAlignmentScale,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Operator class tag
|
||||
typename OperatorClass,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Stages in GEMM
|
||||
int kStages,
|
||||
/// Operator performed by GEMM
|
||||
typename Operator_,
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
SharedMemoryClearOption SharedMemoryClear>
|
||||
struct DqMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementScale, LayoutScale, kAlignmentScale,
|
||||
ElementAccumulator, layout::RowMajor, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape,
|
||||
kStages, Operator_, SharedMemoryClear,
|
||||
typename platform::enable_if<(
|
||||
ArchTag::kMinComputeCapability >= 80 && layout::IsColumnMajorTileInterleave<LayoutB>::value)>::type>
|
||||
{
|
||||
|
||||
static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value
|
||||
|| platform::is_same<ElementA, float_e4m3_t>::value,
|
||||
"Element A must be fp16, fp8 or bf16");
|
||||
|
||||
using OperatorInfo = arch::DetagOperator<Operator_>;
|
||||
using Operator = typename OperatorInfo::Operator;
|
||||
static_assert(platform::is_same<Operator, arch::OpMultiplyAddDequantizeInterleavedBToA>::value,
|
||||
"Mma multistage must dequantize after ldsm");
|
||||
|
||||
static_assert(platform::is_same<ElementB, uint8_t>::value || platform::is_same<ElementB, uint4b_t>::value,
|
||||
"Element B must be uint8 or uint4");
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits<ElementA>::value * kAlignmentA) == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits<ElementB>::value * kAlignmentB) == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
// Define the MmaCore components
|
||||
// Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape,
|
||||
ElementA, LayoutA, ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, OperatorClass,
|
||||
std::max(kStages, 3), Operator, false, CacheOpA, CacheOpB>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::Array<ElementA, kAlignmentA>;
|
||||
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, ElementA, LayoutA, 1, ThreadMapA,
|
||||
AccessTypeA>;
|
||||
|
||||
private:
|
||||
static constexpr int ColumnsInterleaved = LayoutB::kColumnsInterleaved;
|
||||
static constexpr int RowsPerTile = LayoutB::kRowsPerTile;
|
||||
static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), "");
|
||||
static_assert(RowsPerTile == MmaCore::Shape::kK, "");
|
||||
|
||||
using OriginalThreadMap = typename MmaCore::IteratorThreadMapB;
|
||||
using OriginalWarpArrangement = typename OriginalThreadMap::Detail::WarpThreadArrangement;
|
||||
static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), "");
|
||||
|
||||
using GmemIteratorShape
|
||||
= MatrixShape<MmaCore::Shape::kK * ColumnsInterleaved, MmaCore::Shape::kN / ColumnsInterleaved>;
|
||||
using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap<
|
||||
layout::PitchLinearShape<GmemIteratorShape::kRow, GmemIteratorShape::kColumn>, OriginalThreadMap::kThreads,
|
||||
layout::PitchLinearShape<OriginalWarpArrangement::kContiguous * ColumnsInterleaved,
|
||||
OriginalWarpArrangement::kStrided / ColumnsInterleaved>,
|
||||
MmaCore::kAccessSizeInBits / sizeof_bits<ElementB>::value>;
|
||||
|
||||
public:
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::Array<ElementB, kAlignmentB>;
|
||||
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<GmemIteratorShape, ElementB,
|
||||
layout::ColumnMajor, 0, GmemThreadMapB, AccessTypeB>;
|
||||
|
||||
using ScaleIterators = DefaultScaleIteratorsMultistage<typename MmaCore::Shape, ElementScale, LayoutScale,
|
||||
OperatorInfo::QuantOp, kAlignmentScale>;
|
||||
|
||||
// Define iterators over tiles from the scale operand
|
||||
using IteratorScale = typename ScaleIterators::IteratorScale;
|
||||
|
||||
using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale;
|
||||
|
||||
using Converter = FastInterleavedAndBiasedNumericArrayConverter<ElementScale, ElementB,
|
||||
MmaCore::MmaPolicy::Operator::FragmentB::kElements>;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage<typename MmaCore::Shape, IteratorA,
|
||||
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
|
||||
MmaCore::kCacheOpB, IteratorScale, SmemIteratorScale, ElementAccumulator, layout::RowMajor,
|
||||
typename MmaCore::MmaPolicy, kStages, Converter, OperatorInfo::QuantOp, SharedMemoryClear>;
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@@ -0,0 +1,284 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/gemm/threadblock/default_mma.h"
|
||||
#include "cutlass_extensions/arch/mma.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h"
|
||||
#include "cutlass_extensions/gemm/warp/default_mma_tensor_op.h"
|
||||
#include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h"
|
||||
#include "cutlass_extensions/tile_interleaved_layout.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/threadblock/default_dq_mma.h"
|
||||
#include "cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h"
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename MmaShape, typename Element, typename Layout, WeightOnlyQuantOp QuantOp, int Alignment,
|
||||
typename Enable = void>
|
||||
struct DefaultScaleIteratorsPipelined;
|
||||
|
||||
// Fine grained iterators
|
||||
template <typename MmaShape, typename Element, typename Layout, WeightOnlyQuantOp QuantOp, int Alignment>
|
||||
struct DefaultScaleIteratorsPipelined<MmaShape, Element, Layout, QuantOp, Alignment,
|
||||
std::enable_if_t<isFinegrained(QuantOp)>>
|
||||
{
|
||||
private:
|
||||
using SmemScaleType = half_t;
|
||||
|
||||
public:
|
||||
using IteratorScale
|
||||
= cutlass::transform::threadblock::FineGrainedScaleZeroIterator<cutlass::MatrixShape<1, MmaShape::kN>, Element,
|
||||
Layout, 0, Alignment>;
|
||||
|
||||
using SmemIteratorScale
|
||||
= cutlass::transform::threadblock::FineGrainedScaleZeroIterator<cutlass::MatrixShape<1, MmaShape::kN>,
|
||||
SmemScaleType, Layout, 0, Alignment>;
|
||||
};
|
||||
|
||||
// Per column iterators
|
||||
template <typename MmaShape, typename Element, typename Layout, WeightOnlyQuantOp QuantOp, int Alignment>
|
||||
struct DefaultScaleIteratorsPipelined<MmaShape, Element, Layout, QuantOp, Alignment,
|
||||
std::enable_if_t<!isFinegrained(QuantOp)>>
|
||||
{
|
||||
static_assert((MmaShape::kN % Alignment) == 0, "");
|
||||
|
||||
private:
|
||||
// ThreadMap for scale iterator
|
||||
using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap<layout::PitchLinearShape<MmaShape::kN, 1>,
|
||||
MmaShape::kN / Alignment, Alignment>;
|
||||
using SmemScaleType = half_t;
|
||||
|
||||
public:
|
||||
// Define iterators over tiles from the scale operand
|
||||
using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaShape::kN>,
|
||||
Element, Layout, 0, IteratorScaleThreadMap, Alignment>;
|
||||
|
||||
using SmemIteratorScale
|
||||
= cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaShape::kN>, SmemScaleType,
|
||||
Layout, 0, IteratorScaleThreadMap, Alignment>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// Type for element A
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Type for element B
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for the input scale
|
||||
typename ElementScale,
|
||||
/// Layout for the scale operand
|
||||
typename LayoutScale,
|
||||
/// Access granularity of Scales in unit of elements
|
||||
int kAlignmentScale,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Operator class tag
|
||||
typename OperatorClass,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator_>
|
||||
struct DqMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementScale, LayoutScale, kAlignmentScale,
|
||||
ElementAccumulator, layout::RowMajor, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2,
|
||||
Operator_, SharedMemoryClearOption::kNone,
|
||||
typename platform::enable_if<(
|
||||
ArchTag::kMinComputeCapability < 80 && !layout::IsColumnMajorTileInterleave<LayoutB>::value)>::type>
|
||||
{
|
||||
|
||||
static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value,
|
||||
"Element A must be fp16 or bf16");
|
||||
|
||||
static_assert(platform::is_same<ElementB, uint8_t>::value || platform::is_same<ElementB, uint4b_t>::value,
|
||||
"Element B must be uint8 or uint4");
|
||||
|
||||
using OperatorInfo = arch::DetagOperator<Operator_>;
|
||||
using Operator = typename OperatorInfo::Operator;
|
||||
static_assert(OperatorInfo::QuantOp == WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, "");
|
||||
|
||||
static constexpr bool DqAfterLDG = platform::is_same<arch::OpMultiplyAdd, Operator>::value;
|
||||
using MmaCoreElementA = half_t;
|
||||
using MmaCoreElementB = typename platform::conditional<DqAfterLDG, MmaCoreElementA, ElementB>::type;
|
||||
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape,
|
||||
MmaCoreElementA, LayoutA, MmaCoreElementB, LayoutB, ElementAccumulator, layout::RowMajor, OperatorClass, 2,
|
||||
Operator>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>, ElementA, LayoutA, 1,
|
||||
typename MmaCore::IteratorThreadMapA, kAlignmentA>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore::Shape::kK, MmaCore::Shape::kN>, ElementB, LayoutB, 0,
|
||||
typename MmaCore::IteratorThreadMapB, kAlignmentB>;
|
||||
|
||||
using ScaleIterators = DefaultScaleIteratorsPipelined<typename MmaCore::Shape, ElementScale, LayoutScale,
|
||||
OperatorInfo::QuantOp, kAlignmentScale>;
|
||||
|
||||
// Define iterators over tiles from the scale operand
|
||||
using IteratorScale = typename ScaleIterators::IteratorScale;
|
||||
|
||||
using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale;
|
||||
|
||||
using Converters = SetConverters<IteratorB, typename MmaCore::MmaPolicy::Operator, Operator>;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined<typename MmaCore::Shape, IteratorA,
|
||||
typename MmaCore::SmemIteratorA, IteratorB, typename MmaCore::SmemIteratorB, IteratorScale, SmemIteratorScale,
|
||||
ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, typename Converters::TransformAfterLDG,
|
||||
typename Converters::TransformAfterLDS, OperatorInfo::QuantOp>;
|
||||
};
|
||||
|
||||
// Specialization to handle column major interleave B
|
||||
template <
|
||||
/// Type for element A
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Type for element B
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for the input scale
|
||||
typename ElementScale,
|
||||
/// Layout for the scale operand
|
||||
typename LayoutScale,
|
||||
/// Access granularity of Scales in unit of elements
|
||||
int kAlignmentScale,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Operator class tag
|
||||
typename OperatorClass,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator_>
|
||||
struct DqMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementScale, LayoutScale, kAlignmentScale,
|
||||
ElementAccumulator, layout::RowMajor, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2,
|
||||
Operator_, SharedMemoryClearOption::kNone,
|
||||
typename platform::enable_if<(
|
||||
ArchTag::kMinComputeCapability < 80 && layout::IsColumnMajorTileInterleave<LayoutB>::value)>::type>
|
||||
{
|
||||
|
||||
static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value,
|
||||
"Element A must be fp16 or bf16");
|
||||
|
||||
static_assert(platform::is_same<ElementB, uint8_t>::value || platform::is_same<ElementB, uint4b_t>::value,
|
||||
"Element B must be uint8 or uint4");
|
||||
|
||||
using OperatorInfo = arch::DetagOperator<Operator_>;
|
||||
using Operator = typename OperatorInfo::Operator;
|
||||
|
||||
static constexpr bool DqAfterLDG = platform::is_same<arch::OpMultiplyAdd, Operator>::value;
|
||||
using MmaCoreElementA = half_t;
|
||||
using MmaCoreElementB = typename platform::conditional<DqAfterLDG, MmaCoreElementA, ElementB>::type;
|
||||
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape,
|
||||
MmaCoreElementA, LayoutA, MmaCoreElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor,
|
||||
OperatorClass, 2, Operator>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>, ElementA, LayoutA, 1,
|
||||
typename MmaCore::IteratorThreadMapA, kAlignmentA>;
|
||||
|
||||
private:
|
||||
static constexpr int ColumnsInterleaved = LayoutB::kColumnsInterleaved;
|
||||
static constexpr int RowsPerTile = LayoutB::kRowsPerTile;
|
||||
static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), "");
|
||||
static_assert(RowsPerTile == MmaCore::Shape::kK, "");
|
||||
|
||||
using OriginalThreadMap = typename MmaCore::IteratorThreadMapB;
|
||||
using OriginalWarpArrangement = typename OriginalThreadMap::Detail::WarpThreadArrangement;
|
||||
static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), "");
|
||||
|
||||
using GmemIteratorShape
|
||||
= MatrixShape<MmaCore::Shape::kK * ColumnsInterleaved, MmaCore::Shape::kN / ColumnsInterleaved>;
|
||||
using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap<
|
||||
layout::PitchLinearShape<GmemIteratorShape::kRow, GmemIteratorShape::kColumn>, OriginalThreadMap::kThreads,
|
||||
layout::PitchLinearShape<OriginalWarpArrangement::kContiguous * ColumnsInterleaved,
|
||||
OriginalWarpArrangement::kStrided / ColumnsInterleaved>,
|
||||
MmaCore::kAccessSizeInBits / sizeof_bits<ElementB>::value>;
|
||||
|
||||
public:
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator<GmemIteratorShape, ElementB,
|
||||
layout::ColumnMajor, 0, GmemThreadMapB, kAlignmentB>;
|
||||
|
||||
// ThreadMap for scale iterator
|
||||
static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, "");
|
||||
using IteratorScaleThreadMap
|
||||
= transform::PitchLinearStripminedThreadMap<layout::PitchLinearShape<MmaCore::Shape::kN, 1>,
|
||||
MmaCore::Shape::kN / kAlignmentScale, kAlignmentScale>;
|
||||
|
||||
using ScaleIterators = DefaultScaleIteratorsPipelined<typename MmaCore::Shape, ElementScale, LayoutScale,
|
||||
OperatorInfo::QuantOp, kAlignmentScale>;
|
||||
|
||||
// Define iterators over tiles from the scale operand
|
||||
using IteratorScale = typename ScaleIterators::IteratorScale;
|
||||
|
||||
using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale;
|
||||
|
||||
using Converters = SetConverters<IteratorB, typename MmaCore::MmaPolicy::Operator, Operator>;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined<typename MmaCore::Shape, IteratorA,
|
||||
typename MmaCore::SmemIteratorA, IteratorB, typename MmaCore::SmemIteratorB, IteratorScale, SmemIteratorScale,
|
||||
ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, typename Converters::TransformAfterLDG,
|
||||
typename Converters::TransformAfterLDS, OperatorInfo::QuantOp>;
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@@ -0,0 +1,351 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/default_mma_bf16.h"
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight, mma pipelined (stage=2)
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultMma<cutlass::half_t, LayoutA, kAlignmentA, uint8_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator>
|
||||
{
|
||||
|
||||
private:
|
||||
static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
|
||||
|
||||
using Mma = DqMma<half_t, LayoutA, kAlignmentA, uint8_t, LayoutB, kAlignmentB, half_t, layout::RowMajor,
|
||||
kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
|
||||
WarpShape, InstructionShape, 2, Operator>;
|
||||
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight, mma pipelined (stage=2)
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultMma<cutlass::half_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator>
|
||||
{
|
||||
|
||||
private:
|
||||
static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
|
||||
|
||||
using Mma = DqMma<half_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, half_t, layout::RowMajor,
|
||||
kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
|
||||
WarpShape, InstructionShape, 2, Operator>;
|
||||
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight, mma multistage
|
||||
/// (stage>=3)
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
///
|
||||
int kStages,
|
||||
/// Shared memory clear option
|
||||
SharedMemoryClearOption SharedMemoryClear>
|
||||
struct DefaultMma<cutlass::half_t, LayoutA, kAlignmentA, uint8_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
|
||||
false, SharedMemoryClear>
|
||||
{
|
||||
|
||||
private:
|
||||
static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
|
||||
|
||||
using Mma = DqMma<half_t, LayoutA, kAlignmentA, uint8_t, LayoutB, kAlignmentB, half_t, layout::RowMajor,
|
||||
kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
|
||||
WarpShape, InstructionShape, kStages, Operator, SharedMemoryClear>;
|
||||
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight, mma multistage
|
||||
/// (stage>=3)
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
///
|
||||
int kStages,
|
||||
/// Shared memory clear option
|
||||
SharedMemoryClearOption SharedMemoryClear>
|
||||
struct DefaultMma<cutlass::half_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
|
||||
false, SharedMemoryClear>
|
||||
{
|
||||
|
||||
private:
|
||||
static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
|
||||
|
||||
using Mma = DqMma<half_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, half_t, layout::RowMajor,
|
||||
kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
|
||||
WarpShape, InstructionShape, kStages, Operator, SharedMemoryClear>;
|
||||
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
};
|
||||
|
||||
#ifdef ENABLE_FP8
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fp8 activation & int4 weight, mma multistage
|
||||
/// (stage>=3)
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
///
|
||||
int kStages,
|
||||
/// Shared memory clear option
|
||||
SharedMemoryClearOption SharedMemoryClear>
|
||||
struct DefaultMma<cutlass::float_e4m3_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
|
||||
false, SharedMemoryClear>
|
||||
{
|
||||
|
||||
private:
|
||||
static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
|
||||
|
||||
using Mma = DqMma<cutlass::float_e4m3_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, half_t,
|
||||
layout::RowMajor, kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag,
|
||||
ThreadblockShape, WarpShape, InstructionShape, kStages, Operator, SharedMemoryClear>;
|
||||
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
// fp16 x fp16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on
|
||||
// large tile when not enough shared mem is present to do 3+ stage
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
SharedMemoryClearOption SharedMemoryClear,
|
||||
/// Gather operand A by using an index array
|
||||
bool GatherA,
|
||||
/// Gather operand B by using an index array
|
||||
bool GatherB>
|
||||
struct DefaultMma<half_t, LayoutA, kAlignmentA, half_t, LayoutB, kAlignmentB, ElementAccumulator, layout::RowMajor,
|
||||
arch::OpClassTensorOp, arch::Sm80, ThreadblockShape, WarpShape, InstructionShape, 2, Operator, false,
|
||||
SharedMemoryClear, GatherA, GatherB>
|
||||
{
|
||||
|
||||
// Define the MmaCore components
|
||||
// 3 is used on purpose here to trigger components for mma multistage
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape,
|
||||
half_t, LayoutA, half_t, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, 3, Operator>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::Array<half_t, kAlignmentA>;
|
||||
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, half_t, LayoutA, 1, ThreadMapA, AccessTypeA,
|
||||
GatherA>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::Array<half_t, kAlignmentB>;
|
||||
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, half_t, LayoutB, 0, ThreadMapB, AccessTypeB,
|
||||
GatherB>;
|
||||
|
||||
// Define the threadblock-scoped multistage matrix multiply
|
||||
using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage<typename MmaCore::Shape, IteratorA,
|
||||
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
|
||||
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, 2>;
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@@ -0,0 +1,353 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/gemm/threadblock/default_mma.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h"
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & bf16 weight
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
SharedMemoryClearOption SharedMemoryClear,
|
||||
/// Gather operand A by using an index array
|
||||
bool GatherA,
|
||||
/// Gather operand B by using an index array
|
||||
bool GatherB>
|
||||
struct DefaultMma<bfloat16_t, LayoutA, kAlignmentA, bfloat16_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator, false,
|
||||
SharedMemoryClear, GatherA, GatherB>
|
||||
{
|
||||
|
||||
private:
|
||||
// Conversions only needed pre-ampere. This will trigger mma pipeline, so we convert before STS.
|
||||
static constexpr bool arch_has_bf16_mma = ArchTag::kMinComputeCapability >= 80;
|
||||
using MmaElementA = typename platform::conditional<arch_has_bf16_mma, bfloat16_t, half_t>::type;
|
||||
using MmaElementB = typename platform::conditional<arch_has_bf16_mma, bfloat16_t, half_t>::type;
|
||||
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore =
|
||||
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape, MmaElementA,
|
||||
LayoutA, MmaElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, 2, Operator>;
|
||||
|
||||
using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>, bfloat16_t, LayoutA, 1,
|
||||
typename MmaCore::IteratorThreadMapA, kAlignmentA, GatherA>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore::Shape::kK, MmaCore::Shape::kN>, bfloat16_t, LayoutB, 0,
|
||||
typename MmaCore::IteratorThreadMapB, kAlignmentB, GatherB>;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = cutlass::gemm::threadblock::MmaPipelined<typename MmaCore::Shape, IteratorA,
|
||||
typename MmaCore::SmemIteratorA, IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator,
|
||||
layout::RowMajor, typename MmaCore::MmaPolicy>;
|
||||
};
|
||||
|
||||
// bf16 x bf16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on
|
||||
// large tile when not enough shared mem is present to do 3+ stage
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
SharedMemoryClearOption SharedMemoryClear,
|
||||
/// Gather operand A by using an index array
|
||||
bool GatherA,
|
||||
/// Gather operand B by using an index array
|
||||
bool GatherB>
|
||||
struct DefaultMma<bfloat16_t, LayoutA, kAlignmentA, bfloat16_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, arch::Sm80, ThreadblockShape, WarpShape, InstructionShape, 2, Operator,
|
||||
false, SharedMemoryClear, GatherA, GatherB>
|
||||
{
|
||||
|
||||
// Define the MmaCore components
|
||||
// 3 is used on purpose here to trigger components for mma multistage
|
||||
using MmaCore =
|
||||
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape, bfloat16_t,
|
||||
LayoutA, bfloat16_t, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, 3, Operator>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::Array<bfloat16_t, kAlignmentA>;
|
||||
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, bfloat16_t, LayoutA, 1, ThreadMapA,
|
||||
AccessTypeA, GatherA>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::Array<bfloat16_t, kAlignmentB>;
|
||||
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, bfloat16_t, LayoutB, 0, ThreadMapB,
|
||||
AccessTypeB, GatherB>;
|
||||
|
||||
// Define the threadblock-scoped multistage matrix multiply
|
||||
using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage<typename MmaCore::Shape, IteratorA,
|
||||
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
|
||||
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, 2>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int8 weight
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultMma<cutlass::bfloat16_t, LayoutA, kAlignmentA, uint8_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator>
|
||||
{
|
||||
|
||||
private:
|
||||
static constexpr int kAlignmentScale = 128 / sizeof_bits<bfloat16_t>::value;
|
||||
|
||||
using Mma = DqMma<bfloat16_t, LayoutA, kAlignmentA, uint8_t, LayoutB, kAlignmentB, bfloat16_t, layout::RowMajor,
|
||||
kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
|
||||
WarpShape, InstructionShape, 2, Operator>;
|
||||
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int4 weight
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultMma<cutlass::bfloat16_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator>
|
||||
{
|
||||
|
||||
private:
|
||||
static constexpr int kAlignmentScale = 128 / sizeof_bits<bfloat16_t>::value;
|
||||
|
||||
using Mma = DqMma<bfloat16_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, bfloat16_t, layout::RowMajor,
|
||||
kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
|
||||
WarpShape, InstructionShape, 2, Operator>;
|
||||
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
};
|
||||
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
///
|
||||
int kStages,
|
||||
/// Shared memory clear option
|
||||
SharedMemoryClearOption SharedMemoryClear>
|
||||
struct DefaultMma<cutlass::bfloat16_t, LayoutA, kAlignmentA, uint8_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
|
||||
false, SharedMemoryClear>
|
||||
{
|
||||
|
||||
private:
|
||||
static constexpr int kAlignmentScale = 128 / sizeof_bits<bfloat16_t>::value;
|
||||
|
||||
using Mma = DqMma<bfloat16_t, LayoutA, kAlignmentA, uint8_t, LayoutB, kAlignmentB, bfloat16_t, layout::RowMajor,
|
||||
kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
|
||||
WarpShape, InstructionShape, kStages, Operator, SharedMemoryClear>;
|
||||
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
///
|
||||
int kStages,
|
||||
/// Shared memory clear option
|
||||
SharedMemoryClearOption SharedMemoryClear>
|
||||
struct DefaultMma<cutlass::bfloat16_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
|
||||
false, SharedMemoryClear>
|
||||
{
|
||||
|
||||
private:
|
||||
static constexpr int kAlignmentScale = 128 / sizeof_bits<bfloat16_t>::value;
|
||||
|
||||
using Mma = DqMma<bfloat16_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, bfloat16_t, layout::RowMajor,
|
||||
kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
|
||||
WarpShape, InstructionShape, kStages, Operator, SharedMemoryClear>;
|
||||
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@@ -0,0 +1,257 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/threadblock/mma_base.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass_extensions/weight_only_quant_op.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// SFINAE trick so I can keep the same loop code for Volta and dispatch to the
|
||||
// correct warp level mma. On volta, all data is stored to shared memory as FP16.
|
||||
template <typename WarpMma, int kExpansionFactor = 1>
|
||||
CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, typename WarpMma::FragmentC& D,
|
||||
typename WarpMma::FragmentA const& A, typename WarpMma::FragmentB const& B, typename WarpMma::FragmentC const& C,
|
||||
int const warp_tileB_k_offset)
|
||||
{
|
||||
warp_mma(D, A, B, C);
|
||||
}
|
||||
|
||||
template <typename WarpMma, int kExpansionFactor = WarpMma::kExpansionFactor>
|
||||
CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, typename WarpMma::FragmentC& D,
|
||||
typename WarpMma::TransformedFragmentA const& A, typename WarpMma::TransformedFragmentB const& B,
|
||||
typename WarpMma::FragmentC const& C, int const warp_tileB_k_offset)
|
||||
{
|
||||
warp_mma(D, A, B, C, warp_tileB_k_offset);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
|
||||
/// instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy_,
|
||||
/// The type of the scales
|
||||
typename ElementScale_,
|
||||
/// Number of stages,
|
||||
int Stages,
|
||||
/// The dequantizing op to be performed.
|
||||
WeightOnlyQuantOp DequantOp,
|
||||
/// Used for partial specialization,
|
||||
typename Enable = bool>
|
||||
class DqMmaBase
|
||||
{
|
||||
public:
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape = Shape_;
|
||||
|
||||
///< Policy describing tuning details
|
||||
using Policy = Policy_;
|
||||
|
||||
///< Type of the scale to be loaded
|
||||
using ElementScale = ElementScale_;
|
||||
|
||||
static_assert(DequantOp != WeightOnlyQuantOp::UNDEFINED, "");
|
||||
|
||||
// Finegrained scales get streamed in via cp.async
|
||||
static constexpr int ScalebiasStages = isFinegrained(DequantOp) ? Stages : 1;
|
||||
// We always have scales.
|
||||
static constexpr int ScaleElementsPerStage = Shape::kN;
|
||||
// We sometimes have a bias
|
||||
static constexpr int BiasElementsPerStage = hasZero(DequantOp) ? Shape::kN : 0;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator = typename Policy::Operator;
|
||||
|
||||
/// Shape describing the overall GEMM computed from shared memory
|
||||
/// by each warp.
|
||||
using WarpGemm = typename Policy::Operator::Shape;
|
||||
|
||||
/// Shape describing the number of warps filling the CTA
|
||||
using WarpCount = GemmShape<Shape::kM / WarpGemm::kM, Shape::kN / WarpGemm::kN, Shape::kK / WarpGemm::kK>;
|
||||
|
||||
/// Number of warp-level GEMM operations
|
||||
static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK);
|
||||
|
||||
static constexpr int kNumKIterationsPerWarpBLoad
|
||||
= Operator::IteratorB::InstructionShape::kRow / Operator::InstructionShape::kK;
|
||||
|
||||
static_assert(!(kWarpGemmIterations % kNumKIterationsPerWarpBLoad), "");
|
||||
static constexpr int kWarpGemmIterationsForB = kWarpGemmIterations / kNumKIterationsPerWarpBLoad;
|
||||
|
||||
/// Number of stages
|
||||
static int const kStages = Stages;
|
||||
|
||||
/// Tensor reference to the A operand
|
||||
using TensorRefA = TensorRef<typename Operator::ElementA, typename Operator::LayoutA>;
|
||||
|
||||
/// Tensor reference to the B operand
|
||||
using TensorRefB = TensorRef<typename Operator::ElementB, typename Operator::LayoutB>;
|
||||
|
||||
//
|
||||
// Nested structs
|
||||
//
|
||||
|
||||
/// Shared storage object needed by threadblock-scoped GEMM
|
||||
class SharedStorage
|
||||
{
|
||||
public:
|
||||
//
|
||||
// Type definitions
|
||||
//
|
||||
|
||||
/// Shape of the A matrix operand in shared memory
|
||||
using ShapeA
|
||||
= MatrixShape<Shape::kM + Policy::SmemPaddingA::kRow, Shape::kK * kStages + Policy::SmemPaddingA::kColumn>;
|
||||
|
||||
/// Shape of the B matrix operand in shared memory
|
||||
using ShapeB
|
||||
= MatrixShape<Shape::kK * kStages + Policy::SmemPaddingB::kRow, Shape::kN + Policy::SmemPaddingB::kColumn>;
|
||||
|
||||
/// Shape of the shared memory buffer for the scales for the B matrix.
|
||||
using ShapeScale = MatrixShape<ScalebiasStages, ScaleElementsPerStage>;
|
||||
/// Shape of the shared memory buffer for the biases of the B matrix.
|
||||
using ShapeZero = MatrixShape<ScalebiasStages, BiasElementsPerStage>;
|
||||
|
||||
public:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Buffer for A operand
|
||||
AlignedBuffer<typename Operator::ElementA, ShapeA::kCount> operand_A;
|
||||
|
||||
/// Buffer for B operand
|
||||
AlignedBuffer<typename Operator::ElementB, ShapeB::kCount> operand_B;
|
||||
|
||||
/// Buffer to hold scales for threadblock
|
||||
AlignedBuffer<ElementScale, ShapeScale::kCount> operand_scale;
|
||||
|
||||
/// Buffer to hold scales for threadblock
|
||||
AlignedBuffer<ElementScale, ShapeZero::kCount> operand_zero;
|
||||
|
||||
public:
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Returns a layout object for the A matrix
|
||||
CUTLASS_DEVICE
|
||||
static typename Operator::LayoutA LayoutA()
|
||||
{
|
||||
return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn});
|
||||
}
|
||||
|
||||
/// Returns a layout object for the B matrix
|
||||
CUTLASS_HOST_DEVICE
|
||||
static typename Operator::LayoutB LayoutB()
|
||||
{
|
||||
return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn});
|
||||
}
|
||||
|
||||
/// Returns a TensorRef to the A operand
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRefA operand_A_ref()
|
||||
{
|
||||
return TensorRefA{operand_A.data(), LayoutA()};
|
||||
}
|
||||
|
||||
/// Returns a TensorRef to the B operand
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRefB operand_B_ref()
|
||||
{
|
||||
return TensorRefB{operand_B.data(), LayoutB()};
|
||||
}
|
||||
};
|
||||
|
||||
protected:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Iterator to load a warp-scoped tile of A operand from shared memory
|
||||
typename Operator::IteratorA warp_tile_iterator_A_;
|
||||
|
||||
/// Iterator to load a warp-scoped tile of B operand from shared memory
|
||||
typename Operator::IteratorB warp_tile_iterator_B_;
|
||||
|
||||
public:
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
DqMmaBase(
|
||||
///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
SharedStorage& shared_storage,
|
||||
///< ID within the threadblock
|
||||
int thread_idx,
|
||||
///< ID of warp
|
||||
int warp_idx,
|
||||
///< ID of each thread within a warp
|
||||
int lane_idx)
|
||||
: warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx)
|
||||
, warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,110 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h"
|
||||
#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h"
|
||||
#include "cutlass_extensions/interleaved_numeric_conversion.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
|
||||
/// instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Iterates over tiles of A operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorA_,
|
||||
/// Iterates over tiles of A operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorA_,
|
||||
/// Cache operation for operand A
|
||||
cutlass::arch::CacheOperation::Kind CacheOpA,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorB_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB_,
|
||||
/// Cache operation for operand B
|
||||
cutlass::arch::CacheOperation::Kind CacheOpB,
|
||||
/// Data type for the scales
|
||||
typename IteratorScale_,
|
||||
/// Iterators over scales in shared memory
|
||||
typename SmemIteratorScale_,
|
||||
/// Data type of accumulator matrix
|
||||
typename ElementC_,
|
||||
/// Data type of accumulator matrix
|
||||
typename LayoutC_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy_,
|
||||
/// Number of stages,
|
||||
int Stages,
|
||||
/// Converter for B matrix applited immediately after the LDS
|
||||
typename TransformBAfterLDS_,
|
||||
/// The quantization operator being used
|
||||
WeightOnlyQuantOp QuantOp_,
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
|
||||
/// Used for partial specialization
|
||||
typename Enable = void>
|
||||
class DqMmaMultistage;
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
#include "cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h"
|
||||
@@ -0,0 +1,708 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h"
|
||||
#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h"
|
||||
#include "cutlass_extensions/interleaved_numeric_conversion.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
|
||||
/// instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Iterates over tiles of A operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorA_,
|
||||
/// Iterates over tiles of A operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorA_,
|
||||
/// Cache operation for operand A
|
||||
cutlass::arch::CacheOperation::Kind CacheOpA,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorB_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB_,
|
||||
/// Cache operation for operand B
|
||||
cutlass::arch::CacheOperation::Kind CacheOpB,
|
||||
/// Iterators over scales in global memory
|
||||
typename IteratorScale_,
|
||||
/// Iterators over scales in shared memory
|
||||
typename SmemIteratorScale_,
|
||||
/// Data type of accumulator matrix
|
||||
typename ElementC_,
|
||||
/// Layout of accumulator matrix
|
||||
typename LayoutC_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy_,
|
||||
/// Number of stages,
|
||||
int Stages,
|
||||
/// Converter for B matrix applied immediately after the LDS
|
||||
typename TransformBAfterLDS_,
|
||||
/// The quantization operator being used
|
||||
WeightOnlyQuantOp QuantOp_,
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
SharedMemoryClearOption SharedMemoryClear>
|
||||
class DqMmaMultistage<Shape_, IteratorA_, SmemIteratorA_, CacheOpA, IteratorB_, SmemIteratorB_, CacheOpB,
|
||||
IteratorScale_, SmemIteratorScale_, ElementC_, LayoutC_, Policy_, Stages, TransformBAfterLDS_, QuantOp_,
|
||||
SharedMemoryClear, std::enable_if_t<isFinegrained(QuantOp_)>>
|
||||
: public DqMmaBase<Shape_, Policy_, typename IteratorScale_::Element, Stages, QuantOp_>
|
||||
{
|
||||
public:
|
||||
///< Base class
|
||||
using Base = DqMmaBase<Shape_, Policy_, typename IteratorScale_::Element, Stages, QuantOp_>;
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape = Shape_;
|
||||
///< Iterates over tiles of A operand in global memory
|
||||
using IteratorA = IteratorA_;
|
||||
///< Iterates over tiles of B operand in global memory
|
||||
using IteratorB = IteratorB_;
|
||||
///< Data type of accumulator matrix
|
||||
using ElementC = ElementC_;
|
||||
///< Layout of accumulator matrix
|
||||
using LayoutC = LayoutC_;
|
||||
///< Policy describing tuning details
|
||||
using Policy = Policy_;
|
||||
|
||||
using IteratorScale = IteratorScale_;
|
||||
using ElementScale = typename IteratorScale::Element;
|
||||
using LayoutScale = typename IteratorScale::Layout;
|
||||
|
||||
using SmemIteratorA = SmemIteratorA_;
|
||||
using SmemIteratorB = SmemIteratorB_;
|
||||
using SmemIteratorScale = SmemIteratorScale_;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
|
||||
|
||||
using TransformBAfterLDS = TransformBAfterLDS_;
|
||||
|
||||
static constexpr WeightOnlyQuantOp QuantOp = QuantOp_;
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC = typename Policy::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator = typename Policy::Operator;
|
||||
|
||||
/// Minimum architecture is Sm80 to support cp.async
|
||||
using ArchTag = arch::Sm80;
|
||||
|
||||
using Dequantizer = warp::MmaTensorOpDequantizer<Operator, typename Base::WarpGemm, Operand::kB, ElementScale,
|
||||
LayoutScale, 32, QuantOp>;
|
||||
|
||||
/// Complex transform on A operand
|
||||
static ComplexTransform const kTransformA = Operator::kTransformA;
|
||||
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB = Operator::kTransformB;
|
||||
|
||||
static_assert(Base::SharedStorage::ShapeScale::kRow == Stages, "");
|
||||
static_assert(Base::SharedStorage::ShapeScale::kColumn == Shape::kN, "");
|
||||
|
||||
/// Internal structure exposed for introspection.
|
||||
struct Detail
|
||||
{
|
||||
|
||||
static_assert(Base::kWarpGemmIterations > 1,
|
||||
"The pipelined structure requires at least two warp-level "
|
||||
"GEMM operations.");
|
||||
|
||||
/// Number of cp.async instructions to load one stage of operand A
|
||||
static int const AsyncCopyIterationsPerStageA = IteratorA::ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Number of cp.async instructions to load one stage of operand B
|
||||
static int const AsyncCopyIterationsPerStageB = IteratorB::ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Number of stages
|
||||
static int const kStages = Stages;
|
||||
|
||||
/// Number of cp.async instructions to load on group of operand A
|
||||
static int const kAccessesPerGroupA
|
||||
= (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
|
||||
|
||||
/// Number of cp.async instructions to load on group of operand B
|
||||
static int const kAccessesPerGroupB
|
||||
= (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
|
||||
};
|
||||
|
||||
private:
|
||||
using WarpFragmentA = typename Operator::FragmentA;
|
||||
using WarpFragmentB = typename Operator::FragmentB;
|
||||
Dequantizer warp_dequantizer_;
|
||||
|
||||
using ElementA = typename IteratorA::Element;
|
||||
using ElementB = typename IteratorB::Element;
|
||||
using LayoutDetailsForB = kernel::LayoutDetailsB<ElementA, ElementB, ArchTag>;
|
||||
|
||||
static constexpr bool RequiresTileInterleave
|
||||
= layout::IsColumnMajorTileInterleave<typename LayoutDetailsForB::Layout>::value;
|
||||
static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)),
|
||||
"Layout K must match threadblockK");
|
||||
|
||||
private:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
||||
SmemIteratorA smem_iterator_A_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB smem_iterator_B_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of scale and zero operand to shared memory
|
||||
SmemIteratorScale smem_iterator_scale_;
|
||||
|
||||
public:
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
DqMmaMultistage(
|
||||
///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
typename Base::SharedStorage& shared_storage,
|
||||
/// The group size for quantization
|
||||
int const group_size,
|
||||
///< ID within the threadblock
|
||||
int thread_idx,
|
||||
///< ID of warp
|
||||
int warp_idx,
|
||||
///< ID of each thread within a warp
|
||||
int lane_idx)
|
||||
: Base(shared_storage, thread_idx, warp_idx, lane_idx)
|
||||
, warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)},
|
||||
{shared_storage.operand_zero.data(), LayoutScale(Shape::kN)},
|
||||
(warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx)
|
||||
, smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx)
|
||||
, smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx)
|
||||
, smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(),
|
||||
shared_storage.operand_zero.data(), {Base::kStages, Shape::kN}, thread_idx, group_size)
|
||||
{
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
// _m: the warp's position within the threadblock along the M dimension
|
||||
// _n: the warp's position within the threadblock along the N dimension
|
||||
// _k: the warp's position within the threadblock along the K dimension
|
||||
|
||||
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
|
||||
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
|
||||
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
|
||||
|
||||
// Add per-warp offsets in units of warp-level tiles
|
||||
this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
|
||||
this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n});
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void copy_scales_and_advance(IteratorScale& iterator_scale, int stage = -1, int k_iter = -1)
|
||||
{
|
||||
static_assert(IteratorScale::Shape::kRow == 1, "Scale stride must be 1.");
|
||||
|
||||
typename IteratorScale::AccessType* gmem_scale_ptr = iterator_scale.get_scale();
|
||||
typename IteratorScale::AccessType* gmem_zero_ptr = iterator_scale.get_zero();
|
||||
|
||||
typename IteratorScale::AccessType* smem_scale_ptr
|
||||
= reinterpret_cast<typename IteratorScale::AccessType*>(this->smem_iterator_scale_.get_scale());
|
||||
typename IteratorScale::AccessType* smem_zero_ptr
|
||||
= reinterpret_cast<typename IteratorScale::AccessType*>(this->smem_iterator_scale_.get_zero());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorScale::Element>::value * IteratorScale::kAlignment / 8;
|
||||
|
||||
cutlass::arch::cp_async<kSrcBytes, kCacheOpB>(smem_scale_ptr, gmem_scale_ptr, iterator_scale.valid());
|
||||
|
||||
if (gmem_zero_ptr != nullptr)
|
||||
{
|
||||
cutlass::arch::cp_async<kSrcBytes, kCacheOpB>(smem_zero_ptr, gmem_zero_ptr, iterator_scale.valid());
|
||||
}
|
||||
|
||||
if (iterator_scale.group_size_ == 64)
|
||||
{
|
||||
iterator_scale.add_tile_offset({1, 0});
|
||||
}
|
||||
else if (iterator_scale.group_size_ == 128)
|
||||
{
|
||||
if constexpr (Shape::kK == 128)
|
||||
{
|
||||
iterator_scale.add_tile_offset({1, 0});
|
||||
}
|
||||
else if constexpr (Shape::kK == 64)
|
||||
{
|
||||
if (iterator_scale.row_groupsize64_ & 0x1)
|
||||
{
|
||||
iterator_scale.add_tile_offset({1, 0});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(Shape::kK == 0, "Unsupported k tile shape, can only be 64 or 128");
|
||||
}
|
||||
}
|
||||
|
||||
iterator_scale.row_groupsize64_++;
|
||||
|
||||
this->smem_iterator_scale_.add_tile_offset({1, 0});
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void copy_tiles_and_advance(
|
||||
IteratorA& iterator_A, IteratorB& iterator_B, int group_start_A = 0, int group_start_B = 0)
|
||||
{
|
||||
iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector);
|
||||
this->smem_iterator_A_.set_iteration_index(group_start_A);
|
||||
|
||||
// Async Copy for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::kAccessesPerGroupA; ++j)
|
||||
{
|
||||
if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA)
|
||||
{
|
||||
typename IteratorA::AccessType* dst_ptr
|
||||
= reinterpret_cast<typename IteratorA::AccessType*>(this->smem_iterator_A_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value
|
||||
* IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v)
|
||||
{
|
||||
auto gmem_ptr = iterator_A.get();
|
||||
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kZfill)
|
||||
{
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(dst_ptr + v, gmem_ptr, iterator_A.valid());
|
||||
}
|
||||
else
|
||||
{
|
||||
cutlass::arch::cp_async<kSrcBytes, kCacheOpA>(dst_ptr + v, gmem_ptr, iterator_A.valid());
|
||||
}
|
||||
|
||||
++iterator_A;
|
||||
}
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
}
|
||||
}
|
||||
|
||||
iterator_B.set_iteration_index(group_start_B * IteratorB::kAccessesPerVector);
|
||||
this->smem_iterator_B_.set_iteration_index(group_start_B);
|
||||
|
||||
// Async Copy for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::kAccessesPerGroupB; ++j)
|
||||
{
|
||||
if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB)
|
||||
{
|
||||
typename IteratorB::AccessType* dst_ptr
|
||||
= reinterpret_cast<typename IteratorB::AccessType*>(this->smem_iterator_B_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value
|
||||
* IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v)
|
||||
{
|
||||
auto gmem_ptr = iterator_B.get();
|
||||
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kZfill)
|
||||
{
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(dst_ptr + v, gmem_ptr, iterator_B.valid());
|
||||
}
|
||||
else
|
||||
{
|
||||
cutlass::arch::cp_async<kSrcBytes, kCacheOpB>(dst_ptr + v, gmem_ptr, iterator_B.valid());
|
||||
}
|
||||
|
||||
++iterator_B;
|
||||
}
|
||||
++this->smem_iterator_B_;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a threadblock-scoped matrix multiply-accumulate
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
///< problem size of GEMM
|
||||
int gemm_k_iterations,
|
||||
///< destination accumulator tile
|
||||
FragmentC& accum,
|
||||
///< iterator over A operand in global memory
|
||||
IteratorA iterator_A,
|
||||
///< iterator over B operand in global memory
|
||||
IteratorB iterator_B,
|
||||
///< iterator over scale operand in global memory
|
||||
IteratorScale iterator_scale,
|
||||
///< initial value of accumulator
|
||||
FragmentC const& src_accum)
|
||||
{
|
||||
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
|
||||
TransformBAfterLDS lds_converter;
|
||||
|
||||
// Issue several complete stages
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations)
|
||||
{
|
||||
|
||||
iterator_A.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_scale.clear_mask(gemm_k_iterations == 0);
|
||||
|
||||
iterator_A.set_iteration_index(0);
|
||||
this->smem_iterator_A_.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j)
|
||||
{
|
||||
typename IteratorA::AccessType* dst_ptr
|
||||
= reinterpret_cast<typename IteratorA::AccessType*>(this->smem_iterator_A_.get());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v)
|
||||
{
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value
|
||||
* IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
|
||||
dst_ptr + v, iterator_A.get(), iterator_A.valid());
|
||||
|
||||
++iterator_A;
|
||||
}
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
}
|
||||
|
||||
iterator_B.set_iteration_index(0);
|
||||
this->smem_iterator_B_.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j)
|
||||
{
|
||||
typename IteratorB::AccessType* dst_ptr
|
||||
= reinterpret_cast<typename IteratorB::AccessType*>(this->smem_iterator_B_.get());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v)
|
||||
{
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value
|
||||
* IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
|
||||
dst_ptr + v, iterator_B.get(), iterator_B.valid());
|
||||
|
||||
++iterator_B;
|
||||
}
|
||||
|
||||
++this->smem_iterator_B_;
|
||||
}
|
||||
|
||||
copy_scales_and_advance(iterator_scale, stage, gemm_k_iterations);
|
||||
|
||||
// Move to the next stage
|
||||
iterator_A.add_tile_offset({0, 1});
|
||||
iterator_B.add_tile_offset({1, 0});
|
||||
|
||||
this->smem_iterator_A_.add_tile_offset({0, 1});
|
||||
this->smem_iterator_B_.add_tile_offset({1, 0});
|
||||
|
||||
// Defines the boundary of a stage of cp.async.
|
||||
cutlass::arch::cp_async_fence();
|
||||
}
|
||||
|
||||
// Perform accumulation in the 'd' output operand
|
||||
accum = src_accum;
|
||||
|
||||
//
|
||||
// Clear the remaining tiles of SMEM. This is a functional requirement for some kernels
|
||||
// so that all accumulator elements outside the GEMM footprint are zero.
|
||||
//
|
||||
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage)
|
||||
{
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
||||
SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_);
|
||||
|
||||
typename IteratorA::AccessType zero_A;
|
||||
zero_A.clear();
|
||||
|
||||
last_smem_iterator_A.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j)
|
||||
{
|
||||
|
||||
typename IteratorA::AccessType* dst_ptr
|
||||
= reinterpret_cast<typename IteratorA::AccessType*>(last_smem_iterator_A.get());
|
||||
|
||||
*dst_ptr = zero_A;
|
||||
|
||||
++last_smem_iterator_A;
|
||||
}
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_);
|
||||
typename IteratorB::AccessType zero_B;
|
||||
|
||||
zero_B.clear();
|
||||
last_smem_iterator_B.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j)
|
||||
{
|
||||
|
||||
typename IteratorB::AccessType* dst_ptr
|
||||
= reinterpret_cast<typename IteratorB::AccessType*>(last_smem_iterator_B.get());
|
||||
|
||||
*dst_ptr = zero_B;
|
||||
|
||||
++last_smem_iterator_B;
|
||||
}
|
||||
}
|
||||
|
||||
// Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - #committed)
|
||||
cutlass::arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math
|
||||
// instructions
|
||||
WarpFragmentA warp_frag_A[2];
|
||||
WarpFragmentB warp_frag_B[2];
|
||||
typename Dequantizer::FragmentScale warp_frag_scales;
|
||||
typename Dequantizer::FragmentZero warp_frag_zeros;
|
||||
|
||||
Operator warp_mma;
|
||||
|
||||
this->warp_tile_iterator_A_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(0);
|
||||
|
||||
this->warp_tile_iterator_A_.load(warp_frag_A[0]);
|
||||
this->warp_tile_iterator_B_.load(warp_frag_B[0]);
|
||||
|
||||
warp_dequantizer_.load(warp_frag_scales, warp_frag_zeros);
|
||||
|
||||
++this->warp_tile_iterator_A_;
|
||||
++this->warp_tile_iterator_B_;
|
||||
warp_dequantizer_.add_pointer_offset(Shape::kN);
|
||||
|
||||
iterator_A.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_scale.clear_mask(gemm_k_iterations == 0);
|
||||
|
||||
int smem_write_stage_idx = Base::kStages - 1;
|
||||
int smem_read_stage_idx = 0;
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; gemm_k_iterations > (-Base::kStages + 1);)
|
||||
{
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
// Computes a warp-level GEMM on data held in shared memory
|
||||
// Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k)
|
||||
{
|
||||
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if
|
||||
// this is the last group as the case may be.
|
||||
|
||||
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
|
||||
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
|
||||
++this->warp_tile_iterator_A_;
|
||||
|
||||
int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
|
||||
int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
|
||||
if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1)
|
||||
{
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(
|
||||
(warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB);
|
||||
this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]);
|
||||
++this->warp_tile_iterator_B_;
|
||||
}
|
||||
|
||||
typename TransformBAfterLDS::result_type converted_frag_B
|
||||
= lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);
|
||||
warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales, warp_frag_zeros);
|
||||
|
||||
using FragmentOperandB = cutlass::Array<ElementA, Operator::FragmentB::kElements>;
|
||||
constexpr cutlass::FloatRoundStyle RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
|
||||
constexpr int ConversionVectorWidth = TransformBAfterLDS::result_type::kElements;
|
||||
static_assert(ConversionVectorWidth == FragmentOperandB::kElements);
|
||||
|
||||
using Converter
|
||||
= cutlass::NumericArrayConverter<ElementA, ElementScale, ConversionVectorWidth, RoundStyle>;
|
||||
|
||||
FragmentOperandB converted_frag_B_operand = Converter::convert(converted_frag_B);
|
||||
run_warp_mma(warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B_operand, accum,
|
||||
warp_tileB_k_compute_offset);
|
||||
|
||||
// Issue global->shared copies for the this stage
|
||||
if (warp_mma_k < Base::kWarpGemmIterations - 1)
|
||||
{
|
||||
int group_start_iteration_A, group_start_iteration_B;
|
||||
|
||||
group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA;
|
||||
group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB;
|
||||
|
||||
copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B);
|
||||
|
||||
// This is the first group of a given stage, so we issue the loads for the B scales immediately.
|
||||
if (group_start_iteration_B == 0)
|
||||
{
|
||||
copy_scales_and_advance(iterator_scale);
|
||||
}
|
||||
}
|
||||
|
||||
if (warp_mma_k + 2 == Base::kWarpGemmIterations)
|
||||
{
|
||||
int group_start_iteration_A, group_start_iteration_B;
|
||||
group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA;
|
||||
group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB;
|
||||
|
||||
copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B);
|
||||
|
||||
// Inserts a memory fence between stages of cp.async instructions.
|
||||
cutlass::arch::cp_async_fence();
|
||||
|
||||
// Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 -
|
||||
// #committed)
|
||||
arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Move to the next stage
|
||||
iterator_A.add_tile_offset({0, 1});
|
||||
iterator_B.add_tile_offset({1, 0});
|
||||
|
||||
this->smem_iterator_A_.add_tile_offset({0, 1});
|
||||
this->smem_iterator_B_.add_tile_offset({1, 0});
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the
|
||||
// circular buffer in shared memory
|
||||
if (smem_write_stage_idx == (Base::kStages - 1))
|
||||
{
|
||||
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
|
||||
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
|
||||
this->smem_iterator_scale_.add_tile_offset({-Base::kStages, 0});
|
||||
smem_write_stage_idx = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
++smem_write_stage_idx;
|
||||
}
|
||||
|
||||
if (smem_read_stage_idx == (Base::kStages - 1))
|
||||
{
|
||||
this->warp_tile_iterator_A_.add_tile_offset(
|
||||
{0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
|
||||
this->warp_tile_iterator_B_.add_tile_offset(
|
||||
{-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0});
|
||||
warp_dequantizer_.add_pointer_offset(-Base::kStages * Shape::kN);
|
||||
smem_read_stage_idx = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
++smem_read_stage_idx;
|
||||
}
|
||||
|
||||
--gemm_k_iterations;
|
||||
iterator_A.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_scale.clear_mask(gemm_k_iterations == 0);
|
||||
}
|
||||
}
|
||||
|
||||
// Load the scale needed for the next tile iteration.
|
||||
warp_dequantizer_.load(warp_frag_scales, warp_frag_zeros);
|
||||
// Update internal pointer to set of scales in shared memory.
|
||||
warp_dequantizer_.add_pointer_offset(Shape::kN);
|
||||
}
|
||||
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kZfill)
|
||||
{
|
||||
// commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop
|
||||
cutlass::arch::cp_async_fence();
|
||||
cutlass::arch::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,647 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h"
|
||||
#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h"
|
||||
#include "cutlass_extensions/interleaved_numeric_conversion.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
|
||||
/// instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Iterates over tiles of A operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorA_,
|
||||
/// Iterates over tiles of A operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorA_,
|
||||
/// Cache operation for operand A
|
||||
cutlass::arch::CacheOperation::Kind CacheOpA,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorB_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB_,
|
||||
/// Cache operation for operand B
|
||||
cutlass::arch::CacheOperation::Kind CacheOpB,
|
||||
/// Iterators over scales in global memory
|
||||
typename IteratorScale_,
|
||||
/// Iterators over scales in shared memory
|
||||
typename SmemIteratorScale_,
|
||||
/// Data type of accumulator matrix
|
||||
typename ElementC_,
|
||||
/// Layout of accumulator matrix
|
||||
typename LayoutC_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy_,
|
||||
/// Number of stages,
|
||||
int Stages,
|
||||
/// Converter for B matrix applited immediately after the LDS
|
||||
typename TransformBAfterLDS_,
|
||||
/// The quantization operator being used
|
||||
WeightOnlyQuantOp QuantOp_,
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
SharedMemoryClearOption SharedMemoryClear>
|
||||
class DqMmaMultistage<Shape_, IteratorA_, SmemIteratorA_, CacheOpA, IteratorB_, SmemIteratorB_, CacheOpB,
|
||||
IteratorScale_, SmemIteratorScale_, ElementC_, LayoutC_, Policy_, Stages, TransformBAfterLDS_, QuantOp_,
|
||||
SharedMemoryClear, std::enable_if_t<!isFinegrained(QuantOp_)>>
|
||||
: public DqMmaBase<Shape_, Policy_, typename IteratorScale_::Element, Stages, QuantOp_>
|
||||
{
|
||||
public:
|
||||
///< Base class
|
||||
using Base = DqMmaBase<Shape_, Policy_, typename IteratorScale_::Element, Stages, QuantOp_>;
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape = Shape_;
|
||||
///< Iterates over tiles of A operand in global memory
|
||||
using IteratorA = IteratorA_;
|
||||
///< Iterates over tiles of B operand in global memory
|
||||
using IteratorB = IteratorB_;
|
||||
///< Data type of accumulator matrix
|
||||
using ElementC = ElementC_;
|
||||
///< Layout of accumulator matrix
|
||||
using LayoutC = LayoutC_;
|
||||
///< Policy describing tuning details
|
||||
using Policy = Policy_;
|
||||
|
||||
using IteratorScale = IteratorScale_;
|
||||
using ElementScale = typename IteratorScale::Element;
|
||||
using LayoutScale = typename IteratorScale::Layout;
|
||||
|
||||
using SmemIteratorA = SmemIteratorA_;
|
||||
using SmemIteratorB = SmemIteratorB_;
|
||||
using SmemIteratorScale = SmemIteratorScale_;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
|
||||
|
||||
using TransformBAfterLDS = TransformBAfterLDS_;
|
||||
|
||||
static constexpr WeightOnlyQuantOp QuantOp = QuantOp_;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Fragment of operand Scale loaded from global memory;
|
||||
using FragmentScale = typename IteratorScale::Fragment;
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC = typename Policy::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator = typename Policy::Operator;
|
||||
|
||||
/// Minimum architecture is Sm80 to support cp.async
|
||||
using ArchTag = arch::Sm80;
|
||||
|
||||
using Dequantizer = warp::MmaTensorOpDequantizer<Operator, typename Base::WarpGemm, Operand::kB, ElementScale,
|
||||
LayoutScale, 32, QuantOp>;
|
||||
|
||||
/// Complex transform on A operand
|
||||
static ComplexTransform const kTransformA = Operator::kTransformA;
|
||||
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB = Operator::kTransformB;
|
||||
|
||||
/// Internal structure exposed for introspection.
|
||||
struct Detail
|
||||
{
|
||||
|
||||
static_assert(Base::kWarpGemmIterations > 1,
|
||||
"The pipelined structure requires at least two warp-level "
|
||||
"GEMM operations.");
|
||||
|
||||
/// Number of cp.async instructions to load one stage of operand A
|
||||
static int const AsyncCopyIterationsPerStageA = IteratorA::ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Number of cp.async instructions to load one stage of operand B
|
||||
static int const AsyncCopyIterationsPerStageB = IteratorB::ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Number of stages
|
||||
static int const kStages = Stages;
|
||||
|
||||
/// Number of cp.async instructions to load on group of operand A
|
||||
static int const kAccessesPerGroupA
|
||||
= (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
|
||||
|
||||
/// Number of cp.async instructions to load on group of operand B
|
||||
static int const kAccessesPerGroupB
|
||||
= (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
|
||||
};
|
||||
|
||||
private:
|
||||
using WarpFragmentA = typename Operator::FragmentA;
|
||||
using WarpFragmentB = typename Operator::FragmentB;
|
||||
Dequantizer warp_dequantizer_;
|
||||
|
||||
using ElementA = typename IteratorA::Element;
|
||||
using ElementB = typename IteratorB::Element;
|
||||
using LayoutDetailsForB = kernel::LayoutDetailsB<ElementA, ElementB, ArchTag>;
|
||||
|
||||
static constexpr bool RequiresTileInterleave
|
||||
= layout::IsColumnMajorTileInterleave<typename LayoutDetailsForB::Layout>::value;
|
||||
static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)),
|
||||
"Layout K must match threadblockK");
|
||||
|
||||
private:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
||||
SmemIteratorA smem_iterator_A_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB smem_iterator_B_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of scale operand to shared memory
|
||||
SmemIteratorScale smem_iterator_scale_;
|
||||
|
||||
public:
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
DqMmaMultistage(
|
||||
///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
typename Base::SharedStorage& shared_storage,
|
||||
///< Group size for quantization. Not used by this main loop since it assumes per-column
|
||||
int const group_size,
|
||||
///< ID within the threadblock
|
||||
int thread_idx,
|
||||
///< ID of warp
|
||||
int warp_idx,
|
||||
///< ID of each thread within a warp
|
||||
int lane_idx)
|
||||
: Base(shared_storage, thread_idx, warp_idx, lane_idx)
|
||||
, warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)},
|
||||
(warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx)
|
||||
, smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx)
|
||||
, smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx)
|
||||
, smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), {1, Shape::kN}, thread_idx)
|
||||
{
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
// _m: the warp's position within the threadblock along the M dimension
|
||||
// _n: the warp's position within the threadblock along the N dimension
|
||||
// _k: the warp's position within the threadblock along the K dimension
|
||||
|
||||
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
|
||||
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
|
||||
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
|
||||
|
||||
// Add per-warp offsets in units of warp-level tiles
|
||||
this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
|
||||
this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n});
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void copy_tiles_and_advance(
|
||||
IteratorA& iterator_A, IteratorB& iterator_B, int group_start_A = 0, int group_start_B = 0)
|
||||
{
|
||||
iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector);
|
||||
this->smem_iterator_A_.set_iteration_index(group_start_A);
|
||||
|
||||
// Async Copy for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::kAccessesPerGroupA; ++j)
|
||||
{
|
||||
if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA)
|
||||
{
|
||||
typename IteratorA::AccessType* dst_ptr
|
||||
= reinterpret_cast<typename IteratorA::AccessType*>(this->smem_iterator_A_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value
|
||||
* IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v)
|
||||
{
|
||||
auto gmem_ptr = iterator_A.get();
|
||||
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kZfill)
|
||||
{
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(dst_ptr + v, gmem_ptr, iterator_A.valid());
|
||||
}
|
||||
else
|
||||
{
|
||||
cutlass::arch::cp_async<kSrcBytes, kCacheOpA>(dst_ptr + v, gmem_ptr, iterator_A.valid());
|
||||
}
|
||||
|
||||
++iterator_A;
|
||||
}
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
}
|
||||
}
|
||||
|
||||
iterator_B.set_iteration_index(group_start_B * IteratorB::kAccessesPerVector);
|
||||
this->smem_iterator_B_.set_iteration_index(group_start_B);
|
||||
|
||||
// Async Copy for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::kAccessesPerGroupB; ++j)
|
||||
{
|
||||
if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB)
|
||||
{
|
||||
typename IteratorB::AccessType* dst_ptr
|
||||
= reinterpret_cast<typename IteratorB::AccessType*>(this->smem_iterator_B_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value
|
||||
* IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v)
|
||||
{
|
||||
auto gmem_ptr = iterator_B.get();
|
||||
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kZfill)
|
||||
{
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(dst_ptr + v, gmem_ptr, iterator_B.valid());
|
||||
}
|
||||
else
|
||||
{
|
||||
cutlass::arch::cp_async<kSrcBytes, kCacheOpB>(dst_ptr + v, gmem_ptr, iterator_B.valid());
|
||||
}
|
||||
|
||||
++iterator_B;
|
||||
}
|
||||
++this->smem_iterator_B_;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a threadblock-scoped matrix multiply-accumulate
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
///< problem size of GEMM
|
||||
int gemm_k_iterations,
|
||||
///< destination accumulator tile
|
||||
FragmentC& accum,
|
||||
///< iterator over A operand in global memory
|
||||
IteratorA iterator_A,
|
||||
///< iterator over B operand in global memory
|
||||
IteratorB iterator_B,
|
||||
///< iterator over scale operand in global memory
|
||||
IteratorScale iterator_scale,
|
||||
///< initial value of accumulator
|
||||
FragmentC const& src_accum)
|
||||
{
|
||||
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
|
||||
TransformBAfterLDS lds_converter;
|
||||
|
||||
// NOTE - switch to ldg.sts
|
||||
// Issue this first, so cp.async.commit_group will commit this load as well.
|
||||
// Note: we do not commit here and this load will commit in the same group as
|
||||
// the first load of A.
|
||||
FragmentScale tb_frag_scales;
|
||||
tb_frag_scales.clear();
|
||||
iterator_scale.load(tb_frag_scales);
|
||||
this->smem_iterator_scale_.store(tb_frag_scales);
|
||||
|
||||
// Issue several complete stages
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations)
|
||||
{
|
||||
|
||||
iterator_A.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B.clear_mask(gemm_k_iterations == 0);
|
||||
|
||||
iterator_A.set_iteration_index(0);
|
||||
this->smem_iterator_A_.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j)
|
||||
{
|
||||
typename IteratorA::AccessType* dst_ptr
|
||||
= reinterpret_cast<typename IteratorA::AccessType*>(this->smem_iterator_A_.get());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v)
|
||||
{
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value
|
||||
* IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8;
|
||||
|
||||
int src_bytes = (iterator_A.valid() ? kSrcBytes : 0);
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
|
||||
dst_ptr + v, iterator_A.get(), iterator_A.valid());
|
||||
|
||||
++iterator_A;
|
||||
}
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
}
|
||||
|
||||
iterator_B.set_iteration_index(0);
|
||||
this->smem_iterator_B_.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j)
|
||||
{
|
||||
typename IteratorB::AccessType* dst_ptr
|
||||
= reinterpret_cast<typename IteratorB::AccessType*>(this->smem_iterator_B_.get());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v)
|
||||
{
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value
|
||||
* IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
|
||||
dst_ptr + v, iterator_B.get(), iterator_B.valid());
|
||||
|
||||
++iterator_B;
|
||||
}
|
||||
|
||||
++this->smem_iterator_B_;
|
||||
}
|
||||
|
||||
// Move to the next stage
|
||||
iterator_A.add_tile_offset({0, 1});
|
||||
iterator_B.add_tile_offset({1, 0});
|
||||
|
||||
this->smem_iterator_A_.add_tile_offset({0, 1});
|
||||
this->smem_iterator_B_.add_tile_offset({1, 0});
|
||||
|
||||
// Defines the boundary of a stage of cp.async.
|
||||
cutlass::arch::cp_async_fence();
|
||||
}
|
||||
|
||||
// Perform accumulation in the 'd' output operand
|
||||
accum = src_accum;
|
||||
|
||||
//
|
||||
// Clear the remaining tiles of SMEM. This is a functional requirement for some kernels
|
||||
// so that all accumulator elements outside the GEMM footprint are zero.
|
||||
//
|
||||
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage)
|
||||
{
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
||||
SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_);
|
||||
|
||||
typename IteratorA::AccessType zero_A;
|
||||
zero_A.clear();
|
||||
|
||||
last_smem_iterator_A.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j)
|
||||
{
|
||||
|
||||
typename IteratorA::AccessType* dst_ptr
|
||||
= reinterpret_cast<typename IteratorA::AccessType*>(last_smem_iterator_A.get());
|
||||
|
||||
*dst_ptr = zero_A;
|
||||
|
||||
++last_smem_iterator_A;
|
||||
}
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_);
|
||||
typename IteratorB::AccessType zero_B;
|
||||
|
||||
zero_B.clear();
|
||||
last_smem_iterator_B.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j)
|
||||
{
|
||||
|
||||
typename IteratorB::AccessType* dst_ptr
|
||||
= reinterpret_cast<typename IteratorB::AccessType*>(last_smem_iterator_B.get());
|
||||
|
||||
*dst_ptr = zero_B;
|
||||
|
||||
++last_smem_iterator_B;
|
||||
}
|
||||
}
|
||||
|
||||
// Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - #committed)
|
||||
cutlass::arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math
|
||||
// instructions
|
||||
WarpFragmentA warp_frag_A[2];
|
||||
WarpFragmentB warp_frag_B[2];
|
||||
typename Dequantizer::FragmentScale warp_frag_scales;
|
||||
|
||||
Operator warp_mma;
|
||||
|
||||
this->warp_tile_iterator_A_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(0);
|
||||
|
||||
this->warp_tile_iterator_A_.load(warp_frag_A[0]);
|
||||
this->warp_tile_iterator_B_.load(warp_frag_B[0]);
|
||||
warp_dequantizer_.load(warp_frag_scales);
|
||||
|
||||
++this->warp_tile_iterator_A_;
|
||||
++this->warp_tile_iterator_B_;
|
||||
|
||||
iterator_A.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B.clear_mask(gemm_k_iterations == 0);
|
||||
|
||||
int smem_write_stage_idx = Base::kStages - 1;
|
||||
int smem_read_stage_idx = 0;
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; gemm_k_iterations > (-Base::kStages + 1);)
|
||||
{
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
// Computes a warp-level GEMM on data held in shared memory
|
||||
// Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k)
|
||||
{
|
||||
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if
|
||||
// this is the last group as the case may be.
|
||||
|
||||
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
|
||||
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
|
||||
++this->warp_tile_iterator_A_;
|
||||
|
||||
int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
|
||||
int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
|
||||
if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1)
|
||||
{
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(
|
||||
(warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB);
|
||||
this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]);
|
||||
++this->warp_tile_iterator_B_;
|
||||
}
|
||||
|
||||
typename TransformBAfterLDS::result_type converted_frag_B
|
||||
= lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);
|
||||
warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales);
|
||||
|
||||
using FragmentOperandB = cutlass::Array<ElementA, Operator::FragmentB::kElements>;
|
||||
constexpr cutlass::FloatRoundStyle RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
|
||||
constexpr int ConversionVectorWidth = TransformBAfterLDS::result_type::kElements;
|
||||
static_assert(ConversionVectorWidth == FragmentOperandB::kElements);
|
||||
|
||||
using Converter
|
||||
= cutlass::NumericArrayConverter<ElementA, ElementScale, ConversionVectorWidth, RoundStyle>;
|
||||
|
||||
FragmentOperandB converted_frag_B_operand = Converter::convert(converted_frag_B);
|
||||
run_warp_mma(warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B_operand, accum,
|
||||
warp_tileB_k_compute_offset);
|
||||
|
||||
// Issue global->shared copies for the this stage
|
||||
if (warp_mma_k < Base::kWarpGemmIterations - 1)
|
||||
{
|
||||
int group_start_iteration_A, group_start_iteration_B;
|
||||
|
||||
group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA;
|
||||
group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB;
|
||||
|
||||
copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B);
|
||||
}
|
||||
|
||||
if (warp_mma_k + 2 == Base::kWarpGemmIterations)
|
||||
{
|
||||
int group_start_iteration_A, group_start_iteration_B;
|
||||
group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA;
|
||||
group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB;
|
||||
|
||||
copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B);
|
||||
|
||||
// Inserts a memory fence between stages of cp.async instructions.
|
||||
cutlass::arch::cp_async_fence();
|
||||
|
||||
// Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 -
|
||||
// #committed)
|
||||
arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Move to the next stage
|
||||
iterator_A.add_tile_offset({0, 1});
|
||||
iterator_B.add_tile_offset({1, 0});
|
||||
|
||||
this->smem_iterator_A_.add_tile_offset({0, 1});
|
||||
this->smem_iterator_B_.add_tile_offset({1, 0});
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the
|
||||
// circular buffer in shared memory
|
||||
if (smem_write_stage_idx == (Base::kStages - 1))
|
||||
{
|
||||
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
|
||||
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
|
||||
smem_write_stage_idx = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
++smem_write_stage_idx;
|
||||
}
|
||||
|
||||
if (smem_read_stage_idx == (Base::kStages - 1))
|
||||
{
|
||||
this->warp_tile_iterator_A_.add_tile_offset(
|
||||
{0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
|
||||
this->warp_tile_iterator_B_.add_tile_offset(
|
||||
{-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0});
|
||||
smem_read_stage_idx = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
++smem_read_stage_idx;
|
||||
}
|
||||
|
||||
--gemm_k_iterations;
|
||||
iterator_A.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B.clear_mask(gemm_k_iterations == 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kZfill)
|
||||
{
|
||||
// commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop
|
||||
cutlass::arch::cp_async_fence();
|
||||
cutlass::arch::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,106 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h"
|
||||
#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h"
|
||||
#include "cutlass_extensions/interleaved_numeric_conversion.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h"
|
||||
#include "cutlass_extensions/gemm_configs.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Iterates over tiles of A operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
||||
typename IteratorA_,
|
||||
/// Iterates over tiles of A operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorA_,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
||||
typename IteratorB_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB_,
|
||||
/// Data type for the scales
|
||||
typename IteratorScale_,
|
||||
/// Iterators over scales in shared memory
|
||||
typename SmemIteratorScale_,
|
||||
/// Data type of accumulator matrix
|
||||
typename ElementC_,
|
||||
/// Data type of accumulator matrix
|
||||
typename LayoutC_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy_,
|
||||
/// Converter for B matrix applied immediately after the LDG (before STS)
|
||||
typename TransformBAfterLDG_,
|
||||
/// Converter for B matrix applited immediately after the LDS
|
||||
typename TransformBAfterLDS_,
|
||||
/// The quantization operator being used
|
||||
WeightOnlyQuantOp QuantOp_,
|
||||
/// Used for partial specialization
|
||||
typename Enable = void>
|
||||
class DqMmaPipelined;
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
#include "cutlass_extensions/gemm/threadblock/dq_mma_pipelined_finegrained.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/dq_mma_pipelined_percol.h"
|
||||
@@ -0,0 +1,486 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h"
|
||||
#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h"
|
||||
#include "cutlass_extensions/interleaved_numeric_conversion.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h"
|
||||
#include "cutlass_extensions/gemm_configs.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Iterates over tiles of A operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
||||
typename IteratorA_,
|
||||
/// Iterates over tiles of A operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorA_,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
||||
typename IteratorB_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB_,
|
||||
/// Iterators over scales in global memory
|
||||
typename IteratorScale_,
|
||||
/// Iterators over scales in shared memory
|
||||
typename SmemIteratorScale_,
|
||||
/// Data type of accumulator matrix
|
||||
typename ElementC_,
|
||||
/// Layout of accumulator matrix
|
||||
typename LayoutC_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy_,
|
||||
/// Converter for B matrix applied immediately after the LDG (before STS)
|
||||
typename TransformBAfterLDG_,
|
||||
/// Converter for B matrix applited immediately after the LDS
|
||||
typename TransformBAfterLDS_,
|
||||
/// The quantization operator being used
|
||||
WeightOnlyQuantOp QuantOp_>
|
||||
class DqMmaPipelined<Shape_, IteratorA_, SmemIteratorA_, IteratorB_, SmemIteratorB_, IteratorScale_, SmemIteratorScale_,
|
||||
ElementC_, LayoutC_, Policy_, TransformBAfterLDG_, TransformBAfterLDS_, QuantOp_,
|
||||
std::enable_if_t<isFinegrained(QuantOp_)>>
|
||||
: public DqMmaBase<Shape_, Policy_, typename SmemIteratorScale_::Element, 2, QuantOp_>
|
||||
{
|
||||
public:
|
||||
///< Base class
|
||||
using Base = DqMmaBase<Shape_, Policy_, typename SmemIteratorScale_::Element, 2, QuantOp_>;
|
||||
|
||||
using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory
|
||||
using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory
|
||||
using ElementC = ElementC_; ///< Data type of accumulator matrix
|
||||
using LayoutC = LayoutC_; ///< Layout of accumulator matrix
|
||||
using Policy = Policy_; ///< Policy describing tuning details
|
||||
|
||||
using IteratorScale = IteratorScale_;
|
||||
using ElementScale = typename IteratorScale::Element;
|
||||
using LayoutScale = typename IteratorScale::Layout;
|
||||
|
||||
using SmemIteratorA = SmemIteratorA_;
|
||||
using SmemIteratorB = SmemIteratorB_;
|
||||
using SmemIteratorScale = SmemIteratorScale_;
|
||||
|
||||
using TransformBAfterLDG = TransformBAfterLDG_;
|
||||
using TransformBAfterLDS = TransformBAfterLDS_;
|
||||
|
||||
static constexpr WeightOnlyQuantOp QuantOp = QuantOp_;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Fragment of operand A loaded from global memory
|
||||
using FragmentA = typename IteratorA::Fragment;
|
||||
|
||||
/// Fragment of operand B loaded from global memory
|
||||
using FragmentB = typename IteratorB::Fragment;
|
||||
|
||||
/// Fragment of operand Scale loaded from global memory;
|
||||
using FragmentScale = typename IteratorScale::Fragment;
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC = typename Policy::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator = typename Policy::Operator;
|
||||
|
||||
/// Obtain the arch tag from the warp-level operator
|
||||
using ArchTag = typename Policy::Operator::ArchTag;
|
||||
|
||||
using Dequantizer = warp::MmaTensorOpDequantizer<Operator, typename Base::WarpGemm, Operand::kB,
|
||||
typename SmemIteratorScale::Element, LayoutScale, 32, QuantOp>;
|
||||
|
||||
/// Complex transform on A operand
|
||||
static ComplexTransform const kTransformA = Operator::kTransformA;
|
||||
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB = Operator::kTransformB;
|
||||
|
||||
// staticaly assert kStages for DqMmaPipelined is two (Double-buffered pipeline)
|
||||
static_assert((Base::kStages == 2), "DqMmaPipelined requires kStages set to value 2");
|
||||
|
||||
static_assert(Base::SharedStorage::ShapeScale::kRow == Base::kStages, "");
|
||||
static_assert(Base::SharedStorage::ShapeScale::kColumn == Shape::kN, "");
|
||||
|
||||
private:
|
||||
using WarpFragmentA = typename Operator::FragmentA;
|
||||
using WarpFragmentB = typename Operator::FragmentB;
|
||||
Dequantizer warp_dequantizer_;
|
||||
|
||||
using WarpFragmentScale = typename Dequantizer::FragmentScale;
|
||||
using WarpFragmentZero = typename Dequantizer::FragmentZero;
|
||||
|
||||
using ElementA = typename IteratorA::Element;
|
||||
using ElementB = typename IteratorB::Element;
|
||||
using LayoutDetailsForB = kernel::LayoutDetailsB<ElementA, ElementB, ArchTag>;
|
||||
|
||||
static constexpr bool RequiresTileInterleave
|
||||
= layout::IsColumnMajorTileInterleave<typename LayoutDetailsForB::Layout>::value;
|
||||
static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)),
|
||||
"Layout K must match threadblockK");
|
||||
|
||||
protected:
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
||||
SmemIteratorA smem_iterator_A_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB smem_iterator_B_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of scale and zero operand to shared memory
|
||||
SmemIteratorScale smem_iterator_scale_;
|
||||
|
||||
public:
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
DqMmaPipelined(typename Base::SharedStorage&
|
||||
shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
int const group_size, ///< The group size for quantization
|
||||
int thread_idx, ///< ID within the threadblock
|
||||
int warp_idx, ///< ID of warp
|
||||
int lane_idx ///< ID of each thread within a warp
|
||||
)
|
||||
: Base(shared_storage, thread_idx, warp_idx, lane_idx)
|
||||
, warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)},
|
||||
{shared_storage.operand_zero.data(), LayoutScale(Shape::kN)},
|
||||
(warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx)
|
||||
, smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx)
|
||||
, smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx)
|
||||
, smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(),
|
||||
shared_storage.operand_zero.data(), {Base::kStages, Shape::kN}, thread_idx, group_size)
|
||||
{
|
||||
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
// _m: the warp's position within the threadblock along the M dimension
|
||||
// _n: the warp's position within the threadblock along the N dimension
|
||||
// _k: the warp's position within the threadblock along the K dimension
|
||||
|
||||
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
|
||||
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
|
||||
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
|
||||
|
||||
// Add per-warp offsets in units of warp-level tiles
|
||||
this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
|
||||
this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n});
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void copy_scales_and_advance(IteratorScale& iterator_scale)
|
||||
{
|
||||
using TransformScale = NumericArrayConverter<typename SmemIteratorScale::Element,
|
||||
typename FragmentScale::Element, FragmentScale::kElements>;
|
||||
|
||||
FragmentScale tb_frag_scales;
|
||||
FragmentScale tb_frag_zeros;
|
||||
tb_frag_scales.clear();
|
||||
tb_frag_zeros.clear();
|
||||
|
||||
TransformScale transformScale;
|
||||
|
||||
using FragmentElement = typename FragmentScale::Element;
|
||||
|
||||
auto gmem_scale_ptr = iterator_scale.get_scale();
|
||||
auto gmem_zero_ptr = iterator_scale.get_zero();
|
||||
|
||||
arch::global_load<FragmentScale, sizeof(FragmentScale)>(tb_frag_scales, gmem_scale_ptr, iterator_scale.valid());
|
||||
|
||||
if (gmem_zero_ptr != nullptr)
|
||||
{
|
||||
arch::global_load<FragmentScale, sizeof(FragmentScale)>(
|
||||
tb_frag_zeros, gmem_zero_ptr, iterator_scale.valid());
|
||||
}
|
||||
|
||||
typename TransformScale::result_type tb_frag_scales_fp16 = transformScale(tb_frag_scales);
|
||||
typename TransformScale::result_type tb_frag_zeros_fp16;
|
||||
if (gmem_zero_ptr != nullptr)
|
||||
tb_frag_zeros_fp16 = transformScale(tb_frag_zeros);
|
||||
|
||||
auto frag_scale_ptr_fp16 = reinterpret_cast<typename SmemIteratorScale::Element*>(&tb_frag_scales_fp16);
|
||||
auto frag_zero_ptr_fp16 = reinterpret_cast<typename SmemIteratorScale::Element*>(&tb_frag_zeros_fp16);
|
||||
auto smem_scale_ptr = this->smem_iterator_scale_.get_scale();
|
||||
auto smem_zero_ptr = this->smem_iterator_scale_.get_zero();
|
||||
|
||||
if (iterator_scale.valid())
|
||||
{
|
||||
auto smem_offset = cast_smem_ptr_to_uint(smem_scale_ptr);
|
||||
arch::shared_store<sizeof(FragmentScale)>(smem_offset, frag_scale_ptr_fp16);
|
||||
|
||||
if (gmem_zero_ptr != nullptr)
|
||||
{
|
||||
smem_offset = cast_smem_ptr_to_uint(smem_zero_ptr);
|
||||
arch::shared_store<sizeof(FragmentScale)>(smem_offset, frag_zero_ptr_fp16);
|
||||
}
|
||||
}
|
||||
|
||||
if (iterator_scale.group_size_ == 64)
|
||||
{
|
||||
iterator_scale.add_tile_offset({1, 0});
|
||||
}
|
||||
else if (iterator_scale.group_size_ == 128)
|
||||
{
|
||||
if constexpr (Shape::kK == 128)
|
||||
{
|
||||
iterator_scale.add_tile_offset({1, 0});
|
||||
}
|
||||
else if constexpr (Shape::kK == 64)
|
||||
{
|
||||
if (iterator_scale.row_groupsize64_ & 0x1)
|
||||
{
|
||||
iterator_scale.add_tile_offset({1, 0});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(Shape::kK == 0, "Unsupported k tile shape, can only be 64 or 128");
|
||||
}
|
||||
}
|
||||
|
||||
iterator_scale.row_groupsize64_++;
|
||||
|
||||
this->smem_iterator_scale_.add_tile_offset({1, 0});
|
||||
}
|
||||
|
||||
/// Perform a threadblock-scoped matrix multiply-accumulate
|
||||
CUTLASS_DEVICE
|
||||
void operator()(int gemm_k_iterations, ///< number of iterations of the mainloop
|
||||
FragmentC& accum, ///< destination accumulator tile
|
||||
IteratorA iterator_A, ///< iterator over A operand in global memory
|
||||
IteratorB iterator_B, ///< iterator over B operand in global memory
|
||||
IteratorScale iterator_scale, ///< iterator over scale operand in global memory
|
||||
FragmentC const& src_accum)
|
||||
{ ///< source accumulator tile
|
||||
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
TransformBAfterLDG ldg_converter;
|
||||
TransformBAfterLDS lds_converter;
|
||||
|
||||
using TransformA
|
||||
= NumericArrayConverter<typename WarpFragmentA::Element, typename FragmentA::Element, FragmentA::kElements>;
|
||||
|
||||
// These transforms are mainly to handle when we have bfloat activations and weights in GMEM and want
|
||||
// to issue HMMA on architectures older than Ampere. We will convert to FP16 before STS.
|
||||
TransformA transformA;
|
||||
|
||||
// Perform accumulation in the 'd' output operand
|
||||
accum = src_accum;
|
||||
|
||||
FragmentA tb_frag_A;
|
||||
FragmentB tb_frag_B;
|
||||
|
||||
tb_frag_A.clear();
|
||||
tb_frag_B.clear();
|
||||
|
||||
// The last kblock is loaded in the prolog
|
||||
iterator_A.load(tb_frag_A);
|
||||
iterator_B.load(tb_frag_B);
|
||||
|
||||
++iterator_A;
|
||||
++iterator_B;
|
||||
|
||||
this->smem_iterator_A_.store(transformA(tb_frag_A));
|
||||
this->smem_iterator_B_.store(ldg_converter(tb_frag_B));
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
++this->smem_iterator_B_;
|
||||
|
||||
copy_scales_and_advance(iterator_scale);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math instructions
|
||||
WarpFragmentA warp_frag_A[2];
|
||||
WarpFragmentB warp_frag_B[2];
|
||||
WarpFragmentScale warp_frag_scales;
|
||||
WarpFragmentZero warp_frag_zero;
|
||||
|
||||
this->warp_tile_iterator_A_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(0);
|
||||
|
||||
this->warp_tile_iterator_A_.load(warp_frag_A[0]);
|
||||
this->warp_tile_iterator_B_.load(warp_frag_B[0]);
|
||||
|
||||
warp_dequantizer_.load(warp_frag_scales, warp_frag_zero);
|
||||
|
||||
++this->warp_tile_iterator_A_;
|
||||
++this->warp_tile_iterator_B_;
|
||||
warp_dequantizer_.add_pointer_offset(Shape::kN);
|
||||
|
||||
Operator warp_mma;
|
||||
|
||||
int smem_write_stage_idx = 1;
|
||||
|
||||
// Avoid reading out of bounds
|
||||
iterator_A.clear_mask(gemm_k_iterations <= 1);
|
||||
iterator_B.clear_mask(gemm_k_iterations <= 1);
|
||||
iterator_scale.clear_mask(gemm_k_iterations <= 1);
|
||||
|
||||
// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
|
||||
// shared memory loads (which have the tighest latency requirement).
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
// Note: The main loop does not support Base::kWarpGemmIterations == 2.
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; gemm_k_iterations > 0; --gemm_k_iterations)
|
||||
{
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k)
|
||||
{
|
||||
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if this is the last group
|
||||
// as the case may be.
|
||||
|
||||
if (warp_mma_k == Base::kWarpGemmIterations - 1)
|
||||
{
|
||||
|
||||
// Write fragments to shared memory
|
||||
this->smem_iterator_A_.store(transformA(tb_frag_A));
|
||||
|
||||
this->smem_iterator_B_.store(ldg_converter(tb_frag_B));
|
||||
|
||||
__syncthreads();
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
++this->smem_iterator_B_;
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
|
||||
if (smem_write_stage_idx == 1)
|
||||
{
|
||||
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
|
||||
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
|
||||
this->smem_iterator_scale_.add_tile_offset({-Base::kStages, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
this->warp_tile_iterator_A_.add_tile_offset(
|
||||
{0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
|
||||
this->warp_tile_iterator_B_.add_tile_offset(
|
||||
{-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0});
|
||||
warp_dequantizer_.add_pointer_offset(-Base::kStages * Shape::kN);
|
||||
}
|
||||
|
||||
smem_write_stage_idx ^= 1;
|
||||
}
|
||||
|
||||
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
|
||||
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
|
||||
++this->warp_tile_iterator_A_;
|
||||
|
||||
int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
|
||||
int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
|
||||
// We are just about to finish computing on a fragment of B, so initiate the load for the next fragment.
|
||||
if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1)
|
||||
{
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(
|
||||
(warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB);
|
||||
this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]);
|
||||
++this->warp_tile_iterator_B_;
|
||||
}
|
||||
|
||||
if (warp_mma_k == 0)
|
||||
{
|
||||
|
||||
iterator_A.load(tb_frag_A);
|
||||
iterator_B.load(tb_frag_B);
|
||||
|
||||
++iterator_A;
|
||||
++iterator_B;
|
||||
|
||||
copy_scales_and_advance(iterator_scale);
|
||||
|
||||
// Avoid reading out of bounds if this was the last loop iteration
|
||||
iterator_A.clear_mask(gemm_k_iterations <= 2);
|
||||
iterator_B.clear_mask(gemm_k_iterations <= 2);
|
||||
iterator_scale.clear_mask(gemm_k_iterations <= 2);
|
||||
}
|
||||
|
||||
typename TransformBAfterLDS::result_type converted_frag_B
|
||||
= lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);
|
||||
warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales, warp_frag_zero);
|
||||
run_warp_mma(
|
||||
warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset);
|
||||
}
|
||||
|
||||
// Load the scales needed for the next tile iteration
|
||||
warp_dequantizer_.load(warp_frag_scales, warp_frag_zero);
|
||||
// Update internal pointer to the set of scales in shared memory
|
||||
warp_dequantizer_.add_pointer_offset(Shape::kN);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,399 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h"
|
||||
#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h"
|
||||
#include "cutlass_extensions/interleaved_numeric_conversion.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h"
|
||||
#include "cutlass_extensions/gemm_configs.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Iterates over tiles of A operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
||||
typename IteratorA_,
|
||||
/// Iterates over tiles of A operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorA_,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
||||
typename IteratorB_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB_,
|
||||
/// Iterators over scales in global memory
|
||||
typename IteratorScale_,
|
||||
/// Iterators over scales in shared memory
|
||||
typename SmemIteratorScale_,
|
||||
/// Data type of accumulator matrix
|
||||
typename ElementC_,
|
||||
/// Layout of accumulator matrix
|
||||
typename LayoutC_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy_,
|
||||
/// Converter for B matrix applied immediately after the LDG (before STS)
|
||||
typename TransformBAfterLDG_,
|
||||
/// Converter for B matrix applited immediately after the LDS
|
||||
typename TransformBAfterLDS_,
|
||||
/// The quantization operator being used
|
||||
WeightOnlyQuantOp QuantOp_>
|
||||
class DqMmaPipelined<Shape_, IteratorA_, SmemIteratorA_, IteratorB_, SmemIteratorB_, IteratorScale_, SmemIteratorScale_,
|
||||
ElementC_, LayoutC_, Policy_, TransformBAfterLDG_, TransformBAfterLDS_, QuantOp_,
|
||||
std::enable_if_t<!isFinegrained(QuantOp_)>>
|
||||
: public DqMmaBase<Shape_, Policy_, typename SmemIteratorScale_::Element, 2, QuantOp_>
|
||||
{
|
||||
public:
|
||||
///< Base class
|
||||
using Base = DqMmaBase<Shape_, Policy_, typename SmemIteratorScale_::Element, 2, QuantOp_>;
|
||||
|
||||
using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory
|
||||
using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory
|
||||
using ElementC = ElementC_; ///< Data type of accumulator matrix
|
||||
using LayoutC = LayoutC_; ///< Layout of accumulator matrix
|
||||
using Policy = Policy_; ///< Policy describing tuning details
|
||||
|
||||
using IteratorScale = IteratorScale_;
|
||||
using ElementScale = typename IteratorScale::Element;
|
||||
using LayoutScale = typename IteratorScale::Layout;
|
||||
|
||||
using SmemIteratorA = SmemIteratorA_;
|
||||
using SmemIteratorB = SmemIteratorB_;
|
||||
using SmemIteratorScale = SmemIteratorScale_;
|
||||
|
||||
using TransformBAfterLDG = TransformBAfterLDG_;
|
||||
using TransformBAfterLDS = TransformBAfterLDS_;
|
||||
|
||||
static constexpr WeightOnlyQuantOp QuantOp = QuantOp_;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Fragment of operand A loaded from global memory
|
||||
using FragmentA = typename IteratorA::Fragment;
|
||||
|
||||
/// Fragment of operand B loaded from global memory
|
||||
using FragmentB = typename IteratorB::Fragment;
|
||||
|
||||
/// Fragment of operand Scale loaded from global memory;
|
||||
using FragmentScale = typename IteratorScale::Fragment;
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC = typename Policy::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator = typename Policy::Operator;
|
||||
|
||||
/// Obtain the arch tag from the warp-level operator
|
||||
using ArchTag = typename Policy::Operator::ArchTag;
|
||||
|
||||
using Dequantizer = warp::MmaTensorOpDequantizer<Operator, typename Base::WarpGemm, Operand::kB,
|
||||
typename SmemIteratorScale::Fragment::Element, LayoutScale, 32, QuantOp>;
|
||||
|
||||
/// Complex transform on A operand
|
||||
static ComplexTransform const kTransformA = Operator::kTransformA;
|
||||
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB = Operator::kTransformB;
|
||||
|
||||
// staticaly assert kStages for DqMmaPipelined is two (Double-buffered pipeline)
|
||||
static_assert((Base::kStages == 2), "DqMmaPipelined requires kStages set to value 2");
|
||||
|
||||
private:
|
||||
using WarpFragmentA = typename Operator::FragmentA;
|
||||
using WarpFragmentB = typename Operator::FragmentB;
|
||||
Dequantizer warp_dequantizer_;
|
||||
|
||||
using ElementA = typename IteratorA::Element;
|
||||
using ElementB = typename IteratorB::Element;
|
||||
using LayoutDetailsForB = kernel::LayoutDetailsB<ElementA, ElementB, ArchTag>;
|
||||
|
||||
static constexpr bool RequiresTileInterleave
|
||||
= layout::IsColumnMajorTileInterleave<typename LayoutDetailsForB::Layout>::value;
|
||||
static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)),
|
||||
"Layout K must match threadblockK");
|
||||
|
||||
protected:
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
||||
SmemIteratorA smem_iterator_A_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB smem_iterator_B_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of scale operand to shared memory
|
||||
SmemIteratorScale smem_iterator_scale_;
|
||||
|
||||
public:
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
DqMmaPipelined(typename Base::SharedStorage&
|
||||
shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
int const group_size, ///< Will not be used, just to adapt to finegrained modifications and make the compilation
|
||||
///< successful. Because DqMmaPipelined is only enabled for sm<80, so even if this
|
||||
///< argument is not added, it does not affect compilation for sm>=80.
|
||||
int thread_idx, ///< ID within the threadblock
|
||||
int warp_idx, ///< ID of warp
|
||||
int lane_idx ///< ID of each thread within a warp
|
||||
)
|
||||
: Base(shared_storage, thread_idx, warp_idx, lane_idx)
|
||||
, warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)},
|
||||
(warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx)
|
||||
, smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx)
|
||||
, smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx)
|
||||
, smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), {1, Shape::kN}, thread_idx)
|
||||
{
|
||||
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
// _m: the warp's position within the threadblock along the M dimension
|
||||
// _n: the warp's position within the threadblock along the N dimension
|
||||
// _k: the warp's position within the threadblock along the K dimension
|
||||
|
||||
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
|
||||
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
|
||||
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
|
||||
|
||||
// Add per-warp offsets in units of warp-level tiles
|
||||
this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
|
||||
this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n});
|
||||
}
|
||||
|
||||
/// Perform a threadblock-scoped matrix multiply-accumulate
|
||||
CUTLASS_DEVICE
|
||||
void operator()(int gemm_k_iterations, ///< number of iterations of the mainloop
|
||||
FragmentC& accum, ///< destination accumulator tile
|
||||
IteratorA iterator_A, ///< iterator over A operand in global memory
|
||||
IteratorB iterator_B, ///< iterator over B operand in global memory
|
||||
IteratorScale iterator_scale, ///< iterator over scale operand in global memory
|
||||
FragmentC const& src_accum)
|
||||
{ ///< source accumulator tile
|
||||
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
TransformBAfterLDG ldg_converter;
|
||||
TransformBAfterLDS lds_converter;
|
||||
|
||||
using TransformA
|
||||
= NumericArrayConverter<typename WarpFragmentA::Element, typename FragmentA::Element, FragmentA::kElements>;
|
||||
|
||||
using TransformScale = NumericArrayConverter<typename SmemIteratorScale::Fragment::Element,
|
||||
typename FragmentScale::Element, FragmentScale::kElements>;
|
||||
|
||||
// These transforms are mainly to handle when we have bfloat activations and weights in GMEM and want
|
||||
// to issue HMMA on architectures older than Ampere. We will convert to FP16 before STS.
|
||||
TransformA transformA;
|
||||
TransformScale transformScale;
|
||||
|
||||
// Perform accumulation in the 'd' output operand
|
||||
accum = src_accum;
|
||||
|
||||
FragmentA tb_frag_A;
|
||||
FragmentB tb_frag_B;
|
||||
FragmentScale tb_frag_scales;
|
||||
|
||||
using WarpFragmentScale = typename Dequantizer::FragmentScale;
|
||||
WarpFragmentScale warp_frag_scales;
|
||||
|
||||
tb_frag_A.clear();
|
||||
tb_frag_B.clear();
|
||||
tb_frag_scales.clear();
|
||||
|
||||
// The last kblock is loaded in the prolog
|
||||
iterator_A.load(tb_frag_A);
|
||||
iterator_B.load(tb_frag_B);
|
||||
iterator_scale.load(tb_frag_scales);
|
||||
|
||||
++iterator_A;
|
||||
++iterator_B;
|
||||
|
||||
this->smem_iterator_A_.store(transformA(tb_frag_A));
|
||||
this->smem_iterator_B_.store(ldg_converter(tb_frag_B));
|
||||
this->smem_iterator_scale_.store(transformScale(tb_frag_scales));
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
++this->smem_iterator_B_;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
warp_dequantizer_.load(warp_frag_scales);
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math instructions
|
||||
WarpFragmentA warp_frag_A[2];
|
||||
WarpFragmentB warp_frag_B[2];
|
||||
|
||||
this->warp_tile_iterator_A_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(0);
|
||||
|
||||
this->warp_tile_iterator_A_.load(warp_frag_A[0]);
|
||||
this->warp_tile_iterator_B_.load(warp_frag_B[0]);
|
||||
|
||||
++this->warp_tile_iterator_A_;
|
||||
++this->warp_tile_iterator_B_;
|
||||
|
||||
Operator warp_mma;
|
||||
|
||||
int smem_write_stage_idx = 1;
|
||||
|
||||
// Avoid reading out of bounds
|
||||
iterator_A.clear_mask(gemm_k_iterations <= 1);
|
||||
iterator_B.clear_mask(gemm_k_iterations <= 1);
|
||||
|
||||
// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
|
||||
// shared memory loads (which have the tighest latency requirement).
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
// Note: The main loop does not support Base::kWarpGemmIterations == 2.
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; gemm_k_iterations > 0; --gemm_k_iterations)
|
||||
{
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k)
|
||||
{
|
||||
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if this is the last group
|
||||
// as the case may be.
|
||||
|
||||
if (warp_mma_k == Base::kWarpGemmIterations - 1)
|
||||
{
|
||||
|
||||
// Write fragments to shared memory
|
||||
this->smem_iterator_A_.store(transformA(tb_frag_A));
|
||||
|
||||
this->smem_iterator_B_.store(ldg_converter(tb_frag_B));
|
||||
|
||||
__syncthreads();
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
++this->smem_iterator_B_;
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
|
||||
if (smem_write_stage_idx == 1)
|
||||
{
|
||||
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
|
||||
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
this->warp_tile_iterator_A_.add_tile_offset(
|
||||
{0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
|
||||
this->warp_tile_iterator_B_.add_tile_offset(
|
||||
{-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0});
|
||||
}
|
||||
|
||||
smem_write_stage_idx ^= 1;
|
||||
}
|
||||
|
||||
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
|
||||
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
|
||||
++this->warp_tile_iterator_A_;
|
||||
|
||||
int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
|
||||
int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
|
||||
// We are just about to finish computing on a fragment of B, so initiate the load for the next fragment.
|
||||
if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1)
|
||||
{
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(
|
||||
(warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB);
|
||||
this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]);
|
||||
++this->warp_tile_iterator_B_;
|
||||
}
|
||||
|
||||
if (warp_mma_k == 0)
|
||||
{
|
||||
|
||||
iterator_A.load(tb_frag_A);
|
||||
iterator_B.load(tb_frag_B);
|
||||
|
||||
++iterator_A;
|
||||
++iterator_B;
|
||||
|
||||
// Avoid reading out of bounds if this was the last loop iteration
|
||||
iterator_A.clear_mask(gemm_k_iterations <= 2);
|
||||
iterator_B.clear_mask(gemm_k_iterations <= 2);
|
||||
}
|
||||
|
||||
typename TransformBAfterLDS::result_type converted_frag_B
|
||||
= lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);
|
||||
warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales);
|
||||
run_warp_mma(
|
||||
warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,107 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Default warp-level GEMM operators selected by data type, size, and layouts of operands.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/warp/default_mma_tensor_op.h"
|
||||
#include "cutlass/gemm/warp/mma_tensor_op.h"
|
||||
|
||||
#include "cutlass_extensions/arch/mma.h"
|
||||
#include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h"
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace warp
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for m-by-n-by-kgroup
|
||||
template <
|
||||
/// Shape of one matrix production operation (concept: GemmShape)
|
||||
typename WarpShape_,
|
||||
/// Shape of one matrix production operation (concept: GemmShape)
|
||||
typename InstructionShape_,
|
||||
/// Data type of A elements,
|
||||
typename ElementA,
|
||||
/// Layout of A matrix (concept: MatrixLayout)
|
||||
typename LayoutA,
|
||||
/// Data type of B elements
|
||||
typename ElementB,
|
||||
/// Layout of B matrix (concept: MatrixLayout)
|
||||
typename LayoutB,
|
||||
/// Element type of C matrix
|
||||
typename ElementC,
|
||||
/// Layout of C matrix (concept: MatrixLayout)
|
||||
typename LayoutC,
|
||||
/// Number of partitions along K dimension
|
||||
int PartitionsK,
|
||||
/// Store the accumulators in row major or column major. Row major is used
|
||||
/// when output layout is interleaved.
|
||||
bool AccumulatorsInRowMajor>
|
||||
struct DefaultMmaTensorOp<WarpShape_, InstructionShape_, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
|
||||
arch::OpMultiplyAddDequantizeInterleavedBToA, PartitionsK, AccumulatorsInRowMajor>
|
||||
{
|
||||
|
||||
private:
|
||||
// Shape for computing the FP16s
|
||||
using ComputeInstructionShape = InstructionShape_;
|
||||
|
||||
// Chosen so we get K=16 for int8 and K=32 for int4.
|
||||
static constexpr int LoadInstructionK = 128 / sizeof_bits<ElementB>::value;
|
||||
|
||||
// Shape for loading the narrow data type from shared memory
|
||||
using LoadInstructionShape = GemmShape<InstructionShape_::kM, InstructionShape_::kN, LoadInstructionK>;
|
||||
|
||||
public:
|
||||
using Policy = cutlass::gemm::warp::MmaTensorOpPolicy<
|
||||
cutlass::arch::Mma<InstructionShape_, 32, ElementA, cutlass::layout::RowMajor, ElementA,
|
||||
cutlass::layout::ColumnMajor, ElementC, cutlass::layout::RowMajor, arch::OpMultiplyAdd>,
|
||||
cutlass::MatrixShape<1, 1>>;
|
||||
|
||||
// Define the warp-level tensor op
|
||||
using Type = cutlass::gemm::warp::MmaTensorOpComputeBWithF16<WarpShape_, ElementA, LayoutA, ElementB, LayoutB,
|
||||
ElementC, LayoutC, Policy, LoadInstructionShape, PartitionsK, AccumulatorsInRowMajor>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace warp
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,306 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Templates implementing warp-level matrix multiply-accumulate operations targeting
|
||||
Tensor Cores.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/platform/platform.h"
|
||||
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/arch/memory_sm75.h"
|
||||
#include "cutlass/arch/mma_sm75.h"
|
||||
#include "cutlass/arch/mma_sm80.h"
|
||||
#include "cutlass/arch/mma_sm89.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/warp/mma.h"
|
||||
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_policy.h"
|
||||
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h"
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace warp
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Data type of A elements
|
||||
typename ElementA_,
|
||||
/// Layout of A matrix (concept: MatrixLayout)
|
||||
typename LayoutA_,
|
||||
/// Data type of B elements
|
||||
typename ElementB_,
|
||||
/// Layout of B matrix (concept: MatrixLayout)
|
||||
typename LayoutB_,
|
||||
/// Element type of C matrix
|
||||
typename ElementC_,
|
||||
/// Layout of C matrix (concept: MatrixLayout)
|
||||
typename LayoutC_,
|
||||
/// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy)
|
||||
typename Policy_,
|
||||
/// Instruction shape to override shared memory iterators with
|
||||
typename SharedMemoryInstructionShape_,
|
||||
/// Number of partitions along K dimension
|
||||
int PartitionsK_ = 1,
|
||||
/// Store the accumulators in row major or column major. Row major is used
|
||||
/// when output layout is interleaved.
|
||||
bool AccumulatorsInRowMajor = false,
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool>
|
||||
class MmaTensorOpComputeBWithF16
|
||||
{
|
||||
public:
|
||||
/// Shape of warp-level matrix operation (concept: GemmShape)
|
||||
using Shape = Shape_;
|
||||
|
||||
/// Data type of multiplicand A
|
||||
using ElementA = ElementA_;
|
||||
|
||||
/// Layout of multiplicand A
|
||||
using LayoutA = LayoutA_;
|
||||
|
||||
/// Data type of multiplicand B
|
||||
using ElementB = ElementB_;
|
||||
|
||||
/// Layout of multiplicand B
|
||||
using LayoutB = LayoutB_;
|
||||
|
||||
/// Data type of accumulator matrix C
|
||||
using ElementC = ElementC_;
|
||||
|
||||
/// Layout of accumulator matrix C
|
||||
using LayoutC = LayoutC_;
|
||||
|
||||
/// Shape of the warp in units of thread (concept: MmaLanePolicySimt)
|
||||
using Policy = Policy_;
|
||||
|
||||
/// Underlying matrix multiply operator (concept: arch::Mma)
|
||||
using ArchMmaOperator = typename Policy::Operator;
|
||||
|
||||
/// Indicates math operator
|
||||
using MathOperator = typename ArchMmaOperator::Operator;
|
||||
|
||||
/// Architecture tag from underlying instruction
|
||||
using ArchTag = typename ArchMmaOperator::ArchTag;
|
||||
static_assert((platform::is_same<typename ArchMmaOperator::ElementA, half_t>::value
|
||||
&& platform::is_same<typename ArchMmaOperator::ElementB, half_t>::value)
|
||||
|| (platform::is_same<typename ArchMmaOperator::ElementA, bfloat16_t>::value
|
||||
&& platform::is_same<typename ArchMmaOperator::ElementB, bfloat16_t>::value
|
||||
&& ArchTag::kMinComputeCapability >= 80)
|
||||
|| (platform::is_same<typename ArchMmaOperator::ElementA, float_e4m3_t>::value
|
||||
&& platform::is_same<typename ArchMmaOperator::ElementB, float_e4m3_t>::value
|
||||
&& ArchTag::kMinComputeCapability >= 89),
|
||||
"MmaTensorOpCvtBToA only supports underlying HMMA/QMMA");
|
||||
|
||||
static_assert(platform::is_same<ElementA, half_t>::value
|
||||
|| (platform::is_same<ElementA, bfloat16_t>::value && ArchTag::kMinComputeCapability >= 80)
|
||||
|| (platform::is_same<ElementA, float_e4m3_t>::value && ArchTag::kMinComputeCapability >= 89),
|
||||
"MmaTensorOpCvtBToA only supports Fp16 A or Bf16 A on Ampere+, or FP8 on Ada");
|
||||
|
||||
/// Indicates class of matrix operator
|
||||
using OperatorClass = arch::OpClassTensorOp;
|
||||
|
||||
/// Shape of underlying instruction
|
||||
using InstructionShape = typename ArchMmaOperator::Shape;
|
||||
|
||||
/// Instruction shape to override shared memory iterators with
|
||||
using SharedMemoryInstructionShape = SharedMemoryInstructionShape_;
|
||||
|
||||
static_assert(
|
||||
SharedMemoryInstructionShape::kM == InstructionShape::kM, "M dimension of compute instruction must match load");
|
||||
static_assert(
|
||||
SharedMemoryInstructionShape::kN == InstructionShape::kN, "N dimension of compute instruction must match load");
|
||||
|
||||
static constexpr int kExpansionFactor = SharedMemoryInstructionShape::kK / InstructionShape::kK;
|
||||
|
||||
static_assert(!(Shape::kK % SharedMemoryInstructionShape::kK), "");
|
||||
|
||||
/// Complex transform on A operand
|
||||
static ComplexTransform const kTransformA = ComplexTransform::kNone;
|
||||
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB = ComplexTransform::kNone;
|
||||
|
||||
/// Number of threads participating in warp-level matrix product
|
||||
static int const kThreadCount = 32;
|
||||
|
||||
/// Number of partitions along K dimension
|
||||
static int const kPartitionsK = PartitionsK_;
|
||||
|
||||
public:
|
||||
/// Iterates over the A operand in memory
|
||||
using IteratorA
|
||||
= MmaTensorOpMultiplicandTileIterator<MatrixShape<Shape::kM, Shape::kK>, Operand::kA, ElementA, LayoutA,
|
||||
MatrixShape<InstructionShape::kM, InstructionShape::kK>, Policy::OpDelta::kRow, kThreadCount, kPartitionsK>;
|
||||
|
||||
/// Storage for A tile
|
||||
using FragmentA = typename IteratorA::Fragment;
|
||||
|
||||
/// Storage for transformed A tile
|
||||
using TransformedFragmentA = Array<typename ArchMmaOperator::ElementA, FragmentA::kElements>;
|
||||
|
||||
/// Iterates over the B operand in memory
|
||||
using IteratorB = MmaTensorOpMultiplicandTileIterator<MatrixShape<Shape::kK, Shape::kN>, Operand::kB, ElementB,
|
||||
LayoutB, MatrixShape<SharedMemoryInstructionShape::kK, InstructionShape::kN>, Policy::OpDelta::kRow,
|
||||
kThreadCount, kPartitionsK>;
|
||||
|
||||
/// Storage for B tile
|
||||
using FragmentB = typename IteratorB::Fragment;
|
||||
|
||||
/// Storage for transformed B tile
|
||||
using TransformedFragmentB = Array<typename ArchMmaOperator::ElementB, FragmentB::kElements>;
|
||||
|
||||
/// Iterates over the C operand in memory
|
||||
using IteratorC = MmaTensorOpAccumulatorTileIterator<MatrixShape<Shape::kM, Shape::kN>, ElementC, LayoutC,
|
||||
typename ArchMmaOperator::Shape, typename Policy::OpDelta>;
|
||||
|
||||
/// Storage for C tile
|
||||
using FragmentC = typename IteratorC::Fragment;
|
||||
|
||||
/// Number of mma operations performed
|
||||
using MmaIterations = MatrixShape<(Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM,
|
||||
(Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN>;
|
||||
|
||||
public:
|
||||
/// Underlying matrix multiply operator (concept: arch::Mma)
|
||||
ArchMmaOperator mma;
|
||||
|
||||
public:
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Ctor
|
||||
CUTLASS_DEVICE
|
||||
MmaTensorOpComputeBWithF16() {}
|
||||
|
||||
/// Performs a warp-level matrix multiply-accumulate operation
|
||||
CUTLASS_DEVICE
|
||||
void operator()(FragmentC& D, TransformedFragmentA const& A, TransformedFragmentB const& B, FragmentC const& C,
|
||||
int const warp_tileB_k_offset) const
|
||||
{
|
||||
|
||||
using MmaOperandA = typename ArchMmaOperator::FragmentA;
|
||||
using MmaOperandB = typename ArchMmaOperator::FragmentB;
|
||||
using MmaOperandC = typename ArchMmaOperator::FragmentC;
|
||||
|
||||
static_assert(
|
||||
TransformedFragmentB::kElements == MmaOperandB::kElements * kExpansionFactor * MmaIterations::kColumn,
|
||||
"Each thread should have a pack of mma registers for each column iteration AND for the expanded K dim of "
|
||||
"B");
|
||||
|
||||
D = C;
|
||||
|
||||
MmaOperandA const* ptr_A = reinterpret_cast<MmaOperandA const*>(&A);
|
||||
MmaOperandB const* ptr_B = reinterpret_cast<MmaOperandB const*>(&B);
|
||||
MmaOperandC* ptr_D = reinterpret_cast<MmaOperandC*>(&D);
|
||||
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
|
||||
// Serpentine visitation order maximizing reuse of Rb
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int n = 0; n < MmaIterations::kColumn; ++n)
|
||||
{
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int m = 0; m < MmaIterations::kRow; ++m)
|
||||
{
|
||||
|
||||
int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m);
|
||||
|
||||
int n_offsetB = warp_tileB_k_offset + kExpansionFactor * n;
|
||||
if (AccumulatorsInRowMajor)
|
||||
{ // matrix B is reordered
|
||||
mma(ptr_D[n + m_serpentine * MmaIterations::kColumn], ptr_A[m_serpentine], ptr_B[n_offsetB],
|
||||
ptr_D[n + m_serpentine * MmaIterations::kColumn]);
|
||||
}
|
||||
else
|
||||
{
|
||||
mma(ptr_D[m_serpentine + n * MmaIterations::kRow], ptr_A[m_serpentine], ptr_B[n_offsetB],
|
||||
ptr_D[m_serpentine + n * MmaIterations::kRow]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
||||
// Serpentine visitation order maximizing reuse of Ra
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int m = 0; m < MmaIterations::kRow; ++m)
|
||||
{
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int n = 0; n < MmaIterations::kColumn; ++n)
|
||||
{
|
||||
|
||||
int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n);
|
||||
|
||||
int n_serpentine_offsetB = warp_tileB_k_offset + kExpansionFactor * n_serpentine;
|
||||
if (AccumulatorsInRowMajor)
|
||||
{ // matrix B is reordered
|
||||
mma(ptr_D[n_serpentine + m * MmaIterations::kColumn], ptr_A[m], ptr_B[n_serpentine_offsetB],
|
||||
ptr_D[n_serpentine + m * MmaIterations::kColumn]);
|
||||
}
|
||||
else
|
||||
{
|
||||
mma(ptr_D[m + n_serpentine * MmaIterations::kRow], ptr_A[m], ptr_B[n_serpentine_offsetB],
|
||||
ptr_D[m + n_serpentine * MmaIterations::kRow]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
assert(0);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace warp
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,463 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/arch/memory_sm75.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/layout/pitch_linear.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
|
||||
#include "cutlass/functional.h"
|
||||
#include "cutlass/platform/platform.h"
|
||||
|
||||
#include "cutlass_extensions/weight_only_quant_op.h"
|
||||
#include "tensorrt_llm/common/cudaBf16Wrapper.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace warp
|
||||
{
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// Matrix multiply operator
|
||||
typename MmaOperator_,
|
||||
/// Size of the matrix to load (concept: MatrixShape)
|
||||
typename Shape_,
|
||||
/// Operand identity
|
||||
Operand Operand,
|
||||
/// Data type of Scale elements
|
||||
typename Element_,
|
||||
/// Layout of operand
|
||||
typename Layout_,
|
||||
/// Number of threads participating in one matrix operation
|
||||
int Threads,
|
||||
///
|
||||
WeightOnlyQuantOp QuantOp_,
|
||||
///
|
||||
typename Enable = void>
|
||||
class MmaTensorOpDequantizer;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Bfloat specialization for Ampere
|
||||
template <
|
||||
/// Underlying matrix multiply operator (concept: MmaTensorOp)
|
||||
typename MmaOperator_,
|
||||
/// Shape of the warp level matrix multiply (concept: GemmShape)
|
||||
typename Shape_,
|
||||
///
|
||||
WeightOnlyQuantOp QuantOp_>
|
||||
class MmaTensorOpDequantizer<MmaOperator_, Shape_, Operand::kB, bfloat16_t, layout::RowMajor, 32, QuantOp_,
|
||||
typename platform::enable_if<MmaOperator_::ArchTag::kMinComputeCapability >= 80
|
||||
&& platform::is_same<typename MmaOperator_::ArchMmaOperator::LayoutB, layout::ColumnMajor>::value>::type>
|
||||
{
|
||||
|
||||
public:
|
||||
/// Mma Operator
|
||||
using MmaOperator = MmaOperator_;
|
||||
|
||||
// The architecture specific mma ooperator being used
|
||||
using ArchMmaOperator = typename MmaOperator::ArchMmaOperator;
|
||||
|
||||
// Mma Instruction Shape
|
||||
using InstructionShape = typename ArchMmaOperator::Shape;
|
||||
|
||||
// This is the ratio of the load instruction vs the compute instruction.
|
||||
static constexpr int kExpansionFactor = MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK;
|
||||
|
||||
/// Type of the scales
|
||||
using ElementScale = bfloat16_t;
|
||||
|
||||
/// Fragment to hold B data before Mma
|
||||
using FragmentDequantizedOperand = Array<ElementScale, MmaOperator::FragmentB::kElements>;
|
||||
|
||||
// Fragment to hold scale data to apply to B before mma
|
||||
// We need 1 fp16 per matrix iteration in the N dimension
|
||||
static constexpr int kColsPerMmaPerThread = 1;
|
||||
using FragmentScale = Array<ElementScale, kColsPerMmaPerThread * MmaOperator::MmaIterations::kColumn>;
|
||||
using FragmentZero = Array<ElementScale, kColsPerMmaPerThread * MmaOperator::MmaIterations::kColumn>;
|
||||
|
||||
/// Warp mma shape
|
||||
using Shape = Shape_;
|
||||
|
||||
/// Layout of the scales in shared memory
|
||||
using Layout = layout::RowMajor;
|
||||
|
||||
/// TensorRef type for loading element from a tensor
|
||||
using TensorRef = TensorRef<ElementScale, Layout>;
|
||||
|
||||
static constexpr WeightOnlyQuantOp QuantOp = QuantOp_;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
MmaTensorOpDequantizer(TensorRef smem_scales, TensorRef smem_zeros, int const warp_idx_n, int const lane_idx)
|
||||
{
|
||||
int const warp_offset = warp_idx_n * Shape::kN;
|
||||
int const quad = lane_idx / 4;
|
||||
int const thread_offset = warp_offset + quad;
|
||||
pointer_scale_ = smem_scales.data() + thread_offset;
|
||||
if constexpr (hasZero(QuantOp))
|
||||
{
|
||||
pointer_zero_ = smem_zeros.data() + thread_offset;
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
MmaTensorOpDequantizer(TensorRef smem_scales, int const warp_idx_n, int const lane_idx)
|
||||
: MmaTensorOpDequantizer(smem_scales, TensorRef(), warp_idx_n, lane_idx)
|
||||
{
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void load(FragmentScale& scale_frag)
|
||||
{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter)
|
||||
{
|
||||
scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN];
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag)
|
||||
{
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16))
|
||||
using _MmaOperandB = typename ArchMmaOperator::FragmentB;
|
||||
using ExpandedMmaOperandB = Array<typename _MmaOperandB::Element, kExpansionFactor * _MmaOperandB::kElements>;
|
||||
static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn
|
||||
== FragmentDequantizedOperand::kElements,
|
||||
"");
|
||||
|
||||
__nv_bfloat16 const* scale_ptr = reinterpret_cast<__nv_bfloat16 const*>(&scale_frag);
|
||||
ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast<ExpandedMmaOperandB*>(&operand_frag);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter)
|
||||
{
|
||||
static_assert(ExpandedMmaOperandB::kElements % 2 == 0, "");
|
||||
|
||||
__nv_bfloat162 scalex2 = __bfloat162bfloat162(scale_ptr[mma_n_iter]);
|
||||
__nv_bfloat162* operand_bf16x2_ptr = reinterpret_cast<__nv_bfloat162*>(&operand_frag_ptr[mma_n_iter]);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii)
|
||||
{
|
||||
operand_bf16x2_ptr[ii] = __hmul2(operand_bf16x2_ptr[ii], scalex2);
|
||||
}
|
||||
}
|
||||
#else
|
||||
// Slow path not implemented here on purpose. If we need to do HMMA on older arch, scale conversion should
|
||||
// happen before scales are stored to shared memory and we should use the fp16 dequantizer. This will avoid
|
||||
// numerous conversion instructions in GEMM main loop.
|
||||
arch::device_breakpoint();
|
||||
#endif
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void load(FragmentScale& scale_frag, FragmentScale& zero_frag)
|
||||
{
|
||||
if constexpr (hasZero(QuantOp))
|
||||
{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter)
|
||||
{
|
||||
scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN];
|
||||
zero_frag[mma_n_iter] = pointer_zero_[mma_n_iter * InstructionShape::kN];
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter)
|
||||
{
|
||||
scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void dequantize(
|
||||
FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag, FragmentScale const& zero_frag)
|
||||
{
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16))
|
||||
using _MmaOperandB = typename ArchMmaOperator::FragmentB;
|
||||
using ExpandedMmaOperandB = Array<typename _MmaOperandB::Element, kExpansionFactor * _MmaOperandB::kElements>;
|
||||
static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn
|
||||
== FragmentDequantizedOperand::kElements,
|
||||
"");
|
||||
|
||||
__nv_bfloat16 const* scale_ptr = reinterpret_cast<__nv_bfloat16 const*>(&scale_frag);
|
||||
__nv_bfloat16 const* zero_ptr = reinterpret_cast<__nv_bfloat16 const*>(&zero_frag);
|
||||
|
||||
ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast<ExpandedMmaOperandB*>(&operand_frag);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter)
|
||||
{
|
||||
static_assert(ExpandedMmaOperandB::kElements % 2 == 0, "");
|
||||
|
||||
__nv_bfloat162 scalex2 = __bfloat162bfloat162(scale_ptr[mma_n_iter]);
|
||||
__nv_bfloat162 zerox2 = __bfloat162bfloat162(zero_ptr[mma_n_iter]);
|
||||
__nv_bfloat162* operand_bf16x2_ptr = reinterpret_cast<__nv_bfloat162*>(&operand_frag_ptr[mma_n_iter]);
|
||||
|
||||
if constexpr (hasZero(QuantOp))
|
||||
{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii)
|
||||
{
|
||||
operand_bf16x2_ptr[ii] = __hfma2(operand_bf16x2_ptr[ii], scalex2, zerox2);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii)
|
||||
{
|
||||
operand_bf16x2_ptr[ii] = __hmul2(operand_bf16x2_ptr[ii], scalex2);
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
// Slow path not implemented here on purpose. If we need to do HMMA on older arch, scale conversion should
|
||||
// happen before scales are stored to shared memory and we should use the fp16 dequantizer. This will avoid
|
||||
// numerous conversion instructions in GEMM main loop.
|
||||
arch::device_breakpoint();
|
||||
#endif
|
||||
}
|
||||
|
||||
// Adds a pointer offset in units of elements.
|
||||
CUTLASS_DEVICE
|
||||
void add_pointer_offset(int64_t const& offset)
|
||||
{
|
||||
static_assert(sizeof(ElementScale) > 1, "");
|
||||
pointer_scale_ += offset;
|
||||
pointer_zero_ += offset;
|
||||
}
|
||||
|
||||
private:
|
||||
ElementScale const* pointer_scale_;
|
||||
ElementScale const* pointer_zero_;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Specialization for Turing & Ampere
|
||||
template <
|
||||
/// Underlying matrix multiply operator (concept: MmaTensorOp)
|
||||
typename MmaOperator_,
|
||||
/// Shape of the warp level matrix multiply (concept: GemmShape)
|
||||
typename Shape_,
|
||||
///
|
||||
WeightOnlyQuantOp QuantOp_>
|
||||
class MmaTensorOpDequantizer<MmaOperator_, Shape_, Operand::kB, half_t, layout::RowMajor, 32, QuantOp_,
|
||||
typename platform::enable_if<MmaOperator_::ArchTag::kMinComputeCapability >= 75
|
||||
&& platform::is_same<typename MmaOperator_::ArchMmaOperator::LayoutB, layout::ColumnMajor>::value>::type>
|
||||
{
|
||||
|
||||
public:
|
||||
/// Mma Operator
|
||||
using MmaOperator = MmaOperator_;
|
||||
|
||||
// The architecture specific mma ooperator being used
|
||||
using ArchMmaOperator = typename MmaOperator::ArchMmaOperator;
|
||||
|
||||
// Mma Instruction Shape
|
||||
using InstructionShape = typename ArchMmaOperator::Shape;
|
||||
|
||||
// This is the ratio of the load instruction vs the compute instruction.
|
||||
static constexpr int kExpansionFactor = MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK;
|
||||
|
||||
/// Type of the scales
|
||||
using ElementScale = half_t;
|
||||
|
||||
/// Fragment to hold B data before Mma
|
||||
using FragmentDequantizedOperand = Array<ElementScale, MmaOperator::FragmentB::kElements>;
|
||||
|
||||
// Fragment to hold scale data to apply to B before mma
|
||||
// We need 1 fp16 per matrix iteration in the N dimension
|
||||
static constexpr int kColsPerMmaPerThread = 1;
|
||||
using FragmentScale = Array<ElementScale, kColsPerMmaPerThread * MmaOperator::MmaIterations::kColumn>;
|
||||
using FragmentZero = Array<ElementScale, kColsPerMmaPerThread * MmaOperator::MmaIterations::kColumn>;
|
||||
|
||||
/// Warp mma shape
|
||||
using Shape = Shape_;
|
||||
|
||||
/// Layout of the scales in shared memory
|
||||
using Layout = layout::RowMajor;
|
||||
|
||||
/// TensorRef type for loading element from a tensor
|
||||
using TensorRef = TensorRef<ElementScale, Layout>;
|
||||
|
||||
static constexpr WeightOnlyQuantOp QuantOp = QuantOp_;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
MmaTensorOpDequantizer(TensorRef smem_scales, TensorRef smem_zeros, int const warp_idx_n, int const lane_idx)
|
||||
{
|
||||
int const warp_offset = warp_idx_n * Shape::kN;
|
||||
int const quad = lane_idx / 4;
|
||||
int const thread_offset = warp_offset + quad;
|
||||
pointer_scale_ = smem_scales.data() + thread_offset;
|
||||
if constexpr (hasZero(QuantOp))
|
||||
{
|
||||
pointer_zero_ = smem_zeros.data() + thread_offset;
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
MmaTensorOpDequantizer(TensorRef smem_scales, int const warp_idx_n, int const lane_idx)
|
||||
: MmaTensorOpDequantizer(smem_scales, TensorRef(), warp_idx_n, lane_idx)
|
||||
{
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void load(FragmentScale& scale_frag)
|
||||
{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter)
|
||||
{
|
||||
scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN];
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag)
|
||||
{
|
||||
using _MmaOperandB = typename ArchMmaOperator::FragmentB;
|
||||
using ExpandedMmaOperandB
|
||||
= Array<typename FragmentDequantizedOperand::Element, kExpansionFactor * _MmaOperandB::kElements>;
|
||||
static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn
|
||||
== FragmentDequantizedOperand::kElements,
|
||||
"");
|
||||
|
||||
multiplies<ExpandedMmaOperandB> mul_op;
|
||||
|
||||
ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast<ExpandedMmaOperandB*>(&operand_frag);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter)
|
||||
{
|
||||
operand_frag_ptr[mma_n_iter] = mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]);
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void load(FragmentScale& scale_frag, FragmentScale& zero_frag)
|
||||
{
|
||||
if constexpr (hasZero(QuantOp))
|
||||
{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter)
|
||||
{
|
||||
scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN];
|
||||
zero_frag[mma_n_iter] = pointer_zero_[mma_n_iter * InstructionShape::kN];
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter)
|
||||
{
|
||||
scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void dequantize(
|
||||
FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag, FragmentScale const& zero_frag)
|
||||
{
|
||||
using _MmaOperandB = typename ArchMmaOperator::FragmentB;
|
||||
using ExpandedMmaOperandB
|
||||
= Array<typename FragmentDequantizedOperand::Element, kExpansionFactor * _MmaOperandB::kElements>;
|
||||
static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn
|
||||
== FragmentDequantizedOperand::kElements,
|
||||
"");
|
||||
|
||||
multiplies<ExpandedMmaOperandB> mul_op;
|
||||
ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast<ExpandedMmaOperandB*>(&operand_frag);
|
||||
|
||||
if constexpr (hasZero(QuantOp))
|
||||
{
|
||||
plus<ExpandedMmaOperandB> plus_op;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter)
|
||||
{
|
||||
operand_frag_ptr[mma_n_iter]
|
||||
= plus_op(mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]), zero_frag[mma_n_iter]);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter)
|
||||
{
|
||||
operand_frag_ptr[mma_n_iter] = mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Adds a pointer offset in units of elements.
|
||||
CUTLASS_DEVICE
|
||||
void add_pointer_offset(int64_t const& offset)
|
||||
{
|
||||
static_assert(sizeof(ElementScale) > 1, "");
|
||||
pointer_scale_ += offset;
|
||||
pointer_zero_ += offset;
|
||||
}
|
||||
|
||||
private:
|
||||
ElementScale const* pointer_scale_;
|
||||
ElementScale const* pointer_zero_;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace warp
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
224
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h
vendored
Normal file
224
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h
vendored
Normal file
@@ -0,0 +1,224 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace cutlass_extensions
|
||||
{
|
||||
// Note: The shapes are in the format MxNxK. The K shape of the runtime config MUST match the K shape
|
||||
// in the kernel layout details when doing weight only quantization.
|
||||
enum class CutlassTileConfig
|
||||
{
|
||||
// Signals that we should run heuristics do choose a config
|
||||
Undefined,
|
||||
|
||||
// Signals that we should run heuristics do choose a config
|
||||
ChooseWithHeuristic,
|
||||
|
||||
// SiMT config
|
||||
CtaShape128x128x8_WarpShape64x64x8,
|
||||
|
||||
// TensorCore configs CTA_N = 128, CTA_K = 64
|
||||
// Warp configs for M=16
|
||||
CtaShape16x128x64_WarpShape16x32x64,
|
||||
// Warp configs for M=32
|
||||
CtaShape32x128x64_WarpShape32x32x64,
|
||||
|
||||
// Warp configs for M=64
|
||||
CtaShape64x128x64_WarpShape32x64x64,
|
||||
CtaShape64x64x128_WarpShape32x64x64,
|
||||
CtaShape64x128x64_WarpShape64x32x64,
|
||||
|
||||
// Warp configs for M=128
|
||||
CtaShape128x64x64_WarpShape64x32x64,
|
||||
CtaShape128x128x64_WarpShape64x32x64,
|
||||
CtaShape128x128x64_WarpShape64x64x64,
|
||||
CtaShape128x128x64_WarpShape128x32x64,
|
||||
CtaShape128x256x64_WarpShape64x64x64,
|
||||
|
||||
// Warp configs for M=256
|
||||
CtaShape256x128x64_WarpShape64x64x64,
|
||||
|
||||
// TensorCore config CTA_N = 64, CTA_K = 128
|
||||
CtaShape128x64x128_WarpShape64x32x128,
|
||||
|
||||
// TensorCore config CTA_N = 256, CTA_K = 64
|
||||
CtaShape16x256x64_WarpShape16x64x64,
|
||||
|
||||
// TensorCore config CTA_N = 256, CTA_K = 128
|
||||
CtaShape16x256x128_WarpShape16x64x128
|
||||
|
||||
};
|
||||
|
||||
enum class SplitKStyle
|
||||
{
|
||||
NO_SPLIT_K,
|
||||
SPLIT_K_SERIAL,
|
||||
STREAM_K, // Sm80+
|
||||
// SPLIT_K_PARALLEL // Not supported yet
|
||||
};
|
||||
|
||||
enum class CutlassTileConfigSM90
|
||||
{
|
||||
// Signals that we should run heuristics do choose a config
|
||||
Undefined,
|
||||
|
||||
// Signals that we should run heuristics do choose a config
|
||||
ChooseWithHeuristic,
|
||||
|
||||
// CTA configs for M=64
|
||||
CtaShape64x16x128B,
|
||||
CtaShape64x32x128B,
|
||||
CtaShape64x64x128B,
|
||||
CtaShape64x128x128B,
|
||||
CtaShape64x256x128B,
|
||||
|
||||
// CTA configs for M=128
|
||||
CtaShape128x16x128B,
|
||||
CtaShape128x32x128B,
|
||||
CtaShape128x64x128B,
|
||||
CtaShape128x128x128B,
|
||||
CtaShape128x256x128B,
|
||||
|
||||
// CTA configs for M=128
|
||||
CtaShape256x128x128B,
|
||||
};
|
||||
|
||||
enum class MainloopScheduleType
|
||||
{
|
||||
AUTO // Automatically selects between pingpong and cooperative schedules on Hopper. On older architectures, this
|
||||
// defaults to the "legacy" main loop schedule.
|
||||
};
|
||||
|
||||
enum class EpilogueScheduleType
|
||||
{
|
||||
AUTO // Automatically chooses an epilogue schedule compatible with the selected main loop schedule for Hopper. For
|
||||
// architectures older than hopper, the epilogue is always performed by the same thread block as the main loop.
|
||||
};
|
||||
|
||||
enum class ClusterShape
|
||||
{
|
||||
ClusterShape_1x1x1,
|
||||
ClusterShape_2x1x1,
|
||||
ClusterShape_1x2x1,
|
||||
ClusterShape_2x2x1,
|
||||
ClusterShape_1x8x1,
|
||||
ClusterShape_8x1x1
|
||||
};
|
||||
|
||||
struct CutlassGemmConfig
|
||||
{
|
||||
enum CandidateConfigTypeParam : int
|
||||
{
|
||||
NONE = 0,
|
||||
WEIGHT_ONLY = 1u << 0,
|
||||
SIMT_ONLY = 1u << 1,
|
||||
INT8_ONLY = 1u << 2,
|
||||
HOPPER = 1u << 3,
|
||||
GROUPED_GEMM = 1u << 4,
|
||||
FP8_ONLY = 1u << 5,
|
||||
};
|
||||
|
||||
CutlassTileConfig tile_config = CutlassTileConfig::ChooseWithHeuristic;
|
||||
SplitKStyle split_k_style = SplitKStyle::NO_SPLIT_K;
|
||||
int split_k_factor = -1;
|
||||
int stages = -1;
|
||||
|
||||
// config options for sm90
|
||||
CutlassTileConfigSM90 tile_config_sm90 = CutlassTileConfigSM90::ChooseWithHeuristic;
|
||||
MainloopScheduleType mainloop_schedule = MainloopScheduleType::AUTO;
|
||||
EpilogueScheduleType epilogue_schedule = EpilogueScheduleType::AUTO;
|
||||
ClusterShape cluster_shape = ClusterShape::ClusterShape_1x1x1;
|
||||
bool is_sm90 = false;
|
||||
|
||||
CutlassGemmConfig() {}
|
||||
|
||||
CutlassGemmConfig(CutlassTileConfig tile_config, SplitKStyle split_k_style, int split_k_factor, int stages)
|
||||
: tile_config(tile_config)
|
||||
, split_k_style(split_k_style)
|
||||
, split_k_factor(split_k_factor)
|
||||
, stages(stages)
|
||||
, is_sm90(false)
|
||||
{
|
||||
}
|
||||
|
||||
CutlassGemmConfig(CutlassTileConfigSM90 tile_config_sm90, MainloopScheduleType mainloop_schedule,
|
||||
EpilogueScheduleType epilogue_schedule, ClusterShape cluster_shape)
|
||||
: tile_config_sm90(tile_config_sm90)
|
||||
, mainloop_schedule(mainloop_schedule)
|
||||
, epilogue_schedule(epilogue_schedule)
|
||||
, cluster_shape(cluster_shape)
|
||||
, is_sm90(true)
|
||||
{
|
||||
}
|
||||
|
||||
std::string toString() const
|
||||
{
|
||||
std::stringstream tactic;
|
||||
tactic << "Cutlass GEMM Tactic";
|
||||
if (tile_config_sm90 != tensorrt_llm::cutlass_extensions::CutlassTileConfigSM90::ChooseWithHeuristic)
|
||||
{
|
||||
assert(is_sm90 && "Invalid cutlass GEMM config");
|
||||
tactic << "\n\tstyle=TMA"
|
||||
<< "\n\ttile shape ID: " << (int) tile_config_sm90 << "\n\tcluster shape ID: " << (int) cluster_shape
|
||||
<< "\n\tmainloop sched: " << (int) mainloop_schedule << "\n\tepi sched: " << (int) epilogue_schedule;
|
||||
}
|
||||
else if (tile_config != tensorrt_llm::cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic)
|
||||
{
|
||||
assert(!is_sm90 && "Invalid cutlass GEMM config");
|
||||
tactic << "\n\tstyle=compatible"
|
||||
<< "\n\ttile shape ID: " << (int) tile_config << "\n\tstages: " << (int) stages
|
||||
<< "\n\tsplit k: " << (int) split_k_factor;
|
||||
}
|
||||
else
|
||||
{
|
||||
tactic << "\n\tundefined";
|
||||
}
|
||||
tactic << "\n";
|
||||
return tactic.str();
|
||||
}
|
||||
};
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& out, CutlassGemmConfig const& config)
|
||||
{
|
||||
// clang-format off
|
||||
if (config.is_sm90)
|
||||
{
|
||||
out << "tile_config_sm90_enum: " << int(config.tile_config_sm90)
|
||||
<< ", mainloop_schedule_enum: " << int(config.mainloop_schedule)
|
||||
<< ", epilogue_schedule_enum: " << int(config.epilogue_schedule)
|
||||
<< ", cluster_shape_enum: " << int(config.cluster_shape);
|
||||
}
|
||||
else
|
||||
{
|
||||
out << "tile_config_enum: " << int(config.tile_config)
|
||||
<< ", split_k_style_enum: " << int(config.split_k_style)
|
||||
<< ", split_k_factor: " << config.split_k_factor
|
||||
<< ", stages: " << config.stages;
|
||||
}
|
||||
// clang-format on
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace cutlass_extensions
|
||||
} // namespace tensorrt_llm
|
||||
@@ -0,0 +1,447 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*!
|
||||
\file
|
||||
\brief Boost-like numeric conversion operator for int8 and CUTLASS int4b_t interleaved in a register
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/half.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
|
||||
// This converter is meant to be used with data interleaved in a 32-bit register where the even elements are in the low
|
||||
// bits and the odd elemeents are in the high bits of the register. In addition, it assumes elements were originally
|
||||
// signed and had a bias of 2**(b-1) added (where b is the number of bits in the type) to make all numbers unsigned.
|
||||
// This converter will uninterleave the data and subtract the bias while converting to the result type.
|
||||
template <typename T, typename S, int N>
|
||||
struct FastInterleavedAndBiasedNumericArrayConverter
|
||||
{
|
||||
};
|
||||
|
||||
template <>
|
||||
struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint8_t, 4>
|
||||
{
|
||||
using result_type = Array<half_t, 4>;
|
||||
using source_type = Array<uint8_t, 4>;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source)
|
||||
{
|
||||
result_type result;
|
||||
|
||||
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
|
||||
uint32_t const i8s = reinterpret_cast<uint32_t const&>(source);
|
||||
|
||||
static constexpr uint32_t mask_for_elt_01 = 0x5250;
|
||||
static constexpr uint32_t mask_for_elt_23 = 0x5351;
|
||||
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
|
||||
asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[0]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_01));
|
||||
asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[1]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_23));
|
||||
|
||||
// Lastly, we subtract 1152 from our constructed number using fp16 math to get our signed integer as fp16.
|
||||
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(I8s_TO_F16s_MAGIC_NUM));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(I8s_TO_F16s_MAGIC_NUM));
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s)
|
||||
{
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
|
||||
template <int N>
|
||||
struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint8_t, N>
|
||||
{
|
||||
static constexpr int VEC_WIDTH = 4;
|
||||
static_assert(!(N % VEC_WIDTH), "N must be multiple of 4.");
|
||||
|
||||
using result_type = Array<half_t, N>;
|
||||
using source_type = Array<uint8_t, N>;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source)
|
||||
{
|
||||
using scalar_result_type = typename result_type::Element;
|
||||
using scalar_source_type = typename source_type::Element;
|
||||
FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type, scalar_source_type, VEC_WIDTH>
|
||||
convert_vector_;
|
||||
|
||||
result_type result;
|
||||
using vec_result = Array<scalar_result_type, VEC_WIDTH>;
|
||||
using vec_source = Array<scalar_source_type, VEC_WIDTH>;
|
||||
|
||||
vec_result* result_ptr = reinterpret_cast<vec_result*>(&result);
|
||||
vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N / VEC_WIDTH; ++i)
|
||||
{
|
||||
result_ptr[i] = convert_vector_(source_ptr[i]);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s)
|
||||
{
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint8_t, 4>
|
||||
{
|
||||
using result_type = Array<bfloat16_t, 4>;
|
||||
using source_type = Array<uint8_t, 4>;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source)
|
||||
{
|
||||
result_type result;
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
|
||||
|
||||
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(&result);
|
||||
uint32_t const i8s = reinterpret_cast<uint32_t const&>(source);
|
||||
|
||||
static constexpr uint32_t fp32_base = 0x4B000000;
|
||||
float fp32_intermediates[4];
|
||||
|
||||
// Construct FP32s, bfloat does not have enough mantissa for IADD trick
|
||||
uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
|
||||
fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650);
|
||||
fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7652);
|
||||
fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7651);
|
||||
fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653);
|
||||
|
||||
// Subtract out fp32_base + 128 to make the unsigned integer signed.
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < 4; ++ii)
|
||||
{
|
||||
fp32_intermediates[ii] -= 8388736.f;
|
||||
}
|
||||
|
||||
// Truncate the fp32 representation and pack up as bfloat16s.
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < 2; ++ii)
|
||||
{
|
||||
bf16_result_ptr[ii]
|
||||
= __byte_perm(fp32_intermediates_casted[2 * ii + 0], fp32_intermediates_casted[2 * ii + 1], 0x7632);
|
||||
}
|
||||
#else
|
||||
// Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use
|
||||
// HMMA on older hardware, they should Convert directly to FP16 using FP16 converters.
|
||||
result.clear(); // Suppress compiler warning
|
||||
arch::device_breakpoint();
|
||||
#endif
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s)
|
||||
{
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
|
||||
template <int N>
|
||||
struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint8_t, N>
|
||||
{
|
||||
static constexpr int VEC_WIDTH = 4;
|
||||
static_assert(!(N % VEC_WIDTH), "N must be multiple of 4.");
|
||||
|
||||
using result_type = Array<bfloat16_t, N>;
|
||||
using source_type = Array<uint8_t, N>;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source)
|
||||
{
|
||||
using scalar_result_type = typename result_type::Element;
|
||||
using scalar_source_type = typename source_type::Element;
|
||||
FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type, scalar_source_type, VEC_WIDTH>
|
||||
convert_vector_;
|
||||
|
||||
result_type result;
|
||||
using vec_result = Array<scalar_result_type, VEC_WIDTH>;
|
||||
using vec_source = Array<scalar_source_type, VEC_WIDTH>;
|
||||
|
||||
vec_result* result_ptr = reinterpret_cast<vec_result*>(&result);
|
||||
vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N / VEC_WIDTH; ++i)
|
||||
{
|
||||
result_ptr[i] = convert_vector_(source_ptr[i]);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s)
|
||||
{
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint4b_t, 8>
|
||||
{
|
||||
using result_type = Array<half_t, 8>;
|
||||
using source_type = Array<uint4b_t, 8>;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source)
|
||||
{
|
||||
result_type result;
|
||||
|
||||
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
|
||||
uint32_t const i4s = reinterpret_cast<uint32_t const&>(source);
|
||||
|
||||
// First, we extract the i4s and construct an intermediate fp16 number.
|
||||
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
|
||||
static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
|
||||
static constexpr uint32_t TOP_MASK = 0x00f000f0;
|
||||
static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
|
||||
|
||||
// Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing
|
||||
// format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions.
|
||||
// In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and
|
||||
// elt_67 to fp16 without having to shift them to the bottom bits before hand.
|
||||
|
||||
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue
|
||||
// immediately before required.
|
||||
const uint32_t top_i4s = i4s >> 8;
|
||||
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(h[0])
|
||||
: "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(h[1])
|
||||
: "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(h[2])
|
||||
: "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(h[3])
|
||||
: "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||
|
||||
// I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the
|
||||
// half2 ctor. In this case, I chose performance reliability over code readability.
|
||||
|
||||
// This is the half2 {1032, 1032} represented as an integer.
|
||||
static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
|
||||
// This is the half2 {1 / 16, 1 / 16} represented as an integer.
|
||||
static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
|
||||
// This is the half2 {-72, -72} represented as an integer.
|
||||
static constexpr uint32_t NEG_72 = 0xd480d480;
|
||||
|
||||
// Finally, we construct the output numbers.
|
||||
// Convert elt_01
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
|
||||
// Convert elt_23
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_72));
|
||||
// Convert elt_45
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
|
||||
// Convert elt_67
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_72));
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s)
|
||||
{
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
|
||||
template <int N>
|
||||
struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint4b_t, N>
|
||||
{
|
||||
static constexpr int VEC_WIDTH = 8;
|
||||
static_assert(!(N % VEC_WIDTH), "N must be multiple of 8.");
|
||||
|
||||
using result_type = Array<half_t, N>;
|
||||
using source_type = Array<uint4b_t, N>;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source)
|
||||
{
|
||||
using scalar_result_type = typename result_type::Element;
|
||||
using scalar_source_type = typename source_type::Element;
|
||||
FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type, scalar_source_type, VEC_WIDTH>
|
||||
convert_vector_;
|
||||
|
||||
result_type result;
|
||||
using vec_result = Array<scalar_result_type, VEC_WIDTH>;
|
||||
using vec_source = Array<scalar_source_type, VEC_WIDTH>;
|
||||
|
||||
vec_result* result_ptr = reinterpret_cast<vec_result*>(&result);
|
||||
vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N / VEC_WIDTH; ++i)
|
||||
{
|
||||
result_ptr[i] = convert_vector_(source_ptr[i]);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s)
|
||||
{
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint4b_t, 8>
|
||||
{
|
||||
using result_type = Array<bfloat16_t, 8>;
|
||||
using source_type = Array<uint4b_t, 8>;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source)
|
||||
{
|
||||
result_type result;
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
|
||||
|
||||
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
|
||||
uint32_t const source_i4s = reinterpret_cast<uint32_t const&>(source);
|
||||
|
||||
// First, we extract the i4s and construct an intermediate fp16 number.
|
||||
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
|
||||
static constexpr uint32_t MASK = 0x000f000f;
|
||||
static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300;
|
||||
|
||||
// We don't have enough mantissa to remove as much shift overhead as FP16, so we must loop.
|
||||
// No shift needed for first item.
|
||||
uint32_t i4s = source_i4s;
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(h[0])
|
||||
: "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 1; ii < result_type::kElements / 2; ++ii)
|
||||
{
|
||||
i4s >>= sizeof_bits<typename source_type::Element>::value;
|
||||
// (i4s & 0x000f000f) | 0x43004300
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(h[ii])
|
||||
: "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
|
||||
}
|
||||
|
||||
// This is the BF16 {-136, -136} represented as an integer.
|
||||
static constexpr uint32_t BF16_BIAS = 0xC308C308;
|
||||
static constexpr uint32_t BF16_ONE = 0x3F803F80;
|
||||
|
||||
// Finally, we construct the output numbers.
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < result_type::kElements / 2; ++ii)
|
||||
{
|
||||
// Since this section is for Ampere+, we use bf16 fma to do the bias subtraction
|
||||
asm("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[ii]) : "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS));
|
||||
}
|
||||
#else
|
||||
// Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use
|
||||
// HMMA on older hardware, they should Convert directly to FP16 using FP16 converters.
|
||||
arch::device_breakpoint();
|
||||
result.clear(); // Suppress compiler warning.
|
||||
#endif
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s)
|
||||
{
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
|
||||
template <int N>
|
||||
struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint4b_t, N>
|
||||
{
|
||||
static constexpr int VEC_WIDTH = 8;
|
||||
static_assert(!(N % VEC_WIDTH), "N must be multiple of 8.");
|
||||
|
||||
using result_type = Array<bfloat16_t, N>;
|
||||
using source_type = Array<uint4b_t, N>;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source)
|
||||
{
|
||||
using scalar_result_type = typename result_type::Element;
|
||||
using scalar_source_type = typename source_type::Element;
|
||||
FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type, scalar_source_type, VEC_WIDTH>
|
||||
convert_vector_;
|
||||
|
||||
result_type result;
|
||||
using vec_result = Array<scalar_result_type, VEC_WIDTH>;
|
||||
using vec_source = Array<scalar_source_type, VEC_WIDTH>;
|
||||
|
||||
vec_result* result_ptr = reinterpret_cast<vec_result*>(&result);
|
||||
vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N / VEC_WIDTH; ++i)
|
||||
{
|
||||
result_ptr[i] = convert_vector_(source_ptr[i]);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s)
|
||||
{
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,66 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines new layouts needed for MoE
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/pitch_linear_coord.h"
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace layout
|
||||
{
|
||||
|
||||
template <int RowsPerTile, int ColumnsInterleaved>
|
||||
struct ColumnMajorTileInterleave
|
||||
{
|
||||
static constexpr int kRowsPerTile = RowsPerTile;
|
||||
static constexpr int kColumnsInterleaved = ColumnsInterleaved;
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct IsColumnMajorTileInterleave
|
||||
{
|
||||
static constexpr bool value = false;
|
||||
};
|
||||
|
||||
template <int U, int V>
|
||||
struct IsColumnMajorTileInterleave<ColumnMajorTileInterleave<U, V>>
|
||||
{
|
||||
static constexpr bool value = true;
|
||||
};
|
||||
|
||||
} // namespace layout
|
||||
} // namespace cutlass
|
||||
@@ -0,0 +1,250 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Templates for visiting scales to be used when dequantizing the weights for weight-only GEMM
|
||||
quantization.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/layout/pitch_linear.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/predicate_vector.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
#include "cutlass/transform/threadblock/predicated_tile_access_iterator_params.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace transform
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Shape, typename Element, typename Layout, int AdvanceRank, int Alignment>
|
||||
class FineGrainedScaleZeroIterator;
|
||||
|
||||
template <typename Shape_, typename Element_, int Alignment_>
|
||||
class FineGrainedScaleZeroIterator<Shape_, Element_, layout::RowMajor, 0, Alignment_>
|
||||
{
|
||||
public:
|
||||
using Shape = Shape_;
|
||||
using Element = Element_;
|
||||
using Layout = layout::RowMajor;
|
||||
static int const kAdvanceRank = 0;
|
||||
static int const kAlignment = Alignment_;
|
||||
|
||||
static int const kAccessesPerVector = 1;
|
||||
|
||||
/// Row index of scales corresponding to the groupsize of 64
|
||||
int row_groupsize64_;
|
||||
int group_size_;
|
||||
|
||||
using Index = typename Layout::Index;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
|
||||
using TensorRef = TensorRef<Element, Layout>;
|
||||
using TensorView = TensorView<Element, Layout>;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using Pointer = Element*;
|
||||
using NonConstPointer = typename platform::remove_const<Element>::type*;
|
||||
|
||||
using AccessType = AlignedArray<Element, kAlignment>;
|
||||
|
||||
using Fragment = cutlass::Array<Element, kAlignment>;
|
||||
|
||||
// For compatibility with existing iterator interface
|
||||
struct Params
|
||||
{
|
||||
LongIndex stride_ = 0;
|
||||
|
||||
/// amount (in byte) to increment pointer from first access of current tile
|
||||
/// to first access of next tile
|
||||
LongIndex inc_advance_ = 0;
|
||||
|
||||
// Default ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() {}
|
||||
|
||||
/// Construct the Params object given a pitch-linear tensor's layout
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Layout const& layout)
|
||||
: stride_(layout.stride(0))
|
||||
{
|
||||
inc_advance_ = Shape::kRow * stride_ * sizeof_bits<Element>::value / 8;
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
/// Internal pointer type permits fast address arithmetic
|
||||
using BytePointer = char*;
|
||||
|
||||
private:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Parameters object with precomputed internal state
|
||||
Params const params_;
|
||||
|
||||
/// Internal pointer to first access of tile
|
||||
BytePointer pointer_scale_;
|
||||
BytePointer pointer_zero_;
|
||||
|
||||
bool is_valid_ = false;
|
||||
|
||||
public:
|
||||
/// Constructs a TileIterator from its precomputed state, threadblock offset,
|
||||
/// and thread ID
|
||||
CUTLASS_DEVICE
|
||||
FineGrainedScaleZeroIterator(
|
||||
///< Precomputed parameters object
|
||||
Params const& params,
|
||||
///< Pointer to start of scale tensor
|
||||
Pointer pointer_scale,
|
||||
///< Pointer to start of zero tensor
|
||||
Pointer pointer_zero,
|
||||
///< Extent of the scale and bias
|
||||
TensorCoord extent,
|
||||
///< ID of each participating thread
|
||||
int thread_id,
|
||||
///< Initial offset of threadblock
|
||||
TensorCoord const& threadblock_offset,
|
||||
///< Group size
|
||||
int group_size)
|
||||
: params_(params)
|
||||
, pointer_scale_(reinterpret_cast<BytePointer>(const_cast<NonConstPointer>(pointer_scale)))
|
||||
, pointer_zero_(reinterpret_cast<BytePointer>(const_cast<NonConstPointer>(pointer_zero)))
|
||||
{
|
||||
row_groupsize64_ = threadblock_offset.row();
|
||||
group_size_ = group_size;
|
||||
|
||||
const LongIndex tb_row_byte_offset
|
||||
= threadblock_offset.row() / (group_size / 64) * params_.stride_ * sizeof_bits<Element>::value / 8;
|
||||
const LongIndex tb_col_byte_offset = threadblock_offset.column() * sizeof_bits<Element>::value / 8;
|
||||
pointer_scale_ += (tb_row_byte_offset + tb_col_byte_offset);
|
||||
|
||||
if (pointer_zero_ != nullptr)
|
||||
{
|
||||
pointer_zero_ += (tb_row_byte_offset + tb_col_byte_offset);
|
||||
}
|
||||
|
||||
static constexpr int THREADS_PER_ROW = Shape::kColumn / kAlignment;
|
||||
|
||||
int const thread_row = thread_id / THREADS_PER_ROW;
|
||||
int const thread_col = thread_id % THREADS_PER_ROW;
|
||||
|
||||
const LongIndex thread_row_byte_offset = thread_row * params_.stride_ * sizeof_bits<Element>::value / 8;
|
||||
const LongIndex thread_col_byte_offset = thread_col * kAlignment * sizeof_bits<Element>::value / 8;
|
||||
pointer_scale_ += (thread_row_byte_offset + thread_col_byte_offset);
|
||||
if (pointer_zero_ != nullptr)
|
||||
{
|
||||
pointer_zero_ += (thread_row_byte_offset + thread_col_byte_offset);
|
||||
}
|
||||
|
||||
// For the rows, we must check that we are within the extent AND the tile to avoid extra reads on
|
||||
// a given iteration. The same threads will be responsible for issues reads since the number of scales
|
||||
// read in a given iteration is a constant. Therefore, we should never have to update is_valid_
|
||||
// outside of the constructor.
|
||||
int const global_row = threadblock_offset.row() + thread_row;
|
||||
int const global_col = threadblock_offset.column() + thread_col * kAlignment;
|
||||
|
||||
bool const row_in_bounds = global_row < extent.row() && thread_row < Shape::kRow;
|
||||
bool const col_in_bounds = global_col < extent.column();
|
||||
|
||||
is_valid_ = row_in_bounds && col_in_bounds;
|
||||
}
|
||||
|
||||
/// Construct a PredicatedTileAccessIterator with zero threadblock offset
|
||||
CUTLASS_HOST_DEVICE FineGrainedScaleZeroIterator(Params const& params, ///< Precomputed parameters object
|
||||
Pointer pointer_scale, ///< Pointer to start of scale tensor
|
||||
Pointer pointer_zero, ///< Pointer to start of zero tensor
|
||||
TensorCoord extent, ///< Extent of tensor
|
||||
int thread_id, ///< ID of each participating thread
|
||||
int group_size)
|
||||
: FineGrainedScaleZeroIterator(
|
||||
params, pointer_scale, pointer_zero, extent, thread_id, make_Coord(0, 0), group_size)
|
||||
{
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void add_tile_offset(TensorCoord const& tile_offset)
|
||||
{
|
||||
const LongIndex row_byte_offset = tile_offset.row() * params_.inc_advance_;
|
||||
const LongIndex col_byte_offset = tile_offset.column() * Shape::kColumn * sizeof_bits<Element>::value / 8;
|
||||
pointer_scale_ += row_byte_offset + col_byte_offset;
|
||||
if (pointer_zero_ != nullptr)
|
||||
{
|
||||
pointer_zero_ += row_byte_offset + col_byte_offset;
|
||||
}
|
||||
}
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE void clear_mask(bool enable = true)
|
||||
{
|
||||
is_valid_ &= (!enable);
|
||||
}
|
||||
|
||||
/// Returns whether access is valid or not
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool valid() const
|
||||
{
|
||||
return is_valid_;
|
||||
}
|
||||
|
||||
/// Returns a scale pointer
|
||||
CUTLASS_HOST_DEVICE
|
||||
AccessType* get_scale() const
|
||||
{
|
||||
return reinterpret_cast<AccessType*>(pointer_scale_);
|
||||
}
|
||||
|
||||
/// Returns a zero pointer
|
||||
CUTLASS_HOST_DEVICE
|
||||
AccessType* get_zero() const
|
||||
{
|
||||
return reinterpret_cast<AccessType*>(pointer_zero_);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace transform
|
||||
} // namespace cutlass
|
||||
@@ -0,0 +1,181 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "cute/layout.hpp"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/util/print.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
/// Function object that applies an index to its argument
|
||||
template <class Iter>
|
||||
struct IndexedGather
|
||||
{
|
||||
CUTE_HOST_DEVICE constexpr IndexedGather(Iter indices = {})
|
||||
: indices_(indices)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename I>
|
||||
CUTE_HOST_DEVICE constexpr auto operator()(I i) const
|
||||
{
|
||||
return indices_[i];
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE friend void print(IndexedGather const& s)
|
||||
{
|
||||
cute::print("Indexed{");
|
||||
print(s.indices_);
|
||||
print("}");
|
||||
}
|
||||
|
||||
Iter indices_;
|
||||
};
|
||||
|
||||
/// Custom stride object that applies a function followed by a stride
|
||||
template <class Func, class Stride>
|
||||
struct CustomStride
|
||||
{
|
||||
CUTE_HOST_DEVICE constexpr CustomStride(Func const& func, Stride const& stride)
|
||||
: func_(func)
|
||||
, stride_(stride)
|
||||
{
|
||||
}
|
||||
|
||||
template <class I>
|
||||
CUTE_HOST_DEVICE constexpr friend auto operator*(I i, CustomStride const& s)
|
||||
{
|
||||
return s.func_(i) * s.stride_;
|
||||
}
|
||||
|
||||
template <class I>
|
||||
CUTE_HOST_DEVICE constexpr friend auto operator*(CustomStride const& s, I i)
|
||||
{
|
||||
return s.func_(i) * s.stride_;
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE friend void print(CustomStride const& s)
|
||||
{
|
||||
cute::print("Custom{");
|
||||
print(s.func_);
|
||||
cute::print(",");
|
||||
print(s.stride_);
|
||||
cute::print("}");
|
||||
}
|
||||
|
||||
template <class Div>
|
||||
CUTE_HOST_DEVICE constexpr friend auto safe_div(CustomStride const& s, Div const& div)
|
||||
{
|
||||
return CustomStride<Func, decltype(safe_div(s.stride_, div))>(s.func_, safe_div(s.stride_, div));
|
||||
}
|
||||
|
||||
// Circumvent the requirement on make_layout that shape and stride are integral
|
||||
template <class Shape>
|
||||
CUTE_HOST_DEVICE constexpr friend auto make_layout(Shape const& shape, CustomStride const& stride)
|
||||
{
|
||||
return Layout<Shape, CustomStride>(shape, stride);
|
||||
}
|
||||
|
||||
Func func_;
|
||||
Stride stride_;
|
||||
};
|
||||
|
||||
template <class Stride, class Func>
|
||||
CUTLASS_HOST_DEVICE auto make_custom_stride_layout(Stride const& stride, Func&& func)
|
||||
{
|
||||
// Use a dummy shape and replace the first non-unit and non-zero stride with a custom gather stride
|
||||
auto idx = find_if(stride, [](auto x) { return !is_constant<1, decltype(x)>{} && !is_constant<0, decltype(x)>{}; });
|
||||
constexpr int I = decltype(idx)::value;
|
||||
return make_layout(
|
||||
repeat_like(stride, _1{}), replace<I>(stride, CustomStride{static_cast<Func&&>(func), get<I>(stride)}));
|
||||
}
|
||||
|
||||
/// Helper function to optionally create a gather tensor
|
||||
template <class Iterator, class Shape, class Stride, class Func>
|
||||
CUTLASS_HOST_DEVICE auto make_gather_tensor(Iterator iter, Shape const& shape, Stride const& stride, Func&& func)
|
||||
{
|
||||
Layout matrix_layout = make_identity_layout(shape);
|
||||
auto offset = as_arithmetic_tuple(repeat_like(shape, _0{}));
|
||||
Layout gather_layout = make_custom_stride_layout(stride, static_cast<Func&&>(func));
|
||||
return make_tensor(iter, ComposedLayout{gather_layout, offset, matrix_layout});
|
||||
}
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
template <int N, int I, class Shape, class Stride>
|
||||
CUTE_HOST_DEVICE constexpr auto upcast(Shape const& shape, Stride const& stride)
|
||||
{
|
||||
if constexpr (is_tuple<Shape>::value)
|
||||
{
|
||||
return transform_layout(shape, stride, [](auto const& s, auto const& d) { return upcast<N, I>(s, d); });
|
||||
}
|
||||
else if constexpr (is_scaled_basis<Stride>::value)
|
||||
{
|
||||
if constexpr (Stride::mode() == I)
|
||||
{
|
||||
return make_layout(shape_div(shape, Int<N>{}), shape_div(stride, Int<N>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_layout(shape, stride);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return upcast<N>(shape, stride);
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
template <int N, class OuterShape, class OuterStride, class Offset, class Shape, class Stride>
|
||||
CUTE_HOST_DEVICE constexpr auto upcast(
|
||||
ComposedLayout<Layout<OuterShape, OuterStride>, Offset, Layout<Shape, Stride>> const& layout)
|
||||
{
|
||||
// Find index of the stride-1 mode - that is the only one that requires updating inner shape and offset
|
||||
auto idx = find_if(layout.layout_a().stride(), [](auto x) { return is_constant<1, decltype(x)>{}; });
|
||||
constexpr int I = decltype(idx)::value;
|
||||
|
||||
// Upcast the outer layout (works as expected)
|
||||
auto outer = upcast<N>(layout.layout_a());
|
||||
|
||||
// Upcast the accumulated offset along stride-1 mode
|
||||
auto offset = as_arithmetic_tuple(replace<I>(layout.offset(), upcast<N>(get<I>(layout.offset()))));
|
||||
|
||||
// Upcast the inner layout's shape along stride-1 mode
|
||||
auto inner = upcast<N, I>(layout.layout_b().shape(), layout.layout_b().stride());
|
||||
|
||||
return composition(outer, offset, inner);
|
||||
}
|
||||
|
||||
} // namespace cute
|
||||
@@ -0,0 +1,58 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
|
||||
enum class WeightOnlyQuantOp
|
||||
{
|
||||
UNDEFINED,
|
||||
PER_COLUMN_SCALE_ONLY,
|
||||
FINEGRAINED_SCALE_ONLY,
|
||||
FINEGRAINED_SCALE_AND_ZEROS
|
||||
};
|
||||
|
||||
constexpr bool isFinegrained(WeightOnlyQuantOp op)
|
||||
{
|
||||
return op == WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS || op == WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY;
|
||||
}
|
||||
|
||||
constexpr bool hasZero(WeightOnlyQuantOp op)
|
||||
{
|
||||
return op == WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS;
|
||||
}
|
||||
|
||||
} // namespace cutlass
|
||||
@@ -223,3 +223,208 @@ BSD 3-Clause "New" License
|
||||
|
||||
3rdparty/cutlass
|
||||
include/flashinfer/attention/hopper/block_sparse_gather.cuh
|
||||
|
||||
Notice for NVIDIA/TensorRT-LLM
|
||||
-------------------------------
|
||||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
||||
@@ -39,6 +39,8 @@ cutlass_default = root / "3rdparty" / "cutlass"
|
||||
cutlass = Path(os.environ.get("CUSTOM_CUTLASS_SRC_DIR", default=cutlass_default))
|
||||
flashinfer = root / "3rdparty" / "flashinfer"
|
||||
turbomind = root / "3rdparty" / "turbomind"
|
||||
tensorrt_llm_parent = root / "3rdparty"
|
||||
tensorrt_llm = root / "3rdparty" / "tensorrt_llm"
|
||||
include_dirs = [
|
||||
cutlass.resolve() / "include",
|
||||
cutlass.resolve() / "tools" / "util" / "include",
|
||||
@@ -51,6 +53,8 @@ include_dirs = [
|
||||
"cublasLt",
|
||||
turbomind.resolve(),
|
||||
turbomind.resolve() / "src",
|
||||
tensorrt_llm_parent.resolve(),
|
||||
tensorrt_llm.resolve() / "cutlass_extensions" / "include",
|
||||
]
|
||||
|
||||
nvcc_flags = [
|
||||
|
||||
Reference in New Issue
Block a user