Organize code (rename, movement) (#953)
This commit is contained in:
@@ -50,8 +50,9 @@ import torch
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||||
from sglang.srt.managers.schedule_batch import Batch, ForwardMode, Req
|
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
||||||
from sglang.srt.model_config import ModelConfig
|
from sglang.srt.model_config import ModelConfig
|
||||||
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
from sglang.srt.sampling_params import SamplingParams
|
from sglang.srt.sampling_params import SamplingParams
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
@@ -188,7 +189,7 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
|
|||||||
|
|
||||||
|
|
||||||
def extend(reqs, model_runner):
|
def extend(reqs, model_runner):
|
||||||
batch = Batch.init_new(
|
batch = ScheduleBatch.init_new(
|
||||||
reqs=reqs,
|
reqs=reqs,
|
||||||
req_to_token_pool=model_runner.req_to_token_pool,
|
req_to_token_pool=model_runner.req_to_token_pool,
|
||||||
token_to_kv_pool=model_runner.token_to_kv_pool,
|
token_to_kv_pool=model_runner.token_to_kv_pool,
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ from vllm.distributed import (
|
|||||||
tensor_model_parallel_all_gather,
|
tensor_model_parallel_all_gather,
|
||||||
)
|
)
|
||||||
|
|
||||||
from sglang.srt.model_executor.model_runner import ForwardMode, InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
|
|||||||
@@ -22,11 +22,8 @@ from torch import nn
|
|||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
from sglang.srt.layers.extend_attention import extend_attention_fwd
|
from sglang.srt.layers.extend_attention import extend_attention_fwd
|
||||||
from sglang.srt.layers.token_attention import token_attention_fwd
|
from sglang.srt.layers.token_attention import token_attention_fwd
|
||||||
from sglang.srt.model_executor.model_runner import (
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
||||||
ForwardMode,
|
from sglang.srt.model_executor.model_runner import global_server_args_dict
|
||||||
InputMetadata,
|
|
||||||
global_server_args_dict,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class RadixAttention(nn.Module):
|
class RadixAttention(nn.Module):
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ limitations under the License.
|
|||||||
import logging
|
import logging
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import IntEnum, auto
|
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -46,15 +45,6 @@ global_server_args_dict = {
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ForwardMode(IntEnum):
|
|
||||||
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
|
|
||||||
PREFILL = auto()
|
|
||||||
# Extend a sequence. The KV cache of the first part of the sequence is already computed (e.g., system prompt).
|
|
||||||
EXTEND = auto()
|
|
||||||
# Decode one token.
|
|
||||||
DECODE = auto()
|
|
||||||
|
|
||||||
|
|
||||||
class BaseFinishReason:
|
class BaseFinishReason:
|
||||||
def __init__(self, is_error: bool = False):
|
def __init__(self, is_error: bool = False):
|
||||||
self.is_error = is_error
|
self.is_error = is_error
|
||||||
@@ -284,7 +274,7 @@ class Req:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Batch:
|
class ScheduleBatch:
|
||||||
"""Store all inforamtion of a batch."""
|
"""Store all inforamtion of a batch."""
|
||||||
|
|
||||||
# Request, memory pool, and cache
|
# Request, memory pool, and cache
|
||||||
@@ -673,7 +663,7 @@ class Batch:
|
|||||||
if self_val is not None: # logit_bias can be None
|
if self_val is not None: # logit_bias can be None
|
||||||
setattr(self, item, self_val[new_indices])
|
setattr(self, item, self_val[new_indices])
|
||||||
|
|
||||||
def merge(self, other: "Batch"):
|
def merge(self, other: "ScheduleBatch"):
|
||||||
self.reqs.extend(other.reqs)
|
self.reqs.extend(other.reqs)
|
||||||
|
|
||||||
self.req_pool_indices = torch.concat(
|
self.req_pool_indices = torch.concat(
|
||||||
@@ -770,229 +760,6 @@ class Batch:
|
|||||||
return batch_next_token_ids
|
return batch_next_token_ids
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class InputMetadata:
|
|
||||||
"""Store all inforamtion of a forward pass."""
|
|
||||||
|
|
||||||
forward_mode: ForwardMode
|
|
||||||
batch_size: int
|
|
||||||
total_num_tokens: int
|
|
||||||
req_pool_indices: torch.Tensor
|
|
||||||
seq_lens: torch.Tensor
|
|
||||||
positions: torch.Tensor
|
|
||||||
req_to_token_pool: ReqToTokenPool
|
|
||||||
token_to_kv_pool: BaseTokenToKVPool
|
|
||||||
|
|
||||||
# For extend
|
|
||||||
extend_seq_lens: torch.Tensor
|
|
||||||
extend_start_loc: torch.Tensor
|
|
||||||
extend_no_prefix: bool
|
|
||||||
|
|
||||||
# Output location of the KV cache
|
|
||||||
out_cache_loc: torch.Tensor = None
|
|
||||||
|
|
||||||
# Output options
|
|
||||||
return_logprob: bool = False
|
|
||||||
top_logprobs_nums: List[int] = None
|
|
||||||
|
|
||||||
# Trition attention backend
|
|
||||||
triton_max_seq_len: int = 0
|
|
||||||
triton_max_extend_len: int = 0
|
|
||||||
triton_start_loc: torch.Tensor = None
|
|
||||||
triton_prefix_lens: torch.Tensor = None
|
|
||||||
|
|
||||||
# FlashInfer attention backend
|
|
||||||
flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
|
|
||||||
flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
|
|
||||||
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
|
|
||||||
flashinfer_use_ragged: bool = False
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def create(
|
|
||||||
cls,
|
|
||||||
model_runner,
|
|
||||||
forward_mode,
|
|
||||||
req_pool_indices,
|
|
||||||
seq_lens,
|
|
||||||
prefix_lens,
|
|
||||||
position_ids_offsets,
|
|
||||||
out_cache_loc,
|
|
||||||
top_logprobs_nums=None,
|
|
||||||
return_logprob=False,
|
|
||||||
skip_flashinfer_init=False,
|
|
||||||
):
|
|
||||||
flashinfer_use_ragged = False
|
|
||||||
if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer:
|
|
||||||
if forward_mode != ForwardMode.DECODE and int(torch.sum(seq_lens)) > 4096:
|
|
||||||
flashinfer_use_ragged = True
|
|
||||||
init_flashinfer_args(
|
|
||||||
forward_mode,
|
|
||||||
model_runner,
|
|
||||||
req_pool_indices,
|
|
||||||
seq_lens,
|
|
||||||
prefix_lens,
|
|
||||||
model_runner.flashinfer_decode_wrapper,
|
|
||||||
flashinfer_use_ragged,
|
|
||||||
)
|
|
||||||
|
|
||||||
batch_size = len(req_pool_indices)
|
|
||||||
|
|
||||||
if forward_mode == ForwardMode.DECODE:
|
|
||||||
positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
|
|
||||||
extend_seq_lens = extend_start_loc = extend_no_prefix = None
|
|
||||||
if not model_runner.server_args.disable_flashinfer:
|
|
||||||
# This variable is not needed in this case,
|
|
||||||
# we do not compute it to make it compatbile with cuda graph.
|
|
||||||
total_num_tokens = None
|
|
||||||
else:
|
|
||||||
total_num_tokens = int(torch.sum(seq_lens))
|
|
||||||
else:
|
|
||||||
seq_lens_cpu = seq_lens.cpu().numpy()
|
|
||||||
prefix_lens_cpu = prefix_lens.cpu().numpy()
|
|
||||||
position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
|
|
||||||
positions = torch.tensor(
|
|
||||||
np.concatenate(
|
|
||||||
[
|
|
||||||
np.arange(
|
|
||||||
prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
|
|
||||||
seq_lens_cpu[i] + position_ids_offsets_cpu[i],
|
|
||||||
)
|
|
||||||
for i in range(batch_size)
|
|
||||||
],
|
|
||||||
axis=0,
|
|
||||||
),
|
|
||||||
device="cuda",
|
|
||||||
)
|
|
||||||
extend_seq_lens = seq_lens - prefix_lens
|
|
||||||
extend_start_loc = torch.zeros_like(seq_lens)
|
|
||||||
extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0)
|
|
||||||
extend_no_prefix = torch.all(prefix_lens == 0)
|
|
||||||
total_num_tokens = int(torch.sum(seq_lens))
|
|
||||||
|
|
||||||
ret = cls(
|
|
||||||
forward_mode=forward_mode,
|
|
||||||
batch_size=batch_size,
|
|
||||||
total_num_tokens=total_num_tokens,
|
|
||||||
req_pool_indices=req_pool_indices,
|
|
||||||
seq_lens=seq_lens,
|
|
||||||
positions=positions,
|
|
||||||
req_to_token_pool=model_runner.req_to_token_pool,
|
|
||||||
token_to_kv_pool=model_runner.token_to_kv_pool,
|
|
||||||
out_cache_loc=out_cache_loc,
|
|
||||||
extend_seq_lens=extend_seq_lens,
|
|
||||||
extend_start_loc=extend_start_loc,
|
|
||||||
extend_no_prefix=extend_no_prefix,
|
|
||||||
return_logprob=return_logprob,
|
|
||||||
top_logprobs_nums=top_logprobs_nums,
|
|
||||||
flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged,
|
|
||||||
flashinfer_prefill_wrapper_paged=model_runner.flashinfer_prefill_wrapper_paged,
|
|
||||||
flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper,
|
|
||||||
flashinfer_use_ragged=flashinfer_use_ragged,
|
|
||||||
)
|
|
||||||
|
|
||||||
if model_runner.server_args.disable_flashinfer:
|
|
||||||
(
|
|
||||||
ret.triton_max_seq_len,
|
|
||||||
ret.triton_max_extend_len,
|
|
||||||
ret.triton_start_loc,
|
|
||||||
ret.triton_prefix_lens,
|
|
||||||
) = init_triton_args(forward_mode, seq_lens, prefix_lens)
|
|
||||||
|
|
||||||
return ret
|
|
||||||
|
|
||||||
|
|
||||||
def init_flashinfer_args(
|
|
||||||
forward_mode,
|
|
||||||
model_runner,
|
|
||||||
req_pool_indices,
|
|
||||||
seq_lens,
|
|
||||||
prefix_lens,
|
|
||||||
flashinfer_decode_wrapper,
|
|
||||||
flashinfer_use_ragged=False,
|
|
||||||
):
|
|
||||||
"""Init auxiliary variables for FlashInfer attention backend."""
|
|
||||||
num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
|
|
||||||
num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
|
|
||||||
head_dim = model_runner.model_config.head_dim
|
|
||||||
batch_size = len(req_pool_indices)
|
|
||||||
total_num_tokens = int(torch.sum(seq_lens))
|
|
||||||
|
|
||||||
if flashinfer_use_ragged:
|
|
||||||
paged_kernel_lens = prefix_lens
|
|
||||||
else:
|
|
||||||
paged_kernel_lens = seq_lens
|
|
||||||
|
|
||||||
kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
|
||||||
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
|
||||||
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
|
|
||||||
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
|
|
||||||
kv_indices = torch.cat(
|
|
||||||
[
|
|
||||||
model_runner.req_to_token_pool.req_to_token[
|
|
||||||
req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
|
|
||||||
]
|
|
||||||
for i in range(batch_size)
|
|
||||||
],
|
|
||||||
dim=0,
|
|
||||||
).contiguous()
|
|
||||||
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
|
|
||||||
|
|
||||||
if forward_mode == ForwardMode.DECODE:
|
|
||||||
flashinfer_decode_wrapper.end_forward()
|
|
||||||
flashinfer_decode_wrapper.begin_forward(
|
|
||||||
kv_indptr,
|
|
||||||
kv_indices,
|
|
||||||
kv_last_page_len,
|
|
||||||
num_qo_heads,
|
|
||||||
num_kv_heads,
|
|
||||||
head_dim,
|
|
||||||
1,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# extend part
|
|
||||||
qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
|
||||||
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
|
||||||
|
|
||||||
if flashinfer_use_ragged:
|
|
||||||
model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
|
|
||||||
model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
|
|
||||||
qo_indptr,
|
|
||||||
qo_indptr,
|
|
||||||
num_qo_heads,
|
|
||||||
num_kv_heads,
|
|
||||||
head_dim,
|
|
||||||
)
|
|
||||||
|
|
||||||
# cached part
|
|
||||||
model_runner.flashinfer_prefill_wrapper_paged.end_forward()
|
|
||||||
model_runner.flashinfer_prefill_wrapper_paged.begin_forward(
|
|
||||||
qo_indptr,
|
|
||||||
kv_indptr,
|
|
||||||
kv_indices,
|
|
||||||
kv_last_page_len,
|
|
||||||
num_qo_heads,
|
|
||||||
num_kv_heads,
|
|
||||||
head_dim,
|
|
||||||
1,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def init_triton_args(forward_mode, seq_lens, prefix_lens):
|
|
||||||
"""Init auxiliary variables for triton attention backend."""
|
|
||||||
batch_size = len(seq_lens)
|
|
||||||
max_seq_len = int(torch.max(seq_lens))
|
|
||||||
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
|
|
||||||
start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
|
|
||||||
|
|
||||||
if forward_mode == ForwardMode.DECODE:
|
|
||||||
max_extend_len = None
|
|
||||||
else:
|
|
||||||
extend_seq_lens = seq_lens - prefix_lens
|
|
||||||
max_extend_len = int(torch.max(extend_seq_lens))
|
|
||||||
|
|
||||||
return max_seq_len, max_extend_len, start_loc, prefix_lens
|
|
||||||
|
|
||||||
|
|
||||||
def top_k_top_p_sampling_from_probs_torch(
|
def top_k_top_p_sampling_from_probs_torch(
|
||||||
probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor
|
probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -39,13 +39,13 @@ from sglang.srt.managers.policy_scheduler import PolicyScheduler
|
|||||||
from sglang.srt.managers.schedule_batch import (
|
from sglang.srt.managers.schedule_batch import (
|
||||||
FINISH_ABORT,
|
FINISH_ABORT,
|
||||||
BaseFinishReason,
|
BaseFinishReason,
|
||||||
Batch,
|
|
||||||
ForwardMode,
|
|
||||||
Req,
|
Req,
|
||||||
|
ScheduleBatch,
|
||||||
)
|
)
|
||||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||||
from sglang.srt.model_config import ModelConfig
|
from sglang.srt.model_config import ModelConfig
|
||||||
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
@@ -172,7 +172,7 @@ class ModelTpServer:
|
|||||||
|
|
||||||
# Init running status
|
# Init running status
|
||||||
self.waiting_queue: List[Req] = []
|
self.waiting_queue: List[Req] = []
|
||||||
self.running_batch: Batch = None
|
self.running_batch: ScheduleBatch = None
|
||||||
self.out_pyobjs = []
|
self.out_pyobjs = []
|
||||||
self.decode_forward_ct = 0
|
self.decode_forward_ct = 0
|
||||||
self.stream_interval = server_args.stream_interval
|
self.stream_interval = server_args.stream_interval
|
||||||
@@ -353,7 +353,7 @@ class ModelTpServer:
|
|||||||
)
|
)
|
||||||
self.waiting_queue.append(req)
|
self.waiting_queue.append(req)
|
||||||
|
|
||||||
def get_new_prefill_batch(self) -> Optional[Batch]:
|
def get_new_prefill_batch(self) -> Optional[ScheduleBatch]:
|
||||||
# TODO(lsyin): organize this function
|
# TODO(lsyin): organize this function
|
||||||
running_bs = (
|
running_bs = (
|
||||||
len(self.running_batch.reqs) if self.running_batch is not None else 0
|
len(self.running_batch.reqs) if self.running_batch is not None else 0
|
||||||
@@ -526,7 +526,7 @@ class ModelTpServer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Return the new batch
|
# Return the new batch
|
||||||
new_batch = Batch.init_new(
|
new_batch = ScheduleBatch.init_new(
|
||||||
can_run_list,
|
can_run_list,
|
||||||
self.req_to_token_pool,
|
self.req_to_token_pool,
|
||||||
self.token_to_kv_pool,
|
self.token_to_kv_pool,
|
||||||
@@ -535,7 +535,7 @@ class ModelTpServer:
|
|||||||
self.waiting_queue = [x for x in self.waiting_queue if x not in can_run_list]
|
self.waiting_queue = [x for x in self.waiting_queue if x not in can_run_list]
|
||||||
return new_batch
|
return new_batch
|
||||||
|
|
||||||
def forward_prefill_batch(self, batch: Batch):
|
def forward_prefill_batch(self, batch: ScheduleBatch):
|
||||||
# Build batch tensors
|
# Build batch tensors
|
||||||
batch.prepare_for_extend(
|
batch.prepare_for_extend(
|
||||||
self.model_config.vocab_size, self.int_token_logit_bias
|
self.model_config.vocab_size, self.int_token_logit_bias
|
||||||
@@ -624,7 +624,7 @@ class ModelTpServer:
|
|||||||
)
|
)
|
||||||
req.output_top_logprobs.append(output.output_top_logprobs[i])
|
req.output_top_logprobs.append(output.output_top_logprobs[i])
|
||||||
|
|
||||||
def cache_filled_batch(self, batch: Batch):
|
def cache_filled_batch(self, batch: ScheduleBatch):
|
||||||
req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
|
req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
|
||||||
for i, req in enumerate(batch.reqs):
|
for i, req in enumerate(batch.reqs):
|
||||||
new_prefix_indices, new_last_node = self.tree_cache.cache_req(
|
new_prefix_indices, new_last_node = self.tree_cache.cache_req(
|
||||||
@@ -641,7 +641,7 @@ class ModelTpServer:
|
|||||||
# inflight request would get a new req idx
|
# inflight request would get a new req idx
|
||||||
self.req_to_token_pool.free(int(req_pool_indices_cpu[i]))
|
self.req_to_token_pool.free(int(req_pool_indices_cpu[i]))
|
||||||
|
|
||||||
def forward_decode_batch(self, batch: Batch):
|
def forward_decode_batch(self, batch: ScheduleBatch):
|
||||||
# Check if decode out of memory
|
# Check if decode out of memory
|
||||||
if not batch.check_decode_mem():
|
if not batch.check_decode_mem():
|
||||||
old_ratio = self.new_token_ratio
|
old_ratio = self.new_token_ratio
|
||||||
@@ -700,7 +700,7 @@ class ModelTpServer:
|
|||||||
|
|
||||||
self.handle_finished_requests(batch)
|
self.handle_finished_requests(batch)
|
||||||
|
|
||||||
def handle_finished_requests(self, batch: Batch):
|
def handle_finished_requests(self, batch: ScheduleBatch):
|
||||||
output_rids = []
|
output_rids = []
|
||||||
output_vids = []
|
output_vids = []
|
||||||
decoded_texts = []
|
decoded_texts = []
|
||||||
@@ -800,7 +800,7 @@ class ModelTpServer:
|
|||||||
else:
|
else:
|
||||||
batch.reqs = []
|
batch.reqs = []
|
||||||
|
|
||||||
def filter_out_inflight(self, batch: Batch):
|
def filter_out_inflight(self, batch: ScheduleBatch):
|
||||||
# TODO(lsyin): reduce the overhead, make a special version for this
|
# TODO(lsyin): reduce the overhead, make a special version for this
|
||||||
if self.current_inflight_req is None:
|
if self.current_inflight_req is None:
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -29,8 +29,8 @@ from sglang.srt.layers.logits_processor import (
|
|||||||
LogitsMetadata,
|
LogitsMetadata,
|
||||||
LogitsProcessor,
|
LogitsProcessor,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import (
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||||
Batch,
|
from sglang.srt.model_executor.forward_batch_info import (
|
||||||
ForwardMode,
|
ForwardMode,
|
||||||
InputMetadata,
|
InputMetadata,
|
||||||
init_flashinfer_args,
|
init_flashinfer_args,
|
||||||
@@ -202,7 +202,7 @@ class CudaGraphRunner:
|
|||||||
self.graph_memory_pool = graph.pool()
|
self.graph_memory_pool = graph.pool()
|
||||||
return graph, None, out, flashinfer_decode_wrapper
|
return graph, None, out, flashinfer_decode_wrapper
|
||||||
|
|
||||||
def replay(self, batch: Batch):
|
def replay(self, batch: ScheduleBatch):
|
||||||
assert batch.out_cache_loc is not None
|
assert batch.out_cache_loc is not None
|
||||||
raw_bs = len(batch.reqs)
|
raw_bs = len(batch.reqs)
|
||||||
|
|
||||||
|
|||||||
256
python/sglang/srt/model_executor/forward_batch_info.py
Normal file
256
python/sglang/srt/model_executor/forward_batch_info.py
Normal file
@@ -0,0 +1,256 @@
|
|||||||
|
"""
|
||||||
|
Copyright 2023-2024 SGLang Team
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""ModelRunner runs the forward passes of the models."""
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import IntEnum, auto
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
||||||
|
|
||||||
|
|
||||||
|
class ForwardMode(IntEnum):
|
||||||
|
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
|
||||||
|
PREFILL = auto()
|
||||||
|
# Extend a sequence. The KV cache of the first part of the sequence is already computed (e.g., system prompt).
|
||||||
|
EXTEND = auto()
|
||||||
|
# Decode one token.
|
||||||
|
DECODE = auto()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class InputMetadata:
|
||||||
|
"""Store all inforamtion of a forward pass."""
|
||||||
|
|
||||||
|
forward_mode: ForwardMode
|
||||||
|
batch_size: int
|
||||||
|
total_num_tokens: int
|
||||||
|
req_pool_indices: torch.Tensor
|
||||||
|
seq_lens: torch.Tensor
|
||||||
|
positions: torch.Tensor
|
||||||
|
req_to_token_pool: ReqToTokenPool
|
||||||
|
token_to_kv_pool: BaseTokenToKVPool
|
||||||
|
|
||||||
|
# For extend
|
||||||
|
extend_seq_lens: torch.Tensor
|
||||||
|
extend_start_loc: torch.Tensor
|
||||||
|
extend_no_prefix: bool
|
||||||
|
|
||||||
|
# Output location of the KV cache
|
||||||
|
out_cache_loc: torch.Tensor = None
|
||||||
|
|
||||||
|
# Output options
|
||||||
|
return_logprob: bool = False
|
||||||
|
top_logprobs_nums: List[int] = None
|
||||||
|
|
||||||
|
# Trition attention backend
|
||||||
|
triton_max_seq_len: int = 0
|
||||||
|
triton_max_extend_len: int = 0
|
||||||
|
triton_start_loc: torch.Tensor = None
|
||||||
|
triton_prefix_lens: torch.Tensor = None
|
||||||
|
|
||||||
|
# FlashInfer attention backend
|
||||||
|
flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
|
||||||
|
flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
|
||||||
|
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
|
||||||
|
flashinfer_use_ragged: bool = False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(
|
||||||
|
cls,
|
||||||
|
model_runner,
|
||||||
|
forward_mode,
|
||||||
|
req_pool_indices,
|
||||||
|
seq_lens,
|
||||||
|
prefix_lens,
|
||||||
|
position_ids_offsets,
|
||||||
|
out_cache_loc,
|
||||||
|
top_logprobs_nums=None,
|
||||||
|
return_logprob=False,
|
||||||
|
skip_flashinfer_init=False,
|
||||||
|
):
|
||||||
|
flashinfer_use_ragged = False
|
||||||
|
if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer:
|
||||||
|
if forward_mode != ForwardMode.DECODE and int(torch.sum(seq_lens)) > 4096:
|
||||||
|
flashinfer_use_ragged = True
|
||||||
|
init_flashinfer_args(
|
||||||
|
forward_mode,
|
||||||
|
model_runner,
|
||||||
|
req_pool_indices,
|
||||||
|
seq_lens,
|
||||||
|
prefix_lens,
|
||||||
|
model_runner.flashinfer_decode_wrapper,
|
||||||
|
flashinfer_use_ragged,
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_size = len(req_pool_indices)
|
||||||
|
|
||||||
|
if forward_mode == ForwardMode.DECODE:
|
||||||
|
positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
|
||||||
|
extend_seq_lens = extend_start_loc = extend_no_prefix = None
|
||||||
|
if not model_runner.server_args.disable_flashinfer:
|
||||||
|
# This variable is not needed in this case,
|
||||||
|
# we do not compute it to make it compatbile with cuda graph.
|
||||||
|
total_num_tokens = None
|
||||||
|
else:
|
||||||
|
total_num_tokens = int(torch.sum(seq_lens))
|
||||||
|
else:
|
||||||
|
seq_lens_cpu = seq_lens.cpu().numpy()
|
||||||
|
prefix_lens_cpu = prefix_lens.cpu().numpy()
|
||||||
|
position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
|
||||||
|
positions = torch.tensor(
|
||||||
|
np.concatenate(
|
||||||
|
[
|
||||||
|
np.arange(
|
||||||
|
prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
|
||||||
|
seq_lens_cpu[i] + position_ids_offsets_cpu[i],
|
||||||
|
)
|
||||||
|
for i in range(batch_size)
|
||||||
|
],
|
||||||
|
axis=0,
|
||||||
|
),
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
extend_seq_lens = seq_lens - prefix_lens
|
||||||
|
extend_start_loc = torch.zeros_like(seq_lens)
|
||||||
|
extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0)
|
||||||
|
extend_no_prefix = torch.all(prefix_lens == 0)
|
||||||
|
total_num_tokens = int(torch.sum(seq_lens))
|
||||||
|
|
||||||
|
ret = cls(
|
||||||
|
forward_mode=forward_mode,
|
||||||
|
batch_size=batch_size,
|
||||||
|
total_num_tokens=total_num_tokens,
|
||||||
|
req_pool_indices=req_pool_indices,
|
||||||
|
seq_lens=seq_lens,
|
||||||
|
positions=positions,
|
||||||
|
req_to_token_pool=model_runner.req_to_token_pool,
|
||||||
|
token_to_kv_pool=model_runner.token_to_kv_pool,
|
||||||
|
out_cache_loc=out_cache_loc,
|
||||||
|
extend_seq_lens=extend_seq_lens,
|
||||||
|
extend_start_loc=extend_start_loc,
|
||||||
|
extend_no_prefix=extend_no_prefix,
|
||||||
|
return_logprob=return_logprob,
|
||||||
|
top_logprobs_nums=top_logprobs_nums,
|
||||||
|
flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged,
|
||||||
|
flashinfer_prefill_wrapper_paged=model_runner.flashinfer_prefill_wrapper_paged,
|
||||||
|
flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper,
|
||||||
|
flashinfer_use_ragged=flashinfer_use_ragged,
|
||||||
|
)
|
||||||
|
|
||||||
|
if model_runner.server_args.disable_flashinfer:
|
||||||
|
(
|
||||||
|
ret.triton_max_seq_len,
|
||||||
|
ret.triton_max_extend_len,
|
||||||
|
ret.triton_start_loc,
|
||||||
|
ret.triton_prefix_lens,
|
||||||
|
) = init_triton_args(forward_mode, seq_lens, prefix_lens)
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def init_flashinfer_args(
|
||||||
|
forward_mode,
|
||||||
|
model_runner,
|
||||||
|
req_pool_indices,
|
||||||
|
seq_lens,
|
||||||
|
prefix_lens,
|
||||||
|
flashinfer_decode_wrapper,
|
||||||
|
flashinfer_use_ragged=False,
|
||||||
|
):
|
||||||
|
"""Init auxiliary variables for FlashInfer attention backend."""
|
||||||
|
num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
|
||||||
|
num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
|
||||||
|
head_dim = model_runner.model_config.head_dim
|
||||||
|
batch_size = len(req_pool_indices)
|
||||||
|
total_num_tokens = int(torch.sum(seq_lens))
|
||||||
|
|
||||||
|
if flashinfer_use_ragged:
|
||||||
|
paged_kernel_lens = prefix_lens
|
||||||
|
else:
|
||||||
|
paged_kernel_lens = seq_lens
|
||||||
|
|
||||||
|
kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
||||||
|
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||||
|
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
|
||||||
|
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
|
||||||
|
kv_indices = torch.cat(
|
||||||
|
[
|
||||||
|
model_runner.req_to_token_pool.req_to_token[
|
||||||
|
req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
|
||||||
|
]
|
||||||
|
for i in range(batch_size)
|
||||||
|
],
|
||||||
|
dim=0,
|
||||||
|
).contiguous()
|
||||||
|
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
|
||||||
|
|
||||||
|
if forward_mode == ForwardMode.DECODE:
|
||||||
|
flashinfer_decode_wrapper.end_forward()
|
||||||
|
flashinfer_decode_wrapper.begin_forward(
|
||||||
|
kv_indptr,
|
||||||
|
kv_indices,
|
||||||
|
kv_last_page_len,
|
||||||
|
num_qo_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
head_dim,
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# extend part
|
||||||
|
qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
||||||
|
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
||||||
|
|
||||||
|
if flashinfer_use_ragged:
|
||||||
|
model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
|
||||||
|
model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
|
||||||
|
qo_indptr,
|
||||||
|
qo_indptr,
|
||||||
|
num_qo_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
head_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
# cached part
|
||||||
|
model_runner.flashinfer_prefill_wrapper_paged.end_forward()
|
||||||
|
model_runner.flashinfer_prefill_wrapper_paged.begin_forward(
|
||||||
|
qo_indptr,
|
||||||
|
kv_indptr,
|
||||||
|
kv_indices,
|
||||||
|
kv_last_page_len,
|
||||||
|
num_qo_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
head_dim,
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def init_triton_args(forward_mode, seq_lens, prefix_lens):
|
||||||
|
"""Init auxiliary variables for triton attention backend."""
|
||||||
|
batch_size = len(seq_lens)
|
||||||
|
max_seq_len = int(torch.max(seq_lens))
|
||||||
|
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
|
||||||
|
start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
|
||||||
|
|
||||||
|
if forward_mode == ForwardMode.DECODE:
|
||||||
|
max_extend_len = None
|
||||||
|
else:
|
||||||
|
extend_seq_lens = seq_lens - prefix_lens
|
||||||
|
max_extend_len = int(torch.max(extend_seq_lens))
|
||||||
|
|
||||||
|
return max_seq_len, max_extend_len, start_loc, prefix_lens
|
||||||
@@ -41,18 +41,14 @@ from vllm.distributed import (
|
|||||||
from vllm.model_executor.models import ModelRegistry
|
from vllm.model_executor.models import ModelRegistry
|
||||||
|
|
||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
from sglang.srt.managers.schedule_batch import (
|
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
||||||
Batch,
|
|
||||||
ForwardMode,
|
|
||||||
InputMetadata,
|
|
||||||
global_server_args_dict,
|
|
||||||
)
|
|
||||||
from sglang.srt.mem_cache.memory_pool import (
|
from sglang.srt.mem_cache.memory_pool import (
|
||||||
MHATokenToKVPool,
|
MHATokenToKVPool,
|
||||||
MLATokenToKVPool,
|
MLATokenToKVPool,
|
||||||
ReqToTokenPool,
|
ReqToTokenPool,
|
||||||
)
|
)
|
||||||
from sglang.srt.model_config import AttentionArch
|
from sglang.srt.model_config import AttentionArch
|
||||||
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
get_available_gpu_memory,
|
get_available_gpu_memory,
|
||||||
@@ -350,7 +346,7 @@ class ModelRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def forward_decode(self, batch: Batch):
|
def forward_decode(self, batch: ScheduleBatch):
|
||||||
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
|
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
|
||||||
return self.cuda_graph_runner.replay(batch)
|
return self.cuda_graph_runner.replay(batch)
|
||||||
|
|
||||||
@@ -370,7 +366,7 @@ class ModelRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def forward_extend(self, batch: Batch):
|
def forward_extend(self, batch: ScheduleBatch):
|
||||||
input_metadata = InputMetadata.create(
|
input_metadata = InputMetadata.create(
|
||||||
self,
|
self,
|
||||||
forward_mode=ForwardMode.EXTEND,
|
forward_mode=ForwardMode.EXTEND,
|
||||||
@@ -387,7 +383,7 @@ class ModelRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def forward_extend_multi_modal(self, batch: Batch):
|
def forward_extend_multi_modal(self, batch: ScheduleBatch):
|
||||||
input_metadata = InputMetadata.create(
|
input_metadata = InputMetadata.create(
|
||||||
self,
|
self,
|
||||||
forward_mode=ForwardMode.EXTEND,
|
forward_mode=ForwardMode.EXTEND,
|
||||||
@@ -408,7 +404,7 @@ class ModelRunner:
|
|||||||
batch.image_offsets,
|
batch.image_offsets,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, batch: Batch, forward_mode: ForwardMode):
|
def forward(self, batch: ScheduleBatch, forward_mode: ForwardMode):
|
||||||
if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
|
if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
|
||||||
return self.forward_extend_multi_modal(batch)
|
return self.forward_extend_multi_modal(batch)
|
||||||
elif forward_mode == ForwardMode.DECODE:
|
elif forward_mode == ForwardMode.DECODE:
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ from vllm.transformers_utils.configs import ChatGLMConfig
|
|||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.model_executor.model_runner import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|
||||||
LoraConfig = None
|
LoraConfig = None
|
||||||
|
|
||||||
|
|||||||
@@ -64,7 +64,7 @@ from vllm.model_executor.utils import set_weight_attrs
|
|||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.model_executor.model_runner import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
@torch.compile
|
@torch.compile
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig
|
|||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.model_executor.model_runner import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
class DbrxRouter(nn.Module):
|
class DbrxRouter(nn.Module):
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.managers.schedule_batch import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
class DeepseekMLP(nn.Module):
|
class DeepseekMLP(nn.Module):
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.model_executor.model_runner import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
class DeepseekV2MLP(nn.Module):
|
class DeepseekV2MLP(nn.Module):
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.model_executor.model_runner import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
class GemmaMLP(nn.Module):
|
class GemmaMLP(nn.Module):
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.model_executor.model_runner import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
class GemmaRMSNorm(CustomOp):
|
class GemmaRMSNorm(CustomOp):
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.managers.schedule_batch import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
class GPTBigCodeAttention(nn.Module):
|
class GPTBigCodeAttention(nn.Module):
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ from vllm.utils import print_warning_once
|
|||||||
from sglang.srt.layers.fused_moe import fused_moe
|
from sglang.srt.layers.fused_moe import fused_moe
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.model_executor.model_runner import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|
||||||
use_fused = True
|
use_fused = True
|
||||||
|
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.model_executor.model_runner import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
class InternLM2MLP(nn.Module):
|
class InternLM2MLP(nn.Module):
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.model_executor.model_runner import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
class LlamaMLP(nn.Module):
|
class LlamaMLP(nn.Module):
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
|
|||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitProcessorOutput
|
||||||
from sglang.srt.model_executor.model_runner import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
from sglang.srt.models.llama2 import LlamaModel
|
from sglang.srt.models.llama2 import LlamaModel
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -32,13 +32,12 @@ from vllm.config import CacheConfig
|
|||||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
|
||||||
from sglang.srt.managers.schedule_batch import ForwardMode
|
|
||||||
from sglang.srt.mm_utils import (
|
from sglang.srt.mm_utils import (
|
||||||
get_anyres_image_grid_shape,
|
get_anyres_image_grid_shape,
|
||||||
unpad_image,
|
unpad_image,
|
||||||
unpad_image_shape,
|
unpad_image_shape,
|
||||||
)
|
)
|
||||||
from sglang.srt.model_executor.model_runner import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
||||||
from sglang.srt.models.llama2 import LlamaForCausalLM
|
from sglang.srt.models.llama2 import LlamaForCausalLM
|
||||||
from sglang.srt.models.mistral import MistralForCausalLM
|
from sglang.srt.models.mistral import MistralForCausalLM
|
||||||
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
||||||
|
|||||||
@@ -26,13 +26,12 @@ from vllm.config import CacheConfig
|
|||||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
|
||||||
from sglang.srt.managers.schedule_batch import ForwardMode
|
|
||||||
from sglang.srt.mm_utils import (
|
from sglang.srt.mm_utils import (
|
||||||
get_anyres_image_grid_shape,
|
get_anyres_image_grid_shape,
|
||||||
unpad_image,
|
unpad_image,
|
||||||
unpad_image_shape,
|
unpad_image_shape,
|
||||||
)
|
)
|
||||||
from sglang.srt.model_executor.model_runner import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
||||||
from sglang.srt.models.llama2 import LlamaForCausalLM
|
from sglang.srt.models.llama2 import LlamaForCausalLM
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.model_executor.model_runner import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
class MiniCPMMLP(nn.Module):
|
class MiniCPMMLP(nn.Module):
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ from vllm.utils import print_warning_once
|
|||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.model_executor.model_runner import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
class MixtralMoE(nn.Module):
|
class MixtralMoE(nn.Module):
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.model_executor.model_runner import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
class MixtralMLP(nn.Module):
|
class MixtralMLP(nn.Module):
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.model_executor.model_runner import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
class QWenMLP(nn.Module):
|
class QWenMLP(nn.Module):
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.model_executor.model_runner import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|
||||||
Qwen2Config = None
|
Qwen2Config = None
|
||||||
|
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ from vllm.sequence import IntermediateTensors, SamplerOutput
|
|||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.model_executor.model_runner import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
class Qwen2MoeMLP(nn.Module):
|
class Qwen2MoeMLP(nn.Module):
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.model_executor.model_runner import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
class StablelmMLP(nn.Module):
|
class StablelmMLP(nn.Module):
|
||||||
|
|||||||
Reference in New Issue
Block a user