Compat with latest VLLM 0.4.2 main + fork.number rename + Flashinfer 0.0.4 (#380)
Co-authored-by: ZX <zx@lbx.dev> Co-authored-by: ZhouXingg <165115237+ZhouXingg@users.noreply.github.com>
This commit is contained in:
@@ -266,14 +266,14 @@ class StreamExecutor:
|
||||
|
||||
def fork(
|
||||
self,
|
||||
number: int,
|
||||
size: int = 1,
|
||||
position_ids_offset: Optional[List[int]] = None,
|
||||
):
|
||||
if number > 1:
|
||||
if size > 1:
|
||||
self.submit(SglCommitLazy())
|
||||
|
||||
self.sync()
|
||||
number = int(number)
|
||||
size = int(size)
|
||||
|
||||
exes = [
|
||||
StreamExecutor(
|
||||
@@ -283,9 +283,9 @@ class StreamExecutor:
|
||||
self.chat_template,
|
||||
self.stream,
|
||||
)
|
||||
for _ in range(number)
|
||||
for _ in range(size)
|
||||
]
|
||||
for i in range(number):
|
||||
for i in range(size):
|
||||
exes[i].variables = dict(self.variables)
|
||||
exes[i].text_ = str(self.text_)
|
||||
exes[i].messages_ = list(self.messages_)
|
||||
@@ -656,10 +656,10 @@ class ProgramState:
|
||||
|
||||
def fork(
|
||||
self,
|
||||
number: int = 1,
|
||||
size: int = 1,
|
||||
position_ids_offset: Optional[List[int]] = None,
|
||||
):
|
||||
stream_executors = self.stream_executor.fork(number, position_ids_offset)
|
||||
stream_executors = self.stream_executor.fork(size, position_ids_offset)
|
||||
states = [ProgramState(x) for x in stream_executors]
|
||||
state_group = ProgramStateGroup(states, self)
|
||||
return state_group
|
||||
|
||||
@@ -109,19 +109,21 @@ class TracerProgramState(ProgramState):
|
||||
########### Public API ###########
|
||||
##################################
|
||||
|
||||
def fork(self, number: int, position_ids_offset: Optional[List[int]] = None):
|
||||
def fork(self, size: int = 1, position_ids_offset: Optional[List[int]] = None):
|
||||
assert (size >= 1)
|
||||
|
||||
if self.only_trace_prefix:
|
||||
raise StopTracing()
|
||||
|
||||
fork_node = SglFork(number)
|
||||
fork_node = SglFork(size)
|
||||
fork_node.prev_node = self.last_node
|
||||
|
||||
states = [
|
||||
TracerProgramState(self.backend, self.arguments, self.only_trace_prefix)
|
||||
for _ in range(number)
|
||||
for _ in range(size)
|
||||
]
|
||||
|
||||
for i in range(number):
|
||||
for i in range(size):
|
||||
node = SglGetForkItem(i)
|
||||
node.prev_node = fork_node
|
||||
states[i].last_node = node
|
||||
|
||||
Reference in New Issue
Block a user