[router] add base_gpu_id server args & merged radix tree python reference (#2115)

This commit is contained in:
Byron Hsu
2024-11-21 17:13:33 -08:00
committed by GitHub
parent f6f713797b
commit 30af7dfb34
6 changed files with 513 additions and 2 deletions

View File

@@ -156,7 +156,7 @@ class DataParallelController:
)
for tp_rank in tp_rank_range:
reader, writer = mp.Pipe(duplex=False)
gpu_id = base_gpu_id + tp_rank % tp_size_per_node
gpu_id = server_args.base_gpu_id + base_gpu_id + tp_rank % tp_size_per_node
proc = mp.Process(
target=run_scheduler_process,
args=(server_args, port_args, gpu_id, tp_rank, dp_rank, writer),

View File

@@ -1380,6 +1380,10 @@ def run_scheduler_process(
dp_rank: Optional[int],
pipe_writer,
):
# [For Router] if env var "DP_RANK" exist, set dp_rank to the value of the env var
if dp_rank is None:
dp_rank = int(os.getenv("DP_RANK", -1))
if dp_rank is None:
configure_logger(server_args, prefix=f" TP{tp_rank}")
else:

View File

@@ -418,7 +418,7 @@ def launch_engine(
)
for tp_rank in tp_rank_range:
reader, writer = mp.Pipe(duplex=False)
gpu_id = tp_rank % tp_size_per_node
gpu_id = server_args.base_gpu_id + tp_rank % tp_size_per_node
proc = mp.Process(
target=run_scheduler_process,
args=(server_args, port_args, gpu_id, tp_rank, None, writer),

View File

@@ -72,6 +72,7 @@ class ServerArgs:
constrained_json_whitespace_pattern: Optional[str] = None
watchdog_timeout: float = 300
download_dir: Optional[str] = None
base_gpu_id: int = 0
# Logging
log_level: str = "info"
@@ -412,6 +413,12 @@ class ServerArgs:
default=ServerArgs.download_dir,
help="Model download directory.",
)
parser.add_argument(
"--base-gpu-id",
type=int,
default=ServerArgs.base_gpu_id,
help="The base GPU ID to start allocating GPUs from. Useful when running multiple instances on the same machine.",
)
# Logging
parser.add_argument(
@@ -736,6 +743,7 @@ class ServerArgs:
and (self.lora_paths is None or self.disable_cuda_graph)
and (self.lora_paths is None or self.disable_radix_cache)
), "compatibility of lora and cuda graph and radix attention is in progress"
assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative"
if isinstance(self.lora_paths, list):
lora_paths = self.lora_paths