Skip to content

Commit 2ca51cf

Browse files
committed
attempt to make things common
1 parent 74f6694 commit 2ca51cf

File tree

6 files changed

+111
-112
lines changed

6 files changed

+111
-112
lines changed

.pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ repos:
2727
always_run: true
2828
- id: name-tests-test
2929
always_run: true
30-
exclude: ^(.*/tests/utils/)|^(.*fixtures.py)|^(tests/scenarios/bigquery/helpers)|^(tests/scenariosv2/sim)
30+
exclude: ^(.*/tests/utils/)|^(.*fixtures.py)|^(tests/scenariosv2/(sim|flows))
3131
- id: requirements-txt-fixer
3232
always_run: true
3333
- id: mixed-line-ending

tests/scenariosv2/__init__.py

Whitespace-only changes.

tests/scenariosv2/flows/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# stdlib
2+
import random
3+
4+
# syft absolute
5+
import syft as sy
6+
from syft import test_settings
7+
from syft.service.request.request import RequestStatus
8+
9+
# relative
10+
from ..sim.core import SimulatorContext
11+
12+
__all__ = ["bq_test_query", "bq_submit_query", "bq_submit_query_results"]
13+
14+
15+
def query_sql():
16+
dataset_2 = test_settings.get("dataset_2", default="dataset_2")
17+
table_2 = test_settings.get("table_2", default="table_2")
18+
table_2_col_id = test_settings.get("table_2_col_id", default="table_id")
19+
table_2_col_score = test_settings.get("table_2_col_score", default="colname")
20+
21+
query = f"SELECT {table_2_col_id}, AVG({table_2_col_score}) AS average_score \
22+
FROM {dataset_2}.{table_2} \
23+
GROUP BY {table_2_col_id} \
24+
LIMIT 10000"
25+
return query
26+
27+
28+
def bq_test_query(ctx: SimulatorContext, client: sy.DatasiteClient):
29+
ctx.logger.info(
30+
f"User: {client.logged_in_user} - Calling client.api.bigquery.test_query (mock)"
31+
)
32+
res = client.api.bigquery.test_query(sql_query=query_sql())
33+
assert len(res) == 10000
34+
ctx.logger.info(f"User: {client.logged_in_user} - Received {len(res)} rows")
35+
return res
36+
37+
38+
def bq_submit_query(ctx: SimulatorContext, client: sy.DatasiteClient):
39+
# Randomly define a func_name a function to call
40+
func_name = "invalid_func" if random.random() < 0.5 else "test_query"
41+
42+
ctx.logger.info(
43+
f"User: {client.logged_in_user} - Calling client.api.services.bigquery.submit_query func_name={func_name}"
44+
)
45+
res = client.api.bigquery.submit_query(
46+
func_name=func_name,
47+
query=query_sql(),
48+
)
49+
ctx.logger.info(f"User: {client.logged_in_user} - Received {res}")
50+
return res
51+
52+
53+
def bq_submit_query_results(ctx: SimulatorContext, client: sy.DatasiteClient):
54+
for request in client.requests:
55+
if request.get_status() == RequestStatus.APPROVED:
56+
job = request.code(blocking=False)
57+
result = job.wait()
58+
assert len(result) == 10000
59+
if request.get_status() == RequestStatus.REJECTED:
60+
ctx.logger.info(
61+
f"User: {client.logged_in_user} - Request rejected {request.code.service_func_name}"
62+
)
63+
64+
return True

tests/scenariosv2/l0_test.py

+31-77
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,26 @@
66
# third party
77
from faker import Faker
88
import pytest
9-
from sim.core import BaseEvent
10-
from sim.core import Simulator
11-
from sim.core import SimulatorContext
12-
from sim.core import sim_activity
13-
from sim.core import sim_entrypoint
149

1510
# syft absolute
1611
import syft as sy
17-
from syft import test_settings
18-
from syft.client.client import SyftClient
1912
from syft.service.request.request import RequestStatus
2013
from syft.util.test_helpers.apis import make_schema
2114
from syft.util.test_helpers.apis import make_test_query
2215
from syft.util.test_helpers.worker_helpers import (
2316
build_and_launch_worker_pool_from_docker_str,
2417
)
2518

