Skip to content

Commit

Permalink
sprites fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
nmichlo committed Jun 9, 2022
1 parent 9510e66 commit fe8b66d
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 2 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,11 @@ on the GPU or CPU at different points in the pipeline.
<p align="center"><img height="192" src="docs/img/traversals/traversal-transpose__shapes3d.jpg" alt="Shapes3D Dataset Factor Traversals"></p>
</details>

+ <details>
<summary>🏹 <a href="https://github.com/YingzhenLi/Sprites" target="_blank">Sprites (custom)</a></summary>
<p align="center"><img height="192" src="docs/img/traversals/traversal-transpose__sprites.jpg" alt="Sprites (Custom) Dataset Factor Traversals"></p>
</details>

+ <details open>
<summary>
🧵 <u>dSpritesImagenet</u>:
Expand Down
15 changes: 13 additions & 2 deletions disent/dataset/data/_groundtruth__sprites.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,22 @@
# ========================================================================= #


SPRITES_REPO = 'https://github.com/YingzhenLi/Sprites'
SPRITES_REPO_COMMIT_SHA = '3ce4048c5227802bd8f1888e293fd3afdba91c0c'


def fetch_sprite_components() -> Tuple[np.array, np.array]:
try:
import git
except ImportError:
logging.error('GitPython not found! Please install it: `pip install GitPython`')
log.error('GitPython not found! Please install it: `pip install GitPython`')
exit(1)
# store files in a temporary directory
with TemporaryDirectory(suffix='sprites') as temp_dir:
# clone the files into the temp dir
git.Repo.clone_from('https://github.com/YingzhenLi/Sprites', temp_dir)
log.info(f'Generating sprites data, temporarily cloning: {SPRITES_REPO} to {temp_dir}`')
repo = git.Repo.clone_from(SPRITES_REPO, temp_dir, no_checkout=True)
repo.git.checkout(SPRITES_REPO_COMMIT_SHA)
# get all the components!
component_sheets: List[np.ndarray] = []
component_names = ['bottomwear', 'topwear', 'hair', 'eyes', 'shoes', 'body']
Expand Down Expand Up @@ -111,6 +117,11 @@ def _prepare(self, out_dir: str, out_file: str) -> NoReturn:


class SpritesAllData(DiskGroundTruthData):
"""
Custom version of sprites, with the data obtained from:
https://github.com/YingzhenLi/Sprites
"""

name = 'sprites'
factor_names = ('bottomwear', 'topwear', 'hair', 'eyes', 'shoes', 'body', 'action', 'rotation', 'frame')
factor_sizes = (7, 7, 10, 5, 3, 7, 5, 4, 6) # 6_174_000
Expand Down
1 change: 1 addition & 0 deletions disent/dataset/util/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ def main(progress=True, num_workers=32, batch_size=2048): # try changing worker
# mean: [0.046146392822265625]
# std: [0.2096506119375896]

# TODO -- REGENERATE:
# Mpi3dData - mpi3d_toy - {'subset': 'toy', 'in_memory': True}:
# mean: [0.22681593831231503, 0.22353985202496676, 0.22666059934624702]
# std: [0.07854112062669572, 0.07319301658077378, 0.0790763900050426]
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit fe8b66d

Please sign in to comment.