ML-Computer Vision/MMdetection

[mmdetection] bbox title 변경방법

dohyeon2 2022. 7. 20. 18:07

오늘은 실험실에서 mmdetection 툴박스에서 제공하는 Mask R-CNN으로 추론하던 중 추론결과 표시되는 bbox title을 변경하는 task가 주어졌다. 위 문제를 해결하는 방법을 정리해보려 한다.

 

원본코드를 이용한 추론결과는 위와 같다.


mmdetection Mask R-CNN코드는 다음 순서에 따라 동작한다.

1. config, checkpoint file을 불러오기

2. init_detector를 이용해 detector model 생성

3. 추론하고자 하는 image를 불러들여 inference_detector(model, image_name) 함수를 이용해 추론 

4. show_result 함수를 이용해 추론된 결과를 원본 이미지 위에 그려주고 저장

# config 파일을 설정하고, maskrcnn 모델로 학습한 checkpoint 불러오기. 
config_file = '/scratch/dohyeon/mmdetection/custom_config.py'
checkpoint_file = '/scratch/dohyeon/mmdetection/tutorial_exps/epoch_12.pth'

# config 파일과 checkpoint를 기반으로 Detector 모델을 생성.
model = init_detector(config_file, checkpoint_file, device='cuda:0')

for i in range(1):
        img_name = path_dir + '/' + file_list[i]
        img_arr= cv2.imread(img_name, cv2.IMREAD_COLOR)
        img_arr_rgb = cv2.cvtColor(img_arr, cv2.COLOR_BGR2RGB)
        # cv2.imshow('img',img)
        fig= plt.figure(figsize=(12, 12))
        plt.imshow(img_arr_rgb)

        # inference_detector의 인자로 string(file경로), ndarray가 단일 또는 list형태로 입력 될 수 있음. 
        results = inference_detector(model, img_arr)

        # # 추론결과 디렉토리에 저장(confidenece score 0.7이상의 instance만 이미지에 그릴 것).
        model.show_result(img_arr, results,score_thr=0.7, bbox_color=(0,0,0),thickness=0.5,font_size=6, out_file= f'{save_dir1}{file_list[i]}')

        # 각 이미지마다 추론된 instance 개수 덧셈(여기서는 이미지 1장).
        number = len(results[1][0])
        num_chicken= num_chicken + number

 

원래 추론결과는 아래 사진과 같이 class_name | confidence_score 순으로 표시된다. 

하지만 나는 Instance_id | weight  라는 변수를 표시하고 싶어서 show_result 함수를 수정하였다. 

show_result 함수는 mmdetection/mmdet/models/detectors/base.py 에 위치한 함수인데 

내용을 들여다보면 다음과 같다. 

