Unify forward mode (#1360)
This commit is contained in:
@@ -60,7 +60,6 @@ import torch.distributed as dist
|
|||||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||||
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
||||||
from sglang.srt.model_config import ModelConfig
|
from sglang.srt.model_config import ModelConfig
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
|
||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
@@ -208,14 +207,14 @@ def extend(reqs, model_runner):
|
|||||||
tree_cache=None,
|
tree_cache=None,
|
||||||
)
|
)
|
||||||
batch.prepare_for_extend(model_runner.model_config.vocab_size)
|
batch.prepare_for_extend(model_runner.model_config.vocab_size)
|
||||||
sample_output, logits_output = model_runner.forward(batch, ForwardMode.EXTEND)
|
sample_output, logits_output = model_runner.forward(batch)
|
||||||
next_token_ids = sample_output.batch_next_token_ids.tolist()
|
next_token_ids = sample_output.batch_next_token_ids.tolist()
|
||||||
return next_token_ids, logits_output.next_token_logits, batch
|
return next_token_ids, logits_output.next_token_logits, batch
|
||||||
|
|
||||||
|
|
||||||
def decode(input_token_ids, batch, model_runner):
|
def decode(input_token_ids, batch, model_runner):
|
||||||
batch.prepare_for_decode(input_token_ids)
|
batch.prepare_for_decode(input_token_ids)
|
||||||
sample_output, logits_output = model_runner.forward(batch, ForwardMode.DECODE)
|
sample_output, logits_output = model_runner.forward(batch)
|
||||||
next_token_ids = sample_output.batch_next_token_ids.tolist()
|
next_token_ids = sample_output.batch_next_token_ids.tolist()
|
||||||
return next_token_ids, logits_output.next_token_logits
|
return next_token_ids, logits_output.next_token_logits
|
||||||
|
|
||||||
|
|||||||
@@ -103,7 +103,7 @@ class LogitsProcessor(nn.Module):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata):
|
def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata):
|
||||||
if logits_metadata.forward_mode == ForwardMode.DECODE:
|
if logits_metadata.forward_mode.is_decode():
|
||||||
output_top_logprobs = []
|
output_top_logprobs = []
|
||||||
max_k = max(logits_metadata.top_logprobs_nums)
|
max_k = max(logits_metadata.top_logprobs_nums)
|
||||||
ret = all_logprobs.topk(max_k, dim=1)
|
ret = all_logprobs.topk(max_k, dim=1)
|
||||||
@@ -163,7 +163,7 @@ class LogitsProcessor(nn.Module):
|
|||||||
assert isinstance(logits_metadata, LogitsMetadata)
|
assert isinstance(logits_metadata, LogitsMetadata)
|
||||||
|
|
||||||
# Get the last hidden states and last logits for the next token prediction
|
# Get the last hidden states and last logits for the next token prediction
|
||||||
if logits_metadata.forward_mode == ForwardMode.DECODE:
|
if logits_metadata.forward_mode.is_decode():
|
||||||
last_index = None
|
last_index = None
|
||||||
last_hidden = hidden_states
|
last_hidden = hidden_states
|
||||||
else:
|
else:
|
||||||
@@ -195,7 +195,7 @@ class LogitsProcessor(nn.Module):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# When logprob is requested, compute the logits for all tokens.
|
# When logprob is requested, compute the logits for all tokens.
|
||||||
if logits_metadata.forward_mode == ForwardMode.DECODE:
|
if logits_metadata.forward_mode.is_decode():
|
||||||
last_logprobs = torch.nn.functional.log_softmax(last_logits, dim=-1)
|
last_logprobs = torch.nn.functional.log_softmax(last_logits, dim=-1)
|
||||||
|
|
||||||
# Get the logprob of top-k tokens
|
# Get the logprob of top-k tokens
|
||||||
|
|||||||
@@ -197,9 +197,9 @@ class RadixAttention(nn.Module):
|
|||||||
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
|
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
|
||||||
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
|
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
|
||||||
|
|
||||||
if input_metadata.forward_mode == ForwardMode.EXTEND:
|
if input_metadata.forward_mode.is_extend():
|
||||||
return self.extend_forward(q, k, v, input_metadata)
|
return self.extend_forward(q, k, v, input_metadata)
|
||||||
elif input_metadata.forward_mode == ForwardMode.DECODE:
|
elif input_metadata.forward_mode.is_decode():
|
||||||
return self.decode_forward(q, k, v, input_metadata)
|
return self.decode_forward(q, k, v, input_metadata)
|
||||||
|
|
||||||
def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata):
|
def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata):
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ from sglang.srt.constrained.jump_forward import JumpForwardMap
|
|||||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
||||||
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -334,6 +335,8 @@ class ScheduleBatch:
|
|||||||
token_to_kv_pool: BaseTokenToKVPool
|
token_to_kv_pool: BaseTokenToKVPool
|
||||||
tree_cache: BasePrefixCache
|
tree_cache: BasePrefixCache
|
||||||
|
|
||||||
|
forward_mode: ForwardMode = None
|
||||||
|
|
||||||
# Batched arguments to model runner
|
# Batched arguments to model runner
|
||||||
input_ids: torch.Tensor = None
|
input_ids: torch.Tensor = None
|
||||||
req_pool_indices: torch.Tensor = None
|
req_pool_indices: torch.Tensor = None
|
||||||
@@ -397,6 +400,8 @@ class ScheduleBatch:
|
|||||||
return out_cache_loc
|
return out_cache_loc
|
||||||
|
|
||||||
def prepare_for_extend(self, vocab_size: int):
|
def prepare_for_extend(self, vocab_size: int):
|
||||||
|
self.forward_mode = ForwardMode.EXTEND
|
||||||
|
|
||||||
bs = self.batch_size()
|
bs = self.batch_size()
|
||||||
reqs = self.reqs
|
reqs = self.reqs
|
||||||
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
|
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
|
||||||
@@ -626,6 +631,8 @@ class ScheduleBatch:
|
|||||||
return jump_forward_reqs
|
return jump_forward_reqs
|
||||||
|
|
||||||
def prepare_for_decode(self, input_ids=None):
|
def prepare_for_decode(self, input_ids=None):
|
||||||
|
self.forward_mode = ForwardMode.DECODE
|
||||||
|
|
||||||
if input_ids is None:
|
if input_ids is None:
|
||||||
input_ids = [
|
input_ids = [
|
||||||
r.output_ids[-1] if r.output_ids else r.origin_input_ids[-1]
|
r.output_ids[-1] if r.output_ids else r.origin_input_ids[-1]
|
||||||
|
|||||||
@@ -53,7 +53,6 @@ from sglang.srt.managers.schedule_batch import (
|
|||||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||||
from sglang.srt.model_config import ModelConfig
|
from sglang.srt.model_config import ModelConfig
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
|
||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
@@ -521,9 +520,7 @@ class ModelTpServer:
|
|||||||
if self.model_runner.is_generation:
|
if self.model_runner.is_generation:
|
||||||
# Forward and sample the next tokens
|
# Forward and sample the next tokens
|
||||||
if batch.extend_num_tokens != 0:
|
if batch.extend_num_tokens != 0:
|
||||||
sample_output, logits_output = self.model_runner.forward(
|
sample_output, logits_output = self.model_runner.forward(batch)
|
||||||
batch, ForwardMode.EXTEND
|
|
||||||
)
|
|
||||||
next_token_ids = batch.check_sample_results(sample_output)
|
next_token_ids = batch.check_sample_results(sample_output)
|
||||||
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
||||||
next_token_ids
|
next_token_ids
|
||||||
@@ -588,7 +585,7 @@ class ModelTpServer:
|
|||||||
pt += req.extend_input_len
|
pt += req.extend_input_len
|
||||||
else:
|
else:
|
||||||
assert batch.extend_num_tokens != 0
|
assert batch.extend_num_tokens != 0
|
||||||
logits_output = self.model_runner.forward(batch, ForwardMode.EXTEND)
|
logits_output = self.model_runner.forward(batch)
|
||||||
embeddings = logits_output.embeddings.tolist()
|
embeddings = logits_output.embeddings.tolist()
|
||||||
|
|
||||||
# Check finish conditions
|
# Check finish conditions
|
||||||
@@ -699,9 +696,7 @@ class ModelTpServer:
|
|||||||
batch.prepare_for_decode()
|
batch.prepare_for_decode()
|
||||||
|
|
||||||
# Forward and sample the next tokens
|
# Forward and sample the next tokens
|
||||||
sample_output, logits_output = self.model_runner.forward(
|
sample_output, logits_output = self.model_runner.forward(batch)
|
||||||
batch, ForwardMode.DECODE
|
|
||||||
)
|
|
||||||
next_token_ids = batch.check_sample_results(sample_output)
|
next_token_ids = batch.check_sample_results(sample_output)
|
||||||
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
||||||
next_token_ids
|
next_token_ids
|
||||||
|
|||||||
@@ -25,10 +25,9 @@ import torch
|
|||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||||
|
|
||||||
@@ -41,6 +40,15 @@ class ForwardMode(IntEnum):
|
|||||||
# Decode one token.
|
# Decode one token.
|
||||||
DECODE = auto()
|
DECODE = auto()
|
||||||
|
|
||||||
|
def is_prefill(self):
|
||||||
|
return self == ForwardMode.PREFILL
|
||||||
|
|
||||||
|
def is_extend(self):
|
||||||
|
return self == ForwardMode.EXTEND
|
||||||
|
|
||||||
|
def is_decode(self):
|
||||||
|
return self == ForwardMode.DECODE
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class InputMetadata:
|
class InputMetadata:
|
||||||
@@ -102,7 +110,7 @@ class InputMetadata:
|
|||||||
def compute_positions(self, batch: ScheduleBatch):
|
def compute_positions(self, batch: ScheduleBatch):
|
||||||
position_ids_offsets = batch.position_ids_offsets
|
position_ids_offsets = batch.position_ids_offsets
|
||||||
|
|
||||||
if self.forward_mode == ForwardMode.DECODE:
|
if self.forward_mode.is_decode():
|
||||||
if True:
|
if True:
|
||||||
self.positions = self.seq_lens - 1
|
self.positions = self.seq_lens - 1
|
||||||
else:
|
else:
|
||||||
@@ -141,7 +149,7 @@ class InputMetadata:
|
|||||||
self.positions = self.positions.to(torch.int64)
|
self.positions = self.positions.to(torch.int64)
|
||||||
|
|
||||||
def compute_extend_infos(self, batch: ScheduleBatch):
|
def compute_extend_infos(self, batch: ScheduleBatch):
|
||||||
if self.forward_mode == ForwardMode.DECODE:
|
if self.forward_mode.is_decode():
|
||||||
self.extend_seq_lens = self.extend_start_loc = self.extend_no_prefix = None
|
self.extend_seq_lens = self.extend_start_loc = self.extend_no_prefix = None
|
||||||
self.extend_seq_lens_cpu = self.logprob_start_lens_cpu = None
|
self.extend_seq_lens_cpu = self.logprob_start_lens_cpu = None
|
||||||
else:
|
else:
|
||||||
@@ -173,10 +181,9 @@ class InputMetadata:
|
|||||||
cls,
|
cls,
|
||||||
model_runner: "ModelRunner",
|
model_runner: "ModelRunner",
|
||||||
batch: ScheduleBatch,
|
batch: ScheduleBatch,
|
||||||
forward_mode: ForwardMode,
|
|
||||||
):
|
):
|
||||||
ret = cls(
|
ret = cls(
|
||||||
forward_mode=forward_mode,
|
forward_mode=batch.forward_mode,
|
||||||
sampling_info=batch.sampling_info,
|
sampling_info=batch.sampling_info,
|
||||||
batch_size=batch.batch_size(),
|
batch_size=batch.batch_size(),
|
||||||
req_pool_indices=batch.req_pool_indices,
|
req_pool_indices=batch.req_pool_indices,
|
||||||
@@ -194,13 +201,11 @@ class InputMetadata:
|
|||||||
|
|
||||||
ret.compute_extend_infos(batch)
|
ret.compute_extend_infos(batch)
|
||||||
|
|
||||||
if (
|
fm = batch.forward_mode
|
||||||
forward_mode != ForwardMode.DECODE
|
if not fm.is_decode() or model_runner.server_args.disable_flashinfer:
|
||||||
or model_runner.server_args.disable_flashinfer
|
|
||||||
):
|
|
||||||
ret.total_num_tokens = int(torch.sum(ret.seq_lens))
|
ret.total_num_tokens = int(torch.sum(ret.seq_lens))
|
||||||
|
|
||||||
if forward_mode != ForwardMode.DECODE:
|
if not fm.is_decode():
|
||||||
ret.init_multimuldal_info(batch)
|
ret.init_multimuldal_info(batch)
|
||||||
|
|
||||||
if model_runner.server_args.disable_flashinfer:
|
if model_runner.server_args.disable_flashinfer:
|
||||||
@@ -209,7 +214,7 @@ class InputMetadata:
|
|||||||
flashinfer_use_ragged = False
|
flashinfer_use_ragged = False
|
||||||
if not model_runner.server_args.disable_flashinfer:
|
if not model_runner.server_args.disable_flashinfer:
|
||||||
if (
|
if (
|
||||||
forward_mode != ForwardMode.DECODE
|
not fm.is_decode()
|
||||||
and int(torch.sum(ret.seq_lens)) > 4096
|
and int(torch.sum(ret.seq_lens)) > 4096
|
||||||
and model_runner.sliding_window_size is None
|
and model_runner.sliding_window_size is None
|
||||||
):
|
):
|
||||||
@@ -226,7 +231,7 @@ class InputMetadata:
|
|||||||
self.triton_start_loc = torch.zeros_like(self.seq_lens, dtype=torch.int32)
|
self.triton_start_loc = torch.zeros_like(self.seq_lens, dtype=torch.int32)
|
||||||
self.triton_start_loc[1:] = torch.cumsum(self.seq_lens[:-1], dim=0)
|
self.triton_start_loc[1:] = torch.cumsum(self.seq_lens[:-1], dim=0)
|
||||||
|
|
||||||
if self.forward_mode == ForwardMode.DECODE:
|
if self.forward_mode.is_decode():
|
||||||
self.triton_max_extend_len = None
|
self.triton_max_extend_len = None
|
||||||
else:
|
else:
|
||||||
self.triton_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
|
self.triton_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
|
||||||
@@ -239,7 +244,7 @@ class InputMetadata:
|
|||||||
prefix_lens_cpu,
|
prefix_lens_cpu,
|
||||||
flashinfer_use_ragged,
|
flashinfer_use_ragged,
|
||||||
):
|
):
|
||||||
if self.forward_mode == ForwardMode.DECODE:
|
if self.forward_mode.is_decode():
|
||||||
prefix_lens = None
|
prefix_lens = None
|
||||||
else:
|
else:
|
||||||
prefix_lens = self.extend_prefix_lens
|
prefix_lens = self.extend_prefix_lens
|
||||||
@@ -339,7 +344,7 @@ def update_flashinfer_indices(
|
|||||||
|
|
||||||
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
|
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
|
||||||
|
|
||||||
if forward_mode == ForwardMode.DECODE:
|
if forward_mode.is_decode():
|
||||||
# CUDA graph uses different flashinfer_decode_wrapper
|
# CUDA graph uses different flashinfer_decode_wrapper
|
||||||
if flashinfer_decode_wrapper is None:
|
if flashinfer_decode_wrapper is None:
|
||||||
flashinfer_decode_wrapper = model_runner.flashinfer_decode_wrapper
|
flashinfer_decode_wrapper = model_runner.flashinfer_decode_wrapper
|
||||||
@@ -388,7 +393,7 @@ def update_flashinfer_indices(
|
|||||||
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
|
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
|
||||||
for wrapper_id in range(2):
|
for wrapper_id in range(2):
|
||||||
if wrapper_id == 0:
|
if wrapper_id == 0:
|
||||||
if forward_mode == ForwardMode.DECODE:
|
if forward_mode.is_decode():
|
||||||
paged_kernel_lens = torch.minimum(
|
paged_kernel_lens = torch.minimum(
|
||||||
seq_lens, torch.tensor(model_runner.sliding_window_size + 1)
|
seq_lens, torch.tensor(model_runner.sliding_window_size + 1)
|
||||||
)
|
)
|
||||||
@@ -418,7 +423,7 @@ def update_flashinfer_indices(
|
|||||||
kv_indices,
|
kv_indices,
|
||||||
)
|
)
|
||||||
|
|
||||||
if forward_mode == ForwardMode.DECODE:
|
if forward_mode.is_decode():
|
||||||
# CUDA graph uses different flashinfer_decode_wrapper
|
# CUDA graph uses different flashinfer_decode_wrapper
|
||||||
if flashinfer_decode_wrapper is None:
|
if flashinfer_decode_wrapper is None:
|
||||||
flashinfer_decode_wrapper = model_runner.flashinfer_decode_wrapper
|
flashinfer_decode_wrapper = model_runner.flashinfer_decode_wrapper
|
||||||
|
|||||||
@@ -530,11 +530,7 @@ class ModelRunner:
|
|||||||
):
|
):
|
||||||
return self.cuda_graph_runner.replay(batch)
|
return self.cuda_graph_runner.replay(batch)
|
||||||
|
|
||||||
input_metadata = InputMetadata.from_schedule_batch(
|
input_metadata = InputMetadata.from_schedule_batch(self, batch)
|
||||||
self,
|
|
||||||
batch,
|
|
||||||
ForwardMode.DECODE,
|
|
||||||
)
|
|
||||||
|
|
||||||
return self.model.forward(
|
return self.model.forward(
|
||||||
batch.input_ids, input_metadata.positions, input_metadata
|
batch.input_ids, input_metadata.positions, input_metadata
|
||||||
@@ -542,11 +538,7 @@ class ModelRunner:
|
|||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def forward_extend(self, batch: ScheduleBatch):
|
def forward_extend(self, batch: ScheduleBatch):
|
||||||
input_metadata = InputMetadata.from_schedule_batch(
|
input_metadata = InputMetadata.from_schedule_batch(self, batch)
|
||||||
self,
|
|
||||||
batch,
|
|
||||||
forward_mode=ForwardMode.EXTEND,
|
|
||||||
)
|
|
||||||
if self.is_generation:
|
if self.is_generation:
|
||||||
return self.model.forward(
|
return self.model.forward(
|
||||||
batch.input_ids, input_metadata.positions, input_metadata
|
batch.input_ids, input_metadata.positions, input_metadata
|
||||||
@@ -562,11 +554,7 @@ class ModelRunner:
|
|||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def forward_extend_multi_modal(self, batch: ScheduleBatch):
|
def forward_extend_multi_modal(self, batch: ScheduleBatch):
|
||||||
input_metadata = InputMetadata.from_schedule_batch(
|
input_metadata = InputMetadata.from_schedule_batch(self, batch)
|
||||||
self,
|
|
||||||
batch,
|
|
||||||
forward_mode=ForwardMode.EXTEND,
|
|
||||||
)
|
|
||||||
return self.model.forward(
|
return self.model.forward(
|
||||||
batch.input_ids,
|
batch.input_ids,
|
||||||
input_metadata.positions,
|
input_metadata.positions,
|
||||||
@@ -577,16 +565,18 @@ class ModelRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, batch: ScheduleBatch, forward_mode: ForwardMode
|
self, batch: ScheduleBatch
|
||||||
) -> Tuple[SampleOutput, LogitsProcessorOutput]:
|
) -> Tuple[SampleOutput, LogitsProcessorOutput]:
|
||||||
if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
|
assert batch.forward_mode is not None
|
||||||
|
|
||||||
|
if self.is_multimodal_model and batch.forward_mode.is_extend():
|
||||||
return self.forward_extend_multi_modal(batch)
|
return self.forward_extend_multi_modal(batch)
|
||||||
elif forward_mode == ForwardMode.DECODE:
|
elif batch.forward_mode.is_decode():
|
||||||
return self.forward_decode(batch)
|
return self.forward_decode(batch)
|
||||||
elif forward_mode == ForwardMode.EXTEND:
|
elif batch.forward_mode.is_extend():
|
||||||
return self.forward_extend(batch)
|
return self.forward_extend(batch)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invaid forward mode: {forward_mode}")
|
raise ValueError(f"Invaid forward mode: {batch.forward_mode}")
|
||||||
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
|
|||||||
@@ -136,7 +136,7 @@ class LlavaBaseForCausalLM(nn.Module):
|
|||||||
image_sizes: Optional[List[List[int]]] = None,
|
image_sizes: Optional[List[List[int]]] = None,
|
||||||
image_offsets: Optional[List[int]] = None,
|
image_offsets: Optional[List[int]] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if input_metadata.forward_mode == ForwardMode.EXTEND:
|
if input_metadata.forward_mode.is_extend():
|
||||||
bs = input_metadata.batch_size
|
bs = input_metadata.batch_size
|
||||||
# Got List[List[str]] extend it to List[str]
|
# Got List[List[str]] extend it to List[str]
|
||||||
# The length of the List should be equal to batch size
|
# The length of the List should be equal to batch size
|
||||||
@@ -357,7 +357,7 @@ class LlavaBaseForCausalLM(nn.Module):
|
|||||||
return self.language_model(
|
return self.language_model(
|
||||||
input_ids, positions, input_metadata, input_embeds=input_embeds
|
input_ids, positions, input_metadata, input_embeds=input_embeds
|
||||||
)
|
)
|
||||||
elif input_metadata.forward_mode == ForwardMode.DECODE:
|
elif input_metadata.forward_mode.is_decode():
|
||||||
return self.language_model(input_ids, positions, input_metadata)
|
return self.language_model(input_ids, positions, input_metadata)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
|||||||
@@ -116,7 +116,7 @@ class LlavaVidForCausalLM(nn.Module):
|
|||||||
image_sizes: Optional[List[List[int]]] = None,
|
image_sizes: Optional[List[List[int]]] = None,
|
||||||
image_offsets: Optional[List[int]] = None,
|
image_offsets: Optional[List[int]] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if input_metadata.forward_mode == ForwardMode.EXTEND:
|
if input_metadata.forward_mode.is_extend():
|
||||||
bs = input_metadata.batch_size
|
bs = input_metadata.batch_size
|
||||||
|
|
||||||
# Embed text inputs
|
# Embed text inputs
|
||||||
@@ -199,7 +199,7 @@ class LlavaVidForCausalLM(nn.Module):
|
|||||||
return self.language_model(
|
return self.language_model(
|
||||||
input_ids, positions, input_metadata, input_embeds=input_embeds
|
input_ids, positions, input_metadata, input_embeds=input_embeds
|
||||||
)
|
)
|
||||||
elif input_metadata.forward_mode == ForwardMode.DECODE:
|
elif input_metadata.forward_mode.is_decode():
|
||||||
return self.language_model(input_ids, positions, input_metadata)
|
return self.language_model(input_ids, positions, input_metadata)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
|||||||
Reference in New Issue
Block a user