Files
enginex-vllm-bi100-qwen36/qwen3_6_scripts/patch_xformers_sdpa_batch_kernel.py

192 lines
7.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
策略批量block-diagonal— F.scaled_dot_product_attention可走硬件 kernel
=============================================================================
构建块对角 causal mask对整批序列一次 F.scaled_dot_product_attention。
与 patch_xformers_sdpa_batch.py纯 matmul的区别
SDPA 会根据 PyTorch/驱动能力分发到最优 kernelFlash Attention /
mem-efficient attention / math fallback而不是固定走 cublas matmul。
历史说明:
该方案最早因输出全"!"而被弃用,后续排查确认"!"由 mamba_cache.py bug
引起,与 attention 实现无关。当前恢复此方案用于性能对比测试。
已知硬件限制BI-V100
cudnnFlashAttnForward 不支持 is_causal=True报错
本实现使用 is_causal=False + 显式块对角 additive mask 规避此限制。
若 SDPA 仍分发到有问题的 kernel回退到 patch_xformers_sdpa_batch.py。
优点vs 纯 matmul
SDPA 可分发到 Flash Attention kernel → O(L) 显存、更快的 CUDA kernel。
缺点:
依赖硬件 kernel 行为,若 kernel 有 bug 则数值错误(需与 matmul 版对比验证)。
Deploy:
python3 modified_scripts/patch_xformers_sdpa_batch_kernel.py
"""
XFORMERS_PATH = (
"/usr/local/corex/lib64/python3/dist-packages/"
"vllm/attention/backends/xformers.py"
)
FALLBACK_METHOD = '''
def _run_sdpa_fallback(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: "XFormersMetadata",
) -> torch.Tensor:
"""批量 F.scaled_dot_product_attention fallback可走硬件 kernel
构建块对角 causal mask对整批序列一次 SDPA 调用。
SDPA 可分发到 Flash Attention / mem-efficient attention kernel。
is_causal=False + 显式 additive mask规避 cudnnFlashAttnForward
不支持 is_causal=True 的限制。
块对角 maskseq1 len=3seq2 len=2
s1,0 s1,1 s1,2 s2,0 s2,1
s1,0 [ 0 -inf -inf -inf -inf ]
s1,1 [ 0 0 -inf -inf -inf ]
s1,2 [ 0 0 0 -inf -inf ]
s2,0 [-inf -inf -inf 0 -inf ]
s2,1 [-inf -inf -inf 0 0 ]
Args:
query : [1, total_prefill_tokens, num_heads, head_dim]
key : [1, total_prefill_tokens, num_kv_heads, head_dim]
value : [1, total_prefill_tokens, num_kv_heads, head_dim]
Returns:
[1, total_prefill_tokens, num_heads, head_dim]
"""
import torch.nn.functional as F
assert attn_metadata.seq_lens is not None
orig_dtype = query.dtype
total_tokens = query.shape[1]
# ── 块对角 causal mask [T, T] ─────────────────────────────────────
mask = torch.full(
(total_tokens, total_tokens),
float("-inf"),
dtype=orig_dtype,
device=query.device,
)
start = 0
for seq_len in attn_metadata.seq_lens:
end = start + seq_len
mask[start:end, start:end] = torch.tril(
torch.zeros(seq_len, seq_len, dtype=orig_dtype, device=query.device)
)
start = end
# ── [1, H, T, D] ──────────────────────────────────────────────────
q_all = query.squeeze(0).permute(1, 0, 2).contiguous().unsqueeze(0)
k_all = key.squeeze(0).permute(1, 0, 2).contiguous().unsqueeze(0)
v_all = value.squeeze(0).permute(1, 0, 2).contiguous().unsqueeze(0)
# ── GQA展开 KV heads ────────────────────────────────────────────
if k_all.shape[1] != q_all.shape[1]:
n = q_all.shape[1] // k_all.shape[1]
k_all = k_all.repeat_interleave(n, dim=1).contiguous()
v_all = v_all.repeat_interleave(n, dim=1).contiguous()
# ── F.scaled_dot_product_attention可走硬件 kernel─────────────
# is_causal=False避免 cudnnFlashAttnForward "not support causal mode"
# attn_mask 传 additive float mask非 boolSDPA 选择 math/kernel 路径
out = F.scaled_dot_product_attention(
q_all, k_all, v_all,
attn_mask=mask,
dropout_p=0.0,
is_causal=False,
scale=self.scale,
)
# [1, H, T, D] → [1, T, H, D]
return out.squeeze(0).permute(1, 0, 2).contiguous().unsqueeze(0)
'''
OLD_XFORMER_BLOCK = """\
self.attn_op = xops.fmha.flash.FwOp()
if self.alibi_slopes is None:
# Add the batch dimension.
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
out = xops.memory_efficient_attention_forward(
query,
key,
value,
attn_bias=attn_bias[0],
p=0.0,
scale=self.scale,
op = self.attn_op
)
return out.view_as(original_query)\
"""
NEW_XFORMER_BLOCK = """\
self.attn_op = xops.fmha.flash.FwOp()
if self.alibi_slopes is None:
# Add the batch dimension.
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
if self.head_size > 128:
out = self._run_sdpa_fallback(query, key, value, attn_metadata)
else:
out = xops.memory_efficient_attention_forward(
query,
key,
value,
attn_bias=attn_bias[0],
p=0.0,
scale=self.scale,
op=self.attn_op,
)
return out.view_as(original_query)\
"""
INJECT_ANCHOR = " def _run_memory_efficient_xformers_forward("
def patch_file(path):
with open(path, "r") as f:
content = f.read()
changed = False
if "_run_sdpa_fallback" in content:
print(" [skip] _run_sdpa_fallback already present")
elif INJECT_ANCHOR not in content:
print(" [warn] inject anchor not found")
else:
content = content.replace(INJECT_ANCHOR, FALLBACK_METHOD + INJECT_ANCHOR, 1)
print(" [ok] injected _run_sdpa_fallback (batch, F.sdpa kernel)")
changed = True
if NEW_XFORMER_BLOCK in content:
print(" [skip] dispatch block already patched")
elif OLD_XFORMER_BLOCK in content:
content = content.replace(OLD_XFORMER_BLOCK, NEW_XFORMER_BLOCK, 1)
print(" [ok] patched dispatch block")
changed = True
else:
print(" [warn] dispatch block anchor not found")
if changed:
with open(path, "w") as f:
f.write(content)
print(f" Written: {path}")
def main():
print("=== patch_xformers_sdpa_batch_kernel (batch, F.sdpa + kernel dispatch) ===")
print(f"Target: {XFORMERS_PATH}")
patch_file(XFORMERS_PATH)
print("\nDone.")
if __name__ == "__main__":
main()