Expert distribution recording without overhead for EPLB (#4957)
This commit is contained in:
@@ -390,7 +390,7 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"expert_record_server_process, port = launch_server_cmd(\n",
|
"expert_record_server_process, port = launch_server_cmd(\n",
|
||||||
" \"python3 -m sglang.launch_server --model-path Qwen/Qwen1.5-MoE-A2.7B --host 0.0.0.0\"\n",
|
" \"python3 -m sglang.launch_server --model-path Qwen/Qwen1.5-MoE-A2.7B --host 0.0.0.0 --expert-distribution-recorder-mode stat\"\n",
|
||||||
")\n",
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"wait_for_server(f\"http://localhost:{port}\")"
|
"wait_for_server(f\"http://localhost:{port}\")"
|
||||||
@@ -415,19 +415,7 @@
|
|||||||
"print_highlight(response)\n",
|
"print_highlight(response)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"response = requests.post(f\"http://localhost:{port}/dump_expert_distribution_record\")\n",
|
"response = requests.post(f\"http://localhost:{port}/dump_expert_distribution_record\")\n",
|
||||||
"print_highlight(response)\n",
|
"print_highlight(response)"
|
||||||
"\n",
|
|
||||||
"import glob\n",
|
|
||||||
"\n",
|
|
||||||
"output_file = glob.glob(\"expert_distribution_*.csv\")[0]\n",
|
|
||||||
"with open(output_file, \"r\") as f:\n",
|
|
||||||
" print_highlight(\"\\n| Layer ID | Expert ID | Count |\")\n",
|
|
||||||
" print_highlight(\"|----------|-----------|--------|\")\n",
|
|
||||||
" next(f)\n",
|
|
||||||
" for i, line in enumerate(f):\n",
|
|
||||||
" if i < 9:\n",
|
|
||||||
" layer_id, expert_id, count = line.strip().split(\",\")\n",
|
|
||||||
" print_highlight(f\"| {layer_id:8} | {expert_id:9} | {count:6} |\")"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -1,6 +1,9 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
||||||
|
from sglang.srt.managers.expert_distribution import (
|
||||||
|
get_global_expert_distribution_recorder,
|
||||||
|
)
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.utils import DeepEPMode, load_json_config
|
from sglang.srt.utils import DeepEPMode, load_json_config
|
||||||
|
|
||||||
@@ -326,6 +329,13 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|||||||
config=_DeepEPConfig.get_instance().normal_dispatch_config,
|
config=_DeepEPConfig.get_instance().normal_dispatch_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
get_global_expert_distribution_recorder().on_deepep_dispatch_normal(
|
||||||
|
num_recv_tokens_per_expert_list,
|
||||||
|
num_tokens_per_rank=num_tokens_per_rank,
|
||||||
|
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
|
||||||
|
num_tokens_per_expert=num_tokens_per_expert,
|
||||||
|
)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
recv_x,
|
recv_x,
|
||||||
recv_topk_idx,
|
recv_topk_idx,
|
||||||
@@ -489,6 +499,10 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|||||||
):
|
):
|
||||||
hook() if self.return_recv_hook else event.current_stream_wait()
|
hook() if self.return_recv_hook else event.current_stream_wait()
|
||||||
|
|
||||||
|
get_global_expert_distribution_recorder().on_deepep_dispatch_low_latency(
|
||||||
|
masked_m
|
||||||
|
)
|
||||||
|
|
||||||
reorder_topk_ids = seg_indptr = None
|
reorder_topk_ids = seg_indptr = None
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
|||||||
@@ -18,7 +18,10 @@ from typing import Callable, Optional
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
|
from sglang.srt.managers.expert_distribution import (
|
||||||
|
ExpertDistributionRecorder,
|
||||||
|
get_global_expert_distribution_recorder,
|
||||||
|
)
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip
|
from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip
|
||||||
|
|
||||||
@@ -31,8 +34,6 @@ if _is_cuda:
|
|||||||
if _is_cuda or _is_hip:
|
if _is_cuda or _is_hip:
|
||||||
from sgl_kernel import topk_softmax
|
from sgl_kernel import topk_softmax
|
||||||
|
|
||||||
expert_distribution_recorder = ExpertDistributionRecorder()
|
|
||||||
|
|
||||||
|
|
||||||
def fused_topk_native(
|
def fused_topk_native(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@@ -353,6 +354,6 @@ def select_experts(
|
|||||||
renormalize=renormalize,
|
renormalize=renormalize,
|
||||||
)
|
)
|
||||||
|
|
||||||
expert_distribution_recorder.record_new_token(topk_ids)
|
get_global_expert_distribution_recorder().on_select_experts(topk_ids=topk_ids)
|
||||||
|
|
||||||
return topk_weights, topk_ids
|
return topk_weights, topk_ids
|
||||||
|
|||||||
@@ -1,81 +1,620 @@
|
|||||||
import json
|
# Copyright 2023-2024 SGLang Team
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from abc import ABC
|
||||||
from typing import Dict, List, Tuple
|
from contextlib import contextmanager
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Literal, Optional, Tuple, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
|
||||||
|
from sglang.srt.managers.expert_location import ExpertLocationMetadata
|
||||||
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
|
from sglang.srt.server_args import ServerArgs
|
||||||
|
from sglang.srt.utils import Withable
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# --------------------------------------- Entrypoint -----------------------------------------
|
||||||
|
|
||||||
# global expert distribution recording
|
_OutputMode = Literal["file", "object"]
|
||||||
class ExpertDistributionRecorder:
|
|
||||||
# This class is a singleton class
|
|
||||||
def __new__(cls):
|
|
||||||
if not hasattr(cls, "instance"):
|
|
||||||
cls.instance = super(ExpertDistributionRecorder, cls).__new__(cls)
|
|
||||||
return cls.instance
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
# the length of the dictionary is the number of layers
|
|
||||||
# the length of the list is the number of tokens
|
|
||||||
# the length of the tuple is topk's k value
|
|
||||||
self._expert_distribution_record: Dict[int, List[Tuple[int]]] = defaultdict(
|
|
||||||
list
|
|
||||||
)
|
|
||||||
self._record = False
|
|
||||||
self._current_layer_id = "UNKNOWN"
|
|
||||||
|
|
||||||
def set_current_layer(self, layer_idx):
|
class ExpertDistributionRecorder(ABC):
|
||||||
self._current_layer_id = layer_idx
|
"""Global expert distribution recording"""
|
||||||
|
|
||||||
def record_new_token(self, topk_ids):
|
@staticmethod
|
||||||
if not self._record:
|
def init_new(
|
||||||
return
|
server_args: ServerArgs,
|
||||||
topk_ids_list = topk_ids.to("cpu", non_blocking=True).numpy().tolist()
|
expert_location_metadata: "ExpertLocationMetadata",
|
||||||
torch.cuda.synchronize()
|
rank: int,
|
||||||
for i in topk_ids_list:
|
):
|
||||||
self._expert_distribution_record[self._current_layer_id].append(tuple(i))
|
if server_args.expert_distribution_recorder_mode is not None:
|
||||||
|
return _ExpertDistributionRecorderReal(
|
||||||
|
server_args, expert_location_metadata, rank
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return _ExpertDistributionRecorderNoop()
|
||||||
|
|
||||||
def reset(self):
|
@contextmanager
|
||||||
"""Reset the expert distribution recorder."""
|
def with_current_layer(self, layer_idx):
|
||||||
logger.info("Resetting expert distribution record...")
|
yield
|
||||||
self._record = False
|
|
||||||
self._expert_distribution_record.clear()
|
@contextmanager
|
||||||
self._current_layer_id = "UNKNOWN"
|
def with_debug_name(self, debug_name):
|
||||||
|
yield
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def with_forward_pass(self, forward_pass_id: int, forward_batch: ForwardBatch):
|
||||||
|
yield
|
||||||
|
|
||||||
|
def on_select_experts(self, topk_ids: torch.Tensor):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_deepep_dispatch_normal(
|
||||||
|
self,
|
||||||
|
local_physical_count_of_layer: List[int],
|
||||||
|
num_tokens_per_rank,
|
||||||
|
num_tokens_per_rdma_rank,
|
||||||
|
num_tokens_per_expert,
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_deepep_dispatch_low_latency(
|
||||||
|
self, local_physical_count_of_layer: torch.Tensor
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
def start_record(self):
|
def start_record(self):
|
||||||
"""Start recording the expert distribution. Reset the recorder and set the recording flag to True."""
|
self._on_not_implemented()
|
||||||
if self._record == True:
|
|
||||||
|
def stop_record(self):
|
||||||
|
self._on_not_implemented()
|
||||||
|
|
||||||
|
def dump_record(self, output_mode: _OutputMode = "file"):
|
||||||
|
self._on_not_implemented()
|
||||||
|
|
||||||
|
def _on_not_implemented(self):
|
||||||
|
raise Exception(
|
||||||
|
"Please set ServerArgs.expert_distribution_recorder_mode to use ExpertDistributionRecorder."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _ExpertDistributionRecorderNoop(ExpertDistributionRecorder):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class _ExpertDistributionRecorderReal(ExpertDistributionRecorder):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
server_args: ServerArgs,
|
||||||
|
expert_location_metadata: "ExpertLocationMetadata",
|
||||||
|
rank: int,
|
||||||
|
):
|
||||||
|
self._server_args = server_args
|
||||||
|
self._expert_location_metadata = expert_location_metadata
|
||||||
|
|
||||||
|
self._recording = False
|
||||||
|
self._current_forward_pass_id = Withable()
|
||||||
|
self._current_layer_idx = Withable()
|
||||||
|
self._current_debug_name = Withable()
|
||||||
|
self._accumulator = _Accumulator.init_new(
|
||||||
|
server_args, expert_location_metadata, rank
|
||||||
|
)
|
||||||
|
self._single_pass_gatherers = {
|
||||||
|
k: _SinglePassGatherer.init_new(server_args, expert_location_metadata, rank)
|
||||||
|
for k in self._accumulator.get_single_pass_gatherer_keys()
|
||||||
|
}
|
||||||
|
|
||||||
|
def with_current_layer(self, layer_idx):
|
||||||
|
return self._current_layer_idx.with_value(layer_idx)
|
||||||
|
|
||||||
|
def with_debug_name(self, debug_name):
|
||||||
|
return self._current_debug_name.with_value(debug_name)
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def with_forward_pass(self, forward_pass_id: int, forward_batch: ForwardBatch):
|
||||||
|
with self._current_forward_pass_id.with_value(forward_pass_id):
|
||||||
|
self._on_forward_pass_start(forward_batch)
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
self._on_forward_pass_end(forward_pass_id)
|
||||||
|
|
||||||
|
def _on_forward_pass_start(self, forward_batch: ForwardBatch):
|
||||||
|
if not self._recording:
|
||||||
|
return
|
||||||
|
for gatherer_key, gatherer in self._single_pass_gatherers.items():
|
||||||
|
gatherer.reset()
|
||||||
|
gatherer.on_forward_pass_start(forward_batch)
|
||||||
|
|
||||||
|
def _on_forward_pass_end(self, forward_pass_id: int):
|
||||||
|
if not self._recording:
|
||||||
|
return
|
||||||
|
for gatherer_key, gatherer in self._single_pass_gatherers.items():
|
||||||
|
single_pass_data = gatherer.collect()
|
||||||
|
self._accumulator.append(forward_pass_id, gatherer_key, single_pass_data)
|
||||||
|
|
||||||
|
def on_select_experts(self, topk_ids: torch.Tensor):
|
||||||
|
self._on_hook("on_select_experts", topk_ids=topk_ids)
|
||||||
|
|
||||||
|
def on_deepep_dispatch_normal(
|
||||||
|
self,
|
||||||
|
local_physical_count_of_layer: List[int],
|
||||||
|
num_tokens_per_rank,
|
||||||
|
num_tokens_per_rdma_rank,
|
||||||
|
num_tokens_per_expert,
|
||||||
|
):
|
||||||
|
self._on_hook(
|
||||||
|
"on_deepep_dispatch_normal",
|
||||||
|
local_physical_count_of_layer=local_physical_count_of_layer,
|
||||||
|
num_tokens_per_rank=num_tokens_per_rank,
|
||||||
|
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
|
||||||
|
num_tokens_per_expert=num_tokens_per_expert,
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_deepep_dispatch_low_latency(
|
||||||
|
self, local_physical_count_of_layer: torch.Tensor
|
||||||
|
):
|
||||||
|
self._on_hook(
|
||||||
|
"on_deepep_dispatch_low_latency",
|
||||||
|
local_physical_count_of_layer=local_physical_count_of_layer,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _on_hook(self, hook_name: str, **kwargs):
|
||||||
|
if not (self._recording or torch.cuda.is_current_stream_capturing()):
|
||||||
|
return
|
||||||
|
gatherer = self._single_pass_gatherers[
|
||||||
|
self._accumulator.get_single_pass_gatherer_key(
|
||||||
|
self._current_debug_name.value
|
||||||
|
)
|
||||||
|
]
|
||||||
|
getattr(gatherer, hook_name)(layer_idx=self._current_layer_idx.value, **kwargs)
|
||||||
|
|
||||||
|
def _reset(self):
|
||||||
|
"""Reset the expert distribution recorder."""
|
||||||
|
logger.info("Resetting ExpertDistributionRecorder...")
|
||||||
|
assert (
|
||||||
|
self._current_layer_idx.value is None
|
||||||
|
), f"{self._current_layer_idx.value=}"
|
||||||
|
for gatherer in self._single_pass_gatherers.values():
|
||||||
|
gatherer.reset()
|
||||||
|
self._accumulator.reset()
|
||||||
|
|
||||||
|
def start_record(self):
|
||||||
|
"""Start recording the expert distribution."""
|
||||||
|
if self._recording:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"SGLang server is already recording expert ids. Did you forget to dump the expert ids recorded so far by sending requests to the `/stop_expert_distribution_record` and `/dump_expert_distribution_record` endpoints?"
|
"SGLang server is already recording expert ids. Did you forget to dump the expert ids recorded so far by sending requests to the `/stop_expert_distribution_record` and `/dump_expert_distribution_record` endpoints?"
|
||||||
)
|
)
|
||||||
self.reset()
|
self._reset()
|
||||||
self._record = True
|
self._recording = True
|
||||||
|
|
||||||
def stop_record(self):
|
def stop_record(self):
|
||||||
"""Stop recording the expert distribution. Set the recording flag to False."""
|
"""Stop recording the expert distribution."""
|
||||||
if self._record == False:
|
if not self._recording:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"SGLang server has not been recording expert ids. Did you forget to start recording by sending request to the `/start_expert_distribution_record` endpoint?"
|
"SGLang server has not been recording expert ids. Did you forget to start recording by sending request to the `/start_expert_distribution_record` endpoint?"
|
||||||
)
|
)
|
||||||
self._record = False
|
self._recording = False
|
||||||
|
|
||||||
def dump_record(self):
|
def dump_record(self, output_mode: _OutputMode = "file"):
|
||||||
"""Dump the expert distribution record to a file. Reset the recorder after dumping."""
|
"""Dump the expert distribution record and reset the recorder after dumping."""
|
||||||
results = {}
|
output = self._accumulator.dump(output_mode=output_mode)
|
||||||
for layer_idx, layer_record in self._expert_distribution_record.items():
|
self._reset()
|
||||||
results[layer_idx] = defaultdict(int)
|
return output
|
||||||
for token_record in layer_record:
|
|
||||||
for expert_idx in token_record:
|
|
||||||
results[layer_idx][expert_idx] += 1
|
_global_expert_distribution_recorder: Optional[ExpertDistributionRecorder] = (
|
||||||
with open(
|
_ExpertDistributionRecorderNoop()
|
||||||
f"expert_distribution_rank{torch.distributed.get_rank()}_timestamp{time.time()}.csv",
|
)
|
||||||
"w",
|
|
||||||
) as fd:
|
|
||||||
fd.write("layer_id,expert_id,count\n")
|
def get_global_expert_distribution_recorder():
|
||||||
for layer_idx, layer_results in results.items():
|
return _global_expert_distribution_recorder
|
||||||
for expert_idx, count in layer_results.items():
|
|
||||||
fd.write(f"{layer_idx},{expert_idx},{count}\n")
|
|
||||||
self.reset()
|
def set_global_expert_distribution_recorder(value):
|
||||||
|
global _global_expert_distribution_recorder
|
||||||
|
_global_expert_distribution_recorder = value
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------------------- SinglePassGatherer -----------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class _SinglePassGatherer(ABC):
|
||||||
|
@staticmethod
|
||||||
|
def init_new(
|
||||||
|
server_args: ServerArgs,
|
||||||
|
expert_location_metadata: "ExpertLocationMetadata",
|
||||||
|
rank: int,
|
||||||
|
) -> "_SinglePassGatherer":
|
||||||
|
if server_args.expert_distribution_recorder_mode == "per_token":
|
||||||
|
return _DetailSinglePassGatherer(
|
||||||
|
server_args, expert_location_metadata, rank
|
||||||
|
)
|
||||||
|
if server_args.enable_deepep_moe:
|
||||||
|
if server_args.deepep_mode == "normal":
|
||||||
|
return _DeepepNormalSinglePassGatherer(expert_location_metadata, rank)
|
||||||
|
elif server_args.deepep_mode == "low_latency":
|
||||||
|
return _DeepepLowLatencySinglePassGatherer(
|
||||||
|
expert_location_metadata, rank
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
return _SelectExpertsSinglePassGatherer(expert_location_metadata, rank)
|
||||||
|
|
||||||
|
def __init__(self, expert_location_metadata: "ExpertLocationMetadata", rank: int):
|
||||||
|
self._expert_location_metadata = expert_location_metadata
|
||||||
|
self._rank = rank
|
||||||
|
|
||||||
|
def on_forward_pass_start(self, forward_batch: ForwardBatch):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_deepep_dispatch_normal(
|
||||||
|
self,
|
||||||
|
layer_idx: int,
|
||||||
|
local_physical_count_of_layer: List[int],
|
||||||
|
num_tokens_per_rank,
|
||||||
|
num_tokens_per_rdma_rank,
|
||||||
|
num_tokens_per_expert,
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_deepep_dispatch_low_latency(
|
||||||
|
self, layer_idx: int, local_physical_count_of_layer: torch.Tensor
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def collect(self) -> Dict:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class _LayerBasedSinglePassGatherer(_SinglePassGatherer):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._objects_of_layer = {}
|
||||||
|
|
||||||
|
def _on_layer_data(self, layer_idx: int, objects: List[int]):
|
||||||
|
assert 0 <= layer_idx < self._expert_location_metadata.num_layers
|
||||||
|
if layer_idx in self._objects_of_layer:
|
||||||
|
self._objects_of_layer[layer_idx] = _list_sum(
|
||||||
|
self._objects_of_layer[layer_idx], objects
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._objects_of_layer[layer_idx] = objects
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self._objects_of_layer.clear()
|
||||||
|
|
||||||
|
def _collect_objects(self, pad_len: int) -> torch.Tensor:
|
||||||
|
data = [
|
||||||
|
self._objects_of_layer.get(layer_index) or ([0] * pad_len)
|
||||||
|
for layer_index in range(self._expert_location_metadata.num_layers)
|
||||||
|
]
|
||||||
|
return torch.tensor(data)
|
||||||
|
|
||||||
|
|
||||||
|
def _list_sum(a: List, b: List) -> List:
|
||||||
|
return [x + y for x, y in zip(a, b, strict=True)]
|
||||||
|
|
||||||
|
|
||||||
|
class _SelectExpertsSinglePassGatherer(_LayerBasedSinglePassGatherer):
|
||||||
|
# pretty slow, but we will use the DeepEP Gatherer in production
|
||||||
|
def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor):
|
||||||
|
topk_ids_list = topk_ids.to("cpu", non_blocking=True).numpy().tolist()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
global_physical_count = [
|
||||||
|
0
|
||||||
|
] * self._expert_location_metadata.num_physical_experts
|
||||||
|
for token_record in topk_ids_list:
|
||||||
|
for global_physical_expert_idx in token_record:
|
||||||
|
global_physical_count[global_physical_expert_idx] += 1
|
||||||
|
|
||||||
|
self._on_layer_data(layer_idx, global_physical_count)
|
||||||
|
|
||||||
|
def collect(self) -> Dict:
|
||||||
|
global_physical_count = super()._collect_objects(
|
||||||
|
pad_len=self._expert_location_metadata.num_physical_experts
|
||||||
|
)
|
||||||
|
return dict(global_physical_count=global_physical_count)
|
||||||
|
|
||||||
|
|
||||||
|
class _DeepepNormalSinglePassGatherer(_LayerBasedSinglePassGatherer):
|
||||||
|
def on_deepep_dispatch_normal(
|
||||||
|
self,
|
||||||
|
layer_idx: int,
|
||||||
|
local_physical_count_of_layer: List[int],
|
||||||
|
num_tokens_per_rank,
|
||||||
|
num_tokens_per_rdma_rank,
|
||||||
|
num_tokens_per_expert,
|
||||||
|
):
|
||||||
|
assert isinstance(local_physical_count_of_layer, list)
|
||||||
|
self._on_layer_data(layer_idx, local_physical_count_of_layer)
|
||||||
|
|
||||||
|
def collect(self) -> Dict:
|
||||||
|
local_physical_count = super()._collect_objects(
|
||||||
|
pad_len=self._expert_location_metadata.num_local_physical_experts
|
||||||
|
)
|
||||||
|
global_physical_count = _convert_local_to_global_physical_count(
|
||||||
|
local_physical_count,
|
||||||
|
rank=self._rank,
|
||||||
|
num_local_physical_experts=self._expert_location_metadata.num_local_physical_experts,
|
||||||
|
num_physical_experts=self._expert_location_metadata.num_physical_experts,
|
||||||
|
)
|
||||||
|
return dict(global_physical_count=global_physical_count)
|
||||||
|
|
||||||
|
|
||||||
|
class _DeepepLowLatencySinglePassGatherer(_SinglePassGatherer):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._data = torch.zeros(
|
||||||
|
(
|
||||||
|
self._expert_location_metadata.num_layers,
|
||||||
|
self._expert_location_metadata.num_local_physical_experts,
|
||||||
|
),
|
||||||
|
dtype=torch.int,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_deepep_dispatch_low_latency(
|
||||||
|
self, layer_idx: int, local_physical_count_of_layer: torch.Tensor
|
||||||
|
):
|
||||||
|
# Most naive implementation, can optimize later
|
||||||
|
self._data[layer_idx, :] += local_physical_count_of_layer
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self._data[...] = 0
|
||||||
|
|
||||||
|
def collect(self) -> Dict:
|
||||||
|
# Can optimize if bottleneck
|
||||||
|
global_physical_count = _convert_local_to_global_physical_count(
|
||||||
|
self._data,
|
||||||
|
rank=self._rank,
|
||||||
|
num_local_physical_experts=self._expert_location_metadata.num_local_physical_experts,
|
||||||
|
num_physical_experts=self._expert_location_metadata.num_physical_experts,
|
||||||
|
)
|
||||||
|
return dict(global_physical_count=global_physical_count)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_local_to_global_physical_count(
|
||||||
|
local_physical_count: torch.Tensor,
|
||||||
|
rank: int,
|
||||||
|
num_local_physical_experts: int,
|
||||||
|
num_physical_experts: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
dtype = local_physical_count.dtype
|
||||||
|
device = local_physical_count.device
|
||||||
|
num_layers, _ = local_physical_count.shape
|
||||||
|
|
||||||
|
ans = torch.zeros((num_layers, num_physical_experts), dtype=dtype, device=device)
|
||||||
|
ans[
|
||||||
|
:, num_local_physical_experts * rank : num_local_physical_experts * (rank + 1)
|
||||||
|
] = local_physical_count
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------------------- Accumulator -----------------------------------------
|
||||||
|
|
||||||
|
_SINGLE_PASS_GATHERER_KEY_PRIMARY = "primary"
|
||||||
|
|
||||||
|
|
||||||
|
class _Accumulator(ABC):
|
||||||
|
@staticmethod
|
||||||
|
def init_new(
|
||||||
|
server_args: ServerArgs,
|
||||||
|
expert_location_metadata: "ExpertLocationMetadata",
|
||||||
|
rank: int,
|
||||||
|
) -> "_Accumulator":
|
||||||
|
return _Accumulator.get_class(server_args)(
|
||||||
|
server_args, expert_location_metadata, rank
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_class(server_args: ServerArgs) -> Type["_Accumulator"]:
|
||||||
|
return {
|
||||||
|
"stat": _StatAccumulator,
|
||||||
|
# TODO pr-chain: enable this later
|
||||||
|
# "per_pass": _DetailAccumulator,
|
||||||
|
# "per_token": _DetailAccumulator,
|
||||||
|
}[server_args.expert_distribution_recorder_mode]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
server_args: ServerArgs,
|
||||||
|
expert_location_metadata: "ExpertLocationMetadata",
|
||||||
|
rank: int,
|
||||||
|
):
|
||||||
|
self._server_args = server_args
|
||||||
|
self._expert_location_metadata = expert_location_metadata
|
||||||
|
self._rank = rank
|
||||||
|
|
||||||
|
def get_single_pass_gatherer_keys(self):
|
||||||
|
return [_SINGLE_PASS_GATHERER_KEY_PRIMARY]
|
||||||
|
|
||||||
|
def get_single_pass_gatherer_key(self, debug_name: Optional[str]):
|
||||||
|
return _SINGLE_PASS_GATHERER_KEY_PRIMARY
|
||||||
|
|
||||||
|
def append(
|
||||||
|
self,
|
||||||
|
forward_pass_id: int,
|
||||||
|
gatherer_key: str,
|
||||||
|
single_pass_data: Dict,
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def dump(self, output_mode: _OutputMode):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class _StatAccumulator(_Accumulator):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._global_physical_count_of_buffered_step = _Buffer.init_new(
|
||||||
|
item_shape=(
|
||||||
|
self._expert_location_metadata.num_layers,
|
||||||
|
# Cannot use local_physical_count to support select_experts
|
||||||
|
self._expert_location_metadata.num_physical_experts,
|
||||||
|
),
|
||||||
|
buffer_size=self._server_args.expert_distribution_recorder_buffer_size,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self._server_args.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
def append(
|
||||||
|
self,
|
||||||
|
forward_pass_id: int,
|
||||||
|
gatherer_key: str,
|
||||||
|
single_pass_data: Dict,
|
||||||
|
):
|
||||||
|
super().append(forward_pass_id, gatherer_key, single_pass_data)
|
||||||
|
# Can optimize if overhead here is large
|
||||||
|
self._global_physical_count_of_buffered_step.append(
|
||||||
|
single_pass_data["global_physical_count"]
|
||||||
|
)
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
super().reset()
|
||||||
|
self._global_physical_count_of_buffered_step.reset()
|
||||||
|
|
||||||
|
def dump(self, output_mode: _OutputMode):
|
||||||
|
logical_count_of_buffered_step = _convert_global_physical_count_to_logical_count(
|
||||||
|
self._global_physical_count_of_buffered_step.get_all(),
|
||||||
|
num_layers=self._expert_location_metadata.num_layers,
|
||||||
|
num_logical_experts=self._expert_location_metadata.num_logical_experts,
|
||||||
|
physical_to_logical_map=self._expert_location_metadata.physical_to_logical_map,
|
||||||
|
)
|
||||||
|
torch.distributed.all_reduce(
|
||||||
|
logical_count_of_buffered_step, op=torch.distributed.ReduceOp.SUM
|
||||||
|
)
|
||||||
|
output = dict(
|
||||||
|
rank=self._rank,
|
||||||
|
logical_count=logical_count_of_buffered_step,
|
||||||
|
)
|
||||||
|
|
||||||
|
if output_mode == "file":
|
||||||
|
if self._rank == 0:
|
||||||
|
_dump_to_file(f"expert_distribution_recorder_{time.time()}.pt", output)
|
||||||
|
elif output_mode == "object":
|
||||||
|
return output
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
def _dump_to_file(name, data):
|
||||||
|
save_dir = Path(os.environ.get("SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR", "/tmp"))
|
||||||
|
path_output = save_dir / name
|
||||||
|
logger.info(f"Write expert distribution to {path_output}")
|
||||||
|
if not save_dir.exists():
|
||||||
|
save_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
torch.save(data, str(path_output))
|
||||||
|
|
||||||
|
|
||||||
|
class _Buffer:
|
||||||
|
@staticmethod
|
||||||
|
def init_new(item_shape: Tuple, buffer_size: int, dtype, device):
|
||||||
|
if buffer_size < 0:
|
||||||
|
return _InfiniteBuffer(item_shape, dtype=dtype, device=device)
|
||||||
|
else:
|
||||||
|
return _CircularBuffer(item_shape, buffer_size, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def append(self, value: torch.Tensor):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_all(self) -> torch.Tensor:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class _CircularBuffer(_Buffer):
|
||||||
|
def __init__(self, item_shape: Tuple, buffer_size: int, dtype, device):
|
||||||
|
self._buffer = torch.zeros(
|
||||||
|
(buffer_size, *item_shape), dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
self._curr_index = 0
|
||||||
|
|
||||||
|
def append(self, value: torch.Tensor):
|
||||||
|
self._buffer[self._curr_index] = value
|
||||||
|
self._curr_index = (self._curr_index + 1) % len(self._buffer)
|
||||||
|
|
||||||
|
def get_all(self) -> torch.Tensor:
|
||||||
|
return self._buffer
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self._buffer[...] = 0
|
||||||
|
|
||||||
|
|
||||||
|
class _InfiniteBuffer(_Buffer):
|
||||||
|
def __init__(self, item_shape: Tuple, dtype, device):
|
||||||
|
self._item_shape = item_shape
|
||||||
|
self._buffer = torch.zeros((128, *item_shape), dtype=dtype, device=device)
|
||||||
|
self._size = 0
|
||||||
|
|
||||||
|
def append(self, value: torch.Tensor):
|
||||||
|
curr_buffer_size = len(self._buffer)
|
||||||
|
dtype = self._buffer.dtype
|
||||||
|
device = self._buffer.device
|
||||||
|
|
||||||
|
if self._size == curr_buffer_size:
|
||||||
|
new_buffer = torch.zeros(
|
||||||
|
(2 * curr_buffer_size, *self._item_shape), dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
new_buffer[:curr_buffer_size] = self._buffer
|
||||||
|
self._buffer = new_buffer
|
||||||
|
|
||||||
|
self._buffer[self._size] = value
|
||||||
|
self._size += 1
|
||||||
|
|
||||||
|
def get_all(self) -> torch.Tensor:
|
||||||
|
return self._buffer[: self._size]
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self._buffer[...] = 0
|
||||||
|
self._size = 0
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_global_physical_count_to_logical_count(
|
||||||
|
# (whatever, num_layers, num_physical_experts)
|
||||||
|
global_physical_count: torch.Tensor,
|
||||||
|
num_layers: int,
|
||||||
|
num_logical_experts: int,
|
||||||
|
physical_to_logical_map: torch.Tensor,
|
||||||
|
):
|
||||||
|
dim_extra, _, _ = global_physical_count.shape
|
||||||
|
dtype = global_physical_count.dtype
|
||||||
|
device = global_physical_count.device
|
||||||
|
logical_count = torch.zeros(
|
||||||
|
(dim_extra, num_layers, num_logical_experts), dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
logical_count.scatter_add_(
|
||||||
|
dim=2,
|
||||||
|
index=physical_to_logical_map.unsqueeze(0).expand(dim_extra, -1, -1),
|
||||||
|
src=global_physical_count,
|
||||||
|
)
|
||||||
|
return logical_count
|
||||||
|
|||||||
273
python/sglang/srt/managers/expert_location.py
Normal file
273
python/sglang/srt/managers/expert_location.py
Normal file
@@ -0,0 +1,273 @@
|
|||||||
|
# Copyright 2023-2024 SGLang Team
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from sglang.srt.configs.model_config import ModelConfig
|
||||||
|
from sglang.srt.model_loader import get_model_architecture
|
||||||
|
from sglang.srt.server_args import ServerArgs
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ExpertLocationMetadata:
|
||||||
|
physical_to_logical_map: torch.Tensor # (layers, num_physical_experts)
|
||||||
|
logical_to_all_physical_map: torch.Tensor # (layers, num_logical_experts, X)
|
||||||
|
logical_to_all_physical_map_num_valid: torch.Tensor # (layers, num_logical_experts)
|
||||||
|
|
||||||
|
# -------------------------------- properties ------------------------------------
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_layers(self) -> int:
|
||||||
|
return self.physical_to_logical_map.shape[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_physical_experts(self) -> int:
|
||||||
|
return self.physical_to_logical_map.shape[1]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_local_physical_experts(self) -> int:
|
||||||
|
ans, remainder = divmod(self.num_physical_experts, self.ep_size)
|
||||||
|
assert remainder == 0
|
||||||
|
return ans
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_logical_experts(self) -> int:
|
||||||
|
return self.logical_to_all_physical_map.shape[1]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ep_size(self):
|
||||||
|
# TODO change when EP size != world size
|
||||||
|
return torch.distributed.get_world_size()
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
num_layers_0, num_physical_experts_0 = self.physical_to_logical_map.shape
|
||||||
|
num_layers_1, num_logical_experts_0, num_physical_experts_1 = (
|
||||||
|
self.logical_to_all_physical_map.shape
|
||||||
|
)
|
||||||
|
num_layers_2, num_logical_experts_1 = (
|
||||||
|
self.logical_to_all_physical_map_num_valid.shape
|
||||||
|
)
|
||||||
|
# TODO pr-chain: enable this later
|
||||||
|
# assert num_layers_0 == num_layers_1 == num_layers_2 == num_layers_3
|
||||||
|
# assert num_logical_experts_0 == num_logical_experts_1 == num_logical_experts_2
|
||||||
|
assert num_physical_experts_0 == num_physical_experts_1
|
||||||
|
|
||||||
|
# -------------------------------- construction ------------------------------------
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def init_trivial(server_args: ServerArgs, model_config: ModelConfig):
|
||||||
|
"""Trivial location - logical expert i corresponds to physical expert i"""
|
||||||
|
common = ExpertLocationMetadata._init_common(server_args, model_config)
|
||||||
|
num_physical_experts = common["num_physical_experts"]
|
||||||
|
model_config_for_expert_location = common["model_config_for_expert_location"]
|
||||||
|
num_layers = model_config_for_expert_location.num_layers
|
||||||
|
num_logical_experts = model_config_for_expert_location.num_logical_experts
|
||||||
|
|
||||||
|
physical_to_logical_map = (
|
||||||
|
torch.arange(0, num_physical_experts).repeat(num_layers, 1)
|
||||||
|
% num_logical_experts
|
||||||
|
)
|
||||||
|
|
||||||
|
return ExpertLocationMetadata.init_by_mapping(
|
||||||
|
server_args,
|
||||||
|
model_config,
|
||||||
|
physical_to_logical_map=physical_to_logical_map,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def init_by_mapping(
|
||||||
|
server_args: ServerArgs,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
physical_to_logical_map,
|
||||||
|
):
|
||||||
|
if not isinstance(physical_to_logical_map, torch.Tensor):
|
||||||
|
physical_to_logical_map = torch.tensor(physical_to_logical_map)
|
||||||
|
physical_to_logical_map = physical_to_logical_map.to(server_args.device)
|
||||||
|
|
||||||
|
common = ExpertLocationMetadata._init_common(server_args, model_config)
|
||||||
|
model_config_for_expert_location = common["model_config_for_expert_location"]
|
||||||
|
logical_to_all_physical_map = _compute_logical_to_all_physical_map(
|
||||||
|
physical_to_logical_map,
|
||||||
|
num_logical_experts=model_config_for_expert_location.num_logical_experts,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ExpertLocationMetadata._init_raw(
|
||||||
|
ep_size=common["ep_size"],
|
||||||
|
physical_to_logical_map=physical_to_logical_map,
|
||||||
|
logical_to_all_physical_map=logical_to_all_physical_map,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _init_common(server_args: ServerArgs, model_config: ModelConfig):
|
||||||
|
model_config_for_expert_location = (
|
||||||
|
ModelConfigForExpertLocation.from_model_config(model_config)
|
||||||
|
)
|
||||||
|
|
||||||
|
num_physical_experts = (
|
||||||
|
model_config_for_expert_location.num_logical_experts
|
||||||
|
# TODO pr-chain: enable this later
|
||||||
|
# + server_args.ep_num_redundant_experts
|
||||||
|
)
|
||||||
|
ep_size = server_args.ep_size
|
||||||
|
assert num_physical_experts % ep_size == 0
|
||||||
|
num_local_physical_experts = num_physical_experts // ep_size
|
||||||
|
|
||||||
|
return dict(
|
||||||
|
model_config_for_expert_location=model_config_for_expert_location,
|
||||||
|
num_physical_experts=num_physical_experts,
|
||||||
|
num_local_physical_experts=num_local_physical_experts,
|
||||||
|
ep_size=ep_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _init_raw(
|
||||||
|
ep_size: int,
|
||||||
|
physical_to_logical_map: torch.Tensor,
|
||||||
|
logical_to_all_physical_map: torch.Tensor,
|
||||||
|
):
|
||||||
|
_, num_physical_experts = physical_to_logical_map.shape
|
||||||
|
|
||||||
|
logical_to_all_physical_map_padded = F.pad(
|
||||||
|
logical_to_all_physical_map,
|
||||||
|
(0, num_physical_experts - logical_to_all_physical_map.shape[-1]),
|
||||||
|
value=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
logical_to_all_physical_map_num_valid = torch.count_nonzero(
|
||||||
|
logical_to_all_physical_map != -1, dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
|
return ExpertLocationMetadata(
|
||||||
|
physical_to_logical_map=physical_to_logical_map,
|
||||||
|
logical_to_all_physical_map=logical_to_all_physical_map_padded,
|
||||||
|
logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_global_expert_location_metadata: Optional[ExpertLocationMetadata] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_global_expert_location_metadata():
|
||||||
|
return _global_expert_location_metadata
|
||||||
|
|
||||||
|
|
||||||
|
def set_global_expert_location_metadata(value):
|
||||||
|
global _global_expert_location_metadata
|
||||||
|
assert _global_expert_location_metadata is None
|
||||||
|
_global_expert_location_metadata = value
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_logical_to_all_physical_map(
|
||||||
|
physical_to_logical_map: torch.Tensor, num_logical_experts: int
|
||||||
|
):
|
||||||
|
# This is rarely called, so we use for loops for maximum clarity
|
||||||
|
|
||||||
|
num_layers, num_physical_experts = physical_to_logical_map.shape
|
||||||
|
|
||||||
|
logical_to_all_physical_map = [
|
||||||
|
[[] for _ in range(num_logical_experts)] for _ in range(num_layers)
|
||||||
|
]
|
||||||
|
for layer_id in range(num_layers):
|
||||||
|
for physical_expert_id in range(num_physical_experts):
|
||||||
|
logical_expert_id = physical_to_logical_map[
|
||||||
|
layer_id, physical_expert_id
|
||||||
|
].item()
|
||||||
|
logical_to_all_physical_map[layer_id][logical_expert_id].append(
|
||||||
|
physical_expert_id
|
||||||
|
)
|
||||||
|
|
||||||
|
logical_to_all_physical_map = _pad_nested_array(
|
||||||
|
logical_to_all_physical_map, pad_value=-1
|
||||||
|
)
|
||||||
|
|
||||||
|
return torch.tensor(
|
||||||
|
logical_to_all_physical_map, device=physical_to_logical_map.device
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _pad_nested_array(arr, pad_value):
|
||||||
|
max_len = max(len(inner) for outer in arr for inner in outer)
|
||||||
|
padded = [
|
||||||
|
[inner + [pad_value] * (max_len - len(inner)) for inner in outer]
|
||||||
|
for outer in arr
|
||||||
|
]
|
||||||
|
return padded
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelConfigForExpertLocation:
|
||||||
|
num_layers: int
|
||||||
|
num_logical_experts: int
|
||||||
|
num_groups: Optional[int] = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def init_dummy():
|
||||||
|
return ModelConfigForExpertLocation(num_layers=1, num_logical_experts=1)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_model_config(model_config: ModelConfig):
|
||||||
|
model_class, _ = get_model_architecture(model_config)
|
||||||
|
if hasattr(model_class, "get_model_config_for_expert_location"):
|
||||||
|
return model_class.get_model_config_for_expert_location(
|
||||||
|
model_config.hf_config
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return ModelConfigForExpertLocation.init_dummy()
|
||||||
|
|
||||||
|
|
||||||
|
def compute_initial_expert_location_metadata(
|
||||||
|
server_args: ServerArgs, model_config: ModelConfig
|
||||||
|
) -> ExpertLocationMetadata:
|
||||||
|
data = server_args.init_expert_location
|
||||||
|
if data == "trivial":
|
||||||
|
logger.info("init_expert_location from trivial")
|
||||||
|
return ExpertLocationMetadata.init_trivial(server_args, model_config)
|
||||||
|
|
||||||
|
# TODO unify with the utils function
|
||||||
|
if data.endswith(".pt"):
|
||||||
|
data_dict = torch.load(data, weights_only=True)
|
||||||
|
elif data.endswith(".json"):
|
||||||
|
data_dict = json.loads(Path(data).read_text())
|
||||||
|
else:
|
||||||
|
data_dict = json.loads(data)
|
||||||
|
|
||||||
|
if "physical_to_logical_map" in data_dict:
|
||||||
|
logger.info(
|
||||||
|
"init_expert_location from init_by_mapping using ServerArgs.init_expert_location"
|
||||||
|
)
|
||||||
|
return ExpertLocationMetadata.init_by_mapping(
|
||||||
|
server_args, model_config, **data_dict
|
||||||
|
)
|
||||||
|
elif "logical_count" in data_dict:
|
||||||
|
# TODO pr-chain: enable this later
|
||||||
|
raise NotImplementedError
|
||||||
|
# logger.info(
|
||||||
|
# "init_expert_location from init_by_eplb using ServerArgs.init_expert_location"
|
||||||
|
# )
|
||||||
|
# return ExpertLocationMetadata.init_by_eplb(
|
||||||
|
# server_args, model_config, logical_count=data_dict["logical_count"]
|
||||||
|
# )
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Unknown init_expert_location format ({list(data_dict.keys())=})"
|
||||||
|
)
|
||||||
@@ -59,7 +59,10 @@ from sglang.srt.hf_transformers_utils import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
|
from sglang.srt.managers.expert_distribution import (
|
||||||
|
ExpertDistributionRecorder,
|
||||||
|
get_global_expert_distribution_recorder,
|
||||||
|
)
|
||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
AbortReq,
|
AbortReq,
|
||||||
CloseSessionReqInput,
|
CloseSessionReqInput,
|
||||||
@@ -142,8 +145,6 @@ from sglang.srt.utils import (
|
|||||||
)
|
)
|
||||||
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
|
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
|
||||||
|
|
||||||
expert_distribution_recorder = ExpertDistributionRecorder()
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Test retract decode for debugging purposes
|
# Test retract decode for debugging purposes
|
||||||
@@ -2162,11 +2163,11 @@ class Scheduler(
|
|||||||
|
|
||||||
def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
|
def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
|
||||||
if recv_req == ExpertDistributionReq.START_RECORD:
|
if recv_req == ExpertDistributionReq.START_RECORD:
|
||||||
expert_distribution_recorder.start_record()
|
get_global_expert_distribution_recorder().start_record()
|
||||||
elif recv_req == ExpertDistributionReq.STOP_RECORD:
|
elif recv_req == ExpertDistributionReq.STOP_RECORD:
|
||||||
expert_distribution_recorder.stop_record()
|
get_global_expert_distribution_recorder().stop_record()
|
||||||
elif recv_req == ExpertDistributionReq.DUMP_RECORD:
|
elif recv_req == ExpertDistributionReq.DUMP_RECORD:
|
||||||
expert_distribution_recorder.dump_record()
|
get_global_expert_distribution_recorder().dump_record()
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unrecognized ExpertDistributionReq value")
|
raise ValueError("Unrecognized ExpertDistributionReq value")
|
||||||
return ExpertDistributionReqOutput()
|
return ExpertDistributionReqOutput()
|
||||||
|
|||||||
@@ -52,6 +52,16 @@ from sglang.srt.layers.quantization.deep_gemm import (
|
|||||||
from sglang.srt.layers.sampler import Sampler
|
from sglang.srt.layers.sampler import Sampler
|
||||||
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
|
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
|
||||||
from sglang.srt.lora.lora_manager import LoRAManager
|
from sglang.srt.lora.lora_manager import LoRAManager
|
||||||
|
from sglang.srt.managers.expert_distribution import (
|
||||||
|
ExpertDistributionRecorder,
|
||||||
|
get_global_expert_distribution_recorder,
|
||||||
|
set_global_expert_distribution_recorder,
|
||||||
|
)
|
||||||
|
from sglang.srt.managers.expert_location import (
|
||||||
|
compute_initial_expert_location_metadata,
|
||||||
|
get_global_expert_location_metadata,
|
||||||
|
set_global_expert_location_metadata,
|
||||||
|
)
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.mem_cache.memory_pool import (
|
from sglang.srt.mem_cache.memory_pool import (
|
||||||
DoubleSparseTokenToKVPool,
|
DoubleSparseTokenToKVPool,
|
||||||
@@ -161,6 +171,8 @@ class ModelRunner:
|
|||||||
self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA
|
self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA
|
||||||
self.attention_chunk_size = model_config.attention_chunk_size
|
self.attention_chunk_size = model_config.attention_chunk_size
|
||||||
|
|
||||||
|
self.forward_pass_id = 0
|
||||||
|
|
||||||
# Model-specific adjustment
|
# Model-specific adjustment
|
||||||
self.model_specific_adjustment()
|
self.model_specific_adjustment()
|
||||||
|
|
||||||
@@ -219,6 +231,25 @@ class ModelRunner:
|
|||||||
enable=self.server_args.enable_memory_saver
|
enable=self.server_args.enable_memory_saver
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not self.is_draft_worker:
|
||||||
|
set_global_expert_location_metadata(
|
||||||
|
compute_initial_expert_location_metadata(server_args, self.model_config)
|
||||||
|
)
|
||||||
|
if self.tp_rank == 0 and get_bool_env_var(
|
||||||
|
"SGLANG_LOG_EXPERT_LOCATION_METADATA"
|
||||||
|
):
|
||||||
|
logger.info(
|
||||||
|
f"Initial expert_location_metadata: {get_global_expert_location_metadata().debug_str()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
set_global_expert_distribution_recorder(
|
||||||
|
ExpertDistributionRecorder.init_new(
|
||||||
|
server_args,
|
||||||
|
get_global_expert_location_metadata(),
|
||||||
|
rank=self.tp_rank,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Load the model
|
# Load the model
|
||||||
self.sampler = Sampler()
|
self.sampler = Sampler()
|
||||||
self.load_model()
|
self.load_model()
|
||||||
@@ -1093,6 +1124,22 @@ class ModelRunner:
|
|||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
skip_attn_backend_init: bool = False,
|
skip_attn_backend_init: bool = False,
|
||||||
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
||||||
|
) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
|
||||||
|
self.forward_pass_id += 1
|
||||||
|
|
||||||
|
with get_global_expert_distribution_recorder().with_forward_pass(
|
||||||
|
self.forward_pass_id,
|
||||||
|
forward_batch,
|
||||||
|
):
|
||||||
|
return self._forward_raw(
|
||||||
|
forward_batch, skip_attn_backend_init, pp_proxy_tensors
|
||||||
|
)
|
||||||
|
|
||||||
|
def _forward_raw(
|
||||||
|
self,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
skip_attn_backend_init: bool,
|
||||||
|
pp_proxy_tensors: Optional[PPProxyTensors],
|
||||||
) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
|
) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
|
||||||
can_run_cuda_graph = bool(
|
can_run_cuda_graph = bool(
|
||||||
forward_batch.forward_mode.is_cuda_graph()
|
forward_batch.forward_mode.is_cuda_graph()
|
||||||
|
|||||||
@@ -77,7 +77,11 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|||||||
ParallelLMHead,
|
ParallelLMHead,
|
||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
|
from sglang.srt.managers.expert_distribution import (
|
||||||
|
ExpertDistributionRecorder,
|
||||||
|
get_global_expert_distribution_recorder,
|
||||||
|
)
|
||||||
|
from sglang.srt.managers.expert_location import ModelConfigForExpertLocation
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
@@ -109,8 +113,6 @@ if _is_hip:
|
|||||||
decode_attention_fwd_grouped_rope,
|
decode_attention_fwd_grouped_rope,
|
||||||
)
|
)
|
||||||
|
|
||||||
expert_distribution_recorder = ExpertDistributionRecorder()
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -302,6 +304,7 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
|
self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
forward_mode = forward_batch.forward_mode
|
||||||
if (not self._enable_deepep_moe) or is_non_idle_and_non_empty(
|
if (not self._enable_deepep_moe) or is_non_idle_and_non_empty(
|
||||||
forward_mode, hidden_states
|
forward_mode, hidden_states
|
||||||
):
|
):
|
||||||
@@ -1278,7 +1281,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Fully Connected
|
# Fully Connected
|
||||||
hidden_states = self.mlp(hidden_states)
|
hidden_states = self.mlp(hidden_states, forward_batch)
|
||||||
|
|
||||||
# TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
|
# TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
|
||||||
# Scatter
|
# Scatter
|
||||||
@@ -1422,11 +1425,11 @@ class DeepseekV2Model(nn.Module):
|
|||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
for i in range(len(self.layers)):
|
for i in range(len(self.layers)):
|
||||||
expert_distribution_recorder.set_current_layer(i)
|
with get_global_expert_distribution_recorder().with_current_layer(i):
|
||||||
layer = self.layers[i]
|
layer = self.layers[i]
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
positions, hidden_states, forward_batch, residual, zero_allocator
|
positions, hidden_states, forward_batch, residual, zero_allocator
|
||||||
)
|
)
|
||||||
if not forward_batch.forward_mode.is_idle():
|
if not forward_batch.forward_mode.is_idle():
|
||||||
if residual is None:
|
if residual is None:
|
||||||
hidden_states = self.norm(hidden_states)
|
hidden_states = self.norm(hidden_states)
|
||||||
@@ -1872,6 +1875,14 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_model_config_for_expert_location(cls, config):
|
||||||
|
return ModelConfigForExpertLocation(
|
||||||
|
num_layers=config.num_hidden_layers,
|
||||||
|
num_logical_experts=config.n_routed_experts,
|
||||||
|
num_groups=config.n_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
|
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -59,14 +59,16 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|||||||
ParallelLMHead,
|
ParallelLMHead,
|
||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
|
from sglang.srt.managers.expert_distribution import (
|
||||||
|
ExpertDistributionRecorder,
|
||||||
|
get_global_expert_distribution_recorder,
|
||||||
|
)
|
||||||
|
from sglang.srt.managers.expert_location import ModelConfigForExpertLocation
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
from sglang.srt.utils import add_prefix, make_layers
|
from sglang.srt.utils import add_prefix, make_layers
|
||||||
|
|
||||||
expert_distribution_recorder = ExpertDistributionRecorder()
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -591,11 +593,11 @@ class Qwen2MoeModel(nn.Module):
|
|||||||
residual = pp_proxy_tensors["residual"]
|
residual = pp_proxy_tensors["residual"]
|
||||||
|
|
||||||
for i in range(self.start_layer, self.end_layer):
|
for i in range(self.start_layer, self.end_layer):
|
||||||
expert_distribution_recorder.set_current_layer(i)
|
with get_global_expert_distribution_recorder().with_current_layer(i):
|
||||||
layer = self.layers[i]
|
layer = self.layers[i]
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
positions, hidden_states, forward_batch, residual
|
positions, hidden_states, forward_batch, residual
|
||||||
)
|
)
|
||||||
if not self.pp_group.is_last_rank:
|
if not self.pp_group.is_last_rank:
|
||||||
return PPProxyTensors(
|
return PPProxyTensors(
|
||||||
{
|
{
|
||||||
@@ -752,5 +754,13 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|||||||
else:
|
else:
|
||||||
logger.warning(f"Parameter {name} not found in params_dict")
|
logger.warning(f"Parameter {name} not found in params_dict")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_model_config_for_expert_location(cls, config):
|
||||||
|
return ModelConfigForExpertLocation(
|
||||||
|
num_layers=config.num_hidden_layers,
|
||||||
|
num_logical_experts=config.num_experts,
|
||||||
|
num_groups=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
EntryClass = Qwen2MoeForCausalLM
|
EntryClass = Qwen2MoeForCausalLM
|
||||||
|
|||||||
@@ -170,6 +170,11 @@ class ServerArgs:
|
|||||||
enable_ep_moe: bool = False
|
enable_ep_moe: bool = False
|
||||||
enable_deepep_moe: bool = False
|
enable_deepep_moe: bool = False
|
||||||
deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
|
deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
|
||||||
|
init_expert_location: str = "trivial"
|
||||||
|
expert_distribution_recorder_mode: Optional[
|
||||||
|
Literal["stat", "per_pass", "per_token"]
|
||||||
|
] = None
|
||||||
|
expert_distribution_recorder_buffer_size: Optional[int] = None
|
||||||
deepep_config: Optional[str] = None
|
deepep_config: Optional[str] = None
|
||||||
enable_torch_compile: bool = False
|
enable_torch_compile: bool = False
|
||||||
torch_compile_max_bs: int = 32
|
torch_compile_max_bs: int = 32
|
||||||
@@ -361,6 +366,15 @@ class ServerArgs:
|
|||||||
"Pipeline parallelism is incompatible with overlap schedule."
|
"Pipeline parallelism is incompatible with overlap schedule."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.expert_distribution_recorder_buffer_size is None:
|
||||||
|
# TODO pr-chain: enable this later
|
||||||
|
# if (x := self.eplb_rebalance_num_iterations) is not None:
|
||||||
|
# self.expert_distribution_recorder_buffer_size = x
|
||||||
|
if False:
|
||||||
|
pass
|
||||||
|
elif self.expert_distribution_recorder_mode is not None:
|
||||||
|
self.expert_distribution_recorder_buffer_size = 1000
|
||||||
|
|
||||||
# Speculative Decoding
|
# Speculative Decoding
|
||||||
if self.speculative_algorithm == "NEXTN":
|
if self.speculative_algorithm == "NEXTN":
|
||||||
# NEXTN shares the same implementation of EAGLE
|
# NEXTN shares the same implementation of EAGLE
|
||||||
@@ -1257,6 +1271,24 @@ class ServerArgs:
|
|||||||
default="auto",
|
default="auto",
|
||||||
help="Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.",
|
help="Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--init-expert-location",
|
||||||
|
type=str,
|
||||||
|
default=ServerArgs.init_expert_location,
|
||||||
|
help="Initial location of EP experts.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--expert-distribution-recorder-mode",
|
||||||
|
type=str,
|
||||||
|
default=ServerArgs.expert_distribution_recorder_mode,
|
||||||
|
help="Mode of expert distribution recorder.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--expert-distribution-recorder-buffer-size",
|
||||||
|
type=int,
|
||||||
|
default=ServerArgs.expert_distribution_recorder_buffer_size,
|
||||||
|
help="Circular buffer size of expert distribution recorder. Set to -1 to denote infinite buffer.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--deepep-config",
|
"--deepep-config",
|
||||||
type=str,
|
type=str,
|
||||||
|
|||||||
@@ -46,7 +46,19 @@ from importlib.util import find_spec
|
|||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from multiprocessing.reduction import ForkingPickler
|
from multiprocessing.reduction import ForkingPickler
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Tuple, Union
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
Generic,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Protocol,
|
||||||
|
Set,
|
||||||
|
Tuple,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import psutil
|
import psutil
|
||||||
@@ -2126,3 +2138,25 @@ def load_json_config(data: str):
|
|||||||
|
|
||||||
def dispose_tensor(x: torch.Tensor):
|
def dispose_tensor(x: torch.Tensor):
|
||||||
x.set_(torch.empty((0,), device=x.device, dtype=x.dtype))
|
x.set_(torch.empty((0,), device=x.device, dtype=x.dtype))
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class Withable(Generic[T]):
|
||||||
|
def __init__(self):
|
||||||
|
self._value: Optional[T] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def value(self) -> T:
|
||||||
|
return self._value
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def with_value(self, new_value: T):
|
||||||
|
assert self._value is None
|
||||||
|
self._value = new_value
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
assert self._value is new_value
|
||||||
|
self._value = None
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
import csv
|
|
||||||
import glob
|
|
||||||
import os
|
import os
|
||||||
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
import torch
|
||||||
|
|
||||||
from sglang.srt.utils import kill_process_tree
|
from sglang.srt.utils import kill_process_tree
|
||||||
from sglang.test.test_utils import (
|
from sglang.test.test_utils import (
|
||||||
@@ -16,108 +17,86 @@ from sglang.test.test_utils import (
|
|||||||
|
|
||||||
|
|
||||||
class TestExpertDistribution(CustomTestCase):
|
class TestExpertDistribution(CustomTestCase):
|
||||||
def setUp(self):
|
|
||||||
# Clean up any existing expert distribution files before each test
|
|
||||||
for f in glob.glob("expert_distribution_*.csv"):
|
|
||||||
os.remove(f)
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
# Clean up any expert distribution files after each test
|
|
||||||
for f in glob.glob("expert_distribution_*.csv"):
|
|
||||||
os.remove(f)
|
|
||||||
|
|
||||||
def test_expert_distribution_record(self):
|
def test_expert_distribution_record(self):
|
||||||
|
# TODO: Add tests for DeepEP gatherer (currently our CI cannot run that)
|
||||||
|
for info in [
|
||||||
|
dict(model_path="deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"),
|
||||||
|
dict(model_path="Qwen/Qwen1.5-MoE-A2.7B"),
|
||||||
|
dict(model_path="Qwen/Qwen1.5-MoE-A2.7B", tp_size=2),
|
||||||
|
# TODO enable in next PR
|
||||||
|
# dict(model_path="Qwen/Qwen1.5-MoE-A2.7B", mode="per_pass"),
|
||||||
|
# dict(model_path="Qwen/Qwen1.5-MoE-A2.7B", mode="per_token"),
|
||||||
|
]:
|
||||||
|
with self.subTest(info=info):
|
||||||
|
self._execute_core(**info)
|
||||||
|
|
||||||
|
def _execute_core(self, model_path: str, mode: str = "stat", tp_size: int = 1):
|
||||||
"""Test expert distribution record endpoints"""
|
"""Test expert distribution record endpoints"""
|
||||||
process = popen_launch_server(
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
# The feature is only implemented in deepseek_v2.py
|
os.environ["SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR"] = tmp_dir
|
||||||
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct",
|
|
||||||
DEFAULT_URL_FOR_TEST,
|
|
||||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
|
||||||
other_args=[
|
|
||||||
"--trust-remote-code",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
process = popen_launch_server(
|
||||||
# Start recording
|
model_path,
|
||||||
response = requests.post(
|
DEFAULT_URL_FOR_TEST,
|
||||||
f"{DEFAULT_URL_FOR_TEST}/start_expert_distribution_record"
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
)
|
other_args=[
|
||||||
self.assertEqual(response.status_code, 200)
|
"--trust-remote-code",
|
||||||
|
"--tp-size",
|
||||||
# Make some requests to generate expert distribution data
|
str(tp_size),
|
||||||
response = requests.post(
|
"--expert-distribution-recorder-mode",
|
||||||
f"{DEFAULT_URL_FOR_TEST}/generate",
|
mode,
|
||||||
json={
|
"--disable-cuda-graph",
|
||||||
"text": "The capital of France is",
|
"--disable-overlap-schedule",
|
||||||
"sampling_params": {
|
],
|
||||||
"temperature": 0,
|
|
||||||
"max_new_tokens": 32,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
self.assertEqual(response.status_code, 200)
|
|
||||||
|
|
||||||
# Stop recording
|
|
||||||
response = requests.post(
|
|
||||||
f"{DEFAULT_URL_FOR_TEST}/stop_expert_distribution_record"
|
|
||||||
)
|
|
||||||
self.assertEqual(response.status_code, 200)
|
|
||||||
|
|
||||||
# Dump the recorded data
|
|
||||||
response = requests.post(
|
|
||||||
f"{DEFAULT_URL_FOR_TEST}/dump_expert_distribution_record"
|
|
||||||
)
|
|
||||||
self.assertEqual(response.status_code, 200)
|
|
||||||
|
|
||||||
# Verify the dumped file exists and has correct format
|
|
||||||
csv_files = glob.glob("expert_distribution_*.csv")
|
|
||||||
self.assertEqual(
|
|
||||||
len(csv_files),
|
|
||||||
1,
|
|
||||||
f"Expected exactly one expert distribution CSV file {csv_files=}",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check CSV file format
|
try:
|
||||||
with open(csv_files[0], "r") as f:
|
# Start recording
|
||||||
csv_reader = csv.reader(f)
|
response = requests.post(
|
||||||
|
f"{DEFAULT_URL_FOR_TEST}/start_expert_distribution_record"
|
||||||
# Check header
|
|
||||||
header = next(csv_reader)
|
|
||||||
self.assertEqual(
|
|
||||||
header,
|
|
||||||
["layer_id", "expert_id", "count"],
|
|
||||||
"CSV header should be 'layer_id,expert_id,count'",
|
|
||||||
)
|
)
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
|
||||||
|
# Make some requests to generate expert distribution data
|
||||||
|
response = requests.post(
|
||||||
|
f"{DEFAULT_URL_FOR_TEST}/generate",
|
||||||
|
json={
|
||||||
|
"text": "The capital of France is",
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0,
|
||||||
|
"max_new_tokens": 32,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
|
||||||
|
# Stop recording
|
||||||
|
response = requests.post(
|
||||||
|
f"{DEFAULT_URL_FOR_TEST}/stop_expert_distribution_record"
|
||||||
|
)
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
|
||||||
|
# Dump the recorded data
|
||||||
|
response = requests.post(
|
||||||
|
f"{DEFAULT_URL_FOR_TEST}/dump_expert_distribution_record"
|
||||||
|
)
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
|
||||||
# Check data rows
|
# Check data rows
|
||||||
rows = list(csv_reader)
|
data = torch.load(
|
||||||
self.assertGreater(len(rows), 0, "CSV file should contain data rows")
|
list(Path(tmp_dir).glob("*.pt"))[0], weights_only=True
|
||||||
|
)
|
||||||
|
print(f"{data=}")
|
||||||
|
|
||||||
for row in rows:
|
if mode in ["per_pass", "per_token"]:
|
||||||
# Verify each row has 3 columns
|
self.assertGreater(len(data), 0, "Should contain data rows")
|
||||||
self.assertEqual(
|
else:
|
||||||
len(row),
|
logical_count = data["logical_count"]
|
||||||
3,
|
print(f"{logical_count.sum()=} {logical_count=}")
|
||||||
"Each row should have layer_id, expert_id and count",
|
self.assertTrue(logical_count.sum() > 0)
|
||||||
)
|
|
||||||
|
|
||||||
# Verify data types
|
finally:
|
||||||
layer_id, expert_id, count = row
|
kill_process_tree(process.pid)
|
||||||
self.assertTrue(
|
|
||||||
layer_id.isdigit(),
|
|
||||||
f"layer_id should be an integer {row=} {rows=}",
|
|
||||||
)
|
|
||||||
self.assertTrue(
|
|
||||||
expert_id.isdigit(),
|
|
||||||
f"expert_id should be an integer {row=} {rows=}",
|
|
||||||
)
|
|
||||||
self.assertTrue(
|
|
||||||
count.isdigit(), f"count should be an integer {row=} {rows=}"
|
|
||||||
)
|
|
||||||
|
|
||||||
finally:
|
|
||||||
kill_process_tree(process.pid)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user