ML-Computer Vision/MMdetection

[mmdetection] custom dataset 학습방법

dohyeon2 2022. 8. 20. 13:18

오늘은 custom dataset(육계 데이터셋)을 이용해 mmdetection model을 training하는 방법을 정리하고자 한다. 


1. training.py file 생성

from mmdet.apis import init_detector, inference_detector
import mmcv
import torch
import cv2
from mmdet.datasets.builder import DATASETS
from mmdet.datasets.coco import CocoDataset
from mmcv import Config
print(f"Setup complete. Using torch {torch.__version__} ({torch.cuda.get_device_properties(0).name if torch.cuda.is_available() else 'CPU'})")
print(cv2.__version__)

# config 파일을 설정하고, 다운로드 받은 pretrained 모델을 checkpoint로 설정. 
config_file = 'configs/swin/mask_rcnn_swin-t-p4-w7_fpn_1x_coco.py'
checkpoint_file = 'checkpoints/mask_rcnn_swin-t-p4-w7_fpn_1x_coco_20210902_120937-9d6b7cfa.pth'

@DATASETS.register_module(force=True)
class VOCDataset(CocoDataset):
  CLASSES = ('chicken',)

# config file 호출.
cfg = Config.fromfile(config_file)

###############################################################################################################################
from mmdet.apis import set_random_seed

# dataset에 대한 환경 파라미터 수정. 
cfg.dataset_type = 'VOCDataset'
cfg.data_root = '/coco_output/'

# train, val, test dataset에 대한 type, data_root, ann_file, img_prefix 환경 파라미터 수정. 
cfg.data.train.type = 'VOCDataset'
cfg.data.train.data_root = '/scratch/dohyeon/mmdetection/coco_output/'
cfg.data.train.ann_file = 'annotations/instances_Train.json'
cfg.data.train.img_prefix = 'train'

cfg.data.val.type = 'VOCDataset'
cfg.data.val.data_root = '/scratch/dohyeon/mmdetection/coco_output/'
cfg.data.val.ann_file = 'annotations/instances_Validation.json'
cfg.data.val.img_prefix = 'val'


# class의 갯수를 pascal voc로 설정.  수정. 
cfg.model.roi_head.bbox_head.num_classes = 1
cfg.model.roi_head.mask_head.num_classes = 1

# pretrained 모델
cfg.load_from = 'checkpoints/mask_rcnn_swin-t-p4-w7_fpn_1x_coco_20210902_120937-9d6b7cfa.pth'

# 학습 weight 파일로 로그를 저장하기 위한 디렉토리 설정. 
cfg.work_dir = './tutorial_exps_swin'

# 학습율 변경 환경 파라미터 설정. 
cfg.optimizer.lr = 0.0001 / 8
cfg.lr_config.warmup = None
cfg.log_config.interval = 10

# epoch 변경 환경 파라미터 설정
cfg.runner.max_epochs=300

# CocoDataset의 경우 metric을 bbox로 설정해야 함.(mAP아님. bbox로 설정하면 mAP를 iou threshold를 0.5 ~ 0.95까지 변경하면서 측정)
cfg.evaluation.metric = ['bbox', 'segm']
cfg.evaluation.interval = 5
cfg.checkpoint_config.interval = 5

# 두번 config를 로드하면 lr_config의 policy가 사라지는 오류로 인하여 설정. 
cfg.lr_config.policy='step'
# Set seed thus the results are more reproducible
cfg.seed = 0
set_random_seed(0, deterministic=False)
cfg.gpu_ids = range(1)

###################################################################################################################################
from mmdet.datasets import build_dataset
from mmdet.models import build_detector
from mmdet.apis import train_detector

# train용 Dataset 생성. 
datasets = [build_dataset(cfg.data.train)]

model = build_detector(cfg.model, train_cfg=cfg.get('train_cfg'), test_cfg=cfg.get('test_cfg'))
model.CLASSES = datasets[0].CLASSES
print(model.CLASSES)

# ##################################################################################################
#Training !! 
import os.path as osp
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
# epochs는 config의 runner 파라미터로 지정됨. 기본 12회 
train_detector(model, datasets, cfg, distributed=False, validate=True)
###################################################################################################

# inference 테스트 코드(선택사항). 
# from mmdet.apis import show_result_pyplot
# import cv2
# checkpoint_file = '/scratch/dohyeon/mmdetection/tutorial_exps/epoch_12.pth'

# # checkpoint 저장된 model 파일을 이용하여 모델을 생성, 이때 Config는 위에서 update된 config 사용. 
# model_ckpt = init_detector(cfg, checkpoint_file, device='cuda:0')

# # sample image에 적용.
# img = cv2.imread('/scratch/dohyeon/mmdetection/demo/2021-11-11-07_003.jpg')
# result = inference_detector(model_ckpt, img)
# show_result_pyplot(model_ckpt, img, result, score_thr=0.5)

핵심은 

# config 파일을 설정하고, 다운로드 받은 pretrained 모델을 checkpoint로 설정. 
config_file = 'configs/swin/mask_rcnn_swin-t-p4-w7_fpn_1x_coco.py'
checkpoint_file = 'checkpoints/mask_rcnn_swin-t-p4-w7_fpn_1x_coco_20210902_120937-9d6b7cfa.pth'

 

config_file과 checkpoint_file을 지정해주는 것이다. 

기존 mmdetection code는 

/mmdetection/config/swin/ 안에 내가 사용하고자하는 model의 config file이 존재한다. 

따라서 해당 경로에서 본인이 사용하고자 하는 model의 config file의 경로를 가져오면 되고 

 

checkpoint_file은  코드안에 존재하지 않기 때문에 

mmdetection github에서 다운로드한 뒤 checkpoints 폴더를 생성한 뒤 해당 폴더안에 넣어주면 된다. 

다운로드 받는법

https://github.com/open-mmlab/mmdetection/tree/master/configs/swin

 

GitHub - open-mmlab/mmdetection: OpenMMLab Detection Toolbox and Benchmark

OpenMMLab Detection Toolbox and Benchmark. Contribute to open-mmlab/mmdetection development by creating an account on GitHub.

github.com

 

2. 오류 디버깅 

오류1. 학습이 진행은 되지만 mAP가 모두 0으로 나오는 에러 

: learning rate가 너무 커서 발생하는 문제이므로 learning rate를 줄여서 다시 학습하면 해결할 수 있음 

 

오류2. assert len(data_loaders) == len(workflow) assertionerror 에러 

: 해당 에러는 아래 사진에서 workflow = [('train', 1), ('val', 1)] 

과 같이 설정되어 있을 때 발생하는 에러이다. 이때 아래 사진처럼 수정하면 해결된다.