StarRocks 函数就像预设于数据库中的公式,允许用户调用现有的函数以完成特定功能。函数可以很方便地实现业务逻辑的重用,因此正确使用函数会让读者在编写 SQL 语句时起到事半功倍的效果。
StarRocks 提供了多种内置函数,包括标量函数、聚合函数、窗口函数、Table 函数和 Lambda 函数等,可帮助用户更加便捷地处理表中的数据。此外,StarRocks 还允许用户自定义函数以适应实际的业务操作。本文将以标量函数和聚合函数为例,介绍 StarRocks 常见的两种函数实现原理,希望读者能够借鉴其设计思路,并按需实现所需的函数。同时,我们也欢迎社区小伙伴一起贡献力量,共同完善 StarRocks 的功能,具体的函数任务认领方式请见文末。
01
如何为 StarRocks 添加标量函数
1-1
标量函数介绍
1-2
标量函数的实现原理
1-2
标量函数的实现原理
标量函数的函数签名定义在
gensrc/script/functions.py,在编译阶段我们会根据 Python 文件中的内容生成对应的 Java 和 C++ 代码,供 FE 和 BE 使用。
[<function_id>, <function_name>, <return_type>, [<arg_type>...], <be_scalar_function>]or[<function_id>, <function_name>, <return_type>, [<arg_type>...], <be_scalar_function>, <be_prepare_function>, <be_close_function>]
function_id:函数唯一标识,是唯一一串数字,function_id 遵循如下约定,前两位表示 function_type,中间两位表示 function_group,余下的表示具体的 sub_function,后面我们会举例说明 function_name:函数名称 return_type:返回值类型 arg_type:入参类型,如果有多个入参,需要在数组中描述每个入参的类型 be_scalar_function:BE 中负责实现该函数计算逻辑的函数 be_prepare_function/be_close_function:可选参数,有些函数在执行的过程中可能会传递一些状态,be_prepare_function 和 be_close_function 就是 BE 中负责实现创建状态和回收状态的函数
function_id: 10
代表它们都属于 math function,04
代表它们都属于 abs 这个 function group,余下的数字用来区分具体的 sub-functionfunction_name:函数名称都是 abs return_type:返回值类型,同入参类型一致 arg_type:该函数只接受一个入参,所以第四项的数组中只有一个元素。 be_eval_function:BE 中实现计算逻辑的函数,StarRocks 针对每种数据类型做了特殊处理,所以每个签名中的函数名也不一样
[10040, "abs", "DOUBLE", ["DOUBLE"], "MathFunctions::abs_double"],[10041, "abs", "FLOAT", ["FLOAT"], "MathFunctions::abs_float"],[10042, "abs", "LARGEINT", ["LARGEINT"], "MathFunctions::abs_largeint"],[10043, "abs", "LARGEINT", ["BIGINT"], "MathFunctions::abs_bigint"],[10044, "abs", "BIGINT", ["INT"], "MathFunctions::abs_int"],[10045, "abs", "INT", ["SMALLINT"], "MathFunctions::abs_smallint"],[10046, "abs", "SMALLINT", ["TINYINT"], "MathFunctions::abs_tinyint"],[10047, "abs", "DECIMALV2", ["DECIMALV2"], "MathFunctions::abs_decimalv2val"],[100470, "abs", "DECIMAL32", ["DECIMAL32"], "MathFunctions::abs_decimal32"],[100471, "abs", "DECIMAL64", ["DECIMAL64"], "MathFunctions::abs_decimal64"],[100472, "abs", "DECIMAL128", ["DECIMAL128"], "MathFunctions::abs_decimal128"],
在编译阶段,根据 gensrc/script/functions.py
中的内容生成代码供 FE 和 BE 使用。Java 代码在 fe/fe-core/target/generated-sources/build/com/starrocks/builtins/VectorizedBuiltinFunctions.java
,FunctionSet[1]保存了所有的函数签名,初始化阶段会调用VectorizedBuiltinFunctions::initBuiltins
来添加标量函数的函数签名。SQL analyze 阶段,会利用 FunctionSet 提供的信息进行校验,如果找不到函数签名会直接返回错误,这部分实现在 ExpressionAnalyzer.Visitor [2]的 visitFunctionCall[3]方法中。
C++ 代码在 ./gensrc/build/gen_C++/opcode/builtin_functions.cpp
,BE 标量函数的函数签名保存在 BuiltinFunctions::_fn_tables[4] ,生成的代码用于初始化_fn_tables。在 SQL 执行阶段,VectorizedFunctionCallExpr 会根据 fid(函数唯一标识)从 _fn_tables 中找到执行该函数所需要的信息,包括输入参数的个数,执行函数的函数指针(ScalarFunction),以及执行前后的 PrepareFunction 和 CloseFunction,这部分定义在 FunctionDescriptor[5]在 BE 实现函数的计算逻辑
1-3
添加标量函数示例
1-3