19+
# relative
20+
from .flows.user_bigquery_api import bq_submit_query
21+
from .flows.user_bigquery_api import bq_submit_query_results
22+
from .flows.user_bigquery_api import bq_test_query
23+
from .sim.core import BaseEvent
24+
from .sim.core import Simulator
25+
from .sim.core import SimulatorContext
26+
from .sim.core import sim_activity
27+
from .sim.core import sim_entrypoint
28+
2629
fake = Faker()
2730
NUM_USERS = 3
2831
NUM_ENDPOINTS = 3 # test_query, submit_query, schema_query
@@ -59,42 +62,13 @@ class Event(BaseEvent):
5962
# ------------------------------------------------------------------------------------------------
6063

6164

62-
def query_sql():
63-
dataset_2 = test_settings.get("dataset_2", default="dataset_2")
64-
table_2 = test_settings.get("table_2", default="table_2")
65-
table_2_col_id = test_settings.get("table_2_col_id", default="table_id")
66-
table_2_col_score = test_settings.get("table_2_col_score", default="colname")
67-
68-
query = f"SELECT {table_2_col_id}, AVG({table_2_col_score}) AS average_score \
69-
FROM {dataset_2}.{table_2} \
70-
GROUP BY {table_2_col_id} \
71-
LIMIT 10000"
72-
return query
73-
74-
75-
def get_code_from_msg(msg: str):
76-
return str(msg.split("`")[1].replace("()", "").replace("client.", ""))
77-
78-
79-
# ------------------------------------------------------------------------------------------------
80-
81-
8265
@sim_activity(
8366
wait_for=Event.ADMIN_LOW_SIDE_ENDPOINTS_AVAILABLE,
8467
trigger=Event.USER_CAN_QUERY_TEST_ENDPOINT,
8568
)
8669
async def user_query_test_endpoint(ctx: SimulatorContext, client: sy.DatasiteClient):
8770
"""Run query on test endpoint"""
88-
89-
user = client.logged_in_user
90-
91-
def _query_endpoint():
92-
ctx.logger.info(f"User: {user} - Calling client.api.bigquery.test_query (mock)")
93-
res = client.api.bigquery.test_query(sql_query=query_sql())
94-
assert len(res) == 10000
95-
ctx.logger.info(f"User: {user} - Received {len(res)} rows")
96-
97-
await asyncio.to_thread(_query_endpoint)
71+
await asyncio.to_thread(bq_test_query, ctx, client)
9872

9973

10074
@sim_activity(
@@ -103,40 +77,15 @@ def _query_endpoint():
10377
)
10478
async def user_bq_submit(ctx: SimulatorContext, client: sy.DatasiteClient):
10579
"""Submit query to be run on private data"""
106-
user = client.logged_in_user
107-
108-
def _submit_endpoint():
109-
func_name = "invalid_func" if random.random() < 0.5 else "test_query"
110-
ctx.logger.info(
111-
f"User: {user} - Calling client.api.services.bigquery.submit_query func_name={func_name}"
112-
)
113-
114-
res = client.api.bigquery.submit_query(
115-
func_name=func_name,
116-
query=query_sql(),
117-
)
118-
ctx.logger.info(f"User: {user} - Received {res}")
119-
120-
await asyncio.to_thread(_submit_endpoint)
80+
await asyncio.to_thread(bq_submit_query, ctx, client)
12181

12282

12383
@sim_activity(
12484
wait_for=Event.ADMIN_LOW_ALL_RESULTS_AVAILABLE,
12585
trigger=Event.USER_CHECKED_RESULTS,
12686
)
12787
async def user_checks_results(ctx: SimulatorContext, client: sy.DatasiteClient):
128-
def _check_results():
129-
for request in client.requests:
130-
if request.get_status() == RequestStatus.APPROVED:
131-
job = request.code(blocking=False)
132-
result = job.wait()
133-
assert len(result) == 10000
134-
if request.get_status() == RequestStatus.REJECTED:
135-
ctx.logger.info(
136-
f"User: Request with function named {request.code.service_func_name} was rejected"
137-
)
138-
139-
await asyncio.to_thread(_check_results)
88+
await asyncio.to_thread(bq_submit_query_results, ctx, client)
14089

14190

14291
@sim_activity(wait_for=Event.GUEST_USERS_CREATED, trigger=Event.USER_FLOW_COMPLETED)
@@ -148,6 +97,7 @@ async def user_flow(ctx: SimulatorContext, server_url_low: str, user: dict):
14897
)
14998
ctx.logger.info(f"User: {client.logged_in_user} - logged in")
15099

100+
# this must be executed sequentially.
151101
await user_query_test_endpoint(ctx, client)
152102
await user_bq_submit(ctx, client)
153103
await user_checks_results(ctx, client)
@@ -158,7 +108,7 @@ async def user_flow(ctx: SimulatorContext, server_url_low: str, user: dict):
158108

