database.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330
  1. import sqlite3
  2. import os
  3. import logging
  4. import secrets
  5. from datetime import datetime, timedelta
  6. from typing import Optional, Dict, Any, List
  7. logger = logging.getLogger(__name__)
  8. # Locate db correctly in the same directory
  9. DB_PATH = os.path.join(os.path.dirname(__file__), "localfood.db")
  10. def get_db_connection():
  11. # Enable higher timeout and disable thread checks for FastAPI async compatibility
  12. conn = sqlite3.connect(DB_PATH, timeout=30.0, check_same_thread=False)
  13. conn.row_factory = sqlite3.Row
  14. # Enable Write-Ahead Log (WAL) mode for simultaneous read/write operations
  15. conn.execute('PRAGMA journal_mode=WAL')
  16. conn.execute('PRAGMA synchronous=NORMAL')
  17. return conn
  18. def create_tables():
  19. """Initialize the SQLite database with required tables"""
  20. conn = None
  21. try:
  22. conn = get_db_connection()
  23. cursor = conn.cursor()
  24. # Create users table securely locally
  25. cursor.execute('''
  26. CREATE TABLE IF NOT EXISTS users (
  27. id INTEGER PRIMARY KEY AUTOINCREMENT,
  28. username TEXT UNIQUE NOT NULL,
  29. password_hash TEXT NOT NULL,
  30. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
  31. )
  32. ''')
  33. # Create sessions table for database-backed tokens
  34. cursor.execute('''
  35. CREATE TABLE IF NOT EXISTS sessions (
  36. token TEXT PRIMARY KEY,
  37. user_id INTEGER NOT NULL,
  38. expires_at TIMESTAMP NOT NULL,
  39. FOREIGN KEY (user_id) REFERENCES users (id)
  40. )
  41. ''')
  42. # Create localized foods table based on Sprint 5 architecture
  43. cursor.execute('''
  44. CREATE TABLE IF NOT EXISTS foods (
  45. id INTEGER PRIMARY KEY AUTOINCREMENT,
  46. name TEXT NOT NULL,
  47. category TEXT,
  48. base_weight_g REAL DEFAULT 100.0,
  49. calories REAL DEFAULT 0.0,
  50. protein_g REAL DEFAULT 0.0,
  51. fat_g REAL DEFAULT 0.0,
  52. carbs_g REAL DEFAULT 0.0,
  53. fiber_g REAL DEFAULT 0.0,
  54. sugar_g REAL DEFAULT 0.0,
  55. sodium_mg REAL DEFAULT 0.0,
  56. vitamin_a_iu REAL DEFAULT 0.0,
  57. vitamin_c_mg REAL DEFAULT 0.0,
  58. calcium_mg REAL DEFAULT 0.0,
  59. iron_mg REAL DEFAULT 0.0,
  60. potassium_mg REAL DEFAULT 0.0,
  61. cholesterol_mg REAL DEFAULT 0.0,
  62. source TEXT DEFAULT 'System'
  63. )
  64. ''')
  65. # Create chat history table for Sprint 6 persistence
  66. cursor.execute('''
  67. CREATE TABLE IF NOT EXISTS chat_messages (
  68. id INTEGER PRIMARY KEY AUTOINCREMENT,
  69. user_id INTEGER NOT NULL,
  70. role TEXT NOT NULL,
  71. content TEXT NOT NULL,
  72. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
  73. FOREIGN KEY (user_id) REFERENCES users (id)
  74. )
  75. ''')
  76. # Create minimal user_profiles table for macro targets (US-07)
  77. cursor.execute('''
  78. CREATE TABLE IF NOT EXISTS user_profiles (
  79. user_id INTEGER PRIMARY KEY,
  80. target_calories INTEGER DEFAULT 2000,
  81. target_protein_g INTEGER DEFAULT 150,
  82. target_carbs_g INTEGER DEFAULT 200,
  83. target_fat_g INTEGER DEFAULT 65,
  84. FOREIGN KEY (user_id) REFERENCES users (id)
  85. )
  86. ''')
  87. # Create index for rapid fuzzy search compatibility
  88. cursor.execute('CREATE INDEX IF NOT EXISTS idx_food_name ON foods(name COLLATE NOCASE)')
  89. conn.commit()
  90. logger.info("Database and tables initialized successfully.")
  91. except Exception as e:
  92. logger.error(f"Error initializing database: {e}")
  93. raise
  94. finally:
  95. if conn:
  96. conn.close()
  97. def get_user_by_username(username: str) -> Optional[Dict[str, Any]]:
  98. """Retrieve user dictionary if they exist"""
  99. conn = None
  100. try:
  101. conn = get_db_connection()
  102. cursor = conn.cursor()
  103. cursor.execute("SELECT * FROM users WHERE username = ?", (username,))
  104. row = cursor.fetchone()
  105. return dict(row) if row else None
  106. except Exception as e:
  107. logger.error(f"Database error fetching user: {e}")
  108. return None
  109. finally:
  110. if conn: conn.close()
  111. def create_user(username: str, password_hash: str) -> Optional[int]:
  112. """Creates a user securely. Returns user_id if successful, None if username exists."""
  113. conn = None
  114. try:
  115. conn = get_db_connection()
  116. cursor = conn.cursor()
  117. cursor.execute(
  118. "INSERT INTO users (username, password_hash) VALUES (?, ?)",
  119. (username, password_hash)
  120. )
  121. user_id = cursor.lastrowid
  122. conn.commit()
  123. return user_id
  124. except sqlite3.IntegrityError:
  125. return None
  126. except Exception as e:
  127. logger.error(f"Database error during user creation: {e}")
  128. raise
  129. finally:
  130. if conn: conn.close()
  131. def create_session(user_id: int) -> str:
  132. """Create a secure 32-character session token in the DB valid for 7 days"""
  133. token = secrets.token_urlsafe(32)
  134. expires_at = datetime.now() + timedelta(days=7)
  135. conn = None
  136. try:
  137. conn = get_db_connection()
  138. cursor = conn.cursor()
  139. cursor.execute(
  140. "INSERT INTO sessions (token, user_id, expires_at) VALUES (?, ?, ?)",
  141. (token, user_id, expires_at)
  142. )
  143. conn.commit()
  144. return token
  145. except Exception as e:
  146. logger.error(f"Error creating session: {e}")
  147. raise
  148. finally:
  149. if conn: conn.close()
  150. def get_user_from_token(token: str) -> Optional[Dict[str, Any]]:
  151. """Verify a session token and return the associated user data if valid and not expired"""
  152. conn = None
  153. try:
  154. conn = get_db_connection()
  155. cursor = conn.cursor()
  156. # Find user if token exists and hasn't expired
  157. cursor.execute('''
  158. SELECT users.* FROM users
  159. JOIN sessions ON users.id = sessions.user_id
  160. WHERE sessions.token = ? AND sessions.expires_at > ?
  161. ''', (token, datetime.now()))
  162. row = cursor.fetchone()
  163. return dict(row) if row else None
  164. except Exception as e:
  165. logger.error(f"Database error verifying token: {e}")
  166. return None
  167. finally:
  168. if conn: conn.close()
  169. def delete_session(token: str):
  170. """Securely remove a session token when the user logs out"""
  171. conn = None
  172. try:
  173. conn = get_db_connection()
  174. cursor = conn.cursor()
  175. cursor.execute("DELETE FROM sessions WHERE token = ?", (token,))
  176. conn.commit()
  177. except Exception as e:
  178. logger.error(f"Error deleting session: {e}")
  179. finally:
  180. if conn: conn.close()
  181. def search_foods_by_name(query: str, limit: int = 15) -> list[Dict[str, Any]]:
  182. """Securely search for foods matching a string query with relevance-based ordering"""
  183. conn = None
  184. try:
  185. conn = get_db_connection()
  186. cursor = conn.cursor()
  187. # SQL Injection safe query utilizing LIKE parameterization
  188. # We prioritize:
  189. # 1. Items NOT in 'Baby Foods'
  190. # 2. Shorter names (usually more fundamental ingredients)
  191. # 3. Alphabetical order as a tie-breaker
  192. q = f"%{query}%"
  193. prefix_match = f"{query}%"
  194. cursor.execute('''
  195. SELECT * FROM foods
  196. WHERE name LIKE ?
  197. ORDER BY
  198. CASE WHEN category = 'Baby Foods' THEN 1 ELSE 0 END,
  199. CASE WHEN name LIKE ? THEN 0 ELSE 1 END,
  200. LENGTH(name) ASC,
  201. name ASC
  202. LIMIT ?
  203. ''', (q, prefix_match, limit))
  204. rows = cursor.fetchall()
  205. return [dict(row) for row in rows]
  206. except Exception as e:
  207. logger.error(f"Error searching foods: {e}")
  208. return []
  209. finally:
  210. if conn: conn.close()
  211. def save_chat_message(user_id: int, role: str, content: str):
  212. """Persist a chat message to the database"""
  213. conn = None
  214. try:
  215. conn = get_db_connection()
  216. cursor = conn.cursor()
  217. cursor.execute(
  218. "INSERT INTO chat_messages (user_id, role, content) VALUES (?, ?, ?)",
  219. (user_id, role, content)
  220. )
  221. conn.commit()
  222. except Exception as e:
  223. logger.error(f"Error saving chat message: {e}")
  224. finally:
  225. if conn: conn.close()
  226. def get_user_chat_history(user_id: int, limit: int = 50) -> list[Dict[str, Any]]:
  227. """Retrieve the most recent chat messages for a user"""
  228. conn = None
  229. try:
  230. conn = get_db_connection()
  231. cursor = conn.cursor()
  232. # Order by created_at DESC to get recent ones, then reverse for display
  233. cursor.execute('''
  234. SELECT role, content FROM chat_messages
  235. WHERE user_id = ?
  236. ORDER BY created_at ASC
  237. LIMIT ?
  238. ''', (user_id, limit))
  239. rows = cursor.fetchall()
  240. return [dict(row) for row in rows]
  241. except Exception as e:
  242. logger.error(f"Error fetching chat history: {e}")
  243. return []
  244. finally:
  245. if conn: conn.close()
  246. def get_user_profile(user_id: int) -> Optional[Dict[str, Any]]:
  247. """Fetch the user's profile containing macro targets. Inserts defaults if none exists."""
  248. conn = None
  249. try:
  250. conn = get_db_connection()
  251. cursor = conn.cursor()
  252. cursor.execute("SELECT * FROM user_profiles WHERE user_id = ?", (user_id,))
  253. row = cursor.fetchone()
  254. if not row:
  255. # Create a default profile row if one does not exist
  256. cursor.execute('''
  257. INSERT INTO user_profiles (user_id) VALUES (?)
  258. ''', (user_id,))
  259. conn.commit()
  260. cursor.execute("SELECT * FROM user_profiles WHERE user_id = ?", (user_id,))
  261. row = cursor.fetchone()
  262. return dict(row) if row else None
  263. except Exception as e:
  264. logger.error(f"Error fetching user profile: {e}")
  265. return None
  266. finally:
  267. if conn: conn.close()
  268. def get_food_by_id(food_id: int) -> Optional[Dict[str, Any]]:
  269. """Retrieve a single food item by its unique ID"""
  270. conn = None
  271. try:
  272. conn = get_db_connection()
  273. cursor = conn.cursor()
  274. cursor.execute("SELECT * FROM foods WHERE id = ?", (food_id,))
  275. row = cursor.fetchone()
  276. return dict(row) if row else None
  277. except Exception as e:
  278. logger.error(f"Error fetching food by ID {food_id}: {e}")
  279. return None
  280. finally:
  281. if conn: conn.close()
  282. def get_foods_by_ids(food_ids: List[int]) -> List[Dict[str, Any]]:
  283. """Retrieve multiple food items by their unique IDs in bulk"""
  284. if not food_ids:
  285. return []
  286. conn = None
  287. try:
  288. conn = get_db_connection()
  289. cursor = conn.cursor()
  290. # Create placeholders for the IN clause
  291. placeholders = ', '.join(['?'] * len(food_ids))
  292. query = f"SELECT * FROM foods WHERE id IN ({placeholders})"
  293. cursor.execute(query, food_ids)
  294. rows = cursor.fetchall()
  295. return [dict(row) for row in rows]
  296. except Exception as e:
  297. logger.error(f"Error fetching foods by IDs {food_ids}: {e}")
  298. return []
  299. finally:
  300. if conn: conn.close()