|
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758 |
- import pickle
- import random
- import string
- from datetime import datetime
-
- import torch
- import torch.nn as nn
-
-
- class IncrementalAverage:
- def __init__(self):
- self.value = 0
- self.counter = 0
-
- def update(self, x):
- self.counter += 1
- self.value += (x - self.value) / self.counter
-
-
- class Flatten(nn.Module):
- def forward(self, x):
- return x.view(x.size(0), -1)
-
-
- class SizePrinter(nn.Module):
- def forward(self, x):
- print(x.size())
- return x
-
-
- def count_parameters(model, grad_only=True):
- return sum(p.numel() for p in model.parameters() if not grad_only or p.requires_grad)
-
-
- def to_device(device, *tensors):
- return tuple(x.to(device) for x in tensors)
-
-
- def loop_iter(iter):
- while True:
- for item in iter:
- yield item
-
-
- def unique_string():
- return '{}.{}'.format(datetime.now().strftime('%Y%m%dT%H%M%SZ'),
- ''.join(random.choice(string.ascii_uppercase) for _ in range(4)))
-
-
- def set_seeds(seed):
- random.seed(seed)
- torch.manual_seed(seed)
- torch.cuda.manual_seed_all(seed)
-
-
- def pickle_dump(obj, file):
- with open(file, 'wb') as f:
- pickle.dump(obj, f)
|