Let ModelRunner take InputMetadata as input, instead of ScheduleBatch (#1541)
This commit is contained in:
@@ -225,14 +225,16 @@ def extend(reqs, model_runner):
|
|||||||
tree_cache=None,
|
tree_cache=None,
|
||||||
)
|
)
|
||||||
batch.prepare_for_extend(model_runner.model_config.vocab_size)
|
batch.prepare_for_extend(model_runner.model_config.vocab_size)
|
||||||
logits_output = model_runner.forward(batch)
|
input_metadata = batch.get_input_metadata()
|
||||||
|
logits_output = model_runner.forward(input_metadata)
|
||||||
next_token_ids = model_runner.sample(logits_output, batch).tolist()
|
next_token_ids = model_runner.sample(logits_output, batch).tolist()
|
||||||
return next_token_ids, logits_output.next_token_logits, batch
|
return next_token_ids, logits_output.next_token_logits, batch
|
||||||
|
|
||||||
|
|
||||||
def decode(input_token_ids, batch, model_runner):
|
def decode(input_token_ids, batch, model_runner):
|
||||||
batch.prepare_for_decode(input_token_ids)
|
batch.prepare_for_decode(input_token_ids)
|
||||||
logits_output = model_runner.forward(batch)
|
input_metadata = batch.get_input_metadata()
|
||||||
|
logits_output = model_runner.forward(input_metadata)
|
||||||
next_token_ids = model_runner.sample(logits_output, batch).tolist()
|
next_token_ids = model_runner.sample(logits_output, batch).tolist()
|
||||||
return next_token_ids, logits_output.next_token_logits
|
return next_token_ids, logits_output.next_token_logits
|
||||||
|
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ import torch.nn as nn
|
|||||||
|
|
||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
from sglang.srt.layers.flashinfer_utils import update_flashinfer_indices
|
from sglang.srt.layers.flashinfer_utils import update_flashinfer_indices
|
||||||
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
||||||
from sglang.srt.utils import is_hip
|
from sglang.srt.utils import is_hip
|
||||||
|
|
||||||
@@ -37,9 +37,7 @@ class AttentionBackend(ABC):
|
|||||||
"""The base class of attention backends"""
|
"""The base class of attention backends"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def init_forward_metadata(
|
def init_forward_metadata(self, input_metadata: InputMetadata):
|
||||||
self, batch: ScheduleBatch, input_metadata: InputMetadata
|
|
||||||
):
|
|
||||||
"""Init the metadata for a forward pass."""
|
"""Init the metadata for a forward pass."""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@@ -133,12 +131,11 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
self.forward_metadata = None
|
self.forward_metadata = None
|
||||||
self.cuda_graph_metadata = {}
|
self.cuda_graph_metadata = {}
|
||||||
|
|
||||||
def init_forward_metadata(
|
def init_forward_metadata(self, input_metadata: InputMetadata):
|
||||||
self, batch: ScheduleBatch, input_metadata: InputMetadata
|
|
||||||
):
|
|
||||||
if input_metadata.forward_mode.is_decode():
|
if input_metadata.forward_mode.is_decode():
|
||||||
prefix_lens = None
|
prefix_lens = None
|
||||||
use_ragged = False
|
use_ragged = False
|
||||||
|
extend_no_prefix = False
|
||||||
total_num_tokens = None
|
total_num_tokens = None
|
||||||
else:
|
else:
|
||||||
prefix_lens = input_metadata.extend_prefix_lens
|
prefix_lens = input_metadata.extend_prefix_lens
|
||||||
@@ -152,6 +149,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
use_ragged = True
|
use_ragged = True
|
||||||
|
|
||||||
total_num_tokens = torch.sum(input_metadata.seq_lens).item()
|
total_num_tokens = torch.sum(input_metadata.seq_lens).item()
|
||||||
|
extend_no_prefix = not torch.any(input_metadata.extend_prefix_lens).item()
|
||||||
|
|
||||||
update_flashinfer_indices(
|
update_flashinfer_indices(
|
||||||
input_metadata.forward_mode,
|
input_metadata.forward_mode,
|
||||||
@@ -162,7 +160,12 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
use_ragged=use_ragged,
|
use_ragged=use_ragged,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.forward_metadata = (use_ragged, total_num_tokens, self.decode_wrapper)
|
self.forward_metadata = (
|
||||||
|
use_ragged,
|
||||||
|
extend_no_prefix,
|
||||||
|
total_num_tokens,
|
||||||
|
self.decode_wrapper,
|
||||||
|
)
|
||||||
|
|
||||||
def init_cuda_graph_state(self, max_bs: int):
|
def init_cuda_graph_state(self, max_bs: int):
|
||||||
self.cuda_graph_kv_indptr = torch.zeros(
|
self.cuda_graph_kv_indptr = torch.zeros(
|
||||||
@@ -228,7 +231,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
|
|
||||||
self.cuda_graph_metadata[bs] = decode_wrapper
|
self.cuda_graph_metadata[bs] = decode_wrapper
|
||||||
|
|
||||||
self.forward_metadata = (False, None, decode_wrapper)
|
self.forward_metadata = (False, False, None, decode_wrapper)
|
||||||
|
|
||||||
def init_forward_metadata_replay_cuda_graph(
|
def init_forward_metadata_replay_cuda_graph(
|
||||||
self, bs: int, req_pool_indices, seq_lens
|
self, bs: int, req_pool_indices, seq_lens
|
||||||
@@ -254,7 +257,9 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
else:
|
else:
|
||||||
prefill_wrapper_paged = self.prefill_wrapper_paged[1]
|
prefill_wrapper_paged = self.prefill_wrapper_paged[1]
|
||||||
|
|
||||||
use_ragged, total_num_tokens, decode_wrapper = self.forward_metadata
|
use_ragged, extend_no_prefix, total_num_tokens, decode_wrapper = (
|
||||||
|
self.forward_metadata
|
||||||
|
)
|
||||||
|
|
||||||
if not use_ragged:
|
if not use_ragged:
|
||||||
if k is not None:
|
if k is not None:
|
||||||
@@ -280,7 +285,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
logits_soft_cap=layer.logit_cap,
|
logits_soft_cap=layer.logit_cap,
|
||||||
)
|
)
|
||||||
|
|
||||||
if input_metadata.extend_no_prefix:
|
if extend_no_prefix:
|
||||||
o = o1
|
o = o1
|
||||||
else:
|
else:
|
||||||
o2, s2 = prefill_wrapper_paged.forward_return_lse(
|
o2, s2 = prefill_wrapper_paged.forward_return_lse(
|
||||||
@@ -300,7 +305,9 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
||||||
|
|
||||||
def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
|
def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
|
||||||
use_ragged, total_num_tokens, decode_wrapper = self.forward_metadata
|
use_ragged, extend_no_prefix, total_num_tokens, decode_wrapper = (
|
||||||
|
self.forward_metadata
|
||||||
|
)
|
||||||
|
|
||||||
if isinstance(decode_wrapper, list):
|
if isinstance(decode_wrapper, list):
|
||||||
if layer.sliding_window_size != -1:
|
if layer.sliding_window_size != -1:
|
||||||
@@ -351,9 +358,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
|
|
||||||
self.cuda_graph_max_seq_len = model_runner.model_config.context_len
|
self.cuda_graph_max_seq_len = model_runner.model_config.context_len
|
||||||
|
|
||||||
def init_forward_metadata(
|
def init_forward_metadata(self, input_metadata: InputMetadata):
|
||||||
self, batch: ScheduleBatch, input_metadata: InputMetadata
|
|
||||||
):
|
|
||||||
"""Init auxiliary variables for triton attention backend."""
|
"""Init auxiliary variables for triton attention backend."""
|
||||||
|
|
||||||
if input_metadata.forward_mode.is_decode():
|
if input_metadata.forward_mode.is_decode():
|
||||||
@@ -371,7 +376,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
max_extend_len = None
|
max_extend_len = None
|
||||||
else:
|
else:
|
||||||
start_loc = attn_logits = max_seq_len = None
|
start_loc = attn_logits = max_seq_len = None
|
||||||
prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
|
prefix_lens = input_metadata.extend_prefix_lens
|
||||||
max_extend_len = torch.max(input_metadata.seq_lens - prefix_lens).item()
|
max_extend_len = torch.max(input_metadata.seq_lens - prefix_lens).item()
|
||||||
|
|
||||||
self.forward_metadata = start_loc, attn_logits, max_seq_len, max_extend_len
|
self.forward_metadata = start_loc, attn_logits, max_seq_len, max_extend_len
|
||||||
|
|||||||
@@ -18,13 +18,12 @@ limitations under the License.
|
|||||||
|
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.lora.lora import LoRAAdapter, get_lora_layer
|
from sglang.srt.lora.lora import LoRAAdapter, get_lora_layer
|
||||||
from sglang.srt.lora.lora_config import LoRAConfig
|
from sglang.srt.lora.lora_config import LoRAConfig
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
from sglang.srt.utils import is_hip, replace_submodule
|
from sglang.srt.utils import is_hip, replace_submodule
|
||||||
|
|
||||||
# ROCm: flashinfer available later
|
# ROCm: flashinfer available later
|
||||||
@@ -208,9 +207,9 @@ class LoRAManager:
|
|||||||
if lora_weight_name:
|
if lora_weight_name:
|
||||||
self.B_buffer[lora_weight_name][i][buffer_id].copy_(weights)
|
self.B_buffer[lora_weight_name][i][buffer_id].copy_(weights)
|
||||||
|
|
||||||
def prepare_lora_batch(self, batch, extend_seq_lens=None):
|
def prepare_lora_batch(self, input_metadata: InputMetadata):
|
||||||
# load active loras into lora memory pool
|
# load active loras into lora memory pool
|
||||||
cur_uids = set([req.lora_path for req in batch.reqs])
|
cur_uids = set(input_metadata.lora_paths)
|
||||||
assert len(cur_uids) <= self.max_loras_per_batch
|
assert len(cur_uids) <= self.max_loras_per_batch
|
||||||
i = 0
|
i = 0
|
||||||
evictable_uids = list(self.active_uids)
|
evictable_uids = list(self.active_uids)
|
||||||
@@ -230,11 +229,15 @@ class LoRAManager:
|
|||||||
return
|
return
|
||||||
|
|
||||||
# setup lora in forward modules
|
# setup lora in forward modules
|
||||||
bs = len(batch.reqs)
|
bs = input_metadata.batch_size
|
||||||
seg_lens = extend_seq_lens if batch.forward_mode.is_extend() else torch.ones(bs)
|
seg_lens = (
|
||||||
|
input_metadata.extend_seq_lens
|
||||||
|
if input_metadata.forward_mode.is_extend()
|
||||||
|
else torch.ones(bs)
|
||||||
|
)
|
||||||
weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda")
|
weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda")
|
||||||
for i, req in enumerate(batch.reqs):
|
for i, lora_path in enumerate(input_metadata.lora_paths):
|
||||||
weight_indices[i] = self.buffer_id[req.lora_path]
|
weight_indices[i] = self.buffer_id[lora_path]
|
||||||
|
|
||||||
for module_name, module in self.lora_modules:
|
for module_name, module in self.lora_modules:
|
||||||
layer_id = get_layer_id(module_name)
|
layer_id = get_layer_id(module_name)
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ from sglang.srt.constrained.jump_forward import JumpForwardMap
|
|||||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
||||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
@@ -511,6 +511,9 @@ class ScheduleBatch:
|
|||||||
self.extend_logprob_start_lens_cpu = [r.extend_logprob_start_len for r in reqs]
|
self.extend_logprob_start_lens_cpu = [r.extend_logprob_start_len for r in reqs]
|
||||||
self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size)
|
self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size)
|
||||||
|
|
||||||
|
def get_input_metadata(self):
|
||||||
|
return InputMetadata.from_schedule_batch(self)
|
||||||
|
|
||||||
def mix_with_running(self, running_batch: "ScheduleBatch"):
|
def mix_with_running(self, running_batch: "ScheduleBatch"):
|
||||||
self.forward_mode = ForwardMode.MIXED
|
self.forward_mode = ForwardMode.MIXED
|
||||||
running_bs = running_batch.batch_size()
|
running_bs = running_batch.batch_size()
|
||||||
|
|||||||
@@ -575,8 +575,9 @@ class Scheduler:
|
|||||||
if self.is_generation:
|
if self.is_generation:
|
||||||
# Forward and sample the next tokens
|
# Forward and sample the next tokens
|
||||||
if batch.extend_num_tokens != 0:
|
if batch.extend_num_tokens != 0:
|
||||||
|
input_metadata = batch.get_input_metadata()
|
||||||
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
|
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
|
||||||
batch
|
input_metadata, batch
|
||||||
)
|
)
|
||||||
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
||||||
next_token_ids
|
next_token_ids
|
||||||
@@ -640,7 +641,8 @@ class Scheduler:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert batch.extend_num_tokens != 0
|
assert batch.extend_num_tokens != 0
|
||||||
embeddings = self.tp_worker.forward_batch_embedding(batch)
|
input_metadata = batch.get_input_metadata()
|
||||||
|
embeddings = self.tp_worker.forward_batch_embedding(input_metadata)
|
||||||
|
|
||||||
# Check finish conditions
|
# Check finish conditions
|
||||||
for i, req in enumerate(batch.reqs):
|
for i, req in enumerate(batch.reqs):
|
||||||
@@ -769,7 +771,10 @@ class Scheduler:
|
|||||||
batch.prepare_for_decode()
|
batch.prepare_for_decode()
|
||||||
|
|
||||||
# Forward and sample the next tokens
|
# Forward and sample the next tokens
|
||||||
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(batch)
|
input_metadata = batch.get_input_metadata()
|
||||||
|
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
|
||||||
|
input_metadata, batch
|
||||||
|
)
|
||||||
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
||||||
next_token_ids
|
next_token_ids
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import logging
|
|||||||
from sglang.srt.configs.model_config import ModelConfig
|
from sglang.srt.configs.model_config import ModelConfig
|
||||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||||
from sglang.srt.managers.io_struct import UpdateWeightReqInput
|
from sglang.srt.managers.io_struct import UpdateWeightReqInput
|
||||||
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import broadcast_pyobj, is_multimodal_model, set_random_seed
|
from sglang.srt.utils import broadcast_pyobj, is_multimodal_model, set_random_seed
|
||||||
@@ -105,13 +106,13 @@ class ModelTpWorker:
|
|||||||
self.random_seed,
|
self.random_seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward_batch_generation(self, batch):
|
def forward_batch_generation(self, input_metadata: InputMetadata, batch):
|
||||||
logits_output = self.model_runner.forward(batch)
|
logits_output = self.model_runner.forward(input_metadata)
|
||||||
next_token_ids = self.model_runner.sample(logits_output, batch)
|
next_token_ids = self.model_runner.sample(logits_output, batch)
|
||||||
return logits_output, next_token_ids
|
return logits_output, next_token_ids
|
||||||
|
|
||||||
def forward_batch_embedding(self, batch):
|
def forward_batch_embedding(self, input_metadata: InputMetadata):
|
||||||
logits_output = self.model_runner.forward(batch)
|
logits_output = self.model_runner.forward(input_metadata)
|
||||||
embeddings = logits_output.embeddings.tolist()
|
embeddings = logits_output.embeddings.tolist()
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
|||||||
@@ -31,7 +31,6 @@ from sglang.srt.layers.logits_processor import (
|
|||||||
LogitsProcessor,
|
LogitsProcessor,
|
||||||
LogitsProcessorOutput,
|
LogitsProcessorOutput,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
||||||
from sglang.srt.utils import monkey_patch_vllm_all_gather
|
from sglang.srt.utils import monkey_patch_vllm_all_gather
|
||||||
|
|
||||||
@@ -143,7 +142,6 @@ class CudaGraphRunner:
|
|||||||
self.seq_lens = torch.full(
|
self.seq_lens = torch.full(
|
||||||
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
|
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
|
||||||
)
|
)
|
||||||
self.position_ids_offsets = torch.ones((self.max_bs,), dtype=torch.int32)
|
|
||||||
self.out_cache_loc = torch.zeros((self.max_bs,), dtype=torch.int32)
|
self.out_cache_loc = torch.zeros((self.max_bs,), dtype=torch.int32)
|
||||||
|
|
||||||
# Capture
|
# Capture
|
||||||
@@ -189,7 +187,6 @@ class CudaGraphRunner:
|
|||||||
input_ids = self.input_ids[:bs]
|
input_ids = self.input_ids[:bs]
|
||||||
req_pool_indices = self.req_pool_indices[:bs]
|
req_pool_indices = self.req_pool_indices[:bs]
|
||||||
seq_lens = self.seq_lens[:bs]
|
seq_lens = self.seq_lens[:bs]
|
||||||
position_ids_offsets = self.position_ids_offsets[:bs]
|
|
||||||
out_cache_loc = self.out_cache_loc[:bs]
|
out_cache_loc = self.out_cache_loc[:bs]
|
||||||
|
|
||||||
# Attention backend
|
# Attention backend
|
||||||
@@ -202,6 +199,7 @@ class CudaGraphRunner:
|
|||||||
input_metadata = InputMetadata(
|
input_metadata = InputMetadata(
|
||||||
forward_mode=ForwardMode.DECODE,
|
forward_mode=ForwardMode.DECODE,
|
||||||
batch_size=bs,
|
batch_size=bs,
|
||||||
|
input_ids=input_ids,
|
||||||
req_pool_indices=req_pool_indices,
|
req_pool_indices=req_pool_indices,
|
||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
req_to_token_pool=self.model_runner.req_to_token_pool,
|
req_to_token_pool=self.model_runner.req_to_token_pool,
|
||||||
@@ -210,7 +208,7 @@ class CudaGraphRunner:
|
|||||||
out_cache_loc=out_cache_loc,
|
out_cache_loc=out_cache_loc,
|
||||||
return_logprob=False,
|
return_logprob=False,
|
||||||
top_logprobs_nums=[0] * bs,
|
top_logprobs_nums=[0] * bs,
|
||||||
positions=(seq_lens - 1 + position_ids_offsets).to(torch.int64),
|
positions=torch.clamp((seq_lens - 1), min=0).to(torch.int64),
|
||||||
)
|
)
|
||||||
return forward(input_ids, input_metadata.positions, input_metadata)
|
return forward(input_ids, input_metadata.positions, input_metadata)
|
||||||
|
|
||||||
@@ -235,24 +233,22 @@ class CudaGraphRunner:
|
|||||||
self.graph_memory_pool = graph.pool()
|
self.graph_memory_pool = graph.pool()
|
||||||
return graph, out
|
return graph, out
|
||||||
|
|
||||||
def replay(self, batch: ScheduleBatch):
|
def replay(self, input_metadata: InputMetadata):
|
||||||
assert batch.out_cache_loc is not None
|
assert input_metadata.out_cache_loc is not None
|
||||||
raw_bs = len(batch.reqs)
|
raw_bs = input_metadata.batch_size
|
||||||
|
|
||||||
# Pad
|
# Pad
|
||||||
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
||||||
bs = self.capture_bs[index]
|
bs = self.capture_bs[index]
|
||||||
if bs != raw_bs:
|
if bs != raw_bs:
|
||||||
self.seq_lens.fill_(self.seq_len_fill_value)
|
self.seq_lens.fill_(self.seq_len_fill_value)
|
||||||
self.position_ids_offsets.fill_(1)
|
|
||||||
self.out_cache_loc.zero_()
|
self.out_cache_loc.zero_()
|
||||||
|
|
||||||
# Common inputs
|
# Common inputs
|
||||||
self.input_ids[:raw_bs] = batch.input_ids
|
self.input_ids[:raw_bs] = input_metadata.input_ids
|
||||||
self.req_pool_indices[:raw_bs] = batch.req_pool_indices
|
self.req_pool_indices[:raw_bs] = input_metadata.req_pool_indices
|
||||||
self.seq_lens[:raw_bs] = batch.seq_lens
|
self.seq_lens[:raw_bs] = input_metadata.seq_lens
|
||||||
self.position_ids_offsets[:raw_bs] = batch.position_ids_offsets
|
self.out_cache_loc[:raw_bs] = input_metadata.out_cache_loc
|
||||||
self.out_cache_loc[:raw_bs] = batch.out_cache_loc
|
|
||||||
|
|
||||||
# Attention backend
|
# Attention backend
|
||||||
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
|
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
|
||||||
@@ -275,15 +271,15 @@ class CudaGraphRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Extract logprobs
|
# Extract logprobs
|
||||||
if batch.return_logprob:
|
if input_metadata.return_logprob:
|
||||||
logits_output.next_token_logprobs = torch.nn.functional.log_softmax(
|
logits_output.next_token_logprobs = torch.nn.functional.log_softmax(
|
||||||
logits_output.next_token_logits, dim=-1
|
logits_output.next_token_logits, dim=-1
|
||||||
)
|
)
|
||||||
return_top_logprob = any(x > 0 for x in batch.top_logprobs_nums)
|
return_top_logprob = any(x > 0 for x in input_metadata.top_logprobs_nums)
|
||||||
if return_top_logprob:
|
if return_top_logprob:
|
||||||
logits_metadata = LogitsMetadata(
|
logits_metadata = LogitsMetadata(
|
||||||
forward_mode=ForwardMode.DECODE,
|
forward_mode=ForwardMode.DECODE,
|
||||||
top_logprobs_nums=batch.top_logprobs_nums,
|
top_logprobs_nums=input_metadata.top_logprobs_nums,
|
||||||
)
|
)
|
||||||
logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
|
logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
|
||||||
logits_output.next_token_logprobs, logits_metadata
|
logits_output.next_token_logprobs, logits_metadata
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ limitations under the License.
|
|||||||
"""Meta data for a forward pass."""
|
"""Meta data for a forward pass."""
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import IntEnum, auto
|
from enum import IntEnum, auto
|
||||||
from typing import TYPE_CHECKING, List
|
from typing import TYPE_CHECKING, List, Set
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -27,7 +27,6 @@ if TYPE_CHECKING:
|
|||||||
from sglang.srt.layers.attention_backend import AttentionBackend
|
from sglang.srt.layers.attention_backend import AttentionBackend
|
||||||
from sglang.srt.managers.schedule_batch import ImageInputs, ScheduleBatch
|
from sglang.srt.managers.schedule_batch import ImageInputs, ScheduleBatch
|
||||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
|
||||||
|
|
||||||
|
|
||||||
class ForwardMode(IntEnum):
|
class ForwardMode(IntEnum):
|
||||||
@@ -37,7 +36,7 @@ class ForwardMode(IntEnum):
|
|||||||
EXTEND = auto()
|
EXTEND = auto()
|
||||||
# Decode one token.
|
# Decode one token.
|
||||||
DECODE = auto()
|
DECODE = auto()
|
||||||
# Contains both PREFILL and EXTEND.
|
# Contains both EXTEND and DECODE.
|
||||||
MIXED = auto()
|
MIXED = auto()
|
||||||
|
|
||||||
def is_prefill(self):
|
def is_prefill(self):
|
||||||
@@ -57,15 +56,17 @@ class ForwardMode(IntEnum):
|
|||||||
class InputMetadata:
|
class InputMetadata:
|
||||||
"""Store all inforamtion of a forward pass."""
|
"""Store all inforamtion of a forward pass."""
|
||||||
|
|
||||||
|
# The forward mode
|
||||||
forward_mode: ForwardMode
|
forward_mode: ForwardMode
|
||||||
|
# The batch size
|
||||||
batch_size: int
|
batch_size: int
|
||||||
|
# The input ids
|
||||||
|
input_ids: torch.Tensor
|
||||||
|
# The indices of requests in the req_to_token_pool
|
||||||
req_pool_indices: torch.Tensor
|
req_pool_indices: torch.Tensor
|
||||||
|
# The sequence length
|
||||||
seq_lens: torch.Tensor
|
seq_lens: torch.Tensor
|
||||||
req_to_token_pool: ReqToTokenPool
|
# The indices of output tokens in the token_to_kv_pool
|
||||||
token_to_kv_pool: BaseTokenToKVPool
|
|
||||||
attn_backend: AttentionBackend
|
|
||||||
|
|
||||||
# Output location of the KV cache
|
|
||||||
out_cache_loc: torch.Tensor
|
out_cache_loc: torch.Tensor
|
||||||
|
|
||||||
# Position information
|
# Position information
|
||||||
@@ -75,7 +76,6 @@ class InputMetadata:
|
|||||||
extend_seq_lens: torch.Tensor = None
|
extend_seq_lens: torch.Tensor = None
|
||||||
extend_prefix_lens: torch.Tensor = None
|
extend_prefix_lens: torch.Tensor = None
|
||||||
extend_start_loc: torch.Tensor = None
|
extend_start_loc: torch.Tensor = None
|
||||||
extend_no_prefix: bool = None
|
|
||||||
|
|
||||||
# For logprob
|
# For logprob
|
||||||
return_logprob: bool = False
|
return_logprob: bool = False
|
||||||
@@ -86,82 +86,51 @@ class InputMetadata:
|
|||||||
# For multimodal
|
# For multimodal
|
||||||
image_inputs: List[ImageInputs] = None
|
image_inputs: List[ImageInputs] = None
|
||||||
|
|
||||||
def init_multimuldal_info(self, batch: ScheduleBatch):
|
# For LoRA
|
||||||
self.image_inputs = [r.image_inputs for r in batch.reqs]
|
lora_paths: List[str] = None
|
||||||
|
|
||||||
def compute_positions(self, batch: ScheduleBatch):
|
# Attention backend
|
||||||
if self.forward_mode.is_decode():
|
req_to_token_pool: ReqToTokenPool = None
|
||||||
if True:
|
token_to_kv_pool: BaseTokenToKVPool = None
|
||||||
self.positions = self.seq_lens - 1
|
attn_backend: AttentionBackend = None
|
||||||
else:
|
|
||||||
# Deprecated
|
|
||||||
self.positions = (self.seq_lens - 1) + batch.position_ids_offsets
|
|
||||||
else:
|
|
||||||
if True:
|
|
||||||
self.positions = torch.tensor(
|
|
||||||
np.concatenate(
|
|
||||||
[
|
|
||||||
np.arange(batch.prefix_lens_cpu[i], len(req.fill_ids))
|
|
||||||
for i, req in enumerate(batch.reqs)
|
|
||||||
],
|
|
||||||
axis=0,
|
|
||||||
),
|
|
||||||
device="cuda",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Deprecated
|
|
||||||
position_ids_offsets_cpu = batch.position_ids_offsets.cpu().numpy()
|
|
||||||
self.positions = torch.tensor(
|
|
||||||
np.concatenate(
|
|
||||||
[
|
|
||||||
np.arange(
|
|
||||||
batch.prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
|
|
||||||
len(req.fill_ids) + position_ids_offsets_cpu[i],
|
|
||||||
)
|
|
||||||
for i, req in enumerate(batch.reqs)
|
|
||||||
],
|
|
||||||
axis=0,
|
|
||||||
),
|
|
||||||
device="cuda",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Positions should be in long type
|
|
||||||
self.positions = self.positions.to(torch.int64)
|
|
||||||
|
|
||||||
def compute_extend_infos(self, batch: ScheduleBatch):
|
|
||||||
self.extend_seq_lens = torch.tensor(batch.extend_lens_cpu, device="cuda")
|
|
||||||
self.extend_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
|
|
||||||
self.extend_start_loc = torch.zeros_like(self.extend_seq_lens)
|
|
||||||
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
|
|
||||||
self.extend_no_prefix = all(x == 0 for x in batch.prefix_lens_cpu)
|
|
||||||
self.extend_seq_lens_cpu = batch.extend_lens_cpu
|
|
||||||
self.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens_cpu
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_schedule_batch(
|
def from_schedule_batch(
|
||||||
cls,
|
cls,
|
||||||
model_runner: "ModelRunner",
|
|
||||||
batch: ScheduleBatch,
|
batch: ScheduleBatch,
|
||||||
):
|
):
|
||||||
ret = cls(
|
ret = cls(
|
||||||
forward_mode=batch.forward_mode,
|
forward_mode=batch.forward_mode,
|
||||||
batch_size=batch.batch_size(),
|
batch_size=batch.batch_size(),
|
||||||
|
input_ids=batch.input_ids,
|
||||||
req_pool_indices=batch.req_pool_indices,
|
req_pool_indices=batch.req_pool_indices,
|
||||||
seq_lens=batch.seq_lens,
|
seq_lens=batch.seq_lens,
|
||||||
req_to_token_pool=model_runner.req_to_token_pool,
|
|
||||||
token_to_kv_pool=model_runner.token_to_kv_pool,
|
|
||||||
attn_backend=model_runner.attn_backend,
|
|
||||||
out_cache_loc=batch.out_cache_loc,
|
out_cache_loc=batch.out_cache_loc,
|
||||||
return_logprob=batch.return_logprob,
|
return_logprob=batch.return_logprob,
|
||||||
top_logprobs_nums=batch.top_logprobs_nums,
|
top_logprobs_nums=batch.top_logprobs_nums,
|
||||||
|
lora_paths=[req.lora_path for req in batch.reqs],
|
||||||
)
|
)
|
||||||
|
|
||||||
ret.compute_positions(batch)
|
if ret.forward_mode.is_decode():
|
||||||
|
ret.positions = (ret.seq_lens - 1).to(torch.int64)
|
||||||
|
else:
|
||||||
|
ret.positions = torch.tensor(
|
||||||
|
np.concatenate(
|
||||||
|
[
|
||||||
|
np.arange(batch.prefix_lens_cpu[i], len(req.fill_ids))
|
||||||
|
for i, req in enumerate(batch.reqs)
|
||||||
|
],
|
||||||
|
axis=0,
|
||||||
|
),
|
||||||
|
device="cuda",
|
||||||
|
).to(torch.int64)
|
||||||
|
|
||||||
if not batch.forward_mode.is_decode():
|
ret.image_inputs = [r.image_inputs for r in batch.reqs]
|
||||||
ret.init_multimuldal_info(batch)
|
ret.extend_seq_lens = torch.tensor(batch.extend_lens_cpu, device="cuda")
|
||||||
ret.compute_extend_infos(batch)
|
ret.extend_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
|
||||||
|
ret.extend_start_loc = torch.zeros_like(ret.extend_seq_lens)
|
||||||
model_runner.attn_backend.init_forward_metadata(batch, ret)
|
ret.extend_start_loc[1:] = torch.cumsum(ret.extend_seq_lens[:-1], dim=0)
|
||||||
|
ret.extend_seq_lens_cpu = batch.extend_lens_cpu
|
||||||
|
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens_cpu
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|||||||
@@ -466,46 +466,47 @@ class ModelRunner:
|
|||||||
logger.info("Capture cuda graph begin. This can take up to several minutes.")
|
logger.info("Capture cuda graph begin. This can take up to several minutes.")
|
||||||
self.cuda_graph_runner = CudaGraphRunner(self)
|
self.cuda_graph_runner = CudaGraphRunner(self)
|
||||||
|
|
||||||
def forward_decode(self, batch: ScheduleBatch):
|
def forward_decode(self, input_metadata: InputMetadata):
|
||||||
if self.server_args.lora_paths is not None:
|
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(
|
||||||
self.lora_manager.prepare_lora_batch(batch)
|
input_metadata.batch_size
|
||||||
|
):
|
||||||
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
|
return self.cuda_graph_runner.replay(input_metadata)
|
||||||
return self.cuda_graph_runner.replay(batch)
|
|
||||||
|
|
||||||
input_metadata = InputMetadata.from_schedule_batch(self, batch)
|
|
||||||
|
|
||||||
return self.model.forward(
|
return self.model.forward(
|
||||||
batch.input_ids, input_metadata.positions, input_metadata
|
input_metadata.input_ids, input_metadata.positions, input_metadata
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward_extend(self, batch: ScheduleBatch):
|
def forward_extend(self, input_metadata: InputMetadata):
|
||||||
input_metadata = InputMetadata.from_schedule_batch(self, batch)
|
|
||||||
if self.server_args.lora_paths is not None:
|
|
||||||
self.lora_manager.prepare_lora_batch(batch, input_metadata.extend_seq_lens)
|
|
||||||
|
|
||||||
if self.is_generation:
|
if self.is_generation:
|
||||||
return self.model.forward(
|
return self.model.forward(
|
||||||
batch.input_ids, input_metadata.positions, input_metadata
|
input_metadata.input_ids, input_metadata.positions, input_metadata
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Only embedding models have get_embedding parameter
|
# Only embedding models have get_embedding parameter
|
||||||
return self.model.forward(
|
return self.model.forward(
|
||||||
batch.input_ids,
|
input_metadata.input_ids,
|
||||||
input_metadata.positions,
|
input_metadata.positions,
|
||||||
input_metadata,
|
input_metadata,
|
||||||
get_embedding=True,
|
get_embedding=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, batch: ScheduleBatch) -> Tuple[LogitsProcessorOutput]:
|
def forward(self, input_metadata: InputMetadata) -> LogitsProcessorOutput:
|
||||||
assert batch.forward_mode is not None
|
# Attach attention information
|
||||||
|
input_metadata.req_to_token_pool = self.req_to_token_pool
|
||||||
|
input_metadata.token_to_kv_pool = self.token_to_kv_pool
|
||||||
|
input_metadata.attn_backend = self.attn_backend
|
||||||
|
input_metadata.attn_backend.init_forward_metadata(input_metadata)
|
||||||
|
|
||||||
if batch.forward_mode.is_decode():
|
# Attach lora information
|
||||||
return self.forward_decode(batch)
|
if self.server_args.lora_paths is not None:
|
||||||
elif batch.forward_mode.is_extend():
|
self.lora_manager.prepare_lora_batch(input_metadata)
|
||||||
return self.forward_extend(batch)
|
|
||||||
|
if input_metadata.forward_mode.is_decode():
|
||||||
|
return self.forward_decode(input_metadata)
|
||||||
|
elif input_metadata.forward_mode.is_extend():
|
||||||
|
return self.forward_extend(input_metadata)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invaid forward mode: {batch.forward_mode}")
|
raise ValueError(f"Invaid forward mode: {input_metadata.forward_mode}")
|
||||||
|
|
||||||
def _apply_logits_bias(
|
def _apply_logits_bias(
|
||||||
self, logits: torch.Tensor, sampling_info: SamplingBatchInfo
|
self, logits: torch.Tensor, sampling_info: SamplingBatchInfo
|
||||||
|
|||||||
@@ -71,10 +71,10 @@ class ModelOutput:
|
|||||||
class HFRunner:
|
class HFRunner:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_path,
|
model_path: str,
|
||||||
torch_dtype,
|
torch_dtype: torch.dtype,
|
||||||
model_type="generation",
|
model_type: str = "generation",
|
||||||
output_str_only=False,
|
output_str_only: bool = False,
|
||||||
):
|
):
|
||||||
self.model_type = model_type
|
self.model_type = model_type
|
||||||
self.output_str_only = output_str_only
|
self.output_str_only = output_str_only
|
||||||
@@ -244,15 +244,15 @@ class HFRunner:
|
|||||||
class SRTRunner:
|
class SRTRunner:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_path,
|
model_path: str,
|
||||||
torch_dtype,
|
torch_dtype: torch.dtype,
|
||||||
model_type,
|
model_type: str,
|
||||||
tp_size=1,
|
tp_size: int = 1,
|
||||||
port=DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
|
port: int = DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
|
||||||
lora_paths=None,
|
lora_paths: List[str] = None,
|
||||||
max_loras_per_batch=4,
|
max_loras_per_batch: int = 4,
|
||||||
disable_cuda_graph=False,
|
disable_cuda_graph: bool = False,
|
||||||
disable_radix_cache=False,
|
disable_radix_cache: bool = False,
|
||||||
):
|
):
|
||||||
self.model_type = model_type
|
self.model_type = model_type
|
||||||
self.is_generation = model_type == "generation"
|
self.is_generation = model_type == "generation"
|
||||||
|
|||||||
@@ -15,7 +15,6 @@ limitations under the License.
|
|||||||
|
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
import unittest
|
import unittest
|
||||||
import uuid
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -85,9 +84,9 @@ class TestLoRA(unittest.TestCase):
|
|||||||
|
|
||||||
with SRTRunner(
|
with SRTRunner(
|
||||||
base_path,
|
base_path,
|
||||||
tp_size=tp_size,
|
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
is_generation=True,
|
model_type="generation",
|
||||||
|
tp_size=tp_size,
|
||||||
lora_paths=all_lora_paths,
|
lora_paths=all_lora_paths,
|
||||||
max_loras_per_batch=3,
|
max_loras_per_batch=3,
|
||||||
disable_cuda_graph=True,
|
disable_cuda_graph=True,
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ suites = {
|
|||||||
"minimal": [
|
"minimal": [
|
||||||
"models/test_embedding_models.py",
|
"models/test_embedding_models.py",
|
||||||
"models/test_generation_models.py",
|
"models/test_generation_models.py",
|
||||||
|
# "models/test_lora.py",
|
||||||
"models/test_reward_models.py",
|
"models/test_reward_models.py",
|
||||||
"sampling/penaltylib",
|
"sampling/penaltylib",
|
||||||
"test_chunked_prefill.py",
|
"test_chunked_prefill.py",
|
||||||
|
|||||||
Reference in New Issue
Block a user