生成函数签名
[120160, "sha2", "VARCHAR", ["VARCHAR", "INT"], "EncryptionFunctions::sha2", "EncryptionFunctions::sha2_prepare", "EncryptionFunctions::sha2_close"],
EncryptionFunctions::sha2_prepare
和 EncryptionFunctions::sha2_close,用来实现状态的创建和回收。
实现函数的计算逻辑
** Called by sha2 to the corresponding part*/DEFINE_VECTORIZED_FN(sha224);DEFINE_VECTORIZED_FN(sha256);DEFINE_VECTORIZED_FN(sha384);DEFINE_VECTORIZED_FN(sha512);DEFINE_VECTORIZED_FN(invalid_sha);*** @param: [json_string, tagged_value]* @paramType: [BinaryColumn, BinaryColumn]* @return: Int32Column*/DEFINE_VECTORIZED_FN(sha2);static Status sha2_prepare(FunctionContext* context, FunctionContext::FunctionStateScope scope);static Status sha2_close(FunctionContext* context, FunctionContext::FunctionStateScope scope);
PrepareFunction
Status EncryptionFunctions::sha2_prepare(FunctionContext* context, FunctionContext::FunctionStateScope scope) {if (scope != FunctionContext::FRAGMENT_LOCAL) {return Status::OK();}if (!context->is_notnull_constant_column(1)) {return Status::OK();}ColumnPtr column = context->get_constant_column(1);auto hash_length = ColumnHelper::get_const_value<TYPE_INT>(column);ScalarFunction function;if (hash_length == 224) {function = &EncryptionFunctions::sha224;} else if (hash_length == 256 || hash_length == 0) {function = &EncryptionFunctions::sha256;} else if (hash_length == 384) {function = &EncryptionFunctions::sha384;} else if (hash_length == 512) {function = &EncryptionFunctions::sha512;} else {function = EncryptionFunctions::invalid_sha;}auto fc = new EncryptionFunctions::SHA2Ctx();fc->function = function;context->set_function_state(scope, fc);return Status::OK();}
ScalarFunction
function_state 就可以派上用场了。具体代码如下:
StatusOr<ColumnPtr> EncryptionFunctions::sha2(FunctionContext* ctx, const Columns& columns) {if (!ctx->is_notnull_constant_column(1)) {auto src_viewer = ColumnViewer<TYPE_VARCHAR>(columns[0]);auto length_viewer = ColumnViewer<TYPE_INT>(columns[1]);auto size = columns[0]->size();ColumnBuilder<TYPE_VARCHAR> result(size);for (int row = 0; row < size; row++) {if (src_viewer.is_null(row) || length_viewer.is_null(row)) {result.append_null();continue;}auto src_value = src_viewer.value(row);auto length = length_viewer.value(row);if (length == 224) {SHA224Digest digest;digest.update(src_value.data, src_value.size);digest.digest();result.append(Slice(digest.hex().c_str(), digest.hex().size()));} else if (length == 0 || length == 256) {SHA256Digest digest;digest.update(src_value.data, src_value.size);digest.digest();result.append(Slice(digest.hex().c_str(), digest.hex().size()));} else if (length == 384) {SHA384Digest digest;digest.update(src_value.data, src_value.size);digest.digest();result.append(Slice(digest.hex().c_str(), digest.hex().size()));} else if (length == 512) {SHA512Digest digest;digest.update(src_value.data, src_value.size);digest.digest();result.append(Slice(digest.hex().c_str(), digest.hex().size()));} else {result.append_null();}}return result.build(ColumnHelper::is_all_const(columns));}auto ctc = reinterpret_cast<SHA2Ctx*>(ctx->get_function_state(FunctionContext::FRAGMENT_LOCAL));return ctc->function(ctx, columns);}
CloseFunction
Status EncryptionFunctions::sha2_close(FunctionContext* context, FunctionContext::FunctionStateScope scope) {if (scope == FunctionContext::FRAGMENT_LOCAL) {auto fc = reinterpret_cast<SHA2Ctx*>(context->get_function_state(scope));delete fc;}return Status::OK();}
增加对应的单元测试
TEST_P(ShaTestFixture, test_sha2) {auto [str, len, expected] = GetParam();std::unique_ptr<FunctionContext> ctx(FunctionContext::create_test_context());Columns columns;auto plain = BinaryColumn::create();plain->append(str);ColumnPtr hash_length =len == -1 ? ColumnHelper::create_const_null_column(1) : ColumnHelper::create_const_column<TYPE_INT>(len, 1);if (str == "NULL") {columns.emplace_back(ColumnHelper::create_const_null_column(1));} else {columns.emplace_back(plain);}columns.emplace_back(hash_length);ctx->set_constant_columns(columns);ASSERT_TRUE(EncryptionFunctions::sha2_prepare(ctx.get(), FunctionContext::FunctionStateScope::FRAGMENT_LOCAL).ok());if (len != -1) {ASSERT_NE(nullptr, ctx->get_function_state(FunctionContext::FRAGMENT_LOCAL));} else {ASSERT_EQ(nullptr, ctx->get_function_state(FunctionContext::FRAGMENT_LOCAL));}ColumnPtr result = EncryptionFunctions::sha2(ctx.get(), columns).value();if (expected == "NULL") {std::cerr << result->debug_string() << std::endl;EXPECT_TRUE(result->is_null(0));} else {auto v = ColumnHelper::cast_to<TYPE_VARCHAR>(result);EXPECT_EQ(expected, v->get_data()[0].to_string());}ASSERT_TRUE(EncryptionFunctions::sha2_close(ctx.get(),FunctionContext::FunctionContext::FunctionStateScope::FRAGMENT_LOCAL).ok());}
02
如何为 StarRocks 添加聚合函数
2-1
聚合函数介绍
2-1
聚合函数介绍
2-2
聚合函数的实现原理
2-2
在查询执行阶段,Pipeline 引擎的聚合算子通过 Aggregator 完成聚合计算,聚合算子的实现原理可参见文末《StarRocks 聚合算子源码解析》[9],本文主要关注聚合函数的实现原理。
Aggregator 在 prepare 阶段会根据函数名找到对应的 AggregateFunction 并保存下来,AggregateFunction 是最重要的抽象,封装了聚合计算过程中需要的各个接口,每个聚合函数都需要继承 AggregateFunction 实现自己的逻辑。计算的中间结果保存在 AggDataPtr 中,AggDataPrt 是一个指针,指向描述中间结果的数据结构。每种聚合函数的中间结果都不相同,比如求和函数,只需要保存 sum 即可,而平均值函数,除了保存 sum 之外,还需要记录 count。
// 逐行读取数据,不断更新 state 中保存的中间结果。void update(FunctionContext* ctx, const Column** columns, AggDataPtr __restrict state, size_t row_num)// 通常用在多阶段聚合中,读取已经算好的部分中间结果,合并计算,更新 state 中的数据。void merge(FunctionContext* ctx, const Column* column, AggDataPtr __restrict state, size_t row_num)// 多阶段的聚合可能会通过多个节点执行,计算的中间结果需要跨网络传输,这个方法用来实现序列化的逻辑。void serialize_to_column(FunctionContext* ctx, ConstAggDataPtr __restrict state, Column* to)// 把中间结果转成最终对用户返回的结果。比如求和函数,直接返回中间结果保存的 sum 即可,而平均值函数,需要返回 sum/count。void finalize_to_column(FunctionContext* ctx, ConstAggDataPtr __restrict state, Column* to)// 重置 state 的状态,比如在 window aggregate 中,我们会用一个的 state 保存中间结果,每次遇到新的 group时,需要通过 reset 重置,然后才能进行接下来的计算。void reset(FunctionContext* ctx, const Columns& args, AggDataPtr __restrict state)
除了上述内容之外,为了减少函数调用的开销,AggregateFunction 还封装了批量操作的接口,具体的细节这里就不展开讲解了,可以参考 be/src/exprs/agg/aggregate.h
。
2-3
添加聚合函数示例
2-3

