Skip to content

Commit

Permalink
Release changes
Browse files Browse the repository at this point in the history
  • Loading branch information
AlvaroHG committed Aug 22, 2024
1 parent 5a75cd2 commit 020bb5b
Show file tree
Hide file tree
Showing 5 changed files with 220 additions and 326 deletions.
42 changes: 26 additions & 16 deletions ai2thor/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,8 +399,8 @@ def __init__(
server_timeout: Optional[float] = 100.0,
server_start_timeout: float = 300.0,
# objaverse_asset_ids=[], TODO add and implement when objaverse.load_thor_objects is available
action_hook_runner=None,
metadata_hook: Optional[MetadataHook] = None,
before_action_callback=None,
metadata_callback: Optional[MetadataHook] = None,
**unity_initialization_parameters,
):
self.receptacle_nearest_pivot_points = {}
Expand Down Expand Up @@ -443,18 +443,28 @@ def __init__(
)
)

self.action_hook_runner = action_hook_runner
self.action_hooks = (
if "action_hook_runner" in unity_initialization_parameters:
raise ValueError(
f"Deprecated argument 'action_hook_runner'. Use 'before_action_callback' instead."
)

if "metadata_hook" in unity_initialization_parameters:
raise ValueError(
f"Deprecated argument 'metadata_hook'. Use 'metadata_callback' instead."
)

self.before_action_callback = before_action_callback
self.action_callbacks = (
{
func
for func in dir(action_hook_runner)
if callable(getattr(action_hook_runner, func)) and not func.startswith("__")
for func in dir(before_action_callback)
if callable(getattr(before_action_callback, func)) and not func.startswith("__")
}
if self.action_hook_runner is not None
if self.before_action_callback is not None
else None
)

self.metadata_hook = metadata_hook
self.metadata_callback = metadata_callback

if self.gpu_device is not None:
# numbers.Integral works for numpy.int32/64 and Python int
Expand Down Expand Up @@ -971,11 +981,11 @@ def multi_step_physics(self, action, timeStep=0.05, max_steps=20):

return events

def run_action_hook(self, action):
if self.action_hooks is not None and action["action"] in self.action_hooks:
def run_before_action_callback(self, action):
if self.action_callbacks is not None and action["action"] in self.action_callbacks:
try:
# print(f"action hooks: {self.action_hooks}")
method = getattr(self.action_hook_runner, action["action"])
method = getattr(self.before_action_callback, action["action"])
event = method(action, self)
if isinstance(event, list):
self.last_event = event[-1]
Expand All @@ -984,18 +994,18 @@ def run_action_hook(self, action):
except AttributeError:
traceback.print_stack()
raise NotImplementedError(
"Action Hook Runner `{}` does not implement method `{}`,"
"Action Callback `{}` does not implement method `{}`,"
" actions hooks are meant to run before an action, make sure that `action_hook_runner`"
" passed to the controller implements a method for the desired action.".format(
self.action_hook_runner.__class__.__name__, action["action"]
self.before_action_callback.__class__.__name__, action["action"]
)
)
return True
return False

def run_metadata_hook(self, metadata: MetadataWrapper) -> bool:
if self.metadata_hook is not None:
out = self.metadata_hook(metadata=metadata, controller=self)
if self.metadata_callback is not None:
out = self.metadata_callback(metadata=metadata, controller=self)
assert (
out is None
), "`metadata_hook` must return `None` and change the metadata in place."
Expand Down Expand Up @@ -1043,7 +1053,7 @@ def step(self, action: Union[str, Dict[str, Any]] = None, **action_args):
# not deleting to allow for older builds to continue to work
# del action[old]

self.run_action_hook(action)
self.run_before_action_callback(action)

self.server.send(action)
try:
Expand Down
73 changes: 51 additions & 22 deletions ai2thor/hooks/procedural_asset_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
load_existing_thor_asset_file,
)

from objathor.dataset import load_assets_path, DatasetSaveConfig

logger = logging.getLogger(__name__)

