database.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. import sqlite3
  2. import os
  3. import logging
  4. import secrets
  5. from datetime import datetime, timedelta
  6. from typing import Optional, Dict, Any
  7. logger = logging.getLogger(__name__)
  8. # Locate db correctly in the same directory
  9. DB_PATH = os.path.join(os.path.dirname(__file__), "localfood.db")
  10. def get_db_connection():
  11. conn = sqlite3.connect(DB_PATH)
  12. conn.row_factory = sqlite3.Row
  13. return conn
  14. def create_tables():
  15. """Initialize the SQLite database with required tables"""
  16. try:
  17. conn = get_db_connection()
  18. cursor = conn.cursor()
  19. # Create users table securely locally
  20. cursor.execute('''
  21. CREATE TABLE IF NOT EXISTS users (
  22. id INTEGER PRIMARY KEY AUTOINCREMENT,
  23. username TEXT UNIQUE NOT NULL,
  24. password_hash TEXT NOT NULL,
  25. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
  26. )
  27. ''')
  28. # Create sessions table for database-backed tokens
  29. cursor.execute('''
  30. CREATE TABLE IF NOT EXISTS sessions (
  31. token TEXT PRIMARY KEY,
  32. user_id INTEGER NOT NULL,
  33. expires_at TIMESTAMP NOT NULL,
  34. FOREIGN KEY (user_id) REFERENCES users (id)
  35. )
  36. ''')
  37. conn.commit()
  38. conn.close()
  39. logger.info("Database and tables initialized successfully.")
  40. except Exception as e:
  41. logger.error(f"Error initializing database: {e}")
  42. raise
  43. def get_user_by_username(username: str) -> Optional[Dict[str, Any]]:
  44. """Retrieve user dictionary if they exist"""
  45. try:
  46. conn = get_db_connection()
  47. cursor = conn.cursor()
  48. cursor.execute("SELECT * FROM users WHERE username = ?", (username,))
  49. row = cursor.fetchone()
  50. conn.close()
  51. return dict(row) if row else None
  52. except Exception as e:
  53. logger.error(f"Database error fetching user: {e}")
  54. return None
  55. def create_user(username: str, password_hash: str) -> Optional[int]:
  56. """Creates a user securely. Returns user_id if successful, None if username exists."""
  57. try:
  58. conn = get_db_connection()
  59. cursor = conn.cursor()
  60. cursor.execute(
  61. "INSERT INTO users (username, password_hash) VALUES (?, ?)",
  62. (username, password_hash)
  63. )
  64. user_id = cursor.lastrowid
  65. conn.commit()
  66. conn.close()
  67. return user_id
  68. except sqlite3.IntegrityError:
  69. return None
  70. except Exception as e:
  71. logger.error(f"Database error during user creation: {e}")
  72. raise
  73. def create_session(user_id: int) -> str:
  74. """Create a secure 32-character session token in the DB valid for 24h"""
  75. token = secrets.token_urlsafe(32)
  76. expires_at = datetime.now() + timedelta(hours=24)
  77. try:
  78. conn = get_db_connection()
  79. cursor = conn.cursor()
  80. cursor.execute(
  81. "INSERT INTO sessions (token, user_id, expires_at) VALUES (?, ?, ?)",
  82. (token, user_id, expires_at)
  83. )
  84. conn.commit()
  85. conn.close()
  86. return token
  87. except Exception as e:
  88. logger.error(f"Error creating session: {e}")
  89. raise
  90. def get_user_from_token(token: str) -> Optional[Dict[str, Any]]:
  91. """Verify a session token and return the associated user data if valid and not expired"""
  92. try:
  93. conn = get_db_connection()
  94. cursor = conn.cursor()
  95. # Find user if token exists and hasn't expired
  96. cursor.execute('''
  97. SELECT users.* FROM users
  98. JOIN sessions ON users.id = sessions.user_id
  99. WHERE sessions.token = ? AND sessions.expires_at > ?
  100. ''', (token, datetime.now()))
  101. row = cursor.fetchone()
  102. conn.close()
  103. return dict(row) if row else None
  104. except Exception as e:
  105. logger.error(f"Database error verifying token: {e}")
  106. return None
  107. def delete_session(token: str):
  108. """Securely remove a session token when the user logs out"""
  109. try:
  110. conn = get_db_connection()
  111. cursor = conn.cursor()
  112. cursor.execute("DELETE FROM sessions WHERE token = ?", (token,))
  113. conn.commit()
  114. conn.close()
  115. except Exception as e:
  116. logger.error(f"Error deleting session: {e}")