-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Entrypoint and format fix in logging and checkpointing * Docker (#35) * Entrypoint and format fix in logging and checkpointing (#34) * Docker image and entrypoint w/ GPU support
- Loading branch information
Showing
3 changed files
with
72 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,28 @@ | ||
FROM pytorch/pytorch:latest | ||
FROM pytorch/pytorch:2.1.0-cuda11.8-cudnn8-runtime | ||
|
||
WORKDIR /usr/src/app | ||
|
||
# Allow the container to use the GPU | ||
ENV NVIDIA_VISIBLE_DEVICES all | ||
|
||
# Copy everything from the current directory to /usr/src/app in the container | ||
COPY . . | ||
|
||
# Install requirements | ||
RUN pip install --upgrade pip | ||
RUN pip install -r requirements.txt | ||
|
||
# Install the package | ||
RUN pip install --editable . | ||
|
||
ENTRYPOINT [ "main.py" ] | ||
# Allow the app to use tensorboard from runs directory | ||
RUN chmod -R 777 runs | ||
|
||
# Expose the TensorBoard port | ||
EXPOSE 6006 | ||
|
||
# Make the entrypoint script executable | ||
RUN chmod +x /usr/src/app/entrypoint.sh | ||
|
||
# Set the entrypoint script to be executed | ||
ENTRYPOINT ["/usr/src/app/entrypoint.sh"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
#!/bin/bash | ||
|
||
# Function to handle the SIGINT signal (Ctrl-C) | ||
cleanup() { | ||
echo "SIGINT caught, stopping processes..." | ||
|
||
# Kill the TensorBoard process | ||
kill $TENSORBOARD_PID | ||
|
||
# Kill the training process | ||
kill $TRAIN_PID | ||
|
||
# Optionally, wait for the processes to stop | ||
wait $TENSORBOARD_PID | ||
wait $TRAIN_PID | ||
|
||
echo "Processes stopped. Exiting." | ||
exit 0 | ||
} | ||
|
||
# Set trap to catch SIGINT and call the cleanup function | ||
trap cleanup SIGINT | ||
|
||
# Check the first argument to determine the mode | ||
if [ "$1" = "train" ]; then | ||
# Remove the first argument ('train') | ||
shift | ||
|
||
# Run training command with all remaining arguments in the background | ||
python tools/train.py "$@" & | ||
TRAIN_PID=$! | ||
|
||
# Start TensorBoard in the background | ||
tensorboard --logdir=runs --host=0.0.0.0 & | ||
TENSORBOARD_PID=$! | ||
|
||
# Wait for the training process to complete | ||
wait $TRAIN_PID | ||
|
||
elif [ "$1" = "infer" ]; then | ||
# Run inference command | ||
# Assuming you might also want to pass arguments to infer.py in the future | ||
shift | ||
python tools/infer.py "$@" | ||
else | ||
echo "Invalid argument. Please use 'train' or 'infer'." | ||
fi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,4 +3,7 @@ torch | |
scikit-learn | ||
pytest | ||
pytest-cov | ||
matplotlib | ||
matplotlib | ||
randomname | ||
pyyaml | ||
tensorboard |