Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhenyang (Daniel) Feng committed Jan 21, 2025
0 parents commit ea14ffd
Show file tree
Hide file tree
Showing 36 changed files with 7,332 additions and 0 deletions.
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
__pycache__/
*.jpg
*.png
*.pt
*.pth
*.pkl
*.so
57 changes: 57 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Static Segmentation by Tracking: A Frustratingly Label-Efficient Approach to Fine-Grained Segmentation
[Imageomics Institute](https://imageomics.osu.edu/)

[Zhenyang Feng](https://defisch.github.io/), Zihe Wang, Saul Ibaven Bueno, Tomasz Frelek, Advikaa Ramesh, Jingyan Bai, Lemeng Wang, [Zanming Huang](https://tzmhuang.github.io/), [Jianyang Gu](https://vimar-gu.github.io/), [Jinsu Yoo](https://jinsuyoo.info/), [Tai-Yu Pan](https://tydpan.github.io/), Arpita Chowdhury, Michelle Ramirez, [Elizabeth G Campolongo](https://u.osu.edu/campolongo-4/), Matthew J Thompson, [Christopher G. Lawrence](https://eeb.princeton.edu/people/christopher-lawrence), [Sydne Record](https://umaine.edu/wle/faculty-staff-directory/sydne-record/), [Neil Rosser](https://people.miami.edu/profile/74f02be76bd3ae57ed9edfdad0a3f76d), [Anuj Karpatne](https://anujkarpatne.github.io/), [Daniel Rubenstein](https://eeb.princeton.edu/people/daniel-rubenstein), [Hilmar Lapp](https://lappland.io/), [Charles V. Stewart](https://www.cs.rpi.edu/~stewart/), Tanya Berger-Wolf, [Yu Su](https://ysu1989.github.io/), [Wei-Lun Chao](https://sites.google.com/view/wei-lun-harry-chao)
[[arXiv]](https://arxiv.org/abs/2501.06749) [[Dataset]](https://github.com/Imageomics/NEON_beetles_masks.git) [[BibTeX]](#-citation)
![main figure](assets/main.png)


## 🛠️ Installation
To use SST, the following setup must be ran on a GPU enabled machine. The code requires `torch>=2.5.0`, and `python=3.10.14` is recommended.

Example Conda Environment Setup:
```bash
# Create conda environment
conda create --name sst python=3.10.14
conda activate sst
# Download corresponding torch torchvision version
...
# Download required python packages
pip install -r requirements.txt --no-dependencies
# Download model checkpoints
(cd checkpoints && ./download_ckpts.sh)
(cd checkpoints && wget -q https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth)
# Install SAM 2
(cd sam2 && pip install -e .)
```

## 🧑‍💻 Usage


```bash
python code/segment.py --support_image /path/to/sample/image.png \
--support_mask /path/to/greyscale_mask.png \
--query_images /path/to/query/images/folder \
--output /path/to/output/folder \
[--output_format png/gif]
```

## 📊 Dataset
Beetle part segmentation dataset is out! Available [here](https://github.com/Imageomics/NEON_beetles_masks.git).
We will release our trait segmentation datasets for butterfly in the near future!

## ❤️ Acknowledgements
This project makes use of the [SAM2](https://github.com/facebookresearch/sam2) and [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO) codebases. We are grateful to the developers and maintainers of these projects for their contributions to the open-source community.
We thank [LoRA](https://github.com/microsoft/LoRA) for their great work.


## 📝 Citation
If you find our work helpful for your research, please consider citing using the following BibTeX entry:
```bibtex
@article{feng2025static,
title={Static Segmentation by Tracking: A Frustratingly Label-Efficient Approach to Fine-Grained Segmentation},
author={Feng, Zhenyang and Wang, Zihe and Bueno, Saul Ibaven and Frelek, Tomasz and Ramesh, Advikaa and Bai, Jingyan and Wang, Lemeng and Huang, Zanming and Gu, Jianyang and Yoo, Jinsu and others},
journal={arXiv preprint arXiv:2501.06749},
year={2025}
}
```
31 changes: 31 additions & 0 deletions checkpoints/download_ckpts.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#!/bin/bash

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.


# Define the URLs for the checkpoints
BASE_URL="https://dl.fbaipublicfiles.com/segment_anything_2/072824/"
sam2_hiera_t_url="${BASE_URL}sam2_hiera_tiny.pt"
sam2_hiera_s_url="${BASE_URL}sam2_hiera_small.pt"
sam2_hiera_b_plus_url="${BASE_URL}sam2_hiera_base_plus.pt"
sam2_hiera_l_url="${BASE_URL}sam2_hiera_large.pt"


# Download each of the four checkpoints using wget
echo "Downloading sam2_hiera_tiny.pt checkpoint..."
wget $sam2_hiera_t_url || { echo "Failed to download checkpoint from $sam2_hiera_t_url"; exit 1; }

echo "Downloading sam2_hiera_small.pt checkpoint..."
wget $sam2_hiera_s_url || { echo "Failed to download checkpoint from $sam2_hiera_s_url"; exit 1; }

echo "Downloading sam2_hiera_base_plus.pt checkpoint..."
wget $sam2_hiera_b_plus_url || { echo "Failed to download checkpoint from $sam2_hiera_b_plus_url"; exit 1; }

echo "Downloading sam2_hiera_large.pt checkpoint..."
wget $sam2_hiera_l_url || { echo "Failed to download checkpoint from $sam2_hiera_l_url"; exit 1; }

echo "All checkpoints are downloaded successfully."
275 changes: 275 additions & 0 deletions code/sam_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
import cv2
import groundingdino.util.inference as DINO_inf
import groundingdino.datasets.transforms as T
import torch

from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
from sam2.build_sam import build_sam2
from sam2.build_sam import build_sam2_video_predictor
import sam2
from PIL import Image
import os
import numpy as np
import matplotlib.pyplot as plt

import argparse

def load_DINO_model(model_cfg_path="GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py", model_pretrained_path="checkpoints/groundingdino_swint_ogc.pth"):
model = DINO_inf.load_model(model_cfg_path, model_pretrained_path)
return model

def DINO_image_detection(img, text_prompt, box_conf=0.5, model=None, top_1=False):
if model is None:
GD_root = "/fs/scratch/PAS2099/danielf/SAM2/GroundingDINO"
model_cfg_path = os.path.join(GD_root, "groundingdino/config/GroundingDINO_SwinT_OGC.py")
model_pretrained_path = os.path.join(GD_root, "weights/groundingdino_swint_ogc.pth")
model = DINO_inf.load_model(model_cfg_path, model_pretrained_path)
TEXT_PROMPT = text_prompt
BOX_TRESHOLD = box_conf
TEXT_TRESHOLD = 0.25

transform = T.Compose(
[
T.RandomResize([800], max_size=1333),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
img_pil = Image.fromarray(img)
image_transformed, _ = transform(img_pil, None)

boxes, logits, phrases = DINO_inf.predict(
model=model,
image=image_transformed,
caption=TEXT_PROMPT,
box_threshold=BOX_TRESHOLD,
text_threshold=TEXT_TRESHOLD,
remove_combined=True
)
h,w = img.shape[:2]
boxes_cxcywh = boxes * torch.Tensor([w, h, w, h])
boxes_cxcywh = boxes_cxcywh.numpy()
boxes_xyxy = boxes_cxcywh.copy()
boxes_xyxy[:, 0] = boxes_cxcywh[:, 0] - boxes_cxcywh[:, 2] / 2
boxes_xyxy[:, 1] = boxes_cxcywh[:, 1] - boxes_cxcywh[:, 3] / 2
boxes_xyxy[:, 2] = boxes_cxcywh[:, 0] + boxes_cxcywh[:, 2] / 2
boxes_xyxy[:, 3] = boxes_cxcywh[:, 1] + boxes_cxcywh[:, 3] / 2
if top_1:
if len(boxes_xyxy) == 0:
return None
return boxes_xyxy[np.argmax(logits)]
return boxes_xyxy

def area(mask):
if mask.size == 0: return 0
return np.count_nonzero(mask) / mask.size

def show_mask(mask, ax, obj_id=None, random_color=False, borders = True, alpha=0.5):
if random_color:
color = np.concatenate([np.random.random(3), np.array([alpha])], axis=0)
else:
color = np.array([30/255, 144/255, 255/255, alpha])
if not random_color and obj_id is not None:
color = np.array([*plt.get_cmap("tab10")(obj_id)[:3], alpha])
h, w = mask.shape[-2:]
mask = mask.astype(np.uint8)
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
if borders:
import cv2
contours, _ = cv2.findContours(mask,cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
# Try to smooth contours
contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2)
ax.imshow(mask_image)

def area(mask):
if mask.size == 0: return 0
return np.count_nonzero(mask) / mask.size

def nms_bbox_removal(boxes_xyxy, iou_thresh=0.25 ):
remove_indices = []
for i, box in enumerate(boxes_xyxy):
for j in range(i+1, len(boxes_xyxy)):
box2 = boxes_xyxy[j]
iou1 = compute_iou(box, box2)
iou2 = compute_iou(box2, box)
if iou1 > iou_thresh or iou2 > iou_thresh:
if iou1 > iou2:
remove_indices.append(j)
else:
remove_indices.append(i)
return [box for i, box in enumerate(boxes_xyxy) if i not in remove_indices]

def load_SAM2(ckpt_path, model_cfg_path):
if torch.cuda.is_available():
print("Using CUDA")
device = "cuda"
else:
print("CUDA device not found, using CPU instead")
device = "cpu"
sam2 = build_sam2(model_cfg_path, ckpt_path, device=device, apply_postprocessing=False)
return sam2

def compute_iou(box1, box2):
# intersection / area of box1
x1, y1, x2, y2 = box1
x3, y3, x4, y4 = box2
x5, y5 = max(x1, x3), max(y1, y3)
x6, y6 = min(x2, x4), min(y2, y4)
if x5 >= x6 or y5 >= y6:
return 0
intersection = (x6 - x5) * (y6 - y5)
union = (x2 - x1) * (y2 - y1)
return intersection / union

def show_anns(anns, color=None, borders=True):
if len(anns) == 0:
return
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
ax = plt.gca()
ax.set_autoscale_on(False)

img = np.ones((sorted_anns[0]['segmentation'].squeeze().shape[0], sorted_anns[0]['segmentation'].squeeze().shape[1], 4))
img[:, :, 3] = 0
for ann in sorted_anns:
m = ann['segmentation'].squeeze()
if color is None:
color_mask = np.concatenate([np.random.random(3), [0.75]])
else:
color_mask = color
img[m] = color_mask
if borders:
import cv2
contours, _ = cv2.findContours(m.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
# Try to smooth contours
contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
cv2.drawContours(img, contours, -1, (0, 0, 1, 0.4), thickness=2)

ax.imshow(img)

def build_sam2_predictor(checkpoint="../../checkpoints/sam2_hiera_large.pt", model_cfg="../../sam2_configs/sam2_hiera_l.yaml"):
device = "cuda" if torch.cuda.is_available() else "cpu"
video_predictor = build_sam2_video_predictor(model_cfg, checkpoint, device=device, apply_postprocessing=False)
return video_predictor

def load_masks(video_predictor, query_images, support_image, support_masks, offload_video_to_cpu=True, offload_state_to_cpu=True, verbose=False):
'''
video_predictor: sam2 predictor
query_images: list of np.array of shape (H, W, 3)
support_image: np.array of shape (H, W, 3)
support_masks: list of np.array of shape (H, W)
offload_video_to_cpu: for long video sequences, offload the video to the CPU to save GPU memory
offload_state_to_cpu: save GPU memory by offloading the state to the CPU
'''
query_images.insert(0, support_image)
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
state = video_predictor.init_state(None, image_inputs=query_images, async_loading_frames=False, offload_video_to_cpu=offload_video_to_cpu, offload_state_to_cpu=offload_state_to_cpu, verbose=verbose)
video_predictor.reset_state(state)
for i, patch_mask in enumerate(support_masks):
ann_frame_idx = 0
ann_obj_id = i # give a unique id to each object we interact with
patch_mask = np.array(patch_mask, dtype=np.uint8)
patch_mask = cv2.resize(patch_mask, (1024, 1024))
_, _, _ = video_predictor.add_new_mask(
inference_state=state,
frame_idx=ann_frame_idx,
obj_id=ann_obj_id,
mask=patch_mask,
)
return state

def propagate_masks(video_predictor, state, verbose=False):
"""
returns: list[dict] with keys 'obj_ids', 'segmentation', 'area'
list['segmentation']: np.array of shape (H, W) with dtype bool
"""
frame_info = []
# run propagation throughout the video and collect the results in a dict
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
for _, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(state, verbose=verbose):
out_mask_logits = (out_mask_logits>0).cpu().numpy().squeeze()
if out_mask_logits.ndim == 2:
out_mask_logits = np.expand_dims(out_mask_logits, axis=0)
frame_info.append({'obj_ids': out_obj_ids, 'segmentation': out_mask_logits, 'area': area(out_mask_logits)})
return frame_info

def show_video_masks(image, frame_info):
img_resized = cv2.resize(image, (1024, 1024))
plt.imshow(img_resized)
for obj_ids, mask in zip(frame_info['obj_ids'], frame_info['masks']):
mask = cv2.resize(mask.astype(np.uint8), (1024, 1024))
show_mask(mask, plt.gca(), obj_id=obj_ids, borders=True, alpha=0.75)
plt.axis('off')
plt.show()

def get_parser(inputs):
parser = argparse.ArgumentParser(description="Detectron2 demo for builtin configs")
parser.add_argument(
"--config-file",
default="configs/quick_schedules/mask_rcnn_R_50_FPN_inference_acc_test.yaml",
metavar="FILE",
help="path to config file",
)
parser.add_argument(
"--opts",
help="Modify config options using the command-line 'KEY VALUE' pairs",
default=[],
nargs=argparse.REMAINDER,
)
args = parser.parse_args(inputs)
return args

def auto_segment_SAM(boxes_xyxy, img, iou_thresh=0.9, stability_score_thresh=0.95, min_mask_region_area=10000, verbose=False):
checkpoint = "../../checkpoints/sam2_hiera_large.pt"
model_cfg = "../../sam2_configs/sam2_hiera_l.yaml"
sam2 = load_SAM2(checkpoint, model_cfg)
auto_mask_predictor = SAM2AutomaticMaskGenerator(sam2,
points_per_batch=128,
pred_iou_thresh=iou_thresh,
stability_score_thresh=stability_score_thresh,
min_mask_region_area=min_mask_region_area,
multimask_output=True)
masks_list = []
for box_xyxy in boxes_xyxy:
wing = img[int(box_xyxy[1]):int(box_xyxy[3]), int(box_xyxy[0]):int(box_xyxy[2])]
mask = auto_mask_predictor.generate(wing)
# for mask_
# dict in mask:
# mask_dict['segmentation'] = np.bitwise_not(mask_dict['segmentation'])
if verbose:
plt.imshow(wing)
show_anns(mask)
# remove axis
plt.axis('off')
plt.show()
# translate the mask to the original image
binary_masks = [e['segmentation'] for e in mask]

for e in binary_masks:
new_mask = np.zeros((img.shape[0], img.shape[1]), dtype=bool)
new_mask[int(box_xyxy[1]):int(box_xyxy[3]), int(box_xyxy[0]):int(box_xyxy[2])] = e
new_mask_dict = {
'segmentation': new_mask,
'area': area(new_mask)
}
masks_list.append(new_mask_dict)
return masks_list

def show_masks(masks_list, img, verbose=True, imshow=True, grey=False):
if imshow:
if grey:
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
plt.imshow(img, cmap='gray')
else:
plt.imshow(img)
plt.axis('off')
show_anns(masks_list)
if verbose:
plt.show()

def show_individual_masks(masks_list, img):
for mask in masks_list:
plt.imshow(img)
plt.axis('off')
show_anns([mask])
plt.show()
Loading

0 comments on commit ea14ffd

Please sign in to comment.