init
This commit is contained in:
@@ -0,0 +1,16 @@
|
||||
from transformers.modeling_utils import AttentionInterface
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention
|
||||
|
||||
|
||||
def custom_flex(x, **kwargs):
|
||||
"""Dummy function."""
|
||||
return x
|
||||
|
||||
|
||||
ALL_ATTENTION_FUNCTIONS = AttentionInterface()
|
||||
# This indexing statement and associated function should be exported correctly!
|
||||
ALL_ATTENTION_FUNCTIONS["flex_attention"] = custom_flex
|
||||
|
||||
|
||||
class GlobalIndexingAttention(LlamaAttention):
|
||||
pass
|
||||
Reference in New Issue
Block a user