[Bugfix] pipeline parallelism and Eagle Qwen2 (#6910)
This commit is contained in:
@@ -24,13 +24,14 @@ from typing import Iterable, Optional, Tuple
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
from sglang.srt.distributed import get_pp_group
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||||
ParallelLMHead,
|
ParallelLMHead,
|
||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
)
|
)
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||||
from sglang.srt.models.qwen2 import Qwen2DecoderLayer, Qwen2ForCausalLM
|
from sglang.srt.models.qwen2 import Qwen2DecoderLayer, Qwen2ForCausalLM
|
||||||
|
|
||||||
Qwen2Config = None
|
Qwen2Config = None
|
||||||
@@ -87,6 +88,7 @@ class Qwen2Model(nn.Module):
|
|||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
input_embeds: torch.Tensor = None,
|
input_embeds: torch.Tensor = None,
|
||||||
|
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if input_embeds is None:
|
if input_embeds is None:
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
@@ -119,6 +121,7 @@ class Qwen2ForCausalLMEagle(Qwen2ForCausalLM):
|
|||||||
nn.Module.__init__(self)
|
nn.Module.__init__(self)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
self.pp_group = get_pp_group()
|
||||||
self.model = Qwen2Model(
|
self.model = Qwen2Model(
|
||||||
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user