Simplify batch result resolution (#1735)

This commit is contained in:
Lianmin Zheng
2024-10-20 19:47:14 -07:00
committed by GitHub
parent e12358dc91
commit b121bc03a3
5 changed files with 64 additions and 90 deletions

View File

@@ -29,8 +29,8 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
It contains low-level tensor data. Most of the data consists of GPU tensors. It contains low-level tensor data. Most of the data consists of GPU tensors.
""" """
import dataclasses
import logging import logging
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
@@ -116,7 +116,7 @@ class FINISH_ABORT(BaseFinishReason):
} }
@dataclass @dataclasses.dataclass
class ImageInputs: class ImageInputs:
"""The image related inputs.""" """The image related inputs."""
@@ -407,7 +407,7 @@ class Req:
bid = 0 bid = 0
@dataclass @dataclasses.dataclass
class ScheduleBatch: class ScheduleBatch:
"""Store all inforamtion of a batch.""" """Store all inforamtion of a batch."""
@@ -902,7 +902,7 @@ class ScheduleBatch:
) )
@dataclass @dataclasses.dataclass
class ModelWorkerBatch: class ModelWorkerBatch:
# The batch id # The batch id
bid: int bid: int
@@ -942,24 +942,7 @@ class ModelWorkerBatch:
mrope_positions_delta: List[List[int]] mrope_positions_delta: List[List[int]]
def copy(self): def copy(self):
return ModelWorkerBatch( return dataclasses.replace(self, sampling_info=self.sampling_info.copy())
bid=self.bid,
forward_mode=self.forward_mode,
input_ids=self.input_ids,
req_pool_indices=self.req_pool_indices,
seq_lens=self.seq_lens,
out_cache_loc=self.out_cache_loc,
req_to_token_pool_records=self.req_to_token_pool_records,
return_logprob=self.return_logprob,
top_logprobs_nums=self.top_logprobs_nums,
extend_seq_lens=self.extend_seq_lens,
extend_prefix_lens=self.extend_prefix_lens,
extend_logprob_start_lens=self.extend_logprob_start_lens,
image_inputs=self.image_inputs,
lora_paths=self.lora_paths,
sampling_info=self.sampling_info.copy(),
mrope_positions_delta=self.mrope_positions_delta,
)
def to(self, device: str): def to(self, device: str):
self.input_ids = self.input_ids.to(device, non_blocking=True) self.input_ids = self.input_ids.to(device, non_blocking=True)

View File

