Skip to content

Commit

Permalink
add pluck and use zstreamz in score
Browse files Browse the repository at this point in the history
  • Loading branch information
CJ-Wright committed Sep 27, 2018
1 parent d5b5e65 commit e781312
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 19 deletions.
6 changes: 3 additions & 3 deletions score.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ build:
default:
conda: python
run:
streamz:
zstreamz:
default:
conda: streamz
conda: zstreamz
bleeding:
pip: 'git+https://github.com/mrocklin/streamz.git#egg=streamz'
pip: 'git+https://github.com/xpdAcq/zstreamz.git#egg=zstreamz'
toolz:
default:
conda: toolz
Expand Down
9 changes: 5 additions & 4 deletions streamz_ext/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,22 @@
from .core import identity


def result_maybe(future_maybe, top=False):
def result_maybe(future_maybe):
try:
return future_maybe.result()
except AttributeError:
if isinstance(future_maybe, Sequence):
if isinstance(future_maybe, Sequence) and not isinstance(
future_maybe, str
):
aa = []
for a in future_maybe:
aa.append(result_maybe(a, top=False))
aa.append(result_maybe(a))
if isinstance(future_maybe, tuple):
aa = tuple(aa)
return aa
return future_maybe



def delayed_execution(func):
@wraps(func)
def inner(*args, **kwargs):
Expand Down
42 changes: 31 additions & 11 deletions streamz_ext/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .core import Stream

from collections import Sequence
from toolz import pluck as _pluck

NULL_COMPUTE = "~~NULL_COMPUTE~~"

Expand Down Expand Up @@ -142,12 +143,12 @@ class gather(ParallelStream):
def update(self, x, who=None):
client = self.default_client()
result = yield client.gather(x, asynchronous=True)
if not (
(
if (
not (
isinstance(result, Sequence)
and any(r == NULL_COMPUTE for r in result)
)
or result == NULL_COMPUTE
and result != NULL_COMPUTE
):
result2 = yield self._emit(result)
raise gen.Return(result2)
Expand Down Expand Up @@ -208,7 +209,7 @@ def update(self, x, who=None):
@ParallelStream.register_api()
class starmap(ParallelStream):
def __init__(self, upstream, func, *args, **kwargs):
self.func = filter_null_wrapper(func)
self.func = func
stream_name = kwargs.pop("stream_name", None)
self.kwargs = kwargs
self.args = args
Expand All @@ -217,7 +218,13 @@ def __init__(self, upstream, func, *args, **kwargs):

def update(self, x: Future, who=None):
client = self.default_client()
result = client.submit(apply, self.func, x, self.args, self.kwargs)
result = client.submit(
filter_null_wrapper(apply),
filter_null_wrapper(self.func),
x,
self.args,
self.kwargs,
)
return self._emit(result)


Expand All @@ -240,6 +247,25 @@ def update(self, x, who=None):
return self._emit(result)


@args_kwargs
@ParallelStream.register_api()
class pluck(ParallelStream):
def __init__(self, upstream, pick, **kwargs):
self.pick = pick
super().__init__(upstream, **kwargs)

def update(self, x, who=None):
client = self.default_client()
if isinstance(self.pick, Sequence):
return self._emit(
client.submit(filter_null_wrapper(_pluck), self.pick, x)
)
else:
return self._emit(
client.submit(filter_null_wrapper(getitem), x, self.pick)
)


@args_kwargs
@ParallelStream.register_api()
class buffer(ParallelStream, core.buffer):
Expand Down Expand Up @@ -316,9 +342,3 @@ class filenames(ParallelStream, sources.filenames):
@ParallelStream.register_api(staticmethod)
class from_textfile(ParallelStream, sources.from_textfile):
pass


@args_kwargs
@ParallelStream.register_api()
class pluck(ParallelStream, core.pluck):
pass
90 changes: 89 additions & 1 deletion test/test_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,6 @@ def test_buffer(backend):
def test_filter(backend):
source = Stream(asynchronous=True)
futures = scatter(source, backend=backend).filter(lambda x: x % 2 == 0)
print(type(futures))
futures_L = futures.sink_to_list()
L = futures.gather().sink_to_list()

Expand All @@ -200,6 +199,23 @@ def test_filter(backend):
assert all(isinstance(f, Future) for f in futures_L)


@pytest.mark.parametrize("backend", test_params)
@gen_test()
def test_filter_buffer(backend):
source = Stream(asynchronous=True)
futures = scatter(source, backend=backend).filter(lambda x: x % 2 == 0)
futures_L = futures.sink_to_list()
L = futures.buffer(10).gather().sink_to_list()

for i in range(5):
yield source.emit(i)
while len(L) < 3:
yield gen.sleep(.01)

assert L == [0, 2, 4]
assert all(isinstance(f, Future) for f in futures_L)


@pytest.mark.parametrize("backend", test_params)
@gen_test()
def test_filter_map(backend):
Expand All @@ -217,6 +233,38 @@ def test_filter_map(backend):
assert all(isinstance(f, Future) for f in futures_L)


@pytest.mark.parametrize("backend", test_params)
@gen_test()
def test_filter_starmap(backend):
source = Stream(asynchronous=True)
futures1 = scatter(source, backend=backend).filter(lambda x: x[1] % 2 == 0)
futures = futures1.starmap(add)
futures_L = futures.sink_to_list()
L = futures.gather().sink_to_list()

for i in range(5):
yield source.emit((i, i))

assert L == [0, 4, 8]
assert all(isinstance(f, Future) for f in futures_L)


@pytest.mark.parametrize("backend", test_params)
@gen_test()
def test_filter_pluck(backend):
source = Stream(asynchronous=True)
futures1 = scatter(source, backend=backend).filter(lambda x: x[1] % 2 == 0)
futures = futures1.pluck(0)
futures_L = futures.sink_to_list()
L = futures.gather().sink_to_list()

for i in range(5):
yield source.emit((i, i))

assert L == [0, 2, 4]
assert all(isinstance(f, Future) for f in futures_L)


@pytest.mark.parametrize("backend", test_params)
@gen_test()
def test_filter_zip(backend):
Expand Down Expand Up @@ -256,3 +304,43 @@ def test_double_scatter(backend):

assert L == [i + i for i in range(5)]
assert all(isinstance(f, Future) for f in futures_L)


@pytest.mark.parametrize("backend", test_params)
@gen_test
def test_pluck(backend):
source = Stream(asynchronous=True)
L = source.scatter(backend=backend).pluck(0).gather().sink_to_list()

for i in range(5):
yield source.emit((i, i))

assert L == list(range(5))


@pytest.mark.parametrize("backend", test_params)
@gen_test
def test_combined_latest(backend):
def delay(x):
time.sleep(.5)
return x
source = Stream(asynchronous=True)
source2 = Stream(asynchronous=True)
futures = source.scatter(backend=backend).map(delay).combine_latest(
source2.scatter(backend=backend), emit_on=1)
futures_L = futures.sink_to_list()
L = (
futures
.buffer(10)
.gather()
.sink_to_list()
)

for i in range(5):
yield source.emit(i)
yield source.emit(i)
yield source2.emit(i)

while len(L) < len(futures_L):
yield gen.sleep(.01)
assert L == [(i, i) for i in range(5)]

0 comments on commit e781312

Please sign in to comment.