fix: remove cublas_grouped_gemm (#5307)
This commit is contained in:
@@ -25,7 +25,6 @@ from sgl_kernel.elementwise import (
|
||||
from sgl_kernel.gemm import (
|
||||
awq_dequantize,
|
||||
bmm_fp8,
|
||||
cublas_grouped_gemm,
|
||||
cutlass_scaled_fp4_mm,
|
||||
fp8_blockwise_scaled_mm,
|
||||
fp8_scaled_mm,
|
||||
|
||||
@@ -121,26 +121,6 @@ def sgl_per_tensor_quant_fp8(
|
||||
)
|
||||
|
||||
|
||||
def cublas_grouped_gemm(
|
||||
inputs: List[torch.Tensor],
|
||||
weights: List[torch.Tensor],
|
||||
outputs: List[torch.Tensor],
|
||||
out_dtype: torch.dtype,
|
||||
) -> None:
|
||||
assert (
|
||||
len(inputs) > 0 and len(weights) > 0 and len(outputs) > 0
|
||||
), "Inputs/weights/outputs should not be empty!"
|
||||
cublas_handle = torch.cuda.current_blas_handle()
|
||||
torch.ops.sgl_kernel.cublas_grouped_gemm.default(
|
||||
inputs,
|
||||
weights,
|
||||
outputs,
|
||||
out_dtype,
|
||||
cublas_handle,
|
||||
get_cuda_stream(),
|
||||
)
|
||||
|
||||
|
||||
def sgl_per_token_quant_fp8(
|
||||
input: torch.Tensor,
|
||||
output_q: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user