Files
enginex-vllm-bi100-qwen36/qwen3_6_scripts/patch_xformers_sdpa_seq.py

215 lines
7.7 KiB
Python
Raw Normal View History

"""
策略顺序per-sequencefallback PyTorch 数学实现
==========================================================
逐条序列用 matmul + softmax 手写 attention完全绕开所有硬件
flash attention kernelixformer / cudnnFlashAttnForward
背景
Iluvatar cudnnFlashAttnForward 存在两个已知问题
1. 不支持 is_causal=True报错
2. 使用 attn_mask 路径时数值结果不正确静默错误输出全为"!"
与华为昇腾 910B4 llama.cpp --flash-attn off 修复同类问题的原理相同
纯数学路径matmul + softmax在任何 PyTorch 后端上结果都正确
优点
数值正确不依赖任何硬件特定 attention kernel
峰值显存 = max(seq_len)² × H × dtype_size --max-model-len 控制
缺点
并发请求的 prefill attention 串行执行
O() 显存 flash attention O(L) 优化
内存参考fp16H_local=6
max-model-len=4096 峰值 ~200 MB
max-model-len=8192 峰值 ~800 MB
max-model-len=16384 峰值 ~3.2 GB
Deploy:
python3 modified_scripts/patch_xformers_sdpa_seq.py
"""
XFORMERS_PATH = (
"/usr/local/corex/lib64/python3/dist-packages/"
"vllm/attention/backends/xformers.py"
)
FALLBACK_METHOD = '''
def _run_sdpa_fallback(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: "XFormersMetadata",
) -> torch.Tensor:
"""纯数学 causal attention fallback带 Q-tiling 内存优化。
调用时机kv_cache.numel()==0profiling 阶段
此路径无 KV 缓存前缀KV 长度 == query 长度
内存优化Q-tiling Flash Attention 同思路
Q 分成 _Q_CHUNK 大小的子块逐块计算每块峰值内存
O(_Q_CHUNK × q_len) 而非 O(q_len²)
profiling 阶段序列可能达到 max_model_len 20K tokens
不加 Q-tiling 会产生 9.6 GB 矩阵直接 OOM
softmax float32 下计算以防止 float16 溢出结果转回原始 dtype
Args:
query : [1, total_query_tokens, num_heads, head_dim]
key : [1, total_query_tokens, num_kv_heads, head_dim]
value : [1, total_query_tokens, num_kv_heads, head_dim]
Returns:
[1, total_query_tokens, num_heads, head_dim]
"""
_Q_CHUNK = 256 # 与 _forward_prefix_pytorch 的 _ATTN_Q_CHUNK 保持一致
assert attn_metadata.seq_lens is not None
orig_dtype = query.dtype
num_seqs = len(attn_metadata.seq_lens)
# 推导每条序列的实际 query 长度。
# 正常 prefill 时 q_len == seq_len如果将来遇到 chunked 场景,
# query_start_loc 记录的是真实 query token 数(非全序列长度)。
if (attn_metadata.query_start_loc is not None
and len(attn_metadata.query_start_loc) == num_seqs + 1):
q_lens = [
int(attn_metadata.query_start_loc[i + 1].item()) -
int(attn_metadata.query_start_loc[i].item())
for i in range(num_seqs)
]
else:
q_lens = list(attn_metadata.seq_lens)
q_flat = query.squeeze(0) # [T, H, D]
k_flat = key.squeeze(0) # [T, Hkv, D]
v_flat = value.squeeze(0)
output = torch.empty_like(q_flat)
seq_start = 0
for q_len in q_lens:
seq_end = seq_start + q_len
# 当前序列的完整 K/V此路径无前缀KV == Q
k_s = k_flat[seq_start:seq_end].permute(1, 0, 2).float() # [Hkv, q_len, D]
v_s = v_flat[seq_start:seq_end].permute(1, 0, 2).float() # [Hkv, q_len, D]
# GQA展开 KV heads 至与 query heads 一致
if k_s.shape[0] != self.num_heads:
n = self.num_heads // k_s.shape[0]
k_s = k_s.repeat_interleave(n, dim=0).contiguous()
v_s = v_s.repeat_interleave(n, dim=0).contiguous()
# k_pos 用于因果掩码
k_pos = torch.arange(q_len, device=query.device)
# Q-tiling分块处理 query峰值内存 O(_Q_CHUNK × q_len)
for qc_start in range(0, q_len, _Q_CHUNK):
qc_end = min(qc_start + _Q_CHUNK, q_len)
# [H, qc, D]
q_c = q_flat[seq_start + qc_start:seq_start + qc_end] \
.permute(1, 0, 2).float()
# [H, qc, q_len]
attn_w = torch.matmul(q_c, k_s.transpose(-2, -1)) * self.scale
# 因果掩码q_c 里位置 j 只能看 k_pos <= j相对位置
qc_q_pos = torch.arange(qc_start, qc_end, device=query.device)
mask = k_pos.unsqueeze(0) > qc_q_pos.unsqueeze(1)
attn_w = attn_w.masked_fill(mask.unsqueeze(0), float("-inf"))
attn_w = torch.softmax(attn_w, dim=-1)
out_c = torch.matmul(attn_w, v_s).to(orig_dtype) # [H, qc, D]
output[seq_start + qc_start:seq_start + qc_end] = (
out_c.permute(1, 0, 2))
seq_start = seq_end
return output.unsqueeze(0) # [1, T, H, D]
'''
OLD_XFORMER_BLOCK = """\
self.attn_op = xops.fmha.flash.FwOp()
if self.alibi_slopes is None:
# Add the batch dimension.
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
out = xops.memory_efficient_attention_forward(
query,
key,
value,
attn_bias=attn_bias[0],
p=0.0,
scale=self.scale,
op = self.attn_op
)
return out.view_as(original_query)\
"""
NEW_XFORMER_BLOCK = """\
self.attn_op = xops.fmha.flash.FwOp()
if self.alibi_slopes is None:
# Add the batch dimension.
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
if self.head_size > 128:
out = self._run_sdpa_fallback(query, key, value, attn_metadata)
else:
out = xops.memory_efficient_attention_forward(
query,
key,
value,
attn_bias=attn_bias[0],
p=0.0,
scale=self.scale,
op=self.attn_op,
)
return out.view_as(original_query)\
"""
INJECT_ANCHOR = " def _run_memory_efficient_xformers_forward("
def patch_file(path):
with open(path, "r") as f:
content = f.read()
changed = False
if "_run_sdpa_fallback" in content:
print(" [skip] _run_sdpa_fallback already present")
elif INJECT_ANCHOR not in content:
print(" [warn] inject anchor not found")
else:
content = content.replace(INJECT_ANCHOR, FALLBACK_METHOD + INJECT_ANCHOR, 1)
print(" [ok] injected _run_sdpa_fallback (sequential, pure-math)")
changed = True
if NEW_XFORMER_BLOCK in content:
print(" [skip] dispatch block already patched")
elif OLD_XFORMER_BLOCK in content:
content = content.replace(OLD_XFORMER_BLOCK, NEW_XFORMER_BLOCK, 1)
print(" [ok] patched dispatch block")
changed = True
else:
print(" [warn] dispatch block anchor not found")
if changed:
with open(path, "w") as f:
f.write(content)
print(f" Written: {path}")
def main():
print("=== patch_xformers_sdpa_seq (sequential, pure-math) ===")
print(f"Target: {XFORMERS_PATH}")
patch_file(XFORMERS_PATH)
print("\nDone.")
if __name__ == "__main__":
main()