Skip to content

Commit

Permalink
Merge pull request #214 from roboflow/fix/improve_keyless_models_hand…
Browse files Browse the repository at this point in the history
…ling

Fix problem with keyless access and Active Learning
  • Loading branch information
PawelPeczek-Roboflow authored Dec 29, 2023
2 parents 820da06 + 9aa14ae commit 2faf2af
Show file tree
Hide file tree
Showing 10 changed files with 152 additions and 25 deletions.
2 changes: 1 addition & 1 deletion inference/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from inference.core.interfaces.stream.stream import Stream
from inference.core.interfaces.stream.stream import Stream # isort:skip
from inference.core.interfaces.stream.inference_pipeline import InferencePipeline
from inference.models.utils import get_roboflow_model
9 changes: 8 additions & 1 deletion inference/core/interfaces/stream/inference_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,9 @@ def init(
active_learning_enabled (Optional[bool]): Flag to enable / disable Active Learning middleware (setting it
true does not guarantee any data to be collected, as data collection is controlled by Roboflow backend -
it just enables middleware intercepting predictions). If not given, env variable
`ACTIVE_LEARNING_ENABLED` will be used.
`ACTIVE_LEARNING_ENABLED` will be used. Please point out that Active Learning will be forcefully
disabled in a scenario when Roboflow API key is not given, as Roboflow account is required
for this feature to be operational.
Other ENV variables involved in low-level configuration:
* INFERENCE_PIPELINE_PREDICTIONS_QUEUE_SIZE - size of buffer for predictions that are ready for dispatching
Expand Down Expand Up @@ -170,6 +172,11 @@ def init(
f"with value: {ACTIVE_LEARNING_ENABLED}"
)
active_learning_enabled = ACTIVE_LEARNING_ENABLED
if api_key is None:
logger.info(
f"Roboflow API key not given - Active Learning is forced to be disabled."
)
active_learning_enabled = False
if active_learning_enabled is True:
active_learning_middleware = ThreadingActiveLearningMiddleware.init(
api_key=api_key,
Expand Down
11 changes: 6 additions & 5 deletions inference/core/managers/active_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ async def infer_from_request(
self, model_id: str, request: InferenceRequest, **kwargs
) -> InferenceResponse:
prediction = await super().infer_from_request(
model_id=model_id, request=request
model_id=model_id, request=request, **kwargs
)
active_learning_eligible = kwargs.get(ACTIVE_LEARNING_ELIGIBLE_PARAM, False)
if not active_learning_eligible:
if not active_learning_eligible or request.api_key is None:
return prediction
self.register(prediction=prediction, model_id=model_id, request=request)
return prediction
Expand Down Expand Up @@ -108,11 +108,12 @@ class BackgroundTaskActiveLearningManager(ActiveLearningManager):
async def infer_from_request(
self, model_id: str, request: InferenceRequest, **kwargs
) -> InferenceResponse:
active_learning_eligible = kwargs.get(ACTIVE_LEARNING_ELIGIBLE_PARAM, False)
kwargs[ACTIVE_LEARNING_ELIGIBLE_PARAM] = False # disabling AL in super-classes
prediction = await super().infer_from_request(
model_id=model_id, request=request
model_id=model_id, request=request, **kwargs
)
active_learning_eligible = kwargs.get(ACTIVE_LEARNING_ELIGIBLE_PARAM, False)
if not active_learning_eligible:
if not active_learning_eligible or request.api_key is None:
return prediction
if BACKGROUND_TASKS_PARAM not in kwargs:
logger.warning(
Expand Down
2 changes: 1 addition & 1 deletion inference/enterprise/stream_management/api/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class PipelineInitialisationRequest(BaseModel):
sink_configuration: UDPSinkConfiguration = Field(
description="Configuration of the sink."
)
api_key: str = Field(description="Roboflow API key")
api_key: Optional[str] = Field(description="Roboflow API key", default=None)
max_fps: Optional[Union[float, int]] = Field(
description="Limit of FPS in video processing.", default=None
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def _initialise_pipeline(self, request_id: str, payload: dict) -> None:
model_id=payload["model_id"],
video_reference=payload["video_reference"],
on_prediction=sink,
api_key=payload["api_key"],
api_key=payload.get("api_key"),
max_fps=payload.get("max_fps"),
watchdog=watchdog,
source_buffer_filling_strategy=source_buffer_filling_strategy,
Expand Down
2 changes: 1 addition & 1 deletion inference_sdk/http/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ def clip_compare(
)
payload = self.__initialise_payload()
payload["subject_type"] = subject_type
payload["prompt_type"] = subject_type
payload["prompt_type"] = prompt_type
if subject_type == "image":
encoded_image = load_static_inference_input(
inference_input=subject,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,44 @@ def test_initialise_pipeline_when_valid_payload_given(
}, "CommandResponse must be serialised directly to JSON response"


@mock.patch.object(app, "STREAM_MANAGER_CLIENT", new_callable=AsyncMock)
def test_initialise_pipeline_when_valid_payload_given_without_api_key(
stream_manager_client: AsyncMock,
) -> None:
# given
client = TestClient(app.app)
stream_manager_client.initialise_pipeline.return_value = CommandResponse(
status="success",
context=CommandContext(request_id="my_request", pipeline_id="my_pipeline"),
)

# when
response = client.post(
"/initialise",
json={
"model_id": "some/1",
"video_reference": "rtsp://some:543",
"sink_configuration": {
"type": "udp_sink",
"host": "127.0.0.1",
"port": 9090,
},
"model_configuration": {"type": "object-detection"},
"active_learning_enabled": True,
},
)

# then
assert response.status_code == 200, "Status code for success must be 200"
assert response.json() == {
"status": "success",
"context": {
"request_id": "my_request",
"pipeline_id": "my_pipeline",
},
}, "CommandResponse must be serialised directly to JSON response"


@mock.patch.object(app, "STREAM_MANAGER_CLIENT", new_callable=AsyncMock)
def test_pause_pipeline_when_successful_response_expected(
stream_manager_client: AsyncMock,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,52 @@ async def test_stream_manager_client_can_successfully_initialise_pipeline(
)


@pytest.mark.asyncio
@mock.patch.object(stream_manager_client, "establish_socket_connection")
async def test_stream_manager_client_can_successfully_initialise_pipeline_without_api_key(
establish_socket_connection_mock: AsyncMock,
) -> None:
# given
reader = assembly_socket_reader(
message={
"request_id": "my_request",
"pipeline_id": "new_pipeline",
"response": {"status": "success"},
},
header_size=4,
)
writer = DummyStreamWriter()
establish_socket_connection_mock.return_value = (reader, writer)
initialisation_request = PipelineInitialisationRequest(
model_id="some/1",
video_reference="rtsp://some:543",
sink_configuration=UDPSinkConfiguration(
type="udp_sink",
host="127.0.0.1",
port=9090,
),
model_configuration=ObjectDetectionModelConfiguration(type="object_detection"),
)
client = StreamManagerClient.init(
host="127.0.0.1",
port=7070,
operations_timeout=1.0,
header_size=4,
buffer_size=16438,
)

# when
result = await client.initialise_pipeline(
initialisation_request=initialisation_request
)

# then
assert result == CommandResponse(
status="success",
context=CommandContext(request_id="my_request", pipeline_id="new_pipeline"),
)


@pytest.mark.asyncio
@mock.patch.object(stream_manager_client, "establish_socket_connection")
async def test_stream_manager_client_can_successfully_terminate_pipeline(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def test_inference_pipeline_manager_when_init_pipeline_operation_is_requested(

@pytest.mark.timeout(30)
@mock.patch.object(inference_pipeline_manager.InferencePipeline, "init")
def test_inference_pipeline_manager_when_init_pipeline_operation_is_requested_but_invalid_payload_sent(
def test_inference_pipeline_manager_when_init_pipeline_operation_is_requested_without_api_key(
pipeline_init_mock: MagicMock,
) -> None:
# given
Expand All @@ -70,7 +70,7 @@ def test_inference_pipeline_manager_when_init_pipeline_operation_is_requested_bu
command_queue=command_queue, responses_queue=responses_queue
)
init_payload = assembly_valid_init_payload()
del init_payload["model_configuration"]
del init_payload["api_key"]

# when
command_queue.put(("1", init_payload))
Expand All @@ -82,24 +82,18 @@ def test_inference_pipeline_manager_when_init_pipeline_operation_is_requested_bu
status_2 = responses_queue.get()

# then
assert (
status_1[0] == "1"
), "First request should be reported in responses_queue at first"
assert (
status_1[1]["status"] == OperationStatus.FAILURE
), "Init operation should fail"
assert (
status_1[1]["error_type"] == ErrorType.INVALID_PAYLOAD
), "Invalid Payload error is expected"
assert status_1 == (
"1",
{"status": OperationStatus.SUCCESS},
), "Initialisation operation must succeed"
assert status_2 == (
"2",
{"status": OperationStatus.SUCCESS},
), "Termination of pipeline must happen"

), "Termination operation must succeed"

@pytest.mark.timeout(30)
@mock.patch.object(inference_pipeline_manager.InferencePipeline, "init")
def test_inference_pipeline_manager_when_init_pipeline_operation_is_requested_but_api_key_not_given(
def test_inference_pipeline_manager_when_init_pipeline_operation_is_requested_but_invalid_payload_sent(
pipeline_init_mock: MagicMock,
) -> None:
# given
Expand All @@ -109,7 +103,7 @@ def test_inference_pipeline_manager_when_init_pipeline_operation_is_requested_bu
command_queue=command_queue, responses_queue=responses_queue
)
init_payload = assembly_valid_init_payload()
del init_payload["api_key"]
del init_payload["model_configuration"]

# when
command_queue.put(("1", init_payload))
Expand Down
41 changes: 41 additions & 0 deletions tests/inference_sdk/unit_tests/http/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1561,6 +1561,47 @@ def test_clip_compare_when_both_prompt_and_subject_are_texts(
}, "Request must contain API key, subject and prompt types as text, exact values of subject and list of prompt values"


@mock.patch.object(client, "load_static_inference_input")
def test_clip_compare_when_mixed_input_is_given(
load_static_inference_input_mock: MagicMock,
requests_mock: Mocker,
) -> None:
# given
api_url = "http://some.com"
http_client = InferenceHTTPClient(api_key="my-api-key", api_url=api_url)
load_static_inference_input_mock.side_effect = [
[("base64_image_1", 0.5)]
]
requests_mock.post(
f"{api_url}/clip/compare",
json={
"frame_id": None,
"time": 0.1435863340011565,
"similarity": [0.8963012099266052, 0.8830886483192444],
},
)

# when
result = http_client.clip_compare(
subject="/some/image.jpg",
prompt=["dog", "house"],
)

# then
assert result == {
"frame_id": None,
"time": 0.1435863340011565,
"similarity": [0.8963012099266052, 0.8830886483192444],
}, "Result must match the value returned by HTTP endpoint"
assert requests_mock.request_history[0].json() == {
"api_key": "my-api-key",
"subject": {"type": "base64", "value": "base64_image_1"},
"prompt": ["dog", "house"],
"prompt_type": "text",
"subject_type": "image",
}, "Request must contain API key, subject and prompt types as text, exact values of subject and list of prompt values"


@mock.patch.object(client, "load_static_inference_input")
def test_clip_compare_when_both_prompt_and_subject_are_images(
load_static_inference_input_mock: MagicMock,
Expand Down

0 comments on commit 2faf2af

Please sign in to comment.