Sampler cudagraph (#1253)

This commit is contained in:
Liangsheng Yin
2024-08-28 18:58:52 -07:00
committed by GitHub
parent 8153168c96
commit 381dd57bd6
29 changed files with 342 additions and 116 deletions

View File

@@ -21,7 +21,7 @@ import importlib.resources
import logging
import pkgutil
from functools import lru_cache
from typing import Optional, Type
from typing import Optional, Tuple, Type
import torch
import torch.nn as nn
@@ -44,6 +44,8 @@ from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import ModelRegistry
from sglang.global_config import global_config
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import SampleOutput
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
from sglang.srt.mem_cache.memory_pool import (
MHATokenToKVPool,
@@ -524,7 +526,11 @@ class ModelRunner:
@torch.inference_mode()
def forward_decode(self, batch: ScheduleBatch):
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
if (
self.cuda_graph_runner
and self.cuda_graph_runner.can_run(len(batch.reqs))
and not batch.sampling_info.has_bias()
):
return self.cuda_graph_runner.replay(batch)
input_metadata = InputMetadata.from_schedule_batch(
@@ -573,7 +579,9 @@ class ModelRunner:
input_metadata.image_offsets,
)
def forward(self, batch: ScheduleBatch, forward_mode: ForwardMode):
def forward(
self, batch: ScheduleBatch, forward_mode: ForwardMode
) -> Tuple[SampleOutput, LogitsProcessorOutput]:
if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
return self.forward_extend_multi_modal(batch)
elif forward_mode == ForwardMode.DECODE: