Minor add metrics to expert location updater (#6816)

This commit is contained in:
fzyzcjy
2025-06-03 14:59:11 +08:00
committed by GitHub
parent 0ea330ca34
commit b6d0ce9f78

View File

@@ -12,8 +12,10 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
import logging import logging
from typing import Dict, List, Tuple from collections import defaultdict
from typing import Dict, List, Optional, Tuple
import einops
import torch import torch
import torch.distributed import torch.distributed
from torch.distributed import P2POp from torch.distributed import P2POp
@@ -22,6 +24,7 @@ from sglang.srt.managers.expert_location import (
ExpertLocationMetadata, ExpertLocationMetadata,
get_global_expert_location_metadata, get_global_expert_location_metadata,
) )
from sglang.srt.utils import get_bool_env_var
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -59,6 +62,8 @@ def _update_expert_weights(
nnodes: int, nnodes: int,
rank: int, rank: int,
): ):
log_metrics = get_bool_env_var("SGLANG_EXPERT_LOCATION_UPDATER_LOG_METRICS")
temp_buffers = create_temp_buffers( temp_buffers = create_temp_buffers(
next(iter(routed_experts_weights_of_layer.values())) next(iter(routed_experts_weights_of_layer.values()))
) )
@@ -83,6 +88,8 @@ def _update_expert_weights(
num_local_physical_experts=num_local_physical_experts, num_local_physical_experts=num_local_physical_experts,
num_gpu_per_node=num_gpu_per_node, num_gpu_per_node=num_gpu_per_node,
rank=rank, rank=rank,
world_size=world_size,
log_metrics=log_metrics,
) )
@@ -98,7 +105,9 @@ def update_expert_weights_single_layer(
num_local_physical_experts: int, num_local_physical_experts: int,
num_gpu_per_node: int, num_gpu_per_node: int,
rank: int, rank: int,
world_size: Optional[int] = None,
debug: bool = False, debug: bool = False,
log_metrics: bool = False,
): ):
assert all( assert all(
tensor.shape[0] == num_local_physical_experts tensor.shape[0] == num_local_physical_experts
@@ -130,6 +139,14 @@ def update_expert_weights_single_layer(
_execute_p2p_ops(p2p_op_infos) _execute_p2p_ops(p2p_op_infos)
_execute_buffer2weight_copies(buffer2weight_copy_infos) _execute_buffer2weight_copies(buffer2weight_copy_infos)
if log_metrics:
_log_p2p_op_metrics(
p2p_op_infos,
world_size=world_size,
num_gpu_per_node=num_gpu_per_node,
self_node_id=self_node_id,
)
if debug: if debug:
output_logs.append(f"{p2p_op_infos=}") output_logs.append(f"{p2p_op_infos=}")
output_logs.append(f"{buffer2weight_copy_infos=}") output_logs.append(f"{buffer2weight_copy_infos=}")
@@ -429,3 +446,53 @@ def _deduplicate_ordered(arr: List[int]):
if len(output) == 0 or item != output[-1]: if len(output) == 0 or item != output[-1]:
output.append(item) output.append(item)
return output return output
def _log_p2p_op_metrics(
p2p_op_infos: List[Tuple[int, List[P2POp]]],
num_gpu_per_node: int,
world_size: int,
self_node_id: int,
):
text = ""
all_ops = [op for _, ops in p2p_op_infos for op in ops]
for direction, ops in _group_by(all_ops, _get_direction_from_op).items():
nbytes_of_gpu = [0] * world_size
for op in ops:
nbytes_of_gpu[op.peer] += op.tensor.nbytes
nbytes_of_gpu = torch.tensor(nbytes_of_gpu, dtype=torch.int64)
nbytes_of_node = einops.reduce(
nbytes_of_gpu,
"(num_nodes num_gpu_per_node) -> num_nodes",
num_gpu_per_node=num_gpu_per_node,
reduction="sum",
)
nbytes_curr_node = nbytes_of_node[self_node_id]
nbytes_cross_node = torch.sum(nbytes_of_node) - nbytes_curr_node
text += (
f"{direction}_nbytes_of_gpu={nbytes_of_gpu.tolist()} "
f"{direction}_nbytes_of_node={nbytes_of_node.tolist()} "
f"{direction}_nbytes_curr_node={nbytes_curr_node.item()} "
f"{direction}_nbytes_cross_node={nbytes_cross_node.item()} "
)
logger.info(f"[ExpertLocationUpdater] {text}")
def _get_direction_from_op(op: P2POp):
if op.op == torch.distributed.isend:
return "isend"
if op.op == torch.distributed.irecv:
return "irecv"
raise NotImplementedError
def _group_by(items, keyfunc):
ans = defaultdict(list)
for item in items:
ans[keyfunc(item)].append(item)
return dict(ans)