Skip to content

Commit

Permalink
updated readme. Fix to eval
Browse files Browse the repository at this point in the history
  • Loading branch information
herobd committed Jun 19, 2021
1 parent 14195c5 commit 66aa9dc
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 26 deletions.
47 changes: 30 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

This is the code for our ICDAR 2021 paper "Visual FUDGE: Form Understanding via Dynamic Graph Editing" (http://arxiv.org/abs/2105.08194)

This is still getting cleaned up and tweaked. This README is not up to date. I'll probably be done in about a month.

This code is licensed under GNU GPL v3. If you would like it distributed to you under a different license, please contact me (briandavis@byu.net).

Expand All @@ -14,30 +13,30 @@ This code is licensed under GNU GPL v3. If you would like it distributed to you
* scikit-image
* pytorch-geometric https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html

## Model weights
## Pre-trained Model weights

Eventually I'll get the weights added as a release.
See "Releases"


## Reproducability instructions


### Datasets
### Getting the datasets
NAF see https://github.com/herobd/NAF_dataset

FUNSD see https://guillaumejaume.github.io/FUNSD/

The configs expect the datasets to be at `../data/NAF_dataset/` and `../data/FUNSD/`

### Pretraining the detector network
### Pretraining the detector networks
FUNSD: `python train.py -c configs/cf_FUNSDLines_detect_augR_staggerLighter.json`

NAF: `python train.py -c configs/cf_NAF_detect_augR_staggerLighter.json`

### Training the full network
### Training the full networks
FUNSD: `python train.py -c cf_FUNSDLines_pair_graph663rv_new.json`

Word-FUDGE: `python train.py -c cf_FUNSDWords_pair_graph663rv_new.json`
Word-FUDGE: `python train.py -c cf_FUNSDWords_pair_graph663rv_new.json` [ACTUALLY, THERE WAS AN ERROR IN MY ORIGINAL EVAL. NEW RESULTS FORTHCOMING (for Word-FUDGE)]

NAF: `python train.py -c cf_NAF_pair_graph663rv_new.json`

Expand Down Expand Up @@ -108,12 +107,16 @@ Evaluatring pairing:
* `-a draw_verbosity=0-3`: Different ways of displaying the results.

## File Structure
This code is based on based on victoresque's pytorch template.

```
├── train.py - Training script
├── eval.py - Evaluation and display script
├── configs/ - where the config files are
├── base/ - abstract base classes
│ ├── base_data_loader.py - abstract base class for data loaders
│ ├── base_model.py - abstract base class for models
Expand All @@ -123,11 +126,12 @@ Evaluatring pairing:
│ └── data_loaders.py - This provides access to all the dataset objects
├── datasets/ - default datasets folder
│ ├── box_detect.py - base class for detection
│ ├── box_detect.py - base class for detection datasets
│ ├── forms_box_detect.py - detection for NAF dataset
│ ├── forms_feature_pair.py - dataset for training non-visual classifier
│ ├── graph_pair.py - base class for pairing
│ ├── forms_graph_pair.py - pairing for NAD dataset
│ ├── funsd_box_detect.py - detection for FUNSD dataset
│ ├── graph_pair.py - base class for pairing datasets
│ ├── forms_graph_pair.py - pairing for NAF dataset
│ ├── funsd_graph_pair.py - pairing for NAF dataset
│ └── test*.py - scripts to test the datasets and display the images for visual inspection
├── logger/ - for training process logging
Expand All @@ -150,13 +154,22 @@ Evaluatring pairing:
├── trainer/ - trainers
│ ├── box_detect_trainer.py - detector training code
│ ├── feature_pair_trainer.py - non-visual pairing training code
│ ├── graph_pair_trainer.py - pairing training code
│ └── trainer.py
│ └── graph_pair_trainer.py - pairing training code
├── evaluators/ - used to evaluate the models
│ ├── draw_graph.py - draws the predictions onto the image
│ ├── funsdboxdetect_eval.py - for detectors
│ └── funsdgraphpair_eval.py - for pairing networks
└── utils/
├── util.py
└── ...
├── augmentation.py - coloring and contrans augmentation
├── crop_transform.py - handles random croping, especially tracking which text is cropped
├── forms_annotations.py - functions for processing NAF dataset
├── funsd_annotations.py - functions for processing FUNSD dataset
├── group_pairing.py - helper functions dealing with groupings
├── img_f.py - I originally used OpenCV, but was running into issues with the Anaconda installation. This wraps SciKit Image with OpenCV function signatures.
└── yolo_tools.py - Non-maximal supression and pred-to-GT aligning functions
```

### Config file format
Expand Down Expand Up @@ -312,11 +325,11 @@ The config file is saved in the same folder. (as a reference only, the config is
```python
{
'arch': arch,
'epoch': epoch,
'iteration': iteration,
'logger': self.train_logger,
'state_dict': self.model.state_dict(),
'swa_state_dict': self.swa_model.state_dict(),
'optimizer': self.optimizer.state_dict(),
'monitor_best': self.monitor_best,
'config': self.config
}
```
Expand Down
8 changes: 0 additions & 8 deletions eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
import pickle
#import requests
import warnings
from utils.saliency import SimpleFullGradMod
from utils.debug_graph import GraphChecker
import numpy as np

def update_status(name,message):
Expand Down Expand Up @@ -208,12 +206,6 @@ def main(resume,saveDir,numberOfImages,index,gpu=None, shuffle=False, setBatch=N
#saveFunc = eval(trainer_class+'_printer')
saveFunc = eval(config['data_loader']['data_set_name']+'_eval')

do_saliency_map = config['saliency'] if 'saliency' in config else False
do_graph_check_map = config['graph_check'] if 'graph_check' in config else False
if do_saliency_map:
trainer.saliency_model = SimpleFullGradMod(trainer.model)
if do_graph_check_map:
trainer.graph_check_model = GraphChecker(trainer.model)

step=5

Expand Down
10 changes: 9 additions & 1 deletion trainer/graph_pair_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import torch.autograd.profiler as profile

#check for DocStruct hit@1 evaluation
#check that is part of DocStruct hit@1 evaluation
def maxRelScoreIsHit(child_groups,parent_groups,edgeIndexes,edgePred):
max_score=-1
max_score_is_hit=False
Expand Down Expand Up @@ -549,15 +549,23 @@ def simplerAlignEdgePred(self,
#we compute True and False scenarios. If an edge doesn't fit either (e.g. it's impure) then it does not contribute to the loss
#remember * is 'and' + is 'or' and ~ is 'not'

#it is a relationship/link if both are aligned to GT bbs AND this isn't a merge AND they both are pure AND they don't belong to the same GT group AND there is a GT link between their aligned GT groups
wasRel = bothTarged*((g_len_L>1)+(g_len_R>1)+~same_ts_0)*bothPure*~same_GTGroups*gtGroupAdjMat

#there is no relationship/link if neither are aligned OR (they are aligned AND this isn't a merge AND they are pure AND they are not the same GT group AND there is not GT link)
wasNoRel = (badTarged+(bothTarged*((g_len_L>1)+(g_len_R>1)+~same_ts_0)*bothPure*~same_GTGroups*~gtGroupAdjMat))

#this is not a merge if they aren't aligned OR (they are aligned And pure AND either one has multiple BBs or they aren't aligned to the same GT BB)
wasNoOverSeg = (badTarged+(bothTarged*bothPure*~((g_len_R==1)*(g_len_L==1)*same_ts_0)))
#This is a merge if they are both aligned, they both have only one BB and have the same GT BB
wasOverSeg = bothTarged*(g_len_R==1)*(g_len_L==1)*same_ts_0

#This is a group if they are aligned AND this isn't a merge AND they are pure AND are the same GT group
wasGroup = bothTarged*((g_len_L>1)+(g_len_R>1)+~same_ts_0)*bothPure*same_GTGroups
#This is not a group if they arn't aligned OR (this isn't a merge AND they are pure but they belong to different GT groups)
wasNoGroup = badTarged+(bothTarged*((g_len_L>1)+(g_len_R>1)+~same_ts_0)*bothPure*~same_GTGroups)

#The error occurs when something is impure (bad merge or group)
wasError = bothTarged*((purity_R<1)+(purity_L<1))
wasNoError = bothTarged*~((purity_R<1)+(purity_L<1))

Expand Down

0 comments on commit 66aa9dc

Please sign in to comment.