在 FE 创建函数签名
// ANY_VALUEaddBuiltin(AggregateFunction.createBuiltin(ANY_VALUE,Lists.newArrayList(t), t, t, true, false, false));
在 BE 实现函数的计算逻辑
ANY_VALUE的语义很简单,在每个 group 中选择一行返回。
AnyValueAggregateData描述,只需要记录当前是否已经有结果以及对应的数据是什么即可,
AnyValueAggregateData为每种数据类型进行了特化,实现上几乎一致。具体代码如下:
template <LogicalType LT>struct AnyValueAggregateData {using T = AggDataValueType<LT>;T result;bool has_value = false;void reset() {result = T{};has_value = false;}};
template <LogicalType LT, typename State>struct AnyValueElement {using RefType = AggDataRefType<LT>;void operator()(State& state, RefType right) const {if (UNLIKELY(!state.has_value)) {AggDataTypeTraits<LT>::assign_value(state.result, right);state.has_value = true;}}};
template <LogicalType LT, typename State, class OP, typename T = RunTimeC++Type<LT>, typename = guard::Guard>class AnyValueAggregateFunction final: public AggregateFunctionBatchHelper<State, AnyValueAggregateFunction<LT, State, OP, T>> {public:using InputColumnType = RunTimeColumnType<LT>;void reset(FunctionContext* ctx, const Columns& args, AggDataPtr state) const override {this->data(state).reset();}void update(FunctionContext* ctx, const Column** columns, AggDataPtr __restrict state,size_t row_num) const override {DCHECK(!columns[0]->is_nullable());const auto& column = down_cast<const InputColumnType&>(*columns[0]);OP()(this->data(state), AggDataTypeTraits<LT>::get_row_ref(column, row_num));}void update_batch_single_state(FunctionContext* ctx, size_t chunk_size, const Column** columns,AggDataPtr __restrict state) const override {update(ctx, columns, state, 0);}void merge(FunctionContext* ctx, const Column* column, AggDataPtr __restrict state, size_t row_num) const override {DCHECK(!column->is_nullable());const auto& input_column = down_cast<const InputColumnType&>(*column);OP()(this->data(state), AggDataTypeTraits<LT>::get_row_ref(input_column, row_num));}void serialize_to_column(FunctionContext* ctx, ConstAggDataPtr __restrict state, Column* to) const override {DCHECK(!to->is_nullable());AggDataTypeTraits<LT>::append_value(down_cast<InputColumnType*>(to), this->data(state).result);}void convert_to_serialize_format(FunctionContext* ctx, const Columns& src, size_t chunk_size,ColumnPtr* dst) const override {*dst = src[0];}void finalize_to_column(FunctionContext* ctx, ConstAggDataPtr __restrict state, Column* to) const override {DCHECK(!to->is_nullable());AggDataTypeTraits<LT>::append_value(down_cast<InputColumnType*>(to), this->data(state).result);}void get_values(FunctionContext* ctx, ConstAggDataPtr __restrict state, Column* dst, size_t start,size_t end) const override {DCHECK_GT(end, start);InputColumnType* column = down_cast<InputColumnType*>(dst);for (size_t i = start; i < end; ++i) {AggDataTypeTraits<LT>::append_value(column, this->data(state).result);}}std::string get_name() const override { return "any_value"; }};
be/src/exprs/agg/any_value.h在 AggregateFactory 中注册
MakeAnyValueAggregateFunction实现,相关的改动可以在
aggregate_factory.hpp[13]中 grep MakeAnyValueAggregateFunction
看到,比较简单,这里不再过多赘述,具体示例如下:
template <LogicalType LT>AggregateFunctionPtr AggregateFactory::MakeAnyValueAggregateFunction() {return std::make_shared<AnyValueAggregateFunction<LT, AnyValueAggregateData<LT>, AnyValueElement<LT, AnyValueAggregateData<LT>>>>();}
添加单元测试
test/exprs/agg/aggregate_test.cpp[14] 添加单测,比如:
TEST_F(AggregateTest, test_any_value) {const AggregateFunction* func = get_aggregate_function("any_value", TYPE_SMALLINT, TYPE_SMALLINT, false);test_non_deterministic_agg_function<int16_t, int16_t>(ctx, func);func = get_aggregate_function("any_value", TYPE_INT, TYPE_INT, false);test_non_deterministic_agg_function<int32_t, int32_t>(ctx, func);func = get_aggregate_function("any_value", TYPE_BIGINT, TYPE_BIGINT, false);test_non_deterministic_agg_function<int64_t, int64_t>(ctx, func);func = get_aggregate_function("any_value", TYPE_LARGEINT, TYPE_LARGEINT, false);test_non_deterministic_agg_function<int128_t, int128_t>(ctx, func);func = get_aggregate_function("any_value", TYPE_FLOAT, TYPE_FLOAT, false);test_non_deterministic_agg_function<float, float>(ctx, func);func = get_aggregate_function("any_value", TYPE_DOUBLE, TYPE_DOUBLE, false);test_non_deterministic_agg_function<double, double>(ctx, func);func = get_aggregate_function("any_value", TYPE_VARCHAR, TYPE_VARCHAR, false);test_non_deterministic_agg_function<Slice, Slice>(ctx, func);func = get_aggregate_function("any_value", TYPE_DECIMALV2, TYPE_DECIMALV2, false);test_non_deterministic_agg_function<DecimalV2Value, DecimalV2Value>(ctx, func);func = get_aggregate_function("any_value", TYPE_DATETIME, TYPE_DATETIME, false);test_non_deterministic_agg_function<TimestampValue, TimestampValue>(ctx, func);func = get_aggregate_function("any_value", TYPE_DATE, TYPE_DATE, false);test_non_deterministic_agg_function<DateValue, DateValue>(ctx, func);}
03
总结
be/src/exprs/ 目录下。若想查看某个函数的实现,可以在函数签名中找到对应的 be function,然后在该目录下使用 grep 进行查找。
be/src/exprs/agg目录下查找。

相关链接:
[5]FunctionDescriptor:https://github.com/StarRocks/starrocks/blob/main/be/src/exprs/builtin_functions.h#L32
[6]sha2 函数:https://docs.starrocks.io/zh-cn/latest/sql-reference/sql-functions/crytographic-functions/sha2#%E5%8A%9F%E8%83%BD
[7]EncryptionFunctions:https://github.com/StarRocks/starrocks/blob/main/be/src/exprs/encryption_functions.h
[8]EntryptionFunctionTest:https://github.com/StarRocks/starrocks/blob/main/be/test/exprs/encryption_functions_test.cpp
[9]《StarRocks 聚合算子源码解析》:https://zhuanlan.zhihu.com/p/592058276
[10]ANY_VALUE 功能:https://docs.starrocks.io/zh-cn/latest/sql-reference/sql-functions/aggregate-functions/any_value#%E5%8A%9F%E8%83%BD
[12]initAggregateBuiltins:https://github.com/StarRocks/starrocks/blob/main/fe/fe-core/src/main/java/com/starrocks/catalog/FunctionSet.java#L742
[13]aggregate_factory.cpp:https://github.com/StarRocks/starrocks/blob/main/be/src/exprs/agg/factory/aggregate_factory.hpp

关于 StarRocks






