-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathunet_draft.py
47 lines (35 loc) · 1.71 KB
/
unet_draft.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Dropout, concatenate, UpSampling2D
from tensorflow.keras.models import Model
# Load your dataset and preprocess it
def load_and_preprocess_data():
# Load remote sensing images and their ground truth segmentation masks
# Modify this function according to your dataset structure
images, masks = [], []
for img_file, mask_file in zip(image_files, mask_files):
# Load image and mask, then preprocess (e.g., normalize pixel values)
# Use appropriate image loading function depending on your data format (e.g., PIL, OpenCV, etc.)
pass
return np.array(images), np.array(masks)
def unet(input_size=(256, 256, 3)):
inputs = Input(input_size)
conv1 = Conv2D(64, (3, 3), activation='relu', padding='same')(inputs)
conv1 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv1)
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
# Repeat this pattern for each encoding and decoding block, increasing the number of filters
# For example: conv2, conv3, conv4, ..., up6, up7, up8
conv9 = Conv2D(2, (3, 3), activation='relu', padding='same')(up8)
conv10 = Conv2D(1, (1, 1), activation='sigmoid')(conv9)
model = Model(inputs=[inputs], outputs=[conv10])
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
return model
# Load and preprocess your data
train_images, train_masks = load_and_preprocess_data()
# Create and compile the U-Net model
model = unet()
# Train the model
model.fit(train_images, train_masks, batch_size=32, epochs=100, validation_split=0.1)
# Save the model
model.save('unet_remote_sensing.h5')