Files
Kreatyw/api.py
2026-01-09 00:40:25 +01:00

121 lines
2.6 KiB
Python

import os
import sys
import uvicorn
from fastapi import Body, FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
# Import core LLM logic
from llm import SOURCES_DIR, generate_text, load_or_train_model
# --- Configuration ---
# Models to pre-load on startup
PRELOAD_N_GRAMS = [2, 3, 4, 5]
UI_DIR = "ui"
# --- Globals ---
# Cache for loaded models: {n: model}
MODEL_CACHE = {}
# --- Pydantic Models ---
class PredictRequest(BaseModel):
prompt: str
temperature: float = 1.6
n: int = 4
length: int = 5
class PredictResponse(BaseModel):
prediction: str
# --- FastAPI App ---
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def get_model_for_n(n: int):
"""
Retrieves the model for a specific N from cache, or loads/trains it.
"""
global MODEL_CACHE
if n in MODEL_CACHE:
return MODEL_CACHE[n]
print(f"Loading/Training model for N={n}...")
model = load_or_train_model(SOURCES_DIR, n)
MODEL_CACHE[n] = model
return model
@app.on_event("startup")
def startup_event():
"""
On server startup, pre-load models for all specified N-grams.
"""
print("Server starting up. Pre-loading models...")
for n in PRELOAD_N_GRAMS:
get_model_for_n(n)
print(f"Models for N={PRELOAD_N_GRAMS} loaded. Server is ready.")
@app.post("/api/predict", response_model=PredictResponse)
async def predict(request: PredictRequest):
"""
API endpoint to get the next word prediction.
"""
n = max(2, min(request.n, 5))
model = get_model_for_n(n)
if not model:
return {"prediction": ""}
length = max(1, min(request.length, 500))
prediction = generate_text(
model,
start_prompt=request.prompt,
length=length,
temperature=request.temperature,
)
return PredictResponse(prediction=prediction)
@app.get("/api")
async def api_docs():
"""
API documentation page.
"""
return FileResponse(os.path.join(UI_DIR, "api.html"))
# --- Static Files and Root ---
app.mount("/ui", StaticFiles(directory=UI_DIR), name="ui")
@app.get("/")
async def read_root():
return FileResponse(os.path.join(UI_DIR, "index.html"))
def run():
# Read port from environment variable, default to 8000
port = int(os.environ.get("PORT", 8000))
uvicorn.run(app, host="0.0.0.0", port=port, proxy_headers=True, forwarded_allow_ips="*")
if __name__ == "__main__":
run()