Support FP4 gemm (1/2) (#3899)
This commit is contained in:
@@ -113,6 +113,13 @@ void apply_rope_pos_ids_cos_sin_cache(
|
||||
* From csrc/gemm
|
||||
*/
|
||||
torch::Tensor awq_dequantize(torch::Tensor qweight, torch::Tensor scales, torch::Tensor qzeros);
|
||||
void cutlass_scaled_fp4_mm(
|
||||
torch::Tensor& D,
|
||||
torch::Tensor const& A,
|
||||
torch::Tensor const& B,
|
||||
torch::Tensor const& A_sf,
|
||||
torch::Tensor const& B_sf,
|
||||
torch::Tensor const& alpha);
|
||||
torch::Tensor int8_scaled_mm(
|
||||
const torch::Tensor& mat_a,
|
||||
const torch::Tensor& mat_b,
|
||||
@@ -133,6 +140,8 @@ torch::Tensor fp8_blockwise_scaled_mm(
|
||||
const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b,
|
||||
const torch::Dtype& out_dtype);
|
||||
void scaled_fp4_quant(
|
||||
torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_scale, torch::Tensor const& input_scale);
|
||||
void sgl_per_token_group_quant_fp8(
|
||||
at::Tensor input,
|
||||
at::Tensor output_q,
|
||||
|
||||
Reference in New Issue
Block a user