机器学习模型五花八门不知道怎么选?这份指南告诉你
【导读】在本文中,我们将探讨不同的机器学习模型,以及每个模型合理的使用场景。 一般来说,基于树形结构的模型在Kaggle竞赛中是表现最好的,而其它的模型可以用于融合模型。对于计算机视觉领域的挑战,CNNs (Convolutional Neural Network, 卷积神经网络)是最适合不过的。而对于NLP(Natural Language Processing,自然语言处理),LSTMs或GRUs是最好的选择。下面是一个不完全模型细目清单,同时列出了每个模型的一些优缺点。
1. 回归 — 预测连续值
A. 线性回归(Linear Regression)
I.Vanilla Linear Regression
优点
缺点
II. Lasso回归, Ridge回归, Elastic-Net回归
优点
缺点
B. 回归树(Regression Trees)
I.决策树(Decision Tree)
优点
缺点
II.融合模型(RandomForest,XGBoost, CatBoost, LightGBM)
优点
缺点
C. 深度学习(Deep Learning)
优点
缺点
D. 基于距离的K近邻算法(K Nearest Neighbors – Distance Based)
优点
缺点
2. 分类 — 预测一个或多个类别的概率
A. 逻辑回归算法(Logistic Regression)
优点
缺点
优点
缺点
优点
缺点
优点
缺点
E. 分类树(Classification Tree)
I. 决策树(Decision Tree)
优点
缺点
可能会出现过度拟合(见下面的融合模型)
优点
缺点
F. 深度学习(Deep Learning)
优点
缺点
3. 聚类 — 将数据分类以便最大化相似性
A. DBSCAN聚类算法(Density-Based Spatial Clustering of Applications with Noise)
优点
缺点
B. Kmeans算法
优点
缺点
4. Misc — 本文中未包含的模型
降维算法(Dimensionality Reduction Algorithms); 聚类算法(Clustering algorithms);
计算机视觉(CV);
自然语言处理(Natural Language Processing,NLP)
强化学习(Reinforcement Learning)
融合模型
# in order to make the final predictions more robust to overfitting
def blended_predictions(X):
return ((0.1 * ridge_model_full_data.predict(X)) + \
(0.2 * svr_model_full_data.predict(X)) + \
(0.1 * gbr_model_full_data.predict(X)) + \
(0.1 * xgb_model_full_data.predict(X)) + \
(0.1 * lgb_model_full_data.predict(X)) + \
(0.05 * rf_model_full_data.predict(X)) + \
(0.35 * stack_gen_model.predict(np.array(X))))
Bagging:使用随机选择的不同数据子集训练多个基础模型,并进行替换。让基础模型对最终的预测进行投票。常用于随机森林算法(RandomForests); Boosting:迭代地训练模型,并且在每次迭代之后更新获得每个训练示例的重要程度。常用于梯度增强算法(GradientBoosting); Blending:训练许多不同类型的基础模型,并在一个holdout set上进行预测。从它们的预测结果中再训练一个新的模型,并在测试集上进行预测(用一个holdout set堆叠); Stacking:训练多种不同类型的基础模型,并对数据集的k-folds进行预测。从它们的预测结果中再训练一个新的模型,并在测试集上进行预测;
模型对比
# WandB
import wandb
import tensorflow.keras
from wandb.keras import WandbCallback
from sklearn.model_selection import cross_val_score
# Import models (Step 1: add your models here)
from sklearn import svm
from sklearn.linear_model import Ridge, RidgeCV
from xgboost import XGBRegressor
# Model 1
# Initialize wandb run
# You can change your project name here. For more config options, see https://docs.wandb.com/docs/init.html
'allow', project="pick-a-model") =
# Initialize model (Step 2: add your classifier here)
clf = svm.SVR(C= 20, epsilon= 0.008, gamma=0.0003)
# Get CV scores
cv_scores = cross_val_score(clf, X_train, train_labels, cv=5)
# Log scores
for cv_score in cv_scores:
cv_score}) :
# Model 2
# Initialize wandb run
# You can change your project name here. For more config options, see https://docs.wandb.com/docs/init.html
'allow', project="pick-a-model") =
# Initialize model (Step 2: add your classifier here)
clf = XGBRegressor(learning_rate=0.01,
n_estimators=6000,
max_depth=4,
min_child_weight=0,
gamma=0.6,
subsample=0.7,
colsample_bytree=0.7,
objective='reg:linear',
nthread=-1,
scale_pos_weight=1,
seed=27,
reg_alpha=0.00006,
random_state=42)
# Get CV scores
cv_scores = cross_val_score(clf, X_train, train_labels, cv=5)
# Log scores
for cv_score in cv_scores:
cv_score}) :
# Model 3
# Initialize wandb run
# You can change your project name here. For more config options, see https://docs.wandb.com/docs/init.html
'allow', project="pick-a-model") =
# Initialize model (Step 2: add your classifier here)
ridge_alphas = [1e-15, 1e-10, 1e-8, 9e-4, 7e-4, 5e-4, 3e-4, 1e-4, 1e-3, 5e-2, 1e-2, 0.1, 0.3, 1, 3, 5, 10, 15, 18, 20, 30, 50, 75, 100]
clf = Ridge(alphas=ridge_alphas)
# Get CV scores
cv_scores = cross_val_score(clf, X_train, train_labels, cv=5)
# Log scores
for cv_score in cv_scores:
cv_score}) :
◆
精彩推荐
◆
点击阅读原文,或扫描文首贴片二维码
所有CSDN 用户都可参与投票和抽奖活动
加入福利群,每周还有精选学习资料、技术图书等福利发送
GitHub标星1.5w+,从此我只用这款全能高速下载工具
中国工程师在美遭抢劫电脑遇害,数百人悼念
跟风 Google 只是东施效颦?!
召回→排序→重排:技术演进趋势的深度之旅,2020 必备!
如何写出让同事膜拜的漂亮代码?
同样是写代码,你和大神究竟差在哪里?
互联网公司=21世纪的国营大厂
详解CPU几个重点基础知识
DeFi行业2019全年呈爆炸式增长,8.5亿美元资产锁定在DeFi生态中;行业市值主要由头部项目瓜分 | 报告
你点的每个“在看”,我都认真当成了AI
关注公众号:拾黑(shiheibook)了解更多
[广告]赞助链接:
四季很好,只要有你,文娱排行榜:https://www.yaopaiming.com/
让资讯触达的更精准有趣:https://www.0xu.cn/
关注网络尖刀微信公众号
随时掌握互联网精彩
随时掌握互联网精彩
赞助链接
排名
热点
搜索指数
- 1 澳门是伟大祖国的一方宝地 7938471
- 2 女法官遇害案凶手被判死刑 7931081
- 3 日本火山喷发灰柱高达3400米 7869509
- 4 中国为全球经济增长添动能 7777234
- 5 肖战新片射雕英雄传郭靖造型曝光 7615970
- 6 大三女生练咏春一起手眼神骤变 7506103
- 7 #马斯克对特朗普政府影响有多大# 7441368
- 8 男子钓上一条自带“赎金”的鱼 7378118
- 9 赵丽颖带儿子探班 7258790
- 10 女子穿和服在南京景区拍照遭怒怼 7134733