Files

154 lines
5.1 KiB
Python
Raw Permalink Normal View History

# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/batch_invariant.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import torch
from vllm.triton_utils import tl, triton
@triton.jit
def _rms_norm_kernel(
input_ptr,
weight_ptr,
output_ptr,
input_row_stride,
output_row_stride,
n_rows, # 新增参数:总行数
n_cols,
eps,
BLOCK_SIZE: tl.constexpr,
):
"""
Compute RMS normalization along the last dimension of a 2D tensor.
RMS Norm: y = x / sqrt(mean(x^2) + eps) * weight
Each program handles multiple rows of the input tensor.
"""
pid = tl.program_id(0)
n_programs = tl.num_programs(0)
rows_per_program = (n_rows + n_programs - 1) // n_programs
start_row = pid * rows_per_program
end_row = tl.minimum(start_row + rows_per_program, n_rows)
for row_idx in range(start_row, end_row):
row_start_ptr = input_ptr + row_idx * input_row_stride
output_row_start_ptr = output_ptr + row_idx * output_row_stride
# Step 1: Compute sum of squares in float32 to avoid overflow
sum_sq = tl.zeros([1], dtype=tl.float32)
for col_offset in range(0, n_cols, BLOCK_SIZE):
col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
mask = col_idx < n_cols
vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0)
vals_f32 = vals.to(tl.float32)
sq_vals = vals_f32 * vals_f32
sum_sq += tl.sum(tl.where(mask, sq_vals, 0.0))
# Step 2: Compute RMS (root mean square) in float32
mean_sq = sum_sq / n_cols
rms = tl.sqrt(mean_sq + eps)
inv_rms = 1.0 / rms
# Step 3: Normalize and apply weight
for col_offset in range(0, n_cols, BLOCK_SIZE):
col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
mask = col_idx < n_cols
vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0)
weight = tl.load(weight_ptr + col_idx, mask=mask, other=1.0)
vals_f32 = vals.to(tl.float32)
weight_f32 = weight.to(tl.float32)
output_f32 = vals_f32 * inv_rms * weight_f32
output = output_f32.to(vals.dtype)
tl.store(output_row_start_ptr + col_idx, output, mask=mask)
def rms_norm(
input_: torch.Tensor,
weight: torch.Tensor,
eps: float = 1e-6,
) -> torch.Tensor:
"""
Compute RMS normalization using Triton kernel with fixed grid size.
RMS Norm normalizes the input by the root mean square and scales by weight:
output = input / sqrt(mean(input^2) + eps) * weight
Args:
input: Input tensor of shape (..., hidden_size)
weight: Weight tensor of shape (hidden_size,)
eps: Small constant for numerical stability
Returns:
Tensor with RMS normalization applied along the last dimension
"""
assert weight.dim() == 1, "Weight must be 1-dimensional"
assert input_.shape[-1] == weight.shape[0], (
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #12) (#6177) ### What this PR does / why we need it? **Scope of Changes**: | File Path | | :--- | | `vllm_ascend/ops/triton/activation/swiglu_quant.py` | | `vllm_ascend/ops/triton/batch_invariant/matmul.py` | | `vllm_ascend/ops/triton/batch_invariant/mean.py` | | `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` | | `vllm_ascend/ops/triton/fla/chunk.py` | | `vllm_ascend/ops/triton/fla/chunk_delta_h.py` | | `vllm_ascend/ops/triton/fla/chunk_o.py` | | `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` | | `vllm_ascend/ops/triton/fla/cumsum.py` | | `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` | | `vllm_ascend/ops/triton/fla/l2norm.py` | | `vllm_ascend/ops/triton/fla/layernorm_guard.py` | | `vllm_ascend/ops/triton/fla/sigmoid_gating.py` | | `vllm_ascend/ops/triton/fla/solve_tril.py` | | `vllm_ascend/ops/triton/fla/utils.py` | | `vllm_ascend/ops/triton/fla/wy_fast.py` | | `vllm_ascend/ops/triton/fused_gdn_gating.py` | | `vllm_ascend/ops/triton/layernorm_gated.py` | | `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` | | `vllm_ascend/ops/triton/mamba/causal_conv1d.py` | | `vllm_ascend/ops/triton/reject_sample.py` | | `vllm_ascend/ops/triton/rope.py` | | `vllm_ascend/ops/triton/spec_decode/utils.py` | | `vllm_ascend/ops/triton/triton_utils.py` | ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.14.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-23 14:59:19 +08:00
f"Input last dimension ({input_.shape[-1]}) must match weight dimension ({weight.shape[0]})"
)
# Flatten all dimensions except the last one
original_shape = input_.shape
input_2d = input_.reshape(-1, input_.shape[-1])
input_2d = input_2d.contiguous()
weight = weight.contiguous()
n_rows, n_cols = input_2d.shape
output = torch.empty_like(input_2d, dtype=input_.dtype)
BLOCK_SIZE = 1024
[Lint]Add lint hooks for clang-format, shellcheck, forbidden imports, and boolean context manager checks (#7511) ### What this PR does / why we need it? This PR introduces several upstream `vllm`-aligned lint hooks into `vllm-ascend` and makes them part of the actual `pre-commit` flow. Main changes in this PR: - add `check-boolean-context-manager` to catch boolean expressions in `with` statements - add `check-forbidden-imports` to forbid direct `re` imports and disallowed direct `triton` imports - enable shell script linting through `tools/shellcheck.sh` - add root `.clang-format` aligned with upstream `vllm`, enable `clang-format` in `pre-commit`, temporarily **exclude all `csrc/**`** from `clang-format` to avoid bringing a large native code reformat into this PR This PR focuses on landing the smaller and immediately useful lint alignment first, without mixing in the larger requirements-management migration. ### Does this PR introduce _any_ user-facing change? No. This PR only updates repository lint configuration, static checks, and internal import/style enforcement. It does not change runtime behavior or public interfaces. ### How was this patch tested? Tested locally in the project virtual environment. Commands used: ```bash bash format.sh ``` Verified checks passed: ``` bash ruff check...............................................................Passed ruff format..............................................................Passed codespell................................................................Passed typos....................................................................Passed clang-format.............................................................Passed Lint GitHub Actions workflow files.......................................Passed Lint shell scripts.......................................................Passed Lint PNG exports from excalidraw.........................................Passed Check for spaces in all filenames........................................Passed Enforce __init__.py in Python packages...................................Passed Check for forbidden imports..............................................Passed Check for boolean ops in with-statements.................................Passed Suggestion...............................................................Passed - hook id: suggestion - duration: 0s To bypass pre-commit hooks, add --no-verify to git commit. ``` **note:** clang-format is enabled but currently excludes all csrc/** - vLLM version: v0.17.0 - vLLM main: https://github.com/vllm-project/vllm/commit/8b6325758cce5f9c36d38f2462edbd368b97a07c --------- Signed-off-by: MrZ20 <2609716663@qq.com>
2026-03-24 20:03:01 +08:00
max_grid_size = triton.runtime.driver.active.utils.get_device_properties(torch.npu.current_device())[
"num_vectorcore"
]
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #12) (#6177) ### What this PR does / why we need it? **Scope of Changes**: | File Path | | :--- | | `vllm_ascend/ops/triton/activation/swiglu_quant.py` | | `vllm_ascend/ops/triton/batch_invariant/matmul.py` | | `vllm_ascend/ops/triton/batch_invariant/mean.py` | | `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` | | `vllm_ascend/ops/triton/fla/chunk.py` | | `vllm_ascend/ops/triton/fla/chunk_delta_h.py` | | `vllm_ascend/ops/triton/fla/chunk_o.py` | | `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` | | `vllm_ascend/ops/triton/fla/cumsum.py` | | `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` | | `vllm_ascend/ops/triton/fla/l2norm.py` | | `vllm_ascend/ops/triton/fla/layernorm_guard.py` | | `vllm_ascend/ops/triton/fla/sigmoid_gating.py` | | `vllm_ascend/ops/triton/fla/solve_tril.py` | | `vllm_ascend/ops/triton/fla/utils.py` | | `vllm_ascend/ops/triton/fla/wy_fast.py` | | `vllm_ascend/ops/triton/fused_gdn_gating.py` | | `vllm_ascend/ops/triton/layernorm_gated.py` | | `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` | | `vllm_ascend/ops/triton/mamba/causal_conv1d.py` | | `vllm_ascend/ops/triton/reject_sample.py` | | `vllm_ascend/ops/triton/rope.py` | | `vllm_ascend/ops/triton/spec_decode/utils.py` | | `vllm_ascend/ops/triton/triton_utils.py` | ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.14.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-23 14:59:19 +08:00
grid = (min(n_rows, max_grid_size),)
_rms_norm_kernel[grid](
input_2d,
weight,
output,
input_2d.stride(0),
output.stride(0),
n_rows,
n_cols,
eps,
BLOCK_SIZE=BLOCK_SIZE,
)
return output.reshape(original_shape)
def rms_norm_batch_invariant(
input_: torch.Tensor,
weight: torch.Tensor,
eps: float = 1e-6,
) -> torch.Tensor:
"""
Batch-invariant wrapper for RMS normalization.
This function provides a deterministic, batch-invariant implementation
of RMS normalization for use with the batch_invariant mode.
Args:
input: Input tensor of shape (..., hidden_size)
weight: Weight tensor of shape (hidden_size,)
eps: Small constant for numerical stability
Returns:
RMS normalized tensor
"""
return rms_norm(input_, weight, eps=eps)