def show_result(self,
                    img,
                    result,
                    score_thr=0.3,
                    bbox_color=(72, 101, 241),
                    text_color=(72, 101, 241),
                    mask_color=None,
                    thickness=2,
                    font_size=13,
                    win_name='',
                    show=False,
                    wait_time=0,
                    out_file=None):
        """Draw `result` over `img`.

        Args:
            img (str or Tensor): The image to be displayed.
            result (Tensor or tuple): The results to draw over `img`
                bbox_result or (bbox_result, segm_result).
            score_thr (float, optional): Minimum score of bboxes to be shown.
                Default: 0.3.
            bbox_color (str or tuple(int) or :obj:`Color`):Color of bbox lines.
               The tuple of color should be in BGR order. Default: 'green'
            text_color (str or tuple(int) or :obj:`Color`):Color of texts.
               The tuple of color should be in BGR order. Default: 'green'
            mask_color (None or str or tuple(int) or :obj:`Color`):
               Color of masks. The tuple of color should be in BGR order.
               Default: None
            thickness (int): Thickness of lines. Default: 2
            font_size (int): Font size of texts. Default: 13
            win_name (str): The window name. Default: ''
            wait_time (float): Value of waitKey param.
                Default: 0.
            show (bool): Whether to show the image.
                Default: False.
            out_file (str or None): The filename to write the image.
                Default: None.

        Returns:
            img (Tensor): Only if not `show` or `out_file`
        """
        img = mmcv.imread(img)
        img = img.copy()
        if isinstance(result, tuple):
            bbox_result, segm_result = result
            if isinstance(segm_result, tuple):
                segm_result = segm_result[0]  # ms rcnn
        else:
            bbox_result, segm_result = result, None
        bboxes = np.vstack(bbox_result)
        labels = [
            np.full(bbox.shape[0], i, dtype=np.int32)
            for i, bbox in enumerate(bbox_result)
        ]
        labels = np.concatenate(labels)
        # draw segmentation masks
        segms = None
        if segm_result is not None and len(labels) > 0:  # non empty
            segms = mmcv.concat_list(segm_result)
            if isinstance(segms[0], torch.Tensor):
                segms = torch.stack(segms, dim=0).detach().cpu().numpy()
            else:
                segms = np.stack(segms, axis=0)
        # if out_file specified, do not show image in window
        if out_file is not None:
            show = False
        # draw bounding boxes
        img = imshow_det_bboxes(
            img,
            bboxes,
            labels,
            segms,
            class_names=self.CLASSES,
            score_thr=score_thr,
            bbox_color=bbox_color,
            text_color=text_color,
            mask_color=mask_color,
            thickness=thickness,
            font_size=font_size,
            win_name=win_name,
            show=show,
            wait_time=wait_time,
            out_file=out_file)

        if not (show or out_file):
            return img

여기서 # draw bounding boxes 아래 부분을 보면 img = imshow_det_bboxes 라는 함수와 이어진다.

(해당 함수는 mmdetection/mmdet/core/visualization/image.py 모듈 안에 위치한다.)

 

여기서 나는 내가 새로만든 모듈에서 weight(육계 체중)를 받아와야 하기 때문에 

# draw bounding boxes
        img = imshow_det_bboxes(
            img,
            bboxes,
            labels,
            predict_weight_list,
            segms,
            class_names=self.CLASSES,
            score_thr=score_thr,
            bbox_color=bbox_color,
            text_color=text_color,
            mask_color=mask_color,
            thickness=thickness,
            font_size=font_size,
            win_name=win_name,
            show=show,
            wait_time=wait_time,
            out_file=out_file)

위와 같이 predict_weight_list 라는 값을 인자로 추가하였다. 

 

<주의사항> 

python function 정의시 지켜야할 argument 순서가 있는데, 

함수에서 미리 지정된 parameter는 그렇지 않은 것보다 항상 뒤에 와야한다. 

def order(no1 = "아침", no2, no3):
    print(f"{no1}은 9시에 {no2}은 12시에 {no3}은 6시에 먹는다.")
  
order("점심", "저녁")

# 출력값
SyntaxError: non-default argument follows default argument

예를들어 위와 같이 no1 이라는 미리 지정된 parameter가 그렇지 않은 no2,no3보다 앞에 왔기 때문에 

SyntaxError: non-default argument follows default argument가 발생한 것이다. 

결론적으로 미리 값이 지정된 parameter는 그렇지 않은 parameter보다 항상 뒤에 와야한다.

 

다시 이어서 mmdetection/mmdet/core/visualization/image.py 에 위치한 imshow_det_bbox 함수를 찾아가보면 

내용은 아래와 같다.

