|
| 1 | +import os |
| 2 | +from PIL import Image |
| 3 | +from dotenv import load_dotenv |
| 4 | +from PyPDF2 import PdfReader |
| 5 | +import streamlit as st |
| 6 | +import openai |
| 7 | +from langchain.text_splitter import CharacterTextSplitter |
| 8 | +from langchain.llms import OpenAI |
| 9 | +from langchain.memory import ConversationBufferMemory |
| 10 | +from langchain.chains import LLMChain |
| 11 | +from langchain.prompts import PromptTemplate |
| 12 | +from langchain.embeddings import OpenAIEmbeddings |
| 13 | +from langchain.vectorstores import FAISS |
| 14 | +from langchain.chains import RetrievalQA |
| 15 | +from langchain.chains import ConversationalRetrievalChain |
| 16 | +from htmlTemplates import css, bot_template, user_template |
| 17 | +from prompts import CHAT_TEMPLATE, INITIAL_TEMPLATE |
| 18 | +from prompts import CORRECTION_CONTEXT, COMPLETION_CONTEXT, OPTIMIZATION_CONTEXT, GENERAL_ASSISTANT_CONTEXT, GENERATION_CONTEXT, COMMENTING_CONTEXT, EXPLANATION_CONTEXT |
| 19 | + |
| 20 | + |
| 21 | +def page_config(): |
| 22 | + st.set_page_config(page_title="GPT-4 Coding Asssistant", page_icon="computer") |
| 23 | + st.write(css, unsafe_allow_html=True) |
| 24 | + |
| 25 | + |
| 26 | +def init_ses_states(): |
| 27 | + st.session_state.setdefault('chat_history', []) |
| 28 | + st.session_state.setdefault('initial_input', "") |
| 29 | + |
| 30 | + |
| 31 | +def page_title_header(): |
| 32 | + top_image = Image.open('trippyPattern.png') |
| 33 | + st.image(top_image) |
| 34 | + st.title("GPT-4 Coding Assistant") |
| 35 | + st.caption("Powered by OpenAI, LangChain, Streamlit") |
| 36 | + |
| 37 | + |
| 38 | +def sidebar(): |
| 39 | + global language, scenario, temperature, scenario_context |
| 40 | + languages = ['Python', 'GoLang', 'JavaScript', 'Java', 'C', 'C++', 'C#'] |
| 41 | + scenarios = ['General Assistant', 'Code Correction', 'Code Completion', 'Code Commenting', 'Code Optimization', 'Code Generation', 'Code Explanation'] |
| 42 | + scenario_context_map = { |
| 43 | + "Code Correction": CORRECTION_CONTEXT, |
| 44 | + "Code Completion": COMPLETION_CONTEXT, |
| 45 | + "Code Optimization": OPTIMIZATION_CONTEXT, |
| 46 | + "General Assistant": GENERAL_ASSISTANT_CONTEXT, |
| 47 | + "Code Generation": GENERATION_CONTEXT, |
| 48 | + "Code Commenting": COMMENTING_CONTEXT, |
| 49 | + "Code Explanation": EXPLANATION_CONTEXT |
| 50 | + } |
| 51 | + |
| 52 | + with st.sidebar: |
| 53 | + with st.expander(label="Settings", expanded=True): |
| 54 | + language = st.selectbox(label="Language", options=languages, index=0) |
| 55 | + scenario = st.selectbox(label="Scenario", options=scenarios, index=0) |
| 56 | + scenario_context = scenario_context_map.get(scenario, "") |
| 57 | + temperature = st.slider(label="Temperature", min_value=0.0, max_value=1.0, value=0.5) |
| 58 | + |
| 59 | + |
| 60 | +def display_convo(): |
| 61 | + with st.container(): |
| 62 | + for i, message in enumerate(reversed(st.session_state.chat_history)): |
| 63 | + if i % 2 == 0: |
| 64 | + st.markdown(bot_template.replace("{{MSG}}", message), unsafe_allow_html=True) |
| 65 | + else: |
| 66 | + st.markdown(user_template.replace("{{MSG}}", message), unsafe_allow_html=True) |
| 67 | + |
| 68 | + |
| 69 | +def handle_initial_submit(): |
| 70 | + initial_template = PromptTemplate( |
| 71 | + input_variables=['input','language','scenario','scenario_context'], |
| 72 | + template= INITIAL_TEMPLATE |
| 73 | + ) |
| 74 | + initial_llm_chain = create_llm_chain(prompt_template=initial_template) |
| 75 | + initial_input = st.text_area(label=f"User Input", height=300) |
| 76 | + if (st.button(f'Submit Initial Input') and initial_input): |
| 77 | + st.session_state.initial_input = initial_input |
| 78 | + st.session_state['chat_history'] = [] |
| 79 | + with st.spinner('Generating Response...'): |
| 80 | + initial_response = initial_llm_chain.run(input=initial_input, |
| 81 | + language=language, |
| 82 | + scenario=scenario, |
| 83 | + scenario_context=scenario_context) |
| 84 | + st.session_state.chat_history.append(initial_response) |
| 85 | + |
| 86 | + |
| 87 | +def handle_user_message(): |
| 88 | + chat_template = PromptTemplate( |
| 89 | + input_variables=['input','user_message','language','scenario','chat_history'], |
| 90 | + template=CHAT_TEMPLATE |
| 91 | + ) |
| 92 | + chat_llm_chain = create_llm_chain(prompt_template=chat_template) |
| 93 | + if st.session_state.chat_history: |
| 94 | + user_message = st.text_input("Further Questions for Coding AI?", key="user_input") |
| 95 | + if st.button("Submit Message") and user_message: |
| 96 | + st.session_state['chat_history'].append(user_message) |
| 97 | + with st.spinner('Generating Response...'): |
| 98 | + chat_response = chat_llm_chain.run(input=st.session_state['initial_input'], |
| 99 | + user_message=user_message, |
| 100 | + language=language, |
| 101 | + scenario=scenario, |
| 102 | + chat_history=st.session_state.chat_history) |
| 103 | + st.session_state['chat_history'].append(chat_response) |
| 104 | + |
| 105 | + |
| 106 | +def create_llm_chain(prompt_template): |
| 107 | + memory = ConversationBufferMemory(input_key="input", memory_key="chat_history", ) |
| 108 | + llm = OpenAI(temperature=temperature, model_name="gpt-4") |
| 109 | + return LLMChain(llm=llm, prompt=prompt_template, memory=memory) |
| 110 | + |
| 111 | + |
| 112 | +def main(): |
| 113 | + page_config() |
| 114 | + init_ses_states() |
| 115 | + page_title_header() |
| 116 | + sidebar() |
| 117 | + handle_initial_submit() |
| 118 | + handle_user_message() |
| 119 | + display_convo() |
| 120 | + |
| 121 | + |
| 122 | +if __name__ == '__main__': |
| 123 | + load_dotenv() |
| 124 | + main() |
| 125 | + |
0 commit comments