Rename InputMetadata -> ForwardBatch (#1543)
This commit is contained in:
@@ -31,7 +31,7 @@ from sglang.srt.layers.logits_processor import (
|
||||
LogitsProcessor,
|
||||
LogitsProcessorOutput,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.utils import monkey_patch_vllm_all_gather
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -196,7 +196,7 @@ class CudaGraphRunner:
|
||||
|
||||
# Run and capture
|
||||
def run_once():
|
||||
input_metadata = InputMetadata(
|
||||
forward_batch = ForwardBatch(
|
||||
forward_mode=ForwardMode.DECODE,
|
||||
batch_size=bs,
|
||||
input_ids=input_ids,
|
||||
@@ -210,7 +210,7 @@ class CudaGraphRunner:
|
||||
top_logprobs_nums=[0] * bs,
|
||||
positions=torch.clamp((seq_lens - 1), min=0).to(torch.int64),
|
||||
)
|
||||
return forward(input_ids, input_metadata.positions, input_metadata)
|
||||
return forward(input_ids, forward_batch.positions, forward_batch)
|
||||
|
||||
for _ in range(2):
|
||||
torch.cuda.synchronize()
|
||||
@@ -233,9 +233,9 @@ class CudaGraphRunner:
|
||||
self.graph_memory_pool = graph.pool()
|
||||
return graph, out
|
||||
|
||||
def replay(self, input_metadata: InputMetadata):
|
||||
assert input_metadata.out_cache_loc is not None
|
||||
raw_bs = input_metadata.batch_size
|
||||
def replay(self, forward_batch: ForwardBatch):
|
||||
assert forward_batch.out_cache_loc is not None
|
||||
raw_bs = forward_batch.batch_size
|
||||
|
||||
# Pad
|
||||
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
||||
@@ -245,10 +245,10 @@ class CudaGraphRunner:
|
||||
self.out_cache_loc.zero_()
|
||||
|
||||
# Common inputs
|
||||
self.input_ids[:raw_bs] = input_metadata.input_ids
|
||||
self.req_pool_indices[:raw_bs] = input_metadata.req_pool_indices
|
||||
self.seq_lens[:raw_bs] = input_metadata.seq_lens
|
||||
self.out_cache_loc[:raw_bs] = input_metadata.out_cache_loc
|
||||
self.input_ids[:raw_bs] = forward_batch.input_ids
|
||||
self.req_pool_indices[:raw_bs] = forward_batch.req_pool_indices
|
||||
self.seq_lens[:raw_bs] = forward_batch.seq_lens
|
||||
self.out_cache_loc[:raw_bs] = forward_batch.out_cache_loc
|
||||
|
||||
# Attention backend
|
||||
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
|
||||
@@ -271,15 +271,15 @@ class CudaGraphRunner:
|
||||
)
|
||||
|
||||
# Extract logprobs
|
||||
if input_metadata.return_logprob:
|
||||
if forward_batch.return_logprob:
|
||||
logits_output.next_token_logprobs = torch.nn.functional.log_softmax(
|
||||
logits_output.next_token_logits, dim=-1
|
||||
)
|
||||
return_top_logprob = any(x > 0 for x in input_metadata.top_logprobs_nums)
|
||||
return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
|
||||
if return_top_logprob:
|
||||
logits_metadata = LogitsMetadata(
|
||||
forward_mode=ForwardMode.DECODE,
|
||||
top_logprobs_nums=input_metadata.top_logprobs_nums,
|
||||
top_logprobs_nums=forward_batch.top_logprobs_nums,
|
||||
)
|
||||
logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
|
||||
logits_output.next_token_logprobs, logits_metadata
|
||||
|
||||
@@ -18,7 +18,7 @@ limitations under the License.
|
||||
"""Meta data for a forward pass."""
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum, auto
|
||||
from typing import TYPE_CHECKING, List, Set
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -53,8 +53,8 @@ class ForwardMode(IntEnum):
|
||||
|
||||
|
||||
@dataclass
|
||||
class InputMetadata:
|
||||
"""Store all inforamtion of a forward pass."""
|
||||
class ForwardBatch:
|
||||
"""Store all inputs of a forward pass."""
|
||||
|
||||
# The forward mode
|
||||
forward_mode: ForwardMode
|
||||
|
||||
@@ -48,7 +48,7 @@ from sglang.srt.mem_cache.memory_pool import (
|
||||
MLATokenToKVPool,
|
||||
ReqToTokenPool,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
@@ -466,47 +466,47 @@ class ModelRunner:
|
||||
logger.info("Capture cuda graph begin. This can take up to several minutes.")
|
||||
self.cuda_graph_runner = CudaGraphRunner(self)
|
||||
|
||||
def forward_decode(self, input_metadata: InputMetadata):
|
||||
def forward_decode(self, forward_batch: ForwardBatch):
|
||||
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(
|
||||
input_metadata.batch_size
|
||||
forward_batch.batch_size
|
||||
):
|
||||
return self.cuda_graph_runner.replay(input_metadata)
|
||||
return self.cuda_graph_runner.replay(forward_batch)
|
||||
|
||||
return self.model.forward(
|
||||
input_metadata.input_ids, input_metadata.positions, input_metadata
|
||||
forward_batch.input_ids, forward_batch.positions, forward_batch
|
||||
)
|
||||
|
||||
def forward_extend(self, input_metadata: InputMetadata):
|
||||
def forward_extend(self, forward_batch: ForwardBatch):
|
||||
if self.is_generation:
|
||||
return self.model.forward(
|
||||
input_metadata.input_ids, input_metadata.positions, input_metadata
|
||||
forward_batch.input_ids, forward_batch.positions, forward_batch
|
||||
)
|
||||
else:
|
||||
# Only embedding models have get_embedding parameter
|
||||
return self.model.forward(
|
||||
input_metadata.input_ids,
|
||||
input_metadata.positions,
|
||||
input_metadata,
|
||||
forward_batch.input_ids,
|
||||
forward_batch.positions,
|
||||
forward_batch,
|
||||
get_embedding=True,
|
||||
)
|
||||
|
||||
def forward(self, input_metadata: InputMetadata) -> LogitsProcessorOutput:
|
||||
def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
|
||||
# Attach attention information
|
||||
input_metadata.req_to_token_pool = self.req_to_token_pool
|
||||
input_metadata.token_to_kv_pool = self.token_to_kv_pool
|
||||
input_metadata.attn_backend = self.attn_backend
|
||||
input_metadata.attn_backend.init_forward_metadata(input_metadata)
|
||||
forward_batch.req_to_token_pool = self.req_to_token_pool
|
||||
forward_batch.token_to_kv_pool = self.token_to_kv_pool
|
||||
forward_batch.attn_backend = self.attn_backend
|
||||
forward_batch.attn_backend.init_forward_metadata(forward_batch)
|
||||
|
||||
# Attach lora information
|
||||
if self.server_args.lora_paths is not None:
|
||||
self.lora_manager.prepare_lora_batch(input_metadata)
|
||||
self.lora_manager.prepare_lora_batch(forward_batch)
|
||||
|
||||
if input_metadata.forward_mode.is_decode():
|
||||
return self.forward_decode(input_metadata)
|
||||
elif input_metadata.forward_mode.is_extend():
|
||||
return self.forward_extend(input_metadata)
|
||||
if forward_batch.forward_mode.is_decode():
|
||||
return self.forward_decode(forward_batch)
|
||||
elif forward_batch.forward_mode.is_extend():
|
||||
return self.forward_extend(forward_batch)
|
||||
else:
|
||||
raise ValueError(f"Invaid forward mode: {input_metadata.forward_mode}")
|
||||
raise ValueError(f"Invaid forward mode: {forward_batch.forward_mode}")
|
||||
|
||||
def _apply_logits_bias(
|
||||
self, logits: torch.Tensor, sampling_info: SamplingBatchInfo
|
||||
|
||||
Reference in New Issue
Block a user