Source code for handyspark.sql.transform

import datetime
import inspect
import numpy as np
from pyspark.sql import functions as F

_MAPPING = {'string': str,
            'date': datetime.date,
            'timestamp': datetime.datetime,
            'boolean': np.bool,
            'binary': np.byte,
            'byte': np.int8,
            'short': np.int16,
            'integer': np.int32,
            'long': np.int64,
            'float': np.float32,
            'double': np.float64,
            'array': np.ndarray,
            'map': dict}


[docs]class HandyTransform(object): _mapping = dict([(v.__name__, k) for k, v in _MAPPING.items()]) _mapping.update({'float': 'double', 'int': 'integer', 'list': 'array', 'bool': 'boolean'}) @staticmethod def _get_return(sdf, f, args): returnType = None if args is None: args = f.__code__.co_varnames if len(args): returnType = sdf.select(args[0]).dtypes[0][1] return returnType @staticmethod def _signatureType(sig): returnType = None signatureType = str(sig.return_annotation)[7:] if '_empty' not in signatureType: returnType = signatureType types = returnType.replace(']', '').replace('[', ',').split(',')[:3] for returnType in types: assert returnType.lower().strip() in HandyTransform._mapping.keys(), "invalid returnType" types = list(map(lambda t: HandyTransform._mapping[t.lower().strip()], types)) returnType = types[0] if len(types) > 1: returnType = '<'.join([returnType, ','.join(types[1:])]) returnType += '>' return returnType
[docs] @staticmethod def gen_pandas_udf(f, args=None, returnType=None): sig = inspect.signature(f) if args is None: args = tuple(sig.parameters.keys()) assert isinstance(args, (list, tuple)), "args must be list or tuple" name = '{}{}'.format(f.__name__, str(args).replace("'", "")) if returnType is None: returnType = HandyTransform._signatureType(sig) try: import pyarrow @F.pandas_udf(returnType=returnType) def udf(*args): return f(*args) except: @F.udf(returnType=returnType) def udf(*args): return f(*args) return udf(*args).alias(name)
[docs] @staticmethod def gen_grouped_pandas_udf(sdf, f, args=None, returnType=None): # TODO: test it properly! sig = inspect.signature(f) if args is None: args = tuple(sig.parameters.keys()) assert isinstance(args, (list, tuple)), "args must be list or tuple" name = '{}{}'.format(f.__name__, str(f.__code__.co_varnames).replace("'", "")) if returnType is None: returnType = HandyTransform._signatureType(sig) schema = sdf.notHandy().select(*args).withColumn(name, F.lit(None).cast(returnType)).schema @F.pandas_udf(schema, F.PandasUDFType.GROUPED_MAP) def pudf(pdf): computed = pdf.apply(lambda row: f(*tuple(row[p] for p in f.__code__.co_varnames)), axis=1) return pdf.assign(__computed=computed).rename(columns={'__computed': name}) return pudf
[docs] @staticmethod def transform(sdf, f, name=None, args=None, returnType=None): if name is None: name = '{}{}'.format(f.__name__, str(f.__code__.co_varnames).replace("'", "")) if isinstance(f, tuple): f, returnType = f if returnType is None: returnType = HandyTransform._get_return(sdf, f, args) return sdf.withColumn(name, HandyTransform.gen_pandas_udf(f, args, returnType))
[docs] @staticmethod def apply(sdf, f, name=None, args=None, returnType=None): if name is None: name = '{}{}'.format(f.__name__, str(f.__code__.co_varnames).replace("'", "")) if isinstance(f, tuple): f, returnType = f if returnType is None: returnType = HandyTransform._get_return(sdf, f, args) return sdf.select(HandyTransform.gen_pandas_udf(f, args, returnType).alias(name))
[docs] @staticmethod def assign(sdf, **kwargs): for c, f in kwargs.items(): typename = None if isinstance(f, tuple): f, typename = f if callable(f): if typename is None: typename = HandyTransform._get_return(sdf, f, None) if typename is not None: sdf = sdf.transform(f, name=c, returnType=typename) else: sdf = sdf.withColumn(c, F.lit(f())) else: sdf = sdf.withColumn(c, F.lit(f)) return sdf