| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121 |
- import json
- import logging
- import httpx
- from contextlib import asynccontextmanager
- from fastapi import FastAPI, HTTPException
- from database import create_tables, create_user, get_user_by_username
- from passlib.context import CryptContext
- from fastapi.responses import HTMLResponse, StreamingResponse
- from fastapi.staticfiles import StaticFiles
- from pydantic import BaseModel
- from typing import List, Generator
- logging.basicConfig(level=logging.INFO)
- logger = logging.getLogger(__name__)
- @asynccontextmanager
- async def lifespan(app: FastAPI):
- create_tables()
- yield
- app = FastAPI(title="LocalFoodAI Chat", lifespan=lifespan)
- pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
- def get_password_hash(password):
- return pwd_context.hash(password)
- def verify_password(plain_password, hashed_password):
- return pwd_context.verify(plain_password, hashed_password)
- class UserCreate(BaseModel):
- username: str
- password: str
- class UserLogin(BaseModel):
- username: str
- password: str
- OLLAMA_URL = "http://localhost:11434/api/chat"
- MODEL_NAME = "llama3.1:8b"
- # Mount static files to serve the frontend
- app.mount("/static", StaticFiles(directory="static"), name="static")
- class ChatMessage(BaseModel):
- role: str
- content: str
- class ChatRequest(BaseModel):
- messages: List[ChatMessage]
- @app.get("/", response_class=HTMLResponse)
- async def read_root():
- """Serve the chat interface HTML"""
- try:
- with open("static/index.html", "r", encoding="utf-8") as f:
- return HTMLResponse(content=f.read())
- except FileNotFoundError:
- return HTMLResponse(content="<h1>Welcome to LocalFoodAI</h1><p>static/index.html not found. Please create the frontend.</p>")
- @app.post("/api/register")
- async def register_user(user: UserCreate):
- if len(user.username.strip()) < 3:
- raise HTTPException(status_code=400, detail="Username must be at least 3 characters")
- if len(user.password.strip()) < 6:
- raise HTTPException(status_code=400, detail="Password must be at least 6 characters")
-
- hashed_password = get_password_hash(user.password)
- success = create_user(user.username.strip(), hashed_password)
- if not success:
- raise HTTPException(status_code=400, detail="Username already exists")
-
- return {"message": "User registered successfully"}
- @app.post("/api/login")
- async def login_user(user: UserLogin):
- db_user = get_user_by_username(user.username.strip())
- if not db_user:
- raise HTTPException(status_code=401, detail="Invalid username or password")
-
- if not verify_password(user.password, db_user["password_hash"]):
- raise HTTPException(status_code=401, detail="Invalid username or password")
-
- return {"status": "success", "username": db_user["username"]}
- @app.post("/chat")
- async def chat_endpoint(request: ChatRequest):
- """Proxy chat requests to the local Ollama instance with streaming support"""
- payload = {
- "model": MODEL_NAME,
- "messages": [msg.model_dump() for msg in request.messages],
- "stream": True # Enable streaming for a better UI experience
- }
-
- async def generate_response():
- try:
- async with httpx.AsyncClient() as client:
- async with client.stream("POST", OLLAMA_URL, json=payload, timeout=120.0) as response:
- if response.status_code != 200:
- error_detail = await response.aread()
- logger.error(f"Error communicating with Ollama: {error_detail}")
- yield f"data: {json.dumps({'error': 'Error communicating with local LLM.'})}\n\n"
- return
- async for line in response.aiter_lines():
- if line:
- data = json.loads(line)
- if "message" in data and "content" in data["message"]:
- content = data["message"]["content"]
- yield f"data: {json.dumps({'content': content})}\n\n"
- if data.get("done"):
- break
- except Exception as e:
- logger.error(f"Unexpected error during stream: {e}")
- yield f"data: {json.dumps({'error': str(e)})}\n\n"
- return StreamingResponse(generate_response(), media_type="text/event-stream")
- if __name__ == "__main__":
- import uvicorn
- uvicorn.run("main:app", host="127.0.0.1", port=8000, reload=True)
|