From 1ac304eeb483c4ce3435dd1673426ddd7271d02c Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Thu, 8 Aug 2024 01:11:22 -0700 Subject: [PATCH] Adjust `InputeMetadata` and `ScheduleBatch` (#981) --- python/sglang/srt/managers/schedule_batch.py | 62 ++-- .../srt/model_executor/cuda_graph_runner.py | 18 +- .../srt/model_executor/forward_batch_info.py | 272 +++++++++++------- .../sglang/srt/model_executor/model_runner.py | 43 +-- 4 files changed, 203 insertions(+), 192 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 66ff020cc..5f026812a 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -307,7 +307,6 @@ class ScheduleBatch: input_ids: torch.Tensor = None req_pool_indices: torch.Tensor = None seq_lens: torch.Tensor = None - prefix_lens: torch.Tensor = None position_ids_offsets: torch.Tensor = None out_cache_loc: torch.Tensor = None extend_num_tokens: int = None @@ -316,11 +315,6 @@ class ScheduleBatch: return_logprob: bool = False top_logprobs_nums: List[int] = None - # For multimodal - pixel_values: List[torch.Tensor] = None - image_sizes: List[List[int]] = None - image_offsets: List[int] = None - # Batched sampling params temperatures: torch.Tensor = None top_ps: torch.Tensor = None @@ -412,59 +406,40 @@ class ScheduleBatch: self.logit_bias[i][: len(int_token_logit_bias)] = int_token_logit_bias def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor): - device = "cuda" bs = self.batch_size() reqs = self.reqs input_ids = [r.input_ids[len(r.prefix_indices) :] for r in reqs] - prefix_indices = [r.prefix_indices for r in reqs] - - # Handle prefix - extend_lens = [] - prefix_lens = [] + extend_num_tokens = sum(len(ids) for ids in input_ids) seq_lens = [] - req_pool_indices_cpu = self.alloc_req_slots(bs) - - for i, req in enumerate(reqs): - req.req_pool_idx = req_pool_indices_cpu[i] - extend_lens.append(len(input_ids[i])) - - if len(prefix_indices[i]) == 0: - prefix_lens.append(0) - else: - prefix_lens.append(len(prefix_indices[i])) - self.req_to_token_pool.req_to_token[req.req_pool_idx][ - : len(prefix_indices[i]) - ] = prefix_indices[i] - - seq_lens.append(prefix_lens[-1] + extend_lens[-1]) - # Allocate memory - seq_lens, prefix_lens = np.array(seq_lens), np.array(prefix_lens) - extend_num_tokens = seq_lens.sum() - prefix_lens.sum() + req_pool_indices_cpu = self.alloc_req_slots(bs) out_cache_loc = self.alloc_token_slots(extend_num_tokens) pt = 0 for i, req in enumerate(reqs): - self.req_to_token_pool.req_to_token[req.req_pool_idx][ - prefix_lens[i] : prefix_lens[i] + extend_lens[i] - ] = out_cache_loc[pt : pt + extend_lens[i]] - pt += extend_lens[i] + req.req_pool_idx = req_pool_indices_cpu[i] + pre_len, seq_len = len(req.prefix_indices), len(req.input_ids) + ext_len = seq_len - pre_len + seq_lens.append(seq_len) + + if pre_len > 0: + self.req_to_token_pool.req_to_token[req.req_pool_idx][ + :pre_len + ] = req.prefix_indices + + self.req_to_token_pool.req_to_token[req.req_pool_idx][pre_len:seq_len] = ( + out_cache_loc[pt : pt + ext_len] + ) + pt += ext_len # Set fields with torch.device("cuda"): self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32) self.req_pool_indices = torch.tensor(req_pool_indices_cpu) self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32) - self.position_ids_offsets = torch.zeros((bs,), dtype=torch.int32) + self.position_ids_offsets = torch.zeros((bs,), dtype=torch.int64) - self.pixel_values = [r.pixel_values for r in reqs] - self.image_sizes = [r.image_size for r in reqs] - self.image_offsets = [ - (r.image_offset - p_len) if r.image_offset is not None else 0 - for r, p_len in zip(reqs, prefix_lens) - ] - self.prefix_lens = torch.tensor(prefix_lens, dtype=torch.int32, device=device) self.extend_num_tokens = extend_num_tokens self.out_cache_loc = out_cache_loc self.top_logprobs_nums = [r.top_logprobs_num for r in reqs] @@ -642,7 +617,6 @@ class ScheduleBatch: ] self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda") self.seq_lens.add_(1) - self.prefix_lens = None # Alloc mem bs = self.batch_size() @@ -667,7 +641,6 @@ class ScheduleBatch: self.seq_lens = self.seq_lens[new_indices] self.input_ids = None self.req_pool_indices = self.req_pool_indices[new_indices] - self.prefix_lens = None self.position_ids_offsets = self.position_ids_offsets[new_indices] self.out_cache_loc = None self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices] @@ -692,7 +665,6 @@ class ScheduleBatch: [self.req_pool_indices, other.req_pool_indices] ) self.seq_lens = torch.concat([self.seq_lens, other.seq_lens]) - self.prefix_lens = None self.position_ids_offsets = torch.concat( [self.position_ids_offsets, other.position_ids_offsets] ) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index e81d3b10e..ae6fe83c5 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -33,7 +33,7 @@ from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.model_executor.forward_batch_info import ( ForwardMode, InputMetadata, - init_flashinfer_args, + update_flashinfer_indices, ) from sglang.srt.utils import monkey_patch_vllm_all_gather @@ -165,7 +165,7 @@ class CudaGraphRunner: paged_kv_indices_buffer=self.flashinfer_kv_indices, paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[:bs], ) - init_flashinfer_args( + update_flashinfer_indices( ForwardMode.DECODE, self.model_runner, req_pool_indices, @@ -176,19 +176,19 @@ class CudaGraphRunner: # Run and capture def run_once(): - input_metadata = InputMetadata.create( - self.model_runner, + input_metadata = InputMetadata( forward_mode=ForwardMode.DECODE, + batch_size=bs, req_pool_indices=req_pool_indices, seq_lens=seq_lens, - prefix_lens=None, - position_ids_offsets=position_ids_offsets, + req_to_token_pool=self.model_runner.req_to_token_pool, + token_to_kv_pool=self.model_runner.token_to_kv_pool, out_cache_loc=out_cache_loc, return_logprob=False, top_logprobs_nums=0, - skip_flashinfer_init=True, + positions=(seq_lens - 1).to(torch.int64), + flashinfer_decode_wrapper=flashinfer_decode_wrapper, ) - input_metadata.flashinfer_decode_wrapper = flashinfer_decode_wrapper return forward(input_ids, input_metadata.positions, input_metadata) @@ -222,7 +222,7 @@ class CudaGraphRunner: self.out_cache_loc[:raw_bs] = batch.out_cache_loc # FlashInfer inputs - init_flashinfer_args( + update_flashinfer_indices( ForwardMode.DECODE, self.model_runner, self.req_pool_indices[:bs], diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 1b91bcb91..686e7ed86 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -16,13 +16,17 @@ limitations under the License. """ModelRunner runs the forward passes of the models.""" from dataclasses import dataclass from enum import IntEnum, auto -from typing import List +from typing import TYPE_CHECKING, List import numpy as np import torch +from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool +if TYPE_CHECKING: + from sglang.srt.model_executor.model_runner import ModelRunner + class ForwardMode(IntEnum): # Prefill a new sequence. This is deprecated now. "EXTEND" covers this case. @@ -39,25 +43,33 @@ class InputMetadata: forward_mode: ForwardMode batch_size: int - total_num_tokens: int req_pool_indices: torch.Tensor seq_lens: torch.Tensor - positions: torch.Tensor req_to_token_pool: ReqToTokenPool token_to_kv_pool: BaseTokenToKVPool - # For extend - extend_seq_lens: torch.Tensor - extend_start_loc: torch.Tensor - extend_no_prefix: bool - # Output location of the KV cache - out_cache_loc: torch.Tensor = None + out_cache_loc: torch.Tensor + + total_num_tokens: int = None + + # Position information + positions: torch.Tensor = None + + # For extend + extend_seq_lens: torch.Tensor = None + extend_start_loc: torch.Tensor = None + extend_no_prefix: bool = None # Output options return_logprob: bool = False top_logprobs_nums: List[int] = None + # For multimodal + pixel_values: List[torch.Tensor] = None + image_sizes: List[List[int]] = None + image_offsets: List[int] = None + # Trition attention backend triton_max_seq_len: int = 0 triton_max_extend_len: int = 0 @@ -70,107 +82,170 @@ class InputMetadata: flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None flashinfer_use_ragged: bool = False - @classmethod - def create( - cls, - model_runner, - forward_mode, - req_pool_indices, - seq_lens, - prefix_lens, - position_ids_offsets, - out_cache_loc, - top_logprobs_nums=None, - return_logprob=False, - skip_flashinfer_init=False, - ): - flashinfer_use_ragged = False - if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer: - if forward_mode != ForwardMode.DECODE and int(torch.sum(seq_lens)) > 4096: - flashinfer_use_ragged = True - init_flashinfer_args( - forward_mode, - model_runner, - req_pool_indices, - seq_lens, - prefix_lens, - model_runner.flashinfer_decode_wrapper, - flashinfer_use_ragged, + def init_multimuldal_info(self, batch: ScheduleBatch): + reqs = batch.reqs + self.pixel_values = [r.pixel_values for r in reqs] + self.image_sizes = [r.image_size for r in reqs] + self.image_offsets = [ + ( + (r.image_offset - len(r.prefix_indices)) + if r.image_offset is not None + else 0 ) + for r in reqs + ] - batch_size = len(req_pool_indices) + def compute_positions(self, batch: ScheduleBatch): + position_ids_offsets = batch.position_ids_offsets - if forward_mode == ForwardMode.DECODE: - positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64) - extend_seq_lens = extend_start_loc = extend_no_prefix = None - if not model_runner.server_args.disable_flashinfer: - # This variable is not needed in this case, - # we do not compute it to make it compatbile with cuda graph. - total_num_tokens = None + if self.forward_mode == ForwardMode.DECODE: + if True: + self.positions = self.seq_lens - 1 else: - total_num_tokens = int(torch.sum(seq_lens)) + # Deprecated + self.positions = (self.seq_lens - 1) + position_ids_offsets else: - seq_lens_cpu = seq_lens.cpu().numpy() - prefix_lens_cpu = prefix_lens.cpu().numpy() - position_ids_offsets_cpu = position_ids_offsets.cpu().numpy() - positions = torch.tensor( - np.concatenate( - [ - np.arange( - prefix_lens_cpu[i] + position_ids_offsets_cpu[i], - seq_lens_cpu[i] + position_ids_offsets_cpu[i], - ) - for i in range(batch_size) - ], - axis=0, - ), - device="cuda", - ) - extend_seq_lens = seq_lens - prefix_lens - extend_start_loc = torch.zeros_like(seq_lens) - extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0) - extend_no_prefix = torch.all(prefix_lens == 0) - total_num_tokens = int(torch.sum(seq_lens)) + if True: + self.positions = torch.tensor( + np.concatenate( + [ + np.arange(len(req.prefix_indices), len(req.input_ids)) + for req in batch.reqs + ], + axis=0, + ), + device="cuda", + ) + else: + # Deprecated + position_ids_offsets_cpu = position_ids_offsets.cpu().numpy() + self.positions = torch.tensor( + np.concatenate( + [ + np.arange( + len(req.prefix_indices) + position_ids_offsets_cpu[i], + len(req.input_ids) + position_ids_offsets_cpu[i], + ) + for i, req in enumerate(batch.reqs) + ], + axis=0, + ), + device="cuda", + ) + # Positions should be in long type + self.positions = self.positions.to(torch.int64) + + def compute_extend_infos(self, batch: ScheduleBatch): + if self.forward_mode == ForwardMode.DECODE: + self.extend_seq_lens = self.extend_start_loc = self.extend_no_prefix = None + else: + prefix_lens_cpu = [ + len(r.input_ids) - len(r.prefix_indices) for r in batch.reqs + ] + self.extend_seq_lens = torch.tensor(prefix_lens_cpu, device="cuda") + self.extend_start_loc = torch.zeros_like(self.seq_lens) + self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0) + self.extend_no_prefix = all(x == 0 for x in prefix_lens_cpu) + + def init_total_num_tokens(self, batch: ScheduleBatch): + self.total_num_tokens = sum(len(req.input_ids) for req in batch.reqs) + + @classmethod + def from_schedule_batch( + cls, + model_runner: "ModelRunner", + batch: ScheduleBatch, + forward_mode: ForwardMode, + ): ret = cls( forward_mode=forward_mode, - batch_size=batch_size, - total_num_tokens=total_num_tokens, - req_pool_indices=req_pool_indices, - seq_lens=seq_lens, - positions=positions, + batch_size=batch.batch_size(), + req_pool_indices=batch.req_pool_indices, + seq_lens=batch.seq_lens, req_to_token_pool=model_runner.req_to_token_pool, token_to_kv_pool=model_runner.token_to_kv_pool, - out_cache_loc=out_cache_loc, - extend_seq_lens=extend_seq_lens, - extend_start_loc=extend_start_loc, - extend_no_prefix=extend_no_prefix, - return_logprob=return_logprob, - top_logprobs_nums=top_logprobs_nums, - flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged, - flashinfer_prefill_wrapper_paged=model_runner.flashinfer_prefill_wrapper_paged, - flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper, - flashinfer_use_ragged=flashinfer_use_ragged, + out_cache_loc=batch.out_cache_loc, + return_logprob=batch.return_logprob, + top_logprobs_nums=batch.top_logprobs_nums, ) + ret.compute_positions(batch) + + ret.compute_extend_infos(batch) + + ret.init_total_num_tokens(batch) + + if forward_mode != ForwardMode.DECODE: + ret.init_multimuldal_info(batch) + + prefix_lens = None + if forward_mode != ForwardMode.DECODE: + prefix_lens = torch.tensor( + [len(r.prefix_indices) for r in batch.reqs], device="cuda" + ) + if model_runner.server_args.disable_flashinfer: - ( - ret.triton_max_seq_len, - ret.triton_max_extend_len, - ret.triton_start_loc, - ret.triton_prefix_lens, - ) = init_triton_args(forward_mode, seq_lens, prefix_lens) + ret.init_triton_args(batch, prefix_lens) + + flashinfer_use_ragged = False + if not model_runner.server_args.disable_flashinfer: + if ( + forward_mode != ForwardMode.DECODE + and int(torch.sum(ret.seq_lens)) > 4096 + ): + flashinfer_use_ragged = True + ret.init_flashinfer_handlers( + model_runner, prefix_lens, flashinfer_use_ragged + ) return ret + def init_triton_args(self, batch: ScheduleBatch, prefix_lens): + """Init auxiliary variables for triton attention backend.""" + self.triton_max_seq_len = max(len(r.input_ids) for r in batch.reqs) + self.triton_prefix_lens = prefix_lens + self.triton_start_loc = torch.zeros_like(self.seq_lens, dtype=torch.int32) + self.triton_start_loc[1:] = torch.cumsum(self.seq_lens[:-1], dim=0) -def init_flashinfer_args( + if self.forward_mode == ForwardMode.DECODE: + self.triton_max_extend_len = None + else: + extend_seq_lens = self.seq_lens - prefix_lens + self.triton_max_extend_len = int(torch.max(extend_seq_lens)) + + def init_flashinfer_handlers( + self, model_runner, prefix_lens, flashinfer_use_ragged + ): + update_flashinfer_indices( + self.forward_mode, + model_runner, + self.req_pool_indices, + self.seq_lens, + prefix_lens, + flashinfer_use_ragged=flashinfer_use_ragged, + ) + + ( + self.flashinfer_prefill_wrapper_ragged, + self.flashinfer_prefill_wrapper_paged, + self.flashinfer_decode_wrapper, + self.flashinfer_use_ragged, + ) = ( + model_runner.flashinfer_prefill_wrapper_ragged, + model_runner.flashinfer_prefill_wrapper_paged, + model_runner.flashinfer_decode_wrapper, + flashinfer_use_ragged, + ) + + +def update_flashinfer_indices( forward_mode, model_runner, req_pool_indices, seq_lens, prefix_lens, - flashinfer_decode_wrapper, + flashinfer_decode_wrapper=None, flashinfer_use_ragged=False, ): """Init auxiliary variables for FlashInfer attention backend.""" @@ -178,7 +253,6 @@ def init_flashinfer_args( num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size) head_dim = model_runner.model_config.head_dim batch_size = len(req_pool_indices) - total_num_tokens = int(torch.sum(seq_lens)) if flashinfer_use_ragged: paged_kernel_lens = prefix_lens @@ -201,6 +275,10 @@ def init_flashinfer_args( kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda") if forward_mode == ForwardMode.DECODE: + # CUDA graph uses different flashinfer_decode_wrapper + if flashinfer_decode_wrapper is None: + flashinfer_decode_wrapper = model_runner.flashinfer_decode_wrapper + flashinfer_decode_wrapper.end_forward() flashinfer_decode_wrapper.begin_forward( kv_indptr, @@ -238,19 +316,3 @@ def init_flashinfer_args( head_dim, 1, ) - - -def init_triton_args(forward_mode, seq_lens, prefix_lens): - """Init auxiliary variables for triton attention backend.""" - batch_size = len(seq_lens) - max_seq_len = int(torch.max(seq_lens)) - start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda") - start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0) - - if forward_mode == ForwardMode.DECODE: - max_extend_len = None - else: - extend_seq_lens = seq_lens - prefix_lens - max_extend_len = int(torch.max(extend_seq_lens)) - - return max_seq_len, max_extend_len, start_loc, prefix_lens diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 9a285b337..17ce5edf7 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -350,33 +350,18 @@ class ModelRunner: if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)): return self.cuda_graph_runner.replay(batch) - input_metadata = InputMetadata.create( - self, - forward_mode=ForwardMode.DECODE, - req_pool_indices=batch.req_pool_indices, - seq_lens=batch.seq_lens, - prefix_lens=batch.prefix_lens, - position_ids_offsets=batch.position_ids_offsets, - out_cache_loc=batch.out_cache_loc, - top_logprobs_nums=batch.top_logprobs_nums, - return_logprob=batch.return_logprob, + input_metadata = InputMetadata.from_schedule_batch( + self, batch, ForwardMode.DECODE ) + return self.model.forward( batch.input_ids, input_metadata.positions, input_metadata ) @torch.inference_mode() def forward_extend(self, batch: ScheduleBatch): - input_metadata = InputMetadata.create( - self, - forward_mode=ForwardMode.EXTEND, - req_pool_indices=batch.req_pool_indices, - seq_lens=batch.seq_lens, - prefix_lens=batch.prefix_lens, - position_ids_offsets=batch.position_ids_offsets, - out_cache_loc=batch.out_cache_loc, - top_logprobs_nums=batch.top_logprobs_nums, - return_logprob=batch.return_logprob, + input_metadata = InputMetadata.from_schedule_batch( + self, batch, forward_mode=ForwardMode.EXTEND ) return self.model.forward( batch.input_ids, input_metadata.positions, input_metadata @@ -384,24 +369,16 @@ class ModelRunner: @torch.inference_mode() def forward_extend_multi_modal(self, batch: ScheduleBatch): - input_metadata = InputMetadata.create( - self, - forward_mode=ForwardMode.EXTEND, - req_pool_indices=batch.req_pool_indices, - seq_lens=batch.seq_lens, - prefix_lens=batch.prefix_lens, - position_ids_offsets=batch.position_ids_offsets, - out_cache_loc=batch.out_cache_loc, - return_logprob=batch.return_logprob, - top_logprobs_nums=batch.top_logprobs_nums, + input_metadata = InputMetadata.from_schedule_batch( + self, batch, forward_mode=ForwardMode.EXTEND ) return self.model.forward( batch.input_ids, input_metadata.positions, input_metadata, - batch.pixel_values, - batch.image_sizes, - batch.image_offsets, + input_metadata.pixel_values, + input_metadata.image_sizes, + input_metadata.image_offsets, ) def forward(self, batch: ScheduleBatch, forward_mode: ForwardMode):