Simplify logits penalizer (#2086)
This commit is contained in:
@@ -1019,7 +1019,7 @@ class ScheduleBatch:
|
||||
extend_prefix_lens = self.prefix_lens
|
||||
extend_logprob_start_lens = self.extend_logprob_start_lens
|
||||
|
||||
if self.sampling_info is not None:
|
||||
if self.sampling_info:
|
||||
if self.has_grammar:
|
||||
self.sampling_info.grammars = [req.grammar for req in self.reqs]
|
||||
else:
|
||||
@@ -1063,6 +1063,7 @@ class ScheduleBatch:
|
||||
out_cache_loc=self.out_cache_loc,
|
||||
return_logprob=self.return_logprob,
|
||||
decoding_reqs=self.decoding_reqs,
|
||||
sampling_info=dataclasses.replace(self.sampling_info),
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
@@ -1122,20 +1123,6 @@ class ModelWorkerBatch:
|
||||
# Sampling info
|
||||
sampling_info: SamplingBatchInfo
|
||||
|
||||
def copy(self):
|
||||
return dataclasses.replace(self, sampling_info=self.sampling_info.copy())
|
||||
|
||||
def to(self, device: str):
|
||||
self.input_ids = self.input_ids.to(device, non_blocking=True)
|
||||
self.req_pool_indices = self.req_pool_indices.to(device, non_blocking=True)
|
||||
self.seq_lens = self.seq_lens.to(device, non_blocking=True)
|
||||
self.out_cache_loc = self.out_cache_loc.to(device, non_blocking=True)
|
||||
self.req_to_token_pool_records = [
|
||||
(x, y.to(device, non_blocking=True))
|
||||
for x, y in self.req_to_token_pool_records
|
||||
]
|
||||
self.sampling_info.to(device)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def write_req_to_token_pool_triton(
|
||||
|
||||
Reference in New Issue
Block a user