diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index f3c376ed1..24f951f2b 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -2,12 +2,18 @@ import logging from typing import List import torch +import torch.distributed as dist from torch import nn +from sglang.srt.distributed import get_tensor_model_parallel_group from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo -from sglang.srt.utils import crash_on_warnings, is_flashinfer_available +from sglang.srt.utils import ( + crash_on_warnings, + get_bool_env_var, + is_flashinfer_available, +) if is_flashinfer_available(): from flashinfer.sampling import ( @@ -20,6 +26,8 @@ if is_flashinfer_available(): logger = logging.getLogger(__name__) +SYNC_TOKEN_IDS_ACROSS_TP = get_bool_env_var("SYNC_TOKEN_IDS_ACROSS_TP") + class Sampler(nn.Module): def __init__(self): @@ -121,6 +129,20 @@ class Sampler(nn.Module): batch_next_token_ids, ] + if SYNC_TOKEN_IDS_ACROSS_TP or sampling_info.grammars: + # For performance reasons, SGLang does not sync the final token IDs across TP ranks by default. + # This saves one all-reduce, but the correctness of this approach depends on the determinism of several operators: + # the last all-reduce, the last lm_head matmul, and all sampling kernels. + # These kernels are deterministic in most cases, but there are some rare instances where they are not deterministic. + # In such cases, enable this env variable to prevent hanging due to TP ranks becoming desynchronized. + # When using xgrammar, this becomes more likely so we also do the sync when grammar is used. + + torch.distributed.all_reduce( + batch_next_token_ids, + op=dist.ReduceOp.MIN, + group=get_tensor_model_parallel_group().device_group, + ) + return batch_next_token_ids.to(torch.int32) def _apply_custom_logit_processor(