| import os |
| import sys |
| from haystack.document_stores.in_memory import InMemoryDocumentStore |
| from datasets import load_from_disk |
| from haystack import Document |
| from haystack.components.writers import DocumentWriter |
| from haystack.components.embedders import SentenceTransformersDocumentEmbedder |
| from haystack.components.preprocessors.document_splitter import DocumentSplitter |
| from haystack import Pipeline |
| from haystack.components.retrievers.in_memory import ( |
| InMemoryBM25Retriever, |
| InMemoryEmbeddingRetriever, |
| ) |
| from haystack.components.embedders import SentenceTransformersTextEmbedder |
| from haystack.components.joiners import DocumentJoiner |
|
|
| |
| from haystack.components.rankers import SentenceTransformersSimilarityRanker |
| from haystack.document_stores.types import DuplicatePolicy |
| from haystack.components.converters import PyPDFToDocument |
| from haystack.components.preprocessors import DocumentCleaner |
| from haystack.components.builders import PromptBuilder |
| from pathlib import Path |
| from haystack.components.converters import DOCXToDocument |
| import re |
| import argparse |
|
|
|
|
| """ |
| python hybrid.py -c newstore.store │ |
| python hybrid.py -r newstore.store -q "who is pufendorf" |
| """ |
|
|
| embedding_model = "sentence-transformers/all-MiniLM-L6-v2" |
|
|
| |
| reranker_model = "BAAI/bge-reranker-base" |
|
|
|
|
| def build_store_from_dir(dir_path: str) -> InMemoryDocumentStore: |
| root = Path(dir_path) |
| pdfs = sorted(str(p) for p in root.rglob("*.pdf")) |
| docxs = sorted(str(p) for p in root.rglob("*.docx")) |
|
|
| print(pdfs) |
| print(docxs) |
|
|
| pdf_conv = PyPDFToDocument() |
| docx_conv = DOCXToDocument() |
|
|
| docs = [] |
| if pdfs: |
| out = pdf_conv.run(sources=pdfs, meta=[{"source": p} for p in pdfs]) |
| docs.extend(out["documents"]) |
| if docxs: |
| out = docx_conv.run(sources=docxs, meta=[{"source": p} for p in docxs]) |
| docs.extend(out["documents"]) |
|
|
| return docs |
|
|
|
|
| |
| |
| |
|
|
|
|
| |
| def create_index_split(docs, doc_store, split_length=5, split_overlap=1): |
| document_splitter = DocumentSplitter( |
| split_by="sentence", split_length=split_length, split_overlap=split_overlap |
| ) |
| document_embedder = SentenceTransformersDocumentEmbedder( |
| model=embedding_model, |
| ) |
| document_writer = DocumentWriter(doc_store, policy=DuplicatePolicy.SKIP) |
|
|
| indexing_pipeline = Pipeline() |
| indexing_pipeline.add_component("document_splitter", document_splitter) |
| indexing_pipeline.add_component("document_embedder", document_embedder) |
| indexing_pipeline.add_component("document_writer", document_writer) |
|
|
| indexing_pipeline.connect("document_splitter", "document_embedder") |
| indexing_pipeline.connect("document_embedder", "document_writer") |
|
|
| indexing_pipeline.run({"document_splitter": {"documents": docs}}) |
|
|
| hybrid_retrieval = create_hybrid_retriever(doc_store) |
| return hybrid_retrieval |
|
|
|
|
| |
| |
| def create_hybrid_retriever(doc_store): |
| text_embedder = SentenceTransformersTextEmbedder( |
| model=embedding_model, |
| ) |
| embedding_retriever = InMemoryEmbeddingRetriever(doc_store) |
| bm25_retriever = InMemoryBM25Retriever(doc_store) |
|
|
| document_joiner = DocumentJoiner() |
| |
| |
| ranker = SentenceTransformersSimilarityRanker(model=reranker_model) |
|
|
| hybrid_retrieval = Pipeline() |
| hybrid_retrieval.add_component("text_embedder", text_embedder) |
| hybrid_retrieval.add_component("embedding_retriever", embedding_retriever) |
| hybrid_retrieval.add_component("bm25_retriever", bm25_retriever) |
| hybrid_retrieval.add_component("document_joiner", document_joiner) |
| hybrid_retrieval.add_component("ranker", ranker) |
|
|
| hybrid_retrieval.connect("text_embedder", "embedding_retriever") |
| hybrid_retrieval.connect("bm25_retriever", "document_joiner") |
| hybrid_retrieval.connect("embedding_retriever", "document_joiner") |
| hybrid_retrieval.connect("document_joiner", "ranker") |
|
|
| return hybrid_retrieval |
|
|
|
|
| def create_embedding_retriever(doc_store): |
| text_embedder = SentenceTransformersTextEmbedder( |
| model=embedding_model, |
| ) |
| embedding_retriever = InMemoryEmbeddingRetriever(doc_store) |
|
|
| ranker = SentenceTransformersSimilarityRanker(model=reranker_model) |
|
|
| embedding_retrieval = Pipeline() |
| embedding_retrieval.add_component("text_embedder", text_embedder) |
| embedding_retrieval.add_component("embedding_retriever", embedding_retriever) |
| embedding_retrieval.add_component("ranker", ranker) |
|
|
| embedding_retrieval.connect("text_embedder", "embedding_retriever") |
| embedding_retrieval.connect("embedding_retriever", "ranker") |
|
|
| return embedding_retrieval |
|
|
|
|
| def create_bm25_retriever(doc_store): |
| bm25_retriever = InMemoryBM25Retriever(doc_store) |
|
|
| document_joiner = DocumentJoiner() |
| ranker = SentenceTransformersSimilarityRanker(model=reranker_model) |
|
|
| bm25_retrieval = Pipeline() |
| bm25_retrieval.add_component("bm25_retriever", bm25_retriever) |
| bm25_retrieval.add_component("ranker", ranker) |
| bm25_retrieval.connect("bm25_retriever", "ranker") |
|
|
| return bm25_retrieval |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| def retrieve(retriever, query, top_k=8, scale=True): |
| result = retriever.run( |
| { |
| "text_embedder": {"text": query}, |
| "bm25_retriever": { |
| "query": query, |
| "top_k": top_k, |
| "scale_score": scale, |
| |
| |
| |
| }, |
| "embedding_retriever": {"top_k": top_k, "scale_score": True}, |
| "ranker": {"query": query, "top_k": top_k, "scale_score": True}, |
| } |
| ) |
| |
| |
| return result["ranker"]["documents"] |
|
|
|
|
| def retrieve_embedded(retriever, query, top_k=8, scale=True): |
| result = retriever.run( |
| { |
| "text_embedder": {"text": query}, |
| "embedding_retriever": {"top_k": top_k, "scale_score": scale}, |
| "ranker": {"query": query, "top_k": top_k, "scale_score": scale}, |
| } |
| ) |
| return result["ranker"]["documents"] |
|
|
|
|
| def retrieve_bm25(retriever, query, top_k=8, scale=True): |
| result = retriever.run( |
| { |
| "bm25_retriever": { |
| "query": query, |
| "top_k": top_k, |
| "scale_score": scale, |
| |
| |
| |
| }, |
| "ranker": {"query": query, "top_k": top_k, "scale_score": True}, |
| } |
| ) |
| |
| |
| return result["ranker"]["documents"] |
|
|
|
|
| def print_res(doc, width=0): |
| try: |
| txt = doc.meta["researcher_name"] + ":" + " ".join(doc.content.split()) |
| except KeyError: |
| txt = " ".join(doc.content.split()) |
| if width > 0: |
| txt_width = width - 8 - 3 - 1 |
| txt = txt[0:txt_width] + "..." |
| print("{:.5f}".format(doc.score), txt) |
|
|
|
|
| if __name__ == "__main__": |
| terminal_width = os.get_terminal_size().columns |
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| "-c", "--create_store", help="Create a new data store.", default=None |
| ) |
| parser.add_argument("-d", "--dataset", help="Dataset filename.", default=None) |
| parser.add_argument("-r", "--read_store", help="Read a data store.", default=None) |
| parser.add_argument( |
| "-s", |
| "--scale", |
| action="store_false", |
| help="Do not scale retrieved scores.", |
| default=True, |
| ) |
| parser.add_argument("--top_k", type=int, help="Retriever top_k.", default=8) |
| parser.add_argument("-q", "--query", help="Query DBs.", default=None) |
| args = parser.parse_args() |
| query = args.query |
|
|
| if args.create_store: |
| docs = build_store_from_dir("../Gradio/docs") |
| rs_doc_store = InMemoryDocumentStore() |
| print("Starting create_index_nosplit()") |
| create_index_split(docs, rs_doc_store) |
| rs_doc_store.save_to_disk(args.create_store) |
| print("Ready create_index_nosplit()") |
|
|
| if not args.query: |
| sys.exit(0) |
|
|
| if not args.read_store and not args.create_store: |
| args.read_store = "research_docs_ns.store" |
| elif not args.read_store and args.create_store: |
| args.read_store = args.create_store |
| print(f"Loading document store {args.read_store}...") |
| doc_store = InMemoryDocumentStore().load_from_disk(args.read_store) |
| print(f"Number of documents: {doc_store.count_documents()}.") |
|
|
| |
| hybrid_retrieval = create_hybrid_retriever(doc_store) |
|
|
| documents = retrieve(hybrid_retrieval, query, top_k=args.top_k, scale=args.scale) |
| print("=" * 80) |
| print("== Hybrid") |
| print("=" * 80) |
| for doc in documents: |
| |
| print_res(doc, terminal_width) |
|
|
| embedding_retrieval = create_embedding_retriever(doc_store) |
| documents = retrieve_embedded( |
| embedding_retrieval, query, top_k=args.top_k, scale=args.scale |
| ) |
| print("=" * 80) |
| print("== Embedding") |
| print("=" * 80) |
| for doc in documents: |
| print_res(doc, terminal_width) |
|
|
| bm25_retrieval = create_bm25_retriever(doc_store) |
| documents = retrieve_bm25(bm25_retrieval, query, top_k=args.top_k, scale=args.scale) |
| print("=" * 80) |
| print("== bm25") |
| print("=" * 80) |
| for doc in documents: |
| print_res(doc, terminal_width) |
|
|