[CPU] Fix build issue (#6419)
This commit is contained in:
@@ -2,12 +2,10 @@ import itertools
|
||||
import math
|
||||
import unittest
|
||||
|
||||
# TODO: use interface in cpu.py
|
||||
import sgl_kernel
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# TODO: use interface in cpu.py
|
||||
from sgl_kernel.common_ops import convert_weight_packed
|
||||
from sgl_kernel.common_ops import shared_expert_cpu as shared_expert
|
||||
from utils import (
|
||||
BLOCK_K,
|
||||
BLOCK_N,
|
||||
@@ -55,7 +53,7 @@ class TestSharedExpert(CustomTestCase):
|
||||
fused_output.float(),
|
||||
routed_scaling_factor,
|
||||
).to(dtype=dtype)
|
||||
res = shared_expert(
|
||||
res = torch.ops.sgl_kernel.shared_expert_cpu(
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
@@ -113,7 +111,7 @@ class TestSharedExpert(CustomTestCase):
|
||||
fused_output.float(),
|
||||
routed_scaling_factor,
|
||||
).to(dtype=dtype)
|
||||
res2 = shared_expert(
|
||||
res2 = torch.ops.sgl_kernel.shared_expert_cpu(
|
||||
hidden_states2,
|
||||
w1_q,
|
||||
w2_q,
|
||||
@@ -181,9 +179,9 @@ class TestSharedExpert(CustomTestCase):
|
||||
ref_out = shared_out + fused_out.float() * routed_scaling_factor
|
||||
ref_out = ref_out.to(dtype=dtype)
|
||||
|
||||
w1 = convert_weight_packed(w1) # [2N, K]
|
||||
w2 = convert_weight_packed(w2) # [K, N]
|
||||
out = shared_expert(
|
||||
w1 = torch.ops.sgl_kernel.convert_weight_packed(w1) # [2N, K]
|
||||
w2 = torch.ops.sgl_kernel.convert_weight_packed(w2) # [K, N]
|
||||
out = torch.ops.sgl_kernel.shared_expert_cpu(
|
||||
a2,
|
||||
w1,
|
||||
w2,
|
||||
|
||||
Reference in New Issue
Block a user