Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit f0f8e1b

Browse files
committedJan 31, 2025·
feat(restapi): add draft commit workflow
This commit adds a new workflow for committing drafts. Committing a draft will create a new resource snapshot and delete the draft in a single transaction. A draft can be either a draft of a new resource or draft modifications to an existing resource. For draft modifications, the commit will be rejected if the snapshot the draft was based off of is not longer the most recent snapshot. In this case an error is raised that contains detailed information about the draft, the base snapshot, and the current snapshot so that the user can reconcile their desired changes. A draft commit can also be rejected if the resource fails to be created for any reason (e.g. name collisions or associations to invalid resources). It also adds the workflow to the Python client and implements new tests for the workflow.
1 parent c8a4c8d commit f0f8e1b

20 files changed

+668
-53
lines changed
 

‎src/dioptra/client/entrypoints.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,13 @@
3939
"queues",
4040
"plugins",
4141
}
42+
MODIFY_DRAFT_FIELDS: Final[set[str]] = {
43+
"name",
44+
"description",
45+
"taskGraph",
46+
"parameters",
47+
"queues",
48+
}
4249
FIELD_NAMES_TO_CAMEL_CASE: Final[dict[str, str]] = {
4350
"task_graph": "taskGraph",
4451
}
@@ -291,7 +298,7 @@ def __init__(self, session: DioptraSession[T]) -> None:
291298
self._modify_resource_drafts = ModifyResourceDraftsSubCollectionClient[T](
292299
session=session,
293300
validate_fields_fn=make_draft_fields_validator(
294-
draft_fields=DRAFT_FIELDS,
301+
draft_fields=MODIFY_DRAFT_FIELDS,
295302
resource_name=self.name,
296303
),
297304
root_collection=self,

‎src/dioptra/client/workflows.py

+20
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
T = TypeVar("T")
2323

2424
JOB_FILES_DOWNLOAD: Final[str] = "jobFilesDownload"
25+
DRAFT_COMMIT: Final[str] = "draftCommit"
2526

2627

2728
class WorkflowsCollectionClient(CollectionClient[T]):
@@ -86,3 +87,22 @@ def download_job_files(
8687
return self._session.download(
8788
self.url, JOB_FILES_DOWNLOAD, output_path=job_files_path, params=params
8889
)
90+
91+
def commit_draft(
92+
self,
93+
draft_id: str | int,
94+
) -> T:
95+
"""
96+
Commit a draft as a new resource snapshot.
97+
98+
The draft can be a draft of a new resource or a draft modifications to an
99+
existing resource.
100+
101+
Args:
102+
draft_id: The draft id, an intiger.
103+
104+
Returns:
105+
A dictionary containing the contents of the new resource.
106+
"""
107+
108+
return self._session.post(self.url, DRAFT_COMMIT, str(draft_id))

‎src/dioptra/restapi/errors.py

+43
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@
2828
from flask_restx import Api
2929
from structlog.stdlib import BoundLogger
3030

31+
from dioptra.restapi.db import models
32+
from dioptra.restapi.v1 import utils
33+
3134
LOGGER: BoundLogger = structlog.stdlib.get_logger()
3235

3336

@@ -206,6 +209,26 @@ def __init__(self, type: str, id: int):
206209
self.resource_id = id
207210

208211

212+
class DraftResourceModificationsCommitError(DioptraError):
213+
"""The draft modifications to a resource could not be committed"""
214+
215+
def __init__(
216+
self,
217+
resource_type: str,
218+
resource_id: int,
219+
draft: models.DraftResource,
220+
base_snapshot: models.ResourceSnapshot,
221+
curr_snapshot: models.ResourceSnapshot,
222+
):
223+
super().__init__(
224+
f"Draft modifications for a [{resource_type}] with id: {resource_id} "
225+
"could not be commited."
226+
)
227+
self.draft = draft
228+
self.base_snapshot = base_snapshot
229+
self.curr_snapshot = curr_snapshot
230+
231+
209232
class InvalidDraftBaseResourceSnapshotError(DioptraError):
210233
"""The draft's base snapshot identifier is invalid."""
211234

@@ -432,6 +455,26 @@ def handle_draft_already_exists(error: DraftAlreadyExistsError):
432455
log.debug(error.to_message())
433456
return error_result(error, http.HTTPStatus.BAD_REQUEST, {})
434457

458+
@api.errorhandler(DraftResourceModificationsCommitError)
459+
def handle_draft_resource_modifications_commit_error(
460+
error: DraftResourceModificationsCommitError,
461+
):
462+
log.debug(error.to_message())
463+
464+
return error_result(
465+
error,
466+
http.HTTPStatus.BAD_REQUEST,
467+
{
468+
"reason": f"The {error.draft.resource_type} has been modified since "
469+
"this draft was created.",
470+
"draft": error.draft.payload["resource_data"],
471+
"base_snapshot_id": error.base_snapshot.resource_snapshot_id,
472+
"curr_snapshot_id": error.curr_snapshot.resource_snapshot_id,
473+
"base_snapshot": utils.build_resource(error.base_snapshot),
474+
"curr_snapshot": utils.build_resource(error.curr_snapshot),
475+
},
476+
)
477+
435478
@api.errorhandler(InvalidDraftBaseResourceSnapshotError)
436479
def handle_invalid_draft_base_resource_snapshot(
437480
error: InvalidDraftBaseResourceSnapshotError,

‎src/dioptra/restapi/v1/entrypoints/controller.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,7 @@ def delete(self, id: int, queueId):
408408
EntrypointIdDraftResource = generate_resource_id_draft_endpoint(
409409
api,
410410
resource_name=RESOURCE_TYPE,
411-
request_schema=EntrypointDraftSchema(exclude=["groupId"]),
411+
request_schema=EntrypointDraftSchema(exclude=["groupId", "pluginIds"]),
412412
)
413413

414414
EntrypointSnapshotsResource = generate_resource_snapshots_endpoint(

‎src/dioptra/restapi/v1/plugins/controller.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -317,11 +317,11 @@ def post(self, id: int):
317317
parsed_obj = request.parsed_obj # type: ignore # noqa: F841
318318

319319
plugin_file = self._plugin_id_file_service.create(
320+
plugin_id=id,
320321
filename=parsed_obj["filename"],
321322
contents=parsed_obj["contents"],
322323
description=parsed_obj["description"],
323324
tasks=parsed_obj["tasks"],
324-
plugin_id=id,
325325
log=log,
326326
)
327327
return utils.build_plugin_file(plugin_file)

‎src/dioptra/restapi/v1/plugins/service.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -709,11 +709,11 @@ def __init__(
709709

710710
def create(
711711
self,
712+
plugin_id: int,
712713
filename: str,
713714
contents: str,
714715
description: str,
715716
tasks: list[dict[str, Any]],
716-
plugin_id: int,
717717
commit: bool = True,
718718
**kwargs,
719719
) -> utils.PluginFileDict:
@@ -723,11 +723,11 @@ def create(
723723
with the PluginFile. The creator will be the current user.
724724
725725
Args:
726+
plugin_id: The unique id of the plugin containing the plugin file.
726727
filename: The name of the plugin file.
727728
contents: The contents of the plugin file.
728729
description: The description of the plugin file.
729730
tasks: The tasks associated with the plugin file.
730-
plugin_id: The unique id of the plugin containing the plugin file.
731731
commit: If True, commit the transaction. Defaults to True.
732732
733733
Returns:

‎src/dioptra/restapi/v1/shared/drafts/service.py

-9
Original file line numberDiff line numberDiff line change
@@ -467,15 +467,6 @@ def modify(
467467
if draft is None:
468468
return None, num_other_drafts
469469

470-
# NOTE: This check disables the ability to change the base snapshot ID.
471-
# It is scheduled to be removed as part of the draft commit workflow feature.
472-
if draft.payload["resource_snapshot_id"] != payload["resource_snapshot_id"]:
473-
raise InvalidDraftBaseResourceSnapshotError(
474-
"The provided resource snapshot must match the base resource snapshot",
475-
base_resource_snapshot_id=draft.payload["resource_snapshot_id"],
476-
provided_resource_snapshot_id=payload["resource_snapshot_id"],
477-
)
478-
479470
if draft.payload["resource_snapshot_id"] > payload["resource_snapshot_id"]:
480471
raise InvalidDraftBaseResourceSnapshotError(
481472
"The provided resource snapshot must be greater than or equal to "
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
# This Software (Dioptra) is being made available as a public service by the
2+
# National Institute of Standards and Technology (NIST), an Agency of the United
3+
# States Department of Commerce. This software was developed in part by employees of
4+
# NIST and in part by NIST contractors. Copyright in portions of this software that
5+
# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant
6+
# to Title 17 United States Code Section 105, works of NIST employees are not
7+
# subject to copyright protection in the United States. However, NIST may hold
8+
# international copyright in software created by its employees and domestic
9+
# copyright (or licensing rights) in portions of software that were assigned or
10+
# licensed to NIST. To the extent that NIST holds copyright in this software, it is
11+
# being made available under the Creative Commons Attribution 4.0 International
12+
# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts
13+
# of the software developed or licensed by NIST.
14+
#
15+
# ACCESS THE FULL CC BY 4.0 LICENSE HERE:
16+
# https://creativecommons.org/licenses/by/4.0/legalcode
17+
"""The server-side functions that perform resource operations."""
18+
from __future__ import annotations
19+
20+
from typing import Any
21+
22+
import structlog
23+
from injector import inject
24+
from structlog.stdlib import BoundLogger
25+
26+
from dioptra.restapi.errors import DioptraError
27+
from dioptra.restapi.v1.artifacts.service import ArtifactIdService, ArtifactService
28+
from dioptra.restapi.v1.entrypoints.service import (
29+
EntrypointIdService,
30+
EntrypointService,
31+
)
32+
from dioptra.restapi.v1.experiments.service import (
33+
ExperimentIdService,
34+
ExperimentService,
35+
)
36+
from dioptra.restapi.v1.models.service import ModelIdService, ModelService
37+
from dioptra.restapi.v1.plugin_parameter_types.service import (
38+
PluginParameterTypeIdService,
39+
PluginParameterTypeService,
40+
)
41+
from dioptra.restapi.v1.plugins.service import (
42+
PluginIdFileIdService,
43+
PluginIdFileService,
44+
PluginIdService,
45+
PluginService,
46+
)
47+
from dioptra.restapi.v1.queues.service import QueueIdService, QueueService
48+
49+
LOGGER: BoundLogger = structlog.stdlib.get_logger()
50+
51+
52+
class ResourceService(object):
53+
"""The service methods for creating resources."""
54+
55+
@inject
56+
def __init__(
57+
self,
58+
artifact_service: ArtifactService,
59+
entrypoint_service: EntrypointService,
60+
experiment_service: ExperimentService,
61+
model_service: ModelService,
62+
plugin_service: PluginService,
63+
plugin_id_file_service: PluginIdFileService,
64+
plugin_parameter_type_service: PluginParameterTypeService,
65+
queue_service: QueueService,
66+
) -> None:
67+
"""Initialize the queue service.
68+
69+
All arguments are provided via dependency injection.
70+
71+
Args:
72+
queue_name_service: A QueueNameService object.
73+
"""
74+
75+
self._services: dict[str, object] = {
76+
"artifact": artifact_service,
77+
"entry_point": entrypoint_service,
78+
"experiment": experiment_service,
79+
"ml_model": model_service,
80+
"plugin": plugin_service,
81+
"plugin_file": plugin_id_file_service,
82+
"plugin_task_parameter_type": plugin_parameter_type_service,
83+
"queue": queue_service,
84+
}
85+
86+
def create(
87+
self,
88+
*resource_ids: int,
89+
resource_type: str,
90+
resource_data: dict,
91+
group_id: int,
92+
commit: bool = True,
93+
**kwargs,
94+
) -> Any:
95+
"""Create a new queue.
96+
97+
Args:
98+
resource_type: The type of resource to create, e.g. queue
99+
resource_data: Arguments passed to create services as kwargs
100+
group_id: The group that will own the queue.
101+
commit: If True, commit the transaction. Defaults to True.
102+
103+
Returns:
104+
The newly created queue object.
105+
106+
Raises:
107+
EntityExistsError: If a queue with the given name already exists.
108+
EntityDoesNotExistError: If the group with the provided ID does not exist.
109+
"""
110+
log: BoundLogger = kwargs.get("log", LOGGER.new())
111+
112+
if resource_type not in self._services:
113+
raise DioptraError(f"Invalid resource type: {resource_type}")
114+
115+
return self._services[resource_type].create( # type: ignore
116+
*resource_ids, group_id=group_id, **resource_data, commit=commit, log=log
117+
)
118+
119+
120+
class ResourceIdService(object):
121+
"""The service methods for creating resources."""
122+
123+
@inject
124+
def __init__(
125+
self,
126+
artifact_id_service: ArtifactIdService,
127+
entrypoint_id_service: EntrypointIdService,
128+
experiment_id_service: ExperimentIdService,
129+
model_id_service: ModelIdService,
130+
plugin_id_service: PluginIdService,
131+
plugin_id_file_id_service: PluginIdFileIdService,
132+
plugin_parameter_type_id_service: PluginParameterTypeIdService,
133+
queue_id_service: QueueIdService,
134+
) -> None:
135+
"""Initialize the resource id service.
136+
137+
All arguments are provided via dependency injection.
138+
139+
Args:
140+
queue_name_service: A QueueNameService object.
141+
"""
142+
143+
self._services = {
144+
"artifact": artifact_id_service,
145+
"entry_point": entrypoint_id_service,
146+
"experiment": experiment_id_service,
147+
"ml_model": model_id_service,
148+
"plugin": plugin_id_service,
149+
"plugin_file": plugin_id_file_id_service,
150+
"plugin_task_parameter_type": plugin_parameter_type_id_service,
151+
"queue": queue_id_service,
152+
}
153+
154+
def modify(
155+
self,
156+
*resource_ids: int,
157+
resource_type: str,
158+
resource_data: dict,
159+
group_id: int,
160+
commit: bool = True,
161+
**kwargs,
162+
) -> Any:
163+
"""Create a new queue.
164+
165+
Args:
166+
resource_type: The type of resource to create, e.g. queue
167+
resource_data: Arguments passed to create services as kwargs
168+
group_id: The group that will own the queue.
169+
commit: If True, commit the transaction. Defaults to True.
170+
171+
Returns:
172+
The newly created queue object.
173+
174+
Raises:
175+
EntityExistsError: If a queue with the given name already exists.
176+
EntityDoesNotExistError: If the group with the provided ID does not exist.
177+
"""
178+
log: BoundLogger = kwargs.get("log", LOGGER.new())
179+
180+
if resource_type not in self._services:
181+
raise DioptraError(f"Invalid resource type: {resource_type}")
182+
183+
return self._services[resource_type].modify( # type: ignore
184+
*resource_ids, group_id=group_id, **resource_data, commit=commit, log=log
185+
)

‎src/dioptra/restapi/v1/utils.py

+43-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,14 @@
2323
from dioptra.restapi.db import models
2424
from dioptra.restapi.routes import V1_ROOT
2525

26+
from dioptra.restapi.v1.artifacts.schema import ArtifactSchema
27+
from dioptra.restapi.v1.entrypoints.schema import EntrypointSchema
28+
from dioptra.restapi.v1.experiments.schema import ExperimentSchema
29+
from dioptra.restapi.v1.models.schema import ModelSchema
30+
from dioptra.restapi.v1.plugin_parameter_types.schema import PluginParameterTypeSchema
31+
from dioptra.restapi.v1.plugins.schema import PluginFileSchema, PluginSchema
32+
from dioptra.restapi.v1.queues.schema import QueueSchema
33+
2634
ARTIFACTS: Final[str] = "artifacts"
2735
ENTRYPOINTS: Final[str] = "entrypoints"
2836
EXPERIMENTS: Final[str] = "experiments"
@@ -523,6 +531,40 @@ def build_group(group: models.Group) -> dict[str, Any]:
523531
}
524532

525533

534+
def build_resource(resource_snapshot: models.ResourceSnapshot) -> dict[str, Any]:
535+
"""Build a Resource response dictionary.
536+
Args:
537+
resource_snapshot: The resource snapshot ORM object to convert into a Resource
538+
response dictionary.
539+
Returns:
540+
The Resource response dictionary.
541+
"""
542+
543+
build_fn = {
544+
"artifact": build_artifact,
545+
"entry_point": build_entrypoint,
546+
"experiment": build_experiment,
547+
"ml_model": build_model,
548+
"plugin": build_plugin,
549+
"plugin_file": build_plugin_file,
550+
"plugin_task_parameter_type": build_plugin_parameter_type,
551+
"queue": build_queue,
552+
}.get(resource_snapshot.resource_type)
553+
554+
schema = {
555+
"artifact": ArtifactSchema(),
556+
"entry_point": EntrypointSchema(),
557+
"experiment": ExperimentSchema(),
558+
"ml_model": ModelSchema(),
559+
"plugin": PluginSchema(),
560+
"plugin_file": PluginFileSchema(),
561+
"plugin_task_parameter_type": PluginParameterTypeSchema(),
562+
"queue": QueueSchema(),
563+
}.get(resource_snapshot.resource_type)
564+
565+
return schema.dump(build_fn({resource_snapshot.resource_type: resource_snapshot})) # type: ignore
566+
567+
526568
def build_experiment(experiment_dict: ExperimentDict) -> dict[str, Any]:
527569
"""Build an Experiment response dictionary.
528570
@@ -939,7 +981,7 @@ def build_plugin_parameter_type(
939981
The Plugin Parameter Type response dictionary.
940982
"""
941983
plugin_parameter_type = plugin_parameter_type_dict["plugin_task_parameter_type"]
942-
has_draft = plugin_parameter_type_dict["has_draft"]
984+
has_draft = plugin_parameter_type_dict.get("has_draft", None)
943985

944986
data = {
945987
"id": plugin_parameter_type.resource_id,

‎src/dioptra/restapi/v1/workflows/controller.py

+31-2
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,16 @@
1919

2020
import structlog
2121
from flask import request, send_file
22-
from flask_accepts import accepts
22+
from flask_accepts import accepts, responds
2323
from flask_login import login_required
2424
from flask_restx import Namespace, Resource
2525
from injector import inject
2626
from structlog.stdlib import BoundLogger
2727

28+
from dioptra.restapi.v1.schemas import IdStatusResponseSchema
29+
2830
from .schema import FileTypes, JobFilesDownloadQueryParametersSchema
29-
from .service import JobFilesDownloadService
31+
from .service import DraftCommitService, JobFilesDownloadService
3032

3133
LOGGER: BoundLogger = structlog.stdlib.get_logger()
3234

@@ -78,3 +80,30 @@ def get(self):
7880
mimetype=mimetype[parsed_query_params["file_type"]],
7981
download_name=download_name[parsed_query_params["file_type"]],
8082
)
83+
84+
85+
@api.route("/draftCommit/<int:id>")
86+
@api.param("id", "ID for the Draft resource.")
87+
class DraftCommitEndpoint(Resource):
88+
@inject
89+
def __init__(
90+
self, draft_commit_service: DraftCommitService, *args, **kwargs
91+
) -> None:
92+
"""Initialize the workflow resource.
93+
94+
All arguments are provided via dependency injection.
95+
96+
Args:
97+
draft_commit_service: A DraftCommitService object.
98+
"""
99+
self._draft_commit_service = draft_commit_service
100+
super().__init__(*args, **kwargs)
101+
102+
@login_required
103+
@responds(schema=IdStatusResponseSchema, api=api)
104+
def post(self, id: int):
105+
"""Commit a draft as a new resource""" # noqa: B950
106+
log = LOGGER.new(
107+
request_id=str(uuid.uuid4()), resource="DraftCommit", request_type="POST"
108+
) # noqa: F841
109+
return self._draft_commit_service.commit_draft(draft_id=id, log=log)

‎src/dioptra/restapi/v1/workflows/lib/views.py

+89-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from structlog.stdlib import BoundLogger
2020

2121
from dioptra.restapi.db import db, models
22-
from dioptra.restapi.errors import EntityDoesNotExistError
22+
from dioptra.restapi.errors import DioptraError, EntityDoesNotExistError
2323
from dioptra.restapi.v1.entrypoints.service import (
2424
RESOURCE_TYPE as ENTRYPONT_RESOURCE_TYPE,
2525
)
@@ -163,3 +163,91 @@ def get_plugin_parameter_types(
163163
)
164164
)
165165
return list(db.session.scalars(plugin_parameter_types_stmt).all())
166+
167+
168+
def get_resource(
169+
resource_id: int, logger: BoundLogger | None = None
170+
) -> models.Resource | None:
171+
"""Run a query to get a resource
172+
173+
Args:
174+
resource_id: The identifier of the Resource
175+
logger: A structlog logger object to use for logging. A new logger will be
176+
created if None.
177+
178+
Returns:
179+
The retrieved DraftResource ORM object
180+
"""
181+
log = logger or LOGGER.new() # noqa: F841
182+
183+
resource_stmt = select(models.Resource).where(
184+
models.Resource.resource_id == resource_id
185+
)
186+
return db.session.scalar(resource_stmt)
187+
188+
189+
def get_resource_snapshot(
190+
resource_type: str, snapshot_id: int, logger: BoundLogger | None = None
191+
) -> models.ResourceSnapshot:
192+
"""Run a query to get the latest snapshot for a resource
193+
194+
Args:
195+
resource_type: The type of Resource.
196+
snapshot_id: The ID of snapshot to retrieve
197+
logger: A structlog logger object to use for logging. A new logger will be
198+
created if None.
199+
200+
Returns:
201+
The resource snapshot
202+
"""
203+
log = logger or LOGGER.new() # noqa: F841
204+
205+
snapshot_model = {
206+
"artifact": models.Artifact,
207+
"entry_point": models.EntryPoint,
208+
"experiment": models.Experiment,
209+
"ml_model": models.MlModel,
210+
"plugin": models.Plugin,
211+
"plugin_file": models.PluginFile,
212+
"plugin_task_parameter_type": models.PluginTaskParameterType,
213+
"queue": models.Queue,
214+
}.get(resource_type, None)
215+
216+
if snapshot_model is None:
217+
raise DioptraError(f"Invalid resource type: {resource_type}")
218+
219+
snapshot_stmt = (
220+
select(snapshot_model)
221+
.join(models.Resource)
222+
.where(
223+
snapshot_model.resource_snapshot_id == snapshot_id,
224+
models.Resource.is_deleted == False, # noqa: E712
225+
)
226+
)
227+
snapshot = db.session.scalar(snapshot_stmt)
228+
229+
if snapshot is None:
230+
raise EntityDoesNotExistError(resource_type, snapshot_id=snapshot_id)
231+
232+
return snapshot
233+
234+
235+
def get_draft_resource(
236+
draft_id: int, logger: BoundLogger | None = None
237+
) -> models.DraftResource | None:
238+
"""Run a query to get the draft of a resource
239+
240+
Args:
241+
draft_id: The identifier of the DraftResource
242+
logger: A structlog logger object to use for logging. A new logger will be
243+
created if None.
244+
245+
Returns:
246+
The retrieved DraftResource ORM object
247+
"""
248+
log = logger or LOGGER.new() # noqa: F841
249+
250+
draft_stmt = select(models.DraftResource).where(
251+
models.DraftResource.draft_resource_id == draft_id
252+
)
253+
return db.session.scalar(draft_stmt)

‎src/dioptra/restapi/v1/workflows/service.py

+122-1
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,23 @@
1515
# ACCESS THE FULL CC BY 4.0 LICENSE HERE:
1616
# https://creativecommons.org/licenses/by/4.0/legalcode
1717
"""The server-side functions that perform workflows endpoint operations."""
18-
from typing import IO, Final
18+
from typing import IO, Final, cast
1919

2020
import structlog
21+
from injector import inject
2122
from structlog.stdlib import BoundLogger
2223

24+
from dioptra.restapi.db import db
25+
from dioptra.restapi.errors import (
26+
DraftDoesNotExistError,
27+
DraftResourceModificationsCommitError,
28+
EntityDoesNotExistError,
29+
)
30+
from dioptra.restapi.v1.shared.resource_service import (
31+
ResourceIdService,
32+
ResourceService,
33+
)
34+
2335
from .lib import views
2436
from .lib.package_job_files import package_job_files
2537
from .schema import FileTypes
@@ -65,3 +77,112 @@ def get(self, job_id: int, file_type: FileTypes, **kwargs) -> IO[bytes]:
6577
file_type=file_type,
6678
logger=log,
6779
)
80+
81+
82+
class DraftCommitService(object):
83+
"""The service methods for commiting a Draft as a new ResourceSnapshot."""
84+
85+
@inject
86+
def __init__(
87+
self,
88+
resource_service: ResourceService,
89+
resource_id_service: ResourceIdService,
90+
) -> None:
91+
"""Initialize the queue service.
92+
93+
All arguments are provided via dependency injection.
94+
95+
Args:
96+
resource_service: A ResourceService object.
97+
resource_id_service: A ResourceIdService object.
98+
"""
99+
self._resource_service = resource_service
100+
self._resource_id_service = resource_id_service
101+
102+
def commit_draft(self, draft_id: int, **kwargs) -> dict:
103+
"""Commit the Draft as a new ResourceSnapshot
104+
105+
Args:
106+
draft_id: The identifier of the draft.
107+
108+
Returns:
109+
The packaged job files returned as a named temporary file.
110+
"""
111+
log: BoundLogger = kwargs.get("log", LOGGER.new())
112+
log.debug("commit draft", draft_id=draft_id)
113+
114+
draft = views.get_draft_resource(draft_id, logger=log)
115+
if draft is None:
116+
raise DraftDoesNotExistError(draft_resource_id=draft_id)
117+
118+
if draft.payload["resource_id"] is None:
119+
resource_ids = (
120+
[draft.payload["base_resource_id"]]
121+
if draft.payload["base_resource_id"] is not None
122+
else []
123+
)
124+
125+
resource_dict = self._resource_service.create(
126+
*resource_ids,
127+
resource_type=draft.resource_type,
128+
resource_data=draft.payload["resource_data"],
129+
group_id=draft.group_id,
130+
commit=False,
131+
log=log,
132+
)
133+
else: # the draft contains modifications to an existing resource
134+
resource = views.get_resource(draft.payload["resource_id"])
135+
if resource is None:
136+
raise EntityDoesNotExistError(
137+
draft.resource_type, resource_id=draft.payload["resource_id"]
138+
)
139+
140+
# if the underlying resource was modified since the draft was created,
141+
# raise an error with the information necessary to reconcile the draft.
142+
if draft.payload["resource_snapshot_id"] != resource.latest_snapshot_id:
143+
base_snapshot = views.get_resource_snapshot(
144+
draft.resource_type, draft.payload["resource_snapshot_id"]
145+
)
146+
if base_snapshot is None:
147+
raise EntityDoesNotExistError(
148+
draft.resource_type,
149+
snapshot_id=draft.payload["resource_snapshot_id"],
150+
)
151+
152+
curr_snapshot = views.get_resource_snapshot(
153+
draft.resource_type, cast(int, resource.latest_snapshot_id)
154+
)
155+
if curr_snapshot is None:
156+
raise EntityDoesNotExistError(
157+
draft.resource_type, resource_id=draft.payload["resource_id"]
158+
)
159+
160+
raise DraftResourceModificationsCommitError(
161+
resource_type=draft.resource_type,
162+
resource_id=draft.payload["resource_id"],
163+
draft=draft,
164+
base_snapshot=base_snapshot,
165+
curr_snapshot=curr_snapshot,
166+
)
167+
168+
resource_ids = [draft.payload["resource_id"]]
169+
if draft.payload["base_resource_id"] is not None:
170+
resource_ids = [draft.payload["base_resource_id"]] + resource_ids
171+
172+
resource_dict = self._resource_id_service.modify(
173+
*resource_ids,
174+
resource_type=draft.resource_type,
175+
resource_data=draft.payload["resource_data"],
176+
group_id=draft.group_id,
177+
commit=False,
178+
log=log,
179+
)
180+
181+
db.session.delete(draft)
182+
183+
db.session.commit()
184+
185+
resource_ids = [resource_dict[draft.resource_type].resource_id]
186+
if draft.payload["base_resource_id"] is not None:
187+
resource_ids = [draft.payload["base_resource_id"]] + resource_ids
188+
return {"status": "Success", "id": resource_ids}

‎tests/unit/restapi/lib/asserts.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,15 @@ def assert_base_resource_contents_match_expectations(response: dict[str, Any]) -
4141
assert helpers.is_iso_format(response["lastModifiedOn"])
4242

4343

44+
def assert_resource_contents_match_expectations(
45+
response: dict[str, Any], expected: dict[str, Any]
46+
) -> None:
47+
for key, value in expected.items():
48+
assert key in response
49+
if isinstance(value, (int, float, bool, str)):
50+
assert response[key] == value
51+
52+
4453
def assert_user_ref_contents_matches_expectations(
4554
user: dict[str, Any], expected_user_id: int
4655
) -> None:
@@ -284,9 +293,7 @@ def assert_creating_another_existing_draft_fails(
284293
Raises:
285294
AssertionError: If the response status code is not 400.
286295
"""
287-
response = drafts_client.create(
288-
*resource_ids, **payload
289-
)
296+
response = drafts_client.create(*resource_ids, **payload)
290297
assert response.status_code == HTTPStatus.BAD_REQUEST
291298

292299

‎tests/unit/restapi/lib/routines.py

+48-17
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,22 @@
1616
# https://creativecommons.org/licenses/by/4.0/legalcode
1717
from typing import Any
1818

19-
from dioptra.client.base import DioptraResponseProtocol
19+
from dioptra.client.base import CollectionClient, DioptraResponseProtocol
2020
from dioptra.client.drafts import (
2121
ModifyResourceDraftsSubCollectionClient,
2222
NewResourceDraftsSubCollectionClient,
2323
)
2424
from dioptra.client.snapshots import SnapshotsSubCollectionClient
2525
from dioptra.client.tags import TagsSubCollectionClient
26+
from dioptra.client.workflows import WorkflowsCollectionClient
2627

2728
from . import asserts
2829

2930

3031
def run_new_resource_drafts_tests(
31-
client: NewResourceDraftsSubCollectionClient[DioptraResponseProtocol],
32+
resource_client: CollectionClient[DioptraResponseProtocol],
33+
draft_client: NewResourceDraftsSubCollectionClient[DioptraResponseProtocol],
34+
workflow_client: WorkflowsCollectionClient[DioptraResponseProtocol],
3235
*resource_ids: str | int,
3336
drafts: dict[str, Any],
3437
draft1_mod: dict[str, Any],
@@ -38,83 +41,111 @@ def run_new_resource_drafts_tests(
3841
group_id: int | None = None,
3942
) -> None:
4043
# Creation operation tests
41-
draft1_response = client.create(
44+
draft1_response = draft_client.create(
4245
*resource_ids, group_id=group_id, **drafts["draft1"]
4346
).json()
4447
asserts.assert_draft_response_contents_matches_expectations(
4548
draft1_response, draft1_expected
4649
)
4750
asserts.assert_retrieving_draft_by_id_works(
48-
client,
51+
draft_client,
4952
*resource_ids,
5053
draft_id=draft1_response["id"],
5154
expected=draft1_response,
5255
)
5356

54-
draft2_response = client.create(
57+
draft2_response = draft_client.create(
5558
*resource_ids, group_id=group_id, **drafts["draft2"]
5659
).json()
5760
asserts.assert_draft_response_contents_matches_expectations(
5861
draft2_response, draft2_expected
5962
)
6063
asserts.assert_retrieving_draft_by_id_works(
61-
client,
64+
draft_client,
6265
*resource_ids,
6366
draft_id=draft2_response["id"],
6467
expected=draft2_response,
6568
)
6669
asserts.assert_retrieving_drafts_works(
67-
client,
70+
draft_client,
6871
*resource_ids,
6972
expected=[draft1_response, draft2_response],
7073
)
7174

7275
# Modify operation tests
73-
response = client.modify(
76+
response = draft_client.modify(
7477
*resource_ids, draft_id=draft1_response["id"], **draft1_mod
7578
).json()
7679
asserts.assert_draft_response_contents_matches_expectations(
7780
response, draft1_mod_expected
7881
)
7982

8083
# Delete operation tests
81-
client.delete(*resource_ids, draft_id=draft1_response["id"])
84+
draft_client.delete(*resource_ids, draft_id=draft1_response["id"])
85+
asserts.assert_new_draft_is_not_found(
86+
draft_client, *resource_ids, draft_id=draft1_response["id"]
87+
)
88+
89+
# Commit operation tests
90+
commit_response = workflow_client.commit_draft(draft2_response["id"]).json()
8291
asserts.assert_new_draft_is_not_found(
83-
client, *resource_ids, draft_id=draft1_response["id"]
92+
draft_client, *resource_ids, draft_id=draft2_response["id"]
93+
)
94+
resource_response = resource_client.get_by_id(*commit_response["id"]).json()
95+
asserts.assert_resource_contents_match_expectations(
96+
resource_response, draft2_expected["payload"]
8497
)
8598

8699

87100
def run_existing_resource_drafts_tests(
88-
client: ModifyResourceDraftsSubCollectionClient[DioptraResponseProtocol],
101+
resource_client: CollectionClient[DioptraResponseProtocol],
102+
draft_client: ModifyResourceDraftsSubCollectionClient[DioptraResponseProtocol],
103+
workflow_client: WorkflowsCollectionClient[DioptraResponseProtocol],
89104
*resource_ids: str | int,
90105
draft: dict[str, Any],
91106
draft_mod: dict[str, Any],
92107
draft_expected: dict[str, Any],
93108
draft_mod_expected: dict[str, Any],
94109
) -> None:
95110
# Creation operation tests
96-
response = client.create(*resource_ids, **draft).json()
111+
response = draft_client.create(*resource_ids, **draft).json()
97112
asserts.assert_draft_response_contents_matches_expectations(
98113
response, draft_expected
99114
)
100115
asserts.assert_retrieving_draft_by_resource_id_works(
101-
client, *resource_ids, expected=response
116+
draft_client, *resource_ids, expected=response
102117
)
103118
asserts.assert_creating_another_existing_draft_fails(
104-
client, *resource_ids, payload=draft
119+
draft_client, *resource_ids, payload=draft
105120
)
106121

107122
# Modify operation tests
108-
response = client.modify(
123+
response = draft_client.modify(
109124
*resource_ids, resource_snapshot_id=response["resourceSnapshot"], **draft_mod
110125
).json()
111126
asserts.assert_draft_response_contents_matches_expectations(
112127
response, draft_mod_expected
113128
)
114129

115130
# Delete operation tests
116-
client.delete(*resource_ids)
117-
asserts.assert_existing_draft_is_not_found(client, *resource_ids)
131+
draft_client.delete(*resource_ids)
132+
asserts.assert_existing_draft_is_not_found(draft_client, *resource_ids)
133+
134+
# Commit operation tests
135+
draft_response = draft_client.create(*resource_ids, **draft).json()
136+
resource_response = resource_client.modify_by_id(*resource_ids, **draft_mod).json()
137+
commit_response = workflow_client.commit_draft(draft_response["id"]).json()
138+
draft_response = draft_client.modify(
139+
*resource_ids,
140+
resource_snapshot_id=commit_response["detail"]["curr_snapshot_id"],
141+
**draft,
142+
).json()
143+
commit_response = workflow_client.commit_draft(draft_response["id"]).json()
144+
asserts.assert_existing_draft_is_not_found(draft_client, *resource_ids)
145+
resource_response = resource_client.get_by_id(*commit_response["id"]).json()
146+
asserts.assert_resource_contents_match_expectations(
147+
resource_response, draft_expected["payload"]
148+
)
118149

119150

120151
def run_resource_snapshots_tests(

‎tests/unit/restapi/v1/test_entrypoint.py

+16-9
Original file line numberDiff line numberDiff line change
@@ -842,7 +842,6 @@ def test_manage_existing_entrypoint_draft(
842842
"parameterType": "string",
843843
}
844844
]
845-
plugin_ids = [plugin["id"] for plugin in entrypoint["plugins"]]
846845
queue_ids = [queue["id"] for queue in entrypoint["queues"]]
847846

848847
# test creation
@@ -851,15 +850,13 @@ def test_manage_existing_entrypoint_draft(
851850
"description": description,
852851
"task_graph": task_graph,
853852
"parameters": parameters,
854-
"plugins": plugin_ids,
855853
"queues": queue_ids,
856854
}
857855
draft_mod = {
858856
"name": new_name,
859857
"description": description,
860858
"task_graph": task_graph,
861859
"parameters": parameters,
862-
"plugins": plugin_ids,
863860
"queues": queue_ids,
864861
}
865862

@@ -875,7 +872,6 @@ def test_manage_existing_entrypoint_draft(
875872
"description": description,
876873
"taskGraph": task_graph,
877874
"parameters": parameters,
878-
"plugins": plugin_ids,
879875
"queues": queue_ids,
880876
},
881877
}
@@ -890,14 +886,15 @@ def test_manage_existing_entrypoint_draft(
890886
"description": description,
891887
"taskGraph": task_graph,
892888
"parameters": parameters,
893-
"plugins": plugin_ids,
894889
"queues": queue_ids,
895890
},
896891
}
897892

898893
# Run routine: existing resource drafts tests
899894
routines.run_existing_resource_drafts_tests(
895+
dioptra_client.entrypoints,
900896
dioptra_client.entrypoints.modify_resource_drafts,
897+
dioptra_client.workflows,
901898
entrypoint["id"],
902899
draft=draft,
903900
draft_mod=draft_mod,
@@ -910,6 +907,8 @@ def test_manage_new_entrypoint_drafts(
910907
dioptra_client: DioptraClient[DioptraResponseProtocol],
911908
db: SQLAlchemy,
912909
auth_account: dict[str, Any],
910+
registered_queues: dict[str, Any],
911+
registered_plugins: dict[str, Any],
913912
) -> None:
914913
"""Test that drafts of entrypoint can be created and managed by the user
915914
@@ -939,8 +938,11 @@ def test_manage_new_entrypoint_drafts(
939938
"description": "entrypoint",
940939
"task_graph": "graph",
941940
"parameters": [],
942-
"queues": [1, 3],
943-
"plugins": [2],
941+
"queues": [
942+
registered_queues["queue1"]["id"],
943+
registered_queues["queue3"]["id"],
944+
],
945+
"plugins": [registered_plugins["plugin2"]["id"]],
944946
},
945947
}
946948
draft1_mod = {
@@ -973,8 +975,11 @@ def test_manage_new_entrypoint_drafts(
973975
"description": "entrypoint",
974976
"taskGraph": "graph",
975977
"parameters": [],
976-
"queues": [1, 3],
977-
"plugins": [2],
978+
"queues": [
979+
registered_queues["queue1"]["id"],
980+
registered_queues["queue3"]["id"],
981+
],
982+
"plugins": [registered_plugins["plugin2"]["id"]],
978983
},
979984
}
980985
draft1_mod_expected = {
@@ -992,7 +997,9 @@ def test_manage_new_entrypoint_drafts(
992997

993998
# Run routine: existing resource drafts tests
994999
routines.run_new_resource_drafts_tests(
1000+
dioptra_client.entrypoints,
9951001
dioptra_client.entrypoints.new_resource_drafts,
1002+
dioptra_client.workflows,
9961003
drafts=drafts,
9971004
draft1_mod=draft1_mod,
9981005
draft1_expected=draft1_expected,

‎tests/unit/restapi/v1/test_experiment.py

+28-4
Original file line numberDiff line numberDiff line change
@@ -816,6 +816,7 @@ def test_manage_existing_experiment_draft(
816816
db: SQLAlchemy,
817817
auth_account: dict[str, Any],
818818
registered_experiments: dict[str, Any],
819+
registered_entrypoints: dict[str, Any],
819820
) -> None:
820821
"""Test that a draft of an existing experiment can be created and managed by the
821822
user
@@ -840,8 +841,19 @@ def test_manage_existing_experiment_draft(
840841
description = "description"
841842

842843
# test creation
843-
draft = {"name": name, "description": description, "entrypoints": [1]}
844-
draft_mod = {"name": new_name, "description": description, "entrypoints": [3, 2]}
844+
draft = {
845+
"name": name,
846+
"description": description,
847+
"entrypoints": [registered_entrypoints["entrypoint1"]["id"]],
848+
}
849+
draft_mod = {
850+
"name": new_name,
851+
"description": description,
852+
"entrypoints": [
853+
registered_entrypoints["entrypoint3"]["id"],
854+
registered_entrypoints["entrypoint2"]["id"],
855+
],
856+
}
845857

846858
# Expected responses
847859
draft_expected = {
@@ -863,7 +875,9 @@ def test_manage_existing_experiment_draft(
863875

864876
# Run routine: existing resource drafts tests
865877
routines.run_existing_resource_drafts_tests(
878+
dioptra_client.experiments,
866879
dioptra_client.experiments.modify_resource_drafts,
880+
dioptra_client.workflows,
867881
experiment["id"],
868882
draft=draft,
869883
draft_mod=draft_mod,
@@ -876,6 +890,7 @@ def test_manage_new_experiment_drafts(
876890
dioptra_client: DioptraClient[DioptraResponseProtocol],
877891
db: SQLAlchemy,
878892
auth_account: dict[str, Any],
893+
registered_entrypoints: dict[str, Any],
879894
) -> None:
880895
"""Test that drafts of experiment can be created and managed by the user
881896
@@ -895,9 +910,16 @@ def test_manage_new_experiment_drafts(
895910
"draft1": {
896911
"name": "experiment1",
897912
"description": "my experiment",
898-
"entrypoints": [3],
913+
"entrypoints": [registered_entrypoints["entrypoint3"]["id"]],
914+
},
915+
"draft2": {
916+
"name": "experiment2",
917+
"description": None,
918+
"entrypoints": [
919+
registered_entrypoints["entrypoint2"]["id"],
920+
registered_entrypoints["entrypoint1"]["id"],
921+
],
899922
},
900-
"draft2": {"name": "experiment2", "description": None, "entrypoints": [3, 4]},
901923
}
902924
draft1_mod = {"name": "draft1", "description": "new description", "entrypoints": []}
903925

@@ -920,7 +942,9 @@ def test_manage_new_experiment_drafts(
920942

921943
# Run routine: existing resource drafts tests
922944
routines.run_new_resource_drafts_tests(
945+
dioptra_client.experiments,
923946
dioptra_client.experiments.new_resource_drafts,
947+
dioptra_client.workflows,
924948
drafts=drafts,
925949
draft1_mod=draft1_mod,
926950
draft1_expected=draft1_expected,

‎tests/unit/restapi/v1/test_model.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -679,7 +679,9 @@ def test_manage_existing_model_draft(
679679

680680
# Run routine: existing resource drafts tests
681681
routines.run_existing_resource_drafts_tests(
682+
dioptra_client.models,
682683
dioptra_client.models.modify_resource_drafts,
684+
dioptra_client.workflows,
683685
model["id"],
684686
draft=draft,
685687
draft_mod=draft_mod,
@@ -732,7 +734,9 @@ def test_manage_new_model_drafts(
732734

733735
# Run routine: existing resource drafts tests
734736
routines.run_new_resource_drafts_tests(
735-
dioptra_client.queues.new_resource_drafts,
737+
dioptra_client.models,
738+
dioptra_client.models.new_resource_drafts,
739+
dioptra_client.workflows,
736740
drafts=drafts,
737741
draft1_mod=draft1_mod,
738742
draft1_expected=draft1_expected,

‎tests/unit/restapi/v1/test_plugin.py

+8
Original file line numberDiff line numberDiff line change
@@ -1326,7 +1326,9 @@ def test_manage_existing_plugin_draft(
13261326

13271327
# Run routine: existing resource drafts tests
13281328
routines.run_existing_resource_drafts_tests(
1329+
dioptra_client.plugins,
13291330
dioptra_client.plugins.modify_resource_drafts,
1331+
dioptra_client.workflows,
13301332
plugin["id"],
13311333
draft=draft,
13321334
draft_mod=draft_mod,
@@ -1379,7 +1381,9 @@ def test_manage_new_plugin_drafts(
13791381

13801382
# Run routine: new resource drafts tests
13811383
routines.run_new_resource_drafts_tests(
1384+
dioptra_client.plugins,
13821385
dioptra_client.plugins.new_resource_drafts,
1386+
dioptra_client.workflows,
13831387
drafts=drafts,
13841388
draft1_mod=draft1_mod,
13851389
draft1_expected=draft1_expected,
@@ -1456,7 +1460,9 @@ def hello_world(name: str) -> str:
14561460

14571461
# Run routine: existing resource drafts tests
14581462
routines.run_existing_resource_drafts_tests(
1463+
dioptra_client.plugins.files,
14591464
dioptra_client.plugins.files.modify_resource_drafts,
1465+
dioptra_client.workflows,
14601466
plugin_id,
14611467
plugin_file["id"],
14621468
draft=draft,
@@ -1535,7 +1541,9 @@ def hello_world(name: str) -> str:
15351541

15361542
# Run routine: new resource drafts tests
15371543
routines.run_new_resource_drafts_tests(
1544+
dioptra_client.plugins.files,
15381545
dioptra_client.plugins.files.new_resource_drafts,
1546+
dioptra_client.workflows,
15391547
plugin_id,
15401548
drafts=drafts,
15411549
draft1_mod=draft1_mod,

‎tests/unit/restapi/v1/test_plugin_parameter_type.py

+4
Original file line numberDiff line numberDiff line change
@@ -785,7 +785,9 @@ def test_manage_existing_plugin_parameter_type_draft(
785785

786786
# Run routine: existing resource drafts tests
787787
routines.run_existing_resource_drafts_tests(
788+
dioptra_client.plugin_parameter_types,
788789
dioptra_client.plugin_parameter_types.modify_resource_drafts,
790+
dioptra_client.workflows,
789791
plugin_param_type["id"],
790792
draft=draft,
791793
draft_mod=draft_mod,
@@ -844,7 +846,9 @@ def test_manage_new_plugin_parameter_type_drafts(
844846

845847
# Run routine: new resource drafts tests
846848
routines.run_new_resource_drafts_tests(
849+
dioptra_client.plugin_parameter_types,
847850
dioptra_client.plugin_parameter_types.new_resource_drafts,
851+
dioptra_client.workflows,
848852
drafts=drafts,
849853
draft1_mod=draft1_mod,
850854
draft1_expected=draft1_expected,

‎tests/unit/restapi/v1/test_queue.py

+4
Original file line numberDiff line numberDiff line change
@@ -610,7 +610,9 @@ def test_manage_existing_queue_draft(
610610

611611
# Run routine: existing resource drafts tests
612612
routines.run_existing_resource_drafts_tests(
613+
dioptra_client.queues,
613614
dioptra_client.queues.modify_resource_drafts,
615+
dioptra_client.workflows,
614616
queue["id"],
615617
draft=draft,
616618
draft_mod=draft_mod,
@@ -663,7 +665,9 @@ def test_manage_new_queue_drafts(
663665

664666
# Run routine: existing resource drafts tests
665667
routines.run_new_resource_drafts_tests(
668+
dioptra_client.queues,
666669
dioptra_client.queues.new_resource_drafts,
670+
dioptra_client.workflows,
667671
drafts=drafts,
668672
draft1_mod=draft1_mod,
669673
draft1_expected=draft1_expected,

0 commit comments

Comments
 (0)
Please sign in to comment.