23
23
import keras .engine as KE
24
24
import keras .models as KM
25
25
26
- config = tf .compat .v1 .ConfigProto ()
27
- config .gpu_options .allow_growth = True
28
- tf .compat .v1 .keras .backend .set_session (tf .compat .v1 .Session (config = config ))
29
-
30
26
from mrcnn import utils
31
27
32
28
# Requires TensorFlow 1.3+ and Keras 2.0.8+.
@@ -2169,14 +2165,16 @@ def compile(self, learning_rate, momentum):
2169
2165
loss_names = [
2170
2166
"rpn_class_loss" , "rpn_bbox_loss" ,
2171
2167
"mrcnn_class_loss" , "mrcnn_bbox_loss" , "mrcnn_mask_loss" ]
2168
+ output_names = []
2172
2169
for name in loss_names :
2173
2170
layer = self .keras_model .get_layer (name )
2174
- if layer .output in self . keras_model . losses :
2171
+ if layer .output . name in output_names :
2175
2172
continue
2176
2173
loss = (
2177
- tf .reduce_mean (layer .output , keepdims = True )
2174
+ tf .reduce_mean (input_tensor = layer .output , keepdims = True )
2178
2175
* self .config .LOSS_WEIGHTS .get (name , 1. ))
2179
2176
self .keras_model .add_loss (loss )
2177
+ output_names .append (layer .output .name )
2180
2178
2181
2179
# Add L2 Regularization
2182
2180
# Skip gamma and beta weights of batch normalization layers.
@@ -2200,7 +2198,7 @@ def compile(self, learning_rate, momentum):
2200
2198
loss = (
2201
2199
tf .reduce_mean (layer .output , keepdims = True )
2202
2200
* self .config .LOSS_WEIGHTS .get (name , 1. ))
2203
- self .keras_model .metrics_tensors . append (loss )
2201
+ self .keras_model .add_metric (loss , name )
2204
2202
2205
2203
def set_trainable (self , layer_regex , keras_model = None , indent = 0 , verbose = 1 ):
2206
2204
"""Sets model layers as trainable if their names match
@@ -2374,8 +2372,8 @@ def train(self, train_dataset, val_dataset, learning_rate, epochs, layers,
2374
2372
validation_data = val_generator ,
2375
2373
validation_steps = self .config .VALIDATION_STEPS ,
2376
2374
max_queue_size = 100 ,
2377
- workers = 1 ,
2378
- use_multiprocessing = False ,
2375
+ workers = workers ,
2376
+ use_multiprocessing = True ,
2379
2377
)
2380
2378
self .epoch = max (self .epoch , epochs )
2381
2379
0 commit comments