Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: DIA-1926: return annots/preds in sample data #7178

Open
wants to merge 4 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions label_studio/projects/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from django_filters import CharFilter, FilterSet
from django_filters.rest_framework import DjangoFilterBackend
from drf_yasg.utils import swagger_auto_schema
from label_studio_sdk.label_interface.interface import LabelInterface
from ml.serializers import MLBackendSerializer
from projects.functions.next_task import get_next_task
from projects.functions.stream_history import get_label_stream_history
Expand Down Expand Up @@ -832,6 +833,17 @@ def post(self, request, *args, **kwargs):
raise RestValidationError('Label config is not set or is empty')

project = self.get_object()

try:
label_interface = LabelInterface(label_config)
complete_task = label_interface.generate_complete_sample_task(raise_on_failure=True)
return Response({'sample_task': complete_task}, status=200)
except Exception as e:
logger.error(
f'Error generating enhanced sample task, falling back to original method: {str(e)}. Label config: {label_config}'
)

# Fallback to project.get_sample_task LabelInterface.generate_complete_sample_task failed
return Response({'sample_task': project.get_sample_task(label_config)}, status=200)


Expand Down
155 changes: 155 additions & 0 deletions label_studio/projects/tests/test_project_sample_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import json
from unittest.mock import patch

import projects.api
import pytest
from django.test import TestCase
from django.urls import reverse
from projects.tests.factories import ProjectFactory
from rest_framework.test import APIClient


@pytest.mark.django_db
class TestProjectSampleTask(TestCase):
@classmethod
def setUpTestData(cls):
cls.project = ProjectFactory()

@property
def url(self):
return reverse('projects:api:project-sample-task', kwargs={'pk': self.project.id})

def test_sample_task_with_happy_path(self):
"""Test that ProjectSampleTask.post successfully creates a complete sample task with annotations and predictions"""
client = APIClient()
client.force_authenticate(user=self.project.created_by)
label_config = """
<View>
<Text name='text' value='$text'/>
<Choices name='sentiment' toName='text'>
<Choice value='Positive'/>
<Choice value='Negative'/>
<Choice value='Neutral'/>
</Choices>
</View>
"""
sample_prediction = {
'model_version': 'sample model version',
'result': [
{
'id': 'abc123',
'from_name': 'sentiment',
'to_name': 'text',
'type': 'choices',
'value': {'choices': ['Positive']},
}
],
'score': 0.95,
}
sample_annotation = {
'was_cancelled': False,
'ground_truth': False,
'result': [
{
'id': 'def456',
'from_name': 'sentiment',
'to_name': 'text',
'type': 'choices',
'value': {'choices': ['Positive']},
}
],
'completed_by': -1,
}
sample_task = {
'id': 1,
'data': {'text': 'This is a sample task for labeling.'},
'predictions': [sample_prediction],
'annotations': [sample_annotation],
}

with patch.object(
projects.api.LabelInterface,
'generate_complete_sample_task',
return_value=sample_task,
):
response = client.post(
self.url,
data=json.dumps({'label_config': label_config}),
content_type='application/json',
)

assert response.status_code == 200
response_data = response.json()
assert 'sample_task' in response_data
assert response_data['sample_task'] == sample_task

def test_sample_task_fallback_when_generate_task_fails(self):
"""Test fallback to project.get_sample_task when LabelInterface.generate_complete_sample_task fails"""
client = APIClient()
client.force_authenticate(user=self.project.created_by)
label_config = """
<View>
<Text name='text' value='$text'/>
<Choices name='sentiment' toName='text'>
<Choice value='Positive'/>
<Choice value='Negative'/>
<Choice value='Neutral'/>
</Choices>
</View>
"""
fallback_data = {'id': 999, 'data': {'text': 'Fallback task'}}

with (
patch.object(
projects.api.LabelInterface,
'generate_complete_sample_task',
side_effect=ValueError('Failed to generate sample task'),
),
patch('projects.api.Project.get_sample_task', return_value=fallback_data),
):

response = client.post(
self.url,
data=json.dumps({'label_config': label_config}),
content_type='application/json',
)

assert response.status_code == 200
response_data = response.json()
assert 'sample_task' in response_data
assert response_data['sample_task'] == fallback_data

def test_sample_task_fallback_when_prediction_generation_fails(self):
"""Test fallback to project.get_sample_task when LabelInterface.generate_sample_prediction raises an exception"""
client = APIClient()
client.force_authenticate(user=self.project.created_by)
label_config = """
<View>
<Text name='text' value='$text'/>
<Choices name='sentiment' toName='text'>
<Choice value='Positive'/>
<Choice value='Negative'/>
<Choice value='Neutral'/>
</Choices>
</View>
"""
fallback_data = {'id': 999, 'data': {'text': 'Fallback task'}}

with (
patch.object(
projects.api.LabelInterface,
'generate_sample_prediction',
return_value=None,
),
patch('projects.api.Project.get_sample_task', return_value=fallback_data),
):
response = client.post(
self.url,
data=json.dumps({'label_config': label_config}),
content_type='application/json',
)

assert response.status_code == 200
response_data = response.json()
assert 'sample_task' in response_data
assert response_data['sample_task'] == fallback_data
Loading
Loading