Fix ngram spec with page size > 1 (#11135)
This commit is contained in:
@@ -1229,7 +1229,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
|
seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
|
||||||
self.device, non_blocking=True
|
self.device, non_blocking=True
|
||||||
)
|
)
|
||||||
seq_lens_cpu_tensor = torch.tensor(seq_lens, dtype=torch.int64)
|
seq_lens_cpu = torch.tensor(seq_lens, dtype=torch.int64)
|
||||||
orig_seq_lens_tensor = torch.tensor(orig_seq_lens, dtype=torch.int32).to(
|
orig_seq_lens_tensor = torch.tensor(orig_seq_lens, dtype=torch.int32).to(
|
||||||
self.device, non_blocking=True
|
self.device, non_blocking=True
|
||||||
)
|
)
|
||||||
@@ -1366,7 +1366,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
prefix_lens_tensor,
|
prefix_lens_tensor,
|
||||||
prefix_lens_cpu_tensor,
|
prefix_lens_cpu_tensor,
|
||||||
seq_lens_tensor,
|
seq_lens_tensor,
|
||||||
seq_lens_cpu_tensor,
|
seq_lens_cpu,
|
||||||
last_loc,
|
last_loc,
|
||||||
extend_num_tokens,
|
extend_num_tokens,
|
||||||
)
|
)
|
||||||
@@ -1375,7 +1375,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
self.input_ids = input_ids_tensor
|
self.input_ids = input_ids_tensor
|
||||||
self.req_pool_indices = req_pool_indices_tensor
|
self.req_pool_indices = req_pool_indices_tensor
|
||||||
self.seq_lens = seq_lens_tensor
|
self.seq_lens = seq_lens_tensor
|
||||||
self.seq_lens_cpu = seq_lens_cpu_tensor
|
self.seq_lens_cpu = seq_lens_cpu
|
||||||
self.orig_seq_lens = orig_seq_lens_tensor
|
self.orig_seq_lens = orig_seq_lens_tensor
|
||||||
self.out_cache_loc = out_cache_loc
|
self.out_cache_loc = out_cache_loc
|
||||||
self.input_embeds = (
|
self.input_embeds = (
|
||||||
|
|||||||
@@ -1087,7 +1087,10 @@ class ServerArgs:
|
|||||||
and self.attention_backend != "flashinfer"
|
and self.attention_backend != "flashinfer"
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"speculative_eagle_topk > 1 with page_size > 1 is unstable and produces incorrect results for paged attention backends. This combination is only supported for the 'flashinfer' backend."
|
f"speculative_eagle_topk({self.speculative_eagle_topk}) > 1 "
|
||||||
|
f"with page_size({self.page_size}) > 1 is unstable "
|
||||||
|
"and produces incorrect results for paged attention backends. "
|
||||||
|
"This combination is only supported for the 'flashinfer' backend."
|
||||||
)
|
)
|
||||||
if self.enable_dp_attention:
|
if self.enable_dp_attention:
|
||||||
# TODO: support dp attention for ngram speculative decoding
|
# TODO: support dp attention for ngram speculative decoding
|
||||||
|
|||||||
@@ -388,6 +388,8 @@ class EagleVerifyInput(SpecInput):
|
|||||||
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
|
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
|
||||||
evict_mask[accept_index] = False
|
evict_mask[accept_index] = False
|
||||||
accept_length_cpu = accept_length.cpu()
|
accept_length_cpu = accept_length.cpu()
|
||||||
|
# FIXME: this `tolist()` fixes the numerical calculation consistency
|
||||||
|
# try to unify the tensor representation and list representation
|
||||||
accept_length_list = accept_length_cpu.tolist()
|
accept_length_list = accept_length_cpu.tolist()
|
||||||
|
|
||||||
if page_size == 1:
|
if page_size == 1:
|
||||||
|
|||||||
@@ -79,14 +79,21 @@ class NgramVerifyInput(SpecInput):
|
|||||||
else:
|
else:
|
||||||
# TODO(lsyin): add prefix lens cpu here to support page size > 1
|
# TODO(lsyin): add prefix lens cpu here to support page size > 1
|
||||||
prefix_lens = batch.seq_lens
|
prefix_lens = batch.seq_lens
|
||||||
|
prefix_lens_cpu = batch.seq_lens_cpu
|
||||||
end_offset = prefix_lens + self.draft_token_num
|
end_offset = prefix_lens + self.draft_token_num
|
||||||
|
end_offset_cpu = prefix_lens_cpu + self.draft_token_num
|
||||||
last_loc = get_last_loc(
|
last_loc = get_last_loc(
|
||||||
batch.req_to_token_pool.req_to_token,
|
batch.req_to_token_pool.req_to_token,
|
||||||
batch.req_pool_indices,
|
batch.req_pool_indices,
|
||||||
prefix_lens,
|
prefix_lens,
|
||||||
)
|
)
|
||||||
batch.out_cache_loc = batch.alloc_paged_token_slots_extend(
|
batch.out_cache_loc = batch.alloc_paged_token_slots_extend(
|
||||||
prefix_lens, end_offset, last_loc, len(batch.input_ids)
|
prefix_lens,
|
||||||
|
prefix_lens_cpu,
|
||||||
|
end_offset,
|
||||||
|
end_offset_cpu,
|
||||||
|
last_loc,
|
||||||
|
len(batch.input_ids),
|
||||||
)
|
)
|
||||||
self.last_loc = last_loc
|
self.last_loc = last_loc
|
||||||
|
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ DEFAULT_SERVER_ARGS = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class TestStandaloneSpeculativeDecodingBase(CustomTestCase):
|
class TestNgramSpeculativeDecodingBase(CustomTestCase):
|
||||||
|
|
||||||
model = DEFAULT_NGRAM_SPECULATIVE_TARGET_MODEL_FOR_TEST
|
model = DEFAULT_NGRAM_SPECULATIVE_TARGET_MODEL_FOR_TEST
|
||||||
base_url = DEFAULT_URL_FOR_TEST
|
base_url = DEFAULT_URL_FOR_TEST
|
||||||
@@ -88,20 +88,30 @@ class TestStandaloneSpeculativeDecodingBase(CustomTestCase):
|
|||||||
self.assertGreater(avg_spec_accept_length, self.spec_decode_threshold)
|
self.assertGreater(avg_spec_accept_length, self.spec_decode_threshold)
|
||||||
|
|
||||||
|
|
||||||
class TestStandaloneSpeculativeDecodingTriton(TestStandaloneSpeculativeDecodingBase):
|
class TestNgramSpeculativeDecodingTriton(TestNgramSpeculativeDecodingBase):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_server_args(cls):
|
def get_server_args(cls):
|
||||||
return DEFAULT_SERVER_ARGS + ["--attention-backend", "triton"]
|
return DEFAULT_SERVER_ARGS + ["--attention-backend", "triton"]
|
||||||
|
|
||||||
|
|
||||||
class TestStandaloneSpeculativeDecodingFlashinfer(
|
class TestNgramSpeculativeDecodingFlashinfer(TestNgramSpeculativeDecodingBase):
|
||||||
TestStandaloneSpeculativeDecodingBase
|
|
||||||
):
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_server_args(cls):
|
def get_server_args(cls):
|
||||||
return DEFAULT_SERVER_ARGS + ["--attention-backend", "flashinfer"]
|
return DEFAULT_SERVER_ARGS + ["--attention-backend", "flashinfer"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestNgramSpeculativeDecodingPaged(TestNgramSpeculativeDecodingBase):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_server_args(cls):
|
||||||
|
return DEFAULT_SERVER_ARGS + [
|
||||||
|
"--attention-backend",
|
||||||
|
"flashinfer",
|
||||||
|
"--page-size",
|
||||||
|
"64",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user