python - Reshaping/Pivoting data in Spark RDD and/or Spark DataFrames -
i have data in following format (either rdd or spark dataframe):
from pyspark.sql import sqlcontext sqlcontext = sqlcontext(sc) rdd = sc.parallelize([('x01',41,'us',3), ('x01',41,'uk',1), ('x01',41,'ca',2), ('x02',72,'us',4), ('x02',72,'uk',6), ('x02',72,'ca',7), ('x02',72,'xx',8)]) # convert spark dataframe schema = structtype([structfield('id', stringtype(), true), structfield('age', integertype(), true), structfield('country', stringtype(), true), structfield('score', integertype(), true)]) df = sqlcontext.createdataframe(rdd, schema)
what 'reshape' data, convert rows in country(specifically us, uk , ca) columns:
id age uk ca 'x01' 41 3 1 2 'x02' 72 4 6 7
essentially, need along lines of python's pivot
workflow:
categories = ['us', 'uk', 'ca'] new_df = df[df['country'].isin(categories)].pivot(index = 'id', columns = 'country', values = 'score')
my dataset rather large can't collect()
, ingest data memory reshaping in python itself. there way convert python's .pivot()
invokable function while mapping either rdd or spark dataframe? appreciated!
first up, not idea, because not getting information, binding fixed schema (ie must need know how many countries expecting, , of course, additional country means change in code)
having said that, sql problem, shown below. in case suppose not "software like" (seriously, have heard this!!), can refer first solution.
solution 1:
def reshape(t): out = [] out.append(t[0]) out.append(t[1]) v in brc.value: if t[2] == v: out.append(t[3]) else: out.append(0) return (out[0],out[1]),(out[2],out[3],out[4],out[5]) def cntryfilter(t): if t[2] in brc.value: return t else: pass def addtup(t1,t2): j=() k,v in enumerate(t1): j=j+(t1[k]+t2[k],) return j def seq(tintrm,tnext): return addtup(tintrm,tnext) def comb(tp,tf): return addtup(tp,tf) countries = ['ca', 'uk', 'us', 'xx'] brc = sc.broadcast(countries) reshaped = calls.filter(cntryfilter).map(reshape) pivot = reshaped.aggregatebykey((0,0,0,0),seq,comb,1) in pivot.collect(): print
now, solution 2: of course better sql right tool this
callrow = calls.map(lambda t: row(userid=t[0],age=int(t[1]),country=t[2],nbrcalls=t[3])) callsdf = ssc.createdataframe(callrow) callsdf.printschema() callsdf.registertemptable("calls") res = ssc.sql("select userid,age,max(ca),max(uk),max(us),max(xx)\ (select userid,age,\ case when country='ca' nbrcalls else 0 end ca,\ case when country='uk' nbrcalls else 0 end uk,\ case when country='us' nbrcalls else 0 end us,\ case when country='xx' nbrcalls else 0 end xx \ calls) x \ group userid,age") res.show()
data set up:
data=[('x01',41,'us',3),('x01',41,'uk',1),('x01',41,'ca',2),('x02',72,'us',4),('x02',72,'uk',6),('x02',72,'ca',7),('x02',72,'xx',8)] calls = sc.parallelize(data,1) countries = ['ca', 'uk', 'us', 'xx']
result:
from 1st solution
(('x02', 72), (7, 6, 4, 8)) (('x01', 41), (2, 1, 3, 0))
from 2nd solution:
root |-- age: long (nullable = true) |-- country: string (nullable = true) |-- nbrcalls: long (nullable = true) |-- userid: string (nullable = true) userid age ca uk xx x02 72 7 6 4 8 x01 41 2 1 3 0
kindly let me know if works, or not :)
best ayan
Comments
Post a Comment