Reduce computation and communication in DP attention (#4521)
This commit is contained in:
@@ -189,6 +189,9 @@ class GroupCoordinator:
|
||||
device_group: ProcessGroup # group for device communication
|
||||
use_pynccl: bool # a hint of whether to use PyNccl
|
||||
use_custom_allreduce: bool # a hint of whether to use CustomAllreduce
|
||||
use_message_queue_broadcaster: (
|
||||
bool # a hint of whether to use message queue broadcaster
|
||||
)
|
||||
# communicators are only created for world size > 1
|
||||
pynccl_comm: Optional[Any] # PyNccl communicator
|
||||
ca_comm: Optional[Any] # Custom allreduce communicator
|
||||
@@ -241,6 +244,7 @@ class GroupCoordinator:
|
||||
self.use_custom_allreduce = use_custom_allreduce
|
||||
self.use_hpu_communicator = use_hpu_communicator
|
||||
self.use_xpu_communicator = use_xpu_communicator
|
||||
self.use_message_queue_broadcaster = use_message_queue_broadcaster
|
||||
|
||||
# lazy import to avoid documentation build error
|
||||
from sglang.srt.distributed.device_communicators.custom_all_reduce import (
|
||||
@@ -269,7 +273,7 @@ class GroupCoordinator:
|
||||
HpuCommunicator,
|
||||
)
|
||||
|
||||
self.hpu_communicator: Optional[HpuCommunicator]
|
||||
self.hpu_communicator: Optional[HpuCommunicator] = None
|
||||
if use_hpu_communicator and self.world_size > 1:
|
||||
self.hpu_communicator = HpuCommunicator(group=self.device_group)
|
||||
|
||||
@@ -277,7 +281,7 @@ class GroupCoordinator:
|
||||
XpuCommunicator,
|
||||
)
|
||||
|
||||
self.xpu_communicator: Optional[XpuCommunicator]
|
||||
self.xpu_communicator: Optional[XpuCommunicator] = None
|
||||
if use_xpu_communicator and self.world_size > 1:
|
||||
self.xpu_communicator = XpuCommunicator(group=self.device_group)
|
||||
|
||||
|
||||
@@ -53,10 +53,8 @@ def initialize_dp_attention(
|
||||
)
|
||||
|
||||
if enable_dp_attention:
|
||||
local_rank = tp_rank % (tp_size // dp_size)
|
||||
_DP_SIZE = dp_size
|
||||
else:
|
||||
local_rank = tp_rank
|
||||
_DP_SIZE = 1
|
||||
|
||||
tp_group = get_tp_group()
|
||||
@@ -65,7 +63,7 @@ def initialize_dp_attention(
|
||||
list(range(head, head + _ATTN_TP_SIZE))
|
||||
for head in range(0, tp_size, _ATTN_TP_SIZE)
|
||||
],
|
||||
local_rank,
|
||||
tp_group.local_rank,
|
||||
torch.distributed.get_backend(tp_group.device_group),
|
||||
SYNC_TOKEN_IDS_ACROSS_TP,
|
||||
False,
|
||||
@@ -180,20 +178,19 @@ def memcpy_triton(dst, src, dim, offset, sz, offset_src):
|
||||
memcpy_triton_kernel[grid](dst, src, offset, sz, offset_src, chunk_size, BLOCK_SIZE)
|
||||
|
||||
|
||||
def dp_gather(
|
||||
def _dp_gather(
|
||||
global_tokens: torch.Tensor,
|
||||
local_tokens: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
layer_id: Union[str, int],
|
||||
is_partial: bool,
|
||||
):
|
||||
local_start_pos, local_num_tokens = get_dp_local_info(forward_batch)
|
||||
|
||||
global_tokens.fill_(0)
|
||||
assert local_tokens.is_contiguous()
|
||||
assert global_tokens.is_contiguous()
|
||||
if local_tokens.shape[0] > 0 and (
|
||||
layer_id != "embedding" or get_attention_tp_rank() == 0
|
||||
):
|
||||
|
||||
if local_tokens.shape[0] > 0 and (is_partial or get_attention_tp_rank() == 0):
|
||||
assert (
|
||||
global_tokens.untyped_storage().data_ptr()
|
||||
!= local_tokens.untyped_storage().data_ptr()
|
||||
@@ -216,6 +213,22 @@ def dp_gather(
|
||||
global_tokens[:] = tensor_model_parallel_all_reduce(global_tokens)
|
||||
|
||||
|
||||
def dp_gather_partial(
|
||||
global_tokens: torch.Tensor,
|
||||
local_tokens: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
):
|
||||
_dp_gather(global_tokens, local_tokens, forward_batch, is_partial=True)
|
||||
|
||||
|
||||
def dp_gather_replicate(
|
||||
global_tokens: torch.Tensor,
|
||||
local_tokens: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
):
|
||||
_dp_gather(global_tokens, local_tokens, forward_batch, is_partial=False)
|
||||
|
||||
|
||||
def dp_scatter(
|
||||
local_tokens: torch.Tensor, # output
|
||||
global_tokens: torch.Tensor, # input
|
||||
@@ -236,16 +249,3 @@ def dp_scatter(
|
||||
memcpy_triton(
|
||||
local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True
|
||||
)
|
||||
|
||||
|
||||
def get_do_logits_dp_scatter(forward_batch: ForwardBatch):
|
||||
def do_logits_dp_scatter(logits: torch.Tensor):
|
||||
local_logits = torch.empty(
|
||||
(forward_batch.input_ids.shape[0], *logits.shape[1:]),
|
||||
dtype=logits.dtype,
|
||||
device=logits.device,
|
||||
)
|
||||
dp_scatter(local_logits, logits, forward_batch)
|
||||
return local_logits
|
||||
|
||||
return do_logits_dp_scatter
|
||||
|
||||
@@ -28,7 +28,7 @@ from sglang.srt.distributed import (
|
||||
tensor_model_parallel_all_gather,
|
||||
)
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
dp_gather,
|
||||
dp_gather_replicate,
|
||||
dp_scatter,
|
||||
get_attention_dp_rank,
|
||||
get_attention_dp_size,
|
||||
@@ -428,7 +428,7 @@ class LogitsProcessor(nn.Module):
|
||||
logits_metadata.gathered_buffer,
|
||||
hidden_states.clone(),
|
||||
)
|
||||
dp_gather(hidden_states, local_hidden_states, logits_metadata, "embedding")
|
||||
dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
|
||||
|
||||
if hasattr(lm_head, "weight"):
|
||||
logits = torch.matmul(
|
||||
|
||||
@@ -33,7 +33,7 @@ from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
|
||||
decode_attention_fwd_grouped_rope,
|
||||
)
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
dp_gather,
|
||||
dp_gather_partial,
|
||||
dp_scatter,
|
||||
get_attention_dp_size,
|
||||
get_attention_tp_rank,
|
||||
@@ -939,11 +939,47 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
forward_batch: ForwardBatch,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
if residual is None:
|
||||
if hidden_states.shape[0] == 0:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
|
||||
# Self Attention
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
|
||||
# Gather
|
||||
if get_tensor_model_parallel_world_size() > 1:
|
||||
# all gather and all reduce
|
||||
if self.dp_size != 1:
|
||||
if get_attention_tp_rank() == 0:
|
||||
hidden_states += residual
|
||||
hidden_states, local_hidden_states = (
|
||||
forward_batch.gathered_buffer,
|
||||
hidden_states,
|
||||
)
|
||||
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
|
||||
dp_scatter(residual, hidden_states, forward_batch)
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual
|
||||
)
|
||||
else:
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
|
||||
# Scatter
|
||||
if self.dp_size != 1:
|
||||
@@ -955,31 +991,6 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
)
|
||||
dp_scatter(hidden_states, global_hidden_states, forward_batch)
|
||||
|
||||
# Self Attention
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
|
||||
# Gather
|
||||
if get_tensor_model_parallel_world_size() > 1:
|
||||
# all gather and all reduce
|
||||
if self.dp_size != 1:
|
||||
hidden_states, local_hidden_states = (
|
||||
forward_batch.gathered_buffer,
|
||||
hidden_states,
|
||||
)
|
||||
dp_gather(
|
||||
hidden_states, local_hidden_states, forward_batch, self.layer_id
|
||||
)
|
||||
else:
|
||||
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
||||
|
||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@@ -1025,18 +1036,6 @@ class DeepseekV2Model(nn.Module):
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
# Gather
|
||||
if self.dp_size != 1:
|
||||
input_ids, local_input_ids = (
|
||||
torch.empty(
|
||||
(forward_batch.gathered_buffer.shape[0],),
|
||||
dtype=input_ids.dtype,
|
||||
device=input_ids.device,
|
||||
),
|
||||
input_ids,
|
||||
)
|
||||
dp_gather(input_ids, local_input_ids, forward_batch, "embedding")
|
||||
|
||||
if input_embeds is None:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
else:
|
||||
@@ -1087,15 +1086,6 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
|
||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||
|
||||
if self.dp_size != 1:
|
||||
# important: forward batch.gathered_buffer is used both after scatter and after gather.
|
||||
# be careful about this!
|
||||
hidden_states, global_hidden_states = (
|
||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
||||
hidden_states,
|
||||
)
|
||||
dp_scatter(hidden_states, global_hidden_states, forward_batch)
|
||||
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user