Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

i.sam2: SamGeo2 model #1244

Merged
merged 22 commits into from
Feb 15, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/imagery/i.sam2/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
MODULE_TOPDIR = ../..

PGM = i.sam2

include $(MODULE_TOPDIR)/include/Make/Script.make

default: script
50 changes: 50 additions & 0 deletions src/imagery/i.sam2/i.sam2.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
<h2>DESCRIPTION</h2>

<em>i.sam2</em> allows users to segment orthoimagery based on text prompts using <a href="https://samgeo.gishub.org/">SamGeo</a>.

<h2>REQUIREMENTS</h2>

<ul>
<li><a href="https://pillow.readthedocs.io/en/stable/">Pillow>=10.2.0</a></li>
<li><a href="https://numpy.org/">numpy>=1.26.1</a></li>
<li><a href="https://pytorch.org/">torch>=2.5.1</a></li>
<li><a href="https://samgeo.gishub.org/">segment-geospatial>=0.12.3</a></li>
</ul>

<div class="code">
<pre>
pip install pillow numpy torch segment-geospatial
</pre>
</div>

<h2>EXAMPLES</h2>

Segment orthoimagery using SamGeo2:

<div class="code">
<pre>
i.sam2 group=rgb_255 output=tree_mask text_prompt="trees"
</pre>
</div>

<img src="./i_sam2_trees.jpg" height="600" alt="i.sam2 example" />

<h2>NOTES</h2>
The first time use will be longer as the model needs to be downloaded. Subsequent runs will be faster.
Additionally, Cuda is required for GPU acceleration. If you do not have a GPU, you can use the CPU by setting the environment variable `CUDA_VISIBLE_DEVICES` to `-1`.

<h2>REFERENCES</h2>
<ul>
<li>Wu, Q., & Osco, L. (2023). samgeo: A Python package for segmenting geospatial data with the Segment Anything Model (SAM). Journal of Open Source Software, 8(89), 5663. <a href="https://doi.org/10.21105/joss.05663">https://doi.org/10.21105/joss.05663</a></li>
<li>Osco, L. P., Wu, Q., de Lemos, E. L., Gonçalves, W. N., Ramos, A. P. M., Li, J., & Junior, J. M. (2023). The Segment Anything Model (SAM) for remote sensing applications: From zero to one shot. International Journal of Applied Earth Observation and Geoinformation, 124, 103540. <a href="https://doi.org/10.1016/j.jag.2023.103540">https://doi.org/10.1016/j.jag.2023.103540</a></li>
</ul>

<h2>SEE ALSO</h2>
<em>
<a href="i.segment.gsoc.html">i.segment.gsoc</a> for region growing and merging segmentation,
<a href="i.segment.hierarchical">i.segment.hierarchical</a> performs a hierarchical segmentation,
<a href="i.superpixels.slic">i.superpixels.slic</a> for superpixel segmentation.
</em>

<h2>AUTHOR</h2>
Corey T. White (NCSU GeoForAll Lab & OpenPlains Inc.)
256 changes: 256 additions & 0 deletions src/imagery/i.sam2/i.sam2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
#!/usr/bin/env python3

############################################################################
#
# MODULE: i.sam2
# AUTHOR: Corey T. White, OpenPlains Inc.
# PURPOSE: Uses the SAMGeo model for segmentation in GRASS GIS.
# COPYRIGHT: (C) 2023-2025 Corey White
# This program is free software under the GNU General
# Public License (>=v2). Read the file COPYING that
# comes with GRASS for details.
#
#############################################################################

# %module
# % description: Integrates SAMGeo model with text prompt for segmentation in GRASS GIS.
# % keyword: imagery
# % keyword: segmentation
# % keyword: object recognition
# % keyword: deep learning
# %end

# %option G_OPT_I_GROUP
# % key: group
# % description: Name of input imagery group
# % required: yes
# %end

# %option G_OPT_R_OUTPUT
# % key: output
# % description: Name of output segmented raster map
# % required: yes
# %end

# %option G_OPT_M_DIR
# % key: checkpoint_dir
# % description: Path to the SAMGeo model checkpoint directory (optional if using default model)
# % required: no
# %end

# %option
# % key: text_prompt
# % type: string
# % description: Text prompt to guide segmentation
# % required: no
# %end

