diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 8c0b0201c..2548ea59e 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -1260,6 +1260,16 @@ class ModelRunner: // self.server_args.page_size * self.server_args.page_size ) + # different pp rank may have different num of layers, so we need to reduce the max_total_num_tokens + if self.pp_size > 1: + tensor = torch.tensor(self.max_total_num_tokens, dtype=torch.int64) + torch.distributed.all_reduce( + tensor, + op=torch.distributed.ReduceOp.MIN, + group=get_world_group().cpu_group, + ) + self.max_total_num_tokens = tensor.item() + # create token size for hybrid cache if self.is_hybrid: self.set_num_token_hybrid()