Support (1 <= dp < tp) in the dp attention in DeepEP (#4770)
Co-authored-by: Cheng Wan <cwan39@gatech.edu>
This commit is contained in:
@@ -5,7 +5,7 @@ import logging
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from functools import wraps
|
||||
from typing import Callable, List, Optional, TypeVar, Union
|
||||
from typing import Any, Callable, List, Optional, TypeVar, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
@@ -439,6 +439,15 @@ class GroupCoordinator:
|
||||
else:
|
||||
torch.distributed.all_reduce(input_, group=self.device_group)
|
||||
|
||||
def reduce_scatter(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
input_list: List[torch.Tensor],
|
||||
) -> None:
|
||||
# TODO(ch-wan): support other backends
|
||||
torch.distributed.reduce_scatter(output, input_list, group=self.device_group)
|
||||
return output
|
||||
|
||||
def _all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor):
|
||||
pynccl_comm = self.pynccl_comm
|
||||
if pynccl_comm is not None and not pynccl_comm.disabled:
|
||||
@@ -456,11 +465,23 @@ class GroupCoordinator:
|
||||
output, input, group_name=self.unique_name
|
||||
)
|
||||
|
||||
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||
def all_gather(
|
||||
self,
|
||||
input_: torch.Tensor,
|
||||
dim: int = -1,
|
||||
tensor_list: List[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
world_size = self.world_size
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if world_size == 1:
|
||||
return input_
|
||||
|
||||
if tensor_list is not None:
|
||||
# TODO(ch-wan): support other backends
|
||||
return torch.distributed.all_gather(
|
||||
tensor_list, input_, group=self.device_group
|
||||
)
|
||||
|
||||
assert (
|
||||
-input_.dim() <= dim < input_.dim()
|
||||
), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import functools
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Union
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
import torch
|
||||
import triton
|
||||
@@ -249,3 +249,14 @@ def dp_scatter(
|
||||
memcpy_triton(
|
||||
local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True
|
||||
)
|
||||
|
||||
|
||||
def tp_reduce_scatter(
|
||||
output: torch.Tensor,
|
||||
input_list: List[torch.Tensor],
|
||||
):
|
||||
return get_attention_tp_group().reduce_scatter(output, input_list)
|
||||
|
||||
|
||||
def tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor):
|
||||
return get_attention_tp_group().all_gather(input_, tensor_list=output_list)
|
||||
|
||||
@@ -1186,7 +1186,7 @@ class Scheduler(
|
||||
ret = None
|
||||
|
||||
# Handle DP attention
|
||||
if self.server_args.enable_dp_attention:
|
||||
if self.server_args.enable_dp_attention or self.server_args.enable_sp_layernorm:
|
||||
ret, _ = self.prepare_dp_attn_batch(ret)
|
||||
|
||||
return ret
|
||||
|
||||
@@ -174,6 +174,7 @@ class CudaGraphRunner:
|
||||
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
||||
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
|
||||
self.enable_dp_attention = model_runner.server_args.enable_dp_attention
|
||||
self.enable_sp_layernorm = model_runner.server_args.enable_sp_layernorm
|
||||
self.speculative_algorithm = model_runner.server_args.speculative_algorithm
|
||||
self.tp_size = model_runner.server_args.tp_size
|
||||
self.dp_size = model_runner.server_args.dp_size
|
||||
@@ -245,8 +246,8 @@ class CudaGraphRunner:
|
||||
)
|
||||
else:
|
||||
self.encoder_lens = None
|
||||
|
||||
if self.enable_dp_attention:
|
||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||
# TODO(ch-wan): SP layernorm should use a different logic to manage gathered_buffer
|
||||
self.gathered_buffer = torch.zeros(
|
||||
(
|
||||
self.max_bs * self.dp_size * self.num_tokens_per_bs,
|
||||
@@ -288,7 +289,7 @@ class CudaGraphRunner:
|
||||
self.model_runner.token_to_kv_pool.capture_mode = False
|
||||
|
||||
def can_run(self, forward_batch: ForwardBatch):
|
||||
if self.enable_dp_attention:
|
||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||
total_global_tokens = sum(forward_batch.global_num_tokens_cpu)
|
||||
|
||||
is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
|
||||
@@ -369,7 +370,7 @@ class CudaGraphRunner:
|
||||
encoder_lens = None
|
||||
mrope_positions = self.mrope_positions[:, :bs]
|
||||
|
||||
if self.enable_dp_attention:
|
||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||
self.global_num_tokens_gpu.copy_(
|
||||
torch.tensor(
|
||||
[
|
||||
@@ -471,7 +472,7 @@ class CudaGraphRunner:
|
||||
raw_num_token = raw_bs * self.num_tokens_per_bs
|
||||
|
||||
# Pad
|
||||
if self.enable_dp_attention:
|
||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||
index = bisect.bisect_left(
|
||||
self.capture_bs, sum(forward_batch.global_num_tokens_cpu)
|
||||
)
|
||||
@@ -497,7 +498,7 @@ class CudaGraphRunner:
|
||||
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
|
||||
if forward_batch.mrope_positions is not None:
|
||||
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
|
||||
if self.enable_dp_attention:
|
||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
|
||||
|
||||
if hasattr(forward_batch.spec_info, "hidden_states"):
|
||||
|
||||
@@ -281,9 +281,6 @@ class ModelRunner:
|
||||
|
||||
if server_args.enable_deepep_moe:
|
||||
logger.info("DeepEP is turned on.")
|
||||
assert (
|
||||
server_args.enable_dp_attention == True
|
||||
), "Currently DeepEP is bind to Attention DP. Set '--enable-dp-attention --enable-deepep-moe'"
|
||||
|
||||
def init_torch_distributed(self):
|
||||
logger.info("Init torch distributed begin.")
|
||||
|
||||
@@ -39,6 +39,8 @@ from sglang.srt.layers.dp_attention import (
|
||||
get_attention_dp_size,
|
||||
get_attention_tp_rank,
|
||||
get_attention_tp_size,
|
||||
tp_all_gather,
|
||||
tp_reduce_scatter,
|
||||
)
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.linear import (
|
||||
@@ -278,7 +280,11 @@ class DeepseekV2MoE(nn.Module):
|
||||
topk_weights = torch.empty(
|
||||
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
|
||||
)
|
||||
if forward_mode is not None and not forward_mode.is_idle():
|
||||
if (
|
||||
forward_mode is not None
|
||||
and not forward_mode.is_idle()
|
||||
and hidden_states.shape[0] > 0
|
||||
):
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
if self.n_shared_experts is not None:
|
||||
@@ -969,6 +975,14 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
is_nextn: bool = False,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
|
||||
def is_sparse_layer(l: int):
|
||||
return (
|
||||
config.n_routed_experts is not None
|
||||
and l >= config.first_k_dense_replace
|
||||
and l % config.moe_layer_freq == 0
|
||||
)
|
||||
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
@@ -977,6 +991,8 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
|
||||
self.layer_id = layer_id
|
||||
self.dp_size = get_attention_dp_size()
|
||||
self.attn_tp_size = get_attention_tp_size()
|
||||
self.attn_tp_rank = get_attention_tp_rank()
|
||||
|
||||
if not global_server_args_dict["disable_mla"]:
|
||||
self.self_attn = DeepseekV2AttentionMLA(
|
||||
@@ -1019,16 +1035,13 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
prefix=add_prefix("self_attn", prefix),
|
||||
)
|
||||
|
||||
if is_nextn or (
|
||||
config.n_routed_experts is not None
|
||||
and layer_id >= config.first_k_dense_replace
|
||||
and layer_id % config.moe_layer_freq == 0
|
||||
):
|
||||
if is_nextn or is_sparse_layer(layer_id):
|
||||
self.mlp = DeepseekV2MoE(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("mlp", prefix),
|
||||
)
|
||||
self.is_sparse = True
|
||||
else:
|
||||
self.mlp = DeepseekV2MLP(
|
||||
hidden_size=config.hidden_size,
|
||||
@@ -1037,6 +1050,14 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("mlp", prefix),
|
||||
)
|
||||
self.is_sparse = False
|
||||
|
||||
self.input_is_scattered = (
|
||||
is_sparse_layer(layer_id - 1)
|
||||
and global_server_args_dict["enable_deepep_moe"]
|
||||
)
|
||||
self.is_last_layer = self.layer_id == config.num_hidden_layers - 1
|
||||
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
@@ -1049,6 +1070,23 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
forward_batch: ForwardBatch,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
if global_server_args_dict["enable_deepep_moe"] and self.is_sparse:
|
||||
return self.forward_deepep(
|
||||
positions, hidden_states, forward_batch, residual
|
||||
)
|
||||
else:
|
||||
return self.forward_normal(
|
||||
positions, hidden_states, forward_batch, residual
|
||||
)
|
||||
|
||||
def forward_normal(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
|
||||
if hidden_states.shape[0] == 0:
|
||||
residual = hidden_states
|
||||
else:
|
||||
@@ -1065,29 +1103,35 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
|
||||
if self.attn_tp_size != 1 and self.input_is_scattered:
|
||||
hidden_states, local_hidden_states = (
|
||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
||||
hidden_states,
|
||||
)
|
||||
tp_all_gather(
|
||||
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
|
||||
)
|
||||
residual, local_residual = (
|
||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
||||
residual,
|
||||
)
|
||||
tp_all_gather(
|
||||
list(residual.tensor_split(self.attn_tp_size)), local_residual
|
||||
)
|
||||
|
||||
# Gather
|
||||
if get_tensor_model_parallel_world_size() > 1:
|
||||
# all gather and all reduce
|
||||
if self.dp_size != 1:
|
||||
if global_server_args_dict["enable_deepep_moe"] and isinstance(
|
||||
self.mlp, DeepseekV2MoE
|
||||
):
|
||||
if hidden_states.shape[0] != 0:
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual
|
||||
)
|
||||
hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
|
||||
return hidden_states, residual
|
||||
else:
|
||||
if get_attention_tp_rank() == 0:
|
||||
hidden_states += residual
|
||||
hidden_states, local_hidden_states = (
|
||||
forward_batch.gathered_buffer,
|
||||
hidden_states,
|
||||
)
|
||||
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
|
||||
dp_scatter(residual, hidden_states, forward_batch)
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
if self.attn_tp_rank == 0:
|
||||
hidden_states += residual
|
||||
hidden_states, local_hidden_states = (
|
||||
forward_batch.gathered_buffer,
|
||||
hidden_states,
|
||||
)
|
||||
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
|
||||
dp_scatter(residual, hidden_states, forward_batch)
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
@@ -1101,6 +1145,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
# Fully Connected
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
|
||||
# TODO(ch-wan): ues reduce-scatter in MLP to avoid this scatter
|
||||
# Scatter
|
||||
if self.dp_size != 1:
|
||||
# important: forward batch.gathered_buffer is used both after scatter and after gather.
|
||||
@@ -1113,6 +1158,82 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
def forward_deepep(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
|
||||
if hidden_states.shape[0] == 0:
|
||||
residual = hidden_states
|
||||
else:
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
|
||||
if self.attn_tp_size != 1 and self.input_is_scattered:
|
||||
hidden_states, local_hidden_states = (
|
||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
||||
hidden_states,
|
||||
)
|
||||
tp_all_gather(
|
||||
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
|
||||
)
|
||||
|
||||
# Self Attention
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
|
||||
if self.attn_tp_size != 1:
|
||||
if self.input_is_scattered:
|
||||
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
|
||||
hidden_states = tensor_list[self.attn_tp_rank]
|
||||
tp_reduce_scatter(hidden_states, tensor_list)
|
||||
if hidden_states.shape[0] != 0:
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual
|
||||
)
|
||||
else:
|
||||
if self.attn_tp_rank == 0:
|
||||
hidden_states += residual
|
||||
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
|
||||
hidden_states = tensor_list[self.attn_tp_rank]
|
||||
tp_reduce_scatter(hidden_states, tensor_list)
|
||||
residual = hidden_states
|
||||
if hidden_states.shape[0] != 0:
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
else:
|
||||
if hidden_states.shape[0] != 0:
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual
|
||||
)
|
||||
hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
|
||||
|
||||
if self.is_last_layer and self.attn_tp_size != 1:
|
||||
hidden_states, local_hidden_states = (
|
||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
||||
hidden_states,
|
||||
)
|
||||
tp_all_gather(
|
||||
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
|
||||
)
|
||||
residual, local_residual = (
|
||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
||||
residual,
|
||||
)
|
||||
tp_all_gather(
|
||||
list(residual.tensor_split(self.attn_tp_size)), local_residual
|
||||
)
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
class DeepseekV2Model(nn.Module):
|
||||
|
||||
|
||||
@@ -290,12 +290,17 @@ class ServerArgs:
|
||||
logger.warning(
|
||||
f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
|
||||
)
|
||||
# DeepEP MoE
|
||||
if self.enable_deepep_moe:
|
||||
self.ep_size = self.dp_size
|
||||
logger.info(
|
||||
f"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the data parallel size[{self.dp_size}]."
|
||||
)
|
||||
|
||||
self.enable_sp_layernorm = False
|
||||
# DeepEP MoE
|
||||
if self.enable_deepep_moe:
|
||||
self.ep_size = self.tp_size
|
||||
self.enable_sp_layernorm = (
|
||||
self.dp_size < self.tp_size if self.enable_dp_attention else True
|
||||
)
|
||||
logger.info(
|
||||
f"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
|
||||
)
|
||||
|
||||
# Speculative Decoding
|
||||
if self.speculative_algorithm == "NEXTN":
|
||||
|
||||
Reference in New Issue
Block a user