database.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. import sqlite3
  2. import os
  3. import logging
  4. import secrets
  5. from datetime import datetime, timedelta
  6. from typing import Optional, Dict, Any
  7. logger = logging.getLogger(__name__)
  8. # Locate db correctly in the same directory
  9. DB_PATH = os.path.join(os.path.dirname(__file__), "localfood.db")
  10. def get_db_connection():
  11. # Enable higher timeout and disable thread checks for FastAPI async compatibility
  12. conn = sqlite3.connect(DB_PATH, timeout=30.0, check_same_thread=False)
  13. conn.row_factory = sqlite3.Row
  14. # Enable Write-Ahead Log (WAL) mode for simultaneous read/write operations
  15. conn.execute('PRAGMA journal_mode=WAL')
  16. conn.execute('PRAGMA synchronous=NORMAL')
  17. return conn
  18. def create_tables():
  19. """Initialize the SQLite database with required tables"""
  20. conn = None
  21. try:
  22. conn = get_db_connection()
  23. cursor = conn.cursor()
  24. # Create users table securely locally
  25. cursor.execute('''
  26. CREATE TABLE IF NOT EXISTS users (
  27. id INTEGER PRIMARY KEY AUTOINCREMENT,
  28. username TEXT UNIQUE NOT NULL,
  29. password_hash TEXT NOT NULL,
  30. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
  31. )
  32. ''')
  33. # Create sessions table for database-backed tokens
  34. cursor.execute('''
  35. CREATE TABLE IF NOT EXISTS sessions (
  36. token TEXT PRIMARY KEY,
  37. user_id INTEGER NOT NULL,
  38. expires_at TIMESTAMP NOT NULL,
  39. FOREIGN KEY (user_id) REFERENCES users (id)
  40. )
  41. ''')
  42. # Create localized foods table based on Sprint 5 architecture
  43. cursor.execute('''
  44. CREATE TABLE IF NOT EXISTS foods (
  45. id INTEGER PRIMARY KEY AUTOINCREMENT,
  46. name TEXT NOT NULL,
  47. category TEXT,
  48. base_weight_g REAL DEFAULT 100.0,
  49. calories REAL DEFAULT 0.0,
  50. protein_g REAL DEFAULT 0.0,
  51. fat_g REAL DEFAULT 0.0,
  52. carbs_g REAL DEFAULT 0.0,
  53. fiber_g REAL DEFAULT 0.0,
  54. sugar_g REAL DEFAULT 0.0,
  55. sodium_mg REAL DEFAULT 0.0,
  56. vitamin_a_iu REAL DEFAULT 0.0,
  57. vitamin_c_mg REAL DEFAULT 0.0,
  58. calcium_mg REAL DEFAULT 0.0,
  59. iron_mg REAL DEFAULT 0.0,
  60. potassium_mg REAL DEFAULT 0.0,
  61. cholesterol_mg REAL DEFAULT 0.0,
  62. source TEXT DEFAULT 'System'
  63. )
  64. ''')
  65. # Create chat history table for Sprint 6 persistence
  66. cursor.execute('''
  67. CREATE TABLE IF NOT EXISTS chat_messages (
  68. id INTEGER PRIMARY KEY AUTOINCREMENT,
  69. user_id INTEGER NOT NULL,
  70. role TEXT NOT NULL,
  71. content TEXT NOT NULL,
  72. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
  73. FOREIGN KEY (user_id) REFERENCES users (id)
  74. )
  75. ''')
  76. # Create index for rapid fuzzy search compatibility
  77. cursor.execute('CREATE INDEX IF NOT EXISTS idx_food_name ON foods(name COLLATE NOCASE)')
  78. conn.commit()
  79. logger.info("Database and tables initialized successfully.")
  80. except Exception as e:
  81. logger.error(f"Error initializing database: {e}")
  82. raise
  83. finally:
  84. if conn:
  85. conn.close()
  86. def get_user_by_username(username: str) -> Optional[Dict[str, Any]]:
  87. """Retrieve user dictionary if they exist"""
  88. conn = None
  89. try:
  90. conn = get_db_connection()
  91. cursor = conn.cursor()
  92. cursor.execute("SELECT * FROM users WHERE username = ?", (username,))
  93. row = cursor.fetchone()
  94. return dict(row) if row else None
  95. except Exception as e:
  96. logger.error(f"Database error fetching user: {e}")
  97. return None
  98. finally:
  99. if conn: conn.close()
  100. def create_user(username: str, password_hash: str) -> Optional[int]:
  101. """Creates a user securely. Returns user_id if successful, None if username exists."""
  102. conn = None
  103. try:
  104. conn = get_db_connection()
  105. cursor = conn.cursor()
  106. cursor.execute(
  107. "INSERT INTO users (username, password_hash) VALUES (?, ?)",
  108. (username, password_hash)
  109. )
  110. user_id = cursor.lastrowid
  111. conn.commit()
  112. return user_id
  113. except sqlite3.IntegrityError:
  114. return None
  115. except Exception as e:
  116. logger.error(f"Database error during user creation: {e}")
  117. raise
  118. finally:
  119. if conn: conn.close()
  120. def create_session(user_id: int) -> str:
  121. """Create a secure 32-character session token in the DB valid for 7 days"""
  122. token = secrets.token_urlsafe(32)
  123. expires_at = datetime.now() + timedelta(days=7)
  124. conn = None
  125. try:
  126. conn = get_db_connection()
  127. cursor = conn.cursor()
  128. cursor.execute(
  129. "INSERT INTO sessions (token, user_id, expires_at) VALUES (?, ?, ?)",
  130. (token, user_id, expires_at)
  131. )
  132. conn.commit()
  133. return token
  134. except Exception as e:
  135. logger.error(f"Error creating session: {e}")
  136. raise
  137. finally:
  138. if conn: conn.close()
  139. def get_user_from_token(token: str) -> Optional[Dict[str, Any]]:
  140. """Verify a session token and return the associated user data if valid and not expired"""
  141. conn = None
  142. try:
  143. conn = get_db_connection()
  144. cursor = conn.cursor()
  145. # Find user if token exists and hasn't expired
  146. cursor.execute('''
  147. SELECT users.* FROM users
  148. JOIN sessions ON users.id = sessions.user_id
  149. WHERE sessions.token = ? AND sessions.expires_at > ?
  150. ''', (token, datetime.now()))
  151. row = cursor.fetchone()
  152. return dict(row) if row else None
  153. except Exception as e:
  154. logger.error(f"Database error verifying token: {e}")
  155. return None
  156. finally:
  157. if conn: conn.close()
  158. def delete_session(token: str):
  159. """Securely remove a session token when the user logs out"""
  160. conn = None
  161. try:
  162. conn = get_db_connection()
  163. cursor = conn.cursor()
  164. cursor.execute("DELETE FROM sessions WHERE token = ?", (token,))
  165. conn.commit()
  166. except Exception as e:
  167. logger.error(f"Error deleting session: {e}")
  168. finally:
  169. if conn: conn.close()
  170. def search_foods_by_name(query: str, limit: int = 15) -> list[Dict[str, Any]]:
  171. """Securely search for foods matching a string query with relevance-based ordering"""
  172. conn = None
  173. try:
  174. conn = get_db_connection()
  175. cursor = conn.cursor()
  176. # SQL Injection safe query utilizing LIKE parameterization
  177. # We prioritize:
  178. # 1. Items NOT in 'Baby Foods'
  179. # 2. Shorter names (usually more fundamental ingredients)
  180. # 3. Alphabetical order as a tie-breaker
  181. q = f"%{query}%"
  182. prefix_match = f"{query}%"
  183. cursor.execute('''
  184. SELECT * FROM foods
  185. WHERE name LIKE ?
  186. ORDER BY
  187. CASE WHEN category = 'Baby Foods' THEN 1 ELSE 0 END,
  188. CASE WHEN name LIKE ? THEN 0 ELSE 1 END,
  189. LENGTH(name) ASC,
  190. name ASC
  191. LIMIT ?
  192. ''', (q, prefix_match, limit))
  193. rows = cursor.fetchall()
  194. return [dict(row) for row in rows]
  195. except Exception as e:
  196. logger.error(f"Error searching foods: {e}")
  197. return []
  198. finally:
  199. if conn: conn.close()
  200. def save_chat_message(user_id: int, role: str, content: str):
  201. """Persist a chat message to the database"""
  202. conn = None
  203. try:
  204. conn = get_db_connection()
  205. cursor = conn.cursor()
  206. cursor.execute(
  207. "INSERT INTO chat_messages (user_id, role, content) VALUES (?, ?, ?)",
  208. (user_id, role, content)
  209. )
  210. conn.commit()
  211. except Exception as e:
  212. logger.error(f"Error saving chat message: {e}")
  213. finally:
  214. if conn: conn.close()
  215. def get_user_chat_history(user_id: int, limit: int = 50) -> list[Dict[str, Any]]:
  216. """Retrieve the most recent chat messages for a user"""
  217. conn = None
  218. try:
  219. conn = get_db_connection()
  220. cursor = conn.cursor()
  221. # Order by created_at DESC to get recent ones, then reverse for display
  222. cursor.execute('''
  223. SELECT role, content FROM chat_messages
  224. WHERE user_id = ?
  225. ORDER BY created_at ASC
  226. LIMIT ?
  227. ''', (user_id, limit))
  228. rows = cursor.fetchall()
  229. return [dict(row) for row in rows]
  230. except Exception as e:
  231. logger.error(f"Error fetching chat history: {e}")
  232. return []
  233. finally:
  234. if conn: conn.close()