initial version of adding chunked attention, ensuring 20K context

This commit is contained in:
2026-05-29 16:49:33 +08:00
parent 0e89906481
commit 3ef8227384
3 changed files with 142 additions and 47 deletions

View File

@@ -31,9 +31,12 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import IntermediateTensors
from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE,
_get_graph_batch_size)
from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import HasInnerState, SupportsLoRA
logger = init_logger(__name__)
# ---------------------------------------------------------------------------
# Pure-PyTorch DeltaNet kernels (fallbacks from transformers 5.2.0)
@@ -364,14 +367,31 @@ class GatedDeltaNet(nn.Module):
q = q.repeat_interleave(self.head_expand_ratio, dim=2)
k = k.repeat_interleave(self.head_expand_ratio, dim=2)
core_out, last_state = _torch_chunk_gated_delta_rule(
q, k, v, g, beta,
initial_state=temporal_state[si:si + 1],
output_final_state=True,
use_qk_l2norm_in_kernel=True,
)
if last_state is not None:
temporal_state[si].copy_(last_state[0])
# Sub-sequence chunking: call _torch_chunk_gated_delta_rule
# on _DNN_CHUNK tokens at a time to cap peak memory.
# Full 18K: tensors [1,6,282,64,64]=220 MB each → ~990 MB/call.
# With _DNN_CHUNK=4096: [1,6,64,64,64]=6 MB each → ~137 MB/call.
# State is chained via initial_state / output_final_state.
_DNN_CHUNK = 4096
cur_state = temporal_state[si:si + 1].clone()
core_out_parts = []
for sc_start in range(0, seq_len, _DNN_CHUNK):
sc_end = min(sc_start + _DNN_CHUNK, seq_len)
c_out, cur_state = _torch_chunk_gated_delta_rule(
q[:, sc_start:sc_end],
k[:, sc_start:sc_end],
v[:, sc_start:sc_end],
g[:, sc_start:sc_end],
beta[:, sc_start:sc_end],
initial_state=cur_state,
output_final_state=True,
use_qk_l2norm_in_kernel=True,
)
core_out_parts.append(c_out)
if cur_state is not None:
temporal_state[si].copy_(cur_state[0])
# [1, seq_len, num_v_heads, head_v_dim]
core_out = torch.cat(core_out_parts, dim=1)
# Gate + norm + output proj
z = z_all[s:e].reshape(seq_len, local_num_v, self.head_v_dim)
@@ -384,7 +404,10 @@ class GatedDeltaNet(nn.Module):
outputs.append(out)
result = torch.cat(outputs, dim=0)
assert not torch.isnan(result).any(), f"NaN in prefill layer {self.layer_idx}"
if torch.isnan(result).any():
logger.warning("NaN in prefill GatedDeltaNet layer %d (frac=%.4f), replacing with zeros",
self.layer_idx, torch.isnan(result).float().mean().item())
result = torch.nan_to_num(result, nan=0.0)
return result
else:
@@ -434,7 +457,10 @@ class GatedDeltaNet(nn.Module):
z.reshape(-1, self.head_v_dim))
normed = normed.reshape(num_seqs, -1)
out, _ = self.out_proj(normed)
assert not torch.isnan(out).any(), f"NaN in layer {self.layer_idx}"
if torch.isnan(out).any():
logger.warning("NaN in decode GatedDeltaNet layer %d (frac=%.4f), replacing with zeros",
self.layer_idx, torch.isnan(out).float().mean().item())
out = torch.nan_to_num(out, nan=0.0)
return out