use default for torch.ops (#4835)
This commit is contained in:
@@ -7,11 +7,11 @@ from sgl_kernel.utils import _get_cache_buf, get_cuda_stream
|
||||
def awq_dequantize(
|
||||
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
|
||||
) -> torch.ByteTensor:
|
||||
return torch.ops.sgl_kernel.awq_dequantize(qweight, scales, qzeros)
|
||||
return torch.ops.sgl_kernel.awq_dequantize.default(qweight, scales, qzeros)
|
||||
|
||||
|
||||
def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
|
||||
return torch.ops.sgl_kernel.int8_scaled_mm(
|
||||
return torch.ops.sgl_kernel.int8_scaled_mm.default(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scales_a,
|
||||
@@ -22,7 +22,7 @@ def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
|
||||
|
||||
|
||||
def fp8_blockwise_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype):
|
||||
return torch.ops.sgl_kernel.fp8_blockwise_scaled_mm(
|
||||
return torch.ops.sgl_kernel.fp8_blockwise_scaled_mm.default(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scales_a,
|
||||
@@ -32,7 +32,7 @@ def fp8_blockwise_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype):
|
||||
|
||||
|
||||
def fp8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
|
||||
return torch.ops.sgl_kernel.fp8_scaled_mm(
|
||||
return torch.ops.sgl_kernel.fp8_scaled_mm.default(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scales_a,
|
||||
@@ -51,7 +51,7 @@ def _bmm_fp8_internal(
|
||||
B_scale: torch.Tensor,
|
||||
) -> None:
|
||||
cublas_handle = torch.cuda.current_blas_handle()
|
||||
torch.ops.sgl_kernel.bmm_fp8(
|
||||
torch.ops.sgl_kernel.bmm_fp8.default(
|
||||
A,
|
||||
B,
|
||||
D,
|
||||
@@ -91,7 +91,7 @@ def sgl_per_token_group_quant_fp8(
|
||||
fp8_min: float,
|
||||
fp8_max: float,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.sgl_per_token_group_quant_fp8(
|
||||
torch.ops.sgl_kernel.sgl_per_token_group_quant_fp8.default(
|
||||
input, output_q, output_s, group_size, eps, fp8_min, fp8_max
|
||||
)
|
||||
|
||||
@@ -105,7 +105,7 @@ def sgl_per_token_group_quant_int8(
|
||||
int8_min: float,
|
||||
int8_max: float,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.sgl_per_token_group_quant_int8(
|
||||
torch.ops.sgl_kernel.sgl_per_token_group_quant_int8.default(
|
||||
input, output_q, output_s, group_size, eps, int8_min, int8_max
|
||||
)
|
||||
|
||||
@@ -116,7 +116,9 @@ def sgl_per_tensor_quant_fp8(
|
||||
output_s: torch.Tensor,
|
||||
is_static: bool,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.sgl_per_tensor_quant_fp8(input, output_q, output_s, is_static)
|
||||
torch.ops.sgl_kernel.sgl_per_tensor_quant_fp8.default(
|
||||
input, output_q, output_s, is_static
|
||||
)
|
||||
|
||||
|
||||
def cublas_grouped_gemm(
|
||||
@@ -129,7 +131,7 @@ def cublas_grouped_gemm(
|
||||
len(inputs) > 0 and len(weights) > 0 and len(outputs) > 0
|
||||
), "Inputs/weights/outputs should not be empty!"
|
||||
cublas_handle = torch.cuda.current_blas_handle()
|
||||
torch.ops.sgl_kernel.cublas_grouped_gemm(
|
||||
torch.ops.sgl_kernel.cublas_grouped_gemm.default(
|
||||
inputs,
|
||||
weights,
|
||||
outputs,
|
||||
@@ -144,7 +146,7 @@ def sgl_per_token_quant_fp8(
|
||||
output_q: torch.Tensor,
|
||||
output_s: torch.Tensor,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.sgl_per_token_quant_fp8(input, output_q, output_s)
|
||||
torch.ops.sgl_kernel.sgl_per_token_quant_fp8.default(input, output_q, output_s)
|
||||
|
||||
|
||||
def cutlass_scaled_fp4_mm(
|
||||
@@ -158,7 +160,7 @@ def cutlass_scaled_fp4_mm(
|
||||
assert a.ndim == 2 and b.ndim == 2
|
||||
m, n = a.shape[0], b.shape[0]
|
||||
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
|
||||
torch.ops.sgl_kernels.cutlass_scaled_fp4_mm(
|
||||
torch.ops.sgl_kernel.cutlass_scaled_fp4_mm.default(
|
||||
out, a, b, block_scale_a, block_scale_b, alpha
|
||||
)
|
||||
return out
|
||||
@@ -210,7 +212,7 @@ def scaled_fp4_quant(
|
||||
(rounded_m, rounded_n // 4), device=device, dtype=torch.int32
|
||||
)
|
||||
|
||||
torch.ops.sgl_kernels.scaled_fp4_quant(
|
||||
torch.ops.sgl_kernel.scaled_fp4_quant.default(
|
||||
output, input, output_scale, input_global_scale
|
||||
)
|
||||
output_scale = output_scale.view(torch.float8_e4m3fn)
|
||||
|
||||
Reference in New Issue
Block a user