SLM
This commit is contained in:
188
llm.py
Normal file
188
llm.py
Normal file
@@ -0,0 +1,188 @@
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import re
|
||||
import hashlib
|
||||
import pickle
|
||||
from collections import defaultdict, Counter
|
||||
|
||||
SOURCES_DIR = "sources"
|
||||
CACHE_DIR = "models"
|
||||
N_GRAM = 3 # Default N-gram for standalone script use
|
||||
|
||||
def get_dir_checksum(directory):
|
||||
"""
|
||||
Calculates MD5 checksum of all .txt files in the directory to detect changes.
|
||||
"""
|
||||
hash_md5 = hashlib.md5()
|
||||
if not os.path.exists(directory):
|
||||
return None
|
||||
|
||||
files = sorted([f for f in os.listdir(directory) if f.endswith('.txt')])
|
||||
|
||||
for filename in files:
|
||||
filepath = os.path.join(directory, filename)
|
||||
hash_md5.update(filename.encode('utf-8'))
|
||||
with open(filepath, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(4096), b""):
|
||||
hash_md5.update(chunk)
|
||||
|
||||
return hash_md5.hexdigest()
|
||||
|
||||
def train_model(sources_dir, n):
|
||||
"""
|
||||
Trains the N-gram model from scratch.
|
||||
Returns: model object
|
||||
"""
|
||||
print(f"Training new {n}-gram model from sources...")
|
||||
model = defaultdict(Counter)
|
||||
|
||||
files = [f for f in os.listdir(sources_dir) if f.endswith(".txt")]
|
||||
if not files:
|
||||
print("No source files found!")
|
||||
return model
|
||||
|
||||
for filename in files:
|
||||
filepath = os.path.join(sources_dir, filename)
|
||||
try:
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
text = f.read()
|
||||
text = re.sub(r'[[.*?]]', '', text)
|
||||
words = text.split()
|
||||
|
||||
if len(words) < n:
|
||||
continue
|
||||
|
||||
context_size = n - 1
|
||||
for i in range(len(words) - context_size):
|
||||
context = tuple(words[i : i + context_size])
|
||||
next_word = words[i + context_size]
|
||||
model[context][next_word] += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing {filename}: {e}")
|
||||
|
||||
return model
|
||||
|
||||
def load_or_train_model(sources_dir, n):
|
||||
"""
|
||||
Loads model from its dedicated cache file if checksum matches, otherwise retrains.
|
||||
"""
|
||||
if not os.path.exists(CACHE_DIR):
|
||||
os.makedirs(CACHE_DIR)
|
||||
|
||||
cache_file = os.path.join(CACHE_DIR, f"model_n{n}.pkl")
|
||||
checksum_file = os.path.join(CACHE_DIR, f"checksum.txt") # One checksum for all
|
||||
|
||||
current_checksum = get_dir_checksum(sources_dir)
|
||||
|
||||
# Check if a model for this N exists and if the checksum matches
|
||||
if os.path.exists(cache_file) and os.path.exists(checksum_file):
|
||||
with open(checksum_file, 'r') as f:
|
||||
saved_checksum = f.read()
|
||||
|
||||
if saved_checksum == current_checksum:
|
||||
print(f"Sources unchanged. Loading model N={n} from {cache_file}...")
|
||||
with open(cache_file, 'rb') as f:
|
||||
return pickle.load(f)
|
||||
else:
|
||||
print(f"Sources changed. Global retrain needed. Deleting old models.")
|
||||
for item in os.listdir(CACHE_DIR):
|
||||
os.remove(os.path.join(CACHE_DIR, item))
|
||||
|
||||
print(f"No valid cache found for N={n}. Training...")
|
||||
model = train_model(sources_dir, n)
|
||||
|
||||
print(f"Saving model to {cache_file}...")
|
||||
with open(cache_file, 'wb') as f:
|
||||
pickle.dump(model, f)
|
||||
# Update the global checksum file after a successful train
|
||||
with open(checksum_file, 'w') as f:
|
||||
f.write(current_checksum or "")
|
||||
|
||||
return model
|
||||
|
||||
def generate_text(model, start_prompt, length=100, temperature=1.0):
|
||||
"""
|
||||
Generates text using the N-gram model.
|
||||
"""
|
||||
if not model:
|
||||
return ""
|
||||
|
||||
try:
|
||||
context_size = next(iter(model.keys())).__len__() # Get context size from model keys
|
||||
except StopIteration:
|
||||
return "" # Model is empty
|
||||
|
||||
start_words = start_prompt.split()
|
||||
current_context = None
|
||||
|
||||
if len(start_words) >= context_size:
|
||||
potential_context = tuple(start_words[-context_size:])
|
||||
if potential_context in model:
|
||||
current_context = potential_context
|
||||
|
||||
if current_context is None and start_words:
|
||||
last_word = start_words[-1]
|
||||
candidates = [k for k in model.keys() if k[0] == last_word]
|
||||
if candidates:
|
||||
current_context = random.choice(candidates)
|
||||
|
||||
if current_context is None:
|
||||
current_context = random.choice(list(model.keys()))
|
||||
if not start_prompt:
|
||||
start_prompt = ' '.join(current_context)
|
||||
|
||||
generated_words = []
|
||||
|
||||
for _ in range(length):
|
||||
if current_context not in model or not model[current_context]:
|
||||
current_context = random.choice(list(model.keys()))
|
||||
|
||||
possible_next = list(model[current_context].keys())
|
||||
counts = list(model[current_context].values())
|
||||
|
||||
try:
|
||||
if temperature == 1.0:
|
||||
weights = counts
|
||||
else:
|
||||
weights = [c ** (1.0 / temperature) for c in counts]
|
||||
next_word = random.choices(possible_next, weights=weights, k=1)[0]
|
||||
except (ValueError, IndexError):
|
||||
# Fallback if weights are invalid or no words are possible
|
||||
current_context = random.choice(list(model.keys()))
|
||||
next_word = current_context[0]
|
||||
|
||||
generated_words.append(next_word)
|
||||
current_context = current_context[1:] + (next_word,)
|
||||
|
||||
return " ".join(generated_words)
|
||||
|
||||
def main():
|
||||
if not os.path.isdir(SOURCES_DIR):
|
||||
print(f"Error: Directory '{SOURCES_DIR}' not found.")
|
||||
sys.exit(1)
|
||||
|
||||
model = load_or_train_model(SOURCES_DIR, N_GRAM)
|
||||
print(f"Model ready. (N={N_GRAM}, Keys={len(model)})")
|
||||
|
||||
start_prompt = ""
|
||||
length = 100
|
||||
temperature = 1.0
|
||||
|
||||
args = sys.argv[1:]
|
||||
if not args:
|
||||
start_ctx = random.choice(list(model.keys()))
|
||||
start_prompt = " ".join(start_ctx)
|
||||
else:
|
||||
start_prompt = args[0]
|
||||
if len(args) >= 2: length = int(args[1])
|
||||
if len(args) >= 3: temperature = float(args[2])
|
||||
|
||||
print(f"\n--- Generating (Start: '{start_prompt}', Temp: {temperature}) ---\n")
|
||||
output = start_prompt + " " + generate_text(model, start_prompt, length, temperature)
|
||||
print(output)
|
||||
print("\n-------------------------------------------------------------")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user