Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preliminary changes for release (WIP) #107

Draft
wants to merge 13 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 93 additions & 80 deletions bamt/builders/builders_base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import itertools
from typing import Dict, List, Optional, Tuple, Callable, TypedDict, Sequence, Union
from typing import Dict, List, Optional, Tuple, Callable, Union

from pandas import DataFrame

Expand All @@ -11,15 +11,9 @@
from bamt.nodes.gaussian_node import GaussianNode
from bamt.nodes.logit_node import LogitNode
from bamt.nodes.mixture_gaussian_node import MixtureGaussianNode
from bamt.utils import GraphUtils as gru
from bamt.utils import graph_utils as gru


class ParamDict(TypedDict, total=False):
init_edges: Optional[Sequence[str]]
init_nodes: Optional[List[str]]
remove_init_edges: bool
white_list: Optional[Tuple[str, str]]
bl_add: Optional[List[str]]
from bamt.checkers.enums import continuous_nodes


class StructureBuilder(object):
Expand All @@ -28,19 +22,25 @@ class StructureBuilder(object):
It can restrict nodes defined by RESTRICTIONS
"""

def __init__(self, descriptor: Dict[str, Dict[str, str]]):
def __init__(self, descriptor: Dict[str, Dict[str, str]], **kwargs):
"""
:param descriptor: a dict with types and signs of nodes
Attributes:
black_list: a list with restricted connections;
"""
self.skeleton = {"V": [], "E": []}
# name "nodes" in builders' space is forbidden!
self.vertices = []
# name "edges" in builders' space is forbidden!
self.structure = []
self.descriptor = descriptor

self.has_logit = bool

self.black_list = None

self.checker_descriptor = kwargs["checkers_rules"]["descriptor"]
self.is_restricted_pair = kwargs["checkers_rules"]["restriction_rule"]

def restrict(
self,
data: DataFrame,
Expand All @@ -52,16 +52,14 @@ def restrict(
:param init_nodes: nodes to begin with (thus they have no parents)
:param bl_add: additional vertices
"""
node_type = self.descriptor["types"]
blacklist = []
datacol = data.columns.to_list()

