逻辑回归-为什么模型会更加侧重于学习那些数值比较大的列
source link: https://blog.51cto.com/u_15767241/5981244
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
逻辑回归-为什么模型会更加侧重于学习那些数值比较大的列
精选 原创np.random.seed(24)
features,labels = arrayGenReg(w=[1,-1,1])
将第一个特征取值调大100倍
features[:,:1] = features[:,:1] * 100
features
---
array([[ 132.92121726, -0.77003345, 1. ],
[ -31.62803596, -0.99081039, 1. ],
[-107.08162556, -1.43871328, 1. ],
...,
[ 155.07577972, -0.35986144, 1. ],
[-136.26716091, -0.61353562, 1. ],
[-144.02913135, 0.50439425, 1. ]])
np.linalg.lstsq(features,labels,rcond=-1)
---
(array([[ 0.00999619],
[-0.99985281],
[ 0.99970541]]),
array([0.09300731]),
3,
array([3138.44895283, 31.98889632, 30.9814256 ]))
w = np.array([[0.0,0.0,0.0]]).T
w,w_rec = w_cal_rec(features,w,labels,lr_gd,lr=0.0001,itera_times=100)
plt.subplots(1,2,figsize=(10,4))
plt.subplot(121)
plt.plot(range(len(w_rec)),np.array(w_rec)[:,0])
plt.subplot(122)
plt.plot(range(len(w_rec)),np.array(w_rec)[:,1])
可以看到第一个特征对应的系数w1w_{1}w1大概在100次迭代后就到到达了解析解,并在附近震荡;但是w2w_{2}w2,显然距离解析解还有很远的距离,但此时模型主要还是在调整特征取值大的w1w_{1}w1
理性上理解一下,即使w1w_{1}w1只动一小点,评估指标就可以有很大的变化,这要是,但是w2w_{2}w2移动一点,却对评估指标作用不大,因此模型会更加侧重于学习那些数值比较大的列
plt.plot(np.array(w_rec)[:,0],np.array(w_rec)[:,1],'-')
这是w1,w2w_{1},w_{2}w1,w2的坐标变化图,由此图我们可以大概估计,如果能画出等高线图,等高线应该是一个椭圆形,长轴和yyy轴也就是w2w_{2}w2平行,且长轴远远长于短轴,这就导致初始的www取值在非长轴的任意位置,其梯度向量都是近乎或者完全垂直于长轴的,再加上学习率较大,每次迭代后基本到达对侧的相同位置,下次的梯度向量仍然是近乎或者完全垂直于长轴,几乎没有在yyy方向移动的分量,这就使得模型会更加侧重于学习那些数值比较大的列
之所以如此震荡是因为学习率选择的原因,如果调小学习率就可能不会有震荡了
w = np.array([[0.0,0.0,0.0]]).T w,w_rec = w_cal_rec(features,w,labels,lr_gd,lr=0.00001,itera_times=100) plt.subplots(1,2,figsize=(10,4)) plt.subplot(121) plt.plot(range(len(w_rec)),np.array(w_rec)[:,0]) plt.subplot(122) plt.plot(range(len(w_rec)),np.array(w_rec)[:,1])
归一化后模型对于每个特征的学习会平均很多
w = np.array([[0.0,0.0,0.0]]).T w,w_rec = w_cal_rec(features,w,labels,lr_gd,lr=0.3,itera_times=500) plt.subplots(1,2,figsize=(10,4)) plt.subplot(121) plt.plot(range(len(w_rec)),np.array(w_rec)[:,0]) plt.subplot(122) plt.plot(range(len(w_rec)),np.array(w_rec)[:,1])
- 赞
- 收藏
- 评论
- 分享
- 举报
Recommend
About Joyk
Aggregate valuable and interesting links.
Joyk means Joy of geeK