双流RAG:文本+视觉

APPLICATION Nov 2, 2024

传统的检索增强生成 (RAG) 方法彻底改变了我们与文档的交互方式,但它仍然缺少关键的视觉上下文。如果 RAG 不仅可以读取,还可以查看,会怎么样?

通过将视觉语言模型 (VLM) 与传统文本处理相结合,我们开发了一种双流 RAG 架构,可处理 PDF 文档中的文本和视觉内容。我们的方法利用 Qdrant 的多向量功能来存储文本和图像嵌入,从而实现更丰富的上下文检索。查询时,系统不仅匹配文本 - 它实际上“看到”文档页面,从而产生更准确和上下文感知的响应。

在本文中,我们将探讨这种视觉增强型 RAG 系统如何为文档理解和检索开辟新的可能性。

1、系统架构

让我根据图表解释一下这种创新的视觉增强型 RAG 系统的架构。

系统以 PDF 文档作为主要输入,经过双重处理流以最大限度地提取信息。在第一个流中,每页都转换为图像,而在并行流中,从每页中提取文本。这种双重方法可确保在处理阶段不会丢失任何信息。

然后,提取的内容被矢量化并存储在 Qdrant 中,这是一个矢量数据库,可有效处理每个文档的多种矢量类型。Qdrant 中的每个条目都包含图像和文本矢量,以及必要的元数据,包括页码、base64 编码的页面图像和提取的文本。

当用户提交查询时,Qdrant 的预取功能就会发挥作用,根据矢量相似性检索前三个最相关的结果(如本实现中配置的那样)。

这就是架构变得特别有趣的地方——系统不会止步于传统的基于文本的检索。相同的用户查询以及检索到的 base64 编码图像被传递给视觉语言模型 (VLM),在本例中具体是 OpenAI 的视觉模型。这使得系统能够对实际文档布局和内容进行视觉分析,从而提供额外的理解层面。

架构的最后一部分涉及聚合语言学习模型 (LLM),该模型结合了基于文本的检索和视觉模型的分析结果。该聚合器综合了来自两个流的信息,产生了一个综合响应,该响应利用了对文档的文本和视觉理解。结果是一个更强大、更具有上下文感知能力的系统,可以从文本和视觉角度提供具有强大支持证据的答案。

这种架构的出色之处在于它能够理解文档,而不仅仅是文本,而是它们本来应该被看到的样子——包括布局、格式和通常携带关键上下文信息的视觉元素。这种双流方法与现代矢量搜索功能和视觉模型相结合,代表了 RAG 系统的重大进步。

有时文本可能不足以回答你的查询,因此这就是视觉起作用的地方。

2、实现细节

让我们看看架构的摄取部分,如下所示:

数据采集​​到矢量库

让我们设计一个名为 pdf_processor.py 的类,并采用所示的方法:

pdf 处理器类的概要
from pdf2image import convert_from_path
from pypdf import PdfReader
import os


