ruby – Logistic回归给出不正确的结果

前端之家收集整理的这篇文章主要介绍了ruby – Logistic回归给出不正确的结果前端之家小编觉得挺不错的,现在分享给大家,也给大家做个参考。
我正在一个网站上工作,收集人们玩过的国际象棋游戏的结果.观察球员的评分以及他们的评分与对手的评分之间的差异,我绘制了一个带有代表胜利(绿色),平局(蓝色)和损失(红色)的点的图表.

有了这些信息,我还实施了逻辑回归算法来对获胜和获胜/抽奖的截止值进行分类.使用评级和差异作为我的两个特征,我得到一个分类器,然后在图表上绘制分类器改变其预测的边界.

我的梯度下降代码,成本函数和sigmoid函数如下.

def gradient_descent()
    oldJ = 0    
    newJ = J()
    alpha = 1.0     # Learning rate
    run = 0
    while (run < 100) do
      tmpTheta = Array.new
      for j in 0...numFeatures do
        sum = 0
        for i in 0...m do
          sum += ((h(training_data[:x][i]) - training_data[:y][i][0]) * training_data[:x][i][j])
        end
        tmpTheta[j] = Array.new
        tmpTheta[j][0] = theta[j,0] - (alpha / m) * sum  # Alpha * partial derivative of J with respect to theta_j
      end
      self.theta = Matrix.rows(tmpTheta)
      oldJ = newJ
      newJ = J()
      run += 1
      if (run == 100 && (oldJ - newJ > 0.001)) then run -= 20 end   # Do 20 more if the error is still going down a fair amount.
      if (oldJ < newJ)
        alpha /= 10
      end
    end
  end

  def J()
    sum = 0
    for i in 0...m
      sum += ((training_data[:y][i][0] * Math.log(h(training_data[:x][i]))) 
          + ((1 - training_data[:y][i][0]) * Math.log(1 - h(training_data[:x][i]))))
    end
    return (-1.0 / m) * sum
  end

  def h(x)
    if (x.class != 'Matrix')    # In case it comes in as a row vector or an array
      x = Matrix.rows([x])      # [x] because if it's a row vector we want [[a,b]] to get an array whose first row is x.
    end
    x = x.transpose   # x is supposed to be a column vector,and theta^ a row vector,so theta^*x is a number.
    return g((theta.transpose * x)[0,0])  # theta^ * x gives [[z]],so get [0,0] of that for the number z.
  end

  def g(z)
    tmp = 1.0 / (1.0 + Math.exp(-z))   # Sigmoid function
    if (tmp == 1.0) then tmp = 0.99999 end    # These two things are here because ln(0) DNE,so we don't want to do ln(1 - 1.0) or ln(0.0)
    if (tmp == 0.0) then tmp = 0.00001 end
    return tmp
  end

当我在代表我自己的国际象棋简档的数据集上测试时,我得到了合理的结果,我可以满意:

有一段时间,我很高兴.我试过的所有例子都给出了有趣的图表.然后我尝试了一个名叫Kevin Cao的球员,他有超过250场比赛的名字,因此有1000场比赛,用于一个非常大的训练集.结果显然不正确:

嗯,这不好.所以我将初始学习率从1.0增加到100.0作为我的第一个想法.这对Kevin来说看起来是正确的结果:

不幸的是,当我在我自己和我的小数据集上尝试它时,我得到了一个奇怪的现象,它只是给出了一个预测值为0的扁平线:

我检查了θ,它说它是[[2.3707682771730836],[21.22408286825226],[ – 19081.906528679192]].第三个训练变量(实际上是第二个,因为x_0 = 1)是等级的差异,所以当差异只是最小的正位时,逻辑回归的公式变为负,并且sigmoid函数预测y = 0.差异只是稍微有点正面,同样,它会向上跳跃并预测y = 1.

我将初始学习率从100.0降低到1.0,并决定尝试更慢地减少它.因此,当成本函数增加时,不是将其减少十倍,而是将其减少了两倍.

不幸的是,这并没有改变我的结果.即使我将梯度下降的循环次数从100增加到1000,它仍然可以预测错误的结果.

我仍然是逻辑回归的初学者(我刚刚在coursera上完成了机器学习课程,这是我第一次尝试实现我在那里学到的任何算法),所以我已经达到了我的直觉程度.如果有人能帮助我弄清楚这里出了什么问题,我做错了什么,以及我如何解决它,我将非常感激.

