Spatio-Temporal Classification of π Basketball Actions using 3D-CNN Models trained on the SpaceJam Dataset.
LeBron shooting over Deandre Jordan
Utilizing the SpaceJam Basketball Action Dataset Repo, I aim to create a model that takes a video of a basketball game to classify a given action for each of the players tracked with a bounding box. There are two essential parts for this program: R(2+1)D Model (Can be any 3D CNN architecture) and the player tracking. The deep learning framework used to train the network was PyTorch and the machine used to train the model was the Nvidia RTX 3060ti GPU.
This is a demo video from the SpaceJam Repo.
A pretrained baseline R(2+1)D CNN (pretrained on kinetics-400 dataset) from torchvision.models is used and further fine-tuned on the SpaceJam dataset. Any 3D CNN architecture can be used, but for this project it was decided that the R(2+1)D was a perfect balance in terms of number of parameters and overall model performance. It was also shown in the paper that factorizing 3D convolutional filters into separate spatial and temporal dimensions, alongside residual learning yields significant gains in accuracy. The training was done at train.py.
As mentioned above, the SpaceJam Basketball Action Dataset was used to train the R(2+1)D CNN model for video/action classification of basketball actions. The Repo contains two datasets (clips->.mp4 files and joints -> .npy files) of basketball single-player actions. The size of the two final annotated datasets is about 32,560 examples. Custom dataloaders were used for the basketball dataset in the dataset.py.
After reading the thesis Classificazione di Azioni Cestistiche mediante Tecniche di Deep Learning, (Written by Simone Francia) it was determined that the poorest classes with examples less than 2000 examples were augmented. Dribble, Ball in Hand, Pass, Block, Pick and Shoot were among the classes that were augmented. Augmentations were applied by running the script augment_videos.py and saved in a given output directory. Translation and Rotation were the only augmentations applied. After applying the augmentations the dataset has 49,901 examples.
The training was done at train.py. The training was run for 25 epochs and with a batch size of 8. The model was trained with the classic 70/20/10 split. Where 70% of the data was use to train and 20% was used to validate the model. And, the rest of the 10% was used in the inference to test the final model. It was found that a learning rate of 0.0001 was better than a learning rate of 0.001.
Both history and checkpointing is done after every epoch with checkpoints.py in the utils directory.
The final model was a R(2+1)D CNN trained on the additional augmented examples. For validation on the test set, the model at epoch 19 was used as it was the best performing model in terms of validation f1-score and accuracy. The model performs significantly better than the reported 73% in the thesis Classificazione di Azioni Cestistiche mediante Tecniche di Deep Learning, acheiving 85% for both validation accuracy and test accuracy. The confusion matrix was attained using the inference.py code. Further analysis on predictions and errors is done on error_analysis.ipynb notebook.
- 0: Block, 1: Pass, 2: Run, 3: Dribble, 4: Shoot, 5: Ball in Hand, 6: Defence, 7: Pick, 8: No Action, 9: Walk
State | Shooting | Dribble | Pass | Defence | Pick | Run | Walk | Block | No Action |
---|---|---|---|---|---|---|---|---|---|
True | |||||||||
False |
All player tracking is done in main.py. Players are tracked by manually selecting the ROI using the opencv TrackerCSRT_create() tracker. In theory, an unlimited amount of people or players can be tracked, but this will significantly increase the compute time. In the example above only 2 players, LeBron James (Offence) & Deandre Jordan (Defence) were tracked. A simple example of player tracking is available in Basketball-Player-Tracker.
After extracting the bounding boxes from TrackerCSRT_create(), a cropped clip of 16 frames is used to classify the actions. The 16 frame length clip is determined by the vid_stride (Set to 8 in the example video above) which is set in the cropWindows() function in main.py. Within the cropped window time frame the action is displayed on top of the bounding boxes to show the action of the tracked player.
- Separate augmented examples from validation and only in training.
- Utilize better player tracking methods.
- Restrict Box size to 176x128 (Or with similar Aspect Ratio), so resize of image is not applied.
- Fully automate player tracking. Potentially using YOLO or any other Object Detection Models.
- Play around with hyperparameters such as learning rates, batch size, layers frozen, etc.
- Try various 3D-CNN Architectures or sequential models such as CONV-LSTMs.
- Improve model to +90% accuracy and f1-score.
Note:
- The Model does not perform well with off-ball actions for some reason. Often times, the defender is classified to be dribbling when they are not. This might be because of the similarity of the stance while dribbling the ball and playing defence. For both movements, players generally lower their torsos forward in order to lower their centre of gravity.
Major thanks to Simone Francia for the basketball action dataset and paper on action classification with 3D-CNNs.