Code walkthrough: Movie Chatbot (CSV file chat app)
Build a chatbot that responds to a user interactively based on data in a csv file
IMDBot documentation
This code creates a chatbot called IMDBot, which answers questions about movies based on data from a CSV file. The chatbot is powered by OctoAI and is hosted on a Streamlit application.
https://github.com/AI-Bassem/moviebot
Indexing and Query flows:
Setting up Environment Variables
The setup_env_variables
function sets up environment variables OCTOAI_API_TOKEN
and ENDPOINT_URL
using the secrets from Streamlit.
def setup_env_variables():
os.environ["OCTOAI_API_TOKEN"] = st.secrets['OCTOAI_API_TOKEN']
os.environ["ENDPOINT_URL"] = st.secrets['ENDPOINT_URL']
Initializing Session State
handle_session_state
function sets up the initial session state variables.
def handle_session_state():
st.session_state.setdefault('generated', [])
st.session_state.setdefault('past', [])
st.session_state.setdefault('q_count', 0)
Loading Movie Data
The load_data
function loads movie data from a CSV file using a PagedCSVReader
loader.
def load_data(file_path):
PagedCSVReader = download_loader("PagedCSVReader")
loader = PagedCSVReader()
return loader.load_data(file_path)
Initializing OctoAIEndpoint and LangChain LLMPredictor
The initialize_llm
function initializes the OctoAIEndpoint and LLMPredictor.
def initialize_llm(endpoint_url):
llm = OctoAIEndpoint(endpoint_url=endpoint_url, model_kwargs={
"max_new_tokens": 200, "temperature": 0.75, "top_p": 0.95, "repetition_penalty": 1, "seed": None, "stop": [], })
return LLMPredictor(llm=llm)
Creating LangchainEmbedding
The create_embeddings
function creates an instance of LangchainEmbedding using OctoAIEmbeddings wrapper for a hosted Instructor-Large model endpoint.
def create_embeddings():
if 'embeddings' not in st.session_state:
embeddings = LangchainEmbedding(OctoAIEmbeddings(
endpoint_url="https://instruct-f1kzsig6xes9.octoai.run/predict"))
st.session_state['embeddings'] = embeddings
return st.session_state['embeddings']
Creating ServiceContext
The create_service_context
function creates an instance of llama_index ServiceContext.
def create_service_context(llm_predictor, embeddings):
if 'service_context' not in st.session_state:
service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, chunk_size_limit=400, embed_model=embeddings)
st.session_state['service_context'] = service_context
return st.session_state['service_context']
Creating GPTVectorStoreIndex
The create_index
function creates an instance of llama_index GPTVectorStoreIndex
. It first checks if the index.pkl
file exists and loads it if available. If not, it creates a new GPTVectorStoreIndex
from the given documents using the provided service_context
.
def create_index(documents, service_context):
if 'index' not in st.session_state:
path = Path("index.pkl")
if path.exists():
index = dill.load(open(path, "rb"))
else:
index = GPTVectorStoreIndex.from_documents(
documents, service_context=service_context)
#dill.dump(index, open(path, "wb")) #https://github.com/jerryjliu/llama_index/issues/886
st.session_state['index'] = index
return st.session_state['index']
Creating Query Engine
The create_query_engine
function creates a llama_index query engine using the given index
and llm_predictor
.
def create_query_engine(index, llm_predictor):
if 'query_engine' not in st.session_state:
query_engine = index.as_query_engine(
verbose=True, llm_predictor=llm_predictor)
st.session_state['query_engine'] = query_engine
return st.session_state['query_engine']
Processing Query
The query
function processes a query and returns a response. It first gets a response from the llama_index query_engine
and then transforms the response into a string.
def query(payload, query_engine):
response = query_engine.query(payload["inputs"]["text"])
# Transform response to string and remove leading newline character if present
return str(response).lstrip("\n")
Main Function
The main
function initializes the environment, sets up the endpoint URL, loads the data, initializes the LLM predictor, and creates the embeddings, service context, index, and query engine. It then displays the user interface and processes the user's input.
def main():
# Setup the environment variables
setup_env_variables()
# Set the endpoint url
endpoint_url = os.getenv("ENDPOINT_URL")
# Initialize the session state
handle_session_state()
# Load the data
documents = load_data(Path('rotten_tomatoes_top_movies.csv'))
# Initialize the LLM predictor
llm_predictor = initialize_llm(endpoint_url)
# Create the embeddings
embeddings = create_embeddings()
# Create the service context
service_context = create_service_context(llm_predictor, embeddings)
# Create the index
index = create_index(documents, service_context)
# Create the query engine
query_engine = create_query_engine(index, llm_predictor)
# Display the header
st.subheader("🎬 IMDBot - Powered by Oct
The main
function continues by setting up the user interface and getting the user's input. It then processes the user's input and displays the generated response on the user interface.
def main():
# ... continued from before ...
st.subheader("🎬 IMDBot - Powered by OctoAI")
st.markdown('* :movie_camera: Tip #1: IMDBot is great at answering factual questions like: "Who starred in the Harry Potter movies?" or "What year did Jaws come out?')
st.markdown('* :black_nib: Tip #2: IMDBot loves the word "synopsis" -- we suggest using it if you are looking for plot summaries. Otherwise, expect some hallucinations.')
st.markdown("* :blush: Tip #3: IMDbot has information about 500 popular movies, but is not comprehensive. It probably won't know some more obscure films.")
st.markdown("### Welcome to the IMDBot demo")
st.sidebar.image("octoml-octo-ai-logo-color.png", caption="Try OctoML's new compute service for free by signing up for early access: https://octoml.ai/")
try:
# Get the user input
user_input = get_text(q_count=st.session_state['q_count'])
# If user input is not empty, process the input
if user_input and user_input.strip() != '':
output = query({"inputs": {"text": user_input, }}, query_engine)
# Increment q_count, append user input and generated output to session state
st.session_state['q_count'] += 1
st.session_state['past'].append(user_input)
if output:
st.session_state['generated'].append(output)
# If there are generated messages, display them
if st.session_state['generated']:
for i in range(len(st.session_state['generated'])-1, -1, -1):
message(st.session_state['past'][i], is_user=True, key=f'{str(i)}_user')
message(st.session_state["generated"][i], key=str(i))
except Exception as e:
st.error("Something went wrong. Please try again.")
if __name__ == "__main__":
main()
In the exception handling part, if an error occurs during the execution, an error message "Something went wrong. Please try again." is displayed. The main
function is called if this script is run as the main module.
Updated 26 days ago