initial version of adding chunked attention, ensuring 20K context
This commit is contained in:
@@ -1,10 +1,51 @@
|
|||||||
|
# BI-V100 patch script for Qwen3.6-27B (Qwen3_5 architecture)
|
||||||
|
#
|
||||||
|
# Triton situation on BI-V100:
|
||||||
|
# - Standard Triton 2.3.1 is already present in the image.
|
||||||
|
# - HAS_TRITON = False (hardcoded in vendor vllm), but Triton is still used
|
||||||
|
# for TP-mode cache management (custom_cache_manager / libentry).
|
||||||
|
# - The vendor's triton_utils/__init__.py, custom_cache_manager.py, libentry.py
|
||||||
|
# are already correct for standard Triton 2.3.1 — do NOT overwrite them.
|
||||||
|
# - DO NOT install BI-V150 corex Triton 2.1.0 (pkgs/triton): that causes
|
||||||
|
# GPU hang on BI-V100 because the Triton CUDA PTX kernels are incompatible.
|
||||||
|
#
|
||||||
|
# Chunked prefill note:
|
||||||
|
# --enable-chunked-prefill is NOT supported by the vendor's vllm 0.6.3 for
|
||||||
|
# has_inner_state=True models on BI-V100. It causes "Engine loop has died"
|
||||||
|
# immediately on first request. Do NOT use that flag.
|
||||||
|
# Long-context memory is instead handled by query-chunking inside
|
||||||
|
# _forward_prefix_pytorch (see paged_attn.py, _ATTN_Q_CHUNK=256).
|
||||||
|
#
|
||||||
|
# Recommended server start command:
|
||||||
|
# python3 -m vllm.entrypoints.openai.api_server \
|
||||||
|
# --model /workspace/models/Qwen3.6-27B --port 1111 \
|
||||||
|
# --served-model-name llm --max-model-len 20000 \
|
||||||
|
# --enforce-eager --trust-remote-code -tp 4 \
|
||||||
|
# --gpu-memory-utilization 0.95
|
||||||
|
# (No --enable-chunked-prefill, no --max-num-batched-tokens)
|
||||||
|
|
||||||
|
# --- paged_attn.py: replace forward_prefix with pure-PyTorch fallback -------
|
||||||
|
# The Triton context_attention_fwd kernel hangs BI-V100 GPUs permanently
|
||||||
|
# (standard Triton 2.3.1 PTX is not supported by the corex runtime either).
|
||||||
|
# Our paged_attn.py bypasses it entirely via _forward_prefix_pytorch, which
|
||||||
|
# also implements query-chunking (_ATTN_Q_CHUNK=256) to keep peak attention
|
||||||
|
# memory at O(256 × kv_len) instead of O(q_len × kv_len).
|
||||||
|
cp ./paged_attn.py /usr/local/corex/lib64/python3/dist-packages/vllm/attention/ops/paged_attn.py
|
||||||
|
|
||||||
|
# --- transformers: Qwen3_5 tokenizer / model files --------------------------
|
||||||
pip install transformers==4.55.3 -i https://pypi.tuna.tsinghua.edu.cn/simple
|
pip install transformers==4.55.3 -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||||
cp -r ./qwen3_5 /usr/local/lib/python3.10/site-packages/transformers/models/
|
cp -r ./qwen3_5 /usr/local/lib/python3.10/site-packages/transformers/models/
|
||||||
python3 ./patch_transformers_qwen3_5.py
|
python3 ./patch_transformers_qwen3_5.py
|
||||||
|
|
||||||
|
# --- vllm model: Qwen3.6-27B (Qwen3_5 arch) --------------------------------
|
||||||
cp ./mamba_cache.py /usr/local/corex/lib/python3/dist-packages/vllm/model_executor/models/
|
cp ./mamba_cache.py /usr/local/corex/lib/python3/dist-packages/vllm/model_executor/models/
|
||||||
cp ./qwen3_5.py /usr/local/corex/lib/python3/dist-packages/vllm/model_executor/models/
|
cp ./qwen3_5.py /usr/local/corex/lib/python3/dist-packages/vllm/model_executor/models/qwen3_5.py
|
||||||
python3 ./patch_vllm_qwen3_5.py
|
python3 ./patch_vllm_qwen3_5.py
|
||||||
|
|
||||||
# 此步骤脚本四选一(默认 matmul+seq策略)
|
# --- xformers: bypass cudnnFlashAttnForward (head_dim=256 > 128 limit) ------
|
||||||
|
# Injects _run_sdpa_fallback (pure matmul+softmax) into xformers.py.
|
||||||
|
# Required because head_dim=256 > 128 and ixformer flash attention either
|
||||||
|
# crashes (is_causal=True) or produces wrong output (attn_mask path).
|
||||||
|
# The fallback uses query_start_loc to derive actual query lengths, so it
|
||||||
|
# works correctly during profiling runs with chunked-prefill-style batches.
|
||||||
python3 ./patch_xformers_sdpa_seq.py
|
python3 ./patch_xformers_sdpa_seq.py
|
||||||
|
|||||||
@@ -41,62 +41,90 @@ FALLBACK_METHOD = '''
|
|||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
attn_metadata: "XFormersMetadata",
|
attn_metadata: "XFormersMetadata",
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""顺序纯数学 attention fallback。
|
"""纯数学 causal attention fallback,带 Q-tiling 内存优化。
|
||||||
|
|
||||||
完全绕开 ixformer / cudnnFlashAttnForward,用 matmul + softmax
|
调用时机:kv_cache.numel()==0(profiling 阶段)。
|
||||||
手写 attention。Iluvatar cudnnFlashAttnForward 的 attn_mask 路径
|
此路径无 KV 缓存前缀,KV 长度 == query 长度。
|
||||||
存在静默数值错误(输出全为"!"),纯数学路径结果正确。
|
|
||||||
|
内存优化(Q-tiling,与 Flash Attention 同思路):
|
||||||
|
将 Q 分成 _Q_CHUNK 大小的子块逐块计算,每块峰值内存
|
||||||
|
O(_Q_CHUNK × q_len) 而非 O(q_len²)。
|
||||||
|
profiling 阶段序列可能达到 max_model_len(如 20K tokens),
|
||||||
|
不加 Q-tiling 会产生 9.6 GB 矩阵直接 OOM。
|
||||||
|
|
||||||
softmax 在 float32 下计算以防止 float16 溢出,结果转回原始 dtype。
|
softmax 在 float32 下计算以防止 float16 溢出,结果转回原始 dtype。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query : [1, total_prefill_tokens, num_heads, head_dim]
|
query : [1, total_query_tokens, num_heads, head_dim]
|
||||||
key : [1, total_prefill_tokens, num_kv_heads, head_dim]
|
key : [1, total_query_tokens, num_kv_heads, head_dim]
|
||||||
value : [1, total_prefill_tokens, num_kv_heads, head_dim]
|
value : [1, total_query_tokens, num_kv_heads, head_dim]
|
||||||
Returns:
|
Returns:
|
||||||
[1, total_prefill_tokens, num_heads, head_dim]
|
[1, total_query_tokens, num_heads, head_dim]
|
||||||
"""
|
"""
|
||||||
|
_Q_CHUNK = 256 # 与 _forward_prefix_pytorch 的 _ATTN_Q_CHUNK 保持一致
|
||||||
|
|
||||||
assert attn_metadata.seq_lens is not None
|
assert attn_metadata.seq_lens is not None
|
||||||
orig_dtype = query.dtype
|
orig_dtype = query.dtype
|
||||||
|
num_seqs = len(attn_metadata.seq_lens)
|
||||||
|
|
||||||
|
# 推导每条序列的实际 query 长度。
|
||||||
|
# 正常 prefill 时 q_len == seq_len;如果将来遇到 chunked 场景,
|
||||||
|
# query_start_loc 记录的是真实 query token 数(非全序列长度)。
|
||||||
|
if (attn_metadata.query_start_loc is not None
|
||||||
|
and len(attn_metadata.query_start_loc) == num_seqs + 1):
|
||||||
|
q_lens = [
|
||||||
|
int(attn_metadata.query_start_loc[i + 1].item()) -
|
||||||
|
int(attn_metadata.query_start_loc[i].item())
|
||||||
|
for i in range(num_seqs)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
q_lens = list(attn_metadata.seq_lens)
|
||||||
|
|
||||||
q_flat = query.squeeze(0) # [T, H, D]
|
q_flat = query.squeeze(0) # [T, H, D]
|
||||||
k_flat = key.squeeze(0) # [T, Hkv, D]
|
k_flat = key.squeeze(0) # [T, Hkv, D]
|
||||||
v_flat = value.squeeze(0)
|
v_flat = value.squeeze(0)
|
||||||
|
|
||||||
output = torch.empty_like(q_flat)
|
output = torch.empty_like(q_flat)
|
||||||
start = 0
|
seq_start = 0
|
||||||
for seq_len in attn_metadata.seq_lens:
|
for q_len in q_lens:
|
||||||
end = start + seq_len
|
seq_end = seq_start + q_len
|
||||||
# [1, H, L, D]
|
|
||||||
q_s = q_flat[start:end].permute(1, 0, 2).contiguous().unsqueeze(0)
|
# 当前序列的完整 K/V(此路径无前缀,KV == Q)
|
||||||
k_s = k_flat[start:end].permute(1, 0, 2).contiguous().unsqueeze(0)
|
k_s = k_flat[seq_start:seq_end].permute(1, 0, 2).float() # [Hkv, q_len, D]
|
||||||
v_s = v_flat[start:end].permute(1, 0, 2).contiguous().unsqueeze(0)
|
v_s = v_flat[seq_start:seq_end].permute(1, 0, 2).float() # [Hkv, q_len, D]
|
||||||
|
|
||||||
# GQA:展开 KV heads 至与 query heads 一致
|
# GQA:展开 KV heads 至与 query heads 一致
|
||||||
if k_s.shape[1] != q_s.shape[1]:
|
if k_s.shape[0] != self.num_heads:
|
||||||
n = q_s.shape[1] // k_s.shape[1]
|
n = self.num_heads // k_s.shape[0]
|
||||||
k_s = k_s.repeat_interleave(n, dim=1).contiguous()
|
k_s = k_s.repeat_interleave(n, dim=0).contiguous()
|
||||||
v_s = v_s.repeat_interleave(n, dim=1).contiguous()
|
v_s = v_s.repeat_interleave(n, dim=0).contiguous()
|
||||||
|
|
||||||
# 纯数学 attention:完全绕开硬件 flash attention kernel
|
# k_pos 用于因果掩码
|
||||||
# [1, H, L, L]
|
k_pos = torch.arange(q_len, device=query.device)
|
||||||
attn_w = torch.matmul(q_s.float(), k_s.float().transpose(-2, -1))
|
|
||||||
attn_w = attn_w * self.scale
|
|
||||||
|
|
||||||
# 上三角填 -inf(future tokens)
|
# Q-tiling:分块处理 query,峰值内存 O(_Q_CHUNK × q_len)
|
||||||
causal_mask = torch.triu(
|
for qc_start in range(0, q_len, _Q_CHUNK):
|
||||||
torch.ones(seq_len, seq_len, dtype=torch.bool, device=attn_w.device),
|
qc_end = min(qc_start + _Q_CHUNK, q_len)
|
||||||
diagonal=1,
|
|
||||||
)
|
# [H, qc, D]
|
||||||
attn_w = attn_w.masked_fill(causal_mask, float("-inf"))
|
q_c = q_flat[seq_start + qc_start:seq_start + qc_end] \
|
||||||
|
.permute(1, 0, 2).float()
|
||||||
|
|
||||||
|
# [H, qc, q_len]
|
||||||
|
attn_w = torch.matmul(q_c, k_s.transpose(-2, -1)) * self.scale
|
||||||
|
|
||||||
|
# 因果掩码:q_c 里位置 j 只能看 k_pos <= j(相对位置)
|
||||||
|
qc_q_pos = torch.arange(qc_start, qc_end, device=query.device)
|
||||||
|
mask = k_pos.unsqueeze(0) > qc_q_pos.unsqueeze(1)
|
||||||
|
attn_w = attn_w.masked_fill(mask.unsqueeze(0), float("-inf"))
|
||||||
|
|
||||||
# float32 softmax 防止 float16 溢出
|
|
||||||
attn_w = torch.softmax(attn_w, dim=-1)
|
attn_w = torch.softmax(attn_w, dim=-1)
|
||||||
|
out_c = torch.matmul(attn_w, v_s).to(orig_dtype) # [H, qc, D]
|
||||||
|
|
||||||
out_s = torch.matmul(attn_w, v_s.float()).to(orig_dtype)
|
output[seq_start + qc_start:seq_start + qc_end] = (
|
||||||
# [1, H, L, D] → [L, H, D]
|
out_c.permute(1, 0, 2))
|
||||||
output[start:end] = out_s.squeeze(0).permute(1, 0, 2)
|
|
||||||
start = end
|
seq_start = seq_end
|
||||||
|
|
||||||
return output.unsqueeze(0) # [1, T, H, D]
|
return output.unsqueeze(0) # [1, T, H, D]
|
||||||
|
|
||||||
|
|||||||
@@ -31,9 +31,12 @@ from vllm.model_executor.utils import set_weight_attrs
|
|||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE,
|
from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE,
|
||||||
_get_graph_batch_size)
|
_get_graph_batch_size)
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
from vllm.model_executor.models.interfaces import HasInnerState, SupportsLoRA
|
from vllm.model_executor.models.interfaces import HasInnerState, SupportsLoRA
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Pure-PyTorch DeltaNet kernels (fallbacks from transformers 5.2.0)
|
# Pure-PyTorch DeltaNet kernels (fallbacks from transformers 5.2.0)
|
||||||
@@ -364,14 +367,31 @@ class GatedDeltaNet(nn.Module):
|
|||||||
q = q.repeat_interleave(self.head_expand_ratio, dim=2)
|
q = q.repeat_interleave(self.head_expand_ratio, dim=2)
|
||||||
k = k.repeat_interleave(self.head_expand_ratio, dim=2)
|
k = k.repeat_interleave(self.head_expand_ratio, dim=2)
|
||||||
|
|
||||||
core_out, last_state = _torch_chunk_gated_delta_rule(
|
# Sub-sequence chunking: call _torch_chunk_gated_delta_rule
|
||||||
q, k, v, g, beta,
|
# on _DNN_CHUNK tokens at a time to cap peak memory.
|
||||||
initial_state=temporal_state[si:si + 1],
|
# Full 18K: tensors [1,6,282,64,64]=220 MB each → ~990 MB/call.
|
||||||
|
# With _DNN_CHUNK=4096: [1,6,64,64,64]=6 MB each → ~137 MB/call.
|
||||||
|
# State is chained via initial_state / output_final_state.
|
||||||
|
_DNN_CHUNK = 4096
|
||||||
|
cur_state = temporal_state[si:si + 1].clone()
|
||||||
|
core_out_parts = []
|
||||||
|
for sc_start in range(0, seq_len, _DNN_CHUNK):
|
||||||
|
sc_end = min(sc_start + _DNN_CHUNK, seq_len)
|
||||||
|
c_out, cur_state = _torch_chunk_gated_delta_rule(
|
||||||
|
q[:, sc_start:sc_end],
|
||||||
|
k[:, sc_start:sc_end],
|
||||||
|
v[:, sc_start:sc_end],
|
||||||
|
g[:, sc_start:sc_end],
|
||||||
|
beta[:, sc_start:sc_end],
|
||||||
|
initial_state=cur_state,
|
||||||
output_final_state=True,
|
output_final_state=True,
|
||||||
use_qk_l2norm_in_kernel=True,
|
use_qk_l2norm_in_kernel=True,
|
||||||
)
|
)
|
||||||
if last_state is not None:
|
core_out_parts.append(c_out)
|
||||||
temporal_state[si].copy_(last_state[0])
|
if cur_state is not None:
|
||||||
|
temporal_state[si].copy_(cur_state[0])
|
||||||
|
# [1, seq_len, num_v_heads, head_v_dim]
|
||||||
|
core_out = torch.cat(core_out_parts, dim=1)
|
||||||
|
|
||||||
# Gate + norm + output proj
|
# Gate + norm + output proj
|
||||||
z = z_all[s:e].reshape(seq_len, local_num_v, self.head_v_dim)
|
z = z_all[s:e].reshape(seq_len, local_num_v, self.head_v_dim)
|
||||||
@@ -384,7 +404,10 @@ class GatedDeltaNet(nn.Module):
|
|||||||
outputs.append(out)
|
outputs.append(out)
|
||||||
|
|
||||||
result = torch.cat(outputs, dim=0)
|
result = torch.cat(outputs, dim=0)
|
||||||
assert not torch.isnan(result).any(), f"NaN in prefill layer {self.layer_idx}"
|
if torch.isnan(result).any():
|
||||||
|
logger.warning("NaN in prefill GatedDeltaNet layer %d (frac=%.4f), replacing with zeros",
|
||||||
|
self.layer_idx, torch.isnan(result).float().mean().item())
|
||||||
|
result = torch.nan_to_num(result, nan=0.0)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@@ -434,7 +457,10 @@ class GatedDeltaNet(nn.Module):
|
|||||||
z.reshape(-1, self.head_v_dim))
|
z.reshape(-1, self.head_v_dim))
|
||||||
normed = normed.reshape(num_seqs, -1)
|
normed = normed.reshape(num_seqs, -1)
|
||||||
out, _ = self.out_proj(normed)
|
out, _ = self.out_proj(normed)
|
||||||
assert not torch.isnan(out).any(), f"NaN in layer {self.layer_idx}"
|
if torch.isnan(out).any():
|
||||||
|
logger.warning("NaN in decode GatedDeltaNet layer %d (frac=%.4f), replacing with zeros",
|
||||||
|
self.layer_idx, torch.isnan(out).float().mean().item())
|
||||||
|
out = torch.nan_to_num(out, nan=0.0)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user