まずMikuMikuRecommendationという名前をつけました(笑)
というのはさておき本題に入りまして、何を変えたかというとデータセットの作り方なのでこれも本質は何も変わらない。
昨日のやつだとデータセットの作り方がそもそもダサすぎるし大きいデータセット作れないので、乱数にcorrelation的サムシングを加えて1-10に標準化したものをデータセットとして使ってみました。いろいろ問題ありそうだけどパッと見いい感じなデータセットが出来たのでよしとします。
correlation的サムシングをどれくらい強くするかをtightnessという変数で決められるようにしたのですが、暫定的に3にしてあります。なぜかというと2にしたらうまくいかなかったから。。。
それとインラインで作業するときにコピペできるといいなと思って、学習部分を関数にまとめた。
現状下記です。
【全体像】
いちいちコメント書くのは面倒なのでドーンとコピペ
main.py
# coding:utf-8
import numpy as np
from functions_cf import *
import matplotlib.pyplot as plt
sample_size=200
music_size=10
feature_size=10
missing_rate=0.1
cover_size=np.int(sample_size*music_size*missing_rate)
category=4
min=1
max=10
score=make_bigscore(sample_size, music_size, category, max)
data, R=make_data(score, cover_size)
features=np.random.randn(feature_size, music_size)
theta=np.random.randn(sample_size, feature_size+1)
learning_rate=0.1
lam=0.01
iter=5000
theta, features, J_list=learning(iter, learning_rate, data, features, R, theta, lam)
fig=plt.figure()
plt.plot(np.arange(0, J_list.shape[0], 1), J_list)
plt.show()
prediction=predict(np.r_[np.ones(features.shape[1]).reshape(1, -1), features], theta)
sub=prediction-score
result=np.c_[prediction.reshape(-1), data.reshape(-1), sub.reshape(-1)]
print(result[R.reshape(-1)==0])
functions_cf.py
import numpy as np
import math
def predict(features, theta):
out=np.dot(theta, features)
return out
def cost_func(data, prediction, R, theta, lam):
sub=prediction-data
J=np.sum(sub*sub*R)/(np.sum(R)*2)+np.sum(theta*theta)/(np.sum(R)*2)*lam
return J
def delta_func(data, prediction, features, R, theta, lam):
delta_theta=np.dot((prediction-data)*R, features.T)/np.sum(R)
delta_theta[:, 1:]=delta_theta[:, 1:]+theta[:, 1:]*lam/np.sum(R)
delta_features=np.dot(theta.T, (prediction-data)*R)/np.sum(R)
delta_features=delta_features+features*lam/np.sum(R)
return delta_theta, delta_features[1:, :]
def make_data(score, cover_size):
data=score.copy()
data=data.reshape(-1)
mask=np.random.choice(data.size, cover_size)
data[mask]=0
data=data.reshape(score.shape)
R=(data!=0)
return data, R
def learning(iter, learning_rate, data, features, R, theta, lam):
J_list=np.array([])
for i in range(iter):
prediction=predict(np.r_[np.ones(features.shape[1]).reshape(1, -1), features], theta)
J=cost_func(data, prediction, R, theta, lam)
delta_theta, delta_features=delta_func(data, prediction, np.r_[np.ones(features.shape[1]).reshape(1, -1), features], R, theta, lam)
theta=theta-learning_rate*delta_theta
features=features-learning_rate*delta_features
J_list=np.r_[J_list, J]
return theta, features, J_list
def make_bigscore(s, m, category, max, tightness=3):
score=np.zeros(s*m).reshape(s, m)+5
favo_mask=np.random.randint(0, category, s)
music_mask=np.random.randint(0, category, m)
correlation_rate=make_half(np.random.randn(category, category))
for i in range(s):
for j in range(m):
score[i, j]=score[i, j]+correlation_rate[favo_mask[i], music_mask[j]]*tightness+np.random.randn(1)
score=(score-np.min(score))*max/np.max(score)+1
for i in range(score.shape[0]):
for j in range(score.shape[1]):
score[i, j]=math.floor(score[i, j])
return score
def make_half(matrix):
for i in range(matrix.shape[0]):
for j in range(matrix.shape[1]):
if i>j:
matrix[j, i]=matrix[i, j]
elif i==j:
matrix[i, j]=0
return matrix
【結果観測】
まずはR=0となっている「抜け値」を埋めること。
上のコードをそのまま実行するとこんな感じの結果になる。
resultは下記。左から順にprediction, data(=0), sub(=prediction-score)
[[ 6.79221021e+00 0.00000000e+00 1.79221021e+00]
[ 7.53311390e+00 0.00000000e+00 -4.66886104e-01]
[ 7.09569830e+00 0.00000000e+00 9.56982987e-02]
[ 4.08265903e+00 0.00000000e+00 -1.91734097e+00]
[ 5.11038322e+00 0.00000000e+00 1.10383221e-01]
[ 6.83060034e+00 0.00000000e+00 -3.16939966e+00]
[ 4.94816109e+00 0.00000000e+00 9.48161091e-01]
[ 8.94097605e+00 0.00000000e+00 -5.90239520e-02]
[ 5.55085831e+00 0.00000000e+00 1.55085831e+00]
[ 6.49319455e+00 0.00000000e+00 4.93194551e-01]
[ 3.51685950e+00 0.00000000e+00 -1.48314050e+00]
[ 3.30233833e+00 0.00000000e+00 -6.97661671e-01]
[ 8.43240061e+00 0.00000000e+00 4.32400607e-01]
[ 8.41920554e+00 0.00000000e+00 1.41920554e+00]
[ 8.48562862e+00 0.00000000e+00 4.85628622e-01]
[ 7.91982280e+00 0.00000000e+00 -1.08017720e+00]
[ 7.69216022e+00 0.00000000e+00 -1.30783978e+00]
[ 7.55579536e+00 0.00000000e+00 5.55795356e-01]
[ 4.36991203e+00 0.00000000e+00 -2.63008797e+00]
[ 8.37558195e+00 0.00000000e+00 -6.24418049e-01]
[ 7.60526264e+00 0.00000000e+00 -3.94737356e-01]
[ 7.27028265e+00 0.00000000e+00 2.70282648e-01]
[ 3.85643006e+00 0.00000000e+00 -1.43569940e-01]
[ 7.31909333e+00 0.00000000e+00 -1.68090667e+00]
[ 6.51557893e+00 0.00000000e+00 -1.48442107e+00]
[ 4.57743090e+00 0.00000000e+00 5.77430899e-01]
[ 4.38028851e+00 0.00000000e+00 3.80288514e-01]
[ 7.15614184e+00 0.00000000e+00 -8.43858160e-01]
[ 4.86797258e+00 0.00000000e+00 8.67972580e-01]
[ 4.77330759e+00 0.00000000e+00 -2.26692406e-01]
[ 5.24170637e+00 0.00000000e+00 -7.58293631e-01]
[ 5.85907532e+00 0.00000000e+00 1.85907532e+00]
[ 7.01088219e+00 0.00000000e+00 -9.89117812e-01]
[ 5.12001067e+00 0.00000000e+00 1.20010669e-01]
[ 4.12287125e+00 0.00000000e+00 -1.87712875e+00]
[ 4.27710625e+00 0.00000000e+00 2.77106249e-01]
[ 5.09901728e+00 0.00000000e+00 -9.00982720e-01]
[ 9.11113983e+00 0.00000000e+00 1.11139833e-01]
[ 4.42793609e+00 0.00000000e+00 4.27936088e-01]
[ 8.48836557e+00 0.00000000e+00 4.88365570e-01]
[ 7.51564493e+00 0.00000000e+00 5.15644932e-01]
[ 8.46280310e+00 0.00000000e+00 -5.37196902e-01]
[ 5.43836821e+00 0.00000000e+00 2.43836821e+00]
[ 4.65313035e+00 0.00000000e+00 -3.46869649e-01]
[ 7.30430048e+00 0.00000000e+00 3.04300479e-01]
[ 4.96641442e+00 0.00000000e+00 -1.03358558e+00]
[ 4.52025568e+00 0.00000000e+00 5.20255679e-01]
[ 4.92889597e+00 0.00000000e+00 -7.11040257e-02]
[ 5.58917669e+00 0.00000000e+00 -1.41082331e+00]
[ 4.91782026e+00 0.00000000e+00 9.17820259e-01]
[ 8.68151818e+00 0.00000000e+00 6.81518179e-01]
[ 6.85840196e+00 0.00000000e+00 -1.41598037e-01]
[ 3.50970449e+00 0.00000000e+00 -2.49029551e+00]
[ 3.54968073e+00 0.00000000e+00 -4.50319275e-01]
[ 8.32713967e+00 0.00000000e+00 -6.72860330e-01]
[ 2.88564867e+00 0.00000000e+00 -1.14351328e-01]
[ 4.13329265e+00 0.00000000e+00 2.13329265e+00]
[ 6.02285979e+00 0.00000000e+00 -1.97714021e+00]
[ 5.48115372e+00 0.00000000e+00 -5.18846284e-01]
[ 6.22298632e+00 0.00000000e+00 2.22986317e-01]
[ 5.90004546e+00 0.00000000e+00 -9.99545425e-02]
[ 4.40149963e+00 0.00000000e+00 4.01499634e-01]
[ 7.00155610e+00 0.00000000e+00 1.00155610e+00]
[ 6.12984227e+00 0.00000000e+00 2.12984227e+00]
[ 4.43785983e+00 0.00000000e+00 4.37859834e-01]
[ 6.33879076e+00 0.00000000e+00 3.38790763e-01]
[ 4.33571888e+00 0.00000000e+00 -1.66428112e+00]
[ 2.48997197e+00 0.00000000e+00 1.48997197e+00]
[ 7.08099180e+00 0.00000000e+00 -9.19008203e-01]
[ 5.10492050e+00 0.00000000e+00 -1.89507950e+00]
[ 5.02968334e+00 0.00000000e+00 1.02968334e+00]
[ 7.13977640e+00 0.00000000e+00 1.39776399e-01]
[ 7.55119719e+00 0.00000000e+00 -4.48802809e-01]
[ 8.53262155e+00 0.00000000e+00 -4.67378448e-01]
[ 7.51423486e+00 0.00000000e+00 5.14234860e-01]
[ 2.99267833e+00 0.00000000e+00 -1.00732167e+00]
[ 3.32044090e+00 0.00000000e+00 -6.79559105e-01]
[ 8.25607045e+00 0.00000000e+00 2.25607045e+00]
[ 8.27929557e+00 0.00000000e+00 2.27929557e+00]
[ 5.79284155e+00 0.00000000e+00 -2.20715845e+00]
[ 5.63182840e+00 0.00000000e+00 6.31828398e-01]
[ 4.27518546e+00 0.00000000e+00 -1.72481454e+00]
[ 2.82775364e+00 0.00000000e+00 -1.17224636e+00]
[ 3.27150448e+00 0.00000000e+00 -7.28495524e-01]
[ 4.86288111e+00 0.00000000e+00 8.62881106e-01]
[ 4.15287914e+00 0.00000000e+00 1.15287914e+00]
[ 5.48403925e+00 0.00000000e+00 4.84039252e-01]
[ 7.43184309e+00 0.00000000e+00 4.31843086e-01]
[ 5.42698644e+00 0.00000000e+00 -1.57301356e+00]
[ 7.65437641e+00 0.00000000e+00 -1.34562359e+00]
[ 7.55554759e+00 0.00000000e+00 5.55547595e-01]
[ 2.67930823e+00 0.00000000e+00 -1.32069177e+00]
[ 4.00815548e+00 0.00000000e+00 8.15547591e-03]
[ 6.37997063e+00 0.00000000e+00 3.79970634e-01]
[ 6.96754636e+00 0.00000000e+00 -3.03245364e+00]
[ 7.78552434e+00 0.00000000e+00 1.78552434e+00]
[ 7.13891523e+00 0.00000000e+00 -8.61084772e-01]
[ 5.15818977e+00 0.00000000e+00 -3.84181023e+00]
[ 7.47774417e+00 0.00000000e+00 -5.22255829e-01]
[ 6.35284199e+00 0.00000000e+00 3.52841987e-01]
[ 6.76815734e+00 0.00000000e+00 7.68157336e-01]
[ 5.94075467e+00 0.00000000e+00 -1.05924533e+00]
[ 7.12827211e+00 0.00000000e+00 2.12827211e+00]
[ 7.07915735e+00 0.00000000e+00 2.07915735e+00]
[ 7.12728376e+00 0.00000000e+00 3.12728376e+00]
[ 4.78467425e+00 0.00000000e+00 -1.21532575e+00]
[ 4.11857739e+00 0.00000000e+00 1.11857739e+00]
[ 4.94899594e+00 0.00000000e+00 -2.05100406e+00]
[ 3.62200612e+00 0.00000000e+00 -3.77993876e-01]
[ 4.13980901e+00 0.00000000e+00 1.13980901e+00]
[ 4.79519851e+00 0.00000000e+00 7.95198510e-01]
[ 9.58568796e+00 0.00000000e+00 -4.14312037e-01]
[ 2.40409199e+00 0.00000000e+00 1.40409199e+00]
[ 7.23112510e+00 0.00000000e+00 2.31125099e-01]
[ 8.03322342e+00 0.00000000e+00 2.03322342e+00]
[ 4.59414694e+00 0.00000000e+00 5.94146940e-01]
[ 2.43238578e+00 0.00000000e+00 4.32385779e-01]
[ 5.30828543e+00 0.00000000e+00 1.30828543e+00]
[ 7.66747838e+00 0.00000000e+00 2.66747838e+00]
[ 8.80101019e+00 0.00000000e+00 -1.19898981e+00]
[ 8.32916719e+00 0.00000000e+00 3.29167190e-01]
[ 8.59429933e+00 0.00000000e+00 5.94299329e-01]
[ 3.75772349e+00 0.00000000e+00 7.57723494e-01]
[ 4.88486628e+00 0.00000000e+00 8.84866279e-01]
[ 5.53191466e+00 0.00000000e+00 -4.68085340e-01]
[ 2.88859452e+00 0.00000000e+00 -1.11405479e-01]
[ 3.47813361e+00 0.00000000e+00 -1.52186639e+00]
[ 8.17178509e+00 0.00000000e+00 1.71785091e-01]
[ 5.26310717e+00 0.00000000e+00 2.63107170e-01]
[ 7.66027706e+00 0.00000000e+00 -3.39722939e-01]
[ 7.06787062e+00 0.00000000e+00 1.06787062e+00]
[ 6.46431101e+00 0.00000000e+00 1.46431101e+00]
[ 3.89064273e+00 0.00000000e+00 8.90642734e-01]
[ 4.90327565e+00 0.00000000e+00 9.03275650e-01]
[ 1.03351212e+01 0.00000000e+00 1.33512121e+00]
[ 8.87801850e-01 0.00000000e+00 -1.11219815e+00]
[ 7.12930310e-01 0.00000000e+00 -2.28706969e+00]
[ 7.24062819e+00 0.00000000e+00 1.24062819e+00]
[ 6.85799559e+00 0.00000000e+00 -1.14200441e+00]
[ 1.85540284e+00 0.00000000e+00 -1.14459716e+00]
[ 2.32319731e+00 0.00000000e+00 -6.76802692e-01]
[ 2.48839620e+00 0.00000000e+00 4.88396199e-01]
[ 5.65267581e+00 0.00000000e+00 6.52675811e-01]
[ 6.75195833e+00 0.00000000e+00 -2.48041665e-01]
[ 5.82486383e+00 0.00000000e+00 -1.17513617e+00]
[ 6.15176564e+00 0.00000000e+00 1.15176564e+00]
[ 7.77587521e+00 0.00000000e+00 -2.24124785e-01]
[ 7.89983313e+00 0.00000000e+00 -1.00166870e-01]
[ 7.95130816e+00 0.00000000e+00 -1.04869184e+00]
[ 5.09700416e+00 0.00000000e+00 1.09700416e+00]
[ 4.58884212e+00 0.00000000e+00 5.88842121e-01]
[ 3.52519701e+00 0.00000000e+00 5.25197009e-01]
[ 1.09648602e+01 0.00000000e+00 1.96486018e+00]
[ 1.13242890e+01 0.00000000e+00 3.32428897e+00]
[ 5.88963853e+00 0.00000000e+00 8.89638530e-01]
[ 9.09428528e+00 0.00000000e+00 -9.05714716e-01]
[ 8.22378866e+00 0.00000000e+00 1.22378866e+00]
[ 5.81835879e+00 0.00000000e+00 -1.81641213e-01]
[ 5.62012171e+00 0.00000000e+00 -3.79878291e-01]
[ 4.45124746e+00 0.00000000e+00 4.51247460e-01]
[ 2.90545701e+00 0.00000000e+00 -1.09454299e+00]
[ 6.10119442e+00 0.00000000e+00 -1.89880558e+00]
[ 6.03296647e+00 0.00000000e+00 -9.67033530e-01]
[ 6.67224047e+00 0.00000000e+00 -3.27759529e-01]
[ 2.81798517e+00 0.00000000e+00 -1.18201483e+00]
[ 2.27543560e+00 0.00000000e+00 -1.72456440e+00]
[ 7.22559961e+00 0.00000000e+00 1.22559961e+00]
[ 8.13112548e+00 0.00000000e+00 -8.68874517e-01]
[ 7.67090557e+00 0.00000000e+00 1.67090557e+00]
[ 3.03188884e+00 0.00000000e+00 3.18888430e-02]
[ 7.84613021e+00 0.00000000e+00 8.46130212e-01]
[ 3.04744195e+00 0.00000000e+00 -1.95255805e+00]
[ 5.16166895e+00 0.00000000e+00 -1.83833105e+00]
[ 6.67414529e+00 0.00000000e+00 -3.25854707e-01]
[ 2.75351809e+00 0.00000000e+00 -1.24648191e+00]
[ 5.74327034e+00 0.00000000e+00 -1.25672966e+00]
[ 5.63928383e+00 0.00000000e+00 -3.60716172e-01]
[ 7.64732695e+00 0.00000000e+00 -1.35267305e+00]
[ 5.78103852e+00 0.00000000e+00 7.81038524e-01]
[ 7.45234275e+00 0.00000000e+00 -1.54765725e+00]
[ 4.26299512e+00 0.00000000e+00 -7.37004877e-01]
[ 6.52988398e+00 0.00000000e+00 -4.70116016e-01]
[ 8.96044500e+00 0.00000000e+00 9.60445003e-01]
[ 7.35391597e+00 0.00000000e+00 -1.64608403e+00]
[ 6.70063979e+00 0.00000000e+00 -1.29936021e+00]]
たまにでかいのあるけどおおむねいい感じではないですか(非常に主観的)
実用ではmissing_rateをもっと大きくした状態での推測をしなければいけないだろうと思われるので、今後も色々試してみます。
あとこれと同時に作られるthetaとfeaturesについても活用できそうなのでまたやってみます。thetaの傾向が似てる人が好きな曲とかさ。
0 件のコメント:
コメントを投稿