Clean up batch data structures: Introducing ModelWorkerBatch (#1544)
This commit is contained in:
@@ -62,11 +62,13 @@ import torch.distributed as dist
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
from sglang.srt.server import _set_envs_and_config
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
allocate_init_ports,
|
||||
configure_logger,
|
||||
kill_child_process,
|
||||
suppress_other_loggers,
|
||||
@@ -125,6 +127,11 @@ def load_model(server_args, tp_rank):
|
||||
suppress_other_loggers()
|
||||
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
|
||||
|
||||
server_args.port, server_args.additional_ports = allocate_init_ports(
|
||||
server_args.port,
|
||||
server_args.additional_ports,
|
||||
server_args.dp_size,
|
||||
)
|
||||
model_config = ModelConfig(
|
||||
server_args.model_path,
|
||||
server_args.trust_remote_code,
|
||||
@@ -136,7 +143,7 @@ def load_model(server_args, tp_rank):
|
||||
gpu_id=tp_rank,
|
||||
tp_rank=tp_rank,
|
||||
tp_size=server_args.tp_size,
|
||||
nccl_port=28888,
|
||||
nccl_port=server_args.additional_ports[-1],
|
||||
server_args=server_args,
|
||||
)
|
||||
rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}")
|
||||
@@ -225,17 +232,19 @@ def extend(reqs, model_runner):
|
||||
tree_cache=None,
|
||||
)
|
||||
batch.prepare_for_extend(model_runner.model_config.vocab_size)
|
||||
forward_batch = batch.get_forward_batch()
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
||||
logits_output = model_runner.forward(forward_batch)
|
||||
next_token_ids = model_runner.sample(logits_output, batch).tolist()
|
||||
next_token_ids = model_runner.sample(logits_output, forward_batch).tolist()
|
||||
return next_token_ids, logits_output.next_token_logits, batch
|
||||
|
||||
|
||||
def decode(input_token_ids, batch, model_runner):
|
||||
batch.prepare_for_decode(input_token_ids)
|
||||
forward_batch = batch.get_forward_batch()
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
||||
logits_output = model_runner.forward(forward_batch)
|
||||
next_token_ids = model_runner.sample(logits_output, batch).tolist()
|
||||
next_token_ids = model_runner.sample(logits_output, forward_batch).tolist()
|
||||
return next_token_ids, logits_output.next_token_logits
|
||||
|
||||
|
||||
@@ -357,7 +366,6 @@ def latency_test(
|
||||
tp_rank,
|
||||
):
|
||||
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
||||
_set_envs_and_config(server_args)
|
||||
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
|
||||
|
||||
# Load the model
|
||||
@@ -463,6 +471,7 @@ def plot_latency_test(
|
||||
|
||||
|
||||
def main(server_args, bench_args):
|
||||
_set_envs_and_config(server_args)
|
||||
|
||||
if server_args.model_path:
|
||||
if bench_args.correctness_test:
|
||||
@@ -513,8 +522,6 @@ if __name__ == "__main__":
|
||||
format="%(message)s",
|
||||
)
|
||||
|
||||
multiprocessing.set_start_method("spawn", force=True)
|
||||
|
||||
try:
|
||||
main(server_args, bench_args)
|
||||
except Exception as e:
|
||||
|
||||
@@ -62,7 +62,11 @@ class LogitsMetadata:
|
||||
|
||||
@classmethod
|
||||
def from_forward_batch(cls, forward_batch: ForwardBatch):
|
||||
return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
|
||||
if forward_batch.return_logprob:
|
||||
return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
|
||||
else:
|
||||
return_top_logprob = False
|
||||
|
||||
if forward_batch.forward_mode.is_extend():
|
||||
extend_logprob_pruned_lens_cpu = [
|
||||
extend_len - start_len
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
"""
|
||||
Copyright 2023-2024 SGLang Team
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -15,7 +13,19 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
"""Meta data for requests and batches"""
|
||||
"""
|
||||
Store information about requests and batches.
|
||||
|
||||
The following is the flow of data structures for a batch:
|
||||
|
||||
ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
|
||||
|
||||
- ScheduleBatch is managed by `scheduler.py::Scheduler`.
|
||||
It contains high-level scheduling data. Most of the data is on the CPU.
|
||||
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
|
||||
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
|
||||
It contains low-level tensor data. Most of the data consists of GPU tensors.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
@@ -29,7 +39,7 @@ from sglang.srt.constrained.jump_forward import JumpForwardMap
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
@@ -105,6 +115,8 @@ class FINISH_ABORT(BaseFinishReason):
|
||||
|
||||
@dataclass
|
||||
class ImageInputs:
|
||||
"""The image related inputs."""
|
||||
|
||||
pixel_values: torch.Tensor
|
||||
image_hash: int
|
||||
image_sizes: Optional[list] = None
|
||||
@@ -137,7 +149,7 @@ class ImageInputs:
|
||||
|
||||
|
||||
class Req:
|
||||
"""Store all inforamtion of a request."""
|
||||
"""The input and output status of a request."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -393,20 +405,20 @@ class ScheduleBatch:
|
||||
sampling_info: SamplingBatchInfo = None
|
||||
|
||||
# Batched arguments to model runner
|
||||
input_ids: torch.Tensor = None
|
||||
req_pool_indices: torch.Tensor = None
|
||||
seq_lens: torch.Tensor = None
|
||||
position_ids_offsets: torch.Tensor = None
|
||||
input_ids: List[int] = None
|
||||
req_pool_indices: List[int] = None
|
||||
seq_lens: List[int] = None
|
||||
out_cache_loc: torch.Tensor = None
|
||||
extend_num_tokens: int = None
|
||||
|
||||
# For mixed chunekd prefill
|
||||
prefix_lens_cpu: List[int] = None
|
||||
running_bs: int = None
|
||||
|
||||
# For processing logprobs
|
||||
return_logprob: bool = False
|
||||
top_logprobs_nums: List[int] = None
|
||||
top_logprobs_nums: Optional[List[int]] = None
|
||||
|
||||
# For extend and mixed chunekd prefill
|
||||
prefix_lens: List[int] = None
|
||||
extend_lens: List[int] = None
|
||||
extend_num_tokens: int = None
|
||||
running_bs: int = None
|
||||
|
||||
# Stream
|
||||
has_stream: bool = False
|
||||
@@ -466,12 +478,12 @@ class ScheduleBatch:
|
||||
seq_lens = []
|
||||
|
||||
# Allocate memory
|
||||
req_pool_indices_cpu = self.alloc_req_slots(bs)
|
||||
req_pool_indices = self.alloc_req_slots(bs)
|
||||
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
|
||||
|
||||
pt = 0
|
||||
for i, req in enumerate(reqs):
|
||||
req.req_pool_idx = req_pool_indices_cpu[i]
|
||||
req.req_pool_idx = req_pool_indices[i]
|
||||
pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
|
||||
seq_lens.append(seq_len)
|
||||
assert seq_len - pre_len == req.extend_input_len
|
||||
@@ -497,22 +509,19 @@ class ScheduleBatch:
|
||||
pt += req.extend_input_len
|
||||
|
||||
# Set fields
|
||||
with torch.device("cuda"):
|
||||
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32)
|
||||
self.req_pool_indices = torch.tensor(req_pool_indices_cpu)
|
||||
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32)
|
||||
self.position_ids_offsets = torch.zeros((bs,), dtype=torch.int64)
|
||||
self.input_ids = sum(input_ids, [])
|
||||
self.req_pool_indices = torch.tensor(req_pool_indices, device="cuda")
|
||||
self.seq_lens = torch.tensor(seq_lens, device="cuda")
|
||||
|
||||
self.extend_num_tokens = extend_num_tokens
|
||||
self.out_cache_loc = out_cache_loc
|
||||
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
|
||||
self.prefix_lens_cpu = [len(r.prefix_indices) for r in reqs]
|
||||
self.extend_lens_cpu = [r.extend_input_len for r in reqs]
|
||||
self.extend_logprob_start_lens_cpu = [r.extend_logprob_start_len for r in reqs]
|
||||
self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size)
|
||||
if self.return_logprob:
|
||||
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
|
||||
self.prefix_lens = [len(r.prefix_indices) for r in reqs]
|
||||
self.extend_lens = [r.extend_input_len for r in reqs]
|
||||
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
|
||||
|
||||
def get_forward_batch(self):
|
||||
return ForwardBatch.from_schedule_batch(self)
|
||||
self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size)
|
||||
|
||||
def mix_with_running(self, running_batch: "ScheduleBatch"):
|
||||
self.forward_mode = ForwardMode.MIXED
|
||||
@@ -522,24 +531,24 @@ class ScheduleBatch:
|
||||
req.fill_ids = req.origin_input_ids + req.output_ids
|
||||
req.extend_input_len = 1
|
||||
|
||||
input_ids = torch.cat([self.input_ids, running_batch.input_ids])
|
||||
input_ids = self.input_ids + running_batch.input_ids
|
||||
out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
|
||||
extend_num_tokens = self.extend_num_tokens + running_bs
|
||||
|
||||
self.merge(running_batch)
|
||||
self.merge_batch(running_batch)
|
||||
self.input_ids = input_ids
|
||||
self.out_cache_loc = out_cache_loc
|
||||
self.extend_num_tokens = extend_num_tokens
|
||||
|
||||
# NOTE: prefix_indices is what has been cached, but we don't cache each decode step
|
||||
self.prefix_lens_cpu.extend(
|
||||
self.prefix_lens.extend(
|
||||
[
|
||||
len(r.origin_input_ids) + len(r.output_ids) - 1
|
||||
for r in running_batch.reqs
|
||||
]
|
||||
)
|
||||
self.extend_lens_cpu.extend([1] * running_bs)
|
||||
self.extend_logprob_start_lens_cpu.extend([0] * running_bs)
|
||||
self.extend_lens.extend([1] * running_bs)
|
||||
self.extend_logprob_start_lens.extend([0] * running_bs)
|
||||
|
||||
def check_decode_mem(self):
|
||||
bs = len(self.reqs)
|
||||
@@ -631,7 +640,7 @@ class ScheduleBatch:
|
||||
|
||||
return retracted_reqs, new_estimate_ratio
|
||||
|
||||
def check_for_jump_forward(self, model_runner):
|
||||
def check_for_jump_forward(self, pad_input_ids_func):
|
||||
jump_forward_reqs = []
|
||||
filter_indices = [i for i in range(len(self.reqs))]
|
||||
|
||||
@@ -688,7 +697,7 @@ class ScheduleBatch:
|
||||
|
||||
# re-applying image padding
|
||||
if req.image_inputs is not None:
|
||||
req.origin_input_ids = model_runner.model.pad_input_ids(
|
||||
req.origin_input_ids = pad_input_ids_func(
|
||||
req.origin_input_ids_unpadded, req.image_inputs
|
||||
)
|
||||
|
||||
@@ -708,7 +717,7 @@ class ScheduleBatch:
|
||||
for r in self.reqs
|
||||
]
|
||||
|
||||
self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda")
|
||||
self.input_ids = input_ids
|
||||
self.seq_lens.add_(1)
|
||||
|
||||
# Alloc mem
|
||||
@@ -731,32 +740,97 @@ class ScheduleBatch:
|
||||
|
||||
self.reqs = [self.reqs[i] for i in unfinished_indices]
|
||||
new_indices = torch.tensor(unfinished_indices, dtype=torch.int32, device="cuda")
|
||||
self.seq_lens = self.seq_lens[new_indices]
|
||||
self.input_ids = None
|
||||
self.req_pool_indices = self.req_pool_indices[new_indices]
|
||||
self.position_ids_offsets = self.position_ids_offsets[new_indices]
|
||||
self.seq_lens = self.seq_lens[new_indices]
|
||||
self.out_cache_loc = None
|
||||
self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
|
||||
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
||||
if self.return_logprob:
|
||||
self.top_logprobs_nums = [
|
||||
self.top_logprobs_nums[i] for i in unfinished_indices
|
||||
]
|
||||
self.has_stream = any(req.stream for req in self.reqs)
|
||||
|
||||
self.sampling_info.filter(unfinished_indices, new_indices)
|
||||
self.sampling_info.filter_batch(unfinished_indices, new_indices)
|
||||
|
||||
def merge(self, other: "ScheduleBatch"):
|
||||
def merge_batch(self, other: "ScheduleBatch"):
|
||||
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
|
||||
# orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
|
||||
# needs to be called with pre-merged Batch.reqs.
|
||||
self.sampling_info.merge(other.sampling_info)
|
||||
self.sampling_info.merge_batch(other.sampling_info)
|
||||
|
||||
self.reqs.extend(other.reqs)
|
||||
self.req_pool_indices = torch.concat(
|
||||
[self.req_pool_indices, other.req_pool_indices]
|
||||
)
|
||||
self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
|
||||
self.position_ids_offsets = torch.concat(
|
||||
[self.position_ids_offsets, other.position_ids_offsets]
|
||||
)
|
||||
self.out_cache_loc = None
|
||||
self.top_logprobs_nums.extend(other.top_logprobs_nums)
|
||||
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
||||
if self.return_logprob and other.return_logprob:
|
||||
self.top_logprobs_nums.extend(other.top_logprobs_nums)
|
||||
elif self.return_logprob:
|
||||
self.top_logprobs_nums.extend([0] * len(other.reqs))
|
||||
elif other.return_logprob:
|
||||
self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
|
||||
self.has_stream = any(req.stream for req in self.reqs)
|
||||
|
||||
def get_model_worker_batch(self):
|
||||
if self.forward_mode.is_decode():
|
||||
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = (
|
||||
image_inputs
|
||||
) = None
|
||||
else:
|
||||
extend_seq_lens = self.extend_lens
|
||||
extend_prefix_lens = self.prefix_lens
|
||||
extend_logprob_start_lens = self.extend_logprob_start_lens
|
||||
image_inputs = [r.image_inputs for r in self.reqs]
|
||||
|
||||
lora_paths = [req.lora_path for req in self.reqs]
|
||||
self.sampling_info.regex_fsm_states = [req.regex_fsm_state for req in self.reqs]
|
||||
|
||||
return ModelWorkerBatch(
|
||||
forward_mode=self.forward_mode,
|
||||
input_ids=self.input_ids,
|
||||
req_pool_indices=self.req_pool_indices,
|
||||
seq_lens=self.seq_lens,
|
||||
out_cache_loc=self.out_cache_loc,
|
||||
return_logprob=self.return_logprob,
|
||||
top_logprobs_nums=self.top_logprobs_nums,
|
||||
extend_seq_lens=extend_seq_lens,
|
||||
extend_prefix_lens=extend_prefix_lens,
|
||||
extend_logprob_start_lens=extend_logprob_start_lens,
|
||||
image_inputs=image_inputs,
|
||||
lora_paths=lora_paths,
|
||||
sampling_info=self.sampling_info,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelWorkerBatch:
|
||||
# The forward mode
|
||||
forward_mode: ForwardMode
|
||||
# The input ids
|
||||
input_ids: List[int]
|
||||
# The indices of requests in the req_to_token_pool
|
||||
req_pool_indices: torch.Tensor
|
||||
# The sequence length
|
||||
seq_lens: torch.Tensor
|
||||
# The indices of output tokens in the token_to_kv_pool
|
||||
out_cache_loc: torch.Tensor
|
||||
|
||||
# For logprob
|
||||
return_logprob: bool
|
||||
top_logprobs_nums: Optional[List[int]]
|
||||
|
||||
# For extend
|
||||
extend_seq_lens: Optional[List[int]]
|
||||
extend_prefix_lens: Optional[List[int]]
|
||||
extend_logprob_start_lens: Optional[List[int]]
|
||||
|
||||
# For multimodal
|
||||
image_inputs: Optional[List[ImageInputs]]
|
||||
|
||||
# For LoRA
|
||||
lora_paths: Optional[List[str]]
|
||||
|
||||
# Sampling info
|
||||
sampling_info: SamplingBatchInfo
|
||||
|
||||
@@ -141,6 +141,9 @@ class Scheduler:
|
||||
nccl_port=port_args.nccl_ports[0],
|
||||
)
|
||||
self.tp_cpu_group = self.tp_worker.model_runner.tp_group.cpu_group
|
||||
self.pad_input_ids_func = getattr(
|
||||
self.tp_worker.model_runner.model, "pad_input_ids", None
|
||||
)
|
||||
|
||||
# Get token and memory info from the tp worker
|
||||
(
|
||||
@@ -292,7 +295,7 @@ class Scheduler:
|
||||
if self.running_batch is None:
|
||||
self.running_batch = new_batch
|
||||
else:
|
||||
self.running_batch.merge(new_batch)
|
||||
self.running_batch.merge_batch(new_batch)
|
||||
else:
|
||||
# Run a decode batch
|
||||
if self.running_batch is not None:
|
||||
@@ -370,7 +373,7 @@ class Scheduler:
|
||||
req.image_inputs = ImageInputs.from_dict(
|
||||
recv_req.image_inputs, self.model_config.vocab_size
|
||||
)
|
||||
req.origin_input_ids = self.tp_worker.model_runner.model.pad_input_ids(
|
||||
req.origin_input_ids = self.pad_input_ids_func(
|
||||
req.origin_input_ids_unpadded, req.image_inputs
|
||||
)
|
||||
|
||||
@@ -575,9 +578,9 @@ class Scheduler:
|
||||
if self.is_generation:
|
||||
# Forward and sample the next tokens
|
||||
if batch.extend_num_tokens != 0:
|
||||
forward_batch = batch.get_forward_batch()
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
|
||||
forward_batch, batch
|
||||
model_worker_batch
|
||||
)
|
||||
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
||||
next_token_ids
|
||||
@@ -641,8 +644,8 @@ class Scheduler:
|
||||
)
|
||||
else:
|
||||
assert batch.extend_num_tokens != 0
|
||||
forward_batch = batch.get_forward_batch()
|
||||
embeddings = self.tp_worker.forward_batch_embedding(forward_batch)
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
|
||||
|
||||
# Check finish conditions
|
||||
for i, req in enumerate(batch.reqs):
|
||||
@@ -759,9 +762,7 @@ class Scheduler:
|
||||
|
||||
# Check for jump-forward
|
||||
if not self.disable_regex_jump_forward:
|
||||
jump_forward_reqs = batch.check_for_jump_forward(
|
||||
self.tp_worker.model_runner
|
||||
)
|
||||
jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
|
||||
self.waiting_queue.extend(jump_forward_reqs)
|
||||
if batch.is_empty():
|
||||
return
|
||||
@@ -771,9 +772,9 @@ class Scheduler:
|
||||
batch.prepare_for_decode()
|
||||
|
||||
# Forward and sample the next tokens
|
||||
forward_batch = batch.get_forward_batch()
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
|
||||
forward_batch, batch
|
||||
model_worker_batch
|
||||
)
|
||||
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
||||
next_token_ids
|
||||
|
||||
@@ -21,6 +21,7 @@ import logging
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||
from sglang.srt.managers.io_struct import UpdateWeightReqInput
|
||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
@@ -108,12 +109,14 @@ class TpModelWorker:
|
||||
self.random_seed,
|
||||
)
|
||||
|
||||
def forward_batch_generation(self, forward_batch: ForwardBatch, batch):
|
||||
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
|
||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||
logits_output = self.model_runner.forward(forward_batch)
|
||||
next_token_ids = self.model_runner.sample(logits_output, batch)
|
||||
next_token_ids = self.model_runner.sample(logits_output, model_worker_batch)
|
||||
return logits_output, next_token_ids
|
||||
|
||||
def forward_batch_embedding(self, forward_batch: ForwardBatch):
|
||||
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
|
||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||
logits_output = self.model_runner.forward(forward_batch)
|
||||
embeddings = logits_output.embeddings.tolist()
|
||||
return embeddings
|
||||
|
||||
@@ -15,18 +15,33 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
"""Meta data for a forward pass."""
|
||||
"""
|
||||
Store information about a forward batch.
|
||||
|
||||
The following is the flow of data structures for a batch:
|
||||
|
||||
ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
|
||||
|
||||
- ScheduleBatch is managed by `scheduler.py::Scheduler`.
|
||||
It contains high-level scheduling data. Most of the data is on the CPU.
|
||||
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
|
||||
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
|
||||
It contains low-level tensor data. Most of the data consists of GPU tensors.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum, auto
|
||||
from typing import TYPE_CHECKING, List
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.attention_backend import AttentionBackend
|
||||
from sglang.srt.managers.schedule_batch import ImageInputs, ScheduleBatch
|
||||
from sglang.srt.managers.schedule_batch import ImageInputs, ModelWorkerBatch
|
||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
|
||||
|
||||
class ForwardMode(IntEnum):
|
||||
@@ -69,25 +84,28 @@ class ForwardBatch:
|
||||
# The indices of output tokens in the token_to_kv_pool
|
||||
out_cache_loc: torch.Tensor
|
||||
|
||||
# For logprob
|
||||
return_logprob: bool = False
|
||||
top_logprobs_nums: Optional[List[int]] = None
|
||||
|
||||
# Position information
|
||||
positions: torch.Tensor = None
|
||||
|
||||
# For extend
|
||||
extend_seq_lens: torch.Tensor = None
|
||||
extend_prefix_lens: torch.Tensor = None
|
||||
extend_start_loc: torch.Tensor = None
|
||||
|
||||
# For logprob
|
||||
return_logprob: bool = False
|
||||
top_logprobs_nums: List[int] = None
|
||||
extend_seq_lens_cpu: List[int] = None
|
||||
extend_logprob_start_lens_cpu: List[int] = None
|
||||
extend_seq_lens: Optional[torch.Tensor] = None
|
||||
extend_prefix_lens: Optional[torch.Tensor] = None
|
||||
extend_start_loc: Optional[torch.Tensor] = None
|
||||
extend_seq_lens_cpu: Optional[List[int]] = None
|
||||
extend_logprob_start_lens_cpu: Optional[List[int]] = None
|
||||
|
||||
# For multimodal
|
||||
image_inputs: List[ImageInputs] = None
|
||||
image_inputs: Optional[List[ImageInputs]] = None
|
||||
|
||||
# For LoRA
|
||||
lora_paths: List[str] = None
|
||||
lora_paths: Optional[List[str]] = None
|
||||
|
||||
# Sampling info
|
||||
sampling_info: SamplingBatchInfo = None
|
||||
|
||||
# Attention backend
|
||||
req_to_token_pool: ReqToTokenPool = None
|
||||
@@ -95,42 +113,61 @@ class ForwardBatch:
|
||||
attn_backend: AttentionBackend = None
|
||||
|
||||
@classmethod
|
||||
def from_schedule_batch(
|
||||
def init_new(
|
||||
cls,
|
||||
batch: ScheduleBatch,
|
||||
batch: ModelWorkerBatch,
|
||||
model_runner: ModelRunner,
|
||||
):
|
||||
device = "cuda"
|
||||
|
||||
ret = cls(
|
||||
forward_mode=batch.forward_mode,
|
||||
batch_size=batch.batch_size(),
|
||||
input_ids=batch.input_ids,
|
||||
batch_size=len(batch.seq_lens),
|
||||
input_ids=torch.tensor(batch.input_ids, dtype=torch.int32, device=device),
|
||||
req_pool_indices=batch.req_pool_indices,
|
||||
seq_lens=batch.seq_lens,
|
||||
out_cache_loc=batch.out_cache_loc,
|
||||
return_logprob=batch.return_logprob,
|
||||
top_logprobs_nums=batch.top_logprobs_nums,
|
||||
lora_paths=[req.lora_path for req in batch.reqs],
|
||||
lora_paths=batch.lora_paths,
|
||||
sampling_info=batch.sampling_info,
|
||||
)
|
||||
|
||||
# Init position information
|
||||
if ret.forward_mode.is_decode():
|
||||
ret.positions = (ret.seq_lens - 1).to(torch.int64)
|
||||
else:
|
||||
ret.positions = torch.tensor(
|
||||
np.concatenate(
|
||||
[
|
||||
np.arange(batch.prefix_lens_cpu[i], len(req.fill_ids))
|
||||
for i, req in enumerate(batch.reqs)
|
||||
np.arange(prefix_len, prefix_len + extend_len)
|
||||
for prefix_len, extend_len in zip(
|
||||
batch.extend_prefix_lens, batch.extend_seq_lens
|
||||
)
|
||||
],
|
||||
axis=0,
|
||||
),
|
||||
device="cuda",
|
||||
device=device,
|
||||
).to(torch.int64)
|
||||
|
||||
ret.image_inputs = [r.image_inputs for r in batch.reqs]
|
||||
ret.extend_seq_lens = torch.tensor(batch.extend_lens_cpu, device="cuda")
|
||||
ret.extend_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
|
||||
ret.image_inputs = batch.image_inputs
|
||||
ret.extend_seq_lens = torch.tensor(batch.extend_seq_lens, device=device)
|
||||
ret.extend_prefix_lens = torch.tensor(
|
||||
batch.extend_prefix_lens, device=device
|
||||
)
|
||||
ret.extend_start_loc = torch.zeros_like(ret.extend_seq_lens)
|
||||
ret.extend_start_loc[1:] = torch.cumsum(ret.extend_seq_lens[:-1], dim=0)
|
||||
ret.extend_seq_lens_cpu = batch.extend_lens_cpu
|
||||
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens_cpu
|
||||
ret.extend_seq_lens_cpu = batch.extend_seq_lens
|
||||
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
|
||||
|
||||
# Init attention information
|
||||
ret.req_to_token_pool = model_runner.req_to_token_pool
|
||||
ret.token_to_kv_pool = model_runner.token_to_kv_pool
|
||||
ret.attn_backend = model_runner.attn_backend
|
||||
model_runner.attn_backend.init_forward_metadata(ret)
|
||||
|
||||
# Init lora information
|
||||
if model_runner.server_args.lora_paths is not None:
|
||||
model_runner.lora_manager.prepare_lora_batch(ret)
|
||||
|
||||
return ret
|
||||
|
||||
@@ -21,7 +21,7 @@ import importlib.resources
|
||||
import logging
|
||||
import pkgutil
|
||||
from functools import lru_cache
|
||||
from typing import Optional, Tuple, Type
|
||||
from typing import Optional, Type
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -38,11 +38,12 @@ from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
|
||||
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
||||
from sglang.srt.constrained import disable_cache
|
||||
from sglang.srt.layers.attention_backend import FlashInferAttnBackend, TritonAttnBackend
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.layers.sampler import Sampler
|
||||
from sglang.srt.lora.lora_manager import LoRAManager
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.mem_cache.memory_pool import (
|
||||
MHATokenToKVPool,
|
||||
MLATokenToKVPool,
|
||||
@@ -52,6 +53,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
enable_show_time_cost,
|
||||
get_available_gpu_memory,
|
||||
is_generation_model,
|
||||
is_multimodal_model,
|
||||
@@ -102,6 +104,12 @@ class ModelRunner:
|
||||
server_args.chunked_prefill_size = None
|
||||
server_args.mem_fraction_static *= 0.95
|
||||
|
||||
# Global vars
|
||||
if server_args.show_time_cost:
|
||||
enable_show_time_cost()
|
||||
if server_args.disable_disk_cache:
|
||||
disable_cache()
|
||||
|
||||
global_server_args_dict.update(
|
||||
{
|
||||
"attention_backend": server_args.attention_backend,
|
||||
@@ -491,16 +499,6 @@ class ModelRunner:
|
||||
)
|
||||
|
||||
def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
|
||||
# Attach attention information
|
||||
forward_batch.req_to_token_pool = self.req_to_token_pool
|
||||
forward_batch.token_to_kv_pool = self.token_to_kv_pool
|
||||
forward_batch.attn_backend = self.attn_backend
|
||||
forward_batch.attn_backend.init_forward_metadata(forward_batch)
|
||||
|
||||
# Attach lora information
|
||||
if self.server_args.lora_paths is not None:
|
||||
self.lora_manager.prepare_lora_batch(forward_batch)
|
||||
|
||||
if forward_batch.forward_mode.is_decode():
|
||||
return self.forward_decode(forward_batch)
|
||||
elif forward_batch.forward_mode.is_extend():
|
||||
@@ -508,16 +506,27 @@ class ModelRunner:
|
||||
else:
|
||||
raise ValueError(f"Invaid forward mode: {forward_batch.forward_mode}")
|
||||
|
||||
def _apply_logits_bias(
|
||||
self, logits: torch.Tensor, sampling_info: SamplingBatchInfo
|
||||
):
|
||||
def sample(
|
||||
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
|
||||
) -> torch.Tensor:
|
||||
# Put CPU-heavy tasks here. They will be overlapped with the forward pass.
|
||||
sampling_info = forward_batch.sampling_info
|
||||
sampling_info.update_regex_vocab_mask()
|
||||
sampling_info.update_penalties()
|
||||
logits = self.apply_logits_bias(logits_output.next_token_logits, sampling_info)
|
||||
|
||||
# Sample the next tokens.
|
||||
next_token_ids = self.sampler(logits, sampling_info)
|
||||
return next_token_ids
|
||||
|
||||
def apply_logits_bias(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
|
||||
# Apply logit_bias
|
||||
if sampling_info.logit_bias is not None:
|
||||
logits.add_(sampling_info.logit_bias)
|
||||
|
||||
# min-token, presence, frequency
|
||||
if sampling_info.linear_penalties is not None:
|
||||
logits += sampling_info.linear_penalties
|
||||
logits.add_(sampling_info.linear_penalties)
|
||||
|
||||
# repetition
|
||||
if sampling_info.scaling_penalties is not None:
|
||||
@@ -533,20 +542,6 @@ class ModelRunner:
|
||||
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self, logits_output: LogitsProcessorOutput, batch: ScheduleBatch
|
||||
) -> torch.Tensor:
|
||||
# Put CPU-heavy tasks here. They will be overlapped with the forward pass.
|
||||
batch.sampling_info.update_regex_vocab_mask(batch)
|
||||
batch.sampling_info.update_penalties()
|
||||
logits = self._apply_logits_bias(
|
||||
logits_output.next_token_logits, batch.sampling_info
|
||||
)
|
||||
|
||||
# Sample the next tokens.
|
||||
next_token_ids = self.sampler(logits, batch.sampling_info)
|
||||
return next_token_ids
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def import_model_classes():
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, List
|
||||
import torch
|
||||
|
||||
import sglang.srt.sampling.penaltylib as penaltylib
|
||||
from sglang.srt.constrained import RegexGuide
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||
@@ -22,13 +23,17 @@ class SamplingBatchInfo:
|
||||
top_ks: torch.Tensor = None
|
||||
min_ps: torch.Tensor = None
|
||||
|
||||
# Dispatch in CUDA graph
|
||||
need_min_p_sampling: bool = False
|
||||
|
||||
# Bias Tensors
|
||||
logit_bias: torch.Tensor = None
|
||||
vocab_mask: torch.Tensor = None
|
||||
|
||||
# FSM states
|
||||
regex_fsms: List[RegexGuide] = None
|
||||
regex_fsm_states: List[int] = None
|
||||
|
||||
# Dispatch in CUDA graph
|
||||
need_min_p_sampling: bool = False
|
||||
|
||||
# Penalizer
|
||||
penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
|
||||
linear_penalties: torch.Tensor = None
|
||||
@@ -54,6 +59,8 @@ class SamplingBatchInfo:
|
||||
[r.sampling_params.min_p for r in reqs], dtype=torch.float
|
||||
)
|
||||
|
||||
ret.regex_fsms = [r.regex_fsm for r in reqs]
|
||||
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
|
||||
ret.need_min_p_sampling = any(r.sampling_params.min_p > 0 for r in reqs)
|
||||
|
||||
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
|
||||
@@ -102,24 +109,22 @@ class SamplingBatchInfo:
|
||||
)
|
||||
self.linear_penalties = penalizer.apply(self.linear_penalties)
|
||||
|
||||
def update_regex_vocab_mask(self, batch: ScheduleBatch):
|
||||
has_regex = any(req.regex_fsm is not None for req in batch.reqs)
|
||||
|
||||
def update_regex_vocab_mask(self):
|
||||
# Reset the vocab mask
|
||||
self.vocab_mask = None
|
||||
|
||||
if has_regex:
|
||||
if any(regex_fsm is not None for regex_fsm in self.regex_fsms):
|
||||
self.vocab_mask = torch.zeros(
|
||||
batch.batch_size(), self.vocab_size, dtype=torch.bool, device="cuda"
|
||||
len(self.regex_fsms), self.vocab_size, dtype=torch.bool, device="cuda"
|
||||
)
|
||||
for i, req in enumerate(batch.reqs):
|
||||
if req.regex_fsm is not None:
|
||||
for i, regex_fsm in enumerate(self.regex_fsms):
|
||||
if regex_fsm is not None:
|
||||
self.vocab_mask[i].fill_(1)
|
||||
self.vocab_mask[i][
|
||||
req.regex_fsm.get_next_instruction(req.regex_fsm_state).tokens
|
||||
regex_fsm.get_next_instruction(self.regex_fsm_states[i]).tokens
|
||||
] = 0
|
||||
|
||||
def filter(self, unfinished_indices: List[int], new_indices: torch.Tensor):
|
||||
def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
|
||||
self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
|
||||
|
||||
for item in [
|
||||
@@ -129,9 +134,11 @@ class SamplingBatchInfo:
|
||||
"min_ps",
|
||||
"logit_bias",
|
||||
]:
|
||||
self_val = getattr(self, item, None)
|
||||
if self_val is not None: # logit_bias can be None
|
||||
setattr(self, item, self_val[new_indices])
|
||||
value = getattr(self, item, None)
|
||||
if value is not None: # logit_bias can be None
|
||||
setattr(self, item, value[new_indices])
|
||||
|
||||
self.regex_fsms = [self.regex_fsms[i] for i in new_indices]
|
||||
|
||||
@staticmethod
|
||||
def merge_bias_tensor(
|
||||
@@ -153,7 +160,7 @@ class SamplingBatchInfo:
|
||||
|
||||
return None
|
||||
|
||||
def merge(self, other: "SamplingBatchInfo"):
|
||||
def merge_batch(self, other: "SamplingBatchInfo"):
|
||||
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
|
||||
|
||||
for item in [
|
||||
@@ -169,3 +176,5 @@ class SamplingBatchInfo:
|
||||
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
|
||||
self.logit_bias, other.logit_bias, len(self), len(other)
|
||||
)
|
||||
|
||||
self.regex_fsms.extend(other.regex_fsms)
|
||||
|
||||
@@ -41,7 +41,6 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
|
||||
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
|
||||
from sglang.srt.constrained import disable_cache
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
|
||||
from sglang.srt.managers.io_struct import (
|
||||
@@ -72,8 +71,6 @@ from sglang.srt.utils import (
|
||||
allocate_init_ports,
|
||||
assert_pkg_version,
|
||||
configure_logger,
|
||||
enable_show_time_cost,
|
||||
is_hip,
|
||||
kill_child_process,
|
||||
maybe_set_triton_cache_manager,
|
||||
prepare_model_and_tokenizer,
|
||||
@@ -400,14 +397,6 @@ def _set_envs_and_config(server_args: ServerArgs):
|
||||
# Set ulimit
|
||||
set_ulimit()
|
||||
|
||||
# Enable show time cost for debugging
|
||||
if server_args.show_time_cost:
|
||||
enable_show_time_cost()
|
||||
|
||||
# Disable disk cache
|
||||
if server_args.disable_disk_cache:
|
||||
disable_cache()
|
||||
|
||||
# Fix triton bugs
|
||||
if server_args.tp_size * server_args.dp_size > 1:
|
||||
# FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
|
||||
|
||||
Reference in New Issue
Block a user