[feature] add_rms_norm support bias (#5790)

### What this PR does / why we need it?
This PR is to replace addRmsNorm and Add With addRmsNormBias. This way
can lead to a more effecient result.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Full Test Pass

- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef

Signed-off-by: Chen_HaoWen <chenhaowen12@huawei.com>
Co-authored-by: Chen_HaoWen <chenhaowen12@huawei.com>
This commit is contained in:
yjmyl
2026-01-23 21:09:54 +08:00
committed by GitHub
parent 6c73b88dd6
commit e90b14140b
24 changed files with 3537 additions and 13 deletions

View File

@@ -403,6 +403,37 @@ std::tuple<at::Tensor,at::Tensor, at::Tensor> moe_gating_top_k_meta(
return std::tuple<at::Tensor, at::Tensor, at::Tensor>(y,expert_idx,out);
}
std::tuple<at::Tensor,at::Tensor, at::Tensor> npu_add_rms_norm_bias_meta(
const at::Tensor& x1,
const at::Tensor& x2,
const at::Tensor& gamma,
const c10::optional<at::Tensor> &beta,
double epsilon)
{
int64_t dim_x = x1.dim();
int64_t dim_gamma = gamma.dim();
int64_t diff = dim_x - dim_gamma;
c10::SymDimVector new_shape;
at::Tensor rstd;
if (diff > 0) {
new_shape.reserve(dim_x);
auto x1_sizes = x1.sym_sizes();
for (int64_t i = 0; i < diff; ++i) {
new_shape.push_back(x1_sizes[i]);
}
for (int64_t i = 0; i < dim_gamma; ++i) {
new_shape.push_back(c10::SymInt(1));
}
} else {
new_shape.assign(dim_x, c10::SymInt(1));
}
rstd = at::empty_symint(new_shape, x1.options().dtype(at::kFloat));
at::Tensor y = at::empty_symint(x1.sym_sizes(), x1.options());
at::Tensor x = at::empty_symint(x1.sym_sizes(), x1.options());
return std::tuple<at::Tensor, at::Tensor, at::Tensor>(y, rstd, x);
}
} // namespace meta
} // namespace vllm_ascend
@@ -441,5 +472,7 @@ TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) {
ops.impl("npu_moe_init_routing_custom", &vllm_ascend::meta::npu_moe_init_routing_custom_meta);
// Moe_gating_top_k
ops.impl("moe_gating_top_k", &vllm_ascend::meta::moe_gating_top_k_meta);
// Add_Rms_Norm_Bias
ops.impl("npu_add_rms_norm_bias", &vllm_ascend::meta::npu_add_rms_norm_bias_meta);
}
}