Skip to content
This repository has been archived by the owner on Nov 21, 2023. It is now read-only.

How to visualize the network structure #320

Closed
JiYuanFeng opened this issue Mar 27, 2018 · 19 comments
Closed

How to visualize the network structure #320

JiYuanFeng opened this issue Mar 27, 2018 · 19 comments

Comments

@JiYuanFeng
Copy link

Hi, Is there any easy way to visualize the network? like Netscope for caffe?
I can use the net_drawer in caffe2, but found it;' so hard to read the network?

@taoari
Copy link

taoari commented Mar 29, 2018

The following is a code snippet to visualize:

def get_model(cfg_file, weights_file):
    merge_cfg_from_file(cfg_file)
    cfg.TRAIN.WEIGHTS = '' # NOTE: do not download pretrained model weights
    cfg.TEST.WEIGHTS = weights_file
    cfg.NUM_GPUS = 1
    assert_and_infer_cfg()
    model = infer_engine.initialize_model_from_cfg()
    return model

cfg_file = '{}/configs/12_2017_baselines/e2e_mask_rcnn_R-101-FPN_2x.yaml'.format(DETECTRON_ROOT)
weights_file = 'https://s3-us-west-2.amazonaws.com/detectron/' + \
    '35861858/12_2017_baselines/e2e_mask_rcnn_R-101-FPN_2x.yaml.02_32_51.SgT4y1cO/' + \
    'output/train/coco_2014_train:coco_2014_valminusminival/generalized_rcnn/model_final.pkl'
model = get_model(cfg_file, weights_file)

from caffe2.python import net_drawer
g = net_drawer.GetPydotGraph(model, rankdir="TB")
g.write_dot(model.Proto().name + '.dot')

Then you can use a dot viewer (e.g. xdot) to visualize.

@JiYuanFeng
Copy link
Author

Thank you taoari !
I have tested the code in my caffe2 enviorment, And I meet another problem,
File "/home/yuanfeng/Project/Detectron/lib/modeling/ResNet.py", line 95, in add_ResNet_convX_body p = model.AffineChannel(p, 'res_conv1_bn', dim=64, inplace=True) File "/home/yuanfeng/Project/Detectron/lib/modeling/detector.py", line 102, in AffineChannel return self.net.AffineChannel([blob_in, scale, bias], blob_in) File "/home/yuanfeng/Project/caffe2-master/build/caffe2/python/core.py", line 2046, in __getattr__ ",".join(workspace.C.nearby_opnames(op_type)) + ']' AttributeError: Method AffineChannel is not a registered operator. Did you mean: []

i think the original caffe2 don't register the AffinaChannel op,which is belonged to the detectron.
So should i added this op to the caffe2?

@taoari
Copy link

taoari commented Mar 30, 2018

@JiYuanFeng
Copy link
Author

@taoari Thank you for your help!
the problem has solved

@gadcam
Copy link
Contributor

gadcam commented Mar 30, 2018

@taoari @JiYuanFeng Maybe you could do a pull request with your script so everyone could benefit from it? :)

@JiYuanFeng
Copy link
Author

@gadcam i will do it ~

@liujing1995
Copy link

nice

@PacteraOliver
Copy link

@taoari

Hi, thank you for your code.

I use part of the code and make it as a tool.
I just start a pull request at #508

@drcege
Copy link

drcege commented Jun 28, 2018

@PacteraOliver @taoari
Hi, could you really visualize the graph?

I tried GetPydotGraph and GetPydotGraphMinimal, but it hangs too long like this.

@PacteraOliver
Copy link

Yes, the whole program need to load the model (pkl file) and the config file.
The output file is a .dot file. If you open with xdot, the sample is:

screenshot from 2018-06-28 11-19-20

The sample part of the graph is like that:

screenshot from 2018-06-28 11-19-44

It does not take long time.

@WuZhuoran
Copy link

@drcege

Hi, if you like you could try the code in pull request #508

Could you provide more details with the model you want to visualize? You did you gpu to load model right?

@drcege
Copy link

drcege commented Jun 29, 2018

@WuZhuoran I tried your pull request, and it works well. One question: the visualization is for model definition, why we must provide a weights_file? Can we get rid of it?

Your code could draw an inference net. However, my problem is that I want to visualize the training net of e2e_faster_rcnn_R-50-FPN_1x.yaml (I have changed it to one gpu). I saved the dot file just after the default proto dump:

nu.broadcast_parameters(model)
workspace.CreateNet(model.net)
logger.info('Outputs saved to: {:s}'.format(os.path.abspath(output_dir)))
dump_proto_files(model, output_dir)

    g = net_drawer.GetPydotGraph(model.net, rankdir='TB')
    g.write('train_net.dot')

I also tried GetPydotGraphMinimal. But even with the minimal dot file, it takes more than 2 hours for xdot to pop up the graph window! Do you have any experience in visualizing training nets?

