diff --git a/tagilmo/VereyaPython/xml_util.py b/tagilmo/VereyaPython/xml_util.py index 8f8b928..8972af9 100644 --- a/tagilmo/VereyaPython/xml_util.py +++ b/tagilmo/VereyaPython/xml_util.py @@ -74,3 +74,10 @@ def str2xml(xml: str) -> Element: _, _, el.tag = el.tag.rpartition('}') # strip ns root = it.root return root + + +def remove_namespaces(el): + if el.tag.startswith('{'): + _, _, el.tag = el.tag.rpartition('}') + for child in el: + remove_namespaces(child) \ No newline at end of file diff --git a/tagilmo/utils/mission_builder.py b/tagilmo/utils/mission_builder.py index 592ec99..46e76f2 100755 --- a/tagilmo/utils/mission_builder.py +++ b/tagilmo/utils/mission_builder.py @@ -1,5 +1,7 @@ from xml.sax.saxutils import quoteattr from typing import Optional +import xml.etree.ElementTree as ET +from tagilmo.VereyaPython.xml_util import remove_namespaces #skipped: # # 10 @@ -19,7 +21,37 @@ def xml(self): _xml += ""; _xml += '\n\n' return _xml + + def from_xml(self, aboutRoot = None): + if aboutRoot is None or len(list(aboutRoot)) == 0: + return + + summary = aboutRoot.find("Summary") + if summary is not None: + self.summary = summary.text + +class ModSettings: + + def __init__(self, ms_per_tick=None): + self.ms_per_tick = ms_per_tick + + def xml(self): + _xml = '\n' + if self.ms_per_tick: + _xml += ""+self.ms_per_tick+"" + else: + _xml += "" + _xml += '\n\n' + return _xml + + def from_xml(self, msRoot = None): + if msRoot is None or len(list(msRoot)) == 0: + return + + msTick = msRoot.find("MsPerTick") + if msTick is not None: + self.ms_per_tick = msTick.text class ServerInitialConditions: @@ -61,6 +93,33 @@ def xml(self): _xml += ''+self.allowedmobs+'\n' #e.g. "Pig Sheep" _xml += '\n' return _xml + + def _time_from_xml(self, timeRoot = None): + if timeRoot is None or len(list(timeRoot)) == 0: + return + + if timeRoot.find("StartTime") is not None: + self.time_start = timeRoot.find("StartTime").text + + if timeRoot.find("AllowPassageOfTime") is not None: + self.time_pass = timeRoot.find("AllowPassageOfTime").text + + def from_xml(self, initConditionsRoot = None): + if initConditionsRoot is None or len(list(initConditionsRoot)) == 0: + return + + self._time_from_xml(initConditionsRoot.find("Time")) + + weather = None + if initConditionsRoot.find("Weather") is not None: + weather = initConditionsRoot.find("Weather").text + self.weather = weather + + allow_spawning = "true" + if initConditionsRoot.find("AllowSpawning") is not None: + allow_spawning = initConditionsRoot.find("AllowSpawning").text + self.spawning = allow_spawning + def flatworld(generatorString, forceReset="false", seed=''): @@ -90,64 +149,6 @@ def fileworld(uri2save, forceReset="false"): return str -class ServerHandlers: - - def __init__(self, worldgenerator_xml=defaultworld(), alldecorators_xml=None, - bQuitAnyAgent=False, timeLimitsMs_string=None, drawingdecorator = None): - self.worldgenerator = worldgenerator_xml - self.alldecorators = alldecorators_xml - self.bQuitAnyAgent = bQuitAnyAgent - self.timeLimitsMs = timeLimitsMs_string - self.drawingdecorator = drawingdecorator - - def xml(self): - _xml = '\n' + self.worldgenerator + '\n' - if self.drawingdecorator: - _xml += '\n' + self.drawingdecorator.xml() + '\n' - # -- - # -- - if self.alldecorators: - _xml += self.alldecorators + '\n' - if self.bQuitAnyAgent: - _xml += '\n' - if self.timeLimitsMs: - _xml += '\n' - _xml += '\n' - return _xml - - -class ServerSection: - - def __init__(self, handlers=ServerHandlers(), initial_conditions=ServerInitialConditions()): - self.handlers = handlers - self.initial_conditions = initial_conditions - - def xml(self): - return '\n'+self.initial_conditions.xml()+self.handlers.xml()+'\n' - - -class DrawingDecorator: - def __init__(self, decorators = []): - """ - Draw all given Draw objects - - Parameters: - decorators (List[Union[DrawBlock, DrawCuboid, DrawItem, DrawLine]]) : a list of objects to be drawn.\nEach object can be one of the following types: - - DrawBlock: represents a block. - - DrawCuboid: represents a cuboid. - - DrawItem: represents an item. - - DrawLine: represents a line between two points. - """ - self.decorators = decorators - - def xml(self): - _xml = "" - for decorator in self.decorators: - _xml += decorator.xml() - return _xml - - class DrawBlock: def __init__(self, x, y, z, blockType): """ @@ -260,6 +261,105 @@ def __init__(self, x, y, z, radius, blockType): def xml(self): return f'\n' +class DrawingDecorator: + def __init__(self, decorators = []): + """ + Draw all given Draw objects + + Parameters: + decorators (List[Union[DrawBlock, DrawCuboid, DrawItem, DrawLine]]) : a list of objects to be drawn.\nEach object can be one of the following types: + - DrawBlock: represents a block. + - DrawCuboid: represents a cuboid. + - DrawItem: represents an item. + - DrawLine: represents a line between two points. + """ + self.decorators = decorators + + def xml(self): + _xml = "" + for decorator in self.decorators: + _xml += decorator.xml() + return _xml + + def from_xml(self, drawingDecoratorRoot = None): + if drawingDecoratorRoot is None or len(list(drawingDecoratorRoot)) == 0: + return + drawing_elements = list(drawingDecoratorRoot) + for el in drawing_elements: + match el.tag: + case "DrawCuboid": + self.decorators.append(DrawCuboid(el.attrib["x1"], el.attrib["y1"], el.attrib["z1"], + el.attrib["x2"], el.attrib["y2"], el.attrib["z2"], + el.attrib["type"])) + case "DrawBlock": + self.decorators.append(DrawBlock(el.attrib["x"], el.attrib["y"], el.attrib["z"], + el.attrib["type"])) + case "DrawLine": + self.decorators.append(DrawLine(el.attrib["x1"], el.attrib["y1"], el.attrib["z1"], + el.attrib["x2"], el.attrib["y2"], el.attrib["z2"], + el.attrib["type"])) + case "DrawItem": + self.decorators.append(DrawItem(el.attrib["x"], el.attrib["y"], el.attrib["z"], + el.attrib["type"])) + case _: + continue + + +class ServerHandlers: + + def __init__(self, worldgenerator_xml=defaultworld(), alldecorators_xml=None, + bQuitAnyAgent=False, timeLimitsMs_string=None, drawingdecorator = DrawingDecorator()): + self.worldgenerator = worldgenerator_xml + self.alldecorators = alldecorators_xml + self.bQuitAnyAgent = bQuitAnyAgent + self.timeLimitsMs = timeLimitsMs_string + self.drawingdecorator = drawingdecorator + + def xml(self): + _xml = '\n' + self.worldgenerator + '\n' + if self.drawingdecorator: + _xml += '\n' + self.drawingdecorator.xml() + '\n' + # -- + # -- + if self.alldecorators: + _xml += self.alldecorators + '\n' + if self.bQuitAnyAgent: + _xml += '\n' + if self.timeLimitsMs: + _xml += '\n' + _xml += '\n' + return _xml + + def from_xml(self, handlersRoot = None): + if handlersRoot is None or len(list(handlersRoot)) == 0: + return + + world_generator = defaultworld() + flat_gen = handlersRoot.find("FlatWorldGenerator") + if flat_gen is not None: + world_generator = flatworld(flat_gen.attrib["generatorString"]) + self.worldgenerator = world_generator + + self.drawingdecorator.from_xml(handlersRoot.find("DrawingDecorator")) + + if handlersRoot.find("ServerQuitFromTimep") is not None: + time_limit = handlersRoot.find("ServerQuitFromTimep").attrib["timeLimitMs"] + self.timeLimitsMs = time_limit + + +class ServerSection: + + def __init__(self, handlers=ServerHandlers(), initial_conditions=ServerInitialConditions()): + self.handlers = handlers + self.initial_conditions = initial_conditions + + def xml(self): + return '\n'+self.initial_conditions.xml()+self.handlers.xml()+'\n' + + def from_xml(self, serverSectionRoot = None): + self.initial_conditions.from_xml(serverSectionRoot.find("ServerInitialConditions")) + self.handlers.from_xml(serverSectionRoot.find('ServerHandlers')) class Commands: @@ -296,6 +396,12 @@ def xml(self): # -- # -- return _xml + + def from_xml(commandsRoot= None): + if commandsRoot is None: + return + # TODO: parsing non-empty root + return class Observations: @@ -405,20 +511,95 @@ def xml(self): '.format(width=self.width, height=self.height) +class Block: + def __init__(self, reward=None, blockType=None, behaviour=None): + self.reward = reward + self.blockType = blockType + self.behaviour = behaviour + + def xml(self): + _xml = "\n' + + +class AgentQuitFromTouchingBlockType: + def __init__(self, quitBlocks = []): + self.quitBlocks = quitBlocks + + def xml(self): + if len(self.quitBlocks) == 0: + return "" + _xml = "\n" + for quitBlock in self.quitBlocks: + _xml += quitBlock.xml() + _xml += "\n" + return _xml + + def from_xml(self, agentQuitFromTouchingBlockType = None): + if agentQuitFromTouchingBlockType is None or len(list(agentQuitFromTouchingBlockType)) == 0: + return + blocks = agentQuitFromTouchingBlockType.findall("Block") + for block in blocks: + self.quitBlocks.append(Block(blockType=block.attrib['type'])) + class AgentHandlers: def __init__(self, commands=Commands(), observations=Observations(), - all_str='', video_producer=None, colourmap_producer=None): + all_str='', video_producer=None, colourmap_producer=None, + rewardForTouchingBlockType = RewardForTouchingBlockType(), rewardForSendingCommand=RewardForSendingCommand(), + agentQuitFromTouchingBlockType = AgentQuitFromTouchingBlockType()): self.commands = commands self.observations = observations self.all_str = all_str self.video_producer = video_producer self.colourmap_producer = colourmap_producer + self.rewardForTouchingBlockType = rewardForTouchingBlockType + self.rewardForSendingCommand = rewardForSendingCommand + self.agentQuitFromTouchingBlockType = agentQuitFromTouchingBlockType def xml(self): _xml = '\n' _xml += self.commands.xml() + _xml += '' if self.rewardForTouchingBlockType is None else self.rewardForTouchingBlockType.xml() + _xml += '' if self.rewardForSendingCommand is None else self.rewardForSendingCommand.xml() + _xml += '' if self.agentQuitFromTouchingBlockType is None else self.agentQuitFromTouchingBlockType.xml() _xml += self.observations.xml() _xml += self.all_str _xml += '' if self.video_producer is None else self.video_producer.xml() @@ -430,6 +611,40 @@ def xml(self): # ... return _xml + def from_xml(self, agentHandlersRoot = None): + if agentHandlersRoot is None: + return + + if agentHandlersRoot.find("ObservationFromFullStats") is not None: + obsFullStats = True + self.observations = Observations(bFullStats=obsFullStats) + + video_producer = VideoProducer() + if agentHandlersRoot.find("VideoProducer") is not None: + producer = agentHandlersRoot.find("VideoProducer") + video_producer = VideoProducer(height=int(producer.find("Height").text), + width=int(producer.find("Width").text), + want_depth=producer.attrib["want_depth"] == "true") + self.video_producer = video_producer + + self.commands.from_xml() + + rewards_for_touching = None + if agentHandlersRoot.find("RewardForTouchingBlockType") is not None: + rewards_for_touching = RewardForTouchingBlockType().from_xml(agentHandlersRoot.find("RewardForTouchingBlockType")) + self.rewardForTouchingBlockType = rewards_for_touching + + reward_for_sending_command = None + if agentHandlersRoot.find("RewardForSendingCommand") is not None: + reward = agentHandlersRoot.find("RewardForSendingCommand").attrib['reward'] + reward_for_sending_command = RewardForSendingCommand(reward=reward) + self.rewardForSendingCommand = reward_for_sending_command + + agent_quit = None + if agentHandlersRoot.find('AgentQuitFromTouchingBlockType') is not None: + agent_quit = AgentQuitFromTouchingBlockType().from_xml(agentHandlersRoot.find('AgentQuitFromTouchingBlockType')) + self.agentQuitFromTouchingBlockType = agent_quit + def hasVideo(self): if self.video_producer is None: return False @@ -495,6 +710,34 @@ def hasSegmentation(self): if self.agenthandlers.hasSegmentation(): return True return False + + def from_xml(self, agentSectionRoot = None): + + if agentSectionRoot is None or len(agentSectionRoot) == 0: + return [AgentSection()] + + sections = [] + for agent in agentSectionRoot: + agent_section = AgentSection() + if agent.find("mode") is not None: + agent_section.mode = agent.attrib["mode"] + + if agent.find("Name") is not None: + agent_section.name = agent.find("Name").text + + if agent.find("AgentStart") is not None: + xyzp = None + if agent.find("AgentStart").find("Placement") is not None: + placement = agent.find("AgentStart").find("Placement") + xyzp = [float(v) for _,v in placement.items()][:4] #yaw is not in constructor + agent_start = AgentStart(place_xyzp=xyzp) + agent_section.agentstart = agent_start + + agent_section.agenthandlers.from_xml(agent.find("AgentHandlers")) + + sections.append(agent_section) + + return sections class MinecraftServerConnection: @@ -508,11 +751,15 @@ def xml(self): class MissionXML: - def __init__(self, about=About(), serverSection=ServerSection(), agentSections=[AgentSection()], namespace=None): + def __init__(self, xml_path = None, about=About(), modSettings = ModSettings(), + serverSection=ServerSection(), agentSections=[AgentSection()], namespace=None): self.namespace = namespace self.about = about + self.modSettings = modSettings self.serverSection = serverSection self.agentSections = agentSections + if xml_path is not None: + self.from_xml(xml_path) def hasVideo(self): for section in self.agentSections: @@ -561,8 +808,21 @@ def xml(self): '''.format(namespace) _xml += self.about.xml() + _xml += self.modSettings.xml() _xml += self.serverSection.xml() for agentSection in self.agentSections: _xml += agentSection.xml() _xml += '' return _xml + + def from_xml(self, xml_path): + tree = ET.parse(xml_path) + missionRoot = tree.getroot() + remove_namespaces(missionRoot) + if missionRoot is None or len(list(missionRoot)) == 0: + return + + self.about.from_xml(missionRoot.find("About")) + self.modSettings.from_xml(missionRoot.find("ModSettings")) + self.serverSection.from_xml(missionRoot.find("ServerSection")) + self.agentSections = AgentSection().from_xml(missionRoot.findall("AgentSection")) \ No newline at end of file diff --git a/tagilmo/utils/vereya_wrapper.py b/tagilmo/utils/vereya_wrapper.py index 5f800f7..c325d2b 100644 --- a/tagilmo/utils/vereya_wrapper.py +++ b/tagilmo/utils/vereya_wrapper.py @@ -471,6 +471,16 @@ def _sendMotionCommand(self, command, value, agentId=None): def strafe(self, value, agentId=None): return self._sendMotionCommand('strafe', value, agentId) + def discreteMove(self, value, agentId=None): + """Moves the agent one block along one of the cardinal directions. + + Args: + value (string): "west" | "east" | "north" | "south" + + agentId (int, optional): id of an agent + """ + return self._sendMotionCommand('move'+value, 1, agentId) + def move(self, value, agentId=None): return self._sendMotionCommand('move', value, agentId) diff --git a/tests/vereya/run_tests.py b/tests/vereya/run_tests.py index 7712a05..57daca4 100644 --- a/tests/vereya/run_tests.py +++ b/tests/vereya/run_tests.py @@ -29,7 +29,9 @@ def main(): test_files = ['test_motion_vereya', 'test_craft', 'test_inventory', 'test_quit', 'test_observation', 'test_placement', 'test_image', - 'test_consistency', 'test_motion_mob', 'test_mob', 'test_agent', 'test_flat_world', 'test_draw'] + 'test_consistency', 'test_motion_mob', 'test_mob', + 'test_agent', 'test_flat_world', 'test_draw', + 'test_discrete_motion'] res = run_tests(test_files) if not res.wasSuccessful(): sys.exit(1) diff --git a/tests/vereya/test_discrete_motion.py b/tests/vereya/test_discrete_motion.py new file mode 100644 index 0000000..f0dc041 --- /dev/null +++ b/tests/vereya/test_discrete_motion.py @@ -0,0 +1,85 @@ +import logging +import json +import time +import unittest +import tagilmo.utils.mission_builder as mb +from tagilmo.utils.vereya_wrapper import MCConnector, RobustObserver +from base_test import BaseTest +from tagilmo import VereyaPython +from common import init_mission + + +logger = logging.getLogger(__file__) + +class TestDiscreteMotion(BaseTest): + + mc = None + obs = None + + @classmethod + def setUpClass(cls, *args, **kwargs): + start = (-151.0, -213.0) + mc, obs = init_mission(None, start_x=start[0], start_z=start[1], start_y=66, forceReset='true', seed='43') + cls.mc = mc + cls.obs = obs + assert mc.safeStart() + time.sleep(4) + + def test_discrete_motion(self): + obs = self.obs + mc = self.mc + pos0 = obs.getCachedObserve('getAgentPos') + logger.info('position ' + str(pos0)) + print('sending command movenorth 1') + mc.discreteMove("north") + time.sleep(1) + + pos_move = obs.getCachedObserve('getAgentPos') + logger.info('position ' + str(pos_move)) + #check if the z-coordinate has changed (-1) + self.assertEqual(pos0[2] - 1, pos_move[2], 'movenorth 1 sent but position is the same') + pos0 = pos_move + print('sending command movesouth 1') + mc.discreteMove("south") + time.sleep(1) + + pos_move = obs.getCachedObserve('getAgentPos') + logger.info('position ' + str(pos_move)) + #check if the z-coordinate has changed (+1) + self.assertEqual(pos0[2] + 1, pos_move[2], 'movesouth 1 sent but position is the same') + pos0 = pos_move + print('sending command movewest 1') + mc.discreteMove("west") + time.sleep(1) + + pos_move = obs.getCachedObserve('getAgentPos') + logger.info('position ' + str(pos_move)) + #check if the x-coordinate has changed (-1) + self.assertEqual(pos0[0] - 1, pos_move[0], 'movewest 1 sent but position is the same') + pos0 = pos_move + print('sending command moveeast 1') + mc.discreteMove("east") + time.sleep(1) + + pos_move = obs.getCachedObserve('getAgentPos') + logger.info('position ' + str(pos_move)) + #check if the z-coordinate has changed (+1) + self.assertEqual(pos0[0] + 1, pos_move[0], 'moveeast 1 sent but position is the same') + + + def teadDown(self): + self.obs.stopMove() + super().tearDown() + + def setUp(self): + super().setUp() + self.obs.stopMove() + + @classmethod + def tearDownClass(cls, *args, **kwargs): + cls.mc.stop() + + +if __name__ == '__main__': + VereyaPython.setupLogger() + unittest.main()