Fix tp token sync for dp attention (#3062)
This commit is contained in:
@@ -6,6 +6,7 @@ import torch.distributed as dist
|
||||
from torch import nn
|
||||
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_group
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_group
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
@@ -33,6 +34,10 @@ class Sampler(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.use_nan_detectioin = global_server_args_dict["enable_nan_detection"]
|
||||
self.tp_sync_group = get_tensor_model_parallel_group().device_group
|
||||
|
||||
if global_server_args_dict["enable_dp_attention"]:
|
||||
self.tp_sync_group = get_attention_tp_group().device_group
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -140,7 +145,7 @@ class Sampler(nn.Module):
|
||||
torch.distributed.all_reduce(
|
||||
batch_next_token_ids,
|
||||
op=dist.ReduceOp.MIN,
|
||||
group=get_tensor_model_parallel_group().device_group,
|
||||
group=self.tp_sync_group,
|
||||
)
|
||||
|
||||
return batch_next_token_ids.to(torch.int32)
|
||||
|
||||
Reference in New Issue
Block a user