hotfix: revert sampler CUDA Graph (#1242)

This commit is contained in:
Yineng Zhang
2024-08-28 21:16:47 +10:00
committed by GitHub
parent 184ae1c683
commit f25f4dfde5
33 changed files with 119 additions and 348 deletions

View File

@@ -21,7 +21,7 @@ import importlib.resources
import logging
import pkgutil
from functools import lru_cache
from typing import Optional, Tuple, Type
from typing import Optional, Type
import torch
import torch.nn as nn
@@ -44,8 +44,6 @@ from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import ModelRegistry
from sglang.global_config import global_config
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import SampleOutput
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
from sglang.srt.mem_cache.memory_pool import (
MHATokenToKVPool,
@@ -517,11 +515,7 @@ class ModelRunner:
@torch.inference_mode()
def forward_decode(self, batch: ScheduleBatch):
if (
self.cuda_graph_runner
and self.cuda_graph_runner.can_run(len(batch.reqs))
and not batch.sampling_info.has_bias()
):
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
return self.cuda_graph_runner.replay(batch)
input_metadata = InputMetadata.from_schedule_batch(
@@ -570,9 +564,7 @@ class ModelRunner:
input_metadata.image_offsets,
)
def forward(
self, batch: ScheduleBatch, forward_mode: ForwardMode
) -> Tuple[SampleOutput, LogitsProcessorOutput]:
def forward(self, batch: ScheduleBatch, forward_mode: ForwardMode):
if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
return self.forward_extend_multi_modal(batch)
elif forward_mode == ForwardMode.DECODE: