[feat] apply deep_gemm compile_mode to skip launch (#9879)
This commit is contained in:
@@ -85,10 +85,10 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip setuptools wheel html5li
|
|||||||
&& python3 -m pip install --no-cache-dir nvidia-nccl-cu12==2.27.6 --force-reinstall --no-deps \
|
&& python3 -m pip install --no-cache-dir nvidia-nccl-cu12==2.27.6 --force-reinstall --no-deps \
|
||||||
&& python3 -m flashinfer --download-cubin \
|
&& python3 -m flashinfer --download-cubin \
|
||||||
&& if [ "$CUDA_VERSION" = "12.8.1" ]; then \
|
&& if [ "$CUDA_VERSION" = "12.8.1" ]; then \
|
||||||
python3 -m pip install --no-cache-dir https://github.com/sgl-project/whl/releases/download/v0.3.7.post1/sgl_kernel-0.3.7.post1+cu128-cp310-abi3-manylinux2014_x86_64.whl --force-reinstall --no-deps ; \
|
python3 -m pip install --no-cache-dir https://github.com/sgl-project/whl/releases/download/v0.3.8/sgl_kernel-0.3.8+cu128-cp310-abi3-manylinux2014_x86_64.whl --force-reinstall --no-deps ; \
|
||||||
fi \
|
fi \
|
||||||
&& if [ "$CUDA_VERSION" = "12.9.1" ]; then \
|
&& if [ "$CUDA_VERSION" = "12.9.1" ]; then \
|
||||||
python3 -m pip install --no-cache-dir https://github.com/sgl-project/whl/releases/download/v0.3.7.post1/sgl_kernel-0.3.7.post1+cu129-cp310-abi3-manylinux2014_x86_64.whl --force-reinstall --no-deps ; \
|
python3 -m pip install --no-cache-dir https://github.com/sgl-project/whl/releases/download/v0.3.8/sgl_kernel-0.3.8+cu129-cp310-abi3-manylinux2014_x86_64.whl --force-reinstall --no-deps ; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Download source files
|
# Download source files
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ runtime_common = [
|
|||||||
|
|
||||||
srt = [
|
srt = [
|
||||||
"sglang[runtime_common]",
|
"sglang[runtime_common]",
|
||||||
"sgl-kernel==0.3.7.post1",
|
"sgl-kernel==0.3.8",
|
||||||
"torch==2.8.0",
|
"torch==2.8.0",
|
||||||
"torchaudio==2.8.0",
|
"torchaudio==2.8.0",
|
||||||
"torchvision",
|
"torchvision",
|
||||||
|
|||||||
@@ -681,7 +681,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|||||||
if _is_cuda and not get_bool_env_var("SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK"):
|
if _is_cuda and not get_bool_env_var("SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK"):
|
||||||
assert_pkg_version(
|
assert_pkg_version(
|
||||||
"sgl-kernel",
|
"sgl-kernel",
|
||||||
"0.3.7.post1",
|
"0.3.8",
|
||||||
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
|
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -132,9 +132,17 @@ def _compile_deep_gemm_one_type_all(
|
|||||||
kernel_type, max_m=max(m_list), n=n, k=k, num_groups=num_groups
|
kernel_type, max_m=max(m_list), n=n, k=k, num_groups=num_groups
|
||||||
)
|
)
|
||||||
|
|
||||||
|
old_compile_mode = deep_gemm.get_compile_mode()
|
||||||
|
deep_gemm.set_compile_mode(1)
|
||||||
# TODO can use multi thread
|
# TODO can use multi thread
|
||||||
for m in tqdm(m_list, desc=f"DeepGEMM warmup"):
|
for m in tqdm(m_list, desc=f"DeepGEMM warmup"):
|
||||||
executor.execute(m=m)
|
executor.execute(m=m)
|
||||||
|
deep_gemm.set_compile_mode(old_compile_mode)
|
||||||
|
|
||||||
|
# clean up input buffers
|
||||||
|
torch.cuda.current_stream().synchronize()
|
||||||
|
del executor
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
class _BaseWarmupExecutor:
|
class _BaseWarmupExecutor:
|
||||||
|
|||||||
Reference in New Issue
Block a user