Tensorflow 2.0 圖像分類:安裝、加載數據、構建和訓練模型

已發表: 2020-04-21

圖像分類是模式識別的一個範疇。 它根據相鄰像素之間的關係對圖像進行分類。 換句話說,它使用上下文信息來組織圖像,並且在不同的技術中非常流行。 這是深度學習中的一個突出主題,如果你正在學習它,你一定會喜歡這篇文章。

在這裡,我們將執行 TensorFlow 圖像分類。 我們將建立一個模型,對其進行訓練,然後提高其分類仙人掌圖像的準確性。 TensorFlow 是一個開源機器學習平台,是谷歌的產品。

讓我們開始吧。

目錄

安裝 TensorFlow 2.0

首先,您需要在 Google Colab 上安裝 TensorFlow。 你可以通過 pip 安裝它:

!pip install tensorflow-gpu==2.0.0-alpha0

然後我們將驗證安裝:

將張量流導入為 tf

打印(tf.__版本)

# 輸出:2.0.0-alpha0

資源

學習:最受初學者歡迎的 5 個 TensorFlow 項目

加載數據

驗證後,我們可以使用 tf.data.dataset 加載數據。 我們將構建一個分類器來確定圖像是否包含仙人掌。 仙人掌必須是柱狀的。我們可以為此使用仙人掌航空照片數據集。 現在,我們將加載文件路徑及其標籤:

train_csv = pd.read_csv('data/train.csv')

# 在 train/ 中添加相對路徑的圖像文件名

文件名 = ['train/' + fname for fname in train_csv['id'].tolist()]

標籤 = train_csv['has_cactus'].tolist()

train_filenames, val_filenames, train_labels, val_labels =

train_test_split(文件名,

標籤,

火車尺寸=0.9,

隨機狀態=42)

一旦我們有了標籤和文件名,我們就可以創建 tf.data.Dataset 對象了:

train_data = tf.data.Dataset.from_tensor_slices(

(tf.constant(train_filenames), tf.constant(train_labels))

)

val_data = tf.data.Dataset.from_tensor_slices(

(tf.constant(val_filenames), tf.constant(val_labels))

)

資源

目前,我們的數據集沒有實際的圖像。 它只有他們的文件名。 我們需要一個函數來加載必要的圖像並對其進行處理,以便我們可以對它們執行 TensorFlow 圖像識別。

IMAGE_SIZE = 96 # 用於 MobileNetV2 的最小圖像尺寸

BATCH_SIZE = 32

# 加載和預處理每個圖像的函數

def _parse_fn(文件名,標籤):

img = tf.io.read_file(img)

img = tf.image.decode_jpeg(img)

img = (tf.cast(img, tf.float32)/127.5) – 1

img = tf.image.resize(img, (IMAGE_SIZE, IMAGE_SIZE))

返回img,標籤

# 對 train 和 val 數據集中的每個示例運行 _parse_fn

# 同時洗牌和創建批次

train_data = (train_data.map(_parse_fn)

.shuffle(buffer_size=10000)

.batch(BATCH_SIZE)

)

val_data = (val_data.map(_parse_fn)

.shuffle(buffer_size=10000)

.batch(BATCH_SIZE)

)

資源

建立模型

在這個 TensorFlow 圖像分類示例中,我們將創建一個遷移學習模型。 這些模型速度很快,因為它們可以使用之前經過訓練的現有圖像分類模型。 他們只需要重新訓練其網絡的上層,因為該層指定了所需圖像的類別。

我們將使用 TensorFlow 2.0 的 Keras API 來創建我們的圖像分類模型。 出於遷移學習的目的,我們將使用 MobileNetV2 作為屬性檢測器。 它是 MobileNet 的第二個版本,是 Google 的產品。 它的重量比 Inception 和 ResNet 等其他模型更輕,並且可以在移動設備上運行。 我們將把這個模型加載到 ImageNet 上,凍結權重,添加一個分類頭並在沒有頂層的情況下運行它。

IMG_SHAPE = (IMAGE_SIZE, IMAGE_SIZE, 3)

# 使用 MobileNetV2 預訓練模型

base_model = tf.keras.applications.MobileNetV2(

input_shape=IMG_SHAPE,

包括頂部=假,

權重='imagenet'

)

# 凍結預訓練的模型權重

base_model.trainable = 假

# 可訓練的分類頭

maxpool_layer = tf.keras.layers.GlobalMaxPooling2D()

prediction_layer = tf.keras.layers.Dense(1, activation='sigmoid')

# 帶有特徵檢測器的層分類頭

模型 = tf.keras.Sequential([

base_model,

最大池層,

預測層

])

學習率 = 0.0001

# 編譯模型

model.compile(optimizer=tf.keras.optimizers.Adam(lr=learning_rate),

損失='binary_crossentropy',

指標=['準確性']

)

資源

如果你要訓練 tf.keras 模型,你應該使用 TensorFlow 優化器。 tf.keras.optimizers 和 tf.train API 中的優化器一起在 TensorFlow 2.0 的 tf.keras.optimizers 中。 在 TensorFlow 2.0 中,許多 tf.keras 的原始優化器都得到了升級和替換,以獲得更好的性能。 它們使我們能夠在不影響性能的情況下應用優化器並節省時間。

閱讀:面向初學者的 TensorFlow 對象檢測教程

學習世界頂尖大學的數據科學課程獲得行政 PG 課程、高級證書課程或碩士課程,以加快您的職業生涯。

訓練模型

在我們建立模型之後,我們可以教它。 TensorFlow 2.0 的 tf.keras API 支持 tf.data API,因此您必須為此目的使用 tf.data.Dataset 對象。 它可以有效地執行訓練,我們不必對性能做出任何妥協。

num_epochs = 30

steps_per_epoch = round(num_train)//BATCH_SIZE

val_steps = 20

model.fit(train_data.repeat(),

epochs=num_epochs,

steps_per_epoch = steps_per_epoch,

驗證數據=val_data.repeat(),

驗證步驟=驗證步驟)

資源

在 30 個 epoch 之後,模型的準確率大幅提高,但我們可以進一步提高。 還記得,我們​​提到在遷移學習期間凍結權重嗎? 好吧,現在我們已經訓練了分類頭,我們可以解凍這些層並進一步微調我們的數據集:

# 解凍MobileNetV2的所有層

base_model.trainable = True

# 重新凍結層直到我們想要微調的層

對於 base_model.layers[:100] 中的層:

layer.trainable = 假

# 使用較低的學習率

lr_finetune = learning_rate / 10

# 重新編譯模型

model.compile(loss='binary_crossentropy',

優化器 = tf.keras.optimizers.Adam(lr=lr_finetune),

指標=['準確性'])

# 增加訓練時間以進行微調

Fine_tune_epochs = 30

total_epochs = num_epochs + fine_tune_epochs

# 微調模型

# 注意:將 initial_epoch 設置為在 epoch 30 之後開始訓練,因為我們

# 之前訓練了 30 個 epoch。

model.fit(train_data.repeat(),

steps_per_epoch = steps_per_epoch,

epochs=total_epochs,

initial_epoch = num_epochs,

驗證數據=val_data.repeat(),

驗證步驟=驗證步驟)

資源

30 epochs 後,模型的準確性進一步提高。 隨著更多的時期,我們看到模型的準確性有了更多的提高。 現在,我們有了一個合適的 TensorFlow 圖像識別模型,可以高精度地識別圖像中的柱狀仙人掌。

另請閱讀:面向初學者的 TensorFlow 項目創意

了解有關 TensorFlow 圖像分類的更多信息

TensorFlow 的功能強大的 API 及其功能使其成為任何程序員都可以使用的強大技術。 它的高級 API 也消除了它的一般複雜性,使其更易於使用。

您是否有興趣了解有關 TensorFlow、圖像分類和相關主題的更多信息? 然後我們建議您查看 IIIT-B 和 upGrad 的機器學習和人工智能 PG 文憑,專為工作專業人士設計,提供 450 多個小時的嚴格培訓、30 多個案例研究和作業、IIIT-B 校友身份、5 多個實用與頂級公司的實踐頂峰項目和工作協助。

機器學習課程 | 在線學習,IIIT 班加羅爾‎

機器學習和人工智能的 PG 文憑與升級和 IIIT 班加羅爾。
現在申請