MLOps 부트캠프 by 한경+토스뱅크/Machine Learning

로지스틱 회귀 - 손실 함수

나니니 2024. 9. 16. 12:32

선형 회귀를 통해 하려고 하는 건 학습 데이터에 최대한 잘 맞는 가설 함수를 찾는 것이다. 

그러기 위해선 가설 함수를 평가하는 어떤 기준이 있어야 하는데, 그 기준이 되는게 손실 함수이다. 

로지스틱 회귀에서도 마찬가지다. 

데이터에 잘 맞는 가설 함수를 찾고, 손실 함수를 이용해 가설 함수를 평가한다. 

로지스틱 회귀 가설 함수

선형 회귀의 손실 함수는 평균 제곱 오차라는 개념을 기반으로 하는데, 데이터 하나하나의 오차를 구한 후에 그 오차들을 모두 제곱항 평균을 내는 작업을 한다. 

로지스틱 회귀의 손실 함수는 평균 제곱 오차를 사용하지 않고, 대신 '로그 손실', 영어로는 log loss라는 것을 사용한다. 좀 더 어려운 푷ㄴ으로는 cross entropy라고도 한다. 

 

로그 손실

로그 손실은 아래와 같은데, 이를 로그 손실이라고 부르는 이유는 손실의 정도를 로그 함수로 결졍하기 때문이다. 

로그 손실

위 수식에 따른 그래프는 아래처럼 그릴 수 있다. 

아웃풋이 1, 0일 경우

$h(x)$는 어떤 입력 변수에 대한 가설 함수의 예측값이고, $y$는 실제 값이다.

로그 손실 함수는 예측값이 실제 결과랑 얼마나 괴리가 있는지 알려 주는 역할을 한다. 

그런데 로지스틱 회귀는 분류 알고리즘이고, 분류가 두 가지라고 가정하면 가능한 목표 변수가 1과 0밖에 없다.

아웃풋이 1인 경우

왼쪽 그래프를 보면 $h(x)$가 1이면 100%의 확률로 아웃풋이 1일 거라고 예측하는 것이다.

실제 결과가 1이기 때문에, 이 가설 함수는 완벽하게 맞춘 것이므로 손실이 0이다.

만약 $h(x)$가 0.8 정도면 80%의 확률로 아웃풋이 1일 거라고 예측하는 건데, 실제 결과가 1이기 때문에 이 가설 함수는 꽤 잘했다고 평가할 수 있다. 손실이 좀 있으나 크지 않은 수준이다. 

또, 왼쪽으로 갈 수록 손실이 커지는데, 처음에는 완만하게 커지다가 급격하게 가파라진다. $h(x)$가 1에서 멀어질수록 잘 못하고 있는 것이므로 손실을 엄청 키운다고 볼 수 있다. 

 

아웃풋이 0인 경우

오른쪽의 그래프는 실제 아웃풋이 0인 경우인데, 아웃풋이 1인 경우와 반대로 되어 있다.

$h(x)$가 0이라는 건, 아웃풋이 1일 확률이 0%라고 예측하는 건데, 실제 결과가 0이기 때문에 완벽하게 예측했다고 할 수 있다. 그래서 손실이 0인 것이다. 

만약 $h(x)$가 0.2 정도면 20%의 확률로 아웃풋이 1일 거라고 예측하는 건데, 실제 결과가 0이기 때문에 나쁘지 않게 예측했다고 할 수 있다. 그래서 손실은 좀 있으나 크지 않다. 

해당 그래프는 오른쪽으로 갈 수록 손실이 커지는데 처음에는 완만하다가 급격하게 가파라진다. $h(x)$가 0에서 멀어질 수록 못하고 있는 것이기 때문에 손실을 엄청나게 키우는 것이다. 

 

이런 로그 손실을 이용해서 로지스틱 회귀의 손실 함수를 만드는데, 보통 로지스틱 회귀에서 로그 손실을 쓸 때 아래와 같은 형태로 쓴다. 

 

$logloss(h_θ(x),y)=−ylog(h_θ(x))−(1−y)log(1−h_θ(x))$

 

위의 식은 아래와 완전히 동일하다. 

 

이를 한 줄에 표현하면 아래처럼 한 줄 방식으로 나타낼 수 있고, 보통 아래의 식을 주로 사용한다. 

 

$logloss(h_θ(x),y)=−log(1−h_θ(x))$

로지스틱 회귀 손실 함수

로지스틱 회귀 손실 함수를 계산하기 위해서는 각 데이터에 대해 손실을 구한 후, 손실의 평균을 내야 한다.

우선 위의 이미지에서 시그마 오른쪽 부분부터 살펴보면, 로그 손실 함수가 있다. 

 

이는 로그 손실을 이용해 i번째 데이터에 대한 손실을 구하는 것이다.

 

시그마 밑에 $i=1$이라고 되어 있고, 위에는 $m$이라고 되어 있는데,

$m$은 학습 데이터 개수이고 $i$는 1을 대입하고, 2를 대입하고, 3을 대입하고~ 이런 식으로 $i$부터 $m$까지 순서대로 대입하여 계산하고 계산된 결과를 모두 더하는 것이다. 

다 더하는 것까지가 이 시그마의 역할이다. 

 

마지막에 $m$으로 나눠주어서 평균을 구하게 된다. 

 

요약하자면,

모든 학습 데이터에 대해 로그 손실을 계산하고, 평균을 내는 것이다. 그걸로 가설 함수를 평가하게 된다. 

 

참고로 이 손실 함수의 인풋은 세타인데, 왜 그럴까?

가설 함수는 세타 값들을 어떻게 설정하느냐에 따라 바뀐다. 그래서 어떤 세타 값들을 설정하느냐에 따라 학습 데이터의 손실이 달라지게 된다. 그러므로 손실 함수의 인풋은 세타가 된다. 

 

그리고 만약 이 부분에 로그 손실 함수를 완전히 대입하고 싶다면 아래처럼 할 수 있다.