Skip to content

Commit

Permalink
Fix issue with mlaformed payload for clip compare
Browse files Browse the repository at this point in the history
  • Loading branch information
PawelPeczek-Roboflow committed Dec 29, 2023
1 parent 1d19ceb commit 9aa14ae
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 5 deletions.
9 changes: 5 additions & 4 deletions inference/core/managers/active_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ async def infer_from_request(
self, model_id: str, request: InferenceRequest, **kwargs
) -> InferenceResponse:
prediction = await super().infer_from_request(
model_id=model_id, request=request
model_id=model_id, request=request, **kwargs
)
active_learning_eligible = kwargs.get(ACTIVE_LEARNING_ELIGIBLE_PARAM, False)
if not active_learning_eligible or request.api_key is None:
Expand Down Expand Up @@ -108,11 +108,12 @@ class BackgroundTaskActiveLearningManager(ActiveLearningManager):
async def infer_from_request(
self, model_id: str, request: InferenceRequest, **kwargs
) -> InferenceResponse:
active_learning_eligible = kwargs.get(ACTIVE_LEARNING_ELIGIBLE_PARAM, False)
kwargs[ACTIVE_LEARNING_ELIGIBLE_PARAM] = False # disabling AL in super-classes
prediction = await super().infer_from_request(
model_id=model_id, request=request
model_id=model_id, request=request, **kwargs
)
active_learning_eligible = kwargs.get(ACTIVE_LEARNING_ELIGIBLE_PARAM, False)
if not active_learning_eligible:
if not active_learning_eligible or request.api_key is None:
return prediction
if BACKGROUND_TASKS_PARAM not in kwargs:
logger.warning(
Expand Down
2 changes: 1 addition & 1 deletion inference_sdk/http/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ def clip_compare(
)
payload = self.__initialise_payload()
payload["subject_type"] = subject_type
payload["prompt_type"] = subject_type
payload["prompt_type"] = prompt_type
if subject_type == "image":
encoded_image = load_static_inference_input(
inference_input=subject,
Expand Down
41 changes: 41 additions & 0 deletions tests/inference_sdk/unit_tests/http/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1561,6 +1561,47 @@ def test_clip_compare_when_both_prompt_and_subject_are_texts(
}, "Request must contain API key, subject and prompt types as text, exact values of subject and list of prompt values"


@mock.patch.object(client, "load_static_inference_input")
def test_clip_compare_when_mixed_input_is_given(
load_static_inference_input_mock: MagicMock,
requests_mock: Mocker,
) -> None:
# given
api_url = "http://some.com"
http_client = InferenceHTTPClient(api_key="my-api-key", api_url=api_url)
load_static_inference_input_mock.side_effect = [
[("base64_image_1", 0.5)]
]
requests_mock.post(
f"{api_url}/clip/compare",
json={
"frame_id": None,
"time": 0.1435863340011565,
"similarity": [0.8963012099266052, 0.8830886483192444],
},
)

# when
result = http_client.clip_compare(
subject="/some/image.jpg",
prompt=["dog", "house"],
)

# then
assert result == {
"frame_id": None,
"time": 0.1435863340011565,
"similarity": [0.8963012099266052, 0.8830886483192444],
}, "Result must match the value returned by HTTP endpoint"
assert requests_mock.request_history[0].json() == {
"api_key": "my-api-key",
"subject": {"type": "base64", "value": "base64_image_1"},
"prompt": ["dog", "house"],
"prompt_type": "text",
"subject_type": "image",
}, "Request must contain API key, subject and prompt types as text, exact values of subject and list of prompt values"


@mock.patch.object(client, "load_static_inference_input")
def test_clip_compare_when_both_prompt_and_subject_are_images(
load_static_inference_input_mock: MagicMock,
Expand Down

0 comments on commit 9aa14ae

Please sign in to comment.