Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dataloader single worker and default batch_size makes R tabnet 4-15x slower than pytorch tabnet #37

Open
szilard opened this issue Mar 3, 2021 · 16 comments

Comments

@szilard
Copy link

szilard commented Mar 3, 2021

R code:

library(data.table)
library(ROCR)
library(tabnet)
library(Matrix)


d_train <- fread("https://s3.amazonaws.com/benchm-ml--main/train-0.1m.csv", stringsAsFactors=TRUE)
d_test <- fread("https://s3.amazonaws.com/benchm-ml--main/test.csv")

## align cat. values (factors)
d_train_test <- rbind(d_train, d_test)
n1 <- nrow(d_train)
n2 <- nrow(d_test)
d_train <- d_train_test[1:n1,]
d_test <- d_train_test[(n1+1):(n1+n2),]


system.time({
  md <- tabnet_fit(dep_delayed_15min ~ . ,d_train, epochs = 10, verbose = TRUE)
})


phat <- predict(md, d_test, type = "prob")$.pred_Y
rocr_pred <- prediction(phat, d_test$dep_delayed_15min)
performance(rocr_pred, "auc")@y.values[[1]]

Python code:

from pytorch_tabnet.tab_model import TabNetClassifier
import torch

import pandas as pd
import numpy as np
from sklearn import preprocessing
from sklearn import metrics


d_train = pd.read_csv("https://s3.amazonaws.com/benchm-ml--main/train-0.1m.csv")
d_test = pd.read_csv("https://s3.amazonaws.com/benchm-ml--main/test.csv")


d_all = pd.concat([d_train,d_test])

vars_cat = ["Month","DayofMonth","DayOfWeek","UniqueCarrier", "Origin", "Dest"]
vars_num = ["DepTime","Distance"]
for col in vars_cat:
  d_all[col] = preprocessing.LabelEncoder().fit_transform(d_all[col])

X_all = d_all[vars_num+vars_cat]
y_all = np.where(d_all["dep_delayed_15min"]=="Y",1,0)

cat_idxs = [ i for i, col in enumerate(X_all.columns) if col in vars_cat]
cat_dims = [ len(np.unique(X_all.iloc[:,i].values)) for i in cat_idxs]

X_train = X_all[0:d_train.shape[0]].to_numpy()
y_train = y_all[0:d_train.shape[0]]
X_test = X_all[d_train.shape[0]:(d_train.shape[0]+d_test.shape[0])].to_numpy()
y_test = y_all[d_train.shape[0]:(d_train.shape[0]+d_test.shape[0])]


md = TabNetClassifier(cat_idxs=cat_idxs,
                       cat_dims=cat_dims,
                       cat_emb_dim=1
)

%%time
md.fit( X_train=X_train, y_train=y_train,
    max_epochs=10, patience=0
)


y_pred = md.predict_proba(X_test)[:,1]
print(metrics.roc_auc_score(y_test, y_pred))

m5.2xlarge (8 cores):

R:

[Epoch 001] Loss: 0.495622
[Epoch 002] Loss: 0.455483
[Epoch 003] Loss: 0.450127
[Epoch 004] Loss: 0.449376
[Epoch 005] Loss: 0.448024
[Epoch 006] Loss: 0.447154
[Epoch 007] Loss: 0.446089
[Epoch 008] Loss: 0.444280
[Epoch 009] Loss: 0.443956
[Epoch 010] Loss: 0.443126
    user   system  elapsed
2927.067    6.196 1502.377
>
>
> phat <- predict(md, d_test, type = "prob")$.pred_Y
> rocr_pred <- prediction(phat, d_test$dep_delayed_15min)
> performance(rocr_pred, "auc")@y.values[[1]]
[1] 0.70621

Python:

