import os import sys import uvicorn from fastapi import Body, FastAPI from fastapi.responses import FileResponse from fastapi.staticfiles import StaticFiles from pydantic import BaseModel # Import core LLM logic from llm import SOURCES_DIR, generate_text, load_or_train_model # --- Configuration --- # Models to pre-load on startup PRELOAD_N_GRAMS = [2, 3, 4, 5] UI_DIR = "ui" # --- Globals --- # Cache for loaded models: {n: model} MODEL_CACHE = {} # --- Pydantic Models --- class PredictRequest(BaseModel): prompt: str temperature: float = 1.6 n: int = 4 length: int = 5 class PredictResponse(BaseModel): prediction: str # --- FastAPI App --- app = FastAPI() def get_model_for_n(n: int): """ Retrieves the model for a specific N from cache, or loads/trains it. """ global MODEL_CACHE if n in MODEL_CACHE: return MODEL_CACHE[n] print(f"Loading/Training model for N={n}...") model = load_or_train_model(SOURCES_DIR, n) MODEL_CACHE[n] = model return model @app.on_event("startup") def startup_event(): """ On server startup, pre-load models for all specified N-grams. """ print("Server starting up. Pre-loading models...") for n in PRELOAD_N_GRAMS: get_model_for_n(n) print(f"Models for N={PRELOAD_N_GRAMS} loaded. Server is ready.") @app.post("/api/predict", response_model=PredictResponse) async def predict(request: PredictRequest): """ API endpoint to get the next word prediction. """ n = max(2, min(request.n, 5)) model = get_model_for_n(n) if not model: return {"prediction": ""} length = max(1, min(request.length, 500)) prediction = generate_text( model, start_prompt=request.prompt, length=length, temperature=request.temperature, ) return PredictResponse(prediction=prediction) @app.get("/api") async def api_docs(): """ API documentation page. """ return FileResponse(os.path.join(UI_DIR, "api.html")) # --- Static Files and Root --- app.mount("/ui", StaticFiles(directory=UI_DIR), name="ui") @app.get("/") async def read_root(): return FileResponse(os.path.join(UI_DIR, "index.html")) def run(): # Read port from environment variable, default to 8000 port = int(os.environ.get("PORT", 8000)) uvicorn.run(app, host="0.0.0.0", port=port) if __name__ == "__main__": run()