From 63ba2f8d7bf895938a3f4039910044ce6912d57e Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Mon, 30 Sep 2024 06:41:49 -0700 Subject: [PATCH] Clean up batch data structures: Introducing ModelWorkerBatch (#1544) --- python/sglang/bench_latency.py | 23 ++- python/sglang/srt/layers/logits_processor.py | 6 +- python/sglang/srt/managers/schedule_batch.py | 170 +++++++++++++----- python/sglang/srt/managers/scheduler.py | 23 +-- python/sglang/srt/managers/tp_worker.py | 9 +- .../srt/model_executor/forward_batch_info.py | 91 +++++++--- .../sglang/srt/model_executor/model_runner.py | 55 +++--- .../srt/sampling/sampling_batch_info.py | 41 +++-- python/sglang/srt/server.py | 11 -- 9 files changed, 274 insertions(+), 155 deletions(-) diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index 6fbeb80e8..47aca5059 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -62,11 +62,13 @@ import torch.distributed as dist from sglang.srt.configs.model_config import ModelConfig from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.managers.schedule_batch import Req, ScheduleBatch +from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server import _set_envs_and_config from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( + allocate_init_ports, configure_logger, kill_child_process, suppress_other_loggers, @@ -125,6 +127,11 @@ def load_model(server_args, tp_rank): suppress_other_loggers() rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None + server_args.port, server_args.additional_ports = allocate_init_ports( + server_args.port, + server_args.additional_ports, + server_args.dp_size, + ) model_config = ModelConfig( server_args.model_path, server_args.trust_remote_code, @@ -136,7 +143,7 @@ def load_model(server_args, tp_rank): gpu_id=tp_rank, tp_rank=tp_rank, tp_size=server_args.tp_size, - nccl_port=28888, + nccl_port=server_args.additional_ports[-1], server_args=server_args, ) rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}") @@ -225,17 +232,19 @@ def extend(reqs, model_runner): tree_cache=None, ) batch.prepare_for_extend(model_runner.model_config.vocab_size) - forward_batch = batch.get_forward_batch() + model_worker_batch = batch.get_model_worker_batch() + forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner) logits_output = model_runner.forward(forward_batch) - next_token_ids = model_runner.sample(logits_output, batch).tolist() + next_token_ids = model_runner.sample(logits_output, forward_batch).tolist() return next_token_ids, logits_output.next_token_logits, batch def decode(input_token_ids, batch, model_runner): batch.prepare_for_decode(input_token_ids) - forward_batch = batch.get_forward_batch() + model_worker_batch = batch.get_model_worker_batch() + forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner) logits_output = model_runner.forward(forward_batch) - next_token_ids = model_runner.sample(logits_output, batch).tolist() + next_token_ids = model_runner.sample(logits_output, forward_batch).tolist() return next_token_ids, logits_output.next_token_logits @@ -357,7 +366,6 @@ def latency_test( tp_rank, ): configure_logger(server_args, prefix=f" TP{tp_rank}") - _set_envs_and_config(server_args) rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None # Load the model @@ -463,6 +471,7 @@ def plot_latency_test( def main(server_args, bench_args): + _set_envs_and_config(server_args) if server_args.model_path: if bench_args.correctness_test: @@ -513,8 +522,6 @@ if __name__ == "__main__": format="%(message)s", ) - multiprocessing.set_start_method("spawn", force=True) - try: main(server_args, bench_args) except Exception as e: diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 86eec65cc..f0c55af62 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -62,7 +62,11 @@ class LogitsMetadata: @classmethod def from_forward_batch(cls, forward_batch: ForwardBatch): - return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums) + if forward_batch.return_logprob: + return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums) + else: + return_top_logprob = False + if forward_batch.forward_mode.is_extend(): extend_logprob_pruned_lens_cpu = [ extend_len - start_len diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 50d0bbd86..6fcc9616f 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1,5 +1,3 @@ -from __future__ import annotations - """ Copyright 2023-2024 SGLang Team Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,7 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. """ -"""Meta data for requests and batches""" +""" +Store information about requests and batches. + +The following is the flow of data structures for a batch: + +ScheduleBatch -> ModelWorkerBatch -> ForwardBatch + +- ScheduleBatch is managed by `scheduler.py::Scheduler`. + It contains high-level scheduling data. Most of the data is on the CPU. +- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`. +- ForwardBatch is managed by `model_runner.py::ModelRunner`. + It contains low-level tensor data. Most of the data consists of GPU tensors. +""" import logging from dataclasses import dataclass @@ -29,7 +39,7 @@ from sglang.srt.constrained.jump_forward import JumpForwardMap from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool -from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import ServerArgs @@ -105,6 +115,8 @@ class FINISH_ABORT(BaseFinishReason): @dataclass class ImageInputs: + """The image related inputs.""" + pixel_values: torch.Tensor image_hash: int image_sizes: Optional[list] = None @@ -137,7 +149,7 @@ class ImageInputs: class Req: - """Store all inforamtion of a request.""" + """The input and output status of a request.""" def __init__( self, @@ -393,20 +405,20 @@ class ScheduleBatch: sampling_info: SamplingBatchInfo = None # Batched arguments to model runner - input_ids: torch.Tensor = None - req_pool_indices: torch.Tensor = None - seq_lens: torch.Tensor = None - position_ids_offsets: torch.Tensor = None + input_ids: List[int] = None + req_pool_indices: List[int] = None + seq_lens: List[int] = None out_cache_loc: torch.Tensor = None - extend_num_tokens: int = None - - # For mixed chunekd prefill - prefix_lens_cpu: List[int] = None - running_bs: int = None # For processing logprobs return_logprob: bool = False - top_logprobs_nums: List[int] = None + top_logprobs_nums: Optional[List[int]] = None + + # For extend and mixed chunekd prefill + prefix_lens: List[int] = None + extend_lens: List[int] = None + extend_num_tokens: int = None + running_bs: int = None # Stream has_stream: bool = False @@ -466,12 +478,12 @@ class ScheduleBatch: seq_lens = [] # Allocate memory - req_pool_indices_cpu = self.alloc_req_slots(bs) + req_pool_indices = self.alloc_req_slots(bs) out_cache_loc = self.alloc_token_slots(extend_num_tokens) pt = 0 for i, req in enumerate(reqs): - req.req_pool_idx = req_pool_indices_cpu[i] + req.req_pool_idx = req_pool_indices[i] pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids) seq_lens.append(seq_len) assert seq_len - pre_len == req.extend_input_len @@ -497,22 +509,19 @@ class ScheduleBatch: pt += req.extend_input_len # Set fields - with torch.device("cuda"): - self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32) - self.req_pool_indices = torch.tensor(req_pool_indices_cpu) - self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32) - self.position_ids_offsets = torch.zeros((bs,), dtype=torch.int64) + self.input_ids = sum(input_ids, []) + self.req_pool_indices = torch.tensor(req_pool_indices, device="cuda") + self.seq_lens = torch.tensor(seq_lens, device="cuda") self.extend_num_tokens = extend_num_tokens self.out_cache_loc = out_cache_loc - self.top_logprobs_nums = [r.top_logprobs_num for r in reqs] - self.prefix_lens_cpu = [len(r.prefix_indices) for r in reqs] - self.extend_lens_cpu = [r.extend_input_len for r in reqs] - self.extend_logprob_start_lens_cpu = [r.extend_logprob_start_len for r in reqs] - self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size) + if self.return_logprob: + self.top_logprobs_nums = [r.top_logprobs_num for r in reqs] + self.prefix_lens = [len(r.prefix_indices) for r in reqs] + self.extend_lens = [r.extend_input_len for r in reqs] + self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs] - def get_forward_batch(self): - return ForwardBatch.from_schedule_batch(self) + self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size) def mix_with_running(self, running_batch: "ScheduleBatch"): self.forward_mode = ForwardMode.MIXED @@ -522,24 +531,24 @@ class ScheduleBatch: req.fill_ids = req.origin_input_ids + req.output_ids req.extend_input_len = 1 - input_ids = torch.cat([self.input_ids, running_batch.input_ids]) + input_ids = self.input_ids + running_batch.input_ids out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc]) extend_num_tokens = self.extend_num_tokens + running_bs - self.merge(running_batch) + self.merge_batch(running_batch) self.input_ids = input_ids self.out_cache_loc = out_cache_loc self.extend_num_tokens = extend_num_tokens # NOTE: prefix_indices is what has been cached, but we don't cache each decode step - self.prefix_lens_cpu.extend( + self.prefix_lens.extend( [ len(r.origin_input_ids) + len(r.output_ids) - 1 for r in running_batch.reqs ] ) - self.extend_lens_cpu.extend([1] * running_bs) - self.extend_logprob_start_lens_cpu.extend([0] * running_bs) + self.extend_lens.extend([1] * running_bs) + self.extend_logprob_start_lens.extend([0] * running_bs) def check_decode_mem(self): bs = len(self.reqs) @@ -631,7 +640,7 @@ class ScheduleBatch: return retracted_reqs, new_estimate_ratio - def check_for_jump_forward(self, model_runner): + def check_for_jump_forward(self, pad_input_ids_func): jump_forward_reqs = [] filter_indices = [i for i in range(len(self.reqs))] @@ -688,7 +697,7 @@ class ScheduleBatch: # re-applying image padding if req.image_inputs is not None: - req.origin_input_ids = model_runner.model.pad_input_ids( + req.origin_input_ids = pad_input_ids_func( req.origin_input_ids_unpadded, req.image_inputs ) @@ -708,7 +717,7 @@ class ScheduleBatch: for r in self.reqs ] - self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda") + self.input_ids = input_ids self.seq_lens.add_(1) # Alloc mem @@ -731,32 +740,97 @@ class ScheduleBatch: self.reqs = [self.reqs[i] for i in unfinished_indices] new_indices = torch.tensor(unfinished_indices, dtype=torch.int32, device="cuda") - self.seq_lens = self.seq_lens[new_indices] - self.input_ids = None self.req_pool_indices = self.req_pool_indices[new_indices] - self.position_ids_offsets = self.position_ids_offsets[new_indices] + self.seq_lens = self.seq_lens[new_indices] self.out_cache_loc = None - self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices] self.return_logprob = any(req.return_logprob for req in self.reqs) + if self.return_logprob: + self.top_logprobs_nums = [ + self.top_logprobs_nums[i] for i in unfinished_indices + ] self.has_stream = any(req.stream for req in self.reqs) - self.sampling_info.filter(unfinished_indices, new_indices) + self.sampling_info.filter_batch(unfinished_indices, new_indices) - def merge(self, other: "ScheduleBatch"): + def merge_batch(self, other: "ScheduleBatch"): # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because # orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it # needs to be called with pre-merged Batch.reqs. - self.sampling_info.merge(other.sampling_info) + self.sampling_info.merge_batch(other.sampling_info) self.reqs.extend(other.reqs) self.req_pool_indices = torch.concat( [self.req_pool_indices, other.req_pool_indices] ) self.seq_lens = torch.concat([self.seq_lens, other.seq_lens]) - self.position_ids_offsets = torch.concat( - [self.position_ids_offsets, other.position_ids_offsets] - ) self.out_cache_loc = None - self.top_logprobs_nums.extend(other.top_logprobs_nums) self.return_logprob = any(req.return_logprob for req in self.reqs) + if self.return_logprob and other.return_logprob: + self.top_logprobs_nums.extend(other.top_logprobs_nums) + elif self.return_logprob: + self.top_logprobs_nums.extend([0] * len(other.reqs)) + elif other.return_logprob: + self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums self.has_stream = any(req.stream for req in self.reqs) + + def get_model_worker_batch(self): + if self.forward_mode.is_decode(): + extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = ( + image_inputs + ) = None + else: + extend_seq_lens = self.extend_lens + extend_prefix_lens = self.prefix_lens + extend_logprob_start_lens = self.extend_logprob_start_lens + image_inputs = [r.image_inputs for r in self.reqs] + + lora_paths = [req.lora_path for req in self.reqs] + self.sampling_info.regex_fsm_states = [req.regex_fsm_state for req in self.reqs] + + return ModelWorkerBatch( + forward_mode=self.forward_mode, + input_ids=self.input_ids, + req_pool_indices=self.req_pool_indices, + seq_lens=self.seq_lens, + out_cache_loc=self.out_cache_loc, + return_logprob=self.return_logprob, + top_logprobs_nums=self.top_logprobs_nums, + extend_seq_lens=extend_seq_lens, + extend_prefix_lens=extend_prefix_lens, + extend_logprob_start_lens=extend_logprob_start_lens, + image_inputs=image_inputs, + lora_paths=lora_paths, + sampling_info=self.sampling_info, + ) + + +@dataclass +class ModelWorkerBatch: + # The forward mode + forward_mode: ForwardMode + # The input ids + input_ids: List[int] + # The indices of requests in the req_to_token_pool + req_pool_indices: torch.Tensor + # The sequence length + seq_lens: torch.Tensor + # The indices of output tokens in the token_to_kv_pool + out_cache_loc: torch.Tensor + + # For logprob + return_logprob: bool + top_logprobs_nums: Optional[List[int]] + + # For extend + extend_seq_lens: Optional[List[int]] + extend_prefix_lens: Optional[List[int]] + extend_logprob_start_lens: Optional[List[int]] + + # For multimodal + image_inputs: Optional[List[ImageInputs]] + + # For LoRA + lora_paths: Optional[List[str]] + + # Sampling info + sampling_info: SamplingBatchInfo diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 5a31fd65c..9679bb0ae 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -141,6 +141,9 @@ class Scheduler: nccl_port=port_args.nccl_ports[0], ) self.tp_cpu_group = self.tp_worker.model_runner.tp_group.cpu_group + self.pad_input_ids_func = getattr( + self.tp_worker.model_runner.model, "pad_input_ids", None + ) # Get token and memory info from the tp worker ( @@ -292,7 +295,7 @@ class Scheduler: if self.running_batch is None: self.running_batch = new_batch else: - self.running_batch.merge(new_batch) + self.running_batch.merge_batch(new_batch) else: # Run a decode batch if self.running_batch is not None: @@ -370,7 +373,7 @@ class Scheduler: req.image_inputs = ImageInputs.from_dict( recv_req.image_inputs, self.model_config.vocab_size ) - req.origin_input_ids = self.tp_worker.model_runner.model.pad_input_ids( + req.origin_input_ids = self.pad_input_ids_func( req.origin_input_ids_unpadded, req.image_inputs ) @@ -575,9 +578,9 @@ class Scheduler: if self.is_generation: # Forward and sample the next tokens if batch.extend_num_tokens != 0: - forward_batch = batch.get_forward_batch() + model_worker_batch = batch.get_model_worker_batch() logits_output, next_token_ids = self.tp_worker.forward_batch_generation( - forward_batch, batch + model_worker_batch ) batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens( next_token_ids @@ -641,8 +644,8 @@ class Scheduler: ) else: assert batch.extend_num_tokens != 0 - forward_batch = batch.get_forward_batch() - embeddings = self.tp_worker.forward_batch_embedding(forward_batch) + model_worker_batch = batch.get_model_worker_batch() + embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch) # Check finish conditions for i, req in enumerate(batch.reqs): @@ -759,9 +762,7 @@ class Scheduler: # Check for jump-forward if not self.disable_regex_jump_forward: - jump_forward_reqs = batch.check_for_jump_forward( - self.tp_worker.model_runner - ) + jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func) self.waiting_queue.extend(jump_forward_reqs) if batch.is_empty(): return @@ -771,9 +772,9 @@ class Scheduler: batch.prepare_for_decode() # Forward and sample the next tokens - forward_batch = batch.get_forward_batch() + model_worker_batch = batch.get_model_worker_batch() logits_output, next_token_ids = self.tp_worker.forward_batch_generation( - forward_batch, batch + model_worker_batch ) batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens( next_token_ids diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index b62651fae..73c4abe08 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -21,6 +21,7 @@ import logging from sglang.srt.configs.model_config import ModelConfig from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.managers.io_struct import UpdateWeightReqInput +from sglang.srt.managers.schedule_batch import ModelWorkerBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.server_args import ServerArgs @@ -108,12 +109,14 @@ class TpModelWorker: self.random_seed, ) - def forward_batch_generation(self, forward_batch: ForwardBatch, batch): + def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): + forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) logits_output = self.model_runner.forward(forward_batch) - next_token_ids = self.model_runner.sample(logits_output, batch) + next_token_ids = self.model_runner.sample(logits_output, model_worker_batch) return logits_output, next_token_ids - def forward_batch_embedding(self, forward_batch: ForwardBatch): + def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch): + forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) logits_output = self.model_runner.forward(forward_batch) embeddings = logits_output.embeddings.tolist() return embeddings diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index e5b8ff34e..6351d54e3 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -15,18 +15,33 @@ See the License for the specific language governing permissions and limitations under the License. """ -"""Meta data for a forward pass.""" +""" +Store information about a forward batch. + +The following is the flow of data structures for a batch: + +ScheduleBatch -> ModelWorkerBatch -> ForwardBatch + +- ScheduleBatch is managed by `scheduler.py::Scheduler`. + It contains high-level scheduling data. Most of the data is on the CPU. +- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`. +- ForwardBatch is managed by `model_runner.py::ModelRunner`. + It contains low-level tensor data. Most of the data consists of GPU tensors. +""" + from dataclasses import dataclass from enum import IntEnum, auto -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING, List, Optional import numpy as np import torch if TYPE_CHECKING: from sglang.srt.layers.attention_backend import AttentionBackend - from sglang.srt.managers.schedule_batch import ImageInputs, ScheduleBatch + from sglang.srt.managers.schedule_batch import ImageInputs, ModelWorkerBatch from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool + from sglang.srt.model_executor.model_runner import ModelRunner + from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo class ForwardMode(IntEnum): @@ -69,25 +84,28 @@ class ForwardBatch: # The indices of output tokens in the token_to_kv_pool out_cache_loc: torch.Tensor + # For logprob + return_logprob: bool = False + top_logprobs_nums: Optional[List[int]] = None + # Position information positions: torch.Tensor = None # For extend - extend_seq_lens: torch.Tensor = None - extend_prefix_lens: torch.Tensor = None - extend_start_loc: torch.Tensor = None - - # For logprob - return_logprob: bool = False - top_logprobs_nums: List[int] = None - extend_seq_lens_cpu: List[int] = None - extend_logprob_start_lens_cpu: List[int] = None + extend_seq_lens: Optional[torch.Tensor] = None + extend_prefix_lens: Optional[torch.Tensor] = None + extend_start_loc: Optional[torch.Tensor] = None + extend_seq_lens_cpu: Optional[List[int]] = None + extend_logprob_start_lens_cpu: Optional[List[int]] = None # For multimodal - image_inputs: List[ImageInputs] = None + image_inputs: Optional[List[ImageInputs]] = None # For LoRA - lora_paths: List[str] = None + lora_paths: Optional[List[str]] = None + + # Sampling info + sampling_info: SamplingBatchInfo = None # Attention backend req_to_token_pool: ReqToTokenPool = None @@ -95,42 +113,61 @@ class ForwardBatch: attn_backend: AttentionBackend = None @classmethod - def from_schedule_batch( + def init_new( cls, - batch: ScheduleBatch, + batch: ModelWorkerBatch, + model_runner: ModelRunner, ): + device = "cuda" + ret = cls( forward_mode=batch.forward_mode, - batch_size=batch.batch_size(), - input_ids=batch.input_ids, + batch_size=len(batch.seq_lens), + input_ids=torch.tensor(batch.input_ids, dtype=torch.int32, device=device), req_pool_indices=batch.req_pool_indices, seq_lens=batch.seq_lens, out_cache_loc=batch.out_cache_loc, return_logprob=batch.return_logprob, top_logprobs_nums=batch.top_logprobs_nums, - lora_paths=[req.lora_path for req in batch.reqs], + lora_paths=batch.lora_paths, + sampling_info=batch.sampling_info, ) + # Init position information if ret.forward_mode.is_decode(): ret.positions = (ret.seq_lens - 1).to(torch.int64) else: ret.positions = torch.tensor( np.concatenate( [ - np.arange(batch.prefix_lens_cpu[i], len(req.fill_ids)) - for i, req in enumerate(batch.reqs) + np.arange(prefix_len, prefix_len + extend_len) + for prefix_len, extend_len in zip( + batch.extend_prefix_lens, batch.extend_seq_lens + ) ], axis=0, ), - device="cuda", + device=device, ).to(torch.int64) - ret.image_inputs = [r.image_inputs for r in batch.reqs] - ret.extend_seq_lens = torch.tensor(batch.extend_lens_cpu, device="cuda") - ret.extend_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda") + ret.image_inputs = batch.image_inputs + ret.extend_seq_lens = torch.tensor(batch.extend_seq_lens, device=device) + ret.extend_prefix_lens = torch.tensor( + batch.extend_prefix_lens, device=device + ) ret.extend_start_loc = torch.zeros_like(ret.extend_seq_lens) ret.extend_start_loc[1:] = torch.cumsum(ret.extend_seq_lens[:-1], dim=0) - ret.extend_seq_lens_cpu = batch.extend_lens_cpu - ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens_cpu + ret.extend_seq_lens_cpu = batch.extend_seq_lens + ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens + + # Init attention information + ret.req_to_token_pool = model_runner.req_to_token_pool + ret.token_to_kv_pool = model_runner.token_to_kv_pool + ret.attn_backend = model_runner.attn_backend + model_runner.attn_backend.init_forward_metadata(ret) + + # Init lora information + if model_runner.server_args.lora_paths is not None: + model_runner.lora_manager.prepare_lora_batch(ret) return ret diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 7cee4c22a..d4687a0a5 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -21,7 +21,7 @@ import importlib.resources import logging import pkgutil from functools import lru_cache -from typing import Optional, Tuple, Type +from typing import Optional, Type import torch import torch.nn as nn @@ -38,11 +38,12 @@ from vllm.model_executor.model_loader import get_model from vllm.model_executor.models import ModelRegistry from sglang.srt.configs.model_config import AttentionArch, ModelConfig +from sglang.srt.constrained import disable_cache from sglang.srt.layers.attention_backend import FlashInferAttnBackend, TritonAttnBackend from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.sampler import Sampler from sglang.srt.lora.lora_manager import LoRAManager -from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.mem_cache.memory_pool import ( MHATokenToKVPool, MLATokenToKVPool, @@ -52,6 +53,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( + enable_show_time_cost, get_available_gpu_memory, is_generation_model, is_multimodal_model, @@ -102,6 +104,12 @@ class ModelRunner: server_args.chunked_prefill_size = None server_args.mem_fraction_static *= 0.95 + # Global vars + if server_args.show_time_cost: + enable_show_time_cost() + if server_args.disable_disk_cache: + disable_cache() + global_server_args_dict.update( { "attention_backend": server_args.attention_backend, @@ -491,16 +499,6 @@ class ModelRunner: ) def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput: - # Attach attention information - forward_batch.req_to_token_pool = self.req_to_token_pool - forward_batch.token_to_kv_pool = self.token_to_kv_pool - forward_batch.attn_backend = self.attn_backend - forward_batch.attn_backend.init_forward_metadata(forward_batch) - - # Attach lora information - if self.server_args.lora_paths is not None: - self.lora_manager.prepare_lora_batch(forward_batch) - if forward_batch.forward_mode.is_decode(): return self.forward_decode(forward_batch) elif forward_batch.forward_mode.is_extend(): @@ -508,16 +506,27 @@ class ModelRunner: else: raise ValueError(f"Invaid forward mode: {forward_batch.forward_mode}") - def _apply_logits_bias( - self, logits: torch.Tensor, sampling_info: SamplingBatchInfo - ): + def sample( + self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch + ) -> torch.Tensor: + # Put CPU-heavy tasks here. They will be overlapped with the forward pass. + sampling_info = forward_batch.sampling_info + sampling_info.update_regex_vocab_mask() + sampling_info.update_penalties() + logits = self.apply_logits_bias(logits_output.next_token_logits, sampling_info) + + # Sample the next tokens. + next_token_ids = self.sampler(logits, sampling_info) + return next_token_ids + + def apply_logits_bias(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo): # Apply logit_bias if sampling_info.logit_bias is not None: logits.add_(sampling_info.logit_bias) # min-token, presence, frequency if sampling_info.linear_penalties is not None: - logits += sampling_info.linear_penalties + logits.add_(sampling_info.linear_penalties) # repetition if sampling_info.scaling_penalties is not None: @@ -533,20 +542,6 @@ class ModelRunner: return logits - def sample( - self, logits_output: LogitsProcessorOutput, batch: ScheduleBatch - ) -> torch.Tensor: - # Put CPU-heavy tasks here. They will be overlapped with the forward pass. - batch.sampling_info.update_regex_vocab_mask(batch) - batch.sampling_info.update_penalties() - logits = self._apply_logits_bias( - logits_output.next_token_logits, batch.sampling_info - ) - - # Sample the next tokens. - next_token_ids = self.sampler(logits, batch.sampling_info) - return next_token_ids - @lru_cache() def import_model_classes(): diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 5af692868..7d4f39e68 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, List import torch import sglang.srt.sampling.penaltylib as penaltylib +from sglang.srt.constrained import RegexGuide if TYPE_CHECKING: from sglang.srt.managers.schedule_batch import ScheduleBatch @@ -22,13 +23,17 @@ class SamplingBatchInfo: top_ks: torch.Tensor = None min_ps: torch.Tensor = None - # Dispatch in CUDA graph - need_min_p_sampling: bool = False - # Bias Tensors logit_bias: torch.Tensor = None vocab_mask: torch.Tensor = None + # FSM states + regex_fsms: List[RegexGuide] = None + regex_fsm_states: List[int] = None + + # Dispatch in CUDA graph + need_min_p_sampling: bool = False + # Penalizer penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None linear_penalties: torch.Tensor = None @@ -54,6 +59,8 @@ class SamplingBatchInfo: [r.sampling_params.min_p for r in reqs], dtype=torch.float ) + ret.regex_fsms = [r.regex_fsm for r in reqs] + # TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge. ret.need_min_p_sampling = any(r.sampling_params.min_p > 0 for r in reqs) # Each penalizers will do nothing if they evaluate themselves as not required by looking at @@ -102,24 +109,22 @@ class SamplingBatchInfo: ) self.linear_penalties = penalizer.apply(self.linear_penalties) - def update_regex_vocab_mask(self, batch: ScheduleBatch): - has_regex = any(req.regex_fsm is not None for req in batch.reqs) - + def update_regex_vocab_mask(self): # Reset the vocab mask self.vocab_mask = None - if has_regex: + if any(regex_fsm is not None for regex_fsm in self.regex_fsms): self.vocab_mask = torch.zeros( - batch.batch_size(), self.vocab_size, dtype=torch.bool, device="cuda" + len(self.regex_fsms), self.vocab_size, dtype=torch.bool, device="cuda" ) - for i, req in enumerate(batch.reqs): - if req.regex_fsm is not None: + for i, regex_fsm in enumerate(self.regex_fsms): + if regex_fsm is not None: self.vocab_mask[i].fill_(1) self.vocab_mask[i][ - req.regex_fsm.get_next_instruction(req.regex_fsm_state).tokens + regex_fsm.get_next_instruction(self.regex_fsm_states[i]).tokens ] = 0 - def filter(self, unfinished_indices: List[int], new_indices: torch.Tensor): + def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor): self.penalizer_orchestrator.filter(unfinished_indices, new_indices) for item in [ @@ -129,9 +134,11 @@ class SamplingBatchInfo: "min_ps", "logit_bias", ]: - self_val = getattr(self, item, None) - if self_val is not None: # logit_bias can be None - setattr(self, item, self_val[new_indices]) + value = getattr(self, item, None) + if value is not None: # logit_bias can be None + setattr(self, item, value[new_indices]) + + self.regex_fsms = [self.regex_fsms[i] for i in new_indices] @staticmethod def merge_bias_tensor( @@ -153,7 +160,7 @@ class SamplingBatchInfo: return None - def merge(self, other: "SamplingBatchInfo"): + def merge_batch(self, other: "SamplingBatchInfo"): self.penalizer_orchestrator.merge(other.penalizer_orchestrator) for item in [ @@ -169,3 +176,5 @@ class SamplingBatchInfo: self.logit_bias = SamplingBatchInfo.merge_bias_tensor( self.logit_bias, other.logit_bias, len(self), len(other) ) + + self.regex_fsms.extend(other.regex_fsms) diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 986c90ac0..258ddc303 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -41,7 +41,6 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response, StreamingResponse from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint -from sglang.srt.constrained import disable_cache from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.managers.detokenizer_manager import run_detokenizer_process from sglang.srt.managers.io_struct import ( @@ -72,8 +71,6 @@ from sglang.srt.utils import ( allocate_init_ports, assert_pkg_version, configure_logger, - enable_show_time_cost, - is_hip, kill_child_process, maybe_set_triton_cache_manager, prepare_model_and_tokenizer, @@ -400,14 +397,6 @@ def _set_envs_and_config(server_args: ServerArgs): # Set ulimit set_ulimit() - # Enable show time cost for debugging - if server_args.show_time_cost: - enable_show_time_cost() - - # Disable disk cache - if server_args.disable_disk_cache: - disable_cache() - # Fix triton bugs if server_args.tp_size * server_args.dp_size > 1: # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.