Add some flags to allow sync token ids across TP ranks (#3060)
This commit is contained in:
@@ -2,12 +2,18 @@ import logging
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
from sglang.srt.distributed import get_tensor_model_parallel_group
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||||
from sglang.srt.utils import crash_on_warnings, is_flashinfer_available
|
from sglang.srt.utils import (
|
||||||
|
crash_on_warnings,
|
||||||
|
get_bool_env_var,
|
||||||
|
is_flashinfer_available,
|
||||||
|
)
|
||||||
|
|
||||||
if is_flashinfer_available():
|
if is_flashinfer_available():
|
||||||
from flashinfer.sampling import (
|
from flashinfer.sampling import (
|
||||||
@@ -20,6 +26,8 @@ if is_flashinfer_available():
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
SYNC_TOKEN_IDS_ACROSS_TP = get_bool_env_var("SYNC_TOKEN_IDS_ACROSS_TP")
|
||||||
|
|
||||||
|
|
||||||
class Sampler(nn.Module):
|
class Sampler(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -121,6 +129,20 @@ class Sampler(nn.Module):
|
|||||||
batch_next_token_ids,
|
batch_next_token_ids,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if SYNC_TOKEN_IDS_ACROSS_TP or sampling_info.grammars:
|
||||||
|
# For performance reasons, SGLang does not sync the final token IDs across TP ranks by default.
|
||||||
|
# This saves one all-reduce, but the correctness of this approach depends on the determinism of several operators:
|
||||||
|
# the last all-reduce, the last lm_head matmul, and all sampling kernels.
|
||||||
|
# These kernels are deterministic in most cases, but there are some rare instances where they are not deterministic.
|
||||||
|
# In such cases, enable this env variable to prevent hanging due to TP ranks becoming desynchronized.
|
||||||
|
# When using xgrammar, this becomes more likely so we also do the sync when grammar is used.
|
||||||
|
|
||||||
|
torch.distributed.all_reduce(
|
||||||
|
batch_next_token_ids,
|
||||||
|
op=dist.ReduceOp.MIN,
|
||||||
|
group=get_tensor_model_parallel_group().device_group,
|
||||||
|
)
|
||||||
|
|
||||||
return batch_next_token_ids.to(torch.int32)
|
return batch_next_token_ids.to(torch.int32)
|
||||||
|
|
||||||
def _apply_custom_logit_processor(
|
def _apply_custom_logit_processor(
|
||||||
|
|||||||
Reference in New Issue
Block a user