Browse Source

TG-28: Parse open dataset and design SQLite schema (USDA SR Legacy ingestion)

FerRo988 3 tuần trước cách đây
mục cha
commit
0d46c04a30
4 tập tin đã thay đổi với 308 bổ sung42 xóa
  1. 93 14
      database.py
  2. 53 26
      main.py
  3. 112 0
      mega_seed_usda.py
  4. 50 2
      static/script.js

+ 93 - 14
database.py

@@ -12,14 +12,16 @@ DB_PATH = os.path.join(os.path.dirname(__file__), "localfood.db")
 
 def get_db_connection():
     # Enable higher timeout and disable thread checks for FastAPI async compatibility
-    conn = sqlite3.connect(DB_PATH, timeout=20.0, check_same_thread=False)
+    conn = sqlite3.connect(DB_PATH, timeout=30.0, check_same_thread=False)
     conn.row_factory = sqlite3.Row
     # Enable Write-Ahead Log (WAL) mode for simultaneous read/write operations
-    conn.execute('pragma journal_mode=wal')
+    conn.execute('PRAGMA journal_mode=WAL')
+    conn.execute('PRAGMA synchronous=NORMAL')
     return conn
 
 def create_tables():
     """Initialize the SQLite database with required tables"""
+    conn = None
     try:
         conn = get_db_connection()
         cursor = conn.cursor()
@@ -67,32 +69,49 @@ def create_tables():
             source TEXT DEFAULT 'System'
         )
         ''')
+
+        # Create chat history table for Sprint 6 persistence
+        cursor.execute('''
+        CREATE TABLE IF NOT EXISTS chat_messages (
+            id INTEGER PRIMARY KEY AUTOINCREMENT,
+            user_id INTEGER NOT NULL,
+            role TEXT NOT NULL,
+            content TEXT NOT NULL,
+            created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+            FOREIGN KEY (user_id) REFERENCES users (id)
+        )
+        ''')
         
         # Create index for rapid fuzzy search compatibility
         cursor.execute('CREATE INDEX IF NOT EXISTS idx_food_name ON foods(name COLLATE NOCASE)')
         
         conn.commit()
-        conn.close()
         logger.info("Database and tables initialized successfully.")
     except Exception as e:
         logger.error(f"Error initializing database: {e}")
         raise
+    finally:
+        if conn:
+            conn.close()
 
 def get_user_by_username(username: str) -> Optional[Dict[str, Any]]:
     """Retrieve user dictionary if they exist"""
+    conn = None
     try:
         conn = get_db_connection()
         cursor = conn.cursor()
         cursor.execute("SELECT * FROM users WHERE username = ?", (username,))
         row = cursor.fetchone()
-        conn.close()
         return dict(row) if row else None
     except Exception as e:
         logger.error(f"Database error fetching user: {e}")
         return None
+    finally:
+        if conn: conn.close()
 
 def create_user(username: str, password_hash: str) -> Optional[int]:
     """Creates a user securely. Returns user_id if successful, None if username exists."""
+    conn = None
     try:
         conn = get_db_connection()
         cursor = conn.cursor()
@@ -102,19 +121,21 @@ def create_user(username: str, password_hash: str) -> Optional[int]:
         )
         user_id = cursor.lastrowid
         conn.commit()
-        conn.close()
         return user_id
     except sqlite3.IntegrityError:
         return None
     except Exception as e:
         logger.error(f"Database error during user creation: {e}")
         raise
+    finally:
+        if conn: conn.close()
 
 def create_session(user_id: int) -> str:
-    """Create a secure 32-character session token in the DB valid for 24h"""
+    """Create a secure 32-character session token in the DB valid for 7 days"""
     token = secrets.token_urlsafe(32)
-    expires_at = datetime.now() + timedelta(hours=24)
+    expires_at = datetime.now() + timedelta(days=7)
     
+    conn = None
     try:
         conn = get_db_connection()
         cursor = conn.cursor()
@@ -123,14 +144,16 @@ def create_session(user_id: int) -> str:
             (token, user_id, expires_at)
         )
         conn.commit()
-        conn.close()
         return token
     except Exception as e:
         logger.error(f"Error creating session: {e}")
         raise
+    finally:
+        if conn: conn.close()
 
 def get_user_from_token(token: str) -> Optional[Dict[str, Any]]:
     """Verify a session token and return the associated user data if valid and not expired"""
+    conn = None
     try:
         conn = get_db_connection()
         cursor = conn.cursor()
@@ -141,37 +164,93 @@ def get_user_from_token(token: str) -> Optional[Dict[str, Any]]:
             WHERE sessions.token = ? AND sessions.expires_at > ?
         ''', (token, datetime.now()))
         row = cursor.fetchone()
