27 lines
1.1 KiB
Python
27 lines
1.1 KiB
Python
path = '/usr/local/lib/python3.10/dist-packages/vllm/v1/attention/backends/triton_attn.py'
|
|
with open(path, 'r') as f:
|
|
content = f.read()
|
|
|
|
old = ''' @classmethod
|
|
def validate_head_size(cls, head_size: int) -> None:
|
|
supported_head_sizes = cls.get_supported_head_sizes()
|
|
if head_size not in supported_head_sizes:
|
|
attn_type = cls.__name__.removesuffix("Backend")
|
|
raise ValueError(
|
|
f"Head size {head_size} is not supported by {attn_type}. "
|
|
f"Supported head sizes are: {supported_head_sizes}. "
|
|
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
|
|
"FlexAttention backend which supports all head sizes.")'''
|
|
|
|
new = ''' @classmethod
|
|
def validate_head_size(cls, head_size: int) -> None:
|
|
# PATCH: allow all head sizes on ROCm (FlexAttention has vmap issues)
|
|
return'''
|
|
|
|
if old in content:
|
|
content = content.replace(old, new)
|
|
with open(path, 'w') as f:
|
|
f.write(content)
|
|
print('patch_triton: validate_head_size bypassed successfully')
|
|
else:
|
|
print('patch_triton: WARNING - pattern not found, patch skipped') |