Feat/support encoder model (like bert) (#4887)
This commit is contained in:
@@ -51,6 +51,8 @@ NUM_TOP_LOGPROBS = 5
|
||||
def get_dtype_str(torch_dtype):
|
||||
if torch_dtype is torch.float16:
|
||||
return "float16"
|
||||
if torch_dtype is torch.float32:
|
||||
return "float32"
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -447,6 +449,7 @@ class SRTRunner:
|
||||
port: int = DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
|
||||
lora_paths: List[str] = None,
|
||||
max_loras_per_batch: int = 4,
|
||||
attention_backend: Optional[str] = None,
|
||||
lora_backend: str = "triton",
|
||||
disable_cuda_graph: bool = False,
|
||||
disable_radix_cache: bool = False,
|
||||
@@ -487,6 +490,7 @@ class SRTRunner:
|
||||
lora_paths=lora_paths,
|
||||
max_loras_per_batch=max_loras_per_batch,
|
||||
lora_backend=lora_backend,
|
||||
attention_backend=attention_backend,
|
||||
disable_cuda_graph=disable_cuda_graph,
|
||||
disable_radix_cache=disable_radix_cache,
|
||||
chunked_prefill_size=chunked_prefill_size,
|
||||
|
||||
Reference in New Issue
Block a user