Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unitary hack] Improve docstrings in "Task/braket.py" #975

Closed
wants to merge 3 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 153 additions & 26 deletions src/bloqade/task/braket.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import warnings
from dataclasses import dataclass, field
from beartype.typing import Dict, Optional, Any

from bloqade.builder.base import ParamType
from bloqade.serialize import Serializer
from bloqade.submission.ir.parallel import ParallelDecoder
Expand All @@ -7,17 +11,23 @@

from bloqade.submission.base import ValidationError
from bloqade.submission.ir.task_results import QuEraTaskResults, QuEraTaskStatusCode
import warnings
from dataclasses import dataclass, field
from beartype.typing import Dict, Optional, Any


## keep the old conversion for now,
## we will remove conversion btwn QuEraTask <-> BraketTask,
## and specialize/dispatching here.
@dataclass
@Serializer.register
class BraketTask(RemoteTask):
"""
Represents a Braket Task which can be submitted to a Braket backend.

Attributes:
task_id (Optional[str]): The ID of the task.
backend (BraketBackend): The backend to which the task is submitted.
task_ir (QuEraTaskSpecification): The task specification.
metadata (Dict[str, ParamType]): Metadata associated with the task.
parallel_decoder (Optional[ParallelDecoder]): Parallel decoder for the task.
task_result_ir (QuEraTaskResults): The result of the task.
"""

task_id: Optional[str]
backend: BraketBackend
task_ir: QuEraTaskSpecification
Expand All @@ -30,30 +40,78 @@ class BraketTask(RemoteTask):
)

def submit(self, force: bool = False) -> "BraketTask":
"""
shubhusion marked this conversation as resolved.
Show resolved Hide resolved
Submits the task to the backend.

Args:
force (bool): Whether to force submission even if the task is already submitted.

Returns:
BraketTask: The current task instance.

Raises:
ValueError: If the task is already submitted and force is False.

Example:
>>> backend = BraketBackend(...) # Create a backend instance
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't be rendered correctly by mkdocs markdown. You need to wrap them with markdown code blocks.

>>> task_ir = QuEraTaskSpecification(...) # Create a task specification
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

like my comment from the other PR, this is too vague on how you construct a specification, the task will be constructed from a submission, so you should first create a submission (a batch object) then return the task from the batch.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ref: https://github.com/QuEraComputing/bloqade-python/blob/main/src/bloqade/task/batch.py

and you can also check the corresponding tests to understand how this works

>>> metadata = {"key": "value"} # Metadata for the task
>>> braket_task = BraketTask(
... task_id=None,
... backend=backend,
... task_ir=task_ir,
... metadata=metadata,
... )
>>> braket_task.submit()
>>> print(f"Task submitted with ID: {braket_task.task_id}")
"""
if not force:
if self.task_id is not None:
raise ValueError(
"the task is already submitted with %s" % (self.task_id)
)
raise ValueError(f"the task is already submitted with {self.task_id}")
self.task_id = self.backend.submit_task(self.task_ir)

self.task_result_ir = QuEraTaskResults(task_status=QuEraTaskStatusCode.Enqueued)

return self

def validate(self) -> str:
"""
Validates the task specification.

Returns:
str: An empty string if validation is successful, otherwise the validation error message.

Example:
>>> validation_error = braket_task.validate()
>>> if validation_error:
... print(f"Validation Error: {validation_error}")
... else:
... print("Task is valid.")
"""
try:
self.backend.validate_task(self.task_ir)
except ValidationError as e:
return str(e)

return ""

def fetch(self) -> "BraketTask":
# non-blocking, pull only when its completed
"""
Fetches the task results if the task is completed.

Returns:
BraketTask: The current task instance.

Raises:
ValueError: If the task is not yet submitted.

Example:
>>> try:
... braket_task.fetch()
... results = braket_task.result()
... print(f"Task results: {results}")
... except ValueError as e:
... print(e)
"""
if self.task_result_ir.task_status is QuEraTaskStatusCode.Unsubmitted:
raise ValueError("Task ID not found.")

if self.task_result_ir.task_status in [
QuEraTaskStatusCode.Completed,
QuEraTaskStatusCode.Partial,
Expand All @@ -62,27 +120,47 @@ def fetch(self) -> "BraketTask":
QuEraTaskStatusCode.Cancelled,
]:
return self

status = self.status()
if status in [QuEraTaskStatusCode.Completed, QuEraTaskStatusCode.Partial]:
self.task_result_ir = self.backend.task_results(self.task_id)
else:
self.task_result_ir = QuEraTaskResults(task_status=status)

return self

def pull(self) -> "BraketTask":
# blocking, force pulling, even its completed
"""
Forces pulling the task results.

Returns:
BraketTask: The current task instance.

Raises:
ValueError: If the task ID is not found.

Example:
>>> try:
... braket_task.pull()
... results = braket_task.result()
... print(f"Task results: {results}")
... except ValueError as e:
... print(e)
"""
if self.task_id is None:
raise ValueError("Task ID not found.")

self.task_result_ir = self.backend.task_results(self.task_id)

return self

def result(self) -> QuEraTaskResults:
# blocking, caching
"""
Gets the task results, blocking until results are available.

Returns:
QuEraTaskResults: The task results.

Example:
>>> results = braket_task.result()
>>> print(f"Task results: {results}")
"""
if self.task_result_ir is None:
pass
else:
Expand All @@ -91,34 +169,86 @@ def result(self) -> QuEraTaskResults:
and self.task_result_ir.task_status != QuEraTaskStatusCode.Completed
):
self.pull()

return self.task_result_ir

def status(self) -> QuEraTaskStatusCode:
"""
Gets the status of the task.

Returns:
QuEraTaskStatusCode: The status of the task.

Example:
>>> status = braket_task.status()
>>> print(f"Task status: {status.name}")
"""
if self.task_id is None:
return QuEraTaskStatusCode.Unsubmitted

return self.backend.task_status(self.task_id)

def cancel(self) -> None:
"""
Cancels the task if it is currently submitted.

Returns:
None

Raises:
Warning: If the task ID is not found.

Example:
>>> try:
... braket_task.cancel()
... print("Task cancelled.")
... except Warning as w:
... print(w)
"""
if self.task_id is None:
warnings.warn("Cannot cancel task, missing task id.")
return

self.backend.cancel_task(self.task_id)

@property
def nshots(self):
"""
Gets the number of shots specified for the task.

Returns:
int: The number of shots.

Example:
>>> print(f"Number of shots: {braket_task.nshots}")
"""
return self.task_ir.nshots

def _geometry(self) -> Geometry:
"""
Gets the geometry of the task lattice.

Returns:
Geometry: The geometry of the task lattice.

Example:
>>> geometry = braket_task._geometry()
>>> print(f"Task geometry: {geometry}")
"""
return Geometry(
sites=self.task_ir.lattice.sites,
filling=self.task_ir.lattice.filling,
parallel_decoder=self.parallel_decoder,
)

def _result_exists(self) -> bool:
"""
Checks if the task results exist.

Returns:
bool: True if the task results exist and are completed, otherwise False.

Example:
>>> result_exists = braket_task._result_exists()
>>> print(f"Result exists: {result_exists}")
"""
if self.task_result_ir is None:
return False
else:
Expand All @@ -127,9 +257,6 @@ def _result_exists(self) -> bool:
else:
return False

# def submit_no_task_id(self) -> "HardwareTaskShotResults":
# return HardwareTaskShotResults(hardware_task=self)


@BraketTask.set_serializer
def _serialize(obj: BraketTask) -> Dict[str, Any]:
Expand Down
Loading