import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from handyspark.util import get_buckets, none2zero
from operator import add, itemgetter
from pyspark.ml.feature import Bucketizer
from pyspark.ml.pipeline import Pipeline
from pyspark.mllib.stat import Statistics
from pyspark.sql import functions as F
from matplotlib.artist import setp
import matplotlib as mpl
mpl.rc("lines", markeredgewidth=0.5)
[docs]def title_fom_clause(clause):
return clause.replace(' and ', '\n').replace(' == ', '=').replace('"', '')
[docs]def consolidate_plots(fig, axs, title, clauses):
axs[0].set_title(title)
fig.tight_layout()
if len(axs) > 1:
assert len(axs) == len(clauses), 'Mismatched number of plots and clauses!'
xlim = list(map(lambda ax: ax.get_xlim(), axs))
xlim = [np.min(list(map(itemgetter(0), xlim))), np.max(list(map(itemgetter(1), xlim)))]
ylim = list(map(lambda ax: ax.get_ylim(), axs))
ylim = [np.min(list(map(itemgetter(0), ylim))), np.max(list(map(itemgetter(1), ylim)))]
for i, ax in enumerate(axs):
subtitle = title_fom_clause(clauses[i])
ax.set_title(subtitle, fontdict={'fontsize': 10})
ax.set_xlim(xlim)
ax.set_ylim(ylim)
if ax.colNum > 0:
ax.get_yaxis().set_visible(False)
if ax.rowNum < (ax.numRows - 1):
ax.get_xaxis().set_visible(False)
if isinstance(title, list):
title = ', '.join(title)
fig.suptitle(title)
fig.tight_layout()
fig.subplots_adjust(top=0.85)
return fig, axs
### Correlations
[docs]def correlations(sdf, colnames, method='pearson', ax=None, plot=True):
sdf = sdf.notHandy()
correlations = Statistics.corr(sdf.select(colnames).dropna().rdd.map(lambda row: row[0:]), method=method)
pdf = pd.DataFrame(correlations, columns=colnames, index=colnames)
if plot:
if ax is None:
fig, ax = plt.subplots(1, 1)
return sns.heatmap(round(pdf,2), annot=True, cmap="coolwarm", fmt='.2f', linewidths=.05, ax=ax)
else:
return pdf
### Scatterplot
[docs]def strat_scatterplot(sdf, col1, col2, n=30):
stages = []
for col in [col1, col2]:
splits = get_buckets(sdf.select(col).rdd.map(itemgetter(0)), n)
stages.append(Bucketizer(splits=splits,
inputCol=col,
outputCol="__{}_bucket".format(col),
handleInvalid="skip"))
pipeline = Pipeline(stages=stages)
model = pipeline.fit(sdf)
return model, sdf.count()
[docs]def scatterplot(sdf, col1, col2, n=30, ax=None):
strat_ax, data = sdf._get_strata()
sdf = sdf.notHandy()
if data is None:
data = strat_scatterplot(sdf, col1, col2, n)
else:
ax = strat_ax
model, total = data
if ax is None:
fig, ax = plt.subplots(1, 1)
counts = (model
.transform(sdf.select(col1, col2).dropna())
.select(*("__{}_bucket".format(col) for col in (col1, col2)))
.rdd
.map(lambda row: (row[0:], 1))
.reduceByKey(add)
.collect())
splits = [bucket.getSplits() for bucket in model.stages]
splits = [list(map(np.mean, zip(split[1:], split[:-1]))) for split in splits]
df_counts = pd.DataFrame([(splits[0][int(v[0][0])],
splits[1][int(v[0][1])],
v[1]) for v in counts],
columns=[col1, col2, 'Proportion'])
df_counts.loc[:, 'Proportion'] = df_counts.Proportion.apply(lambda p: round(p / total, 4))
return sns.scatterplot(data=df_counts,
x=col1,
y=col2,
size='Proportion',
ax=ax,
legend=False)
### Histogram
[docs]def strat_histogram(sdf, colname, bins=10, categorical=False):
if categorical:
start_values = (sdf.select(colname)
.rdd
.map(lambda row: (itemgetter(0)(row), 1))
.reduceByKey(add)
.sortBy(itemgetter(1), ascending=False)
.collect())
counts = list(map(itemgetter(1), start_values))
start_values = list(map(itemgetter(0), start_values))
else:
start_values, counts = sdf.select(colname).rdd.map(itemgetter(0)).histogram(bins)
return start_values, counts
[docs]def histogram(sdf, colname, bins=10, categorical=False, ax=None):
strat_ax, data = sdf._get_strata()
sdf = sdf.notHandy()
if data is None:
data = strat_histogram(sdf, colname, bins, categorical)
else:
ax = strat_ax
start_values, counts = data
if ax is None:
fig, ax = plt.subplots(1, 1)
if categorical:
values = dict(sdf.select(colname)
.rdd
.map(lambda row: (itemgetter(0)(row), 1))
.reduceByKey(add)
.sortBy(itemgetter(1), ascending=False)
.collect())
values = list(map(lambda k: (k, values.get(k, 0)), start_values))
pdf = pd.Series(map(itemgetter(1), values),
index=map(itemgetter(0), values),
name=colname).sort_index().to_frame().iloc[:bins]
pdf.plot(kind='bar', color='C0', legend=False, rot=0, ax=ax, title=colname)
else:
_, counts = sdf.select(colname).rdd.map(itemgetter(0)).histogram(start_values)
mid_point_bins = start_values[:-1]
ax.hist(mid_point_bins, bins=start_values, weights=counts)
ax.set_title(colname)
return ax
### Stratified Histogram
[docs]def stratified_histogram(sdf, colname, strat_colname, strat_values, ax=None):
buckets = get_buckets(sdf.select(colname).rdd.map(itemgetter(0)), 20)
for value in strat_values:
start_values, counts = (sdf
.select(colname)
.filter('{} == {}'.format(strat_colname, value))
.rdd
.map(itemgetter(0))
.histogram(buckets))
sns.distplot(start_values[:len(counts)],
bins=start_values,
color='C{}'.format(value - 1),
norm_hist=True,
kde=False,
hist_kws={"weights":counts},
label='{}'.format(value),
ax=ax)
ax.set_legend()
return ax
### Boxplot
def _gen_dict(rc_name, properties):
""" Loads properties in the dictionary from rc file if not already
in the dictionary"""
rc_str = 'boxplot.{0}.{1}'
dictionary = dict()
for prop_dict in properties:
dictionary.setdefault(prop_dict,
plt.rcParams[rc_str.format(rc_name, prop_dict)])
return dictionary
[docs]def draw_boxplot(ax, stats):
flier_props = ['color', 'marker', 'markerfacecolor', 'markeredgecolor',
'markersize', 'linestyle', 'linewidth']
default_props = ['color', 'linewidth', 'linestyle']
boxprops = _gen_dict('boxprops', default_props)
whiskerprops = _gen_dict('whiskerprops', default_props)
capprops = _gen_dict('capprops', default_props)
medianprops = _gen_dict('medianprops', default_props)
meanprops = _gen_dict('meanprops', default_props)
flierprops = _gen_dict('flierprops', flier_props)
props = dict(boxprops=boxprops,
flierprops=flierprops,
medianprops=medianprops,
meanprops=meanprops,
capprops=capprops,
whiskerprops=whiskerprops)
colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b',
'#e377c2', '#7f7f7f', '#bcbd22', '#17becf', '#1f77b4']
bp = ax.bxp(stats, **props)
ax.grid(True)
setp(bp['boxes'], color=colors[0], alpha=1)
setp(bp['whiskers'], color=colors[0], alpha=1)
setp(bp['medians'], color=colors[2], alpha=1)
return ax
def _calc_tukey(col_summ, k=1.5):
q1, q3 = float(none2zero(col_summ['25%'])), float(none2zero(col_summ['75%']))
iqr = q3 - q1
lfence = q1 - (k * iqr)
ufence = q3 + (k * iqr)
return lfence, ufence
[docs]def boxplot(sdf, colnames, ax=None, showfliers=True, k=1.5):
strat_ax, data = sdf._get_strata()
sdf = sdf.notHandy()
if data is None:
if ax is None:
fig, ax = plt.subplots(1, 1)
pdf = sdf.select(colnames).summary().toPandas().set_index('summary')
pdf.loc['fence', :] = pdf.apply(lambda v: _calc_tukey(v, k))
# faster than stats()
def minmax(a, b):
return min(a[0], b[0]), max(a[1], b[1])
stats = []
for colname in colnames:
col_summ = pdf[colname]
lfence, ufence = col_summ.fence
outlier = sdf.withColumn('__{}_outlier'.format(colname),
~F.col(colname).between(lfence, ufence))
fliers = []
try:
minv, maxv = (outlier
.filter('not __{}_outlier'.format(colname))
.select(colname)
.rdd
.map(lambda x: (x[0], x[0]))
.reduce(minmax))
if showfliers:
fliers = (outlier
.filter('__{}_outlier'.format(colname))
.select(colname)
.rdd
.map(itemgetter(0))
.sortBy(lambda v: -abs(v))
.take(1000))
except ValueError:
minv = 0.
maxv = 0.
item = {'label': colname,
'mean': float(none2zero(col_summ['mean'])),
'med': float(none2zero(col_summ['50%'])),
'q1': float(none2zero(col_summ['25%'])),
'q3': float(none2zero(col_summ['75%'])),
'whislo': minv,
'whishi': maxv,
'fliers': fliers}
stats.append(item)
if ax is not None:
return draw_boxplot(ax, stats)
else:
return stats
[docs]def post_boxplot(axs, stats, clauses):
if len(axs) == len(stats):
new_res = []
for ax, stat in zip(axs, stats):
ax = draw_boxplot(ax, stat)
new_res.append(ax)
else:
ax = axs[0]
items = []
for clause, stats in zip(clauses, stats):
label = title_fom_clause(clause)
stats[0].update({'label': label})
items.append(stats[0])
ax = draw_boxplot(ax, items)
new_res = [ax]
return new_res