编辑:我也尝试了另一个数据集,它有大约300个数据点,并再次得到一条平绿线和一条正常的蓝线.两者的算法基本相同,只是y的一些不同结果,因为我正在进行多类分类.

编辑:由于人们已经要求它,这里是J,Alpha和Theta每次迭代下降的线条平坦线:

J: 1.7679949412730092  Alpha: 1.0  Theta: Matrix[[-0.004477611940298508],[0.2835820895522388],[-123.63880597014925]]
J: 0.6873432218114784  Alpha: 0.1  Theta: Matrix[[-0.008057848266678727],[-8.033992854843122],[-118.62571350649955]]
J: 2.7493579020963597  Alpha: 0.1  Theta: Matrix[[0.0035837099422764904],[10.036108977992713],[-114.29679460799208]]
J: 2.5431564907845736  Alpha: 0.01  Theta: Matrix[[0.002061352330336195],[7.255061503962862],[-113.88091708799209]]
J: 2.268221136398013  Alpha: 0.01  Theta: Matrix[[0.0008076454646645536],[4.923257856798684],[-113.43169704202194]]
J: 2.02765281325063  Alpha: 0.01  Theta: Matrix[[-0.00014755931145485107],[3.0843409102315205],[-112.95644762679805]]
J: 1.821451342237053  Alpha: 0.01  Theta: Matrix[[-0.0008639634905593289],[1.6548476959031622],[-112.46627318829059]]
J: 1.8214513720879484  Alpha: 0.01  Theta: Matrix[[-0.0013117163263802246],[0.6758826956046561],[-111.9660989569473]]
J: 1.8214513720879484  Alpha: 0.001  Theta: Matrix[[-0.0013535066248876874],[0.5834935043210742],[-111.91600392423089]]
J: 1.7870844304014568  Alpha: 0.001  Theta: Matrix[[-0.0013952969233951501],[0.49110431303749225],[-111.86590889151448]]
J: 1.7870844304014568  Alpha: 0.001  Theta: Matrix[[-0.0014341021771264934],[0.40365238581361185],[-111.81578997843985]]
J: 1.7870844304014568  Alpha: 0.001  Theta: Matrix[[-0.0014729074308578367],[0.31620045858973145],[-111.76567106536523]]
J: 1.752717488714965  Alpha: 0.001  Theta: Matrix[[-0.0015115010626209136],[0.22904945780472585],[-111.71555130580272]]
J: 1.752717488714965  Alpha: 0.001  Theta: Matrix[[-0.001544336226800018],[0.15110191314800955],[-111.66540851236988]]
J: 1.770809597429665  Alpha: 0.001  Theta: Matrix[[-0.0015771713909791226],[0.07315436849129325],[-111.61526571893704]]
J: 1.7297985323807161  Alpha: 0.0001  Theta: Matrix[[-0.00158045491336022],[0.06535960382896211],[-111.61025143962061]]
J: 1.718350722631126  Alpha: 0.0001  Theta: Matrix[[-0.0015837319880072584],[0.05757622586497872],[-111.60523715385645]]
J: 1.7183505768797593  Alpha: 0.0001  Theta: Matrix[[-0.0015867170175074515],[0.05030859963032436],[-111.60022257604714]]
J: 1.7183505768793688  Alpha: 0.0001  Theta: Matrix[[-0.0015897020324328638],[0.04304099913473299],[-111.59520799822326]]
J: 1.7183505768793688  Alpha: 0.0001  Theta: Matrix[[-0.0015926870473582369],[0.03577339863921061],[-111.59019342039937]]
J: 1.7183505768793688  Alpha: 0.0001  Theta: Matrix[[-0.00159567206228361],[0.028505798143688237],[-111.58517884257549]]
J: 1.7183505768793688  Alpha: 0.0001  Theta: Matrix[[-0.001598657077208983],[0.02123819764816586],[-111.5801642647516]]
J: 1.7183505768793688  Alpha: 0.0001  Theta: Matrix[[-0.001601642092134356],[0.013970597152643486],[-111.57514968692772]]
J: 1.7183505768793688  Alpha: 0.0001  Theta: Matrix[[-0.001604627107059729],[0.006702996657121109],[-111.57013510910383]]
J: 1.7183505768793688  Alpha: 0.0001  Theta: Matrix[[-0.0016076121219851022],[-0.0005646038384012671],[-111.56512053127994]]
J: 1.7183505768793688  Alpha: 0.0001  Theta: Matrix[[-0.0016105971369104752],[-0.007832204333923645],[-111.56010595345606]]
J: 1.7183505768793688  Alpha: 0.0001  Theta: Matrix[[-0.0016135821518358483],[-0.01509980482944602],[-111.55509137563217]]
J: 1.7183505768793688  Alpha: 0.0001  Theta: Matrix[[-0.0016165671667612213],[-0.022367405324968396],[-111.55007679780829]]
J: 1.7183505768793688  Alpha: 0.0001  Theta: Matrix[[-0.0016195521816865944],[-0.02963500582049077],[-111.5450622199844]]
J: 1.7183505768793688  Alpha: 0.0001  Theta: Matrix[[-0.0016225371966119674],[-0.03690260631601315],[-111.54004764216052]]
J: 1.7183505768793688  Alpha: 0.0001  Theta: Matrix[[-0.0016255222115373405],[-0.04417020681153553],[-111.53503306433663]]
J: 1.7183505768793688  Alpha: 0.0001  Theta: Matrix[[-0.0016285072264627136],[-0.05143780730705791],[-111.53001848651274]]
J: 1.7183505768793688  Alpha: 0.0001  Theta: Matrix[[-0.0016314922443731613],[-0.05870541239661013],[-111.52500390868587]]
J: 1.7183505768793688  Alpha: 0.0001  Theta: Matrix[[-0.0016344772622834192],[-0.06597301748587016],[-111.519989330859]]
J: 1.7183505768793688  Alpha: 0.0001  Theta: Matrix[[-0.0016374622664495802],[-0.07324060142296517],[-111.51497475304588]]
J: 1.7183505768793688  Alpha: 0.0001  Theta: Matrix[[-0.001640217664533409],[-0.08015482159935092],[-111.50996040483884]]
J: 1.7183505768793688  Alpha: 0.0001  Theta: Matrix[[-0.0016455906875599943],[-0.0937712290880118],[-111.49993184619791]]
J: 1.994702022407994  Alpha: 0.0001  Theta: Matrix[[-0.0016482771980077554],[-0.10057943119248941],[-111.49491756687851]]
J: 1.9789198631246232  Alpha: 1.0e-05  Theta: Matrix[[-0.0016485458502465615],[-0.10126025363935508],[-111.49441613894419]]
J: 1.948354991984789  Alpha: 1.0e-05  Theta: Matrix[[-0.0016490831547241735],[-0.10262189853308641],[-111.49341328307554]]
J: 1.9331013621188657  Alpha: 1.0e-05  Theta: Matrix[[-0.0016493518069629796],[-0.10330272097995208],[-111.49291185514122]]
J: 1.9178620371528292  Alpha: 1.0e-05  Theta: Matrix[[-0.0016496204592017856],[-0.10398354342681772],[-111.49241042720689]]
J: 1.902623825636303  Alpha: 1.0e-05  Theta: Matrix[[-0.0016498891114405914],[-0.10466436587368326],[-111.49190899927257]]
J: 1.8873858680247269  Alpha: 1.0e-05  Theta: Matrix[[-0.0016501577636793972],[-0.10534518832054848],[-111.49140757133824]]
J: 1.8721478527437034  Alpha: 1.0e-05  Theta: Matrix[[-0.0016504264159182024],[-0.10602601076741257],[-111.49090614340392]]
J: 1.8569098083540256  Alpha: 1.0e-05  Theta: Matrix[[-0.0016506950681570054],[-0.10670683321427255],[-111.4904047154696]]
J: 1.8416717846532462  Alpha: 1.0e-05  Theta: Matrix[[-0.0016509637203958004],[-0.10738765566111781],[-111.48990328753527]]
J: 1.8264337702403803  Alpha: 1.0e-05  Theta: Matrix[[-0.0016512323726345674],[-0.10806847810791036],[-111.48940185960095]]
J: 1.8111957469624462  Alpha: 1.0e-05  Theta: Matrix[[-0.0016515010251717409],[-0.1087493010703349],[-111.48890043166602]]
J: 1.7959577228777213  Alpha: 1.0e-05  Theta: Matrix[[-0.001651769677708553],[-0.10943012403208266],[-111.4883990037311]]
J: 1.7807196990939538  Alpha: 1.0e-05  Theta: Matrix[[-0.0016520383302440706],[-0.11011094699140556],[-111.48789757579618]]
J: 1.7654816767669712  Alpha: 1.0e-05  Theta: Matrix[[-0.0016523069827749494],[-0.11079176994204029],[-111.48739614786128]]
J: 1.7197677244765115  Alpha: 1.0e-05  Theta: Matrix[[-0.0016531129399852717],[-0.11283423807786983],[-111.4858918640573]]
J: 1.7045300185036796  Alpha: 1.0e-05  Theta: Matrix[[-0.0016533815914621833],[-0.11351505905442376],[-111.48539043612449]]
J: 1.689293134633683  Alpha: 1.0e-05  Theta: Matrix[[-0.0016536502402002386],[-0.11419587490110002],[-111.48488900819716]]
J: 1.674059195452273  Alpha: 1.0e-05  Theta: Matrix[[-0.001653918879126327],[-0.1148766723699622],[-111.48438758028945]]
J: 1.6588357959146847  Alpha: 1.0e-05  Theta: Matrix[[-0.0016541874829120791],[-0.11555740402097447],[-111.48388615245203]]
J: 1.6436500186219352  Alpha: 1.0e-05  Theta: Matrix[[-0.0016544559609891405],[-0.1162379002196091],[-111.48338472486603]]
J: 1.6285972611659707  Alpha: 1.0e-05  Theta: Matrix[[-0.001654723991174496],[-0.11691755751707966],[-111.4828832981758]]
J: 1.6139994752963014  Alpha: 1.0e-05  Theta: Matrix[[-0.0016549904481917704],[-0.11759426827073645],[-111.48238187463193]]
J: 1.600799606845299  Alpha: 1.0e-05  Theta: Matrix[[-0.0016552516449943116],[-0.11826112664220582],[-111.48188046160847]]
J: 1.5908244528084288  Alpha: 1.0e-05  Theta: Matrix[[-0.0016554977759847996],[-0.1188997667477244],[-111.48137907871664]]
J: 1.5851960976828814  Alpha: 1.0e-05  Theta: Matrix[[-0.0016557144987826046],[-0.11948332530842007],[-111.4808777546412]]
J: 1.5826817076400923  Alpha: 1.0e-05  Theta: Matrix[[-0.0016558999497352893],[-0.12000831170339445],[-111.48037649310945]]
J: 1.5816354848004566  Alpha: 1.0e-05  Theta: Matrix[[-0.0016560658987327093],[-0.12049677093659837],[-111.4798752705816]]
J: 1.581199878569286  Alpha: 1.0e-05  Theta: Matrix[[-0.0016562224426970157],[-0.12096761454376066],[-111.47937406686383]]
J: 1.5810169018926878  Alpha: 1.0e-05  Theta: Matrix[[-0.0016563748211790893],[-0.12143065620486218],[-111.47887287147701]]
J: 1.5809396242131868  Alpha: 1.0e-05  Theta: Matrix[[-0.0016565254040880424],[-0.1218903347622732],[-111.47837167968135]]
J: 1.5809069017613124  Alpha: 1.0e-05  Theta: Matrix[[-0.0016566752202995195],[-0.12234857730581448],[-111.47787048941908]]
J: 1.5808930296490606  Alpha: 1.0e-05  Theta: Matrix[[-0.001656824710233385],[-0.12280620875454971],[-111.47736929980935]]
J: 1.580887145848097  Alpha: 1.0e-05  Theta: Matrix[[-0.0016569740612930289],[-0.12326358014294572],[-111.47686811047738]]
J: 1.580884649719601  Alpha: 1.0e-05  Theta: Matrix[[-0.0016571233527736234],[-0.12372084005243131],[-111.47636692126457]]
J: 1.5808835906710963  Alpha: 1.0e-05  Theta: Matrix[[-0.0016572726175860411],[-0.12417805026085695],[-111.47586573210509]]
J: 1.5808831413239819  Alpha: 1.0e-05  Theta: Matrix[[-0.00165742186803091],[-0.12463523410670607],[-111.47536454297435]]
.........

