fix issues
This commit is contained in:
@@ -2,7 +2,8 @@
|
||||
# Pure-PyTorch DeltaNet (no fla / causal_conv1d dependency).
|
||||
# Text-only (no VL, no MTP).
|
||||
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
from collections import OrderedDict
|
||||
from typing import Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -1033,6 +1034,15 @@ class Qwen3_5ForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
|
||||
# Lazy initialised in first forward call
|
||||
self.mamba_cache: Optional[MambaCacheManager] = None
|
||||
|
||||
# GDN prefix state cache (align mode): stores (conv_states, temporal_states) snapshots
|
||||
# at KV-block boundaries so that prefix-cache-hit requests can restore correct GDN state.
|
||||
# Key: tuple of physical block IDs covering the cached prefix
|
||||
# Value: (conv_states_cpu, temporal_states_cpu) each of shape (num_gdn_layers, ...)
|
||||
self._gdn_prefix_cache: OrderedDict = OrderedDict()
|
||||
self._gdn_prefix_cache_max: int = 16 # ~16 × 16 MB ≈ 256 MB CPU RAM
|
||||
self._block_size: int = (cache_config.block_size
|
||||
if cache_config is not None else 16)
|
||||
|
||||
def _get_mamba_cache_shape(self):
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
# Each sequence's state is stored in float32
|
||||
@@ -1069,9 +1079,69 @@ class Qwen3_5ForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
|
||||
# temporal_states: (num_linear_layers, batch, local_num_v, k_dim, v_dim)
|
||||
conv_states, temporal_states = mamba_tensors
|
||||
|
||||
# ── GDN prefix-cache align mode: inject saved state on prefix hit ─────
|
||||
# Conditions: prefill pass, batch=1, context_len > 0 (prefix cached or
|
||||
# previous chunk already processed), block_tables available.
|
||||
# We always attempt a lookup: for subsequent chunked-prefill chunks the
|
||||
# key matches our own saved state (same data already in slot → no-op).
|
||||
# For a true cross-request prefix hit the key matches a previous request.
|
||||
_is_single_seq_prefill = (
|
||||
attn_metadata is not None
|
||||
and attn_metadata.num_prefill_tokens > 0
|
||||
and conv_states.shape[1] == 1 # batch == 1
|
||||
and getattr(attn_metadata, 'context_lens_tensor', None) is not None
|
||||
and getattr(attn_metadata, 'block_tables', None) is not None
|
||||
and attn_metadata.block_tables.numel() > 0
|
||||
)
|
||||
if _is_single_seq_prefill:
|
||||
context_len = int(attn_metadata.context_lens_tensor[0].item())
|
||||
if context_len > 0:
|
||||
num_prefix_blocks = context_len // self._block_size
|
||||
if (num_prefix_blocks > 0
|
||||
and attn_metadata.block_tables.shape[1] >= num_prefix_blocks):
|
||||
lookup_key = tuple(
|
||||
attn_metadata.block_tables[0, :num_prefix_blocks]
|
||||
.cpu().tolist())
|
||||
if lookup_key in self._gdn_prefix_cache:
|
||||
saved_conv, saved_temporal = self._gdn_prefix_cache[lookup_key]
|
||||
conv_states[:, 0].copy_(
|
||||
saved_conv.to(conv_states.device), non_blocking=True)
|
||||
temporal_states[:, 0].copy_(
|
||||
saved_temporal.to(temporal_states.device), non_blocking=True)
|
||||
self._gdn_prefix_cache.move_to_end(lookup_key)
|
||||
logger.debug("GDN prefix cache hit: prefix_len=%d blocks=%d",
|
||||
context_len, num_prefix_blocks)
|
||||
# ── End inject ──────────────────────────────────────────────────────────
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids, positions, kv_caches, attn_metadata,
|
||||
conv_states, temporal_states)
|
||||
|
||||
# ── GDN prefix-cache align mode: save state after this prefill chunk ───
|
||||
# Save state keyed by ALL complete KV blocks processed so far.
|
||||
# Next requests reusing this prefix will restore from here.
|
||||
if _is_single_seq_prefill:
|
||||
context_len = int(attn_metadata.context_lens_tensor[0].item())
|
||||
query_len = attn_metadata.num_prefill_tokens
|
||||
total_processed = context_len + query_len
|
||||
num_complete_blocks = total_processed // self._block_size
|
||||
if (num_complete_blocks > 0
|
||||
and attn_metadata.block_tables.shape[1] >= num_complete_blocks):
|
||||
save_key = tuple(
|
||||
attn_metadata.block_tables[0, :num_complete_blocks]
|
||||
.cpu().tolist())
|
||||
# Move to end (LRU: most recent = last) and update value
|
||||
if save_key in self._gdn_prefix_cache:
|
||||
self._gdn_prefix_cache.move_to_end(save_key)
|
||||
self._gdn_prefix_cache[save_key] = (
|
||||
conv_states[:, 0].cpu().clone(),
|
||||
temporal_states[:, 0].cpu().clone(),
|
||||
)
|
||||
# Evict oldest entries beyond max
|
||||
while len(self._gdn_prefix_cache) > self._gdn_prefix_cache_max:
|
||||
self._gdn_prefix_cache.popitem(last=False)
|
||||
# ── End save ────────────────────────────────────────────────────────────
|
||||
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
|
||||
Reference in New Issue
Block a user