Adjust InputeMetadata and ScheduleBatch (#981)
This commit is contained in:
@@ -307,7 +307,6 @@ class ScheduleBatch:
|
|||||||
input_ids: torch.Tensor = None
|
input_ids: torch.Tensor = None
|
||||||
req_pool_indices: torch.Tensor = None
|
req_pool_indices: torch.Tensor = None
|
||||||
seq_lens: torch.Tensor = None
|
seq_lens: torch.Tensor = None
|
||||||
prefix_lens: torch.Tensor = None
|
|
||||||
position_ids_offsets: torch.Tensor = None
|
position_ids_offsets: torch.Tensor = None
|
||||||
out_cache_loc: torch.Tensor = None
|
out_cache_loc: torch.Tensor = None
|
||||||
extend_num_tokens: int = None
|
extend_num_tokens: int = None
|
||||||
@@ -316,11 +315,6 @@ class ScheduleBatch:
|
|||||||
return_logprob: bool = False
|
return_logprob: bool = False
|
||||||
top_logprobs_nums: List[int] = None
|
top_logprobs_nums: List[int] = None
|
||||||
|
|
||||||
# For multimodal
|
|
||||||
pixel_values: List[torch.Tensor] = None
|
|
||||||
image_sizes: List[List[int]] = None
|
|
||||||
image_offsets: List[int] = None
|
|
||||||
|
|
||||||
# Batched sampling params
|
# Batched sampling params
|
||||||
temperatures: torch.Tensor = None
|
temperatures: torch.Tensor = None
|
||||||
top_ps: torch.Tensor = None
|
top_ps: torch.Tensor = None
|
||||||
@@ -412,59 +406,40 @@ class ScheduleBatch:
|
|||||||
self.logit_bias[i][: len(int_token_logit_bias)] = int_token_logit_bias
|
self.logit_bias[i][: len(int_token_logit_bias)] = int_token_logit_bias
|
||||||
|
|
||||||
def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
|
def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
|
||||||
device = "cuda"
|
|
||||||
bs = self.batch_size()
|
bs = self.batch_size()
|
||||||
reqs = self.reqs
|
reqs = self.reqs
|
||||||
input_ids = [r.input_ids[len(r.prefix_indices) :] for r in reqs]
|
input_ids = [r.input_ids[len(r.prefix_indices) :] for r in reqs]
|
||||||
prefix_indices = [r.prefix_indices for r in reqs]
|
extend_num_tokens = sum(len(ids) for ids in input_ids)
|
||||||
|
|
||||||
# Handle prefix
|
|
||||||
extend_lens = []
|
|
||||||
prefix_lens = []
|
|
||||||
seq_lens = []
|
seq_lens = []
|
||||||
|
|
||||||
req_pool_indices_cpu = self.alloc_req_slots(bs)
|
|
||||||
|
|
||||||
for i, req in enumerate(reqs):
|
|
||||||
req.req_pool_idx = req_pool_indices_cpu[i]
|
|
||||||
extend_lens.append(len(input_ids[i]))
|
|
||||||
|
|
||||||
if len(prefix_indices[i]) == 0:
|
|
||||||
prefix_lens.append(0)
|
|
||||||
else:
|
|
||||||
prefix_lens.append(len(prefix_indices[i]))
|
|
||||||
self.req_to_token_pool.req_to_token[req.req_pool_idx][
|
|
||||||
: len(prefix_indices[i])
|
|
||||||
] = prefix_indices[i]
|
|
||||||
|
|
||||||
seq_lens.append(prefix_lens[-1] + extend_lens[-1])
|
|
||||||
|
|
||||||
# Allocate memory
|
# Allocate memory
|
||||||
seq_lens, prefix_lens = np.array(seq_lens), np.array(prefix_lens)
|
req_pool_indices_cpu = self.alloc_req_slots(bs)
|
||||||
extend_num_tokens = seq_lens.sum() - prefix_lens.sum()
|
|
||||||
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
|
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
|
||||||
|
|
||||||
pt = 0
|
pt = 0
|
||||||
for i, req in enumerate(reqs):
|
for i, req in enumerate(reqs):
|
||||||
self.req_to_token_pool.req_to_token[req.req_pool_idx][
|
req.req_pool_idx = req_pool_indices_cpu[i]
|
||||||
prefix_lens[i] : prefix_lens[i] + extend_lens[i]
|
pre_len, seq_len = len(req.prefix_indices), len(req.input_ids)
|
||||||
] = out_cache_loc[pt : pt + extend_lens[i]]
|
ext_len = seq_len - pre_len
|
||||||
pt += extend_lens[i]
|
seq_lens.append(seq_len)
|
||||||
|
|
||||||
|
if pre_len > 0:
|
||||||
|
self.req_to_token_pool.req_to_token[req.req_pool_idx][
|
||||||
|
:pre_len
|
||||||
|
] = req.prefix_indices
|
||||||
|
|
||||||
|
self.req_to_token_pool.req_to_token[req.req_pool_idx][pre_len:seq_len] = (
|
||||||
|
out_cache_loc[pt : pt + ext_len]
|
||||||
|
)
|
||||||
|
pt += ext_len
|
||||||
|
|
||||||
# Set fields
|
# Set fields
|
||||||
with torch.device("cuda"):
|
with torch.device("cuda"):
|
||||||
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32)
|
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32)
|
||||||
self.req_pool_indices = torch.tensor(req_pool_indices_cpu)
|
self.req_pool_indices = torch.tensor(req_pool_indices_cpu)
|
||||||
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32)
|
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32)
|
||||||
self.position_ids_offsets = torch.zeros((bs,), dtype=torch.int32)
|
self.position_ids_offsets = torch.zeros((bs,), dtype=torch.int64)
|
||||||
|
|
||||||
self.pixel_values = [r.pixel_values for r in reqs]
|
|
||||||
self.image_sizes = [r.image_size for r in reqs]
|
|
||||||
self.image_offsets = [
|
|
||||||
(r.image_offset - p_len) if r.image_offset is not None else 0
|
|
||||||
for r, p_len in zip(reqs, prefix_lens)
|
|
||||||
]
|
|
||||||
self.prefix_lens = torch.tensor(prefix_lens, dtype=torch.int32, device=device)
|
|
||||||
self.extend_num_tokens = extend_num_tokens
|
self.extend_num_tokens = extend_num_tokens
|
||||||
self.out_cache_loc = out_cache_loc
|
self.out_cache_loc = out_cache_loc
|
||||||
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
|
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
|
||||||
@@ -642,7 +617,6 @@ class ScheduleBatch:
|
|||||||
]
|
]
|
||||||
self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda")
|
self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda")
|
||||||
self.seq_lens.add_(1)
|
self.seq_lens.add_(1)
|
||||||
self.prefix_lens = None
|
|
||||||
|
|
||||||
# Alloc mem
|
# Alloc mem
|
||||||
bs = self.batch_size()
|
bs = self.batch_size()
|
||||||
@@ -667,7 +641,6 @@ class ScheduleBatch:
|
|||||||
self.seq_lens = self.seq_lens[new_indices]
|
self.seq_lens = self.seq_lens[new_indices]
|
||||||
self.input_ids = None
|
self.input_ids = None
|
||||||
self.req_pool_indices = self.req_pool_indices[new_indices]
|
self.req_pool_indices = self.req_pool_indices[new_indices]
|
||||||
self.prefix_lens = None
|
|
||||||
self.position_ids_offsets = self.position_ids_offsets[new_indices]
|
self.position_ids_offsets = self.position_ids_offsets[new_indices]
|
||||||
self.out_cache_loc = None
|
self.out_cache_loc = None
|
||||||
self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
|
self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
|
||||||
@@ -692,7 +665,6 @@ class ScheduleBatch:
|
|||||||
[self.req_pool_indices, other.req_pool_indices]
|
[self.req_pool_indices, other.req_pool_indices]
|
||||||
)
|
)
|
||||||
self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
|
self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
|
||||||
self.prefix_lens = None
|
|
||||||
self.position_ids_offsets = torch.concat(
|
self.position_ids_offsets = torch.concat(
|
||||||
[self.position_ids_offsets, other.position_ids_offsets]
|
[self.position_ids_offsets, other.position_ids_offsets]
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ from sglang.srt.managers.schedule_batch import ScheduleBatch
|
|||||||
from sglang.srt.model_executor.forward_batch_info import (
|
from sglang.srt.model_executor.forward_batch_info import (
|
||||||
ForwardMode,
|
ForwardMode,
|
||||||
InputMetadata,
|
InputMetadata,
|
||||||
init_flashinfer_args,
|
update_flashinfer_indices,
|
||||||
)
|
)
|
||||||
from sglang.srt.utils import monkey_patch_vllm_all_gather
|
from sglang.srt.utils import monkey_patch_vllm_all_gather
|
||||||
|
|
||||||
@@ -165,7 +165,7 @@ class CudaGraphRunner:
|
|||||||
paged_kv_indices_buffer=self.flashinfer_kv_indices,
|
paged_kv_indices_buffer=self.flashinfer_kv_indices,
|
||||||
paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[:bs],
|
paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[:bs],
|
||||||
)
|
)
|
||||||
init_flashinfer_args(
|
update_flashinfer_indices(
|
||||||
ForwardMode.DECODE,
|
ForwardMode.DECODE,
|
||||||
self.model_runner,
|
self.model_runner,
|
||||||
req_pool_indices,
|
req_pool_indices,
|
||||||
@@ -176,19 +176,19 @@ class CudaGraphRunner:
|
|||||||
|
|
||||||
# Run and capture
|
# Run and capture
|
||||||
def run_once():
|
def run_once():
|
||||||
input_metadata = InputMetadata.create(
|
input_metadata = InputMetadata(
|
||||||
self.model_runner,
|
|
||||||
forward_mode=ForwardMode.DECODE,
|
forward_mode=ForwardMode.DECODE,
|
||||||
|
batch_size=bs,
|
||||||
req_pool_indices=req_pool_indices,
|
req_pool_indices=req_pool_indices,
|
||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
prefix_lens=None,
|
req_to_token_pool=self.model_runner.req_to_token_pool,
|
||||||
position_ids_offsets=position_ids_offsets,
|
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
||||||
out_cache_loc=out_cache_loc,
|
out_cache_loc=out_cache_loc,
|
||||||
return_logprob=False,
|
return_logprob=False,
|
||||||
top_logprobs_nums=0,
|
top_logprobs_nums=0,
|
||||||
skip_flashinfer_init=True,
|
positions=(seq_lens - 1).to(torch.int64),
|
||||||
|
flashinfer_decode_wrapper=flashinfer_decode_wrapper,
|
||||||
)
|
)
|
||||||
input_metadata.flashinfer_decode_wrapper = flashinfer_decode_wrapper
|
|
||||||
|
|
||||||
return forward(input_ids, input_metadata.positions, input_metadata)
|
return forward(input_ids, input_metadata.positions, input_metadata)
|
||||||
|
|
||||||
@@ -222,7 +222,7 @@ class CudaGraphRunner:
|
|||||||
self.out_cache_loc[:raw_bs] = batch.out_cache_loc
|
self.out_cache_loc[:raw_bs] = batch.out_cache_loc
|
||||||
|
|
||||||
# FlashInfer inputs
|
# FlashInfer inputs
|
||||||
init_flashinfer_args(
|
update_flashinfer_indices(
|
||||||
ForwardMode.DECODE,
|
ForwardMode.DECODE,
|
||||||
self.model_runner,
|
self.model_runner,
|
||||||
self.req_pool_indices[:bs],
|
self.req_pool_indices[:bs],
|
||||||
|
|||||||
@@ -16,13 +16,17 @@ limitations under the License.
|
|||||||
"""ModelRunner runs the forward passes of the models."""
|
"""ModelRunner runs the forward passes of the models."""
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import IntEnum, auto
|
from enum import IntEnum, auto
|
||||||
from typing import List
|
from typing import TYPE_CHECKING, List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
|
|
||||||
|
|
||||||
class ForwardMode(IntEnum):
|
class ForwardMode(IntEnum):
|
||||||
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
|
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
|
||||||
@@ -39,25 +43,33 @@ class InputMetadata:
|
|||||||
|
|
||||||
forward_mode: ForwardMode
|
forward_mode: ForwardMode
|
||||||
batch_size: int
|
batch_size: int
|
||||||
total_num_tokens: int
|
|
||||||
req_pool_indices: torch.Tensor
|
req_pool_indices: torch.Tensor
|
||||||
seq_lens: torch.Tensor
|
seq_lens: torch.Tensor
|
||||||
positions: torch.Tensor
|
|
||||||
req_to_token_pool: ReqToTokenPool
|
req_to_token_pool: ReqToTokenPool
|
||||||
token_to_kv_pool: BaseTokenToKVPool
|
token_to_kv_pool: BaseTokenToKVPool
|
||||||
|
|
||||||
# For extend
|
|
||||||
extend_seq_lens: torch.Tensor
|
|
||||||
extend_start_loc: torch.Tensor
|
|
||||||
extend_no_prefix: bool
|
|
||||||
|
|
||||||
# Output location of the KV cache
|
# Output location of the KV cache
|
||||||
out_cache_loc: torch.Tensor = None
|
out_cache_loc: torch.Tensor
|
||||||
|
|
||||||
|
total_num_tokens: int = None
|
||||||
|
|
||||||
|
# Position information
|
||||||
|
positions: torch.Tensor = None
|
||||||
|
|
||||||
|
# For extend
|
||||||
|
extend_seq_lens: torch.Tensor = None
|
||||||
|
extend_start_loc: torch.Tensor = None
|
||||||
|
extend_no_prefix: bool = None
|
||||||
|
|
||||||
# Output options
|
# Output options
|
||||||
return_logprob: bool = False
|
return_logprob: bool = False
|
||||||
top_logprobs_nums: List[int] = None
|
top_logprobs_nums: List[int] = None
|
||||||
|
|
||||||
|
# For multimodal
|
||||||
|
pixel_values: List[torch.Tensor] = None
|
||||||
|
image_sizes: List[List[int]] = None
|
||||||
|
image_offsets: List[int] = None
|
||||||
|
|
||||||
# Trition attention backend
|
# Trition attention backend
|
||||||
triton_max_seq_len: int = 0
|
triton_max_seq_len: int = 0
|
||||||
triton_max_extend_len: int = 0
|
triton_max_extend_len: int = 0
|
||||||
@@ -70,107 +82,170 @@ class InputMetadata:
|
|||||||
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
|
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
|
||||||
flashinfer_use_ragged: bool = False
|
flashinfer_use_ragged: bool = False
|
||||||
|
|
||||||
@classmethod
|
def init_multimuldal_info(self, batch: ScheduleBatch):
|
||||||
def create(
|
reqs = batch.reqs
|
||||||
cls,
|
self.pixel_values = [r.pixel_values for r in reqs]
|
||||||
model_runner,
|
self.image_sizes = [r.image_size for r in reqs]
|
||||||
forward_mode,
|
self.image_offsets = [
|
||||||
req_pool_indices,
|
(
|
||||||
seq_lens,
|
(r.image_offset - len(r.prefix_indices))
|
||||||
prefix_lens,
|
if r.image_offset is not None
|
||||||
position_ids_offsets,
|
else 0
|
||||||
out_cache_loc,
|
|
||||||
top_logprobs_nums=None,
|
|
||||||
return_logprob=False,
|
|
||||||
skip_flashinfer_init=False,
|
|
||||||
):
|
|
||||||
flashinfer_use_ragged = False
|
|
||||||
if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer:
|
|
||||||
if forward_mode != ForwardMode.DECODE and int(torch.sum(seq_lens)) > 4096:
|
|
||||||
flashinfer_use_ragged = True
|
|
||||||
init_flashinfer_args(
|
|
||||||
forward_mode,
|
|
||||||
model_runner,
|
|
||||||
req_pool_indices,
|
|
||||||
seq_lens,
|
|
||||||
prefix_lens,
|
|
||||||
model_runner.flashinfer_decode_wrapper,
|
|
||||||
flashinfer_use_ragged,
|
|
||||||
)
|
)
|
||||||
|
for r in reqs
|
||||||
|
]
|
||||||
|
|
||||||
batch_size = len(req_pool_indices)
|
def compute_positions(self, batch: ScheduleBatch):
|
||||||
|
position_ids_offsets = batch.position_ids_offsets
|
||||||
|
|
||||||
if forward_mode == ForwardMode.DECODE:
|
if self.forward_mode == ForwardMode.DECODE:
|
||||||
positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
|
if True:
|
||||||
extend_seq_lens = extend_start_loc = extend_no_prefix = None
|
self.positions = self.seq_lens - 1
|
||||||
if not model_runner.server_args.disable_flashinfer:
|
|
||||||
# This variable is not needed in this case,
|
|
||||||
# we do not compute it to make it compatbile with cuda graph.
|
|
||||||
total_num_tokens = None
|
|
||||||
else:
|
else:
|
||||||
total_num_tokens = int(torch.sum(seq_lens))
|
# Deprecated
|
||||||
|
self.positions = (self.seq_lens - 1) + position_ids_offsets
|
||||||
else:
|
else:
|
||||||
seq_lens_cpu = seq_lens.cpu().numpy()
|
if True:
|
||||||
prefix_lens_cpu = prefix_lens.cpu().numpy()
|
self.positions = torch.tensor(
|
||||||
position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
|
np.concatenate(
|
||||||
positions = torch.tensor(
|
[
|
||||||
np.concatenate(
|
np.arange(len(req.prefix_indices), len(req.input_ids))
|
||||||
[
|
for req in batch.reqs
|
||||||
np.arange(
|
],
|
||||||
prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
|
axis=0,
|
||||||
seq_lens_cpu[i] + position_ids_offsets_cpu[i],
|
),
|
||||||
)
|
device="cuda",
|
||||||
for i in range(batch_size)
|
)
|
||||||
],
|
else:
|
||||||
axis=0,
|
# Deprecated
|
||||||
),
|
position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
|
||||||
device="cuda",
|
self.positions = torch.tensor(
|
||||||
)
|
np.concatenate(
|
||||||
extend_seq_lens = seq_lens - prefix_lens
|
[
|
||||||
extend_start_loc = torch.zeros_like(seq_lens)
|
np.arange(
|
||||||
extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0)
|
len(req.prefix_indices) + position_ids_offsets_cpu[i],
|
||||||
extend_no_prefix = torch.all(prefix_lens == 0)
|
len(req.input_ids) + position_ids_offsets_cpu[i],
|
||||||
total_num_tokens = int(torch.sum(seq_lens))
|
)
|
||||||
|
for i, req in enumerate(batch.reqs)
|
||||||
|
],
|
||||||
|
axis=0,
|
||||||
|
),
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Positions should be in long type
|
||||||
|
self.positions = self.positions.to(torch.int64)
|
||||||
|
|
||||||
|
def compute_extend_infos(self, batch: ScheduleBatch):
|
||||||
|
if self.forward_mode == ForwardMode.DECODE:
|
||||||
|
self.extend_seq_lens = self.extend_start_loc = self.extend_no_prefix = None
|
||||||
|
else:
|
||||||
|
prefix_lens_cpu = [
|
||||||
|
len(r.input_ids) - len(r.prefix_indices) for r in batch.reqs
|
||||||
|
]
|
||||||
|
self.extend_seq_lens = torch.tensor(prefix_lens_cpu, device="cuda")
|
||||||
|
self.extend_start_loc = torch.zeros_like(self.seq_lens)
|
||||||
|
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
|
||||||
|
self.extend_no_prefix = all(x == 0 for x in prefix_lens_cpu)
|
||||||
|
|
||||||
|
def init_total_num_tokens(self, batch: ScheduleBatch):
|
||||||
|
self.total_num_tokens = sum(len(req.input_ids) for req in batch.reqs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_schedule_batch(
|
||||||
|
cls,
|
||||||
|
model_runner: "ModelRunner",
|
||||||
|
batch: ScheduleBatch,
|
||||||
|
forward_mode: ForwardMode,
|
||||||
|
):
|
||||||
ret = cls(
|
ret = cls(
|
||||||
forward_mode=forward_mode,
|
forward_mode=forward_mode,
|
||||||
batch_size=batch_size,
|
batch_size=batch.batch_size(),
|
||||||
total_num_tokens=total_num_tokens,
|
req_pool_indices=batch.req_pool_indices,
|
||||||
req_pool_indices=req_pool_indices,
|
seq_lens=batch.seq_lens,
|
||||||
seq_lens=seq_lens,
|
|
||||||
positions=positions,
|
|
||||||
req_to_token_pool=model_runner.req_to_token_pool,
|
req_to_token_pool=model_runner.req_to_token_pool,
|
||||||
token_to_kv_pool=model_runner.token_to_kv_pool,
|
token_to_kv_pool=model_runner.token_to_kv_pool,
|
||||||
out_cache_loc=out_cache_loc,
|
out_cache_loc=batch.out_cache_loc,
|
||||||
extend_seq_lens=extend_seq_lens,
|
return_logprob=batch.return_logprob,
|
||||||
extend_start_loc=extend_start_loc,
|
top_logprobs_nums=batch.top_logprobs_nums,
|
||||||
extend_no_prefix=extend_no_prefix,
|
|
||||||
return_logprob=return_logprob,
|
|
||||||
top_logprobs_nums=top_logprobs_nums,
|
|
||||||
flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged,
|
|
||||||
flashinfer_prefill_wrapper_paged=model_runner.flashinfer_prefill_wrapper_paged,
|
|
||||||
flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper,
|
|
||||||
flashinfer_use_ragged=flashinfer_use_ragged,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
ret.compute_positions(batch)
|
||||||
|
|
||||||
|
ret.compute_extend_infos(batch)
|
||||||
|
|
||||||
|
ret.init_total_num_tokens(batch)
|
||||||
|
|
||||||
|
if forward_mode != ForwardMode.DECODE:
|
||||||
|
ret.init_multimuldal_info(batch)
|
||||||
|
|
||||||
|
prefix_lens = None
|
||||||
|
if forward_mode != ForwardMode.DECODE:
|
||||||
|
prefix_lens = torch.tensor(
|
||||||
|
[len(r.prefix_indices) for r in batch.reqs], device="cuda"
|
||||||
|
)
|
||||||
|
|
||||||
if model_runner.server_args.disable_flashinfer:
|
if model_runner.server_args.disable_flashinfer:
|
||||||
(
|
ret.init_triton_args(batch, prefix_lens)
|
||||||
ret.triton_max_seq_len,
|
|
||||||
ret.triton_max_extend_len,
|
flashinfer_use_ragged = False
|
||||||
ret.triton_start_loc,
|
if not model_runner.server_args.disable_flashinfer:
|
||||||
ret.triton_prefix_lens,
|
if (
|
||||||
) = init_triton_args(forward_mode, seq_lens, prefix_lens)
|
forward_mode != ForwardMode.DECODE
|
||||||
|
and int(torch.sum(ret.seq_lens)) > 4096
|
||||||
|
):
|
||||||
|
flashinfer_use_ragged = True
|
||||||
|
ret.init_flashinfer_handlers(
|
||||||
|
model_runner, prefix_lens, flashinfer_use_ragged
|
||||||
|
)
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
def init_triton_args(self, batch: ScheduleBatch, prefix_lens):
|
||||||
|
"""Init auxiliary variables for triton attention backend."""
|
||||||
|
self.triton_max_seq_len = max(len(r.input_ids) for r in batch.reqs)
|
||||||
|
self.triton_prefix_lens = prefix_lens
|
||||||
|
self.triton_start_loc = torch.zeros_like(self.seq_lens, dtype=torch.int32)
|
||||||
|
self.triton_start_loc[1:] = torch.cumsum(self.seq_lens[:-1], dim=0)
|
||||||
|
|
||||||
def init_flashinfer_args(
|
if self.forward_mode == ForwardMode.DECODE:
|
||||||
|
self.triton_max_extend_len = None
|
||||||
|
else:
|
||||||
|
extend_seq_lens = self.seq_lens - prefix_lens
|
||||||
|
self.triton_max_extend_len = int(torch.max(extend_seq_lens))
|
||||||
|
|
||||||
|
def init_flashinfer_handlers(
|
||||||
|
self, model_runner, prefix_lens, flashinfer_use_ragged
|
||||||
|
):
|
||||||
|
update_flashinfer_indices(
|
||||||
|
self.forward_mode,
|
||||||
|
model_runner,
|
||||||
|
self.req_pool_indices,
|
||||||
|
self.seq_lens,
|
||||||
|
prefix_lens,
|
||||||
|
flashinfer_use_ragged=flashinfer_use_ragged,
|
||||||
|
)
|
||||||
|
|
||||||
|
(
|
||||||
|
self.flashinfer_prefill_wrapper_ragged,
|
||||||
|
self.flashinfer_prefill_wrapper_paged,
|
||||||
|
self.flashinfer_decode_wrapper,
|
||||||
|
self.flashinfer_use_ragged,
|
||||||
|
) = (
|
||||||
|
model_runner.flashinfer_prefill_wrapper_ragged,
|
||||||
|
model_runner.flashinfer_prefill_wrapper_paged,
|
||||||
|
model_runner.flashinfer_decode_wrapper,
|
||||||
|
flashinfer_use_ragged,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def update_flashinfer_indices(
|
||||||
forward_mode,
|
forward_mode,
|
||||||
model_runner,
|
model_runner,
|
||||||
req_pool_indices,
|
req_pool_indices,
|
||||||
seq_lens,
|
seq_lens,
|
||||||
prefix_lens,
|
prefix_lens,
|
||||||
flashinfer_decode_wrapper,
|
flashinfer_decode_wrapper=None,
|
||||||
flashinfer_use_ragged=False,
|
flashinfer_use_ragged=False,
|
||||||
):
|
):
|
||||||
"""Init auxiliary variables for FlashInfer attention backend."""
|
"""Init auxiliary variables for FlashInfer attention backend."""
|
||||||
@@ -178,7 +253,6 @@ def init_flashinfer_args(
|
|||||||
num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
|
num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
|
||||||
head_dim = model_runner.model_config.head_dim
|
head_dim = model_runner.model_config.head_dim
|
||||||
batch_size = len(req_pool_indices)
|
batch_size = len(req_pool_indices)
|
||||||
total_num_tokens = int(torch.sum(seq_lens))
|
|
||||||
|
|
||||||
if flashinfer_use_ragged:
|
if flashinfer_use_ragged:
|
||||||
paged_kernel_lens = prefix_lens
|
paged_kernel_lens = prefix_lens
|
||||||
@@ -201,6 +275,10 @@ def init_flashinfer_args(
|
|||||||
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
|
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
|
||||||
|
|
||||||
if forward_mode == ForwardMode.DECODE:
|
if forward_mode == ForwardMode.DECODE:
|
||||||
|
# CUDA graph uses different flashinfer_decode_wrapper
|
||||||
|
if flashinfer_decode_wrapper is None:
|
||||||
|
flashinfer_decode_wrapper = model_runner.flashinfer_decode_wrapper
|
||||||
|
|
||||||
flashinfer_decode_wrapper.end_forward()
|
flashinfer_decode_wrapper.end_forward()
|
||||||
flashinfer_decode_wrapper.begin_forward(
|
flashinfer_decode_wrapper.begin_forward(
|
||||||
kv_indptr,
|
kv_indptr,
|
||||||
@@ -238,19 +316,3 @@ def init_flashinfer_args(
|
|||||||
head_dim,
|
head_dim,
|
||||||
1,
|
1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def init_triton_args(forward_mode, seq_lens, prefix_lens):
|
|
||||||
"""Init auxiliary variables for triton attention backend."""
|
|
||||||
batch_size = len(seq_lens)
|
|
||||||
max_seq_len = int(torch.max(seq_lens))
|
|
||||||
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
|
|
||||||
start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
|
|
||||||
|
|
||||||
if forward_mode == ForwardMode.DECODE:
|
|
||||||
max_extend_len = None
|
|
||||||
else:
|
|
||||||
extend_seq_lens = seq_lens - prefix_lens
|
|
||||||
max_extend_len = int(torch.max(extend_seq_lens))
|
|
||||||
|
|
||||||
return max_seq_len, max_extend_len, start_loc, prefix_lens
|
|
||||||
|
|||||||
@@ -350,33 +350,18 @@ class ModelRunner:
|
|||||||
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
|
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
|
||||||
return self.cuda_graph_runner.replay(batch)
|
return self.cuda_graph_runner.replay(batch)
|
||||||
|
|
||||||
input_metadata = InputMetadata.create(
|
input_metadata = InputMetadata.from_schedule_batch(
|
||||||
self,
|
self, batch, ForwardMode.DECODE
|
||||||
forward_mode=ForwardMode.DECODE,
|
|
||||||
req_pool_indices=batch.req_pool_indices,
|
|
||||||
seq_lens=batch.seq_lens,
|
|
||||||
prefix_lens=batch.prefix_lens,
|
|
||||||
position_ids_offsets=batch.position_ids_offsets,
|
|
||||||
out_cache_loc=batch.out_cache_loc,
|
|
||||||
top_logprobs_nums=batch.top_logprobs_nums,
|
|
||||||
return_logprob=batch.return_logprob,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.model.forward(
|
return self.model.forward(
|
||||||
batch.input_ids, input_metadata.positions, input_metadata
|
batch.input_ids, input_metadata.positions, input_metadata
|
||||||
)
|
)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def forward_extend(self, batch: ScheduleBatch):
|
def forward_extend(self, batch: ScheduleBatch):
|
||||||
input_metadata = InputMetadata.create(
|
input_metadata = InputMetadata.from_schedule_batch(
|
||||||
self,
|
self, batch, forward_mode=ForwardMode.EXTEND
|
||||||
forward_mode=ForwardMode.EXTEND,
|
|
||||||
req_pool_indices=batch.req_pool_indices,
|
|
||||||
seq_lens=batch.seq_lens,
|
|
||||||
prefix_lens=batch.prefix_lens,
|
|
||||||
position_ids_offsets=batch.position_ids_offsets,
|
|
||||||
out_cache_loc=batch.out_cache_loc,
|
|
||||||
top_logprobs_nums=batch.top_logprobs_nums,
|
|
||||||
return_logprob=batch.return_logprob,
|
|
||||||
)
|
)
|
||||||
return self.model.forward(
|
return self.model.forward(
|
||||||
batch.input_ids, input_metadata.positions, input_metadata
|
batch.input_ids, input_metadata.positions, input_metadata
|
||||||
@@ -384,24 +369,16 @@ class ModelRunner:
|
|||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def forward_extend_multi_modal(self, batch: ScheduleBatch):
|
def forward_extend_multi_modal(self, batch: ScheduleBatch):
|
||||||
input_metadata = InputMetadata.create(
|
input_metadata = InputMetadata.from_schedule_batch(
|
||||||
self,
|
self, batch, forward_mode=ForwardMode.EXTEND
|
||||||
forward_mode=ForwardMode.EXTEND,
|
|
||||||
req_pool_indices=batch.req_pool_indices,
|
|
||||||
seq_lens=batch.seq_lens,
|
|
||||||
prefix_lens=batch.prefix_lens,
|
|
||||||
position_ids_offsets=batch.position_ids_offsets,
|
|
||||||
out_cache_loc=batch.out_cache_loc,
|
|
||||||
return_logprob=batch.return_logprob,
|
|
||||||
top_logprobs_nums=batch.top_logprobs_nums,
|
|
||||||
)
|
)
|
||||||
return self.model.forward(
|
return self.model.forward(
|
||||||
batch.input_ids,
|
batch.input_ids,
|
||||||
input_metadata.positions,
|
input_metadata.positions,
|
||||||
input_metadata,
|
input_metadata,
|
||||||
batch.pixel_values,
|
input_metadata.pixel_values,
|
||||||
batch.image_sizes,
|
input_metadata.image_sizes,
|
||||||
batch.image_offsets,
|
input_metadata.image_offsets,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, batch: ScheduleBatch, forward_mode: ForwardMode):
|
def forward(self, batch: ScheduleBatch, forward_mode: ForwardMode):
|
||||||
|
|||||||
Reference in New Issue
Block a user