본문 바로가기
Bedrock

Amazon Bedrock, LangChain과 사용하기 - 6. 간단한 검색증강(RAG : Retrieval Augmented Generation) 구현

by Hyeon Cloud 2023. 10. 18.

Amazon Bedrock, LangChain과 사용하기

1. Amazon Bedrock, 개발환경 설정

2. Amazon Bedrock API, LangChain 사용해보기

3. 추론 매개변수

4. 스트리밍 API, 벡터 임베딩

5. Streamlit, 텍스트/이미지 생성

6. 간단한 검색증강(RAG : Retrieval Augmented Generation) 구현

7. 간단한 챗봇 구현 (Conversation Memory)

8. 챗봇 구현 (RAG + Conversation Memory)

기초모델 (FM : Foundation Model)은 일반적으로 오프라인으로 학습되므로 모델 학습 이후 모든 데이터에 구애받지 않고 모델을 학습시킬 수 있습니다. 기초모델의 경우 일반적인 도메인 코퍼스를 대상으로 학습되므로 도메인별 작업에는 효율성이 떨어집니다. 질문에 정확하고 일관되게 대답하려면 모델의 응답을 뒷받침할수 있는 실제 정보가 있는지 확인할 필요가 있습니다. 이를 구현하기위해 RAG (Retrieval Augmented Generation) 패턴을 사용합니다.

 

일반적인 시나리오에서는 Kendra, OpenSearch 등 백터 임베딩을 저장하는 벡터데이터베이스를 사용하지만, 이번 포스트에서는 간단한 구조를 위해 간단하게 벡터를 사용할 수 있는 데이터베이스인 🔗FAISS를 사용하도록 하겠습니다. FAISS는 페이스북, 인스타그램으로 유명한 "메타"라는 회사에서 개발한 인메모리 벡터데이터베이스 입니다.

RAG 구현

아키텍쳐는 다음과 같으며 각 파트의 Flow를 간단히 소개하도록 하겠습니다.

  1. KB문서의 경우 텍스트 청크 단위로 나뉘어져 Titan Embeddings로 전달, 벡터로 변환됩니다. 그 후 FAISS 벡터디비에 저장됩니다.
  2. 유저 인터페이스에서 프롬프트를 전달합니다.
  3. 프롬프트는 Titan Embedding을 사용하여 벡터로 변환되며, 저장된 벡터와 가장 가까운 벡터(일치) 를 찾게됩니다.
  4. 일치하는 벡터와 원래 질문의 결합된 컨텐츠를 Bedrock FM(LLM)에 전달하고, 최상의 답변을 생성합니다.

라이브러리 파일 작성 (lib.py)

이번엔 RAG모델을 일반적으로 구성할때 LangChain을 사용하므로 LangChain을 사용하도록 하겠습니다. 라이브러리들을 임포트합니다.

from langchain.embeddings import BedrockEmbeddings
from langchain.indexes import VectorstoreIndexCreator
from langchain.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import PyPDFLoader
from langchain.llms.bedrock import Bedrock

 

LangChain을 사용하여 Bedrock 클라이언트를 통해 LLM을 불러오도록 하겠습니다.

이전에 보았던 추론 매개변수를 통해 온도를 0으로 설정함으로써 정확한 값을 가져오도록 하겠습니다.

사용하는 모델은 "ai21.j2-ultra-v1"입니다.

def get_llm():
    model_kwargs = {
        "maxTokens": 1024, 
        "temperature": 0, 
        "topP": 0.5, 
        "stopSequences": [], 
        "countPenalty": {"scale": 0 }, 
        "presencePenalty": {"scale": 0 }, 
        "frequencyPenalty": {"scale": 0 } 
    }
    
    llm = Bedrock(
        region_name='us-east-1',
        endpoint_url="https://bedrock-runtime.us-east-1.amazonaws.com",
        model_id="ai21.j2-ultra-v1",
        model_kwargs=model_kwargs)
    
    return llm

 

인메모리 벡터디비인 FAISS를 사용하는 함수를 추가하도록 하겠습니다.

주의점으로는 사용되는 pdf_path의 경우 환경에따라 상대경로가 동작하지 않을 수 있으므로 절대경로도 사용해볼수 있도록 합니다.

2022-Shareholder-Letter.pdf
0.08MB

