Improve DP attention (#4390)

Co-authored-by: dhou-xai <dhou@x.ai>
Co-authored-by: SangBin Cho <rkooo567@gmail.com>
This commit is contained in:
Lianmin Zheng
2025-03-13 08:23:56 -07:00
committed by GitHub
parent f141298a3c
commit 8e66fbecee
9 changed files with 345 additions and 226 deletions

View File

@@ -1,6 +1,8 @@
from __future__ import annotations
import functools
import logging
from contextlib import contextmanager
from typing import TYPE_CHECKING, Union
import torch
@@ -14,6 +16,8 @@ from sglang.srt.distributed import (
tensor_model_parallel_all_reduce,
)
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
@@ -86,6 +90,27 @@ def get_attention_dp_size():
return _DP_SIZE
@contextmanager
def disable_dp_size():
"""Patch the tp group temporarily until this function ends.
This method is for draft workers of speculative decoding to run draft model
with different tp degree from that of target model workers.
Args:
tp_group (GroupCoordinator): the tp group coordinator
"""
global _DP_SIZE
assert _DP_SIZE is not None, "dp attention not initialized!"
old_dp_size = _DP_SIZE
_DP_SIZE = 1
try:
yield
finally:
_DP_SIZE = old_dp_size
def get_dp_local_info(forward_batch: ForwardBatch):
dp_rank = get_attention_dp_rank()
@@ -159,7 +184,8 @@ def dp_gather(
layer_id != "embedding" or get_attention_tp_rank() == 0
):
assert (
global_tokens.storage().data_ptr() != local_tokens.storage().data_ptr()
global_tokens.untyped_storage().data_ptr()
!= local_tokens.untyped_storage().data_ptr()
), "aliasing between global_tokens and local_tokens not allowed"
memcpy_triton(
global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
@@ -174,8 +200,9 @@ def dp_gather(
torch.ops.sglang.inplace_all_reduce(
global_tokens, group_name=get_tp_group().unique_name
)
else:
global_tokens = tensor_model_parallel_all_reduce(global_tokens)
global_tokens[:] = tensor_model_parallel_all_reduce(global_tokens)
def dp_scatter(
@@ -186,6 +213,7 @@ def dp_scatter(
# local_num_tokens is not necessarily the same as local_tokens.shape[0],
# since local_tokens may be padded for cuda graph
local_start_pos, local_num_tokens = get_dp_local_info(forward_batch)
local_tokens.fill_(0)
assert local_tokens.is_contiguous()
assert global_tokens.is_contiguous()