from transformers.modeling_utils import AttentionInterface from transformers.models.llama.modeling_llama import LlamaAttention def custom_flex(x, **kwargs): """Dummy function.""" return x ALL_ATTENTION_FUNCTIONS = AttentionInterface() # This indexing statement and associated function should be exported correctly! ALL_ATTENTION_FUNCTIONS["flex_attention"] = custom_flex class GlobalIndexingAttention(LlamaAttention): pass