No early stopping will be performed, last training weights will be used.
epoch 0  | loss: 0.48224 |  0:00:08s
epoch 1  | loss: 0.45447 |  0:00:16s
epoch 2  | loss: 0.45087 |  0:00:25s
epoch 3  | loss: 0.44885 |  0:00:33s
epoch 4  | loss: 0.44667 |  0:00:42s
epoch 5  | loss: 0.44576 |  0:00:50s
epoch 6  | loss: 0.44538 |  0:00:58s
epoch 7  | loss: 0.44727 |  0:01:07s
epoch 8  | loss: 0.4467  |  0:01:15s
epoch 9  | loss: 0.44514 |  0:01:24s
CPU times: user 5min 14s, sys: 555 ms, total: 5min 15s
Wall time: 1min 26s

In [23]:

In [23]: y_pred = md.predict_proba(X_test)[:,1]

In [24]: print(metrics.roc_auc_score(y_test, y_pred))
0.7031382841315941

Some of the parameter values will have different defaults in the R and Python libs, but still the difference in runtime is too much. More details of my experiments here: szilard/GBM-perf#52

@dfalbel
Copy link
Member

dfalbel commented Mar 3, 2021

We still didn't do any profiling of tabnet code to detect bottlenecks so this is expected in general. It's likely that future improvements in the torch package will make it faster and comparable to python's implementation.

@szilard
Copy link
Author

szilard commented Mar 3, 2021

I thought most of the time should be spent in C code and that should be the same in R and Python.

However if the R implementation takes 1500 sec and the Python 100sec, then max 100 sec is spent is C and 1400 sec somewhere else. I don't know where, Rprof is not giving me something that let's me figure it out:

> summaryRprof()$by.self
                                                 self.time self.pct total.time
"<Anonymous>"                                        64.82    21.07     286.22
"[[.nn_Module"                                       21.56     7.01      25.20
"call_c_function"                                    17.98     5.84     172.94
"all_arguments_to_torch_type"                        15.42     5.01      54.46
"extract_method"                                     10.96     3.56      21.44
"argument_to_torch_type"                             10.58     3.44      39.86
"[["                                                 10.52     3.42      60.96
"$<-.R7"                                              9.68     3.15      15.08
"methods$initialize"                                  9.36     3.04      59.58
"o"                                                   8.88     2.89     132.10
"Tensor_slice"                                        7.02     2.28      20.62
"$.R7"                                                6.70     2.18      23.06
"cpp_make_function_name"                              6.30     2.05       6.30
"Tensor$new"                                          5.14     1.67      56.60
"FUN"                                                 4.78     1.55      94.48
"to_return_type"                                      3.98     1.29      25.14
"$.nn_Module"                                         3.88     1.26      36.28
"lapply"                                              3.68     1.20      98.68
"$<-"                                                 3.60     1.17      18.68
"tryCatchOne"                                         3.30     1.07      44.48
"do_call"                                             3.00     0.98      77.28

@szilard
Copy link
Author

szilard commented Mar 3, 2021

... and if that's the case, improvements is torch will decrease the 100 sec, but not the rest.

@dfalbel
Copy link
Member

dfalbel commented Mar 3, 2021

All the Rprof results is showing torch code. Why do you state that improvements in torch will decrease the 100 sec and not the rest?

@szilard
Copy link
Author

szilard commented Mar 3, 2021

My "statement" was conditional on my belief that both R and Python implementations would call the same C code ("if that's the case"). Maybe that's not the case.

@dfalbel
Copy link
Member

dfalbel commented Mar 3, 2021

I don't know exactly what you mean by the same C code. It's indeed different from xgboost where there's a C++ implementation and both python and R bind to that lib. In tabnet's case we have re-implemented the algorithm in R, using torch that binds to libtorch, so any improvements in torch are likely to have a great impact in downstream libraries.

@szilard
Copy link
Author

szilard commented Mar 3, 2021

OK, thanks for clarifying.

@cregouby
Copy link
Collaborator

cregouby commented Mar 7, 2021

Hello @szilard ,
Please note that the R implementation does use a smaller default batch_size and virtual_batch_size compared to pytorch_tabnet.
So I would align with pytorch_tabnet defaults with those

system.time({
  md <- tabnet_fit(dep_delayed_15min ~ . ,d_train, epochs = 10, batch_size= 1024, virtual_batch_size = 128, verbose = TRUE)
})

