Revert "Rename TokenizerManager to StdOrchestrator" (#3828)

This commit is contained in:
Lianmin Zheng
2025-02-24 14:47:59 -08:00
committed by GitHub
parent c9745ee082
commit f2388f6b95
11 changed files with 130 additions and 116 deletions

View File

@@ -117,7 +117,7 @@ def create_streaming_error_response(
return json_str
def load_chat_template_for_openai_api(orchestrator, chat_template_arg, model_path):
def load_chat_template_for_openai_api(tokenizer_manager, chat_template_arg, model_path):
global chat_template_name
logger.info(
@@ -133,7 +133,9 @@ def load_chat_template_for_openai_api(orchestrator, chat_template_arg, model_pat
if chat_template_arg.endswith(".jinja"):
with open(chat_template_arg, "r") as f:
chat_template = "".join(f.readlines()).strip("\n")
orchestrator.tokenizer.chat_template = chat_template.replace("\\n", "\n")
tokenizer_manager.tokenizer.chat_template = chat_template.replace(
"\\n", "\n"
)
chat_template_name = None
else:
assert chat_template_arg.endswith(
@@ -229,7 +231,7 @@ async def v1_delete_file(file_id: str):
return FileDeleteResponse(id=file_id, deleted=True)
async def v1_batches(orchestrator, raw_request: Request):
async def v1_batches(tokenizer_manager, raw_request: Request):
try:
body = await raw_request.json()
@@ -250,7 +252,7 @@ async def v1_batches(orchestrator, raw_request: Request):
batch_storage[batch_id] = batch_response
# Start processing the batch asynchronously
asyncio.create_task(process_batch(orchestrator, batch_id, batch_request))
asyncio.create_task(process_batch(tokenizer_manager, batch_id, batch_request))
# Return the initial batch_response
return batch_response
@@ -261,7 +263,7 @@ async def v1_batches(orchestrator, raw_request: Request):
return {"error": str(e)}
async def process_batch(orchestrator, batch_id: str, batch_request: BatchRequest):
async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRequest):
try:
# Update the batch status to "in_progress"
batch_storage[batch_id].status = "in_progress"
@@ -304,7 +306,7 @@ async def process_batch(orchestrator, batch_id: str, batch_request: BatchRequest
if end_point == "/v1/chat/completions":
adapted_request, request = v1_chat_generate_request(
all_requests, orchestrator, request_ids=request_ids
all_requests, tokenizer_manager, request_ids=request_ids
)
elif end_point == "/v1/completions":
adapted_request, request = v1_generate_request(
@@ -312,7 +314,7 @@ async def process_batch(orchestrator, batch_id: str, batch_request: BatchRequest
)
try:
ret = await orchestrator.generate_request(adapted_request).__anext__()
ret = await tokenizer_manager.generate_request(adapted_request).__anext__()
if not isinstance(ret, list):
ret = [ret]
if end_point == "/v1/chat/completions":
@@ -320,12 +322,12 @@ async def process_batch(orchestrator, batch_id: str, batch_request: BatchRequest
request,
ret,
to_file=True,
cache_report=orchestrator.server_args.enable_cache_report,
tool_call_parser=orchestrator.server_args.tool_call_parser,
cache_report=tokenizer_manager.server_args.enable_cache_report,
tool_call_parser=tokenizer_manager.server_args.tool_call_parser,
)
else:
responses = v1_generate_response(
request, ret, orchestrator, to_file=True
request, ret, tokenizer_manager, to_file=True
)
except Exception as e:
@@ -397,7 +399,7 @@ async def v1_retrieve_batch(batch_id: str):
return batch_response
async def v1_cancel_batch(orchestrator, batch_id: str):
async def v1_cancel_batch(tokenizer_manager, batch_id: str):
# Retrieve the batch job from the in-memory storage
batch_response = batch_storage.get(batch_id)
if batch_response is None:
@@ -408,7 +410,7 @@ async def v1_cancel_batch(orchestrator, batch_id: str):
# Start cancelling the batch asynchronously
asyncio.create_task(
cancel_batch(
orchestrator=orchestrator,
tokenizer_manager=tokenizer_manager,
batch_id=batch_id,
input_file_id=batch_response.input_file_id,
)
@@ -425,7 +427,7 @@ async def v1_cancel_batch(orchestrator, batch_id: str):
)
async def cancel_batch(orchestrator, batch_id: str, input_file_id: str):
async def cancel_batch(tokenizer_manager, batch_id: str, input_file_id: str):
try:
# Update the batch status to "cancelling"
batch_storage[batch_id].status = "cancelling"
@@ -449,7 +451,7 @@ async def cancel_batch(orchestrator, batch_id: str, input_file_id: str):
# Cancel requests by request_ids
for rid in request_ids:
orchestrator.abort_request(rid=rid)
tokenizer_manager.abort_request(rid=rid)
retrieve_batch = batch_storage[batch_id]
retrieve_batch.status = "cancelled"
@@ -577,7 +579,7 @@ def v1_generate_request(
return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
def v1_generate_response(request, ret, orchestrator, to_file=False):
def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
choices = []
echo = False
@@ -589,13 +591,15 @@ def v1_generate_response(request, ret, orchestrator, to_file=False):
elif isinstance(request.prompt, list) and isinstance(request.prompt[0], list):
# for the case of multiple token ids prompts
prompts = [
orchestrator.tokenizer.decode(prompt, skip_special_tokens=True)
tokenizer_manager.tokenizer.decode(prompt, skip_special_tokens=True)
for prompt in request.prompt
]
elif isinstance(request.prompt, list) and isinstance(request.prompt[0], int):
# for the case of single token ids prompt
prompts = [
orchestrator.tokenizer.decode(request.prompt, skip_special_tokens=True)
tokenizer_manager.tokenizer.decode(
request.prompt, skip_special_tokens=True
)
]
else:
# for the case of single str prompt
@@ -705,7 +709,7 @@ def v1_generate_response(request, ret, orchestrator, to_file=False):
return response
async def v1_completions(orchestrator, raw_request: Request):
async def v1_completions(tokenizer_manager, raw_request: Request):
request_json = await raw_request.json()
all_requests = [CompletionRequest(**request_json)]
adapted_request, request = v1_generate_request(all_requests)
@@ -718,7 +722,7 @@ async def v1_completions(orchestrator, raw_request: Request):
prompt_tokens = {}
completion_tokens = {}
try:
async for content in orchestrator.generate_request(
async for content in tokenizer_manager.generate_request(
adapted_request, raw_request
):
index = content.get("index", 0)
@@ -741,14 +745,14 @@ async def v1_completions(orchestrator, raw_request: Request):
prompts = request.prompt[index // request.n]
elif isinstance(request.prompt[0], int):
# for the case of single token ids prompt
prompts = orchestrator.tokenizer.decode(
prompts = tokenizer_manager.tokenizer.decode(
request.prompt, skip_special_tokens=True
)
elif isinstance(request.prompt[0], list) and isinstance(
request.prompt[0][0], int
):
# for the case of multiple token ids prompts
prompts = orchestrator.tokenizer.decode(
prompts = tokenizer_manager.tokenizer.decode(
request.prompt[index // request.n],
skip_special_tokens=True,
)
@@ -843,12 +847,12 @@ async def v1_completions(orchestrator, raw_request: Request):
return StreamingResponse(
generate_stream_resp(),
media_type="text/event-stream",
background=orchestrator.create_abort_task(adapted_request),
background=tokenizer_manager.create_abort_task(adapted_request),
)
# Non-streaming response.
try:
ret = await orchestrator.generate_request(
ret = await tokenizer_manager.generate_request(
adapted_request, raw_request
).__anext__()
except ValueError as e:
@@ -857,13 +861,13 @@ async def v1_completions(orchestrator, raw_request: Request):
if not isinstance(ret, list):
ret = [ret]
response = v1_generate_response(request, ret, orchestrator)
response = v1_generate_response(request, ret, tokenizer_manager)
return response
def v1_chat_generate_request(
all_requests: List[ChatCompletionRequest],
orchestrator,
tokenizer_manager,
request_ids: List[str] = None,
):
input_ids = []
@@ -918,7 +922,7 @@ def v1_chat_generate_request(
assistant_prefix = None
try:
prompt_ids = orchestrator.tokenizer.apply_chat_template(
prompt_ids = tokenizer_manager.tokenizer.apply_chat_template(
openai_compatible_messages,
tokenize=True,
add_generation_prompt=True,
@@ -929,7 +933,7 @@ def v1_chat_generate_request(
# has a different tools input format that is not compatiable
# with openAI's apply_chat_template tool_call format, like Mistral.
tools = [t if "function" in t else {"function": t} for t in tools]
prompt_ids = orchestrator.tokenizer.apply_chat_template(
prompt_ids = tokenizer_manager.tokenizer.apply_chat_template(
openai_compatible_messages,
tokenize=True,
add_generation_prompt=True,
@@ -937,8 +941,11 @@ def v1_chat_generate_request(
)
if assistant_prefix:
encoded = orchestrator.tokenizer.encode(assistant_prefix)
if encoded and encoded[0] == orchestrator.tokenizer.bos_token_id:
encoded = tokenizer_manager.tokenizer.encode(assistant_prefix)
if (
encoded
and encoded[0] == tokenizer_manager.tokenizer.bos_token_id
):
encoded = encoded[1:]
prompt_ids += encoded
stop = request.stop
@@ -955,7 +962,7 @@ def v1_chat_generate_request(
stop.append(request.stop)
else:
stop.extend(request.stop)
prompt_ids = orchestrator.tokenizer.encode(prompt)
prompt_ids = tokenizer_manager.tokenizer.encode(prompt)
else:
# Use the raw prompt and stop strings if the messages is already a string.
prompt_ids = request.messages
@@ -1194,10 +1201,10 @@ def v1_chat_generate_response(
return response
async def v1_chat_completions(orchestrator, raw_request: Request):
async def v1_chat_completions(tokenizer_manager, raw_request: Request):
request_json = await raw_request.json()
all_requests = [ChatCompletionRequest(**request_json)]
adapted_request, request = v1_chat_generate_request(all_requests, orchestrator)
adapted_request, request = v1_chat_generate_request(all_requests, tokenizer_manager)
if adapted_request.stream:
parser_dict = {}
@@ -1209,7 +1216,7 @@ async def v1_chat_completions(orchestrator, raw_request: Request):
prompt_tokens = {}
completion_tokens = {}
try:
async for content in orchestrator.generate_request(
async for content in tokenizer_manager.generate_request(
adapted_request, raw_request
):
index = content.get("index", 0)
@@ -1299,7 +1306,7 @@ async def v1_chat_completions(orchestrator, raw_request: Request):
if index not in parser_dict:
parser_dict[index] = FunctionCallParser(
tools=request.tools,
tool_call_parser=orchestrator.server_args.tool_call_parser,
tool_call_parser=tokenizer_manager.server_args.tool_call_parser,
)
parser = parser_dict[index]
@@ -1431,12 +1438,12 @@ async def v1_chat_completions(orchestrator, raw_request: Request):
return StreamingResponse(
generate_stream_resp(),
media_type="text/event-stream",
background=orchestrator.create_abort_task(adapted_request),
background=tokenizer_manager.create_abort_task(adapted_request),
)
# Non-streaming response.
try:
ret = await orchestrator.generate_request(
ret = await tokenizer_manager.generate_request(
adapted_request, raw_request
).__anext__()
except ValueError as e:
@@ -1447,14 +1454,14 @@ async def v1_chat_completions(orchestrator, raw_request: Request):
response = v1_chat_generate_response(
request,
ret,
cache_report=orchestrator.server_args.enable_cache_report,
tool_call_parser=orchestrator.server_args.tool_call_parser,
cache_report=tokenizer_manager.server_args.enable_cache_report,
tool_call_parser=tokenizer_manager.server_args.tool_call_parser,
)
return response
def v1_embedding_request(all_requests, orchestrator):
def v1_embedding_request(all_requests, tokenizer_manager):
prompts = []
sampling_params_list = []
first_prompt_type = type(all_requests[0].input)
@@ -1509,13 +1516,13 @@ def v1_embedding_response(ret, model_path, to_file=False):
)
async def v1_embeddings(orchestrator, raw_request: Request):
async def v1_embeddings(tokenizer_manager, raw_request: Request):
request_json = await raw_request.json()
all_requests = [EmbeddingRequest(**request_json)]
adapted_request, request = v1_embedding_request(all_requests, orchestrator)
adapted_request, request = v1_embedding_request(all_requests, tokenizer_manager)
try:
ret = await orchestrator.generate_request(
ret = await tokenizer_manager.generate_request(
adapted_request, raw_request
).__anext__()
except ValueError as e:
@@ -1524,7 +1531,7 @@ async def v1_embeddings(orchestrator, raw_request: Request):
if not isinstance(ret, list):
ret = [ret]
response = v1_embedding_response(ret, orchestrator.model_path)
response = v1_embedding_response(ret, tokenizer_manager.model_path)
return response