[PP] Add pipeline parallelism (#5724)
This commit is contained in:
@@ -25,6 +25,7 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import LlamaConfig
|
||||
|
||||
from sglang.srt.distributed import get_pp_group
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.linear import QKVParallelLinear, RowParallelLinear
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
@@ -33,7 +34,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||
from sglang.srt.models.llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM
|
||||
|
||||
|
||||
@@ -118,6 +119,7 @@ class LlamaModel(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
if input_embeds is None:
|
||||
embeds = self.embed_tokens(input_ids)
|
||||
@@ -155,6 +157,7 @@ class LlamaForCausalLMEagle3(LlamaForCausalLM):
|
||||
nn.Module.__init__(self)
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.pp_group = get_pp_group()
|
||||
|
||||
if self.config.num_hidden_layers != 1:
|
||||
raise ValueError("EAGLE3 currently only supports 1 layer")
|
||||
|
||||
Reference in New Issue
Block a user