Let ModelRunner take InputMetadata as input, instead of ScheduleBatch (#1541)

This commit is contained in:
Lianmin Zheng
2024-09-29 20:28:45 -07:00
committed by GitHub
parent 55b974f96f
commit 3f0fe08d37
12 changed files with 142 additions and 157 deletions

View File

@@ -71,10 +71,10 @@ class ModelOutput:
class HFRunner:
def __init__(
self,
model_path,
torch_dtype,
model_type="generation",
output_str_only=False,
model_path: str,
torch_dtype: torch.dtype,
model_type: str = "generation",
output_str_only: bool = False,
):
self.model_type = model_type
self.output_str_only = output_str_only
@@ -244,15 +244,15 @@ class HFRunner:
class SRTRunner:
def __init__(
self,
model_path,
torch_dtype,
model_type,
tp_size=1,
port=DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
lora_paths=None,
max_loras_per_batch=4,
disable_cuda_graph=False,
disable_radix_cache=False,
model_path: str,
torch_dtype: torch.dtype,
model_type: str,
tp_size: int = 1,
port: int = DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
lora_paths: List[str] = None,
max_loras_per_batch: int = 4,
disable_cuda_graph: bool = False,
disable_radix_cache: bool = False,
):
self.model_type = model_type
self.is_generation = model_type == "generation"