main.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. import json
  2. import logging
  3. import httpx
  4. from contextlib import asynccontextmanager
  5. from fastapi import FastAPI, HTTPException
  6. from database import create_tables, create_user, get_user_by_username
  7. from passlib.context import CryptContext
  8. from fastapi.responses import HTMLResponse, StreamingResponse
  9. from fastapi.staticfiles import StaticFiles
  10. from pydantic import BaseModel
  11. from typing import List, Generator
  12. logging.basicConfig(level=logging.INFO)
  13. logger = logging.getLogger(__name__)
  14. @asynccontextmanager
  15. async def lifespan(app: FastAPI):
  16. create_tables()
  17. yield
  18. app = FastAPI(title="LocalFoodAI Chat", lifespan=lifespan)
  19. pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
  20. def get_password_hash(password):
  21. return pwd_context.hash(password)
  22. def verify_password(plain_password, hashed_password):
  23. return pwd_context.verify(plain_password, hashed_password)
  24. class UserCreate(BaseModel):
  25. username: str
  26. password: str
  27. class UserLogin(BaseModel):
  28. username: str
  29. password: str
  30. OLLAMA_URL = "http://localhost:11434/api/chat"
  31. MODEL_NAME = "llama3.1:8b"
  32. # Mount static files to serve the frontend
  33. app.mount("/static", StaticFiles(directory="static"), name="static")
  34. class ChatMessage(BaseModel):
  35. role: str
  36. content: str
  37. class ChatRequest(BaseModel):
  38. messages: List[ChatMessage]
  39. @app.get("/", response_class=HTMLResponse)
  40. async def read_root():
  41. """Serve the chat interface HTML"""
  42. try:
  43. with open("static/index.html", "r", encoding="utf-8") as f:
  44. return HTMLResponse(content=f.read())
  45. except FileNotFoundError:
  46. return HTMLResponse(content="<h1>Welcome to LocalFoodAI</h1><p>static/index.html not found. Please create the frontend.</p>")
  47. @app.post("/api/register")
  48. async def register_user(user: UserCreate):
  49. if len(user.username.strip()) < 3:
  50. raise HTTPException(status_code=400, detail="Username must be at least 3 characters")
  51. if len(user.password.strip()) < 6:
  52. raise HTTPException(status_code=400, detail="Password must be at least 6 characters")
  53. hashed_password = get_password_hash(user.password)
  54. success = create_user(user.username.strip(), hashed_password)
  55. if not success:
  56. raise HTTPException(status_code=400, detail="Username already exists")
  57. return {"message": "User registered successfully"}
  58. @app.post("/api/login")
  59. async def login_user(user: UserLogin):
  60. db_user = get_user_by_username(user.username.strip())
  61. if not db_user:
  62. raise HTTPException(status_code=401, detail="Invalid username or password")
  63. if not verify_password(user.password, db_user["password_hash"]):
  64. raise HTTPException(status_code=401, detail="Invalid username or password")
  65. return {"status": "success", "username": db_user["username"]}
  66. @app.post("/chat")
  67. async def chat_endpoint(request: ChatRequest):
  68. """Proxy chat requests to the local Ollama instance with streaming support"""
  69. payload = {
  70. "model": MODEL_NAME,
  71. "messages": [msg.model_dump() for msg in request.messages],
  72. "stream": True # Enable streaming for a better UI experience
  73. }
  74. async def generate_response():
  75. try:
  76. async with httpx.AsyncClient() as client:
  77. async with client.stream("POST", OLLAMA_URL, json=payload, timeout=120.0) as response:
  78. if response.status_code != 200:
  79. error_detail = await response.aread()
  80. logger.error(f"Error communicating with Ollama: {error_detail}")
  81. yield f"data: {json.dumps({'error': 'Error communicating with local LLM.'})}\n\n"
  82. return
  83. async for line in response.aiter_lines():
  84. if line:
  85. data = json.loads(line)
  86. if "message" in data and "content" in data["message"]:
  87. content = data["message"]["content"]
  88. yield f"data: {json.dumps({'content': content})}\n\n"
  89. if data.get("done"):
  90. break
  91. except Exception as e:
  92. logger.error(f"Unexpected error during stream: {e}")
  93. yield f"data: {json.dumps({'error': str(e)})}\n\n"
  94. return StreamingResponse(generate_response(), media_type="text/event-stream")
  95. if __name__ == "__main__":
  96. import uvicorn
  97. uvicorn.run("main:app", host="127.0.0.1", port=8000, reload=True)