some modifications to ensure 50K context input
This commit is contained in:
@@ -85,6 +85,85 @@ class PagedAttention:
|
||||
v_scale,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _forward_decode_pytorch(
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
scale: float,
|
||||
) -> torch.Tensor:
|
||||
"""Pure-PyTorch decode attention for long contexts (no hardware kernel).
|
||||
|
||||
paged_attention_v1 hangs on BI-V100 when max_seq_len > ~32K due to
|
||||
shared memory limits. For decode, q_len=1 per sequence so no Q-tiling
|
||||
is needed — the attention weight tensor is [H, 1, seq_len] which is
|
||||
trivially small (~5 MB at 50K).
|
||||
|
||||
Shapes
|
||||
------
|
||||
query : [num_seqs, num_heads, head_dim]
|
||||
key_cache : [num_blocks, num_kv_heads, head_dim//x, block_size, x]
|
||||
value_cache : [num_blocks, num_kv_heads, head_dim, block_size]
|
||||
block_tables: [num_seqs, max_blocks_per_seq]
|
||||
seq_lens : [num_seqs]
|
||||
"""
|
||||
num_seqs, num_heads, head_dim = query.shape
|
||||
num_kv_heads = key_cache.shape[1]
|
||||
block_size = value_cache.shape[3]
|
||||
gqa_ratio = num_heads // num_kv_heads
|
||||
orig_dtype = query.dtype
|
||||
|
||||
output = torch.empty_like(query)
|
||||
|
||||
try:
|
||||
for i in range(num_seqs):
|
||||
seq_len = int(seq_lens[i].item())
|
||||
num_blocks = (seq_len + block_size - 1) // block_size
|
||||
blk_ids = block_tables[i, :num_blocks]
|
||||
|
||||
# Gather K from paged cache: [seq_len, num_kv_heads, head_dim]
|
||||
k_seq = (key_cache[blk_ids]
|
||||
.permute(0, 3, 1, 2, 4)
|
||||
.contiguous()
|
||||
.view(-1, num_kv_heads, head_dim))[:seq_len]
|
||||
|
||||
# Gather V from paged cache: [seq_len, num_kv_heads, head_dim]
|
||||
v_seq = (value_cache[blk_ids]
|
||||
.permute(0, 3, 1, 2)
|
||||
.contiguous()
|
||||
.view(-1, num_kv_heads, head_dim))[:seq_len]
|
||||
|
||||
if gqa_ratio > 1:
|
||||
k_seq = k_seq.repeat_interleave(gqa_ratio, dim=1)
|
||||
v_seq = v_seq.repeat_interleave(gqa_ratio, dim=1)
|
||||
|
||||
# [H, head_dim, seq_len] and [H, seq_len, head_dim]
|
||||
k_t = k_seq.permute(1, 2, 0).float()
|
||||
v_t = v_seq.permute(1, 0, 2).float()
|
||||
|
||||
# q: [H, 1, head_dim]; attn_w: [H, 1, seq_len]
|
||||
q_i = query[i].float().unsqueeze(1)
|
||||
attn_w = torch.matmul(q_i * scale, k_t)
|
||||
attn_w = torch.softmax(attn_w, dim=-1)
|
||||
|
||||
out_i = torch.matmul(attn_w, v_t) # [H, 1, head_dim]
|
||||
output[i] = out_i.squeeze(1).to(orig_dtype)
|
||||
except Exception as e:
|
||||
print(f"[decode_pytorch ERROR] {type(e).__name__}: {e}",
|
||||
file=sys.stderr, flush=True)
|
||||
traceback.print_exc(file=sys.stderr)
|
||||
raise
|
||||
|
||||
return output
|
||||
|
||||
# paged_attention_v1 on BI-V100 hangs when max_seq_len exceeds ~32K due to
|
||||
# shared memory limits; use pure-PyTorch fallback above this threshold.
|
||||
# Set to a large value to disable for now (50K decode confirmed working via
|
||||
# hardware kernel); lower to 32768 if kernel hangs are observed at long contexts.
|
||||
_PYTORCH_DECODE_THRESHOLD = 10_000_000
|
||||
|
||||
@staticmethod
|
||||
def forward_decode(
|
||||
query: torch.Tensor,
|
||||
@@ -105,6 +184,10 @@ class PagedAttention:
|
||||
blocksparse_block_size: int = 64,
|
||||
blocksparse_head_sliding_step: int = 0,
|
||||
) -> torch.Tensor:
|
||||
if max_seq_len > PagedAttention._PYTORCH_DECODE_THRESHOLD:
|
||||
return PagedAttention._forward_decode_pytorch(
|
||||
query, key_cache, value_cache, block_tables, seq_lens, scale)
|
||||
|
||||
if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1:
|
||||
# use blocksparse paged attention
|
||||
block_size = value_cache.size(-1)
|
||||
|
||||
Reference in New Issue
Block a user