Code clean up: Remove deprecated prefill move InputMetadata to infer_batch.py (#609)
This commit is contained in:
@@ -8,6 +8,7 @@ from torch import nn
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.layers.extend_attention import extend_attention_fwd
|
||||
from sglang.srt.layers.token_attention import token_attention_fwd
|
||||
from sglang.srt.managers.controller.infer_batch import global_server_args_dict
|
||||
from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
|
||||
|
||||
|
||||
@@ -29,8 +30,6 @@ class RadixAttention(nn.Module):
|
||||
self.scaling = scaling
|
||||
self.layer_id = layer_id
|
||||
|
||||
from sglang.srt.managers.controller.model_runner import global_server_args_dict
|
||||
|
||||
if not global_server_args_dict.get("disable_flashinfer", False):
|
||||
self.prefill_forward = self.prefill_forward_flashinfer
|
||||
self.extend_forward = self.prefill_forward_flashinfer
|
||||
@@ -141,9 +140,7 @@ class RadixAttention(nn.Module):
|
||||
k = k.view(-1, self.tp_k_head_num, self.head_dim)
|
||||
v = v.view(-1, self.tp_v_head_num, self.head_dim)
|
||||
|
||||
if input_metadata.forward_mode == ForwardMode.PREFILL:
|
||||
return self.prefill_forward(q, k, v, input_metadata)
|
||||
elif input_metadata.forward_mode == ForwardMode.EXTEND:
|
||||
if input_metadata.forward_mode == ForwardMode.EXTEND:
|
||||
return self.extend_forward(q, k, v, input_metadata)
|
||||
elif input_metadata.forward_mode == ForwardMode.DECODE:
|
||||
return self.decode_forward(q, k, v, input_metadata)
|
||||
|
||||
@@ -15,10 +15,16 @@ from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
||||
|
||||
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
||||
|
||||
# Store some global server args
|
||||
global_server_args_dict = {}
|
||||
|
||||
|
||||
class ForwardMode(IntEnum):
|
||||
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
|
||||
PREFILL = auto()
|
||||
# Extend a sequence. The KV cache of the first part of the sequence is already computed (e.g., system prompt).
|
||||
EXTEND = auto()
|
||||
# Decode one token.
|
||||
DECODE = auto()
|
||||
|
||||
|
||||
@@ -66,6 +72,8 @@ class FINISH_ABORT(BaseFinishReason):
|
||||
|
||||
|
||||
class Req:
|
||||
"""Store all inforamtion of a request."""
|
||||
|
||||
def __init__(self, rid, origin_input_text, origin_input_ids):
|
||||
self.rid = rid
|
||||
self.origin_input_text = origin_input_text
|
||||
@@ -74,7 +82,7 @@ class Req:
|
||||
self.output_ids = [] # Each decode stage's output ids
|
||||
self.input_ids = None # input_ids = origin_input_ids + output_ids
|
||||
|
||||
# For incremental decode
|
||||
# For incremental decoding
|
||||
self.decoded_text = ""
|
||||
self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm
|
||||
self.read_offset = None
|
||||
@@ -93,9 +101,8 @@ class Req:
|
||||
self.sampling_params = None
|
||||
self.stream = False
|
||||
|
||||
self.tokenizer = None
|
||||
|
||||
# Check finish
|
||||
self.tokenizer = None
|
||||
self.finished_reason = None
|
||||
|
||||
# Prefix info
|
||||
@@ -252,6 +259,8 @@ class Req:
|
||||
|
||||
@dataclass
|
||||
class Batch:
|
||||
"""Store all inforamtion of a batch."""
|
||||
|
||||
reqs: List[Req]
|
||||
req_to_token_pool: ReqToTokenPool
|
||||
token_to_kv_pool: TokenToKVPool
|
||||
@@ -692,3 +701,203 @@ def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor
|
||||
] = 0.0
|
||||
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
|
||||
return probs_sort, probs_idx
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class InputMetadata:
|
||||
"""Store all inforamtion of a forward pass."""
|
||||
|
||||
forward_mode: ForwardMode
|
||||
batch_size: int
|
||||
total_num_tokens: int
|
||||
max_seq_len: int
|
||||
req_pool_indices: torch.Tensor
|
||||
start_loc: torch.Tensor
|
||||
seq_lens: torch.Tensor
|
||||
prefix_lens: torch.Tensor
|
||||
positions: torch.Tensor
|
||||
req_to_token_pool: ReqToTokenPool
|
||||
token_to_kv_pool: TokenToKVPool
|
||||
|
||||
# for extend
|
||||
extend_seq_lens: torch.Tensor = None
|
||||
extend_start_loc: torch.Tensor = None
|
||||
max_extend_len: int = 0
|
||||
|
||||
out_cache_loc: torch.Tensor = None
|
||||
out_cache_cont_start: torch.Tensor = None
|
||||
out_cache_cont_end: torch.Tensor = None
|
||||
|
||||
other_kv_index: torch.Tensor = None
|
||||
return_logprob: bool = False
|
||||
top_logprobs_nums: List[int] = None
|
||||
|
||||
# for flashinfer
|
||||
qo_indptr: torch.Tensor = None
|
||||
kv_indptr: torch.Tensor = None
|
||||
kv_indices: torch.Tensor = None
|
||||
kv_last_page_len: torch.Tensor = None
|
||||
flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
|
||||
flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
|
||||
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
|
||||
|
||||
def init_flashinfer_args(self, num_qo_heads, num_kv_heads, head_dim):
|
||||
if (
|
||||
self.forward_mode == ForwardMode.EXTEND
|
||||
):
|
||||
paged_kernel_lens = self.prefix_lens
|
||||
self.no_prefix = torch.all(self.prefix_lens == 0)
|
||||
else:
|
||||
paged_kernel_lens = self.seq_lens
|
||||
|
||||
self.kv_indptr = torch.zeros(
|
||||
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||
self.kv_last_page_len = torch.ones(
|
||||
(self.batch_size,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
|
||||
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
|
||||
self.kv_indices = torch.cat(
|
||||
[
|
||||
self.req_to_token_pool.req_to_token[
|
||||
req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
|
||||
]
|
||||
for i in range(self.batch_size)
|
||||
],
|
||||
dim=0,
|
||||
).contiguous()
|
||||
|
||||
if self.forward_mode == ForwardMode.EXTEND:
|
||||
# extend part
|
||||
self.qo_indptr = torch.zeros(
|
||||
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
|
||||
|
||||
self.flashinfer_prefill_wrapper_ragged.end_forward()
|
||||
self.flashinfer_prefill_wrapper_ragged.begin_forward(
|
||||
self.qo_indptr,
|
||||
self.qo_indptr.clone(),
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
)
|
||||
|
||||
# cached part
|
||||
self.flashinfer_prefill_wrapper_paged.end_forward()
|
||||
self.flashinfer_prefill_wrapper_paged.begin_forward(
|
||||
self.qo_indptr,
|
||||
self.kv_indptr,
|
||||
self.kv_indices,
|
||||
self.kv_last_page_len,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
1,
|
||||
)
|
||||
else:
|
||||
self.flashinfer_decode_wrapper.end_forward()
|
||||
self.flashinfer_decode_wrapper.begin_forward(
|
||||
self.kv_indptr,
|
||||
self.kv_indices,
|
||||
self.kv_last_page_len,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
1,
|
||||
pos_encoding_mode="NONE",
|
||||
data_type=self.token_to_kv_pool.kv_data[0].dtype,
|
||||
)
|
||||
|
||||
def init_extend_args(self):
|
||||
self.extend_seq_lens = self.seq_lens - self.prefix_lens
|
||||
self.extend_start_loc = torch.zeros_like(self.seq_lens)
|
||||
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
|
||||
self.max_extend_len = int(torch.max(self.extend_seq_lens))
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
model_runner,
|
||||
tp_size,
|
||||
forward_mode,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
prefix_lens,
|
||||
position_ids_offsets,
|
||||
out_cache_loc,
|
||||
out_cache_cont_start=None,
|
||||
out_cache_cont_end=None,
|
||||
top_logprobs_nums=None,
|
||||
return_logprob=False,
|
||||
flashinfer_prefill_wrapper_ragged=None,
|
||||
flashinfer_prefill_wrapper_paged=None,
|
||||
flashinfer_decode_wrapper=None,
|
||||
):
|
||||
batch_size = len(req_pool_indices)
|
||||
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
|
||||
start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
|
||||
total_num_tokens = int(torch.sum(seq_lens))
|
||||
max_seq_len = int(torch.max(seq_lens))
|
||||
|
||||
if forward_mode == ForwardMode.DECODE:
|
||||
positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
|
||||
other_kv_index = model_runner.req_to_token_pool.req_to_token[
|
||||
req_pool_indices[0], seq_lens[0] - 1
|
||||
].item()
|
||||
else:
|
||||
seq_lens_cpu = seq_lens.cpu().numpy()
|
||||
prefix_lens_cpu = prefix_lens.cpu().numpy()
|
||||
position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
|
||||
positions = torch.tensor(
|
||||
np.concatenate(
|
||||
[
|
||||
np.arange(
|
||||
prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
|
||||
seq_lens_cpu[i] + position_ids_offsets_cpu[i],
|
||||
)
|
||||
for i in range(batch_size)
|
||||
],
|
||||
axis=0,
|
||||
),
|
||||
device="cuda",
|
||||
)
|
||||
other_kv_index = None
|
||||
|
||||
ret = cls(
|
||||
forward_mode=forward_mode,
|
||||
batch_size=batch_size,
|
||||
total_num_tokens=total_num_tokens,
|
||||
max_seq_len=max_seq_len,
|
||||
req_pool_indices=req_pool_indices,
|
||||
start_loc=start_loc,
|
||||
seq_lens=seq_lens,
|
||||
prefix_lens=prefix_lens,
|
||||
positions=positions,
|
||||
req_to_token_pool=model_runner.req_to_token_pool,
|
||||
token_to_kv_pool=model_runner.token_to_kv_pool,
|
||||
out_cache_loc=out_cache_loc,
|
||||
out_cache_cont_start=out_cache_cont_start,
|
||||
out_cache_cont_end=out_cache_cont_end,
|
||||
other_kv_index=other_kv_index,
|
||||
return_logprob=return_logprob,
|
||||
top_logprobs_nums=top_logprobs_nums,
|
||||
flashinfer_prefill_wrapper_ragged=flashinfer_prefill_wrapper_ragged,
|
||||
flashinfer_prefill_wrapper_paged=flashinfer_prefill_wrapper_paged,
|
||||
flashinfer_decode_wrapper=flashinfer_decode_wrapper,
|
||||
)
|
||||
|
||||
if forward_mode == ForwardMode.EXTEND:
|
||||
ret.init_extend_args()
|
||||
|
||||
if not global_server_args_dict.get("disable_flashinfer", False):
|
||||
ret.init_flashinfer_args(
|
||||
model_runner.model_config.num_attention_heads // tp_size,
|
||||
model_runner.model_config.get_num_kv_heads(tp_size),
|
||||
model_runner.model_config.head_dim,
|
||||
)
|
||||
|
||||
return ret
|
||||
|
||||
@@ -4,11 +4,9 @@ import importlib
|
||||
import importlib.resources
|
||||
import logging
|
||||
import pkgutil
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from typing import List, Optional, Type
|
||||
from typing import Optional, Type
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from vllm.config import DeviceConfig, LoadConfig
|
||||
@@ -17,7 +15,7 @@ from vllm.distributed import init_distributed_environment, initialize_model_para
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
|
||||
from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode
|
||||
from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode, InputMetadata, global_server_args_dict
|
||||
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
@@ -29,210 +27,6 @@ from sglang.srt.utils import (
|
||||
|
||||
logger = logging.getLogger("srt.model_runner")
|
||||
|
||||
# for server args in model endpoints
|
||||
global_server_args_dict = {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class InputMetadata:
|
||||
forward_mode: ForwardMode
|
||||
batch_size: int
|
||||
total_num_tokens: int
|
||||
max_seq_len: int
|
||||
req_pool_indices: torch.Tensor
|
||||
start_loc: torch.Tensor
|
||||
seq_lens: torch.Tensor
|
||||
prefix_lens: torch.Tensor
|
||||
positions: torch.Tensor
|
||||
req_to_token_pool: ReqToTokenPool
|
||||
token_to_kv_pool: TokenToKVPool
|
||||
|
||||
# for extend
|
||||
extend_seq_lens: torch.Tensor = None
|
||||
extend_start_loc: torch.Tensor = None
|
||||
max_extend_len: int = 0
|
||||
|
||||
out_cache_loc: torch.Tensor = None
|
||||
out_cache_cont_start: torch.Tensor = None
|
||||
out_cache_cont_end: torch.Tensor = None
|
||||
|
||||
other_kv_index: torch.Tensor = None
|
||||
return_logprob: bool = False
|
||||
top_logprobs_nums: List[int] = None
|
||||
|
||||
# for flashinfer
|
||||
qo_indptr: torch.Tensor = None
|
||||
kv_indptr: torch.Tensor = None
|
||||
kv_indices: torch.Tensor = None
|
||||
kv_last_page_len: torch.Tensor = None
|
||||
flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
|
||||
flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
|
||||
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
|
||||
|
||||
def init_flashinfer_args(self, num_qo_heads, num_kv_heads, head_dim):
|
||||
if (
|
||||
self.forward_mode == ForwardMode.PREFILL
|
||||
or self.forward_mode == ForwardMode.EXTEND
|
||||
):
|
||||
paged_kernel_lens = self.prefix_lens
|
||||
self.no_prefix = torch.all(self.prefix_lens == 0)
|
||||
else:
|
||||
paged_kernel_lens = self.seq_lens
|
||||
|
||||
self.kv_indptr = torch.zeros(
|
||||
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||
self.kv_last_page_len = torch.ones(
|
||||
(self.batch_size,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
|
||||
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
|
||||
self.kv_indices = torch.cat(
|
||||
[
|
||||
self.req_to_token_pool.req_to_token[
|
||||
req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
|
||||
]
|
||||
for i in range(self.batch_size)
|
||||
],
|
||||
dim=0,
|
||||
).contiguous()
|
||||
|
||||
if (
|
||||
self.forward_mode == ForwardMode.PREFILL
|
||||
or self.forward_mode == ForwardMode.EXTEND
|
||||
):
|
||||
# extend part
|
||||
self.qo_indptr = torch.zeros(
|
||||
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
|
||||
|
||||
self.flashinfer_prefill_wrapper_ragged.end_forward()
|
||||
self.flashinfer_prefill_wrapper_ragged.begin_forward(
|
||||
self.qo_indptr,
|
||||
self.qo_indptr.clone(),
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
)
|
||||
|
||||
# cached part
|
||||
self.flashinfer_prefill_wrapper_paged.end_forward()
|
||||
self.flashinfer_prefill_wrapper_paged.begin_forward(
|
||||
self.qo_indptr,
|
||||
self.kv_indptr,
|
||||
self.kv_indices,
|
||||
self.kv_last_page_len,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
1,
|
||||
)
|
||||
else:
|
||||
self.flashinfer_decode_wrapper.end_forward()
|
||||
self.flashinfer_decode_wrapper.begin_forward(
|
||||
self.kv_indptr,
|
||||
self.kv_indices,
|
||||
self.kv_last_page_len,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
1,
|
||||
pos_encoding_mode="NONE",
|
||||
data_type=self.token_to_kv_pool.kv_data[0].dtype,
|
||||
)
|
||||
|
||||
def init_extend_args(self):
|
||||
self.extend_seq_lens = self.seq_lens - self.prefix_lens
|
||||
self.extend_start_loc = torch.zeros_like(self.seq_lens)
|
||||
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
|
||||
self.max_extend_len = int(torch.max(self.extend_seq_lens))
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
model_runner,
|
||||
tp_size,
|
||||
forward_mode,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
prefix_lens,
|
||||
position_ids_offsets,
|
||||
out_cache_loc,
|
||||
out_cache_cont_start=None,
|
||||
out_cache_cont_end=None,
|
||||
top_logprobs_nums=None,
|
||||
return_logprob=False,
|
||||
flashinfer_prefill_wrapper_ragged=None,
|
||||
flashinfer_prefill_wrapper_paged=None,
|
||||
flashinfer_decode_wrapper=None,
|
||||
):
|
||||
batch_size = len(req_pool_indices)
|
||||
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
|
||||
start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
|
||||
total_num_tokens = int(torch.sum(seq_lens))
|
||||
max_seq_len = int(torch.max(seq_lens))
|
||||
|
||||
if forward_mode == ForwardMode.DECODE:
|
||||
positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
|
||||
other_kv_index = model_runner.req_to_token_pool.req_to_token[
|
||||
req_pool_indices[0], seq_lens[0] - 1
|
||||
].item()
|
||||
else:
|
||||
seq_lens_cpu = seq_lens.cpu().numpy()
|
||||
prefix_lens_cpu = prefix_lens.cpu().numpy()
|
||||
position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
|
||||
positions = torch.tensor(
|
||||
np.concatenate(
|
||||
[
|
||||
np.arange(
|
||||
prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
|
||||
seq_lens_cpu[i] + position_ids_offsets_cpu[i],
|
||||
)
|
||||
for i in range(batch_size)
|
||||
],
|
||||
axis=0,
|
||||
),
|
||||
device="cuda",
|
||||
)
|
||||
other_kv_index = None
|
||||
|
||||
ret = cls(
|
||||
forward_mode=forward_mode,
|
||||
batch_size=batch_size,
|
||||
total_num_tokens=total_num_tokens,
|
||||
max_seq_len=max_seq_len,
|
||||
req_pool_indices=req_pool_indices,
|
||||
start_loc=start_loc,
|
||||
seq_lens=seq_lens,
|
||||
prefix_lens=prefix_lens,
|
||||
positions=positions,
|
||||
req_to_token_pool=model_runner.req_to_token_pool,
|
||||
token_to_kv_pool=model_runner.token_to_kv_pool,
|
||||
out_cache_loc=out_cache_loc,
|
||||
out_cache_cont_start=out_cache_cont_start,
|
||||
out_cache_cont_end=out_cache_cont_end,
|
||||
other_kv_index=other_kv_index,
|
||||
return_logprob=return_logprob,
|
||||
top_logprobs_nums=top_logprobs_nums,
|
||||
flashinfer_prefill_wrapper_ragged=flashinfer_prefill_wrapper_ragged,
|
||||
flashinfer_prefill_wrapper_paged=flashinfer_prefill_wrapper_paged,
|
||||
flashinfer_decode_wrapper=flashinfer_decode_wrapper,
|
||||
)
|
||||
|
||||
if forward_mode == ForwardMode.EXTEND:
|
||||
ret.init_extend_args()
|
||||
|
||||
if not global_server_args_dict.get("disable_flashinfer", False):
|
||||
ret.init_flashinfer_args(
|
||||
model_runner.model_config.num_attention_heads // tp_size,
|
||||
model_runner.model_config.get_num_kv_heads(tp_size),
|
||||
model_runner.model_config.head_dim,
|
||||
)
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
class ModelRunner:
|
||||
def __init__(
|
||||
@@ -245,6 +39,7 @@ class ModelRunner:
|
||||
nccl_port: int,
|
||||
server_args: ServerArgs,
|
||||
):
|
||||
# Parse args
|
||||
self.model_config = model_config
|
||||
self.mem_fraction_static = mem_fraction_static
|
||||
self.gpu_id = gpu_id
|
||||
@@ -256,7 +51,6 @@ class ModelRunner:
|
||||
monkey_patch_vllm_dummy_weight_loader()
|
||||
|
||||
# Init torch distributed
|
||||
logger.info(f"[gpu_id={self.gpu_id}] Set cuda device.")
|
||||
torch.cuda.set_device(self.gpu_id)
|
||||
logger.info(f"[gpu_id={self.gpu_id}] Init nccl begin.")
|
||||
|
||||
@@ -287,11 +81,8 @@ class ModelRunner:
|
||||
)
|
||||
|
||||
# Set some global args
|
||||
global global_server_args_dict
|
||||
global_server_args_dict = {
|
||||
"disable_flashinfer": server_args.disable_flashinfer,
|
||||
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
|
||||
}
|
||||
global_server_args_dict["disable_flashinfer"] = server_args.disable_flashinfer
|
||||
global_server_args_dict["attention_reduce_in_fp32"] = server_args.attention_reduce_in_fp32
|
||||
|
||||
# Load the model and create memory pool
|
||||
self.load_model()
|
||||
@@ -425,27 +216,6 @@ class ModelRunner:
|
||||
) = None
|
||||
self.flashinfer_decode_wrapper = None
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward_prefill(self, batch: Batch):
|
||||
input_metadata = InputMetadata.create(
|
||||
self,
|
||||
forward_mode=ForwardMode.PREFILL,
|
||||
tp_size=self.tp_size,
|
||||
req_pool_indices=batch.req_pool_indices,
|
||||
seq_lens=batch.seq_lens,
|
||||
prefix_lens=batch.prefix_lens,
|
||||
position_ids_offsets=batch.position_ids_offsets,
|
||||
out_cache_loc=batch.out_cache_loc,
|
||||
top_logprobs_nums=batch.top_logprobs_nums,
|
||||
return_logprob=batch.return_logprob,
|
||||
flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
|
||||
flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
|
||||
flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
|
||||
)
|
||||
return self.model.forward(
|
||||
batch.input_ids, input_metadata.positions, input_metadata
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward_extend(self, batch: Batch):
|
||||
input_metadata = InputMetadata.create(
|
||||
@@ -523,8 +293,6 @@ class ModelRunner:
|
||||
return self.forward_decode(batch)
|
||||
elif forward_mode == ForwardMode.EXTEND:
|
||||
return self.forward_extend(batch)
|
||||
elif forward_mode == ForwardMode.PREFILL:
|
||||
return self.forward_prefill(batch)
|
||||
else:
|
||||
raise ValueError(f"Invaid forward mode: {forward_mode}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user