[feat] support minimum token load balance in dp attention (#7379)
This commit is contained in:
@@ -155,7 +155,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
|
|||||||
| Arguments | Description | Defaults |
|
| Arguments | Description | Defaults |
|
||||||
|-----------|-------------|----------|
|
|-----------|-------------|----------|
|
||||||
| `--dp-size` | The data parallelism size. | 1 |
|
| `--dp-size` | The data parallelism size. | 1 |
|
||||||
| `--load-balance-method` | The load balancing strategy for data parallelism. | round_robin |
|
| `--load-balance-method` | The load balancing strategy for data parallelism. Options include: 'round_robin', 'minimum_tokens'. The Minimum Token algorithm can only be used when DP attention is applied. This algorithm performs load balancing based on the real-time token load of the DP workers. | round_robin |
|
||||||
|
|
||||||
## Multi-node distributed serving
|
## Multi-node distributed serving
|
||||||
|
|
||||||
|
|||||||
@@ -732,6 +732,7 @@ def _launch_subprocesses(
|
|||||||
pp_rank,
|
pp_rank,
|
||||||
None,
|
None,
|
||||||
writer,
|
writer,
|
||||||
|
None,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -16,9 +16,13 @@
|
|||||||
import logging
|
import logging
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
import signal
|
import signal
|
||||||
|
import struct
|
||||||
|
import sys
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
|
from multiprocessing import shared_memory
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
import setproctitle
|
import setproctitle
|
||||||
@@ -32,6 +36,7 @@ from sglang.srt.managers.io_struct import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import Req
|
from sglang.srt.managers.schedule_batch import Req
|
||||||
from sglang.srt.managers.scheduler import run_scheduler_process
|
from sglang.srt.managers.scheduler import run_scheduler_process
|
||||||
|
from sglang.srt.managers.utils import DPBalanceMeta
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||||
from sglang.srt.utils import bind_port, configure_logger, get_zmq_socket
|
from sglang.srt.utils import bind_port, configure_logger, get_zmq_socket
|
||||||
@@ -45,6 +50,7 @@ class LoadBalanceMethod(Enum):
|
|||||||
|
|
||||||
ROUND_ROBIN = auto()
|
ROUND_ROBIN = auto()
|
||||||
SHORTEST_QUEUE = auto()
|
SHORTEST_QUEUE = auto()
|
||||||
|
MINIMUM_TOKENS = auto()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_str(cls, method: str):
|
def from_str(cls, method: str):
|
||||||
@@ -58,7 +64,16 @@ class LoadBalanceMethod(Enum):
|
|||||||
class DataParallelController:
|
class DataParallelController:
|
||||||
"""A controller that dispatches requests to multiple data parallel workers."""
|
"""A controller that dispatches requests to multiple data parallel workers."""
|
||||||
|
|
||||||
def __init__(self, server_args: ServerArgs, port_args: PortArgs) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
server_args: ServerArgs,
|
||||||
|
port_args: PortArgs,
|
||||||
|
dp_balance_meta: DPBalanceMeta,
|
||||||
|
) -> None:
|
||||||
|
# for dp balance
|
||||||
|
self.global_balance_id = 0
|
||||||
|
self.balance_meta = dp_balance_meta
|
||||||
|
|
||||||
# Parse args
|
# Parse args
|
||||||
self.max_total_num_tokens = None
|
self.max_total_num_tokens = None
|
||||||
self.server_args = server_args
|
self.server_args = server_args
|
||||||
@@ -79,6 +94,7 @@ class DataParallelController:
|
|||||||
dispatch_lookup = {
|
dispatch_lookup = {
|
||||||
LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler,
|
LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler,
|
||||||
LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler,
|
LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler,
|
||||||
|
LoadBalanceMethod.MINIMUM_TOKENS: self.minimum_tokens_scheduler,
|
||||||
}
|
}
|
||||||
self.dispatching = dispatch_lookup[self.load_balance_method]
|
self.dispatching = dispatch_lookup[self.load_balance_method]
|
||||||
|
|
||||||
@@ -234,6 +250,7 @@ class DataParallelController:
|
|||||||
pp_rank,
|
pp_rank,
|
||||||
dp_rank,
|
dp_rank,
|
||||||
writer,
|
writer,
|
||||||
|
self.balance_meta,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
with memory_saver_adapter.configure_subprocess():
|
with memory_saver_adapter.configure_subprocess():
|
||||||
@@ -269,6 +286,33 @@ class DataParallelController:
|
|||||||
def shortest_queue_scheduler(self, input_requests):
|
def shortest_queue_scheduler(self, input_requests):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def minimum_tokens_scheduler(self, req):
|
||||||
|
# This variable corresponds to the balance_id in TokenizedGenerateReqInput.
|
||||||
|
# We use it to to control the number of onfly tokens (requests dispatched to workers but not yet received).
|
||||||
|
def get_next_global_balance_id() -> int:
|
||||||
|
INT32_MAX = 2147483647
|
||||||
|
current_id = self.global_balance_id
|
||||||
|
self.global_balance_id = (self.global_balance_id + 1) % INT32_MAX
|
||||||
|
return current_id
|
||||||
|
|
||||||
|
req.dp_balance_id = get_next_global_balance_id()
|
||||||
|
with self.balance_meta.mutex:
|
||||||
|
# 1. local_tokens represents the tokens currently inferring on the worker,
|
||||||
|
# while onfly refers to the requests dispatched by the dispatcher but not yet received by the scheduler.
|
||||||
|
onfly_info = self.balance_meta.get_shared_onfly()
|
||||||
|
local_tokens = self.balance_meta.get_shared_local_tokens()
|
||||||
|
total_tokens = [
|
||||||
|
local_token + sum(onfly_dict.values())
|
||||||
|
for local_token, onfly_dict in zip(local_tokens, onfly_info)
|
||||||
|
]
|
||||||
|
target_worker = total_tokens.index(min(total_tokens))
|
||||||
|
onfly_info[target_worker][req.dp_balance_id] = len(req.input_ids)
|
||||||
|
# 2. write the new onfly info to the shm
|
||||||
|
self.balance_meta.set_shared_onfly_info(onfly_info)
|
||||||
|
|
||||||
|
# logger.info(f"dp workers {local_tokens=}, {onfly_info=}, {target_worker=}")
|
||||||
|
self.workers[target_worker].send_pyobj(req)
|
||||||
|
|
||||||
def event_loop(self):
|
def event_loop(self):
|
||||||
while True:
|
while True:
|
||||||
while True:
|
while True:
|
||||||
@@ -302,9 +346,12 @@ def run_data_parallel_controller_process(
|
|||||||
setproctitle.setproctitle("sglang::data_parallel_controller")
|
setproctitle.setproctitle("sglang::data_parallel_controller")
|
||||||
configure_logger(server_args)
|
configure_logger(server_args)
|
||||||
parent_process = psutil.Process().parent()
|
parent_process = psutil.Process().parent()
|
||||||
|
balance_meta = DPBalanceMeta(server_args.dp_size)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
controller = DataParallelController(server_args, port_args)
|
controller = DataParallelController(
|
||||||
|
server_args, port_args, dp_balance_meta=balance_meta
|
||||||
|
)
|
||||||
pipe_writer.send(
|
pipe_writer.send(
|
||||||
{
|
{
|
||||||
"status": "ready",
|
"status": "ready",
|
||||||
@@ -323,3 +370,6 @@ def run_data_parallel_controller_process(
|
|||||||
traceback = get_exception_traceback()
|
traceback = get_exception_traceback()
|
||||||
logger.error(f"DataParallelController hit an exception: {traceback}")
|
logger.error(f"DataParallelController hit an exception: {traceback}")
|
||||||
parent_process.send_signal(signal.SIGQUIT)
|
parent_process.send_signal(signal.SIGQUIT)
|
||||||
|
finally:
|
||||||
|
# we need to destruct mp.Manager() in balance_meta
|
||||||
|
balance_meta.destructor()
|
||||||
|
|||||||
@@ -523,6 +523,9 @@ class TokenizedGenerateReqInput:
|
|||||||
# For data parallel rank routing
|
# For data parallel rank routing
|
||||||
data_parallel_rank: Optional[int] = None
|
data_parallel_rank: Optional[int] = None
|
||||||
|
|
||||||
|
# For dp balance
|
||||||
|
dp_balance_id: int = -1
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EmbeddingReqInput:
|
class EmbeddingReqInput:
|
||||||
@@ -648,6 +651,8 @@ class TokenizedEmbeddingReqInput:
|
|||||||
token_type_ids: List[int]
|
token_type_ids: List[int]
|
||||||
# Dummy sampling params for compatibility
|
# Dummy sampling params for compatibility
|
||||||
sampling_params: SamplingParams
|
sampling_params: SamplingParams
|
||||||
|
# For dp balance
|
||||||
|
dp_balance_id: int = -1
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -126,7 +126,7 @@ from sglang.srt.managers.scheduler_update_weights_mixin import (
|
|||||||
from sglang.srt.managers.session_controller import Session
|
from sglang.srt.managers.session_controller import Session
|
||||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||||
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
||||||
from sglang.srt.managers.utils import validate_input_length
|
from sglang.srt.managers.utils import DPBalanceMeta, validate_input_length
|
||||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
||||||
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
||||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||||
@@ -203,6 +203,7 @@ class Scheduler(
|
|||||||
moe_ep_rank: int,
|
moe_ep_rank: int,
|
||||||
pp_rank: int,
|
pp_rank: int,
|
||||||
dp_rank: Optional[int],
|
dp_rank: Optional[int],
|
||||||
|
dp_balance_meta: Optional[DPBalanceMeta] = None,
|
||||||
):
|
):
|
||||||
# Parse args
|
# Parse args
|
||||||
self.server_args = server_args
|
self.server_args = server_args
|
||||||
@@ -522,6 +523,15 @@ class Scheduler(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.balance_meta = dp_balance_meta
|
||||||
|
if (
|
||||||
|
server_args.enable_dp_attention
|
||||||
|
and server_args.load_balance_method == "minimum_tokens"
|
||||||
|
):
|
||||||
|
assert dp_balance_meta is not None
|
||||||
|
|
||||||
|
self.recv_dp_balance_id_this_term = []
|
||||||
|
|
||||||
def init_tokenizer(self):
|
def init_tokenizer(self):
|
||||||
server_args = self.server_args
|
server_args = self.server_args
|
||||||
|
|
||||||
@@ -1049,6 +1059,12 @@ class Scheduler(
|
|||||||
self,
|
self,
|
||||||
recv_req: TokenizedGenerateReqInput,
|
recv_req: TokenizedGenerateReqInput,
|
||||||
):
|
):
|
||||||
|
if (
|
||||||
|
self.server_args.enable_dp_attention
|
||||||
|
and self.server_args.load_balance_method == "minimum_tokens"
|
||||||
|
):
|
||||||
|
self.recv_dp_balance_id_this_term.append(recv_req.dp_balance_id)
|
||||||
|
|
||||||
# Create a new request
|
# Create a new request
|
||||||
if (
|
if (
|
||||||
recv_req.session_params is None
|
recv_req.session_params is None
|
||||||
@@ -1459,6 +1475,11 @@ class Scheduler(
|
|||||||
|
|
||||||
# Handle DP attention
|
# Handle DP attention
|
||||||
if need_dp_attn_preparation:
|
if need_dp_attn_preparation:
|
||||||
|
if (
|
||||||
|
self.server_args.load_balance_method == "minimum_tokens"
|
||||||
|
and self.forward_ct % 40 == 0
|
||||||
|
):
|
||||||
|
self.handle_dp_balance_data(ret)
|
||||||
ret = self.prepare_mlp_sync_batch(ret)
|
ret = self.prepare_mlp_sync_batch(ret)
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
@@ -1786,6 +1807,86 @@ class Scheduler(
|
|||||||
disable_overlap_schedule=self.server_args.disable_overlap_schedule,
|
disable_overlap_schedule=self.server_args.disable_overlap_schedule,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def handle_dp_balance_data(self, local_batch: ScheduleBatch):
|
||||||
|
def gather_dp_balance_info(holding_tokens_list) -> Union[None, List[List[int]]]:
|
||||||
|
"""gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance"""
|
||||||
|
recv_list = self.recv_dp_balance_id_this_term
|
||||||
|
assert len(recv_list) <= 511, (
|
||||||
|
"The number of requests received this round is too large. "
|
||||||
|
"Please increase gather_tensor_size and onfly_info_size."
|
||||||
|
)
|
||||||
|
# The maximum size of the tensor used for gathering data from all workers.
|
||||||
|
gather_tensor_size = 512
|
||||||
|
|
||||||
|
# recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids
|
||||||
|
recv_tensor = torch.zeros(gather_tensor_size, dtype=torch.int32)
|
||||||
|
recv_tensor[0] = holding_tokens_list
|
||||||
|
recv_tensor[1] = len(
|
||||||
|
recv_list
|
||||||
|
) # The first element is the length of the list.
|
||||||
|
recv_tensor[2 : len(recv_list) + 2] = torch.tensor(
|
||||||
|
recv_list, dtype=torch.int32
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.tp_rank == 0:
|
||||||
|
gathered_list = [
|
||||||
|
torch.zeros(gather_tensor_size, dtype=torch.int32)
|
||||||
|
for _ in range(self.balance_meta.num_workers)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
gathered_list = None
|
||||||
|
|
||||||
|
torch.distributed.gather(
|
||||||
|
recv_tensor, gathered_list, group=self.tp_cpu_group
|
||||||
|
)
|
||||||
|
|
||||||
|
gathered_id_list_per_worker = None
|
||||||
|
if self.tp_rank == 0:
|
||||||
|
gathered_id_list_per_worker = []
|
||||||
|
holding_tokens_list = []
|
||||||
|
for tensor in gathered_list:
|
||||||
|
holding_tokens_list.append(tensor[0].item())
|
||||||
|
list_length = tensor[1].item()
|
||||||
|
gathered_id_list_per_worker.append(
|
||||||
|
tensor[2 : list_length + 2].tolist()
|
||||||
|
)
|
||||||
|
|
||||||
|
return gathered_id_list_per_worker, holding_tokens_list
|
||||||
|
|
||||||
|
def write_shared_dp_balance_info(new_recv_rid_lists, local_tokens):
|
||||||
|
meta = self.balance_meta
|
||||||
|
|
||||||
|
with meta.mutex:
|
||||||
|
onfly_list: List[Dict[int, int]] = meta.get_shared_onfly()
|
||||||
|
assert len(new_recv_rid_lists) == len(
|
||||||
|
onfly_list
|
||||||
|
), "num_worker not equal"
|
||||||
|
# 1.Check if the rid received by each worker this round is present in onfly.
|
||||||
|
# If it is, remove the corresponding onfly item.
|
||||||
|
worker_id = 0
|
||||||
|
for new_recv_rids, on_fly_reqs in zip(new_recv_rid_lists, onfly_list):
|
||||||
|
for new_recv_rid in new_recv_rids:
|
||||||
|
assert (
|
||||||
|
new_recv_rid in on_fly_reqs
|
||||||
|
), f"{new_recv_rid=} not in {worker_id=} {on_fly_reqs=}, data consistency is wrong"
|
||||||
|
del on_fly_reqs[new_recv_rid]
|
||||||
|
worker_id += 1
|
||||||
|
# 2. Atomically write local_tokens and onfly into shm under the mutex
|
||||||
|
meta.set_shared_onfly_info(onfly_list)
|
||||||
|
meta.set_shared_local_tokens(local_tokens)
|
||||||
|
|
||||||
|
holding_tokens = self.get_load()
|
||||||
|
|
||||||
|
new_recv_dp_balance_id_list, holding_token_list = gather_dp_balance_info(
|
||||||
|
holding_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
self.recv_dp_balance_id_this_term.clear()
|
||||||
|
if self.tp_rank == 0: # only first worker write info
|
||||||
|
write_shared_dp_balance_info(
|
||||||
|
new_recv_dp_balance_id_list, holding_token_list
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def prepare_mlp_sync_batch_raw(
|
def prepare_mlp_sync_batch_raw(
|
||||||
local_batch: ScheduleBatch,
|
local_batch: ScheduleBatch,
|
||||||
@@ -2394,6 +2495,7 @@ def run_scheduler_process(
|
|||||||
pp_rank: int,
|
pp_rank: int,
|
||||||
dp_rank: Optional[int],
|
dp_rank: Optional[int],
|
||||||
pipe_writer,
|
pipe_writer,
|
||||||
|
balance_meta: Optional[DPBalanceMeta] = None,
|
||||||
):
|
):
|
||||||
# Generate the prefix
|
# Generate the prefix
|
||||||
prefix = ""
|
prefix = ""
|
||||||
@@ -2427,7 +2529,14 @@ def run_scheduler_process(
|
|||||||
# Create a scheduler and run the event loop
|
# Create a scheduler and run the event loop
|
||||||
try:
|
try:
|
||||||
scheduler = Scheduler(
|
scheduler = Scheduler(
|
||||||
server_args, port_args, gpu_id, tp_rank, moe_ep_rank, pp_rank, dp_rank
|
server_args,
|
||||||
|
port_args,
|
||||||
|
gpu_id,
|
||||||
|
tp_rank,
|
||||||
|
moe_ep_rank,
|
||||||
|
pp_rank,
|
||||||
|
dp_rank,
|
||||||
|
dp_balance_meta=balance_meta,
|
||||||
)
|
)
|
||||||
pipe_writer.send(
|
pipe_writer.send(
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import multiprocessing as mp
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req
|
from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req
|
||||||
|
|
||||||
@@ -38,3 +39,46 @@ def validate_input_length(
|
|||||||
return error_msg
|
return error_msg
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class DPBalanceMeta:
|
||||||
|
"""
|
||||||
|
This class will be use in scheduler and dp controller
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, num_workers: int):
|
||||||
|
self.num_workers = num_workers
|
||||||
|
self._manager = mp.Manager()
|
||||||
|
self.mutex = self._manager.Lock()
|
||||||
|
|
||||||
|
init_local_tokens = [0] * self.num_workers
|
||||||
|
init_onfly_info = [self._manager.dict() for _ in range(self.num_workers)]
|
||||||
|
|
||||||
|
self.shared_state = self._manager.Namespace()
|
||||||
|
self.shared_state.local_tokens = self._manager.list(init_local_tokens)
|
||||||
|
self.shared_state.onfly_info = self._manager.list(init_onfly_info)
|
||||||
|
|
||||||
|
def destructor(self):
|
||||||
|
# we must destructor this class manually
|
||||||
|
self._manager.shutdown()
|
||||||
|
|
||||||
|
def get_shared_onfly(self) -> List[Dict[int, int]]:
|
||||||
|
return [dict(d) for d in self.shared_state.onfly_info]
|
||||||
|
|
||||||
|
def set_shared_onfly_info(self, data: List[Dict[int, int]]):
|
||||||
|
self.shared_state.onfly_info = data
|
||||||
|
|
||||||
|
def get_shared_local_tokens(self) -> List[int]:
|
||||||
|
return list(self.shared_state.local_tokens)
|
||||||
|
|
||||||
|
def set_shared_local_tokens(self, data: List[int]):
|
||||||
|
self.shared_state.local_tokens = data
|
||||||
|
|
||||||
|
def __getstate__(self):
|
||||||
|
state = self.__dict__.copy()
|
||||||
|
del state["_manager"]
|
||||||
|
return state
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
self.__dict__.update(state)
|
||||||
|
self._manager = None
|
||||||
|
|||||||
@@ -1171,6 +1171,7 @@ class ServerArgs:
|
|||||||
choices=[
|
choices=[
|
||||||
"round_robin",
|
"round_robin",
|
||||||
"shortest_queue",
|
"shortest_queue",
|
||||||
|
"minimum_tokens",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -137,5 +137,60 @@ class TestDPAttentionDP2TP2DeepseekV3MTP(CustomTestCase):
|
|||||||
self.assertGreater(avg_spec_accept_length, 2.5)
|
self.assertGreater(avg_spec_accept_length, 2.5)
|
||||||
|
|
||||||
|
|
||||||
|
class TestDPAttentionMinimumTokenLoadBalance(CustomTestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=[
|
||||||
|
"--trust-remote-code",
|
||||||
|
"--tp",
|
||||||
|
"2",
|
||||||
|
"--enable-dp-attention",
|
||||||
|
"--dp",
|
||||||
|
"2",
|
||||||
|
"--enable-torch-compile",
|
||||||
|
"--torch-compile-max-bs",
|
||||||
|
"2",
|
||||||
|
"--load-balance-method",
|
||||||
|
"minimum_tokens",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_process_tree(cls.process.pid)
|
||||||
|
|
||||||
|
def test_mmlu(self):
|
||||||
|
args = SimpleNamespace(
|
||||||
|
base_url=self.base_url,
|
||||||
|
model=self.model,
|
||||||
|
eval_name="mmlu",
|
||||||
|
num_examples=64,
|
||||||
|
num_threads=32,
|
||||||
|
)
|
||||||
|
|
||||||
|
metrics = run_eval(args)
|
||||||
|
print(f"{metrics=}")
|
||||||
|
self.assertGreater(metrics["score"], 0.5)
|
||||||
|
|
||||||
|
def test_mgsm_en(self):
|
||||||
|
args = SimpleNamespace(
|
||||||
|
base_url=self.base_url,
|
||||||
|
model=self.model,
|
||||||
|
eval_name="mgsm_en",
|
||||||
|
num_examples=None,
|
||||||
|
num_threads=1024,
|
||||||
|
)
|
||||||
|
|
||||||
|
metrics = run_eval(args)
|
||||||
|
print(f"{metrics=}")
|
||||||
|
self.assertGreater(metrics["score"], 0.8)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user