Files
xc-llm-ascend/vllm_ascend/ops/triton/fla/utils.py
shiyuan680 1c4a0468ee 【OPS】qwen3-next support triton chunk_gated_delta_rule ops (#4070)
### What this PR does / why we need it?
qwen3-next suppot  triton chunk_gated_delta_rule ops

### co-owners
@OsirisDuan

- vLLM version: v0.11.2

Signed-off-by: shiyuan680 <917935075@qq.com>
2025-11-28 20:55:43 +08:00

80 lines
2.4 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
import contextlib
import functools
from typing import Callable
import torch
from vllm.triton_utils import tl, triton
def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
return cu_seqlens[1:] - cu_seqlens[:-1]
def prepare_chunk_indices(cu_seqlens: torch.LongTensor,
chunk_size: int) -> torch.LongTensor:
indices = torch.cat([
torch.arange(n)
for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()
])
return torch.stack([indices.eq(0).cumsum(0) - 1, indices],
1).to(cu_seqlens)
def prepare_chunk_offsets(cu_seqlens: torch.LongTensor,
chunk_size: int) -> torch.LongTensor:
return torch.cat([
cu_seqlens.new_tensor([0]),
triton.cdiv(prepare_lens(cu_seqlens), chunk_size)
]).cumsum(-1)
def input_guard(
fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
"""
A decorator to make sure all input tensors are contiguous and set the device based on input tensors.
"""
@functools.wraps(fn)
def wrapper(*args, **kwargs):
contiguous_args = (i if not isinstance(i, torch.Tensor) else
i.contiguous() for i in args)
contiguous_kwargs = {
k: (v if not isinstance(v, torch.Tensor) else v.contiguous())
for k, v in kwargs.items()
}
tensor = None
for arg in args:
if isinstance(arg, torch.Tensor):
tensor = arg
break
if tensor is None:
for value in kwargs.values():
if isinstance(value, torch.Tensor):
tensor = value
break
if tensor is not None:
ctx = torch.npu.device(tensor.device.index)
else:
ctx = contextlib.nullcontext()
with ctx:
return fn(*contiguous_args, **contiguous_kwargs)
return wrapper
@triton.jit
def safe_exp(x):
return tl.exp(tl.where(x <= 0, x, float("-inf")))