Expose max total num tokens from Runtime & Engine API (#2092)
This commit is contained in:
committed by
GitHub
parent
72f87b723b
commit
c35cd1f8c7
@@ -167,9 +167,12 @@ class DataParallelController:
|
||||
self.context, zmq.PUSH, port_args.scheduler_input_ipc_name
|
||||
)
|
||||
|
||||
# Wait for model to finish loading
|
||||
# Wait for model to finish loading and get max token nums
|
||||
scheduler_info = []
|
||||
for i in range(len(scheduler_pipe_readers)):
|
||||
scheduler_pipe_readers[i].recv()
|
||||
scheduler_info.append(scheduler_pipe_readers[i].recv())
|
||||
|
||||
self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"]
|
||||
|
||||
return send_to
|
||||
|
||||
@@ -191,7 +194,10 @@ class DataParallelController:
|
||||
send_to = get_zmq_socket(
|
||||
self.context, zmq.PUSH, port_args.scheduler_input_ipc_name
|
||||
)
|
||||
reader.recv()
|
||||
|
||||
scheduler_info = reader.recv()
|
||||
self.max_total_num_tokens = scheduler_info["max_total_num_tokens"]
|
||||
|
||||
return send_to
|
||||
|
||||
def round_robin_scheduler(self, req):
|
||||
@@ -233,7 +239,9 @@ def run_data_parallel_controller_process(
|
||||
|
||||
try:
|
||||
controller = DataParallelController(server_args, port_args)
|
||||
pipe_writer.send("ready")
|
||||
pipe_writer.send(
|
||||
{"status": "ready", "max_total_num_tokens": controller.max_total_num_tokens}
|
||||
)
|
||||
controller.event_loop()
|
||||
except Exception:
|
||||
msg = get_exception_traceback()
|
||||
|
||||
@@ -1400,7 +1400,9 @@ def run_scheduler_process(
|
||||
|
||||
try:
|
||||
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
|
||||
pipe_writer.send("ready")
|
||||
pipe_writer.send(
|
||||
{"status": "ready", "max_total_num_tokens": scheduler.max_total_num_tokens}
|
||||
)
|
||||
if scheduler.enable_overlap:
|
||||
scheduler.event_loop_overlap()
|
||||
else:
|
||||
|
||||
@@ -102,6 +102,7 @@ app.add_middleware(
|
||||
)
|
||||
|
||||
tokenizer_manager: TokenizerManager = None
|
||||
_max_total_num_tokens = None
|
||||
|
||||
##### Native API endpoints #####
|
||||
|
||||
@@ -184,6 +185,17 @@ async def stop_profile():
|
||||
)
|
||||
|
||||
|
||||
@app.get("/get_max_total_num_tokens")
|
||||
async def get_max_total_num_tokens():
|
||||
try:
|
||||
return {"max_total_num_tokens": _get_max_total_num_tokens()}
|
||||
|
||||
except Exception as e:
|
||||
return ORJSONResponse(
|
||||
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
||||
)
|
||||
|
||||
|
||||
@app.api_route("/get_memory_pool_size", methods=["GET", "POST"])
|
||||
async def get_memory_pool_size():
|
||||
"""Get the memory pool size in number of tokens"""
|
||||
@@ -390,6 +402,7 @@ def launch_engine(
|
||||
"""
|
||||
|
||||
global tokenizer_manager
|
||||
global _max_total_num_tokens
|
||||
|
||||
# Configure global environment
|
||||
configure_logger(server_args)
|
||||
@@ -455,9 +468,20 @@ def launch_engine(
|
||||
if server_args.chat_template:
|
||||
load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)
|
||||
|
||||
# Wait for model to finish loading
|
||||
# Wait for model to finish loading & get max token nums
|
||||
scheduler_info = []
|
||||
for i in range(len(scheduler_pipe_readers)):
|
||||
scheduler_pipe_readers[i].recv()
|
||||
data = scheduler_pipe_readers[i].recv()
|
||||
|
||||
if data["status"] != "ready":
|
||||
self.shutdown()
|
||||
raise RuntimeError(
|
||||
"Initialization failed. Please see the error messages above."
|
||||
)
|
||||
scheduler_info.append(data)
|
||||
|
||||
# Assume all schedulers have same max_total_num_tokens
|
||||
_max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"]
|
||||
|
||||
|
||||
def launch_server(
|
||||
@@ -518,6 +542,10 @@ def launch_server(
|
||||
t.join()
|
||||
|
||||
|
||||
def _get_max_total_num_tokens():
|
||||
return _max_total_num_tokens
|
||||
|
||||
|
||||
def _set_envs_and_config(server_args: ServerArgs):
|
||||
# Set global environments
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||
@@ -759,6 +787,15 @@ class Runtime:
|
||||
response = requests.post(self.url + "/encode", json=json_data)
|
||||
return json.dumps(response.json())
|
||||
|
||||
def get_max_total_num_tokens(self):
|
||||
response = requests.get(f"{self.url}/get_max_total_num_tokens")
|
||||
if response.status_code == 200:
|
||||
return response.json()["max_total_num_tokens"]
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Failed to get max tokens. {response.json()['error']['message']}"
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
self.shutdown()
|
||||
|
||||
@@ -908,3 +945,6 @@ class Engine:
|
||||
# get the current event loop
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(encode_request(obj, None))
|
||||
|
||||
def get_max_total_num_tokens(self):
|
||||
return _get_max_total_num_tokens()
|
||||
|
||||
Reference in New Issue
Block a user