diff --git a/docs/en/model_support.md b/docs/en/model_support.md index ece2425a6..f05775074 100644 --- a/docs/en/model_support.md +++ b/docs/en/model_support.md @@ -30,6 +30,6 @@ To port a model from vLLM to SGLang, you can compare these two files [SGLang Lla - Replace vllm's `LogitsProcessor` with SGLang's `LogitsProcessor`. - Replace other vLLM layers with SGLang layers (e.g., `RMSNorm`, `SiluAndMul`). - Remove `Sample`. - - Change `forward()` functions, and add `input_metadata`. + - Change `forward()` functions, and add `forward_batch`. - Add `EntryClass` at the end. diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index 354559a08..6fbeb80e8 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -225,16 +225,16 @@ def extend(reqs, model_runner): tree_cache=None, ) batch.prepare_for_extend(model_runner.model_config.vocab_size) - input_metadata = batch.get_input_metadata() - logits_output = model_runner.forward(input_metadata) + forward_batch = batch.get_forward_batch() + logits_output = model_runner.forward(forward_batch) next_token_ids = model_runner.sample(logits_output, batch).tolist() return next_token_ids, logits_output.next_token_logits, batch def decode(input_token_ids, batch, model_runner): batch.prepare_for_decode(input_token_ids) - input_metadata = batch.get_input_metadata() - logits_output = model_runner.forward(input_metadata) + forward_batch = batch.get_forward_batch() + logits_output = model_runner.forward(forward_batch) next_token_ids = model_runner.sample(logits_output, batch).tolist() return next_token_ids, logits_output.next_token_logits diff --git a/python/sglang/srt/layers/attention_backend.py b/python/sglang/srt/layers/attention_backend.py index c8fe52ed3..cff0a707a 100644 --- a/python/sglang/srt/layers/attention_backend.py +++ b/python/sglang/srt/layers/attention_backend.py @@ -16,7 +16,7 @@ import torch.nn as nn from sglang.global_config import global_config from sglang.srt.layers.flashinfer_utils import update_flashinfer_indices from sglang.srt.managers.schedule_batch import global_server_args_dict -from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.utils import is_hip if TYPE_CHECKING: @@ -37,7 +37,7 @@ class AttentionBackend(ABC): """The base class of attention backends""" @abstractmethod - def init_forward_metadata(self, input_metadata: InputMetadata): + def init_forward_metadata(self, forward_batch: ForwardBatch): """Init the metadata for a forward pass.""" raise NotImplementedError() @@ -61,18 +61,18 @@ class AttentionBackend(ABC): """Get the fill value for padded seq lens. Typically, it is 0 or 1.""" raise NotImplementedError() - def forward(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata): + def forward(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch): """Run forward on an attention layer.""" - if input_metadata.forward_mode.is_decode(): - return self.forward_decode(q, k, v, layer, input_metadata) + if forward_batch.forward_mode.is_decode(): + return self.forward_decode(q, k, v, layer, forward_batch) else: - return self.forward_extend(q, k, v, layer, input_metadata) + return self.forward_extend(q, k, v, layer, forward_batch) - def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata): + def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch): """Run a forward for decode.""" raise NotImplementedError() - def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata): + def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch): """Run a forward for extend.""" raise NotImplementedError() @@ -131,31 +131,31 @@ class FlashInferAttnBackend(AttentionBackend): self.forward_metadata = None self.cuda_graph_metadata = {} - def init_forward_metadata(self, input_metadata: InputMetadata): - if input_metadata.forward_mode.is_decode(): + def init_forward_metadata(self, forward_batch: ForwardBatch): + if forward_batch.forward_mode.is_decode(): prefix_lens = None use_ragged = False extend_no_prefix = False total_num_tokens = None else: - prefix_lens = input_metadata.extend_prefix_lens + prefix_lens = forward_batch.extend_prefix_lens # Some heuristics to check whether to use ragged forward use_ragged = False if ( - torch.sum(input_metadata.seq_lens).item() >= 4096 + torch.sum(forward_batch.seq_lens).item() >= 4096 and self.model_runner.sliding_window_size is None ): use_ragged = True - total_num_tokens = torch.sum(input_metadata.seq_lens).item() - extend_no_prefix = not torch.any(input_metadata.extend_prefix_lens).item() + total_num_tokens = torch.sum(forward_batch.seq_lens).item() + extend_no_prefix = not torch.any(forward_batch.extend_prefix_lens).item() update_flashinfer_indices( - input_metadata.forward_mode, + forward_batch.forward_mode, self.model_runner, - input_metadata.req_pool_indices, - input_metadata.seq_lens, + forward_batch.req_pool_indices, + forward_batch.seq_lens, prefix_lens, use_ragged=use_ragged, ) @@ -248,7 +248,7 @@ class FlashInferAttnBackend(AttentionBackend): def get_cuda_graph_seq_len_fill_value(self): return 0 - def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata): + def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch): if not isinstance(self.prefill_wrapper_paged, list): prefill_wrapper_paged = self.prefill_wrapper_paged else: @@ -264,12 +264,12 @@ class FlashInferAttnBackend(AttentionBackend): if not use_ragged: if k is not None: assert v is not None - input_metadata.token_to_kv_pool.set_kv_buffer( - layer.layer_id, input_metadata.out_cache_loc, k, v + forward_batch.token_to_kv_pool.set_kv_buffer( + layer.layer_id, forward_batch.out_cache_loc, k, v ) o = prefill_wrapper_paged.forward( q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), - input_metadata.token_to_kv_pool.get_kv_buffer(layer.layer_id), + forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), causal=True, sm_scale=layer.scaling, window_left=layer.sliding_window_size, @@ -290,7 +290,7 @@ class FlashInferAttnBackend(AttentionBackend): else: o2, s2 = prefill_wrapper_paged.forward_return_lse( q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), - input_metadata.token_to_kv_pool.get_kv_buffer(layer.layer_id), + forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), causal=False, sm_scale=layer.scaling, logits_soft_cap=layer.logit_cap, @@ -298,13 +298,13 @@ class FlashInferAttnBackend(AttentionBackend): o, _ = merge_state(o1, s1, o2, s2) - input_metadata.token_to_kv_pool.set_kv_buffer( - layer.layer_id, input_metadata.out_cache_loc, k, v + forward_batch.token_to_kv_pool.set_kv_buffer( + layer.layer_id, forward_batch.out_cache_loc, k, v ) return o.view(-1, layer.tp_q_head_num * layer.head_dim) - def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata): + def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch): use_ragged, extend_no_prefix, total_num_tokens, decode_wrapper = ( self.forward_metadata ) @@ -317,13 +317,13 @@ class FlashInferAttnBackend(AttentionBackend): if k is not None: assert v is not None - input_metadata.token_to_kv_pool.set_kv_buffer( - layer.layer_id, input_metadata.out_cache_loc, k, v + forward_batch.token_to_kv_pool.set_kv_buffer( + layer.layer_id, forward_batch.out_cache_loc, k, v ) o = decode_wrapper.forward( q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), - input_metadata.token_to_kv_pool.get_kv_buffer(layer.layer_id), + forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), sm_scale=layer.scaling, logits_soft_cap=layer.logit_cap, ) @@ -358,26 +358,26 @@ class TritonAttnBackend(AttentionBackend): self.cuda_graph_max_seq_len = model_runner.model_config.context_len - def init_forward_metadata(self, input_metadata: InputMetadata): + def init_forward_metadata(self, forward_batch: ForwardBatch): """Init auxiliary variables for triton attention backend.""" - if input_metadata.forward_mode.is_decode(): - start_loc = torch.zeros_like(input_metadata.seq_lens, dtype=torch.int32) - start_loc[1:] = torch.cumsum(input_metadata.seq_lens[:-1], dim=0) + if forward_batch.forward_mode.is_decode(): + start_loc = torch.zeros_like(forward_batch.seq_lens, dtype=torch.int32) + start_loc[1:] = torch.cumsum(forward_batch.seq_lens[:-1], dim=0) - total_num_tokens = torch.sum(input_metadata.seq_lens).item() + total_num_tokens = torch.sum(forward_batch.seq_lens).item() attn_logits = torch.empty( (self.num_head, total_num_tokens), dtype=self.reduce_dtype, device="cuda", ) - max_seq_len = torch.max(input_metadata.seq_lens).item() + max_seq_len = torch.max(forward_batch.seq_lens).item() max_extend_len = None else: start_loc = attn_logits = max_seq_len = None - prefix_lens = input_metadata.extend_prefix_lens - max_extend_len = torch.max(input_metadata.seq_lens - prefix_lens).item() + prefix_lens = forward_batch.extend_prefix_lens + max_extend_len = torch.max(forward_batch.seq_lens - prefix_lens).item() self.forward_metadata = start_loc, attn_logits, max_seq_len, max_extend_len @@ -415,15 +415,15 @@ class TritonAttnBackend(AttentionBackend): def get_cuda_graph_seq_len_fill_value(self): return 1 - def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata): + def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch): # TODO: reuse the buffer across layers if layer.qk_head_dim != layer.v_head_dim: o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim)) else: o = torch.empty_like(q) - input_metadata.token_to_kv_pool.set_kv_buffer( - layer.layer_id, input_metadata.out_cache_loc, k, v + forward_batch.token_to_kv_pool.set_kv_buffer( + layer.layer_id, forward_batch.out_cache_loc, k, v ) start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata @@ -432,20 +432,20 @@ class TritonAttnBackend(AttentionBackend): k.contiguous(), v.contiguous(), o.view(-1, layer.tp_q_head_num, layer.v_head_dim), - input_metadata.token_to_kv_pool.get_key_buffer(layer.layer_id), - input_metadata.token_to_kv_pool.get_value_buffer(layer.layer_id), - input_metadata.req_to_token_pool.req_to_token, - input_metadata.req_pool_indices, - input_metadata.seq_lens, - input_metadata.extend_seq_lens, - input_metadata.extend_start_loc, + forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id), + forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id), + forward_batch.req_to_token_pool.req_to_token, + forward_batch.req_pool_indices, + forward_batch.seq_lens, + forward_batch.extend_seq_lens, + forward_batch.extend_start_loc, max_extend_len, layer.scaling, layer.logit_cap, ) return o - def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata): + def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch): # During torch.compile, there is a bug in rotary_emb that causes the # output value to have a 3D tensor shape. This reshapes the output correctly. q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim) @@ -458,19 +458,19 @@ class TritonAttnBackend(AttentionBackend): start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata - input_metadata.token_to_kv_pool.set_kv_buffer( - layer.layer_id, input_metadata.out_cache_loc, k, v + forward_batch.token_to_kv_pool.set_kv_buffer( + layer.layer_id, forward_batch.out_cache_loc, k, v ) self.decode_attention_fwd( q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), - input_metadata.token_to_kv_pool.get_key_buffer(layer.layer_id), - input_metadata.token_to_kv_pool.get_value_buffer(layer.layer_id), + forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id), + forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id), o.view(-1, layer.tp_q_head_num, layer.v_head_dim), - input_metadata.req_to_token_pool.req_to_token, - input_metadata.req_pool_indices, + forward_batch.req_to_token_pool.req_to_token, + forward_batch.req_pool_indices, start_loc, - input_metadata.seq_lens, + forward_batch.seq_lens, attn_logits, max_seq_len, layer.scaling, diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 440c96392..86eec65cc 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -25,7 +25,7 @@ from vllm.distributed import ( tensor_model_parallel_all_gather, ) -from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode @dataclasses.dataclass @@ -61,26 +61,26 @@ class LogitsMetadata: extend_logprob_pruned_lens_cpu: Optional[List[int]] = None @classmethod - def from_input_metadata(cls, input_metadata: InputMetadata): - return_top_logprob = any(x > 0 for x in input_metadata.top_logprobs_nums) - if input_metadata.forward_mode.is_extend(): + def from_forward_batch(cls, forward_batch: ForwardBatch): + return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums) + if forward_batch.forward_mode.is_extend(): extend_logprob_pruned_lens_cpu = [ extend_len - start_len for extend_len, start_len in zip( - input_metadata.extend_seq_lens, - input_metadata.extend_logprob_start_lens_cpu, + forward_batch.extend_seq_lens, + forward_batch.extend_logprob_start_lens_cpu, ) ] else: extend_logprob_pruned_lens_cpu = None return cls( - forward_mode=input_metadata.forward_mode, - top_logprobs_nums=input_metadata.top_logprobs_nums, - return_logprob=input_metadata.return_logprob, + forward_mode=forward_batch.forward_mode, + top_logprobs_nums=forward_batch.top_logprobs_nums, + return_logprob=forward_batch.return_logprob, return_top_logprob=return_top_logprob, - extend_seq_lens=input_metadata.extend_seq_lens, - extend_seq_lens_cpu=input_metadata.extend_seq_lens_cpu, - extend_logprob_start_lens_cpu=input_metadata.extend_logprob_start_lens_cpu, + extend_seq_lens=forward_batch.extend_seq_lens, + extend_seq_lens_cpu=forward_batch.extend_seq_lens_cpu, + extend_logprob_start_lens_cpu=forward_batch.extend_logprob_start_lens_cpu, extend_logprob_pruned_lens_cpu=extend_logprob_pruned_lens_cpu, ) @@ -162,10 +162,10 @@ class LogitsProcessor(nn.Module): input_ids, hidden_states, weight, - logits_metadata: Union[LogitsMetadata, InputMetadata], + logits_metadata: Union[LogitsMetadata, ForwardBatch], ): - if isinstance(logits_metadata, InputMetadata): - logits_metadata = LogitsMetadata.from_input_metadata(logits_metadata) + if isinstance(logits_metadata, ForwardBatch): + logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata) assert isinstance(logits_metadata, LogitsMetadata) # Get the last hidden states and last logits for the next token prediction diff --git a/python/sglang/srt/layers/pooler.py b/python/sglang/srt/layers/pooler.py index 21752366a..751f09fdd 100644 --- a/python/sglang/srt/layers/pooler.py +++ b/python/sglang/srt/layers/pooler.py @@ -7,7 +7,7 @@ from enum import IntEnum import torch import torch.nn as nn -from sglang.srt.model_executor.model_runner import InputMetadata +from sglang.srt.model_executor.model_runner import ForwardBatch class PoolingType(IntEnum): @@ -36,10 +36,10 @@ class Pooler(nn.Module): self.normalize = normalize def forward( - self, hidden_states: torch.Tensor, input_metadata: InputMetadata + self, hidden_states: torch.Tensor, forward_batch: ForwardBatch ) -> EmbeddingPoolerOutput: if self.pooling_type == PoolingType.LAST: - last_token_indices = torch.cumsum(input_metadata.extend_seq_lens, dim=0) - 1 + last_token_indices = torch.cumsum(forward_batch.extend_seq_lens, dim=0) - 1 pooled_data = hidden_states[last_token_indices] else: raise ValueError(f"Invalid pooling type: {self.pooling_type}") diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index 8454d2928..25432660e 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -17,7 +17,7 @@ limitations under the License. from torch import nn -from sglang.srt.model_executor.forward_batch_info import InputMetadata +from sglang.srt.model_executor.forward_batch_info import ForwardBatch class RadixAttention(nn.Module): @@ -48,11 +48,11 @@ class RadixAttention(nn.Module): self.logit_cap = logit_cap self.sliding_window_size = sliding_window_size or -1 - def forward(self, q, k, v, input_metadata: InputMetadata): + def forward(self, q, k, v, forward_batch: ForwardBatch): if k is not None: # For cross-layer sharing, kv can be None assert v is not None k = k.view(-1, self.tp_k_head_num, self.qk_head_dim) v = v.view(-1, self.tp_v_head_num, self.v_head_dim) - return input_metadata.attn_backend.forward(q, k, v, self, input_metadata) + return forward_batch.attn_backend.forward(q, k, v, self, forward_batch) diff --git a/python/sglang/srt/lora/lora.py b/python/sglang/srt/lora/lora.py index 6cc1f0348..dcf6450a0 100644 --- a/python/sglang/srt/lora/lora.py +++ b/python/sglang/srt/lora/lora.py @@ -40,7 +40,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ) from vllm.model_executor.model_loader.loader import DefaultModelLoader -from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode class BaseLayerWithLoRA(nn.Module): diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index 04f29fead..84f55082a 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -23,7 +23,7 @@ import torch from sglang.srt.lora.lora import LoRAAdapter, get_lora_layer from sglang.srt.lora.lora_config import LoRAConfig -from sglang.srt.model_executor.forward_batch_info import InputMetadata +from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.utils import is_hip, replace_submodule # ROCm: flashinfer available later @@ -207,9 +207,9 @@ class LoRAManager: if lora_weight_name: self.B_buffer[lora_weight_name][i][buffer_id].copy_(weights) - def prepare_lora_batch(self, input_metadata: InputMetadata): + def prepare_lora_batch(self, forward_batch: ForwardBatch): # load active loras into lora memory pool - cur_uids = set(input_metadata.lora_paths) + cur_uids = set(forward_batch.lora_paths) assert len(cur_uids) <= self.max_loras_per_batch i = 0 evictable_uids = list(self.active_uids) @@ -229,14 +229,14 @@ class LoRAManager: return # setup lora in forward modules - bs = input_metadata.batch_size + bs = forward_batch.batch_size seg_lens = ( - input_metadata.extend_seq_lens - if input_metadata.forward_mode.is_extend() + forward_batch.extend_seq_lens + if forward_batch.forward_mode.is_extend() else torch.ones(bs) ) weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda") - for i, lora_path in enumerate(input_metadata.lora_paths): + for i, lora_path in enumerate(forward_batch.lora_paths): weight_indices[i] = self.buffer_id[lora_path] for module_name, module in self.lora_modules: diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 6cf870ad7..50d0bbd86 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -29,7 +29,7 @@ from sglang.srt.constrained.jump_forward import JumpForwardMap from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool -from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import ServerArgs @@ -511,8 +511,8 @@ class ScheduleBatch: self.extend_logprob_start_lens_cpu = [r.extend_logprob_start_len for r in reqs] self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size) - def get_input_metadata(self): - return InputMetadata.from_schedule_batch(self) + def get_forward_batch(self): + return ForwardBatch.from_schedule_batch(self) def mix_with_running(self, running_batch: "ScheduleBatch"): self.forward_mode = ForwardMode.MIXED diff --git a/python/sglang/srt/managers/scheduler_policy.py b/python/sglang/srt/managers/schedule_policy.py similarity index 99% rename from python/sglang/srt/managers/scheduler_policy.py rename to python/sglang/srt/managers/schedule_policy.py index 344c86278..9fb7a27aa 100644 --- a/python/sglang/srt/managers/scheduler_policy.py +++ b/python/sglang/srt/managers/schedule_policy.py @@ -32,7 +32,7 @@ from sglang.srt.mem_cache.radix_cache import TreeNode CLIP_MAX_NEW_TOKENS = int(os.environ.get("SGLANG_CLIP_MAX_NEW_TOKENS", "4096")) -class SchedulerPolicy: +class SchedulePolicy: def __init__(self, policy: str, tree_cache: BasePrefixCache): if tree_cache.disable and policy in ["lpm", "dfs-weight"]: # LPM and DFS-weight is meaningless when the tree cache is disabled. diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 093bcbe05..5a31fd65c 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -50,8 +50,8 @@ from sglang.srt.managers.schedule_batch import ( Req, ScheduleBatch, ) -from sglang.srt.managers.scheduler_policy import PrefillAdder, SchedulerPolicy -from sglang.srt.managers.tp_worker import ModelTpWorker +from sglang.srt.managers.schedule_policy import PrefillAdder, SchedulePolicy +from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.server_args import PortArgs, ServerArgs @@ -134,7 +134,7 @@ class Scheduler: ) # Launch a tensor parallel worker - self.tp_worker = ModelTpWorker( + self.tp_worker = TpModelWorker( gpu_id=gpu_id, tp_rank=tp_rank, server_args=server_args, @@ -179,7 +179,7 @@ class Scheduler: disable=server_args.disable_radix_cache, ) self.tree_cache_metrics = {"total": 0, "hit": 0} - self.policy = SchedulerPolicy(self.schedule_policy, self.tree_cache) + self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache) # Init running status self.waiting_queue: List[Req] = [] @@ -575,9 +575,9 @@ class Scheduler: if self.is_generation: # Forward and sample the next tokens if batch.extend_num_tokens != 0: - input_metadata = batch.get_input_metadata() + forward_batch = batch.get_forward_batch() logits_output, next_token_ids = self.tp_worker.forward_batch_generation( - input_metadata, batch + forward_batch, batch ) batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens( next_token_ids @@ -641,8 +641,8 @@ class Scheduler: ) else: assert batch.extend_num_tokens != 0 - input_metadata = batch.get_input_metadata() - embeddings = self.tp_worker.forward_batch_embedding(input_metadata) + forward_batch = batch.get_forward_batch() + embeddings = self.tp_worker.forward_batch_embedding(forward_batch) # Check finish conditions for i, req in enumerate(batch.reqs): @@ -771,9 +771,9 @@ class Scheduler: batch.prepare_for_decode() # Forward and sample the next tokens - input_metadata = batch.get_input_metadata() + forward_batch = batch.get_forward_batch() logits_output, next_token_ids = self.tp_worker.forward_batch_generation( - input_metadata, batch + forward_batch, batch ) batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens( next_token_ids diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index b5d8b4f7f..b62651fae 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -21,7 +21,7 @@ import logging from sglang.srt.configs.model_config import ModelConfig from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.managers.io_struct import UpdateWeightReqInput -from sglang.srt.model_executor.forward_batch_info import InputMetadata +from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.server_args import ServerArgs from sglang.srt.utils import broadcast_pyobj, is_multimodal_model, set_random_seed @@ -29,7 +29,9 @@ from sglang.srt.utils import broadcast_pyobj, is_multimodal_model, set_random_se logger = logging.getLogger(__name__) -class ModelTpWorker: +class TpModelWorker: + """A tensor parallel model worker.""" + def __init__( self, gpu_id: int, @@ -106,13 +108,13 @@ class ModelTpWorker: self.random_seed, ) - def forward_batch_generation(self, input_metadata: InputMetadata, batch): - logits_output = self.model_runner.forward(input_metadata) + def forward_batch_generation(self, forward_batch: ForwardBatch, batch): + logits_output = self.model_runner.forward(forward_batch) next_token_ids = self.model_runner.sample(logits_output, batch) return logits_output, next_token_ids - def forward_batch_embedding(self, input_metadata: InputMetadata): - logits_output = self.model_runner.forward(input_metadata) + def forward_batch_embedding(self, forward_batch: ForwardBatch): + logits_output = self.model_runner.forward(forward_batch) embeddings = logits_output.embeddings.tolist() return embeddings diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 763ee0e15..cdf3a77c9 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -31,7 +31,7 @@ from sglang.srt.layers.logits_processor import ( LogitsProcessor, LogitsProcessorOutput, ) -from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.utils import monkey_patch_vllm_all_gather if TYPE_CHECKING: @@ -196,7 +196,7 @@ class CudaGraphRunner: # Run and capture def run_once(): - input_metadata = InputMetadata( + forward_batch = ForwardBatch( forward_mode=ForwardMode.DECODE, batch_size=bs, input_ids=input_ids, @@ -210,7 +210,7 @@ class CudaGraphRunner: top_logprobs_nums=[0] * bs, positions=torch.clamp((seq_lens - 1), min=0).to(torch.int64), ) - return forward(input_ids, input_metadata.positions, input_metadata) + return forward(input_ids, forward_batch.positions, forward_batch) for _ in range(2): torch.cuda.synchronize() @@ -233,9 +233,9 @@ class CudaGraphRunner: self.graph_memory_pool = graph.pool() return graph, out - def replay(self, input_metadata: InputMetadata): - assert input_metadata.out_cache_loc is not None - raw_bs = input_metadata.batch_size + def replay(self, forward_batch: ForwardBatch): + assert forward_batch.out_cache_loc is not None + raw_bs = forward_batch.batch_size # Pad index = bisect.bisect_left(self.capture_bs, raw_bs) @@ -245,10 +245,10 @@ class CudaGraphRunner: self.out_cache_loc.zero_() # Common inputs - self.input_ids[:raw_bs] = input_metadata.input_ids - self.req_pool_indices[:raw_bs] = input_metadata.req_pool_indices - self.seq_lens[:raw_bs] = input_metadata.seq_lens - self.out_cache_loc[:raw_bs] = input_metadata.out_cache_loc + self.input_ids[:raw_bs] = forward_batch.input_ids + self.req_pool_indices[:raw_bs] = forward_batch.req_pool_indices + self.seq_lens[:raw_bs] = forward_batch.seq_lens + self.out_cache_loc[:raw_bs] = forward_batch.out_cache_loc # Attention backend self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph( @@ -271,15 +271,15 @@ class CudaGraphRunner: ) # Extract logprobs - if input_metadata.return_logprob: + if forward_batch.return_logprob: logits_output.next_token_logprobs = torch.nn.functional.log_softmax( logits_output.next_token_logits, dim=-1 ) - return_top_logprob = any(x > 0 for x in input_metadata.top_logprobs_nums) + return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums) if return_top_logprob: logits_metadata = LogitsMetadata( forward_mode=ForwardMode.DECODE, - top_logprobs_nums=input_metadata.top_logprobs_nums, + top_logprobs_nums=forward_batch.top_logprobs_nums, ) logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs( logits_output.next_token_logprobs, logits_metadata diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index c5b218a1b..e5b8ff34e 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -18,7 +18,7 @@ limitations under the License. """Meta data for a forward pass.""" from dataclasses import dataclass from enum import IntEnum, auto -from typing import TYPE_CHECKING, List, Set +from typing import TYPE_CHECKING, List import numpy as np import torch @@ -53,8 +53,8 @@ class ForwardMode(IntEnum): @dataclass -class InputMetadata: - """Store all inforamtion of a forward pass.""" +class ForwardBatch: + """Store all inputs of a forward pass.""" # The forward mode forward_mode: ForwardMode diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 9f65e5817..7cee4c22a 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -48,7 +48,7 @@ from sglang.srt.mem_cache.memory_pool import ( MLATokenToKVPool, ReqToTokenPool, ) -from sglang.srt.model_executor.forward_batch_info import InputMetadata +from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( @@ -466,47 +466,47 @@ class ModelRunner: logger.info("Capture cuda graph begin. This can take up to several minutes.") self.cuda_graph_runner = CudaGraphRunner(self) - def forward_decode(self, input_metadata: InputMetadata): + def forward_decode(self, forward_batch: ForwardBatch): if self.cuda_graph_runner and self.cuda_graph_runner.can_run( - input_metadata.batch_size + forward_batch.batch_size ): - return self.cuda_graph_runner.replay(input_metadata) + return self.cuda_graph_runner.replay(forward_batch) return self.model.forward( - input_metadata.input_ids, input_metadata.positions, input_metadata + forward_batch.input_ids, forward_batch.positions, forward_batch ) - def forward_extend(self, input_metadata: InputMetadata): + def forward_extend(self, forward_batch: ForwardBatch): if self.is_generation: return self.model.forward( - input_metadata.input_ids, input_metadata.positions, input_metadata + forward_batch.input_ids, forward_batch.positions, forward_batch ) else: # Only embedding models have get_embedding parameter return self.model.forward( - input_metadata.input_ids, - input_metadata.positions, - input_metadata, + forward_batch.input_ids, + forward_batch.positions, + forward_batch, get_embedding=True, ) - def forward(self, input_metadata: InputMetadata) -> LogitsProcessorOutput: + def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput: # Attach attention information - input_metadata.req_to_token_pool = self.req_to_token_pool - input_metadata.token_to_kv_pool = self.token_to_kv_pool - input_metadata.attn_backend = self.attn_backend - input_metadata.attn_backend.init_forward_metadata(input_metadata) + forward_batch.req_to_token_pool = self.req_to_token_pool + forward_batch.token_to_kv_pool = self.token_to_kv_pool + forward_batch.attn_backend = self.attn_backend + forward_batch.attn_backend.init_forward_metadata(forward_batch) # Attach lora information if self.server_args.lora_paths is not None: - self.lora_manager.prepare_lora_batch(input_metadata) + self.lora_manager.prepare_lora_batch(forward_batch) - if input_metadata.forward_mode.is_decode(): - return self.forward_decode(input_metadata) - elif input_metadata.forward_mode.is_extend(): - return self.forward_extend(input_metadata) + if forward_batch.forward_mode.is_decode(): + return self.forward_decode(forward_batch) + elif forward_batch.forward_mode.is_extend(): + return self.forward_extend(forward_batch) else: - raise ValueError(f"Invaid forward mode: {input_metadata.forward_mode}") + raise ValueError(f"Invaid forward mode: {forward_batch.forward_mode}") def _apply_logits_bias( self, logits: torch.Tensor, sampling_info: SamplingBatchInfo diff --git a/python/sglang/srt/models/baichuan.py b/python/sglang/srt/models/baichuan.py index f42627ad6..2c04e5aeb 100644 --- a/python/sglang/srt/models/baichuan.py +++ b/python/sglang/srt/models/baichuan.py @@ -46,7 +46,7 @@ from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.model_executor.forward_batch_info import InputMetadata +from sglang.srt.model_executor.forward_batch_info import ForwardBatch def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: @@ -189,13 +189,13 @@ class BaiChuanAttention(nn.Module): self, positions: torch.Tensor, hidden_states: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, ) -> torch.Tensor: qkv, _ = self.W_pack(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) if self.postion_embedding != "ALIBI": q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, input_metadata) + attn_output = self.attn(q, k, v, forward_batch) output, _ = self.o_proj(attn_output) return output @@ -237,7 +237,7 @@ class BaiChuanDecoderLayer(nn.Module): self, positions: torch.Tensor, hidden_states: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention @@ -249,7 +249,7 @@ class BaiChuanDecoderLayer(nn.Module): hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - input_metadata=input_metadata, + forward_batch=forward_batch, ) # Fully Connected @@ -292,7 +292,7 @@ class BaiChuanModel(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) residual = None @@ -301,7 +301,7 @@ class BaiChuanModel(nn.Module): hidden_states, residual = layer( positions, hidden_states, - input_metadata, + forward_batch, residual, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -350,11 +350,11 @@ class BaiChuanBaseForCausalLM(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, input_metadata) + hidden_states = self.model(input_ids, positions, forward_batch) return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, input_metadata + input_ids, hidden_states, self.lm_head.weight, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/chatglm.py b/python/sglang/srt/models/chatglm.py index 7e8dd2adc..3d1319e40 100644 --- a/python/sglang/srt/models/chatglm.py +++ b/python/sglang/srt/models/chatglm.py @@ -42,7 +42,7 @@ from sglang.srt.layers.linear import ( from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.model_executor.forward_batch_info import InputMetadata +from sglang.srt.model_executor.forward_batch_info import ForwardBatch LoraConfig = None @@ -118,7 +118,7 @@ class GLMAttention(nn.Module): self, hidden_states: torch.Tensor, position_ids: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, ) -> torch.Tensor: qkv, _ = self.query_key_value(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) @@ -127,7 +127,7 @@ class GLMAttention(nn.Module): q, k, v, - input_metadata, + forward_batch, ) attn_output, _ = self.dense(context_layer) return attn_output @@ -220,7 +220,7 @@ class GLMBlock(nn.Module): self, hidden_states: torch.Tensor, position_ids: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, ) -> torch.Tensor: # hidden_states: [num_tokens, h] # Layer norm at the beginning of the transformer layer. @@ -229,7 +229,7 @@ class GLMBlock(nn.Module): attention_output = self.self_attention( hidden_states=layernorm_output, position_ids=position_ids, - input_metadata=input_metadata, + forward_batch=forward_batch, ) # Residual connection. @@ -288,14 +288,14 @@ class GLMTransformer(nn.Module): self, hidden_states: torch.Tensor, position_ids: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, ) -> torch.Tensor: for i in range(self.num_layers): layer = self.layers[i] hidden_states = layer( hidden_states=hidden_states, position_ids=position_ids, - input_metadata=input_metadata, + forward_batch=forward_batch, ) # Final layer norm. if self.post_layer_norm: @@ -328,7 +328,7 @@ class ChatGLMModel(nn.Module): self, input_ids: torch.Tensor, position_ids: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, ) -> torch.Tensor: inputs_embeds = self.embedding(input_ids) @@ -336,7 +336,7 @@ class ChatGLMModel(nn.Module): hidden_states = self.encoder( hidden_states=inputs_embeds, position_ids=position_ids, - input_metadata=input_metadata, + forward_batch=forward_batch, ) return hidden_states @@ -376,11 +376,11 @@ class ChatGLMForCausalLM(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, ) -> torch.Tensor: - hidden_states = self.transformer(input_ids, positions, input_metadata) + hidden_states = self.transformer(input_ids, positions, forward_batch) return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, input_metadata + input_ids, hidden_states, self.lm_head.weight, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/commandr.py b/python/sglang/srt/models/commandr.py index 9c93d1e41..f2ad52963 100644 --- a/python/sglang/srt/models/commandr.py +++ b/python/sglang/srt/models/commandr.py @@ -63,7 +63,7 @@ from sglang.srt.layers.linear import ( from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.model_executor.forward_batch_info import InputMetadata +from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.utils import set_weight_attrs @@ -220,14 +220,14 @@ class CohereAttention(nn.Module): self, positions: torch.Tensor, hidden_states: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) if self.use_qk_norm: q, k = self._apply_qk_norm(q, k) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, input_metadata) + attn_output = self.attn(q, k, v, forward_batch) output, _ = self.o_proj(attn_output) return output @@ -255,7 +255,7 @@ class CohereDecoderLayer(nn.Module): self, positions: torch.Tensor, hidden_states: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention @@ -264,7 +264,7 @@ class CohereDecoderLayer(nn.Module): hidden_states_attention = self.self_attn( positions=positions, hidden_states=hidden_states, - input_metadata=input_metadata, + forward_batch=forward_batch, ) hidden_states_mlp = self.mlp(hidden_states) # Add everything together @@ -299,7 +299,7 @@ class CohereModel(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) residual = None @@ -308,7 +308,7 @@ class CohereModel(nn.Module): hidden_states, residual = layer( positions, hidden_states, - input_metadata, + forward_batch, residual, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -333,15 +333,15 @@ class CohereForCausalLM(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, ) -> torch.Tensor: hidden_states = self.model( input_ids, positions, - input_metadata, + forward_batch, ) return self.logits_processor( - input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata + input_ids, hidden_states, self.model.embed_tokens.weight, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/dbrx.py b/python/sglang/srt/models/dbrx.py index 9fd0f335d..9d4fafd63 100644 --- a/python/sglang/srt/models/dbrx.py +++ b/python/sglang/srt/models/dbrx.py @@ -44,7 +44,7 @@ from sglang.srt.layers.linear import ( from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.model_executor.forward_batch_info import InputMetadata +from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.utils import set_weight_attrs @@ -249,14 +249,14 @@ class DbrxAttention(nn.Module): self, position_ids: torch.Tensor, hidden_states: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, ) -> torch.Tensor: qkv, _ = self.Wqkv(hidden_states) if self.clip_qkv is not None: qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(position_ids, q, k) - attn_output = self.attn(q, k, v, input_metadata) + attn_output = self.attn(q, k, v, forward_batch) hidden_states, _ = self.out_proj(attn_output) return hidden_states @@ -278,14 +278,14 @@ class DbrxFusedNormAttention(nn.Module): self, position_ids: torch.Tensor, hidden_states: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, ) -> torch.Tensor: residual = hidden_states hidden_states = self.norm_1(hidden_states) x = self.attn( position_ids=position_ids, hidden_states=hidden_states, - input_metadata=input_metadata, + forward_batch=forward_batch, ) hidden_states = residual + x residual = hidden_states @@ -310,12 +310,12 @@ class DbrxBlock(nn.Module): self, position_ids: torch.Tensor, hidden_states: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, ) -> torch.Tensor: hidden_states, residual = self.norm_attn_norm( position_ids=position_ids, hidden_states=hidden_states, - input_metadata=input_metadata, + forward_batch=forward_batch, ) hidden_states = self.ffn(hidden_states) hidden_states = hidden_states + residual @@ -349,7 +349,7 @@ class DbrxModel(nn.Module): self, input_ids: torch.Tensor, position_ids: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, ) -> torch.Tensor: if input_embeds is None: @@ -358,7 +358,7 @@ class DbrxModel(nn.Module): hidden_states = input_embeds for i in range(len(self.blocks)): block = self.blocks[i] - hidden_states = block(position_ids, hidden_states, input_metadata) + hidden_states = block(position_ids, hidden_states, forward_batch) hidden_states = self.norm_f(hidden_states) return hidden_states @@ -388,11 +388,11 @@ class DbrxForCausalLM(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, ) -> torch.Tensor: - hidden_states = self.transformer(input_ids, positions, input_metadata) + hidden_states = self.transformer(input_ids, positions, forward_batch) return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, input_metadata + input_ids, hidden_states, self.lm_head.weight, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/deepseek.py b/python/sglang/srt/models/deepseek.py index 590f8f0bf..b320b5167 100644 --- a/python/sglang/srt/models/deepseek.py +++ b/python/sglang/srt/models/deepseek.py @@ -46,7 +46,7 @@ from sglang.srt.layers.linear import ( from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.model_executor.forward_batch_info import InputMetadata +from sglang.srt.model_executor.forward_batch_info import ForwardBatch class DeepseekMLP(nn.Module): @@ -246,12 +246,12 @@ class DeepseekAttention(nn.Module): self, positions: torch.Tensor, hidden_states: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, input_metadata) + attn_output = self.attn(q, k, v, forward_batch) output, _ = self.o_proj(attn_output) return output @@ -303,7 +303,7 @@ class DeepseekDecoderLayer(nn.Module): self, positions: torch.Tensor, hidden_states: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, residual: Optional[torch.Tensor], ) -> torch.Tensor: # Self Attention @@ -315,7 +315,7 @@ class DeepseekDecoderLayer(nn.Module): hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - input_metadata=input_metadata, + forward_batch=forward_batch, ) # Fully Connected @@ -356,14 +356,14 @@ class DeepseekModel(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) residual = None for i in range(len(self.layers)): layer = self.layers[i] hidden_states, residual = layer( - positions, hidden_states, input_metadata, residual + positions, hidden_states, forward_batch, residual ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -391,11 +391,11 @@ class DeepseekForCausalLM(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, input_metadata) + hidden_states = self.model(input_ids, positions, forward_batch) return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, input_metadata + input_ids, hidden_states, self.lm_head.weight, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 2d46f3e8a..8524be22b 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -46,7 +46,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.managers.schedule_batch import global_server_args_dict -from sglang.srt.model_executor.forward_batch_info import InputMetadata +from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.utils import is_hip # ROCm: flashinfer available later @@ -281,7 +281,7 @@ class DeepseekV2Attention(nn.Module): self, positions: torch.Tensor, hidden_states: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, ) -> torch.Tensor: if self.q_lora_rank is not None: q = self.q_a_proj(hidden_states)[0] @@ -314,7 +314,7 @@ class DeepseekV2Attention(nn.Module): v = torch.nn.functional.pad(v, [0, 256 - self.v_head_dim], value=0).view( -1, self.num_local_heads * 256 ) - attn_output = self.attn(q, k, v, input_metadata) + attn_output = self.attn(q, k, v, forward_batch) attn_output = attn_output.view(-1, self.num_local_heads, 256)[ ..., : self.v_head_dim ].reshape(-1, self.num_local_heads * self.v_head_dim) @@ -433,7 +433,7 @@ class DeepseekV2AttentionMLA(nn.Module): self, positions: torch.Tensor, hidden_states: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, ) -> torch.Tensor: q_len = hidden_states.shape[0] q_input = hidden_states.new_empty( @@ -471,7 +471,7 @@ class DeepseekV2AttentionMLA(nn.Module): q_input[..., self.kv_lora_rank :] = q_pe k_input[..., self.kv_lora_rank :] = k_pe - attn_output = self.attn(q_input, k_input, v_input, input_metadata) + attn_output = self.attn(q_input, k_input, v_input, forward_batch) attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank) if self.w_vc.dtype == torch.float8_e4m3fn: @@ -567,7 +567,7 @@ class DeepseekV2DecoderLayer(nn.Module): self, positions: torch.Tensor, hidden_states: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, residual: Optional[torch.Tensor], ) -> torch.Tensor: # Self Attention @@ -579,7 +579,7 @@ class DeepseekV2DecoderLayer(nn.Module): hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - input_metadata=input_metadata, + forward_batch=forward_batch, ) # Fully Connected @@ -623,14 +623,14 @@ class DeepseekV2Model(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) residual = None for i in range(len(self.layers)): layer = self.layers[i] hidden_states, residual = layer( - positions, hidden_states, input_metadata, residual + positions, hidden_states, forward_batch, residual ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -658,11 +658,11 @@ class DeepseekV2ForCausalLM(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, input_metadata) + hidden_states = self.model(input_ids, positions, forward_batch) return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, input_metadata + input_ids, hidden_states, self.lm_head.weight, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/exaone.py b/python/sglang/srt/models/exaone.py index 4af7a9f47..f0b47786a 100644 --- a/python/sglang/srt/models/exaone.py +++ b/python/sglang/srt/models/exaone.py @@ -40,7 +40,7 @@ from sglang.srt.layers.linear import ( from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.model_executor.forward_batch_info import InputMetadata +from sglang.srt.model_executor.forward_batch_info import ForwardBatch class ExaoneGatedMLP(nn.Module): @@ -162,12 +162,12 @@ class ExaoneAttention(nn.Module): self, positions: torch.Tensor, hidden_states: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, input_metadata) + attn_output = self.attn(q, k, v, forward_batch) output, _ = self.out_proj(attn_output) return output @@ -220,7 +220,7 @@ class ExaoneDecoderLayer(nn.Module): self, positions: torch.Tensor, hidden_states: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention @@ -232,7 +232,7 @@ class ExaoneDecoderLayer(nn.Module): hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - input_metadata=input_metadata, + forward_batch=forward_batch, ) # Fully Connected @@ -270,7 +270,7 @@ class ExaoneModel(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, ) -> torch.Tensor: if input_embeds is None: @@ -283,7 +283,7 @@ class ExaoneModel(nn.Module): hidden_states, residual = layer( positions, hidden_states, - input_metadata, + forward_batch, residual, ) hidden_states, _ = self.ln_f(hidden_states, residual) @@ -309,14 +309,14 @@ class ExaoneForCausalLM(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, ) -> LogitsProcessorOutput: hidden_states = self.transformer( - input_ids, positions, input_metadata, input_embeds + input_ids, positions, forward_batch, input_embeds ) return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, input_metadata + input_ids, hidden_states, self.lm_head.weight, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/gemma.py b/python/sglang/srt/models/gemma.py index 50cbca1a1..5c3456161 100644 --- a/python/sglang/srt/models/gemma.py +++ b/python/sglang/srt/models/gemma.py @@ -37,7 +37,7 @@ from sglang.srt.layers.linear import ( from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.model_executor.forward_batch_info import InputMetadata +from sglang.srt.model_executor.forward_batch_info import ForwardBatch class GemmaMLP(nn.Module): @@ -137,12 +137,12 @@ class GemmaAttention(nn.Module): self, positions: torch.Tensor, hidden_states: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, input_metadata) + attn_output = self.attn(q, k, v, forward_batch) output, _ = self.o_proj(attn_output) return output @@ -180,7 +180,7 @@ class GemmaDecoderLayer(nn.Module): self, positions: torch.Tensor, hidden_states: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention @@ -192,7 +192,7 @@ class GemmaDecoderLayer(nn.Module): hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - input_metadata=input_metadata, + forward_batch=forward_batch, ) # Fully Connected @@ -226,7 +226,7 @@ class GemmaModel(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, ) -> torch.Tensor: if input_embeds is None: @@ -243,7 +243,7 @@ class GemmaModel(nn.Module): hidden_states, residual = layer( positions, hidden_states, - input_metadata, + forward_batch, residual, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -293,12 +293,12 @@ class GemmaForCausalLM(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) + hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) return self.logits_processor( - input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata + input_ids, hidden_states, self.model.embed_tokens.weight, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/gemma2.py b/python/sglang/srt/models/gemma2.py index b6a282bd8..59205416a 100644 --- a/python/sglang/srt/models/gemma2.py +++ b/python/sglang/srt/models/gemma2.py @@ -37,7 +37,7 @@ from sglang.srt.layers.linear import ( from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.model_executor.forward_batch_info import InputMetadata +from sglang.srt.model_executor.forward_batch_info import ForwardBatch # Aligned with HF's implementation, using sliding window inclusive with the last token @@ -175,12 +175,12 @@ class Gemma2Attention(nn.Module): self, positions: torch.Tensor, hidden_states: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, input_metadata) + attn_output = self.attn(q, k, v, forward_batch) output, _ = self.o_proj(attn_output) return output @@ -230,7 +230,7 @@ class Gemma2DecoderLayer(nn.Module): self, positions: torch.Tensor, hidden_states: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: if residual is None: @@ -241,7 +241,7 @@ class Gemma2DecoderLayer(nn.Module): hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - input_metadata=input_metadata, + forward_batch=forward_batch, ) hidden_states = self.post_attention_layernorm(hidden_states) @@ -286,7 +286,7 @@ class Gemma2Model(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, ) -> torch.Tensor: if input_embeds is None: @@ -302,7 +302,7 @@ class Gemma2Model(nn.Module): hidden_states, residual = layer( positions, hidden_states, - input_metadata, + forward_batch, residual, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -352,12 +352,12 @@ class Gemma2ForCausalLM(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) + hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) return self.logits_processor( - input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata + input_ids, hidden_states, self.model.embed_tokens.weight, forward_batch ) def get_attention_sliding_window_size(self): diff --git a/python/sglang/srt/models/gpt_bigcode.py b/python/sglang/srt/models/gpt_bigcode.py index f063c732f..ad61b742f 100644 --- a/python/sglang/srt/models/gpt_bigcode.py +++ b/python/sglang/srt/models/gpt_bigcode.py @@ -35,7 +35,7 @@ from sglang.srt.layers.linear import ( from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.model_executor.forward_batch_info import InputMetadata +from sglang.srt.model_executor.forward_batch_info import ForwardBatch class GPTBigCodeAttention(nn.Module): @@ -90,7 +90,7 @@ class GPTBigCodeAttention(nn.Module): def forward( self, hidden_states: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, ) -> torch.Tensor: qkv, _ = self.c_attn(hidden_states) q, k, v = qkv.split( @@ -101,7 +101,7 @@ class GPTBigCodeAttention(nn.Module): ], dim=-1, ) - attn_output = self.attn(q, k, v, input_metadata) + attn_output = self.attn(q, k, v, forward_batch) attn_output, _ = self.c_proj(attn_output) return attn_output @@ -160,12 +160,12 @@ class GPTBigCodeBlock(nn.Module): def forward( self, hidden_states: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, ) -> torch.Tensor: residual = hidden_states hidden_states = self.ln_1(hidden_states) attn_output = self.attn( - hidden_states=hidden_states, input_metadata=input_metadata + hidden_states=hidden_states, forward_batch=forward_batch ) # residual connection hidden_states = attn_output + residual @@ -214,7 +214,7 @@ class GPTBigCodeModel(nn.Module): self, input_ids: torch.Tensor, position_ids: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, ) -> torch.Tensor: inputs_embeds = self.wte(input_ids) position_embeds = self.wpe(position_ids) @@ -222,7 +222,7 @@ class GPTBigCodeModel(nn.Module): for i in range(len(self.h)): layer = self.h[i] - hidden_states = layer(hidden_states, input_metadata) + hidden_states = layer(hidden_states, forward_batch) hidden_states = self.ln_f(hidden_states) return hidden_states @@ -267,11 +267,11 @@ class GPTBigCodeForCausalLM(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, ) -> torch.Tensor: - hidden_states = self.transformer(input_ids, positions, input_metadata) + hidden_states = self.transformer(input_ids, positions, forward_batch) return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, input_metadata + input_ids, hidden_states, self.lm_head.weight, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/grok.py b/python/sglang/srt/models/grok.py index 2eda52983..80bdc2c4c 100644 --- a/python/sglang/srt/models/grok.py +++ b/python/sglang/srt/models/grok.py @@ -46,7 +46,7 @@ from sglang.srt.layers.linear import ( from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.model_executor.forward_batch_info import InputMetadata +from sglang.srt.model_executor.forward_batch_info import ForwardBatch class Grok1MoE(nn.Module): @@ -173,12 +173,12 @@ class Grok1Attention(nn.Module): self, positions: torch.Tensor, hidden_states: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, input_metadata) + attn_output = self.attn(q, k, v, forward_batch) output, _ = self.o_proj(attn_output) return output @@ -219,7 +219,7 @@ class Grok1DecoderLayer(nn.Module): self, positions: torch.Tensor, hidden_states: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, ) -> torch.Tensor: # Self Attention hidden_states = ( @@ -227,7 +227,7 @@ class Grok1DecoderLayer(nn.Module): self.self_attn( positions=positions, hidden_states=self.pre_attn_norm(hidden_states), - input_metadata=input_metadata, + forward_batch=forward_batch, ) ) + hidden_states @@ -268,7 +268,7 @@ class Grok1Model(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, ) -> torch.Tensor: if input_embeds is None: @@ -278,7 +278,7 @@ class Grok1Model(nn.Module): hidden_states = input_embeds for i in range(len(self.layers)): - hidden_states = self.layers[i](positions, hidden_states, input_metadata) + hidden_states = self.layers[i](positions, hidden_states, forward_batch) hidden_states = self.norm(hidden_states) hidden_states.mul_(self.config.output_multiplier_scale) return hidden_states @@ -309,12 +309,12 @@ class Grok1ForCausalLM(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) + hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, input_metadata + input_ids, hidden_states, self.lm_head.weight, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/internlm2.py b/python/sglang/srt/models/internlm2.py index 9eb73f1fc..087793afc 100644 --- a/python/sglang/srt/models/internlm2.py +++ b/python/sglang/srt/models/internlm2.py @@ -40,7 +40,7 @@ from sglang.srt.layers.linear import ( from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.model_executor.forward_batch_info import InputMetadata +from sglang.srt.model_executor.forward_batch_info import ForwardBatch class InternLM2MLP(nn.Module): @@ -137,12 +137,12 @@ class InternLM2Attention(nn.Module): self, positions: torch.Tensor, hidden_states: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, ) -> torch.Tensor: qkv, _ = self.wqkv(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, input_metadata) + attn_output = self.attn(q, k, v, forward_batch) output, _ = self.wo(attn_output) return output @@ -182,7 +182,7 @@ class InternLMDecoderLayer(nn.Module): self, positions: torch.Tensor, hidden_states: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention @@ -194,7 +194,7 @@ class InternLMDecoderLayer(nn.Module): hidden_states = self.attention( positions=positions, hidden_states=hidden_states, - input_metadata=input_metadata, + forward_batch=forward_batch, ) # Fully Connected @@ -229,7 +229,7 @@ class InternLM2Model(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, ) -> torch.Tensor: if input_embeds is None: @@ -242,7 +242,7 @@ class InternLM2Model(nn.Module): hidden_states, residual = layer( positions, hidden_states, - input_metadata, + forward_batch, residual, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -268,12 +268,12 @@ class InternLM2ForCausalLM(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) + hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) return self.logits_processor( - input_ids, hidden_states, self.output.weight, input_metadata + input_ids, hidden_states, self.output.weight, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index b63aaf16f..ce7eed969 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -43,7 +43,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.torchao_utils import apply_torchao_config_ from sglang.srt.managers.schedule_batch import global_server_args_dict -from sglang.srt.model_executor.forward_batch_info import InputMetadata +from sglang.srt.model_executor.forward_batch_info import ForwardBatch class LlamaMLP(nn.Module): @@ -162,12 +162,12 @@ class LlamaAttention(nn.Module): self, positions: torch.Tensor, hidden_states: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, input_metadata) + attn_output = self.attn(q, k, v, forward_batch) output, _ = self.o_proj(attn_output) return output @@ -221,7 +221,7 @@ class LlamaDecoderLayer(nn.Module): self, positions: torch.Tensor, hidden_states: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention @@ -233,7 +233,7 @@ class LlamaDecoderLayer(nn.Module): hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - input_metadata=input_metadata, + forward_batch=forward_batch, ) # Fully Connected @@ -270,7 +270,7 @@ class LlamaModel(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, ) -> torch.Tensor: if input_embeds is None: @@ -283,7 +283,7 @@ class LlamaModel(nn.Module): hidden_states, residual = layer( positions, hidden_states, - input_metadata, + forward_batch, residual, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -310,12 +310,12 @@ class LlamaForCausalLM(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, ) -> LogitsProcessorOutput: - hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) + hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, input_metadata + input_ids, hidden_states, self.lm_head.weight, forward_batch ) def get_hidden_dim(self, module_name): diff --git a/python/sglang/srt/models/llama_classification.py b/python/sglang/srt/models/llama_classification.py index 536bec2f1..58ec09ba0 100644 --- a/python/sglang/srt/models/llama_classification.py +++ b/python/sglang/srt/models/llama_classification.py @@ -23,7 +23,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.quantization.base_config import QuantizationConfig -from sglang.srt.model_executor.forward_batch_info import InputMetadata +from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel @@ -50,18 +50,18 @@ class LlamaForClassification(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) + hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) is_eos_token = input_ids == self.eos_token_id hidden_states = hidden_states[is_eos_token] scores = self.classification_head(hidden_states) - if scores.shape[0] != input_metadata.batch_size: + if scores.shape[0] != forward_batch.batch_size: print("Warning: the EOS tokens are missing in some sentences.") scores = torch.ones( - (input_metadata.batch_size, self.config.classification_out_size) + (forward_batch.batch_size, self.config.classification_out_size) ).to(input_ids.device) logits_output = LogitsProcessorOutput( diff --git a/python/sglang/srt/models/llama_embedding.py b/python/sglang/srt/models/llama_embedding.py index fe407b29f..19e324f92 100644 --- a/python/sglang/srt/models/llama_embedding.py +++ b/python/sglang/srt/models/llama_embedding.py @@ -6,7 +6,7 @@ from transformers import LlamaConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType -from sglang.srt.model_executor.model_runner import InputMetadata +from sglang.srt.model_executor.model_runner import ForwardBatch from sglang.srt.models.llama import LlamaModel @@ -26,15 +26,15 @@ class LlamaEmbeddingModel(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, get_embedding: bool = True, ) -> EmbeddingPoolerOutput: assert ( get_embedding ), "LlamaEmbeddingModel / MistralModel is only used for embedding" - hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) - return self.pooler(hidden_states, input_metadata) + hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) + return self.pooler(hidden_states, forward_batch) def load_weights( self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None diff --git a/python/sglang/srt/models/llama_reward.py b/python/sglang/srt/models/llama_reward.py index 519d9a0d2..fd868a62c 100644 --- a/python/sglang/srt/models/llama_reward.py +++ b/python/sglang/srt/models/llama_reward.py @@ -24,7 +24,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig -from sglang.srt.model_executor.forward_batch_info import InputMetadata +from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel @@ -51,13 +51,13 @@ class LlamaForSequenceClassification(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, ) -> EmbeddingPoolerOutput: - hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) + hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) scores = self.score(hidden_states) - return self.pooler(scores, input_metadata) + return self.pooler(scores, forward_batch) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters()) @@ -102,19 +102,19 @@ class LlamaForSequenceClassificationWithNormal_Weights(LlamaForSequenceClassific self, input_ids: torch.Tensor, positions: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, get_embedding: bool = True, ) -> EmbeddingPoolerOutput: assert ( get_embedding ), "LlamaForSequenceClassification is only used for embedding" - hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) + hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) logits = self.score(hidden_states) weights = self.weights(hidden_states) - pooled_logits = self.pooler(logits, input_metadata).embeddings - pooled_weights = self.pooler(weights, input_metadata).embeddings + pooled_logits = self.pooler(logits, forward_batch).embeddings + pooled_weights = self.pooler(weights, forward_batch).embeddings rews = pooled_logits.view(-1, self.num_labels // 2, 2)[:, :, 0].view( -1, self.num_labels // 2 diff --git a/python/sglang/srt/models/llava.py b/python/sglang/srt/models/llava.py index 1d8a3f40f..b9c1fa0aa 100644 --- a/python/sglang/srt/models/llava.py +++ b/python/sglang/srt/models/llava.py @@ -41,7 +41,7 @@ from sglang.srt.mm_utils import ( unpad_image, unpad_image_shape, ) -from sglang.srt.model_executor.forward_batch_info import InputMetadata +from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.models.llama import LlamaForCausalLM from sglang.srt.models.mistral import MistralForCausalLM from sglang.srt.models.qwen2 import Qwen2ForCausalLM @@ -130,12 +130,12 @@ class LlavaBaseForCausalLM(nn.Module): self, input_ids: torch.LongTensor, positions: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, ) -> torch.Tensor: - image_inputs = input_metadata.image_inputs + image_inputs = forward_batch.image_inputs - if input_metadata.forward_mode.is_extend(): - bs = input_metadata.batch_size + if forward_batch.forward_mode.is_extend(): + bs = forward_batch.batch_size # Got List[List[str]] extend it to List[str] # The length of the List should be equal to batch size modalities_list = [] @@ -151,7 +151,7 @@ class LlavaBaseForCausalLM(nn.Module): # Embed text inputs input_embeds = self.language_model.model.embed_tokens(input_ids) - start_positions = positions[input_metadata.extend_start_loc].cpu().numpy() + start_positions = positions[forward_batch.extend_start_loc].cpu().numpy() need_vision = start_positions <= np.array(max_image_offset) if need_vision.any(): @@ -348,8 +348,8 @@ class LlavaBaseForCausalLM(nn.Module): image_features = new_image_features # Fill in the placeholder for the image - extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy() - prefix_lens_cpu = input_metadata.extend_prefix_lens.cpu().numpy() + extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy() + prefix_lens_cpu = forward_batch.extend_prefix_lens.cpu().numpy() pt = 0 for i in range(bs): if not need_vision[i]: @@ -379,10 +379,10 @@ class LlavaBaseForCausalLM(nn.Module): pt += 1 return self.language_model( - input_ids, positions, input_metadata, input_embeds=input_embeds + input_ids, positions, forward_batch, input_embeds=input_embeds ) - elif input_metadata.forward_mode.is_decode(): - return self.language_model(input_ids, positions, input_metadata) + elif forward_batch.forward_mode.is_decode(): + return self.language_model(input_ids, positions, forward_batch) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Load clip vision model by cfg['mm_vision_tower']: diff --git a/python/sglang/srt/models/llavavid.py b/python/sglang/srt/models/llavavid.py index 4613c208f..82aa7c15d 100644 --- a/python/sglang/srt/models/llavavid.py +++ b/python/sglang/srt/models/llavavid.py @@ -27,7 +27,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.managers.schedule_batch import ImageInputs -from sglang.srt.model_executor.forward_batch_info import InputMetadata +from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.models.llama import LlamaForCausalLM @@ -108,11 +108,11 @@ class LlavaVidForCausalLM(nn.Module): self, input_ids: torch.LongTensor, positions: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, ) -> torch.Tensor: - image_inputs = input_metadata.image_inputs - if input_metadata.forward_mode.is_extend(): - bs = input_metadata.batch_size + image_inputs = forward_batch.image_inputs + if forward_batch.forward_mode.is_extend(): + bs = forward_batch.batch_size # Embed text inputs input_embeds = self.language_model.model.embed_tokens(input_ids) @@ -124,7 +124,7 @@ class LlavaVidForCausalLM(nn.Module): max_image_offset.append(max(im.image_offsets)) else: max_image_offset.append(-1) - start_positions = positions[input_metadata.extend_start_loc].cpu().numpy() + start_positions = positions[forward_batch.extend_start_loc].cpu().numpy() need_vision = start_positions <= np.array(max_image_offset) if need_vision.any(): @@ -169,8 +169,8 @@ class LlavaVidForCausalLM(nn.Module): image_features = new_image_features # Fill in the placeholder for the image - extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy() - prefix_lens_cpu = input_metadata.extend_prefix_lens.cpu().numpy() + extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy() + prefix_lens_cpu = forward_batch.extend_prefix_lens.cpu().numpy() pt = 0 for i in range(bs): if not need_vision[i]: @@ -200,10 +200,10 @@ class LlavaVidForCausalLM(nn.Module): pt += 1 return self.language_model( - input_ids, positions, input_metadata, input_embeds=input_embeds + input_ids, positions, forward_batch, input_embeds=input_embeds ) - elif input_metadata.forward_mode.is_decode(): - return self.language_model(input_ids, positions, input_metadata) + elif forward_batch.forward_mode.is_decode(): + return self.language_model(input_ids, positions, forward_batch) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Load clip vision model by cfg['mm_vision_tower']: diff --git a/python/sglang/srt/models/minicpm.py b/python/sglang/srt/models/minicpm.py index f796e881f..777d572f9 100644 --- a/python/sglang/srt/models/minicpm.py +++ b/python/sglang/srt/models/minicpm.py @@ -39,7 +39,7 @@ from sglang.srt.layers.linear import ( from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.model_executor.forward_batch_info import InputMetadata +from sglang.srt.model_executor.forward_batch_info import ForwardBatch class MiniCPMMLP(nn.Module): @@ -148,7 +148,7 @@ class MiniCPMAttention(nn.Module): self, positions: torch.Tensor, hidden_states: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) @@ -156,7 +156,7 @@ class MiniCPMAttention(nn.Module): q, k = q.float(), k.float() q, k = self.rotary_emb(positions, q, k) q, k = q.to(orig_dtype), k.to(orig_dtype) - attn_output = self.attn(q, k, v, input_metadata) + attn_output = self.attn(q, k, v, forward_batch) output, _ = self.o_proj(attn_output) return output @@ -199,7 +199,7 @@ class MiniCPMDecoderLayer(nn.Module): self, positions: torch.Tensor, hidden_states: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention @@ -208,7 +208,7 @@ class MiniCPMDecoderLayer(nn.Module): hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - input_metadata=input_metadata, + forward_batch=forward_batch, ) hidden_states = residual + hidden_states * ( self.config.scale_depth / math.sqrt(self.config.num_hidden_layers) @@ -252,7 +252,7 @@ class MiniCPMModel(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, ) -> torch.Tensor: if input_embeds is None: @@ -266,7 +266,7 @@ class MiniCPMModel(nn.Module): hidden_states, residual = layer( positions, hidden_states, - input_metadata, + forward_batch, residual, ) hidden_states = self.norm(hidden_states) @@ -303,19 +303,19 @@ class MiniCPMForCausalLM(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, ) -> torch.Tensor: if input_embeds is not None: input_embeds = input_embeds * self.config.scale_emb - hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) + hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) hidden_states = hidden_states / self.scale_width if self.config.tie_word_embeddings: lm_head_weight = self.model.embed_tokens.weight else: lm_head_weight = self.lm_head.weight return self.logits_processor( - input_ids, hidden_states, lm_head_weight, input_metadata + input_ids, hidden_states, lm_head_weight, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/minicpm3.py b/python/sglang/srt/models/minicpm3.py index 3bc094e5a..0e29eb357 100644 --- a/python/sglang/srt/models/minicpm3.py +++ b/python/sglang/srt/models/minicpm3.py @@ -42,7 +42,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.managers.schedule_batch import global_server_args_dict -from sglang.srt.model_executor.forward_batch_info import InputMetadata +from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.utils import is_hip # ROCm: flashinfer available later @@ -193,7 +193,7 @@ class MiniCPM3Attention(nn.Module): self, positions: torch.Tensor, hidden_states: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, ) -> torch.Tensor: if self.q_lora_rank is not None: q = self.q_a_proj(hidden_states)[0] @@ -230,7 +230,7 @@ class MiniCPM3Attention(nn.Module): v = torch.nn.functional.pad(v, [0, 128 - self.v_head_dim], value=0).view( -1, self.num_local_heads * 128 ) - attn_output = self.attn(q, k, v, input_metadata) + attn_output = self.attn(q, k, v, forward_batch) attn_output = attn_output.view(-1, self.num_local_heads, 128)[ ..., : self.v_head_dim ].reshape(-1, self.num_local_heads * self.v_head_dim) @@ -341,7 +341,7 @@ class MiniCPM3AttentionMLA(nn.Module): self, positions: torch.Tensor, hidden_states: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, ) -> torch.Tensor: q_len = hidden_states.shape[0] q_input = hidden_states.new_empty( @@ -383,7 +383,7 @@ class MiniCPM3AttentionMLA(nn.Module): q_input[..., self.kv_lora_rank :] = q_pe k_input[..., self.kv_lora_rank :] = k_pe - attn_output = self.attn(q_input, k_input, v_input, input_metadata) + attn_output = self.attn(q_input, k_input, v_input, forward_batch) attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank) if self.w_vc.dtype == torch.float8_e4m3fn: @@ -472,7 +472,7 @@ class MiniCPM3DecoderLayer(nn.Module): self, positions: torch.Tensor, hidden_states: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention @@ -481,7 +481,7 @@ class MiniCPM3DecoderLayer(nn.Module): hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - input_metadata=input_metadata, + forward_batch=forward_batch, ) hidden_states = residual + hidden_states * ( self.config.scale_depth / math.sqrt(self.config.num_hidden_layers) @@ -528,7 +528,7 @@ class MiniCPM3Model(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, ) -> torch.Tensor: if input_embeds is None: @@ -542,7 +542,7 @@ class MiniCPM3Model(nn.Module): hidden_states, residual = layer( positions, hidden_states, - input_metadata, + forward_batch, residual, ) hidden_states = self.norm(hidden_states) @@ -581,19 +581,19 @@ class MiniCPM3ForCausalLM(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, ) -> torch.Tensor: if input_embeds is not None: input_embeds = input_embeds * self.config.scale_emb - hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) + hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) hidden_states = hidden_states / self.scale_width if self.config.tie_word_embeddings: lm_head_weight = self.model.embed_tokens.weight else: lm_head_weight = self.lm_head.weight return self.logits_processor( - input_ids, hidden_states, lm_head_weight, input_metadata + input_ids, hidden_states, lm_head_weight, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/mixtral.py b/python/sglang/srt/models/mixtral.py index 73a7b8686..b72220fe5 100644 --- a/python/sglang/srt/models/mixtral.py +++ b/python/sglang/srt/models/mixtral.py @@ -43,7 +43,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.torchao_utils import apply_torchao_config_ from sglang.srt.managers.schedule_batch import global_server_args_dict -from sglang.srt.model_executor.forward_batch_info import InputMetadata +from sglang.srt.model_executor.forward_batch_info import ForwardBatch class MixtralMoE(nn.Module): @@ -171,12 +171,12 @@ class MixtralAttention(nn.Module): self, positions: torch.Tensor, hidden_states: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, input_metadata) + attn_output = self.attn(q, k, v, forward_batch) output, _ = self.o_proj(attn_output) return output @@ -220,7 +220,7 @@ class MixtralDecoderLayer(nn.Module): self, positions: torch.Tensor, hidden_states: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, residual: Optional[torch.Tensor], ) -> torch.Tensor: # Self Attention @@ -232,7 +232,7 @@ class MixtralDecoderLayer(nn.Module): hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - input_metadata=input_metadata, + forward_batch=forward_batch, ) # Fully Connected @@ -270,7 +270,7 @@ class MixtralModel(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, ) -> torch.Tensor: if input_embeds is None: @@ -281,7 +281,7 @@ class MixtralModel(nn.Module): for i in range(len(self.layers)): layer = self.layers[i] hidden_states, residual = layer( - positions, hidden_states, input_metadata, residual + positions, hidden_states, forward_batch, residual ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -307,12 +307,12 @@ class MixtralForCausalLM(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) + hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, input_metadata + input_ids, hidden_states, self.lm_head.weight, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/mixtral_quant.py b/python/sglang/srt/models/mixtral_quant.py index 2e7483771..f69d2ea59 100644 --- a/python/sglang/srt/models/mixtral_quant.py +++ b/python/sglang/srt/models/mixtral_quant.py @@ -45,7 +45,7 @@ from sglang.srt.layers.linear import ( from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.model_executor.forward_batch_info import InputMetadata +from sglang.srt.model_executor.forward_batch_info import ForwardBatch class MixtralMLP(nn.Module): @@ -216,12 +216,12 @@ class MixtralAttention(nn.Module): self, positions: torch.Tensor, hidden_states: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, input_metadata) + attn_output = self.attn(q, k, v, forward_batch) output, _ = self.o_proj(attn_output) return output @@ -256,7 +256,7 @@ class MixtralDecoderLayer(nn.Module): self, positions: torch.Tensor, hidden_states: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, residual: Optional[torch.Tensor], ) -> torch.Tensor: # Self Attention @@ -268,7 +268,7 @@ class MixtralDecoderLayer(nn.Module): hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - input_metadata=input_metadata, + forward_batch=forward_batch, ) # Fully Connected @@ -303,7 +303,7 @@ class MixtralModel(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, ) -> torch.Tensor: if input_embeds is None: @@ -314,7 +314,7 @@ class MixtralModel(nn.Module): for i in range(len(self.layers)): layer = self.layers[i] hidden_states, residual = layer( - positions, hidden_states, input_metadata, residual + positions, hidden_states, forward_batch, residual ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -339,12 +339,12 @@ class QuantMixtralForCausalLM(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) + hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, input_metadata + input_ids, hidden_states, self.lm_head.weight, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/olmoe.py b/python/sglang/srt/models/olmoe.py index 28eda6f4f..3e851268d 100644 --- a/python/sglang/srt/models/olmoe.py +++ b/python/sglang/srt/models/olmoe.py @@ -48,7 +48,7 @@ from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.model_executor.forward_batch_info import InputMetadata +from sglang.srt.model_executor.forward_batch_info import ForwardBatch class OlmoeMoE(nn.Module): @@ -175,13 +175,13 @@ class OlmoeAttention(nn.Module): self, positions: torch.Tensor, hidden_states: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.q_norm(q.contiguous()), self.k_norm(k.contiguous()) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, input_metadata) + attn_output = self.attn(q, k, v, forward_batch) output, _ = self.o_proj(attn_output) return output @@ -225,7 +225,7 @@ class OlmoeDecoderLayer(nn.Module): self, positions: torch.Tensor, hidden_states: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, residual: Optional[torch.Tensor], ) -> torch.Tensor: # Self Attention @@ -238,7 +238,7 @@ class OlmoeDecoderLayer(nn.Module): hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - input_metadata=input_metadata, + forward_batch=forward_batch, ) # Fully Connected @@ -274,7 +274,7 @@ class OlmoeModel(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, ) -> torch.Tensor: if input_embeds is None: @@ -285,7 +285,7 @@ class OlmoeModel(nn.Module): for i in range(len(self.layers)): layer = self.layers[i] hidden_states, residual = layer( - positions, hidden_states, input_metadata, residual + positions, hidden_states, forward_batch, residual ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -314,12 +314,12 @@ class OlmoeForCausalLM(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) + hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, input_metadata + input_ids, hidden_states, self.lm_head.weight, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/qwen.py b/python/sglang/srt/models/qwen.py index 3a2d31001..614311c17 100644 --- a/python/sglang/srt/models/qwen.py +++ b/python/sglang/srt/models/qwen.py @@ -39,7 +39,7 @@ from sglang.srt.layers.linear import ( from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.model_executor.forward_batch_info import InputMetadata +from sglang.srt.model_executor.forward_batch_info import ForwardBatch class QWenMLP(nn.Module): @@ -133,12 +133,12 @@ class QWenAttention(nn.Module): self, positions: torch.Tensor, hidden_states: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, ) -> torch.Tensor: qkv, _ = self.c_attn(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, input_metadata) + attn_output = self.attn(q, k, v, forward_batch) output, _ = self.c_proj(attn_output) return output @@ -177,7 +177,7 @@ class QWenBlock(nn.Module): self, positions: torch.Tensor, hidden_states: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, ) -> torch.Tensor: # Self Attention residual = hidden_states @@ -185,7 +185,7 @@ class QWenBlock(nn.Module): hidden_states = self.attn( positions=positions, hidden_states=hidden_states, - input_metadata=input_metadata, + forward_batch=forward_batch, ) hidden_states = residual + hidden_states @@ -224,7 +224,7 @@ class QWenModel(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, ) -> torch.Tensor: hidden_states = self.wte(input_ids) for i in range(len(self.h)): @@ -232,7 +232,7 @@ class QWenModel(nn.Module): hidden_states = layer( positions, hidden_states, - input_metadata, + forward_batch, ) hidden_states = self.ln_f(hidden_states) return hidden_states @@ -257,11 +257,11 @@ class QWenLMHeadModel(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, ): - hidden_states = self.transformer(input_ids, positions, input_metadata) + hidden_states = self.transformer(input_ids, positions, forward_batch) return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, input_metadata + input_ids, hidden_states, self.lm_head.weight, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index 5aac3f79c..d4beae4b3 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -40,7 +40,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.model_executor.forward_batch_info import InputMetadata +from sglang.srt.model_executor.forward_batch_info import ForwardBatch Qwen2Config = None @@ -149,12 +149,12 @@ class Qwen2Attention(nn.Module): self, positions: torch.Tensor, hidden_states: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, input_metadata) + attn_output = self.attn(q, k, v, forward_batch) output, _ = self.o_proj(attn_output) return output @@ -196,7 +196,7 @@ class Qwen2DecoderLayer(nn.Module): self, positions: torch.Tensor, hidden_states: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention @@ -208,7 +208,7 @@ class Qwen2DecoderLayer(nn.Module): hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - input_metadata=input_metadata, + forward_batch=forward_batch, ) # Fully Connected @@ -243,7 +243,7 @@ class Qwen2Model(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, ) -> torch.Tensor: if input_embeds is None: @@ -256,7 +256,7 @@ class Qwen2Model(nn.Module): hidden_states, residual = layer( positions, hidden_states, - input_metadata, + forward_batch, residual, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -283,17 +283,17 @@ class Qwen2ForCausalLM(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, get_embedding: bool = False, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) + hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) if not get_embedding: return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, input_metadata + input_ids, hidden_states, self.lm_head.weight, forward_batch ) else: - return self.pooler(hidden_states, input_metadata) + return self.pooler(hidden_states, forward_batch) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 27fcf6321..2eed8ff45 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -49,7 +49,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.torchao_utils import apply_torchao_config_ from sglang.srt.managers.schedule_batch import global_server_args_dict -from sglang.srt.model_executor.forward_batch_info import InputMetadata +from sglang.srt.model_executor.forward_batch_info import ForwardBatch class Qwen2MoeMLP(nn.Module): @@ -221,12 +221,12 @@ class Qwen2MoeAttention(nn.Module): self, positions: torch.Tensor, hidden_states: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, input_metadata) + attn_output = self.attn(q, k, v, forward_batch) output, _ = self.o_proj(attn_output) return output @@ -281,7 +281,7 @@ class Qwen2MoeDecoderLayer(nn.Module): self, positions: torch.Tensor, hidden_states: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, residual: Optional[torch.Tensor], ) -> torch.Tensor: # Self Attention @@ -293,7 +293,7 @@ class Qwen2MoeDecoderLayer(nn.Module): hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - input_metadata=input_metadata, + forward_batch=forward_batch, ) # Fully Connected @@ -331,7 +331,7 @@ class Qwen2MoeModel(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, ) -> torch.Tensor: if input_embeds is None: @@ -342,7 +342,7 @@ class Qwen2MoeModel(nn.Module): for i in range(len(self.layers)): layer = self.layers[i] hidden_states, residual = layer( - positions, hidden_states, input_metadata, residual + positions, hidden_states, forward_batch, residual ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -373,12 +373,12 @@ class Qwen2MoeForCausalLM(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) + hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, input_metadata + input_ids, hidden_states, self.lm_head.weight, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/stablelm.py b/python/sglang/srt/models/stablelm.py index 2e344ca50..b211037fc 100644 --- a/python/sglang/srt/models/stablelm.py +++ b/python/sglang/srt/models/stablelm.py @@ -40,7 +40,7 @@ from sglang.srt.layers.linear import ( from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.model_executor.forward_batch_info import InputMetadata +from sglang.srt.model_executor.forward_batch_info import ForwardBatch class StablelmMLP(nn.Module): @@ -145,12 +145,12 @@ class StablelmAttention(nn.Module): self, positions: torch.Tensor, hidden_states: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, input_metadata) + attn_output = self.attn(q, k, v, forward_batch) output, _ = self.o_proj(attn_output) return output @@ -173,7 +173,7 @@ class StablelmDecoderLayer(nn.Module): self, positions: torch.Tensor, hidden_states: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention residual = hidden_states @@ -181,7 +181,7 @@ class StablelmDecoderLayer(nn.Module): hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - input_metadata=input_metadata, + forward_batch=forward_batch, ) hidden_states = residual + hidden_states @@ -218,7 +218,7 @@ class StableLMEpochModel(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, ) -> torch.Tensor: if input_embeds is None: @@ -230,7 +230,7 @@ class StableLMEpochModel(nn.Module): hidden_states, residual = layer( positions, hidden_states, - input_metadata, + forward_batch, ) hidden_states = self.norm(hidden_states) return hidden_states @@ -255,12 +255,12 @@ class StableLmForCausalLM(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) + hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, input_metadata + input_ids, hidden_states, self.lm_head.weight, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/xverse.py b/python/sglang/srt/models/xverse.py index 88d3ea2c7..bd10606b5 100644 --- a/python/sglang/srt/models/xverse.py +++ b/python/sglang/srt/models/xverse.py @@ -41,7 +41,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.model_executor.model_runner import InputMetadata +from sglang.srt.model_executor.model_runner import ForwardBatch class XverseMLP(nn.Module): @@ -160,12 +160,12 @@ class XverseAttention(nn.Module): self, positions: torch.Tensor, hidden_states: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, input_metadata) + attn_output = self.attn(q, k, v, forward_batch) output, _ = self.o_proj(attn_output) return output @@ -222,7 +222,7 @@ class XverseDecoderLayer(nn.Module): self, positions: torch.Tensor, hidden_states: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention @@ -234,7 +234,7 @@ class XverseDecoderLayer(nn.Module): hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - input_metadata=input_metadata, + forward_batch=forward_batch, ) # Fully Connected @@ -271,7 +271,7 @@ class XverseModel(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, ) -> torch.Tensor: if input_embeds is None: @@ -284,7 +284,7 @@ class XverseModel(nn.Module): hidden_states, residual = layer( positions, hidden_states, - input_metadata, + forward_batch, residual, ) # print(f"layer[{i}].hidden_states: {hidden_states}") @@ -312,12 +312,12 @@ class XverseForCausalLM(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) + hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, input_metadata + input_ids, hidden_states, self.lm_head.weight, forward_batch ) def load_weights( diff --git a/python/sglang/srt/models/xverse_moe.py b/python/sglang/srt/models/xverse_moe.py index 666bf1d3d..7ff25b340 100644 --- a/python/sglang/srt/models/xverse_moe.py +++ b/python/sglang/srt/models/xverse_moe.py @@ -44,7 +44,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.model_executor.forward_batch_info import InputMetadata +from sglang.srt.model_executor.forward_batch_info import ForwardBatch class XverseMLP(nn.Module): @@ -244,12 +244,12 @@ class XverseAttention(nn.Module): self, positions: torch.Tensor, hidden_states: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, input_metadata) + attn_output = self.attn(q, k, v, forward_batch) output, _ = self.o_proj(attn_output) return output @@ -300,7 +300,7 @@ class XverseDecoderLayer(nn.Module): self, positions: torch.Tensor, hidden_states: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, residual: Optional[torch.Tensor], ) -> torch.Tensor: # Self Attention @@ -312,7 +312,7 @@ class XverseDecoderLayer(nn.Module): hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, - input_metadata=input_metadata, + forward_batch=forward_batch, ) # Fully Connected @@ -353,14 +353,14 @@ class XverseModel(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) residual = None for i in range(len(self.layers)): layer = self.layers[i] hidden_states, residual = layer( - positions, hidden_states, input_metadata, residual + positions, hidden_states, forward_batch, residual ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -388,11 +388,11 @@ class XverseMoeForCausalLM(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - input_metadata: InputMetadata, + forward_batch: ForwardBatch, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, input_metadata) + hidden_states = self.model(input_ids, positions, forward_batch) return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, input_metadata + input_ids, hidden_states, self.lm_head.weight, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):