diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index 611349577..be6795846 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -60,7 +60,6 @@ import torch.distributed as dist from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.model_config import ModelConfig -from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import ServerArgs @@ -208,14 +207,14 @@ def extend(reqs, model_runner): tree_cache=None, ) batch.prepare_for_extend(model_runner.model_config.vocab_size) - sample_output, logits_output = model_runner.forward(batch, ForwardMode.EXTEND) + sample_output, logits_output = model_runner.forward(batch) next_token_ids = sample_output.batch_next_token_ids.tolist() return next_token_ids, logits_output.next_token_logits, batch def decode(input_token_ids, batch, model_runner): batch.prepare_for_decode(input_token_ids) - sample_output, logits_output = model_runner.forward(batch, ForwardMode.DECODE) + sample_output, logits_output = model_runner.forward(batch) next_token_ids = sample_output.batch_next_token_ids.tolist() return next_token_ids, logits_output.next_token_logits diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index b81f3d2a0..72a926cab 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -103,7 +103,7 @@ class LogitsProcessor(nn.Module): @staticmethod def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata): - if logits_metadata.forward_mode == ForwardMode.DECODE: + if logits_metadata.forward_mode.is_decode(): output_top_logprobs = [] max_k = max(logits_metadata.top_logprobs_nums) ret = all_logprobs.topk(max_k, dim=1) @@ -163,7 +163,7 @@ class LogitsProcessor(nn.Module): assert isinstance(logits_metadata, LogitsMetadata) # Get the last hidden states and last logits for the next token prediction - if logits_metadata.forward_mode == ForwardMode.DECODE: + if logits_metadata.forward_mode.is_decode(): last_index = None last_hidden = hidden_states else: @@ -195,7 +195,7 @@ class LogitsProcessor(nn.Module): ) else: # When logprob is requested, compute the logits for all tokens. - if logits_metadata.forward_mode == ForwardMode.DECODE: + if logits_metadata.forward_mode.is_decode(): last_logprobs = torch.nn.functional.log_softmax(last_logits, dim=-1) # Get the logprob of top-k tokens diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index 91735a1b8..1a2feacd3 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -197,9 +197,9 @@ class RadixAttention(nn.Module): k = k.view(-1, self.tp_k_head_num, self.qk_head_dim) v = v.view(-1, self.tp_v_head_num, self.v_head_dim) - if input_metadata.forward_mode == ForwardMode.EXTEND: + if input_metadata.forward_mode.is_extend(): return self.extend_forward(q, k, v, input_metadata) - elif input_metadata.forward_mode == ForwardMode.DECODE: + elif input_metadata.forward_mode.is_decode(): return self.decode_forward(q, k, v, input_metadata) def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata): diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index f126cc9f3..6c6b7f842 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -29,6 +29,7 @@ from sglang.srt.constrained.jump_forward import JumpForwardMap from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool +from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo if TYPE_CHECKING: @@ -334,6 +335,8 @@ class ScheduleBatch: token_to_kv_pool: BaseTokenToKVPool tree_cache: BasePrefixCache + forward_mode: ForwardMode = None + # Batched arguments to model runner input_ids: torch.Tensor = None req_pool_indices: torch.Tensor = None @@ -397,6 +400,8 @@ class ScheduleBatch: return out_cache_loc def prepare_for_extend(self, vocab_size: int): + self.forward_mode = ForwardMode.EXTEND + bs = self.batch_size() reqs = self.reqs input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs] @@ -626,6 +631,8 @@ class ScheduleBatch: return jump_forward_reqs def prepare_for_decode(self, input_ids=None): + self.forward_mode = ForwardMode.DECODE + if input_ids is None: input_ids = [ r.output_ids[-1] if r.output_ids else r.origin_input_ids[-1] diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 513bc517f..736929a65 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -53,7 +53,6 @@ from sglang.srt.managers.schedule_batch import ( from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.model_config import ModelConfig -from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( @@ -521,9 +520,7 @@ class ModelTpServer: if self.model_runner.is_generation: # Forward and sample the next tokens if batch.extend_num_tokens != 0: - sample_output, logits_output = self.model_runner.forward( - batch, ForwardMode.EXTEND - ) + sample_output, logits_output = self.model_runner.forward(batch) next_token_ids = batch.check_sample_results(sample_output) batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens( next_token_ids @@ -588,7 +585,7 @@ class ModelTpServer: pt += req.extend_input_len else: assert batch.extend_num_tokens != 0 - logits_output = self.model_runner.forward(batch, ForwardMode.EXTEND) + logits_output = self.model_runner.forward(batch) embeddings = logits_output.embeddings.tolist() # Check finish conditions @@ -699,9 +696,7 @@ class ModelTpServer: batch.prepare_for_decode() # Forward and sample the next tokens - sample_output, logits_output = self.model_runner.forward( - batch, ForwardMode.DECODE - ) + sample_output, logits_output = self.model_runner.forward(batch) next_token_ids = batch.check_sample_results(sample_output) batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens( next_token_ids diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 75f9136d3..a6ad63ce1 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -25,10 +25,9 @@ import torch import triton import triton.language as tl -from sglang.srt.managers.schedule_batch import ScheduleBatch -from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool - if TYPE_CHECKING: + from sglang.srt.managers.schedule_batch import ScheduleBatch + from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo @@ -41,6 +40,15 @@ class ForwardMode(IntEnum): # Decode one token. DECODE = auto() + def is_prefill(self): + return self == ForwardMode.PREFILL + + def is_extend(self): + return self == ForwardMode.EXTEND + + def is_decode(self): + return self == ForwardMode.DECODE + @dataclass class InputMetadata: @@ -102,7 +110,7 @@ class InputMetadata: def compute_positions(self, batch: ScheduleBatch): position_ids_offsets = batch.position_ids_offsets - if self.forward_mode == ForwardMode.DECODE: + if self.forward_mode.is_decode(): if True: self.positions = self.seq_lens - 1 else: @@ -141,7 +149,7 @@ class InputMetadata: self.positions = self.positions.to(torch.int64) def compute_extend_infos(self, batch: ScheduleBatch): - if self.forward_mode == ForwardMode.DECODE: + if self.forward_mode.is_decode(): self.extend_seq_lens = self.extend_start_loc = self.extend_no_prefix = None self.extend_seq_lens_cpu = self.logprob_start_lens_cpu = None else: @@ -173,10 +181,9 @@ class InputMetadata: cls, model_runner: "ModelRunner", batch: ScheduleBatch, - forward_mode: ForwardMode, ): ret = cls( - forward_mode=forward_mode, + forward_mode=batch.forward_mode, sampling_info=batch.sampling_info, batch_size=batch.batch_size(), req_pool_indices=batch.req_pool_indices, @@ -194,13 +201,11 @@ class InputMetadata: ret.compute_extend_infos(batch) - if ( - forward_mode != ForwardMode.DECODE - or model_runner.server_args.disable_flashinfer - ): + fm = batch.forward_mode + if not fm.is_decode() or model_runner.server_args.disable_flashinfer: ret.total_num_tokens = int(torch.sum(ret.seq_lens)) - if forward_mode != ForwardMode.DECODE: + if not fm.is_decode(): ret.init_multimuldal_info(batch) if model_runner.server_args.disable_flashinfer: @@ -209,7 +214,7 @@ class InputMetadata: flashinfer_use_ragged = False if not model_runner.server_args.disable_flashinfer: if ( - forward_mode != ForwardMode.DECODE + not fm.is_decode() and int(torch.sum(ret.seq_lens)) > 4096 and model_runner.sliding_window_size is None ): @@ -226,7 +231,7 @@ class InputMetadata: self.triton_start_loc = torch.zeros_like(self.seq_lens, dtype=torch.int32) self.triton_start_loc[1:] = torch.cumsum(self.seq_lens[:-1], dim=0) - if self.forward_mode == ForwardMode.DECODE: + if self.forward_mode.is_decode(): self.triton_max_extend_len = None else: self.triton_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda") @@ -239,7 +244,7 @@ class InputMetadata: prefix_lens_cpu, flashinfer_use_ragged, ): - if self.forward_mode == ForwardMode.DECODE: + if self.forward_mode.is_decode(): prefix_lens = None else: prefix_lens = self.extend_prefix_lens @@ -339,7 +344,7 @@ def update_flashinfer_indices( kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda") - if forward_mode == ForwardMode.DECODE: + if forward_mode.is_decode(): # CUDA graph uses different flashinfer_decode_wrapper if flashinfer_decode_wrapper is None: flashinfer_decode_wrapper = model_runner.flashinfer_decode_wrapper @@ -388,7 +393,7 @@ def update_flashinfer_indices( kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda") for wrapper_id in range(2): if wrapper_id == 0: - if forward_mode == ForwardMode.DECODE: + if forward_mode.is_decode(): paged_kernel_lens = torch.minimum( seq_lens, torch.tensor(model_runner.sliding_window_size + 1) ) @@ -418,7 +423,7 @@ def update_flashinfer_indices( kv_indices, ) - if forward_mode == ForwardMode.DECODE: + if forward_mode.is_decode(): # CUDA graph uses different flashinfer_decode_wrapper if flashinfer_decode_wrapper is None: flashinfer_decode_wrapper = model_runner.flashinfer_decode_wrapper diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 78f99dcd6..3cb123c48 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -530,11 +530,7 @@ class ModelRunner: ): return self.cuda_graph_runner.replay(batch) - input_metadata = InputMetadata.from_schedule_batch( - self, - batch, - ForwardMode.DECODE, - ) + input_metadata = InputMetadata.from_schedule_batch(self, batch) return self.model.forward( batch.input_ids, input_metadata.positions, input_metadata @@ -542,11 +538,7 @@ class ModelRunner: @torch.inference_mode() def forward_extend(self, batch: ScheduleBatch): - input_metadata = InputMetadata.from_schedule_batch( - self, - batch, - forward_mode=ForwardMode.EXTEND, - ) + input_metadata = InputMetadata.from_schedule_batch(self, batch) if self.is_generation: return self.model.forward( batch.input_ids, input_metadata.positions, input_metadata @@ -562,11 +554,7 @@ class ModelRunner: @torch.inference_mode() def forward_extend_multi_modal(self, batch: ScheduleBatch): - input_metadata = InputMetadata.from_schedule_batch( - self, - batch, - forward_mode=ForwardMode.EXTEND, - ) + input_metadata = InputMetadata.from_schedule_batch(self, batch) return self.model.forward( batch.input_ids, input_metadata.positions, @@ -577,16 +565,18 @@ class ModelRunner: ) def forward( - self, batch: ScheduleBatch, forward_mode: ForwardMode + self, batch: ScheduleBatch ) -> Tuple[SampleOutput, LogitsProcessorOutput]: - if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND: + assert batch.forward_mode is not None + + if self.is_multimodal_model and batch.forward_mode.is_extend(): return self.forward_extend_multi_modal(batch) - elif forward_mode == ForwardMode.DECODE: + elif batch.forward_mode.is_decode(): return self.forward_decode(batch) - elif forward_mode == ForwardMode.EXTEND: + elif batch.forward_mode.is_extend(): return self.forward_extend(batch) else: - raise ValueError(f"Invaid forward mode: {forward_mode}") + raise ValueError(f"Invaid forward mode: {batch.forward_mode}") @lru_cache() diff --git a/python/sglang/srt/models/llava.py b/python/sglang/srt/models/llava.py index 62041a895..9e20a726a 100644 --- a/python/sglang/srt/models/llava.py +++ b/python/sglang/srt/models/llava.py @@ -136,7 +136,7 @@ class LlavaBaseForCausalLM(nn.Module): image_sizes: Optional[List[List[int]]] = None, image_offsets: Optional[List[int]] = None, ) -> torch.Tensor: - if input_metadata.forward_mode == ForwardMode.EXTEND: + if input_metadata.forward_mode.is_extend(): bs = input_metadata.batch_size # Got List[List[str]] extend it to List[str] # The length of the List should be equal to batch size @@ -357,7 +357,7 @@ class LlavaBaseForCausalLM(nn.Module): return self.language_model( input_ids, positions, input_metadata, input_embeds=input_embeds ) - elif input_metadata.forward_mode == ForwardMode.DECODE: + elif input_metadata.forward_mode.is_decode(): return self.language_model(input_ids, positions, input_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/llavavid.py b/python/sglang/srt/models/llavavid.py index f268ecbbc..45f47cffc 100644 --- a/python/sglang/srt/models/llavavid.py +++ b/python/sglang/srt/models/llavavid.py @@ -116,7 +116,7 @@ class LlavaVidForCausalLM(nn.Module): image_sizes: Optional[List[List[int]]] = None, image_offsets: Optional[List[int]] = None, ) -> torch.Tensor: - if input_metadata.forward_mode == ForwardMode.EXTEND: + if input_metadata.forward_mode.is_extend(): bs = input_metadata.batch_size # Embed text inputs @@ -199,7 +199,7 @@ class LlavaVidForCausalLM(nn.Module): return self.language_model( input_ids, positions, input_metadata, input_embeds=input_embeds ) - elif input_metadata.forward_mode == ForwardMode.DECODE: + elif input_metadata.forward_mode.is_decode(): return self.language_model(input_ids, positions, input_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):