Add endpoints to dump selected expert ids (#4435)
Co-authored-by: Cheng Wan <54331508+ch-wan@users.noreply.github.com>
This commit is contained in:
@@ -17,6 +17,9 @@
|
|||||||
"- `/update_weights`\n",
|
"- `/update_weights`\n",
|
||||||
"- `/encode`(embedding model)\n",
|
"- `/encode`(embedding model)\n",
|
||||||
"- `/classify`(reward model)\n",
|
"- `/classify`(reward model)\n",
|
||||||
|
"- `/start_expert_distribution_record`\n",
|
||||||
|
"- `/stop_expert_distribution_record`\n",
|
||||||
|
"- `/dump_expert_distribution_record`\n",
|
||||||
"\n",
|
"\n",
|
||||||
"We mainly use `requests` to test these APIs in the following examples. You can also use `curl`."
|
"We mainly use `requests` to test these APIs in the following examples. You can also use `curl`."
|
||||||
]
|
]
|
||||||
@@ -362,6 +365,67 @@
|
|||||||
"terminate_process(reward_process)"
|
"terminate_process(reward_process)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Capture expert selection distribution in MoE models\n",
|
||||||
|
"\n",
|
||||||
|
"SGLang Runtime supports recording the number of times an expert is selected in a MoE model run for each expert in the model. This is useful when analyzing the throughput of the model and plan for optimization."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"expert_record_server_process, port = launch_server_cmd(\n",
|
||||||
|
" \"python -m sglang.launch_server --model-path Qwen/Qwen1.5-MoE-A2.7B --host 0.0.0.0\"\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"wait_for_server(f\"http://localhost:{port}\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"response = requests.post(f\"http://localhost:{port}/start_expert_distribution_record\")\n",
|
||||||
|
"print_highlight(response)\n",
|
||||||
|
"\n",
|
||||||
|
"url = f\"http://localhost:{port}/generate\"\n",
|
||||||
|
"data = {\"text\": \"What is the capital of France?\"}\n",
|
||||||
|
"\n",
|
||||||
|
"response = requests.post(url, json=data)\n",
|
||||||
|
"print_highlight(response.json())\n",
|
||||||
|
"\n",
|
||||||
|
"response = requests.post(f\"http://localhost:{port}/stop_expert_distribution_record\")\n",
|
||||||
|
"print_highlight(response)\n",
|
||||||
|
"\n",
|
||||||
|
"response = requests.post(f\"http://localhost:{port}/dump_expert_distribution_record\")\n",
|
||||||
|
"print_highlight(response)\n",
|
||||||
|
"\n",
|
||||||
|
"import glob\n",
|
||||||
|
"\n",
|
||||||
|
"output_file = glob.glob(\"expert_distribution_*.csv\")[0]\n",
|
||||||
|
"with open(output_file, \"r\") as f:\n",
|
||||||
|
" print_highlight(\"Content of dumped record:\")\n",
|
||||||
|
" for line in f:\n",
|
||||||
|
" print_highlight(line.strip())"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"terminate_process(expert_record_server_process)"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
|
|||||||
@@ -343,6 +343,36 @@ async def stop_profile_async():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.api_route("/start_expert_distribution_record", methods=["GET", "POST"])
|
||||||
|
async def start_expert_distribution_record_async():
|
||||||
|
"""Start recording the expert distribution. Clear the previous record if any."""
|
||||||
|
_global_state.tokenizer_manager.start_expert_distribution_record()
|
||||||
|
return Response(
|
||||||
|
content="Start recording the expert distribution.\n",
|
||||||
|
status_code=200,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.api_route("/stop_expert_distribution_record", methods=["GET", "POST"])
|
||||||
|
async def stop_expert_distribution_record_async():
|
||||||
|
"""Stop recording the expert distribution."""
|
||||||
|
_global_state.tokenizer_manager.stop_expert_distribution_record()
|
||||||
|
return Response(
|
||||||
|
content="Stop recording the expert distribution.\n",
|
||||||
|
status_code=200,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.api_route("/dump_expert_distribution_record", methods=["GET", "POST"])
|
||||||
|
async def dump_expert_distribution_record_async():
|
||||||
|
"""Dump expert distribution record."""
|
||||||
|
_global_state.tokenizer_manager.dump_expert_distribution_record()
|
||||||
|
return Response(
|
||||||
|
content="Dump expert distribution record.\n",
|
||||||
|
status_code=200,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/update_weights_from_disk")
|
@app.post("/update_weights_from_disk")
|
||||||
async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request):
|
async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request):
|
||||||
"""Update the weights from disk inplace without re-launching the server."""
|
"""Update the weights from disk inplace without re-launching the server."""
|
||||||
|
|||||||
@@ -21,6 +21,10 @@ from sglang.srt.utils import get_compiler_backend, is_cuda
|
|||||||
|
|
||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
|
|
||||||
|
from sglang.srt.managers.utils import ExpertDistributionRecorder
|
||||||
|
|
||||||
|
expert_distribution_recorder = ExpertDistributionRecorder()
|
||||||
|
|
||||||
|
|
||||||
def fused_topk_native(
|
def fused_topk_native(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@@ -223,4 +227,6 @@ def select_experts(
|
|||||||
renormalize=renormalize,
|
renormalize=renormalize,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
expert_distribution_recorder.record_new_token(topk_ids)
|
||||||
|
|
||||||
return topk_weights, topk_ids
|
return topk_weights, topk_ids
|
||||||
|
|||||||
@@ -658,6 +658,12 @@ class ProfileReqType(Enum):
|
|||||||
STOP_PROFILE = 2
|
STOP_PROFILE = 2
|
||||||
|
|
||||||
|
|
||||||
|
class ExpertDistributionReq(Enum):
|
||||||
|
START_RECORD = 1
|
||||||
|
STOP_RECORD = 2
|
||||||
|
DUMP_RECORD = 3
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ProfileReq:
|
class ProfileReq:
|
||||||
type: ProfileReqType
|
type: ProfileReqType
|
||||||
|
|||||||
@@ -56,6 +56,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
|||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
AbortReq,
|
AbortReq,
|
||||||
CloseSessionReqInput,
|
CloseSessionReqInput,
|
||||||
|
ExpertDistributionReq,
|
||||||
FlushCacheReq,
|
FlushCacheReq,
|
||||||
GetInternalStateReq,
|
GetInternalStateReq,
|
||||||
GetInternalStateReqOutput,
|
GetInternalStateReqOutput,
|
||||||
@@ -104,7 +105,7 @@ from sglang.srt.managers.scheduler_output_processor_mixin import (
|
|||||||
from sglang.srt.managers.session_controller import Session
|
from sglang.srt.managers.session_controller import Session
|
||||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||||
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
||||||
from sglang.srt.managers.utils import validate_input_length
|
from sglang.srt.managers.utils import ExpertDistributionRecorder, validate_input_length
|
||||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||||
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
||||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||||
@@ -128,6 +129,8 @@ from sglang.srt.utils import (
|
|||||||
)
|
)
|
||||||
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
|
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
|
||||||
|
|
||||||
|
expert_distribution_recorder = ExpertDistributionRecorder()
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Test retract decode for debugging purposes
|
# Test retract decode for debugging purposes
|
||||||
@@ -403,6 +406,7 @@ class Scheduler(
|
|||||||
(GetInternalStateReq, self.get_internal_state),
|
(GetInternalStateReq, self.get_internal_state),
|
||||||
(SetInternalStateReq, self.set_internal_state),
|
(SetInternalStateReq, self.set_internal_state),
|
||||||
(RpcReqInput, self.handle_rpc_request),
|
(RpcReqInput, self.handle_rpc_request),
|
||||||
|
(ExpertDistributionReq, self.expert_distribution_handle),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1892,6 +1896,16 @@ class Scheduler(
|
|||||||
ProfileReqOutput(success=True, message="Succeeded.")
|
ProfileReqOutput(success=True, message="Succeeded.")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
|
||||||
|
if recv_req == ExpertDistributionReq.START_RECORD:
|
||||||
|
expert_distribution_recorder.start_record()
|
||||||
|
elif recv_req == ExpertDistributionReq.STOP_RECORD:
|
||||||
|
expert_distribution_recorder.stop_record()
|
||||||
|
elif recv_req == ExpertDistributionReq.DUMP_RECORD:
|
||||||
|
expert_distribution_recorder.dump_record()
|
||||||
|
else:
|
||||||
|
raise ValueError("Unrecognized ExpertDistributionReq value")
|
||||||
|
|
||||||
def open_session(self, recv_req: OpenSessionReqInput):
|
def open_session(self, recv_req: OpenSessionReqInput):
|
||||||
# handle error
|
# handle error
|
||||||
session_id = recv_req.session_id
|
session_id = recv_req.session_id
|
||||||
|
|||||||
@@ -60,6 +60,7 @@ from sglang.srt.managers.io_struct import (
|
|||||||
CloseSessionReqInput,
|
CloseSessionReqInput,
|
||||||
ConfigureLoggingReq,
|
ConfigureLoggingReq,
|
||||||
EmbeddingReqInput,
|
EmbeddingReqInput,
|
||||||
|
ExpertDistributionReq,
|
||||||
FlushCacheReq,
|
FlushCacheReq,
|
||||||
GenerateReqInput,
|
GenerateReqInput,
|
||||||
GetInternalStateReq,
|
GetInternalStateReq,
|
||||||
@@ -638,6 +639,18 @@ class TokenizerManager:
|
|||||||
req = ProfileReq(type=ProfileReqType.STOP_PROFILE)
|
req = ProfileReq(type=ProfileReqType.STOP_PROFILE)
|
||||||
self.send_to_scheduler.send_pyobj(req)
|
self.send_to_scheduler.send_pyobj(req)
|
||||||
|
|
||||||
|
def start_expert_distribution_record(self):
|
||||||
|
req = ExpertDistributionReq.START_RECORD
|
||||||
|
self.send_to_scheduler.send_pyobj(req)
|
||||||
|
|
||||||
|
def stop_expert_distribution_record(self):
|
||||||
|
req = ExpertDistributionReq.STOP_RECORD
|
||||||
|
self.send_to_scheduler.send_pyobj(req)
|
||||||
|
|
||||||
|
def dump_expert_distribution_record(self):
|
||||||
|
req = ExpertDistributionReq.DUMP_RECORD
|
||||||
|
self.send_to_scheduler.send_pyobj(req)
|
||||||
|
|
||||||
async def update_weights_from_disk(
|
async def update_weights_from_disk(
|
||||||
self,
|
self,
|
||||||
obj: UpdateWeightFromDiskReqInput,
|
obj: UpdateWeightFromDiskReqInput,
|
||||||
|
|||||||
@@ -1,6 +1,11 @@
|
|||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
|
from collections import defaultdict
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import Optional
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req
|
from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req
|
||||||
|
|
||||||
@@ -42,3 +47,75 @@ def validate_input_length(
|
|||||||
return error_msg
|
return error_msg
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# global expert distribution recording
|
||||||
|
class ExpertDistributionRecorder:
|
||||||
|
# This class is a singleton class
|
||||||
|
def __new__(cls):
|
||||||
|
if not hasattr(cls, "instance"):
|
||||||
|
cls.instance = super(ExpertDistributionRecorder, cls).__new__(cls)
|
||||||
|
return cls.instance
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# the length of the dictionary is the number of layers
|
||||||
|
# the length of the list is the number of tokens
|
||||||
|
# the length of the tuple is topk's k value
|
||||||
|
self._expert_distribution_record: Dict[int, List[Tuple[int]]] = defaultdict(
|
||||||
|
list
|
||||||
|
)
|
||||||
|
self._record = False
|
||||||
|
self._current_layer_id = "UNKNOWN"
|
||||||
|
|
||||||
|
def set_current_layer(self, layer_idx):
|
||||||
|
self._current_layer_id = layer_idx
|
||||||
|
|
||||||
|
def record_new_token(self, topk_ids):
|
||||||
|
if not self._record:
|
||||||
|
return
|
||||||
|
topk_ids_list = topk_ids.to("cpu", non_blocking=True).numpy().tolist()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
for i in topk_ids_list:
|
||||||
|
self._expert_distribution_record[self._current_layer_id].append(tuple(i))
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""Reset the expert distribution recorder."""
|
||||||
|
logger.info("Resetting expert distribution record...")
|
||||||
|
self._record = False
|
||||||
|
self._expert_distribution_record.clear()
|
||||||
|
self._current_layer_id = "UNKNOWN"
|
||||||
|
|
||||||
|
def start_record(self):
|
||||||
|
"""Start recording the expert distribution. Reset the recorder and set the recording flag to True."""
|
||||||
|
if self._record == True:
|
||||||
|
logger.warning(
|
||||||
|
"SGLang server is already recording expert ids. Did you forget to dump the expert ids recorded so far by sending requests to the `/stop_expert_distribution_record` and `/dump_expert_distribution_record` endpoints?"
|
||||||
|
)
|
||||||
|
self.reset()
|
||||||
|
self._record = True
|
||||||
|
|
||||||
|
def stop_record(self):
|
||||||
|
"""Stop recording the expert distribution. Set the recording flag to False."""
|
||||||
|
if self._record == False:
|
||||||
|
logger.warning(
|
||||||
|
"SGLang server has not been recording expert ids. Did you forget to start recording by sending request to the `/start_expert_distribution_record` endpoint?"
|
||||||
|
)
|
||||||
|
self._record = False
|
||||||
|
|
||||||
|
def dump_record(self):
|
||||||
|
"""Dump the expert distribution record to a file. Reset the recorder after dumping."""
|
||||||
|
results = {}
|
||||||
|
for layer_idx, layer_record in self._expert_distribution_record.items():
|
||||||
|
results[layer_idx] = defaultdict(int)
|
||||||
|
for token_record in layer_record:
|
||||||
|
for expert_idx in token_record:
|
||||||
|
results[layer_idx][expert_idx] += 1
|
||||||
|
with open(
|
||||||
|
f"expert_distribution_rank{torch.distributed.get_rank()}_timestamp{time.time()}.csv",
|
||||||
|
"w",
|
||||||
|
) as fd:
|
||||||
|
fd.write("layer_id,expert_id,count\n")
|
||||||
|
for layer_idx, layer_results in results.items():
|
||||||
|
for expert_idx, count in layer_results.items():
|
||||||
|
fd.write(f"{layer_idx},{expert_idx},{count}\n")
|
||||||
|
self.reset()
|
||||||
|
|||||||
@@ -68,6 +68,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
|
from sglang.srt.managers.utils import ExpertDistributionRecorder
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
from sglang.srt.utils import add_prefix, is_cuda, is_cuda_available, is_hip
|
from sglang.srt.utils import add_prefix, is_cuda, is_cuda_available, is_hip
|
||||||
@@ -80,6 +81,8 @@ if _is_cuda:
|
|||||||
else:
|
else:
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
|
|
||||||
|
expert_distribution_recorder = ExpertDistributionRecorder()
|
||||||
|
|
||||||
|
|
||||||
class DeepseekV2MLP(nn.Module):
|
class DeepseekV2MLP(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -1160,6 +1163,7 @@ class DeepseekV2Model(nn.Module):
|
|||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
for i in range(len(self.layers)):
|
for i in range(len(self.layers)):
|
||||||
|
expert_distribution_recorder.set_current_layer(i)
|
||||||
layer = self.layers[i]
|
layer = self.layers[i]
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
positions, hidden_states, forward_batch, residual
|
positions, hidden_states, forward_batch, residual
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ suites = {
|
|||||||
TestFile("test_ebnf_constrained.py"),
|
TestFile("test_ebnf_constrained.py"),
|
||||||
TestFile("test_fp8_kernel.py", 2),
|
TestFile("test_fp8_kernel.py", 2),
|
||||||
TestFile("test_embedding_openai_server.py", 36),
|
TestFile("test_embedding_openai_server.py", 36),
|
||||||
|
TestFile("test_expert_distribution.py", 31),
|
||||||
TestFile("test_gguf.py", 78),
|
TestFile("test_gguf.py", 78),
|
||||||
TestFile("test_gptqmodel_dynamic.py", 72),
|
TestFile("test_gptqmodel_dynamic.py", 72),
|
||||||
TestFile("test_hidden_states.py", 55),
|
TestFile("test_hidden_states.py", 55),
|
||||||
|
|||||||
111
test/srt/test_expert_distribution.py
Executable file
111
test/srt/test_expert_distribution.py
Executable file
@@ -0,0 +1,111 @@
|
|||||||
|
import csv
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from sglang.srt.utils import kill_process_tree
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST,
|
||||||
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
popen_launch_server,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestExpertDistribution(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
# Clean up any existing expert distribution files before each test
|
||||||
|
for f in glob.glob("expert_distribution_*.csv"):
|
||||||
|
os.remove(f)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
# Clean up any expert distribution files after each test
|
||||||
|
for f in glob.glob("expert_distribution_*.csv"):
|
||||||
|
os.remove(f)
|
||||||
|
|
||||||
|
def test_expert_distribution_record(self):
|
||||||
|
"""Test expert distribution record endpoints"""
|
||||||
|
process = popen_launch_server(
|
||||||
|
DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST,
|
||||||
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Start recording
|
||||||
|
response = requests.post(
|
||||||
|
f"{DEFAULT_URL_FOR_TEST}/start_expert_distribution_record"
|
||||||
|
)
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
|
||||||
|
# Make some requests to generate expert distribution data
|
||||||
|
response = requests.post(
|
||||||
|
f"{DEFAULT_URL_FOR_TEST}/generate",
|
||||||
|
json={
|
||||||
|
"text": "The capital of France is",
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0,
|
||||||
|
"max_new_tokens": 32,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
|
||||||
|
# Stop recording
|
||||||
|
response = requests.post(
|
||||||
|
f"{DEFAULT_URL_FOR_TEST}/stop_expert_distribution_record"
|
||||||
|
)
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
|
||||||
|
# Dump the recorded data
|
||||||
|
response = requests.post(
|
||||||
|
f"{DEFAULT_URL_FOR_TEST}/dump_expert_distribution_record"
|
||||||
|
)
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
|
||||||
|
# Verify the dumped file exists and has correct format
|
||||||
|
csv_files = glob.glob("expert_distribution_*.csv")
|
||||||
|
self.assertEqual(
|
||||||
|
len(csv_files), 1, "Expected exactly one expert distribution CSV file"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check CSV file format
|
||||||
|
with open(csv_files[0], "r") as f:
|
||||||
|
csv_reader = csv.reader(f)
|
||||||
|
|
||||||
|
# Check header
|
||||||
|
header = next(csv_reader)
|
||||||
|
self.assertEqual(
|
||||||
|
header,
|
||||||
|
["layer_id", "expert_id", "count"],
|
||||||
|
"CSV header should be 'layer_id,expert_id,count'",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check data rows
|
||||||
|
rows = list(csv_reader)
|
||||||
|
self.assertGreater(len(rows), 0, "CSV file should contain data rows")
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
|
# Verify each row has 3 columns
|
||||||
|
self.assertEqual(
|
||||||
|
len(row),
|
||||||
|
3,
|
||||||
|
"Each row should have layer_id, expert_id and count",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify data types
|
||||||
|
layer_id, expert_id, count = row
|
||||||
|
self.assertTrue(layer_id.isdigit(), "layer_id should be an integer")
|
||||||
|
self.assertTrue(
|
||||||
|
expert_id.isdigit(), "expert_id should be an integer"
|
||||||
|
)
|
||||||
|
self.assertTrue(count.isdigit(), "count should be an integer")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
kill_process_tree(process.pid)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user