Browse Source

Implement RAG for AI Meal Planner

lanfr144 2 weeks ago
parent
commit
f6b9d4c6dc
1 changed files with 27 additions and 2 deletions
  1. 27 2
      app.py

+ 27 - 2
app.py

@@ -535,16 +535,41 @@ with tab_planner:
     extra_notes = st.text_input("Any additional allergies or goals?")
     
     if st.button("Generate Professional Menu"):
-        with st.spinner("AI is formulating..."):
+        with st.spinner("AI is formulating and interrogating the local database..."):
             sys_prompt = f"""You are a professional Dietitian planner. Target: {target_cal}kcal over {meal_count} meals. 
             Dietary constraint: {diet_pref}. Additional notes: {extra_notes}.
             CRITICAL INSTRUCTIONS:
+            - YOU MUST USE the `search_nutrition_db` tool to find real products and their exact macros before constructing the menu!
+            - If you cannot find appropriate products in the local DB, use `local_web_search`.
             - ALWAYS output exactly as a strict Markdown table including Columns: | Meal | Food | Calories | Salt | Fat | Iron |
             - DO NOT output | separated text outside of standard strict markdown block ` ```markdown ` or standard rendering.
             - Convert ALL cooking measurements to Grams (g). Use these equivalents STRICTLY:
               1 tbsp = 15g, 1 tsp = 5g, 1 cup = 200g, 1 mustard glass = 100g. 1 cl of liquid = 10g.
             """
-            response = ollama.chat(model='mistral', messages=[{'role': 'system', 'content': sys_prompt}, {'role': 'user', 'content': 'Generate menu'}])
+            
+            temp_messages = [{'role': 'system', 'content': sys_prompt}, {'role': 'user', 'content': 'Generate my meal plan. Find real foods from the DB.'}]
+            response = ollama.chat(model='mistral', messages=temp_messages, tools=[search_tool_schema, db_search_tool_schema])
+            
+            # Simple loop to handle multiple tool calls (up to 3 times to prevent infinite loops)
+            for _ in range(3):
+                if response.get('message', {}).get('tool_calls'):
+                    temp_messages.append(response['message'])
+                    for tool in response['message']['tool_calls']:
+                        if tool['function']['name'] == 'local_web_search':
+                            query_arg = tool['function']['arguments'].get('query')
+                            st.info(f"🔍 Planner Web Search triggered for: '{query_arg}'")
+                            search_data = local_web_search(query_arg)
+                            temp_messages.append({'role': 'tool', 'content': search_data, 'name': 'local_web_search'})
+                        elif tool['function']['name'] == 'search_nutrition_db':
+                            query_arg = tool['function']['arguments'].get('query')
+                            st.info(f"🗄️ Planner DB Search triggered for: '{query_arg}'")
+                            db_data = search_nutrition_db(query_arg)
+                            temp_messages.append({'role': 'tool', 'content': db_data, 'name': 'search_nutrition_db'})
+                    
+                    response = ollama.chat(model='mistral', messages=temp_messages, tools=[search_tool_schema, db_search_tool_schema])
+                else:
+                    break
+                    
             st.markdown(response['message']['content'])
 
 if conn_reader: conn_reader.close()