Skip to content

Commit

Permalink
Update torch name
Browse files Browse the repository at this point in the history
  • Loading branch information
giswqs committed Jul 27, 2024
1 parent 5a614a1 commit c857a72
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
19 changes: 19 additions & 0 deletions geoai/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,3 +328,22 @@ def visualize_predictions(

plt.tight_layout()
plt.show()


# Example usage
if __name__ == "__main__":
images_dir = "../datasets/Water-Bodies-Dataset/Images"
masks_dir = "../datasets/Water-Bodies-Dataset/Masks"
transform = get_transform()
train_dataset, val_dataset = prepare_datasets(images_dir, masks_dir, transform)

model_save_path = "./fine_tuned_model"
train_model(train_dataset, val_dataset, model_save_path)

image_path = "../datasets/Water-Bodies-Dataset/Images/water_body_44.jpg"
reference_image_path = image_path.replace("Images", "Masks")
segmented_mask = segment_image(image_path, model_save_path)

visualize_predictions(
image_path, segmented_mask, reference_image_path=reference_image_path
)
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
albumentations
pytorch
scikit-learn
segment-geospatial
torch
transformers

0 comments on commit c857a72

Please sign in to comment.