Skip to content

Commit

Permalink
Merge pull request #1244 from allenai/web-assets-main
Browse files Browse the repository at this point in the history
Web Assets
  • Loading branch information
AlvaroHG authored Nov 1, 2024
2 parents 1a0a95c + 9cdbe76 commit d9ff8cf
Show file tree
Hide file tree
Showing 4 changed files with 290 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,23 @@
controller.step to locally run some local code
"""
import concurrent.futures
import logging
import os
import warnings
import pathlib
import tarfile
import warnings
from concurrent.futures import ThreadPoolExecutor
from contextlib import nullcontext
from tempfile import TemporaryDirectory
from typing import Dict, Any, List, TYPE_CHECKING, Sequence

from typing import Dict, Any, List
import requests
import tqdm
from filelock import FileLock

if TYPE_CHECKING:
from ai2thor.controller import Controller

from objathor.asset_conversion.util import (
get_existing_thor_asset_file_path,
Expand All @@ -24,7 +35,8 @@

from objathor.dataset import load_assets_path, DatasetSaveConfig

logger = logging.getLogger(__name__)
logger = logging.getLogger(os.path.basename(__file__))
logger.setLevel(logging.INFO)

EXTENSIONS_LOADABLE_IN_UNITY = {
".json",
Expand All @@ -49,12 +61,12 @@ def get_all_asset_ids_recursively(objects: List[Dict[str, Any]], asset_ids: List


def create_asset(
thor_controller,
asset_id,
asset_directory,
thor_controller: "Controller",
asset_id: str,
asset_directory: str,
copy_to_dir=None,
verbose=False,
load_file_in_unity=False,
load_file_in_unity=True,
extension=None,
raise_for_failure=True,
fail_if_not_unity_loadable=False,
Expand All @@ -78,7 +90,7 @@ def create_assets(
assets_dir: str,
copy_to_dir=None,
verbose=False,
load_file_in_unity=False,
load_file_in_unity=True,
extension=None,
fail_if_not_unity_loadable=False,
raise_for_failure=True,
Expand Down Expand Up @@ -253,10 +265,10 @@ def create_assets_if_not_exist(
class ProceduralAssetActionCallback:
def __init__(
self,
asset_directory,
asset_directory: str,
target_dir="processed_models",
asset_symlink=True,
load_file_in_unity=False,
load_file_in_unity=True,
stop_if_fail=False,
asset_limit=-1,
extension=None,
Expand Down Expand Up @@ -373,3 +385,182 @@ def CreateHouse(self, action, controller):
extension=self.extension,
verbose=self.verbose,
)


def download_with_progress_bar(save_path: str, url: str, verbose: bool = False):
os.makedirs(os.path.dirname(save_path), exist_ok=True)

with open(save_path, "wb") as f:
if verbose:
print(f"Downloading to {save_path}")

response = requests.get(url, stream=True)
total_length = response.headers.get("content-length")

content_type = response.headers.get("content-type")
if content_type is not None and content_type.startswith("text/html"):
raise ValueError(f"Invalid URL: {url}")

if total_length is None: # no content length header
f.write(response.content)
else:
dl = 0
total_length = int(total_length)

with (
tqdm.tqdm(
total=total_length,
unit="B",
unit_scale=True,
desc=f"Downloading asset {url}",
)
if verbose
else nullcontext()
) as pbar:
for data in response.iter_content(chunk_size=4096):
dl += len(data)
f.write(data)
if verbose:
pbar.update(len(data))


def download_missing_asset(
asset_id: str,
asset_directory: str,
base_url: str,
verbose: bool = False,
) -> str:
final_save_dir = os.path.join(asset_directory, asset_id)

if os.path.exists(final_save_dir):
if any(f"{asset_id}." in p for p in os.listdir(final_save_dir)):
return final_save_dir
else:
print(
f"Directory {final_save_dir} exists but could not find"
f" asset {asset_id} in it. Will attempt to redownload."
)

url = f"{base_url.strip('/')}/{asset_id}.tar"

td = TemporaryDirectory()
with td as td_name:
save_path = os.path.join(td_name, f"{asset_id}.tar")

download_with_progress_bar(save_path=save_path, url=url, verbose=verbose)

os.makedirs(asset_directory, exist_ok=True)

# Loop through all the files in the tar file and extract them one by one
# to the asset directory keeping the directory structure
with FileLock(os.path.join(os.path.expanduser("~"), ".ai2thor", "asset_extraction.lock")):
with tarfile.open(save_path, "r") as tar:
# Here we sort the members so that the <asset_id>.<extension> file is last to ensure that the object
# file is the last thing to be saved to the final location. We do this because
# we check for the existence of the <asset_id>.<extension> file to determine if the asset
# has been successfully downloaded previously and we want to avoid partial downloads.
for member in sorted(tar.getmembers(), key=lambda x: f"{asset_id}." in x.name):
# if "_renders" not in member.name and "success.txt" not in member.name:
tar.extract(member=member, path=asset_directory)
return final_save_dir


def wait_for_futures_and_raise_errors(
futures: Sequence[concurrent.futures.Future],
) -> Sequence[Any]:
results = []
concurrent.futures.wait(futures)
for future in futures:
try:
results.append(future.result()) # This will re-raise any exceptions
except Exception:
raise
return results


def download_missing_assets(
asset_ids: Sequence[str],
asset_directory: str,
base_url: str,
verbose: bool = True,
threads: int = 1,
):
if verbose and threads > 1:
print(f"Downloading assets with {threads} threads. Will NOT log progress bars.")

asset_ids = sorted(set(asset_ids))

with ThreadPoolExecutor(max_workers=threads) as executor:
futures = [
executor.submit(
download_missing_asset,
asset_id=asset_id,
asset_directory=asset_directory,
base_url=base_url,
verbose=verbose and (threads == 1),
)
for asset_id in asset_ids
]
wait_for_futures_and_raise_errors(futures)


class WebProceduralAssetActionCallback(ProceduralAssetActionCallback):
def __init__(
self,
asset_directory: str,
base_url: str,
target_dir: str,
asset_symlink=True,
load_file_in_unity=True,
stop_if_fail=False,
asset_limit=-1,
extension=None,
verbose=True,
):
super().__init__(
asset_directory=asset_directory,
target_dir=target_dir,
asset_symlink=asset_symlink,
load_file_in_unity=load_file_in_unity,
stop_if_fail=stop_if_fail,
asset_limit=asset_limit,
extension=extension,
verbose=verbose,
)
self.base_url = base_url

def _download_missing_assets(self, controller: "Controller", asset_ids: Sequence[str]):
asset_in_db = controller.step(
action="AssetsInDatabase", assetIds=asset_ids, updateProceduralLRUCache=False
).metadata["actionReturn"]
assets_not_created = [asset_id for (asset_id, in_db) in asset_in_db.items() if not in_db]
download_missing_assets(
asset_ids=assets_not_created,
asset_directory=self.asset_directory,
base_url=self.base_url,
)

def Initialize(self, action, controller):
if self.asset_limit > 0:
return controller.step(
action="DeleteLRUFromProceduralCache", assetLimit=self.asset_limit
)

def CreateHouse(self, action: Dict[str, Any], controller: "Controller"):
house = action["house"]
asset_ids = get_all_asset_ids_recursively(house["objects"], [])
self._download_missing_assets(controller=controller, asset_ids=asset_ids)

return super().CreateHouse(action=action, controller=controller)

def SpawnAsset(self, action, controller):
self._download_missing_assets(controller=controller, asset_ids=[action["assetId"]])

return super().SpawnAsset(action=action, controller=controller)

def GetHouseFromTemplate(self, action, controller):
template = action["template"]
asset_ids = get_all_asset_ids_recursively([v for (k, v) in template["objects"].items()], [])
self._download_missing_assets(controller=controller, asset_ids=asset_ids)

super().GetHouseFromTemplate(action=action, controller=controller)
6 changes: 3 additions & 3 deletions tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4724,10 +4724,10 @@ def test_create_prefab(ctx, json_path):


@task
def procedural_asset_hook_test(ctx, asset_dir, house_path, asset_id=""):
def procedural_asset_callback_test(ctx, asset_dir, house_path, asset_id=""):
import json
import ai2thor.controller
from ai2thor.hooks.procedural_asset_hook import ProceduralAssetActionCallback
from ai2thor.hooks.procedural_asset_callback import ProceduralAssetActionCallback
from objathor.asset_conversion.util import view_asset_in_thor

hook_runner = ProceduralAssetActionCallback(
Expand Down Expand Up @@ -4817,7 +4817,7 @@ def procedural_asset_hook_test(ctx, asset_dir, house_path, asset_id=""):
def procedural_asset_cache_test(ctx, asset_dir, house_path, asset_ids="", cache_limit=1):
import json
import ai2thor.controller
from ai2thor.hooks.procedural_asset_hook import ProceduralAssetActionCallback
from ai2thor.hooks.procedural_asset_callback import ProceduralAssetActionCallback

hook_runner = ProceduralAssetActionCallback(
asset_directory=asset_dir, asset_symlink=True, verbose=True, asset_limit=1
Expand Down
Loading

0 comments on commit d9ff8cf

Please sign in to comment.