Files
enginex-vllm-bi100-qwen36/qwen3_6_scripts/patch_xformers_sdpa_seq.py

187 lines
6.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
策略顺序per-sequencefallback — 纯 PyTorch 数学实现
==========================================================
逐条序列用 matmul + softmax 手写 attention完全绕开所有硬件
flash attention kernelixformer / cudnnFlashAttnForward
背景:
Iluvatar cudnnFlashAttnForward 存在两个已知问题:
1. 不支持 is_causal=True报错
2. 使用 attn_mask 路径时数值结果不正确(静默错误,输出全为"!"
与华为昇腾 910B4 上 llama.cpp --flash-attn off 修复同类问题的原理相同。
纯数学路径matmul + softmax在任何 PyTorch 后端上结果都正确。
优点:
数值正确,不依赖任何硬件特定 attention kernel。
峰值显存 = max(seq_len)² × H × dtype_size由 --max-model-len 控制。
缺点:
并发请求的 prefill attention 串行执行。
O(L²) 显存(无 flash attention 的 O(L) 优化)。
内存参考fp16H_local=6
max-model-len=4096 → 峰值 ~200 MB
max-model-len=8192 → 峰值 ~800 MB
max-model-len=16384 → 峰值 ~3.2 GB
Deploy:
python3 modified_scripts/patch_xformers_sdpa_seq.py
"""
XFORMERS_PATH = (
"/usr/local/corex/lib64/python3/dist-packages/"
"vllm/attention/backends/xformers.py"
)
FALLBACK_METHOD = '''
def _run_sdpa_fallback(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: "XFormersMetadata",
) -> torch.Tensor:
"""顺序纯数学 attention fallback。
完全绕开 ixformer / cudnnFlashAttnForward用 matmul + softmax
手写 attention。Iluvatar cudnnFlashAttnForward 的 attn_mask 路径
存在静默数值错误(输出全为"!"),纯数学路径结果正确。
softmax 在 float32 下计算以防止 float16 溢出,结果转回原始 dtype。
Args:
query : [1, total_prefill_tokens, num_heads, head_dim]
key : [1, total_prefill_tokens, num_kv_heads, head_dim]
value : [1, total_prefill_tokens, num_kv_heads, head_dim]
Returns:
[1, total_prefill_tokens, num_heads, head_dim]
"""
assert attn_metadata.seq_lens is not None
orig_dtype = query.dtype
q_flat = query.squeeze(0) # [T, H, D]
k_flat = key.squeeze(0) # [T, Hkv, D]
v_flat = value.squeeze(0)
output = torch.empty_like(q_flat)
start = 0
for seq_len in attn_metadata.seq_lens:
end = start + seq_len
# [1, H, L, D]
q_s = q_flat[start:end].permute(1, 0, 2).contiguous().unsqueeze(0)
k_s = k_flat[start:end].permute(1, 0, 2).contiguous().unsqueeze(0)
v_s = v_flat[start:end].permute(1, 0, 2).contiguous().unsqueeze(0)
# GQA展开 KV heads 至与 query heads 一致
if k_s.shape[1] != q_s.shape[1]:
n = q_s.shape[1] // k_s.shape[1]
k_s = k_s.repeat_interleave(n, dim=1).contiguous()
v_s = v_s.repeat_interleave(n, dim=1).contiguous()
# 纯数学 attention完全绕开硬件 flash attention kernel
# [1, H, L, L]
attn_w = torch.matmul(q_s.float(), k_s.float().transpose(-2, -1))
attn_w = attn_w * self.scale
# 上三角填 -inffuture tokens
causal_mask = torch.triu(
torch.ones(seq_len, seq_len, dtype=torch.bool, device=attn_w.device),
diagonal=1,
)
attn_w = attn_w.masked_fill(causal_mask, float("-inf"))
# float32 softmax 防止 float16 溢出
attn_w = torch.softmax(attn_w, dim=-1)
out_s = torch.matmul(attn_w, v_s.float()).to(orig_dtype)
# [1, H, L, D] → [L, H, D]
output[start:end] = out_s.squeeze(0).permute(1, 0, 2)
start = end
return output.unsqueeze(0) # [1, T, H, D]
'''
OLD_XFORMER_BLOCK = """\
self.attn_op = xops.fmha.flash.FwOp()
if self.alibi_slopes is None:
# Add the batch dimension.
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
out = xops.memory_efficient_attention_forward(
query,
key,
value,
attn_bias=attn_bias[0],
p=0.0,
scale=self.scale,
op = self.attn_op
)
return out.view_as(original_query)\
"""
NEW_XFORMER_BLOCK = """\
self.attn_op = xops.fmha.flash.FwOp()
if self.alibi_slopes is None:
# Add the batch dimension.
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
if self.head_size > 128:
out = self._run_sdpa_fallback(query, key, value, attn_metadata)
else:
out = xops.memory_efficient_attention_forward(
query,
key,
value,
attn_bias=attn_bias[0],
p=0.0,
scale=self.scale,
op=self.attn_op,
)
return out.view_as(original_query)\
"""
INJECT_ANCHOR = " def _run_memory_efficient_xformers_forward("
def patch_file(path):
with open(path, "r") as f:
content = f.read()
changed = False
if "_run_sdpa_fallback" in content:
print(" [skip] _run_sdpa_fallback already present")
elif INJECT_ANCHOR not in content:
print(" [warn] inject anchor not found")
else:
content = content.replace(INJECT_ANCHOR, FALLBACK_METHOD + INJECT_ANCHOR, 1)
print(" [ok] injected _run_sdpa_fallback (sequential, pure-math)")
changed = True
if NEW_XFORMER_BLOCK in content:
print(" [skip] dispatch block already patched")
elif OLD_XFORMER_BLOCK in content:
content = content.replace(OLD_XFORMER_BLOCK, NEW_XFORMER_BLOCK, 1)
print(" [ok] patched dispatch block")
changed = True
else:
print(" [warn] dispatch block anchor not found")
if changed:
with open(path, "w") as f:
f.write(content)
print(f" Written: {path}")
def main():
print("=== patch_xformers_sdpa_seq (sequential, pure-math) ===")
print(f"Target: {XFORMERS_PATH}")
patch_file(XFORMERS_PATH)
print("\nDone.")
if __name__ == "__main__":
main()