if not self.has_logit:
# Has_logit flag allows BN building edges between cont and disc
RESTRICTIONS = [("cont", "disc"), ("cont", "disc_num")]
for x, y in itertools.product(datacol, repeat=2):
if x != y:
if (node_type[x], node_type[y]) in RESTRICTIONS:
if self.is_restricted_pair(x, y):
blacklist.append((x, y))
else:
self.black_list = []
Expand All @@ -73,42 +71,38 @@ def restrict(

def get_family(self):
"""
A function that updates a skeleton;
A function that updates each node accordingly structure;
"""
if not self.skeleton["V"]:
if not self.vertices:
logger_builder.error("Vertex list is None")
return None
if not self.skeleton["E"]:
if not self.structure:
logger_builder.error("Edges list is None")
return None
for node_instance in self.skeleton["V"]:
node = node_instance.name
children = []
parents = []
for edge in self.skeleton["E"]:
if node in edge:
if edge.index(node) == 0:
children.append(edge[1])
if edge.index(node) == 1:
parents.append(edge[0])

disc_parents = []
cont_parents = []
for parent in parents:
if self.descriptor["types"][parent] in ["disc", "disc_num"]:
disc_parents.append(parent)
else:
cont_parents.append(parent)

id = self.skeleton["V"].index(node_instance)
self.skeleton["V"][id].disc_parents = disc_parents
self.skeleton["V"][id].cont_parents = cont_parents
self.skeleton["V"][id].children = children
node_mapping = {
node_instance.name: {"disc_parents": [], "cont_parents": [], "children": []}
for node_instance in self.vertices
}

for edge in self.structure:
parent, child = edge[0], edge[1]
node_mapping[parent]["children"].append(child)
if self.checker_descriptor["types"][parent].is_cont:
node_mapping[child]["cont_parents"].append(parent)
else:
node_mapping[child]["disc_parents"].append(parent)

for node, data in node_mapping.items():
node_instance = next(n for n in self.vertices if n.name == node)
node_instance.disc_parents = data["disc_parents"]
node_instance.cont_parents = data["cont_parents"]
node_instance.children = data["children"]

ordered = gru.toporder(self.skeleton["V"], self.skeleton["E"])
not_ordered = [node.name for node in self.skeleton["V"]]
ordered = gru.toporder(self.vertices, self.structure)
not_ordered = [node.name for node in self.vertices]
mask = [not_ordered.index(name) for name in ordered]
self.skeleton["V"] = [self.skeleton["V"][i] for i in mask]
self.vertices = [self.vertices[i] for i in mask]


class VerticesDefiner(StructureBuilder):
Expand All @@ -117,29 +111,40 @@ class VerticesDefiner(StructureBuilder):
"""

def __init__(
self, descriptor: Dict[str, Dict[str, str]], regressor: Optional[object]
self,
descriptor: Dict[str, Dict[str, str]],
regressor: Optional[object],
**kwargs,
):
"""
Automatically creates a list of nodes
"""
super(VerticesDefiner, self).__init__(descriptor=descriptor)
# Notice that vertices are used only by Builders
self.vertices = []
super().__init__(descriptor=descriptor, **kwargs)
self.regressor = regressor

def init_nodes(self):
vertices = []
node = None
# LEVEL 1: Define a general type of node: Discrete or Gaussian
for vertex, type in self.descriptor["types"].items():
if type in ["disc_num", "disc"]:
for vertex, vertex_checker in self.checker_descriptor["types"].items():
if not vertex_checker.is_cont:
node = DiscreteNode(name=vertex)
elif type == "cont":
node = GaussianNode(name=vertex, regressor=regressor)
elif vertex_checker.is_cont:
node = GaussianNode(name=vertex, regressor=self.regressor)
else:
msg = f"""First stage of automatic vertex detection failed on {vertex} due TypeError ({type}).
Set vertex manually (by calling set_nodes()) or investigate the error."""
Set vertex manually (by calling set_nodes()) or investigate the error."""
logger_builder.error(msg)
continue

self.vertices.append(node)
vertices.append(node)
return vertices


class EdgesDefiner(StructureBuilder):
def __init__(self, vertices, descriptor: Dict[str, Dict[str, str]], **kwargs):
super().__init__(descriptor, **kwargs)
self.vertices = vertices
self.structure = []

def overwrite_vertex(
self,
Expand All @@ -158,57 +163,58 @@ def overwrite_vertex(
for node_instance in self.vertices:
node = node_instance
if has_logit:
if "Discrete" in node_instance.type:
if node_instance.cont_parents:
if not node_instance.disc_parents:
node = LogitNode(
name=node_instance.name, classifier=classifier
)

elif node_instance.disc_parents:
node = ConditionalLogitNode(
name=node_instance.name, classifier=classifier
)

if use_mixture:
if "Gaussian" in node_instance.type:
if node_instance.cont_parents:
if not node_instance.disc_parents:
node = LogitNode(name=node_instance.name, classifier=classifier)

elif node_instance.disc_parents:
node = ConditionalLogitNode(
name=node_instance.name, classifier=classifier
)

if node_instance.type in continuous_nodes:
if use_mixture:
if not node_instance.disc_parents:
node = MixtureGaussianNode(name=node_instance.name)
elif node_instance.disc_parents:
node = ConditionalMixtureGaussianNode(name=node_instance.name)
else:
continue
else:
if "Gaussian" in node_instance.type:
else:
if node_instance.disc_parents:
node = ConditionalGaussianNode(
name=node_instance.name, regressor=regressor
)
else:
continue
# redefine default node with new regressor
node = GaussianNode(
name=node_instance.name, regressor=regressor
)

node_checker = self.checker_descriptor["types"][node.name].evolve(
node.type, node_instance.cont_parents, node_instance.disc_parents
)

self.checker_descriptor["types"][node.name] = node_checker

if node_instance == node:
continue

id = self.skeleton["V"].index(node_instance)
id = self.vertices.index(node_instance)
node.disc_parents = node_instance.disc_parents
node.cont_parents = node_instance.cont_parents
node.children = node_instance.children
self.skeleton["V"][id] = node


class EdgesDefiner(StructureBuilder):
def __init__(self, descriptor: Dict[str, Dict[str, str]]):
super(EdgesDefiner, self).__init__(descriptor)
self.vertices[id] = node


class BaseDefiner(VerticesDefiner, EdgesDefiner):
class BaseDefiner(EdgesDefiner, VerticesDefiner):
def __init__(
self,
data: DataFrame,
vertices: list,
descriptor: Dict[str, Dict[str, str]],
scoring_function: Union[Tuple[str, Callable], Tuple[str]],
regressor: Optional[object] = None,
checkers_rules: dict,
):
self.scoring_function = scoring_function
self.params = {
Expand All @@ -218,5 +224,12 @@ def __init__(
"white_list": None,
"bl_add": None,
}
super().__init__(descriptor, regressor=regressor)

super().__init__(
descriptor=descriptor,
vertices=vertices,
checkers_rules=checkers_rules,
regressor=None,
)

self.optimizer = None # will be defined in subclasses
6 changes: 3 additions & 3 deletions bamt/builders/composite_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from bamt.nodes.composite_discrete_node import CompositeDiscreteNode
from bamt.nodes.discrete_node import DiscreteNode
from bamt.nodes.gaussian_node import GaussianNode
from bamt.utils import EvoUtils as evo
from bamt.utils import evo_utils as evo
from bamt.utils.composite_utils import CompositeGeneticOperators
from bamt.utils.composite_utils.CompositeModel import CompositeModel, CompositeNode

Expand Down Expand Up @@ -136,11 +136,11 @@ def overwrite_vertex(
node = DiscreteNode(name=node_instance.name)
else:
continue
id_node = self.skeleton["V"].index(node_instance)
id_node = self.vertices.index(node_instance)
node.disc_parents = node_instance.disc_parents
node.cont_parents = node_instance.cont_parents
node.children = node_instance.children
self.skeleton["V"][id_node] = node
self.vertices[id_node] = node

def build(
self,
Expand Down
2 changes: 1 addition & 1 deletion bamt/builders/evo_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from pandas import DataFrame

from bamt.builders.builders_base import BaseDefiner
from bamt.utils import EvoUtils as evo
from bamt.utils import evo_utils as evo


class EvoDefiner(BaseDefiner):
Expand Down
Loading
Loading