class PDFProcessor:
    """
    A class to handle PDF processing operations including text extraction and image conversion.
    """

    def __init__(self, pdf_path, output_dir):
        """
        Initialize the PDF processor.

        Parameters:
        - pdf_path: str, path to the PDF file
        - output_dir: str, directory to save the outputs
        """
        self.pdf_path = pdf_path
        self.output_dir = output_dir
        self.saved_images = []
        self.page_texts = []
        self.page_dicts = []

        # Create output directory if it doesn't exist
        os.makedirs(self.output_dir, exist_ok=True)

    def extract_text(self):
        """
        Extract text from each page of the PDF.
        """
        print("Extracting text from PDF...")
        reader = PdfReader(self.pdf_path)

        # Extract text from each page
        for i, page in enumerate(reader.pages):
            text = page.extract_text()
            self.page_texts.append(text)

            # Save text to file
            text_file_path = os.path.join(self.output_dir, f'page_{i + 1}.txt')
            with open(text_file_path, 'w', encoding='utf-8') as f:
                f.write(text)
            print(f"Saved text from page {i + 1} to {text_file_path}")

    def convert_to_images(self, dpi=200, fmt='png'):
        """
        Convert each page of the PDF to images.

        Parameters:
        - dpi: int, resolution of output images
        - fmt: str, output image format
        """
        print("Converting PDF pages to images...")
        pages = convert_from_path(self.pdf_path, dpi=dpi)

        # Save each page as an image
        for i, page in enumerate(pages):
            image_path = os.path.join(self.output_dir, f'page_{i + 1}.{fmt}')
            page.save(image_path, fmt)
            self.saved_images.append(image_path)
            print(f"Saved image from page {i + 1} to {image_path}")

    def create_page_dicts(self, fmt='png'):
        """
        Create a list of dictionaries containing page information.

        Parameters:
        - fmt: str, image format used (needed for filenames)

        Returns:
        - list of dictionaries with page information
        """
        num_pages = max(len(self.saved_images) if self.saved_images else 0,
                        len(self.page_texts) if self.page_texts else 0)

        self.page_dicts = []
        for i in range(num_pages):
            page_dict = {
                "image": f"page_{i + 1}.{fmt}" if self.saved_images else None,
                "text": f"page_{i + 1}.txt" if self.page_texts else None
            }
            self.page_dicts.append(page_dict)

        return self.page_dicts

    def process(self, extract_images=True, extract_text=True, dpi=200, fmt='png'):
        """
        Process the PDF file with specified operations.

        Parameters:
        - extract_images: bool, whether to convert pages to images
        - extract_text: bool, whether to extract text
        - dpi: int, resolution of output images
        - fmt: str, output image format

        Returns:
        - tuple: (list of image paths, list of text content, list of page dictionaries)
        """
        try:
            if extract_text:
                self.extract_text()

            if extract_images:
                self.convert_to_images(dpi=dpi, fmt=fmt)

            self.create_page_dicts(fmt=fmt)

            return self.saved_images, self.page_texts, self.page_dicts

        except Exception as e:
            print(f"Error processing PDF: {str(e)}")
            return [], [], []

    def print_extracted_text(self):
        """
        Print the extracted text from each page with clear separation.
        """
        for i, text in enumerate(self.page_texts, 1):
            print(f"\n{'=' * 40}")
            print(f"Page {i}")
            print(f"{'=' * 40}")
            print(text.strip())


# Example driver usage
# if __name__ == "__main__":
#     # Example parameters
#     pdf_file = "data/rag.pdf"  # infact any pdf as input here.
#     output_folder = "pdf_output"
#
#     # Create processor instance
#     processor = PDFProcessor(pdf_file, output_folder)
#
#     # Process PDF - extract both images and text
#     image_paths, texts, page_dicts = processor.process(
#         extract_images=True,
#         extract_text=True,
#         dpi=200,
#         fmt='png'
#     )
#
#     print("\nProcessing complete.")
#     print("\nPage information:")
#     for i, page_info in enumerate(page_dicts, 1):
#         print(f"Page {i}:", page_info)

pdf_output 收集图像和全文以供进一步处理。现在让我们创建另一个类 DataIndexerAndRetriever.py,如下所示:

pdf 处理器类的概述
from dotenv import load_dotenv, find_dotenv
from qdrant_client import QdrantClient, models
from fastembed import TextEmbedding
from sentence_transformers import SentenceTransformer
from PIL import Image
import openai
import base64
import io
import os

from pdf_processor import PDFProcessor


