Fix two issues related to --moe-dense-tp-size=1 (#5657)
Co-authored-by: liusy58 <liusy58@linux.alibaba.com> Co-authored-by: 颉沆 <xiehang.lsy@alibaba-inc.com>
This commit is contained in:
@@ -24,8 +24,10 @@ if TYPE_CHECKING:
|
|||||||
_ATTN_TP_GROUP = None
|
_ATTN_TP_GROUP = None
|
||||||
_ATTN_TP_RANK = None
|
_ATTN_TP_RANK = None
|
||||||
_ATTN_TP_SIZE = None
|
_ATTN_TP_SIZE = None
|
||||||
_DP_RANK = None
|
_ATTN_DP_RANK = None
|
||||||
_DP_SIZE = None
|
_ATTN_DP_SIZE = None
|
||||||
|
_LOCAL_ATTN_DP_SIZE = None
|
||||||
|
_LOCAL_ATTN_DP_RANK = None
|
||||||
|
|
||||||
|
|
||||||
def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
|
def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
|
||||||
@@ -33,9 +35,27 @@ def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_si
|
|||||||
return tp_rank, tp_size, 0
|
return tp_rank, tp_size, 0
|
||||||
|
|
||||||
attn_tp_size = tp_size // dp_size
|
attn_tp_size = tp_size // dp_size
|
||||||
dp_rank = tp_rank // attn_tp_size
|
attn_dp_rank = tp_rank // attn_tp_size
|
||||||
attn_tp_rank = tp_rank % attn_tp_size
|
attn_tp_rank = tp_rank % attn_tp_size
|
||||||
return attn_tp_rank, attn_tp_size, dp_rank
|
|
||||||
|
return attn_tp_rank, attn_tp_size, attn_dp_rank
|
||||||
|
|
||||||
|
|
||||||
|
def compute_dp_attention_local_info(
|
||||||
|
enable_dp_attention, tp_rank, tp_size, dp_size, moe_dense_tp_size
|
||||||
|
):
|
||||||
|
if not enable_dp_attention:
|
||||||
|
return tp_rank, tp_size, 0
|
||||||
|
|
||||||
|
local_tp_size = moe_dense_tp_size if moe_dense_tp_size else tp_size
|
||||||
|
local_tp_rank = tp_rank % local_tp_size
|
||||||
|
local_dp_size = max(1, dp_size // (tp_size // local_tp_size))
|
||||||
|
|
||||||
|
local_attn_tp_size = local_tp_size // local_dp_size
|
||||||
|
local_attn_dp_rank = local_tp_rank // local_attn_tp_size
|
||||||
|
local_attn_tp_rank = local_tp_rank % local_attn_tp_size
|
||||||
|
|
||||||
|
return local_attn_tp_rank, local_attn_tp_size, local_attn_dp_rank
|
||||||
|
|
||||||
|
|
||||||
def initialize_dp_attention(
|
def initialize_dp_attention(
|
||||||
@@ -43,22 +63,32 @@ def initialize_dp_attention(
|
|||||||
tp_rank: int,
|
tp_rank: int,
|
||||||
tp_size: int,
|
tp_size: int,
|
||||||
dp_size: int,
|
dp_size: int,
|
||||||
|
moe_dense_tp_size: int,
|
||||||
pp_size: int,
|
pp_size: int,
|
||||||
):
|
):
|
||||||
global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK, _DP_SIZE
|
global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK, _ATTN_DP_SIZE
|
||||||
|
global _LOCAL_ATTN_DP_SIZE, _LOCAL_ATTN_DP_RANK
|
||||||
|
|
||||||
from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP
|
from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP
|
||||||
|
|
||||||
_ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK = compute_dp_attention_world_info(
|
_ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK = compute_dp_attention_world_info(
|
||||||
enable_dp_attention, tp_rank, tp_size, dp_size
|
enable_dp_attention, tp_rank, tp_size, dp_size
|
||||||
)
|
)
|
||||||
|
_, _, _LOCAL_ATTN_DP_RANK = compute_dp_attention_local_info(
|
||||||
|
enable_dp_attention, tp_rank, tp_size, dp_size, moe_dense_tp_size
|
||||||
|
)
|
||||||
|
|
||||||
if enable_dp_attention:
|
if enable_dp_attention:
|
||||||
local_rank = tp_rank % (tp_size // dp_size)
|
local_rank = tp_rank % (tp_size // dp_size)
|
||||||
_DP_SIZE = dp_size
|
_ATTN_DP_SIZE = dp_size
|
||||||
|
if moe_dense_tp_size is None:
|
||||||
|
_LOCAL_ATTN_DP_SIZE = _ATTN_DP_SIZE
|
||||||
|
else:
|
||||||
|
_LOCAL_ATTN_DP_SIZE = max(1, dp_size // (tp_size // moe_dense_tp_size))
|
||||||
else:
|
else:
|
||||||
local_rank = tp_rank
|
local_rank = tp_rank
|
||||||
_DP_SIZE = 1
|
_ATTN_DP_SIZE = 1
|
||||||
|
_LOCAL_ATTN_DP_SIZE = 1
|
||||||
|
|
||||||
tp_group = get_tp_group()
|
tp_group = get_tp_group()
|
||||||
_ATTN_TP_GROUP = GroupCoordinator(
|
_ATTN_TP_GROUP = GroupCoordinator(
|
||||||
@@ -93,13 +123,33 @@ def get_attention_tp_size():
|
|||||||
|
|
||||||
|
|
||||||
def get_attention_dp_rank():
|
def get_attention_dp_rank():
|
||||||
assert _DP_RANK is not None, "dp attention not initialized!"
|
assert _ATTN_DP_RANK is not None, "dp attention not initialized!"
|
||||||
return _DP_RANK
|
return _ATTN_DP_RANK
|
||||||
|
|
||||||
|
|
||||||
def get_attention_dp_size():
|
def get_attention_dp_size():
|
||||||
assert _DP_SIZE is not None, "dp attention not initialized!"
|
assert _ATTN_DP_SIZE is not None, "dp attention not initialized!"
|
||||||
return _DP_SIZE
|
return _ATTN_DP_SIZE
|
||||||
|
|
||||||
|
|
||||||
|
def get_local_attention_dp_rank():
|
||||||
|
assert _LOCAL_ATTN_DP_RANK is not None, "dp attention not initialized!"
|
||||||
|
return _LOCAL_ATTN_DP_RANK
|
||||||
|
|
||||||
|
|
||||||
|
def get_local_attention_dp_size():
|
||||||
|
assert _LOCAL_ATTN_DP_SIZE is not None, "dp attention not initialized!"
|
||||||
|
return _LOCAL_ATTN_DP_SIZE
|
||||||
|
|
||||||
|
|
||||||
|
def get_local_attention_dp_rank():
|
||||||
|
assert _LOCAL_ATTN_DP_RANK is not None, "dp attention not initialized!"
|
||||||
|
return _LOCAL_ATTN_DP_RANK
|
||||||
|
|
||||||
|
|
||||||
|
def get_local_attention_dp_size():
|
||||||
|
assert _LOCAL_ATTN_DP_SIZE is not None, "dp attention not initialized!"
|
||||||
|
return _LOCAL_ATTN_DP_SIZE
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
@@ -112,19 +162,19 @@ def disable_dp_size():
|
|||||||
Args:
|
Args:
|
||||||
tp_group (GroupCoordinator): the tp group coordinator
|
tp_group (GroupCoordinator): the tp group coordinator
|
||||||
"""
|
"""
|
||||||
global _DP_SIZE
|
global _ATTN_DP_SIZE
|
||||||
assert _DP_SIZE is not None, "dp attention not initialized!"
|
assert _ATTN_DP_SIZE is not None, "dp attention not initialized!"
|
||||||
|
|
||||||
old_dp_size = _DP_SIZE
|
old_dp_size = _ATTN_DP_SIZE
|
||||||
_DP_SIZE = 1
|
_ATTN_DP_SIZE = 1
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
_DP_SIZE = old_dp_size
|
_ATTN_DP_SIZE = old_dp_size
|
||||||
|
|
||||||
|
|
||||||
def get_dp_local_info(forward_batch: ForwardBatch):
|
def get_dp_local_info(forward_batch: ForwardBatch):
|
||||||
dp_rank = get_attention_dp_rank()
|
dp_rank = get_local_attention_dp_rank()
|
||||||
|
|
||||||
if forward_batch.dp_local_start_pos is None:
|
if forward_batch.dp_local_start_pos is None:
|
||||||
cumtokens = torch.cumsum(forward_batch.global_num_tokens_gpu, dim=0)
|
cumtokens = torch.cumsum(forward_batch.global_num_tokens_gpu, dim=0)
|
||||||
|
|||||||
@@ -30,9 +30,10 @@ from sglang.srt.layers.dp_attention import (
|
|||||||
attn_tp_all_gather,
|
attn_tp_all_gather,
|
||||||
dp_gather_replicate,
|
dp_gather_replicate,
|
||||||
dp_scatter,
|
dp_scatter,
|
||||||
get_attention_dp_rank,
|
|
||||||
get_attention_dp_size,
|
get_attention_dp_size,
|
||||||
get_attention_tp_size,
|
get_attention_tp_size,
|
||||||
|
get_local_attention_dp_rank,
|
||||||
|
get_local_attention_dp_size,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
@@ -46,6 +47,18 @@ from sglang.srt.utils import dump_to_file
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||||
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
|
from sglang.srt.model_executor.forward_batch_info import (
|
||||||
|
CaptureHiddenMode,
|
||||||
|
ForwardBatch,
|
||||||
|
ForwardMode,
|
||||||
|
)
|
||||||
|
from sglang.srt.utils import dump_to_file
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class LogitsProcessorOutput:
|
class LogitsProcessorOutput:
|
||||||
## Part 1: This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
|
## Part 1: This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
|
||||||
@@ -170,7 +183,7 @@ class LogitsMetadata:
|
|||||||
return
|
return
|
||||||
|
|
||||||
cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0)
|
cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0)
|
||||||
dp_rank = get_attention_dp_rank()
|
dp_rank = get_local_attention_dp_rank()
|
||||||
if dp_rank == 0:
|
if dp_rank == 0:
|
||||||
dp_local_start_pos = torch.zeros_like(
|
dp_local_start_pos = torch.zeros_like(
|
||||||
self.global_num_tokens_for_logprob_gpu[0]
|
self.global_num_tokens_for_logprob_gpu[0]
|
||||||
@@ -324,7 +337,8 @@ class LogitsProcessor(nn.Module):
|
|||||||
|
|
||||||
if self.debug_tensor_dump_output_folder:
|
if self.debug_tensor_dump_output_folder:
|
||||||
assert (
|
assert (
|
||||||
not self.do_tensor_parallel_all_gather or get_attention_dp_size() == 1
|
not self.do_tensor_parallel_all_gather
|
||||||
|
or get_local_attention_dp_size() == 1
|
||||||
), "dp attention + sharded lm_head doesn't support full logits"
|
), "dp attention + sharded lm_head doesn't support full logits"
|
||||||
full_logits = self._get_logits(hidden_states, lm_head, logits_metadata)
|
full_logits = self._get_logits(hidden_states, lm_head, logits_metadata)
|
||||||
dump_to_file(self.debug_tensor_dump_output_folder, "logits", full_logits)
|
dump_to_file(self.debug_tensor_dump_output_folder, "logits", full_logits)
|
||||||
|
|||||||
@@ -207,7 +207,8 @@ class Scheduler(
|
|||||||
self.page_size = server_args.page_size
|
self.page_size = server_args.page_size
|
||||||
|
|
||||||
# Distributed rank info
|
# Distributed rank info
|
||||||
self.attn_tp_rank, self.attn_tp_size, self.dp_rank = (
|
self.dp_size = server_args.dp_size
|
||||||
|
self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = (
|
||||||
compute_dp_attention_world_info(
|
compute_dp_attention_world_info(
|
||||||
server_args.enable_dp_attention,
|
server_args.enable_dp_attention,
|
||||||
self.tp_rank,
|
self.tp_rank,
|
||||||
@@ -768,7 +769,7 @@ class Scheduler(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# send out reqs to the next stage
|
# send out reqs to the next stage
|
||||||
dp_offset = self.dp_rank * self.attn_tp_size
|
dp_offset = self.attn_dp_rank * self.attn_tp_size
|
||||||
if self.attn_tp_rank == 0:
|
if self.attn_tp_rank == 0:
|
||||||
point_to_point_pyobj(
|
point_to_point_pyobj(
|
||||||
recv_reqs,
|
recv_reqs,
|
||||||
@@ -815,7 +816,7 @@ class Scheduler(
|
|||||||
recv_reqs = None
|
recv_reqs = None
|
||||||
else:
|
else:
|
||||||
if self.attn_tp_rank == 0:
|
if self.attn_tp_rank == 0:
|
||||||
dp_offset = self.dp_rank * self.attn_tp_size
|
dp_offset = self.attn_dp_rank * self.attn_tp_size
|
||||||
recv_reqs = point_to_point_pyobj(
|
recv_reqs = point_to_point_pyobj(
|
||||||
[],
|
[],
|
||||||
self.pp_rank * self.tp_size + dp_offset,
|
self.pp_rank * self.tp_size + dp_offset,
|
||||||
@@ -1610,6 +1611,7 @@ class Scheduler(
|
|||||||
local_batch,
|
local_batch,
|
||||||
dp_size=self.server_args.dp_size,
|
dp_size=self.server_args.dp_size,
|
||||||
attn_tp_size=self.attn_tp_size,
|
attn_tp_size=self.attn_tp_size,
|
||||||
|
moe_dense_tp_size=self.server_args.moe_dense_tp_size,
|
||||||
tp_cpu_group=self.tp_cpu_group,
|
tp_cpu_group=self.tp_cpu_group,
|
||||||
get_idle_batch=self.get_idle_batch,
|
get_idle_batch=self.get_idle_batch,
|
||||||
disable_cuda_graph=self.server_args.disable_cuda_graph,
|
disable_cuda_graph=self.server_args.disable_cuda_graph,
|
||||||
@@ -1622,6 +1624,7 @@ class Scheduler(
|
|||||||
local_batch: ScheduleBatch,
|
local_batch: ScheduleBatch,
|
||||||
dp_size,
|
dp_size,
|
||||||
attn_tp_size: int,
|
attn_tp_size: int,
|
||||||
|
moe_dense_tp_size: Optional[int],
|
||||||
tp_cpu_group,
|
tp_cpu_group,
|
||||||
get_idle_batch,
|
get_idle_batch,
|
||||||
disable_cuda_graph: bool,
|
disable_cuda_graph: bool,
|
||||||
@@ -1631,15 +1634,15 @@ class Scheduler(
|
|||||||
# Check if other DP workers have running batches
|
# Check if other DP workers have running batches
|
||||||
if local_batch is None:
|
if local_batch is None:
|
||||||
num_tokens = 0
|
num_tokens = 0
|
||||||
global_num_tokens_for_logprob = 0
|
num_tokens_for_logprob = 0
|
||||||
elif local_batch.forward_mode.is_decode():
|
elif local_batch.forward_mode.is_decode():
|
||||||
num_tokens = local_batch.batch_size()
|
num_tokens = local_batch.batch_size()
|
||||||
if not spec_algorithm.is_none() and spec_algorithm.is_eagle():
|
if not spec_algorithm.is_none() and spec_algorithm.is_eagle():
|
||||||
num_tokens = num_tokens * speculative_num_draft_tokens
|
num_tokens = num_tokens * speculative_num_draft_tokens
|
||||||
global_num_tokens_for_logprob = num_tokens
|
num_tokens_for_logprob = num_tokens
|
||||||
else:
|
else:
|
||||||
num_tokens = local_batch.extend_num_tokens
|
num_tokens = local_batch.extend_num_tokens
|
||||||
global_num_tokens_for_logprob = sum(
|
num_tokens_for_logprob = sum(
|
||||||
[
|
[
|
||||||
# We should have at least 1 token for sample in every case.
|
# We should have at least 1 token for sample in every case.
|
||||||
max(extend_len - logprob_start_len, 1)
|
max(extend_len - logprob_start_len, 1)
|
||||||
@@ -1666,7 +1669,7 @@ class Scheduler(
|
|||||||
[
|
[
|
||||||
num_tokens,
|
num_tokens,
|
||||||
can_cuda_graph,
|
can_cuda_graph,
|
||||||
global_num_tokens_for_logprob,
|
num_tokens_for_logprob,
|
||||||
is_extend_in_batch,
|
is_extend_in_batch,
|
||||||
],
|
],
|
||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
@@ -1689,8 +1692,15 @@ class Scheduler(
|
|||||||
local_batch = get_idle_batch()
|
local_batch = get_idle_batch()
|
||||||
|
|
||||||
if local_batch is not None:
|
if local_batch is not None:
|
||||||
local_batch.global_num_tokens = global_num_tokens
|
# TODO: handle the case when moe_dense_tp_size != 1
|
||||||
local_batch.global_num_tokens_for_logprob = global_num_tokens_for_logprob
|
if moe_dense_tp_size == 1 and global_server_args_dict["enable_dp_lm_head"]:
|
||||||
|
local_batch.global_num_tokens = [num_tokens]
|
||||||
|
local_batch.global_num_tokens_for_logprob = [num_tokens_for_logprob]
|
||||||
|
else:
|
||||||
|
local_batch.global_num_tokens = global_num_tokens
|
||||||
|
local_batch.global_num_tokens_for_logprob = (
|
||||||
|
global_num_tokens_for_logprob
|
||||||
|
)
|
||||||
|
|
||||||
# Check forward mode for cuda graph
|
# Check forward mode for cuda graph
|
||||||
if not disable_cuda_graph:
|
if not disable_cuda_graph:
|
||||||
@@ -2177,8 +2187,8 @@ class Scheduler(
|
|||||||
|
|
||||||
def get_print_prefix(self):
|
def get_print_prefix(self):
|
||||||
prefix = ""
|
prefix = ""
|
||||||
if self.dp_rank is not None:
|
if self.attn_dp_rank is not None:
|
||||||
prefix += f" DP{self.dp_rank}"
|
prefix += f" DP{self.attn_dp_rank}"
|
||||||
if self.server_args.tp_size > 1:
|
if self.server_args.tp_size > 1:
|
||||||
prefix += f" TP{self.tp_rank}"
|
prefix += f" TP{self.tp_rank}"
|
||||||
if self.pp_size > 1:
|
if self.pp_size > 1:
|
||||||
|
|||||||
@@ -401,6 +401,7 @@ class ModelRunner:
|
|||||||
tp_rank=self.tp_rank,
|
tp_rank=self.tp_rank,
|
||||||
tp_size=self.tp_size,
|
tp_size=self.tp_size,
|
||||||
dp_size=self.server_args.dp_size,
|
dp_size=self.server_args.dp_size,
|
||||||
|
moe_dense_tp_size=self.server_args.moe_dense_tp_size,
|
||||||
pp_size=self.server_args.pp_size,
|
pp_size=self.server_args.pp_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -40,9 +40,9 @@ from sglang.srt.layers.dp_attention import (
|
|||||||
attn_tp_reduce_scatter,
|
attn_tp_reduce_scatter,
|
||||||
dp_gather_partial,
|
dp_gather_partial,
|
||||||
dp_scatter,
|
dp_scatter,
|
||||||
get_attention_dp_size,
|
|
||||||
get_attention_tp_rank,
|
get_attention_tp_rank,
|
||||||
get_attention_tp_size,
|
get_attention_tp_size,
|
||||||
|
get_local_attention_dp_size,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
from sglang.srt.layers.linear import (
|
from sglang.srt.layers.linear import (
|
||||||
@@ -438,7 +438,6 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
self.v_head_dim = v_head_dim
|
self.v_head_dim = v_head_dim
|
||||||
self.q_lora_rank = q_lora_rank
|
self.q_lora_rank = q_lora_rank
|
||||||
self.kv_lora_rank = kv_lora_rank
|
self.kv_lora_rank = kv_lora_rank
|
||||||
self.dp_size = get_attention_dp_size()
|
|
||||||
attn_tp_rank = get_attention_tp_rank()
|
attn_tp_rank = get_attention_tp_rank()
|
||||||
attn_tp_size = get_attention_tp_size()
|
attn_tp_size = get_attention_tp_size()
|
||||||
|
|
||||||
@@ -1133,7 +1132,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
||||||
self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
|
self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
|
||||||
self.layer_id = layer_id
|
self.layer_id = layer_id
|
||||||
self.dp_size = get_attention_dp_size()
|
self.local_dp_size = get_local_attention_dp_size()
|
||||||
self.attn_tp_size = get_attention_tp_size()
|
self.attn_tp_size = get_attention_tp_size()
|
||||||
self.attn_tp_rank = get_attention_tp_rank()
|
self.attn_tp_rank = get_attention_tp_rank()
|
||||||
self.self_attn = DeepseekV2AttentionMLA(
|
self.self_attn = DeepseekV2AttentionMLA(
|
||||||
@@ -1184,7 +1183,8 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.input_is_scattered = (
|
self.input_is_scattered = (
|
||||||
previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED
|
layer_id > 0
|
||||||
|
and previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED
|
||||||
)
|
)
|
||||||
self.is_last_layer = self.layer_id == config.num_hidden_layers - 1
|
self.is_last_layer = self.layer_id == config.num_hidden_layers - 1
|
||||||
|
|
||||||
@@ -1264,7 +1264,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
# Gather
|
# Gather
|
||||||
if get_tensor_model_parallel_world_size() > 1:
|
if get_tensor_model_parallel_world_size() > 1:
|
||||||
# all gather and all reduce
|
# all gather and all reduce
|
||||||
if self.dp_size != 1:
|
if self.local_dp_size != 1:
|
||||||
if self.attn_tp_rank == 0:
|
if self.attn_tp_rank == 0:
|
||||||
hidden_states += residual
|
hidden_states += residual
|
||||||
hidden_states, local_hidden_states = (
|
hidden_states, local_hidden_states = (
|
||||||
@@ -1289,7 +1289,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
|
|
||||||
# TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
|
# TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
|
||||||
# Scatter
|
# Scatter
|
||||||
if self.dp_size != 1:
|
if self.local_dp_size != 1:
|
||||||
# important: forward batch.gathered_buffer is used both after scatter and after gather.
|
# important: forward batch.gathered_buffer is used both after scatter and after gather.
|
||||||
# be careful about this!
|
# be careful about this!
|
||||||
hidden_states, global_hidden_states = (
|
hidden_states, global_hidden_states = (
|
||||||
@@ -1413,7 +1413,7 @@ class DeepseekV2Model(nn.Module):
|
|||||||
)
|
)
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
self.dp_size = get_attention_dp_size()
|
self.dp_size = get_local_attention_dp_size()
|
||||||
|
|
||||||
def get_input_embeddings(self) -> torch.Tensor:
|
def get_input_embeddings(self) -> torch.Tensor:
|
||||||
return self.embed_tokens
|
return self.embed_tokens
|
||||||
@@ -1478,7 +1478,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
||||||
)
|
)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
self.dp_size = get_attention_dp_size()
|
self.dp_size = get_local_attention_dp_size()
|
||||||
|
|
||||||
def determine_n_share_experts_fusion(
|
def determine_n_share_experts_fusion(
|
||||||
self, architecture: str = "DeepseekV3ForCausalLM"
|
self, architecture: str = "DeepseekV3ForCausalLM"
|
||||||
|
|||||||
@@ -30,9 +30,9 @@ from sglang.srt.distributed import (
|
|||||||
from sglang.srt.layers.dp_attention import (
|
from sglang.srt.layers.dp_attention import (
|
||||||
dp_gather_partial,
|
dp_gather_partial,
|
||||||
dp_scatter,
|
dp_scatter,
|
||||||
get_attention_dp_size,
|
|
||||||
get_attention_tp_rank,
|
get_attention_tp_rank,
|
||||||
get_attention_tp_size,
|
get_attention_tp_size,
|
||||||
|
get_local_attention_dp_size,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
from sglang.srt.layers.linear import (
|
from sglang.srt.layers.linear import (
|
||||||
@@ -198,7 +198,6 @@ class Llama4Attention(nn.Module):
|
|||||||
self.use_rope = int((layer_id + 1) % 4 != 0)
|
self.use_rope = int((layer_id + 1) % 4 != 0)
|
||||||
self.use_qk_norm = config.use_qk_norm and self.use_rope
|
self.use_qk_norm = config.use_qk_norm and self.use_rope
|
||||||
|
|
||||||
self.dp_size = get_attention_dp_size()
|
|
||||||
attn_tp_rank = get_attention_tp_rank()
|
attn_tp_rank = get_attention_tp_rank()
|
||||||
attn_tp_size = get_attention_tp_size()
|
attn_tp_size = get_attention_tp_size()
|
||||||
|
|
||||||
@@ -342,7 +341,7 @@ class Llama4DecoderLayer(nn.Module):
|
|||||||
rope_theta = config.rope_theta
|
rope_theta = config.rope_theta
|
||||||
rope_scaling = config.rope_scaling
|
rope_scaling = config.rope_scaling
|
||||||
max_position_embeddings = config.max_position_embeddings
|
max_position_embeddings = config.max_position_embeddings
|
||||||
self.dp_size = get_attention_dp_size()
|
self.local_dp_size = get_local_attention_dp_size()
|
||||||
self.attn_tp_size = get_attention_tp_size()
|
self.attn_tp_size = get_attention_tp_size()
|
||||||
self.attn_tp_rank = get_attention_tp_rank()
|
self.attn_tp_rank = get_attention_tp_rank()
|
||||||
|
|
||||||
@@ -405,7 +404,7 @@ class Llama4DecoderLayer(nn.Module):
|
|||||||
# Gather
|
# Gather
|
||||||
if get_tensor_model_parallel_world_size() > 1:
|
if get_tensor_model_parallel_world_size() > 1:
|
||||||
# all gather and all reduce
|
# all gather and all reduce
|
||||||
if self.dp_size != 1:
|
if self.local_dp_size != 1:
|
||||||
if self.attn_tp_rank == 0:
|
if self.attn_tp_rank == 0:
|
||||||
hidden_states += residual
|
hidden_states += residual
|
||||||
hidden_states, local_hidden_states = (
|
hidden_states, local_hidden_states = (
|
||||||
@@ -430,7 +429,7 @@ class Llama4DecoderLayer(nn.Module):
|
|||||||
|
|
||||||
# TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
|
# TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
|
||||||
# Scatter
|
# Scatter
|
||||||
if self.dp_size != 1:
|
if self.local_dp_size != 1:
|
||||||
# important: forward batch.gathered_buffer is used both after scatter and after gather.
|
# important: forward batch.gathered_buffer is used both after scatter and after gather.
|
||||||
# be careful about this!
|
# be careful about this!
|
||||||
hidden_states, global_hidden_states = (
|
hidden_states, global_hidden_states = (
|
||||||
|
|||||||
Reference in New Issue
Block a user