CSI/datasets/postprocess_data.py
2022-04-29 19:26:47 +02:00

38 lines
1.4 KiB
Python

import re
import matplotlib.pyplot as plt
PATH = r'C:\Users\feokt\PycharmProjects\CSI\CSI\logs'
def postprocess_data(log: list):
for pth in log:
loss_sim = []
loss_shift = []
with open(PATH + pth) as f:
lines = f.readlines()
for line in lines:
# line = '[2022-01-31 20:40:23.947855] [DONE] [Time 0.179] [Data 0.583] [LossC 0.000000] [LossSim 4.024234] [LossShift 0.065126]'
part = re.search('\[DONE\]', line)
if part is not None:
l_sim = re.search('(\[LossSim.[0-9]*.[0-9]*\])', line).group()
if l_sim is not None:
loss_sim.append(float(re.search('(\s[0-9].*[0-9])', l_sim).group()))
l_shift = re.search('(\[LossShift.[0-9]*.[0-9]*\])', line).group()
if l_shift is not None:
loss_shift.append(float(re.search('(\s[0-9].*[0-9])', l_shift).group()))
loss = [loss_sim[i] + loss_shift[i] for i in range(len(loss_sim))]
plt.ylabel("loss")
plt.xlabel("epoch")
plt.title("Loss over epochs")
plt.plot(list(range(1, 101)), loss)
for idx in range(len(log)):
log[idx] = log[idx][38:]
plt.legend(log)
plt.grid()
#plt.plot(list(range(1, 101)), loss_sim)
#plt.plot(list(range(1, 101)), loss_shift)
plt.show()