Skip to content

Commit

Permalink
Merge pull request ispras#26 from ispras/QAttack_fix
Browse files Browse the repository at this point in the history
Fix QAttack
  • Loading branch information
LukyanovKirillML authored Oct 17, 2024
2 parents c9a1cc3 + 95e9dba commit c7e42d2
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
1 change: 0 additions & 1 deletion .github/workflows/linter.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ name: Pylint

on:
push:
pull_request:

jobs:
lint:
Expand Down
12 changes: 7 additions & 5 deletions src/attacks/QAttack/qattack.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ def fitness(self, model, gen_dataset):

# Get labels from black-box
labels = model.gnn.get_answer(dataset.x, dataset.edge_index)
labeled_nodes = {n: labels.tolist()[n-1] for n in adj_list.keys()} # FIXME check order for labels and node id consistency
labeled_nodes = dict(enumerate(labels.tolist()))
# labeled_nodes = {n: labels.tolist()[n-1] for n in adj_list.keys()} # FIXME check order for labels and node id consistency

# Calculate modularity
Q = self.modularity(adj_list, labeled_nodes)
Expand All @@ -76,7 +77,7 @@ def fitness_individual(self, model, gen_dataset, gene):

# Get labels from black-box
labels = model.gnn.get_answer(dataset.x, dataset.edge_index)
labeled_nodes = {n: labels.tolist()[n-1] for n in adj_list.keys()} # FIXME check order for labels and node id consistency
labeled_nodes = dict(enumerate(labels.tolist()))

# Calculate modularity
Q = self.modularity(adj_list, labeled_nodes)
Expand Down Expand Up @@ -191,6 +192,7 @@ def mutation(self, gen_dataset):
dataset.edge_index = from_adj_list(adj_list)
non_isolated_nodes = set(gen_dataset.dataset.edge_index[0].tolist()).union(
set(gen_dataset.dataset.edge_index[1].tolist()))
non_drain_nodes = set(gen_dataset.dataset.edge_index[0].tolist())
if mut_type == 0:
# add mutation
connected_nodes = set(self.adj_list[n])
Expand All @@ -202,8 +204,8 @@ def mutation(self, gen_dataset):
self.population[i][n]['del'] = np.random.choice(list(adj_list[n]), 1)
else:
selected_nodes = set(self.population[i].keys())
non_selected_nodes = non_isolated_nodes.difference(selected_nodes)
new_node = np.random.choice(list(non_selected_nodes), size=1, replace=False)[0]
non_drain_nodes = non_drain_nodes.difference(selected_nodes)
new_node = np.random.choice(list(non_drain_nodes), size=1, replace=False)[0]
self.population[i].pop(n)
addition_nodes = non_isolated_nodes.difference(set(self.adj_list[new_node]))
self.population[i][new_node] = {}
Expand Down Expand Up @@ -236,4 +238,4 @@ def attack(self, model_manager, gen_dataset, mask_tensor):
set(adj_list[n]).union(set([int(rewiring[n]['add'])])).difference(set([int(rewiring[n]['del'])])))

gen_dataset.dataset.data.edge_index = from_adj_list(adj_list)
return gen_dataset
return gen_dataset

0 comments on commit c7e42d2

Please sign in to comment.