init
This commit is contained in:
107
Dockerfile.rocm
Normal file
107
Dockerfile.rocm
Normal file
@@ -0,0 +1,107 @@
|
||||
# default base image
|
||||
ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
|
||||
|
||||
FROM $BASE_IMAGE
|
||||
|
||||
ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
|
||||
|
||||
RUN echo "Base image is $BASE_IMAGE"
|
||||
|
||||
# BASE_IMAGE for ROCm_5.7: "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1"
|
||||
# BASE_IMAGE for ROCm_6.0: "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
|
||||
|
||||
|
||||
ARG FA_GFX_ARCHS="gfx90a;gfx942"
|
||||
RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS"
|
||||
|
||||
ARG FA_BRANCH="ae7928c"
|
||||
RUN echo "FA_BRANCH is $FA_BRANCH"
|
||||
|
||||
# whether to build flash-attention
|
||||
# if 0, will not build flash attention
|
||||
# this is useful for gfx target where flash-attention is not supported
|
||||
# In that case, we need to use the python reference attention implementation in vllm
|
||||
ARG BUILD_FA="1"
|
||||
|
||||
# whether to build triton on rocm
|
||||
ARG BUILD_TRITON="1"
|
||||
|
||||
# Install some basic utilities
|
||||
RUN apt-get update && apt-get install python3 python3-pip -y
|
||||
|
||||
# Install some basic utilities
|
||||
RUN apt-get update && apt-get install -y \
|
||||
curl \
|
||||
ca-certificates \
|
||||
sudo \
|
||||
git \
|
||||
bzip2 \
|
||||
libx11-6 \
|
||||
build-essential \
|
||||
wget \
|
||||
unzip \
|
||||
nvidia-cuda-toolkit \
|
||||
tmux \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
### Mount Point ###
|
||||
# When launching the container, mount the code directory to /app
|
||||
ARG APP_MOUNT=/vllm-workspace
|
||||
VOLUME [ ${APP_MOUNT} ]
|
||||
WORKDIR ${APP_MOUNT}
|
||||
|
||||
RUN python3 -m pip install --upgrade pip
|
||||
RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas
|
||||
|
||||
ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer
|
||||
ENV PATH=$PATH:/opt/rocm/bin:/libtorch/bin:
|
||||
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/libtorch/lib:
|
||||
ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/libtorch/include:/libtorch/include/torch/csrc/api/include/:/opt/rocm/include/:
|
||||
|
||||
# Install ROCm flash-attention
|
||||
RUN if [ "$BUILD_FA" = "1" ]; then \
|
||||
mkdir libs \
|
||||
&& cd libs \
|
||||
&& git clone https://github.com/ROCm/flash-attention.git \
|
||||
&& cd flash-attention \
|
||||
&& git checkout ${FA_BRANCH} \
|
||||
&& git submodule update --init \
|
||||
&& export GPU_ARCHS=${FA_GFX_ARCHS} \
|
||||
&& if [ "$BASE_IMAGE" = "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" ]; then \
|
||||
patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch; fi \
|
||||
&& python3 setup.py install \
|
||||
&& cd ..; \
|
||||
fi
|
||||
|
||||
# Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt.
|
||||
# Manually removed it so that later steps of numpy upgrade can continue
|
||||
RUN if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" ]; then \
|
||||
rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/; fi
|
||||
|
||||
# build triton
|
||||
RUN if [ "$BUILD_TRITON" = "1" ]; then \
|
||||
mkdir -p libs \
|
||||
&& cd libs \
|
||||
&& pip uninstall -y triton \
|
||||
&& git clone https://github.com/ROCm/triton.git \
|
||||
&& cd triton/python \
|
||||
&& pip3 install . \
|
||||
&& cd ../..; \
|
||||
fi
|
||||
|
||||
WORKDIR /vllm-workspace
|
||||
COPY . .
|
||||
|
||||
RUN python3 -m pip install --upgrade pip numba
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
pip install -U -r requirements-rocm.txt \
|
||||
&& patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h ./rocm_patch/rocm_bf16.patch \
|
||||
&& python3 setup.py install \
|
||||
&& cp build/lib.linux-x86_64-cpython-39/vllm/_C.cpython-39-x86_64-linux-gnu.so vllm/ \
|
||||
&& cd ..
|
||||
|
||||
RUN python3 -m pip install --upgrade pip
|
||||
RUN python3 -m pip install --no-cache-dir ray[all]==2.9.3
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
Reference in New Issue
Block a user