feat: ollama VRAM status + model loading/pinning on switch
- Show loaded models with VRAM usage bar (24GB 3090) - On mode switch: unload old model, load+pin target model (keep_alive=-1m) - Loading banner with spinner (polls faster at 2s while loading) - Lab model changes also trigger model swap when in lab mode - Manual load/unload API endpoints
This commit is contained in:
160
app.py
160
app.py
@@ -2,18 +2,21 @@
|
||||
"""
|
||||
Ollama GPU Switcher — Toggle OpenClaw agents between work mode (qwen3) and lab mode (GPU exclusive).
|
||||
No LLM involved. Reads/writes openclaw.json directly, then signals the gateway to restart.
|
||||
Also manages ollama model loading/pinning via the ollama API.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
import copy
|
||||
import threading
|
||||
from flask import Flask, jsonify, request, send_from_directory
|
||||
import requests as http_requests
|
||||
|
||||
app = Flask(__name__, static_folder="static")
|
||||
|
||||
CONFIG_PATH = os.environ.get("OPENCLAW_CONFIG", os.path.expanduser("~/.openclaw/openclaw.json"))
|
||||
OLLAMA_URL = os.environ.get("OLLAMA_URL", "http://ollama.glenwood.schapira.nyc:11434")
|
||||
|
||||
# Agents that use ollama and compete for GPU
|
||||
OLLAMA_AGENTS = ["rex", "maddy", "coder", "research"]
|
||||
@@ -21,6 +24,10 @@ OLLAMA_AGENTS = ["rex", "maddy", "coder", "research"]
|
||||
WORK_PRIMARY = "ollama/qwen3-128k:14b"
|
||||
LAB_PRIMARY = "groq/llama-3.3-70b-versatile"
|
||||
|
||||
# Model loading state (tracked in-process)
|
||||
_loading_state = {"model": None, "status": "idle"} # idle | loading | done | error
|
||||
_loading_lock = threading.Lock()
|
||||
|
||||
|
||||
def read_config():
|
||||
with open(CONFIG_PATH, "r") as f:
|
||||
@@ -39,7 +46,6 @@ def restart_gateway():
|
||||
subprocess.run(["openclaw", "gateway", "restart"], timeout=10, capture_output=True)
|
||||
return True
|
||||
except Exception:
|
||||
# Fallback: try SIGUSR1 to the gateway process
|
||||
try:
|
||||
result = subprocess.run(["pgrep", "-f", "openclaw.*gateway"], capture_output=True, text=True)
|
||||
if result.stdout.strip():
|
||||
@@ -77,6 +83,82 @@ def detect_mode(config):
|
||||
return "mixed"
|
||||
|
||||
|
||||
def ollama_ps():
|
||||
"""Get currently loaded models from ollama."""
|
||||
try:
|
||||
r = http_requests.get(f"{OLLAMA_URL}/api/ps", timeout=5)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
models = []
|
||||
for m in data.get("models", []):
|
||||
size_gb = m.get("size_vram", 0) / (1024**3)
|
||||
models.append({
|
||||
"name": m.get("name", "unknown"),
|
||||
"size_vram_gb": round(size_gb, 1),
|
||||
"parameter_size": m.get("details", {}).get("parameter_size", ""),
|
||||
"quantization": m.get("details", {}).get("quantization_level", ""),
|
||||
"family": m.get("details", {}).get("family", ""),
|
||||
"context_length": m.get("context_length", 0),
|
||||
"expires_at": m.get("expires_at", ""),
|
||||
})
|
||||
return {"ok": True, "models": models}
|
||||
except Exception as e:
|
||||
return {"ok": False, "models": [], "error": str(e)}
|
||||
|
||||
|
||||
def ollama_load_model(model_name, keep_alive="-1m"):
|
||||
"""Load a model into VRAM and pin it. keep_alive=-1m means forever."""
|
||||
global _loading_state
|
||||
with _loading_lock:
|
||||
_loading_state = {"model": model_name, "status": "loading"}
|
||||
|
||||
try:
|
||||
# Use /api/generate with empty prompt to load & pin the model
|
||||
r = http_requests.post(
|
||||
f"{OLLAMA_URL}/api/generate",
|
||||
json={
|
||||
"model": model_name,
|
||||
"prompt": "",
|
||||
"keep_alive": keep_alive,
|
||||
},
|
||||
timeout=300, # models can take a while to load
|
||||
)
|
||||
r.raise_for_status()
|
||||
with _loading_lock:
|
||||
_loading_state = {"model": model_name, "status": "done"}
|
||||
return True
|
||||
except Exception as e:
|
||||
with _loading_lock:
|
||||
_loading_state = {"model": model_name, "status": "error", "error": str(e)}
|
||||
return False
|
||||
|
||||
|
||||
def ollama_unload_model(model_name):
|
||||
"""Unload a model from VRAM."""
|
||||
try:
|
||||
r = http_requests.post(
|
||||
f"{OLLAMA_URL}/api/generate",
|
||||
json={
|
||||
"model": model_name,
|
||||
"prompt": "",
|
||||
"keep_alive": "0",
|
||||
},
|
||||
timeout=30,
|
||||
)
|
||||
r.raise_for_status()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def load_model_async(model_name):
|
||||
"""Load model in background thread."""
|
||||
t = threading.Thread(target=ollama_load_model, args=(model_name,), daemon=True)
|
||||
t.start()
|
||||
|
||||
|
||||
# --- Routes ---
|
||||
|
||||
@app.route("/")
|
||||
def index():
|
||||
return send_from_directory("static", "index.html")
|
||||
@@ -104,7 +186,6 @@ def status():
|
||||
"model": lab.get("model", {}).get("primary", "unknown") if lab else "unknown",
|
||||
}
|
||||
|
||||
# Subagents default
|
||||
subagents_primary = (
|
||||
config.get("agents", {})
|
||||
.get("defaults", {})
|
||||
@@ -124,6 +205,16 @@ def status():
|
||||
return jsonify({"ok": False, "error": str(e)}), 500
|
||||
|
||||
|
||||
@app.route("/api/ollama")
|
||||
def ollama_status():
|
||||
"""Get ollama loaded models + loading state."""
|
||||
ps = ollama_ps()
|
||||
with _loading_lock:
|
||||
loading = dict(_loading_state)
|
||||
ps["loading"] = loading
|
||||
return jsonify(ps)
|
||||
|
||||
|
||||
@app.route("/api/switch", methods=["POST"])
|
||||
def switch():
|
||||
try:
|
||||
@@ -132,13 +223,23 @@ def switch():
|
||||
|
||||
if target_mode == "lab":
|
||||
new_primary = LAB_PRIMARY
|
||||
target_ollama_model = None # lab model is managed separately
|
||||
elif target_mode == "work":
|
||||
new_primary = WORK_PRIMARY
|
||||
target_ollama_model = "qwen3-128k:14b"
|
||||
else:
|
||||
return jsonify({"ok": False, "error": f"Unknown mode: {target_mode}"}), 400
|
||||
|
||||
config = read_config()
|
||||
|
||||
# Determine which ollama model to load based on mode
|
||||
if target_mode == "lab":
|
||||
lab = find_agent(config, "lab")
|
||||
if lab:
|
||||
lab_model = lab.get("model", {}).get("primary", "")
|
||||
if "ollama/" in lab_model:
|
||||
target_ollama_model = lab_model.replace("ollama/", "")
|
||||
|
||||
# Patch each agent's primary model
|
||||
for agent_id in OLLAMA_AGENTS:
|
||||
agent = find_agent(config, agent_id)
|
||||
@@ -154,10 +255,21 @@ def switch():
|
||||
write_config(config)
|
||||
restarted = restart_gateway()
|
||||
|
||||
# Unload current models and load the target model
|
||||
if target_ollama_model:
|
||||
# First unload anything currently loaded
|
||||
ps = ollama_ps()
|
||||
for m in ps.get("models", []):
|
||||
if m["name"] != target_ollama_model:
|
||||
ollama_unload_model(m["name"])
|
||||
# Load and pin the target model async
|
||||
load_model_async(target_ollama_model)
|
||||
|
||||
return jsonify({
|
||||
"ok": True,
|
||||
"mode": target_mode,
|
||||
"restarted": restarted,
|
||||
"loading_model": target_ollama_model,
|
||||
})
|
||||
except Exception as e:
|
||||
return jsonify({"ok": False, "error": str(e)}), 500
|
||||
@@ -183,13 +295,53 @@ def set_lab_model():
|
||||
write_config(config)
|
||||
restarted = restart_gateway()
|
||||
|
||||
return jsonify({"ok": True, "model": model, "restarted": restarted})
|
||||
# If currently in lab mode, load the new model
|
||||
mode = detect_mode(config)
|
||||
ollama_model_name = None
|
||||
if mode == "lab" and "ollama/" in model:
|
||||
ollama_model_name = model.replace("ollama/", "")
|
||||
# Unload old models first
|
||||
ps = ollama_ps()
|
||||
for m in ps.get("models", []):
|
||||
if m["name"] != ollama_model_name:
|
||||
ollama_unload_model(m["name"])
|
||||
load_model_async(ollama_model_name)
|
||||
|
||||
return jsonify({
|
||||
"ok": True,
|
||||
"model": model,
|
||||
"restarted": restarted,
|
||||
"loading_model": ollama_model_name,
|
||||
})
|
||||
except Exception as e:
|
||||
return jsonify({"ok": False, "error": str(e)}), 500
|
||||
|
||||
|
||||
@app.route("/api/ollama/load", methods=["POST"])
|
||||
def load_model():
|
||||
"""Manually load/pin a model."""
|
||||
data = request.json or {}
|
||||
model = data.get("model", "")
|
||||
if not model:
|
||||
return jsonify({"ok": False, "error": "No model specified"}), 400
|
||||
load_model_async(model)
|
||||
return jsonify({"ok": True, "loading": model})
|
||||
|
||||
|
||||
@app.route("/api/ollama/unload", methods=["POST"])
|
||||
def unload_model():
|
||||
"""Manually unload a model."""
|
||||
data = request.json or {}
|
||||
model = data.get("model", "")
|
||||
if not model:
|
||||
return jsonify({"ok": False, "error": "No model specified"}), 400
|
||||
result = ollama_unload_model(model)
|
||||
return jsonify({"ok": result, "unloaded": model})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
port = int(os.environ.get("PORT", 8585))
|
||||
print(f"🔀 Ollama GPU Switcher running on http://0.0.0.0:{port}")
|
||||
print(f"📄 Config: {CONFIG_PATH}")
|
||||
print(f"🦙 Ollama: {OLLAMA_URL}")
|
||||
app.run(host="0.0.0.0", port=port, debug=False)
|
||||
|
||||
Reference in New Issue
Block a user