which brings a 40%+ time improvement (one epoch, my laptop, CPU).
Note that optimal value for a 32GB machine wold be much much higher like batch_size= 500e3, virtual_batch_size = 25e3 which brings an additional 30% time improvement here (but is not to be considered for py to R comparison)

@szilard
Copy link
Author

szilard commented Mar 8, 2021

Thanks @cregouby for looking into this and remarks. Even with the 40% improvement it would be a 10x slowdown, which to my current understanding based on the above it's happening somewhere in the torch R package. (There was a bit of a misunderstanding before because when @dfalbel was referring to the "torch package" I thought he meant the torch library/packages, which later I got it you guys call it "libtorch"). Anyway, so the 10x slowdown happens at the R (or maybe C) level in the R package (outside of libtorch), so maybe it can be done something about.

@gsgxnet
Copy link

gsgxnet commented Jan 30, 2022

To prepare for taking part in the R TabNet Online Meetup in Brussels I did setup my environments for both the R and Python TabNet.

Sadly the performance differences are as of today (2022-01-30) quite similar to the original post.
An about 9 fold time difference. I do wonder why.


I did several runs on my local machine, CPU and GPU (Nothing fancy, 4+ years old Dell laptop)
I am posting here only the GPU timings for both Python and R


Python & GPU

Same code for Python as in the initial post, just split into 3 chunks into a Jupyter notebook and run chunkwise.
Environment for PyTorch is more current, libtorch is 1.10.1, CUDA 11.3.
conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch and
conda install -c conda-forge pytorch-tabnet

Device used : cuda
No early stopping will be performed, last training weights will be used.
epoch 0  | loss: 0.47993 |  0:00:04s
epoch 1  | loss: 0.45312 |  0:00:07s
epoch 2  | loss: 0.45231 |  0:00:10s
epoch 3  | loss: 0.44889 |  0:00:13s
epoch 4  | loss: 0.4485  |  0:00:16s
epoch 5  | loss: 0.44799 |  0:00:19s
epoch 6  | loss: 0.44674 |  0:00:22s
epoch 7  | loss: 0.44672 |  0:00:25s
epoch 8  | loss: 0.44625 |  0:00:29s
epoch 9  | loss: 0.44756 |  0:00:32s
0.697642484844643
CPU times: user 35.5 s, sys: 628 ms, total: 36.1 s
Wall time: 36.6 s

R & GPU

The time was done as suggested by @cregouby for comparability.

For R I first run with my then current torch and tabnet libraries. Later with the newest versions from GitHub. That did not make a real difference.

fully refreshed setup for torch and lantern:

versuche URL 'https://download.pytorch.org/libtorch/cu111/libtorch-cxx11-abi-shared-with-deps-1.9.1%2Bcu111.zip'
Content type 'application/zip' length 2145005881 bytes (2045.6 MB)
==================================================
downloaded 2045.6 MB

versuche URL 'https://storage.googleapis.com/torch-lantern-builds/refs/heads/master/latest/Linux-gpu-111.zip'
Content type 'application/zip' length 1982279 bytes (1.9 MB)
==================================================
downloaded 1.9 MB

The final clean run:

[Epoch 001] Loss: 0.473077                            
[Epoch 002] Loss: 0.455934                            
[Epoch 003] Loss: 0.450762                            
[Epoch 004] Loss: 0.449865                            
[Epoch 005] Loss: 0.448689                            
[Epoch 006] Loss: 0.449870                            
[Epoch 007] Loss: 0.447446                            
[Epoch 008] Loss: 0.446433                            
[Epoch 009] Loss: 0.446497                            
[Epoch 010] Loss: 0.446185                            
       User      System verstrichen 
    332.756       1.586     294.378 
[1] 0.7022121

sessionInfo()

R version 4.1.2 (2021-11-01)
Platform: x86_64-suse-linux-gnu (64-bit)
Running under: openSUSE Tumbleweed

Matrix products: default
BLAS:   /usr/lib64/R/lib/libRblas.so
LAPACK: /usr/lib64/R/lib/libRlapack.so

