Simplify batch result resolution (#1735)

This commit is contained in:
Lianmin Zheng
2024-10-20 19:47:14 -07:00
committed by GitHub
parent e12358dc91
commit b121bc03a3
5 changed files with 64 additions and 90 deletions

View File

@@ -29,8 +29,8 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
It contains low-level tensor data. Most of the data consists of GPU tensors.
"""
import dataclasses
import logging
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import torch
@@ -116,7 +116,7 @@ class FINISH_ABORT(BaseFinishReason):
}
@dataclass
@dataclasses.dataclass
class ImageInputs:
"""The image related inputs."""
@@ -407,7 +407,7 @@ class Req:
bid = 0
@dataclass
@dataclasses.dataclass
class ScheduleBatch:
"""Store all inforamtion of a batch."""
@@ -902,7 +902,7 @@ class ScheduleBatch:
)
@dataclass
@dataclasses.dataclass
class ModelWorkerBatch:
# The batch id
bid: int
@@ -942,24 +942,7 @@ class ModelWorkerBatch:
mrope_positions_delta: List[List[int]]
def copy(self):
return ModelWorkerBatch(
bid=self.bid,
forward_mode=self.forward_mode,
input_ids=self.input_ids,
req_pool_indices=self.req_pool_indices,
seq_lens=self.seq_lens,
out_cache_loc=self.out_cache_loc,
req_to_token_pool_records=self.req_to_token_pool_records,
return_logprob=self.return_logprob,
top_logprobs_nums=self.top_logprobs_nums,
extend_seq_lens=self.extend_seq_lens,
extend_prefix_lens=self.extend_prefix_lens,
extend_logprob_start_lens=self.extend_logprob_start_lens,
image_inputs=self.image_inputs,
lora_paths=self.lora_paths,
sampling_info=self.sampling_info.copy(),
mrope_positions_delta=self.mrope_positions_delta,
)
return dataclasses.replace(self, sampling_info=self.sampling_info.copy())
def to(self, device: str):
self.input_ids = self.input_ids.to(device, non_blocking=True)