본문 바로가기
3. 딥러닝 (neuralnet 패키지)/2) 정규분포 학습 예제

[R 딥러닝 예제] 표준정규분포함수 만들기 #2. 학습결과 확인하기

by makhimh 2019. 12. 14.

#2. 학습결과 확인하기


우리는 지난시간에 아래와 같은 train 데이터와 test 데이터를 정의했습니다. 


library(neuralnet)


#data 생성

set.seed(1)

input=rnorm(1000)

output=dnorm(input)


#train 데이터를 배정합니다.

input_train=input[1:500]

output_train=output[1:500]


#test데이터를 배정합니다.

input_test=input[501:1000]

output_test=output[501:1000]


#train 데이터 input,output을 하나의 행렬로 묶기

my_data_train=cbind(input_train,output_train)


input_train과 output_train 데이터를 이용하여 신경망을 학습시켰고, my_NN이라는 변수에 그 결과를 넣었습니다. 


my_NN=neuralnet(output_train~input_train, data=my_data_train,

                hidden=c(3,3),

                threshold=0.01)


학습이 잘 되었는지 확인하기 위해서 test 데이터에 신경망을 적용하겠습니다. input_test 데이터를 신경망에 넣고, 예측값을 계산합니다. 계산된 예측값을 실제값인 output_test와 비교할 것입니다. 예측값의 계산에는 compute 함수를 사용합니다. 함수에는 학습한 신경망, 테스트할 데이터 순서로 입력합니다. 


output_predict=compute(my_NN,input_test)


아마 아래와 같은 error가 발생할 것입니다. 


> output_predict=compute(my_NN,input_test)

Error in if (ncol(newdata) == length(object$model.list$variables)) { : 

  argument is of length zero


input test 데이터가 열벡터가 아니라 행벡터이기 때문입니다. 열벡터로 바꾸고 compute 함수를 다시 적용합니다. 


input_test=as.matrix(input[501:1000],ncol=1)

output_predict=compute(my_NN,input_test)


예측된 결과값은 output_predict$net.result 에 저장되어 있습니다. 


이제 두개의 그래프를 그릴겁니다. 한 그래프는 test_input을 x로 하고  test_output을 y로 하는 그래프입니다. 실제 값에 해당되는 그래프구요. 또 다른 그래프는 test_input을 x로 하고 output_predict를 y로 하는 그래프입니다. 우리가 학습시킨 신경망을 이용하여 예측한 값에 해당되는 그래프입니다. 


#그래프로 그릴 데이터를 data.frame에 넣어줌.

results=data.frame(input_test,output_test,output_predict$net.result)

names(results)=c("input_test","output_test","output_predict")


#데이터를 크기 순으로 정렬, 안그러면 선그래프가 이상하게 그려짐(1강 참고)

results=results[order(input_test),]

names(results)=c("input_test","output_test","output_predict")


#그래프

plot(0,type="n",

     xlim=c(-3,3),ylim=c(0,0.5))

points(results$input_test,results$output_test,col='red',type='l')

points(results$input_test,results$output_predict,col='blue',type='l')

legend("topright",c("actual","predict"),fill=c('red','blue'))


그래프는 아래와 같습니다. 상당히 비슷합니다. 끝부분은 표본의 수가 적기 때문에 잘 예측하지 못하는 것 같습니다. 

댓글