Simplify logits penalizer (#2086)
This commit is contained in:
@@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
"""A tensor parallel worker."""
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
@@ -107,7 +108,7 @@ class TpModelWorkerClient:
|
||||
|
||||
# Run forward
|
||||
logits_output, next_token_ids = self.worker.forward_batch_generation(
|
||||
model_worker_batch
|
||||
model_worker_batch, self.launch_event
|
||||
)
|
||||
|
||||
# Update the future token ids map
|
||||
@@ -134,7 +135,6 @@ class TpModelWorkerClient:
|
||||
next_token_ids = next_token_ids.to("cpu", non_blocking=True)
|
||||
copy_event.record()
|
||||
|
||||
self.launch_event.set()
|
||||
self.output_queue.put((copy_event, logits_output, next_token_ids))
|
||||
|
||||
def resolve_batch_result(self, bid: int):
|
||||
@@ -159,7 +159,10 @@ class TpModelWorkerClient:
|
||||
|
||||
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
|
||||
# Push a new batch to the queue
|
||||
self.input_queue.put((model_worker_batch.copy(), self.future_token_ids_ct))
|
||||
model_worker_batch.sampling_info = dataclasses.replace(
|
||||
model_worker_batch.sampling_info
|
||||
)
|
||||
self.input_queue.put((model_worker_batch, self.future_token_ids_ct))
|
||||
|
||||
# Allocate output future objects
|
||||
bs = len(model_worker_batch.seq_lens)
|
||||
|
||||
Reference in New Issue
Block a user