Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sam2 multi polygons #593

Merged
merged 25 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
ec95a3f
Multi polygons
probicheaux Aug 22, 2024
6829daf
Style
probicheaux Aug 22, 2024
a597c44
Merge branch 'sam2-id-free-caching' into sam2-multi-polygons
probicheaux Aug 22, 2024
1b6e391
Saving work
probicheaux Aug 22, 2024
39d997d
Multi polygon responses
probicheaux Aug 22, 2024
55dd199
Merge branch 'main' into sam2-multi-polygons
probicheaux Aug 22, 2024
b529fcb
Better multipoly support
probicheaux Aug 23, 2024
7792159
Merge branch 'sam2-multi-polygons' of github.com:roboflow/inference i…
probicheaux Aug 23, 2024
4b09090
Style
probicheaux Aug 23, 2024
33576c6
Merge branch 'main' into sam2-multi-polygons
probicheaux Aug 23, 2024
6d8c321
Add tests
probicheaux Aug 23, 2024
97f389c
Update integration tests
probicheaux Aug 23, 2024
d26f5ad
Testing stuff
probicheaux Aug 23, 2024
157459a
Test was caching wrong thing facepalm
probicheaux Aug 23, 2024
f751c47
Remove prints
probicheaux Aug 23, 2024
88bf2c8
Move test to models predictions
probicheaux Aug 23, 2024
9e45d34
Somehow had this line twcie
probicheaux Aug 23, 2024
36cde23
Merge branch 'main' into sam2-multi-polygons
probicheaux Aug 23, 2024
f934968
Fix bad merge resolution
probicheaux Aug 23, 2024
fdc915f
Fix bad merge resolution
probicheaux Aug 23, 2024
e764ed7
Merge branch 'main' into sam2-multi-polygons
probicheaux Aug 26, 2024
ad29728
Address pr comments
probicheaux Aug 27, 2024
3e1cfd9
Merge branch 'sam2-multi-polygons' of github.com:roboflow/inference i…
probicheaux Aug 27, 2024
3d6e94d
Merge branch 'main' into sam2-multi-polygons
PawelPeczek-Roboflow Aug 27, 2024
9584889
Merge branch 'main' into sam2-multi-polygons
probicheaux Aug 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion inference/core/entities/responses/sam2.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class Sam2SegmentationPrediction(BaseModel):
time (float): The time in seconds it took to produce the segmentation including preprocessing.
"""

mask: List[List[int]] = Field(
masks: List[List[List[int]]] = Field(
description="The set of points for output mask as polygon. Each element of list represents single point.",
)
confidence: float = Field(description="Masks confidences")
Expand Down
34 changes: 34 additions & 0 deletions inference/core/utils/postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,22 @@ def masks2poly(masks: np.ndarray) -> List[np.ndarray]:
return segments


def masks2multipoly(masks: np.ndarray) -> List[np.ndarray]:
"""Converts binary masks to polygonal segments.

Args:
masks (numpy.ndarray): A set of binary masks, where masks are multiplied by 255 and converted to uint8 type.

Returns:
list: A list of segments, where each segment is obtained by converting the corresponding mask.
"""
segments = []
masks = (masks * 255.0).astype(np.uint8)
for mask in masks:
segments.append(mask2multipoly(mask))
return segments


def mask2poly(mask: np.ndarray) -> np.ndarray:
"""
Find contours in the mask and return them as a float32 array.
Expand All @@ -61,6 +77,24 @@ def mask2poly(mask: np.ndarray) -> np.ndarray:
return contours.astype("float32")


def mask2multipoly(mask: np.ndarray) -> np.ndarray:
"""
Find all contours in the mask and return them as a float32 array.

Args:
mask (np.ndarray): A binary mask.

Returns:
np.ndarray: Contours represented as a float32 array.
"""
contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
if contours:
contours = [c.reshape(-1, 2).astype("float32") for c in contours]
else:
contours = [np.zeros((0, 2)).astype("float32")]
return contours


