长期盘踞热榜,微软官方AutoML库教你三步学会20+炼金基本功
机器之心报道
有了 AutoML,特征工程、神经架构和超参搜索这些炼金基本功再也不用担心了。作为科技巨头,微软也在 AutoML 上开源了自己的 NNI 库,这个库在 GitHub 上非常流行,长期盘踞在每日项目 Trending 榜。
易于使用:NNI 可通过 pip 安装,只需要在代码中添加几行,就可以利用 NNI 来调优超参数与模型架构。
可扩展:调优超参或网络结构通常需要大量的计算资源。NNI 在设计时就支持了多种不同的计算资源,如远程服务器组、OpenPAI 和 Kubernetes 等训练平台。
灵活:除了内置的算法,NNI 中还可以轻松集成自定义的超参调优算法、神经网络架构搜索算法、提前终止算法等等。还可以将 NNI 连接到更多的训练平台上,如云计算虚拟机集群、Kubernetes 服务等等。
高效:NNI 在系统及算法级别上不停地优化,例如可通过 Trial 早期的反馈来加速调优过程。
def run_trial(params):
# 输入数据
mnist = input_data.read_data_sets(params['data_dir'], one_hot=True)
# 构建网络
mnist_network = MnistNetwork(channel_1_num=params['channel_1_num'], channel_2_num=params['channel_2_num'], conv_size=params['conv_size'], hidden_size=params['hidden_size'], pool_size=params['pool_size'], learning_rate=params['learning_rate'])
mnist_network.build_network()
test_acc = 0.0
with tf.Session() as sess:
# 训练网络
mnist_network.train(sess, mnist)
# 评估网络
test_acc = mnist_network.evaluate(mnist)
if __name__ == '__main__':
params = {'data_dir': '/tmp/tensorflow/mnist/input_data', 'dropout_rate': 0.5, 'channel_1_num': 32, 'channel_2_num': 64, 'conv_size': 5, 'pool_size': 2, 'hidden_size': 1024, 'learning_rate': 1e-4, 'batch_num': 2000, 'batch_size': 32}
run_trial(params)
- params = {'data_dir': '/tmp/tensorflow/mnist/input_data', 'dropout_rate': 0.5, 'channel_1_num': 32, 'channel_2_num': 64,
- 'conv_size': 5, 'pool_size': 2, 'hidden_size': 1024, 'learning_rate': 1e-4, 'batch_num': 2000, 'batch_size': 32}
+ {
+ "dropout_rate":{"_type":"uniform","_value":[0.5, 0.9]},
+ "conv_size":{"_type":"choice","_value":[2,3,5,7]},
+ "hidden_size":{"_type":"choice","_value":[124, 512, 1024]},
+ "batch_size": {"_type":"choice", "_value": [1, 4, 8, 16, 32]},
+ "learning_rate":{"_type":"choice","_value":[0.0001, 0.001, 0.01, 0.1]}
+ }
+ import nni
def run_trial(params):
mnist = input_data.read_data_sets(params['data_dir'], one_hot=True)
mnist_network = MnistNetwork(channel_1_num=params['channel_1_num'], channel_2_num=params['channel_2_num'], conv_size=params['conv_size'], hidden_size=params['hidden_size'], pool_size=params['pool_size'], learning_rate=params['learning_rate'])
mnist_network.build_network()
with tf.Session() as sess:
mnist_network.train(sess, mnist)
test_acc = mnist_network.evaluate(mnist)
+ nni.report_final_result(test_acc)
if __name__ == '__main__':
- params = {'data_dir': '/tmp/tensorflow/mnist/input_data', 'dropout_rate': 0.5, 'channel_1_num': 32, 'channel_2_num': 64,
- 'conv_size': 5, 'pool_size': 2, 'hidden_size': 1024, 'learning_rate': 1e-4, 'batch_num': 2000, 'batch_size': 32}
+ params = nni.get_next_parameter()
run_trial(params)
authorName: default
experimentName: example_mnist
trialConcurrency: 1
maxExecDuration: 1h
maxTrialNum: 10
trainingServicePlatform: local
# 搜索空间文件
searchSpacePath: search_space.json
useAnnotation: false
tuner:
builtinTunerName: TPE
# 运行的命令,以及 Trial 代码的路径
trial:
command: python3 mnist.py
codeDir: .
gpuNum: 0
# 进入ENAS的代码目录
cd examples/nas/enas
# 在 Macro 搜索空间中搜索
python3 search.py --search-for macro
# 在 Micro 搜索空间中搜索
python3 search.py --search-for micro
# 查看更多选项
python3 search.py -h
关注公众号:拾黑(shiheibook)了解更多
[广告]赞助链接:
四季很好,只要有你,文娱排行榜:https://www.yaopaiming.com/
让资讯触达的更精准有趣:https://www.0xu.cn/

随时掌握互联网精彩
赞助链接
排名
热点
搜索指数
- 1 中拉团结合作再启新程 7904718
- 2 国台办回应特朗普突然提到“统一” 7809425
- 3 外交部:中方对美芬太尼反制仍然有效 7713544
- 4 持续增加投资的信心从何而来 7616183
- 5 林志炫:不是直播我不会回《歌手》 7522368
- 6 印巴谁赢了?发布会这7秒说明了很多 7426761
- 7 #赵丽颖和赵德胤恋情是真的吗# 7331792
- 8 多名在英国中国公民失踪失联 7237643
- 9 12岁女孩被动欠款百万成老赖 7137520
- 10 知名女演员重病归来 曾四登春晚 7039604