Move sampler into CUDA graph (#1201)
Co-authored-by: Yineng Zhang <me@zhyncs.com>
This commit is contained in:
@@ -21,7 +21,7 @@ import importlib.resources
|
||||
import logging
|
||||
import pkgutil
|
||||
from functools import lru_cache
|
||||
from typing import Optional, Type
|
||||
from typing import Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -44,6 +44,8 @@ from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.layers.sampler import SampleOutput
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
||||
from sglang.srt.mem_cache.memory_pool import (
|
||||
MHATokenToKVPool,
|
||||
@@ -514,7 +516,11 @@ class ModelRunner:
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward_decode(self, batch: ScheduleBatch):
|
||||
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
|
||||
if (
|
||||
self.cuda_graph_runner
|
||||
and self.cuda_graph_runner.can_run(len(batch.reqs))
|
||||
and not batch.sampling_info.has_bias()
|
||||
):
|
||||
return self.cuda_graph_runner.replay(batch)
|
||||
|
||||
input_metadata = InputMetadata.from_schedule_batch(
|
||||
@@ -563,7 +569,9 @@ class ModelRunner:
|
||||
input_metadata.image_offsets,
|
||||
)
|
||||
|
||||
def forward(self, batch: ScheduleBatch, forward_mode: ForwardMode):
|
||||
def forward(
|
||||
self, batch: ScheduleBatch, forward_mode: ForwardMode
|
||||
) -> Tuple[SampleOutput, LogitsProcessorOutput]:
|
||||
if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
|
||||
return self.forward_extend_multi_modal(batch)
|
||||
elif forward_mode == ForwardMode.DECODE:
|
||||
|
||||
Reference in New Issue
Block a user