[Tensorflow]11.MNIST 91%(2)
in Language on tensorflow
11.MNIST 91%(2)
10.MNIST 91%(1) 에 이어서 포스팅을 하겠습니다.
이번 장에서는 MNIST예제를 코드로 작성해보겠습니다.
tensorflow는 기본적으로 몇 천개의 학습데이터를 제공하는데요~~ (지금까지와는 차원이 다릅니다)
그것을 이용하여 저희가 직접 손글씨 숫자를 인식하는 프로그램을 작성해보겠습니다.
시작해볼까요???
먼저, 맨 처음 tensorflow를 import~~!
import tensorflow as tf
그리고 학습 데이터를 가져오기 위해 다음을 입력해줍니다
from tensorflow.examples.tutorials.mnist import input_data mnist=input_data.read_data_sets("MNIST_data/",one_hot=True)
그러면은 위와 같이 tensorflow가 데이터를 알아서 가져와줍니다.
그 다음 학습에 필요한 변수들 선언!
#X의 shape는 28 * 28 크기의 이미지를 한 줄로 폈다고 생각하세요! X=tf.placeholder(tf.float32,[None,784]) Y=tf.placeholder(tf.float32,[None,10]) #W의 shape는 X의 shape와 행렬곱해서 Y의 shape가 되어야하므로 [784,10] W=tf.Variable(tf.random_normal([784,10]),name='weight') b=tf.Variable(tf.random_normal([10]),name='bias')
hypothesis 정의
#softmax로 정의된 hypothesis logits=tf.matmul(X,W)+b hypothesis=tf.nn.softmax(logits)
cost 함수 정의
cost_i=tf.nn.softmax_cross_entropy_with_logits(logits=logits,labels=Y) cost=tf.reduce_mean(cost_i)
cost 최소화
#실제 값과 예측 값의 차이를 계속 줄여준다. optimizer=tf.train.GradientDescentOptimizer(learning_rate=0.0001) train=optimizer.minimize(cost)
노드들을(tensor) 그래프 화
sess=tf.Session() #변수 초기화 sess.run(tf.global_variables_initializer())
위의 코드들을 보면 저희가 배운 코드랑 하나도 다른게 없죠????
보시는 것과 같이 저희는 데이터만 주어진다면 언제라도 예측 프로그램을 만들 수 있습니다!!
학습을 시킬 때 데이터가 몇 천개가 되기 때문에 batch 시스템을 이용하는데요~~
batch 시스템은 대용량 데이터를 차례대로 가져올 때 유용하게 쓰입니다~~~
코드를 천천히 살펴보시면 조금이라도 이해하실 수 있을겁니다~
training_epochs=15 #15번 거쳐서 학습을 시킨다 batch_size=100 #100개 씩 데이터를 가져온다 for epoch in range(training_epochs): avg_cost=0 #학습 1번 당 평균 cost값 #데이터의 총 갯수를 100으로 나누어준다. -> 몇 번에 거쳐 데이터를 가져올 지 계산 total_batch=int(mnist.train.num_examples/batch_size) for i in range(total_batch): #100개 씩 데이터를 가져와 학습시킨다. batch_xs,batch_ys=mnist.train.next_batch(batch_size) c,_=sess.run([cost,train],feed_dict={X:batch_xs,Y:batch_ys}) avg_cost+=c/total_batch print('Epoch:','%04d'%(epoch+1),'avg_cost','{:.9f}'.format(avg_cost))
학습 끝! 학습할 때마다 점점 cost 값이 줄어드시는 것을 볼 수 있습니다.
과연 학습이 잘 되었나 확인해볼까요?
tensorflow에서는 기본적으로 test데이터도 제공하는데 이를 이용해서 정확도를 측정해보겠습니다.
놀랍게도 정확도는 91%까지 나옵니다~~ 신기하죠???
지금까지 배운 것만으로도 이렇게 높은 정확도가 나오는 것을 보실 수 있습니다.
그래도 이것이 정말 제대로 예측하는 지 눈으로 보고 싶은데요~
matplotlib을 이용하면 됩니다~~~~ (이것에 대한 자세한 설명은 생략하겠습니다)
위의 보시는 손글씨 3이 입력데이터인데 실제 값과 저희가 예측한 값이 일치하는 것을 볼 수 있습니다
정말 신기하죠??? 전까지 해왔던 예제와는 데이터 양부터가 차원이 다릅니다~~~
하지만 정확도가 91%인 것은 그리 높은 것이 아닙니다
틀릴 확률이 9% 나 되니깐요
정확도를 더 높이기 위해서 저희는 이제 Deep learning에 대해 다음 장부터 배우도록 하겠습니다!
음 머신러닝에 대한 강의는 여기까지입니다!! ㅋㅋㅋㅋ 수고하셨습니다 ~~!
전체 코드