diff --git a/ai2thor/hooks/procedural_asset_hook.py b/ai2thor/hooks/procedural_asset_callback.py similarity index 63% rename from ai2thor/hooks/procedural_asset_hook.py rename to ai2thor/hooks/procedural_asset_callback.py index 0d7b1cbfc6..281112cba6 100644 --- a/ai2thor/hooks/procedural_asset_hook.py +++ b/ai2thor/hooks/procedural_asset_callback.py @@ -6,12 +6,23 @@ controller.step to locally run some local code """ +import concurrent.futures import logging import os -import warnings import pathlib +import tarfile +import warnings +from concurrent.futures import ThreadPoolExecutor +from contextlib import nullcontext +from tempfile import TemporaryDirectory +from typing import Dict, Any, List, TYPE_CHECKING, Sequence -from typing import Dict, Any, List +import requests +import tqdm +from filelock import FileLock + +if TYPE_CHECKING: + from ai2thor.controller import Controller from objathor.asset_conversion.util import ( get_existing_thor_asset_file_path, @@ -24,7 +35,8 @@ from objathor.dataset import load_assets_path, DatasetSaveConfig -logger = logging.getLogger(__name__) +logger = logging.getLogger(os.path.basename(__file__)) +logger.setLevel(logging.INFO) EXTENSIONS_LOADABLE_IN_UNITY = { ".json", @@ -49,12 +61,12 @@ def get_all_asset_ids_recursively(objects: List[Dict[str, Any]], asset_ids: List def create_asset( - thor_controller, - asset_id, - asset_directory, + thor_controller: "Controller", + asset_id: str, + asset_directory: str, copy_to_dir=None, verbose=False, - load_file_in_unity=False, + load_file_in_unity=True, extension=None, raise_for_failure=True, fail_if_not_unity_loadable=False, @@ -78,7 +90,7 @@ def create_assets( assets_dir: str, copy_to_dir=None, verbose=False, - load_file_in_unity=False, + load_file_in_unity=True, extension=None, fail_if_not_unity_loadable=False, raise_for_failure=True, @@ -253,10 +265,10 @@ def create_assets_if_not_exist( class ProceduralAssetActionCallback: def __init__( self, - asset_directory, + asset_directory: str, target_dir="processed_models", asset_symlink=True, - load_file_in_unity=False, + load_file_in_unity=True, stop_if_fail=False, asset_limit=-1, extension=None, @@ -373,3 +385,182 @@ def CreateHouse(self, action, controller): extension=self.extension, verbose=self.verbose, ) + + +def download_with_progress_bar(save_path: str, url: str, verbose: bool = False): + os.makedirs(os.path.dirname(save_path), exist_ok=True) + + with open(save_path, "wb") as f: + if verbose: + print(f"Downloading to {save_path}") + + response = requests.get(url, stream=True) + total_length = response.headers.get("content-length") + + content_type = response.headers.get("content-type") + if content_type is not None and content_type.startswith("text/html"): + raise ValueError(f"Invalid URL: {url}") + + if total_length is None: # no content length header + f.write(response.content) + else: + dl = 0 + total_length = int(total_length) + + with ( + tqdm.tqdm( + total=total_length, + unit="B", + unit_scale=True, + desc=f"Downloading asset {url}", + ) + if verbose + else nullcontext() + ) as pbar: + for data in response.iter_content(chunk_size=4096): + dl += len(data) + f.write(data) + if verbose: + pbar.update(len(data)) + + +def download_missing_asset( + asset_id: str, + asset_directory: str, + base_url: str, + verbose: bool = False, +) -> str: + final_save_dir = os.path.join(asset_directory, asset_id) + + if os.path.exists(final_save_dir): + if any(f"{asset_id}." in p for p in os.listdir(final_save_dir)): + return final_save_dir + else: + print( + f"Directory {final_save_dir} exists but could not find" + f" asset {asset_id} in it. Will attempt to redownload." + ) + + url = f"{base_url.strip('/')}/{asset_id}.tar" + + td = TemporaryDirectory() + with td as td_name: + save_path = os.path.join(td_name, f"{asset_id}.tar") + + download_with_progress_bar(save_path=save_path, url=url, verbose=verbose) + + os.makedirs(asset_directory, exist_ok=True) + + # Loop through all the files in the tar file and extract them one by one + # to the asset directory keeping the directory structure + with FileLock(os.path.join(os.path.expanduser("~"), ".ai2thor", "asset_extraction.lock")): + with tarfile.open(save_path, "r") as tar: + # Here we sort the members so that the . file is last to ensure that the object + # file is the last thing to be saved to the final location. We do this because + # we check for the existence of the . file to determine if the asset + # has been successfully downloaded previously and we want to avoid partial downloads. + for member in sorted(tar.getmembers(), key=lambda x: f"{asset_id}." in x.name): + # if "_renders" not in member.name and "success.txt" not in member.name: + tar.extract(member=member, path=asset_directory) + return final_save_dir + + +def wait_for_futures_and_raise_errors( + futures: Sequence[concurrent.futures.Future], +) -> Sequence[Any]: + results = [] + concurrent.futures.wait(futures) + for future in futures: + try: + results.append(future.result()) # This will re-raise any exceptions + except Exception: + raise + return results + + +def download_missing_assets( + asset_ids: Sequence[str], + asset_directory: str, + base_url: str, + verbose: bool = True, + threads: int = 1, +): + if verbose and threads > 1: + print(f"Downloading assets with {threads} threads. Will NOT log progress bars.") + + asset_ids = sorted(set(asset_ids)) + + with ThreadPoolExecutor(max_workers=threads) as executor: + futures = [ + executor.submit( + download_missing_asset, + asset_id=asset_id, + asset_directory=asset_directory, + base_url=base_url, + verbose=verbose and (threads == 1), + ) + for asset_id in asset_ids + ] + wait_for_futures_and_raise_errors(futures) + + +class WebProceduralAssetActionCallback(ProceduralAssetActionCallback): + def __init__( + self, + asset_directory: str, + base_url: str, + target_dir: str, + asset_symlink=True, + load_file_in_unity=True, + stop_if_fail=False, + asset_limit=-1, + extension=None, + verbose=True, + ): + super().__init__( + asset_directory=asset_directory, + target_dir=target_dir, + asset_symlink=asset_symlink, + load_file_in_unity=load_file_in_unity, + stop_if_fail=stop_if_fail, + asset_limit=asset_limit, + extension=extension, + verbose=verbose, + ) + self.base_url = base_url + + def _download_missing_assets(self, controller: "Controller", asset_ids: Sequence[str]): + asset_in_db = controller.step( + action="AssetsInDatabase", assetIds=asset_ids, updateProceduralLRUCache=False + ).metadata["actionReturn"] + assets_not_created = [asset_id for (asset_id, in_db) in asset_in_db.items() if not in_db] + download_missing_assets( + asset_ids=assets_not_created, + asset_directory=self.asset_directory, + base_url=self.base_url, + ) + + def Initialize(self, action, controller): + if self.asset_limit > 0: + return controller.step( + action="DeleteLRUFromProceduralCache", assetLimit=self.asset_limit + ) + + def CreateHouse(self, action: Dict[str, Any], controller: "Controller"): + house = action["house"] + asset_ids = get_all_asset_ids_recursively(house["objects"], []) + self._download_missing_assets(controller=controller, asset_ids=asset_ids) + + return super().CreateHouse(action=action, controller=controller) + + def SpawnAsset(self, action, controller): + self._download_missing_assets(controller=controller, asset_ids=[action["assetId"]]) + + return super().SpawnAsset(action=action, controller=controller) + + def GetHouseFromTemplate(self, action, controller): + template = action["template"] + asset_ids = get_all_asset_ids_recursively([v for (k, v) in template["objects"].items()], []) + self._download_missing_assets(controller=controller, asset_ids=asset_ids) + + super().GetHouseFromTemplate(action=action, controller=controller) diff --git a/tasks.py b/tasks.py index 4321fbc4c2..15941c81f0 100644 --- a/tasks.py +++ b/tasks.py @@ -4724,10 +4724,10 @@ def test_create_prefab(ctx, json_path): @task -def procedural_asset_hook_test(ctx, asset_dir, house_path, asset_id=""): +def procedural_asset_callback_test(ctx, asset_dir, house_path, asset_id=""): import json import ai2thor.controller - from ai2thor.hooks.procedural_asset_hook import ProceduralAssetActionCallback + from ai2thor.hooks.procedural_asset_callback import ProceduralAssetActionCallback from objathor.asset_conversion.util import view_asset_in_thor hook_runner = ProceduralAssetActionCallback( @@ -4817,7 +4817,7 @@ def procedural_asset_hook_test(ctx, asset_dir, house_path, asset_id=""): def procedural_asset_cache_test(ctx, asset_dir, house_path, asset_ids="", cache_limit=1): import json import ai2thor.controller - from ai2thor.hooks.procedural_asset_hook import ProceduralAssetActionCallback + from ai2thor.hooks.procedural_asset_callback import ProceduralAssetActionCallback hook_runner = ProceduralAssetActionCallback( asset_directory=asset_dir, asset_symlink=True, verbose=True, asset_limit=1 diff --git a/unity/Assets/Scripts/BaseFPSAgentController.cs b/unity/Assets/Scripts/BaseFPSAgentController.cs index 1ad6e89011..5a52738ef4 100644 --- a/unity/Assets/Scripts/BaseFPSAgentController.cs +++ b/unity/Assets/Scripts/BaseFPSAgentController.cs @@ -3,11 +3,11 @@ using System.Collections; using System.Collections.Generic; using System.IO; -using System.IO; using System.IO.Compression; using System.Linq; using System.Linq.Expressions; using System.Reflection; +using System.Threading.Tasks; using MessagePack; using MessagePack.Formatters; using MessagePack.Resolvers; @@ -22,6 +22,7 @@ using UnityStandardAssets.CrossPlatformInput; using UnityStandardAssets.ImageEffects; using UnityStandardAssets.Utility; +using Diagnostics = System.Diagnostics; using Random = UnityEngine.Random; namespace UnityStandardAssets.Characters.FirstPerson { @@ -7602,7 +7603,7 @@ string backTexturePath actionFinished(true); } - public ActionFinished CreateRuntimeAsset(ProceduralAsset asset) { + public ActionFinished CreateRuntimeAsset(ProceduralAsset asset, bool returnObject = false) { var assetDb = GameObject.FindObjectOfType(); if (assetDb.ContainsAssetKey(asset.name)) { return new ActionFinished( @@ -7627,27 +7628,20 @@ public ActionFinished CreateRuntimeAsset(ProceduralAsset asset) { annotations: asset.annotations, receptacleCandidate: asset.receptacleCandidate, yRotOffset: asset.yRotOffset, + returnObject: returnObject, serializable: asset.serializable, parentTexturesDir: asset.parentTexturesDir ); return new ActionFinished { success = true, actionReturn = assetData }; } - public ActionFinished CreateRuntimeAsset( + private static async Task LoadAssetAsync( string id, string dir, string extension = ".msgpack.gz", ObjectAnnotations annotations = null, bool serializable = false ) { - var assetDb = GameObject.FindObjectOfType(); - if (assetDb.ContainsAssetKey(id)) { - return new ActionFinished( - success: false, - errorMessage: $"'{id}' already exists in ProceduralAssetDatabase, trying to create procedural object twice, call `SpawnAsset` instead.", - toEmitState: true - ); - } var validDirs = new List() { Application.persistentDataPath, @@ -7663,20 +7657,14 @@ public ActionFinished CreateRuntimeAsset( extension = !extension.StartsWith(".") ? $".{extension}" : extension; extension = extension.Trim(); if (!supportedExtensions.Contains(extension)) { - return new ActionFinished( - success: false, - errorMessage: $"Unsupported extension `{extension}`. Only supported: {string.Join(", ", supportedExtensions)}", - actionReturn: null + throw new ArgumentException( + $"Unsupported extension `{extension}`. Only supported: {string.Join(", ", supportedExtensions)}" ); } var filename = $"{id}{extension}"; var filepath = Path.GetFullPath(Path.Combine(dir, id, filename)); if (!File.Exists(filepath)) { - return new ActionFinished( - success: false, - actionReturn: null, - errorMessage: $"Asset fiile '{filepath}' does not exist." - ); + throw new FileNotFoundException($"Asset file '{filepath}' does not exist."); } // to support different @@ -7716,43 +7704,46 @@ public ActionFinished CreateRuntimeAsset( ObjectCreationHandling = ObjectCreationHandling.Replace }; var json = reader.ReadToEnd(); - // procAsset = Newtonsoft.Json.JsonConvert.DeserializeObject(reader.ReadToEnd(), serializer); procAsset = JsonConvert.DeserializeObject(json); } else { - return new ActionFinished( - success: false, - errorMessage: $"Unexpected error with extension `{extension}`, filepath: `{filepath}`, compression stages: {string.Join(".", presentStages)}. Only supported: {string.Join(", ", supportedExtensions)}", - actionReturn: null + throw new ArgumentException( + $"Unexpected error with extension `{extension}`, filepath: `{filepath}`, compression stages: {string.Join(".", presentStages)}. Only supported: {string.Join(", ", supportedExtensions)}" ); } procAsset.parentTexturesDir = Path.Combine(dir, id); - var assetData = ProceduralTools.CreateAsset( - procAsset.vertices, - procAsset.normals, - procAsset.name, - procAsset.triangles, - procAsset.uvs, - procAsset.albedoTexturePath, - procAsset.metallicSmoothnessTexturePath, - procAsset.normalTexturePath, - procAsset.emissionTexturePath, - procAsset.colliders, - procAsset.physicalProperties, - procAsset.visibilityPoints, - procAsset.annotations ?? annotations, - procAsset.receptacleCandidate, - procAsset.yRotOffset, - returnObject: true, - serializable: serializable, - parent: null, - addAnotationComponent: false, - parentTexturesDir: procAsset.parentTexturesDir - ); + return procAsset; + } + + public ActionFinished CreateRuntimeAsset( + string id, + string dir, + string extension = ".msgpack.gz", + ObjectAnnotations annotations = null, + bool serializable = false + ) { + var assetDb = GameObject.FindObjectOfType(); + if (assetDb.ContainsAssetKey(id)) { + return new ActionFinished( + success: false, + errorMessage: $"'{id}' already exists in ProceduralAssetDatabase, trying to create procedural object twice, call `SpawnAsset` instead.", + toEmitState: true + ); + } + + var procAsset = LoadAssetAsync( + id: id, + dir: dir, + extension: extension, + annotations: annotations, + serializable: serializable + ).Result; + procAsset.serializable = serializable; + procAsset.annotations = procAsset.annotations ?? annotations; // Debug.Log($"root is null? {parent == null} - {parent}"); - return new ActionFinished(success: true, actionReturn: assetData); + return CreateRuntimeAsset(asset: procAsset, returnObject: true); } public class UnityLoadableAsset { @@ -7767,18 +7758,51 @@ public ActionFinished CreateRuntimeAssets( List assets, string dir = null ) { - foreach (var asset in assets) { - var actionFinished = CreateRuntimeAsset( - id: asset.id, - dir: dir ?? asset.dir, - extension: asset.extension, - annotations: asset.annotations + try { +#if UNITY_EDITOR + Diagnostics.Stopwatch stopWatch = new Diagnostics.Stopwatch(); + stopWatch.Start(); +#endif + // Load assets in parallel + var loadTasks = assets + .Select(asset => + LoadAssetAsync( + id: asset.id, + dir: dir ?? asset.dir, + extension: asset.extension, + annotations: asset.annotations + ) + ) + .ToList(); + Task.WhenAll(loadTasks).Wait(); + + var loadedAssets = loadTasks.Select(t => t.Result).ToList(); + +#if UNITY_EDITOR + stopWatch.Stop(); + Debug.Log( + $"LoadAssetAsync took {stopWatch.ElapsedMilliseconds} ms, per asset time {stopWatch.ElapsedMilliseconds / assets.Count} ms" ); - if (!actionFinished.success) { - return actionFinished; + stopWatch.Restart(); +#endif + // Create assets serially + foreach (var (asset, procAsset) in assets.Zip(loadedAssets, (a, p) => (a, p))) { + var actionFinished = CreateRuntimeAsset(asset: procAsset); + if (!actionFinished.success) { + return actionFinished; + } } +#if UNITY_EDITOR + stopWatch.Stop(); + Debug.Log( + $"CreateRuntimeAsset loop took {stopWatch.ElapsedMilliseconds} ms, per asset time {stopWatch.ElapsedMilliseconds / assets.Count} ms" + ); +#endif + + return ActionFinished.Success; + } catch (Exception ex) { + return new ActionFinished(success: false, errorMessage: ex.Message); } - return ActionFinished.Success; } public void GetStreamingAssetsPath() { diff --git a/unity/Assets/Scripts/UtilityFunctions.cs b/unity/Assets/Scripts/UtilityFunctions.cs index 10d14ea27d..c9c0157a62 100644 --- a/unity/Assets/Scripts/UtilityFunctions.cs +++ b/unity/Assets/Scripts/UtilityFunctions.cs @@ -484,7 +484,11 @@ public static List GetLightPropertiesOfScene() { return allOfTheLights; } - public static bool ArePositionsApproximatelyEqual(Vector3 position1, Vector3 position2, float epsilon = Vector3.kEpsilon) { + public static bool ArePositionsApproximatelyEqual( + Vector3 position1, + Vector3 position2, + float epsilon = Vector3.kEpsilon + ) { // Compare each component (x, y, z) of the two positions to see if they are approximately equal via the epsilon value return Mathf.Abs(position1.x - position2.x) < epsilon && Mathf.Abs(position1.y - position2.y) < epsilon