forked from optuna/optuna
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_convert_positional_args.py
127 lines (96 loc) · 4.21 KB
/
test_convert_positional_args.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import re
from typing import List
import pytest
from optuna._convert_positional_args import convert_positional_args
def _sample_func(*, a: int, b: int, c: int) -> int:
return a + b + c
class _SimpleClass:
@convert_positional_args(previous_positional_arg_names=["self", "a", "b"])
def simple_method(self, a: int, *, b: int, c: int = 1) -> None:
pass
def test_convert_positional_args_decorator() -> None:
previous_positional_arg_names: List[str] = []
decorator_converter = convert_positional_args(
previous_positional_arg_names=previous_positional_arg_names
)
decorated_func = decorator_converter(_sample_func)
assert decorated_func.__name__ == _sample_func.__name__
def test_convert_positional_args_future_warning_for_methods() -> None:
simple_class = _SimpleClass()
with pytest.warns(FutureWarning) as record:
simple_class.simple_method(1, 2, c=3) # type: ignore
simple_class.simple_method(1, b=2, c=3) # No warning.
simple_class.simple_method(a=1, b=2, c=3) # No warning.
assert len(record) == 1
for warn in record.list:
assert isinstance(warn.message, FutureWarning)
assert "simple_method" in str(warn.message)
def test_convert_positional_args_future_warning() -> None:
previous_positional_arg_names: List[str] = ["a", "b"]
decorator_converter = convert_positional_args(
previous_positional_arg_names=previous_positional_arg_names
)
assert callable(decorator_converter)
decorated_func = decorator_converter(_sample_func)
with pytest.warns(FutureWarning) as record:
decorated_func(1, 2, c=3) # type: ignore
decorated_func(1, b=2, c=3) # type: ignore
decorated_func(a=1, b=2, c=3) # No warning.
assert len(record) == 2
for warn in record.list:
assert isinstance(warn.message, FutureWarning)
assert _sample_func.__name__ in str(warn.message)
def test_convert_positional_args_mypy_type_inference() -> None:
previous_positional_arg_names: List[str] = []
decorator_converter = convert_positional_args(
previous_positional_arg_names=previous_positional_arg_names
)
assert callable(decorator_converter)
class _Sample:
def __init__(self) -> None:
pass
def method(self) -> bool:
return True
def _func_sample() -> _Sample:
return _Sample()
def _func_none() -> None:
pass
ret_none = decorator_converter(_func_none)()
assert ret_none is None
ret_sample = decorator_converter(_func_sample)()
assert isinstance(ret_sample, _Sample)
assert ret_sample.method()
@pytest.mark.parametrize(
"previous_positional_arg_names, raise_error",
[(["a", "b", "c", "d"], True), (["a", "d"], True), (["b", "a"], False)],
)
def test_convert_positional_args_invalid_previous_positional_arg_names(
previous_positional_arg_names: List[str], raise_error: bool
) -> None:
decorator_converter = convert_positional_args(
previous_positional_arg_names=previous_positional_arg_names
)
assert callable(decorator_converter)
if raise_error:
with pytest.raises(AssertionError) as record:
decorator_converter(_sample_func)
res = re.findall(r"({.+?}|set\(\))", str(record.value))
assert len(res) == 2
assert eval(res[0]) == set(previous_positional_arg_names)
assert eval(res[1]) == set(["a", "b", "c"])
else:
decorator_converter(_sample_func)
def test_convert_positional_args_invalid_positional_args() -> None:
previous_positional_arg_names: List[str] = ["a", "b"]
decorator_converter = convert_positional_args(
previous_positional_arg_names=previous_positional_arg_names
)
assert callable(decorator_converter)
decorated_func = decorator_converter(_sample_func)
with pytest.warns(FutureWarning):
with pytest.raises(TypeError) as record:
decorated_func(1, 2, 3) # type: ignore
assert str(record.value) == "_sample_func() takes 2 positional arguments but 3 were given."
with pytest.raises(TypeError) as record:
decorated_func(1, 3, b=2) # type: ignore
assert str(record.value) == "_sample_func() got multiple values for arguments {'b'}."