locale:
 [1] LC_CTYPE=en_US.UTF-8       LC_NUMERIC=C              
 [3] LC_TIME=de_DE.UTF-8        LC_COLLATE=de_DE.UTF-8    
 [5] LC_MONETARY=de_DE.UTF-8    LC_MESSAGES=de_DE.UTF-8   
 [7] LC_PAPER=de_DE.UTF-8       LC_NAME=C                 
 [9] LC_ADDRESS=C               LC_TELEPHONE=C            
[11] LC_MEASUREMENT=de_DE.UTF-8 LC_IDENTIFICATION=C       

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods  
[7] base     

other attached packages:
[1] Matrix_1.4-0      tabnet_0.3.0.9000 ROCR_1.0-11      
[4] data.table_1.14.2

loaded via a namespace (and not attached):
 [1] Rcpp_1.0.8        pillar_1.6.5      compiler_4.1.2   
 [4] prettyunits_1.1.1 progress_1.2.2    tools_4.1.2      
 [7] digest_0.6.29     bit_4.0.4         evaluate_0.14    
[10] lifecycle_1.0.1   tibble_3.1.6      lattice_0.20-45  
[13] pkgconfig_2.0.3   rlang_1.0.0       cli_3.1.1        
[16] curl_4.3.2        yaml_2.2.2        xfun_0.29        
[19] coro_1.0.2        fastmap_1.1.0     withr_2.4.3      
[22] knitr_1.37        hms_1.1.1         vctrs_0.3.8      
[25] bit64_4.0.5       grid_4.1.2        glue_1.6.1       
[28] R6_2.5.1          processx_3.5.2    fansi_1.0.2      
[31] rmarkdown_2.11    callr_3.7.0       magrittr_2.0.2   
[34] ps_1.6.0          ellipsis_0.3.2    htmltools_0.5.2  
[37] hardhat_0.2.0     torch_0.6.1.9000  utf8_1.2.2       
[40] crayon_1.4.2 

@cregouby
Copy link
Collaborator

cregouby commented Jan 30, 2022

Hello @gsgxnet ,
Good to see that this issue is under active watch !
My -very- first feeling around GPU setup is dataset preparation and feeding is not multi-threaded. I can unfortunately only provide my laptop looking glass values at the time of tabnet_fit() :

R implementation

Only one CPU thread is running for preprocessing and torch::dataloader() :
image
and at the same time, the GPU memory mem and compute sm fails to saturate (from nvidia-smi dmon -s puct -d 5)
image

py implementation

at the same time with python , the CPU profile is using the 4 CPUs
image
and the GPU memory and compute profile is more smooth, i.e. preprocessing and dataloader are not the bottleneck here.
image

I've tried to add the num_workers option to dataloader here but with no success as the dataset is already a torch_tensor at this point...

# Erreur : callr subprocess failed: external pointer is not valid
# Run `rlang::last_error()` to see where the error occurred.
# De plus : Messages d'avis :
# 1: Datasets used with parallel dataloader (num_workers > 0) shouldn't have fields containing tensors as they can't be correctly passed to the wroker subprocesses.

maybe @dfalbel you could be of some help to that ?

@gsgxnet
Copy link

gsgxnet commented Jan 30, 2022

Hi @cregouby, thanks for looking further into it.

I am trying to figure out whats going on myself, but did not make any real progress. What I have:


PyTorch run on GPU, display of CPU cores loads:
grafik
All times there is about some load for 1 or 2 cores. Added these loads together they show a bit more than 100%, which is equal to 1 core full load, the other 7 idle.

Nothing unexpected, Python is not multi threading.


R torch on GPU, CPU loads (same 100% steps horizontal grid lines):

grafik

similar upper limit of about 100% total, R in general single threaded. Not all though! The peak on the end is the evaluation and AUC calculation.


Even just visually comparing the two graphs shows, there is nearly a magnitude (at least 5x) more time spend running some tasks on the CPU in the R setup.
I did not check how high the loads on the GPU were in the runs, my impression is, in both setups CPU processing is the limiting factor. Right, or am I missing something?

