[Feature] Merge branch 'Qwen3-Next' into main && Support Qwen-next (#222)
Signed-off-by: xyDong0223 <dongxinyu03@baidu.com> Co-authored-by: xyDong0223 <dongxinyu03@baidu.com>
This commit is contained in:
@@ -407,7 +407,7 @@ def add_rmsnorm(
|
||||
) -> None:
|
||||
kunlun_ops.add_rmsnorm(
|
||||
x,
|
||||
y, # 原来写 residual,这里其实是 y
|
||||
y,
|
||||
residual_output=residual_output,
|
||||
weight=weight,
|
||||
eps=eps,
|
||||
@@ -523,6 +523,145 @@ def _fake_add_rmsnorm(
|
||||
add_rmsnorm.register_fake(_fake_add_rmsnorm)
|
||||
|
||||
|
||||
@custom_op("_C::gemma_add_rmsnorm", mutates_args=())
|
||||
def gemma_add_rmsnorm(
|
||||
x: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
eps: float = 1e-5,
|
||||
enable_pdl: bool = False,
|
||||
interweaved: bool = False,
|
||||
store_output_before_norm: bool = True,
|
||||
bias: torch.Tensor = None,
|
||||
smooth: torch.Tensor = None,
|
||||
residual_output: torch.Tensor = None,
|
||||
force_sdnn: bool = False,
|
||||
) -> None:
|
||||
# print("gemma_add_rmsnorm wrapper")
|
||||
kunlun_ops.gemma_add_rmsnorm(
|
||||
x,
|
||||
y,
|
||||
weight=weight,
|
||||
output=output,
|
||||
eps=eps,
|
||||
enable_pdl=enable_pdl,
|
||||
interweaved=interweaved,
|
||||
store_output_before_norm=store_output_before_norm,
|
||||
bias=bias,
|
||||
smooth=smooth,
|
||||
residual_output=residual_output,
|
||||
force_sdnn=force_sdnn,
|
||||
)
|
||||
|
||||
|
||||
@impl("_C::gemma_add_rmsnorm", "CUDA")
|
||||
def gemma_add_rmsnorm_cuda(
|
||||
x: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
eps: float = 1e-5,
|
||||
enable_pdl: bool = False,
|
||||
interweaved: bool = False,
|
||||
store_output_before_norm: bool = True,
|
||||
bias: torch.Tensor = None,
|
||||
smooth: torch.Tensor = None,
|
||||
residual_output: torch.Tensor = None,
|
||||
force_sdnn: bool = False,
|
||||
) -> None:
|
||||
# print("gemma_add_rmsnorm_cuda wrapper")
|
||||
kunlun_ops.gemma_add_rmsnorm(
|
||||
x,
|
||||
y,
|
||||
weight=weight,
|
||||
output=output,
|
||||
eps=eps,
|
||||
enable_pdl=enable_pdl,
|
||||
interweaved=interweaved,
|
||||
store_output_before_norm=store_output_before_norm,
|
||||
bias=bias,
|
||||
smooth=smooth,
|
||||
residual_output=residual_output,
|
||||
force_sdnn=force_sdnn,
|
||||
)
|
||||
|
||||
|
||||
def _fake_gemma_add_rmsnorm(
|
||||
x: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
eps: float = 1e-5,
|
||||
enable_pdl: bool = False,
|
||||
interweaved: bool = False,
|
||||
store_output_before_norm: bool = True,
|
||||
bias: torch.Tensor = None,
|
||||
smooth: torch.Tensor = None,
|
||||
residual_output: torch.Tensor = None,
|
||||
force_sdnn: bool = False,
|
||||
):
|
||||
output.fake_shape = x.shape
|
||||
output.fake_dtype = x.dtype
|
||||
return None
|
||||
|
||||
|
||||
gemma_add_rmsnorm.register_fake(_fake_gemma_add_rmsnorm)
|
||||
|
||||
|
||||
@custom_op("_C::gemma_rmsnorm", mutates_args=())
|
||||
def gemma_rmsnorm(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
eps: float = 1e-5,
|
||||
enable_pdl: bool = False,
|
||||
interweave: bool = False,
|
||||
bias: torch.Tensor = None,
|
||||
force_sdnn: bool = False,
|
||||
) -> None:
|
||||
# print("gemma_rmsnorm wrapper")
|
||||
kunlun_ops.gemma_rmsnorm(
|
||||
x, weight, output, eps, enable_pdl, interweave, bias, force_sdnn
|
||||
)
|
||||
|
||||
|
||||
@impl("_C::gemma_rmsnorm", "CUDA")
|
||||
def gemma_rmsnorm_cuda(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
eps: float = 1e-5,
|
||||
enable_pdl: bool = False,
|
||||
interweave: bool = False,
|
||||
bias: torch.Tensor = None,
|
||||
force_sdnn: bool = False,
|
||||
) -> None:
|
||||
# print("gemma_rmsnorm_cuda wrapper")
|
||||
kunlun_ops.gemma_rmsnorm(
|
||||
x, weight, output, eps, enable_pdl, interweave, bias, force_sdnn
|
||||
)
|
||||
|
||||
|
||||
def _fake_gemma_rmsnorm(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
eps: float = 1e-5,
|
||||
enable_pdl: bool = False,
|
||||
interweave: bool = False,
|
||||
bias: torch.Tensor = None,
|
||||
force_sdnn: bool = False,
|
||||
):
|
||||
# 设置 shape/dtype,但不返回值
|
||||
output.fake_shape = x.shape
|
||||
output.fake_dtype = x.dtype
|
||||
return None
|
||||
|
||||
|
||||
gemma_rmsnorm.register_fake(_fake_gemma_rmsnorm)
|
||||
|
||||
|
||||
@custom_op("_C::split_norm_rope_neox", mutates_args=())
|
||||
def split_norm_rope_neox(
|
||||
q_emb: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user