Remove the dependency of rpyc (#646)

This commit is contained in:
Mingyi
2024-07-18 02:13:54 -07:00
committed by GitHub
parent d93388da3e
commit d774acad5c
11 changed files with 294 additions and 542 deletions

View File

@@ -44,15 +44,13 @@ from sglang.srt.openai_api_adapter import (
v1_chat_completions,
v1_completions,
)
from sglang.srt.server_args import ModelPortArgs, PortArgs, ServerArgs
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import (
API_KEY_HEADER_NAME,
APIKeyValidatorMiddleware,
allocate_init_ports,
assert_pkg_version,
enable_show_time_cost,
receive_addrs,
send_addrs_to_rank_0,
)
from sglang.utils import get_exception_traceback
@@ -98,6 +96,7 @@ async def flush_cache():
async def generate_request(obj: GenerateReqInput, request: Request):
"""Handle a generate request."""
if obj.stream:
async def stream_results():
@@ -146,7 +145,10 @@ def _set_global_server_args(server_args: ServerArgs):
}
def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_args=None):
def launch_server(server_args: ServerArgs,
model_overide_args: Optional[dict] = None,
pipe_finish_writer: Optional[mp.connection.Connection] = None):
"""Launch an HTTP server."""
global tokenizer_manager
logging.basicConfig(
@@ -173,39 +175,23 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
if server_args.chat_template:
# TODO: replace this with huggingface transformers template
load_chat_template_for_openai_api(server_args.chat_template)
_set_global_server_args(server_args)
# Allocate ports
assert server_args.tp_size % server_args.nnodes == 0
tp_size_local = server_args.tp_size // server_args.nnodes
server_args.port, server_args.additional_ports = allocate_init_ports(
server_args.port,
server_args.additional_ports,
tp_size_local,
server_args.dp_size,
)
ports = server_args.additional_ports
model_port_args = []
for i in range(server_args.dp_size):
model_port_args.append(
ModelPortArgs(
nccl_port=ports[3 + i * (tp_size_local + 1)],
model_tp_ips=[None] * tp_size_local,
model_tp_ports=ports[
3 + i * (tp_size_local + 1) + 1 : 3 + (i + 1) * (tp_size_local + 1)
],
)
)
port_args = PortArgs(
tokenizer_port=ports[0],
router_port=ports[1],
controller_port=ports[1],
detokenizer_port=ports[2],
model_port_args=model_port_args,
nccl_ports=ports[3:],
)
# Handle multi-node tp
# Handle multi-node tensor parallelism
if server_args.nnodes > 1:
assert server_args.dp_size == 1, "Multi-node dp is not supported."
@@ -224,7 +210,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
gpu_ids,
tp_rank_range,
server_args,
port_args.model_port_args[0],
ports[3],
model_overide_args,
)
while True:
@@ -232,18 +218,18 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
# Launch processes
tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args)
pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False)
pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False)
pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
if server_args.dp_size == 1:
start_process = start_controller_process_single
else:
start_process = start_controller_process_multi
proc_router = mp.Process(
proc_controller = mp.Process(
target=start_process,
args=(server_args, port_args, pipe_router_writer, model_overide_args),
args=(server_args, port_args, pipe_controller_writer, model_overide_args),
)
proc_router.start()
proc_controller.start()
proc_detoken = mp.Process(
target=start_detokenizer_process,
args=(
@@ -255,68 +241,27 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
proc_detoken.start()
# Wait for the model to finish loading
router_init_state = pipe_router_reader.recv()
controller_init_state = pipe_controller_reader.recv()
detoken_init_state = pipe_detoken_reader.recv()
if router_init_state != "init ok" or detoken_init_state != "init ok":
proc_router.kill()
if controller_init_state != "init ok" or detoken_init_state != "init ok":
proc_controller.kill()
proc_detoken.kill()
print(
f"Initialization failed. router_init_state: {router_init_state}", flush=True
f"Initialization failed. controller_init_state: {controller_init_state}", flush=True
)
print(
f"Initialization failed. detoken_init_state: {detoken_init_state}",
flush=True,
)
sys.exit(1)
assert proc_router.is_alive() and proc_detoken.is_alive()
assert proc_controller.is_alive() and proc_detoken.is_alive()
if server_args.api_key and server_args.api_key != "":
app.add_middleware(APIKeyValidatorMiddleware, api_key=server_args.api_key)
# Send a warmup request
def _wait_and_warmup():
headers = {}
url = server_args.url()
if server_args.api_key:
headers[API_KEY_HEADER_NAME] = server_args.api_key
# Wait until the server is launched
for _ in range(120):
time.sleep(0.5)
try:
requests.get(url + "/get_model_info", timeout=5, headers=headers)
break
except requests.exceptions.RequestException:
pass
# Send a warmup request
try:
for _ in range(server_args.dp_size):
res = requests.post(
url + "/generate",
json={
"text": "The capital city of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 8,
},
},
headers=headers,
timeout=600,
)
assert res.status_code == 200
except Exception as e:
if pipe_finish_writer is not None:
pipe_finish_writer.send(get_exception_traceback())
print(f"Initialization failed. warmup error: {e}", flush=True)
raise e
logger.info("The server is fired up and ready to roll!")
if pipe_finish_writer is not None:
pipe_finish_writer.send("init ok")
t = threading.Thread(target=_wait_and_warmup)
t = threading.Thread(target=_wait_and_warmup, args=(server_args, pipe_finish_writer))
t.start()
# Listen for requests
@@ -333,6 +278,48 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
t.join()
def _wait_and_warmup(server_args, pipe_finish_writer):
headers = {}
url = server_args.url()
if server_args.api_key:
headers[API_KEY_HEADER_NAME] = server_args.api_key
# Wait until the server is launched
for _ in range(120):
time.sleep(0.5)
try:
requests.get(url + "/get_model_info", timeout=5, headers=headers)
break
except requests.exceptions.RequestException:
pass
# Send a warmup request
try:
for _ in range(server_args.dp_size):
res = requests.post(
url + "/generate",
json={
"text": "The capital city of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 8,
},
},
headers=headers,
timeout=600,
)
assert res.status_code == 200
except Exception as e:
if pipe_finish_writer is not None:
pipe_finish_writer.send(get_exception_traceback())
print(f"Initialization failed. warmup error: {e}", flush=True)
raise e
logger.info("The server is fired up and ready to roll!")
if pipe_finish_writer is not None:
pipe_finish_writer.send("init ok")
class Runtime:
"""
A wrapper for the server.
@@ -354,7 +341,6 @@ class Runtime:
self.server_args.port, self.server_args.additional_ports = allocate_init_ports(
self.server_args.port,
self.server_args.additional_ports,
self.server_args.tp_size,
self.server_args.dp_size,
)
@@ -367,7 +353,7 @@ class Runtime:
pipe_reader, pipe_writer = mp.Pipe(duplex=False)
proc = mp.Process(
target=launch_server,
args=(self.server_args, pipe_writer, model_overide_args),
args=(self.server_args, model_overide_args, pipe_writer),
)
proc.start()
pipe_writer.close()