SLM
This commit is contained in:
92
api.py
Normal file
92
api.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import os
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, Body
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.responses import FileResponse
|
||||
from pydantic import BaseModel
|
||||
import sys
|
||||
|
||||
# Import core LLM logic
|
||||
from llm import load_or_train_model, generate_text, SOURCES_DIR
|
||||
|
||||
# --- Configuration ---
|
||||
# Models to pre-load on startup
|
||||
PRELOAD_N_GRAMS = [2, 3, 4, 5]
|
||||
UI_DIR = "ui"
|
||||
|
||||
# --- Globals ---
|
||||
# Cache for loaded models: {n: model}
|
||||
MODEL_CACHE = {}
|
||||
|
||||
# --- Pydantic Models ---
|
||||
class PredictRequest(BaseModel):
|
||||
prompt: str
|
||||
temperature: float = 0.7
|
||||
n: int = 3
|
||||
length: int = 5
|
||||
|
||||
class PredictResponse(BaseModel):
|
||||
prediction: str
|
||||
|
||||
# --- FastAPI App ---
|
||||
app = FastAPI()
|
||||
|
||||
def get_model_for_n(n: int):
|
||||
"""
|
||||
Retrieves the model for a specific N from cache, or loads/trains it.
|
||||
"""
|
||||
global MODEL_CACHE
|
||||
if n in MODEL_CACHE:
|
||||
return MODEL_CACHE[n]
|
||||
|
||||
print(f"Loading/Training model for N={n}...")
|
||||
model = load_or_train_model(SOURCES_DIR, n)
|
||||
MODEL_CACHE[n] = model
|
||||
return model
|
||||
|
||||
@app.on_event("startup")
|
||||
def startup_event():
|
||||
"""
|
||||
On server startup, pre-load models for all specified N-grams.
|
||||
"""
|
||||
print("Server starting up. Pre-loading models...")
|
||||
for n in PRELOAD_N_GRAMS:
|
||||
get_model_for_n(n)
|
||||
print(f"Models for N={PRELOAD_N_GRAMS} loaded. Server is ready.")
|
||||
|
||||
@app.post("/api/predict", response_model=PredictResponse)
|
||||
async def predict(request: PredictRequest):
|
||||
"""
|
||||
API endpoint to get the next word prediction.
|
||||
"""
|
||||
n = max(2, min(request.n, 5))
|
||||
model = get_model_for_n(n)
|
||||
|
||||
if not model:
|
||||
return {"prediction": ""}
|
||||
|
||||
length = max(1, min(request.length, 500))
|
||||
|
||||
prediction = generate_text(
|
||||
model,
|
||||
start_prompt=request.prompt,
|
||||
length=length,
|
||||
temperature=request.temperature
|
||||
)
|
||||
|
||||
return PredictResponse(prediction=prediction)
|
||||
|
||||
# --- Static Files and Root ---
|
||||
app.mount("/ui", StaticFiles(directory=UI_DIR), name="ui")
|
||||
|
||||
@app.get("/")
|
||||
async def read_root():
|
||||
return FileResponse(os.path.join(UI_DIR, "index.html"))
|
||||
|
||||
def run():
|
||||
# Read port from environment variable, default to 8000
|
||||
port = int(os.environ.get("PORT", 8000))
|
||||
uvicorn.run(app, host="0.0.0.0", port=port)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
||||
Reference in New Issue
Block a user