192 lines
7.2 KiB
Python
192 lines
7.2 KiB
Python
|
|
"""
|
|||
|
|
策略:批量(block-diagonal)— F.scaled_dot_product_attention,可走硬件 kernel
|
|||
|
|
=============================================================================
|
|||
|
|
构建块对角 causal mask,对整批序列一次 F.scaled_dot_product_attention。
|
|||
|
|
与 patch_xformers_sdpa_batch.py(纯 matmul)的区别:
|
|||
|
|
SDPA 会根据 PyTorch/驱动能力分发到最优 kernel(Flash Attention /
|
|||
|
|
mem-efficient attention / math fallback),而不是固定走 cublas matmul。
|
|||
|
|
|
|||
|
|
历史说明:
|
|||
|
|
该方案最早因输出全"!"而被弃用,后续排查确认"!"由 mamba_cache.py bug
|
|||
|
|
引起,与 attention 实现无关。当前恢复此方案用于性能对比测试。
|
|||
|
|
|
|||
|
|
已知硬件限制(BI-V100):
|
|||
|
|
cudnnFlashAttnForward 不支持 is_causal=True(报错)。
|
|||
|
|
本实现使用 is_causal=False + 显式块对角 additive mask 规避此限制。
|
|||
|
|
若 SDPA 仍分发到有问题的 kernel,回退到 patch_xformers_sdpa_batch.py。
|
|||
|
|
|
|||
|
|
优点(vs 纯 matmul):
|
|||
|
|
SDPA 可分发到 Flash Attention kernel → O(L) 显存、更快的 CUDA kernel。
|
|||
|
|
|
|||
|
|
缺点:
|
|||
|
|
依赖硬件 kernel 行为,若 kernel 有 bug 则数值错误(需与 matmul 版对比验证)。
|
|||
|
|
|
|||
|
|
Deploy:
|
|||
|
|
python3 modified_scripts/patch_xformers_sdpa_batch_kernel.py
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
XFORMERS_PATH = (
|
|||
|
|
"/usr/local/corex/lib64/python3/dist-packages/"
|
|||
|
|
"vllm/attention/backends/xformers.py"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
FALLBACK_METHOD = '''
|
|||
|
|
def _run_sdpa_fallback(
|
|||
|
|
self,
|
|||
|
|
query: torch.Tensor,
|
|||
|
|
key: torch.Tensor,
|
|||
|
|
value: torch.Tensor,
|
|||
|
|
attn_metadata: "XFormersMetadata",
|
|||
|
|
) -> torch.Tensor:
|
|||
|
|
"""批量 F.scaled_dot_product_attention fallback(可走硬件 kernel)。
|
|||
|
|
|
|||
|
|
构建块对角 causal mask,对整批序列一次 SDPA 调用。
|
|||
|
|
SDPA 可分发到 Flash Attention / mem-efficient attention kernel。
|
|||
|
|
is_causal=False + 显式 additive mask,规避 cudnnFlashAttnForward
|
|||
|
|
不支持 is_causal=True 的限制。
|
|||
|
|
|
|||
|
|
块对角 mask(seq1 len=3,seq2 len=2):
|
|||
|
|
s1,0 s1,1 s1,2 s2,0 s2,1
|
|||
|
|
s1,0 [ 0 -inf -inf -inf -inf ]
|
|||
|
|
s1,1 [ 0 0 -inf -inf -inf ]
|
|||
|
|
s1,2 [ 0 0 0 -inf -inf ]
|
|||
|
|
s2,0 [-inf -inf -inf 0 -inf ]
|
|||
|
|
s2,1 [-inf -inf -inf 0 0 ]
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
query : [1, total_prefill_tokens, num_heads, head_dim]
|
|||
|
|
key : [1, total_prefill_tokens, num_kv_heads, head_dim]
|
|||
|
|
value : [1, total_prefill_tokens, num_kv_heads, head_dim]
|
|||
|
|
Returns:
|
|||
|
|
[1, total_prefill_tokens, num_heads, head_dim]
|
|||
|
|
"""
|
|||
|
|
import torch.nn.functional as F
|
|||
|
|
|
|||
|
|
assert attn_metadata.seq_lens is not None
|
|||
|
|
orig_dtype = query.dtype
|
|||
|
|
total_tokens = query.shape[1]
|
|||
|
|
|
|||
|
|
# ── 块对角 causal mask [T, T] ─────────────────────────────────────
|
|||
|
|
mask = torch.full(
|
|||
|
|
(total_tokens, total_tokens),
|
|||
|
|
float("-inf"),
|
|||
|
|
dtype=orig_dtype,
|
|||
|
|
device=query.device,
|
|||
|
|
)
|
|||
|
|
start = 0
|
|||
|
|
for seq_len in attn_metadata.seq_lens:
|
|||
|
|
end = start + seq_len
|
|||
|
|
mask[start:end, start:end] = torch.tril(
|
|||
|
|
torch.zeros(seq_len, seq_len, dtype=orig_dtype, device=query.device)
|
|||
|
|
)
|
|||
|
|
start = end
|
|||
|
|
|
|||
|
|
# ── [1, H, T, D] ──────────────────────────────────────────────────
|
|||
|
|
q_all = query.squeeze(0).permute(1, 0, 2).contiguous().unsqueeze(0)
|
|||
|
|
k_all = key.squeeze(0).permute(1, 0, 2).contiguous().unsqueeze(0)
|
|||
|
|
v_all = value.squeeze(0).permute(1, 0, 2).contiguous().unsqueeze(0)
|
|||
|
|
|
|||
|
|
# ── GQA:展开 KV heads ────────────────────────────────────────────
|
|||
|
|
if k_all.shape[1] != q_all.shape[1]:
|
|||
|
|
n = q_all.shape[1] // k_all.shape[1]
|
|||
|
|
k_all = k_all.repeat_interleave(n, dim=1).contiguous()
|
|||
|
|
v_all = v_all.repeat_interleave(n, dim=1).contiguous()
|
|||
|
|
|
|||
|
|
# ── F.scaled_dot_product_attention(可走硬件 kernel)─────────────
|
|||
|
|
# is_causal=False:避免 cudnnFlashAttnForward "not support causal mode"
|
|||
|
|
# attn_mask 传 additive float mask(非 bool),SDPA 选择 math/kernel 路径
|
|||
|
|
out = F.scaled_dot_product_attention(
|
|||
|
|
q_all, k_all, v_all,
|
|||
|
|
attn_mask=mask,
|
|||
|
|
dropout_p=0.0,
|
|||
|
|
is_causal=False,
|
|||
|
|
scale=self.scale,
|
|||
|
|
)
|
|||
|
|
# [1, H, T, D] → [1, T, H, D]
|
|||
|
|
return out.squeeze(0).permute(1, 0, 2).contiguous().unsqueeze(0)
|
|||
|
|
|
|||
|
|
'''
|
|||
|
|
|
|||
|
|
OLD_XFORMER_BLOCK = """\
|
|||
|
|
self.attn_op = xops.fmha.flash.FwOp()
|
|||
|
|
if self.alibi_slopes is None:
|
|||
|
|
# Add the batch dimension.
|
|||
|
|
query = query.unsqueeze(0)
|
|||
|
|
key = key.unsqueeze(0)
|
|||
|
|
value = value.unsqueeze(0)
|
|||
|
|
out = xops.memory_efficient_attention_forward(
|
|||
|
|
query,
|
|||
|
|
key,
|
|||
|
|
value,
|
|||
|
|
attn_bias=attn_bias[0],
|
|||
|
|
p=0.0,
|
|||
|
|
scale=self.scale,
|
|||
|
|
op = self.attn_op
|
|||
|
|
)
|
|||
|
|
return out.view_as(original_query)\
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
NEW_XFORMER_BLOCK = """\
|
|||
|
|
self.attn_op = xops.fmha.flash.FwOp()
|
|||
|
|
if self.alibi_slopes is None:
|
|||
|
|
# Add the batch dimension.
|
|||
|
|
query = query.unsqueeze(0)
|
|||
|
|
key = key.unsqueeze(0)
|
|||
|
|
value = value.unsqueeze(0)
|
|||
|
|
if self.head_size > 128:
|
|||
|
|
out = self._run_sdpa_fallback(query, key, value, attn_metadata)
|
|||
|
|
else:
|
|||
|
|
out = xops.memory_efficient_attention_forward(
|
|||
|
|
query,
|
|||
|
|
key,
|
|||
|
|
value,
|
|||
|
|
attn_bias=attn_bias[0],
|
|||
|
|
p=0.0,
|
|||
|
|
scale=self.scale,
|
|||
|
|
op=self.attn_op,
|
|||
|
|
)
|
|||
|
|
return out.view_as(original_query)\
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
INJECT_ANCHOR = " def _run_memory_efficient_xformers_forward("
|
|||
|
|
|
|||
|
|
|
|||
|
|
def patch_file(path):
|
|||
|
|
with open(path, "r") as f:
|
|||
|
|
content = f.read()
|
|||
|
|
changed = False
|
|||
|
|
|
|||
|
|
if "_run_sdpa_fallback" in content:
|
|||
|
|
print(" [skip] _run_sdpa_fallback already present")
|
|||
|
|
elif INJECT_ANCHOR not in content:
|
|||
|
|
print(" [warn] inject anchor not found")
|
|||
|
|
else:
|
|||
|
|
content = content.replace(INJECT_ANCHOR, FALLBACK_METHOD + INJECT_ANCHOR, 1)
|
|||
|
|
print(" [ok] injected _run_sdpa_fallback (batch, F.sdpa kernel)")
|
|||
|
|
changed = True
|
|||
|
|
|
|||
|
|
if NEW_XFORMER_BLOCK in content:
|
|||
|
|
print(" [skip] dispatch block already patched")
|
|||
|
|
elif OLD_XFORMER_BLOCK in content:
|
|||
|
|
content = content.replace(OLD_XFORMER_BLOCK, NEW_XFORMER_BLOCK, 1)
|
|||
|
|
print(" [ok] patched dispatch block")
|
|||
|
|
changed = True
|
|||
|
|
else:
|
|||
|
|
print(" [warn] dispatch block anchor not found")
|
|||
|
|
|
|||
|
|
if changed:
|
|||
|
|
with open(path, "w") as f:
|
|||
|
|
f.write(content)
|
|||
|
|
print(f" Written: {path}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def main():
|
|||
|
|
print("=== patch_xformers_sdpa_batch_kernel (batch, F.sdpa + kernel dispatch) ===")
|
|||
|
|
print(f"Target: {XFORMERS_PATH}")
|
|||
|
|
patch_file(XFORMERS_PATH)
|
|||
|
|
print("\nDone.")
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
main()
|