chunked prefill support and memory opts
This commit is contained in:
@@ -334,6 +334,11 @@ class GatedDeltaNet(nn.Module):
|
||||
.transpose(0, 1).unsqueeze(0)
|
||||
.to(weight_2d.dtype))
|
||||
|
||||
# Load prev conv state BEFORE overwriting (needed for causal conv padding).
|
||||
# For first prefill of a request: mamba_cache is zeros → correct.
|
||||
# For chunked prefill chunk 2+: carries last state_len tokens from prev chunk.
|
||||
prev_conv = conv_state[si:si + 1].clone().to(weight_2d.dtype) # [1, local_conv_dim, state_len]
|
||||
|
||||
# Save conv state (last state_len positions)
|
||||
if seq_len >= state_len:
|
||||
conv_state[si].copy_(mixed_qkv[0, :, -state_len:])
|
||||
@@ -342,8 +347,8 @@ class GatedDeltaNet(nn.Module):
|
||||
mixed_qkv[0])
|
||||
conv_state[si, :, :state_len - seq_len] = 0
|
||||
|
||||
# Causal conv (left-pad with zeros, then convolve)
|
||||
padded = F.pad(mixed_qkv, (state_len, 0))
|
||||
# Causal conv: left-pad with previous conv state (not zeros).
|
||||
padded = torch.cat([prev_conv, mixed_qkv], dim=2)
|
||||
mixed_qkv_conv = F.conv1d(
|
||||
padded, self.conv1d_weight,
|
||||
bias=None, padding=0, groups=local_conv_dim)
|
||||
@@ -850,12 +855,11 @@ class Qwen3_5ForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
# Non-driver TP ranks have seq_groups=None in sampling_metadata (normal
|
||||
# TP behavior); they must still call logits_processor to participate in
|
||||
# the NCCL gather inside lm_head. logits_processor returns None for
|
||||
# non-driver ranks after the gather, safely skipping _apply_logits_processors.
|
||||
# Rank 0 (driver) always has seq_groups != None given
|
||||
# --max-num-batched-tokens >= --max-model-len (no chunked-prefill splits).
|
||||
# All TP ranks must call logits_processor to participate in the NCCL
|
||||
# gather inside lm_head. Non-driver ranks return None after the gather.
|
||||
# With chunked prefill, intermediate chunks have seq_groups=None on all
|
||||
# ranks; _apply_logits_processors is guarded against this in
|
||||
# logits_processor.py (patched by patch_xformers_sdpa_seq.py).
|
||||
return self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user