ML-Computer Vision/MMdetection

[mmdetection] roi_head 변경방법

dohyeon2 2022. 9. 8. 17:38

기존 mask rcnn model에서 pointRend model로 mask_head를 변경하려고 하였는데 다음과 같은 오류가 발생하였다.

 

AssertionError: The `num_classes` (80) in MaskPointHead of MMDataParallel does not matches the length of `CLASSES` 1) in VOCDataset

해당 오류는 2가지의 이유로 발생할 수 있는데 먼저, 

해결방법 1 

mmdetection-master 디렉토리에는 python 파일이 일부만 존재하는데, 실제로 프로그램이 실행될 때 환경에서 소스 파일을 직접 수정하기 때문에 환경의 소스 파일은 계속 실행됩니다.


내 conda 환경의 이름이 conda_env_name이라고 가정하고 다음 디렉토리로 이동하여 각각 두 개의 파일을 수정합니다.


\anaconda3\envs\conda_env_name\lib\python3.7\site-packages\mmdet\core\evaluation\class_names.py
\anaconda3\envs\conda_env_name\lib\python3.7\site-packages\mmdet\datasets\coco.py

conda 환경에서 이 두 파일의 classes name을 원하는 class name으로 변경하면 됩니다.

해결방법 2

mmdetection/config/원하는 model (내 경우 point_rend)/ point_rend_r50_caffe_fpn_mstrain_3x_coco.py 파일에서 

주석처리한 부분을 내가 원하는 class 개수로 변경

_base_ = '../mask_rcnn/mask_rcnn_r50_caffe_fpn_mstrain_1x_coco.py'
# model settings
model = dict(
    type='PointRend',
    roi_head=dict(
        type='PointRendRoIHead',
        mask_roi_extractor=dict(
            type='GenericRoIExtractor',
            aggregation='concat',
            roi_layer=dict(
                _delete_=True, type='SimpleRoIAlign', output_size=14),
            out_channels=256,
            featmap_strides=[4]),
        mask_head=dict(
            _delete_=True,
            type='CoarseMaskHead',
            num_fcs=2,
            in_channels=256,
            conv_out_channels=256,
            fc_out_channels=1024,
            num_classes=1,    # 내가 원하는 클래스 숫자로 변경
            loss_mask=dict(
                type='CrossEntropyLoss', use_mask=True, loss_weight=1.0)),
        point_head=dict(
            type='MaskPointHead',
            num_fcs=3,
            in_channels=256,
            fc_channels=256,
            num_classes=1,  # 내가 원하는 클래스 숫자로 변경
            coarse_pred_each_layer=True,
            loss_point=dict(
                type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))),
    # model training and testing settings
    train_cfg=dict(
        rcnn=dict(
            mask_size=7,
            num_points=14 * 14,
            oversample_ratio=3,
            importance_sample_ratio=0.75)),
    test_cfg=dict(
        rcnn=dict(
            subdivision_steps=5,
            subdivision_num_points=28 * 28,
            scale_factor=2)))