Fix sync() when fork(1) (#412)
This commit is contained in:
@@ -232,9 +232,15 @@ register_chat_template(
|
||||
name="c4ai-command-r",
|
||||
default_system_prompt=None,
|
||||
role_prefix_and_suffix={
|
||||
"system": ("<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>", "<|END_OF_TURN_TOKEN|>"),
|
||||
"system": (
|
||||
"<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>",
|
||||
"<|END_OF_TURN_TOKEN|>",
|
||||
),
|
||||
"user": ("<|START_OF_TURN_TOKEN|><|USER_TOKEN|>", "<|END_OF_TURN_TOKEN|>"),
|
||||
"assistant": ("<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", "<|END_OF_TURN_TOKEN|>"),
|
||||
"assistant": (
|
||||
"<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>",
|
||||
"<|END_OF_TURN_TOKEN|>",
|
||||
),
|
||||
},
|
||||
style=ChatTemplateStyle.PLAIN,
|
||||
)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""The interpreter that executes SGL programs"""
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
import multiprocessing
|
||||
import queue
|
||||
import threading
|
||||
@@ -9,7 +10,6 @@ from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import contextvars
|
||||
import tqdm
|
||||
|
||||
from sglang.global_config import global_config
|
||||
@@ -222,7 +222,10 @@ class StreamExecutor:
|
||||
|
||||
def _run_worker_in_context():
|
||||
self._thread_worker_func()
|
||||
self.worker = threading.Thread(target=contextvars.copy_context().run, args=(_run_worker_in_context, ))
|
||||
|
||||
self.worker = threading.Thread(
|
||||
target=contextvars.copy_context().run, args=(_run_worker_in_context,)
|
||||
)
|
||||
self.worker.start()
|
||||
|
||||
# For streaming
|
||||
@@ -265,12 +268,11 @@ class StreamExecutor:
|
||||
self,
|
||||
number: int,
|
||||
position_ids_offset: Optional[List[int]] = None,
|
||||
copy: bool = False,
|
||||
):
|
||||
if number > 1 or copy:
|
||||
if number > 1:
|
||||
self.submit(SglCommitLazy())
|
||||
self.sync()
|
||||
|
||||
self.sync()
|
||||
number = int(number)
|
||||
|
||||
exes = [
|
||||
@@ -656,16 +658,15 @@ class ProgramState:
|
||||
self,
|
||||
number: int = 1,
|
||||
position_ids_offset: Optional[List[int]] = None,
|
||||
copy: bool = False,
|
||||
):
|
||||
stream_executors = self.stream_executor.fork(number, position_ids_offset, copy)
|
||||
stream_executors = self.stream_executor.fork(number, position_ids_offset)
|
||||
states = [ProgramState(x) for x in stream_executors]
|
||||
state_group = ProgramStateGroup(states, self)
|
||||
return state_group
|
||||
|
||||
@contextmanager
|
||||
def copy(self, position_ids_offset: Optional[List[int]] = None):
|
||||
state_group = self.fork(1, position_ids_offset, True)
|
||||
state_group = self.fork(1, position_ids_offset)
|
||||
try:
|
||||
yield state_group[0]
|
||||
finally:
|
||||
|
||||
@@ -42,7 +42,9 @@ class DetokenizerManager:
|
||||
output_strs = self.tokenizer.batch_decode(
|
||||
output_tokens,
|
||||
skip_special_tokens=recv_obj.skip_special_tokens[0],
|
||||
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
|
||||
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[
|
||||
0
|
||||
],
|
||||
)
|
||||
|
||||
# Trim stop str
|
||||
|
||||
Reference in New Issue
Block a user