Skip to content

Commit

Permalink
run black
Browse files Browse the repository at this point in the history
  • Loading branch information
bpaul4 committed Sep 25, 2023
1 parent bd57653 commit 1a3d3af
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,25 +40,29 @@ def create_model(x_train, z_train, grad_train):
# we have x_train = (n_m, n_x), z_train = (n_m, n_y) and grad_train = (n_y, n_m, n_x)
n_m, n_x = np.shape(x_train)
_, n_y = np.shape(z_train)

# check dimensions using grad_train
assert np.shape(grad_train) == (n_y, n_m, n_x)

# reshape arrays
X = np.reshape(x_train, (n_x, n_m))
Y = np.reshape(z_train, (n_y, n_m))
J = np.reshape(grad_train, (n_y, n_x, n_m))

# set up and train model

# Train neural net
model = Model.initialize(X.shape[0], Y.shape[0], deep=2, wide=6) # 2 hidden layers with 6 neurons each
model = Model.initialize(
X.shape[0], Y.shape[0], deep=2, wide=6
) # 2 hidden layers with 6 neurons each
model.train(
X=X, # input data
Y=Y, # output data
J=J, # gradient data
num_iterations=25, # number of optimizer iterations per mini-batch
mini_batch_size=int(np.floor(n_m/5)), # used to divide data into training batches (use for large data sets)
mini_batch_size=int(
np.floor(n_m / 5)
), # used to divide data into training batches (use for large data sets)
num_epochs=20, # number of passes through data
alpha=0.15, # learning rate that controls optimizer step size
beta1=0.99, # tuning parameter to control ADAM optimization
Expand All @@ -79,6 +83,7 @@ def create_model(x_train, z_train, grad_train):

return model


# Main code

# import data
Expand All @@ -95,7 +100,7 @@ def create_model(x_train, z_train, grad_train):

xmax, xmin = xdata.max(axis=0), xdata.min(axis=0)
zmax, zmin = zdata.max(axis=0), zdata.min(axis=0)
xdata, zdata = np.array(xdata), np.array(zdata) # (n_m, n_x) and (n_m, n_y)
xdata, zdata = np.array(xdata), np.array(zdata) # (n_m, n_x) and (n_m, n_y)
gdata = np.stack([np.array(grad0_data), np.array(grad1_data)]) # (2, n_m, n_x)

model_data = np.concatenate(
Expand All @@ -110,7 +115,7 @@ def create_model(x_train, z_train, grad_train):
model = create_model(x_train=xdata, z_train=zdata, grad_train=gdata)

with open("mea_column_model_smt.pkl", "wb") as file:
pickle.dump(model, file)
pickle.dump(model, file)

# load model as pickle format
with open("mea_column_model_smt.pkl", "rb") as file:
Expand Down
37 changes: 23 additions & 14 deletions foqus_lib/framework/graph/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,10 +651,7 @@ def sub_symbols(i):
)[0]
elif self.trainer == "smt":
self.scaled_outputs = self.model.evaluate(
np.reshape(
np.array(self.scaled_inputs, ndmin=2),
(self.model._n_x, 1)
)
np.reshape(np.array(self.scaled_inputs, ndmin=2), (self.model._n_x, 1))
)
else: # this shouldn't occur, adding failsafe just in case
raise AttributeError(
Expand Down Expand Up @@ -1226,22 +1223,28 @@ def setSim(self, newType=None, newModel=None, force=False, ids=None):
elif (
extension == ".pkl"
): # use importlib/pickle loading syntax for SciKitLearn models
pickle_loaded = False # use a flag so we don't overload the model unnecessarily
try: # try Scikitlearn first
pickle_loaded = (
False # use a flag so we don't overload the model unnecessarily
)
try: # try Scikitlearn first
if not pickle_loaded:
with open(str(self.modelName) + extension, "rb") as file:
self.model = skl_pickle_load(file)
pickle_loaded = True
except ModuleNotFoundError as e:
_logger.info(e) # will print that sklearn is not installed but won't just fail
_logger.info(
e
) # will print that sklearn is not installed but won't just fail

try: # try SMT next
try: # try SMT next
if not pickle_loaded:
with open(str(self.modelName) + extension, "rb") as file:
self.model = smt_pickle_load(file)
pickle_loaded = True
except ModuleNotFoundError as e:
_logger.info(e) # will print that sklearn is not installed but won't just fail
_logger.info(
e
) # will print that sklearn is not installed but won't just fail

# now check which model type was unpickled
if "sklearn" in str(type(self.model)):
Expand Down Expand Up @@ -1793,22 +1796,28 @@ def runPymodelMLAI(self):
elif (
extension == ".pkl"
): # use importlib/pickle loading syntax for SciKitLearn models
pickle_loaded = False # use a flag so we don't overload the model unnecessarily
try: # try Scikitlearn first
pickle_loaded = (
False # use a flag so we don't overload the model unnecessarily
)
try: # try Scikitlearn first
if not pickle_loaded:
with open(str(self.modelName) + extension, "rb") as file:
self.model = skl_pickle_load(file)
pickle_loaded = True
except ModuleNotFoundError as e:
_logger.info(e) # will print that sklearn is not installed but won't just fail
_logger.info(
e
) # will print that sklearn is not installed but won't just fail

try: # try SMT next
try: # try SMT next
if not pickle_loaded:
with open(str(self.modelName) + extension, "rb") as file:
self.model = smt_pickle_load(file)
pickle_loaded = True
except ModuleNotFoundError as e:
_logger.info(e) # will print that sklearn is not installed but won't just fail
_logger.info(
e
) # will print that sklearn is not installed but won't just fail

# now check which model type was unpickled
if "sklearn" in str(type(self.model)):
Expand Down
5 changes: 1 addition & 4 deletions foqus_lib/gui/tests/test_ml_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,10 +272,7 @@ def test_load_and_run_measmt(self, active_session, simnode):
pytest.importorskip("sympy", reason="sympy not installed")
# set sim name and confirm it's the correct model
simnode.simNameBox.setCurrentIndex(8)
assert (
simnode.simNameBox.currentText()
== "mea_column_model_smt"
)
assert simnode.simNameBox.currentText() == "mea_column_model_smt"

def test_flowsheet_run_successful(
self,
Expand Down
8 changes: 6 additions & 2 deletions foqus_lib/unit_test/node_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,9 @@ def test_import_sklearn_success(self):
# check that the returned functions expect the correct input as a way
# of confirming that the class (function) types are correct
with pytest.raises(TypeError):
skl_pickle_load(None) # should fail to find 'read' and 'readline' attributes
skl_pickle_load(
None
) # should fail to find 'read' and 'readline' attributes

def test_import_smt_success(self):
# skip this test if smt is not available
Expand All @@ -187,7 +189,9 @@ def test_import_smt_success(self):
# check that the returned functions expect the correct input as a way
# of confirming that the class (function) types are correct
with pytest.raises(TypeError):
smt_pickle_load(None) # should fail to find 'read' and 'readline' attributes
smt_pickle_load(
None
) # should fail to find 'read' and 'readline' attributes


# ----------------------------------------------------------------------------
Expand Down

0 comments on commit 1a3d3af

Please sign in to comment.