Fix sync() when fork(1) (#412)
This commit is contained in:
@@ -232,9 +232,15 @@ register_chat_template(
|
|||||||
name="c4ai-command-r",
|
name="c4ai-command-r",
|
||||||
default_system_prompt=None,
|
default_system_prompt=None,
|
||||||
role_prefix_and_suffix={
|
role_prefix_and_suffix={
|
||||||
"system": ("<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>", "<|END_OF_TURN_TOKEN|>"),
|
"system": (
|
||||||
|
"<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>",
|
||||||
|
"<|END_OF_TURN_TOKEN|>",
|
||||||
|
),
|
||||||
"user": ("<|START_OF_TURN_TOKEN|><|USER_TOKEN|>", "<|END_OF_TURN_TOKEN|>"),
|
"user": ("<|START_OF_TURN_TOKEN|><|USER_TOKEN|>", "<|END_OF_TURN_TOKEN|>"),
|
||||||
"assistant": ("<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", "<|END_OF_TURN_TOKEN|>"),
|
"assistant": (
|
||||||
|
"<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>",
|
||||||
|
"<|END_OF_TURN_TOKEN|>",
|
||||||
|
),
|
||||||
},
|
},
|
||||||
style=ChatTemplateStyle.PLAIN,
|
style=ChatTemplateStyle.PLAIN,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""The interpreter that executes SGL programs"""
|
"""The interpreter that executes SGL programs"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import contextvars
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import queue
|
import queue
|
||||||
import threading
|
import threading
|
||||||
@@ -9,7 +10,6 @@ from concurrent.futures import ThreadPoolExecutor
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Callable, Dict, List, Optional, Union
|
from typing import Any, Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
import contextvars
|
|
||||||
import tqdm
|
import tqdm
|
||||||
|
|
||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
@@ -222,7 +222,10 @@ class StreamExecutor:
|
|||||||
|
|
||||||
def _run_worker_in_context():
|
def _run_worker_in_context():
|
||||||
self._thread_worker_func()
|
self._thread_worker_func()
|
||||||
self.worker = threading.Thread(target=contextvars.copy_context().run, args=(_run_worker_in_context, ))
|
|
||||||
|
self.worker = threading.Thread(
|
||||||
|
target=contextvars.copy_context().run, args=(_run_worker_in_context,)
|
||||||
|
)
|
||||||
self.worker.start()
|
self.worker.start()
|
||||||
|
|
||||||
# For streaming
|
# For streaming
|
||||||
@@ -265,12 +268,11 @@ class StreamExecutor:
|
|||||||
self,
|
self,
|
||||||
number: int,
|
number: int,
|
||||||
position_ids_offset: Optional[List[int]] = None,
|
position_ids_offset: Optional[List[int]] = None,
|
||||||
copy: bool = False,
|
|
||||||
):
|
):
|
||||||
if number > 1 or copy:
|
if number > 1:
|
||||||
self.submit(SglCommitLazy())
|
self.submit(SglCommitLazy())
|
||||||
self.sync()
|
|
||||||
|
|
||||||
|
self.sync()
|
||||||
number = int(number)
|
number = int(number)
|
||||||
|
|
||||||
exes = [
|
exes = [
|
||||||
@@ -656,16 +658,15 @@ class ProgramState:
|
|||||||
self,
|
self,
|
||||||
number: int = 1,
|
number: int = 1,
|
||||||
position_ids_offset: Optional[List[int]] = None,
|
position_ids_offset: Optional[List[int]] = None,
|
||||||
copy: bool = False,
|
|
||||||
):
|
):
|
||||||
stream_executors = self.stream_executor.fork(number, position_ids_offset, copy)
|
stream_executors = self.stream_executor.fork(number, position_ids_offset)
|
||||||
states = [ProgramState(x) for x in stream_executors]
|
states = [ProgramState(x) for x in stream_executors]
|
||||||
state_group = ProgramStateGroup(states, self)
|
state_group = ProgramStateGroup(states, self)
|
||||||
return state_group
|
return state_group
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def copy(self, position_ids_offset: Optional[List[int]] = None):
|
def copy(self, position_ids_offset: Optional[List[int]] = None):
|
||||||
state_group = self.fork(1, position_ids_offset, True)
|
state_group = self.fork(1, position_ids_offset)
|
||||||
try:
|
try:
|
||||||
yield state_group[0]
|
yield state_group[0]
|
||||||
finally:
|
finally:
|
||||||
|
|||||||
@@ -42,7 +42,9 @@ class DetokenizerManager:
|
|||||||
output_strs = self.tokenizer.batch_decode(
|
output_strs = self.tokenizer.batch_decode(
|
||||||
output_tokens,
|
output_tokens,
|
||||||
skip_special_tokens=recv_obj.skip_special_tokens[0],
|
skip_special_tokens=recv_obj.skip_special_tokens[0],
|
||||||
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
|
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[
|
||||||
|
0
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Trim stop str
|
# Trim stop str
|
||||||
|
|||||||
Reference in New Issue
Block a user