diff --git a/photonix/classifiers/runners.py b/photonix/classifiers/runners.py index f091d949..d5fb1f3f 100644 --- a/photonix/classifiers/runners.py +++ b/photonix/classifiers/runners.py @@ -19,7 +19,7 @@ def get_or_create_tag(library, name, type, source, parent=None, ordering=None): return tag -def get_photo_by_any_type(photo_id): +def get_photo_by_any_type(photo_id, model=None): is_photo_instance = False photo = None @@ -33,7 +33,7 @@ def get_photo_by_any_type(photo_id): # Is an individual filename so return the prediction if not is_photo_instance: - return None, model.predict(photo_id) + return None # Is a Photo model instance so needs saving if not photo: @@ -50,6 +50,9 @@ def get_photo_by_any_type(photo_id): def results_for_model_on_photo(model, photo_id): - photo = get_photo_by_any_type(photo_id) - results = model.predict(photo.base_image_path) + photo = get_photo_by_any_type(photo_id, model) + if photo: + results = model.predict(photo.base_image_path) + else: + results = model.predict(photo_id) return photo, results diff --git a/tests/factories.py b/tests/factories.py index 0cab88d7..03904f75 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -24,9 +24,13 @@ class Meta: model = Library name = factory.Sequence(lambda n: f'Test Library {n}') + classification_color_enabled = True classification_location_enabled = True + classification_style_enabled = True classification_object_enabled = True classification_face_enabled = True + setup_stage_completed = True + class LibraryUserFactory(factory.django.DjangoModelFactory): class Meta: @@ -36,6 +40,7 @@ class Meta: user = factory.SubFactory(UserFactory) owner = True + class PhotoFactory(factory.django.DjangoModelFactory): class Meta: model = Photo @@ -76,3 +81,4 @@ class Meta: type = 'classify.style' status = 'P' + library = factory.SubFactory(LibraryFactory) diff --git a/tests/test_classifier_batch.py b/tests/test_classifier_batch.py index 159c4f0d..c18de390 100644 --- a/tests/test_classifier_batch.py +++ b/tests/test_classifier_batch.py @@ -22,7 +22,7 @@ def test_classifier_batch(): photo = PhotoFactory() PhotoFileFactory(photo=photo) - for i in range(4): + for _ in range(4): TaskFactory(subject_id=photo.id) start = time() diff --git a/tests/test_graphql.py b/tests/test_graphql.py index 5055426f..b8ee12f6 100644 --- a/tests/test_graphql.py +++ b/tests/test_graphql.py @@ -193,6 +193,7 @@ def test_library_setting_data(self): classificationStyleEnabled classificationObjectEnabled classificationLocationEnabled + classificationFaceEnabled } sourceFolder } @@ -202,14 +203,15 @@ def test_library_setting_data(self): data = get_graphql_content(response) assert response.status_code == 200 self.assertEqual(data['data']['librarySetting']['library']['name'], self.defaults['library'].name) - self.assertFalse(data['data']['librarySetting']['library']['classificationColorEnabled']) - self.assertFalse(data['data']['librarySetting']['library']['classificationStyleEnabled']) - assert data['data']['librarySetting']['library']['classificationObjectEnabled'] - assert data['data']['librarySetting']['library']['classificationLocationEnabled'] + self.assertTrue(data['data']['librarySetting']['library']['classificationColorEnabled']) + self.assertTrue(data['data']['librarySetting']['library']['classificationStyleEnabled']) + self.assertTrue(data['data']['librarySetting']['library']['classificationObjectEnabled']) + self.assertTrue(data['data']['librarySetting']['library']['classificationLocationEnabled']) + self.assertTrue(data['data']['librarySetting']['library']['classificationFaceEnabled']) self.assertEqual(data['data']['librarySetting']['sourceFolder'], self.defaults['library'].paths.all()[0].path) - def test_library_update_style_enabled_mutaion(self): - """Test library updateStyleEnabled mutaion response.""" + def test_library_update_style_enabled_mutation(self): + """Test library updateStyleEnabled mutation response.""" mutation = """ mutation updateStyleEnabled( $classificationStyleEnabled: Boolean! @@ -230,8 +232,8 @@ def test_library_update_style_enabled_mutaion(self): assert response.status_code == 200 assert tuple(tuple(data.values())[0].values())[0].get('classificationStyleEnabled') - def test_library_update_color_enabled_mutaion(self): - """Test library updateColorEnabled mutaion response.""" + def test_library_update_color_enabled_mutation(self): + """Test library updateColorEnabled mutation response.""" mutation = """ mutation updateColorEnabled( $classificationColorEnabled: Boolean! @@ -252,8 +254,8 @@ def test_library_update_color_enabled_mutaion(self): assert response.status_code == 200 assert tuple(tuple(data.values())[0].values())[0].get('classificationColorEnabled') - def test_library_update_location_enabled_mutaion(self): - """Test library updateLocationEnabled mutaion response.""" + def test_library_update_location_enabled_mutation(self): + """Test library updateLocationEnabled mutation response.""" mutation = """ mutation updateLocationEnabled( $classificationLocationEnabled: Boolean! @@ -274,8 +276,8 @@ def test_library_update_location_enabled_mutaion(self): assert response.status_code == 200 self.assertFalse(tuple(tuple(data.values())[0].values())[0].get('classificationLocationEnabled')) - def test_library_update_object_enabled_mutaion(self): - """Test library updateObjectEnabled mutaion response.""" + def test_library_update_object_enabled_mutation(self): + """Test library updateObjectEnabled mutation response.""" mutation = """ mutation updateObjectEnabled( $classificationObjectEnabled: Boolean! @@ -296,8 +298,8 @@ def test_library_update_object_enabled_mutaion(self): assert response.status_code == 200 self.assertFalse(tuple(tuple(data.values())[0].values())[0].get('classificationObjectEnabled')) - def test_library_update_source_folder_mutaion(self): - """Test library updateSourceFolder mutaion response.""" + def test_library_update_source_folder_mutation(self): + """Test library updateSourceFolder mutation response.""" mutation = """ mutation updateSourceFolder($sourceFolder: String!, $libraryId: ID) { updateSourceFolder( @@ -313,7 +315,7 @@ def test_library_update_source_folder_mutaion(self): self.assertEqual(tuple(tuple(data.values())[0].values())[0].get('sourceFolder'),self.defaults['library'].paths.all()[0].path) def test_change_password_mutation(self): - """Test change password mutaion response.""" + """Test change password mutation response.""" mutation = """ mutation changePassword ( $oldPassword: String!, @@ -831,8 +833,12 @@ def test_onboarding_steps(self): assert User.objects.first().has_configured_importing self.assertFalse(User.objects.first().has_configured_image_analysis) mutation = """ - mutation ($classificationColorEnabled: Boolean!,$classificationStyleEnabled: Boolean!, - $classificationObjectEnabled: Boolean!,$classificationLocationEnabled: Boolean!, + mutation ( + $classificationColorEnabled: Boolean!, + $classificationStyleEnabled: Boolean!, + $classificationObjectEnabled: Boolean!, + $classificationLocationEnabled: Boolean!, + $classificationFaceEnabled: Boolean!, $userId: ID!,$libraryId: ID!, ) { imageAnalysis(input:{ @@ -840,6 +846,7 @@ def test_onboarding_steps(self): classificationStyleEnabled:$classificationStyleEnabled, classificationObjectEnabled:$classificationObjectEnabled, classificationLocationEnabled:$classificationLocationEnabled, + classificationFaceEnabled:$classificationFaceEnabled, userId:$userId, libraryId:$libraryId, }) { @@ -851,8 +858,11 @@ def test_onboarding_steps(self): library_id = data['data']['PhotoImporting']['libraryId'] response = self.api_client.post_graphql( mutation, { - 'classificationColorEnabled': True, 'classificationStyleEnabled': True, - 'classificationObjectEnabled': False, 'classificationLocationEnabled': False, + 'classificationColorEnabled': True, + 'classificationStyleEnabled': True, + 'classificationObjectEnabled': False, + 'classificationLocationEnabled': False, + 'classificationFaceEnabled': False, 'userId': data['data']['PhotoImporting']['userId'], 'libraryId': data['data']['PhotoImporting']['libraryId'], }) diff --git a/tests/test_task_queue.py b/tests/test_task_queue.py index 702dd472..4f900f8f 100644 --- a/tests/test_task_queue.py +++ b/tests/test_task_queue.py @@ -74,7 +74,7 @@ def test_tasks_created_updated(photo_fixture_snow): process_classify_images_tasks() task = Task.objects.get(type='classify_images', subject_id=photo_fixture_snow.id) assert task.status == 'S' - assert task.children.count() == 4 + assert task.children.count() == 6 assert task.complete_with_children == True # Completing all the child processes should set the parent task to completed diff --git a/tests/test_thumbnails.py b/tests/test_thumbnails.py index 7f798d0f..9cc85725 100644 --- a/tests/test_thumbnails.py +++ b/tests/test_thumbnails.py @@ -59,7 +59,7 @@ def test_view(photo_fixture_snow): # Now we should get the actual thumbnail image file assert response.status_code == 200 assert response.content[:10] == b'\xff\xd8\xff\xe0\x00\x10JFIF' - assert response._headers['content-type'][1] == 'image/jpeg' + assert response.headers['Content-Type'] == 'image/jpeg' response_length = len(response.content) assert response_length > 5929 * 0.8 assert response_length < 5929 * 1.2