This commit is contained in:
@@ -207,17 +207,23 @@ torch::Tensor fp8_blockwise_scaled_mm(
|
||||
const torch::Dtype& out_dtype);
|
||||
void scaled_fp4_quant(
|
||||
torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_scale, torch::Tensor const& input_scale);
|
||||
void sgl_per_token_group_quant_8bit(
|
||||
void sgl_per_token_group_quant_fp8(
|
||||
at::Tensor input,
|
||||
at::Tensor output_q,
|
||||
at::Tensor output_s,
|
||||
int64_t group_size,
|
||||
double eps,
|
||||
double min_8bit,
|
||||
double max_8bit,
|
||||
bool scale_ue8m0,
|
||||
bool fuse_silu_and_mul,
|
||||
const std::optional<torch::Tensor>& masked_m);
|
||||
double fp8_min,
|
||||
double fp8_max,
|
||||
bool scale_ue8m0);
|
||||
void sgl_per_token_group_quant_int8(
|
||||
at::Tensor input,
|
||||
at::Tensor output_q,
|
||||
at::Tensor output_s,
|
||||
int64_t group_size,
|
||||
double eps,
|
||||
double int8_min,
|
||||
double int8_max);
|
||||
void sgl_per_tensor_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s, bool is_static);
|
||||
void sgl_per_token_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s);
|
||||
void bmm_fp8(
|
||||
|
||||
Reference in New Issue
Block a user