博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
MXNET:权重衰减-gluon实现
阅读量:5930 次
发布时间:2019-06-19

本文共 2098 字,大约阅读时间需要 6 分钟。

构建数据集

# -*- coding: utf-8 -*-from mxnet import initfrom mxnet import ndarray as ndfrom mxnet.gluon import loss as glossimport gbn_train = 20n_test = 100num_inputs = 200true_w = nd.ones((num_inputs, 1)) * 0.01true_b = 0.05features = nd.random.normal(shape=(n_train+n_test, num_inputs))labels = nd.dot(features, true_w) + true_blabels += nd.random.normal(scale=0.01, shape=labels.shape)train_features, test_features = features[:n_train, :], features[n_train:, :]train_labels, test_labels = labels[:n_train], labels[n_train:]

数据迭代器

from mxnet import autogradfrom mxnet.gluon import data as gdatabatch_size = 1num_epochs = 10learning_rate = 0.003train_iter = gdata.DataLoader(gdata.ArrayDataset(    train_features, train_labels), batch_size, shuffle=True)loss = gloss.L2Loss()

训练并展示结果

gb.semilogy函数:绘制训练和测试数据的loss

from mxnet import gluonfrom mxnet.gluon import nndef fit_and_plot(weight_decay):    net = nn.Sequential()    net.add(nn.Dense(1))    net.initialize(init.Normal(sigma=1))    # 对权重参数做 L2 范数正则化,即权重衰减。    trainer_w = gluon.Trainer(net.collect_params('.*weight'), 'sgd', {        'learning_rate': learning_rate, 'wd': weight_decay})    # 不对偏差参数做 L2 范数正则化。    trainer_b = gluon.Trainer(net.collect_params('.*bias'), 'sgd', {        'learning_rate': learning_rate})    train_ls = []    test_ls = []    for _ in range(num_epochs):        for X, y in train_iter:            with autograd.record():                l = loss(net(X), y)            l.backward()            # 对两个 Trainer 实例分别调用 step 函数。            trainer_w.step(batch_size)            trainer_b.step(batch_size)        train_ls.append(loss(net(train_features),                             train_labels).mean().asscalar())        test_ls.append(loss(net(test_features),                            test_labels).mean().asscalar())    gb.semilogy(range(1, num_epochs + 1), train_ls, 'epochs', 'loss',                range(1, num_epochs + 1), test_ls, ['train', 'test'])    return 'w[:10]:', net[0].weight.data()[:, :10], 'b:', net[0].bias.data()print fit_and_plot(5)
  • 使用 Gluon 的 wd 超参数可以使用权重衰减来应对过拟合问题。
  • 我们可以定义多个 Trainer 实例对不同的模型参数使用不同的迭代方法。

转载地址:http://xuutx.baihongyu.com/

你可能感兴趣的文章
Linux备份ifcfg-eth0文件导致的网络故障问题
查看>>
2018年尾总结——稳中成长
查看>>
行列式的乘法定理
查看>>
JFreeChart开发_用JFreeChart增强JSP报表的用户体验
查看>>
度量时间差
查看>>
MySQL 5.6为什么关闭元数据统计信息自动更新&统计信息收集源代码探索
查看>>
apache prefork模式优化错误
查看>>
jmeter高级用法例子,如何扩展自定义函数
查看>>
通过jsp请求Servlet来操作HBASE
查看>>
JS页面刷新保持数据不丢失
查看>>
清橙A1202&Bzoj2201:彩色圆环
查看>>
使用data pump工具的准备
查看>>
springMVC---级联属性
查看>>
get和post区别
查看>>
crontab执行shell脚本日志中出现乱码
查看>>
Floodlight 在 ChannelPipeline 图
查看>>
做移动互联网App,你的测试用例足够吗?
查看>>
cmd.exe启动参数说明
查看>>
《随笔记录》20170310
查看>>
网站分析系统
查看>>