fix issues

This commit is contained in:
2026-06-26 12:55:02 +08:00
parent 3d62430fd7
commit c84151eef9
9 changed files with 1879 additions and 5 deletions

View File

@@ -2,7 +2,8 @@
# Pure-PyTorch DeltaNet (no fla / causal_conv1d dependency).
# Text-only (no VL, no MTP).
from typing import Iterable, List, Optional, Tuple
from collections import OrderedDict
from typing import Dict, Iterable, List, Optional, Tuple
import torch
import torch.nn.functional as F
@@ -1033,6 +1034,15 @@ class Qwen3_5ForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
# Lazy initialised in first forward call
self.mamba_cache: Optional[MambaCacheManager] = None
# GDN prefix state cache (align mode): stores (conv_states, temporal_states) snapshots
# at KV-block boundaries so that prefix-cache-hit requests can restore correct GDN state.
# Key: tuple of physical block IDs covering the cached prefix
# Value: (conv_states_cpu, temporal_states_cpu) each of shape (num_gdn_layers, ...)
self._gdn_prefix_cache: OrderedDict = OrderedDict()
self._gdn_prefix_cache_max: int = 16 # ~16 × 16 MB ≈ 256 MB CPU RAM
self._block_size: int = (cache_config.block_size
if cache_config is not None else 16)
def _get_mamba_cache_shape(self):
tp_size = get_tensor_model_parallel_world_size()
# Each sequence's state is stored in float32
@@ -1069,9 +1079,69 @@ class Qwen3_5ForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
# temporal_states: (num_linear_layers, batch, local_num_v, k_dim, v_dim)
conv_states, temporal_states = mamba_tensors
# ── GDN prefix-cache align mode: inject saved state on prefix hit ─────
# Conditions: prefill pass, batch=1, context_len > 0 (prefix cached or
# previous chunk already processed), block_tables available.
# We always attempt a lookup: for subsequent chunked-prefill chunks the
# key matches our own saved state (same data already in slot → no-op).
# For a true cross-request prefix hit the key matches a previous request.
_is_single_seq_prefill = (
attn_metadata is not None
and attn_metadata.num_prefill_tokens > 0
and conv_states.shape[1] == 1 # batch == 1
and getattr(attn_metadata, 'context_lens_tensor', None) is not None
and getattr(attn_metadata, 'block_tables', None) is not None
and attn_metadata.block_tables.numel() > 0
)
if _is_single_seq_prefill:
context_len = int(attn_metadata.context_lens_tensor[0].item())
if context_len > 0:
num_prefix_blocks = context_len // self._block_size
if (num_prefix_blocks > 0
and attn_metadata.block_tables.shape[1] >= num_prefix_blocks):
lookup_key = tuple(
attn_metadata.block_tables[0, :num_prefix_blocks]
.cpu().tolist())
if lookup_key in self._gdn_prefix_cache:
saved_conv, saved_temporal = self._gdn_prefix_cache[lookup_key]
conv_states[:, 0].copy_(
saved_conv.to(conv_states.device), non_blocking=True)
temporal_states[:, 0].copy_(
saved_temporal.to(temporal_states.device), non_blocking=True)
self._gdn_prefix_cache.move_to_end(lookup_key)
logger.debug("GDN prefix cache hit: prefix_len=%d blocks=%d",
context_len, num_prefix_blocks)
# ── End inject ──────────────────────────────────────────────────────────
hidden_states = self.model(
input_ids, positions, kv_caches, attn_metadata,
conv_states, temporal_states)
# ── GDN prefix-cache align mode: save state after this prefill chunk ───
# Save state keyed by ALL complete KV blocks processed so far.
# Next requests reusing this prefix will restore from here.
if _is_single_seq_prefill:
context_len = int(attn_metadata.context_lens_tensor[0].item())
query_len = attn_metadata.num_prefill_tokens
total_processed = context_len + query_len
num_complete_blocks = total_processed // self._block_size
if (num_complete_blocks > 0
and attn_metadata.block_tables.shape[1] >= num_complete_blocks):
save_key = tuple(
attn_metadata.block_tables[0, :num_complete_blocks]
.cpu().tolist())
# Move to end (LRU: most recent = last) and update value
if save_key in self._gdn_prefix_cache:
self._gdn_prefix_cache.move_to_end(save_key)
self._gdn_prefix_cache[save_key] = (
conv_states[:, 0].cpu().clone(),
temporal_states[:, 0].cpu().clone(),
)
# Evict oldest entries beyond max
while len(self._gdn_prefix_cache) > self._gdn_prefix_cache_max:
self._gdn_prefix_cache.popitem(last=False)
# ── End save ────────────────────────────────────────────────────────────
return hidden_states
def compute_logits(