some modifications to ensure 50K context input
This commit is contained in:
@@ -412,6 +412,9 @@ class GatedDeltaNet(nn.Module):
|
||||
|
||||
else:
|
||||
# Decode: one token per sequence
|
||||
with open("/tmp/vllm_decode_debug.log", "a") as _f:
|
||||
_f.write(f"[deltanet decode] layer={self.layer_idx} num_seqs={hidden_states.shape[0]}\n")
|
||||
_f.flush()
|
||||
num_seqs = hidden_states.shape[0]
|
||||
weight_2d = self.conv1d_weight.squeeze(1)
|
||||
|
||||
@@ -847,6 +850,12 @@ class Qwen3_5ForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
# Non-driver TP ranks have seq_groups=None in sampling_metadata (normal
|
||||
# TP behavior); they must still call logits_processor to participate in
|
||||
# the NCCL gather inside lm_head. logits_processor returns None for
|
||||
# non-driver ranks after the gather, safely skipping _apply_logits_processors.
|
||||
# Rank 0 (driver) always has seq_groups != None given
|
||||
# --max-num-batched-tokens >= --max-model-len (no chunked-prefill splits).
|
||||
return self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user