database.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. import sqlite3
  2. import os
  3. import logging
  4. import secrets
  5. from datetime import datetime, timedelta
  6. from typing import Optional, Dict, Any
  7. logger = logging.getLogger(__name__)
  8. # Locate db correctly in the same directory
  9. DB_PATH = os.path.join(os.path.dirname(__file__), "localfood.db")
  10. def get_db_connection():
  11. # Enable higher timeout and disable thread checks for FastAPI async compatibility
  12. conn = sqlite3.connect(DB_PATH, timeout=20.0, check_same_thread=False)
  13. conn.row_factory = sqlite3.Row
  14. # Enable Write-Ahead Log (WAL) mode for simultaneous read/write operations
  15. conn.execute('pragma journal_mode=wal')
  16. return conn
  17. def create_tables():
  18. """Initialize the SQLite database with required tables"""
  19. try:
  20. conn = get_db_connection()
  21. cursor = conn.cursor()
  22. # Create users table securely locally
  23. cursor.execute('''
  24. CREATE TABLE IF NOT EXISTS users (
  25. id INTEGER PRIMARY KEY AUTOINCREMENT,
  26. username TEXT UNIQUE NOT NULL,
  27. password_hash TEXT NOT NULL,
  28. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
  29. )
  30. ''')
  31. # Create sessions table for database-backed tokens
  32. cursor.execute('''
  33. CREATE TABLE IF NOT EXISTS sessions (
  34. token TEXT PRIMARY KEY,
  35. user_id INTEGER NOT NULL,
  36. expires_at TIMESTAMP NOT NULL,
  37. FOREIGN KEY (user_id) REFERENCES users (id)
  38. )
  39. ''')
  40. conn.commit()
  41. conn.close()
  42. logger.info("Database and tables initialized successfully.")
  43. except Exception as e:
  44. logger.error(f"Error initializing database: {e}")
  45. raise
  46. def get_user_by_username(username: str) -> Optional[Dict[str, Any]]:
  47. """Retrieve user dictionary if they exist"""
  48. try:
  49. conn = get_db_connection()
  50. cursor = conn.cursor()
  51. cursor.execute("SELECT * FROM users WHERE username = ?", (username,))
  52. row = cursor.fetchone()
  53. conn.close()
  54. return dict(row) if row else None
  55. except Exception as e:
  56. logger.error(f"Database error fetching user: {e}")
  57. return None
  58. def create_user(username: str, password_hash: str) -> Optional[int]:
  59. """Creates a user securely. Returns user_id if successful, None if username exists."""
  60. try:
  61. conn = get_db_connection()
  62. cursor = conn.cursor()
  63. cursor.execute(
  64. "INSERT INTO users (username, password_hash) VALUES (?, ?)",
  65. (username, password_hash)
  66. )
  67. user_id = cursor.lastrowid
  68. conn.commit()
  69. conn.close()
  70. return user_id
  71. except sqlite3.IntegrityError:
  72. return None
  73. except Exception as e:
  74. logger.error(f"Database error during user creation: {e}")
  75. raise
  76. def create_session(user_id: int) -> str:
  77. """Create a secure 32-character session token in the DB valid for 24h"""
  78. token = secrets.token_urlsafe(32)
  79. expires_at = datetime.now() + timedelta(hours=24)
  80. try:
  81. conn = get_db_connection()
  82. cursor = conn.cursor()
  83. cursor.execute(
  84. "INSERT INTO sessions (token, user_id, expires_at) VALUES (?, ?, ?)",
  85. (token, user_id, expires_at)
  86. )
  87. conn.commit()
  88. conn.close()
  89. return token
  90. except Exception as e:
  91. logger.error(f"Error creating session: {e}")
  92. raise
  93. def get_user_from_token(token: str) -> Optional[Dict[str, Any]]:
  94. """Verify a session token and return the associated user data if valid and not expired"""
  95. try:
  96. conn = get_db_connection()
  97. cursor = conn.cursor()
  98. # Find user if token exists and hasn't expired
  99. cursor.execute('''
  100. SELECT users.* FROM users
  101. JOIN sessions ON users.id = sessions.user_id
  102. WHERE sessions.token = ? AND sessions.expires_at > ?
  103. ''', (token, datetime.now()))
  104. row = cursor.fetchone()
  105. conn.close()
  106. return dict(row) if row else None
  107. except Exception as e:
  108. logger.error(f"Database error verifying token: {e}")
  109. return None
  110. def delete_session(token: str):
  111. """Securely remove a session token when the user logs out"""
  112. try:
  113. conn = get_db_connection()
  114. cursor = conn.cursor()
  115. cursor.execute("DELETE FROM sessions WHERE token = ?", (token,))
  116. conn.commit()
  117. conn.close()
  118. except Exception as e:
  119. logger.error(f"Error deleting session: {e}")