Skip to content

Commit

Permalink
allow non tuples in starmap fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
CJ-Wright committed Sep 28, 2018
1 parent eed0b74 commit 912da0f
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 23 deletions.
2 changes: 2 additions & 0 deletions streamz_ext/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@


def apply(func, args, args2=None, kwargs=None):
if not isinstance(args, Sequence) or isinstance(args, str):
args = (args,)
if args2:
args = args + args2
if kwargs:
Expand Down
2 changes: 1 addition & 1 deletion test/test_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def test_filter_zip(backend):
@gen_test()
def test_double_scatter(backend):
source1 = Stream(asynchronous=True)
source2 = Stream()
source2 = Stream(asynchronous=True)
sm = (
source1.scatter(backend=backend)
.zip(source2.scatter(backend=backend))
Expand Down
29 changes: 7 additions & 22 deletions test/test_parallel_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def add(x, y, z=0):


@gen_cluster(client=True)
def test_starmap_args():
def test_starmap_args(c, s, a, b):
def add(x, y, z=0):
return x + y + z

Expand Down Expand Up @@ -257,37 +257,22 @@ def test_filter_zip(c, s, a, b):
@gen_cluster(client=True)
def test_double_scatter(c, s, a, b):
source1 = Stream(asynchronous=True)
source2 = Stream()
sm = source1.scatter().zip(source2.scatter()).starmap(add)
futures_L = sm.sink_to_list()
r = sm.buffer(10).gather()
L = r.sink_to_list()

for i in range(5):
yield source1.emit(i)
yield source2.emit(i)

while len(L) < len(futures_L):
yield gen.sleep(.01)

assert L == [i + i for i in range(5)]
assert all(isinstance(f, Future) for f in futures_L)


@gen_cluster(client=True)
def test_double_scatter(c, s, a, b):
source1 = Stream(asynchronous=True)
source2 = Stream()
source1.sink(print)
source2 = Stream(asynchronous=True)
source2.sink(print)
sm = source1.scatter().zip(source2.scatter()).starmap(add)
futures_L = sm.sink_to_list()
r = sm.buffer(10).gather()
L = r.sink_to_list()

print('hi')
for i in range(5):
print('hi')
yield source1.emit(i)
yield source2.emit(i)

while len(L) < len(futures_L):
print(len(L), print(len(futures_L)))
yield gen.sleep(.01)

assert L == [i + i for i in range(5)]
Expand Down

0 comments on commit 912da0f

Please sign in to comment.