Commit 723ca30d by 李宇轩

成了

parent 85de8963
......@@ -95,7 +95,7 @@ class FedAvg(Server):
for item in topKdict:
topK.append(int(item))
self.selected_users = self.select_users(glob_iter, self.num_users)
self.selected_users = self.select_users_(glob_iter, self.num_users,topK)
self.aggregate_parameters()
# for param in self.model.parameters():
......
......@@ -59,6 +59,18 @@ class UserAVG(User):
# print(loss)
sys.stdout.flush()
return LOSS
def test_others(self):
self.mount.eval()
test_acc = 0
for x, y in self.testloaderfull:
x, y = x.to(self.device), y.to(self.device)
output = self.model(x)
test_acc += (torch.sum(torch.argmax(output, dim=1) == y)).item()
#@loss += self.loss(output, y)
#print(self.id + ", Test Accuracy:", test_acc / y.shape[0] )
#print(self.id + ", Test Loss:", loss)
return test_acc, y.shape[0]
#训练完之后跑
def anothertest(self, gradients,glob_iter):
......@@ -70,9 +82,9 @@ class UserAVG(User):
scores = []
for gradient,id in gradients:
# self.set_parameters(gradient)
model = self.model
self.model = copy.deepcopy(gradient)
acc, n = self.test()
# model = self.model
self.mount = copy.deepcopy(gradient)
acc, n = self.test_others()
score = acc/n
r=requests.get(url="http://10.134.153.83:8899/score",params={
"cid":self.id,
......@@ -83,5 +95,5 @@ class UserAVG(User):
# print(r)
sys.stdout.flush()
# scores.append(acc)
self.model = copy.deepcopy(model)
# self.model = copy.deepcopy(model)
# return scores
......@@ -34,6 +34,8 @@ class User:
self.local_model = copy.deepcopy(list(self.model.parameters()))
self.persionalized_model = copy.deepcopy(list(self.model.parameters()))
self.persionalized_model_bar = copy.deepcopy(list(self.model.parameters()))
self.mount = 0
def set_parameters(self, model):
for old_param, new_param, local_param in zip(self.model.parameters(), model.parameters(), self.local_model):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment