Move sampler into CUDA graph (#1201)

Co-authored-by: Yineng Zhang <me@zhyncs.com>
This commit is contained in:
Liangsheng Yin
2024-08-26 07:02:50 -07:00
committed by GitHub
parent 97589a60a2
commit 75ce37f401
28 changed files with 336 additions and 110 deletions

View File

@@ -21,7 +21,7 @@ import importlib.resources
import logging
import pkgutil
from functools import lru_cache
from typing import Optional, Type
from typing import Optional, Tuple, Type
import torch
import torch.nn as nn
@@ -44,6 +44,8 @@ from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import ModelRegistry
from sglang.global_config import global_config
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import SampleOutput
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
from sglang.srt.mem_cache.memory_pool import (
MHATokenToKVPool,
@@ -514,7 +516,11 @@ class ModelRunner:
@torch.inference_mode()
def forward_decode(self, batch: ScheduleBatch):
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
if (
self.cuda_graph_runner
and self.cuda_graph_runner.can_run(len(batch.reqs))
and not batch.sampling_info.has_bias()
):
return self.cuda_graph_runner.replay(batch)
input_metadata = InputMetadata.from_schedule_batch(
@@ -563,7 +569,9 @@ class ModelRunner:
input_metadata.image_offsets,
)
def forward(self, batch: ScheduleBatch, forward_mode: ForwardMode):
def forward(
self, batch: ScheduleBatch, forward_mode: ForwardMode
) -> Tuple[SampleOutput, LogitsProcessorOutput]:
if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
return self.forward_extend_multi_modal(batch)
elif forward_mode == ForwardMode.DECODE: