@@ -22,6 +22,7 @@ def __init__(
22
22
):
23
23
# TODO: support for ragged.
24
24
super ().__init__ (name = name )
25
+
25
26
if "input_shape" in kwargs :
26
27
warnings .warn (
27
28
"Argument `input_shape` is deprecated. Use `shape` instead."
@@ -30,40 +31,76 @@ def __init__(
30
31
if "batch_input_shape" in kwargs :
31
32
batch_shape = kwargs .pop ("batch_input_shape" )
32
33
33
- if shape is not None and batch_shape is not None :
34
- raise ValueError (
35
- "You cannot pass both `shape` and `batch_shape` at the "
36
- "same time."
37
- )
38
- if batch_size is not None and batch_shape is not None :
39
- raise ValueError (
40
- "You cannot pass both `batch_size` and `batch_shape` at the "
41
- "same time."
42
- )
43
- if shape is None and batch_shape is None :
44
- raise ValueError ("You must pass a `shape` argument." )
34
+ if input_tensor is not None :
35
+ if not isinstance (input_tensor , backend .KerasTensor ):
36
+ raise ValueError (
37
+ "Argument `input_tensor` must be a KerasTensor. "
38
+ f"Received invalid type: input_tensor={ input_tensor } "
39
+ f"(of type { type (input_tensor )} )"
40
+ )
41
+ if batch_size is not None :
42
+ if (
43
+ len (input_tensor .shape ) < 1
44
+ or input_tensor .shape [0 ] != batch_size
45
+ ):
46
+ raise ValueError (
47
+ "When providing the `input_tensor` argument, you "
48
+ "cannot provide an incompatible `batch_size` argument."
49
+ )
50
+ if shape is not None :
51
+ if (
52
+ len (shape ) != len (input_tensor .shape ) - 1
53
+ or shape != input_tensor .shape [1 :]
54
+ ):
55
+ raise ValueError (
56
+ "When providing the `input_tensor` argument, you "
57
+ "cannot provide an incompatible `shape` argument."
58
+ )
59
+ if batch_shape is not None and batch_shape != input_tensor .shape :
60
+ raise ValueError (
61
+ "When providing the `input_tensor` argument, you "
62
+ "cannot provide an incompatible `batch_shape` argument."
63
+ )
64
+ if dtype is not None and input_tensor .dtype != dtype :
65
+ raise ValueError (
66
+ "When providing the `input_tensor` argument, you "
67
+ "cannot provide an incompatible `dtype` argument."
68
+ )
69
+ if sparse is not None and input_tensor .sparse != sparse :
70
+ raise ValueError (
71
+ "When providing the `input_tensor` argument, you "
72
+ "cannot provide an incompatible `sparse` argument."
73
+ )
74
+ batch_shape = input_tensor .shape
75
+ dtype = input_tensor .dtype
76
+ sparse = input_tensor .sparse
77
+ else :
78
+ if shape is not None and batch_shape is not None :
79
+ raise ValueError (
80
+ "You cannot pass both `shape` and `batch_shape` at the "
81
+ "same time."
82
+ )
83
+ if batch_size is not None and batch_shape is not None :
84
+ raise ValueError (
85
+ "You cannot pass both `batch_size` and `batch_shape` "
86
+ "at the same time."
87
+ )
88
+ if shape is None and batch_shape is None :
89
+ raise ValueError ("You must pass a `shape` argument." )
90
+
91
+ if shape is not None :
92
+ shape = backend .standardize_shape (shape )
93
+ batch_shape = (batch_size ,) + shape
45
94
46
- if shape is not None :
47
- shape = backend .standardize_shape (shape )
48
- batch_shape = (batch_size ,) + shape
49
95
self ._batch_shape = backend .standardize_shape (batch_shape )
50
96
self ._dtype = backend .standardize_dtype (dtype )
51
-
52
97
self .sparse = bool (sparse )
53
98
if self .sparse and not backend .SUPPORTS_SPARSE_TENSORS :
54
99
raise ValueError (
55
100
"`sparse=True` is not supported with backend: "
56
101
f"{ backend .backend ()} "
57
102
)
58
-
59
- if input_tensor is not None :
60
- if not isinstance (input_tensor , backend .KerasTensor ):
61
- raise ValueError (
62
- "Argument `input_tensor` must be a KerasTensor. "
63
- f"Received invalid type: input_tensor={ input_tensor } "
64
- f"(of type { type (input_tensor )} )"
65
- )
66
- else :
103
+ if input_tensor is None :
67
104
input_tensor = backend .KerasTensor (
68
105
shape = batch_shape , dtype = dtype , sparse = sparse , name = name
69
106
)
0 commit comments