From 06f02c6b253095760090f00213d065e96a679e3f Mon Sep 17 00:00:00 2001 From: Kye Date: Sat, 16 Dec 2023 18:12:52 -0500 Subject: [PATCH] [zeta Module CLEAN UP OPERATIO] --- zeta/__init__.py | 49 ++++++++++------------------------- zeta/ops/__Init__.py | 3 --- zeta/utils/__init__.py | 3 ++- zeta/utils/disable_logging.py | 31 ++++++++++++++++++++++ 4 files changed, 46 insertions(+), 40 deletions(-) create mode 100644 zeta/utils/disable_logging.py diff --git a/zeta/__init__.py b/zeta/__init__.py index 5fbcfce8..31ae3141 100644 --- a/zeta/__init__.py +++ b/zeta/__init__.py @@ -1,36 +1,13 @@ -import logging -import os -import warnings - -# disable warnings - -warnings.filterwarnings("ignore") - -# disable tensorflow warnings - -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" - -# disable bnb warnings and others - -logging.getLogger().setLevel(logging.WARNING) - - -class CustomFilter(logging.Filter): - def filter(self, record): - msg = "Created a temporary directory at" - return msg not in record.getMessage() - - -logger = logging.getLogger() -f = CustomFilter() -logger.addFilter(f) - -from zeta.nn import * -from zeta.models import * -from zeta.utils import * -from zeta.training import * -from zeta.tokenizers import * -from zeta.rl import * -from zeta.optim import * -from zeta.ops import * -from zeta.quant import * +from zeta.utils.disable_logging import disable_warnings_and_logs + +disable_warnings_and_logs() + +from zeta.nn import * # noqa: F403, E402 +from zeta.models import * # noqa: F403, E402 +from zeta.utils import * # noqa: F403, E402 +from zeta.training import * # noqa: F403, E402 +from zeta.tokenizers import * # noqa: F403, E402 +from zeta.rl import * # noqa: F403, E402 +from zeta.optim import * # noqa: F403, E402 +from zeta.ops import * # noqa: F403, E402 +from zeta.quant import * # noqa: F403, E402 diff --git a/zeta/ops/__Init__.py b/zeta/ops/__Init__.py index 0597d52f..e8310817 100644 --- a/zeta/ops/__Init__.py +++ b/zeta/ops/__Init__.py @@ -1,7 +1,4 @@ -from zeta.ops.main import * -from zeta.ops.softmax import * from zeta.ops.unitwise_norm import unitwise_norm -from zeta.ops.mos import MixtureOfSoftmaxes from zeta.ops.softmax import ( standard_softmax, diff --git a/zeta/utils/__init__.py b/zeta/utils/__init__.py index 2edf7a54..1e2293a7 100644 --- a/zeta/utils/__init__.py +++ b/zeta/utils/__init__.py @@ -7,11 +7,12 @@ print_cuda_memory_usage, save_memory_snapshot, ) - +from zeta.utils.disable_logging import disable_warnings_and_logs __all__ = [ "track_cuda_memory_usage", "benchmark", "print_cuda_memory_usage", "save_memory_snapshot", + "disable_warnings_and_logs", ] diff --git a/zeta/utils/disable_logging.py b/zeta/utils/disable_logging.py new file mode 100644 index 00000000..c4bcc12c --- /dev/null +++ b/zeta/utils/disable_logging.py @@ -0,0 +1,31 @@ +import logging +import os +import warnings + + +def disable_warnings_and_logs(): + """Disable warnings and logs. + + Returns: + _type_: _description_ + """ + # disable warnings + warnings.filterwarnings("ignore") + + # disable tensorflow warnings + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" + + # disable bnb warnings and others + logging.getLogger().setLevel(logging.WARNING) + + class CustomFilter(logging.Filter): + def filter(self, record): + msg = "Created a temporary directory at" + return msg not in record.getMessage() + + logger = logging.getLogger() + f = CustomFilter() + logger.addFilter(f) + + +disable_warnings_and_logs()