マルチモーダルRAGの実装方法:Step by Stepチュートリアル

はじめに

このブログ記事では、Jupyterノートブックを使用してマルチモーダルRAG(Retrieval-Augmented Generation)の実装方法を紹介します。このガイドでは、必要なインストール、セットアップ、およびコードスニペットをカバーし、マルチモーダルRAGをアプリケーションに統合する方法を説明します。

前提条件

まず、以下のライブラリがインストールされていることを確認してください:

!pip install -U --quiet langchain langchain_community chromadb langchain-experimental langchain-google-vertexai
!pip install --quiet "unstructured[all-docs]" pypdf pillow==10.0.0 pydantic lxml pillow matplotlib chromadb tiktoken

Google Cloud Vertex AIのセットアップ

最初に、Google Cloud Vertex AIを認証し、初期化します:

PROJECT_ID = " " # 自分のプロジェクトIDを入力してください。
REGION = "us-central1" # デフォルトのリージョン利用します。

from google.colab import auth
auth.authenticate_user()

import vertexai
vertexai.init(project=PROJECT_ID, location=REGION)

データの準備

次に、必要なデータをダウンロードして解凍します:

import logging
import zipfile
import requests

logging.basicConfig(level=logging.INFO)

data_url = "https://storage.googleapis.com/benchmarks-artifacts/langchain-docs-benchmarking/cj.zip"
result = requests.get(data_url)
filename = "cj.zip"
with open(filename, "wb") as file:
    file.write(result.content)

with zipfile.ZipFile(filename, "r") as zip_ref:
    zip_ref.extractall()

ドキュメントの読み込み

PyPDFLoaderを使用してPDFドキュメントを読み込みます:

from langchain_community.document_loaders import PyPDFLoader

loader = PyPDFLoader("./cj/cj.pdf")
docs = loader.load()
texts = [d.page_content for d in docs]

テキストの要約生成

テキスト要約を生成する関数を定義し要約を生成してみます:

from langchain_google_vertexai import VertexAI , ChatVertexAI , VertexAIEmbeddings
from langchain.prompts import PromptTemplate
from langchain_core.messages import AIMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableLambda
from IPython.display import Markdown as md

def generate_text_summaries(texts, tables, summarize_texts=False):
   prompt_text = """あなたは、表やテキストを要約して検索用にまとめるアシスタントです。 \
   これらの要約は埋め込まれ、生のテキストや表の要素を検索するために使用されます。 \
   検索に最適化された簡潔な言葉で要約して綺麗なmarkdown形式の文章を提供してください。表やテキスト:{element}  \
   日本語で出力してください。  """
   prompt = PromptTemplate.from_template(prompt_text)
   empty_response = RunnableLambda(
       lambda x: AIMessage(content="エラーサポートしないファイルです。")
   )

   model = VertexAI(
       temperature=0, model_name="gemini-1.0-pro", max_output_tokens=1024
   ).with_fallbacks([empty_response])
   summarize_chain = {"element": lambda x: x} | prompt | model | StrOutputParser()

   text_summaries = []
   table_summaries = []

   if texts and summarize_texts:
       text_summaries = summarize_chain.batch(texts, {"max_concurrency": 1})
   elif texts:
       text_summaries = texts

   if tables:
       table_summaries = summarize_chain.batch(tables, {"max_concurrency": 1})

   return text_summaries, table_summaries

text_summaries, table_summaries = generate_text_summaries(
   texts, tables, summarize_texts=True
)

md(text_summaries[0])

実行結果:

要約:

  • OpenAIは開発者向けイベントを開催し、コンテキストウィンドウの拡大とコスト削減の2つの重要な発表を行いました。
  • Jamin Ballは、AIの収益化は、以下の順序で進むと予測しています。
    1. Raw silicon (Nvidiaが大量に購入)
    2. モデルプロバイダー (OpenAI、Anthropicなど)
    3. アプリケーション開発者
    4. エンドユーザー

キーワード:

  • OpenAI
  • コンテキストウィンドウ
  • コスト削減
  • AIの収益化
  • Raw silicon
  • モデルプロバイダー
  • アプリケーション開発者
  • エンドユーザー

マルチモーダルRAGチェーンの定義

マルチモーダルRAGのための関数とチェーンを定義します:

import uuid
from langchain.retrievers.multi_vector import MultiVectorRetriever
from langchain.storage import InMemoryStore
from langchain_community.vectorstores import Chroma
from langchain_core.documents import Document

