perf: optimize qwen-vl with symm mem allreduce (#11381)
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
This commit is contained in:
@@ -1766,7 +1766,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
self.seq_lens_cpu = self.seq_lens_cpu[keep_indices]
|
||||
self.orig_seq_lens = self.orig_seq_lens[keep_indices_device]
|
||||
self.out_cache_loc = None
|
||||
self.seq_lens_sum = self.seq_lens.sum().item()
|
||||
if isinstance(self.seq_lens_cpu, torch.Tensor):
|
||||
# CPU tensor
|
||||
self.seq_lens_sum = int(self.seq_lens_cpu.sum().item())
|
||||
else:
|
||||
self.seq_lens_sum = int(np.asarray(self.seq_lens_cpu).sum())
|
||||
self.output_ids = self.output_ids[keep_indices_device]
|
||||
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
||||
if self.return_logprob:
|
||||
|
||||
Reference in New Issue
Block a user