| import os |
| import shutil |
| from typing import Optional |
|
|
| from langchain.document_loaders import UnstructuredFileLoader |
| from langchain.embeddings import OpenAIEmbeddings |
| from langchain.schema import Document |
| from langchain.text_splitter import RecursiveCharacterTextSplitter |
| from langchain.vectorstores import FAISS |
| from loguru import logger |
| from tqdm import tqdm |
|
|
| from .parser import parse_pdf |
|
|
| PROMPT_TEMPLATE = """已知信息: |
| {context} |
| |
| 根据上述已知信息,简洁和专业的来回答用户的问题。 |
| 如果无法从中得到答案,请说 “根据已知信息无法回答该问题” 或 “没有提供足够的相关信息”,不允许在答案中添加编造成分,答案请使用中文。 |
| 问题是:{question}""" |
|
|
|
|
| def _get_documents(filepath, chunk_size=500, chunk_overlap=0, two_column=False): |
| text_splitter = RecursiveCharacterTextSplitter( |
| chunk_size=chunk_size, |
| chunk_overlap=chunk_overlap, |
| ) |
| file_type = os.path.splitext(filepath)[1] |
|
|
| logger.info(f"Loading file: {filepath}") |
| texts = Document(page_content="", metadata={"source": filepath}) |
| try: |
| if file_type == ".pdf": |
| logger.debug("Loading PDF...") |
| try: |
| pdftext = parse_pdf(filepath, two_column).text |
| except: |
| from PyPDF2 import PdfReader |
|
|
| pdftext = "" |
| with open(filepath, "rb") as pdfFileObj: |
| pdfReader = PdfReader(pdfFileObj) |
| for page in tqdm(pdfReader.pages): |
| pdftext += page.extract_text() |
|
|
| texts = Document(page_content=pdftext, metadata={"source": filepath}) |
|
|
| elif file_type == ".docx": |
| from langchain.document_loaders import UnstructuredWordDocumentLoader |
|
|
| logger.debug("Loading Word...") |
| loader = UnstructuredWordDocumentLoader(filepath) |
| texts = loader.load() |
| elif file_type == ".pptx": |
| from langchain.document_loaders import UnstructuredPowerPointLoader |
|
|
| logger.debug("Loading PowerPoint...") |
| loader = UnstructuredPowerPointLoader(filepath) |
| texts = loader.load() |
| elif file_type == ".epub": |
| from langchain.document_loaders import UnstructuredEPubLoader |
|
|
| logger.debug("Loading EPUB...") |
| loader = UnstructuredEPubLoader(filepath) |
| texts = loader.load() |
| elif file_type == ".md": |
| loader = UnstructuredFileLoader(filepath, mode="elements") |
| return loader.load() |
| else: |
| loader = UnstructuredFileLoader(filepath, mode="elements") |
| return loader.load_and_split(text_splitter=text_splitter) |
| except Exception as e: |
| import traceback |
| logger.error(f"Error loading file: {filepath}") |
| traceback.print_exc() |
|
|
| return text_splitter.split_documents([texts]) |
|
|
|
|
| def get_documents(filepath, chunk_size=500, chunk_overlap=0, two_column=False): |
| documents = [] |
| logger.debug("Loading documents...") |
| if os.path.isfile(filepath): |
| documents.extend( |
| _get_documents( |
| filepath, |
| chunk_size=chunk_size, |
| chunk_overlap=chunk_overlap, |
| two_column=two_column |
| ) |
| ) |
| else: |
| for file in filepath: |
| documents.extend( |
| _get_documents( |
| file, |
| chunk_size=chunk_size, |
| chunk_overlap=chunk_overlap, |
| two_column=two_column |
| ) |
| ) |
| logger.debug("Documents loaded.") |
| return documents |
|
|
|
|
| def generate_prompt(related_docs, query: str, prompt_template=PROMPT_TEMPLATE) -> str: |
| context = "\n".join([doc[0].page_content for doc in related_docs]) |
| return prompt_template.replace("{question}", query).replace("{context}", context) |
|
|
|
|
| class DocQAPromptAdapter: |
| def __init__(self, chunk_size: Optional[int] = 500, chunk_overlap: Optional[int] = 0, api_key: Optional[str] = "xxx"): |
| self.embeddings = OpenAIEmbeddings(openai_api_key=api_key) |
| self.chunk_size = chunk_size |
| self.chunk_overlap = chunk_overlap |
|
|
| self.vector_store = None |
|
|
| def create_vector_store(self, file_path, vs_path, embeddings=None): |
| documents = get_documents(file_path, chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap) |
| self.vector_store = FAISS.from_documents(documents, self.embeddings if not embeddings else embeddings) |
| self.vector_store.save_local(vs_path) |
|
|
| def reset_vector_store(self, vs_path, embeddings=None): |
| self.vector_store = FAISS.load_local(vs_path, self.embeddings if not embeddings else embeddings) |
|
|
| @staticmethod |
| def delete_files(files): |
| for file in files: |
| if os.path.exists(file): |
| if os.path.isfile(file): |
| os.remove(file) |
| else: |
| shutil.rmtree(file) |
|
|
| def __call__(self, query, vs_path=None, topk=6): |
| if vs_path is not None and os.path.exists(vs_path): |
| self.reset_vector_store(vs_path) |
| self.vector_store.embedding_function = self.embeddings.embed_query |
| related_docs_with_score = self.vector_store.similarity_search_with_score(query, k=topk) |
| return generate_prompt(related_docs_with_score, query) |
|
|