def get_index(): 
    embeddings = BedrockEmbeddings(
        region_name='us-east-1',
        endpoint_url="https://bedrock-runtime.us-east-1.amazonaws.com",
    ) 
    
    pdf_path = "/Users/hyeonminkim/Desktop/Bedrock/basicrag/2022-Shareholder-Letter.pdf"
    loader = PyPDFLoader(file_path=pdf_path) 
    
    text_splitter = RecursiveCharacterTextSplitter( 
        separators=["\n\n", "\n", ".", " "], 
        chunk_size=1000, 
        chunk_overlap=100 
    )
    
    index_creator = VectorstoreIndexCreator( 
        vectorstore_cls=FAISS, 
        embedding=embeddings,
        text_splitter=text_splitter,
    )
    
    index_from_loader = index_creator.from_loaders([loader])
    
    return index_from_loader

 

Bedrock을 호출하고 리스폰스를 받아오기 위한 함수를 작성합니다. 인덱스로 벡터디비를 연결합니다. 쿼리결과는 LLM으로 전달됩니다.

def get_rag_response(index, question):
    llm = get_llm()
    response_text = index.query(question=question, llm=llm) 
    return response_text

 

전체 코드는 다음과 같습니다.

from langchain.embeddings import BedrockEmbeddings
from langchain.indexes import VectorstoreIndexCreator
from langchain.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import PyPDFLoader
from langchain.llms.bedrock import Bedrock

def get_llm():
    model_kwargs = {
        "maxTokens": 1024, 
        "temperature": 0, 
        "topP": 0.5, 
        "stopSequences": [], 
        "countPenalty": {"scale": 0 }, 
        "presencePenalty": {"scale": 0 }, 
        "frequencyPenalty": {"scale": 0 } 
    }
    
    llm = Bedrock(
        region_name='us-east-1',
        endpoint_url="https://bedrock-runtime.us-east-1.amazonaws.com",
        model_id="ai21.j2-ultra-v1",
        model_kwargs=model_kwargs)
    
    return llm

def get_index(): 
    embeddings = BedrockEmbeddings(
        region_name='us-east-1',
        endpoint_url="https://bedrock-runtime.us-east-1.amazonaws.com",
    ) 
    
    pdf_path = "/Users/hyeonminkim/Desktop/Bedrock/basicrag/2022-Shareholder-Letter.pdf"
    loader = PyPDFLoader(file_path=pdf_path) 
    
    text_splitter = RecursiveCharacterTextSplitter( 
        separators=["\n\n", "\n", ".", " "], 
        chunk_size=1000, 
        chunk_overlap=100 
    )
    
    index_creator = VectorstoreIndexCreator( 
        vectorstore_cls=FAISS, 
        embedding=embeddings,
        text_splitter=text_splitter,
    )
    
    index_from_loader = index_creator.from_loaders([loader])
    
    return index_from_loader

def get_rag_response(index, question):
    llm = get_llm()
    response_text = index.query(question=question, llm=llm) 
    return response_text

 

Streamlit Application 작성 (front.py)

라이브러리와 필요 패키지를 임포트합니다.

import streamlit as st 
import lib as glib

 

제목과 구성을 추가하고 세션캐시에 벡터 인덱스를 추가합니다. 이를통해 세션별로 메모리내 벡터디비를 유지합니다.

st.set_page_config(page_title="Retrieval-Augmented Generation") 
st.title("Retrieval-Augmented Generation")

if 'vector_index' not in st.session_state: 
    with st.spinner("Indexing document..."):
        st.session_state.vector_index = glib.get_index()

 

사용자에게 입력을 받기위한 입력요소와 출력하기위한 출력요소를 추가합니다.

input_text = st.text_area("Input text", label_visibility="collapsed")
go_button = st.button("Go", type="primary")

if go_button:  
    with st.spinner("Working..."):
        response_content = glib.get_rag_response(index=st.session_state.vector_index, question=input_text) 
        st.write(response_content)

 

전체 코드는 다음과 같습니다.

import streamlit as st 
import lib as glib

st.set_page_config(page_title="Retrieval-Augmented Generation") 
st.title("Retrieval-Augmented Generation")

if 'vector_index' not in st.session_state: 
    with st.spinner("Indexing document..."):
        st.session_state.vector_index = glib.get_index()

input_text = st.text_area("Input text", label_visibility="collapsed")
go_button = st.button("Go", type="primary")

if go_button:  
    with st.spinner("Working..."):
        response_content = glib.get_rag_response(index=st.session_state.vector_index, question=input_text) 
        st.write(response_content)

 

동작확인

잘 동작하는것을 확인할 수 있네요.