def create_multi_vector_retriever(
   vectorstore, text_summaries, texts, table_summaries, tables, image_summaries, images
):
   """
   Create retriever that indexes summaries, but returns raw images or texts
   """

   store = InMemoryStore()
   id_key = "doc_id"

   retriever = MultiVectorRetriever(
       vectorstore=vectorstore,
       docstore=store,
       id_key=id_key,
   )

   def add_documents(retriever, doc_summaries, doc_contents):
       doc_ids = [str(uuid.uuid4()) for _ in doc_contents]
       summary_docs = [
           Document(page_content=s, metadata={id_key: doc_ids[i]})
           for i, s in enumerate(doc_summaries)
       ]
       retriever.vectorstore.add_documents(summary_docs)
       retriever.docstore.mset(list(zip(doc_ids, doc_contents)))

   if text_summaries:
       add_documents(retriever, text_summaries, texts)

   if table_summaries:
       add_documents(retriever, table_summaries, tables)

   if image_summaries:
       add_documents(retriever, image_summaries, images)

   return retriever

vectorstore = Chroma(
   collection_name="mm_rag_cj_blog",
   embedding_function=VertexAIEmbeddings(model_name="textembedding-gecko@latest"),
)

retriever_multi_vector_img = create_multi_vector_retriever(
   vectorstore,
   text_summaries,
   texts,
   table_summaries,
   tables,
   image_summaries,
   img_base64_list,
)
    import io
    import re
    
    from IPython.display import HTML, display
    from langchain_core.runnables import RunnableLambda, RunnablePassthrough
    from PIL import Image
    
    
    def plt_img_base64(img_base64):
       image_html = f'<img src="data:image/jpeg;base64,{img_base64}" />'
       display(HTML(image_html))
    
    
    def looks_like_base64(sb):
       return re.match("^[A-Za-z0-9+/]+[=]{0,2}$", sb) is not None
    
    
    def is_image_data(b64data):
       image_signatures = {
           b"\xFF\xD8\xFF": "jpg",
           b"\x89\x50\x4E\x47\x0D\x0A\x1A\x0A": "png",
           b"\x47\x49\x46\x38": "gif",
           b"\x52\x49\x46\x46": "webp",
       }
       try:
           header = base64.b64decode(b64data)[:8]  # Decode and get the first 8 bytes
           for sig, format in image_signatures.items():
               if header.startswith(sig):
                   return True
           return False
       except Exception:
           return False
    
    
    def resize_base64_image(base64_string, size=(128, 128)):
       img_data = base64.b64decode(base64_string)
       img = Image.open(io.BytesIO(img_data))
    
       resized_img = img.resize(size, Image.LANCZOS)
    
       buffered = io.BytesIO()
       resized_img.save(buffered, format=img.format)
    
       return base64.b64encode(buffered.getvalue()).decode("utf-8")
    
    
    def split_image_text_types(docs):
       b64_images = []
       texts = []
       for doc in docs:
           if isinstance(doc, Document):
               doc = doc.page_content
           if looks_like_base64(doc) and is_image_data(doc):
               doc = resize_base64_image(doc, size=(1300, 600))
               b64_images.append(doc)
           else:
               texts.append(doc)
       if len(b64_images) > 0:
           return {"images": b64_images[:1], "texts": []}
       return {"images": b64_images, "texts": texts}
    
    def img_prompt_func(data_dict):
       formatted_texts = "\n".join(data_dict["context"]["texts"])
       messages = []
    
       text_message = {
           "type": "text",
           "text": (
               "あなたは投資アドバイスを提供する役割を担う金融アナリストです。\n"
               "テキスト、表、そして通常はチャートやグラフの画像が混在して提供されます。\n"
               "この情報を使用して、ユーザーの質問に関連する投資アドバイスを提供してください。\n"
               f"ユーザーからの質問: {data_dict['question']}\n\n"
               "テキストおよび/または表:\n"
               f"{formatted_texts}\n"
               "出力はMarkdownの表形式で提供してください。"
           ),
       }
       messages.append(text_message)
       if data_dict["context"]["images"]:
           for image in data_dict["context"]["images"]:
               image_message = {
                   "type": "image_url",
                   "image_url": {"url": f"data:image/jpeg;base64,{image}"},
               }
               messages.append(image_message)
       return [HumanMessage(content=messages)]
    
    def multi_modal_rag_chain(retriever):
    
       model = ChatVertexAI(
           temperature=0, model_name="gemini-1.0-pro-vision", max_output_tokens=2048
       )
    
       chain = (
           {
               "context": retriever | RunnableLambda(split_image_text_types),
               "question": RunnablePassthrough(),
           }
           | RunnableLambda(img_prompt_func)
           | model
           | StrOutputParser()
       )
    
       return chain
    
    chain_multimodal_rag = multi_modal_rag_chain(retriever_multi_vector_img)

    クエリの実行

    最後に、クエリを実行します:

    from IPython.display import Markdown as md
    
    query = "Snowflake、MongoDB、Cloudflare、Datadogについて、EV/NTMとNTM Rev Growthを教えてください。"
    docs = retriever_multi_vector_img.get_relevant_documents(query, limit=1)
    result = chain_multimodal_rag.invoke(query)
    md(result)

    クエリ実行結果:

    コメントを残す