-        conn.close()
         return dict(row) if row else None
     except Exception as e:
         logger.error(f"Database error verifying token: {e}")
         return None
+    finally:
+        if conn: conn.close()
 
 def delete_session(token: str):
     """Securely remove a session token when the user logs out"""
+    conn = None
     try:
         conn = get_db_connection()
         cursor = conn.cursor()
         cursor.execute("DELETE FROM sessions WHERE token = ?", (token,))
         conn.commit()
-        conn.close()
     except Exception as e:
         logger.error(f"Error deleting session: {e}")
+    finally:
+        if conn: conn.close()
 
 def search_foods_by_name(query: str, limit: int = 15) -> list[Dict[str, Any]]:
-    """Securely search for foods matching a string query using fuzzy matching"""
+    """Securely search for foods matching a string query with relevance-based ordering"""
+    conn = None
     try:
         conn = get_db_connection()
         cursor = conn.cursor()
         
         # SQL Injection safe query utilizing LIKE parameterization
-        # COLLATE NOCASE search inherently supported by index on table creation
+        # We prioritize: 
+        # 1. Items NOT in 'Baby Foods' 
+        # 2. Shorter names (usually more fundamental ingredients)
+        # 3. Alphabetical order as a tie-breaker
         q = f"%{query}%"
-        cursor.execute("SELECT * FROM foods WHERE name LIKE ? LIMIT ?", (q, limit))
+        prefix_match = f"{query}%"
+        
+        cursor.execute('''
+            SELECT * FROM foods 
+            WHERE name LIKE ? 
+            ORDER BY 
+                CASE WHEN category = 'Baby Foods' THEN 1 ELSE 0 END,
+                CASE WHEN name LIKE ? THEN 0 ELSE 1 END,
+                LENGTH(name) ASC,
+                name ASC
+            LIMIT ?
+        ''', (q, prefix_match, limit))
         
         rows = cursor.fetchall()
-        conn.close()
         return [dict(row) for row in rows]
     except Exception as e:
         logger.error(f"Error searching foods: {e}")
         return []
+    finally:
+        if conn: conn.close()
+
+def save_chat_message(user_id: int, role: str, content: str):
+    """Persist a chat message to the database"""
+    conn = None
+    try:
+        conn = get_db_connection()
+        cursor = conn.cursor()
+        cursor.execute(
+            "INSERT INTO chat_messages (user_id, role, content) VALUES (?, ?, ?)",
+            (user_id, role, content)
+        )
+        conn.commit()
+    except Exception as e:
+        logger.error(f"Error saving chat message: {e}")
+    finally:
+        if conn: conn.close()
+
+def get_user_chat_history(user_id: int, limit: int = 50) -> list[Dict[str, Any]]:
+    """Retrieve the most recent chat messages for a user"""
+    conn = None
+    try:
+        conn = get_db_connection()
+        cursor = conn.cursor()
+        # Order by created_at DESC to get recent ones, then reverse for display
+        cursor.execute('''
+            SELECT role, content FROM chat_messages 
+            WHERE user_id = ? 
+            ORDER BY created_at ASC 
+            LIMIT ?
+        ''', (user_id, limit))
+        rows = cursor.fetchall()
+        return [dict(row) for row in rows]
+    except Exception as e:
+        logger.error(f"Error fetching chat history: {e}")
+        return []
+    finally:
+        if conn: conn.close()

+ 53 - 26
main.py

@@ -4,7 +4,7 @@ import httpx
 import bcrypt
 from contextlib import asynccontextmanager
 from fastapi import FastAPI, HTTPException, Depends, Header
