Fix ngram spec with page size > 1 (#11135)
This commit is contained in:
@@ -1229,7 +1229,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
seq_lens_cpu_tensor = torch.tensor(seq_lens, dtype=torch.int64)
|
||||
seq_lens_cpu = torch.tensor(seq_lens, dtype=torch.int64)
|
||||
orig_seq_lens_tensor = torch.tensor(orig_seq_lens, dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
@@ -1366,7 +1366,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
prefix_lens_tensor,
|
||||
prefix_lens_cpu_tensor,
|
||||
seq_lens_tensor,
|
||||
seq_lens_cpu_tensor,
|
||||
seq_lens_cpu,
|
||||
last_loc,
|
||||
extend_num_tokens,
|
||||
)
|
||||
@@ -1375,7 +1375,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
self.input_ids = input_ids_tensor
|
||||
self.req_pool_indices = req_pool_indices_tensor
|
||||
self.seq_lens = seq_lens_tensor
|
||||
self.seq_lens_cpu = seq_lens_cpu_tensor
|
||||
self.seq_lens_cpu = seq_lens_cpu
|
||||
self.orig_seq_lens = orig_seq_lens_tensor
|
||||
self.out_cache_loc = out_cache_loc
|
||||
self.input_embeds = (
|
||||
|
||||
Reference in New Issue
Block a user