Simplify batch result resolution (#1735)
This commit is contained in:
@@ -29,8 +29,8 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
|
||||
It contains low-level tensor data. Most of the data consists of GPU tensors.
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@@ -116,7 +116,7 @@ class FINISH_ABORT(BaseFinishReason):
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclasses.dataclass
|
||||
class ImageInputs:
|
||||
"""The image related inputs."""
|
||||
|
||||
@@ -407,7 +407,7 @@ class Req:
|
||||
bid = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclasses.dataclass
|
||||
class ScheduleBatch:
|
||||
"""Store all inforamtion of a batch."""
|
||||
|
||||
@@ -902,7 +902,7 @@ class ScheduleBatch:
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclasses.dataclass
|
||||
class ModelWorkerBatch:
|
||||
# The batch id
|
||||
bid: int
|
||||
@@ -942,24 +942,7 @@ class ModelWorkerBatch:
|
||||
mrope_positions_delta: List[List[int]]
|
||||
|
||||
def copy(self):
|
||||
return ModelWorkerBatch(
|
||||
bid=self.bid,
|
||||
forward_mode=self.forward_mode,
|
||||
input_ids=self.input_ids,
|
||||
req_pool_indices=self.req_pool_indices,
|
||||
seq_lens=self.seq_lens,
|
||||
out_cache_loc=self.out_cache_loc,
|
||||
req_to_token_pool_records=self.req_to_token_pool_records,
|
||||
return_logprob=self.return_logprob,
|
||||
top_logprobs_nums=self.top_logprobs_nums,
|
||||
extend_seq_lens=self.extend_seq_lens,
|
||||
extend_prefix_lens=self.extend_prefix_lens,
|
||||
extend_logprob_start_lens=self.extend_logprob_start_lens,
|
||||
image_inputs=self.image_inputs,
|
||||
lora_paths=self.lora_paths,
|
||||
sampling_info=self.sampling_info.copy(),
|
||||
mrope_positions_delta=self.mrope_positions_delta,
|
||||
)
|
||||
return dataclasses.replace(self, sampling_info=self.sampling_info.copy())
|
||||
|
||||
def to(self, device: str):
|
||||
self.input_ids = self.input_ids.to(device, non_blocking=True)
|
||||
|
||||
@@ -149,12 +149,8 @@ class Scheduler:
|
||||
# Launch a tensor parallel worker
|
||||
if self.enable_overlap:
|
||||
TpWorkerClass = TpModelWorkerClient
|
||||
self.resolve_next_token_ids = (
|
||||
lambda bid, x: self.tp_worker.resolve_future_token_ids(bid)
|
||||
)
|
||||
else:
|
||||
TpWorkerClass = TpModelWorker
|
||||
self.resolve_next_token_ids = lambda bid, x: x.tolist()
|
||||
|
||||
self.tp_worker = TpWorkerClass(
|
||||
server_args=server_args,
|
||||
@@ -756,9 +752,12 @@ class Scheduler:
|
||||
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
|
||||
if self.is_generation:
|
||||
logits_output, next_token_ids, bid = result
|
||||
if batch.return_logprob:
|
||||
# Move logprobs to cpu
|
||||
if logits_output.next_token_logprobs is not None:
|
||||
|
||||
if self.enable_overlap:
|
||||
logits_output, next_token_ids = self.tp_worker.resulve_batch_result(bid)
|
||||
else:
|
||||
# Move next_token_ids and logprobs to cpu
|
||||
if batch.return_logprob:
|
||||
logits_output.next_token_logprobs = (
|
||||
logits_output.next_token_logprobs[
|
||||
torch.arange(len(next_token_ids), device=self.device),
|
||||
@@ -771,8 +770,7 @@ class Scheduler:
|
||||
logits_output.normalized_prompt_logprobs = (
|
||||
logits_output.normalized_prompt_logprobs.tolist()
|
||||
)
|
||||
|
||||
next_token_ids = self.resolve_next_token_ids(bid, next_token_ids)
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
|
||||
# Check finish conditions
|
||||
logprob_pt = 0
|
||||
@@ -825,14 +823,16 @@ class Scheduler:
|
||||
logits_output, next_token_ids, bid = result
|
||||
self.num_generated_tokens += len(batch.reqs)
|
||||
|
||||
# Move logprobs to cpu
|
||||
if batch.return_logprob:
|
||||
next_token_logprobs = logits_output.next_token_logprobs[
|
||||
torch.arange(len(next_token_ids), device=self.device),
|
||||
next_token_ids,
|
||||
].tolist()
|
||||
|
||||
next_token_ids = self.resolve_next_token_ids(bid, next_token_ids)
|
||||
if self.enable_overlap:
|
||||
logits_output, next_token_ids = self.tp_worker.resulve_batch_result(bid)
|
||||
else:
|
||||
# Move next_token_ids and logprobs to cpu
|
||||
if batch.return_logprob:
|
||||
next_token_logprobs = logits_output.next_token_logprobs[
|
||||
torch.arange(len(next_token_ids), device=self.device),
|
||||
next_token_ids,
|
||||
].tolist()
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
|
||||
self.token_to_kv_pool.free_group_begin()
|
||||
|
||||
|
||||
@@ -48,19 +48,16 @@ class TpModelWorkerClient:
|
||||
self.max_running_requests = self.worker.max_running_requests
|
||||
self.device = self.worker.device
|
||||
|
||||
# Create future mappings
|
||||
self.future_logits_output_dict = dict()
|
||||
self.future_logits_output_ct = 0
|
||||
# Init future mappings
|
||||
self.future_token_ids_ct = 0
|
||||
self.future_token_ids_limit = self.max_running_requests * 3
|
||||
self.future_token_ids_map = torch.empty(
|
||||
(self.max_running_requests * 5,), dtype=torch.int32, device=self.device
|
||||
)
|
||||
self.future_token_ids_limit = self.max_running_requests * 3
|
||||
self.future_token_ids_output = dict()
|
||||
|
||||
# Launch a thread
|
||||
self.future_event_map = dict()
|
||||
self.forward_queue = Queue()
|
||||
self.input_queue = Queue()
|
||||
self.output_queue = Queue()
|
||||
self.forward_stream = torch.cuda.Stream()
|
||||
self.forward_thread = threading.Thread(
|
||||
target=self.forward_thread_func,
|
||||
@@ -90,9 +87,7 @@ class TpModelWorkerClient:
|
||||
def forward_thread_func_(self):
|
||||
while True:
|
||||
tic1 = time.time()
|
||||
model_worker_batch, future_logits_output, future_next_token_ids = (
|
||||
self.forward_queue.get()
|
||||
)
|
||||
model_worker_batch, future_token_ids_ct = self.input_queue.get()
|
||||
|
||||
# Resolve future tokens in the input
|
||||
tic2 = time.time()
|
||||
@@ -107,17 +102,22 @@ class TpModelWorkerClient:
|
||||
model_worker_batch
|
||||
)
|
||||
|
||||
# Set future values
|
||||
if model_worker_batch.return_logprob:
|
||||
self.future_logits_output_dict[future_logits_output] = logits_output
|
||||
|
||||
# Update the future token ids map
|
||||
bs = len(model_worker_batch.seq_lens)
|
||||
future_next_token_ids = torch.arange(
|
||||
-(future_token_ids_ct + bs),
|
||||
-(future_token_ids_ct),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
self.future_token_ids_map[-future_next_token_ids] = next_token_ids.to(
|
||||
torch.int32
|
||||
)
|
||||
self.future_token_ids_output[model_worker_batch.bid] = (
|
||||
next_token_ids.tolist()
|
||||
)
|
||||
self.future_event_map[model_worker_batch.bid].set()
|
||||
|
||||
# Set the result
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
assert logits_output.next_token_logprobs is None, "Not supported"
|
||||
self.output_queue.put((None, next_token_ids))
|
||||
|
||||
if False:
|
||||
tic3 = time.time()
|
||||
@@ -128,38 +128,26 @@ class TpModelWorkerClient:
|
||||
f"{self.acc_time_with_waiting=:.3f}, {self.acc_time_without_waiting=:.3f}, {self.forward_queue.qsize()=}"
|
||||
)
|
||||
|
||||
def resolve_future_token_ids(self, bid: int):
|
||||
self.future_event_map[bid].wait()
|
||||
ret = self.future_token_ids_output[bid]
|
||||
del self.future_event_map[bid]
|
||||
return ret
|
||||
|
||||
def resolve_future_logits_output(self, future_obj):
|
||||
return self.future_logits_output_dict.pop(future_obj)
|
||||
def resulve_batch_result(self, bid: int):
|
||||
logits_output, next_token_ids = self.output_queue.get()
|
||||
return logits_output, next_token_ids
|
||||
|
||||
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
|
||||
# Allocate output future objects
|
||||
future_logits_output = self.future_logits_output_ct
|
||||
self.future_logits_output_ct += 1
|
||||
# Push a new batch to the queue
|
||||
self.input_queue.put((model_worker_batch.copy(), self.future_token_ids_ct))
|
||||
|
||||
# Allocate output future objects
|
||||
bs = len(model_worker_batch.seq_lens)
|
||||
with torch.cuda.stream(self.forward_stream):
|
||||
future_next_token_ids = -torch.arange(
|
||||
self.future_token_ids_ct + 1,
|
||||
self.future_token_ids_ct + 1 + bs,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
future_next_token_ids = torch.arange(
|
||||
-(self.future_token_ids_ct + bs),
|
||||
-(self.future_token_ids_ct),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
self.future_token_ids_ct = (
|
||||
self.future_token_ids_ct + bs
|
||||
) % self.future_token_ids_limit
|
||||
ret = future_logits_output, future_next_token_ids
|
||||
|
||||
self.future_event_map[model_worker_batch.bid] = threading.Event()
|
||||
self.forward_queue.put(
|
||||
(model_worker_batch.copy(), future_logits_output, future_next_token_ids)
|
||||
)
|
||||
return ret
|
||||
return None, future_next_token_ids
|
||||
|
||||
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
|
||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||
|
||||
@@ -120,7 +120,7 @@ class ModelRunner:
|
||||
)
|
||||
|
||||
if self.is_multimodal_model:
|
||||
logger.info(
|
||||
logger.warning(
|
||||
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
|
||||
)
|
||||
server_args.chunked_prefill_size = None
|
||||
@@ -131,13 +131,6 @@ class ModelRunner:
|
||||
]:
|
||||
server_args.disable_cuda_graph = True
|
||||
|
||||
if self.server_args.enable_overlap_schedule:
|
||||
logger.warning(
|
||||
"Overlap scheduler is enabled. This is an experimental feature. "
|
||||
"Sampling penalizer (e.g., frequency and repetition penalty), constrained decoding (e.g., regex, JSON), "
|
||||
"and embedding APIs are not supported and will lead to wrong results."
|
||||
)
|
||||
|
||||
# Global vars
|
||||
if server_args.show_time_cost:
|
||||
enable_show_time_cost()
|
||||
|
||||
@@ -177,6 +177,16 @@ class ServerArgs:
|
||||
if self.sampling_backend is None:
|
||||
self.sampling_backend = "flashinfer"
|
||||
|
||||
if self.enable_overlap_schedule:
|
||||
logger.warning(
|
||||
"Overlap scheduler mode is enabled. This is an experimental feature. "
|
||||
"Sampling penalizer (e.g., frequency and repetition penalty), constrained decoding (e.g., regex, JSON), "
|
||||
"and embedding APIs are not supported and will lead to wrong results. "
|
||||
"The NaN detection is also disabled."
|
||||
)
|
||||
self.disable_penalizer = True
|
||||
self.disable_nan_detection = True
|
||||
|
||||
# Model-specific patches
|
||||
if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path:
|
||||
logger.info(
|
||||
|
||||
Reference in New Issue
Block a user