EsDenseVectorService.java 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. package com.example.xiaoshiweixinback.service.common;
  2. import co.elastic.clients.elasticsearch.ElasticsearchClient;
  3. import co.elastic.clients.elasticsearch._types.InlineScript;
  4. import co.elastic.clients.elasticsearch._types.Script;
  5. import co.elastic.clients.elasticsearch._types.aggregations.Aggregate;
  6. import co.elastic.clients.elasticsearch._types.aggregations.Aggregation;
  7. import co.elastic.clients.elasticsearch._types.aggregations.AggregationBuilders;
  8. import co.elastic.clients.elasticsearch._types.query_dsl.Query;
  9. import co.elastic.clients.elasticsearch._types.query_dsl.QueryBuilders;
  10. import co.elastic.clients.elasticsearch.core.SearchRequest;
  11. import co.elastic.clients.elasticsearch.core.SearchResponse;
  12. import co.elastic.clients.elasticsearch.core.search.FieldCollapse;
  13. import co.elastic.clients.elasticsearch.core.search.Hit;
  14. import co.elastic.clients.json.JsonData;
  15. import com.example.xiaoshiweixinback.business.common.base.Records;
  16. import com.example.xiaoshiweixinback.business.utils.BeanUtil;
  17. import com.example.xiaoshiweixinback.business.utils.ToolUtil;
  18. import com.example.xiaoshiweixinback.business.utils.parseQueryToTree.expressManager;
  19. import com.example.xiaoshiweixinback.business.utils.parseQueryToTree.operateNode;
  20. import com.example.xiaoshiweixinback.business.utils.parseQueryToTree.treeNode;
  21. import com.example.xiaoshiweixinback.domain.Product;
  22. import com.example.xiaoshiweixinback.domain.es.PatentVector;
  23. import com.example.xiaoshiweixinback.entity.dto.esPicture.EsPictureNoDTO;
  24. import com.example.xiaoshiweixinback.entity.dto.esPicture.EsPatentVectorDTO;
  25. import com.example.xiaoshiweixinback.entity.dto.searchRecord.AddSearchRecordDTO;
  26. import com.example.xiaoshiweixinback.entity.product.ProductIdDTO;
  27. import com.example.xiaoshiweixinback.entity.vo.esPicture.EsPictureNoVo;
  28. import com.example.xiaoshiweixinback.entity.vo.esPicture.EsPatentVectorVo;
  29. import com.example.xiaoshiweixinback.mapper.ProductMapper;
  30. import com.example.xiaoshiweixinback.service.SearchRecordService;
  31. import com.example.xiaoshiweixinback.service.importPatent.FormatQueryService;
  32. import lombok.RequiredArgsConstructor;
  33. import org.apache.commons.lang3.StringUtils;
  34. import org.springframework.beans.factory.annotation.Autowired;
  35. import org.springframework.context.annotation.Lazy;
  36. import org.springframework.stereotype.Service;
  37. import org.springframework.transaction.annotation.Propagation;
  38. import org.springframework.transaction.annotation.Transactional;
  39. import org.springframework.util.CollectionUtils;
  40. import java.io.File;
  41. import java.io.IOException;
  42. import java.text.ParseException;
  43. import java.text.SimpleDateFormat;
  44. import java.util.*;
  45. import java.util.stream.Collectors;
  46. @Service
  47. @RequiredArgsConstructor(onConstructor_ = {@Lazy})
  48. public class EsDenseVectorService {
  49. private final ElasticsearchClient client;
  50. @Autowired
  51. private FormatQueryService formatQueryService;
  52. @Autowired
  53. private GetVectorService getVectorService;
  54. @Autowired
  55. private SearchRecordService searchRecordService;
  56. @Autowired
  57. private ProductMapper productMapper;
  58. /**
  59. * 根据图片排序获取列表
  60. *
  61. * @param dto
  62. * @return
  63. * @throws IOException
  64. */
  65. @Transactional(propagation = Propagation.REQUIRED, rollbackFor = Throwable.class)
  66. public Records getPatentVectors(EsPatentVectorDTO dto,File file) throws Exception {
  67. Long pageNum = dto.getPageNum();
  68. Long pageSize = dto.getPageSize();
  69. SearchRequest.Builder builder = new SearchRequest.Builder();
  70. //设置查询索引
  71. builder.index("patent_vector");
  72. Query q = null;
  73. String condition = this.appendCondition(dto.getProductId(), dto.getKey());
  74. if (StringUtils.isNotEmpty(condition)) {
  75. //1. 解析检索条件
  76. treeNode tree = expressManager.getInstance().Parse(condition, false);
  77. //2. 从es中检索数据
  78. q = formatQueryService.EsQueryToQuery((operateNode) tree, "patentVector", null);
  79. }
  80. //获取图片向量
  81. List<Float> imageList = new ArrayList<>();
  82. List<String> stringList = new ArrayList<>();
  83. if (file != null && file.exists() && file.length() != 0) {
  84. stringList = getVectorService.getVectorByFile(file);
  85. } else if (StringUtils.isNotEmpty(dto.getDescription())) {
  86. stringList = getVectorService.getVectorByText(dto.getDescription());
  87. }
  88. stringList.forEach(item -> {
  89. Float a = Float.parseFloat(item);
  90. imageList.add(a);
  91. });
  92. if (!CollectionUtils.isEmpty(imageList)) {
  93. String source = "cosineSimilarity(params.queryVector, 'my_vector') + 1.0";
  94. InlineScript inlineScript = InlineScript.of(i -> i.lang("painless").params("queryVector", JsonData.of(imageList)).source(source));
  95. Script script = Script.of(i -> i.inline(inlineScript));
  96. Query query = null;
  97. if (q != null) {
  98. Query finalQ = q;
  99. query = QueryBuilders.scriptScore(i -> i.script(script)
  100. .query(finalQ));
  101. } else {
  102. query = QueryBuilders.scriptScore(i -> i.script(script)
  103. .query(org.springframework.data.elasticsearch.client.elc.QueryBuilders.matchAllQueryAsQuery()));
  104. }
  105. builder.query(query);
  106. } else {
  107. builder.query(q);
  108. }
  109. //根据申请号去重
  110. FieldCollapse collapse = FieldCollapse.of(i -> i.field("app_no"));
  111. builder.collapse(collapse);
  112. //统计总数
  113. Aggregation aggregation = AggregationBuilders.cardinality(i -> i.field("app_no"));
  114. builder.aggregations("count", aggregation);
  115. //分页
  116. if (pageNum != null && pageSize != null && pageNum > 0 && pageSize > 0) {
  117. builder.from((pageNum.intValue() - 1) * pageSize.intValue()).size(pageSize.intValue());
  118. }
  119. //解除最大条数限制
  120. builder.trackTotalHits(i -> i.enabled(true));
  121. SearchResponse<PatentVector> response = client.search(builder.build(), PatentVector.class);
  122. List<Hit<PatentVector>> hits = response.hits().hits();
  123. List<EsPatentVectorVo> vectorVos = new ArrayList<>();
  124. Double fixedScore = 1.7d;
  125. for (Hit<PatentVector> hit : hits) {
  126. Double score = hit.score();
  127. if (score > fixedScore) {
  128. PatentVector vector = hit.source();
  129. EsPatentVectorVo vectorVo = new EsPatentVectorVo();
  130. BeanUtil.copy(vector, vectorVo);
  131. vectorVos.add(vectorVo);
  132. }
  133. }
  134. Aggregate aggregate = response.aggregations().get("count");
  135. long total = aggregate.cardinality().value();
  136. Records records = new Records();
  137. records.setCurrent(pageNum);
  138. records.setSize(pageSize);
  139. records.setData(vectorVos);
  140. long count = total <= vectorVos.size() ? total : vectorVos.size();
  141. records.setTotal(count);
  142. //添加检索历史
  143. AddSearchRecordDTO recordDTO = new AddSearchRecordDTO();
  144. recordDTO.setProductId(dto.getProductId());
  145. recordDTO.setDescription(dto.getDescription());
  146. recordDTO.setGuid(file.getPath());
  147. recordDTO.setSearchCondition(condition);
  148. recordDTO.setAllNum(Integer.parseInt(String.valueOf(count)));
  149. recordDTO.setSearchTime(new Date());
  150. searchRecordService.addSearchRecord(recordDTO);
  151. return records;
  152. }
  153. /**
  154. * 根据专利号获取相关图片
  155. * @param noDTO
  156. * @return
  157. * @throws IOException
  158. */
  159. public List<EsPictureNoVo> getPictureByNo(EsPictureNoDTO noDTO) throws Exception {
  160. List<EsPictureNoVo> pictureNoVos = new ArrayList<>();
  161. SearchRequest.Builder builder = new SearchRequest.Builder();
  162. //设置查询索引
  163. builder.index("patent_vector");
  164. Query query = QueryBuilders.term(i -> i.field("app_no").value(noDTO.getAppNo()));
  165. builder.query(query);
  166. builder.size(100);
  167. SearchResponse<PatentVector> response = client.search(builder.build(), PatentVector.class);
  168. List<Hit<PatentVector>> hits = response.hits().hits();
  169. for (Hit<PatentVector> hit : hits) {
  170. PatentVector vector = hit.source();
  171. EsPictureNoVo noVo = new EsPictureNoVo();
  172. noVo.setGuid(vector.getGuid());
  173. noVo.setImageIndex(vector.getImageIndex());
  174. pictureNoVos.add(noVo);
  175. }
  176. return pictureNoVos.stream().sorted(Comparator.comparing(EsPictureNoVo::getImageIndex)).collect(Collectors.toList());
  177. }
  178. /**
  179. * 拼接检索条件
  180. *
  181. * @param productId
  182. * @param keyword
  183. * @return
  184. * @throws Exception
  185. */
  186. public String appendCondition(Integer productId, String keyword) {
  187. SimpleDateFormat format = new SimpleDateFormat("yyyyMMdd");
  188. SimpleDateFormat format1 = new SimpleDateFormat("yyyy");
  189. String condition = "";
  190. String searchCondition = "";
  191. if (productId != null) {
  192. Product product = productMapper.selectById(productId);
  193. searchCondition = product.getSearchCondition();
  194. Date now = new Date();
  195. String nowFormat = format.format(now);
  196. Calendar calendar = Calendar.getInstance();
  197. calendar.setTime(now);
  198. calendar.add(Calendar.YEAR, -3);
  199. Date beforeDate = calendar.getTime();
  200. String agoFormat = format1.format(beforeDate) + "0101";
  201. String s = agoFormat + "~" + nowFormat;
  202. searchCondition = searchCondition + " AND " + "AD = " + "(" + s + ")";
  203. }
  204. if (StringUtils.isNotEmpty(searchCondition)) {
  205. if (StringUtils.isNotEmpty(keyword)) {
  206. String key = keyword.replaceAll("[,。、;,./;]", " OR ");
  207. condition = "TI = " + "(" + key + ")" + " AND " + searchCondition;
  208. } else {
  209. condition = searchCondition;
  210. }
  211. } else {
  212. //获取关键词
  213. if (StringUtils.isNotEmpty(keyword)) {
  214. // String key = dto.getKey().replaceAll("[,。、;,./;\\s]"," OR ");
  215. String key = keyword.replaceAll("[,。、;,./;]", " OR ");
  216. condition = "TI = " + "(" + key + ")";
  217. }
  218. }
  219. return condition;
  220. }
  221. }