EXTENSIONS_LOADABLE_IN_UNITY = {
Expand Down Expand Up @@ -248,7 +250,7 @@ def create_assets_if_not_exist(
# return evt


class ProceduralAssetHookRunner:
class ProceduralAssetActionCallback:
def __init__(
self,
asset_directory,
Expand All @@ -268,6 +270,7 @@ def __init__(
self.target_dir = target_dir
self.extension = extension
self.verbose = verbose
self.last_asset_id_set = set()

def Initialize(self, action, controller):
if self.asset_limit > 0:
Expand All @@ -278,6 +281,10 @@ def Initialize(self, action, controller):
def CreateHouse(self, action, controller):
house = action["house"]
asset_ids = get_all_asset_ids_recursively(house["objects"], [])
asset_ids_set = set(asset_ids)
if not asset_ids_set.issubset(self.last_asset_id_set):
controller.step(action="DeleteLRUFromProceduralCache", assetLimit=0)
self.last_asset_id_set = set(asset_ids)
return create_assets_if_not_exist(
controller=controller,
asset_ids=asset_ids,
Expand Down Expand Up @@ -320,27 +327,49 @@ def GetHouseFromTemplate(self, action, controller):
)


class ObjaverseAssetHookRunner(object):
def __init__(self):
import objaverse

self.objaverse_uid_set = set(objaverse.load_uids())
class DownloadObjaverseActionCallback(object):
def __init__(
self,
asset_dataset_version,
asset_download_path,
target_dir="processed_models",
asset_symlink=True,
load_file_in_unity=False,
stop_if_fail=False,
asset_limit=-1,
extension=None,
verbose=True,
):
self.asset_download_path = asset_download_path
self.asset_symlink = asset_symlink
self.stop_if_fail = stop_if_fail
self.asset_limit = asset_limit
self.load_file_in_unity = load_file_in_unity
self.target_dir = target_dir
self.extension = extension
self.verbose = verbose
self.last_asset_id_set = set()
dsc = DatasetSaveConfig(
VERSION=asset_dataset_version,
BASE_PATH=asset_download_path,
)
self.asset_path = load_assets_path(dsc)

def CreateHouse(self, action, controller):
raise NotImplemented("Not yet implemented.")

house = action["house"]
asset_ids = list(set(obj["assetId"] for obj in house["objects"]))
evt = controller.step(action="AssetsInDatabase", assetIds=asset_ids)
asset_in_db = evt.metadata["actionReturn"]
assets_not_created = [asset_id for (asset_id, in_db) in asset_in_db.items() if in_db]
not_created_set = set(assets_not_created)
not_objeverse_not_created = not_created_set.difference(self.objaverse_uid_set)
if len(not_created_set):
raise ValueError(
f"Invalid asset ids are not in THOR AssetDatabase or part of objeverse: {not_objeverse_not_created}"
)

# TODO when transformed assets are in objaverse download them and create them
# objaverse.load_thor_objects
# create_assets()
asset_ids = get_all_asset_ids_recursively(house["objects"], [])
asset_ids_set = set(asset_ids)
if not asset_ids_set.issubset(self.last_asset_id_set):
controller.step(action="DeleteLRUFromProceduralCache", assetLimit=0)
self.last_asset_id_set = set(asset_ids)
return create_assets_if_not_exist(
controller=controller,
asset_ids=asset_ids,
asset_directory=self.asset_path,
copy_to_dir=os.path.join(controller._build.base_dir, self.target_dir),
asset_symlink=self.asset_symlink,
stop_if_fail=self.stop_if_fail,
load_file_in_unity=self.load_file_in_unity,
extension=self.extension,
verbose=self.verbose,
)
12 changes: 6 additions & 6 deletions tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4727,10 +4727,10 @@ def test_create_prefab(ctx, json_path):
def procedural_asset_hook_test(ctx, asset_dir, house_path, asset_id=""):
import json
import ai2thor.controller
from ai2thor.hooks.procedural_asset_hook import ProceduralAssetHookRunner
from ai2thor.hooks.procedural_asset_hook import ProceduralAssetActionCallback
from objathor.asset_conversion.util import view_asset_in_thor

hook_runner = ProceduralAssetHookRunner(
hook_runner = ProceduralAssetActionCallback(
asset_directory=asset_dir,
asset_symlink=True,
verbose=True,
Expand All @@ -4747,7 +4747,7 @@ def procedural_asset_hook_test(ctx, asset_dir, house_path, asset_id=""):
height=300,
server_class=ai2thor.fifo_server.FifoServer,
visibilityScheme="Distance",
action_hook_runner=hook_runner,
before_action_callback=hook_runner,
)

# TODO bug why skybox is not changing? from just procedural pipeline
Expand Down Expand Up @@ -4817,9 +4817,9 @@ def procedural_asset_hook_test(ctx, asset_dir, house_path, asset_id=""):
def procedural_asset_cache_test(ctx, asset_dir, house_path, asset_ids="", cache_limit=1):
import json
import ai2thor.controller
from ai2thor.hooks.procedural_asset_hook import ProceduralAssetHookRunner
from ai2thor.hooks.procedural_asset_hook import ProceduralAssetActionCallback

hook_runner = ProceduralAssetHookRunner(
hook_runner = ProceduralAssetActionCallback(
asset_directory=asset_dir, asset_symlink=True, verbose=True, asset_limit=1
)
controller = ai2thor.controller.Controller(
Expand All @@ -4834,7 +4834,7 @@ def procedural_asset_cache_test(ctx, asset_dir, house_path, asset_ids="", cache_
height=300,
server_class=ai2thor.wsgi_server.WsgiServer,
visibilityScheme="Distance",
action_hook_runner=hook_runner,
before_action_callback=hook_runner,
)
asset_ids = asset_ids.split(",")
with open(house_path, "r") as f:
Expand Down
Loading

0 comments on commit 020bb5b

Please sign in to comment.