Reland [1/2] Optimizations and refactors about quant kernel (#10312)

Co-authored-by: Yineng Zhang <me@zhyncs.com>
This commit is contained in:
fzyzcjy
2025-10-11 15:59:03 +08:00
committed by GitHub
parent 129d299278
commit 21337b22b9
13 changed files with 1065 additions and 178 deletions

View File

@@ -219,7 +219,7 @@ torch::Tensor fp8_blockwise_scaled_mm(
const torch::Dtype& out_dtype);
void scaled_fp4_quant(
torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_scale, torch::Tensor const& input_scale);
void sgl_per_token_group_quant_fp8(
void sgl_per_token_group_quant_8bit(
at::Tensor input,
at::Tensor output_q,
at::Tensor output_s,
@@ -228,14 +228,17 @@ void sgl_per_token_group_quant_fp8(
double fp8_min,
double fp8_max,
bool scale_ue8m0);
void sgl_per_token_group_quant_int8(
void sgl_per_token_group_quant_8bit_v2(
at::Tensor input,
at::Tensor output_q,
at::Tensor output_s,
int64_t group_size,
double eps,
double int8_min,
double int8_max);
double min_8bit,
double max_8bit,
bool scale_ue8m0,
bool fuse_silu_and_mul,
const std::optional<torch::Tensor>& masked_m);
void sgl_per_tensor_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s, bool is_static);
void sgl_per_token_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s);
void bmm_fp8(