Переглянути джерело

TG-53: Verify mathematical accuracy with unit tests; fix missing List import in database.py

FerRo988 1 тиждень тому
батько
коміт
477dd3e77b
2 змінених файлів з 138 додано та 1 видалено
  1. 24 1
      database.py
  2. 114 0
      test_math_logic.py

+ 24 - 1
database.py

@@ -3,7 +3,7 @@ import os
 import logging
 import secrets
 from datetime import datetime, timedelta
-from typing import Optional, Dict, Any
+from typing import Optional, Dict, Any, List
 
 logger = logging.getLogger(__name__)
 
@@ -305,3 +305,26 @@ def get_food_by_id(food_id: int) -> Optional[Dict[str, Any]]:
         return None
     finally:
         if conn: conn.close()
+
+def get_foods_by_ids(food_ids: List[int]) -> List[Dict[str, Any]]:
+    """Retrieve multiple food items by their unique IDs in bulk"""
+    if not food_ids:
+        return []
+        
+    conn = None
+    try:
+        conn = get_db_connection()
+        cursor = conn.cursor()
+        
+        # Create placeholders for the IN clause
+        placeholders = ', '.join(['?'] * len(food_ids))
+        query = f"SELECT * FROM foods WHERE id IN ({placeholders})"
+        
+        cursor.execute(query, food_ids)
+        rows = cursor.fetchall()
+        return [dict(row) for row in rows]
+    except Exception as e:
+        logger.error(f"Error fetching foods by IDs {food_ids}: {e}")
+        return []
+    finally:
+        if conn: conn.close()

+ 114 - 0
test_math_logic.py

@@ -0,0 +1,114 @@
+import unittest
+from unittest.mock import patch, MagicMock
+from fastapi.testclient import TestClient
+from main import app
+
+class TestMealMathLogic(unittest.TestCase):
+    def setUp(self):
+        self.client = TestClient(app)
+        # Mock user authentication for the calculation endpoint
+        self.headers = {"Authorization": "Bearer mock-token"}
+
+    @patch("main.get_user_from_token")
+    @patch("main.get_foods_by_ids")
+    def test_math_scaling_and_aggregation(self, mock_get_foods, mock_auth):
+        # 1. Setup mocks
+        mock_auth.return_value = {"id": 1, "username": "testuser"}
+        
+        # Mock database return for two foods
+        mock_get_foods.return_value = [
+            {
+                "id": 1, "name": "Food A", 
+                "calories": 100.0, "protein_g": 10.0, "fat_g": 5.0, "carbs_g": 2.0,
+                "fiber_g": 1.0, "sugar_g": 0.5, "cholesterol_mg": 0.0,
+                "vitamin_a_iu": 0.0, "vitamin_c_mg": 0.0,
+                "calcium_mg": 0.0, "iron_mg": 0.0, "potassium_mg": 0.0, "sodium_mg": 0.0
+            },
+            {
+                "id": 2, "name": "Food B", 
+                "calories": 200.0, "protein_g": 20.0, "fat_g": 10.0, "carbs_g": 4.0,
+                "fiber_g": 2.0, "sugar_g": 1.0, "cholesterol_mg": 5.0,
+                "vitamin_a_iu": 100.0, "vitamin_c_mg": 10.0,
+                "calcium_mg": 50.0, "iron_mg": 2.0, "potassium_mg": 300.0, "sodium_mg": 100.0
+            }
+        ]
+
+        # 2. Test Case: 50g of Food A and 150g of Food B
+        # Expected Food A (50g): Cal=50, Pro=5
+        # Expected Food B (150g): Cal=300, Pro=30
+        # Totals: Cal=350, Pro=35, Weight=200
+        
+        payload = {
+            "items": [
+                {"food_id": 1, "amount_g": 50},
+                {"food_id": 2, "amount_g": 150}
+            ]
+        }
+        
+        response = self.client.post("/api/meal/calculate", json=payload, headers=self.headers)
+        
+        self.assertEqual(response.status_code, 200)
+        data = response.json()
+        
+        self.assertEqual(data["total_weight_g"], 200.0)
+        self.assertEqual(data["macros"]["calories"], 350.0)
+        self.assertEqual(data["macros"]["protein_g"], 35.0)
+        self.assertEqual(data["macros"]["fat_g"], 17.5) # (5 * 0.5) + (10 * 1.5) = 2.5 + 15 = 17.5
+        self.assertEqual(data["macros"]["carbs_g"], 7.0) # (2 * 0.5) + (4 * 1.5) = 1.0 + 6.0 = 7.0
+
+    @patch("main.get_user_from_token")
+    @patch("main.get_foods_by_ids")
+    def test_precision_rounding(self, mock_get_foods, mock_auth):
+        mock_auth.return_value = {"id": 1, "username": "testuser"}
+        
+        # Mock food with irregular values to trigger rounding
+        mock_get_foods.return_value = [
+            {
+                "id": 1, "name": "Irregular Food", 
+                "calories": 123.456, "protein_g": 10.123, "fat_g": 5.555, "carbs_g": 0.0,
+                "fiber_g": 0.0, "sugar_g": 0.0, "cholesterol_mg": 0.0,
+                "vitamin_a_iu": 0.0, "vitamin_c_mg": 0.0,
+                "calcium_mg": 0.0, "iron_mg": 0.0, "potassium_mg": 0.0, "sodium_mg": 0.0
+            }
+        ]
+
+        # 123g of Food A -> Ratio 1.23
+        # Calories: 123.456 * 1.23 = 151.85088 -> Rounded to 151.85
+        # Protein: 10.123 * 1.23 = 12.45129 -> Rounded to 12.45
+        # Fat: 5.555 * 1.23 = 6.83265 -> Rounded to 6.83
+        
+        payload = {"items": [{"food_id": 1, "amount_g": 123}]}
+        response = self.client.post("/api/meal/calculate", json=payload, headers=self.headers)
+        
+        data = response.json()
+        self.assertEqual(data["macros"]["calories"], 151.85)
+        self.assertEqual(data["macros"]["protein_g"], 12.45)
+        self.assertEqual(data["macros"]["fat_g"], 6.83)
+
+    @patch("main.get_user_from_token")
+    @patch("main.get_foods_by_ids")
+    def test_null_value_handling(self, mock_get_foods, mock_auth):
+        mock_auth.return_value = {"id": 1, "username": "testuser"}
+        
+        # Mock food with NULLs
+        mock_get_foods.return_value = [
+            {
+                "id": 1, "name": "Null Food", 
+                "calories": None, "protein_g": 10.0, "fat_g": None, "carbs_g": 2.0,
+                "fiber_g": None, "sugar_g": None, "cholesterol_mg": None,
+                "vitamin_a_iu": None, "vitamin_c_mg": None,
+                "calcium_mg": None, "iron_mg": None, "potassium_mg": None, "sodium_mg": None
+            }
+        ]
+        
+        payload = {"items": [{"food_id": 1, "amount_g": 100}]}
+        response = self.client.post("/api/meal/calculate", json=payload, headers=self.headers)
+        
+        self.assertEqual(response.status_code, 200)
+        data = response.json()
+        self.assertEqual(data["macros"]["calories"], 0.0)
+        self.assertEqual(data["macros"]["fat_g"], 0.0)
+        self.assertEqual(data["extended"]["fiber_g"], 0.0)
+
+if __name__ == "__main__":
+    unittest.main()