From 8b84e69f25929c8de0286c6e0e0c2ce4686b561c Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Wed, 22 Jan 2025 18:51:40 -0800 Subject: [PATCH] Fix tp token sync for dp attention (#3062) --- python/sglang/srt/layers/sampler.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index 24f951f2b..3173d533d 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -6,6 +6,7 @@ import torch.distributed as dist from torch import nn from sglang.srt.distributed import get_tensor_model_parallel_group +from sglang.srt.layers.dp_attention import get_attention_tp_group from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo @@ -33,6 +34,10 @@ class Sampler(nn.Module): def __init__(self): super().__init__() self.use_nan_detectioin = global_server_args_dict["enable_nan_detection"] + self.tp_sync_group = get_tensor_model_parallel_group().device_group + + if global_server_args_dict["enable_dp_attention"]: + self.tp_sync_group = get_attention_tp_group().device_group def forward( self, @@ -140,7 +145,7 @@ class Sampler(nn.Module): torch.distributed.all_reduce( batch_next_token_ids, op=dist.ReduceOp.MIN, - group=get_tensor_model_parallel_group().device_group, + group=self.tp_sync_group, ) return batch_next_token_ids.to(torch.int32)