saito

記者紹介

saitoxu(さいとぅ)

株式会社Carat 最高技術責任者 京都大学大学院修了。新卒でITベンチャーに就職し、BtoBのWebサービス開発に従事。キカガクOnlineでは機械学習を学ぶ初心者が、次のステップを踏み出すのに役立つ情報を発信していく予定。

data

今回はビジネスの現場で役に立つ重回帰分析について、プログラミングを通して学んでいきたいと思います。

重回帰分析とは?

重回帰分析は単回帰分析の発展形と言えるもので、単回帰分析は1つの目的変数を1つの説明変数で予測する手法でしたが、重回帰分析は1つの目的変数を複数の説明変数で予測する手法になります。

式で表すと以下になります。

$$
y=\boldsymbol{a}^{\mathrm{T}}\boldsymbol{x}+b
$$

単回帰分析については前回の記事で取り上げました。

問題設定

今回はより実践的なデータを用いて重回帰分析を行いたいと思います。
Housingデータセットというボストン近郊の住宅情報のデータを使って、いくつかの住宅情報から住宅価格を予測するモデルを作成したいと思います。
説明変数(入力変数)と目的変数(出力変数)は以下を使うこととします。

説明変数 説明
RM 1戸あたりの平均部屋数
AGE 1940年よりも前に建てられた家屋の割合
DIS ボストンの主な5つの雇用圏までの距離
目的変数 説明
MEDV 住宅価格の中央値(単位は1,000ドル)

※実際のHousingデータセットはより多くの情報がありますが、スペースの都合で絞りました

プログラム

では実際にプログラミングをしていきます。
最初はHousingデータセットの読み込みを行います。
データの読み込みにはpandasというデータ分析でよく使われるライブラリを使用します。
今回はデータの読み込みにとどめていますが、pandasを使えば、データ解析の集計からプロットまで非常に低コストで実装でき、機械学習には必須のライブラリといえます。

# coding: utf-8
# 必要なモジュールの読み込み
import pandas as pd
import numpy as np
from sklearn import linear_model

# HousingデータセットのダウンロードURL
url = 'https://archive.ics.uci.edu/ml/machine-learning-databases/housing/housing.data'
# Housingデータセットの読み込み
df = pd.read_csv(url, header=None, sep='\s+')
# データセットの列定義。14種類の情報が定義されている
# 詳しくは https://archive.ics.uci.edu/ml/datasets/Housing を参照
df.columns = ['CRIM', 'ZN', 'INDUS', 'CHAS', 'NOX', 'RM', 'AGE', 'DIS', 'RAD',
              'TAX', 'PTRATIO', 'B', 'LSTAT', 'MEDV']

# 説明変数としてRM, AGE, DISを使用
X = df[['RM', 'AGE', 'DIS']].values
# 目的変数としてMEDVを使用
y = df['MEDV'].values

次に、モデルの学習を行います。

reg = linear_model.LinearRegression()
reg.fit(X, y) # モデルの学習

print(reg.coef_) # 係数の表示 ⇒ [ 8.4406291  -0.09941823 -0.48038409]
print(reg.intercept_) # 切片の表示 ⇒ -21.8727877925

モデルの学習を行って、方程式の各係数と切片を求めることができました。

$$
y= \boldsymbol{a}^{\mathrm{T}}\boldsymbol{x}+b
$$

で表すと、

$$
\boldsymbol{a}=\left(\begin{array}{c}
8.4406291 \newline
-0.09941823 \newline
-0.48038409 \end{array}
\right), \ b=-21.8727877925
$$

となります。
係数を見ると、1番目のRM(平均部屋数)は正、2番目のAGE(古い家屋の割合)と3番目のDIS(仕事場までの距離)は負の値になっています。
これは、平均部屋数が増加すると住宅価格は高くなるが、古い家屋の割合・仕事場までの距離が増加すると住宅価格は低下するということを表しており、直感的に考えても合っていますね。
では、最後に新しい住宅情報に対する価格の予測を行ってみます。

学習データと比べて、

  • 部屋数は多め
  • 古い家屋の割合は低め
  • 仕事場までの距離は短め

なデータを与えてみます。

print(reg.predict(np.array([8., 30., 3.]).reshape(1, -1))) # ⇒ [ 41.22854567]

すると、住宅価格は約41(千ドル)と、たしかに比較的高い結果が出てきました。
このように、モデルを作成し予測してみるところまでは簡単にできるというのが分かってもらえたのではと思います。

まとめ

今回は重回帰分析の基礎を解説し、そのアルゴリズムをプログラミングで実装しました。
重回帰分析はビジネスの現場で役に立つことが多いので、ぜひこの機会に使えるように練習してみてください。