Skip to content

Commit

Permalink
Docker (#35)
Browse files Browse the repository at this point in the history
* Entrypoint and format fix in logging and checkpointing (#34)

* Docker image and entrypoint w/ GPU support
  • Loading branch information
kaseris committed Dec 1, 2023
1 parent 077b8a9 commit 5ee077c
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 3 deletions.
23 changes: 21 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,9 +1,28 @@
FROM pytorch/pytorch:latest
FROM pytorch/pytorch:2.1.0-cuda11.8-cudnn8-runtime

WORKDIR /usr/src/app

# Allow the container to use the GPU
ENV NVIDIA_VISIBLE_DEVICES all

# Copy everything from the current directory to /usr/src/app in the container
COPY . .

# Install requirements
RUN pip install --upgrade pip
RUN pip install -r requirements.txt

# Install the package
RUN pip install --editable .

ENTRYPOINT [ "main.py" ]
# Allow the app to use tensorboard from runs directory
RUN chmod -R 777 runs

# Expose the TensorBoard port
EXPOSE 6006

# Make the entrypoint script executable
RUN chmod +x /usr/src/app/entrypoint.sh

# Set the entrypoint script to be executed
ENTRYPOINT ["/usr/src/app/entrypoint.sh"]
47 changes: 47 additions & 0 deletions entrypoint.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#!/bin/bash

# Function to handle the SIGINT signal (Ctrl-C)
cleanup() {
echo "SIGINT caught, stopping processes..."

# Kill the TensorBoard process
kill $TENSORBOARD_PID

# Kill the training process
kill $TRAIN_PID

# Optionally, wait for the processes to stop
wait $TENSORBOARD_PID
wait $TRAIN_PID

echo "Processes stopped. Exiting."
exit 0
}

# Set trap to catch SIGINT and call the cleanup function
trap cleanup SIGINT

# Check the first argument to determine the mode
if [ "$1" = "train" ]; then
# Remove the first argument ('train')
shift

# Run training command with all remaining arguments in the background
python tools/train.py "$@" &
TRAIN_PID=$!

# Start TensorBoard in the background
tensorboard --logdir=runs --host=0.0.0.0 &
TENSORBOARD_PID=$!

# Wait for the training process to complete
wait $TRAIN_PID

elif [ "$1" = "infer" ]; then
# Run inference command
# Assuming you might also want to pass arguments to infer.py in the future
shift
python tools/infer.py "$@"
else
echo "Invalid argument. Please use 'train' or 'infer'."
fi
5 changes: 4 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,7 @@ torch
scikit-learn
pytest
pytest-cov
matplotlib
matplotlib
randomname
pyyaml
tensorboard

0 comments on commit 5ee077c

Please sign in to comment.