Fix two issues related to --moe-dense-tp-size=1 (#5657)
Co-authored-by: liusy58 <liusy58@linux.alibaba.com> Co-authored-by: 颉沆 <xiehang.lsy@alibaba-inc.com>
This commit is contained in:
@@ -24,8 +24,10 @@ if TYPE_CHECKING:
|
||||
_ATTN_TP_GROUP = None
|
||||
_ATTN_TP_RANK = None
|
||||
_ATTN_TP_SIZE = None
|
||||
_DP_RANK = None
|
||||
_DP_SIZE = None
|
||||
_ATTN_DP_RANK = None
|
||||
_ATTN_DP_SIZE = None
|
||||
_LOCAL_ATTN_DP_SIZE = None
|
||||
_LOCAL_ATTN_DP_RANK = None
|
||||
|
||||
|
||||
def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
|
||||
@@ -33,9 +35,27 @@ def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_si
|
||||
return tp_rank, tp_size, 0
|
||||
|
||||
attn_tp_size = tp_size // dp_size
|
||||
dp_rank = tp_rank // attn_tp_size
|
||||
attn_dp_rank = tp_rank // attn_tp_size
|
||||
attn_tp_rank = tp_rank % attn_tp_size
|
||||
return attn_tp_rank, attn_tp_size, dp_rank
|
||||
|
||||
return attn_tp_rank, attn_tp_size, attn_dp_rank
|
||||
|
||||
|
||||
def compute_dp_attention_local_info(
|
||||
enable_dp_attention, tp_rank, tp_size, dp_size, moe_dense_tp_size
|
||||
):
|
||||
if not enable_dp_attention:
|
||||
return tp_rank, tp_size, 0
|
||||
|
||||
local_tp_size = moe_dense_tp_size if moe_dense_tp_size else tp_size
|
||||
local_tp_rank = tp_rank % local_tp_size
|
||||
local_dp_size = max(1, dp_size // (tp_size // local_tp_size))
|
||||
|
||||
local_attn_tp_size = local_tp_size // local_dp_size
|
||||
local_attn_dp_rank = local_tp_rank // local_attn_tp_size
|
||||
local_attn_tp_rank = local_tp_rank % local_attn_tp_size
|
||||
|
||||
return local_attn_tp_rank, local_attn_tp_size, local_attn_dp_rank
|
||||
|
||||
|
||||
def initialize_dp_attention(
|
||||
@@ -43,22 +63,32 @@ def initialize_dp_attention(
|
||||
tp_rank: int,
|
||||
tp_size: int,
|
||||
dp_size: int,
|
||||
moe_dense_tp_size: int,
|
||||
pp_size: int,
|
||||
):
|
||||
global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK, _DP_SIZE
|
||||
global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK, _ATTN_DP_SIZE
|
||||
global _LOCAL_ATTN_DP_SIZE, _LOCAL_ATTN_DP_RANK
|
||||
|
||||
from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP
|
||||
|
||||
_ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK = compute_dp_attention_world_info(
|
||||
_ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK = compute_dp_attention_world_info(
|
||||
enable_dp_attention, tp_rank, tp_size, dp_size
|
||||
)
|
||||
_, _, _LOCAL_ATTN_DP_RANK = compute_dp_attention_local_info(
|
||||
enable_dp_attention, tp_rank, tp_size, dp_size, moe_dense_tp_size
|
||||
)
|
||||
|
||||
if enable_dp_attention:
|
||||
local_rank = tp_rank % (tp_size // dp_size)
|
||||
_DP_SIZE = dp_size
|
||||
_ATTN_DP_SIZE = dp_size
|
||||
if moe_dense_tp_size is None:
|
||||
_LOCAL_ATTN_DP_SIZE = _ATTN_DP_SIZE
|
||||
else:
|
||||
_LOCAL_ATTN_DP_SIZE = max(1, dp_size // (tp_size // moe_dense_tp_size))
|
||||
else:
|
||||
local_rank = tp_rank
|
||||
_DP_SIZE = 1
|
||||
_ATTN_DP_SIZE = 1
|
||||
_LOCAL_ATTN_DP_SIZE = 1
|
||||
|
||||
tp_group = get_tp_group()
|
||||
_ATTN_TP_GROUP = GroupCoordinator(
|
||||
@@ -93,13 +123,33 @@ def get_attention_tp_size():
|
||||
|
||||
|
||||
def get_attention_dp_rank():
|
||||
assert _DP_RANK is not None, "dp attention not initialized!"
|
||||
return _DP_RANK
|
||||
assert _ATTN_DP_RANK is not None, "dp attention not initialized!"
|
||||
return _ATTN_DP_RANK
|
||||
|
||||
|
||||
def get_attention_dp_size():
|
||||
assert _DP_SIZE is not None, "dp attention not initialized!"
|
||||
return _DP_SIZE
|
||||
assert _ATTN_DP_SIZE is not None, "dp attention not initialized!"
|
||||
return _ATTN_DP_SIZE
|
||||
|
||||
|
||||
def get_local_attention_dp_rank():
|
||||
assert _LOCAL_ATTN_DP_RANK is not None, "dp attention not initialized!"
|
||||
return _LOCAL_ATTN_DP_RANK
|
||||
|
||||
|
||||
def get_local_attention_dp_size():
|
||||
assert _LOCAL_ATTN_DP_SIZE is not None, "dp attention not initialized!"
|
||||
return _LOCAL_ATTN_DP_SIZE
|
||||
|
||||
|
||||
def get_local_attention_dp_rank():
|
||||
assert _LOCAL_ATTN_DP_RANK is not None, "dp attention not initialized!"
|
||||
return _LOCAL_ATTN_DP_RANK
|
||||
|
||||
|
||||
def get_local_attention_dp_size():
|
||||
assert _LOCAL_ATTN_DP_SIZE is not None, "dp attention not initialized!"
|
||||
return _LOCAL_ATTN_DP_SIZE
|
||||
|
||||
|
||||
@contextmanager
|
||||
@@ -112,19 +162,19 @@ def disable_dp_size():
|
||||
Args:
|
||||
tp_group (GroupCoordinator): the tp group coordinator
|
||||
"""
|
||||
global _DP_SIZE
|
||||
assert _DP_SIZE is not None, "dp attention not initialized!"
|
||||
global _ATTN_DP_SIZE
|
||||
assert _ATTN_DP_SIZE is not None, "dp attention not initialized!"
|
||||
|
||||
old_dp_size = _DP_SIZE
|
||||
_DP_SIZE = 1
|
||||
old_dp_size = _ATTN_DP_SIZE
|
||||
_ATTN_DP_SIZE = 1
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_DP_SIZE = old_dp_size
|
||||
_ATTN_DP_SIZE = old_dp_size
|
||||
|
||||
|
||||
def get_dp_local_info(forward_batch: ForwardBatch):
|
||||
dp_rank = get_attention_dp_rank()
|
||||
dp_rank = get_local_attention_dp_rank()
|
||||
|
||||
if forward_batch.dp_local_start_pos is None:
|
||||
cumtokens = torch.cumsum(forward_batch.global_num_tokens_gpu, dim=0)
|
||||
|
||||
@@ -30,9 +30,10 @@ from sglang.srt.layers.dp_attention import (
|
||||
attn_tp_all_gather,
|
||||
dp_gather_replicate,
|
||||
dp_scatter,
|
||||
get_attention_dp_rank,
|
||||
get_attention_dp_size,
|
||||
get_attention_tp_size,
|
||||
get_local_attention_dp_rank,
|
||||
get_local_attention_dp_size,
|
||||
)
|
||||
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
@@ -46,6 +47,18 @@ from sglang.srt.utils import dump_to_file
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import (
|
||||
CaptureHiddenMode,
|
||||
ForwardBatch,
|
||||
ForwardMode,
|
||||
)
|
||||
from sglang.srt.utils import dump_to_file
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class LogitsProcessorOutput:
|
||||
## Part 1: This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
|
||||
@@ -170,7 +183,7 @@ class LogitsMetadata:
|
||||
return
|
||||
|
||||
cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0)
|
||||
dp_rank = get_attention_dp_rank()
|
||||
dp_rank = get_local_attention_dp_rank()
|
||||
if dp_rank == 0:
|
||||
dp_local_start_pos = torch.zeros_like(
|
||||
self.global_num_tokens_for_logprob_gpu[0]
|
||||
@@ -324,7 +337,8 @@ class LogitsProcessor(nn.Module):
|
||||
|
||||
if self.debug_tensor_dump_output_folder:
|
||||
assert (
|
||||
not self.do_tensor_parallel_all_gather or get_attention_dp_size() == 1
|
||||
not self.do_tensor_parallel_all_gather
|
||||
or get_local_attention_dp_size() == 1
|
||||
), "dp attention + sharded lm_head doesn't support full logits"
|
||||
full_logits = self._get_logits(hidden_states, lm_head, logits_metadata)
|
||||
dump_to_file(self.debug_tensor_dump_output_folder, "logits", full_logits)
|
||||
|
||||
@@ -207,7 +207,8 @@ class Scheduler(
|
||||
self.page_size = server_args.page_size
|
||||
|
||||
# Distributed rank info
|
||||
self.attn_tp_rank, self.attn_tp_size, self.dp_rank = (
|
||||
self.dp_size = server_args.dp_size
|
||||
self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = (
|
||||
compute_dp_attention_world_info(
|
||||
server_args.enable_dp_attention,
|
||||
self.tp_rank,
|
||||
@@ -768,7 +769,7 @@ class Scheduler(
|
||||
)
|
||||
|
||||
# send out reqs to the next stage
|
||||
dp_offset = self.dp_rank * self.attn_tp_size
|
||||
dp_offset = self.attn_dp_rank * self.attn_tp_size
|
||||
if self.attn_tp_rank == 0:
|
||||
point_to_point_pyobj(
|
||||
recv_reqs,
|
||||
@@ -815,7 +816,7 @@ class Scheduler(
|
||||
recv_reqs = None
|
||||
else:
|
||||
if self.attn_tp_rank == 0:
|
||||
dp_offset = self.dp_rank * self.attn_tp_size
|
||||
dp_offset = self.attn_dp_rank * self.attn_tp_size
|
||||
recv_reqs = point_to_point_pyobj(
|
||||
[],
|
||||
self.pp_rank * self.tp_size + dp_offset,
|
||||
@@ -1610,6 +1611,7 @@ class Scheduler(
|
||||
local_batch,
|
||||
dp_size=self.server_args.dp_size,
|
||||
attn_tp_size=self.attn_tp_size,
|
||||
moe_dense_tp_size=self.server_args.moe_dense_tp_size,
|
||||
tp_cpu_group=self.tp_cpu_group,
|
||||
get_idle_batch=self.get_idle_batch,
|
||||
disable_cuda_graph=self.server_args.disable_cuda_graph,
|
||||
@@ -1622,6 +1624,7 @@ class Scheduler(
|
||||
local_batch: ScheduleBatch,
|
||||
dp_size,
|
||||
attn_tp_size: int,
|
||||
moe_dense_tp_size: Optional[int],
|
||||
tp_cpu_group,
|
||||
get_idle_batch,
|
||||
disable_cuda_graph: bool,
|
||||
@@ -1631,15 +1634,15 @@ class Scheduler(
|
||||
# Check if other DP workers have running batches
|
||||
if local_batch is None:
|
||||
num_tokens = 0
|
||||
global_num_tokens_for_logprob = 0
|
||||
num_tokens_for_logprob = 0
|
||||
elif local_batch.forward_mode.is_decode():
|
||||
num_tokens = local_batch.batch_size()
|
||||
if not spec_algorithm.is_none() and spec_algorithm.is_eagle():
|
||||
num_tokens = num_tokens * speculative_num_draft_tokens
|
||||
global_num_tokens_for_logprob = num_tokens
|
||||
num_tokens_for_logprob = num_tokens
|
||||
else:
|
||||
num_tokens = local_batch.extend_num_tokens
|
||||
global_num_tokens_for_logprob = sum(
|
||||
num_tokens_for_logprob = sum(
|
||||
[
|
||||
# We should have at least 1 token for sample in every case.
|
||||
max(extend_len - logprob_start_len, 1)
|
||||
@@ -1666,7 +1669,7 @@ class Scheduler(
|
||||
[
|
||||
num_tokens,
|
||||
can_cuda_graph,
|
||||
global_num_tokens_for_logprob,
|
||||
num_tokens_for_logprob,
|
||||
is_extend_in_batch,
|
||||
],
|
||||
dtype=torch.int64,
|
||||
@@ -1689,8 +1692,15 @@ class Scheduler(
|
||||
local_batch = get_idle_batch()
|
||||
|
||||
if local_batch is not None:
|
||||
local_batch.global_num_tokens = global_num_tokens
|
||||
local_batch.global_num_tokens_for_logprob = global_num_tokens_for_logprob
|
||||
# TODO: handle the case when moe_dense_tp_size != 1
|
||||
if moe_dense_tp_size == 1 and global_server_args_dict["enable_dp_lm_head"]:
|
||||
local_batch.global_num_tokens = [num_tokens]
|
||||
local_batch.global_num_tokens_for_logprob = [num_tokens_for_logprob]
|
||||
else:
|
||||
local_batch.global_num_tokens = global_num_tokens
|
||||
local_batch.global_num_tokens_for_logprob = (
|
||||
global_num_tokens_for_logprob
|
||||
)
|
||||
|
||||
# Check forward mode for cuda graph
|
||||
if not disable_cuda_graph:
|
||||
@@ -2177,8 +2187,8 @@ class Scheduler(
|
||||
|
||||
def get_print_prefix(self):
|
||||
prefix = ""
|
||||
if self.dp_rank is not None:
|
||||
prefix += f" DP{self.dp_rank}"
|
||||
if self.attn_dp_rank is not None:
|
||||
prefix += f" DP{self.attn_dp_rank}"
|
||||
if self.server_args.tp_size > 1:
|
||||
prefix += f" TP{self.tp_rank}"
|
||||
if self.pp_size > 1:
|
||||
|
||||
@@ -401,6 +401,7 @@ class ModelRunner:
|
||||
tp_rank=self.tp_rank,
|
||||
tp_size=self.tp_size,
|
||||
dp_size=self.server_args.dp_size,
|
||||
moe_dense_tp_size=self.server_args.moe_dense_tp_size,
|
||||
pp_size=self.server_args.pp_size,
|
||||
)
|
||||
|
||||
|
||||
@@ -40,9 +40,9 @@ from sglang.srt.layers.dp_attention import (
|
||||
attn_tp_reduce_scatter,
|
||||
dp_gather_partial,
|
||||
dp_scatter,
|
||||
get_attention_dp_size,
|
||||
get_attention_tp_rank,
|
||||
get_attention_tp_size,
|
||||
get_local_attention_dp_size,
|
||||
)
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.linear import (
|
||||
@@ -438,7 +438,6 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
self.v_head_dim = v_head_dim
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.dp_size = get_attention_dp_size()
|
||||
attn_tp_rank = get_attention_tp_rank()
|
||||
attn_tp_size = get_attention_tp_size()
|
||||
|
||||
@@ -1133,7 +1132,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
||||
self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
|
||||
self.layer_id = layer_id
|
||||
self.dp_size = get_attention_dp_size()
|
||||
self.local_dp_size = get_local_attention_dp_size()
|
||||
self.attn_tp_size = get_attention_tp_size()
|
||||
self.attn_tp_rank = get_attention_tp_rank()
|
||||
self.self_attn = DeepseekV2AttentionMLA(
|
||||
@@ -1184,7 +1183,8 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
)
|
||||
|
||||
self.input_is_scattered = (
|
||||
previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED
|
||||
layer_id > 0
|
||||
and previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED
|
||||
)
|
||||
self.is_last_layer = self.layer_id == config.num_hidden_layers - 1
|
||||
|
||||
@@ -1264,7 +1264,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
# Gather
|
||||
if get_tensor_model_parallel_world_size() > 1:
|
||||
# all gather and all reduce
|
||||
if self.dp_size != 1:
|
||||
if self.local_dp_size != 1:
|
||||
if self.attn_tp_rank == 0:
|
||||
hidden_states += residual
|
||||
hidden_states, local_hidden_states = (
|
||||
@@ -1289,7 +1289,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
|
||||
# TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
|
||||
# Scatter
|
||||
if self.dp_size != 1:
|
||||
if self.local_dp_size != 1:
|
||||
# important: forward batch.gathered_buffer is used both after scatter and after gather.
|
||||
# be careful about this!
|
||||
hidden_states, global_hidden_states = (
|
||||
@@ -1413,7 +1413,7 @@ class DeepseekV2Model(nn.Module):
|
||||
)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
self.dp_size = get_attention_dp_size()
|
||||
self.dp_size = get_local_attention_dp_size()
|
||||
|
||||
def get_input_embeddings(self) -> torch.Tensor:
|
||||
return self.embed_tokens
|
||||
@@ -1478,7 +1478,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.dp_size = get_attention_dp_size()
|
||||
self.dp_size = get_local_attention_dp_size()
|
||||
|
||||
def determine_n_share_experts_fusion(
|
||||
self, architecture: str = "DeepseekV3ForCausalLM"
|
||||
|
||||
@@ -30,9 +30,9 @@ from sglang.srt.distributed import (
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
dp_gather_partial,
|
||||
dp_scatter,
|
||||
get_attention_dp_size,
|
||||
get_attention_tp_rank,
|
||||
get_attention_tp_size,
|
||||
get_local_attention_dp_size,
|
||||
)
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.linear import (
|
||||
@@ -198,7 +198,6 @@ class Llama4Attention(nn.Module):
|
||||
self.use_rope = int((layer_id + 1) % 4 != 0)
|
||||
self.use_qk_norm = config.use_qk_norm and self.use_rope
|
||||
|
||||
self.dp_size = get_attention_dp_size()
|
||||
attn_tp_rank = get_attention_tp_rank()
|
||||
attn_tp_size = get_attention_tp_size()
|
||||
|
||||
@@ -342,7 +341,7 @@ class Llama4DecoderLayer(nn.Module):
|
||||
rope_theta = config.rope_theta
|
||||
rope_scaling = config.rope_scaling
|
||||
max_position_embeddings = config.max_position_embeddings
|
||||
self.dp_size = get_attention_dp_size()
|
||||
self.local_dp_size = get_local_attention_dp_size()
|
||||
self.attn_tp_size = get_attention_tp_size()
|
||||
self.attn_tp_rank = get_attention_tp_rank()
|
||||
|
||||
@@ -405,7 +404,7 @@ class Llama4DecoderLayer(nn.Module):
|
||||
# Gather
|
||||
if get_tensor_model_parallel_world_size() > 1:
|
||||
# all gather and all reduce
|
||||
if self.dp_size != 1:
|
||||
if self.local_dp_size != 1:
|
||||
if self.attn_tp_rank == 0:
|
||||
hidden_states += residual
|
||||
hidden_states, local_hidden_states = (
|
||||
@@ -430,7 +429,7 @@ class Llama4DecoderLayer(nn.Module):
|
||||
|
||||
# TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
|
||||
# Scatter
|
||||
if self.dp_size != 1:
|
||||
if self.local_dp_size != 1:
|
||||
# important: forward batch.gathered_buffer is used both after scatter and after gather.
|
||||
# be careful about this!
|
||||
hidden_states, global_hidden_states = (
|
||||
|
||||
Reference in New Issue
Block a user