diff --git a/docker/Dockerfile b/docker/Dockerfile index aac130d8a..2801b0737 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -102,7 +102,18 @@ RUN cd /sgl-workspace/nvshmem && \ # Install DeepEP RUN cd /sgl-workspace/DeepEP && \ - TORCH_CUDA_ARCH_LIST='9.0;10.0' NVSHMEM_DIR=${NVSHMEM_DIR} pip install . + case "$CUDA_VERSION" in \ + 12.6.1) \ + CHOSEN_TORCH_CUDA_ARCH_LIST='9.0' \ + ;; \ + 12.8.1|12.9.1) \ + CHOSEN_TORCH_CUDA_ARCH_LIST='9.0;10.0' \ + ;; \ + *) \ + echo "Unsupported CUDA version: $CUDA_VERSION" && exit 1 \ + ;; \ + esac && \ + NVSHMEM_DIR=${NVSHMEM_DIR} TORCH_CUDA_ARCH_LIST="${CHOSEN_TORCH_CUDA_ARCH_LIST}" pip install . # Python tools RUN python3 -m pip install --no-cache-dir \