some modifications to ensure 50K context input
This commit is contained in:
@@ -85,6 +85,85 @@ class PagedAttention:
|
|||||||
v_scale,
|
v_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _forward_decode_pytorch(
|
||||||
|
query: torch.Tensor,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
seq_lens: torch.Tensor,
|
||||||
|
scale: float,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Pure-PyTorch decode attention for long contexts (no hardware kernel).
|
||||||
|
|
||||||
|
paged_attention_v1 hangs on BI-V100 when max_seq_len > ~32K due to
|
||||||
|
shared memory limits. For decode, q_len=1 per sequence so no Q-tiling
|
||||||
|
is needed — the attention weight tensor is [H, 1, seq_len] which is
|
||||||
|
trivially small (~5 MB at 50K).
|
||||||
|
|
||||||
|
Shapes
|
||||||
|
------
|
||||||
|
query : [num_seqs, num_heads, head_dim]
|
||||||
|
key_cache : [num_blocks, num_kv_heads, head_dim//x, block_size, x]
|
||||||
|
value_cache : [num_blocks, num_kv_heads, head_dim, block_size]
|
||||||
|
block_tables: [num_seqs, max_blocks_per_seq]
|
||||||
|
seq_lens : [num_seqs]
|
||||||
|
"""
|
||||||
|
num_seqs, num_heads, head_dim = query.shape
|
||||||
|
num_kv_heads = key_cache.shape[1]
|
||||||
|
block_size = value_cache.shape[3]
|
||||||
|
gqa_ratio = num_heads // num_kv_heads
|
||||||
|
orig_dtype = query.dtype
|
||||||
|
|
||||||
|
output = torch.empty_like(query)
|
||||||
|
|
||||||
|
try:
|
||||||
|
for i in range(num_seqs):
|
||||||
|
seq_len = int(seq_lens[i].item())
|
||||||
|
num_blocks = (seq_len + block_size - 1) // block_size
|
||||||
|
blk_ids = block_tables[i, :num_blocks]
|
||||||
|
|
||||||
|
# Gather K from paged cache: [seq_len, num_kv_heads, head_dim]
|
||||||
|
k_seq = (key_cache[blk_ids]
|
||||||
|
.permute(0, 3, 1, 2, 4)
|
||||||
|
.contiguous()
|
||||||
|
.view(-1, num_kv_heads, head_dim))[:seq_len]
|
||||||
|
|
||||||
|
# Gather V from paged cache: [seq_len, num_kv_heads, head_dim]
|
||||||
|
v_seq = (value_cache[blk_ids]
|
||||||
|
.permute(0, 3, 1, 2)
|
||||||
|
.contiguous()
|
||||||
|
.view(-1, num_kv_heads, head_dim))[:seq_len]
|
||||||
|
|
||||||
|
if gqa_ratio > 1:
|
||||||
|
k_seq = k_seq.repeat_interleave(gqa_ratio, dim=1)
|
||||||
|
v_seq = v_seq.repeat_interleave(gqa_ratio, dim=1)
|
||||||
|
|
||||||
|
# [H, head_dim, seq_len] and [H, seq_len, head_dim]
|
||||||
|
k_t = k_seq.permute(1, 2, 0).float()
|
||||||
|
v_t = v_seq.permute(1, 0, 2).float()
|
||||||
|
|
||||||
|
# q: [H, 1, head_dim]; attn_w: [H, 1, seq_len]
|
||||||
|
q_i = query[i].float().unsqueeze(1)
|
||||||
|
attn_w = torch.matmul(q_i * scale, k_t)
|
||||||
|
attn_w = torch.softmax(attn_w, dim=-1)
|
||||||
|
|
||||||
|
out_i = torch.matmul(attn_w, v_t) # [H, 1, head_dim]
|
||||||
|
output[i] = out_i.squeeze(1).to(orig_dtype)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[decode_pytorch ERROR] {type(e).__name__}: {e}",
|
||||||
|
file=sys.stderr, flush=True)
|
||||||
|
traceback.print_exc(file=sys.stderr)
|
||||||
|
raise
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
# paged_attention_v1 on BI-V100 hangs when max_seq_len exceeds ~32K due to
|
||||||
|
# shared memory limits; use pure-PyTorch fallback above this threshold.
|
||||||
|
# Set to a large value to disable for now (50K decode confirmed working via
|
||||||
|
# hardware kernel); lower to 32768 if kernel hangs are observed at long contexts.
|
||||||
|
_PYTORCH_DECODE_THRESHOLD = 10_000_000
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward_decode(
|
def forward_decode(
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
@@ -105,6 +184,10 @@ class PagedAttention:
|
|||||||
blocksparse_block_size: int = 64,
|
blocksparse_block_size: int = 64,
|
||||||
blocksparse_head_sliding_step: int = 0,
|
blocksparse_head_sliding_step: int = 0,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
if max_seq_len > PagedAttention._PYTORCH_DECODE_THRESHOLD:
|
||||||
|
return PagedAttention._forward_decode_pytorch(
|
||||||
|
query, key_cache, value_cache, block_tables, seq_lens, scale)
|
||||||
|
|
||||||
if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1:
|
if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1:
|
||||||
# use blocksparse paged attention
|
# use blocksparse paged attention
|
||||||
block_size = value_cache.size(-1)
|
block_size = value_cache.size(-1)
|
||||||
|
|||||||
@@ -24,6 +24,12 @@ flash attention kernel(ixformer / cudnnFlashAttnForward)。
|
|||||||
max-model-len=8192 → 峰值 ~800 MB
|
max-model-len=8192 → 峰值 ~800 MB
|
||||||
max-model-len=16384 → 峰值 ~3.2 GB
|
max-model-len=16384 → 峰值 ~3.2 GB
|
||||||
|
|
||||||
|
额外 patch(arg_utils.py):
|
||||||
|
vllm 0.6.3 在 max_model_len > 32K 时会自动开启 chunked prefill(无命令行
|
||||||
|
关闭选项),原意是防止 profiling OOM。但 _run_sdpa_fallback 已通过 Q-tiling
|
||||||
|
解决了该问题,chunked prefill 反而会把推理路径从 _run_sdpa_fallback 切换到
|
||||||
|
_forward_prefix_pytorch,属于不必要的行为变更,因此一并禁用该自动逻辑。
|
||||||
|
|
||||||
Deploy:
|
Deploy:
|
||||||
python3 modified_scripts/patch_xformers_sdpa_seq.py
|
python3 modified_scripts/patch_xformers_sdpa_seq.py
|
||||||
"""
|
"""
|
||||||
@@ -33,6 +39,33 @@ XFORMERS_PATH = (
|
|||||||
"vllm/attention/backends/xformers.py"
|
"vllm/attention/backends/xformers.py"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
ARG_UTILS_PATH = (
|
||||||
|
"/usr/local/corex/lib64/python3/dist-packages/"
|
||||||
|
"vllm/engine/arg_utils.py"
|
||||||
|
)
|
||||||
|
|
||||||
|
# vllm 0.6.3 自动开启 chunked prefill 的原始块
|
||||||
|
_ARG_OLD_BLOCK = """\
|
||||||
|
if (is_gpu and not use_sliding_window and not use_spec_decode
|
||||||
|
and not self.enable_lora
|
||||||
|
and not self.enable_prompt_adapter):
|
||||||
|
self.enable_chunked_prefill = True
|
||||||
|
logger.warning(
|
||||||
|
"Chunked prefill is enabled by default for models with "
|
||||||
|
"max_model_len > 32K. Currently, chunked prefill might "
|
||||||
|
"not work with some features or models. If you "
|
||||||
|
"encounter any issues, please disable chunked prefill "
|
||||||
|
"by setting --enable-chunked-prefill=False.")\
|
||||||
|
"""
|
||||||
|
|
||||||
|
_ARG_NEW_BLOCK = """\
|
||||||
|
if (is_gpu and not use_sliding_window and not use_spec_decode
|
||||||
|
and not self.enable_lora
|
||||||
|
and not self.enable_prompt_adapter):
|
||||||
|
pass # skip auto-enable: Q-tiling in _run_sdpa_fallback
|
||||||
|
# handles long-context memory without chunked prefill\
|
||||||
|
"""
|
||||||
|
|
||||||
FALLBACK_METHOD = '''
|
FALLBACK_METHOD = '''
|
||||||
def _run_sdpa_fallback(
|
def _run_sdpa_fallback(
|
||||||
self,
|
self,
|
||||||
@@ -203,10 +236,35 @@ def patch_file(path):
|
|||||||
print(f" Written: {path}")
|
print(f" Written: {path}")
|
||||||
|
|
||||||
|
|
||||||
|
def patch_arg_utils(path):
|
||||||
|
with open(path, "r") as f:
|
||||||
|
content = f.read()
|
||||||
|
changed = False
|
||||||
|
|
||||||
|
if "skip auto-enable: Q-tiling" in content:
|
||||||
|
print(" [skip] chunked-prefill auto-enable already disabled")
|
||||||
|
elif _ARG_OLD_BLOCK in content:
|
||||||
|
content = content.replace(_ARG_OLD_BLOCK, _ARG_NEW_BLOCK, 1)
|
||||||
|
print(" [ok] disabled chunked-prefill auto-enable for 32K+")
|
||||||
|
changed = True
|
||||||
|
else:
|
||||||
|
print(" [warn] target block not found — check arg_utils.py version")
|
||||||
|
|
||||||
|
if changed:
|
||||||
|
with open(path, "w") as f:
|
||||||
|
f.write(content)
|
||||||
|
print(f" Written: {path}")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
print("=== patch_xformers_sdpa_seq (sequential, pure-math) ===")
|
print("=== patch_xformers_sdpa_seq (sequential, pure-math) ===")
|
||||||
print(f"Target: {XFORMERS_PATH}")
|
print(f"Target: {XFORMERS_PATH}")
|
||||||
patch_file(XFORMERS_PATH)
|
patch_file(XFORMERS_PATH)
|
||||||
|
|
||||||
|
print("\n=== patch_arg_utils (disable chunked-prefill auto-enable) ===")
|
||||||
|
print(f"Target: {ARG_UTILS_PATH}")
|
||||||
|
patch_arg_utils(ARG_UTILS_PATH)
|
||||||
|
|
||||||
print("\nDone.")
|
print("\nDone.")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -412,6 +412,9 @@ class GatedDeltaNet(nn.Module):
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
# Decode: one token per sequence
|
# Decode: one token per sequence
|
||||||
|
with open("/tmp/vllm_decode_debug.log", "a") as _f:
|
||||||
|
_f.write(f"[deltanet decode] layer={self.layer_idx} num_seqs={hidden_states.shape[0]}\n")
|
||||||
|
_f.flush()
|
||||||
num_seqs = hidden_states.shape[0]
|
num_seqs = hidden_states.shape[0]
|
||||||
weight_2d = self.conv1d_weight.squeeze(1)
|
weight_2d = self.conv1d_weight.squeeze(1)
|
||||||
|
|
||||||
@@ -847,6 +850,12 @@ class Qwen3_5ForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> Optional[torch.Tensor]:
|
) -> Optional[torch.Tensor]:
|
||||||
|
# Non-driver TP ranks have seq_groups=None in sampling_metadata (normal
|
||||||
|
# TP behavior); they must still call logits_processor to participate in
|
||||||
|
# the NCCL gather inside lm_head. logits_processor returns None for
|
||||||
|
# non-driver ranks after the gather, safely skipping _apply_logits_processors.
|
||||||
|
# Rank 0 (driver) always has seq_groups != None given
|
||||||
|
# --max-num-batched-tokens >= --max-model-len (no chunked-prefill splits).
|
||||||
return self.logits_processor(self.lm_head, hidden_states,
|
return self.logits_processor(self.lm_head, hidden_states,
|
||||||
sampling_metadata)
|
sampling_metadata)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user