EsDenseVectorService.java 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. package com.example.xiaoshiweixinback.service.common;
  2. import co.elastic.clients.elasticsearch.ElasticsearchClient;
  3. import co.elastic.clients.elasticsearch._types.InlineScript;
  4. import co.elastic.clients.elasticsearch._types.Script;
  5. import co.elastic.clients.elasticsearch._types.query_dsl.Query;
  6. import co.elastic.clients.elasticsearch._types.query_dsl.QueryBuilders;
  7. import co.elastic.clients.elasticsearch.core.SearchRequest;
  8. import co.elastic.clients.elasticsearch.core.SearchResponse;
  9. import co.elastic.clients.elasticsearch.core.search.Hit;
  10. import co.elastic.clients.json.JsonData;
  11. import com.example.xiaoshiweixinback.business.utils.BeanUtil;
  12. import com.example.xiaoshiweixinback.business.utils.parseQueryToTree.expressManager;
  13. import com.example.xiaoshiweixinback.business.utils.parseQueryToTree.operateNode;
  14. import com.example.xiaoshiweixinback.business.utils.parseQueryToTree.treeNode;
  15. import com.example.xiaoshiweixinback.domain.es.PatentVector;
  16. import com.example.xiaoshiweixinback.entity.dto.esPicture.EsPatentSearchDTO;
  17. import com.example.xiaoshiweixinback.entity.dto.esPicture.EsPictureNoDTO;
  18. import com.example.xiaoshiweixinback.entity.vo.esPicture.EsPictureNoVo;
  19. import com.example.xiaoshiweixinback.entity.vo.esPicture.EsPictureVectorVo;
  20. import com.example.xiaoshiweixinback.service.importPatent.FormatQueryService;
  21. import lombok.RequiredArgsConstructor;
  22. import org.apache.commons.lang3.StringUtils;
  23. import org.springframework.beans.factory.annotation.Autowired;
  24. import org.springframework.context.annotation.Lazy;
  25. import org.springframework.stereotype.Service;
  26. import java.io.File;
  27. import java.io.IOException;
  28. import java.util.ArrayList;
  29. import java.util.List;
  30. @Service
  31. @RequiredArgsConstructor(onConstructor_ = {@Lazy})
  32. public class EsDenseVectorService {
  33. private final ElasticsearchClient client;
  34. @Autowired
  35. private FormatQueryService formatQueryService;
  36. @Autowired
  37. private GetVectorService getVectorService;
  38. public List<EsPictureVectorVo> getPatentList(EsPatentSearchDTO dto) throws Exception {
  39. Long pageNum = dto.getPageNum();
  40. Long pageSize = dto.getPageSize();
  41. // String key = dto.getKey().replaceAll("[,。、;,./;\\s]"," OR ");
  42. String key = dto.getKey().replaceAll("[,。、;,./;]"," OR ");
  43. String s = "TI = " + "(" + key + ")";
  44. System.out.println(s);
  45. SearchRequest.Builder builder = new SearchRequest.Builder();
  46. //设置查询索引
  47. builder.index("patent_vector");
  48. //1. 解析检索条件
  49. treeNode tree = expressManager.getInstance().Parse(s, false);
  50. //3. 从es中检索数据
  51. Query query = formatQueryService.EsQueryToQuery((operateNode) tree, "patentVector", null);
  52. builder.query(query);
  53. //分页
  54. if (pageNum != null && pageSize != null && pageNum > 0 && pageSize > 0) {
  55. builder.from((pageNum.intValue() - 1) * pageSize.intValue()).size(pageSize.intValue());
  56. }
  57. // else {
  58. // builder.from(0).size(9);
  59. // }
  60. SearchResponse<PatentVector> response = client.search(builder.build(), PatentVector.class);
  61. List<EsPictureVectorVo> vectorVos = new ArrayList<>();
  62. List<Hit<PatentVector>> hits = response.hits().hits();
  63. for (Hit<PatentVector> hit : hits) {
  64. PatentVector vector = hit.source();
  65. EsPictureVectorVo vectorVo = new EsPictureVectorVo();
  66. BeanUtil.copy(vector,vectorVo);
  67. vectorVos.add(vectorVo);
  68. }
  69. return vectorVos;
  70. }
  71. public List<EsPictureVectorVo> getPatentVectorSort(File file, String description) throws IOException {
  72. List<Float> imageList = new ArrayList<>();
  73. List<String> stringList = new ArrayList<>();
  74. if (file != null) {
  75. stringList = getVectorService.getVectorByFile(file);
  76. } else if (StringUtils.isNotEmpty(description)) {
  77. stringList = getVectorService.getVectorByText(description);
  78. }
  79. stringList.forEach(item -> {
  80. Float a = Float.parseFloat(item);
  81. imageList.add(a);
  82. });
  83. List<EsPictureVectorVo> list = new ArrayList<>();
  84. SearchRequest.Builder builder = new SearchRequest.Builder();
  85. //设置查询索引
  86. builder.index("patent_vector");
  87. String source = "cosineSimilarity(params.queryVector, 'my_vector') + 1.0";
  88. InlineScript inlineScript = InlineScript.of(i -> i.lang("painless").params("queryVector", JsonData.of(imageList)).source(source));
  89. Script script = Script.of(i -> i.inline(inlineScript));
  90. Query query = QueryBuilders.scriptScore(i -> i.script(script).query(org.springframework.data.elasticsearch.client.elc.QueryBuilders.matchAllQueryAsQuery()));
  91. builder.query(query);
  92. builder.size(100);
  93. SearchResponse<PatentVector> response = client.search(builder.build(), PatentVector.class);
  94. List<Hit<PatentVector>> hits = response.hits().hits();
  95. for (Hit<PatentVector> hit : hits) {
  96. PatentVector vector = hit.source();
  97. EsPictureVectorVo vectorVo = new EsPictureVectorVo();
  98. BeanUtil.copy(vector,vectorVo);
  99. list.add(vectorVo);
  100. }
  101. return list;
  102. }
  103. public List<EsPictureNoVo> getPictureByNo(EsPictureNoDTO noDTO) throws IOException {
  104. List<EsPictureNoVo> pictureNoVos = new ArrayList<>();
  105. SearchRequest.Builder builder = new SearchRequest.Builder();
  106. //设置查询索引
  107. builder.index("patent_vector");
  108. Query query = QueryBuilders.term(i -> i.field("app_no.keyword").value(noDTO.getAppNo()));
  109. builder.query(query);
  110. builder.size(100);
  111. SearchResponse<PatentVector> response = client.search(builder.build(), PatentVector.class);
  112. List<Hit<PatentVector>> hits = response.hits().hits();
  113. for (Hit<PatentVector> hit : hits) {
  114. PatentVector vector = hit.source();
  115. EsPictureNoVo noVo = new EsPictureNoVo();
  116. noVo.setGuid(vector.getGuid());
  117. noVo.setImageIndex(vector.getImageIndex());
  118. pictureNoVos.add(noVo);
  119. }
  120. return pictureNoVos;
  121. }
  122. }