[Feature] Comprehensive Hybrid Parallelism Support (#6389)
This commit is contained in:
@@ -165,7 +165,8 @@ def disable_dp_size():
|
||||
|
||||
|
||||
def get_dp_local_info(forward_batch: ForwardBatch):
|
||||
dp_rank = get_local_attention_dp_rank()
|
||||
# `get_dp_local_info` is only called in global DP gather and scatter. We use global DP rank here.
|
||||
dp_rank = get_attention_dp_rank()
|
||||
|
||||
if forward_batch.dp_local_start_pos is None:
|
||||
cumtokens = torch.cumsum(forward_batch.global_num_tokens_gpu, dim=0)
|
||||
|
||||
Reference in New Issue
Block a user