最近グラフ理論的なことを見てる中で久々にダイクストラ法に出会ったので、思い出しがてら書いた。隣接リスト方式のほうは初めて書いた。
-------------------------------------------------------
import numpy as np
inf=np.float('inf')
def dykstra_neighborlist(link, start):
end=np.zeros(np.max(link[:, 0:2])+1)+inf
end[start]=0
while np.sum(end)==inf:
candidate=np.zeros(end.shape[0])+inf
for i in range(candidate.shape[0]):
if end[i]!=inf:
neighbor_link=link[link[:, 0]==i]
for j in range(neighbor_link.shape[0]):
if end[neighbor_link[j, 1]]==inf:
length=end[i]+neighbor_link[j, 2]
if candidate[neighbor_link[j, 1]]>length:
candidate[neighbor_link[j, 1]]=length
candidate[end!=inf]=inf
end[np.argmin(candidate)]=np.min(candidate)
return end
def dykstra_linkmatrix(link, start):
end=np.zeros(link.shape[0])+inf
end[start]=0
while np.sum(end)==inf:
length=link+end.reshape(link.shape[0], 1)
length[:, end!=inf]=inf
end[np.argmin(length)%link.shape[0]]=np.min(length)
return end
--------------------------------------------------------
2017年6月25日日曜日
2017年5月20日土曜日
ボカロ曲レコメンドシステムを作るため嗜好傾向を調査中です。ご協力よろしくお願いします。
MikuMikuRecommendationというボカロ曲レコメンドシステムを作ろうとしており、グーグルフォームにて嗜好傾向を知るためのアンケートを開始しました。選曲はマジカルミライとMIKUEXPOからです。めちゃめちゃ多いですが"39"曲なら仕方ないといってご協力頂ける可能性にかけた。ご協力のほどよろしくお願いします。
グーグルフォームへのリンクはこちら
システムについての話も昨日のは貼っつけただけだったので、以下にまとめておきます。
なお昨日から少し変わってます。
ここおかしいだろ!って点あれば教えてください。
学習は下記のlearning関数です。
引数:iterは学習回数、learning_rateは学習率、dataはアンケート結果の行列で(i, j)が回答者(i)の曲(j)への評価。Rはdataに値が入ってるかどうかのブーリアン。featuresは(i, j)が特徴(i)を曲(j)がどれだけ持っているかのパラメタを意味する行列。thetaは(i, j)が回答者(i)の特徴(j)に対する好感度を意味する行列。lamは正規化項のパラメタ。
出力:学習後のthetaとfeatures、および各学習ごとの誤差関数の値が入ってるJ_list
predictはthetaとfeaturesの内積でつまりscoreの予測をする関数。
cost_funcは二乗誤差、delta_funcはthetaとfeatures各要素の偏微分を出力する関数。これが間違ってたら元も子もない。
---------------------------------------------------------------------------------
def learning(iter, learning_rate, data, features, R, theta, lam):
J_list=np.array([])
for i in range(iter):
prediction=predict(np.r_[np.ones(features.shape[1]).reshape(1, -1), features], theta)
J=cost_func(data, prediction, R, theta, lam)
delta_theta, delta_features=delta_func(data, prediction, np.r_[np.ones(features.shape[1]).reshape(1, -1), features], R, theta, lam)
theta=theta-learning_rate*delta_theta
features=features-learning_rate*delta_features
J_list=np.r_[J_list, J]
return theta, features, J_list
def predict(features, theta):
out=np.dot(theta, features)
return out
def cost_func(data, prediction, R, theta, lam):
sub=prediction-data
J=np.sum(sub*sub*R)/(np.sum(R)*2)+np.sum(theta*theta)/(np.sum(R)*2)*lam
return J
def delta_func(data, prediction, features, R, theta, lam):
delta_theta=np.dot((prediction-data)*R, features.T)/np.sum(R)
delta_theta[:, 1:]=delta_theta[:, 1:]+theta[:, 1:]*lam/np.sum(R)
delta_features=np.dot(theta.T, (prediction-data)*R)/np.sum(R)
delta_features=delta_features+features*lam/np.sum(R)
return delta_theta, delta_features[1:, :]
---------------------------------------------------------------------------------
これがちゃんと動作するかテストするためにデータセットを作る。
引数:sは回答者の数、mは曲数、categoryは曲数を何個のカテゴリにするか、maxは評価の最高点、tightnessはカテゴリ間でのcorrelationの強さを調整するパラメタとしてとりあえず3としています。
出力:scoreはs*mの行列で、(i, j)は回答者(i)の曲(j)への評価を示す。
favo_maskは各回答者の好きなカテゴリ、music_maskは各曲の属するカテゴリ、correlation_rateは各カテゴリ間での好き嫌い関係を意味するものです。make_half関数で、(i, j)と(j, i)が同じ値になるようにし、対角成分は0にします。
その後はscoreの各成分に対してcorrelationにそって増減および乱数ちょこっと加えた後、1からmaxに標準化して出力。
---------------------------------------------------------------------------------
def make_bigscore(s, m, category, max, tightness=3):
score=np.zeros(s*m).reshape(s, m)
favo_mask=np.random.randint(0, category, s)
music_mask=np.random.randint(0, category, m)
correlation_rate=make_half(np.random.randn(category, category))
for i in range(s):
for j in range(m):
score[i, j]=score[i, j]+correlation_rate[favo_mask[i], music_mask[j]]*tightness+np.random.randn(1)
score=(score-np.min(score))*max/np.max(score-np.min(score))+1
for i in range(score.shape[0]):
for j in range(score.shape[1]):
score[i, j]=math.floor(score[i, j])
return score
def make_half(matrix):
for i in range(matrix.shape[0]):
for j in range(matrix.shape[1]):
if i>j:
matrix[j, i]=matrix[i, j]
elif i==j:
matrix[i, j]=0
return matrix
---------------------------------------------------------------------------------
その後、scoreの一部を0でマスクしたdataと、dataで数値が入っているところを真、0になっているところを偽とした行列Rを出力する関数。
引数:上のscoreと、どれだけを0に置き替えるかのcover_size
出力:dataとR
---------------------------------------------------------------------------------
def make_data(score, cover_size):
data=score.copy()
data=data.reshape(-1)
mask=np.random.choice(data.size, cover_size)
data[mask]=0
data=data.reshape(score.shape)
R=(data!=0)
return data, R
これで試した結果が以下になります。
まずはJ_listの推移は以下のとおり。学習回数多すぎなのは気にしてはいけない。
あとはR=Falseの値に対する予測値をみてみる。そもそもdataの作り方が「ルールに基づいてたくさんの人の全曲に対する評価」を作ったうえでちょこちょこゼロでマスクしたものなので、そのマスク前の値scoreを比較対象とします。
下記はR=0に対する結果。左から予測値、R(=0に決まっている)、予測値-元のscore。なお評価は1-5の5段階評価。
[[ 2.91048635 0. -0.08951365]
[ 3.22132161 0. 0.22132161]
[ 2.87064662 0. 0.87064662]
[ 1.12139268 0. 0.12139268]
[ 2.41326846 0. 1.41326846]
[ 3.55421757 0. 0.55421757]
[ 3.97669655 0. -0.02330345]
[ 2.7501173 0. 0.7501173 ]
[ 3.6383573 0. -1.3616427 ]
[ 3.64936787 0. -0.35063213]
[ 2.0244445 0. 0.0244445 ]
[ 3.84411781 0. -0.15588219]
[ 3.01835681 0. 0.01835681]
[ 1.89188363 0. 0.89188363]
[ 2.75051635 0. -1.24948365]
[ 2.32287718 0. -1.67712282]
[ 2.05562347 0. 0.05562347]
[ 0.90773748 0. -1.09226252]
[ 1.80391987 0. -1.19608013]
[ 2.674269 0. 1.674269 ]
[ 3.00672833 0. 0.00672833]
[ 1.58966497 0. 0.58966497]
[ 0.94116798 0. -1.05883202]
[ 0.74880779 0. -1.25119221]
[ 1.09310716 0. -0.90689284]
[ 2.51634842 0. 0.51634842]
[ 2.9827374 0. -0.0172626 ]
[ 2.64218652 0. 1.64218652]
[ 2.1911775 0. 0.1911775 ]
[ 2.74421011 0. 0.74421011]
[ 2.21623794 0. 0.21623794]
[ 1.89532806 0. -0.10467194]
[ 2.37408885 0. 0.37408885]
[ 3.39205366 0. -0.60794634]
[ 1.61236102 0. -0.38763898]
[ 2.11890973 0. 1.11890973]
[ 3.17744208 0. 1.17744208]
[ 3.08776717 0. -0.91223283]
[ 3.61529334 0. -0.38470666]
[ 1.35736046 0. -0.64263954]
[ 2.06968813 0. 1.06968813]
[ 1.95003957 0. -0.04996043]
[ 1.56540138 0. -0.43459862]
[ 1.97601405 0. -0.02398595]
[ 2.92261337 0. -1.07738663]
[ 4.09495721 0. 1.09495721]
[ 3.19551809 0. -0.80448191]
[ 3.04745553 0. 0.04745553]
[ 0.79817745 0. -1.20182255]
[ 1.54938092 0. -0.45061908]
[ 1.36258335 0. -0.63741665]
[ 1.70375735 0. -0.29624265]
[ 2.93756708 0. -0.06243292]
[ 2.87062359 0. -0.12937641]
[ 2.38004315 0. 0.38004315]
[ 2.78595807 0. 0.78595807]
[ 3.1983942 0. -0.8016058 ]
[ 0.96060024 0. -1.03939976]
[ 3.52827113 0. -1.47172887]
[ 2.28749893 0. 1.28749893]
[ 2.35833173 0. 0.35833173]
[ 2.74344353 0. -0.25655647]
[ 0.74420262 0. -1.25579738]
[ 1.00412291 0. -0.99587709]
[ 3.2490832 0. 1.2490832 ]
[ 3.14802348 0. 0.14802348]
[ 3.13892027 0. 1.13892027]
[ 3.13161889 0. -0.86838111]
[ 3.0564188 0. 1.0564188 ]
[ 2.75968977 0. -0.24031023]
[ 3.3187245 0. 0.3187245 ]
[ 2.48239435 0. 0.48239435]
[ 3.16379722 0. 1.16379722]
[ 3.32180745 0. -0.67819255]
[ 1.2453015 0. 0.2453015 ]
[ 2.51397865 0. -1.48602135]
[ 3.46314521 0. 1.46314521]
[ 2.16083479 0. 0.16083479]
[ 1.7728266 0. -0.2271734 ]
[ 3.00928758 0. 0.00928758]
[ 3.27487873 0. -0.72512127]
[ 3.34124152 0. 1.34124152]
[ 1.23691874 0. 0.23691874]
[ 1.4981296 0. -1.5018704 ]
[ 2.37196425 0. 0.37196425]
[ 2.13404779 0. -0.86595221]
[ 3.61207011 0. -0.38792989]
[ 2.59456333 0. -0.40543667]
[ 2.62133109 0. -0.37866891]
[ 3.71468802 0. 1.71468802]
[ 1.78221455 0. -0.21778545]
[ 1.96780433 0. 0.96780433]
[ 2.43415152 0. 0.43415152]
[ 2.97109701 0. 0.97109701]
[ 2.78323318 0. 0.78323318]
[ 2.23812096 0. 0.23812096]
[ 2.71397993 0. -1.28602007]
[ 2.7065102 0. -1.2934898 ]
[ 3.62644254 0. -0.37355746]
[ 2.99239456 0. 0.99239456]
[ 3.54316629 0. 0.54316629]
[ 1.11549732 0. -0.88450268]
[ 1.56349647 0. -0.43650353]
[ 1.71123396 0. 0.71123396]
[ 3.50872977 0. 1.50872977]
[ 3.22804646 0. 2.22804646]
[ 2.01647371 0. -0.98352629]
[ 1.73846621 0. -0.26153379]
[ 3.98571845 0. 0.98571845]
[ 4.08772161 0. 0.08772161]
[ 2.77270122 0. -0.22729878]
[ 2.02429106 0. 0.02429106]
[ 1.6524353 0. -0.3475647 ]
[ 1.86554884 0. -0.13445116]
[ 4.13661877 0. 1.13661877]
[ 1.92517157 0. -1.07482843]
[ 1.32741021 0. 0.32741021]
[ 1.49096309 0. 0.49096309]
[ 4.26649726 0. 0.26649726]
[ 1.30860914 0. 0.30860914]
[ 3.65556042 0. 1.65556042]
[ 2.83933612 0. -0.16066388]
[ 3.95311969 0. -0.04688031]
[ 3.33503986 0. 1.33503986]
[ 1.86988111 0. 0.86988111]
[ 1.6803569 0. -0.3196431 ]
[ 2.84351168 0. 0.84351168]
[ 2.21883433 0. 0.21883433]
[ 2.92736767 0. -0.07263233]
[ 1.94865114 0. 0.94865114]
[ 2.84451654 0. -0.15548346]
[ 2.50099697 0. -0.49900303]
[ 1.84460977 0. -1.15539023]
[ 2.25103711 0. 0.25103711]
[ 3.2031821 0. -1.7968179 ]
[ 2.96381569 0. -0.03618431]
[ 2.9374384 0. 0.9374384 ]
[ 2.75750475 0. 0.75750475]
[ 2.14479067 0. 1.14479067]
[ 1.58906309 0. -1.41093691]
[ 3.95424966 0. 0.95424966]
[ 3.27139735 0. -0.72860265]
[ 2.15821289 0. 1.15821289]
[ 1.46000776 0. -0.53999224]
[ 2.77400627 0. 1.77400627]
[ 2.5346113 0. 0.5346113 ]
[ 1.88058697 0. -0.11941303]
[ 2.21373431 0. 0.21373431]
[ 1.01169338 0. -0.98830662]
[ 3.12535171 0. 1.12535171]
[ 2.8508623 0. -0.1491377 ]
[ 1.71767132 0. 0.71767132]
[ 3.61548509 0. 0.61548509]
[ 2.28357741 0. 1.28357741]
[ 2.18370363 0. 1.18370363]
[ 3.68006173 0. 0.68006173]
[ 3.15995915 0. 0.15995915]
[ 2.06859758 0. 0.06859758]
[ 2.5563126 0. 0.5563126 ]
[ 0.809939 0. -0.190061 ]
[ 2.21540911 0. 0.21540911]
[ 3.11465735 0. 0.11465735]
[ 2.58270229 0. 1.58270229]
[ 1.1271192 0. -0.8728808 ]
[ 1.93800718 0. 0.93800718]
[ 1.82362349 0. -0.17637651]
[ 1.41837389 0. -0.58162611]
[ 1.66102766 0. -0.33897234]
[ 3.2908487 0. 1.2908487 ]
[ 2.72840305 0. -1.27159695]
[ 2.64000898 0. 1.64000898]
[ 3.41794859 0. 1.41794859]
[ 1.72408758 0. 0.72408758]
[ 2.2381923 0. 0.2381923 ]
[ 2.75377796 0. -1.24622204]
[ 1.70660727 0. 0.70660727]
[ 1.6848583 0. -0.3151417 ]
[ 2.41792414 0. -0.58207586]
[ 3.10820851 0. -0.89179149]
[ 2.10151184 0. 0.10151184]
[ 1.36604167 0. -0.63395833]
[ 2.17376358 0. 0.17376358]
[ 1.70217391 0. 0.70217391]
[ 1.52573658 0. -0.47426342]
[ 3.21757008 0. 1.21757008]
[ 3.4661244 0. -1.5338756 ]
[ 2.28811177 0. 1.28811177]
[ 3.00192349 0. 1.00192349]
[ 2.18814338 0. 1.18814338]
[ 2.82484159 0. 0.82484159]]
いくらかデータが集まったら、まずはtheta同士の二乗誤差が一番近い人の好きな曲をレコメンドとして送信予定です。
(5/21:追記)
matplotlibの使い方覚えるためにR=0に限った誤差関数もプロットしてみたら…
ぬっ……と少しびびったのだけれど、考えてみれば(prediction-data)の[R==0]の二乗誤差でやっててそれただの評価の平均や!となり、こっそりscoreも与えて(prediction-score)の[R==0]で二乗誤差をプロットしてもらうと、
よかった……。
(prediction-data)の[R==0]の場合に固定される位置は、prediction-dataの[R==0]の平均の2乗を2で割った値なので確かに((1+5)/2)^2/2で4.5前後になりそうね。一件落着。
現時点でも少し回答を頂けております。ご協力ありがとうございます。気長に続けて行きます。
グーグルフォームへのリンクはこちら
システムについての話も昨日のは貼っつけただけだったので、以下にまとめておきます。
なお昨日から少し変わってます。
ここおかしいだろ!って点あれば教えてください。
学習は下記のlearning関数です。
引数:iterは学習回数、learning_rateは学習率、dataはアンケート結果の行列で(i, j)が回答者(i)の曲(j)への評価。Rはdataに値が入ってるかどうかのブーリアン。featuresは(i, j)が特徴(i)を曲(j)がどれだけ持っているかのパラメタを意味する行列。thetaは(i, j)が回答者(i)の特徴(j)に対する好感度を意味する行列。lamは正規化項のパラメタ。
出力:学習後のthetaとfeatures、および各学習ごとの誤差関数の値が入ってるJ_list
predictはthetaとfeaturesの内積でつまりscoreの予測をする関数。
cost_funcは二乗誤差、delta_funcはthetaとfeatures各要素の偏微分を出力する関数。これが間違ってたら元も子もない。
---------------------------------------------------------------------------------
def learning(iter, learning_rate, data, features, R, theta, lam):
J_list=np.array([])
for i in range(iter):
prediction=predict(np.r_[np.ones(features.shape[1]).reshape(1, -1), features], theta)
J=cost_func(data, prediction, R, theta, lam)
delta_theta, delta_features=delta_func(data, prediction, np.r_[np.ones(features.shape[1]).reshape(1, -1), features], R, theta, lam)
theta=theta-learning_rate*delta_theta
features=features-learning_rate*delta_features
J_list=np.r_[J_list, J]
return theta, features, J_list
out=np.dot(theta, features)
return out
def cost_func(data, prediction, R, theta, lam):
sub=prediction-data
J=np.sum(sub*sub*R)/(np.sum(R)*2)+np.sum(theta*theta)/(np.sum(R)*2)*lam
return J
def delta_func(data, prediction, features, R, theta, lam):
delta_theta=np.dot((prediction-data)*R, features.T)/np.sum(R)
delta_theta[:, 1:]=delta_theta[:, 1:]+theta[:, 1:]*lam/np.sum(R)
delta_features=np.dot(theta.T, (prediction-data)*R)/np.sum(R)
delta_features=delta_features+features*lam/np.sum(R)
return delta_theta, delta_features[1:, :]
---------------------------------------------------------------------------------
これがちゃんと動作するかテストするためにデータセットを作る。
引数:sは回答者の数、mは曲数、categoryは曲数を何個のカテゴリにするか、maxは評価の最高点、tightnessはカテゴリ間でのcorrelationの強さを調整するパラメタとしてとりあえず3としています。
出力:scoreはs*mの行列で、(i, j)は回答者(i)の曲(j)への評価を示す。
favo_maskは各回答者の好きなカテゴリ、music_maskは各曲の属するカテゴリ、correlation_rateは各カテゴリ間での好き嫌い関係を意味するものです。make_half関数で、(i, j)と(j, i)が同じ値になるようにし、対角成分は0にします。
その後はscoreの各成分に対してcorrelationにそって増減および乱数ちょこっと加えた後、1からmaxに標準化して出力。
---------------------------------------------------------------------------------
def make_bigscore(s, m, category, max, tightness=3):
score=np.zeros(s*m).reshape(s, m)
favo_mask=np.random.randint(0, category, s)
music_mask=np.random.randint(0, category, m)
correlation_rate=make_half(np.random.randn(category, category))
for i in range(s):
for j in range(m):
score[i, j]=score[i, j]+correlation_rate[favo_mask[i], music_mask[j]]*tightness+np.random.randn(1)
score=(score-np.min(score))*max/np.max(score-np.min(score))+1
for i in range(score.shape[0]):
for j in range(score.shape[1]):
score[i, j]=math.floor(score[i, j])
return score
def make_half(matrix):
for i in range(matrix.shape[0]):
for j in range(matrix.shape[1]):
if i>j:
matrix[j, i]=matrix[i, j]
elif i==j:
matrix[i, j]=0
return matrix
---------------------------------------------------------------------------------
その後、scoreの一部を0でマスクしたdataと、dataで数値が入っているところを真、0になっているところを偽とした行列Rを出力する関数。
引数:上のscoreと、どれだけを0に置き替えるかのcover_size
出力:dataとR
---------------------------------------------------------------------------------
def make_data(score, cover_size):
data=score.copy()
data=data.reshape(-1)
mask=np.random.choice(data.size, cover_size)
data[mask]=0
data=data.reshape(score.shape)
R=(data!=0)
return data, R
---------------------------------------------------------------------------------
これで試した結果が以下になります。
まずはJ_listの推移は以下のとおり。学習回数多すぎなのは気にしてはいけない。
あとはR=Falseの値に対する予測値をみてみる。そもそもdataの作り方が「ルールに基づいてたくさんの人の全曲に対する評価」を作ったうえでちょこちょこゼロでマスクしたものなので、そのマスク前の値scoreを比較対象とします。
下記はR=0に対する結果。左から予測値、R(=0に決まっている)、予測値-元のscore。なお評価は1-5の5段階評価。
[[ 2.91048635 0. -0.08951365]
[ 3.22132161 0. 0.22132161]
[ 2.87064662 0. 0.87064662]
[ 1.12139268 0. 0.12139268]
[ 2.41326846 0. 1.41326846]
[ 3.55421757 0. 0.55421757]
[ 3.97669655 0. -0.02330345]
[ 2.7501173 0. 0.7501173 ]
[ 3.6383573 0. -1.3616427 ]
[ 3.64936787 0. -0.35063213]
[ 2.0244445 0. 0.0244445 ]
[ 3.84411781 0. -0.15588219]
[ 3.01835681 0. 0.01835681]
[ 1.89188363 0. 0.89188363]
[ 2.75051635 0. -1.24948365]
[ 2.32287718 0. -1.67712282]
[ 2.05562347 0. 0.05562347]
[ 0.90773748 0. -1.09226252]
[ 1.80391987 0. -1.19608013]
[ 2.674269 0. 1.674269 ]
[ 3.00672833 0. 0.00672833]
[ 1.58966497 0. 0.58966497]
[ 0.94116798 0. -1.05883202]
[ 0.74880779 0. -1.25119221]
[ 1.09310716 0. -0.90689284]
[ 2.51634842 0. 0.51634842]
[ 2.9827374 0. -0.0172626 ]
[ 2.64218652 0. 1.64218652]
[ 2.1911775 0. 0.1911775 ]
[ 2.74421011 0. 0.74421011]
[ 2.21623794 0. 0.21623794]
[ 1.89532806 0. -0.10467194]
[ 2.37408885 0. 0.37408885]
[ 3.39205366 0. -0.60794634]
[ 1.61236102 0. -0.38763898]
[ 2.11890973 0. 1.11890973]
[ 3.17744208 0. 1.17744208]
[ 3.08776717 0. -0.91223283]
[ 3.61529334 0. -0.38470666]
[ 1.35736046 0. -0.64263954]
[ 2.06968813 0. 1.06968813]
[ 1.95003957 0. -0.04996043]
[ 1.56540138 0. -0.43459862]
[ 1.97601405 0. -0.02398595]
[ 2.92261337 0. -1.07738663]
[ 4.09495721 0. 1.09495721]
[ 3.19551809 0. -0.80448191]
[ 3.04745553 0. 0.04745553]
[ 0.79817745 0. -1.20182255]
[ 1.54938092 0. -0.45061908]
[ 1.36258335 0. -0.63741665]
[ 1.70375735 0. -0.29624265]
[ 2.93756708 0. -0.06243292]
[ 2.87062359 0. -0.12937641]
[ 2.38004315 0. 0.38004315]
[ 2.78595807 0. 0.78595807]
[ 3.1983942 0. -0.8016058 ]
[ 0.96060024 0. -1.03939976]
[ 3.52827113 0. -1.47172887]
[ 2.28749893 0. 1.28749893]
[ 2.35833173 0. 0.35833173]
[ 2.74344353 0. -0.25655647]
[ 0.74420262 0. -1.25579738]
[ 1.00412291 0. -0.99587709]
[ 3.2490832 0. 1.2490832 ]
[ 3.14802348 0. 0.14802348]
[ 3.13892027 0. 1.13892027]
[ 3.13161889 0. -0.86838111]
[ 3.0564188 0. 1.0564188 ]
[ 2.75968977 0. -0.24031023]
[ 3.3187245 0. 0.3187245 ]
[ 2.48239435 0. 0.48239435]
[ 3.16379722 0. 1.16379722]
[ 3.32180745 0. -0.67819255]
[ 1.2453015 0. 0.2453015 ]
[ 2.51397865 0. -1.48602135]
[ 3.46314521 0. 1.46314521]
[ 2.16083479 0. 0.16083479]
[ 1.7728266 0. -0.2271734 ]
[ 3.00928758 0. 0.00928758]
[ 3.27487873 0. -0.72512127]
[ 3.34124152 0. 1.34124152]
[ 1.23691874 0. 0.23691874]
[ 1.4981296 0. -1.5018704 ]
[ 2.37196425 0. 0.37196425]
[ 2.13404779 0. -0.86595221]
[ 3.61207011 0. -0.38792989]
[ 2.59456333 0. -0.40543667]
[ 2.62133109 0. -0.37866891]
[ 3.71468802 0. 1.71468802]
[ 1.78221455 0. -0.21778545]
[ 1.96780433 0. 0.96780433]
[ 2.43415152 0. 0.43415152]
[ 2.97109701 0. 0.97109701]
[ 2.78323318 0. 0.78323318]
[ 2.23812096 0. 0.23812096]
[ 2.71397993 0. -1.28602007]
[ 2.7065102 0. -1.2934898 ]
[ 3.62644254 0. -0.37355746]
[ 2.99239456 0. 0.99239456]
[ 3.54316629 0. 0.54316629]
[ 1.11549732 0. -0.88450268]
[ 1.56349647 0. -0.43650353]
[ 1.71123396 0. 0.71123396]
[ 3.50872977 0. 1.50872977]
[ 3.22804646 0. 2.22804646]
[ 2.01647371 0. -0.98352629]
[ 1.73846621 0. -0.26153379]
[ 3.98571845 0. 0.98571845]
[ 4.08772161 0. 0.08772161]
[ 2.77270122 0. -0.22729878]
[ 2.02429106 0. 0.02429106]
[ 1.6524353 0. -0.3475647 ]
[ 1.86554884 0. -0.13445116]
[ 4.13661877 0. 1.13661877]
[ 1.92517157 0. -1.07482843]
[ 1.32741021 0. 0.32741021]
[ 1.49096309 0. 0.49096309]
[ 4.26649726 0. 0.26649726]
[ 1.30860914 0. 0.30860914]
[ 3.65556042 0. 1.65556042]
[ 2.83933612 0. -0.16066388]
[ 3.95311969 0. -0.04688031]
[ 3.33503986 0. 1.33503986]
[ 1.86988111 0. 0.86988111]
[ 1.6803569 0. -0.3196431 ]
[ 2.84351168 0. 0.84351168]
[ 2.21883433 0. 0.21883433]
[ 2.92736767 0. -0.07263233]
[ 1.94865114 0. 0.94865114]
[ 2.84451654 0. -0.15548346]
[ 2.50099697 0. -0.49900303]
[ 1.84460977 0. -1.15539023]
[ 2.25103711 0. 0.25103711]
[ 3.2031821 0. -1.7968179 ]
[ 2.96381569 0. -0.03618431]
[ 2.9374384 0. 0.9374384 ]
[ 2.75750475 0. 0.75750475]
[ 2.14479067 0. 1.14479067]
[ 1.58906309 0. -1.41093691]
[ 3.95424966 0. 0.95424966]
[ 3.27139735 0. -0.72860265]
[ 2.15821289 0. 1.15821289]
[ 1.46000776 0. -0.53999224]
[ 2.77400627 0. 1.77400627]
[ 2.5346113 0. 0.5346113 ]
[ 1.88058697 0. -0.11941303]
[ 2.21373431 0. 0.21373431]
[ 1.01169338 0. -0.98830662]
[ 3.12535171 0. 1.12535171]
[ 2.8508623 0. -0.1491377 ]
[ 1.71767132 0. 0.71767132]
[ 3.61548509 0. 0.61548509]
[ 2.28357741 0. 1.28357741]
[ 2.18370363 0. 1.18370363]
[ 3.68006173 0. 0.68006173]
[ 3.15995915 0. 0.15995915]
[ 2.06859758 0. 0.06859758]
[ 2.5563126 0. 0.5563126 ]
[ 0.809939 0. -0.190061 ]
[ 2.21540911 0. 0.21540911]
[ 3.11465735 0. 0.11465735]
[ 2.58270229 0. 1.58270229]
[ 1.1271192 0. -0.8728808 ]
[ 1.93800718 0. 0.93800718]
[ 1.82362349 0. -0.17637651]
[ 1.41837389 0. -0.58162611]
[ 1.66102766 0. -0.33897234]
[ 3.2908487 0. 1.2908487 ]
[ 2.72840305 0. -1.27159695]
[ 2.64000898 0. 1.64000898]
[ 3.41794859 0. 1.41794859]
[ 1.72408758 0. 0.72408758]
[ 2.2381923 0. 0.2381923 ]
[ 2.75377796 0. -1.24622204]
[ 1.70660727 0. 0.70660727]
[ 1.6848583 0. -0.3151417 ]
[ 2.41792414 0. -0.58207586]
[ 3.10820851 0. -0.89179149]
[ 2.10151184 0. 0.10151184]
[ 1.36604167 0. -0.63395833]
[ 2.17376358 0. 0.17376358]
[ 1.70217391 0. 0.70217391]
[ 1.52573658 0. -0.47426342]
[ 3.21757008 0. 1.21757008]
[ 3.4661244 0. -1.5338756 ]
[ 2.28811177 0. 1.28811177]
[ 3.00192349 0. 1.00192349]
[ 2.18814338 0. 1.18814338]
[ 2.82484159 0. 0.82484159]]
いくらかデータが集まったら、まずはtheta同士の二乗誤差が一番近い人の好きな曲をレコメンドとして送信予定です。
(5/21:追記)
matplotlibの使い方覚えるためにR=0に限った誤差関数もプロットしてみたら…
ぬっ……と少しびびったのだけれど、考えてみれば(prediction-data)の[R==0]の二乗誤差でやっててそれただの評価の平均や!となり、こっそりscoreも与えて(prediction-score)の[R==0]で二乗誤差をプロットしてもらうと、
よかった……。
(prediction-data)の[R==0]の場合に固定される位置は、prediction-dataの[R==0]の平均の2乗を2で割った値なので確かに((1+5)/2)^2/2で4.5前後になりそうね。一件落着。
現時点でも少し回答を頂けております。ご協力ありがとうございます。気長に続けて行きます。
2017年5月19日金曜日
昨日のMikuMikuRecommendationに手を加えた
まずMikuMikuRecommendationという名前をつけました(笑)
というのはさておき本題に入りまして、何を変えたかというとデータセットの作り方なのでこれも本質は何も変わらない。
昨日のやつだとデータセットの作り方がそもそもダサすぎるし大きいデータセット作れないので、乱数にcorrelation的サムシングを加えて1-10に標準化したものをデータセットとして使ってみました。いろいろ問題ありそうだけどパッと見いい感じなデータセットが出来たのでよしとします。
correlation的サムシングをどれくらい強くするかをtightnessという変数で決められるようにしたのですが、暫定的に3にしてあります。なぜかというと2にしたらうまくいかなかったから。。。
それとインラインで作業するときにコピペできるといいなと思って、学習部分を関数にまとめた。
現状下記です。
【全体像】
いちいちコメント書くのは面倒なのでドーンとコピペ
main.py
# coding:utf-8
import numpy as np
from functions_cf import *
import matplotlib.pyplot as plt
sample_size=200
music_size=10
feature_size=10
missing_rate=0.1
cover_size=np.int(sample_size*music_size*missing_rate)
category=4
min=1
max=10
score=make_bigscore(sample_size, music_size, category, max)
data, R=make_data(score, cover_size)
features=np.random.randn(feature_size, music_size)
theta=np.random.randn(sample_size, feature_size+1)
learning_rate=0.1
lam=0.01
iter=5000
theta, features, J_list=learning(iter, learning_rate, data, features, R, theta, lam)
fig=plt.figure()
plt.plot(np.arange(0, J_list.shape[0], 1), J_list)
plt.show()
prediction=predict(np.r_[np.ones(features.shape[1]).reshape(1, -1), features], theta)
sub=prediction-score
result=np.c_[prediction.reshape(-1), data.reshape(-1), sub.reshape(-1)]
print(result[R.reshape(-1)==0])
functions_cf.py
import numpy as np
import math
def predict(features, theta):
out=np.dot(theta, features)
return out
def cost_func(data, prediction, R, theta, lam):
sub=prediction-data
J=np.sum(sub*sub*R)/(np.sum(R)*2)+np.sum(theta*theta)/(np.sum(R)*2)*lam
return J
def delta_func(data, prediction, features, R, theta, lam):
delta_theta=np.dot((prediction-data)*R, features.T)/np.sum(R)
delta_theta[:, 1:]=delta_theta[:, 1:]+theta[:, 1:]*lam/np.sum(R)
delta_features=np.dot(theta.T, (prediction-data)*R)/np.sum(R)
delta_features=delta_features+features*lam/np.sum(R)
return delta_theta, delta_features[1:, :]
def make_data(score, cover_size):
data=score.copy()
data=data.reshape(-1)
mask=np.random.choice(data.size, cover_size)
data[mask]=0
data=data.reshape(score.shape)
R=(data!=0)
return data, R
def learning(iter, learning_rate, data, features, R, theta, lam):
J_list=np.array([])
for i in range(iter):
prediction=predict(np.r_[np.ones(features.shape[1]).reshape(1, -1), features], theta)
J=cost_func(data, prediction, R, theta, lam)
delta_theta, delta_features=delta_func(data, prediction, np.r_[np.ones(features.shape[1]).reshape(1, -1), features], R, theta, lam)
theta=theta-learning_rate*delta_theta
features=features-learning_rate*delta_features
J_list=np.r_[J_list, J]
return theta, features, J_list
def make_bigscore(s, m, category, max, tightness=3):
score=np.zeros(s*m).reshape(s, m)+5
favo_mask=np.random.randint(0, category, s)
music_mask=np.random.randint(0, category, m)
correlation_rate=make_half(np.random.randn(category, category))
for i in range(s):
for j in range(m):
score[i, j]=score[i, j]+correlation_rate[favo_mask[i], music_mask[j]]*tightness+np.random.randn(1)
score=(score-np.min(score))*max/np.max(score)+1
for i in range(score.shape[0]):
for j in range(score.shape[1]):
score[i, j]=math.floor(score[i, j])
return score
def make_half(matrix):
for i in range(matrix.shape[0]):
for j in range(matrix.shape[1]):
if i>j:
matrix[j, i]=matrix[i, j]
elif i==j:
matrix[i, j]=0
return matrix
【結果観測】
まずはR=0となっている「抜け値」を埋めること。
上のコードをそのまま実行するとこんな感じの結果になる。
resultは下記。左から順にprediction, data(=0), sub(=prediction-score)
[[ 6.79221021e+00 0.00000000e+00 1.79221021e+00]
[ 7.53311390e+00 0.00000000e+00 -4.66886104e-01]
[ 7.09569830e+00 0.00000000e+00 9.56982987e-02]
[ 4.08265903e+00 0.00000000e+00 -1.91734097e+00]
[ 5.11038322e+00 0.00000000e+00 1.10383221e-01]
[ 6.83060034e+00 0.00000000e+00 -3.16939966e+00]
[ 4.94816109e+00 0.00000000e+00 9.48161091e-01]
[ 8.94097605e+00 0.00000000e+00 -5.90239520e-02]
[ 5.55085831e+00 0.00000000e+00 1.55085831e+00]
[ 6.49319455e+00 0.00000000e+00 4.93194551e-01]
[ 3.51685950e+00 0.00000000e+00 -1.48314050e+00]
[ 3.30233833e+00 0.00000000e+00 -6.97661671e-01]
[ 8.43240061e+00 0.00000000e+00 4.32400607e-01]
[ 8.41920554e+00 0.00000000e+00 1.41920554e+00]
[ 8.48562862e+00 0.00000000e+00 4.85628622e-01]
[ 7.91982280e+00 0.00000000e+00 -1.08017720e+00]
[ 7.69216022e+00 0.00000000e+00 -1.30783978e+00]
[ 7.55579536e+00 0.00000000e+00 5.55795356e-01]
[ 4.36991203e+00 0.00000000e+00 -2.63008797e+00]
[ 8.37558195e+00 0.00000000e+00 -6.24418049e-01]
[ 7.60526264e+00 0.00000000e+00 -3.94737356e-01]
[ 7.27028265e+00 0.00000000e+00 2.70282648e-01]
[ 3.85643006e+00 0.00000000e+00 -1.43569940e-01]
[ 7.31909333e+00 0.00000000e+00 -1.68090667e+00]
[ 6.51557893e+00 0.00000000e+00 -1.48442107e+00]
[ 4.57743090e+00 0.00000000e+00 5.77430899e-01]
[ 4.38028851e+00 0.00000000e+00 3.80288514e-01]
[ 7.15614184e+00 0.00000000e+00 -8.43858160e-01]
[ 4.86797258e+00 0.00000000e+00 8.67972580e-01]
[ 4.77330759e+00 0.00000000e+00 -2.26692406e-01]
[ 5.24170637e+00 0.00000000e+00 -7.58293631e-01]
[ 5.85907532e+00 0.00000000e+00 1.85907532e+00]
[ 7.01088219e+00 0.00000000e+00 -9.89117812e-01]
[ 5.12001067e+00 0.00000000e+00 1.20010669e-01]
[ 4.12287125e+00 0.00000000e+00 -1.87712875e+00]
[ 4.27710625e+00 0.00000000e+00 2.77106249e-01]
[ 5.09901728e+00 0.00000000e+00 -9.00982720e-01]
[ 9.11113983e+00 0.00000000e+00 1.11139833e-01]
[ 4.42793609e+00 0.00000000e+00 4.27936088e-01]
[ 8.48836557e+00 0.00000000e+00 4.88365570e-01]
[ 7.51564493e+00 0.00000000e+00 5.15644932e-01]
[ 8.46280310e+00 0.00000000e+00 -5.37196902e-01]
[ 5.43836821e+00 0.00000000e+00 2.43836821e+00]
[ 4.65313035e+00 0.00000000e+00 -3.46869649e-01]
[ 7.30430048e+00 0.00000000e+00 3.04300479e-01]
[ 4.96641442e+00 0.00000000e+00 -1.03358558e+00]
[ 4.52025568e+00 0.00000000e+00 5.20255679e-01]
[ 4.92889597e+00 0.00000000e+00 -7.11040257e-02]
[ 5.58917669e+00 0.00000000e+00 -1.41082331e+00]
[ 4.91782026e+00 0.00000000e+00 9.17820259e-01]
[ 8.68151818e+00 0.00000000e+00 6.81518179e-01]
[ 6.85840196e+00 0.00000000e+00 -1.41598037e-01]
[ 3.50970449e+00 0.00000000e+00 -2.49029551e+00]
[ 3.54968073e+00 0.00000000e+00 -4.50319275e-01]
[ 8.32713967e+00 0.00000000e+00 -6.72860330e-01]
[ 2.88564867e+00 0.00000000e+00 -1.14351328e-01]
[ 4.13329265e+00 0.00000000e+00 2.13329265e+00]
[ 6.02285979e+00 0.00000000e+00 -1.97714021e+00]
[ 5.48115372e+00 0.00000000e+00 -5.18846284e-01]
[ 6.22298632e+00 0.00000000e+00 2.22986317e-01]
[ 5.90004546e+00 0.00000000e+00 -9.99545425e-02]
[ 4.40149963e+00 0.00000000e+00 4.01499634e-01]
[ 7.00155610e+00 0.00000000e+00 1.00155610e+00]
[ 6.12984227e+00 0.00000000e+00 2.12984227e+00]
[ 4.43785983e+00 0.00000000e+00 4.37859834e-01]
[ 6.33879076e+00 0.00000000e+00 3.38790763e-01]
[ 4.33571888e+00 0.00000000e+00 -1.66428112e+00]
[ 2.48997197e+00 0.00000000e+00 1.48997197e+00]
[ 7.08099180e+00 0.00000000e+00 -9.19008203e-01]
[ 5.10492050e+00 0.00000000e+00 -1.89507950e+00]
[ 5.02968334e+00 0.00000000e+00 1.02968334e+00]
[ 7.13977640e+00 0.00000000e+00 1.39776399e-01]
[ 7.55119719e+00 0.00000000e+00 -4.48802809e-01]
[ 8.53262155e+00 0.00000000e+00 -4.67378448e-01]
[ 7.51423486e+00 0.00000000e+00 5.14234860e-01]
[ 2.99267833e+00 0.00000000e+00 -1.00732167e+00]
[ 3.32044090e+00 0.00000000e+00 -6.79559105e-01]
[ 8.25607045e+00 0.00000000e+00 2.25607045e+00]
[ 8.27929557e+00 0.00000000e+00 2.27929557e+00]
[ 5.79284155e+00 0.00000000e+00 -2.20715845e+00]
[ 5.63182840e+00 0.00000000e+00 6.31828398e-01]
[ 4.27518546e+00 0.00000000e+00 -1.72481454e+00]
[ 2.82775364e+00 0.00000000e+00 -1.17224636e+00]
[ 3.27150448e+00 0.00000000e+00 -7.28495524e-01]
[ 4.86288111e+00 0.00000000e+00 8.62881106e-01]
[ 4.15287914e+00 0.00000000e+00 1.15287914e+00]
[ 5.48403925e+00 0.00000000e+00 4.84039252e-01]
[ 7.43184309e+00 0.00000000e+00 4.31843086e-01]
[ 5.42698644e+00 0.00000000e+00 -1.57301356e+00]
[ 7.65437641e+00 0.00000000e+00 -1.34562359e+00]
[ 7.55554759e+00 0.00000000e+00 5.55547595e-01]
[ 2.67930823e+00 0.00000000e+00 -1.32069177e+00]
[ 4.00815548e+00 0.00000000e+00 8.15547591e-03]
[ 6.37997063e+00 0.00000000e+00 3.79970634e-01]
[ 6.96754636e+00 0.00000000e+00 -3.03245364e+00]
[ 7.78552434e+00 0.00000000e+00 1.78552434e+00]
[ 7.13891523e+00 0.00000000e+00 -8.61084772e-01]
[ 5.15818977e+00 0.00000000e+00 -3.84181023e+00]
[ 7.47774417e+00 0.00000000e+00 -5.22255829e-01]
[ 6.35284199e+00 0.00000000e+00 3.52841987e-01]
[ 6.76815734e+00 0.00000000e+00 7.68157336e-01]
[ 5.94075467e+00 0.00000000e+00 -1.05924533e+00]
[ 7.12827211e+00 0.00000000e+00 2.12827211e+00]
[ 7.07915735e+00 0.00000000e+00 2.07915735e+00]
[ 7.12728376e+00 0.00000000e+00 3.12728376e+00]
[ 4.78467425e+00 0.00000000e+00 -1.21532575e+00]
[ 4.11857739e+00 0.00000000e+00 1.11857739e+00]
[ 4.94899594e+00 0.00000000e+00 -2.05100406e+00]
[ 3.62200612e+00 0.00000000e+00 -3.77993876e-01]
[ 4.13980901e+00 0.00000000e+00 1.13980901e+00]
[ 4.79519851e+00 0.00000000e+00 7.95198510e-01]
[ 9.58568796e+00 0.00000000e+00 -4.14312037e-01]
[ 2.40409199e+00 0.00000000e+00 1.40409199e+00]
[ 7.23112510e+00 0.00000000e+00 2.31125099e-01]
[ 8.03322342e+00 0.00000000e+00 2.03322342e+00]
[ 4.59414694e+00 0.00000000e+00 5.94146940e-01]
[ 2.43238578e+00 0.00000000e+00 4.32385779e-01]
[ 5.30828543e+00 0.00000000e+00 1.30828543e+00]
[ 7.66747838e+00 0.00000000e+00 2.66747838e+00]
[ 8.80101019e+00 0.00000000e+00 -1.19898981e+00]
[ 8.32916719e+00 0.00000000e+00 3.29167190e-01]
[ 8.59429933e+00 0.00000000e+00 5.94299329e-01]
[ 3.75772349e+00 0.00000000e+00 7.57723494e-01]
[ 4.88486628e+00 0.00000000e+00 8.84866279e-01]
[ 5.53191466e+00 0.00000000e+00 -4.68085340e-01]
[ 2.88859452e+00 0.00000000e+00 -1.11405479e-01]
[ 3.47813361e+00 0.00000000e+00 -1.52186639e+00]
[ 8.17178509e+00 0.00000000e+00 1.71785091e-01]
[ 5.26310717e+00 0.00000000e+00 2.63107170e-01]
[ 7.66027706e+00 0.00000000e+00 -3.39722939e-01]
[ 7.06787062e+00 0.00000000e+00 1.06787062e+00]
[ 6.46431101e+00 0.00000000e+00 1.46431101e+00]
[ 3.89064273e+00 0.00000000e+00 8.90642734e-01]
[ 4.90327565e+00 0.00000000e+00 9.03275650e-01]
[ 1.03351212e+01 0.00000000e+00 1.33512121e+00]
[ 8.87801850e-01 0.00000000e+00 -1.11219815e+00]
[ 7.12930310e-01 0.00000000e+00 -2.28706969e+00]
[ 7.24062819e+00 0.00000000e+00 1.24062819e+00]
[ 6.85799559e+00 0.00000000e+00 -1.14200441e+00]
[ 1.85540284e+00 0.00000000e+00 -1.14459716e+00]
[ 2.32319731e+00 0.00000000e+00 -6.76802692e-01]
[ 2.48839620e+00 0.00000000e+00 4.88396199e-01]
[ 5.65267581e+00 0.00000000e+00 6.52675811e-01]
[ 6.75195833e+00 0.00000000e+00 -2.48041665e-01]
[ 5.82486383e+00 0.00000000e+00 -1.17513617e+00]
[ 6.15176564e+00 0.00000000e+00 1.15176564e+00]
[ 7.77587521e+00 0.00000000e+00 -2.24124785e-01]
[ 7.89983313e+00 0.00000000e+00 -1.00166870e-01]
[ 7.95130816e+00 0.00000000e+00 -1.04869184e+00]
[ 5.09700416e+00 0.00000000e+00 1.09700416e+00]
[ 4.58884212e+00 0.00000000e+00 5.88842121e-01]
[ 3.52519701e+00 0.00000000e+00 5.25197009e-01]
[ 1.09648602e+01 0.00000000e+00 1.96486018e+00]
[ 1.13242890e+01 0.00000000e+00 3.32428897e+00]
[ 5.88963853e+00 0.00000000e+00 8.89638530e-01]
[ 9.09428528e+00 0.00000000e+00 -9.05714716e-01]
[ 8.22378866e+00 0.00000000e+00 1.22378866e+00]
[ 5.81835879e+00 0.00000000e+00 -1.81641213e-01]
[ 5.62012171e+00 0.00000000e+00 -3.79878291e-01]
[ 4.45124746e+00 0.00000000e+00 4.51247460e-01]
[ 2.90545701e+00 0.00000000e+00 -1.09454299e+00]
[ 6.10119442e+00 0.00000000e+00 -1.89880558e+00]
[ 6.03296647e+00 0.00000000e+00 -9.67033530e-01]
[ 6.67224047e+00 0.00000000e+00 -3.27759529e-01]
[ 2.81798517e+00 0.00000000e+00 -1.18201483e+00]
[ 2.27543560e+00 0.00000000e+00 -1.72456440e+00]
[ 7.22559961e+00 0.00000000e+00 1.22559961e+00]
[ 8.13112548e+00 0.00000000e+00 -8.68874517e-01]
[ 7.67090557e+00 0.00000000e+00 1.67090557e+00]
[ 3.03188884e+00 0.00000000e+00 3.18888430e-02]
[ 7.84613021e+00 0.00000000e+00 8.46130212e-01]
[ 3.04744195e+00 0.00000000e+00 -1.95255805e+00]
[ 5.16166895e+00 0.00000000e+00 -1.83833105e+00]
[ 6.67414529e+00 0.00000000e+00 -3.25854707e-01]
[ 2.75351809e+00 0.00000000e+00 -1.24648191e+00]
[ 5.74327034e+00 0.00000000e+00 -1.25672966e+00]
[ 5.63928383e+00 0.00000000e+00 -3.60716172e-01]
[ 7.64732695e+00 0.00000000e+00 -1.35267305e+00]
[ 5.78103852e+00 0.00000000e+00 7.81038524e-01]
[ 7.45234275e+00 0.00000000e+00 -1.54765725e+00]
[ 4.26299512e+00 0.00000000e+00 -7.37004877e-01]
[ 6.52988398e+00 0.00000000e+00 -4.70116016e-01]
[ 8.96044500e+00 0.00000000e+00 9.60445003e-01]
[ 7.35391597e+00 0.00000000e+00 -1.64608403e+00]
[ 6.70063979e+00 0.00000000e+00 -1.29936021e+00]]
たまにでかいのあるけどおおむねいい感じではないですか(非常に主観的)
実用ではmissing_rateをもっと大きくした状態での推測をしなければいけないだろうと思われるので、今後も色々試してみます。
あとこれと同時に作られるthetaとfeaturesについても活用できそうなのでまたやってみます。thetaの傾向が似てる人が好きな曲とかさ。
というのはさておき本題に入りまして、何を変えたかというとデータセットの作り方なのでこれも本質は何も変わらない。
昨日のやつだとデータセットの作り方がそもそもダサすぎるし大きいデータセット作れないので、乱数にcorrelation的サムシングを加えて1-10に標準化したものをデータセットとして使ってみました。いろいろ問題ありそうだけどパッと見いい感じなデータセットが出来たのでよしとします。
correlation的サムシングをどれくらい強くするかをtightnessという変数で決められるようにしたのですが、暫定的に3にしてあります。なぜかというと2にしたらうまくいかなかったから。。。
それとインラインで作業するときにコピペできるといいなと思って、学習部分を関数にまとめた。
現状下記です。
【全体像】
いちいちコメント書くのは面倒なのでドーンとコピペ
main.py
# coding:utf-8
import numpy as np
from functions_cf import *
import matplotlib.pyplot as plt
sample_size=200
music_size=10
feature_size=10
missing_rate=0.1
cover_size=np.int(sample_size*music_size*missing_rate)
category=4
min=1
max=10
score=make_bigscore(sample_size, music_size, category, max)
data, R=make_data(score, cover_size)
features=np.random.randn(feature_size, music_size)
theta=np.random.randn(sample_size, feature_size+1)
learning_rate=0.1
lam=0.01
iter=5000
theta, features, J_list=learning(iter, learning_rate, data, features, R, theta, lam)
fig=plt.figure()
plt.plot(np.arange(0, J_list.shape[0], 1), J_list)
plt.show()
prediction=predict(np.r_[np.ones(features.shape[1]).reshape(1, -1), features], theta)
sub=prediction-score
result=np.c_[prediction.reshape(-1), data.reshape(-1), sub.reshape(-1)]
print(result[R.reshape(-1)==0])
functions_cf.py
import numpy as np
import math
def predict(features, theta):
out=np.dot(theta, features)
return out
def cost_func(data, prediction, R, theta, lam):
sub=prediction-data
J=np.sum(sub*sub*R)/(np.sum(R)*2)+np.sum(theta*theta)/(np.sum(R)*2)*lam
return J
def delta_func(data, prediction, features, R, theta, lam):
delta_theta=np.dot((prediction-data)*R, features.T)/np.sum(R)
delta_theta[:, 1:]=delta_theta[:, 1:]+theta[:, 1:]*lam/np.sum(R)
delta_features=np.dot(theta.T, (prediction-data)*R)/np.sum(R)
delta_features=delta_features+features*lam/np.sum(R)
return delta_theta, delta_features[1:, :]
def make_data(score, cover_size):
data=score.copy()
data=data.reshape(-1)
mask=np.random.choice(data.size, cover_size)
data[mask]=0
data=data.reshape(score.shape)
R=(data!=0)
return data, R
def learning(iter, learning_rate, data, features, R, theta, lam):
J_list=np.array([])
for i in range(iter):
prediction=predict(np.r_[np.ones(features.shape[1]).reshape(1, -1), features], theta)
J=cost_func(data, prediction, R, theta, lam)
delta_theta, delta_features=delta_func(data, prediction, np.r_[np.ones(features.shape[1]).reshape(1, -1), features], R, theta, lam)
theta=theta-learning_rate*delta_theta
features=features-learning_rate*delta_features
J_list=np.r_[J_list, J]
return theta, features, J_list
def make_bigscore(s, m, category, max, tightness=3):
score=np.zeros(s*m).reshape(s, m)+5
favo_mask=np.random.randint(0, category, s)
music_mask=np.random.randint(0, category, m)
correlation_rate=make_half(np.random.randn(category, category))
for i in range(s):
for j in range(m):
score[i, j]=score[i, j]+correlation_rate[favo_mask[i], music_mask[j]]*tightness+np.random.randn(1)
score=(score-np.min(score))*max/np.max(score)+1
for i in range(score.shape[0]):
for j in range(score.shape[1]):
score[i, j]=math.floor(score[i, j])
return score
def make_half(matrix):
for i in range(matrix.shape[0]):
for j in range(matrix.shape[1]):
if i>j:
matrix[j, i]=matrix[i, j]
elif i==j:
matrix[i, j]=0
return matrix
【結果観測】
まずはR=0となっている「抜け値」を埋めること。
上のコードをそのまま実行するとこんな感じの結果になる。
resultは下記。左から順にprediction, data(=0), sub(=prediction-score)
[[ 6.79221021e+00 0.00000000e+00 1.79221021e+00]
[ 7.53311390e+00 0.00000000e+00 -4.66886104e-01]
[ 7.09569830e+00 0.00000000e+00 9.56982987e-02]
[ 4.08265903e+00 0.00000000e+00 -1.91734097e+00]
[ 5.11038322e+00 0.00000000e+00 1.10383221e-01]
[ 6.83060034e+00 0.00000000e+00 -3.16939966e+00]
[ 4.94816109e+00 0.00000000e+00 9.48161091e-01]
[ 8.94097605e+00 0.00000000e+00 -5.90239520e-02]
[ 5.55085831e+00 0.00000000e+00 1.55085831e+00]
[ 6.49319455e+00 0.00000000e+00 4.93194551e-01]
[ 3.51685950e+00 0.00000000e+00 -1.48314050e+00]
[ 3.30233833e+00 0.00000000e+00 -6.97661671e-01]
[ 8.43240061e+00 0.00000000e+00 4.32400607e-01]
[ 8.41920554e+00 0.00000000e+00 1.41920554e+00]
[ 8.48562862e+00 0.00000000e+00 4.85628622e-01]
[ 7.91982280e+00 0.00000000e+00 -1.08017720e+00]
[ 7.69216022e+00 0.00000000e+00 -1.30783978e+00]
[ 7.55579536e+00 0.00000000e+00 5.55795356e-01]
[ 4.36991203e+00 0.00000000e+00 -2.63008797e+00]
[ 8.37558195e+00 0.00000000e+00 -6.24418049e-01]
[ 7.60526264e+00 0.00000000e+00 -3.94737356e-01]
[ 7.27028265e+00 0.00000000e+00 2.70282648e-01]
[ 3.85643006e+00 0.00000000e+00 -1.43569940e-01]
[ 7.31909333e+00 0.00000000e+00 -1.68090667e+00]
[ 6.51557893e+00 0.00000000e+00 -1.48442107e+00]
[ 4.57743090e+00 0.00000000e+00 5.77430899e-01]
[ 4.38028851e+00 0.00000000e+00 3.80288514e-01]
[ 7.15614184e+00 0.00000000e+00 -8.43858160e-01]
[ 4.86797258e+00 0.00000000e+00 8.67972580e-01]
[ 4.77330759e+00 0.00000000e+00 -2.26692406e-01]
[ 5.24170637e+00 0.00000000e+00 -7.58293631e-01]
[ 5.85907532e+00 0.00000000e+00 1.85907532e+00]
[ 7.01088219e+00 0.00000000e+00 -9.89117812e-01]
[ 5.12001067e+00 0.00000000e+00 1.20010669e-01]
[ 4.12287125e+00 0.00000000e+00 -1.87712875e+00]
[ 4.27710625e+00 0.00000000e+00 2.77106249e-01]
[ 5.09901728e+00 0.00000000e+00 -9.00982720e-01]
[ 9.11113983e+00 0.00000000e+00 1.11139833e-01]
[ 4.42793609e+00 0.00000000e+00 4.27936088e-01]
[ 8.48836557e+00 0.00000000e+00 4.88365570e-01]
[ 7.51564493e+00 0.00000000e+00 5.15644932e-01]
[ 8.46280310e+00 0.00000000e+00 -5.37196902e-01]
[ 5.43836821e+00 0.00000000e+00 2.43836821e+00]
[ 4.65313035e+00 0.00000000e+00 -3.46869649e-01]
[ 7.30430048e+00 0.00000000e+00 3.04300479e-01]
[ 4.96641442e+00 0.00000000e+00 -1.03358558e+00]
[ 4.52025568e+00 0.00000000e+00 5.20255679e-01]
[ 4.92889597e+00 0.00000000e+00 -7.11040257e-02]
[ 5.58917669e+00 0.00000000e+00 -1.41082331e+00]
[ 4.91782026e+00 0.00000000e+00 9.17820259e-01]
[ 8.68151818e+00 0.00000000e+00 6.81518179e-01]
[ 6.85840196e+00 0.00000000e+00 -1.41598037e-01]
[ 3.50970449e+00 0.00000000e+00 -2.49029551e+00]
[ 3.54968073e+00 0.00000000e+00 -4.50319275e-01]
[ 8.32713967e+00 0.00000000e+00 -6.72860330e-01]
[ 2.88564867e+00 0.00000000e+00 -1.14351328e-01]
[ 4.13329265e+00 0.00000000e+00 2.13329265e+00]
[ 6.02285979e+00 0.00000000e+00 -1.97714021e+00]
[ 5.48115372e+00 0.00000000e+00 -5.18846284e-01]
[ 6.22298632e+00 0.00000000e+00 2.22986317e-01]
[ 5.90004546e+00 0.00000000e+00 -9.99545425e-02]
[ 4.40149963e+00 0.00000000e+00 4.01499634e-01]
[ 7.00155610e+00 0.00000000e+00 1.00155610e+00]
[ 6.12984227e+00 0.00000000e+00 2.12984227e+00]
[ 4.43785983e+00 0.00000000e+00 4.37859834e-01]
[ 6.33879076e+00 0.00000000e+00 3.38790763e-01]
[ 4.33571888e+00 0.00000000e+00 -1.66428112e+00]
[ 2.48997197e+00 0.00000000e+00 1.48997197e+00]
[ 7.08099180e+00 0.00000000e+00 -9.19008203e-01]
[ 5.10492050e+00 0.00000000e+00 -1.89507950e+00]
[ 5.02968334e+00 0.00000000e+00 1.02968334e+00]
[ 7.13977640e+00 0.00000000e+00 1.39776399e-01]
[ 7.55119719e+00 0.00000000e+00 -4.48802809e-01]
[ 8.53262155e+00 0.00000000e+00 -4.67378448e-01]
[ 7.51423486e+00 0.00000000e+00 5.14234860e-01]
[ 2.99267833e+00 0.00000000e+00 -1.00732167e+00]
[ 3.32044090e+00 0.00000000e+00 -6.79559105e-01]
[ 8.25607045e+00 0.00000000e+00 2.25607045e+00]
[ 8.27929557e+00 0.00000000e+00 2.27929557e+00]
[ 5.79284155e+00 0.00000000e+00 -2.20715845e+00]
[ 5.63182840e+00 0.00000000e+00 6.31828398e-01]
[ 4.27518546e+00 0.00000000e+00 -1.72481454e+00]
[ 2.82775364e+00 0.00000000e+00 -1.17224636e+00]
[ 3.27150448e+00 0.00000000e+00 -7.28495524e-01]
[ 4.86288111e+00 0.00000000e+00 8.62881106e-01]
[ 4.15287914e+00 0.00000000e+00 1.15287914e+00]
[ 5.48403925e+00 0.00000000e+00 4.84039252e-01]
[ 7.43184309e+00 0.00000000e+00 4.31843086e-01]
[ 5.42698644e+00 0.00000000e+00 -1.57301356e+00]
[ 7.65437641e+00 0.00000000e+00 -1.34562359e+00]
[ 7.55554759e+00 0.00000000e+00 5.55547595e-01]
[ 2.67930823e+00 0.00000000e+00 -1.32069177e+00]
[ 4.00815548e+00 0.00000000e+00 8.15547591e-03]
[ 6.37997063e+00 0.00000000e+00 3.79970634e-01]
[ 6.96754636e+00 0.00000000e+00 -3.03245364e+00]
[ 7.78552434e+00 0.00000000e+00 1.78552434e+00]
[ 7.13891523e+00 0.00000000e+00 -8.61084772e-01]
[ 5.15818977e+00 0.00000000e+00 -3.84181023e+00]
[ 7.47774417e+00 0.00000000e+00 -5.22255829e-01]
[ 6.35284199e+00 0.00000000e+00 3.52841987e-01]
[ 6.76815734e+00 0.00000000e+00 7.68157336e-01]
[ 5.94075467e+00 0.00000000e+00 -1.05924533e+00]
[ 7.12827211e+00 0.00000000e+00 2.12827211e+00]
[ 7.07915735e+00 0.00000000e+00 2.07915735e+00]
[ 7.12728376e+00 0.00000000e+00 3.12728376e+00]
[ 4.78467425e+00 0.00000000e+00 -1.21532575e+00]
[ 4.11857739e+00 0.00000000e+00 1.11857739e+00]
[ 4.94899594e+00 0.00000000e+00 -2.05100406e+00]
[ 3.62200612e+00 0.00000000e+00 -3.77993876e-01]
[ 4.13980901e+00 0.00000000e+00 1.13980901e+00]
[ 4.79519851e+00 0.00000000e+00 7.95198510e-01]
[ 9.58568796e+00 0.00000000e+00 -4.14312037e-01]
[ 2.40409199e+00 0.00000000e+00 1.40409199e+00]
[ 7.23112510e+00 0.00000000e+00 2.31125099e-01]
[ 8.03322342e+00 0.00000000e+00 2.03322342e+00]
[ 4.59414694e+00 0.00000000e+00 5.94146940e-01]
[ 2.43238578e+00 0.00000000e+00 4.32385779e-01]
[ 5.30828543e+00 0.00000000e+00 1.30828543e+00]
[ 7.66747838e+00 0.00000000e+00 2.66747838e+00]
[ 8.80101019e+00 0.00000000e+00 -1.19898981e+00]
[ 8.32916719e+00 0.00000000e+00 3.29167190e-01]
[ 8.59429933e+00 0.00000000e+00 5.94299329e-01]
[ 3.75772349e+00 0.00000000e+00 7.57723494e-01]
[ 4.88486628e+00 0.00000000e+00 8.84866279e-01]
[ 5.53191466e+00 0.00000000e+00 -4.68085340e-01]
[ 2.88859452e+00 0.00000000e+00 -1.11405479e-01]
[ 3.47813361e+00 0.00000000e+00 -1.52186639e+00]
[ 8.17178509e+00 0.00000000e+00 1.71785091e-01]
[ 5.26310717e+00 0.00000000e+00 2.63107170e-01]
[ 7.66027706e+00 0.00000000e+00 -3.39722939e-01]
[ 7.06787062e+00 0.00000000e+00 1.06787062e+00]
[ 6.46431101e+00 0.00000000e+00 1.46431101e+00]
[ 3.89064273e+00 0.00000000e+00 8.90642734e-01]
[ 4.90327565e+00 0.00000000e+00 9.03275650e-01]
[ 1.03351212e+01 0.00000000e+00 1.33512121e+00]
[ 8.87801850e-01 0.00000000e+00 -1.11219815e+00]
[ 7.12930310e-01 0.00000000e+00 -2.28706969e+00]
[ 7.24062819e+00 0.00000000e+00 1.24062819e+00]
[ 6.85799559e+00 0.00000000e+00 -1.14200441e+00]
[ 1.85540284e+00 0.00000000e+00 -1.14459716e+00]
[ 2.32319731e+00 0.00000000e+00 -6.76802692e-01]
[ 2.48839620e+00 0.00000000e+00 4.88396199e-01]
[ 5.65267581e+00 0.00000000e+00 6.52675811e-01]
[ 6.75195833e+00 0.00000000e+00 -2.48041665e-01]
[ 5.82486383e+00 0.00000000e+00 -1.17513617e+00]
[ 6.15176564e+00 0.00000000e+00 1.15176564e+00]
[ 7.77587521e+00 0.00000000e+00 -2.24124785e-01]
[ 7.89983313e+00 0.00000000e+00 -1.00166870e-01]
[ 7.95130816e+00 0.00000000e+00 -1.04869184e+00]
[ 5.09700416e+00 0.00000000e+00 1.09700416e+00]
[ 4.58884212e+00 0.00000000e+00 5.88842121e-01]
[ 3.52519701e+00 0.00000000e+00 5.25197009e-01]
[ 1.09648602e+01 0.00000000e+00 1.96486018e+00]
[ 1.13242890e+01 0.00000000e+00 3.32428897e+00]
[ 5.88963853e+00 0.00000000e+00 8.89638530e-01]
[ 9.09428528e+00 0.00000000e+00 -9.05714716e-01]
[ 8.22378866e+00 0.00000000e+00 1.22378866e+00]
[ 5.81835879e+00 0.00000000e+00 -1.81641213e-01]
[ 5.62012171e+00 0.00000000e+00 -3.79878291e-01]
[ 4.45124746e+00 0.00000000e+00 4.51247460e-01]
[ 2.90545701e+00 0.00000000e+00 -1.09454299e+00]
[ 6.10119442e+00 0.00000000e+00 -1.89880558e+00]
[ 6.03296647e+00 0.00000000e+00 -9.67033530e-01]
[ 6.67224047e+00 0.00000000e+00 -3.27759529e-01]
[ 2.81798517e+00 0.00000000e+00 -1.18201483e+00]
[ 2.27543560e+00 0.00000000e+00 -1.72456440e+00]
[ 7.22559961e+00 0.00000000e+00 1.22559961e+00]
[ 8.13112548e+00 0.00000000e+00 -8.68874517e-01]
[ 7.67090557e+00 0.00000000e+00 1.67090557e+00]
[ 3.03188884e+00 0.00000000e+00 3.18888430e-02]
[ 7.84613021e+00 0.00000000e+00 8.46130212e-01]
[ 3.04744195e+00 0.00000000e+00 -1.95255805e+00]
[ 5.16166895e+00 0.00000000e+00 -1.83833105e+00]
[ 6.67414529e+00 0.00000000e+00 -3.25854707e-01]
[ 2.75351809e+00 0.00000000e+00 -1.24648191e+00]
[ 5.74327034e+00 0.00000000e+00 -1.25672966e+00]
[ 5.63928383e+00 0.00000000e+00 -3.60716172e-01]
[ 7.64732695e+00 0.00000000e+00 -1.35267305e+00]
[ 5.78103852e+00 0.00000000e+00 7.81038524e-01]
[ 7.45234275e+00 0.00000000e+00 -1.54765725e+00]
[ 4.26299512e+00 0.00000000e+00 -7.37004877e-01]
[ 6.52988398e+00 0.00000000e+00 -4.70116016e-01]
[ 8.96044500e+00 0.00000000e+00 9.60445003e-01]
[ 7.35391597e+00 0.00000000e+00 -1.64608403e+00]
[ 6.70063979e+00 0.00000000e+00 -1.29936021e+00]]
たまにでかいのあるけどおおむねいい感じではないですか(非常に主観的)
実用ではmissing_rateをもっと大きくした状態での推測をしなければいけないだろうと思われるので、今後も色々試してみます。
あとこれと同時に作られるthetaとfeaturesについても活用できそうなのでまたやってみます。thetaの傾向が似てる人が好きな曲とかさ。
機械学習で好きなボカロ曲を探せたらいいなって話
CourseraのMachineLearningコースの何週目かで出てきたCollaborative filteringってのが、ボカロ曲の好みを機械学習してオヌヌメ曲を提示することができそうだなと思ってたので、復習も兼ねて作ってみました。AndrewNg先生が「長くても24時間でプロトタイプを作れ」みたいなこと仰ってたので、本日の実shおっとこれ以上は言えない。
ちゃんとしたデータがないので現実的ではないでしょうけど、グーグルアンケートでやれるだけやってみようかな。
この手の記事の書き方がわからないので、地の文で書き続けていきます。
【0.ライブラリ読み込み】
作った関数はfunctions_cf.pyに入ってます。末尾参照。
import numpy as np
from functions_cf import *
import matplotlib.pyplot as plt
【1.データ作成】
登場する変数
music_size:扱う楽曲数 今回は5とした
sample_size:アンケート回答者 今回は100とした
cover_size:アンケート回答者が知らなかった楽曲数 今回は50
楽曲5曲はそれぞれA,B,C,D,Eと呼ぶことにします。100人の回答者がその5曲に対して順位付けを5-1とつけてもらうことにします(数字が大きいほど好き)。それを100*5の行列scoreとします。
ひとまず次のような設定にしました。
・AとB、CとDは同系統
・AとBが好きな人はCとDが好きではなく、Eはさらに好きではない
・CとDが好きな人は、Eが好きではなく、AとBはさらに好きではない
・Eが好きな人は、CとDは好きではなく、AとBはさらに好きではない
という前提でデータ作り。つまり言い換えれば
・(A,B)が(5,4)or(4,5)な人:(C,D)は(3,2)(2,3)のいずれか、Eは1
・(C,D)が(5,4)or(4,5)な人:(A,B)は(2,1)(1,2)のいずれか、Eは3
・Eが5な人:(A,B)は(2,1)or(1,2)、(C,D)は(4,3)or(3,4)
こんな行列を作るmaking_score関数を作って
score=making_score(sample_size, music_size)
で作ると。
ところが現実問題、全員が全員アンケートの全曲を知っているとは限らないので、scoreからcover_sizeの数の分だけランダムに0にしたdata、および「dataのうち0でないものを真、0のものを偽」とした行列Rを作ります。こちらもmaking_data関数を作って
data, R=making_data(score, cover_size)
で吐いてもらいました。
なお0が意味するのは評価が0というわけではなく、値がないことを意味する便宜上の0です。学習するときにはこのdataを学習し、scoreを予測することを目標とします。
【2.学習準備】
登場する変数
feature_size:機械学習で学んでもらう特徴の数 今回は3としました
theta:各サンプルの好みの特徴を表す行列(これを学習してもらう)
features:各楽曲の持つ特徴の行列(これも学習してもらう)
learning_rate:学習率
lam:正規化項
iter:学習回数
J_list:誤差関数記録
まず学習する2つを乱数で決めてやる。
features=np.random.randn(feature_size, music_size)
theta=np.random.randn(sample_size, feature_size+1)
featuresは実際に使う時に全部1の列を先頭に追加するので、thetaの余分な1行との内積でバイアスとして使われます。
他初期設定
learning_rate=0.3
lam=0.003
iter=2000
J_list=np.array([])
なおlearning_rateとlamはクロスバリデーションで求めました。クロスバリデーションといっても、データセットの出所は同じ関数ですけど。
【3.学習】
0. iter回繰り返し
1. featuresとthetaからscoreの予測値predictionを出させる
2. 記録用に二乗誤差Jを求める
3. 誤差関数の偏微分delta_theta, delta_featuresを求める
4. thetaの調整
5. featuresの調整
6. 記録用に今回のJをJ_listに追加
てな具合。以下。
for i in range(iter):
prediction=predict(np.r_[np.ones(features.shape[1]).reshape(1, -1), features], theta)
J=cost_func(data, prediction, R, theta, lam)
delta_theta, delta_features=delta_func(data, prediction, np.r_[np.ones(features.shape[1]).reshape(1, -1), features], R, theta, lam)
theta=theta-learning_rate*delta_theta
features=features-learning_rate*delta_features
J_list=np.r_[J_list, J]
fig=plt.figure()
plt.plot(np.arange(0, J_list.shape[0], 1), J_list)
plt.show()
sub=prediction-score
result=np.c_[prediction.reshape(-1), data.reshape(-1), sub.reshape(-1)]
np.save('result.npy', result)
二乗誤差をプロットしたのがこんな感じ
ちなみにresult[0:100]についてはこんな感じ。
左から順にprediction, data, prediction-scoreです。
array([[ 1.24964192, 1. , 0.24964192],
[ 1.37293008, 2. , -0.62706992],
[ 3.89820064, 4. , -0.10179936],
[ 3.65735034, 3. , 0.65735034],
[ 4.47454032, 5. , -0.52545968],
[ 1.52484356, 0. , -0.47515644],
[ 1.55074509, 1. , 0.55074509],
[ 3.98696449, 5. , -1.01303551],
[ 3.97059504, 4. , -0.02940496],
[ 3.85918495, 3. , 0.85918495],
[ 1.24810015, 1. , 0.24810015],
[ 1.62459262, 2. , -0.37540738],
[ 3.66986787, 4. , -0.33013213],
[ 3.54198014, 3. , 0.54198014],
[ 4.79943488, 5. , -0.20056512],
[ 1.69793326, 1. , 0.69793326],
[ 1.58617317, 2. , -0.41382683],
[ 4.1376663 , 4. , 0.1376663 ],
[ 4.10344883, 5. , -0.89655117],
[ 3.63891658, 3. , 0.63891658],
[ 1.71531344, 1. , 0.71531344],
[ 1.69460638, 2. , -0.30539362],
[ 4.02300241, 5. , -0.97699759],
[ 3.93272179, 4. , -0.06727821],
[ 3.91481096, 3. , 0.91481096],
[ 1.57404885, 1. , 0.57404885],
[ 1.29974807, 2. , -0.70025193],
[ 3.43582389, 3. , 0.43582389],
[ 3.5243598 , 4. , -0.4756402 ],
[ 2.27410224, 0. , -2.72589776],
[ 1.10700249, 1. , 0.10700249],
[ 1.16714062, 2. , -0.83285938],
[ 4.07213299, 4. , 0.07213299],
[ 4.12135044, 3. , 1.12135044],
[ 3.87078936, 5. , -1.12921064],
[ 1.68449955, 2. , -0.31550045],
[ 1.60428171, 1. , 0.60428171],
[ 4.14871254, 5. , -0.85128746],
[ 4.28263841, 4. , 0.28263841],
[ 3.41619002, 3. , 0.41619002],
[ 4.62560833, 4. , 0.62560833],
[ 4.05205221, 5. , -0.94794779],
[ 2.55317212, 3. , -0.44682788],
[ 2.66680138, 2. , 0.66680138],
[ 0.83417602, 1. , -0.16582398],
[ 4.58030892, 4. , 0.58030892],
[ 4.0197366 , 5. , -0.9802634 ],
[ 2.57726657, 3. , -0.42273343],
[ 2.79122591, 2. , 0.79122591],
[ 0.69624294, 1. , -0.30375706],
[ 1.20624761, 2. , -0.79375239],
[ 1.10612152, 1. , 0.10612152],
[ 4.20693688, 3. , 1.20693688],
[ 4.23614686, 4. , 0.23614686],
[ 3.56631529, 5. , -1.43368471],
[ 1.73888174, 2. , -0.26111826],
[ 1.64337964, 1. , 0.64337964],
[ 4.12220661, 5. , -0.87779339],
[ 4.14964689, 4. , 0.14964689],
[ 3.55789362, 3. , 0.55789362],
[ 1.13147051, 1. , 0.13147051],
[ 1.32433383, 2. , -0.67566617],
[ 3.94218514, 3. , 0.94218514],
[ 4.01280332, 4. , 0.01280332],
[ 4.10728622, 5. , -0.89271378],
[ 1.28113118, 2. , -0.71886882],
[ 1.63935595, 1. , 0.63935595],
[ 3.71823625, 3. , 0.71823625],
[ 3.71832546, 4. , -0.28167454],
[ 4.5450203 , 5. , -0.4549797 ],
[ 1.77791569, 2. , -0.22208431],
[ 1.66842039, 1. , 0.66842039],
[ 4.11773873, 4. , 0.11773873],
[ 4.09306547, 5. , -0.90693453],
[ 3.61683769, 3. , 0.61683769],
[ 1.47290005, 1. , 0.47290005],
[ 1.50925726, 2. , -0.49074274],
[ 3.92312391, 0. , 0.92312391],
[ 3.9062157 , 4. , -0.0937843 ],
[ 3.82848362, 0. , -1.17151638],
[ 1.46103827, 1. , 0.46103827],
[ 1.58404067, 2. , -0.41595933],
[ 3.50582218, 0. , -1.49417782],
[ 3.73732263, 4. , -0.26267737],
[ 3.22673005, 3. , 0.22673005],
[ 1.72104299, 2. , -0.27895701],
[ 1.838129 , 1. , 0.838129 ],
[ 3.95118648, 4. , -0.04881352],
[ 4.11197255, 5. , -0.88802745],
[ 3.77340792, 3. , 0.77340792],
[ 1.34400899, 2. , -0.65599101],
[ 1.40548812, 1. , 0.40548812],
[ 3.93117907, 4. , -0.06882093],
[ 3.67388799, 3. , 0.67388799],
[ 4.36215808, 5. , -0.63784192],
[ 2.19485857, 0. , 1.19485857],
[ 1.96226504, 2. , -0.03773496],
[ 4.32091318, 4. , 0.32091318],
[ 4.46082193, 5. , -0.53917807],
[ 3.15809731, 3. , 0.15809731]])
【全体像】
貼っておきます。誤差関数と偏微分あたりがエッセンスかと思われるので、その辺に間違いがないかチェックしてもらえると助かります。
main.py
# coding:utf-8
import numpy as np
from functions_cf import *
import matplotlib.pyplot as plt
sample_size=100
music_size=5
feature_size=3
cover_size=np.int(sample_size*music_size/10)
score=make_score(sample_size, music_size)
data, R=make_data(score, cover_size)
features=np.random.randn(feature_size, music_size)
theta=np.random.randn(sample_size, feature_size+1)
learning_rate=0.3
lam=0.003
iter=2000
J_list=np.array([])
for i in range(iter):
prediction=predict(np.r_[np.ones(features.shape[1]).reshape(1, -1), features], theta)
J=cost_func(data, prediction, R, theta, lam)
delta_theta, delta_features=delta_func(data, prediction, np.r_[np.ones(features.shape[1]).reshape(1, -1), features], R, theta, lam)
theta=theta-learning_rate*delta_theta
features=features-learning_rate*delta_features
J_list=np.r_[J_list, J]
if i%500==0:
print(J)
fig=plt.figure()
plt.plot(np.arange(0, J_list.shape[0], 1), J_list)
plt.show()
sub=prediction-score
result=np.c_[prediction.reshape(-1), data.reshape(-1), sub.reshape(-1)]
np.load('result.npy', result)
functions_cf.py
import numpy as np
def predict(features, theta):
out=np.dot(theta, features)
return out
def cost_func(data, prediction, R, theta, lam):
sub=prediction-data
J=np.sum(sub*sub*R)/(np.sum(R)*2)+np.sum(theta*theta)/(np.sum(R)*2)*lam
return J
def delta_func(data, prediction, features, R, theta, lam):
delta_theta=np.dot((prediction-data)*R, features.T)/np.sum(R)
delta_theta[:, 1:]=delta_theta[:, 1:]+theta[:, 1:]*lam/np.sum(R)
delta_features=np.dot(theta.T, (prediction-data)*R)/np.sum(R)
delta_features=delta_features+features*lam
return delta_theta, delta_features[1:, :]
def make_score(s, m):
score=np.zeros(s*m).reshape(s, m)
mask=np.random.randint(0, 3, s)
for i in range(mask.shape[0]):
if mask[i]==0:
temp=np.random.randn(1)
if temp>=0:
score[i, 0]=5
score[i, 1]=4
else:
score[i, 0]=4
score[i, 1]=5
temp=np.random.randn(1)
if temp>=0:
score[i, 2]=3
score[i, 3]=2
else:
score[i, 2]=2
score[i, 3]=3
score[i, 4]=1
elif mask[i]==1:
temp=np.random.randn(1)
if temp>=0:
score[i, 0]=1
score[i, 1]=2
else:
score[i, 0]=2
score[i, 1]=1
temp=np.random.randn(1)
if temp>=0:
score[i, 2]=5
score[i, 3]=4
else:
score[i, 2]=4
score[i, 3]=5
score[i, 4]=3
else:
temp=np.random.randn(1)
if temp>=0:
score[i, 0]=1
score[i, 1]=2
else:
score[i, 0]=2
score[i, 1]=1
temp=np.random.randn(1)
if temp>=0:
score[i, 2]=3
score[i, 3]=4
else:
score[i, 2]=4
score[i, 3]=3
score[i, 4]=5
return score
def make_data(score, cover_size):
data=score.copy()
data=data.reshape(-1)
mask=np.random.choice(data.size, cover_size)
data[mask]=0
data=data.reshape(score.shape)
R=(data!=0)
return data, R
PS
ガチの機械学習の人に見られたらボコボコにされそう。
ちゃんとしたデータがないので現実的ではないでしょうけど、グーグルアンケートでやれるだけやってみようかな。
この手の記事の書き方がわからないので、地の文で書き続けていきます。
【0.ライブラリ読み込み】
作った関数はfunctions_cf.pyに入ってます。末尾参照。
import numpy as np
from functions_cf import *
import matplotlib.pyplot as plt
【1.データ作成】
登場する変数
music_size:扱う楽曲数 今回は5とした
sample_size:アンケート回答者 今回は100とした
cover_size:アンケート回答者が知らなかった楽曲数 今回は50
楽曲5曲はそれぞれA,B,C,D,Eと呼ぶことにします。100人の回答者がその5曲に対して順位付けを5-1とつけてもらうことにします(数字が大きいほど好き)。それを100*5の行列scoreとします。
ひとまず次のような設定にしました。
・AとB、CとDは同系統
・AとBが好きな人はCとDが好きではなく、Eはさらに好きではない
・CとDが好きな人は、Eが好きではなく、AとBはさらに好きではない
・Eが好きな人は、CとDは好きではなく、AとBはさらに好きではない
という前提でデータ作り。つまり言い換えれば
・(A,B)が(5,4)or(4,5)な人:(C,D)は(3,2)(2,3)のいずれか、Eは1
・(C,D)が(5,4)or(4,5)な人:(A,B)は(2,1)(1,2)のいずれか、Eは3
・Eが5な人:(A,B)は(2,1)or(1,2)、(C,D)は(4,3)or(3,4)
こんな行列を作るmaking_score関数を作って
score=making_score(sample_size, music_size)
で作ると。
ところが現実問題、全員が全員アンケートの全曲を知っているとは限らないので、scoreからcover_sizeの数の分だけランダムに0にしたdata、および「dataのうち0でないものを真、0のものを偽」とした行列Rを作ります。こちらもmaking_data関数を作って
data, R=making_data(score, cover_size)
で吐いてもらいました。
なお0が意味するのは評価が0というわけではなく、値がないことを意味する便宜上の0です。学習するときにはこのdataを学習し、scoreを予測することを目標とします。
【2.学習準備】
登場する変数
feature_size:機械学習で学んでもらう特徴の数 今回は3としました
theta:各サンプルの好みの特徴を表す行列(これを学習してもらう)
features:各楽曲の持つ特徴の行列(これも学習してもらう)
learning_rate:学習率
lam:正規化項
iter:学習回数
J_list:誤差関数記録
まず学習する2つを乱数で決めてやる。
features=np.random.randn(feature_size, music_size)
theta=np.random.randn(sample_size, feature_size+1)
featuresは実際に使う時に全部1の列を先頭に追加するので、thetaの余分な1行との内積でバイアスとして使われます。
他初期設定
learning_rate=0.3
lam=0.003
iter=2000
J_list=np.array([])
なおlearning_rateとlamはクロスバリデーションで求めました。クロスバリデーションといっても、データセットの出所は同じ関数ですけど。
【3.学習】
0. iter回繰り返し
1. featuresとthetaからscoreの予測値predictionを出させる
2. 記録用に二乗誤差Jを求める
3. 誤差関数の偏微分delta_theta, delta_featuresを求める
4. thetaの調整
5. featuresの調整
6. 記録用に今回のJをJ_listに追加
てな具合。以下。
for i in range(iter):
prediction=predict(np.r_[np.ones(features.shape[1]).reshape(1, -1), features], theta)
J=cost_func(data, prediction, R, theta, lam)
delta_theta, delta_features=delta_func(data, prediction, np.r_[np.ones(features.shape[1]).reshape(1, -1), features], R, theta, lam)
theta=theta-learning_rate*delta_theta
features=features-learning_rate*delta_features
J_list=np.r_[J_list, J]
【4.結果観測】
matplotlibの使い方はこの3行以外知らない模様。
fig=plt.figure()
plt.plot(np.arange(0, J_list.shape[0], 1), J_list)
plt.show()
sub=prediction-score
result=np.c_[prediction.reshape(-1), data.reshape(-1), sub.reshape(-1)]
np.save('result.npy', result)
二乗誤差をプロットしたのがこんな感じ
ちなみにresult[0:100]についてはこんな感じ。
左から順にprediction, data, prediction-scoreです。
array([[ 1.24964192, 1. , 0.24964192],
[ 1.37293008, 2. , -0.62706992],
[ 3.89820064, 4. , -0.10179936],
[ 3.65735034, 3. , 0.65735034],
[ 4.47454032, 5. , -0.52545968],
[ 1.52484356, 0. , -0.47515644],
[ 1.55074509, 1. , 0.55074509],
[ 3.98696449, 5. , -1.01303551],
[ 3.97059504, 4. , -0.02940496],
[ 3.85918495, 3. , 0.85918495],
[ 1.24810015, 1. , 0.24810015],
[ 1.62459262, 2. , -0.37540738],
[ 3.66986787, 4. , -0.33013213],
[ 3.54198014, 3. , 0.54198014],
[ 4.79943488, 5. , -0.20056512],
[ 1.69793326, 1. , 0.69793326],
[ 1.58617317, 2. , -0.41382683],
[ 4.1376663 , 4. , 0.1376663 ],
[ 4.10344883, 5. , -0.89655117],
[ 3.63891658, 3. , 0.63891658],
[ 1.71531344, 1. , 0.71531344],
[ 1.69460638, 2. , -0.30539362],
[ 4.02300241, 5. , -0.97699759],
[ 3.93272179, 4. , -0.06727821],
[ 3.91481096, 3. , 0.91481096],
[ 1.57404885, 1. , 0.57404885],
[ 1.29974807, 2. , -0.70025193],
[ 3.43582389, 3. , 0.43582389],
[ 3.5243598 , 4. , -0.4756402 ],
[ 2.27410224, 0. , -2.72589776],
[ 1.10700249, 1. , 0.10700249],
[ 1.16714062, 2. , -0.83285938],
[ 4.07213299, 4. , 0.07213299],
[ 4.12135044, 3. , 1.12135044],
[ 3.87078936, 5. , -1.12921064],
[ 1.68449955, 2. , -0.31550045],
[ 1.60428171, 1. , 0.60428171],
[ 4.14871254, 5. , -0.85128746],
[ 4.28263841, 4. , 0.28263841],
[ 3.41619002, 3. , 0.41619002],
[ 4.62560833, 4. , 0.62560833],
[ 4.05205221, 5. , -0.94794779],
[ 2.55317212, 3. , -0.44682788],
[ 2.66680138, 2. , 0.66680138],
[ 0.83417602, 1. , -0.16582398],
[ 4.58030892, 4. , 0.58030892],
[ 4.0197366 , 5. , -0.9802634 ],
[ 2.57726657, 3. , -0.42273343],
[ 2.79122591, 2. , 0.79122591],
[ 0.69624294, 1. , -0.30375706],
[ 1.20624761, 2. , -0.79375239],
[ 1.10612152, 1. , 0.10612152],
[ 4.20693688, 3. , 1.20693688],
[ 4.23614686, 4. , 0.23614686],
[ 3.56631529, 5. , -1.43368471],
[ 1.73888174, 2. , -0.26111826],
[ 1.64337964, 1. , 0.64337964],
[ 4.12220661, 5. , -0.87779339],
[ 4.14964689, 4. , 0.14964689],
[ 3.55789362, 3. , 0.55789362],
[ 1.13147051, 1. , 0.13147051],
[ 1.32433383, 2. , -0.67566617],
[ 3.94218514, 3. , 0.94218514],
[ 4.01280332, 4. , 0.01280332],
[ 4.10728622, 5. , -0.89271378],
[ 1.28113118, 2. , -0.71886882],
[ 1.63935595, 1. , 0.63935595],
[ 3.71823625, 3. , 0.71823625],
[ 3.71832546, 4. , -0.28167454],
[ 4.5450203 , 5. , -0.4549797 ],
[ 1.77791569, 2. , -0.22208431],
[ 1.66842039, 1. , 0.66842039],
[ 4.11773873, 4. , 0.11773873],
[ 4.09306547, 5. , -0.90693453],
[ 3.61683769, 3. , 0.61683769],
[ 1.47290005, 1. , 0.47290005],
[ 1.50925726, 2. , -0.49074274],
[ 3.92312391, 0. , 0.92312391],
[ 3.9062157 , 4. , -0.0937843 ],
[ 3.82848362, 0. , -1.17151638],
[ 1.46103827, 1. , 0.46103827],
[ 1.58404067, 2. , -0.41595933],
[ 3.50582218, 0. , -1.49417782],
[ 3.73732263, 4. , -0.26267737],
[ 3.22673005, 3. , 0.22673005],
[ 1.72104299, 2. , -0.27895701],
[ 1.838129 , 1. , 0.838129 ],
[ 3.95118648, 4. , -0.04881352],
[ 4.11197255, 5. , -0.88802745],
[ 3.77340792, 3. , 0.77340792],
[ 1.34400899, 2. , -0.65599101],
[ 1.40548812, 1. , 0.40548812],
[ 3.93117907, 4. , -0.06882093],
[ 3.67388799, 3. , 0.67388799],
[ 4.36215808, 5. , -0.63784192],
[ 2.19485857, 0. , 1.19485857],
[ 1.96226504, 2. , -0.03773496],
[ 4.32091318, 4. , 0.32091318],
[ 4.46082193, 5. , -0.53917807],
[ 3.15809731, 3. , 0.15809731]])
【全体像】
貼っておきます。誤差関数と偏微分あたりがエッセンスかと思われるので、その辺に間違いがないかチェックしてもらえると助かります。
main.py
# coding:utf-8
import numpy as np
from functions_cf import *
import matplotlib.pyplot as plt
sample_size=100
music_size=5
feature_size=3
cover_size=np.int(sample_size*music_size/10)
score=make_score(sample_size, music_size)
data, R=make_data(score, cover_size)
features=np.random.randn(feature_size, music_size)
theta=np.random.randn(sample_size, feature_size+1)
learning_rate=0.3
lam=0.003
iter=2000
J_list=np.array([])
for i in range(iter):
prediction=predict(np.r_[np.ones(features.shape[1]).reshape(1, -1), features], theta)
J=cost_func(data, prediction, R, theta, lam)
delta_theta, delta_features=delta_func(data, prediction, np.r_[np.ones(features.shape[1]).reshape(1, -1), features], R, theta, lam)
theta=theta-learning_rate*delta_theta
features=features-learning_rate*delta_features
J_list=np.r_[J_list, J]
if i%500==0:
print(J)
fig=plt.figure()
plt.plot(np.arange(0, J_list.shape[0], 1), J_list)
plt.show()
sub=prediction-score
result=np.c_[prediction.reshape(-1), data.reshape(-1), sub.reshape(-1)]
np.load('result.npy', result)
functions_cf.py
import numpy as np
def predict(features, theta):
out=np.dot(theta, features)
return out
def cost_func(data, prediction, R, theta, lam):
sub=prediction-data
J=np.sum(sub*sub*R)/(np.sum(R)*2)+np.sum(theta*theta)/(np.sum(R)*2)*lam
return J
def delta_func(data, prediction, features, R, theta, lam):
delta_theta=np.dot((prediction-data)*R, features.T)/np.sum(R)
delta_theta[:, 1:]=delta_theta[:, 1:]+theta[:, 1:]*lam/np.sum(R)
delta_features=np.dot(theta.T, (prediction-data)*R)/np.sum(R)
delta_features=delta_features+features*lam
return delta_theta, delta_features[1:, :]
def make_score(s, m):
score=np.zeros(s*m).reshape(s, m)
mask=np.random.randint(0, 3, s)
for i in range(mask.shape[0]):
if mask[i]==0:
temp=np.random.randn(1)
if temp>=0:
score[i, 0]=5
score[i, 1]=4
else:
score[i, 0]=4
score[i, 1]=5
temp=np.random.randn(1)
if temp>=0:
score[i, 2]=3
score[i, 3]=2
else:
score[i, 2]=2
score[i, 3]=3
score[i, 4]=1
elif mask[i]==1:
temp=np.random.randn(1)
if temp>=0:
score[i, 0]=1
score[i, 1]=2
else:
score[i, 0]=2
score[i, 1]=1
temp=np.random.randn(1)
if temp>=0:
score[i, 2]=5
score[i, 3]=4
else:
score[i, 2]=4
score[i, 3]=5
score[i, 4]=3
else:
temp=np.random.randn(1)
if temp>=0:
score[i, 0]=1
score[i, 1]=2
else:
score[i, 0]=2
score[i, 1]=1
temp=np.random.randn(1)
if temp>=0:
score[i, 2]=3
score[i, 3]=4
else:
score[i, 2]=4
score[i, 3]=3
score[i, 4]=5
return score
def make_data(score, cover_size):
data=score.copy()
data=data.reshape(-1)
mask=np.random.choice(data.size, cover_size)
data[mask]=0
data=data.reshape(score.shape)
R=(data!=0)
return data, R
PS
ガチの機械学習の人に見られたらボコボコにされそう。
2017年5月3日水曜日
2017年1~4月
研究室に通い始めて半年経ち、諸々の前処理は調べ調べではあるものの一人で進めることができるようになりました。ここからはひたすらインプットだけでなくresearch questionを探したり、研究っぽくなってきたんじゃないですかね。
機械学習についてはCourseraのMachineLearningコースを先月頭から始めまして、現在Week8の教師なし学習をやっております。Week7のサポートベクターマシンがやりたくて進めてきたんですが、SVMの細かい話はなく課題も細々とした穴埋めで終わってしまったため、結局は自分で勉強しなきゃダメそうですね。まぁ難しいんでしょう。
Pythonについては現在絶賛放置中でして「必要になったらその都度調べればいいっしょ」的なスタンスでいたら、指導してもらってる先生に紹介される際に
指導教官「彼はPythonで色々書いたりしてるんだよ」(確かに書いたりはしている)
他の先生「Python書けるんだ~」「誰々と会わせてみようか」
という間違った伝わり方をしたため、死ぬ前に死ぬほど頑張ろうと思います(Python超速で勉強始めます)。具体的にはまずオライリー「入門Python3」の知らないところをしらみつぶしして、その後もオライリーのPython関係をひたすらつぶしていきましょうというところですか。
(5/4追記)CourseraのMachineLearning終わったらPyQやるか
(5/4追記)CourseraのMachineLearning終わったらPyQやるか
MATLABは上述のCourseraを進める中で慣れた。
Linuxは絶賛放置!これは指導教官もなんとなく俺がやりたがっていないのを察されている気がする。具体的に困らないとやはりやる気が起きないというのが本音。
それと先日「Rをやるといいよ!」と言われたので、ここにきてまた新たに習得することになるのかもしれない。「僕もRはやらないといけないと思っているから、一緒に勉強してもいいよね」と仰っていただけたので、そうなったらそうします。そもそも何がいいかって、Rの勉強する際に統計の勉強もできる点だそうで、確かにそこはいいかもしれない。
医学は、どうしよう。ただの通過点とはいえ総合試験も手放しで受けるわけにもいかないので、計画だけは立てないといけないと思うわけです。CCに合わせていくとすれば、
5月 神経系
6月 代謝内分泌系
7月 消化器系
8月 小児産婦
9月 循環器系
10月 呼吸器 膠リア
11月 腎泌尿 血液
12月 総合試験
てな具合ですかな。やるとしたらQBしかないよなぁ。やるかぁ。やるかぁ~。
そもそも既にCCで回った診療科が今後の勉強の予定に組み込まれてること自体ダメなンすわ。幸い危機感が出てきた今月はやる気のある神経内科なので、ここで医学へのモチベーションを高めて6月以降も研究室と並行して医学を進められれば良いですね。(他人事)
2017年5月2日火曜日
png画像で"真黒"以外を真白にする
もういっちょ。
手書き画像を写真で撮ってペイントで線なぞった後、以前は消しゴムでくまなく消すというアホらしい作業をしていたのですが、もう今となればちょちょいのちょいですね。
線でなぞるところも自動化するのはおいおい。
> python line.py filename new_filename
import sys
import numpy as np
from PIL import Image
filename=sys.argv[1]
new_filename=sys.argv[2]
img=np.array(Image.open(filename))
img[~(img==0)]=255
Image.fromarray(img).save(new_filename)
さっきも一つ備忘録残したんすけどこれ事情があって、ファイル管理下手くそすぎてどっか行っちゃうンすわ。
今気づいたけど一色だけパラメータ振り切れてたりしても消せないね。まぁいいやとりあえず困ってないし。
手書き画像を写真で撮ってペイントで線なぞった後、以前は消しゴムでくまなく消すというアホらしい作業をしていたのですが、もう今となればちょちょいのちょいですね。
線でなぞるところも自動化するのはおいおい。
> python line.py filename new_filename
import sys
import numpy as np
from PIL import Image
filename=sys.argv[1]
new_filename=sys.argv[2]
img=np.array(Image.open(filename))
img[~(img==0)]=255
Image.fromarray(img).save(new_filename)
さっきも一つ備忘録残したんすけどこれ事情があって、ファイル管理下手くそすぎてどっか行っちゃうンすわ。
今気づいたけど一色だけパラメータ振り切れてたりしても消せないね。まぁいいやとりあえず困ってないし。
登録:
投稿 (Atom)