[CPU] Fix build issue (#6419)
This commit is contained in:
@@ -1,18 +1,10 @@
|
||||
import itertools
|
||||
import unittest
|
||||
|
||||
# TODO: use interface in cpu.py
|
||||
import sgl_kernel
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# TODO: use interface in cpu.py
|
||||
from sgl_kernel.common_ops import (
|
||||
convert_weight_packed,
|
||||
fp8_scaled_mm_cpu,
|
||||
int8_scaled_mm_cpu,
|
||||
int8_scaled_mm_with_quant,
|
||||
per_token_quant_int8_cpu,
|
||||
weight_packed_linear,
|
||||
)
|
||||
from utils import (
|
||||
convert_weight,
|
||||
native_w8a8_per_token_matmul,
|
||||
@@ -58,10 +50,14 @@ class TestGemm(CustomTestCase):
|
||||
|
||||
ref = ref.bfloat16()
|
||||
|
||||
out = weight_packed_linear(mat1, mat2, bias if has_bias else None, False)
|
||||
out = torch.ops.sgl_kernel.weight_packed_linear(
|
||||
mat1, mat2, bias if has_bias else None, False
|
||||
)
|
||||
|
||||
packed_mat2 = convert_weight_packed(mat2)
|
||||
out2 = weight_packed_linear(mat1, packed_mat2, bias if has_bias else None, True)
|
||||
packed_mat2 = torch.ops.sgl_kernel.convert_weight_packed(mat2)
|
||||
out2 = torch.ops.sgl_kernel.weight_packed_linear(
|
||||
mat1, packed_mat2, bias if has_bias else None, True
|
||||
)
|
||||
|
||||
atol = rtol = precision[ref.dtype]
|
||||
self.assertTrue(torch.allclose(ref, out, atol=atol, rtol=rtol))
|
||||
@@ -100,14 +96,14 @@ class TestGemm(CustomTestCase):
|
||||
|
||||
atol = rtol = precision[ref_out.dtype]
|
||||
|
||||
Aq2, As2 = per_token_quant_int8_cpu(A)
|
||||
out = int8_scaled_mm_cpu(
|
||||
Aq2, As2 = torch.ops.sgl_kernel.per_token_quant_int8_cpu(A)
|
||||
out = torch.ops.sgl_kernel.int8_scaled_mm_cpu(
|
||||
Aq2, Bq, As2, Bs, bias if has_bias else None, torch.bfloat16, False
|
||||
)
|
||||
self.assertTrue(torch.allclose(ref_out, out, atol=atol, rtol=rtol))
|
||||
|
||||
# test the fused version
|
||||
fused_out = int8_scaled_mm_with_quant(
|
||||
fused_out = torch.ops.sgl_kernel.int8_scaled_mm_with_quant(
|
||||
A, Bq, Bs, bias if has_bias else None, torch.bfloat16, False
|
||||
)
|
||||
self.assertTrue(torch.allclose(ref_out, fused_out, atol=atol, rtol=rtol))
|
||||
@@ -157,9 +153,9 @@ class TestGemm(CustomTestCase):
|
||||
ref = torch.matmul(data.to(A_dtype), dq_weight.T)
|
||||
|
||||
if prepack:
|
||||
fp8_weight = convert_weight_packed(fp8_weight)
|
||||
fp8_weight = torch.ops.sgl_kernel.convert_weight_packed(fp8_weight)
|
||||
|
||||
opt = fp8_scaled_mm_cpu(
|
||||
opt = torch.ops.sgl_kernel.fp8_scaled_mm_cpu(
|
||||
data,
|
||||
fp8_weight,
|
||||
scales,
|
||||
|
||||
@@ -2,12 +2,10 @@ import itertools
|
||||
import math
|
||||
import unittest
|
||||
|
||||
# TODO: use interface in cpu.py
|
||||
import sgl_kernel
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# TODO: use interface in cpu.py
|
||||
from sgl_kernel.common_ops import convert_weight_packed
|
||||
from sgl_kernel.common_ops import shared_expert_cpu as shared_expert
|
||||
from utils import (
|
||||
BLOCK_K,
|
||||
BLOCK_N,
|
||||
@@ -55,7 +53,7 @@ class TestSharedExpert(CustomTestCase):
|
||||
fused_output.float(),
|
||||
routed_scaling_factor,
|
||||
).to(dtype=dtype)
|
||||
res = shared_expert(
|
||||
res = torch.ops.sgl_kernel.shared_expert_cpu(
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
@@ -113,7 +111,7 @@ class TestSharedExpert(CustomTestCase):
|
||||
fused_output.float(),
|
||||
routed_scaling_factor,
|
||||
).to(dtype=dtype)
|
||||
res2 = shared_expert(
|
||||
res2 = torch.ops.sgl_kernel.shared_expert_cpu(
|
||||
hidden_states2,
|
||||
w1_q,
|
||||
w2_q,
|
||||
@@ -181,9 +179,9 @@ class TestSharedExpert(CustomTestCase):
|
||||
ref_out = shared_out + fused_out.float() * routed_scaling_factor
|
||||
ref_out = ref_out.to(dtype=dtype)
|
||||
|
||||
w1 = convert_weight_packed(w1) # [2N, K]
|
||||
w2 = convert_weight_packed(w2) # [K, N]
|
||||
out = shared_expert(
|
||||
w1 = torch.ops.sgl_kernel.convert_weight_packed(w1) # [2N, K]
|
||||
w2 = torch.ops.sgl_kernel.convert_weight_packed(w2) # [K, N]
|
||||
out = torch.ops.sgl_kernel.shared_expert_cpu(
|
||||
a2,
|
||||
w1,
|
||||
w2,
|
||||
|
||||
Reference in New Issue
Block a user