Skip to content

Commit 4ca4345

Browse files
committed
Better input validation for InputLayer with input_tensor provided
1 parent b6d305f commit 4ca4345

File tree

3 files changed

+118
-41
lines changed

3 files changed

+118
-41
lines changed

keras/src/layers/core/input_layer.py

+62-25
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def __init__(
2222
):
2323
# TODO: support for ragged.
2424
super().__init__(name=name)
25+
2526
if "input_shape" in kwargs:
2627
warnings.warn(
2728
"Argument `input_shape` is deprecated. Use `shape` instead."
@@ -30,40 +31,76 @@ def __init__(
3031
if "batch_input_shape" in kwargs:
3132
batch_shape = kwargs.pop("batch_input_shape")
3233

33-
if shape is not None and batch_shape is not None:
34-
raise ValueError(
35-
"You cannot pass both `shape` and `batch_shape` at the "
36-
"same time."
37-
)
38-
if batch_size is not None and batch_shape is not None:
39-
raise ValueError(
40-
"You cannot pass both `batch_size` and `batch_shape` at the "
41-
"same time."
42-
)
43-
if shape is None and batch_shape is None:
44-
raise ValueError("You must pass a `shape` argument.")
34+
if input_tensor is not None:
35+
if not isinstance(input_tensor, backend.KerasTensor):
36+
raise ValueError(
37+
"Argument `input_tensor` must be a KerasTensor. "
38+
f"Received invalid type: input_tensor={input_tensor} "
39+
f"(of type {type(input_tensor)})"
40+
)
41+
if batch_size is not None:
42+
if (
43+
len(input_tensor.shape) < 1
44+
or input_tensor.shape[0] != batch_size
45+
):
46+
raise ValueError(
47+
"When providing the `input_tensor` argument, you "
48+
"cannot provide an incompatible `batch_size` argument."
49+
)
50+
if shape is not None:
51+
if (
52+
len(shape) != len(input_tensor.shape) - 1
53+
or shape != input_tensor.shape[1:]
54+
):
55+
raise ValueError(
56+
"When providing the `input_tensor` argument, you "
57+
"cannot provide an incompatible `shape` argument."
58+
)
59+
if batch_shape is not None and batch_shape != input_tensor.shape:
60+
raise ValueError(
61+
"When providing the `input_tensor` argument, you "
62+
"cannot provide an incompatible `batch_shape` argument."
63+
)
64+
if dtype is not None and input_tensor.dtype != dtype:
65+
raise ValueError(
66+
"When providing the `input_tensor` argument, you "
67+
"cannot provide an incompatible `dtype` argument."
68+
)
69+
if sparse is not None and input_tensor.sparse != sparse:
70+
raise ValueError(
71+
"When providing the `input_tensor` argument, you "
72+
"cannot provide an incompatible `sparse` argument."
73+
)
74+
batch_shape = input_tensor.shape
75+
dtype = input_tensor.dtype
76+
sparse = input_tensor.sparse
77+
else:
78+
if shape is not None and batch_shape is not None:
79+
raise ValueError(
80+
"You cannot pass both `shape` and `batch_shape` at the "
81+
"same time."
82+
)
83+
if batch_size is not None and batch_shape is not None:
84+
raise ValueError(
85+
"You cannot pass both `batch_size` and `batch_shape` "
86+
"at the same time."
87+
)
88+
if shape is None and batch_shape is None:
89+
raise ValueError("You must pass a `shape` argument.")
90+
91+
if shape is not None:
92+
shape = backend.standardize_shape(shape)
93+
batch_shape = (batch_size,) + shape
4594

46-
if shape is not None:
47-
shape = backend.standardize_shape(shape)
48-
batch_shape = (batch_size,) + shape
4995
self._batch_shape = backend.standardize_shape(batch_shape)
5096
self._dtype = backend.standardize_dtype(dtype)
51-
5297
self.sparse = bool(sparse)
5398
if self.sparse and not backend.SUPPORTS_SPARSE_TENSORS:
5499
raise ValueError(
55100
"`sparse=True` is not supported with backend: "
56101
f"{backend.backend()}"
57102
)
58-
59-
if input_tensor is not None:
60-
if not isinstance(input_tensor, backend.KerasTensor):
61-
raise ValueError(
62-
"Argument `input_tensor` must be a KerasTensor. "
63-
f"Received invalid type: input_tensor={input_tensor} "
64-
f"(of type {type(input_tensor)})"
65-
)
66-
else:
103+
if input_tensor is None:
67104
input_tensor = backend.KerasTensor(
68105
shape=batch_shape, dtype=dtype, sparse=sparse, name=name
69106
)

keras/src/layers/core/input_layer_test.py

+56-13
Original file line numberDiff line numberDiff line change
@@ -89,25 +89,20 @@ def test_input_tensor_error(self):
8989
# Testing happy path for layer with input tensor
9090
def testing_input_tensor(self):
9191
input_shape = (2, 3)
92-
batch_size = 4
9392
dtype = "float32"
9493
input_tensor = KerasTensor(shape=input_shape, dtype=dtype)
9594

96-
values = InputLayer(
97-
shape=input_shape,
98-
batch_size=batch_size,
95+
layer = InputLayer(
9996
input_tensor=input_tensor,
100-
dtype=dtype,
10197
)
10298

103-
self.assertEqual(values.dtype, dtype)
104-
self.assertEqual(values.batch_shape[0], batch_size)
105-
self.assertEqual(values.batch_shape[1:], input_shape)
106-
self.assertEqual(values.trainable, True)
107-
self.assertIsInstance(values.output, KerasTensor)
108-
self.assertEqual(values.output, input_tensor)
109-
self.assertEqual(values.output.ndim, input_tensor.ndim)
110-
self.assertEqual(values.output.dtype, dtype)
99+
self.assertEqual(layer.dtype, dtype)
100+
self.assertEqual(layer.batch_shape, (2, 3))
101+
self.assertEqual(layer.trainable, True)
102+
self.assertIsInstance(layer.output, KerasTensor)
103+
self.assertEqual(layer.output, input_tensor)
104+
self.assertEqual(layer.output.ndim, input_tensor.ndim)
105+
self.assertEqual(layer.output.dtype, dtype)
111106

112107
def test_input_shape_deprecated(self):
113108
input_shape = (2, 3)
@@ -135,3 +130,51 @@ def test_call_method(self):
135130
def test_numpy_shape(self):
136131
# non-python int type shapes should be ok
137132
InputLayer(shape=(np.int64(32),))
133+
134+
def test_invalid_arg_combinations(self):
135+
input_tensor = KerasTensor(shape=(2, 3), dtype="float32")
136+
137+
with self.assertRaisesRegex(
138+
ValueError, "cannot provide an incompatible `shape`"
139+
):
140+
_ = InputLayer(
141+
shape=(2, 4),
142+
input_tensor=input_tensor,
143+
)
144+
with self.assertRaisesRegex(
145+
ValueError, "cannot provide an incompatible `batch_shape`"
146+
):
147+
_ = InputLayer(
148+
batch_shape=(2, 4),
149+
input_tensor=input_tensor,
150+
)
151+
with self.assertRaisesRegex(
152+
ValueError, "cannot provide an incompatible `batch_size`"
153+
):
154+
_ = InputLayer(
155+
batch_size=5,
156+
input_tensor=input_tensor,
157+
)
158+
with self.assertRaisesRegex(
159+
ValueError, "cannot provide an incompatible `dtype`"
160+
):
161+
_ = InputLayer(
162+
dtype="float16",
163+
input_tensor=input_tensor,
164+
)
165+
with self.assertRaisesRegex(
166+
ValueError, "cannot provide an incompatible `sparse`"
167+
):
168+
_ = InputLayer(
169+
sparse=True,
170+
input_tensor=input_tensor,
171+
)
172+
173+
# This works
174+
_ = InputLayer(
175+
shape=(3,),
176+
batch_size=2,
177+
sparse=False,
178+
dtype="float32",
179+
input_tensor=input_tensor,
180+
)

keras/src/models/cloning.py

-3
Original file line numberDiff line numberDiff line change
@@ -312,15 +312,12 @@ def _clone_sequential_model(model, clone_function, input_tensors=None):
312312
)
313313
inputs = Input(
314314
tensor=input_tensors,
315-
batch_shape=input_tensors.shape,
316-
dtype=input_tensors.dtype,
317315
name=input_name,
318316
)
319317
new_layers = [inputs] + new_layers
320318
else:
321319
if input_batch_shape is not None:
322320
inputs = Input(
323-
tensor=input_tensors,
324321
batch_shape=input_batch_shape,
325322
dtype=input_dtype,
326323
name=input_name,

0 commit comments

Comments
 (0)