Fix ngram spec with page size > 1 (#11135)
This commit is contained in:
@@ -1229,7 +1229,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
seq_lens_cpu_tensor = torch.tensor(seq_lens, dtype=torch.int64)
|
||||
seq_lens_cpu = torch.tensor(seq_lens, dtype=torch.int64)
|
||||
orig_seq_lens_tensor = torch.tensor(orig_seq_lens, dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
@@ -1366,7 +1366,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
prefix_lens_tensor,
|
||||
prefix_lens_cpu_tensor,
|
||||
seq_lens_tensor,
|
||||
seq_lens_cpu_tensor,
|
||||
seq_lens_cpu,
|
||||
last_loc,
|
||||
extend_num_tokens,
|
||||
)
|
||||
@@ -1375,7 +1375,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
self.input_ids = input_ids_tensor
|
||||
self.req_pool_indices = req_pool_indices_tensor
|
||||
self.seq_lens = seq_lens_tensor
|
||||
self.seq_lens_cpu = seq_lens_cpu_tensor
|
||||
self.seq_lens_cpu = seq_lens_cpu
|
||||
self.orig_seq_lens = orig_seq_lens_tensor
|
||||
self.out_cache_loc = out_cache_loc
|
||||
self.input_embeds = (
|
||||
|
||||
@@ -1087,7 +1087,10 @@ class ServerArgs:
|
||||
and self.attention_backend != "flashinfer"
|
||||
):
|
||||
raise ValueError(
|
||||
"speculative_eagle_topk > 1 with page_size > 1 is unstable and produces incorrect results for paged attention backends. This combination is only supported for the 'flashinfer' backend."
|
||||
f"speculative_eagle_topk({self.speculative_eagle_topk}) > 1 "
|
||||
f"with page_size({self.page_size}) > 1 is unstable "
|
||||
"and produces incorrect results for paged attention backends. "
|
||||
"This combination is only supported for the 'flashinfer' backend."
|
||||
)
|
||||
if self.enable_dp_attention:
|
||||
# TODO: support dp attention for ngram speculative decoding
|
||||
|
||||
@@ -388,6 +388,8 @@ class EagleVerifyInput(SpecInput):
|
||||
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
|
||||
evict_mask[accept_index] = False
|
||||
accept_length_cpu = accept_length.cpu()
|
||||
# FIXME: this `tolist()` fixes the numerical calculation consistency
|
||||
# try to unify the tensor representation and list representation
|
||||
accept_length_list = accept_length_cpu.tolist()
|
||||
|
||||
if page_size == 1:
|
||||
|
||||
@@ -79,14 +79,21 @@ class NgramVerifyInput(SpecInput):
|
||||
else:
|
||||
# TODO(lsyin): add prefix lens cpu here to support page size > 1
|
||||
prefix_lens = batch.seq_lens
|
||||
prefix_lens_cpu = batch.seq_lens_cpu
|
||||
end_offset = prefix_lens + self.draft_token_num
|
||||
end_offset_cpu = prefix_lens_cpu + self.draft_token_num
|
||||
last_loc = get_last_loc(
|
||||
batch.req_to_token_pool.req_to_token,
|
||||
batch.req_pool_indices,
|
||||
prefix_lens,
|
||||
)
|
||||
batch.out_cache_loc = batch.alloc_paged_token_slots_extend(
|
||||
prefix_lens, end_offset, last_loc, len(batch.input_ids)
|
||||
prefix_lens,
|
||||
prefix_lens_cpu,
|
||||
end_offset,
|
||||
end_offset_cpu,
|
||||
last_loc,
|
||||
len(batch.input_ids),
|
||||
)
|
||||
self.last_loc = last_loc
|
||||
|
||||
|
||||
Reference in New Issue
Block a user