diff --git a/python/sglang/srt/models/qwen2_eagle.py b/python/sglang/srt/models/qwen2_eagle.py index 793d91560..4b4c0ec41 100644 --- a/python/sglang/srt/models/qwen2_eagle.py +++ b/python/sglang/srt/models/qwen2_eagle.py @@ -24,13 +24,14 @@ from typing import Iterable, Optional, Tuple import torch from torch import nn +from sglang.srt.distributed import get_pp_group from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.models.qwen2 import Qwen2DecoderLayer, Qwen2ForCausalLM Qwen2Config = None @@ -87,6 +88,7 @@ class Qwen2Model(nn.Module): positions: torch.Tensor, forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, + pp_proxy_tensors: Optional[PPProxyTensors] = None, ) -> torch.Tensor: if input_embeds is None: hidden_states = self.embed_tokens(input_ids) @@ -119,6 +121,7 @@ class Qwen2ForCausalLMEagle(Qwen2ForCausalLM): nn.Module.__init__(self) self.config = config self.quant_config = quant_config + self.pp_group = get_pp_group() self.model = Qwen2Model( config, quant_config=quant_config, prefix=add_prefix("model", prefix) )