159109
@sim_activity(trigger=Event.GUEST_USERS_CREATED)
160110
async def admin_signup_users(
161-
ctx: SimulatorContext, admin_client: SyftClient, users: list[dict]
111+
ctx: SimulatorContext, admin_client: sy.DatasiteClient, users: list[dict]
162112
):
163113
for user in users:
164114
ctx.logger.info(f"Admin low: Creating guest user {user['email']}")
@@ -173,7 +123,7 @@ async def admin_signup_users(
173123
@sim_activity(trigger=Event.ADMIN_BQ_SCHEMA_ENDPOINT_CREATED)
174124
async def admin_endpoint_bq_schema(
175125
ctx: SimulatorContext,
176-
admin_client: SyftClient,
126+
admin_client: sy.DatasiteClient,
177127
worker_pool: str | None = None,
178128
):
179129
path = "bigquery.schema"
@@ -195,7 +145,7 @@ async def admin_endpoint_bq_schema(
195145
@sim_activity(trigger=Event.ADMIN_BQ_TEST_ENDPOINT_CREATED)
196146
async def admin_endpoint_bq_test(
197147
ctx: SimulatorContext,
198-
admin_client: SyftClient,
148+
admin_client: sy.DatasiteClient,
199149
worker_pool: str | None = None,
200150
):
201151
path = "bigquery.test_query"
@@ -290,7 +240,7 @@ def execute_query(query: str, endpoint):
290240

291241

292242
@sim_activity(trigger=Event.ADMIN_ALL_ENDPOINTS_CREATED)
293-
async def admin_create_endpoint(ctx: SimulatorContext, admin_client: SyftClient):
243+
async def admin_create_endpoint(ctx: SimulatorContext, admin_client: sy.DatasiteClient):
294244
worker_pool = "biquery-pool"
295245

296246
await asyncio.gather(
@@ -308,7 +258,7 @@ async def admin_create_endpoint(ctx: SimulatorContext, admin_client: SyftClient)
308258
Event.ADMIN_LOWSIDE_WORKER_POOL_CREATED,
309259
]
310260
)
311-
async def admin_watch_sync(ctx: SimulatorContext, admin_client: SyftClient):
261+
async def admin_watch_sync(ctx: SimulatorContext, admin_client: sy.DatasiteClient):
312262
while True:
313263
await asyncio.sleep(random.uniform(5, 10))
314264

@@ -340,7 +290,7 @@ async def admin_watch_sync(ctx: SimulatorContext, admin_client: SyftClient):
340290

341291

342292
# @sim_activity(trigger=Event.ADMIN_WORKER_POOL_CREATED)
343-
async def admin_create_bq_pool(ctx: SimulatorContext, admin_client: SyftClient):
293+
async def admin_create_bq_pool(ctx: SimulatorContext, admin_client: sy.DatasiteClient):
344294
worker_pool = "biquery-pool"
345295

346296
base_image = admin_client.images.get_all()[0]
@@ -375,12 +325,16 @@ async def admin_create_bq_pool(ctx: SimulatorContext, admin_client: SyftClient):
375325

376326

377327
@sim_activity(trigger=Event.ADMIN_HIGHSIDE_WORKER_POOL_CREATED)
378-
async def admin_create_bq_pool_high(ctx: SimulatorContext, admin_client: SyftClient):
328+
async def admin_create_bq_pool_high(
329+
ctx: SimulatorContext, admin_client: sy.DatasiteClient
330+
):
379331
await admin_create_bq_pool(ctx, admin_client)
380332

381333

382334
@sim_activity(trigger=Event.ADMIN_LOWSIDE_WORKER_POOL_CREATED)
383-
async def admin_create_bq_pool_low(ctx: SimulatorContext, admin_client: SyftClient):
335+
async def admin_create_bq_pool_low(
336+
ctx: SimulatorContext, admin_client: sy.DatasiteClient
337+
):
384338
await admin_create_bq_pool(ctx, admin_client)
385339

386340

@@ -391,7 +345,9 @@ async def admin_create_bq_pool_low(ctx: SimulatorContext, admin_client: SyftClie
391345
],
392346
trigger=Event.ADMIN_HIGHSIDE_FLOW_COMPLETED,
393347
)
394-
async def admin_triage_requests_high(ctx: SimulatorContext, admin_client: SyftClient):
348+
async def admin_triage_requests_high(
349+
ctx: SimulatorContext, admin_client: sy.DatasiteClient
350+
):
395351
while True:
396352
await asyncio.sleep(random.uniform(5, 10))
397353

