Simplify logits penalizer (#2086)

This commit is contained in:
Lianmin Zheng
2024-11-18 17:48:28 -08:00
committed by GitHub
parent 3b44bbeecf
commit b110453802
18 changed files with 125 additions and 190 deletions

View File

@@ -1019,7 +1019,7 @@ class ScheduleBatch:
extend_prefix_lens = self.prefix_lens
extend_logprob_start_lens = self.extend_logprob_start_lens
if self.sampling_info is not None:
if self.sampling_info:
if self.has_grammar:
self.sampling_info.grammars = [req.grammar for req in self.reqs]
else:
@@ -1063,6 +1063,7 @@ class ScheduleBatch:
out_cache_loc=self.out_cache_loc,
return_logprob=self.return_logprob,
decoding_reqs=self.decoding_reqs,
sampling_info=dataclasses.replace(self.sampling_info),
)
def __str__(self):
@@ -1122,20 +1123,6 @@ class ModelWorkerBatch:
# Sampling info
sampling_info: SamplingBatchInfo
def copy(self):
return dataclasses.replace(self, sampling_info=self.sampling_info.copy())
def to(self, device: str):
self.input_ids = self.input_ids.to(device, non_blocking=True)
self.req_pool_indices = self.req_pool_indices.to(device, non_blocking=True)
self.seq_lens = self.seq_lens.to(device, non_blocking=True)
self.out_cache_loc = self.out_cache_loc.to(device, non_blocking=True)
self.req_to_token_pool_records = [
(x, y.to(device, non_blocking=True))
for x, y in self.req_to_token_pool_records
]
self.sampling_info.to(device)
@triton.jit
def write_req_to_token_pool_triton(