Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement a host promise rejection tracker with Python callback #80

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ The `Function` class has, apart from being a callable, additional methods:
- `memory` – returns a dict with information about memory usage.
- `add_callable` – adds a Python function and makes it callable from JS.
- `execute_pending_job` – executes a pending job (such as a async function or Promise).
- `set_promise_rejection_tracker` - sets a callback receiving (promise, reason, is_handled) when a promise is rejected. Pass None to disable.

## Documentation
For full functionality, please see `test_quickjs.py`
Expand Down
67 changes: 67 additions & 0 deletions module.c
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ typedef struct {
JSContext *context;
int has_time_limit;
clock_t time_limit;
PyObject *promise_rejection_tracker_callback;
// Used when releasing the GIL.
PyThreadState *thread_state;
InterruptData interrupt_data;
Expand Down Expand Up @@ -102,6 +103,45 @@ static void end_call_python(ContextData *context) {
context->thread_state = PyEval_SaveThread();
}

static void js_python_promise_rejection_tracker(
JSContext *ctx, JSValueConst promise, JSValueConst reason, int is_handled, void *opaque) {
ContextData *context = (ContextData *)JS_GetContextOpaque(ctx);
PyObject *callback = (PyObject *)opaque;
// Cannot call into Python with a time limit set.
if (context->has_time_limit) {
return;
}
prepare_call_python(context);
PyObject *py_promise = quickjs_to_python(context, JS_DupValue(ctx, promise));
if (py_promise == NULL) {
PyErr_WriteUnraisable(callback);
PyErr_Clear();
end_call_python(context);
return;
}
PyObject *py_reason = quickjs_to_python(context, JS_DupValue(ctx, reason));
if (py_reason == NULL) {
PyErr_WriteUnraisable(callback);
PyErr_Clear();
Py_DECREF(py_promise);
end_call_python(context);
return;
}
PyObject *py_is_handled = is_handled ? Py_True : Py_False;
Py_INCREF(py_is_handled);
PyObject *ret = PyObject_CallFunctionObjArgs(callback, py_promise, py_reason, py_is_handled, NULL);
if (ret == NULL) {
PyErr_WriteUnraisable(callback);
PyErr_Clear();
} else {
Py_DECREF(ret);
}
Py_DECREF(py_is_handled);
Py_DECREF(py_reason);
Py_DECREF(py_promise);
end_call_python(context);
}

// GC traversal.
static int object_traverse(ObjectData *self, visitproc visit, void *arg) {
Py_VISIT(self->context);
Expand Down Expand Up @@ -373,6 +413,7 @@ static PyObject *context_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
self->context = JS_NewContext(self->runtime);
self->has_time_limit = 0;
self->time_limit = 0;
self->promise_rejection_tracker_callback = NULL;
self->thread_state = NULL;
self->python_callables = NULL;
JS_SetContextOpaque(self->context, self);
Expand All @@ -386,6 +427,7 @@ static void context_dealloc(ContextData *self) {
JS_FreeContext(self->context);
JS_FreeRuntime(self->runtime);
PyObject_GC_UnTrack(self);
Py_XDECREF(self->promise_rejection_tracker_callback);
PythonCallableNode *node = self->python_callables;
self->python_callables = NULL;
while (node) {
Expand Down Expand Up @@ -541,6 +583,27 @@ static PyObject *context_set_max_stack_size(ContextData *self, PyObject *args) {
Py_RETURN_NONE;
}

// _quickjs.Context.set_promise_rejection_tracker
//
// Sets a callback receiving (promise, reason, is_handled) when a promise is rejected.
static PyObject *context_set_promise_rejection_tracker(ContextData *self, PyObject *args) {
PyObject *callback = NULL;
if (!PyArg_ParseTuple(args, "|O", &callback)) {
return NULL;
}
Py_XDECREF(self->promise_rejection_tracker_callback);
if (callback == NULL || callback == Py_None) {
self->promise_rejection_tracker_callback = NULL;
JS_SetHostPromiseRejectionTracker(self->runtime, NULL, NULL);
} else {
Py_INCREF(callback);
self->promise_rejection_tracker_callback = callback;
JS_SetHostPromiseRejectionTracker(self->runtime, js_python_promise_rejection_tracker,
callback);
}
Py_RETURN_NONE;
}

// _quickjs.Context.memory
//
// Sets the CPU time limit of the context. This will be used in an interrupt handler.
Expand Down Expand Up @@ -716,6 +779,10 @@ static PyMethodDef context_methods[] = {
(PyCFunction)context_set_max_stack_size,
METH_VARARGS,
"Sets the maximum stack size in bytes. Default is 256kB."},
{"set_promise_rejection_tracker",
(PyCFunction)context_set_promise_rejection_tracker,
METH_VARARGS,
"Sets a callback receiving (promise, reason, is_handled) when a promise is rejected. Pass None to disable."},
{"memory", (PyCFunction)context_memory, METH_NOARGS, "Returns the memory usage as a dict."},
{"gc", (PyCFunction)context_gc, METH_NOARGS, "Runs garbage collection."},
{"add_callable", (PyCFunction)context_add_callable, METH_VARARGS, "Wraps a Python callable."},
Expand Down
4 changes: 4 additions & 0 deletions quickjs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ def set_max_stack_size(self, limit):
with self._lock:
return self._context.set_max_stack_size(limit)

def set_promise_rejection_tracker(self, tracker):
with self._lock:
return self._context.set_promise_rejection_tracker(tracker)

def memory(self):
with self._lock:
return self._context.memory()
Expand Down
69 changes: 69 additions & 0 deletions test_quickjs.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,60 @@ def test_list():
# instead of a JS exception.
self.context.eval("test_list()")

def test_promise_rejection_tracker(self):
called = [0]
def tracker(promise, reason, is_handled):
called[0] += 1
self.assertFalse(is_handled)
self.context.set("reason", reason)
self.assertTrue(self.context.eval("reason.message === 'x';"))
def run_async_error():
self.context.eval("function f() {throw Error('x');}")
self.context.eval("async function g() {await f();}")
self.context.eval("g()")
self.context.set_promise_rejection_tracker(tracker)
run_async_error()
self.context.set_promise_rejection_tracker(None)
run_async_error()
self.context.set_promise_rejection_tracker()
run_async_error()
self.assertEqual(called[0], 1)

def test_promise_rejection_tracker_promise(self):
called = [0]
def tracker_false(promise, reason, is_handled):
called[0] += 1
self.assertFalse(is_handled)
def tracker_true(promise, reason, is_handled):
called[0] += 1
self.assertTrue(is_handled)
self.context.eval("Promise.reject().then(() => {}, () => {return Promise.reject();})")
self.context.set_promise_rejection_tracker(tracker_false)
self.context.execute_pending_job()
self.context.set_promise_rejection_tracker(tracker_true)
self.assertTrue(self.context.execute_pending_job())
self.context.set_promise_rejection_tracker(tracker_false)
self.assertTrue(self.context.execute_pending_job())
self.assertFalse(self.context.execute_pending_job())
self.assertEqual(called[0], 3)

def test_promise_rejection_tracker_unraisable(self):
import sys
def unraisablehook(u):
self.assertTrue(isinstance(u.exc_value, ZeroDivisionError))
unraisablehook_orig = sys.unraisablehook
sys.unraisablehook = unraisablehook
called = [0]
def tracker(promise, reason, is_handled):
called[0] += 1
raise ZeroDivisionError
self.context.set_promise_rejection_tracker(tracker)
self.context.eval("function f() {throw Error;}")
self.context.eval("async function g() {await f();}")
self.context.eval("g()")
self.assertEqual(called[0], 1)
sys.unraisablehook = unraisablehook_orig


class Object(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -493,6 +547,21 @@ def test_execute_pending_job(self):
self.assertEqual(f(), 2)
self.assertEqual(f.execute_pending_job(), False)

def test_promise_rejection_tracker(self):
called = [0]
def tracker(promise, reason, is_handled):
called[0] += 1
self.assertFalse(is_handled)
f = quickjs.Function(
"f", """
function g() {throw Error('x');}
async function h() {await g();}
function f() {h();}
""")
f.set_promise_rejection_tracker(tracker)
f()
self.assertEqual(called[0], 1)


class JavascriptFeatures(unittest.TestCase):
def test_unicode_strings(self):
Expand Down