Fix sync() when fork(1) (#412)

This commit is contained in:
Liangsheng Yin
2024-05-08 15:15:18 +08:00
committed by GitHub
parent 4a1c6ae2ce
commit d5de20a3ee
3 changed files with 20 additions and 11 deletions

View File

@@ -1,6 +1,7 @@
"""The interpreter that executes SGL programs"""
import asyncio
import contextvars
import multiprocessing
import queue
import threading
@@ -9,7 +10,6 @@ from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from typing import Any, Callable, Dict, List, Optional, Union
import contextvars
import tqdm
from sglang.global_config import global_config
@@ -222,7 +222,10 @@ class StreamExecutor:
def _run_worker_in_context():
self._thread_worker_func()
self.worker = threading.Thread(target=contextvars.copy_context().run, args=(_run_worker_in_context, ))
self.worker = threading.Thread(
target=contextvars.copy_context().run, args=(_run_worker_in_context,)
)
self.worker.start()
# For streaming
@@ -265,12 +268,11 @@ class StreamExecutor:
self,
number: int,
position_ids_offset: Optional[List[int]] = None,
copy: bool = False,
):
if number > 1 or copy:
if number > 1:
self.submit(SglCommitLazy())
self.sync()
self.sync()
number = int(number)
exes = [
@@ -656,16 +658,15 @@ class ProgramState:
self,
number: int = 1,
position_ids_offset: Optional[List[int]] = None,
copy: bool = False,
):
stream_executors = self.stream_executor.fork(number, position_ids_offset, copy)
stream_executors = self.stream_executor.fork(number, position_ids_offset)
states = [ProgramState(x) for x in stream_executors]
state_group = ProgramStateGroup(states, self)
return state_group
@contextmanager
def copy(self, position_ids_offset: Optional[List[int]] = None):
state_group = self.fork(1, position_ids_offset, True)
state_group = self.fork(1, position_ids_offset)
try:
yield state_group[0]
finally: