diff --git a/python/sglang/lang/interpreter.py b/python/sglang/lang/interpreter.py index 00691ca50..ef3d9fb1f 100644 --- a/python/sglang/lang/interpreter.py +++ b/python/sglang/lang/interpreter.py @@ -256,9 +256,15 @@ class StreamExecutor: ret = self.meta_info.get(name, None) return ret - def fork(self, number: int, position_ids_offset: Optional[List[int]] = None): - self.submit(SglCommitLazy()) - self.sync() + def fork( + self, + number: int, + position_ids_offset: Optional[List[int]] = None, + copy: bool = False, + ): + if number > 1 or copy: + self.submit(SglCommitLazy()) + self.sync() number = int(number) @@ -641,15 +647,20 @@ class ProgramState: yield self.stream_executor.submit(SglVarScopeEnd(name)) - def fork(self, number: int = 1, position_ids_offset: Optional[List[int]] = None): - stream_executors = self.stream_executor.fork(number, position_ids_offset) + def fork( + self, + number: int = 1, + position_ids_offset: Optional[List[int]] = None, + copy: bool = False, + ): + stream_executors = self.stream_executor.fork(number, position_ids_offset, copy) states = [ProgramState(x) for x in stream_executors] state_group = ProgramStateGroup(states, self) return state_group @contextmanager def copy(self, position_ids_offset: Optional[List[int]] = None): - state_group = self.fork(1, position_ids_offset) + state_group = self.fork(1, position_ids_offset, True) try: yield state_group[0] finally: diff --git a/python/sglang/srt/managers/router/radix_cache.py b/python/sglang/srt/managers/router/radix_cache.py index 7bb8a4b2a..c7bd9cb6b 100644 --- a/python/sglang/srt/managers/router/radix_cache.py +++ b/python/sglang/srt/managers/router/radix_cache.py @@ -179,7 +179,9 @@ class RadixCache: def _print_helper(self, node, indent): for _, child in node.children.items(): - print(" " * indent, len(child.key), child.key[:10], f"r={child.ref_counter}") + print( + " " * indent, len(child.key), child.key[:10], f"r={child.ref_counter}" + ) self._print_helper(child, indent=indent + 2) def _delete_leaf(self, node):