Skip to content

CPS-research-group/dtsemnet

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

26 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

DTSemNet portfolio_view

Welcome to the official implementation of the DTSemNet architecture, as proposed in the paper: “Vanilla Gradient Descent for Oblique Decision Trees,” ECAI-2024. [Paper] [Website]

High Level

Overview

DTSemNet is a novel invertible encoding of Oblique Decision Trees (ODT) as a Neural Network (NN) facilitating the training of ODT using vanilla gradient descent. This repository contains all the necessary files and scripts to replicate the experiments and results presented in the paper. We have reused some parts of code from CRO-DT, DGT and ICCT.

Repository Structure

  • src/dtsemnet.py: Core implementation of the DTSemNet model.
  • src/net_train.py: Script for training on small classification datasets.
  • src/net_train2.py: Script for training on large classification datasets, with GPU support.
  • src/reg_train_linear.py: Script for training DTSemNet on regression tasks.
  • /results: Directory where training logs are saved. Previous logs used in our paper are available in /results/ecai-reported.
  • /results/combined_mean.py: In the case of multiple splits, the reported mean ± standard deviation are calculated by combining the results from each split.

Datasets

Classification (Small)

Included in the `datasets' directory: [breast_cancer,car,banknote,balance,acute−1,acute−2,transfusion,climate,sonar,optical,drybean,avila,wine−red,wine−white]

Due to size constraints, the following datasets are not included. It needs to be downloaded to `datasets' directory.

Classification (Large)

Dataset Name Download Link
mnist PyTorch Auto-Download
letter Download
connect Download
segment Download
satimages Download
pendigits Download
protein Download
sensit Download

Regression

Dataset Name Download Link
abalone Download
ailerons Download
cpu_active Download
pdb_bind Download
year Download
ctslice Download
ms Download

Installation

Please install conda environment using environment.yml. In case, there is an error in PyTorch installation, please install the same version manually.

conda env create -f environment.yml
conda activate dtsemnet

Also, install the packages in `src' directory:

python -m pip install -e .

Execute Code for Classification Dataset with Small DTs

python -m src.net_train --model dtsemnet --dataset all --depth 4 -s 1 --output_prefix dtsement --verbose True
- replace `dtsemnet` with `dgt` for DGT evaluation
- For time computation restrict number of cores to 8, use `taskset -cpa 0-7`
- s: Number of simulations for averaging (100 for small DTs and 10 for large DTs)
- depth: Height of the DT
- dataset: Use appropriate dataset name. `all' to evaluate all datasets
- output_prefix: helps in managing the name of the log file

Execute Code for Classification Dataset with Large DTs

python -m src.net_train2 --model dtsemnet --dataset mnist -s 1 --output_prefix dtsement --verbose True -g
- Height is included in the configuration file, there is no need to specify
- model: `dtsemnet' or `dgt'
- dataset: use appropriate name of the dataset
- add `-g` for GPU operation

DTSemNet-Regression

Use the following terminal command for regression datasets ["abalone", "ailerons", "cpu_active", "pdb_bind", "year", "ctslice", "ms"].

python -m src.reg_train_linear --model dtregnet --dataset ailerons -s 1 --output_prefix ailerons --verbose True -g
- Use `--model dtregnet` for for DTSemNet-regression

RL Experiments

  • Please switch to branch rl_d (discrete action: PPO) or rl_c (continuous action: SAC) for the RL experiments.
  • The branches are standalone, so you have to follow respective README files and do the environment installation for RL experiment.

If you find DTSemNet useful in your research, please cite our work:

Subrat Panda, Blaise Genest, Arvind Easwaran, and Ponnuthurai Suganthan, "Vanilla Gradient Descent for Oblique Decision Trees," European Conference on Artificial Intelligence (ECAI), 2024. Link to paper