对于创建正确预测的人:

J: 4.330234652497978  Alpha: 1.0  Theta: Matrix[[0.12388059701492538],[211.9910447761194],[-111.13731343283582]]
J: 4.330234652497978  Alpha: 0.1  Theta: Matrix[[0.08626965671641812],[152.3222144059701],[-118.07202388059702]]
J: 4.2958677406623815  Alpha: 0.1  Theta: Matrix[[0.048658716417910856],[92.65338403582082],[-125.0067343283582]]
J: 3.333594209265678  Alpha: 0.1  Theta: Matrix[[0.011644779104478219],[33.61767533134318],[-131.44443979104477]]
J: 0.4467735852246924  Alpha: 0.1  Theta: Matrix[[-0.014623104477611202],[-11.126378913433022],[-132.24166105074627]]
J: 3.333594209265678  Alpha: 0.1  Theta: Matrix[[0.01194378805970217],[31.177094038805805],[-126.89243925671643]]
J: 3.0930257965656063  Alpha: 0.01  Theta: Matrix[[0.009436400895523079],[26.892626149850567],[-126.92472924]]
J: 2.7493567080605392  Alpha: 0.01  Theta: Matrix[[0.007257365074627634],[23.13644550388053],[-126.8386038647761]]
J: 2.508788325211366  Alpha: 0.01  Theta: Matrix[[0.005466380895523164],[19.99261048238799],[-126.62851089164178]]
J: 2.405687589704577  Alpha: 0.01  Theta: Matrix[[0.004152999104478391],[17.61296913194023],[-126.28907722179103]]
J: 2.268219942362192  Alpha: 0.01  Theta: Matrix[[0.002959017910448543],[15.415473392238736],[-125.92224111492536]]
J: 2.1307522353180164  Alpha: 0.01  Theta: Matrix[[0.002093389253732125],[13.751072827761122],[-125.48597339134326]]
J: 2.027651529662123  Alpha: 0.01  Theta: Matrix[[0.0014367116417918252],[12.436814710149182],[-125.00961691402983]]
J: 1.9589177059909308  Alpha: 0.01  Theta: Matrix[[0.0009889847761201823],[11.44908667850739],[-124.49911195194028]]
J: 1.8558169406332465  Alpha: 0.01  Theta: Matrix[[0.0006606582089560022],[10.652638055522315],[-123.97004023522386]]
J: 1.8214500586485458  Alpha: 0.01  Theta: Matrix[[0.0004218823880604789],[9.988664770447688],[-123.42914782925371]]
J: 1.8214500884994413  Alpha: 0.01  Theta: Matrix[[0.0002428068653197179],[9.416182220312082],[-122.88082274064425]]
J: 1.8214500884994413  Alpha: 0.001  Theta: Matrix[[0.00023086931308091184],[9.369775500013574],[-122.82513353589798]]
J: 1.8214500884994413  Alpha: 0.001  Theta: Matrix[[0.00021893176084210577],[9.323368779715066],[-122.7694443311517]]
J: 1.8214500884994413  Alpha: 0.001  Theta: Matrix[[0.0002069942086032997],[9.276962059416558],[-122.71375512640543]]
J: 1.8214500884994413  Alpha: 0.001  Theta: Matrix[[0.00019505665636449364],[9.23055533911805],[-122.65806592165916]]
J: 1.8214500884994413  Alpha: 0.001  Theta: Matrix[[0.00018311910412568757],[9.184148618819542],[-122.60237671691289]]
J: 1.8214500884994413  Alpha: 0.001  Theta: Matrix[[0.0001711815518868815],[9.137741898521034],[-122.54668751216661]]
J: 1.8214500884994413  Alpha: 0.001  Theta: Matrix[[0.00015924399964807544],[9.091335178222526],[-122.49099830742034]]
J: 1.8214500884994413  Alpha: 0.001  Theta: Matrix[[0.00014730641755852312],[9.04492840598372],[-122.43530910670393]]
J: 1.8677695240029366  Alpha: 0.001  Theta: Matrix[[0.0001353688354689708],[8.998521633744915],[-122.37961990598751]]
J: 1.8462563443835032  Alpha: 0.0001  Theta: Matrix[[0.0001341750742749415],[8.993880951437452],[-122.374050986289]]
J: 1.8247430163841476  Alpha: 0.0001  Theta: Matrix[[0.00013298131308164604],[8.98924026913124],[-122.3684820665904]]
J: 1.803243007740144  Alpha: 0.0001  Theta: Matrix[[0.0001317875528781551],[8.984599588510665],[-122.36291314676808]]
J: 1.7875423426167685  Alpha: 0.0001  Theta: Matrix[[0.00013059512176735966],[8.979961171334951],[-122.35734406080917]]
J: 1.7870839229503594  Alpha: 0.0001  Theta: Matrix[[0.0001296573060241053],[8.97575636413016],[-122.35174314792931]]
J: 1.7870831481868632  Alpha: 0.0001  Theta: Matrix[[0.00012876197468911015],[8.971623907872633],[-122.34613692449842]]
J: 1.7870831468153818  Alpha: 0.0001  Theta: Matrix[[0.00012786672082037553],[8.967491583540149],[-122.34053069138426]]
J: 1.7870831468129538  Alpha: 0.0001  Theta: Matrix[[0.000126971467088789],[8.963359259441226],[-122.33492445825294]]
J: 1.7870831468129498  Alpha: 0.0001  Theta: Matrix[[0.0001260762133574453],[8.959226935342718],[-122.3293182251216]]
J: 1.7870831468129498  Alpha: 0.0001  Theta: Matrix[[0.00012518095962610202],[8.95509461124421],[-122.32371199199025]]
J: 1.7870831468129498  Alpha: 0.0001  Theta: Matrix[[0.00012428570589475874],[8.950962287145702],[-122.3181057588589]]
J: 1.7870831468129498  Alpha: 0.0001  Theta: Matrix[[0.00012339045216341546],[8.946829963047193],[-122.31249952572756]]
J: 1.7870831468129498  Alpha: 0.0001  Theta: Matrix[[0.00012249519843207218],[8.942697638948685],[-122.30689329259621]]
J: 1.7870831468129498  Alpha: 0.0001  Theta: Matrix[[0.00012159994470072888],[8.938565314850177],[-122.30128705946487]]
J: 1.7870831468129498  Alpha: 0.0001  Theta: Matrix[[0.00012070469096938559],[8.934432990751668],[-122.29568082633352]]
J: 1.7870831468129498  Alpha: 0.0001  Theta: Matrix[[0.0001198094372380423],[8.93030066665316],[-122.29007459320218]]
J: 1.7870831468129498  Alpha: 0.0001  Theta: Matrix[[0.000118914183506699],[8.926168342554652],[-122.28446836007083]]
J: 1.7870831468129498  Alpha: 0.0001  Theta: Matrix[[0.00011801892977535571],[8.922036018456144],[-122.27886212693949]]
......

