暂无图片
暂无图片
暂无图片
暂无图片
暂无图片

进一寸有进一寸的欢喜,谈谈如何优化 Milvus 数据库的向量查询功能

ZILLIZ 2021-11-04
469

✏️ 编者按

每年暑期,Milvus 社区都会携手中科院软件所,在「开源之夏」活动中为高校学生们准备丰富的工程项目,并安排导师答疑解惑。张煜旻同学在「开源之夏」活动中表现优秀,相信进一寸有进一寸的欢喜,尝试在贡献开源的过程中超越自我。

他的项目为 Milvus 数据库的向量查询操作提供精度控制,能让开发者自定义返回精度,在减少内存消耗的同时,提高了返回结果的可读性。

想要了解更多优质开源项目和项目经验分享?请戳:有哪些值得参与的开源项目?

项目简介

项目名称:支持指定搜索时返回的距离精度
学生简介:张煜旻,中国科学院大学电子信息软件工程专业硕士在读
项目导师:Zilliz 软件工程师张财
导师评语:张煜旻同学优化了 Milvus 数据库的查询功能,使其在搜索时可以用指定精度去进行查询,使搜索过程更灵活,用户可以根据自己的需求用不同的精度进行查询,给用户带来了便利。

支持指定搜索时返回的距离精度


  任务简介
在进行向量查询时,搜索请求返回 id 和 distance 字段,其中的 distance 字段类型是浮点数。Milvus 数据库所计算的距离是一个 32 位浮点数,但是 Python SDK 返回并以 64 位浮点显示它,导致某些精度无效。本项目的贡献是,支持指定搜索时返回的距离精度,解决了在 Python 端显示时部分精度无效的情况,并减少部分内存开销。
  项目目标
  • 解决计算结果和显示精度不匹配的问题

  • 支持搜索时返回指定的距离精度

  • 补充相关文档

  项目步骤
  • 前期调研,理解 Milvus 整体框架

  • 明确各模块之间的调用关系

  • 设计解决方案和确认结果

  项目综述
什么是 Milvus 数据库?
Milvus 是一款开源向量数据库,赋能 AI 应用和向量相似度搜索。在系统设计上, Milvus 数据库的前端有方便用户使用的 Python SDK(Client);在 Milvus 数据库的后端,整个系统分为了接入层(Access Layer)、协调服务(Coordinator Server)、执行节点(Worker Node)和存储服务(Storge)四个层面:
(1)接入层(Access Layer):系统的门面,包含了一组对等的 Proxy 节点。接入层是暴露给用户的统一 endpoint,负责转发请求并收集执行结果。
(2)协调服务(Coordinator Service):系统的大脑,负责分配任务给执行节点。共有四类协调者角色:root 协调者、data 协调者、query 协调者和 index 协调者。
(3)执行节点(Worker Node):系统的四肢,执行节点只负责被动执行协调服务发起的读写请求。目前有三类执行节点:data 节点、query 节点和 index 节点。
(4)存储服务(Storage):系统的骨骼,是所有其他功能实现的基础。Milvus  数据库依赖三类存储:元数据存储、消息存储(log broker)和对象存储。从语言角度来看,则可以看作三个语言层,分别是 Python 构成的 SDK 层、Go 构成的中间层和 C++ 构成的核心计算层。
Milvus 数据库的架构图
向量查询 Search 时,到底发生了什么?
在 Python SDK 端,当用户发起一个 Search API 调用时,这个调用会被封装成 gRPC 请求并发送给 Milvus 后端,同时 SDK 开始等待。而在后端,Proxy 节点首先接受了从 Python SDK 发送过来的请求,然后会对接受的请求进行处理,最后将其封装成 message,经由 Producer 发送到消费队列中。当消息被发送到消费队列后,Coordinator 将会对其进行协调,将信息发送到合适的 query node 中进行消费。而当 query node 接收到消息后,则会对消息进行进一步的处理,最后将信息传递给由 C++ 构成的计算层。在计算层,则会根据不同的情形,调用不同的计算函数对向量间的距离进行计算。当计算完成后,结果则会依次向上传递,直到到达 SDK 端。
  解决方案设计
