[PP] Add pipeline parallelism (#5724)

This commit is contained in:
Ying Sheng
2025-04-30 18:18:07 -07:00
committed by GitHub
parent e97e57e699
commit 11383cec3c
25 changed files with 1150 additions and 308 deletions

View File

@@ -25,6 +25,7 @@ import torch
from torch import nn
from transformers import LlamaConfig
from sglang.srt.distributed import get_pp_group
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import QKVParallelLinear, RowParallelLinear
from sglang.srt.layers.logits_processor import LogitsProcessor
@@ -33,7 +34,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.models.llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM
@@ -118,6 +119,7 @@ class LlamaModel(nn.Module):
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
) -> torch.Tensor:
if input_embeds is None:
embeds = self.embed_tokens(input_ids)
@@ -155,6 +157,7 @@ class LlamaForCausalLMEagle3(LlamaForCausalLM):
nn.Module.__init__(self)
self.config = config
self.quant_config = quant_config
self.pp_group = get_pp_group()
if self.config.num_hidden_layers != 1:
raise ValueError("EAGLE3 currently only supports 1 layer")