Computer Vision/MMdetection

[mmdetection] 모델 학습과정에서 wanDB 연동방법

dohyeon2 2022. 12. 3. 14:56

목차

    이번에 mmdetection에서 제공하는 Instance Segmentation Network들의 성능을 비교해보려 합니다.

     

    위와같이 mmdetection은 2022년 12월 03일 기준 13개의 Instance Segmentation Network를 툴박스에서 제공합니다.

    거두절미하고 학습 시각화를 위한 wegiht&bias를 mmdetection에 연동하는 과정을 설명하겠습니다.

    2022년 12월 현재 mmdetection github의 master branch가 2.26.0버전을 제공하므로 이를 기준으로 설명하겠습니다. 


    1. Prerequirement 설치

    # pytorch 설치 for cuda 11.3
    conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch

    (cuda와 cudnn은 알아서 본인에게 맞는 버전을 설치해주세요)

     

    2. mmdetection 패키지 설치

    제 경우 mmdet source code를 직접 수정하기 때문에 source 설치를 진행합니다.

    git clone https://github.com/open-mmlab/mmdetection.git
    cd mmdetection
    pip install -v -e .
    # "-v" means verbose, or more output
    # "-e" means installing a project in editable mode,
    # thus any local modifications made to the code will take effect without reinstallation.

    설치 끝! 간단합니다. 

     

    3. mmdetection archive에서 본인이 학습시키고자 하는 모델 다운로드

    예를들어 본인이 mask-rcnn 모델을 학습하고자 한다면 아래 주소에 들어가서 model을 다운로드 받아주세요.

    https://github.com/open-mmlab/mmdetection/tree/v2.25.0/configs/mask_rcnn

     

    GitHub - open-mmlab/mmdetection: OpenMMLab Detection Toolbox and Benchmark

    OpenMMLab Detection Toolbox and Benchmark. Contribute to open-mmlab/mmdetection development by creating an account on GitHub.

    github.com

    4. Training 코드에 wanDB Hook 추가

    wanDB Hook

     

    5. 전체 학습코드

    from mmdet.apis import init_detector, inference_detector
    import mmcv
    import torch
    import cv2
    from mmdet.datasets.builder import DATASETS
    from mmdet.datasets.coco import CocoDataset
    from mmcv import Config
    import copy
    print(f"Setup complete. Using torch {torch.__version__} ({torch.cuda.get_device_properties(0).name if torch.cuda.is_available() else 'CPU'})")
    print(cv2.__version__)
    
    # config 파일을 설정하고, 다운로드 받은 pretrained 모델을 checkpoint로 설정.  
    config_file = 'configs/mask_rcnn/mask_rcnn_r101_fpn_2x_coco.py'
    checkpoint_file = 'checkpoints/mask_rcnn_r101_fpn_2x.pth'
    
    @DATASETS.register_module(force=True)
    class VOCDataset(CocoDataset):
      CLASSES = ('chicken',)
    
    # config file 호출.
    cfg = Config.fromfile(config_file)
    
    ###############################################################################################################################
    from mmdet.apis import set_random_seed
    
    # dataset에 대한 환경 파라미터 수정. 
    cfg.dataset_type = 'VOCDataset'
    cfg.data_root = '/coco_output/'
    
    # train, val, test dataset에 대한 type, data_root, ann_file, img_prefix 환경 파라미터 수정. 
    cfg.data.train.type = 'VOCDataset'
    cfg.data.train.data_root = '/scratch/dohyeon/PREDIX/src/mmdetection/dataset/dataset_64/'
    cfg.data.train.ann_file = 'annotations/train.json'
    cfg.data.train.img_prefix = 'train'
    
    cfg.data.val.type = 'VOCDataset'
    cfg.data.val.data_root = '/scratch/dohyeon/PREDIX/src/mmdetection/dataset/dataset_64/'
    cfg.data.val.ann_file = 'annotations/val.json'
    cfg.data.val.img_prefix = 'val'
    
    
    # class의 갯수 수정. 
    cfg.model.roi_head.bbox_head.num_classes = 1
    cfg.model.roi_head.mask_head.num_classes = 1
    
    # pretrained 모델설정(Mask R-CNN + resnet101).
    cfg.load_from = 'checkpoints/mask_rcnn_r101_fpn_2x.pth'
    
    # 학습 weight 파일로 로그를 저장하기 위한 디렉토리 설정. 
    cfg.work_dir = './mask_rcnn_r101_fpn_2x_dataset_64'
    
    # 학습율 변경 환경 파라미터 설정. 
    cfg.optimizer.lr = 0.02 / 8
    cfg.lr_config.warmup = None
    cfg.log_config.interval = 32
    cfg.log_config.hooks = [
        dict(type='TextLoggerHook'),
        dict(type='MMDetWandbHook',
             init_kwargs={
            'project': 'mask_rcnn',
            'entity': 'msdl_wandb',
            'name': 'mask_rcnn_r101_2x_dataset_64'},
             interval=10,
             log_checkpoint=True,
             log_checkpoint_metadata=True,
             num_eval_images=10,
             bbox_score_thr=0.7)]
    # epoch 변경 환경 파라미터 설정
    cfg.runner.max_epochs=100
    
    # workflow val 추가
    cfg.workflow = [('train', 1), ('val', 1)]
    
    # CocoDataset의 경우 metric을 bbox로 설정해야 함.(mAP아님. bbox로 설정하면 mAP를 iou threshold를 0.5 ~ 0.95까지 변경하면서 측정)
    cfg.evaluation.metric = ['bbox', 'segm']
    cfg.evaluation.interval = 10
    cfg.checkpoint_config.interval = 10
    
    # 두번 config를 로드하면 lr_config의 policy가 사라지는 오류로 인하여 설정. 
    cfg.lr_config.policy='step'
    # Set seed thus the results are more reproducible
    cfg.seed = 0
    set_random_seed(0, deterministic=False)
    cfg.gpu_ids = range(1)
    cfg.device='cuda'
    
    ###################################################################################################################################
    from mmdet.datasets import build_dataset
    from mmdet.models import build_detector
    from mmdet.apis import train_detector
    
    # train, val Dataset 생성. 
    datasets = [build_dataset(cfg.data.train)]
    if len(cfg.workflow) == 2:
      val_dataset = copy.deepcopy(cfg.data.val)
      val_dataset.pipeline = cfg.data.train.pipeline
      datasets.append(build_dataset(val_dataset))
    
    model = build_detector(cfg.model, train_cfg=cfg.get('train_cfg'), test_cfg=cfg.get('test_cfg'))
    model.CLASSES = datasets[0].CLASSES
    print(model.CLASSES)
    
    
    
    ##################################################################################################
    
    #Training !! 
    import os.path as osp
    mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
    # epochs는 config의 runner 파라미터로 지정됨. 기본 12회 
    train_detector(model, datasets, cfg, distributed=False, validate=True)
    
    ###################################################################################################
    
    # inference 테스트 코드(선택사항). 
    # from mmdet.apis import show_result_pyplot
    # import cv2
    # checkpoint_file = '/scratch/dohyeon/mmdetection/tutorial_exps/epoch_12.pth'
    
    # # checkpoint 저장된 model 파일을 이용하여 모델을 생성, 이때 Config는 위에서 update된 config 사용. 
    # model_ckpt = init_detector(cfg, checkpoint_file, device='cuda:0')
    
    # # sample image에 적용.
    # img = cv2.imread('/scratch/dohyeon/mmdetection/demo/2021-11-11-07_003.jpg')
    # result = inference_detector(model_ckpt, img)
    # show_result_pyplot(model_ckpt, img, result, score_thr=0.5)

     

     

    Reference

    https://docs.wandb.ai/guides/integrations/mmdetection