Simplify batch result resolution (#1735)
This commit is contained in:
@@ -29,8 +29,8 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
|
||||
It contains low-level tensor data. Most of the data consists of GPU tensors.
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@@ -116,7 +116,7 @@ class FINISH_ABORT(BaseFinishReason):
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclasses.dataclass
|
||||
class ImageInputs:
|
||||
"""The image related inputs."""
|
||||
|
||||
@@ -407,7 +407,7 @@ class Req:
|
||||
bid = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclasses.dataclass
|
||||
class ScheduleBatch:
|
||||
"""Store all inforamtion of a batch."""
|
||||
|
||||
@@ -902,7 +902,7 @@ class ScheduleBatch:
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclasses.dataclass
|
||||
class ModelWorkerBatch:
|
||||
# The batch id
|
||||
bid: int
|
||||
@@ -942,24 +942,7 @@ class ModelWorkerBatch:
|
||||
mrope_positions_delta: List[List[int]]
|
||||
|
||||
def copy(self):
|
||||
return ModelWorkerBatch(
|
||||
bid=self.bid,
|
||||
forward_mode=self.forward_mode,
|
||||
input_ids=self.input_ids,
|
||||
req_pool_indices=self.req_pool_indices,
|
||||
seq_lens=self.seq_lens,
|
||||
out_cache_loc=self.out_cache_loc,
|
||||
req_to_token_pool_records=self.req_to_token_pool_records,
|
||||
return_logprob=self.return_logprob,
|
||||
top_logprobs_nums=self.top_logprobs_nums,
|
||||
extend_seq_lens=self.extend_seq_lens,
|
||||
extend_prefix_lens=self.extend_prefix_lens,
|
||||
extend_logprob_start_lens=self.extend_logprob_start_lens,
|
||||
image_inputs=self.image_inputs,
|
||||
lora_paths=self.lora_paths,
|
||||
sampling_info=self.sampling_info.copy(),
|
||||
mrope_positions_delta=self.mrope_positions_delta,
|
||||
)
|
||||
return dataclasses.replace(self, sampling_info=self.sampling_info.copy())
|
||||
|
||||
def to(self, device: str):
|
||||
self.input_ids = self.input_ids.to(device, non_blocking=True)
|
||||
|
||||
Reference in New Issue
Block a user