diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..3107e4c --- /dev/null +++ b/.coveragerc @@ -0,0 +1,24 @@ +[html] +directory = coverage + +[run] +source = + instageo +omit = + */__init__.py + *_test.py + +[report] +omit = + __init__.py + *_test.py + *app.py + +exclude_lines = + # Don't complain if tests don't hit defensive assertion code: + raise AssertionError + raise NotImplementedError + + # Don't complain if non-runnable code isn't run: + if 0: + if __name__ == .__main__.: diff --git a/.github/workflows/tests_and_linters.yaml b/.github/workflows/tests_and_linters.yaml new file mode 100644 index 0000000..49255d0 --- /dev/null +++ b/.github/workflows/tests_and_linters.yaml @@ -0,0 +1,37 @@ +name: Tests and Linters 🧪 + +on: [push, pull_request] + +jobs: + tests-and-linters: + name: "Python 3.10 on instadeep-ci" + runs-on: instadeep-ci + container: + image: python:3.10 + + steps: + - name: Install dependencies for viewer test + run: apt-get update && apt-get install -y xvfb + - name: Checkout your repo 📦 + uses: actions/checkout@v3 + + - name: Install python dependencies 🔧 + run: | + pip install --upgrade pip + pip install -r requirements.txt \ + -r instageo/model/requirements.txt \ + -r instageo/data/requirements.txt \ + -r instageo/apps/requirements.txt + + - run: git config --system --add safe.directory $GITHUB_WORKSPACE + - name: Run linters 🖌️ + run: pre-commit run --all-files --verbose + - name: Set PYTHONPATH + run: echo "PYTHONPATH=$PYTHONPATH:$(pwd)" >> $GITHUB_ENV + - name: Run tests 🧪 + env: + EARTHDATA_USERNAME: ${{ secrets.EARTHDATA_USERNAME }} + EARTHDATA_PASSWORD: ${{ secrets.EARTHDATA_PASSWORD }} + run: pytest --cov --cov-config=.coveragerc --cov-report=html --cov-report=term-missing --cov-fail-under=50 -m "not auth" +# - name: Test build docs 📖 +# run: mkdocs build --verbose --site-dir docs_public diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f96319a --- /dev/null +++ b/.gitignore @@ -0,0 +1,147 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# PyCharm stuff +.idea/ + +# VSCode stuff +.vscode + +# Docker stuff +.devcontainer + +# MacBook Finder +.DS_Store + +# Experiment results +outputs/ + + +*.pyc +__pycache__/ +*lightning_logs +*prithvi/ +*.ipynb +nul/ +notebooks/ +!notebooks/*.ipynb \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..d9dd83f --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,88 @@ +# Pre-commit hooks for repo. + +# Packages: +# pre-commit: General pre-commits for formatting. +# black: Python code strict formatting. +# pyupgrade: Upgrade syntax for newer versions of the language. +# isort: Sorts imports. +# flake8: Checks code follows PEP8 standard. +# mypy: Static typing. +# conventional-pre-commit: commit format checker. +# blacken-docs: Checks docs follow black format standard. +# pydocstyle: Checking docstring style. + +default_stages: ["commit", "commit-msg", "push"] +default_language_version: + python: python3.10 + +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + exclude_types: [image] + - id: check-merge-conflict + - id: debug-statements + - id: mixed-line-ending + - id: check-yaml + + - repo: https://github.com/psf/black + rev: 23.11.0 + hooks: + - id: black + exclude_types: [image] + language_version: python3 + + - repo: https://github.com/asottile/pyupgrade + rev: v3.15.0 + hooks: + - id: pyupgrade + + - repo: https://github.com/timothycrosley/isort + rev: 5.12.0 + hooks: + - id: isort + args: ["--profile", "black", "--filter-files"] + + - repo: https://github.com/PyCQA/flake8 + rev: 6.1.0 + hooks: + - id: flake8 + args: + [ + "--max-line-length=100", + "--extend-ignore=E203,BLK100", + "--exclude=*tests*", + ] + + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.7.0 + hooks: + - id: mypy + exclude: ^docs/|test + args: [--config-file=setup.cfg] + + - repo: https://github.com/compilerla/conventional-pre-commit + rev: v3.0.0 + hooks: + - id: conventional-pre-commit + stages: [commit-msg] + + - repo: https://github.com/asottile/blacken-docs + rev: 1.16.0 + hooks: + - id: blacken-docs + additional_dependencies: [black>=22.1.0] + language_version: python3 + + - repo: https://github.com/pycqa/pydocstyle + rev: 6.3.0 + hooks: + - id: pydocstyle + name: Checking docstring style. + args: + [ + "--convention=google", + "--match=^((?!test).)*$", + ] diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..41ccd39 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,136 @@ +# Contributing to InstaGeo + +First off, thank you for considering contributing to InstaGeo. It's people like you that make InstaGeo such a great tool. + +### What should I know before I get started? + +- Code of Conduct: InstaGeo adheres to a [Code of Conduct](#code-of-conduct). By participating, you are expected to uphold this code. +- Project Structure: Familiarize yourself with the layout of the project's repository. This can help you understand how the project is organized and where to make your changes. + +### Ways to Contribute + +There are many ways to contribute to InstaGeo, from writing tutorials or blog posts, improving the documentation, submitting bug reports and feature requests, or writing code which can be incorporated into InstaGeo itself. + +#### Reporting Bugs + +- Use the issue tracker to report bugs. +- Describe the bug and include additional details to help maintainers reproduce the problem. +- Provide a detailed description of the expected behavior. +- Include any relevant screenshots or error messages. + +#### Suggesting Enhancements + +- Use the issue tracker to suggest feature enhancements. +- Clearly describe the suggestion and include additional documentation or screenshots as necessary. + +**!!** Before opening a new issue, make sure to search for keywords in the issues filtered by the +`"type::"` label and verify the issue you're about to submit isn't a duplicate. + +#### Pull Requests + +- Fork the repository and create your branch from `main`. +- If you've added code that should be tested, add tests. +- Ensure your code adheres to the existing style of the project to increase the chance of your changes being merged directly. +- Write a convincing description of your PR and why we should land it. + + +### Our Pledge + +In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to making participation in our project and our community a harassment-free experience for everyone. + +## Contributing Code + +All submissions, including submissions by project members, require review. We use GitHub pull +requests for this purpose. Consult +[GitHub Help](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/about-pull-requests) +for more information on using pull requests. + +Before sending your pull request for review, make sure your changes are consistent with the +guidelines and follow the coding style and code of conduct. + +### General guidelines and philosophy for contribution + +- Include unit tests when you contribute new features, as they help to a) prove that your code works + correctly, and b) guard against future breaking changes to lower the maintenance cost. +- When you contribute a new feature to InstaGeo, the maintenance burden is (by default) transferred + to the InstaGeo team. This means that the benefit of the contribution must be compared against the + cost of maintaining the feature. +- Keep API compatibility in mind when you change code. Non-backward-compatible API changes will not + be made if they don't greatly improve the library. +- As every PR requires CI testing, we discourage submitting PRs to fix one typo, one warning, etc. + We recommend fixing the same issue at the file level at least (e.g.: fix all typos in a file, fix + all compiler warnings in a file, etc.) + +### Coding Style + +To guarantee the quality and uniformisation of the code, we use various linters: + +- [Black](https://black.readthedocs.io/en/stable/#) is a deterministic code formatter that is + compliant with PEP8 standards. +- [Isort](https://pycqa.github.io/isort/) sorts imports alphabetically and separates them into + sections. +- [Flake8](https://flake8.pycqa.org/en/latest/) is a library that wraps PyFlakes and PyCodeStyle. It + is a great toolkit for checking your codebase against coding style (PEP8), programming, and syntax + errors. Flake8 also benefits from an ecosystem of plugins developed by the community that extend + its capabilities. You can read more about Flake8 plugins on the documentation and find a curated + list of plugins here. +- [MyPy](https://mypy.readthedocs.io/en/stable/#) is a static type checker that can help you detect + inconsistent typing of variables. + +### Pre-Commit + +To help in automating the quality of the code, we use [pre-commit](https://pre-commit.com/), a +framework that manages the installation and execution of git hooks that will be run before every +commit. These hooks help to automatically point out issues in code such as formatting mistakes, +unused variables, trailing whitespace, debug statements, etc. By pointing these issues out before +code review, it allows a code reviewer to focus on the architecture of a change while not wasting +time with trivial style nitpicks. Each commit should be preceded by a call to pre-commit to ensure +code quality and formatting. The configuration is in .pre-commit-config.yaml and includes Black, +Flake8, MyPy and checks for the yaml formatting, trimming trailing whitespace, etc. Try running: +`pre-commit run --all-files`. All linters must pass before committing your change. + +### Docstrings + +Public modules, functions, classes, and methods must be documented using Python docstrings. These +docstrings must have sections for Arguments, Returns, and Raises (if applicable). For every argument +of a function, the docstring must explain precisely what the argument does, what data type it +expects, whether or not it is optional, and any requirements for the range of values it expects. The +same goes for the returns. Use existing docstrings as templates. Non-public functions and methods +must also be documented for defining the API contract. In addition to being useful for generating +documentation, docstrings add clarity when looking through the source code or using the built-in +help system, and can be leveraged in autocompletion by IDEs. + +Please see [PEP 257](https://peps.python.org/pep-0257/) for details on semantics and conventions +associated with Python docstrings. Docstrings must follow +[Google style docstrings](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html). +This docstring style is more pleasant to read when browsing the source. + +If you are using Visual Studio Code, you can install the +[autoDocstring](https://marketplace.visualstudio.com/items?itemName=njpwerner.autodocstring) +extension to quickly generate docstring snippets. + +### Code of Conduct + +Examples of behavior that contributes to creating a positive environment include: + +- Using welcoming and inclusive language +- Being respectful of differing viewpoints and experiences +- Gracefully accepting constructive criticism +- Focusing on what is best for the community +- Showing empathy towards other community members + +### Testing + +Please make sure that your PR passes all tests by running +[pytest](https://docs.pytest.org/en/latest/) on your local machine. Also, you can run only tests +that are affected by your code changes, but you will need to select them manually. + +### Contributor License Agreement + +Contributions to this project must be accompanied by a Contributor License Agreement. You (or your +employer) retain the copyright to your contribution, this simply gives us permission to use and +redistribute your contributions as part of the project. Head over to +https://cla-assistant.io/instadeepai/ to see your current agreements on file or to sign a new one. + +You generally only need to submit a CLA once, so if you've already submitted one (even if it was for +a different project), you probably don't need to do it again. diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..cbe5ad1 --- /dev/null +++ b/LICENSE @@ -0,0 +1,437 @@ +Attribution-NonCommercial-ShareAlike 4.0 International + +======================================================================= + +Creative Commons Corporation ("Creative Commons") is not a law firm and +does not provide legal services or legal advice. Distribution of +Creative Commons public licenses does not create a lawyer-client or +other relationship. Creative Commons makes its licenses and related +information available on an "as-is" basis. Creative Commons gives no +warranties regarding its licenses, any material licensed under their +terms and conditions, or any related information. Creative Commons +disclaims all liability for damages resulting from their use to the +fullest extent possible. + +Using Creative Commons Public Licenses + +Creative Commons public licenses provide a standard set of terms and +conditions that creators and other rights holders may use to share +original works of authorship and other material subject to copyright +and certain other rights specified in the public license below. The +following considerations are for informational purposes only, are not +exhaustive, and do not form part of our licenses. + + Considerations for licensors: Our public licenses are + intended for use by those authorized to give the public + permission to use material in ways otherwise restricted by + copyright and certain other rights. Our licenses are + irrevocable. Licensors should read and understand the terms + and conditions of the license they choose before applying it. + Licensors should also secure all rights necessary before + applying our licenses so that the public can reuse the + material as expected. Licensors should clearly mark any + material not subject to the license. This includes other CC- + licensed material, or material used under an exception or + limitation to copyright. More considerations for licensors: + wiki.creativecommons.org/Considerations_for_licensors + + Considerations for the public: By using one of our public + licenses, a licensor grants the public permission to use the + licensed material under specified terms and conditions. If + the licensor's permission is not necessary for any reason--for + example, because of any applicable exception or limitation to + copyright--then that use is not regulated by the license. Our + licenses grant only permissions under copyright and certain + other rights that a licensor has authority to grant. Use of + the licensed material may still be restricted for other + reasons, including because others have copyright or other + rights in the material. A licensor may make special requests, + such as asking that all changes be marked or described. + Although not required by our licenses, you are encouraged to + respect those requests where reasonable. More considerations + for the public: + wiki.creativecommons.org/Considerations_for_licensees + +======================================================================= + +Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International +Public License + +By exercising the Licensed Rights (defined below), You accept and agree +to be bound by the terms and conditions of this Creative Commons +Attribution-NonCommercial-ShareAlike 4.0 International Public License +("Public License"). To the extent this Public License may be +interpreted as a contract, You are granted the Licensed Rights in +consideration of Your acceptance of these terms and conditions, and the +Licensor grants You such rights in consideration of benefits the +Licensor receives from making the Licensed Material available under +these terms and conditions. + + +Section 1 -- Definitions. + + a. Adapted Material means material subject to Copyright and Similar + Rights that is derived from or based upon the Licensed Material + and in which the Licensed Material is translated, altered, + arranged, transformed, or otherwise modified in a manner requiring + permission under the Copyright and Similar Rights held by the + Licensor. For purposes of this Public License, where the Licensed + Material is a musical work, performance, or sound recording, + Adapted Material is always produced where the Licensed Material is + synched in timed relation with a moving image. + + b. Adapter's License means the license You apply to Your Copyright + and Similar Rights in Your contributions to Adapted Material in + accordance with the terms and conditions of this Public License. + + c. BY-NC-SA Compatible License means a license listed at + creativecommons.org/compatiblelicenses, approved by Creative + Commons as essentially the equivalent of this Public License. + + d. Copyright and Similar Rights means copyright and/or similar rights + closely related to copyright including, without limitation, + performance, broadcast, sound recording, and Sui Generis Database + Rights, without regard to how the rights are labeled or + categorized. For purposes of this Public License, the rights + specified in Section 2(b)(1)-(2) are not Copyright and Similar + Rights. + + e. Effective Technological Measures means those measures that, in the + absence of proper authority, may not be circumvented under laws + fulfilling obligations under Article 11 of the WIPO Copyright + Treaty adopted on December 20, 1996, and/or similar international + agreements. + + f. Exceptions and Limitations means fair use, fair dealing, and/or + any other exception or limitation to Copyright and Similar Rights + that applies to Your use of the Licensed Material. + + g. License Elements means the license attributes listed in the name + of a Creative Commons Public License. The License Elements of this + Public License are Attribution, NonCommercial, and ShareAlike. + + h. Licensed Material means the artistic or literary work, database, + or other material to which the Licensor applied this Public + License. + + i. Licensed Rights means the rights granted to You subject to the + terms and conditions of this Public License, which are limited to + all Copyright and Similar Rights that apply to Your use of the + Licensed Material and that the Licensor has authority to license. + + j. Licensor means the individual(s) or entity(ies) granting rights + under this Public License. + + k. NonCommercial means not primarily intended for or directed towards + commercial advantage or monetary compensation. For purposes of + this Public License, the exchange of the Licensed Material for + other material subject to Copyright and Similar Rights by digital + file-sharing or similar means is NonCommercial provided there is + no payment of monetary compensation in connection with the + exchange. + + l. Share means to provide material to the public by any means or + process that requires permission under the Licensed Rights, such + as reproduction, public display, public performance, distribution, + dissemination, communication, or importation, and to make material + available to the public including in ways that members of the + public may access the material from a place and at a time + individually chosen by them. + + m. Sui Generis Database Rights means rights other than copyright + resulting from Directive 96/9/EC of the European Parliament and of + the Council of 11 March 1996 on the legal protection of databases, + as amended and/or succeeded, as well as other essentially + equivalent rights anywhere in the world. + + n. You means the individual or entity exercising the Licensed Rights + under this Public License. Your has a corresponding meaning. + + +Section 2 -- Scope. + + a. License grant. + + 1. Subject to the terms and conditions of this Public License, + the Licensor hereby grants You a worldwide, royalty-free, + non-sublicensable, non-exclusive, irrevocable license to + exercise the Licensed Rights in the Licensed Material to: + + a. reproduce and Share the Licensed Material, in whole or + in part, for NonCommercial purposes only; and + + b. produce, reproduce, and Share Adapted Material for + NonCommercial purposes only. + + 2. Exceptions and Limitations. For the avoidance of doubt, where + Exceptions and Limitations apply to Your use, this Public + License does not apply, and You do not need to comply with + its terms and conditions. + + 3. Term. The term of this Public License is specified in Section + 6(a). + + 4. Media and formats; technical modifications allowed. The + Licensor authorizes You to exercise the Licensed Rights in + all media and formats whether now known or hereafter created, + and to make technical modifications necessary to do so. The + Licensor waives and/or agrees not to assert any right or + authority to forbid You from making technical modifications + necessary to exercise the Licensed Rights, including + technical modifications necessary to circumvent Effective + Technological Measures. For purposes of this Public License, + simply making modifications authorized by this Section 2(a) + (4) never produces Adapted Material. + + 5. Downstream recipients. + + a. Offer from the Licensor -- Licensed Material. Every + recipient of the Licensed Material automatically + receives an offer from the Licensor to exercise the + Licensed Rights under the terms and conditions of this + Public License. + + b. Additional offer from the Licensor -- Adapted Material. + Every recipient of Adapted Material from You + automatically receives an offer from the Licensor to + exercise the Licensed Rights in the Adapted Material + under the conditions of the Adapter's License You apply. + + c. No downstream restrictions. You may not offer or impose + any additional or different terms or conditions on, or + apply any Effective Technological Measures to, the + Licensed Material if doing so restricts exercise of the + Licensed Rights by any recipient of the Licensed + Material. + + 6. No endorsement. Nothing in this Public License constitutes or + may be construed as permission to assert or imply that You + are, or that Your use of the Licensed Material is, connected + with, or sponsored, endorsed, or granted official status by, + the Licensor or others designated to receive attribution as + provided in Section 3(a)(1)(A)(i). + + b. Other rights. + + 1. Moral rights, such as the right of integrity, are not + licensed under this Public License, nor are publicity, + privacy, and/or other similar personality rights; however, to + the extent possible, the Licensor waives and/or agrees not to + assert any such rights held by the Licensor to the limited + extent necessary to allow You to exercise the Licensed + Rights, but not otherwise. + + 2. Patent and trademark rights are not licensed under this + Public License. + + 3. To the extent possible, the Licensor waives any right to + collect royalties from You for the exercise of the Licensed + Rights, whether directly or through a collecting society + under any voluntary or waivable statutory or compulsory + licensing scheme. In all other cases the Licensor expressly + reserves any right to collect such royalties, including when + the Licensed Material is used other than for NonCommercial + purposes. + + +Section 3 -- License Conditions. + +Your exercise of the Licensed Rights is expressly made subject to the +following conditions. + + a. Attribution. + + 1. If You Share the Licensed Material (including in modified + form), You must: + + a. retain the following if it is supplied by the Licensor + with the Licensed Material: + + i. identification of the creator(s) of the Licensed + Material and any others designated to receive + attribution, in any reasonable manner requested by + the Licensor (including by pseudonym if + designated); + + ii. a copyright notice; + + iii. a notice that refers to this Public License; + + iv. a notice that refers to the disclaimer of + warranties; + + v. a URI or hyperlink to the Licensed Material to the + extent reasonably practicable; + + b. indicate if You modified the Licensed Material and + retain an indication of any previous modifications; and + + c. indicate the Licensed Material is licensed under this + Public License, and include the text of, or the URI or + hyperlink to, this Public License. + + 2. You may satisfy the conditions in Section 3(a)(1) in any + reasonable manner based on the medium, means, and context in + which You Share the Licensed Material. For example, it may be + reasonable to satisfy the conditions by providing a URI or + hyperlink to a resource that includes the required + information. + 3. If requested by the Licensor, You must remove any of the + information required by Section 3(a)(1)(A) to the extent + reasonably practicable. + + b. ShareAlike. + + In addition to the conditions in Section 3(a), if You Share + Adapted Material You produce, the following conditions also apply. + + 1. The Adapter's License You apply must be a Creative Commons + license with the same License Elements, this version or + later, or a BY-NC-SA Compatible License. + + 2. You must include the text of, or the URI or hyperlink to, the + Adapter's License You apply. You may satisfy this condition + in any reasonable manner based on the medium, means, and + context in which You Share Adapted Material. + + 3. You may not offer or impose any additional or different terms + or conditions on, or apply any Effective Technological + Measures to, Adapted Material that restrict exercise of the + rights granted under the Adapter's License You apply. + + +Section 4 -- Sui Generis Database Rights. + +Where the Licensed Rights include Sui Generis Database Rights that +apply to Your use of the Licensed Material: + + a. for the avoidance of doubt, Section 2(a)(1) grants You the right + to extract, reuse, reproduce, and Share all or a substantial + portion of the contents of the database for NonCommercial purposes + only; + + b. if You include all or a substantial portion of the database + contents in a database in which You have Sui Generis Database + Rights, then the database in which You have Sui Generis Database + Rights (but not its individual contents) is Adapted Material, + including for purposes of Section 3(b); and + + c. You must comply with the conditions in Section 3(a) if You Share + all or a substantial portion of the contents of the database. + +For the avoidance of doubt, this Section 4 supplements and does not +replace Your obligations under this Public License where the Licensed +Rights include other Copyright and Similar Rights. + + +Section 5 -- Disclaimer of Warranties and Limitation of Liability. + + a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE + EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS + AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF + ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, + IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, + WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR + PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, + ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT + KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT + ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. + + b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE + TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, + NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, + INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, + COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR + USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN + ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR + DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR + IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. + + c. The disclaimer of warranties and limitation of liability provided + above shall be interpreted in a manner that, to the extent + possible, most closely approximates an absolute disclaimer and + waiver of all liability. + + +Section 6 -- Term and Termination. + + a. This Public License applies for the term of the Copyright and + Similar Rights licensed here. However, if You fail to comply with + this Public License, then Your rights under this Public License + terminate automatically. + + b. Where Your right to use the Licensed Material has terminated under + Section 6(a), it reinstates: + + 1. automatically as of the date the violation is cured, provided + it is cured within 30 days of Your discovery of the + violation; or + + 2. upon express reinstatement by the Licensor. + + For the avoidance of doubt, this Section 6(b) does not affect any + right the Licensor may have to seek remedies for Your violations + of this Public License. + + c. For the avoidance of doubt, the Licensor may also offer the + Licensed Material under separate terms or conditions or stop + distributing the Licensed Material at any time; however, doing so + will not terminate this Public License. + + d. Sections 1, 5, 6, 7, and 8 survive termination of this Public + License. + + +Section 7 -- Other Terms and Conditions. + + a. The Licensor shall not be bound by any additional or different + terms or conditions communicated by You unless expressly agreed. + + b. Any arrangements, understandings, or agreements regarding the + Licensed Material not stated herein are separate from and + independent of the terms and conditions of this Public License. + + +Section 8 -- Interpretation. + + a. For the avoidance of doubt, this Public License does not, and + shall not be interpreted to, reduce, limit, restrict, or impose + conditions on any use of the Licensed Material that could lawfully + be made without permission under this Public License. + + b. To the extent possible, if any provision of this Public License is + deemed unenforceable, it shall be automatically reformed to the + minimum extent necessary to make it enforceable. If the provision + cannot be reformed, it shall be severed from this Public License + without affecting the enforceability of the remaining terms and + conditions. + + c. No term or condition of this Public License will be waived and no + failure to comply consented to unless expressly agreed to by the + Licensor. + + d. Nothing in this Public License constitutes or may be interpreted + as a limitation upon, or waiver of, any privileges and immunities + that apply to the Licensor or You, including from the legal + processes of any jurisdiction or authority. + +======================================================================= + +Creative Commons is not a party to its public +licenses. Notwithstanding, Creative Commons may elect to apply one of +its public licenses to material it publishes and in those instances +will be considered the “Licensor.” The text of the Creative Commons +public licenses is dedicated to the public domain under the CC0 Public +Domain Dedication. Except for the limited purpose of indicating that +material is shared under a Creative Commons public license or as +otherwise permitted by the Creative Commons policies published at +creativecommons.org/policies, Creative Commons does not authorize the +use of the trademark "Creative Commons" or any other trademark or logo +of Creative Commons without its prior written consent including, +without limitation, in connection with any unauthorized modifications +to any of its public licenses or any other arrangements, +understandings, or agreements concerning use of licensed material. For +the avoidance of doubt, this paragraph does not form part of the +public licenses. + +Creative Commons may be contacted at creativecommons.org. diff --git a/README.md b/README.md new file mode 100644 index 0000000..7925cd3 --- /dev/null +++ b/README.md @@ -0,0 +1,185 @@ + + + Logo + + +## Overview + +InstaGeo is geospatial deep learning Python package designed to facilitate geospatial machine learning using satellite imagery data from [Harmonized Landsat and Sentinel-2 (HLS)](https://hls.gsfc.nasa.gov/) Data Product and [Prithvi](https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M) geospatial foundational model. It consists of three core components: Data, Model, and Apps, each tailored to support various aspects of geospatial data retrieval, manipulation, preprocessing, model training, and inference serving. + +### Components + +1. [**Data**](./instageo/data/README.md): Focuses on retrieving, manipulating, and processing Harmonized Landsat Sentinel-2 (HLS) data for classification and segmentation tasks such as disaster mapping, crop classification, and breeding ground prediction. +2. [**Model**](./instageo/model/README.md): Centers around data loading, training, and evaluating models, particularly leveraging the Prithvi model for various modeling tasks. It includes a sliding-window feature that allows inference to be run on large inputs. +3. [**Apps**](./instageo/apps/README.md): Aims to operationalize models developed in the Model component for practical applications. + +## Installation + +To get started with InstaGeo, ensure you have Python installed on your system. Then, execute the following command in your terminal or command prompt to install InstaGeo: + +```bash +pip install instageo +``` +This command will download and install the latest version of InstaGeo along with its required dependencies. + +## Running Tests +After installation, you may want to verify that InstaGeo has been correctly installed and is functioning as expected. To do this, run the included test suite with the following commands: + +```bash +export PYTHONPATH=$PYTHONPATH:$(pwd) +pytest --verbose . +``` +## Usage + +### Data Component + +- **HLS Data Retrieval**: InstaGeo efficiently searches for and download Sentinel-2 and Landsat 8/9 multi-spectral earth observation images from the HLS data product. + +- **Create Chips and Segmentation Maps**: InstaGeo breaks down large satellite image tiles into smaller, manageable patches (referred to as "chips") suitable for deep learning model training. It also generate segmentation maps, which serve as targets for training, by categorizing each pixel in the chips. + +### Model Component + +- **Training Custom Models**: Utilize the Prithvi geospatial foundational model as a backbone to develop custom models tailored for precise geospatial applications. These applications include, but are not limited to, flood mapping for emergency response planning, crop classification for agricultural management, and locust breeding ground prediction to address food security. + +- **Inference on Large-scale Geospatial Data**: Perform inference using the models that have been trained on 'chips' (typically measuring 224 x 224 pixels) on expansive HLS tiles, which measure 3660 x 3660 pixels. + +### Apps Component + +- **Operationalize Models**: Once data has been created and model trained, deploy model for use using the Apps components. HLS tile predictions can be overlaid and visualized on interactive maps. + +### Putting It All Together - Locust Breeding Ground Prediction +See [InstGeo_Demo](notebooks/InstaGeo_Demo.ipynb) notebook for an end-to-end demo. +#### Download locust breeding ground observation records. +```bash +mkdir locust_breeding +gsutil -m cp -r gs://instageo/data/locust_breeding/records locust_breeding +``` + +#### Download HLS tiles, create chips and segmentation maps + +**Note: Ensure that you have up to **1.5TB** free disk space** + +Create output directory for each split +```bash +mkdir locust_breeding/train locust_breeding/val locust_breeding/test +``` + +- Train Split +```bash +python -m "instageo.data.chip_creator" \ + --dataframe_path="locust_breeding/records/train.csv" \ + --output_directory="locust_breeding/train" \ + --min_count=1 \ + --chip_size=224 \ + --no_data_value=-1 \ + --temporal_tolerance=3 \ + --temporal_step=30 \ + --mask_cloud=False \ + --num_steps=3 +``` + +- Validation Split +```bash +python -m "instageo.data.chip_creator" \ + --dataframe_path="locust_breeding/records/val.csv" \ + --output_directory="locust_breeding/val" \ + --min_count=1 \ + --chip_size=224 \ + --no_data_value=-1 \ + --temporal_tolerance=3 \ + --temporal_step=30 \ + --mask_cloud=False \ + --num_steps=3 +``` + +- Test Split +```bash +python -m "instageo.data.chip_creator" \ + --dataframe_path="locust_breeding/records/test.csv" \ + --output_directory="locust_breeding/test" \ + --min_count=1 \ + --chip_size=224 \ + --no_data_value=-1 \ + --temporal_tolerance=3 \ + --temporal_step=30 \ + --mask_cloud=False \ + --num_steps=3 +``` + +#### Launch Training + +Before launching training, modify the path to chips and segementation maps in each split +```python +for split in ["train", "val", "test"]: + root_dir = "locust_breeding" + chips = [ + chip.replace("chip", f"{split}/chips/chip") + for chip in os.listdir(os.path.join(root_dir, f"{split}/chips")) + ] + seg_maps = [ + chip.replace("chip", f"{split}/seg_maps/seg_map") for chip in chips_orig + ] + + df = pd.DataFrame({"Input": chips, "Label": seg_maps}) + df.to_csv(os.path.join(root_dir, f"{split}.csv")) +``` + +```bash +python -m instageo.model.run --config-name=locust \ + root_dir='locust_breeding' \ + train_filepath="locust_breeding/train.csv" \ + valid_filepath="locust_breeding/val.csv" +``` + +#### Run Evaluation +```bash +python -m instageo.model.run --config-name=locust \ + root_dir='locust_breeding' \ + test_filepath="locust_breeding/test.csv" \ + train.batch_size=16 \ + checkpoint_path='instageo-data/outputs/2024-03-01/09-16-30/instageo_epoch-10-val_iou-0.70.ckpt' \ + mode=eval +``` + +#### Run Inference on Africa Continent +- Download HLS tiles +```bash +python -m "instageo.data.chip_creator" \ + --dataframe_path="gs://instageo/utils/africa_prediction_template.csv" \ + --output_directory="inference/20223-01" \ + --min_count=1 \ + --no_data_value=-1 \ + --temporal_tolerance=3 \ + --temporal_step=30 \ + --num_steps=3 \ + --download_only +``` + +- Inference +```bash +python -m instageo.model.run --config-name=locust \ + root_dir='inference/20223-01' \ + test_filepath='hls_dataset.json' \ + train.batch_size=16 \ + checkpoint_path='instageo-data/outputs/2024-03-01/09-16-30/instageo_epoch-10-val_iou-0.70.ckpt' \ + mode=predict +``` + +#### Visualize Predictions +- Run InstaGeo Serve +```bash +cd instageo/apps +streamlit run app.py +``` +- Specify the directory containing the predictions. + +![InstaGeo Serve](assets/instageo_serve.png) + + +## Contributing + +We welcome contributions to InstaGeo. Please follow the [contribution guidelines](./CONTRIBUTING.md) for submitting pull requests and reporting issues to help us improve the package. + +## License + +This project is licensed under the [CC BY-NC-SA 4.0](LICENSE). diff --git a/assets/instageo_serve.png b/assets/instageo_serve.png new file mode 100644 index 0000000..601735a Binary files /dev/null and b/assets/instageo_serve.png differ diff --git a/assets/logo-dark.png b/assets/logo-dark.png new file mode 100644 index 0000000..119de29 Binary files /dev/null and b/assets/logo-dark.png differ diff --git a/assets/logo.png b/assets/logo.png new file mode 100644 index 0000000..c344dcb Binary files /dev/null and b/assets/logo.png differ diff --git a/docker/Dockerfile b/docker/Dockerfile new file mode 100755 index 0000000..40fa293 --- /dev/null +++ b/docker/Dockerfile @@ -0,0 +1,62 @@ +FROM mambaorg/micromamba:0.24.0 as conda + +USER root + +RUN apt-get update && \ + apt-get upgrade -y && \ + apt-get install -y git pip swig cmake libpython3.8 golang && \ + rm -rf /var/lib/apt/lists/* + +## Speed up the build, and avoid unnecessary writes to disk +ENV LANG=C.UTF-8 LC_ALL=C.UTF-8 PYTHONDONTWRITEBYTECODE=1 PYTHONUNBUFFERED=1 +ENV PIPENV_VENV_IN_PROJECT=true PIP_NO_CACHE_DIR=false PIP_DISABLE_PIP_VERSION_CHECK=1 + +COPY --chown=$MAMBA_USER:$MAMBA_USER docker/environment.yaml /tmp/environment.yaml + +ARG ACCELERATOR +RUN if [ "${ACCELERATOR}" = "cuda" ]; then \ + sed -i "s/jax==/jax[cuda]==/g" /tmp/environment.yaml ; \ + elif [ "${ACCELERATOR}" = "tpu" ]; then \ + sed -i "s/jax==/jax[tpu]==/g" /tmp/environment.yaml ; \ + sed -i "s/jax_cuda_releases.html/libtpu_releases.html/g" /tmp/environment.yaml ; \ + elif [ "${ACCELERATOR}" = "cpu" ]; then \ + sed -i "s/jax==/jax[cpu]==/g" /tmp/environment.yaml ; \ + else \ + echo "ACCELERATOR should be cpu, cuda or tpu, but got: \"${ACCELERATOR}\"" && exit 1; \ + fi + +RUN micromamba create -y --file /tmp/environment.yaml \ + && micromamba clean --all --yes \ + && find /opt/conda/ -follow -type f -name '*.pyc' -delete + +FROM debian:bullseye-slim as test-image + +# Let's have git installed in the container. +RUN apt update +RUN apt install -y git curl + +COPY --from=conda /opt/conda/envs/. /opt/conda/envs/ +ENV PATH=/opt/conda/envs/ml-research-template/bin/:$PATH APP_FOLDER=/app +ENV PYTHONPATH=$APP_FOLDER:$PYTHONPATH +ENV TZ=Europe/London +RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone + +WORKDIR $APP_FOLDER + +ARG USER_ID=1000 +ARG GROUP_ID=1000 +ENV USER=eng +ENV GROUP=eng +ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/conda/envs/ml-research-template/lib + +COPY . $APP_FOLDER +RUN pip install --user -e ./ + +FROM test-image as run-image + +# N.B: Had to relink this library so that TensorFlow has access to IFUNC symbol clock_gettime +# when cpu or tpu accelerator used. +ENV LD_PRELOAD="${LD_PRELOAD}:/lib/x86_64-linux-gnu/librt.so.1" + +# The run-image (default) is the same as the dev-image with the some files directly +# copied inside diff --git a/docker/environment.yaml b/docker/environment.yaml new file mode 100644 index 0000000..680fee7 --- /dev/null +++ b/docker/environment.yaml @@ -0,0 +1,12 @@ +name: ml-research-template +channels: + - defaults + - conda-forge +dependencies: + - python=3.10 + - pip + - pip: + - --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html + - jax==0.3.24 + - jaxlib==0.3.24 + - -r ../requirements.txt diff --git a/instageo/__init__.py b/instageo/__init__.py new file mode 100644 index 0000000..be1596a --- /dev/null +++ b/instageo/__init__.py @@ -0,0 +1,51 @@ +# ------------------------------------------------------------------------------ +# This code is licensed under the Attribution-NonCommercial-ShareAlike 4.0 +# International (CC BY-NC-SA 4.0) License. +# +# You are free to: +# - Share: Copy and redistribute the material in any medium or format +# - Adapt: Remix, transform, and build upon the material +# +# Under the following terms: +# - Attribution: You must give appropriate credit, provide a link to the license, +# and indicate if changes were made. You may do so in any reasonable manner, +# but not in any way that suggests the licensor endorses you or your use. +# - NonCommercial: You may not use the material for commercial purposes. +# - ShareAlike: If you remix, transform, or build upon the material, you must +# distribute your contributions under the same license as the original. +# +# For more details, see https://creativecommons.org/licenses/by-nc-sa/4.0/ +# ------------------------------------------------------------------------------ + +"""InstaGeo: A Geospatial Data Analysis Library. + +This package provides tools for processing and analyzing geospatial data, +with components for data handling (instageo.data), modeling (instageo.model), +and applications (instageo.apps). +""" + +from __future__ import annotations + +import pathlib + +# Package information +__version__ = "0.0.1" + +# Package paths +INSTAGEO_PATH = pathlib.Path(__file__).parent.resolve() +"""Path to InstaGeo package.""" + +INSTAGEO_APPS_PATH = INSTAGEO_PATH / "apps" +"""Path to the InstaGeo apps sub-package.""" + +INSTAGEO_DATA_PATH = INSTAGEO_PATH / "data" +"""Path to the InstaGeo data sub-package.""" + +INSTAGEO_MODEL_PATH = INSTAGEO_PATH / "model" +"""Path to the InstaGeo model sub-package.""" + +INSTAGEO_PARENT_PATH = INSTAGEO_PATH.parent +"""Path to the parent directory containing InstaGeo.""" + +INSTAGEO_TEST_PATH = INSTAGEO_PARENT_PATH / "tests" +"""Path to InstaGeo test directory.""" diff --git a/instageo/apps/README.md b/instageo/apps/README.md new file mode 100644 index 0000000..65b18b5 --- /dev/null +++ b/instageo/apps/README.md @@ -0,0 +1,31 @@ +# InstaGeo - Apps (Serve) + +## Overview +InstaGeo Serve is a web application that enables the visualization of HLS tile predictions created using InstaGeo Model component. + +To visualize the predictions, a user simply needs to provide the absolute path of the parent directory containing the predictions in a specific format. + +![InstaGeo Serve](../../assets/instageo_serve.png) + + +## Getting Started +The predictions need to be stored in a directory having the format `parent/directory/year/month`, where `month` is an integer representing the month, for example January is `1`. + +1. **Launch the Application**: Initiate the Streamlit app by executing the command `streamlit run app.py` in your terminal. This will start the InstaGeo Serve application. + +2. **Specify Options**: Upon starting the app, you'll be required to input the parent directory containing the prediction in the specified format. Also specify date information and country of interest. + +3. **Explore Your Predictions on the Map**: After specifying the options, lick on `Generate Map` to fetch predictions for specified the date and country, which will be overlaid on an interactive map for exploration. + +By streamlining the process of visualizing GeoTiff files, InstaGeo Serve aims to empower users with an efficient and accessible tool for visualizing model predictions. + +## Dependencies +- Python 3.x +- Plotly +- Rasterio +- Xarray +- Datashader +- Streamlit + +## Contribution +Contributions to enhance InstaGeo Serve are encouraged. Follow standard GitHub pull request procedures for collaboration. diff --git a/instageo/apps/__init__.py b/instageo/apps/__init__.py new file mode 100644 index 0000000..62b5402 --- /dev/null +++ b/instageo/apps/__init__.py @@ -0,0 +1,7 @@ +"""Applications and Utilities for InstaGeo. + +The 'apps' subpackage provides a collection of applications, tools, and utilities +built upon the InstaGeo framework. These may include command-line tools, web interfaces, +or other practical implementations that utilize the data handling and modeling +capabilities of InstaGeo for specific geospatial analysis tasks. +""" diff --git a/instageo/apps/app.py b/instageo/apps/app.py new file mode 100644 index 0000000..9f4f60b --- /dev/null +++ b/instageo/apps/app.py @@ -0,0 +1,110 @@ +# ------------------------------------------------------------------------------ +# This code is licensed under the Attribution-NonCommercial-ShareAlike 4.0 +# International (CC BY-NC-SA 4.0) License. +# +# You are free to: +# - Share: Copy and redistribute the material in any medium or format +# - Adapt: Remix, transform, and build upon the material +# +# Under the following terms: +# - Attribution: You must give appropriate credit, provide a link to the license, +# and indicate if changes were made. You may do so in any reasonable manner, +# but not in any way that suggests the licensor endorses you or your use. +# - NonCommercial: You may not use the material for commercial purposes. +# - ShareAlike: If you remix, transform, or build upon the material, you must +# distribute your contributions under the same license as the original. +# +# For more details, see https://creativecommons.org/licenses/by-nc-sa/4.0/ +# ------------------------------------------------------------------------------ + +"""InstaGeo Serve Module. + +InstaGeo Serve is a web application that enables the visualisation of GeoTIFF files in an +interactive map. +""" + +import glob +import json +import os +from pathlib import Path + +import streamlit as st + +from instageo import INSTAGEO_APPS_PATH +from instageo.apps.viz import create_map_with_geotiff_tiles + + +def generate_map( + directory: str, year: int, month: int, country_tiles: list[str] +) -> None: + """Generate the plotly map. + + Arguments: + directory (str): Directory containing GeoTiff files. + year (int): Selected year. + month (int): Selected month formatted as an integer in the range 1-12. + country_tiles (list[str]): List of MGRS tiles for the selected country. + + Returns: + None. + """ + try: + if not directory or not Path(directory).is_dir(): + raise ValueError("Invalid directory path.") + + prediction_tiles = glob.glob(os.path.join(directory, f"{year}/{month}/*.tif")) + tiles_to_consider = [ + tile + for tile in prediction_tiles + if os.path.basename(tile).split("_")[1].strip("T") in country_tiles + ] + + if not tiles_to_consider: + raise FileNotFoundError( + "No GeoTIFF files found for the given year, month, and country." + ) + + fig = create_map_with_geotiff_tiles(tiles_to_consider) + st.plotly_chart(fig, use_container_width=True) + except (ValueError, FileNotFoundError, Exception) as e: + st.error(f"An error occurred: {str(e)}") + + +def main() -> None: + """Instageo Serve Main Entry Point.""" + st.set_page_config(layout="wide") + st.title("InstaGeo Serve") + + st.sidebar.subheader( + "This application enables the visualisation of GeoTIFF files on an interactive map.", + divider="rainbow", + ) + st.sidebar.header("Settings") + with open( + INSTAGEO_APPS_PATH / "utils/country_code_to_mgrs_tiles.json" + ) as json_file: + countries_to_tiles_map = json.load(json_file) + + with st.sidebar.container(): + directory = st.sidebar.text_input( + "GeoTiff Directory:", + help="Write the path to the directory containing your GeoTIFF files", + ) + country_code = st.sidebar.selectbox( + "ISO 3166-1 Alpha-2 Country Code:", + options=list(countries_to_tiles_map.keys()), + ) + year = st.sidebar.number_input("Select Year", 2023, 2024) + month = st.sidebar.number_input("Select Month", 1, 12) + + if st.sidebar.button("Generate Map"): + country_tiles = countries_to_tiles_map[country_code] + generate_map(directory, year, month, country_tiles) + else: + st.plotly_chart( + create_map_with_geotiff_tiles(tiles_to_overlay=[]), use_container_width=True + ) + + +if __name__ == "__main__": + main() diff --git a/instageo/apps/requirements.txt b/instageo/apps/requirements.txt new file mode 100644 index 0000000..aec9c78 --- /dev/null +++ b/instageo/apps/requirements.txt @@ -0,0 +1,7 @@ +absl-py +plotly +rasterio +xarray +datashader +matplotlib +streamlit==1.31.1 diff --git a/instageo/apps/utils/country_code_to_mgrs_tiles.json b/instageo/apps/utils/country_code_to_mgrs_tiles.json new file mode 100644 index 0000000..0761832 --- /dev/null +++ b/instageo/apps/utils/country_code_to_mgrs_tiles.json @@ -0,0 +1 @@ +{"AE": ["39QYF", "39QWD", "40RBM", "40RBN", "40QBJ", "39QXG", "40QBK", "39QVE", "39QXD", "40QCM", "40QCL", "39QWF", "39QXE", "39QZG", "40QBM"], "AF": ["42STF", "41SPV", "41RPQ", "41RQP", "41SQS", "41SPT", "41RLP", "41SNV", "41SMU", "42STB", "41RNP", "41SPR", "41SQB", "42SVB", "42RTA", "42SUB", "41SPS", "42SWD", "41SQT", "41SNU", "41SNT", "42SXF", "42SUF", "42SXE", "41SQA", "42STA", "41SLR", "42SWF", "42SUD", "41SMS", "41SQU", "41SLV", "41RMN", "42SUA", "41SLT", "41SNR", "41SPA", "42SVD", "41RLQ", "41SQV", "42RTV", "42RTU", "41RPR", "41SQR", "43SBA", "42RUA", "42SWE", "42SVC", "42SVE", "41SLS", "41RQQ", "42SXC"], "AO": ["33LXJ", "33LVL", "34LBK", "34LDK", "33LWF", "33LVH", "33LUD", "34LCS", "33LWJ", "34LEL", "33LWE", "34LFP", "34LFM", "33MVN", "33KXC", "34LBQ", "33LVK", "33LYL", "33MXM", "33MTN", "34LCJ", "33LYK", "34LCP", "33KXB", "33LYH", "33LYJ", "33MUM", "34LEM", "34LFQ", "34LCR", "33LWC", "33LWK", "34LCK", "34LDN", "34LAJ", "33LTC", "33LXK", "34LEP", "33LYC", "34LBR", "34LDS", "34LDR", "33LXG", "33LTD", "34MES", "34LDM", "33LYG", "33MWM", "33MTP", "34LCN", "34LCQ", "33LXH", "32KQG", "33LXD", "34LGP", "33LTE", "33LSD", "33MUN", "33LUC", "33LVE", "33MVP"], "BF": ["30PYU", "30PVV", "30PXV", "30PYV", "31PBQ", "30PWV", "31PCP", "31PBP", "30PVS", "31PBR", "30PWA", "30PXT", "30PYB", "30PUU", "30PXA", "31PCQ", "30PUT", "31PBN", "30PZU", "30PVU", "31PDP"], "BI": ["36MSC", "35MRS", "35MQS"], "BJ": ["31PCM", "31PCJ", "31PDL", "31NCH", "31PCL", "31PDN", "31PCK", "31PBM", "31PEM", "31PDP", "31PDK"], "BW": ["34JFT", "34JFS", "34JGT", "34KFD", "34KEE", "34KDU", "35KPQ", "35KNS", "35KPR", "34JGS", "35KLT", "35KJR", "35KMT", "34JDS", "34KEA", "34KFE", "35JKM", "35JLM", "34KFC", "35KMP", "35KMS", "34KDV", "35KLP", "34KGV", "35KLS", "35KLB", "34KDB", "34JFR", "35KMR", "35KLV", "34JDU", "35KKA", "34KGU", "34KDD", "34JER", "35KKS", "35KKQ", "34KGE", "35KQS", "35KMQ", "35KLU", "34KEV", "35KNQ", "34KFA", "35JLP", "34JGU", "34KDC"], "CD": ["34NGH", "34MGV", "35MPS", "35MLP", "35MPV", "35MKV", "35MMV", "35LML", "34MGA", "34NEF", "34MEU", "34NHG", "35MJT", "35NKC", "35MPT", "35MJM", "34NBF", "35NPB", "33MYM", "34MHT", "34MFD", "34NCH", "35MNS", "35MLS", "34LGQ", "34NEJ", "35NQD", "35NMC", "35LMJ", "34NBK", "34NDF", "35MLQ", "34MDD", "35MNT", "35MMU", "33MZQ", "35NMD", "35MJR", "34MCB", "34MED", "33MXS", "33NYA", "35NKE", "34NDH", "34MAU", "35NLE", "35NQE", "33MXR", "35MKT", "36NTH", "35MMM", "35NLD", "34NCJ", "33NZA", "34MCV", "35NLA", "33MYT", "35MMP", "34MBE", "33MXP", "33LZL", "34MHU", "33MWP", "35NLB", "34NGJ", "33MWQ", "35MNQ", "34MDC", "33MVP", "35MLT", "34MAA", "34MGC", "34NGK", "33MZV", "35MKP", "33MYR", "35MMQ", "34MEC", "35MKS", "34MEV", "35MLV", "34NCF", "34NHJ", "34LGS", "34MDA", "34MBA", "34NCG", "35NPD", "35LNH", "34MDV", "34MFV", "35MQN", "35NLC", "33MYQ", "34NEH", "34MCA", "34MHE", "34LFR", "34MEB", "34MHC", "34MCD", "35MMT", "35MMS", "33MXT", "34MFE", "35LNJ", "34MBT", "33MYP", "35MNU", "34NGF", "35MLU", "34MEA", "34MHV", "35NQB", "35NPC", "34MBV", "34MGU", "35NNA", "34MEE", "35MMN", "35MKN", "34MCU", "35NKA", "35MQP", "33MXU", "34MAE", "35NND", "34MFB", "34MBB"], "CF": ["34PES", "33NZJ", "34PER", "35NME", "33NWH", "34NGP", "35NLG", "33NXD", "35NMH", "34PDR", "33NXJ", "34NEM", "34PEQ", "34PFT", "34PGR", "34PGQ", "34NFM", "34NBJ", "34NDP", "35NKG", "34PGT", "34NDN", "34PHQ", "33NYG", "33NZH", "35NJF", "35NMG", "35NJJ", "33NXE", "34PFP", "34NDM", "33NYH", "34NCL", "34PCQ", "34PDQ", "34NFL", "33NWE", "34NCP", "34NCN", "34NGL", "34NEN", "34PFR", "35NKF", "34NEP", "34NBM", "33NWF"], "CG": ["33NVC", "33MWU", "33NUB", "33MVS", "33MWV", "33MUS", "33MVR", "33NXB", "33NUC", "33MTS", "33NYC", "34NAH", "33NXA", "33NYB", "33MVV", "33NZC", "32MRA", "34NBG", "33NWA", "33MSQ", "34NBH", "33MWT", "33MUR"], "CI": ["30NUM", "29NRG", "30PTR", "30PWR", "29NPJ", "29NPH", "29PQK", "29PQM", "30NWP", "30NVN", "30NUN", "29NQG", "30PUR", "30PVQ", "30PWP", "30PUQ", "30NVM", "29NQH", "30PTQ", "30NUP"], "CM": ["33PVK", "33PVM", "32NPJ", "33NVG", "32NQJ", "33PUL", "33NVH", "33NUJ", "33NSF", "33NVJ", "33PTK", "32NQM", "33NVE", "33NUD", "33NTH", "33PVN", "33NVF", "33NTF", "33PUK", "32NPM", "33PSJ", "33NUF", "32NNK", "32NRJ", "32NPK", "32NNL", "33NUH"], "CN": ["43SEB", "43SFA"], "CY": ["36SVD", "36SWD"], "DJ": ["38PKS", "38PKT", "38PJS", "37PHM"], "DZ": ["31RCK", "32SLB", "30RYN", "32RKP", "31RGP", "29RQK", "31RFP", "31RFL", "31SES", "31RGM", "31REK", "31SFT", "31SEA", "31REJ", "31SDU", "31RDK", "31RFG", "31SBV", "32RLU", "29RNK", "31SDT", "31SDR", "31QGC", "30RYU", "31RDH", "31RGL", "31SEV", "30RWT", "32RLM", "31SGS", "31SDV", "31RDL", "31REH", "30RYV", "31RCN", "31SCU", "31RFN", "31REQ", "32SKC", "31QEG", "32SLD", "30RVU", "32RKU", "30RWR", "32RMR", "29RPL", "31SCV", "31SBU", "32QKM", "30SYA", "30RXT", "29RQJ", "29RPH", "31RFM", "32RLS", "30RXS", "31QBG", "31SEU", "32RKT", "31RCM", "31SFS", "30SYC", "30RTQ", "30SXD", "32RLR", "31SGV", "31RDQ", "31RGH", "30QYM", "31SGT", "31RBP", "30SXC", "31RBM", "30RXU", "31RBH", "31SCA", "31RDM", "31REP", "31RBL", "32RLN", "31REG", "30RXR", "31QGD", "29RMK", "31SCS", "31REM", "31RCR", "31SFR", "31RDN", "32QKL", "31SFV", "31QGE", "32RMP", "30RYP", "31RCH", "30RYS", "32SLA", "29RPK", "31RGJ", "32QLM", "31SFU", "30RYT", "30RXV", "32RNR", "31QEE", "31RDR", "30RTP", "32RNQ", "29RMJ", "31SGR", "31RFJ", "31QFD", "32RPR", "31SCT", "29RQL", "31SBS", "29RNL", "29RQH", "31RCQ", "31QDG", "32SKE", "32SKD", "30RZN", "31RCP", "32QLJ", "31RGK", "32SLC", "32RLQ", "31QDF", "30SXB"], "EG": ["35RMN", "36RUT", "36RVR", "35RLL", "35RMQ", "35RPK", "36RVM", "36RVV", "36RVS", "36QUJ", "35QPG", "35RQK", "36RXS", "36QTL", "36QVL", "35RMP", "36RXP", "35RNM", "35RKL", "35RLP", "36RUU", "36RWQ", "35RMJ", "36RWT", "36RWV", "35RPP", "36RTT", "35RQL", "36RVT", "35RNG", "35RLK", "35RNH", "36RUR", "35QQF", "36RUV", "35RNN", "36RVN", "35RNL", "35RPJ", "36QVH", "36QVK", "36RYS", "36RWR", "35RPH", "35RQG", "36QVJ", "36RTM", "35RKM", "35QNG", "36QTK", "36RXM", "36RXR", "36QUM", "35RMM", "36RTP", "36QXM", "36QXK", "36RTU", "36RWP", "35RLM", "36RWS", "35RKK", "36RTQ", "35RQP"], "EH": ["28REN", "29RMK", "29RLJ", "28RGQ", "29RLK", "28QEL", "28QDK", "28QCL", "28QDL", "28QDM", "28QCK", "28QEK", "28RFP"], "ER": ["38PKV", "37PHQ", "37QDT", "37PCR", "37QDU", "37QET", "37PET", "37PES", "37QDV", "37PGS", "37PFS", "37QEU", "37PDT", "37QCU", "37PCS"], "ES": ["30SWG", "29SQB", "30SVG", "28RDR", "28RFS"], "ET": ["36PZR", "37PEM", "37PGM", "36PXS", "37PCM", "36PYS", "37NDJ", "36NYL", "37PCL", "36PZP", "37PCQ", "37PGK", "36PZQ", "37PCR", "37PDK", "37NBJ", "36NYP", "37NEJ", "37PAL", "37NGF", "36PYP", "37PAK", "37NAG", "37PBK", "37PFP", "37PBL", "37NBH", "37NGJ", "37NCG", "37NFF", "37NCD", "37PFR", "37PEQ", "37NBG", "37NDF", "37NBF", "37PFN", "37NEF", "38NKN", "37PGL", "37NDE", "37PEP", "37PER", "37NHH", "37PFQ", "38PLQ", "38PKQ", "36PYR", "37PCK", "37PCP", "37PDQ", "37NBE", "36NWP", "37PDM", "36NWN", "36PXR", "37PCN"], "GA": ["33MTV", "32MPD", "33MVU", "33MVV", "32MNE", "32NQG", "32MND", "33MTU", "32NPF", "33MTT", "32MQE", "32MRE", "32MNC", "33NTB", "32MRD", "32MQB", "32MQD", "32NNF"], "GH": ["31NAG", "30NXP", "30NZN", "31NBG", "30NWN", "31NBH", "30PZQ", "30NYN", "30NWL", "31NAJ", "30PXQ", "30NXL", "30PYR"], "GM": ["28PFV", "28PEA", "28PCV", "28PEV", "28PCA"], "GN": ["29PKM", "28PFS", "28PET", "28PGT", "29PLL", "28PFU", "29PJP", "28PGS", "29PNM", "29PLM", "28PHU", "29PPK", "28PGV"], "GQ": ["32NQH", "32NNH"], "GR": ["35SMA", "35SNV", "34SGG", "35SKV", "35SLV"], "GW": ["28PEU", "28PDU"], "IL": ["36RXU", "36SYB", "36RXV", "36RXT", "36RXA"], "IN": ["43REN", "43SFV", "43QFU", "43QEU", "43QBE", "43RFN", "43QEV", "43PGS", "43SFU", "43SDT", "43RDG", "43PFS", "43RFK", "43PDQ", "43QGG", "43QDU", "43SES", "43SEU", "43SFR", "43RFG", "43PFP", "43RFM", "43QBF", "43RGL", "43QGD", "43RFP", "43QBC", "43QFC", "43QEF", "43QCV", "43QAD", "43QDA", "43RDH", "43SFT", "42QXJ", "43RFL", "43PDT", "43QFV", "43QCD", "43PGK", "43PCT", "43QFB", "43QDF", "43QCE", "43PES", "43QCU", "43RGN", "43QGF", "43SFS", "43QCB", "43PDR", "43QCA", "43RGK", "43SER", "43PDS", "43RFJ", "43QFA", "43QDV", "43PGQ", "43RBK", "43RCL", "42QWM", "43RDL", "42RYN", "43RDK", "43RDM", "43REM", "42RXR", "43RDN", "43RCK", "42QXL", "43RCN", "42RXN", "43RBJ", "43RCM", "42QWL", "43RDP", "42QVM", "43REL", "42RXP", "42RYP", "43RBL", "42RYQ", "43RCP", "43RCJ", "42QXM", "43RBM", "42RXQ", "43QGE", "43QCF", "43RGM", "43RDJ", "43QEE", "43SGU", "43QDB", "43QDE", "43RFQ", "43PGP", "43QGC", "43PFQ", "43QCC"], "IQ": ["38SQA", "38SKC", "37SGS", "37SFS", "38SMD", "38SMB", "38RQU", "38RKV", "37SES", "38SLC", "38SNC", "38SNB", "38SLE", "38RPU", "38SKE", "38RLT", "38SKB", "37SGT", "38SLD", "38SKD", "38SLA", "38RMU", "37SGU", "38RNV", "38RMV", "39RTP", "38RMA", "38RNU", "38RLV", "38SME", "37SFR", "38SNA", "38SLB"], "IR": ["39SVT", "40SFE", "41RMQ", "38SNC", "40RBT", "39RUP", "40SDG", "39RXR", "40SBF", "40SDC", "39SWT", "39SUT", "40RGV", "38SQC", "39SYU", "40SEA", "41RLN", "40SCC", "38SQG", "40REU", "40RFU", "39RVQ", "38SPF", "39SVA", "38SQD", "40SED", "40SEG", "40SDD", "39SWV", "39STV", "40SCB", "39RWP", "38SPD", "39SXS", "41SLV", "38SQB", "41RKM", "40SDE", "39SXT", "39SXA", "40RDS", "40RDV", "40SCF", "38SQE", "41SKU", "38SPB", "40RCV", "38SPC", "39SUA", "40SCE", "39RVP", "41RLJ", "40REP", "39RXL", "39RWM", "41RKK", "41RKJ", "40RFP", "39RYK", "39RTP", "40RBQ", "40RCR", "40RBR", "40RFR", "39RYL", "40REQ", "41RKL", "40RFQ", "41SKS", "41RLK", "41RKR", "41RKQ", "40RGA", "40RDU", "41RLM", "39RXM", "41RKP", "40RCQ", "41SKT", "40RGQ", "41RLL", "39RYM", "39RVN", "39RWQ", "38SPE", "40REV", "40RBV", "40SDB", "38SNE", "40RGS", "39SVS", "39RYQ", "40SDF", "39SVU", "38SNG", "40REA", "41RLQ", "40SBC", "40SBA"], "IT": ["33SVB"], "JO": ["37SBR", "37RBN", "37RBQ", "36RYU", "37RBP"], "KE": ["37NDC", "36MZC", "37NGF", "37MDT", "37MET", "37NFA", "37MBV", "37MFV", "37MCT", "37NDB", "37NBB", "37MEV", "36MYD", "37MER", "37NFF", "37NBD", "37NCB", "37NFB", "36MYE", "37NAB", "37MGU", "37MCV", "37NBC", "37NCD", "37MFT", "37NCC", "37MDV", "37MGT", "37MDS", "37MDU", "36NYH", "36NXF", "37MBT", "37MEU", "37NGB", "36NYJ", "37MCU", "38NJK", "37NHE", "37NDD", "37NDA"], "KW": ["38RQS"], "LB": ["37SBU", "36SYC"], "LR": ["29NNG", "29NLJ", "29PMK", "29NMH", "29NPF", "29NNF", "29NMF"], "LS": ["35JPG", "35JNG", "35JPH"], "LY": ["34RGP", "34SEB", "34RCP", "33RWJ", "35RKN", "34RFN", "33RXQ", "33RUQ", "34RFT", "34RFP", "33RYM", "33RWK", "35RKH", "34QFL", "35RLH", "32RQP", "32QNK", "32QML", "32RPQ", "35QLC", "33RVP", "33RXJ", "33RTK", "33RVH", "33QUE", "33RVM", "34QFK", "33QTF", "33RYL", "33RWH", "32QQL", "34RBU", "32RPT", "33RYJ", "33RUG", "33STR", "32RMS", "34RES", "32RNM", "34SEA", "34RFA", "33RTH", "34RGT", "33QXG", "34RGQ", "33RXN", "34RCT", "34RCQ", "35RKR", "34SDA", "35RLP", "34QEJ", "34RBP", "34REU", "33RUR", "34RBR", "33RXK", "33RUK", "33RVJ", "33RUN", "33RTM", "33RYP", "34RER", "32QNM", "33QVF", "34RCR", "33QTG", "34RDT", "35QMF", "34QGK", "34REV", "35QKC", "34QEM", "34RET", "34RDV", "33RTR", "33RYK", "34RFV", "34REA", "34RGU", "32RQR", "33SVR", "32RNT", "32RMT", "33RTQ", "34RDS", "34RDU", "33RWM", "35QLF", "34RDP", "32RQM", "34RDM", "34RFQ", "34QGL", "35QKE", "33RVN", "32RPR", "33RWN", "33RVK", "34RBQ", "34RGS", "32RMV", "33RYN", "34REN", "32RNS", "32RNU", "35QKF", "33RTJ", "34RCS", "34RGR", "34RFU", "32RQS", "33RTL", "32RQV", "32RPV", "33RVQ", "33RWQ", "34SFA", "33RUJ", "32QNL", "34SCA", "35QLE", "32RPN", "35RKM", "32RQT", "32RQQ", "32RLV"], "MA": ["29RPP", "29SPS", "30SVC", "30STC", "29RMQ", "30RTR", "30SUC", "29RQQ", "29RLH", "30STD", "29RNN", "29SQU", "30RVR", "30RUQ", "29RKG", "30RVS", "30SUB", "30RYV", "30RUS", "30SUD", "29RNR", "30RTT", "29RPM", "29RPN", "29RMN", "29RMG", "29SNS", "30SXC", "30STE", "29SMR", "29SPR", "28REN", "30SWC", "30RVT", "30SUA", "29RPQ", "29RNP", "30SXA", "29RLM", "29RNQ", "30SVD", "29RNM", "28RFM", "28QFM", "29RMM", "28RFN", "30STB", "30SVA", "29RPR", "30RXV", "30SVB", "30RUU"], "MG": ["38KQF", "39LTE", "38JMU", "38KNV", "38KPE", "38KLA", "38KMD", "38KME", "38KNA", "38KNF", "39KTC", "39LTG", "38KMB", "38KNG", "39LUF", "38KQA", "38KRE", "38KQC", "39LTD", "38KMA", "39LVC", "38LQH", "39KTV", "38KQV", "38KPF", "38JNT", "38KQG", "38KQB", "39LTC", "39KTT", "39LUE", "38KNB", "38KPC", "38KMG", "38KRG", "38KNC", "38KMF", "39KTU", "38KLB", "39LUD", "38KPA"], "ML": ["29PPN", "31QAA", "31QCT", "31PBT", "30QWL", "29QQD", "30PUV", "29QPE", "30QYL", "30QUK", "31QAC", "30PVA", "31QDD", "30QWF", "29PLQ", "31QDE", "30QXD", "30QTL", "29PQM", "29PQS", "30QXC", "30QTK", "29PPP", "30QUG", "30RWN", "30QXK", "30PUB", "31QEB", "30QXL", "30QVD", "30QYE", "29PNP", "31QDU", "31QCV", "29PLP", "30QWJ", "30QXJ", "30QVE", "31QDB", "31QFA", "30QTH", "29QQC", "30QUF", "31QCE", "31QCU", "30PTB", "29PMP", "29PQN", "30QUD", "29PPR", "31QEV", "31QFB", "29PKR", "30QYD", "29PNN", "30QYK", "30RUP", "31PDT", "30PST", "31QBU", "30QTM", "30QZJ", "30QVL", "31QBC", "30QWG", "29PMR", "30QUL", "30QVF", "30PVV", "31QCF", "31QEA", "29QPD", "29PQR", "30QUE", "29PKQ", "30QWC", "31QEC", "30QXG", "30QVH", "30RTN", "30PTV", "30QTD", "30PSU", "30QWH", "29PRQ", "30RWM", "31PET", "31PDS", "30QSH"], "MR": ["29QLA", "29QMC", "29QNB", "29QLG", "29QPC", "29QKB", "29QME", "29QNC", "29QLF", "28QGK", "29QPB", "29QLE", "29QNE", "29QND", "29QKG", "29QNG", "28QGJ", "29QPA", "28QGD", "29QQB", "29QLC", "29QNV", "28QEF", "28QEG", "30QTE", "28QEJ", "28QFM", "28QDJ", "28QFH", "29QLV", "28QEL", "28QEH", "28QFG", "28QCK", "28QEK", "28QFF", "28QGH", "28QFL", "29QMB", "29QQV", "29QKE", "28QGG"], "MU": ["40KEC"], "MW": ["37LBC", "36KYG", "36LWJ", "36LXN", "36LXL", "36KYH", "36KXF", "36LYH", "36LXK", "36KYF", "37KAC", "36LWQ"], "MZ": ["37LAD", "37LDD", "36KXB", "36KXV", "37KEB", "36KWD", "36KXA", "37LBD", "37LCC", "37KCB", "36KYU", "36KYA", "36KWA", "36KZE", "37LED", "36JVS", "37LDF", "37LDE", "36KYE", "36KWV", "36KVU", "37LEF", "36KXC", "36KWF", "37LCF", "37KCA", "36KXE", "37KDB", "37LBF", "36JWT", "37LFF", "37LDG", "37KCC", "36LWH", "36JVT", "37LCD", "36KWG", "36JXT", "36JVR", "37LFH", "37LBE", "37LFE", "37KBA", "36KXD"], "NA": ["34JCT", "34KDV", "34KCE", "33JYP", "34KDA", "33JXN", "33KYU", "33KWA", "33KTV", "33JYK", "33JWN", "34KDU", "33KYP", "33KYA", "33KXB", "33KZV", "34JCR", "33KZS", "33JXL", "34JCQ", "33KTB", "34KFG", "33KZB", "33KYR", "34JDS", "34KDF", "34KBC", "34KEF", "33KWT", "34JBQ", "33JVM", "34KAD", "33JXJ", "34JBT", "34KCU", "34KBG", "33KWS", "33KWV", "33KXT", "33KVS", "33JYL", "33KYQ", "33KWU", "33KVU", "33KVV", "34KBV", "33KYV", "33KTA", "34KDD", "33KUU", "33KXQ", "33JWK", "34KCD", "34KCC", "33KXA", "33KXS", "34KCG", "33KXV", "33JYN", "33KUV", "33KYS", "33JXM", "34JDT", "33KWQ", "33KXP", "33JWL", "34KDC", "34KCF", "34KBA", "33KUA"], "NE": ["32PKA", "31PES", "32QNG", "31PDR", "32QKH", "33QVB", "32QPF", "32QPH", "33QUA", "33QWU", "32QLJ", "31PER", "33QWB", "32QPC", "33PTQ", "32PLA", "32QMH", "32QMK", "31PGR", "33QTA", "33QVV", "32PMA", "32QJC", "32PKC", "31PEQ", "32QQE", "31PBS", "33PST", "32PKB", "32QRD", "32QNH", "31PBR", "33QSC", "32PNA", "32QMJ", "33PTR", "32QLF", "31QHV", "31QHA", "31PHR", "31PDS", "32PJA", "32PRA", "33PVS", "32QNJ", "31PCR", "32PMV", "33PUT", "32PMB", "33QSA", "32QQH", "31QFU", "32QPE", "33QUU"], "NG": ["33PUR", "31PHN", "32NKP", "32PLQ", "32PLP", "32PNS", "32NML", "33PTN", "31PFK", "32PNP", "32PLV", "31PEK", "32NQN", "32PKS", "32PNU", "33PUQ", "31PGQ", "32PKT", "32PPT", "31NEJ", "31PFL", "31NFJ", "31PEQ", "31NFH", "33PUP", "32PKQ", "32PQS", "32PPR", "33PTM", "31PGP", "32NLN", "32PPV", "33PTL", "31PHM", "31PFN", "32NLP", "32PPP", "32NNP", "32NMM", "32PNT", "32PLT", "32PQU", "31PHL", "32PPQ", "32PMT", "31NDH", "33PSQ", "32PMV", "33PTP", "32NKM", "31NGH", "32PLS", "33PUM", "31PGR", "32NQP", "32PQR", "32NJM", "32NJN", "31PFM", "31NGG", "32NLL", "31PFQ", "31PGL"], "OM": ["40QDH", "40QCK", "40QCL", "40QBF", "40QGL", "40QCH", "40QFJ", "40QBG", "40QBD", "40QGK", "40QEM", "40QFL", "40QEK", "40RCM", "40QBE", "40QDF", "39QZU", "40QFM", "39QYU", "40QEL", "40RDM", "40QDL", "40QAD", "40RCN", "40QEG", "40QCM", "40QAE", "40QFK", "40QDM", "40QCG", "40QDK", "40QBJ", "40QCJ", "40QBH"], "PK": ["42RVV", "43SEV", "43SEA", "42RUV", "43RBP", "42SYE", "42SXA", "42SYD", "42RUR", "42RXV", "43SCU", "43SDA", "43SDV", "42SYC", "43RDQ", "41RNK", "42RTN", "42RTS", "42RXN", "41RQM", "42RTP", "42RXR", "41RPN", "43RBR", "42RWS", "42SXB", "42RTT", "41RNM", "41RNJ", "41RPK", "42SYB", "41RMJ", "42RWT", "41RQN", "42RWN", "42RUP", "42RUU", "42RYA", "41RNH", "42RXM", "41RPL", "41RQL", "42RVR", "42RXA", "43RCP", "42RYS", "41RPM", "42RUS", "41RNN", "42RVT", "42SYA", "42RWM", "42RXS", "42RVM", "42RUQ", "41RQJ", "43SBV", "43RBQ", "43SEU"], "PS": ["36SYA"], "PT": ["28SCB"], "QA": ["39QVG", "39RWH"], "SA": ["39RTG", "38RQP", "38RNR", "38QNG", "38QNK", "37RCM", "38RPP", "38QKK", "38RPN", "38QLL", "39RUG", "38RQQ", "38QQJ", "37REH", "38QMK", "38RKV", "39RTH", "37RDQ", "37REM", "37REN", "38QMF", "39RTJ", "37RFP", "38RKU", "39QTD", "37RCR", "38QQK", "38QPG", "38QPF", "37REP", "37RCP", "37RCL", "38RNS", "37REQ", "37QGD", "37RCN", "38QLF", "37QGF", "38QNF", "37RBQ", "37QGE", "39QUD", "37QGB", "38QJE", "38QKE", "37QGC", "37QGA", "38QME", "38RMP", "37RCJ", "38RLQ", "37QEF", "38RMN", "37RDG", "37QFC", "37QED", "37QDG", "37RCK", "38RKR", "37RCH", "37QFD", "37RDK", "38RMQ", "37RGM", "38RQR", "38RLP", "39RUJ", "37QFB", "38QNH", "37RBK", "37RDH", "37QDF", "39RTL", "39RUK", "37QEC", "39RTK", "37REG", "38RKS", "37QEE", "38RPR", "37RDJ", "37RDP", "37RCG", "38RLR", "37RFL", "38QPL", "38QMG", "37RFH", "38RKN", "37RGJ", "38RKP", "38RLN", "38RJP", "37QEG", "38RKM", "37RFG", "38QKH", "37RFJ", "38RLM", "38QKG", "37RGH", "38QKL", "39QTE", "39QUG", "37SGR", "37RER", "38RPQ", "39QUE", "39QUF", "38RQN", "38QMJ"], "SD": ["36PSS", "35QQA", "36PVS", "35PJQ", "36PVT", "35QPU", "35PKQ", "34PHB", "36PYC", "34PFU", "35QNV", "36QTF", "35QQC", "37QBU", "37QCU", "36PUB", "35PKT", "35QQV", "36PWT", "35PMM", "34PGA", "36PZU", "35QMC", "35PLS", "36QUH", "36PYT", "35PQM", "35PLT", "36QXK", "36PYV", "35PPN", "35PNT", "36QXH", "35PKS", "36QYH", "35PKR", "35PNR", "34PHT", "35QMT", "35QND", "35QPA", "35QLC", "36PYB", "36QXG", "35QQD", "36QXD", "35QNE", "35QRV", "36QXF", "37QDV", "37QBD", "37QDA", "36QXE", "37QCA", "36QZK", "37QBA", "37QAD", "37QBE", "37QAE", "36QVF", "36QYL", "37QCB", "36QYF", "37QAA", "36QZJ", "36QYD", "36QZF", "36QUD", "36QVD", "36QWD", "36QUG", "36PWS", "35QPB", "35PKN", "36QVE", "36QYK", "36QUE", "36QVG", "37QAC", "36QWF", "36QYE", "36QWE", "35PPS", "36PXA", "35QQU", "35QQB", "35QLU", "36QYJ", "35QPV", "35PQQ", "35PRN", "35PLQ"], "SH": ["28HGD"], "SL": ["29PJL", "28PGR", "29PJK", "29NLJ", "29PLK", "29PKM", "28PGQ", "29PLL"], "SN": ["28PCB", "28PEV", "28PFA", "28QCD", "28PDV", "28PEA", "28PEB", "28QFD", "28PCU", "29PKP", "28PGA", "28PDB", "28QGC", "28QED", "28PGV"], "SO": ["37MGU", "39NSF", "38NLK", "39PVN", "38NQM", "38NLJ", "38NNJ", "38PMP", "38NKG", "38NMM", "38PQP", "38NMP", "37NGA", "38NLG", "38NLL", "39PWM", "38PNQ", "38PMQ", "38NLF", "37MGV", "38NKH", "39NTG", "38NNM", "38NKJ", "38NJG", "39PSK", "38PLR", "38NPP", "39NUJ", "38NKL", "38NNN", "38NPM", "39PUM", "39PWN", "38PLS", "39PTM", "39PTK", "38PKS", "39PVM", "38PPS", "38PMR", "39NTJ", "38PNS", "38NPN", "38PMS", "38NLM", "38NMN", "39PUK", "38NNP", "38NQN", "39PSM", "38PNR", "38PLP", "38NPJ", "38NLP", "39PTL", "38PLQ", "39PSL", "38NKN", "38NLH", "38NNK", "39NTH", "38NKM"], "SS": ["35NKJ", "36PVS", "35PLK", "36NTN", "35PNL", "36NXM", "35NNF", "35NQJ", "35PKL", "36NUN", "36NTK", "36NVL", "36NUP", "35NRG", "35NNJ", "35NRF", "35NQE", "36NTL", "36NWM", "36NTP", "35NQG", "35PNK", "35PML", "36NVM", "36PSP", "36PUQ", "36PVQ", "36NXL", "36PTQ", "35NMJ", "35PQM", "35NQF", "35PLL", "35NNH", "36NVP", "35NLJ", "35NLH", "35PQK", "36PVR", "36NUM"], "SY": ["37SBS", "36SYE", "37SGA", "37SDV", "37SDA", "37SBV", "37SDU", "37SDS", "37SCU", "37SEU", "37SDT", "37SCT", "37SCA", "37SBT", "38SKF", "37SCS", "36SYB", "37SFA"], "SZ": ["36JVR", "36JUR"], "TD": ["33QYU", "34QCJ", "33PXT", "34QCG", "33QYC", "34QDG", "34PDV", "34QBE", "34PCA", "34QBK", "34QBD", "34PBB", "33PVL", "33PWL", "33QYD", "34QBJ", "34QGE", "34QEG", "34QHD", "33QYV", "33PWT", "33QXV", "35QKA", "34QCD", "34QCM", "34QBL", "33QVT", "33QYG", "34QBC", "34PBQ", "34PES", "35QMV", "34PCR", "34QCK", "33PWN", "34PBR", "33PZN", "33PYM", "34QBH", "33PYN", "34PBS", "34QGH", "34QBG", "34QCE", "34PGB", "34QBF", "35QKB", "34QDF", "33QYF", "33QXF", "35QKU", "34QGG", "33QWC", "33NYJ", "34QGF", "34QDL", "34QDH", "33QXE", "35QLB", "33PWK", "34QDK", "34QHF", "34QFF", "33QXU", "34QDJ", "34QEE", "34QFE", "34RBP", "33PYJ", "34RCN", "34QFH", "33PXQ", "34PDC", "33QYB", "34PET", "33PYK", "34QEH", "33QXC", "35QMA", "33QXD", "33QWU", "33QWB", "34PDT", "33NWJ", "33QYA", "33QWD", "33QYE"], "TG": ["31NCJ", "31NBJ", "31PCJ"], "TJ": ["42SVG", "43SCB"], "TM": ["41SKB", "40SFG", "41SLB"], "TN": ["32SME", "32SMC", "32SNB", "32SPF", "32SNC", "32SNE", "32SND", "32SNF", "32SMD", "32SPB", "32SPE"], "TR": ["38SKG", "36SUF", "37SBA", "36SWF", "36SXF", "38SLG", "36STG", "36SXG", "37SEA", "37SGB", "35SNB", "35SQA", "35SQB", "37SFB"], "TZ": ["36LUS", "35MRM", "37LCL", "36MWT", "37MAQ", "36MTT", "37MBR", "36LYN", "36MXV", "36MZU", "36MVU", "36MTA", "36LWR", "37LDJ", "36MWS", "36MYV", "36MTC", "36MWV", "36MYS", "36MTV", "36MTU", "37LDK", "36MYB", "36MXA", "37LCK", "36MUC", "37MAR", "37MDM", "36MVA", "36LXR", "35MPQ", "36LXQ", "36MWU", "36MSA", "36MWA", "36MXB", "36LWQ", "36MXT", "36MSV", "37LEL", "37MCQ", "36MVB", "36LYQ", "37LBL", "37MDN", "37MCS", "36LYM", "36MYA", "37MCM", "37LAK", "37MBS", "36LUR", "36MWB", "36MUU", "36LZP", "36MVC", "36MYT", "37MCR", "37LEM", "37MEM", "37LAH", "37MCP", "36MTS", "37MEN", "36MVV", "36LZQ", "36LYP", "36MXS", "37LDM", "36MUS", "37MBP", "36MYU", "36MVT", "36MUA", "37LEH", "37LCH", "36MUV"], "UG": ["36NUH", "36NWK", "36NUF", "36NYH", "36NWG", "35NRA", "36NXK", "36MTE", "36NTH", "36NVG", "36NWH", "35MQV", "36NSF", "36NTJ", "36NUJ", "36NVJ", "36NTF", "36NVH", "36MSE", "36NZH", "36NXG", "36NXH", "36NXJ", "36NYK", "36MVE"], "YE": ["39QTA", "38QQE", "38QQG", "39QXA", "39PVS", "39PWS", "39QZB", "39QUC", "38PLA", "38QND", "39QXC", "38PMA", "39QUD", "39QVA", "39QVD", "39PWT", "39QYB", "39QVC", "39QWA", "38QQD", "38PNA", "39QWC", "39QYV", "39QUB", "39PVT", "38PMB", "38PNC", "38PLB", "38PKC", "38QMC", "38PMV", "38PMC", "38PKB", "39QYU", "38PNB", "38PRC", "38QKC", "38PNV", "39QTC", "39QXV", "38PPB", "38PPA", "38PPV", "39QYA"], "ZA": ["35JMG", "34HDH", "35JPE", "35JKE", "36JTP", "35JNL", "36JUQ", "35JLG", "36KVV", "35JNP", "35HKD", "34HCK", "36JVQ", "34HDJ", "34JGM", "36JUU", "36JTT", "34JFP", "36KTV", "35JPF", "33JYF", "35JNK", "36JUT", "35JQK", "34JCM", "34JGR", "35JQL", "35JML", "35JQH", "35JQJ", "34JBL", "35HNE", "35JPK", "36JTR", "35JME", "35JPM", "34HBH", "36KUV", "34JFM", "35JQG", "35JKH", "35JLK", "34JDL", "35JQN", "35JNF", "34JBM", "35JMH", "35KQQ", "34HBJ", "35JKL", "34HGJ", "35JLL", "35JKK", "35HMD", "35JLJ", "36JTS", "35JPL", "36KTU", "34JEL", "35HLD", "36JUP", "35JNJ", "34JEQ", "34JGN", "34JGK", "34JFK", "36JTM", "34JFQ", "36JUN", "35JNE", "35JKG", "34JBN", "36JTQ", "34JCK", "34JGP", "34HEJ", "35JMK", "34JEK"], "ZM": ["35LMF", "35LQJ", "35LKF", "35LND", "36LTQ", "36LTP", "36LVM", "36LUM", "35LMG", "35LKC", "36LUP", "35LNC", "34LEK", "35LQH", "36LVJ", "35LME", "34LGM", "36LTR", "35KKC", "35KPB", "34LGL", "36LTJ", "35KNA", "35LKE", "35LLD", "35LNK", "35LJE", "34LGH", "34LFH", "36LTN", "34LDJ", "35LJH", "35LLG", "36LVP", "35KMB", "35LPF", "35LMD", "35LQF", "35LPD", "35KNB", "34LDH", "35LNF", "35MQM", "34LDK", "35LJF", "36LVK", "35KKB", "36KSH", "36LUN", "35LPK", "34LEJ", "36LWN", "36LUJ", "35LLC", "35LMC", "35LQD", "35MPM", "35KMA", "35LJJ", "35LKG", "35LNG", "35LPE", "34LGJ", "36LSP", "35LQE"], "ZW": ["36KVA", "36KSC", "35KQS", "36KUE", "36KUA", "35KQB", "36KVE", "35KNV", "35KQA", "35KPT", "35KQT", "36KVC", "36KTF", "36KUD", "35KQC", "35LQC", "36KTG", "36KTB", "35KRU", "36KWC", "35KMV", "36LVH", "35KRA", "35KNT"]} diff --git a/instageo/apps/viz.py b/instageo/apps/viz.py new file mode 100644 index 0000000..53eda88 --- /dev/null +++ b/instageo/apps/viz.py @@ -0,0 +1,163 @@ +# ------------------------------------------------------------------------------ +# This code is licensed under the Attribution-NonCommercial-ShareAlike 4.0 +# International (CC BY-NC-SA 4.0) License. +# +# You are free to: +# - Share: Copy and redistribute the material in any medium or format +# - Adapt: Remix, transform, and build upon the material +# +# Under the following terms: +# - Attribution: You must give appropriate credit, provide a link to the license, +# and indicate if changes were made. You may do so in any reasonable manner, +# but not in any way that suggests the licensor endorses you or your use. +# - NonCommercial: You may not use the material for commercial purposes. +# - ShareAlike: If you remix, transform, or build upon the material, you must +# distribute your contributions under the same license as the original. +# +# For more details, see https://creativecommons.org/licenses/by-nc-sa/4.0/ +# ------------------------------------------------------------------------------ + +"""Utils for Raster Visualisation.""" + +import datashader as ds +import datashader.transfer_functions as tf +import matplotlib.cm +import plotly.graph_objects as go +import rasterio +import xarray as xr +from pyproj import CRS, Transformer + +epsg3857_to_epsg4326 = Transformer.from_crs(3857, 4326, always_xy=True) + + +def get_crs(filepath: str) -> CRS: + """Retrieves the CRS of a GeoTiff data. + + Args: + filepath: Path to a GeoTiff file. + + Returns: + CRS of data stored in `filepath` + """ + src = rasterio.open(filepath) + return src.crs + + +def add_raster_to_plotly_figure( + xarr_dataset: xr.Dataset, + from_crs: CRS, + column_name: str = "band_data", + scale: float = 1.0, +) -> go.Figure: + """Add a raster plot on a Plotly graph object figure. + + This function overlays raster data from an xarray dataset onto a Plotly map figure. + The data is reprojected to EPSG:3857 CRS for compatibility with Mapbox's projection + system. + + Args: + xarr_dataset (xr.Dataset): xarray dataset containing the raster data. + from_crs (CRS): Coordinate Reference System of data stored in xarr_dataset. + column_name (str): Name of the column in `xarr_dataset` to be plotted. Defaults + to "band_data". + scale (float): Scale factor for adjusting the plot resolution. Defaults to 1.0. + + Returns: + Figure: The modified Plotly figure with the raster data overlaid. + """ + # Reproject to EPSG:3857 CRS + xarr_dataset = xarr_dataset.rio.write_crs(from_crs).rio.reproject("EPSG:3857") + xarr_dataset = xarr_dataset.where(xarr_dataset <= 1, 0) + xarr_dataset = xarr_dataset.where(xarr_dataset > 0.8, 0) + # Get Raster dimension and range + numpy_data = xarr_dataset[column_name].squeeze().to_numpy() + plot_height, plot_width = numpy_data.shape + + # Data aggregation + canvas = ds.Canvas( + plot_width=int(plot_width * scale), plot_height=int(plot_height * scale) + ) + agg = canvas.raster(xarr_dataset[column_name].squeeze(), interpolate="linear") + + coords_lat_min, coords_lat_max = ( + agg.coords["y"].values.min(), + agg.coords["y"].values.max(), + ) + coords_lon_min, coords_lon_max = ( + agg.coords["x"].values.min(), + agg.coords["x"].values.max(), + ) + # xarr_dataset CRS was converted to EPSG:3857 because when EPSG:4326 is used the + # overlaid image doesn't overlap properly resulting in misrepresentation. The actual + # cause of this behavior is that 'Mapbox supports the popular Web Mercator + # projection, and does not support any other projections.' + ( + coords_lon_min, + coords_lon_max, + ), ( + coords_lat_min, + coords_lat_max, + ) = epsg3857_to_epsg4326.transform( + [coords_lon_min, coords_lon_max], [coords_lat_min, coords_lat_max] + ) + # Corners of the image, which need to be passed to mapbox + coordinates = [ + [coords_lon_min, coords_lat_max], + [coords_lon_max, coords_lat_max], + [coords_lon_max, coords_lat_min], + [coords_lon_min, coords_lat_min], + ] + + # Apply color map + img = tf.shade( + agg, + cmap=matplotlib.colormaps["Reds"], + alpha=100, + how="linear", + )[::-1].to_pil() + return img, coordinates + + +def read_geotiff_to_xarray(filepath: str) -> tuple[xr.Dataset, CRS]: + """Read a GeoTIFF file into an xarray Dataset. + + Args: + filepath (str): Path to the GeoTIFF file. + + Returns: + xr.Dataset: The loaded xarray dataset. + """ + return xr.open_dataset(filepath).sel(band=1), get_crs(filepath) + + +def create_map_with_geotiff_tiles(tiles_to_overlay: list[str]) -> go.Figure: + """Create a map with multiple GeoTIFF tiles overlaid. + + This function reads GeoTIFF files from a specified directory and overlays them on a + Plotly map. + + Args: + tiles_to_overlay (list[str]): Path to tiles to overlay on map. + + Returns: + Figure: A Plotly figure with overlaid GeoTIFF tiles. + """ + fig = go.Figure(go.Scattermapbox()) + fig.update_layout( + mapbox_style="open-street-map", + mapbox=dict(center=go.layout.mapbox.Center(lat=0, lon=20), zoom=2.0), + ) + fig.update_layout(margin={"r": 0, "t": 40, "l": 0, "b": 0}) + mapbox_layers = [] + for tile in tiles_to_overlay: + if tile.endswith(".tif") or tile.endswith(".tiff"): + xarr_dataset, crs = read_geotiff_to_xarray(tile) + img, coordinates = add_raster_to_plotly_figure( + xarr_dataset, crs, "band_data", scale=0.1 + ) + mapbox_layers.append( + {"sourcetype": "image", "source": img, "coordinates": coordinates} + ) + # Overlay the resulting image + fig.update_layout(mapbox_layers=mapbox_layers) + return fig diff --git a/instageo/data/README.md b/instageo/data/README.md new file mode 100644 index 0000000..50d777b --- /dev/null +++ b/instageo/data/README.md @@ -0,0 +1,86 @@ +# InstaGeo - Data (Chip Creator) + +## Overview +The [Harmonized Landsat-8 Sentinel-2 (HLS)](https://hls.gsfc.nasa.gov/) satellite data product is en effort by [NASA](https://www.nasa.gov/) to provide a consistent and long-term data record of the Earth's land surface at a high temporal (2-3 days) and spatial (30m) resolutions. + +As with every raw satellite data product, processing and manipulating it requires quite some knowledge and experience which can be an impediment for the use of such data for social good. + +InstaGeo's Chip Creator aims to fill this gap to enable researchers leverage HLS without having to worry about handling and processing raw satellite data. + +By providing gelocated observation records containing longitude, latitude, date and label, chip creator sources the appropriate HLS granules and creates valid chips with corresponding segmentation maps. + +## Workflow +- Input geolocated observation records in a CSV file with the following columns: + - x (longitude) + - y (latitude) + - date + - label +- Group records based on Military Grid Reference System (MGRS) +- For each record, create a corresponding temporal series of HLS granules +- Create a set of all HLS granules +- Download the required bands from each granule +- Create and save chips and segementation maps + +Chip creator is particularly useful for geospatial image segmentation where large satellite images are segmented into smaller chips for training geospatial models (such as [Prithvi](https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M)) using image segmentation objective. + +## Requirements +- Python >= 3.10 +- Required Python libraries listed in `requirements.txt` + +## Installation +Ensure that Python 3.10 and all required libraries are installed. You can install the required libraries using pip: + +```bash +pip install -r requirements.txt +``` +## Authentication +HLS data is hosted on LP DAAC and requires authentication to access it. Create an account an [Earth Data Account](https://urs.earthdata.nasa.gov/) and save the following to your home directory at `~/.netrc` + + `machine urs.earthdata.nasa.gov login USERNAME password PASSWORD` + + Replace `USERNAME` and `PASSWORD` with those of your account. + +## Usage + +### Command-line Arguments +- `--dataframe_path`: Path to the DataFrame CSV file containing geolocated observation records. (required) +- `--chip_size`: Size of each chip. (default: 224) +- `--output_directory`: Directory where the chips and segmentation maps will be saved. (required) +- `--no_data_value`: Value to use for no data pixels in the segmentation maps. (default: -1) +- `--src_crs`: Coordinate Reference System (CRS) of the geographical coordinates in `dataframe_path`. (default: 4326) +- `--min_count`: Minimum records that is desired per granule. This is useful for cotroling sparsity. (default: 1) +- `--num_steps`: Number of temporal steps for creating data. For static data, this value should be set to 1 (default: 1) +- `--temporal_step`: Size of each temporal step in days (default: 30) +- `--temporal_tolerance`: Number of days to be tolerated when searching for closest HLS granules. (default: 3) +- `--download_only`: Downloads only HLS tiles when set to True. (default: False) + + +### Running the Module +To run the module, use the following command in the terminal: + +```bash +python chip_creator.py --dataframe_path="" --chip_size= --output_directory="" --no_data_value= --src_crs= +``` + +Replace ``, ``, ``, ``, `` `` and `` with appropriate values. + +### Example +```bash +python chip_creator.py --dataframe_path="/path/to/data.csv" --chip_size=224 --output_directory="/path/to/output" --no_data_value=-1 --src_crs 4326 +``` + +This command will process the HLS tile, segment it into chips of size 224x224, create segmentation maps based on the provided DataFrame, and save the output in the specified directory. + +Run this command to test using an example observation record. + +```bash +python -m "instageo.data.chip_creator" \ + --dataframe_path=tests/data/test_breeding_data.csv + --output_directory="." \ + --min_count=4 \ + --chip_size=512 \ + --no_data_value=-1 \ + --temporal_tolerance=3 \ + --temporal_step=30 \ + --num_steps=3 +``` diff --git a/instageo/data/__init__.py b/instageo/data/__init__.py new file mode 100644 index 0000000..90719f4 --- /dev/null +++ b/instageo/data/__init__.py @@ -0,0 +1,5 @@ +"""Data Handling for InstaGeo. + +This subpackage includes modules and utilities for loading, processing, +and transforming geospatial data for analysis. +""" diff --git a/instageo/data/chip_creator.py b/instageo/data/chip_creator.py new file mode 100644 index 0000000..edb3844 --- /dev/null +++ b/instageo/data/chip_creator.py @@ -0,0 +1,383 @@ +# ------------------------------------------------------------------------------ +# This code is licensed under the Attribution-NonCommercial-ShareAlike 4.0 +# International (CC BY-NC-SA 4.0) License. +# +# You are free to: +# - Share: Copy and redistribute the material in any medium or format +# - Adapt: Remix, transform, and build upon the material +# +# Under the following terms: +# - Attribution: You must give appropriate credit, provide a link to the license, +# and indicate if changes were made. You may do so in any reasonable manner, +# but not in any way that suggests the licensor endorses you or your use. +# - NonCommercial: You may not use the material for commercial purposes. +# - ShareAlike: If you remix, transform, or build upon the material, you must +# distribute your contributions under the same license as the original. +# +# For more details, see https://creativecommons.org/licenses/by-nc-sa/4.0/ +# ------------------------------------------------------------------------------ + +"""InstaGeo Chip Creator Module.""" + +import bisect +import json +import os +from typing import Any + +import geopandas as gpd +import numpy as np +import pandas as pd +import rasterio +import xarray as xr +from absl import app, flags, logging +from shapely.geometry import Point +from tqdm import tqdm + +import instageo.data.hls_utils as hls_utils +from instageo.data.geo_utils import open_mf_tiff_dataset + +logging.set_verbosity(logging.INFO) + +FLAGS = flags.FLAGS + +flags.DEFINE_string("dataframe_path", None, "Path to the DataFrame CSV file.") +flags.DEFINE_integer("chip_size", 224, "Size of each chip.") +flags.DEFINE_integer("src_crs", 4326, "CRS of the geo-coordinates in `dataframe_path`") +flags.DEFINE_string( + "output_directory", + None, + "Directory where the chips and segmentation maps will be saved.", +) +flags.DEFINE_integer( + "no_data_value", -1, "Value to use for no data areas in the segmentation maps." +) +flags.DEFINE_integer("min_count", 100, "Minimum observation counts per tile") +flags.DEFINE_integer("num_steps", 3, "Number of temporal steps") +flags.DEFINE_integer("temporal_step", 30, "Temporal step size.") +flags.DEFINE_integer( + "temporal_tolerance", 5, "Tolerance used when searching for the closest tile" +) +flags.DEFINE_boolean( + "download_only", False, "Downloads HLS dataset without creating chips." +) +flags.DEFINE_boolean("mask_cloud", False, "Perform Cloud Masking") + + +def check_required_flags() -> None: + """Check if required flags are provided.""" + required_flags = ["dataframe_path", "output_directory"] + for flag_name in required_flags: + if not getattr(FLAGS, flag_name): + raise app.UsageError(f"Flag --{flag_name} is required.") + + +def create_segmentation_map( + chip: Any, + df: pd.DataFrame, + no_data_value: int, +) -> np.ndarray: + """Create a segmentation map for the chip using the DataFrame. + + Args: + chip (Any): The chip (subset of the original data) for which the segmentation + map is being created. + df (pd.DataFrame): DataFrame containing the data to be used in the segmentation + map. + no_data_value (int): Value to be used for pixels with no data. + + Returns: + np.ndarray: The created segmentation map as a NumPy array. + """ + seg_map = chip.isel(band=0).assign( + { + "band_data": ( + ("y", "x"), + no_data_value * np.ones((chip.sizes["x"], chip.sizes["y"])), + ) + } + ) + df = df[ + (chip["x"].min().item() <= df["geometry"].x) + & (df["geometry"].x <= chip["x"].max().item()) + & (chip["y"].min().item() <= df["geometry"].y) + & (df["geometry"].y <= chip["y"].max().item()) + ] + # Use a tolerance of 30 meters + for _, row in df.iterrows(): + nearest_index = seg_map.sel( + x=row["geometry"].x, y=row["geometry"].y, method="nearest", tolerance=30 + ) + seg_map.loc[ + dict(x=nearest_index["x"].values.item(), y=nearest_index["y"].values.item()) + ] = row["label"] + return seg_map.band_data.squeeze() + + +def get_chip_coords( + df: gpd.GeoDataFrame, tile: xr.DataArray, chip_size: int +) -> list[tuple[int, int]]: + """Get Chip Coordinates. + + Given a list of x,y coordinates tuples of a point and an xarray dataarray, this + function returns the corresponding x,y indices of the grid where each point will fall + when the DataArray is gridded such that each grid has size `chip_size` + indices where it will fall. + + Args: + gdf (gpd.GeoDataFrame): GeoPandas dataframe containing the point. + tile (xr.DataArray): Tile DataArray. + chip_size (int): Size of each chip. + + Returns: + List of chip indices. + """ + coords = [] + for _, row in df.iterrows(): + x = bisect.bisect_left(tile["x"].values, row["geometry"].x) + y = bisect.bisect_left(tile["y"].values[::-1], row["geometry"].y) + y = tile.sizes["y"] - y - 1 + x = int(x // chip_size) + y = int(y // chip_size) + coords.append((x, y)) + return coords + + +def create_and_save_chips_with_seg_maps( + hls_tile_dict: dict[str, dict[str, str]], + df: pd.DataFrame, + chip_size: int, + output_directory: str, + no_data_value: int, + src_crs: int, + mask_cloud: bool, +) -> tuple[list[str], list[str | None]]: + """Chip Creator. + + Create chips and corresponding segmentation maps from a HLS tile and save them to + an output directory. + + Args: + hls_tile_dict (Dict): A dict mapping band names to HLS tile filepath. + df (pd.DataFrame): DataFrame containing the data for segmentation maps. + chip_size (int): Size of each chip. + output_directory (str): Directory where the chips and segmentation maps will be + saved. + no_data_value (int): Value to use for no data areas in the segmentation maps. + src_crs (int): CRS of points in `df` + mask_cloud (bool): Perform cloud masking if True. + + Returns: + A tuple containing the lists of created chips and segmentation maps. + """ + ds, crs = open_mf_tiff_dataset(hls_tile_dict, mask_cloud) + df = gpd.GeoDataFrame(df, geometry=[Point(xy) for xy in zip(df.x, df.y)]) + df.set_crs(epsg=src_crs, inplace=True) + df = df.to_crs(crs=crs) + + df = df[ + (ds["x"].min().item() <= df["geometry"].x) + & (df["geometry"].x <= ds["x"].max().item()) + & (ds["y"].min().item() <= df["geometry"].y) + & (df["geometry"].y <= ds["y"].max().item()) + ] + os.makedirs(output_directory, exist_ok=True) + tile_name_splits = hls_tile_dict["tiles"]["B02_0"].split(".") + tile_id = f"{tile_name_splits[1]}_{tile_name_splits[2]}_{tile_name_splits[3]}" + date_id = df.iloc[0]["date"].strftime("%Y%m%d") + chips = [] + seg_maps: list[str | None] = [] + n_chips_x = ds.sizes["x"] // chip_size + n_chips_y = ds.sizes["y"] // chip_size + chip_coords = list(set(get_chip_coords(df, ds, chip_size))) + for x, y in chip_coords: + if (x >= n_chips_x) or (y >= n_chips_y): + continue + chip_id = f"{date_id}_{tile_id}_{x}_{y}" + chip_name = f"chip_{chip_id}.tif" + seg_map_name = f"seg_map_{chip_id}.tif" + + chip_filename = os.path.join(output_directory, "chips", chip_name) + seg_map_filename = os.path.join(output_directory, "seg_maps", seg_map_name) + if os.path.exists(chip_filename) or os.path.exists(seg_map_filename): + continue + + chip = ds.isel( + x=slice(x * chip_size, (x + 1) * chip_size), + y=slice(y * chip_size, (y + 1) * chip_size), + ) + if chip.count().values == 0: + continue + seg_map = create_segmentation_map(chip, df, no_data_value) + if seg_map.where(seg_map != -1).count().values == 0: + continue + seg_maps.append(seg_map_name) + seg_map.rio.to_raster(seg_map_filename) + chip = chip.fillna(no_data_value) + chips.append(chip_name) + chip.band_data.rio.to_raster(chip_filename) + return chips, seg_maps + + +def create_hls_dataset( + data_with_tiles: pd.DataFrame, outdir: str +) -> tuple[dict[str, dict[str, dict[str, str]]], set[str]]: + """Creates HLS Dataset. + + A HLS dataset is a list of dictionary mapping band names to corresponding GeoTiff + filepath. It is required for creating chips. + + Args: + data_with_tiles (pd.DataFrame): A dataframe containing observations that fall + within a dense tile. It also has `hls_tiles` column that contains a temporal + series of HLS granules. + outdir (str): Output directory where tiles could be downloaded to. + + Returns: + A tuple containing HLS dataset and a list of tiles that needs to be downloaded. + """ + data_with_tiles = data_with_tiles.drop_duplicates(subset=["hls_tiles"]) + data_with_tiles = data_with_tiles[ + data_with_tiles["hls_tiles"].apply( + lambda granule_lst: all("HLS" in str(item) for item in granule_lst) + ) + ] + assert not data_with_tiles.empty, "No observation record with valid HLS tiles" + hls_dataset = {} + granules_to_download = [] + s30_bands = ["B02", "B03", "B04", "B8A", "B11", "B12", "Fmask"] + l30_bands = ["B02", "B03", "B04", "B05", "B06", "B07", "Fmask"] + for hls_tiles, obsv_date in zip( + data_with_tiles["hls_tiles"], data_with_tiles["date"] + ): + band_id, band_path = [], [] + mask_id, mask_path = [], [] + for idx, tile in enumerate(hls_tiles): + tile = tile.strip(".") + if "HLS.S30" in tile: + for band in s30_bands: + if band == "Fmask": + mask_id.append(f"{band}_{idx}") + mask_path.append( + os.path.join(outdir, "hls_tiles", f"{tile}.{band}.tif") + ) + else: + band_id.append(f"{band}_{idx}") + band_path.append( + os.path.join(outdir, "hls_tiles", f"{tile}.{band}.tif") + ) + granules_to_download.append( + f"https://data.lpdaac.earthdatacloud.nasa.gov/lp-prod-protected/HLSS30.020/{tile}/{tile}.{band}.tif" # noqa + ) + else: + for band in l30_bands: + if band == "Fmask": + mask_id.append(f"{band}_{idx}") + mask_path.append( + os.path.join(outdir, "hls_tiles", f"{tile}.{band}.tif") + ) + else: + band_id.append(f"{band}_{idx}") + band_path.append( + os.path.join(outdir, "hls_tiles", f"{tile}.{band}.tif") + ) + granules_to_download.append( + f"https://data.lpdaac.earthdatacloud.nasa.gov/lp-prod-protected/HLSL30.020/{tile}/{tile}.{band}.tif" # noqa + ) + + hls_dataset[f'{obsv_date.strftime("%Y-%m-%d")}_{tile.split(".")[2]}'] = { + "tiles": {k: v for k, v in zip(band_id, band_path)}, + "fmasks": {k: v for k, v in zip(mask_id, mask_path)}, + } + return hls_dataset, set(granules_to_download) + + +def main(argv: Any) -> None: + """CSV Chip Creator. + + Given a csv file containing geo-located point observations and labels, the Chip + Creator creates small chip from large HLS tiles which is suitable for training + segmentation models. + """ + del argv + data = pd.read_csv(FLAGS.dataframe_path) + data["date"] = pd.to_datetime(data["date"]) - pd.offsets.MonthBegin(1) + data["input_features_date"] = data["date"] - pd.DateOffset(months=1) + sub_data = hls_utils.get_hls_tiles(data, min_count=FLAGS.min_count) + + if not ( + os.path.exists(os.path.join(FLAGS.output_directory, "hls_dataset.json")) + and os.path.exists( + os.path.join(FLAGS.output_directory, "granules_to_download.csv") + ) + ): + logging.info("Creating HLS dataset JSON.") + logging.info("Retrieving HLS tile ID for each observation.") + sub_data_with_tiles = hls_utils.add_hls_granules( + sub_data, + num_steps=FLAGS.num_steps, + temporal_step=FLAGS.temporal_step, + temporal_tolerance=FLAGS.temporal_tolerance, + ) + logging.info("Retrieving HLS tiles that will be downloaded.") + hls_dataset, granules_to_download = create_hls_dataset( + sub_data_with_tiles, outdir=FLAGS.output_directory + ) + with open( + os.path.join(FLAGS.output_directory, "hls_dataset.json"), "w" + ) as json_file: + json.dump(hls_dataset, json_file, indent=4) + pd.DataFrame({"tiles": list(granules_to_download)}).to_csv( + os.path.join(FLAGS.output_directory, "granules_to_download.csv") + ) + else: + logging.info("HLS dataset JSON already created") + with open( + os.path.join(FLAGS.output_directory, "hls_dataset.json") + ) as json_file: + hls_dataset = json.load(json_file) + granules_to_download = pd.read_csv( + os.path.join(FLAGS.output_directory, "granules_to_download.csv") + )["tiles"].tolist() + os.makedirs(os.path.join(FLAGS.output_directory, "hls_tiles"), exist_ok=True) + logging.info("Downloading HLS Tiles") + hls_utils.parallel_download( + granules_to_download, + outdir=os.path.join(FLAGS.output_directory, "hls_tiles"), + ) + if FLAGS.download_only: + return + logging.info("Creating Chips and Segmentation Maps") + all_chips = [] + all_seg_maps = [] + os.makedirs(os.path.join(FLAGS.output_directory, "chips"), exist_ok=True) + os.makedirs(os.path.join(FLAGS.output_directory, "seg_maps"), exist_ok=True) + for key, hls_tile_dict in tqdm(hls_dataset.items(), desc="Processing HLS Dataset"): + obsv_date_str, tile_id = key.split("_") + obsv_data = sub_data[ + (sub_data["date"] == pd.to_datetime(obsv_date_str)) + & (sub_data["mgrs_tile_id"].str.contains(tile_id.strip("T"))) + ] + try: + chips, seg_maps = create_and_save_chips_with_seg_maps( + hls_tile_dict, + obsv_data, + chip_size=FLAGS.chip_size, + output_directory=FLAGS.output_directory, + no_data_value=FLAGS.no_data_value, + src_crs=FLAGS.src_crs, + mask_cloud=FLAGS.mask_cloud, + ) + all_chips.extend(chips) + all_seg_maps.extend(seg_maps) + except rasterio.errors.RasterioIOError as e: + logging.error(f"Error {e} when reading dataset containing: {hls_tile_dict}") + except IndexError as e: + logging.error(f"Error {e} when processing {key}") + logging.info("Saving dataframe of chips and segmentation maps.") + pd.DataFrame({"Input": all_chips, "Label": all_seg_maps}).to_csv( + os.path.join(FLAGS.output_directory, "hls_chips_dataset.csv") + ) + + +if __name__ == "__main__": + app.run(main) diff --git a/instageo/data/geo_utils.py b/instageo/data/geo_utils.py new file mode 100644 index 0000000..c8c5562 --- /dev/null +++ b/instageo/data/geo_utils.py @@ -0,0 +1,78 @@ +# ------------------------------------------------------------------------------ +# This code is licensed under the Attribution-NonCommercial-ShareAlike 4.0 +# International (CC BY-NC-SA 4.0) License. +# +# You are free to: +# - Share: Copy and redistribute the material in any medium or format +# - Adapt: Remix, transform, and build upon the material +# +# Under the following terms: +# - Attribution: You must give appropriate credit, provide a link to the license, +# and indicate if changes were made. You may do so in any reasonable manner, +# but not in any way that suggests the licensor endorses you or your use. +# - NonCommercial: You may not use the material for commercial purposes. +# - ShareAlike: If you remix, transform, or build upon the material, you must +# distribute your contributions under the same license as the original. +# +# For more details, see https://creativecommons.org/licenses/by-nc-sa/4.0/ +# ------------------------------------------------------------------------------ + +"""Geo Utils Module.""" + +import rasterio +import xarray as xr +from rasterio.crs import CRS + + +def open_mf_tiff_dataset( + band_files: dict[str, dict[str, str]], mask_cloud: bool +) -> tuple[xr.Dataset, CRS]: + """Open multiple TIFF files as an xarray Dataset. + + Args: + band_files (Dict[str, Dict[str, str]]): A dictionary mapping band names to file paths. + mask_cloud (bool): Perform cloud masking. + + Returns: + (xr.Dataset, CRS): A tuple of xarray Dataset combining data from all the + provided TIFF files and its CRS + """ + band_paths = list(band_files["tiles"].values()) + bands_dataset = xr.open_mfdataset( + band_paths, + concat_dim="band", + combine="nested", + ) + mask_paths = list(band_files["fmasks"].values()) + mask_dataset = xr.open_mfdataset( + mask_paths, + concat_dim="band", + combine="nested", + ) + water_mask = decode_fmask_value(mask_dataset, 5) + water_mask = water_mask.band_data.values.any(axis=0).astype(int) + bands_dataset = bands_dataset.where(water_mask == 0) + if mask_cloud: + cloud_mask = decode_fmask_value(mask_dataset, 1) + cloud_mask = cloud_mask.band_data.values.any(axis=0).astype(int) + bands_dataset = bands_dataset.where(cloud_mask == 0) + with rasterio.open(band_paths[0]) as src: + crs = src.crs + return bands_dataset, crs + + +def decode_fmask_value(value: xr.Dataset, position: int) -> xr.Dataset: + """Decodes HLS v2.0 Fmask. + + The decoding strategy is described in Appendix A of the user manual + (https://lpdaac.usgs.gov/documents/1698/HLS_User_Guide_V2.pdf). + + Arguments: + value: Input xarray Dataset created from Fmask.tif + position: Bit position to decode. + + Returns: + Xarray dataset containing decoded bits. + """ + quotient = value // (2**position) + return quotient - ((quotient // 2) * 2) diff --git a/instageo/data/hls_utils.py b/instageo/data/hls_utils.py new file mode 100644 index 0000000..a8a1b09 --- /dev/null +++ b/instageo/data/hls_utils.py @@ -0,0 +1,294 @@ +# ------------------------------------------------------------------------------ +# This code is licensed under the Attribution-NonCommercial-ShareAlike 4.0 +# International (CC BY-NC-SA 4.0) License. +# +# You are free to: +# - Share: Copy and redistribute the material in any medium or format +# - Adapt: Remix, transform, and build upon the material +# +# Under the following terms: +# - Attribution: You must give appropriate credit, provide a link to the license, +# and indicate if changes were made. You may do so in any reasonable manner, +# but not in any way that suggests the licensor endorses you or your use. +# - NonCommercial: You may not use the material for commercial purposes. +# - ShareAlike: If you remix, transform, or build upon the material, you must +# distribute your contributions under the same license as the original. +# +# For more details, see https://creativecommons.org/licenses/by-nc-sa/4.0/ +# ------------------------------------------------------------------------------ + +"""Utility Functions for Reading and Processing Harmonized Landsat Sentinel-2 Dataset.""" + +import os +import re +from datetime import datetime, timedelta +from multiprocessing import cpu_count + +import earthaccess +import mgrs +import pandas as pd +from absl import logging + + +def parse_date_from_entry(hls_tile_name: str) -> datetime | None: + """Extracts the date from a HLS Tile Name. + + Args: + hls_tile_name (str): Name of HLS tile. + + Returns: + Parsed date or None. + """ + match = re.search(r"\.(\d{7})T", hls_tile_name) + if match: + date_str = match.group(1) + return datetime.strptime(date_str, "%Y%j") + else: + return None + + +def find_closest_tile( + tile_queries: dict[str, tuple[str, list[str]]], + tile_database: dict[str, list[str]], + temporal_tolerance: int = 5, +) -> pd.DataFrame: + """Find Closes HLS Tile. + + HLS dataset gets updated every 2 or 3 days and each tile is marked by the time of + observation. This makes it difficult to derterministically find tiles for a given + observation time. Rather we try to find a tile with observation time closest to our + desired time. + + To do this, we create a database of tiles within a specific timeframe then we search + for our desired tile within the database. + + Args: + tile_queries (dict[str, tuple[str, list[str]]]): A dict with tile_query as key + and a tuple of tile_id and a list of dates on which the tile needs to be + retrieved as value. + tile_database (dict[str, list[str]]): A database mapping HLS tile_id to a list of + available tiles within a pre-defined period of time + temporal_tolerance: Number of days that can be tolerated for matching a closest + tile in tile_databse. + + Returns: + DataFrame containing the tile queries to the tile found. + """ + query_results = {} + for query_str, (tile_id, dates) in tile_queries.items(): + result = [] + if tile_id in tile_database: + for date_str in dates: + date = pd.to_datetime(date_str) + year, day_of_year = date.year, date.day_of_year + query_date = datetime(year, 1, 1) + timedelta(days=day_of_year - 1) + closest_entry = None + for entry in tile_database[tile_id]: + entry_date = parse_date_from_entry(entry) + if not entry_date: + continue + diff = abs((entry_date - query_date).days) + if (diff <= temporal_tolerance) and (diff >= 0): + closest_entry = entry + break + result.append(closest_entry) + query_results[query_str] = result + query_results = pd.DataFrame( + {"tile_queries": query_results.keys(), "hls_tiles": query_results.values()} + ) + return query_results + + +def retrieve_hls_metadata(tile_info_df: pd.DataFrame) -> dict[str, list[str]]: + """Retrieve HLS Tiles Metadata. + + Given a tile_id, start_date and end_date, this function fetches all the HLS granules + available for this tile_id in this time window. + + Args: + tile_info_df (pd.DataFrame): A dataframe containing tile_id, start_date and + end_date in each row. + + Returns: + A dictionary mapping tile_id to a list of available HLS granules. + """ + granules_dict = {} + for _, ( + tile_id, + start_date, + end_date, + lon_min, + lon_max, + lat_min, + lat_max, + ) in tile_info_df.iterrows(): + results = earthaccess.search_data( + short_name=["HLSL30", "HLSS30"], + bounding_box=(lon_min, lat_min, lon_max, lat_max), + temporal=(f"{start_date}T00:00:00", f"{end_date}T23:59:59"), + ) + granules = pd.json_normalize(results) + granules = granules[granules["meta.native-id"].str.contains(tile_id)] + granules = list(granules["meta.native-id"]) + granules_dict[tile_id] = granules + return granules_dict + + +def get_hls_tiles(data: pd.DataFrame, min_count: int = 100) -> pd.DataFrame: + """Get HLS Tile ID for Each Observation. + + Locust observations are described by geolocation scattered across the globe. They are + dense as well as sparse in various locations. In order to optimize resource usage, we + subset the observations in dense locations. + + We first add the HLS tile ID for each observationa and count the number of + observations in each tile. Then we retain the tiles with `min_count` observations. + + Args: + data: Dataframe containing locust observations + min_count: minimum count of locust observations per HLS tile. + + Returns: + Subset of locust observations where there are `min_count` observations per tile + + """ + mgrs_object = mgrs.MGRS() + get_mgrs_tile_id = lambda row: mgrs_object.toMGRS( + row["y"], row["x"], MGRSPrecision=0 + ) + data["mgrs_tile_id"] = data.apply(get_mgrs_tile_id, axis=1) + tile_counts = data.groupby("mgrs_tile_id").size().sort_values(ascending=False) + data = pd.merge( + data, tile_counts.reset_index(name="counts"), how="left", on="mgrs_tile_id" + ) + sub_data = data[data["counts"] >= min_count] + assert not sub_data.empty, "No observation records left" + return sub_data + + +def get_hls_tile_info( + data: pd.DataFrame, num_steps: int = 3, temporal_step: int = 10 +) -> tuple[pd.DataFrame, list[tuple[str, list[str]]]]: + """Get HLS Tile Info. + + Retrieves a summary of all tiles required for a given dataset. The summary contains + the desired start and end date for each HLS tile. Also retrieves a list of queries + that can be used to retieve the tiles for each observation in `data`. + + Args: + data (pd.DataFrame): A dataframe containing observation records. + num_steps (int): Number of temporal time steps + temporal_step (int): Size of each temporal step. + + Returns: + A `tile_info` dataframe and a list of `tile_queries` + """ + data = data[["mgrs_tile_id", "input_features_date", "x", "y"]].reset_index( + drop=True + ) + tile_queries = [] + tile_info = [] + for _, (tile_id, date, lon, lat) in data.iterrows(): + history = [] + for i in range(num_steps): + curr_date = date - pd.Timedelta(days=temporal_step * i) + history.append(curr_date.strftime("%Y-%m-%d")) + tile_info.append([tile_id, curr_date.strftime("%Y-%m-%d"), lon, lat]) + tile_queries.append((tile_id, history)) + tile_info = ( + pd.DataFrame(tile_info, columns=["tile_id", "date", "lon", "lat"]) + .groupby("tile_id") + .agg( + min_date=("date", "min"), + max_date=("date", "max"), + lon_min=("lon", "min"), + lon_max=("lon", "max"), + lat_min=("lat", "min"), + lat_max=("lat", "max"), + ) + ).reset_index() + return tile_info, tile_queries + + +def add_hls_granules( + data: pd.DataFrame, + num_steps: int = 3, + temporal_step: int = 10, + temporal_tolerance: int = 5, +) -> pd.DataFrame: + """Add HLS Granules. + + Data contains tile_id and a series of date for which the tile is desired. This + function takes the tile_id and the dates and finds the HLS tiles closest to the + desired date with a tolearance of `temporal_tolerance`. + + Args: + data (pd.DataFrame): A dattaframe containing observations that fall within a + dense tile. + num_steps (int): Number of temporal steps into the past to fetch. + temporal_step (int): Step size (in days) for creating temporal steps. + temporal_tolerance (int): Tolerance (in days) for finding closest HLS tile. + + Returns: + A dataframe containing a list of HLS granules. Each granule is a directory + containing all the bands. + """ + tiles_info, tile_queries = get_hls_tile_info( + data, num_steps=num_steps, temporal_step=temporal_step + ) + tile_queries_str = [ + f"{tile_id}_{'_'.join(dates)}" for tile_id, dates in tile_queries + ] + data["tile_queries"] = tile_queries_str + tile_database = retrieve_hls_metadata(tiles_info) + tile_queries_dict = {k: v for k, v in zip(tile_queries_str, tile_queries)} + query_result = find_closest_tile( + tile_queries=tile_queries_dict, + tile_database=tile_database, + temporal_tolerance=temporal_tolerance, + ) + data = pd.merge(data, query_result, how="left", on="tile_queries") + return data + + +def parallel_download(urls: set[str], outdir: str, max_retries: int = 3) -> None: + """Parallel Download. + + Wraps `download_tile` with multiprocessing.Pool for downloading multiple tiles in + parallel. + + Args: + urls: Tile urls to download. + outdir: Directory to save downloaded tiles. + max_retries: Number of times to retry downloading all tiles. + + Returns: + None + """ + num_cpus = cpu_count() + earthaccess.login(persist=True) + retries = 0 + complete = False + while retries <= max_retries: + temp_urls = [ + url + for url in urls + if not os.path.exists(os.path.join(outdir, url.split("/")[-1])) + ] + if not temp_urls: + complete = True + break + earthaccess.download(temp_urls, local_path=outdir, threads=num_cpus) + for filename in os.listdir(outdir): + file_path = os.path.join(outdir, filename) + if os.path.isfile(file_path): + file_size = os.path.getsize(file_path) + if file_size < 1024: + os.remove(file_path) + retries += 1 + if complete: + logging.info("Successfully downloaded all granules") + else: + logging.warning( + f"Couldn't download the following granules after {max_retries} retries:\n{urls}" # noqa + ) diff --git a/instageo/data/requirements.txt b/instageo/data/requirements.txt new file mode 100644 index 0000000..fd7760c --- /dev/null +++ b/instageo/data/requirements.txt @@ -0,0 +1,2 @@ +mgrs==1.4.6 +earthaccess==0.8.2 diff --git a/instageo/model/Prithvi.py b/instageo/model/Prithvi.py new file mode 100644 index 0000000..920334f --- /dev/null +++ b/instageo/model/Prithvi.py @@ -0,0 +1,323 @@ +# ------------------------------------------------------------------------------ +# This code is adapted from the original code in the following repository: +# +# Original Source: +# https://github.com/NASA-IMPACT/hls-foundation-os/blob/main/geospatial_fm/geospatial_fm.py +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# Adapted by: Ibrahim Salihu Yusuf +# Date of adaptation: 27-10-2023 +# +# Changes made: +# - Removed all other modules and functions except TemporalViTEncoder (renamed as ViTEncoder) and +# its dependencies +# +# References: +# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm +# DeiT: https://github.com/facebookresearch/deit +# -------------------------------------------------------- +"""Vision Transformer Module.""" + +from typing import Tuple + +import numpy as np +import torch +import torch.nn as nn +from timm.models.layers import to_2tuple +from timm.models.vision_transformer import Block + + +def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: np.ndarray) -> np.ndarray: + """Generates a 1D sinusoidal position embedding from a list of positions. + + Args: + embed_dim (int): Output dimension for each position. + pos (np.ndarray): A list of positions to be encoded, size (M,). + + Returns: + np.ndarray: The sinusoidal position embedding (M, D). + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float32) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: np.ndarray) -> np.ndarray: + """Generates a 2D sinusoidal position embedding from a grid of positions. + + Args: + embed_dim (int): Output dimension for each position. + grid (np.ndarray): A 2D grid of positions. + + Returns: + np.ndarray: The sinusoidal position embedding (H*W, D). + """ + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_3d_sincos_pos_embed( + embed_dim: int, grid_size: Tuple[int, int, int], cls_token: bool = False +) -> np.ndarray: + """Generates a 3D sinusoidal position embedding from a given grid size. + + Args: + embed_dim (int): Output dimension for each position. + grid_size (Tuple[int, int, int]): 3D tuple of grid size (t, h, w). + cls_token (bool): Whether to include a class token. + + Returns: + np.ndarray: The sinusoidal position embedding (L, D). + """ + assert embed_dim % 16 == 0 + + t_size, h_size, w_size = grid_size + + w_embed_dim = embed_dim // 16 * 6 + h_embed_dim = embed_dim // 16 * 6 + t_embed_dim = embed_dim // 16 * 4 + + w_pos_embed = get_1d_sincos_pos_embed_from_grid(w_embed_dim, np.arange(w_size)) + h_pos_embed = get_1d_sincos_pos_embed_from_grid(h_embed_dim, np.arange(h_size)) + t_pos_embed = get_1d_sincos_pos_embed_from_grid(t_embed_dim, np.arange(t_size)) + + w_pos_embed = np.tile(w_pos_embed, (t_size * h_size, 1)) + h_pos_embed = np.tile(np.repeat(h_pos_embed, w_size, axis=0), (t_size, 1)) + t_pos_embed = np.repeat(t_pos_embed, h_size * w_size, axis=0) + + pos_embed = np.concatenate((w_pos_embed, h_pos_embed, t_pos_embed), axis=1) + + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +class PatchEmbed(nn.Module): + """Converts frames of 2D images to patch embedding. + + This is a 3D version of timm.models.vision_transformer.PatchEmbed. + """ + + def __init__( + self, + img_size_int: int = 224, + patch_size_int: int = 16, + num_frames: int = 3, + tubelet_size: int = 1, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: nn.Module | None = None, + flatten: bool = True, + bias: bool = True, + ): + """Initializes the PatchEmbed module. + + Args: + img_size_int (int): Size of the image. + patch_size_int (int): Size of each patch. + num_frames (int): Number of frames. + tubelet_size (int): Size of the tubelet. + in_chans (int): Number of input channels. + embed_dim (int): Dimension of the embedding. + norm_layer (Optional[nn.Module]): Normalization layer. + flatten (bool): Whether to flatten the output. + bias (bool): Whether to use bias in the convolution layer. + """ + super().__init__() + img_size = to_2tuple(img_size_int) + patch_size = to_2tuple(patch_size_int) + self.img_size = img_size + self.patch_size = patch_size + self.num_frames = num_frames + self.tubelet_size = tubelet_size + self.grid_size = ( + num_frames // tubelet_size, + img_size[0] // patch_size[0], + img_size[1] // patch_size[1], + ) + self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2] + self.flatten = flatten + + self.proj = nn.Conv3d( + in_chans, + embed_dim, + kernel_size=(tubelet_size, patch_size[0], patch_size[1]), + stride=(tubelet_size, patch_size[0], patch_size[1]), + bias=bias, + ) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward Pass.""" + B, C, T, H, W = x.shape + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # B,C,T,H,W -> B,C,L -> B,L,C + x = self.norm(x) + return x + + +class ViTEncoder(nn.Module): + """Masked Autoencoder with a Vision Transformer (ViT) backbone. + + This class represents an autoencoder architecture suitable for vision tasks, + particularly utilizing a Transformer-based approach for encoding and decoding. + """ + + def __init__( + self, + img_size: int = 224, + patch_size: int = 16, + num_frames: int = 3, + tubelet_size: int = 1, + in_chans: int = 3, + embed_dim: int = 1024, + depth: int = 24, + num_heads: int = 16, + mlp_ratio: float = 4.0, + norm_layer: nn.Module = nn.LayerNorm, + norm_pix_loss: bool = False, + **kwargs: int, + ) -> None: + """Initializes the ViTEncoder model. + + Args: + img_size (int): Size of the input images. + patch_size (int): Size of the patches to be extracted from the input images. + num_frames (int): Number of frames to be considered for the input. + tubelet_size (int): Size of the tubelet. + in_chans (int): Number of input channels. + embed_dim (int): Dimensionality of the token embeddings. + depth (int): Number of layers in the Transformer encoder. + num_heads (int): Number of attention heads in the Transformer encoder. + decoder_embed_dim (int): Dimensionality of the token embeddings in the + decoder. + decoder_depth (int): Number of layers in the Transformer decoder. + decoder_num_heads (int): Number of attention heads in the Transformer + decoder. + mlp_ratio (float): Ratio of feed-forward layer size to Transformer block + size. + norm_layer (Optional[nn.Module]): Normalization layer to be used in the + Transformer. + norm_pix_loss (bool): Whether to normalize pixel loss. + + The class includes methods for patchifying images, applying random masking, + encoding with a Transformer-based architecture, and decoding with a separate + Transformer-based architecture. + """ + del kwargs # unused kwargs + super().__init__() + + # -------------------------------------------------------------------------- + # MAE encoder specifics + self.patch_embed = PatchEmbed( + img_size, patch_size, num_frames, tubelet_size, in_chans, embed_dim + ) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False + ) # fixed sin-cos embedding + + self.blocks = nn.ModuleList( + [ + Block( + embed_dim, + num_heads, + mlp_ratio, + qkv_bias=True, + norm_layer=norm_layer, + ) + for i in range(depth) + ] + ) + self.norm = norm_layer(embed_dim) + # -------------------------------------------------------------------------- + + self.norm_pix_loss = norm_pix_loss + + self.initialize_weights() + + def initialize_weights(self) -> None: + """Initialization. + + Initialize (and freeze) pos_embed by sin-cos embedding. + """ + pos_embed = get_3d_sincos_pos_embed( + self.pos_embed.shape[-1], self.patch_embed.grid_size, cls_token=True + ) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + + # initialize patch_embed like nn.Linear (instead of nn.Conv2d) + w = self.patch_embed.proj.weight.data + torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) # noqa + torch.nn.init.normal_(self.cls_token, std=0.02) + + # initialize nn.Linear and nn.LayerNorm + self.apply(self._init_weights) + + def _init_weights(self, m: nn.Module) -> None: + """Initialize weights. + + Args: + m: nn.Module layer in model + """ + if isinstance(m, nn.Linear): + # we use xavier_uniform following official JAX ViT: + torch.nn.init.xavier_uniform_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward( + self, x: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Encodes the input patches and applies random masking. + + Args: + x (torch.Tensor): The input tensor of patches. + mask_ratio (float): The proportion of the sequence to be masked. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: The encoded tensor, the + binary mask, and the indices to restore the original order. + """ + # embed patches + x = self.patch_embed(x) + + # add pos embed w/o cls token + x = x + self.pos_embed[:, 1:, :] + + # append cls token + cls_token = self.cls_token + self.pos_embed[:, :1, :] + cls_tokens = cls_token.expand(x.shape[0], -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + # apply Transformer blocks + for blk in self.blocks: + x = blk(x) + x = self.norm(x) + return x diff --git a/instageo/model/README.md b/instageo/model/README.md new file mode 100644 index 0000000..1bb0eb8 --- /dev/null +++ b/instageo/model/README.md @@ -0,0 +1,225 @@ +# InstaGeo - Model + +## Overview + +The InstaGeo Model component is designed for training, validation and inference using custom deep learning models having [Prithvi_100M](https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M) foundational model as backbone. + +## Requirements + +See `requirements.txt` + +## Features +- Supports both classsification and regression tasks +- Accepts both temporal and non-temporal inputs +- Custom models with [Prithvi_100M](https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M) backbone +- Training, Validation and Inference runs +- Sliding window inference for inference on expansive HLS tiles, which measure 3660 x 3660 pixels +- Reproducible training pipeline +- Command-line flags for easy configuration of training parameters. + + + +## Usage + +1. **Setting Up Run Arguments:** Configure the training parameters using command-line arguments. + + - `root_dir`: Root directory of the dataset. + - `valid_filepath`: File path for validation data. + - `train_filepath`: File path for training data. + - `test_filepath`: File path for test data. + - `checkpoint_path`: File path for model checkpoint. + - `learning_rate`: Initial learning rate. + - `num_epochs`: Number of training epochs. + - `batch_size`: Batch size for training and validation. + - `mode`: Select one of training or evaluation mode. +See `configs/config.yaml` for more. + +2. **Dataset Preparation:** Prepare your geospatial data using the InstaGeo Chip Creator or similar and place it in the specified `root_dir`. Ensure that the csv file for each dataset has `Input` and `Label` columns corresponding to the path of the image and label relative to the `root_dir`. Additionally, ensure the data is compatible with `InstaGeoDataset` + +3. **Training the Model:** + +Run training with the necessary flags: + +```bash +python -m instageo.model.run \ + root_dir=path/to/root valid_filepath=path/to/valdata \ + train_filepath=path/to/traindata \ + learning_rate=0.001 \ + num_epochs=100 \ + batch_size=4 +``` + +4. **Prediction using Sliding Window Inference:** For training we create chips from HLS tiles, this is necessary because our model can only process an input of size 224 x 224. For the purpose of inference we have a sliding window inference feature that inputs HLS tile and perform a sliding window inference on patches of size 224 x 224. This is useful because it skips the process of creating chips using the `instageo.data.chip_creator`, we only need to download HLS tiles and directly runs inference on them. We can run inference using the following command: + +```bash +python -m instageo.model.run --config-name=config.yaml \ + root_dir='path-to-root_dir-containing-hls_dataset.json' \ + test_filepath='hls_dataset.json' \ + train.batch_size=16 \ + test.stride=224 \ + checkpoint_path='path-to-checkpoint' \ + mode=predict +``` + +5. **Example (Flood Mapping):** +[Sen1Floods11](https://github.com/cloudtostreet/Sen1Floods11) is a geospatial dataset of 10m Sentinel-2 imagery for flood detection. +- Data: Download the Sen1Floods11 hand labelled Sentinel-2 chips as well as `train`, `validation` and `test` splits using the following command +```bash +mkdir sen1floods11 +mkdir sen1floods11/S2Hand +mkdir sen1floods11/LabelHand + +gsutil cp gs://instageo/data/sen1floods11/flood_train_data.csv sen1floods11 +gsutil cp gs://instageo/data/sen1floods11/flood_test_data.csv sen1floods11 +gsutil cp gs://instageo/data/sen1floods11/flood_valid_data.csv sen1floods11 + +gsutil -m rsync -r gs://sen1floods11/v1.1/data/flood_events/HandLabeled/S2Hand sen1floods11/S2Hand +gsutil -m rsync -r gs://sen1floods11/v1.1/data/flood_events/HandLabeled/LabelHand sen1floods11/LabelHand +``` + +- Model: Fine-tune Prithvi on the dataset by running the following command +```bash +python -m instageo.model.run --config-name=sen1floods11 \ + root_dir=sen1floods11 \ + train_filepath=sen1floods11/flood_train_data.csv \ + valid_filepath=sen1floods11/flood_valid_data.csv \ + num_epochs=100 +``` +After training you are expected to have a checkpoint with mIoU of ~ 89% + +- Evaluate: Evaluate the fine-tuned model on test set using the following command. +Replace `path/to/checkpoint/checkpoint.ckpt` with the path to your model checkpoint. +```bash +python -m instageo.model.run --config-name=sen1floods11 \ + root_dir=sen1floods11 test_filepath=sen1floods11/flood_test_data.csv \ + checkpoint_path=path/to/checkpoint/checkpoint.ckpt \ + mode=eval +``` +When the saved checkpoint is evaluated on the test set, you should have results comparable to the following + +`Class based metrics:` +| Metric | Class 0 (No Water) | Class 1 (Flood/Water) | +|-------------------|-----------------------| -----------------------| +| Accuracy | 0.99 | 0.88 +| Intersection over Union (IoU) | 0.97 | 0.81 + +`Global metrics:` +| Metric | Value | +|-----------|-------| +| Overall Accuracy | 0.98 | +| Mean IoU | 0.89 | +| Cross Entropy Loss | 0.11 | + +6. **Example (Multi-Temporal Crop Classification):** +[Multi-Temporal Crop Classification](https://huggingface.co/datasets/ibm-nasa-geospatial/multi-temporal-crop-classification) contains Harmonized Landsat-Sentinel (HLS) imagery spanning various land cover and crop type classes throughout the Contiguous United States, captured during the year 2022. The classification labels used in this dataset are based on the Crop Data Layer (CDL) provided by the United States Department of Agriculture (USDA). + +- Data: Download the Multi-Temporal Crop Classification data splits using the following command + +```bash +gsutil -m cp -r gs://instageo/data/multi-temporal-crop-classification . +``` + +- Model: Fine-tune Prithvi on the dataset by running the following command + +```bash +python -m instageo.model.run --config-name=multitemporal_crop_classification \ + root_dir='multi-temporal-crop-classification' \ + train_filepath='multi-temporal-crop-classification/training_data.csv' \ + valid_filepath='multi-temporal-crop-classification/validation_data.csv' \ + train.batch_size=16 \ + train.num_epochs=100 \ + train.learning_rate=1e-4 +``` +After training you are expected to have a checkpoint with mIoU of ~ 45% + +- Evaluate: Evaluate the fine-tuned model on test set using the following command. +Replace `path/to/checkpoint/checkpoint.ckpt` with the path to your model checkpoint. + +```bash +python -m instageo.model.run --config-name=multitemporal_crop_classification \ + root_dir='multi-temporal-crop-classification' \ + test_filepath='multi-temporal-crop-classification/validation_data.csv' \ + train.batch_size=16 \ + checkpoint_path=`path/to/checkpoint/checkpoint.ckpt` \ + mode=eval +``` +When the saved checkpoint is evaluated on the test set, you should have results comparable to the following + +`Class based metrics:` +| Metric | Accuracy | Intersection over Union (IoU) | +|-------------------------------|-------|-------| +| Natural Vegetation | 0.44 | 0.50 | +| Forest | 0.53 | 0.76 | +| Corn | 0.62 | 0.73 | +| Soybeans | 0.60 | 0.75 | +| Wetlands | 0.45 | 0.59 | +| Developed/Barren | 0.42 | 0.66 | +| Open Water | 0.69 | 0.88 | +| Winter Wheat | 0.55 | 0.71 | +| Alfalfa | 0.37 | 0.67 | +| Fallow/Idle Cropland | 0.37 | 0.57 | +| Cotton | 0.37 | 0.67 | +| Sorghum | 0.36 | 0.65 | +| Other | 0.43 | 0.57 | + +`Global metrics:` +| Metric | Value | +|-----------|-------| +| Overall Accuracy | 0.67 | +| Mean IoU | 0.48 | +| Cross Entropy Loss | 0.93 | + +7. **Example (Desert Locust Breeding Ground Prediction):** +Desert Locusts Breeding Ground Prediction using HLS dataset. Observation records of breeding grounds are sourced from [UN-FAO Locust Hub](https://locust-hub-hqfao.hub.arcgis.com/) and used to download HLS tiles used for creating chips and segmentation maps. + +- Data: The resulting chips and segmentation maps created using `instageo.chip_creator` can be downloaded using the following command: +```bash +gsutil -m cp -r gs://instageo/data/locust_breeding . +``` + +- Model: Fine-tune Prithvi on the dataset by running the following command + +```bash +python -m instageo.model.run --config-name=locust \ + root_dir='locust_breeding' \ + train_filepath='locust_breeding/train.csv' \ + valid_filepath='locust_breeding/val.csv' +``` +After training you are expected to have a checkpoint with mIoU of 70% + +- Evaluate: Evaluate the fine-tuned model on test set using the following command. +Replace `path/to/checkpoint/checkpoint.ckpt` with the path to your model checkpoint. + +```bash +python -m instageo.model.run --config-name=locust \ + root_dir='locust_breeding' \ + test_filepath='locust_breeding/test.csv' \ + checkpoint_path=`path/to/checkpoint/checkpoint.ckpt` \ + mode=eval +``` +When the saved checkpoint is evaluated on the test set, you should have results comparable to the following + +`Class based metrics:` +| Metric | Class 0 (Non-Breeding) | Class 1 (Breeding) | +|--------------------------------|--------------------------| -----------------------| +| Accuracy | 0.85 | 0.85 | +| Intersection over Union (IoU) | 0.74 | 0.74 | + +`Global metrics:` +| Metric | Value | +|-----------|-------| +| Overall Accuracy | 0.85 | +| Mean IoU | 0.74 | +| Cross Entropy Loss | 0.40 | + +## Customization + +- Modify the `bands`, `mean`, and `std` lists in `configs/config.yaml` to match your dataset's characteristics. +- Implement additional data augmentation strategies in `process_and_augment`. diff --git a/instageo/model/__init__.py b/instageo/model/__init__.py new file mode 100644 index 0000000..9961696 --- /dev/null +++ b/instageo/model/__init__.py @@ -0,0 +1,6 @@ +"""Modeling Component for InstaGeo. + +This subpackage contains modules and classes for building, training, and evaluating +geospatial models. It includes various algorithms and model architectures specialized +for handling and analyzing geospatial data. +""" diff --git a/instageo/model/configs/config.yaml b/instageo/model/configs/config.yaml new file mode 100644 index 0000000..9a2edc1 --- /dev/null +++ b/instageo/model/configs/config.yaml @@ -0,0 +1,36 @@ +# config.yaml +root_dir: null +valid_filepath: null +train_filepath: null +test_filepath: null +checkpoint_path: null +mode: train # train or eval + +train: + learning_rate: 0.0001 + num_epochs: 10 + batch_size: 8 + class_weights: [1, 1] + ignore_index: -100 + weight_decay: 0.01 + +model: + freeze_backbone: False + num_classes: 2 + +dataloader: + bands: [1, 2, 3, 8, 11, 12] # Blue, Green, Red, Narrow NIR, SWIR 1, SWIR 2 + mean: [0.14245495, 0.13921481, 0.12434631, 0.31420089, 0.20743526, 0.12046503] + std: [0.04036231, 0.04186983, 0.05267646, 0.0822221, 0.06834774, 0.05294205] + img_size: 224 + temporal_dim: 1 + replace_label: [-1, 2] + reduce_to_zero: False + no_data_value: -9999 + constant_multiplier: 1.0 + +test: + img_size: 224 + crop_size: 224 + stride: 224 + mask_cloud: False diff --git a/instageo/model/configs/locust.yaml b/instageo/model/configs/locust.yaml new file mode 100644 index 0000000..cabd5cc --- /dev/null +++ b/instageo/model/configs/locust.yaml @@ -0,0 +1,37 @@ +# config.yaml +root_dir: null +valid_filepath: null +train_filepath: null +test_filepath: null +checkpoint_path: null +mode: train # train or eval + +train: + learning_rate: 0.0001 + num_epochs: 20 + batch_size: 8 + class_weights: [1, 1] + ignore_index: -1 + weight_decay: 0.1 + +model: + freeze_backbone: False + num_classes: 2 + +dataloader: + # 3*(Blue, Green, Red, Narrow NIR, SWIR 1, SWIR 2) + bands: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17] + mean: [0.14245495, 0.13921481, 0.12434631, 0.31420089, 0.20743526, 0.12046503] + std: [0.04036231, 0.04186983, 0.05267646, 0.0822221, 0.06834774, 0.05294205] + img_size: 224 + temporal_dim: 3 + replace_label: null + reduce_to_zero: False + no_data_value: -1 + constant_multiplier: 1.0 + +test: + img_size: 224 + crop_size: 224 + stride: 224 + mask_cloud: False diff --git a/instageo/model/configs/multitemporal_crop_classification.yaml b/instageo/model/configs/multitemporal_crop_classification.yaml new file mode 100644 index 0000000..43d4305 --- /dev/null +++ b/instageo/model/configs/multitemporal_crop_classification.yaml @@ -0,0 +1,53 @@ +# config.yaml +root_dir: null +valid_filepath: null +train_filepath: null +test_filepath: null +checkpoint_path: null +mode: train # train or eval + +train: + learning_rate: 0.0001 + num_epochs: 10 + batch_size: 8 + class_weights: + [ + 0.386375, + 0.661126, + 0.548184, + 0.640482, + 0.876862, + 0.925186, + 3.249462, + 1.542289, + 2.175141, + 2.272419, + 3.062762, + 3.626097, + 1.198702, + ] + ignore_index: -100 + weight_decay: 0.01 + +model: + freeze_backbone: False + num_classes: 13 + +dataloader: + # 3*(Blue, Green, Red, Narrow NIR, SWIR 1, SWIR 2) + bands: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17] + mean: + [494.905781, 815.239594, 924.335066, 2968.881459, 2634.621962, 1739.579917] + std: [284.925432, 357.84876, 575.566823, 896.601013, 951.900334, 921.407808] + img_size: 224 + temporal_dim: 3 + replace_label: null + reduce_to_zero: True + no_data_value: null + constant_multiplier: 1.0 + +test: + img_size: 224 + crop_size: 224 + stride: 224 + mask_cloud: False diff --git a/instageo/model/configs/sen1floods11.yaml b/instageo/model/configs/sen1floods11.yaml new file mode 100644 index 0000000..827bd2c --- /dev/null +++ b/instageo/model/configs/sen1floods11.yaml @@ -0,0 +1,36 @@ +# config.yaml +root_dir: null +valid_filepath: null +train_filepath: null +test_filepath: null +checkpoint_path: null +mode: train # train or eval + +train: + learning_rate: 0.0001 + num_epochs: 10 + batch_size: 16 + class_weights: [1, 3] + ignore_index: 2 + weight_decay: 0.01 + +model: + freeze_backbone: False + num_classes: 2 + +dataloader: + bands: [1, 2, 3, 8, 11, 12] # Blue, Green, Red, Narrow NIR, SWIR 1, SWIR 2 + mean: [0.14245495, 0.13921481, 0.12434631, 0.31420089, 0.20743526, 0.12046503] + std: [0.04036231, 0.04186983, 0.05267646, 0.0822221, 0.06834774, 0.05294205] + img_size: 224 + temporal_dim: 1 + replace_label: [-1, 2] + reduce_to_zero: False + no_data_value: -9999 + constant_multiplier: 0.0001 + +test: + img_size: 512 + crop_size: 224 + stride: 224 + mask_cloud: False diff --git a/instageo/model/dataloader.py b/instageo/model/dataloader.py new file mode 100644 index 0000000..a0f2d06 --- /dev/null +++ b/instageo/model/dataloader.py @@ -0,0 +1,390 @@ +# ------------------------------------------------------------------------------ +# This code is licensed under the Attribution-NonCommercial-ShareAlike 4.0 +# International (CC BY-NC-SA 4.0) License. +# +# You are free to: +# - Share: Copy and redistribute the material in any medium or format +# - Adapt: Remix, transform, and build upon the material +# +# Under the following terms: +# - Attribution: You must give appropriate credit, provide a link to the license, +# and indicate if changes were made. You may do so in any reasonable manner, +# but not in any way that suggests the licensor endorses you or your use. +# - NonCommercial: You may not use the material for commercial purposes. +# - ShareAlike: If you remix, transform, or build upon the material, you must +# distribute your contributions under the same license as the original. +# +# For more details, see https://creativecommons.org/licenses/by-nc-sa/4.0/ +# ------------------------------------------------------------------------------ + +"""Dataloader Module.""" + +import os +import random +from functools import partial +from typing import Callable, List, Tuple + +import numpy as np +import pandas as pd +import rasterio +import torch +from absl import logging +from PIL import Image +from torchvision import transforms + +from instageo.data.geo_utils import open_mf_tiff_dataset + + +def random_crop_and_flip( + ims: List[Image.Image], label: Image.Image, im_size: int +) -> Tuple[List[Image.Image], Image.Image]: + """Apply random cropping and flipping transformations to the given images and label. + + Args: + ims (List[Image.Image]): List of PIL Image objects representing the images. + label (Image.Image): A PIL Image object representing the label. + + Returns: + Tuple[List[Image.Image], Image.Image]: A tuple containing the transformed list of + images and label. + """ + i, j, h, w = transforms.RandomCrop.get_params(ims[0], (im_size, im_size)) + + ims = [transforms.functional.crop(im, i, j, h, w) for im in ims] + label = transforms.functional.crop(label, i, j, h, w) + + if random.random() > 0.5: + ims = [transforms.functional.hflip(im) for im in ims] + label = transforms.functional.hflip(label) + + if random.random() > 0.5: + ims = [transforms.functional.vflip(im) for im in ims] + label = transforms.functional.vflip(label) + + return ims, label + + +def normalize_and_convert_to_tensor( + ims: List[Image.Image], + label: Image.Image | None, + mean: List[float], + std: List[float], + temporal_size: int = 1, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Normalize the images and label and convert them to PyTorch tensors. + + Args: + ims (List[Image.Image]): List of PIL Image objects representing the images. + label (Image.Image | None): A PIL Image object representing the label. + mean (List[float]): The mean of each channel in the image + std (List[float]): The standard deviation of each channel in the image + temporal_size: The number of temporal steps + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple of tensors representing the normalized + images and label. + """ + norm = transforms.Normalize(mean, std) + ims_tensor = torch.stack([transforms.ToTensor()(im).squeeze() for im in ims]) + _, h, w = ims_tensor.shape + ims_tensor = ims_tensor.reshape([temporal_size, -1, h, w]) # T*C,H,W -> T,C,H,W + ims_tensor = torch.stack([norm(im) for im in ims_tensor]).permute( + [1, 0, 2, 3] + ) # T,C,H,W -> C,T,H,W + if label: + label = torch.from_numpy(np.array(label)).squeeze() + return ims_tensor, label + + +def process_and_augment( + x: np.ndarray, + y: np.ndarray | None, + mean: List[float], + std: List[float], + temporal_size: int = 1, + im_size: int = 224, + train: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Process and augment the given images and labels. + + Args: + x (np.ndarray): Numpy array representing the images. + y (np.ndarray): Numpy array representing the label. + mean (List[float]): The mean of each channel in the image + std (List[float]): The standard deviation of each channel in the image + temporal_size: The number of temporal steps + train: To identify training mode. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple of tensors representing the processed + and augmented images and label. + """ + ims = x.copy() + label = None + # convert to PIL for easier transforms + ims = [Image.fromarray(im) for im in ims] + if y is not None: + label = Image.fromarray(y.copy().squeeze()) + if train: + ims, label = random_crop_and_flip(ims, label, im_size) + ims, label = normalize_and_convert_to_tensor(ims, label, mean, std, temporal_size) + return ims, label + + +def crop_array( + arr: np.ndarray, left: int, top: int, right: int, bottom: int +) -> np.ndarray: + """Crop Numpy Image. + + Crop a given array (image) using specified left, top, right, and bottom indices. + + This function supports cropping both grayscale (2D) and color (3D) images. + + Args: + arr (np.ndarray): The input array (image) to be cropped. + left (int): The left boundary index for cropping. + top (int): The top boundary index for cropping. + right (int): The right boundary index for cropping. + bottom (int): The bottom boundary index for cropping. + + Returns: + np.ndarray: The cropped portion of the input array (image). + + Raises: + ValueError: If the input array is not 2D or 3D. + """ + if len(arr.shape) == 2: # Grayscale image (2D array) + return arr[top:bottom, left:right] + elif len(arr.shape) == 3: # Color image (3D array) + return arr[:, top:bottom, left:right] + elif len(arr.shape) == 4: # Color image (3D array) + return arr[:, :, top:bottom, left:right] + else: + raise ValueError("Input array must be a 2D, 3D or 4D array") + + +def process_test( + x: np.ndarray, + y: np.ndarray, + mean: List[float], + std: List[float], + temporal_size: int = 1, + img_size: int = 512, + crop_size: int = 224, + stride: int = 224, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Process and augment test data. + + Args: + x (np.ndarray): Input image array. + y (np.ndarray): Corresponding mask array. + mean (List[float]): Mean values for normalization. + std: (List[float]): Standard deviation values for normalization. + temporal_size (int, optional): Temporal dimension size. Defaults to 1. + img_size (int, optional): Size of the input images. Defaults to + 512. + crop_size (int, optional): Size of the crops to be extracted from the + images. Defaults to 224. + stride (int, optional): Stride for cropping. Defaults to 224. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of tensors containing the processed + images and masks. + """ + preprocess_func = partial( + process_and_augment, + mean=mean, + std=std, + temporal_size=temporal_size, + train=False, + ) + + img_crops, mask_crops = [], [] + width, height = img_size, img_size + + for top in range(0, height - crop_size + 1, stride): + for left in range(0, width - crop_size + 1, stride): + bottom = top + crop_size + right = left + crop_size + + img_crops.append(crop_array(x, left, top, right, bottom)) + mask_crops.append(crop_array(y, left, top, right, bottom)) + + samples = [preprocess_func(x, y) for x, y in zip(img_crops, mask_crops)] + imgs = torch.stack([sample[0] for sample in samples]) + labels = torch.stack([sample[1] for sample in samples]) + return imgs, labels + + +def get_raster_data( + fname: str | dict[str, dict[str, str]], + is_label: bool = True, + bands: List[int] | None = None, + no_data_value: int | None = -9999, + mask_cloud: bool = True, +) -> np.ndarray: + """Load and process raster data from a file. + + Args: + fname (str): Filename to load data from. + is_label (bool): Whether the file is a label file. + bands (List[int]): Index of bands to select from array. + no_data_value (int | None): NODATA value in image raster. + mask_cloud (bool): Perform cloud masking. + + Returns: + np.ndarray: Numpy array representing the processed data. + """ + if isinstance(fname, dict): + data, _ = open_mf_tiff_dataset(fname, mask_cloud) + data = data.fillna(no_data_value) + data = data.band_data.values + else: + with rasterio.open(fname) as src: + data = src.read() + if (not is_label) and bands: + data = data[bands, ...] + # For some reasons, some few HLS tiles are not scaled in v2.0. + # In the following lines, we find and scale them + bands = [] + for band in data: + if band.max() > 10: + band *= 0.0001 + bands.append(band) + data = np.stack(bands, axis=0) + return data + + +def process_data( + im_fname: str, + mask_fname: str | None = None, + no_data_value: int | None = -9999, + reduce_to_zero: bool = False, + replace_label: Tuple | None = None, + bands: List[int] | None = None, + constant_multiplier: float = 1.0, + mask_cloud: bool = False, +) -> Tuple[np.ndarray, np.ndarray]: + """Process image and mask data from filenames. + + Args: + im_fname (str): Filename for the image data. + mask_fname (str | None): Filename for the mask data. + bands (List[int]): Indices of bands to select from array. + no_data_value (int | None): NODATA value in image raster. + reduce_to_zero (bool): Reduces the label index to start from Zero. + replace_label (Tuple): Tuple of value to replace and the replacement value. + constant_multiplier (float): Constant multiplier for image. + mask_cloud (bool): Perform cloud masking. + + Returns: + Tuple[np.ndarray, np.ndarray]: A tuple of numpy arrays representing the processed + image and mask data. + """ + arr_x = get_raster_data( + im_fname, + is_label=False, + bands=bands, + no_data_value=no_data_value, + mask_cloud=mask_cloud, + ) + arr_x = arr_x * constant_multiplier + if mask_fname: + arr_y = get_raster_data(mask_fname) + if replace_label: + arr_y = np.where(arr_y == replace_label[0], replace_label[1], arr_y) + if reduce_to_zero: + arr_y -= 1 + else: + arr_y = None + return arr_x, arr_y + + +def load_data_from_csv(fname: str, input_root: str) -> List[Tuple[str, str]]: + """Load data file paths from a CSV file. + + Args: + fname (str): Filename of the CSV file. + input_root (str): Root directory for input images and labels. + + Returns: + List[Tuple[str, str]]: A list of tuples, each containing file paths for input + image and label image. + """ + file_paths = [] + data = pd.read_csv(fname) + for _, row in data.iterrows(): + im_path = os.path.join(input_root, row["Input"]) + mask_path = os.path.join(input_root, row["Label"]) + if os.path.exists(im_path): + try: + with rasterio.open(im_path) as src: + _ = src.crs + file_paths.append((im_path, mask_path)) + except Exception as e: + logging.error(e) + continue + return file_paths + + +class InstaGeoDataset(torch.utils.data.Dataset): + """InstaGeo PyTorch Dataset for Loading and Handling HLS Data.""" + + def __init__( + self, + filename: str, + input_root: str, + preprocess_func: Callable, + no_data_value: int | None, + replace_label: Tuple, + reduce_to_zero: bool, + constant_multiplier: float, + bands: List[int] | None = None, + ): + """Dataset Class for loading and preprocessing the dataset. + + Args: + filename (str): Filename of the CSV file containing data paths. + input_root (str): Root directory for input images and labels. + preprocess_func (Callable): Function to preprocess the data. + bands (List[int]): Indices of bands to select from array. + no_data_value (int | None): NODATA value in image raster. + reduce_to_zero (bool): Reduces the label index to start from Zero. + replace_label (Tuple): Tuple of value to replace and the replacement value. + constant_multiplier (float): Constant multiplier for image. + + """ + self.input_root = input_root + self.preprocess_func = preprocess_func + self.bands = bands + self.file_paths = load_data_from_csv(filename, input_root) + self.no_data_value = no_data_value + self.replace_label = replace_label + self.reduce_to_zero = reduce_to_zero + self.constant_multiplier = constant_multiplier + + def __getitem__(self, i: int) -> Tuple[torch.Tensor, torch.Tensor]: + """Retrieves a sample from dataset. + + Args: + i (int): Sample index to retrieve. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple of tensors representing the + processed images and label. + """ + im_fname, mask_fname = self.file_paths[i] + arr_x, arr_y = process_data( + im_fname, + mask_fname, + no_data_value=self.no_data_value, + replace_label=self.replace_label, + reduce_to_zero=self.reduce_to_zero, + bands=self.bands, + constant_multiplier=self.constant_multiplier, + ) + return self.preprocess_func(arr_x, arr_y) + + def __len__(self) -> int: + """Return length of dataset.""" + return len(self.file_paths) diff --git a/instageo/model/infer_utils.py b/instageo/model/infer_utils.py new file mode 100644 index 0000000..6673e07 --- /dev/null +++ b/instageo/model/infer_utils.py @@ -0,0 +1,94 @@ +# ------------------------------------------------------------------------------ +# This code is licensed under the Attribution-NonCommercial-ShareAlike 4.0 +# International (CC BY-NC-SA 4.0) License. +# +# You are free to: +# - Share: Copy and redistribute the material in any medium or format +# - Adapt: Remix, transform, and build upon the material +# +# Under the following terms: +# - Attribution: You must give appropriate credit, provide a link to the license, +# and indicate if changes were made. You may do so in any reasonable manner, +# but not in any way that suggests the licensor endorses you or your use. +# - NonCommercial: You may not use the material for commercial purposes. +# - ShareAlike: If you remix, transform, or build upon the material, you must +# distribute your contributions under the same license as the original. +# +# For more details, see https://creativecommons.org/licenses/by-nc-sa/4.0/ +# ------------------------------------------------------------------------------ + +"""Utils for Running Inference.""" + +import numpy as np +import pytorch_lightning as pl +import torch + +from instageo.model.dataloader import crop_array + + +def sliding_window_inference( + hls_tile: torch.Tensor, + model: pl.LightningModule, + window_size: tuple[int, int] = (224, 224), + stride: int = 224, + batch_size: int = 32, + device: str = "gpu", +) -> np.ndarray: + """Sliding Window Inference. + + Performs sliding window inference on large inputs using a given model with batching, + and reassemble the output to match the original image size. + + Args: + image_path: Path to the large image. + model: Trained model for inference. + window_size: Size of the window (default is 224x224). + stride: Step size for sliding the window (default is 224). + batch_size: Number of patches to process in one batch. + device: Device used for training. + + Returns: + Final prediction image of the same size as the original image. + """ + device = "cuda" if device == "gpu" else device + _, _, width, height = hls_tile.shape + + final_prediction = np.zeros((height, width), dtype=np.float32) + + patch_coords = [] + current_batch = [] + + for y in range(0, height - window_size[1] + 1, stride): + for x in range(0, width - window_size[0] + 1, stride): + patch_array = crop_array( + hls_tile, x, y, x + window_size[0], y + window_size[1] + ) + patch_coords.append((x, y)) + current_batch.append(patch_array) + + if len(current_batch) == batch_size: + batch_array = torch.stack(current_batch, dim=0) + batch_results = ( + model.predict_step(batch_array.to(device)).detach().cpu().numpy() + ) + + for i, (px, py) in enumerate(patch_coords): + final_prediction[ + py : py + window_size[1], px : px + window_size[0] + ] = batch_results[i] + + current_batch = [] + patch_coords = [] + + if current_batch: + batch_array = torch.stack(current_batch, dim=0) + batch_results = ( + model.predict_step(batch_array.to(device)).detach().cpu().numpy() + ) + + for i, (px, py) in enumerate(patch_coords): + final_prediction[ + py : py + window_size[1], px : px + window_size[0] + ] = batch_results[i] + + return final_prediction diff --git a/instageo/model/model.py b/instageo/model/model.py new file mode 100644 index 0000000..473dc7e --- /dev/null +++ b/instageo/model/model.py @@ -0,0 +1,228 @@ +# ------------------------------------------------------------------------------ +# This code is licensed under the Attribution-NonCommercial-ShareAlike 4.0 +# International (CC BY-NC-SA 4.0) License. +# +# You are free to: +# - Share: Copy and redistribute the material in any medium or format +# - Adapt: Remix, transform, and build upon the material +# +# Under the following terms: +# - Attribution: You must give appropriate credit, provide a link to the license, +# and indicate if changes were made. You may do so in any reasonable manner, +# but not in any way that suggests the licensor endorses you or your use. +# - NonCommercial: You may not use the material for commercial purposes. +# - ShareAlike: If you remix, transform, or build upon the material, you must +# distribute your contributions under the same license as the original. +# +# For more details, see https://creativecommons.org/licenses/by-nc-sa/4.0/ +# ------------------------------------------------------------------------------ + +"""Model Module.""" + +import os +import time +from pathlib import Path + +import numpy as np +import requests # type: ignore +import torch +import torch.nn as nn +import yaml # type: ignore +from absl import logging + +from instageo.model.Prithvi import ViTEncoder + + +def download_file(url: str, filename: str | Path, retries: int = 3) -> None: + """Downloads a file from the given URL and saves it to a local file. + + Args: + url (str): The URL from which to download the file. + filename (str): The local path where the file will be saved. + retries (int, optional): The number of times to retry the download + in case of failure. Defaults to 3. + + Raises: + Exception: If the download fails after the specified number of retries. + + Returns: + None + """ + if os.path.exists(filename): + logging.info(f"File '{filename}' already exists. Skipping download.") + return + + for attempt in range(retries): + try: + response = requests.get(url) + if response.status_code == 200: + with open(filename, "wb") as f: + f.write(response.content) + logging.info(f"Download successful on attempt {attempt + 1}") + break + else: + logging.warning( + f"Attempt {attempt + 1} failed with status code {response.status_code}" # noqa + ) + except requests.RequestException as e: + logging.warning(f"Attempt {attempt + 1} failed with error: {e}") + + if attempt < retries - 1: + time.sleep(2) + + else: + raise Exception("Failed to download the file after several attempts.") + + +class Norm2D(nn.Module): + """A normalization layer for 2D inputs. + + This class implements a 2D normalization layer using Layer Normalization. + It is designed to normalize 2D inputs (e.g., images or feature maps in a + convolutional neural network). + + Attributes: + ln (nn.LayerNorm): The layer normalization component. + + Args: + embed_dim (int): The number of features of the input tensor (i.e., the number of + channels in the case of images). + + Methods: + forward: Applies normalization to the input tensor. + """ + + def __init__(self, embed_dim: int): + """Initializes the Norm2D module. + + Args: + embed_dim (int): The number of features of the input tensor. + """ + super().__init__() + self.ln = nn.LayerNorm(embed_dim, eps=1e-6) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Applies the normalization process to the input tensor. + + Args: + x (torch.Tensor): A 4D input tensor with shape + (batch_size, channels, height, width). + + Returns: + torch.Tensor: The normalized tensor, having the same shape as the input. + """ + x = x.permute(0, 2, 3, 1) + x = self.ln(x) + x = x.permute(0, 3, 1, 2).contiguous() + return x + + +class PrithviSeg(nn.Module): + """Prithvi Segmentation Model.""" + + def __init__( + self, temporal_step: int = 1, num_classes: int = 2, freeze_backbone: bool = True + ) -> None: + """Initialize the PrithviSeg model. + + This model is designed for image segmentation tasks on remote sensing data. + It loads Prithvi configuration and weights and sets up a ViTEncoder backbone + along with a segmentation head. + + Args: + temporal_step (int): Size of temporal dimension. + num_classes (int): Number of target classes. + freeze_backbone (bool): Flag to freeze ViT transformer backbone weights. + """ + super().__init__() + weights_dir = Path.home() / ".instageo" / "prithvi" + weights_dir.mkdir(parents=True, exist_ok=True) + weights_path = weights_dir / "Prithvi_100M.pt" + cfg_path = weights_dir / "Prithvi_100M_config.yaml" + download_file( + "https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M/resolve/main/Prithvi_100M.pt?download=true", # noqa + weights_path, + ) + download_file( + "https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M/raw/main/Prithvi_100M_config.yaml", # noqa + cfg_path, + ) + checkpoint = torch.load(weights_path, map_location="cpu") + with open(cfg_path) as f: + model_config = yaml.safe_load(f) + + model_args = model_config["model_args"] + + model_args["num_frames"] = temporal_step + self.model_args = model_args + # instantiate model + model = ViTEncoder(**model_args) + if freeze_backbone: + for param in model.parameters(): + param.requires_grad = False + del checkpoint["pos_embed"] + _ = model.load_state_dict(checkpoint, strict=False) + + self.prithvi_100M_backbone = model + + def upscaling_block(in_channels: int, out_channels: int) -> nn.Module: + """Upscaling block. + + Args: + in_channels (int): number of input channels. + out_channels (int): number of output channels. + + Returns: + An upscaling block configured to upscale spatially. + """ + return nn.Sequential( + nn.ConvTranspose2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=2, + padding=1, + output_padding=1, + ), + nn.Conv2d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + padding=1, + ), + nn.BatchNorm2d(out_channels), + nn.ReLU(), + ) + + embed_dims = [ + (model_args["embed_dim"] * model_args["num_frames"]) // (2**i) + for i in range(5) + ] + self.segmentation_head = nn.Sequential( + *[upscaling_block(embed_dims[i], embed_dims[i + 1]) for i in range(4)], + nn.Conv2d( + kernel_size=1, in_channels=embed_dims[-1], out_channels=num_classes + ), + ) + + def forward(self, img: torch.Tensor) -> torch.Tensor: + """Define the forward pass of the model. + + Args: + img (torch.Tensor): The input tensor representing the image. + + Returns: + torch.Tensor: Output tensor after image segmentation. + """ + features = self.prithvi_100M_backbone(img) + # drop cls token + reshaped_features = features[:, 1:, :] + feature_img_side_length = int( + np.sqrt(reshaped_features.shape[1] // self.model_args["num_frames"]) + ) + reshaped_features = reshaped_features.permute(0, 2, 1).reshape( + features.shape[0], -1, feature_img_side_length, feature_img_side_length + ) + + out = self.segmentation_head(reshaped_features) + return out diff --git a/instageo/model/requirements.txt b/instageo/model/requirements.txt new file mode 100644 index 0000000..55a82ab --- /dev/null +++ b/instageo/model/requirements.txt @@ -0,0 +1,8 @@ +pytorch_lightning +torch +timm==0.4.12 +einops +tensorboard +omegaconf +hydra-core +torchvision diff --git a/instageo/model/run.py b/instageo/model/run.py new file mode 100644 index 0000000..839f5be --- /dev/null +++ b/instageo/model/run.py @@ -0,0 +1,630 @@ +# ------------------------------------------------------------------------------ +# This code is licensed under the Attribution-NonCommercial-ShareAlike 4.0 +# International (CC BY-NC-SA 4.0) License. +# +# You are free to: +# - Share: Copy and redistribute the material in any medium or format +# - Adapt: Remix, transform, and build upon the material +# +# Under the following terms: +# - Attribution: You must give appropriate credit, provide a link to the license, +# and indicate if changes were made. You may do so in any reasonable manner, +# but not in any way that suggests the licensor endorses you or your use. +# - NonCommercial: You may not use the material for commercial purposes. +# - ShareAlike: If you remix, transform, or build upon the material, you must +# distribute your contributions under the same license as the original. +# +# For more details, see https://creativecommons.org/licenses/by-nc-sa/4.0/ +# ------------------------------------------------------------------------------ + +"""Run Module Containing Training, Evaluation and Inference Logic.""" + +import json +import logging +import os +from functools import partial +from typing import Any, Callable, List, Optional, Tuple + +import hydra +import numpy as np +import pytorch_lightning as pl +import rasterio +import torch +import torch.nn as nn +from omegaconf import DictConfig, OmegaConf +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.loggers import TensorBoardLogger +from torch.utils.data import DataLoader, Dataset +from tqdm import tqdm + +from instageo.model.dataloader import ( + InstaGeoDataset, + process_and_augment, + process_data, + process_test, +) +from instageo.model.infer_utils import sliding_window_inference +from instageo.model.model import PrithviSeg + +pl.seed_everything(seed=1042, workers=True) +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False + +log = logging.getLogger(__name__) +log.setLevel(logging.INFO) + + +def check_required_flags(required_flags: List[str], config: DictConfig) -> None: + """Check if required flags are provided. + + Args: + required_flags: A list of required command line arguments. + + Raises: + An exception if at least one of the arguments is not set + """ + for flag_name in required_flags: + if getattr(config, flag_name) == "None": + raise RuntimeError(f"Flag --{flag_name} is required.") + + +def get_device() -> str: + """Selects available device.""" + try: + import torch_xla.core.xla_model as xm # noqa: F401 + + device = "tpu" + logging.info("TPU is available. Using TPU...") + except ImportError: + if torch.cuda.is_available(): + device = "gpu" + logging.info("GPU is available. Using GPU...") + else: + device = "cpu" + logging.info("Neither GPU nor TPU is available. Using CPU...") + return device + +def custom_collate_fn(batch): + """Test DataLoader Collate Function. + + This function is a convenient wrapper around the PyTorch DataLoader class, + allowing for easy setup of various DataLoader parameters. + + Args: + batch (Tuple[Tensor]): A list of tuples containing features and labels. + + Returns: + Tuple of (x,y) concatenated into separate tensors + """ + data = torch.cat([a[0] for a in batch], 0) + labels = torch.cat([a[1] for a in batch], 0) + return data, labels + +def create_dataloader( + dataset: Dataset, + batch_size: int, + shuffle: bool = False, + num_workers: int = 1, + collate_fn: Optional[Callable] = None, + pin_memory: bool = True, +) -> DataLoader: + """Create a DataLoader for the given dataset. + + This function is a convenient wrapper around the PyTorch DataLoader class, + allowing for easy setup of various DataLoader parameters. + + Args: + dataset (Dataset): The dataset to load data from. + batch_size (int): How many samples per batch to load. + shuffle (bool): Set to True to have the data reshuffled at every epoch. + num_workers (int): How many subprocesses to use for data loading. + collate_fn (Optional[Callable]): Merges a list of samples to form a mini-batch. + pin_memory (bool): If True, the data loader will copy tensors into CUDA pinned + memory. + + Returns: + DataLoader: An instance of the PyTorch DataLoader. + """ + return DataLoader( + dataset, + batch_size=batch_size, + shuffle=shuffle, + num_workers=num_workers, + collate_fn=collate_fn, + pin_memory=pin_memory, + ) + + +class PrithviSegmentationModule(pl.LightningModule): + """Prithvi Segmentation PyTorch Lightning Module.""" + + def __init__( + self, + learning_rate: float = 1e-4, + freeze_backbone: bool = True, + num_classes: int = 2, + temporal_step: int = 1, + class_weights: List[float] = [1, 2], + ignore_index: int = -100, + weight_decay: float = 1e-2, + ) -> None: + """Initialization. + + Initialize the PrithviSegmentationModule, a PyTorch Lightning module for image + segmentation. + + Args: + num_classes (int): Number of classes for segmentation. + temporal_step (int): Number of temporal steps for multi-temporal input. + learning_rate (float): Learning rate for the optimizer. + freeze_backbone (bool): Flag to freeze ViT transformer backbone weights. + class_weights (List[float]): Class weights for mitigating class imbalance. + ignore_index (int): Class index to ignore during loss computation. + weight_decay (float): Weight decay for L2 regularization. + """ + super().__init__() + self.net = PrithviSeg( + num_classes=num_classes, + temporal_step=temporal_step, + freeze_backbone=freeze_backbone, + ) + weight_tensor = torch.tensor(class_weights).float() if class_weights else None + self.criterion = nn.CrossEntropyLoss( + ignore_index=ignore_index, weight=weight_tensor + ) + self.learning_rate = learning_rate + self.ignore_index = ignore_index + self.weight_decay = weight_decay + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Define the forward pass of the model. + + Args: + x (torch.Tensor): Input tensor for the model. + + Returns: + torch.Tensor: Output tensor from the model. + """ + return self.net(x) + + def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor: + """Perform a training step. + + Args: + batch (Any): Input batch data. + batch_idx (int): Index of the batch. + + Returns: + torch.Tensor: The loss value for the batch. + """ + inputs, labels = batch + outputs = self.forward(inputs) + loss = self.criterion(outputs, labels.long()) + self.log_metrics(outputs, labels, "train", loss) + return loss + + def validation_step(self, batch: Any, batch_idx: int) -> torch.Tensor: + """Perform a validation step. + + Args: + batch (Any): Input batch data. + batch_idx (int): Index of the batch. + + Returns: + torch.Tensor: The loss value for the batch. + """ + inputs, labels = batch + outputs = self.forward(inputs) + loss = self.criterion(outputs, labels.long()) + self.log_metrics(outputs, labels, "val", loss) + return loss + + def test_step(self, batch: Any, batch_idx: int) -> torch.Tensor: + """Perform a test step. + + Args: + batch (Any): Input batch data. + batch_idx (int): Index of the batch. + + Returns: + torch.Tensor: The loss value for the batch. + """ + inputs, labels = batch + outputs = self.forward(inputs) + loss = self.criterion(outputs, labels.long()) + self.log_metrics(outputs, labels, "test", loss) + return loss + + def predict_step(self, batch: Any) -> torch.Tensor: + """Perform a prediction step. + + Args: + batch (Any): Input batch data. + + Returns: + torch.Tensor: The loss value for the batch. + """ + prediction = self.forward(batch) + probabilities = torch.nn.functional.softmax(prediction, dim=1)[:, 1, :, :] + return probabilities + + def configure_optimizers( + self, + ) -> Tuple[ + List[torch.optim.Optimizer], List[torch.optim.lr_scheduler._LRScheduler] + ]: + """Configure the model's optimizers and learning rate schedulers. + + Returns: + Tuple[List[torch.optim.Optimizer], List[torch.optim.lr_scheduler]]: + A tuple containing the list of optimizers and the list of LR schedulers. + """ + optimizer = torch.optim.AdamW( + self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay + ) + scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( + optimizer, T_0=10, T_mult=2, eta_min=0 + ) + return [optimizer], [scheduler] + + def log_metrics( + self, + predictions: torch.Tensor, + labels: torch.Tensor, + stage: str, + loss: torch.Tensor, + ) -> None: + """Log all metrics for any stage. + + Args: + predictions(torch.Tensor): Prediction tensor from the model. + labels(torch.Tensor): Label mask. + stage (str): One of train, val and test stages. + loss (torch.Tensor): Loss value. + + Returns: + None. + """ + out = self.compute_metrics(predictions, labels) + self.log( + f"{stage}_loss", + loss, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + ) + self.log( + f"{stage}_aAcc", + out["acc"], + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + ) + self.log( + f"{stage}_mIoU", + out["iou"], + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + ) + for idx, value in enumerate(out["iou_per_class"]): + self.log( + f"{stage}_IoU_{idx}", + value, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + ) + for idx, value in enumerate(out["acc_per_class"]): + self.log( + f"{stage}_Acc_{idx}", + value, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + ) + for idx, value in enumerate(out["precision_per_class"]): + self.log( + f"{stage}_Precision_{idx}", + value, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + ) + for idx, value in enumerate(out["recall_per_class"]): + self.log( + f"{stage}_Recall_{idx}", + value, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + ) + + def compute_metrics( + self, pred_mask: torch.Tensor, gt_mask: torch.Tensor + ) -> dict[str, List[float]]: + """Calculate the Intersection over Union (IoU), Accuracy, Precision and Recall metrics. + + Args: + pred_mask (np.array): Predicted segmentation mask. + gt_mask (np.array): Ground truth segmentation mask. + + Returns: + dict: A dictionary containing 'iou', 'overall_accuracy', and + 'accuracy_per_class', 'precision_per_class' and 'recall_per_class'. + """ + pred_mask = torch.argmax(pred_mask, dim=1) + no_ignore = gt_mask.ne(self.ignore_index).to(self.device) + pred_mask = pred_mask.masked_select(no_ignore).cpu().numpy() + gt_mask = gt_mask.masked_select(no_ignore).cpu().numpy() + classes = np.unique(np.concatenate((gt_mask, pred_mask))) + + iou_per_class = [] + accuracy_per_class = [] + precision_per_class = [] + recall_per_class = [] + + for clas in classes: + pred_cls = pred_mask == clas + gt_cls = gt_mask == clas + + intersection = np.logical_and(pred_cls, gt_cls) + union = np.logical_or(pred_cls, gt_cls) + true_positive = np.sum(intersection) + false_positive = np.sum(pred_cls) - true_positive + false_negative = np.sum(gt_cls) - true_positive + + if np.any(union): + iou = np.sum(intersection) / np.sum(union) + iou_per_class.append(iou) + + accuracy = true_positive / np.sum(gt_cls) if np.sum(gt_cls) > 0 else 0 + accuracy_per_class.append(accuracy) + + precision = ( + true_positive / (true_positive + false_positive) + if (true_positive + false_positive) > 0 + else 0 + ) + precision_per_class.append(precision) + + recall = ( + true_positive / (true_positive + false_negative) + if (true_positive + false_negative) > 0 + else 0 + ) + recall_per_class.append(recall) + + # Overall IoU and accuracy + mean_iou = np.mean(iou_per_class) if iou_per_class else 0.0 + overall_accuracy = np.sum(pred_mask == gt_mask) / gt_mask.size + + return { + "iou": mean_iou, + "acc": overall_accuracy, + "acc_per_class": accuracy_per_class, + "iou_per_class": iou_per_class, + "precision_per_class": precision_per_class, + "recall_per_class": recall_per_class, + } + + +@hydra.main(config_path="configs", version_base=None, config_name="config") +def main(cfg: DictConfig) -> None: + """Runner Entry Point. + + Performs training, evaluation or inference/prediction depending on the selected mode. + + Arguments: + cfg (DictConfig): Dict-like object containing necessary values used to configure runner. + + Returns: + None. + """ + log.info(f"Script: {__file__}") + log.info(f"Imported hydra config:\n{OmegaConf.to_yaml(cfg)}") + + BANDS = cfg.dataloader.bands + MEAN = cfg.dataloader.mean + STD = cfg.dataloader.std + IM_SIZE = cfg.dataloader.img_size + TEMPORAL_SIZE = cfg.dataloader.temporal_dim + + batch_size = cfg.train.batch_size + root_dir = cfg.root_dir + valid_filepath = cfg.valid_filepath + train_filepath = cfg.train_filepath + test_filepath = cfg.test_filepath + checkpoint_path = cfg.checkpoint_path + + if cfg.mode == "train": + check_required_flags(["root_dir", "train_filepath", "valid_filepath"], cfg) + train_dataset = InstaGeoDataset( + filename=train_filepath, + input_root=root_dir, + preprocess_func=partial( + process_and_augment, + mean=MEAN, + std=STD, + temporal_size=TEMPORAL_SIZE, + im_size=IM_SIZE, + ), + bands=BANDS, + replace_label=cfg.dataloader.replace_label, + reduce_to_zero=cfg.dataloader.reduce_to_zero, + no_data_value=cfg.dataloader.no_data_value, + constant_multiplier=cfg.dataloader.constant_multiplier, + ) + + valid_dataset = InstaGeoDataset( + filename=valid_filepath, + input_root=root_dir, + preprocess_func=partial( + process_and_augment, + mean=MEAN, + std=STD, + temporal_size=TEMPORAL_SIZE, + im_size=IM_SIZE, + ), + bands=BANDS, + replace_label=cfg.dataloader.replace_label, + reduce_to_zero=cfg.dataloader.reduce_to_zero, + no_data_value=cfg.dataloader.no_data_value, + constant_multiplier=cfg.dataloader.constant_multiplier, + ) + train_loader = create_dataloader( + train_dataset, batch_size=batch_size, shuffle=True, num_workers=1 + ) + valid_loader = create_dataloader( + valid_dataset, batch_size=batch_size, shuffle=False, num_workers=1 + ) + model = PrithviSegmentationModule( + learning_rate=cfg.train.learning_rate, + freeze_backbone=cfg.model.freeze_backbone, + num_classes=cfg.model.num_classes, + temporal_step=cfg.dataloader.temporal_dim, + class_weights=cfg.train.class_weights, + ignore_index=cfg.train.ignore_index, + weight_decay=cfg.train.weight_decay, + ) + hydra_out_dir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir + checkpoint_callback = ModelCheckpoint( + monitor="val_mIoU", + dirpath=hydra_out_dir, + filename="instageo_epoch-{epoch:02d}-val_iou-{val_mIoU:.2f}", + auto_insert_metric_name=False, + mode="max", + save_top_k=3, + ) + + logger = TensorBoardLogger(hydra_out_dir, name="instageo") + + trainer = pl.Trainer( + accelerator=get_device(), + max_epochs=cfg.train.num_epochs, + callbacks=[checkpoint_callback], + logger=logger, + ) + + # run training and validation + trainer.fit(model, train_loader, valid_loader) + + elif cfg.mode == "eval": + check_required_flags(["root_dir", "test_filepath", "checkpoint_path"], cfg) + test_dataset = InstaGeoDataset( + filename=test_filepath, + input_root=root_dir, + preprocess_func=partial( + process_test, + mean=MEAN, + std=STD, + temporal_size=TEMPORAL_SIZE, + img_size=cfg.test.img_size, + crop_size=cfg.test.crop_size, + stride=cfg.test.stride, + ), + bands=BANDS, + replace_label=cfg.dataloader.replace_label, + reduce_to_zero=cfg.dataloader.reduce_to_zero, + no_data_value=cfg.dataloader.no_data_value, + constant_multiplier=cfg.dataloader.constant_multiplier, + ) + test_loader = create_dataloader( + test_dataset, + batch_size=batch_size, + collate_fn=custom_collate_fn + ) + model = PrithviSegmentationModule.load_from_checkpoint( + checkpoint_path, + learning_rate=cfg.train.learning_rate, + freeze_backbone=cfg.model.freeze_backbone, + num_classes=cfg.model.num_classes, + temporal_step=cfg.dataloader.temporal_dim, + class_weights=cfg.train.class_weights, + ignore_index=cfg.train.ignore_index, + weight_decay=cfg.train.weight_decay, + ) + trainer = pl.Trainer(accelerator=get_device()) + result = trainer.test(model, dataloaders=test_loader) + log.info(f"Evaluation results:\n{result}") + elif cfg.mode == "predict": + model = PrithviSegmentationModule.load_from_checkpoint( + cfg.checkpoint_path, + learning_rate=cfg.train.learning_rate, + freeze_backbone=cfg.model.freeze_backbone, + num_classes=cfg.model.num_classes, + temporal_step=cfg.dataloader.temporal_dim, + class_weights=cfg.train.class_weights, + ignore_index=cfg.train.ignore_index, + ) + model.eval() + infer_filepath = os.path.join(root_dir, cfg.test_filepath) + assert ( + os.path.splitext(infer_filepath)[-1] == ".json" + ), f"Test file path expects a json file but got {infer_filepath}" + output_dir = os.path.join(root_dir, "predictions") + os.makedirs(output_dir, exist_ok=True) + with open(os.path.join(infer_filepath)) as json_file: + hls_dataset = json.load(json_file) + for key, hls_tile_path in tqdm( + hls_dataset.items(), desc="Processing HLS Dataset" + ): + try: + hls_tile, _ = process_data( + hls_tile_path, + None, + bands=cfg.dataloader.bands, + no_data_value=cfg.dataloader.no_data_value, + constant_multiplier=cfg.dataloader.constant_multiplier, + mask_cloud=cfg.test.mask_cloud, + replace_label=cfg.dataloader.replace_label, + reduce_to_zero=cfg.dataloader.reduce_to_zero, + ) + except rasterio.RasterioIOError: + continue + nan_mask = hls_tile == cfg.dataloader.no_data_value + nan_mask = np.any(nan_mask, axis=0).astype(int) + hls_tile, _ = process_and_augment( + hls_tile, + None, + mean=cfg.dataloader.mean, + std=cfg.dataloader.std, + temporal_size=cfg.dataloader.temporal_dim, + train=False, + ) + prediction = sliding_window_inference( + hls_tile, + model, + window_size=(cfg.test.img_size, cfg.test.img_size), + stride=cfg.test.stride, + batch_size=cfg.train.batch_size, + device=get_device(), + ) + prediction = np.where(nan_mask == 1, np.nan, prediction) + prediction_filename = os.path.join(output_dir, f"{key}_prediction.tif") + with rasterio.open(hls_tile_path["tiles"]["B02_0"]) as src: + crs = src.crs + transform = src.transform + with rasterio.open( + prediction_filename, + "w", + driver="GTiff", + height=prediction.shape[0], + width=prediction.shape[1], + count=1, + dtype=str(prediction.dtype), + crs=crs, + transform=transform, + ) as dst: + dst.write(prediction, 1) + + +if __name__ == "__main__": + main() diff --git a/notebooks/InstaGeo_Demo.ipynb b/notebooks/InstaGeo_Demo.ipynb new file mode 100644 index 0000000..3525567 --- /dev/null +++ b/notebooks/InstaGeo_Demo.ipynb @@ -0,0 +1,518 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "4574f053-60d4-419b-88cc-c58d08a8177c", + "metadata": {}, + "source": [ + "# InstaGeo Demo" + ] + }, + { + "cell_type": "markdown", + "id": "7eee5ac3", + "metadata": {}, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4fc7ff9c", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "4fc7ff9c", + "outputId": "bb888e31-f3c5-4497-da9b-71853ca7dfff", + "tags": [] + }, + "outputs": [], + "source": [ + "from getpass import getpass\n", + "import os" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e550a300", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "e550a300", + "outputId": "57a02930-a586-4b28-dc98-dab26bbcd87c", + "tags": [] + }, + "outputs": [], + "source": [ + "repository_url = \"https://github.com/instadeepai/InstaGeo-E2E-Geospatial-ML\"\n", + "\n", + "!git clone {repository_url}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9f7e762b", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "collapsed": true, + "id": "9f7e762b", + "jupyter": { + "outputs_hidden": true + }, + "outputId": "c9bf104b-9352-4254-8622-5894f8283744", + "tags": [] + }, + "outputs": [], + "source": [ + "%%bash\n", + "cd InstaGeo-E2E-Geospatial-ML\n", + "pip install -e .[all]" + ] + }, + { + "cell_type": "markdown", + "id": "238c78e7-720e-49ed-b5ce-6be6567d2585", + "metadata": {}, + "source": [ + "## EarthData Login" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8b7a5f61", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "8b7a5f61", + "outputId": "1d9437c6-01f1-4680-caa6-f21302bc02f0" + }, + "outputs": [], + "source": [ + "# Create an EarthData user account on https://urs.earthdata.nasa.gov/\n", + "USERNAME = getpass('Enter your EarthDAata username: ')\n", + "PASSWORD = getpass('Enter your EarthDAata password: ')\n", + "\n", + "content = f\"\"\"machine urs.earthdata.nasa.gov login {USERNAME} password {PASSWORD}\"\"\"\n", + "\n", + "with open(os.path.expanduser('~/.netrc'), 'w') as file:\n", + " file.write(content)" + ] + }, + { + "cell_type": "markdown", + "id": "488a15d3-f22e-4b2c-85c0-435c832e708c", + "metadata": {}, + "source": [ + "## InstaGeo - Data\n", + "\n", + "- Searches and retrieves HLS metadata. \n", + "- Downloads the required bands from HLS granules\n", + "- Creates Chips and Target" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c11bfeb8-df6b-4b8c-8c8a-059a5a28b8ea", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "c11bfeb8-df6b-4b8c-8c8a-059a5a28b8ea", + "outputId": "2636c9ce-912a-4c36-f9c3-6fb7d15f8beb", + "tags": [] + }, + "outputs": [], + "source": [ + "%%bash\n", + "mkdir demo\n", + "python -m \"instageo.data.chip_creator\" \\\n", + " --dataframe_path=\"InstaGeo-E2E-Geospatial-ML/tests/data/test_breeding_data.csv\" \\\n", + " --output_directory=\"demo\" \\\n", + " --min_count=4 \\\n", + " --chip_size=224 \\\n", + " --no_data_value=-1 \\\n", + " --temporal_tolerance=3 \\\n", + " --temporal_step=30 \\\n", + " --mask_cloud=False \\\n", + " --num_steps=3" + ] + }, + { + "cell_type": "markdown", + "id": "bcb132b0-04dd-4583-8e30-e6332446a0e6", + "metadata": {}, + "source": [ + "## InstaGeo - Model" + ] + }, + { + "cell_type": "markdown", + "id": "fd313044-829c-482d-934e-9ae662f132fc", + "metadata": {}, + "source": [ + "Create a dataframe for created chips and target that will be used with model component" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c64fee29-dca3-4611-a43a-4412a6033751", + "metadata": { + "id": "c64fee29-dca3-4611-a43a-4412a6033751" + }, + "outputs": [], + "source": [ + "import os\n", + "import os\n", + "import pandas as pd\n", + "import numpy as np\n", + "from pathlib import Path" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3b11cb3a-1940-4056-a6a7-33baa375bddb", + "metadata": { + "id": "3b11cb3a-1940-4056-a6a7-33baa375bddb" + }, + "outputs": [], + "source": [ + "root_dir = Path.cwd()\n", + "chips_orig = os.listdir(os.path.join(root_dir, \"demo/chips\"))\n", + "chips = [chip.replace(\"chip\", \"chips/chip\") for chip in chips_orig]\n", + "seg_maps = [chip.replace(\"chip\", \"seg_maps/seg_map\") for chip in chips_orig]\n", + "\n", + "df = pd.DataFrame({\"Input\": chips, \"Label\": seg_maps})\n", + "df.to_csv(os.path.join(root_dir, \"demo/demo_train.csv\"))\n", + "df.to_csv(os.path.join(root_dir, \"demo/demo_val.csv\"))" + ] + }, + { + "cell_type": "markdown", + "id": "6778e394-43ef-42ee-ae24-e00122e86d10", + "metadata": {}, + "source": [ + "Launch training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "47dc76c1-26b9-47b9-8066-5c1df499b40f", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "collapsed": true, + "id": "47dc76c1-26b9-47b9-8066-5c1df499b40f", + "jupyter": { + "outputs_hidden": true + }, + "outputId": "72e8f6e4-0263-431a-c313-47df5b0cbded", + "tags": [] + }, + "outputs": [], + "source": [ + "# %%bash\n", + "!python -m instageo.model.run --config-name=locust \\\n", + " root_dir='demo' \\\n", + " train.batch_size=8 \\\n", + " train.num_epochs=3 \\\n", + " train_filepath=\"demo/demo_train.csv\" \\\n", + " valid_filepath=\"demo/demo_val.csv\"" + ] + }, + { + "cell_type": "markdown", + "id": "a8a98a64-0962-40e5-a52e-1e8687198862", + "metadata": { + "tags": [] + }, + "source": [ + "Run Model Evaluation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f8230588-2842-460c-b2e7-4f2f4175dc3e", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "collapsed": true, + "id": "f8230588-2842-460c-b2e7-4f2f4175dc3e", + "jupyter": { + "outputs_hidden": true + }, + "outputId": "94515cc4-972a-4e15-d3fe-1dc28200f848", + "tags": [] + }, + "outputs": [], + "source": [ + "# %%bash\n", + "!python -m instageo.model.run --config-name=locust \\\n", + " root_dir='demo' \\\n", + " test_filepath=\"demo/demo_val.csv\" \\\n", + " train.batch_size=8 \\\n", + " checkpoint_path='outputs/2024-09-12/12-17-46/instageo_epoch-02-val_iou-0.00.ckpt' \\\n", + " mode=eval" + ] + }, + { + "cell_type": "markdown", + "id": "b9a73d79-f69f-4fb8-aa63-768ee3ee47ea", + "metadata": {}, + "source": [ + "**Run Inference**\n", + "\n", + "- Runs inference on the Africa continent (plus some part of Asia)\n", + "- Downloads and run inference on 2,120 HLS tiles" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "YgxyC5GnvjJU", + "metadata": { + "collapsed": true, + "id": "YgxyC5GnvjJU", + "jupyter": { + "outputs_hidden": true + }, + "tags": [] + }, + "outputs": [], + "source": [ + "!gsutil -m cp -r gs://instageo/checkpoints/locust_breeding_ckpt .\n", + "!gsutil cp gs://instageo/utils/africa_prediction_template.csv .\n", + "!mkdir -p inference/2023-01" + ] + }, + { + "cell_type": "markdown", + "id": "9c30afd0-73ec-4eb1-b9c1-a2d36475eae7", + "metadata": {}, + "source": [ + "**Create Inference Data**\n", + "\n", + "During inference we only download HLS tiles and irectly runinference on them using the sliding window inference feature." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8f76b3ea-8ef7-473b-8f2e-72d493fdba03", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "collapsed": true, + "id": "8f76b3ea-8ef7-473b-8f2e-72d493fdba03", + "jupyter": { + "outputs_hidden": true + }, + "outputId": "f623f2a0-77b2-46c8-aa5a-586452c84c9f", + "tags": [] + }, + "outputs": [], + "source": [ + "!python -m \"instageo.data.chip_creator\" \\\n", + " --dataframe_path=\"africa_prediction_template.csv\" \\\n", + " --output_directory=\"inference/2023-01\" \\\n", + " --min_count=1 \\\n", + " --no_data_value=-1 \\\n", + " --temporal_tolerance=3 \\\n", + " --temporal_step=30 \\\n", + " --num_steps=3 \\\n", + " --download_only" + ] + }, + { + "cell_type": "markdown", + "id": "65e86341-0717-458f-b061-3e957ce8536d", + "metadata": {}, + "source": [ + "**Run Inference**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "JJXq8oWNAr1w", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "JJXq8oWNAr1w", + "outputId": "74f43e30-f642-4eb2-ca6f-64f9343a275c", + "tags": [] + }, + "outputs": [], + "source": [ + "!python -m instageo.model.run --config-name=locust \\\n", + " root_dir='inference/2023-01' \\\n", + " test_filepath='hls_dataset.json' \\\n", + " train.batch_size=16 \\\n", + " test.mask_cloud=True \\\n", + " checkpoint_path='locust_breeding_ckpt/instageo_epoch-02-val_iou-0.69.ckpt' \\\n", + " mode=predict" + ] + }, + { + "cell_type": "markdown", + "id": "8d79e412-8176-411a-8a73-6068e22a0d25", + "metadata": {}, + "source": [ + "## InstaGeo - Apps\n", + "\n", + "Visualize predictions my moving HLS prediction GeoTiff files to an appropriate directory." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8a904882-cde9-40c8-affb-f988892e5183", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "!mkdir -p predictions/2023/1\n", + "!mv inference/2023-01/predictions/* content/predictions/2023/1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "tQGnk67MY6Cd", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "tQGnk67MY6Cd", + "outputId": "8b7a5935-7817-40f3-97d6-94a0ea2f69ef", + "tags": [] + }, + "outputs": [], + "source": [ + "!npm install localtunnel" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "BeJsQzkeBNm7", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "collapsed": true, + "id": "BeJsQzkeBNm7", + "jupyter": { + "outputs_hidden": true + }, + "outputId": "d4488061-d390-4c9e-c10c-bdd4c0222f93", + "tags": [] + }, + "outputs": [], + "source": [ + "!nohup streamlit run InstaGeo/instageo/apps/app.py --server.address=localhost &" + ] + }, + { + "cell_type": "markdown", + "id": "48459844-9277-473b-8cd8-e47691c59c0a", + "metadata": {}, + "source": [ + "Retrieve your IP address which is the password of the localtunnel" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "S5rCS-lVZiWe", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "S5rCS-lVZiWe", + "outputId": "cbd68bed-b165-4fd3-8c87-14a5dcad4dbc", + "tags": [] + }, + "outputs": [], + "source": [ + "import urllib\n", + "print(\"Password/Enpoint IP for localtunnel is:\",urllib.request.urlopen('https://ipv4.icanhazip.com').read().decode('utf8').strip(\"\\n\"))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bB61RRBNY-48", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "bB61RRBNY-48", + "outputId": "28bb4e8b-56c9-454c-89f2-47f576a5723b", + "tags": [] + }, + "outputs": [], + "source": [ + "!npx localtunnel --port 8501" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3c44bc28-bc5d-43d8-aad8-1e69bace9b0f", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] + }, + "environment": { + "kernel": "python3", + "name": "tf2-cpu.2-11.m115", + "type": "gcloud", + "uri": "gcr.io/deeplearning-platform-release/tf2-cpu.2-11:m115" + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..943def2 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,34 @@ +[tool.isort] +profile = "black" + +[tool.mypy] +python_version = 3.10 +namespace_packages = true +incremental = false +cache_dir = "" +warn_redundant_casts = true +warn_return_any = true +warn_unused_configs = true +warn_unused_ignores = false +allow_redefinition = true +disallow_untyped_calls = true +disallow_untyped_defs = true +disallow_incomplete_defs = true +check_untyped_defs = true +disallow_untyped_decorators = false +strict_optional = true +strict_equality = true +explicit_package_bases = true +follow_imports = "skip" + +[[tool.mypy.overrides]] +module = [ + "numpy.*", + "pytest.*", +] +ignore_missing_imports = true + +[tool.pytest.ini_options] +markers = [ + "auth: Requires NASA EarthData Authentication" +] diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..d4008cc --- /dev/null +++ b/requirements.txt @@ -0,0 +1,24 @@ +geopandas==0.14.1 +shapely +pandas +xarray[complete] +rasterio +earthpy +cftime +h5pyd +Bottleneck +absl-py +black +flake8 +isort +jupyter +mypy +neptune-client +numpy +pre-commit +pytest +setuptools +streamlit +tqdm +pytest-cov +rioxarray diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..a4ea2f7 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,74 @@ +[flake8] +# Selected error codes for reporting +select = A,B,C,D,E,F,G,I,N,T,W + +# Directories and files to be excluded from checks +exclude = + .tox, + .git, + __pycache__, + build, + dist, + proto/*, + *.pyc, + *.egg-info, + .cache, + .eggs + +# Maximum allowed line length in the code +max-line-length = 100 + +# Maximum allowed cognitive complexity for functions +max-cognitive-complexity = 10 + +# Enable syntax checks in docstrings +doctests = True + +# Import order style to follow +import-order-style = google + +# Convention for docstrings +docstring-convention = google + +# File-specific ignore rules +per-file-ignores = + __init__.py:F401 + +# Specific error codes to ignore +ignore = + A002 + A003 + D107 + E266 + E731 + W503 + +[mypy] +# Mypy configuration settings +python_version = 3.10 +namespace_packages = True +incremental = True +cache_dir = nul +warn_redundant_casts = True +warn_return_any = False +warn_unused_configs = False +warn_unused_ignores = False +disallow_untyped_calls = True +disallow_untyped_defs = True +disallow_incomplete_defs = True +check_untyped_defs = True +disallow_untyped_decorators = False +strict_optional = True +strict_equality = True +explicit_package_bases = True +ignore_missing_imports = True + +[tool:pytest] +# PyTest configuration settings +minversion = 6.0 +addopts = -ra -q +testpaths = + tests +python_files = test_*.py +python_classes = *Tests +python_functions = test_* diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..aa6f67b --- /dev/null +++ b/setup.py @@ -0,0 +1,61 @@ +"""InstaGeo Package Setup.""" + +from setuptools import find_packages, setup + +data_dependencies = [ + # Add dependencies specific to the data component + "geopandas==0.14.1", + "shapely", + "cftime", + "h5pyd", + "Bottleneck", + "absl-py", + "mgrs==1.4.6", + "earthaccess==0.8.2", +] +model_dependencies = [ + # Add dependencies specific to the model component + "pytorch_lightning", + "torch", + "timm==0.4.12", + "einops", + "tensorboard", + "hydra-core", + "omegaconf", +] +apps_dependencies = [ + # Add dependencies specific to the apps component + "plotly", + "datashader", + "matplotlib", + "streamlit==1.31.1", +] +setup( + name="instageo", + version="0.0.1", + packages=find_packages(), + install_requires=[ + # Add common dependencies here + "pandas", + "numpy", + "xarray[complete]", + "rasterio", + "rioxarray", + ], + extras_require={ + "data": data_dependencies, + "model": model_dependencies, + "apps": apps_dependencies, + "all": data_dependencies + model_dependencies + apps_dependencies, + }, + python_requires=">=3.10", + tests_require=["pytest", "pytest-cov"], + setup_requires=["pytest-runner"], + test_suite="tests", + author="Ibrahim Salihu Yusuf", + author_email="i.yusuf@instadeep.com", + description="""A modular Python package for geospatial data processing, modeling, + and applications using Harmonized Landsat Sentinel (HLS) Data""", + keywords="geospatial, machine learning, data, model, applications", + url="https://github.com/instadeepai/InstaGeo-E2E-Geospatial-ML", +) diff --git a/tests/apps_tests/test_viz.py b/tests/apps_tests/test_viz.py new file mode 100644 index 0000000..9c7f17c --- /dev/null +++ b/tests/apps_tests/test_viz.py @@ -0,0 +1,42 @@ +import numpy as np + +# import plotly.graph_objects as go +import xarray as xr +from PIL import Image + +from instageo.apps.viz import ( # create_map_with_geotiff_tiles, + add_raster_to_plotly_figure, + get_crs, + read_geotiff_to_xarray, +) + + +def test_read_geotiff_to_xarray(): + data, crs = read_geotiff_to_xarray("tests/data/sample.tif") + assert isinstance(data, xr.Dataset) + assert crs == 32613 + + +def test_get_crs(): + assert get_crs("tests/data/sample.tif") == 32613 + + +def test_get_image_and_coordinates(): + ds, crs = read_geotiff_to_xarray("tests/data/sample.tif") + img, coords = add_raster_to_plotly_figure(ds, crs) + assert isinstance(img, Image.Image) + np.testing.assert_almost_equal( + coords, + [ + [-107.16932831572083, 41.38046859355541], + [-107.08744863572721, 41.38046859355541], + [-107.08744863572721, 41.31873255032166], + [-107.16932831572083, 41.31873255032166], + ], + ) + + +# def test_create_map_with_geotiff_tiles(): +# fig = create_map_with_geotiff_tiles(["tests/data/sample.tif"]) +# assert isinstance(fig, go.Figure) +# assert len(fig.layout.mapbox.layers) == 1 diff --git a/tests/data/HLS.S30.T38PMB.2022145T072619.v2.0.B02.tif b/tests/data/HLS.S30.T38PMB.2022145T072619.v2.0.B02.tif new file mode 100644 index 0000000..923a248 Binary files /dev/null and b/tests/data/HLS.S30.T38PMB.2022145T072619.v2.0.B02.tif differ diff --git a/tests/data/fmask.tif b/tests/data/fmask.tif new file mode 100644 index 0000000..a666c26 Binary files /dev/null and b/tests/data/fmask.tif differ diff --git a/tests/data/sample.csv b/tests/data/sample.csv new file mode 100644 index 0000000..87bdfe9 --- /dev/null +++ b/tests/data/sample.csv @@ -0,0 +1,6 @@ +x,y,date +10,1,2023-01-01 +20,2,2023-01-02 +30,3,2023-01-03 +40,4,2023-01-04 +50,5,2023-01-05 diff --git a/tests/data/sample.tif b/tests/data/sample.tif new file mode 100644 index 0000000..923a248 Binary files /dev/null and b/tests/data/sample.tif differ diff --git a/tests/data/sample_4326.csv b/tests/data/sample_4326.csv new file mode 100644 index 0000000..7e0a71d --- /dev/null +++ b/tests/data/sample_4326.csv @@ -0,0 +1,11 @@ +,x,y,label +0,-107.11335902745832,41.37894599955397,1 +1,-107.14567939343958,41.32538721089502,1 +2,-107.10102772173182,41.341342943275286,1 +3,-107.15303320631186,41.320115889120856,1 +4,-107.12079158922131,41.37583693617166,1 +5,-107.09514728746764,41.33685665334261,1 +6,-107.11232871594486,41.32492353540072,1 +7,-107.15812818069698,41.32245229497166,1 +8,-107.10489272525344,41.32776190077167,1 +9,-107.13281252845832,41.33751558131009,1 diff --git a/tests/data/test_breeding_data.csv b/tests/data/test_breeding_data.csv new file mode 100644 index 0000000..1607418 --- /dev/null +++ b/tests/data/test_breeding_data.csv @@ -0,0 +1,51 @@ +,index,date,x,y,label,CAT,generated,year +0,1424,2023-03-08,43.03915,15.315846,0,Ecology,0,2023 +1,1976,2023-04-05,73.42101,27.547216,0,Ecology,0,2023 +2,4211,2023-05-18,44.42283,17.62352,0,Ecology,0,2023 +3,2905,2023-03-04,-0.171886,28.246159,0,Ecology,0,2023 +4,72,2022-06-14,33.708817,17.1869,0,Ecology,0,2022 +5,4061,2023-05-11,43.026635,25.717982,0,Ecology,0,2023 +6,6955,2022-10-26,45.326618,9.726063,0,Ecology,0,2022 +7,8381,2022-10-21,43.281983,14.984089,0,Ecology,0,2022 +8,8276,2022-10-18,20.78685,15.217167,0,Ecology,0,2022 +9,4331,2023-04-24,41.32555,27.22835,0,Ecology,0,2023 +10,6885,2022-11-07,48.774927,15.969577,0,Ecology,0,2022 +11,8923,2023-05-10,42.759551,24.169729,1,Adult,0,2023 +12,2145,2023-01-09,37.190179,19.94838,0,Ecology,0,2023 +13,6755,2022-11-03,-0.108,27.565933,0,Ecology,0,2022 +14,6677,2022-08-09,72.14875,27.760517,0,Ecology,0,2022 +15,4151,2023-05-22,40.122413,24.432497,0,Ecology,0,2023 +16,7424,2022-12-03,49.111667,30.953056,0,Ecology,0,2022 +17,5700,2022-07-30,30.918502,17.974223,0,Ecology,0,2022 +18,4858,2022-08-19,51.94742,17.153655,0,Ecology,0,2022 +19,5518,2022-08-29,51.341308,16.71492,0,Ecology,0,2022 +20,5756,2022-08-04,44.665585,16.228966,0,Ecology,0,2022 +21,6841,2022-11-11,43.057165,15.715822,0,Ecology,0,2022 +22,4096,2023-05-13,43.145633,25.725376,0,Ecology,0,2023 +23,3127,2023-02-06,39.751026,21.710444,0,Ecology,0,2023 +24,4652,2022-08-24,71.96893,27.938168,0,Ecology,0,2022 +25,553,2022-07-19,69.6824,23.4292,0,Ecology,0,2022 +26,6654,2022-08-09,-12.302033,18.584383,0,Ecology,0,2022 +27,9309,2023-04-13,-7.846384,29.317383,1,Hopper,0,2023 +28,6425,2022-08-15,50.585011,17.431202,0,Ecology,0,2022 +29,4755,2022-08-26,50.539011,16.219579,0,Ecology,0,2022 +30,8937,2023-02-19,37.326446,19.068729,1,Adult,0,2023 +31,2152,2023-01-09,43.248547,14.873787,0,Ecology,0,2023 +32,5714,2022-07-30,36.649834,14.774283,0,Ecology,0,2022 +33,7184,2022-12-18,43.17153,14.601784,0,Ecology,0,2022 +34,9083,2023-05-30,41.952972,25.364161,1,Hopper,0,2023 +35,2420,2022-12-29,45.344326,13.264151,0,Ecology,0,2022 +36,403,2022-05-12,28.615,22.693533,0,Ecology,0,2022 +37,1095,2023-03-17,43.185587,15.171837,0,Ecology,0,2023 +38,6706,2022-09-10,44.684444,15.346389,0,Ecology,0,2022 +39,3176,2023-02-12,34.72385,23.87805,0,Ecology,0,2023 +40,7766,2022-09-15,31.924797,16.258003,0,Ecology,0,2022 +41,6520,2022-08-06,44.722084,16.205366,0,Ecology,0,2022 +42,2460,2022-12-26,41.433401,18.602769,0,Ecology,0,2022 +43,7619,2022-12-12,-0.035867,28.249033,0,Ecology,0,2022 +44,9427,2023-03-27,38.739928,23.760671,1,Hopper,0,2023 +45,1314,2023-03-14,43.637681,26.091387,0,Ecology,0,2023 +46,826,2022-07-12,69.827385,23.400734,0,Ecology,0,2022 +47,8582,2022-09-23,44.77788,16.089483,0,Ecology,0,2022 +48,3006,2023-01-28,38.565735,17.998667,0,Ecology,0,2023 +49,8421,2022-10-21,43.2613,14.287534,0,Ecology,0,2022 diff --git a/tests/data_tests/test_chip_creator.py b/tests/data_tests/test_chip_creator.py new file mode 100644 index 0000000..235910d --- /dev/null +++ b/tests/data_tests/test_chip_creator.py @@ -0,0 +1,159 @@ +import os +import pathlib +import shutil + +import geopandas as gpd +import numpy as np +import pandas as pd +import pytest +import xarray as xr +from absl import flags +from shapely.geometry import Point + +from instageo.data import chip_creator +from instageo.data.chip_creator import app, check_required_flags, get_chip_coords + +FLAGS = flags.FLAGS + +test_root = pathlib.Path(__file__).parent.resolve() + + +@pytest.fixture +def setup_and_teardown_output_dir(): + output_dir = "/tmp/csv_chip_creator" + os.makedirs(output_dir, exist_ok=True) + yield + try: + shutil.rmtree(os.path.join(output_dir, "chips")) + shutil.rmtree(os.path.join(output_dir, "seg_maps")) + os.remove(os.path.join(output_dir, "granules_to_download.csv")) + os.remove(os.path.join(output_dir, "hls_dataset.json")) + os.remove(os.path.join(output_dir, "hls_chips_dataset.csv")) + except FileNotFoundError: + pass + + +def test_get_chip_coords(): + df = pd.read_csv("tests/data/sample_4326.csv") + df = gpd.GeoDataFrame(df, geometry=[Point(xy) for xy in zip(df.x, df.y)]) + df.set_crs(epsg=4326, inplace=True) + df = df.to_crs(crs=32613) + + ds = xr.open_dataset("tests/data/HLS.S30.T38PMB.2022145T072619.v2.0.B02.tif") + chip_coords = get_chip_coords(df, ds, 64) + assert chip_coords == [ + (2, 0), + (0, 3), + (2, 2), + (0, 3), + (2, 0), + (3, 2), + (2, 3), + (0, 3), + (2, 3), + (1, 2), + ] + + +@pytest.mark.auth +def test_chip_creator(setup_and_teardown_output_dir): + output_directory = "/tmp/csv_chip_creator" + argv = [ + "chip_creator", + "--dataframe_path", + os.path.join(os.path.dirname(test_root), "data/test_breeding_data.csv"), + "--output_directory", + output_directory, + "--min_count", + "4", + "--chip_size", + "512", + "--no_data_value", + "-1", + "--temporal_tolerance", + "1", + "--temporal_step", + "30", + "--num_steps", + "1", + ] + FLAGS(argv) + chip_creator.main("None") + chips = os.listdir(os.path.join(output_directory, "chips")) + seg_maps = os.listdir(os.path.join(output_directory, "seg_maps")) + assert len(chips) == len(seg_maps) + assert len(chips) == 3 + chip_path = os.path.join(output_directory, "chips", chips[0]) + seg_map_path = os.path.join(output_directory, "seg_maps", seg_maps[0]) + chip = xr.open_dataset(chip_path) + seg_map = xr.open_dataset(seg_map_path) + assert chip.band_data.shape == (6, 512, 512) + assert np.unique(chip.band_data).size > 1 + assert seg_map.band_data.shape == (1, 512, 512) + assert np.unique(seg_map.band_data).size > 1 + assert ( + len( + pd.read_csv(os.path.join(output_directory, "granules_to_download.csv"))[ + "tiles" + ].tolist() + ) + == 21 + ) + + +@pytest.mark.auth +def test_chip_creator_download_only(setup_and_teardown_output_dir): + output_directory = "/tmp/csv_chip_creator" + argv = [ + "chip_creator", + "--dataframe_path", + os.path.join(os.path.dirname(test_root), "data/test_breeding_data.csv"), + "--output_directory", + output_directory, + "--min_count", + "4", + "--chip_size", + "512", + "--no_data_value", + "-1", + "--temporal_tolerance", + "1", + "--temporal_step", + "30", + "--num_steps", + "1", + "--download_only", + ] + FLAGS(argv) + chip_creator.main("None") + assert os.path.exists(os.path.join(output_directory, "hls_dataset.json")) + assert os.path.exists(os.path.join(output_directory, "granules_to_download.csv")) + assert not os.path.exists(os.path.join(output_directory, "chips")) + assert not os.path.exists(os.path.join(output_directory, "seg_maps")) + + +def test_missing_flags_raises_error(): + """Test missing flags.""" + FLAGS([__file__]) + with pytest.raises(app.UsageError) as excinfo: + check_required_flags() + assert "Flag --dataframe_path is required" in str( + excinfo.value + ) or "Flag --output_directory is required" in str( + excinfo.value + ), "Expected UsageError with a message about missing required flags" + + +def test_no_missing_flags(): + """Test correct flags.""" + FLAGS( + [ + __file__, + "--dataframe_path=/path/to/dataframe", + "--output_directory=/path/to/output", + ] + ) + try: + check_required_flags() + except app.UsageError: + pytest.fail("UsageError was raised even though no flags were missing.") diff --git a/tests/data_tests/test_create_chips.py b/tests/data_tests/test_create_chips.py new file mode 100644 index 0000000..ea0ff7c --- /dev/null +++ b/tests/data_tests/test_create_chips.py @@ -0,0 +1,60 @@ +import os +import shutil + +import numpy as np +import pandas as pd +import pytest +import xarray as xr + +from instageo.data.chip_creator import create_and_save_chips_with_seg_maps + + +@pytest.fixture +def setup_and_teardown_output_dir(): + output_dir = "/tmp/output" + os.makedirs(output_dir, exist_ok=True) + yield + shutil.rmtree(output_dir) + + +def test_create_chips(setup_and_teardown_output_dir): + geotiff_path = "tests/data/HLS.S30.T38PMB.2022145T072619.v2.0.B02.tif" + fmask_path = "tests/data/fmask.tif" + chip_size = 64 + output_directory = "/tmp/output" + no_data_value = -1 + df = pd.read_csv("tests/data/sample_4326.csv") + df["date"] = pd.to_datetime("2020-01-01") + os.makedirs(os.path.join(output_directory, "chips"), exist_ok=True) + os.makedirs(os.path.join(output_directory, "seg_maps"), exist_ok=True) + chips, labels = create_and_save_chips_with_seg_maps( + { + "tiles": {"B02_0": geotiff_path, "B04_0": geotiff_path}, + "fmasks": {"Fmask_0": fmask_path}, + }, + df, + chip_size, + output_directory, + no_data_value, + src_crs=4326, + mask_cloud=False, + ) + num_chips = len(chips) + + assert num_chips == 3 + for i in range(num_chips): + chip_path = os.path.join( + output_directory, "chips", "chip_20200101_S30_T38PMB_2022145T072619_1_2.tif" + ) + seg_map_path = os.path.join( + output_directory, + "seg_maps", + "seg_map_20200101_S30_T38PMB_2022145T072619_1_2.tif", + ) + chip = xr.open_dataset(chip_path) + seg_map = xr.open_dataset(seg_map_path) + + assert chip.band_data.shape == (2, 64, 64) + assert np.unique(chip.band_data).size > 1 + assert seg_map.band_data.shape == (1, 64, 64) + assert np.unique(seg_map.band_data).size > 1 diff --git a/tests/data_tests/test_geo_utils.py b/tests/data_tests/test_geo_utils.py new file mode 100644 index 0000000..c79509d --- /dev/null +++ b/tests/data_tests/test_geo_utils.py @@ -0,0 +1,67 @@ +import pytest +import xarray as xr +from rasterio.crs import CRS + +from instageo.data.geo_utils import decode_fmask_value, open_mf_tiff_dataset + + +def test_open_mf_tiff_dataset(): + band_files = { + "tiles": { + "band1": "tests/data/sample.tif", + "band2": "tests/data/sample.tif", + }, + "fmasks": { + "band1": "tests/data/fmask.tif", + "band2": "tests/data/fmask.tif", + }, + } + + result, crs = open_mf_tiff_dataset(band_files, mask_cloud=False) + assert isinstance(result, xr.Dataset) + assert isinstance(crs, CRS) + assert crs == 32613 + assert result["band_data"].shape == (2, 224, 224) + + +def test_open_mf_tiff_dataset_cloud_mask(): + band_files = { + "tiles": { + "band1": "tests/data/sample.tif", + "band2": "tests/data/sample.tif", + }, + "fmasks": { + "band1": "tests/data/fmask.tif", + "band2": "tests/data/fmask.tif", + }, + } + result_no_mask, crs = open_mf_tiff_dataset(band_files, mask_cloud=False) + num_points = result_no_mask.band_data.count().values.item() + result_with_mask, crs = open_mf_tiff_dataset(band_files, mask_cloud=True) + fmask = xr.open_dataset("tests/data/fmask.tif") + cloud_mask = decode_fmask_value(fmask, 1) + num_clouds = cloud_mask.where(cloud_mask == 1).band_data.count().values.item() + assert ( + result_with_mask.band_data.count().values.item() == num_points - 2 * num_clouds + ) + assert isinstance(result_with_mask, xr.Dataset) + assert isinstance(crs, CRS) + assert crs == 32613 + assert result_with_mask["band_data"].shape == (2, 224, 224) + + +@pytest.mark.parametrize( + "value, position, result", + [ + (100, 0, 0), + (100, 1, 0), + (100, 2, 1), + (100, 3, 0), + (100, 4, 0), + (100, 5, 1), + (100, 6, 1), + (100, 7, 0), + ], +) +def test_decode_fmask_value(value, position, result): + assert decode_fmask_value(value, position) == result diff --git a/tests/data_tests/test_hls_utils.py b/tests/data_tests/test_hls_utils.py new file mode 100644 index 0000000..5d867fb --- /dev/null +++ b/tests/data_tests/test_hls_utils.py @@ -0,0 +1,401 @@ +import datetime +import os +import shutil + +import pandas as pd +import pytest +import rasterio +from rasterio.crs import CRS + +from instageo.data.chip_creator import create_hls_dataset +from instageo.data.hls_utils import ( + add_hls_granules, + find_closest_tile, + get_hls_tile_info, + get_hls_tiles, + parallel_download, + parse_date_from_entry, + retrieve_hls_metadata, +) + + +@pytest.fixture +def setup_and_teardown_output_dir(): + output_dir = "/tmp/test_hls" + os.makedirs(output_dir, exist_ok=True) + yield + shutil.rmtree(output_dir) + + +@pytest.fixture +def observation_data(): + data = pd.DataFrame( + { + "date": { + 0: "2022-06-08", + 1: "2022-06-08", + 2: "2022-06-08", + 3: "2022-06-08", + 4: "2022-06-09", + 5: "2022-06-09", + 6: "2022-06-09", + 7: "2022-06-08", + 8: "2022-06-09", + 9: "2022-06-09", + }, + "x": { + 0: 44.48, + 1: 44.48865, + 2: 46.437787, + 3: 49.095545, + 4: -0.1305, + 5: 44.6216, + 6: 49.398908, + 7: 44.451435, + 8: 49.435228, + 9: 44.744167, + }, + "y": { + 0: 15.115617, + 1: 15.099767, + 2: 14.714659, + 3: 16.066929, + 4: 28.028967, + 5: 16.16195, + 6: 16.139727, + 7: 15.209633, + 8: 16.151837, + 9: 15.287778, + }, + "year": { + 0: 2022, + 1: 2022, + 2: 2022, + 3: 2022, + 4: 2022, + 5: 2022, + 6: 2022, + 7: 2022, + 8: 2022, + 9: 2022, + }, + } + ) + data["date"] = pd.to_datetime(data["date"]) + data["input_features_date"] = data["date"] + return data + + +def test_get_hls_tiles(observation_data): + hls_tiles = get_hls_tiles(data=observation_data, min_count=1) + assert list(hls_tiles["mgrs_tile_id"]) == [ + "38PMB", + "38PMB", + "38PPB", + "39QTT", + "30RYS", + "38QMC", + "39QUT", + "38PMB", + "39QUT", + "38PMB", + ] + + +def test_get_hls_tile_info(observation_data): + hls_tiles = get_hls_tiles(observation_data, min_count=3) + tiles_info, tile_queries = get_hls_tile_info( + hls_tiles, num_steps=3, temporal_step=5 + ) + pd.testing.assert_frame_equal( + tiles_info, + pd.DataFrame( + { + "tile_id": ["38PMB"], + "min_date": ["2022-05-29"], + "max_date": ["2022-06-09"], + "lon_min": [44.451435], + "lon_max": [44.744167], + "lat_min": [15.099767], + "lat_max": [15.287778], + } + ), + check_like=True, + ) + assert tile_queries == [ + ("38PMB", ["2022-06-08", "2022-06-03", "2022-05-29"]), + ("38PMB", ["2022-06-08", "2022-06-03", "2022-05-29"]), + ("38PMB", ["2022-06-08", "2022-06-03", "2022-05-29"]), + ("38PMB", ["2022-06-09", "2022-06-04", "2022-05-30"]), + ] + + +def test_retrieve_hls_metadata(): + tile_info = pd.DataFrame( + { + "tile_id": ["38PMB"], + "min_date": ["2022-05-29"], + "max_date": ["2022-06-09"], + "lon_min": [44.451435], + "lon_max": [44.744167], + "lat_min": [15.099767], + "lat_max": [15.287778], + } + ) + tile_database = retrieve_hls_metadata(tile_info) + assert tile_database == { + "38PMB": [ + "HLS.S30.T38PMB.2022150T072621.v2.0", + "HLS.S30.T38PMB.2022152T071619.v2.0", + "HLS.L30.T38PMB.2022154T072604.v2.0", + "HLS.L30.T38PMB.2022155T071923.v2.0", + "HLS.S30.T38PMB.2022155T072619.v2.0", + "HLS.S30.T38PMB.2022157T071631.v2.0", + "HLS.S30.T38PMB.2022160T072621.v2.0", + ] + } + + +def test_add_hls_granules(observation_data): + data = get_hls_tiles(observation_data, min_count=3) + result = add_hls_granules(data) + assert list(result["hls_tiles"]) == [ + [ + "HLS.L30.T38PMB.2022154T072604.v2.0", + "HLS.S30.T38PMB.2022145T072619.v2.0", + "HLS.L30.T38PMB.2022139T071922.v2.0", + ], + [ + "HLS.L30.T38PMB.2022154T072604.v2.0", + "HLS.S30.T38PMB.2022145T072619.v2.0", + "HLS.L30.T38PMB.2022139T071922.v2.0", + ], + [ + "HLS.L30.T38PMB.2022154T072604.v2.0", + "HLS.S30.T38PMB.2022145T072619.v2.0", + "HLS.L30.T38PMB.2022139T071922.v2.0", + ], + [ + "HLS.L30.T38PMB.2022155T071923.v2.0", + "HLS.S30.T38PMB.2022145T072619.v2.0", + "HLS.L30.T38PMB.2022139T071922.v2.0", + ], + ] + + +def test_find_closest_tile(): + tile_database = { + "30RYS": [ + "HLS.L30.T30RYS.2022140T102748.v2.0.", + "HLS.S30.T30RYS.2022142T103629.v2.0.", + "HLS.S30.T30RYS.2022144T102611.v2.0.", + "HLS.L30.T30RYS.2022147T103357.v2.0.", + "HLS.S30.T30RYS.2022147T103631.v2.0.", + "HLS.L30.T30RYS.2022148T102722.v2.0.", + "HLS.S30.T30RYS.2022149T102559.v2.0.", + "HLS.S30.T30RYS.2022152T103629.v2.0.", + "HLS.S30.T30RYS.2022154T102611.v2.0.", + "HLS.L30.T30RYS.2022155T103334.v2.0.", + "HLS.L30.T30RYS.2022156T102755.v2.0.", + "HLS.S30.T30RYS.2022157T103631.v2.0.", + "HLS.S30.T30RYS.2022159T102559.v2.0.", + ], + "38PMB": [ + "HLS.L30.T38PMB.2022139T071922.v2.0.", + "HLS.S30.T38PMB.2022140T072621.v2.0.", + "HLS.S30.T38PMB.2022142T071619.v2.0.", + "HLS.S30.T38PMB.2022145T072619.v2.0.", + "HLS.L30.T38PMB.2022146T072532.v2.0.", + "HLS.L30.T38PMB.2022147T071947.v2.0.", + "HLS.S30.T38PMB.2022147T071621.v2.0.", + "HLS.S30.T38PMB.2022150T072621.v2.0.", + "HLS.S30.T38PMB.2022152T071619.v2.0.", + "HLS.L30.T38PMB.2022154T072604.v2.0.", + "HLS.L30.T38PMB.2022155T071923.v2.0.", + "HLS.S30.T38PMB.2022155T072619.v2.0.", + "HLS.S30.T38PMB.2022157T071631.v2.0.", + "HLS.S30.T38PMB.2022160T072621.v2.0.", + "HLS.L30.T38PMB.2022162T072536.v2.0.", + "HLS.S30.T38PMB.2022162T071619.v2.0.", + ], + "38PNB": [ + "HLS.S30.T38PNB.2022142T071619.v2.0.", + "HLS.L30.T38PNB.2022147T071947.v2.0.", + "HLS.S30.T38PNB.2022147T071621.v2.0.", + "HLS.L30.T38PNB.2022148T071311.v2.0.", + "HLS.S30.T38PNB.2022152T071619.v2.0.", + "HLS.L30.T38PNB.2022155T071923.v2.0.", + "HLS.L30.T38PNB.2022156T071344.v2.0.", + "HLS.S30.T38PNB.2022157T071631.v2.0.", + ], + "38PPB": [ + "HLS.L30.T38PPB.2022139T071922.v2.0.", + "HLS.L30.T38PPB.2022140T071337.v2.0.", + "HLS.S30.T38PPB.2022142T071619.v2.0.", + "HLS.L30.T38PPB.2022147T071947.v2.0.", + "HLS.S30.T38PPB.2022147T071621.v2.0.", + "HLS.L30.T38PPB.2022148T071311.v2.0.", + "HLS.S30.T38PPB.2022152T071619.v2.0.", + "HLS.L30.T38PPB.2022155T071923.v2.0.", + "HLS.L30.T38PPB.2022156T071344.v2.0.", + "HLS.S30.T38PPB.2022157T071631.v2.0.", + ], + "38QMC": [], + "38QMD": [ + "HLS.S30.T38QMD.2022142T071619.v2.0.", + "HLS.S30.T38QMD.2022145T072619.v2.0.", + "HLS.L30.T38QMD.2022146T072508.v2.0.", + "HLS.L30.T38QMD.2022147T071923.v2.0.", + "HLS.S30.T38QMD.2022147T071621.v2.0.", + "HLS.S30.T38QMD.2022150T072621.v2.0.", + "HLS.S30.T38QMD.2022152T071619.v2.0.", + "HLS.L30.T38QMD.2022154T072540.v2.0.", + "HLS.L30.T38QMD.2022155T071859.v2.0.", + "HLS.S30.T38QMD.2022155T072619.v2.0.", + "HLS.S30.T38QMD.2022157T071631.v2.0.", + "HLS.S30.T38QMD.2022160T072621.v2.0.", + "HLS.L30.T38QMD.2022162T072512.v2.0.", + "HLS.S30.T38QMD.2022162T071619.v2.0.", + ], + "39QTT": [], + "39QUT": [], + } + tile_queries = [ + ("38PMB", ["2022-06-08", "2022-05-29", "2022-05-19"]), + ("38PPB", ["2022-06-08", "2022-05-29", "2022-05-19"]), + ("39QTT", ["2022-06-08", "2022-05-29", "2022-05-19"]), + ("30RYS", ["2022-06-09", "2022-05-30", "2022-05-20"]), + ("38QMC", ["2022-06-09", "2022-05-30", "2022-05-20"]), + ("39QUT", ["2022-06-09", "2022-05-30", "2022-05-20"]), + ("38PMB", ["2022-06-09", "2022-05-30", "2022-05-20"]), + ("38PNB", ["2022-06-10", "2022-05-31", "2022-05-21"]), + ("39QTT", ["2022-06-10", "2022-05-31", "2022-05-21"]), + ("38QMD", ["2022-06-11", "2022-06-01", "2022-05-22"]), + ("38PMB", ["2022-06-11", "2022-06-01", "2022-05-22"]), + ] + + tile_queries_str = [ + f"{tile_id}_{'_'.join(dates)}" for tile_id, dates in tile_queries + ] + tile_queries = {k: v for k, v in zip(tile_queries_str, tile_queries)} + query_result = find_closest_tile(tile_queries, tile_database) + assert list(query_result["hls_tiles"]) == [ + [ + "HLS.L30.T38PMB.2022154T072604.v2.0.", + "HLS.S30.T38PMB.2022145T072619.v2.0.", + "HLS.L30.T38PMB.2022139T071922.v2.0.", + ], + [ + "HLS.L30.T38PPB.2022155T071923.v2.0.", + "HLS.L30.T38PPB.2022147T071947.v2.0.", + "HLS.L30.T38PPB.2022139T071922.v2.0.", + ], + [None, None, None], + [ + "HLS.L30.T30RYS.2022155T103334.v2.0.", + "HLS.L30.T30RYS.2022147T103357.v2.0.", + "HLS.L30.T30RYS.2022140T102748.v2.0.", + ], + [None, None, None], + [None, None, None], + [ + "HLS.L30.T38PMB.2022155T071923.v2.0.", + "HLS.S30.T38PMB.2022145T072619.v2.0.", + "HLS.L30.T38PMB.2022139T071922.v2.0.", + ], + [ + "HLS.L30.T38PNB.2022156T071344.v2.0.", + "HLS.L30.T38PNB.2022147T071947.v2.0.", + "HLS.S30.T38PNB.2022142T071619.v2.0.", + ], + [None, None, None], + [ + "HLS.S30.T38QMD.2022157T071631.v2.0.", + "HLS.L30.T38QMD.2022147T071923.v2.0.", + "HLS.S30.T38QMD.2022142T071619.v2.0.", + ], + [ + "HLS.S30.T38PMB.2022157T071631.v2.0.", + "HLS.L30.T38PMB.2022147T071947.v2.0.", + "HLS.L30.T38PMB.2022139T071922.v2.0.", + ], + ] + + +def test_create_hls_dataset(observation_data): + data = get_hls_tiles(observation_data, min_count=3) + data_with_tiles = add_hls_granules( + data, num_steps=3, temporal_step=10, temporal_tolerance=5 + ) + hls_dataset, tiles_to_download = create_hls_dataset(data_with_tiles, outdir="") + assert len(tiles_to_download) == 28 + assert len(hls_dataset) == 2 + assert len(hls_dataset["2022-06-08_T38PMB"]["tiles"]) == 18 + assert len(hls_dataset["2022-06-08_T38PMB"]["fmasks"]) == 3 + assert ( + hls_dataset["2022-06-08_T38PMB"]["tiles"]["B02_0"] + == "hls_tiles/HLS.L30.T38PMB.2022154T072604.v2.0.B02.tif" + ) + assert ( + hls_dataset["2022-06-08_T38PMB"]["fmasks"]["Fmask_0"] + == "hls_tiles/HLS.L30.T38PMB.2022154T072604.v2.0.Fmask.tif" + ) + assert ( + hls_dataset["2022-06-08_T38PMB"]["tiles"]["B02_1"] + == "hls_tiles/HLS.S30.T38PMB.2022145T072619.v2.0.B02.tif" + ) + assert ( + hls_dataset["2022-06-08_T38PMB"]["fmasks"]["Fmask_1"] + == "hls_tiles/HLS.S30.T38PMB.2022145T072619.v2.0.Fmask.tif" + ) + assert ( + hls_dataset["2022-06-08_T38PMB"]["tiles"]["B02_2"] + == "hls_tiles/HLS.L30.T38PMB.2022139T071922.v2.0.B02.tif" + ) + assert ( + hls_dataset["2022-06-08_T38PMB"]["fmasks"]["Fmask_2"] + == "hls_tiles/HLS.L30.T38PMB.2022139T071922.v2.0.Fmask.tif" + ) + + +@pytest.mark.auth +def test_download_hls_tile(): + urls = [ + "https://data.lpdaac.earthdatacloud.nasa.gov/lp-prod-protected/HLSL30.020/HLS.L30.T38PMB.2022139T071922.v2.0/HLS.L30.T38PMB.2022139T071922.v2.0.B01.tif" # noqa + ] + parallel_download(urls, outdir="/tmp") + out_filename = "/tmp/HLS.L30.T38PMB.2022139T071922.v2.0.B01.tif" # noqa + assert os.path.exists(out_filename) + src = rasterio.open(out_filename) + assert isinstance(src.crs, CRS) + + +@pytest.mark.auth +def test_download_hls_tile_with_retry(setup_and_teardown_output_dir): + outdir = "/tmp/test_hls" + open( + os.path.join(outdir, "HLS.L30.T38PMB.2022139T071922.v2.0.B02.tif"), "w" + ).close() + urls = { + "https://data.lpdaac.earthdatacloud.nasa.gov/lp-prod-protected/HLSL30.020/HLS.L30.T38PMB.2022139T071922.v2.0/HLS.L30.T38PMB.2022139T071922.v2.0.B03.tif", # noqa + "https://data.lpdaac.earthdatacloud.nasa.gov/lp-prod-protected/HLSL30.020/HLS.L30.T38PMB.2022139T071922.v2.0/HLS.L30.T38PMB.2022139T071922.v2.0.B02.tif", # noqa + } + parallel_download(urls, outdir=outdir) + out_filename = os.path.join(outdir, "HLS.L30.T38PMB.2022139T071922.v2.0.B02.tif") + assert os.path.exists(out_filename) + src = rasterio.open(out_filename) + assert isinstance(src.crs, CRS) + + +@pytest.mark.parametrize( + "tile_id, result", + [ + ("HLS.L30.T38PMB.2022139T071922.v2.0.", "2022-05-19"), + ("HLS.L30.T38PMB.202213T071922.v2.0.", None), + ], +) +def test_parse_date_from_entry(tile_id, result): + parsed_date = parse_date_from_entry(tile_id) + if isinstance(parsed_date, datetime.datetime): + parsed_date = parsed_date.strftime("%Y-%m-%d") + assert parsed_date == result diff --git a/tests/model_tests/test_dataloader.py b/tests/model_tests/test_dataloader.py new file mode 100644 index 0000000..6a63eb4 --- /dev/null +++ b/tests/model_tests/test_dataloader.py @@ -0,0 +1,211 @@ +from functools import partial + +import numpy as np +import pandas as pd +import pytest +import torch +from PIL import Image + +from instageo.model.dataloader import ( + InstaGeoDataset, + crop_array, + get_raster_data, + load_data_from_csv, + normalize_and_convert_to_tensor, + process_and_augment, + process_data, + process_test, + random_crop_and_flip, +) + + +def test_random_crop_and_flip(): + # Create dummy images and label + ims = [Image.new("L", (256, 256)) for _ in range(3)] + label = Image.new("L", (256, 256)) + + # Apply function + transformed_ims, transformed_label = random_crop_and_flip(ims, label, im_size=224) + + # Check output types and dimensions + assert isinstance(transformed_ims, list) + assert isinstance(transformed_label, Image.Image) + assert all(isinstance(im, Image.Image) for im in transformed_ims) + assert all(im.size == (224, 224) for im in transformed_ims) + assert transformed_label.size == (224, 224) + + +def test_normalize_and_convert_to_tensor(): + # Create dummy images and label + ims = [Image.new("L", (224, 224)) for _ in range(3)] + label = Image.new("L", (224, 224)) + + tensor_ims, tensor_label = normalize_and_convert_to_tensor( + ims, label, [0.0, 0.0, 0.0], [1.0, 1.0, 1.0] + ) + + assert isinstance(tensor_ims, torch.Tensor) + assert isinstance(tensor_label, torch.Tensor) + assert tensor_ims.shape == torch.Size([3, 1, 224, 224]) + assert tensor_label.shape == torch.Size([224, 224]) + + +def test_process_and_augment(): + x = np.random.rand(6, 256, 256) + y = np.random.rand(256, 256) + processed_ims, processed_label = process_and_augment( + x, y, mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0], temporal_size=2, im_size=224 + ) + + assert processed_ims.shape == torch.Size([3, 2, 224, 224]) + assert processed_label.shape == torch.Size([224, 224]) + + +def test_get_raster_data_with_str(): + test_fname = "tests/data/sample.tif" + result = get_raster_data(test_fname, is_label=False, bands=[0]) + assert isinstance(result, np.ndarray) + + +def test_get_raster_data_with_dict(): + band_files = { + "tiles": { + "band1": "tests/data/sample.tif", + "band2": "tests/data/sample.tif", + }, + "fmasks": { + "band1": "tests/data/fmask.tif", + "band2": "tests/data/fmask.tif", + }, + } + result = get_raster_data(band_files, mask_cloud=True) + assert isinstance(result, np.ndarray) + + +def test_process_data(): + im_test_fname = "tests/data/sample.tif" + mask_test_fname = "tests/data/sample.tif" + + arr_x, arr_y = process_data( + im_test_fname, + mask_test_fname, + no_data_value=-1, + replace_label=(-1, 0), + reduce_to_zero=True, + ) + + assert isinstance(arr_x, np.ndarray) + assert isinstance(arr_y, np.ndarray) + assert arr_x.shape[:-2] == arr_y.shape[:-2] + + +def test_process_data_without_label(): + im_test_fname = "tests/data/sample.tif" + + arr_x, arr_y = process_data( + im_test_fname, no_data_value=-1, replace_label=(-1, 0), reduce_to_zero=True + ) + + assert isinstance(arr_x, np.ndarray) + assert arr_y is None + + +def test_crop_2d_array(): + arr = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + cropped = crop_array(arr, 1, 1, 3, 3) + expected = np.array([[5, 6], [8, 9]]) + assert np.array_equal(cropped, expected) + + +def test_crop_3d_array(): + arr = np.zeros((3, 3, 3)) # Example 3D array + arr[:, 1, 1] = 1 # Modify the array for testing + cropped = crop_array(arr, 1, 1, 3, 3) + expected = np.zeros((3, 2, 2)) + expected[:, 0, 0] = 1 + assert np.array_equal(cropped, expected) + + +def test_invalid_dimensions(): + arr = np.array([1, 2, 3]) # 1D array, not valid + with pytest.raises(ValueError): + crop_array(arr, 0, 0, 1, 1) + + +def test_boundary_conditions(): + arr = np.array([[1, 2], [3, 4]]) + cropped = crop_array(arr, 0, 0, 2, 2) + assert np.array_equal(cropped, arr) + + +def test_non_integer_indices(): + arr = np.array([[1, 2], [3, 4]]) + with pytest.raises(TypeError): + crop_array(arr, 0.5, 0.5, 1.5, 1.5) + + +def test_output_types_and_shapes(): + x = np.random.rand(3, 512, 512) + y = np.random.rand(512, 512) + mean = [0.5, 0.5, 0.5] + std = [0.1, 0.1, 0.1] + imgs, labels = process_test(x, y, mean, std) + assert isinstance(imgs, torch.Tensor) + assert isinstance(labels, torch.Tensor) + assert imgs.shape == torch.Size([4, 3, 1, 224, 224]) + assert labels.shape == torch.Size([4, 224, 224]) + + +def test_invalid_inputs(): + x = np.random.rand(3, 512, 512) + y = np.random.rand(512, 512) + mean = [0.5, 0.5, "invalid"] + std = [0.1, 0.1, 0.1] + with pytest.raises(Exception): + process_test(x, y, mean, std) + + +def test_load_data_from_csv(): + sample_filename = "/tmp/sample_data.csv" + pd.DataFrame( + { + "Input": ["sample.tif", "sample.tif"], + "Label": ["sample.tif", "sample.tif"], + } + ).to_csv(sample_filename) + data = load_data_from_csv(sample_filename, input_root="tests/data") + assert data == [ + ("tests/data/sample.tif", "tests/data/sample.tif"), + ("tests/data/sample.tif", "tests/data/sample.tif"), + ] + + +def test_instageo_dataset(): + sample_filename = "/tmp/sample_data.csv" + pd.DataFrame( + { + "Input": ["sample.tif", "sample.tif"], + "Label": ["sample.tif", "sample.tif"], + } + ).to_csv(sample_filename) + dataset = InstaGeoDataset( + filename=sample_filename, + input_root="tests/data", + preprocess_func=partial( + process_and_augment, + mean=[0.0], + std=[1.0], + temporal_size=1, + im_size=224, + ), + bands=[0], + replace_label=None, + reduce_to_zero=False, + no_data_value=-1, + constant_multiplier=0.001, + ) + im, label = next(iter(dataset)) + assert isinstance(im, torch.Tensor) + assert isinstance(label, torch.Tensor) + assert im.shape == torch.Size([1, 1, 224, 224]) + assert label.shape == torch.Size([224, 224]) diff --git a/tests/model_tests/test_model.py b/tests/model_tests/test_model.py new file mode 100644 index 0000000..2d29a0d --- /dev/null +++ b/tests/model_tests/test_model.py @@ -0,0 +1,24 @@ +from unittest.mock import patch + +from instageo.model.model import download_file + + +def test_download_file_success(tmp_path): + test_filename = tmp_path / "test_file.txt" + + with patch("instageo.model.model.requests.get") as mock_get: + mock_get.return_value.status_code = 200 + mock_get.return_value.content = b"dummy content" + download_file("http://example.com", str(test_filename)) + + assert test_filename.exists() + + +def test_download_file_already_exists(tmp_path): + test_filename = tmp_path / "test_file.txt" + + test_filename.write_text("dummy content") + + with patch("instageo.model.model.requests.get") as mock_get: + download_file("http://example.com", str(test_filename)) + mock_get.assert_not_called() diff --git a/tests/model_tests/test_sliding_window_inference.py b/tests/model_tests/test_sliding_window_inference.py new file mode 100644 index 0000000..22e43ec --- /dev/null +++ b/tests/model_tests/test_sliding_window_inference.py @@ -0,0 +1,26 @@ +import numpy as np +import pytest +import torch + +from instageo.model.infer_utils import sliding_window_inference + + +@pytest.fixture +def model(): + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + + def predict_step(self, input): + h, w = input.shape[-2:] + return torch.ones([h, w]) + + return Model() + + +def test_sliding_window_inference(model): + hls_tile = torch.zeros((2, 18, 672, 672)) + prediction = sliding_window_inference( + hls_tile, model, window_size=(224, 224), stride=224, batch_size=2, device="cpu" + ) + assert np.unique(prediction) == 1