initial version of adding chunked attention, ensuring 20K context
This commit is contained in:
@@ -41,62 +41,90 @@ FALLBACK_METHOD = '''
|
||||
value: torch.Tensor,
|
||||
attn_metadata: "XFormersMetadata",
|
||||
) -> torch.Tensor:
|
||||
"""顺序纯数学 attention fallback。
|
||||
"""纯数学 causal attention fallback,带 Q-tiling 内存优化。
|
||||
|
||||
完全绕开 ixformer / cudnnFlashAttnForward,用 matmul + softmax
|
||||
手写 attention。Iluvatar cudnnFlashAttnForward 的 attn_mask 路径
|
||||
存在静默数值错误(输出全为"!"),纯数学路径结果正确。
|
||||
调用时机:kv_cache.numel()==0(profiling 阶段)。
|
||||
此路径无 KV 缓存前缀,KV 长度 == query 长度。
|
||||
|
||||
内存优化(Q-tiling,与 Flash Attention 同思路):
|
||||
将 Q 分成 _Q_CHUNK 大小的子块逐块计算,每块峰值内存
|
||||
O(_Q_CHUNK × q_len) 而非 O(q_len²)。
|
||||
profiling 阶段序列可能达到 max_model_len(如 20K tokens),
|
||||
不加 Q-tiling 会产生 9.6 GB 矩阵直接 OOM。
|
||||
|
||||
softmax 在 float32 下计算以防止 float16 溢出,结果转回原始 dtype。
|
||||
|
||||
Args:
|
||||
query : [1, total_prefill_tokens, num_heads, head_dim]
|
||||
key : [1, total_prefill_tokens, num_kv_heads, head_dim]
|
||||
value : [1, total_prefill_tokens, num_kv_heads, head_dim]
|
||||
query : [1, total_query_tokens, num_heads, head_dim]
|
||||
key : [1, total_query_tokens, num_kv_heads, head_dim]
|
||||
value : [1, total_query_tokens, num_kv_heads, head_dim]
|
||||
Returns:
|
||||
[1, total_prefill_tokens, num_heads, head_dim]
|
||||
[1, total_query_tokens, num_heads, head_dim]
|
||||
"""
|
||||
_Q_CHUNK = 256 # 与 _forward_prefix_pytorch 的 _ATTN_Q_CHUNK 保持一致
|
||||
|
||||
assert attn_metadata.seq_lens is not None
|
||||
orig_dtype = query.dtype
|
||||
num_seqs = len(attn_metadata.seq_lens)
|
||||
|
||||
# 推导每条序列的实际 query 长度。
|
||||
# 正常 prefill 时 q_len == seq_len;如果将来遇到 chunked 场景,
|
||||
# query_start_loc 记录的是真实 query token 数(非全序列长度)。
|
||||
if (attn_metadata.query_start_loc is not None
|
||||
and len(attn_metadata.query_start_loc) == num_seqs + 1):
|
||||
q_lens = [
|
||||
int(attn_metadata.query_start_loc[i + 1].item()) -
|
||||
int(attn_metadata.query_start_loc[i].item())
|
||||
for i in range(num_seqs)
|
||||
]
|
||||
else:
|
||||
q_lens = list(attn_metadata.seq_lens)
|
||||
|
||||
q_flat = query.squeeze(0) # [T, H, D]
|
||||
k_flat = key.squeeze(0) # [T, Hkv, D]
|
||||
v_flat = value.squeeze(0)
|
||||
|
||||
output = torch.empty_like(q_flat)
|
||||
start = 0
|
||||
for seq_len in attn_metadata.seq_lens:
|
||||
end = start + seq_len
|
||||
# [1, H, L, D]
|
||||
q_s = q_flat[start:end].permute(1, 0, 2).contiguous().unsqueeze(0)
|
||||
k_s = k_flat[start:end].permute(1, 0, 2).contiguous().unsqueeze(0)
|
||||
v_s = v_flat[start:end].permute(1, 0, 2).contiguous().unsqueeze(0)
|
||||
seq_start = 0
|
||||
for q_len in q_lens:
|
||||
seq_end = seq_start + q_len
|
||||
|
||||
# 当前序列的完整 K/V(此路径无前缀,KV == Q)
|
||||
k_s = k_flat[seq_start:seq_end].permute(1, 0, 2).float() # [Hkv, q_len, D]
|
||||
v_s = v_flat[seq_start:seq_end].permute(1, 0, 2).float() # [Hkv, q_len, D]
|
||||
|
||||
# GQA:展开 KV heads 至与 query heads 一致
|
||||
if k_s.shape[1] != q_s.shape[1]:
|
||||
n = q_s.shape[1] // k_s.shape[1]
|
||||
k_s = k_s.repeat_interleave(n, dim=1).contiguous()
|
||||
v_s = v_s.repeat_interleave(n, dim=1).contiguous()
|
||||
if k_s.shape[0] != self.num_heads:
|
||||
n = self.num_heads // k_s.shape[0]
|
||||
k_s = k_s.repeat_interleave(n, dim=0).contiguous()
|
||||
v_s = v_s.repeat_interleave(n, dim=0).contiguous()
|
||||
|
||||
# 纯数学 attention:完全绕开硬件 flash attention kernel
|
||||
# [1, H, L, L]
|
||||
attn_w = torch.matmul(q_s.float(), k_s.float().transpose(-2, -1))
|
||||
attn_w = attn_w * self.scale
|
||||
# k_pos 用于因果掩码
|
||||
k_pos = torch.arange(q_len, device=query.device)
|
||||
|
||||
# 上三角填 -inf(future tokens)
|
||||
causal_mask = torch.triu(
|
||||
torch.ones(seq_len, seq_len, dtype=torch.bool, device=attn_w.device),
|
||||
diagonal=1,
|
||||
)
|
||||
attn_w = attn_w.masked_fill(causal_mask, float("-inf"))
|
||||
# Q-tiling:分块处理 query,峰值内存 O(_Q_CHUNK × q_len)
|
||||
for qc_start in range(0, q_len, _Q_CHUNK):
|
||||
qc_end = min(qc_start + _Q_CHUNK, q_len)
|
||||
|
||||
# float32 softmax 防止 float16 溢出
|
||||
attn_w = torch.softmax(attn_w, dim=-1)
|
||||
# [H, qc, D]
|
||||
q_c = q_flat[seq_start + qc_start:seq_start + qc_end] \
|
||||
.permute(1, 0, 2).float()
|
||||
|
||||
out_s = torch.matmul(attn_w, v_s.float()).to(orig_dtype)
|
||||
# [1, H, L, D] → [L, H, D]
|
||||
output[start:end] = out_s.squeeze(0).permute(1, 0, 2)
|
||||
start = end
|
||||
# [H, qc, q_len]
|
||||
attn_w = torch.matmul(q_c, k_s.transpose(-2, -1)) * self.scale
|
||||
|
||||
# 因果掩码:q_c 里位置 j 只能看 k_pos <= j(相对位置)
|
||||
qc_q_pos = torch.arange(qc_start, qc_end, device=query.device)
|
||||
mask = k_pos.unsqueeze(0) > qc_q_pos.unsqueeze(1)
|
||||
attn_w = attn_w.masked_fill(mask.unsqueeze(0), float("-inf"))
|
||||
|
||||
attn_w = torch.softmax(attn_w, dim=-1)
|
||||
out_c = torch.matmul(attn_w, v_s).to(orig_dtype) # [H, qc, D]
|
||||
|
||||
output[seq_start + qc_start:seq_start + qc_end] = (
|
||||
out_c.permute(1, 0, 2))
|
||||
|
||||
seq_start = seq_end
|
||||
|
||||
return output.unsqueeze(0) # [1, T, H, D]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user