Improve: Rename TokenizerManager to StdOrchestrator (#3116)
This commit is contained in:
@@ -426,7 +426,7 @@
|
|||||||
"from sglang.srt.managers.io_struct import Tool, Function\n",
|
"from sglang.srt.managers.io_struct import Tool, Function\n",
|
||||||
"\n",
|
"\n",
|
||||||
"llm = sgl.Engine(model_path=\"meta-llama/Meta-Llama-3.1-8B-Instruct\")\n",
|
"llm = sgl.Engine(model_path=\"meta-llama/Meta-Llama-3.1-8B-Instruct\")\n",
|
||||||
"tokenizer = llm.tokenizer_manager.tokenizer\n",
|
"tokenizer = llm.orchestrator.tokenizer\n",
|
||||||
"input_ids = tokenizer.apply_chat_template(\n",
|
"input_ids = tokenizer.apply_chat_template(\n",
|
||||||
" messages, tokenize=True, add_generation_prompt=True, tools=tools\n",
|
" messages, tokenize=True, add_generation_prompt=True, tools=tools\n",
|
||||||
")\n",
|
")\n",
|
||||||
|
|||||||
@@ -48,8 +48,8 @@ from sglang.srt.managers.io_struct import (
|
|||||||
UpdateWeightsFromTensorReqInput,
|
UpdateWeightsFromTensorReqInput,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.scheduler import run_scheduler_process
|
from sglang.srt.managers.scheduler import run_scheduler_process
|
||||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
|
||||||
from sglang.srt.openai_api.adapter import load_chat_template_for_openai_api
|
from sglang.srt.openai_api.adapter import load_chat_template_for_openai_api
|
||||||
|
from sglang.srt.orchestration.std.orchestrator import StdOrchestrator
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
@@ -74,12 +74,12 @@ class Engine:
|
|||||||
The entry point to the inference engine.
|
The entry point to the inference engine.
|
||||||
|
|
||||||
- The engine consists of three components:
|
- The engine consists of three components:
|
||||||
1. TokenizerManager: Tokenizes the requests and sends them to the scheduler.
|
1. StdOrchestrator: Tokenizes the requests and sends them to the scheduler.
|
||||||
2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager.
|
2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager.
|
||||||
3. DetokenizerManager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager.
|
3. DetokenizerManager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
1. The HTTP server, Engine, and TokenizerManager both run in the main process.
|
1. The HTTP server, Engine, and StdOrchestrator both run in the main process.
|
||||||
2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library.
|
2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -102,10 +102,8 @@ class Engine:
|
|||||||
atexit.register(self.shutdown)
|
atexit.register(self.shutdown)
|
||||||
|
|
||||||
# Launch subprocesses
|
# Launch subprocesses
|
||||||
tokenizer_manager, scheduler_info = _launch_subprocesses(
|
orchestrator, scheduler_info = _launch_subprocesses(server_args=server_args)
|
||||||
server_args=server_args
|
self.orchestrator = orchestrator
|
||||||
)
|
|
||||||
self.tokenizer_manager = tokenizer_manager
|
|
||||||
self.scheduler_info = scheduler_info
|
self.scheduler_info = scheduler_info
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
@@ -147,7 +145,7 @@ class Engine:
|
|||||||
stream=stream,
|
stream=stream,
|
||||||
)
|
)
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
generator = self.tokenizer_manager.generate_request(obj, None)
|
generator = self.orchestrator.generate_request(obj, None)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
|
|
||||||
@@ -197,7 +195,7 @@ class Engine:
|
|||||||
stream=stream,
|
stream=stream,
|
||||||
custom_logit_processor=custom_logit_processor,
|
custom_logit_processor=custom_logit_processor,
|
||||||
)
|
)
|
||||||
generator = self.tokenizer_manager.generate_request(obj, None)
|
generator = self.orchestrator.generate_request(obj, None)
|
||||||
|
|
||||||
if stream is True:
|
if stream is True:
|
||||||
return generator
|
return generator
|
||||||
@@ -215,7 +213,7 @@ class Engine:
|
|||||||
|
|
||||||
obj = EmbeddingReqInput(text=prompt)
|
obj = EmbeddingReqInput(text=prompt)
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
generator = self.tokenizer_manager.generate_request(obj, None)
|
generator = self.orchestrator.generate_request(obj, None)
|
||||||
ret = loop.run_until_complete(generator.__anext__())
|
ret = loop.run_until_complete(generator.__anext__())
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@@ -224,14 +222,14 @@ class Engine:
|
|||||||
kill_process_tree(os.getpid(), include_parent=False)
|
kill_process_tree(os.getpid(), include_parent=False)
|
||||||
|
|
||||||
def start_profile(self):
|
def start_profile(self):
|
||||||
self.tokenizer_manager.start_profile()
|
self.orchestrator.start_profile()
|
||||||
|
|
||||||
def stop_profile(self):
|
def stop_profile(self):
|
||||||
self.tokenizer_manager.stop_profile()
|
self.orchestrator.stop_profile()
|
||||||
|
|
||||||
def get_server_info(self):
|
def get_server_info(self):
|
||||||
return {
|
return {
|
||||||
**dataclasses.asdict(self.tokenizer_manager.server_args), # server args
|
**dataclasses.asdict(self.orchestrator.server_args), # server args
|
||||||
**self.scheduler_info,
|
**self.scheduler_info,
|
||||||
"version": __version__,
|
"version": __version__,
|
||||||
}
|
}
|
||||||
@@ -256,7 +254,7 @@ class Engine:
|
|||||||
)
|
)
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
return loop.run_until_complete(
|
return loop.run_until_complete(
|
||||||
self.tokenizer_manager.init_weights_update_group(obj, None)
|
self.orchestrator.init_weights_update_group(obj, None)
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_weights_from_distributed(self, name: str, dtype, shape):
|
def update_weights_from_distributed(self, name: str, dtype, shape):
|
||||||
@@ -268,7 +266,7 @@ class Engine:
|
|||||||
)
|
)
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
return loop.run_until_complete(
|
return loop.run_until_complete(
|
||||||
self.tokenizer_manager.update_weights_from_distributed(obj, None)
|
self.orchestrator.update_weights_from_distributed(obj, None)
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_weights_from_tensor(self, named_tensors: List[Tuple[str, torch.Tensor]]):
|
def update_weights_from_tensor(self, named_tensors: List[Tuple[str, torch.Tensor]]):
|
||||||
@@ -278,23 +276,21 @@ class Engine:
|
|||||||
)
|
)
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
return loop.run_until_complete(
|
return loop.run_until_complete(
|
||||||
self.tokenizer_manager.update_weights_from_tensor(obj, None)
|
self.orchestrator.update_weights_from_tensor(obj, None)
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_weights_by_name(self, name: str, truncate_size: int = 100):
|
def get_weights_by_name(self, name: str, truncate_size: int = 100):
|
||||||
"""Get weights by parameter name."""
|
"""Get weights by parameter name."""
|
||||||
obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
|
obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
return loop.run_until_complete(
|
return loop.run_until_complete(self.orchestrator.get_weights_by_name(obj, None))
|
||||||
self.tokenizer_manager.get_weights_by_name(obj, None)
|
|
||||||
)
|
|
||||||
|
|
||||||
def release_memory_occupation(self):
|
def release_memory_occupation(self):
|
||||||
"""Release GPU occupation temporarily."""
|
"""Release GPU occupation temporarily."""
|
||||||
obj = ReleaseMemoryOccupationReqInput()
|
obj = ReleaseMemoryOccupationReqInput()
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
return loop.run_until_complete(
|
return loop.run_until_complete(
|
||||||
self.tokenizer_manager.release_memory_occupation(obj, None)
|
self.orchestrator.release_memory_occupation(obj, None)
|
||||||
)
|
)
|
||||||
|
|
||||||
def resume_memory_occupation(self):
|
def resume_memory_occupation(self):
|
||||||
@@ -302,7 +298,7 @@ class Engine:
|
|||||||
obj = ResumeMemoryOccupationReqInput()
|
obj = ResumeMemoryOccupationReqInput()
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
return loop.run_until_complete(
|
return loop.run_until_complete(
|
||||||
self.tokenizer_manager.resume_memory_occupation(obj, None)
|
self.orchestrator.resume_memory_occupation(obj, None)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -351,9 +347,9 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|||||||
mp.set_start_method("spawn", force=True)
|
mp.set_start_method("spawn", force=True)
|
||||||
|
|
||||||
|
|
||||||
def _launch_subprocesses(server_args: ServerArgs) -> Tuple[TokenizerManager, Dict]:
|
def _launch_subprocesses(server_args: ServerArgs) -> Tuple[StdOrchestrator, Dict]:
|
||||||
"""
|
"""
|
||||||
Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess.
|
Launch the StdOrchestrator in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess.
|
||||||
"""
|
"""
|
||||||
# Configure global environment
|
# Configure global environment
|
||||||
configure_logger(server_args)
|
configure_logger(server_args)
|
||||||
@@ -436,10 +432,10 @@ def _launch_subprocesses(server_args: ServerArgs) -> Tuple[TokenizerManager, Dic
|
|||||||
detoken_proc.start()
|
detoken_proc.start()
|
||||||
|
|
||||||
# Launch tokenizer process
|
# Launch tokenizer process
|
||||||
tokenizer_manager = TokenizerManager(server_args, port_args)
|
orchestrator = StdOrchestrator(server_args, port_args)
|
||||||
if server_args.chat_template:
|
if server_args.chat_template:
|
||||||
load_chat_template_for_openai_api(
|
load_chat_template_for_openai_api(
|
||||||
tokenizer_manager, server_args.chat_template, server_args.model_path
|
orchestrator, server_args.chat_template, server_args.model_path
|
||||||
)
|
)
|
||||||
|
|
||||||
# Wait for the model to finish loading
|
# Wait for the model to finish loading
|
||||||
@@ -463,5 +459,5 @@ def _launch_subprocesses(server_args: ServerArgs) -> Tuple[TokenizerManager, Dic
|
|||||||
|
|
||||||
# Assume all schedulers have the same scheduler_info
|
# Assume all schedulers have the same scheduler_info
|
||||||
scheduler_info = scheduler_infos[0]
|
scheduler_info = scheduler_infos[0]
|
||||||
tokenizer_manager.configure_max_req_input_len(scheduler_info["max_req_input_len"])
|
orchestrator.configure_max_req_input_len(scheduler_info["max_req_input_len"])
|
||||||
return tokenizer_manager, scheduler_info
|
return orchestrator, scheduler_info
|
||||||
|
|||||||
@@ -54,7 +54,6 @@ from sglang.srt.managers.io_struct import (
|
|||||||
UpdateWeightFromDiskReqInput,
|
UpdateWeightFromDiskReqInput,
|
||||||
UpdateWeightsFromDistributedReqInput,
|
UpdateWeightsFromDistributedReqInput,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
|
||||||
from sglang.srt.metrics.func_timer import enable_func_timer
|
from sglang.srt.metrics.func_timer import enable_func_timer
|
||||||
from sglang.srt.openai_api.adapter import (
|
from sglang.srt.openai_api.adapter import (
|
||||||
v1_batches,
|
v1_batches,
|
||||||
@@ -69,6 +68,7 @@ from sglang.srt.openai_api.adapter import (
|
|||||||
v1_retrieve_file_content,
|
v1_retrieve_file_content,
|
||||||
)
|
)
|
||||||
from sglang.srt.openai_api.protocol import ModelCard, ModelList
|
from sglang.srt.openai_api.protocol import ModelCard, ModelList
|
||||||
|
from sglang.srt.orchestration.std.orchestrator import StdOrchestrator
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
add_api_key_middleware,
|
add_api_key_middleware,
|
||||||
@@ -97,7 +97,7 @@ app.add_middleware(
|
|||||||
# Store global states
|
# Store global states
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class _GlobalState:
|
class _GlobalState:
|
||||||
tokenizer_manager: TokenizerManager
|
orchestrator: StdOrchestrator
|
||||||
scheduler_info: Dict
|
scheduler_info: Dict
|
||||||
|
|
||||||
|
|
||||||
@@ -124,7 +124,7 @@ async def health_generate(request: Request) -> Response:
|
|||||||
|
|
||||||
sampling_params = {"max_new_tokens": 1, "temperature": 0.7}
|
sampling_params = {"max_new_tokens": 1, "temperature": 0.7}
|
||||||
|
|
||||||
if _global_state.tokenizer_manager.is_generation:
|
if _global_state.orchestrator.is_generation:
|
||||||
gri = GenerateReqInput(
|
gri = GenerateReqInput(
|
||||||
input_ids=[0], sampling_params=sampling_params, log_metrics=False
|
input_ids=[0], sampling_params=sampling_params, log_metrics=False
|
||||||
)
|
)
|
||||||
@@ -134,7 +134,7 @@ async def health_generate(request: Request) -> Response:
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async for _ in _global_state.tokenizer_manager.generate_request(gri, request):
|
async for _ in _global_state.orchestrator.generate_request(gri, request):
|
||||||
break
|
break
|
||||||
return Response(status_code=200)
|
return Response(status_code=200)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -146,9 +146,9 @@ async def health_generate(request: Request) -> Response:
|
|||||||
async def get_model_info():
|
async def get_model_info():
|
||||||
"""Get the model information."""
|
"""Get the model information."""
|
||||||
result = {
|
result = {
|
||||||
"model_path": _global_state.tokenizer_manager.model_path,
|
"model_path": _global_state.orchestrator.model_path,
|
||||||
"tokenizer_path": _global_state.tokenizer_manager.server_args.tokenizer_path,
|
"tokenizer_path": _global_state.orchestrator.server_args.tokenizer_path,
|
||||||
"is_generation": _global_state.tokenizer_manager.is_generation,
|
"is_generation": _global_state.orchestrator.is_generation,
|
||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -156,7 +156,7 @@ async def get_model_info():
|
|||||||
@app.get("/get_server_info")
|
@app.get("/get_server_info")
|
||||||
async def get_server_info():
|
async def get_server_info():
|
||||||
return {
|
return {
|
||||||
**dataclasses.asdict(_global_state.tokenizer_manager.server_args),
|
**dataclasses.asdict(_global_state.orchestrator.server_args),
|
||||||
**_global_state.scheduler_info,
|
**_global_state.scheduler_info,
|
||||||
"version": __version__,
|
"version": __version__,
|
||||||
}
|
}
|
||||||
@@ -170,7 +170,7 @@ async def generate_request(obj: GenerateReqInput, request: Request):
|
|||||||
|
|
||||||
async def stream_results() -> AsyncIterator[bytes]:
|
async def stream_results() -> AsyncIterator[bytes]:
|
||||||
try:
|
try:
|
||||||
async for out in _global_state.tokenizer_manager.generate_request(
|
async for out in _global_state.orchestrator.generate_request(
|
||||||
obj, request
|
obj, request
|
||||||
):
|
):
|
||||||
yield b"data: " + orjson.dumps(
|
yield b"data: " + orjson.dumps(
|
||||||
@@ -186,11 +186,11 @@ async def generate_request(obj: GenerateReqInput, request: Request):
|
|||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
stream_results(),
|
stream_results(),
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
background=_global_state.tokenizer_manager.create_abort_task(obj),
|
background=_global_state.orchestrator.create_abort_task(obj),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
ret = await _global_state.tokenizer_manager.generate_request(
|
ret = await _global_state.orchestrator.generate_request(
|
||||||
obj, request
|
obj, request
|
||||||
).__anext__()
|
).__anext__()
|
||||||
return ret
|
return ret
|
||||||
@@ -203,7 +203,7 @@ async def generate_request(obj: GenerateReqInput, request: Request):
|
|||||||
async def encode_request(obj: EmbeddingReqInput, request: Request):
|
async def encode_request(obj: EmbeddingReqInput, request: Request):
|
||||||
"""Handle an embedding request."""
|
"""Handle an embedding request."""
|
||||||
try:
|
try:
|
||||||
ret = await _global_state.tokenizer_manager.generate_request(
|
ret = await _global_state.orchestrator.generate_request(
|
||||||
obj, request
|
obj, request
|
||||||
).__anext__()
|
).__anext__()
|
||||||
return ret
|
return ret
|
||||||
@@ -215,7 +215,7 @@ async def encode_request(obj: EmbeddingReqInput, request: Request):
|
|||||||
async def classify_request(obj: EmbeddingReqInput, request: Request):
|
async def classify_request(obj: EmbeddingReqInput, request: Request):
|
||||||
"""Handle a reward model request. Now the arguments and return values are the same as embedding models."""
|
"""Handle a reward model request. Now the arguments and return values are the same as embedding models."""
|
||||||
try:
|
try:
|
||||||
ret = await _global_state.tokenizer_manager.generate_request(
|
ret = await _global_state.orchestrator.generate_request(
|
||||||
obj, request
|
obj, request
|
||||||
).__anext__()
|
).__anext__()
|
||||||
return ret
|
return ret
|
||||||
@@ -226,7 +226,7 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
|
|||||||
@app.post("/flush_cache")
|
@app.post("/flush_cache")
|
||||||
async def flush_cache():
|
async def flush_cache():
|
||||||
"""Flush the radix cache."""
|
"""Flush the radix cache."""
|
||||||
_global_state.tokenizer_manager.flush_cache()
|
_global_state.orchestrator.flush_cache()
|
||||||
return Response(
|
return Response(
|
||||||
content="Cache flushed.\nPlease check backend logs for more details. "
|
content="Cache flushed.\nPlease check backend logs for more details. "
|
||||||
"(When there are running or waiting requests, the operation will not be performed.)\n",
|
"(When there are running or waiting requests, the operation will not be performed.)\n",
|
||||||
@@ -237,7 +237,7 @@ async def flush_cache():
|
|||||||
@app.api_route("/start_profile", methods=["GET", "POST"])
|
@app.api_route("/start_profile", methods=["GET", "POST"])
|
||||||
async def start_profile_async():
|
async def start_profile_async():
|
||||||
"""Start profiling."""
|
"""Start profiling."""
|
||||||
_global_state.tokenizer_manager.start_profile()
|
_global_state.orchestrator.start_profile()
|
||||||
return Response(
|
return Response(
|
||||||
content="Start profiling.\n",
|
content="Start profiling.\n",
|
||||||
status_code=200,
|
status_code=200,
|
||||||
@@ -247,7 +247,7 @@ async def start_profile_async():
|
|||||||
@app.api_route("/stop_profile", methods=["GET", "POST"])
|
@app.api_route("/stop_profile", methods=["GET", "POST"])
|
||||||
async def stop_profile_async():
|
async def stop_profile_async():
|
||||||
"""Stop profiling."""
|
"""Stop profiling."""
|
||||||
_global_state.tokenizer_manager.stop_profile()
|
_global_state.orchestrator.stop_profile()
|
||||||
return Response(
|
return Response(
|
||||||
content="Stop profiling. This will take some time.\n",
|
content="Stop profiling. This will take some time.\n",
|
||||||
status_code=200,
|
status_code=200,
|
||||||
@@ -257,7 +257,7 @@ async def stop_profile_async():
|
|||||||
@app.post("/update_weights_from_disk")
|
@app.post("/update_weights_from_disk")
|
||||||
async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request):
|
async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request):
|
||||||
"""Update the weights from disk in-place without re-launching the server."""
|
"""Update the weights from disk in-place without re-launching the server."""
|
||||||
success, message = await _global_state.tokenizer_manager.update_weights_from_disk(
|
success, message = await _global_state.orchestrator.update_weights_from_disk(
|
||||||
obj, request
|
obj, request
|
||||||
)
|
)
|
||||||
content = {"success": success, "message": message}
|
content = {"success": success, "message": message}
|
||||||
@@ -278,7 +278,7 @@ async def init_weights_update_group(
|
|||||||
obj: InitWeightsUpdateGroupReqInput, request: Request
|
obj: InitWeightsUpdateGroupReqInput, request: Request
|
||||||
):
|
):
|
||||||
"""Initialize the parameter update group."""
|
"""Initialize the parameter update group."""
|
||||||
success, message = await _global_state.tokenizer_manager.init_weights_update_group(
|
success, message = await _global_state.orchestrator.init_weights_update_group(
|
||||||
obj, request
|
obj, request
|
||||||
)
|
)
|
||||||
content = {"success": success, "message": message}
|
content = {"success": success, "message": message}
|
||||||
@@ -293,10 +293,8 @@ async def update_weights_from_distributed(
|
|||||||
obj: UpdateWeightsFromDistributedReqInput, request: Request
|
obj: UpdateWeightsFromDistributedReqInput, request: Request
|
||||||
):
|
):
|
||||||
"""Update model parameter from distributed online."""
|
"""Update model parameter from distributed online."""
|
||||||
success, message = (
|
success, message = await _global_state.orchestrator.update_weights_from_distributed(
|
||||||
await _global_state.tokenizer_manager.update_weights_from_distributed(
|
obj, request
|
||||||
obj, request
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
content = {"success": success, "message": message}
|
content = {"success": success, "message": message}
|
||||||
if success:
|
if success:
|
||||||
@@ -309,7 +307,7 @@ async def update_weights_from_distributed(
|
|||||||
async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
|
async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
|
||||||
"""Get model parameter by name."""
|
"""Get model parameter by name."""
|
||||||
try:
|
try:
|
||||||
ret = await _global_state.tokenizer_manager.get_weights_by_name(obj, request)
|
ret = await _global_state.orchestrator.get_weights_by_name(obj, request)
|
||||||
if ret is None:
|
if ret is None:
|
||||||
return _create_error_response("Get parameter by name failed")
|
return _create_error_response("Get parameter by name failed")
|
||||||
else:
|
else:
|
||||||
@@ -324,7 +322,7 @@ async def release_memory_occupation(
|
|||||||
):
|
):
|
||||||
"""Release GPU occupation temporarily"""
|
"""Release GPU occupation temporarily"""
|
||||||
try:
|
try:
|
||||||
await _global_state.tokenizer_manager.release_memory_occupation(obj, request)
|
await _global_state.orchestrator.release_memory_occupation(obj, request)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return _create_error_response(e)
|
return _create_error_response(e)
|
||||||
|
|
||||||
@@ -335,7 +333,7 @@ async def resume_memory_occupation(
|
|||||||
):
|
):
|
||||||
"""Resume GPU occupation"""
|
"""Resume GPU occupation"""
|
||||||
try:
|
try:
|
||||||
await _global_state.tokenizer_manager.resume_memory_occupation(obj, request)
|
await _global_state.orchestrator.resume_memory_occupation(obj, request)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return _create_error_response(e)
|
return _create_error_response(e)
|
||||||
|
|
||||||
@@ -344,7 +342,7 @@ async def resume_memory_occupation(
|
|||||||
async def open_session(obj: OpenSessionReqInput, request: Request):
|
async def open_session(obj: OpenSessionReqInput, request: Request):
|
||||||
"""Open a session, and return its unique session id."""
|
"""Open a session, and return its unique session id."""
|
||||||
try:
|
try:
|
||||||
session_id = await _global_state.tokenizer_manager.open_session(obj, request)
|
session_id = await _global_state.orchestrator.open_session(obj, request)
|
||||||
if session_id is None:
|
if session_id is None:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"Failed to open the session. Check if a session with the same id is still open."
|
"Failed to open the session. Check if a session with the same id is still open."
|
||||||
@@ -358,7 +356,7 @@ async def open_session(obj: OpenSessionReqInput, request: Request):
|
|||||||
async def close_session(obj: CloseSessionReqInput, request: Request):
|
async def close_session(obj: CloseSessionReqInput, request: Request):
|
||||||
"""Close the session"""
|
"""Close the session"""
|
||||||
try:
|
try:
|
||||||
await _global_state.tokenizer_manager.close_session(obj, request)
|
await _global_state.orchestrator.close_session(obj, request)
|
||||||
return Response(status_code=200)
|
return Response(status_code=200)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return _create_error_response(e)
|
return _create_error_response(e)
|
||||||
@@ -367,7 +365,7 @@ async def close_session(obj: CloseSessionReqInput, request: Request):
|
|||||||
@app.api_route("/configure_logging", methods=["GET", "POST"])
|
@app.api_route("/configure_logging", methods=["GET", "POST"])
|
||||||
async def configure_logging(obj: ConfigureLoggingReq, request: Request):
|
async def configure_logging(obj: ConfigureLoggingReq, request: Request):
|
||||||
"""Close the session"""
|
"""Close the session"""
|
||||||
_global_state.tokenizer_manager.configure_logging(obj)
|
_global_state.orchestrator.configure_logging(obj)
|
||||||
return Response(status_code=200)
|
return Response(status_code=200)
|
||||||
|
|
||||||
|
|
||||||
@@ -398,24 +396,24 @@ async def function_call_request(obj: FunctionCallReqInput, request: Request):
|
|||||||
|
|
||||||
@app.post("/v1/completions")
|
@app.post("/v1/completions")
|
||||||
async def openai_v1_completions(raw_request: Request):
|
async def openai_v1_completions(raw_request: Request):
|
||||||
return await v1_completions(_global_state.tokenizer_manager, raw_request)
|
return await v1_completions(_global_state.orchestrator, raw_request)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/chat/completions")
|
@app.post("/v1/chat/completions")
|
||||||
async def openai_v1_chat_completions(raw_request: Request):
|
async def openai_v1_chat_completions(raw_request: Request):
|
||||||
return await v1_chat_completions(_global_state.tokenizer_manager, raw_request)
|
return await v1_chat_completions(_global_state.orchestrator, raw_request)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/embeddings", response_class=ORJSONResponse)
|
@app.post("/v1/embeddings", response_class=ORJSONResponse)
|
||||||
async def openai_v1_embeddings(raw_request: Request):
|
async def openai_v1_embeddings(raw_request: Request):
|
||||||
response = await v1_embeddings(_global_state.tokenizer_manager, raw_request)
|
response = await v1_embeddings(_global_state.orchestrator, raw_request)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
@app.get("/v1/models", response_class=ORJSONResponse)
|
@app.get("/v1/models", response_class=ORJSONResponse)
|
||||||
def available_models():
|
def available_models():
|
||||||
"""Show available models."""
|
"""Show available models."""
|
||||||
served_model_names = [_global_state.tokenizer_manager.served_model_name]
|
served_model_names = [_global_state.orchestrator.served_model_name]
|
||||||
model_cards = []
|
model_cards = []
|
||||||
for served_model_name in served_model_names:
|
for served_model_name in served_model_names:
|
||||||
model_cards.append(ModelCard(id=served_model_name, root=served_model_name))
|
model_cards.append(ModelCard(id=served_model_name, root=served_model_name))
|
||||||
@@ -425,7 +423,7 @@ def available_models():
|
|||||||
@app.post("/v1/files")
|
@app.post("/v1/files")
|
||||||
async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")):
|
async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")):
|
||||||
return await v1_files_create(
|
return await v1_files_create(
|
||||||
file, purpose, _global_state.tokenizer_manager.server_args.file_storage_pth
|
file, purpose, _global_state.orchestrator.server_args.file_storage_pth
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -437,13 +435,13 @@ async def delete_file(file_id: str):
|
|||||||
|
|
||||||
@app.post("/v1/batches")
|
@app.post("/v1/batches")
|
||||||
async def openai_v1_batches(raw_request: Request):
|
async def openai_v1_batches(raw_request: Request):
|
||||||
return await v1_batches(_global_state.tokenizer_manager, raw_request)
|
return await v1_batches(_global_state.orchestrator, raw_request)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/batches/{batch_id}/cancel")
|
@app.post("/v1/batches/{batch_id}/cancel")
|
||||||
async def cancel_batches(batch_id: str):
|
async def cancel_batches(batch_id: str):
|
||||||
# https://platform.openai.com/docs/api-reference/batch/cancel
|
# https://platform.openai.com/docs/api-reference/batch/cancel
|
||||||
return await v1_cancel_batch(_global_state.tokenizer_manager, batch_id)
|
return await v1_cancel_batch(_global_state.orchestrator, batch_id)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/v1/batches/{batch_id}")
|
@app.get("/v1/batches/{batch_id}")
|
||||||
@@ -492,18 +490,18 @@ def launch_server(
|
|||||||
|
|
||||||
- HTTP server: A FastAPI server that routes requests to the engine.
|
- HTTP server: A FastAPI server that routes requests to the engine.
|
||||||
- The engine consists of three components:
|
- The engine consists of three components:
|
||||||
1. TokenizerManager: Tokenizes the requests and sends them to the scheduler.
|
1. StdOrchestrator: Tokenizes the requests and sends them to the scheduler.
|
||||||
2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager.
|
2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager.
|
||||||
3. DetokenizerManager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager.
|
3. DetokenizerManager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
1. The HTTP server, Engine, and TokenizerManager both run in the main process.
|
1. The HTTP server, Engine, and StdOrchestrator both run in the main process.
|
||||||
2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library.
|
2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library.
|
||||||
"""
|
"""
|
||||||
tokenizer_manager, scheduler_info = _launch_subprocesses(server_args=server_args)
|
orchestrator, scheduler_info = _launch_subprocesses(server_args=server_args)
|
||||||
set_global_state(
|
set_global_state(
|
||||||
_GlobalState(
|
_GlobalState(
|
||||||
tokenizer_manager=tokenizer_manager,
|
orchestrator=orchestrator,
|
||||||
scheduler_info=scheduler_info,
|
scheduler_info=scheduler_info,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -523,7 +521,7 @@ def launch_server(
|
|||||||
args=(
|
args=(
|
||||||
server_args,
|
server_args,
|
||||||
pipe_finish_writer,
|
pipe_finish_writer,
|
||||||
_global_state.tokenizer_manager.image_token_id,
|
_global_state.orchestrator.image_token_id,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
t.start()
|
t.start()
|
||||||
|
|||||||
@@ -241,7 +241,7 @@ class LlavaImageProcessor(BaseImageProcessor):
|
|||||||
|
|
||||||
return pixel_values, image_hash, image.size
|
return pixel_values, image_hash, image.size
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())
|
logger.error("Exception in StdOrchestrator:\n" + get_exception_traceback())
|
||||||
|
|
||||||
async def _process_single_image(
|
async def _process_single_image(
|
||||||
self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str
|
self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str
|
||||||
@@ -491,7 +491,7 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
|
|||||||
|
|
||||||
return pixel_values, image_hash, image.size, image_grid_thws
|
return pixel_values, image_hash, image.size, image_grid_thws
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())
|
logger.error("Exception in StdOrchestrator:\n" + get_exception_traceback())
|
||||||
|
|
||||||
async def _process_single_image(self, image_data: Union[bytes, str]):
|
async def _process_single_image(self, image_data: Union[bytes, str]):
|
||||||
if self.executor is not None:
|
if self.executor is not None:
|
||||||
|
|||||||
@@ -13,7 +13,7 @@
|
|||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""
|
"""
|
||||||
The definition of objects transfered between different
|
The definition of objects transfered between different
|
||||||
processes (TokenizerManager, DetokenizerManager, Controller).
|
processes (StdOrchestrator, DetokenizerManager, Controller).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
|
|||||||
@@ -173,7 +173,7 @@ class Scheduler:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if server_args.skip_tokenizer_init:
|
if server_args.skip_tokenizer_init:
|
||||||
# Directly send to the TokenizerManager
|
# Directly send to the StdOrchestrator
|
||||||
self.send_to_detokenizer = get_zmq_socket(
|
self.send_to_detokenizer = get_zmq_socket(
|
||||||
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
|
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -117,7 +117,7 @@ def create_streaming_error_response(
|
|||||||
return json_str
|
return json_str
|
||||||
|
|
||||||
|
|
||||||
def load_chat_template_for_openai_api(tokenizer_manager, chat_template_arg, model_path):
|
def load_chat_template_for_openai_api(orchestrator, chat_template_arg, model_path):
|
||||||
global chat_template_name
|
global chat_template_name
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -133,9 +133,7 @@ def load_chat_template_for_openai_api(tokenizer_manager, chat_template_arg, mode
|
|||||||
if chat_template_arg.endswith(".jinja"):
|
if chat_template_arg.endswith(".jinja"):
|
||||||
with open(chat_template_arg, "r") as f:
|
with open(chat_template_arg, "r") as f:
|
||||||
chat_template = "".join(f.readlines()).strip("\n")
|
chat_template = "".join(f.readlines()).strip("\n")
|
||||||
tokenizer_manager.tokenizer.chat_template = chat_template.replace(
|
orchestrator.tokenizer.chat_template = chat_template.replace("\\n", "\n")
|
||||||
"\\n", "\n"
|
|
||||||
)
|
|
||||||
chat_template_name = None
|
chat_template_name = None
|
||||||
else:
|
else:
|
||||||
assert chat_template_arg.endswith(
|
assert chat_template_arg.endswith(
|
||||||
@@ -231,7 +229,7 @@ async def v1_delete_file(file_id: str):
|
|||||||
return FileDeleteResponse(id=file_id, deleted=True)
|
return FileDeleteResponse(id=file_id, deleted=True)
|
||||||
|
|
||||||
|
|
||||||
async def v1_batches(tokenizer_manager, raw_request: Request):
|
async def v1_batches(orchestrator, raw_request: Request):
|
||||||
try:
|
try:
|
||||||
body = await raw_request.json()
|
body = await raw_request.json()
|
||||||
|
|
||||||
@@ -252,7 +250,7 @@ async def v1_batches(tokenizer_manager, raw_request: Request):
|
|||||||
batch_storage[batch_id] = batch_response
|
batch_storage[batch_id] = batch_response
|
||||||
|
|
||||||
# Start processing the batch asynchronously
|
# Start processing the batch asynchronously
|
||||||
asyncio.create_task(process_batch(tokenizer_manager, batch_id, batch_request))
|
asyncio.create_task(process_batch(orchestrator, batch_id, batch_request))
|
||||||
|
|
||||||
# Return the initial batch_response
|
# Return the initial batch_response
|
||||||
return batch_response
|
return batch_response
|
||||||
@@ -263,7 +261,7 @@ async def v1_batches(tokenizer_manager, raw_request: Request):
|
|||||||
return {"error": str(e)}
|
return {"error": str(e)}
|
||||||
|
|
||||||
|
|
||||||
async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRequest):
|
async def process_batch(orchestrator, batch_id: str, batch_request: BatchRequest):
|
||||||
try:
|
try:
|
||||||
# Update the batch status to "in_progress"
|
# Update the batch status to "in_progress"
|
||||||
batch_storage[batch_id].status = "in_progress"
|
batch_storage[batch_id].status = "in_progress"
|
||||||
@@ -306,7 +304,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
|
|||||||
|
|
||||||
if end_point == "/v1/chat/completions":
|
if end_point == "/v1/chat/completions":
|
||||||
adapted_request, request = v1_chat_generate_request(
|
adapted_request, request = v1_chat_generate_request(
|
||||||
all_requests, tokenizer_manager, request_ids=request_ids
|
all_requests, orchestrator, request_ids=request_ids
|
||||||
)
|
)
|
||||||
elif end_point == "/v1/completions":
|
elif end_point == "/v1/completions":
|
||||||
adapted_request, request = v1_generate_request(
|
adapted_request, request = v1_generate_request(
|
||||||
@@ -314,7 +312,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ret = await tokenizer_manager.generate_request(adapted_request).__anext__()
|
ret = await orchestrator.generate_request(adapted_request).__anext__()
|
||||||
if not isinstance(ret, list):
|
if not isinstance(ret, list):
|
||||||
ret = [ret]
|
ret = [ret]
|
||||||
if end_point == "/v1/chat/completions":
|
if end_point == "/v1/chat/completions":
|
||||||
@@ -322,12 +320,12 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
|
|||||||
request,
|
request,
|
||||||
ret,
|
ret,
|
||||||
to_file=True,
|
to_file=True,
|
||||||
cache_report=tokenizer_manager.server_args.enable_cache_report,
|
cache_report=orchestrator.server_args.enable_cache_report,
|
||||||
tool_call_parser=tokenizer_manager.server_args.tool_call_parser,
|
tool_call_parser=orchestrator.server_args.tool_call_parser,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
responses = v1_generate_response(
|
responses = v1_generate_response(
|
||||||
request, ret, tokenizer_manager, to_file=True
|
request, ret, orchestrator, to_file=True
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -399,7 +397,7 @@ async def v1_retrieve_batch(batch_id: str):
|
|||||||
return batch_response
|
return batch_response
|
||||||
|
|
||||||
|
|
||||||
async def v1_cancel_batch(tokenizer_manager, batch_id: str):
|
async def v1_cancel_batch(orchestrator, batch_id: str):
|
||||||
# Retrieve the batch job from the in-memory storage
|
# Retrieve the batch job from the in-memory storage
|
||||||
batch_response = batch_storage.get(batch_id)
|
batch_response = batch_storage.get(batch_id)
|
||||||
if batch_response is None:
|
if batch_response is None:
|
||||||
@@ -410,7 +408,7 @@ async def v1_cancel_batch(tokenizer_manager, batch_id: str):
|
|||||||
# Start cancelling the batch asynchronously
|
# Start cancelling the batch asynchronously
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
cancel_batch(
|
cancel_batch(
|
||||||
tokenizer_manager=tokenizer_manager,
|
orchestrator=orchestrator,
|
||||||
batch_id=batch_id,
|
batch_id=batch_id,
|
||||||
input_file_id=batch_response.input_file_id,
|
input_file_id=batch_response.input_file_id,
|
||||||
)
|
)
|
||||||
@@ -427,7 +425,7 @@ async def v1_cancel_batch(tokenizer_manager, batch_id: str):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def cancel_batch(tokenizer_manager, batch_id: str, input_file_id: str):
|
async def cancel_batch(orchestrator, batch_id: str, input_file_id: str):
|
||||||
try:
|
try:
|
||||||
# Update the batch status to "cancelling"
|
# Update the batch status to "cancelling"
|
||||||
batch_storage[batch_id].status = "cancelling"
|
batch_storage[batch_id].status = "cancelling"
|
||||||
@@ -451,7 +449,7 @@ async def cancel_batch(tokenizer_manager, batch_id: str, input_file_id: str):
|
|||||||
|
|
||||||
# Cancel requests by request_ids
|
# Cancel requests by request_ids
|
||||||
for rid in request_ids:
|
for rid in request_ids:
|
||||||
tokenizer_manager.abort_request(rid=rid)
|
orchestrator.abort_request(rid=rid)
|
||||||
|
|
||||||
retrieve_batch = batch_storage[batch_id]
|
retrieve_batch = batch_storage[batch_id]
|
||||||
retrieve_batch.status = "cancelled"
|
retrieve_batch.status = "cancelled"
|
||||||
@@ -579,7 +577,7 @@ def v1_generate_request(
|
|||||||
return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
|
return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
|
||||||
|
|
||||||
|
|
||||||
def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
|
def v1_generate_response(request, ret, orchestrator, to_file=False):
|
||||||
choices = []
|
choices = []
|
||||||
echo = False
|
echo = False
|
||||||
|
|
||||||
@@ -591,15 +589,13 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
|
|||||||
elif isinstance(request.prompt, list) and isinstance(request.prompt[0], list):
|
elif isinstance(request.prompt, list) and isinstance(request.prompt[0], list):
|
||||||
# for the case of multiple token ids prompts
|
# for the case of multiple token ids prompts
|
||||||
prompts = [
|
prompts = [
|
||||||
tokenizer_manager.tokenizer.decode(prompt, skip_special_tokens=True)
|
orchestrator.tokenizer.decode(prompt, skip_special_tokens=True)
|
||||||
for prompt in request.prompt
|
for prompt in request.prompt
|
||||||
]
|
]
|
||||||
elif isinstance(request.prompt, list) and isinstance(request.prompt[0], int):
|
elif isinstance(request.prompt, list) and isinstance(request.prompt[0], int):
|
||||||
# for the case of single token ids prompt
|
# for the case of single token ids prompt
|
||||||
prompts = [
|
prompts = [
|
||||||
tokenizer_manager.tokenizer.decode(
|
orchestrator.tokenizer.decode(request.prompt, skip_special_tokens=True)
|
||||||
request.prompt, skip_special_tokens=True
|
|
||||||
)
|
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
# for the case of single str prompt
|
# for the case of single str prompt
|
||||||
@@ -709,7 +705,7 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
async def v1_completions(tokenizer_manager, raw_request: Request):
|
async def v1_completions(orchestrator, raw_request: Request):
|
||||||
request_json = await raw_request.json()
|
request_json = await raw_request.json()
|
||||||
all_requests = [CompletionRequest(**request_json)]
|
all_requests = [CompletionRequest(**request_json)]
|
||||||
adapted_request, request = v1_generate_request(all_requests)
|
adapted_request, request = v1_generate_request(all_requests)
|
||||||
@@ -722,7 +718,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|||||||
prompt_tokens = {}
|
prompt_tokens = {}
|
||||||
completion_tokens = {}
|
completion_tokens = {}
|
||||||
try:
|
try:
|
||||||
async for content in tokenizer_manager.generate_request(
|
async for content in orchestrator.generate_request(
|
||||||
adapted_request, raw_request
|
adapted_request, raw_request
|
||||||
):
|
):
|
||||||
index = content.get("index", 0)
|
index = content.get("index", 0)
|
||||||
@@ -745,14 +741,14 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|||||||
prompts = request.prompt[index // request.n]
|
prompts = request.prompt[index // request.n]
|
||||||
elif isinstance(request.prompt[0], int):
|
elif isinstance(request.prompt[0], int):
|
||||||
# for the case of single token ids prompt
|
# for the case of single token ids prompt
|
||||||
prompts = tokenizer_manager.tokenizer.decode(
|
prompts = orchestrator.tokenizer.decode(
|
||||||
request.prompt, skip_special_tokens=True
|
request.prompt, skip_special_tokens=True
|
||||||
)
|
)
|
||||||
elif isinstance(request.prompt[0], list) and isinstance(
|
elif isinstance(request.prompt[0], list) and isinstance(
|
||||||
request.prompt[0][0], int
|
request.prompt[0][0], int
|
||||||
):
|
):
|
||||||
# for the case of multiple token ids prompts
|
# for the case of multiple token ids prompts
|
||||||
prompts = tokenizer_manager.tokenizer.decode(
|
prompts = orchestrator.tokenizer.decode(
|
||||||
request.prompt[index // request.n],
|
request.prompt[index // request.n],
|
||||||
skip_special_tokens=True,
|
skip_special_tokens=True,
|
||||||
)
|
)
|
||||||
@@ -847,12 +843,12 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
generate_stream_resp(),
|
generate_stream_resp(),
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
background=tokenizer_manager.create_abort_task(adapted_request),
|
background=orchestrator.create_abort_task(adapted_request),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Non-streaming response.
|
# Non-streaming response.
|
||||||
try:
|
try:
|
||||||
ret = await tokenizer_manager.generate_request(
|
ret = await orchestrator.generate_request(
|
||||||
adapted_request, raw_request
|
adapted_request, raw_request
|
||||||
).__anext__()
|
).__anext__()
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
@@ -861,13 +857,13 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|||||||
if not isinstance(ret, list):
|
if not isinstance(ret, list):
|
||||||
ret = [ret]
|
ret = [ret]
|
||||||
|
|
||||||
response = v1_generate_response(request, ret, tokenizer_manager)
|
response = v1_generate_response(request, ret, orchestrator)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
def v1_chat_generate_request(
|
def v1_chat_generate_request(
|
||||||
all_requests: List[ChatCompletionRequest],
|
all_requests: List[ChatCompletionRequest],
|
||||||
tokenizer_manager,
|
orchestrator,
|
||||||
request_ids: List[str] = None,
|
request_ids: List[str] = None,
|
||||||
):
|
):
|
||||||
input_ids = []
|
input_ids = []
|
||||||
@@ -922,7 +918,7 @@ def v1_chat_generate_request(
|
|||||||
assistant_prefix = None
|
assistant_prefix = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
prompt_ids = tokenizer_manager.tokenizer.apply_chat_template(
|
prompt_ids = orchestrator.tokenizer.apply_chat_template(
|
||||||
openai_compatible_messages,
|
openai_compatible_messages,
|
||||||
tokenize=True,
|
tokenize=True,
|
||||||
add_generation_prompt=True,
|
add_generation_prompt=True,
|
||||||
@@ -933,7 +929,7 @@ def v1_chat_generate_request(
|
|||||||
# has a different tools input format that is not compatiable
|
# has a different tools input format that is not compatiable
|
||||||
# with openAI's apply_chat_template tool_call format, like Mistral.
|
# with openAI's apply_chat_template tool_call format, like Mistral.
|
||||||
tools = [t if "function" in t else {"function": t} for t in tools]
|
tools = [t if "function" in t else {"function": t} for t in tools]
|
||||||
prompt_ids = tokenizer_manager.tokenizer.apply_chat_template(
|
prompt_ids = orchestrator.tokenizer.apply_chat_template(
|
||||||
openai_compatible_messages,
|
openai_compatible_messages,
|
||||||
tokenize=True,
|
tokenize=True,
|
||||||
add_generation_prompt=True,
|
add_generation_prompt=True,
|
||||||
@@ -941,11 +937,8 @@ def v1_chat_generate_request(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if assistant_prefix:
|
if assistant_prefix:
|
||||||
encoded = tokenizer_manager.tokenizer.encode(assistant_prefix)
|
encoded = orchestrator.tokenizer.encode(assistant_prefix)
|
||||||
if (
|
if encoded and encoded[0] == orchestrator.tokenizer.bos_token_id:
|
||||||
encoded
|
|
||||||
and encoded[0] == tokenizer_manager.tokenizer.bos_token_id
|
|
||||||
):
|
|
||||||
encoded = encoded[1:]
|
encoded = encoded[1:]
|
||||||
prompt_ids += encoded
|
prompt_ids += encoded
|
||||||
stop = request.stop
|
stop = request.stop
|
||||||
@@ -962,7 +955,7 @@ def v1_chat_generate_request(
|
|||||||
stop.append(request.stop)
|
stop.append(request.stop)
|
||||||
else:
|
else:
|
||||||
stop.extend(request.stop)
|
stop.extend(request.stop)
|
||||||
prompt_ids = tokenizer_manager.tokenizer.encode(prompt)
|
prompt_ids = orchestrator.tokenizer.encode(prompt)
|
||||||
else:
|
else:
|
||||||
# Use the raw prompt and stop strings if the messages is already a string.
|
# Use the raw prompt and stop strings if the messages is already a string.
|
||||||
prompt_ids = request.messages
|
prompt_ids = request.messages
|
||||||
@@ -1201,10 +1194,10 @@ def v1_chat_generate_response(
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
async def v1_chat_completions(orchestrator, raw_request: Request):
|
||||||
request_json = await raw_request.json()
|
request_json = await raw_request.json()
|
||||||
all_requests = [ChatCompletionRequest(**request_json)]
|
all_requests = [ChatCompletionRequest(**request_json)]
|
||||||
adapted_request, request = v1_chat_generate_request(all_requests, tokenizer_manager)
|
adapted_request, request = v1_chat_generate_request(all_requests, orchestrator)
|
||||||
|
|
||||||
if adapted_request.stream:
|
if adapted_request.stream:
|
||||||
parser_dict = {}
|
parser_dict = {}
|
||||||
@@ -1216,7 +1209,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|||||||
prompt_tokens = {}
|
prompt_tokens = {}
|
||||||
completion_tokens = {}
|
completion_tokens = {}
|
||||||
try:
|
try:
|
||||||
async for content in tokenizer_manager.generate_request(
|
async for content in orchestrator.generate_request(
|
||||||
adapted_request, raw_request
|
adapted_request, raw_request
|
||||||
):
|
):
|
||||||
index = content.get("index", 0)
|
index = content.get("index", 0)
|
||||||
@@ -1306,7 +1299,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|||||||
if index not in parser_dict:
|
if index not in parser_dict:
|
||||||
parser_dict[index] = FunctionCallParser(
|
parser_dict[index] = FunctionCallParser(
|
||||||
tools=request.tools,
|
tools=request.tools,
|
||||||
tool_call_parser=tokenizer_manager.server_args.tool_call_parser,
|
tool_call_parser=orchestrator.server_args.tool_call_parser,
|
||||||
)
|
)
|
||||||
parser = parser_dict[index]
|
parser = parser_dict[index]
|
||||||
|
|
||||||
@@ -1438,12 +1431,12 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
generate_stream_resp(),
|
generate_stream_resp(),
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
background=tokenizer_manager.create_abort_task(adapted_request),
|
background=orchestrator.create_abort_task(adapted_request),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Non-streaming response.
|
# Non-streaming response.
|
||||||
try:
|
try:
|
||||||
ret = await tokenizer_manager.generate_request(
|
ret = await orchestrator.generate_request(
|
||||||
adapted_request, raw_request
|
adapted_request, raw_request
|
||||||
).__anext__()
|
).__anext__()
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
@@ -1454,14 +1447,14 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|||||||
response = v1_chat_generate_response(
|
response = v1_chat_generate_response(
|
||||||
request,
|
request,
|
||||||
ret,
|
ret,
|
||||||
cache_report=tokenizer_manager.server_args.enable_cache_report,
|
cache_report=orchestrator.server_args.enable_cache_report,
|
||||||
tool_call_parser=tokenizer_manager.server_args.tool_call_parser,
|
tool_call_parser=orchestrator.server_args.tool_call_parser,
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
def v1_embedding_request(all_requests, tokenizer_manager):
|
def v1_embedding_request(all_requests, orchestrator):
|
||||||
prompts = []
|
prompts = []
|
||||||
sampling_params_list = []
|
sampling_params_list = []
|
||||||
first_prompt_type = type(all_requests[0].input)
|
first_prompt_type = type(all_requests[0].input)
|
||||||
@@ -1516,13 +1509,13 @@ def v1_embedding_response(ret, model_path, to_file=False):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def v1_embeddings(tokenizer_manager, raw_request: Request):
|
async def v1_embeddings(orchestrator, raw_request: Request):
|
||||||
request_json = await raw_request.json()
|
request_json = await raw_request.json()
|
||||||
all_requests = [EmbeddingRequest(**request_json)]
|
all_requests = [EmbeddingRequest(**request_json)]
|
||||||
adapted_request, request = v1_embedding_request(all_requests, tokenizer_manager)
|
adapted_request, request = v1_embedding_request(all_requests, orchestrator)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ret = await tokenizer_manager.generate_request(
|
ret = await orchestrator.generate_request(
|
||||||
adapted_request, raw_request
|
adapted_request, raw_request
|
||||||
).__anext__()
|
).__anext__()
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
@@ -1531,7 +1524,7 @@ async def v1_embeddings(tokenizer_manager, raw_request: Request):
|
|||||||
if not isinstance(ret, list):
|
if not isinstance(ret, list):
|
||||||
ret = [ret]
|
ret = [ret]
|
||||||
|
|
||||||
response = v1_embedding_response(ret, tokenizer_manager.model_path)
|
response = v1_embedding_response(ret, orchestrator.model_path)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|||||||
0
python/sglang/srt/orchestration/__init__.py
Normal file
0
python/sglang/srt/orchestration/__init__.py
Normal file
0
python/sglang/srt/orchestration/std/__init__.py
Normal file
0
python/sglang/srt/orchestration/std/__init__.py
Normal file
@@ -11,7 +11,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""TokenizerManager is a process that tokenizes the text."""
|
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
@@ -66,8 +65,8 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TokenizerManager:
|
class StdOrchestrator:
|
||||||
"""TokenizerManager is a process that tokenizes the text."""
|
"""StdOrchestrator is the primary entrypoint of orchestration.std package"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -439,20 +438,20 @@ async def print_exception_wrapper(func):
|
|||||||
await func()
|
await func()
|
||||||
except Exception:
|
except Exception:
|
||||||
traceback = get_exception_traceback()
|
traceback = get_exception_traceback()
|
||||||
logger.error(f"TokenizerManager hit an exception: {traceback}")
|
logger.error(f"StdOrchestrator hit an exception: {traceback}")
|
||||||
kill_process_tree(os.getpid(), include_parent=True)
|
kill_process_tree(os.getpid(), include_parent=True)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
class SignalHandler:
|
class SignalHandler:
|
||||||
def __init__(self, tokenizer_manager):
|
def __init__(self, orchestrator):
|
||||||
self.tokenizer_manager = tokenizer_manager
|
self.orchestrator = orchestrator
|
||||||
|
|
||||||
def signal_handler(self, signum=None, frame=None):
|
def signal_handler(self, signum=None, frame=None):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..."
|
f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..."
|
||||||
)
|
)
|
||||||
self.tokenizer_manager.gracefully_exit = True
|
self.orchestrator.gracefully_exit = True
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
@@ -1039,7 +1039,7 @@ class PortArgs:
|
|||||||
if dp_rank is None:
|
if dp_rank is None:
|
||||||
scheduler_input_port = (
|
scheduler_input_port = (
|
||||||
port_base + 2
|
port_base + 2
|
||||||
) # TokenizerManager to DataParallelController
|
) # StdOrchestrator to DataParallelController
|
||||||
else:
|
else:
|
||||||
scheduler_input_port = port_base + 2 + 1 + dp_rank
|
scheduler_input_port = port_base + 2 + 1 + dp_rank
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user