[Feature] Merge branch 'Qwen3-Next' into main && Support Qwen-next (#222)

Signed-off-by: xyDong0223 <dongxinyu03@baidu.com>
Co-authored-by: xyDong0223 <dongxinyu03@baidu.com>
This commit is contained in:
chanzhennan
2026-02-28 11:15:50 +08:00
committed by GitHub
parent 153093d3b3
commit 82544aa0cc
17 changed files with 2668 additions and 1532 deletions

View File

@@ -407,7 +407,7 @@ def add_rmsnorm(
) -> None:
kunlun_ops.add_rmsnorm(
x,
y, # 原来写 residual这里其实是 y
y,
residual_output=residual_output,
weight=weight,
eps=eps,
@@ -523,6 +523,145 @@ def _fake_add_rmsnorm(
add_rmsnorm.register_fake(_fake_add_rmsnorm)
@custom_op("_C::gemma_add_rmsnorm", mutates_args=())
def gemma_add_rmsnorm(
x: torch.Tensor,
y: torch.Tensor,
weight: torch.Tensor,
output: torch.Tensor,
eps: float = 1e-5,
enable_pdl: bool = False,
interweaved: bool = False,
store_output_before_norm: bool = True,
bias: torch.Tensor = None,
smooth: torch.Tensor = None,
residual_output: torch.Tensor = None,
force_sdnn: bool = False,
) -> None:
# print("gemma_add_rmsnorm wrapper")
kunlun_ops.gemma_add_rmsnorm(
x,
y,
weight=weight,
output=output,
eps=eps,
enable_pdl=enable_pdl,
interweaved=interweaved,
store_output_before_norm=store_output_before_norm,
bias=bias,
smooth=smooth,
residual_output=residual_output,
force_sdnn=force_sdnn,
)
@impl("_C::gemma_add_rmsnorm", "CUDA")
def gemma_add_rmsnorm_cuda(
x: torch.Tensor,
y: torch.Tensor,
weight: torch.Tensor,
output: torch.Tensor,
eps: float = 1e-5,
enable_pdl: bool = False,
interweaved: bool = False,
store_output_before_norm: bool = True,
bias: torch.Tensor = None,
smooth: torch.Tensor = None,
residual_output: torch.Tensor = None,
force_sdnn: bool = False,
) -> None:
# print("gemma_add_rmsnorm_cuda wrapper")
kunlun_ops.gemma_add_rmsnorm(
x,
y,
weight=weight,
output=output,
eps=eps,
enable_pdl=enable_pdl,
interweaved=interweaved,
store_output_before_norm=store_output_before_norm,
bias=bias,
smooth=smooth,
residual_output=residual_output,
force_sdnn=force_sdnn,
)
def _fake_gemma_add_rmsnorm(
x: torch.Tensor,
y: torch.Tensor,
weight: torch.Tensor,
output: torch.Tensor,
eps: float = 1e-5,
enable_pdl: bool = False,
interweaved: bool = False,
store_output_before_norm: bool = True,
bias: torch.Tensor = None,
smooth: torch.Tensor = None,
residual_output: torch.Tensor = None,
force_sdnn: bool = False,
):
output.fake_shape = x.shape
output.fake_dtype = x.dtype
return None
gemma_add_rmsnorm.register_fake(_fake_gemma_add_rmsnorm)
@custom_op("_C::gemma_rmsnorm", mutates_args=())
def gemma_rmsnorm(
x: torch.Tensor,
weight: torch.Tensor,
output: torch.Tensor,
eps: float = 1e-5,
enable_pdl: bool = False,
interweave: bool = False,
bias: torch.Tensor = None,
force_sdnn: bool = False,
) -> None:
# print("gemma_rmsnorm wrapper")
kunlun_ops.gemma_rmsnorm(
x, weight, output, eps, enable_pdl, interweave, bias, force_sdnn
)
@impl("_C::gemma_rmsnorm", "CUDA")
def gemma_rmsnorm_cuda(
x: torch.Tensor,
weight: torch.Tensor,
output: torch.Tensor,
eps: float = 1e-5,
enable_pdl: bool = False,
interweave: bool = False,
bias: torch.Tensor = None,
force_sdnn: bool = False,
) -> None:
# print("gemma_rmsnorm_cuda wrapper")
kunlun_ops.gemma_rmsnorm(
x, weight, output, eps, enable_pdl, interweave, bias, force_sdnn
)
def _fake_gemma_rmsnorm(
x: torch.Tensor,
weight: torch.Tensor,
output: torch.Tensor,
eps: float = 1e-5,
enable_pdl: bool = False,
interweave: bool = False,
bias: torch.Tensor = None,
force_sdnn: bool = False,
):
# 设置 shape/dtype但不返回值
output.fake_shape = x.shape
output.fake_dtype = x.dtype
return None
gemma_rmsnorm.register_fake(_fake_gemma_rmsnorm)
@custom_op("_C::split_norm_rope_neox", mutates_args=())
def split_norm_rope_neox(
q_emb: torch.Tensor,