Fix memory leak during abort (#1674)
This commit is contained in:
@@ -17,7 +17,7 @@ import json
|
|||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from typing import Optional, Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import requests
|
import requests
|
||||||
|
|||||||
@@ -775,7 +775,7 @@ class Scheduler:
|
|||||||
else:
|
else:
|
||||||
self.tree_cache.cache_unfinished_req(req)
|
self.tree_cache.cache_unfinished_req(req)
|
||||||
|
|
||||||
self.stream_output(batch)
|
self.stream_output(batch.reqs)
|
||||||
|
|
||||||
def process_batch_result_decode(self, batch: ScheduleBatch, result):
|
def process_batch_result_decode(self, batch: ScheduleBatch, result):
|
||||||
logits_output, next_token_ids = result
|
logits_output, next_token_ids = result
|
||||||
@@ -815,7 +815,7 @@ class Scheduler:
|
|||||||
if req.top_logprobs_num > 0:
|
if req.top_logprobs_num > 0:
|
||||||
req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
|
req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
|
||||||
|
|
||||||
self.stream_output(batch)
|
self.stream_output(batch.reqs)
|
||||||
|
|
||||||
self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
|
self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
|
||||||
if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
|
if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
|
||||||
@@ -894,7 +894,7 @@ class Scheduler:
|
|||||||
|
|
||||||
return num_input_logprobs
|
return num_input_logprobs
|
||||||
|
|
||||||
def stream_output(self, batch: ScheduleBatch):
|
def stream_output(self, reqs: List[Req]):
|
||||||
output_rids = []
|
output_rids = []
|
||||||
output_meta_info = []
|
output_meta_info = []
|
||||||
output_finished_reason: List[BaseFinishReason] = []
|
output_finished_reason: List[BaseFinishReason] = []
|
||||||
@@ -911,7 +911,7 @@ class Scheduler:
|
|||||||
|
|
||||||
is_stream_iter = self.decode_forward_ct % self.stream_interval == 0
|
is_stream_iter = self.decode_forward_ct % self.stream_interval == 0
|
||||||
|
|
||||||
for req in batch.reqs:
|
for req in reqs:
|
||||||
if req.finished() or (
|
if req.finished() or (
|
||||||
req.stream and (is_stream_iter or len(req.output_ids) == 1)
|
req.stream and (is_stream_iter or len(req.output_ids) == 1)
|
||||||
):
|
):
|
||||||
@@ -1025,8 +1025,9 @@ class Scheduler:
|
|||||||
# Delete requests in the running batch
|
# Delete requests in the running batch
|
||||||
if self.running_batch:
|
if self.running_batch:
|
||||||
for req in self.running_batch.reqs:
|
for req in self.running_batch.reqs:
|
||||||
if req.rid == recv_req.rid:
|
if req.rid == recv_req.rid and not req.finished():
|
||||||
req.finished_reason = FINISH_ABORT()
|
req.finished_reason = FINISH_ABORT()
|
||||||
|
self.tree_cache.cache_finished_req(req)
|
||||||
break
|
break
|
||||||
|
|
||||||
def update_weights(self, recv_req: UpdateWeightReqInput):
|
def update_weights(self, recv_req: UpdateWeightReqInput):
|
||||||
|
|||||||
Reference in New Issue
Block a user