[Enhancement] Custom Logit Processor Improvement (#2998)
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
This commit is contained in:
@@ -595,6 +595,9 @@ class ScheduleBatch:
|
||||
spec_algorithm: SpeculativeAlgorithm = None
|
||||
spec_info: Optional[SpecInfo] = None
|
||||
|
||||
# Enable custom logit processor
|
||||
enable_custom_logit_processor: bool = False
|
||||
|
||||
@classmethod
|
||||
def init_new(
|
||||
cls,
|
||||
@@ -605,6 +608,7 @@ class ScheduleBatch:
|
||||
model_config: ModelConfig,
|
||||
enable_overlap: bool,
|
||||
spec_algorithm: SpeculativeAlgorithm,
|
||||
enable_custom_logit_processor: bool,
|
||||
):
|
||||
return cls(
|
||||
reqs=reqs,
|
||||
@@ -618,6 +622,7 @@ class ScheduleBatch:
|
||||
has_grammar=any(req.grammar for req in reqs),
|
||||
device=req_to_token_pool.device,
|
||||
spec_algorithm=spec_algorithm,
|
||||
enable_custom_logit_processor=enable_custom_logit_processor,
|
||||
)
|
||||
|
||||
def batch_size(self):
|
||||
@@ -1201,6 +1206,7 @@ class ScheduleBatch:
|
||||
return_logprob=self.return_logprob,
|
||||
decoding_reqs=self.decoding_reqs,
|
||||
spec_algorithm=self.spec_algorithm,
|
||||
enable_custom_logit_processor=self.enable_custom_logit_processor,
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
|
||||
Reference in New Issue
Block a user