""" 策略:顺序(per-sequence)fallback — 纯 PyTorch 数学实现 ========================================================== 逐条序列用 matmul + softmax 手写 attention,完全绕开所有硬件 flash attention kernel(ixformer / cudnnFlashAttnForward)。 背景: Iluvatar cudnnFlashAttnForward 存在两个已知问题: 1. 不支持 is_causal=True(报错) 2. 使用 attn_mask 路径时数值结果不正确(静默错误,输出全为"!") 与华为昇腾 910B4 上 llama.cpp --flash-attn off 修复同类问题的原理相同。 纯数学路径(matmul + softmax)在任何 PyTorch 后端上结果都正确。 优点: 数值正确,不依赖任何硬件特定 attention kernel。 峰值显存 = max(seq_len)² × H × dtype_size,由 --max-model-len 控制。 缺点: 并发请求的 prefill attention 串行执行。 O(L²) 显存(无 flash attention 的 O(L) 优化)。 内存参考(fp16,H_local=6): max-model-len=4096 → 峰值 ~200 MB max-model-len=8192 → 峰值 ~800 MB max-model-len=16384 → 峰值 ~3.2 GB Deploy: python3 modified_scripts/patch_xformers_sdpa_seq.py """ XFORMERS_PATH = ( "/usr/local/corex/lib64/python3/dist-packages/" "vllm/attention/backends/xformers.py" ) FALLBACK_METHOD = ''' def _run_sdpa_fallback( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_metadata: "XFormersMetadata", ) -> torch.Tensor: """顺序纯数学 attention fallback。 完全绕开 ixformer / cudnnFlashAttnForward,用 matmul + softmax 手写 attention。Iluvatar cudnnFlashAttnForward 的 attn_mask 路径 存在静默数值错误(输出全为"!"),纯数学路径结果正确。 softmax 在 float32 下计算以防止 float16 溢出,结果转回原始 dtype。 Args: query : [1, total_prefill_tokens, num_heads, head_dim] key : [1, total_prefill_tokens, num_kv_heads, head_dim] value : [1, total_prefill_tokens, num_kv_heads, head_dim] Returns: [1, total_prefill_tokens, num_heads, head_dim] """ assert attn_metadata.seq_lens is not None orig_dtype = query.dtype q_flat = query.squeeze(0) # [T, H, D] k_flat = key.squeeze(0) # [T, Hkv, D] v_flat = value.squeeze(0) output = torch.empty_like(q_flat) start = 0 for seq_len in attn_metadata.seq_lens: end = start + seq_len # [1, H, L, D] q_s = q_flat[start:end].permute(1, 0, 2).contiguous().unsqueeze(0) k_s = k_flat[start:end].permute(1, 0, 2).contiguous().unsqueeze(0) v_s = v_flat[start:end].permute(1, 0, 2).contiguous().unsqueeze(0) # GQA:展开 KV heads 至与 query heads 一致 if k_s.shape[1] != q_s.shape[1]: n = q_s.shape[1] // k_s.shape[1] k_s = k_s.repeat_interleave(n, dim=1).contiguous() v_s = v_s.repeat_interleave(n, dim=1).contiguous() # 纯数学 attention:完全绕开硬件 flash attention kernel # [1, H, L, L] attn_w = torch.matmul(q_s.float(), k_s.float().transpose(-2, -1)) attn_w = attn_w * self.scale # 上三角填 -inf(future tokens) causal_mask = torch.triu( torch.ones(seq_len, seq_len, dtype=torch.bool, device=attn_w.device), diagonal=1, ) attn_w = attn_w.masked_fill(causal_mask, float("-inf")) # float32 softmax 防止 float16 溢出 attn_w = torch.softmax(attn_w, dim=-1) out_s = torch.matmul(attn_w, v_s.float()).to(orig_dtype) # [1, H, L, D] → [L, H, D] output[start:end] = out_s.squeeze(0).permute(1, 0, 2) start = end return output.unsqueeze(0) # [1, T, H, D] ''' OLD_XFORMER_BLOCK = """\ self.attn_op = xops.fmha.flash.FwOp() if self.alibi_slopes is None: # Add the batch dimension. query = query.unsqueeze(0) key = key.unsqueeze(0) value = value.unsqueeze(0) out = xops.memory_efficient_attention_forward( query, key, value, attn_bias=attn_bias[0], p=0.0, scale=self.scale, op = self.attn_op ) return out.view_as(original_query)\ """ NEW_XFORMER_BLOCK = """\ self.attn_op = xops.fmha.flash.FwOp() if self.alibi_slopes is None: # Add the batch dimension. query = query.unsqueeze(0) key = key.unsqueeze(0) value = value.unsqueeze(0) if self.head_size > 128: out = self._run_sdpa_fallback(query, key, value, attn_metadata) else: out = xops.memory_efficient_attention_forward( query, key, value, attn_bias=attn_bias[0], p=0.0, scale=self.scale, op=self.attn_op, ) return out.view_as(original_query)\ """ INJECT_ANCHOR = " def _run_memory_efficient_xformers_forward(" def patch_file(path): with open(path, "r") as f: content = f.read() changed = False if "_run_sdpa_fallback" in content: print(" [skip] _run_sdpa_fallback already present") elif INJECT_ANCHOR not in content: print(" [warn] inject anchor not found") else: content = content.replace(INJECT_ANCHOR, FALLBACK_METHOD + INJECT_ANCHOR, 1) print(" [ok] injected _run_sdpa_fallback (sequential, pure-math)") changed = True if NEW_XFORMER_BLOCK in content: print(" [skip] dispatch block already patched") elif OLD_XFORMER_BLOCK in content: content = content.replace(OLD_XFORMER_BLOCK, NEW_XFORMER_BLOCK, 1) print(" [ok] patched dispatch block") changed = True else: print(" [warn] dispatch block anchor not found") if changed: with open(path, "w") as f: f.write(content) print(f" Written: {path}") def main(): print("=== patch_xformers_sdpa_seq (sequential, pure-math) ===") print(f"Target: {XFORMERS_PATH}") patch_file(XFORMERS_PATH) print("\nDone.") if __name__ == "__main__": main()