Compat with latest VLLM 0.4.2 main + fork.number rename + Flashinfer 0.0.4 (#380)

Co-authored-by: ZX <zx@lbx.dev>
Co-authored-by: ZhouXingg <165115237+ZhouXingg@users.noreply.github.com>
This commit is contained in:
Qubitium
2024-05-12 07:37:49 +08:00
committed by GitHub
parent a511a2d089
commit 33b242df30
20 changed files with 611 additions and 187 deletions

View File

@@ -266,14 +266,14 @@ class StreamExecutor:
def fork(
self,
number: int,
size: int = 1,
position_ids_offset: Optional[List[int]] = None,
):
if number > 1:
if size > 1:
self.submit(SglCommitLazy())
self.sync()
number = int(number)
size = int(size)
exes = [
StreamExecutor(
@@ -283,9 +283,9 @@ class StreamExecutor:
self.chat_template,
self.stream,
)
for _ in range(number)
for _ in range(size)
]
for i in range(number):
for i in range(size):
exes[i].variables = dict(self.variables)
exes[i].text_ = str(self.text_)
exes[i].messages_ = list(self.messages_)
@@ -656,10 +656,10 @@ class ProgramState:
def fork(
self,
number: int = 1,
size: int = 1,
position_ids_offset: Optional[List[int]] = None,
):
stream_executors = self.stream_executor.fork(number, position_ids_offset)
stream_executors = self.stream_executor.fork(size, position_ids_offset)
states = [ProgramState(x) for x in stream_executors]
state_group = ProgramStateGroup(states, self)
return state_group