From 519e20cfda4aad594e32c86e844effdec753dcca Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Fri, 12 Jul 2024 12:28:09 -0700 Subject: [PATCH] Code clean up: Remove deprecated prefill move InputMetadata to infer_batch.py (#609) --- python/sglang/srt/layers/radix_attention.py | 7 +- .../srt/managers/controller/infer_batch.py | 215 +++++++++++++++- .../srt/managers/controller/model_runner.py | 242 +----------------- 3 files changed, 219 insertions(+), 245 deletions(-) diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index a2d96e9d2..eab16d536 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -8,6 +8,7 @@ from torch import nn from sglang.global_config import global_config from sglang.srt.layers.extend_attention import extend_attention_fwd from sglang.srt.layers.token_attention import token_attention_fwd +from sglang.srt.managers.controller.infer_batch import global_server_args_dict from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata @@ -29,8 +30,6 @@ class RadixAttention(nn.Module): self.scaling = scaling self.layer_id = layer_id - from sglang.srt.managers.controller.model_runner import global_server_args_dict - if not global_server_args_dict.get("disable_flashinfer", False): self.prefill_forward = self.prefill_forward_flashinfer self.extend_forward = self.prefill_forward_flashinfer @@ -141,9 +140,7 @@ class RadixAttention(nn.Module): k = k.view(-1, self.tp_k_head_num, self.head_dim) v = v.view(-1, self.tp_v_head_num, self.head_dim) - if input_metadata.forward_mode == ForwardMode.PREFILL: - return self.prefill_forward(q, k, v, input_metadata) - elif input_metadata.forward_mode == ForwardMode.EXTEND: + if input_metadata.forward_mode == ForwardMode.EXTEND: return self.extend_forward(q, k, v, input_metadata) elif input_metadata.forward_mode == ForwardMode.DECODE: return self.decode_forward(q, k, v, input_metadata) diff --git a/python/sglang/srt/managers/controller/infer_batch.py b/python/sglang/srt/managers/controller/infer_batch.py index ec4730061..793262b6f 100644 --- a/python/sglang/srt/managers/controller/infer_batch.py +++ b/python/sglang/srt/managers/controller/infer_batch.py @@ -15,10 +15,16 @@ from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5 +# Store some global server args +global_server_args_dict = {} + class ForwardMode(IntEnum): + # Prefill a new sequence. This is deprecated now. "EXTEND" covers this case. PREFILL = auto() + # Extend a sequence. The KV cache of the first part of the sequence is already computed (e.g., system prompt). EXTEND = auto() + # Decode one token. DECODE = auto() @@ -66,6 +72,8 @@ class FINISH_ABORT(BaseFinishReason): class Req: + """Store all inforamtion of a request.""" + def __init__(self, rid, origin_input_text, origin_input_ids): self.rid = rid self.origin_input_text = origin_input_text @@ -74,7 +82,7 @@ class Req: self.output_ids = [] # Each decode stage's output ids self.input_ids = None # input_ids = origin_input_ids + output_ids - # For incremental decode + # For incremental decoding self.decoded_text = "" self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm self.read_offset = None @@ -93,9 +101,8 @@ class Req: self.sampling_params = None self.stream = False - self.tokenizer = None - # Check finish + self.tokenizer = None self.finished_reason = None # Prefix info @@ -252,6 +259,8 @@ class Req: @dataclass class Batch: + """Store all inforamtion of a batch.""" + reqs: List[Req] req_to_token_pool: ReqToTokenPool token_to_kv_pool: TokenToKVPool @@ -692,3 +701,203 @@ def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor ] = 0.0 probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0]) return probs_sort, probs_idx + + + +@dataclass +class InputMetadata: + """Store all inforamtion of a forward pass.""" + + forward_mode: ForwardMode + batch_size: int + total_num_tokens: int + max_seq_len: int + req_pool_indices: torch.Tensor + start_loc: torch.Tensor + seq_lens: torch.Tensor + prefix_lens: torch.Tensor + positions: torch.Tensor + req_to_token_pool: ReqToTokenPool + token_to_kv_pool: TokenToKVPool + + # for extend + extend_seq_lens: torch.Tensor = None + extend_start_loc: torch.Tensor = None + max_extend_len: int = 0 + + out_cache_loc: torch.Tensor = None + out_cache_cont_start: torch.Tensor = None + out_cache_cont_end: torch.Tensor = None + + other_kv_index: torch.Tensor = None + return_logprob: bool = False + top_logprobs_nums: List[int] = None + + # for flashinfer + qo_indptr: torch.Tensor = None + kv_indptr: torch.Tensor = None + kv_indices: torch.Tensor = None + kv_last_page_len: torch.Tensor = None + flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None + flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None + flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None + + def init_flashinfer_args(self, num_qo_heads, num_kv_heads, head_dim): + if ( + self.forward_mode == ForwardMode.EXTEND + ): + paged_kernel_lens = self.prefix_lens + self.no_prefix = torch.all(self.prefix_lens == 0) + else: + paged_kernel_lens = self.seq_lens + + self.kv_indptr = torch.zeros( + (self.batch_size + 1,), dtype=torch.int32, device="cuda" + ) + self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) + self.kv_last_page_len = torch.ones( + (self.batch_size,), dtype=torch.int32, device="cuda" + ) + req_pool_indices_cpu = self.req_pool_indices.cpu().numpy() + paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy() + self.kv_indices = torch.cat( + [ + self.req_to_token_pool.req_to_token[ + req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i] + ] + for i in range(self.batch_size) + ], + dim=0, + ).contiguous() + + if self.forward_mode == ForwardMode.EXTEND: + # extend part + self.qo_indptr = torch.zeros( + (self.batch_size + 1,), dtype=torch.int32, device="cuda" + ) + self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0) + + self.flashinfer_prefill_wrapper_ragged.end_forward() + self.flashinfer_prefill_wrapper_ragged.begin_forward( + self.qo_indptr, + self.qo_indptr.clone(), + num_qo_heads, + num_kv_heads, + head_dim, + ) + + # cached part + self.flashinfer_prefill_wrapper_paged.end_forward() + self.flashinfer_prefill_wrapper_paged.begin_forward( + self.qo_indptr, + self.kv_indptr, + self.kv_indices, + self.kv_last_page_len, + num_qo_heads, + num_kv_heads, + head_dim, + 1, + ) + else: + self.flashinfer_decode_wrapper.end_forward() + self.flashinfer_decode_wrapper.begin_forward( + self.kv_indptr, + self.kv_indices, + self.kv_last_page_len, + num_qo_heads, + num_kv_heads, + head_dim, + 1, + pos_encoding_mode="NONE", + data_type=self.token_to_kv_pool.kv_data[0].dtype, + ) + + def init_extend_args(self): + self.extend_seq_lens = self.seq_lens - self.prefix_lens + self.extend_start_loc = torch.zeros_like(self.seq_lens) + self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0) + self.max_extend_len = int(torch.max(self.extend_seq_lens)) + + @classmethod + def create( + cls, + model_runner, + tp_size, + forward_mode, + req_pool_indices, + seq_lens, + prefix_lens, + position_ids_offsets, + out_cache_loc, + out_cache_cont_start=None, + out_cache_cont_end=None, + top_logprobs_nums=None, + return_logprob=False, + flashinfer_prefill_wrapper_ragged=None, + flashinfer_prefill_wrapper_paged=None, + flashinfer_decode_wrapper=None, + ): + batch_size = len(req_pool_indices) + start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda") + start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0) + total_num_tokens = int(torch.sum(seq_lens)) + max_seq_len = int(torch.max(seq_lens)) + + if forward_mode == ForwardMode.DECODE: + positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64) + other_kv_index = model_runner.req_to_token_pool.req_to_token[ + req_pool_indices[0], seq_lens[0] - 1 + ].item() + else: + seq_lens_cpu = seq_lens.cpu().numpy() + prefix_lens_cpu = prefix_lens.cpu().numpy() + position_ids_offsets_cpu = position_ids_offsets.cpu().numpy() + positions = torch.tensor( + np.concatenate( + [ + np.arange( + prefix_lens_cpu[i] + position_ids_offsets_cpu[i], + seq_lens_cpu[i] + position_ids_offsets_cpu[i], + ) + for i in range(batch_size) + ], + axis=0, + ), + device="cuda", + ) + other_kv_index = None + + ret = cls( + forward_mode=forward_mode, + batch_size=batch_size, + total_num_tokens=total_num_tokens, + max_seq_len=max_seq_len, + req_pool_indices=req_pool_indices, + start_loc=start_loc, + seq_lens=seq_lens, + prefix_lens=prefix_lens, + positions=positions, + req_to_token_pool=model_runner.req_to_token_pool, + token_to_kv_pool=model_runner.token_to_kv_pool, + out_cache_loc=out_cache_loc, + out_cache_cont_start=out_cache_cont_start, + out_cache_cont_end=out_cache_cont_end, + other_kv_index=other_kv_index, + return_logprob=return_logprob, + top_logprobs_nums=top_logprobs_nums, + flashinfer_prefill_wrapper_ragged=flashinfer_prefill_wrapper_ragged, + flashinfer_prefill_wrapper_paged=flashinfer_prefill_wrapper_paged, + flashinfer_decode_wrapper=flashinfer_decode_wrapper, + ) + + if forward_mode == ForwardMode.EXTEND: + ret.init_extend_args() + + if not global_server_args_dict.get("disable_flashinfer", False): + ret.init_flashinfer_args( + model_runner.model_config.num_attention_heads // tp_size, + model_runner.model_config.get_num_kv_heads(tp_size), + model_runner.model_config.head_dim, + ) + + return ret diff --git a/python/sglang/srt/managers/controller/model_runner.py b/python/sglang/srt/managers/controller/model_runner.py index 21466da03..a439756cf 100644 --- a/python/sglang/srt/managers/controller/model_runner.py +++ b/python/sglang/srt/managers/controller/model_runner.py @@ -4,11 +4,9 @@ import importlib import importlib.resources import logging import pkgutil -from dataclasses import dataclass from functools import lru_cache -from typing import List, Optional, Type +from typing import Optional, Type -import numpy as np import torch import torch.nn as nn from vllm.config import DeviceConfig, LoadConfig @@ -17,7 +15,7 @@ from vllm.distributed import init_distributed_environment, initialize_model_para from vllm.model_executor.model_loader import get_model from vllm.model_executor.models import ModelRegistry -from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode +from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode, InputMetadata, global_server_args_dict from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( @@ -29,210 +27,6 @@ from sglang.srt.utils import ( logger = logging.getLogger("srt.model_runner") -# for server args in model endpoints -global_server_args_dict = {} - - -@dataclass -class InputMetadata: - forward_mode: ForwardMode - batch_size: int - total_num_tokens: int - max_seq_len: int - req_pool_indices: torch.Tensor - start_loc: torch.Tensor - seq_lens: torch.Tensor - prefix_lens: torch.Tensor - positions: torch.Tensor - req_to_token_pool: ReqToTokenPool - token_to_kv_pool: TokenToKVPool - - # for extend - extend_seq_lens: torch.Tensor = None - extend_start_loc: torch.Tensor = None - max_extend_len: int = 0 - - out_cache_loc: torch.Tensor = None - out_cache_cont_start: torch.Tensor = None - out_cache_cont_end: torch.Tensor = None - - other_kv_index: torch.Tensor = None - return_logprob: bool = False - top_logprobs_nums: List[int] = None - - # for flashinfer - qo_indptr: torch.Tensor = None - kv_indptr: torch.Tensor = None - kv_indices: torch.Tensor = None - kv_last_page_len: torch.Tensor = None - flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None - flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None - flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None - - def init_flashinfer_args(self, num_qo_heads, num_kv_heads, head_dim): - if ( - self.forward_mode == ForwardMode.PREFILL - or self.forward_mode == ForwardMode.EXTEND - ): - paged_kernel_lens = self.prefix_lens - self.no_prefix = torch.all(self.prefix_lens == 0) - else: - paged_kernel_lens = self.seq_lens - - self.kv_indptr = torch.zeros( - (self.batch_size + 1,), dtype=torch.int32, device="cuda" - ) - self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) - self.kv_last_page_len = torch.ones( - (self.batch_size,), dtype=torch.int32, device="cuda" - ) - req_pool_indices_cpu = self.req_pool_indices.cpu().numpy() - paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy() - self.kv_indices = torch.cat( - [ - self.req_to_token_pool.req_to_token[ - req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i] - ] - for i in range(self.batch_size) - ], - dim=0, - ).contiguous() - - if ( - self.forward_mode == ForwardMode.PREFILL - or self.forward_mode == ForwardMode.EXTEND - ): - # extend part - self.qo_indptr = torch.zeros( - (self.batch_size + 1,), dtype=torch.int32, device="cuda" - ) - self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0) - - self.flashinfer_prefill_wrapper_ragged.end_forward() - self.flashinfer_prefill_wrapper_ragged.begin_forward( - self.qo_indptr, - self.qo_indptr.clone(), - num_qo_heads, - num_kv_heads, - head_dim, - ) - - # cached part - self.flashinfer_prefill_wrapper_paged.end_forward() - self.flashinfer_prefill_wrapper_paged.begin_forward( - self.qo_indptr, - self.kv_indptr, - self.kv_indices, - self.kv_last_page_len, - num_qo_heads, - num_kv_heads, - head_dim, - 1, - ) - else: - self.flashinfer_decode_wrapper.end_forward() - self.flashinfer_decode_wrapper.begin_forward( - self.kv_indptr, - self.kv_indices, - self.kv_last_page_len, - num_qo_heads, - num_kv_heads, - head_dim, - 1, - pos_encoding_mode="NONE", - data_type=self.token_to_kv_pool.kv_data[0].dtype, - ) - - def init_extend_args(self): - self.extend_seq_lens = self.seq_lens - self.prefix_lens - self.extend_start_loc = torch.zeros_like(self.seq_lens) - self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0) - self.max_extend_len = int(torch.max(self.extend_seq_lens)) - - @classmethod - def create( - cls, - model_runner, - tp_size, - forward_mode, - req_pool_indices, - seq_lens, - prefix_lens, - position_ids_offsets, - out_cache_loc, - out_cache_cont_start=None, - out_cache_cont_end=None, - top_logprobs_nums=None, - return_logprob=False, - flashinfer_prefill_wrapper_ragged=None, - flashinfer_prefill_wrapper_paged=None, - flashinfer_decode_wrapper=None, - ): - batch_size = len(req_pool_indices) - start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda") - start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0) - total_num_tokens = int(torch.sum(seq_lens)) - max_seq_len = int(torch.max(seq_lens)) - - if forward_mode == ForwardMode.DECODE: - positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64) - other_kv_index = model_runner.req_to_token_pool.req_to_token[ - req_pool_indices[0], seq_lens[0] - 1 - ].item() - else: - seq_lens_cpu = seq_lens.cpu().numpy() - prefix_lens_cpu = prefix_lens.cpu().numpy() - position_ids_offsets_cpu = position_ids_offsets.cpu().numpy() - positions = torch.tensor( - np.concatenate( - [ - np.arange( - prefix_lens_cpu[i] + position_ids_offsets_cpu[i], - seq_lens_cpu[i] + position_ids_offsets_cpu[i], - ) - for i in range(batch_size) - ], - axis=0, - ), - device="cuda", - ) - other_kv_index = None - - ret = cls( - forward_mode=forward_mode, - batch_size=batch_size, - total_num_tokens=total_num_tokens, - max_seq_len=max_seq_len, - req_pool_indices=req_pool_indices, - start_loc=start_loc, - seq_lens=seq_lens, - prefix_lens=prefix_lens, - positions=positions, - req_to_token_pool=model_runner.req_to_token_pool, - token_to_kv_pool=model_runner.token_to_kv_pool, - out_cache_loc=out_cache_loc, - out_cache_cont_start=out_cache_cont_start, - out_cache_cont_end=out_cache_cont_end, - other_kv_index=other_kv_index, - return_logprob=return_logprob, - top_logprobs_nums=top_logprobs_nums, - flashinfer_prefill_wrapper_ragged=flashinfer_prefill_wrapper_ragged, - flashinfer_prefill_wrapper_paged=flashinfer_prefill_wrapper_paged, - flashinfer_decode_wrapper=flashinfer_decode_wrapper, - ) - - if forward_mode == ForwardMode.EXTEND: - ret.init_extend_args() - - if not global_server_args_dict.get("disable_flashinfer", False): - ret.init_flashinfer_args( - model_runner.model_config.num_attention_heads // tp_size, - model_runner.model_config.get_num_kv_heads(tp_size), - model_runner.model_config.head_dim, - ) - - return ret - class ModelRunner: def __init__( @@ -245,6 +39,7 @@ class ModelRunner: nccl_port: int, server_args: ServerArgs, ): + # Parse args self.model_config = model_config self.mem_fraction_static = mem_fraction_static self.gpu_id = gpu_id @@ -256,7 +51,6 @@ class ModelRunner: monkey_patch_vllm_dummy_weight_loader() # Init torch distributed - logger.info(f"[gpu_id={self.gpu_id}] Set cuda device.") torch.cuda.set_device(self.gpu_id) logger.info(f"[gpu_id={self.gpu_id}] Init nccl begin.") @@ -287,11 +81,8 @@ class ModelRunner: ) # Set some global args - global global_server_args_dict - global_server_args_dict = { - "disable_flashinfer": server_args.disable_flashinfer, - "attention_reduce_in_fp32": server_args.attention_reduce_in_fp32, - } + global_server_args_dict["disable_flashinfer"] = server_args.disable_flashinfer + global_server_args_dict["attention_reduce_in_fp32"] = server_args.attention_reduce_in_fp32 # Load the model and create memory pool self.load_model() @@ -425,27 +216,6 @@ class ModelRunner: ) = None self.flashinfer_decode_wrapper = None - @torch.inference_mode() - def forward_prefill(self, batch: Batch): - input_metadata = InputMetadata.create( - self, - forward_mode=ForwardMode.PREFILL, - tp_size=self.tp_size, - req_pool_indices=batch.req_pool_indices, - seq_lens=batch.seq_lens, - prefix_lens=batch.prefix_lens, - position_ids_offsets=batch.position_ids_offsets, - out_cache_loc=batch.out_cache_loc, - top_logprobs_nums=batch.top_logprobs_nums, - return_logprob=batch.return_logprob, - flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged, - flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged, - flashinfer_decode_wrapper=self.flashinfer_decode_wrapper, - ) - return self.model.forward( - batch.input_ids, input_metadata.positions, input_metadata - ) - @torch.inference_mode() def forward_extend(self, batch: Batch): input_metadata = InputMetadata.create( @@ -523,8 +293,6 @@ class ModelRunner: return self.forward_decode(batch) elif forward_mode == ForwardMode.EXTEND: return self.forward_extend(batch) - elif forward_mode == ForwardMode.PREFILL: - return self.forward_prefill(batch) else: raise ValueError(f"Invaid forward mode: {forward_mode}")