[PP] Fix init_memory_pool desync & add PP for mixtral (#6223)
This commit is contained in:
@@ -468,9 +468,6 @@ class PrefillAdder:
|
||||
return AddReqResult.OTHER
|
||||
|
||||
with self._lock_node(req.last_node):
|
||||
if total_tokens > self.rem_total_tokens:
|
||||
return AddReqResult.NO_TOKEN
|
||||
|
||||
if (
|
||||
enable_hierarchical_cache
|
||||
and req.last_node_global is not None
|
||||
|
||||
@@ -719,7 +719,7 @@ class Scheduler(
|
||||
server_is_idle = False
|
||||
result = self.run_batch(self.cur_batch)
|
||||
|
||||
# send the outputs to the next step
|
||||
# (last rank) send the outputs to the next step
|
||||
if self.pp_group.is_last_rank:
|
||||
if self.cur_batch:
|
||||
next_token_ids, bids[mb_id] = (
|
||||
@@ -759,18 +759,18 @@ class Scheduler(
|
||||
self.process_batch_result(mbs[next_mb_id], output_result)
|
||||
last_mbs[next_mb_id] = mbs[next_mb_id]
|
||||
|
||||
# carry the outputs to the next stage
|
||||
# (not last rank)
|
||||
if not self.pp_group.is_last_rank:
|
||||
if self.cur_batch:
|
||||
bids[mb_id] = result.bid
|
||||
# carry the outputs to the next stage
|
||||
# send the outputs from the last round to let the next stage worker run post processing
|
||||
if pp_outputs:
|
||||
# send the outputs from the last round to let the next stage worker run post processing
|
||||
self.pp_group.send_tensor_dict(
|
||||
pp_outputs.tensors,
|
||||
all_gather_group=self.attn_tp_group,
|
||||
)
|
||||
|
||||
if not self.pp_group.is_last_rank:
|
||||
# send out reqs to the next stage
|
||||
dp_offset = self.dp_rank * self.attn_tp_size
|
||||
if self.attn_tp_rank == 0:
|
||||
|
||||
@@ -32,6 +32,7 @@ from sglang.srt.configs.load_config import LoadConfig
|
||||
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
||||
from sglang.srt.distributed import (
|
||||
get_tp_group,
|
||||
get_world_group,
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel,
|
||||
set_custom_all_reduce,
|
||||
@@ -404,7 +405,10 @@ class ModelRunner:
|
||||
)
|
||||
|
||||
min_per_gpu_memory = get_available_gpu_memory(
|
||||
self.device, self.gpu_id, distributed=self.tp_size > 1
|
||||
self.device,
|
||||
self.gpu_id,
|
||||
distributed=get_world_group().world_size > 1,
|
||||
cpu_group=get_world_group().cpu_group,
|
||||
)
|
||||
self.tp_group = get_tp_group()
|
||||
self.attention_tp_group = get_attention_tp_group()
|
||||
@@ -716,7 +720,10 @@ class ModelRunner:
|
||||
|
||||
def profile_max_num_token(self, total_gpu_memory: int):
|
||||
available_gpu_memory = get_available_gpu_memory(
|
||||
self.device, self.gpu_id, distributed=self.tp_size > 1
|
||||
self.device,
|
||||
self.gpu_id,
|
||||
distributed=get_world_group().world_size > 1,
|
||||
cpu_group=get_world_group().cpu_group,
|
||||
)
|
||||
if self.use_mla_backend:
|
||||
num_layers = (
|
||||
|
||||
@@ -16,13 +16,15 @@
|
||||
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
|
||||
"""Inference-only Mixtral model."""
|
||||
|
||||
from typing import Iterable, Optional, Tuple
|
||||
import logging
|
||||
from typing import Iterable, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import MixtralConfig
|
||||
|
||||
from sglang.srt.distributed import (
|
||||
get_pp_group,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
@@ -38,14 +40,17 @@ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.rotary_embedding import get_rope
|
||||
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
|
||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import add_prefix
|
||||
from sglang.srt.utils import add_prefix, make_layers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MixtralMoE(nn.Module):
|
||||
@@ -257,24 +262,32 @@ class MixtralModel(nn.Module):
|
||||
super().__init__()
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
self.pp_group = get_pp_group()
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
prefix=add_prefix("embed_tokens", prefix),
|
||||
if self.pp_group.is_first_rank:
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
prefix=add_prefix("embed_tokens", prefix),
|
||||
)
|
||||
else:
|
||||
self.embed_tokens = PPMissingLayer()
|
||||
|
||||
self.layers, self.start_layer, self.end_layer = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda idx, prefix: MixtralDecoderLayer(
|
||||
config=config, quant_config=quant_config, layer_id=idx, prefix=prefix
|
||||
),
|
||||
pp_rank=self.pp_group.rank_in_group,
|
||||
pp_size=self.pp_group.world_size,
|
||||
prefix="layers",
|
||||
return_tuple=True,
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
MixtralDecoderLayer(
|
||||
config,
|
||||
i,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix(f"layers.{i}", prefix),
|
||||
)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
if self.pp_group.is_last_rank:
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
else:
|
||||
self.norm = PPMissingLayer(return_tuple=True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -282,18 +295,35 @@ class MixtralModel(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
if input_embeds is None:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
||||
) -> Union[torch.Tensor, PPProxyTensors]:
|
||||
if self.pp_group.is_first_rank:
|
||||
if input_embeds is None:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
else:
|
||||
hidden_states = input_embeds
|
||||
residual = None
|
||||
else:
|
||||
hidden_states = input_embeds
|
||||
residual = None
|
||||
for i in range(len(self.layers)):
|
||||
assert pp_proxy_tensors is not None
|
||||
hidden_states = pp_proxy_tensors["hidden_states"]
|
||||
residual = pp_proxy_tensors["residual"]
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions, hidden_states, forward_batch, residual
|
||||
)
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
||||
if not self.pp_group.is_last_rank:
|
||||
return PPProxyTensors(
|
||||
{
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual,
|
||||
}
|
||||
)
|
||||
else:
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -306,6 +336,7 @@ class MixtralForCausalLM(nn.Module):
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.pp_group = get_pp_group()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = MixtralModel(
|
||||
@@ -322,12 +353,31 @@ class MixtralForCausalLM(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
positions,
|
||||
forward_batch,
|
||||
input_embeds,
|
||||
pp_proxy_tensors=pp_proxy_tensors,
|
||||
)
|
||||
|
||||
if self.pp_group.is_last_rank:
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
)
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
@property
|
||||
def start_layer(self):
|
||||
return self.model.start_layer
|
||||
|
||||
@property
|
||||
def end_layer(self):
|
||||
return self.model.end_layer
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
@@ -348,6 +398,17 @@ class MixtralForCausalLM(nn.Module):
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in weights:
|
||||
layer_id = get_layer_id(name)
|
||||
if (
|
||||
layer_id is not None
|
||||
and hasattr(self.model, "start_layer")
|
||||
and (
|
||||
layer_id < self.model.start_layer
|
||||
or layer_id >= self.model.end_layer
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
@@ -398,11 +459,14 @@ class MixtralForCausalLM(nn.Module):
|
||||
if name is None:
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
if name in params_dict.keys():
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
else:
|
||||
logger.warning(f"Parameter {name} not found in params_dict")
|
||||
|
||||
|
||||
EntryClass = MixtralForCausalLM
|
||||
|
||||
@@ -347,6 +347,12 @@ class ServerArgs:
|
||||
f"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
|
||||
)
|
||||
|
||||
if self.pp_size > 1:
|
||||
self.disable_overlap_schedule = True
|
||||
logger.warning(
|
||||
"Pipeline parallelism is incompatible with overlap schedule."
|
||||
)
|
||||
|
||||
# Speculative Decoding
|
||||
if self.speculative_algorithm == "NEXTN":
|
||||
# NEXTN shares the same implementation of EAGLE
|
||||
|
||||
@@ -282,7 +282,9 @@ def calculate_time(show=False, min_cost_ms=0.0):
|
||||
return wrapper
|
||||
|
||||
|
||||
def get_available_gpu_memory(device, gpu_id, distributed=False, empty_cache=True):
|
||||
def get_available_gpu_memory(
|
||||
device, gpu_id, distributed=False, empty_cache=True, cpu_group=None
|
||||
):
|
||||
"""
|
||||
Get available memory for cuda:gpu_id device.
|
||||
When distributed is True, the available memory is the minimum available memory of all GPUs.
|
||||
@@ -344,10 +346,10 @@ def get_available_gpu_memory(device, gpu_id, distributed=False, empty_cache=True
|
||||
free_gpu_memory, total_gpu_memory = torch.npu.mem_get_info()
|
||||
|
||||
if distributed:
|
||||
tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to(
|
||||
torch.device(device, gpu_id)
|
||||
tensor = torch.tensor(free_gpu_memory, dtype=torch.float32)
|
||||
torch.distributed.all_reduce(
|
||||
tensor, op=torch.distributed.ReduceOp.MIN, group=cpu_group
|
||||
)
|
||||
torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.MIN)
|
||||
free_gpu_memory = tensor.item()
|
||||
|
||||
return free_gpu_memory / (1 << 30)
|
||||
|
||||
Reference in New Issue
Block a user