Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unitary hack] Improve docstrings in "Task/braket.py" #975

Closed
wants to merge 3 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 118 additions & 12 deletions src/bloqade/task/braket.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
"""
Module for managing Braket tasks in the bloqade framework.

This module defines the BraketTask class, which represents a task that can be submitted
to a Braket backend. It includes methods for task submission, validation, fetching results,
checking status, and cancellation. Additionally, serialization and deserialization
functions are provided for the BraketTask class.
"""

import warnings
from dataclasses import dataclass, field
from beartype.typing import Dict, Optional, Any

from bloqade.builder.base import ParamType
from bloqade.serialize import Serializer
from bloqade.submission.ir.parallel import ParallelDecoder
Expand All @@ -7,9 +20,6 @@

from bloqade.submission.base import ValidationError
from bloqade.submission.ir.task_results import QuEraTaskResults, QuEraTaskStatusCode
import warnings
from dataclasses import dataclass, field
from beartype.typing import Dict, Optional, Any


## keep the old conversion for now,
Expand All @@ -18,6 +28,18 @@
@dataclass
@Serializer.register
class BraketTask(RemoteTask):
"""
Represents a Braket Task which can be submitted to a Braket backend.

Attributes:
task_id (Optional[str]): The ID of the task.
backend (BraketBackend): The backend to which the task is submitted.
task_ir (QuEraTaskSpecification): The task specification.
metadata (Dict[str, ParamType]): Metadata associated with the task.
parallel_decoder (Optional[ParallelDecoder]): Parallel decoder for the task.
task_result_ir (QuEraTaskResults): The result of the task.
"""

task_id: Optional[str]
backend: BraketBackend
task_ir: QuEraTaskSpecification
Expand All @@ -30,18 +52,34 @@ class BraketTask(RemoteTask):
)

def submit(self, force: bool = False) -> "BraketTask":
"""
shubhusion marked this conversation as resolved.
Show resolved Hide resolved
Submits the task to the backend.

Args:
force (bool): Whether to force submission even if the task is already submitted.

Returns:
BraketTask: The current task instance.

Raises:
ValueError: If the task is already submitted and force is False.
"""
if not force:
if self.task_id is not None:
raise ValueError(
"the task is already submitted with %s" % (self.task_id)
)
raise ValueError(f"the task is already submitted with {self.task_id}")
self.task_id = self.backend.submit_task(self.task_ir)

self.task_result_ir = QuEraTaskResults(task_status=QuEraTaskStatusCode.Enqueued)

return self

def validate(self) -> str:
"""
Validates the task specification.

Returns:
str: An empty string if validation is successful,otherwise the validation error message.
"""
try:
self.backend.validate_task(self.task_ir)
except ValidationError as e:
Expand All @@ -50,7 +88,15 @@ def validate(self) -> str:
return ""

def fetch(self) -> "BraketTask":
# non-blocking, pull only when its completed
"""
Fetches the task results if the task is completed.

Returns:
BraketTask: The current task instance.

Raises:
ValueError: If the task is not yet submitted.
"""
if self.task_result_ir.task_status is QuEraTaskStatusCode.Unsubmitted:
raise ValueError("Task ID not found.")

Expand All @@ -72,7 +118,15 @@ def fetch(self) -> "BraketTask":
return self

def pull(self) -> "BraketTask":
# blocking, force pulling, even its completed
"""
Forces pulling the task results.

Returns:
BraketTask: The current task instance.

Raises:
ValueError: If the task ID is not found.
"""
if self.task_id is None:
raise ValueError("Task ID not found.")

Expand All @@ -81,8 +135,12 @@ def pull(self) -> "BraketTask":
return self

def result(self) -> QuEraTaskResults:
# blocking, caching
"""
Gets the task results, blocking until results are available.

Returns:
QuEraTaskResults: The task results.
"""
if self.task_result_ir is None:
pass
else:
Expand All @@ -95,12 +153,27 @@ def result(self) -> QuEraTaskResults:
return self.task_result_ir

def status(self) -> QuEraTaskStatusCode:
"""
Gets the status of the task.

Returns:
QuEraTaskStatusCode: The status of the task.
"""
if self.task_id is None:
return QuEraTaskStatusCode.Unsubmitted

return self.backend.task_status(self.task_id)

def cancel(self) -> None:
"""
Cancels the task if it is currently submitted.

Returns:
None

Raises:
Warning: If the task ID is not found.
"""
if self.task_id is None:
warnings.warn("Cannot cancel task, missing task id.")
return
Expand All @@ -109,16 +182,34 @@ def cancel(self) -> None:

@property
def nshots(self):
"""
Gets the number of shots specified for the task.

Returns:
int: The number of shots.
"""
return self.task_ir.nshots

def _geometry(self) -> Geometry:
"""
Gets the geometry of the task lattice.

Returns:
Geometry: The geometry of the task lattice.
"""
return Geometry(
sites=self.task_ir.lattice.sites,
filling=self.task_ir.lattice.filling,
parallel_decoder=self.parallel_decoder,
)

def _result_exists(self) -> bool:
"""
Checks if the task results exist.

Returns:
bool: True if the task results exist and are completed, otherwise False.
"""
if self.task_result_ir is None:
return False
else:
Expand All @@ -127,12 +218,18 @@ def _result_exists(self) -> bool:
else:
return False

# def submit_no_task_id(self) -> "HardwareTaskShotResults":
# return HardwareTaskShotResults(hardware_task=self)


@BraketTask.set_serializer
def _serialize(obj: BraketTask) -> Dict[str, Any]:
"""
shubhusion marked this conversation as resolved.
Show resolved Hide resolved
Serializes the BraketTask instance to a dictionary.

Args:
obj (BraketTask): The task instance to serialize.

Returns:
Dict[str, Any]: The serialized dictionary representation of the task.
"""
return {
"task_id": obj.task_id,
"backend": obj.backend.dict(),
Expand All @@ -147,6 +244,15 @@ def _serialize(obj: BraketTask) -> Dict[str, Any]:

@BraketTask.set_deserializer
def _deserialize(d: Dict[str, Any]) -> BraketTask:
"""
Deserializes a dictionary to a BraketTask instance.

Args:
d (Dict[str, Any]): The dictionary to deserialize.

Returns:
BraketTask: The deserialized task instance.
"""
d["backend"] = BraketBackend(**d["backend"])
d["task_ir"] = QuEraTaskSpecification(**d["task_ir"])
d["parallel_decoder"] = (
Expand Down
Loading