Fix sync() when fork(1) (#412)

This commit is contained in:
Liangsheng Yin
2024-05-08 15:15:18 +08:00
committed by GitHub
parent 4a1c6ae2ce
commit d5de20a3ee
3 changed files with 20 additions and 11 deletions

View File

@@ -232,9 +232,15 @@ register_chat_template(
name="c4ai-command-r", name="c4ai-command-r",
default_system_prompt=None, default_system_prompt=None,
role_prefix_and_suffix={ role_prefix_and_suffix={
"system": ("<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>", "<|END_OF_TURN_TOKEN|>"), "system": (
"<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>",
"<|END_OF_TURN_TOKEN|>",
),
"user": ("<|START_OF_TURN_TOKEN|><|USER_TOKEN|>", "<|END_OF_TURN_TOKEN|>"), "user": ("<|START_OF_TURN_TOKEN|><|USER_TOKEN|>", "<|END_OF_TURN_TOKEN|>"),
"assistant": ("<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", "<|END_OF_TURN_TOKEN|>"), "assistant": (
"<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>",
"<|END_OF_TURN_TOKEN|>",
),
}, },
style=ChatTemplateStyle.PLAIN, style=ChatTemplateStyle.PLAIN,
) )

View File

@@ -1,6 +1,7 @@
"""The interpreter that executes SGL programs""" """The interpreter that executes SGL programs"""
import asyncio import asyncio
import contextvars
import multiprocessing import multiprocessing
import queue import queue
import threading import threading
@@ -9,7 +10,6 @@ from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
import contextvars
import tqdm import tqdm
from sglang.global_config import global_config from sglang.global_config import global_config
@@ -222,7 +222,10 @@ class StreamExecutor:
def _run_worker_in_context(): def _run_worker_in_context():
self._thread_worker_func() self._thread_worker_func()
self.worker = threading.Thread(target=contextvars.copy_context().run, args=(_run_worker_in_context, ))
self.worker = threading.Thread(
target=contextvars.copy_context().run, args=(_run_worker_in_context,)
)
self.worker.start() self.worker.start()
# For streaming # For streaming
@@ -265,12 +268,11 @@ class StreamExecutor:
self, self,
number: int, number: int,
position_ids_offset: Optional[List[int]] = None, position_ids_offset: Optional[List[int]] = None,
copy: bool = False,
): ):
if number > 1 or copy: if number > 1:
self.submit(SglCommitLazy()) self.submit(SglCommitLazy())
self.sync()
self.sync()
number = int(number) number = int(number)
exes = [ exes = [
@@ -656,16 +658,15 @@ class ProgramState:
self, self,
number: int = 1, number: int = 1,
position_ids_offset: Optional[List[int]] = None, position_ids_offset: Optional[List[int]] = None,
copy: bool = False,
): ):
stream_executors = self.stream_executor.fork(number, position_ids_offset, copy) stream_executors = self.stream_executor.fork(number, position_ids_offset)
states = [ProgramState(x) for x in stream_executors] states = [ProgramState(x) for x in stream_executors]
state_group = ProgramStateGroup(states, self) state_group = ProgramStateGroup(states, self)
return state_group return state_group
@contextmanager @contextmanager
def copy(self, position_ids_offset: Optional[List[int]] = None): def copy(self, position_ids_offset: Optional[List[int]] = None):
state_group = self.fork(1, position_ids_offset, True) state_group = self.fork(1, position_ids_offset)
try: try:
yield state_group[0] yield state_group[0]
finally: finally:

View File

@@ -42,7 +42,9 @@ class DetokenizerManager:
output_strs = self.tokenizer.batch_decode( output_strs = self.tokenizer.batch_decode(
output_tokens, output_tokens,
skip_special_tokens=recv_obj.skip_special_tokens[0], skip_special_tokens=recv_obj.skip_special_tokens[0],
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0], spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[
0
],
) )
# Trim stop str # Trim stop str