add tensorrt_llm common and cutlass_extensions as 3rdparty (#3216)
Co-authored-by: BBuf <35585791+BBuf@users.noreply.github.com>
This commit is contained in:
148
sgl-kernel/3rdparty/tensorrt_llm/common/cublasMMWrapper.h
vendored
Normal file
148
sgl-kernel/3rdparty/tensorrt_llm/common/cublasMMWrapper.h
vendored
Normal file
@@ -0,0 +1,148 @@
|
||||
/*
|
||||
* Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include <cublasLt.h>
|
||||
#include <cublas_v2.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <map>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace common
|
||||
{
|
||||
|
||||
class CublasMMWrapper
|
||||
{
|
||||
protected:
|
||||
std::shared_ptr<cublasHandle_t> mCublasHandle;
|
||||
std::shared_ptr<cublasLtHandle_t> mCublasLtHandle;
|
||||
|
||||
cudaDataType_t mAType{};
|
||||
cudaDataType_t mBType{};
|
||||
cudaDataType_t mCType{};
|
||||
cublasComputeType_t mComputeType{};
|
||||
cudaDataType_t mScaleType{};
|
||||
|
||||
cublasLtMatmulDesc_t mOperationDesc{NULL};
|
||||
cublasLtMatrixLayout_t mADesc{NULL};
|
||||
cublasLtMatrixLayout_t mBDesc{NULL};
|
||||
cublasLtMatrixLayout_t mCDesc{NULL};
|
||||
|
||||
cudaStream_t mStream;
|
||||
|
||||
void* mCublasWorkspace = nullptr;
|
||||
|
||||
private:
|
||||
bool descriptorsCreated() const
|
||||
{
|
||||
return mOperationDesc != NULL && mADesc != NULL && mBDesc != NULL && mCDesc != NULL;
|
||||
}
|
||||
|
||||
public:
|
||||
CublasMMWrapper(std::shared_ptr<cublasHandle_t> cublasHandle, std::shared_ptr<cublasLtHandle_t> cublasLtHandle,
|
||||
cudaStream_t stream, void* workspace);
|
||||
|
||||
~CublasMMWrapper();
|
||||
|
||||
CublasMMWrapper(CublasMMWrapper const& wrapper);
|
||||
|
||||
/********************** GEMMs **********************/
|
||||
void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A,
|
||||
int const lda, void const* B, int const ldb, void* C, int const ldc);
|
||||
|
||||
void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A,
|
||||
int const lda, void const* B, int const ldb, void* C, int const ldc,
|
||||
std::optional<cublasLtMatmulHeuristicResult_t> const& algo);
|
||||
|
||||
void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A,
|
||||
int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta,
|
||||
std::optional<cublasLtMatmulHeuristicResult_t> const& algo);
|
||||
|
||||
void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A,
|
||||
int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta);
|
||||
|
||||
void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A,
|
||||
int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta,
|
||||
cublasLtMatmulAlgo_t const& algo, bool hasAlgo, bool usingCublasLt);
|
||||
|
||||
void stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
|
||||
void const* A, int const lda, const int64_t strideA, void const* B, int const ldb, const int64_t strideB,
|
||||
void* C, int const ldc, const int64_t strideC, int const batchCount, float const f_alpha = 1.0f,
|
||||
float const f_beta = 0.0f);
|
||||
|
||||
void stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
|
||||
float const f_alpha, void const* A, cudaDataType_t AType, int const lda, const int64_t strideA, void const* B,
|
||||
cudaDataType_t BType, int const ldb, const int64_t strideB, float const f_beta, void* C, cudaDataType_t CType,
|
||||
int const ldc, const int64_t strideC, int const batchCount, cudaDataType_t computeType);
|
||||
|
||||
/********************** Tactic selection helpers **********************/
|
||||
bool checkTactic(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
|
||||
int const lda, int const ldb, int const ldc, cublasLtMatmulAlgo_t const& algo);
|
||||
|
||||
std::vector<cublasLtMatmulHeuristicResult_t> getTactics(cublasOperation_t transa, cublasOperation_t transb,
|
||||
int const m, int const n, int const k, int const lda, int const ldb, int const ldc);
|
||||
|
||||
std::vector<cublasLtMatmulHeuristicResult_t> getTactics(cublasLtHandle_t lightHandle,
|
||||
cublasLtMatmulDesc_t computeDesc, cublasLtMatrixLayout_t Adesc, cublasLtMatrixLayout_t Bdesc,
|
||||
cublasLtMatrixLayout_t Cdesc, cublasLtMatrixLayout_t Ddesc);
|
||||
|
||||
using MatrixLayout = std::tuple<cudaDataType_t, cublasLtOrder_t, uint64_t, uint64_t>;
|
||||
using cache_idx_t = std::tuple<cublasLtMatmulDesc_t, std::array<MatrixLayout, 4>>;
|
||||
|
||||
MatrixLayout createMatrixLayout(cublasLtMatrixLayout_t Mdesc);
|
||||
|
||||
/********************** Utils **********************/
|
||||
void setWorkspace(void* workspace);
|
||||
|
||||
void setFP32GemmConfig();
|
||||
void setFP16GemmConfig(cudaDataType_t outputType = CUDA_R_16F);
|
||||
#ifdef ENABLE_BF16
|
||||
void setBF16GemmConfig(cudaDataType_t outputType = CUDA_R_16BF);
|
||||
#endif
|
||||
#ifdef ENABLE_FP8
|
||||
void setFP8GemmConfig(cudaDataType_t outputType = CUDA_R_16F);
|
||||
#endif
|
||||
|
||||
void setStream(cudaStream_t stream);
|
||||
|
||||
void setGemmConfig(cudaDataType_t aType, cudaDataType_t bType, cudaDataType_t cType, cudaDataType_t computeType);
|
||||
|
||||
CublasDataType getCublasDataType(cudaDataType_t data_type);
|
||||
|
||||
void createDescriptors(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
|
||||
int const lda, int const ldb, int const ldc, int8_t fastAcc = 0);
|
||||
void setScaleDescriptors(void* scale_a, void* scale_b);
|
||||
void destroyDescriptors();
|
||||
|
||||
cublasHandle_t getCublasHandle()
|
||||
{
|
||||
return *(this->mCublasHandle);
|
||||
}
|
||||
|
||||
cublasLtHandle_t getCublasLtHandle() const
|
||||
{
|
||||
return *(this->mCublasLtHandle);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace common
|
||||
|
||||
} // namespace tensorrt_llm
|
||||
Reference in New Issue
Block a user