Improve: Rename TokenizerManager to StdOrchestrator (#3116)

This commit is contained in:
fzyzcjy
2025-02-23 16:30:58 +08:00
committed by GitHub
parent 3f41b18455
commit 45360b2fa9
11 changed files with 116 additions and 130 deletions

View File

@@ -117,7 +117,7 @@ def create_streaming_error_response(
return json_str
def load_chat_template_for_openai_api(tokenizer_manager, chat_template_arg, model_path):
def load_chat_template_for_openai_api(orchestrator, chat_template_arg, model_path):
global chat_template_name
logger.info(
@@ -133,9 +133,7 @@ def load_chat_template_for_openai_api(tokenizer_manager, chat_template_arg, mode
if chat_template_arg.endswith(".jinja"):
with open(chat_template_arg, "r") as f:
chat_template = "".join(f.readlines()).strip("\n")
tokenizer_manager.tokenizer.chat_template = chat_template.replace(
"\\n", "\n"
)
orchestrator.tokenizer.chat_template = chat_template.replace("\\n", "\n")
chat_template_name = None
else:
assert chat_template_arg.endswith(
@@ -231,7 +229,7 @@ async def v1_delete_file(file_id: str):
return FileDeleteResponse(id=file_id, deleted=True)
async def v1_batches(tokenizer_manager, raw_request: Request):
async def v1_batches(orchestrator, raw_request: Request):
try:
body = await raw_request.json()
@@ -252,7 +250,7 @@ async def v1_batches(tokenizer_manager, raw_request: Request):
batch_storage[batch_id] = batch_response
# Start processing the batch asynchronously
asyncio.create_task(process_batch(tokenizer_manager, batch_id, batch_request))
asyncio.create_task(process_batch(orchestrator, batch_id, batch_request))
# Return the initial batch_response
return batch_response
@@ -263,7 +261,7 @@ async def v1_batches(tokenizer_manager, raw_request: Request):
return {"error": str(e)}
async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRequest):
async def process_batch(orchestrator, batch_id: str, batch_request: BatchRequest):
try:
# Update the batch status to "in_progress"
batch_storage[batch_id].status = "in_progress"
@@ -306,7 +304,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
if end_point == "/v1/chat/completions":
adapted_request, request = v1_chat_generate_request(
all_requests, tokenizer_manager, request_ids=request_ids
all_requests, orchestrator, request_ids=request_ids
)
elif end_point == "/v1/completions":
adapted_request, request = v1_generate_request(
@@ -314,7 +312,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
)
try:
ret = await tokenizer_manager.generate_request(adapted_request).__anext__()
ret = await orchestrator.generate_request(adapted_request).__anext__()
if not isinstance(ret, list):
ret = [ret]
if end_point == "/v1/chat/completions":
@@ -322,12 +320,12 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
request,
ret,
to_file=True,
cache_report=tokenizer_manager.server_args.enable_cache_report,
tool_call_parser=tokenizer_manager.server_args.tool_call_parser,
cache_report=orchestrator.server_args.enable_cache_report,
tool_call_parser=orchestrator.server_args.tool_call_parser,
)
else:
responses = v1_generate_response(
request, ret, tokenizer_manager, to_file=True
request, ret, orchestrator, to_file=True
)
except Exception as e:
@@ -399,7 +397,7 @@ async def v1_retrieve_batch(batch_id: str):
return batch_response
async def v1_cancel_batch(tokenizer_manager, batch_id: str):
async def v1_cancel_batch(orchestrator, batch_id: str):
# Retrieve the batch job from the in-memory storage
batch_response = batch_storage.get(batch_id)
if batch_response is None:
@@ -410,7 +408,7 @@ async def v1_cancel_batch(tokenizer_manager, batch_id: str):
# Start cancelling the batch asynchronously
asyncio.create_task(
cancel_batch(
tokenizer_manager=tokenizer_manager,
orchestrator=orchestrator,
batch_id=batch_id,
input_file_id=batch_response.input_file_id,
)
@@ -427,7 +425,7 @@ async def v1_cancel_batch(tokenizer_manager, batch_id: str):
)
async def cancel_batch(tokenizer_manager, batch_id: str, input_file_id: str):
async def cancel_batch(orchestrator, batch_id: str, input_file_id: str):
try:
# Update the batch status to "cancelling"
batch_storage[batch_id].status = "cancelling"
@@ -451,7 +449,7 @@ async def cancel_batch(tokenizer_manager, batch_id: str, input_file_id: str):
# Cancel requests by request_ids
for rid in request_ids:
tokenizer_manager.abort_request(rid=rid)
orchestrator.abort_request(rid=rid)
retrieve_batch = batch_storage[batch_id]
retrieve_batch.status = "cancelled"
@@ -579,7 +577,7 @@ def v1_generate_request(
return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
def v1_generate_response(request, ret, orchestrator, to_file=False):
choices = []
echo = False
@@ -591,15 +589,13 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
elif isinstance(request.prompt, list) and isinstance(request.prompt[0], list):
# for the case of multiple token ids prompts
prompts = [
tokenizer_manager.tokenizer.decode(prompt, skip_special_tokens=True)
orchestrator.tokenizer.decode(prompt, skip_special_tokens=True)
for prompt in request.prompt
]
elif isinstance(request.prompt, list) and isinstance(request.prompt[0], int):
# for the case of single token ids prompt
prompts = [
tokenizer_manager.tokenizer.decode(
request.prompt, skip_special_tokens=True
)
orchestrator.tokenizer.decode(request.prompt, skip_special_tokens=True)
]
else:
# for the case of single str prompt
@@ -709,7 +705,7 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
return response
async def v1_completions(tokenizer_manager, raw_request: Request):
async def v1_completions(orchestrator, raw_request: Request):
request_json = await raw_request.json()
all_requests = [CompletionRequest(**request_json)]
adapted_request, request = v1_generate_request(all_requests)
@@ -722,7 +718,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
prompt_tokens = {}
completion_tokens = {}
try:
async for content in tokenizer_manager.generate_request(
async for content in orchestrator.generate_request(
adapted_request, raw_request
):
index = content.get("index", 0)
@@ -745,14 +741,14 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
prompts = request.prompt[index // request.n]
elif isinstance(request.prompt[0], int):
# for the case of single token ids prompt
prompts = tokenizer_manager.tokenizer.decode(
prompts = orchestrator.tokenizer.decode(
request.prompt, skip_special_tokens=True
)
elif isinstance(request.prompt[0], list) and isinstance(
request.prompt[0][0], int
):
# for the case of multiple token ids prompts
prompts = tokenizer_manager.tokenizer.decode(
prompts = orchestrator.tokenizer.decode(
request.prompt[index // request.n],
skip_special_tokens=True,
)
@@ -847,12 +843,12 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
return StreamingResponse(
generate_stream_resp(),
media_type="text/event-stream",
background=tokenizer_manager.create_abort_task(adapted_request),
background=orchestrator.create_abort_task(adapted_request),
)
# Non-streaming response.
try:
ret = await tokenizer_manager.generate_request(
ret = await orchestrator.generate_request(
adapted_request, raw_request
).__anext__()
except ValueError as e:
@@ -861,13 +857,13 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
if not isinstance(ret, list):
ret = [ret]
response = v1_generate_response(request, ret, tokenizer_manager)
response = v1_generate_response(request, ret, orchestrator)
return response
def v1_chat_generate_request(
all_requests: List[ChatCompletionRequest],
tokenizer_manager,
orchestrator,
request_ids: List[str] = None,
):
input_ids = []
@@ -922,7 +918,7 @@ def v1_chat_generate_request(
assistant_prefix = None
try:
prompt_ids = tokenizer_manager.tokenizer.apply_chat_template(
prompt_ids = orchestrator.tokenizer.apply_chat_template(
openai_compatible_messages,
tokenize=True,
add_generation_prompt=True,
@@ -933,7 +929,7 @@ def v1_chat_generate_request(
# has a different tools input format that is not compatiable
# with openAI's apply_chat_template tool_call format, like Mistral.
tools = [t if "function" in t else {"function": t} for t in tools]
prompt_ids = tokenizer_manager.tokenizer.apply_chat_template(
prompt_ids = orchestrator.tokenizer.apply_chat_template(
openai_compatible_messages,
tokenize=True,
add_generation_prompt=True,
@@ -941,11 +937,8 @@ def v1_chat_generate_request(
)
if assistant_prefix:
encoded = tokenizer_manager.tokenizer.encode(assistant_prefix)
if (
encoded
and encoded[0] == tokenizer_manager.tokenizer.bos_token_id
):
encoded = orchestrator.tokenizer.encode(assistant_prefix)
if encoded and encoded[0] == orchestrator.tokenizer.bos_token_id:
encoded = encoded[1:]
prompt_ids += encoded
stop = request.stop
@@ -962,7 +955,7 @@ def v1_chat_generate_request(
stop.append(request.stop)
else:
stop.extend(request.stop)
prompt_ids = tokenizer_manager.tokenizer.encode(prompt)
prompt_ids = orchestrator.tokenizer.encode(prompt)
else:
# Use the raw prompt and stop strings if the messages is already a string.
prompt_ids = request.messages
@@ -1201,10 +1194,10 @@ def v1_chat_generate_response(
return response
async def v1_chat_completions(tokenizer_manager, raw_request: Request):
async def v1_chat_completions(orchestrator, raw_request: Request):
request_json = await raw_request.json()
all_requests = [ChatCompletionRequest(**request_json)]
adapted_request, request = v1_chat_generate_request(all_requests, tokenizer_manager)
adapted_request, request = v1_chat_generate_request(all_requests, orchestrator)
if adapted_request.stream:
parser_dict = {}
@@ -1216,7 +1209,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
prompt_tokens = {}
completion_tokens = {}
try:
async for content in tokenizer_manager.generate_request(
async for content in orchestrator.generate_request(
adapted_request, raw_request
):
index = content.get("index", 0)
@@ -1306,7 +1299,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
if index not in parser_dict:
parser_dict[index] = FunctionCallParser(
tools=request.tools,
tool_call_parser=tokenizer_manager.server_args.tool_call_parser,
tool_call_parser=orchestrator.server_args.tool_call_parser,
)
parser = parser_dict[index]
@@ -1438,12 +1431,12 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
return StreamingResponse(
generate_stream_resp(),
media_type="text/event-stream",
background=tokenizer_manager.create_abort_task(adapted_request),
background=orchestrator.create_abort_task(adapted_request),
)
# Non-streaming response.
try:
ret = await tokenizer_manager.generate_request(
ret = await orchestrator.generate_request(
adapted_request, raw_request
).__anext__()
except ValueError as e:
@@ -1454,14 +1447,14 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
response = v1_chat_generate_response(
request,
ret,
cache_report=tokenizer_manager.server_args.enable_cache_report,
tool_call_parser=tokenizer_manager.server_args.tool_call_parser,
cache_report=orchestrator.server_args.enable_cache_report,
tool_call_parser=orchestrator.server_args.tool_call_parser,
)
return response
def v1_embedding_request(all_requests, tokenizer_manager):
def v1_embedding_request(all_requests, orchestrator):
prompts = []
sampling_params_list = []
first_prompt_type = type(all_requests[0].input)
@@ -1516,13 +1509,13 @@ def v1_embedding_response(ret, model_path, to_file=False):
)
async def v1_embeddings(tokenizer_manager, raw_request: Request):
async def v1_embeddings(orchestrator, raw_request: Request):
request_json = await raw_request.json()
all_requests = [EmbeddingRequest(**request_json)]
adapted_request, request = v1_embedding_request(all_requests, tokenizer_manager)
adapted_request, request = v1_embedding_request(all_requests, orchestrator)
try:
ret = await tokenizer_manager.generate_request(
ret = await orchestrator.generate_request(
adapted_request, raw_request
).__anext__()
except ValueError as e:
@@ -1531,7 +1524,7 @@ async def v1_embeddings(tokenizer_manager, raw_request: Request):
if not isinstance(ret, list):
ret = [ret]
response = v1_embedding_response(ret, tokenizer_manager.model_path)
response = v1_embedding_response(ret, orchestrator.model_path)
return response