不依赖Python第三方库实现梯度下降

前端之家收集整理的这篇文章主要介绍了不依赖Python第三方库实现梯度下降前端之家小编觉得挺不错的,现在分享给大家,也给大家做个参考。

认识

梯度的本意是一个向量(矢量),表示某一函数在该点处的方向导数沿着该方向取得最大值,即函数在该点处沿着该方向(此梯度的方向)变化最快,变化率最大(为该梯度的模),我感觉,其实就是偏导数向量方向呗,沿着这个向量方向可以找到局部的极值.

from random import random

def gradient_down(func,part_df_func,var_num,rate=0.1,max_iter=10000,tolerance=1e-10):
    """
    不依赖第三库实现梯度下降
    :param func: 损失(误差)函数
    :param part_df_func: 损失函数的偏导数向量
    :param var_num: 变量个数
    :param rate: 学习率(参数的每次变化的幅度)
    :param max_iter: 最大计算次数
    :param tolerance: 误差的精度
    :return: theta,y_current:  权重参数值列表,损失函数最小值
    """

    theta = [random() for _ in range(var_num)]  # 随机给定参数的初始值
    y_current = func(*theta)  # 参数解包

    for i in range(max_iter):
        # 计算当前参数的梯度(偏导数导数向量值)
        gradient = [f(*theta) for f in part_df_func]
        # 根据梯度更新参数 theta
        for j in range(var_num):

            theta[j] -= gradient[j] * rate  # [0.3,0.6,0.7] ==> [0.3-0.3*lr,0.6-0.6*lr,0.7-0.7*lr]

            y_current,y_predict = func(*theta),y_current
            print(f"正在进行第{i}次迭代,误差精度为{abs(y_predict - y_current)}")

            if abs(y_predict - y_current) < tolerance:   # 判断是否收敛,(误差值的精度)

                print(); print(f"ok,在第{i}次迭代,收敛到可以了哦!")

                return theta,y_current


def f(x,y):
    """原函数"""
    return (x + y - 3) ** 2 + (x + 2 * y - 5) ** 2 + 2


def df_dx(x,y):
    """对x求偏导数"""
    return 2 * (x + y - 3) + 2 * (x + 2 * y - 5)


def df_dy(x,y):
    """对y求偏导数,注意求导的链式法则哦"""
    return 2 * (x + y - 3) + 2 * (x + 2 * y - 5) * 2


def main():
    """主函数"""
    print("用梯度下降的方式求解函数的最小值哦:")
    theta,f_theta = gradient_down(f,[df_dx,df_dy],var_num=2)

    theta,f_theta = [round(i,3) for i in theta],round(f_theta,3)  # 保留3位小数

    print("该函数最优解是: 当theta取:{}时,f(theta)取到最小值:{}".format(theta,f_theta))


if __name__ == '__main__':
    main()
...
...
正在进行第248次迭代,误差精度为1.6640999689343516e-10
正在进行第249次迭代,误差精度为1.5684031851037616e-10
正在进行第250次迭代,误差精度为1.478208666583214e-10
正在进行第251次迭代,误差精度为1.3931966691416164e-10
正在进行第252次迭代,误差精度为1.3130829756846651e-10
正在进行第253次迭代,误差精度为1.2375700464417605e-10
正在进行第254次迭代,误差精度为1.166395868779091e-10
正在进行第255次迭代,误差精度为1.0993206345233375e-10
正在进行第256次迭代,误差精度为1.0361000946090826e-10
正在进行第257次迭代,误差精度为9.765166453234997e-11

ok,在第257次迭代,收敛到可以了哦!
该函数最优解是: 当theta取:[1.0,2.0]时,f(theta)取到最小值:2.0
[Finished in 0.0s]

猜你在找的设计模式相关文章