FROM pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel
WORKDIR /workspace
# ENV PT_SDPA_ENABLE_HEAD_DIM_PADDING=1
RUN pip install diffusers==0.34.0 transformers sentencepiece
COPY main.py test.sh dataset.json /workspace/