[CPU] Fix build issue (#6419)

This commit is contained in:
blzheng
2025-05-22 02:17:10 +08:00
committed by GitHub
parent d4c038daed
commit cfe48c5902
14 changed files with 157 additions and 143 deletions

View File

@@ -2,12 +2,10 @@ import itertools
import math
import unittest
# TODO: use interface in cpu.py
import sgl_kernel
import torch
import torch.nn as nn
# TODO: use interface in cpu.py
from sgl_kernel.common_ops import convert_weight_packed
from sgl_kernel.common_ops import shared_expert_cpu as shared_expert
from utils import (
BLOCK_K,
BLOCK_N,
@@ -55,7 +53,7 @@ class TestSharedExpert(CustomTestCase):
fused_output.float(),
routed_scaling_factor,
).to(dtype=dtype)
res = shared_expert(
res = torch.ops.sgl_kernel.shared_expert_cpu(
hidden_states,
w1,
w2,
@@ -113,7 +111,7 @@ class TestSharedExpert(CustomTestCase):
fused_output.float(),
routed_scaling_factor,
).to(dtype=dtype)
res2 = shared_expert(
res2 = torch.ops.sgl_kernel.shared_expert_cpu(
hidden_states2,
w1_q,
w2_q,
@@ -181,9 +179,9 @@ class TestSharedExpert(CustomTestCase):
ref_out = shared_out + fused_out.float() * routed_scaling_factor
ref_out = ref_out.to(dtype=dtype)
w1 = convert_weight_packed(w1) # [2N, K]
w2 = convert_weight_packed(w2) # [K, N]
out = shared_expert(
w1 = torch.ops.sgl_kernel.convert_weight_packed(w1) # [2N, K]
w2 = torch.ops.sgl_kernel.convert_weight_packed(w2) # [K, N]
out = torch.ops.sgl_kernel.shared_expert_cpu(
a2,
w1,
w2,