Eagle speculative decoding part 3: small modifications to the general scheduler (#2709)
Co-authored-by: kavioyu <kavioyu@tencent.com>
This commit is contained in:
@@ -575,8 +575,8 @@ class ScheduleBatch:
|
||||
device: str = "cuda"
|
||||
|
||||
# Speculative decoding
|
||||
spec_algorithm: SpeculativeAlgorithm = None
|
||||
spec_info: Optional[SpecInfo] = None
|
||||
spec_algorithm: Optional[SpeculativeAlgorithm] = None
|
||||
|
||||
@classmethod
|
||||
def init_new(
|
||||
@@ -587,7 +587,7 @@ class ScheduleBatch:
|
||||
tree_cache: BasePrefixCache,
|
||||
model_config: ModelConfig,
|
||||
enable_overlap: bool,
|
||||
speculative_algorithm: Optional[SpeculativeAlgorithm] = None,
|
||||
spec_algorithm: SpeculativeAlgorithm,
|
||||
):
|
||||
return cls(
|
||||
reqs=reqs,
|
||||
@@ -600,7 +600,7 @@ class ScheduleBatch:
|
||||
has_stream=any(req.stream for req in reqs),
|
||||
has_grammar=any(req.grammar for req in reqs),
|
||||
device=req_to_token_pool.device,
|
||||
spec_algorithm=speculative_algorithm,
|
||||
spec_algorithm=spec_algorithm,
|
||||
)
|
||||
|
||||
def batch_size(self):
|
||||
@@ -1010,6 +1010,8 @@ class ScheduleBatch:
|
||||
|
||||
def prepare_for_decode(self):
|
||||
self.forward_mode = ForwardMode.DECODE
|
||||
if self.spec_algorithm.is_eagle():
|
||||
return
|
||||
|
||||
self.input_ids = self.output_ids
|
||||
self.output_ids = None
|
||||
@@ -1172,6 +1174,7 @@ class ScheduleBatch:
|
||||
out_cache_loc=self.out_cache_loc,
|
||||
return_logprob=self.return_logprob,
|
||||
decoding_reqs=self.decoding_reqs,
|
||||
spec_algorithm=self.spec_algorithm,
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
@@ -1232,8 +1235,8 @@ class ModelWorkerBatch:
|
||||
input_embeds: Optional[torch.tensor] = None
|
||||
|
||||
# Speculative decoding
|
||||
spec_algorithm: SpeculativeAlgorithm = None
|
||||
spec_info: Optional[SpecInfo] = None
|
||||
spec_algorithm: Optional[SpeculativeAlgorithm] = None
|
||||
|
||||
|
||||
@triton.jit
|
||||
|
||||
Reference in New Issue
Block a user