some modifications to ensure 50K context input

This commit is contained in:
2026-06-04 17:56:29 +08:00
parent 1c33ef1355
commit 8c047a70ea
3 changed files with 150 additions and 0 deletions

View File

@@ -85,6 +85,85 @@ class PagedAttention:
v_scale,
)
@staticmethod
def _forward_decode_pytorch(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
scale: float,
) -> torch.Tensor:
"""Pure-PyTorch decode attention for long contexts (no hardware kernel).
paged_attention_v1 hangs on BI-V100 when max_seq_len > ~32K due to
shared memory limits. For decode, q_len=1 per sequence so no Q-tiling
is needed — the attention weight tensor is [H, 1, seq_len] which is
trivially small (~5 MB at 50K).
Shapes
------
query : [num_seqs, num_heads, head_dim]
key_cache : [num_blocks, num_kv_heads, head_dim//x, block_size, x]
value_cache : [num_blocks, num_kv_heads, head_dim, block_size]
block_tables: [num_seqs, max_blocks_per_seq]
seq_lens : [num_seqs]
"""
num_seqs, num_heads, head_dim = query.shape
num_kv_heads = key_cache.shape[1]
block_size = value_cache.shape[3]
gqa_ratio = num_heads // num_kv_heads
orig_dtype = query.dtype
output = torch.empty_like(query)
try:
for i in range(num_seqs):
seq_len = int(seq_lens[i].item())
num_blocks = (seq_len + block_size - 1) // block_size
blk_ids = block_tables[i, :num_blocks]
# Gather K from paged cache: [seq_len, num_kv_heads, head_dim]
k_seq = (key_cache[blk_ids]
.permute(0, 3, 1, 2, 4)
.contiguous()
.view(-1, num_kv_heads, head_dim))[:seq_len]
# Gather V from paged cache: [seq_len, num_kv_heads, head_dim]
v_seq = (value_cache[blk_ids]
.permute(0, 3, 1, 2)
.contiguous()
.view(-1, num_kv_heads, head_dim))[:seq_len]
if gqa_ratio > 1:
k_seq = k_seq.repeat_interleave(gqa_ratio, dim=1)
v_seq = v_seq.repeat_interleave(gqa_ratio, dim=1)
# [H, head_dim, seq_len] and [H, seq_len, head_dim]
k_t = k_seq.permute(1, 2, 0).float()
v_t = v_seq.permute(1, 0, 2).float()
# q: [H, 1, head_dim]; attn_w: [H, 1, seq_len]
q_i = query[i].float().unsqueeze(1)
attn_w = torch.matmul(q_i * scale, k_t)
attn_w = torch.softmax(attn_w, dim=-1)
out_i = torch.matmul(attn_w, v_t) # [H, 1, head_dim]
output[i] = out_i.squeeze(1).to(orig_dtype)
except Exception as e:
print(f"[decode_pytorch ERROR] {type(e).__name__}: {e}",
file=sys.stderr, flush=True)
traceback.print_exc(file=sys.stderr)
raise
return output
# paged_attention_v1 on BI-V100 hangs when max_seq_len exceeds ~32K due to
# shared memory limits; use pure-PyTorch fallback above this threshold.
# Set to a large value to disable for now (50K decode confirmed working via
# hardware kernel); lower to 32768 if kernel hangs are observed at long contexts.
_PYTORCH_DECODE_THRESHOLD = 10_000_000
@staticmethod
def forward_decode(
query: torch.Tensor,
@@ -105,6 +184,10 @@ class PagedAttention:
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0,
) -> torch.Tensor:
if max_seq_len > PagedAttention._PYTORCH_DECODE_THRESHOLD:
return PagedAttention._forward_decode_pytorch(
query, key_cache, value_cache, block_tables, seq_lens, scale)
if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1:
# use blocksparse paged attention
block_size = value_cache.size(-1)