Performing Vocabulary Parallelism for LM Head across Attention TP Groups (#5558)
Co-authored-by: liusy58 <liusy58@linux.alibaba.com>
This commit is contained in:
@@ -221,3 +221,4 @@ Please consult the documentation below and [server_args.py](https://github.com/s
|
||||
| `triton_attention_num_kv_splits` | Use to adjust the number of KV splits in triton kernels. | `8` |
|
||||
| `flashinfer_mla_disable_ragged` | Disable the use of the [ragged prefill](https://github.com/flashinfer-ai/flashinfer/blob/5751fc68f109877f6e0fc54f674cdcdef361af56/docs/tutorials/kv_layout.rst#L26) wrapper for the FlashInfer MLA attention backend. Ragged prefill increases throughput by computing MHA instead of paged MLA when there is no prefix match. Only use it when FlashInfer is being used as the MLA backend. | `False` |
|
||||
| `disable_chunked_prefix_cache` | Disable the use of chunked prefix cache for DeepSeek models. Only use it when FA3 is attention backend. | `False` |
|
||||
| `enable_dp_lm_head` | Enable vocabulary parallel across the attention TP group to avoid all-gather across DP groups, optimizing performance under DP attention. | `False` |
|
||||
|
||||
@@ -252,12 +252,12 @@ def dp_scatter(
|
||||
)
|
||||
|
||||
|
||||
def tp_reduce_scatter(
|
||||
def attn_tp_reduce_scatter(
|
||||
output: torch.Tensor,
|
||||
input_list: List[torch.Tensor],
|
||||
):
|
||||
return get_attention_tp_group().reduce_scatter(output, input_list)
|
||||
|
||||
|
||||
def tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor):
|
||||
def attn_tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor):
|
||||
return get_attention_tp_group().all_gather(input_, tensor_list=output_list)
|
||||
|
||||
@@ -23,15 +23,16 @@ import triton.language as tl
|
||||
from torch import nn
|
||||
|
||||
from sglang.srt.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather,
|
||||
)
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
attn_tp_all_gather,
|
||||
dp_gather_replicate,
|
||||
dp_scatter,
|
||||
get_attention_dp_rank,
|
||||
get_attention_dp_size,
|
||||
get_attention_tp_size,
|
||||
)
|
||||
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
@@ -198,12 +199,20 @@ class LogitsProcessor(nn.Module):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.logit_scale = logit_scale
|
||||
self.do_tensor_parallel_all_gather = (
|
||||
not skip_all_gather and get_tensor_model_parallel_world_size() > 1
|
||||
)
|
||||
self.do_tensor_parallel_all_gather_dp_attn = (
|
||||
self.do_tensor_parallel_all_gather and get_attention_dp_size() != 1
|
||||
)
|
||||
self.use_attn_tp_group = global_server_args_dict["enable_dp_lm_head"]
|
||||
if self.use_attn_tp_group:
|
||||
self.attn_tp_size = get_attention_tp_size()
|
||||
self.do_tensor_parallel_all_gather = (
|
||||
not skip_all_gather and self.attn_tp_size > 1
|
||||
)
|
||||
self.do_tensor_parallel_all_gather_dp_attn = False
|
||||
else:
|
||||
self.do_tensor_parallel_all_gather = (
|
||||
not skip_all_gather and get_tensor_model_parallel_world_size() > 1
|
||||
)
|
||||
self.do_tensor_parallel_all_gather_dp_attn = (
|
||||
self.do_tensor_parallel_all_gather and get_attention_dp_size() != 1
|
||||
)
|
||||
self.final_logit_softcapping = getattr(
|
||||
self.config, "final_logit_softcapping", None
|
||||
)
|
||||
@@ -442,7 +451,19 @@ class LogitsProcessor(nn.Module):
|
||||
logits.mul_(self.logit_scale)
|
||||
|
||||
if self.do_tensor_parallel_all_gather:
|
||||
logits = tensor_model_parallel_all_gather(logits)
|
||||
if self.use_attn_tp_group:
|
||||
global_logits = torch.empty(
|
||||
(self.config.vocab_size, logits.shape[0]),
|
||||
device=logits.device,
|
||||
dtype=logits.dtype,
|
||||
)
|
||||
global_logits = global_logits.T
|
||||
attn_tp_all_gather(
|
||||
list(global_logits.tensor_split(self.attn_tp_size, dim=-1)), logits
|
||||
)
|
||||
logits = global_logits
|
||||
else:
|
||||
logits = tensor_model_parallel_all_gather(logits)
|
||||
|
||||
if self.do_tensor_parallel_all_gather_dp_attn:
|
||||
logits, global_logits = (
|
||||
|
||||
@@ -13,6 +13,7 @@ from sglang.srt.distributed import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
|
||||
from sglang.srt.layers.parameter import BasevLLMParameter
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
@@ -214,12 +215,14 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
*,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
org_num_embeddings: Optional[int] = None,
|
||||
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
enable_tp: bool = True,
|
||||
use_attn_tp_group: bool = False,
|
||||
use_presharded_weights: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -227,9 +230,14 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
|
||||
self.enable_tp = enable_tp
|
||||
if self.enable_tp:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
if use_attn_tp_group:
|
||||
tp_rank = get_attention_tp_rank()
|
||||
self.tp_size = get_attention_tp_size()
|
||||
else:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
else:
|
||||
assert use_attn_tp_group is False
|
||||
tp_rank = 0
|
||||
self.tp_size = 1
|
||||
|
||||
@@ -519,22 +527,25 @@ class ParallelLMHead(VocabParallelEmbedding):
|
||||
self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
*,
|
||||
bias: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
org_num_embeddings: Optional[int] = None,
|
||||
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
use_attn_tp_group: bool = False,
|
||||
use_presharded_weights: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
num_embeddings,
|
||||
embedding_dim,
|
||||
params_dtype,
|
||||
org_num_embeddings,
|
||||
padding_size,
|
||||
quant_config,
|
||||
prefix,
|
||||
params_dtype=params_dtype,
|
||||
org_num_embeddings=org_num_embeddings,
|
||||
padding_size=padding_size,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
use_attn_tp_group=use_attn_tp_group,
|
||||
use_presharded_weights=use_presharded_weights,
|
||||
)
|
||||
self.quant_config = quant_config
|
||||
|
||||
@@ -74,6 +74,7 @@ global_server_args_dict = {
|
||||
"disable_radix_cache": ServerArgs.disable_radix_cache,
|
||||
"enable_deepep_moe": ServerArgs.enable_deepep_moe,
|
||||
"enable_dp_attention": ServerArgs.enable_dp_attention,
|
||||
"enable_dp_lm_head": ServerArgs.enable_dp_lm_head,
|
||||
"enable_ep_moe": ServerArgs.enable_ep_moe,
|
||||
"enable_nan_detection": ServerArgs.enable_nan_detection,
|
||||
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
|
||||
|
||||
@@ -36,13 +36,13 @@ from sglang.srt.distributed import (
|
||||
)
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
attn_tp_all_gather,
|
||||
attn_tp_reduce_scatter,
|
||||
dp_gather_partial,
|
||||
dp_scatter,
|
||||
get_attention_dp_size,
|
||||
get_attention_tp_rank,
|
||||
get_attention_tp_size,
|
||||
tp_all_gather,
|
||||
tp_reduce_scatter,
|
||||
)
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.linear import (
|
||||
@@ -1323,7 +1323,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
||||
hidden_states,
|
||||
)
|
||||
tp_all_gather(
|
||||
attn_tp_all_gather(
|
||||
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
|
||||
)
|
||||
|
||||
@@ -1339,7 +1339,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
if self.input_is_scattered:
|
||||
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
|
||||
hidden_states = tensor_list[self.attn_tp_rank]
|
||||
tp_reduce_scatter(hidden_states, tensor_list)
|
||||
attn_tp_reduce_scatter(hidden_states, tensor_list)
|
||||
if hidden_states.shape[0] != 0:
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual
|
||||
@@ -1349,7 +1349,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
hidden_states += residual
|
||||
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
|
||||
hidden_states = tensor_list[self.attn_tp_rank]
|
||||
tp_reduce_scatter(hidden_states, tensor_list)
|
||||
attn_tp_reduce_scatter(hidden_states, tensor_list)
|
||||
residual = hidden_states
|
||||
if hidden_states.shape[0] != 0:
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
@@ -1373,7 +1373,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
||||
hidden_states,
|
||||
)
|
||||
tp_all_gather(
|
||||
attn_tp_all_gather(
|
||||
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
|
||||
)
|
||||
|
||||
@@ -1475,6 +1475,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.dp_size = get_attention_dp_size()
|
||||
|
||||
@@ -45,6 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||
from sglang.srt.model_loader.weight_utils import (
|
||||
default_weight_loader,
|
||||
@@ -420,6 +421,7 @@ class LlamaForCausalLM(nn.Module):
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||
|
||||
@@ -159,6 +159,7 @@ class ServerArgs:
|
||||
disable_overlap_schedule: bool = False
|
||||
enable_mixed_chunk: bool = False
|
||||
enable_dp_attention: bool = False
|
||||
enable_dp_lm_head: bool = False
|
||||
enable_ep_moe: bool = False
|
||||
enable_deepep_moe: bool = False
|
||||
deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
|
||||
@@ -323,6 +324,11 @@ class ServerArgs:
|
||||
f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
|
||||
)
|
||||
|
||||
if self.enable_dp_lm_head:
|
||||
assert (
|
||||
self.enable_dp_attention
|
||||
), "Please enable dp attention when setting enable_dp_attention. "
|
||||
|
||||
# DeepEP MoE
|
||||
self.enable_sp_layernorm = False
|
||||
if self.enable_deepep_moe:
|
||||
@@ -1055,6 +1061,11 @@ class ServerArgs:
|
||||
action="store_true",
|
||||
help="Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently only DeepSeek-V2 is supported.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-dp-lm-head",
|
||||
action="store_true",
|
||||
help="Enable vocabulary parallel across the attention TP group to avoid all-gather across DP groups, optimizing performance under DP attention.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-ep-moe",
|
||||
action="store_true",
|
||||
|
||||
Reference in New Issue
Block a user