[PP] Fix init_memory_pool desync & add PP for mixtral (#6223)
This commit is contained in:
12
.github/workflows/pr-test.yml
vendored
12
.github/workflows/pr-test.yml
vendored
@@ -229,6 +229,18 @@ jobs:
|
|||||||
cd test/srt
|
cd test/srt
|
||||||
python3 -m unittest test_bench_serving.TestBenchServing.test_moe_offline_throughput_without_radix_cache
|
python3 -m unittest test_bench_serving.TestBenchServing.test_moe_offline_throughput_without_radix_cache
|
||||||
|
|
||||||
|
- name: Benchmark offline decode throughput (PP=2)
|
||||||
|
timeout-minutes: 10
|
||||||
|
run: |
|
||||||
|
cd test/srt
|
||||||
|
python3 -m unittest test_bench_serving.TestBenchServing.test_pp_offline_throughput_default_decode
|
||||||
|
|
||||||
|
- name: Benchmark offline prefill throughput (PP=2)
|
||||||
|
timeout-minutes: 10
|
||||||
|
run: |
|
||||||
|
cd test/srt
|
||||||
|
python3 -m unittest test_bench_serving.TestBenchServing.test_pp_long_context_prefill
|
||||||
|
|
||||||
accuracy-test-1-gpu:
|
accuracy-test-1-gpu:
|
||||||
if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') &&
|
if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') &&
|
||||||
github.event.pull_request.draft == false
|
github.event.pull_request.draft == false
|
||||||
|
|||||||
@@ -468,9 +468,6 @@ class PrefillAdder:
|
|||||||
return AddReqResult.OTHER
|
return AddReqResult.OTHER
|
||||||
|
|
||||||
with self._lock_node(req.last_node):
|
with self._lock_node(req.last_node):
|
||||||
if total_tokens > self.rem_total_tokens:
|
|
||||||
return AddReqResult.NO_TOKEN
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
enable_hierarchical_cache
|
enable_hierarchical_cache
|
||||||
and req.last_node_global is not None
|
and req.last_node_global is not None
|
||||||
|
|||||||
@@ -719,7 +719,7 @@ class Scheduler(
|
|||||||
server_is_idle = False
|
server_is_idle = False
|
||||||
result = self.run_batch(self.cur_batch)
|
result = self.run_batch(self.cur_batch)
|
||||||
|
|
||||||
# send the outputs to the next step
|
# (last rank) send the outputs to the next step
|
||||||
if self.pp_group.is_last_rank:
|
if self.pp_group.is_last_rank:
|
||||||
if self.cur_batch:
|
if self.cur_batch:
|
||||||
next_token_ids, bids[mb_id] = (
|
next_token_ids, bids[mb_id] = (
|
||||||
@@ -759,18 +759,18 @@ class Scheduler(
|
|||||||
self.process_batch_result(mbs[next_mb_id], output_result)
|
self.process_batch_result(mbs[next_mb_id], output_result)
|
||||||
last_mbs[next_mb_id] = mbs[next_mb_id]
|
last_mbs[next_mb_id] = mbs[next_mb_id]
|
||||||
|
|
||||||
# carry the outputs to the next stage
|
# (not last rank)
|
||||||
if not self.pp_group.is_last_rank:
|
if not self.pp_group.is_last_rank:
|
||||||
if self.cur_batch:
|
if self.cur_batch:
|
||||||
bids[mb_id] = result.bid
|
bids[mb_id] = result.bid
|
||||||
|
# carry the outputs to the next stage
|
||||||
|
# send the outputs from the last round to let the next stage worker run post processing
|
||||||
if pp_outputs:
|
if pp_outputs:
|
||||||
# send the outputs from the last round to let the next stage worker run post processing
|
|
||||||
self.pp_group.send_tensor_dict(
|
self.pp_group.send_tensor_dict(
|
||||||
pp_outputs.tensors,
|
pp_outputs.tensors,
|
||||||
all_gather_group=self.attn_tp_group,
|
all_gather_group=self.attn_tp_group,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not self.pp_group.is_last_rank:
|
|
||||||
# send out reqs to the next stage
|
# send out reqs to the next stage
|
||||||
dp_offset = self.dp_rank * self.attn_tp_size
|
dp_offset = self.dp_rank * self.attn_tp_size
|
||||||
if self.attn_tp_rank == 0:
|
if self.attn_tp_rank == 0:
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ from sglang.srt.configs.load_config import LoadConfig
|
|||||||
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
||||||
from sglang.srt.distributed import (
|
from sglang.srt.distributed import (
|
||||||
get_tp_group,
|
get_tp_group,
|
||||||
|
get_world_group,
|
||||||
init_distributed_environment,
|
init_distributed_environment,
|
||||||
initialize_model_parallel,
|
initialize_model_parallel,
|
||||||
set_custom_all_reduce,
|
set_custom_all_reduce,
|
||||||
@@ -404,7 +405,10 @@ class ModelRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
min_per_gpu_memory = get_available_gpu_memory(
|
min_per_gpu_memory = get_available_gpu_memory(
|
||||||
self.device, self.gpu_id, distributed=self.tp_size > 1
|
self.device,
|
||||||
|
self.gpu_id,
|
||||||
|
distributed=get_world_group().world_size > 1,
|
||||||
|
cpu_group=get_world_group().cpu_group,
|
||||||
)
|
)
|
||||||
self.tp_group = get_tp_group()
|
self.tp_group = get_tp_group()
|
||||||
self.attention_tp_group = get_attention_tp_group()
|
self.attention_tp_group = get_attention_tp_group()
|
||||||
@@ -716,7 +720,10 @@ class ModelRunner:
|
|||||||
|
|
||||||
def profile_max_num_token(self, total_gpu_memory: int):
|
def profile_max_num_token(self, total_gpu_memory: int):
|
||||||
available_gpu_memory = get_available_gpu_memory(
|
available_gpu_memory = get_available_gpu_memory(
|
||||||
self.device, self.gpu_id, distributed=self.tp_size > 1
|
self.device,
|
||||||
|
self.gpu_id,
|
||||||
|
distributed=get_world_group().world_size > 1,
|
||||||
|
cpu_group=get_world_group().cpu_group,
|
||||||
)
|
)
|
||||||
if self.use_mla_backend:
|
if self.use_mla_backend:
|
||||||
num_layers = (
|
num_layers = (
|
||||||
|
|||||||
@@ -16,13 +16,15 @@
|
|||||||
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
|
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
|
||||||
"""Inference-only Mixtral model."""
|
"""Inference-only Mixtral model."""
|
||||||
|
|
||||||
from typing import Iterable, Optional, Tuple
|
import logging
|
||||||
|
from typing import Iterable, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import MixtralConfig
|
from transformers import MixtralConfig
|
||||||
|
|
||||||
from sglang.srt.distributed import (
|
from sglang.srt.distributed import (
|
||||||
|
get_pp_group,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
tensor_model_parallel_all_reduce,
|
tensor_model_parallel_all_reduce,
|
||||||
)
|
)
|
||||||
@@ -38,14 +40,17 @@ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
|||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.layers.rotary_embedding import get_rope
|
from sglang.srt.layers.rotary_embedding import get_rope
|
||||||
|
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||||
ParallelLMHead,
|
ParallelLMHead,
|
||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
from sglang.srt.utils import add_prefix
|
from sglang.srt.utils import add_prefix, make_layers
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class MixtralMoE(nn.Module):
|
class MixtralMoE(nn.Module):
|
||||||
@@ -257,24 +262,32 @@ class MixtralModel(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.padding_idx = config.pad_token_id
|
self.padding_idx = config.pad_token_id
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
|
self.pp_group = get_pp_group()
|
||||||
|
|
||||||
self.embed_tokens = VocabParallelEmbedding(
|
if self.pp_group.is_first_rank:
|
||||||
config.vocab_size,
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
config.hidden_size,
|
config.vocab_size,
|
||||||
prefix=add_prefix("embed_tokens", prefix),
|
config.hidden_size,
|
||||||
|
prefix=add_prefix("embed_tokens", prefix),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.embed_tokens = PPMissingLayer()
|
||||||
|
|
||||||
|
self.layers, self.start_layer, self.end_layer = make_layers(
|
||||||
|
config.num_hidden_layers,
|
||||||
|
lambda idx, prefix: MixtralDecoderLayer(
|
||||||
|
config=config, quant_config=quant_config, layer_id=idx, prefix=prefix
|
||||||
|
),
|
||||||
|
pp_rank=self.pp_group.rank_in_group,
|
||||||
|
pp_size=self.pp_group.world_size,
|
||||||
|
prefix="layers",
|
||||||
|
return_tuple=True,
|
||||||
)
|
)
|
||||||
self.layers = nn.ModuleList(
|
|
||||||
[
|
if self.pp_group.is_last_rank:
|
||||||
MixtralDecoderLayer(
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
config,
|
else:
|
||||||
i,
|
self.norm = PPMissingLayer(return_tuple=True)
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=add_prefix(f"layers.{i}", prefix),
|
|
||||||
)
|
|
||||||
for i in range(config.num_hidden_layers)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -282,18 +295,35 @@ class MixtralModel(nn.Module):
|
|||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
input_embeds: torch.Tensor = None,
|
input_embeds: torch.Tensor = None,
|
||||||
) -> torch.Tensor:
|
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
||||||
if input_embeds is None:
|
) -> Union[torch.Tensor, PPProxyTensors]:
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
if self.pp_group.is_first_rank:
|
||||||
|
if input_embeds is None:
|
||||||
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
else:
|
||||||
|
hidden_states = input_embeds
|
||||||
|
residual = None
|
||||||
else:
|
else:
|
||||||
hidden_states = input_embeds
|
assert pp_proxy_tensors is not None
|
||||||
residual = None
|
hidden_states = pp_proxy_tensors["hidden_states"]
|
||||||
for i in range(len(self.layers)):
|
residual = pp_proxy_tensors["residual"]
|
||||||
|
|
||||||
|
for i in range(self.start_layer, self.end_layer):
|
||||||
layer = self.layers[i]
|
layer = self.layers[i]
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
positions, hidden_states, forward_batch, residual
|
positions, hidden_states, forward_batch, residual
|
||||||
)
|
)
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
|
||||||
|
if not self.pp_group.is_last_rank:
|
||||||
|
return PPProxyTensors(
|
||||||
|
{
|
||||||
|
"hidden_states": hidden_states,
|
||||||
|
"residual": residual,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
@@ -306,6 +336,7 @@ class MixtralForCausalLM(nn.Module):
|
|||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.pp_group = get_pp_group()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.model = MixtralModel(
|
self.model = MixtralModel(
|
||||||
@@ -322,12 +353,31 @@ class MixtralForCausalLM(nn.Module):
|
|||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
input_embeds: torch.Tensor = None,
|
input_embeds: torch.Tensor = None,
|
||||||
|
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
hidden_states = self.model(
|
||||||
return self.logits_processor(
|
input_ids,
|
||||||
input_ids, hidden_states, self.lm_head, forward_batch
|
positions,
|
||||||
|
forward_batch,
|
||||||
|
input_embeds,
|
||||||
|
pp_proxy_tensors=pp_proxy_tensors,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.pp_group.is_last_rank:
|
||||||
|
return self.logits_processor(
|
||||||
|
input_ids, hidden_states, self.lm_head, forward_batch
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
@property
|
||||||
|
def start_layer(self):
|
||||||
|
return self.model.start_layer
|
||||||
|
|
||||||
|
@property
|
||||||
|
def end_layer(self):
|
||||||
|
return self.model.end_layer
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
# (param_name, shard_name, shard_id)
|
# (param_name, shard_name, shard_id)
|
||||||
@@ -348,6 +398,17 @@ class MixtralForCausalLM(nn.Module):
|
|||||||
|
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
|
layer_id = get_layer_id(name)
|
||||||
|
if (
|
||||||
|
layer_id is not None
|
||||||
|
and hasattr(self.model, "start_layer")
|
||||||
|
and (
|
||||||
|
layer_id < self.model.start_layer
|
||||||
|
or layer_id >= self.model.end_layer
|
||||||
|
)
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
if "rotary_emb.inv_freq" in name:
|
if "rotary_emb.inv_freq" in name:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -398,11 +459,14 @@ class MixtralForCausalLM(nn.Module):
|
|||||||
if name is None:
|
if name is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
param = params_dict[name]
|
if name in params_dict.keys():
|
||||||
weight_loader = getattr(
|
param = params_dict[name]
|
||||||
param, "weight_loader", default_weight_loader
|
weight_loader = getattr(
|
||||||
)
|
param, "weight_loader", default_weight_loader
|
||||||
weight_loader(param, loaded_weight)
|
)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
else:
|
||||||
|
logger.warning(f"Parameter {name} not found in params_dict")
|
||||||
|
|
||||||
|
|
||||||
EntryClass = MixtralForCausalLM
|
EntryClass = MixtralForCausalLM
|
||||||
|
|||||||
@@ -347,6 +347,12 @@ class ServerArgs:
|
|||||||
f"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
|
f"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.pp_size > 1:
|
||||||
|
self.disable_overlap_schedule = True
|
||||||
|
logger.warning(
|
||||||
|
"Pipeline parallelism is incompatible with overlap schedule."
|
||||||
|
)
|
||||||
|
|
||||||
# Speculative Decoding
|
# Speculative Decoding
|
||||||
if self.speculative_algorithm == "NEXTN":
|
if self.speculative_algorithm == "NEXTN":
|
||||||
# NEXTN shares the same implementation of EAGLE
|
# NEXTN shares the same implementation of EAGLE
|
||||||
|
|||||||
@@ -282,7 +282,9 @@ def calculate_time(show=False, min_cost_ms=0.0):
|
|||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def get_available_gpu_memory(device, gpu_id, distributed=False, empty_cache=True):
|
def get_available_gpu_memory(
|
||||||
|
device, gpu_id, distributed=False, empty_cache=True, cpu_group=None
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Get available memory for cuda:gpu_id device.
|
Get available memory for cuda:gpu_id device.
|
||||||
When distributed is True, the available memory is the minimum available memory of all GPUs.
|
When distributed is True, the available memory is the minimum available memory of all GPUs.
|
||||||
@@ -344,10 +346,10 @@ def get_available_gpu_memory(device, gpu_id, distributed=False, empty_cache=True
|
|||||||
free_gpu_memory, total_gpu_memory = torch.npu.mem_get_info()
|
free_gpu_memory, total_gpu_memory = torch.npu.mem_get_info()
|
||||||
|
|
||||||
if distributed:
|
if distributed:
|
||||||
tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to(
|
tensor = torch.tensor(free_gpu_memory, dtype=torch.float32)
|
||||||
torch.device(device, gpu_id)
|
torch.distributed.all_reduce(
|
||||||
|
tensor, op=torch.distributed.ReduceOp.MIN, group=cpu_group
|
||||||
)
|
)
|
||||||
torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.MIN)
|
|
||||||
free_gpu_memory = tensor.item()
|
free_gpu_memory = tensor.item()
|
||||||
|
|
||||||
return free_gpu_memory / (1 << 30)
|
return free_gpu_memory / (1 << 30)
|
||||||
|
|||||||
@@ -272,6 +272,50 @@ class TestBenchServing(CustomTestCase):
|
|||||||
else:
|
else:
|
||||||
self.assertGreater(res["output_throughput"], 2200)
|
self.assertGreater(res["output_throughput"], 2200)
|
||||||
|
|
||||||
|
def test_pp_offline_throughput_default_decode(self):
|
||||||
|
res = run_bench_serving(
|
||||||
|
model=DEFAULT_MOE_MODEL_NAME_FOR_TEST,
|
||||||
|
num_prompts=1000,
|
||||||
|
request_rate=float("inf"),
|
||||||
|
random_input_len=1,
|
||||||
|
random_output_len=1024,
|
||||||
|
other_server_args=["--pp", "2"],
|
||||||
|
need_warmup=True,
|
||||||
|
seed=42,
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_in_ci():
|
||||||
|
write_github_step_summary(
|
||||||
|
f"### test_pp_offline_throughput_default_decode\n"
|
||||||
|
f'Output throughput: {res["output_throughput"]:.2f} token/s\n'
|
||||||
|
)
|
||||||
|
self.assertGreater(res["output_throughput"], 7500)
|
||||||
|
|
||||||
|
def test_pp_long_context_prefill(self):
|
||||||
|
res = run_bench_serving(
|
||||||
|
model="meta-llama/Llama-3.3-70B-Instruct",
|
||||||
|
num_prompts=4,
|
||||||
|
request_rate=float("inf"),
|
||||||
|
random_input_len=128000,
|
||||||
|
random_output_len=1,
|
||||||
|
dataset_name="random",
|
||||||
|
other_server_args=[
|
||||||
|
"--quantization",
|
||||||
|
"fp8",
|
||||||
|
"--pp",
|
||||||
|
2,
|
||||||
|
],
|
||||||
|
need_warmup=False,
|
||||||
|
seed=42,
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_in_ci():
|
||||||
|
write_github_step_summary(
|
||||||
|
f"### test_pp_long_context_latency_prefill\n"
|
||||||
|
f'input_throughput: {res["input_throughput"]:.2f} ms\n'
|
||||||
|
)
|
||||||
|
self.assertGreater(res["input_throughput"], 4000)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user