main.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. import json
  2. import logging
  3. import httpx
  4. import bcrypt
  5. from contextlib import asynccontextmanager
  6. from fastapi import FastAPI, HTTPException, Depends, Header
  7. from database import create_tables, create_user, get_user_by_username, create_session, get_user_from_token, delete_session
  8. from fastapi.responses import HTMLResponse, StreamingResponse
  9. from fastapi.staticfiles import StaticFiles
  10. from pydantic import BaseModel
  11. from typing import List, Generator, Optional
  12. logging.basicConfig(level=logging.INFO)
  13. logger = logging.getLogger(__name__)
  14. @asynccontextmanager
  15. async def lifespan(app: FastAPI):
  16. create_tables()
  17. yield
  18. app = FastAPI(title="LocalFoodAI Chat", lifespan=lifespan)
  19. # Use direct bcrypt for better environment compatibility
  20. def get_password_hash(password: str):
  21. # Hash requires bytes
  22. pwd_bytes = password.encode('utf-8')
  23. salt = bcrypt.gensalt()
  24. hashed = bcrypt.hashpw(pwd_bytes, salt)
  25. return hashed.decode('utf-8')
  26. def verify_password(plain_password: str, hashed_password: str):
  27. # bcrypt.checkpw handles verification
  28. return bcrypt.checkpw(plain_password.encode('utf-8'), hashed_password.encode('utf-8'))
  29. class UserCreate(BaseModel):
  30. username: str
  31. password: str
  32. class UserLogin(BaseModel):
  33. username: str
  34. password: str
  35. async def get_current_user(authorization: Optional[str] = Header(None)):
  36. if not authorization or not authorization.startswith("Bearer "):
  37. raise HTTPException(status_code=401, detail="Authentication required")
  38. token = authorization.split(" ")[1]
  39. user = get_user_from_token(token)
  40. if not user:
  41. raise HTTPException(status_code=401, detail="Invalid or expired session")
  42. return user
  43. OLLAMA_URL = "http://localhost:11434/api/chat"
  44. MODEL_NAME = "llama3.1:8b"
  45. # Mount static files to serve the frontend
  46. app.mount("/static", StaticFiles(directory="static"), name="static")
  47. class ChatMessage(BaseModel):
  48. role: str
  49. content: str
  50. class ChatRequest(BaseModel):
  51. messages: List[ChatMessage]
  52. @app.get("/", response_class=HTMLResponse)
  53. async def read_root():
  54. """Serve the chat interface HTML"""
  55. try:
  56. with open("static/index.html", "r", encoding="utf-8") as f:
  57. return HTMLResponse(content=f.read())
  58. except FileNotFoundError:
  59. return HTMLResponse(content="<h1>Welcome to LocalFoodAI</h1><p>static/index.html not found. Please create the frontend.</p>")
  60. @app.post("/api/register")
  61. async def register_user(user: UserCreate):
  62. if len(user.username.strip()) < 3:
  63. raise HTTPException(status_code=400, detail="Username must be at least 3 characters")
  64. if len(user.password.strip()) < 6:
  65. raise HTTPException(status_code=400, detail="Password must be at least 6 characters")
  66. hashed_password = get_password_hash(user.password)
  67. user_id = create_user(user.username.strip(), hashed_password)
  68. if not user_id:
  69. raise HTTPException(status_code=400, detail="Username already exists")
  70. # Auto-login after registration
  71. token = create_session(user_id)
  72. return {"message": "User registered successfully", "token": token, "username": user.username.strip()}
  73. @app.post("/api/login")
  74. async def login_user(user: UserLogin):
  75. db_user = get_user_by_username(user.username.strip())
  76. if not db_user:
  77. raise HTTPException(status_code=401, detail="Invalid username or password")
  78. if not verify_password(user.password, db_user["password_hash"]):
  79. raise HTTPException(status_code=401, detail="Invalid username or password")
  80. token = create_session(db_user["id"])
  81. return {"status": "success", "username": db_user["username"], "token": token}
  82. @app.post("/api/logout")
  83. async def logout(authorization: Optional[str] = Header(None)):
  84. if authorization and authorization.startswith("Bearer "):
  85. token = authorization.split(" ")[1]
  86. delete_session(token)
  87. return {"message": "Logged out successfully"}
  88. @app.post("/chat")
  89. async def chat_endpoint(request: ChatRequest, current_user: dict = Depends(get_current_user)):
  90. """Proxy chat requests to the local Ollama instance with streaming support"""
  91. payload = {
  92. "model": MODEL_NAME,
  93. "messages": [msg.model_dump() for msg in request.messages],
  94. "stream": True # Enable streaming for a better UI experience
  95. }
  96. async def generate_response():
  97. try:
  98. async with httpx.AsyncClient() as client:
  99. async with client.stream("POST", OLLAMA_URL, json=payload, timeout=120.0) as response:
  100. if response.status_code != 200:
  101. error_detail = await response.aread()
  102. logger.error(f"Error communicating with Ollama: {error_detail}")
  103. yield f"data: {json.dumps({'error': 'Error communicating with local LLM.'})}\n\n"
  104. return
  105. async for line in response.aiter_lines():
  106. if line:
  107. data = json.loads(line)
  108. if "message" in data and "content" in data["message"]:
  109. content = data["message"]["content"]
  110. yield f"data: {json.dumps({'content': content})}\n\n"
  111. if data.get("done"):
  112. break
  113. except Exception as e:
  114. logger.error(f"Unexpected error during stream: {e}")
  115. yield f"data: {json.dumps({'error': str(e)})}\n\n"
  116. return StreamingResponse(generate_response(), media_type="text/event-stream")
  117. if __name__ == "__main__":
  118. import uvicorn
  119. uvicorn.run("main:app", host="127.0.0.1", port=8000, reload=True)