CPU: map changes from developing branch in sgl-kernel (#6833)
Co-authored-by: mingfeima <mingfei.ma@intel.com>
This commit is contained in:
@@ -4,7 +4,7 @@ from typing import Optional, Tuple, Union
|
||||
|
||||
import sgl_kernel
|
||||
import torch
|
||||
from utils import precision
|
||||
from utils import make_non_contiguous, precision
|
||||
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
@@ -38,6 +38,7 @@ class TestNorm(CustomTestCase):
|
||||
def _norm_test(self, m, n, dtype):
|
||||
|
||||
x = torch.randn([m, n], dtype=dtype)
|
||||
x = make_non_contiguous(x)
|
||||
hidden_size = x.size(-1)
|
||||
weight = torch.randn(hidden_size, dtype=dtype)
|
||||
variance_epsilon = 1e-6
|
||||
@@ -49,7 +50,7 @@ class TestNorm(CustomTestCase):
|
||||
self.assertTrue(torch.allclose(ref_out, out, atol=atol, rtol=rtol))
|
||||
|
||||
ref_x = x.clone()
|
||||
residual = torch.randn([m, n], dtype=dtype)
|
||||
residual = torch.randn([m, hidden_size], dtype=dtype)
|
||||
ref_residual = residual.clone()
|
||||
|
||||
torch.ops.sgl_kernel.fused_add_rmsnorm_cpu(
|
||||
|
||||
Reference in New Issue
Block a user