Remove the dependency of rpyc (#646)
This commit is contained in:
@@ -44,15 +44,13 @@ from sglang.srt.openai_api_adapter import (
|
||||
v1_chat_completions,
|
||||
v1_completions,
|
||||
)
|
||||
from sglang.srt.server_args import ModelPortArgs, PortArgs, ServerArgs
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
API_KEY_HEADER_NAME,
|
||||
APIKeyValidatorMiddleware,
|
||||
allocate_init_ports,
|
||||
assert_pkg_version,
|
||||
enable_show_time_cost,
|
||||
receive_addrs,
|
||||
send_addrs_to_rank_0,
|
||||
)
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
@@ -98,6 +96,7 @@ async def flush_cache():
|
||||
|
||||
|
||||
async def generate_request(obj: GenerateReqInput, request: Request):
|
||||
"""Handle a generate request."""
|
||||
if obj.stream:
|
||||
|
||||
async def stream_results():
|
||||
@@ -146,7 +145,10 @@ def _set_global_server_args(server_args: ServerArgs):
|
||||
}
|
||||
|
||||
|
||||
def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_args=None):
|
||||
def launch_server(server_args: ServerArgs,
|
||||
model_overide_args: Optional[dict] = None,
|
||||
pipe_finish_writer: Optional[mp.connection.Connection] = None):
|
||||
"""Launch an HTTP server."""
|
||||
global tokenizer_manager
|
||||
|
||||
logging.basicConfig(
|
||||
@@ -173,39 +175,23 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
||||
if server_args.chat_template:
|
||||
# TODO: replace this with huggingface transformers template
|
||||
load_chat_template_for_openai_api(server_args.chat_template)
|
||||
|
||||
_set_global_server_args(server_args)
|
||||
|
||||
# Allocate ports
|
||||
assert server_args.tp_size % server_args.nnodes == 0
|
||||
tp_size_local = server_args.tp_size // server_args.nnodes
|
||||
server_args.port, server_args.additional_ports = allocate_init_ports(
|
||||
server_args.port,
|
||||
server_args.additional_ports,
|
||||
tp_size_local,
|
||||
server_args.dp_size,
|
||||
)
|
||||
|
||||
ports = server_args.additional_ports
|
||||
model_port_args = []
|
||||
for i in range(server_args.dp_size):
|
||||
model_port_args.append(
|
||||
ModelPortArgs(
|
||||
nccl_port=ports[3 + i * (tp_size_local + 1)],
|
||||
model_tp_ips=[None] * tp_size_local,
|
||||
model_tp_ports=ports[
|
||||
3 + i * (tp_size_local + 1) + 1 : 3 + (i + 1) * (tp_size_local + 1)
|
||||
],
|
||||
)
|
||||
)
|
||||
port_args = PortArgs(
|
||||
tokenizer_port=ports[0],
|
||||
router_port=ports[1],
|
||||
controller_port=ports[1],
|
||||
detokenizer_port=ports[2],
|
||||
model_port_args=model_port_args,
|
||||
nccl_ports=ports[3:],
|
||||
)
|
||||
|
||||
# Handle multi-node tp
|
||||
# Handle multi-node tensor parallelism
|
||||
if server_args.nnodes > 1:
|
||||
assert server_args.dp_size == 1, "Multi-node dp is not supported."
|
||||
|
||||
@@ -224,7 +210,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
||||
gpu_ids,
|
||||
tp_rank_range,
|
||||
server_args,
|
||||
port_args.model_port_args[0],
|
||||
ports[3],
|
||||
model_overide_args,
|
||||
)
|
||||
while True:
|
||||
@@ -232,18 +218,18 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
||||
|
||||
# Launch processes
|
||||
tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args)
|
||||
pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False)
|
||||
pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False)
|
||||
pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
|
||||
|
||||
if server_args.dp_size == 1:
|
||||
start_process = start_controller_process_single
|
||||
else:
|
||||
start_process = start_controller_process_multi
|
||||
proc_router = mp.Process(
|
||||
proc_controller = mp.Process(
|
||||
target=start_process,
|
||||
args=(server_args, port_args, pipe_router_writer, model_overide_args),
|
||||
args=(server_args, port_args, pipe_controller_writer, model_overide_args),
|
||||
)
|
||||
proc_router.start()
|
||||
proc_controller.start()
|
||||
proc_detoken = mp.Process(
|
||||
target=start_detokenizer_process,
|
||||
args=(
|
||||
@@ -255,68 +241,27 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
||||
proc_detoken.start()
|
||||
|
||||
# Wait for the model to finish loading
|
||||
router_init_state = pipe_router_reader.recv()
|
||||
controller_init_state = pipe_controller_reader.recv()
|
||||
detoken_init_state = pipe_detoken_reader.recv()
|
||||
|
||||
if router_init_state != "init ok" or detoken_init_state != "init ok":
|
||||
proc_router.kill()
|
||||
if controller_init_state != "init ok" or detoken_init_state != "init ok":
|
||||
proc_controller.kill()
|
||||
proc_detoken.kill()
|
||||
print(
|
||||
f"Initialization failed. router_init_state: {router_init_state}", flush=True
|
||||
f"Initialization failed. controller_init_state: {controller_init_state}", flush=True
|
||||
)
|
||||
print(
|
||||
f"Initialization failed. detoken_init_state: {detoken_init_state}",
|
||||
flush=True,
|
||||
)
|
||||
sys.exit(1)
|
||||
assert proc_router.is_alive() and proc_detoken.is_alive()
|
||||
assert proc_controller.is_alive() and proc_detoken.is_alive()
|
||||
|
||||
if server_args.api_key and server_args.api_key != "":
|
||||
app.add_middleware(APIKeyValidatorMiddleware, api_key=server_args.api_key)
|
||||
|
||||
# Send a warmup request
|
||||
def _wait_and_warmup():
|
||||
headers = {}
|
||||
url = server_args.url()
|
||||
if server_args.api_key:
|
||||
headers[API_KEY_HEADER_NAME] = server_args.api_key
|
||||
|
||||
# Wait until the server is launched
|
||||
for _ in range(120):
|
||||
time.sleep(0.5)
|
||||
try:
|
||||
requests.get(url + "/get_model_info", timeout=5, headers=headers)
|
||||
break
|
||||
except requests.exceptions.RequestException:
|
||||
pass
|
||||
|
||||
# Send a warmup request
|
||||
try:
|
||||
for _ in range(server_args.dp_size):
|
||||
res = requests.post(
|
||||
url + "/generate",
|
||||
json={
|
||||
"text": "The capital city of France is",
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 8,
|
||||
},
|
||||
},
|
||||
headers=headers,
|
||||
timeout=600,
|
||||
)
|
||||
assert res.status_code == 200
|
||||
except Exception as e:
|
||||
if pipe_finish_writer is not None:
|
||||
pipe_finish_writer.send(get_exception_traceback())
|
||||
print(f"Initialization failed. warmup error: {e}", flush=True)
|
||||
raise e
|
||||
|
||||
logger.info("The server is fired up and ready to roll!")
|
||||
if pipe_finish_writer is not None:
|
||||
pipe_finish_writer.send("init ok")
|
||||
|
||||
t = threading.Thread(target=_wait_and_warmup)
|
||||
t = threading.Thread(target=_wait_and_warmup, args=(server_args, pipe_finish_writer))
|
||||
t.start()
|
||||
|
||||
# Listen for requests
|
||||
@@ -333,6 +278,48 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
||||
t.join()
|
||||
|
||||
|
||||
def _wait_and_warmup(server_args, pipe_finish_writer):
|
||||
headers = {}
|
||||
url = server_args.url()
|
||||
if server_args.api_key:
|
||||
headers[API_KEY_HEADER_NAME] = server_args.api_key
|
||||
|
||||
# Wait until the server is launched
|
||||
for _ in range(120):
|
||||
time.sleep(0.5)
|
||||
try:
|
||||
requests.get(url + "/get_model_info", timeout=5, headers=headers)
|
||||
break
|
||||
except requests.exceptions.RequestException:
|
||||
pass
|
||||
|
||||
# Send a warmup request
|
||||
try:
|
||||
for _ in range(server_args.dp_size):
|
||||
res = requests.post(
|
||||
url + "/generate",
|
||||
json={
|
||||
"text": "The capital city of France is",
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 8,
|
||||
},
|
||||
},
|
||||
headers=headers,
|
||||
timeout=600,
|
||||
)
|
||||
assert res.status_code == 200
|
||||
except Exception as e:
|
||||
if pipe_finish_writer is not None:
|
||||
pipe_finish_writer.send(get_exception_traceback())
|
||||
print(f"Initialization failed. warmup error: {e}", flush=True)
|
||||
raise e
|
||||
|
||||
logger.info("The server is fired up and ready to roll!")
|
||||
if pipe_finish_writer is not None:
|
||||
pipe_finish_writer.send("init ok")
|
||||
|
||||
|
||||
class Runtime:
|
||||
"""
|
||||
A wrapper for the server.
|
||||
@@ -354,7 +341,6 @@ class Runtime:
|
||||
self.server_args.port, self.server_args.additional_ports = allocate_init_ports(
|
||||
self.server_args.port,
|
||||
self.server_args.additional_ports,
|
||||
self.server_args.tp_size,
|
||||
self.server_args.dp_size,
|
||||
)
|
||||
|
||||
@@ -367,7 +353,7 @@ class Runtime:
|
||||
pipe_reader, pipe_writer = mp.Pipe(duplex=False)
|
||||
proc = mp.Process(
|
||||
target=launch_server,
|
||||
args=(self.server_args, pipe_writer, model_overide_args),
|
||||
args=(self.server_args, model_overide_args, pipe_writer),
|
||||
)
|
||||
proc.start()
|
||||
pipe_writer.close()
|
||||
|
||||
Reference in New Issue
Block a user