✏️ 编者按
每年暑期,Milvus 社区都会携手中科院软件所,在「开源之夏」活动中为高校学生们准备丰富的工程项目,并安排导师答疑解惑。张煜旻同学在「开源之夏」活动中表现优秀,他相信进一寸有进一寸的欢喜,尝试在贡献开源的过程中超越自我。
他的项目为 Milvus 数据库的向量查询操作提供精度控制,能让开发者自定义返回精度,在减少内存消耗的同时,提高了返回结果的可读性。
想要了解更多优质开源项目和项目经验分享?请戳:有哪些值得参与的开源项目?
项目简介
支持指定搜索时返回的距离精度
解决计算结果和显示精度不匹配的问题
支持搜索时返回指定的距离精度
补充相关文档
前期调研,理解 Milvus 整体框架
明确各模块之间的调用关系
设计解决方案和确认结果

round_decimal = param_copy("round_decimal", 3)if not isinstance(round_decimal, (int, str))raise ParamError("round_decimal must be int or str")try:round_decimal = int(round_decimal)except Exception:raise ParamError("round_decimal is not illegal")if round_decimal < 0 or round_decimal > 6:raise ParamError("round_decimal must be greater than zero and less than seven")if not instance(params, dict):raise ParamError("Search params must be a dict")search_params = {"anns_field": anns_field, "topk": limit, "metric_type": metric_type, "params": params, "round_decimal": round_decimal}
const (InsertTaskName = "InsertTask"CreateCollectionTaskName = "CreateCollectionTask"DropCollectionTaskName = "DropCollectionTask"SearchTaskName = "SearchTask"RetrieveTaskName = "RetrieveTask"QueryTaskName = "QueryTask"AnnsFieldKey = "anns_field"TopKKey = "topk"MetricTypeKey = "metric_type"SearchParamsKey = "params"RoundDecimalKey = "round_decimal"HasCollectionTaskName = "HasCollectionTask"DescribeCollectionTaskName = "DescribeCollectionTask"
searchParams, err := funcutil.GetAttrByKeyFromRepeatedKV(SearchParamsKey, st.query.SearchParams)if err != nil {return errors.New(SearchParamsKey + " not found in search_params")}roundDecimalStr, err := funcutil.GetAttrByKeyFromRepeatedKV(RoundDecimalKey, st.query.SearchParams)if err != nil {return errors.New(RoundDecimalKey + "not found in search_params")}roundDeciaml, err := strconv.Atoi(roundDecimalStr)if err != nil {return errors.New(RoundDecimalKey + " " + roundDecimalStr + " is not invalid")}queryInfo := &planpb.QueryInfo{Topk: int64(topK),MetricType: metricType,SearchParams: searchParams,RoundDecimal: int64(roundDeciaml),}
message QueryInfo {int64 topk = 1;string metric_type = 3;string search_params = 4;int64 round_decimal = 5;}
struct SearchInfo {int64_t topk_;int64_t round_decimal_;FieldOffset field_offset_;MetricType metric_type_;nlohmann::json search_params_;};
std::unique_ptr<VectorPlanNode>Parser::ParseVecNode(const Json& out_body) {Assert(out_body.is_object());Assert(out_body.size() == 1);auto iter = out_body.begin();auto field_name = FieldName(iter.key());auto& vec_info = iter.value();Assert(vec_info.is_object());auto topk = vec_info["topk"];AssertInfo(topk > 0, "topk must greater than 0");AssertInfo(topk < 16384, "topk is too large");auto field_offset = schema.get_offset(field_name);auto vec_node = [&]() -> std::unique_ptr<VectorPlanNode> {auto& field_meta = schema.operator[](field_name);auto data_type = field_meta.get_data_type();if (data_type == DataType::VECTOR_FLOAT) {return std::make_unique<FloatVectorANNS>();} else {return std::make_unique<BinaryVectorANNS>();}}();vec_node->search_info_.topk_ = topk;vec_node->search_info_.metric_type_ = GetMetricType(vec_info.at("metric_type"));vec_node->search_info_.search_params_ = vec_info.at("params");vec_node->search_info_.field_offset_ = field_offset;vec_node->search_info_.round_decimal_ = vec_info.at("round_decimal");vec_node->placeholder_tag_ = vec_info.at("query");auto tag = vec_node->placeholder_tag_;AssertInfo(!tag2field_.count(tag), "duplicated placeholder tag");tag2field_.emplace(tag, field_offset);return vec_node;}
std::unique_ptr<VectorPlanNode>ProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) {// TODO: add more buffsAssert(plan_node_proto.has_vector_anns());auto& anns_proto = plan_node_proto.vector_anns();auto expr_opt = [&]() -> std::optional<ExprPtr> {if (!anns_proto.has_predicates()) {return std::nullopt;} else {return ParseExpr(anns_proto.predicates());}}();auto& query_info_proto = anns_proto.query_info();SearchInfo search_info;auto field_id = FieldId(anns_proto.field_id());auto field_offset = schema.get_offset(field_id);search_info.field_offset_ = field_offset;search_info.metric_type_ = GetMetricType(query_info_proto.metric_type());search_info.topk_ = query_info_proto.topk();search_info.round_decimal_ = query_info_proto.round_decimal();search_info.search_params_ = json::parse(query_info_proto.search_params());auto plan_node = [&]() -> std::unique_ptr<VectorPlanNode> {if (anns_proto.is_binary()) {return std::make_unique<BinaryVectorANNS>();} else {return std::make_unique<FloatVectorANNS>();}}();plan_node->placeholder_tag_ = anns_proto.placeholder_tag();plan_node->predicate_ = std::move(expr_opt);plan_node->search_info_ = std::move(search_info);return plan_node;}
class SubSearchResult {public:SubSearchResult(int64_t num_queries, int64_t topk, MetricType metric_type): metric_type_(metric_type),num_queries_(num_queries),topk_(topk),labels_(num_queries * topk, -1),values_(num_queries * topk, init_value(metric_type)) {}
voidSubSearchResult::round_values() {if (round_decimal_ == -1)return;const float multiplier = pow(10.0, round_decimal_);for (auto it = this->values_.begin(); it != this->values_.end(); it++) {*it = round(*it * multiplier) multiplier;}}
struct SearchDataset {MetricType metric_type;int64_t num_queries;int64_t topk;int64_t round_decimal;int64_t dim;const void* query_data;};
StatusFloatSearch(const segcore::SegmentGrowingImpl& segment, const query::SearchInfo& info, const float* query_data, int64_t num_queries, int64_t ins_barrier, const BitsetView& bitset, SearchResult& results) { auto& schema = segment.get_schema(); auto& indexing_record = segment.get_indexing_record(); auto& record = segment.get_insert_record(); // step 1: binary search to find the barrier of the snapshot auto del_barrier = get_barrier(deleted_record_, timestamp);#if 0 auto bitmap_holder = get_deleted_bitmap(del_barrier, timestamp, ins_barrier); Assert(bitmap_holder); auto bitmap = bitmap_holder->bitmap_ptr;#endif step 2.1: get meta step 2.2: get which vector field to search auto vecfield_offset = info.field_offset_; auto& field = schema[vecfield_offset]; AssertInfo(field.get_data_type() == DataType::VECTOR_FLOAT, "[FloatSearch]Field data type isn't VECTOR_FLOAT"); auto dim = field.get_dim(); auto topk = info.topk_; auto total_count = topk * num_queries; auto metric_type = info.metric_type_; auto round_decimal = info.round_decimal_; step 3: small indexing search / std::vector<int64_t> final_uids(total_count, -1); // std::vector<float> final_dis(total_count, std::numeric_limits<float>::max()); SubSearchResult final_qr(num_queries, topk, metric_type, round_decimal); dataset::SearchDataset search_dataset{metric_type, num_queries, topk, round_decimal, dim, query_data}; auto vec_ptr = record.get_field_data<FloatVector>(vecfield_offset); int current_chunk_id = 0;
SubSearchResultBinarySearchBruteForceFast(MetricType metric_type,int64_t dim,const uint8_t* binary_chunk,int64_t size_per_chunk,int64_t topk,int64_t num_queries,int64_t round_decimal,const uint8_t* query_data,const faiss::BitsetView& bitset) {SubSearchResult sub_result(num_queries, topk, metric_type, round_decimal);float* result_distances = sub_result.get_values();idx_t* result_labels = sub_result.get_labels();int64_t code_size = dim / 8;const idx_t block_size = size_per_chunk;raw_search(metric_type, binary_chunk, size_per_chunk, code_size, num_queries, query_data, topk, result_distances,result_labels, bitset);sub_result.round_values();return sub_result;}






最后修改时间:2021-11-04 08:23:23
文章转载自ZILLIZ,如果涉嫌侵权,请发送邮件至:contact@modb.pro进行举报,并提供相关证据,一经查实,墨天轮将立刻删除相关内容。




