# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Songlin Yang, Yu Zhang # # This file contains code copied from the flash-linear-attention project. # The original source code was licensed under the MIT license and included # the following copyright notice: # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang # ruff: noqa: E501 import contextlib import functools from typing import Callable import torch from vllm.triton_utils import tl, triton def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor: return cu_seqlens[1:] - cu_seqlens[:-1] def prepare_chunk_indices(cu_seqlens: torch.LongTensor, chunk_size: int) -> torch.LongTensor: indices = torch.cat([ torch.arange(n) for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist() ]) return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens) def prepare_chunk_offsets(cu_seqlens: torch.LongTensor, chunk_size: int) -> torch.LongTensor: return torch.cat([ cu_seqlens.new_tensor([0]), triton.cdiv(prepare_lens(cu_seqlens), chunk_size) ]).cumsum(-1) def input_guard( fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: """ A decorator to make sure all input tensors are contiguous and set the device based on input tensors. """ @functools.wraps(fn) def wrapper(*args, **kwargs): contiguous_args = (i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args) contiguous_kwargs = { k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items() } tensor = None for arg in args: if isinstance(arg, torch.Tensor): tensor = arg break if tensor is None: for value in kwargs.values(): if isinstance(value, torch.Tensor): tensor = value break if tensor is not None: ctx = torch.npu.device(tensor.device.index) else: ctx = contextlib.nullcontext() with ctx: return fn(*contiguous_args, **contiguous_kwargs) return wrapper @triton.jit def safe_exp(x): return tl.exp(tl.where(x <= 0, x, float("-inf")))