编辑:我注意到假设的第一次迭代总是预测0.5,因为theta全是0.但之后它总是预测1或0(0.00001或0.99999以避免我的代码中不存在的对数).这对我来说似乎不对 – 方式太自信了 – 而且可能是为什么这不起作用的关键.

解决方法

有一些关于您的实现的东西有点不标准.

>首先,逻辑回归目标通常作为最小化问题给出

lr(x [n],y [n])= log(1 exp(-y [n] * dot(w [n],x [n])))
其中y [n]为1或-1

你似乎正在使用等效的最大化问题公式

lr(x [n],y [n])= – y [n] * log(1 exp(-dot(w [n],x [n])))(1-y [n])*( – dot(w [n],x [n]) – log(1 exp(-dot(w [n],x [n])))

其中y [n]为0或1(该配方中的y [n] = 0等于第一配方中的y [n] = 1).

因此,您应确保在数据集中,标签为0或1而不是1或-1.
>接下来,LR目标通常不除以m(数据集的大小).将逻辑回归视为概率模型时,此缩放因子不正确.
>最后,您的实现可能存在一些数字问题(您尝试在g函数中进行更正). Leon Bottou的sgd代码(http://leon.bottou.org/projects/sgd)对损失函数和导数进行了一些更稳定的计算,如下所示(在C代码中 – 他使用我提到的第一个LR公式):

/* logloss(a,y) = log(1+exp(-a*y)) */
double loss(double a,double y)
{
  double z = a * y;
  if (z > 18) {
    return exp(-z);
  }
  if (z < -18) {
    return -z;
  }
  return log(1 + exp(-z));
}

/*  -dloss(a,y)/da */
double dloss(double a,double y)
{
  double z = a * y;
  if (z > 18) {
    return y * exp(-z);
  }
  if (z < -18){
    return y;
  }
  return y / (1 + exp(z));
}

您还应该考虑运行stock l-bfgs例程(我不熟悉Ruby实现),这样您就可以专注于使目标和梯度计算正确,而不必担心学习率等问题.

猜你在找的Ruby相关文章