Skip to content

Commit

Permalink
fix: snapshot the current value of mutable objects (#12)
Browse files Browse the repository at this point in the history
BREAKING CHANGE: values have to be copyable with `copy.deepcopy`

This is a behaviour which is already expected from other libraries.

syrusakbary/snapshottest#99
  • Loading branch information
15r10nk authored Jul 12, 2023
1 parent 6ca0e21 commit 0c7976c
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 7 deletions.
9 changes: 5 additions & 4 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,11 @@ def test_something():

The code is generated in the following way:

1. The code is generated with `repr(value)`
2. Strings which contain newlines are converted to triple quoted strings.
3. The code is formatted with black.
4. The whole file is formatted with black if it was formatted before.
1. The value is copied with `value = copy.deepcopy(value)`
2. The code is generated with `repr(value)`
3. Strings which contain newlines are converted to triple quoted strings.
4. The code is formatted with black.
5. The whole file is formatted with black if it was formatted before.

!!! note
The black formatting of the whole file could not work for the following reasons:
Expand Down
7 changes: 7 additions & 0 deletions inline_snapshot/_inline_snapshot.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import ast
import contextlib
import copy
import inspect
import io
import token
Expand Down Expand Up @@ -151,6 +152,8 @@ def __getitem__(self, item):

class FixValue(GenericValue):
def __eq__(self, other):
other = copy.deepcopy(other)

if self._new_value is undefined:
self._new_value = other

Expand All @@ -171,6 +174,8 @@ def cmp(a, b):
raise NotImplemented

def _generic_cmp(self, other):
other = copy.deepcopy(other)

if self._new_value is undefined:
self._new_value = other
else:
Expand Down Expand Up @@ -258,6 +263,8 @@ def cmp(a, b):

class CollectionValue(GenericValue):
def __contains__(self, item):
item = copy.deepcopy(item)

if self._new_value is undefined:
self._new_value = [item]
else:
Expand Down
92 changes: 89 additions & 3 deletions tests/test_inline_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_disabled():
def check_update(tmp_path):
filecount = 1

def w(source, *, flags="", reported_flags=None):
def w(source, *, flags="", reported_flags=None, number=1):
flags = Flags({*flags.split(",")})
if reported_flags is None:
reported_flags = flags
Expand Down Expand Up @@ -65,7 +65,7 @@ def w(source, *, flags="", reported_flags=None):
finally:
_inline_snapshot._active = False

assert len(_inline_snapshot.snapshots) == 1
assert len(_inline_snapshot.snapshots) == number

snapshot_flags = set()

Expand All @@ -77,7 +77,7 @@ def w(source, *, flags="", reported_flags=None):

changes = recorder.changes()

assert len(changes) == 1
assert len(changes) == number

print("changes:")
recorder.dump()
Expand All @@ -88,6 +88,92 @@ def w(source, *, flags="", reported_flags=None):
return w


def test_mutable_values(check_update):
assert (
check_update(
"""
l=[1,2]
assert l==snapshot()
l.append(3)
assert l==snapshot()
""",
flags="create",
number=2,
)
== snapshot(
"""
l=[1,2]
assert l==snapshot([1, 2])
l.append(3)
assert l==snapshot([1, 2, 3])
"""
)
)

assert (
check_update(
"""
l=[1,2]
assert l<=snapshot()
l.append(3)
assert l<=snapshot()
""",
flags="create",
number=2,
)
== snapshot(
"""
l=[1,2]
assert l<=snapshot([1, 2])
l.append(3)
assert l<=snapshot([1, 2, 3])
"""
)
)

assert (
check_update(
"""
l=[1,2]
assert l>=snapshot()
l.append(3)
assert l>=snapshot()
""",
flags="create",
number=2,
)
== snapshot(
"""
l=[1,2]
assert l>=snapshot([1, 2])
l.append(3)
assert l>=snapshot([1, 2, 3])
"""
)
)

assert (
check_update(
"""
l=[1,2]
assert l in snapshot()
l.append(3)
assert l in snapshot()
""",
flags="create",
number=2,
)
== snapshot(
"""
l=[1,2]
assert l in snapshot([[1, 2]])
l.append(3)
assert l in snapshot([[1, 2, 3]])
"""
)
)


def test_comparison(check_update):
assert check_update("assert 5==snapshot()", flags="create") == snapshot(
"assert 5==snapshot(5)"
Expand Down

0 comments on commit 0c7976c

Please sign in to comment.