Improve DP attention (#4390)
Co-authored-by: dhou-xai <dhou@x.ai> Co-authored-by: SangBin Cho <rkooo567@gmail.com>
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
@@ -14,6 +16,8 @@ from sglang.srt.distributed import (
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
|
||||
@@ -86,6 +90,27 @@ def get_attention_dp_size():
|
||||
return _DP_SIZE
|
||||
|
||||
|
||||
@contextmanager
|
||||
def disable_dp_size():
|
||||
"""Patch the tp group temporarily until this function ends.
|
||||
|
||||
This method is for draft workers of speculative decoding to run draft model
|
||||
with different tp degree from that of target model workers.
|
||||
|
||||
Args:
|
||||
tp_group (GroupCoordinator): the tp group coordinator
|
||||
"""
|
||||
global _DP_SIZE
|
||||
assert _DP_SIZE is not None, "dp attention not initialized!"
|
||||
|
||||
old_dp_size = _DP_SIZE
|
||||
_DP_SIZE = 1
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_DP_SIZE = old_dp_size
|
||||
|
||||
|
||||
def get_dp_local_info(forward_batch: ForwardBatch):
|
||||
dp_rank = get_attention_dp_rank()
|
||||
|
||||
@@ -159,7 +184,8 @@ def dp_gather(
|
||||
layer_id != "embedding" or get_attention_tp_rank() == 0
|
||||
):
|
||||
assert (
|
||||
global_tokens.storage().data_ptr() != local_tokens.storage().data_ptr()
|
||||
global_tokens.untyped_storage().data_ptr()
|
||||
!= local_tokens.untyped_storage().data_ptr()
|
||||
), "aliasing between global_tokens and local_tokens not allowed"
|
||||
memcpy_triton(
|
||||
global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
|
||||
@@ -174,8 +200,9 @@ def dp_gather(
|
||||
torch.ops.sglang.inplace_all_reduce(
|
||||
global_tokens, group_name=get_tp_group().unique_name
|
||||
)
|
||||
|
||||
else:
|
||||
global_tokens = tensor_model_parallel_all_reduce(global_tokens)
|
||||
global_tokens[:] = tensor_model_parallel_all_reduce(global_tokens)
|
||||
|
||||
|
||||
def dp_scatter(
|
||||
@@ -186,6 +213,7 @@ def dp_scatter(
|
||||
# local_num_tokens is not necessarily the same as local_tokens.shape[0],
|
||||
# since local_tokens may be padded for cuda graph
|
||||
local_start_pos, local_num_tokens = get_dp_local_info(forward_batch)
|
||||
|
||||
local_tokens.fill_(0)
|
||||
assert local_tokens.is_contiguous()
|
||||
assert global_tokens.is_contiguous()
|
||||
|
||||
Reference in New Issue
Block a user