main.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424
  1. import json
  2. import logging
  3. import httpx
  4. import bcrypt
  5. from contextlib import asynccontextmanager
  6. from fastapi import FastAPI, HTTPException, Depends, Header
  7. from database import create_tables, create_user, get_user_by_username, create_session, get_user_from_token, delete_session, search_foods_by_name, save_chat_message, get_user_chat_history, get_user_profile, get_food_by_id, get_foods_by_ids
  8. from fastapi.responses import HTMLResponse, StreamingResponse
  9. from fastapi.staticfiles import StaticFiles
  10. from pydantic import BaseModel
  11. from typing import List, Generator, Optional
  12. logging.basicConfig(level=logging.INFO)
  13. logger = logging.getLogger(__name__)
  14. @asynccontextmanager
  15. async def lifespan(app: FastAPI):
  16. create_tables()
  17. yield
  18. app = FastAPI(title="LocalFoodAI Chat", lifespan=lifespan)
  19. # Use direct bcrypt for better environment compatibility
  20. def get_password_hash(password: str):
  21. # Hash requires bytes
  22. pwd_bytes = password.encode('utf-8')
  23. salt = bcrypt.gensalt()
  24. hashed = bcrypt.hashpw(pwd_bytes, salt)
  25. return hashed.decode('utf-8')
  26. def verify_password(plain_password: str, hashed_password: str):
  27. # bcrypt.checkpw handles verification
  28. return bcrypt.checkpw(plain_password.encode('utf-8'), hashed_password.encode('utf-8'))
  29. class UserCreate(BaseModel):
  30. username: str
  31. password: str
  32. class UserLogin(BaseModel):
  33. username: str
  34. password: str
  35. async def get_current_user(authorization: Optional[str] = Header(None)):
  36. if not authorization or not authorization.startswith("Bearer "):
  37. raise HTTPException(status_code=401, detail="Authentication required")
  38. token = authorization.split(" ")[1]
  39. user = get_user_from_token(token)
  40. if not user:
  41. raise HTTPException(status_code=401, detail="Invalid or expired session")
  42. return user
  43. OLLAMA_URL = "http://localhost:11434/api/chat"
  44. MODEL_NAME = "qwen3.5:4b"
  45. # Common stopwords to strip before searching the food database
  46. _STOPWORDS = {
  47. 'how', 'many', 'much', 'calories', 'does', 'have', 'has', 'is', 'are',
  48. 'in', 'the', 'a', 'an', 'of', 'for', 'with', 'what', 'tell', 'me',
  49. 'about', 'nutritional', 'value', 'nutrition', 'macro', 'macros',
  50. 'protein', 'fat', 'carbs', 'fiber', 'can', 'you', 'i', 'want', 'need',
  51. 'eat', 'eating', 'food', 'meal', 'diet', 'healthy', 'make', 'cook',
  52. 'recipe', 'per', '100g', 'gram', 'grams', 'serving'
  53. }
  54. def extract_food_context(messages: list) -> str | None:
  55. """Scan the last user message for food keywords and enrich with local DB data."""
  56. # Find the last user message
  57. last_user_msg = None
  58. for msg in reversed(messages):
  59. role = msg.get('role', '') if isinstance(msg, dict) else msg.role
  60. content = msg.get('content', '') if isinstance(msg, dict) else msg.content
  61. if role == 'user':
  62. last_user_msg = content
  63. break
  64. if not last_user_msg:
  65. return None
  66. # Extract meaningful keywords by removing stopwords
  67. words = last_user_msg.lower().replace('?', '').replace(',', '').split()
  68. keywords = [w for w in words if w not in _STOPWORDS and len(w) > 2]
  69. if not keywords:
  70. return None
  71. # Try each keyword against the local food database, collect unique results
  72. found_items = {}
  73. # Optimization: Only use the first 2 most relevant keywords to keep context small on CPU
  74. for kw in keywords[:2]:
  75. results = search_foods_by_name(kw, limit=2)
  76. for item in results:
  77. # Truncate extremely long USDA names for performance
  78. short_name = item['name'][:100] + ("..." if len(item['name']) > 100 else "")
  79. if short_name not in found_items:
  80. found_items[short_name] = item
  81. if len(found_items) >= 3:
  82. break
  83. if not found_items:
  84. return None
  85. # Build a structured context block for the system prompt
  86. lines = [
  87. "[SYSTEM: NUTRITIONAL ANALYST MODE]",
  88. "You are the LocalFoodAI Analyst. Use ONLY verified local data for values.",
  89. "CRITICAL: Provide direct, concise answers. Skip all internal monologues, <thought> tags, or reasoning steps.",
  90. "For each food discussed, you MUST follow this structure:",
  91. "1. Header: ### 🥗 [Name] (per 100g)",
  92. "2. Macros: A markdown table for Cal, P, F, C, Fib, Sug, Chol.",
  93. "3. Micros: A bulleted list for Na, Ca, Fe, K, VitA, VitC.",
  94. "4. Insight: A 1-sentence analysis of the food's nutritional profile.",
  95. "Always prioritize local data over training memory. If a nutrient is missing, say 'Data not available'.",
  96. ""
  97. ]
  98. for name, item in found_items.items():
  99. # Compact, token-efficient format for the LLM
  100. line = (
  101. f"- {name}: {item['calories']}kcal | P:{item['protein_g']}g | F:{item['fat_g']}g | C:{item['carbs_g']}g | "
  102. f"Fib:{item['fiber_g']}g | Sug:{item['sugar_g']}g | Chol:{item['cholesterol_mg']}mg | "
  103. f"Na:{item['sodium_mg']}mg | Ca:{item['calcium_mg']}mg | Fe:{item['iron_mg']}mg | "
  104. f"K:{item['potassium_mg']}mg | VitA:{item['vitamin_a_iu']}IU | VitC:{item['vitamin_c_mg']}mg"
  105. )
  106. lines.append(line)
  107. return "\n".join(lines)
  108. # Mount static files to serve the frontend
  109. app.mount("/static", StaticFiles(directory="static"), name="static")
  110. class ChatMessage(BaseModel):
  111. role: str
  112. content: str
  113. class ChatRequest(BaseModel):
  114. messages: List[ChatMessage]
  115. class MealItemInput(BaseModel):
  116. food_id: int
  117. amount_g: float
  118. class MealCalculateRequest(BaseModel):
  119. items: List[MealItemInput]
  120. class MealSaveRequest(BaseModel):
  121. name: str
  122. items: List[MealItemInput]
  123. @app.get("/", response_class=HTMLResponse)
  124. async def read_root():
  125. """Serve the chat interface HTML"""
  126. try:
  127. with open("static/index.html", "r", encoding="utf-8") as f:
  128. return HTMLResponse(content=f.read())
  129. except FileNotFoundError:
  130. return HTMLResponse(content="<h1>Welcome to LocalFoodAI</h1><p>static/index.html not found. Please create the frontend.</p>")
  131. @app.post("/api/register")
  132. async def register_user(user: UserCreate):
  133. if len(user.username.strip()) < 3:
  134. raise HTTPException(status_code=400, detail="Username must be at least 3 characters")
  135. if len(user.password.strip()) < 6:
  136. raise HTTPException(status_code=400, detail="Password must be at least 6 characters")
  137. hashed_password = get_password_hash(user.password)
  138. user_id = create_user(user.username.strip(), hashed_password)
  139. if not user_id:
  140. raise HTTPException(status_code=400, detail="Username already exists")
  141. # Auto-login after registration
  142. token = create_session(user_id)
  143. return {"message": "User registered successfully", "token": token, "username": user.username.strip()}
  144. @app.post("/api/login")
  145. async def login_user(user: UserLogin):
  146. db_user = get_user_by_username(user.username.strip())
  147. if not db_user:
  148. raise HTTPException(status_code=401, detail="Invalid username or password")
  149. if not verify_password(user.password, db_user["password_hash"]):
  150. raise HTTPException(status_code=401, detail="Invalid username or password")
  151. token = create_session(db_user["id"])
  152. return {"status": "success", "username": db_user["username"], "token": token}
  153. @app.post("/api/logout")
  154. async def logout(authorization: Optional[str] = Header(None)):
  155. if authorization and authorization.startswith("Bearer "):
  156. token = authorization.split(" ")[1]
  157. delete_session(token)
  158. return {"message": "Logged out successfully"}
  159. @app.get("/api/macros/targets")
  160. async def get_macro_targets(current_user: dict = Depends(get_current_user)):
  161. """API endpoint to securely fetch the user's current macronutrient targets"""
  162. profile = get_user_profile(current_user['id'])
  163. if not profile:
  164. # Fallback to defaults in case database insertion failed
  165. return {
  166. "calories": 2000,
  167. "protein_g": 150,
  168. "carbs_g": 200,
  169. "fat_g": 65
  170. }
  171. return {
  172. "calories": profile.get("target_calories", 2000),
  173. "protein_g": profile.get("target_protein_g", 150),
  174. "carbs_g": profile.get("target_carbs_g", 200),
  175. "fat_g": profile.get("target_fat_g", 65)
  176. }
  177. @app.post("/chat")
  178. async def chat_endpoint(request: ChatRequest, current_user: dict = Depends(get_current_user)):
  179. """Proxy chat requests to the local Ollama instance with streaming support.
  180. Automatically enriches prompts with verified local SQLite nutritional data.
  181. """
  182. # Keep only the last 6 messages for context window performance on CPU
  183. all_messages = [msg.model_dump() for msg in request.messages]
  184. messages = all_messages[-6:]
  185. # Save the latest user message to DB
  186. if messages and messages[-1]['role'] == 'user':
  187. save_chat_message(current_user['id'], 'user', messages[-1]['content'])
  188. # --- TG-35: Local SQL RAG Enrichment ---
  189. db_context = extract_food_context(messages)
  190. if db_context:
  191. # Prepend as a system message so it acts as grounded knowledge
  192. # We ensure it's a short, concise instruction to prevent context bloat
  193. messages = [{"role": "system", "content": db_context}] + messages
  194. logger.info(f"[Chat] User '{current_user['username']}' is chatting. Context items: {'Yes' if db_context else 'No'}. Message count: {len(messages)}")
  195. payload = {
  196. "model": MODEL_NAME,
  197. "messages": messages,
  198. "stream": True,
  199. "think": False # Disable reasoning/thinking mode for faster responses
  200. }
  201. async def generate_response():
  202. try:
  203. bot_full_response = ""
  204. async with httpx.AsyncClient(timeout=300.0) as client:
  205. # Use a combined timeout for the entire request
  206. async with client.stream("POST", OLLAMA_URL, json=payload, timeout=300.0) as response:
  207. if response.status_code != 200:
  208. error_detail = await response.aread()
  209. logger.error(f"Ollama returned error {response.status_code}: {error_detail}")
  210. yield f"data: {json.dumps({'error': f'LLM Error ({response.status_code})'})}\n\n"
  211. return
  212. async for line in response.aiter_lines():
  213. if line:
  214. try:
  215. data = json.loads(line)
  216. if "message" in data and "content" in data["message"]:
  217. content = data["message"]["content"]
  218. bot_full_response += content
  219. yield f"data: {json.dumps({'content': content})}\n\n"
  220. if data.get("done"):
  221. break
  222. except json.JSONDecodeError:
  223. continue
  224. # Save final bot response to DB
  225. if bot_full_response.strip():
  226. save_chat_message(current_user['id'], 'assistant', bot_full_response)
  227. except Exception as e:
  228. logger.exception(f"Unexpected error during chat stream: {e}")
  229. yield f"data: {json.dumps({'error': 'A technical error occurred while generating the response.'})}\n\n"
  230. return StreamingResponse(generate_response(), media_type="text/event-stream")
  231. @app.get("/api/chat/history")
  232. async def get_history(current_user: dict = Depends(get_current_user)):
  233. """Fetch the chat history for the authenticated user"""
  234. history = get_user_chat_history(current_user['id'])
  235. return {"history": history}
  236. @app.get("/api/food/search")
  237. async def search_food(q: str, current_user: dict = Depends(get_current_user)):
  238. """API endpoint to search for food items securely using token authentication"""
  239. if not q or len(q.strip()) < 1:
  240. return {"results": []}
  241. logger.info(f"User {current_user['username']} searched for [{q}]")
  242. results = search_foods_by_name(q.strip(), limit=15)
  243. return {"results": results}
  244. @app.get("/api/food/{food_id}")
  245. async def get_food_detail(food_id: int, current_user: dict = Depends(get_current_user)):
  246. """API endpoint to fetch structured nutritional details for a specific food item"""
  247. food = get_food_by_id(food_id)
  248. if not food:
  249. raise HTTPException(status_code=404, detail="Food item not found")
  250. # Structure the data as proposed in the implementation plan
  251. structured_data = {
  252. "id": food["id"],
  253. "name": food["name"],
  254. "category": food["category"],
  255. "base_weight_g": food["base_weight_g"],
  256. "macros": {
  257. "calories": food["calories"],
  258. "protein_g": food["protein_g"],
  259. "fat_g": food["fat_g"],
  260. "carbs_g": food["carbs_g"]
  261. },
  262. "extended": {
  263. "fiber_g": food["fiber_g"],
  264. "sugar_g": food["sugar_g"],
  265. "cholesterol_mg": food["cholesterol_mg"]
  266. },
  267. "vitamins": {
  268. "vitamin_a_iu": food["vitamin_a_iu"],
  269. "vitamin_c_mg": food["vitamin_c_mg"]
  270. },
  271. "minerals": {
  272. "calcium_mg": food["calcium_mg"],
  273. "iron_mg": food["iron_mg"],
  274. "potassium_mg": food["potassium_mg"],
  275. "sodium_mg": food["sodium_mg"]
  276. },
  277. "source": food["source"]
  278. }
  279. return structured_data
  280. @app.post("/api/meal/calculate")
  281. async def calculate_meal(request: MealCalculateRequest, current_user: dict = Depends(get_current_user)):
  282. """Calculate the total nutritional value for a combined list of foods and their custom weights."""
  283. if not request.items:
  284. return {"error": "Meal is empty"}
  285. # Validation: Cast to floats and ensure > 0
  286. items = []
  287. for item in request.items:
  288. try:
  289. amount = float(item.amount_g)
  290. if amount <= 0:
  291. raise HTTPException(status_code=400, detail="Quantity must be greater than 0g for all items.")
  292. items.append({"food_id": item.food_id, "amount_g": amount})
  293. except ValueError:
  294. raise HTTPException(status_code=400, detail=f"Invalid amount for food ID {item.food_id}")
  295. # Bulk fetch from DB
  296. requested_ids = list(set(item["food_id"] for item in items))
  297. foods_data = get_foods_by_ids(requested_ids)
  298. # Map for easy lookup
  299. foods_map = {food["id"]: food for food in foods_data}
  300. # Fail-fast: Check if all requested IDs exist
  301. found_ids = set(foods_map.keys())
  302. missing_ids = [fid for fid in requested_ids if fid not in found_ids]
  303. if missing_ids:
  304. raise HTTPException(status_code=400, detail=f"Invalid food IDs provided: {missing_ids}")
  305. # Initialize aggregator
  306. totals = {
  307. "total_weight_g": 0.0,
  308. "macros": {"calories": 0.0, "protein_g": 0.0, "fat_g": 0.0, "carbs_g": 0.0},
  309. "extended": {"fiber_g": 0.0, "sugar_g": 0.0, "cholesterol_mg": 0.0},
  310. "vitamins": {"vitamin_a_iu": 0.0, "vitamin_c_mg": 0.0},
  311. "minerals": {"calcium_mg": 0.0, "iron_mg": 0.0, "potassium_mg": 0.0, "sodium_mg": 0.0}
  312. }
  313. def safe_val(val):
  314. return float(val) if val is not None else 0.0
  315. for item in items:
  316. food = foods_map[item["food_id"]]
  317. ratio = item["amount_g"] / 100.0
  318. totals["total_weight_g"] += item["amount_g"]
  319. totals["macros"]["calories"] += safe_val(food.get("calories")) * ratio
  320. totals["macros"]["protein_g"] += safe_val(food.get("protein_g")) * ratio
  321. totals["macros"]["fat_g"] += safe_val(food.get("fat_g")) * ratio
  322. totals["macros"]["carbs_g"] += safe_val(food.get("carbs_g")) * ratio
  323. totals["extended"]["fiber_g"] += safe_val(food.get("fiber_g")) * ratio
  324. totals["extended"]["sugar_g"] += safe_val(food.get("sugar_g")) * ratio
  325. totals["extended"]["cholesterol_mg"] += safe_val(food.get("cholesterol_mg")) * ratio
  326. totals["vitamins"]["vitamin_a_iu"] += safe_val(food.get("vitamin_a_iu")) * ratio
  327. totals["vitamins"]["vitamin_c_mg"] += safe_val(food.get("vitamin_c_mg")) * ratio
  328. totals["minerals"]["calcium_mg"] += safe_val(food.get("calcium_mg")) * ratio
  329. totals["minerals"]["iron_mg"] += safe_val(food.get("iron_mg")) * ratio
  330. totals["minerals"]["potassium_mg"] += safe_val(food.get("potassium_mg")) * ratio
  331. totals["minerals"]["sodium_mg"] += safe_val(food.get("sodium_mg")) * ratio
  332. # Rounding to 2 decimal places
  333. totals["total_weight_g"] = round(totals["total_weight_g"], 2)
  334. for category in ["macros", "extended", "vitamins", "minerals"]:
  335. for key in totals[category]:
  336. totals[category][key] = round(totals[category][key], 2)
  337. return totals
  338. @app.post("/api/meals")
  339. async def save_meal(request: MealSaveRequest, current_user: dict = Depends(get_current_user)):
  340. """Securely save a named meal list for the authenticated user"""
  341. if not request.name.strip():
  342. raise HTTPException(status_code=400, detail="Meal name cannot be empty")
  343. if not request.items:
  344. raise HTTPException(status_code=400, detail="Meal items cannot be empty")
  345. items_list = [item.model_dump() for item in request.items]
  346. meal_id = save_user_meal(current_user['id'], request.name.strip(), items_list)
  347. if meal_id is None:
  348. raise HTTPException(status_code=500, detail="Failed to save meal to database")
  349. return {"status": "success", "meal_id": meal_id, "name": request.name.strip()}
  350. if __name__ == "__main__":
  351. import uvicorn
  352. uvicorn.run("main:app", host="127.0.0.1", port=8000, reload=True)