[Bugfix] pipeline parallelism and Eagle Qwen2 (#6910)
This commit is contained in:
@@ -24,13 +24,14 @@ from typing import Iterable, Optional, Tuple
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from sglang.srt.distributed import get_pp_group
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||
from sglang.srt.models.qwen2 import Qwen2DecoderLayer, Qwen2ForCausalLM
|
||||
|
||||
Qwen2Config = None
|
||||
@@ -87,6 +88,7 @@ class Qwen2Model(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
if input_embeds is None:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
@@ -119,6 +121,7 @@ class Qwen2ForCausalLMEagle(Qwen2ForCausalLM):
|
||||
nn.Module.__init__(self)
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.pp_group = get_pp_group()
|
||||
self.model = Qwen2Model(
|
||||
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user