Add fp8 shared_expert kernel for CPU in sgl-kernel and add UT (#6339)
Co-authored-by: Jiang, Yanbing <yanbing.jiang@intel.com> Co-authored-by: mingfeima <mingfei.ma@intel.com>
This commit is contained in:
@@ -248,38 +248,6 @@ struct brgemm {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename scalar_t, bool has_bias>
|
||||
struct brgemm<scalar_t, scalar_t, has_bias> {
|
||||
static inline void apply(
|
||||
const scalar_t* __restrict__ A,
|
||||
const scalar_t* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
scalar_t* __restrict__ Btmp,
|
||||
float* __restrict__ Ctmp,
|
||||
const float* __restrict__ bias,
|
||||
const float* __restrict__ scale,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int lda,
|
||||
int ldb,
|
||||
int ldc) {
|
||||
UNUSED(scale);
|
||||
|
||||
constexpr int BLOCK_N = block_size_n();
|
||||
at::native::cpublas::brgemm(M, N, K, lda, ldb, BLOCK_N, /* add_C */ false, A, B, Ctmp);
|
||||
|
||||
// copy from Ctmp to C
|
||||
for (int m = 0; m < M; ++m) {
|
||||
if constexpr (has_bias) {
|
||||
copy_add_stub(C + m * ldc, Ctmp + m * BLOCK_N, bias, N);
|
||||
} else {
|
||||
copy_stub(C + m * ldc, Ctmp + m * BLOCK_N, N);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <bool has_bias>
|
||||
struct brgemm<at::BFloat16, at::Float8_e4m3fn, has_bias> {
|
||||
static inline void apply(
|
||||
@@ -469,6 +437,46 @@ void fp8_scaled_mm_kernel_impl(
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
// tinygemm interface
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(
|
||||
const scalar_t* __restrict__ A,
|
||||
const at::Float8_e4m3fn* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
scalar_t* __restrict__ Btmp,
|
||||
float* __restrict__ Ctmp,
|
||||
const float* __restrict__ scale,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
bool brg,
|
||||
int64_t block_size_K) {
|
||||
tinygemm_kernel<scalar_t, false>(A, B, C, Btmp, Ctmp, scale, nullptr, M, N, K, lda, ldb, ldc, brg, block_size_K);
|
||||
}
|
||||
|
||||
#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \
|
||||
template void tinygemm_kernel<TYPE>( \
|
||||
const TYPE* __restrict__ A, \
|
||||
const at::Float8_e4m3fn* __restrict__ B, \
|
||||
TYPE* __restrict__ C, \
|
||||
TYPE* __restrict__ Btmp, \
|
||||
float* __restrict__ Ctmp, \
|
||||
const float* __restrict__ scale, \
|
||||
int64_t M, \
|
||||
int64_t N, \
|
||||
int64_t K, \
|
||||
int64_t lda, \
|
||||
int64_t ldb, \
|
||||
int64_t ldc, \
|
||||
bool brg, \
|
||||
int64_t block_size_K)
|
||||
|
||||
INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16);
|
||||
INSTANTIATE_TINYGEMM_TEMPLATE(at::Half);
|
||||
|
||||
at::Tensor fp8_scaled_mm_cpu(
|
||||
at::Tensor& mat1,
|
||||
at::Tensor& mat2,
|
||||
|
||||
Reference in New Issue
Block a user