init
This commit is contained in:
70
csrc_musa/quantization/gptq_marlin/gptq_marlin.muh
Normal file
70
csrc_musa/quantization/gptq_marlin/gptq_marlin.muh
Normal file
@@ -0,0 +1,70 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include "torch_musa/csrc/aten/musa/MUSAContext.h"
|
||||
#include "torch_musa/csrc/core/MUSAGuard.h"
|
||||
#include <musa.h>
|
||||
#include <musa_fp16.h>
|
||||
#include <musa_runtime.h>
|
||||
#include <iostream>
|
||||
|
||||
namespace gptq_marlin {
|
||||
|
||||
// 8 warps are a good choice since every SM has 4 schedulers and having more than 1 warp per
|
||||
// schedule allows some more latency hiding. At the same time, we want relatively few warps to have
|
||||
// many registers per warp and small tiles.
|
||||
static constexpr int default_threads = 256;
|
||||
|
||||
static constexpr int pipe_stages = 4; // 4 pipeline stages fit into shared memory
|
||||
|
||||
static constexpr int min_thread_n = 64;
|
||||
static constexpr int min_thread_k = 64;
|
||||
|
||||
static constexpr int tile_size = 16;
|
||||
static constexpr int max_par = 16;
|
||||
|
||||
template <typename T, int n>
|
||||
struct Vec {
|
||||
T elems[n];
|
||||
__device__ T& operator[](int i) { return elems[i]; }
|
||||
};
|
||||
|
||||
using I4 = Vec<int, 4>;
|
||||
|
||||
constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
|
||||
|
||||
#if defined(__MUSA_ARCH__) && __MUSA_ARCH__ < 800
|
||||
// No support for async
|
||||
#else
|
||||
|
||||
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) {
|
||||
const int BYTES = 16;
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
asm volatile("{\n"
|
||||
" .reg .pred p;\n"
|
||||
" setp.ne.b32 p, %0, 0;\n"
|
||||
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
||||
"}\n" ::"r"((int)pred),
|
||||
"r"(smem), "l"(glob_ptr), "n"(BYTES));
|
||||
}
|
||||
|
||||
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
|
||||
const int BYTES = 16;
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
asm volatile("{\n"
|
||||
" cp.async.cg.shared.global [%0], [%1], %2;\n"
|
||||
"}\n" ::"r"(smem),
|
||||
"l"(glob_ptr), "n"(BYTES));
|
||||
}
|
||||
|
||||
__device__ inline void cp_async_fence() { asm volatile("cp.async.commit_group;\n" ::); }
|
||||
|
||||
template <int n>
|
||||
__device__ inline void cp_async_wait() {
|
||||
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace gptq_marlin
|
||||
Reference in New Issue
Block a user