@dfalbel I will try to find out in which R most of the CPU time seems to be spend. In the past data.table and tidyverse did not harmonize too well, could that still be a reason for slower performance, e.g. in the dataloader?

@gsgxnet
Copy link

gsgxnet commented Jan 30, 2022

Using profvis for the tabnet_fit:

profvis({
  md <- tabnet_fit(dep_delayed_15min ~ . ,d_train, epochs = 10, batch_size= 1024, virtual_batch_size = 128, verbose = TRUE)
})

I get:

profile

As far as I know, that .Call is calling an underlying C or other compiled library, I assume libtorch here.
Looks like 1/3 of the total time is spend in that, 2/3 are spend in R code.
I am not an export, so I must ask - Is that interpretation right?

@cregouby cregouby changed the title R tabnet apparently 15x slower than pytorch tabnet Dataloader single worker makes R tabnet 4-15x slower than pytorch tabnet Jan 31, 2022
@cregouby cregouby changed the title Dataloader single worker makes R tabnet 4-15x slower than pytorch tabnet Dataloader single worker and default batch_size makes R tabnet 4-15x slower than pytorch tabnet Jan 31, 2022
@gsgxnet
Copy link

gsgxnet commented Jan 31, 2022

I have a second remark to my profiling of the code.

If that .Call with 87.330 sec is really the time spent in libtorch, mostly GPU leaning, that would also be 3 times the time pytorch-tabnet needs for the same training. Reason might be my tests were run based on pytorch 1.10.1 and CUDA 11.3, while R torch is still based on libtorch 1.9.1 and CUDA 11.1. But I doubt that can be the only reason for the huge difference.

And we have another huge diff, when we run CPU only.
pytorch-tabnet CPU only:
pytorch cpu

same scaling as before. We see pytorch in CPU mode is able to use 4 cores in parallel.
Total time is even less than with my GPU.

No early stopping will be performed, last training weights will be used.
epoch 0  | loss: 0.4809  |  0:00:02s
epoch 1  | loss: 0.45383 |  0:00:05s
epoch 2  | loss: 0.45076 |  0:00:08s
epoch 3  | loss: 0.44815 |  0:00:11s
epoch 4  | loss: 0.44709 |  0:00:14s
epoch 5  | loss: 0.44611 |  0:00:17s
epoch 6  | loss: 0.44456 |  0:00:20s
epoch 7  | loss: 0.44452 |  0:00:23s
epoch 8  | loss: 0.44404 |  0:00:26s
epoch 9  | loss: 0.44329 |  0:00:29s
CPU times: user 2min, sys: 548 ms, total: 2min 1s
Wall time: 30.6 s

The 2 min = 120 s = 4 * 30 s!

R CPU:
grafik

[Epoch 001] Loss: 0.481743                                                  
[Epoch 002] Loss: 0.455055                                                  
[Epoch 003] Loss: 0.450363                                                  
[Epoch 004] Loss: 0.449259                                                  
[Epoch 005] Loss: 0.446910                                                  
[Epoch 006] Loss: 0.446304                                                  
[Epoch 007] Loss: 0.445140                                                  
[Epoch 008] Loss: 0.444208                                                  
[Epoch 009] Loss: 0.442979                                                  
[Epoch 010] Loss: 0.442043                                                  
       User      System verstrichen 
    841.376       1.972     297.636 

so a similar nearly 8 fold (120 to 840 s) CPU time diff.

@cregouby
Copy link
Collaborator

cregouby commented Jan 31, 2022

Thanks a lot @gsgxnet for this investigation, very helpfull.
May I propose to split that thread in as many topics as needed ?

@gsgxnet
Copy link

gsgxnet commented Jan 31, 2022

Thanks a lot @gsgxnet for this investigation, very helpfull. May I propose to split that thread in as many topics as needed ?

* [Dataloader cannot use num_workers>0L #78](https://github.com/mlverse/tabnet/issues/78)

Yes please do.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants