|
@@ -1,14 +1,14 @@
|
|
|
import json
|
|
import json
|
|
|
import logging
|
|
import logging
|
|
|
import httpx
|
|
import httpx
|
|
|
|
|
+import bcrypt
|
|
|
from contextlib import asynccontextmanager
|
|
from contextlib import asynccontextmanager
|
|
|
-from fastapi import FastAPI, HTTPException
|
|
|
|
|
-from database import create_tables, create_user, get_user_by_username
|
|
|
|
|
-from passlib.context import CryptContext
|
|
|
|
|
|
|
+from fastapi import FastAPI, HTTPException, Depends, Header
|
|
|
|
|
+from database import create_tables, create_user, get_user_by_username, create_session, get_user_from_token, delete_session
|
|
|
from fastapi.responses import HTMLResponse, StreamingResponse
|
|
from fastapi.responses import HTMLResponse, StreamingResponse
|
|
|
from fastapi.staticfiles import StaticFiles
|
|
from fastapi.staticfiles import StaticFiles
|
|
|
from pydantic import BaseModel
|
|
from pydantic import BaseModel
|
|
|
-from typing import List, Generator
|
|
|
|
|
|
|
+from typing import List, Generator, Optional
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
logger = logging.getLogger(__name__)
|
|
logger = logging.getLogger(__name__)
|
|
@@ -20,13 +20,17 @@ async def lifespan(app: FastAPI):
|
|
|
|
|
|
|
|
app = FastAPI(title="LocalFoodAI Chat", lifespan=lifespan)
|
|
app = FastAPI(title="LocalFoodAI Chat", lifespan=lifespan)
|
|
|
|
|
|
|
|
-pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
|
|
|
|
|
|
+# Use direct bcrypt for better environment compatibility
|
|
|
|
|
+def get_password_hash(password: str):
|
|
|
|
|
+ # Hash requires bytes
|
|
|
|
|
+ pwd_bytes = password.encode('utf-8')
|
|
|
|
|
+ salt = bcrypt.gensalt()
|
|
|
|
|
+ hashed = bcrypt.hashpw(pwd_bytes, salt)
|
|
|
|
|
+ return hashed.decode('utf-8')
|
|
|
|
|
|
|
|
-def get_password_hash(password):
|
|
|
|
|
- return pwd_context.hash(password)
|
|
|
|
|
-
|
|
|
|
|
-def verify_password(plain_password, hashed_password):
|
|
|
|
|
- return pwd_context.verify(plain_password, hashed_password)
|
|
|
|
|
|
|
+def verify_password(plain_password: str, hashed_password: str):
|
|
|
|
|
+ # bcrypt.checkpw handles verification
|
|
|
|
|
+ return bcrypt.checkpw(plain_password.encode('utf-8'), hashed_password.encode('utf-8'))
|
|
|
|
|
|
|
|
class UserCreate(BaseModel):
|
|
class UserCreate(BaseModel):
|
|
|
username: str
|
|
username: str
|
|
@@ -36,6 +40,16 @@ class UserLogin(BaseModel):
|
|
|
username: str
|
|
username: str
|
|
|
password: str
|
|
password: str
|
|
|
|
|
|
|
|
|
|
+async def get_current_user(authorization: Optional[str] = Header(None)):
|
|
|
|
|
+ if not authorization or not authorization.startswith("Bearer "):
|
|
|
|
|
+ raise HTTPException(status_code=401, detail="Authentication required")
|
|
|
|
|
+
|
|
|
|
|
+ token = authorization.split(" ")[1]
|
|
|
|
|
+ user = get_user_from_token(token)
|
|
|
|
|
+ if not user:
|
|
|
|
|
+ raise HTTPException(status_code=401, detail="Invalid or expired session")
|
|
|
|
|
+ return user
|
|
|
|
|
+
|
|
|
OLLAMA_URL = "http://localhost:11434/api/chat"
|
|
OLLAMA_URL = "http://localhost:11434/api/chat"
|
|
|
MODEL_NAME = "llama3.1:8b"
|
|
MODEL_NAME = "llama3.1:8b"
|
|
|
|
|
|
|
@@ -66,11 +80,13 @@ async def register_user(user: UserCreate):
|
|
|
raise HTTPException(status_code=400, detail="Password must be at least 6 characters")
|
|
raise HTTPException(status_code=400, detail="Password must be at least 6 characters")
|
|
|
|
|
|
|
|
hashed_password = get_password_hash(user.password)
|
|
hashed_password = get_password_hash(user.password)
|
|
|
- success = create_user(user.username.strip(), hashed_password)
|
|
|
|
|
- if not success:
|
|
|
|
|
|
|
+ user_id = create_user(user.username.strip(), hashed_password)
|
|
|
|
|
+ if not user_id:
|
|
|
raise HTTPException(status_code=400, detail="Username already exists")
|
|
raise HTTPException(status_code=400, detail="Username already exists")
|
|
|
|
|
|
|
|
- return {"message": "User registered successfully"}
|
|
|
|
|
|
|
+ # Auto-login after registration
|
|
|
|
|
+ token = create_session(user_id)
|
|
|
|
|
+ return {"message": "User registered successfully", "token": token, "username": user.username.strip()}
|
|
|
|
|
|
|
|
@app.post("/api/login")
|
|
@app.post("/api/login")
|
|
|
async def login_user(user: UserLogin):
|
|
async def login_user(user: UserLogin):
|
|
@@ -81,10 +97,18 @@ async def login_user(user: UserLogin):
|
|
|
if not verify_password(user.password, db_user["password_hash"]):
|
|
if not verify_password(user.password, db_user["password_hash"]):
|
|
|
raise HTTPException(status_code=401, detail="Invalid username or password")
|
|
raise HTTPException(status_code=401, detail="Invalid username or password")
|
|
|
|
|
|
|
|
- return {"status": "success", "username": db_user["username"]}
|
|
|
|
|
|
|
+ token = create_session(db_user["id"])
|
|
|
|
|
+ return {"status": "success", "username": db_user["username"], "token": token}
|
|
|
|
|
+
|
|
|
|
|
+@app.post("/api/logout")
|
|
|
|
|
+async def logout(authorization: Optional[str] = Header(None)):
|
|
|
|
|
+ if authorization and authorization.startswith("Bearer "):
|
|
|
|
|
+ token = authorization.split(" ")[1]
|
|
|
|
|
+ delete_session(token)
|
|
|
|
|
+ return {"message": "Logged out successfully"}
|
|
|
|
|
|
|
|
@app.post("/chat")
|
|
@app.post("/chat")
|
|
|
-async def chat_endpoint(request: ChatRequest):
|
|
|
|
|
|
|
+async def chat_endpoint(request: ChatRequest, current_user: dict = Depends(get_current_user)):
|
|
|
"""Proxy chat requests to the local Ollama instance with streaming support"""
|
|
"""Proxy chat requests to the local Ollama instance with streaming support"""
|
|
|
payload = {
|
|
payload = {
|
|
|
"model": MODEL_NAME,
|
|
"model": MODEL_NAME,
|