diff --git a/score.yaml b/score.yaml index e1a1d72..6a19400 100644 --- a/score.yaml +++ b/score.yaml @@ -12,11 +12,11 @@ build: default: conda: python run: - streamz: + zstreamz: default: - conda: streamz + conda: zstreamz bleeding: - pip: 'git+https://github.com/mrocklin/streamz.git#egg=streamz' + pip: 'git+https://github.com/xpdAcq/zstreamz.git#egg=zstreamz' toolz: default: conda: toolz diff --git a/streamz_ext/clients.py b/streamz_ext/clients.py index 25b60e2..5ca35a1 100644 --- a/streamz_ext/clients.py +++ b/streamz_ext/clients.py @@ -9,21 +9,22 @@ from .core import identity -def result_maybe(future_maybe, top=False): +def result_maybe(future_maybe): try: return future_maybe.result() except AttributeError: - if isinstance(future_maybe, Sequence): + if isinstance(future_maybe, Sequence) and not isinstance( + future_maybe, str + ): aa = [] for a in future_maybe: - aa.append(result_maybe(a, top=False)) + aa.append(result_maybe(a)) if isinstance(future_maybe, tuple): aa = tuple(aa) return aa return future_maybe - def delayed_execution(func): @wraps(func) def inner(*args, **kwargs): diff --git a/streamz_ext/parallel.py b/streamz_ext/parallel.py index 7f671ba..d6fca62 100644 --- a/streamz_ext/parallel.py +++ b/streamz_ext/parallel.py @@ -13,6 +13,7 @@ from .core import Stream from collections import Sequence +from toolz import pluck as _pluck NULL_COMPUTE = "~~NULL_COMPUTE~~" @@ -142,12 +143,12 @@ class gather(ParallelStream): def update(self, x, who=None): client = self.default_client() result = yield client.gather(x, asynchronous=True) - if not ( - ( + if ( + not ( isinstance(result, Sequence) and any(r == NULL_COMPUTE for r in result) ) - or result == NULL_COMPUTE + and result != NULL_COMPUTE ): result2 = yield self._emit(result) raise gen.Return(result2) @@ -208,7 +209,7 @@ def update(self, x, who=None): @ParallelStream.register_api() class starmap(ParallelStream): def __init__(self, upstream, func, *args, **kwargs): - self.func = filter_null_wrapper(func) + self.func = func stream_name = kwargs.pop("stream_name", None) self.kwargs = kwargs self.args = args @@ -217,7 +218,13 @@ def __init__(self, upstream, func, *args, **kwargs): def update(self, x: Future, who=None): client = self.default_client() - result = client.submit(apply, self.func, x, self.args, self.kwargs) + result = client.submit( + filter_null_wrapper(apply), + filter_null_wrapper(self.func), + x, + self.args, + self.kwargs, + ) return self._emit(result) @@ -240,6 +247,25 @@ def update(self, x, who=None): return self._emit(result) +@args_kwargs +@ParallelStream.register_api() +class pluck(ParallelStream): + def __init__(self, upstream, pick, **kwargs): + self.pick = pick + super().__init__(upstream, **kwargs) + + def update(self, x, who=None): + client = self.default_client() + if isinstance(self.pick, Sequence): + return self._emit( + client.submit(filter_null_wrapper(_pluck), self.pick, x) + ) + else: + return self._emit( + client.submit(filter_null_wrapper(getitem), x, self.pick) + ) + + @args_kwargs @ParallelStream.register_api() class buffer(ParallelStream, core.buffer): @@ -316,9 +342,3 @@ class filenames(ParallelStream, sources.filenames): @ParallelStream.register_api(staticmethod) class from_textfile(ParallelStream, sources.from_textfile): pass - - -@args_kwargs -@ParallelStream.register_api() -class pluck(ParallelStream, core.pluck): - pass diff --git a/test/test_parallel.py b/test/test_parallel.py index 9d0d01a..e6999ef 100644 --- a/test/test_parallel.py +++ b/test/test_parallel.py @@ -189,7 +189,6 @@ def test_buffer(backend): def test_filter(backend): source = Stream(asynchronous=True) futures = scatter(source, backend=backend).filter(lambda x: x % 2 == 0) - print(type(futures)) futures_L = futures.sink_to_list() L = futures.gather().sink_to_list() @@ -200,6 +199,23 @@ def test_filter(backend): assert all(isinstance(f, Future) for f in futures_L) +@pytest.mark.parametrize("backend", test_params) +@gen_test() +def test_filter_buffer(backend): + source = Stream(asynchronous=True) + futures = scatter(source, backend=backend).filter(lambda x: x % 2 == 0) + futures_L = futures.sink_to_list() + L = futures.buffer(10).gather().sink_to_list() + + for i in range(5): + yield source.emit(i) + while len(L) < 3: + yield gen.sleep(.01) + + assert L == [0, 2, 4] + assert all(isinstance(f, Future) for f in futures_L) + + @pytest.mark.parametrize("backend", test_params) @gen_test() def test_filter_map(backend): @@ -217,6 +233,38 @@ def test_filter_map(backend): assert all(isinstance(f, Future) for f in futures_L) +@pytest.mark.parametrize("backend", test_params) +@gen_test() +def test_filter_starmap(backend): + source = Stream(asynchronous=True) + futures1 = scatter(source, backend=backend).filter(lambda x: x[1] % 2 == 0) + futures = futures1.starmap(add) + futures_L = futures.sink_to_list() + L = futures.gather().sink_to_list() + + for i in range(5): + yield source.emit((i, i)) + + assert L == [0, 4, 8] + assert all(isinstance(f, Future) for f in futures_L) + + +@pytest.mark.parametrize("backend", test_params) +@gen_test() +def test_filter_pluck(backend): + source = Stream(asynchronous=True) + futures1 = scatter(source, backend=backend).filter(lambda x: x[1] % 2 == 0) + futures = futures1.pluck(0) + futures_L = futures.sink_to_list() + L = futures.gather().sink_to_list() + + for i in range(5): + yield source.emit((i, i)) + + assert L == [0, 2, 4] + assert all(isinstance(f, Future) for f in futures_L) + + @pytest.mark.parametrize("backend", test_params) @gen_test() def test_filter_zip(backend): @@ -256,3 +304,43 @@ def test_double_scatter(backend): assert L == [i + i for i in range(5)] assert all(isinstance(f, Future) for f in futures_L) + + +@pytest.mark.parametrize("backend", test_params) +@gen_test +def test_pluck(backend): + source = Stream(asynchronous=True) + L = source.scatter(backend=backend).pluck(0).gather().sink_to_list() + + for i in range(5): + yield source.emit((i, i)) + + assert L == list(range(5)) + + +@pytest.mark.parametrize("backend", test_params) +@gen_test +def test_combined_latest(backend): + def delay(x): + time.sleep(.5) + return x + source = Stream(asynchronous=True) + source2 = Stream(asynchronous=True) + futures = source.scatter(backend=backend).map(delay).combine_latest( + source2.scatter(backend=backend), emit_on=1) + futures_L = futures.sink_to_list() + L = ( + futures + .buffer(10) + .gather() + .sink_to_list() + ) + + for i in range(5): + yield source.emit(i) + yield source.emit(i) + yield source2.emit(i) + + while len(L) < len(futures_L): + yield gen.sleep(.01) + assert L == [(i, i) for i in range(5)]