Files
enginex-vllm-bi100-qwen36/qwen3_6_scripts/patch_xformers_sdpa_seq_kernel.py

182 lines
6.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
策略顺序per-sequence— F.scaled_dot_product_attention可走硬件 kernel
=============================================================================
逐条序列调用 F.scaled_dot_product_attentionis_causal=False + 显式因果 mask。
与 patch_xformers_sdpa_seq.py纯 matmul的区别
SDPA 可分发到 Flash Attention / mem-efficient attention kernel
而纯 matmul 固定走 cublas。
硬件限制BI-V100
cudnnFlashAttnForward 不支持 is_causal=True直接报错
必须使用 is_causal=False + 显式 additive causal mask。
每条序列单独构造上三角 -inf maskpeak 显存 = max(seq_len)² × dtype
比 batch 版的 total_tokens² 小得多。
与 batch_kernel 的对比:
seq_kernel: 显存小peak = max_single_seq²并发 prefill 串行排队
batch_kernel: 显存大peak = total_tokens²并发 prefill 一次并行处理,
通过 --max-num-batched-tokens 控制 total_tokens 上限
Deploy:
python3 modified_scripts/patch_xformers_sdpa_seq_kernel.py
"""
XFORMERS_PATH = (
"/usr/local/corex/lib64/python3/dist-packages/"
"vllm/attention/backends/xformers.py"
)
FALLBACK_METHOD = '''
def _run_sdpa_fallback(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: "XFormersMetadata",
) -> torch.Tensor:
"""顺序 F.scaled_dot_product_attention fallback可走硬件 kernel
逐条序列调用 SDPAis_causal=False + 显式上三角 additive mask。
cudnnFlashAttnForward 不支持 is_causal=True必须用显式 mask。
逐序列构造 maskpeak 显存 = max(seq_len)² × dtype远小于 batch 版)。
Args:
query : [1, total_prefill_tokens, num_heads, head_dim]
key : [1, total_prefill_tokens, num_kv_heads, head_dim]
value : [1, total_prefill_tokens, num_kv_heads, head_dim]
Returns:
[1, total_prefill_tokens, num_heads, head_dim]
"""
import torch.nn.functional as F
assert attn_metadata.seq_lens is not None
orig_dtype = query.dtype
q_flat = query.squeeze(0) # [T, H, D]
k_flat = key.squeeze(0) # [T, Hkv, D]
v_flat = value.squeeze(0)
output = torch.empty_like(q_flat)
start = 0
for seq_len in attn_metadata.seq_lens:
end = start + seq_len
# [1, H, L, D]
q_s = q_flat[start:end].permute(1, 0, 2).contiguous().unsqueeze(0)
k_s = k_flat[start:end].permute(1, 0, 2).contiguous().unsqueeze(0)
v_s = v_flat[start:end].permute(1, 0, 2).contiguous().unsqueeze(0)
# GQA展开 KV heads
if k_s.shape[1] != q_s.shape[1]:
n = q_s.shape[1] // k_s.shape[1]
k_s = k_s.repeat_interleave(n, dim=1).contiguous()
v_s = v_s.repeat_interleave(n, dim=1).contiguous()
# 逐序列因果 mask [L, L],上三角 -inf
causal_mask = torch.tril(
torch.zeros(seq_len, seq_len, dtype=orig_dtype, device=q_s.device)
)
causal_mask = causal_mask.masked_fill(
torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool,
device=q_s.device), diagonal=1),
float("-inf"),
)
# is_causal=False + 显式 mask规避 cudnnFlashAttnForward 不支持 is_causal=True
out_s = F.scaled_dot_product_attention(
q_s, k_s, v_s,
attn_mask=causal_mask,
dropout_p=0.0,
is_causal=False,
scale=self.scale,
)
# [1, H, L, D] → [L, H, D]
output[start:end] = out_s.squeeze(0).permute(1, 0, 2).to(orig_dtype)
start = end
return output.unsqueeze(0) # [1, T, H, D]
'''
OLD_XFORMER_BLOCK = """\
self.attn_op = xops.fmha.flash.FwOp()
if self.alibi_slopes is None:
# Add the batch dimension.
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
out = xops.memory_efficient_attention_forward(
query,
key,
value,
attn_bias=attn_bias[0],
p=0.0,
scale=self.scale,
op = self.attn_op
)
return out.view_as(original_query)\
"""
NEW_XFORMER_BLOCK = """\
self.attn_op = xops.fmha.flash.FwOp()
if self.alibi_slopes is None:
# Add the batch dimension.
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
if self.head_size > 128:
out = self._run_sdpa_fallback(query, key, value, attn_metadata)
else:
out = xops.memory_efficient_attention_forward(
query,
key,
value,
attn_bias=attn_bias[0],
p=0.0,
scale=self.scale,
op=self.attn_op,
)
return out.view_as(original_query)\
"""
INJECT_ANCHOR = " def _run_memory_efficient_xformers_forward("
def patch_file(path):
with open(path, "r") as f:
content = f.read()
changed = False
if "_run_sdpa_fallback" in content:
print(" [skip] _run_sdpa_fallback already present")
elif INJECT_ANCHOR not in content:
print(" [warn] inject anchor not found")
else:
content = content.replace(INJECT_ANCHOR, FALLBACK_METHOD + INJECT_ANCHOR, 1)
print(" [ok] injected _run_sdpa_fallback (seq, F.sdpa kernel)")
changed = True
if NEW_XFORMER_BLOCK in content:
print(" [skip] dispatch block already patched")
elif OLD_XFORMER_BLOCK in content:
content = content.replace(OLD_XFORMER_BLOCK, NEW_XFORMER_BLOCK, 1)
print(" [ok] patched dispatch block")
changed = True
else:
print(" [warn] dispatch block anchor not found")
if changed:
with open(path, "w") as f:
f.write(content)
print(f" Written: {path}")
def main():
print("=== patch_xformers_sdpa_seq_kernel (seq, F.sdpa + kernel dispatch) ===")
print(f"Target: {XFORMERS_PATH}")
patch_file(XFORMERS_PATH)
print("\nDone.")
if __name__ == "__main__":
main()