Make constrained decoding work for overlap scheduler (#2095)
This commit is contained in:
@@ -136,6 +136,7 @@ class ImageInputs:
|
||||
image_embeds: Optional[List[torch.Tensor]] = None
|
||||
aspect_ratio_ids: Optional[List[torch.Tensor]] = None
|
||||
aspect_ratio_mask: Optional[List[torch.Tensor]] = None
|
||||
|
||||
# QWen2-VL related
|
||||
image_grid_thws: List[Tuple[int, int, int]] = None
|
||||
mrope_position_delta: Optional[torch.Tensor] = None
|
||||
@@ -187,11 +188,10 @@ class Req:
|
||||
self.origin_input_ids = origin_input_ids
|
||||
self.output_ids = [] # Each decode stage's output ids
|
||||
self.fill_ids = None # fill_ids = origin_input_ids + output_ids
|
||||
|
||||
self.sampling_params = sampling_params
|
||||
self.lora_path = lora_path
|
||||
|
||||
# Memory info
|
||||
# Memory pool info
|
||||
self.req_pool_idx = None
|
||||
|
||||
# Check finish
|
||||
@@ -428,7 +428,7 @@ bid = 0
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ScheduleBatch:
|
||||
"""Store all inforamtion of a batch."""
|
||||
"""Store all inforamtion of a batch on the scheduler."""
|
||||
|
||||
# Request, memory pool, and cache
|
||||
reqs: List[Req]
|
||||
@@ -438,9 +438,9 @@ class ScheduleBatch:
|
||||
|
||||
# For utility
|
||||
model_config: ModelConfig = None
|
||||
|
||||
forward_mode: ForwardMode = None
|
||||
sampling_info: SamplingBatchInfo = None
|
||||
next_batch_sampling_info: SamplingBatchInfo = None
|
||||
|
||||
# Batched arguments to model runner
|
||||
input_ids: torch.Tensor = None
|
||||
@@ -509,7 +509,7 @@ class ScheduleBatch:
|
||||
def is_empty(self):
|
||||
return len(self.reqs) == 0
|
||||
|
||||
def alloc_req_slots(self, num_reqs):
|
||||
def alloc_req_slots(self, num_reqs: int):
|
||||
req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
|
||||
if req_pool_indices is None:
|
||||
raise RuntimeError(
|
||||
@@ -610,7 +610,7 @@ class ScheduleBatch:
|
||||
|
||||
assert len(self.out_cache_loc) == self.extend_num_tokens
|
||||
|
||||
def prepare_for_extend(self):
|
||||
def prepare_for_extend(self, enable_overlap_schedule: bool = False):
|
||||
self.forward_mode = ForwardMode.EXTEND
|
||||
|
||||
bs = len(self.reqs)
|
||||
@@ -704,7 +704,7 @@ class ScheduleBatch:
|
||||
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
||||
self,
|
||||
self.model_config.vocab_size,
|
||||
global_server_args_dict["disable_penalizer"],
|
||||
enable_overlap_schedule=enable_overlap_schedule,
|
||||
)
|
||||
|
||||
def mix_with_running(self, running_batch: "ScheduleBatch"):
|
||||
@@ -746,6 +746,7 @@ class ScheduleBatch:
|
||||
return False
|
||||
|
||||
def retract_decode(self):
|
||||
"""Retract the decoding requests when there is not enough memory."""
|
||||
sorted_indices = [i for i in range(len(self.reqs))]
|
||||
|
||||
# TODO(lsyin): improve retraction policy for radix cache
|
||||
@@ -886,18 +887,10 @@ class ScheduleBatch:
|
||||
|
||||
def prepare_for_idle(self):
|
||||
self.forward_mode = ForwardMode.IDLE
|
||||
self.input_ids = torch.empty(0, dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
self.seq_lens = torch.empty(0, dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
self.out_cache_loc = torch.empty(0, dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
self.req_pool_indices = torch.empty(0, dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
self.input_ids = torch.empty(0, dtype=torch.int32, device=self.device)
|
||||
self.seq_lens = torch.empty(0, dtype=torch.int32, device=self.device)
|
||||
self.out_cache_loc = torch.empty(0, dtype=torch.int32, device=self.device)
|
||||
self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
|
||||
self.seq_lens_sum = 0
|
||||
self.extend_num_tokens = 0
|
||||
|
||||
@@ -1063,7 +1056,6 @@ class ScheduleBatch:
|
||||
out_cache_loc=self.out_cache_loc,
|
||||
return_logprob=self.return_logprob,
|
||||
decoding_reqs=self.decoding_reqs,
|
||||
sampling_info=self.sampling_info,
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
|
||||
Reference in New Issue
Block a user