def imshow_det_bboxes(img,
                      bboxes,
                      labels,
                      segms=None,
                      class_names=None,
                      score_thr=0,
                      bbox_color='green',
                      text_color='green',
                      mask_color=None,
                      thickness=2,
                      font_size=13,
                      win_name='',
                      show=True,
                      wait_time=0,
                      out_file=None):
    """Draw bboxes and class labels (with scores) on an image.

    Args:
        img (str or ndarray): The image to be displayed.
        bboxes (ndarray): Bounding boxes (with scores), shaped (n, 4) or
            (n, 5).
        labels (ndarray): Labels of bboxes.
        segms (ndarray or None): Masks, shaped (n,h,w) or None
        class_names (list[str]): Names of each classes.
        score_thr (float): Minimum score of bboxes to be shown.  Default: 0
        bbox_color (str or tuple(int) or :obj:`Color`):Color of bbox lines.
           The tuple of color should be in BGR order. Default: 'green'
        text_color (str or tuple(int) or :obj:`Color`):Color of texts.
           The tuple of color should be in BGR order. Default: 'green'
        mask_color (str or tuple(int) or :obj:`Color`, optional):
           Color of masks. The tuple of color should be in BGR order.
           Default: None
        thickness (int): Thickness of lines. Default: 2
        font_size (int): Font size of texts. Default: 13
        show (bool): Whether to show the image. Default: True
        win_name (str): The window name. Default: ''
        wait_time (float): Value of waitKey param. Default: 0.
        out_file (str, optional): The filename to write the image.
            Default: None

    Returns:
        ndarray: The image with bboxes drawn on it.
    """
    assert bboxes.ndim == 2, \
        f' bboxes ndim should be 2, but its ndim is {bboxes.ndim}.'
    assert labels.ndim == 1, \
        f' labels ndim should be 1, but its ndim is {labels.ndim}.'
    assert bboxes.shape[0] == labels.shape[0], \
        'bboxes.shape[0] and labels.shape[0] should have the same length.'
    assert bboxes.shape[1] == 4 or bboxes.shape[1] == 5, \
        f' bboxes.shape[1] should be 4 or 5, but its {bboxes.shape[1]}.'
    img = mmcv.imread(img).astype(np.uint8)

    if score_thr > 0:
        assert bboxes.shape[1] == 5
        scores = bboxes[:, -1]
        inds = scores > score_thr
        bboxes = bboxes[inds, :]
        labels = labels[inds]
        if segms is not None:
            segms = segms[inds, ...]

    mask_colors = []
    if labels.shape[0] > 0:
        if mask_color is None:
            # Get random state before set seed, and restore random state later.
            # Prevent loss of randomness.
            # See: https://github.com/open-mmlab/mmdetection/issues/5844
            state = np.random.get_state()
            # random color
            np.random.seed(42)
            mask_colors = [
                np.random.randint(0, 256, (1, 3), dtype=np.uint8)
                for _ in range(max(labels) + 1)
            ]
            np.random.set_state(state)
        else:
            # specify  color
            mask_colors = [
                np.array(mmcv.color_val(mask_color)[::-1], dtype=np.uint8)
            ] * (
                max(labels) + 1)

    bbox_color = color_val_matplotlib(bbox_color)
    text_color = color_val_matplotlib(text_color)

    img = mmcv.bgr2rgb(img)
    width, height = img.shape[1], img.shape[0]
    img = np.ascontiguousarray(img)

    fig = plt.figure(win_name, frameon=False)
    plt.title(win_name)
    canvas = fig.canvas
    dpi = fig.get_dpi()
    # add a small EPS to avoid precision lost due to matplotlib's truncation
    # (https://github.com/matplotlib/matplotlib/issues/15363)
    fig.set_size_inches((width + EPS) / dpi, (height + EPS) / dpi)

    # remove white edges by set subplot margin
    plt.subplots_adjust(left=0, right=1, bottom=0, top=1)
    ax = plt.gca()
    ax.axis('off')

    polygons = []
    color = []
    for i, (bbox, label) in enumerate(zip(bboxes, labels)): # i는 인덱스, bbox = bboxes result값, lables= label result값
        bbox_int = bbox.astype(np.int32)
        poly = [[bbox_int[0], bbox_int[1]], [bbox_int[0], bbox_int[3]],
                [bbox_int[2], bbox_int[3]], [bbox_int[2], bbox_int[1]]]
        np_poly = np.array(poly).reshape((4, 2))
        polygons.append(Polygon(np_poly))
        color.append(bbox_color)
        label_text = class_names[
            label] if class_names is not None else f'class {label}'
        if len(bbox) > 4:
            label_text += f'|{bbox[-1]:.02f}'
        ax.text(
            bbox_int[0],
            bbox_int[1],
            f'{label_text}',
            bbox={
                'facecolor': 'black',
                'alpha': 0.8,
                'pad': 0.7,
                'edgecolor': 'none'
            },
            color=text_color,
            fontsize=font_size,
            verticalalignment='top',
            horizontalalignment='left')
        if segms is not None:
            color_mask = mask_colors[labels[i]]
            mask = segms[i].astype(bool)
            img[mask] = img[mask] * 0.5 + color_mask * 0.5

    plt.imshow(img)

    p = PatchCollection(
        polygons, facecolor='none', edgecolors=color, linewidths=thickness)
    ax.add_collection(p)

    stream, _ = canvas.print_to_buffer()
    buffer = np.frombuffer(stream, dtype='uint8')
    img_rgba = buffer.reshape(height, width, 4)
    rgb, alpha = np.split(img_rgba, [3], axis=2)
    img = rgb.astype('uint8')
    img = mmcv.rgb2bgr(img)

    if show:
        # We do not use cv2 for display because in some cases, opencv will
        # conflict with Qt, it will output a warning: Current thread
        # is not the object's thread. You can refer to
        # https://github.com/opencv/opencv-python/issues/46 for details
        if wait_time == 0:
            plt.show()
        else:
            plt.show(block=False)
            plt.pause(wait_time)
    if out_file is not None:
        mmcv.imwrite(img, out_file)

    plt.close()

    return img

