MNIST 入門訓練

MNIST資料集

MNIST 一組公開的資料集,包含了一組手寫數字的影像檔以及對應數字值的標籤檔,資料集可以從Yann LeCun’s website下載或透過Google提供的Python程式下載. 資料集分成訓練跟測試(10,000筆)兩組資料,將資料拆分出兩組在機器學習是很重要的概念,可以確保模型的一般性,避免模型因為過度學習反而失去對未知資料的解釋能力.

訓練集有60,000筆數字,每個數字是一個28 * 28 像素的影像,每個像素的數值由0-1強度代表影像內容. mnist.train.images is a tensor (an n-dimensional array) with a shape of [60000, 784].

訓練集標籤檔有60,000筆資料,有一組0,1的陣列表示0-9的數字,例如3用([0,0,0,1,0,0,0,0,0,0]). mnist.train.labels is a [60000, 10] array of floats.

Softmax 回歸

影像強度向量值透過下列函式轉換成代表0-9數字的類別值

  • 是每個像素的權重
  • 是每個數字類別的偏差修正
  • j 是像素指標,用來計算像素強度權重的加總

用Softmax函式將,轉換證據值成預測機率

Softmax將每個數字類別的證據值取指數後正規化,轉換成0~1 之間的機率

透過tensorflow實作

To use TensorFlow, we need to import it.

import tensorflow as tf

We describe these interacting operations by manipulating symbolic variables. Let's create one:

x = tf.placeholder("float", [None, 784])
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))

定義數字的預測模型

y = tf.nn.softmax(tf.matmul(x,W) + b)

訓練模型

模型需要定義怎樣的結果是好的,在機器學習通常定義怎樣是不好的結果(Cost/Loss),透過最佳化程序去減小不好結果.

Cross-entropy用來描述預測結果如何沒有效率的描述實際結果,想進一步暸解可以參考連結

  • (y) 預測結果的分佈
  • (y’) 是實際數字結果的分佈
y_ = tf.placeholder("float", [None,10])

實作cross entropy定義

cross_entropy = -tf.reduce_sum(y_*tf.log(y))

透過剛剛的定義TensorFlow 能知道完整計算流程,會透過類神經背向傳遞的演算法,能有效率的決定要如何調整參數透過定義的降低你定義的成本值(cross entropy).

使用GradientDescentOptimizer最佳化方式降低模型成本值

train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)

Tensorflow有提供其他最佳化處理的演算法

模型訓練相關設定完成了,初始化變數

init = tf.initialize_all_variables()

開啟Session執行模型

sess = tf.Session()
sess.run(init)

執行1000次訓練

for i in range(1000):  
  batch_xs, batch_ys = mnist.train.next_batch(100)
  sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

評量模型結果

判斷模型預測結果是否等於實際結果

correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))

透過cast將boolean 轉換成數值,計算平均值可以計算正確率. For example, [True, False, True, True] would become [1,0,1,1] which would become 0.75.

accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

使用測試資料集計算正確率

print sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})

results matching ""

    No results matching ""