| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424 |
- import json
- import logging
- import httpx
- import bcrypt
- from contextlib import asynccontextmanager
- from fastapi import FastAPI, HTTPException, Depends, Header
- from database import create_tables, create_user, get_user_by_username, create_session, get_user_from_token, delete_session, search_foods_by_name, save_chat_message, get_user_chat_history, get_user_profile, get_food_by_id, get_foods_by_ids
- from fastapi.responses import HTMLResponse, StreamingResponse
- from fastapi.staticfiles import StaticFiles
- from pydantic import BaseModel
- from typing import List, Generator, Optional
- logging.basicConfig(level=logging.INFO)
- logger = logging.getLogger(__name__)
- @asynccontextmanager
- async def lifespan(app: FastAPI):
- create_tables()
- yield
- app = FastAPI(title="LocalFoodAI Chat", lifespan=lifespan)
- # Use direct bcrypt for better environment compatibility
- def get_password_hash(password: str):
- # Hash requires bytes
- pwd_bytes = password.encode('utf-8')
- salt = bcrypt.gensalt()
- hashed = bcrypt.hashpw(pwd_bytes, salt)
- return hashed.decode('utf-8')
- def verify_password(plain_password: str, hashed_password: str):
- # bcrypt.checkpw handles verification
- return bcrypt.checkpw(plain_password.encode('utf-8'), hashed_password.encode('utf-8'))
- class UserCreate(BaseModel):
- username: str
- password: str
- class UserLogin(BaseModel):
- username: str
- password: str
- async def get_current_user(authorization: Optional[str] = Header(None)):
- if not authorization or not authorization.startswith("Bearer "):
- raise HTTPException(status_code=401, detail="Authentication required")
-
- token = authorization.split(" ")[1]
- user = get_user_from_token(token)
- if not user:
- raise HTTPException(status_code=401, detail="Invalid or expired session")
- return user
- OLLAMA_URL = "http://localhost:11434/api/chat"
- MODEL_NAME = "qwen3.5:4b"
- # Common stopwords to strip before searching the food database
- _STOPWORDS = {
- 'how', 'many', 'much', 'calories', 'does', 'have', 'has', 'is', 'are',
- 'in', 'the', 'a', 'an', 'of', 'for', 'with', 'what', 'tell', 'me',
- 'about', 'nutritional', 'value', 'nutrition', 'macro', 'macros',
- 'protein', 'fat', 'carbs', 'fiber', 'can', 'you', 'i', 'want', 'need',
- 'eat', 'eating', 'food', 'meal', 'diet', 'healthy', 'make', 'cook',
- 'recipe', 'per', '100g', 'gram', 'grams', 'serving'
- }
- def extract_food_context(messages: list) -> str | None:
- """Scan the last user message for food keywords and enrich with local DB data."""
- # Find the last user message
- last_user_msg = None
- for msg in reversed(messages):
- role = msg.get('role', '') if isinstance(msg, dict) else msg.role
- content = msg.get('content', '') if isinstance(msg, dict) else msg.content
- if role == 'user':
- last_user_msg = content
- break
-
- if not last_user_msg:
- return None
-
- # Extract meaningful keywords by removing stopwords
- words = last_user_msg.lower().replace('?', '').replace(',', '').split()
- keywords = [w for w in words if w not in _STOPWORDS and len(w) > 2]
-
- if not keywords:
- return None
-
- # Try each keyword against the local food database, collect unique results
- found_items = {}
- # Optimization: Only use the first 2 most relevant keywords to keep context small on CPU
- for kw in keywords[:2]:
- results = search_foods_by_name(kw, limit=2)
- for item in results:
- # Truncate extremely long USDA names for performance
- short_name = item['name'][:100] + ("..." if len(item['name']) > 100 else "")
- if short_name not in found_items:
- found_items[short_name] = item
- if len(found_items) >= 3:
- break
-
- if not found_items:
- return None
-
- # Build a structured context block for the system prompt
- lines = [
- "[SYSTEM: NUTRITIONAL ANALYST MODE]",
- "You are the LocalFoodAI Analyst. Use ONLY verified local data for values.",
- "CRITICAL: Provide direct, concise answers. Skip all internal monologues, <thought> tags, or reasoning steps.",
- "For each food discussed, you MUST follow this structure:",
- "1. Header: ### 🥗 [Name] (per 100g)",
- "2. Macros: A markdown table for Cal, P, F, C, Fib, Sug, Chol.",
- "3. Micros: A bulleted list for Na, Ca, Fe, K, VitA, VitC.",
- "4. Insight: A 1-sentence analysis of the food's nutritional profile.",
- "Always prioritize local data over training memory. If a nutrient is missing, say 'Data not available'.",
- ""
- ]
- for name, item in found_items.items():
- # Compact, token-efficient format for the LLM
- line = (
- f"- {name}: {item['calories']}kcal | P:{item['protein_g']}g | F:{item['fat_g']}g | C:{item['carbs_g']}g | "
- f"Fib:{item['fiber_g']}g | Sug:{item['sugar_g']}g | Chol:{item['cholesterol_mg']}mg | "
- f"Na:{item['sodium_mg']}mg | Ca:{item['calcium_mg']}mg | Fe:{item['iron_mg']}mg | "
- f"K:{item['potassium_mg']}mg | VitA:{item['vitamin_a_iu']}IU | VitC:{item['vitamin_c_mg']}mg"
- )
- lines.append(line)
-
- return "\n".join(lines)
- # Mount static files to serve the frontend
- app.mount("/static", StaticFiles(directory="static"), name="static")
- class ChatMessage(BaseModel):
- role: str
- content: str
- class ChatRequest(BaseModel):
- messages: List[ChatMessage]
- class MealItemInput(BaseModel):
- food_id: int
- amount_g: float
- class MealCalculateRequest(BaseModel):
- items: List[MealItemInput]
- class MealSaveRequest(BaseModel):
- name: str
- items: List[MealItemInput]
- @app.get("/", response_class=HTMLResponse)
- async def read_root():
- """Serve the chat interface HTML"""
- try:
- with open("static/index.html", "r", encoding="utf-8") as f:
- return HTMLResponse(content=f.read())
- except FileNotFoundError:
- return HTMLResponse(content="<h1>Welcome to LocalFoodAI</h1><p>static/index.html not found. Please create the frontend.</p>")
- @app.post("/api/register")
- async def register_user(user: UserCreate):
- if len(user.username.strip()) < 3:
- raise HTTPException(status_code=400, detail="Username must be at least 3 characters")
- if len(user.password.strip()) < 6:
- raise HTTPException(status_code=400, detail="Password must be at least 6 characters")
-
- hashed_password = get_password_hash(user.password)
- user_id = create_user(user.username.strip(), hashed_password)
- if not user_id:
- raise HTTPException(status_code=400, detail="Username already exists")
-
- # Auto-login after registration
- token = create_session(user_id)
- return {"message": "User registered successfully", "token": token, "username": user.username.strip()}
- @app.post("/api/login")
- async def login_user(user: UserLogin):
- db_user = get_user_by_username(user.username.strip())
- if not db_user:
- raise HTTPException(status_code=401, detail="Invalid username or password")
-
- if not verify_password(user.password, db_user["password_hash"]):
- raise HTTPException(status_code=401, detail="Invalid username or password")
-
- token = create_session(db_user["id"])
- return {"status": "success", "username": db_user["username"], "token": token}
- @app.post("/api/logout")
- async def logout(authorization: Optional[str] = Header(None)):
- if authorization and authorization.startswith("Bearer "):
- token = authorization.split(" ")[1]
- delete_session(token)
- return {"message": "Logged out successfully"}
- @app.get("/api/macros/targets")
- async def get_macro_targets(current_user: dict = Depends(get_current_user)):
- """API endpoint to securely fetch the user's current macronutrient targets"""
- profile = get_user_profile(current_user['id'])
-
- if not profile:
- # Fallback to defaults in case database insertion failed
- return {
- "calories": 2000,
- "protein_g": 150,
- "carbs_g": 200,
- "fat_g": 65
- }
-
- return {
- "calories": profile.get("target_calories", 2000),
- "protein_g": profile.get("target_protein_g", 150),
- "carbs_g": profile.get("target_carbs_g", 200),
- "fat_g": profile.get("target_fat_g", 65)
- }
- @app.post("/chat")
- async def chat_endpoint(request: ChatRequest, current_user: dict = Depends(get_current_user)):
- """Proxy chat requests to the local Ollama instance with streaming support.
- Automatically enriches prompts with verified local SQLite nutritional data.
- """
- # Keep only the last 6 messages for context window performance on CPU
- all_messages = [msg.model_dump() for msg in request.messages]
- messages = all_messages[-6:]
-
- # Save the latest user message to DB
- if messages and messages[-1]['role'] == 'user':
- save_chat_message(current_user['id'], 'user', messages[-1]['content'])
-
- # --- TG-35: Local SQL RAG Enrichment ---
- db_context = extract_food_context(messages)
- if db_context:
- # Prepend as a system message so it acts as grounded knowledge
- # We ensure it's a short, concise instruction to prevent context bloat
- messages = [{"role": "system", "content": db_context}] + messages
- logger.info(f"[Chat] User '{current_user['username']}' is chatting. Context items: {'Yes' if db_context else 'No'}. Message count: {len(messages)}")
-
- payload = {
- "model": MODEL_NAME,
- "messages": messages,
- "stream": True,
- "think": False # Disable reasoning/thinking mode for faster responses
- }
-
- async def generate_response():
- try:
- bot_full_response = ""
- async with httpx.AsyncClient(timeout=300.0) as client:
- # Use a combined timeout for the entire request
- async with client.stream("POST", OLLAMA_URL, json=payload, timeout=300.0) as response:
- if response.status_code != 200:
- error_detail = await response.aread()
- logger.error(f"Ollama returned error {response.status_code}: {error_detail}")
- yield f"data: {json.dumps({'error': f'LLM Error ({response.status_code})'})}\n\n"
- return
- async for line in response.aiter_lines():
- if line:
- try:
- data = json.loads(line)
- if "message" in data and "content" in data["message"]:
- content = data["message"]["content"]
- bot_full_response += content
- yield f"data: {json.dumps({'content': content})}\n\n"
- if data.get("done"):
- break
- except json.JSONDecodeError:
- continue
-
- # Save final bot response to DB
- if bot_full_response.strip():
- save_chat_message(current_user['id'], 'assistant', bot_full_response)
-
- except Exception as e:
- logger.exception(f"Unexpected error during chat stream: {e}")
- yield f"data: {json.dumps({'error': 'A technical error occurred while generating the response.'})}\n\n"
- return StreamingResponse(generate_response(), media_type="text/event-stream")
- @app.get("/api/chat/history")
- async def get_history(current_user: dict = Depends(get_current_user)):
- """Fetch the chat history for the authenticated user"""
- history = get_user_chat_history(current_user['id'])
- return {"history": history}
- @app.get("/api/food/search")
- async def search_food(q: str, current_user: dict = Depends(get_current_user)):
- """API endpoint to search for food items securely using token authentication"""
- if not q or len(q.strip()) < 1:
- return {"results": []}
-
- logger.info(f"User {current_user['username']} searched for [{q}]")
- results = search_foods_by_name(q.strip(), limit=15)
- return {"results": results}
- @app.get("/api/food/{food_id}")
- async def get_food_detail(food_id: int, current_user: dict = Depends(get_current_user)):
- """API endpoint to fetch structured nutritional details for a specific food item"""
- food = get_food_by_id(food_id)
- if not food:
- raise HTTPException(status_code=404, detail="Food item not found")
-
- # Structure the data as proposed in the implementation plan
- structured_data = {
- "id": food["id"],
- "name": food["name"],
- "category": food["category"],
- "base_weight_g": food["base_weight_g"],
- "macros": {
- "calories": food["calories"],
- "protein_g": food["protein_g"],
- "fat_g": food["fat_g"],
- "carbs_g": food["carbs_g"]
- },
- "extended": {
- "fiber_g": food["fiber_g"],
- "sugar_g": food["sugar_g"],
- "cholesterol_mg": food["cholesterol_mg"]
- },
- "vitamins": {
- "vitamin_a_iu": food["vitamin_a_iu"],
- "vitamin_c_mg": food["vitamin_c_mg"]
- },
- "minerals": {
- "calcium_mg": food["calcium_mg"],
- "iron_mg": food["iron_mg"],
- "potassium_mg": food["potassium_mg"],
- "sodium_mg": food["sodium_mg"]
- },
- "source": food["source"]
- }
-
- return structured_data
- @app.post("/api/meal/calculate")
- async def calculate_meal(request: MealCalculateRequest, current_user: dict = Depends(get_current_user)):
- """Calculate the total nutritional value for a combined list of foods and their custom weights."""
- if not request.items:
- return {"error": "Meal is empty"}
-
- # Validation: Cast to floats and ensure > 0
- items = []
- for item in request.items:
- try:
- amount = float(item.amount_g)
- if amount <= 0:
- raise HTTPException(status_code=400, detail="Quantity must be greater than 0g for all items.")
- items.append({"food_id": item.food_id, "amount_g": amount})
- except ValueError:
- raise HTTPException(status_code=400, detail=f"Invalid amount for food ID {item.food_id}")
-
- # Bulk fetch from DB
- requested_ids = list(set(item["food_id"] for item in items))
- foods_data = get_foods_by_ids(requested_ids)
-
- # Map for easy lookup
- foods_map = {food["id"]: food for food in foods_data}
-
- # Fail-fast: Check if all requested IDs exist
- found_ids = set(foods_map.keys())
- missing_ids = [fid for fid in requested_ids if fid not in found_ids]
- if missing_ids:
- raise HTTPException(status_code=400, detail=f"Invalid food IDs provided: {missing_ids}")
-
- # Initialize aggregator
- totals = {
- "total_weight_g": 0.0,
- "macros": {"calories": 0.0, "protein_g": 0.0, "fat_g": 0.0, "carbs_g": 0.0},
- "extended": {"fiber_g": 0.0, "sugar_g": 0.0, "cholesterol_mg": 0.0},
- "vitamins": {"vitamin_a_iu": 0.0, "vitamin_c_mg": 0.0},
- "minerals": {"calcium_mg": 0.0, "iron_mg": 0.0, "potassium_mg": 0.0, "sodium_mg": 0.0}
- }
-
- def safe_val(val):
- return float(val) if val is not None else 0.0
-
- for item in items:
- food = foods_map[item["food_id"]]
-
- ratio = item["amount_g"] / 100.0
- totals["total_weight_g"] += item["amount_g"]
-
- totals["macros"]["calories"] += safe_val(food.get("calories")) * ratio
- totals["macros"]["protein_g"] += safe_val(food.get("protein_g")) * ratio
- totals["macros"]["fat_g"] += safe_val(food.get("fat_g")) * ratio
- totals["macros"]["carbs_g"] += safe_val(food.get("carbs_g")) * ratio
-
- totals["extended"]["fiber_g"] += safe_val(food.get("fiber_g")) * ratio
- totals["extended"]["sugar_g"] += safe_val(food.get("sugar_g")) * ratio
- totals["extended"]["cholesterol_mg"] += safe_val(food.get("cholesterol_mg")) * ratio
-
- totals["vitamins"]["vitamin_a_iu"] += safe_val(food.get("vitamin_a_iu")) * ratio
- totals["vitamins"]["vitamin_c_mg"] += safe_val(food.get("vitamin_c_mg")) * ratio
-
- totals["minerals"]["calcium_mg"] += safe_val(food.get("calcium_mg")) * ratio
- totals["minerals"]["iron_mg"] += safe_val(food.get("iron_mg")) * ratio
- totals["minerals"]["potassium_mg"] += safe_val(food.get("potassium_mg")) * ratio
- totals["minerals"]["sodium_mg"] += safe_val(food.get("sodium_mg")) * ratio
-
- # Rounding to 2 decimal places
- totals["total_weight_g"] = round(totals["total_weight_g"], 2)
- for category in ["macros", "extended", "vitamins", "minerals"]:
- for key in totals[category]:
- totals[category][key] = round(totals[category][key], 2)
-
- return totals
- @app.post("/api/meals")
- async def save_meal(request: MealSaveRequest, current_user: dict = Depends(get_current_user)):
- """Securely save a named meal list for the authenticated user"""
- if not request.name.strip():
- raise HTTPException(status_code=400, detail="Meal name cannot be empty")
- if not request.items:
- raise HTTPException(status_code=400, detail="Meal items cannot be empty")
-
- items_list = [item.model_dump() for item in request.items]
- meal_id = save_user_meal(current_user['id'], request.name.strip(), items_list)
-
- if meal_id is None:
- raise HTTPException(status_code=500, detail="Failed to save meal to database")
-
- return {"status": "success", "meal_id": meal_id, "name": request.name.strip()}
- if __name__ == "__main__":
- import uvicorn
- uvicorn.run("main:app", host="127.0.0.1", port=8000, reload=True)
|