Skip to content

Commit

Permalink
working initial test for command processor
Browse files Browse the repository at this point in the history
  • Loading branch information
JosephMontoya-TRI committed Nov 14, 2019
1 parent 784acf2 commit caea10e
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 17 deletions.
3 changes: 2 additions & 1 deletion taburu/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@ def __init__(self, filename):
self._filename = filename

def publish(self, event):
data = jsanitize(event, strict=True)
data = json.dumps(jsanitize(event, strict=True))
CHANNEL_THREAD_LOCK.acquire()
with open(self._filename, "a+") as f:
f.write(data)
f.write("\n")
CHANNEL_THREAD_LOCK.release()

def subscribe(self, iterations=None, poll_time=10):
Expand Down
47 changes: 35 additions & 12 deletions taburu/command.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
from taburu.state import uProjectState
from taburu.state import TaburuState
from taburu.event import Event, ParametersAdded, MethodAdded
from monty.json import MSONable


class AddParameterCommanded(ParametersAdded):
"""The event schema is the same aside from the name"""
pass
class AddParameterCommanded(Event, MSONable):
def __init__(self, table_name, parameters, time=None):
self.table_name = table_name
self.parameters = parameters
super(AddParameterCommanded, self).__init__(time)

def as_dict(self):
return {
"@class": self.__class__.__name__,
"@module": self.__class__.__module__,
"table_name": self.table_name,
"parameters": self.parameters,
"time": self._time
}


class AddMethodCommanded(Event, MSONable):
Expand All @@ -15,16 +26,25 @@ def __init__(self, name, parameter_names, time=None):
self.parameter_names = parameter_names
super(AddMethodCommanded, self).__init__(time)

def as_dict(self):
return {
"@class": self.__class__.__name__,
"@module": self.__class__.__module__,
"name": self.name,
"parameter_names": self.parameter_names,
"time": self._time
}


class uProjectCommandProcessor(object):
def __init__(self, input_channel, output_channel=None,
class TaburuCommandProcessor(object):
def __init__(self, command_channel, event_channel=None,
state=None):
self.state = state or uProjectState()
self.input_channel = input_channel
self.output_channel = output_channel
self.state = state or TaburuState()
self.command_channel = command_channel
self.event_channel = event_channel

def publish(self, event):
self.output_channel.publish(event)
self.event_channel.publish(event)

def process_command(self, command):
if isinstance(command, AddParameterCommanded):
Expand All @@ -48,6 +68,9 @@ def process_command(self, command):
self.publish(method_added)
return True

def run(self):
for command in self.input_channel.subscribe():
def run(self, iterations=None, poll_time=10):
commands = self.command_channel.subscribe(
iterations=iterations, poll_time=poll_time)
for command in commands:
self.process_command(command)
print(self.state)
6 changes: 5 additions & 1 deletion taburu/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,16 @@ def __init__(self, time=None):


# TODO: more descriptive name, populate space?
class ParametersAdded(Event, MSONable, time=None):
class ParametersAdded(Event, MSONable):
def __init__(self, table_name, parameters, time=None):
self.table_name = table_name
self.parameters = parameters
super(ParametersAdded, self).__init__(time)

def as_dict(self):
return {
"@class": self.__class__.__name__,
"@module": self.__class__.__module__,
"table_name": self.table_name,
"parameters": self.parameters,
"time": self._time
Expand All @@ -31,6 +33,8 @@ def __init__(self, name, parameter_indices, time=None):

def as_dict(self):
return {
"@class": self.__class__.__name__,
"@module": self.__class__.__module__,
"name": self.name,
"parameters": self.parameter_indices,
"time": self._time
Expand Down
18 changes: 15 additions & 3 deletions taburu/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def apply(self, event):
pass


class uProjectState(State):
class TaburuState(State):
def __init__(self):
self.parameters = IndexedOrderedDict()
self.methods = IndexedOrderedDict()
Expand All @@ -29,5 +29,17 @@ def apply(self, event):
return False
return True



def __str__(self):
return self.__repr__()

def __repr__(self):
parameters = ['{} {}: {}'.format(n, parameter_name, len(table))
for n, (parameter_name, table) in
enumerate(self.parameters.items())]
methods = ["{}: {}".format(method_name, parameters)
for method_name, parameters in self.methods.items()]
output = "Parameters:\n "
output += "\n ".join(parameters)
output += "\nMethods:\n "
output += "\n ".join(methods)
return output
55 changes: 55 additions & 0 deletions taburu/tests/test_command.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright (c) 2019 Toyota Research Institute. All rights reserved.

import unittest
from monty.tempfile import ScratchDir
from taburu.channel import FileChannel
from taburu.command import TaburuCommandProcessor, AddMethodCommanded, \
AddParameterCommanded


class CommandTest(unittest.TestCase):
def test_command_processor(self):
with ScratchDir('.'):
command_channel = FileChannel("commands.txt")
event_channel = FileChannel("events.txt")
processor = TaburuCommandProcessor(command_channel, event_channel)

# Dump some commands
add_parameter_1 = AddParameterCommanded(
table_name="agent",
parameters=[
{
"@class": ["camd.agent.agents.QBCStabilityAgent"],
"n_query": [4, 6, 8],
"n_members": list(range(2, 5)),
"hull_distance": [0.1, 0.2],
"training_fraction": [0.4, 0.5, 0.6],
},
]
)
add_parameter_2 = AddParameterCommanded(
table_name="chemsys",
parameters=[
{
"element_1": ["Fe", "Mn"],
"element_2": ["O", "S"],
},
]
)
add_method_1 = AddMethodCommanded(
name="CamdAgentSimulation",
parameter_names=["agent", "chemsys"]
)
command_channel.publish(add_parameter_1)
command_channel.publish(add_parameter_2)
command_channel.publish(add_method_1)
# Run for 6 seconds
processor.run(iterations=2, poll_time=0.5)
self.assertEqual(
processor.state.methods['CamdAgentSimulation'], (0, 1))
self.assertEqual(
len(processor.state.parameters['agent']), 54)


if __name__ == '__main__':
unittest.main()

0 comments on commit caea10e

Please sign in to comment.