perf: optimize qwen-vl with symm mem allreduce (#11381)

Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
This commit is contained in:
Yuan Luo
2025-10-10 22:24:45 +08:00
committed by GitHub
parent a1a20b4c7c
commit 3b9d97f335
5 changed files with 82 additions and 17 deletions

View File

@@ -1766,7 +1766,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.seq_lens_cpu = self.seq_lens_cpu[keep_indices]
self.orig_seq_lens = self.orig_seq_lens[keep_indices_device]
self.out_cache_loc = None
self.seq_lens_sum = self.seq_lens.sum().item()
if isinstance(self.seq_lens_cpu, torch.Tensor):
# CPU tensor
self.seq_lens_sum = int(self.seq_lens_cpu.sum().item())
else:
self.seq_lens_sum = int(np.asarray(self.seq_lens_cpu).sum())
self.output_ids = self.output_ids[keep_indices_device]
self.return_logprob = any(req.return_logprob for req in self.reqs)
if self.return_logprob: