Explorar o código

fixed es search

zero hai 1 ano
pai
achega
715888c851

+ 21 - 5
src/main/java/com/example/xiaoshiweixinback/service/common/EsDenseVectorService.java

@@ -3,6 +3,9 @@ package com.example.xiaoshiweixinback.service.common;
 import co.elastic.clients.elasticsearch.ElasticsearchClient;
 import co.elastic.clients.elasticsearch._types.InlineScript;
 import co.elastic.clients.elasticsearch._types.Script;
+import co.elastic.clients.elasticsearch._types.aggregations.Aggregate;
+import co.elastic.clients.elasticsearch._types.aggregations.Aggregation;
+import co.elastic.clients.elasticsearch._types.aggregations.AggregationBuilders;
 import co.elastic.clients.elasticsearch._types.query_dsl.Query;
 import co.elastic.clients.elasticsearch._types.query_dsl.QueryBuilders;
 import co.elastic.clients.elasticsearch.core.SearchRequest;
@@ -88,7 +91,7 @@ public class EsDenseVectorService {
             Float a = Float.parseFloat(item);
             imageList.add(a);
         });
-
+        System.out.println(imageList);
         if (!CollectionUtils.isEmpty(imageList)) {
             String source = "cosineSimilarity(params.queryVector, 'my_vector') + 1.0";
             InlineScript inlineScript = InlineScript.of(i -> i.lang("painless").params("queryVector", JsonData.of(imageList)).source(source));
@@ -109,6 +112,9 @@ public class EsDenseVectorService {
         //根据申请号去重
         FieldCollapse collapse = FieldCollapse.of(i -> i.field("app_no"));
         builder.collapse(collapse);
+        //统计总数
+        Aggregation aggregation = AggregationBuilders.cardinality(i -> i.field("app_no"));
+        builder.aggregations("count", aggregation);
 
         //分页
         if (pageNum != null && pageSize != null && pageNum > 0 && pageSize > 0) {
@@ -120,18 +126,28 @@ public class EsDenseVectorService {
         SearchResponse<PatentVector> response = client.search(builder.build(), PatentVector.class);
         List<Hit<PatentVector>> hits = response.hits().hits();
         List<EsPatentVectorVo> vectorVos = new ArrayList<>();
-        long total = response.hits().total().value();
         Double fixedScore = 1.8d;
-        for (Hit<PatentVector> hit : hits) {
-            Double score = hit.score();
-            if (score > fixedScore) {
+        if (hits.size() < 10) {
+            for (Hit<PatentVector> hit : hits) {
                 PatentVector vector = hit.source();
                 EsPatentVectorVo vectorVo = new EsPatentVectorVo();
                 BeanUtil.copy(vector, vectorVo);
                 vectorVos.add(vectorVo);
             }
+        } else {
+            for (Hit<PatentVector> hit : hits) {
+                Double score = hit.score();
+                if (score > fixedScore) {
+                    PatentVector vector = hit.source();
+                    EsPatentVectorVo vectorVo = new EsPatentVectorVo();
+                    BeanUtil.copy(vector, vectorVo);
+                    vectorVos.add(vectorVo);
+                }
+            }
         }
 
+        Aggregate aggregate = response.aggregations().get("count");
+        long total = aggregate.cardinality().value();
         Records records = new Records();
         records.setCurrent(pageNum);
         records.setSize(pageSize);