Let ModelRunner take InputMetadata as input, instead of ScheduleBatch (#1541)
This commit is contained in:
@@ -71,10 +71,10 @@ class ModelOutput:
|
||||
class HFRunner:
|
||||
def __init__(
|
||||
self,
|
||||
model_path,
|
||||
torch_dtype,
|
||||
model_type="generation",
|
||||
output_str_only=False,
|
||||
model_path: str,
|
||||
torch_dtype: torch.dtype,
|
||||
model_type: str = "generation",
|
||||
output_str_only: bool = False,
|
||||
):
|
||||
self.model_type = model_type
|
||||
self.output_str_only = output_str_only
|
||||
@@ -244,15 +244,15 @@ class HFRunner:
|
||||
class SRTRunner:
|
||||
def __init__(
|
||||
self,
|
||||
model_path,
|
||||
torch_dtype,
|
||||
model_type,
|
||||
tp_size=1,
|
||||
port=DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
|
||||
lora_paths=None,
|
||||
max_loras_per_batch=4,
|
||||
disable_cuda_graph=False,
|
||||
disable_radix_cache=False,
|
||||
model_path: str,
|
||||
torch_dtype: torch.dtype,
|
||||
model_type: str,
|
||||
tp_size: int = 1,
|
||||
port: int = DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
|
||||
lora_paths: List[str] = None,
|
||||
max_loras_per_batch: int = 4,
|
||||
disable_cuda_graph: bool = False,
|
||||
disable_radix_cache: bool = False,
|
||||
):
|
||||
self.model_type = model_type
|
||||
self.is_generation = model_type == "generation"
|
||||
|
||||
Reference in New Issue
Block a user