[FEAT] Add transformers backend support (#5929)
This commit is contained in:
@@ -455,6 +455,7 @@ class SRTRunner:
|
||||
torch_dtype: torch.dtype,
|
||||
model_type: str,
|
||||
tp_size: int = 1,
|
||||
impl: str = "auto",
|
||||
port: int = DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
|
||||
lora_paths: List[str] = None,
|
||||
max_loras_per_batch: int = 4,
|
||||
@@ -475,6 +476,7 @@ class SRTRunner:
|
||||
speculative_num_draft_tokens: Optional[int] = None,
|
||||
disable_overlap_schedule: bool = False,
|
||||
disable_custom_all_reduce: bool = False,
|
||||
torchao_config: Optional[str] = None,
|
||||
):
|
||||
self.model_type = model_type
|
||||
self.is_generation = model_type == "generation"
|
||||
@@ -493,6 +495,8 @@ class SRTRunner:
|
||||
tp_size=tp_size,
|
||||
dtype=get_dtype_str(torch_dtype),
|
||||
port=port,
|
||||
impl=impl,
|
||||
torchao_config=torchao_config,
|
||||
mem_fraction_static=mem_fraction_static,
|
||||
trust_remote_code=trust_remote_code,
|
||||
is_embedding=not self.is_generation,
|
||||
|
||||
Reference in New Issue
Block a user