手把手教你实现PySpark机器学习项目——回归算法
作者 | hecongqing
来源 | AI算法之心(ID:AIHeartForYou)
【导读】PySpark作为工业界常用于处理大数据以及分布式计算的工具,特别是在算法建模时起到了非常大的作用。PySpark如何建模呢?这篇文章手把手带你入门PySpark,提前感受工业界的建模过程!
from pyspark.sql import SparkSessionspark = SparkSession .builder .appName("test") .config("spark.some.config.option", "setting") .getOrCreate() train = spark.read.csv('./BlackFriday/train.csv', header=True, inferSchema=True)test = spark.read.csv('./BlackFriday/test.csv', header=True, inferSchema=True
train.printSchema()"""root |-- User_ID: integer (nullable = true) |-- Product_ID: string (nullable = true) |-- Gender: string (nullable = true) |-- Age: string (nullable = true) |-- Occupation: integer (nullable = true) |-- City_Category: string (nullable = true) |-- Stay_In_Current_City_Years: string (nullable = true) |-- Marital_Status: integer (nullable = true) |-- Product_Category_1: integer (nullable = true) |-- Product_Category_2: integer (nullable = true) |-- Product_Category_3: integer (nullable = true) |-- Purchase: integer (nullable = true)"""
train.head(5)"""[Row(User_ID=1000001, Product_ID='P00069042', Gender='F', Age='0-17', Occupation=10, City_Category='A', Stay_In_Current_City_Years='2', Marital_Status=0, Product_Category_1=3, Product_Category_2=None, Product_Category_3=None, Purchase=8370), Row(User_ID=1000001, Product_ID='P00248942', Gender='F', Age='0-17', Occupation=10, City_Category='A', Stay_In_Current_City_Years='2', Marital_Status=0, Product_Category_1=1, Product_Category_2=6, Product_Category_3=14, Purchase=15200), Row(User_ID=1000001, Product_ID='P00087842', Gender='F', Age='0-17', Occupation=10, City_Category='A', Stay_In_Current_City_Years='2', Marital_Status=0, Product_Category_1=12, Product_Category_2=None, Product_Category_3=None, Purchase=1422), Row(User_ID=1000001, Product_ID='P00085442', Gender='F', Age='0-17', Occupation=10, City_Category='A', Stay_In_Current_City_Years='2', Marital_Status=0, Product_Category_1=12, Product_Category_2=14, Product_Category_3=None, Purchase=1057), Row(User_ID=1000002, Product_ID='P00285442', Gender='M', Age='55+', Occupation=16, City_Category='C', Stay_In_Current_City_Years='4+', Marital_Status=0, Product_Category_1=8, Product_Category_2=None, Product_Category_3=None, Purchase=7969)]"""
train.na.drop('any').count(),test.na.drop('any').count()"""(166821, 71037)"""
train = train.fillna(-1)test = test.fillna(-1)
train.describe().show()"""+-------+------------------+----------+------+------+------------------+-------------+--------------------------+-------------------+------------------+------------------+------------------+-----------------+|summary| User_ID|Product_ID|Gender| Age| Occupation|City_Category|Stay_In_Current_City_Years| Marital_Status|Product_Category_1|Product_Category_2|Product_Category_3| Purchase|+-------+------------------+----------+------+------+------------------+-------------+--------------------------+-------------------+------------------+------------------+------------------+-----------------+| count| 550068| 550068|550068|550068| 550068| 550068| 550068| 550068| 550068| 550068| 550068| 550068|| mean|1003028.8424013031| null| null| null| 8.076706879876669| null| 1.468494139793958|0.40965298835780306| 5.404270017525106| 6.419769919355425| 3.145214773446192|9263.968712959126|| stddev| 1727.591585530871| null| null| null|6.5226604873418115| null| 0.989086680757309| 0.4917701263173259| 3.936211369201324| 6.565109781181374| 6.681038828257864|5023.065393820593|| min| 1000001| P00000142| F| 0-17| 0| A| 0| 0| 1| -1| -1| 12|| max| 1006040| P0099942| M| 55+| 20| C| 4+| 1| 20| 18| 18| 23961|+-------+------------------+----------+------+------+------------------+-------------+--------------------------+-------------------+------------------+------------------+------------------+-----------------+"""
train.select('User_ID','Age').show(5)"""+-------+----+|User_ID| Age|+-------+----+|1000001|0-17||1000001|0-17||1000001|0-17||1000001|0-17||1000002| 55+|+-------+----+only showing top 5 rows"""
6. 分析categorical特征
train.select('Product_ID').distinct().count(), test.select('Product_ID').distinct().count()"""(3631, 3491)"""
diff_cat_in_train_test=test.select('Product_ID').subtract(train.select('Product_ID'))diff_cat_in_train_test.distinct().count()"""(46, None)"""diff_cat_in_train_test.distinct().show(5)"""+----------+|Product_ID|+----------+| P00322642|| P00300142|| P00077642|| P00249942|| P00294942|+----------+only showing top 5 rows"""
7. 将分类变量转换为标签
from pyspark.ml.feature import StringIndexerplan_indexer = StringIndexer(inputCol = 'Product_ID', outputCol = 'product_id_trans')labeller = plan_indexer.fit(train)
Train1 = labeller.transform(train)Test1 = labeller.transform(test)Train1.show(2)"""+-------+----------+------+----+----------+-------------+--------------------------+--------------+------------------+------------------+------------------+--------+----------------+|User_ID|Product_ID|Gender| Age|Occupation|City_Category|Stay_In_Current_City_Years|Marital_Status|Product_Category_1|Product_Category_2|Product_Category_3|Purchase|product_id_trans|+-------+----------+------+----+----------+-------------+--------------------------+--------------+------------------+------------------+------------------+--------+----------------+|1000001| P00069042| F|0-17| 10| A| 2| 0| 3| -1| -1| 8370| 766.0||1000001| P00248942| F|0-17| 10| A| 2| 0| 1| 6| 14| 15200| 183.0|+-------+----------+------+----+----------+-------------+--------------------------+--------------+------------------+------------------+------------------+--------+----------------+only showing top 2 rows"""Train1.select('product_id_trans').show(2)"""+----------------+|product_id_trans|+----------------+| 766.0|| 183.0|+----------------+only showing top 2 rows"""
from pyspark.ml.feature import RFormulaformula = RFormula(formula="Purchase ~ Age+ Occupation +City_Category+Stay_In_Current_City_Years+Product_Category_1+Product_Category_2+ Gender", featuresCol="features",labelCol="label")
t1 = formula.fit(Train1)train1 = t1.transform(Train1)test1 = t1.transform(Test1)train1.show(2)"""+-------+----------+------+----+----------+-------------+--------------------------+--------------+------------------+------------------+------------------+--------+----------------+--------------------+-------+|User_ID|Product_ID|Gender| Age|Occupation|City_Category|Stay_In_Current_City_Years|Marital_Status|Product_Category_1|Product_Category_2|Product_Category_3|Purchase|product_id_trans| features| label|+-------+----------+------+----+----------+-------------+--------------------------+--------------+------------------+------------------+------------------+--------+----------------+--------------------+-------+|1000001| P00069042| F|0-17| 10| A| 2| 0| 3| -1| -1| 8370| 766.0|(16,[6,10,13,14],...| 8370.0||1000001| P00248942| F|0-17| 10| A| 2| 0| 1| 6| 14| 15200| 183.0|(16,[6,10,13,14],...|15200.0|+-------+----------+------+----+----------+-------------+--------------------------+--------------+------------------+------------------+------------------+--------+----------------+--------------------+-------+only showing top 2 rows"""
train1.select('features').show(2)"""+--------------------+| features|+--------------------+|(16,[6,10,13,14],...||(16,[6,10,13,14],...|+--------------------+only showing top 2 rows"""train1.select('label').show(2)"""+-------+| label|+-------+| 8370.0||15200.0|+-------+only showing top 2 rows"""
from pyspark.ml.regression import RandomForestRegressorrf = RandomForestRegressor()
(train_cv, test_cv) = train1.randomSplit([0.7, 0.3])
model1 = rf.fit(train_cv)predictions = model1.transform(test_cv)
from pyspark.ml.evaluation import RegressionEvaluatorevaluator = RegressionEvaluator()mse = evaluator.evaluate(predictions,{evaluator.metricName:"mse" })import numpy as npnp.sqrt(mse), mse"""(3832.4796474051345, 14687900.247774584)"""
model = rf.fit(train1)predictions1 = model.transform(test1)
df = predictions1.selectExpr("User_ID as User_ID", "Product_ID as Product_ID", 'prediction as Purchase')df.toPandas().to_csv('./BlackFriday/submission.csv')
(*本文为AI科技大本营转载文章,转载请联系原作者)
◆
精彩推荐
◆
2019 中国大数据技术大会(BDTC)再度来袭!豪华主席阵容及百位技术专家齐聚,15 场精选专题技术和行业论坛,超强干货+技术剖析+行业实践立体解读,深入解析热门技术在行业中的实践落地。6.6 折票限时特惠(立减1400元),学生票仅 599 元!
推荐阅读
12306系统的秒杀“艺术”:如何抗住100万人同时抢1万张票?
实战:基于技术分析的Python算法交易
90 后技术宅研发 Magi 一夜爆红,新一代知识化结构搜索新时代来了?
谷歌“夜莺计划”秘密采集数百万美国人健康隐私;联发科首款7nm产能的5G芯片;2019年天猫双11落幕,最终成交额2684亿……
云计算软件生态圈:摸到一把大牌
女明星因自拍瞳孔倒影暴露住址惨遭跟踪,一张照片是怎么出卖你?
重大利好!人民日报海外版整版报道:区块链“链”向未来,既要积极又要稳妥
云计算软件生态圈:摸到一把大牌
你点的每个“在看”,我都认真当成了AI
关注公众号:拾黑(shiheibook)了解更多
[广告]赞助链接:
四季很好,只要有你,文娱排行榜:https://www.yaopaiming.com/
让资讯触达的更精准有趣:https://www.0xu.cn/
关注网络尖刀微信公众号
随时掌握互联网精彩
随时掌握互联网精彩
赞助链接
排名
热点
搜索指数
- 1 澳门是伟大祖国的一方宝地 7927989
- 2 女法官遇害案凶手被判死刑 7906150
- 3 日本火山喷发灰柱高达3400米 7806901
- 4 中国为全球经济增长添动能 7751085
- 5 肖战新片射雕英雄传郭靖造型曝光 7616600
- 6 大三女生练咏春一起手眼神骤变 7573309
- 7 #马斯克对特朗普政府影响有多大# 7484448
- 8 36岁女子看高血压查出怀孕34周 7314435
- 9 赵丽颖带儿子探班 7232331
- 10 女子穿和服在南京景区拍照遭怒怼 7175904