From da4b0865ecad072db606541daedfb81ac6e4a605 Mon Sep 17 00:00:00 2001 From: ashawkey Date: Fri, 1 Dec 2023 18:38:45 +0800 Subject: [PATCH] add kiui.seed_everything --- .gitignore | 2 +- kiui/env.py | 17 +++++++++++----- kiui/utils.py | 49 ++++++++++++++++++++++++++++++++++++++++++---- setup.py | 2 +- tests/test_seed.py | 24 +++++++++++++++++++++++ 5 files changed, 83 insertions(+), 11 deletions(-) create mode 100644 tests/test_seed.py diff --git a/.gitignore b/.gitignore index d2d5b6a..82c63ae 100644 --- a/.gitignore +++ b/.gitignore @@ -9,4 +9,4 @@ __pycache__/ build/ dist/ -tmp_* \ No newline at end of file +tmp* \ No newline at end of file diff --git a/kiui/env.py b/kiui/env.py index f8ff84f..ec182d3 100644 --- a/kiui/env.py +++ b/kiui/env.py @@ -59,7 +59,8 @@ def retrieve_globals(verbose=False): if "kiui" in g: G = g if verbose: - print(f"[KiuiKit-INFO] located global frame at {frame_id}") + print(f"[INFO] located global frame at {frame_id}") + # print(G) break frame_id += 1 if G is None: @@ -67,6 +68,12 @@ def retrieve_globals(verbose=False): "Cannot locate global frame, make sure you called exactly `import kiui`!" ) +def is_imported(target, verbose=False): + + if G is None: + retrieve_globals(verbose) + + return target in G def try_import(target, sources, verbose=False): @@ -75,7 +82,7 @@ def try_import(target, sources, verbose=False): if target in G: if verbose: - print(f"[KiuiKit-INFO] {target} is already present, skipped.") + print(f"[INFO] {target} is already present, skipped.") return if not isinstance(sources, list): @@ -84,7 +91,7 @@ def try_import(target, sources, verbose=False): for source in sources: try: if verbose: - print(f"[KiuiKit-INFO] try to import {source}") + print(f"[INFO] try to import {source}") # (module, component) or ("module", component) if isinstance(source, tuple): @@ -102,11 +109,11 @@ def try_import(target, sources, verbose=False): G[target] = source if verbose: - print(f"[KiuiKit-INFO] succeed to import {source} as {target}") + print(f"[INFO] succeed to import {source} as {target}") break except ImportError as e: - print(f"[KiuiKit-WARN] failed to import {source} as {target}: {str(e)}") + print(f"[WARN] failed to import {source} as {target}: {str(e)}") def import_libs(pack, verbose=False): diff --git a/kiui/utils.py b/kiui/utils.py index 70efbe5..b09c5d3 100644 --- a/kiui/utils.py +++ b/kiui/utils.py @@ -1,17 +1,24 @@ import os +import sys import glob import tqdm -import cv2 import json import pickle import varname +from objprint import objstr +from rich.console import Console + +import cv2 from PIL import Image -import torch import numpy as np -from objprint import objstr +import torch -from rich.console import Console +from kiui.env import is_imported + +''' utils +All functions will be automatically imported as kiui. +''' # inspect array like object x and report stats def lo(*xs, verbose=0): @@ -72,6 +79,40 @@ def _lo(x, name): _lo(x, name) +def seed_everything(seed=42, verbose=False, strict=False): + + os.environ['PYTHONHASHSEED'] = str(seed) + + if is_imported('random'): + import random # still need to import it here + random.seed(seed) + if verbose: print(f'[INFO] set random.seed = {seed}') + else: + if verbose: print(f'[INFO] random not imported, skip setting seed') + + # assume numpy is imported as np + if is_imported('np'): + import numpy as np + np.random.seed(seed) + if verbose: print(f'[INFO] set np.random.seed = {seed}') + else: + if verbose: print(f'[INFO] numpy not imported, skip setting seed') + + if is_imported('torch'): + import torch + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + if verbose: print(f'[INFO] set torch.manual_seed = {seed}') + + if strict: + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + torch.use_deterministic_algorithms(True) + if verbose: print(f'[INFO] set strict deterministic mode for torch.') + else: + if verbose: print(f'[INFO] torch not imported, skip setting seed') + + def read_json(path): with open(path, "r") as f: return json.load(f) diff --git a/setup.py b/setup.py index 745dbff..fa5819d 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ if __name__ == "__main__": setup( name="kiui", - version="0.1.10", + version="0.1.11", description="A toolkit for 3D vision", long_description=open("README.md", encoding="utf-8").read(), long_description_content_type="text/markdown", diff --git a/tests/test_seed.py b/tests/test_seed.py new file mode 100644 index 0000000..715f806 --- /dev/null +++ b/tests/test_seed.py @@ -0,0 +1,24 @@ +import kiui + +kiui.seed_everything(42, True) + +import random +kiui.seed_everything(42, True) +a = random.random() +kiui.seed_everything(42, True) +b = random.random() +assert a == b + +import numpy as np +kiui.seed_everything(42, True) +a = np.random.randn(10) +kiui.seed_everything(42, True) +b = np.random.randn(10) +assert np.allclose(a, b) + +import torch +kiui.seed_everything(42, True) +a = torch.randn(10) +kiui.seed_everything(42, True) +b = torch.randn(10) +assert torch.allclose(a, b) \ No newline at end of file