Organize code (rename, movement) (#953)

This commit is contained in:
Liangsheng Yin
2024-08-06 20:50:32 -07:00
committed by GitHub
parent ad56e68495
commit 87e8c090e9
29 changed files with 304 additions and 289 deletions

View File

@@ -41,18 +41,14 @@ from vllm.distributed import (
from vllm.model_executor.models import ModelRegistry
from sglang.global_config import global_config
from sglang.srt.managers.schedule_batch import (
Batch,
ForwardMode,
InputMetadata,
global_server_args_dict,
)
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
from sglang.srt.mem_cache.memory_pool import (
MHATokenToKVPool,
MLATokenToKVPool,
ReqToTokenPool,
)
from sglang.srt.model_config import AttentionArch
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
get_available_gpu_memory,
@@ -350,7 +346,7 @@ class ModelRunner:
)
@torch.inference_mode()
def forward_decode(self, batch: Batch):
def forward_decode(self, batch: ScheduleBatch):
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
return self.cuda_graph_runner.replay(batch)
@@ -370,7 +366,7 @@ class ModelRunner:
)
@torch.inference_mode()
def forward_extend(self, batch: Batch):
def forward_extend(self, batch: ScheduleBatch):
input_metadata = InputMetadata.create(
self,
forward_mode=ForwardMode.EXTEND,
@@ -387,7 +383,7 @@ class ModelRunner:
)
@torch.inference_mode()
def forward_extend_multi_modal(self, batch: Batch):
def forward_extend_multi_modal(self, batch: ScheduleBatch):
input_metadata = InputMetadata.create(
self,
forward_mode=ForwardMode.EXTEND,
@@ -408,7 +404,7 @@ class ModelRunner:
batch.image_offsets,
)
def forward(self, batch: Batch, forward_mode: ForwardMode):
def forward(self, batch: ScheduleBatch, forward_mode: ForwardMode):
if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
return self.forward_extend_multi_modal(batch)
elif forward_mode == ForwardMode.DECODE: