Files
enginex-mlu370-any2any/transformers/examples/modular-transformers/modular_global_indexing.py
2025-10-09 16:47:16 +08:00

17 lines
441 B
Python

from transformers.modeling_utils import AttentionInterface
from transformers.models.llama.modeling_llama import LlamaAttention
def custom_flex(x, **kwargs):
"""Dummy function."""
return x
ALL_ATTENTION_FUNCTIONS = AttentionInterface()
# This indexing statement and associated function should be exported correctly!
ALL_ATTENTION_FUNCTIONS["flex_attention"] = custom_flex
class GlobalIndexingAttention(LlamaAttention):
pass