38 lines
1.4 KiB
Python
38 lines
1.4 KiB
Python
import re
|
|
import matplotlib.pyplot as plt
|
|
|
|
PATH = r'C:\Users\feokt\PycharmProjects\CSI\CSI\logs'
|
|
|
|
|
|
def postprocess_data(log: list):
|
|
for pth in log:
|
|
loss_sim = []
|
|
loss_shift = []
|
|
with open(PATH + pth) as f:
|
|
lines = f.readlines()
|
|
for line in lines:
|
|
# line = '[2022-01-31 20:40:23.947855] [DONE] [Time 0.179] [Data 0.583] [LossC 0.000000] [LossSim 4.024234] [LossShift 0.065126]'
|
|
part = re.search('\[DONE\]', line)
|
|
if part is not None:
|
|
l_sim = re.search('(\[LossSim.[0-9]*.[0-9]*\])', line).group()
|
|
if l_sim is not None:
|
|
loss_sim.append(float(re.search('(\s[0-9].*[0-9])', l_sim).group()))
|
|
l_shift = re.search('(\[LossShift.[0-9]*.[0-9]*\])', line).group()
|
|
if l_shift is not None:
|
|
loss_shift.append(float(re.search('(\s[0-9].*[0-9])', l_shift).group()))
|
|
loss = [loss_sim[i] + loss_shift[i] for i in range(len(loss_sim))]
|
|
|
|
plt.ylabel("loss")
|
|
plt.xlabel("epoch")
|
|
plt.title("Loss over epochs")
|
|
plt.plot(list(range(1, 101)), loss)
|
|
for idx in range(len(log)):
|
|
log[idx] = log[idx][38:]
|
|
plt.legend(log)
|
|
plt.grid()
|
|
#plt.plot(list(range(1, 101)), loss_sim)
|
|
#plt.plot(list(range(1, 101)), loss_shift)
|
|
plt.show()
|
|
|
|
|