Fix ngram spec with page size > 1 (#11135)

This commit is contained in:
Liangsheng Yin
2025-10-02 12:34:23 +08:00
committed by GitHub
parent 0b2aa8a70c
commit 25e7dbe8af
5 changed files with 32 additions and 10 deletions

View File

@@ -1229,7 +1229,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
self.device, non_blocking=True
)
seq_lens_cpu_tensor = torch.tensor(seq_lens, dtype=torch.int64)
seq_lens_cpu = torch.tensor(seq_lens, dtype=torch.int64)
orig_seq_lens_tensor = torch.tensor(orig_seq_lens, dtype=torch.int32).to(
self.device, non_blocking=True
)
@@ -1366,7 +1366,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
prefix_lens_tensor,
prefix_lens_cpu_tensor,
seq_lens_tensor,
seq_lens_cpu_tensor,
seq_lens_cpu,
last_loc,
extend_num_tokens,
)
@@ -1375,7 +1375,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.input_ids = input_ids_tensor
self.req_pool_indices = req_pool_indices_tensor
self.seq_lens = seq_lens_tensor
self.seq_lens_cpu = seq_lens_cpu_tensor
self.seq_lens_cpu = seq_lens_cpu
self.orig_seq_lens = orig_seq_lens_tensor
self.out_cache_loc = out_cache_loc
self.input_embeds = (