# %option
# % key: text_threshold
# % type: double
# % answer: 0.24
# % description: Text threshold for text segmentation
# % required: no
# % multiple: no
# %end

# %option
# % key: box_threshold
# % type: double
# % answer: 0.24
# % description: Box threshold for text segmentation
# % required: no
# % multiple: no
# %end

import os
import sys
import grass.script as gs
import torch
import numpy as np
from PIL import Image
from grass.script import array as garray


def get_device():
"""
Determines the available device for computation (CUDA or CPU).

This function checks if a CUDA-enabled GPU is available and returns "cuda" if it is,
otherwise it returns "cpu". If CUDA is available, it also clears the CUDA cache.

Returns:
str: "cuda" if a CUDA-enabled GPU is available, otherwise "cpu".
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
gs.message(_(f"Running computation on {device}..."))
if device == "cuda":
torch.cuda.empty_cache()
return device


def read_raster_group(group):
"""
Reads a group of raster maps and returns them as a list of numpy arrays.

Parameters:
group (str): The name of the raster group to read.

Returns:
list: A list of numpy arrays, each representing a raster map in the group.
"""
gs.message(_("Reading imagery group..."))
rasters = gs.read_command("i.group", group=group, flags="lg")
raster_list = map(str.split, rasters.splitlines())
return [garray.array(raster, dtype=np.uint8) for raster in raster_list]


def normalize_rgb_array(rgb_array):
"""
Normalizes an RGB array to the range [0, 255].

This function takes an RGB array and normalizes its values to the range [0, 255].
If the input array is not of type np.uint8, it scales the values to fit within
this range and converts the array to np.uint8.

Parameters:
rgb_array (numpy.ndarray): The input RGB array to be normalized.

Returns:
numpy.ndarray: The normalized RGB array with values in the range [0, 255] and type np.uint8.
"""
if rgb_array.dtype != np.uint8:
gs.message(_("Converting RGB array to uint8..."))
min_val = rgb_array.min()
max_val = rgb_array.max()

# Avoid potenital division by zero error
if min_val == max_val:
gs.warning(_("RGB array has a constant value, returning uniform array."))
rgb_array = np.full_like(rgb_array, 0, dtype=np.uint8)
else:
scale = 255 / (max_val - min_val)
rgb_array = ((rgb_array - min_val) * scale).astype(np.uint8)

return rgb_array


def run_langsam_segmentation(
np_image, text_prompt, box_threshold, text_threshold, device
):
"""
Runs LangSAM segmentation on the given image using the specified text prompt and thresholds.

Parameters:
np_image (numpy.ndarray): The input image as a NumPy array.
text_prompt (str): The text prompt to guide the segmentation.
box_threshold (float): The threshold for box predictions.
text_threshold (float): The threshold for text predictions.
device (str): The device to run the segmentation on (e.g., 'cpu' or 'cuda').

Returns:
list: A list of masks generated by the segmentation.
"""
from samgeo.text_sam import LangSAM
from torch.amp.autocast_mode import autocast

gs.message(_("Running LangSAM segmentation..."))
sam = LangSAM(model_type="sam2-hiera-large")
with autocast(device_type=device):
masks, boxes, phrases, logits = sam.predict(
image=np_image,
text_prompt=text_prompt,
box_threshold=box_threshold,
text_threshold=text_threshold,
return_results=True,
)
return masks


def run_samgeo_segmentation(rgb_array, checkpoint_dir, device):
"""
Runs SAMGeo segmentation on an input image and saves the output.

Parameters:
rgb_array (numpy.ndarray): The input image as a NumPy array.
checkpoint_dir (str): The path to the SAMGeo model checkpoint.
device (str): The device to run the model on (e.g., 'cpu', 'cuda').

Returns:
list: A list of masks generated by the segmentation.
"""
from samgeo import SamGeo

gs.message(_("Running SAMGeo segmentation..."))
sam = SamGeo(model_type="vit_h", checkpoint_dir=checkpoint_dir, device=device)
sam.generate(source=rgb_array)
masks = sam.objects
return masks


def write_raster(input_np_array, output_raster_masks):
"""
Writes a segmented raster into GRASS GIS.

Parameters:
input_np_array (list of numpy.ndarray): A list of numpy arrays representing the masks of the input image.
output_raster_masks (str): The name of the output raster map to be created in GRASS GIS.

