|
|
@@ -4,7 +4,7 @@ import httpx
|
|
|
import bcrypt
|
|
|
from contextlib import asynccontextmanager
|
|
|
from fastapi import FastAPI, HTTPException, Depends, Header
|
|
|
-from database import create_tables, create_user, get_user_by_username, create_session, get_user_from_token, delete_session, search_foods_by_name, save_chat_message, get_user_chat_history, get_user_profile, get_food_by_id
|
|
|
+from database import create_tables, create_user, get_user_by_username, create_session, get_user_from_token, delete_session, search_foods_by_name, save_chat_message, get_user_chat_history, get_user_profile, get_food_by_id, get_foods_by_ids
|
|
|
from fastapi.responses import HTMLResponse, StreamingResponse
|
|
|
from fastapi.staticfiles import StaticFiles
|
|
|
from pydantic import BaseModel
|
|
|
@@ -134,6 +134,13 @@ class ChatMessage(BaseModel):
|
|
|
class ChatRequest(BaseModel):
|
|
|
messages: List[ChatMessage]
|
|
|
|
|
|
+class MealItemInput(BaseModel):
|
|
|
+ food_id: int
|
|
|
+ amount_g: float
|
|
|
+
|
|
|
+class MealCalculateRequest(BaseModel):
|
|
|
+ items: List[MealItemInput]
|
|
|
+
|
|
|
@app.get("/", response_class=HTMLResponse)
|
|
|
async def read_root():
|
|
|
"""Serve the chat interface HTML"""
|
|
|
@@ -317,6 +324,79 @@ async def get_food_detail(food_id: int, current_user: dict = Depends(get_current
|
|
|
|
|
|
return structured_data
|
|
|
|
|
|
+@app.post("/api/meal/calculate")
|
|
|
+async def calculate_meal(request: MealCalculateRequest, current_user: dict = Depends(get_current_user)):
|
|
|
+ """Calculate the total nutritional value for a combined list of foods and their custom weights."""
|
|
|
+ if not request.items:
|
|
|
+ return {"error": "Meal is empty"}
|
|
|
+
|
|
|
+ # Validation: Cast to floats and ensure > 0
|
|
|
+ items = []
|
|
|
+ for item in request.items:
|
|
|
+ try:
|
|
|
+ amount = float(item.amount_g)
|
|
|
+ if amount <= 0:
|
|
|
+ raise HTTPException(status_code=400, detail="Quantity must be greater than 0g for all items.")
|
|
|
+ items.append({"food_id": item.food_id, "amount_g": amount})
|
|
|
+ except ValueError:
|
|
|
+ raise HTTPException(status_code=400, detail=f"Invalid amount for food ID {item.food_id}")
|
|
|
+
|
|
|
+ # Bulk fetch from DB
|
|
|
+ requested_ids = list(set(item["food_id"] for item in items))
|
|
|
+ foods_data = get_foods_by_ids(requested_ids)
|
|
|
+
|
|
|
+ # Map for easy lookup
|
|
|
+ foods_map = {food["id"]: food for food in foods_data}
|
|
|
+
|
|
|
+ # Fail-fast: Check if all requested IDs exist
|
|
|
+ found_ids = set(foods_map.keys())
|
|
|
+ missing_ids = [fid for fid in requested_ids if fid not in found_ids]
|
|
|
+ if missing_ids:
|
|
|
+ raise HTTPException(status_code=400, detail=f"Invalid food IDs provided: {missing_ids}")
|
|
|
+
|
|
|
+ # Initialize aggregator
|
|
|
+ totals = {
|
|
|
+ "total_weight_g": 0.0,
|
|
|
+ "macros": {"calories": 0.0, "protein_g": 0.0, "fat_g": 0.0, "carbs_g": 0.0},
|
|
|
+ "extended": {"fiber_g": 0.0, "sugar_g": 0.0, "cholesterol_mg": 0.0},
|
|
|
+ "vitamins": {"vitamin_a_iu": 0.0, "vitamin_c_mg": 0.0},
|
|
|
+ "minerals": {"calcium_mg": 0.0, "iron_mg": 0.0, "potassium_mg": 0.0, "sodium_mg": 0.0}
|
|
|
+ }
|
|
|
+
|
|
|
+ def safe_val(val):
|
|
|
+ return float(val) if val is not None else 0.0
|
|
|
+
|
|
|
+ for item in items:
|
|
|
+ food = foods_map[item["food_id"]]
|
|
|
+
|
|
|
+ ratio = item["amount_g"] / 100.0
|
|
|
+ totals["total_weight_g"] += item["amount_g"]
|
|
|
+
|
|
|
+ totals["macros"]["calories"] += safe_val(food.get("calories")) * ratio
|
|
|
+ totals["macros"]["protein_g"] += safe_val(food.get("protein_g")) * ratio
|
|
|
+ totals["macros"]["fat_g"] += safe_val(food.get("fat_g")) * ratio
|
|
|
+ totals["macros"]["carbs_g"] += safe_val(food.get("carbs_g")) * ratio
|
|
|
+
|
|
|
+ totals["extended"]["fiber_g"] += safe_val(food.get("fiber_g")) * ratio
|
|
|
+ totals["extended"]["sugar_g"] += safe_val(food.get("sugar_g")) * ratio
|
|
|
+ totals["extended"]["cholesterol_mg"] += safe_val(food.get("cholesterol_mg")) * ratio
|
|
|
+
|
|
|
+ totals["vitamins"]["vitamin_a_iu"] += safe_val(food.get("vitamin_a_iu")) * ratio
|
|
|
+ totals["vitamins"]["vitamin_c_mg"] += safe_val(food.get("vitamin_c_mg")) * ratio
|
|
|
+
|
|
|
+ totals["minerals"]["calcium_mg"] += safe_val(food.get("calcium_mg")) * ratio
|
|
|
+ totals["minerals"]["iron_mg"] += safe_val(food.get("iron_mg")) * ratio
|
|
|
+ totals["minerals"]["potassium_mg"] += safe_val(food.get("potassium_mg")) * ratio
|
|
|
+ totals["minerals"]["sodium_mg"] += safe_val(food.get("sodium_mg")) * ratio
|
|
|
+
|
|
|
+ # Rounding to 2 decimal places
|
|
|
+ totals["total_weight_g"] = round(totals["total_weight_g"], 2)
|
|
|
+ for category in ["macros", "extended", "vitamins", "minerals"]:
|
|
|
+ for key in totals[category]:
|
|
|
+ totals[category][key] = round(totals[category][key], 2)
|
|
|
+
|
|
|
+ return totals
|
|
|
+
|
|
|
if __name__ == "__main__":
|
|
|
import uvicorn
|
|
|
uvicorn.run("main:app", host="127.0.0.1", port=8000, reload=True)
|