여기서 나는 weight 값을 인자로 받아줘야 하기 때문에 아래와 같이 코드를 수정한다. 

def imshow_det_bboxes(img,
                      bboxes,
                      labels,
                      predict_weight_list,
                      segms=None,
                      class_names=None,
                      score_thr=0,
                      bbox_color='green',
                      text_color='green',
                      mask_color=None,
                      thickness=2,
                      font_size=13,
                      win_name='',
                      show=True,
                      wait_time=0,
                      out_file=None):

imshow_det_bboxes() 함수의 4번째 인자로 predict_weight_list를 넣어주었다. 

그 뒤에 아래와 같이 수정하였다.

# 개체번호 list.
id =[]
for i in range(len(labels)):
    id.append(i+1)

# 체중 list 받아오기.
weight = predict_weight_list

polygons = []
color = []
for i, (bbox, label) in enumerate(zip(bboxes, labels)): # i는 인덱스, bbox = bboxes result값, lables= label result값
    bbox_int = bbox.astype(np.int32)
    poly = [[bbox_int[0], bbox_int[1]], [bbox_int[0], bbox_int[3]],
            [bbox_int[2], bbox_int[3]], [bbox_int[2], bbox_int[1]]]
    np_poly = np.array(poly).reshape((4, 2))
    polygons.append(Polygon(np_poly))
    color.append(bbox_color)
    label_text = f'id:{id[i]}' if id is not None else f'class {label}'
    if len(bbox) > 4:
        label_text += f'|{weight[i][0]:.02f}g'

bbox title에 표시되는 text는

label_text =    이 부분이기 때문에 

label_text 값을 id number를 할당하였고, 

if len(bbox) > 4:
            label_text += f'|{weight[i][0]:.02f}g'

   위 뜻은 id_num뒤에 이어서 weight[index]값을 붙여주겠다 라는 뜻이다. 

 

이렇게하면 결과적으로 

위와같이 id_num | weight 값을 title에 표시할 수 있다.