Make constrained decoding work for overlap scheduler (#2095)

This commit is contained in:
Lianmin Zheng
2024-11-19 15:04:43 -08:00
committed by GitHub
parent 55bd97f3e5
commit ffd20fcd03
8 changed files with 119 additions and 95 deletions

View File

@@ -136,6 +136,7 @@ class ImageInputs:
image_embeds: Optional[List[torch.Tensor]] = None
aspect_ratio_ids: Optional[List[torch.Tensor]] = None
aspect_ratio_mask: Optional[List[torch.Tensor]] = None
# QWen2-VL related
image_grid_thws: List[Tuple[int, int, int]] = None
mrope_position_delta: Optional[torch.Tensor] = None
@@ -187,11 +188,10 @@ class Req:
self.origin_input_ids = origin_input_ids
self.output_ids = [] # Each decode stage's output ids
self.fill_ids = None # fill_ids = origin_input_ids + output_ids
self.sampling_params = sampling_params
self.lora_path = lora_path
# Memory info
# Memory pool info
self.req_pool_idx = None
# Check finish
@@ -428,7 +428,7 @@ bid = 0
@dataclasses.dataclass
class ScheduleBatch:
"""Store all inforamtion of a batch."""
"""Store all inforamtion of a batch on the scheduler."""
# Request, memory pool, and cache
reqs: List[Req]
@@ -438,9 +438,9 @@ class ScheduleBatch:
# For utility
model_config: ModelConfig = None
forward_mode: ForwardMode = None
sampling_info: SamplingBatchInfo = None
next_batch_sampling_info: SamplingBatchInfo = None
# Batched arguments to model runner
input_ids: torch.Tensor = None
@@ -509,7 +509,7 @@ class ScheduleBatch:
def is_empty(self):
return len(self.reqs) == 0
def alloc_req_slots(self, num_reqs):
def alloc_req_slots(self, num_reqs: int):
req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
if req_pool_indices is None:
raise RuntimeError(
@@ -610,7 +610,7 @@ class ScheduleBatch:
assert len(self.out_cache_loc) == self.extend_num_tokens
def prepare_for_extend(self):
def prepare_for_extend(self, enable_overlap_schedule: bool = False):
self.forward_mode = ForwardMode.EXTEND
bs = len(self.reqs)
@@ -704,7 +704,7 @@ class ScheduleBatch:
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
self,
self.model_config.vocab_size,
global_server_args_dict["disable_penalizer"],
enable_overlap_schedule=enable_overlap_schedule,
)
def mix_with_running(self, running_batch: "ScheduleBatch"):
@@ -746,6 +746,7 @@ class ScheduleBatch:
return False
def retract_decode(self):
"""Retract the decoding requests when there is not enough memory."""
sorted_indices = [i for i in range(len(self.reqs))]
# TODO(lsyin): improve retraction policy for radix cache
@@ -886,18 +887,10 @@ class ScheduleBatch:
def prepare_for_idle(self):
self.forward_mode = ForwardMode.IDLE
self.input_ids = torch.empty(0, dtype=torch.int32).to(
self.device, non_blocking=True
)
self.seq_lens = torch.empty(0, dtype=torch.int32).to(
self.device, non_blocking=True
)
self.out_cache_loc = torch.empty(0, dtype=torch.int32).to(
self.device, non_blocking=True
)
self.req_pool_indices = torch.empty(0, dtype=torch.int32).to(
self.device, non_blocking=True
)
self.input_ids = torch.empty(0, dtype=torch.int32, device=self.device)
self.seq_lens = torch.empty(0, dtype=torch.int32, device=self.device)
self.out_cache_loc = torch.empty(0, dtype=torch.int32, device=self.device)
self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
self.seq_lens_sum = 0
self.extend_num_tokens = 0
@@ -1063,7 +1056,6 @@ class ScheduleBatch:
out_cache_loc=self.out_cache_loc,
return_logprob=self.return_logprob,
decoding_reqs=self.decoding_reqs,
sampling_info=self.sampling_info,
)
def __str__(self):