-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgraphing_class.py
34 lines (24 loc) · 1.01 KB
/
graphing_class.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import matplotlib.pyplot as plt
#Arithmetic mean
class CreateGraph:
def __init__(self, batch_count,name):
self.batch_count = batch_count
self.num_for_G = 0
self.num_for_D = 0
self.array_epoch, self.array_loss_G, self.array_loss_D = [], [], []
self.name = name
def count(self, epoch):
self.array_epoch.append(epoch)
self.array_loss_G.append(self.num_for_G / self.batch_count)
self.array_loss_D.append(self.num_for_D / self.batch_count)
#print("Epoch average loss: Generator - {}, Discriminator - {}".format(self.num_for_G / self.batch_count, self.num_for_D / self.batch_count))
self.num_for_G = 0
self.num_for_D = 0
plt.plot(self.array_epoch, self.array_loss_G)
plt.plot(self.array_epoch, self.array_loss_D)
plt.title(self.name)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend(["Generator loss", "Discriminator loss"])
plt.savefig(self.name + ".png")
plt.show()