Fix flush cache (#627)

This commit is contained in:
Lianmin Zheng
2024-07-15 19:56:55 -07:00
parent 56f5fc4ab5
commit 41d1f67704
5 changed files with 30 additions and 20 deletions

View File

@@ -284,23 +284,26 @@ def main(server_args, bench_args):
else:
work_func = latency_test
workers = []
for tp_rank in range(server_args.tp_size):
proc = multiprocessing.Process(
target=work_func,
args=(
server_args,
bench_args,
tp_rank,
),
)
proc.start()
workers.append(proc)
if server_args.tp_size == 1:
work_func(server_args, bench_args, 0)
else:
workers = []
for tp_rank in range(server_args.tp_size):
proc = multiprocessing.Process(
target=work_func,
args=(
server_args,
bench_args,
tp_rank,
),
)
proc.start()
workers.append(proc)
for proc in workers:
proc.join()
for proc in workers:
proc.join()
proc.terminate()
proc.terminate()
if __name__ == "__main__":

View File

@@ -96,6 +96,7 @@ class ControllerSingle:
def __init__(self, server_args: ServerArgs, port_args: PortArgs, model_overide_args: dict):
# Parse args
self.server_args = server_args
self.tp_procs = []
# Init communication
context = zmq.Context(2)

View File

@@ -98,6 +98,8 @@ class TokenToKVPool:
self.can_use_mem_size += len(free_index)
def clear(self):
self.prefetch_buffer = torch.empty(0, device="cuda", dtype=torch.int32)
self.mem_state.fill_(True)
self.can_use_mem_size = self.size