Files
Kreatyw/api.py
2026-01-07 10:40:22 +01:00

112 lines
2.4 KiB
Python

import os
import sys
import uvicorn
from fastapi import Body, FastAPI
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
# Import core LLM logic
from llm import SOURCES_DIR, generate_text, load_or_train_model
# --- Configuration ---
# Models to pre-load on startup
PRELOAD_N_GRAMS = [2, 3, 4, 5]
UI_DIR = "ui"
# --- Globals ---
# Cache for loaded models: {n: model}
MODEL_CACHE = {}
# --- Pydantic Models ---
class PredictRequest(BaseModel):
prompt: str
temperature: float = 1.6
n: int = 4
length: int = 5
class PredictResponse(BaseModel):
prediction: str
# --- FastAPI App ---
app = FastAPI()
def get_model_for_n(n: int):
"""
Retrieves the model for a specific N from cache, or loads/trains it.
"""
global MODEL_CACHE
if n in MODEL_CACHE:
return MODEL_CACHE[n]
print(f"Loading/Training model for N={n}...")
model = load_or_train_model(SOURCES_DIR, n)
MODEL_CACHE[n] = model
return model
@app.on_event("startup")
def startup_event():
"""
On server startup, pre-load models for all specified N-grams.
"""
print("Server starting up. Pre-loading models...")
for n in PRELOAD_N_GRAMS:
get_model_for_n(n)
print(f"Models for N={PRELOAD_N_GRAMS} loaded. Server is ready.")
@app.post("/api/predict", response_model=PredictResponse)
async def predict(request: PredictRequest):
"""
API endpoint to get the next word prediction.
"""
n = max(2, min(request.n, 5))
model = get_model_for_n(n)
if not model:
return {"prediction": ""}
length = max(1, min(request.length, 500))
prediction = generate_text(
model,
start_prompt=request.prompt,
length=length,
temperature=request.temperature,
)
return PredictResponse(prediction=prediction)
@app.get("/api")
async def api_docs():
"""
API documentation page.
"""
return FileResponse(os.path.join(UI_DIR, "api.html"))
# --- Static Files and Root ---
app.mount("/ui", StaticFiles(directory=UI_DIR), name="ui")
@app.get("/")
async def read_root():
return FileResponse(os.path.join(UI_DIR, "index.html"))
def run():
# Read port from environment variable, default to 8000
port = int(os.environ.get("PORT", 8000))
uvicorn.run(app, host="0.0.0.0", port=port)
if __name__ == "__main__":
run()