database.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384
  1. import sqlite3
  2. import os
  3. import logging
  4. import secrets
  5. from datetime import datetime, timedelta
  6. from typing import Optional, Dict, Any, List
  7. logger = logging.getLogger(__name__)
  8. # Locate db correctly in the same directory
  9. DB_PATH = os.path.join(os.path.dirname(__file__), "localfood.db")
  10. def get_db_connection():
  11. # Enable higher timeout and disable thread checks for FastAPI async compatibility
  12. conn = sqlite3.connect(DB_PATH, timeout=30.0, check_same_thread=False)
  13. conn.row_factory = sqlite3.Row
  14. # Enable Write-Ahead Log (WAL) mode for simultaneous read/write operations
  15. conn.execute('PRAGMA journal_mode=WAL')
  16. conn.execute('PRAGMA synchronous=NORMAL')
  17. return conn
  18. def create_tables():
  19. """Initialize the SQLite database with required tables"""
  20. conn = None
  21. try:
  22. conn = get_db_connection()
  23. cursor = conn.cursor()
  24. # Create users table securely locally
  25. cursor.execute('''
  26. CREATE TABLE IF NOT EXISTS users (
  27. id INTEGER PRIMARY KEY AUTOINCREMENT,
  28. username TEXT UNIQUE NOT NULL,
  29. password_hash TEXT NOT NULL,
  30. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
  31. )
  32. ''')
  33. # Create sessions table for database-backed tokens
  34. cursor.execute('''
  35. CREATE TABLE IF NOT EXISTS sessions (
  36. token TEXT PRIMARY KEY,
  37. user_id INTEGER NOT NULL,
  38. expires_at TIMESTAMP NOT NULL,
  39. FOREIGN KEY (user_id) REFERENCES users (id)
  40. )
  41. ''')
  42. # Create localized foods table based on Sprint 5 architecture
  43. cursor.execute('''
  44. CREATE TABLE IF NOT EXISTS foods (
  45. id INTEGER PRIMARY KEY AUTOINCREMENT,
  46. name TEXT NOT NULL,
  47. category TEXT,
  48. base_weight_g REAL DEFAULT 100.0,
  49. calories REAL DEFAULT 0.0,
  50. protein_g REAL DEFAULT 0.0,
  51. fat_g REAL DEFAULT 0.0,
  52. carbs_g REAL DEFAULT 0.0,
  53. fiber_g REAL DEFAULT 0.0,
  54. sugar_g REAL DEFAULT 0.0,
  55. sodium_mg REAL DEFAULT 0.0,
  56. vitamin_a_iu REAL DEFAULT 0.0,
  57. vitamin_c_mg REAL DEFAULT 0.0,
  58. calcium_mg REAL DEFAULT 0.0,
  59. iron_mg REAL DEFAULT 0.0,
  60. potassium_mg REAL DEFAULT 0.0,
  61. cholesterol_mg REAL DEFAULT 0.0,
  62. source TEXT DEFAULT 'System'
  63. )
  64. ''')
  65. # Create chat history table for Sprint 6 persistence
  66. cursor.execute('''
  67. CREATE TABLE IF NOT EXISTS chat_messages (
  68. id INTEGER PRIMARY KEY AUTOINCREMENT,
  69. user_id INTEGER NOT NULL,
  70. role TEXT NOT NULL,
  71. content TEXT NOT NULL,
  72. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
  73. FOREIGN KEY (user_id) REFERENCES users (id)
  74. )
  75. ''')
  76. # Create minimal user_profiles table for macro targets (US-07)
  77. cursor.execute('''
  78. CREATE TABLE IF NOT EXISTS user_profiles (
  79. user_id INTEGER PRIMARY KEY,
  80. target_calories INTEGER DEFAULT 2000,
  81. target_protein_g INTEGER DEFAULT 150,
  82. target_carbs_g INTEGER DEFAULT 200,
  83. target_fat_g INTEGER DEFAULT 65,
  84. FOREIGN KEY (user_id) REFERENCES users (id)
  85. )
  86. ''')
  87. # Create user-named meals table for Sprint 8
  88. cursor.execute('''
  89. CREATE TABLE IF NOT EXISTS saved_meals (
  90. id INTEGER PRIMARY KEY AUTOINCREMENT,
  91. user_id INTEGER NOT NULL,
  92. name TEXT NOT NULL,
  93. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
  94. FOREIGN KEY (user_id) REFERENCES users (id)
  95. )
  96. ''')
  97. # Create meal items table to link multiple foods to a single saved meal
  98. cursor.execute('''
  99. CREATE TABLE IF NOT EXISTS meal_items (
  100. id INTEGER PRIMARY KEY AUTOINCREMENT,
  101. meal_id INTEGER NOT NULL,
  102. food_id INTEGER NOT NULL,
  103. amount_g REAL NOT NULL,
  104. FOREIGN KEY (meal_id) REFERENCES saved_meals (id) ON DELETE CASCADE,
  105. FOREIGN KEY (food_id) REFERENCES foods (id)
  106. )
  107. ''')
  108. # Create index for rapid fuzzy search compatibility
  109. cursor.execute('CREATE INDEX IF NOT EXISTS idx_food_name ON foods(name COLLATE NOCASE)')
  110. cursor.execute('CREATE INDEX IF NOT EXISTS idx_saved_meals_user ON saved_meals(user_id)')
  111. conn.commit()
  112. logger.info("Database and tables initialized successfully.")
  113. except Exception as e:
  114. logger.error(f"Error initializing database: {e}")
  115. raise
  116. finally:
  117. if conn:
  118. conn.close()
  119. def save_user_meal(user_id: int, name: str, items: List[Dict[str, Any]]) -> Optional[int]:
  120. """Persist a collection of food items as a named meal list for a user"""
  121. conn = None
  122. try:
  123. conn = get_db_connection()
  124. cursor = conn.cursor()
  125. # 1. Create the meal header
  126. cursor.execute(
  127. "INSERT INTO saved_meals (user_id, name) VALUES (?, ?)",
  128. (user_id, name)
  129. )
  130. meal_id = cursor.lastrowid
  131. # 2. Add each item linked to this meal
  132. for item in items:
  133. cursor.execute(
  134. "INSERT INTO meal_items (meal_id, food_id, amount_g) VALUES (?, ?, ?)",
  135. (meal_id, item['food_id'], item['amount_g'])
  136. )
  137. conn.commit()
  138. return meal_id
  139. except Exception as e:
  140. logger.error(f"Error saving user meal: {e}")
  141. if conn: conn.rollback()
  142. return None
  143. finally:
  144. if conn: conn.close()
  145. def get_user_by_username(username: str) -> Optional[Dict[str, Any]]:
  146. """Retrieve user dictionary if they exist"""
  147. conn = None
  148. try:
  149. conn = get_db_connection()
  150. cursor = conn.cursor()
  151. cursor.execute("SELECT * FROM users WHERE username = ?", (username,))
  152. row = cursor.fetchone()
  153. return dict(row) if row else None
  154. except Exception as e:
  155. logger.error(f"Database error fetching user: {e}")
  156. return None
  157. finally:
  158. if conn: conn.close()
  159. def create_user(username: str, password_hash: str) -> Optional[int]:
  160. """Creates a user securely. Returns user_id if successful, None if username exists."""
  161. conn = None
  162. try:
  163. conn = get_db_connection()
  164. cursor = conn.cursor()
  165. cursor.execute(
  166. "INSERT INTO users (username, password_hash) VALUES (?, ?)",
  167. (username, password_hash)
  168. )
  169. user_id = cursor.lastrowid
  170. conn.commit()
  171. return user_id
  172. except sqlite3.IntegrityError:
  173. return None
  174. except Exception as e:
  175. logger.error(f"Database error during user creation: {e}")
  176. raise
  177. finally:
  178. if conn: conn.close()
  179. def create_session(user_id: int) -> str:
  180. """Create a secure 32-character session token in the DB valid for 7 days"""
  181. token = secrets.token_urlsafe(32)
  182. expires_at = datetime.now() + timedelta(days=7)
  183. conn = None
  184. try:
  185. conn = get_db_connection()
  186. cursor = conn.cursor()
  187. cursor.execute(
  188. "INSERT INTO sessions (token, user_id, expires_at) VALUES (?, ?, ?)",
  189. (token, user_id, expires_at)
  190. )
  191. conn.commit()
  192. return token
  193. except Exception as e:
  194. logger.error(f"Error creating session: {e}")
  195. raise
  196. finally:
  197. if conn: conn.close()
  198. def get_user_from_token(token: str) -> Optional[Dict[str, Any]]:
  199. """Verify a session token and return the associated user data if valid and not expired"""
  200. conn = None
  201. try:
  202. conn = get_db_connection()
  203. cursor = conn.cursor()
  204. # Find user if token exists and hasn't expired
  205. cursor.execute('''
  206. SELECT users.* FROM users
  207. JOIN sessions ON users.id = sessions.user_id
  208. WHERE sessions.token = ? AND sessions.expires_at > ?
  209. ''', (token, datetime.now()))
  210. row = cursor.fetchone()
  211. return dict(row) if row else None
  212. except Exception as e:
  213. logger.error(f"Database error verifying token: {e}")
  214. return None
  215. finally:
  216. if conn: conn.close()
  217. def delete_session(token: str):
  218. """Securely remove a session token when the user logs out"""
  219. conn = None
  220. try:
  221. conn = get_db_connection()
  222. cursor = conn.cursor()
  223. cursor.execute("DELETE FROM sessions WHERE token = ?", (token,))
  224. conn.commit()
  225. except Exception as e:
  226. logger.error(f"Error deleting session: {e}")
  227. finally:
  228. if conn: conn.close()
  229. def search_foods_by_name(query: str, limit: int = 15) -> list[Dict[str, Any]]:
  230. """Securely search for foods matching a string query with relevance-based ordering"""
  231. conn = None
  232. try:
  233. conn = get_db_connection()
  234. cursor = conn.cursor()
  235. # SQL Injection safe query utilizing LIKE parameterization
  236. # We prioritize:
  237. # 1. Items NOT in 'Baby Foods'
  238. # 2. Shorter names (usually more fundamental ingredients)
  239. # 3. Alphabetical order as a tie-breaker
  240. q = f"%{query}%"
  241. prefix_match = f"{query}%"
  242. cursor.execute('''
  243. SELECT * FROM foods
  244. WHERE name LIKE ?
  245. ORDER BY
  246. CASE WHEN category = 'Baby Foods' THEN 1 ELSE 0 END,
  247. CASE WHEN name LIKE ? THEN 0 ELSE 1 END,
  248. LENGTH(name) ASC,
  249. name ASC
  250. LIMIT ?
  251. ''', (q, prefix_match, limit))
  252. rows = cursor.fetchall()
  253. return [dict(row) for row in rows]
  254. except Exception as e:
  255. logger.error(f"Error searching foods: {e}")
  256. return []
  257. finally:
  258. if conn: conn.close()
  259. def save_chat_message(user_id: int, role: str, content: str):
  260. """Persist a chat message to the database"""
  261. conn = None
  262. try:
  263. conn = get_db_connection()
  264. cursor = conn.cursor()
  265. cursor.execute(
  266. "INSERT INTO chat_messages (user_id, role, content) VALUES (?, ?, ?)",
  267. (user_id, role, content)
  268. )
  269. conn.commit()
  270. except Exception as e:
  271. logger.error(f"Error saving chat message: {e}")
  272. finally:
  273. if conn: conn.close()
  274. def get_user_chat_history(user_id: int, limit: int = 50) -> list[Dict[str, Any]]:
  275. """Retrieve the most recent chat messages for a user"""
  276. conn = None
  277. try:
  278. conn = get_db_connection()
  279. cursor = conn.cursor()
  280. # Order by created_at DESC to get recent ones, then reverse for display
  281. cursor.execute('''
  282. SELECT role, content FROM chat_messages
  283. WHERE user_id = ?
  284. ORDER BY created_at ASC
  285. LIMIT ?
  286. ''', (user_id, limit))
  287. rows = cursor.fetchall()
  288. return [dict(row) for row in rows]
  289. except Exception as e:
  290. logger.error(f"Error fetching chat history: {e}")
  291. return []
  292. finally:
  293. if conn: conn.close()
  294. def get_user_profile(user_id: int) -> Optional[Dict[str, Any]]:
  295. """Fetch the user's profile containing macro targets. Inserts defaults if none exists."""
  296. conn = None
  297. try:
  298. conn = get_db_connection()
  299. cursor = conn.cursor()
  300. cursor.execute("SELECT * FROM user_profiles WHERE user_id = ?", (user_id,))
  301. row = cursor.fetchone()
  302. if not row:
  303. # Create a default profile row if one does not exist
  304. cursor.execute('''
  305. INSERT INTO user_profiles (user_id) VALUES (?)
  306. ''', (user_id,))
  307. conn.commit()
  308. cursor.execute("SELECT * FROM user_profiles WHERE user_id = ?", (user_id,))
  309. row = cursor.fetchone()
  310. return dict(row) if row else None
  311. except Exception as e:
  312. logger.error(f"Error fetching user profile: {e}")
  313. return None
  314. finally:
  315. if conn: conn.close()
  316. def get_food_by_id(food_id: int) -> Optional[Dict[str, Any]]:
  317. """Retrieve a single food item by its unique ID"""
  318. conn = None
  319. try:
  320. conn = get_db_connection()
  321. cursor = conn.cursor()
  322. cursor.execute("SELECT * FROM foods WHERE id = ?", (food_id,))
  323. row = cursor.fetchone()
  324. return dict(row) if row else None
  325. except Exception as e:
  326. logger.error(f"Error fetching food by ID {food_id}: {e}")
  327. return None
  328. finally:
  329. if conn: conn.close()
  330. def get_foods_by_ids(food_ids: List[int]) -> List[Dict[str, Any]]:
  331. """Retrieve multiple food items by their unique IDs in bulk"""
  332. if not food_ids:
  333. return []
  334. conn = None
  335. try:
  336. conn = get_db_connection()
  337. cursor = conn.cursor()
  338. # Create placeholders for the IN clause
  339. placeholders = ', '.join(['?'] * len(food_ids))
  340. query = f"SELECT * FROM foods WHERE id IN ({placeholders})"
  341. cursor.execute(query, food_ids)
  342. rows = cursor.fetchall()
  343. return [dict(row) for row in rows]
  344. except Exception as e:
  345. logger.error(f"Error fetching foods by IDs {food_ids}: {e}")
  346. return []
  347. finally:
  348. if conn: conn.close()