소스 검색

TG-49: Implement POST /api/meal/calculate endpoint with strict validation

FerRo988 1 주 전
부모
커밋
ddc8149ee8
1개의 변경된 파일81개의 추가작업 그리고 1개의 파일을 삭제
  1. 81 1
      main.py

+ 81 - 1
main.py

@@ -4,7 +4,7 @@ import httpx
 import bcrypt
 from contextlib import asynccontextmanager
 from fastapi import FastAPI, HTTPException, Depends, Header
-from database import create_tables, create_user, get_user_by_username, create_session, get_user_from_token, delete_session, search_foods_by_name, save_chat_message, get_user_chat_history, get_user_profile, get_food_by_id
+from database import create_tables, create_user, get_user_by_username, create_session, get_user_from_token, delete_session, search_foods_by_name, save_chat_message, get_user_chat_history, get_user_profile, get_food_by_id, get_foods_by_ids
 from fastapi.responses import HTMLResponse, StreamingResponse
 from fastapi.staticfiles import StaticFiles
 from pydantic import BaseModel
@@ -134,6 +134,13 @@ class ChatMessage(BaseModel):
 class ChatRequest(BaseModel):
     messages: List[ChatMessage]
 
+class MealItemInput(BaseModel):
+    food_id: int
+    amount_g: float
+
+class MealCalculateRequest(BaseModel):
+    items: List[MealItemInput]
+
 @app.get("/", response_class=HTMLResponse)
 async def read_root():
     """Serve the chat interface HTML"""
@@ -317,6 +324,79 @@ async def get_food_detail(food_id: int, current_user: dict = Depends(get_current
     
     return structured_data
 
+@app.post("/api/meal/calculate")
+async def calculate_meal(request: MealCalculateRequest, current_user: dict = Depends(get_current_user)):
+    """Calculate the total nutritional value for a combined list of foods and their custom weights."""
+    if not request.items:
+        return {"error": "Meal is empty"}
+        
+    # Validation: Cast to floats and ensure > 0
+    items = []
+    for item in request.items:
+        try:
+            amount = float(item.amount_g)
+            if amount <= 0:
+                raise HTTPException(status_code=400, detail="Quantity must be greater than 0g for all items.")
+            items.append({"food_id": item.food_id, "amount_g": amount})
+        except ValueError:
+            raise HTTPException(status_code=400, detail=f"Invalid amount for food ID {item.food_id}")
+            
+    # Bulk fetch from DB
+    requested_ids = list(set(item["food_id"] for item in items))
+    foods_data = get_foods_by_ids(requested_ids)
+    
+    # Map for easy lookup
+    foods_map = {food["id"]: food for food in foods_data}
+    
+    # Fail-fast: Check if all requested IDs exist
+    found_ids = set(foods_map.keys())
+    missing_ids = [fid for fid in requested_ids if fid not in found_ids]
+    if missing_ids:
+        raise HTTPException(status_code=400, detail=f"Invalid food IDs provided: {missing_ids}")
+    
+    # Initialize aggregator
+    totals = {
+        "total_weight_g": 0.0,
+        "macros": {"calories": 0.0, "protein_g": 0.0, "fat_g": 0.0, "carbs_g": 0.0},
+        "extended": {"fiber_g": 0.0, "sugar_g": 0.0, "cholesterol_mg": 0.0},
+        "vitamins": {"vitamin_a_iu": 0.0, "vitamin_c_mg": 0.0},
+        "minerals": {"calcium_mg": 0.0, "iron_mg": 0.0, "potassium_mg": 0.0, "sodium_mg": 0.0}
+    }
+    
+    def safe_val(val):
+        return float(val) if val is not None else 0.0
+        
+    for item in items:
+        food = foods_map[item["food_id"]]
+        
+        ratio = item["amount_g"] / 100.0
+        totals["total_weight_g"] += item["amount_g"]
+        
+        totals["macros"]["calories"] += safe_val(food.get("calories")) * ratio
+        totals["macros"]["protein_g"] += safe_val(food.get("protein_g")) * ratio
+        totals["macros"]["fat_g"] += safe_val(food.get("fat_g")) * ratio
+        totals["macros"]["carbs_g"] += safe_val(food.get("carbs_g")) * ratio
+        
+        totals["extended"]["fiber_g"] += safe_val(food.get("fiber_g")) * ratio
+        totals["extended"]["sugar_g"] += safe_val(food.get("sugar_g")) * ratio
+        totals["extended"]["cholesterol_mg"] += safe_val(food.get("cholesterol_mg")) * ratio
+        
+        totals["vitamins"]["vitamin_a_iu"] += safe_val(food.get("vitamin_a_iu")) * ratio
+        totals["vitamins"]["vitamin_c_mg"] += safe_val(food.get("vitamin_c_mg")) * ratio
+        
+        totals["minerals"]["calcium_mg"] += safe_val(food.get("calcium_mg")) * ratio
+        totals["minerals"]["iron_mg"] += safe_val(food.get("iron_mg")) * ratio
+        totals["minerals"]["potassium_mg"] += safe_val(food.get("potassium_mg")) * ratio
+        totals["minerals"]["sodium_mg"] += safe_val(food.get("sodium_mg")) * ratio
+        
+    # Rounding to 2 decimal places
+    totals["total_weight_g"] = round(totals["total_weight_g"], 2)
+    for category in ["macros", "extended", "vitamins", "minerals"]:
+        for key in totals[category]:
+            totals[category][key] = round(totals[category][key], 2)
+            
+    return totals
+
 if __name__ == "__main__":
     import uvicorn
     uvicorn.run("main:app", host="127.0.0.1", port=8000, reload=True)