Raises:
ValueError: If the input array is empty or if the masks do not have the same shape.

This function merges multiple raster masks into a single raster, where each mask is assigned a unique value.
The merged raster is then written to a GRASS GIS raster map.
"""

gs.message(_("Importing the segmented raster into GRASS GIS..."))

if len(input_np_array) == 0:
gs.fatal("No masks found.")

merged_raster = np.zeros_like(input_np_array[0], dtype=np.int32)
for idx, band in enumerate(input_np_array):
if band.shape != input_np_array[0].shape:
gs.fatal(_("All masks must have the same shape."))
unique_value = idx + 1
mask = band != 0
merged_raster[mask] = unique_value

mask_raster = garray.array()
mask_raster[...] = merged_raster
mask_raster.write(mapname=output_raster_masks)


def main():
group = options["group"]
output_raster_masks = options["output"]
checkpoint_dir = options.get("checkpoint_dir")
text_prompt = options.get("text_prompt")
text_threshold = float(options.get("text_threshold"))
box_threshold = float(options.get("box_threshold"))

input_image_np = read_raster_group(group)
rgb_array = normalize_rgb_array(np.stack(input_image_np, axis=-1))
np_image = Image.fromarray(rgb_array[:, :, :3])

device = get_device()

try:
if text_prompt:
masks = run_langsam_segmentation(
np_image, text_prompt, box_threshold, text_threshold, device
)
else:
masks = run_samgeo_segmentation(rgb_array, checkpoint_dir, device)
except Exception as e:
gs.fatal(_(f"Error while running SAMGeo: {e}"))
return 1

gs.message(_("Segmentation complete."))
write_raster(masks, output_raster_masks)
return 0


if __name__ == "__main__":
options, flags = gs.parser()
sys.exit(main())
Binary file added src/imagery/i.sam2/i_sam2_trees.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 4 additions & 0 deletions src/imagery/i.sam2/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Pillow>=10.2.0
numpy>=1.26.1
torch>=2.5.1
segment-geospatial>=0.12.3
67 changes: 67 additions & 0 deletions src/imagery/i.sam2/testsuite/test_i_sam2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# import os
# import sys
# import pytest
# import numpy as np
# from unittest.mock import patch, MagicMock
# from grass.script import array as garray
# from PIL import Image
# from grass.gunittest.case import TestCase
# from grass.gunittest.main import test


# @pytest.fixture(scope="module")
# def mock_torch():
# with patch('torch.cuda.is_available') as mock_is_available:
# mock_is_available.return_value = False
# yield mock_is_available


# @pytest.fixture(scope="module")
# def mock_run_langsam_segmentation():
# with patch('i.sam2.run_langsam_segmentation') as mock_run_langsam:
# # Define the mock return value
# mock_run_langsam.return_value = [np.random.randint(0, 2, (100, 100), dtype=np.uint8) for _ in range(3)]
# yield mock_run_langsam


# class TestISam2(TestCase):

# RED_BAND = "lsat7_2002_30"
# GREEN_BAND = "lsat7_2002_20"
# BLUE_BAND = "lsat7_2002_10"

# def _create_imagery_group(cls):
# cls.runModule("i.group", group="test_group", input=','.join([cls.RED_BAND, cls.GREEN_BAND, cls.BLUE_BAND]))

# @classmethod
# def setUpClass(cls):
# """Ensures expected computational region"""
# # to not override mapset's region (which might be used by other tests)
# cls.use_temp_region()
# cls.runModule("g.region", raster="elev_lid792_1m", res=30)
# cls._create_imagery_group(cls)

# @classmethod
# def tearDown(self):
# """
# Remove the outputs created from the centroids module
# This is executed after each test run.
# """
# self.runModule("g.remove", flags="f", type="raster", name="test_output")

# @pytest.mark.usefixtures("mock_torch", "mock_run_langsam")
# def test_main_with_text_prompt(self):
# options = {
# "group": "test_group",
# "output": "test_output",
# "model_path": None,
# "text_prompt": "Waterbodies",
# "text_threshold": "0.24",
# "box_threshold": "0.24"
# }

# self.assertModule("i.sam2", **options)


# if __name__ == "__main__":
# test()