Support FP4 gemm (1/2) (#3899)

This commit is contained in:
Trevor Morris
2025-03-24 19:50:23 -07:00
committed by GitHub
parent 22c3702e1e
commit e9f8e42318
11 changed files with 1245 additions and 5 deletions

View File

@@ -113,6 +113,13 @@ void apply_rope_pos_ids_cos_sin_cache(
* From csrc/gemm
*/
torch::Tensor awq_dequantize(torch::Tensor qweight, torch::Tensor scales, torch::Tensor qzeros);
void cutlass_scaled_fp4_mm(
torch::Tensor& D,
torch::Tensor const& A,
torch::Tensor const& B,
torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha);
torch::Tensor int8_scaled_mm(
const torch::Tensor& mat_a,
const torch::Tensor& mat_b,
@@ -133,6 +140,8 @@ torch::Tensor fp8_blockwise_scaled_mm(
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const torch::Dtype& out_dtype);
void scaled_fp4_quant(
torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_scale, torch::Tensor const& input_scale);
void sgl_per_token_group_quant_fp8(
at::Tensor input,
at::Tensor output_q,