# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Songlin Yang, Yu Zhang # # This file contains code copied from the flash-linear-attention project. # The original source code was licensed under the MIT license and included # the following copyright notice: # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang import os from vllm.triton_utils import tl, tldevice, triton from .utils import is_gather_supported if os.environ.get("FLA_USE_FAST_OPS", "0") == "1": exp = tldevice.fast_expf log = tldevice.fast_logf log2 = tldevice.fast_log2f else: exp = tl.exp log = tl.log log2 = tl.log2 if not is_gather_supported: @triton.jit def gather(src, index, axis, _builder=None): """ Gather operation that works when tl.gather is not supported. This is a fallback implementation that returns None. Just to make triton compiler happy. """ return None else: gather = tl.gather if hasattr(triton.language, "_experimental_make_tensor_descriptor"): # For Triton 3.3.x make_tensor_descriptor = triton.language._experimental_make_tensor_descriptor elif hasattr(triton.language, "make_tensor_descriptor"): # For Triton 3.4.x and later make_tensor_descriptor = triton.language.make_tensor_descriptor else: """ Fallback implementation when TMA is not supported. Returns None to indicate TMA descriptors are unavailable. Just make triton compiler happy. """ @triton.jit def make_tensor_descriptor( base, shape, strides, block_shape, _builder=None, ): return None