[router] add base_gpu_id server args & merged radix tree python reference (#2115)
This commit is contained in:
@@ -156,7 +156,7 @@ class DataParallelController:
|
||||
)
|
||||
for tp_rank in tp_rank_range:
|
||||
reader, writer = mp.Pipe(duplex=False)
|
||||
gpu_id = base_gpu_id + tp_rank % tp_size_per_node
|
||||
gpu_id = server_args.base_gpu_id + base_gpu_id + tp_rank % tp_size_per_node
|
||||
proc = mp.Process(
|
||||
target=run_scheduler_process,
|
||||
args=(server_args, port_args, gpu_id, tp_rank, dp_rank, writer),
|
||||
|
||||
@@ -1380,6 +1380,10 @@ def run_scheduler_process(
|
||||
dp_rank: Optional[int],
|
||||
pipe_writer,
|
||||
):
|
||||
# [For Router] if env var "DP_RANK" exist, set dp_rank to the value of the env var
|
||||
if dp_rank is None:
|
||||
dp_rank = int(os.getenv("DP_RANK", -1))
|
||||
|
||||
if dp_rank is None:
|
||||
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
||||
else:
|
||||
|
||||
@@ -418,7 +418,7 @@ def launch_engine(
|
||||
)
|
||||
for tp_rank in tp_rank_range:
|
||||
reader, writer = mp.Pipe(duplex=False)
|
||||
gpu_id = tp_rank % tp_size_per_node
|
||||
gpu_id = server_args.base_gpu_id + tp_rank % tp_size_per_node
|
||||
proc = mp.Process(
|
||||
target=run_scheduler_process,
|
||||
args=(server_args, port_args, gpu_id, tp_rank, None, writer),
|
||||
|
||||
@@ -72,6 +72,7 @@ class ServerArgs:
|
||||
constrained_json_whitespace_pattern: Optional[str] = None
|
||||
watchdog_timeout: float = 300
|
||||
download_dir: Optional[str] = None
|
||||
base_gpu_id: int = 0
|
||||
|
||||
# Logging
|
||||
log_level: str = "info"
|
||||
@@ -412,6 +413,12 @@ class ServerArgs:
|
||||
default=ServerArgs.download_dir,
|
||||
help="Model download directory.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base-gpu-id",
|
||||
type=int,
|
||||
default=ServerArgs.base_gpu_id,
|
||||
help="The base GPU ID to start allocating GPUs from. Useful when running multiple instances on the same machine.",
|
||||
)
|
||||
|
||||
# Logging
|
||||
parser.add_argument(
|
||||
@@ -736,6 +743,7 @@ class ServerArgs:
|
||||
and (self.lora_paths is None or self.disable_cuda_graph)
|
||||
and (self.lora_paths is None or self.disable_radix_cache)
|
||||
), "compatibility of lora and cuda graph and radix attention is in progress"
|
||||
assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative"
|
||||
|
||||
if isinstance(self.lora_paths, list):
|
||||
lora_paths = self.lora_paths
|
||||
|
||||
Reference in New Issue
Block a user