선형회귀 (Linear Regression)를 통해 생선의 무게를 구해보자!
지난 글에서 배운 선형회귀 (Linear Regression)를 이용하여 직접 생선의 길이로 무게를 예측하는 모델을 구현해보자.
[인공지능] 선형회귀에 대해 알아보자(이론) : itstory1592.tistory.com/4
먼저, 생선 데이터를 준비해보자!
import numpy as np
#농어의 길이
perch_length = np.array(
[8.4, 13.7, 15.0, 16.2, 17.4, 18.0, 18.7, 19.0, 19.6, 20.0,
21.0, 21.0, 21.0, 21.3, 22.0, 22.0, 22.0, 22.0, 22.0, 22.5,
22.5, 22.7, 23.0, 23.5, 24.0, 24.0, 24.6, 25.0, 25.6, 26.5,
27.3, 27.5, 27.5, 27.5, 28.0, 28.7, 30.0, 32.8, 34.5, 35.0,
36.5, 36.0, 37.0, 37.0, 39.0, 39.0, 39.0, 40.0, 40.0, 40.0,
40.0, 42.0, 43.0, 43.0, 43.5, 44.0])
#농어의 무게
perch_weight = np.array(
[5.9, 32.0, 40.0, 51.5, 70.0, 100.0, 78.0, 80.0, 85.0, 85.0,
110.0, 115.0, 125.0, 130.0, 120.0, 120.0, 130.0, 135.0, 110.0,
130.0, 150.0, 145.0, 150.0, 170.0, 225.0, 145.0, 188.0, 180.0,
197.0, 218.0, 300.0, 260.0, 265.0, 250.0, 250.0, 300.0, 320.0,
514.0, 556.0, 840.0, 685.0, 700.0, 700.0, 690.0, 900.0, 650.0,
820.0, 850.0, 900.0, 1015.0, 820.0, 1100.0, 1000.0, 1100.0,
1000.0, 1000.0])
첫번째 넘파이 리스트는 농어의 길이,
두번째 넘파이 리스트는 농어의 무게에 대한 데이터이다.
from sklearn.model_selection import train_test_split
# 훈련 세트와 테스트 세트로 나눕니다
train_input, test_input, train_target, test_target = train_test_split(
perch_length, perch_weight, random_state=42)
# 훈련 세트와 테스트 세트를 2차원 배열로 바꿉니다
train_input = train_input.reshape(-1, 1)
test_input = test_input.reshape(-1, 1)
그런 다음, 위의 데이터를 훈련을 위한 데이터와 테스트에 사용할 데이터로 나누어준다.
train_input은 농어의 길이이다. 그 이유는, 농어의 길이를 입력값으로 넣어주어야 하기 때문이다.
train_target은 농어의 무게인데, 이 때, target이란 단어가 생소하게 들릴 수 있다.
target(타겟)은 input을 넣었을 때 얻을 수 있는 값을 의미한다.
쉽게 말하자면, input값에 대한 정답이라고 표현할 수 있다.
여기에서는, input으로 길이를 입력하고 얻는 값이 무게이므로 target이 농어의 무게 데이터가 되는 것이다.
#산점도를 그리기위한 라이브러리
import matplotlib.pyplot as plt
plt.scatter(train_input, train_target)
plt.xlabel('length')
plt.ylabel('weight')
plt.show()
matplotlib은 파이썬에서 산점도를 그릴 수 있도록 도와주는 라이브러리이다.
as문을 통해 plt로 표현해준다.
plt.scatter에서 scatter() 함수의 첫번째 매개변수는 x값, 두번째 매개변수는 y값을 의미한다.
그럼 한번 plt.show() 함수를 통해 라이브러리가 그려주는 산점도를 확인해보자!
예상했던대로 농어의 길이가 길 수록 무게가 많이 나가는 선형적인 모습을 보이고 있다.
from sklearn.linear_model import LinearRegression
lr = LinearRegression()
lr.fit(train_input, train_target)
이번에는 이 데이터를 훈련시킬 LinearRegression을 import(임포트) 해보자!
사이킷런(sklearn)에서는 이렇게 쉽게 모델을 학습시킬 수 있는 LinearRegression 클래스를 제공한다.
위 코드에서처럼 LinearRegression의 fit() 메소드에 input(농어의 길이 리스트)과 target(농어의 무게 리스트)를 매개변수로 입력해준다.
이후에도 계속 fit() 메소드가 등장할텐데, 사이킷런에서는 fit() 메소드를 훈련시킨다라고 이해하면 된다.
이렇게하면 자동으로 평균제곱오차(mse)를 최소화하는 방향으로 w(가중치)와 b(계수)를 구해준다
훈련을 마쳤다면, 한 번 예측을 해보자!
print(lr.predict([[50]]))
위 predict() 메소드를 통해 길이가 50인 농어의 무게를 예측해달라고 요청하였다.
[1241.83860323]
그럼 위와 같은 숫자가 출력될 것이다.
이는 곧, 길이가 50cm인 농어의 무게는 1241.83860323g 일 것이라고 예측하는 것이다.
가중치와 절편도 함께 확인해보자.
print(lr.coef_, lr.intercept_)
[39.01714496] -709.0186449535477
훈련을 마친 선형회귀 모델의 가중치와 절편은 각각 coef_와 intercept_변수에 담겨 있다.
이는 곧, 아래 식의 w와 b의 값이라고 할 수 있다.
실제로 값 (w=39.01714496, b=-709.0186449535477)을 대입하여 곱해보면,
생선의 길이 = 39.01714496 * 50 - 709.0186449535477 = 1241.83860323 이라는 값이 나온다.
따라서 우리는 y = 39.01714496 * x - 709.0186449535477 라는 방정식을 얻을 수 있다.
이 방정식을 그래프로 표현해보면 더 이해하기 쉬울 것이다.
# 훈련 세트의 산점도를 그립니다
plt.scatter(train_input, train_target)
# 15에서 50까지 1차 방정식 그래프를 그립니다
plt.plot([15, 50], [15*lr.coef_+lr.intercept_, 50*lr.coef_+lr.intercept_])
# 50cm 농어 데이터
plt.scatter(50, 1241.8, marker='^')
plt.xlabel('length')
plt.ylabel('weight')
plt.show()
scatter() 메소드와 다르게 plot()은 방정식을 표현하는 메소드이다.
첫번째 매개변수로 X의 범위를 나타내는 리스트를 입력받고, 두번째 메소드는 Y의 범위를 나타내는 리스트를 입력받는다.
길이가 15부터 50까지인 생선의 무게를 나타내기 위한 식을 모델이 학습하여 얻어낸 가중치와 절편을 적용하여 표현하였다.
그리고, 아까 위에서 구한 길이가 50이며 구한 무게가1241.8g인 생선의 무게를 ▲로 표현하기 위해 scatter() 함수를 다시 사용하였다.
여기서 marker는 세모 모양을 표현하기 위한 기호이다.
위 코드를 실행하면 이러한 그래프가 나타날 것이다.
어느정도 선형관계가 보이고 있다.
마지막으로 훈련세트와 테스트세트의 정확도를 출력해보며 마무리하겠다.
print(lr.score(train_input, train_target))
print(lr.score(test_input, test_target))
0.9398463339976039
0.8247503123313558
훈련세트는 대략 94%의 정확도를 가지고 있지만, 테스트세트는 82%정도로 어느정도 차이를 보이고 있다.
해당 수치는 훈련세트에 비해 테스트 데이터에서는 82%정도만큼의 효율밖에 내지 못하나는 것을 의미한다.
이처럼, 훈련세트의 정확도가 테스트세트에 비해 높은 경우를 과대적합(Overfitting)되었다고 표현한다.
전체 소스코드는 아래 링크를 통해 참조할 수 있다.
전체 소스 코드 :
colab.research.google.com/drive/1fNofUKTraUovleX-iIcrWwpji9tB_Mvt#scrollTo=IvWaL4S2YXi7
(이해가 다소 힘들거나, 틀린 부분이 있다면 댓글 부탁드리겠습니다! 😊)
💖댓글과 공감은 큰 힘이 됩니다!💖