123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240 |
- package com.example.xiaoshiweixinback.service.common;
- import co.elastic.clients.elasticsearch.ElasticsearchClient;
- import co.elastic.clients.elasticsearch._types.InlineScript;
- import co.elastic.clients.elasticsearch._types.Script;
- import co.elastic.clients.elasticsearch._types.aggregations.Aggregate;
- import co.elastic.clients.elasticsearch._types.aggregations.Aggregation;
- import co.elastic.clients.elasticsearch._types.aggregations.AggregationBuilders;
- import co.elastic.clients.elasticsearch._types.query_dsl.Query;
- import co.elastic.clients.elasticsearch._types.query_dsl.QueryBuilders;
- import co.elastic.clients.elasticsearch.core.SearchRequest;
- import co.elastic.clients.elasticsearch.core.SearchResponse;
- import co.elastic.clients.elasticsearch.core.search.FieldCollapse;
- import co.elastic.clients.elasticsearch.core.search.Hit;
- import co.elastic.clients.json.JsonData;
- import com.example.xiaoshiweixinback.business.common.base.Records;
- import com.example.xiaoshiweixinback.business.utils.BeanUtil;
- import com.example.xiaoshiweixinback.business.utils.ToolUtil;
- import com.example.xiaoshiweixinback.business.utils.parseQueryToTree.expressManager;
- import com.example.xiaoshiweixinback.business.utils.parseQueryToTree.operateNode;
- import com.example.xiaoshiweixinback.business.utils.parseQueryToTree.treeNode;
- import com.example.xiaoshiweixinback.domain.Product;
- import com.example.xiaoshiweixinback.domain.es.PatentVector;
- import com.example.xiaoshiweixinback.entity.dto.esPicture.EsPictureNoDTO;
- import com.example.xiaoshiweixinback.entity.dto.esPicture.EsPatentVectorDTO;
- import com.example.xiaoshiweixinback.entity.dto.searchRecord.AddSearchRecordDTO;
- import com.example.xiaoshiweixinback.entity.product.ProductIdDTO;
- import com.example.xiaoshiweixinback.entity.vo.esPicture.EsPictureNoVo;
- import com.example.xiaoshiweixinback.entity.vo.esPicture.EsPatentVectorVo;
- import com.example.xiaoshiweixinback.mapper.ProductMapper;
- import com.example.xiaoshiweixinback.service.SearchRecordService;
- import com.example.xiaoshiweixinback.service.importPatent.FormatQueryService;
- import lombok.RequiredArgsConstructor;
- import org.apache.commons.lang3.StringUtils;
- import org.springframework.beans.factory.annotation.Autowired;
- import org.springframework.context.annotation.Lazy;
- import org.springframework.stereotype.Service;
- import org.springframework.transaction.annotation.Propagation;
- import org.springframework.transaction.annotation.Transactional;
- import org.springframework.util.CollectionUtils;
- import java.io.File;
- import java.io.IOException;
- import java.text.ParseException;
- import java.text.SimpleDateFormat;
- import java.util.*;
- import java.util.stream.Collectors;
- @Service
- @RequiredArgsConstructor(onConstructor_ = {@Lazy})
- public class EsDenseVectorService {
- private final ElasticsearchClient client;
- @Autowired
- private FormatQueryService formatQueryService;
- @Autowired
- private GetVectorService getVectorService;
- @Autowired
- private SearchRecordService searchRecordService;
- @Autowired
- private ProductMapper productMapper;
- /**
- * 根据图片排序获取列表
- *
- * @param dto
- * @return
- * @throws IOException
- */
- @Transactional(propagation = Propagation.REQUIRED, rollbackFor = Throwable.class)
- public Records getPatentVectors(EsPatentVectorDTO dto,File file) throws Exception {
- Long pageNum = dto.getPageNum();
- Long pageSize = dto.getPageSize();
- SearchRequest.Builder builder = new SearchRequest.Builder();
- //设置查询索引
- builder.index("patent_vector");
- Query q = null;
- String condition = this.appendCondition(dto.getProductId(), dto.getKey());
- if (StringUtils.isNotEmpty(condition)) {
- //1. 解析检索条件
- treeNode tree = expressManager.getInstance().Parse(condition, false);
- //2. 从es中检索数据
- q = formatQueryService.EsQueryToQuery((operateNode) tree, "patentVector", null);
- }
- //获取图片向量
- List<Float> imageList = new ArrayList<>();
- List<String> stringList = new ArrayList<>();
- if (file != null && file.exists() && file.length() != 0) {
- stringList = getVectorService.getVectorByFile(file);
- } else if (StringUtils.isNotEmpty(dto.getDescription())) {
- stringList = getVectorService.getVectorByText(dto.getDescription());
- }
- stringList.forEach(item -> {
- Float a = Float.parseFloat(item);
- imageList.add(a);
- });
- if (!CollectionUtils.isEmpty(imageList)) {
- String source = "cosineSimilarity(params.queryVector, 'my_vector') + 1.0";
- InlineScript inlineScript = InlineScript.of(i -> i.lang("painless").params("queryVector", JsonData.of(imageList)).source(source));
- Script script = Script.of(i -> i.inline(inlineScript));
- Query query = null;
- if (q != null) {
- Query finalQ = q;
- query = QueryBuilders.scriptScore(i -> i.script(script)
- .query(finalQ));
- } else {
- query = QueryBuilders.scriptScore(i -> i.script(script)
- .query(org.springframework.data.elasticsearch.client.elc.QueryBuilders.matchAllQueryAsQuery()));
- }
- builder.query(query);
- } else {
- builder.query(q);
- }
- //根据申请号去重
- FieldCollapse collapse = FieldCollapse.of(i -> i.field("app_no"));
- builder.collapse(collapse);
- //统计总数
- Aggregation aggregation = AggregationBuilders.cardinality(i -> i.field("app_no"));
- builder.aggregations("count", aggregation);
- //分页
- if (pageNum != null && pageSize != null && pageNum > 0 && pageSize > 0) {
- builder.from((pageNum.intValue() - 1) * pageSize.intValue()).size(pageSize.intValue());
- }
- //解除最大条数限制
- builder.trackTotalHits(i -> i.enabled(true));
- SearchResponse<PatentVector> response = client.search(builder.build(), PatentVector.class);
- List<Hit<PatentVector>> hits = response.hits().hits();
- List<EsPatentVectorVo> vectorVos = new ArrayList<>();
- Double fixedScore = 1.7d;
- for (Hit<PatentVector> hit : hits) {
- Double score = hit.score();
- if (score > fixedScore) {
- PatentVector vector = hit.source();
- EsPatentVectorVo vectorVo = new EsPatentVectorVo();
- BeanUtil.copy(vector, vectorVo);
- vectorVos.add(vectorVo);
- }
- }
- Aggregate aggregate = response.aggregations().get("count");
- long total = aggregate.cardinality().value();
- Records records = new Records();
- records.setCurrent(pageNum);
- records.setSize(pageSize);
- records.setData(vectorVos);
- long count = total <= vectorVos.size() ? total : vectorVos.size();
- records.setTotal(count);
- //添加检索历史
- AddSearchRecordDTO recordDTO = new AddSearchRecordDTO();
- recordDTO.setProductId(dto.getProductId());
- recordDTO.setDescription(dto.getDescription());
- recordDTO.setGuid(file.getPath());
- recordDTO.setSearchCondition(condition);
- recordDTO.setAllNum(Integer.parseInt(String.valueOf(count)));
- recordDTO.setSearchTime(new Date());
- searchRecordService.addSearchRecord(recordDTO);
- return records;
- }
- /**
- * 根据专利号获取相关图片
- * @param noDTO
- * @return
- * @throws IOException
- */
- public List<EsPictureNoVo> getPictureByNo(EsPictureNoDTO noDTO) throws Exception {
- List<EsPictureNoVo> pictureNoVos = new ArrayList<>();
- SearchRequest.Builder builder = new SearchRequest.Builder();
- //设置查询索引
- builder.index("patent_vector");
- Query query = QueryBuilders.term(i -> i.field("app_no").value(noDTO.getAppNo()));
- builder.query(query);
- builder.size(100);
- SearchResponse<PatentVector> response = client.search(builder.build(), PatentVector.class);
- List<Hit<PatentVector>> hits = response.hits().hits();
- for (Hit<PatentVector> hit : hits) {
- PatentVector vector = hit.source();
- EsPictureNoVo noVo = new EsPictureNoVo();
- noVo.setGuid(vector.getGuid());
- noVo.setImageIndex(vector.getImageIndex());
- pictureNoVos.add(noVo);
- }
- return pictureNoVos.stream().sorted(Comparator.comparing(EsPictureNoVo::getImageIndex)).collect(Collectors.toList());
- }
- /**
- * 拼接检索条件
- *
- * @param productId
- * @param keyword
- * @return
- * @throws Exception
- */
- public String appendCondition(Integer productId, String keyword) {
- SimpleDateFormat format = new SimpleDateFormat("yyyyMMdd");
- SimpleDateFormat format1 = new SimpleDateFormat("yyyy");
- String condition = "";
- String searchCondition = "";
- if (productId != null) {
- Product product = productMapper.selectById(productId);
- searchCondition = product.getSearchCondition();
- Date now = new Date();
- String nowFormat = format.format(now);
- Calendar calendar = Calendar.getInstance();
- calendar.setTime(now);
- calendar.add(Calendar.YEAR, -3);
- Date beforeDate = calendar.getTime();
- String agoFormat = format1.format(beforeDate) + "0101";
- String s = agoFormat + "~" + nowFormat;
- searchCondition = searchCondition + " AND " + "AD = " + "(" + s + ")";
- }
- if (StringUtils.isNotEmpty(searchCondition)) {
- if (StringUtils.isNotEmpty(keyword)) {
- String key = keyword.replaceAll("[,。、;,./;]", " OR ");
- condition = "TI = " + "(" + key + ")" + " AND " + searchCondition;
- } else {
- condition = searchCondition;
- }
- } else {
- //获取关键词
- if (StringUtils.isNotEmpty(keyword)) {
- // String key = dto.getKey().replaceAll("[,。、;,./;\\s]"," OR ");
- String key = keyword.replaceAll("[,。、;,./;]", " OR ");
- condition = "TI = " + "(" + key + ")";
- }
- }
- return condition;
- }
- }
|