Compat with latest VLLM 0.4.2 main + fork.number rename + Flashinfer 0.0.4 (#380)

Co-authored-by: ZX <zx@lbx.dev>
Co-authored-by: ZhouXingg <165115237+ZhouXingg@users.noreply.github.com>
This commit is contained in:
Qubitium
2024-05-12 07:37:49 +08:00
committed by GitHub
parent a511a2d089
commit 33b242df30
20 changed files with 611 additions and 187 deletions

View File

@@ -109,19 +109,21 @@ class TracerProgramState(ProgramState):
########### Public API ###########
##################################
def fork(self, number: int, position_ids_offset: Optional[List[int]] = None):
def fork(self, size: int = 1, position_ids_offset: Optional[List[int]] = None):
assert (size >= 1)
if self.only_trace_prefix:
raise StopTracing()
fork_node = SglFork(number)
fork_node = SglFork(size)
fork_node.prev_node = self.last_node
states = [
TracerProgramState(self.backend, self.arguments, self.only_trace_prefix)
for _ in range(number)
for _ in range(size)
]
for i in range(number):
for i in range(size):
node = SglGetForkItem(i)
node.prev_node = fork_node
states[i].last_node = node