Rename InputMetadata -> ForwardBatch (#1543)

This commit is contained in:
Lianmin Zheng
2024-09-30 02:41:11 -07:00
committed by GitHub
parent 3f0fe08d37
commit 36d5acfca5
44 changed files with 435 additions and 433 deletions

View File

@@ -225,16 +225,16 @@ def extend(reqs, model_runner):
tree_cache=None,
)
batch.prepare_for_extend(model_runner.model_config.vocab_size)
input_metadata = batch.get_input_metadata()
logits_output = model_runner.forward(input_metadata)
forward_batch = batch.get_forward_batch()
logits_output = model_runner.forward(forward_batch)
next_token_ids = model_runner.sample(logits_output, batch).tolist()
return next_token_ids, logits_output.next_token_logits, batch
def decode(input_token_ids, batch, model_runner):
batch.prepare_for_decode(input_token_ids)
input_metadata = batch.get_input_metadata()
logits_output = model_runner.forward(input_metadata)
forward_batch = batch.get_forward_batch()
logits_output = model_runner.forward(forward_batch)
next_token_ids = model_runner.sample(logits_output, batch).tolist()
return next_token_ids, logits_output.next_token_logits

View File

@@ -16,7 +16,7 @@ import torch.nn as nn
from sglang.global_config import global_config
from sglang.srt.layers.flashinfer_utils import update_flashinfer_indices
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.utils import is_hip
if TYPE_CHECKING:
@@ -37,7 +37,7 @@ class AttentionBackend(ABC):
"""The base class of attention backends"""
@abstractmethod
def init_forward_metadata(self, input_metadata: InputMetadata):
def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Init the metadata for a forward pass."""
raise NotImplementedError()
@@ -61,18 +61,18 @@ class AttentionBackend(ABC):
"""Get the fill value for padded seq lens. Typically, it is 0 or 1."""
raise NotImplementedError()
def forward(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
def forward(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
"""Run forward on an attention layer."""
if input_metadata.forward_mode.is_decode():
return self.forward_decode(q, k, v, layer, input_metadata)
if forward_batch.forward_mode.is_decode():
return self.forward_decode(q, k, v, layer, forward_batch)
else:
return self.forward_extend(q, k, v, layer, input_metadata)
return self.forward_extend(q, k, v, layer, forward_batch)
def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
"""Run a forward for decode."""
raise NotImplementedError()
def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
"""Run a forward for extend."""
raise NotImplementedError()
@@ -131,31 +131,31 @@ class FlashInferAttnBackend(AttentionBackend):
self.forward_metadata = None
self.cuda_graph_metadata = {}
def init_forward_metadata(self, input_metadata: InputMetadata):
if input_metadata.forward_mode.is_decode():
def init_forward_metadata(self, forward_batch: ForwardBatch):
if forward_batch.forward_mode.is_decode():
prefix_lens = None
use_ragged = False
extend_no_prefix = False
total_num_tokens = None
else:
prefix_lens = input_metadata.extend_prefix_lens
prefix_lens = forward_batch.extend_prefix_lens
# Some heuristics to check whether to use ragged forward
use_ragged = False
if (
torch.sum(input_metadata.seq_lens).item() >= 4096
torch.sum(forward_batch.seq_lens).item() >= 4096
and self.model_runner.sliding_window_size is None
):
use_ragged = True
total_num_tokens = torch.sum(input_metadata.seq_lens).item()
extend_no_prefix = not torch.any(input_metadata.extend_prefix_lens).item()
total_num_tokens = torch.sum(forward_batch.seq_lens).item()
extend_no_prefix = not torch.any(forward_batch.extend_prefix_lens).item()
update_flashinfer_indices(
input_metadata.forward_mode,
forward_batch.forward_mode,
self.model_runner,
input_metadata.req_pool_indices,
input_metadata.seq_lens,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
prefix_lens,
use_ragged=use_ragged,
)
@@ -248,7 +248,7 @@ class FlashInferAttnBackend(AttentionBackend):
def get_cuda_graph_seq_len_fill_value(self):
return 0
def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
if not isinstance(self.prefill_wrapper_paged, list):
prefill_wrapper_paged = self.prefill_wrapper_paged
else:
@@ -264,12 +264,12 @@ class FlashInferAttnBackend(AttentionBackend):
if not use_ragged:
if k is not None:
assert v is not None
input_metadata.token_to_kv_pool.set_kv_buffer(
layer.layer_id, input_metadata.out_cache_loc, k, v
forward_batch.token_to_kv_pool.set_kv_buffer(
layer.layer_id, forward_batch.out_cache_loc, k, v
)
o = prefill_wrapper_paged.forward(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
input_metadata.token_to_kv_pool.get_kv_buffer(layer.layer_id),
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
causal=True,
sm_scale=layer.scaling,
window_left=layer.sliding_window_size,
@@ -290,7 +290,7 @@ class FlashInferAttnBackend(AttentionBackend):
else:
o2, s2 = prefill_wrapper_paged.forward_return_lse(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
input_metadata.token_to_kv_pool.get_kv_buffer(layer.layer_id),
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
causal=False,
sm_scale=layer.scaling,
logits_soft_cap=layer.logit_cap,
@@ -298,13 +298,13 @@ class FlashInferAttnBackend(AttentionBackend):
o, _ = merge_state(o1, s1, o2, s2)
input_metadata.token_to_kv_pool.set_kv_buffer(
layer.layer_id, input_metadata.out_cache_loc, k, v
forward_batch.token_to_kv_pool.set_kv_buffer(
layer.layer_id, forward_batch.out_cache_loc, k, v
)
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
use_ragged, extend_no_prefix, total_num_tokens, decode_wrapper = (
self.forward_metadata
)
@@ -317,13 +317,13 @@ class FlashInferAttnBackend(AttentionBackend):
if k is not None:
assert v is not None
input_metadata.token_to_kv_pool.set_kv_buffer(
layer.layer_id, input_metadata.out_cache_loc, k, v
forward_batch.token_to_kv_pool.set_kv_buffer(
layer.layer_id, forward_batch.out_cache_loc, k, v
)
o = decode_wrapper.forward(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
input_metadata.token_to_kv_pool.get_kv_buffer(layer.layer_id),
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
sm_scale=layer.scaling,
logits_soft_cap=layer.logit_cap,
)
@@ -358,26 +358,26 @@ class TritonAttnBackend(AttentionBackend):
self.cuda_graph_max_seq_len = model_runner.model_config.context_len
def init_forward_metadata(self, input_metadata: InputMetadata):
def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Init auxiliary variables for triton attention backend."""
if input_metadata.forward_mode.is_decode():
start_loc = torch.zeros_like(input_metadata.seq_lens, dtype=torch.int32)
start_loc[1:] = torch.cumsum(input_metadata.seq_lens[:-1], dim=0)
if forward_batch.forward_mode.is_decode():
start_loc = torch.zeros_like(forward_batch.seq_lens, dtype=torch.int32)
start_loc[1:] = torch.cumsum(forward_batch.seq_lens[:-1], dim=0)
total_num_tokens = torch.sum(input_metadata.seq_lens).item()
total_num_tokens = torch.sum(forward_batch.seq_lens).item()
attn_logits = torch.empty(
(self.num_head, total_num_tokens),
dtype=self.reduce_dtype,
device="cuda",
)
max_seq_len = torch.max(input_metadata.seq_lens).item()
max_seq_len = torch.max(forward_batch.seq_lens).item()
max_extend_len = None
else:
start_loc = attn_logits = max_seq_len = None
prefix_lens = input_metadata.extend_prefix_lens
max_extend_len = torch.max(input_metadata.seq_lens - prefix_lens).item()
prefix_lens = forward_batch.extend_prefix_lens
max_extend_len = torch.max(forward_batch.seq_lens - prefix_lens).item()
self.forward_metadata = start_loc, attn_logits, max_seq_len, max_extend_len
@@ -415,15 +415,15 @@ class TritonAttnBackend(AttentionBackend):
def get_cuda_graph_seq_len_fill_value(self):
return 1
def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
# TODO: reuse the buffer across layers
if layer.qk_head_dim != layer.v_head_dim:
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
else:
o = torch.empty_like(q)
input_metadata.token_to_kv_pool.set_kv_buffer(
layer.layer_id, input_metadata.out_cache_loc, k, v
forward_batch.token_to_kv_pool.set_kv_buffer(
layer.layer_id, forward_batch.out_cache_loc, k, v
)
start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
@@ -432,20 +432,20 @@ class TritonAttnBackend(AttentionBackend):
k.contiguous(),
v.contiguous(),
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
input_metadata.token_to_kv_pool.get_key_buffer(layer.layer_id),
input_metadata.token_to_kv_pool.get_value_buffer(layer.layer_id),
input_metadata.req_to_token_pool.req_to_token,
input_metadata.req_pool_indices,
input_metadata.seq_lens,
input_metadata.extend_seq_lens,
input_metadata.extend_start_loc,
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
forward_batch.req_to_token_pool.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.extend_seq_lens,
forward_batch.extend_start_loc,
max_extend_len,
layer.scaling,
layer.logit_cap,
)
return o
def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
# During torch.compile, there is a bug in rotary_emb that causes the
# output value to have a 3D tensor shape. This reshapes the output correctly.
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
@@ -458,19 +458,19 @@ class TritonAttnBackend(AttentionBackend):
start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
input_metadata.token_to_kv_pool.set_kv_buffer(
layer.layer_id, input_metadata.out_cache_loc, k, v
forward_batch.token_to_kv_pool.set_kv_buffer(
layer.layer_id, forward_batch.out_cache_loc, k, v
)
self.decode_attention_fwd(
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
input_metadata.token_to_kv_pool.get_key_buffer(layer.layer_id),
input_metadata.token_to_kv_pool.get_value_buffer(layer.layer_id),
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
input_metadata.req_to_token_pool.req_to_token,
input_metadata.req_pool_indices,
forward_batch.req_to_token_pool.req_to_token,
forward_batch.req_pool_indices,
start_loc,
input_metadata.seq_lens,
forward_batch.seq_lens,
attn_logits,
max_seq_len,
layer.scaling,

View File

@@ -25,7 +25,7 @@ from vllm.distributed import (
tensor_model_parallel_all_gather,
)
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
@dataclasses.dataclass
@@ -61,26 +61,26 @@ class LogitsMetadata:
extend_logprob_pruned_lens_cpu: Optional[List[int]] = None
@classmethod
def from_input_metadata(cls, input_metadata: InputMetadata):
return_top_logprob = any(x > 0 for x in input_metadata.top_logprobs_nums)
if input_metadata.forward_mode.is_extend():
def from_forward_batch(cls, forward_batch: ForwardBatch):
return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
if forward_batch.forward_mode.is_extend():
extend_logprob_pruned_lens_cpu = [
extend_len - start_len
for extend_len, start_len in zip(
input_metadata.extend_seq_lens,
input_metadata.extend_logprob_start_lens_cpu,
forward_batch.extend_seq_lens,
forward_batch.extend_logprob_start_lens_cpu,
)
]
else:
extend_logprob_pruned_lens_cpu = None
return cls(
forward_mode=input_metadata.forward_mode,
top_logprobs_nums=input_metadata.top_logprobs_nums,
return_logprob=input_metadata.return_logprob,
forward_mode=forward_batch.forward_mode,
top_logprobs_nums=forward_batch.top_logprobs_nums,
return_logprob=forward_batch.return_logprob,
return_top_logprob=return_top_logprob,
extend_seq_lens=input_metadata.extend_seq_lens,
extend_seq_lens_cpu=input_metadata.extend_seq_lens_cpu,
extend_logprob_start_lens_cpu=input_metadata.extend_logprob_start_lens_cpu,
extend_seq_lens=forward_batch.extend_seq_lens,
extend_seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
extend_logprob_start_lens_cpu=forward_batch.extend_logprob_start_lens_cpu,
extend_logprob_pruned_lens_cpu=extend_logprob_pruned_lens_cpu,
)
@@ -162,10 +162,10 @@ class LogitsProcessor(nn.Module):
input_ids,
hidden_states,
weight,
logits_metadata: Union[LogitsMetadata, InputMetadata],
logits_metadata: Union[LogitsMetadata, ForwardBatch],
):
if isinstance(logits_metadata, InputMetadata):
logits_metadata = LogitsMetadata.from_input_metadata(logits_metadata)
if isinstance(logits_metadata, ForwardBatch):
logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)
assert isinstance(logits_metadata, LogitsMetadata)
# Get the last hidden states and last logits for the next token prediction

