Performing Vocabulary Parallelism for LM Head across Attention TP Groups (#5558)
Co-authored-by: liusy58 <liusy58@linux.alibaba.com>
This commit is contained in:
@@ -252,12 +252,12 @@ def dp_scatter(
|
||||
)
|
||||
|
||||
|
||||
def tp_reduce_scatter(
|
||||
def attn_tp_reduce_scatter(
|
||||
output: torch.Tensor,
|
||||
input_list: List[torch.Tensor],
|
||||
):
|
||||
return get_attention_tp_group().reduce_scatter(output, input_list)
|
||||
|
||||
|
||||
def tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor):
|
||||
def attn_tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor):
|
||||
return get_attention_tp_group().all_gather(input_, tensor_list=output_list)
|
||||
|
||||
@@ -23,15 +23,16 @@ import triton.language as tl
|
||||
from torch import nn
|
||||
|
||||
from sglang.srt.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather,
|
||||
)
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
attn_tp_all_gather,
|
||||
dp_gather_replicate,
|
||||
dp_scatter,
|
||||
get_attention_dp_rank,
|
||||
get_attention_dp_size,
|
||||
get_attention_tp_size,
|
||||
)
|
||||
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
@@ -198,12 +199,20 @@ class LogitsProcessor(nn.Module):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.logit_scale = logit_scale
|
||||
self.do_tensor_parallel_all_gather = (
|
||||
not skip_all_gather and get_tensor_model_parallel_world_size() > 1
|
||||
)
|
||||
self.do_tensor_parallel_all_gather_dp_attn = (
|
||||
self.do_tensor_parallel_all_gather and get_attention_dp_size() != 1
|
||||
)
|
||||
self.use_attn_tp_group = global_server_args_dict["enable_dp_lm_head"]
|
||||
if self.use_attn_tp_group:
|
||||
self.attn_tp_size = get_attention_tp_size()
|
||||
self.do_tensor_parallel_all_gather = (
|
||||
not skip_all_gather and self.attn_tp_size > 1
|
||||
)
|
||||
self.do_tensor_parallel_all_gather_dp_attn = False
|
||||
else:
|
||||
self.do_tensor_parallel_all_gather = (
|
||||
not skip_all_gather and get_tensor_model_parallel_world_size() > 1
|
||||
)
|
||||
self.do_tensor_parallel_all_gather_dp_attn = (
|
||||
self.do_tensor_parallel_all_gather and get_attention_dp_size() != 1
|
||||
)
|
||||
self.final_logit_softcapping = getattr(
|
||||
self.config, "final_logit_softcapping", None
|
||||
)
|
||||
@@ -442,7 +451,19 @@ class LogitsProcessor(nn.Module):
|
||||
logits.mul_(self.logit_scale)
|
||||
|
||||
if self.do_tensor_parallel_all_gather:
|
||||
logits = tensor_model_parallel_all_gather(logits)
|
||||
if self.use_attn_tp_group:
|
||||
global_logits = torch.empty(
|
||||
(self.config.vocab_size, logits.shape[0]),
|
||||
device=logits.device,
|
||||
dtype=logits.dtype,
|
||||
)
|
||||
global_logits = global_logits.T
|
||||
attn_tp_all_gather(
|
||||
list(global_logits.tensor_split(self.attn_tp_size, dim=-1)), logits
|
||||
)
|
||||
logits = global_logits
|
||||
else:
|
||||
logits = tensor_model_parallel_all_gather(logits)
|
||||
|
||||
if self.do_tensor_parallel_all_gather_dp_attn:
|
||||
logits, global_logits = (
|
||||
|
||||
@@ -13,6 +13,7 @@ from sglang.srt.distributed import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
|
||||
from sglang.srt.layers.parameter import BasevLLMParameter
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
@@ -214,12 +215,14 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
*,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
org_num_embeddings: Optional[int] = None,
|
||||
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
enable_tp: bool = True,
|
||||
use_attn_tp_group: bool = False,
|
||||
use_presharded_weights: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -227,9 +230,14 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
|
||||
self.enable_tp = enable_tp
|
||||
if self.enable_tp:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
if use_attn_tp_group:
|
||||
tp_rank = get_attention_tp_rank()
|
||||
self.tp_size = get_attention_tp_size()
|
||||
else:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
else:
|
||||
assert use_attn_tp_group is False
|
||||
tp_rank = 0
|
||||
self.tp_size = 1
|
||||
|
||||
@@ -519,22 +527,25 @@ class ParallelLMHead(VocabParallelEmbedding):
|
||||
self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
*,
|
||||
bias: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
org_num_embeddings: Optional[int] = None,
|
||||
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
use_attn_tp_group: bool = False,
|
||||
use_presharded_weights: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
num_embeddings,
|
||||
embedding_dim,
|
||||
params_dtype,
|
||||
org_num_embeddings,
|
||||
padding_size,
|
||||
quant_config,
|
||||
prefix,
|
||||
params_dtype=params_dtype,
|
||||
org_num_embeddings=org_num_embeddings,
|
||||
padding_size=padding_size,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
use_attn_tp_group=use_attn_tp_group,
|
||||
use_presharded_weights=use_presharded_weights,
|
||||
)
|
||||
self.quant_config = quant_config
|
||||
|
||||
@@ -74,6 +74,7 @@ global_server_args_dict = {
|
||||
"disable_radix_cache": ServerArgs.disable_radix_cache,
|
||||
"enable_deepep_moe": ServerArgs.enable_deepep_moe,
|
||||
"enable_dp_attention": ServerArgs.enable_dp_attention,
|
||||
"enable_dp_lm_head": ServerArgs.enable_dp_lm_head,
|
||||
"enable_ep_moe": ServerArgs.enable_ep_moe,
|
||||
"enable_nan_detection": ServerArgs.enable_nan_detection,
|
||||
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
|
||||
|
||||
@@ -36,13 +36,13 @@ from sglang.srt.distributed import (
|
||||
)
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
attn_tp_all_gather,
|
||||
attn_tp_reduce_scatter,
|
||||
dp_gather_partial,
|
||||
dp_scatter,
|
||||
get_attention_dp_size,
|
||||
get_attention_tp_rank,
|
||||
get_attention_tp_size,
|
||||
tp_all_gather,
|
||||
tp_reduce_scatter,
|
||||
)
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.linear import (
|
||||
@@ -1323,7 +1323,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
||||
hidden_states,
|
||||
)
|
||||
tp_all_gather(
|
||||
attn_tp_all_gather(
|
||||
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
|
||||
)
|
||||
|
||||
@@ -1339,7 +1339,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
if self.input_is_scattered:
|
||||
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
|
||||
hidden_states = tensor_list[self.attn_tp_rank]
|
||||
tp_reduce_scatter(hidden_states, tensor_list)
|
||||
attn_tp_reduce_scatter(hidden_states, tensor_list)
|
||||
if hidden_states.shape[0] != 0:
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual
|
||||
@@ -1349,7 +1349,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
hidden_states += residual
|
||||
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
|
||||
hidden_states = tensor_list[self.attn_tp_rank]
|
||||
tp_reduce_scatter(hidden_states, tensor_list)
|
||||
attn_tp_reduce_scatter(hidden_states, tensor_list)
|
||||
residual = hidden_states
|
||||
if hidden_states.shape[0] != 0:
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
@@ -1373,7 +1373,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
||||
hidden_states,
|
||||
)
|
||||
tp_all_gather(
|
||||
attn_tp_all_gather(
|
||||
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
|
||||
)
|
||||
|
||||
@@ -1475,6 +1475,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.dp_size = get_attention_dp_size()
|
||||
|
||||
@@ -45,6 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||
from sglang.srt.model_loader.weight_utils import (
|
||||
default_weight_loader,
|
||||
@@ -420,6 +421,7 @@ class LlamaForCausalLM(nn.Module):
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||
|
||||
@@ -159,6 +159,7 @@ class ServerArgs:
|
||||
disable_overlap_schedule: bool = False
|
||||
enable_mixed_chunk: bool = False
|
||||
enable_dp_attention: bool = False
|
||||
enable_dp_lm_head: bool = False
|
||||
enable_ep_moe: bool = False
|
||||
enable_deepep_moe: bool = False
|
||||
deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
|
||||
@@ -323,6 +324,11 @@ class ServerArgs:
|
||||
f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
|
||||
)
|
||||
|
||||
if self.enable_dp_lm_head:
|
||||
assert (
|
||||
self.enable_dp_attention
|
||||
), "Please enable dp attention when setting enable_dp_attention. "
|
||||
|
||||
# DeepEP MoE
|
||||
self.enable_sp_layernorm = False
|
||||
if self.enable_deepep_moe:
|
||||
@@ -1055,6 +1061,11 @@ class ServerArgs:
|
||||
action="store_true",
|
||||
help="Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently only DeepSeek-V2 is supported.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-dp-lm-head",
|
||||
action="store_true",
|
||||
help="Enable vocabulary parallel across the attention TP group to avoid all-gather across DP groups, optimizing performance under DP attention.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-ep-moe",
|
||||
action="store_true",
|
||||
|
||||
Reference in New Issue
Block a user