[feat] apply deep_gemm compile_mode to skip launch (#9879)
This commit is contained in:
@@ -58,7 +58,7 @@ runtime_common = [
|
||||
|
||||
srt = [
|
||||
"sglang[runtime_common]",
|
||||
"sgl-kernel==0.3.7.post1",
|
||||
"sgl-kernel==0.3.8",
|
||||
"torch==2.8.0",
|
||||
"torchaudio==2.8.0",
|
||||
"torchvision",
|
||||
|
||||
@@ -681,7 +681,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
||||
if _is_cuda and not get_bool_env_var("SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK"):
|
||||
assert_pkg_version(
|
||||
"sgl-kernel",
|
||||
"0.3.7.post1",
|
||||
"0.3.8",
|
||||
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
|
||||
)
|
||||
|
||||
|
||||
@@ -132,9 +132,17 @@ def _compile_deep_gemm_one_type_all(
|
||||
kernel_type, max_m=max(m_list), n=n, k=k, num_groups=num_groups
|
||||
)
|
||||
|
||||
old_compile_mode = deep_gemm.get_compile_mode()
|
||||
deep_gemm.set_compile_mode(1)
|
||||
# TODO can use multi thread
|
||||
for m in tqdm(m_list, desc=f"DeepGEMM warmup"):
|
||||
executor.execute(m=m)
|
||||
deep_gemm.set_compile_mode(old_compile_mode)
|
||||
|
||||
# clean up input buffers
|
||||
torch.cuda.current_stream().synchronize()
|
||||
del executor
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
class _BaseWarmupExecutor:
|
||||
|
||||
Reference in New Issue
Block a user