View File

@@ -7,7 +7,7 @@ from enum import IntEnum
import torch
import torch.nn as nn
from sglang.srt.model_executor.model_runner import InputMetadata
from sglang.srt.model_executor.model_runner import ForwardBatch
class PoolingType(IntEnum):
@@ -36,10 +36,10 @@ class Pooler(nn.Module):
self.normalize = normalize
def forward(
self, hidden_states: torch.Tensor, input_metadata: InputMetadata
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
) -> EmbeddingPoolerOutput:
if self.pooling_type == PoolingType.LAST:
last_token_indices = torch.cumsum(input_metadata.extend_seq_lens, dim=0) - 1
last_token_indices = torch.cumsum(forward_batch.extend_seq_lens, dim=0) - 1
pooled_data = hidden_states[last_token_indices]
else:
raise ValueError(f"Invalid pooling type: {self.pooling_type}")

View File

@@ -17,7 +17,7 @@ limitations under the License.
from torch import nn
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class RadixAttention(nn.Module):
@@ -48,11 +48,11 @@ class RadixAttention(nn.Module):
self.logit_cap = logit_cap
self.sliding_window_size = sliding_window_size or -1
def forward(self, q, k, v, input_metadata: InputMetadata):
def forward(self, q, k, v, forward_batch: ForwardBatch):
if k is not None:
# For cross-layer sharing, kv can be None
assert v is not None
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
return input_metadata.attn_backend.forward(q, k, v, self, input_metadata)
return forward_batch.attn_backend.forward(q, k, v, self, forward_batch)

View File

@@ -40,7 +40,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
)
from vllm.model_executor.model_loader.loader import DefaultModelLoader
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
class BaseLayerWithLoRA(nn.Module):

View File

@@ -23,7 +23,7 @@ import torch
from sglang.srt.lora.lora import LoRAAdapter, get_lora_layer
from sglang.srt.lora.lora_config import LoRAConfig
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import is_hip, replace_submodule
# ROCm: flashinfer available later
@@ -207,9 +207,9 @@ class LoRAManager:
if lora_weight_name:
self.B_buffer[lora_weight_name][i][buffer_id].copy_(weights)
def prepare_lora_batch(self, input_metadata: InputMetadata):
def prepare_lora_batch(self, forward_batch: ForwardBatch):
# load active loras into lora memory pool
cur_uids = set(input_metadata.lora_paths)
cur_uids = set(forward_batch.lora_paths)
assert len(cur_uids) <= self.max_loras_per_batch
i = 0
evictable_uids = list(self.active_uids)
@@ -229,14 +229,14 @@ class LoRAManager:
return
# setup lora in forward modules
bs = input_metadata.batch_size
bs = forward_batch.batch_size
seg_lens = (
input_metadata.extend_seq_lens
if input_metadata.forward_mode.is_extend()
forward_batch.extend_seq_lens
if forward_batch.forward_mode.is_extend()
else torch.ones(bs)
)
weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda")
for i, lora_path in enumerate(input_metadata.lora_paths):
for i, lora_path in enumerate(forward_batch.lora_paths):
weight_indices[i] = self.buffer_id[lora_path]
for module_name, module in self.lora_modules:

View File

@@ -29,7 +29,7 @@ from sglang.srt.constrained.jump_forward import JumpForwardMap
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs
@@ -511,8 +511,8 @@ class ScheduleBatch:
self.extend_logprob_start_lens_cpu = [r.extend_logprob_start_len for r in reqs]
self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size)
def get_input_metadata(self):
return InputMetadata.from_schedule_batch(self)
def get_forward_batch(self):
return ForwardBatch.from_schedule_batch(self)
def mix_with_running(self, running_batch: "ScheduleBatch"):
self.forward_mode = ForwardMode.MIXED

View File

@@ -32,7 +32,7 @@ from sglang.srt.mem_cache.radix_cache import TreeNode
CLIP_MAX_NEW_TOKENS = int(os.environ.get("SGLANG_CLIP_MAX_NEW_TOKENS", "4096"))
class SchedulerPolicy:
class SchedulePolicy:
def __init__(self, policy: str, tree_cache: BasePrefixCache):
if tree_cache.disable and policy in ["lpm", "dfs-weight"]:
# LPM and DFS-weight is meaningless when the tree cache is disabled.

View File

@@ -50,8 +50,8 @@ from sglang.srt.managers.schedule_batch import (
Req,
ScheduleBatch,
)
from sglang.srt.managers.scheduler_policy import PrefillAdder, SchedulerPolicy
from sglang.srt.managers.tp_worker import ModelTpWorker
from sglang.srt.managers.schedule_policy import PrefillAdder, SchedulePolicy
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.server_args import PortArgs, ServerArgs
@@ -134,7 +134,7 @@ class Scheduler:
)
# Launch a tensor parallel worker
self.tp_worker = ModelTpWorker(
self.tp_worker = TpModelWorker(
gpu_id=gpu_id,
tp_rank=tp_rank,
server_args=server_args,
@@ -179,7 +179,7 @@ class Scheduler:
disable=server_args.disable_radix_cache,
)
self.tree_cache_metrics = {"total": 0, "hit": 0}
self.policy = SchedulerPolicy(self.schedule_policy, self.tree_cache)
self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
# Init running status
self.waiting_queue: List[Req] = []
@@ -575,9 +575,9 @@ class Scheduler:
if self.is_generation:
# Forward and sample the next tokens
if batch.extend_num_tokens != 0:
input_metadata = batch.get_input_metadata()
forward_batch = batch.get_forward_batch()
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
input_metadata, batch
forward_batch, batch
)
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
next_token_ids
@@ -641,8 +641,8 @@ class Scheduler:
)
else:
assert batch.extend_num_tokens != 0
input_metadata = batch.get_input_metadata()
embeddings = self.tp_worker.forward_batch_embedding(input_metadata)
forward_batch = batch.get_forward_batch()
embeddings = self.tp_worker.forward_batch_embedding(forward_batch)
# Check finish conditions
for i, req in enumerate(batch.reqs):
@@ -771,9 +771,9 @@ class Scheduler:
batch.prepare_for_decode()
# Forward and sample the next tokens
input_metadata = batch.get_input_metadata()
forward_batch = batch.get_forward_batch()
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
input_metadata, batch
forward_batch, batch
)
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
next_token_ids

View File

@@ -21,7 +21,7 @@ import logging
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.io_struct import UpdateWeightReqInput
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import broadcast_pyobj, is_multimodal_model, set_random_seed
@@ -29,7 +29,9 @@ from sglang.srt.utils import broadcast_pyobj, is_multimodal_model, set_random_se
logger = logging.getLogger(__name__)
class ModelTpWorker:
class TpModelWorker:
"""A tensor parallel model worker."""
def __init__(
self,
gpu_id: int,
@@ -106,13 +108,13 @@ class ModelTpWorker:
self.random_seed,
)
def forward_batch_generation(self, input_metadata: InputMetadata, batch):
logits_output = self.model_runner.forward(input_metadata)
def forward_batch_generation(self, forward_batch: ForwardBatch, batch):
logits_output = self.model_runner.forward(forward_batch)
next_token_ids = self.model_runner.sample(logits_output, batch)
return logits_output, next_token_ids
def forward_batch_embedding(self, input_metadata: InputMetadata):
logits_output = self.model_runner.forward(input_metadata)
def forward_batch_embedding(self, forward_batch: ForwardBatch):
logits_output = self.model_runner.forward(forward_batch)
embeddings = logits_output.embeddings.tolist()
return embeddings

View File

@@ -31,7 +31,7 @@ from sglang.srt.layers.logits_processor import (
LogitsProcessor,
LogitsProcessorOutput,
)
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.utils import monkey_patch_vllm_all_gather
if TYPE_CHECKING:
@@ -196,7 +196,7 @@ class CudaGraphRunner:
# Run and capture
def run_once():
input_metadata = InputMetadata(
forward_batch = ForwardBatch(
forward_mode=ForwardMode.DECODE,
batch_size=bs,
input_ids=input_ids,
@@ -210,7 +210,7 @@ class CudaGraphRunner:
top_logprobs_nums=[0] * bs,
positions=torch.clamp((seq_lens - 1), min=0).to(torch.int64),
)
return forward(input_ids, input_metadata.positions, input_metadata)
return forward(input_ids, forward_batch.positions, forward_batch)
for _ in range(2):
torch.cuda.synchronize()
@@ -233,9 +233,9 @@ class CudaGraphRunner:
self.graph_memory_pool = graph.pool()
return graph, out
def replay(self, input_metadata: InputMetadata):
assert input_metadata.out_cache_loc is not None
raw_bs = input_metadata.batch_size
def replay(self, forward_batch: ForwardBatch):
assert forward_batch.out_cache_loc is not None
raw_bs = forward_batch.batch_size
# Pad
index = bisect.bisect_left(self.capture_bs, raw_bs)
@@ -245,10 +245,10 @@ class CudaGraphRunner:
self.out_cache_loc.zero_()
# Common inputs
self.input_ids[:raw_bs] = input_metadata.input_ids
self.req_pool_indices[:raw_bs] = input_metadata.req_pool_indices
self.seq_lens[:raw_bs] = input_metadata.seq_lens
self.out_cache_loc[:raw_bs] = input_metadata.out_cache_loc
self.input_ids[:raw_bs] = forward_batch.input_ids
self.req_pool_indices[:raw_bs] = forward_batch.req_pool_indices
self.seq_lens[:raw_bs] = forward_batch.seq_lens
self.out_cache_loc[:raw_bs] = forward_batch.out_cache_loc
# Attention backend
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
@@ -271,15 +271,15 @@ class CudaGraphRunner:
)
# Extract logprobs
if input_metadata.return_logprob:
if forward_batch.return_logprob:
logits_output.next_token_logprobs = torch.nn.functional.log_softmax(
logits_output.next_token_logits, dim=-1
)
return_top_logprob = any(x > 0 for x in input_metadata.top_logprobs_nums)
return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
if return_top_logprob:
logits_metadata = LogitsMetadata(
forward_mode=ForwardMode.DECODE,
top_logprobs_nums=input_metadata.top_logprobs_nums,
top_logprobs_nums=forward_batch.top_logprobs_nums,
)
logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
logits_output.next_token_logprobs, logits_metadata

View File

