### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
| `vllm_ascend/ops/triton/activation/swiglu_quant.py` |
| `vllm_ascend/ops/triton/batch_invariant/matmul.py` |
| `vllm_ascend/ops/triton/batch_invariant/mean.py` |
| `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` |
| `vllm_ascend/ops/triton/fla/chunk.py` |
| `vllm_ascend/ops/triton/fla/chunk_delta_h.py` |
| `vllm_ascend/ops/triton/fla/chunk_o.py` |
| `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` |
| `vllm_ascend/ops/triton/fla/cumsum.py` |
| `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` |
| `vllm_ascend/ops/triton/fla/l2norm.py` |
| `vllm_ascend/ops/triton/fla/layernorm_guard.py` |
| `vllm_ascend/ops/triton/fla/sigmoid_gating.py` |
| `vllm_ascend/ops/triton/fla/solve_tril.py` |
| `vllm_ascend/ops/triton/fla/utils.py` |
| `vllm_ascend/ops/triton/fla/wy_fast.py` |
| `vllm_ascend/ops/triton/fused_gdn_gating.py` |
| `vllm_ascend/ops/triton/layernorm_gated.py` |
| `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` |
| `vllm_ascend/ops/triton/mamba/causal_conv1d.py` |
| `vllm_ascend/ops/triton/reject_sample.py` |
| `vllm_ascend/ops/triton/rope.py` |
| `vllm_ascend/ops/triton/spec_decode/utils.py` |
| `vllm_ascend/ops/triton/triton_utils.py` |
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.14.0
- vLLM main:
d68209402d
Signed-off-by: MrZ20 <2609716663@qq.com>
67 lines
2.2 KiB
Python
67 lines
2.2 KiB
Python
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/fla/ops/l2norm.py
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
|
#
|
|
# This file contains code copied from the flash-linear-attention project.
|
|
# The original source code was licensed under the MIT license and included
|
|
# the following copyright notice:
|
|
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
|
|
|
import torch
|
|
from vllm.triton_utils import tl, triton
|
|
|
|
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
|
|
|
|
|
|
@triton.jit
|
|
def l2norm_fwd_kernel2_loop(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr, NUM_CHUNKS: tl.constexpr):
|
|
base_row = tl.program_id(0) * (NUM_CHUNKS * MBLOCK)
|
|
rindex = tl.arange(0, N)[None, :]
|
|
|
|
for chunk in range(NUM_CHUNKS):
|
|
row_idx = base_row + chunk * MBLOCK + tl.arange(0, MBLOCK)[:, None]
|
|
xmask = row_idx < M
|
|
|
|
xs = tl.load(X + (rindex + N * row_idx), mask=xmask, other=0.0).to(tl.float32)
|
|
square = xs * xs
|
|
square_sum = tl.sum(square, 1)[:, None]
|
|
rsqrt = tl.rsqrt(square_sum + eps)
|
|
|
|
tl.store(Y + (rindex + N * row_idx), xs * rsqrt, xmask)
|
|
|
|
|
|
def l2norm_fwd(x: torch.Tensor, eps: float = 1e-6, output_dtype: torch.dtype | None = None):
|
|
x_shape_og = x.shape
|
|
x = x.reshape(-1, x.shape[-1])
|
|
# allocate output
|
|
if output_dtype is None:
|
|
y = torch.empty_like(x)
|
|
else:
|
|
y = torch.empty_like(x, dtype=output_dtype)
|
|
assert y.stride(-1) == 1
|
|
T, D = x.shape[0], x.shape[-1]
|
|
# Less than 64KB per feature: enqueue fused kernel
|
|
MAX_FUSED_SIZE = 65536 // x.element_size()
|
|
BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D))
|
|
if D > BD:
|
|
raise RuntimeError("This layer doesn't support feature dim >= 64KB.")
|
|
|
|
MBLOCK = 69
|
|
# M, N = x.shape
|
|
num_core = get_vectorcore_num()
|
|
main_bs = triton.cdiv(T, num_core)
|
|
num_sub_blocks = triton.cdiv(main_bs, MBLOCK)
|
|
grid = (num_core,)
|
|
l2norm_fwd_kernel2_loop[grid](
|
|
X=x,
|
|
Y=y,
|
|
eps=eps,
|
|
M=T,
|
|
N=D,
|
|
MBLOCK=MBLOCK,
|
|
NUM_CHUNKS=num_sub_blocks,
|
|
)
|
|
|
|
return y.view(x_shape_og)
|