fix: remove cublas_grouped_gemm (#5307)
This commit is contained in:
@@ -121,26 +121,6 @@ def sgl_per_tensor_quant_fp8(
|
||||
)
|
||||
|
||||
|
||||
def cublas_grouped_gemm(
|
||||
inputs: List[torch.Tensor],
|
||||
weights: List[torch.Tensor],
|
||||
outputs: List[torch.Tensor],
|
||||
out_dtype: torch.dtype,
|
||||
) -> None:
|
||||
assert (
|
||||
len(inputs) > 0 and len(weights) > 0 and len(outputs) > 0
|
||||
), "Inputs/weights/outputs should not be empty!"
|
||||
cublas_handle = torch.cuda.current_blas_handle()
|
||||
torch.ops.sgl_kernel.cublas_grouped_gemm.default(
|
||||
inputs,
|
||||
weights,
|
||||
outputs,
|
||||
out_dtype,
|
||||
cublas_handle,
|
||||
get_cuda_stream(),
|
||||
)
|
||||
|
||||
|
||||
def sgl_per_token_quant_fp8(
|
||||
input: torch.Tensor,
|
||||
output_q: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user