def post_process_bboxes(
predictions: List[List[List[float]]],
infer_shape: Tuple[int, int],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
K = TypeVar("K")

DETECTIONS_CLASS_NAME_FIELD = "class_name"
DETECTION_ID_FIELD = "detection_id"

LONG_DESCRIPTION = """
Run Segment Anything 2, a zero-shot instance segmentation model, on an image.
Expand Down Expand Up @@ -195,13 +196,15 @@ def run_locally(
for single_image, boxes_for_image in zip(images, boxes):
prompt_class_ids: List[Optional[int]] = []
prompt_class_names: List[str] = []
prompt_detection_ids: List[Optional[str]] = []

prompts = []
if boxes_for_image is not None:
for xyxy, _, confidence, class_id, _, bbox_data in boxes_for_image:
x1, y1, x2, y2 = xyxy
prompt_class_ids.append(class_id)
prompt_class_names.append(bbox_data[DETECTIONS_CLASS_NAME_FIELD])
prompt_detection_ids.append(bbox_data[DETECTION_ID_FIELD])
width = x2 - x1
height = y2 - y1
cx = x1 + width / 2
Expand Down Expand Up @@ -239,6 +242,7 @@ def run_locally(
image=single_image,
prompt_class_ids=prompt_class_ids,
prompt_class_names=prompt_class_names,
prompt_detection_ids=prompt_detection_ids,
threshold=threshold,
)
predictions.append(prediction)
Expand Down Expand Up @@ -273,6 +277,7 @@ def convert_sam2_segmentation_response_to_inference_instances_seg_response(
image: WorkflowImageData,
prompt_class_ids: List[Optional[int]],
prompt_class_names: List[Optional[str]],
prompt_detection_ids: List[Optional[str]],
threshold: float,
) -> InstanceSegmentationInferenceResponse:
image_width = image.numpy_image.shape[1]
Expand All @@ -283,39 +288,43 @@ def convert_sam2_segmentation_response_to_inference_instances_seg_response(
prompt_class_names = [
"foreground" for _ in range(len(sam2_segmentation_predictions))
]
for prediction, class_id, class_name in zip(
sam2_segmentation_predictions, prompt_class_ids, prompt_class_names
prompt_detection_ids = [None for _ in range(len(sam2_segmentation_predictions))]
for prediction, class_id, class_name, detection_id in zip(
sam2_segmentation_predictions,
prompt_class_ids,
prompt_class_names,
prompt_detection_ids,
):
if len(prediction.mask) == 0:
# skipping empty masks
continue
if prediction.confidence < threshold:
# skipping maks below threshold
continue
x_coords = [coord[0] for coord in prediction.mask]
y_coords = [coord[1] for coord in prediction.mask]
min_x = np.min(x_coords)
max_x = np.max(x_coords)
min_y = np.min(y_coords)
max_y = np.max(y_coords)
center_x = (min_x + max_x) / 2
center_y = (min_y + max_y) / 2
predictions.append(
InstanceSegmentationPrediction(
**{
"x": center_x,
"y": center_y,
"width": max_x - min_x,
"height": max_y - min_y,
"points": [
Point(x=point[0], y=point[1]) for point in prediction.mask
],
"confidence": prediction.confidence,
"class": class_name,
"class_id": class_id,
}
for mask in prediction.masks:
if len(mask) == 0:
# skipping empty masks
continue
if prediction.confidence < threshold:
# skipping maks below threshold
continue
x_coords = [coord[0] for coord in mask]
y_coords = [coord[1] for coord in mask]
min_x = np.min(x_coords)
max_x = np.max(x_coords)
min_y = np.min(y_coords)
max_y = np.max(y_coords)
center_x = (min_x + max_x) / 2
center_y = (min_y + max_y) / 2
predictions.append(
InstanceSegmentationPrediction(
**{
"x": center_x,
"y": center_y,
"width": max_x - min_x,
"height": max_y - min_y,
"points": [Point(x=point[0], y=point[1]) for point in mask],
"confidence": prediction.confidence,
"class": class_name,
"class_id": class_id,
probicheaux marked this conversation as resolved.
Show resolved Hide resolved
"parent_id": detection_id,
}
)
)
)
return InstanceSegmentationInferenceResponse(
predictions=predictions,
image=InferenceResponseImage(width=image_width, height=image_height),
Expand Down
8 changes: 5 additions & 3 deletions inference/models/sam2/segment_anything2.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
)
from inference.core.models.roboflow import RoboflowCoreModel
from inference.core.utils.image_utils import load_image_rgb
from inference.core.utils.postprocess import masks2poly
from inference.core.utils.postprocess import masks2multipoly

if DEVICE is None:
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
Expand Down Expand Up @@ -303,6 +303,8 @@ def segment_image(
)

args = pad_points(args)
if not any(args.values()):
args = {"point_coords": [[0, 0]], "point_labels": [-1], "box": None}
masks, scores, low_resolution_logits = self.predictor.predict(
mask_input=mask_input,
multimask_output=multimask_output,
Expand Down Expand Up @@ -484,10 +486,10 @@ def turn_segmentation_results_into_api_response(
inference_start_timestamp: float,
) -> Sam2SegmentationResponse:
predictions = []
masks_plygons = masks2poly(masks >= mask_threshold)
masks_plygons = masks2multipoly(masks >= mask_threshold)
for mask_polygon, score in zip(masks_plygons, scores):
prediction = Sam2SegmentationPrediction(
mask=mask_polygon.tolist(),
masks=[mask.tolist() for mask in mask_polygon],
confidence=score.item(),
)
predictions.append(prediction)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"predictions": [{"masks": [[[239, 298], [238, 299], [239, 300], [239, 301], [238, 302], [238, 338], [237, 339], [238, 340], [238, 361], [239, 362], [238, 363], [238, 365], [239, 366], [241, 366], [243, 364], [243, 361], [242, 360], [242, 355], [243, 354], [242, 353], [242, 336], [241, 335], [241, 326], [240, 325], [240, 311], [239, 310]], [[239, 295]], [[551, 266], [550, 267], [546, 267], [546, 273], [547, 274], [545, 276], [545, 309], [544, 310], [544, 325], [543, 326], [543, 328], [544, 329], [543, 330], [543, 346], [542, 347], [542, 348], [543, 349], [542, 350], [542, 372], [543, 373], [543, 376], [548, 376], [549, 375], [551, 375], [552, 376], [553, 375], [553, 373], [554, 372], [554, 369], [555, 368], [555, 361], [556, 360], [556, 359], [555, 358], [555, 356], [556, 355], [555, 354], [555, 353], [556, 352], [556, 350], [555, 349], [556, 348], [556, 347], [555, 346], [555, 345], [556, 344], [556, 339], [555, 338], [555, 336], [556, 335], [555, 334], [555, 332], [556, 331], [556, 330], [555, 329], [556, 328], [556, 318], [555, 317], [555, 315], [556, 314], [556, 287], [557, 286], [556, 285], [556, 283], [557, 282], [556, 281], [556, 266]], [[592, 265], [591, 266], [590, 266], [589, 267], [589, 271], [588, 272], [588, 321], [587, 322], [587, 323], [588, 324], [588, 330], [587, 331], [588, 332], [588, 342], [587, 343], [587, 348], [588, 349], [588, 362], [587, 363], [587, 364], [588, 365], [588, 400], [589, 401], [589, 406], [593, 406], [594, 407], [598, 407], [599, 408], [600, 408], [601, 409], [602, 409], [603, 410], [603, 418], [604, 419], [604, 420], [603, 421], [603, 423], [604, 424], [604, 426], [603, 427], [603, 428], [601, 430], [600, 430], [599, 431], [598, 431], [597, 432], [592, 432], [591, 433], [591, 434], [592, 435], [592, 436], [595, 439], [597, 439], [599, 441], [600, 441], [607, 448], [608, 448], [610, 450], [611, 450], [612, 451], [614, 451], [615, 452], [616, 452], [617, 453], [619, 453], [620, 454], [623, 454], [627, 458], [628, 458], [630, 460], [632, 460], [633, 461], [636, 461], [637, 462], [640, 462], [642, 464], [643, 464], [644, 465], [646, 465], [648, 467], [649, 467], [651, 469], [652, 469], [654, 471], [654, 472], [657, 475], [657, 477], [658, 478], [658, 479], [659, 480], [659, 481], [660, 482], [661, 482], [662, 483], [662, 484], [667, 489], [669, 489], [670, 490], [672, 490], [675, 493], [676, 493], [677, 494], [678, 494], [680, 496], [680, 497], [681, 497], [682, 498], [683, 498], [684, 499], [685, 499], [686, 500], [688, 500], [689, 501], [690, 501], [691, 502], [692, 502], [693, 503], [694, 503], [695, 502], [696, 503], [698, 503], [699, 504], [700, 503], [704, 503], [704, 501], [705, 500], [708, 500], [709, 499], [710, 499], [711, 498], [713, 498], [714, 497], [716, 497], [717, 496], [718, 496], [719, 495], [722, 495], [723, 494], [728, 494], [729, 493], [732, 493], [733, 492], [735, 492], [736, 491], [737, 491], [738, 490], [749, 490], [750, 489], [752, 489], [753, 488], [757, 488], [758, 487], [764, 487], [765, 486], [769, 486], [770, 485], [772, 485], [774, 483], [775, 483], [776, 482], [779, 482], [779, 267], [777, 267], [776, 266], [774, 266], [773, 267], [694, 267], [693, 266], [689, 266], [688, 267], [678, 267], [677, 266], [676, 267], [607, 267], [606, 266], [605, 266], [604, 267], [598, 267], [597, 266], [596, 266], [595, 265]], [[260, 264], [259, 265], [256, 265], [255, 266], [255, 279], [256, 280], [256, 281], [255, 282], [256, 283], [256, 286], [257, 287], [269, 287], [270, 288], [272, 288], [273, 289], [274, 289], [278, 293], [278, 299], [277, 300], [277, 318], [276, 319], [276, 324], [277, 325], [276, 326], [276, 348], [275, 349], [275, 353], [276, 354], [275, 355], [275, 356], [276, 357], [276, 360], [278, 362], [288, 362], [289, 363], [299, 363], [300, 362], [301, 363], [311, 363], [312, 362], [314, 362], [315, 363], [316, 363], [317, 362], [318, 362], [319, 363], [340, 363], [341, 362], [351, 362], [352, 363], [387, 363], [388, 364], [395, 364], [396, 363], [412, 363], [413, 364], [442, 364], [443, 365], [444, 364], [447, 364], [448, 365], [457, 365], [458, 364], [459, 364], [460, 365], [502, 365], [503, 366], [504, 366], [505, 365], [512, 365], [513, 366], [514, 365], [514, 361], [516, 359], [516, 342], [517, 341], [516, 340], [516, 334], [517, 333], [516, 332], [516, 327], [517, 326], [517, 324], [516, 323], [516, 320], [517, 319], [517, 305], [518, 304], [518, 300], [519, 299], [519, 290], [518, 289], [519, 288], [519, 279], [520, 278], [520, 277], [519, 276], [519, 273], [520, 272], [520, 268], [519, 267], [501, 267], [500, 266], [494, 266], [493, 267], [487, 267], [486, 266], [469, 266], [468, 267], [466, 267], [465, 266], [453, 266], [452, 267], [449, 267], [448, 266], [439, 266], [438, 265], [420, 265], [419, 266], [417, 266], [416, 265], [401, 265], [400, 264], [396, 264], [395, 265], [316, 265], [315, 264], [314, 265], [305, 265], [304, 264], [285, 264], [284, 265], [283, 264]], [[7, 264], [6, 265], [0, 265], [0, 486], [1, 486], [3, 488], [3, 489], [9, 489], [10, 490], [11, 489], [17, 489], [17, 488], [19, 486], [21, 486], [22, 487], [23, 487], [24, 488], [25, 488], [26, 489], [27, 489], [28, 488], [31, 488], [32, 487], [33, 488], [42, 488], [43, 487], [47, 487], [48, 486], [49, 486], [50, 485], [52, 485], [53, 484], [53, 483], [57, 479], [58, 479], [59, 478], [60, 478], [61, 477], [62, 477], [64, 475], [65, 475], [65, 474], [67, 472], [67, 471], [68, 470], [69, 470], [70, 469], [71, 469], [75, 465], [76, 465], [77, 464], [78, 464], [79, 463], [81, 463], [82, 462], [86, 462], [87, 461], [87, 458], [89, 456], [89, 455], [91, 453], [92, 453], [94, 451], [95, 451], [96, 450], [97, 450], [98, 449], [101, 449], [102, 448], [103, 448], [104, 447], [105, 447], [106, 446], [111, 446], [112, 445], [113, 445], [114, 444], [116, 444], [117, 443], [119, 443], [120, 442], [124, 442], [125, 441], [126, 441], [126, 440], [127, 439], [128, 439], [129, 438], [130, 438], [132, 436], [133, 436], [134, 435], [136, 435], [137, 434], [138, 434], [139, 433], [140, 433], [141, 432], [142, 432], [144, 430], [145, 430], [146, 429], [149, 429], [153, 425], [155, 425], [156, 424], [157, 424], [159, 422], [160, 422], [160, 421], [161, 420], [162, 420], [164, 418], [165, 418], [164, 418], [162, 416], [162, 415], [161, 414], [161, 408], [160, 407], [160, 405], [161, 404], [161, 401], [168, 394], [168, 393], [171, 390], [172, 390], [174, 388], [175, 388], [176, 387], [186, 387], [187, 386], [194, 386], [196, 388], [197, 387], [200, 387], [201, 386], [202, 386], [203, 385], [203, 383], [204, 382], [204, 378], [205, 377], [204, 376], [205, 375], [205, 373], [204, 372], [204, 371], [205, 370], [205, 344], [206, 343], [206, 341], [205, 340], [205, 338], [206, 337], [205, 336], [205, 328], [206, 327], [206, 326], [205, 325], [205, 322], [206, 321], [206, 274], [207, 273], [207, 268], [205, 266], [205, 265], [188, 265], [187, 266], [186, 266], [185, 265], [127, 265], [126, 266], [125, 265], [12, 265], [11, 264]]], "confidence": 0.9681398272514343}], "time": 2.0950593883171678}
22 changes: 21 additions & 1 deletion tests/inference/models_predictions_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import numpy as np
import pytest
import requests
import json
from typing import Dict

from inference.core.env import MODEL_CACHE_DIR

Expand All @@ -22,8 +24,17 @@
TRUCK_IMAGE_PATH = os.path.join(ASSETS_DIR, "truck.jpg")
SAM2_TRUCK_LOGITS = os.path.join(ASSETS_DIR, "low_res_logits.npy")
SAM2_TRUCK_MASK_FROM_CACHE = os.path.join(ASSETS_DIR, "mask_from_cached_logits.npy")
SAM2_MULTI_POLY_RESPONSE_PATH = os.path.join(
ASSETS_DIR, "sam2_multipolygon_response.json"
)



@pytest.fixture(scope="function")
def sam2_multipolygon_response() -> Dict:
with open(SAM2_MULTI_POLY_RESPONSE_PATH) as f:
return json.load(f)

@pytest.fixture(scope="function")
def example_image() -> np.ndarray:
return cv2.imread(EXAMPLE_IMAGE_PATH)
Expand Down Expand Up @@ -186,6 +197,16 @@ def sam2_small_model() -> Generator[str, None, None]:
yield model_id
shutil.rmtree(model_cache_dir)

@pytest.fixture(scope="function")
def sam2_tiny_model() -> Generator[str, None, None]:
model_id = "sam2/hiera_tiny"
model_cache_dir = fetch_and_place_model_in_cache(
model_id=model_id,
model_package_url="https://storage.googleapis.com/roboflow-tests-assets/sam2_tiny.zip",
)
yield model_id
shutil.rmtree(model_cache_dir)


@pytest.fixture(scope="function")
def sam2_small_truck_logits() -> Generator[np.ndarray, None, None]:
Expand All @@ -196,7 +217,6 @@ def sam2_small_truck_logits() -> Generator[np.ndarray, None, None]:
def sam2_small_truck_mask_from_cached_logits() -> Generator[np.ndarray, None, None]:
yield np.load(SAM2_TRUCK_MASK_FROM_CACHE)


def fetch_and_place_model_in_cache(
model_id: str,
model_package_url: str,
Expand Down
Loading
Loading