adapt to dsv32 on dcu

This commit is contained in:
maxiao
2025-09-30 18:37:31 +08:00
parent 8f7453e3af
commit 852a49c5cc
159 changed files with 7211 additions and 7782 deletions

View File

@@ -187,9 +187,7 @@ fused_dual_residual_rmsnorm_kernel_autotune = rmsnorm_autotune(
def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=False):
assert len(x.shape) == 2
assert (
x.shape == residual.shape and x.dtype == residual.dtype
), f"{x.shape=} {residual.shape=} {x.dtype=} {residual.dtype=}"
assert x.shape == residual.shape and x.dtype == residual.dtype
output, mid = torch.empty_like(x), torch.empty_like(x)
bs, hidden_dim = x.shape
if autotune: