104 lines
2.2 KiB
Python
104 lines
2.2 KiB
Python
import os
|
|
import sys
|
|
|
|
import uvicorn
|
|
from fastapi import Body, FastAPI
|
|
from fastapi.responses import FileResponse
|
|
from fastapi.staticfiles import StaticFiles
|
|
from pydantic import BaseModel
|
|
|
|
# Import core LLM logic
|
|
from llm import SOURCES_DIR, generate_text, load_or_train_model
|
|
|
|
# --- Configuration ---
|
|
# Models to pre-load on startup
|
|
PRELOAD_N_GRAMS = [2, 3, 4, 5]
|
|
UI_DIR = "ui"
|
|
|
|
# --- Globals ---
|
|
# Cache for loaded models: {n: model}
|
|
MODEL_CACHE = {}
|
|
|
|
|
|
# --- Pydantic Models ---
|
|
class PredictRequest(BaseModel):
|
|
prompt: str
|
|
temperature: float = 0.7
|
|
n: int = 3
|
|
length: int = 5
|
|
|
|
|
|
class PredictResponse(BaseModel):
|
|
prediction: str
|
|
|
|
|
|
# --- FastAPI App ---
|
|
app = FastAPI()
|
|
|
|
|
|
def get_model_for_n(n: int):
|
|
"""
|
|
Retrieves the model for a specific N from cache, or loads/trains it.
|
|
"""
|
|
global MODEL_CACHE
|
|
if n in MODEL_CACHE:
|
|
return MODEL_CACHE[n]
|
|
|
|
print(f"Loading/Training model for N={n}...")
|
|
model = load_or_train_model(SOURCES_DIR, n)
|
|
MODEL_CACHE[n] = model
|
|
return model
|
|
|
|
|
|
@app.on_event("startup")
|
|
def startup_event():
|
|
"""
|
|
On server startup, pre-load models for all specified N-grams.
|
|
"""
|
|
print("Server starting up. Pre-loading models...")
|
|
for n in PRELOAD_N_GRAMS:
|
|
get_model_for_n(n)
|
|
print(f"Models for N={PRELOAD_N_GRAMS} loaded. Server is ready.")
|
|
|
|
|
|
@app.post("/api/predict", response_model=PredictResponse)
|
|
async def predict(request: PredictRequest):
|
|
"""
|
|
API endpoint to get the next word prediction.
|
|
"""
|
|
n = max(2, min(request.n, 5))
|
|
model = get_model_for_n(n)
|
|
|
|
if not model:
|
|
return {"prediction": ""}
|
|
|
|
length = max(1, min(request.length, 500))
|
|
|
|
prediction = generate_text(
|
|
model,
|
|
start_prompt=request.prompt,
|
|
length=length,
|
|
temperature=request.temperature,
|
|
)
|
|
|
|
return PredictResponse(prediction=prediction)
|
|
|
|
|
|
# --- Static Files and Root ---
|
|
app.mount("/ui", StaticFiles(directory=UI_DIR), name="ui")
|
|
|
|
|
|
@app.get("/")
|
|
async def read_root():
|
|
return FileResponse(os.path.join(UI_DIR, "index.html"))
|
|
|
|
|
|
def run():
|
|
# Read port from environment variable, default to 8000
|
|
port = int(os.environ.get("PORT", 8000))
|
|
uvicorn.run(app, host="0.0.0.0", port=port)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run()
|