-
Notifications
You must be signed in to change notification settings - Fork 1
/
Utils.hs
219 lines (160 loc) · 6.3 KB
/
Utils.hs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
module Utils
( roundToStr
, print_grads
, save_batch_to_file
, save_labels_to_file
, save_list_to_file
, save_vae_to_file
, load_vae_from_file
, get_random_batch
, get_random_batch_with_labels
, get_beta_KL
, get_contents_as_mat_from_fname
, get_contents_from_fname
) where
import Text.Printf
import Data.Array.IO
import Data.List
import Data.List.Split
import Control.Monad
import Data.Random
import System.Random
import Numeric.LinearAlgebra
import Data.Typeable
import Control.DeepSeq
import System.IO
import MatNNGradTypes
-- Get the contents of a single data file from an integer index
get_contents_from_fname_ind :: String -> Int -> IO [Double]
get_contents_from_fname_ind data_dir fname_ind = do
--let fname = "mnist_data/" ++ (show fname_ind) ++ ".txt"
let fname = data_dir ++ (show fname_ind) ++ ".txt"
let toDouble x = read x / 256
extractXs = map toDouble . tail . splitOn ","
f_h <- openFile fname ReadMode
line <- hGetLine f_h
hClose f_h
let data_list = extractXs line
return data_list
-- Same as get_contents_from_fname_ind but also get the label (assumed to be the first digit)
get_contents_from_fname_ind_with_labels :: String -> Int -> IO ([Double], Int)
get_contents_from_fname_ind_with_labels data_dir fname_ind = do
--let fname = "mnist_data/" ++ (show fname_ind) ++ ".txt"
let fname = data_dir ++ (show fname_ind) ++ ".txt"
let toDouble x = read x / 256
extractXs = map toDouble . tail . splitOn ","
extract_label = read . head . splitOn ","
f_h <- openFile fname ReadMode
line <- hGetLine f_h
hClose f_h
let data_list = extractXs line
label = extract_label line
return (data_list, label)
-- Get a random batch from the data dir
get_random_batch :: String -> Int -> Int -> IO [[Double]]
get_random_batch data_dir batch_size n_tot_samples = do
rand_list <- mapM (\x -> randomRIO (0, n_tot_samples-1)) [1..batch_size]
rand_batch <- mapM (get_contents_from_fname_ind data_dir) rand_list
return rand_batch
-- Same as get_random_batch but with labels
get_random_batch_with_labels :: String -> Int -> Int -> IO ([[Double]], [Int])
get_random_batch_with_labels data_dir batch_size n_tot_samples = do
rand_list <- mapM (\x -> randomRIO (0, n_tot_samples-1)) [1..batch_size]
rand_batch_label_list <- mapM (get_contents_from_fname_ind_with_labels data_dir) rand_list
let (rand_batch, labels) = unzip rand_batch_label_list
return (rand_batch, labels)
-- Generic file read function. Uses hGetLine to avoid lazy reading like readFile
get_contents_from_fname :: String -> IO [Double]
get_contents_from_fname fname = do
let toDouble x = read x / 256
extractXs = map toDouble . tail . splitOn ","
f_h <- openFile fname ReadMode
line <- hGetLine f_h
hClose f_h
let data_list = extractXs line
return data_list
-- Same as get_contents_from_fname, but formats it into a Matrix R
get_contents_as_mat_from_fname :: String -> IO (Matrix R)
get_contents_as_mat_from_fname fname = do
dat <- get_contents_from_fname fname
let m = fromLists [dat]
return m
--------------------- Saving data to file
list_to_csv_string :: [Int] -> String
list_to_csv_string my_list = intercalate "," $ map show my_list
matrix_to_string :: Matrix R -> String
matrix_to_string mat = unlines $ map (intercalate "," . map (roundToStr 3)) $ toLists mat
save_batch_to_file :: Batch -> String -> IO ()
save_batch_to_file (Batch b) fname = do
writeFile fname $ matrix_to_string b
save_labels_to_file :: [Int] -> String -> IO ()
save_labels_to_file labels fname = do
writeFile fname $ list_to_csv_string labels
write_matrix_to_file :: Matrix R -> String -> IO ()
write_matrix_to_file m fname = do
writeFile fname $ matrix_to_string m
read_matrix_from_file :: String -> IO (Matrix R)
read_matrix_from_file fname = do
let extractXs = map read . splitOn ","
num_lists <- map extractXs . lines <$> readFile fname
let mat = fromLists num_lists
return mat
save_list_to_file :: [Double] -> String -> IO ()
save_list_to_file l fname = do
let str_list = intercalate "," $ map (roundToStr 5) l
writeFile fname $ str_list
return ()
--------------------- Saving and loading NNs and VAE to and from file
save_vae_to_file :: VAE -> String -> IO ()
save_vae_to_file (VAE (front_half, back_half)) fname_base = do
putStrLn "\nSaving VAE to file..."
() <- save_nn_to_file front_half (fname_base ++ "_front_nn")
() <- save_nn_to_file back_half (fname_base ++ "_back_nn")
putStrLn "Done!\n"
return ()
save_nn_to_file :: NN -> String -> IO ()
save_nn_to_file (WeightMatList nn) fname_base = do
let fnames = [fname_base ++ "_layer_" ++ (show i) ++ ".txt" | i <- [0..(length nn - 1)]]
writeFile (fname_base ++ "_info.txt") (show $ length nn)
_ <- mapM (\(a, b) -> save_layer_to_file a b) $ zip nn fnames
--_ = zipWith save_layer_to_file nn fnames
return ()
save_layer_to_file :: Layer -> String -> IO ()
save_layer_to_file (Layer m) fname = do
write_matrix_to_file m fname
return ()
load_layer_frome_file :: String -> IO Layer
load_layer_frome_file fname = do
mat <- read_matrix_from_file fname
putStrLn ("Reading in mat of size " ++ (show $ size mat))
return (Layer mat)
load_vae_from_file :: String -> IO VAE
load_vae_from_file fname_base = do
putStrLn "\nLoading VAE from file..."
nn_front <- load_nn_from_file (fname_base ++ "_front_nn")
nn_back <- load_nn_from_file (fname_base ++ "_back_nn")
let vae = VAE (nn_front, nn_back)
putStrLn "Done!\n"
return vae
load_nn_from_file :: String -> IO NN
load_nn_from_file fname_base = do
let info_fname = (fname_base ++ "_info.txt")
nn_info <- (map read . lines <$> readFile info_fname)
let n_layers = head nn_info
let fnames = [fname_base ++ "_layer_" ++ (show i) ++ ".txt" | i <- [0..(n_layers - 1)]]
layer_list <- mapM load_layer_frome_file fnames
let nn = WeightMatList layer_list
return nn
-- Scheme to choose method for determining beta_KL based on current epoch and
-- number of total epochs.
get_beta_KL :: Int -> Int -> Double -> Double
get_beta_KL cur_epoch n_epochs beta_KL_max = beta_KL
where beta_KL = beta_KL_max -- const version
{-
| cur_epoch < div n_epochs 2 = 0
| otherwise = beta_KL_max*(fromIntegral $ cur_epoch - div n_epochs 2)/(fromIntegral $ div n_epochs 2)
-}
------------------------------ Misc
roundToStr :: (PrintfArg a, Floating a) => Int -> a -> String
roundToStr = printf "%0.*f"
--