Organize code (rename, movement) (#953)
This commit is contained in:
@@ -41,18 +41,14 @@ from vllm.distributed import (
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
Batch,
|
||||
ForwardMode,
|
||||
InputMetadata,
|
||||
global_server_args_dict,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
||||
from sglang.srt.mem_cache.memory_pool import (
|
||||
MHATokenToKVPool,
|
||||
MLATokenToKVPool,
|
||||
ReqToTokenPool,
|
||||
)
|
||||
from sglang.srt.model_config import AttentionArch
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
get_available_gpu_memory,
|
||||
@@ -350,7 +346,7 @@ class ModelRunner:
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward_decode(self, batch: Batch):
|
||||
def forward_decode(self, batch: ScheduleBatch):
|
||||
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
|
||||
return self.cuda_graph_runner.replay(batch)
|
||||
|
||||
@@ -370,7 +366,7 @@ class ModelRunner:
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward_extend(self, batch: Batch):
|
||||
def forward_extend(self, batch: ScheduleBatch):
|
||||
input_metadata = InputMetadata.create(
|
||||
self,
|
||||
forward_mode=ForwardMode.EXTEND,
|
||||
@@ -387,7 +383,7 @@ class ModelRunner:
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward_extend_multi_modal(self, batch: Batch):
|
||||
def forward_extend_multi_modal(self, batch: ScheduleBatch):
|
||||
input_metadata = InputMetadata.create(
|
||||
self,
|
||||
forward_mode=ForwardMode.EXTEND,
|
||||
@@ -408,7 +404,7 @@ class ModelRunner:
|
||||
batch.image_offsets,
|
||||
)
|
||||
|
||||
def forward(self, batch: Batch, forward_mode: ForwardMode):
|
||||
def forward(self, batch: ScheduleBatch, forward_mode: ForwardMode):
|
||||
if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
|
||||
return self.forward_extend_multi_modal(batch)
|
||||
elif forward_mode == ForwardMode.DECODE:
|
||||
|
||||
Reference in New Issue
Block a user