Apply black and isort auto formatting

Signed-off-by: Marek Pikuła <marek.pikula@embevity.com>
This commit is contained in:
Marek Pikuła
2023-04-03 12:11:45 +00:00
parent 4919147483
commit 358c0086af
4 changed files with 1625 additions and 859 deletions

View File

@@ -1,159 +1,189 @@
# pylint: disable=wrong-import-order import json
import logging
import os
from datetime import date, timedelta
import requests, json, os, logging, yaml import requests
import yaml
from cryptography.fernet import Fernet from cryptography.fernet import Fernet
from datetime import timedelta, date from dateutil import parser
from dateutil import parser from flask import Flask
from flask import Flask
LOG_LEVEL = os.environ["LOG_LEVEL"].replace('"', '').upper() LOG_LEVEL = os.environ["LOG_LEVEL"].replace('"', "").upper()
# Initiate the Flask application and logging: # Initiate the Flask application and logging:
app = Flask(__name__, static_url_path="/static") app = Flask(__name__, static_url_path="/static")
match LOG_LEVEL: match LOG_LEVEL:
case "DEBUG" : app.logger.setLevel(logging.DEBUG) case "DEBUG":
case "INFO" : app.logger.setLevel(logging.INFO) app.logger.setLevel(logging.DEBUG)
case "WARNING" : app.logger.setLevel(logging.WARNING) case "INFO":
case "ERROR" : app.logger.setLevel(logging.ERROR) app.logger.setLevel(logging.INFO)
case "CRITICAL": app.logger.setLevel(logging.CRITICAL) case "WARNING":
app.logger.setLevel(logging.WARNING)
case "ERROR":
app.logger.setLevel(logging.ERROR)
case "CRITICAL":
app.logger.setLevel(logging.CRITICAL)
################################################################## ##################################################################
# Functions related to HEADSCALE and API KEYS # Functions related to HEADSCALE and API KEYS
################################################################## ##################################################################
def get_url(inpage=False): def get_url(inpage=False):
if not inpage: if not inpage:
return os.environ['HS_SERVER'] return os.environ["HS_SERVER"]
config_file = "" config_file = ""
try: try:
config_file = open("/etc/headscale/config.yml", "r") config_file = open("/etc/headscale/config.yml", "r")
app.logger.info("Opening /etc/headscale/config.yml") app.logger.info("Opening /etc/headscale/config.yml")
except: except:
config_file = open("/etc/headscale/config.yaml", "r") config_file = open("/etc/headscale/config.yaml", "r")
app.logger.info("Opening /etc/headscale/config.yaml") app.logger.info("Opening /etc/headscale/config.yaml")
config_yaml = yaml.safe_load(config_file) config_yaml = yaml.safe_load(config_file)
if "server_url" in config_yaml: if "server_url" in config_yaml:
return str(config_yaml["server_url"]) return str(config_yaml["server_url"])
app.logger.warning("Failed to find server_url in the config. Falling back to ENV variable") app.logger.warning(
return os.environ['HS_SERVER'] "Failed to find server_url in the config. Falling back to ENV variable"
)
return os.environ["HS_SERVER"]
def set_api_key(api_key): def set_api_key(api_key):
# User-set encryption key # User-set encryption key
encryption_key = os.environ['KEY'] encryption_key = os.environ["KEY"]
# Key file on the filesystem for persistent storage # Key file on the filesystem for persistent storage
key_file = open("/data/key.txt", "wb+") key_file = open("/data/key.txt", "wb+")
# Preparing the Fernet class with the key # Preparing the Fernet class with the key
fernet = Fernet(encryption_key) fernet = Fernet(encryption_key)
# Encrypting the key # Encrypting the key
encrypted_key = fernet.encrypt(api_key.encode()) encrypted_key = fernet.encrypt(api_key.encode())
# Return true if the file wrote correctly # Return true if the file wrote correctly
return True if key_file.write(encrypted_key) else False return True if key_file.write(encrypted_key) else False
def get_api_key(): def get_api_key():
if not os.path.exists("/data/key.txt"): return False if not os.path.exists("/data/key.txt"):
return False
# User-set encryption key # User-set encryption key
encryption_key = os.environ['KEY'] encryption_key = os.environ["KEY"]
# Key file on the filesystem for persistent storage # Key file on the filesystem for persistent storage
key_file = open("/data/key.txt", "rb+") key_file = open("/data/key.txt", "rb+")
# The encrypted key read from the file # The encrypted key read from the file
enc_api_key = key_file.read() enc_api_key = key_file.read()
if enc_api_key == b'': return "NULL" if enc_api_key == b"":
return "NULL"
# Preparing the Fernet class with the key # Preparing the Fernet class with the key
fernet = Fernet(encryption_key) fernet = Fernet(encryption_key)
# Decrypting the key # Decrypting the key
decrypted_key = fernet.decrypt(enc_api_key).decode() decrypted_key = fernet.decrypt(enc_api_key).decode()
return decrypted_key return decrypted_key
def test_api_key(url, api_key): def test_api_key(url, api_key):
response = requests.get( response = requests.get(
str(url)+"/api/v1/apikey", str(url) + "/api/v1/apikey",
headers={ headers={
'Accept': 'application/json', "Accept": "application/json",
'Authorization': 'Bearer '+str(api_key) "Authorization": "Bearer " + str(api_key),
} },
) )
return response.status_code return response.status_code
# Expires an API key # Expires an API key
def expire_key(url, api_key): def expire_key(url, api_key):
payload = {'prefix':str(api_key[0:10])} payload = {"prefix": str(api_key[0:10])}
json_payload=json.dumps(payload) json_payload = json.dumps(payload)
app.logger.debug("Sending the payload '"+str(json_payload)+"' to the headscale server") app.logger.debug(
"Sending the payload '" + str(json_payload) + "' to the headscale server"
)
response = requests.post( response = requests.post(
str(url)+"/api/v1/apikey/expire", str(url) + "/api/v1/apikey/expire",
data=json_payload, data=json_payload,
headers={ headers={
'Accept': 'application/json', "Accept": "application/json",
'Content-Type': 'application/json', "Content-Type": "application/json",
'Authorization': 'Bearer '+str(api_key) "Authorization": "Bearer " + str(api_key),
} },
) )
return response.status_code return response.status_code
# Checks if the key needs to be renewed # Checks if the key needs to be renewed
# If it does, renews the key, then expires the old key # If it does, renews the key, then expires the old key
def renew_api_key(url, api_key): def renew_api_key(url, api_key):
# 0 = Key has been updated or key is not in need of an update # 0 = Key has been updated or key is not in need of an update
# 1 = Key has failed validity check or has failed to write the API key # 1 = Key has failed validity check or has failed to write the API key
# Check when the key expires and compare it to todays date: # Check when the key expires and compare it to todays date:
key_info = get_api_key_info(url, api_key) key_info = get_api_key_info(url, api_key)
expiration_time = key_info["expiration"] expiration_time = key_info["expiration"]
today_date = date.today() today_date = date.today()
expire = parser.parse(expiration_time) expire = parser.parse(expiration_time)
expire_fmt = str(expire.year) + "-" + str(expire.month).zfill(2) + "-" + str(expire.day).zfill(2) expire_fmt = (
expire_date = date.fromisoformat(expire_fmt) str(expire.year)
delta = expire_date - today_date + "-"
tmp = today_date + timedelta(days=90) + str(expire.month).zfill(2)
new_expiration_date = str(tmp)+"T00:00:00.000000Z" + "-"
+ str(expire.day).zfill(2)
)
expire_date = date.fromisoformat(expire_fmt)
delta = expire_date - today_date
tmp = today_date + timedelta(days=90)
new_expiration_date = str(tmp) + "T00:00:00.000000Z"
# If the delta is less than 5 days, renew the key: # If the delta is less than 5 days, renew the key:
if delta < timedelta(days=5): if delta < timedelta(days=5):
app.logger.warning("Key is about to expire. Delta is "+str(delta)) app.logger.warning("Key is about to expire. Delta is " + str(delta))
payload = {'expiration':str(new_expiration_date)} payload = {"expiration": str(new_expiration_date)}
json_payload=json.dumps(payload) json_payload = json.dumps(payload)
app.logger.debug("Sending the payload '"+str(json_payload)+"' to the headscale server") app.logger.debug(
"Sending the payload '" + str(json_payload) + "' to the headscale server"
)
response = requests.post( response = requests.post(
str(url)+"/api/v1/apikey", str(url) + "/api/v1/apikey",
data=json_payload, data=json_payload,
headers={ headers={
'Accept': 'application/json', "Accept": "application/json",
'Content-Type': 'application/json', "Content-Type": "application/json",
'Authorization': 'Bearer '+str(api_key) "Authorization": "Bearer " + str(api_key),
} },
) )
new_key = response.json() new_key = response.json()
app.logger.debug("JSON: "+json.dumps(new_key)) app.logger.debug("JSON: " + json.dumps(new_key))
app.logger.debug("New Key is: "+new_key["apiKey"]) app.logger.debug("New Key is: " + new_key["apiKey"])
api_key_test = test_api_key(url, new_key["apiKey"]) api_key_test = test_api_key(url, new_key["apiKey"])
app.logger.debug("Testing the key: "+str(api_key_test)) app.logger.debug("Testing the key: " + str(api_key_test))
# Test if the new key works: # Test if the new key works:
if api_key_test == 200: if api_key_test == 200:
app.logger.info("The new key is valid and we are writing it to the file") app.logger.info("The new key is valid and we are writing it to the file")
if not set_api_key(new_key["apiKey"]): if not set_api_key(new_key["apiKey"]):
app.logger.error("We failed writing the new key!") app.logger.error("We failed writing the new key!")
return False # Key write failed return False # Key write failed
app.logger.info("Key validated and written. Moving to expire the key.") app.logger.info("Key validated and written. Moving to expire the key.")
expire_key(url, api_key) expire_key(url, api_key)
return True # Key updated and validated return True # Key updated and validated
else: else:
app.logger.error("Testing the API key failed.") app.logger.error("Testing the API key failed.")
return False # The API Key test failed return False # The API Key test failed
else: return True # No work is required else:
return True # No work is required
# Gets information about the current API key # Gets information about the current API key
def get_api_key_info(url, api_key): def get_api_key_info(url, api_key):
app.logger.info("Getting API key information") app.logger.info("Getting API key information")
response = requests.get( response = requests.get(
str(url)+"/api/v1/apikey", str(url) + "/api/v1/apikey",
headers={ headers={
'Accept': 'application/json', "Accept": "application/json",
'Authorization': 'Bearer '+str(api_key) "Authorization": "Bearer " + str(api_key),
} },
) )
json_response = response.json() json_response = response.json()
# Find the current key in the array: # Find the current key in the array:
key_prefix = str(api_key[0:10]) key_prefix = str(api_key[0:10])
app.logger.info("Looking for valid API Key...") app.logger.info("Looking for valid API Key...")
for key in json_response["apiKeys"]: for key in json_response["apiKeys"]:
@@ -163,19 +193,25 @@ def get_api_key_info(url, api_key):
app.logger.error("Could not find a valid key in Headscale. Need a new API key.") app.logger.error("Could not find a valid key in Headscale. Need a new API key.")
return "Key not found" return "Key not found"
################################################################## ##################################################################
# Functions related to MACHINES # Functions related to MACHINES
################################################################## ##################################################################
# register a new machine # register a new machine
def register_machine(url, api_key, machine_key, user): def register_machine(url, api_key, machine_key, user):
app.logger.info("Registering machine %s to user %s", str(machine_key), str(user)) app.logger.info("Registering machine %s to user %s", str(machine_key), str(user))
response = requests.post( response = requests.post(
str(url)+"/api/v1/machine/register?user="+str(user)+"&key="+str(machine_key), str(url)
+ "/api/v1/machine/register?user="
+ str(user)
+ "&key="
+ str(machine_key),
headers={ headers={
'Accept': 'application/json', "Accept": "application/json",
'Authorization': 'Bearer '+str(api_key) "Authorization": "Bearer " + str(api_key),
} },
) )
return response.json() return response.json()
@@ -184,81 +220,86 @@ def register_machine(url, api_key, machine_key, user):
def set_machine_tags(url, api_key, machine_id, tags_list): def set_machine_tags(url, api_key, machine_id, tags_list):
app.logger.info("Setting machine_id %s tag %s", str(machine_id), str(tags_list)) app.logger.info("Setting machine_id %s tag %s", str(machine_id), str(tags_list))
response = requests.post( response = requests.post(
str(url)+"/api/v1/machine/"+str(machine_id)+"/tags", str(url) + "/api/v1/machine/" + str(machine_id) + "/tags",
data=tags_list, data=tags_list,
headers={ headers={
'Accept': 'application/json', "Accept": "application/json",
'Content-Type': 'application/json', "Content-Type": "application/json",
'Authorization': 'Bearer '+str(api_key) "Authorization": "Bearer " + str(api_key),
} },
) )
return response.json() return response.json()
# Moves machine_id to user "new_user" # Moves machine_id to user "new_user"
def move_user(url, api_key, machine_id, new_user): def move_user(url, api_key, machine_id, new_user):
app.logger.info("Moving machine_id %s to user %s", str(machine_id), str(new_user)) app.logger.info("Moving machine_id %s to user %s", str(machine_id), str(new_user))
response = requests.post( response = requests.post(
str(url)+"/api/v1/machine/"+str(machine_id)+"/user?user="+str(new_user), str(url) + "/api/v1/machine/" + str(machine_id) + "/user?user=" + str(new_user),
headers={ headers={
'Accept': 'application/json', "Accept": "application/json",
'Authorization': 'Bearer '+str(api_key) "Authorization": "Bearer " + str(api_key),
} },
) )
return response.json() return response.json()
def update_route(url, api_key, route_id, current_state): def update_route(url, api_key, route_id, current_state):
action = "disable" if current_state == "True" else "enable" action = "disable" if current_state == "True" else "enable"
app.logger.info("Updating Route %s: Action: %s", str(route_id), str(action)) app.logger.info("Updating Route %s: Action: %s", str(route_id), str(action))
# Debug # Debug
app.logger.debug("URL: "+str(url)) app.logger.debug("URL: " + str(url))
app.logger.debug("Route ID: "+str(route_id)) app.logger.debug("Route ID: " + str(route_id))
app.logger.debug("Current State: "+str(current_state)) app.logger.debug("Current State: " + str(current_state))
app.logger.debug("Action to take: "+str(action)) app.logger.debug("Action to take: " + str(action))
response = requests.post( response = requests.post(
str(url)+"/api/v1/routes/"+str(route_id)+"/"+str(action), str(url) + "/api/v1/routes/" + str(route_id) + "/" + str(action),
headers={ headers={
'Accept': 'application/json', "Accept": "application/json",
'Authorization': 'Bearer '+str(api_key) "Authorization": "Bearer " + str(api_key),
} },
) )
return response.json() return response.json()
# Get all machines on the Headscale network # Get all machines on the Headscale network
def get_machines(url, api_key): def get_machines(url, api_key):
app.logger.info("Getting machine information") app.logger.info("Getting machine information")
response = requests.get( response = requests.get(
str(url)+"/api/v1/machine", str(url) + "/api/v1/machine",
headers={ headers={
'Accept': 'application/json', "Accept": "application/json",
'Authorization': 'Bearer '+str(api_key) "Authorization": "Bearer " + str(api_key),
} },
) )
return response.json() return response.json()
# Get machine with "machine_id" on the Headscale network # Get machine with "machine_id" on the Headscale network
def get_machine_info(url, api_key, machine_id): def get_machine_info(url, api_key, machine_id):
app.logger.info("Getting information for machine ID %s", str(machine_id)) app.logger.info("Getting information for machine ID %s", str(machine_id))
response = requests.get( response = requests.get(
str(url)+"/api/v1/machine/"+str(machine_id), str(url) + "/api/v1/machine/" + str(machine_id),
headers={ headers={
'Accept': 'application/json', "Accept": "application/json",
'Authorization': 'Bearer '+str(api_key) "Authorization": "Bearer " + str(api_key),
} },
) )
return response.json() return response.json()
# Delete a machine from Headscale # Delete a machine from Headscale
def delete_machine(url, api_key, machine_id): def delete_machine(url, api_key, machine_id):
app.logger.info("Deleting machine %s", str(machine_id)) app.logger.info("Deleting machine %s", str(machine_id))
response = requests.delete( response = requests.delete(
str(url)+"/api/v1/machine/"+str(machine_id), str(url) + "/api/v1/machine/" + str(machine_id),
headers={ headers={
'Accept': 'application/json', "Accept": "application/json",
'Authorization': 'Bearer '+str(api_key) "Authorization": "Bearer " + str(api_key),
} },
) )
status = "True" if response.status_code == 200 else "False" status = "True" if response.status_code == 200 else "False"
if response.status_code == 200: if response.status_code == 200:
@@ -267,15 +308,16 @@ def delete_machine(url, api_key, machine_id):
app.logger.error("Deleting machine failed! %s", str(response.json())) app.logger.error("Deleting machine failed! %s", str(response.json()))
return {"status": status, "body": response.json()} return {"status": status, "body": response.json()}
# Rename "machine_id" with name "new_name" # Rename "machine_id" with name "new_name"
def rename_machine(url, api_key, machine_id, new_name): def rename_machine(url, api_key, machine_id, new_name):
app.logger.info("Renaming machine %s", str(machine_id)) app.logger.info("Renaming machine %s", str(machine_id))
response = requests.post( response = requests.post(
str(url)+"/api/v1/machine/"+str(machine_id)+"/rename/"+str(new_name), str(url) + "/api/v1/machine/" + str(machine_id) + "/rename/" + str(new_name),
headers={ headers={
'Accept': 'application/json', "Accept": "application/json",
'Authorization': 'Bearer '+str(api_key) "Authorization": "Bearer " + str(api_key),
} },
) )
status = "True" if response.status_code == 200 else "False" status = "True" if response.status_code == 200 else "False"
if response.status_code == 200: if response.status_code == 200:
@@ -284,15 +326,16 @@ def rename_machine(url, api_key, machine_id, new_name):
app.logger.error("Machine rename failed! %s", str(response.json())) app.logger.error("Machine rename failed! %s", str(response.json()))
return {"status": status, "body": response.json()} return {"status": status, "body": response.json()}
# Gets routes for the passed machine_id # Gets routes for the passed machine_id
def get_machine_routes(url, api_key, machine_id): def get_machine_routes(url, api_key, machine_id):
app.logger.info("Getting routes for machine %s", str(machine_id)) app.logger.info("Getting routes for machine %s", str(machine_id))
response = requests.get( response = requests.get(
str(url)+"/api/v1/machine/"+str(machine_id)+"/routes", str(url) + "/api/v1/machine/" + str(machine_id) + "/routes",
headers={ headers={
'Accept': 'application/json', "Accept": "application/json",
'Authorization': 'Bearer '+str(api_key) "Authorization": "Bearer " + str(api_key),
} },
) )
if response.status_code == 200: if response.status_code == 200:
app.logger.info("Routes obtained") app.logger.info("Routes obtained")
@@ -300,42 +343,47 @@ def get_machine_routes(url, api_key, machine_id):
app.logger.error("Failed to get routes: %s", str(response.json())) app.logger.error("Failed to get routes: %s", str(response.json()))
return response.json() return response.json()
# Gets routes for the entire tailnet # Gets routes for the entire tailnet
def get_routes(url, api_key): def get_routes(url, api_key):
app.logger.info("Getting routes") app.logger.info("Getting routes")
response = requests.get( response = requests.get(
str(url)+"/api/v1/routes", str(url) + "/api/v1/routes",
headers={ headers={
'Accept': 'application/json', "Accept": "application/json",
'Authorization': 'Bearer '+str(api_key) "Authorization": "Bearer " + str(api_key),
} },
) )
return response.json() return response.json()
################################################################## ##################################################################
# Functions related to USERS # Functions related to USERS
################################################################## ##################################################################
# Get all users in use # Get all users in use
def get_users(url, api_key): def get_users(url, api_key):
app.logger.info("Getting Users") app.logger.info("Getting Users")
response = requests.get( response = requests.get(
str(url)+"/api/v1/user", str(url) + "/api/v1/user",
headers={ headers={
'Accept': 'application/json', "Accept": "application/json",
'Authorization': 'Bearer '+str(api_key) "Authorization": "Bearer " + str(api_key),
} },
) )
return response.json() return response.json()
# Rename "old_name" with name "new_name" # Rename "old_name" with name "new_name"
def rename_user(url, api_key, old_name, new_name): def rename_user(url, api_key, old_name, new_name):
app.logger.info("Renaming user %s to %s.", str(old_name), str(new_name)) app.logger.info("Renaming user %s to %s.", str(old_name), str(new_name))
response = requests.post( response = requests.post(
str(url)+"/api/v1/user/"+str(old_name)+"/rename/"+str(new_name), str(url) + "/api/v1/user/" + str(old_name) + "/rename/" + str(new_name),
headers={ headers={
'Accept': 'application/json', "Accept": "application/json",
'Authorization': 'Bearer '+str(api_key) "Authorization": "Bearer " + str(api_key),
} },
) )
status = "True" if response.status_code == 200 else "False" status = "True" if response.status_code == 200 else "False"
if response.status_code == 200: if response.status_code == 200:
@@ -344,15 +392,16 @@ def rename_user(url, api_key, old_name, new_name):
app.logger.error("Renaming User failed!") app.logger.error("Renaming User failed!")
return {"status": status, "body": response.json()} return {"status": status, "body": response.json()}
# Delete a user from Headscale # Delete a user from Headscale
def delete_user(url, api_key, user_name): def delete_user(url, api_key, user_name):
app.logger.info("Deleting a User: %s", str(user_name)) app.logger.info("Deleting a User: %s", str(user_name))
response = requests.delete( response = requests.delete(
str(url)+"/api/v1/user/"+str(user_name), str(url) + "/api/v1/user/" + str(user_name),
headers={ headers={
'Accept': 'application/json', "Accept": "application/json",
'Authorization': 'Bearer '+str(api_key) "Authorization": "Bearer " + str(api_key),
} },
) )
status = "True" if response.status_code == 200 else "False" status = "True" if response.status_code == 200 else "False"
if response.status_code == 200: if response.status_code == 200:
@@ -361,17 +410,18 @@ def delete_user(url, api_key, user_name):
app.logger.error("Deleting User failed!") app.logger.error("Deleting User failed!")
return {"status": status, "body": response.json()} return {"status": status, "body": response.json()}
# Add a user from Headscale # Add a user from Headscale
def add_user(url, api_key, data): def add_user(url, api_key, data):
app.logger.info("Adding user: %s", str(data)) app.logger.info("Adding user: %s", str(data))
response = requests.post( response = requests.post(
str(url)+"/api/v1/user", str(url) + "/api/v1/user",
data=data, data=data,
headers={ headers={
'Accept': 'application/json', "Accept": "application/json",
'Content-Type': 'application/json', "Content-Type": "application/json",
'Authorization': 'Bearer '+str(api_key) "Authorization": "Bearer " + str(api_key),
} },
) )
status = "True" if response.status_code == 200 else "False" status = "True" if response.status_code == 200 else "False"
if response.status_code == 200: if response.status_code == 200:
@@ -380,34 +430,37 @@ def add_user(url, api_key, data):
app.logger.error("Adding User failed!") app.logger.error("Adding User failed!")
return {"status": status, "body": response.json()} return {"status": status, "body": response.json()}
################################################################## ##################################################################
# Functions related to PREAUTH KEYS in USERS # Functions related to PREAUTH KEYS in USERS
################################################################## ##################################################################
# Get all PreAuth keys associated with a user "user_name" # Get all PreAuth keys associated with a user "user_name"
def get_preauth_keys(url, api_key, user_name): def get_preauth_keys(url, api_key, user_name):
app.logger.info("Getting PreAuth Keys in User %s", str(user_name)) app.logger.info("Getting PreAuth Keys in User %s", str(user_name))
response = requests.get( response = requests.get(
str(url)+"/api/v1/preauthkey?user="+str(user_name), str(url) + "/api/v1/preauthkey?user=" + str(user_name),
headers={ headers={
'Accept': 'application/json', "Accept": "application/json",
'Authorization': 'Bearer '+str(api_key) "Authorization": "Bearer " + str(api_key),
} },
) )
return response.json() return response.json()
# Add a preauth key to the user "user_name" given the booleans "ephemeral"
# Add a preauth key to the user "user_name" given the booleans "ephemeral"
# and "reusable" with the expiration date "date" contained in the JSON payload "data" # and "reusable" with the expiration date "date" contained in the JSON payload "data"
def add_preauth_key(url, api_key, data): def add_preauth_key(url, api_key, data):
app.logger.info("Adding PreAuth Key: %s", str(data)) app.logger.info("Adding PreAuth Key: %s", str(data))
response = requests.post( response = requests.post(
str(url)+"/api/v1/preauthkey", str(url) + "/api/v1/preauthkey",
data=data, data=data,
headers={ headers={
'Accept': 'application/json', "Accept": "application/json",
'Content-Type': 'application/json', "Content-Type": "application/json",
'Authorization': 'Bearer '+str(api_key) "Authorization": "Bearer " + str(api_key),
} },
) )
status = "True" if response.status_code == 200 else "False" status = "True" if response.status_code == 200 else "False"
if response.status_code == 200: if response.status_code == 200:
@@ -416,19 +469,20 @@ def add_preauth_key(url, api_key, data):
app.logger.error("Adding PreAuth Key failed!") app.logger.error("Adding PreAuth Key failed!")
return {"status": status, "body": response.json()} return {"status": status, "body": response.json()}
# Expire a pre-auth key. data is {"user": "string", "key": "string"} # Expire a pre-auth key. data is {"user": "string", "key": "string"}
def expire_preauth_key(url, api_key, data): def expire_preauth_key(url, api_key, data):
app.logger.info("Expiring PreAuth Key...") app.logger.info("Expiring PreAuth Key...")
response = requests.post( response = requests.post(
str(url)+"/api/v1/preauthkey/expire", str(url) + "/api/v1/preauthkey/expire",
data=data, data=data,
headers={ headers={
'Accept': 'application/json', "Accept": "application/json",
'Content-Type': 'application/json', "Content-Type": "application/json",
'Authorization': 'Bearer '+str(api_key) "Authorization": "Bearer " + str(api_key),
} },
) )
status = "True" if response.status_code == 200 else "False" status = "True" if response.status_code == 200 else "False"
app.logger.debug("expire_preauth_key - Return: "+str(response.json())) app.logger.debug("expire_preauth_key - Return: " + str(response.json()))
app.logger.debug("expire_preauth_key - Status: "+str(status)) app.logger.debug("expire_preauth_key - Status: " + str(status))
return {"status": status, "body": response.json()} return {"status": status, "body": response.json()}