@@ -452,9 +408,7 @@ async def admin_low_side(ctx: SimulatorContext, admin_auth, users):
452408

453409
@sim_activity(trigger=Event.ADMIN_SYNC_COMPLETED)
454410
async def admin_sync_to_low_flow(
455-
ctx: SimulatorContext,
456-
admin_auth_high,
457-
admin_auth_low,
411+
ctx: SimulatorContext, admin_auth_high: dict, admin_auth_low: dict
458412
):
459413
high_client = sy.login(**admin_auth_high)
460414
ctx.logger.info("Admin: logged in to high-side")
@@ -485,7 +439,7 @@ async def admin_sync_to_low_flow(
485439

486440
@sim_activity(trigger=Event.ADMIN_SYNC_COMPLETED)
487441
async def admin_sync_to_high_flow(
488-
ctx: SimulatorContext, admin_auth_high, admin_auth_low
442+
ctx: SimulatorContext, admin_auth_high: dict, admin_auth_low: dict
489443
):
490444
high_client = sy.login(**admin_auth_high)
491445
ctx.logger.info("Admin low: logged in to high-side")

tests/scenariosv2/l2_test.py

+15-34
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,22 @@
77

88
# third party
99
from faker import Faker
10-
from l0_test import Event
11-
from l0_test import admin_create_bq_pool_high
12-
from l0_test import admin_create_endpoint
13-
from l0_test import admin_signup_users
14-
from l0_test import query_sql
1510
import pytest
16-
from sim.core import Simulator
17-
from sim.core import SimulatorContext
18-
from sim.core import sim_activity
19-
from sim.core import sim_entrypoint
2011

2112
# syft absolute
2213
import syft as sy
23-
from syft.client.client import SyftClient
14+
15+
# relative
16+
from .flows.user_bigquery_api import bq_submit_query
17+
from .flows.user_bigquery_api import bq_test_query
18+
from .l0_test import Event
19+
from .l0_test import admin_create_bq_pool_high
20+
from .l0_test import admin_create_endpoint
21+
from .l0_test import admin_signup_users
22+
from .sim.core import Simulator
23+
from .sim.core import SimulatorContext
24+
from .sim.core import sim_activity
25+
from .sim.core import sim_entrypoint
2426

2527
fake = Faker()
2628

@@ -31,7 +33,7 @@
3133
Event.USER_CAN_SUBMIT_QUERY,
3234
]
3335
)
34-
async def admin_triage_requests(ctx: SimulatorContext, admin_client: SyftClient):
36+
async def admin_triage_requests(ctx: SimulatorContext, admin_client: sy.DatasiteClient):
3537
while True:
3638
await asyncio.sleep(random.uniform(3, 5))
3739
ctx.logger.info("Admin: Triaging requests")
@@ -72,16 +74,7 @@ async def admin_flow(
7274
)
7375
async def user_query_test_endpoint(ctx: SimulatorContext, client: sy.DatasiteClient):
7476
"""Run query on test endpoint"""
75-
76-
user = client.logged_in_user
77-
78-
def _query_endpoint():
79-
ctx.logger.info(f"User: {user} - Calling client.api.bigquery.test_query (mock)")
80-
res = client.api.bigquery.test_query(sql_query=query_sql())
81-
assert len(res) == 10000
82-
ctx.logger.info(f"User: {user} - Received {len(res)} rows")
83-
84-
await asyncio.to_thread(_query_endpoint)
77+
await asyncio.to_thread(bq_test_query, ctx, client)
8578

8679

8780
@sim_activity(
@@ -93,19 +86,7 @@ def _query_endpoint():
9386
)
9487
async def user_bq_submit(ctx: SimulatorContext, client: sy.DatasiteClient):
9588
"""Submit query to be run on private data"""
96-
user = client.logged_in_user
97-
98-
def _submit_endpoint():
99-
ctx.logger.info(
100-
f"User: {user} - Calling client.api.services.bigquery.submit_query"
101-
)
102-
res = client.api.bigquery.submit_query(
103-
func_name="invalid_func",
104-
query=query_sql(),
105-
)
106-
ctx.logger.info(f"User: {user} - Received {res}")
107-
108-
await asyncio.to_thread(_submit_endpoint)
89+
await asyncio.to_thread(bq_submit_query, ctx, client)
10990

11091

11192
@sim_activity(

0 commit comments

Comments
 (0)