Add GemmaRmsNorm ACLGraph Support (#6473)
### What this PR does / why we need it?
1. New Custom NPU Operation: Introduced npu_gemma_rms_norm in
csrc/torch_binding.cpp to provide optimized Gemma RMS Normalization
support for Ascend NPUs. This function includes logic to handle dynamic
shapes for the gamma tensor.
2. PyTorch Operator Registration: The new npu_gemma_rms_norm operation
has been registered with the PyTorch custom operator library, making it
accessible from Python.
Meta-Implementation for ACLGraph: A corresponding meta-implementation,
npu_gemma_rms_norm_meta, was added in csrc/torch_binding_meta.cpp. This
is crucial for symbolic tracing and allowing the custom kernel to be
captured and optimized by ACLGraph.
3. Python Frontend Integration: The vllm_ascend/ops/layernorm.py file
was updated to utilize the newly added
torch.ops._C_ascend.npu_gemma_rms_norm for Gemma RMS Normalization,
replacing the generic torch_npu.npu_rms_norm
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
- vLLM version: v0.14.1
- vLLM main:
dc917cceb8
---------
Signed-off-by: SunnyLee219 <3294305115@qq.com>
Signed-off-by: LeeWenquan <83354342+SunnyLee151064@users.noreply.github.com>
This commit is contained in:
@@ -553,6 +553,34 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> dispatch_prefill(
|
|||||||
return {expandx_out, expand_idx_out, recv_count, num_recv_tokens_per_expert};
|
return {expandx_out, expand_idx_out, recv_count, num_recv_tokens_per_expert};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::tuple<at::Tensor, at::Tensor> npu_gemma_rms_norm(
|
||||||
|
const at::Tensor& x,
|
||||||
|
const at::Tensor& gamma,
|
||||||
|
double epsilon)
|
||||||
|
{
|
||||||
|
int64_t dim_x = x.dim();
|
||||||
|
int64_t dim_gamma = gamma.dim();
|
||||||
|
int64_t diff = dim_x - dim_gamma;
|
||||||
|
std::vector<int64_t> new_shape;
|
||||||
|
at::Tensor rstd;
|
||||||
|
if (diff > 0) {
|
||||||
|
new_shape.reserve(dim_x);
|
||||||
|
auto x_sizes = x.sizes();
|
||||||
|
for (int64_t i = 0; i < diff; ++i) {
|
||||||
|
new_shape.push_back(x_sizes[i]);
|
||||||
|
}
|
||||||
|
for (int64_t i = 0; i < dim_gamma; ++i) {
|
||||||
|
new_shape.push_back(1);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
new_shape.assign(dim_x, 1);
|
||||||
|
}
|
||||||
|
rstd = at::empty(new_shape, x.options().dtype(at::kFloat));
|
||||||
|
at::Tensor y = at::empty(x.sizes(), x.options());
|
||||||
|
EXEC_NPU_CMD(aclnnGemmaRmsNorm, x, gamma, epsilon, y, rstd);
|
||||||
|
return std::tuple<at::Tensor, at::Tensor>(y, rstd);
|
||||||
|
}
|
||||||
|
|
||||||
void transpose_kv_cache_by_block(
|
void transpose_kv_cache_by_block(
|
||||||
const at::TensorList &kCache,
|
const at::TensorList &kCache,
|
||||||
const at::TensorList &vCache,
|
const at::TensorList &vCache,
|
||||||
@@ -575,6 +603,14 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
|||||||
{
|
{
|
||||||
|
|
||||||
// vLLM-Ascend custom ops
|
// vLLM-Ascend custom ops
|
||||||
|
// Gemma RmsNorm
|
||||||
|
ops.def(
|
||||||
|
"npu_gemma_rms_norm(Tensor x, "
|
||||||
|
"Tensor gamma, "
|
||||||
|
"float epsilon=1e-6)"
|
||||||
|
"-> (Tensor y ,Tensor rstd)"
|
||||||
|
);
|
||||||
|
ops.impl("npu_gemma_rms_norm", torch::kPrivateUse1, &vllm_ascend::npu_gemma_rms_norm);
|
||||||
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
|
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
|
||||||
ops.impl("weak_ref_tensor", torch::kPrivateUse1, &vllm_ascend::weak_ref_tensor);
|
ops.impl("weak_ref_tensor", torch::kPrivateUse1, &vllm_ascend::weak_ref_tensor);
|
||||||
|
|
||||||
|
|||||||
@@ -418,6 +418,33 @@ std::tuple<at::Tensor,at::Tensor, at::Tensor> npu_add_rms_norm_bias_meta(
|
|||||||
return std::tuple<at::Tensor, at::Tensor, at::Tensor>(y, rstd, x);
|
return std::tuple<at::Tensor, at::Tensor, at::Tensor>(y, rstd, x);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::tuple<at::Tensor, at::Tensor> npu_gemma_rms_norm_meta(
|
||||||
|
const at::Tensor& x,
|
||||||
|
const at::Tensor& gamma,
|
||||||
|
double epsilon)
|
||||||
|
{
|
||||||
|
int64_t dim_x = x.dim();
|
||||||
|
int64_t dim_gamma = gamma.dim();
|
||||||
|
int64_t diff = dim_x - dim_gamma;
|
||||||
|
c10::SymDimVector new_shape;
|
||||||
|
at::Tensor rstd;
|
||||||
|
if (diff > 0) {
|
||||||
|
new_shape.reserve(dim_x);
|
||||||
|
auto x_sizes = x.sym_sizes();
|
||||||
|
for (int64_t i = 0; i < diff; ++i) {
|
||||||
|
new_shape.push_back(x_sizes[i]);
|
||||||
|
}
|
||||||
|
for (int64_t i = 0; i < dim_gamma; ++i) {
|
||||||
|
new_shape.push_back(c10::SymInt(1));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
new_shape.assign(dim_x, c10::SymInt(1));
|
||||||
|
}
|
||||||
|
rstd = at::empty_symint(new_shape, x.options().dtype(at::kFloat));
|
||||||
|
at::Tensor y = at::empty_symint(x.sym_sizes(), x.options());
|
||||||
|
return std::tuple<at::Tensor, at::Tensor>(y, rstd);
|
||||||
|
}
|
||||||
|
|
||||||
void transpose_kv_cache_by_block_meta(
|
void transpose_kv_cache_by_block_meta(
|
||||||
const at::TensorList &k_cache,
|
const at::TensorList &k_cache,
|
||||||
const at::TensorList &v_cache,
|
const at::TensorList &v_cache,
|
||||||
@@ -430,7 +457,6 @@ void transpose_kv_cache_by_block_meta(
|
|||||||
{
|
{
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace meta
|
} // namespace meta
|
||||||
} // namespace vllm_ascend
|
} // namespace vllm_ascend
|
||||||
|
|
||||||
@@ -438,7 +464,8 @@ namespace {
|
|||||||
// Register the meta implementations of the custom kernels for symbolic tracing, this will also
|
// Register the meta implementations of the custom kernels for symbolic tracing, this will also
|
||||||
// the custom kernel been captured into aclgraph
|
// the custom kernel been captured into aclgraph
|
||||||
TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) {
|
TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) {
|
||||||
|
//Gemma rmsnorm meta implementation
|
||||||
|
ops.impl("npu_gemma_rms_norm", &vllm_ascend::meta::npu_gemma_rms_norm_meta);
|
||||||
// Masked input and mask meta implementation
|
// Masked input and mask meta implementation
|
||||||
ops.impl("get_masked_input_and_mask", &vllm_ascend::meta::get_masked_input_and_mask_meta);
|
ops.impl("get_masked_input_and_mask", &vllm_ascend::meta::get_masked_input_and_mask_meta);
|
||||||
// Bgmv expand
|
// Bgmv expand
|
||||||
|
|||||||
@@ -87,7 +87,7 @@ class AscendGemmaRMSNorm(GemmaRMSNorm):
|
|||||||
x, _, residual = torch_npu.npu_add_rms_norm(x, residual, 1.0 + self.weight, self.variance_epsilon)
|
x, _, residual = torch_npu.npu_add_rms_norm(x, residual, 1.0 + self.weight, self.variance_epsilon)
|
||||||
return x, residual
|
return x, residual
|
||||||
|
|
||||||
x, _ = torch_npu.npu_rms_norm(x, 1.0 + self.weight, self.variance_epsilon)
|
x, _ = torch.ops._C_ascend.npu_gemma_rms_norm(x, self.weight, self.variance_epsilon)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user