Fix bugs of logprobs_nums (#1548)
This commit is contained in:
@@ -748,6 +748,8 @@ class ScheduleBatch:
|
|||||||
self.top_logprobs_nums = [
|
self.top_logprobs_nums = [
|
||||||
self.top_logprobs_nums[i] for i in unfinished_indices
|
self.top_logprobs_nums[i] for i in unfinished_indices
|
||||||
]
|
]
|
||||||
|
else:
|
||||||
|
self.top_logprobs_nums = None
|
||||||
self.has_stream = any(req.stream for req in self.reqs)
|
self.has_stream = any(req.stream for req in self.reqs)
|
||||||
|
|
||||||
self.sampling_info.filter_batch(unfinished_indices, new_indices)
|
self.sampling_info.filter_batch(unfinished_indices, new_indices)
|
||||||
@@ -758,13 +760,11 @@ class ScheduleBatch:
|
|||||||
# needs to be called with pre-merged Batch.reqs.
|
# needs to be called with pre-merged Batch.reqs.
|
||||||
self.sampling_info.merge_batch(other.sampling_info)
|
self.sampling_info.merge_batch(other.sampling_info)
|
||||||
|
|
||||||
self.reqs.extend(other.reqs)
|
|
||||||
self.req_pool_indices = torch.concat(
|
self.req_pool_indices = torch.concat(
|
||||||
[self.req_pool_indices, other.req_pool_indices]
|
[self.req_pool_indices, other.req_pool_indices]
|
||||||
)
|
)
|
||||||
self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
|
self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
|
||||||
self.out_cache_loc = None
|
self.out_cache_loc = None
|
||||||
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
|
||||||
if self.return_logprob and other.return_logprob:
|
if self.return_logprob and other.return_logprob:
|
||||||
self.top_logprobs_nums.extend(other.top_logprobs_nums)
|
self.top_logprobs_nums.extend(other.top_logprobs_nums)
|
||||||
elif self.return_logprob:
|
elif self.return_logprob:
|
||||||
@@ -772,6 +772,8 @@ class ScheduleBatch:
|
|||||||
elif other.return_logprob:
|
elif other.return_logprob:
|
||||||
self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
|
self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
|
||||||
self.has_stream = any(req.stream for req in self.reqs)
|
self.has_stream = any(req.stream for req in self.reqs)
|
||||||
|
self.reqs.extend(other.reqs)
|
||||||
|
self.return_logprob = self.return_logprob or other.return_logprob
|
||||||
|
|
||||||
def get_model_worker_batch(self):
|
def get_model_worker_batch(self):
|
||||||
if self.forward_mode.is_decode():
|
if self.forward_mode.is_decode():
|
||||||
|
|||||||
Reference in New Issue
Block a user