diff --git a/Dockerfile b/Dockerfile index 57bd924..a97141f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,9 +1,28 @@ -FROM pytorch/pytorch:latest +FROM pytorch/pytorch:2.1.0-cuda11.8-cudnn8-runtime WORKDIR /usr/src/app +# Allow the container to use the GPU +ENV NVIDIA_VISIBLE_DEVICES all + +# Copy everything from the current directory to /usr/src/app in the container COPY . . +# Install requirements +RUN pip install --upgrade pip +RUN pip install -r requirements.txt + +# Install the package RUN pip install --editable . -ENTRYPOINT [ "main.py" ] +# Allow the app to use tensorboard from runs directory +RUN chmod -R 777 runs + +# Expose the TensorBoard port +EXPOSE 6006 + +# Make the entrypoint script executable +RUN chmod +x /usr/src/app/entrypoint.sh + +# Set the entrypoint script to be executed +ENTRYPOINT ["/usr/src/app/entrypoint.sh"] diff --git a/entrypoint.sh b/entrypoint.sh new file mode 100755 index 0000000..8dfa878 --- /dev/null +++ b/entrypoint.sh @@ -0,0 +1,47 @@ +#!/bin/bash + +# Function to handle the SIGINT signal (Ctrl-C) +cleanup() { + echo "SIGINT caught, stopping processes..." + + # Kill the TensorBoard process + kill $TENSORBOARD_PID + + # Kill the training process + kill $TRAIN_PID + + # Optionally, wait for the processes to stop + wait $TENSORBOARD_PID + wait $TRAIN_PID + + echo "Processes stopped. Exiting." + exit 0 +} + +# Set trap to catch SIGINT and call the cleanup function +trap cleanup SIGINT + +# Check the first argument to determine the mode +if [ "$1" = "train" ]; then + # Remove the first argument ('train') + shift + + # Run training command with all remaining arguments in the background + python tools/train.py "$@" & + TRAIN_PID=$! + + # Start TensorBoard in the background + tensorboard --logdir=runs --host=0.0.0.0 & + TENSORBOARD_PID=$! + + # Wait for the training process to complete + wait $TRAIN_PID + +elif [ "$1" = "infer" ]; then + # Run inference command + # Assuming you might also want to pass arguments to infer.py in the future + shift + python tools/infer.py "$@" +else + echo "Invalid argument. Please use 'train' or 'infer'." +fi diff --git a/requirements.txt b/requirements.txt index ce02cc4..dc3ca56 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,7 @@ torch scikit-learn pytest pytest-cov -matplotlib \ No newline at end of file +matplotlib +randomname +pyyaml +tensorboard \ No newline at end of file