diff --git a/K100-vLLM-Patched-v2.0/patch_triton.py b/K100-vLLM-Patched-v2.0/patch_triton.py new file mode 100644 index 0000000..092466e --- /dev/null +++ b/K100-vLLM-Patched-v2.0/patch_triton.py @@ -0,0 +1,27 @@ +path = '/usr/local/lib/python3.10/dist-packages/vllm/v1/attention/backends/triton_attn.py' +with open(path, 'r') as f: + content = f.read() + +old = ''' @classmethod + def validate_head_size(cls, head_size: int) -> None: + supported_head_sizes = cls.get_supported_head_sizes() + if head_size not in supported_head_sizes: + attn_type = cls.__name__.removesuffix("Backend") + raise ValueError( + f"Head size {head_size} is not supported by {attn_type}. " + f"Supported head sizes are: {supported_head_sizes}. " + "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " + "FlexAttention backend which supports all head sizes.")''' + +new = ''' @classmethod + def validate_head_size(cls, head_size: int) -> None: + # PATCH: allow all head sizes on ROCm (FlexAttention has vmap issues) + return''' + +if old in content: + content = content.replace(old, new) + with open(path, 'w') as f: + f.write(content) + print('patch_triton: validate_head_size bypassed successfully') +else: + print('patch_triton: WARNING - pattern not found, patch skipped') \ No newline at end of file