chunked prefill support and memory opts

This commit is contained in:
2026-06-05 16:03:34 +08:00
parent 8c047a70ea
commit 2d1ef50992
4 changed files with 166 additions and 86 deletions

View File

@@ -334,6 +334,11 @@ class GatedDeltaNet(nn.Module):
.transpose(0, 1).unsqueeze(0)
.to(weight_2d.dtype))
# Load prev conv state BEFORE overwriting (needed for causal conv padding).
# For first prefill of a request: mamba_cache is zeros → correct.
# For chunked prefill chunk 2+: carries last state_len tokens from prev chunk.
prev_conv = conv_state[si:si + 1].clone().to(weight_2d.dtype) # [1, local_conv_dim, state_len]
# Save conv state (last state_len positions)
if seq_len >= state_len:
conv_state[si].copy_(mixed_qkv[0, :, -state_len:])
@@ -342,8 +347,8 @@ class GatedDeltaNet(nn.Module):
mixed_qkv[0])
conv_state[si, :, :state_len - seq_len] = 0
# Causal conv (left-pad with zeros, then convolve)
padded = F.pad(mixed_qkv, (state_len, 0))
# Causal conv: left-pad with previous conv state (not zeros).
padded = torch.cat([prev_conv, mixed_qkv], dim=2)
mixed_qkv_conv = F.conv1d(
padded, self.conv1d_weight,
bias=None, padding=0, groups=local_conv_dim)
@@ -850,12 +855,11 @@ class Qwen3_5ForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
# Non-driver TP ranks have seq_groups=None in sampling_metadata (normal
# TP behavior); they must still call logits_processor to participate in
# the NCCL gather inside lm_head. logits_processor returns None for
# non-driver ranks after the gather, safely skipping _apply_logits_processors.
# Rank 0 (driver) always has seq_groups != None given
# --max-num-batched-tokens >= --max-model-len (no chunked-prefill splits).
# All TP ranks must call logits_processor to participate in the NCCL
# gather inside lm_head. Non-driver ranks return None after the gather.
# With chunked prefill, intermediate chunks have seq_groups=None on all
# ranks; _apply_logits_processors is guarded against this in
# logits_processor.py (patched by patch_xformers_sdpa_seq.py).
return self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)