Skip to content

Commit

Permalink
Move AcquisitionReport work into asynch thread (#210)
Browse files Browse the repository at this point in the history
* Move AcquisitionReport work into asynch thread

* Remove unecessary comment
  • Loading branch information
amalnanavati authored Feb 5, 2025
1 parent 03c5043 commit cdd0bb9
Showing 1 changed file with 55 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import copy
import errno
import os
import threading
import time
from typing import Dict
import uuid
Expand Down Expand Up @@ -292,12 +293,14 @@ def __init__(self):

# Create AcquisitionSelect cache
# UUID -> {context, request, response}
self.cache_lock = threading.Lock()
self.cache = {}

# Init Checkpoints / Data Record
self._init_checkpoints_record(context_cls, posthoc_cls)

# Start ROS services
self.acquisition_report_threads = []
self.ros_objs = []
self.ros_objs.append(
self.create_service(
Expand Down Expand Up @@ -339,11 +342,12 @@ def select_callback(
response.probabilities = list(res[0])
response.actions = list(res[1])
select_id = str(uuid.uuid4())
self.cache[select_id] = {
"context": np.copy(context),
"request": copy.deepcopy(request),
"response": copy.deepcopy(response),
}
with self.cache_lock:
self.cache[select_id] = {
"context": np.copy(context),
"request": copy.deepcopy(request),
"response": copy.deepcopy(response),
}
response.id = select_id

if response.status != "Success":
Expand All @@ -365,13 +369,49 @@ def report_callback(
f"AcquisitionReport Request with ID: '{request.id}' and loss '{request.loss}'"
)

# Collect cached context
if request.id not in self.cache:
response.status = "id does not map to previous select call"
self.get_logger().error(f"AcquistionReport: {response.status}")
response.success = False
return response
cache = self.cache[request.id]
# Remove any completed threads
i = 0
while i < len(self.acquisition_report_threads):
if not self.acquisition_report_threads[i].is_alive():
self.get_logger().info("Removing completed acquisition report thread")
self.acquisition_report_threads.pop(i)
else:
i += 1

# Start the asynch thread
request_copy = copy.deepcopy(request)
response_copy = copy.deepcopy(response)
thread = threading.Thread(
target=self.report_callback_work, args=(request_copy, response_copy)
)
self.acquisition_report_threads.append(thread)
self.get_logger().info("Starting new acquisition report thread")
thread.start()

# Return success immediately
response.status = "Success"
response.success = True
return response

# pylint: disable=too-many-statements
# One over is fine for this function.
def report_callback_work(
self, request: AcquisitionReport.Request, response: AcquisitionReport.Response
) -> AcquisitionReport.Response:
"""
Perform the work of updating the policy based on the acquisition. This is a workaround
to the fact that either ROSLib or rosbridge (likely the latter) cannot process a service
and action at the same time, so in practice the next motion waits until after the policy
has been updated, which adds a few seconds of unnecessary latency.
"""
with self.cache_lock:
# Collect cached context
if request.id not in self.cache:
response.status = "id does not map to previous select call"
self.get_logger().error(f"AcquistionReport: {response.status}")
response.success = False
return response
cache = copy.deepcopy(self.cache[request.id])
context = cache["context"]

# Collect executed action
Expand Down Expand Up @@ -409,7 +449,9 @@ def report_callback(

# Report completed
self.n_successful_reports += 1
del self.cache[request.id]
with self.cache_lock:
if request.id in self.cache:
del self.cache[request.id]

# Save checkpoint if requested
if (
Expand Down

0 comments on commit cdd0bb9

Please sign in to comment.