Expose max total num tokens from Runtime & Engine API (#2092)

This commit is contained in:
Henry Hyeonmok Ko
2024-11-22 15:10:10 -08:00
committed by GitHub
parent 72f87b723b
commit c35cd1f8c7
4 changed files with 81 additions and 7 deletions

View File

@@ -167,9 +167,12 @@ class DataParallelController:
self.context, zmq.PUSH, port_args.scheduler_input_ipc_name
)
# Wait for model to finish loading
# Wait for model to finish loading and get max token nums
scheduler_info = []
for i in range(len(scheduler_pipe_readers)):
scheduler_pipe_readers[i].recv()
scheduler_info.append(scheduler_pipe_readers[i].recv())
self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"]
return send_to
@@ -191,7 +194,10 @@ class DataParallelController:
send_to = get_zmq_socket(
self.context, zmq.PUSH, port_args.scheduler_input_ipc_name
)
reader.recv()
scheduler_info = reader.recv()
self.max_total_num_tokens = scheduler_info["max_total_num_tokens"]
return send_to
def round_robin_scheduler(self, req):
@@ -233,7 +239,9 @@ def run_data_parallel_controller_process(
try:
controller = DataParallelController(server_args, port_args)
pipe_writer.send("ready")
pipe_writer.send(
{"status": "ready", "max_total_num_tokens": controller.max_total_num_tokens}
)
controller.event_loop()
except Exception:
msg = get_exception_traceback()

View File

@@ -1400,7 +1400,9 @@ def run_scheduler_process(
try:
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
pipe_writer.send("ready")
pipe_writer.send(
{"status": "ready", "max_total_num_tokens": scheduler.max_total_num_tokens}
)
if scheduler.enable_overlap:
scheduler.event_loop_overlap()
else:

View File

@@ -102,6 +102,7 @@ app.add_middleware(
)
tokenizer_manager: TokenizerManager = None
_max_total_num_tokens = None
##### Native API endpoints #####
@@ -184,6 +185,17 @@ async def stop_profile():
)
@app.get("/get_max_total_num_tokens")
async def get_max_total_num_tokens():
try:
return {"max_total_num_tokens": _get_max_total_num_tokens()}
except Exception as e:
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
@app.api_route("/get_memory_pool_size", methods=["GET", "POST"])
async def get_memory_pool_size():
"""Get the memory pool size in number of tokens"""
@@ -390,6 +402,7 @@ def launch_engine(
"""
global tokenizer_manager
global _max_total_num_tokens
# Configure global environment
configure_logger(server_args)
@@ -455,9 +468,20 @@ def launch_engine(
if server_args.chat_template:
load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)
# Wait for model to finish loading
# Wait for model to finish loading & get max token nums
scheduler_info = []
for i in range(len(scheduler_pipe_readers)):
scheduler_pipe_readers[i].recv()
data = scheduler_pipe_readers[i].recv()
if data["status"] != "ready":
self.shutdown()
raise RuntimeError(
"Initialization failed. Please see the error messages above."
)
scheduler_info.append(data)
# Assume all schedulers have same max_total_num_tokens
_max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"]
def launch_server(
@@ -518,6 +542,10 @@ def launch_server(
t.join()
def _get_max_total_num_tokens():
return _max_total_num_tokens
def _set_envs_and_config(server_args: ServerArgs):
# Set global environments
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
@@ -759,6 +787,15 @@ class Runtime:
response = requests.post(self.url + "/encode", json=json_data)
return json.dumps(response.json())
def get_max_total_num_tokens(self):
response = requests.get(f"{self.url}/get_max_total_num_tokens")
if response.status_code == 200:
return response.json()["max_total_num_tokens"]
else:
raise RuntimeError(
f"Failed to get max tokens. {response.json()['error']['message']}"
)
def __del__(self):
self.shutdown()
@@ -908,3 +945,6 @@ class Engine:
# get the current event loop
loop = asyncio.get_event_loop()
return loop.run_until_complete(encode_request(obj, None))
def get_max_total_num_tokens(self):
return _get_max_total_num_tokens()