[CPU] Fix build issue (#6419)

This commit is contained in:
blzheng
2025-05-22 02:17:10 +08:00
committed by GitHub
parent d4c038daed
commit cfe48c5902
14 changed files with 157 additions and 143 deletions

View File

@@ -1,18 +1,10 @@
import itertools
import unittest
# TODO: use interface in cpu.py
import sgl_kernel
import torch
import torch.nn as nn
# TODO: use interface in cpu.py
from sgl_kernel.common_ops import (
convert_weight_packed,
fp8_scaled_mm_cpu,
int8_scaled_mm_cpu,
int8_scaled_mm_with_quant,
per_token_quant_int8_cpu,
weight_packed_linear,
)
from utils import (
convert_weight,
native_w8a8_per_token_matmul,
@@ -58,10 +50,14 @@ class TestGemm(CustomTestCase):
ref = ref.bfloat16()
out = weight_packed_linear(mat1, mat2, bias if has_bias else None, False)
out = torch.ops.sgl_kernel.weight_packed_linear(
mat1, mat2, bias if has_bias else None, False
)
packed_mat2 = convert_weight_packed(mat2)
out2 = weight_packed_linear(mat1, packed_mat2, bias if has_bias else None, True)
packed_mat2 = torch.ops.sgl_kernel.convert_weight_packed(mat2)
out2 = torch.ops.sgl_kernel.weight_packed_linear(
mat1, packed_mat2, bias if has_bias else None, True
)
atol = rtol = precision[ref.dtype]
self.assertTrue(torch.allclose(ref, out, atol=atol, rtol=rtol))
@@ -100,14 +96,14 @@ class TestGemm(CustomTestCase):
atol = rtol = precision[ref_out.dtype]
Aq2, As2 = per_token_quant_int8_cpu(A)
out = int8_scaled_mm_cpu(
Aq2, As2 = torch.ops.sgl_kernel.per_token_quant_int8_cpu(A)
out = torch.ops.sgl_kernel.int8_scaled_mm_cpu(
Aq2, Bq, As2, Bs, bias if has_bias else None, torch.bfloat16, False
)
self.assertTrue(torch.allclose(ref_out, out, atol=atol, rtol=rtol))
# test the fused version
fused_out = int8_scaled_mm_with_quant(
fused_out = torch.ops.sgl_kernel.int8_scaled_mm_with_quant(
A, Bq, Bs, bias if has_bias else None, torch.bfloat16, False
)
self.assertTrue(torch.allclose(ref_out, fused_out, atol=atol, rtol=rtol))
@@ -157,9 +153,9 @@ class TestGemm(CustomTestCase):
ref = torch.matmul(data.to(A_dtype), dq_weight.T)
if prepack:
fp8_weight = convert_weight_packed(fp8_weight)
fp8_weight = torch.ops.sgl_kernel.convert_weight_packed(fp8_weight)
opt = fp8_scaled_mm_cpu(
opt = torch.ops.sgl_kernel.fp8_scaled_mm_cpu(
data,
fp8_weight,
scales,

View File

@@ -2,12 +2,10 @@ import itertools
import math
import unittest
# TODO: use interface in cpu.py
import sgl_kernel
import torch
import torch.nn as nn
# TODO: use interface in cpu.py
from sgl_kernel.common_ops import convert_weight_packed
from sgl_kernel.common_ops import shared_expert_cpu as shared_expert
from utils import (
BLOCK_K,
BLOCK_N,
@@ -55,7 +53,7 @@ class TestSharedExpert(CustomTestCase):
fused_output.float(),
routed_scaling_factor,
).to(dtype=dtype)
res = shared_expert(
res = torch.ops.sgl_kernel.shared_expert_cpu(
hidden_states,
w1,
w2,
@@ -113,7 +111,7 @@ class TestSharedExpert(CustomTestCase):
fused_output.float(),
routed_scaling_factor,
).to(dtype=dtype)
res2 = shared_expert(
res2 = torch.ops.sgl_kernel.shared_expert_cpu(
hidden_states2,
w1_q,
w2_q,
@@ -181,9 +179,9 @@ class TestSharedExpert(CustomTestCase):
ref_out = shared_out + fused_out.float() * routed_scaling_factor
ref_out = ref_out.to(dtype=dtype)
w1 = convert_weight_packed(w1) # [2N, K]
w2 = convert_weight_packed(w2) # [K, N]
out = shared_expert(
w1 = torch.ops.sgl_kernel.convert_weight_packed(w1) # [2N, K]
w2 = torch.ops.sgl_kernel.convert_weight_packed(w2) # [K, N]
out = torch.ops.sgl_kernel.shared_expert_cpu(
a2,
w1,
w2,