通过前文简单介绍,我们对向量查询的过程有了一个大致的概念。同时,我们也可以清楚地认识到,为了完成查询目标,我们需要对 Python 构成的 SDK 层、Go 构成的中间层和 C++ 构成的计算层都进行修改,修改方案如下:
1. 在 Python 层中的修改步骤:
为向量查询 Search 请求添加一个 round_decimal 参数,从而确定返回的精度信息。同时,需要对参数进行一些合法性检查和异常处理,从而构建 gRPC 的请求:
    round_decimal = param_copy("round_decimal", 3)
    if not isinstance(round_decimal, (int, str))
    raise ParamError("round_decimal must be int or str")
    try:
    round_decimal = int(round_decimal)
    except Exception:
    raise ParamError("round_decimal is not illegal")

    if round_decimal < 0 or round_decimal > 6:
    raise ParamError("round_decimal must be greater than zero and less than seven")
    if not instance(params, dict):
    raise ParamError("Search params must be a dict")
    search_params = {"anns_field": anns_field, "topk": limit, "metric_type": metric_type, "params": params, "round_decimal": round_decimal}
    2. 在 Go 层中的修改步骤:
    在 task.go 文件中添加 RoundDecimalKey 这个常量,保持风格统一并方便后续调取:
      const (
      InsertTaskName = "InsertTask"
      CreateCollectionTaskName = "CreateCollectionTask"
      DropCollectionTaskName = "DropCollectionTask"
      SearchTaskName = "SearchTask"
      RetrieveTaskName = "RetrieveTask"
      QueryTaskName = "QueryTask"
      AnnsFieldKey = "anns_field"
      TopKKey = "topk"
      MetricTypeKey = "metric_type"
      SearchParamsKey = "params"
      RoundDecimalKey = "round_decimal"
      HasCollectionTaskName = "HasCollectionTask"
      DescribeCollectionTaskName = "DescribeCollectionTask"
      接着,修改 PreExecute 函数,获取 round_decimal 的值,构建 queryInfo 变量,并添加异常处理:
        searchParams, err := funcutil.GetAttrByKeyFromRepeatedKV(SearchParamsKey, st.query.SearchParams)
        if err != nil {
        return errors.New(SearchParamsKey + " not found in search_params")
        }
        roundDecimalStr, err := funcutil.GetAttrByKeyFromRepeatedKV(RoundDecimalKey, st.query.SearchParams)
        if err != nil {
        return errors.New(RoundDecimalKey + "not found in search_params")
        }
        roundDeciaml, err := strconv.Atoi(roundDecimalStr)
        if err != nil {
        return errors.New(RoundDecimalKey + " " + roundDecimalStr + " is not invalid")
        }

        queryInfo := &planpb.QueryInfo{
        Topk: int64(topK),
        MetricType: metricType,
        SearchParams: searchParams,
        RoundDecimal: int64(roundDeciaml),
        }
        同时,修改 query 的 proto 文件,为 QueryInfo 添加 round_decimal 变量:
          message QueryInfo {
          int64 topk = 1;
          string metric_type = 3;
          string search_params = 4;
          int64 round_decimal = 5;
          }
          3. 在 C++ 层中的修改步骤:
          在 SearchInfo 结构体中添加新的变量 round_decimal_ ,从而接受 Go 层传来的 round_decimal 值:
            struct SearchInfo {
            int64_t topk_;
            int64_t round_decimal_;
            FieldOffset field_offset_;
            MetricType metric_type_;
            nlohmann::json search_params_;
            };
            在 ParseVecNode 和 PlanNodeFromProto 函数中,SearchInfo 结构体需要接受 Go 层中 round_decimal 值:
              std::unique_ptr<VectorPlanNode>
              Parser::ParseVecNode(const Json& out_body) {
              Assert(out_body.is_object());
              Assert(out_body.size() == 1);
              auto iter = out_body.begin();
              auto field_name = FieldName(iter.key());

              auto& vec_info = iter.value();
              Assert(vec_info.is_object());
              auto topk = vec_info["topk"];
              AssertInfo(topk > 0, "topk must greater than 0");
              AssertInfo(topk < 16384, "topk is too large");

              auto field_offset = schema.get_offset(field_name);

              auto vec_node = [&]() -> std::unique_ptr<VectorPlanNode> {
              auto& field_meta = schema.operator[](field_name);
              auto data_type = field_meta.get_data_type();
              if (data_type == DataType::VECTOR_FLOAT) {
              return std::make_unique<FloatVectorANNS>();
              } else {
              return std::make_unique<BinaryVectorANNS>();
              }
              }();
              vec_node->search_info_.topk_ = topk;
              vec_node->search_info_.metric_type_ = GetMetricType(vec_info.at("metric_type"));
              vec_node->search_info_.search_params_ = vec_info.at("params");
              vec_node->search_info_.field_offset_ = field_offset;
              vec_node->search_info_.round_decimal_ = vec_info.at("round_decimal");
              vec_node->placeholder_tag_ = vec_info.at("query");
              auto tag = vec_node->placeholder_tag_;
              AssertInfo(!tag2field_.count(tag), "duplicated placeholder tag");
              tag2field_.emplace(tag, field_offset);
              return vec_node;
              }
                std::unique_ptr<VectorPlanNode>
                ProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) {
                // TODO: add more buffs
                Assert(plan_node_proto.has_vector_anns());
                auto& anns_proto = plan_node_proto.vector_anns();
                auto expr_opt = [&]() -> std::optional<ExprPtr> {
                if (!anns_proto.has_predicates()) {
                return std::nullopt;
                } else {
                return ParseExpr(anns_proto.predicates());
                }
                }();

                auto& query_info_proto = anns_proto.query_info();

                SearchInfo search_info;
                auto field_id = FieldId(anns_proto.field_id());
                auto field_offset = schema.get_offset(field_id);
                search_info.field_offset_ = field_offset;

                search_info.metric_type_ = GetMetricType(query_info_proto.metric_type());
                search_info.topk_ = query_info_proto.topk();
                search_info.round_decimal_ = query_info_proto.round_decimal();
                search_info.search_params_ = json::parse(query_info_proto.search_params());

                auto plan_node = [&]() -> std::unique_ptr<VectorPlanNode> {
                if (anns_proto.is_binary()) {
                return std::make_unique<BinaryVectorANNS>();
                } else {
                return std::make_unique<FloatVectorANNS>();
                }
                }();
                plan_node->placeholder_tag_ = anns_proto.placeholder_tag();
                plan_node->predicate_ = std::move(expr_opt);
                plan_node->search_info_ = std::move(search_info);
                return plan_node;
                }
                在 SubSearchResult 类添加新的成员变量 round_decimal,同时修改每一处的 SubSearchResult 变量声明:
                  class SubSearchResult {
                  public:
                  SubSearchResult(int64_t num_queries, int64_t topk, MetricType metric_type)
                  : metric_type_(metric_type),
                  num_queries_(num_queries),
                  topk_(topk),
                  labels_(num_queries * topk, -1),
                  values_(num_queries * topk, init_value(metric_type)) {
                      }
                  在 SubSearchResult 类添加一个新的成员函数,以便最后对每一个结果进行四舍五入精度控制:
                    void
                    SubSearchResult::round_values() {
                    if (round_decimal_ == -1)
                    return;
                    const float multiplier = pow(10.0, round_decimal_);
                    for (auto it = this->values_.begin(); it != this->values_.end(); it++) {
                    *it = round(*it * multiplier) multiplier;
                    }
                    }
                    为 SearchDataset 结构体添加新的变量 round_decimal,同时修改每一处的 SearchDataset 变量声明:
                      struct SearchDataset {
                      MetricType metric_type;
                      int64_t num_queries;
                      int64_t topk;
                      int64_t round_decimal;
                      int64_t dim;
                      const void* query_data;
                      };
                      修改 C++ 层中各个距离计算函数(FloatSearch、BinarySearchBruteForceFast 等等),使其接受 round_decomal 值:
                        StatusFloatSearch(const segcore::SegmentGrowingImpl& segment, const query::SearchInfo& info, const float* query_data, int64_t num_queries, int64_t ins_barrier, const BitsetView& bitset, SearchResult& results) { auto& schema = segment.get_schema(); auto& indexing_record = segment.get_indexing_record(); auto& record = segment.get_insert_record(); // step 1: binary search to find the barrier of the snapshot auto del_barrier = get_barrier(deleted_record_, timestamp);#if 0 auto bitmap_holder = get_deleted_bitmap(del_barrier, timestamp, ins_barrier); Assert(bitmap_holder); auto bitmap = bitmap_holder->bitmap_ptr;#endif step 2.1: get meta step 2.2: get which vector field to search auto vecfield_offset = info.field_offset_; auto& field = schema[vecfield_offset]; AssertInfo(field.get_data_type() == DataType::VECTOR_FLOAT, "[FloatSearch]Field data type isn't VECTOR_FLOAT"); auto dim = field.get_dim(); auto topk = info.topk_; auto total_count = topk * num_queries; auto metric_type = info.metric_type_; auto round_decimal = info.round_decimal_; step 3: small indexing search / std::vector<int64_t> final_uids(total_count, -1); // std::vector<float> final_dis(total_count, std::numeric_limits<float>::max()); SubSearchResult final_qr(num_queries, topk, metric_type, round_decimal); dataset::SearchDataset search_dataset{metric_type, num_queries, topk, round_decimal, dim, query_data}; auto vec_ptr = record.get_field_data<FloatVector>(vecfield_offset); int current_chunk_id = 0;
                          SubSearchResult
                          BinarySearchBruteForceFast(MetricType metric_type,
                          int64_t dim,
                          const uint8_t* binary_chunk,
                          int64_t size_per_chunk,
                          int64_t topk,
                          int64_t num_queries,
                          int64_t round_decimal,
                          const uint8_t* query_data,
                          const faiss::BitsetView& bitset) {
                          SubSearchResult sub_result(num_queries, topk, metric_type, round_decimal);
                          float* result_distances = sub_result.get_values();
                          idx_t* result_labels = sub_result.get_labels();


                          int64_t code_size = dim / 8;
                          const idx_t block_size = size_per_chunk;


                          raw_search(metric_type, binary_chunk, size_per_chunk, code_size, num_queries, query_data, topk, result_distances,
                          result_labels, bitset);
                          sub_result.round_values();
                          return sub_result;
                          }

                            结果确认
                          1. 对 Milvus 数据库进行重新编译:
                          2. 启动环境容器:
                          3. 启动 Milvus 数据库:
                          4.构建向量查询请求:
                          5. 确认结果,默认保留 3 位小数,0 舍去:

                            总结和感想
                          参加这次的夏季开源活动,对我来说是非常宝贵的经历。在这次活动中,我第一次尝试阅读开源项目代码,第一次尝试接触多语言构成的项目,第一次接触到 Make、gRPc、pytest 等等。在编写代码和测试代码阶段,我也遇到来许多意想不到的问题,例如,「奇奇怪怪」的依赖问题、由于 Conda 环境导致的编译失败问题、测试无法通过等等。面对这些问题,我渐渐学会耐心细心地查看报错日志,积极思考、检查代码并进行测试,一步一步缩小错误范围,定位错误代码并尝试各种解决方案。
                          通过这次的活动,我吸取经验和教训,同时也十分感谢张财导师,感谢他在我开发过程中耐心地帮我答疑解惑、指导方向!同时,希望大家能多多关注 Milvus 社区相信一定能够有所收获!
                          最后,欢迎大家多多与我交流(📮 deepmin@mail.deepexplore.top ),我主要的研究方向是自然语言处理,平时喜欢看科幻小说、动画和折腾服务器个人网站,每日闲逛 Stack Overflow 和GitHub。我相信进一寸有进一寸的欢喜,希望能和你一起共同进步。


                          Zilliz 以重新定义数据科学为愿景,致力于打造一家全球领先的开源技术创新公司,并通过开源和云原生解决方案为企业解锁非结构化数据的隐藏价值。
                          Zilliz 构建了 Milvus 向量数据库,以加快下一代数据平台的发展。Milvus 数据库是 LF AI & Data 基金会的毕业项目,能够管理大量非结构化数据集,在新药发现、推荐系统、聊天机器人等方面具有广泛的应用。
                          解锁更多应用场景
                          最后修改时间:2021-11-04 08:23:23
                          文章转载自ZILLIZ,如果涉嫌侵权,请发送邮件至:contact@modb.pro进行举报,并提供相关证据,一经查实,墨天轮将立刻删除相关内容。

                          评论