Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Rename Classes UQL Operation #656

Merged
merged 11 commits into from
Sep 20, 2024
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, List, Literal, Union
from typing import Any, Dict, List, Literal, Union

from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Annotated
Expand Down Expand Up @@ -450,6 +450,27 @@ class Divide(OperationDefinition):
other: Union[int, float]


class DetectionsRename(OperationDefinition):
model_config = ConfigDict(
json_schema_extra={
"description": "Renames classes in detections based on provided mapping",
"compound": False,
"input_kind": [
OBJECT_DETECTION_PREDICTION_KIND,
INSTANCE_SEGMENTATION_PREDICTION_KIND,
KEYPOINT_DETECTION_PREDICTION_KIND,
],
"output_kind": [
OBJECT_DETECTION_PREDICTION_KIND,
INSTANCE_SEGMENTATION_PREDICTION_KIND,
KEYPOINT_DETECTION_PREDICTION_KIND,
],
},
)
type: Literal["DetectionsRename"]
class_map: Dict[str, str]


AllOperationsType = Annotated[
Union[
StringToLowerCase,
Expand All @@ -469,6 +490,7 @@ class Divide(OperationDefinition):
DetectionsFilter,
DetectionsOffset,
DetectionsShift,
DetectionsRename,
RandomNumber,
StringMatches,
ExtractImageProperty,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
extract_detections_property,
filter_detections,
offset_detections,
rename_detections,
select_detections,
shift_detections,
sort_detections,
Expand Down Expand Up @@ -187,6 +188,7 @@ def build_detections_filter_operation(
"DetectionsSelection": select_detections,
"SortDetections": sort_detections,
"ClassificationPropertyExtract": extract_classification_property,
"DetectionsRename": rename_detections,
}

REGISTERED_COMPOUND_OPERATIONS_BUILDERS = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,30 @@ def sort_detections(
if not ascending:
sorted_indices = sorted_indices[::-1]
return value[sorted_indices]


def rename_detections(
detections: Any,
class_map: Dict[str, str],
**kwargs,
) -> sv.Detections:
if not isinstance(detections, sv.Detections):
value_as_str = safe_stringify(value=detections)
raise InvalidInputTypeError(
public_message=f"Executing rename_detections(...), expected sv.Detections object as value, "
f"got {value_as_str} of type {type(detections)}",
context="step_execution | roboflow_query_language_evaluation",
)

detections_copy = deepcopy(detections)
class_names = detections_copy.data.get("class_name", []).tolist()

for i, class_name in enumerate(class_names):
try:
class_names[i] = class_map[class_name]
except KeyError:
# If the class is not in the class_map, keep the original class
pass

detections_copy.data["class_name"] = np.array(class_names, dtype=object)
return detections_copy
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
5 changes: 5 additions & 0 deletions tests/workflows/integration_tests/execution/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ def red_image() -> np.ndarray:
return cv2.imread(os.path.join(ASSETS_DIR, "red_image.png"))


@pytest.fixture(scope="function")
def fruit_image() -> np.ndarray:
return cv2.imread(os.path.join(ASSETS_DIR, "multi-fruit.jpg"))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please report image credits

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Complete



@pytest.fixture(scope="function")
def left_scissors_right_paper() -> np.ndarray:
return cv2.imread(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import numpy as np
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great tests coverage

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@PawelPeczek-Roboflow Hey Pawel, I couldn't figure out or find a good example of passing Input Parameters into UQL operations for tests. Instead I created a hacky work around with parameterized tests to replace the Workflow Specification to test scenarios. Feel free to change this if you'd like; as I would also like to learn how to pass input parameters into UQL operations for future work.

import supervision as sv

from inference.core.env import WORKFLOWS_MAX_CONCURRENT_STEPS
from inference.core.managers.base import ModelManager
from inference.core.workflows.core_steps.common.entities import StepExecutionMode
from inference.core.workflows.execution_engine.core import ExecutionEngine

CLASS_RENAME_WORKFLOW = {
"version": "1.0",
"inputs": [
{
"type": "WorkflowImage",
"name": "image"
},
{
"type": "WorkflowParameter",
"name": "model_id"
},
{
"type": "WorkflowParameter",
"name": "confidence",
"default_value": 0.4
},
{
"type": "WorkflowParameter",
"name": "classes"
},
],
"steps": [
{
"type": "ObjectDetectionModel",
"name": "model",
"image": "$inputs.image",
"model_id": "$inputs.model_id",
"confidence": "$inputs.confidence",
},
{
"type": "DetectionsTransformation",
"name": "class_rename",
"predictions": "$steps.model.predictions",
"operations": [
{
"type": "DetectionsRename",
"class_map": {
"orange": "fruit",
"banana": "fruit",
"apple": "fruit"
}
}
],
},
],
"outputs": [
{
"type": "JsonField",
"name": "original_predictions",
"selector": "$steps.model.predictions",
},
{
"type": "JsonField",
"name": "renamed_predictions",
"selector": "$steps.class_rename.predictions",
},
],
}


EXPECTED_ORIGINAL_CLASSES = np.array(
[
"apple",
"apple",
"apple",
"orange",
"banana"
]
)
EXPECTED_RENAMED_CLASSES = np.array(
[
"fruit",
"fruit",
"fruit",
"fruit",
"fruit"
]
)


def test_class_rename_workflow_to_have_correct_classes(
model_manager: ModelManager,
fruit_image: np.ndarray,
) -> None:
# given
workflow_init_parameters = {
"workflows_core.model_manager": model_manager,
"workflows_core.api_key": None,
"workflows_core.step_execution_mode": StepExecutionMode.LOCAL,
}
execution_engine = ExecutionEngine.init(
workflow_definition=CLASS_RENAME_WORKFLOW,
init_parameters=workflow_init_parameters,
max_concurrent_steps=WORKFLOWS_MAX_CONCURRENT_STEPS,
)

# when
result = execution_engine.run(
runtime_parameters={
"image": fruit_image,
"model_id": "yolov8n-640",
}
)

# then
assert isinstance(result, list), "Expected result to be list"
assert len(result) == 1, "Single image provided - single output expected"

original_predictions: sv.Detections = result[0]["original_predictions"]
renamed_predictions: sv.Detections = result[0]["renamed_predictions"]

assert len(original_predictions) == len(EXPECTED_ORIGINAL_CLASSES), "length of original predictions match expected length"
assert len(renamed_predictions) == len(EXPECTED_RENAMED_CLASSES), "length of renamed predictions match expected length "

assert np.array_equal(EXPECTED_ORIGINAL_CLASSES, original_predictions.data["class_name"]), "Expected original classes to match predicted classes"
assert np.array_equal(EXPECTED_RENAMED_CLASSES, renamed_predictions.data["class_name"]), "Expected renamed classes to match block class renaming"
Loading