Files
enginex-vllm-bi100-qwen36/qwen3_6_scripts/patch_xformers_sdpa_seq.py

215 lines
7.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
策略顺序per-sequencefallback — 纯 PyTorch 数学实现
==========================================================
逐条序列用 matmul + softmax 手写 attention完全绕开所有硬件
flash attention kernelixformer / cudnnFlashAttnForward
背景:
Iluvatar cudnnFlashAttnForward 存在两个已知问题:
1. 不支持 is_causal=True报错
2. 使用 attn_mask 路径时数值结果不正确(静默错误,输出全为"!"
与华为昇腾 910B4 上 llama.cpp --flash-attn off 修复同类问题的原理相同。
纯数学路径matmul + softmax在任何 PyTorch 后端上结果都正确。
优点:
数值正确,不依赖任何硬件特定 attention kernel。
峰值显存 = max(seq_len)² × H × dtype_size由 --max-model-len 控制。
缺点:
并发请求的 prefill attention 串行执行。
O(L²) 显存(无 flash attention 的 O(L) 优化)。
内存参考fp16H_local=6
max-model-len=4096 → 峰值 ~200 MB
max-model-len=8192 → 峰值 ~800 MB
max-model-len=16384 → 峰值 ~3.2 GB
Deploy:
python3 modified_scripts/patch_xformers_sdpa_seq.py
"""
XFORMERS_PATH = (
"/usr/local/corex/lib64/python3/dist-packages/"
"vllm/attention/backends/xformers.py"
)
FALLBACK_METHOD = '''
def _run_sdpa_fallback(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: "XFormersMetadata",
) -> torch.Tensor:
"""纯数学 causal attention fallback带 Q-tiling 内存优化。
调用时机kv_cache.numel()==0profiling 阶段)。
此路径无 KV 缓存前缀KV 长度 == query 长度。
内存优化Q-tiling与 Flash Attention 同思路):
将 Q 分成 _Q_CHUNK 大小的子块逐块计算,每块峰值内存
O(_Q_CHUNK × q_len) 而非 O(q_len²)。
profiling 阶段序列可能达到 max_model_len如 20K tokens
不加 Q-tiling 会产生 9.6 GB 矩阵直接 OOM。
softmax 在 float32 下计算以防止 float16 溢出,结果转回原始 dtype。
Args:
query : [1, total_query_tokens, num_heads, head_dim]
key : [1, total_query_tokens, num_kv_heads, head_dim]
value : [1, total_query_tokens, num_kv_heads, head_dim]
Returns:
[1, total_query_tokens, num_heads, head_dim]
"""
_Q_CHUNK = 256 # 与 _forward_prefix_pytorch 的 _ATTN_Q_CHUNK 保持一致
assert attn_metadata.seq_lens is not None
orig_dtype = query.dtype
num_seqs = len(attn_metadata.seq_lens)
# 推导每条序列的实际 query 长度。
# 正常 prefill 时 q_len == seq_len如果将来遇到 chunked 场景,
# query_start_loc 记录的是真实 query token 数(非全序列长度)。
if (attn_metadata.query_start_loc is not None
and len(attn_metadata.query_start_loc) == num_seqs + 1):
q_lens = [
int(attn_metadata.query_start_loc[i + 1].item()) -
int(attn_metadata.query_start_loc[i].item())
for i in range(num_seqs)
]
else:
q_lens = list(attn_metadata.seq_lens)
q_flat = query.squeeze(0) # [T, H, D]
k_flat = key.squeeze(0) # [T, Hkv, D]
v_flat = value.squeeze(0)
output = torch.empty_like(q_flat)
seq_start = 0
for q_len in q_lens:
seq_end = seq_start + q_len
# 当前序列的完整 K/V此路径无前缀KV == Q
k_s = k_flat[seq_start:seq_end].permute(1, 0, 2).float() # [Hkv, q_len, D]
v_s = v_flat[seq_start:seq_end].permute(1, 0, 2).float() # [Hkv, q_len, D]
# GQA展开 KV heads 至与 query heads 一致
if k_s.shape[0] != self.num_heads:
n = self.num_heads // k_s.shape[0]
k_s = k_s.repeat_interleave(n, dim=0).contiguous()
v_s = v_s.repeat_interleave(n, dim=0).contiguous()
# k_pos 用于因果掩码
k_pos = torch.arange(q_len, device=query.device)
# Q-tiling分块处理 query峰值内存 O(_Q_CHUNK × q_len)
for qc_start in range(0, q_len, _Q_CHUNK):
qc_end = min(qc_start + _Q_CHUNK, q_len)
# [H, qc, D]
q_c = q_flat[seq_start + qc_start:seq_start + qc_end] \
.permute(1, 0, 2).float()
# [H, qc, q_len]
attn_w = torch.matmul(q_c, k_s.transpose(-2, -1)) * self.scale
# 因果掩码q_c 里位置 j 只能看 k_pos <= j相对位置
qc_q_pos = torch.arange(qc_start, qc_end, device=query.device)
mask = k_pos.unsqueeze(0) > qc_q_pos.unsqueeze(1)
attn_w = attn_w.masked_fill(mask.unsqueeze(0), float("-inf"))
attn_w = torch.softmax(attn_w, dim=-1)
out_c = torch.matmul(attn_w, v_s).to(orig_dtype) # [H, qc, D]
output[seq_start + qc_start:seq_start + qc_end] = (
out_c.permute(1, 0, 2))
seq_start = seq_end
return output.unsqueeze(0) # [1, T, H, D]
'''
OLD_XFORMER_BLOCK = """\
self.attn_op = xops.fmha.flash.FwOp()
if self.alibi_slopes is None:
# Add the batch dimension.
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
out = xops.memory_efficient_attention_forward(
query,
key,
value,
attn_bias=attn_bias[0],
p=0.0,
scale=self.scale,
op = self.attn_op
)
return out.view_as(original_query)\
"""
NEW_XFORMER_BLOCK = """\
self.attn_op = xops.fmha.flash.FwOp()
if self.alibi_slopes is None:
# Add the batch dimension.
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
if self.head_size > 128:
out = self._run_sdpa_fallback(query, key, value, attn_metadata)
else:
out = xops.memory_efficient_attention_forward(
query,
key,
value,
attn_bias=attn_bias[0],
p=0.0,
scale=self.scale,
op=self.attn_op,
)
return out.view_as(original_query)\
"""
INJECT_ANCHOR = " def _run_memory_efficient_xformers_forward("
def patch_file(path):
with open(path, "r") as f:
content = f.read()
changed = False
if "_run_sdpa_fallback" in content:
print(" [skip] _run_sdpa_fallback already present")
elif INJECT_ANCHOR not in content:
print(" [warn] inject anchor not found")
else:
content = content.replace(INJECT_ANCHOR, FALLBACK_METHOD + INJECT_ANCHOR, 1)
print(" [ok] injected _run_sdpa_fallback (sequential, pure-math)")
changed = True
if NEW_XFORMER_BLOCK in content:
print(" [skip] dispatch block already patched")
elif OLD_XFORMER_BLOCK in content:
content = content.replace(OLD_XFORMER_BLOCK, NEW_XFORMER_BLOCK, 1)
print(" [ok] patched dispatch block")
changed = True
else:
print(" [warn] dispatch block anchor not found")
if changed:
with open(path, "w") as f:
f.write(content)
print(f" Written: {path}")
def main():
print("=== patch_xformers_sdpa_seq (sequential, pure-math) ===")
print(f"Target: {XFORMERS_PATH}")
patch_file(XFORMERS_PATH)
print("\nDone.")
if __name__ == "__main__":
main()