Unify forward mode (#1360)
This commit is contained in:
@@ -60,7 +60,6 @@ import torch.distributed as dist
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
||||
from sglang.srt.model_config import ModelConfig
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
@@ -208,14 +207,14 @@ def extend(reqs, model_runner):
|
||||
tree_cache=None,
|
||||
)
|
||||
batch.prepare_for_extend(model_runner.model_config.vocab_size)
|
||||
sample_output, logits_output = model_runner.forward(batch, ForwardMode.EXTEND)
|
||||
sample_output, logits_output = model_runner.forward(batch)
|
||||
next_token_ids = sample_output.batch_next_token_ids.tolist()
|
||||
return next_token_ids, logits_output.next_token_logits, batch
|
||||
|
||||
|
||||
def decode(input_token_ids, batch, model_runner):
|
||||
batch.prepare_for_decode(input_token_ids)
|
||||
sample_output, logits_output = model_runner.forward(batch, ForwardMode.DECODE)
|
||||
sample_output, logits_output = model_runner.forward(batch)
|
||||
next_token_ids = sample_output.batch_next_token_ids.tolist()
|
||||
return next_token_ids, logits_output.next_token_logits
|
||||
|
||||
|
||||
Reference in New Issue
Block a user