Organize code (rename, movement) (#953)
This commit is contained in:
@@ -50,8 +50,9 @@ import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.managers.schedule_batch import Batch, ForwardMode, Req
|
||||
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
||||
from sglang.srt.model_config import ModelConfig
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.sampling_params import SamplingParams
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
@@ -188,7 +189,7 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
|
||||
|
||||
|
||||
def extend(reqs, model_runner):
|
||||
batch = Batch.init_new(
|
||||
batch = ScheduleBatch.init_new(
|
||||
reqs=reqs,
|
||||
req_to_token_pool=model_runner.req_to_token_pool,
|
||||
token_to_kv_pool=model_runner.token_to_kv_pool,
|
||||
|
||||
@@ -25,7 +25,7 @@ from vllm.distributed import (
|
||||
tensor_model_parallel_all_gather,
|
||||
)
|
||||
|
||||
from sglang.srt.model_executor.model_runner import ForwardMode, InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
||||
@@ -22,11 +22,8 @@ from torch import nn
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.layers.extend_attention import extend_attention_fwd
|
||||
from sglang.srt.layers.token_attention import token_attention_fwd
|
||||
from sglang.srt.model_executor.model_runner import (
|
||||
ForwardMode,
|
||||
InputMetadata,
|
||||
global_server_args_dict,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
||||
from sglang.srt.model_executor.model_runner import global_server_args_dict
|
||||
|
||||
|
||||
class RadixAttention(nn.Module):
|
||||
|
||||
@@ -18,7 +18,6 @@ limitations under the License.
|
||||
import logging
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum, auto
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -46,15 +45,6 @@ global_server_args_dict = {
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ForwardMode(IntEnum):
|
||||
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
|
||||
PREFILL = auto()
|
||||
# Extend a sequence. The KV cache of the first part of the sequence is already computed (e.g., system prompt).
|
||||
EXTEND = auto()
|
||||
# Decode one token.
|
||||
DECODE = auto()
|
||||
|
||||
|
||||
class BaseFinishReason:
|
||||
def __init__(self, is_error: bool = False):
|
||||
self.is_error = is_error
|
||||
@@ -284,7 +274,7 @@ class Req:
|
||||
|
||||
|
||||
@dataclass
|
||||
class Batch:
|
||||
class ScheduleBatch:
|
||||
"""Store all inforamtion of a batch."""
|
||||
|
||||
# Request, memory pool, and cache
|
||||
@@ -673,7 +663,7 @@ class Batch:
|
||||
if self_val is not None: # logit_bias can be None
|
||||
setattr(self, item, self_val[new_indices])
|
||||
|
||||
def merge(self, other: "Batch"):
|
||||
def merge(self, other: "ScheduleBatch"):
|
||||
self.reqs.extend(other.reqs)
|
||||
|
||||
self.req_pool_indices = torch.concat(
|
||||
@@ -770,229 +760,6 @@ class Batch:
|
||||
return batch_next_token_ids
|
||||
|
||||
|
||||
@dataclass
|
||||
class InputMetadata:
|
||||
"""Store all inforamtion of a forward pass."""
|
||||
|
||||
forward_mode: ForwardMode
|
||||
batch_size: int
|
||||
total_num_tokens: int
|
||||
req_pool_indices: torch.Tensor
|
||||
seq_lens: torch.Tensor
|
||||
positions: torch.Tensor
|
||||
req_to_token_pool: ReqToTokenPool
|
||||
token_to_kv_pool: BaseTokenToKVPool
|
||||
|
||||
# For extend
|
||||
extend_seq_lens: torch.Tensor
|
||||
extend_start_loc: torch.Tensor
|
||||
extend_no_prefix: bool
|
||||
|
||||
# Output location of the KV cache
|
||||
out_cache_loc: torch.Tensor = None
|
||||
|
||||
# Output options
|
||||
return_logprob: bool = False
|
||||
top_logprobs_nums: List[int] = None
|
||||
|
||||
# Trition attention backend
|
||||
triton_max_seq_len: int = 0
|
||||
triton_max_extend_len: int = 0
|
||||
triton_start_loc: torch.Tensor = None
|
||||
triton_prefix_lens: torch.Tensor = None
|
||||
|
||||
# FlashInfer attention backend
|
||||
flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
|
||||
flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
|
||||
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
|
||||
flashinfer_use_ragged: bool = False
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
model_runner,
|
||||
forward_mode,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
prefix_lens,
|
||||
position_ids_offsets,
|
||||
out_cache_loc,
|
||||
top_logprobs_nums=None,
|
||||
return_logprob=False,
|
||||
skip_flashinfer_init=False,
|
||||
):
|
||||
flashinfer_use_ragged = False
|
||||
if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer:
|
||||
if forward_mode != ForwardMode.DECODE and int(torch.sum(seq_lens)) > 4096:
|
||||
flashinfer_use_ragged = True
|
||||
init_flashinfer_args(
|
||||
forward_mode,
|
||||
model_runner,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
prefix_lens,
|
||||
model_runner.flashinfer_decode_wrapper,
|
||||
flashinfer_use_ragged,
|
||||
)
|
||||
|
||||
batch_size = len(req_pool_indices)
|
||||
|
||||
if forward_mode == ForwardMode.DECODE:
|
||||
positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
|
||||
extend_seq_lens = extend_start_loc = extend_no_prefix = None
|
||||
if not model_runner.server_args.disable_flashinfer:
|
||||
# This variable is not needed in this case,
|
||||
# we do not compute it to make it compatbile with cuda graph.
|
||||
total_num_tokens = None
|
||||
else:
|
||||
total_num_tokens = int(torch.sum(seq_lens))
|
||||
else:
|
||||
seq_lens_cpu = seq_lens.cpu().numpy()
|
||||
prefix_lens_cpu = prefix_lens.cpu().numpy()
|
||||
position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
|
||||
positions = torch.tensor(
|
||||
np.concatenate(
|
||||
[
|
||||
np.arange(
|
||||
prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
|
||||
seq_lens_cpu[i] + position_ids_offsets_cpu[i],
|
||||
)
|
||||
for i in range(batch_size)
|
||||
],
|
||||
axis=0,
|
||||
),
|
||||
device="cuda",
|
||||
)
|
||||
extend_seq_lens = seq_lens - prefix_lens
|
||||
extend_start_loc = torch.zeros_like(seq_lens)
|
||||
extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0)
|
||||
extend_no_prefix = torch.all(prefix_lens == 0)
|
||||
total_num_tokens = int(torch.sum(seq_lens))
|
||||
|
||||
ret = cls(
|
||||
forward_mode=forward_mode,
|
||||
batch_size=batch_size,
|
||||
total_num_tokens=total_num_tokens,
|
||||
req_pool_indices=req_pool_indices,
|
||||
seq_lens=seq_lens,
|
||||
positions=positions,
|
||||
req_to_token_pool=model_runner.req_to_token_pool,
|
||||
token_to_kv_pool=model_runner.token_to_kv_pool,
|
||||
out_cache_loc=out_cache_loc,
|
||||
extend_seq_lens=extend_seq_lens,
|
||||
extend_start_loc=extend_start_loc,
|
||||
extend_no_prefix=extend_no_prefix,
|
||||
return_logprob=return_logprob,
|
||||
top_logprobs_nums=top_logprobs_nums,
|
||||
flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged,
|
||||
flashinfer_prefill_wrapper_paged=model_runner.flashinfer_prefill_wrapper_paged,
|
||||
flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper,
|
||||
flashinfer_use_ragged=flashinfer_use_ragged,
|
||||
)
|
||||
|
||||
if model_runner.server_args.disable_flashinfer:
|
||||
(
|
||||
ret.triton_max_seq_len,
|
||||
ret.triton_max_extend_len,
|
||||
ret.triton_start_loc,
|
||||
ret.triton_prefix_lens,
|
||||
) = init_triton_args(forward_mode, seq_lens, prefix_lens)
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def init_flashinfer_args(
|
||||
forward_mode,
|
||||
model_runner,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
prefix_lens,
|
||||
flashinfer_decode_wrapper,
|
||||
flashinfer_use_ragged=False,
|
||||
):
|
||||
"""Init auxiliary variables for FlashInfer attention backend."""
|
||||
num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
|
||||
num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
|
||||
head_dim = model_runner.model_config.head_dim
|
||||
batch_size = len(req_pool_indices)
|
||||
total_num_tokens = int(torch.sum(seq_lens))
|
||||
|
||||
if flashinfer_use_ragged:
|
||||
paged_kernel_lens = prefix_lens
|
||||
else:
|
||||
paged_kernel_lens = seq_lens
|
||||
|
||||
kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
||||
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
|
||||
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
|
||||
kv_indices = torch.cat(
|
||||
[
|
||||
model_runner.req_to_token_pool.req_to_token[
|
||||
req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
|
||||
]
|
||||
for i in range(batch_size)
|
||||
],
|
||||
dim=0,
|
||||
).contiguous()
|
||||
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
|
||||
|
||||
if forward_mode == ForwardMode.DECODE:
|
||||
flashinfer_decode_wrapper.end_forward()
|
||||
flashinfer_decode_wrapper.begin_forward(
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_len,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
1,
|
||||
)
|
||||
else:
|
||||
# extend part
|
||||
qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
||||
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
||||
|
||||
if flashinfer_use_ragged:
|
||||
model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
|
||||
model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
|
||||
qo_indptr,
|
||||
qo_indptr,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
)
|
||||
|
||||
# cached part
|
||||
model_runner.flashinfer_prefill_wrapper_paged.end_forward()
|
||||
model_runner.flashinfer_prefill_wrapper_paged.begin_forward(
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_len,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
1,
|
||||
)
|
||||
|
||||
|
||||
def init_triton_args(forward_mode, seq_lens, prefix_lens):
|
||||
"""Init auxiliary variables for triton attention backend."""
|
||||
batch_size = len(seq_lens)
|
||||
max_seq_len = int(torch.max(seq_lens))
|
||||
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
|
||||
start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
|
||||
|
||||
if forward_mode == ForwardMode.DECODE:
|
||||
max_extend_len = None
|
||||
else:
|
||||
extend_seq_lens = seq_lens - prefix_lens
|
||||
max_extend_len = int(torch.max(extend_seq_lens))
|
||||
|
||||
return max_seq_len, max_extend_len, start_loc, prefix_lens
|
||||
|
||||
|
||||
def top_k_top_p_sampling_from_probs_torch(
|
||||
probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor
|
||||
):
|
||||
|
||||
@@ -39,13 +39,13 @@ from sglang.srt.managers.policy_scheduler import PolicyScheduler
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
FINISH_ABORT,
|
||||
BaseFinishReason,
|
||||
Batch,
|
||||
ForwardMode,
|
||||
Req,
|
||||
ScheduleBatch,
|
||||
)
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||
from sglang.srt.model_config import ModelConfig
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
@@ -172,7 +172,7 @@ class ModelTpServer:
|
||||
|
||||
# Init running status
|
||||
self.waiting_queue: List[Req] = []
|
||||
self.running_batch: Batch = None
|
||||
self.running_batch: ScheduleBatch = None
|
||||
self.out_pyobjs = []
|
||||
self.decode_forward_ct = 0
|
||||
self.stream_interval = server_args.stream_interval
|
||||
@@ -353,7 +353,7 @@ class ModelTpServer:
|
||||
)
|
||||
self.waiting_queue.append(req)
|
||||
|
||||
def get_new_prefill_batch(self) -> Optional[Batch]:
|
||||
def get_new_prefill_batch(self) -> Optional[ScheduleBatch]:
|
||||
# TODO(lsyin): organize this function
|
||||
running_bs = (
|
||||
len(self.running_batch.reqs) if self.running_batch is not None else 0
|
||||
@@ -526,7 +526,7 @@ class ModelTpServer:
|
||||
)
|
||||
|
||||
# Return the new batch
|
||||
new_batch = Batch.init_new(
|
||||
new_batch = ScheduleBatch.init_new(
|
||||
can_run_list,
|
||||
self.req_to_token_pool,
|
||||
self.token_to_kv_pool,
|
||||
@@ -535,7 +535,7 @@ class ModelTpServer:
|
||||
self.waiting_queue = [x for x in self.waiting_queue if x not in can_run_list]
|
||||
return new_batch
|
||||
|
||||
def forward_prefill_batch(self, batch: Batch):
|
||||
def forward_prefill_batch(self, batch: ScheduleBatch):
|
||||
# Build batch tensors
|
||||
batch.prepare_for_extend(
|
||||
self.model_config.vocab_size, self.int_token_logit_bias
|
||||
@@ -624,7 +624,7 @@ class ModelTpServer:
|
||||
)
|
||||
req.output_top_logprobs.append(output.output_top_logprobs[i])
|
||||
|
||||
def cache_filled_batch(self, batch: Batch):
|
||||
def cache_filled_batch(self, batch: ScheduleBatch):
|
||||
req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
|
||||
for i, req in enumerate(batch.reqs):
|
||||
new_prefix_indices, new_last_node = self.tree_cache.cache_req(
|
||||
@@ -641,7 +641,7 @@ class ModelTpServer:
|
||||
# inflight request would get a new req idx
|
||||
self.req_to_token_pool.free(int(req_pool_indices_cpu[i]))
|
||||
|
||||
def forward_decode_batch(self, batch: Batch):
|
||||
def forward_decode_batch(self, batch: ScheduleBatch):
|
||||
# Check if decode out of memory
|
||||
if not batch.check_decode_mem():
|
||||
old_ratio = self.new_token_ratio
|
||||
@@ -700,7 +700,7 @@ class ModelTpServer:
|
||||
|
||||
self.handle_finished_requests(batch)
|
||||
|
||||
def handle_finished_requests(self, batch: Batch):
|
||||
def handle_finished_requests(self, batch: ScheduleBatch):
|
||||
output_rids = []
|
||||
output_vids = []
|
||||
decoded_texts = []
|
||||
@@ -800,7 +800,7 @@ class ModelTpServer:
|
||||
else:
|
||||
batch.reqs = []
|
||||
|
||||
def filter_out_inflight(self, batch: Batch):
|
||||
def filter_out_inflight(self, batch: ScheduleBatch):
|
||||
# TODO(lsyin): reduce the overhead, make a special version for this
|
||||
if self.current_inflight_req is None:
|
||||
return
|
||||
|
||||
@@ -29,8 +29,8 @@ from sglang.srt.layers.logits_processor import (
|
||||
LogitsMetadata,
|
||||
LogitsProcessor,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
Batch,
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||
from sglang.srt.model_executor.forward_batch_info import (
|
||||
ForwardMode,
|
||||
InputMetadata,
|
||||
init_flashinfer_args,
|
||||
@@ -202,7 +202,7 @@ class CudaGraphRunner:
|
||||
self.graph_memory_pool = graph.pool()
|
||||
return graph, None, out, flashinfer_decode_wrapper
|
||||
|
||||
def replay(self, batch: Batch):
|
||||
def replay(self, batch: ScheduleBatch):
|
||||
assert batch.out_cache_loc is not None
|
||||
raw_bs = len(batch.reqs)
|
||||
|
||||
|
||||
256
python/sglang/srt/model_executor/forward_batch_info.py
Normal file
256
python/sglang/srt/model_executor/forward_batch_info.py
Normal file
@@ -0,0 +1,256 @@
|
||||
"""
|
||||
Copyright 2023-2024 SGLang Team
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
"""ModelRunner runs the forward passes of the models."""
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum, auto
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
||||
|
||||
|
||||
class ForwardMode(IntEnum):
|
||||
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
|
||||
PREFILL = auto()
|
||||
# Extend a sequence. The KV cache of the first part of the sequence is already computed (e.g., system prompt).
|
||||
EXTEND = auto()
|
||||
# Decode one token.
|
||||
DECODE = auto()
|
||||
|
||||
|
||||
@dataclass
|
||||
class InputMetadata:
|
||||
"""Store all inforamtion of a forward pass."""
|
||||
|
||||
forward_mode: ForwardMode
|
||||
batch_size: int
|
||||
total_num_tokens: int
|
||||
req_pool_indices: torch.Tensor
|
||||
seq_lens: torch.Tensor
|
||||
positions: torch.Tensor
|
||||
req_to_token_pool: ReqToTokenPool
|
||||
token_to_kv_pool: BaseTokenToKVPool
|
||||
|
||||
# For extend
|
||||
extend_seq_lens: torch.Tensor
|
||||
extend_start_loc: torch.Tensor
|
||||
extend_no_prefix: bool
|
||||
|
||||
# Output location of the KV cache
|
||||
out_cache_loc: torch.Tensor = None
|
||||
|
||||
# Output options
|
||||
return_logprob: bool = False
|
||||
top_logprobs_nums: List[int] = None
|
||||
|
||||
# Trition attention backend
|
||||
triton_max_seq_len: int = 0
|
||||
triton_max_extend_len: int = 0
|
||||
triton_start_loc: torch.Tensor = None
|
||||
triton_prefix_lens: torch.Tensor = None
|
||||
|
||||
# FlashInfer attention backend
|
||||
flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
|
||||
flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
|
||||
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
|
||||
flashinfer_use_ragged: bool = False
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
model_runner,
|
||||
forward_mode,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
prefix_lens,
|
||||
position_ids_offsets,
|
||||
out_cache_loc,
|
||||
top_logprobs_nums=None,
|
||||
return_logprob=False,
|
||||
skip_flashinfer_init=False,
|
||||
):
|
||||
flashinfer_use_ragged = False
|
||||
if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer:
|
||||
if forward_mode != ForwardMode.DECODE and int(torch.sum(seq_lens)) > 4096:
|
||||
flashinfer_use_ragged = True
|
||||
init_flashinfer_args(
|
||||
forward_mode,
|
||||
model_runner,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
prefix_lens,
|
||||
model_runner.flashinfer_decode_wrapper,
|
||||
flashinfer_use_ragged,
|
||||
)
|
||||
|
||||
batch_size = len(req_pool_indices)
|
||||
|
||||
if forward_mode == ForwardMode.DECODE:
|
||||
positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
|
||||
extend_seq_lens = extend_start_loc = extend_no_prefix = None
|
||||
if not model_runner.server_args.disable_flashinfer:
|
||||
# This variable is not needed in this case,
|
||||
# we do not compute it to make it compatbile with cuda graph.
|
||||
total_num_tokens = None
|
||||
else:
|
||||
total_num_tokens = int(torch.sum(seq_lens))
|
||||
else:
|
||||
seq_lens_cpu = seq_lens.cpu().numpy()
|
||||
prefix_lens_cpu = prefix_lens.cpu().numpy()
|
||||
position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
|
||||
positions = torch.tensor(
|
||||
np.concatenate(
|
||||
[
|
||||
np.arange(
|
||||
prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
|
||||
seq_lens_cpu[i] + position_ids_offsets_cpu[i],
|
||||
)
|
||||
for i in range(batch_size)
|
||||
],
|
||||
axis=0,
|
||||
),
|
||||
device="cuda",
|
||||
)
|
||||
extend_seq_lens = seq_lens - prefix_lens
|
||||
extend_start_loc = torch.zeros_like(seq_lens)
|
||||
extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0)
|
||||
extend_no_prefix = torch.all(prefix_lens == 0)
|
||||
total_num_tokens = int(torch.sum(seq_lens))
|
||||
|
||||
ret = cls(
|
||||
forward_mode=forward_mode,
|
||||
batch_size=batch_size,
|
||||
total_num_tokens=total_num_tokens,
|
||||
req_pool_indices=req_pool_indices,
|
||||
seq_lens=seq_lens,
|
||||
positions=positions,
|
||||
req_to_token_pool=model_runner.req_to_token_pool,
|
||||
token_to_kv_pool=model_runner.token_to_kv_pool,
|
||||
out_cache_loc=out_cache_loc,
|
||||
extend_seq_lens=extend_seq_lens,
|
||||
extend_start_loc=extend_start_loc,
|
||||
extend_no_prefix=extend_no_prefix,
|
||||
return_logprob=return_logprob,
|
||||
top_logprobs_nums=top_logprobs_nums,
|
||||
flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged,
|
||||
flashinfer_prefill_wrapper_paged=model_runner.flashinfer_prefill_wrapper_paged,
|
||||
flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper,
|
||||
flashinfer_use_ragged=flashinfer_use_ragged,
|
||||
)
|
||||
|
||||
if model_runner.server_args.disable_flashinfer:
|
||||
(
|
||||
ret.triton_max_seq_len,
|
||||
ret.triton_max_extend_len,
|
||||
ret.triton_start_loc,
|
||||
ret.triton_prefix_lens,
|
||||
) = init_triton_args(forward_mode, seq_lens, prefix_lens)
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def init_flashinfer_args(
|
||||
forward_mode,
|
||||
model_runner,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
prefix_lens,
|
||||
flashinfer_decode_wrapper,
|
||||
flashinfer_use_ragged=False,
|
||||
):
|
||||
"""Init auxiliary variables for FlashInfer attention backend."""
|
||||
num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
|
||||
num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
|
||||
head_dim = model_runner.model_config.head_dim
|
||||
batch_size = len(req_pool_indices)
|
||||
total_num_tokens = int(torch.sum(seq_lens))
|
||||
|
||||
if flashinfer_use_ragged:
|
||||
paged_kernel_lens = prefix_lens
|
||||
else:
|
||||
paged_kernel_lens = seq_lens
|
||||
|
||||
kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
||||
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
|
||||
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
|
||||
kv_indices = torch.cat(
|
||||
[
|
||||
model_runner.req_to_token_pool.req_to_token[
|
||||
req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
|
||||
]
|
||||
for i in range(batch_size)
|
||||
],
|
||||
dim=0,
|
||||
).contiguous()
|
||||
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
|
||||
|
||||
if forward_mode == ForwardMode.DECODE:
|
||||
flashinfer_decode_wrapper.end_forward()
|
||||
flashinfer_decode_wrapper.begin_forward(
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_len,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
1,
|
||||
)
|
||||
else:
|
||||
# extend part
|
||||
qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
||||
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
||||
|
||||
if flashinfer_use_ragged:
|
||||
model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
|
||||
model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
|
||||
qo_indptr,
|
||||
qo_indptr,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
)
|
||||
|
||||
# cached part
|
||||
model_runner.flashinfer_prefill_wrapper_paged.end_forward()
|
||||
model_runner.flashinfer_prefill_wrapper_paged.begin_forward(
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_len,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
1,
|
||||
)
|
||||
|
||||
|
||||
def init_triton_args(forward_mode, seq_lens, prefix_lens):
|
||||
"""Init auxiliary variables for triton attention backend."""
|
||||
batch_size = len(seq_lens)
|
||||
max_seq_len = int(torch.max(seq_lens))
|
||||
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
|
||||
start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
|
||||
|
||||
if forward_mode == ForwardMode.DECODE:
|
||||
max_extend_len = None
|
||||
else:
|
||||
extend_seq_lens = seq_lens - prefix_lens
|
||||
max_extend_len = int(torch.max(extend_seq_lens))
|
||||
|
||||
return max_seq_len, max_extend_len, start_loc, prefix_lens
|
||||
@@ -41,18 +41,14 @@ from vllm.distributed import (
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
Batch,
|
||||
ForwardMode,
|
||||
InputMetadata,
|
||||
global_server_args_dict,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
||||
from sglang.srt.mem_cache.memory_pool import (
|
||||
MHATokenToKVPool,
|
||||
MLATokenToKVPool,
|
||||
ReqToTokenPool,
|
||||
)
|
||||
from sglang.srt.model_config import AttentionArch
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
get_available_gpu_memory,
|
||||
@@ -350,7 +346,7 @@ class ModelRunner:
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward_decode(self, batch: Batch):
|
||||
def forward_decode(self, batch: ScheduleBatch):
|
||||
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
|
||||
return self.cuda_graph_runner.replay(batch)
|
||||
|
||||
@@ -370,7 +366,7 @@ class ModelRunner:
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward_extend(self, batch: Batch):
|
||||
def forward_extend(self, batch: ScheduleBatch):
|
||||
input_metadata = InputMetadata.create(
|
||||
self,
|
||||
forward_mode=ForwardMode.EXTEND,
|
||||
@@ -387,7 +383,7 @@ class ModelRunner:
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward_extend_multi_modal(self, batch: Batch):
|
||||
def forward_extend_multi_modal(self, batch: ScheduleBatch):
|
||||
input_metadata = InputMetadata.create(
|
||||
self,
|
||||
forward_mode=ForwardMode.EXTEND,
|
||||
@@ -408,7 +404,7 @@ class ModelRunner:
|
||||
batch.image_offsets,
|
||||
)
|
||||
|
||||
def forward(self, batch: Batch, forward_mode: ForwardMode):
|
||||
def forward(self, batch: ScheduleBatch, forward_mode: ForwardMode):
|
||||
if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
|
||||
return self.forward_extend_multi_modal(batch)
|
||||
elif forward_mode == ForwardMode.DECODE:
|
||||
|
||||
@@ -45,7 +45,7 @@ from vllm.transformers_utils.configs import ChatGLMConfig
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.model_runner import InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
|
||||
LoraConfig = None
|
||||
|
||||
|
||||
@@ -64,7 +64,7 @@ from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.model_runner import InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
|
||||
|
||||
@torch.compile
|
||||
|
||||
@@ -45,7 +45,7 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.model_runner import InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
|
||||
|
||||
class DbrxRouter(nn.Module):
|
||||
|
||||
@@ -46,7 +46,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.managers.schedule_batch import InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
|
||||
|
||||
class DeepseekMLP(nn.Module):
|
||||
|
||||
@@ -46,7 +46,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.model_runner import InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
|
||||
|
||||
class DeepseekV2MLP(nn.Module):
|
||||
|
||||
@@ -37,7 +37,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.model_runner import InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
|
||||
|
||||
class GemmaMLP(nn.Module):
|
||||
|
||||
@@ -42,7 +42,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.model_runner import InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
|
||||
|
||||
class GemmaRMSNorm(CustomOp):
|
||||
|
||||
@@ -35,7 +35,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.managers.schedule_batch import InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
|
||||
|
||||
class GPTBigCodeAttention(nn.Module):
|
||||
|
||||
@@ -52,7 +52,7 @@ from vllm.utils import print_warning_once
|
||||
from sglang.srt.layers.fused_moe import fused_moe
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.model_runner import InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
|
||||
use_fused = True
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.model_runner import InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
|
||||
|
||||
class InternLM2MLP(nn.Module):
|
||||
|
||||
@@ -41,7 +41,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.model_runner import InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
|
||||
|
||||
class LlamaMLP(nn.Module):
|
||||
|
||||
@@ -25,7 +25,7 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitProcessorOutput
|
||||
from sglang.srt.model_executor.model_runner import InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
from sglang.srt.models.llama2 import LlamaModel
|
||||
|
||||
|
||||
|
||||
@@ -32,13 +32,12 @@ from vllm.config import CacheConfig
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.managers.schedule_batch import ForwardMode
|
||||
from sglang.srt.mm_utils import (
|
||||
get_anyres_image_grid_shape,
|
||||
unpad_image,
|
||||
unpad_image_shape,
|
||||
)
|
||||
from sglang.srt.model_executor.model_runner import InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
||||
from sglang.srt.models.llama2 import LlamaForCausalLM
|
||||
from sglang.srt.models.mistral import MistralForCausalLM
|
||||
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
||||
|
||||
@@ -26,13 +26,12 @@ from vllm.config import CacheConfig
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.managers.schedule_batch import ForwardMode
|
||||
from sglang.srt.mm_utils import (
|
||||
get_anyres_image_grid_shape,
|
||||
unpad_image,
|
||||
unpad_image_shape,
|
||||
)
|
||||
from sglang.srt.model_executor.model_runner import InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
||||
from sglang.srt.models.llama2 import LlamaForCausalLM
|
||||
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.model_runner import InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
|
||||
|
||||
class MiniCPMMLP(nn.Module):
|
||||
|
||||
@@ -50,7 +50,7 @@ from vllm.utils import print_warning_once
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.model_runner import InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
|
||||
|
||||
class MixtralMoE(nn.Module):
|
||||
|
||||
@@ -45,7 +45,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.model_runner import InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
|
||||
|
||||
class MixtralMLP(nn.Module):
|
||||
|
||||
@@ -39,7 +39,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.model_runner import InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
|
||||
|
||||
class QWenMLP(nn.Module):
|
||||
|
||||
@@ -39,7 +39,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.model_runner import InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
|
||||
Qwen2Config = None
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.model_runner import InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
|
||||
|
||||
class Qwen2MoeMLP(nn.Module):
|
||||
|
||||
@@ -40,7 +40,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.model_runner import InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
|
||||
|
||||
class StablelmMLP(nn.Module):
|
||||
|
||||
Reference in New Issue
Block a user