[feature] add_rms_norm support bias (#5790)
### What this PR does / why we need it?
This PR is to replace addRmsNorm and Add With addRmsNormBias. This way
can lead to a more effecient result.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Full Test Pass
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
Signed-off-by: Chen_HaoWen <chenhaowen12@huawei.com>
Co-authored-by: Chen_HaoWen <chenhaowen12@huawei.com>
This commit is contained in:
@@ -1288,6 +1288,38 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> moe_gating_top_k(
|
||||
return std::tuple<at::Tensor, at::Tensor, at::Tensor>(y,expert_idx,out);
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor,at::Tensor, at::Tensor> npu_add_rms_norm_bias(
|
||||
const at::Tensor& x1,
|
||||
const at::Tensor& x2,
|
||||
const at::Tensor& gamma,
|
||||
const c10::optional<at::Tensor> &beta,
|
||||
double epsilon)
|
||||
{
|
||||
int64_t dim_x = x1.dim();
|
||||
int64_t dim_gamma = gamma.dim();
|
||||
int64_t diff = dim_x - dim_gamma;
|
||||
std::vector<int64_t> new_shape;
|
||||
at::Tensor rstd;
|
||||
|
||||
if (diff > 0) {
|
||||
new_shape.reserve(dim_x);
|
||||
auto x1_sizes = x1.sizes();
|
||||
for (int64_t i = 0; i < diff; ++i) {
|
||||
new_shape.push_back(x1_sizes[i]);
|
||||
}
|
||||
for (int64_t i = 0; i < dim_gamma; ++i) {
|
||||
new_shape.push_back(1);
|
||||
}
|
||||
} else {
|
||||
new_shape.assign(dim_x, 1);
|
||||
}
|
||||
rstd = at::empty(new_shape, x1.options().dtype(at::kFloat));
|
||||
at::Tensor y = at::empty(x1.sizes(), x1.options());
|
||||
at::Tensor x = at::empty(x1.sizes(), x1.options());
|
||||
EXEC_NPU_CMD(aclnnAddRmsNormBias, x1, x2, gamma, beta, epsilon, y, rstd, x);
|
||||
return std::tuple<at::Tensor, at::Tensor, at::Tensor>(y, rstd, x);
|
||||
}
|
||||
|
||||
} // namespace vllm_ascend
|
||||
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
||||
@@ -1453,4 +1485,14 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
||||
"-> (Tensor y ,Tensor expert_idx, Tensor out)"
|
||||
);
|
||||
ops.impl("moe_gating_top_k", torch::kPrivateUse1,&vllm_ascend::moe_gating_top_k);
|
||||
|
||||
ops.def(
|
||||
"npu_add_rms_norm_bias(Tensor x1, "
|
||||
"Tensor x2, "
|
||||
"Tensor gamma, "
|
||||
"Tensor? beta=None, "
|
||||
"float epsilon=1e-6)"
|
||||
"-> (Tensor y ,Tensor rstd, Tensor x)"
|
||||
);
|
||||
ops.impl("npu_add_rms_norm_bias", torch::kPrivateUse1, &vllm_ascend::npu_add_rms_norm_bias);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user