Unify forward mode (#1360)
This commit is contained in:
@@ -60,7 +60,6 @@ import torch.distributed as dist
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
||||
from sglang.srt.model_config import ModelConfig
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
@@ -208,14 +207,14 @@ def extend(reqs, model_runner):
|
||||
tree_cache=None,
|
||||
)
|
||||
batch.prepare_for_extend(model_runner.model_config.vocab_size)
|
||||
sample_output, logits_output = model_runner.forward(batch, ForwardMode.EXTEND)
|
||||
sample_output, logits_output = model_runner.forward(batch)
|
||||
next_token_ids = sample_output.batch_next_token_ids.tolist()
|
||||
return next_token_ids, logits_output.next_token_logits, batch
|
||||
|
||||
|
||||
def decode(input_token_ids, batch, model_runner):
|
||||
batch.prepare_for_decode(input_token_ids)
|
||||
sample_output, logits_output = model_runner.forward(batch, ForwardMode.DECODE)
|
||||
sample_output, logits_output = model_runner.forward(batch)
|
||||
next_token_ids = sample_output.batch_next_token_ids.tolist()
|
||||
return next_token_ids, logits_output.next_token_logits
|
||||
|
||||
|
||||
@@ -103,7 +103,7 @@ class LogitsProcessor(nn.Module):
|
||||
|
||||
@staticmethod
|
||||
def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata):
|
||||
if logits_metadata.forward_mode == ForwardMode.DECODE:
|
||||
if logits_metadata.forward_mode.is_decode():
|
||||
output_top_logprobs = []
|
||||
max_k = max(logits_metadata.top_logprobs_nums)
|
||||
ret = all_logprobs.topk(max_k, dim=1)
|
||||
@@ -163,7 +163,7 @@ class LogitsProcessor(nn.Module):
|
||||
assert isinstance(logits_metadata, LogitsMetadata)
|
||||
|
||||
# Get the last hidden states and last logits for the next token prediction
|
||||
if logits_metadata.forward_mode == ForwardMode.DECODE:
|
||||
if logits_metadata.forward_mode.is_decode():
|
||||
last_index = None
|
||||
last_hidden = hidden_states
|
||||
else:
|
||||
@@ -195,7 +195,7 @@ class LogitsProcessor(nn.Module):
|
||||
)
|
||||
else:
|
||||
# When logprob is requested, compute the logits for all tokens.
|
||||
if logits_metadata.forward_mode == ForwardMode.DECODE:
|
||||
if logits_metadata.forward_mode.is_decode():
|
||||
last_logprobs = torch.nn.functional.log_softmax(last_logits, dim=-1)
|
||||
|
||||
# Get the logprob of top-k tokens
|
||||
|
||||
@@ -197,9 +197,9 @@ class RadixAttention(nn.Module):
|
||||
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
|
||||
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
|
||||
|
||||
if input_metadata.forward_mode == ForwardMode.EXTEND:
|
||||
if input_metadata.forward_mode.is_extend():
|
||||
return self.extend_forward(q, k, v, input_metadata)
|
||||
elif input_metadata.forward_mode == ForwardMode.DECODE:
|
||||
elif input_metadata.forward_mode.is_decode():
|
||||
return self.decode_forward(q, k, v, input_metadata)
|
||||
|
||||
def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata):
|
||||
|
||||
@@ -29,6 +29,7 @@ from sglang.srt.constrained.jump_forward import JumpForwardMap
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -334,6 +335,8 @@ class ScheduleBatch:
|
||||
token_to_kv_pool: BaseTokenToKVPool
|
||||
tree_cache: BasePrefixCache
|
||||
|
||||
forward_mode: ForwardMode = None
|
||||
|
||||
# Batched arguments to model runner
|
||||
input_ids: torch.Tensor = None
|
||||
req_pool_indices: torch.Tensor = None
|
||||
@@ -397,6 +400,8 @@ class ScheduleBatch:
|
||||
return out_cache_loc
|
||||
|
||||
def prepare_for_extend(self, vocab_size: int):
|
||||
self.forward_mode = ForwardMode.EXTEND
|
||||
|
||||
bs = self.batch_size()
|
||||
reqs = self.reqs
|
||||
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
|
||||
@@ -626,6 +631,8 @@ class ScheduleBatch:
|
||||
return jump_forward_reqs
|
||||
|
||||
def prepare_for_decode(self, input_ids=None):
|
||||
self.forward_mode = ForwardMode.DECODE
|
||||
|
||||
if input_ids is None:
|
||||
input_ids = [
|
||||
r.output_ids[-1] if r.output_ids else r.origin_input_ids[-1]
|
||||
|
||||
@@ -53,7 +53,6 @@ from sglang.srt.managers.schedule_batch import (
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||
from sglang.srt.model_config import ModelConfig
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
@@ -521,9 +520,7 @@ class ModelTpServer:
|
||||
if self.model_runner.is_generation:
|
||||
# Forward and sample the next tokens
|
||||
if batch.extend_num_tokens != 0:
|
||||
sample_output, logits_output = self.model_runner.forward(
|
||||
batch, ForwardMode.EXTEND
|
||||
)
|
||||
sample_output, logits_output = self.model_runner.forward(batch)
|
||||
next_token_ids = batch.check_sample_results(sample_output)
|
||||
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
||||
next_token_ids
|
||||
@@ -588,7 +585,7 @@ class ModelTpServer:
|
||||
pt += req.extend_input_len
|
||||
else:
|
||||
assert batch.extend_num_tokens != 0
|
||||
logits_output = self.model_runner.forward(batch, ForwardMode.EXTEND)
|
||||
logits_output = self.model_runner.forward(batch)
|
||||
embeddings = logits_output.embeddings.tolist()
|
||||
|
||||
# Check finish conditions
|
||||
@@ -699,9 +696,7 @@ class ModelTpServer:
|
||||
batch.prepare_for_decode()
|
||||
|
||||
# Forward and sample the next tokens
|
||||
sample_output, logits_output = self.model_runner.forward(
|
||||
batch, ForwardMode.DECODE
|
||||
)
|
||||
sample_output, logits_output = self.model_runner.forward(batch)
|
||||
next_token_ids = batch.check_sample_results(sample_output)
|
||||
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
||||
next_token_ids
|
||||
|
||||
@@ -25,10 +25,9 @@ import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
|
||||
@@ -41,6 +40,15 @@ class ForwardMode(IntEnum):
|
||||
# Decode one token.
|
||||
DECODE = auto()
|
||||
|
||||
def is_prefill(self):
|
||||
return self == ForwardMode.PREFILL
|
||||
|
||||
def is_extend(self):
|
||||
return self == ForwardMode.EXTEND
|
||||
|
||||
def is_decode(self):
|
||||
return self == ForwardMode.DECODE
|
||||
|
||||
|
||||
@dataclass
|
||||
class InputMetadata:
|
||||
@@ -102,7 +110,7 @@ class InputMetadata:
|
||||
def compute_positions(self, batch: ScheduleBatch):
|
||||
position_ids_offsets = batch.position_ids_offsets
|
||||
|
||||
if self.forward_mode == ForwardMode.DECODE:
|
||||
if self.forward_mode.is_decode():
|
||||
if True:
|
||||
self.positions = self.seq_lens - 1
|
||||
else:
|
||||
@@ -141,7 +149,7 @@ class InputMetadata:
|
||||
self.positions = self.positions.to(torch.int64)
|
||||
|
||||
def compute_extend_infos(self, batch: ScheduleBatch):
|
||||
if self.forward_mode == ForwardMode.DECODE:
|
||||
if self.forward_mode.is_decode():
|
||||
self.extend_seq_lens = self.extend_start_loc = self.extend_no_prefix = None
|
||||
self.extend_seq_lens_cpu = self.logprob_start_lens_cpu = None
|
||||
else:
|
||||
@@ -173,10 +181,9 @@ class InputMetadata:
|
||||
cls,
|
||||
model_runner: "ModelRunner",
|
||||
batch: ScheduleBatch,
|
||||
forward_mode: ForwardMode,
|
||||
):
|
||||
ret = cls(
|
||||
forward_mode=forward_mode,
|
||||
forward_mode=batch.forward_mode,
|
||||
sampling_info=batch.sampling_info,
|
||||
batch_size=batch.batch_size(),
|
||||
req_pool_indices=batch.req_pool_indices,
|
||||
@@ -194,13 +201,11 @@ class InputMetadata:
|
||||
|
||||
ret.compute_extend_infos(batch)
|
||||
|
||||
if (
|
||||
forward_mode != ForwardMode.DECODE
|
||||
or model_runner.server_args.disable_flashinfer
|
||||
):
|
||||
fm = batch.forward_mode
|
||||
if not fm.is_decode() or model_runner.server_args.disable_flashinfer:
|
||||
ret.total_num_tokens = int(torch.sum(ret.seq_lens))
|
||||
|
||||
if forward_mode != ForwardMode.DECODE:
|
||||
if not fm.is_decode():
|
||||
ret.init_multimuldal_info(batch)
|
||||
|
||||
if model_runner.server_args.disable_flashinfer:
|
||||
@@ -209,7 +214,7 @@ class InputMetadata:
|
||||
flashinfer_use_ragged = False
|
||||
if not model_runner.server_args.disable_flashinfer:
|
||||
if (
|
||||
forward_mode != ForwardMode.DECODE
|
||||
not fm.is_decode()
|
||||
and int(torch.sum(ret.seq_lens)) > 4096
|
||||
and model_runner.sliding_window_size is None
|
||||
):
|
||||
@@ -226,7 +231,7 @@ class InputMetadata:
|
||||
self.triton_start_loc = torch.zeros_like(self.seq_lens, dtype=torch.int32)
|
||||
self.triton_start_loc[1:] = torch.cumsum(self.seq_lens[:-1], dim=0)
|
||||
|
||||
if self.forward_mode == ForwardMode.DECODE:
|
||||
if self.forward_mode.is_decode():
|
||||
self.triton_max_extend_len = None
|
||||
else:
|
||||
self.triton_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
|
||||
@@ -239,7 +244,7 @@ class InputMetadata:
|
||||
prefix_lens_cpu,
|
||||
flashinfer_use_ragged,
|
||||
):
|
||||
if self.forward_mode == ForwardMode.DECODE:
|
||||
if self.forward_mode.is_decode():
|
||||
prefix_lens = None
|
||||
else:
|
||||
prefix_lens = self.extend_prefix_lens
|
||||
@@ -339,7 +344,7 @@ def update_flashinfer_indices(
|
||||
|
||||
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
|
||||
|
||||
if forward_mode == ForwardMode.DECODE:
|
||||
if forward_mode.is_decode():
|
||||
# CUDA graph uses different flashinfer_decode_wrapper
|
||||
if flashinfer_decode_wrapper is None:
|
||||
flashinfer_decode_wrapper = model_runner.flashinfer_decode_wrapper
|
||||
@@ -388,7 +393,7 @@ def update_flashinfer_indices(
|
||||
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
|
||||
for wrapper_id in range(2):
|
||||
if wrapper_id == 0:
|
||||
if forward_mode == ForwardMode.DECODE:
|
||||
if forward_mode.is_decode():
|
||||
paged_kernel_lens = torch.minimum(
|
||||
seq_lens, torch.tensor(model_runner.sliding_window_size + 1)
|
||||
)
|
||||
@@ -418,7 +423,7 @@ def update_flashinfer_indices(
|
||||
kv_indices,
|
||||
)
|
||||
|
||||
if forward_mode == ForwardMode.DECODE:
|
||||
if forward_mode.is_decode():
|
||||
# CUDA graph uses different flashinfer_decode_wrapper
|
||||
if flashinfer_decode_wrapper is None:
|
||||
flashinfer_decode_wrapper = model_runner.flashinfer_decode_wrapper
|
||||
|
||||
@@ -530,11 +530,7 @@ class ModelRunner:
|
||||
):
|
||||
return self.cuda_graph_runner.replay(batch)
|
||||
|
||||
input_metadata = InputMetadata.from_schedule_batch(
|
||||
self,
|
||||
batch,
|
||||
ForwardMode.DECODE,
|
||||
)
|
||||
input_metadata = InputMetadata.from_schedule_batch(self, batch)
|
||||
|
||||
return self.model.forward(
|
||||
batch.input_ids, input_metadata.positions, input_metadata
|
||||
@@ -542,11 +538,7 @@ class ModelRunner:
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward_extend(self, batch: ScheduleBatch):
|
||||
input_metadata = InputMetadata.from_schedule_batch(
|
||||
self,
|
||||
batch,
|
||||
forward_mode=ForwardMode.EXTEND,
|
||||
)
|
||||
input_metadata = InputMetadata.from_schedule_batch(self, batch)
|
||||
if self.is_generation:
|
||||
return self.model.forward(
|
||||
batch.input_ids, input_metadata.positions, input_metadata
|
||||
@@ -562,11 +554,7 @@ class ModelRunner:
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward_extend_multi_modal(self, batch: ScheduleBatch):
|
||||
input_metadata = InputMetadata.from_schedule_batch(
|
||||
self,
|
||||
batch,
|
||||
forward_mode=ForwardMode.EXTEND,
|
||||
)
|
||||
input_metadata = InputMetadata.from_schedule_batch(self, batch)
|
||||
return self.model.forward(
|
||||
batch.input_ids,
|
||||
input_metadata.positions,
|
||||
@@ -577,16 +565,18 @@ class ModelRunner:
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, batch: ScheduleBatch, forward_mode: ForwardMode
|
||||
self, batch: ScheduleBatch
|
||||
) -> Tuple[SampleOutput, LogitsProcessorOutput]:
|
||||
if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
|
||||
assert batch.forward_mode is not None
|
||||
|
||||
if self.is_multimodal_model and batch.forward_mode.is_extend():
|
||||
return self.forward_extend_multi_modal(batch)
|
||||
elif forward_mode == ForwardMode.DECODE:
|
||||
elif batch.forward_mode.is_decode():
|
||||
return self.forward_decode(batch)
|
||||
elif forward_mode == ForwardMode.EXTEND:
|
||||
elif batch.forward_mode.is_extend():
|
||||
return self.forward_extend(batch)
|
||||
else:
|
||||
raise ValueError(f"Invaid forward mode: {forward_mode}")
|
||||
raise ValueError(f"Invaid forward mode: {batch.forward_mode}")
|
||||
|
||||
|
||||
@lru_cache()
|
||||
|
||||
@@ -136,7 +136,7 @@ class LlavaBaseForCausalLM(nn.Module):
|
||||
image_sizes: Optional[List[List[int]]] = None,
|
||||
image_offsets: Optional[List[int]] = None,
|
||||
) -> torch.Tensor:
|
||||
if input_metadata.forward_mode == ForwardMode.EXTEND:
|
||||
if input_metadata.forward_mode.is_extend():
|
||||
bs = input_metadata.batch_size
|
||||
# Got List[List[str]] extend it to List[str]
|
||||
# The length of the List should be equal to batch size
|
||||
@@ -357,7 +357,7 @@ class LlavaBaseForCausalLM(nn.Module):
|
||||
return self.language_model(
|
||||
input_ids, positions, input_metadata, input_embeds=input_embeds
|
||||
)
|
||||
elif input_metadata.forward_mode == ForwardMode.DECODE:
|
||||
elif input_metadata.forward_mode.is_decode():
|
||||
return self.language_model(input_ids, positions, input_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
|
||||
@@ -116,7 +116,7 @@ class LlavaVidForCausalLM(nn.Module):
|
||||
image_sizes: Optional[List[List[int]]] = None,
|
||||
image_offsets: Optional[List[int]] = None,
|
||||
) -> torch.Tensor:
|
||||
if input_metadata.forward_mode == ForwardMode.EXTEND:
|
||||
if input_metadata.forward_mode.is_extend():
|
||||
bs = input_metadata.batch_size
|
||||
|
||||
# Embed text inputs
|
||||
@@ -199,7 +199,7 @@ class LlavaVidForCausalLM(nn.Module):
|
||||
return self.language_model(
|
||||
input_ids, positions, input_metadata, input_embeds=input_embeds
|
||||
)
|
||||
elif input_metadata.forward_mode == ForwardMode.DECODE:
|
||||
elif input_metadata.forward_mode.is_decode():
|
||||
return self.language_model(input_ids, positions, input_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
|
||||
Reference in New Issue
Block a user