Improve weight loading and code style (#3174)
This commit is contained in:
@@ -247,6 +247,7 @@ class Req:
|
||||
# Each decode stage's output ids
|
||||
self.output_ids = []
|
||||
# fill_ids = origin_input_ids + output_ids. Updated if chunked.
|
||||
self.fill_ids = None
|
||||
self.session_id = session_id
|
||||
self.input_embeds = input_embeds
|
||||
|
||||
|
||||
@@ -486,7 +486,7 @@ class Scheduler:
|
||||
@torch.no_grad()
|
||||
def event_loop_overlap(self):
|
||||
"""A scheduler loop that overlaps the CPU processing and GPU computation."""
|
||||
result_queue = deque()
|
||||
self.result_queue = deque()
|
||||
|
||||
while True:
|
||||
recv_reqs = self.recv_requests()
|
||||
@@ -497,7 +497,7 @@ class Scheduler:
|
||||
|
||||
if batch:
|
||||
result = self.run_batch(batch)
|
||||
result_queue.append((batch.copy(), result))
|
||||
self.result_queue.append((batch.copy(), result))
|
||||
|
||||
if self.last_batch is None:
|
||||
# Create a dummy first batch to start the pipeline for overlap schedule.
|
||||
@@ -511,7 +511,7 @@ class Scheduler:
|
||||
|
||||
if self.last_batch:
|
||||
# Process the results of the last batch
|
||||
tmp_batch, tmp_result = result_queue.popleft()
|
||||
tmp_batch, tmp_result = self.result_queue.popleft()
|
||||
tmp_batch.next_batch_sampling_info = (
|
||||
self.tp_worker.cur_sampling_info if batch else None
|
||||
)
|
||||
@@ -642,7 +642,7 @@ class Scheduler:
|
||||
self.waiting_queue.append(req)
|
||||
return
|
||||
|
||||
# Handle image inputs
|
||||
# Handle multimodal inputs
|
||||
if recv_req.image_inputs is not None:
|
||||
image_inputs = ImageInputs.from_dict(recv_req.image_inputs)
|
||||
# Expand a single image token into multiple dummy tokens for receiving image embeddings
|
||||
@@ -743,7 +743,13 @@ class Scheduler:
|
||||
req.logprob_start_len = len(req.origin_input_ids) - 1
|
||||
self.waiting_queue.append(req)
|
||||
|
||||
def log_prefill_stats(self, adder, can_run_list, running_bs, has_being_chunked):
|
||||
def log_prefill_stats(
|
||||
self,
|
||||
adder: PrefillAdder,
|
||||
can_run_list: List[Req],
|
||||
running_bs: ScheduleBatch,
|
||||
has_being_chunked: bool,
|
||||
):
|
||||
self.tree_cache_metrics["total"] += (
|
||||
adder.log_input_tokens + adder.log_hit_tokens
|
||||
) / 10**9
|
||||
|
||||
Reference in New Issue
Block a user