2025-01-26 02:55:08 +08:00
|
|
|
#pragma once
|
2025-01-26 14:29:58 +08:00
|
|
|
|
2025-01-26 02:55:08 +08:00
|
|
|
#include <Python.h>
|
|
|
|
|
#include <torch/extension.h>
|
|
|
|
|
|
2025-01-20 00:44:49 +08:00
|
|
|
#include <vector>
|
|
|
|
|
|
2025-01-26 02:55:08 +08:00
|
|
|
#define _CONCAT(A, B) A##B
|
|
|
|
|
#define CONCAT(A, B) _CONCAT(A, B)
|
|
|
|
|
|
|
|
|
|
#define _STRINGIFY(A) #A
|
|
|
|
|
#define STRINGIFY(A) _STRINGIFY(A)
|
|
|
|
|
|
|
|
|
|
#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
|
|
|
|
|
|
|
|
|
|
#define REGISTER_EXTENSION(NAME) \
|
|
|
|
|
PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \
|
|
|
|
|
static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, STRINGIFY(NAME), nullptr, 0, nullptr}; \
|
|
|
|
|
return PyModule_Create(&module); \
|
|
|
|
|
}
|
|
|
|
|
|
2024-12-30 18:07:01 +08:00
|
|
|
// trt_reduce
|
|
|
|
|
using fptr_t = int64_t;
|
2025-01-16 03:04:25 +08:00
|
|
|
fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, torch::Tensor& rank_data, const std::vector<fptr_t>& buffers,
|
|
|
|
|
const std::vector<fptr_t>& tmp_result_buffers, const std::vector<fptr_t>& barrier_in,
|
|
|
|
|
const std::vector<fptr_t>& barrier_out);
|
2024-12-30 18:07:01 +08:00
|
|
|
void dispose(fptr_t _fa);
|
|
|
|
|
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
|
2025-01-16 03:04:25 +08:00
|
|
|
std::tuple<std::vector<int64_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa);
|
|
|
|
|
void register_graph_buffers(fptr_t _fa, const std::vector<std::vector<int64_t>>& handles,
|
|
|
|
|
const std::vector<std::vector<int64_t>>& offsets);
|
2024-12-30 18:07:01 +08:00
|
|
|
|
|
|
|
|
// moe_align_block_size
|
|
|
|
|
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size,
|
|
|
|
|
torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad,
|
|
|
|
|
torch::Tensor token_cnts_buffer, torch::Tensor cumsum_buffer);
|
|
|
|
|
|
2025-01-06 22:51:22 +08:00
|
|
|
// int8_scaled_mm
|
|
|
|
|
torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a,
|
|
|
|
|
const torch::Tensor& scales_b, const torch::Dtype& out_dtype,
|
|
|
|
|
const c10::optional<torch::Tensor>& bias);
|
|
|
|
|
|
2025-01-26 15:46:51 +08:00
|
|
|
// fp8_scaled_mm
|
|
|
|
|
torch::Tensor fp8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a,
|
|
|
|
|
const torch::Tensor& scales_b, const torch::Dtype& out_dtype,
|
|
|
|
|
const c10::optional<torch::Tensor>& bias);
|
|
|
|
|
|
2025-01-23 15:29:20 +08:00
|
|
|
// lightning_attention_decode
|
|
|
|
|
void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v,
|
|
|
|
|
const torch::Tensor& past_kv, const torch::Tensor& slope, torch::Tensor output,
|
|
|
|
|
torch::Tensor new_kv);
|
|
|
|
|
|
2025-01-21 20:44:49 +08:00
|
|
|
// rms norm
|
|
|
|
|
void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream);
|
|
|
|
|
|
2025-01-22 21:32:48 +08:00
|
|
|
// fused rms norm
|
2025-01-27 19:11:01 +08:00
|
|
|
void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps);
|
2025-01-22 21:32:48 +08:00
|
|
|
|
|
|
|
|
// gemma rms norm
|
|
|
|
|
void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream);
|
|
|
|
|
|
|
|
|
|
// fused gemma rms norm
|
|
|
|
|
void gemma_fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps,
|
|
|
|
|
int64_t cuda_stream);
|
|
|
|
|
|
2025-01-22 23:25:45 +08:00
|
|
|
// silu and mul
|
|
|
|
|
void silu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
|
|
|
|
|
|
|
|
|
|
// gelu tanh and mul
|
|
|
|
|
void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
|
|
|
|
|
|
|
|
|
|
// gelu and mul
|
|
|
|
|
void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
|
|
|
|
|
|
2025-01-23 00:39:38 +08:00
|
|
|
// bmm fp8
|
|
|
|
|
void bmm_fp8(at::Tensor A, at::Tensor B, at::Tensor D, at::Tensor A_scale, at::Tensor B_scale,
|
|
|
|
|
at::Tensor workspace_buffer, int64_t cublas_handle, int64_t cuda_stream);
|
|
|
|
|
|
2025-01-24 01:54:47 +08:00
|
|
|
// min p sampling from probs
|
|
|
|
|
void min_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples,
|
|
|
|
|
std::optional<at::Tensor> maybe_min_p_arr, double min_p_val, bool deterministic,
|
|
|
|
|
int64_t cuda_stream);
|
|
|
|
|
|
|
|
|
|
// top k renorm probs
|
2025-01-26 02:55:08 +08:00
|
|
|
// patch here, cause flashinfer use unsigned int. but torch must use int64_t for extension.
|
2025-01-24 01:54:47 +08:00
|
|
|
void top_k_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, std::optional<at::Tensor> maybe_top_k_arr,
|
|
|
|
|
unsigned int top_k_val, int64_t cuda_stream);
|
|
|
|
|
|
2025-01-26 02:55:08 +08:00
|
|
|
// patch here, cause flashinfer use unsigned int. but torch must use int64_t for extension.
|
|
|
|
|
// wrapper for binding
|
|
|
|
|
inline void top_k_renorm_probs_wrapper(at::Tensor probs, at::Tensor renorm_probs,
|
|
|
|
|
std::optional<at::Tensor> maybe_top_k_arr, int64_t top_k_val,
|
|
|
|
|
int64_t cuda_stream) {
|
|
|
|
|
top_k_renorm_probs(probs, renorm_probs, maybe_top_k_arr, static_cast<unsigned int>(top_k_val), cuda_stream);
|
|
|
|
|
}
|
|
|
|
|
|
2025-01-24 01:54:47 +08:00
|
|
|
// top p renorm probs
|
|
|
|
|
void top_p_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, std::optional<at::Tensor> maybe_top_p_arr,
|
|
|
|
|
double top_p_val, int64_t cuda_stream);
|
|
|
|
|
|
|
|
|
|
// top k top p sampling from probs
|
|
|
|
|
void top_k_top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples,
|
|
|
|
|
at::Tensor success, std::optional<at::Tensor> maybe_top_k_arr, double top_k_val,
|
|
|
|
|
std::optional<at::Tensor> maybe_top_p_arr, double top_p_val, bool deterministic,
|
|
|
|
|
int64_t cuda_stream);
|
|
|
|
|
|
|
|
|
|
// top p sampling from probs
|
|
|
|
|
void top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, at::Tensor success,
|
|
|
|
|
std::optional<at::Tensor> maybe_top_p_arr, double top_p_val, bool deterministic,
|
|
|
|
|
int64_t cuda_stream);
|
2025-01-26 23:28:00 -08:00
|
|
|
|
|
|
|
|
void apply_rope_pos_ids_cos_sin_cache(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tensor k_rope,
|
|
|
|
|
at::Tensor cos_sin_cache, at::Tensor pos_ids, bool interleave,
|
|
|
|
|
int64_t cuda_stream);
|