Transfer Learning

  • 기존 존재하는 모델 가져온 후, pre-trained된 weight를 로드한다.
  • 그리고 이 pre-trained된 weight를 학습되지 못하도록 막는다.
import os
from tensorflow.keras import layers
from tensorflow.keras import Model
from tensorflow.keras.applications.inception_v3 import InceptionV3

local_weights_file = "@#$%"
pre_trained_model = InceptionV3(input_shape=(150, 150, 3),
                                include_top=False,	# FC -> False -> get straight to the conv
                                weights=None)

pre_trained_model.load_weights(local_weights_file)

for layer in pre_trained_model.layers:
	layer.trainable=False # lock layers (not trainable)

 

 

 

  • 기존 모델의 레이어 중 하나 선택하고, 새롭게 추가한 레이어와 연결한다.
  • 그림으로 이해하면 아래와 같다.

last_layer = pre_trained_model.get_layer('mixed7')
last_output = last_layer.output

x = layers.Flatten()(last_output)
x = layers.Dense(1024, activation='relu')(x)
x = layers.Dense(1, activation='sigmoid')(x)

from tensorflow.keras.optimizers import RMSprop

model = Model(pre_trained_model.input, x)
model.compile(	optimizer=RMSprop(lr=0.0001),
                loss='binary_crossentropy',
                metrics=['acc'])

 

 

ImageDataGenerator 이용하여 데이터 로드

train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=40,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True)
    
    
train_generator = train_datagen.flow_from_directory(
    train_dir,
    batch_size=20,
    class_mode='binary',
    target_size=(150, 150)
)

 

 

모델 학습

history = model.fit_generator(
    train_generator,
    validation_data = validation_generator,
    steps_per_epoch=100,
    epochs=100,
    validation_steps=50,
    verbose=2
)

Dropout

from tensorflow.keras.optimizers import RMSprop

x = layers.Flatten()(last_input)
x = layers.Dense(1024, activation='relu')(x)
x = layers.Dropout(0.2)(x)	# add dropout (0.2 -> 20% dropout)
x = layers.Dense(1, activation='sigmoid')(x)

model = Model(pre_trained_model.input, x)
model.compile(
    optimizer=RMSprop(lr=0.0001),
    loss='binary_crossentropy',
    metrics=['acc']
)

dropout은 overfitting을 방지하는 효과가 있다.

 

dropout에 대한 설명은 아래에 글에 있습니다.

hyoeun-log.tistory.com/entry/WEEK3-regularization-%EC%A0%95%EA%B7%9C%ED%99%94

 

WEEK3 : regularization (정규화)

L2 regularization dropout data augmentation early stopping 정규화 (variance를 줄이기 위한 방법) 1. L2 regularization L2 regularization이란 무엇인가? loss function에 W 행렬의 norm을 더해주어 W 행렬의..

hyoeun-log.tistory.com

 

+ Recent posts