Skip to content

Commit

Permalink
Merge pull request #22 from Kasliwal17/main
Browse files Browse the repository at this point in the history
Fixed Verification Module Functionality and FedDyn aggregation
  • Loading branch information
Kasliwal17 authored Dec 14, 2023
2 parents 1447edc + 86d9e20 commit 8db9bbc
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 12 deletions.
8 changes: 5 additions & 3 deletions federa/server/src/algorithms/feddyn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
class feddyn():

def __init__(self, config):
self.algorithm = "Mime"
self.algorithm = "FedDyn"
self.lr = 1.0
self.momentum = 0.9
self.h = None
Expand All @@ -28,14 +28,16 @@ def aggregate(self, server_model_state_dict, state_dicts):

delta_x = [torch.zeros_like(server_model_state_dict[key]) for key in server_model_state_dict.keys()]
for d_x, key in zip(delta_x, keys):
d_x.data = sum_y[key] - server_model_state_dict[key]
d_x.data = sum_y[key]/len(state_dicts) - server_model_state_dict[key].to(sum_y[key].device)

#Update h
for h, d_x in zip(self.h, delta_x):
h.data = h.data.to(d_x.data.device)
h.data -= (self.alpha/len(state_dicts)) * d_x.data


#Update x
for key, h in zip(keys, self.h):
server_model_state_dict[key] = (sum_y[key]/len(state_dicts)) - (h.data/self.alpha)

return server_model_state_dict
return server_model_state_dict
33 changes: 24 additions & 9 deletions federa/server/src/verification.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,37 +2,48 @@
from random import randint
from collections import OrderedDict
from concurrent import futures
import copy

def verify(clients, trained_model_state_dicts, save_dir_path, threshold = 0):
##modify the verify function to consider the updated control variates also
def verify(clients, trained_model_state_dicts, save_dir_path, threshold = 0, updated_control_variates = None, server_model_state_dict = None):
verification_dict = OrderedDict()
config_dict = {"message": "verify"}

##if server_model_state_dict is not None then to each trained_model_state_dict, we need to add the server_model_state_dict
if server_model_state_dict is not None:
for i in range(len(trained_model_state_dicts)):
for key in server_model_state_dict.keys():
trained_model_state_dicts[i][key] += server_model_state_dict[key]

for i, client in zip( range(len(clients)), clients):
verification_dict[client.client_id] = {"client_wrapper_object": client, "model": trained_model_state_dicts[i]}
verification_dict[client.client_id] = {"client_wrapper_object": client, "model": trained_model_state_dicts[i], "control_variates": updated_control_variates[i]}
client_ids = list(verification_dict.keys())
client_ids_shuffled = random_derangement(client_ids)
for i, client_id in zip( range(len(verification_dict)), verification_dict.keys() ):
verification_dict[client_id]["assigned_client_id"] = client_ids_shuffled[i]

with futures.ThreadPoolExecutor(max_workers = 20) as executor:
result_futures = []

for client_id, client_info in verification_dict.items():
assigned_client_id = client_info["assigned_client_id"]
assigned_client = verification_dict[assigned_client_id]["client_wrapper_object"]
model_to_verify = client_info["model"]
result_futures.append(executor.submit(assigned_client.evaluate, model_to_verify, config_dict))
config_dict['client_id'] = client_id
config_dict_s = copy.deepcopy(config_dict)
result_futures.append(executor.submit(assigned_client.evaluate, model_to_verify, config_dict_s))


verification_results = [result_future.result() for result_future in futures.as_completed(result_futures)]

for client_id, index in zip(verification_dict.keys(), range(len(verification_results))):
verification_dict[client_id]["score"] = verification_results[index]["eval_accuracy"]
for index in range(len(verification_results)):
verification_dict[verification_results[index]["client_id"]]["score"] = verification_results[index]["eval_accuracy"]


selected_client_models, ignored_client_models = [], []
selected_client_models, ignored_client_models, selected_control_variates = [], [], []
for client_id, client_info in verification_dict.items():
if client_info["score"] >= threshold:
selected_client_models.append(client_info["model"])
selected_control_variates.append(client_info["control_variates"])
client_info["selected"] = True

else:
Expand Down Expand Up @@ -71,12 +82,16 @@ def verify(clients, trained_model_state_dicts, save_dir_path, threshold = 0):
with open(f"{save_dir_path}/verification_ignored_stats.txt", "a", encoding='UTF-8') as file:
file.write( f"{ignored_info_dict}\n" )

if server_model_state_dict is not None:
for i in range(len(selected_client_models)):
for key in server_model_state_dict.keys():
selected_client_models[i][key] -= server_model_state_dict[key]

return selected_client_models
return selected_client_models, selected_control_variates


def random_derangement(list_to_shuffle):
for index1 in range(1, len(list_to_shuffle)):
index2 = randint(0, index1 - 1) # nosec
list_to_shuffle[index1], list_to_shuffle[index2] = list_to_shuffle[index2], list_to_shuffle[index1]
return list_to_shuffle
return list_to_shuffle

0 comments on commit 8db9bbc

Please sign in to comment.