-from database import create_tables, create_user, get_user_by_username, create_session, get_user_from_token, delete_session, search_foods_by_name
+from database import create_tables, create_user, get_user_by_username, create_session, get_user_from_token, delete_session, search_foods_by_name, save_chat_message, get_user_chat_history
 from fastapi.responses import HTMLResponse, StreamingResponse
 from fastapi.staticfiles import StaticFiles
 from pydantic import BaseModel
@@ -86,12 +86,15 @@ def extract_food_context(messages: list) -> str | None:
     
     # Try each keyword against the local food database, collect unique results
     found_items = {}
-    for kw in keywords[:5]:  # Limit to first 5 keywords for performance
-        results = search_foods_by_name(kw, limit=3)
+    # Optimization: Only use the first 2 most relevant keywords to keep context small on CPU
+    for kw in keywords[:2]:
+        results = search_foods_by_name(kw, limit=2)
         for item in results:
-            if item['name'] not in found_items:
-                found_items[item['name']] = item
-        if len(found_items) >= 5:
+            # Truncate extremely long USDA names for performance
+            short_name = item['name'][:100] + ("..." if len(item['name']) > 100 else "")
+            if short_name not in found_items:
+                found_items[short_name] = item
+        if len(found_items) >= 3:
             break
     
     if not found_items:
@@ -104,12 +107,11 @@ def extract_food_context(messages: list) -> str | None:
         "Use ONLY the following data for specific nutritional values (per 100g serving):",
         ""
     ]
-    for item in found_items.values():
+    for name, item in found_items.items():
         line = (
-            f"- {item['name']}: {item['calories']} kcal | "
-            f"Protein: {item['protein_g']}g | Fat: {item['fat_g']}g | "
-            f"Carbs: {item['carbs_g']}g | Fiber: {item['fiber_g']}g | "
-            f"Sodium: {item['sodium_mg']}mg"
+            f"- {name}: {item['calories']} kcal | "
+            f"P: {item['protein_g']}g | F: {item['fat_g']}g | "
+            f"C: {item['carbs_g']}g"
         )
         lines.append(line)
     
