-
Notifications
You must be signed in to change notification settings - Fork 27
/
test_fan.py
executable file
·128 lines (99 loc) · 5.46 KB
/
test_fan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import re
import os
import sys
import json
import logging
import argparse
import numpy as np
from pathlib import Path
import helpers.utils
from helpers import fsutil, dataset, results_data
from training.validation import validate_fan
from compression import codec
from workflows import manipulation_classification
# Setup logging
logging.basicConfig(level=logging.INFO)
log = logging.getLogger('test')
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
def restore_flow(filename, isp, manipulations, jpeg_qf, jpeg_codec, dcn_model, patch_size):
with open(filename) as f:
training_log = json.load(f)
print('\n[{}]'.format(os.path.split(filename)[0]))
# Setup manipulations
if manipulations is None:
manipulations = helpers.utils.get(training_log, 'manipulations')
if 'native' in manipulations: manipulations.remove('native')
else:
print('info: overriding manipulation list with {}'.format(manipulations))
manipulations = manipulations
try:
accuracy = helpers.utils.get(training_log, 'forensics.performance.accuracy.validation')[-1]
except:
accuracy = np.nan
distribution = helpers.utils.get(training_log, 'distribution')
if jpeg_qf is not None:
print('info: overriding JPEG quality with {}'.format(jpeg_qf))
distribution['compression_params']['quality'] = jpeg_qf
if jpeg_codec is not None:
print('info: overriding JPEG codec with {}'.format(jpeg_codec))
distribution['compression_params']['codec'] = jpeg_codec
if dcn_model is not None:
print('info: overriding DCN model with {}'.format(dcn_model))
distribution['compression_params']['dirname'] = dcn_model
flow = manipulation_classification.ManipulationClassification(isp, manipulations, distribution, coreutils.getkey(training_log, 'forensics/args'), {}, patch_size=patch_size)
flow.fan.load_model(os.path.join(os.path.split(filename)[0], 'models'))
return flow, accuracy
def main():
parser = argparse.ArgumentParser(description='Test manipulation detection (FAN) on RGB images')
group = parser.add_argument_group('General settings')
group.add_argument('-p', '--patch', dest='patch', action='store', default=64, type=int,
help='patch size')
group.add_argument('-i', '--images', dest='images', action='store', default=-1, type=int,
help='number of validation images (defaults to -1 - use all in the directory)')
group.add_argument('--patches', dest='patches', action='store', default=1, type=int,
help='number of validation patches')
group.add_argument('--data', dest='data', action='store', default='./data/rgb/native12k',
help='directory with test RGB images')
group.add_argument('--isp', dest='isp', action='store', default='ONet',
help='test imaging pipeline')
group = parser.add_argument_group('Training session selection')
group.add_argument('--dir', dest='dir', action='store', default='./data/m/7-raw',
help='directory with training sessions')
group.add_argument('--re', dest='re', action='store', default=None,
help='regular expression to filter training sessions')
group = parser.add_argument_group('Override training settings')
group.add_argument('-q', '--jpeg_qf', dest='jpeg_qf', action='store', default=None, type=int,
help='Override JPEG quality level (distribution channel)')
group.add_argument('-c', '--codec', dest='jpeg_codec', action='store', default=None, type=str,
help='Override JPEG codec settings (libjpeg, soft, sin)')
group.add_argument('--dcn', dest='dcn_model', action='store', default=None,
help='Coverride DCN model directory')
group.add_argument('-m', '--manip', dest='manipulations', action='store', default=None,
help='Included manipulations, e.g., : {}'.format('sharpen,jpeg,resample,gaussian'))
args = parser.parse_args()
# Split manipulations
if args.manipulations is not None:
args.manipulations = args.manipulations.strip().split(',')
json_files = sorted(str(f) for f in Path(args.dir).glob('**/training.json'))
if len(json_files) == 0:
sys.exit(0)
# Load training / validation data
if args.isp == 'ONet':
data = dataset.Dataset(args.data, n_images=0, v_images=args.images, load='y', val_rgb_patch_size=2 * args.patch, val_n_patches=args.patches)
else:
data = dataset.Dataset(args.data, n_images=0, v_images=args.images, load='xy', val_rgb_patch_size=2 * args.patch, val_n_patches=args.patches)
print('Data: {}'.format(data.summary()))
print('Found {} candidate training sessions ({})'.format(len(json_files), args.dir))
for filename in json_files:
if args.re is None or re.findall(args.re, filename):
flow, accuracy = restore_flow(filename, args.isp, args.manipulations, args.jpeg_qf, args.jpeg_codec, args.dcn_model, args.patch)
print(flow.summary())
_, conf = validate_fan(flow, data)
print('Accuracy validated/expected: {:.4f} / {:.4f}'.format(np.mean(np.diag(conf)), accuracy))
print(results_data.confusion_to_text((100*conf).round(0), flow._forensics_classes, filename, 'txt'))
else:
print('Skipping {}...'.format(filename))
if __name__ == "__main__":
main()