ImageDataGenerator ์‚ฌ์šฉํ•˜์—ฌ data augmentation ์ˆ˜ํ–‰ํ•˜๊ธฐ

 

from tensorflow.keras.preprocessing.image import ImageDataGenerator

train_datagen = ImageDataGenerator(rescale=1./255)
  • ์ €๋ฒˆ์—๋Š” ์ด๋ ‡๊ฒŒ๋งŒ ์‚ฌ์šฉ์„ ํ–ˆ์—ˆ์ง€๋งŒ, ์ด๋ฒˆ์—๋Š” data augmentation๊นŒ์ง€ ํ•ด๋ณด์ž!
  • ๋‹จ์ˆœํžˆ ์•„๋ž˜์™€ ๊ฐ™์ด ๋ช‡ ๊ฐ€์ง€ ๊ฐ’์„ ์ง€์ •ํ•ด์ฃผ๊ธฐ๋งŒ ํ•˜๋ฉด ๋งค์šฐ ์‰ฝ๊ฒŒ data augmentation์ด ๋๋‚œ๋‹ค.
  • ์ด ์ดํ›„์˜ ์ ˆ์ฐจ๋Š” ์ด์ „๊ณผ ๋‹ค๋ฅผ๋ฐ”๊ฐ€ ์—†๋‹ค.
train_datagen = ImageDataGenerator(
    rescale = 1./255,
    rotation_rate=40,
    width_shift_range=0.2,	// ์ขŒ์šฐ์ด๋™ (์ด๋ฏธ์ง€ ํฌ๊ธฐ ๋Œ€๋น„ ์ด๋™)
    height_shift_range=0.2,	// ์ƒํ•˜์ด๋™
    shear_range=0.2,		// ๊ธฐ์šธ์ž„
    zoom_range=0.2,		// ํ™•๋Œ€ (์ƒ๋Œ€์  ํ™•๋Œ€ ๋น„์œจ)
    horizontal_flip=True,	// ์ˆ˜ํ‰๋ฐฉํ–ฅ ๋’ค์ง‘๊ธฐ
    fill_mode='nearest'		// ์†์‹ค๋˜์—ˆ์„ ํ”ฝ์…€์„ ๊ฐ€์žฅ ๊ฐ€๊นŒ์šด ํ”ฝ์…€๊ฐ’์„ ์ด์šฉํ•˜์—ฌ ์ฑ„์šด๋‹ค
)
  • augmentation ์‚ฌ์šฉํ•˜๋ฉด training ์ด๋ฏธ์ง€์— ๋‹ค์–‘์„ฑ(=random)์„ ๋ถ€์—ฌํ•  ์ˆ˜ ์žˆ๊ธฐ ๋•Œ๋ฌธ์—  overfitting ๋ฌธ์ œ๋ฅผ ์™„ํ™”ํ•  ์ˆ˜ ์žˆ๋‹ค.
  • ํ•˜์ง€๋งŒ validation ์ด๋ฏธ์ง€์—์„œ์˜ randomness์™€ train์—์„œ์˜ randomness๊ฐ€ ์ผ์น˜ํ•˜์ง€ ์•Š๋Š”๋‹ค๋ฉด,
  • data augmentation์„ ์ˆ˜ํ–‰ํ•ด๋„ validation accuracy๊ฐ€ ๋†’์•„์ง€์ง€ ์•Š์„ ์ˆ˜ ์žˆ๋‹ค. 
  • (data augmentation์ด ํ•ญ์ƒ overfitting ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•ด์ฃผ๋Š” ๊ฒƒ์ด ์•„๋‹ˆ๋‹ค)
  • validation data๋Š” data augmentation ํ•ด์ฃผ๋ฉด ์•ˆ ๋œ๋‹ค

 

TRAINING_DIR = '/tmp/cats-v-dogs/training'
train_datagen = ImageDataGenerator(
    rescale=1./255, 
    width_shift_range=0.2,
    height_shift_range=0.2,
    rotation_range=0.2,
    horizontal_flip=True
)

train_generator = train_datagen.flow_from_directory(
    TRAINING_DIR,
    target_size=(300, 300),
    batch_size=10,
    class_mode='binary'
)


VALIDATION_DIR = '/tmp/cats-v-dogs/testing'
validation_datagen = ImageDataGenerator(
    rescale=1./255
    # data augmentation ์ˆ˜ํ–‰ํ•˜๋ฉด X
)

validation_generator = validation_datagen.flow_from_directory(
    VALIDATION_DIR,
    target_size=(300, 300),
    batch_size=10,
    class_mode='binary'
)

 

+ Recent posts