Rename InputMetadata -> ForwardBatch (#1543)
This commit is contained in:
@@ -225,16 +225,16 @@ def extend(reqs, model_runner):
|
||||
tree_cache=None,
|
||||
)
|
||||
batch.prepare_for_extend(model_runner.model_config.vocab_size)
|
||||
input_metadata = batch.get_input_metadata()
|
||||
logits_output = model_runner.forward(input_metadata)
|
||||
forward_batch = batch.get_forward_batch()
|
||||
logits_output = model_runner.forward(forward_batch)
|
||||
next_token_ids = model_runner.sample(logits_output, batch).tolist()
|
||||
return next_token_ids, logits_output.next_token_logits, batch
|
||||
|
||||
|
||||
def decode(input_token_ids, batch, model_runner):
|
||||
batch.prepare_for_decode(input_token_ids)
|
||||
input_metadata = batch.get_input_metadata()
|
||||
logits_output = model_runner.forward(input_metadata)
|
||||
forward_batch = batch.get_forward_batch()
|
||||
logits_output = model_runner.forward(forward_batch)
|
||||
next_token_ids = model_runner.sample(logits_output, batch).tolist()
|
||||
return next_token_ids, logits_output.next_token_logits
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ import torch.nn as nn
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.layers.flashinfer_utils import update_flashinfer_indices
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -37,7 +37,7 @@ class AttentionBackend(ABC):
|
||||
"""The base class of attention backends"""
|
||||
|
||||
@abstractmethod
|
||||
def init_forward_metadata(self, input_metadata: InputMetadata):
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
"""Init the metadata for a forward pass."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -61,18 +61,18 @@ class AttentionBackend(ABC):
|
||||
"""Get the fill value for padded seq lens. Typically, it is 0 or 1."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def forward(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
|
||||
def forward(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
|
||||
"""Run forward on an attention layer."""
|
||||
if input_metadata.forward_mode.is_decode():
|
||||
return self.forward_decode(q, k, v, layer, input_metadata)
|
||||
if forward_batch.forward_mode.is_decode():
|
||||
return self.forward_decode(q, k, v, layer, forward_batch)
|
||||
else:
|
||||
return self.forward_extend(q, k, v, layer, input_metadata)
|
||||
return self.forward_extend(q, k, v, layer, forward_batch)
|
||||
|
||||
def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
|
||||
def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
|
||||
"""Run a forward for decode."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
|
||||
def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
|
||||
"""Run a forward for extend."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -131,31 +131,31 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
self.forward_metadata = None
|
||||
self.cuda_graph_metadata = {}
|
||||
|
||||
def init_forward_metadata(self, input_metadata: InputMetadata):
|
||||
if input_metadata.forward_mode.is_decode():
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
if forward_batch.forward_mode.is_decode():
|
||||
prefix_lens = None
|
||||
use_ragged = False
|
||||
extend_no_prefix = False
|
||||
total_num_tokens = None
|
||||
else:
|
||||
prefix_lens = input_metadata.extend_prefix_lens
|
||||
prefix_lens = forward_batch.extend_prefix_lens
|
||||
|
||||
# Some heuristics to check whether to use ragged forward
|
||||
use_ragged = False
|
||||
if (
|
||||
torch.sum(input_metadata.seq_lens).item() >= 4096
|
||||
torch.sum(forward_batch.seq_lens).item() >= 4096
|
||||
and self.model_runner.sliding_window_size is None
|
||||
):
|
||||
use_ragged = True
|
||||
|
||||
total_num_tokens = torch.sum(input_metadata.seq_lens).item()
|
||||
extend_no_prefix = not torch.any(input_metadata.extend_prefix_lens).item()
|
||||
total_num_tokens = torch.sum(forward_batch.seq_lens).item()
|
||||
extend_no_prefix = not torch.any(forward_batch.extend_prefix_lens).item()
|
||||
|
||||
update_flashinfer_indices(
|
||||
input_metadata.forward_mode,
|
||||
forward_batch.forward_mode,
|
||||
self.model_runner,
|
||||
input_metadata.req_pool_indices,
|
||||
input_metadata.seq_lens,
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
prefix_lens,
|
||||
use_ragged=use_ragged,
|
||||
)
|
||||
@@ -248,7 +248,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
def get_cuda_graph_seq_len_fill_value(self):
|
||||
return 0
|
||||
|
||||
def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
|
||||
def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
|
||||
if not isinstance(self.prefill_wrapper_paged, list):
|
||||
prefill_wrapper_paged = self.prefill_wrapper_paged
|
||||
else:
|
||||
@@ -264,12 +264,12 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
if not use_ragged:
|
||||
if k is not None:
|
||||
assert v is not None
|
||||
input_metadata.token_to_kv_pool.set_kv_buffer(
|
||||
layer.layer_id, input_metadata.out_cache_loc, k, v
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer.layer_id, forward_batch.out_cache_loc, k, v
|
||||
)
|
||||
o = prefill_wrapper_paged.forward(
|
||||
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||
input_metadata.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
||||
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
||||
causal=True,
|
||||
sm_scale=layer.scaling,
|
||||
window_left=layer.sliding_window_size,
|
||||
@@ -290,7 +290,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
else:
|
||||
o2, s2 = prefill_wrapper_paged.forward_return_lse(
|
||||
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||
input_metadata.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
||||
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
||||
causal=False,
|
||||
sm_scale=layer.scaling,
|
||||
logits_soft_cap=layer.logit_cap,
|
||||
@@ -298,13 +298,13 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
|
||||
o, _ = merge_state(o1, s1, o2, s2)
|
||||
|
||||
input_metadata.token_to_kv_pool.set_kv_buffer(
|
||||
layer.layer_id, input_metadata.out_cache_loc, k, v
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer.layer_id, forward_batch.out_cache_loc, k, v
|
||||
)
|
||||
|
||||
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
||||
|
||||
def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
|
||||
def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
|
||||
use_ragged, extend_no_prefix, total_num_tokens, decode_wrapper = (
|
||||
self.forward_metadata
|
||||
)
|
||||
@@ -317,13 +317,13 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
|
||||
if k is not None:
|
||||
assert v is not None
|
||||
input_metadata.token_to_kv_pool.set_kv_buffer(
|
||||
layer.layer_id, input_metadata.out_cache_loc, k, v
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer.layer_id, forward_batch.out_cache_loc, k, v
|
||||
)
|
||||
|
||||
o = decode_wrapper.forward(
|
||||
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||
input_metadata.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
||||
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
||||
sm_scale=layer.scaling,
|
||||
logits_soft_cap=layer.logit_cap,
|
||||
)
|
||||
@@ -358,26 +358,26 @@ class TritonAttnBackend(AttentionBackend):
|
||||
|
||||
self.cuda_graph_max_seq_len = model_runner.model_config.context_len
|
||||
|
||||
def init_forward_metadata(self, input_metadata: InputMetadata):
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
"""Init auxiliary variables for triton attention backend."""
|
||||
|
||||
if input_metadata.forward_mode.is_decode():
|
||||
start_loc = torch.zeros_like(input_metadata.seq_lens, dtype=torch.int32)
|
||||
start_loc[1:] = torch.cumsum(input_metadata.seq_lens[:-1], dim=0)
|
||||
if forward_batch.forward_mode.is_decode():
|
||||
start_loc = torch.zeros_like(forward_batch.seq_lens, dtype=torch.int32)
|
||||
start_loc[1:] = torch.cumsum(forward_batch.seq_lens[:-1], dim=0)
|
||||
|
||||
total_num_tokens = torch.sum(input_metadata.seq_lens).item()
|
||||
total_num_tokens = torch.sum(forward_batch.seq_lens).item()
|
||||
attn_logits = torch.empty(
|
||||
(self.num_head, total_num_tokens),
|
||||
dtype=self.reduce_dtype,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
max_seq_len = torch.max(input_metadata.seq_lens).item()
|
||||
max_seq_len = torch.max(forward_batch.seq_lens).item()
|
||||
max_extend_len = None
|
||||
else:
|
||||
start_loc = attn_logits = max_seq_len = None
|
||||
prefix_lens = input_metadata.extend_prefix_lens
|
||||
max_extend_len = torch.max(input_metadata.seq_lens - prefix_lens).item()
|
||||
prefix_lens = forward_batch.extend_prefix_lens
|
||||
max_extend_len = torch.max(forward_batch.seq_lens - prefix_lens).item()
|
||||
|
||||
self.forward_metadata = start_loc, attn_logits, max_seq_len, max_extend_len
|
||||
|
||||
@@ -415,15 +415,15 @@ class TritonAttnBackend(AttentionBackend):
|
||||
def get_cuda_graph_seq_len_fill_value(self):
|
||||
return 1
|
||||
|
||||
def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
|
||||
def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
|
||||
# TODO: reuse the buffer across layers
|
||||
if layer.qk_head_dim != layer.v_head_dim:
|
||||
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
||||
else:
|
||||
o = torch.empty_like(q)
|
||||
|
||||
input_metadata.token_to_kv_pool.set_kv_buffer(
|
||||
layer.layer_id, input_metadata.out_cache_loc, k, v
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer.layer_id, forward_batch.out_cache_loc, k, v
|
||||
)
|
||||
|
||||
start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
|
||||
@@ -432,20 +432,20 @@ class TritonAttnBackend(AttentionBackend):
|
||||
k.contiguous(),
|
||||
v.contiguous(),
|
||||
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
|
||||
input_metadata.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
||||
input_metadata.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
||||
input_metadata.req_to_token_pool.req_to_token,
|
||||
input_metadata.req_pool_indices,
|
||||
input_metadata.seq_lens,
|
||||
input_metadata.extend_seq_lens,
|
||||
input_metadata.extend_start_loc,
|
||||
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
||||
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
||||
forward_batch.req_to_token_pool.req_to_token,
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
forward_batch.extend_seq_lens,
|
||||
forward_batch.extend_start_loc,
|
||||
max_extend_len,
|
||||
layer.scaling,
|
||||
layer.logit_cap,
|
||||
)
|
||||
return o
|
||||
|
||||
def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
|
||||
def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
|
||||
# During torch.compile, there is a bug in rotary_emb that causes the
|
||||
# output value to have a 3D tensor shape. This reshapes the output correctly.
|
||||
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
|
||||
@@ -458,19 +458,19 @@ class TritonAttnBackend(AttentionBackend):
|
||||
|
||||
start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
|
||||
|
||||
input_metadata.token_to_kv_pool.set_kv_buffer(
|
||||
layer.layer_id, input_metadata.out_cache_loc, k, v
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer.layer_id, forward_batch.out_cache_loc, k, v
|
||||
)
|
||||
|
||||
self.decode_attention_fwd(
|
||||
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
||||
input_metadata.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
||||
input_metadata.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
||||
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
||||
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
||||
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
|
||||
input_metadata.req_to_token_pool.req_to_token,
|
||||
input_metadata.req_pool_indices,
|
||||
forward_batch.req_to_token_pool.req_to_token,
|
||||
forward_batch.req_pool_indices,
|
||||
start_loc,
|
||||
input_metadata.seq_lens,
|
||||
forward_batch.seq_lens,
|
||||
attn_logits,
|
||||
max_seq_len,
|
||||
layer.scaling,
|
||||
|
||||
@@ -25,7 +25,7 @@ from vllm.distributed import (
|
||||
tensor_model_parallel_all_gather,
|
||||
)
|
||||
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@@ -61,26 +61,26 @@ class LogitsMetadata:
|
||||
extend_logprob_pruned_lens_cpu: Optional[List[int]] = None
|
||||
|
||||
@classmethod
|
||||
def from_input_metadata(cls, input_metadata: InputMetadata):
|
||||
return_top_logprob = any(x > 0 for x in input_metadata.top_logprobs_nums)
|
||||
if input_metadata.forward_mode.is_extend():
|
||||
def from_forward_batch(cls, forward_batch: ForwardBatch):
|
||||
return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
|
||||
if forward_batch.forward_mode.is_extend():
|
||||
extend_logprob_pruned_lens_cpu = [
|
||||
extend_len - start_len
|
||||
for extend_len, start_len in zip(
|
||||
input_metadata.extend_seq_lens,
|
||||
input_metadata.extend_logprob_start_lens_cpu,
|
||||
forward_batch.extend_seq_lens,
|
||||
forward_batch.extend_logprob_start_lens_cpu,
|
||||
)
|
||||
]
|
||||
else:
|
||||
extend_logprob_pruned_lens_cpu = None
|
||||
return cls(
|
||||
forward_mode=input_metadata.forward_mode,
|
||||
top_logprobs_nums=input_metadata.top_logprobs_nums,
|
||||
return_logprob=input_metadata.return_logprob,
|
||||
forward_mode=forward_batch.forward_mode,
|
||||
top_logprobs_nums=forward_batch.top_logprobs_nums,
|
||||
return_logprob=forward_batch.return_logprob,
|
||||
return_top_logprob=return_top_logprob,
|
||||
extend_seq_lens=input_metadata.extend_seq_lens,
|
||||
extend_seq_lens_cpu=input_metadata.extend_seq_lens_cpu,
|
||||
extend_logprob_start_lens_cpu=input_metadata.extend_logprob_start_lens_cpu,
|
||||
extend_seq_lens=forward_batch.extend_seq_lens,
|
||||
extend_seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
|
||||
extend_logprob_start_lens_cpu=forward_batch.extend_logprob_start_lens_cpu,
|
||||
extend_logprob_pruned_lens_cpu=extend_logprob_pruned_lens_cpu,
|
||||
)
|
||||
|
||||
@@ -162,10 +162,10 @@ class LogitsProcessor(nn.Module):
|
||||
input_ids,
|
||||
hidden_states,
|
||||
weight,
|
||||
logits_metadata: Union[LogitsMetadata, InputMetadata],
|
||||
logits_metadata: Union[LogitsMetadata, ForwardBatch],
|
||||
):
|
||||
if isinstance(logits_metadata, InputMetadata):
|
||||
logits_metadata = LogitsMetadata.from_input_metadata(logits_metadata)
|
||||
if isinstance(logits_metadata, ForwardBatch):
|
||||
logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)
|
||||
assert isinstance(logits_metadata, LogitsMetadata)
|
||||
|
||||
# Get the last hidden states and last logits for the next token prediction
|
||||
|
||||
@@ -7,7 +7,7 @@ from enum import IntEnum
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from sglang.srt.model_executor.model_runner import InputMetadata
|
||||
from sglang.srt.model_executor.model_runner import ForwardBatch
|
||||
|
||||
|
||||
class PoolingType(IntEnum):
|
||||
@@ -36,10 +36,10 @@ class Pooler(nn.Module):
|
||||
self.normalize = normalize
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, input_metadata: InputMetadata
|
||||
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
|
||||
) -> EmbeddingPoolerOutput:
|
||||
if self.pooling_type == PoolingType.LAST:
|
||||
last_token_indices = torch.cumsum(input_metadata.extend_seq_lens, dim=0) - 1
|
||||
last_token_indices = torch.cumsum(forward_batch.extend_seq_lens, dim=0) - 1
|
||||
pooled_data = hidden_states[last_token_indices]
|
||||
else:
|
||||
raise ValueError(f"Invalid pooling type: {self.pooling_type}")
|
||||
|
||||
@@ -17,7 +17,7 @@ limitations under the License.
|
||||
|
||||
from torch import nn
|
||||
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
|
||||
|
||||
class RadixAttention(nn.Module):
|
||||
@@ -48,11 +48,11 @@ class RadixAttention(nn.Module):
|
||||
self.logit_cap = logit_cap
|
||||
self.sliding_window_size = sliding_window_size or -1
|
||||
|
||||
def forward(self, q, k, v, input_metadata: InputMetadata):
|
||||
def forward(self, q, k, v, forward_batch: ForwardBatch):
|
||||
if k is not None:
|
||||
# For cross-layer sharing, kv can be None
|
||||
assert v is not None
|
||||
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
|
||||
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
|
||||
|
||||
return input_metadata.attn_backend.forward(q, k, v, self, input_metadata)
|
||||
return forward_batch.attn_backend.forward(q, k, v, self, forward_batch)
|
||||
|
||||
@@ -40,7 +40,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
)
|
||||
from vllm.model_executor.model_loader.loader import DefaultModelLoader
|
||||
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
|
||||
|
||||
class BaseLayerWithLoRA(nn.Module):
|
||||
|
||||
@@ -23,7 +23,7 @@ import torch
|
||||
|
||||
from sglang.srt.lora.lora import LoRAAdapter, get_lora_layer
|
||||
from sglang.srt.lora.lora_config import LoRAConfig
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.utils import is_hip, replace_submodule
|
||||
|
||||
# ROCm: flashinfer available later
|
||||
@@ -207,9 +207,9 @@ class LoRAManager:
|
||||
if lora_weight_name:
|
||||
self.B_buffer[lora_weight_name][i][buffer_id].copy_(weights)
|
||||
|
||||
def prepare_lora_batch(self, input_metadata: InputMetadata):
|
||||
def prepare_lora_batch(self, forward_batch: ForwardBatch):
|
||||
# load active loras into lora memory pool
|
||||
cur_uids = set(input_metadata.lora_paths)
|
||||
cur_uids = set(forward_batch.lora_paths)
|
||||
assert len(cur_uids) <= self.max_loras_per_batch
|
||||
i = 0
|
||||
evictable_uids = list(self.active_uids)
|
||||
@@ -229,14 +229,14 @@ class LoRAManager:
|
||||
return
|
||||
|
||||
# setup lora in forward modules
|
||||
bs = input_metadata.batch_size
|
||||
bs = forward_batch.batch_size
|
||||
seg_lens = (
|
||||
input_metadata.extend_seq_lens
|
||||
if input_metadata.forward_mode.is_extend()
|
||||
forward_batch.extend_seq_lens
|
||||
if forward_batch.forward_mode.is_extend()
|
||||
else torch.ones(bs)
|
||||
)
|
||||
weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda")
|
||||
for i, lora_path in enumerate(input_metadata.lora_paths):
|
||||
for i, lora_path in enumerate(forward_batch.lora_paths):
|
||||
weight_indices[i] = self.buffer_id[lora_path]
|
||||
|
||||
for module_name, module in self.lora_modules:
|
||||
|
||||
@@ -29,7 +29,7 @@ from sglang.srt.constrained.jump_forward import JumpForwardMap
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
@@ -511,8 +511,8 @@ class ScheduleBatch:
|
||||
self.extend_logprob_start_lens_cpu = [r.extend_logprob_start_len for r in reqs]
|
||||
self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size)
|
||||
|
||||
def get_input_metadata(self):
|
||||
return InputMetadata.from_schedule_batch(self)
|
||||
def get_forward_batch(self):
|
||||
return ForwardBatch.from_schedule_batch(self)
|
||||
|
||||
def mix_with_running(self, running_batch: "ScheduleBatch"):
|
||||
self.forward_mode = ForwardMode.MIXED
|
||||
|
||||
@@ -32,7 +32,7 @@ from sglang.srt.mem_cache.radix_cache import TreeNode
|
||||
CLIP_MAX_NEW_TOKENS = int(os.environ.get("SGLANG_CLIP_MAX_NEW_TOKENS", "4096"))
|
||||
|
||||
|
||||
class SchedulerPolicy:
|
||||
class SchedulePolicy:
|
||||
def __init__(self, policy: str, tree_cache: BasePrefixCache):
|
||||
if tree_cache.disable and policy in ["lpm", "dfs-weight"]:
|
||||
# LPM and DFS-weight is meaningless when the tree cache is disabled.
|
||||
@@ -50,8 +50,8 @@ from sglang.srt.managers.schedule_batch import (
|
||||
Req,
|
||||
ScheduleBatch,
|
||||
)
|
||||
from sglang.srt.managers.scheduler_policy import PrefillAdder, SchedulerPolicy
|
||||
from sglang.srt.managers.tp_worker import ModelTpWorker
|
||||
from sglang.srt.managers.schedule_policy import PrefillAdder, SchedulePolicy
|
||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
@@ -134,7 +134,7 @@ class Scheduler:
|
||||
)
|
||||
|
||||
# Launch a tensor parallel worker
|
||||
self.tp_worker = ModelTpWorker(
|
||||
self.tp_worker = TpModelWorker(
|
||||
gpu_id=gpu_id,
|
||||
tp_rank=tp_rank,
|
||||
server_args=server_args,
|
||||
@@ -179,7 +179,7 @@ class Scheduler:
|
||||
disable=server_args.disable_radix_cache,
|
||||
)
|
||||
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
||||
self.policy = SchedulerPolicy(self.schedule_policy, self.tree_cache)
|
||||
self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
|
||||
|
||||
# Init running status
|
||||
self.waiting_queue: List[Req] = []
|
||||
@@ -575,9 +575,9 @@ class Scheduler:
|
||||
if self.is_generation:
|
||||
# Forward and sample the next tokens
|
||||
if batch.extend_num_tokens != 0:
|
||||
input_metadata = batch.get_input_metadata()
|
||||
forward_batch = batch.get_forward_batch()
|
||||
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
|
||||
input_metadata, batch
|
||||
forward_batch, batch
|
||||
)
|
||||
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
||||
next_token_ids
|
||||
@@ -641,8 +641,8 @@ class Scheduler:
|
||||
)
|
||||
else:
|
||||
assert batch.extend_num_tokens != 0
|
||||
input_metadata = batch.get_input_metadata()
|
||||
embeddings = self.tp_worker.forward_batch_embedding(input_metadata)
|
||||
forward_batch = batch.get_forward_batch()
|
||||
embeddings = self.tp_worker.forward_batch_embedding(forward_batch)
|
||||
|
||||
# Check finish conditions
|
||||
for i, req in enumerate(batch.reqs):
|
||||
@@ -771,9 +771,9 @@ class Scheduler:
|
||||
batch.prepare_for_decode()
|
||||
|
||||
# Forward and sample the next tokens
|
||||
input_metadata = batch.get_input_metadata()
|
||||
forward_batch = batch.get_forward_batch()
|
||||
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
|
||||
input_metadata, batch
|
||||
forward_batch, batch
|
||||
)
|
||||
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
||||
next_token_ids
|
||||
|
||||
@@ -21,7 +21,7 @@ import logging
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||
from sglang.srt.managers.io_struct import UpdateWeightReqInput
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import broadcast_pyobj, is_multimodal_model, set_random_seed
|
||||
@@ -29,7 +29,9 @@ from sglang.srt.utils import broadcast_pyobj, is_multimodal_model, set_random_se
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelTpWorker:
|
||||
class TpModelWorker:
|
||||
"""A tensor parallel model worker."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
gpu_id: int,
|
||||
@@ -106,13 +108,13 @@ class ModelTpWorker:
|
||||
self.random_seed,
|
||||
)
|
||||
|
||||
def forward_batch_generation(self, input_metadata: InputMetadata, batch):
|
||||
logits_output = self.model_runner.forward(input_metadata)
|
||||
def forward_batch_generation(self, forward_batch: ForwardBatch, batch):
|
||||
logits_output = self.model_runner.forward(forward_batch)
|
||||
next_token_ids = self.model_runner.sample(logits_output, batch)
|
||||
return logits_output, next_token_ids
|
||||
|
||||
def forward_batch_embedding(self, input_metadata: InputMetadata):
|
||||
logits_output = self.model_runner.forward(input_metadata)
|
||||
def forward_batch_embedding(self, forward_batch: ForwardBatch):
|
||||
logits_output = self.model_runner.forward(forward_batch)
|
||||
embeddings = logits_output.embeddings.tolist()
|
||||
return embeddings
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ from sglang.srt.layers.logits_processor import (
|
||||
LogitsProcessor,
|
||||
LogitsProcessorOutput,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.utils import monkey_patch_vllm_all_gather
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -196,7 +196,7 @@ class CudaGraphRunner:
|
||||
|
||||
# Run and capture
|
||||
def run_once():
|
||||
input_metadata = InputMetadata(
|
||||
forward_batch = ForwardBatch(
|
||||
forward_mode=ForwardMode.DECODE,
|
||||
batch_size=bs,
|
||||
input_ids=input_ids,
|
||||
@@ -210,7 +210,7 @@ class CudaGraphRunner:
|
||||
top_logprobs_nums=[0] * bs,
|
||||
positions=torch.clamp((seq_lens - 1), min=0).to(torch.int64),
|
||||
)
|
||||
return forward(input_ids, input_metadata.positions, input_metadata)
|
||||
return forward(input_ids, forward_batch.positions, forward_batch)
|
||||
|
||||
for _ in range(2):
|
||||
torch.cuda.synchronize()
|
||||
@@ -233,9 +233,9 @@ class CudaGraphRunner:
|
||||
self.graph_memory_pool = graph.pool()
|
||||
return graph, out
|
||||
|
||||
def replay(self, input_metadata: InputMetadata):
|
||||
assert input_metadata.out_cache_loc is not None
|
||||
raw_bs = input_metadata.batch_size
|
||||
def replay(self, forward_batch: ForwardBatch):
|
||||
assert forward_batch.out_cache_loc is not None
|
||||
raw_bs = forward_batch.batch_size
|
||||
|
||||
# Pad
|
||||
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
||||
@@ -245,10 +245,10 @@ class CudaGraphRunner:
|
||||
self.out_cache_loc.zero_()
|
||||
|
||||
# Common inputs
|
||||
self.input_ids[:raw_bs] = input_metadata.input_ids
|
||||
self.req_pool_indices[:raw_bs] = input_metadata.req_pool_indices
|
||||
self.seq_lens[:raw_bs] = input_metadata.seq_lens
|
||||
self.out_cache_loc[:raw_bs] = input_metadata.out_cache_loc
|
||||
self.input_ids[:raw_bs] = forward_batch.input_ids
|
||||
self.req_pool_indices[:raw_bs] = forward_batch.req_pool_indices
|
||||
self.seq_lens[:raw_bs] = forward_batch.seq_lens
|
||||
self.out_cache_loc[:raw_bs] = forward_batch.out_cache_loc
|
||||
|
||||
# Attention backend
|
||||
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
|
||||
@@ -271,15 +271,15 @@ class CudaGraphRunner:
|
||||
)
|
||||
|
||||
# Extract logprobs
|
||||
if input_metadata.return_logprob:
|
||||
if forward_batch.return_logprob:
|
||||
logits_output.next_token_logprobs = torch.nn.functional.log_softmax(
|
||||
logits_output.next_token_logits, dim=-1
|
||||
)
|
||||
return_top_logprob = any(x > 0 for x in input_metadata.top_logprobs_nums)
|
||||
return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
|
||||
if return_top_logprob:
|
||||
logits_metadata = LogitsMetadata(
|
||||
forward_mode=ForwardMode.DECODE,
|
||||
top_logprobs_nums=input_metadata.top_logprobs_nums,
|
||||
top_logprobs_nums=forward_batch.top_logprobs_nums,
|
||||
)
|
||||
logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
|
||||
logits_output.next_token_logprobs, logits_metadata
|
||||
|
||||
@@ -18,7 +18,7 @@ limitations under the License.
|
||||
"""Meta data for a forward pass."""
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum, auto
|
||||
from typing import TYPE_CHECKING, List, Set
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -53,8 +53,8 @@ class ForwardMode(IntEnum):
|
||||
|
||||
|
||||
@dataclass
|
||||
class InputMetadata:
|
||||
"""Store all inforamtion of a forward pass."""
|
||||
class ForwardBatch:
|
||||
"""Store all inputs of a forward pass."""
|
||||
|
||||
# The forward mode
|
||||
forward_mode: ForwardMode
|
||||
|
||||
@@ -48,7 +48,7 @@ from sglang.srt.mem_cache.memory_pool import (
|
||||
MLATokenToKVPool,
|
||||
ReqToTokenPool,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
@@ -466,47 +466,47 @@ class ModelRunner:
|
||||
logger.info("Capture cuda graph begin. This can take up to several minutes.")
|
||||
self.cuda_graph_runner = CudaGraphRunner(self)
|
||||
|
||||
def forward_decode(self, input_metadata: InputMetadata):
|
||||
def forward_decode(self, forward_batch: ForwardBatch):
|
||||
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(
|
||||
input_metadata.batch_size
|
||||
forward_batch.batch_size
|
||||
):
|
||||
return self.cuda_graph_runner.replay(input_metadata)
|
||||
return self.cuda_graph_runner.replay(forward_batch)
|
||||
|
||||
return self.model.forward(
|
||||
input_metadata.input_ids, input_metadata.positions, input_metadata
|
||||
forward_batch.input_ids, forward_batch.positions, forward_batch
|
||||
)
|
||||
|
||||
def forward_extend(self, input_metadata: InputMetadata):
|
||||
def forward_extend(self, forward_batch: ForwardBatch):
|
||||
if self.is_generation:
|
||||
return self.model.forward(
|
||||
input_metadata.input_ids, input_metadata.positions, input_metadata
|
||||
forward_batch.input_ids, forward_batch.positions, forward_batch
|
||||
)
|
||||
else:
|
||||
# Only embedding models have get_embedding parameter
|
||||
return self.model.forward(
|
||||
input_metadata.input_ids,
|
||||
input_metadata.positions,
|
||||
input_metadata,
|
||||
forward_batch.input_ids,
|
||||
forward_batch.positions,
|
||||
forward_batch,
|
||||
get_embedding=True,
|
||||
)
|
||||
|
||||
def forward(self, input_metadata: InputMetadata) -> LogitsProcessorOutput:
|
||||
def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
|
||||
# Attach attention information
|
||||
input_metadata.req_to_token_pool = self.req_to_token_pool
|
||||
input_metadata.token_to_kv_pool = self.token_to_kv_pool
|
||||
input_metadata.attn_backend = self.attn_backend
|
||||
input_metadata.attn_backend.init_forward_metadata(input_metadata)
|
||||
forward_batch.req_to_token_pool = self.req_to_token_pool
|
||||
forward_batch.token_to_kv_pool = self.token_to_kv_pool
|
||||
forward_batch.attn_backend = self.attn_backend
|
||||
forward_batch.attn_backend.init_forward_metadata(forward_batch)
|
||||
|
||||
# Attach lora information
|
||||
if self.server_args.lora_paths is not None:
|
||||
self.lora_manager.prepare_lora_batch(input_metadata)
|
||||
self.lora_manager.prepare_lora_batch(forward_batch)
|
||||
|
||||
if input_metadata.forward_mode.is_decode():
|
||||
return self.forward_decode(input_metadata)
|
||||
elif input_metadata.forward_mode.is_extend():
|
||||
return self.forward_extend(input_metadata)
|
||||
if forward_batch.forward_mode.is_decode():
|
||||
return self.forward_decode(forward_batch)
|
||||
elif forward_batch.forward_mode.is_extend():
|
||||
return self.forward_extend(forward_batch)
|
||||
else:
|
||||
raise ValueError(f"Invaid forward mode: {input_metadata.forward_mode}")
|
||||
raise ValueError(f"Invaid forward mode: {forward_batch.forward_mode}")
|
||||
|
||||
def _apply_logits_bias(
|
||||
self, logits: torch.Tensor, sampling_info: SamplingBatchInfo
|
||||
|
||||
@@ -46,7 +46,7 @@ from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
|
||||
|
||||
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
|
||||
@@ -189,13 +189,13 @@ class BaiChuanAttention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.W_pack(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
if self.postion_embedding != "ALIBI":
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, input_metadata)
|
||||
attn_output = self.attn(q, k, v, forward_batch)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@@ -237,7 +237,7 @@ class BaiChuanDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
@@ -249,7 +249,7 @@ class BaiChuanDecoderLayer(nn.Module):
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
input_metadata=input_metadata,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
@@ -292,7 +292,7 @@ class BaiChuanModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
@@ -301,7 +301,7 @@ class BaiChuanModel(nn.Module):
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
input_metadata,
|
||||
forward_batch,
|
||||
residual,
|
||||
)
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
@@ -350,11 +350,11 @@ class BaiChuanBaseForCausalLM(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, input_metadata)
|
||||
hidden_states = self.model(input_ids, positions, forward_batch)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
|
||||
@@ -42,7 +42,7 @@ from sglang.srt.layers.linear import (
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
|
||||
LoraConfig = None
|
||||
|
||||
@@ -118,7 +118,7 @@ class GLMAttention(nn.Module):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.query_key_value(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
@@ -127,7 +127,7 @@ class GLMAttention(nn.Module):
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
input_metadata,
|
||||
forward_batch,
|
||||
)
|
||||
attn_output, _ = self.dense(context_layer)
|
||||
return attn_output
|
||||
@@ -220,7 +220,7 @@ class GLMBlock(nn.Module):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
# hidden_states: [num_tokens, h]
|
||||
# Layer norm at the beginning of the transformer layer.
|
||||
@@ -229,7 +229,7 @@ class GLMBlock(nn.Module):
|
||||
attention_output = self.self_attention(
|
||||
hidden_states=layernorm_output,
|
||||
position_ids=position_ids,
|
||||
input_metadata=input_metadata,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
|
||||
# Residual connection.
|
||||
@@ -288,14 +288,14 @@ class GLMTransformer(nn.Module):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
for i in range(self.num_layers):
|
||||
layer = self.layers[i]
|
||||
hidden_states = layer(
|
||||
hidden_states=hidden_states,
|
||||
position_ids=position_ids,
|
||||
input_metadata=input_metadata,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
# Final layer norm.
|
||||
if self.post_layer_norm:
|
||||
@@ -328,7 +328,7 @@ class ChatGLMModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.embedding(input_ids)
|
||||
|
||||
@@ -336,7 +336,7 @@ class ChatGLMModel(nn.Module):
|
||||
hidden_states = self.encoder(
|
||||
hidden_states=inputs_embeds,
|
||||
position_ids=position_ids,
|
||||
input_metadata=input_metadata,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
@@ -376,11 +376,11 @@ class ChatGLMForCausalLM(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(input_ids, positions, input_metadata)
|
||||
hidden_states = self.transformer(input_ids, positions, forward_batch)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
|
||||
@@ -63,7 +63,7 @@ from sglang.srt.layers.linear import (
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.utils import set_weight_attrs
|
||||
|
||||
|
||||
@@ -220,14 +220,14 @@ class CohereAttention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
if self.use_qk_norm:
|
||||
q, k = self._apply_qk_norm(q, k)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, input_metadata)
|
||||
attn_output = self.attn(q, k, v, forward_batch)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@@ -255,7 +255,7 @@ class CohereDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
@@ -264,7 +264,7 @@ class CohereDecoderLayer(nn.Module):
|
||||
hidden_states_attention = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
input_metadata=input_metadata,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
hidden_states_mlp = self.mlp(hidden_states)
|
||||
# Add everything together
|
||||
@@ -299,7 +299,7 @@ class CohereModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
@@ -308,7 +308,7 @@ class CohereModel(nn.Module):
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
input_metadata,
|
||||
forward_batch,
|
||||
residual,
|
||||
)
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
@@ -333,15 +333,15 @@ class CohereForCausalLM(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
positions,
|
||||
input_metadata,
|
||||
forward_batch,
|
||||
)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
|
||||
input_ids, hidden_states, self.model.embed_tokens.weight, forward_batch
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
|
||||
@@ -44,7 +44,7 @@ from sglang.srt.layers.linear import (
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.utils import set_weight_attrs
|
||||
|
||||
|
||||
@@ -249,14 +249,14 @@ class DbrxAttention(nn.Module):
|
||||
self,
|
||||
position_ids: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.Wqkv(hidden_states)
|
||||
if self.clip_qkv is not None:
|
||||
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(position_ids, q, k)
|
||||
attn_output = self.attn(q, k, v, input_metadata)
|
||||
attn_output = self.attn(q, k, v, forward_batch)
|
||||
hidden_states, _ = self.out_proj(attn_output)
|
||||
return hidden_states
|
||||
|
||||
@@ -278,14 +278,14 @@ class DbrxFusedNormAttention(nn.Module):
|
||||
self,
|
||||
position_ids: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm_1(hidden_states)
|
||||
x = self.attn(
|
||||
position_ids=position_ids,
|
||||
hidden_states=hidden_states,
|
||||
input_metadata=input_metadata,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
hidden_states = residual + x
|
||||
residual = hidden_states
|
||||
@@ -310,12 +310,12 @@ class DbrxBlock(nn.Module):
|
||||
self,
|
||||
position_ids: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
hidden_states, residual = self.norm_attn_norm(
|
||||
position_ids=position_ids,
|
||||
hidden_states=hidden_states,
|
||||
input_metadata=input_metadata,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
hidden_states = self.ffn(hidden_states)
|
||||
hidden_states = hidden_states + residual
|
||||
@@ -349,7 +349,7 @@ class DbrxModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
if input_embeds is None:
|
||||
@@ -358,7 +358,7 @@ class DbrxModel(nn.Module):
|
||||
hidden_states = input_embeds
|
||||
for i in range(len(self.blocks)):
|
||||
block = self.blocks[i]
|
||||
hidden_states = block(position_ids, hidden_states, input_metadata)
|
||||
hidden_states = block(position_ids, hidden_states, forward_batch)
|
||||
hidden_states = self.norm_f(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
@@ -388,11 +388,11 @@ class DbrxForCausalLM(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(input_ids, positions, input_metadata)
|
||||
hidden_states = self.transformer(input_ids, positions, forward_batch)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
|
||||
@@ -46,7 +46,7 @@ from sglang.srt.layers.linear import (
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
|
||||
|
||||
class DeepseekMLP(nn.Module):
|
||||
@@ -246,12 +246,12 @@ class DeepseekAttention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, input_metadata)
|
||||
attn_output = self.attn(q, k, v, forward_batch)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@@ -303,7 +303,7 @@ class DeepseekDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
@@ -315,7 +315,7 @@ class DeepseekDecoderLayer(nn.Module):
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
input_metadata=input_metadata,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
@@ -356,14 +356,14 @@ class DeepseekModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
for i in range(len(self.layers)):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions, hidden_states, input_metadata, residual
|
||||
positions, hidden_states, forward_batch, residual
|
||||
)
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
@@ -391,11 +391,11 @@ class DeepseekForCausalLM(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, input_metadata)
|
||||
hidden_states = self.model(input_ids, positions, forward_batch)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
|
||||
@@ -46,7 +46,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
# ROCm: flashinfer available later
|
||||
@@ -281,7 +281,7 @@ class DeepseekV2Attention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
if self.q_lora_rank is not None:
|
||||
q = self.q_a_proj(hidden_states)[0]
|
||||
@@ -314,7 +314,7 @@ class DeepseekV2Attention(nn.Module):
|
||||
v = torch.nn.functional.pad(v, [0, 256 - self.v_head_dim], value=0).view(
|
||||
-1, self.num_local_heads * 256
|
||||
)
|
||||
attn_output = self.attn(q, k, v, input_metadata)
|
||||
attn_output = self.attn(q, k, v, forward_batch)
|
||||
attn_output = attn_output.view(-1, self.num_local_heads, 256)[
|
||||
..., : self.v_head_dim
|
||||
].reshape(-1, self.num_local_heads * self.v_head_dim)
|
||||
@@ -433,7 +433,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
q_len = hidden_states.shape[0]
|
||||
q_input = hidden_states.new_empty(
|
||||
@@ -471,7 +471,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
q_input[..., self.kv_lora_rank :] = q_pe
|
||||
k_input[..., self.kv_lora_rank :] = k_pe
|
||||
|
||||
attn_output = self.attn(q_input, k_input, v_input, input_metadata)
|
||||
attn_output = self.attn(q_input, k_input, v_input, forward_batch)
|
||||
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
||||
|
||||
if self.w_vc.dtype == torch.float8_e4m3fn:
|
||||
@@ -567,7 +567,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
@@ -579,7 +579,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
input_metadata=input_metadata,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
@@ -623,14 +623,14 @@ class DeepseekV2Model(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
for i in range(len(self.layers)):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions, hidden_states, input_metadata, residual
|
||||
positions, hidden_states, forward_batch, residual
|
||||
)
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
@@ -658,11 +658,11 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, input_metadata)
|
||||
hidden_states = self.model(input_ids, positions, forward_batch)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
|
||||
@@ -40,7 +40,7 @@ from sglang.srt.layers.linear import (
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
|
||||
|
||||
class ExaoneGatedMLP(nn.Module):
|
||||
@@ -162,12 +162,12 @@ class ExaoneAttention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, input_metadata)
|
||||
attn_output = self.attn(q, k, v, forward_batch)
|
||||
output, _ = self.out_proj(attn_output)
|
||||
return output
|
||||
|
||||
@@ -220,7 +220,7 @@ class ExaoneDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
@@ -232,7 +232,7 @@ class ExaoneDecoderLayer(nn.Module):
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
input_metadata=input_metadata,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
@@ -270,7 +270,7 @@ class ExaoneModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
if input_embeds is None:
|
||||
@@ -283,7 +283,7 @@ class ExaoneModel(nn.Module):
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
input_metadata,
|
||||
forward_batch,
|
||||
residual,
|
||||
)
|
||||
hidden_states, _ = self.ln_f(hidden_states, residual)
|
||||
@@ -309,14 +309,14 @@ class ExaoneForCausalLM(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> LogitsProcessorOutput:
|
||||
hidden_states = self.transformer(
|
||||
input_ids, positions, input_metadata, input_embeds
|
||||
input_ids, positions, forward_batch, input_embeds
|
||||
)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
|
||||
@@ -37,7 +37,7 @@ from sglang.srt.layers.linear import (
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
|
||||
|
||||
class GemmaMLP(nn.Module):
|
||||
@@ -137,12 +137,12 @@ class GemmaAttention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, input_metadata)
|
||||
attn_output = self.attn(q, k, v, forward_batch)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@@ -180,7 +180,7 @@ class GemmaDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
@@ -192,7 +192,7 @@ class GemmaDecoderLayer(nn.Module):
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
input_metadata=input_metadata,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
@@ -226,7 +226,7 @@ class GemmaModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
if input_embeds is None:
|
||||
@@ -243,7 +243,7 @@ class GemmaModel(nn.Module):
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
input_metadata,
|
||||
forward_batch,
|
||||
residual,
|
||||
)
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
@@ -293,12 +293,12 @@ class GemmaForCausalLM(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
|
||||
input_ids, hidden_states, self.model.embed_tokens.weight, forward_batch
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
|
||||
@@ -37,7 +37,7 @@ from sglang.srt.layers.linear import (
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
|
||||
|
||||
# Aligned with HF's implementation, using sliding window inclusive with the last token
|
||||
@@ -175,12 +175,12 @@ class Gemma2Attention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, input_metadata)
|
||||
attn_output = self.attn(q, k, v, forward_batch)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@@ -230,7 +230,7 @@ class Gemma2DecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if residual is None:
|
||||
@@ -241,7 +241,7 @@ class Gemma2DecoderLayer(nn.Module):
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
input_metadata=input_metadata,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
|
||||
@@ -286,7 +286,7 @@ class Gemma2Model(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
if input_embeds is None:
|
||||
@@ -302,7 +302,7 @@ class Gemma2Model(nn.Module):
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
input_metadata,
|
||||
forward_batch,
|
||||
residual,
|
||||
)
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
@@ -352,12 +352,12 @@ class Gemma2ForCausalLM(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
|
||||
input_ids, hidden_states, self.model.embed_tokens.weight, forward_batch
|
||||
)
|
||||
|
||||
def get_attention_sliding_window_size(self):
|
||||
|
||||
@@ -35,7 +35,7 @@ from sglang.srt.layers.linear import (
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
|
||||
|
||||
class GPTBigCodeAttention(nn.Module):
|
||||
@@ -90,7 +90,7 @@ class GPTBigCodeAttention(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.c_attn(hidden_states)
|
||||
q, k, v = qkv.split(
|
||||
@@ -101,7 +101,7 @@ class GPTBigCodeAttention(nn.Module):
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
attn_output = self.attn(q, k, v, input_metadata)
|
||||
attn_output = self.attn(q, k, v, forward_batch)
|
||||
attn_output, _ = self.c_proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
@@ -160,12 +160,12 @@ class GPTBigCodeBlock(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.ln_1(hidden_states)
|
||||
attn_output = self.attn(
|
||||
hidden_states=hidden_states, input_metadata=input_metadata
|
||||
hidden_states=hidden_states, forward_batch=forward_batch
|
||||
)
|
||||
# residual connection
|
||||
hidden_states = attn_output + residual
|
||||
@@ -214,7 +214,7 @@ class GPTBigCodeModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.wte(input_ids)
|
||||
position_embeds = self.wpe(position_ids)
|
||||
@@ -222,7 +222,7 @@ class GPTBigCodeModel(nn.Module):
|
||||
|
||||
for i in range(len(self.h)):
|
||||
layer = self.h[i]
|
||||
hidden_states = layer(hidden_states, input_metadata)
|
||||
hidden_states = layer(hidden_states, forward_batch)
|
||||
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
return hidden_states
|
||||
@@ -267,11 +267,11 @@ class GPTBigCodeForCausalLM(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(input_ids, positions, input_metadata)
|
||||
hidden_states = self.transformer(input_ids, positions, forward_batch)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
|
||||
@@ -46,7 +46,7 @@ from sglang.srt.layers.linear import (
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
|
||||
|
||||
class Grok1MoE(nn.Module):
|
||||
@@ -173,12 +173,12 @@ class Grok1Attention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, input_metadata)
|
||||
attn_output = self.attn(q, k, v, forward_batch)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@@ -219,7 +219,7 @@ class Grok1DecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
hidden_states = (
|
||||
@@ -227,7 +227,7 @@ class Grok1DecoderLayer(nn.Module):
|
||||
self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=self.pre_attn_norm(hidden_states),
|
||||
input_metadata=input_metadata,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
)
|
||||
+ hidden_states
|
||||
@@ -268,7 +268,7 @@ class Grok1Model(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
if input_embeds is None:
|
||||
@@ -278,7 +278,7 @@ class Grok1Model(nn.Module):
|
||||
hidden_states = input_embeds
|
||||
|
||||
for i in range(len(self.layers)):
|
||||
hidden_states = self.layers[i](positions, hidden_states, input_metadata)
|
||||
hidden_states = self.layers[i](positions, hidden_states, forward_batch)
|
||||
hidden_states = self.norm(hidden_states)
|
||||
hidden_states.mul_(self.config.output_multiplier_scale)
|
||||
return hidden_states
|
||||
@@ -309,12 +309,12 @@ class Grok1ForCausalLM(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
|
||||
@@ -40,7 +40,7 @@ from sglang.srt.layers.linear import (
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
|
||||
|
||||
class InternLM2MLP(nn.Module):
|
||||
@@ -137,12 +137,12 @@ class InternLM2Attention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.wqkv(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, input_metadata)
|
||||
attn_output = self.attn(q, k, v, forward_batch)
|
||||
output, _ = self.wo(attn_output)
|
||||
return output
|
||||
|
||||
@@ -182,7 +182,7 @@ class InternLMDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
@@ -194,7 +194,7 @@ class InternLMDecoderLayer(nn.Module):
|
||||
hidden_states = self.attention(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
input_metadata=input_metadata,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
@@ -229,7 +229,7 @@ class InternLM2Model(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
if input_embeds is None:
|
||||
@@ -242,7 +242,7 @@ class InternLM2Model(nn.Module):
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
input_metadata,
|
||||
forward_batch,
|
||||
residual,
|
||||
)
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
@@ -268,12 +268,12 @@ class InternLM2ForCausalLM(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.output.weight, input_metadata
|
||||
input_ids, hidden_states, self.output.weight, forward_batch
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
|
||||
@@ -43,7 +43,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.torchao_utils import apply_torchao_config_
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
|
||||
|
||||
class LlamaMLP(nn.Module):
|
||||
@@ -162,12 +162,12 @@ class LlamaAttention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, input_metadata)
|
||||
attn_output = self.attn(q, k, v, forward_batch)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@@ -221,7 +221,7 @@ class LlamaDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
@@ -233,7 +233,7 @@ class LlamaDecoderLayer(nn.Module):
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
input_metadata=input_metadata,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
@@ -270,7 +270,7 @@ class LlamaModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
if input_embeds is None:
|
||||
@@ -283,7 +283,7 @@ class LlamaModel(nn.Module):
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
input_metadata,
|
||||
forward_batch,
|
||||
residual,
|
||||
)
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
@@ -310,12 +310,12 @@ class LlamaForCausalLM(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> LogitsProcessorOutput:
|
||||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
||||
)
|
||||
|
||||
def get_hidden_dim(self, module_name):
|
||||
|
||||
@@ -23,7 +23,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
|
||||
|
||||
|
||||
@@ -50,18 +50,18 @@ class LlamaForClassification(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||
is_eos_token = input_ids == self.eos_token_id
|
||||
hidden_states = hidden_states[is_eos_token]
|
||||
scores = self.classification_head(hidden_states)
|
||||
|
||||
if scores.shape[0] != input_metadata.batch_size:
|
||||
if scores.shape[0] != forward_batch.batch_size:
|
||||
print("Warning: the EOS tokens are missing in some sentences.")
|
||||
scores = torch.ones(
|
||||
(input_metadata.batch_size, self.config.classification_out_size)
|
||||
(forward_batch.batch_size, self.config.classification_out_size)
|
||||
).to(input_ids.device)
|
||||
|
||||
logits_output = LogitsProcessorOutput(
|
||||
|
||||
@@ -6,7 +6,7 @@ from transformers import LlamaConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
|
||||
from sglang.srt.model_executor.model_runner import InputMetadata
|
||||
from sglang.srt.model_executor.model_runner import ForwardBatch
|
||||
from sglang.srt.models.llama import LlamaModel
|
||||
|
||||
|
||||
@@ -26,15 +26,15 @@ class LlamaEmbeddingModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
get_embedding: bool = True,
|
||||
) -> EmbeddingPoolerOutput:
|
||||
assert (
|
||||
get_embedding
|
||||
), "LlamaEmbeddingModel / MistralModel is only used for embedding"
|
||||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||
return self.pooler(hidden_states, input_metadata)
|
||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||
return self.pooler(hidden_states, forward_batch)
|
||||
|
||||
def load_weights(
|
||||
self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None
|
||||
|
||||
@@ -24,7 +24,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
|
||||
|
||||
|
||||
@@ -51,13 +51,13 @@ class LlamaForSequenceClassification(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> EmbeddingPoolerOutput:
|
||||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||
scores = self.score(hidden_states)
|
||||
|
||||
return self.pooler(scores, input_metadata)
|
||||
return self.pooler(scores, forward_batch)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
params_dict = dict(self.named_parameters())
|
||||
@@ -102,19 +102,19 @@ class LlamaForSequenceClassificationWithNormal_Weights(LlamaForSequenceClassific
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
get_embedding: bool = True,
|
||||
) -> EmbeddingPoolerOutput:
|
||||
assert (
|
||||
get_embedding
|
||||
), "LlamaForSequenceClassification is only used for embedding"
|
||||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||
logits = self.score(hidden_states)
|
||||
weights = self.weights(hidden_states)
|
||||
|
||||
pooled_logits = self.pooler(logits, input_metadata).embeddings
|
||||
pooled_weights = self.pooler(weights, input_metadata).embeddings
|
||||
pooled_logits = self.pooler(logits, forward_batch).embeddings
|
||||
pooled_weights = self.pooler(weights, forward_batch).embeddings
|
||||
|
||||
rews = pooled_logits.view(-1, self.num_labels // 2, 2)[:, :, 0].view(
|
||||
-1, self.num_labels // 2
|
||||
|
||||
@@ -41,7 +41,7 @@ from sglang.srt.mm_utils import (
|
||||
unpad_image,
|
||||
unpad_image_shape,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.models.llama import LlamaForCausalLM
|
||||
from sglang.srt.models.mistral import MistralForCausalLM
|
||||
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
||||
@@ -130,12 +130,12 @@ class LlavaBaseForCausalLM(nn.Module):
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
image_inputs = input_metadata.image_inputs
|
||||
image_inputs = forward_batch.image_inputs
|
||||
|
||||
if input_metadata.forward_mode.is_extend():
|
||||
bs = input_metadata.batch_size
|
||||
if forward_batch.forward_mode.is_extend():
|
||||
bs = forward_batch.batch_size
|
||||
# Got List[List[str]] extend it to List[str]
|
||||
# The length of the List should be equal to batch size
|
||||
modalities_list = []
|
||||
@@ -151,7 +151,7 @@ class LlavaBaseForCausalLM(nn.Module):
|
||||
# Embed text inputs
|
||||
input_embeds = self.language_model.model.embed_tokens(input_ids)
|
||||
|
||||
start_positions = positions[input_metadata.extend_start_loc].cpu().numpy()
|
||||
start_positions = positions[forward_batch.extend_start_loc].cpu().numpy()
|
||||
need_vision = start_positions <= np.array(max_image_offset)
|
||||
|
||||
if need_vision.any():
|
||||
@@ -348,8 +348,8 @@ class LlavaBaseForCausalLM(nn.Module):
|
||||
image_features = new_image_features
|
||||
|
||||
# Fill in the placeholder for the image
|
||||
extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy()
|
||||
prefix_lens_cpu = input_metadata.extend_prefix_lens.cpu().numpy()
|
||||
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
|
||||
prefix_lens_cpu = forward_batch.extend_prefix_lens.cpu().numpy()
|
||||
pt = 0
|
||||
for i in range(bs):
|
||||
if not need_vision[i]:
|
||||
@@ -379,10 +379,10 @@ class LlavaBaseForCausalLM(nn.Module):
|
||||
pt += 1
|
||||
|
||||
return self.language_model(
|
||||
input_ids, positions, input_metadata, input_embeds=input_embeds
|
||||
input_ids, positions, forward_batch, input_embeds=input_embeds
|
||||
)
|
||||
elif input_metadata.forward_mode.is_decode():
|
||||
return self.language_model(input_ids, positions, input_metadata)
|
||||
elif forward_batch.forward_mode.is_decode():
|
||||
return self.language_model(input_ids, positions, forward_batch)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
# Load clip vision model by cfg['mm_vision_tower']:
|
||||
|
||||
@@ -27,7 +27,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.managers.schedule_batch import ImageInputs
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.models.llama import LlamaForCausalLM
|
||||
|
||||
|
||||
@@ -108,11 +108,11 @@ class LlavaVidForCausalLM(nn.Module):
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
image_inputs = input_metadata.image_inputs
|
||||
if input_metadata.forward_mode.is_extend():
|
||||
bs = input_metadata.batch_size
|
||||
image_inputs = forward_batch.image_inputs
|
||||
if forward_batch.forward_mode.is_extend():
|
||||
bs = forward_batch.batch_size
|
||||
|
||||
# Embed text inputs
|
||||
input_embeds = self.language_model.model.embed_tokens(input_ids)
|
||||
@@ -124,7 +124,7 @@ class LlavaVidForCausalLM(nn.Module):
|
||||
max_image_offset.append(max(im.image_offsets))
|
||||
else:
|
||||
max_image_offset.append(-1)
|
||||
start_positions = positions[input_metadata.extend_start_loc].cpu().numpy()
|
||||
start_positions = positions[forward_batch.extend_start_loc].cpu().numpy()
|
||||
need_vision = start_positions <= np.array(max_image_offset)
|
||||
|
||||
if need_vision.any():
|
||||
@@ -169,8 +169,8 @@ class LlavaVidForCausalLM(nn.Module):
|
||||
image_features = new_image_features
|
||||
|
||||
# Fill in the placeholder for the image
|
||||
extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy()
|
||||
prefix_lens_cpu = input_metadata.extend_prefix_lens.cpu().numpy()
|
||||
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
|
||||
prefix_lens_cpu = forward_batch.extend_prefix_lens.cpu().numpy()
|
||||
pt = 0
|
||||
for i in range(bs):
|
||||
if not need_vision[i]:
|
||||
@@ -200,10 +200,10 @@ class LlavaVidForCausalLM(nn.Module):
|
||||
pt += 1
|
||||
|
||||
return self.language_model(
|
||||
input_ids, positions, input_metadata, input_embeds=input_embeds
|
||||
input_ids, positions, forward_batch, input_embeds=input_embeds
|
||||
)
|
||||
elif input_metadata.forward_mode.is_decode():
|
||||
return self.language_model(input_ids, positions, input_metadata)
|
||||
elif forward_batch.forward_mode.is_decode():
|
||||
return self.language_model(input_ids, positions, forward_batch)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
# Load clip vision model by cfg['mm_vision_tower']:
|
||||
|
||||
@@ -39,7 +39,7 @@ from sglang.srt.layers.linear import (
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
|
||||
|
||||
class MiniCPMMLP(nn.Module):
|
||||
@@ -148,7 +148,7 @@ class MiniCPMAttention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
@@ -156,7 +156,7 @@ class MiniCPMAttention(nn.Module):
|
||||
q, k = q.float(), k.float()
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
q, k = q.to(orig_dtype), k.to(orig_dtype)
|
||||
attn_output = self.attn(q, k, v, input_metadata)
|
||||
attn_output = self.attn(q, k, v, forward_batch)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@@ -199,7 +199,7 @@ class MiniCPMDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
@@ -208,7 +208,7 @@ class MiniCPMDecoderLayer(nn.Module):
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
input_metadata=input_metadata,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
hidden_states = residual + hidden_states * (
|
||||
self.config.scale_depth / math.sqrt(self.config.num_hidden_layers)
|
||||
@@ -252,7 +252,7 @@ class MiniCPMModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
if input_embeds is None:
|
||||
@@ -266,7 +266,7 @@ class MiniCPMModel(nn.Module):
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
input_metadata,
|
||||
forward_batch,
|
||||
residual,
|
||||
)
|
||||
hidden_states = self.norm(hidden_states)
|
||||
@@ -303,19 +303,19 @@ class MiniCPMForCausalLM(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
if input_embeds is not None:
|
||||
input_embeds = input_embeds * self.config.scale_emb
|
||||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||
hidden_states = hidden_states / self.scale_width
|
||||
if self.config.tie_word_embeddings:
|
||||
lm_head_weight = self.model.embed_tokens.weight
|
||||
else:
|
||||
lm_head_weight = self.lm_head.weight
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, lm_head_weight, input_metadata
|
||||
input_ids, hidden_states, lm_head_weight, forward_batch
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
|
||||
@@ -42,7 +42,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
# ROCm: flashinfer available later
|
||||
@@ -193,7 +193,7 @@ class MiniCPM3Attention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
if self.q_lora_rank is not None:
|
||||
q = self.q_a_proj(hidden_states)[0]
|
||||
@@ -230,7 +230,7 @@ class MiniCPM3Attention(nn.Module):
|
||||
v = torch.nn.functional.pad(v, [0, 128 - self.v_head_dim], value=0).view(
|
||||
-1, self.num_local_heads * 128
|
||||
)
|
||||
attn_output = self.attn(q, k, v, input_metadata)
|
||||
attn_output = self.attn(q, k, v, forward_batch)
|
||||
attn_output = attn_output.view(-1, self.num_local_heads, 128)[
|
||||
..., : self.v_head_dim
|
||||
].reshape(-1, self.num_local_heads * self.v_head_dim)
|
||||
@@ -341,7 +341,7 @@ class MiniCPM3AttentionMLA(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
q_len = hidden_states.shape[0]
|
||||
q_input = hidden_states.new_empty(
|
||||
@@ -383,7 +383,7 @@ class MiniCPM3AttentionMLA(nn.Module):
|
||||
q_input[..., self.kv_lora_rank :] = q_pe
|
||||
k_input[..., self.kv_lora_rank :] = k_pe
|
||||
|
||||
attn_output = self.attn(q_input, k_input, v_input, input_metadata)
|
||||
attn_output = self.attn(q_input, k_input, v_input, forward_batch)
|
||||
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
||||
|
||||
if self.w_vc.dtype == torch.float8_e4m3fn:
|
||||
@@ -472,7 +472,7 @@ class MiniCPM3DecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
@@ -481,7 +481,7 @@ class MiniCPM3DecoderLayer(nn.Module):
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
input_metadata=input_metadata,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
hidden_states = residual + hidden_states * (
|
||||
self.config.scale_depth / math.sqrt(self.config.num_hidden_layers)
|
||||
@@ -528,7 +528,7 @@ class MiniCPM3Model(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
if input_embeds is None:
|
||||
@@ -542,7 +542,7 @@ class MiniCPM3Model(nn.Module):
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
input_metadata,
|
||||
forward_batch,
|
||||
residual,
|
||||
)
|
||||
hidden_states = self.norm(hidden_states)
|
||||
@@ -581,19 +581,19 @@ class MiniCPM3ForCausalLM(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
if input_embeds is not None:
|
||||
input_embeds = input_embeds * self.config.scale_emb
|
||||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||
hidden_states = hidden_states / self.scale_width
|
||||
if self.config.tie_word_embeddings:
|
||||
lm_head_weight = self.model.embed_tokens.weight
|
||||
else:
|
||||
lm_head_weight = self.lm_head.weight
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, lm_head_weight, input_metadata
|
||||
input_ids, hidden_states, lm_head_weight, forward_batch
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
|
||||
@@ -43,7 +43,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.torchao_utils import apply_torchao_config_
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
|
||||
|
||||
class MixtralMoE(nn.Module):
|
||||
@@ -171,12 +171,12 @@ class MixtralAttention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, input_metadata)
|
||||
attn_output = self.attn(q, k, v, forward_batch)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@@ -220,7 +220,7 @@ class MixtralDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
@@ -232,7 +232,7 @@ class MixtralDecoderLayer(nn.Module):
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
input_metadata=input_metadata,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
@@ -270,7 +270,7 @@ class MixtralModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
if input_embeds is None:
|
||||
@@ -281,7 +281,7 @@ class MixtralModel(nn.Module):
|
||||
for i in range(len(self.layers)):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions, hidden_states, input_metadata, residual
|
||||
positions, hidden_states, forward_batch, residual
|
||||
)
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
@@ -307,12 +307,12 @@ class MixtralForCausalLM(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
|
||||
@@ -45,7 +45,7 @@ from sglang.srt.layers.linear import (
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
|
||||
|
||||
class MixtralMLP(nn.Module):
|
||||
@@ -216,12 +216,12 @@ class MixtralAttention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, input_metadata)
|
||||
attn_output = self.attn(q, k, v, forward_batch)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@@ -256,7 +256,7 @@ class MixtralDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
@@ -268,7 +268,7 @@ class MixtralDecoderLayer(nn.Module):
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
input_metadata=input_metadata,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
@@ -303,7 +303,7 @@ class MixtralModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
if input_embeds is None:
|
||||
@@ -314,7 +314,7 @@ class MixtralModel(nn.Module):
|
||||
for i in range(len(self.layers)):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions, hidden_states, input_metadata, residual
|
||||
positions, hidden_states, forward_batch, residual
|
||||
)
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
@@ -339,12 +339,12 @@ class QuantMixtralForCausalLM(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
|
||||
@@ -48,7 +48,7 @@ from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
|
||||
|
||||
class OlmoeMoE(nn.Module):
|
||||
@@ -175,13 +175,13 @@ class OlmoeAttention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.q_norm(q.contiguous()), self.k_norm(k.contiguous())
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, input_metadata)
|
||||
attn_output = self.attn(q, k, v, forward_batch)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@@ -225,7 +225,7 @@ class OlmoeDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
@@ -238,7 +238,7 @@ class OlmoeDecoderLayer(nn.Module):
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
input_metadata=input_metadata,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
@@ -274,7 +274,7 @@ class OlmoeModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
if input_embeds is None:
|
||||
@@ -285,7 +285,7 @@ class OlmoeModel(nn.Module):
|
||||
for i in range(len(self.layers)):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions, hidden_states, input_metadata, residual
|
||||
positions, hidden_states, forward_batch, residual
|
||||
)
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
@@ -314,12 +314,12 @@ class OlmoeForCausalLM(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
|
||||
@@ -39,7 +39,7 @@ from sglang.srt.layers.linear import (
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
|
||||
|
||||
class QWenMLP(nn.Module):
|
||||
@@ -133,12 +133,12 @@ class QWenAttention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.c_attn(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, input_metadata)
|
||||
attn_output = self.attn(q, k, v, forward_batch)
|
||||
output, _ = self.c_proj(attn_output)
|
||||
return output
|
||||
|
||||
@@ -177,7 +177,7 @@ class QWenBlock(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
residual = hidden_states
|
||||
@@ -185,7 +185,7 @@ class QWenBlock(nn.Module):
|
||||
hidden_states = self.attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
input_metadata=input_metadata,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
@@ -224,7 +224,7 @@ class QWenModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.wte(input_ids)
|
||||
for i in range(len(self.h)):
|
||||
@@ -232,7 +232,7 @@ class QWenModel(nn.Module):
|
||||
hidden_states = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
input_metadata,
|
||||
forward_batch,
|
||||
)
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
return hidden_states
|
||||
@@ -257,11 +257,11 @@ class QWenLMHeadModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
):
|
||||
hidden_states = self.transformer(input_ids, positions, input_metadata)
|
||||
hidden_states = self.transformer(input_ids, positions, forward_batch)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
|
||||
@@ -40,7 +40,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.pooler import Pooler, PoolingType
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
|
||||
Qwen2Config = None
|
||||
|
||||
@@ -149,12 +149,12 @@ class Qwen2Attention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, input_metadata)
|
||||
attn_output = self.attn(q, k, v, forward_batch)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@@ -196,7 +196,7 @@ class Qwen2DecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
@@ -208,7 +208,7 @@ class Qwen2DecoderLayer(nn.Module):
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
input_metadata=input_metadata,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
@@ -243,7 +243,7 @@ class Qwen2Model(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
if input_embeds is None:
|
||||
@@ -256,7 +256,7 @@ class Qwen2Model(nn.Module):
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
input_metadata,
|
||||
forward_batch,
|
||||
residual,
|
||||
)
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
@@ -283,17 +283,17 @@ class Qwen2ForCausalLM(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
get_embedding: bool = False,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||
if not get_embedding:
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
||||
)
|
||||
else:
|
||||
return self.pooler(hidden_states, input_metadata)
|
||||
return self.pooler(hidden_states, forward_batch)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
|
||||
@@ -49,7 +49,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.torchao_utils import apply_torchao_config_
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
|
||||
|
||||
class Qwen2MoeMLP(nn.Module):
|
||||
@@ -221,12 +221,12 @@ class Qwen2MoeAttention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, input_metadata)
|
||||
attn_output = self.attn(q, k, v, forward_batch)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@@ -281,7 +281,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
@@ -293,7 +293,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
input_metadata=input_metadata,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
@@ -331,7 +331,7 @@ class Qwen2MoeModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
if input_embeds is None:
|
||||
@@ -342,7 +342,7 @@ class Qwen2MoeModel(nn.Module):
|
||||
for i in range(len(self.layers)):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions, hidden_states, input_metadata, residual
|
||||
positions, hidden_states, forward_batch, residual
|
||||
)
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
@@ -373,12 +373,12 @@ class Qwen2MoeForCausalLM(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
|
||||
@@ -40,7 +40,7 @@ from sglang.srt.layers.linear import (
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
|
||||
|
||||
class StablelmMLP(nn.Module):
|
||||
@@ -145,12 +145,12 @@ class StablelmAttention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, input_metadata)
|
||||
attn_output = self.attn(q, k, v, forward_batch)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@@ -173,7 +173,7 @@ class StablelmDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
residual = hidden_states
|
||||
@@ -181,7 +181,7 @@ class StablelmDecoderLayer(nn.Module):
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
input_metadata=input_metadata,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
@@ -218,7 +218,7 @@ class StableLMEpochModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
if input_embeds is None:
|
||||
@@ -230,7 +230,7 @@ class StableLMEpochModel(nn.Module):
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
input_metadata,
|
||||
forward_batch,
|
||||
)
|
||||
hidden_states = self.norm(hidden_states)
|
||||
return hidden_states
|
||||
@@ -255,12 +255,12 @@ class StableLmForCausalLM(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
|
||||
@@ -41,7 +41,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.model_runner import InputMetadata
|
||||
from sglang.srt.model_executor.model_runner import ForwardBatch
|
||||
|
||||
|
||||
class XverseMLP(nn.Module):
|
||||
@@ -160,12 +160,12 @@ class XverseAttention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, input_metadata)
|
||||
attn_output = self.attn(q, k, v, forward_batch)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@@ -222,7 +222,7 @@ class XverseDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
@@ -234,7 +234,7 @@ class XverseDecoderLayer(nn.Module):
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
input_metadata=input_metadata,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
@@ -271,7 +271,7 @@ class XverseModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
if input_embeds is None:
|
||||
@@ -284,7 +284,7 @@ class XverseModel(nn.Module):
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
input_metadata,
|
||||
forward_batch,
|
||||
residual,
|
||||
)
|
||||
# print(f"layer[{i}].hidden_states: {hidden_states}")
|
||||
@@ -312,12 +312,12 @@ class XverseForCausalLM(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
||||
)
|
||||
|
||||
def load_weights(
|
||||
|
||||
@@ -44,7 +44,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
|
||||
|
||||
class XverseMLP(nn.Module):
|
||||
@@ -244,12 +244,12 @@ class XverseAttention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, input_metadata)
|
||||
attn_output = self.attn(q, k, v, forward_batch)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@@ -300,7 +300,7 @@ class XverseDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
@@ -312,7 +312,7 @@ class XverseDecoderLayer(nn.Module):
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
input_metadata=input_metadata,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
@@ -353,14 +353,14 @@ class XverseModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
for i in range(len(self.layers)):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions, hidden_states, input_metadata, residual
|
||||
positions, hidden_states, forward_batch, residual
|
||||
)
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
@@ -388,11 +388,11 @@ class XverseMoeForCausalLM(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, input_metadata)
|
||||
hidden_states = self.model(input_ids, positions, forward_batch)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
|
||||
Reference in New Issue
Block a user