@@ -18,7 +18,7 @@ limitations under the License.
"""Meta data for a forward pass."""
from dataclasses import dataclass
from enum import IntEnum, auto
from typing import TYPE_CHECKING, List, Set
from typing import TYPE_CHECKING, List
import numpy as np
import torch
@@ -53,8 +53,8 @@ class ForwardMode(IntEnum):
@dataclass
class InputMetadata:
"""Store all inforamtion of a forward pass."""
class ForwardBatch:
"""Store all inputs of a forward pass."""
# The forward mode
forward_mode: ForwardMode

View File

@@ -48,7 +48,7 @@ from sglang.srt.mem_cache.memory_pool import (
MLATokenToKVPool,
ReqToTokenPool,
)
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
@@ -466,47 +466,47 @@ class ModelRunner:
logger.info("Capture cuda graph begin. This can take up to several minutes.")
self.cuda_graph_runner = CudaGraphRunner(self)
def forward_decode(self, input_metadata: InputMetadata):
def forward_decode(self, forward_batch: ForwardBatch):
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(
input_metadata.batch_size
forward_batch.batch_size
):
return self.cuda_graph_runner.replay(input_metadata)
return self.cuda_graph_runner.replay(forward_batch)
return self.model.forward(
input_metadata.input_ids, input_metadata.positions, input_metadata
forward_batch.input_ids, forward_batch.positions, forward_batch
)
def forward_extend(self, input_metadata: InputMetadata):
def forward_extend(self, forward_batch: ForwardBatch):
if self.is_generation:
return self.model.forward(
input_metadata.input_ids, input_metadata.positions, input_metadata
forward_batch.input_ids, forward_batch.positions, forward_batch
)
else:
# Only embedding models have get_embedding parameter
return self.model.forward(
input_metadata.input_ids,
input_metadata.positions,
input_metadata,
forward_batch.input_ids,
forward_batch.positions,
forward_batch,
get_embedding=True,
)
def forward(self, input_metadata: InputMetadata) -> LogitsProcessorOutput:
def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
# Attach attention information
input_metadata.req_to_token_pool = self.req_to_token_pool
input_metadata.token_to_kv_pool = self.token_to_kv_pool
input_metadata.attn_backend = self.attn_backend
input_metadata.attn_backend.init_forward_metadata(input_metadata)
forward_batch.req_to_token_pool = self.req_to_token_pool
forward_batch.token_to_kv_pool = self.token_to_kv_pool
forward_batch.attn_backend = self.attn_backend
forward_batch.attn_backend.init_forward_metadata(forward_batch)
# Attach lora information
if self.server_args.lora_paths is not None:
self.lora_manager.prepare_lora_batch(input_metadata)
self.lora_manager.prepare_lora_batch(forward_batch)
if input_metadata.forward_mode.is_decode():
return self.forward_decode(input_metadata)
elif input_metadata.forward_mode.is_extend():
return self.forward_extend(input_metadata)
if forward_batch.forward_mode.is_decode():
return self.forward_decode(forward_batch)
elif forward_batch.forward_mode.is_extend():
return self.forward_extend(forward_batch)
else:
raise ValueError(f"Invaid forward mode: {input_metadata.forward_mode}")
raise ValueError(f"Invaid forward mode: {forward_batch.forward_mode}")
def _apply_logits_bias(
self, logits: torch.Tensor, sampling_info: SamplingBatchInfo

View File

@@ -46,7 +46,7 @@ from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
@@ -189,13 +189,13 @@ class BaiChuanAttention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
qkv, _ = self.W_pack(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
if self.postion_embedding != "ALIBI":
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, input_metadata)
attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output)
return output
@@ -237,7 +237,7 @@ class BaiChuanDecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
@@ -249,7 +249,7 @@ class BaiChuanDecoderLayer(nn.Module):
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
input_metadata=input_metadata,
forward_batch=forward_batch,
)
# Fully Connected
@@ -292,7 +292,7 @@ class BaiChuanModel(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
@@ -301,7 +301,7 @@ class BaiChuanModel(nn.Module):
hidden_states, residual = layer(
positions,
hidden_states,
input_metadata,
forward_batch,
residual,
)
hidden_states, _ = self.norm(hidden_states, residual)
@@ -350,11 +350,11 @@ class BaiChuanBaseForCausalLM(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata)
hidden_states = self.model(input_ids, positions, forward_batch)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
input_ids, hidden_states, self.lm_head.weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -42,7 +42,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
LoraConfig = None
@@ -118,7 +118,7 @@ class GLMAttention(nn.Module):
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
qkv, _ = self.query_key_value(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
@@ -127,7 +127,7 @@ class GLMAttention(nn.Module):
q,
k,
v,
input_metadata,
forward_batch,
)
attn_output, _ = self.dense(context_layer)
return attn_output
@@ -220,7 +220,7 @@ class GLMBlock(nn.Module):
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
# hidden_states: [num_tokens, h]
# Layer norm at the beginning of the transformer layer.
@@ -229,7 +229,7 @@ class GLMBlock(nn.Module):
attention_output = self.self_attention(
hidden_states=layernorm_output,
position_ids=position_ids,
input_metadata=input_metadata,
forward_batch=forward_batch,
)
# Residual connection.
@@ -288,14 +288,14 @@ class GLMTransformer(nn.Module):
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
for i in range(self.num_layers):
layer = self.layers[i]
hidden_states = layer(
hidden_states=hidden_states,
position_ids=position_ids,
input_metadata=input_metadata,
forward_batch=forward_batch,
)
# Final layer norm.
if self.post_layer_norm:
@@ -328,7 +328,7 @@ class ChatGLMModel(nn.Module):
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
inputs_embeds = self.embedding(input_ids)
@@ -336,7 +336,7 @@ class ChatGLMModel(nn.Module):
hidden_states = self.encoder(
hidden_states=inputs_embeds,
position_ids=position_ids,
input_metadata=input_metadata,
forward_batch=forward_batch,
)
return hidden_states
@@ -376,11 +376,11 @@ class ChatGLMForCausalLM(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, input_metadata)
hidden_states = self.transformer(input_ids, positions, forward_batch)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
input_ids, hidden_states, self.lm_head.weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -63,7 +63,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import set_weight_attrs
@@ -220,14 +220,14 @@ class CohereAttention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if self.use_qk_norm:
q, k = self._apply_qk_norm(q, k)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, input_metadata)
attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output)
return output
@@ -255,7 +255,7 @@ class CohereDecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
@@ -264,7 +264,7 @@ class CohereDecoderLayer(nn.Module):
hidden_states_attention = self.self_attn(
positions=positions,
hidden_states=hidden_states,
input_metadata=input_metadata,
forward_batch=forward_batch,
)
hidden_states_mlp = self.mlp(hidden_states)
# Add everything together
@@ -299,7 +299,7 @@ class CohereModel(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
@@ -308,7 +308,7 @@ class CohereModel(nn.Module):
hidden_states, residual = layer(
positions,
hidden_states,
input_metadata,
forward_batch,
residual,
)
hidden_states, _ = self.norm(hidden_states, residual)
@@ -333,15 +333,15 @@ class CohereForCausalLM(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
hidden_states = self.model(
input_ids,
positions,
input_metadata,
forward_batch,
)
return self.logits_processor(
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
input_ids, hidden_states, self.model.embed_tokens.weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -44,7 +44,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import set_weight_attrs
@@ -249,14 +249,14 @@ class DbrxAttention(nn.Module):
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
qkv, _ = self.Wqkv(hidden_states)
if self.clip_qkv is not None:
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(position_ids, q, k)
attn_output = self.attn(q, k, v, input_metadata)
attn_output = self.attn(q, k, v, forward_batch)
hidden_states, _ = self.out_proj(attn_output)
return hidden_states
@@ -278,14 +278,14 @@ class DbrxFusedNormAttention(nn.Module):
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.norm_1(hidden_states)
x = self.attn(
position_ids=position_ids,
hidden_states=hidden_states,
input_metadata=input_metadata,
forward_batch=forward_batch,
)
hidden_states = residual + x
residual = hidden_states
@@ -310,12 +310,12 @@ class DbrxBlock(nn.Module):
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
hidden_states, residual = self.norm_attn_norm(
position_ids=position_ids,
hidden_states=hidden_states,
input_metadata=input_metadata,
forward_batch=forward_batch,
)
hidden_states = self.ffn(hidden_states)
hidden_states = hidden_states + residual
@@ -349,7 +349,7 @@ class DbrxModel(nn.Module):
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
if input_embeds is None:
@@ -358,7 +358,7 @@ class DbrxModel(nn.Module):
hidden_states = input_embeds
for i in range(len(self.blocks)):
block = self.blocks[i]
hidden_states = block(position_ids, hidden_states, input_metadata)
hidden_states = block(position_ids, hidden_states, forward_batch)
hidden_states = self.norm_f(hidden_states)
return hidden_states
@@ -388,11 +388,11 @@ class DbrxForCausalLM(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, input_metadata)
hidden_states = self.transformer(input_ids, positions, forward_batch)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
input_ids, hidden_states, self.lm_head.weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -46,7 +46,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class DeepseekMLP(nn.Module):
@@ -246,12 +246,12 @@ class DeepseekAttention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, input_metadata)
attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output)
return output
@@ -303,7 +303,7 @@ class DeepseekDecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
# Self Attention
@@ -315,7 +315,7 @@ class DeepseekDecoderLayer(nn.Module):
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
input_metadata=input_metadata,
forward_batch=forward_batch,
)
# Fully Connected
@@ -356,14 +356,14 @@ class DeepseekModel(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states, residual = layer(
positions, hidden_states, input_metadata, residual
positions, hidden_states, forward_batch, residual
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
@@ -391,11 +391,11 @@ class DeepseekForCausalLM(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata)
hidden_states = self.model(input_ids, positions, forward_batch)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
input_ids, hidden_states, self.lm_head.weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -46,7 +46,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import is_hip
# ROCm: flashinfer available later
@@ -281,7 +281,7 @@ class DeepseekV2Attention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
if self.q_lora_rank is not None:
q = self.q_a_proj(hidden_states)[0]
@@ -314,7 +314,7 @@ class DeepseekV2Attention(nn.Module):
v = torch.nn.functional.pad(v, [0, 256 - self.v_head_dim], value=0).view(
-1, self.num_local_heads * 256
)
attn_output = self.attn(q, k, v, input_metadata)
attn_output = self.attn(q, k, v, forward_batch)
attn_output = attn_output.view(-1, self.num_local_heads, 256)[
..., : self.v_head_dim
].reshape(-1, self.num_local_heads * self.v_head_dim)
@@ -433,7 +433,7 @@ class DeepseekV2AttentionMLA(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
q_len = hidden_states.shape[0]
q_input = hidden_states.new_empty(
@@ -471,7 +471,7 @@ class DeepseekV2AttentionMLA(nn.Module):
q_input[..., self.kv_lora_rank :] = q_pe
k_input[..., self.kv_lora_rank :] = k_pe
attn_output = self.attn(q_input, k_input, v_input, input_metadata)
attn_output = self.attn(q_input, k_input, v_input, forward_batch)
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
if self.w_vc.dtype == torch.float8_e4m3fn:
@@ -567,7 +567,7 @@ class DeepseekV2DecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
# Self Attention
@@ -579,7 +579,7 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
input_metadata=input_metadata,
forward_batch=forward_batch,
)
# Fully Connected
@@ -623,14 +623,14 @@ class DeepseekV2Model(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states, residual = layer(
positions, hidden_states, input_metadata, residual
positions, hidden_states, forward_batch, residual
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
@@ -658,11 +658,11 @@ class DeepseekV2ForCausalLM(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata)
hidden_states = self.model(input_ids, positions, forward_batch)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
input_ids, hidden_states, self.lm_head.weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -40,7 +40,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class ExaoneGatedMLP(nn.Module):
@@ -162,12 +162,12 @@ class ExaoneAttention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, input_metadata)
attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.out_proj(attn_output)
return output
@@ -220,7 +220,7 @@ class ExaoneDecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
@@ -232,7 +232,7 @@ class ExaoneDecoderLayer(nn.Module):
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
input_metadata=input_metadata,
forward_batch=forward_batch,
)
# Fully Connected
@@ -270,7 +270,7 @@ class ExaoneModel(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
if input_embeds is None:
@@ -283,7 +283,7 @@ class ExaoneModel(nn.Module):
hidden_states, residual = layer(
positions,
hidden_states,
input_metadata,
forward_batch,
residual,
)
hidden_states, _ = self.ln_f(hidden_states, residual)
@@ -309,14 +309,14 @@ class ExaoneForCausalLM(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> LogitsProcessorOutput:
hidden_states = self.transformer(
input_ids, positions, input_metadata, input_embeds
input_ids, positions, forward_batch, input_embeds
)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
input_ids, hidden_states, self.lm_head.weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -37,7 +37,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class GemmaMLP(nn.Module):
@@ -137,12 +137,12 @@ class GemmaAttention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, input_metadata)
attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output)
return output
@@ -180,7 +180,7 @@ class GemmaDecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
@@ -192,7 +192,7 @@ class GemmaDecoderLayer(nn.Module):
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
input_metadata=input_metadata,
forward_batch=forward_batch,
)
# Fully Connected
@@ -226,7 +226,7 @@ class GemmaModel(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
if input_embeds is None:
@@ -243,7 +243,7 @@ class GemmaModel(nn.Module):
hidden_states, residual = layer(
positions,
hidden_states,
input_metadata,
forward_batch,
residual,
)
hidden_states, _ = self.norm(hidden_states, residual)
@@ -293,12 +293,12 @@ class GemmaForCausalLM(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
input_ids, hidden_states, self.model.embed_tokens.weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -37,7 +37,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
# Aligned with HF's implementation, using sliding window inclusive with the last token
@@ -175,12 +175,12 @@ class Gemma2Attention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, input_metadata)
attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output)
return output
@@ -230,7 +230,7 @@ class Gemma2DecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
if residual is None:
@@ -241,7 +241,7 @@ class Gemma2DecoderLayer(nn.Module):
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
input_metadata=input_metadata,
forward_batch=forward_batch,
)
hidden_states = self.post_attention_layernorm(hidden_states)
@@ -286,7 +286,7 @@ class Gemma2Model(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
if input_embeds is None:
@@ -302,7 +302,7 @@ class Gemma2Model(nn.Module):
hidden_states, residual = layer(
positions,
hidden_states,
input_metadata,
forward_batch,
residual,
)
hidden_states, _ = self.norm(hidden_states, residual)
@@ -352,12 +352,12 @@ class Gemma2ForCausalLM(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
input_ids, hidden_states, self.model.embed_tokens.weight, forward_batch
)
def get_attention_sliding_window_size(self):

View File

@@ -35,7 +35,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class GPTBigCodeAttention(nn.Module):
@@ -90,7 +90,7 @@ class GPTBigCodeAttention(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.split(
@@ -101,7 +101,7 @@ class GPTBigCodeAttention(nn.Module):
],
dim=-1,
)
attn_output = self.attn(q, k, v, input_metadata)
attn_output = self.attn(q, k, v, forward_batch)
attn_output, _ = self.c_proj(attn_output)
return attn_output
@@ -160,12 +160,12 @@ class GPTBigCodeBlock(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_output = self.attn(
hidden_states=hidden_states, input_metadata=input_metadata
hidden_states=hidden_states, forward_batch=forward_batch
)
# residual connection
hidden_states = attn_output + residual
@@ -214,7 +214,7 @@ class GPTBigCodeModel(nn.Module):
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
@@ -222,7 +222,7 @@ class GPTBigCodeModel(nn.Module):
for i in range(len(self.h)):
layer = self.h[i]
hidden_states = layer(hidden_states, input_metadata)
hidden_states = layer(hidden_states, forward_batch)
hidden_states = self.ln_f(hidden_states)
return hidden_states
@@ -267,11 +267,11 @@ class GPTBigCodeForCausalLM(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, input_metadata)
hidden_states = self.transformer(input_ids, positions, forward_batch)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
input_ids, hidden_states, self.lm_head.weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -46,7 +46,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class Grok1MoE(nn.Module):
@@ -173,12 +173,12 @@ class Grok1Attention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, input_metadata)
attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output)
return output
@@ -219,7 +219,7 @@ class Grok1DecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
# Self Attention
hidden_states = (
@@ -227,7 +227,7 @@ class Grok1DecoderLayer(nn.Module):
self.self_attn(
positions=positions,
hidden_states=self.pre_attn_norm(hidden_states),
input_metadata=input_metadata,
forward_batch=forward_batch,
)
)
+ hidden_states
@@ -268,7 +268,7 @@ class Grok1Model(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
if input_embeds is None:
@@ -278,7 +278,7 @@ class Grok1Model(nn.Module):
hidden_states = input_embeds
for i in range(len(self.layers)):
hidden_states = self.layers[i](positions, hidden_states, input_metadata)
hidden_states = self.layers[i](positions, hidden_states, forward_batch)
hidden_states = self.norm(hidden_states)
hidden_states.mul_(self.config.output_multiplier_scale)
return hidden_states
@@ -309,12 +309,12 @@ class Grok1ForCausalLM(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
input_ids, hidden_states, self.lm_head.weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -40,7 +40,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class InternLM2MLP(nn.Module):
@@ -137,12 +137,12 @@ class InternLM2Attention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
qkv, _ = self.wqkv(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, input_metadata)
attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.wo(attn_output)
return output
@@ -182,7 +182,7 @@ class InternLMDecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
@@ -194,7 +194,7 @@ class InternLMDecoderLayer(nn.Module):
hidden_states = self.attention(
positions=positions,
hidden_states=hidden_states,
input_metadata=input_metadata,
forward_batch=forward_batch,
)
# Fully Connected
@@ -229,7 +229,7 @@ class InternLM2Model(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
if input_embeds is None:
@@ -242,7 +242,7 @@ class InternLM2Model(nn.Module):
hidden_states, residual = layer(
positions,
hidden_states,
input_metadata,
forward_batch,
residual,
)
hidden_states, _ = self.norm(hidden_states, residual)
@@ -268,12 +268,12 @@ class InternLM2ForCausalLM(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.output.weight, input_metadata
input_ids, hidden_states, self.output.weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -43,7 +43,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.torchao_utils import apply_torchao_config_
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class LlamaMLP(nn.Module):
@@ -162,12 +162,12 @@ class LlamaAttention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, input_metadata)
attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output)
return output
@@ -221,7 +221,7 @@ class LlamaDecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
@@ -233,7 +233,7 @@ class LlamaDecoderLayer(nn.Module):
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
input_metadata=input_metadata,
forward_batch=forward_batch,
)
# Fully Connected
@@ -270,7 +270,7 @@ class LlamaModel(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
if input_embeds is None:
@@ -283,7 +283,7 @@ class LlamaModel(nn.Module):
hidden_states, residual = layer(
positions,
hidden_states,
input_metadata,
forward_batch,
residual,
)
hidden_states, _ = self.norm(hidden_states, residual)
@@ -310,12 +310,12 @@ class LlamaForCausalLM(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> LogitsProcessorOutput:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
input_ids, hidden_states, self.lm_head.weight, forward_batch
)
def get_hidden_dim(self, module_name):

View File

@@ -23,7 +23,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
@@ -50,18 +50,18 @@ class LlamaForClassification(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
is_eos_token = input_ids == self.eos_token_id
hidden_states = hidden_states[is_eos_token]
scores = self.classification_head(hidden_states)
if scores.shape[0] != input_metadata.batch_size:
if scores.shape[0] != forward_batch.batch_size:
print("Warning: the EOS tokens are missing in some sentences.")
scores = torch.ones(
(input_metadata.batch_size, self.config.classification_out_size)
(forward_batch.batch_size, self.config.classification_out_size)
).to(input_ids.device)
logits_output = LogitsProcessorOutput(

View File

@@ -6,7 +6,7 @@ from transformers import LlamaConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
from sglang.srt.model_executor.model_runner import InputMetadata
from sglang.srt.model_executor.model_runner import ForwardBatch
from sglang.srt.models.llama import LlamaModel
@@ -26,15 +26,15 @@ class LlamaEmbeddingModel(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
get_embedding: bool = True,
) -> EmbeddingPoolerOutput:
assert (
get_embedding
), "LlamaEmbeddingModel / MistralModel is only used for embedding"
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
return self.pooler(hidden_states, input_metadata)
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.pooler(hidden_states, forward_batch)
def load_weights(
self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None

View File

@@ -24,7 +24,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
@@ -51,13 +51,13 @@ class LlamaForSequenceClassification(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> EmbeddingPoolerOutput:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
scores = self.score(hidden_states)
return self.pooler(scores, input_metadata)
return self.pooler(scores, forward_batch)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
@@ -102,19 +102,19 @@ class LlamaForSequenceClassificationWithNormal_Weights(LlamaForSequenceClassific
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
get_embedding: bool = True,
) -> EmbeddingPoolerOutput:
assert (
get_embedding
), "LlamaForSequenceClassification is only used for embedding"
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
logits = self.score(hidden_states)
weights = self.weights(hidden_states)
pooled_logits = self.pooler(logits, input_metadata).embeddings
pooled_weights = self.pooler(weights, input_metadata).embeddings
pooled_logits = self.pooler(logits, forward_batch).embeddings
pooled_weights = self.pooler(weights, forward_batch).embeddings
rews = pooled_logits.view(-1, self.num_labels // 2, 2)[:, :, 0].view(
-1, self.num_labels // 2

View File

@@ -41,7 +41,7 @@ from sglang.srt.mm_utils import (
unpad_image,
unpad_image_shape,
)
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.llama import LlamaForCausalLM
from sglang.srt.models.mistral import MistralForCausalLM
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
@@ -130,12 +130,12 @@ class LlavaBaseForCausalLM(nn.Module):
self,
input_ids: torch.LongTensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
image_inputs = input_metadata.image_inputs
image_inputs = forward_batch.image_inputs
if input_metadata.forward_mode.is_extend():
bs = input_metadata.batch_size
if forward_batch.forward_mode.is_extend():
bs = forward_batch.batch_size
# Got List[List[str]] extend it to List[str]
# The length of the List should be equal to batch size
modalities_list = []
@@ -151,7 +151,7 @@ class LlavaBaseForCausalLM(nn.Module):
# Embed text inputs
input_embeds = self.language_model.model.embed_tokens(input_ids)
start_positions = positions[input_metadata.extend_start_loc].cpu().numpy()
start_positions = positions[forward_batch.extend_start_loc].cpu().numpy()
need_vision = start_positions <= np.array(max_image_offset)
if need_vision.any():
@@ -348,8 +348,8 @@ class LlavaBaseForCausalLM(nn.Module):
image_features = new_image_features
# Fill in the placeholder for the image
extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy()
prefix_lens_cpu = input_metadata.extend_prefix_lens.cpu().numpy()
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
prefix_lens_cpu = forward_batch.extend_prefix_lens.cpu().numpy()
pt = 0
for i in range(bs):
if not need_vision[i]:
@@ -379,10 +379,10 @@ class LlavaBaseForCausalLM(nn.Module):
pt += 1
return self.language_model(
input_ids, positions, input_metadata, input_embeds=input_embeds
input_ids, positions, forward_batch, input_embeds=input_embeds
)
elif input_metadata.forward_mode.is_decode():
return self.language_model(input_ids, positions, input_metadata)
elif forward_batch.forward_mode.is_decode():
return self.language_model(input_ids, positions, forward_batch)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# Load clip vision model by cfg['mm_vision_tower']:

View File

@@ -27,7 +27,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.llama import LlamaForCausalLM
@@ -108,11 +108,11 @@ class LlavaVidForCausalLM(nn.Module):
self,
input_ids: torch.LongTensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
image_inputs = input_metadata.image_inputs
if input_metadata.forward_mode.is_extend():
bs = input_metadata.batch_size
image_inputs = forward_batch.image_inputs
if forward_batch.forward_mode.is_extend():
bs = forward_batch.batch_size
# Embed text inputs
input_embeds = self.language_model.model.embed_tokens(input_ids)
@@ -124,7 +124,7 @@ class LlavaVidForCausalLM(nn.Module):
max_image_offset.append(max(im.image_offsets))
else:
max_image_offset.append(-1)
start_positions = positions[input_metadata.extend_start_loc].cpu().numpy()
start_positions = positions[forward_batch.extend_start_loc].cpu().numpy()
need_vision = start_positions <= np.array(max_image_offset)
if need_vision.any():
@@ -169,8 +169,8 @@ class LlavaVidForCausalLM(nn.Module):
image_features = new_image_features
# Fill in the placeholder for the image
extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy()
prefix_lens_cpu = input_metadata.extend_prefix_lens.cpu().numpy()
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
prefix_lens_cpu = forward_batch.extend_prefix_lens.cpu().numpy()
pt = 0
for i in range(bs):
if not need_vision[i]:
@@ -200,10 +200,10 @@ class LlavaVidForCausalLM(nn.Module):
pt += 1
return self.language_model(
input_ids, positions, input_metadata, input_embeds=input_embeds
input_ids, positions, forward_batch, input_embeds=input_embeds
)
elif input_metadata.forward_mode.is_decode():
return self.language_model(input_ids, positions, input_metadata)
elif forward_batch.forward_mode.is_decode():
return self.language_model(input_ids, positions, forward_batch)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# Load clip vision model by cfg['mm_vision_tower']:

View File

@@ -39,7 +39,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class MiniCPMMLP(nn.Module):
@@ -148,7 +148,7 @@ class MiniCPMAttention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
@@ -156,7 +156,7 @@ class MiniCPMAttention(nn.Module):
q, k = q.float(), k.float()
q, k = self.rotary_emb(positions, q, k)
q, k = q.to(orig_dtype), k.to(orig_dtype)
attn_output = self.attn(q, k, v, input_metadata)
attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output)
return output
@@ -199,7 +199,7 @@ class MiniCPMDecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
@@ -208,7 +208,7 @@ class MiniCPMDecoderLayer(nn.Module):
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
input_metadata=input_metadata,
forward_batch=forward_batch,
)
hidden_states = residual + hidden_states * (
self.config.scale_depth / math.sqrt(self.config.num_hidden_layers)
@@ -252,7 +252,7 @@ class MiniCPMModel(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
if input_embeds is None:
@@ -266,7 +266,7 @@ class MiniCPMModel(nn.Module):
hidden_states, residual = layer(
positions,
hidden_states,
input_metadata,
forward_batch,
residual,
)
hidden_states = self.norm(hidden_states)
@@ -303,19 +303,19 @@ class MiniCPMForCausalLM(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
if input_embeds is not None:
input_embeds = input_embeds * self.config.scale_emb
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
hidden_states = hidden_states / self.scale_width
if self.config.tie_word_embeddings:
lm_head_weight = self.model.embed_tokens.weight
else:
lm_head_weight = self.lm_head.weight
return self.logits_processor(
input_ids, hidden_states, lm_head_weight, input_metadata
input_ids, hidden_states, lm_head_weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -42,7 +42,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import is_hip
# ROCm: flashinfer available later
@@ -193,7 +193,7 @@ class MiniCPM3Attention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
if self.q_lora_rank is not None:
q = self.q_a_proj(hidden_states)[0]
@@ -230,7 +230,7 @@ class MiniCPM3Attention(nn.Module):
v = torch.nn.functional.pad(v, [0, 128 - self.v_head_dim], value=0).view(
-1, self.num_local_heads * 128
)
attn_output = self.attn(q, k, v, input_metadata)
attn_output = self.attn(q, k, v, forward_batch)
attn_output = attn_output.view(-1, self.num_local_heads, 128)[
..., : self.v_head_dim
].reshape(-1, self.num_local_heads * self.v_head_dim)
@@ -341,7 +341,7 @@ class MiniCPM3AttentionMLA(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
q_len = hidden_states.shape[0]
q_input = hidden_states.new_empty(
@@ -383,7 +383,7 @@ class MiniCPM3AttentionMLA(nn.Module):
q_input[..., self.kv_lora_rank :] = q_pe
k_input[..., self.kv_lora_rank :] = k_pe
attn_output = self.attn(q_input, k_input, v_input, input_metadata)
attn_output = self.attn(q_input, k_input, v_input, forward_batch)
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
if self.w_vc.dtype == torch.float8_e4m3fn:
@@ -472,7 +472,7 @@ class MiniCPM3DecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
@@ -481,7 +481,7 @@ class MiniCPM3DecoderLayer(nn.Module):
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
input_metadata=input_metadata,
forward_batch=forward_batch,
)
hidden_states = residual + hidden_states * (
self.config.scale_depth / math.sqrt(self.config.num_hidden_layers)
@@ -528,7 +528,7 @@ class MiniCPM3Model(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
if input_embeds is None:
@@ -542,7 +542,7 @@ class MiniCPM3Model(nn.Module):
hidden_states, residual = layer(
positions,
hidden_states,
input_metadata,
forward_batch,
residual,
)
hidden_states = self.norm(hidden_states)
@@ -581,19 +581,19 @@ class MiniCPM3ForCausalLM(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
if input_embeds is not None:
input_embeds = input_embeds * self.config.scale_emb
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
hidden_states = hidden_states / self.scale_width
if self.config.tie_word_embeddings:
lm_head_weight = self.model.embed_tokens.weight
else:
lm_head_weight = self.lm_head.weight
return self.logits_processor(
input_ids, hidden_states, lm_head_weight, input_metadata
input_ids, hidden_states, lm_head_weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -43,7 +43,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.torchao_utils import apply_torchao_config_
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class MixtralMoE(nn.Module):
@@ -171,12 +171,12 @@ class MixtralAttention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, input_metadata)
attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output)
return output
@@ -220,7 +220,7 @@ class MixtralDecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
# Self Attention
@@ -232,7 +232,7 @@ class MixtralDecoderLayer(nn.Module):
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
input_metadata=input_metadata,
forward_batch=forward_batch,
)
# Fully Connected
@@ -270,7 +270,7 @@ class MixtralModel(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
if input_embeds is None:
@@ -281,7 +281,7 @@ class MixtralModel(nn.Module):
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states, residual = layer(
positions, hidden_states, input_metadata, residual
positions, hidden_states, forward_batch, residual
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
@@ -307,12 +307,12 @@ class MixtralForCausalLM(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
input_ids, hidden_states, self.lm_head.weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -45,7 +45,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class MixtralMLP(nn.Module):
@@ -216,12 +216,12 @@ class MixtralAttention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, input_metadata)
attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output)
return output
@@ -256,7 +256,7 @@ class MixtralDecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
# Self Attention
@@ -268,7 +268,7 @@ class MixtralDecoderLayer(nn.Module):
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
input_metadata=input_metadata,
forward_batch=forward_batch,
)
# Fully Connected
@@ -303,7 +303,7 @@ class MixtralModel(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
if input_embeds is None:
@@ -314,7 +314,7 @@ class MixtralModel(nn.Module):
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states, residual = layer(
positions, hidden_states, input_metadata, residual
positions, hidden_states, forward_batch, residual
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
@@ -339,12 +339,12 @@ class QuantMixtralForCausalLM(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
input_ids, hidden_states, self.lm_head.weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -48,7 +48,7 @@ from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class OlmoeMoE(nn.Module):
@@ -175,13 +175,13 @@ class OlmoeAttention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.q_norm(q.contiguous()), self.k_norm(k.contiguous())
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, input_metadata)
attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output)
return output
@@ -225,7 +225,7 @@ class OlmoeDecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
# Self Attention
@@ -238,7 +238,7 @@ class OlmoeDecoderLayer(nn.Module):
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
input_metadata=input_metadata,
forward_batch=forward_batch,
)
# Fully Connected
@@ -274,7 +274,7 @@ class OlmoeModel(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
if input_embeds is None:
@@ -285,7 +285,7 @@ class OlmoeModel(nn.Module):
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states, residual = layer(
positions, hidden_states, input_metadata, residual
positions, hidden_states, forward_batch, residual
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
@@ -314,12 +314,12 @@ class OlmoeForCausalLM(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
input_ids, hidden_states, self.lm_head.weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -39,7 +39,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class QWenMLP(nn.Module):
@@ -133,12 +133,12 @@ class QWenAttention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, input_metadata)
attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.c_proj(attn_output)
return output
@@ -177,7 +177,7 @@ class QWenBlock(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
# Self Attention
residual = hidden_states
@@ -185,7 +185,7 @@ class QWenBlock(nn.Module):
hidden_states = self.attn(
positions=positions,
hidden_states=hidden_states,
input_metadata=input_metadata,
forward_batch=forward_batch,
)
hidden_states = residual + hidden_states
@@ -224,7 +224,7 @@ class QWenModel(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
hidden_states = self.wte(input_ids)
for i in range(len(self.h)):
@@ -232,7 +232,7 @@ class QWenModel(nn.Module):
hidden_states = layer(
positions,
hidden_states,
input_metadata,
forward_batch,
)
hidden_states = self.ln_f(hidden_states)
return hidden_states
@@ -257,11 +257,11 @@ class QWenLMHeadModel(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
):
hidden_states = self.transformer(input_ids, positions, input_metadata)
hidden_states = self.transformer(input_ids, positions, forward_batch)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
input_ids, hidden_states, self.lm_head.weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -40,7 +40,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
Qwen2Config = None
@@ -149,12 +149,12 @@ class Qwen2Attention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, input_metadata)
attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output)
return output
@@ -196,7 +196,7 @@ class Qwen2DecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
@@ -208,7 +208,7 @@ class Qwen2DecoderLayer(nn.Module):
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
input_metadata=input_metadata,
forward_batch=forward_batch,
)
# Fully Connected
@@ -243,7 +243,7 @@ class Qwen2Model(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
if input_embeds is None:
@@ -256,7 +256,7 @@ class Qwen2Model(nn.Module):
hidden_states, residual = layer(
positions,
hidden_states,
input_metadata,
forward_batch,
residual,
)
hidden_states, _ = self.norm(hidden_states, residual)
@@ -283,17 +283,17 @@ class Qwen2ForCausalLM(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
get_embedding: bool = False,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
if not get_embedding:
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
input_ids, hidden_states, self.lm_head.weight, forward_batch
)
else:
return self.pooler(hidden_states, input_metadata)
return self.pooler(hidden_states, forward_batch)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [

View File

@@ -49,7 +49,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.torchao_utils import apply_torchao_config_
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class Qwen2MoeMLP(nn.Module):
@@ -221,12 +221,12 @@ class Qwen2MoeAttention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, input_metadata)
attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output)
return output
@@ -281,7 +281,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
# Self Attention
@@ -293,7 +293,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
input_metadata=input_metadata,
forward_batch=forward_batch,
)
# Fully Connected
@@ -331,7 +331,7 @@ class Qwen2MoeModel(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
if input_embeds is None:
@@ -342,7 +342,7 @@ class Qwen2MoeModel(nn.Module):
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states, residual = layer(
positions, hidden_states, input_metadata, residual
positions, hidden_states, forward_batch, residual
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
@@ -373,12 +373,12 @@ class Qwen2MoeForCausalLM(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
input_ids, hidden_states, self.lm_head.weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -40,7 +40,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class StablelmMLP(nn.Module):
@@ -145,12 +145,12 @@ class StablelmAttention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, input_metadata)
attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output)
return output
@@ -173,7 +173,7 @@ class StablelmDecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
residual = hidden_states
@@ -181,7 +181,7 @@ class StablelmDecoderLayer(nn.Module):
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
input_metadata=input_metadata,
forward_batch=forward_batch,
)
hidden_states = residual + hidden_states
@@ -218,7 +218,7 @@ class StableLMEpochModel(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
if input_embeds is None:
@@ -230,7 +230,7 @@ class StableLMEpochModel(nn.Module):
hidden_states, residual = layer(
positions,
hidden_states,
input_metadata,
forward_batch,
)
hidden_states = self.norm(hidden_states)
return hidden_states
@@ -255,12 +255,12 @@ class StableLmForCausalLM(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
input_ids, hidden_states, self.lm_head.weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -41,7 +41,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import InputMetadata
from sglang.srt.model_executor.model_runner import ForwardBatch
class XverseMLP(nn.Module):
@@ -160,12 +160,12 @@ class XverseAttention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, input_metadata)
attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output)
return output
@@ -222,7 +222,7 @@ class XverseDecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
@@ -234,7 +234,7 @@ class XverseDecoderLayer(nn.Module):
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
input_metadata=input_metadata,
forward_batch=forward_batch,
)
# Fully Connected
@@ -271,7 +271,7 @@ class XverseModel(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
if input_embeds is None:
@@ -284,7 +284,7 @@ class XverseModel(nn.Module):
hidden_states, residual = layer(
positions,
hidden_states,
input_metadata,
forward_batch,
residual,
)
# print(f"layer[{i}].hidden_states: {hidden_states}")
@@ -312,12 +312,12 @@ class XverseForCausalLM(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
input_ids, hidden_states, self.lm_head.weight, forward_batch
)
def load_weights(

View File

@@ -44,7 +44,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class XverseMLP(nn.Module):
@@ -244,12 +244,12 @@ class XverseAttention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, input_metadata)
attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output)
return output
@@ -300,7 +300,7 @@ class XverseDecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
# Self Attention
@@ -312,7 +312,7 @@ class XverseDecoderLayer(nn.Module):
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
input_metadata=input_metadata,
forward_batch=forward_batch,
)
# Fully Connected
@@ -353,14 +353,14 @@ class XverseModel(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states, residual = layer(
positions, hidden_states, input_metadata, residual
positions, hidden_states, forward_batch, residual
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
@@ -388,11 +388,11 @@ class XverseMoeForCausalLM(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata)
hidden_states = self.model(input_ids, positions, forward_batch)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
input_ids, hidden_states, self.lm_head.weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):