Skip to content

Commit 04ebd5d

Browse files
authored
matterport/MaskRCNN#1775
1 parent 3e5358d commit 04ebd5d

File tree

1 file changed

+7
-9
lines changed

1 file changed

+7
-9
lines changed

mrcnn/model.py

+7-9
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,6 @@
2323
import keras.engine as KE
2424
import keras.models as KM
2525

26-
config = tf.compat.v1.ConfigProto()
27-
config.gpu_options.allow_growth = True
28-
tf.compat.v1.keras.backend.set_session(tf.compat.v1.Session(config=config))
29-
3026
from mrcnn import utils
3127

3228
# Requires TensorFlow 1.3+ and Keras 2.0.8+.
@@ -2169,14 +2165,16 @@ def compile(self, learning_rate, momentum):
21692165
loss_names = [
21702166
"rpn_class_loss", "rpn_bbox_loss",
21712167
"mrcnn_class_loss", "mrcnn_bbox_loss", "mrcnn_mask_loss"]
2168+
output_names = []
21722169
for name in loss_names:
21732170
layer = self.keras_model.get_layer(name)
2174-
if layer.output in self.keras_model.losses:
2171+
if layer.output.name in output_names:
21752172
continue
21762173
loss = (
2177-
tf.reduce_mean(layer.output, keepdims=True)
2174+
tf.reduce_mean(input_tensor=layer.output, keepdims=True)
21782175
* self.config.LOSS_WEIGHTS.get(name, 1.))
21792176
self.keras_model.add_loss(loss)
2177+
output_names.append(layer.output.name)
21802178

21812179
# Add L2 Regularization
21822180
# Skip gamma and beta weights of batch normalization layers.
@@ -2200,7 +2198,7 @@ def compile(self, learning_rate, momentum):
22002198
loss = (
22012199
tf.reduce_mean(layer.output, keepdims=True)
22022200
* self.config.LOSS_WEIGHTS.get(name, 1.))
2203-
self.keras_model.metrics_tensors.append(loss)
2201+
self.keras_model.add_metric(loss, name)
22042202

22052203
def set_trainable(self, layer_regex, keras_model=None, indent=0, verbose=1):
22062204
"""Sets model layers as trainable if their names match
@@ -2374,8 +2372,8 @@ def train(self, train_dataset, val_dataset, learning_rate, epochs, layers,
23742372
validation_data=val_generator,
23752373
validation_steps=self.config.VALIDATION_STEPS,
23762374
max_queue_size=100,
2377-
workers=1,
2378-
use_multiprocessing=False,
2375+
workers=workers,
2376+
use_multiprocessing=True,
23792377
)
23802378
self.epoch = max(self.epoch, epochs)
23812379

0 commit comments

Comments
 (0)