Refactor the ops PyTorch adapter,cleanup for csrc/torch_binding.cpp (#6732)

### What this PR does / why we need it?
Refactor the ops PyTorch adapter,cleanup for csrc/torch_binding.cpp,
more details see
https://github.com/vllm-project/vllm-ascend/issues/6486

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
install the new package to test the new modification, here is the
result:


- vLLM version: v0.15.0
- vLLM main:
9562912cea

---------

Signed-off-by: liziyu <liziyu16@huawei.com>
Signed-off-by: wangxiaoteng <wangxiaoteng@huawei.com>
Signed-off-by: luomin2005 <luomin2005@huawei.com>
Co-authored-by: liziyu <56102866+liziyu179@users.noreply.github.com>
Co-authored-by: wangxiaoteng <wangxiaoteng@huawei.com>
This commit is contained in:
luomin2005
2026-02-24 09:12:43 +08:00
committed by GitHub
parent f0caeeadcb
commit f41eeeb11e
15 changed files with 1037 additions and 735 deletions

View File

@@ -0,0 +1,53 @@
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef ADD_RMS_NORM_BIAS_TORCH_ADPT_H
#define ADD_RMS_NORM_BIAS_TORCH_ADPT_H
namespace vllm_ascend {
std::tuple<at::Tensor,at::Tensor, at::Tensor> npu_add_rms_norm_bias(
const at::Tensor& x1,
const at::Tensor& x2,
const at::Tensor& gamma,
const c10::optional<at::Tensor> &beta,
double epsilon)
{
int64_t dim_x = x1.dim();
int64_t dim_gamma = gamma.dim();
int64_t diff = dim_x - dim_gamma;
std::vector<int64_t> new_shape;
at::Tensor rstd;
if (diff > 0) {
new_shape.reserve(dim_x);
auto x1_sizes = x1.sizes();
for (int64_t i = 0; i < diff; ++i) {
new_shape.push_back(x1_sizes[i]);
}
for (int64_t i = 0; i < dim_gamma; ++i) {
new_shape.push_back(1);
}
} else {
new_shape.assign(dim_x, 1);
}
rstd = at::empty(new_shape, x1.options().dtype(at::kFloat));
at::Tensor y = at::empty(x1.sizes(), x1.options());
at::Tensor x = at::empty(x1.sizes(), x1.options());
EXEC_NPU_CMD(aclnnAddRmsNormBias, x1, x2, gamma, beta, epsilon, y, rstd, x);
return std::tuple<at::Tensor, at::Tensor, at::Tensor>(y, rstd, x);
}
}
#endif