Add some flags to allow sync token ids across TP ranks (#3060)
This commit is contained in:
@@ -2,12 +2,18 @@ import logging
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import nn
|
||||
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_group
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.utils import crash_on_warnings, is_flashinfer_available
|
||||
from sglang.srt.utils import (
|
||||
crash_on_warnings,
|
||||
get_bool_env_var,
|
||||
is_flashinfer_available,
|
||||
)
|
||||
|
||||
if is_flashinfer_available():
|
||||
from flashinfer.sampling import (
|
||||
@@ -20,6 +26,8 @@ if is_flashinfer_available():
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SYNC_TOKEN_IDS_ACROSS_TP = get_bool_env_var("SYNC_TOKEN_IDS_ACROSS_TP")
|
||||
|
||||
|
||||
class Sampler(nn.Module):
|
||||
def __init__(self):
|
||||
@@ -121,6 +129,20 @@ class Sampler(nn.Module):
|
||||
batch_next_token_ids,
|
||||
]
|
||||
|
||||
if SYNC_TOKEN_IDS_ACROSS_TP or sampling_info.grammars:
|
||||
# For performance reasons, SGLang does not sync the final token IDs across TP ranks by default.
|
||||
# This saves one all-reduce, but the correctness of this approach depends on the determinism of several operators:
|
||||
# the last all-reduce, the last lm_head matmul, and all sampling kernels.
|
||||
# These kernels are deterministic in most cases, but there are some rare instances where they are not deterministic.
|
||||
# In such cases, enable this env variable to prevent hanging due to TP ranks becoming desynchronized.
|
||||
# When using xgrammar, this becomes more likely so we also do the sync when grammar is used.
|
||||
|
||||
torch.distributed.all_reduce(
|
||||
batch_next_token_ids,
|
||||
op=dist.ReduceOp.MIN,
|
||||
group=get_tensor_model_parallel_group().device_group,
|
||||
)
|
||||
|
||||
return batch_next_token_ids.to(torch.int32)
|
||||
|
||||
def _apply_custom_logit_processor(
|
||||
|
||||
Reference in New Issue
Block a user