| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469 |
- import sqlite3
- import os
- import logging
- import secrets
- from datetime import datetime, timedelta
- from typing import Optional, Dict, Any, List
- logger = logging.getLogger(__name__)
- # Locate db correctly in the same directory
- DB_PATH = os.path.join(os.path.dirname(__file__), "localfood.db")
- def get_db_connection():
- # Enable higher timeout and disable thread checks for FastAPI async compatibility
- conn = sqlite3.connect(DB_PATH, timeout=30.0, check_same_thread=False)
- conn.row_factory = sqlite3.Row
- # Enable Write-Ahead Log (WAL) mode for simultaneous read/write operations
- conn.execute('PRAGMA journal_mode=WAL')
- conn.execute('PRAGMA synchronous=NORMAL')
- return conn
- def create_tables():
- """Initialize the SQLite database with required tables"""
- conn = None
- try:
- conn = get_db_connection()
- cursor = conn.cursor()
-
- # Create users table securely locally
- cursor.execute('''
- CREATE TABLE IF NOT EXISTS users (
- id INTEGER PRIMARY KEY AUTOINCREMENT,
- username TEXT UNIQUE NOT NULL,
- password_hash TEXT NOT NULL,
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
- )
- ''')
- # Create sessions table for database-backed tokens
- cursor.execute('''
- CREATE TABLE IF NOT EXISTS sessions (
- token TEXT PRIMARY KEY,
- user_id INTEGER NOT NULL,
- expires_at TIMESTAMP NOT NULL,
- FOREIGN KEY (user_id) REFERENCES users (id)
- )
- ''')
- # Create localized foods table based on Sprint 5 architecture
- cursor.execute('''
- CREATE TABLE IF NOT EXISTS foods (
- id INTEGER PRIMARY KEY AUTOINCREMENT,
- name TEXT NOT NULL,
- category TEXT,
- base_weight_g REAL DEFAULT 100.0,
- calories REAL DEFAULT 0.0,
- protein_g REAL DEFAULT 0.0,
- fat_g REAL DEFAULT 0.0,
- carbs_g REAL DEFAULT 0.0,
- fiber_g REAL DEFAULT 0.0,
- sugar_g REAL DEFAULT 0.0,
- sodium_mg REAL DEFAULT 0.0,
- vitamin_a_iu REAL DEFAULT 0.0,
- vitamin_c_mg REAL DEFAULT 0.0,
- calcium_mg REAL DEFAULT 0.0,
- iron_mg REAL DEFAULT 0.0,
- potassium_mg REAL DEFAULT 0.0,
- cholesterol_mg REAL DEFAULT 0.0,
- source TEXT DEFAULT 'System'
- )
- ''')
- # Create chat history table for Sprint 6 persistence
- cursor.execute('''
- CREATE TABLE IF NOT EXISTS chat_messages (
- id INTEGER PRIMARY KEY AUTOINCREMENT,
- user_id INTEGER NOT NULL,
- role TEXT NOT NULL,
- content TEXT NOT NULL,
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
- FOREIGN KEY (user_id) REFERENCES users (id)
- )
- ''')
-
- # Create minimal user_profiles table for macro targets (US-07)
- cursor.execute('''
- CREATE TABLE IF NOT EXISTS user_profiles (
- user_id INTEGER PRIMARY KEY,
- target_calories INTEGER DEFAULT 2000,
- target_protein_g INTEGER DEFAULT 150,
- target_carbs_g INTEGER DEFAULT 200,
- target_fat_g INTEGER DEFAULT 65,
- FOREIGN KEY (user_id) REFERENCES users (id)
- )
- ''')
- # Create user-named meals table for Sprint 8
- cursor.execute('''
- CREATE TABLE IF NOT EXISTS saved_meals (
- id INTEGER PRIMARY KEY AUTOINCREMENT,
- user_id INTEGER NOT NULL,
- name TEXT NOT NULL,
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
- FOREIGN KEY (user_id) REFERENCES users (id)
- )
- ''')
-
- # Create meal items table to link multiple foods to a single saved meal
- cursor.execute('''
- CREATE TABLE IF NOT EXISTS meal_items (
- id INTEGER PRIMARY KEY AUTOINCREMENT,
- meal_id INTEGER NOT NULL,
- food_id INTEGER NOT NULL,
- amount_g REAL NOT NULL,
- FOREIGN KEY (meal_id) REFERENCES saved_meals (id) ON DELETE CASCADE,
- FOREIGN KEY (food_id) REFERENCES foods (id)
- )
- ''')
-
- # Create index for rapid fuzzy search compatibility
- cursor.execute('CREATE INDEX IF NOT EXISTS idx_food_name ON foods(name COLLATE NOCASE)')
- cursor.execute('CREATE INDEX IF NOT EXISTS idx_saved_meals_user ON saved_meals(user_id)')
-
- conn.commit()
- logger.info("Database and tables initialized successfully.")
- except Exception as e:
- logger.error(f"Error initializing database: {e}")
- raise
- finally:
- if conn:
- conn.close()
- def save_user_meal(user_id: int, name: str, items: List[Dict[str, Any]]) -> Optional[int]:
- """Persist a collection of food items as a named meal list for a user"""
- conn = None
- try:
- conn = get_db_connection()
- cursor = conn.cursor()
-
- # 1. Create the meal header
- cursor.execute(
- "INSERT INTO saved_meals (user_id, name) VALUES (?, ?)",
- (user_id, name)
- )
- meal_id = cursor.lastrowid
-
- # 2. Add each item linked to this meal
- for item in items:
- cursor.execute(
- "INSERT INTO meal_items (meal_id, food_id, amount_g) VALUES (?, ?, ?)",
- (meal_id, item['food_id'], item['amount_g'])
- )
-
- conn.commit()
- return meal_id
- except Exception as e:
- logger.error(f"Error saving user meal: {e}")
- if conn: conn.rollback()
- return None
- finally:
- if conn: conn.close()
- def get_user_meals(user_id: int) -> List[Dict[str, Any]]:
- """Retrieve all saved meals for a user, including total macro calculations"""
- conn = None
- try:
- conn = get_db_connection()
- cursor = conn.cursor()
-
- # Fetch meal headers
- cursor.execute(
- "SELECT * FROM saved_meals WHERE user_id = ? ORDER BY created_at DESC",
- (user_id,)
- )
- meals = [dict(row) for row in cursor.fetchall()]
-
- # For each meal, fetch items and calculate totals
- for meal in meals:
- cursor.execute('''
- SELECT mi.amount_g, f.*
- FROM meal_items mi
- JOIN foods f ON mi.food_id = f.id
- WHERE mi.meal_id = ?
- ''', (meal['id'],))
- items = [dict(row) for row in cursor.fetchall()]
-
- # Calculate totals for the meal card summary
- meal['items'] = items
- meal['total_calories'] = sum((item['calories'] * item['amount_g'] / 100.0) for item in items)
- meal['total_protein'] = sum((item['protein_g'] * item['amount_g'] / 100.0) for item in items)
- meal['total_carbs'] = sum((item['carbs_g'] * item['amount_g'] / 100.0) for item in items)
- meal['total_fat'] = sum((item['fat_g'] * item['amount_g'] / 100.0) for item in items)
-
- return meals
- except Exception as e:
- logger.error(f"Error fetching user meals: {e}")
- return []
- finally:
- if conn: conn.close()
- def delete_user_meal(user_id: int, meal_id: int) -> bool:
- """Securely delete a meal and its items, ensuring ownership"""
- conn = None
- try:
- conn = get_db_connection()
- cursor = conn.cursor()
-
- # Verify ownership first
- cursor.execute("SELECT id FROM saved_meals WHERE id = ? AND user_id = ?", (meal_id, user_id))
- if not cursor.fetchone():
- return False
-
- # Delete items first (even if cascading is enabled, we stay transactional)
- cursor.execute("DELETE FROM meal_items WHERE meal_id = ?", (meal_id,))
- # Delete header
- cursor.execute("DELETE FROM saved_meals WHERE id = ?", (meal_id,))
-
- conn.commit()
- return True
- except Exception as e:
- logger.error(f"Error deleting meal: {e}")
- if conn: conn.rollback()
- return False
- finally:
- if conn: conn.close()
- def update_user_meal(user_id: int, meal_id: int, new_name: str) -> bool:
- """Updates the name of a meal, ensuring ownership"""
- conn = None
- try:
- conn = get_db_connection()
- cursor = conn.cursor()
-
- cursor.execute(
- "UPDATE saved_meals SET name = ? WHERE id = ? AND user_id = ?",
- (new_name, meal_id, user_id)
- )
- conn.commit()
- # rowcount will be 1 if updated, 0 if ID/ownership failed
- return cursor.rowcount > 0
- except Exception as e:
- logger.error(f"Error updating meal: {e}")
- if conn: conn.rollback()
- return False
- finally:
- if conn: conn.close()
- def get_user_by_username(username: str) -> Optional[Dict[str, Any]]:
- """Retrieve user dictionary if they exist"""
- conn = None
- try:
- conn = get_db_connection()
- cursor = conn.cursor()
- cursor.execute("SELECT * FROM users WHERE username = ?", (username,))
- row = cursor.fetchone()
- return dict(row) if row else None
- except Exception as e:
- logger.error(f"Database error fetching user: {e}")
- return None
- finally:
- if conn: conn.close()
- def create_user(username: str, password_hash: str) -> Optional[int]:
- """Creates a user securely. Returns user_id if successful, None if username exists."""
- conn = None
- try:
- conn = get_db_connection()
- cursor = conn.cursor()
- cursor.execute(
- "INSERT INTO users (username, password_hash) VALUES (?, ?)",
- (username, password_hash)
- )
- user_id = cursor.lastrowid
- conn.commit()
- return user_id
- except sqlite3.IntegrityError:
- return None
- except Exception as e:
- logger.error(f"Database error during user creation: {e}")
- raise
- finally:
- if conn: conn.close()
- def create_session(user_id: int) -> str:
- """Create a secure 32-character session token in the DB valid for 7 days"""
- token = secrets.token_urlsafe(32)
- expires_at = datetime.now() + timedelta(days=7)
-
- conn = None
- try:
- conn = get_db_connection()
- cursor = conn.cursor()
- cursor.execute(
- "INSERT INTO sessions (token, user_id, expires_at) VALUES (?, ?, ?)",
- (token, user_id, expires_at)
- )
- conn.commit()
- return token
- except Exception as e:
- logger.error(f"Error creating session: {e}")
- raise
- finally:
- if conn: conn.close()
- def get_user_from_token(token: str) -> Optional[Dict[str, Any]]:
- """Verify a session token and return the associated user data if valid and not expired"""
- conn = None
- try:
- conn = get_db_connection()
- cursor = conn.cursor()
- # Find user if token exists and hasn't expired
- cursor.execute('''
- SELECT users.* FROM users
- JOIN sessions ON users.id = sessions.user_id
- WHERE sessions.token = ? AND sessions.expires_at > ?
- ''', (token, datetime.now()))
- row = cursor.fetchone()
- return dict(row) if row else None
- except Exception as e:
- logger.error(f"Database error verifying token: {e}")
- return None
- finally:
- if conn: conn.close()
- def delete_session(token: str):
- """Securely remove a session token when the user logs out"""
- conn = None
- try:
- conn = get_db_connection()
- cursor = conn.cursor()
- cursor.execute("DELETE FROM sessions WHERE token = ?", (token,))
- conn.commit()
- except Exception as e:
- logger.error(f"Error deleting session: {e}")
- finally:
- if conn: conn.close()
- def search_foods_by_name(query: str, limit: int = 15) -> list[Dict[str, Any]]:
- """Securely search for foods matching a string query with relevance-based ordering"""
- conn = None
- try:
- conn = get_db_connection()
- cursor = conn.cursor()
-
- # SQL Injection safe query utilizing LIKE parameterization
- # We prioritize:
- # 1. Items NOT in 'Baby Foods'
- # 2. Shorter names (usually more fundamental ingredients)
- # 3. Alphabetical order as a tie-breaker
- q = f"%{query}%"
- prefix_match = f"{query}%"
-
- cursor.execute('''
- SELECT * FROM foods
- WHERE name LIKE ?
- ORDER BY
- CASE WHEN category = 'Baby Foods' THEN 1 ELSE 0 END,
- CASE WHEN name LIKE ? THEN 0 ELSE 1 END,
- LENGTH(name) ASC,
- name ASC
- LIMIT ?
- ''', (q, prefix_match, limit))
-
- rows = cursor.fetchall()
- return [dict(row) for row in rows]
- except Exception as e:
- logger.error(f"Error searching foods: {e}")
- return []
- finally:
- if conn: conn.close()
- def save_chat_message(user_id: int, role: str, content: str):
- """Persist a chat message to the database"""
- conn = None
- try:
- conn = get_db_connection()
- cursor = conn.cursor()
- cursor.execute(
- "INSERT INTO chat_messages (user_id, role, content) VALUES (?, ?, ?)",
- (user_id, role, content)
- )
- conn.commit()
- except Exception as e:
- logger.error(f"Error saving chat message: {e}")
- finally:
- if conn: conn.close()
- def get_user_chat_history(user_id: int, limit: int = 50) -> list[Dict[str, Any]]:
- """Retrieve the most recent chat messages for a user"""
- conn = None
- try:
- conn = get_db_connection()
- cursor = conn.cursor()
- # Order by created_at DESC to get recent ones, then reverse for display
- cursor.execute('''
- SELECT role, content FROM chat_messages
- WHERE user_id = ?
- ORDER BY created_at ASC
- LIMIT ?
- ''', (user_id, limit))
- rows = cursor.fetchall()
- return [dict(row) for row in rows]
- except Exception as e:
- logger.error(f"Error fetching chat history: {e}")
- return []
- finally:
- if conn: conn.close()
- def get_user_profile(user_id: int) -> Optional[Dict[str, Any]]:
- """Fetch the user's profile containing macro targets. Inserts defaults if none exists."""
- conn = None
- try:
- conn = get_db_connection()
- cursor = conn.cursor()
- cursor.execute("SELECT * FROM user_profiles WHERE user_id = ?", (user_id,))
- row = cursor.fetchone()
-
- if not row:
- # Create a default profile row if one does not exist
- cursor.execute('''
- INSERT INTO user_profiles (user_id) VALUES (?)
- ''', (user_id,))
- conn.commit()
- cursor.execute("SELECT * FROM user_profiles WHERE user_id = ?", (user_id,))
- row = cursor.fetchone()
-
- return dict(row) if row else None
- except Exception as e:
- logger.error(f"Error fetching user profile: {e}")
- return None
- finally:
- if conn: conn.close()
- def get_food_by_id(food_id: int) -> Optional[Dict[str, Any]]:
- """Retrieve a single food item by its unique ID"""
- conn = None
- try:
- conn = get_db_connection()
- cursor = conn.cursor()
- cursor.execute("SELECT * FROM foods WHERE id = ?", (food_id,))
- row = cursor.fetchone()
- return dict(row) if row else None
- except Exception as e:
- logger.error(f"Error fetching food by ID {food_id}: {e}")
- return None
- finally:
- if conn: conn.close()
- def get_foods_by_ids(food_ids: List[int]) -> List[Dict[str, Any]]:
- """Retrieve multiple food items by their unique IDs in bulk"""
- if not food_ids:
- return []
-
- conn = None
- try:
- conn = get_db_connection()
- cursor = conn.cursor()
-
- # Create placeholders for the IN clause
- placeholders = ', '.join(['?'] * len(food_ids))
- query = f"SELECT * FROM foods WHERE id IN ({placeholders})"
-
- cursor.execute(query, food_ids)
- rows = cursor.fetchall()
- return [dict(row) for row in rows]
- except Exception as e:
- logger.error(f"Error fetching foods by IDs {food_ids}: {e}")
- return []
- finally:
- if conn: conn.close()
|