From 5a614a1a2ceb162a2abc9e5a83101dc93e0e6cbc Mon Sep 17 00:00:00 2001 From: Qiusheng Wu Date: Sat, 27 Jul 2024 12:33:04 -0400 Subject: [PATCH] Add image segmentation module --- .github/dependabot.yml | 19 + .github/workflows/docs-build.yml | 6 +- .github/workflows/docs.yml | 8 +- .github/workflows/macos.yml | 42 +-- .github/workflows/pypi.yml | 4 +- .github/workflows/ubuntu.yml | 5 +- .github/workflows/windows.yml | 2 +- .pre-commit-config.yaml | 29 ++ docs/examples/dataviz/lidar_viz.ipynb | 53 ++- docs/examples/dataviz/raster_viz.ipynb | 66 +++- docs/examples/dataviz/vector_viz.ipynb | 160 ++++++--- .../rastervision/semantic_segmentation.ipynb | 116 +++--- docs/examples/samgeo/arcgis.ipynb | 6 +- .../samgeo/automatic_mask_generator_hq.ipynb | 12 +- docs/examples/samgeo/box_prompts.ipynb | 7 +- docs/examples/samgeo/fast_sam.ipynb | 4 +- docs/examples/samgeo/input_prompts.ipynb | 7 +- docs/examples/samgeo/input_prompts_hq.ipynb | 3 +- docs/examples/samgeo/maxar_open_data.ipynb | 22 +- docs/examples/samgeo/swimming_pools.ipynb | 19 +- docs/examples/samgeo/text_prompts.ipynb | 19 +- docs/examples/samgeo/text_prompts_batch.ipynb | 11 +- docs/geoai.md | 2 +- geoai/__init__.py | 4 +- geoai/common.py | 6 +- geoai/segmentation.py | 330 ++++++++++++++++++ requirements.txt | 6 +- requirements_dev.txt | 12 +- requirements_docs.txt | 22 +- setup.py | 42 +-- 30 files changed, 779 insertions(+), 265 deletions(-) create mode 100644 .github/dependabot.yml create mode 100644 .pre-commit-config.yaml create mode 100644 geoai/segmentation.py diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..0372e10 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,19 @@ +# To get started with Dependabot version updates, you'll need to specify which +# package ecosystems to update and where the package manifests are located. +# Please see the documentation for all configuration options: +# https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file + +version: 2 +updates: + - package-ecosystem: "docker" + directory: "/" + schedule: + interval: "weekly" + - package-ecosystem: "pip" + directory: "/" + schedule: + interval: "weekly" + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "weekly" diff --git a/.github/workflows/docs-build.yml b/.github/workflows/docs-build.yml index 6d24027..af43f7c 100644 --- a/.github/workflows/docs-build.yml +++ b/.github/workflows/docs-build.yml @@ -8,10 +8,10 @@ jobs: deploy: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: fetch-depth: 0 - - uses: actions/setup-python@v4 + - uses: actions/setup-python@v5 with: python-version: "3.10" - name: Install GDAL @@ -27,8 +27,6 @@ jobs: pip install --no-cache-dir Cython pip install -r requirements.txt -r requirements_dev.txt -r requirements_docs.txt pip install . - - name: Discover typos with codespell - run: codespell --skip="*.csv,*.geojson,*.json,*.js,*.html,*cff,*.pdf,./.git" --ignore-words-list="aci,acount,acounts,fallow,hart,hist,nd,ned,ois,wqs" - name: PKG-TEST run: | python -m unittest discover tests/ diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 7262e05..46824ca 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -7,8 +7,8 @@ jobs: deploy: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 with: python-version: 3.9 - name: Install dependencies @@ -17,10 +17,6 @@ jobs: pip install --user --no-cache-dir Cython pip install --user -r requirements.txt pip install . - - name: Discover typos with codespell - run: | - pip install codespell - codespell --skip="*.csv,*.geojson,*.json,*.js,*.html,*cff,./.git" --ignore-words-list="aci,acount,acounts,fallow,hart,hist,nd,ned,ois,wqs,watermask" - name: PKG-TEST run: | python -m unittest discover tests/ diff --git a/.github/workflows/macos.yml b/.github/workflows/macos.yml index 3bb390c..f904bcc 100644 --- a/.github/workflows/macos.yml +++ b/.github/workflows/macos.yml @@ -14,25 +14,25 @@ jobs: strategy: fail-fast: false matrix: - os: ["macOS-latest"] - python-version: ["3.10"] + os: ["macOS-latest"] + python-version: ["3.10"] steps: - - name: Checkout code - uses: actions/checkout@v3 - - - name: Set up Python - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version}} - - name: Install GDAL - run: | - brew install gdal - - name: Test GDAL installation - run: | - gdalinfo --version - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install --no-cache-dir Cython - pip install -r requirements.txt - pip install . + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version}} + - name: Install GDAL + run: | + brew install gdal + - name: Test GDAL installation + run: | + gdalinfo --version + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install --no-cache-dir Cython + pip install -r requirements.txt + pip install . diff --git a/.github/workflows/pypi.yml b/.github/workflows/pypi.yml index 2bbb0c0..ae45065 100644 --- a/.github/workflows/pypi.yml +++ b/.github/workflows/pypi.yml @@ -12,9 +12,9 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: "3.x" - name: Install dependencies diff --git a/.github/workflows/ubuntu.yml b/.github/workflows/ubuntu.yml index dd78cdd..2282a3a 100644 --- a/.github/workflows/ubuntu.yml +++ b/.github/workflows/ubuntu.yml @@ -21,9 +21,9 @@ jobs: - { os: ubuntu-latest, py: "3.11" } steps: - name: Checkout Code - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Setup Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.config.py }} - name: Install GDAL @@ -43,4 +43,3 @@ jobs: - name: PKG-TEST run: | python -m unittest discover tests/ - diff --git a/.github/workflows/windows.yml b/.github/workflows/windows.yml index 84c0b02..c8c11ea 100644 --- a/.github/workflows/windows.yml +++ b/.github/workflows/windows.yml @@ -11,7 +11,7 @@ jobs: test-windows: runs-on: windows-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Install miniconda uses: conda-incubator/setup-miniconda@v2 with: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..5be88b7 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,29 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.6.0 + hooks: + - id: check-toml + - id: check-yaml + - id: end-of-file-fixer + types: [python] + - id: trailing-whitespace + - id: requirements-txt-fixer + - id: check-added-large-files + args: ["--maxkb=500"] + + - repo: https://github.com/psf/black + rev: 24.4.2 + hooks: + - id: black-jupyter + language_version: python3.11 + + - repo: https://github.com/codespell-project/codespell + rev: v2.3.0 + hooks: + - id: codespell + args: ["--ignore-words-list=gis,timeseries", "--skip=*.json,*.csv"] + + - repo: https://github.com/kynan/nbstripout + rev: 0.7.1 + hooks: + - id: nbstripout diff --git a/docs/examples/dataviz/lidar_viz.ipynb b/docs/examples/dataviz/lidar_viz.ipynb index 58a8f80..6ef85d6 100644 --- a/docs/examples/dataviz/lidar_viz.ipynb +++ b/docs/examples/dataviz/lidar_viz.ipynb @@ -2,6 +2,7 @@ "cells": [ { "cell_type": "markdown", + "id": "0", "metadata": {}, "source": [ "[![image](https://studiolab.sagemaker.aws/studiolab.svg)](https://studiolab.sagemaker.aws/import/github/opengeos/geoai/blob/main/docs/examples/dataviz/lidar_viz.ipynb)\n", @@ -20,6 +21,7 @@ { "cell_type": "code", "execution_count": null, + "id": "1", "metadata": {}, "outputs": [], "source": [ @@ -28,6 +30,7 @@ }, { "cell_type": "markdown", + "id": "2", "metadata": {}, "source": [ "## Import libraries" @@ -35,7 +38,8 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, + "id": "3", "metadata": {}, "outputs": [], "source": [ @@ -44,6 +48,7 @@ }, { "cell_type": "markdown", + "id": "4", "metadata": {}, "source": [ "## Download data" @@ -51,6 +56,7 @@ }, { "cell_type": "markdown", + "id": "5", "metadata": {}, "source": [ "Download a [sample LiDAR dataset](https://drive.google.com/file/d/1H_X1190vL63BoFYa_cVBDxtIa8rG-Usb/view?usp=sharing) from Google Drive. The zip file is 52.1 MB and the uncompressed LAS file is 109 MB." @@ -58,25 +64,28 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, + "id": "6", "metadata": {}, "outputs": [], "source": [ - "url = 'https://open.gishub.org/data/lidar/madison.zip'\n", - "filename = 'madison.las'" + "url = \"https://open.gishub.org/data/lidar/madison.zip\"\n", + "filename = \"madison.las\"" ] }, { "cell_type": "code", "execution_count": null, + "id": "7", "metadata": {}, "outputs": [], "source": [ - "leafmap.download_file(url, 'madison.zip', unzip=True)" + "leafmap.download_file(url, \"madison.zip\", unzip=True)" ] }, { "cell_type": "markdown", + "id": "8", "metadata": {}, "source": [ "## Metadata" @@ -84,6 +93,7 @@ }, { "cell_type": "markdown", + "id": "9", "metadata": {}, "source": [ "Read the LiDAR data" @@ -92,6 +102,7 @@ { "cell_type": "code", "execution_count": null, + "id": "10", "metadata": {}, "outputs": [], "source": [ @@ -100,6 +111,7 @@ }, { "cell_type": "markdown", + "id": "11", "metadata": {}, "source": [ "The LAS header." @@ -108,6 +120,7 @@ { "cell_type": "code", "execution_count": null, + "id": "12", "metadata": {}, "outputs": [], "source": [ @@ -116,6 +129,7 @@ }, { "cell_type": "markdown", + "id": "13", "metadata": {}, "source": [ "The number of points." @@ -124,6 +138,7 @@ { "cell_type": "code", "execution_count": null, + "id": "14", "metadata": {}, "outputs": [], "source": [ @@ -132,6 +147,7 @@ }, { "cell_type": "markdown", + "id": "15", "metadata": {}, "source": [ "The list of features." @@ -140,6 +156,7 @@ { "cell_type": "code", "execution_count": null, + "id": "16", "metadata": {}, "outputs": [], "source": [ @@ -148,6 +165,7 @@ }, { "cell_type": "markdown", + "id": "17", "metadata": {}, "source": [ "## Read data" @@ -155,6 +173,7 @@ }, { "cell_type": "markdown", + "id": "18", "metadata": {}, "source": [ "Inspect data." @@ -163,6 +182,7 @@ { "cell_type": "code", "execution_count": null, + "id": "19", "metadata": {}, "outputs": [], "source": [ @@ -172,6 +192,7 @@ { "cell_type": "code", "execution_count": null, + "id": "20", "metadata": {}, "outputs": [], "source": [ @@ -181,6 +202,7 @@ { "cell_type": "code", "execution_count": null, + "id": "21", "metadata": {}, "outputs": [], "source": [ @@ -190,6 +212,7 @@ { "cell_type": "code", "execution_count": null, + "id": "22", "metadata": {}, "outputs": [], "source": [ @@ -198,6 +221,7 @@ }, { "cell_type": "markdown", + "id": "23", "metadata": {}, "source": [ "## PyVista\n", @@ -208,14 +232,16 @@ { "cell_type": "code", "execution_count": null, + "id": "24", "metadata": {}, "outputs": [], "source": [ - "leafmap.view_lidar(filename, cmap='terrain', backend='pyvista')" + "leafmap.view_lidar(filename, cmap=\"terrain\", backend=\"pyvista\")" ] }, { "cell_type": "markdown", + "id": "25", "metadata": {}, "source": [ "![](https://i.imgur.com/xezcgMP.gif)" @@ -223,6 +249,7 @@ }, { "cell_type": "markdown", + "id": "26", "metadata": {}, "source": [ "## ipygany\n", @@ -233,14 +260,16 @@ { "cell_type": "code", "execution_count": null, + "id": "27", "metadata": {}, "outputs": [], "source": [ - "leafmap.view_lidar(filename, backend='ipygany', background='white')" + "leafmap.view_lidar(filename, backend=\"ipygany\", background=\"white\")" ] }, { "cell_type": "markdown", + "id": "28", "metadata": {}, "source": [ "![](https://i.imgur.com/MyMWW4I.gif)" @@ -248,6 +277,7 @@ }, { "cell_type": "markdown", + "id": "29", "metadata": {}, "source": [ "## Panel\n", @@ -258,14 +288,16 @@ { "cell_type": "code", "execution_count": null, + "id": "30", "metadata": {}, "outputs": [], "source": [ - "leafmap.view_lidar(filename, cmap='terrain', backend='panel', background='white')" + "leafmap.view_lidar(filename, cmap=\"terrain\", backend=\"panel\", background=\"white\")" ] }, { "cell_type": "markdown", + "id": "31", "metadata": {}, "source": [ "![](https://i.imgur.com/XQGWbJk.gif)" @@ -273,6 +305,7 @@ }, { "cell_type": "markdown", + "id": "32", "metadata": {}, "source": [ "## Open3D\n", @@ -283,14 +316,16 @@ { "cell_type": "code", "execution_count": null, + "id": "33", "metadata": {}, "outputs": [], "source": [ - "leafmap.view_lidar(filename, backend='open3d')" + "leafmap.view_lidar(filename, backend=\"open3d\")" ] }, { "cell_type": "markdown", + "id": "34", "metadata": {}, "source": [ "![](https://i.imgur.com/rL85fbl.gif)" diff --git a/docs/examples/dataviz/raster_viz.ipynb b/docs/examples/dataviz/raster_viz.ipynb index cf9f89a..d2b5338 100644 --- a/docs/examples/dataviz/raster_viz.ipynb +++ b/docs/examples/dataviz/raster_viz.ipynb @@ -2,6 +2,7 @@ "cells": [ { "cell_type": "markdown", + "id": "0", "metadata": {}, "source": [ "[![image](https://studiolab.sagemaker.aws/studiolab.svg)](https://studiolab.sagemaker.aws/import/github/opengeos/geoai/blob/main/docs/examples/dataviz/raster_viz.ipynb)\n", @@ -20,6 +21,7 @@ { "cell_type": "code", "execution_count": null, + "id": "1", "metadata": {}, "outputs": [], "source": [ @@ -28,6 +30,7 @@ }, { "cell_type": "markdown", + "id": "2", "metadata": {}, "source": [ "## Import packages" @@ -35,7 +38,8 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, + "id": "3", "metadata": {}, "outputs": [], "source": [ @@ -44,6 +48,7 @@ }, { "cell_type": "markdown", + "id": "4", "metadata": {}, "source": [ "## COG\n", @@ -56,17 +61,19 @@ { "cell_type": "code", "execution_count": null, + "id": "5", "metadata": {}, "outputs": [], "source": [ "m = leafmap.Map()\n", - "url = 'https://opendata.digitalglobe.com/events/california-fire-2020/pre-event/2018-02-16/pine-gulch-fire20/1030010076004E00.tif'\n", + "url = \"https://opendata.digitalglobe.com/events/california-fire-2020/pre-event/2018-02-16/pine-gulch-fire20/1030010076004E00.tif\"\n", "m.add_cog_layer(url, name=\"Fire (pre-event)\")\n", "m" ] }, { "cell_type": "markdown", + "id": "6", "metadata": {}, "source": [ "You can add multiple COGs to the map. Let's add another COG to the map." @@ -75,16 +82,18 @@ { "cell_type": "code", "execution_count": null, + "id": "7", "metadata": {}, "outputs": [], "source": [ - "url2 = 'https://opendata.digitalglobe.com/events/california-fire-2020/post-event/2020-08-14/pine-gulch-fire20/10300100AAC8DD00.tif'\n", + "url2 = \"https://opendata.digitalglobe.com/events/california-fire-2020/post-event/2020-08-14/pine-gulch-fire20/10300100AAC8DD00.tif\"\n", "m.add_cog_layer(url2, name=\"Fire (post-event)\")\n", "m" ] }, { "cell_type": "markdown", + "id": "8", "metadata": {}, "source": [ "![](https://i.imgur.com/J9fLbYh.png)" @@ -92,6 +101,7 @@ }, { "cell_type": "markdown", + "id": "9", "metadata": {}, "source": [ "Create a split map for comparing two COGs." @@ -100,6 +110,7 @@ { "cell_type": "code", "execution_count": null, + "id": "10", "metadata": {}, "outputs": [], "source": [ @@ -110,6 +121,7 @@ }, { "cell_type": "markdown", + "id": "11", "metadata": {}, "source": [ "![](https://i.imgur.com/FJa0Yta.png)" @@ -117,6 +129,7 @@ }, { "cell_type": "markdown", + "id": "12", "metadata": {}, "source": [ "## Local Raster\n", @@ -127,15 +140,17 @@ { "cell_type": "code", "execution_count": null, + "id": "13", "metadata": {}, "outputs": [], "source": [ - "dem_url = 'https://open.gishub.org/data/raster/srtm90.tif'\n", + "dem_url = \"https://open.gishub.org/data/raster/srtm90.tif\"\n", "leafmap.download_file(dem_url, unzip=False)" ] }, { "cell_type": "markdown", + "id": "14", "metadata": {}, "source": [ "Visualize a single-band raster." @@ -144,16 +159,18 @@ { "cell_type": "code", "execution_count": null, + "id": "15", "metadata": {}, "outputs": [], "source": [ "m = leafmap.Map()\n", - "m.add_raster('srtm90.tif', cmap='terrain', layer_name='DEM')\n", + "m.add_raster(\"srtm90.tif\", cmap=\"terrain\", layer_name=\"DEM\")\n", "m" ] }, { "cell_type": "markdown", + "id": "16", "metadata": {}, "source": [ "![](https://i.imgur.com/7g9huvY.png)" @@ -162,15 +179,17 @@ { "cell_type": "code", "execution_count": null, + "id": "17", "metadata": {}, "outputs": [], "source": [ - "landsat_url = 'https://open.gishub.org/data/raster/cog.tif'\n", + "landsat_url = \"https://open.gishub.org/data/raster/cog.tif\"\n", "leafmap.download_file(landsat_url)" ] }, { "cell_type": "markdown", + "id": "18", "metadata": {}, "source": [ "Visualize a multi-band raster." @@ -179,16 +198,18 @@ { "cell_type": "code", "execution_count": null, + "id": "19", "metadata": {}, "outputs": [], "source": [ "m = leafmap.Map()\n", - "m.add_raster('cog.tif', bands=[4, 3, 2], layer_name='Landsat')\n", + "m.add_raster(\"cog.tif\", bands=[4, 3, 2], layer_name=\"Landsat\")\n", "m" ] }, { "cell_type": "markdown", + "id": "20", "metadata": {}, "source": [ "![](https://i.imgur.com/euhkajs.png)" @@ -196,6 +217,7 @@ }, { "cell_type": "markdown", + "id": "21", "metadata": {}, "source": [ "## STAC\n", @@ -205,6 +227,7 @@ }, { "cell_type": "markdown", + "id": "22", "metadata": {}, "source": [ "Create an interactive map." @@ -213,15 +236,17 @@ { "cell_type": "code", "execution_count": null, + "id": "23", "metadata": {}, "outputs": [], "source": [ - "url = 'https://canada-spot-ortho.s3.amazonaws.com/canada_spot_orthoimages/canada_spot5_orthoimages/S5_2007/S5_11055_6057_20070622/S5_11055_6057_20070622.json'\n", + "url = \"https://canada-spot-ortho.s3.amazonaws.com/canada_spot_orthoimages/canada_spot5_orthoimages/S5_2007/S5_11055_6057_20070622/S5_11055_6057_20070622.json\"\n", "leafmap.stac_bands(url)" ] }, { "cell_type": "markdown", + "id": "24", "metadata": {}, "source": [ "Add STAC layers to the map." @@ -230,17 +255,19 @@ { "cell_type": "code", "execution_count": null, + "id": "25", "metadata": {}, "outputs": [], "source": [ "m = leafmap.Map()\n", - "m.add_stac_layer(url, bands=['pan'], name='Panchromatic')\n", - "m.add_stac_layer(url, bands=['B3', 'B2', 'B1'], name='False color')\n", + "m.add_stac_layer(url, bands=[\"pan\"], name=\"Panchromatic\")\n", + "m.add_stac_layer(url, bands=[\"B3\", \"B2\", \"B1\"], name=\"False color\")\n", "m" ] }, { "cell_type": "markdown", + "id": "26", "metadata": {}, "source": [ "![](https://i.imgur.com/IlqsJXK.png)" @@ -248,6 +275,7 @@ }, { "cell_type": "markdown", + "id": "27", "metadata": {}, "source": [ "## Custom STAC Catalog" @@ -255,6 +283,7 @@ }, { "cell_type": "markdown", + "id": "28", "metadata": {}, "source": [ "Provide custom STAC API endpoints as a dictionary in the format of `{\"name\": \"url\"}`. The name will show up in the dropdown menu, while the url is the STAC API endpoint that will be used to search for items." @@ -263,6 +292,7 @@ { "cell_type": "code", "execution_count": null, + "id": "29", "metadata": {}, "outputs": [], "source": [ @@ -275,6 +305,7 @@ { "cell_type": "code", "execution_count": null, + "id": "30", "metadata": {}, "outputs": [], "source": [ @@ -286,6 +317,7 @@ }, { "cell_type": "markdown", + "id": "31", "metadata": {}, "source": [ "Once the catalog panel is open, you can search for items from the custom STAC API endpoints. Simply draw a bounding box on the map or zoom to a location of interest. Click on the **Collections** button to retrieve the collections from the custom STAC API endpoints. Next, select a collection from the dropdown menu. Then, click on the **Items** button to retrieve the items from the selected collection. Finally, click on the **Display** button to add the selected item to the map." @@ -293,6 +325,7 @@ }, { "cell_type": "markdown", + "id": "32", "metadata": {}, "source": [ "![](https://i.imgur.com/M8IbRsM.png)" @@ -301,6 +334,7 @@ { "cell_type": "code", "execution_count": null, + "id": "33", "metadata": {}, "outputs": [], "source": [ @@ -310,6 +344,7 @@ { "cell_type": "code", "execution_count": null, + "id": "34", "metadata": {}, "outputs": [], "source": [ @@ -319,6 +354,7 @@ { "cell_type": "code", "execution_count": null, + "id": "35", "metadata": {}, "outputs": [], "source": [ @@ -327,6 +363,7 @@ }, { "cell_type": "markdown", + "id": "36", "metadata": {}, "source": [ "## AWS S3" @@ -334,6 +371,7 @@ }, { "cell_type": "markdown", + "id": "37", "metadata": {}, "source": [ "To Be able to run this notebook you'll need to have AWS credential available as environment variables. Uncomment the following lines to set the environment variables." @@ -342,6 +380,7 @@ { "cell_type": "code", "execution_count": null, + "id": "38", "metadata": {}, "outputs": [], "source": [ @@ -351,6 +390,7 @@ }, { "cell_type": "markdown", + "id": "39", "metadata": {}, "source": [ "In this example, we will use datasets from the [Maxar Open Data Program on AWS](https://registry.opendata.aws/maxar-open-data/)." @@ -359,6 +399,7 @@ { "cell_type": "code", "execution_count": null, + "id": "40", "metadata": {}, "outputs": [], "source": [ @@ -368,6 +409,7 @@ }, { "cell_type": "markdown", + "id": "41", "metadata": {}, "source": [ "List all the datasets in the bucket. Specify a file extension to filter the results if needed." @@ -376,6 +418,7 @@ { "cell_type": "code", "execution_count": null, + "id": "42", "metadata": {}, "outputs": [], "source": [ @@ -385,6 +428,7 @@ }, { "cell_type": "markdown", + "id": "43", "metadata": {}, "source": [ "Visualize raster datasets from the bucket." @@ -393,6 +437,7 @@ { "cell_type": "code", "execution_count": null, + "id": "44", "metadata": {}, "outputs": [], "source": [ @@ -403,6 +448,7 @@ }, { "cell_type": "markdown", + "id": "45", "metadata": {}, "source": [ "![](https://i.imgur.com/NkTZ6Lj.png)" diff --git a/docs/examples/dataviz/vector_viz.ipynb b/docs/examples/dataviz/vector_viz.ipynb index 99011fb..2c2eb0c 100644 --- a/docs/examples/dataviz/vector_viz.ipynb +++ b/docs/examples/dataviz/vector_viz.ipynb @@ -2,6 +2,7 @@ "cells": [ { "cell_type": "markdown", + "id": "0", "metadata": {}, "source": [ "[![image](https://studiolab.sagemaker.aws/studiolab.svg)](https://studiolab.sagemaker.aws/import/github/opengeos/geoai/blob/main/docs/examples/dataviz/vector_viz.ipynb)\n", @@ -18,6 +19,7 @@ { "cell_type": "code", "execution_count": null, + "id": "1", "metadata": {}, "outputs": [], "source": [ @@ -26,6 +28,7 @@ }, { "cell_type": "markdown", + "id": "2", "metadata": {}, "source": [ "## Import libraries" @@ -34,6 +37,7 @@ { "cell_type": "code", "execution_count": null, + "id": "3", "metadata": {}, "outputs": [], "source": [ @@ -42,6 +46,7 @@ }, { "cell_type": "markdown", + "id": "4", "metadata": {}, "source": [ "## Visualize vector data\n", @@ -52,17 +57,19 @@ { "cell_type": "code", "execution_count": null, + "id": "5", "metadata": {}, "outputs": [], "source": [ "m = leafmap.Map(center=[0, 0], zoom=2)\n", - "data = 'https://open.gishub.org/data/vector/cables.geojson'\n", + "data = \"https://open.gishub.org/data/vector/cables.geojson\"\n", "m.add_vector(data, layer_name=\"Cable lines\", info_mode=\"on_hover\")\n", "m" ] }, { "cell_type": "markdown", + "id": "6", "metadata": {}, "source": [ "![](https://i.imgur.com/NIcnLWs.png)" @@ -70,6 +77,7 @@ }, { "cell_type": "markdown", + "id": "7", "metadata": {}, "source": [ "You can style the vector with custom style callback functions. " @@ -78,12 +86,13 @@ { "cell_type": "code", "execution_count": null, + "id": "8", "metadata": {}, "outputs": [], "source": [ "m = leafmap.Map(center=[20, 0], zoom=2)\n", "m.add_basemap(\"CartoDB.DarkMatter\")\n", - "data = 'https://open.gishub.org/data/vector/cables.geojson'\n", + "data = \"https://open.gishub.org/data/vector/cables.geojson\"\n", "callback = lambda feat: {\"color\": feat[\"properties\"][\"color\"], \"weight\": 1}\n", "m.add_vector(data, layer_name=\"Cable lines\", style_callback=callback)\n", "m" @@ -91,6 +100,7 @@ }, { "cell_type": "markdown", + "id": "9", "metadata": {}, "source": [ "![](https://i.imgur.com/mQnV53U.png)" @@ -98,6 +108,7 @@ }, { "cell_type": "markdown", + "id": "10", "metadata": {}, "source": [ "## Choropleth map\n", @@ -108,23 +119,21 @@ { "cell_type": "code", "execution_count": null, + "id": "11", "metadata": {}, "outputs": [], "source": [ "m = leafmap.Map()\n", - "data = 'https://raw.githubusercontent.com/opengeos/leafmap/master/docs/examples/data/countries.geojson'\n", + "data = \"https://raw.githubusercontent.com/opengeos/leafmap/master/docs/examples/data/countries.geojson\"\n", "m.add_data(\n", - " data, \n", - " column='POP_EST', \n", - " scheme='Quantiles', \n", - " cmap='Blues', \n", - " legend_title='Population'\n", + " data, column=\"POP_EST\", scheme=\"Quantiles\", cmap=\"Blues\", legend_title=\"Population\"\n", ")\n", "m" ] }, { "cell_type": "markdown", + "id": "12", "metadata": {}, "source": [ "![](https://i.imgur.com/4FV6f72.png)" @@ -133,22 +142,24 @@ { "cell_type": "code", "execution_count": null, + "id": "13", "metadata": {}, "outputs": [], "source": [ "m = leafmap.Map()\n", "m.add_data(\n", " data,\n", - " column='POP_EST',\n", - " scheme='EqualInterval',\n", - " cmap='Blues',\n", - " legend_title='Population',\n", + " column=\"POP_EST\",\n", + " scheme=\"EqualInterval\",\n", + " cmap=\"Blues\",\n", + " legend_title=\"Population\",\n", ")\n", "m" ] }, { "cell_type": "markdown", + "id": "14", "metadata": {}, "source": [ "![](https://i.imgur.com/KEY6zEj.png)" @@ -156,6 +167,7 @@ }, { "cell_type": "markdown", + "id": "15", "metadata": {}, "source": [ "## GeoParquet\n", @@ -166,16 +178,18 @@ { "cell_type": "code", "execution_count": null, + "id": "16", "metadata": {}, "outputs": [], "source": [ - "url = 'https://open.gishub.org/data/duckdb/cities.parquet'\n", - "gdf = leafmap.read_parquet(url, return_type='gdf', src_crs='EPSG:4326')\n", + "url = \"https://open.gishub.org/data/duckdb/cities.parquet\"\n", + "gdf = leafmap.read_parquet(url, return_type=\"gdf\", src_crs=\"EPSG:4326\")\n", "gdf.head()" ] }, { "cell_type": "markdown", + "id": "17", "metadata": {}, "source": [ "Visualize point data." @@ -184,20 +198,22 @@ { "cell_type": "code", "execution_count": null, + "id": "18", "metadata": {}, "outputs": [], "source": [ "leafmap.view_vector(\n", - " gdf, \n", - " get_radius=20000, \n", - " get_fill_color='blue', \n", - " zoom_to_layer=False, \n", - " map_args={'center': (40, -100), 'zoom': 3, 'height': 500}\n", - " )" + " gdf,\n", + " get_radius=20000,\n", + " get_fill_color=\"blue\",\n", + " zoom_to_layer=False,\n", + " map_args={\"center\": (40, -100), \"zoom\": 3, \"height\": 500},\n", + ")" ] }, { "cell_type": "markdown", + "id": "19", "metadata": {}, "source": [ "![](https://i.imgur.com/BBsLjvx.png)" @@ -205,6 +221,7 @@ }, { "cell_type": "markdown", + "id": "20", "metadata": {}, "source": [ "Visualizing polygon data." @@ -213,17 +230,21 @@ { "cell_type": "code", "execution_count": null, + "id": "21", "metadata": {}, "outputs": [], "source": [ - "url = 'https://data.source.coop/giswqs/nwi/wetlands/DC_Wetlands.parquet'\n", - "gdf = leafmap.read_parquet(url, return_type='gdf', src_crs='EPSG:5070', dst_crs='EPSG:4326')\n", + "url = \"https://data.source.coop/giswqs/nwi/wetlands/DC_Wetlands.parquet\"\n", + "gdf = leafmap.read_parquet(\n", + " url, return_type=\"gdf\", src_crs=\"EPSG:5070\", dst_crs=\"EPSG:4326\"\n", + ")\n", "gdf.head()" ] }, { "cell_type": "code", "execution_count": null, + "id": "22", "metadata": {}, "outputs": [], "source": [ @@ -232,6 +253,7 @@ }, { "cell_type": "markdown", + "id": "23", "metadata": {}, "source": [ "![vector](https://i.imgur.com/HRtpiVd.png)" @@ -239,6 +261,7 @@ }, { "cell_type": "markdown", + "id": "24", "metadata": {}, "source": [ "Alternatively, you can specify a color map to visualize the data." @@ -247,32 +270,35 @@ { "cell_type": "code", "execution_count": null, + "id": "25", "metadata": {}, "outputs": [], "source": [ - "color_map = {\n", - " \"Freshwater Forested/Shrub Wetland\": (0, 136, 55),\n", - " \"Freshwater Emergent Wetland\": (127, 195, 28),\n", - " \"Freshwater Pond\": (104, 140, 192),\n", - " \"Estuarine and Marine Wetland\": (102, 194, 165),\n", - " \"Riverine\": (1, 144, 191),\n", - " \"Lake\": (19, 0, 124),\n", - " \"Estuarine and Marine Deepwater\": (0, 124, 136),\n", - " \"Other\": (178, 134, 86),\n", - " }" + "color_map = {\n", + " \"Freshwater Forested/Shrub Wetland\": (0, 136, 55),\n", + " \"Freshwater Emergent Wetland\": (127, 195, 28),\n", + " \"Freshwater Pond\": (104, 140, 192),\n", + " \"Estuarine and Marine Wetland\": (102, 194, 165),\n", + " \"Riverine\": (1, 144, 191),\n", + " \"Lake\": (19, 0, 124),\n", + " \"Estuarine and Marine Deepwater\": (0, 124, 136),\n", + " \"Other\": (178, 134, 86),\n", + "}" ] }, { "cell_type": "code", "execution_count": null, + "id": "26", "metadata": {}, "outputs": [], "source": [ - "leafmap.view_vector(gdf, color_column='WETLAND_TYPE', color_map=color_map, opacity=0.5)" + "leafmap.view_vector(gdf, color_column=\"WETLAND_TYPE\", color_map=color_map, opacity=0.5)" ] }, { "cell_type": "markdown", + "id": "27", "metadata": {}, "source": [ "![vector-color](https://i.imgur.com/Ejh8hK6.png)" @@ -280,6 +306,7 @@ }, { "cell_type": "markdown", + "id": "28", "metadata": {}, "source": [ "Display a legend for the data." @@ -288,6 +315,7 @@ { "cell_type": "code", "execution_count": null, + "id": "29", "metadata": {}, "outputs": [], "source": [ @@ -296,6 +324,7 @@ }, { "cell_type": "markdown", + "id": "30", "metadata": {}, "source": [ "![legend](https://i.imgur.com/fxzHHFN.png)\n", @@ -307,6 +336,7 @@ }, { "cell_type": "markdown", + "id": "31", "metadata": {}, "source": [ "## Remote PMTiles\n", @@ -316,6 +346,7 @@ }, { "cell_type": "markdown", + "id": "32", "metadata": {}, "source": [ "### Overture data" @@ -324,6 +355,7 @@ { "cell_type": "code", "execution_count": null, + "id": "33", "metadata": {}, "outputs": [], "source": [ @@ -332,6 +364,7 @@ }, { "cell_type": "markdown", + "id": "34", "metadata": {}, "source": [ "Check the metadata of the PMTiles." @@ -340,6 +373,7 @@ { "cell_type": "code", "execution_count": null, + "id": "35", "metadata": {}, "outputs": [], "source": [ @@ -351,6 +385,7 @@ }, { "cell_type": "markdown", + "id": "36", "metadata": {}, "source": [ "Visualize the PMTiles." @@ -359,11 +394,12 @@ { "cell_type": "code", "execution_count": null, + "id": "37", "metadata": {}, "outputs": [], "source": [ "m = leafmap.Map()\n", - "m.add_basemap('CartoDB.DarkMatter')\n", + "m.add_basemap(\"CartoDB.DarkMatter\")\n", "\n", "style = {\n", " \"version\": 8,\n", @@ -371,7 +407,7 @@ " \"example_source\": {\n", " \"type\": \"vector\",\n", " \"url\": \"pmtiles://\" + url,\n", - " \"attribution\": 'PMTiles',\n", + " \"attribution\": \"PMTiles\",\n", " }\n", " },\n", " \"layers\": [\n", @@ -407,14 +443,14 @@ "}\n", "\n", "m.add_pmtiles(\n", - " url, name='PMTiles', style=style, overlay=True, show=True, zoom_to_layer=True\n", + " url, name=\"PMTiles\", style=style, overlay=True, show=True, zoom_to_layer=True\n", ")\n", "\n", "legend_dict = {\n", - " 'admins': 'BDD3C7',\n", - " 'buildings': 'FFFFB3',\n", - " 'places': 'BEBADA',\n", - " 'roads': 'FB8072',\n", + " \"admins\": \"BDD3C7\",\n", + " \"buildings\": \"FFFFB3\",\n", + " \"places\": \"BEBADA\",\n", + " \"roads\": \"FB8072\",\n", "}\n", "\n", "m.add_legend(legend_dict=legend_dict)\n", @@ -423,6 +459,7 @@ }, { "cell_type": "markdown", + "id": "38", "metadata": {}, "source": [ "![](https://i.imgur.com/fjWMOu3.png)" @@ -430,6 +467,7 @@ }, { "cell_type": "markdown", + "id": "39", "metadata": {}, "source": [ "### Source Cooperative\n", @@ -442,10 +480,11 @@ { "cell_type": "code", "execution_count": null, + "id": "40", "metadata": {}, "outputs": [], "source": [ - "url = 'https://data.source.coop/vida/google-microsoft-open-buildings/pmtiles/go_ms_building_footprints.pmtiles'\n", + "url = \"https://data.source.coop/vida/google-microsoft-open-buildings/pmtiles/go_ms_building_footprints.pmtiles\"\n", "metadata = leafmap.pmtiles_metadata(url)\n", "print(f\"layer names: {metadata['layer_names']}\")\n", "print(f\"bounds: {metadata['bounds']}\")" @@ -453,6 +492,7 @@ }, { "cell_type": "markdown", + "id": "41", "metadata": {}, "source": [ "Visualize the PMTiles." @@ -461,12 +501,13 @@ { "cell_type": "code", "execution_count": null, + "id": "42", "metadata": {}, "outputs": [], "source": [ "m = leafmap.Map(center=[20, 0], zoom=2)\n", - "m.add_basemap('CartoDB.DarkMatter')\n", - "m.add_basemap('Esri.WorldImagery', show=False)\n", + "m.add_basemap(\"CartoDB.DarkMatter\")\n", + "m.add_basemap(\"Esri.WorldImagery\", show=False)\n", "\n", "style = {\n", " \"version\": 8,\n", @@ -474,7 +515,7 @@ " \"example_source\": {\n", " \"type\": \"vector\",\n", " \"url\": \"pmtiles://\" + url,\n", - " \"attribution\": 'PMTiles',\n", + " \"attribution\": \"PMTiles\",\n", " }\n", " },\n", " \"layers\": [\n", @@ -489,7 +530,7 @@ "}\n", "\n", "m.add_pmtiles(\n", - " url, name='Buildings', style=style, overlay=True, show=True, zoom_to_layer=False\n", + " url, name=\"Buildings\", style=style, overlay=True, show=True, zoom_to_layer=False\n", ")\n", "\n", "m" @@ -497,6 +538,7 @@ }, { "cell_type": "markdown", + "id": "43", "metadata": {}, "source": [ "![](https://i.imgur.com/kT6ng6k.png)" @@ -504,6 +546,7 @@ }, { "cell_type": "markdown", + "id": "44", "metadata": {}, "source": [ "## Local PMTiles\n", @@ -516,15 +559,17 @@ { "cell_type": "code", "execution_count": null, + "id": "45", "metadata": {}, "outputs": [], "source": [ - "url = 'https://raw.githubusercontent.com/opengeos/open-data/main/datasets/libya/Derna_buildings.geojson'\n", - "leafmap.download_file(url, 'buildings.geojson')" + "url = \"https://raw.githubusercontent.com/opengeos/open-data/main/datasets/libya/Derna_buildings.geojson\"\n", + "leafmap.download_file(url, \"buildings.geojson\")" ] }, { "cell_type": "markdown", + "id": "46", "metadata": {}, "source": [ "Convert vector to PMTiles." @@ -533,21 +578,19 @@ { "cell_type": "code", "execution_count": null, + "id": "47", "metadata": {}, "outputs": [], "source": [ - "pmtiles = 'buildings.pmtiles'\n", + "pmtiles = \"buildings.pmtiles\"\n", "leafmap.geojson_to_pmtiles(\n", - " 'buildings.geojson',\n", - " pmtiles,\n", - " layer_name='buildings',\n", - " overwrite=True,\n", - " quiet=True\n", + " \"buildings.geojson\", pmtiles, layer_name=\"buildings\", overwrite=True, quiet=True\n", ")" ] }, { "cell_type": "markdown", + "id": "48", "metadata": {}, "source": [ "Start a HTTP Sever" @@ -556,6 +599,7 @@ { "cell_type": "code", "execution_count": null, + "id": "49", "metadata": {}, "outputs": [], "source": [ @@ -565,15 +609,17 @@ { "cell_type": "code", "execution_count": null, + "id": "50", "metadata": {}, "outputs": [], "source": [ - "url = f'http://127.0.0.1:8000/{pmtiles}'\n", + "url = f\"http://127.0.0.1:8000/{pmtiles}\"\n", "leafmap.pmtiles_metadata(url)" ] }, { "cell_type": "markdown", + "id": "51", "metadata": {}, "source": [ "Display the PMTiles on the map." @@ -582,6 +628,7 @@ { "cell_type": "code", "execution_count": null, + "id": "52", "metadata": {}, "outputs": [], "source": [ @@ -593,7 +640,7 @@ " \"example_source\": {\n", " \"type\": \"vector\",\n", " \"url\": \"pmtiles://\" + url,\n", - " \"attribution\": 'PMTiles',\n", + " \"attribution\": \"PMTiles\",\n", " }\n", " },\n", " \"layers\": [\n", @@ -607,12 +654,13 @@ " ],\n", "}\n", "\n", - "m.add_pmtiles(url, name='Buildings', show=True, zoom_to_layer=True, style=style)\n", + "m.add_pmtiles(url, name=\"Buildings\", show=True, zoom_to_layer=True, style=style)\n", "m" ] }, { "cell_type": "markdown", + "id": "53", "metadata": {}, "source": [ "![](https://i.imgur.com/lhGri31.png)" diff --git a/docs/examples/rastervision/semantic_segmentation.ipynb b/docs/examples/rastervision/semantic_segmentation.ipynb index d7001db..9762f83 100644 --- a/docs/examples/rastervision/semantic_segmentation.ipynb +++ b/docs/examples/rastervision/semantic_segmentation.ipynb @@ -21,7 +21,8 @@ " SemanticSegmentationGeoDataConfig,\n", " SemanticSegmentationLearnerConfig,\n", " SolverConfig,\n", - " SemanticSegmentationLearner)\n" + " SemanticSegmentationLearner,\n", + ")" ] }, { @@ -30,7 +31,7 @@ "metadata": {}, "outputs": [], "source": [ - "os.environ['AWS_NO_SIGN_REQUEST'] = 'YES'" + "os.environ[\"AWS_NO_SIGN_REQUEST\"] = \"YES\"" ] }, { @@ -40,12 +41,14 @@ "outputs": [], "source": [ "class_config = ClassConfig(\n", - " names=['background', 'building'], \n", - " colors=['lightgray', 'darkred'],\n", - " null_class='background')\n", + " names=[\"background\", \"building\"],\n", + " colors=[\"lightgray\", \"darkred\"],\n", + " null_class=\"background\",\n", + ")\n", "\n", "viz = SemanticSegmentationVisualizer(\n", - " class_names=class_config.names, class_colors=class_config.colors)" + " class_names=class_config.names, class_colors=class_config.colors\n", + ")" ] }, { @@ -54,11 +57,11 @@ "metadata": {}, "outputs": [], "source": [ - "train_image_uri = 's3://spacenet-dataset/spacenet/SN7_buildings/train/L15-0331E-1257N_1327_3160_13/images/global_monthly_2018_01_mosaic_L15-0331E-1257N_1327_3160_13.tif'\n", - "train_label_uri = 's3://spacenet-dataset/spacenet/SN7_buildings/train/L15-0331E-1257N_1327_3160_13/labels/global_monthly_2018_01_mosaic_L15-0331E-1257N_1327_3160_13_Buildings.geojson'\n", + "train_image_uri = \"s3://spacenet-dataset/spacenet/SN7_buildings/train/L15-0331E-1257N_1327_3160_13/images/global_monthly_2018_01_mosaic_L15-0331E-1257N_1327_3160_13.tif\"\n", + "train_label_uri = \"s3://spacenet-dataset/spacenet/SN7_buildings/train/L15-0331E-1257N_1327_3160_13/labels/global_monthly_2018_01_mosaic_L15-0331E-1257N_1327_3160_13_Buildings.geojson\"\n", "\n", - "val_image_uri = 's3://spacenet-dataset/spacenet/SN7_buildings/train/L15-0357E-1223N_1429_3296_13/images/global_monthly_2018_01_mosaic_L15-0357E-1223N_1429_3296_13.tif'\n", - "val_label_uri = 's3://spacenet-dataset/spacenet/SN7_buildings/train/L15-0357E-1223N_1429_3296_13/labels/global_monthly_2018_01_mosaic_L15-0357E-1223N_1429_3296_13_Buildings.geojson'" + "val_image_uri = \"s3://spacenet-dataset/spacenet/SN7_buildings/train/L15-0357E-1223N_1429_3296_13/images/global_monthly_2018_01_mosaic_L15-0357E-1223N_1429_3296_13.tif\"\n", + "val_label_uri = \"s3://spacenet-dataset/spacenet/SN7_buildings/train/L15-0357E-1223N_1429_3296_13/labels/global_monthly_2018_01_mosaic_L15-0357E-1223N_1429_3296_13_Buildings.geojson\"" ] }, { @@ -67,8 +70,8 @@ "metadata": {}, "outputs": [], "source": [ - "pred_image_uri = 's3://spacenet-dataset/spacenet/SN7_buildings/train/L15-0357E-1223N_1429_3296_13/images/global_monthly_2020_01_mosaic_L15-0357E-1223N_1429_3296_13.tif'\n", - "pred_label_uri = 's3://spacenet-dataset/spacenet/SN7_buildings/train/L15-0357E-1223N_1429_3296_13/labels/global_monthly_2020_01_mosaic_L15-0357E-1223N_1429_3296_13_Buildings.geojson'" + "pred_image_uri = \"s3://spacenet-dataset/spacenet/SN7_buildings/train/L15-0357E-1223N_1429_3296_13/images/global_monthly_2020_01_mosaic_L15-0357E-1223N_1429_3296_13.tif\"\n", + "pred_label_uri = \"s3://spacenet-dataset/spacenet/SN7_buildings/train/L15-0357E-1223N_1429_3296_13/labels/global_monthly_2020_01_mosaic_L15-0357E-1223N_1429_3296_13_Buildings.geojson\"" ] }, { @@ -77,19 +80,23 @@ "metadata": {}, "outputs": [], "source": [ - "data_augmentation_transform = A.Compose([\n", - " A.Flip(),\n", - " A.ShiftScaleRotate(),\n", - " A.OneOf([\n", - " A.HueSaturationValue(hue_shift_limit=10),\n", - " A.RGBShift(),\n", - " A.ToGray(),\n", - " A.ToSepia(),\n", - " A.RandomBrightness(),\n", - " A.RandomGamma(),\n", - " ]),\n", - " A.CoarseDropout(max_height=32, max_width=32, max_holes=5)\n", - "])" + "data_augmentation_transform = A.Compose(\n", + " [\n", + " A.Flip(),\n", + " A.ShiftScaleRotate(),\n", + " A.OneOf(\n", + " [\n", + " A.HueSaturationValue(hue_shift_limit=10),\n", + " A.RGBShift(),\n", + " A.ToGray(),\n", + " A.ToSepia(),\n", + " A.RandomBrightness(),\n", + " A.RandomGamma(),\n", + " ]\n", + " ),\n", + " A.CoarseDropout(max_height=32, max_width=32, max_holes=5),\n", + " ]\n", + ")" ] }, { @@ -102,11 +109,12 @@ " class_config=class_config,\n", " image_uri=train_image_uri,\n", " label_vector_uri=train_label_uri,\n", - " label_vector_default_class_id=class_config.get_class_id('building'),\n", + " label_vector_default_class_id=class_config.get_class_id(\"building\"),\n", " size_lims=(150, 200),\n", " out_size=256,\n", " max_windows=400,\n", - " transform=data_augmentation_transform)\n", + " transform=data_augmentation_transform,\n", + ")\n", "\n", "len(train_ds)" ] @@ -131,10 +139,11 @@ " class_config=class_config,\n", " image_uri=val_image_uri,\n", " label_vector_uri=val_label_uri,\n", - " label_vector_default_class_id=class_config.get_class_id('building'),\n", + " label_vector_default_class_id=class_config.get_class_id(\"building\"),\n", " size=200,\n", " stride=100,\n", - " transform=A.Resize(256, 256))\n", + " transform=A.Resize(256, 256),\n", + ")\n", "len(val_ds)" ] }, @@ -159,7 +168,8 @@ " image_uri=pred_image_uri,\n", " size=200,\n", " stride=100,\n", - " transform=A.Resize(256, 256))\n", + " transform=A.Resize(256, 256),\n", + ")\n", "len(pred_ds)" ] }, @@ -170,15 +180,16 @@ "outputs": [], "source": [ "model = torch.hub.load(\n", - " 'AdeelH/pytorch-fpn:0.3',\n", - " 'make_fpn_resnet',\n", - " name='resnet18',\n", - " fpn_type='panoptic',\n", + " \"AdeelH/pytorch-fpn:0.3\",\n", + " \"make_fpn_resnet\",\n", + " name=\"resnet18\",\n", + " fpn_type=\"panoptic\",\n", " num_classes=len(class_config),\n", " fpn_channels=128,\n", " in_channels=3,\n", " out_size=(256, 256),\n", - " pretrained=True)" + " pretrained=True,\n", + ")" ] }, { @@ -190,7 +201,7 @@ "data_cfg = SemanticSegmentationGeoDataConfig(\n", " class_names=class_config.names,\n", " class_colors=class_config.colors,\n", - " num_workers=0, # increase to use multi-processing\n", + " num_workers=0, # increase to use multi-processing\n", ")" ] }, @@ -200,11 +211,7 @@ "metadata": {}, "outputs": [], "source": [ - "solver_cfg = SolverConfig(\n", - " batch_sz=8,\n", - " lr=3e-2,\n", - " class_loss_weights=[1., 10.]\n", - ")" + "solver_cfg = SolverConfig(batch_sz=8, lr=3e-2, class_loss_weights=[1.0, 10.0])" ] }, { @@ -224,7 +231,7 @@ "source": [ "learner = SemanticSegmentationLearner(\n", " cfg=learner_cfg,\n", - " output_dir='./train-demo/',\n", + " output_dir=\"./train-demo/\",\n", " model=model,\n", " train_ds=train_ds,\n", " valid_ds=val_ds,\n", @@ -255,7 +262,7 @@ "metadata": {}, "outputs": [], "source": [ - "%tensorboard --bind_all --logdir \"./train-demo/tb-logs\" --reload_interval 10 " + "%tensorboard --bind_all --logdir \"./train-demo/tb-logs\" --reload_interval 10" ] }, { @@ -282,7 +289,7 @@ "metadata": {}, "outputs": [], "source": [ - "learner.plot_predictions(split='valid', show=True)" + "learner.plot_predictions(split=\"valid\", show=True)" ] }, { @@ -301,8 +308,8 @@ "outputs": [], "source": [ "learner = SemanticSegmentationLearner.from_model_bundle(\n", - " model_bundle_uri='./train-demo/model-bundle.zip',\n", - " output_dir='./train-demo/',\n", + " model_bundle_uri=\"./train-demo/model-bundle.zip\",\n", + " output_dir=\"./train-demo/\",\n", " model=model,\n", ")" ] @@ -314,8 +321,8 @@ "outputs": [], "source": [ "learner = SemanticSegmentationLearner.from_model_bundle(\n", - " model_bundle_uri='./train-demo/model-bundle.zip',\n", - " output_dir='./train-demo/',\n", + " model_bundle_uri=\"./train-demo/model-bundle.zip\",\n", + " output_dir=\"./train-demo/\",\n", " model=model,\n", " train_ds=train_ds,\n", " valid_ds=val_ds,\n", @@ -338,7 +345,7 @@ "metadata": {}, "outputs": [], "source": [ - "learner.plot_predictions(split='valid', show=True)" + "learner.plot_predictions(split=\"valid\", show=True)" ] }, { @@ -352,7 +359,8 @@ " raw_out=True,\n", " numpy_out=True,\n", " predict_kw=dict(out_shape=(325, 325)),\n", - " progress_bar=True)" + " progress_bar=True,\n", + ")" ] }, { @@ -366,7 +374,8 @@ " predictions,\n", " smooth=True,\n", " extent=pred_ds.scene.extent,\n", - " num_classes=len(class_config))" + " num_classes=len(class_config),\n", + ")" ] }, { @@ -385,9 +394,10 @@ "outputs": [], "source": [ "pred_labels.save(\n", - " uri=f'predict',\n", + " uri=f\"predict\",\n", " crs_transformer=pred_ds.scene.raster_source.crs_transformer,\n", - " class_config=class_config)" + " class_config=class_config,\n", + ")" ] } ], diff --git a/docs/examples/samgeo/arcgis.ipynb b/docs/examples/samgeo/arcgis.ipynb index 2c45580..8055179 100644 --- a/docs/examples/samgeo/arcgis.ipynb +++ b/docs/examples/samgeo/arcgis.ipynb @@ -227,9 +227,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "scrolled": false - }, + "metadata": {}, "outputs": [], "source": [ "sam.show_anns(axis=\"off\", alpha=1, output=\"ag_annotations.tif\")" @@ -427,7 +425,7 @@ "metadata": {}, "outputs": [], "source": [ - "sam.generate('agriculture.tif', output=\"ag_masks2.tif\", foreground=True)" + "sam.generate(\"agriculture.tif\", output=\"ag_masks2.tif\", foreground=True)" ] }, { diff --git a/docs/examples/samgeo/automatic_mask_generator_hq.ipynb b/docs/examples/samgeo/automatic_mask_generator_hq.ipynb index a510a0a..4f3c4a9 100644 --- a/docs/examples/samgeo/automatic_mask_generator_hq.ipynb +++ b/docs/examples/samgeo/automatic_mask_generator_hq.ipynb @@ -39,7 +39,13 @@ "source": [ "import os\n", "import leafmap\n", - "from samgeo.hq_sam import SamGeo, show_image, download_file, overlay_images, tms_to_geotiff" + "from samgeo.hq_sam import (\n", + " SamGeo,\n", + " show_image,\n", + " download_file,\n", + " overlay_images,\n", + " tms_to_geotiff,\n", + ")" ] }, { @@ -147,7 +153,7 @@ "outputs": [], "source": [ "sam = SamGeo(\n", - " model_type=\"vit_h\", # can be vit_h, vit_b, vit_l, vit_tiny\n", + " model_type=\"vit_h\", # can be vit_h, vit_b, vit_l, vit_tiny\n", " sam_kwargs=None,\n", ")" ] @@ -288,7 +294,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ diff --git a/docs/examples/samgeo/box_prompts.ipynb b/docs/examples/samgeo/box_prompts.ipynb index 5858943..80fd433 100644 --- a/docs/examples/samgeo/box_prompts.ipynb +++ b/docs/examples/samgeo/box_prompts.ipynb @@ -241,7 +241,7 @@ "metadata": {}, "outputs": [], "source": [ - "m.add_raster('mask.tif', cmap='viridis', nodata=0, layer_name='Mask')\n", + "m.add_raster(\"mask.tif\", cmap=\"viridis\", nodata=0, layer_name=\"Mask\")\n", "m" ] }, @@ -260,7 +260,7 @@ "metadata": {}, "outputs": [], "source": [ - "url = 'https://opengeos.github.io/data/sam/tree_boxes.geojson'\n", + "url = \"https://opengeos.github.io/data/sam/tree_boxes.geojson\"\n", "geojson = \"tree_boxes.geojson\"\n", "leafmap.download_file(url, geojson)" ] @@ -350,8 +350,7 @@ "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.8" - }, - "orig_nbformat": 4 + } }, "nbformat": 4, "nbformat_minor": 2 diff --git a/docs/examples/samgeo/fast_sam.ipynb b/docs/examples/samgeo/fast_sam.ipynb index 153595a..14a98fa 100644 --- a/docs/examples/samgeo/fast_sam.ipynb +++ b/docs/examples/samgeo/fast_sam.ipynb @@ -142,6 +142,7 @@ "outputs": [], "source": [ "from samgeo.fast_sam import SamGeo\n", + "\n", "sam = SamGeo(model=\"FastSAM-x.pt\")" ] }, @@ -259,8 +260,7 @@ "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" - }, - "orig_nbformat": 4 + } }, "nbformat": 4, "nbformat_minor": 2 diff --git a/docs/examples/samgeo/input_prompts.ipynb b/docs/examples/samgeo/input_prompts.ipynb index d9c4b87..cf036e6 100644 --- a/docs/examples/samgeo/input_prompts.ipynb +++ b/docs/examples/samgeo/input_prompts.ipynb @@ -37,7 +37,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -147,7 +147,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -260,8 +260,7 @@ "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.8" - }, - "orig_nbformat": 4 + } }, "nbformat": 4, "nbformat_minor": 2 diff --git a/docs/examples/samgeo/input_prompts_hq.ipynb b/docs/examples/samgeo/input_prompts_hq.ipynb index 11329aa..40c5d61 100644 --- a/docs/examples/samgeo/input_prompts_hq.ipynb +++ b/docs/examples/samgeo/input_prompts_hq.ipynb @@ -258,8 +258,7 @@ "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" - }, - "orig_nbformat": 4 + } }, "nbformat": 4, "nbformat_minor": 2 diff --git a/docs/examples/samgeo/maxar_open_data.ipynb b/docs/examples/samgeo/maxar_open_data.ipynb index f6faaad..347e6bc 100644 --- a/docs/examples/samgeo/maxar_open_data.ipynb +++ b/docs/examples/samgeo/maxar_open_data.ipynb @@ -63,7 +63,9 @@ }, "outputs": [], "source": [ - "url = 'https://drive.google.com/file/d/1jIIC5hvSPeJEC0fbDhtxVWk2XV9AxsQD/view?usp=sharing'" + "url = (\n", + " \"https://drive.google.com/file/d/1jIIC5hvSPeJEC0fbDhtxVWk2XV9AxsQD/view?usp=sharing\"\n", + ")" ] }, { @@ -74,7 +76,7 @@ }, "outputs": [], "source": [ - "leafmap.download_file(url, output='image.tif')" + "leafmap.download_file(url, output=\"image.tif\")" ] }, { @@ -94,7 +96,7 @@ "source": [ "m = leafmap.Map(height=\"600px\")\n", "m.add_basemap(\"SATELLITE\")\n", - "m.add_raster('image.tif', layer_name=\"Image\")\n", + "m.add_raster(\"image.tif\", layer_name=\"Image\")\n", "m.add_layer_manager()\n", "m" ] @@ -160,7 +162,7 @@ }, "outputs": [], "source": [ - "sam.generate('image.tif', output=\"mask.tif\", foreground=True)" + "sam.generate(\"image.tif\", output=\"mask.tif\", foreground=True)" ] }, { @@ -178,7 +180,7 @@ }, "outputs": [], "source": [ - "raster_to_vector('mask.tif', output='mask.shp')" + "raster_to_vector(\"mask.tif\", output=\"mask.shp\")" ] }, { @@ -235,7 +237,7 @@ "outputs": [], "source": [ "leafmap.image_comparison(\n", - " 'image.tif',\n", + " \"image.tif\",\n", " \"annotation.tif\",\n", " label1=\"Image\",\n", " label2=\"Segmentation\",\n", @@ -257,7 +259,7 @@ }, "outputs": [], "source": [ - "overlay_images('image.tif', \"annotation.tif\", backend=\"TkAgg\")" + "overlay_images(\"image.tif\", \"annotation.tif\", backend=\"TkAgg\")" ] }, { @@ -273,8 +275,8 @@ "metadata": {}, "outputs": [], "source": [ - "m.add_raster('mask.tif', layer_name='Mask', nodata=0)\n", - "m.add_raster('annotation.tif', layer_name='Annotation')\n", + "m.add_raster(\"mask.tif\", layer_name=\"Mask\", nodata=0)\n", + "m.add_raster(\"annotation.tif\", layer_name=\"Annotation\")\n", "m" ] }, @@ -286,7 +288,7 @@ }, "outputs": [], "source": [ - "m.add_vector('mask.shp', layer_name='Vector', info_mode=None)" + "m.add_vector(\"mask.shp\", layer_name=\"Vector\", info_mode=None)" ] }, { diff --git a/docs/examples/samgeo/swimming_pools.ipynb b/docs/examples/samgeo/swimming_pools.ipynb index d49e349..d301e53 100644 --- a/docs/examples/samgeo/swimming_pools.ipynb +++ b/docs/examples/samgeo/swimming_pools.ipynb @@ -200,9 +200,9 @@ "outputs": [], "source": [ "sam.show_anns(\n", - " cmap='Blues',\n", - " box_color='red',\n", - " title='Automatic Segmentation of Swimming Pools',\n", + " cmap=\"Blues\",\n", + " box_color=\"red\",\n", + " title=\"Automatic Segmentation of Swimming Pools\",\n", " blend=True,\n", ")" ] @@ -228,10 +228,10 @@ "outputs": [], "source": [ "sam.show_anns(\n", - " cmap='Blues',\n", + " cmap=\"Blues\",\n", " add_boxes=False,\n", " alpha=0.5,\n", - " title='Automatic Segmentation of Swimming Pools',\n", + " title=\"Automatic Segmentation of Swimming Pools\",\n", ")" ] }, @@ -256,12 +256,12 @@ "outputs": [], "source": [ "sam.show_anns(\n", - " cmap='Greys_r',\n", + " cmap=\"Greys_r\",\n", " add_boxes=False,\n", " alpha=1,\n", - " title='Automatic Segmentation of Swimming Pools',\n", + " title=\"Automatic Segmentation of Swimming Pools\",\n", " blend=False,\n", - " output='pools.tif',\n", + " output=\"pools.tif\",\n", ")" ] }, @@ -360,8 +360,7 @@ "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" - }, - "orig_nbformat": 4 + } }, "nbformat": 4, "nbformat_minor": 2 diff --git a/docs/examples/samgeo/text_prompts.ipynb b/docs/examples/samgeo/text_prompts.ipynb index 19de7f5..11d2719 100644 --- a/docs/examples/samgeo/text_prompts.ipynb +++ b/docs/examples/samgeo/text_prompts.ipynb @@ -200,9 +200,9 @@ "outputs": [], "source": [ "sam.show_anns(\n", - " cmap='Greens',\n", - " box_color='red',\n", - " title='Automatic Segmentation of Trees',\n", + " cmap=\"Greens\",\n", + " box_color=\"red\",\n", + " title=\"Automatic Segmentation of Trees\",\n", " blend=True,\n", ")" ] @@ -228,10 +228,10 @@ "outputs": [], "source": [ "sam.show_anns(\n", - " cmap='Greens',\n", + " cmap=\"Greens\",\n", " add_boxes=False,\n", " alpha=0.5,\n", - " title='Automatic Segmentation of Trees',\n", + " title=\"Automatic Segmentation of Trees\",\n", ")" ] }, @@ -256,12 +256,12 @@ "outputs": [], "source": [ "sam.show_anns(\n", - " cmap='Greys_r',\n", + " cmap=\"Greys_r\",\n", " add_boxes=False,\n", " alpha=1,\n", - " title='Automatic Segmentation of Trees',\n", + " title=\"Automatic Segmentation of Trees\",\n", " blend=False,\n", - " output='trees.tif',\n", + " output=\"trees.tif\",\n", ")" ] }, @@ -353,8 +353,7 @@ "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" - }, - "orig_nbformat": 4 + } }, "nbformat": 4, "nbformat_minor": 2 diff --git a/docs/examples/samgeo/text_prompts_batch.ipynb b/docs/examples/samgeo/text_prompts_batch.ipynb index 7f5c160..cbb100e 100644 --- a/docs/examples/samgeo/text_prompts_batch.ipynb +++ b/docs/examples/samgeo/text_prompts_batch.ipynb @@ -198,13 +198,13 @@ "outputs": [], "source": [ "sam.predict_batch(\n", - " images='tiles',\n", - " out_dir='masks',\n", + " images=\"tiles\",\n", + " out_dir=\"masks\",\n", " text_prompt=text_prompt,\n", " box_threshold=0.24,\n", " text_threshold=0.24,\n", " mask_multiplier=255,\n", - " dtype='uint8',\n", + " dtype=\"uint8\",\n", " merge=True,\n", " verbose=True,\n", ")" @@ -223,7 +223,7 @@ "metadata": {}, "outputs": [], "source": [ - "m.add_raster('masks/merged.tif', cmap='viridis', nodata=0, layer_name='Mask')\n", + "m.add_raster(\"masks/merged.tif\", cmap=\"viridis\", nodata=0, layer_name=\"Mask\")\n", "m.add_layer_manager()\n", "m" ] @@ -253,8 +253,7 @@ "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" - }, - "orig_nbformat": 4 + } }, "nbformat": 4, "nbformat_minor": 2 diff --git a/docs/geoai.md b/docs/geoai.md index 0f0df0b..1285125 100644 --- a/docs/geoai.md +++ b/docs/geoai.md @@ -1,4 +1,4 @@ - + # geoai module ::: geoai.geoai \ No newline at end of file diff --git a/geoai/__init__.py b/geoai/__init__.py index 952765f..53a3208 100644 --- a/geoai/__init__.py +++ b/geoai/__init__.py @@ -1,5 +1,5 @@ """Top-level package for geoai.""" __author__ = """Qiusheng Wu""" -__email__ = 'giswqs@gmail.com' -__version__ = '0.1.0' +__email__ = "giswqs@gmail.com" +__version__ = "0.1.0" diff --git a/geoai/common.py b/geoai/common.py index 6967be7..9681b86 100644 --- a/geoai/common.py +++ b/geoai/common.py @@ -1,7 +1,7 @@ """The common module contains common functions and classes used by the other modules. """ + def hello_world(): - """Prints "Hello World!" to the console. - """ - print("Hello World!") \ No newline at end of file + """Prints "Hello World!" to the console.""" + print("Hello World!") diff --git a/geoai/segmentation.py b/geoai/segmentation.py new file mode 100644 index 0000000..44c64e1 --- /dev/null +++ b/geoai/segmentation.py @@ -0,0 +1,330 @@ +import os +import numpy as np +from PIL import Image +import torch +import matplotlib.pyplot as plt +from torch.utils.data import Dataset, Subset +import torch.nn.functional as F +from sklearn.model_selection import train_test_split +import albumentations as A +from albumentations.pytorch import ToTensorV2 +from transformers import ( + Trainer, + TrainingArguments, + SegformerForSemanticSegmentation, + DefaultDataCollator, +) + + +class CustomDataset(Dataset): + """Custom Dataset for loading images and masks.""" + + def __init__( + self, + images_dir: str, + masks_dir: str, + transform: A.Compose = None, + target_size: tuple = (256, 256), + num_classes: int = 2, + ): + """ + Args: + images_dir (str): Directory containing images. + masks_dir (str): Directory containing masks. + transform (A.Compose, optional): Transformations to be applied on the images and masks. + target_size (tuple, optional): Target size for resizing images and masks. + num_classes (int, optional): Number of classes in the masks. + """ + self.images_dir = images_dir + self.masks_dir = masks_dir + self.transform = transform + self.target_size = target_size + self.num_classes = num_classes + self.images = sorted(os.listdir(images_dir)) + self.masks = sorted(os.listdir(masks_dir)) + + def __len__(self) -> int: + """Returns the total number of samples.""" + return len(self.images) + + def __getitem__(self, idx: int) -> dict: + """ + Args: + idx (int): Index of the sample to fetch. + + Returns: + dict: A dictionary with 'pixel_values' and 'labels'. + """ + img_path = os.path.join(self.images_dir, self.images[idx]) + mask_path = os.path.join(self.masks_dir, self.masks[idx]) + image = Image.open(img_path).convert("RGB") + mask = Image.open(mask_path).convert("L") + + image = image.resize(self.target_size) + mask = mask.resize(self.target_size) + + image = np.array(image) + mask = np.array(mask) + + mask = (mask > 127).astype(np.uint8) + + if self.transform: + transformed = self.transform(image=image, mask=mask) + image = transformed["image"] + mask = transformed["mask"] + + assert ( + mask.max() < self.num_classes + ), f"Mask values should be less than {self.num_classes}, but found {mask.max()}" + assert ( + mask.min() >= 0 + ), f"Mask values should be greater than or equal to 0, but found {mask.min()}" + + mask = mask.clone().detach().long() + + return {"pixel_values": image, "labels": mask} + + +def get_transform() -> A.Compose: + """ + Returns: + A.Compose: A composition of image transformations. + """ + return A.Compose( + [ + A.Resize(256, 256), + A.HorizontalFlip(p=0.5), + A.VerticalFlip(p=0.5), + A.RandomRotate90(p=0.5), + A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + ToTensorV2(), + ] + ) + + +def prepare_datasets( + images_dir: str, + masks_dir: str, + transform: A.Compose, + test_size: float = 0.2, + random_state: int = 42, +) -> tuple: + """ + Args: + images_dir (str): Directory containing images. + masks_dir (str): Directory containing masks. + transform (A.Compose): Transformations to be applied. + test_size (float, optional): Proportion of the dataset to include in the validation split. + random_state (int, optional): Random seed for shuffling the dataset. + + Returns: + tuple: Training and validation datasets. + """ + dataset = CustomDataset(images_dir, masks_dir, transform) + train_indices, val_indices = train_test_split( + list(range(len(dataset))), test_size=test_size, random_state=random_state + ) + train_dataset = Subset(dataset, train_indices) + val_dataset = Subset(dataset, val_indices) + return train_dataset, val_dataset + + +def train_model( + train_dataset: Dataset, + val_dataset: Dataset, + pretrained_model: str = "nvidia/segformer-b0-finetuned-ade-512-512", + model_save_path: str = "./model", + output_dir: str = "./results", + num_epochs: int = 10, + batch_size: int = 8, + learning_rate: float = 5e-5, +) -> str: + """ + Trains the model and saves the fine-tuned model to the specified path. + + Args: + train_dataset (Dataset): Training dataset. + val_dataset (Dataset): Validation dataset. + pretrained_model (str, optional): Pretrained model to fine-tune. + model_save_path (str): Path to save the fine-tuned model. Defaults to './model'. + output_dir (str, optional): Directory to save training outputs. + num_epochs (int, optional): Number of training epochs. + batch_size (int, optional): Batch size for training and evaluation. + learning_rate (float, optional): Learning rate for training. + + Returns: + str: Path to the saved fine-tuned model. + """ + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = SegformerForSemanticSegmentation.from_pretrained(pretrained_model).to( + device + ) + data_collator = DefaultDataCollator(return_tensors="pt") + + training_args = TrainingArguments( + output_dir=output_dir, + num_train_epochs=num_epochs, + per_device_train_batch_size=batch_size, + per_device_eval_batch_size=batch_size, + eval_strategy="epoch", + save_strategy="epoch", + logging_dir="./logs", + learning_rate=learning_rate, + ) + + trainer = Trainer( + model=model, + args=training_args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=val_dataset, + ) + + trainer.train() + model.save_pretrained(model_save_path) + print(f"Model saved to {model_save_path}") + return model_save_path + + +def load_model( + model_path: str, device: torch.device +) -> SegformerForSemanticSegmentation: + """ + Loads the fine-tuned model from the specified path. + + Args: + model_path (str): Path to the model. + device (torch.device): Device to load the model on. + + Returns: + SegformerForSemanticSegmentation: Loaded model. + """ + model = SegformerForSemanticSegmentation.from_pretrained(model_path) + model.to(device) + model.eval() + return model + + +def preprocess_image(image_path: str, target_size: tuple = (256, 256)) -> torch.Tensor: + """ + Preprocesses the input image for prediction. + + Args: + image_path (str): Path to the input image. + target_size (tuple, optional): Target size for resizing the image. + + Returns: + torch.Tensor: Preprocessed image tensor. + """ + image = Image.open(image_path).convert("RGB") + transform = A.Compose( + [ + A.Resize(target_size[0], target_size[1]), + A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + ToTensorV2(), + ] + ) + image = np.array(image) + transformed = transform(image=image) + return transformed["image"].unsqueeze(0) + + +def predict_image( + model: SegformerForSemanticSegmentation, + image_tensor: torch.Tensor, + original_size: tuple, + device: torch.device, +) -> np.ndarray: + """ + Predicts the segmentation mask for the input image. + + Args: + model (SegformerForSemanticSegmentation): Fine-tuned model. + image_tensor (torch.Tensor): Preprocessed image tensor. + original_size (tuple): Original size of the image (width, height). + device (torch.device): Device to perform inference on. + + Returns: + np.ndarray: Predicted segmentation mask. + """ + with torch.no_grad(): + image_tensor = image_tensor.to(device) + outputs = model(pixel_values=image_tensor) + logits = outputs.logits + upsampled_logits = F.interpolate( + logits, size=original_size[::-1], mode="bilinear", align_corners=False + ) + predictions = torch.argmax(upsampled_logits, dim=1).cpu().numpy() + return predictions[0] + + +def segment_image( + image_path: str, + model_path: str, + target_size: tuple = (256, 256), + device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu"), +) -> np.ndarray: + """ + Segments the input image using the fine-tuned model. + + Args: + image_path (str): Path to the input image. + model_path (str): Path to the fine-tuned model. + target_size (tuple, optional): Target size for resizing the image. + device (torch.device, optional): Device to perform inference on. + + Returns: + np.ndarray: Predicted segmentation mask. + """ + model = load_model(model_path, device) + image = Image.open(image_path).convert("RGB") + original_size = image.size + image_tensor = preprocess_image(image_path, target_size) + predictions = predict_image(model, image_tensor, original_size, device) + return predictions + + +def visualize_predictions( + image_path: str, + segmented_mask: np.ndarray, + target_size: tuple = (256, 256), + reference_image_path: str = None, +) -> None: + """ + Visualizes the original image, segmented mask, and optionally the reference image. + + Args: + image_path (str): Path to the original image. + segmented_mask (np.ndarray): Predicted segmentation mask. + target_size (tuple, optional): Target size for resizing images. + reference_image_path (str, optional): Path to the reference image. + """ + original_image = Image.open(image_path).convert("RGB") + original_image = original_image.resize(target_size) + segmented_image = Image.fromarray((segmented_mask * 255).astype(np.uint8)) + + if reference_image_path: + reference_image = Image.open(reference_image_path).convert("RGB") + reference_image = reference_image.resize(target_size) + fig, axes = plt.subplots(1, 3, figsize=(18, 6)) + axes[1].imshow(reference_image) + axes[1].set_title("Reference Image") + axes[1].axis("off") + else: + fig, axes = plt.subplots(1, 2, figsize=(12, 6)) + + axes[0].imshow(original_image) + axes[0].set_title("Original Image") + axes[0].axis("off") + + if reference_image_path: + axes[2].imshow(segmented_image, cmap="gray") + axes[2].set_title("Segmented Image") + axes[2].axis("off") + else: + axes[1].imshow(segmented_image, cmap="gray") + axes[1].set_title("Segmented Image") + axes[1].axis("off") + + plt.tight_layout() + plt.show() diff --git a/requirements.txt b/requirements.txt index 053928e..e6935e9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,5 @@ -segment-geospatial \ No newline at end of file +albumentations +pytorch +scikit-learn +segment-geospatial +transformers diff --git a/requirements_dev.txt b/requirements_dev.txt index 22ac8e7..90f13f8 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -1,13 +1,13 @@ black black[jupyter] -pip bump2version -wheel -watchdog -flake8 -tox +codespell coverage +flake8 +pip Sphinx +tox twine -codespell +watchdog +wheel diff --git a/requirements_docs.txt b/requirements_docs.txt index bd75819..9e0c3a2 100644 --- a/requirements_docs.txt +++ b/requirements_docs.txt @@ -3,22 +3,22 @@ coverage flake8 ipykernel livereload +mkdocs +mkdocs-git-revision-date-localized-plugin +mkdocs-git-revision-date-plugin +mkdocs-jupyter>=0.24.0 +mkdocs-material>=9.1.3 +mkdocs-pdf-export-plugin +mkdocstrings +mkdocstrings-crystal +mkdocstrings-python-legacy nbconvert nbformat pip +pygments +pymdown-extensions sphinx tox twine watchdog wheel -mkdocs -mkdocs-git-revision-date-plugin -mkdocs-git-revision-date-localized-plugin -mkdocs-jupyter>=0.24.0 -mkdocs-material>=9.1.3 -mkdocs-pdf-export-plugin -mkdocstrings -mkdocstrings-crystal -mkdocstrings-python-legacy -pygments -pymdown-extensions \ No newline at end of file diff --git a/setup.py b/setup.py index 9228ae4..0f3f202 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ from os import path as op from setuptools import setup, find_packages -with open('README.md') as readme_file: +with open("README.md") as readme_file: readme = readme_file.read() here = op.abspath(op.dirname(__file__)) @@ -18,40 +18,40 @@ install_requires = [x.strip() for x in all_reqs if "git+" not in x] dependency_links = [x.strip().replace("git+", "") for x in all_reqs if "git+" not in x] -requirements = [ ] +requirements = [] -setup_requirements = [ ] +setup_requirements = [] -test_requirements = [ ] +test_requirements = [] setup( author="Qiusheng Wu", - author_email='giswqs@gmail.com', - python_requires='>=3.8', + author_email="giswqs@gmail.com", + python_requires=">=3.8", classifiers=[ - 'Intended Audience :: Developers', - 'License :: OSI Approved :: MIT License', - 'Natural Language :: English', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10', - 'Programming Language :: Python :: 3.11', + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Natural Language :: English", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", ], description="A Python package for using Artificial Intelligence (AI) with geospatial data", install_requires=install_requires, dependency_links=dependency_links, license="MIT license", long_description=readme, - long_description_content_type='text/markdown', + long_description_content_type="text/markdown", include_package_data=True, - keywords='geoai', - name='geoai-py', - packages=find_packages(include=['geoai', 'geoai.*']), + keywords="geoai", + name="geoai-py", + packages=find_packages(include=["geoai", "geoai.*"]), setup_requires=setup_requirements, - test_suite='tests', + test_suite="tests", tests_require=test_requirements, - url='https://github.com/opengeos/geoai', - version='0.1.0', + url="https://github.com/opengeos/geoai", + version="0.1.0", zip_safe=False, )