[1/2] Support Qserve (#6457)
Co-authored-by: yych0745 <1398089567@qq.com> Co-authored-by: sleepcoo <sleepcoo@gmail.com>
This commit is contained in:
@@ -36,6 +36,8 @@ from sgl_kernel.gemm import (
|
||||
fp8_blockwise_scaled_mm,
|
||||
fp8_scaled_mm,
|
||||
int8_scaled_mm,
|
||||
qserve_w4a8_per_chn_gemm,
|
||||
qserve_w4a8_per_group_gemm,
|
||||
scaled_fp4_quant,
|
||||
sgl_per_tensor_quant_fp8,
|
||||
sgl_per_token_group_quant_fp8,
|
||||
|
||||
@@ -197,3 +197,47 @@ def scaled_fp4_quant(
|
||||
)
|
||||
output_scale = output_scale.view(torch.float8_e4m3fn)
|
||||
return output, output_scale
|
||||
|
||||
|
||||
def qserve_w4a8_per_chn_gemm(
|
||||
in_feats: torch.Tensor,
|
||||
kernel: torch.Tensor,
|
||||
wscales: torch.Tensor,
|
||||
ascales: torch.Tensor,
|
||||
w_szs: torch.Tensor,
|
||||
a_ssums: torch.Tensor,
|
||||
out_feats: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if out_feats is None:
|
||||
# NOTE(HandH1998): qserve_w4a8_per_chn_gemm only supports out dtype=torch.float16 now
|
||||
out_feats = torch.empty(
|
||||
(in_feats.shape[0], kernel.shape[0]),
|
||||
device=in_feats.device,
|
||||
dtype=torch.float16,
|
||||
)
|
||||
torch.ops.sgl_kernel.qserve_w4a8_per_chn_gemm.default(
|
||||
in_feats, kernel, wscales, ascales, w_szs, a_ssums, out_feats
|
||||
)
|
||||
return out_feats
|
||||
|
||||
|
||||
def qserve_w4a8_per_group_gemm(
|
||||
in_feats: torch.Tensor,
|
||||
kernel: torch.Tensor,
|
||||
zeros: torch.Tensor,
|
||||
scales_i8: torch.Tensor,
|
||||
wscales: torch.Tensor,
|
||||
ascales: torch.Tensor,
|
||||
out_feats: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if out_feats is None:
|
||||
# NOTE(HandH1998): qserve_w4a8_per_group_gemm only supports out dtype=torch.float16 now
|
||||
out_feats = torch.empty(
|
||||
(in_feats.shape[0], kernel.shape[0]),
|
||||
device=in_feats.device,
|
||||
dtype=torch.float16,
|
||||
)
|
||||
torch.ops.sgl_kernel.qserve_w4a8_per_group_gemm.default(
|
||||
in_feats, kernel, zeros, scales_i8, wscales, ascales, out_feats
|
||||
)
|
||||
return out_feats
|
||||
|
||||
Reference in New Issue
Block a user