暂无图片
暂无图片
1
暂无图片
暂无图片
暂无图片

10 - 教你如何动态注册Spark UDF和UDAF

MLSQL之道 2021-09-23
2069
今天笔者解密MLSQL是如何动态注册UDF和UDAF的(Scala版本),笔者从分析解决问题思路展开讨论。一步一步实现,学习解决问题的思路。培养解决问题的思维很重要,不断受益。


Spark注册UDF和UDAF都是通过以下接口:
    spark.sessionState.udfRegistration
    下面先分析如何实现UDF,进入udfRegistration:
        def register(name: String, udf: UserDefinedFunction): UserDefinedFunction = {
      def builder(children: Seq[Expression]) = udf.apply(children.map(Column.apply) : _*).expr
          functionRegistry.createOrReplaceTempFunction(name, builder)
          udf
      }


      ----------------------------------------
      UserDefinedFunction的apply方法:
      def apply(exprs: Column*): Column = {
      // TODO: make sure this class is only instantiated through `SparkUserDefinedFunction.create()`
      / and `nullableTypes` is always set.
      if (nullableTypes.isEmpty) {
      nullableTypes = Some(ScalaReflection.getParameterTypeNullability(f))
      }
      if (inputTypes.isDefined) {
      assert(inputTypes.get.length == nullableTypes.get.length)
      }


      Column(ScalaUDF(
      f,
      dataType,
      exprs.map(_.expr),
      nullableTypes.get,
      inputTypes.getOrElse(Nil),
      udfName = _nameOption,
      nullable = _nullable,
      udfDeterministic = _deterministic))
      }
      只需要构造UserDefinedFunction,然后看到它的apply中有ScalaUDF,查看一下构造的具体描述:
        /**
        * User-defined function.
        * @param function The user defined scala function to run.
        * Note that if you use primitive parameters, you are not able to check if it is
        * null or not, and the UDF will return null for you if the primitive input is
        * null. Use boxed type or [[Option]] if you wanna do the null-handling yourself.
        * @param dataType Return type of function.
        * @param children The input expressions of this UDF.
        */
        case class ScalaUDF(
        function: AnyRef,
        dataType: DataType,
        children: Seq[Expression],
        inputsNullSafe: Seq[Boolean],
        inputTypes: Seq[DataType] = Nil,
        udfName: Option[String] = None,
        nullable: Boolean = true,
        udfDeterministic: Boolean = true)
        extends Expression with ImplicitCastInputTypes with NonSQLExpression with UserDefinedExpression {


        这里面有三个比较重要的参数,function为处理数据的函数,dataType是函数的返回类型,children是UDF接收的表达式(比如udf(max(age)) ,max(age)为接收的表达式)


        用户输入的是脚本,如何编译Scala脚本呢?Google可以找到用mkToolBox动态编译Scala,或者去https://docs.scala-lang.org/overviews/reflection下面找找,可以看到symbols-trees-types.html。

        从网上找的一个示例:
          import scala.reflect.runtime.universe
          import scala.tools.reflect.ToolBox


          private object RuntimeCompilationTest{
          val tb = universe.runtimeMirror(getClass.getClassLoader).mkToolBox()


          val classDef = tb.parse {
          """
          |private class MyParser extends Function[String,String]{
          | override def apply(v1: String): String = v1 + "123"
          |}
          |
          |scala.reflect.classTag[MyParser].runtimeClass
          """.stripMargin
          }




          val clazz = tb.compile(classDef).apply().asInstanceOf[Class[Function[String,String]]]


          val instance = clazz.getConstructor().newInstance()
          println(instance.apply("asdf"))
          }
          根据这个示例来构造一个工具类:
            object ScriptCodeCompiler extends Logging {


            def newInstance(clazz: Class[_]): Any = {
            val constructor = clazz.getDeclaredConstructors.head
                constructor.setAccessible(true)
            constructor.newInstance()
            }


            def getMethod(clazz: Class[_], method: String) = {
            val candidate = clazz.getDeclaredMethods.filter(_.getName == method).filterNot(_.isBridge)
            if (candidate.isEmpty) {
            throw new Exception(s"No method $method found in class ${clazz.getCanonicalName}")
            } else if (candidate.length > 1) {
            throw new Exception(s"Multiple method $method found in class ${clazz.getCanonicalName}")
            } else {
            candidate.head
            }
            }


            def compileScala(src: String): Class[_] = {
            import scala.reflect.runtime.universe
            import scala.tools.reflect.ToolBox
            val classLoader = scala.reflect.runtime.universe.getClass.getClassLoader
            val tb = universe.runtimeMirror(classLoader).mkToolBox()
            val tree = tb.parse(src)
            val clazz = tb.compile(tree).apply().asInstanceOf[Class[_]]
            clazz
            }


            def prepareScala(src: String, className: String): String = {
            src + "\n" + s"scala.reflect.classTag[$className].runtimeClass"
            }
            }
            接下来实现ScalaUDF的两个主要参数的构造:
              object ScalaSourceUDF {


              //因为输入是脚本,比如
              //def apply(a:Double,b:Double)={
              // a + b
              //}
              //因此要把它包裹成一个类。
              def wrapClass(function: String) = {
                  val className = s"WowUDF_${UUID.randomUUID().toString.replaceAll("-", "")}"
                   val newfun = s"""
              |class ${className}{
              |
              |${function}
              |
              |}
              """.stripMargin
              (className, newfun)
              }


              def apply(src: String, className: String, methodName: Option[String]): (AnyRef, DataType) = {
              val (argumentNum, returnType) = getFunctionReturnType(src, className, methodName)
              (generateFunction(src, className, methodName, argumentNum), returnType)
              }


              //需要根据反射获取函数返回的类型和参数的个数,参数的个数用于new函数
              private def getFunctionReturnType(src: String, className: String, methodName: Option[String]): (Int, DataType) = {
              val clazz = SourceCodeCompiler.compileScala(SourceCodeCompiler.prepareScala(src, className))
              val method = SourceCodeCompiler.getMethod(clazz, methodName.getOrElse("apply"))
              val dataType: (DataType, Boolean) = JavaTypeInference.inferDataType(method.getReturnType)
              (method.getParameterCount, dataType._1)
              }


              //构造函数,根据反射调用真实的函数,case class最大支持参数是22个,因此要实现22个new Function,也限制了UDF参数的个数不能超过22个
              def generateFunction(src: String, className: String, methodName: Option[String], argumentNum: Int): AnyRef = {
              lazy val clazz = SourceCodeCompiler.compileScala(SourceCodeCompiler.prepareScala(src, className))
              lazy val instance = SourceCodeCompiler.newInstance(clazz)
              lazy val method = SourceCodeCompiler.getMethod(clazz, methodName.getOrElse("apply"))
              argumentNum match {
              case 0 => new Function0[Any] with Serializable {
              override def apply(): Any = {
              method.invoke(instance)
              }
              }
              case 1 => new Function1[Object, Any] with Serializable {
              override def apply(v1: Object): Any = {
              method.invoke(instance, v1)
              }
              }
              ...


              }
              下面来写测试程序:
                    val scalaScript =
                """
                |def apply(a:Double,b:Double)={
                | a + b
                |}
                """.stripMargin


                val (className, src) = ScalaSourceUDF.wrapClass(scalaScript)
                val clazz = ScriptCodeCompiler.compileScala(ScriptCodeCompiler.prepareScala(src, className))


                val (func, returnType) = ScalaSourceUDF(src, className, Some("apply"))


                spark.sessionState.udfRegistration.register("wowPlus" ,UserDefinedFunction(func ,returnType ,None))


                import spark.implicits._


                val testDf = Seq(1 ,2 ,3).toDF("a").createOrReplaceTempView("test")


                spark.sql("select wowPlus(a ,1) a from test").show(false)


                ---------------结果---------------


                +---+
                |a |
                +---+
                |2.0|
                |3.0|
                |4.0|
                +---+
                第二部分,如何实现UDAF:
                    def register(name: String, udaf: UserDefinedAggregateFunction): UserDefinedAggregateFunction = {
                  def builder(children: Seq[Expression]) = ScalaUDAF(children, udaf)
                  functionRegistry.createOrReplaceTempFunction(name, builder)
                  udaf
                  }
                  UserDefinedAggregateFunction是一个抽象类,比case class简单些,只需要实现这个类,通过反射机制获取方法信息,并调用相应方法:
                    object ScalaSourceUDAF {
                    def apply(src: String, className: String): UserDefinedAggregateFunction = {
                    generateAggregateFunction(src, className)
                    }


                    private def generateAggregateFunction(src: String, className: String): UserDefinedAggregateFunction = {
                    new UserDefinedAggregateFunction with Serializable {


                    @transient val clazzUsingInDriver = ScriptCodeCompiler.compileScala(ScriptCodeCompiler.prepareScala(src, className))
                    @transient val instanceUsingInDriver = ScriptCodeCompiler.newInstance(clazzUsingInDriver)


                    lazy val clazzUsingInExecutor = ScriptCodeCompiler.compileScala(ScriptCodeCompiler.prepareScala(src, className))
                    lazy val instanceUsingInExecutor = ScriptCodeCompiler.newInstance(clazzUsingInExecutor)


                    def invokeMethod[T: ClassTag](clazz: Class[_], instance: Any, method: String): T = {
                    ScriptCodeCompiler.getMethod(clazz, method).invoke(instance).asInstanceOf[T]
                    }


                    val _inputSchema = invokeMethod[StructType](clazzUsingInDriver, instanceUsingInDriver, "inputSchema")
                    val _dataType = invokeMethod[DataType](clazzUsingInDriver, instanceUsingInDriver, "dataType")
                    val _bufferSchema = invokeMethod[StructType](clazzUsingInDriver, instanceUsingInDriver, "bufferSchema")
                    val _deterministic = invokeMethod[Boolean](clazzUsingInDriver, instanceUsingInDriver, "deterministic")


                    override def inputSchema: StructType = {
                    _inputSchema
                    }


                    override def dataType: DataType = {
                    _dataType
                    }


                    override def bufferSchema: StructType = {
                    _bufferSchema
                    }


                    override def deterministic: Boolean = {
                    _deterministic
                    }


                    lazy val _update = ScriptCodeCompiler.getMethod(clazzUsingInExecutor, "update")
                    lazy val _merge = ScriptCodeCompiler.getMethod(clazzUsingInExecutor, "merge")
                    lazy val _initialize = ScriptCodeCompiler.getMethod(clazzUsingInExecutor, "initialize")
                    lazy val _evaluate = ScriptCodeCompiler.getMethod(clazzUsingInExecutor, "evaluate")


                    override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
                    _update.invoke(instanceUsingInExecutor, buffer, input)
                    }


                    override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
                    _merge.invoke(instanceUsingInExecutor, buffer1, buffer2)
                    }


                    override def initialize(buffer: MutableAggregationBuffer): Unit = {
                    _initialize.invoke(instanceUsingInExecutor, buffer)
                    }


                    override def evaluate(buffer: Row): Any = {
                    _evaluate.invoke(instanceUsingInExecutor, buffer)
                    }




                    }
                    }
                    }


                    下面实现测试程序:
                          val scalaScript =
                      """
                      |import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
                      |import org.apache.spark.sql.types._
                      |import org.apache.spark.sql.Row
                      |class SumAggregation extends UserDefinedAggregateFunction with Serializable{
                      | def inputSchema: StructType = new StructType().add("a", LongType)
                      | def bufferSchema: StructType = new StructType().add("total", LongType)
                      | def dataType: DataType = LongType
                      | def deterministic: Boolean = true
                      | def initialize(buffer: MutableAggregationBuffer): Unit = {
                      | buffer.update(0, 0l)
                      | }
                      | def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
                      | val sum = buffer.getLong(0)
                      | val newitem = input.getLong(0)
                      | buffer.update(0, sum + newitem)
                      | }
                      | def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
                      | buffer1.update(0, buffer1.getLong(0) + buffer2.getLong(0))
                      | }
                      | def evaluate(buffer: Row): Any = {
                      | buffer.getLong(0)
                      | }
                      |}
                      """.stripMargin


                      spark.sessionState.udfRegistration.register("wowSum" ,ScalaSourceUDAF(scalaScript, "SumAggregation"))


                      import spark.implicits._


                      val testDf = Seq(1 ,2 ,3).toDF("a").createOrReplaceTempView("test")


                      spark.sql("select wowSum(a) a from test").show(false)


                      ---------------结果---------------
                      +---+
                      |a |
                      +---+
                      |6 |
                      +---+


                      本实现参照MLSQL的streaming.udf.ScalaRuntimeCompileUDAF和ScalaRuntimeCompileUDF,做了简化和部分修改。本示例以开发的角度,思考分析解决问题,希望对读者在分析解决问题上有所帮助。Java和Python版本实现请参照MLSQL:https://github.com/allwefantasy/mlsql的streaming.udf包。Python的实现基于jython,python包支持受限于jython的支持。


                      往期回顾:
                      1 - MLSQL介绍
                      2 - MLSQL加载JDBC数据源深度剖析
                      3 - MLSQL DSL-你准备好搞自己的DSL了吗
                      4 - 教你如何实现 Hive 列权限控制
                      5 - 教你如何实现 JDBC 列权限控制
                      6 - 教你如何使用Spark分布式执行Python脚本计算数据
                      7 - 教你如何读取MySQL binlog
                      8 - 请开启解析Canal binlog为Spark DataFrame的正确姿势
                      9 - 教你如何用注册发现模式--动态扩展

                      喜欢就点击最上方的[ MLSQL之道 ]关注下吧!

                      源码地址:https://github.com/latincross/mlsqlwechat(c10-udf-udaf)



                      文章转载自MLSQL之道,如果涉嫌侵权,请发送邮件至:contact@modb.pro进行举报,并提供相关证据,一经查实,墨天轮将立刻删除相关内容。

                      评论