class DataIndexerAndRetriever:
    def __init__(self, data_dir='./pdf_output', qdrant_url="http://localhost:6333", qdrant_api_key='th3s3cr3tk3y'):
        """
        Initialize the Research Paper Processor.

        Parameters:
        - data_dir: str, directory containing PDF output files
        - qdrant_url: str, Qdrant server URL
        - qdrant_api_key: str, Qdrant API key
        """
        # Load environment variables
        _ = load_dotenv(find_dotenv())

        self.data_dir = data_dir
        self.collection_name = 'research_papers'

        # Initialize models
        self.client = QdrantClient(url=qdrant_url, api_key=qdrant_api_key)
        self.image_embedding_model = SentenceTransformer("clip-ViT-B-32")
        self.text_embedding_model = TextEmbedding(
            model_name='sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'
        )
        # Initialize OpenAI API
        api_key = 'sk-proj-api-key'
        self.openai_client = openai.OpenAI(api_key=api_key)

        # Initialize collection if it doesn't exist
        self._initialize_collection()

    def _initialize_collection(self):
        """Initialize Qdrant collection if it doesn't exist."""
        if not self.client.collection_exists(collection_name=self.collection_name):
            self.client.create_collection(
                collection_name=self.collection_name,
                vectors_config={
                    "clip-ViT-B-32": models.VectorParams(
                        size=512,
                        distance=models.Distance.COSINE
                    ),
                    "paraphrase-multilingual-MiniLM-L12-v2": models.VectorParams(
                        size=384,
                        distance=models.Distance.COSINE
                    ),
                }
            )

    def get_text_embeddings(self, text_file_path):
        """
        Get embeddings for text file content.

        Parameters:
        - text_file_path: str, path to text file

        Returns:
        - tuple: (text embeddings, full text content)
        """
        with open(file=text_file_path, mode='r') as data:
            full_text = data.read()
        return next(self.text_embedding_model.passage_embed(full_text)), full_text

    def image_to_base64(self, image_path):
        """
        Convert image to base64 and get embeddings.

        Parameters:
        - image_path: str, path to image file

        Returns:
        - tuple: (image embeddings, base64 encoded string)
        """
        try:
            with open(image_path, "rb") as image_file:
                encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
            with Image.open(image_path) as img:
                image_embedding = self.image_embedding_model.encode(img).tolist()
            return image_embedding, encoded_string
        except Exception as e:
            print(f"Error converting image to base64: {str(e)}")
            return None

    def base64_to_image(self, base64_string, output_path=None, fmt='png'):
        """
        Convert base64 string back to image.

        Parameters:
        - base64_string: str, base64 encoded image string
        - output_path: str, path to save decoded image (optional)
        - fmt: str, image format (default: 'png')

        Returns:
        - PIL.Image or str: Image object or path to saved image
        """
        try:
            image_data = base64.b64decode(base64_string)
            image = Image.open(io.BytesIO(image_data))

            if output_path:
                image.save(output_path, fmt)
                return output_path

            return image
        except Exception as e:
            print(f"Error converting base64 to image: {str(e)}")
            return None

    def index_pages(self, pages_data):
        """
        Process and index pages data.

        Parameters:
        - pages_data: list of dict, containing image and text file information
        """
        for index, obj in enumerate(pages_data):
            image_path = os.path.join(self.data_dir, obj["image"])
            text_file_path = os.path.join(self.data_dir, obj["text"])

            image_embedding, base64str = self.image_to_base64(image_path)
            text_embedding, full_text = self.get_text_embeddings(text_file_path=text_file_path)

            points = [
                models.PointStruct(
                    id=index + 1,
                    vector={
                        "clip-ViT-B-32": image_embedding,
                        "paraphrase-multilingual-MiniLM-L12-v2": text_embedding
                    },
                    payload={
                        "_id": index + 1,
                        "base64str": base64str,
                        "full_text": full_text,
                        "page": index + 1
                    }
                )
            ]

            self.client.upsert(
                collection_name=self.collection_name,
                points=points
            )

    def query_with_rrf(self, query_text: str = '', query_image_path: str = ''):
        """
        Query the collection using Reciprocal Rank Fusion.

        Parameters:
        - query_text: str, text query
        - query_image_path: str, path to query image

        Returns:
        - list: search results
        """
        text_embedding = None
        if query_text != '':
            text_embedding = next(self.text_embedding_model.embed(query_text)).tolist()

        image_embedding = None
        if query_image_path != '':
            with Image.open(query_image_path) as img:
                image_embedding = self.image_embedding_model.encode(img).tolist()

        prefetch = None
        if text_embedding and len(text_embedding) > 0:
            prefetch = [
                models.Prefetch(
                    query=text_embedding,
                    using="paraphrase-multilingual-MiniLM-L12-v2",
                    limit=3,
                )
            ]
        if image_embedding and len(image_embedding) > 0:
            prefetch = [
                models.Prefetch(
                    query=image_embedding,
                    using="clip-ViT-B-32",
                    limit=3,
                )
            ]

        results = self.client.query_points(
            collection_name=self.collection_name,
            prefetch=prefetch,
            query=models.FusionQuery(
                fusion=models.Fusion.RRF
            ),
            with_payload=True,
            limit=3,
        )
        return results

    # Function to ask a question about the image using OpenAI API
    def ask_image_question(self, base64_image, question):
        try:
            # Send the image and question to the OpenAI API
            response = self.openai_client.chat.completions.create(
                model="gpt-4o-mini",
                messages=[
                    {
                        "role": "user",
                        "content": [
                            {
                                "type": "text",
                                "text": question + ". Support your answer with evidence from given context. example: page number, section heading etc",
                            },
                            {
                                "type": "image_url",
                                "image_url": {
                                    "url": f"data:image/jpeg;base64,{base64_image}"
                                },
                            },
                        ],
                    }
                ],
            )

            # Extract and return the response
            answer = response.choices[0].message.content
            return answer

        except Exception as e:
            print(f"Error during API call: {e}")
            return None


