[Bugfix] pipeline parallelism and Eagle Qwen2 (#6910)

This commit is contained in:
Swipe4057
2025-06-07 11:58:50 +03:00
committed by GitHub
parent 2f715f51cc
commit 9736cd3b7d

View File

@@ -24,13 +24,14 @@ from typing import Iterable, Optional, Tuple
import torch
from torch import nn
from sglang.srt.distributed import get_pp_group
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.models.qwen2 import Qwen2DecoderLayer, Qwen2ForCausalLM
Qwen2Config = None
@@ -87,6 +88,7 @@ class Qwen2Model(nn.Module):
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
) -> torch.Tensor:
if input_embeds is None:
hidden_states = self.embed_tokens(input_ids)
@@ -119,6 +121,7 @@ class Qwen2ForCausalLMEagle(Qwen2ForCausalLM):
nn.Module.__init__(self)
self.config = config
self.quant_config = quant_config
self.pp_group = get_pp_group()
self.model = Qwen2Model(
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
)