Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

copy edata in to_simple #104

Open
wants to merge 3 commits into
base: development
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions dgl_ptm/dgl_ptm/network/global_attachment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@ def global_attachment(agent_graph, device, ratio: float):
'''
global_attachment - randomly connects different agents globally based on a ratio

Args:
Args:
agent_graph: DGLGraph with agent nodes and edges connecting agents
ratio: ratio of number of new edges to add to total existing edges in graph

Output:
Modified agent_graph with new edges introduced
Modified agent_graph with new edges introduced
'''

# Add edges based on ratio
agent_graph = AddEdge(ratio=ratio)(agent_graph)

Expand All @@ -23,7 +23,8 @@ def global_attachment(agent_graph, device, ratio: float):

# Remove duplicate edges
# dgl.to_simple works only on device=cpu hence we move the graph to cpu:
agent_graph = dgl.to_simple(agent_graph.to('cpu'), return_counts='cnt')
# to_simple by default copies ndata but not edata, hence we need to copy edata explicitly.
agent_graph = dgl.to_simple(agent_graph.to('cpu'), return_counts='cnt', copy_edata=True)
# move the graph back to user choice of device.
# This is necessary for running on cuda or other hardware.
agent_graph = agent_graph.to(device)
30 changes: 30 additions & 0 deletions dgl_ptm/tests/test_network.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
import pytest
import dgl_ptm
import os
import dgl
from dgl import AddEdge, AddReverse

from dgl_ptm.network.global_attachment import global_attachment
from dgl_ptm.network.link_deletion import link_deletion
from dgl_ptm.network.local_attachment import local_attachment
from dgl_ptm.network.network_creation import network_creation

from dgl_ptm.agentInteraction.trade_money import trade_money
from dgl_ptm.network.local_attachment import local_attachment
from dgl_ptm.network.link_deletion import link_deletion


os.environ["DGLBACKEND"] = "pytorch"

Expand All @@ -29,6 +35,30 @@ def test_global_attachment(self, model):
assert updated_number_of_edges > ratio * current_number_of_edges
assert updated_number_of_edges < (1 + ratio) * current_number_of_edges

def test_global_attachment_to_simple(self, model):
agent_graph = model.model_graph
params = model.steering_parameters

# step operations on agent_graph
trade_money(agent_graph, 'cpu', method = params['wealth_method'])
local_attachment(agent_graph, n_FoF_links = 1, edge_prop='weight', p_attach=1. )
link_deletion(agent_graph, method=params['del_method'], threshold=params['del_threshold'])

# global attachment operations on agent_graph
agent_graph = AddEdge(ratio=params['noise_ratio'])(agent_graph)
agent_graph = AddReverse()(agent_graph)

# edata are not copied by default
simple_agent_graph = dgl.to_simple(agent_graph, return_counts='cnt')
assert 'weight' not in simple_agent_graph.edata
assert 'wealth_diff' not in simple_agent_graph.edata
assert 'theta' in simple_agent_graph.ndata # check ndata

# copy edata explicitly
simple_agent_graph = dgl.to_simple(agent_graph, return_counts='cnt', copy_edata=True)
assert 'weight' in simple_agent_graph.edata
assert 'wealth_diff' in simple_agent_graph.edata


class TestLinkDeletion:
def test_link_deletion(self, model):
Expand Down
Loading