@@ -149,12 +149,8 @@ class Scheduler:
# Launch a tensor parallel worker # Launch a tensor parallel worker
if self.enable_overlap: if self.enable_overlap:
TpWorkerClass = TpModelWorkerClient TpWorkerClass = TpModelWorkerClient
self.resolve_next_token_ids = (
lambda bid, x: self.tp_worker.resolve_future_token_ids(bid)
)
else: else:
TpWorkerClass = TpModelWorker TpWorkerClass = TpModelWorker
self.resolve_next_token_ids = lambda bid, x: x.tolist()
self.tp_worker = TpWorkerClass( self.tp_worker = TpWorkerClass(
server_args=server_args, server_args=server_args,
@@ -756,9 +752,12 @@ class Scheduler:
def process_batch_result_prefill(self, batch: ScheduleBatch, result): def process_batch_result_prefill(self, batch: ScheduleBatch, result):
if self.is_generation: if self.is_generation:
logits_output, next_token_ids, bid = result logits_output, next_token_ids, bid = result
if batch.return_logprob:
# Move logprobs to cpu if self.enable_overlap:
if logits_output.next_token_logprobs is not None: logits_output, next_token_ids = self.tp_worker.resulve_batch_result(bid)
else:
# Move next_token_ids and logprobs to cpu
if batch.return_logprob:
logits_output.next_token_logprobs = ( logits_output.next_token_logprobs = (
logits_output.next_token_logprobs[ logits_output.next_token_logprobs[
torch.arange(len(next_token_ids), device=self.device), torch.arange(len(next_token_ids), device=self.device),
@@ -771,8 +770,7 @@ class Scheduler:
logits_output.normalized_prompt_logprobs = ( logits_output.normalized_prompt_logprobs = (
logits_output.normalized_prompt_logprobs.tolist() logits_output.normalized_prompt_logprobs.tolist()
) )
next_token_ids = next_token_ids.tolist()
next_token_ids = self.resolve_next_token_ids(bid, next_token_ids)
# Check finish conditions # Check finish conditions
logprob_pt = 0 logprob_pt = 0
@@ -825,14 +823,16 @@ class Scheduler:
logits_output, next_token_ids, bid = result logits_output, next_token_ids, bid = result
self.num_generated_tokens += len(batch.reqs) self.num_generated_tokens += len(batch.reqs)
# Move logprobs to cpu if self.enable_overlap:
if batch.return_logprob: logits_output, next_token_ids = self.tp_worker.resulve_batch_result(bid)
next_token_logprobs = logits_output.next_token_logprobs[ else:
torch.arange(len(next_token_ids), device=self.device), # Move next_token_ids and logprobs to cpu
next_token_ids, if batch.return_logprob:
].tolist() next_token_logprobs = logits_output.next_token_logprobs[
torch.arange(len(next_token_ids), device=self.device),
next_token_ids = self.resolve_next_token_ids(bid, next_token_ids) next_token_ids,
].tolist()
next_token_ids = next_token_ids.tolist()
self.token_to_kv_pool.free_group_begin() self.token_to_kv_pool.free_group_begin()

View File

@@ -48,19 +48,16 @@ class TpModelWorkerClient:
self.max_running_requests = self.worker.max_running_requests self.max_running_requests = self.worker.max_running_requests
self.device = self.worker.device self.device = self.worker.device
# Create future mappings # Init future mappings
self.future_logits_output_dict = dict()
self.future_logits_output_ct = 0
self.future_token_ids_ct = 0 self.future_token_ids_ct = 0
self.future_token_ids_limit = self.max_running_requests * 3
self.future_token_ids_map = torch.empty( self.future_token_ids_map = torch.empty(
(self.max_running_requests * 5,), dtype=torch.int32, device=self.device (self.max_running_requests * 5,), dtype=torch.int32, device=self.device
) )
self.future_token_ids_limit = self.max_running_requests * 3
self.future_token_ids_output = dict()
# Launch a thread # Launch a thread
self.future_event_map = dict() self.input_queue = Queue()
self.forward_queue = Queue() self.output_queue = Queue()
self.forward_stream = torch.cuda.Stream() self.forward_stream = torch.cuda.Stream()
self.forward_thread = threading.Thread( self.forward_thread = threading.Thread(
target=self.forward_thread_func, target=self.forward_thread_func,
@@ -90,9 +87,7 @@ class TpModelWorkerClient:
def forward_thread_func_(self): def forward_thread_func_(self):
while True: while True:
tic1 = time.time() tic1 = time.time()
model_worker_batch, future_logits_output, future_next_token_ids = ( model_worker_batch, future_token_ids_ct = self.input_queue.get()
self.forward_queue.get()
)
# Resolve future tokens in the input # Resolve future tokens in the input
tic2 = time.time() tic2 = time.time()
@@ -107,17 +102,22 @@ class TpModelWorkerClient:
model_worker_batch model_worker_batch
) )
# Set future values # Update the future token ids map
if model_worker_batch.return_logprob: bs = len(model_worker_batch.seq_lens)
self.future_logits_output_dict[future_logits_output] = logits_output future_next_token_ids = torch.arange(
-(future_token_ids_ct + bs),
-(future_token_ids_ct),
dtype=torch.int32,
device=self.device,
)
self.future_token_ids_map[-future_next_token_ids] = next_token_ids.to( self.future_token_ids_map[-future_next_token_ids] = next_token_ids.to(
torch.int32 torch.int32
) )
self.future_token_ids_output[model_worker_batch.bid] = (
next_token_ids.tolist() # Set the result
) next_token_ids = next_token_ids.tolist()
self.future_event_map[model_worker_batch.bid].set() assert logits_output.next_token_logprobs is None, "Not supported"
self.output_queue.put((None, next_token_ids))
if False: if False:
tic3 = time.time() tic3 = time.time()
@@ -128,38 +128,26 @@ class TpModelWorkerClient:
f"{self.acc_time_with_waiting=:.3f}, {self.acc_time_without_waiting=:.3f}, {self.forward_queue.qsize()=}" f"{self.acc_time_with_waiting=:.3f}, {self.acc_time_without_waiting=:.3f}, {self.forward_queue.qsize()=}"
) )
def resolve_future_token_ids(self, bid: int): def resulve_batch_result(self, bid: int):
self.future_event_map[bid].wait() logits_output, next_token_ids = self.output_queue.get()
ret = self.future_token_ids_output[bid] return logits_output, next_token_ids
del self.future_event_map[bid]
return ret
def resolve_future_logits_output(self, future_obj):
return self.future_logits_output_dict.pop(future_obj)
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
# Allocate output future objects # Push a new batch to the queue
future_logits_output = self.future_logits_output_ct self.input_queue.put((model_worker_batch.copy(), self.future_token_ids_ct))
self.future_logits_output_ct += 1
# Allocate output future objects
bs = len(model_worker_batch.seq_lens) bs = len(model_worker_batch.seq_lens)
with torch.cuda.stream(self.forward_stream): future_next_token_ids = torch.arange(
future_next_token_ids = -torch.arange( -(self.future_token_ids_ct + bs),
self.future_token_ids_ct + 1, -(self.future_token_ids_ct),
self.future_token_ids_ct + 1 + bs, dtype=torch.int32,
dtype=torch.int32, device=self.device,
device=self.device, )
)
self.future_token_ids_ct = ( self.future_token_ids_ct = (
self.future_token_ids_ct + bs self.future_token_ids_ct + bs
) % self.future_token_ids_limit ) % self.future_token_ids_limit
ret = future_logits_output, future_next_token_ids return None, future_next_token_ids
self.future_event_map[model_worker_batch.bid] = threading.Event()
self.forward_queue.put(
(model_worker_batch.copy(), future_logits_output, future_next_token_ids)
)
return ret
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch): def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)

View File

@@ -120,7 +120,7 @@ class ModelRunner:
) )
if self.is_multimodal_model: if self.is_multimodal_model:
logger.info( logger.warning(
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models." "Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
) )
server_args.chunked_prefill_size = None server_args.chunked_prefill_size = None
@@ -131,13 +131,6 @@ class ModelRunner:
]: ]:
server_args.disable_cuda_graph = True server_args.disable_cuda_graph = True
if self.server_args.enable_overlap_schedule:
logger.warning(
"Overlap scheduler is enabled. This is an experimental feature. "
"Sampling penalizer (e.g., frequency and repetition penalty), constrained decoding (e.g., regex, JSON), "
"and embedding APIs are not supported and will lead to wrong results."
)
# Global vars # Global vars
if server_args.show_time_cost: if server_args.show_time_cost:
enable_show_time_cost() enable_show_time_cost()

View File

@@ -177,6 +177,16 @@ class ServerArgs:
if self.sampling_backend is None: if self.sampling_backend is None:
self.sampling_backend = "flashinfer" self.sampling_backend = "flashinfer"
if self.enable_overlap_schedule:
logger.warning(
"Overlap scheduler mode is enabled. This is an experimental feature. "
"Sampling penalizer (e.g., frequency and repetition penalty), constrained decoding (e.g., regex, JSON), "
"and embedding APIs are not supported and will lead to wrong results. "
"The NaN detection is also disabled."
)
self.disable_penalizer = True
self.disable_nan_detection = True
# Model-specific patches # Model-specific patches
if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path: if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path:
logger.info( logger.info(