187 lines
6.4 KiB
Python
187 lines
6.4 KiB
Python
"""
|
||
策略:顺序(per-sequence)fallback — 纯 PyTorch 数学实现
|
||
==========================================================
|
||
逐条序列用 matmul + softmax 手写 attention,完全绕开所有硬件
|
||
flash attention kernel(ixformer / cudnnFlashAttnForward)。
|
||
|
||
背景:
|
||
Iluvatar cudnnFlashAttnForward 存在两个已知问题:
|
||
1. 不支持 is_causal=True(报错)
|
||
2. 使用 attn_mask 路径时数值结果不正确(静默错误,输出全为"!")
|
||
与华为昇腾 910B4 上 llama.cpp --flash-attn off 修复同类问题的原理相同。
|
||
纯数学路径(matmul + softmax)在任何 PyTorch 后端上结果都正确。
|
||
|
||
优点:
|
||
数值正确,不依赖任何硬件特定 attention kernel。
|
||
峰值显存 = max(seq_len)² × H × dtype_size,由 --max-model-len 控制。
|
||
|
||
缺点:
|
||
并发请求的 prefill attention 串行执行。
|
||
O(L²) 显存(无 flash attention 的 O(L) 优化)。
|
||
|
||
内存参考(fp16,H_local=6):
|
||
max-model-len=4096 → 峰值 ~200 MB
|
||
max-model-len=8192 → 峰值 ~800 MB
|
||
max-model-len=16384 → 峰值 ~3.2 GB
|
||
|
||
Deploy:
|
||
python3 modified_scripts/patch_xformers_sdpa_seq.py
|
||
"""
|
||
|
||
XFORMERS_PATH = (
|
||
"/usr/local/corex/lib64/python3/dist-packages/"
|
||
"vllm/attention/backends/xformers.py"
|
||
)
|
||
|
||
FALLBACK_METHOD = '''
|
||
def _run_sdpa_fallback(
|
||
self,
|
||
query: torch.Tensor,
|
||
key: torch.Tensor,
|
||
value: torch.Tensor,
|
||
attn_metadata: "XFormersMetadata",
|
||
) -> torch.Tensor:
|
||
"""顺序纯数学 attention fallback。
|
||
|
||
完全绕开 ixformer / cudnnFlashAttnForward,用 matmul + softmax
|
||
手写 attention。Iluvatar cudnnFlashAttnForward 的 attn_mask 路径
|
||
存在静默数值错误(输出全为"!"),纯数学路径结果正确。
|
||
|
||
softmax 在 float32 下计算以防止 float16 溢出,结果转回原始 dtype。
|
||
|
||
Args:
|
||
query : [1, total_prefill_tokens, num_heads, head_dim]
|
||
key : [1, total_prefill_tokens, num_kv_heads, head_dim]
|
||
value : [1, total_prefill_tokens, num_kv_heads, head_dim]
|
||
Returns:
|
||
[1, total_prefill_tokens, num_heads, head_dim]
|
||
"""
|
||
assert attn_metadata.seq_lens is not None
|
||
orig_dtype = query.dtype
|
||
|
||
q_flat = query.squeeze(0) # [T, H, D]
|
||
k_flat = key.squeeze(0) # [T, Hkv, D]
|
||
v_flat = value.squeeze(0)
|
||
|
||
output = torch.empty_like(q_flat)
|
||
start = 0
|
||
for seq_len in attn_metadata.seq_lens:
|
||
end = start + seq_len
|
||
# [1, H, L, D]
|
||
q_s = q_flat[start:end].permute(1, 0, 2).contiguous().unsqueeze(0)
|
||
k_s = k_flat[start:end].permute(1, 0, 2).contiguous().unsqueeze(0)
|
||
v_s = v_flat[start:end].permute(1, 0, 2).contiguous().unsqueeze(0)
|
||
|
||
# GQA:展开 KV heads 至与 query heads 一致
|
||
if k_s.shape[1] != q_s.shape[1]:
|
||
n = q_s.shape[1] // k_s.shape[1]
|
||
k_s = k_s.repeat_interleave(n, dim=1).contiguous()
|
||
v_s = v_s.repeat_interleave(n, dim=1).contiguous()
|
||
|
||
# 纯数学 attention:完全绕开硬件 flash attention kernel
|
||
# [1, H, L, L]
|
||
attn_w = torch.matmul(q_s.float(), k_s.float().transpose(-2, -1))
|
||
attn_w = attn_w * self.scale
|
||
|
||
# 上三角填 -inf(future tokens)
|
||
causal_mask = torch.triu(
|
||
torch.ones(seq_len, seq_len, dtype=torch.bool, device=attn_w.device),
|
||
diagonal=1,
|
||
)
|
||
attn_w = attn_w.masked_fill(causal_mask, float("-inf"))
|
||
|
||
# float32 softmax 防止 float16 溢出
|
||
attn_w = torch.softmax(attn_w, dim=-1)
|
||
|
||
out_s = torch.matmul(attn_w, v_s.float()).to(orig_dtype)
|
||
# [1, H, L, D] → [L, H, D]
|
||
output[start:end] = out_s.squeeze(0).permute(1, 0, 2)
|
||
start = end
|
||
|
||
return output.unsqueeze(0) # [1, T, H, D]
|
||
|
||
'''
|
||
|
||
OLD_XFORMER_BLOCK = """\
|
||
self.attn_op = xops.fmha.flash.FwOp()
|
||
if self.alibi_slopes is None:
|
||
# Add the batch dimension.
|
||
query = query.unsqueeze(0)
|
||
key = key.unsqueeze(0)
|
||
value = value.unsqueeze(0)
|
||
out = xops.memory_efficient_attention_forward(
|
||
query,
|
||
key,
|
||
value,
|
||
attn_bias=attn_bias[0],
|
||
p=0.0,
|
||
scale=self.scale,
|
||
op = self.attn_op
|
||
)
|
||
return out.view_as(original_query)\
|
||
"""
|
||
|
||
NEW_XFORMER_BLOCK = """\
|
||
self.attn_op = xops.fmha.flash.FwOp()
|
||
if self.alibi_slopes is None:
|
||
# Add the batch dimension.
|
||
query = query.unsqueeze(0)
|
||
key = key.unsqueeze(0)
|
||
value = value.unsqueeze(0)
|
||
if self.head_size > 128:
|
||
out = self._run_sdpa_fallback(query, key, value, attn_metadata)
|
||
else:
|
||
out = xops.memory_efficient_attention_forward(
|
||
query,
|
||
key,
|
||
value,
|
||
attn_bias=attn_bias[0],
|
||
p=0.0,
|
||
scale=self.scale,
|
||
op=self.attn_op,
|
||
)
|
||
return out.view_as(original_query)\
|
||
"""
|
||
|
||
INJECT_ANCHOR = " def _run_memory_efficient_xformers_forward("
|
||
|
||
|
||
def patch_file(path):
|
||
with open(path, "r") as f:
|
||
content = f.read()
|
||
changed = False
|
||
|
||
if "_run_sdpa_fallback" in content:
|
||
print(" [skip] _run_sdpa_fallback already present")
|
||
elif INJECT_ANCHOR not in content:
|
||
print(" [warn] inject anchor not found")
|
||
else:
|
||
content = content.replace(INJECT_ANCHOR, FALLBACK_METHOD + INJECT_ANCHOR, 1)
|
||
print(" [ok] injected _run_sdpa_fallback (sequential, pure-math)")
|
||
changed = True
|
||
|
||
if NEW_XFORMER_BLOCK in content:
|
||
print(" [skip] dispatch block already patched")
|
||
elif OLD_XFORMER_BLOCK in content:
|
||
content = content.replace(OLD_XFORMER_BLOCK, NEW_XFORMER_BLOCK, 1)
|
||
print(" [ok] patched dispatch block")
|
||
changed = True
|
||
else:
|
||
print(" [warn] dispatch block anchor not found")
|
||
|
||
if changed:
|
||
with open(path, "w") as f:
|
||
f.write(content)
|
||
print(f" Written: {path}")
|
||
|
||
|
||
def main():
|
||
print("=== patch_xformers_sdpa_seq (sequential, pure-math) ===")
|
||
print(f"Target: {XFORMERS_PATH}")
|
||
patch_file(XFORMERS_PATH)
|
||
print("\nDone.")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|