#!/usr/bin/env python3 """ Ollama GPU Switcher — Toggle OpenClaw agents between work mode (qwen3) and lab mode (GPU exclusive). No LLM involved. Reads/writes openclaw.json directly, then signals the gateway to restart. Also manages ollama model loading/pinning via the ollama API. """ import json import os import signal import subprocess import threading from flask import Flask, jsonify, request, send_from_directory import requests as http_requests app = Flask(__name__, static_folder="static") CONFIG_PATH = os.environ.get("OPENCLAW_CONFIG", os.path.expanduser("~/.openclaw/openclaw.json")) OLLAMA_URL = os.environ.get("OLLAMA_URL", "http://ollama.glenwood.schapira.nyc:11434") # Agents that use ollama and compete for GPU OLLAMA_AGENTS = ["rex", "maddy", "coder", "research"] WORK_PRIMARY = "ollama/qwen3-128k:14b" LAB_PRIMARY = "groq/llama-3.3-70b-versatile" # Model loading state (tracked in-process) _loading_state = {"model": None, "status": "idle"} # idle | loading | done | error _loading_lock = threading.Lock() def read_config(): with open(CONFIG_PATH, "r") as f: return json.load(f) def write_config(config): with open(CONFIG_PATH, "w") as f: json.dump(config, f, indent=2) f.write("\n") def restart_gateway(): """Restart the openclaw gateway via CLI.""" try: subprocess.run(["openclaw", "gateway", "restart"], timeout=10, capture_output=True) return True except Exception: try: result = subprocess.run(["pgrep", "-f", "openclaw.*gateway"], capture_output=True, text=True) if result.stdout.strip(): pid = int(result.stdout.strip().split("\n")[0]) os.kill(pid, signal.SIGUSR1) return True except Exception: pass return False def find_agent(config, agent_id): for agent in config.get("agents", {}).get("list", []): if agent.get("id") == agent_id: return agent return None def detect_mode(config): ollama_count = 0 groq_count = 0 for agent_id in OLLAMA_AGENTS: agent = find_agent(config, agent_id) if agent: primary = agent.get("model", {}).get("primary", "") if "ollama/" in primary: ollama_count += 1 elif "groq/" in primary: groq_count += 1 if ollama_count == len(OLLAMA_AGENTS): return "work" elif groq_count >= len(OLLAMA_AGENTS): return "lab" return "mixed" def ollama_ps(): """Get currently loaded models from ollama.""" try: r = http_requests.get(f"{OLLAMA_URL}/api/ps", timeout=5) r.raise_for_status() data = r.json() models = [] for m in data.get("models", []): size_gb = m.get("size_vram", 0) / (1024**3) models.append({ "name": m.get("name", "unknown"), "size_vram_gb": round(size_gb, 1), "parameter_size": m.get("details", {}).get("parameter_size", ""), "quantization": m.get("details", {}).get("quantization_level", ""), "family": m.get("details", {}).get("family", ""), "context_length": m.get("context_length", 0), "expires_at": m.get("expires_at", ""), }) return {"ok": True, "models": models} except Exception as e: return {"ok": False, "models": [], "error": str(e)} def ollama_load_model(model_name, keep_alive="-1m"): """Load a model into VRAM and pin it. keep_alive=-1m means forever.""" global _loading_state with _loading_lock: _loading_state = {"model": model_name, "status": "loading"} try: # Use /api/generate with empty prompt to load & pin the model r = http_requests.post( f"{OLLAMA_URL}/api/generate", json={ "model": model_name, "prompt": "", "keep_alive": keep_alive, }, timeout=300, # models can take a while to load ) r.raise_for_status() with _loading_lock: _loading_state = {"model": model_name, "status": "done"} return True except Exception as e: with _loading_lock: _loading_state = {"model": model_name, "status": "error", "error": str(e)} return False def ollama_unload_model(model_name): """Unload a model from VRAM.""" try: r = http_requests.post( f"{OLLAMA_URL}/api/generate", json={ "model": model_name, "prompt": "", "keep_alive": "0", }, timeout=30, ) r.raise_for_status() return True except Exception: return False def load_model_async(model_name): """Load model in background thread.""" t = threading.Thread(target=ollama_load_model, args=(model_name,), daemon=True) t.start() # --- Routes --- @app.route("/") def index(): return send_from_directory("static", "index.html") @app.route("/api/status") def status(): try: config = read_config() mode = detect_mode(config) agent_details = [] for agent_id in OLLAMA_AGENTS: agent = find_agent(config, agent_id) if agent: agent_details.append({ "id": agent["id"], "name": agent.get("name", agent["id"]), "model": agent.get("model", {}).get("primary", "unknown"), }) lab = find_agent(config, "lab") lab_info = { "name": lab.get("name", "Eric") if lab else "Eric", "model": lab.get("model", {}).get("primary", "unknown") if lab else "unknown", } subagents_primary = ( config.get("agents", {}) .get("defaults", {}) .get("subagents", {}) .get("model", {}) .get("primary", "unknown") ) return jsonify({ "ok": True, "mode": mode, "lab": lab_info, "agents": agent_details, "subagentsPrimary": subagents_primary, }) except Exception as e: return jsonify({"ok": False, "error": str(e)}), 500 @app.route("/api/ollama") def ollama_status(): """Get ollama loaded models + loading state.""" ps = ollama_ps() with _loading_lock: loading = dict(_loading_state) ps["loading"] = loading return jsonify(ps) @app.route("/api/switch", methods=["POST"]) def switch(): try: data = request.json or {} target_mode = data.get("mode", "work") if target_mode == "lab": new_primary = LAB_PRIMARY target_ollama_model = None # lab model is managed separately elif target_mode == "work": new_primary = WORK_PRIMARY target_ollama_model = "qwen3-128k:14b" else: return jsonify({"ok": False, "error": f"Unknown mode: {target_mode}"}), 400 config = read_config() # Determine which ollama model to load based on mode if target_mode == "lab": lab = find_agent(config, "lab") if lab: lab_model = lab.get("model", {}).get("primary", "") if "ollama/" in lab_model: target_ollama_model = lab_model.replace("ollama/", "") # Patch each agent's primary model for agent_id in OLLAMA_AGENTS: agent = find_agent(config, agent_id) if agent: if "model" not in agent: agent["model"] = {} agent["model"]["primary"] = new_primary # Patch subagents default config.setdefault("agents", {}).setdefault("defaults", {}).setdefault("subagents", {}).setdefault("model", {}) config["agents"]["defaults"]["subagents"]["model"]["primary"] = new_primary write_config(config) restarted = restart_gateway() # Unload current models and load the target model if target_ollama_model: # First unload anything currently loaded ps = ollama_ps() for m in ps.get("models", []): if m["name"] != target_ollama_model: ollama_unload_model(m["name"]) # Load and pin the target model async load_model_async(target_ollama_model) return jsonify({ "ok": True, "mode": target_mode, "restarted": restarted, "loading_model": target_ollama_model, }) except Exception as e: return jsonify({"ok": False, "error": str(e)}), 500 @app.route("/api/lab-model", methods=["POST"]) def set_lab_model(): try: data = request.json or {} model = data.get("model", "") if not model: return jsonify({"ok": False, "error": "No model specified"}), 400 config = read_config() lab = find_agent(config, "lab") if not lab: return jsonify({"ok": False, "error": "Lab agent not found"}), 404 if "model" not in lab: lab["model"] = {} lab["model"]["primary"] = model write_config(config) restarted = restart_gateway() # If currently in lab mode, load the new model mode = detect_mode(config) ollama_model_name = None if mode == "lab" and "ollama/" in model: ollama_model_name = model.replace("ollama/", "") # Unload old models first ps = ollama_ps() for m in ps.get("models", []): if m["name"] != ollama_model_name: ollama_unload_model(m["name"]) load_model_async(ollama_model_name) return jsonify({ "ok": True, "model": model, "restarted": restarted, "loading_model": ollama_model_name, }) except Exception as e: return jsonify({"ok": False, "error": str(e)}), 500 @app.route("/api/ollama/load", methods=["POST"]) def load_model(): """Manually load/pin a model.""" data = request.json or {} model = data.get("model", "") if not model: return jsonify({"ok": False, "error": "No model specified"}), 400 load_model_async(model) return jsonify({"ok": True, "loading": model}) @app.route("/api/ollama/unload", methods=["POST"]) def unload_model(): """Manually unload a model.""" data = request.json or {} model = data.get("model", "") if not model: return jsonify({"ok": False, "error": "No model specified"}), 400 result = ollama_unload_model(model) return jsonify({"ok": result, "unloaded": model}) if __name__ == "__main__": port = int(os.environ.get("PORT", 8585)) print(f"🔀 Ollama GPU Switcher running on http://0.0.0.0:{port}") print(f"📄 Config: {CONFIG_PATH}") print(f"🦙 Ollama: {OLLAMA_URL}") app.run(host="0.0.0.0", port=port, debug=False)