목차
오늘은 custom dataset(육계 데이터셋)을 이용해 mmdetection model을 training하는 방법을 정리하고자 한다.
1. training.py file 생성
from mmdet.apis import init_detector, inference_detector
import mmcv
import torch
import cv2
from mmdet.datasets.builder import DATASETS
from mmdet.datasets.coco import CocoDataset
from mmcv import Config
print(f"Setup complete. Using torch {torch.__version__} ({torch.cuda.get_device_properties(0).name if torch.cuda.is_available() else 'CPU'})")
print(cv2.__version__)
# config 파일을 설정하고, 다운로드 받은 pretrained 모델을 checkpoint로 설정.
config_file = 'configs/swin/mask_rcnn_swin-t-p4-w7_fpn_1x_coco.py'
checkpoint_file = 'checkpoints/mask_rcnn_swin-t-p4-w7_fpn_1x_coco_20210902_120937-9d6b7cfa.pth'
@DATASETS.register_module(force=True)
class VOCDataset(CocoDataset):
CLASSES = ('chicken',)
# config file 호출.
cfg = Config.fromfile(config_file)
###############################################################################################################################
from mmdet.apis import set_random_seed
# dataset에 대한 환경 파라미터 수정.
cfg.dataset_type = 'VOCDataset'
cfg.data_root = '/coco_output/'
# train, val, test dataset에 대한 type, data_root, ann_file, img_prefix 환경 파라미터 수정.
cfg.data.train.type = 'VOCDataset'
cfg.data.train.data_root = '/scratch/dohyeon/mmdetection/coco_output/'
cfg.data.train.ann_file = 'annotations/instances_Train.json'
cfg.data.train.img_prefix = 'train'
cfg.data.val.type = 'VOCDataset'
cfg.data.val.data_root = '/scratch/dohyeon/mmdetection/coco_output/'
cfg.data.val.ann_file = 'annotations/instances_Validation.json'
cfg.data.val.img_prefix = 'val'
# class의 갯수를 pascal voc로 설정. 수정.
cfg.model.roi_head.bbox_head.num_classes = 1
cfg.model.roi_head.mask_head.num_classes = 1
# pretrained 모델
cfg.load_from = 'checkpoints/mask_rcnn_swin-t-p4-w7_fpn_1x_coco_20210902_120937-9d6b7cfa.pth'
# 학습 weight 파일로 로그를 저장하기 위한 디렉토리 설정.
cfg.work_dir = './tutorial_exps_swin'
# 학습율 변경 환경 파라미터 설정.
cfg.optimizer.lr = 0.0001 / 8
cfg.lr_config.warmup = None
cfg.log_config.interval = 10
# epoch 변경 환경 파라미터 설정
cfg.runner.max_epochs=300
# CocoDataset의 경우 metric을 bbox로 설정해야 함.(mAP아님. bbox로 설정하면 mAP를 iou threshold를 0.5 ~ 0.95까지 변경하면서 측정)
cfg.evaluation.metric = ['bbox', 'segm']
cfg.evaluation.interval = 5
cfg.checkpoint_config.interval = 5
# 두번 config를 로드하면 lr_config의 policy가 사라지는 오류로 인하여 설정.
cfg.lr_config.policy='step'
# Set seed thus the results are more reproducible
cfg.seed = 0
set_random_seed(0, deterministic=False)
cfg.gpu_ids = range(1)
###################################################################################################################################
from mmdet.datasets import build_dataset
from mmdet.models import build_detector
from mmdet.apis import train_detector
# train용 Dataset 생성.
datasets = [build_dataset(cfg.data.train)]
model = build_detector(cfg.model, train_cfg=cfg.get('train_cfg'), test_cfg=cfg.get('test_cfg'))
model.CLASSES = datasets[0].CLASSES
print(model.CLASSES)
# ##################################################################################################
#Training !!
import os.path as osp
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
# epochs는 config의 runner 파라미터로 지정됨. 기본 12회
train_detector(model, datasets, cfg, distributed=False, validate=True)
###################################################################################################
# inference 테스트 코드(선택사항).
# from mmdet.apis import show_result_pyplot
# import cv2
# checkpoint_file = '/scratch/dohyeon/mmdetection/tutorial_exps/epoch_12.pth'
# # checkpoint 저장된 model 파일을 이용하여 모델을 생성, 이때 Config는 위에서 update된 config 사용.
# model_ckpt = init_detector(cfg, checkpoint_file, device='cuda:0')
# # sample image에 적용.
# img = cv2.imread('/scratch/dohyeon/mmdetection/demo/2021-11-11-07_003.jpg')
# result = inference_detector(model_ckpt, img)
# show_result_pyplot(model_ckpt, img, result, score_thr=0.5)
핵심은
# config 파일을 설정하고, 다운로드 받은 pretrained 모델을 checkpoint로 설정.
config_file = 'configs/swin/mask_rcnn_swin-t-p4-w7_fpn_1x_coco.py'
checkpoint_file = 'checkpoints/mask_rcnn_swin-t-p4-w7_fpn_1x_coco_20210902_120937-9d6b7cfa.pth'
config_file과 checkpoint_file을 지정해주는 것이다.
기존 mmdetection code는
/mmdetection/config/swin/ 안에 내가 사용하고자하는 model의 config file이 존재한다.
따라서 해당 경로에서 본인이 사용하고자 하는 model의 config file의 경로를 가져오면 되고
checkpoint_file은 코드안에 존재하지 않기 때문에
mmdetection github에서 다운로드한 뒤 checkpoints 폴더를 생성한 뒤 해당 폴더안에 넣어주면 된다.
https://github.com/open-mmlab/mmdetection/tree/master/configs/swin
2. 오류 디버깅
오류1. 학습이 진행은 되지만 mAP가 모두 0으로 나오는 에러
: learning rate가 너무 커서 발생하는 문제이므로 learning rate를 줄여서 다시 학습하면 해결할 수 있음
오류2. assert len(data_loaders) == len(workflow) assertionerror 에러
: 해당 에러는 아래 사진에서 workflow = [('train', 1), ('val', 1)]
과 같이 설정되어 있을 때 발생하는 에러이다. 이때 아래 사진처럼 수정하면 해결된다.
'Computer Vision > MMdetection' 카테고리의 다른 글
[mmdetection] roi_head 변경방법 (0) | 2022.09.08 |
---|---|
[mmdetection] custom training 방법 (개선) (0) | 2022.08.22 |
[mmdetection] 탐지 개체수 설정방법 (0) | 2022.08.09 |
[mmdetection] bbox title 변경방법 (0) | 2022.07.20 |
[mmdetection] 초기 환경세팅 정보 (0) | 2022.05.11 |