添加 K100-vLLM-Patched-v2.0/patch_triton.py
This commit is contained in:
27
K100-vLLM-Patched-v2.0/patch_triton.py
Normal file
27
K100-vLLM-Patched-v2.0/patch_triton.py
Normal file
@@ -0,0 +1,27 @@
|
||||
path = '/usr/local/lib/python3.10/dist-packages/vllm/v1/attention/backends/triton_attn.py'
|
||||
with open(path, 'r') as f:
|
||||
content = f.read()
|
||||
|
||||
old = ''' @classmethod
|
||||
def validate_head_size(cls, head_size: int) -> None:
|
||||
supported_head_sizes = cls.get_supported_head_sizes()
|
||||
if head_size not in supported_head_sizes:
|
||||
attn_type = cls.__name__.removesuffix("Backend")
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by {attn_type}. "
|
||||
f"Supported head sizes are: {supported_head_sizes}. "
|
||||
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
|
||||
"FlexAttention backend which supports all head sizes.")'''
|
||||
|
||||
new = ''' @classmethod
|
||||
def validate_head_size(cls, head_size: int) -> None:
|
||||
# PATCH: allow all head sizes on ROCm (FlexAttention has vmap issues)
|
||||
return'''
|
||||
|
||||
if old in content:
|
||||
content = content.replace(old, new)
|
||||
with open(path, 'w') as f:
|
||||
f.write(content)
|
||||
print('patch_triton: validate_head_size bypassed successfully')
|
||||
else:
|
||||
print('patch_triton: WARNING - pattern not found, patch skipped')
|
||||
Reference in New Issue
Block a user