@@ -176,45 +178,70 @@ async def chat_endpoint(request: ChatRequest, current_user: dict = Depends(get_c
     """Proxy chat requests to the local Ollama instance with streaming support.
     Automatically enriches prompts with verified local SQLite nutritional data.
     """
-    messages = [msg.model_dump() for msg in request.messages]
+    # Keep only the last 6 messages for context window performance on CPU
+    all_messages = [msg.model_dump() for msg in request.messages]
+    messages = all_messages[-6:]
+    
+    # Save the latest user message to DB
+    if messages and messages[-1]['role'] == 'user':
+        save_chat_message(current_user['id'], 'user', messages[-1]['content'])
     
     # --- TG-35: Local SQL RAG Enrichment ---
     db_context = extract_food_context(messages)
     if db_context:
-        logger.info(f"[RAG] Injecting local DB context for user '{current_user['username']}'")
         # Prepend as a system message so it acts as grounded knowledge
+        # We ensure it's a short, concise instruction to prevent context bloat
         messages = [{"role": "system", "content": db_context}] + messages
+
+    logger.info(f"[Chat] User '{current_user['username']}' is chatting. Context items: {'Yes' if db_context else 'No'}. Message count: {len(messages)}")
     
     payload = {
         "model": MODEL_NAME,
         "messages": messages,
-        "stream": True  # Enable streaming for a better UI experience
+        "stream": True
     }
     
     async def generate_response():
         try:
-            async with httpx.AsyncClient() as client:
-                async with client.stream("POST", OLLAMA_URL, json=payload, timeout=120.0) as response:
+            bot_full_response = ""
+            async with httpx.AsyncClient(timeout=300.0) as client:
+                # Use a combined timeout for the entire request
+                async with client.stream("POST", OLLAMA_URL, json=payload, timeout=300.0) as response:
                     if response.status_code != 200:
                         error_detail = await response.aread()
-                        logger.error(f"Error communicating with Ollama: {error_detail}")
-                        yield f"data: {json.dumps({'error': 'Error communicating with local LLM.'})}\n\n"
+                        logger.error(f"Ollama returned error {response.status_code}: {error_detail}")
+                        yield f"data: {json.dumps({'error': f'LLM Error ({response.status_code})'})}\n\n"
                         return
 
                     async for line in response.aiter_lines():
                         if line:
-                            data = json.loads(line)
-                            if "message" in data and "content" in data["message"]:
-                                content = data["message"]["content"]
-                                yield f"data: {json.dumps({'content': content})}\n\n"
-                            if data.get("done"):
-                                break
+                            try:
+                                data = json.loads(line)
+                                if "message" in data and "content" in data["message"]:
+                                    content = data["message"]["content"]
+                                    bot_full_response += content
+                                    yield f"data: {json.dumps({'content': content})}\n\n"
+                                if data.get("done"):
+                                    break
+                            except json.JSONDecodeError:
+                                continue
+            
+            # Save final bot response to DB
+            if bot_full_response.strip():
+                save_chat_message(current_user['id'], 'assistant', bot_full_response)
+                
         except Exception as e:
-            logger.error(f"Unexpected error during stream: {e}")
-            yield f"data: {json.dumps({'error': str(e)})}\n\n"
+            logger.exception(f"Unexpected error during chat stream: {e}")
+            yield f"data: {json.dumps({'error': 'A technical error occurred while generating the response.'})}\n\n"
 
     return StreamingResponse(generate_response(), media_type="text/event-stream")
 
+@app.get("/api/chat/history")
+async def get_history(current_user: dict = Depends(get_current_user)):
+    """Fetch the chat history for the authenticated user"""
+    history = get_user_chat_history(current_user['id'])
+    return {"history": history}
+
 @app.get("/api/food/search")
 async def search_food(q: str, current_user: dict = Depends(get_current_user)):
     """API endpoint to search for food items securely using token authentication"""

+ 112 - 0
mega_seed_usda.py

@@ -0,0 +1,112 @@
+import sqlite3
+import os
+import csv
+import sys
+
+# Define path to the unzipped SR Legacy data
+SR_PATH = "sr28data"
+DB_PATH = "localfood.db"
+
+# Nutrient IDs to extract (USDA SR28 IDs)
+NUTRIENT_MAP = {
+    '208': 'calories',
+    '203': 'protein_g',
+    '204': 'fat_g',
+    '205': 'carbs_g',
+    '291': 'fiber_g',
+    '269': 'sugar_g',
+    '307': 'sodium_mg'
+}
+
+def parse_usda_line(line):
+    """USDA SR Legacy files are ^-delimited with ~ around strings"""
+    return [item.strip('~') for item in line.strip().split('^')]
+
+def run_seeding():
+    print("Starting Mega-Seeding from USDA SR Legacy...")
+    
+    if not os.path.exists(SR_PATH):
+        print(f"Error: {SR_PATH} directory not found.")
+        return
+
+    # 0. Load Food Groups (ID -> Group Name)
+    food_groups = {}
+    print("Reading food groups...")
+    with open(os.path.join(SR_PATH, "FD_GROUP.txt"), "r", encoding="iso-8859-1") as f:
+        for line in f:
+            parts = parse_usda_line(line)
+            group_id = parts[0]
+            group_name = parts[1]
+            food_groups[group_id] = group_name
+
+    # 1. Load Food Descriptions (NDB_No -> Name, Group_ID)
+    food_info = {}
+    print("Reading food descriptions...")
+    with open(os.path.join(SR_PATH, "FOOD_DES.txt"), "r", encoding="iso-8859-1") as f:
+        for line in f:
+            parts = parse_usda_line(line)
+            ndb_no = parts[0]
+            group_id = parts[1]
+            long_desc = parts[2]
+            food_info[ndb_no] = {
+                'name': long_desc,
+                'category': food_groups.get(group_id, "Unknown")
+            }
+
+    # 2. Load Nutrient Data
+    # Structure: ndb_no -> {nutrient_id: value}
+    nutrient_data = {}
+    print("Reading nutrient data (this may take a moment)...")
+    with open(os.path.join(SR_PATH, "NUT_DATA.txt"), "r", encoding="iso-8859-1") as f:
+        for line in f:
+            parts = parse_usda_line(line)
+            ndb_no = parts[0]
+            nutr_no = parts[1]
+            val = float(parts[2]) if parts[2] else 0.0
+            
+            if nutr_no in NUTRIENT_MAP:
+                if ndb_no not in nutrient_data:
+                    nutrient_data[ndb_no] = {}
+                nutrient_data[ndb_no][NUTRIENT_MAP[nutr_no]] = val
+
+    # 3. Insert into localfood.db
+    print(f"Ingesting into {DB_PATH}...")
+    conn = sqlite3.connect(DB_PATH)
+    cursor = conn.cursor()
+    
+    # First, clear existing foods to avoid duplicates if re-running
+    cursor.execute("DELETE FROM foods")
+    
+    count = 0
+    for ndb_no, info in food_info.items():
+        macros = nutrient_data.get(ndb_no, {})
+        
+        # Only add if we have at least some nutritional info
+        if not macros: continue
+        
+        cursor.execute('''
+            INSERT INTO foods (
+                name, category, calories, protein_g, fat_g, carbs_g, fiber_g, sugar_g, sodium_mg, source
+            ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
+        ''', (
+            info['name'], 
+            info['category'],
+            macros.get('calories', 0.0),
+            macros.get('protein_g', 0.0),
+            macros.get('fat_g', 0.0),
+            macros.get('carbs_g', 0.0),
+            macros.get('fiber_g', 0.0),
+            macros.get('sugar_g', 0.0),
+            macros.get('sodium_mg', 0.0),
+            f"USDA-{ndb_no}"
+        ))
+        count += 1
+        if count % 1000 == 0:
+            print(f"Inserted {count} items...")
+
+    conn.commit()
+    conn.close()
+    print(f"SUCCESS: Successfully seeded {count} high-quality items into the local database!")
+
+if __name__ == "__main__":
+    run_seeding()

+ 50 - 2
static/script.js

@@ -221,7 +221,7 @@ document.addEventListener('DOMContentLoaded', () => {
         setLoggedInState(savedUser, savedToken);
     }
 
-    function setLoggedInState(username, token) {
+    async function setLoggedInState(username, token) {
         localStorage.setItem('localFoodUser', username);
         localStorage.setItem('localFoodToken', token);
         userGreeting.textContent = `Welcome, ${username}`;
@@ -233,6 +233,41 @@ document.addEventListener('DOMContentLoaded', () => {
             chatApp.classList.add('fade-in');
             userInput.focus();
         }, 500);
+
+        // Load persisted chat history from the server
+        await loadChatHistory();
+    }
+
+    async function loadChatHistory() {
+        const token = localStorage.getItem('localFoodToken');
+        if (!token) return;
+
+        try {
+            const response = await fetch('/api/chat/history', {
+                headers: { 'Authorization': `Bearer ${token}` }
+            });
+
+            if (response.status === 401) {
+                setLoggedOutState();
+                return;
+            }
+
+            if (response.ok) {
+                const data = await response.json();
+                if (data.history && data.history.length > 0) {
+                    // Clear initial welcome message if we have real history
+                    chatContainer.innerHTML = '';
+                    chatHistory = []; // Reset local state
+                    
+                    data.history.forEach(msg => {
+                        addMessage(msg.role, msg.content);
+                        chatHistory.push({ role: msg.role, content: msg.content });
+                    });
+                }
+            }
+        } catch (err) {
+            console.error("Failed to load chat history:", err);
+        }
     }
 
     async function setLoggedOutState() {
@@ -385,6 +420,18 @@ document.addEventListener('DOMContentLoaded', () => {
             const response = await fetch(`/api/food/search?q=${encodeURIComponent(query)}`, {
                 headers: { 'Authorization': `Bearer ${token}` }
             });
+            
+            if (response.status === 401) {
+                setLoggedOutState();
+                addMessage('system', 'Your session has expired. Please log in again to use the food search.');
+                searchDropdown.style.display = 'none';
+                return;
+            }
+
+            if (!response.ok) {
+                throw new Error(`HTTP error! status: ${response.status}`);
+            }
+
             const data = await response.json();
             
             if (data.results && data.results.length > 0) {
@@ -424,7 +471,8 @@ document.addEventListener('DOMContentLoaded', () => {
                 searchDropdown.innerHTML = '<div class="search-empty">No matching foods found.</div>';
             }
         } catch (error) {
-            searchDropdown.innerHTML = '<div class="search-empty">Error searching database.</div>';
+            console.error('Search error:', error);
+            searchDropdown.innerHTML = '<div class="search-empty">Service currently unavailable. Please try again.</div>';
         }
     };