# Example usage
if __name__ == "__main__":
    # Sample pages data
    # Example parameters
    pdf_file = "data/rag.pdf"
    output_folder = "pdf_output"

    # Create processor instance (open the below comments for the first time 
    # when you want to process the pdf file)
    # processor = PDFProcessor(pdf_file, output_folder)

    # Process PDF - extract both images and text
    # image_paths, texts, page_dicts = processor.process(
    #     extract_images=True,
    #     extract_text=True,
    #     dpi=200,
    #     fmt='png'
    # )

    # Initialize processor
    processor = DataIndexerAndRetriever()

    # Process pages (uncomment to run indexing into qdrant)
    # processor.index_pages(page_dicts)

    # Query example
    question = 'What is the OpenAI assistants workflow?'
    result = processor.query_with_rrf(query_text=question)
    for point in result.points:
        response = processor.ask_image_question(base64_image=point.payload['base64str'],
                                                question=question)
        print("-" * 50)
        print(response)

以下是代码主要功能的简要概述:

DataIndexerAndRetriever 类处理具有文本和图像功能的双流文档处理:

  • 初始化与 Qdrant 矢量数据库的连接并加载所需的嵌入模型(用于图像的 CLIP,用于文本的 MiniLM)
  • 设置 OpenAI 客户端以进行视觉模型集成

核心处理功能:

  • 将 PDF 页面转换为图像和文本
  • 为图像和文本内容生成嵌入
  • 使用每个文档页面的双矢量将数据存储在 Qdrant 中

检索系统:

  • 使用互惠等级融合(RRF) 在文本和图像向量中进行搜索
  • 默认返回前 3 个最相关的结果
  • 结果中包含原始 base64 图像和全文

视觉集成:

  • 使用 OpenAI 的 GPT-4o 视觉模型处理查询
  • 接受用户问题和相关页面图像
  • 返回带有文档上下文证据的答案

主要工作流程:

  • 处理 PDF 文档
  • 在 Qdrant 中索引内容
  • 接受用户查询
  • 使用文本和视觉功能返回上下文答案

3、结果

观察我们提出的问题,看看 OpenAi 如何明显地回答,说它在下面附件的“图 2”中明确提到:

4、结束语

总之,将视觉模型集成到检索增强生成 (RAG) 系统中代表了文档处理的重大进步。通过利用图像和文本数据,我们增强了索引和检索功能,从而允许更丰富、更符合上下文的响应。

这种创新方法不仅提高了信息检索的准确性,而且还提供了令人信服的证据,强化了从文档中得出的见解。随着我们继续探索视觉和语言模型之间的协同作用,更有效、更细致的文档理解的潜力将越来越大。


原文链接:Revolutionizing RAG by Integrating Vision Models for Enhanced Document Processing

汇智网翻译整理,转载请标明出处

Tags