Let ModelRunner take InputMetadata as input, instead of ScheduleBatch (#1541)
This commit is contained in:
@@ -31,7 +31,6 @@ from sglang.srt.layers.logits_processor import (
|
||||
LogitsProcessor,
|
||||
LogitsProcessorOutput,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
||||
from sglang.srt.utils import monkey_patch_vllm_all_gather
|
||||
|
||||
@@ -143,7 +142,6 @@ class CudaGraphRunner:
|
||||
self.seq_lens = torch.full(
|
||||
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
|
||||
)
|
||||
self.position_ids_offsets = torch.ones((self.max_bs,), dtype=torch.int32)
|
||||
self.out_cache_loc = torch.zeros((self.max_bs,), dtype=torch.int32)
|
||||
|
||||
# Capture
|
||||
@@ -189,7 +187,6 @@ class CudaGraphRunner:
|
||||
input_ids = self.input_ids[:bs]
|
||||
req_pool_indices = self.req_pool_indices[:bs]
|
||||
seq_lens = self.seq_lens[:bs]
|
||||
position_ids_offsets = self.position_ids_offsets[:bs]
|
||||
out_cache_loc = self.out_cache_loc[:bs]
|
||||
|
||||
# Attention backend
|
||||
@@ -202,6 +199,7 @@ class CudaGraphRunner:
|
||||
input_metadata = InputMetadata(
|
||||
forward_mode=ForwardMode.DECODE,
|
||||
batch_size=bs,
|
||||
input_ids=input_ids,
|
||||
req_pool_indices=req_pool_indices,
|
||||
seq_lens=seq_lens,
|
||||
req_to_token_pool=self.model_runner.req_to_token_pool,
|
||||
@@ -210,7 +208,7 @@ class CudaGraphRunner:
|
||||
out_cache_loc=out_cache_loc,
|
||||
return_logprob=False,
|
||||
top_logprobs_nums=[0] * bs,
|
||||
positions=(seq_lens - 1 + position_ids_offsets).to(torch.int64),
|
||||
positions=torch.clamp((seq_lens - 1), min=0).to(torch.int64),
|
||||
)
|
||||
return forward(input_ids, input_metadata.positions, input_metadata)
|
||||
|
||||
@@ -235,24 +233,22 @@ class CudaGraphRunner:
|
||||
self.graph_memory_pool = graph.pool()
|
||||
return graph, out
|
||||
|
||||
def replay(self, batch: ScheduleBatch):
|
||||
assert batch.out_cache_loc is not None
|
||||
raw_bs = len(batch.reqs)
|
||||
def replay(self, input_metadata: InputMetadata):
|
||||
assert input_metadata.out_cache_loc is not None
|
||||
raw_bs = input_metadata.batch_size
|
||||
|
||||
# Pad
|
||||
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
||||
bs = self.capture_bs[index]
|
||||
if bs != raw_bs:
|
||||
self.seq_lens.fill_(self.seq_len_fill_value)
|
||||
self.position_ids_offsets.fill_(1)
|
||||
self.out_cache_loc.zero_()
|
||||
|
||||
# Common inputs
|
||||
self.input_ids[:raw_bs] = batch.input_ids
|
||||
self.req_pool_indices[:raw_bs] = batch.req_pool_indices
|
||||
self.seq_lens[:raw_bs] = batch.seq_lens
|
||||
self.position_ids_offsets[:raw_bs] = batch.position_ids_offsets
|
||||
self.out_cache_loc[:raw_bs] = batch.out_cache_loc
|
||||
self.input_ids[:raw_bs] = input_metadata.input_ids
|
||||
self.req_pool_indices[:raw_bs] = input_metadata.req_pool_indices
|
||||
self.seq_lens[:raw_bs] = input_metadata.seq_lens
|
||||
self.out_cache_loc[:raw_bs] = input_metadata.out_cache_loc
|
||||
|
||||
# Attention backend
|
||||
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
|
||||
@@ -275,15 +271,15 @@ class CudaGraphRunner:
|
||||
)
|
||||
|
||||
# Extract logprobs
|
||||
if batch.return_logprob:
|
||||
if input_metadata.return_logprob:
|
||||
logits_output.next_token_logprobs = torch.nn.functional.log_softmax(
|
||||
logits_output.next_token_logits, dim=-1
|
||||
)
|
||||
return_top_logprob = any(x > 0 for x in batch.top_logprobs_nums)
|
||||
return_top_logprob = any(x > 0 for x in input_metadata.top_logprobs_nums)
|
||||
if return_top_logprob:
|
||||
logits_metadata = LogitsMetadata(
|
||||
forward_mode=ForwardMode.DECODE,
|
||||
top_logprobs_nums=batch.top_logprobs_nums,
|
||||
top_logprobs_nums=input_metadata.top_logprobs_nums,
|
||||
)
|
||||
logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
|
||||
logits_output.next_token_logprobs, logits_metadata
|
||||
|
||||
@@ -18,7 +18,7 @@ limitations under the License.
|
||||
"""Meta data for a forward pass."""
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum, auto
|
||||
from typing import TYPE_CHECKING, List
|
||||
from typing import TYPE_CHECKING, List, Set
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -27,7 +27,6 @@ if TYPE_CHECKING:
|
||||
from sglang.srt.layers.attention_backend import AttentionBackend
|
||||
from sglang.srt.managers.schedule_batch import ImageInputs, ScheduleBatch
|
||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
|
||||
|
||||
class ForwardMode(IntEnum):
|
||||
@@ -37,7 +36,7 @@ class ForwardMode(IntEnum):
|
||||
EXTEND = auto()
|
||||
# Decode one token.
|
||||
DECODE = auto()
|
||||
# Contains both PREFILL and EXTEND.
|
||||
# Contains both EXTEND and DECODE.
|
||||
MIXED = auto()
|
||||
|
||||
def is_prefill(self):
|
||||
@@ -57,15 +56,17 @@ class ForwardMode(IntEnum):
|
||||
class InputMetadata:
|
||||
"""Store all inforamtion of a forward pass."""
|
||||
|
||||
# The forward mode
|
||||
forward_mode: ForwardMode
|
||||
# The batch size
|
||||
batch_size: int
|
||||
# The input ids
|
||||
input_ids: torch.Tensor
|
||||
# The indices of requests in the req_to_token_pool
|
||||
req_pool_indices: torch.Tensor
|
||||
# The sequence length
|
||||
seq_lens: torch.Tensor
|
||||
req_to_token_pool: ReqToTokenPool
|
||||
token_to_kv_pool: BaseTokenToKVPool
|
||||
attn_backend: AttentionBackend
|
||||
|
||||
# Output location of the KV cache
|
||||
# The indices of output tokens in the token_to_kv_pool
|
||||
out_cache_loc: torch.Tensor
|
||||
|
||||
# Position information
|
||||
@@ -75,7 +76,6 @@ class InputMetadata:
|
||||
extend_seq_lens: torch.Tensor = None
|
||||
extend_prefix_lens: torch.Tensor = None
|
||||
extend_start_loc: torch.Tensor = None
|
||||
extend_no_prefix: bool = None
|
||||
|
||||
# For logprob
|
||||
return_logprob: bool = False
|
||||
@@ -86,82 +86,51 @@ class InputMetadata:
|
||||
# For multimodal
|
||||
image_inputs: List[ImageInputs] = None
|
||||
|
||||
def init_multimuldal_info(self, batch: ScheduleBatch):
|
||||
self.image_inputs = [r.image_inputs for r in batch.reqs]
|
||||
# For LoRA
|
||||
lora_paths: List[str] = None
|
||||
|
||||
def compute_positions(self, batch: ScheduleBatch):
|
||||
if self.forward_mode.is_decode():
|
||||
if True:
|
||||
self.positions = self.seq_lens - 1
|
||||
else:
|
||||
# Deprecated
|
||||
self.positions = (self.seq_lens - 1) + batch.position_ids_offsets
|
||||
else:
|
||||
if True:
|
||||
self.positions = torch.tensor(
|
||||
np.concatenate(
|
||||
[
|
||||
np.arange(batch.prefix_lens_cpu[i], len(req.fill_ids))
|
||||
for i, req in enumerate(batch.reqs)
|
||||
],
|
||||
axis=0,
|
||||
),
|
||||
device="cuda",
|
||||
)
|
||||
else:
|
||||
# Deprecated
|
||||
position_ids_offsets_cpu = batch.position_ids_offsets.cpu().numpy()
|
||||
self.positions = torch.tensor(
|
||||
np.concatenate(
|
||||
[
|
||||
np.arange(
|
||||
batch.prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
|
||||
len(req.fill_ids) + position_ids_offsets_cpu[i],
|
||||
)
|
||||
for i, req in enumerate(batch.reqs)
|
||||
],
|
||||
axis=0,
|
||||
),
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
# Positions should be in long type
|
||||
self.positions = self.positions.to(torch.int64)
|
||||
|
||||
def compute_extend_infos(self, batch: ScheduleBatch):
|
||||
self.extend_seq_lens = torch.tensor(batch.extend_lens_cpu, device="cuda")
|
||||
self.extend_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
|
||||
self.extend_start_loc = torch.zeros_like(self.extend_seq_lens)
|
||||
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
|
||||
self.extend_no_prefix = all(x == 0 for x in batch.prefix_lens_cpu)
|
||||
self.extend_seq_lens_cpu = batch.extend_lens_cpu
|
||||
self.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens_cpu
|
||||
# Attention backend
|
||||
req_to_token_pool: ReqToTokenPool = None
|
||||
token_to_kv_pool: BaseTokenToKVPool = None
|
||||
attn_backend: AttentionBackend = None
|
||||
|
||||
@classmethod
|
||||
def from_schedule_batch(
|
||||
cls,
|
||||
model_runner: "ModelRunner",
|
||||
batch: ScheduleBatch,
|
||||
):
|
||||
ret = cls(
|
||||
forward_mode=batch.forward_mode,
|
||||
batch_size=batch.batch_size(),
|
||||
input_ids=batch.input_ids,
|
||||
req_pool_indices=batch.req_pool_indices,
|
||||
seq_lens=batch.seq_lens,
|
||||
req_to_token_pool=model_runner.req_to_token_pool,
|
||||
token_to_kv_pool=model_runner.token_to_kv_pool,
|
||||
attn_backend=model_runner.attn_backend,
|
||||
out_cache_loc=batch.out_cache_loc,
|
||||
return_logprob=batch.return_logprob,
|
||||
top_logprobs_nums=batch.top_logprobs_nums,
|
||||
lora_paths=[req.lora_path for req in batch.reqs],
|
||||
)
|
||||
|
||||
ret.compute_positions(batch)
|
||||
if ret.forward_mode.is_decode():
|
||||
ret.positions = (ret.seq_lens - 1).to(torch.int64)
|
||||
else:
|
||||
ret.positions = torch.tensor(
|
||||
np.concatenate(
|
||||
[
|
||||
np.arange(batch.prefix_lens_cpu[i], len(req.fill_ids))
|
||||
for i, req in enumerate(batch.reqs)
|
||||
],
|
||||
axis=0,
|
||||
),
|
||||
device="cuda",
|
||||
).to(torch.int64)
|
||||
|
||||
if not batch.forward_mode.is_decode():
|
||||
ret.init_multimuldal_info(batch)
|
||||
ret.compute_extend_infos(batch)
|
||||
|
||||
model_runner.attn_backend.init_forward_metadata(batch, ret)
|
||||
ret.image_inputs = [r.image_inputs for r in batch.reqs]
|
||||
ret.extend_seq_lens = torch.tensor(batch.extend_lens_cpu, device="cuda")
|
||||
ret.extend_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
|
||||
ret.extend_start_loc = torch.zeros_like(ret.extend_seq_lens)
|
||||
ret.extend_start_loc[1:] = torch.cumsum(ret.extend_seq_lens[:-1], dim=0)
|
||||
ret.extend_seq_lens_cpu = batch.extend_lens_cpu
|
||||
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens_cpu
|
||||
|
||||
return ret
|
||||
|
||||
@@ -466,46 +466,47 @@ class ModelRunner:
|
||||
logger.info("Capture cuda graph begin. This can take up to several minutes.")
|
||||
self.cuda_graph_runner = CudaGraphRunner(self)
|
||||
|
||||
def forward_decode(self, batch: ScheduleBatch):
|
||||
if self.server_args.lora_paths is not None:
|
||||
self.lora_manager.prepare_lora_batch(batch)
|
||||
|
||||
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
|
||||
return self.cuda_graph_runner.replay(batch)
|
||||
|
||||
input_metadata = InputMetadata.from_schedule_batch(self, batch)
|
||||
def forward_decode(self, input_metadata: InputMetadata):
|
||||
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(
|
||||
input_metadata.batch_size
|
||||
):
|
||||
return self.cuda_graph_runner.replay(input_metadata)
|
||||
|
||||
return self.model.forward(
|
||||
batch.input_ids, input_metadata.positions, input_metadata
|
||||
input_metadata.input_ids, input_metadata.positions, input_metadata
|
||||
)
|
||||
|
||||
def forward_extend(self, batch: ScheduleBatch):
|
||||
input_metadata = InputMetadata.from_schedule_batch(self, batch)
|
||||
if self.server_args.lora_paths is not None:
|
||||
self.lora_manager.prepare_lora_batch(batch, input_metadata.extend_seq_lens)
|
||||
|
||||
def forward_extend(self, input_metadata: InputMetadata):
|
||||
if self.is_generation:
|
||||
return self.model.forward(
|
||||
batch.input_ids, input_metadata.positions, input_metadata
|
||||
input_metadata.input_ids, input_metadata.positions, input_metadata
|
||||
)
|
||||
else:
|
||||
# Only embedding models have get_embedding parameter
|
||||
return self.model.forward(
|
||||
batch.input_ids,
|
||||
input_metadata.input_ids,
|
||||
input_metadata.positions,
|
||||
input_metadata,
|
||||
get_embedding=True,
|
||||
)
|
||||
|
||||
def forward(self, batch: ScheduleBatch) -> Tuple[LogitsProcessorOutput]:
|
||||
assert batch.forward_mode is not None
|
||||
def forward(self, input_metadata: InputMetadata) -> LogitsProcessorOutput:
|
||||
# Attach attention information
|
||||
input_metadata.req_to_token_pool = self.req_to_token_pool
|
||||
input_metadata.token_to_kv_pool = self.token_to_kv_pool
|
||||
input_metadata.attn_backend = self.attn_backend
|
||||
input_metadata.attn_backend.init_forward_metadata(input_metadata)
|
||||
|
||||
if batch.forward_mode.is_decode():
|
||||
return self.forward_decode(batch)
|
||||
elif batch.forward_mode.is_extend():
|
||||
return self.forward_extend(batch)
|
||||
# Attach lora information
|
||||
if self.server_args.lora_paths is not None:
|
||||
self.lora_manager.prepare_lora_batch(input_metadata)
|
||||
|
||||
if input_metadata.forward_mode.is_decode():
|
||||
return self.forward_decode(input_metadata)
|
||||
elif input_metadata.forward_mode.is_extend():
|
||||
return self.forward_extend(input_metadata)
|
||||
else:
|
||||
raise ValueError(f"Invaid forward mode: {batch.forward_mode}")
|
||||
raise ValueError(f"Invaid forward mode: {input_metadata.forward_mode}")
|
||||
|
||||
def _apply_logits_bias(
|
||||
self, logits: torch.Tensor, sampling_info: SamplingBatchInfo
|
||||
|
||||
Reference in New Issue
Block a user