
spark.sessionState.udfRegistration
def register(name: String, udf: UserDefinedFunction): UserDefinedFunction = {def builder(children: Seq[Expression]) = udf.apply(children.map(Column.apply) : _*).exprfunctionRegistry.createOrReplaceTempFunction(name, builder)udf}----------------------------------------UserDefinedFunction的apply方法:def apply(exprs: Column*): Column = {// TODO: make sure this class is only instantiated through `SparkUserDefinedFunction.create()`/ and `nullableTypes` is always set.if (nullableTypes.isEmpty) {nullableTypes = Some(ScalaReflection.getParameterTypeNullability(f))}if (inputTypes.isDefined) {assert(inputTypes.get.length == nullableTypes.get.length)}Column(ScalaUDF(f,dataType,exprs.map(_.expr),nullableTypes.get,inputTypes.getOrElse(Nil),udfName = _nameOption,nullable = _nullable,udfDeterministic = _deterministic))}
/*** User-defined function.* @param function The user defined scala function to run.* Note that if you use primitive parameters, you are not able to check if it is* null or not, and the UDF will return null for you if the primitive input is* null. Use boxed type or [[Option]] if you wanna do the null-handling yourself.* @param dataType Return type of function.* @param children The input expressions of this UDF.*/case class ScalaUDF(function: AnyRef,dataType: DataType,children: Seq[Expression],inputsNullSafe: Seq[Boolean],inputTypes: Seq[DataType] = Nil,udfName: Option[String] = None,nullable: Boolean = true,udfDeterministic: Boolean = true)extends Expression with ImplicitCastInputTypes with NonSQLExpression with UserDefinedExpression {
用户输入的是脚本,如何编译Scala脚本呢?Google可以找到用mkToolBox动态编译Scala,或者去https://docs.scala-lang.org/overviews/reflection下面找找,可以看到symbols-trees-types.html。
import scala.reflect.runtime.universeimport scala.tools.reflect.ToolBoxprivate object RuntimeCompilationTest{val tb = universe.runtimeMirror(getClass.getClassLoader).mkToolBox()val classDef = tb.parse {"""|private class MyParser extends Function[String,String]{| override def apply(v1: String): String = v1 + "123"|}||scala.reflect.classTag[MyParser].runtimeClass""".stripMargin}val clazz = tb.compile(classDef).apply().asInstanceOf[Class[Function[String,String]]]val instance = clazz.getConstructor().newInstance()println(instance.apply("asdf"))}
object ScriptCodeCompiler extends Logging {def newInstance(clazz: Class[_]): Any = {val constructor = clazz.getDeclaredConstructors.headconstructor.setAccessible(true)constructor.newInstance()}def getMethod(clazz: Class[_], method: String) = {val candidate = clazz.getDeclaredMethods.filter(_.getName == method).filterNot(_.isBridge)if (candidate.isEmpty) {throw new Exception(s"No method $method found in class ${clazz.getCanonicalName}")} else if (candidate.length > 1) {throw new Exception(s"Multiple method $method found in class ${clazz.getCanonicalName}")} else {candidate.head}}def compileScala(src: String): Class[_] = {import scala.reflect.runtime.universeimport scala.tools.reflect.ToolBoxval classLoader = scala.reflect.runtime.universe.getClass.getClassLoaderval tb = universe.runtimeMirror(classLoader).mkToolBox()val tree = tb.parse(src)val clazz = tb.compile(tree).apply().asInstanceOf[Class[_]]clazz}def prepareScala(src: String, className: String): String = {src + "\n" + s"scala.reflect.classTag[$className].runtimeClass"}}
object ScalaSourceUDF {//因为输入是脚本,比如//def apply(a:Double,b:Double)={// a + b//}//因此要把它包裹成一个类。def wrapClass(function: String) = {val className = s"WowUDF_${UUID.randomUUID().toString.replaceAll("-", "")}"val newfun = s"""|class ${className}{||${function}||}""".stripMargin(className, newfun)}def apply(src: String, className: String, methodName: Option[String]): (AnyRef, DataType) = {val (argumentNum, returnType) = getFunctionReturnType(src, className, methodName)(generateFunction(src, className, methodName, argumentNum), returnType)}//需要根据反射获取函数返回的类型和参数的个数,参数的个数用于new函数private def getFunctionReturnType(src: String, className: String, methodName: Option[String]): (Int, DataType) = {val clazz = SourceCodeCompiler.compileScala(SourceCodeCompiler.prepareScala(src, className))val method = SourceCodeCompiler.getMethod(clazz, methodName.getOrElse("apply"))val dataType: (DataType, Boolean) = JavaTypeInference.inferDataType(method.getReturnType)(method.getParameterCount, dataType._1)}//构造函数,根据反射调用真实的函数,case class最大支持参数是22个,因此要实现22个new Function,也限制了UDF参数的个数不能超过22个def generateFunction(src: String, className: String, methodName: Option[String], argumentNum: Int): AnyRef = {lazy val clazz = SourceCodeCompiler.compileScala(SourceCodeCompiler.prepareScala(src, className))lazy val instance = SourceCodeCompiler.newInstance(clazz)lazy val method = SourceCodeCompiler.getMethod(clazz, methodName.getOrElse("apply"))argumentNum match {case 0 => new Function0[Any] with Serializable {override def apply(): Any = {method.invoke(instance)}}case 1 => new Function1[Object, Any] with Serializable {override def apply(v1: Object): Any = {method.invoke(instance, v1)}}...}
val scalaScript ="""|def apply(a:Double,b:Double)={| a + b|}""".stripMarginval (className, src) = ScalaSourceUDF.wrapClass(scalaScript)val clazz = ScriptCodeCompiler.compileScala(ScriptCodeCompiler.prepareScala(src, className))val (func, returnType) = ScalaSourceUDF(src, className, Some("apply"))spark.sessionState.udfRegistration.register("wowPlus" ,UserDefinedFunction(func ,returnType ,None))import spark.implicits._val testDf = Seq(1 ,2 ,3).toDF("a").createOrReplaceTempView("test")spark.sql("select wowPlus(a ,1) a from test").show(false)---------------结果---------------+---+|a |+---+|2.0||3.0||4.0|+---+
def register(name: String, udaf: UserDefinedAggregateFunction): UserDefinedAggregateFunction = {def builder(children: Seq[Expression]) = ScalaUDAF(children, udaf)functionRegistry.createOrReplaceTempFunction(name, builder)udaf}
object ScalaSourceUDAF {def apply(src: String, className: String): UserDefinedAggregateFunction = {generateAggregateFunction(src, className)}private def generateAggregateFunction(src: String, className: String): UserDefinedAggregateFunction = {new UserDefinedAggregateFunction with Serializable {@transient val clazzUsingInDriver = ScriptCodeCompiler.compileScala(ScriptCodeCompiler.prepareScala(src, className))@transient val instanceUsingInDriver = ScriptCodeCompiler.newInstance(clazzUsingInDriver)lazy val clazzUsingInExecutor = ScriptCodeCompiler.compileScala(ScriptCodeCompiler.prepareScala(src, className))lazy val instanceUsingInExecutor = ScriptCodeCompiler.newInstance(clazzUsingInExecutor)def invokeMethod[T: ClassTag](clazz: Class[_], instance: Any, method: String): T = {ScriptCodeCompiler.getMethod(clazz, method).invoke(instance).asInstanceOf[T]}val _inputSchema = invokeMethod[StructType](clazzUsingInDriver, instanceUsingInDriver, "inputSchema")val _dataType = invokeMethod[DataType](clazzUsingInDriver, instanceUsingInDriver, "dataType")val _bufferSchema = invokeMethod[StructType](clazzUsingInDriver, instanceUsingInDriver, "bufferSchema")val _deterministic = invokeMethod[Boolean](clazzUsingInDriver, instanceUsingInDriver, "deterministic")override def inputSchema: StructType = {_inputSchema}override def dataType: DataType = {_dataType}override def bufferSchema: StructType = {_bufferSchema}override def deterministic: Boolean = {_deterministic}lazy val _update = ScriptCodeCompiler.getMethod(clazzUsingInExecutor, "update")lazy val _merge = ScriptCodeCompiler.getMethod(clazzUsingInExecutor, "merge")lazy val _initialize = ScriptCodeCompiler.getMethod(clazzUsingInExecutor, "initialize")lazy val _evaluate = ScriptCodeCompiler.getMethod(clazzUsingInExecutor, "evaluate")override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {_update.invoke(instanceUsingInExecutor, buffer, input)}override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {_merge.invoke(instanceUsingInExecutor, buffer1, buffer2)}override def initialize(buffer: MutableAggregationBuffer): Unit = {_initialize.invoke(instanceUsingInExecutor, buffer)}override def evaluate(buffer: Row): Any = {_evaluate.invoke(instanceUsingInExecutor, buffer)}}}}
val scalaScript ="""|import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}|import org.apache.spark.sql.types._|import org.apache.spark.sql.Row|class SumAggregation extends UserDefinedAggregateFunction with Serializable{| def inputSchema: StructType = new StructType().add("a", LongType)| def bufferSchema: StructType = new StructType().add("total", LongType)| def dataType: DataType = LongType| def deterministic: Boolean = true| def initialize(buffer: MutableAggregationBuffer): Unit = {| buffer.update(0, 0l)| }| def update(buffer: MutableAggregationBuffer, input: Row): Unit = {| val sum = buffer.getLong(0)| val newitem = input.getLong(0)| buffer.update(0, sum + newitem)| }| def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {| buffer1.update(0, buffer1.getLong(0) + buffer2.getLong(0))| }| def evaluate(buffer: Row): Any = {| buffer.getLong(0)| }|}""".stripMarginspark.sessionState.udfRegistration.register("wowSum" ,ScalaSourceUDAF(scalaScript, "SumAggregation"))import spark.implicits._val testDf = Seq(1 ,2 ,3).toDF("a").createOrReplaceTempView("test")spark.sql("select wowSum(a) a from test").show(false)---------------结果---------------+---+|a |+---+|6 |+---+

喜欢就点击最上方的[ MLSQL之道 ]关注下吧!
源码地址:https://github.com/latincross/mlsqlwechat(c10-udf-udaf)
文章转载自MLSQL之道,如果涉嫌侵权,请发送邮件至:contact@modb.pro进行举报,并提供相关证据,一经查实,墨天轮将立刻删除相关内容。




