Fix memory leak during abort (#1674)

This commit is contained in:
Lianmin Zheng
2024-10-15 08:15:08 -07:00
committed by GitHub
parent 175afed370
commit f1088e0fc8
2 changed files with 7 additions and 6 deletions

View File

@@ -17,7 +17,7 @@ import json
import multiprocessing import multiprocessing
import os import os
import time import time
from typing import Optional, Tuple from typing import Tuple
import numpy as np import numpy as np
import requests import requests

View File

@@ -775,7 +775,7 @@ class Scheduler:
else: else:
self.tree_cache.cache_unfinished_req(req) self.tree_cache.cache_unfinished_req(req)
self.stream_output(batch) self.stream_output(batch.reqs)
def process_batch_result_decode(self, batch: ScheduleBatch, result): def process_batch_result_decode(self, batch: ScheduleBatch, result):
logits_output, next_token_ids = result logits_output, next_token_ids = result
@@ -815,7 +815,7 @@ class Scheduler:
if req.top_logprobs_num > 0: if req.top_logprobs_num > 0:
req.output_top_logprobs.append(logits_output.output_top_logprobs[i]) req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
self.stream_output(batch) self.stream_output(batch.reqs)
self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30) self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0: if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
@@ -894,7 +894,7 @@ class Scheduler:
return num_input_logprobs return num_input_logprobs
def stream_output(self, batch: ScheduleBatch): def stream_output(self, reqs: List[Req]):
output_rids = [] output_rids = []
output_meta_info = [] output_meta_info = []
output_finished_reason: List[BaseFinishReason] = [] output_finished_reason: List[BaseFinishReason] = []
@@ -911,7 +911,7 @@ class Scheduler:
is_stream_iter = self.decode_forward_ct % self.stream_interval == 0 is_stream_iter = self.decode_forward_ct % self.stream_interval == 0
for req in batch.reqs: for req in reqs:
if req.finished() or ( if req.finished() or (
req.stream and (is_stream_iter or len(req.output_ids) == 1) req.stream and (is_stream_iter or len(req.output_ids) == 1)
): ):
@@ -1025,8 +1025,9 @@ class Scheduler:
# Delete requests in the running batch # Delete requests in the running batch
if self.running_batch: if self.running_batch:
for req in self.running_batch.reqs: for req in self.running_batch.reqs:
if req.rid == recv_req.rid: if req.rid == recv_req.rid and not req.finished():
req.finished_reason = FINISH_ABORT() req.finished_reason = FINISH_ABORT()
self.tree_cache.cache_finished_req(req)
break break
def update_weights(self, recv_req: UpdateWeightReqInput): def update_weights(self, recv_req: UpdateWeightReqInput):