Code clean up: Remove deprecated prefill move InputMetadata to infer_batch.py (#609)
This commit is contained in:
@@ -8,6 +8,7 @@ from torch import nn
|
|||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
from sglang.srt.layers.extend_attention import extend_attention_fwd
|
from sglang.srt.layers.extend_attention import extend_attention_fwd
|
||||||
from sglang.srt.layers.token_attention import token_attention_fwd
|
from sglang.srt.layers.token_attention import token_attention_fwd
|
||||||
|
from sglang.srt.managers.controller.infer_batch import global_server_args_dict
|
||||||
from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
|
from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
|
||||||
|
|
||||||
|
|
||||||
@@ -29,8 +30,6 @@ class RadixAttention(nn.Module):
|
|||||||
self.scaling = scaling
|
self.scaling = scaling
|
||||||
self.layer_id = layer_id
|
self.layer_id = layer_id
|
||||||
|
|
||||||
from sglang.srt.managers.controller.model_runner import global_server_args_dict
|
|
||||||
|
|
||||||
if not global_server_args_dict.get("disable_flashinfer", False):
|
if not global_server_args_dict.get("disable_flashinfer", False):
|
||||||
self.prefill_forward = self.prefill_forward_flashinfer
|
self.prefill_forward = self.prefill_forward_flashinfer
|
||||||
self.extend_forward = self.prefill_forward_flashinfer
|
self.extend_forward = self.prefill_forward_flashinfer
|
||||||
@@ -141,9 +140,7 @@ class RadixAttention(nn.Module):
|
|||||||
k = k.view(-1, self.tp_k_head_num, self.head_dim)
|
k = k.view(-1, self.tp_k_head_num, self.head_dim)
|
||||||
v = v.view(-1, self.tp_v_head_num, self.head_dim)
|
v = v.view(-1, self.tp_v_head_num, self.head_dim)
|
||||||
|
|
||||||
if input_metadata.forward_mode == ForwardMode.PREFILL:
|
if input_metadata.forward_mode == ForwardMode.EXTEND:
|
||||||
return self.prefill_forward(q, k, v, input_metadata)
|
|
||||||
elif input_metadata.forward_mode == ForwardMode.EXTEND:
|
|
||||||
return self.extend_forward(q, k, v, input_metadata)
|
return self.extend_forward(q, k, v, input_metadata)
|
||||||
elif input_metadata.forward_mode == ForwardMode.DECODE:
|
elif input_metadata.forward_mode == ForwardMode.DECODE:
|
||||||
return self.decode_forward(q, k, v, input_metadata)
|
return self.decode_forward(q, k, v, input_metadata)
|
||||||
|
|||||||
@@ -15,10 +15,16 @@ from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
|||||||
|
|
||||||
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
||||||
|
|
||||||
|
# Store some global server args
|
||||||
|
global_server_args_dict = {}
|
||||||
|
|
||||||
|
|
||||||
class ForwardMode(IntEnum):
|
class ForwardMode(IntEnum):
|
||||||
|
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
|
||||||
PREFILL = auto()
|
PREFILL = auto()
|
||||||
|
# Extend a sequence. The KV cache of the first part of the sequence is already computed (e.g., system prompt).
|
||||||
EXTEND = auto()
|
EXTEND = auto()
|
||||||
|
# Decode one token.
|
||||||
DECODE = auto()
|
DECODE = auto()
|
||||||
|
|
||||||
|
|
||||||
@@ -66,6 +72,8 @@ class FINISH_ABORT(BaseFinishReason):
|
|||||||
|
|
||||||
|
|
||||||
class Req:
|
class Req:
|
||||||
|
"""Store all inforamtion of a request."""
|
||||||
|
|
||||||
def __init__(self, rid, origin_input_text, origin_input_ids):
|
def __init__(self, rid, origin_input_text, origin_input_ids):
|
||||||
self.rid = rid
|
self.rid = rid
|
||||||
self.origin_input_text = origin_input_text
|
self.origin_input_text = origin_input_text
|
||||||
@@ -74,7 +82,7 @@ class Req:
|
|||||||
self.output_ids = [] # Each decode stage's output ids
|
self.output_ids = [] # Each decode stage's output ids
|
||||||
self.input_ids = None # input_ids = origin_input_ids + output_ids
|
self.input_ids = None # input_ids = origin_input_ids + output_ids
|
||||||
|
|
||||||
# For incremental decode
|
# For incremental decoding
|
||||||
self.decoded_text = ""
|
self.decoded_text = ""
|
||||||
self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm
|
self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm
|
||||||
self.read_offset = None
|
self.read_offset = None
|
||||||
@@ -93,9 +101,8 @@ class Req:
|
|||||||
self.sampling_params = None
|
self.sampling_params = None
|
||||||
self.stream = False
|
self.stream = False
|
||||||
|
|
||||||
self.tokenizer = None
|
|
||||||
|
|
||||||
# Check finish
|
# Check finish
|
||||||
|
self.tokenizer = None
|
||||||
self.finished_reason = None
|
self.finished_reason = None
|
||||||
|
|
||||||
# Prefix info
|
# Prefix info
|
||||||
@@ -252,6 +259,8 @@ class Req:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Batch:
|
class Batch:
|
||||||
|
"""Store all inforamtion of a batch."""
|
||||||
|
|
||||||
reqs: List[Req]
|
reqs: List[Req]
|
||||||
req_to_token_pool: ReqToTokenPool
|
req_to_token_pool: ReqToTokenPool
|
||||||
token_to_kv_pool: TokenToKVPool
|
token_to_kv_pool: TokenToKVPool
|
||||||
@@ -692,3 +701,203 @@ def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor
|
|||||||
] = 0.0
|
] = 0.0
|
||||||
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
|
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
|
||||||
return probs_sort, probs_idx
|
return probs_sort, probs_idx
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class InputMetadata:
|
||||||
|
"""Store all inforamtion of a forward pass."""
|
||||||
|
|
||||||
|
forward_mode: ForwardMode
|
||||||
|
batch_size: int
|
||||||
|
total_num_tokens: int
|
||||||
|
max_seq_len: int
|
||||||
|
req_pool_indices: torch.Tensor
|
||||||
|
start_loc: torch.Tensor
|
||||||
|
seq_lens: torch.Tensor
|
||||||
|
prefix_lens: torch.Tensor
|
||||||
|
positions: torch.Tensor
|
||||||
|
req_to_token_pool: ReqToTokenPool
|
||||||
|
token_to_kv_pool: TokenToKVPool
|
||||||
|
|
||||||
|
# for extend
|
||||||
|
extend_seq_lens: torch.Tensor = None
|
||||||
|
extend_start_loc: torch.Tensor = None
|
||||||
|
max_extend_len: int = 0
|
||||||
|
|
||||||
|
out_cache_loc: torch.Tensor = None
|
||||||
|
out_cache_cont_start: torch.Tensor = None
|
||||||
|
out_cache_cont_end: torch.Tensor = None
|
||||||
|
|
||||||
|
other_kv_index: torch.Tensor = None
|
||||||
|
return_logprob: bool = False
|
||||||
|
top_logprobs_nums: List[int] = None
|
||||||
|
|
||||||
|
# for flashinfer
|
||||||
|
qo_indptr: torch.Tensor = None
|
||||||
|
kv_indptr: torch.Tensor = None
|
||||||
|
kv_indices: torch.Tensor = None
|
||||||
|
kv_last_page_len: torch.Tensor = None
|
||||||
|
flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
|
||||||
|
flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
|
||||||
|
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
|
||||||
|
|
||||||
|
def init_flashinfer_args(self, num_qo_heads, num_kv_heads, head_dim):
|
||||||
|
if (
|
||||||
|
self.forward_mode == ForwardMode.EXTEND
|
||||||
|
):
|
||||||
|
paged_kernel_lens = self.prefix_lens
|
||||||
|
self.no_prefix = torch.all(self.prefix_lens == 0)
|
||||||
|
else:
|
||||||
|
paged_kernel_lens = self.seq_lens
|
||||||
|
|
||||||
|
self.kv_indptr = torch.zeros(
|
||||||
|
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
||||||
|
)
|
||||||
|
self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||||
|
self.kv_last_page_len = torch.ones(
|
||||||
|
(self.batch_size,), dtype=torch.int32, device="cuda"
|
||||||
|
)
|
||||||
|
req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
|
||||||
|
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
|
||||||
|
self.kv_indices = torch.cat(
|
||||||
|
[
|
||||||
|
self.req_to_token_pool.req_to_token[
|
||||||
|
req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
|
||||||
|
]
|
||||||
|
for i in range(self.batch_size)
|
||||||
|
],
|
||||||
|
dim=0,
|
||||||
|
).contiguous()
|
||||||
|
|
||||||
|
if self.forward_mode == ForwardMode.EXTEND:
|
||||||
|
# extend part
|
||||||
|
self.qo_indptr = torch.zeros(
|
||||||
|
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
||||||
|
)
|
||||||
|
self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
|
||||||
|
|
||||||
|
self.flashinfer_prefill_wrapper_ragged.end_forward()
|
||||||
|
self.flashinfer_prefill_wrapper_ragged.begin_forward(
|
||||||
|
self.qo_indptr,
|
||||||
|
self.qo_indptr.clone(),
|
||||||
|
num_qo_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
head_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
# cached part
|
||||||
|
self.flashinfer_prefill_wrapper_paged.end_forward()
|
||||||
|
self.flashinfer_prefill_wrapper_paged.begin_forward(
|
||||||
|
self.qo_indptr,
|
||||||
|
self.kv_indptr,
|
||||||
|
self.kv_indices,
|
||||||
|
self.kv_last_page_len,
|
||||||
|
num_qo_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
head_dim,
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.flashinfer_decode_wrapper.end_forward()
|
||||||
|
self.flashinfer_decode_wrapper.begin_forward(
|
||||||
|
self.kv_indptr,
|
||||||
|
self.kv_indices,
|
||||||
|
self.kv_last_page_len,
|
||||||
|
num_qo_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
head_dim,
|
||||||
|
1,
|
||||||
|
pos_encoding_mode="NONE",
|
||||||
|
data_type=self.token_to_kv_pool.kv_data[0].dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
def init_extend_args(self):
|
||||||
|
self.extend_seq_lens = self.seq_lens - self.prefix_lens
|
||||||
|
self.extend_start_loc = torch.zeros_like(self.seq_lens)
|
||||||
|
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
|
||||||
|
self.max_extend_len = int(torch.max(self.extend_seq_lens))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(
|
||||||
|
cls,
|
||||||
|
model_runner,
|
||||||
|
tp_size,
|
||||||
|
forward_mode,
|
||||||
|
req_pool_indices,
|
||||||
|
seq_lens,
|
||||||
|
prefix_lens,
|
||||||
|
position_ids_offsets,
|
||||||
|
out_cache_loc,
|
||||||
|
out_cache_cont_start=None,
|
||||||
|
out_cache_cont_end=None,
|
||||||
|
top_logprobs_nums=None,
|
||||||
|
return_logprob=False,
|
||||||
|
flashinfer_prefill_wrapper_ragged=None,
|
||||||
|
flashinfer_prefill_wrapper_paged=None,
|
||||||
|
flashinfer_decode_wrapper=None,
|
||||||
|
):
|
||||||
|
batch_size = len(req_pool_indices)
|
||||||
|
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
|
||||||
|
start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
|
||||||
|
total_num_tokens = int(torch.sum(seq_lens))
|
||||||
|
max_seq_len = int(torch.max(seq_lens))
|
||||||
|
|
||||||
|
if forward_mode == ForwardMode.DECODE:
|
||||||
|
positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
|
||||||
|
other_kv_index = model_runner.req_to_token_pool.req_to_token[
|
||||||
|
req_pool_indices[0], seq_lens[0] - 1
|
||||||
|
].item()
|
||||||
|
else:
|
||||||
|
seq_lens_cpu = seq_lens.cpu().numpy()
|
||||||
|
prefix_lens_cpu = prefix_lens.cpu().numpy()
|
||||||
|
position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
|
||||||
|
positions = torch.tensor(
|
||||||
|
np.concatenate(
|
||||||
|
[
|
||||||
|
np.arange(
|
||||||
|
prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
|
||||||
|
seq_lens_cpu[i] + position_ids_offsets_cpu[i],
|
||||||
|
)
|
||||||
|
for i in range(batch_size)
|
||||||
|
],
|
||||||
|
axis=0,
|
||||||
|
),
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
other_kv_index = None
|
||||||
|
|
||||||
|
ret = cls(
|
||||||
|
forward_mode=forward_mode,
|
||||||
|
batch_size=batch_size,
|
||||||
|
total_num_tokens=total_num_tokens,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
req_pool_indices=req_pool_indices,
|
||||||
|
start_loc=start_loc,
|
||||||
|
seq_lens=seq_lens,
|
||||||
|
prefix_lens=prefix_lens,
|
||||||
|
positions=positions,
|
||||||
|
req_to_token_pool=model_runner.req_to_token_pool,
|
||||||
|
token_to_kv_pool=model_runner.token_to_kv_pool,
|
||||||
|
out_cache_loc=out_cache_loc,
|
||||||
|
out_cache_cont_start=out_cache_cont_start,
|
||||||
|
out_cache_cont_end=out_cache_cont_end,
|
||||||
|
other_kv_index=other_kv_index,
|
||||||
|
return_logprob=return_logprob,
|
||||||
|
top_logprobs_nums=top_logprobs_nums,
|
||||||
|
flashinfer_prefill_wrapper_ragged=flashinfer_prefill_wrapper_ragged,
|
||||||
|
flashinfer_prefill_wrapper_paged=flashinfer_prefill_wrapper_paged,
|
||||||
|
flashinfer_decode_wrapper=flashinfer_decode_wrapper,
|
||||||
|
)
|
||||||
|
|
||||||
|
if forward_mode == ForwardMode.EXTEND:
|
||||||
|
ret.init_extend_args()
|
||||||
|
|
||||||
|
if not global_server_args_dict.get("disable_flashinfer", False):
|
||||||
|
ret.init_flashinfer_args(
|
||||||
|
model_runner.model_config.num_attention_heads // tp_size,
|
||||||
|
model_runner.model_config.get_num_kv_heads(tp_size),
|
||||||
|
model_runner.model_config.head_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|||||||
@@ -4,11 +4,9 @@ import importlib
|
|||||||
import importlib.resources
|
import importlib.resources
|
||||||
import logging
|
import logging
|
||||||
import pkgutil
|
import pkgutil
|
||||||
from dataclasses import dataclass
|
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import List, Optional, Type
|
from typing import Optional, Type
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from vllm.config import DeviceConfig, LoadConfig
|
from vllm.config import DeviceConfig, LoadConfig
|
||||||
@@ -17,7 +15,7 @@ from vllm.distributed import init_distributed_environment, initialize_model_para
|
|||||||
from vllm.model_executor.model_loader import get_model
|
from vllm.model_executor.model_loader import get_model
|
||||||
from vllm.model_executor.models import ModelRegistry
|
from vllm.model_executor.models import ModelRegistry
|
||||||
|
|
||||||
from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode
|
from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode, InputMetadata, global_server_args_dict
|
||||||
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
@@ -29,210 +27,6 @@ from sglang.srt.utils import (
|
|||||||
|
|
||||||
logger = logging.getLogger("srt.model_runner")
|
logger = logging.getLogger("srt.model_runner")
|
||||||
|
|
||||||
# for server args in model endpoints
|
|
||||||
global_server_args_dict = {}
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class InputMetadata:
|
|
||||||
forward_mode: ForwardMode
|
|
||||||
batch_size: int
|
|
||||||
total_num_tokens: int
|
|
||||||
max_seq_len: int
|
|
||||||
req_pool_indices: torch.Tensor
|
|
||||||
start_loc: torch.Tensor
|
|
||||||
seq_lens: torch.Tensor
|
|
||||||
prefix_lens: torch.Tensor
|
|
||||||
positions: torch.Tensor
|
|
||||||
req_to_token_pool: ReqToTokenPool
|
|
||||||
token_to_kv_pool: TokenToKVPool
|
|
||||||
|
|
||||||
# for extend
|
|
||||||
extend_seq_lens: torch.Tensor = None
|
|
||||||
extend_start_loc: torch.Tensor = None
|
|
||||||
max_extend_len: int = 0
|
|
||||||
|
|
||||||
out_cache_loc: torch.Tensor = None
|
|
||||||
out_cache_cont_start: torch.Tensor = None
|
|
||||||
out_cache_cont_end: torch.Tensor = None
|
|
||||||
|
|
||||||
other_kv_index: torch.Tensor = None
|
|
||||||
return_logprob: bool = False
|
|
||||||
top_logprobs_nums: List[int] = None
|
|
||||||
|
|
||||||
# for flashinfer
|
|
||||||
qo_indptr: torch.Tensor = None
|
|
||||||
kv_indptr: torch.Tensor = None
|
|
||||||
kv_indices: torch.Tensor = None
|
|
||||||
kv_last_page_len: torch.Tensor = None
|
|
||||||
flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
|
|
||||||
flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
|
|
||||||
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
|
|
||||||
|
|
||||||
def init_flashinfer_args(self, num_qo_heads, num_kv_heads, head_dim):
|
|
||||||
if (
|
|
||||||
self.forward_mode == ForwardMode.PREFILL
|
|
||||||
or self.forward_mode == ForwardMode.EXTEND
|
|
||||||
):
|
|
||||||
paged_kernel_lens = self.prefix_lens
|
|
||||||
self.no_prefix = torch.all(self.prefix_lens == 0)
|
|
||||||
else:
|
|
||||||
paged_kernel_lens = self.seq_lens
|
|
||||||
|
|
||||||
self.kv_indptr = torch.zeros(
|
|
||||||
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
|
||||||
)
|
|
||||||
self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
|
||||||
self.kv_last_page_len = torch.ones(
|
|
||||||
(self.batch_size,), dtype=torch.int32, device="cuda"
|
|
||||||
)
|
|
||||||
req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
|
|
||||||
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
|
|
||||||
self.kv_indices = torch.cat(
|
|
||||||
[
|
|
||||||
self.req_to_token_pool.req_to_token[
|
|
||||||
req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
|
|
||||||
]
|
|
||||||
for i in range(self.batch_size)
|
|
||||||
],
|
|
||||||
dim=0,
|
|
||||||
).contiguous()
|
|
||||||
|
|
||||||
if (
|
|
||||||
self.forward_mode == ForwardMode.PREFILL
|
|
||||||
or self.forward_mode == ForwardMode.EXTEND
|
|
||||||
):
|
|
||||||
# extend part
|
|
||||||
self.qo_indptr = torch.zeros(
|
|
||||||
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
|
||||||
)
|
|
||||||
self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
|
|
||||||
|
|
||||||
self.flashinfer_prefill_wrapper_ragged.end_forward()
|
|
||||||
self.flashinfer_prefill_wrapper_ragged.begin_forward(
|
|
||||||
self.qo_indptr,
|
|
||||||
self.qo_indptr.clone(),
|
|
||||||
num_qo_heads,
|
|
||||||
num_kv_heads,
|
|
||||||
head_dim,
|
|
||||||
)
|
|
||||||
|
|
||||||
# cached part
|
|
||||||
self.flashinfer_prefill_wrapper_paged.end_forward()
|
|
||||||
self.flashinfer_prefill_wrapper_paged.begin_forward(
|
|
||||||
self.qo_indptr,
|
|
||||||
self.kv_indptr,
|
|
||||||
self.kv_indices,
|
|
||||||
self.kv_last_page_len,
|
|
||||||
num_qo_heads,
|
|
||||||
num_kv_heads,
|
|
||||||
head_dim,
|
|
||||||
1,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.flashinfer_decode_wrapper.end_forward()
|
|
||||||
self.flashinfer_decode_wrapper.begin_forward(
|
|
||||||
self.kv_indptr,
|
|
||||||
self.kv_indices,
|
|
||||||
self.kv_last_page_len,
|
|
||||||
num_qo_heads,
|
|
||||||
num_kv_heads,
|
|
||||||
head_dim,
|
|
||||||
1,
|
|
||||||
pos_encoding_mode="NONE",
|
|
||||||
data_type=self.token_to_kv_pool.kv_data[0].dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
def init_extend_args(self):
|
|
||||||
self.extend_seq_lens = self.seq_lens - self.prefix_lens
|
|
||||||
self.extend_start_loc = torch.zeros_like(self.seq_lens)
|
|
||||||
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
|
|
||||||
self.max_extend_len = int(torch.max(self.extend_seq_lens))
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def create(
|
|
||||||
cls,
|
|
||||||
model_runner,
|
|
||||||
tp_size,
|
|
||||||
forward_mode,
|
|
||||||
req_pool_indices,
|
|
||||||
seq_lens,
|
|
||||||
prefix_lens,
|
|
||||||
position_ids_offsets,
|
|
||||||
out_cache_loc,
|
|
||||||
out_cache_cont_start=None,
|
|
||||||
out_cache_cont_end=None,
|
|
||||||
top_logprobs_nums=None,
|
|
||||||
return_logprob=False,
|
|
||||||
flashinfer_prefill_wrapper_ragged=None,
|
|
||||||
flashinfer_prefill_wrapper_paged=None,
|
|
||||||
flashinfer_decode_wrapper=None,
|
|
||||||
):
|
|
||||||
batch_size = len(req_pool_indices)
|
|
||||||
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
|
|
||||||
start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
|
|
||||||
total_num_tokens = int(torch.sum(seq_lens))
|
|
||||||
max_seq_len = int(torch.max(seq_lens))
|
|
||||||
|
|
||||||
if forward_mode == ForwardMode.DECODE:
|
|
||||||
positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
|
|
||||||
other_kv_index = model_runner.req_to_token_pool.req_to_token[
|
|
||||||
req_pool_indices[0], seq_lens[0] - 1
|
|
||||||
].item()
|
|
||||||
else:
|
|
||||||
seq_lens_cpu = seq_lens.cpu().numpy()
|
|
||||||
prefix_lens_cpu = prefix_lens.cpu().numpy()
|
|
||||||
position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
|
|
||||||
positions = torch.tensor(
|
|
||||||
np.concatenate(
|
|
||||||
[
|
|
||||||
np.arange(
|
|
||||||
prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
|
|
||||||
seq_lens_cpu[i] + position_ids_offsets_cpu[i],
|
|
||||||
)
|
|
||||||
for i in range(batch_size)
|
|
||||||
],
|
|
||||||
axis=0,
|
|
||||||
),
|
|
||||||
device="cuda",
|
|
||||||
)
|
|
||||||
other_kv_index = None
|
|
||||||
|
|
||||||
ret = cls(
|
|
||||||
forward_mode=forward_mode,
|
|
||||||
batch_size=batch_size,
|
|
||||||
total_num_tokens=total_num_tokens,
|
|
||||||
max_seq_len=max_seq_len,
|
|
||||||
req_pool_indices=req_pool_indices,
|
|
||||||
start_loc=start_loc,
|
|
||||||
seq_lens=seq_lens,
|
|
||||||
prefix_lens=prefix_lens,
|
|
||||||
positions=positions,
|
|
||||||
req_to_token_pool=model_runner.req_to_token_pool,
|
|
||||||
token_to_kv_pool=model_runner.token_to_kv_pool,
|
|
||||||
out_cache_loc=out_cache_loc,
|
|
||||||
out_cache_cont_start=out_cache_cont_start,
|
|
||||||
out_cache_cont_end=out_cache_cont_end,
|
|
||||||
other_kv_index=other_kv_index,
|
|
||||||
return_logprob=return_logprob,
|
|
||||||
top_logprobs_nums=top_logprobs_nums,
|
|
||||||
flashinfer_prefill_wrapper_ragged=flashinfer_prefill_wrapper_ragged,
|
|
||||||
flashinfer_prefill_wrapper_paged=flashinfer_prefill_wrapper_paged,
|
|
||||||
flashinfer_decode_wrapper=flashinfer_decode_wrapper,
|
|
||||||
)
|
|
||||||
|
|
||||||
if forward_mode == ForwardMode.EXTEND:
|
|
||||||
ret.init_extend_args()
|
|
||||||
|
|
||||||
if not global_server_args_dict.get("disable_flashinfer", False):
|
|
||||||
ret.init_flashinfer_args(
|
|
||||||
model_runner.model_config.num_attention_heads // tp_size,
|
|
||||||
model_runner.model_config.get_num_kv_heads(tp_size),
|
|
||||||
model_runner.model_config.head_dim,
|
|
||||||
)
|
|
||||||
|
|
||||||
return ret
|
|
||||||
|
|
||||||
|
|
||||||
class ModelRunner:
|
class ModelRunner:
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -245,6 +39,7 @@ class ModelRunner:
|
|||||||
nccl_port: int,
|
nccl_port: int,
|
||||||
server_args: ServerArgs,
|
server_args: ServerArgs,
|
||||||
):
|
):
|
||||||
|
# Parse args
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.mem_fraction_static = mem_fraction_static
|
self.mem_fraction_static = mem_fraction_static
|
||||||
self.gpu_id = gpu_id
|
self.gpu_id = gpu_id
|
||||||
@@ -256,7 +51,6 @@ class ModelRunner:
|
|||||||
monkey_patch_vllm_dummy_weight_loader()
|
monkey_patch_vllm_dummy_weight_loader()
|
||||||
|
|
||||||
# Init torch distributed
|
# Init torch distributed
|
||||||
logger.info(f"[gpu_id={self.gpu_id}] Set cuda device.")
|
|
||||||
torch.cuda.set_device(self.gpu_id)
|
torch.cuda.set_device(self.gpu_id)
|
||||||
logger.info(f"[gpu_id={self.gpu_id}] Init nccl begin.")
|
logger.info(f"[gpu_id={self.gpu_id}] Init nccl begin.")
|
||||||
|
|
||||||
@@ -287,11 +81,8 @@ class ModelRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Set some global args
|
# Set some global args
|
||||||
global global_server_args_dict
|
global_server_args_dict["disable_flashinfer"] = server_args.disable_flashinfer
|
||||||
global_server_args_dict = {
|
global_server_args_dict["attention_reduce_in_fp32"] = server_args.attention_reduce_in_fp32
|
||||||
"disable_flashinfer": server_args.disable_flashinfer,
|
|
||||||
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Load the model and create memory pool
|
# Load the model and create memory pool
|
||||||
self.load_model()
|
self.load_model()
|
||||||
@@ -425,27 +216,6 @@ class ModelRunner:
|
|||||||
) = None
|
) = None
|
||||||
self.flashinfer_decode_wrapper = None
|
self.flashinfer_decode_wrapper = None
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def forward_prefill(self, batch: Batch):
|
|
||||||
input_metadata = InputMetadata.create(
|
|
||||||
self,
|
|
||||||
forward_mode=ForwardMode.PREFILL,
|
|
||||||
tp_size=self.tp_size,
|
|
||||||
req_pool_indices=batch.req_pool_indices,
|
|
||||||
seq_lens=batch.seq_lens,
|
|
||||||
prefix_lens=batch.prefix_lens,
|
|
||||||
position_ids_offsets=batch.position_ids_offsets,
|
|
||||||
out_cache_loc=batch.out_cache_loc,
|
|
||||||
top_logprobs_nums=batch.top_logprobs_nums,
|
|
||||||
return_logprob=batch.return_logprob,
|
|
||||||
flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
|
|
||||||
flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
|
|
||||||
flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
|
|
||||||
)
|
|
||||||
return self.model.forward(
|
|
||||||
batch.input_ids, input_metadata.positions, input_metadata
|
|
||||||
)
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def forward_extend(self, batch: Batch):
|
def forward_extend(self, batch: Batch):
|
||||||
input_metadata = InputMetadata.create(
|
input_metadata = InputMetadata.create(
|
||||||
@@ -523,8 +293,6 @@ class ModelRunner:
|
|||||||
return self.forward_decode(batch)
|
return self.forward_decode(batch)
|
||||||
elif forward_mode == ForwardMode.EXTEND:
|
elif forward_mode == ForwardMode.EXTEND:
|
||||||
return self.forward_extend(batch)
|
return self.forward_extend(batch)
|
||||||
elif forward_mode == ForwardMode.PREFILL:
|
|
||||||
return self.forward_prefill(batch)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invaid forward mode: {forward_mode}")
|
raise ValueError(f"Invaid forward mode: {forward_mode}")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user