This repository contains our PyTorch implementation of ARM-Net: Adaptive Relation Modeling Network for Structured Data. We also provide the implementation of relevant baseline models for structured (tabular) data learning.
- Frappe - App Recommendation
- MovieLens - Movie Recommendation
- Avazu - Click-Through Rate Prediction
- Criteo - Display Advertising Challenge
- Diabetes130 - Diabetes Readmission Prediction
- Main results are summarized below.
- ARM-Net achieves the overall best performance.
- More results and technical details can be found in the paper.
- Note that these results are reported using a fixed embedding size of 10 for a fair comparison, and higher AUC can be obtained by increasing the embedding size:
E.g., with a larger embedding size of 100, ARM-Net (single head, without ensemble with a DNN)
can obtain 0.9817 AUC on Frappe with only 10 exponential neurons.
CUDA_VISIBLE_DEVICES=0 python train.py --model armnet_1h --nemb 100 --h 10 --alpha 1.7 --lr 0.001 --exp_name frappe_armnet_1h_nemb_100
The AUC and Model Size of this ARM-Net with different embedding sizes are listed below.
Embedding Size | 10 | 20 | 30 | 40 | 50 | 60 | 70 | 80 | 90 | 100 | 110 | 120 |
---|---|---|---|---|---|---|---|---|---|---|---|---|
AUC | 0.9777 | 0.9779 | 0.9801 | 0.9803 | 0.9798 | 0.9807 | 0.9808 | 0.9810 | 0.9810 | 0.9817 | 0.9811 | 0.9805 |
Model Size | 177K | 262K | 348K | 434K | 520K | 606K | 692K | 779K | 866K | 953K | 1.04M | 1.13M |
Model | Code | Reference |
---|---|---|
Logistic Regression | LR, lr.py | - |
Factorization Machine | FM, fm.py | [ICDE-10] FM |
Attentional Factorization Machine | AFM, afm.py | [IJCAI-17] AFM |
Higher-Order Factorization Machines | HOFM, hofm.py | [NeurIPS-16] HOFM |
Deep Neural Network | DNN, dnn.py | - |
Graph Convolutional Networks | GCN, gcn.py | [ICLR-17] GCN |
Graph Convolutional Networks | GAT, gat.py | [ICLR-18] GAT |
Wide&Deep | Wide&Deep, wd.py | [RecSys-16] Wide&Deep |
Product Neural Network | IPNN/KPNN, pnn.py | [ICDE-16] PNN |
Neural Factorization Machine | NFM, nfm.py | [SIGIR-17] NFM |
DeepFM | DeepFM, dfm.py | [IJCAI-17] DeepFM |
Deep & Cross Network | DCN/DCN+, dcn.py | [KDD-17] DCN |
Gated Linear Unit | SA_GLU, sa_glu.py | [ICML-17] GLU |
xDeepFM | CIN/xDeepFM, xdfm.py | [KDD-18] xDeepFM |
Context-Aware Self-Attention Network | GC_ARM, gc_arm.py | [AAAI-19] GC-ARM |
AFN | AFN/AFN+, afn.py | [AAAI-20] AFN |
ARM-Net | ARM-Net/ARM-Net+, armnet.py | [SIGMOD-21] ARM-Net |
ARM-Net-1h (one-head, recommended) | ARM-Net/ARM-Net+, armnet_1h.py | [SIGMOD-21] ARM-Net-1h |
ARM-Net for Small to Medium Tabular Datasets (121 UCI datasets)
We also implement and evaluate prior arts, latest models and our ARM-Net on UCI datasets. These datasets are multi-class real-world classification tasks, whose features are all converted into numerical features following common practice. Models and Utilities for evaluating models on 121 UCI Datasets are included in this branch.
- Main results are summarized below.
- ARM-Net achieves overall best performance.
- More results and technical details can be found here.
Model | Rank(Best_Cnt) | abalone | acute-inflammation | acute-nephritis | adult | annealing | arrhythmia | audiology-std | balance-scale | balloons | bank | blood | breast-cancer | breast-cancer-wisc | breast-cancer-wisc-diag | breast-cancer-wisc-prog | breast-tissue | car | cardiotocography-10clases | cardiotocography-3clases | chess-krvk | chess-krvkp | congressional-voting | conn-bench-sonar-mines-rocks | conn-bench-vowel-deterding | connect-4 | contrac | credit-approval | cylinder-bands | dermatology | echocardiogram | ecoli | energy-y1 | energy-y2 | fertility | flags | glass |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
n_samples |
- | 4177 | 120 | 120 | 48842 | 898 | 452 | 196 | 625 | 16 | 4521 | 748 | 286 | 699 | 569 | 198 | 106 | 1728 | 2126 | 2126 | 28056 | 3196 | 435 | 208 | 990 | 67557 | 1473 | 690 | 512 | 366 | 131 | 336 | 768 | 768 | 100 | 194 | 214 |
n_features |
- | 9 | 7 | 7 | 15 | 32 | 263 | 60 | 5 | 5 | 17 | 5 | 10 | 10 | 31 | 34 | 10 | 7 | 22 | 22 | 7 | 37 | 17 | 61 | 12 | 43 | 10 | 16 | 36 | 35 | 11 | 8 | 9 | 9 | 10 | 29 | 10 |
LR | 6-th (0/36) | 0.6293/0.0080 | 0.9833/0.0211 | 0.9533/0.0552 | 0.8423/0.0008 | 0.1280/0.0172 | 0.5442/0.0184 | 0.7040/0.0480 | 0.8718/0.0310 | 0.7250/0.0935 | 0.8904/0.0023 | 0.7610/0.0043 | 0.6923/0.0171 | 0.9490/0.0090 | 0.9641/0.0103 | 0.6626/0.0656 | 0.5283/0.1371 | 0.8032/0.0052 | 0.7595/0.0118 | 0.8798/0.0120 | 0.2743/0.0009 | 0.9438/0.0035 | 0.5705/0.0328 | 0.7385/0.0186 | 0.7121/0.0088 | 0.7547/0.0004 | 0.4829/0.0383 | 0.8557/0.0119 | 0.6305/0.0647 | 0.9399/0.0313 | 0.7600/0.0605 | 0.7988/0.0510 | 0.8391/0.0123 | 0.8448/0.0297 | 0.5800/0.1066 | 0.4206/0.0365 | 0.5290/0.0281 |
FM | 5-th (3/36) | 0.6329/0.0067 | 0.9767/0.0389 | 0.8700/0.0945 | 0.8443/0.0005 | 0.1960/0.1493 | 0.5283/0.0211 | 0.4880/0.0588 | 0.9224/0.0087 |
0.5750/0.1275 | 0.8882/0.0028 | 0.7647/0.0000 | 0.6909/0.0604 | 0.9599/0.0048 | 0.9697/0.0048 |
0.6626/0.0849 | 0.5094/0.0818 | 0.8882/0.0097 | 0.7616/0.0161 | 0.8903/0.0172 | 0.3127/0.0035 | 0.9796/0.0038 | 0.5705/0.0306 | 0.9502/0.0087 |
0.9502/0.0087 | 0.8264/0.0005 | 0.4524/0.0140 | 0.8638/0.0093 | 0.7016/0.0250 | 0.9202/0.0350 | 0.7846/0.0600 | 0.7595/0.0680 | 0.8823/0.0086 | 0.8604/0.0283 | 0.7720/0.0688 | 0.3423/0.0200 | 0.5907/0.0361 |
DNN | 4-th (6/36) | 0.6560/0.0051 | 0.9900/0.0200 | 0.9500/0.0316 | 0.8519/0.0015 | 0.4420/0.2346 | 0.6442/0.0114 | 0.6880/0.0466 | 0.8987/0.0048 | 0.5500/0.2318 | 0.8900/0.0035 | 0.7583/0.0050 | 0.7147/0.0082 | 0.9633/0.0033 | 0.9648/0.0107 | 0.7091/0.0475 | 0.5849/0.0396 | 0.9442/0.0034 | 0.7797/0.0121 | 0.9178/0.0031 | 0.6842/0.0147 | 0.9775/0.0032 | 0.5834/0.0147 | 0.7481/0.0377 | 0.9745/0.0063 |
0.8501/0.0023 | 0.5084/0.0158 | 0.8417/0.0187 | 0.7359/0.0386 |
0.9639/0.0101 |
0.7846/0.0337 | 0.8524/0.0166 |
0.8688/0.0107 | 0.8865/0.0094 |
0.8320/0.0722 | 0.4969/0.0272 |
0.5850/0.0316 |
SNN | 3rd (6/36) | 0.6457/0.0043 | 0.9567/0.0389 | 0.9000/0.0548 | 0.8489/0.0009 | 0.2280/0.2671 | 0.5841/0.0410 | 0.7200/0.0253 |
0.9058/0.0240 | 0.7250/0.1225 | 0.8885/0.0019 | 0.8885/0.0019 | 0.7105/0.0105 | 0.9656/0.0041 |
0.9690/0.0112 | 0.6727/0.0903 | 0.6000/0.0690 |
0.9632/0.0066 |
0.8008/0.0125 |
0.9029/0.0086 | 0.6796/0.0141 | 0.9726/0.0061 | 0.5779/0.0209 | 0.7135/0.0300 | 0.9693/0.0100 | 0.8491/0.0013 | 0.5106/0.0098 | 0.8719/0.0121 |
0.7000/0.0163 | 0.9388/0.0269 | 0.7877/0.0439 | 0.8179/0.035 | 0.8714/0.0142 | 0.8854/0.0154 | 0.7600/0.1180 | 0.4804/0.0231 | 0.5738/0.0602 |
Perceiver-IO | 2nd (6/36) | 0.6381/0.0143 | 1.0000/0.0000 |
0.9367/0.0531 | 0.8521/0.0011 | 0.7600/0.0000 |
0.5602/0.0053 | 0.0080/0.0160 | 0.8821/0.0166 | 0.7750/0.0500 |
0.8850/0.0000 | 0.7620/0.0000 | 0.7063/0.0088 | 0.9352/0.0313 | 0.9556/0.0142 | 0.7596/0.0118 |
0.3208/0.0597 | 0.9326/0.0120 | 0.5325/0.0861 | 0.7817/0.0035 | 0.6834/0.0151 | 0.8106/0.0895 | 0.6129/0.0000 |
0.5635/0.0817 | 0.6732/0.0521 | 0.7538/0.0000 | 0.4457/0.0122 | 0.7745/0.1075 | 0.6133/0.0078 | 0.4295/0.0754 | 0.7662/0.0834 | 0.6440/0.0239 | 0.8417/0.0295 | 0.8807/0.0325 | 0.8560/0.0480 |
0.3010/0.0247 | 0.4093/0.0415 |
ARM-Net |
1st (15/36) |
0.6603/0.0034 |
0.9767/0.0389 | 0.9600/0.0800 |
0.8562/0.0011 |
0.1500/0.1131 | 0.6487/0.0214 |
0.5520/0.0299 | 0.9135/0.0070 | 0.7500/0.0791 | 0.8922/0.0012 |
0.8922/0.0012 |
0.7203/0.0193 |
0.9530/0.0118 | 0.9521/0.0186 | 0.6828/0.0485 | 0.5170/0.0638 | 0.9463/0.0086 | 0.7868/0.0054 | 0.9146/0.0051 |
0.6982/0.0109 |
0.9826/0.0040 |
0.5760/0.0193 | 0.7712/0.0335 | 0.9675/0.0115 | 0.8672/0.0028 |
0.5228/0.0119 |
0.8620/0.0187 | 0.7133/0.0305 | 0.9497/0.0181 | 0.8338/0.0406 |
0.8214/0.0279 | 0.8844/0.0048 |
0.8750/0.0304 | 0.8240/0.0528 | 0.4330/0.0526 | 0.6150/0.0232 |
ARM-Net for Log-Based Anomaly Detection
ARM-Net can also be readily adapted for supporting Log-based Anomaly Detection. Log-based anomaly detection aims to discover abnormal system behaviors (binary classification) by analyzing log sequences that are generated routinely by the system at runtime.
Each log is a message in unstructued data format (raw text), which can be parsed into structured data format of a number of key information fields, e.g., date, pid, level, event ID and etc. Models and Utilities for supporting End-to-end Log-based Anomaly Detection can be found in this branch.
If you use our code in your research, please cite:
S. Cai, K. Zheng, G. Chen, H.V. Jagadish, B.C. Ooi, M. Zhang. ARM-Net: Adaptive Relation Modeling Network for Structured Data. ACM International Conference on Management of Data (SIGMOD), 2021
To ask questions or report issues, you can drop us an email.