Fix tp token sync for dp attention (#3062)
This commit is contained in:
@@ -6,6 +6,7 @@ import torch.distributed as dist
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from sglang.srt.distributed import get_tensor_model_parallel_group
|
from sglang.srt.distributed import get_tensor_model_parallel_group
|
||||||
|
from sglang.srt.layers.dp_attention import get_attention_tp_group
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||||
@@ -33,6 +34,10 @@ class Sampler(nn.Module):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.use_nan_detectioin = global_server_args_dict["enable_nan_detection"]
|
self.use_nan_detectioin = global_server_args_dict["enable_nan_detection"]
|
||||||
|
self.tp_sync_group = get_tensor_model_parallel_group().device_group
|
||||||
|
|
||||||
|
if global_server_args_dict["enable_dp_attention"]:
|
||||||
|
self.tp_sync_group = get_attention_tp_group().device_group
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -140,7 +145,7 @@ class Sampler(nn.Module):
|
|||||||
torch.distributed.all_reduce(
|
torch.distributed.all_reduce(
|
||||||
batch_next_token_ids,
|
batch_next_token_ids,
|
||||||
op=dist.ReduceOp.MIN,
|
op=dist.ReduceOp.MIN,
|
||||||
group=get_tensor_model_parallel_group().device_group,
|
group=self.tp_sync_group,
|
||||||
)
|
)
|
||||||
|
|
||||||
return batch_next_token_ids.to(torch.int32)
|
return batch_next_token_ids.to(torch.int32)
|
||||||
|
|||||||
Reference in New Issue
Block a user