database.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469
  1. import sqlite3
  2. import os
  3. import logging
  4. import secrets
  5. from datetime import datetime, timedelta
  6. from typing import Optional, Dict, Any, List
  7. logger = logging.getLogger(__name__)
  8. # Locate db correctly in the same directory
  9. DB_PATH = os.path.join(os.path.dirname(__file__), "localfood.db")
  10. def get_db_connection():
  11. # Enable higher timeout and disable thread checks for FastAPI async compatibility
  12. conn = sqlite3.connect(DB_PATH, timeout=30.0, check_same_thread=False)
  13. conn.row_factory = sqlite3.Row
  14. # Enable Write-Ahead Log (WAL) mode for simultaneous read/write operations
  15. conn.execute('PRAGMA journal_mode=WAL')
  16. conn.execute('PRAGMA synchronous=NORMAL')
  17. return conn
  18. def create_tables():
  19. """Initialize the SQLite database with required tables"""
  20. conn = None
  21. try:
  22. conn = get_db_connection()
  23. cursor = conn.cursor()
  24. # Create users table securely locally
  25. cursor.execute('''
  26. CREATE TABLE IF NOT EXISTS users (
  27. id INTEGER PRIMARY KEY AUTOINCREMENT,
  28. username TEXT UNIQUE NOT NULL,
  29. password_hash TEXT NOT NULL,
  30. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
  31. )
  32. ''')
  33. # Create sessions table for database-backed tokens
  34. cursor.execute('''
  35. CREATE TABLE IF NOT EXISTS sessions (
  36. token TEXT PRIMARY KEY,
  37. user_id INTEGER NOT NULL,
  38. expires_at TIMESTAMP NOT NULL,
  39. FOREIGN KEY (user_id) REFERENCES users (id)
  40. )
  41. ''')
  42. # Create localized foods table based on Sprint 5 architecture
  43. cursor.execute('''
  44. CREATE TABLE IF NOT EXISTS foods (
  45. id INTEGER PRIMARY KEY AUTOINCREMENT,
  46. name TEXT NOT NULL,
  47. category TEXT,
  48. base_weight_g REAL DEFAULT 100.0,
  49. calories REAL DEFAULT 0.0,
  50. protein_g REAL DEFAULT 0.0,
  51. fat_g REAL DEFAULT 0.0,
  52. carbs_g REAL DEFAULT 0.0,
  53. fiber_g REAL DEFAULT 0.0,
  54. sugar_g REAL DEFAULT 0.0,
  55. sodium_mg REAL DEFAULT 0.0,
  56. vitamin_a_iu REAL DEFAULT 0.0,
  57. vitamin_c_mg REAL DEFAULT 0.0,
  58. calcium_mg REAL DEFAULT 0.0,
  59. iron_mg REAL DEFAULT 0.0,
  60. potassium_mg REAL DEFAULT 0.0,
  61. cholesterol_mg REAL DEFAULT 0.0,
  62. source TEXT DEFAULT 'System'
  63. )
  64. ''')
  65. # Create chat history table for Sprint 6 persistence
  66. cursor.execute('''
  67. CREATE TABLE IF NOT EXISTS chat_messages (
  68. id INTEGER PRIMARY KEY AUTOINCREMENT,
  69. user_id INTEGER NOT NULL,
  70. role TEXT NOT NULL,
  71. content TEXT NOT NULL,
  72. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
  73. FOREIGN KEY (user_id) REFERENCES users (id)
  74. )
  75. ''')
  76. # Create minimal user_profiles table for macro targets (US-07)
  77. cursor.execute('''
  78. CREATE TABLE IF NOT EXISTS user_profiles (
  79. user_id INTEGER PRIMARY KEY,
  80. target_calories INTEGER DEFAULT 2000,
  81. target_protein_g INTEGER DEFAULT 150,
  82. target_carbs_g INTEGER DEFAULT 200,
  83. target_fat_g INTEGER DEFAULT 65,
  84. FOREIGN KEY (user_id) REFERENCES users (id)
  85. )
  86. ''')
  87. # Create user-named meals table for Sprint 8
  88. cursor.execute('''
  89. CREATE TABLE IF NOT EXISTS saved_meals (
  90. id INTEGER PRIMARY KEY AUTOINCREMENT,
  91. user_id INTEGER NOT NULL,
  92. name TEXT NOT NULL,
  93. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
  94. FOREIGN KEY (user_id) REFERENCES users (id)
  95. )
  96. ''')
  97. # Create meal items table to link multiple foods to a single saved meal
  98. cursor.execute('''
  99. CREATE TABLE IF NOT EXISTS meal_items (
  100. id INTEGER PRIMARY KEY AUTOINCREMENT,
  101. meal_id INTEGER NOT NULL,
  102. food_id INTEGER NOT NULL,
  103. amount_g REAL NOT NULL,
  104. FOREIGN KEY (meal_id) REFERENCES saved_meals (id) ON DELETE CASCADE,
  105. FOREIGN KEY (food_id) REFERENCES foods (id)
  106. )
  107. ''')
  108. # Create index for rapid fuzzy search compatibility
  109. cursor.execute('CREATE INDEX IF NOT EXISTS idx_food_name ON foods(name COLLATE NOCASE)')
  110. cursor.execute('CREATE INDEX IF NOT EXISTS idx_saved_meals_user ON saved_meals(user_id)')
  111. conn.commit()
  112. logger.info("Database and tables initialized successfully.")
  113. except Exception as e:
  114. logger.error(f"Error initializing database: {e}")
  115. raise
  116. finally:
  117. if conn:
  118. conn.close()
  119. def save_user_meal(user_id: int, name: str, items: List[Dict[str, Any]]) -> Optional[int]:
  120. """Persist a collection of food items as a named meal list for a user"""
  121. conn = None
  122. try:
  123. conn = get_db_connection()
  124. cursor = conn.cursor()
  125. # 1. Create the meal header
  126. cursor.execute(
  127. "INSERT INTO saved_meals (user_id, name) VALUES (?, ?)",
  128. (user_id, name)
  129. )
  130. meal_id = cursor.lastrowid
  131. # 2. Add each item linked to this meal
  132. for item in items:
  133. cursor.execute(
  134. "INSERT INTO meal_items (meal_id, food_id, amount_g) VALUES (?, ?, ?)",
  135. (meal_id, item['food_id'], item['amount_g'])
  136. )
  137. conn.commit()
  138. return meal_id
  139. except Exception as e:
  140. logger.error(f"Error saving user meal: {e}")
  141. if conn: conn.rollback()
  142. return None
  143. finally:
  144. if conn: conn.close()
  145. def get_user_meals(user_id: int) -> List[Dict[str, Any]]:
  146. """Retrieve all saved meals for a user, including total macro calculations"""
  147. conn = None
  148. try:
  149. conn = get_db_connection()
  150. cursor = conn.cursor()
  151. # Fetch meal headers
  152. cursor.execute(
  153. "SELECT * FROM saved_meals WHERE user_id = ? ORDER BY created_at DESC",
  154. (user_id,)
  155. )
  156. meals = [dict(row) for row in cursor.fetchall()]
  157. # For each meal, fetch items and calculate totals
  158. for meal in meals:
  159. cursor.execute('''
  160. SELECT mi.amount_g, f.*
  161. FROM meal_items mi
  162. JOIN foods f ON mi.food_id = f.id
  163. WHERE mi.meal_id = ?
  164. ''', (meal['id'],))
  165. items = [dict(row) for row in cursor.fetchall()]
  166. # Calculate totals for the meal card summary
  167. meal['items'] = items
  168. meal['total_calories'] = sum((item['calories'] * item['amount_g'] / 100.0) for item in items)
  169. meal['total_protein'] = sum((item['protein_g'] * item['amount_g'] / 100.0) for item in items)
  170. meal['total_carbs'] = sum((item['carbs_g'] * item['amount_g'] / 100.0) for item in items)
  171. meal['total_fat'] = sum((item['fat_g'] * item['amount_g'] / 100.0) for item in items)
  172. return meals
  173. except Exception as e:
  174. logger.error(f"Error fetching user meals: {e}")
  175. return []
  176. finally:
  177. if conn: conn.close()
  178. def delete_user_meal(user_id: int, meal_id: int) -> bool:
  179. """Securely delete a meal and its items, ensuring ownership"""
  180. conn = None
  181. try:
  182. conn = get_db_connection()
  183. cursor = conn.cursor()
  184. # Verify ownership first
  185. cursor.execute("SELECT id FROM saved_meals WHERE id = ? AND user_id = ?", (meal_id, user_id))
  186. if not cursor.fetchone():
  187. return False
  188. # Delete items first (even if cascading is enabled, we stay transactional)
  189. cursor.execute("DELETE FROM meal_items WHERE meal_id = ?", (meal_id,))
  190. # Delete header
  191. cursor.execute("DELETE FROM saved_meals WHERE id = ?", (meal_id,))
  192. conn.commit()
  193. return True
  194. except Exception as e:
  195. logger.error(f"Error deleting meal: {e}")
  196. if conn: conn.rollback()
  197. return False
  198. finally:
  199. if conn: conn.close()
  200. def update_user_meal(user_id: int, meal_id: int, new_name: str) -> bool:
  201. """Updates the name of a meal, ensuring ownership"""
  202. conn = None
  203. try:
  204. conn = get_db_connection()
  205. cursor = conn.cursor()
  206. cursor.execute(
  207. "UPDATE saved_meals SET name = ? WHERE id = ? AND user_id = ?",
  208. (new_name, meal_id, user_id)
  209. )
  210. conn.commit()
  211. # rowcount will be 1 if updated, 0 if ID/ownership failed
  212. return cursor.rowcount > 0
  213. except Exception as e:
  214. logger.error(f"Error updating meal: {e}")
  215. if conn: conn.rollback()
  216. return False
  217. finally:
  218. if conn: conn.close()
  219. def get_user_by_username(username: str) -> Optional[Dict[str, Any]]:
  220. """Retrieve user dictionary if they exist"""
  221. conn = None
  222. try:
  223. conn = get_db_connection()
  224. cursor = conn.cursor()
  225. cursor.execute("SELECT * FROM users WHERE username = ?", (username,))
  226. row = cursor.fetchone()
  227. return dict(row) if row else None
  228. except Exception as e:
  229. logger.error(f"Database error fetching user: {e}")
  230. return None
  231. finally:
  232. if conn: conn.close()
  233. def create_user(username: str, password_hash: str) -> Optional[int]:
  234. """Creates a user securely. Returns user_id if successful, None if username exists."""
  235. conn = None
  236. try:
  237. conn = get_db_connection()
  238. cursor = conn.cursor()
  239. cursor.execute(
  240. "INSERT INTO users (username, password_hash) VALUES (?, ?)",
  241. (username, password_hash)
  242. )
  243. user_id = cursor.lastrowid
  244. conn.commit()
  245. return user_id
  246. except sqlite3.IntegrityError:
  247. return None
  248. except Exception as e:
  249. logger.error(f"Database error during user creation: {e}")
  250. raise
  251. finally:
  252. if conn: conn.close()
  253. def create_session(user_id: int) -> str:
  254. """Create a secure 32-character session token in the DB valid for 7 days"""
  255. token = secrets.token_urlsafe(32)
  256. expires_at = datetime.now() + timedelta(days=7)
  257. conn = None
  258. try:
  259. conn = get_db_connection()
  260. cursor = conn.cursor()
  261. cursor.execute(
  262. "INSERT INTO sessions (token, user_id, expires_at) VALUES (?, ?, ?)",
  263. (token, user_id, expires_at)
  264. )
  265. conn.commit()
  266. return token
  267. except Exception as e:
  268. logger.error(f"Error creating session: {e}")
  269. raise
  270. finally:
  271. if conn: conn.close()
  272. def get_user_from_token(token: str) -> Optional[Dict[str, Any]]:
  273. """Verify a session token and return the associated user data if valid and not expired"""
  274. conn = None
  275. try:
  276. conn = get_db_connection()
  277. cursor = conn.cursor()
  278. # Find user if token exists and hasn't expired
  279. cursor.execute('''
  280. SELECT users.* FROM users
  281. JOIN sessions ON users.id = sessions.user_id
  282. WHERE sessions.token = ? AND sessions.expires_at > ?
  283. ''', (token, datetime.now()))
  284. row = cursor.fetchone()
  285. return dict(row) if row else None
  286. except Exception as e:
  287. logger.error(f"Database error verifying token: {e}")
  288. return None
  289. finally:
  290. if conn: conn.close()
  291. def delete_session(token: str):
  292. """Securely remove a session token when the user logs out"""
  293. conn = None
  294. try:
  295. conn = get_db_connection()
  296. cursor = conn.cursor()
  297. cursor.execute("DELETE FROM sessions WHERE token = ?", (token,))
  298. conn.commit()
  299. except Exception as e:
  300. logger.error(f"Error deleting session: {e}")
  301. finally:
  302. if conn: conn.close()
  303. def search_foods_by_name(query: str, limit: int = 15) -> list[Dict[str, Any]]:
  304. """Securely search for foods matching a string query with relevance-based ordering"""
  305. conn = None
  306. try:
  307. conn = get_db_connection()
  308. cursor = conn.cursor()
  309. # SQL Injection safe query utilizing LIKE parameterization
  310. # We prioritize:
  311. # 1. Items NOT in 'Baby Foods'
  312. # 2. Shorter names (usually more fundamental ingredients)
  313. # 3. Alphabetical order as a tie-breaker
  314. q = f"%{query}%"
  315. prefix_match = f"{query}%"
  316. cursor.execute('''
  317. SELECT * FROM foods
  318. WHERE name LIKE ?
  319. ORDER BY
  320. CASE WHEN category = 'Baby Foods' THEN 1 ELSE 0 END,
  321. CASE WHEN name LIKE ? THEN 0 ELSE 1 END,
  322. LENGTH(name) ASC,
  323. name ASC
  324. LIMIT ?
  325. ''', (q, prefix_match, limit))
  326. rows = cursor.fetchall()
  327. return [dict(row) for row in rows]
  328. except Exception as e:
  329. logger.error(f"Error searching foods: {e}")
  330. return []
  331. finally:
  332. if conn: conn.close()
  333. def save_chat_message(user_id: int, role: str, content: str):
  334. """Persist a chat message to the database"""
  335. conn = None
  336. try:
  337. conn = get_db_connection()
  338. cursor = conn.cursor()
  339. cursor.execute(
  340. "INSERT INTO chat_messages (user_id, role, content) VALUES (?, ?, ?)",
  341. (user_id, role, content)
  342. )
  343. conn.commit()
  344. except Exception as e:
  345. logger.error(f"Error saving chat message: {e}")
  346. finally:
  347. if conn: conn.close()
  348. def get_user_chat_history(user_id: int, limit: int = 50) -> list[Dict[str, Any]]:
  349. """Retrieve the most recent chat messages for a user"""
  350. conn = None
  351. try:
  352. conn = get_db_connection()
  353. cursor = conn.cursor()
  354. # Order by created_at DESC to get recent ones, then reverse for display
  355. cursor.execute('''
  356. SELECT role, content FROM chat_messages
  357. WHERE user_id = ?
  358. ORDER BY created_at ASC
  359. LIMIT ?
  360. ''', (user_id, limit))
  361. rows = cursor.fetchall()
  362. return [dict(row) for row in rows]
  363. except Exception as e:
  364. logger.error(f"Error fetching chat history: {e}")
  365. return []
  366. finally:
  367. if conn: conn.close()
  368. def get_user_profile(user_id: int) -> Optional[Dict[str, Any]]:
  369. """Fetch the user's profile containing macro targets. Inserts defaults if none exists."""
  370. conn = None
  371. try:
  372. conn = get_db_connection()
  373. cursor = conn.cursor()
  374. cursor.execute("SELECT * FROM user_profiles WHERE user_id = ?", (user_id,))
  375. row = cursor.fetchone()
  376. if not row:
  377. # Create a default profile row if one does not exist
  378. cursor.execute('''
  379. INSERT INTO user_profiles (user_id) VALUES (?)
  380. ''', (user_id,))
  381. conn.commit()
  382. cursor.execute("SELECT * FROM user_profiles WHERE user_id = ?", (user_id,))
  383. row = cursor.fetchone()
  384. return dict(row) if row else None
  385. except Exception as e:
  386. logger.error(f"Error fetching user profile: {e}")
  387. return None
  388. finally:
  389. if conn: conn.close()
  390. def get_food_by_id(food_id: int) -> Optional[Dict[str, Any]]:
  391. """Retrieve a single food item by its unique ID"""
  392. conn = None
  393. try:
  394. conn = get_db_connection()
  395. cursor = conn.cursor()
  396. cursor.execute("SELECT * FROM foods WHERE id = ?", (food_id,))
  397. row = cursor.fetchone()
  398. return dict(row) if row else None
  399. except Exception as e:
  400. logger.error(f"Error fetching food by ID {food_id}: {e}")
  401. return None
  402. finally:
  403. if conn: conn.close()
  404. def get_foods_by_ids(food_ids: List[int]) -> List[Dict[str, Any]]:
  405. """Retrieve multiple food items by their unique IDs in bulk"""
  406. if not food_ids:
  407. return []
  408. conn = None
  409. try:
  410. conn = get_db_connection()
  411. cursor = conn.cursor()
  412. # Create placeholders for the IN clause
  413. placeholders = ', '.join(['?'] * len(food_ids))
  414. query = f"SELECT * FROM foods WHERE id IN ({placeholders})"
  415. cursor.execute(query, food_ids)
  416. rows = cursor.fetchall()
  417. return [dict(row) for row in rows]
  418. except Exception as e:
  419. logger.error(f"Error fetching foods by IDs {food_ids}: {e}")
  420. return []
  421. finally:
  422. if conn: conn.close()