|
@@ -12,14 +12,16 @@ DB_PATH = os.path.join(os.path.dirname(__file__), "localfood.db")
|
|
|
|
|
|
|
|
def get_db_connection():
|
|
def get_db_connection():
|
|
|
# Enable higher timeout and disable thread checks for FastAPI async compatibility
|
|
# Enable higher timeout and disable thread checks for FastAPI async compatibility
|
|
|
- conn = sqlite3.connect(DB_PATH, timeout=20.0, check_same_thread=False)
|
|
|
|
|
|
|
+ conn = sqlite3.connect(DB_PATH, timeout=30.0, check_same_thread=False)
|
|
|
conn.row_factory = sqlite3.Row
|
|
conn.row_factory = sqlite3.Row
|
|
|
# Enable Write-Ahead Log (WAL) mode for simultaneous read/write operations
|
|
# Enable Write-Ahead Log (WAL) mode for simultaneous read/write operations
|
|
|
- conn.execute('pragma journal_mode=wal')
|
|
|
|
|
|
|
+ conn.execute('PRAGMA journal_mode=WAL')
|
|
|
|
|
+ conn.execute('PRAGMA synchronous=NORMAL')
|
|
|
return conn
|
|
return conn
|
|
|
|
|
|
|
|
def create_tables():
|
|
def create_tables():
|
|
|
"""Initialize the SQLite database with required tables"""
|
|
"""Initialize the SQLite database with required tables"""
|
|
|
|
|
+ conn = None
|
|
|
try:
|
|
try:
|
|
|
conn = get_db_connection()
|
|
conn = get_db_connection()
|
|
|
cursor = conn.cursor()
|
|
cursor = conn.cursor()
|
|
@@ -67,32 +69,49 @@ def create_tables():
|
|
|
source TEXT DEFAULT 'System'
|
|
source TEXT DEFAULT 'System'
|
|
|
)
|
|
)
|
|
|
''')
|
|
''')
|
|
|
|
|
+
|
|
|
|
|
+ # Create chat history table for Sprint 6 persistence
|
|
|
|
|
+ cursor.execute('''
|
|
|
|
|
+ CREATE TABLE IF NOT EXISTS chat_messages (
|
|
|
|
|
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
|
|
|
+ user_id INTEGER NOT NULL,
|
|
|
|
|
+ role TEXT NOT NULL,
|
|
|
|
|
+ content TEXT NOT NULL,
|
|
|
|
|
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
|
|
|
+ FOREIGN KEY (user_id) REFERENCES users (id)
|
|
|
|
|
+ )
|
|
|
|
|
+ ''')
|
|
|
|
|
|
|
|
# Create index for rapid fuzzy search compatibility
|
|
# Create index for rapid fuzzy search compatibility
|
|
|
cursor.execute('CREATE INDEX IF NOT EXISTS idx_food_name ON foods(name COLLATE NOCASE)')
|
|
cursor.execute('CREATE INDEX IF NOT EXISTS idx_food_name ON foods(name COLLATE NOCASE)')
|
|
|
|
|
|
|
|
conn.commit()
|
|
conn.commit()
|
|
|
- conn.close()
|
|
|
|
|
logger.info("Database and tables initialized successfully.")
|
|
logger.info("Database and tables initialized successfully.")
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
|
logger.error(f"Error initializing database: {e}")
|
|
logger.error(f"Error initializing database: {e}")
|
|
|
raise
|
|
raise
|
|
|
|
|
+ finally:
|
|
|
|
|
+ if conn:
|
|
|
|
|
+ conn.close()
|
|
|
|
|
|
|
|
def get_user_by_username(username: str) -> Optional[Dict[str, Any]]:
|
|
def get_user_by_username(username: str) -> Optional[Dict[str, Any]]:
|
|
|
"""Retrieve user dictionary if they exist"""
|
|
"""Retrieve user dictionary if they exist"""
|
|
|
|
|
+ conn = None
|
|
|
try:
|
|
try:
|
|
|
conn = get_db_connection()
|
|
conn = get_db_connection()
|
|
|
cursor = conn.cursor()
|
|
cursor = conn.cursor()
|
|
|
cursor.execute("SELECT * FROM users WHERE username = ?", (username,))
|
|
cursor.execute("SELECT * FROM users WHERE username = ?", (username,))
|
|
|
row = cursor.fetchone()
|
|
row = cursor.fetchone()
|
|
|
- conn.close()
|
|
|
|
|
return dict(row) if row else None
|
|
return dict(row) if row else None
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
|
logger.error(f"Database error fetching user: {e}")
|
|
logger.error(f"Database error fetching user: {e}")
|
|
|
return None
|
|
return None
|
|
|
|
|
+ finally:
|
|
|
|
|
+ if conn: conn.close()
|
|
|
|
|
|
|
|
def create_user(username: str, password_hash: str) -> Optional[int]:
|
|
def create_user(username: str, password_hash: str) -> Optional[int]:
|
|
|
"""Creates a user securely. Returns user_id if successful, None if username exists."""
|
|
"""Creates a user securely. Returns user_id if successful, None if username exists."""
|
|
|
|
|
+ conn = None
|
|
|
try:
|
|
try:
|
|
|
conn = get_db_connection()
|
|
conn = get_db_connection()
|
|
|
cursor = conn.cursor()
|
|
cursor = conn.cursor()
|
|
@@ -102,19 +121,21 @@ def create_user(username: str, password_hash: str) -> Optional[int]:
|
|
|
)
|
|
)
|
|
|
user_id = cursor.lastrowid
|
|
user_id = cursor.lastrowid
|
|
|
conn.commit()
|
|
conn.commit()
|
|
|
- conn.close()
|
|
|
|
|
return user_id
|
|
return user_id
|
|
|
except sqlite3.IntegrityError:
|
|
except sqlite3.IntegrityError:
|
|
|
return None
|
|
return None
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
|
logger.error(f"Database error during user creation: {e}")
|
|
logger.error(f"Database error during user creation: {e}")
|
|
|
raise
|
|
raise
|
|
|
|
|
+ finally:
|
|
|
|
|
+ if conn: conn.close()
|
|
|
|
|
|
|
|
def create_session(user_id: int) -> str:
|
|
def create_session(user_id: int) -> str:
|
|
|
- """Create a secure 32-character session token in the DB valid for 24h"""
|
|
|
|
|
|
|
+ """Create a secure 32-character session token in the DB valid for 7 days"""
|
|
|
token = secrets.token_urlsafe(32)
|
|
token = secrets.token_urlsafe(32)
|
|
|
- expires_at = datetime.now() + timedelta(hours=24)
|
|
|
|
|
|
|
+ expires_at = datetime.now() + timedelta(days=7)
|
|
|
|
|
|
|
|
|
|
+ conn = None
|
|
|
try:
|
|
try:
|
|
|
conn = get_db_connection()
|
|
conn = get_db_connection()
|
|
|
cursor = conn.cursor()
|
|
cursor = conn.cursor()
|
|
@@ -123,14 +144,16 @@ def create_session(user_id: int) -> str:
|
|
|
(token, user_id, expires_at)
|
|
(token, user_id, expires_at)
|
|
|
)
|
|
)
|
|
|
conn.commit()
|
|
conn.commit()
|
|
|
- conn.close()
|
|
|
|
|
return token
|
|
return token
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
|
logger.error(f"Error creating session: {e}")
|
|
logger.error(f"Error creating session: {e}")
|
|
|
raise
|
|
raise
|
|
|
|
|
+ finally:
|
|
|
|
|
+ if conn: conn.close()
|
|
|
|
|
|
|
|
def get_user_from_token(token: str) -> Optional[Dict[str, Any]]:
|
|
def get_user_from_token(token: str) -> Optional[Dict[str, Any]]:
|
|
|
"""Verify a session token and return the associated user data if valid and not expired"""
|
|
"""Verify a session token and return the associated user data if valid and not expired"""
|
|
|
|
|
+ conn = None
|
|
|
try:
|
|
try:
|
|
|
conn = get_db_connection()
|
|
conn = get_db_connection()
|
|
|
cursor = conn.cursor()
|
|
cursor = conn.cursor()
|
|
@@ -141,37 +164,93 @@ def get_user_from_token(token: str) -> Optional[Dict[str, Any]]:
|
|
|
WHERE sessions.token = ? AND sessions.expires_at > ?
|
|
WHERE sessions.token = ? AND sessions.expires_at > ?
|
|
|
''', (token, datetime.now()))
|
|
''', (token, datetime.now()))
|
|
|
row = cursor.fetchone()
|
|
row = cursor.fetchone()
|
|
|
- conn.close()
|
|
|
|
|
return dict(row) if row else None
|
|
return dict(row) if row else None
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
|
logger.error(f"Database error verifying token: {e}")
|
|
logger.error(f"Database error verifying token: {e}")
|
|
|
return None
|
|
return None
|
|
|
|
|
+ finally:
|
|
|
|
|
+ if conn: conn.close()
|
|
|
|
|
|
|
|
def delete_session(token: str):
|
|
def delete_session(token: str):
|
|
|
"""Securely remove a session token when the user logs out"""
|
|
"""Securely remove a session token when the user logs out"""
|
|
|
|
|
+ conn = None
|
|
|
try:
|
|
try:
|
|
|
conn = get_db_connection()
|
|
conn = get_db_connection()
|
|
|
cursor = conn.cursor()
|
|
cursor = conn.cursor()
|
|
|
cursor.execute("DELETE FROM sessions WHERE token = ?", (token,))
|
|
cursor.execute("DELETE FROM sessions WHERE token = ?", (token,))
|
|
|
conn.commit()
|
|
conn.commit()
|
|
|
- conn.close()
|
|
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
|
logger.error(f"Error deleting session: {e}")
|
|
logger.error(f"Error deleting session: {e}")
|
|
|
|
|
+ finally:
|
|
|
|
|
+ if conn: conn.close()
|
|
|
|
|
|
|
|
def search_foods_by_name(query: str, limit: int = 15) -> list[Dict[str, Any]]:
|
|
def search_foods_by_name(query: str, limit: int = 15) -> list[Dict[str, Any]]:
|
|
|
- """Securely search for foods matching a string query using fuzzy matching"""
|
|
|
|
|
|
|
+ """Securely search for foods matching a string query with relevance-based ordering"""
|
|
|
|
|
+ conn = None
|
|
|
try:
|
|
try:
|
|
|
conn = get_db_connection()
|
|
conn = get_db_connection()
|
|
|
cursor = conn.cursor()
|
|
cursor = conn.cursor()
|
|
|
|
|
|
|
|
# SQL Injection safe query utilizing LIKE parameterization
|
|
# SQL Injection safe query utilizing LIKE parameterization
|
|
|
- # COLLATE NOCASE search inherently supported by index on table creation
|
|
|
|
|
|
|
+ # We prioritize:
|
|
|
|
|
+ # 1. Items NOT in 'Baby Foods'
|
|
|
|
|
+ # 2. Shorter names (usually more fundamental ingredients)
|
|
|
|
|
+ # 3. Alphabetical order as a tie-breaker
|
|
|
q = f"%{query}%"
|
|
q = f"%{query}%"
|
|
|
- cursor.execute("SELECT * FROM foods WHERE name LIKE ? LIMIT ?", (q, limit))
|
|
|
|
|
|
|
+ prefix_match = f"{query}%"
|
|
|
|
|
+
|
|
|
|
|
+ cursor.execute('''
|
|
|
|
|
+ SELECT * FROM foods
|
|
|
|
|
+ WHERE name LIKE ?
|
|
|
|
|
+ ORDER BY
|
|
|
|
|
+ CASE WHEN category = 'Baby Foods' THEN 1 ELSE 0 END,
|
|
|
|
|
+ CASE WHEN name LIKE ? THEN 0 ELSE 1 END,
|
|
|
|
|
+ LENGTH(name) ASC,
|
|
|
|
|
+ name ASC
|
|
|
|
|
+ LIMIT ?
|
|
|
|
|
+ ''', (q, prefix_match, limit))
|
|
|
|
|
|
|
|
rows = cursor.fetchall()
|
|
rows = cursor.fetchall()
|
|
|
- conn.close()
|
|
|
|
|
return [dict(row) for row in rows]
|
|
return [dict(row) for row in rows]
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
|
logger.error(f"Error searching foods: {e}")
|
|
logger.error(f"Error searching foods: {e}")
|
|
|
return []
|
|
return []
|
|
|
|
|
+ finally:
|
|
|
|
|
+ if conn: conn.close()
|
|
|
|
|
+
|
|
|
|
|
+def save_chat_message(user_id: int, role: str, content: str):
|
|
|
|
|
+ """Persist a chat message to the database"""
|
|
|
|
|
+ conn = None
|
|
|
|
|
+ try:
|
|
|
|
|
+ conn = get_db_connection()
|
|
|
|
|
+ cursor = conn.cursor()
|
|
|
|
|
+ cursor.execute(
|
|
|
|
|
+ "INSERT INTO chat_messages (user_id, role, content) VALUES (?, ?, ?)",
|
|
|
|
|
+ (user_id, role, content)
|
|
|
|
|
+ )
|
|
|
|
|
+ conn.commit()
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"Error saving chat message: {e}")
|
|
|
|
|
+ finally:
|
|
|
|
|
+ if conn: conn.close()
|
|
|
|
|
+
|
|
|
|
|
+def get_user_chat_history(user_id: int, limit: int = 50) -> list[Dict[str, Any]]:
|
|
|
|
|
+ """Retrieve the most recent chat messages for a user"""
|
|
|
|
|
+ conn = None
|
|
|
|
|
+ try:
|
|
|
|
|
+ conn = get_db_connection()
|
|
|
|
|
+ cursor = conn.cursor()
|
|
|
|
|
+ # Order by created_at DESC to get recent ones, then reverse for display
|
|
|
|
|
+ cursor.execute('''
|
|
|
|
|
+ SELECT role, content FROM chat_messages
|
|
|
|
|
+ WHERE user_id = ?
|
|
|
|
|
+ ORDER BY created_at ASC
|
|
|
|
|
+ LIMIT ?
|
|
|
|
|
+ ''', (user_id, limit))
|
|
|
|
|
+ rows = cursor.fetchall()
|
|
|
|
|
+ return [dict(row) for row in rows]
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"Error fetching chat history: {e}")
|
|
|
|
|
+ return []
|
|
|
|
|
+ finally:
|
|
|
|
|
+ if conn: conn.close()
|