diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index aea02c5e5..d20365722 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -229,6 +229,18 @@ jobs: cd test/srt python3 -m unittest test_bench_serving.TestBenchServing.test_moe_offline_throughput_without_radix_cache + - name: Benchmark offline decode throughput (PP=2) + timeout-minutes: 10 + run: | + cd test/srt + python3 -m unittest test_bench_serving.TestBenchServing.test_pp_offline_throughput_default_decode + + - name: Benchmark offline prefill throughput (PP=2) + timeout-minutes: 10 + run: | + cd test/srt + python3 -m unittest test_bench_serving.TestBenchServing.test_pp_long_context_prefill + accuracy-test-1-gpu: if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && github.event.pull_request.draft == false diff --git a/python/sglang/srt/managers/schedule_policy.py b/python/sglang/srt/managers/schedule_policy.py index 3cd2be26b..777ecd343 100644 --- a/python/sglang/srt/managers/schedule_policy.py +++ b/python/sglang/srt/managers/schedule_policy.py @@ -468,9 +468,6 @@ class PrefillAdder: return AddReqResult.OTHER with self._lock_node(req.last_node): - if total_tokens > self.rem_total_tokens: - return AddReqResult.NO_TOKEN - if ( enable_hierarchical_cache and req.last_node_global is not None diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 0e7cb29a8..b9758465b 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -719,7 +719,7 @@ class Scheduler( server_is_idle = False result = self.run_batch(self.cur_batch) - # send the outputs to the next step + # (last rank) send the outputs to the next step if self.pp_group.is_last_rank: if self.cur_batch: next_token_ids, bids[mb_id] = ( @@ -759,18 +759,18 @@ class Scheduler( self.process_batch_result(mbs[next_mb_id], output_result) last_mbs[next_mb_id] = mbs[next_mb_id] - # carry the outputs to the next stage + # (not last rank) if not self.pp_group.is_last_rank: if self.cur_batch: bids[mb_id] = result.bid + # carry the outputs to the next stage + # send the outputs from the last round to let the next stage worker run post processing if pp_outputs: - # send the outputs from the last round to let the next stage worker run post processing self.pp_group.send_tensor_dict( pp_outputs.tensors, all_gather_group=self.attn_tp_group, ) - if not self.pp_group.is_last_rank: # send out reqs to the next stage dp_offset = self.dp_rank * self.attn_tp_size if self.attn_tp_rank == 0: diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 4ff92ec86..235869b35 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -32,6 +32,7 @@ from sglang.srt.configs.load_config import LoadConfig from sglang.srt.configs.model_config import AttentionArch, ModelConfig from sglang.srt.distributed import ( get_tp_group, + get_world_group, init_distributed_environment, initialize_model_parallel, set_custom_all_reduce, @@ -404,7 +405,10 @@ class ModelRunner: ) min_per_gpu_memory = get_available_gpu_memory( - self.device, self.gpu_id, distributed=self.tp_size > 1 + self.device, + self.gpu_id, + distributed=get_world_group().world_size > 1, + cpu_group=get_world_group().cpu_group, ) self.tp_group = get_tp_group() self.attention_tp_group = get_attention_tp_group() @@ -716,7 +720,10 @@ class ModelRunner: def profile_max_num_token(self, total_gpu_memory: int): available_gpu_memory = get_available_gpu_memory( - self.device, self.gpu_id, distributed=self.tp_size > 1 + self.device, + self.gpu_id, + distributed=get_world_group().world_size > 1, + cpu_group=get_world_group().cpu_group, ) if self.use_mla_backend: num_layers = ( diff --git a/python/sglang/srt/models/mixtral.py b/python/sglang/srt/models/mixtral.py index 8d5a03f0a..90a12f12f 100644 --- a/python/sglang/srt/models/mixtral.py +++ b/python/sglang/srt/models/mixtral.py @@ -16,13 +16,15 @@ # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1 """Inference-only Mixtral model.""" -from typing import Iterable, Optional, Tuple +import logging +from typing import Iterable, Optional, Tuple, Union import torch from torch import nn from transformers import MixtralConfig from sglang.srt.distributed import ( + get_pp_group, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) @@ -38,14 +40,17 @@ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import get_rope +from sglang.srt.layers.utils import PPMissingLayer, get_layer_id from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) from sglang.srt.managers.schedule_batch import global_server_args_dict -from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_loader.weight_utils import default_weight_loader -from sglang.srt.utils import add_prefix +from sglang.srt.utils import add_prefix, make_layers + +logger = logging.getLogger(__name__) class MixtralMoE(nn.Module): @@ -257,24 +262,32 @@ class MixtralModel(nn.Module): super().__init__() self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size + self.pp_group = get_pp_group() - self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - prefix=add_prefix("embed_tokens", prefix), + if self.pp_group.is_first_rank: + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + prefix=add_prefix("embed_tokens", prefix), + ) + else: + self.embed_tokens = PPMissingLayer() + + self.layers, self.start_layer, self.end_layer = make_layers( + config.num_hidden_layers, + lambda idx, prefix: MixtralDecoderLayer( + config=config, quant_config=quant_config, layer_id=idx, prefix=prefix + ), + pp_rank=self.pp_group.rank_in_group, + pp_size=self.pp_group.world_size, + prefix="layers", + return_tuple=True, ) - self.layers = nn.ModuleList( - [ - MixtralDecoderLayer( - config, - i, - quant_config=quant_config, - prefix=add_prefix(f"layers.{i}", prefix), - ) - for i in range(config.num_hidden_layers) - ] - ) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + if self.pp_group.is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer(return_tuple=True) def forward( self, @@ -282,18 +295,35 @@ class MixtralModel(nn.Module): positions: torch.Tensor, forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, - ) -> torch.Tensor: - if input_embeds is None: - hidden_states = self.embed_tokens(input_ids) + pp_proxy_tensors: Optional[PPProxyTensors] = None, + ) -> Union[torch.Tensor, PPProxyTensors]: + if self.pp_group.is_first_rank: + if input_embeds is None: + hidden_states = self.embed_tokens(input_ids) + else: + hidden_states = input_embeds + residual = None else: - hidden_states = input_embeds - residual = None - for i in range(len(self.layers)): + assert pp_proxy_tensors is not None + hidden_states = pp_proxy_tensors["hidden_states"] + residual = pp_proxy_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer( positions, hidden_states, forward_batch, residual ) - hidden_states, _ = self.norm(hidden_states, residual) + + if not self.pp_group.is_last_rank: + return PPProxyTensors( + { + "hidden_states": hidden_states, + "residual": residual, + } + ) + else: + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states @@ -306,6 +336,7 @@ class MixtralForCausalLM(nn.Module): prefix: str = "", ) -> None: super().__init__() + self.pp_group = get_pp_group() self.config = config self.quant_config = quant_config self.model = MixtralModel( @@ -322,12 +353,31 @@ class MixtralForCausalLM(nn.Module): positions: torch.Tensor, forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, + pp_proxy_tensors: Optional[PPProxyTensors] = None, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) - return self.logits_processor( - input_ids, hidden_states, self.lm_head, forward_batch + hidden_states = self.model( + input_ids, + positions, + forward_batch, + input_embeds, + pp_proxy_tensors=pp_proxy_tensors, ) + if self.pp_group.is_last_rank: + return self.logits_processor( + input_ids, hidden_states, self.lm_head, forward_batch + ) + else: + return hidden_states + + @property + def start_layer(self): + return self.model.start_layer + + @property + def end_layer(self): + return self.model.end_layer + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) @@ -348,6 +398,17 @@ class MixtralForCausalLM(nn.Module): params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: + layer_id = get_layer_id(name) + if ( + layer_id is not None + and hasattr(self.model, "start_layer") + and ( + layer_id < self.model.start_layer + or layer_id >= self.model.end_layer + ) + ): + continue + if "rotary_emb.inv_freq" in name: continue @@ -398,11 +459,14 @@ class MixtralForCausalLM(nn.Module): if name is None: continue - param = params_dict[name] - weight_loader = getattr( - param, "weight_loader", default_weight_loader - ) - weight_loader(param, loaded_weight) + if name in params_dict.keys(): + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + else: + logger.warning(f"Parameter {name} not found in params_dict") EntryClass = MixtralForCausalLM diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 32b29bd9f..a0c9bba1d 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -347,6 +347,12 @@ class ServerArgs: f"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]." ) + if self.pp_size > 1: + self.disable_overlap_schedule = True + logger.warning( + "Pipeline parallelism is incompatible with overlap schedule." + ) + # Speculative Decoding if self.speculative_algorithm == "NEXTN": # NEXTN shares the same implementation of EAGLE diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 469ef4d75..1ab14c819 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -282,7 +282,9 @@ def calculate_time(show=False, min_cost_ms=0.0): return wrapper -def get_available_gpu_memory(device, gpu_id, distributed=False, empty_cache=True): +def get_available_gpu_memory( + device, gpu_id, distributed=False, empty_cache=True, cpu_group=None +): """ Get available memory for cuda:gpu_id device. When distributed is True, the available memory is the minimum available memory of all GPUs. @@ -344,10 +346,10 @@ def get_available_gpu_memory(device, gpu_id, distributed=False, empty_cache=True free_gpu_memory, total_gpu_memory = torch.npu.mem_get_info() if distributed: - tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to( - torch.device(device, gpu_id) + tensor = torch.tensor(free_gpu_memory, dtype=torch.float32) + torch.distributed.all_reduce( + tensor, op=torch.distributed.ReduceOp.MIN, group=cpu_group ) - torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.MIN) free_gpu_memory = tensor.item() return free_gpu_memory / (1 << 30) diff --git a/test/srt/test_bench_serving.py b/test/srt/test_bench_serving.py index 7aaa7ab7c..c45c7ccf6 100644 --- a/test/srt/test_bench_serving.py +++ b/test/srt/test_bench_serving.py @@ -272,6 +272,50 @@ class TestBenchServing(CustomTestCase): else: self.assertGreater(res["output_throughput"], 2200) + def test_pp_offline_throughput_default_decode(self): + res = run_bench_serving( + model=DEFAULT_MOE_MODEL_NAME_FOR_TEST, + num_prompts=1000, + request_rate=float("inf"), + random_input_len=1, + random_output_len=1024, + other_server_args=["--pp", "2"], + need_warmup=True, + seed=42, + ) + + if is_in_ci(): + write_github_step_summary( + f"### test_pp_offline_throughput_default_decode\n" + f'Output throughput: {res["output_throughput"]:.2f} token/s\n' + ) + self.assertGreater(res["output_throughput"], 7500) + + def test_pp_long_context_prefill(self): + res = run_bench_serving( + model="meta-llama/Llama-3.3-70B-Instruct", + num_prompts=4, + request_rate=float("inf"), + random_input_len=128000, + random_output_len=1, + dataset_name="random", + other_server_args=[ + "--quantization", + "fp8", + "--pp", + 2, + ], + need_warmup=False, + seed=42, + ) + + if is_in_ci(): + write_github_step_summary( + f"### test_pp_long_context_latency_prefill\n" + f'input_throughput: {res["input_throughput"]:.2f} ms\n' + ) + self.assertGreater(res["input_throughput"], 4000) + if __name__ == "__main__": unittest.main()