diff --git a/benchmark/gsm8k/bench_sglang.py b/benchmark/gsm8k/bench_sglang.py index d9d4b0ab2..d32790fe0 100644 --- a/benchmark/gsm8k/bench_sglang.py +++ b/benchmark/gsm8k/bench_sglang.py @@ -64,7 +64,9 @@ def main(args): @sgl.function def few_shot_gsm8k(s, question): s += few_shot_examples + question - s += sgl.gen("answer", max_tokens=512, stop=["Question", "Assistant:"]) + s += sgl.gen( + "answer", max_tokens=512, stop=["Question", "Assistant:", "<|separator|>"] + ) ##################################### ########## SGL Program End ########## diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 541fa0f15..2e0ce6d5c 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -67,10 +67,12 @@ class LogitsMetadata: class LogitsProcessor(nn.Module): - def __init__(self, config): + def __init__(self, config, skip_all_gather: bool = False): super().__init__() self.config = config - self.tp_size = get_tensor_model_parallel_world_size() + self.do_tensor_parallel_all_gather = ( + not skip_all_gather and get_tensor_model_parallel_world_size() > 1 + ) def _get_normalized_prompt_logprobs( self, input_token_logprobs, logits_metadata: LogitsMetadata @@ -159,7 +161,7 @@ class LogitsProcessor(nn.Module): last_hidden = hidden_states[last_index] last_logits = torch.matmul(last_hidden, weight.T) - if self.tp_size > 1: + if self.do_tensor_parallel_all_gather: last_logits = tensor_model_parallel_all_gather(last_logits) last_logits = last_logits[:, : self.config.vocab_size].float() @@ -204,7 +206,7 @@ class LogitsProcessor(nn.Module): ) else: all_logits = torch.matmul(hidden_states, weight.T) - if self.tp_size > 1: + if self.do_tensor_parallel_all_gather: all_logits = tensor_model_parallel_all_gather(all_logits) all_logits = all_logits[:, : self.config.vocab_size].float() diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 9e86c9b18..f6706781d 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -21,7 +21,9 @@ from dataclasses import dataclass from typing import List, Optional, Union import torch +import torch.distributed as dist from flashinfer.sampling import top_k_top_p_sampling_from_probs +from vllm.distributed import get_tensor_model_parallel_group import sglang.srt.sampling.penaltylib as penaltylib from sglang.global_config import global_config @@ -724,7 +726,7 @@ class ScheduleBatch: ) self.logit_bias = torch.concat([self.logit_bias, other.logit_bias]) - def sample(self, logits: torch.Tensor): + def sample(self, logits: torch.Tensor, is_multi_node_tp=False): # TODO(lsyin): move this into a part of layer and run with CUDA Graph # Post process logits logits = logits.contiguous() @@ -779,6 +781,16 @@ class ScheduleBatch: self.penalizer_orchestrator.cumulate_output_tokens(batch_next_token_ids) + if is_multi_node_tp: + # If the tensor parallelism spans across multiple nodes, there is some indeterminism + # that can cause the TP workers to generate different tokens, so we need to + # sync here + torch.distributed.all_reduce( + batch_next_token_ids, + op=dist.ReduceOp.MIN, + group=get_tensor_model_parallel_group().device_group, + ) + return batch_next_token_ids diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 4d869c591..945a4c95e 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -85,10 +85,6 @@ class ModelTpServer: self.schedule_policy = server_args.schedule_policy self.disable_regex_jump_forward = server_args.disable_regex_jump_forward - # Chunked prefill - self.chunked_prefill_size = server_args.chunked_prefill_size - self.current_inflight_req = None - # Init model and tokenizer self.model_config = ModelConfig( server_args.model_path, @@ -175,6 +171,10 @@ class ModelTpServer: self.num_generated_tokens = 0 self.last_stats_tic = time.time() + # Chunked prefill + self.chunked_prefill_size = server_args.chunked_prefill_size + self.current_inflight_req = None + # Init the FSM cache for constrained generation if not server_args.skip_tokenizer_init: self.regex_fsm_cache = FSMCache( @@ -444,7 +444,9 @@ class ModelTpServer: # Forward and sample the next tokens if batch.extend_num_tokens != 0: output = self.model_runner.forward(batch, ForwardMode.EXTEND) - next_token_ids = batch.sample(output.next_token_logits) + next_token_ids = batch.sample( + output.next_token_logits, self.model_runner.is_multi_node_tp + ) # Move logprobs to cpu if output.next_token_logprobs is not None: @@ -603,7 +605,9 @@ class ModelTpServer: # Forward and sample the next tokens output = self.model_runner.forward(batch, ForwardMode.DECODE) - next_token_ids = batch.sample(output.next_token_logits) + next_token_ids = batch.sample( + output.next_token_logits, self.model_runner.is_multi_node_tp + ) # Move logprobs to cpu if output.next_token_logprobs is not None: diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 3d4e5d4c6..af39065cf 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -142,7 +142,7 @@ class CudaGraphRunner: set_torch_compile_config() def can_run(self, batch_size): - return batch_size < self.max_bs + return batch_size <= self.max_bs def capture(self, batch_size_list): self.batch_size_list = batch_size_list @@ -239,12 +239,23 @@ class CudaGraphRunner: return forward(input_ids, input_metadata.positions, input_metadata) for _ in range(2): + torch.cuda.synchronize() + self.model_runner.tp_group.barrier() + run_once() + torch.cuda.synchronize() + self.model_runner.tp_group.barrier() + torch.cuda.synchronize() + self.model_runner.tp_group.barrier() + with torch.cuda.graph(graph, pool=self.graph_memory_pool, stream=stream): out = run_once() + torch.cuda.synchronize() + self.model_runner.tp_group.barrier() + self.graph_memory_pool = graph.pool() return graph, None, out, flashinfer_decode_wrapper @@ -278,7 +289,9 @@ class CudaGraphRunner: ) # Replay + torch.cuda.synchronize() self.graphs[bs].replay() + torch.cuda.synchronize() output = self.output_buffers[bs] # Unpad diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 7af4ec2dd..2de432144 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -38,6 +38,7 @@ from vllm.distributed import ( init_distributed_environment, initialize_model_parallel, ) +from vllm.distributed.parallel_state import in_the_same_node_as from vllm.model_executor.model_loader import get_model from vllm.model_executor.models import ModelRegistry @@ -112,10 +113,13 @@ class ModelRunner: distributed_init_method=nccl_init_method, ) initialize_model_parallel(tensor_model_parallel_size=self.tp_size) - self.tp_group = get_tp_group() total_gpu_memory = get_available_gpu_memory( self.gpu_id, distributed=self.tp_size > 1 ) + self.tp_group = get_tp_group() + self.is_multi_node_tp = not all( + in_the_same_node_as(self.tp_group.cpu_group, source_rank=0) + ) if self.tp_size > 1: total_local_gpu_memory = get_available_gpu_memory(self.gpu_id) diff --git a/python/sglang/srt/models/grok.py b/python/sglang/srt/models/grok.py index eff746f1d..75b086fd6 100644 --- a/python/sglang/srt/models/grok.py +++ b/python/sglang/srt/models/grok.py @@ -295,8 +295,9 @@ class Grok1ModelForCausalLM(nn.Module): self.config = config self.quant_config = quant_config self.model = Grok1Model(config, quant_config=quant_config) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) - self.logits_processor = LogitsProcessor(config) + # self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + self.lm_head = ReplicatedLinear(config.hidden_size, config.vocab_size) + self.logits_processor = LogitsProcessor(config, skip_all_gather=True) # Monkey patch _prepare_weights to load pre-sharded weights setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)