From 87e8c090e910c20f9619808179d6e38ba10e2034 Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Tue, 6 Aug 2024 20:50:32 -0700 Subject: [PATCH] Organize code (rename, movement) (#953) --- python/sglang/bench_latency.py | 5 +- python/sglang/srt/layers/logits_processor.py | 2 +- python/sglang/srt/layers/radix_attention.py | 7 +- python/sglang/srt/managers/schedule_batch.py | 237 +--------------- python/sglang/srt/managers/tp_worker.py | 20 +- .../srt/model_executor/cuda_graph_runner.py | 6 +- .../srt/model_executor/forward_batch_info.py | 256 ++++++++++++++++++ .../sglang/srt/model_executor/model_runner.py | 16 +- python/sglang/srt/models/chatglm.py | 2 +- python/sglang/srt/models/commandr.py | 2 +- python/sglang/srt/models/dbrx.py | 2 +- python/sglang/srt/models/deepseek.py | 2 +- python/sglang/srt/models/deepseek_v2.py | 2 +- python/sglang/srt/models/gemma.py | 2 +- python/sglang/srt/models/gemma2.py | 2 +- python/sglang/srt/models/gpt_bigcode.py | 2 +- python/sglang/srt/models/grok.py | 2 +- python/sglang/srt/models/internlm2.py | 2 +- python/sglang/srt/models/llama2.py | 2 +- .../sglang/srt/models/llama_classification.py | 2 +- python/sglang/srt/models/llava.py | 3 +- python/sglang/srt/models/llavavid.py | 3 +- python/sglang/srt/models/minicpm.py | 2 +- python/sglang/srt/models/mixtral.py | 2 +- python/sglang/srt/models/mixtral_quant.py | 2 +- python/sglang/srt/models/qwen.py | 2 +- python/sglang/srt/models/qwen2.py | 2 +- python/sglang/srt/models/qwen2_moe.py | 2 +- python/sglang/srt/models/stablelm.py | 2 +- 29 files changed, 304 insertions(+), 289 deletions(-) create mode 100644 python/sglang/srt/model_executor/forward_batch_info.py diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index 3000b0bb9..ffd6b24f0 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -50,8 +50,9 @@ import torch import torch.distributed as dist from sglang.srt.hf_transformers_utils import get_tokenizer -from sglang.srt.managers.schedule_batch import Batch, ForwardMode, Req +from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.model_config import ModelConfig +from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.sampling_params import SamplingParams from sglang.srt.server_args import ServerArgs @@ -188,7 +189,7 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len): def extend(reqs, model_runner): - batch = Batch.init_new( + batch = ScheduleBatch.init_new( reqs=reqs, req_to_token_pool=model_runner.req_to_token_pool, token_to_kv_pool=model_runner.token_to_kv_pool, diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index c50f61f37..5584d01ad 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -25,7 +25,7 @@ from vllm.distributed import ( tensor_model_parallel_all_gather, ) -from sglang.srt.model_executor.model_runner import ForwardMode, InputMetadata +from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata @dataclasses.dataclass diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index 784f0df34..2afd329f9 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -22,11 +22,8 @@ from torch import nn from sglang.global_config import global_config from sglang.srt.layers.extend_attention import extend_attention_fwd from sglang.srt.layers.token_attention import token_attention_fwd -from sglang.srt.model_executor.model_runner import ( - ForwardMode, - InputMetadata, - global_server_args_dict, -) +from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata +from sglang.srt.model_executor.model_runner import global_server_args_dict class RadixAttention(nn.Module): diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 5ebf12e30..4e9b9eb2f 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -18,7 +18,6 @@ limitations under the License. import logging import warnings from dataclasses import dataclass -from enum import IntEnum, auto from typing import List, Union import numpy as np @@ -46,15 +45,6 @@ global_server_args_dict = { logger = logging.getLogger(__name__) -class ForwardMode(IntEnum): - # Prefill a new sequence. This is deprecated now. "EXTEND" covers this case. - PREFILL = auto() - # Extend a sequence. The KV cache of the first part of the sequence is already computed (e.g., system prompt). - EXTEND = auto() - # Decode one token. - DECODE = auto() - - class BaseFinishReason: def __init__(self, is_error: bool = False): self.is_error = is_error @@ -284,7 +274,7 @@ class Req: @dataclass -class Batch: +class ScheduleBatch: """Store all inforamtion of a batch.""" # Request, memory pool, and cache @@ -673,7 +663,7 @@ class Batch: if self_val is not None: # logit_bias can be None setattr(self, item, self_val[new_indices]) - def merge(self, other: "Batch"): + def merge(self, other: "ScheduleBatch"): self.reqs.extend(other.reqs) self.req_pool_indices = torch.concat( @@ -770,229 +760,6 @@ class Batch: return batch_next_token_ids -@dataclass -class InputMetadata: - """Store all inforamtion of a forward pass.""" - - forward_mode: ForwardMode - batch_size: int - total_num_tokens: int - req_pool_indices: torch.Tensor - seq_lens: torch.Tensor - positions: torch.Tensor - req_to_token_pool: ReqToTokenPool - token_to_kv_pool: BaseTokenToKVPool - - # For extend - extend_seq_lens: torch.Tensor - extend_start_loc: torch.Tensor - extend_no_prefix: bool - - # Output location of the KV cache - out_cache_loc: torch.Tensor = None - - # Output options - return_logprob: bool = False - top_logprobs_nums: List[int] = None - - # Trition attention backend - triton_max_seq_len: int = 0 - triton_max_extend_len: int = 0 - triton_start_loc: torch.Tensor = None - triton_prefix_lens: torch.Tensor = None - - # FlashInfer attention backend - flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None - flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None - flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None - flashinfer_use_ragged: bool = False - - @classmethod - def create( - cls, - model_runner, - forward_mode, - req_pool_indices, - seq_lens, - prefix_lens, - position_ids_offsets, - out_cache_loc, - top_logprobs_nums=None, - return_logprob=False, - skip_flashinfer_init=False, - ): - flashinfer_use_ragged = False - if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer: - if forward_mode != ForwardMode.DECODE and int(torch.sum(seq_lens)) > 4096: - flashinfer_use_ragged = True - init_flashinfer_args( - forward_mode, - model_runner, - req_pool_indices, - seq_lens, - prefix_lens, - model_runner.flashinfer_decode_wrapper, - flashinfer_use_ragged, - ) - - batch_size = len(req_pool_indices) - - if forward_mode == ForwardMode.DECODE: - positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64) - extend_seq_lens = extend_start_loc = extend_no_prefix = None - if not model_runner.server_args.disable_flashinfer: - # This variable is not needed in this case, - # we do not compute it to make it compatbile with cuda graph. - total_num_tokens = None - else: - total_num_tokens = int(torch.sum(seq_lens)) - else: - seq_lens_cpu = seq_lens.cpu().numpy() - prefix_lens_cpu = prefix_lens.cpu().numpy() - position_ids_offsets_cpu = position_ids_offsets.cpu().numpy() - positions = torch.tensor( - np.concatenate( - [ - np.arange( - prefix_lens_cpu[i] + position_ids_offsets_cpu[i], - seq_lens_cpu[i] + position_ids_offsets_cpu[i], - ) - for i in range(batch_size) - ], - axis=0, - ), - device="cuda", - ) - extend_seq_lens = seq_lens - prefix_lens - extend_start_loc = torch.zeros_like(seq_lens) - extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0) - extend_no_prefix = torch.all(prefix_lens == 0) - total_num_tokens = int(torch.sum(seq_lens)) - - ret = cls( - forward_mode=forward_mode, - batch_size=batch_size, - total_num_tokens=total_num_tokens, - req_pool_indices=req_pool_indices, - seq_lens=seq_lens, - positions=positions, - req_to_token_pool=model_runner.req_to_token_pool, - token_to_kv_pool=model_runner.token_to_kv_pool, - out_cache_loc=out_cache_loc, - extend_seq_lens=extend_seq_lens, - extend_start_loc=extend_start_loc, - extend_no_prefix=extend_no_prefix, - return_logprob=return_logprob, - top_logprobs_nums=top_logprobs_nums, - flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged, - flashinfer_prefill_wrapper_paged=model_runner.flashinfer_prefill_wrapper_paged, - flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper, - flashinfer_use_ragged=flashinfer_use_ragged, - ) - - if model_runner.server_args.disable_flashinfer: - ( - ret.triton_max_seq_len, - ret.triton_max_extend_len, - ret.triton_start_loc, - ret.triton_prefix_lens, - ) = init_triton_args(forward_mode, seq_lens, prefix_lens) - - return ret - - -def init_flashinfer_args( - forward_mode, - model_runner, - req_pool_indices, - seq_lens, - prefix_lens, - flashinfer_decode_wrapper, - flashinfer_use_ragged=False, -): - """Init auxiliary variables for FlashInfer attention backend.""" - num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size - num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size) - head_dim = model_runner.model_config.head_dim - batch_size = len(req_pool_indices) - total_num_tokens = int(torch.sum(seq_lens)) - - if flashinfer_use_ragged: - paged_kernel_lens = prefix_lens - else: - paged_kernel_lens = seq_lens - - kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda") - kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) - req_pool_indices_cpu = req_pool_indices.cpu().numpy() - paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy() - kv_indices = torch.cat( - [ - model_runner.req_to_token_pool.req_to_token[ - req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i] - ] - for i in range(batch_size) - ], - dim=0, - ).contiguous() - kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda") - - if forward_mode == ForwardMode.DECODE: - flashinfer_decode_wrapper.end_forward() - flashinfer_decode_wrapper.begin_forward( - kv_indptr, - kv_indices, - kv_last_page_len, - num_qo_heads, - num_kv_heads, - head_dim, - 1, - ) - else: - # extend part - qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda") - qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0) - - if flashinfer_use_ragged: - model_runner.flashinfer_prefill_wrapper_ragged.end_forward() - model_runner.flashinfer_prefill_wrapper_ragged.begin_forward( - qo_indptr, - qo_indptr, - num_qo_heads, - num_kv_heads, - head_dim, - ) - - # cached part - model_runner.flashinfer_prefill_wrapper_paged.end_forward() - model_runner.flashinfer_prefill_wrapper_paged.begin_forward( - qo_indptr, - kv_indptr, - kv_indices, - kv_last_page_len, - num_qo_heads, - num_kv_heads, - head_dim, - 1, - ) - - -def init_triton_args(forward_mode, seq_lens, prefix_lens): - """Init auxiliary variables for triton attention backend.""" - batch_size = len(seq_lens) - max_seq_len = int(torch.max(seq_lens)) - start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda") - start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0) - - if forward_mode == ForwardMode.DECODE: - max_extend_len = None - else: - extend_seq_lens = seq_lens - prefix_lens - max_extend_len = int(torch.max(extend_seq_lens)) - - return max_seq_len, max_extend_len, start_loc, prefix_lens - - def top_k_top_p_sampling_from_probs_torch( probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor ): diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index d7dedc29d..54d6805d8 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -39,13 +39,13 @@ from sglang.srt.managers.policy_scheduler import PolicyScheduler from sglang.srt.managers.schedule_batch import ( FINISH_ABORT, BaseFinishReason, - Batch, - ForwardMode, Req, + ScheduleBatch, ) from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.model_config import ModelConfig +from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( @@ -172,7 +172,7 @@ class ModelTpServer: # Init running status self.waiting_queue: List[Req] = [] - self.running_batch: Batch = None + self.running_batch: ScheduleBatch = None self.out_pyobjs = [] self.decode_forward_ct = 0 self.stream_interval = server_args.stream_interval @@ -353,7 +353,7 @@ class ModelTpServer: ) self.waiting_queue.append(req) - def get_new_prefill_batch(self) -> Optional[Batch]: + def get_new_prefill_batch(self) -> Optional[ScheduleBatch]: # TODO(lsyin): organize this function running_bs = ( len(self.running_batch.reqs) if self.running_batch is not None else 0 @@ -526,7 +526,7 @@ class ModelTpServer: ) # Return the new batch - new_batch = Batch.init_new( + new_batch = ScheduleBatch.init_new( can_run_list, self.req_to_token_pool, self.token_to_kv_pool, @@ -535,7 +535,7 @@ class ModelTpServer: self.waiting_queue = [x for x in self.waiting_queue if x not in can_run_list] return new_batch - def forward_prefill_batch(self, batch: Batch): + def forward_prefill_batch(self, batch: ScheduleBatch): # Build batch tensors batch.prepare_for_extend( self.model_config.vocab_size, self.int_token_logit_bias @@ -624,7 +624,7 @@ class ModelTpServer: ) req.output_top_logprobs.append(output.output_top_logprobs[i]) - def cache_filled_batch(self, batch: Batch): + def cache_filled_batch(self, batch: ScheduleBatch): req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy() for i, req in enumerate(batch.reqs): new_prefix_indices, new_last_node = self.tree_cache.cache_req( @@ -641,7 +641,7 @@ class ModelTpServer: # inflight request would get a new req idx self.req_to_token_pool.free(int(req_pool_indices_cpu[i])) - def forward_decode_batch(self, batch: Batch): + def forward_decode_batch(self, batch: ScheduleBatch): # Check if decode out of memory if not batch.check_decode_mem(): old_ratio = self.new_token_ratio @@ -700,7 +700,7 @@ class ModelTpServer: self.handle_finished_requests(batch) - def handle_finished_requests(self, batch: Batch): + def handle_finished_requests(self, batch: ScheduleBatch): output_rids = [] output_vids = [] decoded_texts = [] @@ -800,7 +800,7 @@ class ModelTpServer: else: batch.reqs = [] - def filter_out_inflight(self, batch: Batch): + def filter_out_inflight(self, batch: ScheduleBatch): # TODO(lsyin): reduce the overhead, make a special version for this if self.current_inflight_req is None: return diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 458395e73..e81d3b10e 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -29,8 +29,8 @@ from sglang.srt.layers.logits_processor import ( LogitsMetadata, LogitsProcessor, ) -from sglang.srt.managers.schedule_batch import ( - Batch, +from sglang.srt.managers.schedule_batch import ScheduleBatch +from sglang.srt.model_executor.forward_batch_info import ( ForwardMode, InputMetadata, init_flashinfer_args, @@ -202,7 +202,7 @@ class CudaGraphRunner: self.graph_memory_pool = graph.pool() return graph, None, out, flashinfer_decode_wrapper - def replay(self, batch: Batch): + def replay(self, batch: ScheduleBatch): assert batch.out_cache_loc is not None raw_bs = len(batch.reqs) diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py new file mode 100644 index 000000000..1b91bcb91 --- /dev/null +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -0,0 +1,256 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""ModelRunner runs the forward passes of the models.""" +from dataclasses import dataclass +from enum import IntEnum, auto +from typing import List + +import numpy as np +import torch + +from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool + + +class ForwardMode(IntEnum): + # Prefill a new sequence. This is deprecated now. "EXTEND" covers this case. + PREFILL = auto() + # Extend a sequence. The KV cache of the first part of the sequence is already computed (e.g., system prompt). + EXTEND = auto() + # Decode one token. + DECODE = auto() + + +@dataclass +class InputMetadata: + """Store all inforamtion of a forward pass.""" + + forward_mode: ForwardMode + batch_size: int + total_num_tokens: int + req_pool_indices: torch.Tensor + seq_lens: torch.Tensor + positions: torch.Tensor + req_to_token_pool: ReqToTokenPool + token_to_kv_pool: BaseTokenToKVPool + + # For extend + extend_seq_lens: torch.Tensor + extend_start_loc: torch.Tensor + extend_no_prefix: bool + + # Output location of the KV cache + out_cache_loc: torch.Tensor = None + + # Output options + return_logprob: bool = False + top_logprobs_nums: List[int] = None + + # Trition attention backend + triton_max_seq_len: int = 0 + triton_max_extend_len: int = 0 + triton_start_loc: torch.Tensor = None + triton_prefix_lens: torch.Tensor = None + + # FlashInfer attention backend + flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None + flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None + flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None + flashinfer_use_ragged: bool = False + + @classmethod + def create( + cls, + model_runner, + forward_mode, + req_pool_indices, + seq_lens, + prefix_lens, + position_ids_offsets, + out_cache_loc, + top_logprobs_nums=None, + return_logprob=False, + skip_flashinfer_init=False, + ): + flashinfer_use_ragged = False + if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer: + if forward_mode != ForwardMode.DECODE and int(torch.sum(seq_lens)) > 4096: + flashinfer_use_ragged = True + init_flashinfer_args( + forward_mode, + model_runner, + req_pool_indices, + seq_lens, + prefix_lens, + model_runner.flashinfer_decode_wrapper, + flashinfer_use_ragged, + ) + + batch_size = len(req_pool_indices) + + if forward_mode == ForwardMode.DECODE: + positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64) + extend_seq_lens = extend_start_loc = extend_no_prefix = None + if not model_runner.server_args.disable_flashinfer: + # This variable is not needed in this case, + # we do not compute it to make it compatbile with cuda graph. + total_num_tokens = None + else: + total_num_tokens = int(torch.sum(seq_lens)) + else: + seq_lens_cpu = seq_lens.cpu().numpy() + prefix_lens_cpu = prefix_lens.cpu().numpy() + position_ids_offsets_cpu = position_ids_offsets.cpu().numpy() + positions = torch.tensor( + np.concatenate( + [ + np.arange( + prefix_lens_cpu[i] + position_ids_offsets_cpu[i], + seq_lens_cpu[i] + position_ids_offsets_cpu[i], + ) + for i in range(batch_size) + ], + axis=0, + ), + device="cuda", + ) + extend_seq_lens = seq_lens - prefix_lens + extend_start_loc = torch.zeros_like(seq_lens) + extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0) + extend_no_prefix = torch.all(prefix_lens == 0) + total_num_tokens = int(torch.sum(seq_lens)) + + ret = cls( + forward_mode=forward_mode, + batch_size=batch_size, + total_num_tokens=total_num_tokens, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + positions=positions, + req_to_token_pool=model_runner.req_to_token_pool, + token_to_kv_pool=model_runner.token_to_kv_pool, + out_cache_loc=out_cache_loc, + extend_seq_lens=extend_seq_lens, + extend_start_loc=extend_start_loc, + extend_no_prefix=extend_no_prefix, + return_logprob=return_logprob, + top_logprobs_nums=top_logprobs_nums, + flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged, + flashinfer_prefill_wrapper_paged=model_runner.flashinfer_prefill_wrapper_paged, + flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper, + flashinfer_use_ragged=flashinfer_use_ragged, + ) + + if model_runner.server_args.disable_flashinfer: + ( + ret.triton_max_seq_len, + ret.triton_max_extend_len, + ret.triton_start_loc, + ret.triton_prefix_lens, + ) = init_triton_args(forward_mode, seq_lens, prefix_lens) + + return ret + + +def init_flashinfer_args( + forward_mode, + model_runner, + req_pool_indices, + seq_lens, + prefix_lens, + flashinfer_decode_wrapper, + flashinfer_use_ragged=False, +): + """Init auxiliary variables for FlashInfer attention backend.""" + num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size + num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size) + head_dim = model_runner.model_config.head_dim + batch_size = len(req_pool_indices) + total_num_tokens = int(torch.sum(seq_lens)) + + if flashinfer_use_ragged: + paged_kernel_lens = prefix_lens + else: + paged_kernel_lens = seq_lens + + kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda") + kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) + req_pool_indices_cpu = req_pool_indices.cpu().numpy() + paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy() + kv_indices = torch.cat( + [ + model_runner.req_to_token_pool.req_to_token[ + req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i] + ] + for i in range(batch_size) + ], + dim=0, + ).contiguous() + kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda") + + if forward_mode == ForwardMode.DECODE: + flashinfer_decode_wrapper.end_forward() + flashinfer_decode_wrapper.begin_forward( + kv_indptr, + kv_indices, + kv_last_page_len, + num_qo_heads, + num_kv_heads, + head_dim, + 1, + ) + else: + # extend part + qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda") + qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0) + + if flashinfer_use_ragged: + model_runner.flashinfer_prefill_wrapper_ragged.end_forward() + model_runner.flashinfer_prefill_wrapper_ragged.begin_forward( + qo_indptr, + qo_indptr, + num_qo_heads, + num_kv_heads, + head_dim, + ) + + # cached part + model_runner.flashinfer_prefill_wrapper_paged.end_forward() + model_runner.flashinfer_prefill_wrapper_paged.begin_forward( + qo_indptr, + kv_indptr, + kv_indices, + kv_last_page_len, + num_qo_heads, + num_kv_heads, + head_dim, + 1, + ) + + +def init_triton_args(forward_mode, seq_lens, prefix_lens): + """Init auxiliary variables for triton attention backend.""" + batch_size = len(seq_lens) + max_seq_len = int(torch.max(seq_lens)) + start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda") + start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0) + + if forward_mode == ForwardMode.DECODE: + max_extend_len = None + else: + extend_seq_lens = seq_lens - prefix_lens + max_extend_len = int(torch.max(extend_seq_lens)) + + return max_seq_len, max_extend_len, start_loc, prefix_lens diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 5aa6de550..6426c8e69 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -41,18 +41,14 @@ from vllm.distributed import ( from vllm.model_executor.models import ModelRegistry from sglang.global_config import global_config -from sglang.srt.managers.schedule_batch import ( - Batch, - ForwardMode, - InputMetadata, - global_server_args_dict, -) +from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict from sglang.srt.mem_cache.memory_pool import ( MHATokenToKVPool, MLATokenToKVPool, ReqToTokenPool, ) from sglang.srt.model_config import AttentionArch +from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( get_available_gpu_memory, @@ -350,7 +346,7 @@ class ModelRunner: ) @torch.inference_mode() - def forward_decode(self, batch: Batch): + def forward_decode(self, batch: ScheduleBatch): if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)): return self.cuda_graph_runner.replay(batch) @@ -370,7 +366,7 @@ class ModelRunner: ) @torch.inference_mode() - def forward_extend(self, batch: Batch): + def forward_extend(self, batch: ScheduleBatch): input_metadata = InputMetadata.create( self, forward_mode=ForwardMode.EXTEND, @@ -387,7 +383,7 @@ class ModelRunner: ) @torch.inference_mode() - def forward_extend_multi_modal(self, batch: Batch): + def forward_extend_multi_modal(self, batch: ScheduleBatch): input_metadata = InputMetadata.create( self, forward_mode=ForwardMode.EXTEND, @@ -408,7 +404,7 @@ class ModelRunner: batch.image_offsets, ) - def forward(self, batch: Batch, forward_mode: ForwardMode): + def forward(self, batch: ScheduleBatch, forward_mode: ForwardMode): if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND: return self.forward_extend_multi_modal(batch) elif forward_mode == ForwardMode.DECODE: diff --git a/python/sglang/srt/models/chatglm.py b/python/sglang/srt/models/chatglm.py index 4589a14ac..d2ad02fbf 100644 --- a/python/sglang/srt/models/chatglm.py +++ b/python/sglang/srt/models/chatglm.py @@ -45,7 +45,7 @@ from vllm.transformers_utils.configs import ChatGLMConfig from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.model_executor.model_runner import InputMetadata +from sglang.srt.model_executor.forward_batch_info import InputMetadata LoraConfig = None diff --git a/python/sglang/srt/models/commandr.py b/python/sglang/srt/models/commandr.py index 671746bf7..1259285c4 100644 --- a/python/sglang/srt/models/commandr.py +++ b/python/sglang/srt/models/commandr.py @@ -64,7 +64,7 @@ from vllm.model_executor.utils import set_weight_attrs from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.model_executor.model_runner import InputMetadata +from sglang.srt.model_executor.forward_batch_info import InputMetadata @torch.compile diff --git a/python/sglang/srt/models/dbrx.py b/python/sglang/srt/models/dbrx.py index 1d0f40bd3..39ac4aefa 100644 --- a/python/sglang/srt/models/dbrx.py +++ b/python/sglang/srt/models/dbrx.py @@ -45,7 +45,7 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.model_executor.model_runner import InputMetadata +from sglang.srt.model_executor.forward_batch_info import InputMetadata class DbrxRouter(nn.Module): diff --git a/python/sglang/srt/models/deepseek.py b/python/sglang/srt/models/deepseek.py index 09481e71b..98dcfd28d 100644 --- a/python/sglang/srt/models/deepseek.py +++ b/python/sglang/srt/models/deepseek.py @@ -46,7 +46,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.managers.schedule_batch import InputMetadata +from sglang.srt.model_executor.forward_batch_info import InputMetadata class DeepseekMLP(nn.Module): diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index bc31d89ae..739562730 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -46,7 +46,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.managers.schedule_batch import global_server_args_dict -from sglang.srt.model_executor.model_runner import InputMetadata +from sglang.srt.model_executor.forward_batch_info import InputMetadata class DeepseekV2MLP(nn.Module): diff --git a/python/sglang/srt/models/gemma.py b/python/sglang/srt/models/gemma.py index 843bc5d28..ce3973115 100644 --- a/python/sglang/srt/models/gemma.py +++ b/python/sglang/srt/models/gemma.py @@ -37,7 +37,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.model_executor.model_runner import InputMetadata +from sglang.srt.model_executor.forward_batch_info import InputMetadata class GemmaMLP(nn.Module): diff --git a/python/sglang/srt/models/gemma2.py b/python/sglang/srt/models/gemma2.py index 4c77e0c69..539554fa8 100644 --- a/python/sglang/srt/models/gemma2.py +++ b/python/sglang/srt/models/gemma2.py @@ -42,7 +42,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.model_executor.model_runner import InputMetadata +from sglang.srt.model_executor.forward_batch_info import InputMetadata class GemmaRMSNorm(CustomOp): diff --git a/python/sglang/srt/models/gpt_bigcode.py b/python/sglang/srt/models/gpt_bigcode.py index eee7f6483..9a9e2aec3 100644 --- a/python/sglang/srt/models/gpt_bigcode.py +++ b/python/sglang/srt/models/gpt_bigcode.py @@ -35,7 +35,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.managers.schedule_batch import InputMetadata +from sglang.srt.model_executor.forward_batch_info import InputMetadata class GPTBigCodeAttention(nn.Module): diff --git a/python/sglang/srt/models/grok.py b/python/sglang/srt/models/grok.py index b989c4e79..38297b7d6 100644 --- a/python/sglang/srt/models/grok.py +++ b/python/sglang/srt/models/grok.py @@ -52,7 +52,7 @@ from vllm.utils import print_warning_once from sglang.srt.layers.fused_moe import fused_moe from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.model_executor.model_runner import InputMetadata +from sglang.srt.model_executor.forward_batch_info import InputMetadata use_fused = True diff --git a/python/sglang/srt/models/internlm2.py b/python/sglang/srt/models/internlm2.py index 35f81f8a9..394d00504 100644 --- a/python/sglang/srt/models/internlm2.py +++ b/python/sglang/srt/models/internlm2.py @@ -40,7 +40,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.model_executor.model_runner import InputMetadata +from sglang.srt.model_executor.forward_batch_info import InputMetadata class InternLM2MLP(nn.Module): diff --git a/python/sglang/srt/models/llama2.py b/python/sglang/srt/models/llama2.py index 9fcbb794b..7a6d570a4 100644 --- a/python/sglang/srt/models/llama2.py +++ b/python/sglang/srt/models/llama2.py @@ -41,7 +41,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.model_executor.model_runner import InputMetadata +from sglang.srt.model_executor.forward_batch_info import InputMetadata class LlamaMLP(nn.Module): diff --git a/python/sglang/srt/models/llama_classification.py b/python/sglang/srt/models/llama_classification.py index 3ffb256dd..02224971d 100644 --- a/python/sglang/srt/models/llama_classification.py +++ b/python/sglang/srt/models/llama_classification.py @@ -25,7 +25,7 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.logits_processor import LogitProcessorOutput -from sglang.srt.model_executor.model_runner import InputMetadata +from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.models.llama2 import LlamaModel diff --git a/python/sglang/srt/models/llava.py b/python/sglang/srt/models/llava.py index f89a9b618..a885a6e59 100644 --- a/python/sglang/srt/models/llava.py +++ b/python/sglang/srt/models/llava.py @@ -32,13 +32,12 @@ from vllm.config import CacheConfig from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from sglang.srt.managers.schedule_batch import ForwardMode from sglang.srt.mm_utils import ( get_anyres_image_grid_shape, unpad_image, unpad_image_shape, ) -from sglang.srt.model_executor.model_runner import InputMetadata +from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata from sglang.srt.models.llama2 import LlamaForCausalLM from sglang.srt.models.mistral import MistralForCausalLM from sglang.srt.models.qwen2 import Qwen2ForCausalLM diff --git a/python/sglang/srt/models/llavavid.py b/python/sglang/srt/models/llavavid.py index 3f88d41a1..8b81251d6 100644 --- a/python/sglang/srt/models/llavavid.py +++ b/python/sglang/srt/models/llavavid.py @@ -26,13 +26,12 @@ from vllm.config import CacheConfig from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from sglang.srt.managers.schedule_batch import ForwardMode from sglang.srt.mm_utils import ( get_anyres_image_grid_shape, unpad_image, unpad_image_shape, ) -from sglang.srt.model_executor.model_runner import InputMetadata +from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata from sglang.srt.models.llama2 import LlamaForCausalLM diff --git a/python/sglang/srt/models/minicpm.py b/python/sglang/srt/models/minicpm.py index ab2a08325..bf572855e 100644 --- a/python/sglang/srt/models/minicpm.py +++ b/python/sglang/srt/models/minicpm.py @@ -39,7 +39,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.model_executor.model_runner import InputMetadata +from sglang.srt.model_executor.forward_batch_info import InputMetadata class MiniCPMMLP(nn.Module): diff --git a/python/sglang/srt/models/mixtral.py b/python/sglang/srt/models/mixtral.py index a7d45d455..63053ac50 100644 --- a/python/sglang/srt/models/mixtral.py +++ b/python/sglang/srt/models/mixtral.py @@ -50,7 +50,7 @@ from vllm.utils import print_warning_once from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.model_executor.model_runner import InputMetadata +from sglang.srt.model_executor.forward_batch_info import InputMetadata class MixtralMoE(nn.Module): diff --git a/python/sglang/srt/models/mixtral_quant.py b/python/sglang/srt/models/mixtral_quant.py index d643db33f..07caf3833 100644 --- a/python/sglang/srt/models/mixtral_quant.py +++ b/python/sglang/srt/models/mixtral_quant.py @@ -45,7 +45,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.model_executor.model_runner import InputMetadata +from sglang.srt.model_executor.forward_batch_info import InputMetadata class MixtralMLP(nn.Module): diff --git a/python/sglang/srt/models/qwen.py b/python/sglang/srt/models/qwen.py index 52edd28bc..ffc512b1c 100644 --- a/python/sglang/srt/models/qwen.py +++ b/python/sglang/srt/models/qwen.py @@ -39,7 +39,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.model_executor.model_runner import InputMetadata +from sglang.srt.model_executor.forward_batch_info import InputMetadata class QWenMLP(nn.Module): diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index 2df91814e..dec962bf0 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -39,7 +39,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.model_executor.model_runner import InputMetadata +from sglang.srt.model_executor.forward_batch_info import InputMetadata Qwen2Config = None diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 7475d8f62..f3105ad45 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -51,7 +51,7 @@ from vllm.sequence import IntermediateTensors, SamplerOutput from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.model_executor.model_runner import InputMetadata +from sglang.srt.model_executor.forward_batch_info import InputMetadata class Qwen2MoeMLP(nn.Module): diff --git a/python/sglang/srt/models/stablelm.py b/python/sglang/srt/models/stablelm.py index 76f40437a..aeaa46ab1 100644 --- a/python/sglang/srt/models/stablelm.py +++ b/python/sglang/srt/models/stablelm.py @@ -40,7 +40,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.model_executor.model_runner import InputMetadata +from sglang.srt.model_executor.forward_batch_info import InputMetadata class StablelmMLP(nn.Module):