-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathserver.py
101 lines (81 loc) · 2.69 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
"""
Federated Learning client using the Flower framework.
"""
from pathlib import Path
import tensorflow as tf
import flwr as fl
import json
from anomaly_flow.data.netflow import NetFlowV2
from anomaly_flow.train.trainer_flow_nids import GANomaly
hps = dict()
with open('hps.json', 'r', encoding='utf-8') as file:
hps = json.load(file)
NUM_ROUNDS = 10
def average_metrics(metrics):
"""
Function to calculate the average metrics for the clients.
"""
auc_rcs = [metric["auc_rc"] for _, metric in metrics]
auc_rocs = [metric["auc_roc"] for _, metric in metrics]
f1_values = [metric["f1_value"] for _, metric in metrics]
f2_values = [metric["f2_value"] for _, metric in metrics]
acc_values = [metric["acc_value"] for _, metric in metrics]
auc_rcs = sum(auc_rcs) / len(auc_rcs)
auc_rocs = sum(auc_rocs) / len(auc_rocs)
f1_values = sum(f1_values) / len(f1_values)
f2_values = sum(f2_values) / len(f2_values)
acc_values = sum(acc_values) / len(acc_values)
return {
"auc_rc": auc_rcs,
"auc_roc": auc_rocs,
"f1_score": f1_values,
"f2_score": f2_values,
"accuracy": acc_values
}
def create_model_trainer():
"""
Function to create a dummy centralized model for Federated Learning Schema.
"""
netflow_dataset = NetFlowV2("NF-CSE-CIC-IDS2018-v2-DDoS-downsample", train_size=5000)
netflow_dataset.configure(
hps["batch_size"], 52, 1,
hps["shuffle_buffer_size"], True, True
)
return GANomaly(
netflow_dataset,
hps,
tf.summary.create_file_writer("logs"),
Path("log")
)
def pretrain():
"""
Pre train the model to initialize default weights.
"""
netflow_trainer = create_model_trainer()
model = netflow_trainer.get_model()
weights = model.get_weights()
return weights
def main():
"""
Creates the default main function for the Federated Learning Schema.
"""
print(">>> Flower version:", fl.__version__)
strategy_1 = fl.server.strategy.FedAvg(
evaluate_metrics_aggregation_fn=average_metrics,
min_fit_clients=3,
min_available_clients=3,
initial_parameters=fl.common.ndarrays_to_parameters(pretrain())
)
# strategy_2 = fl.server.strategy.QFedAvg(
# evaluate_metrics_aggregation_fn=average_metrics,
# min_fit_clients=3,
# min_available_clients=3,
# initial_parameters=fl.common.ndarrays_to_parameters(pretrain())
# )
fl.server.start_server(
server_address="0.0.0.0:8081",
strategy=strategy_1,
config=fl.server.ServerConfig(num_rounds=NUM_ROUNDS)
)
if __name__ == "__main__":
main()