257
helper.py
View File

@@ -1,68 +1,120 @@
# pylint: disable=wrong-import-order import logging
import os
import os, headscale, requests, logging import requests
from flask import Flask from flask import Flask
LOG_LEVEL = os.environ["LOG_LEVEL"].replace('"', '').upper() import headscale
LOG_LEVEL = os.environ["LOG_LEVEL"].replace('"', "").upper()
# Initiate the Flask application and logging: # Initiate the Flask application and logging:
app = Flask(__name__, static_url_path="/static") app = Flask(__name__, static_url_path="/static")
match LOG_LEVEL: match LOG_LEVEL:
case "DEBUG" : app.logger.setLevel(logging.DEBUG) case "DEBUG":
case "INFO" : app.logger.setLevel(logging.INFO) app.logger.setLevel(logging.DEBUG)
case "WARNING" : app.logger.setLevel(logging.WARNING) case "INFO":
case "ERROR" : app.logger.setLevel(logging.ERROR) app.logger.setLevel(logging.INFO)
case "CRITICAL": app.logger.setLevel(logging.CRITICAL) case "WARNING":
app.logger.setLevel(logging.WARNING)
case "ERROR":
app.logger.setLevel(logging.ERROR)
case "CRITICAL":
app.logger.setLevel(logging.CRITICAL)
def pretty_print_duration(duration, delta_type=""): def pretty_print_duration(duration, delta_type=""):
""" Prints a duration in human-readable formats """ """Prints a duration in human-readable formats"""
days, seconds = duration.days, duration.seconds days, seconds = duration.days, duration.seconds
hours = (days * 24 + seconds // 3600) hours = days * 24 + seconds // 3600
mins = (seconds % 3600) // 60 mins = (seconds % 3600) // 60
secs = seconds % 60 secs = seconds % 60
if delta_type == "expiry": if delta_type == "expiry":
if days > 730: return "in greater than two years" if days > 730:
if days > 365: return "in greater than a year" return "in greater than two years"
if days > 0 : return "in "+ str(days ) + " days" if days > 1 else "in "+ str(days ) + " day" if days > 365:
if hours > 0 : return "in "+ str(hours) + " hours" if hours > 1 else "in "+ str(hours) + " hour" return "in greater than a year"
if mins > 0 : return "in "+ str(mins ) + " minutes" if mins > 1 else "in "+ str(mins ) + " minute" if days > 0:
return "in "+ str(secs ) + " seconds" if secs >= 1 or secs == 0 else "in "+ str(secs ) + " second" return (
if days > 730: return "over two years ago" "in " + str(days) + " days" if days > 1 else "in " + str(days) + " day"
if days > 365: return "over a year ago" )
if days > 0 : return str(days ) + " days ago" if days > 1 else str(days ) + " day ago" if hours > 0:
if hours > 0 : return str(hours) + " hours ago" if hours > 1 else str(hours) + " hour ago" return (
if mins > 0 : return str(mins ) + " minutes ago" if mins > 1 else str(mins ) + " minute ago" "in " + str(hours) + " hours"
return str(secs ) + " seconds ago" if secs >= 1 or secs == 0 else str(secs ) + " second ago" if hours > 1
else "in " + str(hours) + " hour"
)
if mins > 0:
return (
"in " + str(mins) + " minutes"
if mins > 1
else "in " + str(mins) + " minute"
)
return (
"in " + str(secs) + " seconds"
if secs >= 1 or secs == 0
else "in " + str(secs) + " second"
)
if days > 730:
return "over two years ago"
if days > 365:
return "over a year ago"
if days > 0:
return str(days) + " days ago" if days > 1 else str(days) + " day ago"
if hours > 0:
return str(hours) + " hours ago" if hours > 1 else str(hours) + " hour ago"
if mins > 0:
return str(mins) + " minutes ago" if mins > 1 else str(mins) + " minute ago"
return (
str(secs) + " seconds ago"
if secs >= 1 or secs == 0
else str(secs) + " second ago"
)
def text_color_duration(duration): def text_color_duration(duration):
""" Prints a color based on duratioin (imported as seconds) """ """Prints a color based on duratioin (imported as seconds)"""
days, seconds = duration.days, duration.seconds days, seconds = duration.days, duration.seconds
hours = (days * 24 + seconds // 3600) hours = days * 24 + seconds // 3600
mins = ((seconds % 3600) // 60) mins = (seconds % 3600) // 60
secs = (seconds % 60) secs = seconds % 60
if days > 30: return "grey-text " if days > 30:
if days > 14: return "red-text text-darken-2 " return "grey-text "
if days > 5: return "deep-orange-text text-lighten-1" if days > 14:
if days > 1: return "deep-orange-text text-lighten-1" return "red-text text-darken-2 "
if hours > 12: return "orange-text " if days > 5:
if hours > 1: return "orange-text text-lighten-2" return "deep-orange-text text-lighten-1"
if hours == 1: return "yellow-text " if days > 1:
if mins > 15: return "yellow-text text-lighten-2" return "deep-orange-text text-lighten-1"
if mins > 5: return "green-text text-lighten-3" if hours > 12:
if secs > 30: return "green-text text-lighten-2" return "orange-text "
if hours > 1:
return "orange-text text-lighten-2"
if hours == 1:
return "yellow-text "
if mins > 15:
return "yellow-text text-lighten-2"
if mins > 5:
return "green-text text-lighten-3"
if secs > 30:
return "green-text text-lighten-2"
return "green-text " return "green-text "
def key_check():
""" Checks the validity of a Headsclae API key and renews it if it's nearing expiration """
api_key = headscale.get_api_key()
url = headscale.get_url()
# Test the API key. If the test fails, return a failure. def key_check():
"""Checks the validity of a Headsclae API key and renews it if it's nearing expiration"""
api_key = headscale.get_api_key()
url = headscale.get_url()
# Test the API key. If the test fails, return a failure.
# AKA, if headscale returns Unauthorized, fail: # AKA, if headscale returns Unauthorized, fail:
app.logger.info("Testing API key validity.") app.logger.info("Testing API key validity.")
status = headscale.test_api_key(url, api_key) status = headscale.test_api_key(url, api_key)
if status != 200: if status != 200:
app.logger.info("Got a non-200 response from Headscale. Test failed (Response: %i)", status) app.logger.info(
"Got a non-200 response from Headscale. Test failed (Response: %i)",
status,
)
return False return False
else: else:
app.logger.info("Key check passed.") app.logger.info("Key check passed.")
@@ -70,8 +122,9 @@ def key_check():
headscale.renew_api_key(url, api_key) headscale.renew_api_key(url, api_key)
return True return True
def get_color(import_id, item_type = ""):
""" Sets colors for users/namespaces """ def get_color(import_id, item_type=""):
"""Sets colors for users/namespaces"""
# Define the colors... Seems like a good number to start with # Define the colors... Seems like a good number to start with
if item_type == "failover": if item_type == "failover":
colors = [ colors = [
@@ -122,54 +175,58 @@ def get_color(import_id, item_type = ""):
index = import_id % len(colors) index = import_id % len(colors)
return colors[index] return colors[index]
def format_message(error_type, title, message): def format_message(error_type, title, message):
""" Defines a generic 'collection' as error/warning/info messages """ """Defines a generic 'collection' as error/warning/info messages"""
content = """ content = """
<ul class="collection"> <ul class="collection">
<li class="collection-item avatar"> <li class="collection-item avatar">
""" """
match error_type.lower(): match error_type.lower():
case "warning": case "warning":
icon = """<i class="material-icons circle yellow">priority_high</i>""" icon = """<i class="material-icons circle yellow">priority_high</i>"""
title = """<span class="title">Warning - """+title+"""</span>""" title = """<span class="title">Warning - """ + title + """</span>"""
case "success": case "success":
icon = """<i class="material-icons circle green">check</i>""" icon = """<i class="material-icons circle green">check</i>"""
title = """<span class="title">Success - """+title+"""</span>""" title = """<span class="title">Success - """ + title + """</span>"""
case "error": case "error":
icon = """<i class="material-icons circle red">warning</i>""" icon = """<i class="material-icons circle red">warning</i>"""
title = """<span class="title">Error - """+title+"""</span>""" title = """<span class="title">Error - """ + title + """</span>"""
case "information": case "information":
icon = """<i class="material-icons circle grey">help</i>""" icon = """<i class="material-icons circle grey">help</i>"""
title = """<span class="title">Information - """+title+"""</span>""" title = """<span class="title">Information - """ + title + """</span>"""
content = content+icon+title+message content = content + icon + title + message
content = content+""" content = (
content
+ """
</li> </li>
</ul> </ul>
""" """
)
return content return content
def access_checks(): def access_checks():
""" Checks various items before each page load to ensure permissions are correct """ """Checks various items before each page load to ensure permissions are correct"""
url = headscale.get_url() url = headscale.get_url()
# Return an error message if things fail. # Return an error message if things fail.
# Return a formatted error message for EACH fail. # Return a formatted error message for EACH fail.
checks_passed = True # Default to true. Set to false when any checks fail. checks_passed = True # Default to true. Set to false when any checks fail.
data_readable = False # Checks R permissions of /data data_readable = False # Checks R permissions of /data
data_writable = False # Checks W permissions of /data data_writable = False # Checks W permissions of /data
data_executable = False # Execute on directories allows file access data_executable = False # Execute on directories allows file access
file_readable = False # Checks R permissions of /data/key.txt file_readable = False # Checks R permissions of /data/key.txt
file_writable = False # Checks W permissions of /data/key.txt file_writable = False # Checks W permissions of /data/key.txt
file_exists = False # Checks if /data/key.txt exists file_exists = False # Checks if /data/key.txt exists
config_readable = False # Checks if the headscale configuration file is readable config_readable = False # Checks if the headscale configuration file is readable
# Check 1: Check: the Headscale server is reachable: # Check 1: Check: the Headscale server is reachable:
server_reachable = False server_reachable = False
response = requests.get(str(url)+"/health") response = requests.get(str(url) + "/health")
if response.status_code == 200: if response.status_code == 200:
server_reachable = True server_reachable = True
else: else:
@@ -177,35 +234,43 @@ def access_checks():
app.logger.critical("Headscale URL: Response 200: FAILED") app.logger.critical("Headscale URL: Response 200: FAILED")
# Check: /data is rwx for 1000:1000: # Check: /data is rwx for 1000:1000:
if os.access('/data/', os.R_OK): data_readable = True if os.access("/data/", os.R_OK):
data_readable = True
else: else:
app.logger.critical("/data READ: FAILED") app.logger.critical("/data READ: FAILED")
checks_passed = False checks_passed = False
if os.access('/data/', os.W_OK): data_writable = True if os.access("/data/", os.W_OK):
data_writable = True
else: else:
app.logger.critical("/data WRITE: FAILED") app.logger.critical("/data WRITE: FAILED")
checks_passed = False checks_passed = False
if os.access('/data/', os.X_OK): data_executable = True if os.access("/data/", os.X_OK):
data_executable = True
else: else:
app.logger.critical("/data EXEC: FAILED") app.logger.critical("/data EXEC: FAILED")
checks_passed = False checks_passed = False
# Check: /data/key.txt exists and is rw: # Check: /data/key.txt exists and is rw:
if os.access('/data/key.txt', os.F_OK): if os.access("/data/key.txt", os.F_OK):
file_exists = True file_exists = True
if os.access('/data/key.txt', os.R_OK): file_readable = True if os.access("/data/key.txt", os.R_OK):
file_readable = True
else: else:
app.logger.critical("/data/key.txt READ: FAILED") app.logger.critical("/data/key.txt READ: FAILED")
checks_passed = False checks_passed = False
if os.access('/data/key.txt', os.W_OK): file_writable = True if os.access("/data/key.txt", os.W_OK):
file_writable = True
else: else:
app.logger.critical("/data/key.txt WRITE: FAILED") app.logger.critical("/data/key.txt WRITE: FAILED")
checks_passed = False checks_passed = False
else: app.logger.error("/data/key.txt EXIST: FAILED - NO ERROR") else:
app.logger.error("/data/key.txt EXIST: FAILED - NO ERROR")
# Check: /etc/headscale/config.yaml is readable: # Check: /etc/headscale/config.yaml is readable:
if os.access('/etc/headscale/config.yaml', os.R_OK): config_readable = True if os.access("/etc/headscale/config.yaml", os.R_OK):
elif os.access('/etc/headscale/config.yml', os.R_OK): config_readable = True config_readable = True
elif os.access("/etc/headscale/config.yml", os.R_OK):
config_readable = True
else: else:
app.logger.error("/etc/headscale/config.y(a)ml: READ: FAILED") app.logger.error("/etc/headscale/config.y(a)ml: READ: FAILED")
checks_passed = False checks_passed = False
@@ -218,11 +283,17 @@ def access_checks():
# Generate the message: # Generate the message:
if not server_reachable: if not server_reachable:
app.logger.critical("Server is unreachable") app.logger.critical("Server is unreachable")
message = """ message = (
<p>Your headscale server is either unreachable or not properly configured. """
<p>Your headscale server is either unreachable or not properly configured.
Please ensure your configuration is correct (Check for 200 status on Please ensure your configuration is correct (Check for 200 status on
"""+url+"""/api/v1 failed. Response: """+str(response.status_code)+""".)</p>
""" """
+ url
+ """/api/v1 failed. Response: """
+ str(response.status_code)
+ """.)</p>
"""
)
message_html += format_message("Error", "Headscale unreachable", message) message_html += format_message("Error", "Headscale unreachable", message)
@@ -234,7 +305,9 @@ def access_checks():
is named "config.yaml" or "config.yml"</p> is named "config.yaml" or "config.yml"</p>
""" """
message_html += format_message("Error", "/etc/headscale/config.yaml not readable", message) message_html += format_message(
"Error", "/etc/headscale/config.yaml not readable", message
)
if not data_writable: if not data_writable:
app.logger.critical("/data folder is not writable") app.logger.critical("/data folder is not writable")
@@ -266,7 +339,6 @@ def access_checks():
message_html += format_message("Error", "/data not executable", message) message_html += format_message("Error", "/data not executable", message)
if file_exists: if file_exists:
# If it doesn't exist, we assume the user hasn't created it yet. # If it doesn't exist, we assume the user hasn't created it yet.
# Just redirect to the settings page to enter an API Key # Just redirect to the settings page to enter an API Key
@@ -278,7 +350,9 @@ def access_checks():
by UID/GID 1000:1000.</p> by UID/GID 1000:1000.</p>
""" """
message_html += format_message("Error", "/data/key.txt not writable", message) message_html += format_message(
"Error", "/data/key.txt not writable", message
)
if not file_readable: if not file_readable:
app.logger.critical("/data/key.txt is not readable") app.logger.critical("/data/key.txt is not readable")
@@ -288,14 +362,19 @@ def access_checks():
by UID/GID 1000:1000.</p> by UID/GID 1000:1000.</p>
""" """
message_html += format_message("Error", "/data/key.txt not readable", message) message_html += format_message(
"Error", "/data/key.txt not readable", message
)
return message_html return message_html
def load_checks(): def load_checks():
""" Bundles all the checks into a single function to call easier """ """Bundles all the checks into a single function to call easier"""
# General error checks. See the function for more info: # General error checks. See the function for more info:
if access_checks() != "Pass": return 'error_page' if access_checks() != "Pass":
return "error_page"
# If the API key fails, redirect to the settings page: # If the API key fails, redirect to the settings page:
if not key_check(): return 'settings_page' if not key_check():
return "settings_page"
return "Pass" return "Pass"

File diff suppressed because it is too large Load Diff

640
server.py
View File

@@ -1,35 +1,53 @@
# pylint: disable=wrong-import-order import json
import logging
import os
import secrets
from datetime import datetime
from functools import wraps
import headscale, helper, json, os, pytz, renderer, secrets, requests, logging import pytz
from functools import wraps import requests
from datetime import datetime from dateutil import parser
from flask import Flask, escape, Markup, redirect, render_template, request, url_for from flask import Flask, Markup, escape, redirect, render_template, request, url_for
from dateutil import parser from flask_executor import Executor
from flask_executor import Executor
from werkzeug.middleware.proxy_fix import ProxyFix from werkzeug.middleware.proxy_fix import ProxyFix
import headscale
import helper
import renderer
# Global vars # Global vars
# Colors: https://materializecss.com/color.html # Colors: https://materializecss.com/color.html
COLOR = os.environ["COLOR"].replace('"', '').lower() COLOR = os.environ["COLOR"].replace('"', "").lower()
COLOR_NAV = COLOR+" darken-1" COLOR_NAV = COLOR + " darken-1"
COLOR_BTN = COLOR+" darken-3" COLOR_BTN = COLOR + " darken-3"
AUTH_TYPE = os.environ["AUTH_TYPE"].replace('"', '').lower() AUTH_TYPE = os.environ["AUTH_TYPE"].replace('"', "").lower()
LOG_LEVEL = os.environ["LOG_LEVEL"].replace('"', '').upper() LOG_LEVEL = os.environ["LOG_LEVEL"].replace('"', "").upper()
# If LOG_LEVEL is DEBUG, enable Flask debugging: # If LOG_LEVEL is DEBUG, enable Flask debugging:
DEBUG_STATE = True if LOG_LEVEL == "DEBUG" else False DEBUG_STATE = True if LOG_LEVEL == "DEBUG" else False
# Initiate the Flask application and logging: # Initiate the Flask application and logging:
app = Flask(__name__, static_url_path="/static") app = Flask(__name__, static_url_path="/static")
match LOG_LEVEL: match LOG_LEVEL:
case "DEBUG" : app.logger.setLevel(logging.DEBUG) case "DEBUG":
case "INFO" : app.logger.setLevel(logging.INFO) app.logger.setLevel(logging.DEBUG)
case "WARNING" : app.logger.setLevel(logging.WARNING) case "INFO":
case "ERROR" : app.logger.setLevel(logging.ERROR) app.logger.setLevel(logging.INFO)
case "CRITICAL": app.logger.setLevel(logging.CRITICAL) case "WARNING":
app.logger.setLevel(logging.WARNING)
case "ERROR":
app.logger.setLevel(logging.ERROR)
case "CRITICAL":
app.logger.setLevel(logging.CRITICAL)
executor = Executor(app) executor = Executor(app)
app.wsgi_app = ProxyFix(app.wsgi_app, x_for=1, x_proto=1, x_host=1, x_prefix=1) app.wsgi_app = ProxyFix(app.wsgi_app, x_for=1, x_proto=1, x_host=1, x_prefix=1)
app.logger.info("Headscale-WebUI Version: "+os.environ["APP_VERSION"]+" / "+os.environ["GIT_BRANCH"]) app.logger.info(
"Headscale-WebUI Version: "
+ os.environ["APP_VERSION"]
+ " / "
+ os.environ["GIT_BRANCH"]
)
app.logger.info("LOG LEVEL SET TO %s", str(LOG_LEVEL)) app.logger.info("LOG LEVEL SET TO %s", str(LOG_LEVEL))
app.logger.info("DEBUG STATE: %s", str(DEBUG_STATE)) app.logger.info("DEBUG STATE: %s", str(DEBUG_STATE))
@@ -37,23 +55,23 @@ app.logger.info("DEBUG STATE: %s", str(DEBUG_STATE))
# Set Authentication type. Currently "OIDC" and "BASIC" # Set Authentication type. Currently "OIDC" and "BASIC"
######################################################################################## ########################################################################################
if AUTH_TYPE == "oidc": if AUTH_TYPE == "oidc":
# Currently using: flask-providers-oidc - https://pypi.org/project/flask-providers-oidc/ # Currently using: flask-providers-oidc - https://pypi.org/project/flask-providers-oidc/
# #
# https://gist.github.com/thomasdarimont/145dc9aa857b831ff2eff221b79d179a/ # https://gist.github.com/thomasdarimont/145dc9aa857b831ff2eff221b79d179a/
# https://www.authelia.com/integration/openid-connect/introduction/ # https://www.authelia.com/integration/openid-connect/introduction/
# https://github.com/steinarvk/flask_oidc_demo # https://github.com/steinarvk/flask_oidc_demo
app.logger.info("Loading OIDC libraries and configuring app...") app.logger.info("Loading OIDC libraries and configuring app...")
DOMAIN_NAME = os.environ["DOMAIN_NAME"] DOMAIN_NAME = os.environ["DOMAIN_NAME"]
BASE_PATH = os.environ["SCRIPT_NAME"] if os.environ["SCRIPT_NAME"] != "/" else "" BASE_PATH = os.environ["SCRIPT_NAME"] if os.environ["SCRIPT_NAME"] != "/" else ""
OIDC_SECRET = os.environ["OIDC_CLIENT_SECRET"] OIDC_SECRET = os.environ["OIDC_CLIENT_SECRET"]
OIDC_CLIENT_ID = os.environ["OIDC_CLIENT_ID"] OIDC_CLIENT_ID = os.environ["OIDC_CLIENT_ID"]
OIDC_AUTH_URL = os.environ["OIDC_AUTH_URL"] OIDC_AUTH_URL = os.environ["OIDC_AUTH_URL"]
# Construct client_secrets.json: # Construct client_secrets.json:
response = requests.get(str(OIDC_AUTH_URL)) response = requests.get(str(OIDC_AUTH_URL))
oidc_info = response.json() oidc_info = response.json()
app.logger.debug("JSON Dumps for OIDC_INFO: "+json.dumps(oidc_info)) app.logger.debug("JSON Dumps for OIDC_INFO: " + json.dumps(oidc_info))
client_secrets = json.dumps( client_secrets = json.dumps(
{ {
@@ -75,20 +93,23 @@ if AUTH_TYPE == "oidc":
with open("/app/instance/secrets.json", "r+") as secrets_json: with open("/app/instance/secrets.json", "r+") as secrets_json:
app.logger.debug("/app/instances/secrets.json:") app.logger.debug("/app/instances/secrets.json:")
app.logger.debug(secrets_json.read()) app.logger.debug(secrets_json.read())
app.config.update({ app.config.update(
'SECRET_KEY': secrets.token_urlsafe(32), {
'TESTING': DEBUG_STATE, "SECRET_KEY": secrets.token_urlsafe(32),
'DEBUG': DEBUG_STATE, "TESTING": DEBUG_STATE,
'OIDC_CLIENT_SECRETS': '/app/instance/secrets.json', "DEBUG": DEBUG_STATE,
'OIDC_ID_TOKEN_COOKIE_SECURE': True, "OIDC_CLIENT_SECRETS": "/app/instance/secrets.json",
'OIDC_REQUIRE_VERIFIED_EMAIL': False, "OIDC_ID_TOKEN_COOKIE_SECURE": True,
'OIDC_USER_INFO_ENABLED': True, "OIDC_REQUIRE_VERIFIED_EMAIL": False,
'OIDC_OPENID_REALM': 'Headscale-WebUI', "OIDC_USER_INFO_ENABLED": True,
'OIDC_SCOPES': ['openid', 'profile', 'email'], "OIDC_OPENID_REALM": "Headscale-WebUI",
'OIDC_INTROSPECTION_AUTH_METHOD': 'client_secret_post' "OIDC_SCOPES": ["openid", "profile", "email"],
}) "OIDC_INTROSPECTION_AUTH_METHOD": "client_secret_post",
}
)
from flask_oidc import OpenIDConnect from flask_oidc import OpenIDConnect
oidc = OpenIDConnect(app) oidc = OpenIDConnect(app)
elif AUTH_TYPE == "basic": elif AUTH_TYPE == "basic":
@@ -96,159 +117,56 @@ elif AUTH_TYPE == "basic":
app.logger.info("Loading basic auth libraries and configuring app...") app.logger.info("Loading basic auth libraries and configuring app...")
from flask_basicauth import BasicAuth from flask_basicauth import BasicAuth
app.config['BASIC_AUTH_USERNAME'] = os.environ["BASIC_AUTH_USER"].replace('"', '') app.config["BASIC_AUTH_USERNAME"] = os.environ["BASIC_AUTH_USER"].replace('"', "")
app.config['BASIC_AUTH_PASSWORD'] = os.environ["BASIC_AUTH_PASS"] app.config["BASIC_AUTH_PASSWORD"] = os.environ["BASIC_AUTH_PASS"]
app.config['BASIC_AUTH_FORCE'] = True app.config["BASIC_AUTH_FORCE"] = True
basic_auth = BasicAuth(app) basic_auth = BasicAuth(app)
######################################################################################## ########################################################################################
# Set Authentication type - Dynamically load function decorators # Set Authentication type - Dynamically load function decorators
# https://stackoverflow.com/questions/17256602/assertionerror-view-function-mapping-is-overwriting-an-existing-endpoint-functi # https://stackoverflow.com/questions/17256602/assertionerror-view-function-mapping-is-overwriting-an-existing-endpoint-functi
######################################################################################## ########################################################################################
# Make a fake decorator for oidc.require_login # Make a fake decorator for oidc.require_login
# If anyone knows a better way of doing this, please let me know. # If anyone knows a better way of doing this, please let me know.
class OpenIDConnect(): class OpenIDConnect:
def require_login(self, view_func): def require_login(self, view_func):
@wraps(view_func) @wraps(view_func)
def decorated(*args, **kwargs): def decorated(*args, **kwargs):
return view_func(*args, **kwargs) return view_func(*args, **kwargs)
return decorated return decorated
oidc = OpenIDConnect() oidc = OpenIDConnect()
else: else:
######################################################################################## ########################################################################################
# Set Authentication type - Dynamically load function decorators # Set Authentication type - Dynamically load function decorators
# https://stackoverflow.com/questions/17256602/assertionerror-view-function-mapping-is-overwriting-an-existing-endpoint-functi # https://stackoverflow.com/questions/17256602/assertionerror-view-function-mapping-is-overwriting-an-existing-endpoint-functi
######################################################################################## ########################################################################################
# Make a fake decorator for oidc.require_login # Make a fake decorator for oidc.require_login
# If anyone knows a better way of doing this, please let me know. # If anyone knows a better way of doing this, please let me know.
class OpenIDConnect(): class OpenIDConnect:
def require_login(self, view_func): def require_login(self, view_func):
@wraps(view_func) @wraps(view_func)
def decorated(*args, **kwargs): def decorated(*args, **kwargs):
return view_func(*args, **kwargs) return view_func(*args, **kwargs)
return decorated return decorated
oidc = OpenIDConnect() oidc = OpenIDConnect()
######################################################################################## ########################################################################################
# / pages - User-facing pages # / pages - User-facing pages
######################################################################################## ########################################################################################
@app.route('/') @app.route("/")
@app.route('/overview') @app.route("/overview")
@oidc.require_login @oidc.require_login
def overview_page(): def overview_page():
# Some basic sanity checks: # Some basic sanity checks:
pass_checks = str(helper.load_checks()) pass_checks = str(helper.load_checks())
if pass_checks != "Pass": return redirect(url_for(pass_checks)) if pass_checks != "Pass":
# Check if OIDC is enabled. If it is, display the buttons:
OIDC_NAV_DROPDOWN = Markup("")
OIDC_NAV_MOBILE = Markup("")
if AUTH_TYPE == "oidc":
email_address = oidc.user_getfield("email")
user_name = oidc.user_getfield("preferred_username")
name = oidc.user_getfield("name")
OIDC_NAV_DROPDOWN = renderer.oidc_nav_dropdown(user_name, email_address, name)
OIDC_NAV_MOBILE = renderer.oidc_nav_mobile(user_name, email_address, name)
return render_template('overview.html',
render_page = renderer.render_overview(),
COLOR_NAV = COLOR_NAV,
COLOR_BTN = COLOR_BTN,
OIDC_NAV_DROPDOWN = OIDC_NAV_DROPDOWN,
OIDC_NAV_MOBILE = OIDC_NAV_MOBILE
)
@app.route('/routes', methods=('GET', 'POST'))
@oidc.require_login
def routes_page():
# Some basic sanity checks:
pass_checks = str(helper.load_checks())
if pass_checks != "Pass": return redirect(url_for(pass_checks))
# Check if OIDC is enabled. If it is, display the buttons:
OIDC_NAV_DROPDOWN = Markup("")
OIDC_NAV_MOBILE = Markup("")
INPAGE_SEARCH = Markup(renderer.render_search())
if AUTH_TYPE == "oidc":
email_address = oidc.user_getfield("email")
user_name = oidc.user_getfield("preferred_username")
name = oidc.user_getfield("name")
OIDC_NAV_DROPDOWN = renderer.oidc_nav_dropdown(user_name, email_address, name)
OIDC_NAV_MOBILE = renderer.oidc_nav_mobile(user_name, email_address, name)
return render_template('routes.html',
render_page = renderer.render_routes(),
COLOR_NAV = COLOR_NAV,
COLOR_BTN = COLOR_BTN,
OIDC_NAV_DROPDOWN = OIDC_NAV_DROPDOWN,
OIDC_NAV_MOBILE = OIDC_NAV_MOBILE
)
@app.route('/machines', methods=('GET', 'POST'))
@oidc.require_login
def machines_page():
# Some basic sanity checks:
pass_checks = str(helper.load_checks())
if pass_checks != "Pass": return redirect(url_for(pass_checks))
# Check if OIDC is enabled. If it is, display the buttons:
OIDC_NAV_DROPDOWN = Markup("")
OIDC_NAV_MOBILE = Markup("")
INPAGE_SEARCH = Markup(renderer.render_search())
if AUTH_TYPE == "oidc":
email_address = oidc.user_getfield("email")
user_name = oidc.user_getfield("preferred_username")
name = oidc.user_getfield("name")
OIDC_NAV_DROPDOWN = renderer.oidc_nav_dropdown(user_name, email_address, name)
OIDC_NAV_MOBILE = renderer.oidc_nav_mobile(user_name, email_address, name)
cards = renderer.render_machines_cards()
return render_template('machines.html',
cards = cards,
headscale_server = headscale.get_url(True),
COLOR_NAV = COLOR_NAV,
COLOR_BTN = COLOR_BTN,
OIDC_NAV_DROPDOWN = OIDC_NAV_DROPDOWN,
OIDC_NAV_MOBILE = OIDC_NAV_MOBILE,
INPAGE_SEARCH = INPAGE_SEARCH
)
@app.route('/users', methods=('GET', 'POST'))
@oidc.require_login
def users_page():
# Some basic sanity checks:
pass_checks = str(helper.load_checks())
if pass_checks != "Pass": return redirect(url_for(pass_checks))
# Check if OIDC is enabled. If it is, display the buttons:
OIDC_NAV_DROPDOWN = Markup("")
OIDC_NAV_MOBILE = Markup("")
INPAGE_SEARCH = Markup(renderer.render_search())
if AUTH_TYPE == "oidc":
email_address = oidc.user_getfield("email")
user_name = oidc.user_getfield("preferred_username")
name = oidc.user_getfield("name")
OIDC_NAV_DROPDOWN = renderer.oidc_nav_dropdown(user_name, email_address, name)
OIDC_NAV_MOBILE = renderer.oidc_nav_mobile(user_name, email_address, name)
cards = renderer.render_users_cards()
return render_template('users.html',
cards = cards,
COLOR_NAV = COLOR_NAV,
COLOR_BTN = COLOR_BTN,
OIDC_NAV_DROPDOWN = OIDC_NAV_DROPDOWN,
OIDC_NAV_MOBILE = OIDC_NAV_MOBILE,
INPAGE_SEARCH = INPAGE_SEARCH
)
@app.route('/settings', methods=('GET', 'POST'))
@oidc.require_login
def settings_page():
# Some basic sanity checks:
pass_checks = str(helper.load_checks())
if pass_checks != "Pass" and pass_checks != "settings_page":
return redirect(url_for(pass_checks)) return redirect(url_for(pass_checks))
# Check if OIDC is enabled. If it is, display the buttons: # Check if OIDC is enabled. If it is, display the buttons:
@@ -256,41 +174,170 @@ def settings_page():
OIDC_NAV_MOBILE = Markup("") OIDC_NAV_MOBILE = Markup("")
if AUTH_TYPE == "oidc": if AUTH_TYPE == "oidc":
email_address = oidc.user_getfield("email") email_address = oidc.user_getfield("email")
user_name = oidc.user_getfield("preferred_username") user_name = oidc.user_getfield("preferred_username")
name = oidc.user_getfield("name") name = oidc.user_getfield("name")
OIDC_NAV_DROPDOWN = renderer.oidc_nav_dropdown(user_name, email_address, name) OIDC_NAV_DROPDOWN = renderer.oidc_nav_dropdown(user_name, email_address, name)
OIDC_NAV_MOBILE = renderer.oidc_nav_mobile(user_name, email_address, name) OIDC_NAV_MOBILE = renderer.oidc_nav_mobile(user_name, email_address, name)
GIT_COMMIT_LINK = Markup("<a href='https://github.com/iFargle/headscale-webui/commit/"+os.environ["GIT_COMMIT"]+"'>"+str(os.environ["GIT_COMMIT"])[0:7]+"</a>") return render_template(
"overview.html",
return render_template('settings.html', render_page=renderer.render_overview(),
url = headscale.get_url(), COLOR_NAV=COLOR_NAV,
COLOR_NAV = COLOR_NAV, COLOR_BTN=COLOR_BTN,
COLOR_BTN = COLOR_BTN, OIDC_NAV_DROPDOWN=OIDC_NAV_DROPDOWN,
OIDC_NAV_DROPDOWN = OIDC_NAV_DROPDOWN, OIDC_NAV_MOBILE=OIDC_NAV_MOBILE,
OIDC_NAV_MOBILE = OIDC_NAV_MOBILE,
BUILD_DATE = os.environ["BUILD_DATE"],
APP_VERSION = os.environ["APP_VERSION"],
GIT_COMMIT = GIT_COMMIT_LINK,
GIT_BRANCH = os.environ["GIT_BRANCH"],
HS_VERSION = os.environ["HS_VERSION"]
) )
@app.route('/error')
@app.route("/routes", methods=("GET", "POST"))
@oidc.require_login
def routes_page():
# Some basic sanity checks:
pass_checks = str(helper.load_checks())
if pass_checks != "Pass":
return redirect(url_for(pass_checks))
# Check if OIDC is enabled. If it is, display the buttons:
OIDC_NAV_DROPDOWN = Markup("")
OIDC_NAV_MOBILE = Markup("")
INPAGE_SEARCH = Markup(renderer.render_search())
if AUTH_TYPE == "oidc":
email_address = oidc.user_getfield("email")
user_name = oidc.user_getfield("preferred_username")
name = oidc.user_getfield("name")
OIDC_NAV_DROPDOWN = renderer.oidc_nav_dropdown(user_name, email_address, name)
OIDC_NAV_MOBILE = renderer.oidc_nav_mobile(user_name, email_address, name)
return render_template(
"routes.html",
render_page=renderer.render_routes(),
COLOR_NAV=COLOR_NAV,
COLOR_BTN=COLOR_BTN,
OIDC_NAV_DROPDOWN=OIDC_NAV_DROPDOWN,
OIDC_NAV_MOBILE=OIDC_NAV_MOBILE,
)
@app.route("/machines", methods=("GET", "POST"))
@oidc.require_login
def machines_page():
# Some basic sanity checks:
pass_checks = str(helper.load_checks())
if pass_checks != "Pass":
return redirect(url_for(pass_checks))
# Check if OIDC is enabled. If it is, display the buttons:
OIDC_NAV_DROPDOWN = Markup("")
OIDC_NAV_MOBILE = Markup("")
INPAGE_SEARCH = Markup(renderer.render_search())
if AUTH_TYPE == "oidc":
email_address = oidc.user_getfield("email")
user_name = oidc.user_getfield("preferred_username")
name = oidc.user_getfield("name")
OIDC_NAV_DROPDOWN = renderer.oidc_nav_dropdown(user_name, email_address, name)
OIDC_NAV_MOBILE = renderer.oidc_nav_mobile(user_name, email_address, name)
cards = renderer.render_machines_cards()
return render_template(
"machines.html",
cards=cards,
headscale_server=headscale.get_url(True),
COLOR_NAV=COLOR_NAV,
COLOR_BTN=COLOR_BTN,
OIDC_NAV_DROPDOWN=OIDC_NAV_DROPDOWN,
OIDC_NAV_MOBILE=OIDC_NAV_MOBILE,
INPAGE_SEARCH=INPAGE_SEARCH,
)
@app.route("/users", methods=("GET", "POST"))
@oidc.require_login
def users_page():
# Some basic sanity checks:
pass_checks = str(helper.load_checks())
if pass_checks != "Pass":
return redirect(url_for(pass_checks))
# Check if OIDC is enabled. If it is, display the buttons:
OIDC_NAV_DROPDOWN = Markup("")
OIDC_NAV_MOBILE = Markup("")
INPAGE_SEARCH = Markup(renderer.render_search())
if AUTH_TYPE == "oidc":
email_address = oidc.user_getfield("email")
user_name = oidc.user_getfield("preferred_username")
name = oidc.user_getfield("name")
OIDC_NAV_DROPDOWN = renderer.oidc_nav_dropdown(user_name, email_address, name)
OIDC_NAV_MOBILE = renderer.oidc_nav_mobile(user_name, email_address, name)
cards = renderer.render_users_cards()
return render_template(
"users.html",
cards=cards,
COLOR_NAV=COLOR_NAV,
COLOR_BTN=COLOR_BTN,
OIDC_NAV_DROPDOWN=OIDC_NAV_DROPDOWN,
OIDC_NAV_MOBILE=OIDC_NAV_MOBILE,
INPAGE_SEARCH=INPAGE_SEARCH,
)
@app.route("/settings", methods=("GET", "POST"))
@oidc.require_login
def settings_page():
# Some basic sanity checks:
pass_checks = str(helper.load_checks())
if pass_checks != "Pass" and pass_checks != "settings_page":
return redirect(url_for(pass_checks))
# Check if OIDC is enabled. If it is, display the buttons:
OIDC_NAV_DROPDOWN = Markup("")
OIDC_NAV_MOBILE = Markup("")
if AUTH_TYPE == "oidc":
email_address = oidc.user_getfield("email")
user_name = oidc.user_getfield("preferred_username")
name = oidc.user_getfield("name")
OIDC_NAV_DROPDOWN = renderer.oidc_nav_dropdown(user_name, email_address, name)
OIDC_NAV_MOBILE = renderer.oidc_nav_mobile(user_name, email_address, name)
GIT_COMMIT_LINK = Markup(
"<a href='https://github.com/iFargle/headscale-webui/commit/"
+ os.environ["GIT_COMMIT"]
+ "'>"
+ str(os.environ["GIT_COMMIT"])[0:7]
+ "</a>"
)
return render_template(
"settings.html",
url=headscale.get_url(),
COLOR_NAV=COLOR_NAV,
COLOR_BTN=COLOR_BTN,
OIDC_NAV_DROPDOWN=OIDC_NAV_DROPDOWN,
OIDC_NAV_MOBILE=OIDC_NAV_MOBILE,
BUILD_DATE=os.environ["BUILD_DATE"],
APP_VERSION=os.environ["APP_VERSION"],
GIT_COMMIT=GIT_COMMIT_LINK,
GIT_BRANCH=os.environ["GIT_BRANCH"],
HS_VERSION=os.environ["HS_VERSION"],
)
@app.route("/error")
@oidc.require_login @oidc.require_login
def error_page(): def error_page():
if helper.access_checks() == "Pass": if helper.access_checks() == "Pass":
return redirect(url_for('overview_page')) return redirect(url_for("overview_page"))
return render_template('error.html', return render_template("error.html", ERROR_MESSAGE=Markup(helper.access_checks()))
ERROR_MESSAGE = Markup(helper.access_checks())
)
@app.route('/logout')
@app.route("/logout")
def logout_page(): def logout_page():
if AUTH_TYPE == "oidc": if AUTH_TYPE == "oidc":
oidc.logout() oidc.logout()
return redirect(url_for('overview_page')) return redirect(url_for("overview_page"))
######################################################################################## ########################################################################################
# /api pages # /api pages
######################################################################################## ########################################################################################
@@ -299,227 +346,262 @@ def logout_page():
# Headscale API Key Endpoints # Headscale API Key Endpoints
######################################################################################## ########################################################################################
@app.route('/api/test_key', methods=('GET', 'POST'))
@app.route("/api/test_key", methods=("GET", "POST"))
@oidc.require_login @oidc.require_login
def test_key_page(): def test_key_page():
api_key = headscale.get_api_key() api_key = headscale.get_api_key()
url = headscale.get_url() url = headscale.get_url()
# Test the API key. If the test fails, return a failure. # Test the API key. If the test fails, return a failure.
status = headscale.test_api_key(url, api_key) status = headscale.test_api_key(url, api_key)
if status != 200: return "Unauthenticated" if status != 200:
return "Unauthenticated"
renewed = headscale.renew_api_key(url, api_key) renewed = headscale.renew_api_key(url, api_key)
app.logger.warning("The below statement will be TRUE if the key has been renewed, ") app.logger.warning("The below statement will be TRUE if the key has been renewed, ")
app.logger.warning("or DOES NOT need renewal. False in all other cases") app.logger.warning("or DOES NOT need renewal. False in all other cases")
app.logger.warning("Renewed: "+str(renewed)) app.logger.warning("Renewed: " + str(renewed))
# The key works, let's renew it if it needs it. If it does, re-read the api_key from the file: # The key works, let's renew it if it needs it. If it does, re-read the api_key from the file:
if renewed: api_key = headscale.get_api_key() if renewed:
api_key = headscale.get_api_key()
key_info = headscale.get_api_key_info(url, api_key) key_info = headscale.get_api_key_info(url, api_key)
# Set the current timezone and local time # Set the current timezone and local time
timezone = pytz.timezone(os.environ["TZ"] if os.environ["TZ"] else "UTC") timezone = pytz.timezone(os.environ["TZ"] if os.environ["TZ"] else "UTC")
local_time = timezone.localize(datetime.now()) local_time = timezone.localize(datetime.now())
# Format the dates for easy readability # Format the dates for easy readability
creation_parse = parser.parse(key_info['createdAt']) creation_parse = parser.parse(key_info["createdAt"])
creation_local = creation_parse.astimezone(timezone) creation_local = creation_parse.astimezone(timezone)
creation_delta = local_time - creation_local creation_delta = local_time - creation_local
creation_print = helper.pretty_print_duration(creation_delta) creation_print = helper.pretty_print_duration(creation_delta)
creation_time = str(creation_local.strftime('%A %m/%d/%Y, %H:%M:%S'))+" "+str(timezone)+" ("+str(creation_print)+")" creation_time = (
str(creation_local.strftime("%A %m/%d/%Y, %H:%M:%S"))
+ " "
+ str(timezone)
+ " ("
+ str(creation_print)
+ ")"
)
expiration_parse = parser.parse(key_info['expiration']) expiration_parse = parser.parse(key_info["expiration"])
expiration_local = expiration_parse.astimezone(timezone) expiration_local = expiration_parse.astimezone(timezone)
expiration_delta = expiration_local - local_time expiration_delta = expiration_local - local_time
expiration_print = helper.pretty_print_duration(expiration_delta, "expiry") expiration_print = helper.pretty_print_duration(expiration_delta, "expiry")
expiration_time = str(expiration_local.strftime('%A %m/%d/%Y, %H:%M:%S'))+" "+str(timezone)+" ("+str(expiration_print)+")" expiration_time = (
str(expiration_local.strftime("%A %m/%d/%Y, %H:%M:%S"))
+ " "
+ str(timezone)
+ " ("
+ str(expiration_print)
+ ")"
)
key_info['expiration'] = expiration_time key_info["expiration"] = expiration_time
key_info['createdAt'] = creation_time key_info["createdAt"] = creation_time
message = json.dumps(key_info) message = json.dumps(key_info)
return message return message
@app.route('/api/save_key', methods=['POST'])
@app.route("/api/save_key", methods=["POST"])
@oidc.require_login @oidc.require_login
def save_key_page(): def save_key_page():
json_response = request.get_json() json_response = request.get_json()
api_key = json_response['api_key'] api_key = json_response["api_key"]
url = headscale.get_url() url = headscale.get_url()
file_written = headscale.set_api_key(api_key) file_written = headscale.set_api_key(api_key)
message = '' message = ""
if file_written: if file_written:
# Re-read the file and get the new API key and test it # Re-read the file and get the new API key and test it
api_key = headscale.get_api_key() api_key = headscale.get_api_key()
test_status = headscale.test_api_key(url, api_key) test_status = headscale.test_api_key(url, api_key)
if test_status == 200: if test_status == 200:
key_info = headscale.get_api_key_info(url, api_key) key_info = headscale.get_api_key_info(url, api_key)
expiration = key_info['expiration'] expiration = key_info["expiration"]
message = "Key: '"+api_key+"', Expiration: "+expiration message = "Key: '" + api_key + "', Expiration: " + expiration
# If the key was saved successfully, test it: # If the key was saved successfully, test it:
return "Key saved and tested: "+message return "Key saved and tested: " + message
else: return "Key failed testing. Check your key" else:
else: return "Key did not save properly. Check logs" return "Key failed testing. Check your key"
else:
return "Key did not save properly. Check logs"
######################################################################################## ########################################################################################
# Machine API Endpoints # Machine API Endpoints
######################################################################################## ########################################################################################
@app.route('/api/update_route', methods=['POST']) @app.route("/api/update_route", methods=["POST"])
@oidc.require_login @oidc.require_login
def update_route_page(): def update_route_page():
json_response = request.get_json() json_response = request.get_json()
route_id = escape(json_response['route_id']) route_id = escape(json_response["route_id"])
url = headscale.get_url() url = headscale.get_url()
api_key = headscale.get_api_key() api_key = headscale.get_api_key()
current_state = json_response['current_state'] current_state = json_response["current_state"]
return headscale.update_route(url, api_key, route_id, current_state) return headscale.update_route(url, api_key, route_id, current_state)
@app.route('/api/machine_information', methods=['POST'])
@app.route("/api/machine_information", methods=["POST"])
@oidc.require_login @oidc.require_login
def machine_information_page(): def machine_information_page():
json_response = request.get_json() json_response = request.get_json()
machine_id = escape(json_response['id']) machine_id = escape(json_response["id"])
url = headscale.get_url() url = headscale.get_url()
api_key = headscale.get_api_key() api_key = headscale.get_api_key()
return headscale.get_machine_info(url, api_key, machine_id) return headscale.get_machine_info(url, api_key, machine_id)
@app.route('/api/delete_machine', methods=['POST'])
@app.route("/api/delete_machine", methods=["POST"])
@oidc.require_login @oidc.require_login
def delete_machine_page(): def delete_machine_page():
json_response = request.get_json() json_response = request.get_json()
machine_id = escape(json_response['id']) machine_id = escape(json_response["id"])
url = headscale.get_url() url = headscale.get_url()
api_key = headscale.get_api_key() api_key = headscale.get_api_key()
return headscale.delete_machine(url, api_key, machine_id) return headscale.delete_machine(url, api_key, machine_id)
@app.route('/api/rename_machine', methods=['POST'])
@app.route("/api/rename_machine", methods=["POST"])
@oidc.require_login @oidc.require_login
def rename_machine_page(): def rename_machine_page():
json_response = request.get_json() json_response = request.get_json()
machine_id = escape(json_response['id']) machine_id = escape(json_response["id"])
new_name = escape(json_response['new_name']) new_name = escape(json_response["new_name"])
url = headscale.get_url() url = headscale.get_url()
api_key = headscale.get_api_key() api_key = headscale.get_api_key()
return headscale.rename_machine(url, api_key, machine_id, new_name) return headscale.rename_machine(url, api_key, machine_id, new_name)
@app.route('/api/move_user', methods=['POST'])
@app.route("/api/move_user", methods=["POST"])
@oidc.require_login @oidc.require_login
def move_user_page(): def move_user_page():
json_response = request.get_json() json_response = request.get_json()
machine_id = escape(json_response['id']) machine_id = escape(json_response["id"])
new_user = escape(json_response['new_user']) new_user = escape(json_response["new_user"])
url = headscale.get_url() url = headscale.get_url()
api_key = headscale.get_api_key() api_key = headscale.get_api_key()
return headscale.move_user(url, api_key, machine_id, new_user) return headscale.move_user(url, api_key, machine_id, new_user)
@app.route('/api/set_machine_tags', methods=['POST'])
@app.route("/api/set_machine_tags", methods=["POST"])
@oidc.require_login @oidc.require_login
def set_machine_tags(): def set_machine_tags():
json_response = request.get_json() json_response = request.get_json()
machine_id = escape(json_response['id']) machine_id = escape(json_response["id"])
machine_tags = json_response['tags_list'] machine_tags = json_response["tags_list"]
url = headscale.get_url() url = headscale.get_url()
api_key = headscale.get_api_key() api_key = headscale.get_api_key()
return headscale.set_machine_tags(url, api_key, machine_id, machine_tags) return headscale.set_machine_tags(url, api_key, machine_id, machine_tags)
@app.route('/api/register_machine', methods=['POST'])
@app.route("/api/register_machine", methods=["POST"])
@oidc.require_login @oidc.require_login
def register_machine(): def register_machine():
json_response = request.get_json() json_response = request.get_json()
machine_key = escape(json_response['key']) machine_key = escape(json_response["key"])
user = escape(json_response['user']) user = escape(json_response["user"])
url = headscale.get_url() url = headscale.get_url()
api_key = headscale.get_api_key() api_key = headscale.get_api_key()
return headscale.register_machine(url, api_key, machine_key, user) return headscale.register_machine(url, api_key, machine_key, user)
######################################################################################## ########################################################################################
# User API Endpoints # User API Endpoints
######################################################################################## ########################################################################################
@app.route('/api/rename_user', methods=['POST']) @app.route("/api/rename_user", methods=["POST"])
@oidc.require_login @oidc.require_login
def rename_user_page(): def rename_user_page():
json_response = request.get_json() json_response = request.get_json()
old_name = escape(json_response['old_name']) old_name = escape(json_response["old_name"])
new_name = escape(json_response['new_name']) new_name = escape(json_response["new_name"])
url = headscale.get_url() url = headscale.get_url()
api_key = headscale.get_api_key() api_key = headscale.get_api_key()
return headscale.rename_user(url, api_key, old_name, new_name) return headscale.rename_user(url, api_key, old_name, new_name)
@app.route('/api/add_user', methods=['POST'])
@app.route("/api/add_user", methods=["POST"])
@oidc.require_login @oidc.require_login
def add_user(): def add_user():
json_response = request.get_json() json_response = request.get_json()
user_name = str(escape(json_response['name'])) user_name = str(escape(json_response["name"]))
url = headscale.get_url() url = headscale.get_url()
api_key = headscale.get_api_key() api_key = headscale.get_api_key()
json_string = '{"name": "'+user_name+'"}' json_string = '{"name": "' + user_name + '"}'
return headscale.add_user(url, api_key, json_string) return headscale.add_user(url, api_key, json_string)
@app.route('/api/delete_user', methods=['POST'])
@app.route("/api/delete_user", methods=["POST"])
@oidc.require_login @oidc.require_login
def delete_user(): def delete_user():
json_response = request.get_json() json_response = request.get_json()
user_name = str(escape(json_response['name'])) user_name = str(escape(json_response["name"]))
url = headscale.get_url() url = headscale.get_url()
api_key = headscale.get_api_key() api_key = headscale.get_api_key()
return headscale.delete_user(url, api_key, user_name) return headscale.delete_user(url, api_key, user_name)
@app.route('/api/get_users', methods=['POST'])
@app.route("/api/get_users", methods=["POST"])
@oidc.require_login @oidc.require_login
def get_users_page(): def get_users_page():
url = headscale.get_url() url = headscale.get_url()
api_key = headscale.get_api_key() api_key = headscale.get_api_key()
return headscale.get_users(url, api_key) return headscale.get_users(url, api_key)
######################################################################################## ########################################################################################
# Pre-Auth Key API Endpoints # Pre-Auth Key API Endpoints
######################################################################################## ########################################################################################
@app.route('/api/add_preauth_key', methods=['POST']) @app.route("/api/add_preauth_key", methods=["POST"])
@oidc.require_login @oidc.require_login
def add_preauth_key(): def add_preauth_key():
json_response = json.dumps(request.get_json()) json_response = json.dumps(request.get_json())
url = headscale.get_url() url = headscale.get_url()
api_key = headscale.get_api_key() api_key = headscale.get_api_key()
return headscale.add_preauth_key(url, api_key, json_response) return headscale.add_preauth_key(url, api_key, json_response)
@app.route('/api/expire_preauth_key', methods=['POST'])
@app.route("/api/expire_preauth_key", methods=["POST"])
@oidc.require_login @oidc.require_login
def expire_preauth_key(): def expire_preauth_key():
json_response = json.dumps(request.get_json()) json_response = json.dumps(request.get_json())
url = headscale.get_url() url = headscale.get_url()
api_key = headscale.get_api_key() api_key = headscale.get_api_key()
return headscale.expire_preauth_key(url, api_key, json_response) return headscale.expire_preauth_key(url, api_key, json_response)
@app.route('/api/build_preauthkey_table', methods=['POST'])
@app.route("/api/build_preauthkey_table", methods=["POST"])
@oidc.require_login @oidc.require_login
def build_preauth_key_table(): def build_preauth_key_table():
json_response = request.get_json() json_response = request.get_json()
user_name = str(escape(json_response['name'])) user_name = str(escape(json_response["name"]))
return renderer.build_preauth_key_table(user_name) return renderer.build_preauth_key_table(user_name)
######################################################################################## ########################################################################################
# Route API Endpoints # Route API Endpoints
######################################################################################## ########################################################################################
@app.route('/api/get_routes', methods=['POST']) @app.route("/api/get_routes", methods=["POST"])
@oidc.require_login @oidc.require_login
def get_route_info(): def get_route_info():
url = headscale.get_url() url = headscale.get_url()
api_key = headscale.get_api_key() api_key = headscale.get_api_key()
return headscale.get_routes(url, api_key) return headscale.get_routes(url, api_key)
@@ -528,5 +610,5 @@ def get_route_info():
######################################################################################## ########################################################################################
# Main thread # Main thread
######################################################################################## ########################################################################################
if __name__ == '__main__': if __name__ == "__main__":
app.run(host="0.0.0.0", debug=DEBUG_STATE) app.run(host="0.0.0.0", debug=DEBUG_STATE)