From 912c4418139f43fe7e3205cd737617b65eae4eb0 Mon Sep 17 00:00:00 2001 From: Philippe Remy Date: Sat, 19 Jun 2021 18:00:57 +0200 Subject: [PATCH] fixing tests --- keract/keract.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/keract/keract.py b/keract/keract.py index db87d75..c7e1c36 100644 --- a/keract/keract.py +++ b/keract/keract.py @@ -3,11 +3,10 @@ from collections import OrderedDict import numpy as np +import tensorflow as tf import tensorflow.keras.backend as K from tensorflow.keras import Sequential from tensorflow.keras.models import Model -import tensorflow as tf - if tf.__version__ == '2.5.0': tf.compat.v1.experimental.output_all_intermediates(True) @@ -212,16 +211,14 @@ def update_node(n): except AttributeError: pass - if hasattr(module, '_layers'): - for layer in module._layers: - update_node(layer) - if nested: - _get_nodes(layer, nodes, output_format, nested, layer_names, depth + 1) - else: - for layer in module.layers: - update_node(layer) - if nested: - _get_nodes(layer, nodes, output_format, nested, layer_names, depth + 1) + try: + layers = module._layers if hasattr(module, '_layers') else module.layers + except AttributeError: + return + for layer in layers: + update_node(layer) + if nested: + _get_nodes(layer, nodes, output_format, nested, layer_names, depth + 1) # def _get_nodes(module, output_format, nested=False, layer_names=[]):