Warmup cublas (#566)
This commit is contained in:
@@ -270,6 +270,7 @@ class ModelRunner:
|
||||
# Load the model and create memory pool
|
||||
self.load_model()
|
||||
self.init_memory_pool(total_gpu_memory)
|
||||
self.init_cublas()
|
||||
self.init_flash_infer()
|
||||
|
||||
def load_model(self):
|
||||
@@ -346,6 +347,15 @@ class ModelRunner:
|
||||
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
||||
)
|
||||
|
||||
def init_cublas(self):
|
||||
"""We need to run a small matmul to init cublas. Otherwise, it will raise some errors later."""
|
||||
dtype = torch.float16
|
||||
device = "cuda"
|
||||
a = torch.ones((16, 16), dtype=dtype, device=device)
|
||||
b = torch.ones((16, 16), dtype=dtype, device=device)
|
||||
c = a @ b
|
||||
return c
|
||||
|
||||
def init_flash_infer(self):
|
||||
if global_server_args_dict.get("enable_flashinfer", False):
|
||||
from flashinfer import (
|
||||
|
||||
@@ -410,7 +410,7 @@ class ModelTpServer:
|
||||
self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
|
||||
)
|
||||
logger.info(
|
||||
f"[gpu_id={self.gpu_id}] Prefil batch. "
|
||||
f"[gpu_id={self.gpu_id}] Prefill batch. "
|
||||
f"#new-seq: {len(can_run_list)}, "
|
||||
f"#new-token: {new_batch_input_tokens}, "
|
||||
f"#cached-token: {hit_tokens}, "
|
||||
|
||||
Reference in New Issue
Block a user