283 lines
7.8 KiB
Plaintext
283 lines
7.8 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "4128db56-1f32-41a1-a3aa-55525d48646d",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain.prompts import (\n",
|
|
" ChatPromptTemplate,\n",
|
|
" PromptTemplate,\n",
|
|
" SystemMessagePromptTemplate,\n",
|
|
" AIMessagePromptTemplate,\n",
|
|
" HumanMessagePromptTemplate,\n",
|
|
")\n",
|
|
"from datetime import datetime\n",
|
|
"from langchain.llms import OpenAI\n",
|
|
"from langchain.output_parsers import DatetimeOutputParser\n",
|
|
"from langchain.chat_models import ChatOpenAI"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 37,
|
|
"id": "becc7263-d617-44b3-89b5-970e61e3b3aa",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class HistoryQuiz:\n",
|
|
"\n",
|
|
" def create_history_question(self, topic):\n",
|
|
" \"\"\"\n",
|
|
" This method should output a historical question about the topic that\n",
|
|
" has a date as a correct answer. For example:\n",
|
|
"\n",
|
|
" On what date did World War 2 end?\n",
|
|
" \"\"\"\n",
|
|
" system_template = \"You write single quiz question about {topic}. You only return the quiz question\"\n",
|
|
" system_message_prompt = SystemMessagePromptTemplate.from_template(system_template)\n",
|
|
"\n",
|
|
" human_template = \"{question}\"\n",
|
|
" human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)\n",
|
|
"\n",
|
|
" chat_prompt = ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt])\n",
|
|
"\n",
|
|
" ques = \"Give me a quiz question where the correct answer is a specific date.\"\n",
|
|
" request = chat_prompt.format_prompt(topic=topic, question=ques).to_messages()\n",
|
|
"\n",
|
|
" chat = ChatOpenAI()\n",
|
|
" result = chat(request)\n",
|
|
" \n",
|
|
" return result.content\n",
|
|
"\n",
|
|
"\n",
|
|
" def get_AI_answer(self, question):\n",
|
|
" \"\"\"\n",
|
|
" This method should get the answer to the historical question from the method above.\n",
|
|
" The answer should be datetime format\n",
|
|
"\n",
|
|
" September 2, 1945 --> datetime.datetime(1945, 9, 2, 0, 0)\n",
|
|
" \"\"\"\n",
|
|
" output_parser = DatetimeOutputParser()\n",
|
|
" system_template = \"You answer quiz questions with just a date.\"\n",
|
|
" system_message_prompt = SystemMessagePromptTemplate.from_template(system_template)\n",
|
|
"\n",
|
|
" human_template =\"\"\"Answer the user's question:\n",
|
|
"\n",
|
|
" {question}\n",
|
|
"\n",
|
|
" {format_instructions}\"\"\"\n",
|
|
" human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)\n",
|
|
"\n",
|
|
" chat_prompt = ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt])\n",
|
|
"\n",
|
|
" format_instructions = output_parser.get_format_instructions()\n",
|
|
" request = chat_prompt.format_prompt(question=question,\n",
|
|
" format_instructions=format_instructions).to_messages()\n",
|
|
"\n",
|
|
" chat = ChatOpenAI()\n",
|
|
" result = chat(request)\n",
|
|
"\n",
|
|
" correct_datetime = output_parser.parse(result.content)\n",
|
|
"\n",
|
|
" return correct_datetime\n",
|
|
"\n",
|
|
"\n",
|
|
" def get_user_answer(self, question):\n",
|
|
" \"\"\"\n",
|
|
" This method should grab a user answer and convert it to a datetime format\n",
|
|
" \"\"\"\n",
|
|
" print(question)\n",
|
|
" print(\"\\n\")\n",
|
|
"\n",
|
|
" year = int(input(\"Enter the year: \"))\n",
|
|
" month = int(input(\"Enter the month (1-12): \"))\n",
|
|
" day = int(input(\"Enter the day (1-31): \"))\n",
|
|
" user_datetime = datetime(year, month, day)\n",
|
|
"\n",
|
|
" return user_datetime\n",
|
|
"\n",
|
|
" \n",
|
|
" def check_user_answer(self, user_answer, ai_answer):\n",
|
|
" \"\"\"\n",
|
|
" Should check the user answer against the AI answer and return the difference.\n",
|
|
" \"\"\"\n",
|
|
" difference = user_answer - ai_answer\n",
|
|
" formatted_difference = str(difference)\n",
|
|
" print(\"The difference between the user and ai answer is:\", formatted_difference)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 42,
|
|
"id": "83eb3ad4-56da-4ffe-a7b0-f3d348dca6ae",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"quiz_bot = HistoryQuiz()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 43,
|
|
"id": "5cb1f8c0-2582-4444-92ab-83b45d3eb77a",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"question = quiz_bot.create_history_question(topic=\"World War 2\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 44,
|
|
"id": "80efca07-45e6-471a-946e-95d97be972e9",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"'On what date did the Japanese attack Pearl Harbor, drawing the United States into World War II?'"
|
|
]
|
|
},
|
|
"execution_count": 44,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"question"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 45,
|
|
"id": "fb72ea47-215c-4952-a631-d45b1e3119df",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"ai_answer = quiz_bot.get_AI_answer(question)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 46,
|
|
"id": "b9d94135-45f6-4998-a8bd-29f0729ece34",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"datetime.datetime(1941, 12, 7, 0, 0)"
|
|
]
|
|
},
|
|
"execution_count": 46,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"ai_answer"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 47,
|
|
"id": "dd16b9cf-2b4c-4707-9665-20b9a9d74e57",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"On what date did the Japanese attack Pearl Harbor, drawing the United States into World War II?\n",
|
|
"\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdin",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Enter the year: 1942\n",
|
|
"Enter the month (1-12): 12\n",
|
|
"Enter the day (1-31): 7\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"user_answer = quiz_bot.get_user_answer(question)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 48,
|
|
"id": "09f76155-40e2-4a07-a0ea-5bc011fef809",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"datetime.datetime(1942, 12, 7, 0, 0)"
|
|
]
|
|
},
|
|
"execution_count": 48,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"user_answer"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 49,
|
|
"id": "4a8226af-1c99-4ed3-8063-2d9c6e0fbfc9",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"The difference between the user and ai answer is: 365 days, 0:00:00\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"quiz_bot.check_user_answer(user_answer, ai_answer)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "68c2d1fc-f0be-47ea-972a-7e978dd94ed2",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.11.4"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|