본문 바로가기
1-1. 지도학습 (이산자료)/3) 서포트 벡터머신

R 서포트 벡터 머신 예시 (대장암 예시)

by makhimh 2023. 7. 27.

먼저 아래 패키지를 설치합시다. 

install.packages("survival")


survival 패키지에는 대장암 데이터인 colon 데이터가 들어 있습니다. 

패키지를 불러옵니다. 

library(survival)


데이터를 변수에 저장해줍니다. 

data=colon

 

이후 과정은 번호를 붙여 진행하겠습니다. 

 

1. 데이터 살펴보기

str함수를 이용하여 변수들을 살펴봅시다. 

 

> str(colon)
'data.frame':	1858 obs. of  16 variables:
 $ id      : num  1 1 2 2 3 3 4 4 5 5 ...
 $ study   : num  1 1 1 1 1 1 1 1 1 1 ...
 $ rx      : Factor w/ 3 levels "Obs","Lev","Lev+5FU": 3 3 3 3 1 1 3 3 1 1 ...
 $ sex     : num  1 1 1 1 0 0 0 0 1 1 ...
 $ age     : num  43 43 63 63 71 71 66 66 69 69 ...
 $ obstruct: num  0 0 0 0 0 0 1 1 0 0 ...
 $ perfor  : num  0 0 0 0 0 0 0 0 0 0 ...
 $ adhere  : num  0 0 0 0 1 1 0 0 0 0 ...
 $ nodes   : num  5 5 1 1 7 7 6 6 22 22 ...
 $ status  : num  1 1 0 0 1 1 1 1 1 1 ...
 $ differ  : num  2 2 2 2 2 2 2 2 2 2 ...
 $ extent  : num  3 3 3 3 2 2 3 3 3 3 ...
 $ surg    : num  0 0 0 0 0 0 1 1 1 1 ...
 $ node4   : num  1 1 0 0 1 1 1 1 1 1 ...
 $ time    : num  1521 968 3087 3087 963 ...
 $ etype   : num  2 1 2 1 2 1 2 1 2 1 ...

 

종속변수는 status 입니다. 1이면 대장암 재발 또는 사망입니다. 독리변수는 rx, sex, age, obstruct, perfor, adhere, nodes, differ, extent, surg 로 놓겠습니다. 대장암 여부를 예측하는 모델을 만드는 것이 목적입니다. 

 

2. 결측치 확인 및 제거

결측치를 확인하고 제거해줍니다 .

 

#결측치 위치 확인 함수
where.na.df=function(df){ 
  
  res=data.frame(row=NA,col=NA)
  
  
  for (i in 1:dim(df)[1]){
    for (j in 1:dim(df)[2]){
      
      if (is.na(df[i,j])){
        res=rbind(res,c(i,j))
      }
      
    }
  }
  
  res=res[-1,]
  rownames(res)=NULL
  
  return(res)
  
}

where.na.df(data) #결측치 위치확인
data = na.omit(data) #결측치 제거

 

인스턴스의 개수가 1858에서 1776으로 줄었습니다. 

 

> dim(data)
[1] 1776   16

 

3. 훈련데이터셋, 테스트데이터셋 나누기

데이터를 훈련데이터셋과 테스트데이터셋으로 나누겠습니다. 전체의 25%를 테스트 데이터셋으로 만들겠습니다. 코드는 아래와 같습니다. 

set.seed(999)
id_test=sample(1:nrow(data),nrow(data)*0.25)

data_test=data[id_test,]
data_train=data[setdiff(1:nrow(data),y=id_test),]

 

4. 모델 만들기

대장암을 예측하는 모델을 만들 것입니다. 서포트 벡터머신을 사용합니다. e1071 패키지입니다. svm 함수를 사용합니다. kernal, cost, gamma 등의 파라미터를 설정해주어야 합니다. 나머지는 디폴트 값으로 놓고 kernal 만 radial 로 설정합니다. 기초강의이므로 자세히 다루지는 않겠습니다. 

model_svm = svm(
  status ~ rx + sex + age + obstruct + perfor + adhere + nodes + differ +
    extent + surg,
  data = data_train,
  kernel = "radial",
  probability = TRUE
)


모델 생성 결과는 아래와 같습니다. 

> model_svm

Call:
svm(formula = status ~ rx + sex + age + obstruct + perfor + adhere + nodes + differ + extent + surg, 
    data = data_train, kernel = "radial", probability = TRUE)


Parameters:
   SVM-Type:  eps-regression 
 SVM-Kernel:  radial 
       cost:  1 
      gamma:  0.08333333 
    epsilon:  0.1 

Sigma:  0.7715177 


Number of Support Vectors:  1163

 

파라미터들을 변경하며 튜닝할 수 있습니다. 

 

6. 모델 평가

predict 함수를 이용하여 테스트셋의 결과변수를 구해줍니다. type 을 response 로 설정하면 확률이 반환됩니다. 해당 확률이 0.5보다 높은 경우를 1로, 낮은 경우를 0으로 반환한 뒤, test 데이터의 status 와 비교합니다. 이때 TRUE 의 비율이 정확도입니다. mean 으로 구한 평균이 TRUE 비율입니다. 잘 생각해보시면 이해가 되실겁니다. 

 

pred = predict(model_svm, data_test, type = 'response')
mean(ifelse(pred > 0.5, 1, 0) == data_test$status)

 

정확도는 66.4% 입니다. 

 

test 데이터에 대해 ROC 곡선도 그려봅시다. 

 

library(Epi)
ROC(test=pred,stat=data_test$status,plot="ROC")

 

코드 모아보기

library(survival)


#1.데이터
data = colon

#2.결측치확인 및 제거

#결측치 위치 확인 함수
where.na.df = function(df) {
  res = data.frame(row = NA, col = NA)
  
  
  for (i in 1:dim(df)[1]) {
    for (j in 1:dim(df)[2]) {
      if (is.na(df[i, j])) {
        res = rbind(res, c(i, j))
      }
    }
  }
  res = res[-1, ]
  rownames(res) = NULL
  
  return(res)
}

where.na.df(data) #결측치 위치확인
data = na.omit(data) #결측치 제거



#3.훈련데이터셋, 테스트데이터셋 나누기
set.seed(999)
id_test = sample(1:nrow(data), nrow(data) * 0.25)

data_test = data[id_test, ]
data_train = data[setdiff(1:nrow(data), y = id_test), ]

#4.모델 만들기
library(e1071)
model_svm = svm(
  status ~ rx + sex + age + obstruct + perfor + adhere + nodes + differ +
    extent + surg,
  data = data_train,
  kernel = "radial",
  probability = TRUE
)


#5.모델 평가하기
pred = predict(model_svm, data_test, type = 'response')
mean(ifelse(pred > 0.5, 1, 0) == data_test$status)

#roc curve
library(Epi)
ROC(test = pred,
    stat = data_test$status,
    plot = "ROC")

댓글