Simplify batch result resolution (#1735)
This commit is contained in:
@@ -29,8 +29,8 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
|
|||||||
It contains low-level tensor data. Most of the data consists of GPU tensors.
|
It contains low-level tensor data. Most of the data consists of GPU tensors.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
import logging
|
import logging
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -116,7 +116,7 @@ class FINISH_ABORT(BaseFinishReason):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclasses.dataclass
|
||||||
class ImageInputs:
|
class ImageInputs:
|
||||||
"""The image related inputs."""
|
"""The image related inputs."""
|
||||||
|
|
||||||
@@ -407,7 +407,7 @@ class Req:
|
|||||||
bid = 0
|
bid = 0
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclasses.dataclass
|
||||||
class ScheduleBatch:
|
class ScheduleBatch:
|
||||||
"""Store all inforamtion of a batch."""
|
"""Store all inforamtion of a batch."""
|
||||||
|
|
||||||
@@ -902,7 +902,7 @@ class ScheduleBatch:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclasses.dataclass
|
||||||
class ModelWorkerBatch:
|
class ModelWorkerBatch:
|
||||||
# The batch id
|
# The batch id
|
||||||
bid: int
|
bid: int
|
||||||
@@ -942,24 +942,7 @@ class ModelWorkerBatch:
|
|||||||
mrope_positions_delta: List[List[int]]
|
mrope_positions_delta: List[List[int]]
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
return ModelWorkerBatch(
|
return dataclasses.replace(self, sampling_info=self.sampling_info.copy())
|
||||||
bid=self.bid,
|
|
||||||
forward_mode=self.forward_mode,
|
|
||||||
input_ids=self.input_ids,
|
|
||||||
req_pool_indices=self.req_pool_indices,
|
|
||||||
seq_lens=self.seq_lens,
|
|
||||||
out_cache_loc=self.out_cache_loc,
|
|
||||||
req_to_token_pool_records=self.req_to_token_pool_records,
|
|
||||||
return_logprob=self.return_logprob,
|
|
||||||
top_logprobs_nums=self.top_logprobs_nums,
|
|
||||||
extend_seq_lens=self.extend_seq_lens,
|
|
||||||
extend_prefix_lens=self.extend_prefix_lens,
|
|
||||||
extend_logprob_start_lens=self.extend_logprob_start_lens,
|
|
||||||
image_inputs=self.image_inputs,
|
|
||||||
lora_paths=self.lora_paths,
|
|
||||||
sampling_info=self.sampling_info.copy(),
|
|
||||||
mrope_positions_delta=self.mrope_positions_delta,
|
|
||||||
)
|
|
||||||
|
|
||||||
def to(self, device: str):
|
def to(self, device: str):
|
||||||
self.input_ids = self.input_ids.to(device, non_blocking=True)
|
self.input_ids = self.input_ids.to(device, non_blocking=True)
|
||||||
|
|||||||
@@ -149,12 +149,8 @@ class Scheduler:
|
|||||||
# Launch a tensor parallel worker
|
# Launch a tensor parallel worker
|
||||||
if self.enable_overlap:
|
if self.enable_overlap:
|
||||||
TpWorkerClass = TpModelWorkerClient
|
TpWorkerClass = TpModelWorkerClient
|
||||||
self.resolve_next_token_ids = (
|
|
||||||
lambda bid, x: self.tp_worker.resolve_future_token_ids(bid)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
TpWorkerClass = TpModelWorker
|
TpWorkerClass = TpModelWorker
|
||||||
self.resolve_next_token_ids = lambda bid, x: x.tolist()
|
|
||||||
|
|
||||||
self.tp_worker = TpWorkerClass(
|
self.tp_worker = TpWorkerClass(
|
||||||
server_args=server_args,
|
server_args=server_args,
|
||||||
@@ -756,9 +752,12 @@ class Scheduler:
|
|||||||
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
|
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
|
||||||
if self.is_generation:
|
if self.is_generation:
|
||||||
logits_output, next_token_ids, bid = result
|
logits_output, next_token_ids, bid = result
|
||||||
if batch.return_logprob:
|
|
||||||
# Move logprobs to cpu
|
if self.enable_overlap:
|
||||||
if logits_output.next_token_logprobs is not None:
|
logits_output, next_token_ids = self.tp_worker.resulve_batch_result(bid)
|
||||||
|
else:
|
||||||
|
# Move next_token_ids and logprobs to cpu
|
||||||
|
if batch.return_logprob:
|
||||||
logits_output.next_token_logprobs = (
|
logits_output.next_token_logprobs = (
|
||||||
logits_output.next_token_logprobs[
|
logits_output.next_token_logprobs[
|
||||||
torch.arange(len(next_token_ids), device=self.device),
|
torch.arange(len(next_token_ids), device=self.device),
|
||||||
@@ -771,8 +770,7 @@ class Scheduler:
|
|||||||
logits_output.normalized_prompt_logprobs = (
|
logits_output.normalized_prompt_logprobs = (
|
||||||
logits_output.normalized_prompt_logprobs.tolist()
|
logits_output.normalized_prompt_logprobs.tolist()
|
||||||
)
|
)
|
||||||
|
next_token_ids = next_token_ids.tolist()
|
||||||
next_token_ids = self.resolve_next_token_ids(bid, next_token_ids)
|
|
||||||
|
|
||||||
# Check finish conditions
|
# Check finish conditions
|
||||||
logprob_pt = 0
|
logprob_pt = 0
|
||||||
@@ -825,14 +823,16 @@ class Scheduler:
|
|||||||
logits_output, next_token_ids, bid = result
|
logits_output, next_token_ids, bid = result
|
||||||
self.num_generated_tokens += len(batch.reqs)
|
self.num_generated_tokens += len(batch.reqs)
|
||||||
|
|
||||||
# Move logprobs to cpu
|
if self.enable_overlap:
|
||||||
if batch.return_logprob:
|
logits_output, next_token_ids = self.tp_worker.resulve_batch_result(bid)
|
||||||
next_token_logprobs = logits_output.next_token_logprobs[
|
else:
|
||||||
torch.arange(len(next_token_ids), device=self.device),
|
# Move next_token_ids and logprobs to cpu
|
||||||
next_token_ids,
|
if batch.return_logprob:
|
||||||
].tolist()
|
next_token_logprobs = logits_output.next_token_logprobs[
|
||||||
|
torch.arange(len(next_token_ids), device=self.device),
|
||||||
next_token_ids = self.resolve_next_token_ids(bid, next_token_ids)
|
next_token_ids,
|
||||||
|
].tolist()
|
||||||
|
next_token_ids = next_token_ids.tolist()
|
||||||
|
|
||||||
self.token_to_kv_pool.free_group_begin()
|
self.token_to_kv_pool.free_group_begin()
|
||||||
|
|
||||||
|
|||||||
@@ -48,19 +48,16 @@ class TpModelWorkerClient:
|
|||||||
self.max_running_requests = self.worker.max_running_requests
|
self.max_running_requests = self.worker.max_running_requests
|
||||||
self.device = self.worker.device
|
self.device = self.worker.device
|
||||||
|
|
||||||
# Create future mappings
|
# Init future mappings
|
||||||
self.future_logits_output_dict = dict()
|
|
||||||
self.future_logits_output_ct = 0
|
|
||||||
self.future_token_ids_ct = 0
|
self.future_token_ids_ct = 0
|
||||||
|
self.future_token_ids_limit = self.max_running_requests * 3
|
||||||
self.future_token_ids_map = torch.empty(
|
self.future_token_ids_map = torch.empty(
|
||||||
(self.max_running_requests * 5,), dtype=torch.int32, device=self.device
|
(self.max_running_requests * 5,), dtype=torch.int32, device=self.device
|
||||||
)
|
)
|
||||||
self.future_token_ids_limit = self.max_running_requests * 3
|
|
||||||
self.future_token_ids_output = dict()
|
|
||||||
|
|
||||||
# Launch a thread
|
# Launch a thread
|
||||||
self.future_event_map = dict()
|
self.input_queue = Queue()
|
||||||
self.forward_queue = Queue()
|
self.output_queue = Queue()
|
||||||
self.forward_stream = torch.cuda.Stream()
|
self.forward_stream = torch.cuda.Stream()
|
||||||
self.forward_thread = threading.Thread(
|
self.forward_thread = threading.Thread(
|
||||||
target=self.forward_thread_func,
|
target=self.forward_thread_func,
|
||||||
@@ -90,9 +87,7 @@ class TpModelWorkerClient:
|
|||||||
def forward_thread_func_(self):
|
def forward_thread_func_(self):
|
||||||
while True:
|
while True:
|
||||||
tic1 = time.time()
|
tic1 = time.time()
|
||||||
model_worker_batch, future_logits_output, future_next_token_ids = (
|
model_worker_batch, future_token_ids_ct = self.input_queue.get()
|
||||||
self.forward_queue.get()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Resolve future tokens in the input
|
# Resolve future tokens in the input
|
||||||
tic2 = time.time()
|
tic2 = time.time()
|
||||||
@@ -107,17 +102,22 @@ class TpModelWorkerClient:
|
|||||||
model_worker_batch
|
model_worker_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set future values
|
# Update the future token ids map
|
||||||
if model_worker_batch.return_logprob:
|
bs = len(model_worker_batch.seq_lens)
|
||||||
self.future_logits_output_dict[future_logits_output] = logits_output
|
future_next_token_ids = torch.arange(
|
||||||
|
-(future_token_ids_ct + bs),
|
||||||
|
-(future_token_ids_ct),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
self.future_token_ids_map[-future_next_token_ids] = next_token_ids.to(
|
self.future_token_ids_map[-future_next_token_ids] = next_token_ids.to(
|
||||||
torch.int32
|
torch.int32
|
||||||
)
|
)
|
||||||
self.future_token_ids_output[model_worker_batch.bid] = (
|
|
||||||
next_token_ids.tolist()
|
# Set the result
|
||||||
)
|
next_token_ids = next_token_ids.tolist()
|
||||||
self.future_event_map[model_worker_batch.bid].set()
|
assert logits_output.next_token_logprobs is None, "Not supported"
|
||||||
|
self.output_queue.put((None, next_token_ids))
|
||||||
|
|
||||||
if False:
|
if False:
|
||||||
tic3 = time.time()
|
tic3 = time.time()
|
||||||
@@ -128,38 +128,26 @@ class TpModelWorkerClient:
|
|||||||
f"{self.acc_time_with_waiting=:.3f}, {self.acc_time_without_waiting=:.3f}, {self.forward_queue.qsize()=}"
|
f"{self.acc_time_with_waiting=:.3f}, {self.acc_time_without_waiting=:.3f}, {self.forward_queue.qsize()=}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def resolve_future_token_ids(self, bid: int):
|
def resulve_batch_result(self, bid: int):
|
||||||
self.future_event_map[bid].wait()
|
logits_output, next_token_ids = self.output_queue.get()
|
||||||
ret = self.future_token_ids_output[bid]
|
return logits_output, next_token_ids
|
||||||
del self.future_event_map[bid]
|
|
||||||
return ret
|
|
||||||
|
|
||||||
def resolve_future_logits_output(self, future_obj):
|
|
||||||
return self.future_logits_output_dict.pop(future_obj)
|
|
||||||
|
|
||||||
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
|
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
|
||||||
# Allocate output future objects
|
# Push a new batch to the queue
|
||||||
future_logits_output = self.future_logits_output_ct
|
self.input_queue.put((model_worker_batch.copy(), self.future_token_ids_ct))
|
||||||
self.future_logits_output_ct += 1
|
|
||||||
|
|
||||||
|
# Allocate output future objects
|
||||||
bs = len(model_worker_batch.seq_lens)
|
bs = len(model_worker_batch.seq_lens)
|
||||||
with torch.cuda.stream(self.forward_stream):
|
future_next_token_ids = torch.arange(
|
||||||
future_next_token_ids = -torch.arange(
|
-(self.future_token_ids_ct + bs),
|
||||||
self.future_token_ids_ct + 1,
|
-(self.future_token_ids_ct),
|
||||||
self.future_token_ids_ct + 1 + bs,
|
dtype=torch.int32,
|
||||||
dtype=torch.int32,
|
device=self.device,
|
||||||
device=self.device,
|
)
|
||||||
)
|
|
||||||
self.future_token_ids_ct = (
|
self.future_token_ids_ct = (
|
||||||
self.future_token_ids_ct + bs
|
self.future_token_ids_ct + bs
|
||||||
) % self.future_token_ids_limit
|
) % self.future_token_ids_limit
|
||||||
ret = future_logits_output, future_next_token_ids
|
return None, future_next_token_ids
|
||||||
|
|
||||||
self.future_event_map[model_worker_batch.bid] = threading.Event()
|
|
||||||
self.forward_queue.put(
|
|
||||||
(model_worker_batch.copy(), future_logits_output, future_next_token_ids)
|
|
||||||
)
|
|
||||||
return ret
|
|
||||||
|
|
||||||
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
|
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
|
||||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||||
|
|||||||
@@ -120,7 +120,7 @@ class ModelRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.is_multimodal_model:
|
if self.is_multimodal_model:
|
||||||
logger.info(
|
logger.warning(
|
||||||
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
|
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
|
||||||
)
|
)
|
||||||
server_args.chunked_prefill_size = None
|
server_args.chunked_prefill_size = None
|
||||||
@@ -131,13 +131,6 @@ class ModelRunner:
|
|||||||
]:
|
]:
|
||||||
server_args.disable_cuda_graph = True
|
server_args.disable_cuda_graph = True
|
||||||
|
|
||||||
if self.server_args.enable_overlap_schedule:
|
|
||||||
logger.warning(
|
|
||||||
"Overlap scheduler is enabled. This is an experimental feature. "
|
|
||||||
"Sampling penalizer (e.g., frequency and repetition penalty), constrained decoding (e.g., regex, JSON), "
|
|
||||||
"and embedding APIs are not supported and will lead to wrong results."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Global vars
|
# Global vars
|
||||||
if server_args.show_time_cost:
|
if server_args.show_time_cost:
|
||||||
enable_show_time_cost()
|
enable_show_time_cost()
|
||||||
|
|||||||
@@ -177,6 +177,16 @@ class ServerArgs:
|
|||||||
if self.sampling_backend is None:
|
if self.sampling_backend is None:
|
||||||
self.sampling_backend = "flashinfer"
|
self.sampling_backend = "flashinfer"
|
||||||
|
|
||||||
|
if self.enable_overlap_schedule:
|
||||||
|
logger.warning(
|
||||||
|
"Overlap scheduler mode is enabled. This is an experimental feature. "
|
||||||
|
"Sampling penalizer (e.g., frequency and repetition penalty), constrained decoding (e.g., regex, JSON), "
|
||||||
|
"and embedding APIs are not supported and will lead to wrong results. "
|
||||||
|
"The NaN detection is also disabled."
|
||||||
|
)
|
||||||
|
self.disable_penalizer = True
|
||||||
|
self.disable_nan_detection = True
|
||||||
|
|
||||||
# Model-specific patches
|
# Model-specific patches
|
||||||
if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path:
|
if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path:
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|||||||
Reference in New Issue
Block a user