CPU: map changes from developing branch in sgl-kernel (#6833)
Co-authored-by: mingfeima <mingfei.ma@intel.com>
This commit is contained in:
@@ -34,7 +34,15 @@ class TestGroupedTopK(CustomTestCase):
|
||||
|
||||
# fused version
|
||||
topk_weights, topk_ids = torch.ops.sgl_kernel.grouped_topk_cpu(
|
||||
hidden_states, gating_output, topk, renormalize, G, topk_group
|
||||
hidden_states,
|
||||
gating_output,
|
||||
topk,
|
||||
renormalize,
|
||||
G,
|
||||
topk_group,
|
||||
0,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
res = torch.zeros(M, E, dtype=torch.float)
|
||||
@@ -83,6 +91,9 @@ class TestBiasedGroupedTopK(CustomTestCase):
|
||||
renormalize,
|
||||
G,
|
||||
topk_group,
|
||||
0,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
res = torch.zeros(M, E, dtype=torch.float)
|
||||
|
||||
Reference in New Issue
Block a user