From 0c7976c417eca5d69b49d695a3f3b5f9bd142ed1 Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <44680962+15r10nk@users.noreply.github.com> Date: Wed, 12 Jul 2023 18:34:06 +0200 Subject: [PATCH] fix: snapshot the current value of mutable objects (#12) BREAKING CHANGE: values have to be copyable with `copy.deepcopy` This is a behaviour which is already expected from other libraries. https://github.com/syrusakbary/snapshottest/issues/99 --- docs/index.md | 9 +-- inline_snapshot/_inline_snapshot.py | 7 +++ tests/test_inline_snapshot.py | 92 ++++++++++++++++++++++++++++- 3 files changed, 101 insertions(+), 7 deletions(-) diff --git a/docs/index.md b/docs/index.md index 436dad12..862ab23b 100644 --- a/docs/index.md +++ b/docs/index.md @@ -210,10 +210,11 @@ def test_something(): The code is generated in the following way: -1. The code is generated with `repr(value)` -2. Strings which contain newlines are converted to triple quoted strings. -3. The code is formatted with black. -4. The whole file is formatted with black if it was formatted before. +1. The value is copied with `value = copy.deepcopy(value)` +2. The code is generated with `repr(value)` +3. Strings which contain newlines are converted to triple quoted strings. +4. The code is formatted with black. +5. The whole file is formatted with black if it was formatted before. !!! note The black formatting of the whole file could not work for the following reasons: diff --git a/inline_snapshot/_inline_snapshot.py b/inline_snapshot/_inline_snapshot.py index 3d781dd9..85a0285c 100644 --- a/inline_snapshot/_inline_snapshot.py +++ b/inline_snapshot/_inline_snapshot.py @@ -1,5 +1,6 @@ import ast import contextlib +import copy import inspect import io import token @@ -151,6 +152,8 @@ def __getitem__(self, item): class FixValue(GenericValue): def __eq__(self, other): + other = copy.deepcopy(other) + if self._new_value is undefined: self._new_value = other @@ -171,6 +174,8 @@ def cmp(a, b): raise NotImplemented def _generic_cmp(self, other): + other = copy.deepcopy(other) + if self._new_value is undefined: self._new_value = other else: @@ -258,6 +263,8 @@ def cmp(a, b): class CollectionValue(GenericValue): def __contains__(self, item): + item = copy.deepcopy(item) + if self._new_value is undefined: self._new_value = [item] else: diff --git a/tests/test_inline_snapshot.py b/tests/test_inline_snapshot.py index 3e79acab..f133c1a6 100644 --- a/tests/test_inline_snapshot.py +++ b/tests/test_inline_snapshot.py @@ -34,7 +34,7 @@ def test_disabled(): def check_update(tmp_path): filecount = 1 - def w(source, *, flags="", reported_flags=None): + def w(source, *, flags="", reported_flags=None, number=1): flags = Flags({*flags.split(",")}) if reported_flags is None: reported_flags = flags @@ -65,7 +65,7 @@ def w(source, *, flags="", reported_flags=None): finally: _inline_snapshot._active = False - assert len(_inline_snapshot.snapshots) == 1 + assert len(_inline_snapshot.snapshots) == number snapshot_flags = set() @@ -77,7 +77,7 @@ def w(source, *, flags="", reported_flags=None): changes = recorder.changes() - assert len(changes) == 1 + assert len(changes) == number print("changes:") recorder.dump() @@ -88,6 +88,92 @@ def w(source, *, flags="", reported_flags=None): return w +def test_mutable_values(check_update): + assert ( + check_update( + """ +l=[1,2] +assert l==snapshot() +l.append(3) +assert l==snapshot() + """, + flags="create", + number=2, + ) + == snapshot( + """ +l=[1,2] +assert l==snapshot([1, 2]) +l.append(3) +assert l==snapshot([1, 2, 3]) +""" + ) + ) + + assert ( + check_update( + """ +l=[1,2] +assert l<=snapshot() +l.append(3) +assert l<=snapshot() + """, + flags="create", + number=2, + ) + == snapshot( + """ +l=[1,2] +assert l<=snapshot([1, 2]) +l.append(3) +assert l<=snapshot([1, 2, 3]) +""" + ) + ) + + assert ( + check_update( + """ +l=[1,2] +assert l>=snapshot() +l.append(3) +assert l>=snapshot() + """, + flags="create", + number=2, + ) + == snapshot( + """ +l=[1,2] +assert l>=snapshot([1, 2]) +l.append(3) +assert l>=snapshot([1, 2, 3]) +""" + ) + ) + + assert ( + check_update( + """ +l=[1,2] +assert l in snapshot() +l.append(3) +assert l in snapshot() + """, + flags="create", + number=2, + ) + == snapshot( + """ +l=[1,2] +assert l in snapshot([[1, 2]]) +l.append(3) +assert l in snapshot([[1, 2, 3]]) +""" + ) + ) + + def test_comparison(check_update): assert check_update("assert 5==snapshot()", flags="create") == snapshot( "assert 5==snapshot(5)"