main.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. import json
  2. import logging
  3. import httpx
  4. from fastapi import FastAPI, HTTPException
  5. from fastapi.responses import HTMLResponse, StreamingResponse
  6. from fastapi.staticfiles import StaticFiles
  7. from pydantic import BaseModel
  8. from typing import List, Generator
  9. logging.basicConfig(level=logging.INFO)
  10. logger = logging.getLogger(__name__)
  11. app = FastAPI(title="LocalFoodAI Chat")
  12. OLLAMA_URL = "http://localhost:11434/api/chat"
  13. MODEL_NAME = "llama3.1:8b"
  14. # Mount static files to serve the frontend
  15. app.mount("/static", StaticFiles(directory="static"), name="static")
  16. class ChatMessage(BaseModel):
  17. role: str
  18. content: str
  19. class ChatRequest(BaseModel):
  20. messages: List[ChatMessage]
  21. @app.get("/", response_class=HTMLResponse)
  22. async def read_root():
  23. """Serve the chat interface HTML"""
  24. try:
  25. with open("static/index.html", "r", encoding="utf-8") as f:
  26. return HTMLResponse(content=f.read())
  27. except FileNotFoundError:
  28. return HTMLResponse(content="<h1>Welcome to LocalFoodAI</h1><p>static/index.html not found. Please create the frontend.</p>")
  29. @app.post("/chat")
  30. async def chat_endpoint(request: ChatRequest):
  31. """Proxy chat requests to the local Ollama instance with streaming support"""
  32. payload = {
  33. "model": MODEL_NAME,
  34. "messages": [msg.model_dump() for msg in request.messages],
  35. "stream": True # Enable streaming for a better UI experience
  36. }
  37. async def generate_response():
  38. try:
  39. async with httpx.AsyncClient() as client:
  40. async with client.stream("POST", OLLAMA_URL, json=payload, timeout=120.0) as response:
  41. if response.status_code != 200:
  42. error_detail = await response.aread()
  43. logger.error(f"Error communicating with Ollama: {error_detail}")
  44. yield f"data: {json.dumps({'error': 'Error communicating with local LLM.'})}\n\n"
  45. return
  46. async for line in response.aiter_lines():
  47. if line:
  48. data = json.loads(line)
  49. if "message" in data and "content" in data["message"]:
  50. content = data["message"]["content"]
  51. yield f"data: {json.dumps({'content': content})}\n\n"
  52. if data.get("done"):
  53. break
  54. except Exception as e:
  55. logger.error(f"Unexpected error during stream: {e}")
  56. yield f"data: {json.dumps({'error': str(e)})}\n\n"
  57. return StreamingResponse(generate_response(), media_type="text/event-stream")
  58. if __name__ == "__main__":
  59. import uvicorn
  60. uvicorn.run("main:app", host="127.0.0.1", port=8000, reload=True)