This repository seeks to be an example of how to create a BERT-like model for text classification and deploy it on a container that can be used by (API Gateway + Lambda + SageMaker endpoint) or other cloud vendors (e.g., Azure, GCP). Currently, this repository seeks to be applied for binary text classification.
We used the tips from Roblox post How We Scaled Bert To Serve 1+ Billion Daily Requests on CPUs to run the inference on CPU units, as its cheaper than using GPU for serving.
bert-textclassification/
├─ ops/
│ ├─ config/
│ ├─ *.py (to be run locally)
├─ local_test/
│ ├─ test_dir/
│ │ ├─ input/
│ │ │ ├─ config/
│ │ │ ├─ data/
│ │ │ │ ├─ train/
│ │ │ │ ├─ test/
│ │ │ │ ├─ validation/
│ │ ├─ model/
│ │ ├─ output/
│ ├─ *.sh (to be run locally or trigger local or cloud executions)
├─ src/
│ ├─ *.py
├─ tests/
│ ├─ *.py (to be run locally or trigger by git/jenkins/drone or another CI/CD tool)
├─ README.md
Contains devops/operation files
A directory containing scripts and configurations to trigger training and inference jobs locally.
- train-local.sh: trigger the local training container.
- serve-local.sh: trigger the local serving container and launch a local flask API.
- test-dir: The directory that is mounted on the container with test data mounted everywhere that matches the schema of the container.
- build_and_push.sh: A script to trigger the container build and then push it to the AWS SageMaker.
- sagemaker_training.sh: Triggers the SageMaker training job with the parameters defined.
- sagemaker_hyperparameter.sh: Triggers a SageMaker Hyperparamter Training Job to optimize the algorithm parameters
Module containing classes and helper functions. We use the following libraries to create a production ready inference server container:
- nginx : https://www.nginx.com/
- gunicorn : https://gunicorn.org/
- flask https://flask.palletsprojects.com/en/
When SageMaker starts a container, it invokes the container with an argument of train or serve. We configure this container to receive the operation as an argument that will be executed. The scripts on source folders are the following:
- api.py: The API interface with methods.
- train: The main model training script. When building your own algorithm, you will edit it to include your training code.
- serve: the wrapper that starts the inference server. In most cases, you can use this file as is.
- wsgi.py: The startup shell for individual server workers. This only needs to be changed if you changed where predictor.py is located or if it was renamed.
- predictor.py: This is the file where you can include your business rules before or after the model inference.
- nginx.conf: The configuration of the nginx server.
Contains test files for model functions or code and specific business rule cases. You can run local tests with pytest -v
.