Development/Python

Python - 자주 사용하는 Utils

투푸월드 2023. 7. 30. 18:14

서버에서 돌리기 위해서는 터미널 출력뿐만 아니라 더 안전하게 출력물들을 기록 해야한다. 또한 검색등을 활용 할 수 있으므로 텍스트 파일에 자세하고 알기 쉽게 많은것을 저장해 두자. 아니면 그냥 쉽게 print 함수를 쓰지 말고 의미 있는 성능들은 모두 logger로 보내자.

LoggerPermalink

import sys

class Logger(object):
    def __init__(self, local_rank=0, no_save=False):
        self.terminal = sys.stdout
        self.file = None
        self.local_rank = local_rank
        self.no_save = no_save
    def open(self, fp, mode=None):
        if mode is None: mode = 'w'
        if self.local_rank and not self.no_save == 0: self.file = open(fp, mode)
    def write(self, msg, is_terminal=1, is_file=1):
        if msg[-1] != "\n": msg = msg + "\n"
        if self.local_rank == 0:
            if '\r' in msg: is_file = 0
            if is_terminal == 1:
                self.terminal.write(msg)
                self.terminal.flush()
            if is_file == 1 and not self.no_save:
                self.file.write(msg)
                self.file.flush()
    def flush(self): 
        pass

읽어보면 쉽다. self.terminal은 print처럼 화면에 sysout하는 것이고, (print보다 빠름), self.file 은 메모장 txt켜서 거기에 저장. 반드시 logger = Logger()선언을 하고 logger.open(“asdasd.txt”)를 해줘야 한다.

  • 2021-07-23: 파이참에서는
    print("asdasd \r")
    
    과같이 \r을 맨 마지막에 사용하면 콘솔창에 출력자체가 안되므로
print("\r asdasd")

처럼 \r을 처음에 쓰자. 추가적으로 어차피 기능은 똑같으므로 이렇게 쓰는 습관을 들이자.

예시: 따라서 \r을 사용할때

for iter, data in enumerate(train_loader):
    if iter % 10 == 0:
        logger.write(f"\r[{iter}/{len(train_loader)}]")
    if iter % 1000 == 0:
        logger.write("error msg write")

와 같이 진행상황은 무시되지만 콘솔에서는 볼 수 있는 코드를 사용할 수 있다.

Print_argsPermalink

  • 2021-07-27: print_args도 많이 써서 기록해 둔다.
    def print_args(args, logger=None):
      if logger is not None:
          logger.write("#### configurations ####")
      for k, v in vars(args).items():
          if logger is not None:
              logger.write('{}: {}\n'.format(k, v))
          else:
              print('{}: {}'.format(k, v))
      if logger is not None:
          logger.write("########################")
    

save_args & load_argsPermalink

  • 2022-04-30: 보통 모델 빌드 때 파라미터를 args에 다 저장하는데 나중에 모델 불러올때 같은 구조를 맞추기 위해 args까지 저장해 두고 불러온다.
    import argparse
    import json
    def save_args(args, to_path):
      with open(to_path, "w") as f:
          json.dump(args.__dict__, f, indent=2)
    def load_args(from_path):
      parser = argparse.ArgumentParser()
      args = parser.parse_args()
      with open(from_path, "r") as f:
          args.__dict__ = json.load(f)
      return args    
    

AverageMeterPermalink

AverageMeter도 기록해두자. 이거 되게 편하다

class AverageMeter (object):
    def __init__(self):
        self.reset ()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

AccuracyPermalink

def Accuracy(output, target, topk=(1,)):
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res
import random
import torch
import numpy as np

def fix_seed(seed: int) -> None:
  torch.manual_seed(seed)
  torch.cuda.manual_seed(seed)
  torch.cuda.manual_seed_all(seed)
  torch.backends.cudnn.deterministic = True # this can slow down speed
  torch.backends.cudnn.benchmark = False
  np.random.seed(seed)
  random.seed(seed)

torch.manual_seed : torch.~~ 로 시작하는 모든 코드의 난수를 고정 시킬 수 있다.

torch.cuda.manual_seed : torch.cuda~~로 시작하는 모든 코드의 나수를 고정 시킬 수 있다.

torch.cuda.manual_seed_all: multi_gpu를 사용할때 난수를 고정 시켜준다.

torch.backends.cudnn.deterministic: 파이토치는 cudnn을 백엔드로 사용하기 때문에 이것도 설정한다. 하지만 속도가 느려질 수 있다.

torch.backends.cudnn.benchmark: True이면 convolution 연산을 할 때 입력 사이즈에 맞게 최적화된 알고리즘을 쓴다. 단점은 입력 이미지 사이즈가 너무 다르면 오히려 성능이 저하될 수 있다.

np.random.seed: 파이토치에서 많은 코드가 넘파이로 데이터를 받아오기 때문에 넘파이 시드도 고정 시켜야 한다.

random.seed: torchvision의 transform에서 RandomCrop, RandomHorizontalFlip등은 python의 random을 사용한다. 따라서 이것도 필요함