Deep Learning

Deep Learning - 딥러닝의 활성화 히트맵 시각화하기

electronicprogrammer 2020. 11. 8. 01:13

이번 포스팅에서는 이미 훈련된 CNN의 활성화를 파악하는 방법 중 하나인 히트맵 시각화하는 방법에 대해서 정리해보도록 하겠습니다.

이 방법으로 저희는 이미지의 어느 부분이 CNN의 최종 분류 결정에 기여하는 가에 대해서 파악할 수 가 있습니다.

 

먼저 히트맵을 시각화하기 이전에 이미 훈련된 CNN 모델 중 하나인 VGG16 모델을 가져오도록 하겠습니다.

from keras.applications.vgg16 import VGG16

from keras.preprocessing import image

from keras.applications.vgg16 import preprocess_input , decode_predictions

 

그리고 이 모델을 이용해서 히트맵을 그리고자 하는 예시 사진을 하나 가져오겠습니다.

image_Path = './archive/training_set/training_set/dogs/dog.2240.jpg'

img = image.load_img(image_Path , target_size = (224 , 224))

plt.figure(figsize = (8 , 5))

plt.imshow(img)

plt.show()

 

VGG16에 입력하기 위해서 사진의 높이 및 너비를 모두 224로 설정해서 이미지 파일을 불러왔습니다.

 

이제 이 이미지를 VGG16 모델에 입력하기 앞서서 데이터 전처리를 하도록 하겠습니다.

img = image.img_to_array(img)

x = np.expand_dims(img , axis = 0)
x = preprocess_input(x)

model = VGG16(weights = 'imagenet')

preds = model.predict(x)

 

전처리한 이미지를 VGG16모델에 입력함으로써 preds 정보를 구할 수 있었습니다.

이제 이를 decode_predict 함수를 이용해서 해당 사진을 무엇으로 생각하는 지를 파악할 수 있습니다.

preds_info = decode_predictions(preds , top=3)[0]

print('Predicted')

for i in range(len(preds_info)) :
    
    print(preds_info[i])
    

 

groenendael 일 확률을 대략 47%로 판단하였음을 알 수 있습니다.

groenendael 이 무엇인지 확인해보니 다음과 같습니다.

 

이제 불러온 사진의 히트맵을 그리기 위해서 필요한 라이브러리를 불러오겠습니다.

from keras import preprocessing
from keras import backend as K

from keras import models

import tensorflow as tf

 

다음은 히트맵을 구하는 과정입니다.

히트맵을 구하는 과정은 GRAD-CAM 방법을 이용하는데 이 방법에 대한 원리에 대해서 저도 자세하게는 모르겠습니다.

아래는 구하는 과정만을 정리해서 구현한 코드입니다.

# GRAD-CAM에 대한 내용을 이용해서 구현

last_conv_layer = model.get_layer('block5_conv3')

heatmap_model = models.Model([model.input] , [last_conv_layer.output , model.output])

with tf.GradientTape() as gtape:
    
    conv_output, predictions = heatmap_model(x)
    
    # 가장 가능성이 높은 클래스에 대한 정보에서 loss 파악
    loss = predictions[:, np.argmax(predictions[0])]
    
    # block5_conv3의 특성 맵 출력에 대한 'groenendael' 클래스의 그래디언트
    grads = gtape.gradient(loss, conv_output)
    
    # 특성 맵 채널별 그래디언트 평균값이 담긴 벡터
    pooled_grads = K.mean(grads, axis=(0, 1, 2))

# 특성 맵의 출력
heatmap = tf.reduce_mean(tf.multiply(pooled_grads, conv_output), axis=-1)

# 0 ~ 1 사이의 값으로 정규화
heatmap = np.maximum(heatmap, 0)

max_heat = np.max(heatmap)

if max_heat == 0:
    max_heat = 1e-10
    
heatmap /= max_heat

 

그리고 구한 히트맵을 출력한 결과입니다.

print('Heamap Shape : ', heatmap.shape)

plt.matshow(heatmap[0]);

 

최종적으로 이 히트맵과 위 본 이미지를 같이 그려서 이미지의 어떠한 부분이 해당 이미지가 무엇이다 라고 판단하기에 결정적인 역할을 했는지 파악해보겠습니다.

org_Image = cv2.imread(image_Path)

height , width , channel = org_Image.shape

# heatmap을 원본 이미지 사이즈에 맞춘다
heatmap_Resized = cv2.resize(heatmap[0]  , (width , height))

# 값을 0 ~ 255 사이 int 형으로 변경 // RGB 형식
heatmap_Resized = np.uint8(255 * heatmap_Resized)

# heatmap으로 변환
heatmap_Resized = cv2.applyColorMap(heatmap_Resized , cv2.COLORMAP_JET)

# 기존 이미지와 히트맵 이미지를 겹쳐서 그리기 위해
superimposed_img = np.zeros((height , width , channel))

for c in range(channel) :
    
    for i in range(height) :
        
        for j in range(width) :
            
            h = heatmap_Resized[i][j][c]
            v = org_Image[i][j][c]
            
            # 히트맵의 강도를 0.4로 설정
            superimposed_img[i][j][c] = h * 0.4 + v

cv2.imwrite('./heatmap_Result.jpg' , superimposed_img)

 

cv2,imwrite을 이용해서 jpg 파일을 만들어서 저장을 하였으며 이 파일을 다시 불러와서 출력해보겠습니다.

result = cv2.imread('./heatmap_Result.jpg')

plt.figure(figsize = (12 , 5))
plt.imshow(result);

 

이로써 이미지의 코 부분이 VGG16 모델이 해당 이미지를 그루넨달이라고 판단하는데 있어서 중요한 역할을 했다고 판단할 수 있습니다.

 

참고자료 : www.yes24.com/Product/Goods/65050162

 

케라스 창시자에게 배우는 딥러닝

단어 하나, 코드 한 줄 버릴 것이 없다!창시자의 철학까지 담은 딥러닝 입문서케라스 창시자이자 구글 딥러닝 연구원인 저자는 ‘인공 지능의 민주화’를 강조한다. 이 책 역시 많은 사람에게

www.yes24.com

참고자료 : stackoverflow.com/questions/58322147/how-to-generate-cnn-heatmaps-using-built-in-keras-in-tf2-0-tf-keras

 

How to generate CNN heatmaps using built-in Keras in TF2.0 (tf.keras)

I used to generate heatmaps for my Convolutional Neural Networks, based on the stand-alone Keras library on top of TensorFlow 1. That worked fine, however, after my switch to TF2.0 and built-in tf....

stackoverflow.com