Skip to content

Gradient-based Entire Tree Optimization For Oblique Regression Tree

License

Notifications You must be signed in to change notification settings

maoqiangqiang/GET

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 

Repository files navigation

Gradient-based Entire Tree Optimization For Oblique Regression Tree

Initial version, all rights reserved. Updates coming soon.

Language Dependencies Status

Requirements

  • PyTorch 2.0.1
  • Python 3.9.6
    • scikit-learn 1.0.2
    • numpy 1.21.6
    • pandas 1.3.5
    • h5py 3.8.0

Script Description

  • src folder contains the scripts of Gradient-based Entire Tree Optimization for oblique regression tree with constant prediction, termed as GET.

    • ancestorTF_File subfolder contains the deterministic tree path routing (simply tree path of sample assignment) in h5 file format. These files can be generated by the function of treePathCalculation in treeFunc.py script.
    • treeFunc.py includes utility functions.
    • dataset.py is to load the dataset.
    • warmStart.py generates the warm-start initialization based on CART method.
    • GET_IterAlp.py includes the main functions of GET method.
    • subtreePolish.py includes the subtree polish strategy.
  • test folder contains the script of running these algorithms.

    • test_GET.py is to test the GET method without equipping subtree polish strategy.
    • test_GETwithSubtreePolish.py is to test GET method with subtree polish strategy.
  • data folder contains the datasets. These publicly-available datasets are from the UCI Machine Learning repository and OpenML. Each dataset is shuffled and split into training, validation and testing sets, and saved in the *.csv format.

Examples

The examples can be implemented via:

# test the GET method 
python .\test\test_GET.py 3 3 1 1 2 3000 "cuda" 10
# test the GET method with subtree polish strategy
python .\test\test_GETwithSubtreePolish.py 3 3 1 1 2 3000 "cuda" 10

License

This repository is published under the terms of the GNU General Public License v3.0 .

Releases

No releases published

Packages

No packages published

Languages