@drcege
Copy link

drcege commented Jun 29, 2018

I think maybe this is because there are too many gradient/lr related nodes. Actually, my purpose is to include the external input nodes (e.g., labels). The gradient nodes are unnecessary.

@tonbing
Copy link

tonbing commented Sep 22, 2018

I tried to use the same process to get visualize the network but I am getting the following error? Anyone knows how to resolve that?

Traceback (most recent call last):
File "visual.py", line 18, in
model = get_model(cfg_file, weights_file)
File "visual.py", line 12, in get_model
model = infer_engine.initialize_model_from_cfg(1)
File "/detectron/detectron/core/test_engine.py", line 330, in initialize_model_from_cfg
model, weights_file, gpu_id=gpu_id,
File "/detectron/detectron/utils/net.py", line 61, in initialize_gpu_from_weights_file
with open(weights_file, 'r') as f:
TypeError: coercing to Unicode: need string or buffer, int found

@gadcam
Copy link
Contributor

gadcam commented Sep 23, 2018

@tonbing Looks like weights_file is an int in this portion of the code. I would first double check that the content of this variable is correct.

@tonbing
Copy link

tonbing commented Sep 23, 2018

Hey @gadcam thanks for the reply, I figured the problem out. Above code had some issue, that's why it was giving me the above error. But the code for #508 request worked just fine with some tweaks according to the requirement of my model. I believe the code should have worked also, but it did not.

@jiec-msft
Copy link

I use anaconda with python2.7.
I installed

  • graphviz-2.40.1-h21bd128_1.tar.bz2 from https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/linux-64/graphviz-2.40.1-h21bd128_1.tar.bz2
  • pydot-1.2.3-py27h55c791f_0.tar.bz2 from https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/linux-64/pydot-1.2.3-py27h55c791f_0.tar.bz2
    and here is my code, it works fine.
#!/usr/bin/env python
# -×- coding: utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

from collections import defaultdict
import argparse
import cv2  # NOQA (Must import before importing caffe2 due to bug in cv2)
import glob
import logging
import os
import sys
reload(sys) 
sys.setdefaultencoding('utf8') # 设置默认编码格式为'utf-8'

import time


from caffe2.python import workspace

from detectron.core.config import assert_and_infer_cfg
from detectron.core.config import cfg
from detectron.core.config import merge_cfg_from_file
from detectron.utils.io import cache_url
from detectron.utils.logging import setup_logging
from detectron.utils.timer import Timer
import detectron.core.test_engine as infer_engine
import detectron.datasets.dummy_datasets as dummy_datasets
import detectron.utils.c2 as c2_utils
import detectron.utils.vis as vis_utils

c2_utils.import_detectron_ops()

# OpenCL may be enabled by default in OpenCV3; disable it because it's not
# thread safe and causes unwanted GPU memory allocations.
cv2.ocl.setUseOpenCL(False)

def get_model(cfg_file, weights_file):
    merge_cfg_from_file(cfg_file)
    cfg.TRAIN.WEIGHTS = '' # NOTE: do not download pretrained model weights
    cfg.TEST.WEIGHTS = weights_file
    cfg.NUM_GPUS = 1
    assert_and_infer_cfg()
    model = infer_engine.initialize_model_from_cfg(cfg.TEST.WEIGHTS)
    return model

# cfg_file = '{}/configs/12_2017_baselines/e2e_mask_rcnn_R-101-FPN_2x.yaml'.format(DETECTRON_ROOT)
cfg_file = '/home/ceej/DetectronRepo/detectron/configs/12_2017_baselines/defects_X-101.yaml'
# weights_file = 'https://s3-us-west-2.amazonaws.com/detectron/' + \
#     '35861858/12_2017_baselines/e2e_mask_rcnn_R-101-FPN_2x.yaml.02_32_51.SgT4y1cO/' + \
#     'output/train/coco_2014_train:coco_2014_valminusminival/generalized_rcnn/model_final.pkl'
weights_file = '/home/ceej/aluminium_surface_defects_detection/m.pkl'
model = get_model(cfg_file, weights_file)

from caffe2.python import net_drawer
g = net_drawer.GetPydotGraph(model, rankdir="TB")
print(model.Proto().name + '.dot')
g.write_dot(model.Proto().name + '.dot')

@jiec-msft
Copy link

By the way, I installed xdot on my ubuntu 16.04 lts with

sudo apt-get update
sudo apt-get install xdot

@jiec-msft
Copy link

jiec-msft commented Mar 16, 2019

And the result can be like this:

https://note.youdao.com/yws/api/personal/file/WEB86b1c61048404045fa8a514b209dfe06?method=getImage&version=22336&cstk=vij4EFEr

The url listed above doesn't support firefox

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

9 participants